diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index 51aa4507..ccba3b2c 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -4,7 +4,8 @@ import logging import asyncio import time import shutil -from typing import List, Dict, Optional, Type, Set +from dataclasses import dataclass +from typing import Awaitable, Callable, List, Dict, Optional, Type, Set from ..utils.models import BaseModelMetadata from ..config import config @@ -19,6 +20,16 @@ from .websocket_manager import ws_manager logger = logging.getLogger(__name__) + +@dataclass +class CacheBuildResult: + """Represents the outcome of scanning model files for cache building.""" + + raw_data: List[Dict] + hash_index: ModelHashIndex + tags_count: Dict[str, int] + excluded_models: List[str] + class ModelScanner: """Base service for scanning and managing model files""" @@ -130,12 +141,15 @@ class ModelScanner: start_time = time.time() # Use thread pool to execute CPU-intensive operations with progress reporting - await loop.run_in_executor( + scan_result: Optional[CacheBuildResult] = await loop.run_in_executor( None, # Use default thread pool self._initialize_cache_sync, # Run synchronous version in thread total_files, # Pass the total file count for progress reporting page_type # Pass the page type for progress reporting ) + + if scan_result: + await self._apply_scan_result(scan_result) # Send final progress update await ws_manager.broadcast_init_progress({ @@ -204,124 +218,53 @@ class ModelScanner: return total_files - def _initialize_cache_sync(self, total_files=0, page_type='loras'): + def _initialize_cache_sync(self, total_files: int = 0, page_type: str = 'loras') -> Optional[CacheBuildResult]: """Synchronous version of cache initialization for thread pool execution""" + + loop = asyncio.new_event_loop() try: - # Create a new event loop for this thread - loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - - # Create a synchronous method to bypass the async lock - def sync_initialize_cache(): - # Track progress - processed_files = 0 - last_progress_time = time.time() - last_progress_percent = 0 - - # We need a wrapper around scan_all_models to track progress - # This is a local function that will run in our thread's event loop - async def scan_with_progress(): - nonlocal processed_files, last_progress_time, last_progress_percent - - # For storing raw model data - all_models = [] - - # Process each model root - for root_path in self.get_model_roots(): - if not os.path.exists(root_path): - continue - - # Track visited paths to avoid symlink loops - visited_paths = set() - - # Recursively process directory - async def scan_dir_with_progress(path): - nonlocal processed_files, last_progress_time, last_progress_percent - - try: - real_path = os.path.realpath(path) - if real_path in visited_paths: - return - visited_paths.add(real_path) - - with os.scandir(path) as it: - entries = list(it) - for entry in entries: - try: - if entry.is_file(follow_symlinks=True): - ext = os.path.splitext(entry.name)[1].lower() - if ext in self.file_extensions: - file_path = entry.path.replace(os.sep, "/") - result = await self._process_model_file(file_path, root_path) - if result: - all_models.append(result) - - # Update progress counter - processed_files += 1 - - # Update progress periodically (not every file to avoid excessive updates) - current_time = time.time() - if total_files > 0 and (current_time - last_progress_time > 0.5 or processed_files == total_files): - # Adjusted progress calculation - progress_percent = min(99, int(1 + (processed_files / total_files) * 98)) - if progress_percent > last_progress_percent: - last_progress_percent = progress_percent - last_progress_time = current_time - - # Send progress update through websocket - await ws_manager.broadcast_init_progress({ - 'stage': 'process_models', - 'progress': progress_percent, - 'details': f"Processing {self.model_type} files: {processed_files}/{total_files}", - 'scanner_type': self.model_type, - 'pageType': page_type - }) - - elif entry.is_dir(follow_symlinks=True): - await scan_dir_with_progress(entry.path) - - except Exception as e: - logger.error(f"Error processing entry {entry.path}: {e}") - except Exception as e: - logger.error(f"Error scanning {path}: {e}") - - # Process the root path - await scan_dir_with_progress(root_path) - - return all_models - - # Run the progress-tracking scan function - raw_data = loop.run_until_complete(scan_with_progress()) - - # Update hash index and tags count - for model_data in raw_data: - if 'sha256' in model_data and 'file_path' in model_data: - self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path']) - - # Count tags - if 'tags' in model_data and model_data['tags']: - for tag in model_data['tags']: - self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 - - # Log duplicate filename warnings after building the index - # duplicate_filenames = self._hash_index.get_duplicate_filenames() - # if duplicate_filenames: - # logger.warning(f"Found {len(duplicate_filenames)} filename(s) with duplicates during {self.model_type} cache build:") - # for filename, paths in duplicate_filenames.items(): - # logger.warning(f" Duplicate filename '{filename}': {paths}") - - # Update cache - self._cache.raw_data = raw_data - loop.run_until_complete(self._cache.resort()) - - return self._cache - - # Run our sync initialization that avoids lock conflicts - return sync_initialize_cache() + + last_progress_time = time.time() + last_progress_percent = 0 + + async def progress_callback(processed_files: int, expected_total: int) -> None: + nonlocal last_progress_time, last_progress_percent + + if expected_total <= 0: + return + + current_time = time.time() + progress_percent = min(99, int(1 + (processed_files / expected_total) * 98)) + + if progress_percent <= last_progress_percent: + return + + if current_time - last_progress_time <= 0.5 and processed_files != expected_total: + return + + last_progress_percent = progress_percent + last_progress_time = current_time + + await ws_manager.broadcast_init_progress({ + 'stage': 'process_models', + 'progress': progress_percent, + 'details': f"Processing {self.model_type} files: {processed_files}/{expected_total}", + 'scanner_type': self.model_type, + 'pageType': page_type + }) + + return loop.run_until_complete( + self._gather_model_data( + total_files=total_files, + progress_callback=progress_callback + ) + ) except Exception as e: logger.error(f"Error in thread-based {self.model_type} cache initialization: {e}") + return None finally: - # Clean up the event loop + asyncio.set_event_loop(None) loop.close() async def get_cached_data(self, force_refresh: bool = False, rebuild_cache: bool = False) -> ModelCache: @@ -353,45 +296,15 @@ class ModelScanner: self._is_initializing = True # Set flag try: start_time = time.time() - # Clear existing hash index - self._hash_index.clear() - - # Clear existing tags count - self._tags_count = {} - # Determine the page type based on model type - page_type = 'loras' if self.model_type == 'lora' else 'checkpoints' - # Scan for new data - raw_data = await self.scan_all_models() - - # Build hash index and tags count - for model_data in raw_data: - if 'sha256' in model_data and 'file_path' in model_data: - self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path']) - - # Count tags - if 'tags' in model_data and model_data['tags']: - for tag in model_data['tags']: - self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 - - # Log duplicate filename warnings after building the index - # duplicate_filenames = self._hash_index.get_duplicate_filenames() - # if duplicate_filenames: - # logger.warning(f"Found {len(duplicate_filenames)} filename(s) with duplicates during {self.model_type} cache build:") - # for filename, paths in duplicate_filenames.items(): - # logger.warning(f" Duplicate filename '{filename}': {paths}") - - # Update cache - self._cache = ModelCache( - raw_data=raw_data, - folders=[] - ) - - # Resort cache - await self._cache.resort() + scan_result = await self._gather_model_data() + await self._apply_scan_result(scan_result) - logger.info(f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} models") + logger.info( + f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, " + f"found {len(scan_result.raw_data)} models" + ) except Exception as e: logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache: {e}") # Ensure cache is at least an empty structure on error @@ -547,23 +460,9 @@ class ModelScanner: async def scan_all_models(self) -> List[Dict]: """Scan all model directories and return metadata""" - all_models = [] - - # Create scan tasks for each directory - scan_tasks = [] - for model_root in self.get_model_roots(): - task = asyncio.create_task(self._scan_directory(model_root)) - scan_tasks.append(task) - - # Wait for all tasks to complete - for task in scan_tasks: - try: - models = await task - all_models.extend(models) - except Exception as e: - logger.error(f"Error scanning directory: {e}") - - return all_models + scan_result = await self._gather_model_data() + self._excluded_models = scan_result.excluded_models + return scan_result.raw_data async def _scan_directory(self, root_path: str) -> List[Dict]: """Scan a single directory for model files""" @@ -624,8 +523,18 @@ class ModelScanner: """Hook for subclasses: adjust metadata during scanning""" return metadata - async def _process_model_file(self, file_path: str, root_path: str) -> Dict: + async def _process_model_file( + self, + file_path: str, + root_path: str, + *, + hash_index: Optional[ModelHashIndex] = None, + excluded_models: Optional[List[str]] = None + ) -> Dict: """Process a single model file and return its metadata""" + hash_index = hash_index or self._hash_index + excluded_models = excluded_models if excluded_models is not None else self._excluded_models + metadata, should_skip = await MetadataManager.load_metadata(file_path, self.model_class) if should_skip: @@ -689,26 +598,130 @@ class ModelScanner: # Skip excluded models if model_data.get('exclude', False): - self._excluded_models.append(model_data['file_path']) + excluded_models.append(model_data['file_path']) return None - + # Check for duplicate filename before adding to hash index - filename = os.path.splitext(os.path.basename(file_path))[0] - existing_hash = self._hash_index.get_hash_by_filename(filename) - if existing_hash and existing_hash != model_data.get('sha256', '').lower(): - existing_path = self._hash_index.get_path(existing_hash) - if existing_path and existing_path != file_path: - logger.warning(f"Duplicate filename detected: '{filename}' - files: '{existing_path}' and '{file_path}'") + # filename = os.path.splitext(os.path.basename(file_path))[0] + # existing_hash = hash_index.get_hash_by_filename(filename) + # if existing_hash and existing_hash != model_data.get('sha256', '').lower(): + # existing_path = hash_index.get_path(existing_hash) + # if existing_path and existing_path != file_path: + # logger.warning(f"Duplicate filename detected: '{filename}' - files: '{existing_path}' and '{file_path}'") rel_path = os.path.relpath(file_path, root_path) folder = os.path.dirname(rel_path) model_data['folder'] = folder.replace(os.path.sep, '/') - + return model_data + async def _apply_scan_result(self, scan_result: CacheBuildResult) -> None: + """Apply scan results to the cache and associated indexes.""" + + if scan_result is None: + return + + self._hash_index = scan_result.hash_index + self._tags_count = dict(scan_result.tags_count) + self._excluded_models = list(scan_result.excluded_models) + + if self._cache is None: + self._cache = ModelCache( + raw_data=list(scan_result.raw_data), + folders=[] + ) + else: + self._cache.raw_data = list(scan_result.raw_data) + + await self._cache.resort() + + async def _gather_model_data( + self, + *, + total_files: int = 0, + progress_callback: Optional[Callable[[int, int], Awaitable[None]]] = None + ) -> CacheBuildResult: + """Collect metadata for all model files.""" + + raw_data: List[Dict] = [] + hash_index = ModelHashIndex() + tags_count: Dict[str, int] = {} + excluded_models: List[str] = [] + processed_files = 0 + + async def handle_progress() -> None: + if progress_callback is None: + return + try: + await progress_callback(processed_files, total_files) + except Exception as exc: # pragma: no cover - defensive logging + logger.error(f"Error reporting progress for {self.model_type}: {exc}") + + async def scan_recursive(current_path: str, root_path: str, visited_paths: Set[str]) -> None: + nonlocal processed_files + + try: + real_path = os.path.realpath(current_path) + if real_path in visited_paths: + return + visited_paths.add(real_path) + + with os.scandir(current_path) as iterator: + entries = list(iterator) + + for entry in entries: + try: + if entry.is_file(follow_symlinks=True): + ext = os.path.splitext(entry.name)[1].lower() + if ext not in self.file_extensions: + continue + + file_path = entry.path.replace(os.sep, "/") + result = await self._process_model_file( + file_path, + root_path, + hash_index=hash_index, + excluded_models=excluded_models + ) + + processed_files += 1 + + if result: + raw_data.append(result) + + sha_value = result.get('sha256') + model_path = result.get('file_path') + if sha_value and model_path: + hash_index.add_entry(sha_value.lower(), model_path) + + for tag in result.get('tags') or []: + tags_count[tag] = tags_count.get(tag, 0) + 1 + + await handle_progress() + await asyncio.sleep(0) + elif entry.is_dir(follow_symlinks=True): + await scan_recursive(entry.path, root_path, visited_paths) + except Exception as entry_error: + logger.error(f"Error processing entry {entry.path}: {entry_error}") + except Exception as scan_error: + logger.error(f"Error scanning {current_path}: {scan_error}") + + for model_root in self.get_model_roots(): + if not os.path.exists(model_root): + continue + + await scan_recursive(model_root, model_root, set()) + + return CacheBuildResult( + raw_data=raw_data, + hash_index=hash_index, + tags_count=tags_count, + excluded_models=excluded_models + ) + async def add_model_to_cache(self, metadata_dict: Dict, folder: str = '') -> bool: """Add a model to the cache - + Args: metadata_dict: The model metadata dictionary folder: The relative folder path for the model diff --git a/tests/services/test_model_scanner.py b/tests/services/test_model_scanner.py new file mode 100644 index 00000000..7eff5e78 --- /dev/null +++ b/tests/services/test_model_scanner.py @@ -0,0 +1,177 @@ +import asyncio +import os +from pathlib import Path +from typing import List + +import pytest + +from py.services import model_scanner +from py.services.model_cache import ModelCache +from py.services.model_hash_index import ModelHashIndex +from py.services.model_scanner import CacheBuildResult, ModelScanner +from py.utils.models import BaseModelMetadata + + +class RecordingWebSocketManager: + def __init__(self) -> None: + self.payloads: List[dict] = [] + + async def broadcast_init_progress(self, payload: dict) -> None: + self.payloads.append(payload) + + +def _normalize_path(path: Path) -> str: + return str(path).replace(os.sep, "/") + + +class DummyScanner(ModelScanner): + def __init__(self, root: Path): + self._root = str(root) + super().__init__( + model_type="dummy", + model_class=BaseModelMetadata, + file_extensions={".txt"}, + hash_index=ModelHashIndex(), + ) + + def get_model_roots(self) -> List[str]: + return [self._root] + + async def _process_model_file( + self, + file_path: str, + root_path: str, + *, + hash_index: ModelHashIndex | None = None, + excluded_models: List[str] | None = None, + ) -> dict: + hash_index = hash_index or self._hash_index + excluded_models = excluded_models if excluded_models is not None else self._excluded_models + + rel_path = os.path.relpath(file_path, root_path) + folder = os.path.dirname(rel_path).replace(os.path.sep, "/") + name = os.path.splitext(os.path.basename(file_path))[0] + + if name.startswith("skip"): + excluded_models.append(file_path.replace(os.sep, "/")) + return None + + tags = ["alpha"] if "one" in name else ["beta"] + + return { + "file_path": file_path.replace(os.sep, "/"), + "folder": folder, + "sha256": f"hash-{name}", + "tags": tags, + "model_name": name, + "size": 1, + "modified": 1.0, + } + + +@pytest.fixture(autouse=True) +def reset_model_scanner_singletons(): + ModelScanner._instances.clear() + ModelScanner._locks.clear() + yield + ModelScanner._instances.clear() + ModelScanner._locks.clear() + + +@pytest.fixture(autouse=True) +def stub_register_service(monkeypatch): + async def noop(*_args, **_kwargs): + return None + + monkeypatch.setattr(model_scanner.ServiceRegistry, "register_service", noop) + + +def _create_files(root: Path) -> tuple[Path, Path, Path]: + first = root / "one.txt" + first.write_text("one", encoding="utf-8") + + nested_dir = root / "nested" + nested_dir.mkdir() + second = nested_dir / "two.txt" + second.write_text("two", encoding="utf-8") + + skipped = root / "skip-file.txt" + skipped.write_text("skip", encoding="utf-8") + + return first, second, skipped + + +@pytest.mark.asyncio +async def test_initialize_cache_populates_cache(tmp_path: Path): + _create_files(tmp_path) + scanner = DummyScanner(tmp_path) + + await scanner._initialize_cache() + cache = await scanner.get_cached_data() + + cached_paths = {item["file_path"] for item in cache.raw_data} + assert cached_paths == { + _normalize_path(tmp_path / "one.txt"), + _normalize_path(tmp_path / "nested" / "two.txt"), + } + + assert scanner._hash_index.get_path("hash-one") == _normalize_path(tmp_path / "one.txt") + assert scanner._hash_index.get_path("hash-two") == _normalize_path(tmp_path / "nested" / "two.txt") + assert scanner._tags_count == {"alpha": 1, "beta": 1} + assert scanner._excluded_models == [_normalize_path(tmp_path / "skip-file.txt")] + assert sorted(cache.folders) == ["", "nested"] + + +@pytest.mark.asyncio +async def test_initialize_cache_sync_returns_result_without_mutating_state(tmp_path: Path, monkeypatch): + _create_files(tmp_path) + scanner = DummyScanner(tmp_path) + + ws_stub = RecordingWebSocketManager() + monkeypatch.setattr(model_scanner, "ws_manager", ws_stub) + + scanner._cache = ModelCache(raw_data=[{"file_path": "sentinel", "folder": ""}], folders=["existing"]) + + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, scanner._initialize_cache_sync, 2, "dummy") + + assert isinstance(result, CacheBuildResult) + assert {item["file_path"] for item in result.raw_data} == { + _normalize_path(tmp_path / "one.txt"), + _normalize_path(tmp_path / "nested" / "two.txt"), + } + assert result.tags_count == {"alpha": 1, "beta": 1} + assert ws_stub.payloads, "expected progress updates from websocket manager" + + assert scanner._cache.raw_data == [{"file_path": "sentinel", "folder": ""}] + assert scanner._hash_index.get_path("hash-one") is None + + +@pytest.mark.asyncio +async def test_initialize_in_background_applies_scan_result(tmp_path: Path, monkeypatch): + _create_files(tmp_path) + scanner = DummyScanner(tmp_path) + + ws_stub = RecordingWebSocketManager() + monkeypatch.setattr(model_scanner, "ws_manager", ws_stub) + + original_sleep = asyncio.sleep + + async def fast_sleep(duration: float) -> None: + await original_sleep(0) + + monkeypatch.setattr(model_scanner.asyncio, "sleep", fast_sleep) + + await scanner.initialize_in_background() + + cache = await scanner.get_cached_data() + cached_paths = {item["file_path"] for item in cache.raw_data} + + assert cached_paths == { + _normalize_path(tmp_path / "one.txt"), + _normalize_path(tmp_path / "nested" / "two.txt"), + } + assert scanner._hash_index.get_path("hash-two") == _normalize_path(tmp_path / "nested" / "two.txt") + assert scanner._tags_count == {"alpha": 1, "beta": 1} + assert scanner._excluded_models == [_normalize_path(tmp_path / "skip-file.txt")] + assert ws_stub.payloads[-1]["progress"] == 100