Merge branch 'sort-by-usage-count' into main

This commit is contained in:
pixelpaws
2025-12-26 22:17:03 +08:00
committed by GitHub
85 changed files with 6030 additions and 1550 deletions

View File

@@ -1,11 +1,13 @@
import os
import platform
import threading
from pathlib import Path
import folder_paths # type: ignore
from typing import Any, Dict, Iterable, List, Mapping, Optional, Set
from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple
import logging
import json
import urllib.parse
import time
from .utils.settings_paths import ensure_settings_file, get_settings_dir, load_settings_template
@@ -80,6 +82,8 @@ class Config:
self._path_mappings: Dict[str, str] = {}
# Normalized preview root directories used to validate preview access
self._preview_root_paths: Set[Path] = set()
# Optional background rescan thread
self._rescan_thread: Optional[threading.Thread] = None
self.loras_roots = self._init_lora_paths()
self.checkpoints_roots = None
self.unet_roots = None
@@ -282,58 +286,25 @@ class Config:
def _load_symlink_cache(self) -> bool:
cache_path = self._get_symlink_cache_path()
if not cache_path.exists():
logger.info("Symlink cache not found at %s", cache_path)
return False
try:
with cache_path.open("r", encoding="utf-8") as handle:
payload = json.load(handle)
except Exception as exc:
logger.debug("Failed to load symlink cache %s: %s", cache_path, exc)
logger.info("Failed to load symlink cache %s: %s", cache_path, exc)
return False
if not isinstance(payload, dict):
logger.info("Symlink cache payload is not a dict: %s", type(payload))
return False
cached_fingerprint = payload.get("fingerprint")
cached_mappings = payload.get("path_mappings")
if not isinstance(cached_fingerprint, dict) or not isinstance(cached_mappings, Mapping):
if not isinstance(cached_mappings, Mapping):
logger.info("Symlink cache missing path mappings")
return False
current_fingerprint = self._build_symlink_fingerprint()
cached_roots = cached_fingerprint.get("roots")
cached_stats = cached_fingerprint.get("stats")
if (
not isinstance(cached_roots, list)
or not isinstance(cached_stats, Mapping)
or sorted(cached_roots) != sorted(current_fingerprint["roots"]) # type: ignore[index]
):
return False
for root in current_fingerprint["roots"]: # type: ignore[assignment]
cached_stat = cached_stats.get(root) if isinstance(cached_stats, Mapping) else None
current_stat = current_fingerprint["stats"].get(root) # type: ignore[index]
if not isinstance(cached_stat, Mapping) or not current_stat:
return False
cached_mtime = cached_stat.get("mtime_ns")
cached_inode = cached_stat.get("inode")
current_mtime = current_stat.get("mtime_ns")
current_inode = current_stat.get("inode")
if cached_inode != current_inode:
return False
if cached_mtime != current_mtime:
cached_noise = cached_stat.get("noise_mtime_ns")
current_noise = current_stat.get("noise_mtime_ns")
if not (
cached_noise
and current_noise
and cached_mtime == cached_noise
and current_mtime == current_noise
):
return False
normalized_mappings: Dict[str, str] = {}
for target, link in cached_mappings.items():
if not isinstance(target, str) or not isinstance(link, str):
@@ -341,6 +312,7 @@ class Config:
normalized_mappings[self._normalize_path(target)] = self._normalize_path(link)
self._path_mappings = normalized_mappings
logger.info("Symlink cache loaded with %d mappings", len(self._path_mappings))
return True
def _save_symlink_cache(self) -> None:
@@ -353,22 +325,75 @@ class Config:
try:
with cache_path.open("w", encoding="utf-8") as handle:
json.dump(payload, handle, ensure_ascii=False, indent=2)
logger.info("Symlink cache saved to %s with %d mappings", cache_path, len(self._path_mappings))
except Exception as exc:
logger.debug("Failed to write symlink cache %s: %s", cache_path, exc)
logger.info("Failed to write symlink cache %s: %s", cache_path, exc)
def _initialize_symlink_mappings(self) -> None:
if not self._load_symlink_cache():
self._scan_symbolic_links()
self._save_symlink_cache()
else:
logger.info("Loaded symlink mappings from cache")
start = time.perf_counter()
cache_loaded = self._load_symlink_cache()
if cache_loaded:
logger.info(
"Symlink mappings restored from cache in %.2f ms",
(time.perf_counter() - start) * 1000,
)
self._rebuild_preview_roots()
self._schedule_symlink_rescan()
return
self._scan_symbolic_links()
self._save_symlink_cache()
self._rebuild_preview_roots()
logger.info(
"Symlink mappings rebuilt and cached in %.2f ms",
(time.perf_counter() - start) * 1000,
)
def _scan_symbolic_links(self):
"""Scan all symbolic links in LoRA, Checkpoint, and Embedding root directories"""
start = time.perf_counter()
# Reset mappings before rescanning to avoid stale entries
self._path_mappings.clear()
self._seed_root_symlink_mappings()
visited_dirs: Set[str] = set()
for root in self._symlink_roots():
self._scan_directory_links(root, visited_dirs)
logger.info(
"Symlink scan finished in %.2f ms with %d mappings",
(time.perf_counter() - start) * 1000,
len(self._path_mappings),
)
def _schedule_symlink_rescan(self) -> None:
"""Trigger a best-effort background rescan to refresh stale caches."""
if self._rescan_thread and self._rescan_thread.is_alive():
return
def worker():
try:
self._scan_symbolic_links()
self._save_symlink_cache()
self._rebuild_preview_roots()
logger.info("Background symlink rescan completed")
except Exception as exc: # pragma: no cover - defensive logging
logger.info("Background symlink rescan failed: %s", exc)
thread = threading.Thread(
target=worker,
name="lora-manager-symlink-rescan",
daemon=True,
)
self._rescan_thread = thread
thread.start()
def _wait_for_rescan(self, timeout: Optional[float] = None) -> None:
"""Block until the background rescan completes (testing convenience)."""
thread = self._rescan_thread
if thread:
thread.join(timeout=timeout)
def _scan_directory_links(self, root: str, visited_dirs: Set[str]):
"""Iteratively scan directory symlinks to avoid deep recursion."""
@@ -434,6 +459,22 @@ class Config:
self._preview_root_paths.update(self._expand_preview_root(normalized_target))
self._preview_root_paths.update(self._expand_preview_root(normalized_link))
def _seed_root_symlink_mappings(self) -> None:
"""Ensure symlinked root folders are recorded before deep scanning."""
for root in self._symlink_roots():
if not root:
continue
try:
if not self._is_link(root):
continue
target_path = os.path.realpath(root)
if not os.path.isdir(target_path):
continue
self.add_path_mapping(root, target_path)
except Exception as exc:
logger.debug("Skipping root symlink %s: %s", root, exc)
def _expand_preview_root(self, path: str) -> Set[Path]:
"""Return normalized ``Path`` objects representing a preview root."""

View File

@@ -39,8 +39,39 @@ class MetadataProcessor:
if node_id in metadata.get(SAMPLING, {}) and metadata[SAMPLING][node_id].get(IS_SAMPLER, False):
candidate_samplers[node_id] = metadata[SAMPLING][node_id]
# If we found candidate samplers, apply primary sampler logic to these candidates only
if candidate_samplers:
# If we found candidate samplers, apply primary sampler logic to these candidates only
# PRE-PROCESS: Ensure all candidate samplers have their parameters populated
# This is especially important for SamplerCustomAdvanced which needs tracing
prompt = metadata.get("current_prompt")
for node_id in candidate_samplers:
# If a sampler is missing common parameters like steps or denoise,
# try to populate them using tracing before ranking
sampler_info = candidate_samplers[node_id]
params = sampler_info.get("parameters", {})
if prompt and (params.get("steps") is None or params.get("denoise") is None):
# Create a temporary params dict to use the handler
temp_params = {
"steps": params.get("steps"),
"denoise": params.get("denoise"),
"sampler": params.get("sampler_name"),
"scheduler": params.get("scheduler")
}
# Check if it's SamplerCustomAdvanced
if prompt.original_prompt and node_id in prompt.original_prompt:
if prompt.original_prompt[node_id].get("class_type") == "SamplerCustomAdvanced":
MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, node_id, temp_params)
# Update the actual parameters with found values
params["steps"] = temp_params.get("steps")
params["denoise"] = temp_params.get("denoise")
if temp_params.get("sampler"):
params["sampler_name"] = temp_params.get("sampler")
if temp_params.get("scheduler"):
params["scheduler"] = temp_params.get("scheduler")
# Collect potential primary samplers based on different criteria
custom_advanced_samplers = []
advanced_add_noise_samplers = []
@@ -49,7 +80,6 @@ class MetadataProcessor:
high_denoise_id = None
# First, check for SamplerCustomAdvanced among candidates
prompt = metadata.get("current_prompt")
if prompt and prompt.original_prompt:
for node_id in candidate_samplers:
node_info = prompt.original_prompt.get(node_id, {})
@@ -77,15 +107,16 @@ class MetadataProcessor:
# Combine all potential primary samplers
potential_samplers = custom_advanced_samplers + advanced_add_noise_samplers + high_denoise_samplers
# Find the most recent potential primary sampler (closest to downstream node)
for i in range(downstream_index - 1, -1, -1):
# Find the first potential primary sampler (prefer base sampler over refine)
# Use forward search to prioritize the first one in execution order
for i in range(downstream_index):
node_id = execution_order[i]
if node_id in potential_samplers:
return node_id, candidate_samplers[node_id]
# If no potential sampler found from our criteria, return the most recent sampler
# If no potential sampler found from our criteria, return the first sampler
if candidate_samplers:
for i in range(downstream_index - 1, -1, -1):
for i in range(downstream_index):
node_id = execution_order[i]
if node_id in candidate_samplers:
return node_id, candidate_samplers[node_id]
@@ -176,8 +207,11 @@ class MetadataProcessor:
found_node_id = input_value[0] # Connected node_id
# If we're looking for a specific node class
if target_class and prompt.original_prompt[found_node_id].get("class_type") == target_class:
return found_node_id
if target_class:
if found_node_id not in prompt.original_prompt:
return None
if prompt.original_prompt[found_node_id].get("class_type") == target_class:
return found_node_id
# If we're not looking for a specific class, update the last valid node
if not target_class:
@@ -185,11 +219,19 @@ class MetadataProcessor:
# Continue tracing through intermediate nodes
current_node_id = found_node_id
# For most conditioning nodes, the input we want to follow is named "conditioning"
if "conditioning" in prompt.original_prompt[current_node_id].get("inputs", {}):
# Check if current source node exists
if current_node_id not in prompt.original_prompt:
return found_node_id if not target_class else None
# Determine which input to follow next on the source node
source_node_inputs = prompt.original_prompt[current_node_id].get("inputs", {})
if input_name in source_node_inputs:
current_input = input_name
elif "conditioning" in source_node_inputs:
current_input = "conditioning"
else:
# If there's no "conditioning" input, return the current node
# If there's no suitable input to follow, return the current node
# if we're not looking for a specific target_class
return found_node_id if not target_class else None
else:
@@ -202,12 +244,89 @@ class MetadataProcessor:
return last_valid_node if not target_class else None
@staticmethod
def find_primary_checkpoint(metadata):
"""Find the primary checkpoint model in the workflow"""
if not metadata.get(MODELS):
def trace_model_path(metadata, prompt, start_node_id):
"""
Trace the model connection path upstream to find the checkpoint
"""
if not prompt or not prompt.original_prompt:
return None
# In most workflows, there's only one checkpoint, so we can just take the first one
current_node_id = start_node_id
depth = 0
max_depth = 50
while depth < max_depth:
# Check if current node is a registered checkpoint in our metadata
# This handles cached nodes correctly because metadata contains info for all nodes in the graph
if current_node_id in metadata.get(MODELS, {}):
if metadata[MODELS][current_node_id].get("type") == "checkpoint":
return current_node_id
if current_node_id not in prompt.original_prompt:
return None
node = prompt.original_prompt[current_node_id]
inputs = node.get("inputs", {})
class_type = node.get("class_type", "")
# Determine which input to follow next
next_input_name = "model"
# Special handling for initial node
if depth == 0:
if class_type == "SamplerCustomAdvanced":
next_input_name = "guider"
# If the specific input doesn't exist, try generic 'model'
if next_input_name not in inputs:
if "model" in inputs:
next_input_name = "model"
elif "basic_pipe" in inputs:
# Handle pipe nodes like FromBasicPipe by following the pipeline
next_input_name = "basic_pipe"
else:
# Dead end - no model input to follow
return None
# Get connected node
input_val = inputs[next_input_name]
if isinstance(input_val, list) and len(input_val) > 0:
current_node_id = input_val[0]
else:
return None
depth += 1
return None
@staticmethod
def find_primary_checkpoint(metadata, downstream_id=None, primary_sampler_id=None):
"""
Find the primary checkpoint model in the workflow
Parameters:
- metadata: The workflow metadata
- downstream_id: Optional ID of a downstream node to help identify the specific primary sampler
- primary_sampler_id: Optional ID of the primary sampler if already known
"""
if not metadata.get(MODELS):
return None
# Method 1: Topology-based tracing (More accurate for complex workflows)
# First, find the primary sampler if not provided
if not primary_sampler_id:
primary_sampler_id, _ = MetadataProcessor.find_primary_sampler(metadata, downstream_id)
if primary_sampler_id:
prompt = metadata.get("current_prompt")
if prompt:
# Trace back from the sampler to find the checkpoint
checkpoint_id = MetadataProcessor.trace_model_path(metadata, prompt, primary_sampler_id)
if checkpoint_id and checkpoint_id in metadata.get(MODELS, {}):
return metadata[MODELS][checkpoint_id].get("name")
# Method 2: Fallback to the first available checkpoint (Original behavior)
# In most simple workflows, there's only one checkpoint, so we can just take the first one
for node_id, model_info in metadata.get(MODELS, {}).items():
if model_info.get("type") == "checkpoint":
return model_info.get("name")
@@ -311,7 +430,8 @@ class MetadataProcessor:
primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata, id)
# Directly get checkpoint from metadata instead of tracing
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata)
# Pass primary_sampler_id to avoid redundant calculation
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata, id, primary_sampler_id)
if checkpoint:
params["checkpoint"] = checkpoint
@@ -445,6 +565,7 @@ class MetadataProcessor:
scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {})
params["steps"] = scheduler_params.get("steps")
params["scheduler"] = scheduler_params.get("scheduler")
params["denoise"] = scheduler_params.get("denoise")
# 2. Trace sampler input to find KSamplerSelect (only if sampler input exists)
if "sampler" in sampler_inputs:

View File

@@ -9,7 +9,7 @@ from ..metadata_collector import get_metadata
from PIL import Image, PngImagePlugin
import piexif
class SaveImage:
class SaveImageLM:
NAME = "Save Image (LoraManager)"
CATEGORY = "Lora Manager/utils"
DESCRIPTION = "Save images with embedded generation metadata in compatible format"

View File

@@ -103,7 +103,7 @@ def nunchaku_load_lora(model, lora_name, lora_strength):
# Save the transformer temporarily
model_wrapper.model = None
ret_model = copy.deepcopy(model) # copy everything except the model
ret_model = model.clone()
ret_model_wrapper = ret_model.model.diffusion_model
# Restore the model and set it for the copy

View File

@@ -37,7 +37,8 @@ class RecipeMetadataParser(ABC):
"""
pass
async def populate_lora_from_civitai(self, lora_entry: Dict[str, Any], civitai_info_tuple: Tuple[Dict[str, Any], Optional[str]],
@staticmethod
async def populate_lora_from_civitai(lora_entry: Dict[str, Any], civitai_info_tuple: Tuple[Dict[str, Any], Optional[str]],
recipe_scanner=None, base_model_counts=None, hash_value=None) -> Optional[Dict[str, Any]]:
"""
Populate a lora entry with information from Civitai API response
@@ -148,8 +149,9 @@ class RecipeMetadataParser(ABC):
logger.error(f"Error populating lora from Civitai info: {e}")
return lora_entry
async def populate_checkpoint_from_civitai(self, checkpoint: Dict[str, Any], civitai_info: Dict[str, Any]) -> Dict[str, Any]:
@staticmethod
async def populate_checkpoint_from_civitai(checkpoint: Dict[str, Any], civitai_info: Dict[str, Any]) -> Dict[str, Any]:
"""
Populate checkpoint information from Civitai API response
@@ -187,6 +189,7 @@ class RecipeMetadataParser(ABC):
checkpoint['downloadUrl'] = civitai_data.get('downloadUrl', '')
checkpoint['modelId'] = civitai_data.get('modelId', checkpoint.get('modelId', 0))
checkpoint['id'] = civitai_data.get('id', 0)
if 'files' in civitai_data:
model_file = next(

216
py/recipes/enrichment.py Normal file
View File

@@ -0,0 +1,216 @@
import logging
import json
import re
import os
from typing import Any, Dict, Optional
from .merger import GenParamsMerger
from .base import RecipeMetadataParser
from ..services.metadata_service import get_default_metadata_provider
logger = logging.getLogger(__name__)
class RecipeEnricher:
"""Service to enrich recipe metadata from multiple sources (Civitai, Embedded, User)."""
@staticmethod
async def enrich_recipe(
recipe: Dict[str, Any],
civitai_client: Any,
request_params: Optional[Dict[str, Any]] = None
) -> bool:
"""
Enrich a recipe dictionary in-place with metadata from Civitai and embedded params.
Args:
recipe: The recipe dictionary to enrich. Must have 'gen_params' initialized.
civitai_client: Authenticated Civitai client instance.
request_params: (Optional) Parameters from a user request (e.g. import).
Returns:
bool: True if the recipe was modified, False otherwise.
"""
updated = False
gen_params = recipe.get("gen_params", {})
# 1. Fetch Civitai Info if available
civitai_meta = None
model_version_id = None
source_url = recipe.get("source_url") or recipe.get("source_path", "")
# Check if it's a Civitai image URL
image_id_match = re.search(r'civitai\.com/images/(\d+)', str(source_url))
if image_id_match:
image_id = image_id_match.group(1)
try:
image_info = await civitai_client.get_image_info(image_id)
if image_info:
# Handle nested meta often found in Civitai API responses
raw_meta = image_info.get("meta")
if isinstance(raw_meta, dict):
if "meta" in raw_meta and isinstance(raw_meta["meta"], dict):
civitai_meta = raw_meta["meta"]
else:
civitai_meta = raw_meta
model_version_id = image_info.get("modelVersionId")
# If not at top level, check resources in meta
if not model_version_id and civitai_meta:
resources = civitai_meta.get("civitaiResources", [])
for res in resources:
if res.get("type") == "checkpoint":
model_version_id = res.get("modelVersionId")
break
except Exception as e:
logger.warning(f"Failed to fetch Civitai image info: {e}")
# 2. Merge Parameters
# Priority: request_params > civitai_meta > embedded (existing gen_params)
new_gen_params = GenParamsMerger.merge(
request_params=request_params,
civitai_meta=civitai_meta,
embedded_metadata=gen_params
)
if new_gen_params != gen_params:
recipe["gen_params"] = new_gen_params
updated = True
# 3. Checkpoint Enrichment
# If we have a checkpoint entry, or we can find one
# Use 'id' (from Civitai version) as a marker that it's been enriched
checkpoint_entry = recipe.get("checkpoint")
has_full_checkpoint = checkpoint_entry and checkpoint_entry.get("name") and checkpoint_entry.get("id")
if not has_full_checkpoint:
# Helper to look up values in priority order
def start_lookup(keys):
for source in [request_params, civitai_meta, gen_params]:
if source:
if isinstance(keys, list):
for k in keys:
if k in source: return source[k]
else:
if keys in source: return source[keys]
return None
target_version_id = model_version_id or start_lookup("modelVersionId")
# Also check existing checkpoint entry
if not target_version_id and checkpoint_entry:
target_version_id = checkpoint_entry.get("modelVersionId") or checkpoint_entry.get("id")
# Check for version ID in resources (which might be a string in gen_params)
if not target_version_id:
# Look in all sources for "Civitai resources"
resources_val = start_lookup(["Civitai resources", "civitai_resources", "resources"])
if resources_val:
target_version_id = RecipeEnricher._extract_version_id_from_resources({"Civitai resources": resources_val})
target_hash = start_lookup(["Model hash", "checkpoint_hash", "hashes"])
if not target_hash and checkpoint_entry:
target_hash = checkpoint_entry.get("hash") or checkpoint_entry.get("model_hash")
# Look for 'Model' which sometimes is the hash or name
model_val = start_lookup("Model")
# Look for Checkpoint name fallback
checkpoint_val = checkpoint_entry.get("name") if checkpoint_entry else None
if not checkpoint_val:
checkpoint_val = start_lookup(["Checkpoint", "checkpoint"])
checkpoint_updated = await RecipeEnricher._resolve_and_populate_checkpoint(
recipe, target_version_id, target_hash, model_val, checkpoint_val
)
if checkpoint_updated:
updated = True
else:
# Checkpoint exists, no need to sync to gen_params anymore.
pass
# base_model resolution moved to _resolve_and_populate_checkpoint to support strict formatting
return updated
@staticmethod
def _extract_version_id_from_resources(gen_params: Dict[str, Any]) -> Optional[Any]:
"""Try to find modelVersionId in Civitai resources parameter."""
civitai_resources_raw = gen_params.get("Civitai resources")
if not civitai_resources_raw:
return None
resources_list = None
if isinstance(civitai_resources_raw, str):
try:
resources_list = json.loads(civitai_resources_raw)
except Exception:
pass
elif isinstance(civitai_resources_raw, list):
resources_list = civitai_resources_raw
if isinstance(resources_list, list):
for res in resources_list:
if res.get("type") == "checkpoint":
return res.get("modelVersionId")
return None
@staticmethod
async def _resolve_and_populate_checkpoint(
recipe: Dict[str, Any],
target_version_id: Optional[Any],
target_hash: Optional[str],
model_val: Optional[str],
checkpoint_val: Optional[str]
) -> bool:
"""Find checkpoint metadata and populate it in the recipe."""
metadata_provider = await get_default_metadata_provider()
civitai_info = None
if target_version_id:
civitai_info = await metadata_provider.get_model_version_info(str(target_version_id))
elif target_hash:
civitai_info = await metadata_provider.get_model_by_hash(target_hash)
else:
# Look for 'Model' which sometimes is the hash or name
if model_val and len(model_val) == 10: # Likely a short hash
civitai_info = await metadata_provider.get_model_by_hash(model_val)
if civitai_info and not (isinstance(civitai_info, tuple) and civitai_info[1] == "Model not found"):
# If we already have a partial checkpoint, use it as base
existing_cp = recipe.get("checkpoint")
if existing_cp is None:
existing_cp = {}
checkpoint_data = await RecipeMetadataParser.populate_checkpoint_from_civitai(existing_cp, civitai_info)
# 1. First, resolve base_model using full data before we format it away
current_base_model = recipe.get("base_model")
resolved_base_model = checkpoint_data.get("baseModel")
if resolved_base_model:
# Update if empty OR if it matches our generic prefix but is less specific
is_generic = not current_base_model or current_base_model.lower() in ["flux", "sdxl", "sd15"]
if is_generic and resolved_base_model != current_base_model:
recipe["base_model"] = resolved_base_model
# 2. Format according to requirements: type, modelId, modelVersionId, modelName, modelVersionName
formatted_checkpoint = {
"type": "checkpoint",
"modelId": checkpoint_data.get("modelId"),
"modelVersionId": checkpoint_data.get("id") or checkpoint_data.get("modelVersionId"),
"modelName": checkpoint_data.get("name"), # In base.py, 'name' is populated from civitai_data['model']['name']
"modelVersionName": checkpoint_data.get("version") # In base.py, 'version' is populated from civitai_data['name']
}
# Remove None values
recipe["checkpoint"] = {k: v for k, v in formatted_checkpoint.items() if v is not None}
return True
else:
# Fallback to name extraction if we don't already have one
existing_cp = recipe.get("checkpoint")
if not existing_cp or not existing_cp.get("modelName"):
cp_name = checkpoint_val
if cp_name:
recipe["checkpoint"] = {
"type": "checkpoint",
"modelName": cp_name
}
return True
return False

98
py/recipes/merger.py Normal file
View File

@@ -0,0 +1,98 @@
from typing import Any, Dict, Optional
import logging
logger = logging.getLogger(__name__)
class GenParamsMerger:
"""Utility to merge generation parameters from multiple sources with priority."""
BLACKLISTED_KEYS = {
"id", "url", "userId", "username", "createdAt", "updatedAt", "hash", "meta",
"draft", "extra", "width", "height", "process", "quantity", "workflow",
"baseModel", "resources", "disablePoi", "aspectRatio", "Created Date",
"experimental", "civitaiResources", "civitai_resources", "Civitai resources",
"modelVersionId", "modelId", "hashes", "Model", "Model hash", "checkpoint_hash",
"checkpoint", "checksum", "model_checksum"
}
NORMALIZATION_MAPPING = {
# Civitai specific
"cfgScale": "cfg_scale",
"clipSkip": "clip_skip",
"negativePrompt": "negative_prompt",
# Case variations
"Sampler": "sampler",
"Steps": "steps",
"Seed": "seed",
"Size": "size",
"Prompt": "prompt",
"Negative prompt": "negative_prompt",
"Cfg scale": "cfg_scale",
"Clip skip": "clip_skip",
"Denoising strength": "denoising_strength",
}
@staticmethod
def merge(
request_params: Optional[Dict[str, Any]] = None,
civitai_meta: Optional[Dict[str, Any]] = None,
embedded_metadata: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Merge generation parameters from three sources.
Priority: request_params > civitai_meta > embedded_metadata
Args:
request_params: Params provided directly in the import request
civitai_meta: Params from Civitai Image API 'meta' field
embedded_metadata: Params extracted from image EXIF/embedded metadata
Returns:
Merged parameters dictionary
"""
result = {}
# 1. Start with embedded metadata (lowest priority)
if embedded_metadata:
# If it's a full recipe metadata, we use its gen_params
if "gen_params" in embedded_metadata and isinstance(embedded_metadata["gen_params"], dict):
GenParamsMerger._update_normalized(result, embedded_metadata["gen_params"])
else:
# Otherwise assume the dict itself contains gen_params
GenParamsMerger._update_normalized(result, embedded_metadata)
# 2. Layer Civitai meta (medium priority)
if civitai_meta:
GenParamsMerger._update_normalized(result, civitai_meta)
# 3. Layer request params (highest priority)
if request_params:
GenParamsMerger._update_normalized(result, request_params)
# Filter out blacklisted keys and also the original camelCase keys if they were normalized
final_result = {}
for k, v in result.items():
if k in GenParamsMerger.BLACKLISTED_KEYS:
continue
if k in GenParamsMerger.NORMALIZATION_MAPPING:
continue
final_result[k] = v
return final_result
@staticmethod
def _update_normalized(target: Dict[str, Any], source: Dict[str, Any]) -> None:
"""Update target dict with normalized keys from source."""
for k, v in source.items():
normalized_key = GenParamsMerger.NORMALIZATION_MAPPING.get(k, k)
target[normalized_key] = v
# Also keep the original key for now if it's not the same,
# so we can filter at the end or avoid losing it if it wasn't supposed to be renamed?
# Actually, if we rename it, we should probably NOT keep both in 'target'
# because we want to filter them out at the end anyway.
if normalized_key != k:
# If we are overwriting an existing snake_case key with a camelCase one's value,
# that's fine because of the priority order of calls to _update_normalized.
pass
target[k] = v

View File

@@ -36,9 +36,6 @@ class ComfyMetadataParser(RecipeMetadataParser):
# Find all LoraLoader nodes
lora_nodes = {k: v for k, v in data.items() if isinstance(v, dict) and v.get('class_type') == 'LoraLoader'}
if not lora_nodes:
return {"error": "No LoRA information found in this ComfyUI workflow", "loras": []}
# Process each LoraLoader node
for node_id, node in lora_nodes.items():
if 'inputs' not in node or 'lora_name' not in node['inputs']:

View File

@@ -79,26 +79,8 @@ class BaseRecipeRoutes:
return
app.on_startup.append(self.attach_dependencies)
app.on_startup.append(self.prewarm_cache)
self._startup_hooks_registered = True
async def prewarm_cache(self, app: web.Application | None = None) -> None:
"""Pre-load recipe and LoRA caches on startup."""
try:
await self.attach_dependencies(app)
if self.lora_scanner is not None:
await self.lora_scanner.get_cached_data()
hash_index = getattr(self.lora_scanner, "_hash_index", None)
if hash_index is not None and hasattr(hash_index, "_hash_to_path"):
_ = len(hash_index._hash_to_path)
if self.recipe_scanner is not None:
await self.recipe_scanner.get_cached_data(force_refresh=True)
except Exception as exc:
logger.error("Error pre-warming recipe cache: %s", exc, exc_info=True)
def to_route_mapping(self) -> Mapping[str, Callable]:
"""Return a mapping of handler name to coroutine for registrar binding."""

View File

@@ -5,6 +5,7 @@ import asyncio
import json
import logging
import os
import time
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Mapping, Optional
@@ -61,6 +62,37 @@ class ModelPageView:
self._settings = settings_service
self._server_i18n = server_i18n
self._logger = logger
self._app_version = self._get_app_version()
def _get_app_version(self) -> str:
version = "1.0.0"
short_hash = "stable"
try:
import toml
current_file = os.path.abspath(__file__)
# Navigate up from py/routes/handlers/model_handlers.py to project root
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
pyproject_path = os.path.join(root_dir, 'pyproject.toml')
if os.path.exists(pyproject_path):
with open(pyproject_path, 'r', encoding='utf-8') as f:
data = toml.load(f)
version = data.get('project', {}).get('version', '1.0.0').replace('v', '')
# Try to get git info for granular cache busting
git_dir = os.path.join(root_dir, '.git')
if os.path.exists(git_dir):
try:
import git
repo = git.Repo(root_dir)
short_hash = repo.head.commit.hexsha[:7]
except Exception:
# Fallback if git is not available or not a repo
pass
except Exception as e:
self._logger.debug(f"Failed to read version info for cache busting: {e}")
return f"{version}-{short_hash}"
async def handle(self, request: web.Request) -> web.Response:
try:
@@ -96,6 +128,7 @@ class ModelPageView:
"request": request,
"folders": [],
"t": self._server_i18n.get_translation,
"version": self._app_version,
}
if not is_initializing:
@@ -128,9 +161,12 @@ class ModelListingHandler:
self._logger = logger
async def get_models(self, request: web.Request) -> web.Response:
start_time = time.perf_counter()
try:
params = self._parse_common_params(request)
result = await self._service.get_paginated_data(**params)
format_start = time.perf_counter()
formatted_result = {
"items": [await self._service.format_response(item) for item in result["items"]],
"total": result["total"],
@@ -138,6 +174,13 @@ class ModelListingHandler:
"page_size": result["page_size"],
"total_pages": result["total_pages"],
}
format_duration = time.perf_counter() - format_start
duration = time.perf_counter() - start_time
self._logger.info(
"Request for %s/list took %.3fs (formatting: %.3fs)",
self._service.model_type, duration, format_duration
)
return web.json_response(formatted_result)
except Exception as exc:
self._logger.error("Error retrieving %ss: %s", self._service.model_type, exc, exc_info=True)

View File

@@ -5,6 +5,7 @@ import json
import logging
import os
import re
import asyncio
import tempfile
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional
@@ -23,6 +24,11 @@ from ...services.recipes import (
RecipeValidationError,
)
from ...services.metadata_service import get_default_metadata_provider
from ...utils.civitai_utils import rewrite_preview_url
from ...utils.exif_utils import ExifUtils
from ...recipes.merger import GenParamsMerger
from ...recipes.enrichment import RecipeEnricher
from ...services.websocket_manager import ws_manager as default_ws_manager
Logger = logging.Logger
EnsureDependenciesCallable = Callable[[], Awaitable[None]]
@@ -55,16 +61,25 @@ class RecipeHandlerSet:
"delete_recipe": self.management.delete_recipe,
"get_top_tags": self.query.get_top_tags,
"get_base_models": self.query.get_base_models,
"get_roots": self.query.get_roots,
"get_folders": self.query.get_folders,
"get_folder_tree": self.query.get_folder_tree,
"get_unified_folder_tree": self.query.get_unified_folder_tree,
"share_recipe": self.sharing.share_recipe,
"download_shared_recipe": self.sharing.download_shared_recipe,
"get_recipe_syntax": self.query.get_recipe_syntax,
"update_recipe": self.management.update_recipe,
"reconnect_lora": self.management.reconnect_lora,
"find_duplicates": self.query.find_duplicates,
"move_recipes_bulk": self.management.move_recipes_bulk,
"bulk_delete": self.management.bulk_delete,
"save_recipe_from_widget": self.management.save_recipe_from_widget,
"get_recipes_for_lora": self.query.get_recipes_for_lora,
"scan_recipes": self.query.scan_recipes,
"move_recipe": self.management.move_recipe,
"repair_recipes": self.management.repair_recipes,
"repair_recipe": self.management.repair_recipe,
"get_repair_progress": self.management.get_repair_progress,
}
@@ -148,12 +163,15 @@ class RecipeListingHandler:
page_size = int(request.query.get("page_size", "20"))
sort_by = request.query.get("sort_by", "date")
search = request.query.get("search")
folder = request.query.get("folder")
recursive = request.query.get("recursive", "true").lower() == "true"
search_options = {
"title": request.query.get("search_title", "true").lower() == "true",
"tags": request.query.get("search_tags", "true").lower() == "true",
"lora_name": request.query.get("search_lora_name", "true").lower() == "true",
"lora_model": request.query.get("search_lora_model", "true").lower() == "true",
"prompt": request.query.get("search_prompt", "true").lower() == "true",
}
filters: Dict[str, Any] = {}
@@ -161,6 +179,9 @@ class RecipeListingHandler:
if base_models:
filters["base_model"] = base_models.split(",")
if request.query.get("favorite", "false").lower() == "true":
filters["favorite"] = True
tag_filters: Dict[str, str] = {}
legacy_tags = request.query.get("tags")
if legacy_tags:
@@ -192,6 +213,8 @@ class RecipeListingHandler:
filters=filters,
search_options=search_options,
lora_hash=lora_hash,
folder=folder,
recursive=recursive,
)
for item in result.get("items", []):
@@ -298,6 +321,58 @@ class RecipeQueryHandler:
self._logger.error("Error retrieving base models: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_roots(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
roots = [recipe_scanner.recipes_dir] if recipe_scanner.recipes_dir else []
return web.json_response({"success": True, "roots": roots})
except Exception as exc:
self._logger.error("Error retrieving recipe roots: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_folders(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
folders = await recipe_scanner.get_folders()
return web.json_response({"success": True, "folders": folders})
except Exception as exc:
self._logger.error("Error retrieving recipe folders: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_folder_tree(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
folder_tree = await recipe_scanner.get_folder_tree()
return web.json_response({"success": True, "tree": folder_tree})
except Exception as exc:
self._logger.error("Error retrieving recipe folder tree: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_unified_folder_tree(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
folder_tree = await recipe_scanner.get_folder_tree()
return web.json_response({"success": True, "tree": folder_tree})
except Exception as exc:
self._logger.error("Error retrieving unified recipe folder tree: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_recipes_for_lora(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
@@ -410,6 +485,7 @@ class RecipeManagementHandler:
analysis_service: RecipeAnalysisService,
downloader_factory,
civitai_client_getter: CivitaiClientGetter,
ws_manager=default_ws_manager,
) -> None:
self._ensure_dependencies_ready = ensure_dependencies_ready
self._recipe_scanner_getter = recipe_scanner_getter
@@ -418,6 +494,7 @@ class RecipeManagementHandler:
self._analysis_service = analysis_service
self._downloader_factory = downloader_factory
self._civitai_client_getter = civitai_client_getter
self._ws_manager = ws_manager
async def save_recipe(self, request: web.Request) -> web.Response:
try:
@@ -436,6 +513,7 @@ class RecipeManagementHandler:
name=payload["name"],
tags=payload["tags"],
metadata=payload["metadata"],
extension=payload.get("extension"),
)
return web.json_response(result.payload, status=result.status)
except RecipeValidationError as exc:
@@ -444,17 +522,84 @@ class RecipeManagementHandler:
self._logger.error("Error saving recipe: %s", exc, exc_info=True)
return web.json_response({"error": str(exc)}, status=500)
async def repair_recipes(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
return web.json_response({"success": False, "error": "Recipe scanner unavailable"}, status=503)
# Check if already running
if self._ws_manager.get_recipe_repair_progress():
return web.json_response({"success": False, "error": "Recipe repair already in progress"}, status=409)
async def progress_callback(data):
await self._ws_manager.broadcast_recipe_repair_progress(data)
# Run in background to avoid timeout
async def run_repair():
try:
await recipe_scanner.repair_all_recipes(
progress_callback=progress_callback
)
except Exception as e:
self._logger.error(f"Error in recipe repair task: {e}", exc_info=True)
await self._ws_manager.broadcast_recipe_repair_progress({
"status": "error",
"error": str(e)
})
finally:
# Keep the final status for a while so the UI can see it
await asyncio.sleep(5)
self._ws_manager.cleanup_recipe_repair_progress()
asyncio.create_task(run_repair())
return web.json_response({"success": True, "message": "Recipe repair started"})
except Exception as exc:
self._logger.error("Error starting recipe repair: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def repair_recipe(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
return web.json_response({"success": False, "error": "Recipe scanner unavailable"}, status=503)
recipe_id = request.match_info["recipe_id"]
result = await recipe_scanner.repair_recipe_by_id(recipe_id)
return web.json_response(result)
except RecipeNotFoundError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=404)
except Exception as exc:
self._logger.error("Error repairing single recipe: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_repair_progress(self, request: web.Request) -> web.Response:
try:
progress = self._ws_manager.get_recipe_repair_progress()
if progress:
return web.json_response({"success": True, "progress": progress})
return web.json_response({"success": False, "message": "No repair in progress"}, status=404)
except Exception as exc:
self._logger.error("Error getting repair progress: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def import_remote_recipe(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
# 1. Parse Parameters
params = request.rel_url.query
image_url = params.get("image_url")
name = params.get("name")
resources_raw = params.get("resources")
if not image_url:
raise RecipeValidationError("Missing required field: image_url")
if not name:
@@ -463,27 +608,93 @@ class RecipeManagementHandler:
raise RecipeValidationError("Missing required field: resources")
checkpoint_entry, lora_entries = self._parse_resources_payload(resources_raw)
gen_params = self._parse_gen_params(params.get("gen_params"))
gen_params_request = self._parse_gen_params(params.get("gen_params"))
# 2. Initial Metadata Construction
metadata: Dict[str, Any] = {
"base_model": params.get("base_model", "") or "",
"loras": lora_entries,
"gen_params": gen_params_request or {},
"source_url": image_url
}
source_path = params.get("source_path")
if source_path:
metadata["source_path"] = source_path
if gen_params is not None:
metadata["gen_params"] = gen_params
# Checkpoint handling
if checkpoint_entry:
metadata["checkpoint"] = checkpoint_entry
gen_params_ref = metadata.setdefault("gen_params", {})
if "checkpoint" not in gen_params_ref:
gen_params_ref["checkpoint"] = checkpoint_entry
base_model_from_metadata = await self._resolve_base_model_from_checkpoint(checkpoint_entry)
if base_model_from_metadata:
metadata["base_model"] = base_model_from_metadata
# Ensure checkpoint is also in gen_params for consistency if needed by enricher?
# Actually enricher looks at metadata['checkpoint'], so this is fine.
# Try to resolve base model from checkpoint if not explicitly provided
if not metadata["base_model"]:
base_model_from_metadata = await self._resolve_base_model_from_checkpoint(checkpoint_entry)
if base_model_from_metadata:
metadata["base_model"] = base_model_from_metadata
tags = self._parse_tags(params.get("tags"))
image_bytes = await self._download_image_bytes(image_url)
# 3. Download Image
image_bytes, extension, civitai_meta_from_download = await self._download_remote_media(image_url)
# 4. Extract Embedded Metadata
# Note: We still extract this here because Enricher currently expects 'gen_params' to already be populated
# with embedded data if we want it to merge it.
# However, logic in Enricher merges: request > civitai > embedded.
# So we should gather embedded params and put them into the recipe's gen_params (as initial state)
# OR pass them to enricher to handle?
# The interface of Enricher.enrich_recipe takes `recipe` (with gen_params) and `request_params`.
# So let's extract embedded and put it into recipe['gen_params'] but careful not to overwrite request params.
# Actually, `GenParamsMerger` which `Enricher` uses handles 3 layers.
# But `Enricher` interface is: recipe['gen_params'] (as embedded) + request_params + civitai (fetched internally).
# Wait, `Enricher` fetches Civitai info internally based on URL.
# `civitai_meta_from_download` is returned by `_download_remote_media` which might be useful if URL didn't have ID.
# Let's extract embedded metadata first
embedded_gen_params = {}
try:
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_img:
temp_img.write(image_bytes)
temp_img_path = temp_img.name
try:
raw_embedded = ExifUtils.extract_image_metadata(temp_img_path)
if raw_embedded:
parser = self._analysis_service._recipe_parser_factory.create_parser(raw_embedded)
if parser:
parsed_embedded = await parser.parse_metadata(raw_embedded, recipe_scanner=recipe_scanner)
if parsed_embedded and "gen_params" in parsed_embedded:
embedded_gen_params = parsed_embedded["gen_params"]
else:
embedded_gen_params = {"raw_metadata": raw_embedded}
finally:
if os.path.exists(temp_img_path):
os.unlink(temp_img_path)
except Exception as exc:
self._logger.warning("Failed to extract embedded metadata during import: %s", exc)
# Pre-populate gen_params with embedded data so Enricher treats it as the "base" layer
if embedded_gen_params:
# Merge embedded into existing gen_params (which currently only has request params if any)
# But wait, we want request params to override everything.
# So we should set recipe['gen_params'] = embedded, and pass request params to enricher.
metadata["gen_params"] = embedded_gen_params
# 5. Enrich with unified logic
# This will fetch Civitai info (if URL matches) and merge: request > civitai > embedded
civitai_client = self._civitai_client_getter()
await RecipeEnricher.enrich_recipe(
recipe=metadata,
civitai_client=civitai_client,
request_params=gen_params_request # Pass explicit request params here to override
)
# If we got civitai_meta from download but Enricher didn't fetch it (e.g. not a civitai URL or failed),
# we might want to manually merge it?
# But usually `import_remote_recipe` is used with Civitai URLs.
# For now, relying on Enricher's internal fetch is consistent with repair.
result = await self._persistence_service.save_recipe(
recipe_scanner=recipe_scanner,
@@ -492,6 +703,7 @@ class RecipeManagementHandler:
name=name,
tags=tags,
metadata=metadata,
extension=extension,
)
return web.json_response(result.payload, status=result.status)
except RecipeValidationError as exc:
@@ -541,6 +753,64 @@ class RecipeManagementHandler:
self._logger.error("Error updating recipe: %s", exc, exc_info=True)
return web.json_response({"error": str(exc)}, status=500)
async def move_recipe(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
data = await request.json()
recipe_id = data.get("recipe_id")
target_path = data.get("target_path")
if not recipe_id or not target_path:
return web.json_response(
{"success": False, "error": "recipe_id and target_path are required"}, status=400
)
result = await self._persistence_service.move_recipe(
recipe_scanner=recipe_scanner,
recipe_id=str(recipe_id),
target_path=str(target_path),
)
return web.json_response(result.payload, status=result.status)
except RecipeValidationError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except RecipeNotFoundError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=404)
except Exception as exc:
self._logger.error("Error moving recipe: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def move_recipes_bulk(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
data = await request.json()
recipe_ids = data.get("recipe_ids") or []
target_path = data.get("target_path")
if not recipe_ids or not target_path:
return web.json_response(
{"success": False, "error": "recipe_ids and target_path are required"}, status=400
)
result = await self._persistence_service.move_recipes_bulk(
recipe_scanner=recipe_scanner,
recipe_ids=recipe_ids,
target_path=str(target_path),
)
return web.json_response(result.payload, status=result.status)
except RecipeValidationError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except RecipeNotFoundError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=404)
except Exception as exc:
self._logger.error("Error moving recipes in bulk: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def reconnect_lora(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
@@ -622,6 +892,7 @@ class RecipeManagementHandler:
name: Optional[str] = None
tags: list[str] = []
metadata: Optional[Dict[str, Any]] = None
extension: Optional[str] = None
while True:
field = await reader.next()
@@ -652,6 +923,8 @@ class RecipeManagementHandler:
metadata = json.loads(metadata_text)
except Exception:
metadata = {}
elif field.name == "extension":
extension = await field.text()
return {
"image_bytes": image_bytes,
@@ -659,6 +932,7 @@ class RecipeManagementHandler:
"name": name,
"tags": tags,
"metadata": metadata,
"extension": extension,
}
def _parse_tags(self, tag_text: Optional[str]) -> list[str]:
@@ -729,7 +1003,7 @@ class RecipeManagementHandler:
"exclude": False,
}
async def _download_image_bytes(self, image_url: str) -> bytes:
async def _download_remote_media(self, image_url: str) -> tuple[bytes, str]:
civitai_client = self._civitai_client_getter()
downloader = await self._downloader_factory()
temp_path = None
@@ -744,15 +1018,31 @@ class RecipeManagementHandler:
image_info = await civitai_client.get_image_info(civitai_match.group(1))
if not image_info:
raise RecipeDownloadError("Failed to fetch image information from Civitai")
download_url = image_info.get("url")
if not download_url:
media_url = image_info.get("url")
if not media_url:
raise RecipeDownloadError("No image URL found in Civitai response")
# Use optimized preview URLs if possible
media_type = image_info.get("type")
rewritten_url, _ = rewrite_preview_url(media_url, media_type=media_type)
if rewritten_url:
download_url = rewritten_url
else:
download_url = media_url
success, result = await downloader.download_file(download_url, temp_path, use_auth=False)
if not success:
raise RecipeDownloadError(f"Failed to download image: {result}")
# Extract extension from URL
url_path = download_url.split('?')[0].split('#')[0]
extension = os.path.splitext(url_path)[1].lower()
if not extension:
extension = ".webp" # Default to webp if unknown
with open(temp_path, "rb") as file_obj:
return file_obj.read()
return file_obj.read(), extension, image_info.get("meta") if civitai_match and image_info else None
except RecipeDownloadError:
raise
except RecipeValidationError:
@@ -766,6 +1056,7 @@ class RecipeManagementHandler:
except FileNotFoundError:
pass
def _safe_int(self, value: Any) -> int:
try:
return int(value)

View File

@@ -27,16 +27,25 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("DELETE", "/api/lm/recipe/{recipe_id}", "delete_recipe"),
RouteDefinition("GET", "/api/lm/recipes/top-tags", "get_top_tags"),
RouteDefinition("GET", "/api/lm/recipes/base-models", "get_base_models"),
RouteDefinition("GET", "/api/lm/recipes/roots", "get_roots"),
RouteDefinition("GET", "/api/lm/recipes/folders", "get_folders"),
RouteDefinition("GET", "/api/lm/recipes/folder-tree", "get_folder_tree"),
RouteDefinition("GET", "/api/lm/recipes/unified-folder-tree", "get_unified_folder_tree"),
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share", "share_recipe"),
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share/download", "download_shared_recipe"),
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/syntax", "get_recipe_syntax"),
RouteDefinition("GET", "/api/lm/recipes/syntax", "get_recipe_syntax"),
RouteDefinition("PUT", "/api/lm/recipe/{recipe_id}/update", "update_recipe"),
RouteDefinition("POST", "/api/lm/recipe/move", "move_recipe"),
RouteDefinition("POST", "/api/lm/recipes/move-bulk", "move_recipes_bulk"),
RouteDefinition("POST", "/api/lm/recipe/lora/reconnect", "reconnect_lora"),
RouteDefinition("GET", "/api/lm/recipes/find-duplicates", "find_duplicates"),
RouteDefinition("POST", "/api/lm/recipes/bulk-delete", "bulk_delete"),
RouteDefinition("POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget"),
RouteDefinition("GET", "/api/lm/recipes/for-lora", "get_recipes_for_lora"),
RouteDefinition("GET", "/api/lm/recipes/scan", "scan_recipes"),
RouteDefinition("POST", "/api/lm/recipes/repair", "repair_recipes"),
RouteDefinition("POST", "/api/lm/recipe/{recipe_id}/repair", "repair_recipe"),
RouteDefinition("GET", "/api/lm/recipes/repair-progress", "get_repair_progress"),
)

View File

@@ -3,6 +3,7 @@ import asyncio
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
import logging
import os
import time
from ..utils.constants import VALID_LORA_TYPES
from ..utils.models import BaseModelMetadata
@@ -80,13 +81,20 @@ class BaseModelService(ABC):
**kwargs,
) -> Dict:
"""Get paginated and filtered model data"""
overall_start = time.perf_counter()
sort_params = self.cache_repository.parse_sort(sort_by)
if sort_params.key == 'usage':
sorted_data = await self._fetch_with_usage_sort(sort_params)
else:
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
t0 = time.perf_counter()
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
fetch_duration = time.perf_counter() - t0
initial_count = len(sorted_data)
t1 = time.perf_counter()
if hash_filters:
filtered_data = await self._apply_hash_filters(sorted_data, hash_filters)
else:
@@ -116,17 +124,25 @@ class BaseModelService(ABC):
if allow_selling_generated_content is not None:
filtered_data = await self._apply_allow_selling_filter(filtered_data, allow_selling_generated_content)
filter_duration = time.perf_counter() - t1
post_filter_count = len(filtered_data)
annotated_for_filter: Optional[List[Dict]] = None
t2 = time.perf_counter()
if update_available_only:
annotated_for_filter = await self._annotate_update_flags(filtered_data)
filtered_data = [
item for item in annotated_for_filter
if item.get('update_available')
]
update_filter_duration = time.perf_counter() - t2
final_count = len(filtered_data)
t3 = time.perf_counter()
paginated = self._paginate(filtered_data, page, page_size)
pagination_duration = time.perf_counter() - t3
t4 = time.perf_counter()
if update_available_only:
# Items already include update flags thanks to the pre-filter annotation.
paginated['items'] = list(paginated['items'])
@@ -134,6 +150,16 @@ class BaseModelService(ABC):
paginated['items'] = await self._annotate_update_flags(
paginated['items'],
)
annotate_duration = time.perf_counter() - t4
overall_duration = time.perf_counter() - overall_start
logger.info(
"%s.get_paginated_data took %.3fs (fetch: %.3fs, filter: %.3fs, update_filter: %.3fs, pagination: %.3fs, annotate: %.3fs). "
"Counts: initial=%d, post_filter=%d, final=%d",
self.__class__.__name__, overall_duration, fetch_duration, filter_duration,
update_filter_duration, pagination_duration, annotate_duration,
initial_count, post_filter_count, final_count
)
return paginated
async def _fetch_with_usage_sort(self, sort_params):

View File

@@ -1,4 +1,8 @@
import asyncio
import time
import logging
logger = logging.getLogger(__name__)
from typing import Any, Dict, List, Optional, Tuple
from dataclasses import dataclass, field
from operator import itemgetter
@@ -215,24 +219,25 @@ class ModelCache:
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
"""Sort data by sort_key and order"""
start_time = time.perf_counter()
reverse = (order == 'desc')
if sort_key == 'name':
# Natural sort by configured display name, case-insensitive
return natsorted(
result = natsorted(
data,
key=lambda x: self._get_display_name(x).lower(),
reverse=reverse
)
elif sort_key == 'date':
# Sort by modified timestamp
return sorted(
result = sorted(
data,
key=itemgetter('modified'),
reverse=reverse
)
elif sort_key == 'size':
# Sort by file size
return sorted(
result = sorted(
data,
key=itemgetter('size'),
reverse=reverse
@@ -249,16 +254,28 @@ class ModelCache:
)
else:
# Fallback: no sort
return list(data)
result = list(data)
duration = time.perf_counter() - start_time
if duration > 0.05:
logger.info("ModelCache._sort_data(%s, %s) for %d items took %.3fs", sort_key, order, len(data), duration)
return result
async def get_sorted_data(self, sort_key: str = 'name', order: str = 'asc') -> List[Dict]:
"""Get sorted data by sort_key and order, using cache if possible"""
async with self._lock:
if (sort_key, order) == self._last_sort:
return self._last_sorted_data
start_time = time.perf_counter()
sorted_data = self._sort_data(self.raw_data, sort_key, order)
self._last_sort = (sort_key, order)
self._last_sorted_data = sorted_data
duration = time.perf_counter() - start_time
if duration > 0.1:
logger.debug("ModelCache.get_sorted_data(%s, %s) took %.3fs", sort_key, order, duration)
return sorted_data
async def update_name_display_mode(self, display_mode: str) -> None:

View File

@@ -5,6 +5,10 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
from ..utils.constants import NSFW_LEVELS
from ..utils.utils import fuzzy_match as default_fuzzy_match
import time
import logging
logger = logging.getLogger(__name__)
DEFAULT_CIVITAI_MODEL_TYPE = "LORA"
@@ -115,22 +119,33 @@ class ModelFilterSet:
def apply(self, data: Iterable[Dict[str, Any]], criteria: FilterCriteria) -> List[Dict[str, Any]]:
"""Return items that satisfy the provided criteria."""
overall_start = time.perf_counter()
items = list(data)
initial_count = len(items)
if self._settings.get("show_only_sfw", False):
t0 = time.perf_counter()
threshold = self._nsfw_levels.get("R", 0)
items = [
item for item in items
if not item.get("preview_nsfw_level") or item.get("preview_nsfw_level") < threshold
]
sfw_duration = time.perf_counter() - t0
else:
sfw_duration = 0
favorites_duration = 0
if criteria.favorites_only:
t0 = time.perf_counter()
items = [item for item in items if item.get("favorite", False)]
favorites_duration = time.perf_counter() - t0
folder_duration = 0
folder = criteria.folder
options = criteria.search_options or {}
recursive = bool(options.get("recursive", True))
if folder is not None:
t0 = time.perf_counter()
if recursive:
if folder:
folder_with_sep = f"{folder}/"
@@ -140,51 +155,82 @@ class ModelFilterSet:
]
else:
items = [item for item in items if item.get("folder") == folder]
folder_duration = time.perf_counter() - t0
base_models_duration = 0
base_models = criteria.base_models or []
if base_models:
t0 = time.perf_counter()
base_model_set = set(base_models)
items = [item for item in items if item.get("base_model") in base_model_set]
base_models_duration = time.perf_counter() - t0
tags_duration = 0
tag_filters = criteria.tags or {}
include_tags = set()
exclude_tags = set()
if isinstance(tag_filters, dict):
for tag, state in tag_filters.items():
if not tag:
continue
if state == "exclude":
exclude_tags.add(tag)
else:
include_tags.add(tag)
else:
include_tags = {tag for tag in tag_filters if tag}
if tag_filters:
t0 = time.perf_counter()
include_tags = set()
exclude_tags = set()
if isinstance(tag_filters, dict):
for tag, state in tag_filters.items():
if not tag:
continue
if state == "exclude":
exclude_tags.add(tag)
else:
include_tags.add(tag)
else:
include_tags = {tag for tag in tag_filters if tag}
if include_tags:
items = [
item for item in items
if any(tag in include_tags for tag in (item.get("tags", []) or []))
]
if include_tags:
def matches_include(item_tags):
if not item_tags and "__no_tags__" in include_tags:
return True
return any(tag in include_tags for tag in (item_tags or []))
if exclude_tags:
items = [
item for item in items
if not any(tag in exclude_tags for tag in (item.get("tags", []) or []))
]
items = [
item for item in items
if matches_include(item.get("tags"))
]
if exclude_tags:
def matches_exclude(item_tags):
if not item_tags and "__no_tags__" in exclude_tags:
return True
return any(tag in exclude_tags for tag in (item_tags or []))
items = [
item for item in items
if not matches_exclude(item.get("tags"))
]
tags_duration = time.perf_counter() - t0
model_types_duration = 0
model_types = criteria.model_types or []
normalized_model_types = {
model_type for model_type in (
normalize_civitai_model_type(value) for value in model_types
)
if model_type
}
if normalized_model_types:
items = [
item for item in items
if normalize_civitai_model_type(resolve_civitai_model_type(item)) in normalized_model_types
]
if model_types:
t0 = time.perf_counter()
normalized_model_types = {
model_type for model_type in (
normalize_civitai_model_type(value) for value in model_types
)
if model_type
}
if normalized_model_types:
items = [
item for item in items
if normalize_civitai_model_type(resolve_civitai_model_type(item)) in normalized_model_types
]
model_types_duration = time.perf_counter() - t0
duration = time.perf_counter() - overall_start
if duration > 0.1: # Only log if it's potentially slow
logger.info(
"ModelFilterSet.apply took %.3fs (sfw: %.3fs, fav: %.3fs, folder: %.3fs, base: %.3fs, tags: %.3fs, types: %.3fs). "
"Count: %d -> %d",
duration, sfw_duration, favorites_duration, folder_duration,
base_models_duration, tags_duration, model_types_duration,
initial_count, len(items)
)
return items

View File

@@ -7,12 +7,18 @@ from natsort import natsorted
@dataclass
class RecipeCache:
"""Cache structure for Recipe data"""
raw_data: List[Dict]
sorted_by_name: List[Dict]
sorted_by_date: List[Dict]
folders: List[str] | None = None
folder_tree: Dict | None = None
def __post_init__(self):
self._lock = asyncio.Lock()
# Normalize optional metadata containers
self.folders = self.folders or []
self.folder_tree = self.folder_tree or {}
async def resort(self, name_only: bool = False):
"""Resort all cached data views"""

View File

@@ -1,7 +1,9 @@
import os
import logging
from __future__ import annotations
import asyncio
import json
import logging
import os
import time
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from ..config import config
@@ -14,6 +16,9 @@ from .recipes.errors import RecipeNotFoundError
from ..utils.utils import calculate_recipe_fingerprint, fuzzy_match
from natsort import natsorted
import sys
import re
from ..recipes.merger import GenParamsMerger
from ..recipes.enrichment import RecipeEnricher
logger = logging.getLogger(__name__)
@@ -52,6 +57,8 @@ class RecipeScanner:
cls._instance._civitai_client = None # Will be lazily initialized
return cls._instance
REPAIR_VERSION = 3
def __init__(
self,
lora_scanner: Optional[LoraScanner] = None,
@@ -64,6 +71,7 @@ class RecipeScanner:
self._initialization_task: Optional[asyncio.Task] = None
self._is_initializing = False
self._mutation_lock = asyncio.Lock()
self._post_scan_task: Optional[asyncio.Task] = None
self._resort_tasks: Set[asyncio.Task] = set()
if lora_scanner:
self._lora_scanner = lora_scanner
@@ -84,6 +92,10 @@ class RecipeScanner:
task.cancel()
self._resort_tasks.clear()
if self._post_scan_task and not self._post_scan_task.done():
self._post_scan_task.cancel()
self._post_scan_task = None
self._cache = None
self._initialization_task = None
self._is_initializing = False
@@ -102,19 +114,223 @@ class RecipeScanner:
self._civitai_client = await ServiceRegistry.get_civitai_client()
return self._civitai_client
async def repair_all_recipes(
self,
progress_callback: Optional[Callable[[Dict], Any]] = None
) -> Dict[str, Any]:
"""Repair all recipes by enrichment with Civitai and embedded metadata.
Args:
persistence_service: Service for saving updated recipes
progress_callback: Optional callback for progress updates
Returns:
Dict summary of repair results
"""
async with self._mutation_lock:
cache = await self.get_cached_data()
all_recipes = list(cache.raw_data)
total = len(all_recipes)
repaired_count = 0
skipped_count = 0
errors_count = 0
civitai_client = await self._get_civitai_client()
for i, recipe in enumerate(all_recipes):
try:
# Report progress
if progress_callback:
await progress_callback({
"status": "processing",
"current": i + 1,
"total": total,
"recipe_name": recipe.get("name", "Unknown")
})
if await self._repair_single_recipe(recipe, civitai_client):
repaired_count += 1
else:
skipped_count += 1
except Exception as e:
logger.error(f"Error repairing recipe {recipe.get('file_path')}: {e}")
errors_count += 1
# Final progress update
if progress_callback:
await progress_callback({
"status": "completed",
"repaired": repaired_count,
"skipped": skipped_count,
"errors": errors_count,
"total": total
})
return {
"success": True,
"repaired": repaired_count,
"skipped": skipped_count,
"errors": errors_count,
"total": total
}
async def repair_recipe_by_id(self, recipe_id: str) -> Dict[str, Any]:
"""Repair a single recipe by its ID.
Args:
recipe_id: ID of the recipe to repair
Returns:
Dict summary of repair result
"""
async with self._mutation_lock:
# Get raw recipe from cache directly to avoid formatted fields
cache = await self.get_cached_data()
recipe = next((r for r in cache.raw_data if str(r.get('id', '')) == recipe_id), None)
if not recipe:
raise RecipeNotFoundError(f"Recipe {recipe_id} not found")
civitai_client = await self._get_civitai_client()
success = await self._repair_single_recipe(recipe, civitai_client)
# If successfully repaired, we should return the formatted version for the UI
return {
"success": True,
"repaired": 1 if success else 0,
"skipped": 0 if success else 1,
"recipe": await self.get_recipe_by_id(recipe_id) if success else recipe
}
async def _repair_single_recipe(self, recipe: Dict[str, Any], civitai_client: Any) -> bool:
"""Internal helper to repair a single recipe object.
Args:
recipe: The recipe dictionary to repair (modified in-place)
civitai_client: Authenticated Civitai client
Returns:
bool: True if recipe was repaired or updated, False if skipped
"""
# 1. Skip if already at latest repair version
if recipe.get("repair_version", 0) >= self.REPAIR_VERSION:
return False
# 2. Identification: Is repair needed?
has_checkpoint = "checkpoint" in recipe and recipe["checkpoint"] and recipe["checkpoint"].get("name")
gen_params = recipe.get("gen_params", {})
has_prompt = bool(gen_params.get("prompt"))
needs_repair = not has_checkpoint or not has_prompt
if not needs_repair:
# Even if no repair needed, we mark it with version if it was processed
# Always update and save because if we are here, the version is old (checked in step 1)
recipe["repair_version"] = self.REPAIR_VERSION
await self._save_recipe_persistently(recipe)
return True
# 3. Use Enricher to repair/enrich
try:
updated = await RecipeEnricher.enrich_recipe(recipe, civitai_client)
except Exception as e:
logger.error(f"Error enriching recipe {recipe.get('id')}: {e}")
updated = False
# 4. Mark version and save if updated or just marking version
# If we updated it, OR if the version is old (which we know it is if we are here), save it.
# Actually, if we are here and updated is False, it means we tried to repair but couldn't/didn't need to.
# But we still want to mark it as processed so we don't try again until version bump.
if updated or recipe.get("repair_version", 0) < self.REPAIR_VERSION:
recipe["repair_version"] = self.REPAIR_VERSION
await self._save_recipe_persistently(recipe)
return True
return False
async def _save_recipe_persistently(self, recipe: Dict[str, Any]) -> bool:
"""Helper to save a recipe to both JSON and EXIF metadata."""
recipe_id = recipe.get("id")
if not recipe_id:
return False
recipe_json_path = await self.get_recipe_json_path(recipe_id)
if not recipe_json_path:
return False
try:
# 1. Sanitize for storage (remove runtime convenience fields)
clean_recipe = self._sanitize_recipe_for_storage(recipe)
# 2. Update the original dictionary so that we persist the clean version
# globally if needed, effectively overwriting it in-place.
recipe.clear()
recipe.update(clean_recipe)
# 3. Save JSON
with open(recipe_json_path, 'w', encoding='utf-8') as f:
json.dump(recipe, f, indent=4, ensure_ascii=False)
# 4. Update EXIF if image exists
image_path = recipe.get('file_path')
if image_path and os.path.exists(image_path):
from ..utils.exif_utils import ExifUtils
ExifUtils.append_recipe_metadata(image_path, recipe)
return True
except Exception as e:
logger.error(f"Error persisting recipe {recipe_id}: {e}")
return False
def _sanitize_recipe_for_storage(self, recipe: Dict[str, Any]) -> Dict[str, Any]:
"""Create a clean copy of the recipe without runtime convenience fields."""
import copy
clean = copy.deepcopy(recipe)
# 0. Clean top-level runtime fields
for key in ("file_url", "created_date_formatted", "modified_formatted"):
clean.pop(key, None)
# 1. Clean LORAs
if "loras" in clean and isinstance(clean["loras"], list):
for lora in clean["loras"]:
# Fields to remove (runtime only)
for key in ("inLibrary", "preview_url", "localPath"):
lora.pop(key, None)
# Normalize weight/strength if mapping is desired (standard in persistence_service)
if "weight" in lora and "strength" not in lora:
lora["strength"] = float(lora.pop("weight"))
# 2. Clean Checkpoint
if "checkpoint" in clean and isinstance(clean["checkpoint"], dict):
cp = clean["checkpoint"]
# Fields to remove (runtime only)
for key in ("inLibrary", "localPath", "preview_url", "thumbnailUrl", "size", "downloadUrl"):
cp.pop(key, None)
return clean
async def initialize_in_background(self) -> None:
"""Initialize cache in background using thread pool"""
try:
await self._wait_for_lora_scanner()
# Set initial empty cache to avoid None reference errors
if self._cache is None:
self._cache = RecipeCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[]
sorted_by_date=[],
folders=[],
folder_tree={},
)
# Mark as initializing to prevent concurrent initializations
self._is_initializing = True
self._initialization_task = asyncio.current_task()
try:
# Start timer
@@ -126,11 +342,14 @@ class RecipeScanner:
None, # Use default thread pool
self._initialize_recipe_cache_sync # Run synchronous version in thread
)
if cache is not None:
self._cache = cache
# Calculate elapsed time and log it
elapsed_time = time.time() - start_time
recipe_count = len(cache.raw_data) if cache and hasattr(cache, 'raw_data') else 0
logger.info(f"Recipe cache initialized in {elapsed_time:.2f} seconds. Found {recipe_count} recipes")
self._schedule_post_scan_enrichment()
finally:
# Mark initialization as complete regardless of outcome
self._is_initializing = False
@@ -207,6 +426,7 @@ class RecipeScanner:
# Update cache with the collected data
self._cache.raw_data = recipes
self._update_folder_metadata(self._cache)
# Create a simplified resort function that doesn't use await
if hasattr(self._cache, "resort"):
@@ -237,12 +457,97 @@ class RecipeScanner:
# Clean up the event loop
loop.close()
async def _wait_for_lora_scanner(self) -> None:
"""Ensure the LoRA scanner has initialized before recipe enrichment."""
if not getattr(self, "_lora_scanner", None):
return
lora_scanner = self._lora_scanner
cache_ready = getattr(lora_scanner, "_cache", None) is not None
# If cache is already available, we can proceed
if cache_ready:
return
# Await an existing initialization task if present
task = getattr(lora_scanner, "_initialization_task", None)
if task and hasattr(task, "done") and not task.done():
try:
await task
except Exception: # pragma: no cover - defensive guard
pass
if getattr(lora_scanner, "_cache", None) is not None:
return
# Otherwise, request initialization and proceed once it completes
try:
await lora_scanner.initialize_in_background()
except Exception as exc: # pragma: no cover - defensive guard
logger.debug("Recipe Scanner: LoRA init request failed: %s", exc)
def _schedule_post_scan_enrichment(self) -> None:
"""Kick off a non-blocking enrichment pass to fill remote metadata."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
if self._post_scan_task and not self._post_scan_task.done():
return
async def _run_enrichment():
try:
await self._enrich_cache_metadata()
except asyncio.CancelledError:
raise
except Exception as exc: # pragma: no cover - defensive guard
logger.error("Recipe Scanner: error during post-scan enrichment: %s", exc, exc_info=True)
self._post_scan_task = loop.create_task(_run_enrichment(), name="recipe_cache_enrichment")
async def _enrich_cache_metadata(self) -> None:
"""Perform remote metadata enrichment after the initial scan."""
cache = self._cache
if cache is None or not getattr(cache, "raw_data", None):
return
for index, recipe in enumerate(list(cache.raw_data)):
try:
metadata_updated = await self._update_lora_information(recipe)
if metadata_updated:
recipe_id = recipe.get("id")
if recipe_id:
recipe_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json")
if os.path.exists(recipe_path):
try:
self._write_recipe_file(recipe_path, recipe)
except Exception as exc: # pragma: no cover - best-effort persistence
logger.debug("Recipe Scanner: could not persist recipe %s: %s", recipe_id, exc)
except asyncio.CancelledError:
raise
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Recipe Scanner: error enriching recipe %s: %s", recipe.get("id"), exc, exc_info=True)
if index % 10 == 0:
await asyncio.sleep(0)
try:
await cache.resort()
except Exception as exc: # pragma: no cover - defensive logging
logger.debug("Recipe Scanner: error resorting cache after enrichment: %s", exc)
def _schedule_resort(self, *, name_only: bool = False) -> None:
"""Schedule a background resort of the recipe cache."""
if not self._cache:
return
# Keep folder metadata up to date alongside sort order
self._update_folder_metadata()
async def _resort_wrapper() -> None:
try:
await self._cache.resort(name_only=name_only)
@@ -253,6 +558,75 @@ class RecipeScanner:
self._resort_tasks.add(task)
task.add_done_callback(lambda finished: self._resort_tasks.discard(finished))
def _calculate_folder(self, recipe_path: str) -> str:
"""Calculate a normalized folder path relative to ``recipes_dir``."""
recipes_dir = self.recipes_dir
if not recipes_dir:
return ""
try:
recipe_dir = os.path.dirname(os.path.normpath(recipe_path))
relative_dir = os.path.relpath(recipe_dir, recipes_dir)
if relative_dir in (".", ""):
return ""
return relative_dir.replace(os.path.sep, "/")
except Exception:
return ""
def _build_folder_tree(self, folders: list[str]) -> dict:
"""Build a nested folder tree structure from relative folder paths."""
tree: dict[str, dict] = {}
for folder in folders:
if not folder:
continue
parts = folder.split("/")
current_level = tree
for part in parts:
if part not in current_level:
current_level[part] = {}
current_level = current_level[part]
return tree
def _update_folder_metadata(self, cache: RecipeCache | None = None) -> None:
"""Ensure folder lists and tree metadata are synchronized with cache contents."""
cache = cache or self._cache
if cache is None:
return
folders: set[str] = set()
for item in cache.raw_data:
folder_value = item.get("folder", "")
if folder_value is None:
folder_value = ""
if folder_value == ".":
folder_value = ""
normalized = str(folder_value).replace("\\", "/")
item["folder"] = normalized
folders.add(normalized)
cache.folders = sorted(folders, key=lambda entry: entry.lower())
cache.folder_tree = self._build_folder_tree(cache.folders)
async def get_folders(self) -> list[str]:
"""Return a sorted list of recipe folders relative to the recipes root."""
cache = await self.get_cached_data()
self._update_folder_metadata(cache)
return cache.folders
async def get_folder_tree(self) -> dict:
"""Return a hierarchical tree of recipe folders for sidebar navigation."""
cache = await self.get_cached_data()
self._update_folder_metadata(cache)
return cache.folder_tree
@property
def recipes_dir(self) -> str:
"""Get path to recipes directory"""
@@ -269,11 +643,14 @@ class RecipeScanner:
"""Get cached recipe data, refresh if needed"""
# If cache is already initialized and no refresh is needed, return it immediately
if self._cache is not None and not force_refresh:
self._update_folder_metadata()
return self._cache
# If another initialization is already in progress, wait for it to complete
if self._is_initializing and not force_refresh:
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
return self._cache or RecipeCache(
raw_data=[], sorted_by_name=[], sorted_by_date=[], folders=[], folder_tree={}
)
# If force refresh is requested, initialize the cache directly
if force_refresh:
@@ -291,11 +668,14 @@ class RecipeScanner:
self._cache = RecipeCache(
raw_data=raw_data,
sorted_by_name=[],
sorted_by_date=[]
sorted_by_date=[],
folders=[],
folder_tree={},
)
# Resort cache
await self._cache.resort()
self._update_folder_metadata(self._cache)
return self._cache
@@ -305,7 +685,9 @@ class RecipeScanner:
self._cache = RecipeCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[]
sorted_by_date=[],
folders=[],
folder_tree={},
)
return self._cache
finally:
@@ -316,7 +698,9 @@ class RecipeScanner:
logger.error(f"Unexpected error in get_cached_data: {e}")
# Return the cache (may be empty or partially initialized)
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
return self._cache or RecipeCache(
raw_data=[], sorted_by_name=[], sorted_by_date=[], folders=[], folder_tree={}
)
async def refresh_cache(self, force: bool = False) -> RecipeCache:
"""Public helper to refresh or return the recipe cache."""
@@ -331,6 +715,7 @@ class RecipeScanner:
cache = await self.get_cached_data()
await cache.add_recipe(recipe_data, resort=False)
self._update_folder_metadata(cache)
self._schedule_resort()
async def remove_recipe(self, recipe_id: str) -> bool:
@@ -344,6 +729,7 @@ class RecipeScanner:
if removed is None:
return False
self._update_folder_metadata(cache)
self._schedule_resort()
return True
@@ -428,6 +814,9 @@ class RecipeScanner:
if path_updated:
self._write_recipe_file(recipe_path, recipe_data)
# Track folder placement relative to recipes directory
recipe_data['folder'] = recipe_data.get('folder') or self._calculate_folder(recipe_path)
# Ensure loras array exists
if 'loras' not in recipe_data:
@@ -438,7 +827,7 @@ class RecipeScanner:
recipe_data['gen_params'] = {}
# Update lora information with local paths and availability
await self._update_lora_information(recipe_data)
lora_metadata_updated = await self._update_lora_information(recipe_data)
if recipe_data.get('checkpoint'):
checkpoint_entry = self._normalize_checkpoint_entry(recipe_data['checkpoint'])
@@ -459,6 +848,12 @@ class RecipeScanner:
logger.info(f"Added fingerprint to recipe: {recipe_path}")
except Exception as e:
logger.error(f"Error writing updated recipe with fingerprint: {e}")
elif lora_metadata_updated:
# Persist updates such as marking invalid entries as deleted
try:
self._write_recipe_file(recipe_path, recipe_data)
except Exception as e:
logger.error(f"Error writing updated recipe metadata: {e}")
return recipe_data
except Exception as e:
@@ -519,7 +914,13 @@ class RecipeScanner:
logger.warning(f"Marked lora with modelVersionId {model_version_id} as deleted")
metadata_updated = True
else:
logger.debug(f"Could not get hash for modelVersionId {model_version_id}")
# No hash returned; mark as deleted to avoid repeated lookups
lora['isDeleted'] = True
metadata_updated = True
logger.warning(
"Marked lora with modelVersionId %s as deleted after failed hash lookup",
model_version_id,
)
# If has hash but no file_name, look up in lora library
if 'hash' in lora and (not lora.get('file_name') or not lora['file_name']):
@@ -809,7 +1210,7 @@ class RecipeScanner:
return await self._lora_scanner.get_model_info_by_name(name)
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None, lora_hash: str = None, bypass_filters: bool = True):
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None, lora_hash: str = None, bypass_filters: bool = True, folder: str | None = None, recursive: bool = True):
"""Get paginated and filtered recipe data
Args:
@@ -821,11 +1222,20 @@ class RecipeScanner:
search_options: Dictionary of search options to apply
lora_hash: Optional SHA256 hash of a LoRA to filter recipes by
bypass_filters: If True, ignore other filters when a lora_hash is provided
folder: Optional folder filter relative to recipes directory
recursive: Whether to include recipes in subfolders of the selected folder
"""
cache = await self.get_cached_data()
# Get base dataset
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
sort_field = sort_by.split(':')[0] if ':' in sort_by else sort_by
if sort_field == 'date':
filtered_data = list(cache.sorted_by_date)
elif sort_field == 'name':
filtered_data = list(cache.sorted_by_name)
else:
filtered_data = list(cache.raw_data)
# Apply SFW filtering if enabled
from .settings_manager import get_settings_manager
@@ -856,6 +1266,22 @@ class RecipeScanner:
# Skip further filtering if we're only filtering by LoRA hash with bypass enabled
if not (lora_hash and bypass_filters):
# Apply folder filter before other criteria
if folder is not None:
normalized_folder = folder.strip("/")
def matches_folder(item_folder: str) -> bool:
item_path = (item_folder or "").strip("/")
if recursive:
if not normalized_folder:
return True
return item_path == normalized_folder or item_path.startswith(f"{normalized_folder}/")
return item_path == normalized_folder
filtered_data = [
item for item in filtered_data
if matches_folder(item.get('folder', ''))
]
# Apply search filter
if search:
# Default search options if none provided
@@ -892,6 +1318,14 @@ class RecipeScanner:
if fuzzy_match(str(lora.get('modelName', '')), search):
return True
# Search in prompt and negative_prompt if enabled
if search_options.get('prompt', True) and 'gen_params' in item:
gen_params = item['gen_params']
if fuzzy_match(str(gen_params.get('prompt', '')), search):
return True
if fuzzy_match(str(gen_params.get('negative_prompt', '')), search):
return True
# No match found
return False
@@ -907,6 +1341,13 @@ class RecipeScanner:
if item.get('base_model', '') in filters['base_model']
]
# Filter by favorite
if 'favorite' in filters and filters['favorite']:
filtered_data = [
item for item in filtered_data
if item.get('favorite') is True
]
# Filter by tags
if 'tags' in filters and filters['tags']:
tag_spec = filters['tags']
@@ -925,17 +1366,41 @@ class RecipeScanner:
include_tags = {tag for tag in tag_spec if tag}
if include_tags:
def matches_include(item_tags):
if not item_tags and "__no_tags__" in include_tags:
return True
return any(tag in include_tags for tag in (item_tags or []))
filtered_data = [
item for item in filtered_data
if any(tag in include_tags for tag in (item.get('tags', []) or []))
if matches_include(item.get('tags'))
]
if exclude_tags:
def matches_exclude(item_tags):
if not item_tags and "__no_tags__" in exclude_tags:
return True
return any(tag in exclude_tags for tag in (item_tags or []))
filtered_data = [
item for item in filtered_data
if not any(tag in exclude_tags for tag in (item.get('tags', []) or []))
if not matches_exclude(item.get('tags'))
]
# Apply sorting if not already handled by pre-sorted cache
if ':' in sort_by or sort_field == 'loras_count':
field, order = (sort_by.split(':') + ['desc'])[:2]
reverse = order.lower() == 'desc'
if field == 'name':
filtered_data = natsorted(filtered_data, key=lambda x: x.get('title', '').lower(), reverse=reverse)
elif field == 'date':
# Use modified if available, falling back to created_date
filtered_data.sort(key=lambda x: (x.get('modified', x.get('created_date', 0)), x.get('file_path', '')), reverse=reverse)
elif field == 'loras_count':
filtered_data.sort(key=lambda x: len(x.get('loras', [])), reverse=reverse)
# Calculate pagination
total_items = len(filtered_data)
start_idx = (page - 1) * page_size
@@ -1031,6 +1496,30 @@ class RecipeScanner:
from datetime import datetime
return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S')
async def get_recipe_json_path(self, recipe_id: str) -> Optional[str]:
"""Locate the recipe JSON file, accounting for folder placement."""
recipes_dir = self.recipes_dir
if not recipes_dir:
return None
cache = await self.get_cached_data()
folder = ""
for item in cache.raw_data:
if str(item.get("id")) == str(recipe_id):
folder = item.get("folder") or ""
break
candidate = os.path.normpath(os.path.join(recipes_dir, folder, f"{recipe_id}.recipe.json"))
if os.path.exists(candidate):
return candidate
for root, _, files in os.walk(recipes_dir):
if f"{recipe_id}.recipe.json" in files:
return os.path.join(root, f"{recipe_id}.recipe.json")
return None
async def update_recipe_metadata(self, recipe_id: str, metadata: dict) -> bool:
"""Update recipe metadata (like title and tags) in both file system and cache
@@ -1041,13 +1530,9 @@ class RecipeScanner:
Returns:
bool: True if successful, False otherwise
"""
import os
import json
# First, find the recipe JSON file path
recipe_json_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json")
if not os.path.exists(recipe_json_path):
recipe_json_path = await self.get_recipe_json_path(recipe_id)
if not recipe_json_path or not os.path.exists(recipe_json_path):
return False
try:
@@ -1096,8 +1581,8 @@ class RecipeScanner:
if target_name is None:
raise ValueError("target_name must be provided")
recipe_json_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json")
if not os.path.exists(recipe_json_path):
recipe_json_path = await self.get_recipe_json_path(recipe_id)
if not recipe_json_path or not os.path.exists(recipe_json_path):
raise RecipeNotFoundError("Recipe not found")
async with self._mutation_lock:
@@ -1228,71 +1713,56 @@ class RecipeScanner:
# Always use lowercase hash for consistency
hash_value = hash_value.lower()
# Get recipes directory
recipes_dir = self.recipes_dir
if not recipes_dir or not os.path.exists(recipes_dir):
logger.warning(f"Recipes directory not found: {recipes_dir}")
# Get cache
cache = await self.get_cached_data()
if not cache or not cache.raw_data:
return 0, 0
file_updated_count = 0
cache_updated_count = 0
# Find recipes that need updating from the cache
recipes_to_update = []
for recipe in cache.raw_data:
loras = recipe.get('loras', [])
if not isinstance(loras, list):
continue
has_match = False
for lora in loras:
if not isinstance(lora, dict):
continue
if (lora.get('hash') or '').lower() == hash_value:
if lora.get('file_name') != new_file_name:
lora['file_name'] = new_file_name
has_match = True
if has_match:
recipes_to_update.append(recipe)
cache_updated_count += 1
if not recipes_to_update:
return 0, 0
# Check if cache is initialized
cache_initialized = self._cache is not None
cache_updated_count = 0
file_updated_count = 0
# Get all recipe JSON files in the recipes directory
recipe_files = []
for root, _, files in os.walk(recipes_dir):
for file in files:
if file.lower().endswith('.recipe.json'):
recipe_files.append(os.path.join(root, file))
# Process each recipe file
for recipe_path in recipe_files:
try:
# Load the recipe data
with open(recipe_path, 'r', encoding='utf-8') as f:
recipe_data = json.load(f)
# Skip if no loras or invalid structure
if not recipe_data or not isinstance(recipe_data, dict) or 'loras' not in recipe_data:
# Persist changes to disk
async with self._mutation_lock:
for recipe in recipes_to_update:
recipe_id = recipe.get('id')
if not recipe_id:
continue
# Check if any lora has matching hash
file_updated = False
for lora in recipe_data.get('loras', []):
if 'hash' in lora and lora['hash'].lower() == hash_value:
# Update file_name
old_file_name = lora.get('file_name', '')
lora['file_name'] = new_file_name
file_updated = True
logger.info(f"Updated file_name in recipe {recipe_path}: {old_file_name} -> {new_file_name}")
# If updated, save the file
if file_updated:
with open(recipe_path, 'w', encoding='utf-8') as f:
json.dump(recipe_data, f, indent=4, ensure_ascii=False)
file_updated_count += 1
# Also update in cache if it exists
if cache_initialized:
recipe_id = recipe_data.get('id')
if recipe_id:
for cache_item in self._cache.raw_data:
if cache_item.get('id') == recipe_id:
# Replace loras array with updated version
cache_item['loras'] = recipe_data['loras']
cache_updated_count += 1
break
except Exception as e:
logger.error(f"Error updating recipe file {recipe_path}: {e}")
import traceback
traceback.print_exc(file=sys.stderr)
recipe_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json")
try:
self._write_recipe_file(recipe_path, recipe)
file_updated_count += 1
logger.info(f"Updated file_name in recipe {recipe_path}: -> {new_file_name}")
except Exception as e:
logger.error(f"Error updating recipe file {recipe_path}: {e}")
# Resort cache if updates were made
if cache_initialized and cache_updated_count > 0:
await self._cache.resort()
logger.info(f"Resorted recipe cache after updating {cache_updated_count} items")
# We don't necessarily need to resort because LoRA file_name isn't a sort key,
# but we might want to schedule a resort if we're paranoid or if searching relies on sorted state.
# Given it's a rename of a dependency, search results might change if searching by LoRA name.
self._schedule_resort()
return file_updated_count, cache_updated_count

View File

@@ -13,6 +13,7 @@ import numpy as np
from PIL import Image
from ...utils.utils import calculate_recipe_fingerprint
from ...utils.civitai_utils import rewrite_preview_url
from .errors import (
RecipeDownloadError,
RecipeNotFoundError,
@@ -94,18 +95,39 @@ class RecipeAnalysisService:
if civitai_client is None:
raise RecipeServiceError("Civitai client unavailable")
temp_path = self._create_temp_path()
temp_path = None
metadata: Optional[dict[str, Any]] = None
is_video = False
extension = ".jpg" # Default
try:
civitai_match = re.match(r"https://civitai\.com/images/(\d+)", url)
if civitai_match:
image_info = await civitai_client.get_image_info(civitai_match.group(1))
if not image_info:
raise RecipeDownloadError("Failed to fetch image information from Civitai")
image_url = image_info.get("url")
if not image_url:
raise RecipeDownloadError("No image URL found in Civitai response")
is_video = image_info.get("type") == "video"
# Use optimized preview URLs if possible
rewritten_url, _ = rewrite_preview_url(image_url, media_type=image_info.get("type"))
if rewritten_url:
image_url = rewritten_url
if is_video:
# Extract extension from URL
url_path = image_url.split('?')[0].split('#')[0]
extension = os.path.splitext(url_path)[1].lower() or ".mp4"
else:
extension = ".jpg"
temp_path = self._create_temp_path(suffix=extension)
await self._download_image(image_url, temp_path)
metadata = image_info.get("meta") if "meta" in image_info else None
if (
isinstance(metadata, dict)
@@ -114,22 +136,31 @@ class RecipeAnalysisService:
):
metadata = metadata["meta"]
else:
# Basic extension detection for non-Civitai URLs
url_path = url.split('?')[0].split('#')[0]
extension = os.path.splitext(url_path)[1].lower()
if extension in [".mp4", ".webm"]:
is_video = True
else:
extension = ".jpg"
temp_path = self._create_temp_path(suffix=extension)
await self._download_image(url, temp_path)
if metadata is None:
if metadata is None and not is_video:
metadata = self._exif_utils.extract_image_metadata(temp_path)
if not metadata:
return self._metadata_not_found_response(temp_path)
return await self._parse_metadata(
metadata,
metadata or {},
recipe_scanner=recipe_scanner,
image_path=temp_path,
include_image_base64=True,
is_video=is_video,
extension=extension,
)
finally:
self._safe_cleanup(temp_path)
if temp_path:
self._safe_cleanup(temp_path)
async def analyze_local_image(
self,
@@ -198,12 +229,16 @@ class RecipeAnalysisService:
recipe_scanner,
image_path: Optional[str],
include_image_base64: bool,
is_video: bool = False,
extension: str = ".jpg",
) -> AnalysisResult:
parser = self._recipe_parser_factory.create_parser(metadata)
if parser is None:
payload = {"error": "No parser found for this image", "loras": []}
if include_image_base64 and image_path:
payload["image_base64"] = self._encode_file(image_path)
payload["is_video"] = is_video
payload["extension"] = extension
return AnalysisResult(payload)
result = await parser.parse_metadata(metadata, recipe_scanner=recipe_scanner)
@@ -211,6 +246,9 @@ class RecipeAnalysisService:
if include_image_base64 and image_path:
result["image_base64"] = self._encode_file(image_path)
result["is_video"] = is_video
result["extension"] = extension
if "error" in result and not result.get("loras"):
return AnalysisResult(result)
@@ -241,8 +279,8 @@ class RecipeAnalysisService:
temp_file.write(data)
return temp_file.name
def _create_temp_path(self) -> str:
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
def _create_temp_path(self, suffix: str = ".jpg") -> str:
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
return temp_file.name
def _safe_cleanup(self, path: Optional[str]) -> None:

View File

@@ -5,6 +5,7 @@ import base64
import json
import os
import re
import shutil
import time
import uuid
from dataclasses import dataclass
@@ -46,6 +47,7 @@ class RecipePersistenceService:
name: str | None,
tags: Iterable[str],
metadata: Optional[dict[str, Any]],
extension: str | None = None,
) -> PersistenceResult:
"""Persist a user uploaded recipe."""
@@ -64,13 +66,21 @@ class RecipePersistenceService:
os.makedirs(recipes_dir, exist_ok=True)
recipe_id = str(uuid.uuid4())
optimized_image, extension = self._exif_utils.optimize_image(
image_data=resolved_image_bytes,
target_width=self._card_preview_width,
format="webp",
quality=85,
preserve_metadata=True,
)
# Handle video formats by bypassing optimization and metadata embedding
is_video = extension in [".mp4", ".webm"]
if is_video:
optimized_image = resolved_image_bytes
# extension is already set
else:
optimized_image, extension = self._exif_utils.optimize_image(
image_data=resolved_image_bytes,
target_width=self._card_preview_width,
format="webp",
quality=85,
preserve_metadata=True,
)
image_filename = f"{recipe_id}{extension}"
image_path = os.path.join(recipes_dir, image_filename)
normalized_image_path = os.path.normpath(image_path)
@@ -126,7 +136,8 @@ class RecipePersistenceService:
with open(json_path, "w", encoding="utf-8") as file_obj:
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
self._exif_utils.append_recipe_metadata(normalized_image_path, recipe_data)
if not is_video:
self._exif_utils.append_recipe_metadata(normalized_image_path, recipe_data)
matching_recipes = await self._find_matching_recipes(recipe_scanner, fingerprint, exclude_id=recipe_id)
await recipe_scanner.add_recipe(recipe_data)
@@ -144,12 +155,8 @@ class RecipePersistenceService:
async def delete_recipe(self, *, recipe_scanner, recipe_id: str) -> PersistenceResult:
"""Delete an existing recipe."""
recipes_dir = recipe_scanner.recipes_dir
if not recipes_dir or not os.path.exists(recipes_dir):
raise RecipeNotFoundError("Recipes directory not found")
recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json")
if not os.path.exists(recipe_json_path):
recipe_json_path = await recipe_scanner.get_recipe_json_path(recipe_id)
if not recipe_json_path or not os.path.exists(recipe_json_path):
raise RecipeNotFoundError("Recipe not found")
with open(recipe_json_path, "r", encoding="utf-8") as file_obj:
@@ -166,9 +173,9 @@ class RecipePersistenceService:
async def update_recipe(self, *, recipe_scanner, recipe_id: str, updates: dict[str, Any]) -> PersistenceResult:
"""Update persisted metadata for a recipe."""
if not any(key in updates for key in ("title", "tags", "source_path", "preview_nsfw_level")):
if not any(key in updates for key in ("title", "tags", "source_path", "preview_nsfw_level", "favorite")):
raise RecipeValidationError(
"At least one field to update must be provided (title or tags or source_path or preview_nsfw_level)"
"At least one field to update must be provided (title or tags or source_path or preview_nsfw_level or favorite)"
)
success = await recipe_scanner.update_recipe_metadata(recipe_id, updates)
@@ -177,6 +184,163 @@ class RecipePersistenceService:
return PersistenceResult({"success": True, "recipe_id": recipe_id, "updates": updates})
def _normalize_target_path(self, recipe_scanner, target_path: str) -> tuple[str, str]:
"""Normalize and validate the target path for recipe moves."""
if not target_path:
raise RecipeValidationError("Target path is required")
recipes_root = recipe_scanner.recipes_dir
if not recipes_root:
raise RecipeNotFoundError("Recipes directory not found")
normalized_target = os.path.normpath(target_path)
recipes_root = os.path.normpath(recipes_root)
if not os.path.isabs(normalized_target):
normalized_target = os.path.normpath(os.path.join(recipes_root, normalized_target))
try:
common_root = os.path.commonpath([normalized_target, recipes_root])
except ValueError as exc:
raise RecipeValidationError("Invalid target path") from exc
if common_root != recipes_root:
raise RecipeValidationError("Target path must be inside the recipes directory")
return normalized_target, recipes_root
async def _move_recipe_files(
self,
*,
recipe_scanner,
recipe_id: str,
normalized_target: str,
recipes_root: str,
) -> dict[str, Any]:
"""Move the recipe's JSON and preview image into the normalized target."""
recipe_json_path = await recipe_scanner.get_recipe_json_path(recipe_id)
if not recipe_json_path or not os.path.exists(recipe_json_path):
raise RecipeNotFoundError("Recipe not found")
recipe_data = await recipe_scanner.get_recipe_by_id(recipe_id)
if not recipe_data:
raise RecipeNotFoundError("Recipe not found")
current_json_dir = os.path.dirname(recipe_json_path)
normalized_image_path = os.path.normpath(recipe_data.get("file_path") or "") if recipe_data.get("file_path") else None
os.makedirs(normalized_target, exist_ok=True)
if os.path.normpath(current_json_dir) == normalized_target:
return {
"success": True,
"message": "Recipe is already in the target folder",
"recipe_id": recipe_id,
"original_file_path": recipe_data.get("file_path"),
"new_file_path": recipe_data.get("file_path"),
}
new_json_path = os.path.normpath(os.path.join(normalized_target, os.path.basename(recipe_json_path)))
shutil.move(recipe_json_path, new_json_path)
new_image_path = normalized_image_path
if normalized_image_path:
target_image_path = os.path.normpath(os.path.join(normalized_target, os.path.basename(normalized_image_path)))
if os.path.exists(normalized_image_path) and normalized_image_path != target_image_path:
shutil.move(normalized_image_path, target_image_path)
new_image_path = target_image_path
relative_folder = os.path.relpath(normalized_target, recipes_root)
if relative_folder in (".", ""):
relative_folder = ""
updates = {"file_path": new_image_path or recipe_data.get("file_path"), "folder": relative_folder.replace(os.path.sep, "/")}
updated = await recipe_scanner.update_recipe_metadata(recipe_id, updates)
if not updated:
raise RecipeNotFoundError("Recipe not found after move")
return {
"success": True,
"recipe_id": recipe_id,
"original_file_path": recipe_data.get("file_path"),
"new_file_path": updates["file_path"],
"json_path": new_json_path,
"folder": updates["folder"],
}
async def move_recipe(self, *, recipe_scanner, recipe_id: str, target_path: str) -> PersistenceResult:
"""Move a recipe's assets into a new folder under the recipes root."""
normalized_target, recipes_root = self._normalize_target_path(recipe_scanner, target_path)
result = await self._move_recipe_files(
recipe_scanner=recipe_scanner,
recipe_id=recipe_id,
normalized_target=normalized_target,
recipes_root=recipes_root,
)
return PersistenceResult(result)
async def move_recipes_bulk(
self,
*,
recipe_scanner,
recipe_ids: Iterable[str],
target_path: str,
) -> PersistenceResult:
"""Move multiple recipes to a new folder."""
recipe_ids = list(recipe_ids)
if not recipe_ids:
raise RecipeValidationError("No recipe IDs provided")
normalized_target, recipes_root = self._normalize_target_path(recipe_scanner, target_path)
results: list[dict[str, Any]] = []
success_count = 0
failure_count = 0
for recipe_id in recipe_ids:
try:
move_result = await self._move_recipe_files(
recipe_scanner=recipe_scanner,
recipe_id=str(recipe_id),
normalized_target=normalized_target,
recipes_root=recipes_root,
)
results.append(
{
"recipe_id": recipe_id,
"original_file_path": move_result.get("original_file_path"),
"new_file_path": move_result.get("new_file_path"),
"success": True,
"message": move_result.get("message", ""),
"folder": move_result.get("folder", ""),
}
)
success_count += 1
except Exception as exc: # pragma: no cover - per-item error handling
results.append(
{
"recipe_id": recipe_id,
"original_file_path": None,
"new_file_path": None,
"success": False,
"message": str(exc),
}
)
failure_count += 1
return PersistenceResult(
{
"success": True,
"message": f"Moved {success_count} of {len(recipe_ids)} recipes",
"results": results,
"success_count": success_count,
"failure_count": failure_count,
}
)
async def reconnect_lora(
self,
*,
@@ -187,8 +351,8 @@ class RecipePersistenceService:
) -> PersistenceResult:
"""Reconnect a LoRA entry within an existing recipe."""
recipe_path = os.path.join(recipe_scanner.recipes_dir, f"{recipe_id}.recipe.json")
if not os.path.exists(recipe_path):
recipe_path = await recipe_scanner.get_recipe_json_path(recipe_id)
if not recipe_path or not os.path.exists(recipe_path):
raise RecipeNotFoundError("Recipe not found")
target_lora = await recipe_scanner.get_local_lora(target_name)
@@ -233,16 +397,12 @@ class RecipePersistenceService:
if not recipe_ids:
raise RecipeValidationError("No recipe IDs provided")
recipes_dir = recipe_scanner.recipes_dir
if not recipes_dir or not os.path.exists(recipes_dir):
raise RecipeNotFoundError("Recipes directory not found")
deleted_recipes: list[str] = []
failed_recipes: list[dict[str, Any]] = []
for recipe_id in recipe_ids:
recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json")
if not os.path.exists(recipe_json_path):
recipe_json_path = await recipe_scanner.get_recipe_json_path(recipe_id)
if not recipe_json_path or not os.path.exists(recipe_json_path):
failed_recipes.append({"id": recipe_id, "reason": "Recipe not found"})
continue

View File

@@ -20,6 +20,8 @@ class WebSocketManager:
self._last_init_progress: Dict[str, Dict] = {}
# Add auto-organize progress tracking
self._auto_organize_progress: Optional[Dict] = None
# Add recipe repair progress tracking
self._recipe_repair_progress: Optional[Dict] = None
self._auto_organize_lock = asyncio.Lock()
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
@@ -189,6 +191,14 @@ class WebSocketManager:
# Broadcast via WebSocket
await self.broadcast(data)
async def broadcast_recipe_repair_progress(self, data: Dict):
"""Broadcast recipe repair progress to connected clients"""
# Store progress data in memory
self._recipe_repair_progress = data
# Broadcast via WebSocket
await self.broadcast(data)
def get_auto_organize_progress(self) -> Optional[Dict]:
"""Get current auto-organize progress"""
return self._auto_organize_progress
@@ -197,6 +207,14 @@ class WebSocketManager:
"""Clear auto-organize progress data"""
self._auto_organize_progress = None
def get_recipe_repair_progress(self) -> Optional[Dict]:
"""Get current recipe repair progress"""
return self._recipe_repair_progress
def cleanup_recipe_repair_progress(self):
"""Clear recipe repair progress data"""
self._recipe_repair_progress = None
def is_auto_organize_running(self) -> bool:
"""Check if auto-organize is currently running"""
if not self._auto_organize_progress: