Compare commits

..

4 Commits

Author SHA1 Message Date
Will Miao
2eef629821 fix(checkpoints): singleflight pending hash calculation 2026-04-23 11:36:32 +08:00
Will Miao
658a04736d fix(recipes): save widget checkpoint metadata as dict 2026-04-23 11:20:20 +08:00
Will Miao
ef7f677933 chore(skills): add lora manager runtime context 2026-04-23 09:42:47 +08:00
Will Miao
63f0942452 fix(models): classify Anima as diffusion model 2026-04-23 07:35:34 +08:00
11 changed files with 885 additions and 5 deletions

View File

@@ -0,0 +1,69 @@
---
name: lora-manager-runtime-context
description: Inspect ComfyUI LoRA Manager runtime configuration and local diagnostic state. Use when debugging LoRA Manager issues that require locating or reading settings.json, active library paths, model metadata JSON sidecars, recipe metadata JSON files, example image folders, SQLite caches, symlink maps, download history, aria2 state, or other cache files under the LoRA Manager user config directory.
---
# LoRA Manager Runtime Context
## Core Rules
- Treat runtime state as local user data. Prefer read-only inspection unless the user explicitly asks for mutation.
- Never print secret-like settings values. Redact keys containing `key`, `token`, `secret`, `password`, `auth`, or `credential`, including `civitai_api_key`.
- Resolve paths from the runtime configuration before guessing. In this environment the settings file is normally `/home/miao/.config/ComfyUI-LoRA-Manager/settings.json`, but portable settings can override this through the repository `settings.json`.
- Use the active library when selecting per-library caches and paths. Read `active_library` from settings; fall back to `default` if missing.
- Normalize and expand `~` before comparing paths. Symlinks are common in this repo.
## Quick Start
Use the bundled helper for a safe first pass:
```bash
python .agents/skills/lora-manager-runtime-context/scripts/inspect_runtime_context.py summary
python .agents/skills/lora-manager-runtime-context/scripts/inspect_runtime_context.py caches
```
The script redacts sensitive settings, opens SQLite databases read-only, and reports inaccessible or locked databases as warnings.
For focused checks:
```bash
python .agents/skills/lora-manager-runtime-context/scripts/inspect_runtime_context.py recipes
python .agents/skills/lora-manager-runtime-context/scripts/inspect_runtime_context.py model --path /path/to/model.safetensors
python .agents/skills/lora-manager-runtime-context/scripts/inspect_runtime_context.py sqlite --db /path/to/cache.sqlite --limit 3
```
## Runtime Path Rules
- Settings directory: use `py/utils/settings_paths.py`. Default platform path is `platformdirs.user_config_dir("ComfyUI-LoRA-Manager", appauthor=False)`.
- Settings file: `<settings_dir>/settings.json`.
- Cache root: `<settings_dir>/cache`.
- Canonical cache files:
- Model cache: `cache/model/<active_library>.sqlite`.
- Recipe cache: `cache/recipe/<active_library>.sqlite`.
- Model update cache: `cache/model_update/<active_library>.sqlite`.
- Recipe FTS: `cache/fts/recipe_fts.sqlite`.
- Tag FTS: `cache/fts/tag_fts.sqlite`.
- Symlink map: `cache/symlink/symlink_map.json`.
- Download history: `cache/download_history/downloaded_versions.sqlite`.
- aria2 state: `cache/aria2/downloads.json`.
- Legacy cache locations may exist; prefer canonical paths unless diagnosing migrations.
## Data Location Rules
- Model roots come from `settings.folder_paths` and the active library payload under `settings.libraries[active_library]`.
- Model metadata JSON sidecars live next to the model file as `<model basename>.metadata.json`.
- Recipes root is `settings.recipes_path` when it is a non-empty string. If empty, use the first configured LoRA root plus `/recipes`.
- Recipe JSON files are named `*.recipe.json` under the recipes root and may be nested in folders.
- Example image root is `settings.example_images_path`.
- If multiple libraries are configured, example images are stored under `<example_images_path>/<sanitized_library>/<sha256>/`; otherwise they are under `<example_images_path>/<sha256>/`.
## Useful Cache Tables
- Model cache: `models`, `model_tags`, `hash_index`, `excluded_models`.
- Recipe cache: `recipes`, `cache_metadata`.
- Model update cache: `model_update_status`, `model_update_versions`.
- Tag FTS cache: `tags`, `fts_metadata`, plus FTS internal tables.
- Recipe FTS cache: `recipe_rowid`, `fts_metadata`, plus FTS internal tables.
- Download history: `downloaded_model_versions`.
Prefer querying only counts, schema, and a few sample rows unless the user asks for full output.

View File

@@ -0,0 +1,4 @@
interface:
display_name: "LoRA Manager Runtime Context"
short_description: "Inspect LoRA Manager runtime state"
default_prompt: "Use $lora-manager-runtime-context to inspect LoRA Manager settings, metadata paths, and caches for debugging."

View File

@@ -0,0 +1,381 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import os
import re
import shutil
import sqlite3
import sys
import tempfile
from pathlib import Path
from typing import Any
SECRET_PATTERN = re.compile(r"(key|token|secret|password|auth|credential)", re.IGNORECASE)
APP_NAME = "ComfyUI-LoRA-Manager"
CACHE_SQLITE = {
"model": ("model", "{library}.sqlite"),
"recipe": ("recipe", "{library}.sqlite"),
"model_update": ("model_update", "{library}.sqlite"),
"recipe_fts": ("fts", "recipe_fts.sqlite"),
"tag_fts": ("fts", "tag_fts.sqlite"),
"download_history": ("download_history", "downloaded_versions.sqlite"),
}
CACHE_JSON = {
"symlink": ("symlink", "symlink_map.json"),
"aria2": ("aria2", "downloads.json"),
}
def main() -> int:
parser = argparse.ArgumentParser(description="Inspect LoRA Manager runtime state read-only.")
subparsers = parser.add_subparsers(dest="command", required=True)
subparsers.add_parser("summary", help="Print redacted settings and resolved paths.")
subparsers.add_parser("caches", help="Print cache paths and SQLite table summaries.")
subparsers.add_parser("recipes", help="Print resolved recipes root and recipe JSON count.")
model_parser = subparsers.add_parser("model", help="Inspect a model metadata sidecar path.")
model_parser.add_argument("--path", required=True, help="Path to a model file or metadata JSON file.")
sqlite_parser = subparsers.add_parser("sqlite", help="Inspect a SQLite database read-only.")
sqlite_parser.add_argument("--db", required=True, help="Path to the SQLite database.")
sqlite_parser.add_argument("--limit", type=int, default=3, help="Rows to sample from each user table.")
args = parser.parse_args()
context = build_context()
if args.command == "summary":
print_json(summary_payload(context))
elif args.command == "caches":
print_json(caches_payload(context))
elif args.command == "recipes":
print_json(recipes_payload(context))
elif args.command == "model":
print_json(model_payload(args.path))
elif args.command == "sqlite":
print_json(sqlite_payload(Path(args.db).expanduser(), args.limit))
return 0
def build_context() -> dict[str, Any]:
settings_path = resolve_settings_path()
settings = load_json(settings_path)
settings_dir = settings_path.parent
active_library = settings.get("active_library") or "default"
safe_library = sanitize_library_name(str(active_library))
cache_root = settings_dir / "cache"
return {
"settings_path": str(settings_path),
"settings_dir": str(settings_dir),
"settings": settings,
"active_library": active_library,
"safe_library": safe_library,
"cache_root": str(cache_root),
"cache_paths": resolve_cache_paths(cache_root, safe_library),
}
def resolve_settings_path() -> Path:
repo_root = find_repo_root()
portable = repo_root / "settings.json"
if portable.exists():
payload = load_json(portable)
if isinstance(payload, dict) and payload.get("use_portable_settings") is True:
return portable
config_home = os.environ.get("XDG_CONFIG_HOME")
if config_home:
return Path(config_home).expanduser() / APP_NAME / "settings.json"
return Path.home() / ".config" / APP_NAME / "settings.json"
def find_repo_root() -> Path:
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "py").is_dir() and (parent / "standalone.py").exists():
return parent
return Path.cwd()
def load_json(path: Path) -> dict[str, Any]:
try:
with path.open("r", encoding="utf-8") as handle:
payload = json.load(handle)
except FileNotFoundError:
return {}
except json.JSONDecodeError as exc:
return {"_error": f"invalid JSON: {exc}"}
except OSError as exc:
return {"_error": f"unreadable: {exc}"}
return payload if isinstance(payload, dict) else {"_error": "JSON root is not an object"}
def resolve_cache_paths(cache_root: Path, library: str) -> dict[str, str]:
paths: dict[str, str] = {}
for name, (subdir, filename) in CACHE_SQLITE.items():
paths[name] = str(cache_root / subdir / filename.format(library=library))
for name, (subdir, filename) in CACHE_JSON.items():
paths[name] = str(cache_root / subdir / filename)
return paths
def summary_payload(context: dict[str, Any]) -> dict[str, Any]:
settings = context["settings"]
return {
"settings_path": context["settings_path"],
"settings_dir": context["settings_dir"],
"active_library": context["active_library"],
"settings": redact(settings),
"model_roots": model_roots(settings, context["active_library"]),
"recipes_root": str(resolve_recipes_root(settings, context["active_library"]) or ""),
"example_images": example_images_payload(settings, context["active_library"]),
"cache_root": context["cache_root"],
"cache_paths": context["cache_paths"],
}
def caches_payload(context: dict[str, Any]) -> dict[str, Any]:
caches: dict[str, Any] = {}
for name, path_string in context["cache_paths"].items():
path = Path(path_string)
item: dict[str, Any] = {
"path": str(path),
"exists": path.exists(),
"size": path.stat().st_size if path.exists() else None,
}
if path.suffix == ".sqlite":
item["sqlite"] = sqlite_payload(path, limit=0)
elif path.suffix == ".json":
item["json"] = json_file_summary(path)
caches[name] = item
return {"active_library": context["active_library"], "caches": caches}
def recipes_payload(context: dict[str, Any]) -> dict[str, Any]:
root = resolve_recipes_root(context["settings"], context["active_library"])
files: list[str] = []
if root and root.exists():
files = [str(path) for path in sorted(root.rglob("*.recipe.json"))[:20]]
return {
"recipes_root": str(root or ""),
"exists": bool(root and root.exists()),
"recipe_json_count": count_recipe_files(root),
"sample_recipe_json": files,
"recipe_cache": context["cache_paths"].get("recipe"),
}
def model_payload(raw_path: str) -> dict[str, Any]:
path = Path(raw_path).expanduser()
metadata_path = path if path.name.endswith(".metadata.json") else path.with_suffix(".metadata.json")
payload = {
"input_path": str(path),
"metadata_path": str(metadata_path),
"model_exists": path.exists(),
"metadata_exists": metadata_path.exists(),
}
if metadata_path.exists():
data = load_json(metadata_path)
payload["metadata_summary"] = redact(summarize_value(data))
return payload
def sqlite_payload(path: Path, limit: int = 3, allow_copy: bool = True) -> dict[str, Any]:
result: dict[str, Any] = {"path": str(path), "exists": path.exists(), "tables": {}}
if not path.exists():
return result
try:
conn = connect_sqlite_readonly(path)
except sqlite3.Error as exc:
result["error"] = str(exc)
return result
try:
table_rows = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
).fetchall()
for table_row in table_rows:
table = table_row["name"]
columns = [
row["name"]
for row in conn.execute(f"PRAGMA table_info({quote_identifier(table)})").fetchall()
]
table_info: dict[str, Any] = {"columns": columns}
try:
table_info["count"] = conn.execute(
f"SELECT COUNT(*) FROM {quote_identifier(table)}"
).fetchone()[0]
except sqlite3.Error as exc:
table_info["count_error"] = str(exc)
if limit > 0 and columns and not is_internal_sqlite_table(table):
try:
rows = conn.execute(
f"SELECT * FROM {quote_identifier(table)} LIMIT ?", (limit,)
).fetchall()
table_info["sample"] = [redact(dict(row)) for row in rows]
except sqlite3.Error as exc:
table_info["sample_error"] = str(exc)
result["tables"][table] = table_info
except sqlite3.Error as exc:
fallback = sqlite_copy_payload(path, limit, str(exc)) if allow_copy else None
if fallback is not None:
result.update(fallback)
else:
result["error"] = str(exc)
finally:
conn.close()
return result
def connect_sqlite_readonly(path: Path) -> sqlite3.Connection:
errors: list[str] = []
for query in ("mode=ro", "mode=ro&immutable=1"):
try:
conn = sqlite3.connect(f"file:{path}?{query}", uri=True)
conn.row_factory = sqlite3.Row
return conn
except sqlite3.Error as exc:
errors.append(f"{query}: {exc}")
raise sqlite3.OperationalError("; ".join(errors))
def sqlite_copy_payload(path: Path, limit: int, original_error: str) -> dict[str, Any] | None:
try:
with tempfile.TemporaryDirectory(prefix="lm-cache-inspect-") as temp_dir:
copy_path = Path(temp_dir) / path.name
shutil.copy2(path, copy_path)
payload = sqlite_payload(copy_path, limit, allow_copy=False)
payload["path"] = str(path)
payload["inspected_copy"] = True
payload["original_error"] = original_error
return payload
except Exception:
return None
def json_file_summary(path: Path) -> dict[str, Any]:
if not path.exists():
return {"exists": False}
data = load_json(path)
return {"exists": True, "summary": redact(summarize_value(data))}
def model_roots(settings: dict[str, Any], active_library: str) -> dict[str, list[str]]:
roots: dict[str, list[str]] = {}
sources = [settings]
library = settings.get("libraries", {}).get(active_library)
if isinstance(library, dict):
sources.insert(0, library)
for source in sources:
folder_paths = source.get("folder_paths")
if isinstance(folder_paths, dict):
for key, value in folder_paths.items():
roots.setdefault(key, []).extend(normalize_path_list(value))
for default_key, folder_key in (
("default_lora_root", "loras"),
("default_checkpoint_root", "checkpoints"),
("default_embedding_root", "embeddings"),
("default_unet_root", "unet"),
):
value = settings.get(default_key)
if isinstance(value, str) and value:
roots.setdefault(folder_key, []).append(expand_path(value))
return {key: dedupe(values) for key, values in roots.items()}
def resolve_recipes_root(settings: dict[str, Any], active_library: str) -> Path | None:
recipes_path = settings.get("recipes_path")
library = settings.get("libraries", {}).get(active_library)
if isinstance(library, dict) and isinstance(library.get("recipes_path"), str):
recipes_path = library["recipes_path"] or recipes_path
if isinstance(recipes_path, str) and recipes_path.strip():
return Path(expand_path(recipes_path.strip()))
lora_roots = model_roots(settings, active_library).get("loras") or []
return Path(lora_roots[0]) / "recipes" if lora_roots else None
def example_images_payload(settings: dict[str, Any], active_library: str) -> dict[str, Any]:
root = settings.get("example_images_path") or ""
libraries = settings.get("libraries")
library_count = len(libraries) if isinstance(libraries, dict) else 0
scoped = library_count > 1
root_path = Path(expand_path(root)) if isinstance(root, str) and root else None
library_root = root_path / sanitize_library_name(active_library) if root_path and scoped else root_path
return {
"root": str(root_path or ""),
"uses_library_scoped_folders": scoped,
"library_root": str(library_root or ""),
}
def count_recipe_files(root: Path | None) -> int:
if not root or not root.exists():
return 0
return sum(1 for _ in root.rglob("*.recipe.json"))
def normalize_path_list(value: Any) -> list[str]:
if isinstance(value, str):
return [expand_path(value)] if value else []
if isinstance(value, list):
return [expand_path(item) for item in value if isinstance(item, str) and item]
return []
def expand_path(value: str) -> str:
return str(Path(value).expanduser().resolve(strict=False))
def sanitize_library_name(name: str) -> str:
safe = re.sub(r"[^A-Za-z0-9_.-]", "_", name or "default")
return safe or "default"
def dedupe(values: list[str]) -> list[str]:
seen: set[str] = set()
result: list[str] = []
for value in values:
if value not in seen:
result.append(value)
seen.add(value)
return result
def redact(value: Any, key: str = "") -> Any:
if key and SECRET_PATTERN.search(key):
return "<redacted>"
if isinstance(value, dict):
return {str(k): redact(v, str(k)) for k, v in value.items()}
if isinstance(value, list):
return [redact(item) for item in value]
return value
def summarize_value(value: Any) -> Any:
if isinstance(value, dict):
return {key: summarize_value(item) for key, item in value.items()}
if isinstance(value, list):
return {
"type": "array",
"length": len(value),
"first": summarize_value(value[0]) if value else None,
}
return value
def quote_identifier(identifier: str) -> str:
return '"' + identifier.replace('"', '""') + '"'
def is_internal_sqlite_table(table: str) -> bool:
return table.startswith("sqlite_") or table.endswith(("_data", "_idx", "_docsize", "_config", "_content"))
def print_json(payload: Any) -> None:
json.dump(payload, sys.stdout, indent=2, ensure_ascii=False)
sys.stdout.write("\n")
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -780,8 +780,10 @@ NODE_EXTRACTORS = {
"GGUFLoaderKJ": KJNodesModelLoaderExtractor, # KJNodes
"DiffusionModelLoaderKJ": KJNodesModelLoaderExtractor, # KJNodes
"CheckpointLoaderKJ": CheckpointLoaderExtractor, # KJNodes
"CheckpointLoaderLM": CheckpointLoaderExtractor, # LoRA Manager
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
"UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor
"UNETLoaderLM": UNETLoaderExtractor, # LoRA Manager
"LoraLoader": LoraLoaderExtractor,
"LoraLoaderLM": LoraLoaderManagerExtractor,
"RgthreePowerLoraLoader": RgthreePowerLoraLoaderExtractor,

View File

@@ -1,3 +1,4 @@
import asyncio
import json
import logging
import os
@@ -36,6 +37,9 @@ class CheckpointScanner(ModelScanner):
file_extensions=file_extensions,
hash_index=ModelHashIndex(),
)
if not hasattr(self, "_hash_calculation_lock"):
self._hash_calculation_lock = asyncio.Lock()
self._hash_calculation_tasks: dict[str, asyncio.Task[Optional[str]]] = {}
async def _create_default_metadata(
self, file_path: str
@@ -88,7 +92,7 @@ class CheckpointScanner(ModelScanner):
return None
async def calculate_hash_for_model(self, file_path: str) -> Optional[str]:
"""Calculate hash for a checkpoint on-demand.
"""Calculate hash for a checkpoint on-demand with per-file singleflight.
Args:
file_path: Path to the model file
@@ -96,14 +100,65 @@ class CheckpointScanner(ModelScanner):
Returns:
SHA256 hash string, or None if calculation failed
"""
from ..utils.file_utils import calculate_sha256
try:
real_path = os.path.realpath(file_path)
if not os.path.exists(real_path):
logger.error(f"File not found for hash calculation: {file_path}")
return None
metadata, _ = await MetadataManager.load_metadata(
file_path, self.model_class
)
if (
metadata is not None
and metadata.hash_status == "completed"
and metadata.sha256
):
return metadata.sha256
async with self._hash_calculation_lock:
metadata, _ = await MetadataManager.load_metadata(
file_path, self.model_class
)
if (
metadata is not None
and metadata.hash_status == "completed"
and metadata.sha256
):
return metadata.sha256
task = self._hash_calculation_tasks.get(real_path)
if task is None:
task = asyncio.create_task(
self._run_hash_calculation_task(file_path, real_path)
)
self._hash_calculation_tasks[real_path] = task
return await asyncio.shield(task)
except Exception as e:
logger.error(f"Error calculating hash for {file_path}: {e}")
return None
async def _run_hash_calculation_task(
self, file_path: str, real_path: str
) -> Optional[str]:
"""Run a hash calculation task and remove it from the in-flight map."""
try:
return await self._calculate_hash_for_model_uncached(file_path, real_path)
finally:
task = asyncio.current_task()
async with self._hash_calculation_lock:
if self._hash_calculation_tasks.get(real_path) is task:
del self._hash_calculation_tasks[real_path]
async def _calculate_hash_for_model_uncached(
self, file_path: str, real_path: str
) -> Optional[str]:
"""Calculate hash for a checkpoint without checking in-flight tasks."""
from ..utils.file_utils import calculate_sha256
try:
# Load current metadata
metadata, should_skip = await MetadataManager.load_metadata(
file_path, self.model_class

View File

@@ -1815,6 +1815,15 @@ class RecipeScanner:
return await self._lora_scanner.get_model_info_by_name(name)
async def get_local_checkpoint(self, name: str) -> Optional[Dict[str, Any]]:
"""Lookup a local checkpoint model by name."""
checkpoint_scanner = getattr(self, "_checkpoint_scanner", None)
if not checkpoint_scanner or not name:
return None
return await checkpoint_scanner.get_model_info_by_name(name)
async def get_paginated_data(
self,
page: int,

View File

@@ -508,6 +508,10 @@ class RecipePersistenceService:
most_common_base_model = (
max(base_model_counts.items(), key=lambda item: item[1])[0] if base_model_counts else ""
)
checkpoint_entry = await self._build_widget_checkpoint_entry(
recipe_scanner,
metadata.get("checkpoint"),
)
recipe_data = {
"id": recipe_id,
@@ -515,9 +519,8 @@ class RecipePersistenceService:
"title": recipe_name,
"modified": time.time(),
"created_date": time.time(),
"base_model": most_common_base_model,
"base_model": most_common_base_model or (checkpoint_entry or {}).get("baseModel", ""),
"loras": loras_data,
"checkpoint": self._sanitize_checkpoint_entry(metadata.get("checkpoint", "")),
"gen_params": {
key: value
for key, value in metadata.items()
@@ -525,6 +528,8 @@ class RecipePersistenceService:
},
"loras_stack": lora_stack,
}
if checkpoint_entry:
recipe_data["checkpoint"] = checkpoint_entry
json_filename = f"{recipe_id}.recipe.json"
json_path = os.path.join(recipes_dir, json_filename)
@@ -546,6 +551,91 @@ class RecipePersistenceService:
# Helper methods ---------------------------------------------------
async def _build_widget_checkpoint_entry(
self,
recipe_scanner,
checkpoint_raw: Any,
) -> Optional[dict[str, Any]]:
"""Build recipe checkpoint metadata from widget generation metadata."""
if isinstance(checkpoint_raw, dict):
return self._sanitize_checkpoint_entry(checkpoint_raw)
if not isinstance(checkpoint_raw, str):
return None
checkpoint_name = checkpoint_raw.strip()
if not checkpoint_name:
return None
file_name = os.path.splitext(os.path.basename(checkpoint_name))[0]
checkpoint_info = await self._lookup_widget_checkpoint(
recipe_scanner,
checkpoint_name,
)
if not checkpoint_info:
return {
"type": "checkpoint",
"name": checkpoint_name,
"file_name": file_name,
"hash": "",
}
civitai = checkpoint_info.get("civitai") or {}
civitai_model = civitai.get("model") or {}
file_path = checkpoint_info.get("file_path") or checkpoint_info.get("path") or ""
cached_file_name = (
checkpoint_info.get("file_name")
or (os.path.splitext(os.path.basename(file_path))[0] if file_path else "")
or file_name
)
return {
"type": "checkpoint",
"modelId": civitai_model.get("id", 0),
"modelVersionId": civitai.get("id", 0),
"name": civitai_model.get("name") or checkpoint_info.get("model_name") or checkpoint_name,
"version": civitai.get("name", ""),
"hash": (checkpoint_info.get("sha256") or checkpoint_info.get("hash") or "").lower(),
"file_name": cached_file_name,
"modelName": civitai_model.get("name", ""),
"modelVersionName": civitai.get("name", ""),
"baseModel": checkpoint_info.get("base_model") or civitai.get("baseModel", ""),
}
async def _lookup_widget_checkpoint(
self,
recipe_scanner,
checkpoint_name: str,
) -> Optional[dict[str, Any]]:
lookup = getattr(recipe_scanner, "get_local_checkpoint", None)
if not callable(lookup):
return None
candidates = []
for candidate in (
checkpoint_name,
os.path.basename(checkpoint_name),
os.path.splitext(os.path.basename(checkpoint_name))[0],
):
if candidate and candidate not in candidates:
candidates.append(candidate)
for candidate in candidates:
try:
checkpoint_info = await lookup(candidate)
except Exception as exc:
self._logger.debug(
"Failed to lookup checkpoint %s while saving widget recipe: %s",
candidate,
exc,
)
continue
if checkpoint_info:
return checkpoint_info
return None
def _extract_checkpoint_entry(self, metadata: dict[str, Any]) -> Optional[dict[str, Any]]:
"""Pull a checkpoint entry from various metadata locations."""

View File

@@ -100,6 +100,7 @@ DEFAULT_PRIORITY_TAG_CONFIG = {
# These model types are incorrectly labeled as "checkpoint" by CivitAI but are actually diffusion models
DIFFUSION_MODEL_BASE_MODELS = frozenset(
[
"Anima",
"ZImageTurbo",
"ZImageBase",
"Wan Video 1.3B t2v",

View File

@@ -354,3 +354,33 @@ def test_lora_manager_cache_updates_when_loras_removed(metadata_registry):
metadata = metadata_registry.get_metadata("prompt3")
assert "lora_node" not in metadata[LORAS]
def test_lora_manager_checkpoint_and_unet_loaders_extract_models(metadata_registry):
metadata_registry.start_collection("prompt1")
metadata_registry.record_node_execution(
"checkpoint_node",
"CheckpointLoaderLM",
{"ckpt_name": ["models/checkpoint.safetensors"]},
None,
)
metadata_registry.record_node_execution(
"unet_node",
"UNETLoaderLM",
{"unet_name": ["models/diffusion_model.safetensors"], "weight_dtype": ["default"]},
None,
)
metadata = metadata_registry.get_metadata("prompt1")
assert metadata[MODELS]["checkpoint_node"] == {
"name": "models/checkpoint.safetensors",
"type": "checkpoint",
"node_id": "checkpoint_node",
}
assert metadata[MODELS]["unet_node"] == {
"name": "models/diffusion_model.safetensors",
"type": "checkpoint",
"node_id": "unet_node",
}

View File

@@ -1,5 +1,6 @@
"""Tests for checkpoint lazy hash calculation feature."""
import asyncio
import json
import os
from pathlib import Path
@@ -199,6 +200,160 @@ async def test_calculate_hash_skips_if_already_completed(tmp_path: Path, monkeyp
mock_calc.assert_not_called(), "Should not recalculate if already completed"
@pytest.mark.asyncio
async def test_calculate_hash_for_model_singleflight_same_file(
tmp_path: Path, monkeypatch
):
"""Concurrent calls for the same checkpoint should share one SHA256 task."""
checkpoints_root = tmp_path / "checkpoints"
checkpoints_root.mkdir()
checkpoint_file = checkpoints_root / "test_model.safetensors"
checkpoint_file.write_text("fake content", encoding="utf-8")
normalized_root = _normalize(checkpoints_root)
normalized_file = _normalize(checkpoint_file)
real_file = os.path.realpath(normalized_file)
monkeypatch.setattr(
model_scanner.config,
"base_models_roots",
[normalized_root],
raising=False,
)
scanner = CheckpointScanner()
metadata = await scanner._create_default_metadata(normalized_file)
assert metadata is not None
calls = []
async def fake_calculate_sha256(file_path: str) -> str:
calls.append(file_path)
await asyncio.sleep(0.01)
return "a" * 64
with patch(
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
):
results = await asyncio.gather(
*[scanner.calculate_hash_for_model(normalized_file) for _ in range(8)]
)
assert calls == [real_file]
assert results == ["a" * 64] * 8
assert scanner._hash_calculation_tasks == {}
@pytest.mark.asyncio
async def test_calculate_hash_for_model_cleans_task_after_failure_and_retries(
tmp_path: Path, monkeypatch
):
"""A failed in-flight task should be removed so later calls can retry."""
checkpoints_root = tmp_path / "checkpoints"
checkpoints_root.mkdir()
checkpoint_file = checkpoints_root / "retry_model.safetensors"
checkpoint_file.write_text("fake content", encoding="utf-8")
normalized_root = _normalize(checkpoints_root)
normalized_file = _normalize(checkpoint_file)
monkeypatch.setattr(
model_scanner.config,
"base_models_roots",
[normalized_root],
raising=False,
)
scanner = CheckpointScanner()
metadata = await scanner._create_default_metadata(normalized_file)
assert metadata is not None
attempts = 0
async def fake_calculate_sha256(_file_path: str) -> str:
nonlocal attempts
attempts += 1
if attempts == 1:
raise RuntimeError("hash failed")
return "b" * 64
with patch(
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
):
assert await scanner.calculate_hash_for_model(normalized_file) is None
assert scanner._hash_calculation_tasks == {}
hash_result = await scanner.calculate_hash_for_model(normalized_file)
assert hash_result == "b" * 64
assert attempts == 2
assert scanner._hash_calculation_tasks == {}
@pytest.mark.asyncio
async def test_calculate_hash_for_model_uses_separate_tasks_for_different_files(
tmp_path: Path, monkeypatch
):
"""Different checkpoint files should not share the same in-flight task."""
checkpoints_root = tmp_path / "checkpoints"
checkpoints_root.mkdir()
checkpoint_files = [
checkpoints_root / "model_a.safetensors",
checkpoints_root / "model_b.safetensors",
]
for checkpoint_file in checkpoint_files:
checkpoint_file.write_text(
f"fake content for {checkpoint_file.name}", encoding="utf-8"
)
normalized_root = _normalize(checkpoints_root)
normalized_files = [
_normalize(checkpoint_file) for checkpoint_file in checkpoint_files
]
real_files = {os.path.realpath(file_path) for file_path in normalized_files}
monkeypatch.setattr(
model_scanner.config,
"base_models_roots",
[normalized_root],
raising=False,
)
scanner = CheckpointScanner()
for normalized_file in normalized_files:
metadata = await scanner._create_default_metadata(normalized_file)
assert metadata is not None
calls = []
hashes_by_path = {
os.path.realpath(normalized_files[0]): "c" * 64,
os.path.realpath(normalized_files[1]): "d" * 64,
}
async def fake_calculate_sha256(file_path: str) -> str:
calls.append(file_path)
await asyncio.sleep(0.01)
return hashes_by_path[file_path]
with patch(
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
):
results = await asyncio.gather(
*[
scanner.calculate_hash_for_model(file_path)
for file_path in normalized_files
]
)
assert set(calls) == real_files
assert len(calls) == 2
assert set(results) == {"c" * 64, "d" * 64}
assert scanner._hash_calculation_tasks == {}
@pytest.mark.asyncio
async def test_calculate_hash_for_model_bootstraps_missing_metadata(tmp_path: Path, monkeypatch):
"""Test that calculate_hash_for_model creates pending metadata when it is missing."""

View File

@@ -491,6 +491,9 @@ async def test_save_recipe_from_widget_allows_empty_lora(tmp_path):
async def get_local_lora(self, name): # pragma: no cover - no lookups expected
return None
async def get_local_checkpoint(self, name):
return None
async def add_recipe(self, recipe_data):
self.added.append(recipe_data)
@@ -518,9 +521,90 @@ async def test_save_recipe_from_widget_allows_empty_lora(tmp_path):
assert stored["loras"] == []
assert stored["title"] == "recipe"
assert stored["checkpoint"] == {
"type": "checkpoint",
"name": "base-model.safetensors",
"file_name": "base-model",
"hash": "",
}
assert scanner.added and scanner.added[0]["loras"] == []
@pytest.mark.asyncio
async def test_save_recipe_from_widget_enriches_checkpoint_from_local_cache(tmp_path):
exif_utils = DummyExifUtils()
class DummyScanner:
def __init__(self, root):
self.recipes_dir = str(root)
self.added = []
self.checkpoint_queries = []
async def get_local_lora(self, name): # pragma: no cover - no loras
return None
async def get_local_checkpoint(self, name):
self.checkpoint_queries.append(name)
if name != "matched-model":
return None
return {
"file_name": "matched-model",
"file_path": "/models/checkpoints/folder/matched-model.safetensors",
"sha256": "ABC123",
"base_model": "Illustrious",
"civitai": {
"id": 456,
"name": "v1.0",
"baseModel": "Illustrious",
"model": {
"id": 123,
"name": "Matched Model",
},
},
}
async def add_recipe(self, recipe_data):
self.added.append(recipe_data)
scanner = DummyScanner(tmp_path)
service = RecipePersistenceService(
exif_utils=exif_utils,
card_preview_width=512,
logger=logging.getLogger("test"),
)
result = await service.save_recipe_from_widget(
recipe_scanner=scanner,
metadata={
"loras": "",
"checkpoint": "folder/matched-model.safetensors",
"prompt": "a calm scene",
},
image_bytes=b"image-bytes",
)
stored = json.loads(Path(result.payload["json_path"]).read_text())
assert scanner.checkpoint_queries == [
"folder/matched-model.safetensors",
"matched-model.safetensors",
"matched-model",
]
assert stored["base_model"] == "Illustrious"
assert stored["checkpoint"] == {
"type": "checkpoint",
"modelId": 123,
"modelVersionId": 456,
"name": "Matched Model",
"version": "v1.0",
"hash": "abc123",
"file_name": "matched-model",
"modelName": "Matched Model",
"modelVersionName": "v1.0",
"baseModel": "Illustrious",
}
@pytest.mark.asyncio
async def test_move_recipe_updates_paths(tmp_path):
exif_utils = DummyExifUtils()