From c5175bb8702b3288c2a2f0d127a51dd33355f63c Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Tue, 21 Oct 2025 22:55:00 +0800 Subject: [PATCH] fix(checkpoints): preserve model type on persisted load --- py/services/checkpoint_scanner.py | 31 ++++- py/services/model_scanner.py | 43 ++++++- tests/services/test_checkpoint_scanner.py | 135 ++++++++++++++++++++++ 3 files changed, 201 insertions(+), 8 deletions(-) create mode 100644 tests/services/test_checkpoint_scanner.py diff --git a/py/services/checkpoint_scanner.py b/py/services/checkpoint_scanner.py index 3f298156..25afce90 100644 --- a/py/services/checkpoint_scanner.py +++ b/py/services/checkpoint_scanner.py @@ -1,5 +1,5 @@ import logging -from typing import List +from typing import Any, Dict, List, Optional from ..utils.models import CheckpointMetadata from ..config import config @@ -21,14 +21,33 @@ class CheckpointScanner(ModelScanner): hash_index=ModelHashIndex() ) + def _resolve_model_type(self, root_path: Optional[str]) -> Optional[str]: + if not root_path: + return None + + if config.checkpoints_roots and root_path in config.checkpoints_roots: + return "checkpoint" + + if config.unet_roots and root_path in config.unet_roots: + return "diffusion_model" + + return None + def adjust_metadata(self, metadata, file_path, root_path): if hasattr(metadata, "model_type"): - if root_path in config.checkpoints_roots: - metadata.model_type = "checkpoint" - elif root_path in config.unet_roots: - metadata.model_type = "diffusion_model" + model_type = self._resolve_model_type(root_path) + if model_type: + metadata.model_type = model_type return metadata + def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]: + model_type = self._resolve_model_type( + self._find_root_for_file(entry.get("file_path")) + ) + if model_type: + entry["model_type"] = model_type + return entry + def get_model_roots(self) -> List[str]: """Get checkpoint root directories""" - return config.base_models_roots \ No newline at end of file + return config.base_models_roots diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index 4859f669..4ec993c3 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -376,12 +376,16 @@ class ModelScanner: hash_index.add_entry(sha_value.lower(), path) tags_count: Dict[str, int] = {} + adjusted_raw_data: List[Dict[str, Any]] = [] for item in persisted.raw_data: - for tag in item.get('tags') or []: + adjusted_item = self.adjust_cached_entry(dict(item)) + adjusted_raw_data.append(adjusted_item) + + for tag in adjusted_item.get('tags') or []: tags_count[tag] = tags_count.get(tag, 0) + 1 scan_result = CacheBuildResult( - raw_data=list(persisted.raw_data), + raw_data=adjusted_raw_data, hash_index=hash_index, tags_count=tags_count, excluded_models=list(persisted.excluded_models) @@ -766,6 +770,41 @@ class ModelScanner: """Hook for subclasses: adjust metadata during scanning""" return metadata + def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]: + """Hook for subclasses: adjust entries loaded from the persisted cache.""" + return entry + + @staticmethod + def _normalize_path_value(path: Optional[str]) -> str: + if not path: + return '' + + normalized = os.path.normpath(path) + if normalized == '.': + return '' + + return normalized.replace('\\', '/') + + def _find_root_for_file(self, file_path: Optional[str]) -> Optional[str]: + """Return the configured root directory that contains ``file_path``.""" + + normalized_path = self._normalize_path_value(file_path) + if not normalized_path: + return None + + for root in self.get_model_roots() or []: + normalized_root = self._normalize_path_value(root) + if not normalized_root: + continue + + if ( + normalized_path == normalized_root + or normalized_path.startswith(f"{normalized_root}/") + ): + return root + + return None + async def _process_model_file( self, file_path: str, diff --git a/tests/services/test_checkpoint_scanner.py b/tests/services/test_checkpoint_scanner.py new file mode 100644 index 00000000..eb5a9944 --- /dev/null +++ b/tests/services/test_checkpoint_scanner.py @@ -0,0 +1,135 @@ +import os +from pathlib import Path +from typing import List + +import pytest + +from py.services import model_scanner +from py.services.checkpoint_scanner import CheckpointScanner +from py.services.model_scanner import ModelScanner +from py.services.persistent_model_cache import PersistedCacheData + + +class RecordingWebSocketManager: + def __init__(self) -> None: + self.payloads: List[dict] = [] + + async def broadcast_init_progress(self, payload: dict) -> None: + self.payloads.append(payload) + + +def _normalize(path: Path) -> str: + return str(path).replace(os.sep, "/") + + +@pytest.fixture(autouse=True) +def reset_model_scanner_singletons(): + ModelScanner._instances.clear() + ModelScanner._locks.clear() + yield + ModelScanner._instances.clear() + ModelScanner._locks.clear() + + +@pytest.mark.asyncio +async def test_persisted_cache_restores_model_type(tmp_path: Path, monkeypatch): + monkeypatch.setenv("LORA_MANAGER_DISABLE_PERSISTENT_CACHE", "0") + + checkpoints_root = tmp_path / "checkpoints" + unet_root = tmp_path / "unet" + checkpoints_root.mkdir() + unet_root.mkdir() + + checkpoint_file = checkpoints_root / "alpha.safetensors" + unet_file = unet_root / "beta.safetensors" + checkpoint_file.write_text("alpha", encoding="utf-8") + unet_file.write_text("beta", encoding="utf-8") + + normalized_checkpoint_root = _normalize(checkpoints_root) + normalized_unet_root = _normalize(unet_root) + normalized_checkpoint_file = _normalize(checkpoint_file) + normalized_unet_file = _normalize(unet_file) + + monkeypatch.setattr( + model_scanner.config, + "base_models_roots", + [normalized_checkpoint_root, normalized_unet_root], + raising=False, + ) + monkeypatch.setattr( + model_scanner.config, + "checkpoints_roots", + [normalized_checkpoint_root], + raising=False, + ) + monkeypatch.setattr( + model_scanner.config, + "unet_roots", + [normalized_unet_root], + raising=False, + ) + + raw_checkpoint = { + "file_path": normalized_checkpoint_file, + "file_name": "alpha", + "model_name": "alpha", + "folder": "", + "size": 1, + "modified": 1.0, + "sha256": "hash-alpha", + "base_model": "", + "preview_url": "", + "preview_nsfw_level": 0, + "from_civitai": False, + "favorite": False, + "notes": "", + "usage_tips": "", + "metadata_source": None, + "exclude": False, + "db_checked": False, + "last_checked_at": 0.0, + "tags": [], + "civitai": None, + "civitai_deleted": False, + } + + raw_unet = dict(raw_checkpoint) + raw_unet.update( + { + "file_path": normalized_unet_file, + "file_name": "beta", + "model_name": "beta", + "sha256": "hash-beta", + } + ) + + persisted = PersistedCacheData( + raw_data=[raw_checkpoint, raw_unet], + hash_rows=[], + excluded_models=[], + ) + + class FakePersistentCache: + def load_cache(self, model_type: str): + assert model_type == "checkpoint" + return persisted + + fake_cache = FakePersistentCache() + monkeypatch.setattr(model_scanner, "get_persistent_cache", lambda: fake_cache) + + ws_stub = RecordingWebSocketManager() + monkeypatch.setattr(model_scanner, "ws_manager", ws_stub) + + scanner = CheckpointScanner() + + loaded = await scanner._load_persisted_cache("checkpoints") + assert loaded is True + + cache = await scanner.get_cached_data() + types_by_path = {item["file_path"]: item.get("model_type") for item in cache.raw_data} + + assert types_by_path[normalized_checkpoint_file] == "checkpoint" + assert types_by_path[normalized_unet_file] == "diffusion_model" + + assert ws_stub.payloads + assert ws_stub.payloads[-1]["stage"] == "loading_cache"