refactor: Update model retrieval methods in RecipeRoutes and streamline CheckpointScanner and LoraScanner initialization

This commit is contained in:
Will Miao
2025-07-23 23:27:18 +08:00
parent 68d00ce289
commit bf9aa9356b
4 changed files with 113 additions and 225 deletions

View File

@@ -1148,7 +1148,7 @@ class RecipeRoutes:
for lora_name, lora_strength in lora_matches: for lora_name, lora_strength in lora_matches:
try: try:
# Get lora info from scanner # Get lora info from scanner
lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora_name) lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora_name)
# Create lora entry # Create lora entry
lora_entry = { lora_entry = {
@@ -1167,7 +1167,7 @@ class RecipeRoutes:
# Get base model from lora scanner for the available loras # Get base model from lora scanner for the available loras
base_model_counts = {} base_model_counts = {}
for lora in loras_data: for lora in loras_data:
lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora.get("file_name", "")) lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora.get("file_name", ""))
if lora_info and "base_model" in lora_info: if lora_info and "base_model" in lora_info:
base_model = lora_info["base_model"] base_model = lora_info["base_model"]
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
@@ -1365,7 +1365,7 @@ class RecipeRoutes:
return web.json_response({"error": "Recipe not found"}, status=404) return web.json_response({"error": "Recipe not found"}, status=404)
# Find target LoRA by name # Find target LoRA by name
target_lora = await lora_scanner.get_lora_info_by_name(target_name) target_lora = await lora_scanner.get_model_info_by_name(target_name)
if not target_lora: if not target_lora:
return web.json_response({"error": f"Local LoRA not found with name: {target_name}"}, status=404) return web.json_response({"error": f"Local LoRA not found with name: {target_name}"}, status=404)

View File

@@ -1,7 +1,5 @@
import os
import logging import logging
import asyncio from typing import List
from typing import List, Dict
from ..utils.models import CheckpointMetadata from ..utils.models import CheckpointMetadata
from ..config import config from ..config import config
@@ -13,16 +11,7 @@ logger = logging.getLogger(__name__)
class CheckpointScanner(ModelScanner): class CheckpointScanner(ModelScanner):
"""Service for scanning and managing checkpoint files""" """Service for scanning and managing checkpoint files"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self): def __init__(self):
if not hasattr(self, '_initialized'):
# Define supported file extensions # Define supported file extensions
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'} file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
super().__init__( super().__init__(
@@ -31,84 +20,11 @@ class CheckpointScanner(ModelScanner):
file_extensions=file_extensions, file_extensions=file_extensions,
hash_index=ModelHashIndex() hash_index=ModelHashIndex()
) )
self._initialized = True
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def get_model_roots(self) -> List[str]: def get_model_roots(self) -> List[str]:
"""Get checkpoint root directories""" """Get checkpoint root directories"""
return config.base_models_roots return config.base_models_roots
async def scan_all_models(self) -> List[Dict]:
"""Scan all checkpoint directories and return metadata"""
all_checkpoints = []
# Create scan tasks for each directory
scan_tasks = []
for root in self.get_model_roots():
task = asyncio.create_task(self._scan_directory(root))
scan_tasks.append(task)
# Wait for all tasks to complete
for task in scan_tasks:
try:
checkpoints = await task
all_checkpoints.extend(checkpoints)
except Exception as e:
logger.error(f"Error scanning checkpoint directory: {e}")
return all_checkpoints
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a directory for checkpoint files"""
checkpoints = []
original_root = root_path
async def scan_recursive(path: str, visited_paths: set):
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True):
# Check if file has supported extension
ext = os.path.splitext(entry.name)[1].lower()
if ext in self.file_extensions:
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, checkpoints)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# For directories, continue scanning with original path
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return checkpoints
async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list):
"""Process a single checkpoint file and add to results"""
try:
result = await self._process_model_file(file_path, root_path)
if result:
checkpoints.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
# Checkpoint-specific hash index functionality # Checkpoint-specific hash index functionality
def has_checkpoint_hash(self, sha256: str) -> bool: def has_checkpoint_hash(self, sha256: str) -> bool:
"""Check if a checkpoint with given hash exists""" """Check if a checkpoint with given hash exists"""
@@ -121,17 +37,3 @@ class CheckpointScanner(ModelScanner):
def get_checkpoint_hash_by_path(self, file_path: str) -> str: def get_checkpoint_hash_by_path(self, file_path: str) -> str:
"""Get hash for a checkpoint by its file path""" """Get hash for a checkpoint by its file path"""
return self.get_hash_by_path(file_path) return self.get_hash_by_path(file_path)
async def get_checkpoint_info_by_name(self, name):
"""Get checkpoint information by name"""
try:
cache = await self.get_cached_data()
for checkpoint in cache.raw_data:
if checkpoint.get("file_name") == name:
return checkpoint
return None
except Exception as e:
logger.error(f"Error getting checkpoint info by name: {e}", exc_info=True)
return None

View File

@@ -1,7 +1,5 @@
import os
import logging import logging
import asyncio from typing import List, Optional
from typing import List, Dict, Optional
from ..utils.models import LoraMetadata from ..utils.models import LoraMetadata
from ..config import config from ..config import config
@@ -14,17 +12,7 @@ logger = logging.getLogger(__name__)
class LoraScanner(ModelScanner): class LoraScanner(ModelScanner):
"""Service for scanning and managing LoRA files""" """Service for scanning and managing LoRA files"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self): def __init__(self):
# Ensure initialization happens only once
if not hasattr(self, '_initialized'):
# Define supported file extensions # Define supported file extensions
file_extensions = {'.safetensors'} file_extensions = {'.safetensors'}
@@ -35,83 +23,11 @@ class LoraScanner(ModelScanner):
file_extensions=file_extensions, file_extensions=file_extensions,
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
) )
self._initialized = True
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def get_model_roots(self) -> List[str]: def get_model_roots(self) -> List[str]:
"""Get lora root directories""" """Get lora root directories"""
return config.loras_roots return config.loras_roots
async def scan_all_models(self) -> List[Dict]:
"""Scan all LoRA directories and return metadata"""
all_loras = []
# Create scan tasks for each directory
scan_tasks = []
for lora_root in self.get_model_roots():
task = asyncio.create_task(self._scan_directory(lora_root))
scan_tasks.append(task)
# Wait for all tasks to complete
for task in scan_tasks:
try:
loras = await task
all_loras.extend(loras)
except Exception as e:
logger.error(f"Error scanning directory: {e}")
return all_loras
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a single directory for LoRA files"""
loras = []
original_root = root_path # Save original root path
async def scan_recursive(path: str, visited_paths: set):
"""Recursively scan directory, avoiding circular symlinks"""
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
# Use original path instead of real path
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, loras)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# For directories, continue scanning with original path
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return loras
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
"""Process a single file and add to results list"""
try:
result = await self._process_model_file(file_path, root_path)
if result:
loras.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
# Lora-specific hash index functionality # Lora-specific hash index functionality
def has_lora_hash(self, sha256: str) -> bool: def has_lora_hash(self, sha256: str) -> bool:
"""Check if a LoRA with given hash exists""" """Check if a LoRA with given hash exists"""
@@ -160,19 +76,3 @@ class LoraScanner(ModelScanner):
test_hash_result = self._hash_index.get_hash(test_path) test_hash_result = self._hash_index.get_hash(test_path)
print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr) print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr)
async def get_lora_info_by_name(self, name):
"""Get LoRA information by name"""
try:
# Get cached data
cache = await self.get_cached_data()
# Find the LoRA by name
for lora in cache.raw_data:
if lora.get("file_name") == name:
return lora
return None
except Exception as e:
logger.error(f"Error getting LoRA info by name: {e}", exc_info=True)
return None

View File

@@ -29,7 +29,30 @@ CACHE_VERSION = 3
class ModelScanner: class ModelScanner:
"""Base service for scanning and managing model files""" """Base service for scanning and managing model files"""
_lock = asyncio.Lock() _instances = {} # Dictionary to store instances by class
_locks = {} # Dictionary to store locks by class
def __new__(cls, *args, **kwargs):
"""Implement singleton pattern for each subclass"""
if cls not in cls._instances:
cls._instances[cls] = super().__new__(cls)
return cls._instances[cls]
@classmethod
def _get_lock(cls):
"""Get or create a lock for this class"""
if cls not in cls._locks:
cls._locks[cls] = asyncio.Lock()
return cls._locks[cls]
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
lock = cls._get_lock()
async with lock:
if cls not in cls._instances:
cls._instances[cls] = cls()
return cls._instances[cls]
def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None): def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None):
"""Initialize the scanner """Initialize the scanner
@@ -40,6 +63,10 @@ class ModelScanner:
file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'}) file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'})
hash_index: Hash index instance (optional) hash_index: Hash index instance (optional)
""" """
# Ensure initialization happens only once per instance
if hasattr(self, '_initialized'):
return
self.model_type = model_type self.model_type = model_type
self.model_class = model_class self.model_class = model_class
self.file_extensions = file_extensions self.file_extensions = file_extensions
@@ -50,6 +77,7 @@ class ModelScanner:
self._excluded_models = [] # List to track excluded models self._excluded_models = [] # List to track excluded models
self._dirs_last_modified = {} # Track directory modification times self._dirs_last_modified = {} # Track directory modification times
self._use_cache_files = False # Flag to control cache file usage, default to disabled self._use_cache_files = False # Flag to control cache file usage, default to disabled
self._initialized = True
# Clear cache files if disabled # Clear cache files if disabled
if not self._use_cache_files: if not self._use_cache_files:
@@ -744,10 +772,68 @@ class ModelScanner:
finally: finally:
self._is_initializing = False # Unset flag self._is_initializing = False # Unset flag
# These methods should be implemented in child classes
async def scan_all_models(self) -> List[Dict]: async def scan_all_models(self) -> List[Dict]:
"""Scan all model directories and return metadata""" """Scan all model directories and return metadata"""
raise NotImplementedError("Subclasses must implement scan_all_models") all_models = []
# Create scan tasks for each directory
scan_tasks = []
for model_root in self.get_model_roots():
task = asyncio.create_task(self._scan_directory(model_root))
scan_tasks.append(task)
# Wait for all tasks to complete
for task in scan_tasks:
try:
models = await task
all_models.extend(models)
except Exception as e:
logger.error(f"Error scanning directory: {e}")
return all_models
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a single directory for model files"""
models = []
original_root = root_path # Save original root path
async def scan_recursive(path: str, visited_paths: set):
"""Recursively scan directory, avoiding circular symlinks"""
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
# Use original path instead of real path
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, models)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# For directories, continue scanning with original path
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return models
async def _process_single_file(self, file_path: str, root_path: str, models: list):
"""Process a single file and add to results list"""
try:
result = await self._process_model_file(file_path, root_path)
if result:
models.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
def is_initializing(self) -> bool: def is_initializing(self) -> bool:
"""Check if the scanner is currently initializing""" """Check if the scanner is currently initializing"""