refactor: Update model retrieval methods in RecipeRoutes and streamline CheckpointScanner and LoraScanner initialization

This commit is contained in:
Will Miao
2025-07-23 23:27:18 +08:00
parent 68d00ce289
commit bf9aa9356b
4 changed files with 113 additions and 225 deletions

View File

@@ -1148,7 +1148,7 @@ class RecipeRoutes:
for lora_name, lora_strength in lora_matches:
try:
# Get lora info from scanner
lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora_name)
lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora_name)
# Create lora entry
lora_entry = {
@@ -1167,7 +1167,7 @@ class RecipeRoutes:
# Get base model from lora scanner for the available loras
base_model_counts = {}
for lora in loras_data:
lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora.get("file_name", ""))
lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora.get("file_name", ""))
if lora_info and "base_model" in lora_info:
base_model = lora_info["base_model"]
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
@@ -1365,7 +1365,7 @@ class RecipeRoutes:
return web.json_response({"error": "Recipe not found"}, status=404)
# Find target LoRA by name
target_lora = await lora_scanner.get_lora_info_by_name(target_name)
target_lora = await lora_scanner.get_model_info_by_name(target_name)
if not target_lora:
return web.json_response({"error": f"Local LoRA not found with name: {target_name}"}, status=404)

View File

@@ -1,7 +1,5 @@
import os
import logging
import asyncio
from typing import List, Dict
from typing import List
from ..utils.models import CheckpointMetadata
from ..config import config
@@ -13,101 +11,19 @@ logger = logging.getLogger(__name__)
class CheckpointScanner(ModelScanner):
"""Service for scanning and managing checkpoint files"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not hasattr(self, '_initialized'):
# Define supported file extensions
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
super().__init__(
model_type="checkpoint",
model_class=CheckpointMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex()
)
self._initialized = True
# Define supported file extensions
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
super().__init__(
model_type="checkpoint",
model_class=CheckpointMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex()
)
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def get_model_roots(self) -> List[str]:
"""Get checkpoint root directories"""
return config.base_models_roots
async def scan_all_models(self) -> List[Dict]:
"""Scan all checkpoint directories and return metadata"""
all_checkpoints = []
# Create scan tasks for each directory
scan_tasks = []
for root in self.get_model_roots():
task = asyncio.create_task(self._scan_directory(root))
scan_tasks.append(task)
# Wait for all tasks to complete
for task in scan_tasks:
try:
checkpoints = await task
all_checkpoints.extend(checkpoints)
except Exception as e:
logger.error(f"Error scanning checkpoint directory: {e}")
return all_checkpoints
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a directory for checkpoint files"""
checkpoints = []
original_root = root_path
async def scan_recursive(path: str, visited_paths: set):
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True):
# Check if file has supported extension
ext = os.path.splitext(entry.name)[1].lower()
if ext in self.file_extensions:
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, checkpoints)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# For directories, continue scanning with original path
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return checkpoints
async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list):
"""Process a single checkpoint file and add to results"""
try:
result = await self._process_model_file(file_path, root_path)
if result:
checkpoints.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
# Checkpoint-specific hash index functionality
def has_checkpoint_hash(self, sha256: str) -> bool:
@@ -120,18 +36,4 @@ class CheckpointScanner(ModelScanner):
def get_checkpoint_hash_by_path(self, file_path: str) -> str:
"""Get hash for a checkpoint by its file path"""
return self.get_hash_by_path(file_path)
async def get_checkpoint_info_by_name(self, name):
"""Get checkpoint information by name"""
try:
cache = await self.get_cached_data()
for checkpoint in cache.raw_data:
if checkpoint.get("file_name") == name:
return checkpoint
return None
except Exception as e:
logger.error(f"Error getting checkpoint info by name: {e}", exc_info=True)
return None
return self.get_hash_by_path(file_path)

View File

@@ -1,7 +1,5 @@
import os
import logging
import asyncio
from typing import List, Dict, Optional
from typing import List, Optional
from ..utils.models import LoraMetadata
from ..config import config
@@ -14,103 +12,21 @@ logger = logging.getLogger(__name__)
class LoraScanner(ModelScanner):
"""Service for scanning and managing LoRA files"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
# Ensure initialization happens only once
if not hasattr(self, '_initialized'):
# Define supported file extensions
file_extensions = {'.safetensors'}
# Initialize parent class with ModelHashIndex
super().__init__(
model_type="lora",
model_class=LoraMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
)
self._initialized = True
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
# Define supported file extensions
file_extensions = {'.safetensors'}
# Initialize parent class with ModelHashIndex
super().__init__(
model_type="lora",
model_class=LoraMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
)
def get_model_roots(self) -> List[str]:
"""Get lora root directories"""
return config.loras_roots
async def scan_all_models(self) -> List[Dict]:
"""Scan all LoRA directories and return metadata"""
all_loras = []
# Create scan tasks for each directory
scan_tasks = []
for lora_root in self.get_model_roots():
task = asyncio.create_task(self._scan_directory(lora_root))
scan_tasks.append(task)
# Wait for all tasks to complete
for task in scan_tasks:
try:
loras = await task
all_loras.extend(loras)
except Exception as e:
logger.error(f"Error scanning directory: {e}")
return all_loras
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a single directory for LoRA files"""
loras = []
original_root = root_path # Save original root path
async def scan_recursive(path: str, visited_paths: set):
"""Recursively scan directory, avoiding circular symlinks"""
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
# Use original path instead of real path
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, loras)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# For directories, continue scanning with original path
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return loras
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
"""Process a single file and add to results list"""
try:
result = await self._process_model_file(file_path, root_path)
if result:
loras.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
# Lora-specific hash index functionality
def has_lora_hash(self, sha256: str) -> bool:
@@ -160,19 +76,3 @@ class LoraScanner(ModelScanner):
test_hash_result = self._hash_index.get_hash(test_path)
print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr)
async def get_lora_info_by_name(self, name):
"""Get LoRA information by name"""
try:
# Get cached data
cache = await self.get_cached_data()
# Find the LoRA by name
for lora in cache.raw_data:
if lora.get("file_name") == name:
return lora
return None
except Exception as e:
logger.error(f"Error getting LoRA info by name: {e}", exc_info=True)
return None

View File

@@ -29,7 +29,30 @@ CACHE_VERSION = 3
class ModelScanner:
"""Base service for scanning and managing model files"""
_lock = asyncio.Lock()
_instances = {} # Dictionary to store instances by class
_locks = {} # Dictionary to store locks by class
def __new__(cls, *args, **kwargs):
"""Implement singleton pattern for each subclass"""
if cls not in cls._instances:
cls._instances[cls] = super().__new__(cls)
return cls._instances[cls]
@classmethod
def _get_lock(cls):
"""Get or create a lock for this class"""
if cls not in cls._locks:
cls._locks[cls] = asyncio.Lock()
return cls._locks[cls]
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
lock = cls._get_lock()
async with lock:
if cls not in cls._instances:
cls._instances[cls] = cls()
return cls._instances[cls]
def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None):
"""Initialize the scanner
@@ -40,6 +63,10 @@ class ModelScanner:
file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'})
hash_index: Hash index instance (optional)
"""
# Ensure initialization happens only once per instance
if hasattr(self, '_initialized'):
return
self.model_type = model_type
self.model_class = model_class
self.file_extensions = file_extensions
@@ -50,6 +77,7 @@ class ModelScanner:
self._excluded_models = [] # List to track excluded models
self._dirs_last_modified = {} # Track directory modification times
self._use_cache_files = False # Flag to control cache file usage, default to disabled
self._initialized = True
# Clear cache files if disabled
if not self._use_cache_files:
@@ -744,10 +772,68 @@ class ModelScanner:
finally:
self._is_initializing = False # Unset flag
# These methods should be implemented in child classes
async def scan_all_models(self) -> List[Dict]:
"""Scan all model directories and return metadata"""
raise NotImplementedError("Subclasses must implement scan_all_models")
all_models = []
# Create scan tasks for each directory
scan_tasks = []
for model_root in self.get_model_roots():
task = asyncio.create_task(self._scan_directory(model_root))
scan_tasks.append(task)
# Wait for all tasks to complete
for task in scan_tasks:
try:
models = await task
all_models.extend(models)
except Exception as e:
logger.error(f"Error scanning directory: {e}")
return all_models
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a single directory for model files"""
models = []
original_root = root_path # Save original root path
async def scan_recursive(path: str, visited_paths: set):
"""Recursively scan directory, avoiding circular symlinks"""
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
# Use original path instead of real path
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, models)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# For directories, continue scanning with original path
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return models
async def _process_single_file(self, file_path: str, root_path: str, models: list):
"""Process a single file and add to results list"""
try:
result = await self._process_model_file(file_path, root_path)
if result:
models.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
def is_initializing(self) -> bool:
"""Check if the scanner is currently initializing"""