mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat: Introduce recipe management with data models, scanning, enrichment, and repair for generation configurations.
This commit is contained in:
@@ -37,7 +37,8 @@ class RecipeMetadataParser(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def populate_lora_from_civitai(self, lora_entry: Dict[str, Any], civitai_info_tuple: Tuple[Dict[str, Any], Optional[str]],
|
||||
@staticmethod
|
||||
async def populate_lora_from_civitai(lora_entry: Dict[str, Any], civitai_info_tuple: Tuple[Dict[str, Any], Optional[str]],
|
||||
recipe_scanner=None, base_model_counts=None, hash_value=None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Populate a lora entry with information from Civitai API response
|
||||
@@ -148,8 +149,9 @@ class RecipeMetadataParser(ABC):
|
||||
logger.error(f"Error populating lora from Civitai info: {e}")
|
||||
|
||||
return lora_entry
|
||||
|
||||
async def populate_checkpoint_from_civitai(self, checkpoint: Dict[str, Any], civitai_info: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
async def populate_checkpoint_from_civitai(checkpoint: Dict[str, Any], civitai_info: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Populate checkpoint information from Civitai API response
|
||||
|
||||
@@ -187,6 +189,7 @@ class RecipeMetadataParser(ABC):
|
||||
checkpoint['downloadUrl'] = civitai_data.get('downloadUrl', '')
|
||||
|
||||
checkpoint['modelId'] = civitai_data.get('modelId', checkpoint.get('modelId', 0))
|
||||
checkpoint['id'] = civitai_data.get('id', 0)
|
||||
|
||||
if 'files' in civitai_data:
|
||||
model_file = next(
|
||||
|
||||
224
py/recipes/enrichment.py
Normal file
224
py/recipes/enrichment.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import logging
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
from .merger import GenParamsMerger
|
||||
from .base import RecipeMetadataParser
|
||||
from ..services.metadata_service import get_default_metadata_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RecipeEnricher:
|
||||
"""Service to enrich recipe metadata from multiple sources (Civitai, Embedded, User)."""
|
||||
|
||||
@staticmethod
|
||||
async def enrich_recipe(
|
||||
recipe: Dict[str, Any],
|
||||
civitai_client: Any,
|
||||
request_params: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Enrich a recipe dictionary in-place with metadata from Civitai and embedded params.
|
||||
|
||||
Args:
|
||||
recipe: The recipe dictionary to enrich. Must have 'gen_params' initialized.
|
||||
civitai_client: Authenticated Civitai client instance.
|
||||
request_params: (Optional) Parameters from a user request (e.g. import).
|
||||
|
||||
Returns:
|
||||
bool: True if the recipe was modified, False otherwise.
|
||||
"""
|
||||
updated = False
|
||||
gen_params = recipe.get("gen_params", {})
|
||||
|
||||
# 1. Fetch Civitai Info if available
|
||||
civitai_meta = None
|
||||
model_version_id = None
|
||||
|
||||
source_url = recipe.get("source_url") or recipe.get("source_path", "")
|
||||
|
||||
# Check if it's a Civitai image URL
|
||||
image_id_match = re.search(r'civitai\.com/images/(\d+)', str(source_url))
|
||||
if image_id_match:
|
||||
image_id = image_id_match.group(1)
|
||||
try:
|
||||
image_info = await civitai_client.get_image_info(image_id)
|
||||
if image_info:
|
||||
# Handle nested meta often found in Civitai API responses
|
||||
raw_meta = image_info.get("meta")
|
||||
if isinstance(raw_meta, dict):
|
||||
if "meta" in raw_meta and isinstance(raw_meta["meta"], dict):
|
||||
civitai_meta = raw_meta["meta"]
|
||||
else:
|
||||
civitai_meta = raw_meta
|
||||
|
||||
model_version_id = image_info.get("modelVersionId")
|
||||
|
||||
# If not at top level, check resources in meta
|
||||
if not model_version_id and civitai_meta:
|
||||
resources = civitai_meta.get("civitaiResources", [])
|
||||
for res in resources:
|
||||
if res.get("type") == "checkpoint":
|
||||
model_version_id = res.get("modelVersionId")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch Civitai image info: {e}")
|
||||
|
||||
# 2. Merge Parameters
|
||||
# Priority: request_params > civitai_meta > embedded (existing gen_params)
|
||||
new_gen_params = GenParamsMerger.merge(
|
||||
request_params=request_params,
|
||||
civitai_meta=civitai_meta,
|
||||
embedded_metadata=gen_params
|
||||
)
|
||||
|
||||
if new_gen_params != gen_params:
|
||||
recipe["gen_params"] = new_gen_params
|
||||
updated = True
|
||||
|
||||
# 3. Checkpoint Enrichment
|
||||
# If we have a checkpoint entry, or we can find one
|
||||
# Use 'id' (from Civitai version) as a marker that it's been enriched
|
||||
checkpoint_entry = recipe.get("checkpoint")
|
||||
has_full_checkpoint = checkpoint_entry and checkpoint_entry.get("name") and checkpoint_entry.get("id")
|
||||
|
||||
if not has_full_checkpoint:
|
||||
# Helper to look up values in priority order
|
||||
def start_lookup(keys):
|
||||
for source in [request_params, civitai_meta, gen_params]:
|
||||
if source:
|
||||
if isinstance(keys, list):
|
||||
for k in keys:
|
||||
if k in source: return source[k]
|
||||
else:
|
||||
if keys in source: return source[keys]
|
||||
return None
|
||||
|
||||
target_version_id = model_version_id or start_lookup("modelVersionId")
|
||||
|
||||
# Also check existing checkpoint entry
|
||||
if not target_version_id and checkpoint_entry:
|
||||
target_version_id = checkpoint_entry.get("modelVersionId") or checkpoint_entry.get("id")
|
||||
|
||||
# Check for version ID in resources (which might be a string in gen_params)
|
||||
if not target_version_id:
|
||||
# Look in all sources for "Civitai resources"
|
||||
resources_val = start_lookup(["Civitai resources", "civitai_resources", "resources"])
|
||||
if resources_val:
|
||||
target_version_id = RecipeEnricher._extract_version_id_from_resources({"Civitai resources": resources_val})
|
||||
|
||||
target_hash = start_lookup(["Model hash", "checkpoint_hash", "hashes"])
|
||||
if not target_hash and checkpoint_entry:
|
||||
target_hash = checkpoint_entry.get("hash") or checkpoint_entry.get("model_hash")
|
||||
|
||||
# Look for 'Model' which sometimes is the hash or name
|
||||
model_val = start_lookup("Model")
|
||||
|
||||
# Look for Checkpoint name fallback
|
||||
checkpoint_val = checkpoint_entry.get("name") if checkpoint_entry else None
|
||||
if not checkpoint_val:
|
||||
checkpoint_val = start_lookup(["Checkpoint", "checkpoint"])
|
||||
|
||||
checkpoint_updated = await RecipeEnricher._resolve_and_populate_checkpoint(
|
||||
recipe, target_version_id, target_hash, model_val, checkpoint_val
|
||||
)
|
||||
if checkpoint_updated:
|
||||
# Sync to gen_params for consistency with legacy usage
|
||||
if "gen_params" not in recipe:
|
||||
recipe["gen_params"] = {}
|
||||
recipe["gen_params"]["checkpoint"] = recipe["checkpoint"]
|
||||
updated = True
|
||||
else:
|
||||
# Even if we have a checkpoint, ensure it is synced to gen_params if missing there
|
||||
if "checkpoint" in recipe and recipe["checkpoint"]:
|
||||
if "gen_params" not in recipe:
|
||||
recipe["gen_params"] = {}
|
||||
if "checkpoint" not in recipe["gen_params"]:
|
||||
recipe["gen_params"]["checkpoint"] = recipe["checkpoint"]
|
||||
# We don't necessarily mark 'updated=True' just for this sync if the rest is the same,
|
||||
# but it's safer to ensure it's there.
|
||||
updated = True
|
||||
# If base_model is empty or very generic, try to use what we found in checkpoint
|
||||
current_base_model = recipe.get("base_model")
|
||||
checkpoint_after = recipe.get("checkpoint")
|
||||
if checkpoint_after and checkpoint_after.get("baseModel"):
|
||||
resolved_base_model = checkpoint_after["baseModel"]
|
||||
# Update if empty OR if it matches our generic prefix but is less specific
|
||||
is_generic = not current_base_model or current_base_model.lower() in ["flux", "sdxl", "sd15"]
|
||||
if is_generic and resolved_base_model != current_base_model:
|
||||
recipe["base_model"] = resolved_base_model
|
||||
updated = True
|
||||
|
||||
return updated
|
||||
|
||||
@staticmethod
|
||||
def _extract_version_id_from_resources(gen_params: Dict[str, Any]) -> Optional[Any]:
|
||||
"""Try to find modelVersionId in Civitai resources parameter."""
|
||||
civitai_resources_raw = gen_params.get("Civitai resources")
|
||||
if not civitai_resources_raw:
|
||||
return None
|
||||
|
||||
resources_list = None
|
||||
if isinstance(civitai_resources_raw, str):
|
||||
try:
|
||||
resources_list = json.loads(civitai_resources_raw)
|
||||
except Exception:
|
||||
pass
|
||||
elif isinstance(civitai_resources_raw, list):
|
||||
resources_list = civitai_resources_raw
|
||||
|
||||
if isinstance(resources_list, list):
|
||||
for res in resources_list:
|
||||
if res.get("type") == "checkpoint":
|
||||
return res.get("modelVersionId")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _resolve_and_populate_checkpoint(
|
||||
recipe: Dict[str, Any],
|
||||
target_version_id: Optional[Any],
|
||||
target_hash: Optional[str],
|
||||
model_val: Optional[str],
|
||||
checkpoint_val: Optional[str]
|
||||
) -> bool:
|
||||
"""Find checkpoint metadata and populate it in the recipe."""
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
civitai_info = None
|
||||
|
||||
if target_version_id:
|
||||
civitai_info = await metadata_provider.get_model_version_info(str(target_version_id))
|
||||
elif target_hash:
|
||||
civitai_info = await metadata_provider.get_model_by_hash(target_hash)
|
||||
else:
|
||||
# Look for 'Model' which sometimes is the hash or name
|
||||
if model_val and len(model_val) == 10: # Likely a short hash
|
||||
civitai_info = await metadata_provider.get_model_by_hash(model_val)
|
||||
|
||||
if civitai_info and not (isinstance(civitai_info, tuple) and civitai_info[1] == "Model not found"):
|
||||
# If we already have a partial checkpoint, use it as base
|
||||
existing_cp = recipe.get("checkpoint")
|
||||
if existing_cp is None:
|
||||
existing_cp = {}
|
||||
checkpoint_data = await RecipeMetadataParser.populate_checkpoint_from_civitai(existing_cp, civitai_info)
|
||||
recipe["checkpoint"] = checkpoint_data
|
||||
|
||||
# Ensure the modelVersionId is stored if we found it
|
||||
if target_version_id and "modelVersionId" not in recipe["checkpoint"]:
|
||||
recipe["checkpoint"]["modelVersionId"] = int(target_version_id)
|
||||
return True
|
||||
else:
|
||||
# Fallback to name extraction if we don't already have one
|
||||
existing_cp = recipe.get("checkpoint")
|
||||
if not existing_cp or not existing_cp.get("name"):
|
||||
cp_name = checkpoint_val
|
||||
if cp_name:
|
||||
recipe["checkpoint"] = {
|
||||
"type": "checkpoint",
|
||||
"name": cp_name,
|
||||
"modelName": cp_name,
|
||||
"file_name": os.path.splitext(cp_name)[0]
|
||||
}
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -6,7 +6,31 @@ logger = logging.getLogger(__name__)
|
||||
class GenParamsMerger:
|
||||
"""Utility to merge generation parameters from multiple sources with priority."""
|
||||
|
||||
BLACKLISTED_KEYS = {"id", "url", "userId", "username", "createdAt", "updatedAt", "hash"}
|
||||
BLACKLISTED_KEYS = {
|
||||
"id", "url", "userId", "username", "createdAt", "updatedAt", "hash", "meta",
|
||||
"draft", "extra", "width", "height", "process", "quantity", "workflow",
|
||||
"baseModel", "resources", "disablePoi", "aspectRatio", "Created Date",
|
||||
"experimental", "civitaiResources", "civitai_resources", "Civitai resources",
|
||||
"modelVersionId", "modelId", "hashes", "Model", "Model hash", "checkpoint_hash",
|
||||
"checksum", "model_checksum"
|
||||
}
|
||||
|
||||
NORMALIZATION_MAPPING = {
|
||||
# Civitai specific
|
||||
"cfgScale": "cfg_scale",
|
||||
"clipSkip": "clip_skip",
|
||||
"negativePrompt": "negative_prompt",
|
||||
# Case variations
|
||||
"Sampler": "sampler",
|
||||
"Steps": "steps",
|
||||
"Seed": "seed",
|
||||
"Size": "size",
|
||||
"Prompt": "prompt",
|
||||
"Negative prompt": "negative_prompt",
|
||||
"Cfg scale": "cfg_scale",
|
||||
"Clip skip": "clip_skip",
|
||||
"Denoising strength": "denoising_strength",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def merge(
|
||||
@@ -33,18 +57,42 @@ class GenParamsMerger:
|
||||
if embedded_metadata:
|
||||
# If it's a full recipe metadata, we use its gen_params
|
||||
if "gen_params" in embedded_metadata and isinstance(embedded_metadata["gen_params"], dict):
|
||||
result.update(embedded_metadata["gen_params"])
|
||||
GenParamsMerger._update_normalized(result, embedded_metadata["gen_params"])
|
||||
else:
|
||||
# Otherwise assume the dict itself contains gen_params
|
||||
result.update(embedded_metadata)
|
||||
GenParamsMerger._update_normalized(result, embedded_metadata)
|
||||
|
||||
# 2. Layer Civitai meta (medium priority)
|
||||
if civitai_meta:
|
||||
result.update(civitai_meta)
|
||||
GenParamsMerger._update_normalized(result, civitai_meta)
|
||||
|
||||
# 3. Layer request params (highest priority)
|
||||
if request_params:
|
||||
result.update(request_params)
|
||||
GenParamsMerger._update_normalized(result, request_params)
|
||||
|
||||
# Filter out blacklisted keys
|
||||
return {k: v for k, v in result.items() if k not in GenParamsMerger.BLACKLISTED_KEYS}
|
||||
# Filter out blacklisted keys and also the original camelCase keys if they were normalized
|
||||
final_result = {}
|
||||
for k, v in result.items():
|
||||
if k in GenParamsMerger.BLACKLISTED_KEYS:
|
||||
continue
|
||||
if k in GenParamsMerger.NORMALIZATION_MAPPING:
|
||||
continue
|
||||
final_result[k] = v
|
||||
|
||||
return final_result
|
||||
|
||||
@staticmethod
|
||||
def _update_normalized(target: Dict[str, Any], source: Dict[str, Any]) -> None:
|
||||
"""Update target dict with normalized keys from source."""
|
||||
for k, v in source.items():
|
||||
normalized_key = GenParamsMerger.NORMALIZATION_MAPPING.get(k, k)
|
||||
target[normalized_key] = v
|
||||
# Also keep the original key for now if it's not the same,
|
||||
# so we can filter at the end or avoid losing it if it wasn't supposed to be renamed?
|
||||
# Actually, if we rename it, we should probably NOT keep both in 'target'
|
||||
# because we want to filter them out at the end anyway.
|
||||
if normalized_key != k:
|
||||
# If we are overwriting an existing snake_case key with a camelCase one's value,
|
||||
# that's fine because of the priority order of calls to _update_normalized.
|
||||
pass
|
||||
target[k] = v
|
||||
|
||||
@@ -27,6 +27,7 @@ from ...services.metadata_service import get_default_metadata_provider
|
||||
from ...utils.civitai_utils import rewrite_preview_url
|
||||
from ...utils.exif_utils import ExifUtils
|
||||
from ...recipes.merger import GenParamsMerger
|
||||
from ...recipes.enrichment import RecipeEnricher
|
||||
from ...services.websocket_manager import ws_manager as default_ws_manager
|
||||
|
||||
Logger = logging.Logger
|
||||
@@ -585,13 +586,15 @@ class RecipeManagementHandler:
|
||||
self._logger.error("Error getting repair progress: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
async def import_remote_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
|
||||
# 1. Parse Parameters
|
||||
params = request.rel_url.query
|
||||
image_url = params.get("image_url")
|
||||
name = params.get("name")
|
||||
@@ -605,30 +608,52 @@ class RecipeManagementHandler:
|
||||
raise RecipeValidationError("Missing required field: resources")
|
||||
|
||||
checkpoint_entry, lora_entries = self._parse_resources_payload(resources_raw)
|
||||
gen_params = self._parse_gen_params(params.get("gen_params"))
|
||||
gen_params_request = self._parse_gen_params(params.get("gen_params"))
|
||||
|
||||
# 2. Initial Metadata Construction
|
||||
metadata: Dict[str, Any] = {
|
||||
"base_model": params.get("base_model", "") or "",
|
||||
"loras": lora_entries,
|
||||
"gen_params": gen_params_request or {},
|
||||
"source_url": image_url
|
||||
}
|
||||
|
||||
source_path = params.get("source_path")
|
||||
if source_path:
|
||||
metadata["source_path"] = source_path
|
||||
if gen_params is not None:
|
||||
metadata["gen_params"] = gen_params
|
||||
|
||||
# Checkpoint handling
|
||||
if checkpoint_entry:
|
||||
metadata["checkpoint"] = checkpoint_entry
|
||||
gen_params_ref = metadata.setdefault("gen_params", {})
|
||||
if "checkpoint" not in gen_params_ref:
|
||||
gen_params_ref["checkpoint"] = checkpoint_entry
|
||||
base_model_from_metadata = await self._resolve_base_model_from_checkpoint(checkpoint_entry)
|
||||
if base_model_from_metadata:
|
||||
metadata["base_model"] = base_model_from_metadata
|
||||
# Ensure checkpoint is also in gen_params for consistency if needed by enricher?
|
||||
# Actually enricher looks at metadata['checkpoint'], so this is fine.
|
||||
|
||||
# Try to resolve base model from checkpoint if not explicitly provided
|
||||
if not metadata["base_model"]:
|
||||
base_model_from_metadata = await self._resolve_base_model_from_checkpoint(checkpoint_entry)
|
||||
if base_model_from_metadata:
|
||||
metadata["base_model"] = base_model_from_metadata
|
||||
|
||||
tags = self._parse_tags(params.get("tags"))
|
||||
image_bytes, extension, civitai_meta = await self._download_remote_media(image_url)
|
||||
|
||||
# 3. Download Image
|
||||
image_bytes, extension, civitai_meta_from_download = await self._download_remote_media(image_url)
|
||||
|
||||
# Extract embedded metadata from the downloaded image
|
||||
embedded_metadata = None
|
||||
# 4. Extract Embedded Metadata
|
||||
# Note: We still extract this here because Enricher currently expects 'gen_params' to already be populated
|
||||
# with embedded data if we want it to merge it.
|
||||
# However, logic in Enricher merges: request > civitai > embedded.
|
||||
# So we should gather embedded params and put them into the recipe's gen_params (as initial state)
|
||||
# OR pass them to enricher to handle?
|
||||
# The interface of Enricher.enrich_recipe takes `recipe` (with gen_params) and `request_params`.
|
||||
# So let's extract embedded and put it into recipe['gen_params'] but careful not to overwrite request params.
|
||||
# Actually, `GenParamsMerger` which `Enricher` uses handles 3 layers.
|
||||
# But `Enricher` interface is: recipe['gen_params'] (as embedded) + request_params + civitai (fetched internally).
|
||||
# Wait, `Enricher` fetches Civitai info internally based on URL.
|
||||
# `civitai_meta_from_download` is returned by `_download_remote_media` which might be useful if URL didn't have ID.
|
||||
|
||||
# Let's extract embedded metadata first
|
||||
embedded_gen_params = {}
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_img:
|
||||
temp_img.write(image_bytes)
|
||||
@@ -637,29 +662,39 @@ class RecipeManagementHandler:
|
||||
try:
|
||||
raw_embedded = ExifUtils.extract_image_metadata(temp_img_path)
|
||||
if raw_embedded:
|
||||
# Try to parse it using standard parsers if it looks like a recipe
|
||||
parser = self._analysis_service._recipe_parser_factory.create_parser(raw_embedded)
|
||||
if parser:
|
||||
parsed_embedded = await parser.parse_metadata(raw_embedded, recipe_scanner=recipe_scanner)
|
||||
embedded_metadata = parsed_embedded
|
||||
if parsed_embedded and "gen_params" in parsed_embedded:
|
||||
embedded_gen_params = parsed_embedded["gen_params"]
|
||||
else:
|
||||
# Fallback to raw string if no parser matches (might be simple params)
|
||||
embedded_metadata = {"gen_params": {"raw_metadata": raw_embedded}}
|
||||
embedded_gen_params = {"raw_metadata": raw_embedded}
|
||||
finally:
|
||||
if os.path.exists(temp_img_path):
|
||||
os.unlink(temp_img_path)
|
||||
except Exception as exc:
|
||||
self._logger.warning("Failed to extract embedded metadata during import: %s", exc)
|
||||
|
||||
# Merge gen_params from all sources
|
||||
merged_gen_params = GenParamsMerger.merge(
|
||||
request_params=gen_params,
|
||||
civitai_meta=civitai_meta,
|
||||
embedded_metadata=embedded_metadata
|
||||
# Pre-populate gen_params with embedded data so Enricher treats it as the "base" layer
|
||||
if embedded_gen_params:
|
||||
# Merge embedded into existing gen_params (which currently only has request params if any)
|
||||
# But wait, we want request params to override everything.
|
||||
# So we should set recipe['gen_params'] = embedded, and pass request params to enricher.
|
||||
metadata["gen_params"] = embedded_gen_params
|
||||
|
||||
# 5. Enrich with unified logic
|
||||
# This will fetch Civitai info (if URL matches) and merge: request > civitai > embedded
|
||||
civitai_client = self._civitai_client_getter()
|
||||
await RecipeEnricher.enrich_recipe(
|
||||
recipe=metadata,
|
||||
civitai_client=civitai_client,
|
||||
request_params=gen_params_request # Pass explicit request params here to override
|
||||
)
|
||||
|
||||
if merged_gen_params:
|
||||
metadata["gen_params"] = merged_gen_params
|
||||
|
||||
# If we got civitai_meta from download but Enricher didn't fetch it (e.g. not a civitai URL or failed),
|
||||
# we might want to manually merge it?
|
||||
# But usually `import_remote_recipe` is used with Civitai URLs.
|
||||
# For now, relying on Enricher's internal fetch is consistent with repair.
|
||||
|
||||
result = await self._persistence_service.save_recipe(
|
||||
recipe_scanner=recipe_scanner,
|
||||
|
||||
@@ -18,6 +18,7 @@ from natsort import natsorted
|
||||
import sys
|
||||
import re
|
||||
from ..recipes.merger import GenParamsMerger
|
||||
from ..recipes.enrichment import RecipeEnricher
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -184,18 +185,22 @@ class RecipeScanner:
|
||||
Dict summary of repair result
|
||||
"""
|
||||
async with self._mutation_lock:
|
||||
recipe = await self.get_recipe_by_id(recipe_id)
|
||||
# Get raw recipe from cache directly to avoid formatted fields
|
||||
cache = await self.get_cached_data()
|
||||
recipe = next((r for r in cache.raw_data if str(r.get('id', '')) == recipe_id), None)
|
||||
|
||||
if not recipe:
|
||||
raise RecipeNotFoundError(f"Recipe {recipe_id} not found")
|
||||
|
||||
civitai_client = await self._get_civitai_client()
|
||||
success = await self._repair_single_recipe(recipe, civitai_client)
|
||||
|
||||
# If successfully repaired, we should return the formatted version for the UI
|
||||
return {
|
||||
"success": True,
|
||||
"repaired": 1 if success else 0,
|
||||
"skipped": 0 if success else 1,
|
||||
"recipe": recipe
|
||||
"recipe": await self.get_recipe_by_id(recipe_id) if success else recipe
|
||||
}
|
||||
|
||||
async def _repair_single_recipe(self, recipe: Dict[str, Any], civitai_client: Any) -> bool:
|
||||
@@ -221,68 +226,28 @@ class RecipeScanner:
|
||||
|
||||
if not needs_repair:
|
||||
# Even if no repair needed, we mark it with version if it was processed
|
||||
if "repair_version" not in recipe:
|
||||
recipe["repair_version"] = self.REPAIR_VERSION
|
||||
await self._save_recipe_persistently(recipe)
|
||||
return True
|
||||
return False
|
||||
# Always update and save because if we are here, the version is old (checked in step 1)
|
||||
recipe["repair_version"] = self.REPAIR_VERSION
|
||||
await self._save_recipe_persistently(recipe)
|
||||
return True
|
||||
|
||||
# 3. Data Fetching & Merging
|
||||
source_url = recipe.get("source_url", "")
|
||||
civitai_meta = None
|
||||
model_version_id = None
|
||||
|
||||
# Check if it's a Civitai image URL
|
||||
image_id_match = re.search(r'civitai\.com/images/(\d+)', source_url)
|
||||
if image_id_match:
|
||||
image_id = image_id_match.group(1)
|
||||
image_info = await civitai_client.get_image_info(image_id)
|
||||
if image_info:
|
||||
if "meta" in image_info:
|
||||
civitai_meta = image_info["meta"]
|
||||
model_version_id = image_info.get("modelVersionId")
|
||||
|
||||
# Merge with existing data
|
||||
new_gen_params = GenParamsMerger.merge(
|
||||
civitai_meta=civitai_meta,
|
||||
embedded_metadata=gen_params
|
||||
)
|
||||
|
||||
updated = False
|
||||
if new_gen_params != gen_params:
|
||||
recipe["gen_params"] = new_gen_params
|
||||
updated = True
|
||||
# 3. Use Enricher to repair/enrich
|
||||
try:
|
||||
updated = await RecipeEnricher.enrich_recipe(recipe, civitai_client)
|
||||
except Exception as e:
|
||||
logger.error(f"Error enriching recipe {recipe.get('id')}: {e}")
|
||||
updated = False
|
||||
|
||||
# 4. Mark version and save if updated or just marking version
|
||||
# If we updated it, OR if the version is old (which we know it is if we are here), save it.
|
||||
# Actually, if we are here and updated is False, it means we tried to repair but couldn't/didn't need to.
|
||||
# But we still want to mark it as processed so we don't try again until version bump.
|
||||
if updated or recipe.get("repair_version", 0) < self.REPAIR_VERSION:
|
||||
recipe["repair_version"] = self.REPAIR_VERSION
|
||||
await self._save_recipe_persistently(recipe)
|
||||
return True
|
||||
|
||||
# 4. Update checkpoint if missing or repairable
|
||||
if not has_checkpoint:
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
target_version_id = model_version_id or new_gen_params.get("modelVersionId")
|
||||
target_hash = new_gen_params.get("Model hash")
|
||||
|
||||
civitai_info = None
|
||||
if target_version_id:
|
||||
civitai_info = await metadata_provider.get_model_version_info(str(target_version_id))
|
||||
elif target_hash:
|
||||
civitai_info = await metadata_provider.get_model_by_hash(target_hash)
|
||||
|
||||
if civitai_info and not (isinstance(civitai_info, tuple) and civitai_info[1] == "Model not found"):
|
||||
recipe["checkpoint"] = await self._populate_checkpoint(civitai_info)
|
||||
updated = True
|
||||
else:
|
||||
# Fallback to name extraction
|
||||
cp_name = new_gen_params.get("Checkpoint") or new_gen_params.get("checkpoint")
|
||||
if cp_name:
|
||||
recipe["checkpoint"] = {
|
||||
"name": cp_name,
|
||||
"file_name": os.path.splitext(cp_name)[0]
|
||||
}
|
||||
updated = True
|
||||
|
||||
# 5. Mark version and save
|
||||
recipe["repair_version"] = self.REPAIR_VERSION
|
||||
await self._save_recipe_persistently(recipe)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _save_recipe_persistently(self, recipe: Dict[str, Any]) -> bool:
|
||||
"""Helper to save a recipe to both JSON and EXIF metadata."""
|
||||
@@ -318,58 +283,16 @@ class RecipeScanner:
|
||||
logger.error(f"Error persisting recipe {recipe_id}: {e}")
|
||||
return False
|
||||
|
||||
async def _populate_checkpoint(self, civitai_info_tuple: Any) -> Dict[str, Any]:
|
||||
"""Helper to populate checkpoint info using common logic."""
|
||||
civitai_data, error_msg = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
|
||||
|
||||
checkpoint = {
|
||||
"name": "",
|
||||
"file_name": "",
|
||||
"isDeleted": False,
|
||||
"hash": ""
|
||||
}
|
||||
|
||||
if not civitai_data or error_msg == "Model not found":
|
||||
checkpoint["isDeleted"] = True
|
||||
return checkpoint
|
||||
|
||||
try:
|
||||
if "model" in civitai_data and "name" in civitai_data["model"]:
|
||||
checkpoint["name"] = civitai_data["model"]["name"]
|
||||
|
||||
if "name" in civitai_data:
|
||||
checkpoint["version"] = civitai_data.get("name", "")
|
||||
|
||||
if "images" in civitai_data and civitai_data["images"]:
|
||||
from ..utils.civitai_utils import rewrite_preview_url
|
||||
image_url = civitai_data["images"][0].get("url")
|
||||
if image_url:
|
||||
rewritten_url, _ = rewrite_preview_url(image_url, media_type="image")
|
||||
checkpoint["thumbnailUrl"] = rewritten_url or image_url
|
||||
|
||||
checkpoint["baseModel"] = civitai_data.get("baseModel", "")
|
||||
checkpoint["modelId"] = civitai_data.get("modelId", 0)
|
||||
checkpoint["id"] = civitai_data.get("id", 0)
|
||||
|
||||
if "files" in civitai_data:
|
||||
model_file = next((f for f in civitai_data.get("files", []) if f.get("type") == "Model"), None)
|
||||
if model_file:
|
||||
sha256 = model_file.get("hashes", {}).get("SHA256")
|
||||
if sha256:
|
||||
checkpoint["hash"] = sha256.lower()
|
||||
f_name = model_file.get("name", "")
|
||||
if f_name:
|
||||
checkpoint["file_name"] = os.path.splitext(f_name)[0]
|
||||
except Exception as e:
|
||||
logger.error(f"Error populating checkpoint: {e}")
|
||||
|
||||
return checkpoint
|
||||
|
||||
def _sanitize_recipe_for_storage(self, recipe: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Create a clean copy of the recipe without runtime convenience fields."""
|
||||
import copy
|
||||
clean = copy.deepcopy(recipe)
|
||||
|
||||
# 0. Clean top-level runtime fields
|
||||
for key in ("file_url", "created_date_formatted", "modified_formatted"):
|
||||
clean.pop(key, None)
|
||||
|
||||
# 1. Clean LORAs
|
||||
if "loras" in clean and isinstance(clean["loras"], list):
|
||||
for lora in clean["loras"]:
|
||||
|
||||
@@ -338,7 +338,7 @@ async def test_move_recipe_invokes_persistence(monkeypatch, tmp_path: Path) -> N
|
||||
|
||||
|
||||
async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None:
|
||||
provider_calls: list[int] = []
|
||||
provider_calls: list[str | int] = []
|
||||
|
||||
class Provider:
|
||||
async def get_model_version_info(self, model_version_id):
|
||||
@@ -348,7 +348,7 @@ async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None:
|
||||
async def fake_get_default_metadata_provider():
|
||||
return Provider()
|
||||
|
||||
monkeypatch.setattr(recipe_handlers, "get_default_metadata_provider", fake_get_default_metadata_provider)
|
||||
monkeypatch.setattr("py.recipes.enrichment.get_default_metadata_provider", fake_get_default_metadata_provider)
|
||||
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
resources = [
|
||||
@@ -390,7 +390,7 @@ async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None:
|
||||
assert call["tags"] == ["foo", "bar"]
|
||||
metadata = call["metadata"]
|
||||
assert metadata["base_model"] == "Flux Provider"
|
||||
assert provider_calls == [33]
|
||||
assert provider_calls == ["33"]
|
||||
assert metadata["checkpoint"]["modelVersionId"] == 33
|
||||
assert metadata["loras"][0]["weight"] == 0.25
|
||||
assert metadata["gen_params"]["prompt"] == "hello world"
|
||||
@@ -399,7 +399,7 @@ async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None:
|
||||
|
||||
|
||||
async def test_import_remote_recipe_falls_back_to_request_base_model(monkeypatch, tmp_path: Path) -> None:
|
||||
provider_calls: list[int] = []
|
||||
provider_calls: list[str | int] = []
|
||||
|
||||
class Provider:
|
||||
async def get_model_version_info(self, model_version_id):
|
||||
@@ -409,7 +409,7 @@ async def test_import_remote_recipe_falls_back_to_request_base_model(monkeypatch
|
||||
async def fake_get_default_metadata_provider():
|
||||
return Provider()
|
||||
|
||||
monkeypatch.setattr(recipe_handlers, "get_default_metadata_provider", fake_get_default_metadata_provider)
|
||||
monkeypatch.setattr("py.recipes.enrichment.get_default_metadata_provider", fake_get_default_metadata_provider)
|
||||
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
resources = [
|
||||
@@ -438,14 +438,14 @@ async def test_import_remote_recipe_falls_back_to_request_base_model(monkeypatch
|
||||
|
||||
metadata = harness.persistence.save_calls[-1]["metadata"]
|
||||
assert metadata["base_model"] == "Flux"
|
||||
assert provider_calls == [77]
|
||||
assert provider_calls == ["77"]
|
||||
|
||||
|
||||
async def test_import_remote_video_recipe(monkeypatch, tmp_path: Path) -> None:
|
||||
async def fake_get_default_metadata_provider():
|
||||
return SimpleNamespace(get_model_version_info=lambda id: ({}, None))
|
||||
|
||||
monkeypatch.setattr(recipe_handlers, "get_default_metadata_provider", fake_get_default_metadata_provider)
|
||||
monkeypatch.setattr("py.recipes.enrichment.get_default_metadata_provider", fake_get_default_metadata_provider)
|
||||
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
harness.civitai.image_info["12345"] = {
|
||||
@@ -537,7 +537,7 @@ async def test_import_remote_recipe_merges_metadata(monkeypatch, tmp_path: Path)
|
||||
async def fake_get_default_metadata_provider():
|
||||
return Provider()
|
||||
|
||||
monkeypatch.setattr(recipe_handlers, "get_default_metadata_provider", fake_get_default_metadata_provider)
|
||||
monkeypatch.setattr("py.recipes.enrichment.get_default_metadata_provider", fake_get_default_metadata_provider)
|
||||
|
||||
# 2. Mock ExifUtils to return some embedded metadata
|
||||
class MockExifUtils:
|
||||
|
||||
@@ -57,3 +57,38 @@ def test_merge_filters_blacklisted_keys():
|
||||
assert "id" not in merged
|
||||
assert "url" not in merged
|
||||
assert "hash" not in merged
|
||||
|
||||
def test_merge_filters_meta_and_normalizes_keys():
|
||||
civitai_meta = {
|
||||
"prompt": "masterpiece",
|
||||
"cfgScale": 5,
|
||||
"clipSkip": 2,
|
||||
"negativePrompt": "low quality",
|
||||
"meta": {"irrelevant": "data"},
|
||||
"Size": "1024x1024",
|
||||
"draft": False,
|
||||
"workflow": "txt2img",
|
||||
"civitaiResources": [{"type": "checkpoint"}]
|
||||
}
|
||||
request_params = {
|
||||
"cfg_scale": 5.0,
|
||||
"clip_skip": "2",
|
||||
"Steps": 30
|
||||
}
|
||||
|
||||
merged = GenParamsMerger.merge(request_params, civitai_meta)
|
||||
|
||||
assert "meta" not in merged
|
||||
assert "cfgScale" not in merged
|
||||
assert "clipSkip" not in merged
|
||||
assert "negativePrompt" not in merged
|
||||
assert "Size" not in merged
|
||||
assert "draft" not in merged
|
||||
assert "workflow" not in merged
|
||||
assert "civitaiResources" not in merged
|
||||
|
||||
assert merged["cfg_scale"] == 5.0 # From request_params
|
||||
assert merged["clip_skip"] == "2" # From request_params
|
||||
assert merged["negative_prompt"] == "low quality" # Normalized from civitai_meta
|
||||
assert merged["size"] == "1024x1024" # Normalized from civitai_meta
|
||||
assert merged["steps"] == 30 # Normalized from request_params
|
||||
|
||||
@@ -43,7 +43,7 @@ def setup_scanner(recipe_scanner, mock_civitai_client, mock_metadata_provider, m
|
||||
mock_save = AsyncMock(side_effect=real_save)
|
||||
monkeypatch.setattr(recipe_scanner, "_save_recipe_persistently", mock_save)
|
||||
|
||||
monkeypatch.setattr("py.services.recipe_scanner.get_default_metadata_provider", AsyncMock(return_value=mock_metadata_provider))
|
||||
monkeypatch.setattr("py.recipes.enrichment.get_default_metadata_provider", AsyncMock(return_value=mock_metadata_provider))
|
||||
|
||||
# Mock get_recipe_json_path to avoid file system issues in tests
|
||||
recipe_scanner.get_recipe_json_path = AsyncMock(return_value="/tmp/test_recipe.json")
|
||||
@@ -259,11 +259,7 @@ async def test_repair_all_recipes_strips_runtime_fields(setup_scanner):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sanitize_recipe_for_storage(recipe_scanner):
|
||||
import sys
|
||||
import py.services.recipe_scanner
|
||||
print(f"\nDEBUG_ENV: sys.path: {sys.path}")
|
||||
print(f"DEBUG_ENV: recipe_scanner file: {py.services.recipe_scanner.__file__}")
|
||||
|
||||
|
||||
recipe = {
|
||||
"loras": [{"name": "L1", "inLibrary": True, "weight": 0.5}],
|
||||
"checkpoint": {"name": "CP", "localPath": "/tmp/cp"}
|
||||
|
||||
Reference in New Issue
Block a user