mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(scanner): enhance model scanning with cache build result and progress reporting
This commit is contained in:
@@ -4,7 +4,8 @@ import logging
|
||||
import asyncio
|
||||
import time
|
||||
import shutil
|
||||
from typing import List, Dict, Optional, Type, Set
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable, List, Dict, Optional, Type, Set
|
||||
|
||||
from ..utils.models import BaseModelMetadata
|
||||
from ..config import config
|
||||
@@ -19,6 +20,16 @@ from .websocket_manager import ws_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheBuildResult:
|
||||
"""Represents the outcome of scanning model files for cache building."""
|
||||
|
||||
raw_data: List[Dict]
|
||||
hash_index: ModelHashIndex
|
||||
tags_count: Dict[str, int]
|
||||
excluded_models: List[str]
|
||||
|
||||
class ModelScanner:
|
||||
"""Base service for scanning and managing model files"""
|
||||
|
||||
@@ -130,12 +141,15 @@ class ModelScanner:
|
||||
start_time = time.time()
|
||||
|
||||
# Use thread pool to execute CPU-intensive operations with progress reporting
|
||||
await loop.run_in_executor(
|
||||
scan_result: Optional[CacheBuildResult] = await loop.run_in_executor(
|
||||
None, # Use default thread pool
|
||||
self._initialize_cache_sync, # Run synchronous version in thread
|
||||
total_files, # Pass the total file count for progress reporting
|
||||
page_type # Pass the page type for progress reporting
|
||||
)
|
||||
|
||||
if scan_result:
|
||||
await self._apply_scan_result(scan_result)
|
||||
|
||||
# Send final progress update
|
||||
await ws_manager.broadcast_init_progress({
|
||||
@@ -204,124 +218,53 @@ class ModelScanner:
|
||||
|
||||
return total_files
|
||||
|
||||
def _initialize_cache_sync(self, total_files=0, page_type='loras'):
|
||||
def _initialize_cache_sync(self, total_files: int = 0, page_type: str = 'loras') -> Optional[CacheBuildResult]:
|
||||
"""Synchronous version of cache initialization for thread pool execution"""
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# Create a new event loop for this thread
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Create a synchronous method to bypass the async lock
|
||||
def sync_initialize_cache():
|
||||
# Track progress
|
||||
processed_files = 0
|
||||
last_progress_time = time.time()
|
||||
last_progress_percent = 0
|
||||
|
||||
# We need a wrapper around scan_all_models to track progress
|
||||
# This is a local function that will run in our thread's event loop
|
||||
async def scan_with_progress():
|
||||
nonlocal processed_files, last_progress_time, last_progress_percent
|
||||
|
||||
# For storing raw model data
|
||||
all_models = []
|
||||
|
||||
# Process each model root
|
||||
for root_path in self.get_model_roots():
|
||||
if not os.path.exists(root_path):
|
||||
continue
|
||||
|
||||
# Track visited paths to avoid symlink loops
|
||||
visited_paths = set()
|
||||
|
||||
# Recursively process directory
|
||||
async def scan_dir_with_progress(path):
|
||||
nonlocal processed_files, last_progress_time, last_progress_percent
|
||||
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
entries = list(it)
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True):
|
||||
ext = os.path.splitext(entry.name)[1].lower()
|
||||
if ext in self.file_extensions:
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
result = await self._process_model_file(file_path, root_path)
|
||||
if result:
|
||||
all_models.append(result)
|
||||
|
||||
# Update progress counter
|
||||
processed_files += 1
|
||||
|
||||
# Update progress periodically (not every file to avoid excessive updates)
|
||||
current_time = time.time()
|
||||
if total_files > 0 and (current_time - last_progress_time > 0.5 or processed_files == total_files):
|
||||
# Adjusted progress calculation
|
||||
progress_percent = min(99, int(1 + (processed_files / total_files) * 98))
|
||||
if progress_percent > last_progress_percent:
|
||||
last_progress_percent = progress_percent
|
||||
last_progress_time = current_time
|
||||
|
||||
# Send progress update through websocket
|
||||
await ws_manager.broadcast_init_progress({
|
||||
'stage': 'process_models',
|
||||
'progress': progress_percent,
|
||||
'details': f"Processing {self.model_type} files: {processed_files}/{total_files}",
|
||||
'scanner_type': self.model_type,
|
||||
'pageType': page_type
|
||||
})
|
||||
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
await scan_dir_with_progress(entry.path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {path}: {e}")
|
||||
|
||||
# Process the root path
|
||||
await scan_dir_with_progress(root_path)
|
||||
|
||||
return all_models
|
||||
|
||||
# Run the progress-tracking scan function
|
||||
raw_data = loop.run_until_complete(scan_with_progress())
|
||||
|
||||
# Update hash index and tags count
|
||||
for model_data in raw_data:
|
||||
if 'sha256' in model_data and 'file_path' in model_data:
|
||||
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
|
||||
|
||||
# Count tags
|
||||
if 'tags' in model_data and model_data['tags']:
|
||||
for tag in model_data['tags']:
|
||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||
|
||||
# Log duplicate filename warnings after building the index
|
||||
# duplicate_filenames = self._hash_index.get_duplicate_filenames()
|
||||
# if duplicate_filenames:
|
||||
# logger.warning(f"Found {len(duplicate_filenames)} filename(s) with duplicates during {self.model_type} cache build:")
|
||||
# for filename, paths in duplicate_filenames.items():
|
||||
# logger.warning(f" Duplicate filename '{filename}': {paths}")
|
||||
|
||||
# Update cache
|
||||
self._cache.raw_data = raw_data
|
||||
loop.run_until_complete(self._cache.resort())
|
||||
|
||||
return self._cache
|
||||
|
||||
# Run our sync initialization that avoids lock conflicts
|
||||
return sync_initialize_cache()
|
||||
|
||||
last_progress_time = time.time()
|
||||
last_progress_percent = 0
|
||||
|
||||
async def progress_callback(processed_files: int, expected_total: int) -> None:
|
||||
nonlocal last_progress_time, last_progress_percent
|
||||
|
||||
if expected_total <= 0:
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
progress_percent = min(99, int(1 + (processed_files / expected_total) * 98))
|
||||
|
||||
if progress_percent <= last_progress_percent:
|
||||
return
|
||||
|
||||
if current_time - last_progress_time <= 0.5 and processed_files != expected_total:
|
||||
return
|
||||
|
||||
last_progress_percent = progress_percent
|
||||
last_progress_time = current_time
|
||||
|
||||
await ws_manager.broadcast_init_progress({
|
||||
'stage': 'process_models',
|
||||
'progress': progress_percent,
|
||||
'details': f"Processing {self.model_type} files: {processed_files}/{expected_total}",
|
||||
'scanner_type': self.model_type,
|
||||
'pageType': page_type
|
||||
})
|
||||
|
||||
return loop.run_until_complete(
|
||||
self._gather_model_data(
|
||||
total_files=total_files,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in thread-based {self.model_type} cache initialization: {e}")
|
||||
return None
|
||||
finally:
|
||||
# Clean up the event loop
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
async def get_cached_data(self, force_refresh: bool = False, rebuild_cache: bool = False) -> ModelCache:
|
||||
@@ -353,45 +296,15 @@ class ModelScanner:
|
||||
self._is_initializing = True # Set flag
|
||||
try:
|
||||
start_time = time.time()
|
||||
# Clear existing hash index
|
||||
self._hash_index.clear()
|
||||
|
||||
# Clear existing tags count
|
||||
self._tags_count = {}
|
||||
|
||||
# Determine the page type based on model type
|
||||
page_type = 'loras' if self.model_type == 'lora' else 'checkpoints'
|
||||
|
||||
# Scan for new data
|
||||
raw_data = await self.scan_all_models()
|
||||
|
||||
# Build hash index and tags count
|
||||
for model_data in raw_data:
|
||||
if 'sha256' in model_data and 'file_path' in model_data:
|
||||
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
|
||||
|
||||
# Count tags
|
||||
if 'tags' in model_data and model_data['tags']:
|
||||
for tag in model_data['tags']:
|
||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||
|
||||
# Log duplicate filename warnings after building the index
|
||||
# duplicate_filenames = self._hash_index.get_duplicate_filenames()
|
||||
# if duplicate_filenames:
|
||||
# logger.warning(f"Found {len(duplicate_filenames)} filename(s) with duplicates during {self.model_type} cache build:")
|
||||
# for filename, paths in duplicate_filenames.items():
|
||||
# logger.warning(f" Duplicate filename '{filename}': {paths}")
|
||||
|
||||
# Update cache
|
||||
self._cache = ModelCache(
|
||||
raw_data=raw_data,
|
||||
folders=[]
|
||||
)
|
||||
|
||||
# Resort cache
|
||||
await self._cache.resort()
|
||||
scan_result = await self._gather_model_data()
|
||||
await self._apply_scan_result(scan_result)
|
||||
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} models")
|
||||
logger.info(
|
||||
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
|
||||
f"found {len(scan_result.raw_data)} models"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache: {e}")
|
||||
# Ensure cache is at least an empty structure on error
|
||||
@@ -547,23 +460,9 @@ class ModelScanner:
|
||||
|
||||
async def scan_all_models(self) -> List[Dict]:
|
||||
"""Scan all model directories and return metadata"""
|
||||
all_models = []
|
||||
|
||||
# Create scan tasks for each directory
|
||||
scan_tasks = []
|
||||
for model_root in self.get_model_roots():
|
||||
task = asyncio.create_task(self._scan_directory(model_root))
|
||||
scan_tasks.append(task)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for task in scan_tasks:
|
||||
try:
|
||||
models = await task
|
||||
all_models.extend(models)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning directory: {e}")
|
||||
|
||||
return all_models
|
||||
scan_result = await self._gather_model_data()
|
||||
self._excluded_models = scan_result.excluded_models
|
||||
return scan_result.raw_data
|
||||
|
||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||
"""Scan a single directory for model files"""
|
||||
@@ -624,8 +523,18 @@ class ModelScanner:
|
||||
"""Hook for subclasses: adjust metadata during scanning"""
|
||||
return metadata
|
||||
|
||||
async def _process_model_file(self, file_path: str, root_path: str) -> Dict:
|
||||
async def _process_model_file(
|
||||
self,
|
||||
file_path: str,
|
||||
root_path: str,
|
||||
*,
|
||||
hash_index: Optional[ModelHashIndex] = None,
|
||||
excluded_models: Optional[List[str]] = None
|
||||
) -> Dict:
|
||||
"""Process a single model file and return its metadata"""
|
||||
hash_index = hash_index or self._hash_index
|
||||
excluded_models = excluded_models if excluded_models is not None else self._excluded_models
|
||||
|
||||
metadata, should_skip = await MetadataManager.load_metadata(file_path, self.model_class)
|
||||
|
||||
if should_skip:
|
||||
@@ -689,26 +598,130 @@ class ModelScanner:
|
||||
|
||||
# Skip excluded models
|
||||
if model_data.get('exclude', False):
|
||||
self._excluded_models.append(model_data['file_path'])
|
||||
excluded_models.append(model_data['file_path'])
|
||||
return None
|
||||
|
||||
|
||||
# Check for duplicate filename before adding to hash index
|
||||
filename = os.path.splitext(os.path.basename(file_path))[0]
|
||||
existing_hash = self._hash_index.get_hash_by_filename(filename)
|
||||
if existing_hash and existing_hash != model_data.get('sha256', '').lower():
|
||||
existing_path = self._hash_index.get_path(existing_hash)
|
||||
if existing_path and existing_path != file_path:
|
||||
logger.warning(f"Duplicate filename detected: '{filename}' - files: '{existing_path}' and '{file_path}'")
|
||||
# filename = os.path.splitext(os.path.basename(file_path))[0]
|
||||
# existing_hash = hash_index.get_hash_by_filename(filename)
|
||||
# if existing_hash and existing_hash != model_data.get('sha256', '').lower():
|
||||
# existing_path = hash_index.get_path(existing_hash)
|
||||
# if existing_path and existing_path != file_path:
|
||||
# logger.warning(f"Duplicate filename detected: '{filename}' - files: '{existing_path}' and '{file_path}'")
|
||||
|
||||
rel_path = os.path.relpath(file_path, root_path)
|
||||
folder = os.path.dirname(rel_path)
|
||||
model_data['folder'] = folder.replace(os.path.sep, '/')
|
||||
|
||||
|
||||
return model_data
|
||||
|
||||
async def _apply_scan_result(self, scan_result: CacheBuildResult) -> None:
|
||||
"""Apply scan results to the cache and associated indexes."""
|
||||
|
||||
if scan_result is None:
|
||||
return
|
||||
|
||||
self._hash_index = scan_result.hash_index
|
||||
self._tags_count = dict(scan_result.tags_count)
|
||||
self._excluded_models = list(scan_result.excluded_models)
|
||||
|
||||
if self._cache is None:
|
||||
self._cache = ModelCache(
|
||||
raw_data=list(scan_result.raw_data),
|
||||
folders=[]
|
||||
)
|
||||
else:
|
||||
self._cache.raw_data = list(scan_result.raw_data)
|
||||
|
||||
await self._cache.resort()
|
||||
|
||||
async def _gather_model_data(
|
||||
self,
|
||||
*,
|
||||
total_files: int = 0,
|
||||
progress_callback: Optional[Callable[[int, int], Awaitable[None]]] = None
|
||||
) -> CacheBuildResult:
|
||||
"""Collect metadata for all model files."""
|
||||
|
||||
raw_data: List[Dict] = []
|
||||
hash_index = ModelHashIndex()
|
||||
tags_count: Dict[str, int] = {}
|
||||
excluded_models: List[str] = []
|
||||
processed_files = 0
|
||||
|
||||
async def handle_progress() -> None:
|
||||
if progress_callback is None:
|
||||
return
|
||||
try:
|
||||
await progress_callback(processed_files, total_files)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error(f"Error reporting progress for {self.model_type}: {exc}")
|
||||
|
||||
async def scan_recursive(current_path: str, root_path: str, visited_paths: Set[str]) -> None:
|
||||
nonlocal processed_files
|
||||
|
||||
try:
|
||||
real_path = os.path.realpath(current_path)
|
||||
if real_path in visited_paths:
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(current_path) as iterator:
|
||||
entries = list(iterator)
|
||||
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True):
|
||||
ext = os.path.splitext(entry.name)[1].lower()
|
||||
if ext not in self.file_extensions:
|
||||
continue
|
||||
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
result = await self._process_model_file(
|
||||
file_path,
|
||||
root_path,
|
||||
hash_index=hash_index,
|
||||
excluded_models=excluded_models
|
||||
)
|
||||
|
||||
processed_files += 1
|
||||
|
||||
if result:
|
||||
raw_data.append(result)
|
||||
|
||||
sha_value = result.get('sha256')
|
||||
model_path = result.get('file_path')
|
||||
if sha_value and model_path:
|
||||
hash_index.add_entry(sha_value.lower(), model_path)
|
||||
|
||||
for tag in result.get('tags') or []:
|
||||
tags_count[tag] = tags_count.get(tag, 0) + 1
|
||||
|
||||
await handle_progress()
|
||||
await asyncio.sleep(0)
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
await scan_recursive(entry.path, root_path, visited_paths)
|
||||
except Exception as entry_error:
|
||||
logger.error(f"Error processing entry {entry.path}: {entry_error}")
|
||||
except Exception as scan_error:
|
||||
logger.error(f"Error scanning {current_path}: {scan_error}")
|
||||
|
||||
for model_root in self.get_model_roots():
|
||||
if not os.path.exists(model_root):
|
||||
continue
|
||||
|
||||
await scan_recursive(model_root, model_root, set())
|
||||
|
||||
return CacheBuildResult(
|
||||
raw_data=raw_data,
|
||||
hash_index=hash_index,
|
||||
tags_count=tags_count,
|
||||
excluded_models=excluded_models
|
||||
)
|
||||
|
||||
async def add_model_to_cache(self, metadata_dict: Dict, folder: str = '') -> bool:
|
||||
"""Add a model to the cache
|
||||
|
||||
|
||||
Args:
|
||||
metadata_dict: The model metadata dictionary
|
||||
folder: The relative folder path for the model
|
||||
|
||||
177
tests/services/test_model_scanner.py
Normal file
177
tests/services/test_model_scanner.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services import model_scanner
|
||||
from py.services.model_cache import ModelCache
|
||||
from py.services.model_hash_index import ModelHashIndex
|
||||
from py.services.model_scanner import CacheBuildResult, ModelScanner
|
||||
from py.utils.models import BaseModelMetadata
|
||||
|
||||
|
||||
class RecordingWebSocketManager:
|
||||
def __init__(self) -> None:
|
||||
self.payloads: List[dict] = []
|
||||
|
||||
async def broadcast_init_progress(self, payload: dict) -> None:
|
||||
self.payloads.append(payload)
|
||||
|
||||
|
||||
def _normalize_path(path: Path) -> str:
|
||||
return str(path).replace(os.sep, "/")
|
||||
|
||||
|
||||
class DummyScanner(ModelScanner):
|
||||
def __init__(self, root: Path):
|
||||
self._root = str(root)
|
||||
super().__init__(
|
||||
model_type="dummy",
|
||||
model_class=BaseModelMetadata,
|
||||
file_extensions={".txt"},
|
||||
hash_index=ModelHashIndex(),
|
||||
)
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
return [self._root]
|
||||
|
||||
async def _process_model_file(
|
||||
self,
|
||||
file_path: str,
|
||||
root_path: str,
|
||||
*,
|
||||
hash_index: ModelHashIndex | None = None,
|
||||
excluded_models: List[str] | None = None,
|
||||
) -> dict:
|
||||
hash_index = hash_index or self._hash_index
|
||||
excluded_models = excluded_models if excluded_models is not None else self._excluded_models
|
||||
|
||||
rel_path = os.path.relpath(file_path, root_path)
|
||||
folder = os.path.dirname(rel_path).replace(os.path.sep, "/")
|
||||
name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
if name.startswith("skip"):
|
||||
excluded_models.append(file_path.replace(os.sep, "/"))
|
||||
return None
|
||||
|
||||
tags = ["alpha"] if "one" in name else ["beta"]
|
||||
|
||||
return {
|
||||
"file_path": file_path.replace(os.sep, "/"),
|
||||
"folder": folder,
|
||||
"sha256": f"hash-{name}",
|
||||
"tags": tags,
|
||||
"model_name": name,
|
||||
"size": 1,
|
||||
"modified": 1.0,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_model_scanner_singletons():
|
||||
ModelScanner._instances.clear()
|
||||
ModelScanner._locks.clear()
|
||||
yield
|
||||
ModelScanner._instances.clear()
|
||||
ModelScanner._locks.clear()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def stub_register_service(monkeypatch):
|
||||
async def noop(*_args, **_kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(model_scanner.ServiceRegistry, "register_service", noop)
|
||||
|
||||
|
||||
def _create_files(root: Path) -> tuple[Path, Path, Path]:
|
||||
first = root / "one.txt"
|
||||
first.write_text("one", encoding="utf-8")
|
||||
|
||||
nested_dir = root / "nested"
|
||||
nested_dir.mkdir()
|
||||
second = nested_dir / "two.txt"
|
||||
second.write_text("two", encoding="utf-8")
|
||||
|
||||
skipped = root / "skip-file.txt"
|
||||
skipped.write_text("skip", encoding="utf-8")
|
||||
|
||||
return first, second, skipped
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_cache_populates_cache(tmp_path: Path):
|
||||
_create_files(tmp_path)
|
||||
scanner = DummyScanner(tmp_path)
|
||||
|
||||
await scanner._initialize_cache()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
cached_paths = {item["file_path"] for item in cache.raw_data}
|
||||
assert cached_paths == {
|
||||
_normalize_path(tmp_path / "one.txt"),
|
||||
_normalize_path(tmp_path / "nested" / "two.txt"),
|
||||
}
|
||||
|
||||
assert scanner._hash_index.get_path("hash-one") == _normalize_path(tmp_path / "one.txt")
|
||||
assert scanner._hash_index.get_path("hash-two") == _normalize_path(tmp_path / "nested" / "two.txt")
|
||||
assert scanner._tags_count == {"alpha": 1, "beta": 1}
|
||||
assert scanner._excluded_models == [_normalize_path(tmp_path / "skip-file.txt")]
|
||||
assert sorted(cache.folders) == ["", "nested"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_cache_sync_returns_result_without_mutating_state(tmp_path: Path, monkeypatch):
|
||||
_create_files(tmp_path)
|
||||
scanner = DummyScanner(tmp_path)
|
||||
|
||||
ws_stub = RecordingWebSocketManager()
|
||||
monkeypatch.setattr(model_scanner, "ws_manager", ws_stub)
|
||||
|
||||
scanner._cache = ModelCache(raw_data=[{"file_path": "sentinel", "folder": ""}], folders=["existing"])
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(None, scanner._initialize_cache_sync, 2, "dummy")
|
||||
|
||||
assert isinstance(result, CacheBuildResult)
|
||||
assert {item["file_path"] for item in result.raw_data} == {
|
||||
_normalize_path(tmp_path / "one.txt"),
|
||||
_normalize_path(tmp_path / "nested" / "two.txt"),
|
||||
}
|
||||
assert result.tags_count == {"alpha": 1, "beta": 1}
|
||||
assert ws_stub.payloads, "expected progress updates from websocket manager"
|
||||
|
||||
assert scanner._cache.raw_data == [{"file_path": "sentinel", "folder": ""}]
|
||||
assert scanner._hash_index.get_path("hash-one") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_in_background_applies_scan_result(tmp_path: Path, monkeypatch):
|
||||
_create_files(tmp_path)
|
||||
scanner = DummyScanner(tmp_path)
|
||||
|
||||
ws_stub = RecordingWebSocketManager()
|
||||
monkeypatch.setattr(model_scanner, "ws_manager", ws_stub)
|
||||
|
||||
original_sleep = asyncio.sleep
|
||||
|
||||
async def fast_sleep(duration: float) -> None:
|
||||
await original_sleep(0)
|
||||
|
||||
monkeypatch.setattr(model_scanner.asyncio, "sleep", fast_sleep)
|
||||
|
||||
await scanner.initialize_in_background()
|
||||
|
||||
cache = await scanner.get_cached_data()
|
||||
cached_paths = {item["file_path"] for item in cache.raw_data}
|
||||
|
||||
assert cached_paths == {
|
||||
_normalize_path(tmp_path / "one.txt"),
|
||||
_normalize_path(tmp_path / "nested" / "two.txt"),
|
||||
}
|
||||
assert scanner._hash_index.get_path("hash-two") == _normalize_path(tmp_path / "nested" / "two.txt")
|
||||
assert scanner._tags_count == {"alpha": 1, "beta": 1}
|
||||
assert scanner._excluded_models == [_normalize_path(tmp_path / "skip-file.txt")]
|
||||
assert ws_stub.payloads[-1]["progress"] == 100
|
||||
Reference in New Issue
Block a user