mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Merge pull request #582 from willmiao/codex/fix-model-type-adjustment-in-scanner
Fix checkpoint model type when hydrating persisted cache
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..utils.models import CheckpointMetadata
|
||||
from ..config import config
|
||||
@@ -21,14 +21,33 @@ class CheckpointScanner(ModelScanner):
|
||||
hash_index=ModelHashIndex()
|
||||
)
|
||||
|
||||
def _resolve_model_type(self, root_path: Optional[str]) -> Optional[str]:
|
||||
if not root_path:
|
||||
return None
|
||||
|
||||
if config.checkpoints_roots and root_path in config.checkpoints_roots:
|
||||
return "checkpoint"
|
||||
|
||||
if config.unet_roots and root_path in config.unet_roots:
|
||||
return "diffusion_model"
|
||||
|
||||
return None
|
||||
|
||||
def adjust_metadata(self, metadata, file_path, root_path):
|
||||
if hasattr(metadata, "model_type"):
|
||||
if root_path in config.checkpoints_roots:
|
||||
metadata.model_type = "checkpoint"
|
||||
elif root_path in config.unet_roots:
|
||||
metadata.model_type = "diffusion_model"
|
||||
model_type = self._resolve_model_type(root_path)
|
||||
if model_type:
|
||||
metadata.model_type = model_type
|
||||
return metadata
|
||||
|
||||
def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
|
||||
model_type = self._resolve_model_type(
|
||||
self._find_root_for_file(entry.get("file_path"))
|
||||
)
|
||||
if model_type:
|
||||
entry["model_type"] = model_type
|
||||
return entry
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get checkpoint root directories"""
|
||||
return config.base_models_roots
|
||||
return config.base_models_roots
|
||||
|
||||
@@ -376,12 +376,16 @@ class ModelScanner:
|
||||
hash_index.add_entry(sha_value.lower(), path)
|
||||
|
||||
tags_count: Dict[str, int] = {}
|
||||
adjusted_raw_data: List[Dict[str, Any]] = []
|
||||
for item in persisted.raw_data:
|
||||
for tag in item.get('tags') or []:
|
||||
adjusted_item = self.adjust_cached_entry(dict(item))
|
||||
adjusted_raw_data.append(adjusted_item)
|
||||
|
||||
for tag in adjusted_item.get('tags') or []:
|
||||
tags_count[tag] = tags_count.get(tag, 0) + 1
|
||||
|
||||
scan_result = CacheBuildResult(
|
||||
raw_data=list(persisted.raw_data),
|
||||
raw_data=adjusted_raw_data,
|
||||
hash_index=hash_index,
|
||||
tags_count=tags_count,
|
||||
excluded_models=list(persisted.excluded_models)
|
||||
@@ -766,6 +770,41 @@ class ModelScanner:
|
||||
"""Hook for subclasses: adjust metadata during scanning"""
|
||||
return metadata
|
||||
|
||||
def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Hook for subclasses: adjust entries loaded from the persisted cache."""
|
||||
return entry
|
||||
|
||||
@staticmethod
|
||||
def _normalize_path_value(path: Optional[str]) -> str:
|
||||
if not path:
|
||||
return ''
|
||||
|
||||
normalized = os.path.normpath(path)
|
||||
if normalized == '.':
|
||||
return ''
|
||||
|
||||
return normalized.replace('\\', '/')
|
||||
|
||||
def _find_root_for_file(self, file_path: Optional[str]) -> Optional[str]:
|
||||
"""Return the configured root directory that contains ``file_path``."""
|
||||
|
||||
normalized_path = self._normalize_path_value(file_path)
|
||||
if not normalized_path:
|
||||
return None
|
||||
|
||||
for root in self.get_model_roots() or []:
|
||||
normalized_root = self._normalize_path_value(root)
|
||||
if not normalized_root:
|
||||
continue
|
||||
|
||||
if (
|
||||
normalized_path == normalized_root
|
||||
or normalized_path.startswith(f"{normalized_root}/")
|
||||
):
|
||||
return root
|
||||
|
||||
return None
|
||||
|
||||
async def _process_model_file(
|
||||
self,
|
||||
file_path: str,
|
||||
|
||||
135
tests/services/test_checkpoint_scanner.py
Normal file
135
tests/services/test_checkpoint_scanner.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services import model_scanner
|
||||
from py.services.checkpoint_scanner import CheckpointScanner
|
||||
from py.services.model_scanner import ModelScanner
|
||||
from py.services.persistent_model_cache import PersistedCacheData
|
||||
|
||||
|
||||
class RecordingWebSocketManager:
|
||||
def __init__(self) -> None:
|
||||
self.payloads: List[dict] = []
|
||||
|
||||
async def broadcast_init_progress(self, payload: dict) -> None:
|
||||
self.payloads.append(payload)
|
||||
|
||||
|
||||
def _normalize(path: Path) -> str:
|
||||
return str(path).replace(os.sep, "/")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_model_scanner_singletons():
|
||||
ModelScanner._instances.clear()
|
||||
ModelScanner._locks.clear()
|
||||
yield
|
||||
ModelScanner._instances.clear()
|
||||
ModelScanner._locks.clear()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persisted_cache_restores_model_type(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setenv("LORA_MANAGER_DISABLE_PERSISTENT_CACHE", "0")
|
||||
|
||||
checkpoints_root = tmp_path / "checkpoints"
|
||||
unet_root = tmp_path / "unet"
|
||||
checkpoints_root.mkdir()
|
||||
unet_root.mkdir()
|
||||
|
||||
checkpoint_file = checkpoints_root / "alpha.safetensors"
|
||||
unet_file = unet_root / "beta.safetensors"
|
||||
checkpoint_file.write_text("alpha", encoding="utf-8")
|
||||
unet_file.write_text("beta", encoding="utf-8")
|
||||
|
||||
normalized_checkpoint_root = _normalize(checkpoints_root)
|
||||
normalized_unet_root = _normalize(unet_root)
|
||||
normalized_checkpoint_file = _normalize(checkpoint_file)
|
||||
normalized_unet_file = _normalize(unet_file)
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"base_models_roots",
|
||||
[normalized_checkpoint_root, normalized_unet_root],
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"checkpoints_roots",
|
||||
[normalized_checkpoint_root],
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"unet_roots",
|
||||
[normalized_unet_root],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
raw_checkpoint = {
|
||||
"file_path": normalized_checkpoint_file,
|
||||
"file_name": "alpha",
|
||||
"model_name": "alpha",
|
||||
"folder": "",
|
||||
"size": 1,
|
||||
"modified": 1.0,
|
||||
"sha256": "hash-alpha",
|
||||
"base_model": "",
|
||||
"preview_url": "",
|
||||
"preview_nsfw_level": 0,
|
||||
"from_civitai": False,
|
||||
"favorite": False,
|
||||
"notes": "",
|
||||
"usage_tips": "",
|
||||
"metadata_source": None,
|
||||
"exclude": False,
|
||||
"db_checked": False,
|
||||
"last_checked_at": 0.0,
|
||||
"tags": [],
|
||||
"civitai": None,
|
||||
"civitai_deleted": False,
|
||||
}
|
||||
|
||||
raw_unet = dict(raw_checkpoint)
|
||||
raw_unet.update(
|
||||
{
|
||||
"file_path": normalized_unet_file,
|
||||
"file_name": "beta",
|
||||
"model_name": "beta",
|
||||
"sha256": "hash-beta",
|
||||
}
|
||||
)
|
||||
|
||||
persisted = PersistedCacheData(
|
||||
raw_data=[raw_checkpoint, raw_unet],
|
||||
hash_rows=[],
|
||||
excluded_models=[],
|
||||
)
|
||||
|
||||
class FakePersistentCache:
|
||||
def load_cache(self, model_type: str):
|
||||
assert model_type == "checkpoint"
|
||||
return persisted
|
||||
|
||||
fake_cache = FakePersistentCache()
|
||||
monkeypatch.setattr(model_scanner, "get_persistent_cache", lambda: fake_cache)
|
||||
|
||||
ws_stub = RecordingWebSocketManager()
|
||||
monkeypatch.setattr(model_scanner, "ws_manager", ws_stub)
|
||||
|
||||
scanner = CheckpointScanner()
|
||||
|
||||
loaded = await scanner._load_persisted_cache("checkpoints")
|
||||
assert loaded is True
|
||||
|
||||
cache = await scanner.get_cached_data()
|
||||
types_by_path = {item["file_path"]: item.get("model_type") for item in cache.raw_data}
|
||||
|
||||
assert types_by_path[normalized_checkpoint_file] == "checkpoint"
|
||||
assert types_by_path[normalized_unet_file] == "diffusion_model"
|
||||
|
||||
assert ws_stub.payloads
|
||||
assert ws_stub.payloads[-1]["stage"] == "loading_cache"
|
||||
Reference in New Issue
Block a user