mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 13:42:12 -03:00
- Remove backward compatibility code for `model_type` in `ModelScanner._build_cache_entry()` - Update `CheckpointScanner` to only handle `sub_type` in `adjust_metadata()` and `adjust_cached_entry()` - Delete deprecated aliases `resolve_civitai_model_type` and `normalize_civitai_model_type` from `model_query.py` - Update frontend components (`RecipeModal.js`, `ModelCard.js`, etc.) to use `sub_type` instead of `model_type` - Update API response format to return only `sub_type`, removing `model_type` from service responses - Revise technical documentation to mark Phase 5 as completed and remove outdated TODO items All cleanup tasks for the model type refactoring are now complete, ensuring consistent use of `sub_type` across the codebase.
136 lines
3.8 KiB
Python
136 lines
3.8 KiB
Python
import os
|
|
from pathlib import Path
|
|
from typing import List
|
|
|
|
import pytest
|
|
|
|
from py.services import model_scanner
|
|
from py.services.checkpoint_scanner import CheckpointScanner
|
|
from py.services.model_scanner import ModelScanner
|
|
from py.services.persistent_model_cache import PersistedCacheData
|
|
|
|
|
|
class RecordingWebSocketManager:
|
|
def __init__(self) -> None:
|
|
self.payloads: List[dict] = []
|
|
|
|
async def broadcast_init_progress(self, payload: dict) -> None:
|
|
self.payloads.append(payload)
|
|
|
|
|
|
def _normalize(path: Path) -> str:
|
|
return str(path).replace(os.sep, "/")
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def reset_model_scanner_singletons():
|
|
ModelScanner._instances.clear()
|
|
ModelScanner._locks.clear()
|
|
yield
|
|
ModelScanner._instances.clear()
|
|
ModelScanner._locks.clear()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_persisted_cache_restores_model_type(tmp_path: Path, monkeypatch):
|
|
monkeypatch.setenv("LORA_MANAGER_DISABLE_PERSISTENT_CACHE", "0")
|
|
|
|
checkpoints_root = tmp_path / "checkpoints"
|
|
unet_root = tmp_path / "unet"
|
|
checkpoints_root.mkdir()
|
|
unet_root.mkdir()
|
|
|
|
checkpoint_file = checkpoints_root / "alpha.safetensors"
|
|
unet_file = unet_root / "beta.safetensors"
|
|
checkpoint_file.write_text("alpha", encoding="utf-8")
|
|
unet_file.write_text("beta", encoding="utf-8")
|
|
|
|
normalized_checkpoint_root = _normalize(checkpoints_root)
|
|
normalized_unet_root = _normalize(unet_root)
|
|
normalized_checkpoint_file = _normalize(checkpoint_file)
|
|
normalized_unet_file = _normalize(unet_file)
|
|
|
|
monkeypatch.setattr(
|
|
model_scanner.config,
|
|
"base_models_roots",
|
|
[normalized_checkpoint_root, normalized_unet_root],
|
|
raising=False,
|
|
)
|
|
monkeypatch.setattr(
|
|
model_scanner.config,
|
|
"checkpoints_roots",
|
|
[normalized_checkpoint_root],
|
|
raising=False,
|
|
)
|
|
monkeypatch.setattr(
|
|
model_scanner.config,
|
|
"unet_roots",
|
|
[normalized_unet_root],
|
|
raising=False,
|
|
)
|
|
|
|
raw_checkpoint = {
|
|
"file_path": normalized_checkpoint_file,
|
|
"file_name": "alpha",
|
|
"model_name": "alpha",
|
|
"folder": "",
|
|
"size": 1,
|
|
"modified": 1.0,
|
|
"sha256": "hash-alpha",
|
|
"base_model": "",
|
|
"preview_url": "",
|
|
"preview_nsfw_level": 0,
|
|
"from_civitai": False,
|
|
"favorite": False,
|
|
"notes": "",
|
|
"usage_tips": "",
|
|
"metadata_source": None,
|
|
"exclude": False,
|
|
"db_checked": False,
|
|
"last_checked_at": 0.0,
|
|
"tags": [],
|
|
"civitai": None,
|
|
"civitai_deleted": False,
|
|
}
|
|
|
|
raw_unet = dict(raw_checkpoint)
|
|
raw_unet.update(
|
|
{
|
|
"file_path": normalized_unet_file,
|
|
"file_name": "beta",
|
|
"model_name": "beta",
|
|
"sha256": "hash-beta",
|
|
}
|
|
)
|
|
|
|
persisted = PersistedCacheData(
|
|
raw_data=[raw_checkpoint, raw_unet],
|
|
hash_rows=[],
|
|
excluded_models=[],
|
|
)
|
|
|
|
class FakePersistentCache:
|
|
def load_cache(self, model_type: str):
|
|
assert model_type == "checkpoint"
|
|
return persisted
|
|
|
|
fake_cache = FakePersistentCache()
|
|
monkeypatch.setattr(model_scanner, "get_persistent_cache", lambda: fake_cache)
|
|
|
|
ws_stub = RecordingWebSocketManager()
|
|
monkeypatch.setattr(model_scanner, "ws_manager", ws_stub)
|
|
|
|
scanner = CheckpointScanner()
|
|
|
|
loaded = await scanner._load_persisted_cache("checkpoints")
|
|
assert loaded is True
|
|
|
|
cache = await scanner.get_cached_data()
|
|
types_by_path = {item["file_path"]: item.get("sub_type") for item in cache.raw_data}
|
|
|
|
assert types_by_path[normalized_checkpoint_file] == "checkpoint"
|
|
assert types_by_path[normalized_unet_file] == "diffusion_model"
|
|
|
|
assert ws_stub.payloads
|
|
assert ws_stub.payloads[-1]["stage"] == "loading_cache"
|