mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 08:26:45 -03:00
fix(recipes): save widget checkpoint metadata as dict
This commit is contained in:
@@ -780,8 +780,10 @@ NODE_EXTRACTORS = {
|
|||||||
"GGUFLoaderKJ": KJNodesModelLoaderExtractor, # KJNodes
|
"GGUFLoaderKJ": KJNodesModelLoaderExtractor, # KJNodes
|
||||||
"DiffusionModelLoaderKJ": KJNodesModelLoaderExtractor, # KJNodes
|
"DiffusionModelLoaderKJ": KJNodesModelLoaderExtractor, # KJNodes
|
||||||
"CheckpointLoaderKJ": CheckpointLoaderExtractor, # KJNodes
|
"CheckpointLoaderKJ": CheckpointLoaderExtractor, # KJNodes
|
||||||
|
"CheckpointLoaderLM": CheckpointLoaderExtractor, # LoRA Manager
|
||||||
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
|
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||||
"UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor
|
"UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||||
|
"UNETLoaderLM": UNETLoaderExtractor, # LoRA Manager
|
||||||
"LoraLoader": LoraLoaderExtractor,
|
"LoraLoader": LoraLoaderExtractor,
|
||||||
"LoraLoaderLM": LoraLoaderManagerExtractor,
|
"LoraLoaderLM": LoraLoaderManagerExtractor,
|
||||||
"RgthreePowerLoraLoader": RgthreePowerLoraLoaderExtractor,
|
"RgthreePowerLoraLoader": RgthreePowerLoraLoaderExtractor,
|
||||||
|
|||||||
@@ -1815,6 +1815,15 @@ class RecipeScanner:
|
|||||||
|
|
||||||
return await self._lora_scanner.get_model_info_by_name(name)
|
return await self._lora_scanner.get_model_info_by_name(name)
|
||||||
|
|
||||||
|
async def get_local_checkpoint(self, name: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Lookup a local checkpoint model by name."""
|
||||||
|
|
||||||
|
checkpoint_scanner = getattr(self, "_checkpoint_scanner", None)
|
||||||
|
if not checkpoint_scanner or not name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return await checkpoint_scanner.get_model_info_by_name(name)
|
||||||
|
|
||||||
async def get_paginated_data(
|
async def get_paginated_data(
|
||||||
self,
|
self,
|
||||||
page: int,
|
page: int,
|
||||||
|
|||||||
@@ -508,6 +508,10 @@ class RecipePersistenceService:
|
|||||||
most_common_base_model = (
|
most_common_base_model = (
|
||||||
max(base_model_counts.items(), key=lambda item: item[1])[0] if base_model_counts else ""
|
max(base_model_counts.items(), key=lambda item: item[1])[0] if base_model_counts else ""
|
||||||
)
|
)
|
||||||
|
checkpoint_entry = await self._build_widget_checkpoint_entry(
|
||||||
|
recipe_scanner,
|
||||||
|
metadata.get("checkpoint"),
|
||||||
|
)
|
||||||
|
|
||||||
recipe_data = {
|
recipe_data = {
|
||||||
"id": recipe_id,
|
"id": recipe_id,
|
||||||
@@ -515,9 +519,8 @@ class RecipePersistenceService:
|
|||||||
"title": recipe_name,
|
"title": recipe_name,
|
||||||
"modified": time.time(),
|
"modified": time.time(),
|
||||||
"created_date": time.time(),
|
"created_date": time.time(),
|
||||||
"base_model": most_common_base_model,
|
"base_model": most_common_base_model or (checkpoint_entry or {}).get("baseModel", ""),
|
||||||
"loras": loras_data,
|
"loras": loras_data,
|
||||||
"checkpoint": self._sanitize_checkpoint_entry(metadata.get("checkpoint", "")),
|
|
||||||
"gen_params": {
|
"gen_params": {
|
||||||
key: value
|
key: value
|
||||||
for key, value in metadata.items()
|
for key, value in metadata.items()
|
||||||
@@ -525,6 +528,8 @@ class RecipePersistenceService:
|
|||||||
},
|
},
|
||||||
"loras_stack": lora_stack,
|
"loras_stack": lora_stack,
|
||||||
}
|
}
|
||||||
|
if checkpoint_entry:
|
||||||
|
recipe_data["checkpoint"] = checkpoint_entry
|
||||||
|
|
||||||
json_filename = f"{recipe_id}.recipe.json"
|
json_filename = f"{recipe_id}.recipe.json"
|
||||||
json_path = os.path.join(recipes_dir, json_filename)
|
json_path = os.path.join(recipes_dir, json_filename)
|
||||||
@@ -546,6 +551,91 @@ class RecipePersistenceService:
|
|||||||
|
|
||||||
# Helper methods ---------------------------------------------------
|
# Helper methods ---------------------------------------------------
|
||||||
|
|
||||||
|
async def _build_widget_checkpoint_entry(
|
||||||
|
self,
|
||||||
|
recipe_scanner,
|
||||||
|
checkpoint_raw: Any,
|
||||||
|
) -> Optional[dict[str, Any]]:
|
||||||
|
"""Build recipe checkpoint metadata from widget generation metadata."""
|
||||||
|
|
||||||
|
if isinstance(checkpoint_raw, dict):
|
||||||
|
return self._sanitize_checkpoint_entry(checkpoint_raw)
|
||||||
|
|
||||||
|
if not isinstance(checkpoint_raw, str):
|
||||||
|
return None
|
||||||
|
|
||||||
|
checkpoint_name = checkpoint_raw.strip()
|
||||||
|
if not checkpoint_name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
file_name = os.path.splitext(os.path.basename(checkpoint_name))[0]
|
||||||
|
checkpoint_info = await self._lookup_widget_checkpoint(
|
||||||
|
recipe_scanner,
|
||||||
|
checkpoint_name,
|
||||||
|
)
|
||||||
|
if not checkpoint_info:
|
||||||
|
return {
|
||||||
|
"type": "checkpoint",
|
||||||
|
"name": checkpoint_name,
|
||||||
|
"file_name": file_name,
|
||||||
|
"hash": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
civitai = checkpoint_info.get("civitai") or {}
|
||||||
|
civitai_model = civitai.get("model") or {}
|
||||||
|
file_path = checkpoint_info.get("file_path") or checkpoint_info.get("path") or ""
|
||||||
|
cached_file_name = (
|
||||||
|
checkpoint_info.get("file_name")
|
||||||
|
or (os.path.splitext(os.path.basename(file_path))[0] if file_path else "")
|
||||||
|
or file_name
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "checkpoint",
|
||||||
|
"modelId": civitai_model.get("id", 0),
|
||||||
|
"modelVersionId": civitai.get("id", 0),
|
||||||
|
"name": civitai_model.get("name") or checkpoint_info.get("model_name") or checkpoint_name,
|
||||||
|
"version": civitai.get("name", ""),
|
||||||
|
"hash": (checkpoint_info.get("sha256") or checkpoint_info.get("hash") or "").lower(),
|
||||||
|
"file_name": cached_file_name,
|
||||||
|
"modelName": civitai_model.get("name", ""),
|
||||||
|
"modelVersionName": civitai.get("name", ""),
|
||||||
|
"baseModel": checkpoint_info.get("base_model") or civitai.get("baseModel", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _lookup_widget_checkpoint(
|
||||||
|
self,
|
||||||
|
recipe_scanner,
|
||||||
|
checkpoint_name: str,
|
||||||
|
) -> Optional[dict[str, Any]]:
|
||||||
|
lookup = getattr(recipe_scanner, "get_local_checkpoint", None)
|
||||||
|
if not callable(lookup):
|
||||||
|
return None
|
||||||
|
|
||||||
|
candidates = []
|
||||||
|
for candidate in (
|
||||||
|
checkpoint_name,
|
||||||
|
os.path.basename(checkpoint_name),
|
||||||
|
os.path.splitext(os.path.basename(checkpoint_name))[0],
|
||||||
|
):
|
||||||
|
if candidate and candidate not in candidates:
|
||||||
|
candidates.append(candidate)
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
try:
|
||||||
|
checkpoint_info = await lookup(candidate)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.debug(
|
||||||
|
"Failed to lookup checkpoint %s while saving widget recipe: %s",
|
||||||
|
candidate,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
if checkpoint_info:
|
||||||
|
return checkpoint_info
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def _extract_checkpoint_entry(self, metadata: dict[str, Any]) -> Optional[dict[str, Any]]:
|
def _extract_checkpoint_entry(self, metadata: dict[str, Any]) -> Optional[dict[str, Any]]:
|
||||||
"""Pull a checkpoint entry from various metadata locations."""
|
"""Pull a checkpoint entry from various metadata locations."""
|
||||||
|
|
||||||
|
|||||||
@@ -354,3 +354,33 @@ def test_lora_manager_cache_updates_when_loras_removed(metadata_registry):
|
|||||||
metadata = metadata_registry.get_metadata("prompt3")
|
metadata = metadata_registry.get_metadata("prompt3")
|
||||||
|
|
||||||
assert "lora_node" not in metadata[LORAS]
|
assert "lora_node" not in metadata[LORAS]
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_manager_checkpoint_and_unet_loaders_extract_models(metadata_registry):
|
||||||
|
metadata_registry.start_collection("prompt1")
|
||||||
|
|
||||||
|
metadata_registry.record_node_execution(
|
||||||
|
"checkpoint_node",
|
||||||
|
"CheckpointLoaderLM",
|
||||||
|
{"ckpt_name": ["models/checkpoint.safetensors"]},
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
metadata_registry.record_node_execution(
|
||||||
|
"unet_node",
|
||||||
|
"UNETLoaderLM",
|
||||||
|
{"unet_name": ["models/diffusion_model.safetensors"], "weight_dtype": ["default"]},
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = metadata_registry.get_metadata("prompt1")
|
||||||
|
|
||||||
|
assert metadata[MODELS]["checkpoint_node"] == {
|
||||||
|
"name": "models/checkpoint.safetensors",
|
||||||
|
"type": "checkpoint",
|
||||||
|
"node_id": "checkpoint_node",
|
||||||
|
}
|
||||||
|
assert metadata[MODELS]["unet_node"] == {
|
||||||
|
"name": "models/diffusion_model.safetensors",
|
||||||
|
"type": "checkpoint",
|
||||||
|
"node_id": "unet_node",
|
||||||
|
}
|
||||||
|
|||||||
@@ -491,6 +491,9 @@ async def test_save_recipe_from_widget_allows_empty_lora(tmp_path):
|
|||||||
async def get_local_lora(self, name): # pragma: no cover - no lookups expected
|
async def get_local_lora(self, name): # pragma: no cover - no lookups expected
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def get_local_checkpoint(self, name):
|
||||||
|
return None
|
||||||
|
|
||||||
async def add_recipe(self, recipe_data):
|
async def add_recipe(self, recipe_data):
|
||||||
self.added.append(recipe_data)
|
self.added.append(recipe_data)
|
||||||
|
|
||||||
@@ -518,9 +521,90 @@ async def test_save_recipe_from_widget_allows_empty_lora(tmp_path):
|
|||||||
|
|
||||||
assert stored["loras"] == []
|
assert stored["loras"] == []
|
||||||
assert stored["title"] == "recipe"
|
assert stored["title"] == "recipe"
|
||||||
|
assert stored["checkpoint"] == {
|
||||||
|
"type": "checkpoint",
|
||||||
|
"name": "base-model.safetensors",
|
||||||
|
"file_name": "base-model",
|
||||||
|
"hash": "",
|
||||||
|
}
|
||||||
assert scanner.added and scanner.added[0]["loras"] == []
|
assert scanner.added and scanner.added[0]["loras"] == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_recipe_from_widget_enriches_checkpoint_from_local_cache(tmp_path):
|
||||||
|
exif_utils = DummyExifUtils()
|
||||||
|
|
||||||
|
class DummyScanner:
|
||||||
|
def __init__(self, root):
|
||||||
|
self.recipes_dir = str(root)
|
||||||
|
self.added = []
|
||||||
|
self.checkpoint_queries = []
|
||||||
|
|
||||||
|
async def get_local_lora(self, name): # pragma: no cover - no loras
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_local_checkpoint(self, name):
|
||||||
|
self.checkpoint_queries.append(name)
|
||||||
|
if name != "matched-model":
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"file_name": "matched-model",
|
||||||
|
"file_path": "/models/checkpoints/folder/matched-model.safetensors",
|
||||||
|
"sha256": "ABC123",
|
||||||
|
"base_model": "Illustrious",
|
||||||
|
"civitai": {
|
||||||
|
"id": 456,
|
||||||
|
"name": "v1.0",
|
||||||
|
"baseModel": "Illustrious",
|
||||||
|
"model": {
|
||||||
|
"id": 123,
|
||||||
|
"name": "Matched Model",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def add_recipe(self, recipe_data):
|
||||||
|
self.added.append(recipe_data)
|
||||||
|
|
||||||
|
scanner = DummyScanner(tmp_path)
|
||||||
|
service = RecipePersistenceService(
|
||||||
|
exif_utils=exif_utils,
|
||||||
|
card_preview_width=512,
|
||||||
|
logger=logging.getLogger("test"),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await service.save_recipe_from_widget(
|
||||||
|
recipe_scanner=scanner,
|
||||||
|
metadata={
|
||||||
|
"loras": "",
|
||||||
|
"checkpoint": "folder/matched-model.safetensors",
|
||||||
|
"prompt": "a calm scene",
|
||||||
|
},
|
||||||
|
image_bytes=b"image-bytes",
|
||||||
|
)
|
||||||
|
|
||||||
|
stored = json.loads(Path(result.payload["json_path"]).read_text())
|
||||||
|
|
||||||
|
assert scanner.checkpoint_queries == [
|
||||||
|
"folder/matched-model.safetensors",
|
||||||
|
"matched-model.safetensors",
|
||||||
|
"matched-model",
|
||||||
|
]
|
||||||
|
assert stored["base_model"] == "Illustrious"
|
||||||
|
assert stored["checkpoint"] == {
|
||||||
|
"type": "checkpoint",
|
||||||
|
"modelId": 123,
|
||||||
|
"modelVersionId": 456,
|
||||||
|
"name": "Matched Model",
|
||||||
|
"version": "v1.0",
|
||||||
|
"hash": "abc123",
|
||||||
|
"file_name": "matched-model",
|
||||||
|
"modelName": "Matched Model",
|
||||||
|
"modelVersionName": "v1.0",
|
||||||
|
"baseModel": "Illustrious",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_move_recipe_updates_paths(tmp_path):
|
async def test_move_recipe_updates_paths(tmp_path):
|
||||||
exif_utils = DummyExifUtils()
|
exif_utils = DummyExifUtils()
|
||||||
|
|||||||
Reference in New Issue
Block a user