From e7c626eb5fe267cf0add5d985252f8deaa85b30c Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 28 May 2025 22:30:06 +0800 Subject: [PATCH] Add MessagePack support for efficient cache serialization and update dependencies --- .gitignore | 1 + py/services/model_scanner.py | 183 ++++++++++++++++++++++++++++++++++- pyproject.toml | 3 +- requirements.txt | 3 +- 4 files changed, 187 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 183cf2cd..584ff547 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ settings.json output/* py/run_test.py .vscode/ +cache/ diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index dfa4fb66..1690c485 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -5,6 +5,7 @@ import asyncio import time import shutil from typing import List, Dict, Optional, Type, Set +import msgpack # Add MessagePack import for efficient serialization from ..utils.models import BaseModelMetadata from ..config import config @@ -17,6 +18,9 @@ from .websocket_manager import ws_manager logger = logging.getLogger(__name__) +# Define cache version to handle future format changes +CACHE_VERSION = 1 + class ModelScanner: """Base service for scanning and managing model files""" @@ -39,6 +43,7 @@ class ModelScanner: self._tags_count = {} # Dictionary to store tag counts self._is_initializing = False # Flag to track initialization state self._excluded_models = [] # List to track excluded models + self._dirs_last_modified = {} # Track directory modification times # Register this service asyncio.create_task(self._register_service()) @@ -48,6 +53,149 @@ class ModelScanner: service_name = f"{self.model_type}_scanner" await ServiceRegistry.register_service(service_name, self) + def _get_cache_file_path(self) -> Optional[str]: + """Get the path to the cache file""" + # Get the directory where this module is located + current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + + # Create a cache directory within the project if it doesn't exist + cache_dir = os.path.join(current_dir, "cache") + os.makedirs(cache_dir, exist_ok=True) + + # Create filename based on model type + cache_filename = f"lm_{self.model_type}_cache.msgpack" + return os.path.join(cache_dir, cache_filename) + + async def _save_cache_to_disk(self) -> bool: + """Save cache data to disk using MessagePack""" + if self._cache is None or not self._cache.raw_data: + logger.debug(f"No {self.model_type} cache data to save") + return False + + cache_path = self._get_cache_file_path() + if not cache_path: + logger.warning(f"Cannot determine {self.model_type} cache file location") + return False + + try: + # Create cache data structure + cache_data = { + "version": CACHE_VERSION, + "timestamp": time.time(), + "model_type": self.model_type, + "raw_data": self._cache.raw_data, + "hash_index": { + "hash_to_path": self._hash_index._hash_to_path, + "filename_to_hash": self._hash_index._filename_to_hash # Fix: changed from path_to_hash to filename_to_hash + }, + "tags_count": self._tags_count, + "dirs_last_modified": self._get_dirs_last_modified() + } + + # Write to temporary file first (atomic operation) + temp_path = f"{cache_path}.tmp" + with open(temp_path, 'wb') as f: + msgpack.pack(cache_data, f) + + # Replace the old file with the new one + if os.path.exists(cache_path): + os.replace(temp_path, cache_path) + else: + os.rename(temp_path, cache_path) + + logger.info(f"Saved {self.model_type} cache with {len(self._cache.raw_data)} models to {cache_path}") + return True + except Exception as e: + logger.error(f"Error saving {self.model_type} cache to disk: {e}") + # Try to clean up temp file if it exists + if 'temp_path' in locals() and os.path.exists(temp_path): + try: + os.remove(temp_path) + except: + pass + return False + + def _get_dirs_last_modified(self) -> Dict[str, float]: + """Get last modified time for all model directories""" + dirs_info = {} + for root in self.get_model_roots(): + if os.path.exists(root): + dirs_info[root] = os.path.getmtime(root) + # Also check immediate subdirectories for changes + try: + with os.scandir(root) as it: + for entry in it: + if entry.is_dir(follow_symlinks=True): + dirs_info[entry.path] = entry.stat().st_mtime + except Exception as e: + logger.error(f"Error getting directory info for {root}: {e}") + return dirs_info + + def _is_cache_valid(self, cache_data: Dict) -> bool: + """Validate if the loaded cache is still valid""" + if not cache_data or cache_data.get("version") != CACHE_VERSION: + return False + + if cache_data.get("model_type") != self.model_type: + return False + + # Check if directories have changed + stored_dirs = cache_data.get("dirs_last_modified", {}) + current_dirs = self._get_dirs_last_modified() + + # If directory structure has changed, cache is invalid + if set(stored_dirs.keys()) != set(current_dirs.keys()): + return False + + # Check if any directory's modification time has changed + for dir_path, stored_time in stored_dirs.items(): + current_time = current_dirs.get(dir_path) + if current_time is None or current_time > stored_time: + return False + + return True + + async def _load_cache_from_disk(self) -> bool: + """Load cache data from disk using MessagePack""" + start_time = time.time() + cache_path = self._get_cache_file_path() + if not cache_path or not os.path.exists(cache_path): + return False + + try: + with open(cache_path, 'rb') as f: + cache_data = msgpack.unpack(f) + + # Validate cache data + if not self._is_cache_valid(cache_data): + logger.info(f"{self.model_type.capitalize()} cache file found but invalid or outdated") + return False + + # Load data into memory + self._cache = ModelCache( + raw_data=cache_data["raw_data"], + sorted_by_name=[], + sorted_by_date=[], + folders=[] + ) + + # Load hash index + hash_index_data = cache_data.get("hash_index", {}) + self._hash_index._hash_to_path = hash_index_data.get("hash_to_path", {}) + self._hash_index._filename_to_hash = hash_index_data.get("filename_to_hash", {}) # Fix: changed from path_to_hash to filename_to_hash + + # Load tags count + self._tags_count = cache_data.get("tags_count", {}) + + # Resort the cache + await self._cache.resort() + + logger.info(f"Loaded {self.model_type} cache from disk with {len(self._cache.raw_data)} models in {time.time() - start_time:.2f} seconds") + return True + except Exception as e: + logger.error(f"Error loading {self.model_type} cache from disk: {e}") + return False + async def initialize_in_background(self) -> None: """Initialize cache in background using thread pool""" try: @@ -66,7 +214,31 @@ class ModelScanner: # Determine the page type based on model type page_type = 'loras' if self.model_type == 'lora' else 'checkpoints' - # First, count all model files to track progress + # First, try to load from cache + await ws_manager.broadcast_init_progress({ + 'stage': 'loading_cache', + 'progress': 0, + 'details': f"Loading {self.model_type} cache...", + 'scanner_type': self.model_type, + 'pageType': page_type + }) + + cache_loaded = await self._load_cache_from_disk() + + if cache_loaded: + # Cache loaded successfully, broadcast complete message + await ws_manager.broadcast_init_progress({ + 'stage': 'finalizing', + 'progress': 100, + 'status': 'complete', + 'details': f"Loaded {len(self._cache.raw_data)} {self.model_type} files from cache.", + 'scanner_type': self.model_type, + 'pageType': page_type + }) + self._is_initializing = False + return + + # If cache loading failed, proceed with full scan await ws_manager.broadcast_init_progress({ 'stage': 'scan_folders', 'progress': 0, @@ -111,6 +283,9 @@ class ModelScanner: logger.info(f"{self.model_type.capitalize()} cache initialized in {time.time() - start_time:.2f} seconds. Found {len(self._cache.raw_data)} models") + # Save the cache to disk after initialization + await self._save_cache_to_disk() + # Send completion message await asyncio.sleep(0.5) # Small delay to ensure final progress message is sent await ws_manager.broadcast_init_progress({ @@ -484,6 +659,9 @@ class ModelScanner: # Resort cache await self._cache.resort() + # Save updated cache to disk + await self._save_cache_to_disk() + logger.info(f"{self.model_type.capitalize()} Scanner: Cache reconciliation completed in {time.time() - start_time:.2f} seconds. Added {total_added}, removed {total_removed} models.") except Exception as e: logger.error(f"{self.model_type.capitalize()} Scanner: Error reconciling cache: {e}", exc_info=True) @@ -854,6 +1032,9 @@ class ModelScanner: await cache.resort() + # Save the updated cache + await self._save_cache_to_disk() + return True def has_hash(self, sha256: str) -> bool: diff --git a/pyproject.toml b/pyproject.toml index 377a53e8..771f9db4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,8 @@ dependencies = [ "olefile", # for getting rid of warning message "requests", "toml", - "natsort" + "natsort", + "msgpack" ] [project.urls] diff --git a/requirements.txt b/requirements.txt index 6a90387f..1a0a199d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ requests toml numpy torch -natsort \ No newline at end of file +natsort +msgpack \ No newline at end of file