From bf9aa9356bdd180f02e398dac6d18c809b15f636 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 23 Jul 2025 23:27:18 +0800 Subject: [PATCH] refactor: Update model retrieval methods in RecipeRoutes and streamline CheckpointScanner and LoraScanner initialization --- py/routes/recipe_routes.py | 6 +- py/services/checkpoint_scanner.py | 118 +++-------------------------- py/services/lora_scanner.py | 122 +++--------------------------- py/services/model_scanner.py | 92 +++++++++++++++++++++- 4 files changed, 113 insertions(+), 225 deletions(-) diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index d181ed65..07fcb287 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -1148,7 +1148,7 @@ class RecipeRoutes: for lora_name, lora_strength in lora_matches: try: # Get lora info from scanner - lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora_name) + lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora_name) # Create lora entry lora_entry = { @@ -1167,7 +1167,7 @@ class RecipeRoutes: # Get base model from lora scanner for the available loras base_model_counts = {} for lora in loras_data: - lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora.get("file_name", "")) + lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora.get("file_name", "")) if lora_info and "base_model" in lora_info: base_model = lora_info["base_model"] base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 @@ -1365,7 +1365,7 @@ class RecipeRoutes: return web.json_response({"error": "Recipe not found"}, status=404) # Find target LoRA by name - target_lora = await lora_scanner.get_lora_info_by_name(target_name) + target_lora = await lora_scanner.get_model_info_by_name(target_name) if not target_lora: return web.json_response({"error": f"Local LoRA not found with name: {target_name}"}, status=404) diff --git a/py/services/checkpoint_scanner.py b/py/services/checkpoint_scanner.py index 32da3dbf..95569d4f 100644 --- a/py/services/checkpoint_scanner.py +++ b/py/services/checkpoint_scanner.py @@ -1,7 +1,5 @@ -import os import logging -import asyncio -from typing import List, Dict +from typing import List from ..utils.models import CheckpointMetadata from ..config import config @@ -13,101 +11,19 @@ logger = logging.getLogger(__name__) class CheckpointScanner(ModelScanner): """Service for scanning and managing checkpoint files""" - _instance = None - _lock = asyncio.Lock() - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - def __init__(self): - if not hasattr(self, '_initialized'): - # Define supported file extensions - file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'} - super().__init__( - model_type="checkpoint", - model_class=CheckpointMetadata, - file_extensions=file_extensions, - hash_index=ModelHashIndex() - ) - self._initialized = True + # Define supported file extensions + file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'} + super().__init__( + model_type="checkpoint", + model_class=CheckpointMetadata, + file_extensions=file_extensions, + hash_index=ModelHashIndex() + ) - @classmethod - async def get_instance(cls): - """Get singleton instance with async support""" - async with cls._lock: - if cls._instance is None: - cls._instance = cls() - return cls._instance - def get_model_roots(self) -> List[str]: """Get checkpoint root directories""" return config.base_models_roots - - async def scan_all_models(self) -> List[Dict]: - """Scan all checkpoint directories and return metadata""" - all_checkpoints = [] - - # Create scan tasks for each directory - scan_tasks = [] - for root in self.get_model_roots(): - task = asyncio.create_task(self._scan_directory(root)) - scan_tasks.append(task) - - # Wait for all tasks to complete - for task in scan_tasks: - try: - checkpoints = await task - all_checkpoints.extend(checkpoints) - except Exception as e: - logger.error(f"Error scanning checkpoint directory: {e}") - - return all_checkpoints - - async def _scan_directory(self, root_path: str) -> List[Dict]: - """Scan a directory for checkpoint files""" - checkpoints = [] - original_root = root_path - - async def scan_recursive(path: str, visited_paths: set): - try: - real_path = os.path.realpath(path) - if real_path in visited_paths: - logger.debug(f"Skipping already visited path: {path}") - return - visited_paths.add(real_path) - - with os.scandir(path) as it: - entries = list(it) - for entry in entries: - try: - if entry.is_file(follow_symlinks=True): - # Check if file has supported extension - ext = os.path.splitext(entry.name)[1].lower() - if ext in self.file_extensions: - file_path = entry.path.replace(os.sep, "/") - await self._process_single_file(file_path, original_root, checkpoints) - await asyncio.sleep(0) - elif entry.is_dir(follow_symlinks=True): - # For directories, continue scanning with original path - await scan_recursive(entry.path, visited_paths) - except Exception as e: - logger.error(f"Error processing entry {entry.path}: {e}") - except Exception as e: - logger.error(f"Error scanning {path}: {e}") - - await scan_recursive(root_path, set()) - return checkpoints - - async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list): - """Process a single checkpoint file and add to results""" - try: - result = await self._process_model_file(file_path, root_path) - if result: - checkpoints.append(result) - except Exception as e: - logger.error(f"Error processing {file_path}: {e}") # Checkpoint-specific hash index functionality def has_checkpoint_hash(self, sha256: str) -> bool: @@ -120,18 +36,4 @@ class CheckpointScanner(ModelScanner): def get_checkpoint_hash_by_path(self, file_path: str) -> str: """Get hash for a checkpoint by its file path""" - return self.get_hash_by_path(file_path) - - async def get_checkpoint_info_by_name(self, name): - """Get checkpoint information by name""" - try: - cache = await self.get_cached_data() - - for checkpoint in cache.raw_data: - if checkpoint.get("file_name") == name: - return checkpoint - - return None - except Exception as e: - logger.error(f"Error getting checkpoint info by name: {e}", exc_info=True) - return None \ No newline at end of file + return self.get_hash_by_path(file_path) \ No newline at end of file diff --git a/py/services/lora_scanner.py b/py/services/lora_scanner.py index fd2694db..f4066dbe 100644 --- a/py/services/lora_scanner.py +++ b/py/services/lora_scanner.py @@ -1,7 +1,5 @@ -import os import logging -import asyncio -from typing import List, Dict, Optional +from typing import List, Optional from ..utils.models import LoraMetadata from ..config import config @@ -14,103 +12,21 @@ logger = logging.getLogger(__name__) class LoraScanner(ModelScanner): """Service for scanning and managing LoRA files""" - _instance = None - _lock = asyncio.Lock() - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - def __init__(self): - # Ensure initialization happens only once - if not hasattr(self, '_initialized'): - # Define supported file extensions - file_extensions = {'.safetensors'} - - # Initialize parent class with ModelHashIndex - super().__init__( - model_type="lora", - model_class=LoraMetadata, - file_extensions=file_extensions, - hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex - ) - self._initialized = True - - @classmethod - async def get_instance(cls): - """Get singleton instance with async support""" - async with cls._lock: - if cls._instance is None: - cls._instance = cls() - return cls._instance + # Define supported file extensions + file_extensions = {'.safetensors'} + + # Initialize parent class with ModelHashIndex + super().__init__( + model_type="lora", + model_class=LoraMetadata, + file_extensions=file_extensions, + hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex + ) def get_model_roots(self) -> List[str]: """Get lora root directories""" return config.loras_roots - - async def scan_all_models(self) -> List[Dict]: - """Scan all LoRA directories and return metadata""" - all_loras = [] - - # Create scan tasks for each directory - scan_tasks = [] - for lora_root in self.get_model_roots(): - task = asyncio.create_task(self._scan_directory(lora_root)) - scan_tasks.append(task) - - # Wait for all tasks to complete - for task in scan_tasks: - try: - loras = await task - all_loras.extend(loras) - except Exception as e: - logger.error(f"Error scanning directory: {e}") - - return all_loras - - async def _scan_directory(self, root_path: str) -> List[Dict]: - """Scan a single directory for LoRA files""" - loras = [] - original_root = root_path # Save original root path - - async def scan_recursive(path: str, visited_paths: set): - """Recursively scan directory, avoiding circular symlinks""" - try: - real_path = os.path.realpath(path) - if real_path in visited_paths: - logger.debug(f"Skipping already visited path: {path}") - return - visited_paths.add(real_path) - - with os.scandir(path) as it: - entries = list(it) - for entry in entries: - try: - if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions): - # Use original path instead of real path - file_path = entry.path.replace(os.sep, "/") - await self._process_single_file(file_path, original_root, loras) - await asyncio.sleep(0) - elif entry.is_dir(follow_symlinks=True): - # For directories, continue scanning with original path - await scan_recursive(entry.path, visited_paths) - except Exception as e: - logger.error(f"Error processing entry {entry.path}: {e}") - except Exception as e: - logger.error(f"Error scanning {path}: {e}") - - await scan_recursive(root_path, set()) - return loras - - async def _process_single_file(self, file_path: str, root_path: str, loras: list): - """Process a single file and add to results list""" - try: - result = await self._process_model_file(file_path, root_path) - if result: - loras.append(result) - except Exception as e: - logger.error(f"Error processing {file_path}: {e}") # Lora-specific hash index functionality def has_lora_hash(self, sha256: str) -> bool: @@ -160,19 +76,3 @@ class LoraScanner(ModelScanner): test_hash_result = self._hash_index.get_hash(test_path) print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr) - async def get_lora_info_by_name(self, name): - """Get LoRA information by name""" - try: - # Get cached data - cache = await self.get_cached_data() - - # Find the LoRA by name - for lora in cache.raw_data: - if lora.get("file_name") == name: - return lora - - return None - except Exception as e: - logger.error(f"Error getting LoRA info by name: {e}", exc_info=True) - return None - diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index fdd9c020..a31bff42 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -29,7 +29,30 @@ CACHE_VERSION = 3 class ModelScanner: """Base service for scanning and managing model files""" - _lock = asyncio.Lock() + _instances = {} # Dictionary to store instances by class + _locks = {} # Dictionary to store locks by class + + def __new__(cls, *args, **kwargs): + """Implement singleton pattern for each subclass""" + if cls not in cls._instances: + cls._instances[cls] = super().__new__(cls) + return cls._instances[cls] + + @classmethod + def _get_lock(cls): + """Get or create a lock for this class""" + if cls not in cls._locks: + cls._locks[cls] = asyncio.Lock() + return cls._locks[cls] + + @classmethod + async def get_instance(cls): + """Get singleton instance with async support""" + lock = cls._get_lock() + async with lock: + if cls not in cls._instances: + cls._instances[cls] = cls() + return cls._instances[cls] def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None): """Initialize the scanner @@ -40,6 +63,10 @@ class ModelScanner: file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'}) hash_index: Hash index instance (optional) """ + # Ensure initialization happens only once per instance + if hasattr(self, '_initialized'): + return + self.model_type = model_type self.model_class = model_class self.file_extensions = file_extensions @@ -50,6 +77,7 @@ class ModelScanner: self._excluded_models = [] # List to track excluded models self._dirs_last_modified = {} # Track directory modification times self._use_cache_files = False # Flag to control cache file usage, default to disabled + self._initialized = True # Clear cache files if disabled if not self._use_cache_files: @@ -744,10 +772,68 @@ class ModelScanner: finally: self._is_initializing = False # Unset flag - # These methods should be implemented in child classes async def scan_all_models(self) -> List[Dict]: """Scan all model directories and return metadata""" - raise NotImplementedError("Subclasses must implement scan_all_models") + all_models = [] + + # Create scan tasks for each directory + scan_tasks = [] + for model_root in self.get_model_roots(): + task = asyncio.create_task(self._scan_directory(model_root)) + scan_tasks.append(task) + + # Wait for all tasks to complete + for task in scan_tasks: + try: + models = await task + all_models.extend(models) + except Exception as e: + logger.error(f"Error scanning directory: {e}") + + return all_models + + async def _scan_directory(self, root_path: str) -> List[Dict]: + """Scan a single directory for model files""" + models = [] + original_root = root_path # Save original root path + + async def scan_recursive(path: str, visited_paths: set): + """Recursively scan directory, avoiding circular symlinks""" + try: + real_path = os.path.realpath(path) + if real_path in visited_paths: + logger.debug(f"Skipping already visited path: {path}") + return + visited_paths.add(real_path) + + with os.scandir(path) as it: + entries = list(it) + for entry in entries: + try: + if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions): + # Use original path instead of real path + file_path = entry.path.replace(os.sep, "/") + await self._process_single_file(file_path, original_root, models) + await asyncio.sleep(0) + elif entry.is_dir(follow_symlinks=True): + # For directories, continue scanning with original path + await scan_recursive(entry.path, visited_paths) + except Exception as e: + logger.error(f"Error processing entry {entry.path}: {e}") + except Exception as e: + logger.error(f"Error scanning {path}: {e}") + + await scan_recursive(root_path, set()) + return models + + async def _process_single_file(self, file_path: str, root_path: str, models: list): + """Process a single file and add to results list""" + try: + result = await self._process_model_file(file_path, root_path) + if result: + models.append(result) + except Exception as e: + logger.error(f"Error processing {file_path}: {e}") def is_initializing(self) -> bool: """Check if the scanner is currently initializing"""