mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 08:26:45 -03:00
fix(recipes): save widget checkpoint metadata as dict
This commit is contained in:
@@ -780,8 +780,10 @@ NODE_EXTRACTORS = {
|
||||
"GGUFLoaderKJ": KJNodesModelLoaderExtractor, # KJNodes
|
||||
"DiffusionModelLoaderKJ": KJNodesModelLoaderExtractor, # KJNodes
|
||||
"CheckpointLoaderKJ": CheckpointLoaderExtractor, # KJNodes
|
||||
"CheckpointLoaderLM": CheckpointLoaderExtractor, # LoRA Manager
|
||||
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||
"UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||
"UNETLoaderLM": UNETLoaderExtractor, # LoRA Manager
|
||||
"LoraLoader": LoraLoaderExtractor,
|
||||
"LoraLoaderLM": LoraLoaderManagerExtractor,
|
||||
"RgthreePowerLoraLoader": RgthreePowerLoraLoaderExtractor,
|
||||
|
||||
@@ -1815,6 +1815,15 @@ class RecipeScanner:
|
||||
|
||||
return await self._lora_scanner.get_model_info_by_name(name)
|
||||
|
||||
async def get_local_checkpoint(self, name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Lookup a local checkpoint model by name."""
|
||||
|
||||
checkpoint_scanner = getattr(self, "_checkpoint_scanner", None)
|
||||
if not checkpoint_scanner or not name:
|
||||
return None
|
||||
|
||||
return await checkpoint_scanner.get_model_info_by_name(name)
|
||||
|
||||
async def get_paginated_data(
|
||||
self,
|
||||
page: int,
|
||||
|
||||
@@ -508,6 +508,10 @@ class RecipePersistenceService:
|
||||
most_common_base_model = (
|
||||
max(base_model_counts.items(), key=lambda item: item[1])[0] if base_model_counts else ""
|
||||
)
|
||||
checkpoint_entry = await self._build_widget_checkpoint_entry(
|
||||
recipe_scanner,
|
||||
metadata.get("checkpoint"),
|
||||
)
|
||||
|
||||
recipe_data = {
|
||||
"id": recipe_id,
|
||||
@@ -515,9 +519,8 @@ class RecipePersistenceService:
|
||||
"title": recipe_name,
|
||||
"modified": time.time(),
|
||||
"created_date": time.time(),
|
||||
"base_model": most_common_base_model,
|
||||
"base_model": most_common_base_model or (checkpoint_entry or {}).get("baseModel", ""),
|
||||
"loras": loras_data,
|
||||
"checkpoint": self._sanitize_checkpoint_entry(metadata.get("checkpoint", "")),
|
||||
"gen_params": {
|
||||
key: value
|
||||
for key, value in metadata.items()
|
||||
@@ -525,6 +528,8 @@ class RecipePersistenceService:
|
||||
},
|
||||
"loras_stack": lora_stack,
|
||||
}
|
||||
if checkpoint_entry:
|
||||
recipe_data["checkpoint"] = checkpoint_entry
|
||||
|
||||
json_filename = f"{recipe_id}.recipe.json"
|
||||
json_path = os.path.join(recipes_dir, json_filename)
|
||||
@@ -546,6 +551,91 @@ class RecipePersistenceService:
|
||||
|
||||
# Helper methods ---------------------------------------------------
|
||||
|
||||
async def _build_widget_checkpoint_entry(
|
||||
self,
|
||||
recipe_scanner,
|
||||
checkpoint_raw: Any,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""Build recipe checkpoint metadata from widget generation metadata."""
|
||||
|
||||
if isinstance(checkpoint_raw, dict):
|
||||
return self._sanitize_checkpoint_entry(checkpoint_raw)
|
||||
|
||||
if not isinstance(checkpoint_raw, str):
|
||||
return None
|
||||
|
||||
checkpoint_name = checkpoint_raw.strip()
|
||||
if not checkpoint_name:
|
||||
return None
|
||||
|
||||
file_name = os.path.splitext(os.path.basename(checkpoint_name))[0]
|
||||
checkpoint_info = await self._lookup_widget_checkpoint(
|
||||
recipe_scanner,
|
||||
checkpoint_name,
|
||||
)
|
||||
if not checkpoint_info:
|
||||
return {
|
||||
"type": "checkpoint",
|
||||
"name": checkpoint_name,
|
||||
"file_name": file_name,
|
||||
"hash": "",
|
||||
}
|
||||
|
||||
civitai = checkpoint_info.get("civitai") or {}
|
||||
civitai_model = civitai.get("model") or {}
|
||||
file_path = checkpoint_info.get("file_path") or checkpoint_info.get("path") or ""
|
||||
cached_file_name = (
|
||||
checkpoint_info.get("file_name")
|
||||
or (os.path.splitext(os.path.basename(file_path))[0] if file_path else "")
|
||||
or file_name
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "checkpoint",
|
||||
"modelId": civitai_model.get("id", 0),
|
||||
"modelVersionId": civitai.get("id", 0),
|
||||
"name": civitai_model.get("name") or checkpoint_info.get("model_name") or checkpoint_name,
|
||||
"version": civitai.get("name", ""),
|
||||
"hash": (checkpoint_info.get("sha256") or checkpoint_info.get("hash") or "").lower(),
|
||||
"file_name": cached_file_name,
|
||||
"modelName": civitai_model.get("name", ""),
|
||||
"modelVersionName": civitai.get("name", ""),
|
||||
"baseModel": checkpoint_info.get("base_model") or civitai.get("baseModel", ""),
|
||||
}
|
||||
|
||||
async def _lookup_widget_checkpoint(
|
||||
self,
|
||||
recipe_scanner,
|
||||
checkpoint_name: str,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
lookup = getattr(recipe_scanner, "get_local_checkpoint", None)
|
||||
if not callable(lookup):
|
||||
return None
|
||||
|
||||
candidates = []
|
||||
for candidate in (
|
||||
checkpoint_name,
|
||||
os.path.basename(checkpoint_name),
|
||||
os.path.splitext(os.path.basename(checkpoint_name))[0],
|
||||
):
|
||||
if candidate and candidate not in candidates:
|
||||
candidates.append(candidate)
|
||||
|
||||
for candidate in candidates:
|
||||
try:
|
||||
checkpoint_info = await lookup(candidate)
|
||||
except Exception as exc:
|
||||
self._logger.debug(
|
||||
"Failed to lookup checkpoint %s while saving widget recipe: %s",
|
||||
candidate,
|
||||
exc,
|
||||
)
|
||||
continue
|
||||
if checkpoint_info:
|
||||
return checkpoint_info
|
||||
|
||||
return None
|
||||
|
||||
def _extract_checkpoint_entry(self, metadata: dict[str, Any]) -> Optional[dict[str, Any]]:
|
||||
"""Pull a checkpoint entry from various metadata locations."""
|
||||
|
||||
|
||||
@@ -354,3 +354,33 @@ def test_lora_manager_cache_updates_when_loras_removed(metadata_registry):
|
||||
metadata = metadata_registry.get_metadata("prompt3")
|
||||
|
||||
assert "lora_node" not in metadata[LORAS]
|
||||
|
||||
|
||||
def test_lora_manager_checkpoint_and_unet_loaders_extract_models(metadata_registry):
|
||||
metadata_registry.start_collection("prompt1")
|
||||
|
||||
metadata_registry.record_node_execution(
|
||||
"checkpoint_node",
|
||||
"CheckpointLoaderLM",
|
||||
{"ckpt_name": ["models/checkpoint.safetensors"]},
|
||||
None,
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"unet_node",
|
||||
"UNETLoaderLM",
|
||||
{"unet_name": ["models/diffusion_model.safetensors"], "weight_dtype": ["default"]},
|
||||
None,
|
||||
)
|
||||
|
||||
metadata = metadata_registry.get_metadata("prompt1")
|
||||
|
||||
assert metadata[MODELS]["checkpoint_node"] == {
|
||||
"name": "models/checkpoint.safetensors",
|
||||
"type": "checkpoint",
|
||||
"node_id": "checkpoint_node",
|
||||
}
|
||||
assert metadata[MODELS]["unet_node"] == {
|
||||
"name": "models/diffusion_model.safetensors",
|
||||
"type": "checkpoint",
|
||||
"node_id": "unet_node",
|
||||
}
|
||||
|
||||
@@ -491,6 +491,9 @@ async def test_save_recipe_from_widget_allows_empty_lora(tmp_path):
|
||||
async def get_local_lora(self, name): # pragma: no cover - no lookups expected
|
||||
return None
|
||||
|
||||
async def get_local_checkpoint(self, name):
|
||||
return None
|
||||
|
||||
async def add_recipe(self, recipe_data):
|
||||
self.added.append(recipe_data)
|
||||
|
||||
@@ -518,9 +521,90 @@ async def test_save_recipe_from_widget_allows_empty_lora(tmp_path):
|
||||
|
||||
assert stored["loras"] == []
|
||||
assert stored["title"] == "recipe"
|
||||
assert stored["checkpoint"] == {
|
||||
"type": "checkpoint",
|
||||
"name": "base-model.safetensors",
|
||||
"file_name": "base-model",
|
||||
"hash": "",
|
||||
}
|
||||
assert scanner.added and scanner.added[0]["loras"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_recipe_from_widget_enriches_checkpoint_from_local_cache(tmp_path):
|
||||
exif_utils = DummyExifUtils()
|
||||
|
||||
class DummyScanner:
|
||||
def __init__(self, root):
|
||||
self.recipes_dir = str(root)
|
||||
self.added = []
|
||||
self.checkpoint_queries = []
|
||||
|
||||
async def get_local_lora(self, name): # pragma: no cover - no loras
|
||||
return None
|
||||
|
||||
async def get_local_checkpoint(self, name):
|
||||
self.checkpoint_queries.append(name)
|
||||
if name != "matched-model":
|
||||
return None
|
||||
return {
|
||||
"file_name": "matched-model",
|
||||
"file_path": "/models/checkpoints/folder/matched-model.safetensors",
|
||||
"sha256": "ABC123",
|
||||
"base_model": "Illustrious",
|
||||
"civitai": {
|
||||
"id": 456,
|
||||
"name": "v1.0",
|
||||
"baseModel": "Illustrious",
|
||||
"model": {
|
||||
"id": 123,
|
||||
"name": "Matched Model",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
async def add_recipe(self, recipe_data):
|
||||
self.added.append(recipe_data)
|
||||
|
||||
scanner = DummyScanner(tmp_path)
|
||||
service = RecipePersistenceService(
|
||||
exif_utils=exif_utils,
|
||||
card_preview_width=512,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
result = await service.save_recipe_from_widget(
|
||||
recipe_scanner=scanner,
|
||||
metadata={
|
||||
"loras": "",
|
||||
"checkpoint": "folder/matched-model.safetensors",
|
||||
"prompt": "a calm scene",
|
||||
},
|
||||
image_bytes=b"image-bytes",
|
||||
)
|
||||
|
||||
stored = json.loads(Path(result.payload["json_path"]).read_text())
|
||||
|
||||
assert scanner.checkpoint_queries == [
|
||||
"folder/matched-model.safetensors",
|
||||
"matched-model.safetensors",
|
||||
"matched-model",
|
||||
]
|
||||
assert stored["base_model"] == "Illustrious"
|
||||
assert stored["checkpoint"] == {
|
||||
"type": "checkpoint",
|
||||
"modelId": 123,
|
||||
"modelVersionId": 456,
|
||||
"name": "Matched Model",
|
||||
"version": "v1.0",
|
||||
"hash": "abc123",
|
||||
"file_name": "matched-model",
|
||||
"modelName": "Matched Model",
|
||||
"modelVersionName": "v1.0",
|
||||
"baseModel": "Illustrious",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_move_recipe_updates_paths(tmp_path):
|
||||
exif_utils = DummyExifUtils()
|
||||
|
||||
Reference in New Issue
Block a user