From 8d336320c0c25e0cdd1bebe8a47df6ae3bba477e Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Thu, 23 Oct 2025 07:34:35 +0800 Subject: [PATCH] fix(scanner): apply metadata adjustments during reconciliation --- py/services/download_manager.py | 32 +++++++-- py/services/model_scanner.py | 3 + tests/services/test_download_manager.py | 92 +++++++++++++++++++++++++ tests/services/test_model_scanner.py | 32 +++++++++ 4 files changed, 155 insertions(+), 4 deletions(-) diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 198631a6..a0b7b6fe 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -643,10 +643,10 @@ class DownloadManager: # 4. Update file information (size and modified time) metadata.update_file_info(save_path) - # 5. Final metadata update - await MetadataManager.save_metadata(save_path, metadata) + scanner = None + adjust_root: Optional[str] = None - # 6. Update cache based on model type + # 5. Determine scanner and adjust metadata for cache consistency if model_type == "checkpoint": scanner = await self._get_checkpoint_scanner() logger.info(f"Updating checkpoint cache for {save_path}") @@ -656,9 +656,33 @@ class DownloadManager: elif model_type == "embedding": scanner = await ServiceRegistry.get_embedding_scanner() logger.info(f"Updating embedding cache for {save_path}") - + + if scanner is not None: + file_path_for_adjust = getattr(metadata, "file_path", save_path) + if isinstance(file_path_for_adjust, str): + normalized_file_path = file_path_for_adjust.replace(os.sep, "/") + else: + normalized_file_path = str(file_path_for_adjust) + + find_root = getattr(scanner, "_find_root_for_file", None) + if callable(find_root): + try: + adjust_root = find_root(normalized_file_path) + except TypeError: + adjust_root = None + + adjust_metadata = getattr(scanner, "adjust_metadata", None) + if callable(adjust_metadata): + metadata = adjust_metadata(metadata, normalized_file_path, adjust_root) + + # 6. Persist metadata with any adjustments + await MetadataManager.save_metadata(save_path, metadata) + # Convert metadata to dictionary metadata_dict = metadata.to_dict() + adjust_cached_entry = getattr(scanner, "adjust_cached_entry", None) if scanner is not None else None + if callable(adjust_cached_entry): + metadata_dict = adjust_cached_entry(metadata_dict) # Add model to cache and save to disk in a single operation await scanner.add_model_to_cache(metadata_dict, relative_path) diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index 4ec993c3..f035ea34 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -678,6 +678,9 @@ class ModelScanner: if root_path: model_data = await self._process_model_file(path, root_path) if model_data: + model_data = self.adjust_cached_entry(dict(model_data)) + if not model_data: + continue # Add to cache self._cache.raw_data.append(model_data) self._cache.add_to_version_index(model_data) diff --git a/tests/services/test_download_manager.py b/tests/services/test_download_manager.py index 4306adf7..488b9a61 100644 --- a/tests/services/test_download_manager.py +++ b/tests/services/test_download_manager.py @@ -1,6 +1,7 @@ import asyncio import os from pathlib import Path +from typing import Optional from types import SimpleNamespace from unittest.mock import AsyncMock @@ -431,6 +432,97 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path): assert dummy_scanner.calls # ensure cache updated +async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_path): + manager = DownloadManager() + + root_dir = tmp_path / "checkpoints" + root_dir.mkdir() + save_dir = root_dir + target_path = save_dir / "model.safetensors" + + class DummyMetadata: + def __init__(self, path: Path): + self.file_path = path.as_posix() + self.sha256 = "sha256" + self.file_name = path.stem + self.preview_url = None + self.preview_nsfw_level = 0 + self.model_type = "checkpoint" + + def generate_unique_filename(self, *_args, **_kwargs): + return os.path.basename(self.file_path) + + def update_file_info(self, updated_path): + self.file_path = Path(updated_path).as_posix() + + def to_dict(self): + return { + "file_path": self.file_path, + "model_type": self.model_type, + "sha256": self.sha256, + } + + metadata = DummyMetadata(target_path) + version_info = {"images": []} + download_urls = ["https://example.invalid/model.safetensors"] + + class DummyDownloader: + async def download_file(self, _url, path, progress_callback=None, use_auth=None): + Path(path).write_text("content") + return True, "ok" + + monkeypatch.setattr( + download_manager, + "get_downloader", + AsyncMock(return_value=DummyDownloader()), + ) + + class DummyCheckpointScanner: + def __init__(self, root: Path): + self.root = root.as_posix() + self.add_calls = [] + + def _find_root_for_file(self, file_path: str): + return self.root if file_path.startswith(self.root) else None + + def adjust_metadata(self, metadata_obj, _file_path: str, root_path: Optional[str]): + if root_path: + metadata_obj.model_type = "diffusion_model" + return metadata_obj + + def adjust_cached_entry(self, entry): + if entry.get("file_path", "").startswith(self.root): + entry["model_type"] = "diffusion_model" + return entry + + async def add_model_to_cache(self, metadata_dict, relative_path): + self.add_calls.append((metadata_dict, relative_path)) + return True + + dummy_scanner = DummyCheckpointScanner(root_dir) + monkeypatch.setattr(DownloadManager, "_get_checkpoint_scanner", AsyncMock(return_value=dummy_scanner)) + monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) + + result = await manager._execute_download( + download_urls=download_urls, + save_dir=str(save_dir), + metadata=metadata, + version_info=version_info, + relative_path="", + progress_callback=None, + model_type="checkpoint", + download_id=None, + ) + + assert result == {"success": True} + assert metadata.model_type == "diffusion_model" + saved_metadata = MetadataManager.save_metadata.await_args.args[1] + assert saved_metadata.model_type == "diffusion_model" + assert dummy_scanner.add_calls + cached_entry, _ = dummy_scanner.add_calls[0] + assert cached_entry["model_type"] == "diffusion_model" + + async def test_pause_download_updates_state(): manager = DownloadManager() diff --git a/tests/services/test_model_scanner.py b/tests/services/test_model_scanner.py index 78505f3c..972d5f5a 100644 --- a/tests/services/test_model_scanner.py +++ b/tests/services/test_model_scanner.py @@ -3,6 +3,7 @@ import os import sqlite3 from pathlib import Path from typing import List +from types import MethodType import pytest @@ -470,6 +471,37 @@ async def test_reconcile_cache_adds_new_files_and_updates_hash_index(tmp_path: P assert cache.folders == [""] +@pytest.mark.asyncio +async def test_reconcile_cache_applies_adjust_cached_entry(tmp_path: Path): + existing = tmp_path / "one.txt" + existing.write_text("one", encoding="utf-8") + + scanner = DummyScanner(tmp_path) + + applied: List[str] = [] + + def _adjust(self, entry: dict) -> dict: + applied.append(entry["file_path"]) + entry["model_type"] = "adjusted" + return entry + + scanner.adjust_cached_entry = MethodType(_adjust, scanner) + + await scanner._initialize_cache() + applied.clear() + + new_file = tmp_path / "two.txt" + new_file.write_text("two", encoding="utf-8") + + await scanner._reconcile_cache() + + normalized_new = _normalize_path(new_file) + assert normalized_new in applied + + new_entry = next(item for item in scanner._cache.raw_data if item["file_path"] == normalized_new) + assert new_entry["model_type"] == "adjusted" + + @pytest.mark.asyncio async def test_count_model_files_handles_symlink_loops(tmp_path: Path): scanner = DummyScanner(tmp_path)