mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(scanner): enhance model scanning with cache build result and progress reporting
This commit is contained in:
@@ -4,7 +4,8 @@ import logging
|
||||
import asyncio
|
||||
import time
|
||||
import shutil
|
||||
from typing import List, Dict, Optional, Type, Set
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable, List, Dict, Optional, Type, Set
|
||||
|
||||
from ..utils.models import BaseModelMetadata
|
||||
from ..config import config
|
||||
@@ -19,6 +20,16 @@ from .websocket_manager import ws_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheBuildResult:
|
||||
"""Represents the outcome of scanning model files for cache building."""
|
||||
|
||||
raw_data: List[Dict]
|
||||
hash_index: ModelHashIndex
|
||||
tags_count: Dict[str, int]
|
||||
excluded_models: List[str]
|
||||
|
||||
class ModelScanner:
|
||||
"""Base service for scanning and managing model files"""
|
||||
|
||||
@@ -130,12 +141,15 @@ class ModelScanner:
|
||||
start_time = time.time()
|
||||
|
||||
# Use thread pool to execute CPU-intensive operations with progress reporting
|
||||
await loop.run_in_executor(
|
||||
scan_result: Optional[CacheBuildResult] = await loop.run_in_executor(
|
||||
None, # Use default thread pool
|
||||
self._initialize_cache_sync, # Run synchronous version in thread
|
||||
total_files, # Pass the total file count for progress reporting
|
||||
page_type # Pass the page type for progress reporting
|
||||
)
|
||||
|
||||
if scan_result:
|
||||
await self._apply_scan_result(scan_result)
|
||||
|
||||
# Send final progress update
|
||||
await ws_manager.broadcast_init_progress({
|
||||
@@ -204,124 +218,53 @@ class ModelScanner:
|
||||
|
||||
return total_files
|
||||
|
||||
def _initialize_cache_sync(self, total_files=0, page_type='loras'):
|
||||
def _initialize_cache_sync(self, total_files: int = 0, page_type: str = 'loras') -> Optional[CacheBuildResult]:
|
||||
"""Synchronous version of cache initialization for thread pool execution"""
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# Create a new event loop for this thread
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Create a synchronous method to bypass the async lock
|
||||
def sync_initialize_cache():
|
||||
# Track progress
|
||||
processed_files = 0
|
||||
last_progress_time = time.time()
|
||||
last_progress_percent = 0
|
||||
|
||||
# We need a wrapper around scan_all_models to track progress
|
||||
# This is a local function that will run in our thread's event loop
|
||||
async def scan_with_progress():
|
||||
nonlocal processed_files, last_progress_time, last_progress_percent
|
||||
|
||||
# For storing raw model data
|
||||
all_models = []
|
||||
|
||||
# Process each model root
|
||||
for root_path in self.get_model_roots():
|
||||
if not os.path.exists(root_path):
|
||||
continue
|
||||
|
||||
# Track visited paths to avoid symlink loops
|
||||
visited_paths = set()
|
||||
|
||||
# Recursively process directory
|
||||
async def scan_dir_with_progress(path):
|
||||
nonlocal processed_files, last_progress_time, last_progress_percent
|
||||
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
entries = list(it)
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True):
|
||||
ext = os.path.splitext(entry.name)[1].lower()
|
||||
if ext in self.file_extensions:
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
result = await self._process_model_file(file_path, root_path)
|
||||
if result:
|
||||
all_models.append(result)
|
||||
|
||||
# Update progress counter
|
||||
processed_files += 1
|
||||
|
||||
# Update progress periodically (not every file to avoid excessive updates)
|
||||
current_time = time.time()
|
||||
if total_files > 0 and (current_time - last_progress_time > 0.5 or processed_files == total_files):
|
||||
# Adjusted progress calculation
|
||||
progress_percent = min(99, int(1 + (processed_files / total_files) * 98))
|
||||
if progress_percent > last_progress_percent:
|
||||
last_progress_percent = progress_percent
|
||||
last_progress_time = current_time
|
||||
|
||||
# Send progress update through websocket
|
||||
await ws_manager.broadcast_init_progress({
|
||||
'stage': 'process_models',
|
||||
'progress': progress_percent,
|
||||
'details': f"Processing {self.model_type} files: {processed_files}/{total_files}",
|
||||
'scanner_type': self.model_type,
|
||||
'pageType': page_type
|
||||
})
|
||||
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
await scan_dir_with_progress(entry.path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {path}: {e}")
|
||||
|
||||
# Process the root path
|
||||
await scan_dir_with_progress(root_path)
|
||||
|
||||
return all_models
|
||||
|
||||
# Run the progress-tracking scan function
|
||||
raw_data = loop.run_until_complete(scan_with_progress())
|
||||
|
||||
# Update hash index and tags count
|
||||
for model_data in raw_data:
|
||||
if 'sha256' in model_data and 'file_path' in model_data:
|
||||
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
|
||||
|
||||
# Count tags
|
||||
if 'tags' in model_data and model_data['tags']:
|
||||
for tag in model_data['tags']:
|
||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||
|
||||
# Log duplicate filename warnings after building the index
|
||||
# duplicate_filenames = self._hash_index.get_duplicate_filenames()
|
||||
# if duplicate_filenames:
|
||||
# logger.warning(f"Found {len(duplicate_filenames)} filename(s) with duplicates during {self.model_type} cache build:")
|
||||
# for filename, paths in duplicate_filenames.items():
|
||||
# logger.warning(f" Duplicate filename '{filename}': {paths}")
|
||||
|
||||
# Update cache
|
||||
self._cache.raw_data = raw_data
|
||||
loop.run_until_complete(self._cache.resort())
|
||||
|
||||
return self._cache
|
||||
|
||||
# Run our sync initialization that avoids lock conflicts
|
||||
return sync_initialize_cache()
|
||||
|
||||
last_progress_time = time.time()
|
||||
last_progress_percent = 0
|
||||
|
||||
async def progress_callback(processed_files: int, expected_total: int) -> None:
|
||||
nonlocal last_progress_time, last_progress_percent
|
||||
|
||||
if expected_total <= 0:
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
progress_percent = min(99, int(1 + (processed_files / expected_total) * 98))
|
||||
|
||||
if progress_percent <= last_progress_percent:
|
||||
return
|
||||
|
||||
if current_time - last_progress_time <= 0.5 and processed_files != expected_total:
|
||||
return
|
||||
|
||||
last_progress_percent = progress_percent
|
||||
last_progress_time = current_time
|
||||
|
||||
await ws_manager.broadcast_init_progress({
|
||||
'stage': 'process_models',
|
||||
'progress': progress_percent,
|
||||
'details': f"Processing {self.model_type} files: {processed_files}/{expected_total}",
|
||||
'scanner_type': self.model_type,
|
||||
'pageType': page_type
|
||||
})
|
||||
|
||||
return loop.run_until_complete(
|
||||
self._gather_model_data(
|
||||
total_files=total_files,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in thread-based {self.model_type} cache initialization: {e}")
|
||||
return None
|
||||
finally:
|
||||
# Clean up the event loop
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
async def get_cached_data(self, force_refresh: bool = False, rebuild_cache: bool = False) -> ModelCache:
|
||||
@@ -353,45 +296,15 @@ class ModelScanner:
|
||||
self._is_initializing = True # Set flag
|
||||
try:
|
||||
start_time = time.time()
|
||||
# Clear existing hash index
|
||||
self._hash_index.clear()
|
||||
|
||||
# Clear existing tags count
|
||||
self._tags_count = {}
|
||||
|
||||
# Determine the page type based on model type
|
||||
page_type = 'loras' if self.model_type == 'lora' else 'checkpoints'
|
||||
|
||||
# Scan for new data
|
||||
raw_data = await self.scan_all_models()
|
||||
|
||||
# Build hash index and tags count
|
||||
for model_data in raw_data:
|
||||
if 'sha256' in model_data and 'file_path' in model_data:
|
||||
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
|
||||
|
||||
# Count tags
|
||||
if 'tags' in model_data and model_data['tags']:
|
||||
for tag in model_data['tags']:
|
||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||
|
||||
# Log duplicate filename warnings after building the index
|
||||
# duplicate_filenames = self._hash_index.get_duplicate_filenames()
|
||||
# if duplicate_filenames:
|
||||
# logger.warning(f"Found {len(duplicate_filenames)} filename(s) with duplicates during {self.model_type} cache build:")
|
||||
# for filename, paths in duplicate_filenames.items():
|
||||
# logger.warning(f" Duplicate filename '{filename}': {paths}")
|
||||
|
||||
# Update cache
|
||||
self._cache = ModelCache(
|
||||
raw_data=raw_data,
|
||||
folders=[]
|
||||
)
|
||||
|
||||
# Resort cache
|
||||
await self._cache.resort()
|
||||
scan_result = await self._gather_model_data()
|
||||
await self._apply_scan_result(scan_result)
|
||||
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} models")
|
||||
logger.info(
|
||||
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
|
||||
f"found {len(scan_result.raw_data)} models"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache: {e}")
|
||||
# Ensure cache is at least an empty structure on error
|
||||
@@ -547,23 +460,9 @@ class ModelScanner:
|
||||
|
||||
async def scan_all_models(self) -> List[Dict]:
|
||||
"""Scan all model directories and return metadata"""
|
||||
all_models = []
|
||||
|
||||
# Create scan tasks for each directory
|
||||
scan_tasks = []
|
||||
for model_root in self.get_model_roots():
|
||||
task = asyncio.create_task(self._scan_directory(model_root))
|
||||
scan_tasks.append(task)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for task in scan_tasks:
|
||||
try:
|
||||
models = await task
|
||||
all_models.extend(models)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning directory: {e}")
|
||||
|
||||
return all_models
|
||||
scan_result = await self._gather_model_data()
|
||||
self._excluded_models = scan_result.excluded_models
|
||||
return scan_result.raw_data
|
||||
|
||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||
"""Scan a single directory for model files"""
|
||||
@@ -624,8 +523,18 @@ class ModelScanner:
|
||||
"""Hook for subclasses: adjust metadata during scanning"""
|
||||
return metadata
|
||||
|
||||
async def _process_model_file(self, file_path: str, root_path: str) -> Dict:
|
||||
async def _process_model_file(
|
||||
self,
|
||||
file_path: str,
|
||||
root_path: str,
|
||||
*,
|
||||
hash_index: Optional[ModelHashIndex] = None,
|
||||
excluded_models: Optional[List[str]] = None
|
||||
) -> Dict:
|
||||
"""Process a single model file and return its metadata"""
|
||||
hash_index = hash_index or self._hash_index
|
||||
excluded_models = excluded_models if excluded_models is not None else self._excluded_models
|
||||
|
||||
metadata, should_skip = await MetadataManager.load_metadata(file_path, self.model_class)
|
||||
|
||||
if should_skip:
|
||||
@@ -689,26 +598,130 @@ class ModelScanner:
|
||||
|
||||
# Skip excluded models
|
||||
if model_data.get('exclude', False):
|
||||
self._excluded_models.append(model_data['file_path'])
|
||||
excluded_models.append(model_data['file_path'])
|
||||
return None
|
||||
|
||||
|
||||
# Check for duplicate filename before adding to hash index
|
||||
filename = os.path.splitext(os.path.basename(file_path))[0]
|
||||
existing_hash = self._hash_index.get_hash_by_filename(filename)
|
||||
if existing_hash and existing_hash != model_data.get('sha256', '').lower():
|
||||
existing_path = self._hash_index.get_path(existing_hash)
|
||||
if existing_path and existing_path != file_path:
|
||||
logger.warning(f"Duplicate filename detected: '{filename}' - files: '{existing_path}' and '{file_path}'")
|
||||
# filename = os.path.splitext(os.path.basename(file_path))[0]
|
||||
# existing_hash = hash_index.get_hash_by_filename(filename)
|
||||
# if existing_hash and existing_hash != model_data.get('sha256', '').lower():
|
||||
# existing_path = hash_index.get_path(existing_hash)
|
||||
# if existing_path and existing_path != file_path:
|
||||
# logger.warning(f"Duplicate filename detected: '{filename}' - files: '{existing_path}' and '{file_path}'")
|
||||
|
||||
rel_path = os.path.relpath(file_path, root_path)
|
||||
folder = os.path.dirname(rel_path)
|
||||
model_data['folder'] = folder.replace(os.path.sep, '/')
|
||||
|
||||
|
||||
return model_data
|
||||
|
||||
async def _apply_scan_result(self, scan_result: CacheBuildResult) -> None:
|
||||
"""Apply scan results to the cache and associated indexes."""
|
||||
|
||||
if scan_result is None:
|
||||
return
|
||||
|
||||
self._hash_index = scan_result.hash_index
|
||||
self._tags_count = dict(scan_result.tags_count)
|
||||
self._excluded_models = list(scan_result.excluded_models)
|
||||
|
||||
if self._cache is None:
|
||||
self._cache = ModelCache(
|
||||
raw_data=list(scan_result.raw_data),
|
||||
folders=[]
|
||||
)
|
||||
else:
|
||||
self._cache.raw_data = list(scan_result.raw_data)
|
||||
|
||||
await self._cache.resort()
|
||||
|
||||
async def _gather_model_data(
|
||||
self,
|
||||
*,
|
||||
total_files: int = 0,
|
||||
progress_callback: Optional[Callable[[int, int], Awaitable[None]]] = None
|
||||
) -> CacheBuildResult:
|
||||
"""Collect metadata for all model files."""
|
||||
|
||||
raw_data: List[Dict] = []
|
||||
hash_index = ModelHashIndex()
|
||||
tags_count: Dict[str, int] = {}
|
||||
excluded_models: List[str] = []
|
||||
processed_files = 0
|
||||
|
||||
async def handle_progress() -> None:
|
||||
if progress_callback is None:
|
||||
return
|
||||
try:
|
||||
await progress_callback(processed_files, total_files)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error(f"Error reporting progress for {self.model_type}: {exc}")
|
||||
|
||||
async def scan_recursive(current_path: str, root_path: str, visited_paths: Set[str]) -> None:
|
||||
nonlocal processed_files
|
||||
|
||||
try:
|
||||
real_path = os.path.realpath(current_path)
|
||||
if real_path in visited_paths:
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(current_path) as iterator:
|
||||
entries = list(iterator)
|
||||
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True):
|
||||
ext = os.path.splitext(entry.name)[1].lower()
|
||||
if ext not in self.file_extensions:
|
||||
continue
|
||||
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
result = await self._process_model_file(
|
||||
file_path,
|
||||
root_path,
|
||||
hash_index=hash_index,
|
||||
excluded_models=excluded_models
|
||||
)
|
||||
|
||||
processed_files += 1
|
||||
|
||||
if result:
|
||||
raw_data.append(result)
|
||||
|
||||
sha_value = result.get('sha256')
|
||||
model_path = result.get('file_path')
|
||||
if sha_value and model_path:
|
||||
hash_index.add_entry(sha_value.lower(), model_path)
|
||||
|
||||
for tag in result.get('tags') or []:
|
||||
tags_count[tag] = tags_count.get(tag, 0) + 1
|
||||
|
||||
await handle_progress()
|
||||
await asyncio.sleep(0)
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
await scan_recursive(entry.path, root_path, visited_paths)
|
||||
except Exception as entry_error:
|
||||
logger.error(f"Error processing entry {entry.path}: {entry_error}")
|
||||
except Exception as scan_error:
|
||||
logger.error(f"Error scanning {current_path}: {scan_error}")
|
||||
|
||||
for model_root in self.get_model_roots():
|
||||
if not os.path.exists(model_root):
|
||||
continue
|
||||
|
||||
await scan_recursive(model_root, model_root, set())
|
||||
|
||||
return CacheBuildResult(
|
||||
raw_data=raw_data,
|
||||
hash_index=hash_index,
|
||||
tags_count=tags_count,
|
||||
excluded_models=excluded_models
|
||||
)
|
||||
|
||||
async def add_model_to_cache(self, metadata_dict: Dict, folder: str = '') -> bool:
|
||||
"""Add a model to the cache
|
||||
|
||||
|
||||
Args:
|
||||
metadata_dict: The model metadata dictionary
|
||||
folder: The relative folder path for the model
|
||||
|
||||
Reference in New Issue
Block a user