mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 16:36:45 -03:00
Compare commits
4 Commits
a1dff6dd47
...
2eef629821
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2eef629821 | ||
|
|
658a04736d | ||
|
|
ef7f677933 | ||
|
|
63f0942452 |
69
.agents/skills/lora-manager-runtime-context/SKILL.md
Normal file
69
.agents/skills/lora-manager-runtime-context/SKILL.md
Normal file
@@ -0,0 +1,69 @@
|
||||
---
|
||||
name: lora-manager-runtime-context
|
||||
description: Inspect ComfyUI LoRA Manager runtime configuration and local diagnostic state. Use when debugging LoRA Manager issues that require locating or reading settings.json, active library paths, model metadata JSON sidecars, recipe metadata JSON files, example image folders, SQLite caches, symlink maps, download history, aria2 state, or other cache files under the LoRA Manager user config directory.
|
||||
---
|
||||
|
||||
# LoRA Manager Runtime Context
|
||||
|
||||
## Core Rules
|
||||
|
||||
- Treat runtime state as local user data. Prefer read-only inspection unless the user explicitly asks for mutation.
|
||||
- Never print secret-like settings values. Redact keys containing `key`, `token`, `secret`, `password`, `auth`, or `credential`, including `civitai_api_key`.
|
||||
- Resolve paths from the runtime configuration before guessing. In this environment the settings file is normally `/home/miao/.config/ComfyUI-LoRA-Manager/settings.json`, but portable settings can override this through the repository `settings.json`.
|
||||
- Use the active library when selecting per-library caches and paths. Read `active_library` from settings; fall back to `default` if missing.
|
||||
- Normalize and expand `~` before comparing paths. Symlinks are common in this repo.
|
||||
|
||||
## Quick Start
|
||||
|
||||
Use the bundled helper for a safe first pass:
|
||||
|
||||
```bash
|
||||
python .agents/skills/lora-manager-runtime-context/scripts/inspect_runtime_context.py summary
|
||||
python .agents/skills/lora-manager-runtime-context/scripts/inspect_runtime_context.py caches
|
||||
```
|
||||
|
||||
The script redacts sensitive settings, opens SQLite databases read-only, and reports inaccessible or locked databases as warnings.
|
||||
|
||||
For focused checks:
|
||||
|
||||
```bash
|
||||
python .agents/skills/lora-manager-runtime-context/scripts/inspect_runtime_context.py recipes
|
||||
python .agents/skills/lora-manager-runtime-context/scripts/inspect_runtime_context.py model --path /path/to/model.safetensors
|
||||
python .agents/skills/lora-manager-runtime-context/scripts/inspect_runtime_context.py sqlite --db /path/to/cache.sqlite --limit 3
|
||||
```
|
||||
|
||||
## Runtime Path Rules
|
||||
|
||||
- Settings directory: use `py/utils/settings_paths.py`. Default platform path is `platformdirs.user_config_dir("ComfyUI-LoRA-Manager", appauthor=False)`.
|
||||
- Settings file: `<settings_dir>/settings.json`.
|
||||
- Cache root: `<settings_dir>/cache`.
|
||||
- Canonical cache files:
|
||||
- Model cache: `cache/model/<active_library>.sqlite`.
|
||||
- Recipe cache: `cache/recipe/<active_library>.sqlite`.
|
||||
- Model update cache: `cache/model_update/<active_library>.sqlite`.
|
||||
- Recipe FTS: `cache/fts/recipe_fts.sqlite`.
|
||||
- Tag FTS: `cache/fts/tag_fts.sqlite`.
|
||||
- Symlink map: `cache/symlink/symlink_map.json`.
|
||||
- Download history: `cache/download_history/downloaded_versions.sqlite`.
|
||||
- aria2 state: `cache/aria2/downloads.json`.
|
||||
- Legacy cache locations may exist; prefer canonical paths unless diagnosing migrations.
|
||||
|
||||
## Data Location Rules
|
||||
|
||||
- Model roots come from `settings.folder_paths` and the active library payload under `settings.libraries[active_library]`.
|
||||
- Model metadata JSON sidecars live next to the model file as `<model basename>.metadata.json`.
|
||||
- Recipes root is `settings.recipes_path` when it is a non-empty string. If empty, use the first configured LoRA root plus `/recipes`.
|
||||
- Recipe JSON files are named `*.recipe.json` under the recipes root and may be nested in folders.
|
||||
- Example image root is `settings.example_images_path`.
|
||||
- If multiple libraries are configured, example images are stored under `<example_images_path>/<sanitized_library>/<sha256>/`; otherwise they are under `<example_images_path>/<sha256>/`.
|
||||
|
||||
## Useful Cache Tables
|
||||
|
||||
- Model cache: `models`, `model_tags`, `hash_index`, `excluded_models`.
|
||||
- Recipe cache: `recipes`, `cache_metadata`.
|
||||
- Model update cache: `model_update_status`, `model_update_versions`.
|
||||
- Tag FTS cache: `tags`, `fts_metadata`, plus FTS internal tables.
|
||||
- Recipe FTS cache: `recipe_rowid`, `fts_metadata`, plus FTS internal tables.
|
||||
- Download history: `downloaded_model_versions`.
|
||||
|
||||
Prefer querying only counts, schema, and a few sample rows unless the user asks for full output.
|
||||
@@ -0,0 +1,4 @@
|
||||
interface:
|
||||
display_name: "LoRA Manager Runtime Context"
|
||||
short_description: "Inspect LoRA Manager runtime state"
|
||||
default_prompt: "Use $lora-manager-runtime-context to inspect LoRA Manager settings, metadata paths, and caches for debugging."
|
||||
381
.agents/skills/lora-manager-runtime-context/scripts/inspect_runtime_context.py
Executable file
381
.agents/skills/lora-manager-runtime-context/scripts/inspect_runtime_context.py
Executable file
@@ -0,0 +1,381 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sqlite3
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
SECRET_PATTERN = re.compile(r"(key|token|secret|password|auth|credential)", re.IGNORECASE)
|
||||
APP_NAME = "ComfyUI-LoRA-Manager"
|
||||
CACHE_SQLITE = {
|
||||
"model": ("model", "{library}.sqlite"),
|
||||
"recipe": ("recipe", "{library}.sqlite"),
|
||||
"model_update": ("model_update", "{library}.sqlite"),
|
||||
"recipe_fts": ("fts", "recipe_fts.sqlite"),
|
||||
"tag_fts": ("fts", "tag_fts.sqlite"),
|
||||
"download_history": ("download_history", "downloaded_versions.sqlite"),
|
||||
}
|
||||
CACHE_JSON = {
|
||||
"symlink": ("symlink", "symlink_map.json"),
|
||||
"aria2": ("aria2", "downloads.json"),
|
||||
}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Inspect LoRA Manager runtime state read-only.")
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
subparsers.add_parser("summary", help="Print redacted settings and resolved paths.")
|
||||
subparsers.add_parser("caches", help="Print cache paths and SQLite table summaries.")
|
||||
subparsers.add_parser("recipes", help="Print resolved recipes root and recipe JSON count.")
|
||||
|
||||
model_parser = subparsers.add_parser("model", help="Inspect a model metadata sidecar path.")
|
||||
model_parser.add_argument("--path", required=True, help="Path to a model file or metadata JSON file.")
|
||||
|
||||
sqlite_parser = subparsers.add_parser("sqlite", help="Inspect a SQLite database read-only.")
|
||||
sqlite_parser.add_argument("--db", required=True, help="Path to the SQLite database.")
|
||||
sqlite_parser.add_argument("--limit", type=int, default=3, help="Rows to sample from each user table.")
|
||||
|
||||
args = parser.parse_args()
|
||||
context = build_context()
|
||||
|
||||
if args.command == "summary":
|
||||
print_json(summary_payload(context))
|
||||
elif args.command == "caches":
|
||||
print_json(caches_payload(context))
|
||||
elif args.command == "recipes":
|
||||
print_json(recipes_payload(context))
|
||||
elif args.command == "model":
|
||||
print_json(model_payload(args.path))
|
||||
elif args.command == "sqlite":
|
||||
print_json(sqlite_payload(Path(args.db).expanduser(), args.limit))
|
||||
return 0
|
||||
|
||||
|
||||
def build_context() -> dict[str, Any]:
|
||||
settings_path = resolve_settings_path()
|
||||
settings = load_json(settings_path)
|
||||
settings_dir = settings_path.parent
|
||||
active_library = settings.get("active_library") or "default"
|
||||
safe_library = sanitize_library_name(str(active_library))
|
||||
cache_root = settings_dir / "cache"
|
||||
return {
|
||||
"settings_path": str(settings_path),
|
||||
"settings_dir": str(settings_dir),
|
||||
"settings": settings,
|
||||
"active_library": active_library,
|
||||
"safe_library": safe_library,
|
||||
"cache_root": str(cache_root),
|
||||
"cache_paths": resolve_cache_paths(cache_root, safe_library),
|
||||
}
|
||||
|
||||
|
||||
def resolve_settings_path() -> Path:
|
||||
repo_root = find_repo_root()
|
||||
portable = repo_root / "settings.json"
|
||||
if portable.exists():
|
||||
payload = load_json(portable)
|
||||
if isinstance(payload, dict) and payload.get("use_portable_settings") is True:
|
||||
return portable
|
||||
|
||||
config_home = os.environ.get("XDG_CONFIG_HOME")
|
||||
if config_home:
|
||||
return Path(config_home).expanduser() / APP_NAME / "settings.json"
|
||||
return Path.home() / ".config" / APP_NAME / "settings.json"
|
||||
|
||||
|
||||
def find_repo_root() -> Path:
|
||||
current = Path(__file__).resolve()
|
||||
for parent in current.parents:
|
||||
if (parent / "py").is_dir() and (parent / "standalone.py").exists():
|
||||
return parent
|
||||
return Path.cwd()
|
||||
|
||||
|
||||
def load_json(path: Path) -> dict[str, Any]:
|
||||
try:
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
payload = json.load(handle)
|
||||
except FileNotFoundError:
|
||||
return {}
|
||||
except json.JSONDecodeError as exc:
|
||||
return {"_error": f"invalid JSON: {exc}"}
|
||||
except OSError as exc:
|
||||
return {"_error": f"unreadable: {exc}"}
|
||||
return payload if isinstance(payload, dict) else {"_error": "JSON root is not an object"}
|
||||
|
||||
|
||||
def resolve_cache_paths(cache_root: Path, library: str) -> dict[str, str]:
|
||||
paths: dict[str, str] = {}
|
||||
for name, (subdir, filename) in CACHE_SQLITE.items():
|
||||
paths[name] = str(cache_root / subdir / filename.format(library=library))
|
||||
for name, (subdir, filename) in CACHE_JSON.items():
|
||||
paths[name] = str(cache_root / subdir / filename)
|
||||
return paths
|
||||
|
||||
|
||||
def summary_payload(context: dict[str, Any]) -> dict[str, Any]:
|
||||
settings = context["settings"]
|
||||
return {
|
||||
"settings_path": context["settings_path"],
|
||||
"settings_dir": context["settings_dir"],
|
||||
"active_library": context["active_library"],
|
||||
"settings": redact(settings),
|
||||
"model_roots": model_roots(settings, context["active_library"]),
|
||||
"recipes_root": str(resolve_recipes_root(settings, context["active_library"]) or ""),
|
||||
"example_images": example_images_payload(settings, context["active_library"]),
|
||||
"cache_root": context["cache_root"],
|
||||
"cache_paths": context["cache_paths"],
|
||||
}
|
||||
|
||||
|
||||
def caches_payload(context: dict[str, Any]) -> dict[str, Any]:
|
||||
caches: dict[str, Any] = {}
|
||||
for name, path_string in context["cache_paths"].items():
|
||||
path = Path(path_string)
|
||||
item: dict[str, Any] = {
|
||||
"path": str(path),
|
||||
"exists": path.exists(),
|
||||
"size": path.stat().st_size if path.exists() else None,
|
||||
}
|
||||
if path.suffix == ".sqlite":
|
||||
item["sqlite"] = sqlite_payload(path, limit=0)
|
||||
elif path.suffix == ".json":
|
||||
item["json"] = json_file_summary(path)
|
||||
caches[name] = item
|
||||
return {"active_library": context["active_library"], "caches": caches}
|
||||
|
||||
|
||||
def recipes_payload(context: dict[str, Any]) -> dict[str, Any]:
|
||||
root = resolve_recipes_root(context["settings"], context["active_library"])
|
||||
files: list[str] = []
|
||||
if root and root.exists():
|
||||
files = [str(path) for path in sorted(root.rglob("*.recipe.json"))[:20]]
|
||||
return {
|
||||
"recipes_root": str(root or ""),
|
||||
"exists": bool(root and root.exists()),
|
||||
"recipe_json_count": count_recipe_files(root),
|
||||
"sample_recipe_json": files,
|
||||
"recipe_cache": context["cache_paths"].get("recipe"),
|
||||
}
|
||||
|
||||
|
||||
def model_payload(raw_path: str) -> dict[str, Any]:
|
||||
path = Path(raw_path).expanduser()
|
||||
metadata_path = path if path.name.endswith(".metadata.json") else path.with_suffix(".metadata.json")
|
||||
payload = {
|
||||
"input_path": str(path),
|
||||
"metadata_path": str(metadata_path),
|
||||
"model_exists": path.exists(),
|
||||
"metadata_exists": metadata_path.exists(),
|
||||
}
|
||||
if metadata_path.exists():
|
||||
data = load_json(metadata_path)
|
||||
payload["metadata_summary"] = redact(summarize_value(data))
|
||||
return payload
|
||||
|
||||
|
||||
def sqlite_payload(path: Path, limit: int = 3, allow_copy: bool = True) -> dict[str, Any]:
|
||||
result: dict[str, Any] = {"path": str(path), "exists": path.exists(), "tables": {}}
|
||||
if not path.exists():
|
||||
return result
|
||||
try:
|
||||
conn = connect_sqlite_readonly(path)
|
||||
except sqlite3.Error as exc:
|
||||
result["error"] = str(exc)
|
||||
return result
|
||||
try:
|
||||
table_rows = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
||||
).fetchall()
|
||||
for table_row in table_rows:
|
||||
table = table_row["name"]
|
||||
columns = [
|
||||
row["name"]
|
||||
for row in conn.execute(f"PRAGMA table_info({quote_identifier(table)})").fetchall()
|
||||
]
|
||||
table_info: dict[str, Any] = {"columns": columns}
|
||||
try:
|
||||
table_info["count"] = conn.execute(
|
||||
f"SELECT COUNT(*) FROM {quote_identifier(table)}"
|
||||
).fetchone()[0]
|
||||
except sqlite3.Error as exc:
|
||||
table_info["count_error"] = str(exc)
|
||||
if limit > 0 and columns and not is_internal_sqlite_table(table):
|
||||
try:
|
||||
rows = conn.execute(
|
||||
f"SELECT * FROM {quote_identifier(table)} LIMIT ?", (limit,)
|
||||
).fetchall()
|
||||
table_info["sample"] = [redact(dict(row)) for row in rows]
|
||||
except sqlite3.Error as exc:
|
||||
table_info["sample_error"] = str(exc)
|
||||
result["tables"][table] = table_info
|
||||
except sqlite3.Error as exc:
|
||||
fallback = sqlite_copy_payload(path, limit, str(exc)) if allow_copy else None
|
||||
if fallback is not None:
|
||||
result.update(fallback)
|
||||
else:
|
||||
result["error"] = str(exc)
|
||||
finally:
|
||||
conn.close()
|
||||
return result
|
||||
|
||||
|
||||
def connect_sqlite_readonly(path: Path) -> sqlite3.Connection:
|
||||
errors: list[str] = []
|
||||
for query in ("mode=ro", "mode=ro&immutable=1"):
|
||||
try:
|
||||
conn = sqlite3.connect(f"file:{path}?{query}", uri=True)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
except sqlite3.Error as exc:
|
||||
errors.append(f"{query}: {exc}")
|
||||
raise sqlite3.OperationalError("; ".join(errors))
|
||||
|
||||
|
||||
def sqlite_copy_payload(path: Path, limit: int, original_error: str) -> dict[str, Any] | None:
|
||||
try:
|
||||
with tempfile.TemporaryDirectory(prefix="lm-cache-inspect-") as temp_dir:
|
||||
copy_path = Path(temp_dir) / path.name
|
||||
shutil.copy2(path, copy_path)
|
||||
payload = sqlite_payload(copy_path, limit, allow_copy=False)
|
||||
payload["path"] = str(path)
|
||||
payload["inspected_copy"] = True
|
||||
payload["original_error"] = original_error
|
||||
return payload
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def json_file_summary(path: Path) -> dict[str, Any]:
|
||||
if not path.exists():
|
||||
return {"exists": False}
|
||||
data = load_json(path)
|
||||
return {"exists": True, "summary": redact(summarize_value(data))}
|
||||
|
||||
|
||||
def model_roots(settings: dict[str, Any], active_library: str) -> dict[str, list[str]]:
|
||||
roots: dict[str, list[str]] = {}
|
||||
sources = [settings]
|
||||
library = settings.get("libraries", {}).get(active_library)
|
||||
if isinstance(library, dict):
|
||||
sources.insert(0, library)
|
||||
for source in sources:
|
||||
folder_paths = source.get("folder_paths")
|
||||
if isinstance(folder_paths, dict):
|
||||
for key, value in folder_paths.items():
|
||||
roots.setdefault(key, []).extend(normalize_path_list(value))
|
||||
for default_key, folder_key in (
|
||||
("default_lora_root", "loras"),
|
||||
("default_checkpoint_root", "checkpoints"),
|
||||
("default_embedding_root", "embeddings"),
|
||||
("default_unet_root", "unet"),
|
||||
):
|
||||
value = settings.get(default_key)
|
||||
if isinstance(value, str) and value:
|
||||
roots.setdefault(folder_key, []).append(expand_path(value))
|
||||
return {key: dedupe(values) for key, values in roots.items()}
|
||||
|
||||
|
||||
def resolve_recipes_root(settings: dict[str, Any], active_library: str) -> Path | None:
|
||||
recipes_path = settings.get("recipes_path")
|
||||
library = settings.get("libraries", {}).get(active_library)
|
||||
if isinstance(library, dict) and isinstance(library.get("recipes_path"), str):
|
||||
recipes_path = library["recipes_path"] or recipes_path
|
||||
if isinstance(recipes_path, str) and recipes_path.strip():
|
||||
return Path(expand_path(recipes_path.strip()))
|
||||
lora_roots = model_roots(settings, active_library).get("loras") or []
|
||||
return Path(lora_roots[0]) / "recipes" if lora_roots else None
|
||||
|
||||
|
||||
def example_images_payload(settings: dict[str, Any], active_library: str) -> dict[str, Any]:
|
||||
root = settings.get("example_images_path") or ""
|
||||
libraries = settings.get("libraries")
|
||||
library_count = len(libraries) if isinstance(libraries, dict) else 0
|
||||
scoped = library_count > 1
|
||||
root_path = Path(expand_path(root)) if isinstance(root, str) and root else None
|
||||
library_root = root_path / sanitize_library_name(active_library) if root_path and scoped else root_path
|
||||
return {
|
||||
"root": str(root_path or ""),
|
||||
"uses_library_scoped_folders": scoped,
|
||||
"library_root": str(library_root or ""),
|
||||
}
|
||||
|
||||
|
||||
def count_recipe_files(root: Path | None) -> int:
|
||||
if not root or not root.exists():
|
||||
return 0
|
||||
return sum(1 for _ in root.rglob("*.recipe.json"))
|
||||
|
||||
|
||||
def normalize_path_list(value: Any) -> list[str]:
|
||||
if isinstance(value, str):
|
||||
return [expand_path(value)] if value else []
|
||||
if isinstance(value, list):
|
||||
return [expand_path(item) for item in value if isinstance(item, str) and item]
|
||||
return []
|
||||
|
||||
|
||||
def expand_path(value: str) -> str:
|
||||
return str(Path(value).expanduser().resolve(strict=False))
|
||||
|
||||
|
||||
def sanitize_library_name(name: str) -> str:
|
||||
safe = re.sub(r"[^A-Za-z0-9_.-]", "_", name or "default")
|
||||
return safe or "default"
|
||||
|
||||
|
||||
def dedupe(values: list[str]) -> list[str]:
|
||||
seen: set[str] = set()
|
||||
result: list[str] = []
|
||||
for value in values:
|
||||
if value not in seen:
|
||||
result.append(value)
|
||||
seen.add(value)
|
||||
return result
|
||||
|
||||
|
||||
def redact(value: Any, key: str = "") -> Any:
|
||||
if key and SECRET_PATTERN.search(key):
|
||||
return "<redacted>"
|
||||
if isinstance(value, dict):
|
||||
return {str(k): redact(v, str(k)) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [redact(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def summarize_value(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {key: summarize_value(item) for key, item in value.items()}
|
||||
if isinstance(value, list):
|
||||
return {
|
||||
"type": "array",
|
||||
"length": len(value),
|
||||
"first": summarize_value(value[0]) if value else None,
|
||||
}
|
||||
return value
|
||||
|
||||
|
||||
def quote_identifier(identifier: str) -> str:
|
||||
return '"' + identifier.replace('"', '""') + '"'
|
||||
|
||||
|
||||
def is_internal_sqlite_table(table: str) -> bool:
|
||||
return table.startswith("sqlite_") or table.endswith(("_data", "_idx", "_docsize", "_config", "_content"))
|
||||
|
||||
|
||||
def print_json(payload: Any) -> None:
|
||||
json.dump(payload, sys.stdout, indent=2, ensure_ascii=False)
|
||||
sys.stdout.write("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -780,8 +780,10 @@ NODE_EXTRACTORS = {
|
||||
"GGUFLoaderKJ": KJNodesModelLoaderExtractor, # KJNodes
|
||||
"DiffusionModelLoaderKJ": KJNodesModelLoaderExtractor, # KJNodes
|
||||
"CheckpointLoaderKJ": CheckpointLoaderExtractor, # KJNodes
|
||||
"CheckpointLoaderLM": CheckpointLoaderExtractor, # LoRA Manager
|
||||
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||
"UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||
"UNETLoaderLM": UNETLoaderExtractor, # LoRA Manager
|
||||
"LoraLoader": LoraLoaderExtractor,
|
||||
"LoraLoaderLM": LoraLoaderManagerExtractor,
|
||||
"RgthreePowerLoraLoader": RgthreePowerLoraLoaderExtractor,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -36,6 +37,9 @@ class CheckpointScanner(ModelScanner):
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex(),
|
||||
)
|
||||
if not hasattr(self, "_hash_calculation_lock"):
|
||||
self._hash_calculation_lock = asyncio.Lock()
|
||||
self._hash_calculation_tasks: dict[str, asyncio.Task[Optional[str]]] = {}
|
||||
|
||||
async def _create_default_metadata(
|
||||
self, file_path: str
|
||||
@@ -88,7 +92,7 @@ class CheckpointScanner(ModelScanner):
|
||||
return None
|
||||
|
||||
async def calculate_hash_for_model(self, file_path: str) -> Optional[str]:
|
||||
"""Calculate hash for a checkpoint on-demand.
|
||||
"""Calculate hash for a checkpoint on-demand with per-file singleflight.
|
||||
|
||||
Args:
|
||||
file_path: Path to the model file
|
||||
@@ -96,14 +100,65 @@ class CheckpointScanner(ModelScanner):
|
||||
Returns:
|
||||
SHA256 hash string, or None if calculation failed
|
||||
"""
|
||||
from ..utils.file_utils import calculate_sha256
|
||||
|
||||
try:
|
||||
real_path = os.path.realpath(file_path)
|
||||
if not os.path.exists(real_path):
|
||||
logger.error(f"File not found for hash calculation: {file_path}")
|
||||
return None
|
||||
|
||||
metadata, _ = await MetadataManager.load_metadata(
|
||||
file_path, self.model_class
|
||||
)
|
||||
if (
|
||||
metadata is not None
|
||||
and metadata.hash_status == "completed"
|
||||
and metadata.sha256
|
||||
):
|
||||
return metadata.sha256
|
||||
|
||||
async with self._hash_calculation_lock:
|
||||
metadata, _ = await MetadataManager.load_metadata(
|
||||
file_path, self.model_class
|
||||
)
|
||||
if (
|
||||
metadata is not None
|
||||
and metadata.hash_status == "completed"
|
||||
and metadata.sha256
|
||||
):
|
||||
return metadata.sha256
|
||||
|
||||
task = self._hash_calculation_tasks.get(real_path)
|
||||
if task is None:
|
||||
task = asyncio.create_task(
|
||||
self._run_hash_calculation_task(file_path, real_path)
|
||||
)
|
||||
self._hash_calculation_tasks[real_path] = task
|
||||
|
||||
return await asyncio.shield(task)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating hash for {file_path}: {e}")
|
||||
return None
|
||||
|
||||
async def _run_hash_calculation_task(
|
||||
self, file_path: str, real_path: str
|
||||
) -> Optional[str]:
|
||||
"""Run a hash calculation task and remove it from the in-flight map."""
|
||||
try:
|
||||
return await self._calculate_hash_for_model_uncached(file_path, real_path)
|
||||
finally:
|
||||
task = asyncio.current_task()
|
||||
async with self._hash_calculation_lock:
|
||||
if self._hash_calculation_tasks.get(real_path) is task:
|
||||
del self._hash_calculation_tasks[real_path]
|
||||
|
||||
async def _calculate_hash_for_model_uncached(
|
||||
self, file_path: str, real_path: str
|
||||
) -> Optional[str]:
|
||||
"""Calculate hash for a checkpoint without checking in-flight tasks."""
|
||||
from ..utils.file_utils import calculate_sha256
|
||||
|
||||
try:
|
||||
# Load current metadata
|
||||
metadata, should_skip = await MetadataManager.load_metadata(
|
||||
file_path, self.model_class
|
||||
|
||||
@@ -1815,6 +1815,15 @@ class RecipeScanner:
|
||||
|
||||
return await self._lora_scanner.get_model_info_by_name(name)
|
||||
|
||||
async def get_local_checkpoint(self, name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Lookup a local checkpoint model by name."""
|
||||
|
||||
checkpoint_scanner = getattr(self, "_checkpoint_scanner", None)
|
||||
if not checkpoint_scanner or not name:
|
||||
return None
|
||||
|
||||
return await checkpoint_scanner.get_model_info_by_name(name)
|
||||
|
||||
async def get_paginated_data(
|
||||
self,
|
||||
page: int,
|
||||
|
||||
@@ -508,6 +508,10 @@ class RecipePersistenceService:
|
||||
most_common_base_model = (
|
||||
max(base_model_counts.items(), key=lambda item: item[1])[0] if base_model_counts else ""
|
||||
)
|
||||
checkpoint_entry = await self._build_widget_checkpoint_entry(
|
||||
recipe_scanner,
|
||||
metadata.get("checkpoint"),
|
||||
)
|
||||
|
||||
recipe_data = {
|
||||
"id": recipe_id,
|
||||
@@ -515,9 +519,8 @@ class RecipePersistenceService:
|
||||
"title": recipe_name,
|
||||
"modified": time.time(),
|
||||
"created_date": time.time(),
|
||||
"base_model": most_common_base_model,
|
||||
"base_model": most_common_base_model or (checkpoint_entry or {}).get("baseModel", ""),
|
||||
"loras": loras_data,
|
||||
"checkpoint": self._sanitize_checkpoint_entry(metadata.get("checkpoint", "")),
|
||||
"gen_params": {
|
||||
key: value
|
||||
for key, value in metadata.items()
|
||||
@@ -525,6 +528,8 @@ class RecipePersistenceService:
|
||||
},
|
||||
"loras_stack": lora_stack,
|
||||
}
|
||||
if checkpoint_entry:
|
||||
recipe_data["checkpoint"] = checkpoint_entry
|
||||
|
||||
json_filename = f"{recipe_id}.recipe.json"
|
||||
json_path = os.path.join(recipes_dir, json_filename)
|
||||
@@ -546,6 +551,91 @@ class RecipePersistenceService:
|
||||
|
||||
# Helper methods ---------------------------------------------------
|
||||
|
||||
async def _build_widget_checkpoint_entry(
|
||||
self,
|
||||
recipe_scanner,
|
||||
checkpoint_raw: Any,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""Build recipe checkpoint metadata from widget generation metadata."""
|
||||
|
||||
if isinstance(checkpoint_raw, dict):
|
||||
return self._sanitize_checkpoint_entry(checkpoint_raw)
|
||||
|
||||
if not isinstance(checkpoint_raw, str):
|
||||
return None
|
||||
|
||||
checkpoint_name = checkpoint_raw.strip()
|
||||
if not checkpoint_name:
|
||||
return None
|
||||
|
||||
file_name = os.path.splitext(os.path.basename(checkpoint_name))[0]
|
||||
checkpoint_info = await self._lookup_widget_checkpoint(
|
||||
recipe_scanner,
|
||||
checkpoint_name,
|
||||
)
|
||||
if not checkpoint_info:
|
||||
return {
|
||||
"type": "checkpoint",
|
||||
"name": checkpoint_name,
|
||||
"file_name": file_name,
|
||||
"hash": "",
|
||||
}
|
||||
|
||||
civitai = checkpoint_info.get("civitai") or {}
|
||||
civitai_model = civitai.get("model") or {}
|
||||
file_path = checkpoint_info.get("file_path") or checkpoint_info.get("path") or ""
|
||||
cached_file_name = (
|
||||
checkpoint_info.get("file_name")
|
||||
or (os.path.splitext(os.path.basename(file_path))[0] if file_path else "")
|
||||
or file_name
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "checkpoint",
|
||||
"modelId": civitai_model.get("id", 0),
|
||||
"modelVersionId": civitai.get("id", 0),
|
||||
"name": civitai_model.get("name") or checkpoint_info.get("model_name") or checkpoint_name,
|
||||
"version": civitai.get("name", ""),
|
||||
"hash": (checkpoint_info.get("sha256") or checkpoint_info.get("hash") or "").lower(),
|
||||
"file_name": cached_file_name,
|
||||
"modelName": civitai_model.get("name", ""),
|
||||
"modelVersionName": civitai.get("name", ""),
|
||||
"baseModel": checkpoint_info.get("base_model") or civitai.get("baseModel", ""),
|
||||
}
|
||||
|
||||
async def _lookup_widget_checkpoint(
|
||||
self,
|
||||
recipe_scanner,
|
||||
checkpoint_name: str,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
lookup = getattr(recipe_scanner, "get_local_checkpoint", None)
|
||||
if not callable(lookup):
|
||||
return None
|
||||
|
||||
candidates = []
|
||||
for candidate in (
|
||||
checkpoint_name,
|
||||
os.path.basename(checkpoint_name),
|
||||
os.path.splitext(os.path.basename(checkpoint_name))[0],
|
||||
):
|
||||
if candidate and candidate not in candidates:
|
||||
candidates.append(candidate)
|
||||
|
||||
for candidate in candidates:
|
||||
try:
|
||||
checkpoint_info = await lookup(candidate)
|
||||
except Exception as exc:
|
||||
self._logger.debug(
|
||||
"Failed to lookup checkpoint %s while saving widget recipe: %s",
|
||||
candidate,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
if checkpoint_info:
|
||||
return checkpoint_info
|
||||
|
||||
return None
|
||||
|
||||
def _extract_checkpoint_entry(self, metadata: dict[str, Any]) -> Optional[dict[str, Any]]:
|
||||
"""Pull a checkpoint entry from various metadata locations."""
|
||||
|
||||
|
||||
@@ -100,6 +100,7 @@ DEFAULT_PRIORITY_TAG_CONFIG = {
|
||||
# These model types are incorrectly labeled as "checkpoint" by CivitAI but are actually diffusion models
|
||||
DIFFUSION_MODEL_BASE_MODELS = frozenset(
|
||||
[
|
||||
"Anima",
|
||||
"ZImageTurbo",
|
||||
"ZImageBase",
|
||||
"Wan Video 1.3B t2v",
|
||||
|
||||
@@ -354,3 +354,33 @@ def test_lora_manager_cache_updates_when_loras_removed(metadata_registry):
|
||||
metadata = metadata_registry.get_metadata("prompt3")
|
||||
|
||||
assert "lora_node" not in metadata[LORAS]
|
||||
|
||||
|
||||
def test_lora_manager_checkpoint_and_unet_loaders_extract_models(metadata_registry):
|
||||
metadata_registry.start_collection("prompt1")
|
||||
|
||||
metadata_registry.record_node_execution(
|
||||
"checkpoint_node",
|
||||
"CheckpointLoaderLM",
|
||||
{"ckpt_name": ["models/checkpoint.safetensors"]},
|
||||
None,
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"unet_node",
|
||||
"UNETLoaderLM",
|
||||
{"unet_name": ["models/diffusion_model.safetensors"], "weight_dtype": ["default"]},
|
||||
None,
|
||||
)
|
||||
|
||||
metadata = metadata_registry.get_metadata("prompt1")
|
||||
|
||||
assert metadata[MODELS]["checkpoint_node"] == {
|
||||
"name": "models/checkpoint.safetensors",
|
||||
"type": "checkpoint",
|
||||
"node_id": "checkpoint_node",
|
||||
}
|
||||
assert metadata[MODELS]["unet_node"] == {
|
||||
"name": "models/diffusion_model.safetensors",
|
||||
"type": "checkpoint",
|
||||
"node_id": "unet_node",
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for checkpoint lazy hash calculation feature."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -199,6 +200,160 @@ async def test_calculate_hash_skips_if_already_completed(tmp_path: Path, monkeyp
|
||||
mock_calc.assert_not_called(), "Should not recalculate if already completed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_hash_for_model_singleflight_same_file(
|
||||
tmp_path: Path, monkeypatch
|
||||
):
|
||||
"""Concurrent calls for the same checkpoint should share one SHA256 task."""
|
||||
checkpoints_root = tmp_path / "checkpoints"
|
||||
checkpoints_root.mkdir()
|
||||
|
||||
checkpoint_file = checkpoints_root / "test_model.safetensors"
|
||||
checkpoint_file.write_text("fake content", encoding="utf-8")
|
||||
|
||||
normalized_root = _normalize(checkpoints_root)
|
||||
normalized_file = _normalize(checkpoint_file)
|
||||
real_file = os.path.realpath(normalized_file)
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"base_models_roots",
|
||||
[normalized_root],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
scanner = CheckpointScanner()
|
||||
metadata = await scanner._create_default_metadata(normalized_file)
|
||||
assert metadata is not None
|
||||
|
||||
calls = []
|
||||
|
||||
async def fake_calculate_sha256(file_path: str) -> str:
|
||||
calls.append(file_path)
|
||||
await asyncio.sleep(0.01)
|
||||
return "a" * 64
|
||||
|
||||
with patch(
|
||||
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
|
||||
):
|
||||
results = await asyncio.gather(
|
||||
*[scanner.calculate_hash_for_model(normalized_file) for _ in range(8)]
|
||||
)
|
||||
|
||||
assert calls == [real_file]
|
||||
assert results == ["a" * 64] * 8
|
||||
assert scanner._hash_calculation_tasks == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_hash_for_model_cleans_task_after_failure_and_retries(
|
||||
tmp_path: Path, monkeypatch
|
||||
):
|
||||
"""A failed in-flight task should be removed so later calls can retry."""
|
||||
checkpoints_root = tmp_path / "checkpoints"
|
||||
checkpoints_root.mkdir()
|
||||
|
||||
checkpoint_file = checkpoints_root / "retry_model.safetensors"
|
||||
checkpoint_file.write_text("fake content", encoding="utf-8")
|
||||
|
||||
normalized_root = _normalize(checkpoints_root)
|
||||
normalized_file = _normalize(checkpoint_file)
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"base_models_roots",
|
||||
[normalized_root],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
scanner = CheckpointScanner()
|
||||
metadata = await scanner._create_default_metadata(normalized_file)
|
||||
assert metadata is not None
|
||||
|
||||
attempts = 0
|
||||
|
||||
async def fake_calculate_sha256(_file_path: str) -> str:
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
if attempts == 1:
|
||||
raise RuntimeError("hash failed")
|
||||
return "b" * 64
|
||||
|
||||
with patch(
|
||||
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
|
||||
):
|
||||
assert await scanner.calculate_hash_for_model(normalized_file) is None
|
||||
assert scanner._hash_calculation_tasks == {}
|
||||
|
||||
hash_result = await scanner.calculate_hash_for_model(normalized_file)
|
||||
|
||||
assert hash_result == "b" * 64
|
||||
assert attempts == 2
|
||||
assert scanner._hash_calculation_tasks == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_hash_for_model_uses_separate_tasks_for_different_files(
|
||||
tmp_path: Path, monkeypatch
|
||||
):
|
||||
"""Different checkpoint files should not share the same in-flight task."""
|
||||
checkpoints_root = tmp_path / "checkpoints"
|
||||
checkpoints_root.mkdir()
|
||||
|
||||
checkpoint_files = [
|
||||
checkpoints_root / "model_a.safetensors",
|
||||
checkpoints_root / "model_b.safetensors",
|
||||
]
|
||||
for checkpoint_file in checkpoint_files:
|
||||
checkpoint_file.write_text(
|
||||
f"fake content for {checkpoint_file.name}", encoding="utf-8"
|
||||
)
|
||||
|
||||
normalized_root = _normalize(checkpoints_root)
|
||||
normalized_files = [
|
||||
_normalize(checkpoint_file) for checkpoint_file in checkpoint_files
|
||||
]
|
||||
real_files = {os.path.realpath(file_path) for file_path in normalized_files}
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"base_models_roots",
|
||||
[normalized_root],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
scanner = CheckpointScanner()
|
||||
for normalized_file in normalized_files:
|
||||
metadata = await scanner._create_default_metadata(normalized_file)
|
||||
assert metadata is not None
|
||||
|
||||
calls = []
|
||||
hashes_by_path = {
|
||||
os.path.realpath(normalized_files[0]): "c" * 64,
|
||||
os.path.realpath(normalized_files[1]): "d" * 64,
|
||||
}
|
||||
|
||||
async def fake_calculate_sha256(file_path: str) -> str:
|
||||
calls.append(file_path)
|
||||
await asyncio.sleep(0.01)
|
||||
return hashes_by_path[file_path]
|
||||
|
||||
with patch(
|
||||
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
|
||||
):
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
scanner.calculate_hash_for_model(file_path)
|
||||
for file_path in normalized_files
|
||||
]
|
||||
)
|
||||
|
||||
assert set(calls) == real_files
|
||||
assert len(calls) == 2
|
||||
assert set(results) == {"c" * 64, "d" * 64}
|
||||
assert scanner._hash_calculation_tasks == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_hash_for_model_bootstraps_missing_metadata(tmp_path: Path, monkeypatch):
|
||||
"""Test that calculate_hash_for_model creates pending metadata when it is missing."""
|
||||
|
||||
@@ -491,6 +491,9 @@ async def test_save_recipe_from_widget_allows_empty_lora(tmp_path):
|
||||
async def get_local_lora(self, name): # pragma: no cover - no lookups expected
|
||||
return None
|
||||
|
||||
async def get_local_checkpoint(self, name):
|
||||
return None
|
||||
|
||||
async def add_recipe(self, recipe_data):
|
||||
self.added.append(recipe_data)
|
||||
|
||||
@@ -518,9 +521,90 @@ async def test_save_recipe_from_widget_allows_empty_lora(tmp_path):
|
||||
|
||||
assert stored["loras"] == []
|
||||
assert stored["title"] == "recipe"
|
||||
assert stored["checkpoint"] == {
|
||||
"type": "checkpoint",
|
||||
"name": "base-model.safetensors",
|
||||
"file_name": "base-model",
|
||||
"hash": "",
|
||||
}
|
||||
assert scanner.added and scanner.added[0]["loras"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_recipe_from_widget_enriches_checkpoint_from_local_cache(tmp_path):
|
||||
exif_utils = DummyExifUtils()
|
||||
|
||||
class DummyScanner:
|
||||
def __init__(self, root):
|
||||
self.recipes_dir = str(root)
|
||||
self.added = []
|
||||
self.checkpoint_queries = []
|
||||
|
||||
async def get_local_lora(self, name): # pragma: no cover - no loras
|
||||
return None
|
||||
|
||||
async def get_local_checkpoint(self, name):
|
||||
self.checkpoint_queries.append(name)
|
||||
if name != "matched-model":
|
||||
return None
|
||||
return {
|
||||
"file_name": "matched-model",
|
||||
"file_path": "/models/checkpoints/folder/matched-model.safetensors",
|
||||
"sha256": "ABC123",
|
||||
"base_model": "Illustrious",
|
||||
"civitai": {
|
||||
"id": 456,
|
||||
"name": "v1.0",
|
||||
"baseModel": "Illustrious",
|
||||
"model": {
|
||||
"id": 123,
|
||||
"name": "Matched Model",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
async def add_recipe(self, recipe_data):
|
||||
self.added.append(recipe_data)
|
||||
|
||||
scanner = DummyScanner(tmp_path)
|
||||
service = RecipePersistenceService(
|
||||
exif_utils=exif_utils,
|
||||
card_preview_width=512,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
result = await service.save_recipe_from_widget(
|
||||
recipe_scanner=scanner,
|
||||
metadata={
|
||||
"loras": "",
|
||||
"checkpoint": "folder/matched-model.safetensors",
|
||||
"prompt": "a calm scene",
|
||||
},
|
||||
image_bytes=b"image-bytes",
|
||||
)
|
||||
|
||||
stored = json.loads(Path(result.payload["json_path"]).read_text())
|
||||
|
||||
assert scanner.checkpoint_queries == [
|
||||
"folder/matched-model.safetensors",
|
||||
"matched-model.safetensors",
|
||||
"matched-model",
|
||||
]
|
||||
assert stored["base_model"] == "Illustrious"
|
||||
assert stored["checkpoint"] == {
|
||||
"type": "checkpoint",
|
||||
"modelId": 123,
|
||||
"modelVersionId": 456,
|
||||
"name": "Matched Model",
|
||||
"version": "v1.0",
|
||||
"hash": "abc123",
|
||||
"file_name": "matched-model",
|
||||
"modelName": "Matched Model",
|
||||
"modelVersionName": "v1.0",
|
||||
"baseModel": "Illustrious",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_recipe_updates_paths(tmp_path):
|
||||
exif_utils = DummyExifUtils()
|
||||
|
||||
Reference in New Issue
Block a user