Add MessagePack support for efficient cache serialization and update dependencies

This commit is contained in:
Will Miao
2025-05-28 22:30:06 +08:00
parent a0b0d40a19
commit e7c626eb5f
4 changed files with 187 additions and 3 deletions

1
.gitignore vendored
View File

@@ -3,3 +3,4 @@ settings.json
output/*
py/run_test.py
.vscode/
cache/

View File

@@ -5,6 +5,7 @@ import asyncio
import time
import shutil
from typing import List, Dict, Optional, Type, Set
import msgpack # Add MessagePack import for efficient serialization
from ..utils.models import BaseModelMetadata
from ..config import config
@@ -17,6 +18,9 @@ from .websocket_manager import ws_manager
logger = logging.getLogger(__name__)
# Define cache version to handle future format changes
CACHE_VERSION = 1
class ModelScanner:
"""Base service for scanning and managing model files"""
@@ -39,6 +43,7 @@ class ModelScanner:
self._tags_count = {} # Dictionary to store tag counts
self._is_initializing = False # Flag to track initialization state
self._excluded_models = [] # List to track excluded models
self._dirs_last_modified = {} # Track directory modification times
# Register this service
asyncio.create_task(self._register_service())
@@ -48,6 +53,149 @@ class ModelScanner:
service_name = f"{self.model_type}_scanner"
await ServiceRegistry.register_service(service_name, self)
def _get_cache_file_path(self) -> Optional[str]:
"""Get the path to the cache file"""
# Get the directory where this module is located
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
# Create a cache directory within the project if it doesn't exist
cache_dir = os.path.join(current_dir, "cache")
os.makedirs(cache_dir, exist_ok=True)
# Create filename based on model type
cache_filename = f"lm_{self.model_type}_cache.msgpack"
return os.path.join(cache_dir, cache_filename)
async def _save_cache_to_disk(self) -> bool:
"""Save cache data to disk using MessagePack"""
if self._cache is None or not self._cache.raw_data:
logger.debug(f"No {self.model_type} cache data to save")
return False
cache_path = self._get_cache_file_path()
if not cache_path:
logger.warning(f"Cannot determine {self.model_type} cache file location")
return False
try:
# Create cache data structure
cache_data = {
"version": CACHE_VERSION,
"timestamp": time.time(),
"model_type": self.model_type,
"raw_data": self._cache.raw_data,
"hash_index": {
"hash_to_path": self._hash_index._hash_to_path,
"filename_to_hash": self._hash_index._filename_to_hash # Fix: changed from path_to_hash to filename_to_hash
},
"tags_count": self._tags_count,
"dirs_last_modified": self._get_dirs_last_modified()
}
# Write to temporary file first (atomic operation)
temp_path = f"{cache_path}.tmp"
with open(temp_path, 'wb') as f:
msgpack.pack(cache_data, f)
# Replace the old file with the new one
if os.path.exists(cache_path):
os.replace(temp_path, cache_path)
else:
os.rename(temp_path, cache_path)
logger.info(f"Saved {self.model_type} cache with {len(self._cache.raw_data)} models to {cache_path}")
return True
except Exception as e:
logger.error(f"Error saving {self.model_type} cache to disk: {e}")
# Try to clean up temp file if it exists
if 'temp_path' in locals() and os.path.exists(temp_path):
try:
os.remove(temp_path)
except:
pass
return False
def _get_dirs_last_modified(self) -> Dict[str, float]:
"""Get last modified time for all model directories"""
dirs_info = {}
for root in self.get_model_roots():
if os.path.exists(root):
dirs_info[root] = os.path.getmtime(root)
# Also check immediate subdirectories for changes
try:
with os.scandir(root) as it:
for entry in it:
if entry.is_dir(follow_symlinks=True):
dirs_info[entry.path] = entry.stat().st_mtime
except Exception as e:
logger.error(f"Error getting directory info for {root}: {e}")
return dirs_info
def _is_cache_valid(self, cache_data: Dict) -> bool:
"""Validate if the loaded cache is still valid"""
if not cache_data or cache_data.get("version") != CACHE_VERSION:
return False
if cache_data.get("model_type") != self.model_type:
return False
# Check if directories have changed
stored_dirs = cache_data.get("dirs_last_modified", {})
current_dirs = self._get_dirs_last_modified()
# If directory structure has changed, cache is invalid
if set(stored_dirs.keys()) != set(current_dirs.keys()):
return False
# Check if any directory's modification time has changed
for dir_path, stored_time in stored_dirs.items():
current_time = current_dirs.get(dir_path)
if current_time is None or current_time > stored_time:
return False
return True
async def _load_cache_from_disk(self) -> bool:
"""Load cache data from disk using MessagePack"""
start_time = time.time()
cache_path = self._get_cache_file_path()
if not cache_path or not os.path.exists(cache_path):
return False
try:
with open(cache_path, 'rb') as f:
cache_data = msgpack.unpack(f)
# Validate cache data
if not self._is_cache_valid(cache_data):
logger.info(f"{self.model_type.capitalize()} cache file found but invalid or outdated")
return False
# Load data into memory
self._cache = ModelCache(
raw_data=cache_data["raw_data"],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# Load hash index
hash_index_data = cache_data.get("hash_index", {})
self._hash_index._hash_to_path = hash_index_data.get("hash_to_path", {})
self._hash_index._filename_to_hash = hash_index_data.get("filename_to_hash", {}) # Fix: changed from path_to_hash to filename_to_hash
# Load tags count
self._tags_count = cache_data.get("tags_count", {})
# Resort the cache
await self._cache.resort()
logger.info(f"Loaded {self.model_type} cache from disk with {len(self._cache.raw_data)} models in {time.time() - start_time:.2f} seconds")
return True
except Exception as e:
logger.error(f"Error loading {self.model_type} cache from disk: {e}")
return False
async def initialize_in_background(self) -> None:
"""Initialize cache in background using thread pool"""
try:
@@ -66,7 +214,31 @@ class ModelScanner:
# Determine the page type based on model type
page_type = 'loras' if self.model_type == 'lora' else 'checkpoints'
# First, count all model files to track progress
# First, try to load from cache
await ws_manager.broadcast_init_progress({
'stage': 'loading_cache',
'progress': 0,
'details': f"Loading {self.model_type} cache...",
'scanner_type': self.model_type,
'pageType': page_type
})
cache_loaded = await self._load_cache_from_disk()
if cache_loaded:
# Cache loaded successfully, broadcast complete message
await ws_manager.broadcast_init_progress({
'stage': 'finalizing',
'progress': 100,
'status': 'complete',
'details': f"Loaded {len(self._cache.raw_data)} {self.model_type} files from cache.",
'scanner_type': self.model_type,
'pageType': page_type
})
self._is_initializing = False
return
# If cache loading failed, proceed with full scan
await ws_manager.broadcast_init_progress({
'stage': 'scan_folders',
'progress': 0,
@@ -111,6 +283,9 @@ class ModelScanner:
logger.info(f"{self.model_type.capitalize()} cache initialized in {time.time() - start_time:.2f} seconds. Found {len(self._cache.raw_data)} models")
# Save the cache to disk after initialization
await self._save_cache_to_disk()
# Send completion message
await asyncio.sleep(0.5) # Small delay to ensure final progress message is sent
await ws_manager.broadcast_init_progress({
@@ -484,6 +659,9 @@ class ModelScanner:
# Resort cache
await self._cache.resort()
# Save updated cache to disk
await self._save_cache_to_disk()
logger.info(f"{self.model_type.capitalize()} Scanner: Cache reconciliation completed in {time.time() - start_time:.2f} seconds. Added {total_added}, removed {total_removed} models.")
except Exception as e:
logger.error(f"{self.model_type.capitalize()} Scanner: Error reconciling cache: {e}", exc_info=True)
@@ -854,6 +1032,9 @@ class ModelScanner:
await cache.resort()
# Save the updated cache
await self._save_cache_to_disk()
return True
def has_hash(self, sha256: str) -> bool:

View File

@@ -14,7 +14,8 @@ dependencies = [
"olefile", # for getting rid of warning message
"requests",
"toml",
"natsort"
"natsort",
"msgpack"
]
[project.urls]

View File

@@ -11,3 +11,4 @@ toml
numpy
torch
natsort
msgpack