mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-04-12 05:42:14 -03:00
fix(usage): resolve checkpoint hashes from disk
This commit is contained in:
@@ -105,12 +105,18 @@ class CheckpointScanner(ModelScanner):
|
||||
return None
|
||||
|
||||
# Load current metadata
|
||||
metadata, _ = await MetadataManager.load_metadata(
|
||||
metadata, should_skip = await MetadataManager.load_metadata(
|
||||
file_path, self.model_class
|
||||
)
|
||||
if metadata is None:
|
||||
if should_skip:
|
||||
logger.error(f"Invalid metadata found for {file_path}")
|
||||
return None
|
||||
created_metadata = await self._create_default_metadata(file_path)
|
||||
if created_metadata is None:
|
||||
logger.error(f"No metadata found for {file_path}")
|
||||
return None
|
||||
metadata = created_metadata
|
||||
|
||||
# Check if hash is already calculated
|
||||
if metadata.hash_status == "completed" and metadata.sha256:
|
||||
|
||||
@@ -29,6 +29,18 @@ if not standalone_mode:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_CHECKPOINT_EXTENSIONS = {
|
||||
".ckpt",
|
||||
".pt",
|
||||
".pt2",
|
||||
".bin",
|
||||
".pth",
|
||||
".safetensors",
|
||||
".pkl",
|
||||
".sft",
|
||||
".gguf",
|
||||
}
|
||||
|
||||
class UsageStats:
|
||||
"""Track usage statistics for models and save to JSON"""
|
||||
|
||||
@@ -335,6 +347,55 @@ class UsageStats:
|
||||
|
||||
return None
|
||||
|
||||
async def _find_checkpoint_file_on_disk(self, checkpoint_scanner, model_name: str):
|
||||
"""Search checkpoint roots directly for a matching file.
|
||||
|
||||
This is used when usage tracking sees a checkpoint name before the cache has
|
||||
been refreshed. The lookup is intentionally exact: we only match the model
|
||||
basename and supported checkpoint extensions.
|
||||
"""
|
||||
get_model_roots = getattr(checkpoint_scanner, "get_model_roots", None)
|
||||
if not callable(get_model_roots):
|
||||
return None
|
||||
|
||||
roots = [root for root in get_model_roots() if root]
|
||||
if not roots:
|
||||
return None
|
||||
|
||||
supported_extensions = getattr(
|
||||
checkpoint_scanner, "file_extensions", _DEFAULT_CHECKPOINT_EXTENSIONS
|
||||
)
|
||||
if not isinstance(supported_extensions, (set, frozenset, list, tuple)):
|
||||
supported_extensions = _DEFAULT_CHECKPOINT_EXTENSIONS
|
||||
|
||||
normalized_name = self._normalize_model_lookup_name(model_name)
|
||||
matches: list[str] = []
|
||||
|
||||
for root_path in roots:
|
||||
if not os.path.exists(root_path):
|
||||
continue
|
||||
|
||||
for dirpath, _dirnames, filenames in os.walk(root_path):
|
||||
for filename in filenames:
|
||||
extension = os.path.splitext(filename)[1].lower()
|
||||
if extension not in supported_extensions:
|
||||
continue
|
||||
|
||||
if os.path.splitext(filename)[0] != normalized_name:
|
||||
continue
|
||||
|
||||
matches.append(os.path.join(dirpath, filename).replace(os.sep, "/"))
|
||||
|
||||
if len(matches) > 1:
|
||||
logger.warning(
|
||||
"Multiple checkpoint files matched '%s'; skipping usage tracking: %s",
|
||||
normalized_name,
|
||||
", ".join(matches),
|
||||
)
|
||||
return None
|
||||
|
||||
return matches[0] if matches else None
|
||||
|
||||
async def _resolve_checkpoint_hash(self, checkpoint_scanner, model_name: str):
|
||||
"""Resolve a checkpoint hash, calculating pending hashes on demand when needed."""
|
||||
model_filename = self._normalize_model_lookup_name(model_name)
|
||||
@@ -343,27 +404,49 @@ class UsageStats:
|
||||
return model_hash
|
||||
|
||||
cached_entry = await self._find_cached_checkpoint_entry(checkpoint_scanner, model_name)
|
||||
if not cached_entry:
|
||||
logger.warning(f"No hash found for checkpoint '{model_filename}', skipping usage tracking")
|
||||
return None
|
||||
|
||||
if cached_entry:
|
||||
cached_hash = cached_entry.get("sha256")
|
||||
if cached_hash:
|
||||
return cached_hash
|
||||
|
||||
if cached_entry.get("hash_status") == "pending":
|
||||
calculate_hash = getattr(checkpoint_scanner, "calculate_hash_for_model", None)
|
||||
file_path = cached_entry.get("file_path")
|
||||
if callable(calculate_hash) and file_path:
|
||||
calculated_hash = await calculate_hash(file_path)
|
||||
if calculated_hash:
|
||||
return calculated_hash
|
||||
hash_status = cached_entry.get("hash_status")
|
||||
if hash_status and hash_status != "pending":
|
||||
logger.warning(
|
||||
f"Failed to calculate pending hash for checkpoint '{model_filename}', skipping usage tracking"
|
||||
"Checkpoint '%s' has hash_status=%s; skipping usage tracking",
|
||||
model_filename,
|
||||
hash_status,
|
||||
)
|
||||
return None
|
||||
|
||||
logger.warning(f"No hash found for checkpoint '{model_filename}', skipping usage tracking")
|
||||
file_path = cached_entry.get("file_path") if cached_entry else None
|
||||
if not file_path:
|
||||
file_path = await self._find_checkpoint_file_on_disk(
|
||||
checkpoint_scanner, model_name
|
||||
)
|
||||
|
||||
if not file_path:
|
||||
logger.warning(
|
||||
f"No hash found for checkpoint '{model_filename}', skipping usage tracking"
|
||||
)
|
||||
return None
|
||||
|
||||
calculate_hash = getattr(checkpoint_scanner, "calculate_hash_for_model", None)
|
||||
if not callable(calculate_hash):
|
||||
logger.warning("Checkpoint scanner not available for usage tracking")
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
"Calculating hash for checkpoint '%s' from %s",
|
||||
model_filename,
|
||||
file_path,
|
||||
)
|
||||
calculated_hash = await calculate_hash(file_path)
|
||||
if calculated_hash:
|
||||
return calculated_hash
|
||||
|
||||
logger.warning(
|
||||
f"Failed to calculate hash for checkpoint '{model_filename}', skipping usage tracking"
|
||||
)
|
||||
return None
|
||||
|
||||
async def _process_checkpoints(self, models_data, today_date):
|
||||
|
||||
@@ -199,6 +199,47 @@ async def test_calculate_hash_skips_if_already_completed(tmp_path: Path, monkeyp
|
||||
mock_calc.assert_not_called(), "Should not recalculate if already completed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_hash_for_model_bootstraps_missing_metadata(tmp_path: Path, monkeypatch):
|
||||
"""Test that calculate_hash_for_model creates pending metadata when it is missing."""
|
||||
checkpoints_root = tmp_path / "checkpoints"
|
||||
checkpoints_root.mkdir()
|
||||
|
||||
checkpoint_file = checkpoints_root / "bootstrap_model.gguf"
|
||||
checkpoint_file.write_text("fake content for hashing", encoding="utf-8")
|
||||
|
||||
normalized_root = _normalize(checkpoints_root)
|
||||
normalized_file = _normalize(checkpoint_file)
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"base_models_roots",
|
||||
[normalized_root],
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"checkpoints_roots",
|
||||
[normalized_root],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
scanner = CheckpointScanner()
|
||||
|
||||
hash_result = await scanner.calculate_hash_for_model(normalized_file)
|
||||
|
||||
assert hash_result is not None, "Hash calculation should succeed without existing metadata"
|
||||
assert len(hash_result) == 64, "SHA256 should be 64 hex characters"
|
||||
assert scanner.get_hash_by_filename("bootstrap_model") == hash_result
|
||||
|
||||
metadata_file = checkpoints_root / "bootstrap_model.metadata.json"
|
||||
with open(metadata_file, "r", encoding="utf-8") as f:
|
||||
saved_data = json.load(f)
|
||||
|
||||
assert saved_data.get("sha256") == hash_result, "sha256 should be updated"
|
||||
assert saved_data.get("hash_status") == "completed", "hash_status should be 'completed'"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_all_pending_hashes(tmp_path: Path, monkeypatch):
|
||||
"""Test bulk hash calculation for all pending checkpoints."""
|
||||
|
||||
@@ -194,6 +194,84 @@ async def test_usage_stats_calculates_pending_checkpoint_hash_on_demand(tmp_path
|
||||
await _finalize_usage_stats(tasks)
|
||||
|
||||
|
||||
async def test_usage_stats_skips_failed_checkpoint_hash_retry(tmp_path, monkeypatch):
|
||||
stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch)
|
||||
|
||||
metadata_payload = {
|
||||
"models": {
|
||||
"1": {"type": "checkpoint", "name": "failed_model.safetensors"},
|
||||
},
|
||||
"loras": {},
|
||||
}
|
||||
|
||||
checkpoint_cache = SimpleNamespace(
|
||||
raw_data=[
|
||||
{
|
||||
"file_name": "failed_model",
|
||||
"model_name": "failed_model",
|
||||
"file_path": "/models/failed_model.safetensors",
|
||||
"sha256": "",
|
||||
"hash_status": "failed",
|
||||
}
|
||||
]
|
||||
)
|
||||
checkpoint_scanner = SimpleNamespace(
|
||||
get_hash_by_filename=lambda name: None,
|
||||
get_cached_data=AsyncMock(return_value=checkpoint_cache),
|
||||
calculate_hash_for_model=AsyncMock(return_value=None),
|
||||
)
|
||||
lora_scanner = SimpleNamespace(get_hash_by_filename=lambda name: None)
|
||||
|
||||
monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner))
|
||||
monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=lora_scanner))
|
||||
|
||||
await stats._process_metadata(metadata_payload)
|
||||
|
||||
checkpoint_scanner.calculate_hash_for_model.assert_not_awaited()
|
||||
assert stats.stats["checkpoints"] == {}
|
||||
|
||||
await _finalize_usage_stats(tasks)
|
||||
|
||||
|
||||
async def test_usage_stats_resolves_manually_copied_checkpoint_from_disk(tmp_path, monkeypatch):
|
||||
stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch)
|
||||
|
||||
checkpoints_root = tmp_path / "checkpoints"
|
||||
checkpoints_root.mkdir()
|
||||
checkpoint_file = checkpoints_root / "Chroma1-HD-Q8_0.gguf"
|
||||
checkpoint_file.write_text("manual checkpoint content", encoding="utf-8")
|
||||
|
||||
metadata_payload = {
|
||||
"models": {
|
||||
"1": {"type": "checkpoint", "name": "Chroma1-HD-Q8_0"},
|
||||
},
|
||||
"loras": {},
|
||||
}
|
||||
|
||||
checkpoint_cache = SimpleNamespace(raw_data=[])
|
||||
checkpoint_scanner = SimpleNamespace(
|
||||
get_hash_by_filename=lambda name: None,
|
||||
get_cached_data=AsyncMock(return_value=checkpoint_cache),
|
||||
get_model_roots=lambda: [str(checkpoints_root)],
|
||||
file_extensions={".ckpt", ".pt", ".pt2", ".bin", ".pth", ".safetensors", ".pkl", ".sft", ".gguf"},
|
||||
calculate_hash_for_model=AsyncMock(return_value="resolved-hash"),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner))
|
||||
monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=None))
|
||||
|
||||
await stats._process_metadata(metadata_payload)
|
||||
|
||||
checkpoint_scanner.calculate_hash_for_model.assert_awaited_once_with(
|
||||
str(checkpoint_file).replace(os.sep, "/")
|
||||
)
|
||||
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
assert stats.stats["checkpoints"]["resolved-hash"]["history"][today] == 1
|
||||
|
||||
await _finalize_usage_stats(tasks)
|
||||
|
||||
|
||||
async def test_usage_stats_skips_name_fallback_for_missing_lora_hash(tmp_path, monkeypatch):
|
||||
stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user