mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-26 07:35:44 -03:00
refactor: Update model retrieval methods in RecipeRoutes and streamline CheckpointScanner and LoraScanner initialization
This commit is contained in:
@@ -1148,7 +1148,7 @@ class RecipeRoutes:
|
|||||||
for lora_name, lora_strength in lora_matches:
|
for lora_name, lora_strength in lora_matches:
|
||||||
try:
|
try:
|
||||||
# Get lora info from scanner
|
# Get lora info from scanner
|
||||||
lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora_name)
|
lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora_name)
|
||||||
|
|
||||||
# Create lora entry
|
# Create lora entry
|
||||||
lora_entry = {
|
lora_entry = {
|
||||||
@@ -1167,7 +1167,7 @@ class RecipeRoutes:
|
|||||||
# Get base model from lora scanner for the available loras
|
# Get base model from lora scanner for the available loras
|
||||||
base_model_counts = {}
|
base_model_counts = {}
|
||||||
for lora in loras_data:
|
for lora in loras_data:
|
||||||
lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora.get("file_name", ""))
|
lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora.get("file_name", ""))
|
||||||
if lora_info and "base_model" in lora_info:
|
if lora_info and "base_model" in lora_info:
|
||||||
base_model = lora_info["base_model"]
|
base_model = lora_info["base_model"]
|
||||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||||
@@ -1365,7 +1365,7 @@ class RecipeRoutes:
|
|||||||
return web.json_response({"error": "Recipe not found"}, status=404)
|
return web.json_response({"error": "Recipe not found"}, status=404)
|
||||||
|
|
||||||
# Find target LoRA by name
|
# Find target LoRA by name
|
||||||
target_lora = await lora_scanner.get_lora_info_by_name(target_name)
|
target_lora = await lora_scanner.get_model_info_by_name(target_name)
|
||||||
if not target_lora:
|
if not target_lora:
|
||||||
return web.json_response({"error": f"Local LoRA not found with name: {target_name}"}, status=404)
|
return web.json_response({"error": f"Local LoRA not found with name: {target_name}"}, status=404)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
from typing import List
|
||||||
from typing import List, Dict
|
|
||||||
|
|
||||||
from ..utils.models import CheckpointMetadata
|
from ..utils.models import CheckpointMetadata
|
||||||
from ..config import config
|
from ..config import config
|
||||||
@@ -13,16 +11,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class CheckpointScanner(ModelScanner):
|
class CheckpointScanner(ModelScanner):
|
||||||
"""Service for scanning and managing checkpoint files"""
|
"""Service for scanning and managing checkpoint files"""
|
||||||
|
|
||||||
_instance = None
|
|
||||||
_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
def __new__(cls):
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super().__new__(cls)
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
if not hasattr(self, '_initialized'):
|
|
||||||
# Define supported file extensions
|
# Define supported file extensions
|
||||||
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
|
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -31,84 +20,11 @@ class CheckpointScanner(ModelScanner):
|
|||||||
file_extensions=file_extensions,
|
file_extensions=file_extensions,
|
||||||
hash_index=ModelHashIndex()
|
hash_index=ModelHashIndex()
|
||||||
)
|
)
|
||||||
self._initialized = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_instance(cls):
|
|
||||||
"""Get singleton instance with async support"""
|
|
||||||
async with cls._lock:
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = cls()
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def get_model_roots(self) -> List[str]:
|
def get_model_roots(self) -> List[str]:
|
||||||
"""Get checkpoint root directories"""
|
"""Get checkpoint root directories"""
|
||||||
return config.base_models_roots
|
return config.base_models_roots
|
||||||
|
|
||||||
async def scan_all_models(self) -> List[Dict]:
|
|
||||||
"""Scan all checkpoint directories and return metadata"""
|
|
||||||
all_checkpoints = []
|
|
||||||
|
|
||||||
# Create scan tasks for each directory
|
|
||||||
scan_tasks = []
|
|
||||||
for root in self.get_model_roots():
|
|
||||||
task = asyncio.create_task(self._scan_directory(root))
|
|
||||||
scan_tasks.append(task)
|
|
||||||
|
|
||||||
# Wait for all tasks to complete
|
|
||||||
for task in scan_tasks:
|
|
||||||
try:
|
|
||||||
checkpoints = await task
|
|
||||||
all_checkpoints.extend(checkpoints)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error scanning checkpoint directory: {e}")
|
|
||||||
|
|
||||||
return all_checkpoints
|
|
||||||
|
|
||||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
|
||||||
"""Scan a directory for checkpoint files"""
|
|
||||||
checkpoints = []
|
|
||||||
original_root = root_path
|
|
||||||
|
|
||||||
async def scan_recursive(path: str, visited_paths: set):
|
|
||||||
try:
|
|
||||||
real_path = os.path.realpath(path)
|
|
||||||
if real_path in visited_paths:
|
|
||||||
logger.debug(f"Skipping already visited path: {path}")
|
|
||||||
return
|
|
||||||
visited_paths.add(real_path)
|
|
||||||
|
|
||||||
with os.scandir(path) as it:
|
|
||||||
entries = list(it)
|
|
||||||
for entry in entries:
|
|
||||||
try:
|
|
||||||
if entry.is_file(follow_symlinks=True):
|
|
||||||
# Check if file has supported extension
|
|
||||||
ext = os.path.splitext(entry.name)[1].lower()
|
|
||||||
if ext in self.file_extensions:
|
|
||||||
file_path = entry.path.replace(os.sep, "/")
|
|
||||||
await self._process_single_file(file_path, original_root, checkpoints)
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
elif entry.is_dir(follow_symlinks=True):
|
|
||||||
# For directories, continue scanning with original path
|
|
||||||
await scan_recursive(entry.path, visited_paths)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error scanning {path}: {e}")
|
|
||||||
|
|
||||||
await scan_recursive(root_path, set())
|
|
||||||
return checkpoints
|
|
||||||
|
|
||||||
async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list):
|
|
||||||
"""Process a single checkpoint file and add to results"""
|
|
||||||
try:
|
|
||||||
result = await self._process_model_file(file_path, root_path)
|
|
||||||
if result:
|
|
||||||
checkpoints.append(result)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing {file_path}: {e}")
|
|
||||||
|
|
||||||
# Checkpoint-specific hash index functionality
|
# Checkpoint-specific hash index functionality
|
||||||
def has_checkpoint_hash(self, sha256: str) -> bool:
|
def has_checkpoint_hash(self, sha256: str) -> bool:
|
||||||
"""Check if a checkpoint with given hash exists"""
|
"""Check if a checkpoint with given hash exists"""
|
||||||
@@ -121,17 +37,3 @@ class CheckpointScanner(ModelScanner):
|
|||||||
def get_checkpoint_hash_by_path(self, file_path: str) -> str:
|
def get_checkpoint_hash_by_path(self, file_path: str) -> str:
|
||||||
"""Get hash for a checkpoint by its file path"""
|
"""Get hash for a checkpoint by its file path"""
|
||||||
return self.get_hash_by_path(file_path)
|
return self.get_hash_by_path(file_path)
|
||||||
|
|
||||||
async def get_checkpoint_info_by_name(self, name):
|
|
||||||
"""Get checkpoint information by name"""
|
|
||||||
try:
|
|
||||||
cache = await self.get_cached_data()
|
|
||||||
|
|
||||||
for checkpoint in cache.raw_data:
|
|
||||||
if checkpoint.get("file_name") == name:
|
|
||||||
return checkpoint
|
|
||||||
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting checkpoint info by name: {e}", exc_info=True)
|
|
||||||
return None
|
|
||||||
@@ -1,7 +1,5 @@
|
|||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
from typing import List, Optional
|
||||||
from typing import List, Dict, Optional
|
|
||||||
|
|
||||||
from ..utils.models import LoraMetadata
|
from ..utils.models import LoraMetadata
|
||||||
from ..config import config
|
from ..config import config
|
||||||
@@ -14,17 +12,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class LoraScanner(ModelScanner):
|
class LoraScanner(ModelScanner):
|
||||||
"""Service for scanning and managing LoRA files"""
|
"""Service for scanning and managing LoRA files"""
|
||||||
|
|
||||||
_instance = None
|
|
||||||
_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
def __new__(cls):
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super().__new__(cls)
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Ensure initialization happens only once
|
|
||||||
if not hasattr(self, '_initialized'):
|
|
||||||
# Define supported file extensions
|
# Define supported file extensions
|
||||||
file_extensions = {'.safetensors'}
|
file_extensions = {'.safetensors'}
|
||||||
|
|
||||||
@@ -35,83 +23,11 @@ class LoraScanner(ModelScanner):
|
|||||||
file_extensions=file_extensions,
|
file_extensions=file_extensions,
|
||||||
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
|
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
|
||||||
)
|
)
|
||||||
self._initialized = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_instance(cls):
|
|
||||||
"""Get singleton instance with async support"""
|
|
||||||
async with cls._lock:
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = cls()
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def get_model_roots(self) -> List[str]:
|
def get_model_roots(self) -> List[str]:
|
||||||
"""Get lora root directories"""
|
"""Get lora root directories"""
|
||||||
return config.loras_roots
|
return config.loras_roots
|
||||||
|
|
||||||
async def scan_all_models(self) -> List[Dict]:
|
|
||||||
"""Scan all LoRA directories and return metadata"""
|
|
||||||
all_loras = []
|
|
||||||
|
|
||||||
# Create scan tasks for each directory
|
|
||||||
scan_tasks = []
|
|
||||||
for lora_root in self.get_model_roots():
|
|
||||||
task = asyncio.create_task(self._scan_directory(lora_root))
|
|
||||||
scan_tasks.append(task)
|
|
||||||
|
|
||||||
# Wait for all tasks to complete
|
|
||||||
for task in scan_tasks:
|
|
||||||
try:
|
|
||||||
loras = await task
|
|
||||||
all_loras.extend(loras)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error scanning directory: {e}")
|
|
||||||
|
|
||||||
return all_loras
|
|
||||||
|
|
||||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
|
||||||
"""Scan a single directory for LoRA files"""
|
|
||||||
loras = []
|
|
||||||
original_root = root_path # Save original root path
|
|
||||||
|
|
||||||
async def scan_recursive(path: str, visited_paths: set):
|
|
||||||
"""Recursively scan directory, avoiding circular symlinks"""
|
|
||||||
try:
|
|
||||||
real_path = os.path.realpath(path)
|
|
||||||
if real_path in visited_paths:
|
|
||||||
logger.debug(f"Skipping already visited path: {path}")
|
|
||||||
return
|
|
||||||
visited_paths.add(real_path)
|
|
||||||
|
|
||||||
with os.scandir(path) as it:
|
|
||||||
entries = list(it)
|
|
||||||
for entry in entries:
|
|
||||||
try:
|
|
||||||
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
|
|
||||||
# Use original path instead of real path
|
|
||||||
file_path = entry.path.replace(os.sep, "/")
|
|
||||||
await self._process_single_file(file_path, original_root, loras)
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
elif entry.is_dir(follow_symlinks=True):
|
|
||||||
# For directories, continue scanning with original path
|
|
||||||
await scan_recursive(entry.path, visited_paths)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error scanning {path}: {e}")
|
|
||||||
|
|
||||||
await scan_recursive(root_path, set())
|
|
||||||
return loras
|
|
||||||
|
|
||||||
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
|
|
||||||
"""Process a single file and add to results list"""
|
|
||||||
try:
|
|
||||||
result = await self._process_model_file(file_path, root_path)
|
|
||||||
if result:
|
|
||||||
loras.append(result)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing {file_path}: {e}")
|
|
||||||
|
|
||||||
# Lora-specific hash index functionality
|
# Lora-specific hash index functionality
|
||||||
def has_lora_hash(self, sha256: str) -> bool:
|
def has_lora_hash(self, sha256: str) -> bool:
|
||||||
"""Check if a LoRA with given hash exists"""
|
"""Check if a LoRA with given hash exists"""
|
||||||
@@ -160,19 +76,3 @@ class LoraScanner(ModelScanner):
|
|||||||
test_hash_result = self._hash_index.get_hash(test_path)
|
test_hash_result = self._hash_index.get_hash(test_path)
|
||||||
print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr)
|
print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr)
|
||||||
|
|
||||||
async def get_lora_info_by_name(self, name):
|
|
||||||
"""Get LoRA information by name"""
|
|
||||||
try:
|
|
||||||
# Get cached data
|
|
||||||
cache = await self.get_cached_data()
|
|
||||||
|
|
||||||
# Find the LoRA by name
|
|
||||||
for lora in cache.raw_data:
|
|
||||||
if lora.get("file_name") == name:
|
|
||||||
return lora
|
|
||||||
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting LoRA info by name: {e}", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,30 @@ CACHE_VERSION = 3
|
|||||||
class ModelScanner:
|
class ModelScanner:
|
||||||
"""Base service for scanning and managing model files"""
|
"""Base service for scanning and managing model files"""
|
||||||
|
|
||||||
_lock = asyncio.Lock()
|
_instances = {} # Dictionary to store instances by class
|
||||||
|
_locks = {} # Dictionary to store locks by class
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
"""Implement singleton pattern for each subclass"""
|
||||||
|
if cls not in cls._instances:
|
||||||
|
cls._instances[cls] = super().__new__(cls)
|
||||||
|
return cls._instances[cls]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_lock(cls):
|
||||||
|
"""Get or create a lock for this class"""
|
||||||
|
if cls not in cls._locks:
|
||||||
|
cls._locks[cls] = asyncio.Lock()
|
||||||
|
return cls._locks[cls]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_instance(cls):
|
||||||
|
"""Get singleton instance with async support"""
|
||||||
|
lock = cls._get_lock()
|
||||||
|
async with lock:
|
||||||
|
if cls not in cls._instances:
|
||||||
|
cls._instances[cls] = cls()
|
||||||
|
return cls._instances[cls]
|
||||||
|
|
||||||
def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None):
|
def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None):
|
||||||
"""Initialize the scanner
|
"""Initialize the scanner
|
||||||
@@ -40,6 +63,10 @@ class ModelScanner:
|
|||||||
file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'})
|
file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'})
|
||||||
hash_index: Hash index instance (optional)
|
hash_index: Hash index instance (optional)
|
||||||
"""
|
"""
|
||||||
|
# Ensure initialization happens only once per instance
|
||||||
|
if hasattr(self, '_initialized'):
|
||||||
|
return
|
||||||
|
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.model_class = model_class
|
self.model_class = model_class
|
||||||
self.file_extensions = file_extensions
|
self.file_extensions = file_extensions
|
||||||
@@ -50,6 +77,7 @@ class ModelScanner:
|
|||||||
self._excluded_models = [] # List to track excluded models
|
self._excluded_models = [] # List to track excluded models
|
||||||
self._dirs_last_modified = {} # Track directory modification times
|
self._dirs_last_modified = {} # Track directory modification times
|
||||||
self._use_cache_files = False # Flag to control cache file usage, default to disabled
|
self._use_cache_files = False # Flag to control cache file usage, default to disabled
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
# Clear cache files if disabled
|
# Clear cache files if disabled
|
||||||
if not self._use_cache_files:
|
if not self._use_cache_files:
|
||||||
@@ -744,10 +772,68 @@ class ModelScanner:
|
|||||||
finally:
|
finally:
|
||||||
self._is_initializing = False # Unset flag
|
self._is_initializing = False # Unset flag
|
||||||
|
|
||||||
# These methods should be implemented in child classes
|
|
||||||
async def scan_all_models(self) -> List[Dict]:
|
async def scan_all_models(self) -> List[Dict]:
|
||||||
"""Scan all model directories and return metadata"""
|
"""Scan all model directories and return metadata"""
|
||||||
raise NotImplementedError("Subclasses must implement scan_all_models")
|
all_models = []
|
||||||
|
|
||||||
|
# Create scan tasks for each directory
|
||||||
|
scan_tasks = []
|
||||||
|
for model_root in self.get_model_roots():
|
||||||
|
task = asyncio.create_task(self._scan_directory(model_root))
|
||||||
|
scan_tasks.append(task)
|
||||||
|
|
||||||
|
# Wait for all tasks to complete
|
||||||
|
for task in scan_tasks:
|
||||||
|
try:
|
||||||
|
models = await task
|
||||||
|
all_models.extend(models)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error scanning directory: {e}")
|
||||||
|
|
||||||
|
return all_models
|
||||||
|
|
||||||
|
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||||
|
"""Scan a single directory for model files"""
|
||||||
|
models = []
|
||||||
|
original_root = root_path # Save original root path
|
||||||
|
|
||||||
|
async def scan_recursive(path: str, visited_paths: set):
|
||||||
|
"""Recursively scan directory, avoiding circular symlinks"""
|
||||||
|
try:
|
||||||
|
real_path = os.path.realpath(path)
|
||||||
|
if real_path in visited_paths:
|
||||||
|
logger.debug(f"Skipping already visited path: {path}")
|
||||||
|
return
|
||||||
|
visited_paths.add(real_path)
|
||||||
|
|
||||||
|
with os.scandir(path) as it:
|
||||||
|
entries = list(it)
|
||||||
|
for entry in entries:
|
||||||
|
try:
|
||||||
|
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
|
||||||
|
# Use original path instead of real path
|
||||||
|
file_path = entry.path.replace(os.sep, "/")
|
||||||
|
await self._process_single_file(file_path, original_root, models)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
elif entry.is_dir(follow_symlinks=True):
|
||||||
|
# For directories, continue scanning with original path
|
||||||
|
await scan_recursive(entry.path, visited_paths)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error scanning {path}: {e}")
|
||||||
|
|
||||||
|
await scan_recursive(root_path, set())
|
||||||
|
return models
|
||||||
|
|
||||||
|
async def _process_single_file(self, file_path: str, root_path: str, models: list):
|
||||||
|
"""Process a single file and add to results list"""
|
||||||
|
try:
|
||||||
|
result = await self._process_model_file(file_path, root_path)
|
||||||
|
if result:
|
||||||
|
models.append(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing {file_path}: {e}")
|
||||||
|
|
||||||
def is_initializing(self) -> bool:
|
def is_initializing(self) -> bool:
|
||||||
"""Check if the scanner is currently initializing"""
|
"""Check if the scanner is currently initializing"""
|
||||||
|
|||||||
Reference in New Issue
Block a user