fix(checkpoints): preserve model type on persisted load

This commit is contained in:
pixelpaws
2025-10-21 22:55:00 +08:00
parent e63ef8d031
commit c5175bb870
3 changed files with 201 additions and 8 deletions

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import List from typing import Any, Dict, List, Optional
from ..utils.models import CheckpointMetadata from ..utils.models import CheckpointMetadata
from ..config import config from ..config import config
@@ -21,14 +21,33 @@ class CheckpointScanner(ModelScanner):
hash_index=ModelHashIndex() hash_index=ModelHashIndex()
) )
def _resolve_model_type(self, root_path: Optional[str]) -> Optional[str]:
if not root_path:
return None
if config.checkpoints_roots and root_path in config.checkpoints_roots:
return "checkpoint"
if config.unet_roots and root_path in config.unet_roots:
return "diffusion_model"
return None
def adjust_metadata(self, metadata, file_path, root_path): def adjust_metadata(self, metadata, file_path, root_path):
if hasattr(metadata, "model_type"): if hasattr(metadata, "model_type"):
if root_path in config.checkpoints_roots: model_type = self._resolve_model_type(root_path)
metadata.model_type = "checkpoint" if model_type:
elif root_path in config.unet_roots: metadata.model_type = model_type
metadata.model_type = "diffusion_model"
return metadata return metadata
def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
model_type = self._resolve_model_type(
self._find_root_for_file(entry.get("file_path"))
)
if model_type:
entry["model_type"] = model_type
return entry
def get_model_roots(self) -> List[str]: def get_model_roots(self) -> List[str]:
"""Get checkpoint root directories""" """Get checkpoint root directories"""
return config.base_models_roots return config.base_models_roots

View File

@@ -376,12 +376,16 @@ class ModelScanner:
hash_index.add_entry(sha_value.lower(), path) hash_index.add_entry(sha_value.lower(), path)
tags_count: Dict[str, int] = {} tags_count: Dict[str, int] = {}
adjusted_raw_data: List[Dict[str, Any]] = []
for item in persisted.raw_data: for item in persisted.raw_data:
for tag in item.get('tags') or []: adjusted_item = self.adjust_cached_entry(dict(item))
adjusted_raw_data.append(adjusted_item)
for tag in adjusted_item.get('tags') or []:
tags_count[tag] = tags_count.get(tag, 0) + 1 tags_count[tag] = tags_count.get(tag, 0) + 1
scan_result = CacheBuildResult( scan_result = CacheBuildResult(
raw_data=list(persisted.raw_data), raw_data=adjusted_raw_data,
hash_index=hash_index, hash_index=hash_index,
tags_count=tags_count, tags_count=tags_count,
excluded_models=list(persisted.excluded_models) excluded_models=list(persisted.excluded_models)
@@ -766,6 +770,41 @@ class ModelScanner:
"""Hook for subclasses: adjust metadata during scanning""" """Hook for subclasses: adjust metadata during scanning"""
return metadata return metadata
def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
"""Hook for subclasses: adjust entries loaded from the persisted cache."""
return entry
@staticmethod
def _normalize_path_value(path: Optional[str]) -> str:
if not path:
return ''
normalized = os.path.normpath(path)
if normalized == '.':
return ''
return normalized.replace('\\', '/')
def _find_root_for_file(self, file_path: Optional[str]) -> Optional[str]:
"""Return the configured root directory that contains ``file_path``."""
normalized_path = self._normalize_path_value(file_path)
if not normalized_path:
return None
for root in self.get_model_roots() or []:
normalized_root = self._normalize_path_value(root)
if not normalized_root:
continue
if (
normalized_path == normalized_root
or normalized_path.startswith(f"{normalized_root}/")
):
return root
return None
async def _process_model_file( async def _process_model_file(
self, self,
file_path: str, file_path: str,

View File

@@ -0,0 +1,135 @@
import os
from pathlib import Path
from typing import List
import pytest
from py.services import model_scanner
from py.services.checkpoint_scanner import CheckpointScanner
from py.services.model_scanner import ModelScanner
from py.services.persistent_model_cache import PersistedCacheData
class RecordingWebSocketManager:
def __init__(self) -> None:
self.payloads: List[dict] = []
async def broadcast_init_progress(self, payload: dict) -> None:
self.payloads.append(payload)
def _normalize(path: Path) -> str:
return str(path).replace(os.sep, "/")
@pytest.fixture(autouse=True)
def reset_model_scanner_singletons():
ModelScanner._instances.clear()
ModelScanner._locks.clear()
yield
ModelScanner._instances.clear()
ModelScanner._locks.clear()
@pytest.mark.asyncio
async def test_persisted_cache_restores_model_type(tmp_path: Path, monkeypatch):
monkeypatch.setenv("LORA_MANAGER_DISABLE_PERSISTENT_CACHE", "0")
checkpoints_root = tmp_path / "checkpoints"
unet_root = tmp_path / "unet"
checkpoints_root.mkdir()
unet_root.mkdir()
checkpoint_file = checkpoints_root / "alpha.safetensors"
unet_file = unet_root / "beta.safetensors"
checkpoint_file.write_text("alpha", encoding="utf-8")
unet_file.write_text("beta", encoding="utf-8")
normalized_checkpoint_root = _normalize(checkpoints_root)
normalized_unet_root = _normalize(unet_root)
normalized_checkpoint_file = _normalize(checkpoint_file)
normalized_unet_file = _normalize(unet_file)
monkeypatch.setattr(
model_scanner.config,
"base_models_roots",
[normalized_checkpoint_root, normalized_unet_root],
raising=False,
)
monkeypatch.setattr(
model_scanner.config,
"checkpoints_roots",
[normalized_checkpoint_root],
raising=False,
)
monkeypatch.setattr(
model_scanner.config,
"unet_roots",
[normalized_unet_root],
raising=False,
)
raw_checkpoint = {
"file_path": normalized_checkpoint_file,
"file_name": "alpha",
"model_name": "alpha",
"folder": "",
"size": 1,
"modified": 1.0,
"sha256": "hash-alpha",
"base_model": "",
"preview_url": "",
"preview_nsfw_level": 0,
"from_civitai": False,
"favorite": False,
"notes": "",
"usage_tips": "",
"metadata_source": None,
"exclude": False,
"db_checked": False,
"last_checked_at": 0.0,
"tags": [],
"civitai": None,
"civitai_deleted": False,
}
raw_unet = dict(raw_checkpoint)
raw_unet.update(
{
"file_path": normalized_unet_file,
"file_name": "beta",
"model_name": "beta",
"sha256": "hash-beta",
}
)
persisted = PersistedCacheData(
raw_data=[raw_checkpoint, raw_unet],
hash_rows=[],
excluded_models=[],
)
class FakePersistentCache:
def load_cache(self, model_type: str):
assert model_type == "checkpoint"
return persisted
fake_cache = FakePersistentCache()
monkeypatch.setattr(model_scanner, "get_persistent_cache", lambda: fake_cache)
ws_stub = RecordingWebSocketManager()
monkeypatch.setattr(model_scanner, "ws_manager", ws_stub)
scanner = CheckpointScanner()
loaded = await scanner._load_persisted_cache("checkpoints")
assert loaded is True
cache = await scanner.get_cached_data()
types_by_path = {item["file_path"]: item.get("model_type") for item in cache.raw_data}
assert types_by_path[normalized_checkpoint_file] == "checkpoint"
assert types_by_path[normalized_unet_file] == "diffusion_model"
assert ws_stub.payloads
assert ws_stub.payloads[-1]["stage"] == "loading_cache"