feat(scanner): enhance model scanning with cache build result and progress reporting

This commit is contained in:
Will Miao
2025-10-02 21:25:09 +08:00
parent 375b5a49f3
commit 3b1990e97a
2 changed files with 368 additions and 178 deletions

View File

@@ -4,7 +4,8 @@ import logging
import asyncio
import time
import shutil
from typing import List, Dict, Optional, Type, Set
from dataclasses import dataclass
from typing import Awaitable, Callable, List, Dict, Optional, Type, Set
from ..utils.models import BaseModelMetadata
from ..config import config
@@ -19,6 +20,16 @@ from .websocket_manager import ws_manager
logger = logging.getLogger(__name__)
@dataclass
class CacheBuildResult:
"""Represents the outcome of scanning model files for cache building."""
raw_data: List[Dict]
hash_index: ModelHashIndex
tags_count: Dict[str, int]
excluded_models: List[str]
class ModelScanner:
"""Base service for scanning and managing model files"""
@@ -130,12 +141,15 @@ class ModelScanner:
start_time = time.time()
# Use thread pool to execute CPU-intensive operations with progress reporting
await loop.run_in_executor(
scan_result: Optional[CacheBuildResult] = await loop.run_in_executor(
None, # Use default thread pool
self._initialize_cache_sync, # Run synchronous version in thread
total_files, # Pass the total file count for progress reporting
page_type # Pass the page type for progress reporting
)
if scan_result:
await self._apply_scan_result(scan_result)
# Send final progress update
await ws_manager.broadcast_init_progress({
@@ -204,124 +218,53 @@ class ModelScanner:
return total_files
def _initialize_cache_sync(self, total_files=0, page_type='loras'):
def _initialize_cache_sync(self, total_files: int = 0, page_type: str = 'loras') -> Optional[CacheBuildResult]:
"""Synchronous version of cache initialization for thread pool execution"""
loop = asyncio.new_event_loop()
try:
# Create a new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Create a synchronous method to bypass the async lock
def sync_initialize_cache():
# Track progress
processed_files = 0
last_progress_time = time.time()
last_progress_percent = 0
# We need a wrapper around scan_all_models to track progress
# This is a local function that will run in our thread's event loop
async def scan_with_progress():
nonlocal processed_files, last_progress_time, last_progress_percent
# For storing raw model data
all_models = []
# Process each model root
for root_path in self.get_model_roots():
if not os.path.exists(root_path):
continue
# Track visited paths to avoid symlink loops
visited_paths = set()
# Recursively process directory
async def scan_dir_with_progress(path):
nonlocal processed_files, last_progress_time, last_progress_percent
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True):
ext = os.path.splitext(entry.name)[1].lower()
if ext in self.file_extensions:
file_path = entry.path.replace(os.sep, "/")
result = await self._process_model_file(file_path, root_path)
if result:
all_models.append(result)
# Update progress counter
processed_files += 1
# Update progress periodically (not every file to avoid excessive updates)
current_time = time.time()
if total_files > 0 and (current_time - last_progress_time > 0.5 or processed_files == total_files):
# Adjusted progress calculation
progress_percent = min(99, int(1 + (processed_files / total_files) * 98))
if progress_percent > last_progress_percent:
last_progress_percent = progress_percent
last_progress_time = current_time
# Send progress update through websocket
await ws_manager.broadcast_init_progress({
'stage': 'process_models',
'progress': progress_percent,
'details': f"Processing {self.model_type} files: {processed_files}/{total_files}",
'scanner_type': self.model_type,
'pageType': page_type
})
elif entry.is_dir(follow_symlinks=True):
await scan_dir_with_progress(entry.path)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
# Process the root path
await scan_dir_with_progress(root_path)
return all_models
# Run the progress-tracking scan function
raw_data = loop.run_until_complete(scan_with_progress())
# Update hash index and tags count
for model_data in raw_data:
if 'sha256' in model_data and 'file_path' in model_data:
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
# Count tags
if 'tags' in model_data and model_data['tags']:
for tag in model_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Log duplicate filename warnings after building the index
# duplicate_filenames = self._hash_index.get_duplicate_filenames()
# if duplicate_filenames:
# logger.warning(f"Found {len(duplicate_filenames)} filename(s) with duplicates during {self.model_type} cache build:")
# for filename, paths in duplicate_filenames.items():
# logger.warning(f" Duplicate filename '{filename}': {paths}")
# Update cache
self._cache.raw_data = raw_data
loop.run_until_complete(self._cache.resort())
return self._cache
# Run our sync initialization that avoids lock conflicts
return sync_initialize_cache()
last_progress_time = time.time()
last_progress_percent = 0
async def progress_callback(processed_files: int, expected_total: int) -> None:
nonlocal last_progress_time, last_progress_percent
if expected_total <= 0:
return
current_time = time.time()
progress_percent = min(99, int(1 + (processed_files / expected_total) * 98))
if progress_percent <= last_progress_percent:
return
if current_time - last_progress_time <= 0.5 and processed_files != expected_total:
return
last_progress_percent = progress_percent
last_progress_time = current_time
await ws_manager.broadcast_init_progress({
'stage': 'process_models',
'progress': progress_percent,
'details': f"Processing {self.model_type} files: {processed_files}/{expected_total}",
'scanner_type': self.model_type,
'pageType': page_type
})
return loop.run_until_complete(
self._gather_model_data(
total_files=total_files,
progress_callback=progress_callback
)
)
except Exception as e:
logger.error(f"Error in thread-based {self.model_type} cache initialization: {e}")
return None
finally:
# Clean up the event loop
asyncio.set_event_loop(None)
loop.close()
async def get_cached_data(self, force_refresh: bool = False, rebuild_cache: bool = False) -> ModelCache:
@@ -353,45 +296,15 @@ class ModelScanner:
self._is_initializing = True # Set flag
try:
start_time = time.time()
# Clear existing hash index
self._hash_index.clear()
# Clear existing tags count
self._tags_count = {}
# Determine the page type based on model type
page_type = 'loras' if self.model_type == 'lora' else 'checkpoints'
# Scan for new data
raw_data = await self.scan_all_models()
# Build hash index and tags count
for model_data in raw_data:
if 'sha256' in model_data and 'file_path' in model_data:
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
# Count tags
if 'tags' in model_data and model_data['tags']:
for tag in model_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Log duplicate filename warnings after building the index
# duplicate_filenames = self._hash_index.get_duplicate_filenames()
# if duplicate_filenames:
# logger.warning(f"Found {len(duplicate_filenames)} filename(s) with duplicates during {self.model_type} cache build:")
# for filename, paths in duplicate_filenames.items():
# logger.warning(f" Duplicate filename '{filename}': {paths}")
# Update cache
self._cache = ModelCache(
raw_data=raw_data,
folders=[]
)
# Resort cache
await self._cache.resort()
scan_result = await self._gather_model_data()
await self._apply_scan_result(scan_result)
logger.info(f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} models")
logger.info(
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
f"found {len(scan_result.raw_data)} models"
)
except Exception as e:
logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache: {e}")
# Ensure cache is at least an empty structure on error
@@ -547,23 +460,9 @@ class ModelScanner:
async def scan_all_models(self) -> List[Dict]:
"""Scan all model directories and return metadata"""
all_models = []
# Create scan tasks for each directory
scan_tasks = []
for model_root in self.get_model_roots():
task = asyncio.create_task(self._scan_directory(model_root))
scan_tasks.append(task)
# Wait for all tasks to complete
for task in scan_tasks:
try:
models = await task
all_models.extend(models)
except Exception as e:
logger.error(f"Error scanning directory: {e}")
return all_models
scan_result = await self._gather_model_data()
self._excluded_models = scan_result.excluded_models
return scan_result.raw_data
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a single directory for model files"""
@@ -624,8 +523,18 @@ class ModelScanner:
"""Hook for subclasses: adjust metadata during scanning"""
return metadata
async def _process_model_file(self, file_path: str, root_path: str) -> Dict:
async def _process_model_file(
self,
file_path: str,
root_path: str,
*,
hash_index: Optional[ModelHashIndex] = None,
excluded_models: Optional[List[str]] = None
) -> Dict:
"""Process a single model file and return its metadata"""
hash_index = hash_index or self._hash_index
excluded_models = excluded_models if excluded_models is not None else self._excluded_models
metadata, should_skip = await MetadataManager.load_metadata(file_path, self.model_class)
if should_skip:
@@ -689,26 +598,130 @@ class ModelScanner:
# Skip excluded models
if model_data.get('exclude', False):
self._excluded_models.append(model_data['file_path'])
excluded_models.append(model_data['file_path'])
return None
# Check for duplicate filename before adding to hash index
filename = os.path.splitext(os.path.basename(file_path))[0]
existing_hash = self._hash_index.get_hash_by_filename(filename)
if existing_hash and existing_hash != model_data.get('sha256', '').lower():
existing_path = self._hash_index.get_path(existing_hash)
if existing_path and existing_path != file_path:
logger.warning(f"Duplicate filename detected: '{filename}' - files: '{existing_path}' and '{file_path}'")
# filename = os.path.splitext(os.path.basename(file_path))[0]
# existing_hash = hash_index.get_hash_by_filename(filename)
# if existing_hash and existing_hash != model_data.get('sha256', '').lower():
# existing_path = hash_index.get_path(existing_hash)
# if existing_path and existing_path != file_path:
# logger.warning(f"Duplicate filename detected: '{filename}' - files: '{existing_path}' and '{file_path}'")
rel_path = os.path.relpath(file_path, root_path)
folder = os.path.dirname(rel_path)
model_data['folder'] = folder.replace(os.path.sep, '/')
return model_data
async def _apply_scan_result(self, scan_result: CacheBuildResult) -> None:
"""Apply scan results to the cache and associated indexes."""
if scan_result is None:
return
self._hash_index = scan_result.hash_index
self._tags_count = dict(scan_result.tags_count)
self._excluded_models = list(scan_result.excluded_models)
if self._cache is None:
self._cache = ModelCache(
raw_data=list(scan_result.raw_data),
folders=[]
)
else:
self._cache.raw_data = list(scan_result.raw_data)
await self._cache.resort()
async def _gather_model_data(
self,
*,
total_files: int = 0,
progress_callback: Optional[Callable[[int, int], Awaitable[None]]] = None
) -> CacheBuildResult:
"""Collect metadata for all model files."""
raw_data: List[Dict] = []
hash_index = ModelHashIndex()
tags_count: Dict[str, int] = {}
excluded_models: List[str] = []
processed_files = 0
async def handle_progress() -> None:
if progress_callback is None:
return
try:
await progress_callback(processed_files, total_files)
except Exception as exc: # pragma: no cover - defensive logging
logger.error(f"Error reporting progress for {self.model_type}: {exc}")
async def scan_recursive(current_path: str, root_path: str, visited_paths: Set[str]) -> None:
nonlocal processed_files
try:
real_path = os.path.realpath(current_path)
if real_path in visited_paths:
return
visited_paths.add(real_path)
with os.scandir(current_path) as iterator:
entries = list(iterator)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True):
ext = os.path.splitext(entry.name)[1].lower()
if ext not in self.file_extensions:
continue
file_path = entry.path.replace(os.sep, "/")
result = await self._process_model_file(
file_path,
root_path,
hash_index=hash_index,
excluded_models=excluded_models
)
processed_files += 1
if result:
raw_data.append(result)
sha_value = result.get('sha256')
model_path = result.get('file_path')
if sha_value and model_path:
hash_index.add_entry(sha_value.lower(), model_path)
for tag in result.get('tags') or []:
tags_count[tag] = tags_count.get(tag, 0) + 1
await handle_progress()
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
await scan_recursive(entry.path, root_path, visited_paths)
except Exception as entry_error:
logger.error(f"Error processing entry {entry.path}: {entry_error}")
except Exception as scan_error:
logger.error(f"Error scanning {current_path}: {scan_error}")
for model_root in self.get_model_roots():
if not os.path.exists(model_root):
continue
await scan_recursive(model_root, model_root, set())
return CacheBuildResult(
raw_data=raw_data,
hash_index=hash_index,
tags_count=tags_count,
excluded_models=excluded_models
)
async def add_model_to_cache(self, metadata_dict: Dict, folder: str = '') -> bool:
"""Add a model to the cache
Args:
metadata_dict: The model metadata dictionary
folder: The relative folder path for the model