Files
ComfyUI-Lora-Manager/tests/services/test_checkpoint_scanner.py
Will Miao 84c62f2954 refactor(model-type): complete phase 5 cleanup by removing deprecated model_type field
- Remove backward compatibility code for `model_type` in `ModelScanner._build_cache_entry()`
- Update `CheckpointScanner` to only handle `sub_type` in `adjust_metadata()` and `adjust_cached_entry()`
- Delete deprecated aliases `resolve_civitai_model_type` and `normalize_civitai_model_type` from `model_query.py`
- Update frontend components (`RecipeModal.js`, `ModelCard.js`, etc.) to use `sub_type` instead of `model_type`
- Update API response format to return only `sub_type`, removing `model_type` from service responses
- Revise technical documentation to mark Phase 5 as completed and remove outdated TODO items

All cleanup tasks for the model type refactoring are now complete, ensuring consistent use of `sub_type` across the codebase.
2026-01-30 07:48:31 +08:00

136 lines
3.8 KiB
Python

import os
from pathlib import Path
from typing import List
import pytest
from py.services import model_scanner
from py.services.checkpoint_scanner import CheckpointScanner
from py.services.model_scanner import ModelScanner
from py.services.persistent_model_cache import PersistedCacheData
class RecordingWebSocketManager:
def __init__(self) -> None:
self.payloads: List[dict] = []
async def broadcast_init_progress(self, payload: dict) -> None:
self.payloads.append(payload)
def _normalize(path: Path) -> str:
return str(path).replace(os.sep, "/")
@pytest.fixture(autouse=True)
def reset_model_scanner_singletons():
ModelScanner._instances.clear()
ModelScanner._locks.clear()
yield
ModelScanner._instances.clear()
ModelScanner._locks.clear()
@pytest.mark.asyncio
async def test_persisted_cache_restores_model_type(tmp_path: Path, monkeypatch):
monkeypatch.setenv("LORA_MANAGER_DISABLE_PERSISTENT_CACHE", "0")
checkpoints_root = tmp_path / "checkpoints"
unet_root = tmp_path / "unet"
checkpoints_root.mkdir()
unet_root.mkdir()
checkpoint_file = checkpoints_root / "alpha.safetensors"
unet_file = unet_root / "beta.safetensors"
checkpoint_file.write_text("alpha", encoding="utf-8")
unet_file.write_text("beta", encoding="utf-8")
normalized_checkpoint_root = _normalize(checkpoints_root)
normalized_unet_root = _normalize(unet_root)
normalized_checkpoint_file = _normalize(checkpoint_file)
normalized_unet_file = _normalize(unet_file)
monkeypatch.setattr(
model_scanner.config,
"base_models_roots",
[normalized_checkpoint_root, normalized_unet_root],
raising=False,
)
monkeypatch.setattr(
model_scanner.config,
"checkpoints_roots",
[normalized_checkpoint_root],
raising=False,
)
monkeypatch.setattr(
model_scanner.config,
"unet_roots",
[normalized_unet_root],
raising=False,
)
raw_checkpoint = {
"file_path": normalized_checkpoint_file,
"file_name": "alpha",
"model_name": "alpha",
"folder": "",
"size": 1,
"modified": 1.0,
"sha256": "hash-alpha",
"base_model": "",
"preview_url": "",
"preview_nsfw_level": 0,
"from_civitai": False,
"favorite": False,
"notes": "",
"usage_tips": "",
"metadata_source": None,
"exclude": False,
"db_checked": False,
"last_checked_at": 0.0,
"tags": [],
"civitai": None,
"civitai_deleted": False,
}
raw_unet = dict(raw_checkpoint)
raw_unet.update(
{
"file_path": normalized_unet_file,
"file_name": "beta",
"model_name": "beta",
"sha256": "hash-beta",
}
)
persisted = PersistedCacheData(
raw_data=[raw_checkpoint, raw_unet],
hash_rows=[],
excluded_models=[],
)
class FakePersistentCache:
def load_cache(self, model_type: str):
assert model_type == "checkpoint"
return persisted
fake_cache = FakePersistentCache()
monkeypatch.setattr(model_scanner, "get_persistent_cache", lambda: fake_cache)
ws_stub = RecordingWebSocketManager()
monkeypatch.setattr(model_scanner, "ws_manager", ws_stub)
scanner = CheckpointScanner()
loaded = await scanner._load_persisted_cache("checkpoints")
assert loaded is True
cache = await scanner.get_cached_data()
types_by_path = {item["file_path"]: item.get("sub_type") for item in cache.raw_data}
assert types_by_path[normalized_checkpoint_file] == "checkpoint"
assert types_by_path[normalized_unet_file] == "diffusion_model"
assert ws_stub.payloads
assert ws_stub.payloads[-1]["stage"] == "loading_cache"