mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 14:42:11 -03:00
Merge branch 'sort-by-usage-count' into main
This commit is contained in:
131
py/config.py
131
py/config.py
@@ -1,11 +1,13 @@
|
||||
import os
|
||||
import platform
|
||||
import threading
|
||||
from pathlib import Path
|
||||
import folder_paths # type: ignore
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Set
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple
|
||||
import logging
|
||||
import json
|
||||
import urllib.parse
|
||||
import time
|
||||
|
||||
from .utils.settings_paths import ensure_settings_file, get_settings_dir, load_settings_template
|
||||
|
||||
@@ -80,6 +82,8 @@ class Config:
|
||||
self._path_mappings: Dict[str, str] = {}
|
||||
# Normalized preview root directories used to validate preview access
|
||||
self._preview_root_paths: Set[Path] = set()
|
||||
# Optional background rescan thread
|
||||
self._rescan_thread: Optional[threading.Thread] = None
|
||||
self.loras_roots = self._init_lora_paths()
|
||||
self.checkpoints_roots = None
|
||||
self.unet_roots = None
|
||||
@@ -282,58 +286,25 @@ class Config:
|
||||
def _load_symlink_cache(self) -> bool:
|
||||
cache_path = self._get_symlink_cache_path()
|
||||
if not cache_path.exists():
|
||||
logger.info("Symlink cache not found at %s", cache_path)
|
||||
return False
|
||||
|
||||
try:
|
||||
with cache_path.open("r", encoding="utf-8") as handle:
|
||||
payload = json.load(handle)
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to load symlink cache %s: %s", cache_path, exc)
|
||||
logger.info("Failed to load symlink cache %s: %s", cache_path, exc)
|
||||
return False
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
logger.info("Symlink cache payload is not a dict: %s", type(payload))
|
||||
return False
|
||||
|
||||
cached_fingerprint = payload.get("fingerprint")
|
||||
cached_mappings = payload.get("path_mappings")
|
||||
if not isinstance(cached_fingerprint, dict) or not isinstance(cached_mappings, Mapping):
|
||||
if not isinstance(cached_mappings, Mapping):
|
||||
logger.info("Symlink cache missing path mappings")
|
||||
return False
|
||||
|
||||
current_fingerprint = self._build_symlink_fingerprint()
|
||||
cached_roots = cached_fingerprint.get("roots")
|
||||
cached_stats = cached_fingerprint.get("stats")
|
||||
if (
|
||||
not isinstance(cached_roots, list)
|
||||
or not isinstance(cached_stats, Mapping)
|
||||
or sorted(cached_roots) != sorted(current_fingerprint["roots"]) # type: ignore[index]
|
||||
):
|
||||
return False
|
||||
|
||||
for root in current_fingerprint["roots"]: # type: ignore[assignment]
|
||||
cached_stat = cached_stats.get(root) if isinstance(cached_stats, Mapping) else None
|
||||
current_stat = current_fingerprint["stats"].get(root) # type: ignore[index]
|
||||
if not isinstance(cached_stat, Mapping) or not current_stat:
|
||||
return False
|
||||
|
||||
cached_mtime = cached_stat.get("mtime_ns")
|
||||
cached_inode = cached_stat.get("inode")
|
||||
current_mtime = current_stat.get("mtime_ns")
|
||||
current_inode = current_stat.get("inode")
|
||||
|
||||
if cached_inode != current_inode:
|
||||
return False
|
||||
|
||||
if cached_mtime != current_mtime:
|
||||
cached_noise = cached_stat.get("noise_mtime_ns")
|
||||
current_noise = current_stat.get("noise_mtime_ns")
|
||||
if not (
|
||||
cached_noise
|
||||
and current_noise
|
||||
and cached_mtime == cached_noise
|
||||
and current_mtime == current_noise
|
||||
):
|
||||
return False
|
||||
|
||||
normalized_mappings: Dict[str, str] = {}
|
||||
for target, link in cached_mappings.items():
|
||||
if not isinstance(target, str) or not isinstance(link, str):
|
||||
@@ -341,6 +312,7 @@ class Config:
|
||||
normalized_mappings[self._normalize_path(target)] = self._normalize_path(link)
|
||||
|
||||
self._path_mappings = normalized_mappings
|
||||
logger.info("Symlink cache loaded with %d mappings", len(self._path_mappings))
|
||||
return True
|
||||
|
||||
def _save_symlink_cache(self) -> None:
|
||||
@@ -353,22 +325,75 @@ class Config:
|
||||
try:
|
||||
with cache_path.open("w", encoding="utf-8") as handle:
|
||||
json.dump(payload, handle, ensure_ascii=False, indent=2)
|
||||
logger.info("Symlink cache saved to %s with %d mappings", cache_path, len(self._path_mappings))
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to write symlink cache %s: %s", cache_path, exc)
|
||||
logger.info("Failed to write symlink cache %s: %s", cache_path, exc)
|
||||
|
||||
def _initialize_symlink_mappings(self) -> None:
|
||||
if not self._load_symlink_cache():
|
||||
self._scan_symbolic_links()
|
||||
self._save_symlink_cache()
|
||||
else:
|
||||
logger.info("Loaded symlink mappings from cache")
|
||||
start = time.perf_counter()
|
||||
cache_loaded = self._load_symlink_cache()
|
||||
|
||||
if cache_loaded:
|
||||
logger.info(
|
||||
"Symlink mappings restored from cache in %.2f ms",
|
||||
(time.perf_counter() - start) * 1000,
|
||||
)
|
||||
self._rebuild_preview_roots()
|
||||
self._schedule_symlink_rescan()
|
||||
return
|
||||
|
||||
self._scan_symbolic_links()
|
||||
self._save_symlink_cache()
|
||||
self._rebuild_preview_roots()
|
||||
logger.info(
|
||||
"Symlink mappings rebuilt and cached in %.2f ms",
|
||||
(time.perf_counter() - start) * 1000,
|
||||
)
|
||||
|
||||
def _scan_symbolic_links(self):
|
||||
"""Scan all symbolic links in LoRA, Checkpoint, and Embedding root directories"""
|
||||
start = time.perf_counter()
|
||||
# Reset mappings before rescanning to avoid stale entries
|
||||
self._path_mappings.clear()
|
||||
self._seed_root_symlink_mappings()
|
||||
visited_dirs: Set[str] = set()
|
||||
for root in self._symlink_roots():
|
||||
self._scan_directory_links(root, visited_dirs)
|
||||
logger.info(
|
||||
"Symlink scan finished in %.2f ms with %d mappings",
|
||||
(time.perf_counter() - start) * 1000,
|
||||
len(self._path_mappings),
|
||||
)
|
||||
|
||||
def _schedule_symlink_rescan(self) -> None:
|
||||
"""Trigger a best-effort background rescan to refresh stale caches."""
|
||||
|
||||
if self._rescan_thread and self._rescan_thread.is_alive():
|
||||
return
|
||||
|
||||
def worker():
|
||||
try:
|
||||
self._scan_symbolic_links()
|
||||
self._save_symlink_cache()
|
||||
self._rebuild_preview_roots()
|
||||
logger.info("Background symlink rescan completed")
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.info("Background symlink rescan failed: %s", exc)
|
||||
|
||||
thread = threading.Thread(
|
||||
target=worker,
|
||||
name="lora-manager-symlink-rescan",
|
||||
daemon=True,
|
||||
)
|
||||
self._rescan_thread = thread
|
||||
thread.start()
|
||||
|
||||
def _wait_for_rescan(self, timeout: Optional[float] = None) -> None:
|
||||
"""Block until the background rescan completes (testing convenience)."""
|
||||
|
||||
thread = self._rescan_thread
|
||||
if thread:
|
||||
thread.join(timeout=timeout)
|
||||
|
||||
def _scan_directory_links(self, root: str, visited_dirs: Set[str]):
|
||||
"""Iteratively scan directory symlinks to avoid deep recursion."""
|
||||
@@ -434,6 +459,22 @@ class Config:
|
||||
self._preview_root_paths.update(self._expand_preview_root(normalized_target))
|
||||
self._preview_root_paths.update(self._expand_preview_root(normalized_link))
|
||||
|
||||
def _seed_root_symlink_mappings(self) -> None:
|
||||
"""Ensure symlinked root folders are recorded before deep scanning."""
|
||||
|
||||
for root in self._symlink_roots():
|
||||
if not root:
|
||||
continue
|
||||
try:
|
||||
if not self._is_link(root):
|
||||
continue
|
||||
target_path = os.path.realpath(root)
|
||||
if not os.path.isdir(target_path):
|
||||
continue
|
||||
self.add_path_mapping(root, target_path)
|
||||
except Exception as exc:
|
||||
logger.debug("Skipping root symlink %s: %s", root, exc)
|
||||
|
||||
def _expand_preview_root(self, path: str) -> Set[Path]:
|
||||
"""Return normalized ``Path`` objects representing a preview root."""
|
||||
|
||||
|
||||
@@ -39,8 +39,39 @@ class MetadataProcessor:
|
||||
if node_id in metadata.get(SAMPLING, {}) and metadata[SAMPLING][node_id].get(IS_SAMPLER, False):
|
||||
candidate_samplers[node_id] = metadata[SAMPLING][node_id]
|
||||
|
||||
# If we found candidate samplers, apply primary sampler logic to these candidates only
|
||||
if candidate_samplers:
|
||||
# If we found candidate samplers, apply primary sampler logic to these candidates only
|
||||
|
||||
# PRE-PROCESS: Ensure all candidate samplers have their parameters populated
|
||||
# This is especially important for SamplerCustomAdvanced which needs tracing
|
||||
prompt = metadata.get("current_prompt")
|
||||
for node_id in candidate_samplers:
|
||||
# If a sampler is missing common parameters like steps or denoise,
|
||||
# try to populate them using tracing before ranking
|
||||
sampler_info = candidate_samplers[node_id]
|
||||
params = sampler_info.get("parameters", {})
|
||||
|
||||
if prompt and (params.get("steps") is None or params.get("denoise") is None):
|
||||
# Create a temporary params dict to use the handler
|
||||
temp_params = {
|
||||
"steps": params.get("steps"),
|
||||
"denoise": params.get("denoise"),
|
||||
"sampler": params.get("sampler_name"),
|
||||
"scheduler": params.get("scheduler")
|
||||
}
|
||||
|
||||
# Check if it's SamplerCustomAdvanced
|
||||
if prompt.original_prompt and node_id in prompt.original_prompt:
|
||||
if prompt.original_prompt[node_id].get("class_type") == "SamplerCustomAdvanced":
|
||||
MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, node_id, temp_params)
|
||||
|
||||
# Update the actual parameters with found values
|
||||
params["steps"] = temp_params.get("steps")
|
||||
params["denoise"] = temp_params.get("denoise")
|
||||
if temp_params.get("sampler"):
|
||||
params["sampler_name"] = temp_params.get("sampler")
|
||||
if temp_params.get("scheduler"):
|
||||
params["scheduler"] = temp_params.get("scheduler")
|
||||
|
||||
# Collect potential primary samplers based on different criteria
|
||||
custom_advanced_samplers = []
|
||||
advanced_add_noise_samplers = []
|
||||
@@ -49,7 +80,6 @@ class MetadataProcessor:
|
||||
high_denoise_id = None
|
||||
|
||||
# First, check for SamplerCustomAdvanced among candidates
|
||||
prompt = metadata.get("current_prompt")
|
||||
if prompt and prompt.original_prompt:
|
||||
for node_id in candidate_samplers:
|
||||
node_info = prompt.original_prompt.get(node_id, {})
|
||||
@@ -77,15 +107,16 @@ class MetadataProcessor:
|
||||
# Combine all potential primary samplers
|
||||
potential_samplers = custom_advanced_samplers + advanced_add_noise_samplers + high_denoise_samplers
|
||||
|
||||
# Find the most recent potential primary sampler (closest to downstream node)
|
||||
for i in range(downstream_index - 1, -1, -1):
|
||||
# Find the first potential primary sampler (prefer base sampler over refine)
|
||||
# Use forward search to prioritize the first one in execution order
|
||||
for i in range(downstream_index):
|
||||
node_id = execution_order[i]
|
||||
if node_id in potential_samplers:
|
||||
return node_id, candidate_samplers[node_id]
|
||||
|
||||
# If no potential sampler found from our criteria, return the most recent sampler
|
||||
# If no potential sampler found from our criteria, return the first sampler
|
||||
if candidate_samplers:
|
||||
for i in range(downstream_index - 1, -1, -1):
|
||||
for i in range(downstream_index):
|
||||
node_id = execution_order[i]
|
||||
if node_id in candidate_samplers:
|
||||
return node_id, candidate_samplers[node_id]
|
||||
@@ -176,8 +207,11 @@ class MetadataProcessor:
|
||||
found_node_id = input_value[0] # Connected node_id
|
||||
|
||||
# If we're looking for a specific node class
|
||||
if target_class and prompt.original_prompt[found_node_id].get("class_type") == target_class:
|
||||
return found_node_id
|
||||
if target_class:
|
||||
if found_node_id not in prompt.original_prompt:
|
||||
return None
|
||||
if prompt.original_prompt[found_node_id].get("class_type") == target_class:
|
||||
return found_node_id
|
||||
|
||||
# If we're not looking for a specific class, update the last valid node
|
||||
if not target_class:
|
||||
@@ -185,11 +219,19 @@ class MetadataProcessor:
|
||||
|
||||
# Continue tracing through intermediate nodes
|
||||
current_node_id = found_node_id
|
||||
# For most conditioning nodes, the input we want to follow is named "conditioning"
|
||||
if "conditioning" in prompt.original_prompt[current_node_id].get("inputs", {}):
|
||||
|
||||
# Check if current source node exists
|
||||
if current_node_id not in prompt.original_prompt:
|
||||
return found_node_id if not target_class else None
|
||||
|
||||
# Determine which input to follow next on the source node
|
||||
source_node_inputs = prompt.original_prompt[current_node_id].get("inputs", {})
|
||||
if input_name in source_node_inputs:
|
||||
current_input = input_name
|
||||
elif "conditioning" in source_node_inputs:
|
||||
current_input = "conditioning"
|
||||
else:
|
||||
# If there's no "conditioning" input, return the current node
|
||||
# If there's no suitable input to follow, return the current node
|
||||
# if we're not looking for a specific target_class
|
||||
return found_node_id if not target_class else None
|
||||
else:
|
||||
@@ -202,12 +244,89 @@ class MetadataProcessor:
|
||||
return last_valid_node if not target_class else None
|
||||
|
||||
@staticmethod
|
||||
def find_primary_checkpoint(metadata):
|
||||
"""Find the primary checkpoint model in the workflow"""
|
||||
if not metadata.get(MODELS):
|
||||
def trace_model_path(metadata, prompt, start_node_id):
|
||||
"""
|
||||
Trace the model connection path upstream to find the checkpoint
|
||||
"""
|
||||
if not prompt or not prompt.original_prompt:
|
||||
return None
|
||||
|
||||
# In most workflows, there's only one checkpoint, so we can just take the first one
|
||||
current_node_id = start_node_id
|
||||
depth = 0
|
||||
max_depth = 50
|
||||
|
||||
while depth < max_depth:
|
||||
# Check if current node is a registered checkpoint in our metadata
|
||||
# This handles cached nodes correctly because metadata contains info for all nodes in the graph
|
||||
if current_node_id in metadata.get(MODELS, {}):
|
||||
if metadata[MODELS][current_node_id].get("type") == "checkpoint":
|
||||
return current_node_id
|
||||
|
||||
if current_node_id not in prompt.original_prompt:
|
||||
return None
|
||||
|
||||
node = prompt.original_prompt[current_node_id]
|
||||
inputs = node.get("inputs", {})
|
||||
class_type = node.get("class_type", "")
|
||||
|
||||
# Determine which input to follow next
|
||||
next_input_name = "model"
|
||||
|
||||
# Special handling for initial node
|
||||
if depth == 0:
|
||||
if class_type == "SamplerCustomAdvanced":
|
||||
next_input_name = "guider"
|
||||
|
||||
# If the specific input doesn't exist, try generic 'model'
|
||||
if next_input_name not in inputs:
|
||||
if "model" in inputs:
|
||||
next_input_name = "model"
|
||||
elif "basic_pipe" in inputs:
|
||||
# Handle pipe nodes like FromBasicPipe by following the pipeline
|
||||
next_input_name = "basic_pipe"
|
||||
else:
|
||||
# Dead end - no model input to follow
|
||||
return None
|
||||
|
||||
# Get connected node
|
||||
input_val = inputs[next_input_name]
|
||||
if isinstance(input_val, list) and len(input_val) > 0:
|
||||
current_node_id = input_val[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
depth += 1
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def find_primary_checkpoint(metadata, downstream_id=None, primary_sampler_id=None):
|
||||
"""
|
||||
Find the primary checkpoint model in the workflow
|
||||
|
||||
Parameters:
|
||||
- metadata: The workflow metadata
|
||||
- downstream_id: Optional ID of a downstream node to help identify the specific primary sampler
|
||||
- primary_sampler_id: Optional ID of the primary sampler if already known
|
||||
"""
|
||||
if not metadata.get(MODELS):
|
||||
return None
|
||||
|
||||
# Method 1: Topology-based tracing (More accurate for complex workflows)
|
||||
# First, find the primary sampler if not provided
|
||||
if not primary_sampler_id:
|
||||
primary_sampler_id, _ = MetadataProcessor.find_primary_sampler(metadata, downstream_id)
|
||||
|
||||
if primary_sampler_id:
|
||||
prompt = metadata.get("current_prompt")
|
||||
if prompt:
|
||||
# Trace back from the sampler to find the checkpoint
|
||||
checkpoint_id = MetadataProcessor.trace_model_path(metadata, prompt, primary_sampler_id)
|
||||
if checkpoint_id and checkpoint_id in metadata.get(MODELS, {}):
|
||||
return metadata[MODELS][checkpoint_id].get("name")
|
||||
|
||||
# Method 2: Fallback to the first available checkpoint (Original behavior)
|
||||
# In most simple workflows, there's only one checkpoint, so we can just take the first one
|
||||
for node_id, model_info in metadata.get(MODELS, {}).items():
|
||||
if model_info.get("type") == "checkpoint":
|
||||
return model_info.get("name")
|
||||
@@ -311,7 +430,8 @@ class MetadataProcessor:
|
||||
primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata, id)
|
||||
|
||||
# Directly get checkpoint from metadata instead of tracing
|
||||
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata)
|
||||
# Pass primary_sampler_id to avoid redundant calculation
|
||||
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata, id, primary_sampler_id)
|
||||
if checkpoint:
|
||||
params["checkpoint"] = checkpoint
|
||||
|
||||
@@ -445,6 +565,7 @@ class MetadataProcessor:
|
||||
scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {})
|
||||
params["steps"] = scheduler_params.get("steps")
|
||||
params["scheduler"] = scheduler_params.get("scheduler")
|
||||
params["denoise"] = scheduler_params.get("denoise")
|
||||
|
||||
# 2. Trace sampler input to find KSamplerSelect (only if sampler input exists)
|
||||
if "sampler" in sampler_inputs:
|
||||
|
||||
@@ -9,7 +9,7 @@ from ..metadata_collector import get_metadata
|
||||
from PIL import Image, PngImagePlugin
|
||||
import piexif
|
||||
|
||||
class SaveImage:
|
||||
class SaveImageLM:
|
||||
NAME = "Save Image (LoraManager)"
|
||||
CATEGORY = "Lora Manager/utils"
|
||||
DESCRIPTION = "Save images with embedded generation metadata in compatible format"
|
||||
|
||||
@@ -103,7 +103,7 @@ def nunchaku_load_lora(model, lora_name, lora_strength):
|
||||
|
||||
# Save the transformer temporarily
|
||||
model_wrapper.model = None
|
||||
ret_model = copy.deepcopy(model) # copy everything except the model
|
||||
ret_model = model.clone()
|
||||
ret_model_wrapper = ret_model.model.diffusion_model
|
||||
|
||||
# Restore the model and set it for the copy
|
||||
|
||||
@@ -37,7 +37,8 @@ class RecipeMetadataParser(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def populate_lora_from_civitai(self, lora_entry: Dict[str, Any], civitai_info_tuple: Tuple[Dict[str, Any], Optional[str]],
|
||||
@staticmethod
|
||||
async def populate_lora_from_civitai(lora_entry: Dict[str, Any], civitai_info_tuple: Tuple[Dict[str, Any], Optional[str]],
|
||||
recipe_scanner=None, base_model_counts=None, hash_value=None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Populate a lora entry with information from Civitai API response
|
||||
@@ -148,8 +149,9 @@ class RecipeMetadataParser(ABC):
|
||||
logger.error(f"Error populating lora from Civitai info: {e}")
|
||||
|
||||
return lora_entry
|
||||
|
||||
async def populate_checkpoint_from_civitai(self, checkpoint: Dict[str, Any], civitai_info: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
async def populate_checkpoint_from_civitai(checkpoint: Dict[str, Any], civitai_info: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Populate checkpoint information from Civitai API response
|
||||
|
||||
@@ -187,6 +189,7 @@ class RecipeMetadataParser(ABC):
|
||||
checkpoint['downloadUrl'] = civitai_data.get('downloadUrl', '')
|
||||
|
||||
checkpoint['modelId'] = civitai_data.get('modelId', checkpoint.get('modelId', 0))
|
||||
checkpoint['id'] = civitai_data.get('id', 0)
|
||||
|
||||
if 'files' in civitai_data:
|
||||
model_file = next(
|
||||
|
||||
216
py/recipes/enrichment.py
Normal file
216
py/recipes/enrichment.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import logging
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
from .merger import GenParamsMerger
|
||||
from .base import RecipeMetadataParser
|
||||
from ..services.metadata_service import get_default_metadata_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RecipeEnricher:
|
||||
"""Service to enrich recipe metadata from multiple sources (Civitai, Embedded, User)."""
|
||||
|
||||
@staticmethod
|
||||
async def enrich_recipe(
|
||||
recipe: Dict[str, Any],
|
||||
civitai_client: Any,
|
||||
request_params: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Enrich a recipe dictionary in-place with metadata from Civitai and embedded params.
|
||||
|
||||
Args:
|
||||
recipe: The recipe dictionary to enrich. Must have 'gen_params' initialized.
|
||||
civitai_client: Authenticated Civitai client instance.
|
||||
request_params: (Optional) Parameters from a user request (e.g. import).
|
||||
|
||||
Returns:
|
||||
bool: True if the recipe was modified, False otherwise.
|
||||
"""
|
||||
updated = False
|
||||
gen_params = recipe.get("gen_params", {})
|
||||
|
||||
# 1. Fetch Civitai Info if available
|
||||
civitai_meta = None
|
||||
model_version_id = None
|
||||
|
||||
source_url = recipe.get("source_url") or recipe.get("source_path", "")
|
||||
|
||||
# Check if it's a Civitai image URL
|
||||
image_id_match = re.search(r'civitai\.com/images/(\d+)', str(source_url))
|
||||
if image_id_match:
|
||||
image_id = image_id_match.group(1)
|
||||
try:
|
||||
image_info = await civitai_client.get_image_info(image_id)
|
||||
if image_info:
|
||||
# Handle nested meta often found in Civitai API responses
|
||||
raw_meta = image_info.get("meta")
|
||||
if isinstance(raw_meta, dict):
|
||||
if "meta" in raw_meta and isinstance(raw_meta["meta"], dict):
|
||||
civitai_meta = raw_meta["meta"]
|
||||
else:
|
||||
civitai_meta = raw_meta
|
||||
|
||||
model_version_id = image_info.get("modelVersionId")
|
||||
|
||||
# If not at top level, check resources in meta
|
||||
if not model_version_id and civitai_meta:
|
||||
resources = civitai_meta.get("civitaiResources", [])
|
||||
for res in resources:
|
||||
if res.get("type") == "checkpoint":
|
||||
model_version_id = res.get("modelVersionId")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch Civitai image info: {e}")
|
||||
|
||||
# 2. Merge Parameters
|
||||
# Priority: request_params > civitai_meta > embedded (existing gen_params)
|
||||
new_gen_params = GenParamsMerger.merge(
|
||||
request_params=request_params,
|
||||
civitai_meta=civitai_meta,
|
||||
embedded_metadata=gen_params
|
||||
)
|
||||
|
||||
if new_gen_params != gen_params:
|
||||
recipe["gen_params"] = new_gen_params
|
||||
updated = True
|
||||
|
||||
# 3. Checkpoint Enrichment
|
||||
# If we have a checkpoint entry, or we can find one
|
||||
# Use 'id' (from Civitai version) as a marker that it's been enriched
|
||||
checkpoint_entry = recipe.get("checkpoint")
|
||||
has_full_checkpoint = checkpoint_entry and checkpoint_entry.get("name") and checkpoint_entry.get("id")
|
||||
|
||||
if not has_full_checkpoint:
|
||||
# Helper to look up values in priority order
|
||||
def start_lookup(keys):
|
||||
for source in [request_params, civitai_meta, gen_params]:
|
||||
if source:
|
||||
if isinstance(keys, list):
|
||||
for k in keys:
|
||||
if k in source: return source[k]
|
||||
else:
|
||||
if keys in source: return source[keys]
|
||||
return None
|
||||
|
||||
target_version_id = model_version_id or start_lookup("modelVersionId")
|
||||
|
||||
# Also check existing checkpoint entry
|
||||
if not target_version_id and checkpoint_entry:
|
||||
target_version_id = checkpoint_entry.get("modelVersionId") or checkpoint_entry.get("id")
|
||||
|
||||
# Check for version ID in resources (which might be a string in gen_params)
|
||||
if not target_version_id:
|
||||
# Look in all sources for "Civitai resources"
|
||||
resources_val = start_lookup(["Civitai resources", "civitai_resources", "resources"])
|
||||
if resources_val:
|
||||
target_version_id = RecipeEnricher._extract_version_id_from_resources({"Civitai resources": resources_val})
|
||||
|
||||
target_hash = start_lookup(["Model hash", "checkpoint_hash", "hashes"])
|
||||
if not target_hash and checkpoint_entry:
|
||||
target_hash = checkpoint_entry.get("hash") or checkpoint_entry.get("model_hash")
|
||||
|
||||
# Look for 'Model' which sometimes is the hash or name
|
||||
model_val = start_lookup("Model")
|
||||
|
||||
# Look for Checkpoint name fallback
|
||||
checkpoint_val = checkpoint_entry.get("name") if checkpoint_entry else None
|
||||
if not checkpoint_val:
|
||||
checkpoint_val = start_lookup(["Checkpoint", "checkpoint"])
|
||||
|
||||
checkpoint_updated = await RecipeEnricher._resolve_and_populate_checkpoint(
|
||||
recipe, target_version_id, target_hash, model_val, checkpoint_val
|
||||
)
|
||||
if checkpoint_updated:
|
||||
updated = True
|
||||
else:
|
||||
# Checkpoint exists, no need to sync to gen_params anymore.
|
||||
pass
|
||||
# base_model resolution moved to _resolve_and_populate_checkpoint to support strict formatting
|
||||
return updated
|
||||
|
||||
@staticmethod
|
||||
def _extract_version_id_from_resources(gen_params: Dict[str, Any]) -> Optional[Any]:
|
||||
"""Try to find modelVersionId in Civitai resources parameter."""
|
||||
civitai_resources_raw = gen_params.get("Civitai resources")
|
||||
if not civitai_resources_raw:
|
||||
return None
|
||||
|
||||
resources_list = None
|
||||
if isinstance(civitai_resources_raw, str):
|
||||
try:
|
||||
resources_list = json.loads(civitai_resources_raw)
|
||||
except Exception:
|
||||
pass
|
||||
elif isinstance(civitai_resources_raw, list):
|
||||
resources_list = civitai_resources_raw
|
||||
|
||||
if isinstance(resources_list, list):
|
||||
for res in resources_list:
|
||||
if res.get("type") == "checkpoint":
|
||||
return res.get("modelVersionId")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _resolve_and_populate_checkpoint(
|
||||
recipe: Dict[str, Any],
|
||||
target_version_id: Optional[Any],
|
||||
target_hash: Optional[str],
|
||||
model_val: Optional[str],
|
||||
checkpoint_val: Optional[str]
|
||||
) -> bool:
|
||||
"""Find checkpoint metadata and populate it in the recipe."""
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
civitai_info = None
|
||||
|
||||
if target_version_id:
|
||||
civitai_info = await metadata_provider.get_model_version_info(str(target_version_id))
|
||||
elif target_hash:
|
||||
civitai_info = await metadata_provider.get_model_by_hash(target_hash)
|
||||
else:
|
||||
# Look for 'Model' which sometimes is the hash or name
|
||||
if model_val and len(model_val) == 10: # Likely a short hash
|
||||
civitai_info = await metadata_provider.get_model_by_hash(model_val)
|
||||
|
||||
if civitai_info and not (isinstance(civitai_info, tuple) and civitai_info[1] == "Model not found"):
|
||||
# If we already have a partial checkpoint, use it as base
|
||||
existing_cp = recipe.get("checkpoint")
|
||||
if existing_cp is None:
|
||||
existing_cp = {}
|
||||
checkpoint_data = await RecipeMetadataParser.populate_checkpoint_from_civitai(existing_cp, civitai_info)
|
||||
# 1. First, resolve base_model using full data before we format it away
|
||||
current_base_model = recipe.get("base_model")
|
||||
resolved_base_model = checkpoint_data.get("baseModel")
|
||||
if resolved_base_model:
|
||||
# Update if empty OR if it matches our generic prefix but is less specific
|
||||
is_generic = not current_base_model or current_base_model.lower() in ["flux", "sdxl", "sd15"]
|
||||
if is_generic and resolved_base_model != current_base_model:
|
||||
recipe["base_model"] = resolved_base_model
|
||||
|
||||
# 2. Format according to requirements: type, modelId, modelVersionId, modelName, modelVersionName
|
||||
formatted_checkpoint = {
|
||||
"type": "checkpoint",
|
||||
"modelId": checkpoint_data.get("modelId"),
|
||||
"modelVersionId": checkpoint_data.get("id") or checkpoint_data.get("modelVersionId"),
|
||||
"modelName": checkpoint_data.get("name"), # In base.py, 'name' is populated from civitai_data['model']['name']
|
||||
"modelVersionName": checkpoint_data.get("version") # In base.py, 'version' is populated from civitai_data['name']
|
||||
}
|
||||
# Remove None values
|
||||
recipe["checkpoint"] = {k: v for k, v in formatted_checkpoint.items() if v is not None}
|
||||
|
||||
return True
|
||||
else:
|
||||
# Fallback to name extraction if we don't already have one
|
||||
existing_cp = recipe.get("checkpoint")
|
||||
if not existing_cp or not existing_cp.get("modelName"):
|
||||
cp_name = checkpoint_val
|
||||
if cp_name:
|
||||
recipe["checkpoint"] = {
|
||||
"type": "checkpoint",
|
||||
"modelName": cp_name
|
||||
}
|
||||
return True
|
||||
|
||||
return False
|
||||
98
py/recipes/merger.py
Normal file
98
py/recipes/merger.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from typing import Any, Dict, Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GenParamsMerger:
|
||||
"""Utility to merge generation parameters from multiple sources with priority."""
|
||||
|
||||
BLACKLISTED_KEYS = {
|
||||
"id", "url", "userId", "username", "createdAt", "updatedAt", "hash", "meta",
|
||||
"draft", "extra", "width", "height", "process", "quantity", "workflow",
|
||||
"baseModel", "resources", "disablePoi", "aspectRatio", "Created Date",
|
||||
"experimental", "civitaiResources", "civitai_resources", "Civitai resources",
|
||||
"modelVersionId", "modelId", "hashes", "Model", "Model hash", "checkpoint_hash",
|
||||
"checkpoint", "checksum", "model_checksum"
|
||||
}
|
||||
|
||||
NORMALIZATION_MAPPING = {
|
||||
# Civitai specific
|
||||
"cfgScale": "cfg_scale",
|
||||
"clipSkip": "clip_skip",
|
||||
"negativePrompt": "negative_prompt",
|
||||
# Case variations
|
||||
"Sampler": "sampler",
|
||||
"Steps": "steps",
|
||||
"Seed": "seed",
|
||||
"Size": "size",
|
||||
"Prompt": "prompt",
|
||||
"Negative prompt": "negative_prompt",
|
||||
"Cfg scale": "cfg_scale",
|
||||
"Clip skip": "clip_skip",
|
||||
"Denoising strength": "denoising_strength",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def merge(
|
||||
request_params: Optional[Dict[str, Any]] = None,
|
||||
civitai_meta: Optional[Dict[str, Any]] = None,
|
||||
embedded_metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Merge generation parameters from three sources.
|
||||
|
||||
Priority: request_params > civitai_meta > embedded_metadata
|
||||
|
||||
Args:
|
||||
request_params: Params provided directly in the import request
|
||||
civitai_meta: Params from Civitai Image API 'meta' field
|
||||
embedded_metadata: Params extracted from image EXIF/embedded metadata
|
||||
|
||||
Returns:
|
||||
Merged parameters dictionary
|
||||
"""
|
||||
result = {}
|
||||
|
||||
# 1. Start with embedded metadata (lowest priority)
|
||||
if embedded_metadata:
|
||||
# If it's a full recipe metadata, we use its gen_params
|
||||
if "gen_params" in embedded_metadata and isinstance(embedded_metadata["gen_params"], dict):
|
||||
GenParamsMerger._update_normalized(result, embedded_metadata["gen_params"])
|
||||
else:
|
||||
# Otherwise assume the dict itself contains gen_params
|
||||
GenParamsMerger._update_normalized(result, embedded_metadata)
|
||||
|
||||
# 2. Layer Civitai meta (medium priority)
|
||||
if civitai_meta:
|
||||
GenParamsMerger._update_normalized(result, civitai_meta)
|
||||
|
||||
# 3. Layer request params (highest priority)
|
||||
if request_params:
|
||||
GenParamsMerger._update_normalized(result, request_params)
|
||||
|
||||
# Filter out blacklisted keys and also the original camelCase keys if they were normalized
|
||||
final_result = {}
|
||||
for k, v in result.items():
|
||||
if k in GenParamsMerger.BLACKLISTED_KEYS:
|
||||
continue
|
||||
if k in GenParamsMerger.NORMALIZATION_MAPPING:
|
||||
continue
|
||||
final_result[k] = v
|
||||
|
||||
return final_result
|
||||
|
||||
@staticmethod
|
||||
def _update_normalized(target: Dict[str, Any], source: Dict[str, Any]) -> None:
|
||||
"""Update target dict with normalized keys from source."""
|
||||
for k, v in source.items():
|
||||
normalized_key = GenParamsMerger.NORMALIZATION_MAPPING.get(k, k)
|
||||
target[normalized_key] = v
|
||||
# Also keep the original key for now if it's not the same,
|
||||
# so we can filter at the end or avoid losing it if it wasn't supposed to be renamed?
|
||||
# Actually, if we rename it, we should probably NOT keep both in 'target'
|
||||
# because we want to filter them out at the end anyway.
|
||||
if normalized_key != k:
|
||||
# If we are overwriting an existing snake_case key with a camelCase one's value,
|
||||
# that's fine because of the priority order of calls to _update_normalized.
|
||||
pass
|
||||
target[k] = v
|
||||
@@ -36,9 +36,6 @@ class ComfyMetadataParser(RecipeMetadataParser):
|
||||
# Find all LoraLoader nodes
|
||||
lora_nodes = {k: v for k, v in data.items() if isinstance(v, dict) and v.get('class_type') == 'LoraLoader'}
|
||||
|
||||
if not lora_nodes:
|
||||
return {"error": "No LoRA information found in this ComfyUI workflow", "loras": []}
|
||||
|
||||
# Process each LoraLoader node
|
||||
for node_id, node in lora_nodes.items():
|
||||
if 'inputs' not in node or 'lora_name' not in node['inputs']:
|
||||
|
||||
@@ -79,26 +79,8 @@ class BaseRecipeRoutes:
|
||||
return
|
||||
|
||||
app.on_startup.append(self.attach_dependencies)
|
||||
app.on_startup.append(self.prewarm_cache)
|
||||
self._startup_hooks_registered = True
|
||||
|
||||
async def prewarm_cache(self, app: web.Application | None = None) -> None:
|
||||
"""Pre-load recipe and LoRA caches on startup."""
|
||||
|
||||
try:
|
||||
await self.attach_dependencies(app)
|
||||
|
||||
if self.lora_scanner is not None:
|
||||
await self.lora_scanner.get_cached_data()
|
||||
hash_index = getattr(self.lora_scanner, "_hash_index", None)
|
||||
if hash_index is not None and hasattr(hash_index, "_hash_to_path"):
|
||||
_ = len(hash_index._hash_to_path)
|
||||
|
||||
if self.recipe_scanner is not None:
|
||||
await self.recipe_scanner.get_cached_data(force_refresh=True)
|
||||
except Exception as exc:
|
||||
logger.error("Error pre-warming recipe cache: %s", exc, exc_info=True)
|
||||
|
||||
def to_route_mapping(self) -> Mapping[str, Callable]:
|
||||
"""Return a mapping of handler name to coroutine for registrar binding."""
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Mapping, Optional
|
||||
|
||||
@@ -61,6 +62,37 @@ class ModelPageView:
|
||||
self._settings = settings_service
|
||||
self._server_i18n = server_i18n
|
||||
self._logger = logger
|
||||
self._app_version = self._get_app_version()
|
||||
|
||||
def _get_app_version(self) -> str:
|
||||
version = "1.0.0"
|
||||
short_hash = "stable"
|
||||
try:
|
||||
import toml
|
||||
current_file = os.path.abspath(__file__)
|
||||
# Navigate up from py/routes/handlers/model_handlers.py to project root
|
||||
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
|
||||
pyproject_path = os.path.join(root_dir, 'pyproject.toml')
|
||||
|
||||
if os.path.exists(pyproject_path):
|
||||
with open(pyproject_path, 'r', encoding='utf-8') as f:
|
||||
data = toml.load(f)
|
||||
version = data.get('project', {}).get('version', '1.0.0').replace('v', '')
|
||||
|
||||
# Try to get git info for granular cache busting
|
||||
git_dir = os.path.join(root_dir, '.git')
|
||||
if os.path.exists(git_dir):
|
||||
try:
|
||||
import git
|
||||
repo = git.Repo(root_dir)
|
||||
short_hash = repo.head.commit.hexsha[:7]
|
||||
except Exception:
|
||||
# Fallback if git is not available or not a repo
|
||||
pass
|
||||
except Exception as e:
|
||||
self._logger.debug(f"Failed to read version info for cache busting: {e}")
|
||||
|
||||
return f"{version}-{short_hash}"
|
||||
|
||||
async def handle(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
@@ -96,6 +128,7 @@ class ModelPageView:
|
||||
"request": request,
|
||||
"folders": [],
|
||||
"t": self._server_i18n.get_translation,
|
||||
"version": self._app_version,
|
||||
}
|
||||
|
||||
if not is_initializing:
|
||||
@@ -128,9 +161,12 @@ class ModelListingHandler:
|
||||
self._logger = logger
|
||||
|
||||
async def get_models(self, request: web.Request) -> web.Response:
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
params = self._parse_common_params(request)
|
||||
result = await self._service.get_paginated_data(**params)
|
||||
|
||||
format_start = time.perf_counter()
|
||||
formatted_result = {
|
||||
"items": [await self._service.format_response(item) for item in result["items"]],
|
||||
"total": result["total"],
|
||||
@@ -138,6 +174,13 @@ class ModelListingHandler:
|
||||
"page_size": result["page_size"],
|
||||
"total_pages": result["total_pages"],
|
||||
}
|
||||
format_duration = time.perf_counter() - format_start
|
||||
|
||||
duration = time.perf_counter() - start_time
|
||||
self._logger.info(
|
||||
"Request for %s/list took %.3fs (formatting: %.3fs)",
|
||||
self._service.model_type, duration, format_duration
|
||||
)
|
||||
return web.json_response(formatted_result)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving %ss: %s", self._service.model_type, exc, exc_info=True)
|
||||
|
||||
@@ -5,6 +5,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import asyncio
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional
|
||||
@@ -23,6 +24,11 @@ from ...services.recipes import (
|
||||
RecipeValidationError,
|
||||
)
|
||||
from ...services.metadata_service import get_default_metadata_provider
|
||||
from ...utils.civitai_utils import rewrite_preview_url
|
||||
from ...utils.exif_utils import ExifUtils
|
||||
from ...recipes.merger import GenParamsMerger
|
||||
from ...recipes.enrichment import RecipeEnricher
|
||||
from ...services.websocket_manager import ws_manager as default_ws_manager
|
||||
|
||||
Logger = logging.Logger
|
||||
EnsureDependenciesCallable = Callable[[], Awaitable[None]]
|
||||
@@ -55,16 +61,25 @@ class RecipeHandlerSet:
|
||||
"delete_recipe": self.management.delete_recipe,
|
||||
"get_top_tags": self.query.get_top_tags,
|
||||
"get_base_models": self.query.get_base_models,
|
||||
"get_roots": self.query.get_roots,
|
||||
"get_folders": self.query.get_folders,
|
||||
"get_folder_tree": self.query.get_folder_tree,
|
||||
"get_unified_folder_tree": self.query.get_unified_folder_tree,
|
||||
"share_recipe": self.sharing.share_recipe,
|
||||
"download_shared_recipe": self.sharing.download_shared_recipe,
|
||||
"get_recipe_syntax": self.query.get_recipe_syntax,
|
||||
"update_recipe": self.management.update_recipe,
|
||||
"reconnect_lora": self.management.reconnect_lora,
|
||||
"find_duplicates": self.query.find_duplicates,
|
||||
"move_recipes_bulk": self.management.move_recipes_bulk,
|
||||
"bulk_delete": self.management.bulk_delete,
|
||||
"save_recipe_from_widget": self.management.save_recipe_from_widget,
|
||||
"get_recipes_for_lora": self.query.get_recipes_for_lora,
|
||||
"scan_recipes": self.query.scan_recipes,
|
||||
"move_recipe": self.management.move_recipe,
|
||||
"repair_recipes": self.management.repair_recipes,
|
||||
"repair_recipe": self.management.repair_recipe,
|
||||
"get_repair_progress": self.management.get_repair_progress,
|
||||
}
|
||||
|
||||
|
||||
@@ -148,12 +163,15 @@ class RecipeListingHandler:
|
||||
page_size = int(request.query.get("page_size", "20"))
|
||||
sort_by = request.query.get("sort_by", "date")
|
||||
search = request.query.get("search")
|
||||
folder = request.query.get("folder")
|
||||
recursive = request.query.get("recursive", "true").lower() == "true"
|
||||
|
||||
search_options = {
|
||||
"title": request.query.get("search_title", "true").lower() == "true",
|
||||
"tags": request.query.get("search_tags", "true").lower() == "true",
|
||||
"lora_name": request.query.get("search_lora_name", "true").lower() == "true",
|
||||
"lora_model": request.query.get("search_lora_model", "true").lower() == "true",
|
||||
"prompt": request.query.get("search_prompt", "true").lower() == "true",
|
||||
}
|
||||
|
||||
filters: Dict[str, Any] = {}
|
||||
@@ -161,6 +179,9 @@ class RecipeListingHandler:
|
||||
if base_models:
|
||||
filters["base_model"] = base_models.split(",")
|
||||
|
||||
if request.query.get("favorite", "false").lower() == "true":
|
||||
filters["favorite"] = True
|
||||
|
||||
tag_filters: Dict[str, str] = {}
|
||||
legacy_tags = request.query.get("tags")
|
||||
if legacy_tags:
|
||||
@@ -192,6 +213,8 @@ class RecipeListingHandler:
|
||||
filters=filters,
|
||||
search_options=search_options,
|
||||
lora_hash=lora_hash,
|
||||
folder=folder,
|
||||
recursive=recursive,
|
||||
)
|
||||
|
||||
for item in result.get("items", []):
|
||||
@@ -298,6 +321,58 @@ class RecipeQueryHandler:
|
||||
self._logger.error("Error retrieving base models: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_roots(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
roots = [recipe_scanner.recipes_dir] if recipe_scanner.recipes_dir else []
|
||||
return web.json_response({"success": True, "roots": roots})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving recipe roots: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_folders(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
folders = await recipe_scanner.get_folders()
|
||||
return web.json_response({"success": True, "folders": folders})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving recipe folders: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_folder_tree(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
folder_tree = await recipe_scanner.get_folder_tree()
|
||||
return web.json_response({"success": True, "tree": folder_tree})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving recipe folder tree: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_unified_folder_tree(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
folder_tree = await recipe_scanner.get_folder_tree()
|
||||
return web.json_response({"success": True, "tree": folder_tree})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving unified recipe folder tree: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_recipes_for_lora(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
@@ -410,6 +485,7 @@ class RecipeManagementHandler:
|
||||
analysis_service: RecipeAnalysisService,
|
||||
downloader_factory,
|
||||
civitai_client_getter: CivitaiClientGetter,
|
||||
ws_manager=default_ws_manager,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
@@ -418,6 +494,7 @@ class RecipeManagementHandler:
|
||||
self._analysis_service = analysis_service
|
||||
self._downloader_factory = downloader_factory
|
||||
self._civitai_client_getter = civitai_client_getter
|
||||
self._ws_manager = ws_manager
|
||||
|
||||
async def save_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
@@ -436,6 +513,7 @@ class RecipeManagementHandler:
|
||||
name=payload["name"],
|
||||
tags=payload["tags"],
|
||||
metadata=payload["metadata"],
|
||||
extension=payload.get("extension"),
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
@@ -444,17 +522,84 @@ class RecipeManagementHandler:
|
||||
self._logger.error("Error saving recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def repair_recipes(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
return web.json_response({"success": False, "error": "Recipe scanner unavailable"}, status=503)
|
||||
|
||||
# Check if already running
|
||||
if self._ws_manager.get_recipe_repair_progress():
|
||||
return web.json_response({"success": False, "error": "Recipe repair already in progress"}, status=409)
|
||||
|
||||
async def progress_callback(data):
|
||||
await self._ws_manager.broadcast_recipe_repair_progress(data)
|
||||
|
||||
# Run in background to avoid timeout
|
||||
async def run_repair():
|
||||
try:
|
||||
await recipe_scanner.repair_all_recipes(
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error in recipe repair task: {e}", exc_info=True)
|
||||
await self._ws_manager.broadcast_recipe_repair_progress({
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
})
|
||||
finally:
|
||||
# Keep the final status for a while so the UI can see it
|
||||
await asyncio.sleep(5)
|
||||
self._ws_manager.cleanup_recipe_repair_progress()
|
||||
|
||||
asyncio.create_task(run_repair())
|
||||
|
||||
return web.json_response({"success": True, "message": "Recipe repair started"})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error starting recipe repair: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def repair_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
return web.json_response({"success": False, "error": "Recipe scanner unavailable"}, status=503)
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
result = await recipe_scanner.repair_recipe_by_id(recipe_id)
|
||||
return web.json_response(result)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error repairing single recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_repair_progress(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
progress = self._ws_manager.get_recipe_repair_progress()
|
||||
if progress:
|
||||
return web.json_response({"success": True, "progress": progress})
|
||||
return web.json_response({"success": False, "message": "No repair in progress"}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error getting repair progress: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
async def import_remote_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
|
||||
# 1. Parse Parameters
|
||||
params = request.rel_url.query
|
||||
image_url = params.get("image_url")
|
||||
name = params.get("name")
|
||||
resources_raw = params.get("resources")
|
||||
|
||||
if not image_url:
|
||||
raise RecipeValidationError("Missing required field: image_url")
|
||||
if not name:
|
||||
@@ -463,27 +608,93 @@ class RecipeManagementHandler:
|
||||
raise RecipeValidationError("Missing required field: resources")
|
||||
|
||||
checkpoint_entry, lora_entries = self._parse_resources_payload(resources_raw)
|
||||
gen_params = self._parse_gen_params(params.get("gen_params"))
|
||||
gen_params_request = self._parse_gen_params(params.get("gen_params"))
|
||||
|
||||
# 2. Initial Metadata Construction
|
||||
metadata: Dict[str, Any] = {
|
||||
"base_model": params.get("base_model", "") or "",
|
||||
"loras": lora_entries,
|
||||
"gen_params": gen_params_request or {},
|
||||
"source_url": image_url
|
||||
}
|
||||
|
||||
source_path = params.get("source_path")
|
||||
if source_path:
|
||||
metadata["source_path"] = source_path
|
||||
if gen_params is not None:
|
||||
metadata["gen_params"] = gen_params
|
||||
|
||||
# Checkpoint handling
|
||||
if checkpoint_entry:
|
||||
metadata["checkpoint"] = checkpoint_entry
|
||||
gen_params_ref = metadata.setdefault("gen_params", {})
|
||||
if "checkpoint" not in gen_params_ref:
|
||||
gen_params_ref["checkpoint"] = checkpoint_entry
|
||||
base_model_from_metadata = await self._resolve_base_model_from_checkpoint(checkpoint_entry)
|
||||
if base_model_from_metadata:
|
||||
metadata["base_model"] = base_model_from_metadata
|
||||
# Ensure checkpoint is also in gen_params for consistency if needed by enricher?
|
||||
# Actually enricher looks at metadata['checkpoint'], so this is fine.
|
||||
|
||||
# Try to resolve base model from checkpoint if not explicitly provided
|
||||
if not metadata["base_model"]:
|
||||
base_model_from_metadata = await self._resolve_base_model_from_checkpoint(checkpoint_entry)
|
||||
if base_model_from_metadata:
|
||||
metadata["base_model"] = base_model_from_metadata
|
||||
|
||||
tags = self._parse_tags(params.get("tags"))
|
||||
image_bytes = await self._download_image_bytes(image_url)
|
||||
|
||||
# 3. Download Image
|
||||
image_bytes, extension, civitai_meta_from_download = await self._download_remote_media(image_url)
|
||||
|
||||
# 4. Extract Embedded Metadata
|
||||
# Note: We still extract this here because Enricher currently expects 'gen_params' to already be populated
|
||||
# with embedded data if we want it to merge it.
|
||||
# However, logic in Enricher merges: request > civitai > embedded.
|
||||
# So we should gather embedded params and put them into the recipe's gen_params (as initial state)
|
||||
# OR pass them to enricher to handle?
|
||||
# The interface of Enricher.enrich_recipe takes `recipe` (with gen_params) and `request_params`.
|
||||
# So let's extract embedded and put it into recipe['gen_params'] but careful not to overwrite request params.
|
||||
# Actually, `GenParamsMerger` which `Enricher` uses handles 3 layers.
|
||||
# But `Enricher` interface is: recipe['gen_params'] (as embedded) + request_params + civitai (fetched internally).
|
||||
# Wait, `Enricher` fetches Civitai info internally based on URL.
|
||||
# `civitai_meta_from_download` is returned by `_download_remote_media` which might be useful if URL didn't have ID.
|
||||
|
||||
# Let's extract embedded metadata first
|
||||
embedded_gen_params = {}
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_img:
|
||||
temp_img.write(image_bytes)
|
||||
temp_img_path = temp_img.name
|
||||
|
||||
try:
|
||||
raw_embedded = ExifUtils.extract_image_metadata(temp_img_path)
|
||||
if raw_embedded:
|
||||
parser = self._analysis_service._recipe_parser_factory.create_parser(raw_embedded)
|
||||
if parser:
|
||||
parsed_embedded = await parser.parse_metadata(raw_embedded, recipe_scanner=recipe_scanner)
|
||||
if parsed_embedded and "gen_params" in parsed_embedded:
|
||||
embedded_gen_params = parsed_embedded["gen_params"]
|
||||
else:
|
||||
embedded_gen_params = {"raw_metadata": raw_embedded}
|
||||
finally:
|
||||
if os.path.exists(temp_img_path):
|
||||
os.unlink(temp_img_path)
|
||||
except Exception as exc:
|
||||
self._logger.warning("Failed to extract embedded metadata during import: %s", exc)
|
||||
|
||||
# Pre-populate gen_params with embedded data so Enricher treats it as the "base" layer
|
||||
if embedded_gen_params:
|
||||
# Merge embedded into existing gen_params (which currently only has request params if any)
|
||||
# But wait, we want request params to override everything.
|
||||
# So we should set recipe['gen_params'] = embedded, and pass request params to enricher.
|
||||
metadata["gen_params"] = embedded_gen_params
|
||||
|
||||
# 5. Enrich with unified logic
|
||||
# This will fetch Civitai info (if URL matches) and merge: request > civitai > embedded
|
||||
civitai_client = self._civitai_client_getter()
|
||||
await RecipeEnricher.enrich_recipe(
|
||||
recipe=metadata,
|
||||
civitai_client=civitai_client,
|
||||
request_params=gen_params_request # Pass explicit request params here to override
|
||||
)
|
||||
|
||||
# If we got civitai_meta from download but Enricher didn't fetch it (e.g. not a civitai URL or failed),
|
||||
# we might want to manually merge it?
|
||||
# But usually `import_remote_recipe` is used with Civitai URLs.
|
||||
# For now, relying on Enricher's internal fetch is consistent with repair.
|
||||
|
||||
result = await self._persistence_service.save_recipe(
|
||||
recipe_scanner=recipe_scanner,
|
||||
@@ -492,6 +703,7 @@ class RecipeManagementHandler:
|
||||
name=name,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
extension=extension,
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
@@ -541,6 +753,64 @@ class RecipeManagementHandler:
|
||||
self._logger.error("Error updating recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def move_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
data = await request.json()
|
||||
recipe_id = data.get("recipe_id")
|
||||
target_path = data.get("target_path")
|
||||
if not recipe_id or not target_path:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "recipe_id and target_path are required"}, status=400
|
||||
)
|
||||
|
||||
result = await self._persistence_service.move_recipe(
|
||||
recipe_scanner=recipe_scanner,
|
||||
recipe_id=str(recipe_id),
|
||||
target_path=str(target_path),
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error moving recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def move_recipes_bulk(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
data = await request.json()
|
||||
recipe_ids = data.get("recipe_ids") or []
|
||||
target_path = data.get("target_path")
|
||||
if not recipe_ids or not target_path:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "recipe_ids and target_path are required"}, status=400
|
||||
)
|
||||
|
||||
result = await self._persistence_service.move_recipes_bulk(
|
||||
recipe_scanner=recipe_scanner,
|
||||
recipe_ids=recipe_ids,
|
||||
target_path=str(target_path),
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error moving recipes in bulk: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def reconnect_lora(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
@@ -622,6 +892,7 @@ class RecipeManagementHandler:
|
||||
name: Optional[str] = None
|
||||
tags: list[str] = []
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
extension: Optional[str] = None
|
||||
|
||||
while True:
|
||||
field = await reader.next()
|
||||
@@ -652,6 +923,8 @@ class RecipeManagementHandler:
|
||||
metadata = json.loads(metadata_text)
|
||||
except Exception:
|
||||
metadata = {}
|
||||
elif field.name == "extension":
|
||||
extension = await field.text()
|
||||
|
||||
return {
|
||||
"image_bytes": image_bytes,
|
||||
@@ -659,6 +932,7 @@ class RecipeManagementHandler:
|
||||
"name": name,
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
"extension": extension,
|
||||
}
|
||||
|
||||
def _parse_tags(self, tag_text: Optional[str]) -> list[str]:
|
||||
@@ -729,7 +1003,7 @@ class RecipeManagementHandler:
|
||||
"exclude": False,
|
||||
}
|
||||
|
||||
async def _download_image_bytes(self, image_url: str) -> bytes:
|
||||
async def _download_remote_media(self, image_url: str) -> tuple[bytes, str]:
|
||||
civitai_client = self._civitai_client_getter()
|
||||
downloader = await self._downloader_factory()
|
||||
temp_path = None
|
||||
@@ -744,15 +1018,31 @@ class RecipeManagementHandler:
|
||||
image_info = await civitai_client.get_image_info(civitai_match.group(1))
|
||||
if not image_info:
|
||||
raise RecipeDownloadError("Failed to fetch image information from Civitai")
|
||||
download_url = image_info.get("url")
|
||||
if not download_url:
|
||||
|
||||
media_url = image_info.get("url")
|
||||
if not media_url:
|
||||
raise RecipeDownloadError("No image URL found in Civitai response")
|
||||
|
||||
# Use optimized preview URLs if possible
|
||||
media_type = image_info.get("type")
|
||||
rewritten_url, _ = rewrite_preview_url(media_url, media_type=media_type)
|
||||
if rewritten_url:
|
||||
download_url = rewritten_url
|
||||
else:
|
||||
download_url = media_url
|
||||
|
||||
success, result = await downloader.download_file(download_url, temp_path, use_auth=False)
|
||||
if not success:
|
||||
raise RecipeDownloadError(f"Failed to download image: {result}")
|
||||
|
||||
# Extract extension from URL
|
||||
url_path = download_url.split('?')[0].split('#')[0]
|
||||
extension = os.path.splitext(url_path)[1].lower()
|
||||
if not extension:
|
||||
extension = ".webp" # Default to webp if unknown
|
||||
|
||||
with open(temp_path, "rb") as file_obj:
|
||||
return file_obj.read()
|
||||
return file_obj.read(), extension, image_info.get("meta") if civitai_match and image_info else None
|
||||
except RecipeDownloadError:
|
||||
raise
|
||||
except RecipeValidationError:
|
||||
@@ -766,6 +1056,7 @@ class RecipeManagementHandler:
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def _safe_int(self, value: Any) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
|
||||
@@ -27,16 +27,25 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("DELETE", "/api/lm/recipe/{recipe_id}", "delete_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/top-tags", "get_top_tags"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/base-models", "get_base_models"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/roots", "get_roots"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/folders", "get_folders"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/folder-tree", "get_folder_tree"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/unified-folder-tree", "get_unified_folder_tree"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share", "share_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share/download", "download_shared_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/syntax", "get_recipe_syntax"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/syntax", "get_recipe_syntax"),
|
||||
RouteDefinition("PUT", "/api/lm/recipe/{recipe_id}/update", "update_recipe"),
|
||||
RouteDefinition("POST", "/api/lm/recipe/move", "move_recipe"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/move-bulk", "move_recipes_bulk"),
|
||||
RouteDefinition("POST", "/api/lm/recipe/lora/reconnect", "reconnect_lora"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/find-duplicates", "find_duplicates"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/bulk-delete", "bulk_delete"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/for-lora", "get_recipes_for_lora"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/scan", "scan_recipes"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/repair", "repair_recipes"),
|
||||
RouteDefinition("POST", "/api/lm/recipe/{recipe_id}/repair", "repair_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/repair-progress", "get_repair_progress"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import asyncio
|
||||
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
from ..utils.constants import VALID_LORA_TYPES
|
||||
from ..utils.models import BaseModelMetadata
|
||||
@@ -80,13 +81,20 @@ class BaseModelService(ABC):
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""Get paginated and filtered model data"""
|
||||
overall_start = time.perf_counter()
|
||||
|
||||
sort_params = self.cache_repository.parse_sort(sort_by)
|
||||
if sort_params.key == 'usage':
|
||||
sorted_data = await self._fetch_with_usage_sort(sort_params)
|
||||
else:
|
||||
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
|
||||
fetch_duration = time.perf_counter() - t0
|
||||
initial_count = len(sorted_data)
|
||||
|
||||
t1 = time.perf_counter()
|
||||
if hash_filters:
|
||||
filtered_data = await self._apply_hash_filters(sorted_data, hash_filters)
|
||||
else:
|
||||
@@ -116,17 +124,25 @@ class BaseModelService(ABC):
|
||||
|
||||
if allow_selling_generated_content is not None:
|
||||
filtered_data = await self._apply_allow_selling_filter(filtered_data, allow_selling_generated_content)
|
||||
filter_duration = time.perf_counter() - t1
|
||||
post_filter_count = len(filtered_data)
|
||||
|
||||
annotated_for_filter: Optional[List[Dict]] = None
|
||||
t2 = time.perf_counter()
|
||||
if update_available_only:
|
||||
annotated_for_filter = await self._annotate_update_flags(filtered_data)
|
||||
filtered_data = [
|
||||
item for item in annotated_for_filter
|
||||
if item.get('update_available')
|
||||
]
|
||||
update_filter_duration = time.perf_counter() - t2
|
||||
final_count = len(filtered_data)
|
||||
|
||||
t3 = time.perf_counter()
|
||||
paginated = self._paginate(filtered_data, page, page_size)
|
||||
pagination_duration = time.perf_counter() - t3
|
||||
|
||||
t4 = time.perf_counter()
|
||||
if update_available_only:
|
||||
# Items already include update flags thanks to the pre-filter annotation.
|
||||
paginated['items'] = list(paginated['items'])
|
||||
@@ -134,6 +150,16 @@ class BaseModelService(ABC):
|
||||
paginated['items'] = await self._annotate_update_flags(
|
||||
paginated['items'],
|
||||
)
|
||||
annotate_duration = time.perf_counter() - t4
|
||||
|
||||
overall_duration = time.perf_counter() - overall_start
|
||||
logger.info(
|
||||
"%s.get_paginated_data took %.3fs (fetch: %.3fs, filter: %.3fs, update_filter: %.3fs, pagination: %.3fs, annotate: %.3fs). "
|
||||
"Counts: initial=%d, post_filter=%d, final=%d",
|
||||
self.__class__.__name__, overall_duration, fetch_duration, filter_duration,
|
||||
update_filter_duration, pagination_duration, annotate_duration,
|
||||
initial_count, post_filter_count, final_count
|
||||
)
|
||||
return paginated
|
||||
|
||||
async def _fetch_with_usage_sort(self, sort_params):
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
import asyncio
|
||||
import time
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from operator import itemgetter
|
||||
@@ -215,24 +219,25 @@ class ModelCache:
|
||||
|
||||
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
|
||||
"""Sort data by sort_key and order"""
|
||||
start_time = time.perf_counter()
|
||||
reverse = (order == 'desc')
|
||||
if sort_key == 'name':
|
||||
# Natural sort by configured display name, case-insensitive
|
||||
return natsorted(
|
||||
result = natsorted(
|
||||
data,
|
||||
key=lambda x: self._get_display_name(x).lower(),
|
||||
reverse=reverse
|
||||
)
|
||||
elif sort_key == 'date':
|
||||
# Sort by modified timestamp
|
||||
return sorted(
|
||||
result = sorted(
|
||||
data,
|
||||
key=itemgetter('modified'),
|
||||
reverse=reverse
|
||||
)
|
||||
elif sort_key == 'size':
|
||||
# Sort by file size
|
||||
return sorted(
|
||||
result = sorted(
|
||||
data,
|
||||
key=itemgetter('size'),
|
||||
reverse=reverse
|
||||
@@ -249,16 +254,28 @@ class ModelCache:
|
||||
)
|
||||
else:
|
||||
# Fallback: no sort
|
||||
return list(data)
|
||||
result = list(data)
|
||||
|
||||
duration = time.perf_counter() - start_time
|
||||
if duration > 0.05:
|
||||
logger.info("ModelCache._sort_data(%s, %s) for %d items took %.3fs", sort_key, order, len(data), duration)
|
||||
return result
|
||||
|
||||
async def get_sorted_data(self, sort_key: str = 'name', order: str = 'asc') -> List[Dict]:
|
||||
"""Get sorted data by sort_key and order, using cache if possible"""
|
||||
async with self._lock:
|
||||
if (sort_key, order) == self._last_sort:
|
||||
return self._last_sorted_data
|
||||
|
||||
start_time = time.perf_counter()
|
||||
sorted_data = self._sort_data(self.raw_data, sort_key, order)
|
||||
self._last_sort = (sort_key, order)
|
||||
self._last_sorted_data = sorted_data
|
||||
|
||||
duration = time.perf_counter() - start_time
|
||||
if duration > 0.1:
|
||||
logger.debug("ModelCache.get_sorted_data(%s, %s) took %.3fs", sort_key, order, duration)
|
||||
|
||||
return sorted_data
|
||||
|
||||
async def update_name_display_mode(self, display_mode: str) -> None:
|
||||
|
||||
@@ -5,6 +5,10 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from ..utils.utils import fuzzy_match as default_fuzzy_match
|
||||
import time
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_CIVITAI_MODEL_TYPE = "LORA"
|
||||
@@ -115,22 +119,33 @@ class ModelFilterSet:
|
||||
|
||||
def apply(self, data: Iterable[Dict[str, Any]], criteria: FilterCriteria) -> List[Dict[str, Any]]:
|
||||
"""Return items that satisfy the provided criteria."""
|
||||
overall_start = time.perf_counter()
|
||||
items = list(data)
|
||||
initial_count = len(items)
|
||||
|
||||
if self._settings.get("show_only_sfw", False):
|
||||
t0 = time.perf_counter()
|
||||
threshold = self._nsfw_levels.get("R", 0)
|
||||
items = [
|
||||
item for item in items
|
||||
if not item.get("preview_nsfw_level") or item.get("preview_nsfw_level") < threshold
|
||||
]
|
||||
sfw_duration = time.perf_counter() - t0
|
||||
else:
|
||||
sfw_duration = 0
|
||||
|
||||
favorites_duration = 0
|
||||
if criteria.favorites_only:
|
||||
t0 = time.perf_counter()
|
||||
items = [item for item in items if item.get("favorite", False)]
|
||||
favorites_duration = time.perf_counter() - t0
|
||||
|
||||
folder_duration = 0
|
||||
folder = criteria.folder
|
||||
options = criteria.search_options or {}
|
||||
recursive = bool(options.get("recursive", True))
|
||||
if folder is not None:
|
||||
t0 = time.perf_counter()
|
||||
if recursive:
|
||||
if folder:
|
||||
folder_with_sep = f"{folder}/"
|
||||
@@ -140,51 +155,82 @@ class ModelFilterSet:
|
||||
]
|
||||
else:
|
||||
items = [item for item in items if item.get("folder") == folder]
|
||||
folder_duration = time.perf_counter() - t0
|
||||
|
||||
base_models_duration = 0
|
||||
base_models = criteria.base_models or []
|
||||
if base_models:
|
||||
t0 = time.perf_counter()
|
||||
base_model_set = set(base_models)
|
||||
items = [item for item in items if item.get("base_model") in base_model_set]
|
||||
base_models_duration = time.perf_counter() - t0
|
||||
|
||||
tags_duration = 0
|
||||
tag_filters = criteria.tags or {}
|
||||
include_tags = set()
|
||||
exclude_tags = set()
|
||||
if isinstance(tag_filters, dict):
|
||||
for tag, state in tag_filters.items():
|
||||
if not tag:
|
||||
continue
|
||||
if state == "exclude":
|
||||
exclude_tags.add(tag)
|
||||
else:
|
||||
include_tags.add(tag)
|
||||
else:
|
||||
include_tags = {tag for tag in tag_filters if tag}
|
||||
if tag_filters:
|
||||
t0 = time.perf_counter()
|
||||
include_tags = set()
|
||||
exclude_tags = set()
|
||||
if isinstance(tag_filters, dict):
|
||||
for tag, state in tag_filters.items():
|
||||
if not tag:
|
||||
continue
|
||||
if state == "exclude":
|
||||
exclude_tags.add(tag)
|
||||
else:
|
||||
include_tags.add(tag)
|
||||
else:
|
||||
include_tags = {tag for tag in tag_filters if tag}
|
||||
|
||||
if include_tags:
|
||||
items = [
|
||||
item for item in items
|
||||
if any(tag in include_tags for tag in (item.get("tags", []) or []))
|
||||
]
|
||||
if include_tags:
|
||||
def matches_include(item_tags):
|
||||
if not item_tags and "__no_tags__" in include_tags:
|
||||
return True
|
||||
return any(tag in include_tags for tag in (item_tags or []))
|
||||
|
||||
if exclude_tags:
|
||||
items = [
|
||||
item for item in items
|
||||
if not any(tag in exclude_tags for tag in (item.get("tags", []) or []))
|
||||
]
|
||||
items = [
|
||||
item for item in items
|
||||
if matches_include(item.get("tags"))
|
||||
]
|
||||
|
||||
if exclude_tags:
|
||||
def matches_exclude(item_tags):
|
||||
if not item_tags and "__no_tags__" in exclude_tags:
|
||||
return True
|
||||
return any(tag in exclude_tags for tag in (item_tags or []))
|
||||
|
||||
items = [
|
||||
item for item in items
|
||||
if not matches_exclude(item.get("tags"))
|
||||
]
|
||||
tags_duration = time.perf_counter() - t0
|
||||
|
||||
model_types_duration = 0
|
||||
model_types = criteria.model_types or []
|
||||
normalized_model_types = {
|
||||
model_type for model_type in (
|
||||
normalize_civitai_model_type(value) for value in model_types
|
||||
)
|
||||
if model_type
|
||||
}
|
||||
if normalized_model_types:
|
||||
items = [
|
||||
item for item in items
|
||||
if normalize_civitai_model_type(resolve_civitai_model_type(item)) in normalized_model_types
|
||||
]
|
||||
if model_types:
|
||||
t0 = time.perf_counter()
|
||||
normalized_model_types = {
|
||||
model_type for model_type in (
|
||||
normalize_civitai_model_type(value) for value in model_types
|
||||
)
|
||||
if model_type
|
||||
}
|
||||
if normalized_model_types:
|
||||
items = [
|
||||
item for item in items
|
||||
if normalize_civitai_model_type(resolve_civitai_model_type(item)) in normalized_model_types
|
||||
]
|
||||
model_types_duration = time.perf_counter() - t0
|
||||
|
||||
duration = time.perf_counter() - overall_start
|
||||
if duration > 0.1: # Only log if it's potentially slow
|
||||
logger.info(
|
||||
"ModelFilterSet.apply took %.3fs (sfw: %.3fs, fav: %.3fs, folder: %.3fs, base: %.3fs, tags: %.3fs, types: %.3fs). "
|
||||
"Count: %d -> %d",
|
||||
duration, sfw_duration, favorites_duration, folder_duration,
|
||||
base_models_duration, tags_duration, model_types_duration,
|
||||
initial_count, len(items)
|
||||
)
|
||||
return items
|
||||
|
||||
|
||||
|
||||
@@ -7,12 +7,18 @@ from natsort import natsorted
|
||||
@dataclass
|
||||
class RecipeCache:
|
||||
"""Cache structure for Recipe data"""
|
||||
|
||||
raw_data: List[Dict]
|
||||
sorted_by_name: List[Dict]
|
||||
sorted_by_date: List[Dict]
|
||||
folders: List[str] | None = None
|
||||
folder_tree: Dict | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._lock = asyncio.Lock()
|
||||
# Normalize optional metadata containers
|
||||
self.folders = self.folders or []
|
||||
self.folder_tree = self.folder_tree or {}
|
||||
|
||||
async def resort(self, name_only: bool = False):
|
||||
"""Resort all cached data views"""
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import os
|
||||
import logging
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
||||
from ..config import config
|
||||
@@ -14,6 +16,9 @@ from .recipes.errors import RecipeNotFoundError
|
||||
from ..utils.utils import calculate_recipe_fingerprint, fuzzy_match
|
||||
from natsort import natsorted
|
||||
import sys
|
||||
import re
|
||||
from ..recipes.merger import GenParamsMerger
|
||||
from ..recipes.enrichment import RecipeEnricher
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -52,6 +57,8 @@ class RecipeScanner:
|
||||
cls._instance._civitai_client = None # Will be lazily initialized
|
||||
return cls._instance
|
||||
|
||||
REPAIR_VERSION = 3
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_scanner: Optional[LoraScanner] = None,
|
||||
@@ -64,6 +71,7 @@ class RecipeScanner:
|
||||
self._initialization_task: Optional[asyncio.Task] = None
|
||||
self._is_initializing = False
|
||||
self._mutation_lock = asyncio.Lock()
|
||||
self._post_scan_task: Optional[asyncio.Task] = None
|
||||
self._resort_tasks: Set[asyncio.Task] = set()
|
||||
if lora_scanner:
|
||||
self._lora_scanner = lora_scanner
|
||||
@@ -84,6 +92,10 @@ class RecipeScanner:
|
||||
task.cancel()
|
||||
self._resort_tasks.clear()
|
||||
|
||||
if self._post_scan_task and not self._post_scan_task.done():
|
||||
self._post_scan_task.cancel()
|
||||
self._post_scan_task = None
|
||||
|
||||
self._cache = None
|
||||
self._initialization_task = None
|
||||
self._is_initializing = False
|
||||
@@ -102,19 +114,223 @@ class RecipeScanner:
|
||||
self._civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
return self._civitai_client
|
||||
|
||||
async def repair_all_recipes(
|
||||
self,
|
||||
progress_callback: Optional[Callable[[Dict], Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Repair all recipes by enrichment with Civitai and embedded metadata.
|
||||
|
||||
Args:
|
||||
persistence_service: Service for saving updated recipes
|
||||
progress_callback: Optional callback for progress updates
|
||||
|
||||
Returns:
|
||||
Dict summary of repair results
|
||||
"""
|
||||
async with self._mutation_lock:
|
||||
cache = await self.get_cached_data()
|
||||
all_recipes = list(cache.raw_data)
|
||||
total = len(all_recipes)
|
||||
repaired_count = 0
|
||||
skipped_count = 0
|
||||
errors_count = 0
|
||||
|
||||
civitai_client = await self._get_civitai_client()
|
||||
|
||||
for i, recipe in enumerate(all_recipes):
|
||||
try:
|
||||
# Report progress
|
||||
if progress_callback:
|
||||
await progress_callback({
|
||||
"status": "processing",
|
||||
"current": i + 1,
|
||||
"total": total,
|
||||
"recipe_name": recipe.get("name", "Unknown")
|
||||
})
|
||||
|
||||
if await self._repair_single_recipe(recipe, civitai_client):
|
||||
repaired_count += 1
|
||||
else:
|
||||
skipped_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error repairing recipe {recipe.get('file_path')}: {e}")
|
||||
errors_count += 1
|
||||
|
||||
# Final progress update
|
||||
if progress_callback:
|
||||
await progress_callback({
|
||||
"status": "completed",
|
||||
"repaired": repaired_count,
|
||||
"skipped": skipped_count,
|
||||
"errors": errors_count,
|
||||
"total": total
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"repaired": repaired_count,
|
||||
"skipped": skipped_count,
|
||||
"errors": errors_count,
|
||||
"total": total
|
||||
}
|
||||
|
||||
async def repair_recipe_by_id(self, recipe_id: str) -> Dict[str, Any]:
|
||||
"""Repair a single recipe by its ID.
|
||||
|
||||
Args:
|
||||
recipe_id: ID of the recipe to repair
|
||||
|
||||
Returns:
|
||||
Dict summary of repair result
|
||||
"""
|
||||
async with self._mutation_lock:
|
||||
# Get raw recipe from cache directly to avoid formatted fields
|
||||
cache = await self.get_cached_data()
|
||||
recipe = next((r for r in cache.raw_data if str(r.get('id', '')) == recipe_id), None)
|
||||
|
||||
if not recipe:
|
||||
raise RecipeNotFoundError(f"Recipe {recipe_id} not found")
|
||||
|
||||
civitai_client = await self._get_civitai_client()
|
||||
success = await self._repair_single_recipe(recipe, civitai_client)
|
||||
|
||||
# If successfully repaired, we should return the formatted version for the UI
|
||||
return {
|
||||
"success": True,
|
||||
"repaired": 1 if success else 0,
|
||||
"skipped": 0 if success else 1,
|
||||
"recipe": await self.get_recipe_by_id(recipe_id) if success else recipe
|
||||
}
|
||||
|
||||
async def _repair_single_recipe(self, recipe: Dict[str, Any], civitai_client: Any) -> bool:
|
||||
"""Internal helper to repair a single recipe object.
|
||||
|
||||
Args:
|
||||
recipe: The recipe dictionary to repair (modified in-place)
|
||||
civitai_client: Authenticated Civitai client
|
||||
|
||||
Returns:
|
||||
bool: True if recipe was repaired or updated, False if skipped
|
||||
"""
|
||||
# 1. Skip if already at latest repair version
|
||||
if recipe.get("repair_version", 0) >= self.REPAIR_VERSION:
|
||||
return False
|
||||
|
||||
# 2. Identification: Is repair needed?
|
||||
has_checkpoint = "checkpoint" in recipe and recipe["checkpoint"] and recipe["checkpoint"].get("name")
|
||||
gen_params = recipe.get("gen_params", {})
|
||||
has_prompt = bool(gen_params.get("prompt"))
|
||||
|
||||
needs_repair = not has_checkpoint or not has_prompt
|
||||
|
||||
if not needs_repair:
|
||||
# Even if no repair needed, we mark it with version if it was processed
|
||||
# Always update and save because if we are here, the version is old (checked in step 1)
|
||||
recipe["repair_version"] = self.REPAIR_VERSION
|
||||
await self._save_recipe_persistently(recipe)
|
||||
return True
|
||||
|
||||
# 3. Use Enricher to repair/enrich
|
||||
try:
|
||||
updated = await RecipeEnricher.enrich_recipe(recipe, civitai_client)
|
||||
except Exception as e:
|
||||
logger.error(f"Error enriching recipe {recipe.get('id')}: {e}")
|
||||
updated = False
|
||||
|
||||
# 4. Mark version and save if updated or just marking version
|
||||
# If we updated it, OR if the version is old (which we know it is if we are here), save it.
|
||||
# Actually, if we are here and updated is False, it means we tried to repair but couldn't/didn't need to.
|
||||
# But we still want to mark it as processed so we don't try again until version bump.
|
||||
if updated or recipe.get("repair_version", 0) < self.REPAIR_VERSION:
|
||||
recipe["repair_version"] = self.REPAIR_VERSION
|
||||
await self._save_recipe_persistently(recipe)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _save_recipe_persistently(self, recipe: Dict[str, Any]) -> bool:
|
||||
"""Helper to save a recipe to both JSON and EXIF metadata."""
|
||||
recipe_id = recipe.get("id")
|
||||
if not recipe_id:
|
||||
return False
|
||||
|
||||
recipe_json_path = await self.get_recipe_json_path(recipe_id)
|
||||
if not recipe_json_path:
|
||||
return False
|
||||
|
||||
try:
|
||||
# 1. Sanitize for storage (remove runtime convenience fields)
|
||||
clean_recipe = self._sanitize_recipe_for_storage(recipe)
|
||||
|
||||
# 2. Update the original dictionary so that we persist the clean version
|
||||
# globally if needed, effectively overwriting it in-place.
|
||||
recipe.clear()
|
||||
recipe.update(clean_recipe)
|
||||
|
||||
# 3. Save JSON
|
||||
with open(recipe_json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(recipe, f, indent=4, ensure_ascii=False)
|
||||
|
||||
# 4. Update EXIF if image exists
|
||||
image_path = recipe.get('file_path')
|
||||
if image_path and os.path.exists(image_path):
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
ExifUtils.append_recipe_metadata(image_path, recipe)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error persisting recipe {recipe_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _sanitize_recipe_for_storage(self, recipe: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Create a clean copy of the recipe without runtime convenience fields."""
|
||||
import copy
|
||||
clean = copy.deepcopy(recipe)
|
||||
|
||||
# 0. Clean top-level runtime fields
|
||||
for key in ("file_url", "created_date_formatted", "modified_formatted"):
|
||||
clean.pop(key, None)
|
||||
|
||||
# 1. Clean LORAs
|
||||
if "loras" in clean and isinstance(clean["loras"], list):
|
||||
for lora in clean["loras"]:
|
||||
# Fields to remove (runtime only)
|
||||
for key in ("inLibrary", "preview_url", "localPath"):
|
||||
lora.pop(key, None)
|
||||
|
||||
# Normalize weight/strength if mapping is desired (standard in persistence_service)
|
||||
if "weight" in lora and "strength" not in lora:
|
||||
lora["strength"] = float(lora.pop("weight"))
|
||||
|
||||
# 2. Clean Checkpoint
|
||||
if "checkpoint" in clean and isinstance(clean["checkpoint"], dict):
|
||||
cp = clean["checkpoint"]
|
||||
# Fields to remove (runtime only)
|
||||
for key in ("inLibrary", "localPath", "preview_url", "thumbnailUrl", "size", "downloadUrl"):
|
||||
cp.pop(key, None)
|
||||
|
||||
return clean
|
||||
|
||||
async def initialize_in_background(self) -> None:
|
||||
"""Initialize cache in background using thread pool"""
|
||||
try:
|
||||
await self._wait_for_lora_scanner()
|
||||
|
||||
# Set initial empty cache to avoid None reference errors
|
||||
if self._cache is None:
|
||||
self._cache = RecipeCache(
|
||||
raw_data=[],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[]
|
||||
sorted_by_date=[],
|
||||
folders=[],
|
||||
folder_tree={},
|
||||
)
|
||||
|
||||
# Mark as initializing to prevent concurrent initializations
|
||||
self._is_initializing = True
|
||||
self._initialization_task = asyncio.current_task()
|
||||
|
||||
try:
|
||||
# Start timer
|
||||
@@ -126,11 +342,14 @@ class RecipeScanner:
|
||||
None, # Use default thread pool
|
||||
self._initialize_recipe_cache_sync # Run synchronous version in thread
|
||||
)
|
||||
if cache is not None:
|
||||
self._cache = cache
|
||||
|
||||
# Calculate elapsed time and log it
|
||||
elapsed_time = time.time() - start_time
|
||||
recipe_count = len(cache.raw_data) if cache and hasattr(cache, 'raw_data') else 0
|
||||
logger.info(f"Recipe cache initialized in {elapsed_time:.2f} seconds. Found {recipe_count} recipes")
|
||||
self._schedule_post_scan_enrichment()
|
||||
finally:
|
||||
# Mark initialization as complete regardless of outcome
|
||||
self._is_initializing = False
|
||||
@@ -207,6 +426,7 @@ class RecipeScanner:
|
||||
|
||||
# Update cache with the collected data
|
||||
self._cache.raw_data = recipes
|
||||
self._update_folder_metadata(self._cache)
|
||||
|
||||
# Create a simplified resort function that doesn't use await
|
||||
if hasattr(self._cache, "resort"):
|
||||
@@ -237,12 +457,97 @@ class RecipeScanner:
|
||||
# Clean up the event loop
|
||||
loop.close()
|
||||
|
||||
async def _wait_for_lora_scanner(self) -> None:
|
||||
"""Ensure the LoRA scanner has initialized before recipe enrichment."""
|
||||
|
||||
if not getattr(self, "_lora_scanner", None):
|
||||
return
|
||||
|
||||
lora_scanner = self._lora_scanner
|
||||
cache_ready = getattr(lora_scanner, "_cache", None) is not None
|
||||
|
||||
# If cache is already available, we can proceed
|
||||
if cache_ready:
|
||||
return
|
||||
|
||||
# Await an existing initialization task if present
|
||||
task = getattr(lora_scanner, "_initialization_task", None)
|
||||
if task and hasattr(task, "done") and not task.done():
|
||||
try:
|
||||
await task
|
||||
except Exception: # pragma: no cover - defensive guard
|
||||
pass
|
||||
if getattr(lora_scanner, "_cache", None) is not None:
|
||||
return
|
||||
|
||||
# Otherwise, request initialization and proceed once it completes
|
||||
try:
|
||||
await lora_scanner.initialize_in_background()
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.debug("Recipe Scanner: LoRA init request failed: %s", exc)
|
||||
|
||||
def _schedule_post_scan_enrichment(self) -> None:
|
||||
"""Kick off a non-blocking enrichment pass to fill remote metadata."""
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return
|
||||
|
||||
if self._post_scan_task and not self._post_scan_task.done():
|
||||
return
|
||||
|
||||
async def _run_enrichment():
|
||||
try:
|
||||
await self._enrich_cache_metadata()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.error("Recipe Scanner: error during post-scan enrichment: %s", exc, exc_info=True)
|
||||
|
||||
self._post_scan_task = loop.create_task(_run_enrichment(), name="recipe_cache_enrichment")
|
||||
|
||||
async def _enrich_cache_metadata(self) -> None:
|
||||
"""Perform remote metadata enrichment after the initial scan."""
|
||||
|
||||
cache = self._cache
|
||||
if cache is None or not getattr(cache, "raw_data", None):
|
||||
return
|
||||
|
||||
for index, recipe in enumerate(list(cache.raw_data)):
|
||||
try:
|
||||
metadata_updated = await self._update_lora_information(recipe)
|
||||
if metadata_updated:
|
||||
recipe_id = recipe.get("id")
|
||||
if recipe_id:
|
||||
recipe_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json")
|
||||
if os.path.exists(recipe_path):
|
||||
try:
|
||||
self._write_recipe_file(recipe_path, recipe)
|
||||
except Exception as exc: # pragma: no cover - best-effort persistence
|
||||
logger.debug("Recipe Scanner: could not persist recipe %s: %s", recipe_id, exc)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Recipe Scanner: error enriching recipe %s: %s", recipe.get("id"), exc, exc_info=True)
|
||||
|
||||
if index % 10 == 0:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
try:
|
||||
await cache.resort()
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.debug("Recipe Scanner: error resorting cache after enrichment: %s", exc)
|
||||
|
||||
def _schedule_resort(self, *, name_only: bool = False) -> None:
|
||||
"""Schedule a background resort of the recipe cache."""
|
||||
|
||||
if not self._cache:
|
||||
return
|
||||
|
||||
# Keep folder metadata up to date alongside sort order
|
||||
self._update_folder_metadata()
|
||||
|
||||
async def _resort_wrapper() -> None:
|
||||
try:
|
||||
await self._cache.resort(name_only=name_only)
|
||||
@@ -253,6 +558,75 @@ class RecipeScanner:
|
||||
self._resort_tasks.add(task)
|
||||
task.add_done_callback(lambda finished: self._resort_tasks.discard(finished))
|
||||
|
||||
def _calculate_folder(self, recipe_path: str) -> str:
|
||||
"""Calculate a normalized folder path relative to ``recipes_dir``."""
|
||||
|
||||
recipes_dir = self.recipes_dir
|
||||
if not recipes_dir:
|
||||
return ""
|
||||
|
||||
try:
|
||||
recipe_dir = os.path.dirname(os.path.normpath(recipe_path))
|
||||
relative_dir = os.path.relpath(recipe_dir, recipes_dir)
|
||||
if relative_dir in (".", ""):
|
||||
return ""
|
||||
return relative_dir.replace(os.path.sep, "/")
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def _build_folder_tree(self, folders: list[str]) -> dict:
|
||||
"""Build a nested folder tree structure from relative folder paths."""
|
||||
|
||||
tree: dict[str, dict] = {}
|
||||
for folder in folders:
|
||||
if not folder:
|
||||
continue
|
||||
|
||||
parts = folder.split("/")
|
||||
current_level = tree
|
||||
|
||||
for part in parts:
|
||||
if part not in current_level:
|
||||
current_level[part] = {}
|
||||
current_level = current_level[part]
|
||||
|
||||
return tree
|
||||
|
||||
def _update_folder_metadata(self, cache: RecipeCache | None = None) -> None:
|
||||
"""Ensure folder lists and tree metadata are synchronized with cache contents."""
|
||||
|
||||
cache = cache or self._cache
|
||||
if cache is None:
|
||||
return
|
||||
|
||||
folders: set[str] = set()
|
||||
for item in cache.raw_data:
|
||||
folder_value = item.get("folder", "")
|
||||
if folder_value is None:
|
||||
folder_value = ""
|
||||
if folder_value == ".":
|
||||
folder_value = ""
|
||||
normalized = str(folder_value).replace("\\", "/")
|
||||
item["folder"] = normalized
|
||||
folders.add(normalized)
|
||||
|
||||
cache.folders = sorted(folders, key=lambda entry: entry.lower())
|
||||
cache.folder_tree = self._build_folder_tree(cache.folders)
|
||||
|
||||
async def get_folders(self) -> list[str]:
|
||||
"""Return a sorted list of recipe folders relative to the recipes root."""
|
||||
|
||||
cache = await self.get_cached_data()
|
||||
self._update_folder_metadata(cache)
|
||||
return cache.folders
|
||||
|
||||
async def get_folder_tree(self) -> dict:
|
||||
"""Return a hierarchical tree of recipe folders for sidebar navigation."""
|
||||
|
||||
cache = await self.get_cached_data()
|
||||
self._update_folder_metadata(cache)
|
||||
return cache.folder_tree
|
||||
|
||||
@property
|
||||
def recipes_dir(self) -> str:
|
||||
"""Get path to recipes directory"""
|
||||
@@ -269,11 +643,14 @@ class RecipeScanner:
|
||||
"""Get cached recipe data, refresh if needed"""
|
||||
# If cache is already initialized and no refresh is needed, return it immediately
|
||||
if self._cache is not None and not force_refresh:
|
||||
self._update_folder_metadata()
|
||||
return self._cache
|
||||
|
||||
# If another initialization is already in progress, wait for it to complete
|
||||
if self._is_initializing and not force_refresh:
|
||||
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
|
||||
return self._cache or RecipeCache(
|
||||
raw_data=[], sorted_by_name=[], sorted_by_date=[], folders=[], folder_tree={}
|
||||
)
|
||||
|
||||
# If force refresh is requested, initialize the cache directly
|
||||
if force_refresh:
|
||||
@@ -291,11 +668,14 @@ class RecipeScanner:
|
||||
self._cache = RecipeCache(
|
||||
raw_data=raw_data,
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[]
|
||||
sorted_by_date=[],
|
||||
folders=[],
|
||||
folder_tree={},
|
||||
)
|
||||
|
||||
|
||||
# Resort cache
|
||||
await self._cache.resort()
|
||||
self._update_folder_metadata(self._cache)
|
||||
|
||||
return self._cache
|
||||
|
||||
@@ -305,7 +685,9 @@ class RecipeScanner:
|
||||
self._cache = RecipeCache(
|
||||
raw_data=[],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[]
|
||||
sorted_by_date=[],
|
||||
folders=[],
|
||||
folder_tree={},
|
||||
)
|
||||
return self._cache
|
||||
finally:
|
||||
@@ -316,7 +698,9 @@ class RecipeScanner:
|
||||
logger.error(f"Unexpected error in get_cached_data: {e}")
|
||||
|
||||
# Return the cache (may be empty or partially initialized)
|
||||
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
|
||||
return self._cache or RecipeCache(
|
||||
raw_data=[], sorted_by_name=[], sorted_by_date=[], folders=[], folder_tree={}
|
||||
)
|
||||
|
||||
async def refresh_cache(self, force: bool = False) -> RecipeCache:
|
||||
"""Public helper to refresh or return the recipe cache."""
|
||||
@@ -331,6 +715,7 @@ class RecipeScanner:
|
||||
|
||||
cache = await self.get_cached_data()
|
||||
await cache.add_recipe(recipe_data, resort=False)
|
||||
self._update_folder_metadata(cache)
|
||||
self._schedule_resort()
|
||||
|
||||
async def remove_recipe(self, recipe_id: str) -> bool:
|
||||
@@ -344,6 +729,7 @@ class RecipeScanner:
|
||||
if removed is None:
|
||||
return False
|
||||
|
||||
self._update_folder_metadata(cache)
|
||||
self._schedule_resort()
|
||||
return True
|
||||
|
||||
@@ -428,6 +814,9 @@ class RecipeScanner:
|
||||
|
||||
if path_updated:
|
||||
self._write_recipe_file(recipe_path, recipe_data)
|
||||
|
||||
# Track folder placement relative to recipes directory
|
||||
recipe_data['folder'] = recipe_data.get('folder') or self._calculate_folder(recipe_path)
|
||||
|
||||
# Ensure loras array exists
|
||||
if 'loras' not in recipe_data:
|
||||
@@ -438,7 +827,7 @@ class RecipeScanner:
|
||||
recipe_data['gen_params'] = {}
|
||||
|
||||
# Update lora information with local paths and availability
|
||||
await self._update_lora_information(recipe_data)
|
||||
lora_metadata_updated = await self._update_lora_information(recipe_data)
|
||||
|
||||
if recipe_data.get('checkpoint'):
|
||||
checkpoint_entry = self._normalize_checkpoint_entry(recipe_data['checkpoint'])
|
||||
@@ -459,6 +848,12 @@ class RecipeScanner:
|
||||
logger.info(f"Added fingerprint to recipe: {recipe_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing updated recipe with fingerprint: {e}")
|
||||
elif lora_metadata_updated:
|
||||
# Persist updates such as marking invalid entries as deleted
|
||||
try:
|
||||
self._write_recipe_file(recipe_path, recipe_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing updated recipe metadata: {e}")
|
||||
|
||||
return recipe_data
|
||||
except Exception as e:
|
||||
@@ -519,7 +914,13 @@ class RecipeScanner:
|
||||
logger.warning(f"Marked lora with modelVersionId {model_version_id} as deleted")
|
||||
metadata_updated = True
|
||||
else:
|
||||
logger.debug(f"Could not get hash for modelVersionId {model_version_id}")
|
||||
# No hash returned; mark as deleted to avoid repeated lookups
|
||||
lora['isDeleted'] = True
|
||||
metadata_updated = True
|
||||
logger.warning(
|
||||
"Marked lora with modelVersionId %s as deleted after failed hash lookup",
|
||||
model_version_id,
|
||||
)
|
||||
|
||||
# If has hash but no file_name, look up in lora library
|
||||
if 'hash' in lora and (not lora.get('file_name') or not lora['file_name']):
|
||||
@@ -809,7 +1210,7 @@ class RecipeScanner:
|
||||
|
||||
return await self._lora_scanner.get_model_info_by_name(name)
|
||||
|
||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None, lora_hash: str = None, bypass_filters: bool = True):
|
||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None, lora_hash: str = None, bypass_filters: bool = True, folder: str | None = None, recursive: bool = True):
|
||||
"""Get paginated and filtered recipe data
|
||||
|
||||
Args:
|
||||
@@ -821,11 +1222,20 @@ class RecipeScanner:
|
||||
search_options: Dictionary of search options to apply
|
||||
lora_hash: Optional SHA256 hash of a LoRA to filter recipes by
|
||||
bypass_filters: If True, ignore other filters when a lora_hash is provided
|
||||
folder: Optional folder filter relative to recipes directory
|
||||
recursive: Whether to include recipes in subfolders of the selected folder
|
||||
"""
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
# Get base dataset
|
||||
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
||||
sort_field = sort_by.split(':')[0] if ':' in sort_by else sort_by
|
||||
|
||||
if sort_field == 'date':
|
||||
filtered_data = list(cache.sorted_by_date)
|
||||
elif sort_field == 'name':
|
||||
filtered_data = list(cache.sorted_by_name)
|
||||
else:
|
||||
filtered_data = list(cache.raw_data)
|
||||
|
||||
# Apply SFW filtering if enabled
|
||||
from .settings_manager import get_settings_manager
|
||||
@@ -856,6 +1266,22 @@ class RecipeScanner:
|
||||
|
||||
# Skip further filtering if we're only filtering by LoRA hash with bypass enabled
|
||||
if not (lora_hash and bypass_filters):
|
||||
# Apply folder filter before other criteria
|
||||
if folder is not None:
|
||||
normalized_folder = folder.strip("/")
|
||||
def matches_folder(item_folder: str) -> bool:
|
||||
item_path = (item_folder or "").strip("/")
|
||||
if recursive:
|
||||
if not normalized_folder:
|
||||
return True
|
||||
return item_path == normalized_folder or item_path.startswith(f"{normalized_folder}/")
|
||||
return item_path == normalized_folder
|
||||
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if matches_folder(item.get('folder', ''))
|
||||
]
|
||||
|
||||
# Apply search filter
|
||||
if search:
|
||||
# Default search options if none provided
|
||||
@@ -892,6 +1318,14 @@ class RecipeScanner:
|
||||
if fuzzy_match(str(lora.get('modelName', '')), search):
|
||||
return True
|
||||
|
||||
# Search in prompt and negative_prompt if enabled
|
||||
if search_options.get('prompt', True) and 'gen_params' in item:
|
||||
gen_params = item['gen_params']
|
||||
if fuzzy_match(str(gen_params.get('prompt', '')), search):
|
||||
return True
|
||||
if fuzzy_match(str(gen_params.get('negative_prompt', '')), search):
|
||||
return True
|
||||
|
||||
# No match found
|
||||
return False
|
||||
|
||||
@@ -907,6 +1341,13 @@ class RecipeScanner:
|
||||
if item.get('base_model', '') in filters['base_model']
|
||||
]
|
||||
|
||||
# Filter by favorite
|
||||
if 'favorite' in filters and filters['favorite']:
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if item.get('favorite') is True
|
||||
]
|
||||
|
||||
# Filter by tags
|
||||
if 'tags' in filters and filters['tags']:
|
||||
tag_spec = filters['tags']
|
||||
@@ -925,17 +1366,41 @@ class RecipeScanner:
|
||||
include_tags = {tag for tag in tag_spec if tag}
|
||||
|
||||
if include_tags:
|
||||
def matches_include(item_tags):
|
||||
if not item_tags and "__no_tags__" in include_tags:
|
||||
return True
|
||||
return any(tag in include_tags for tag in (item_tags or []))
|
||||
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if any(tag in include_tags for tag in (item.get('tags', []) or []))
|
||||
if matches_include(item.get('tags'))
|
||||
]
|
||||
|
||||
if exclude_tags:
|
||||
def matches_exclude(item_tags):
|
||||
if not item_tags and "__no_tags__" in exclude_tags:
|
||||
return True
|
||||
return any(tag in exclude_tags for tag in (item_tags or []))
|
||||
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if not any(tag in exclude_tags for tag in (item.get('tags', []) or []))
|
||||
if not matches_exclude(item.get('tags'))
|
||||
]
|
||||
|
||||
|
||||
# Apply sorting if not already handled by pre-sorted cache
|
||||
if ':' in sort_by or sort_field == 'loras_count':
|
||||
field, order = (sort_by.split(':') + ['desc'])[:2]
|
||||
reverse = order.lower() == 'desc'
|
||||
|
||||
if field == 'name':
|
||||
filtered_data = natsorted(filtered_data, key=lambda x: x.get('title', '').lower(), reverse=reverse)
|
||||
elif field == 'date':
|
||||
# Use modified if available, falling back to created_date
|
||||
filtered_data.sort(key=lambda x: (x.get('modified', x.get('created_date', 0)), x.get('file_path', '')), reverse=reverse)
|
||||
elif field == 'loras_count':
|
||||
filtered_data.sort(key=lambda x: len(x.get('loras', [])), reverse=reverse)
|
||||
|
||||
# Calculate pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
@@ -1031,6 +1496,30 @@ class RecipeScanner:
|
||||
from datetime import datetime
|
||||
return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
async def get_recipe_json_path(self, recipe_id: str) -> Optional[str]:
|
||||
"""Locate the recipe JSON file, accounting for folder placement."""
|
||||
|
||||
recipes_dir = self.recipes_dir
|
||||
if not recipes_dir:
|
||||
return None
|
||||
|
||||
cache = await self.get_cached_data()
|
||||
folder = ""
|
||||
for item in cache.raw_data:
|
||||
if str(item.get("id")) == str(recipe_id):
|
||||
folder = item.get("folder") or ""
|
||||
break
|
||||
|
||||
candidate = os.path.normpath(os.path.join(recipes_dir, folder, f"{recipe_id}.recipe.json"))
|
||||
if os.path.exists(candidate):
|
||||
return candidate
|
||||
|
||||
for root, _, files in os.walk(recipes_dir):
|
||||
if f"{recipe_id}.recipe.json" in files:
|
||||
return os.path.join(root, f"{recipe_id}.recipe.json")
|
||||
|
||||
return None
|
||||
|
||||
async def update_recipe_metadata(self, recipe_id: str, metadata: dict) -> bool:
|
||||
"""Update recipe metadata (like title and tags) in both file system and cache
|
||||
|
||||
@@ -1041,13 +1530,9 @@ class RecipeScanner:
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
|
||||
# First, find the recipe JSON file path
|
||||
recipe_json_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json")
|
||||
|
||||
if not os.path.exists(recipe_json_path):
|
||||
recipe_json_path = await self.get_recipe_json_path(recipe_id)
|
||||
if not recipe_json_path or not os.path.exists(recipe_json_path):
|
||||
return False
|
||||
|
||||
try:
|
||||
@@ -1096,8 +1581,8 @@ class RecipeScanner:
|
||||
if target_name is None:
|
||||
raise ValueError("target_name must be provided")
|
||||
|
||||
recipe_json_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json")
|
||||
if not os.path.exists(recipe_json_path):
|
||||
recipe_json_path = await self.get_recipe_json_path(recipe_id)
|
||||
if not recipe_json_path or not os.path.exists(recipe_json_path):
|
||||
raise RecipeNotFoundError("Recipe not found")
|
||||
|
||||
async with self._mutation_lock:
|
||||
@@ -1228,71 +1713,56 @@ class RecipeScanner:
|
||||
# Always use lowercase hash for consistency
|
||||
hash_value = hash_value.lower()
|
||||
|
||||
# Get recipes directory
|
||||
recipes_dir = self.recipes_dir
|
||||
if not recipes_dir or not os.path.exists(recipes_dir):
|
||||
logger.warning(f"Recipes directory not found: {recipes_dir}")
|
||||
# Get cache
|
||||
cache = await self.get_cached_data()
|
||||
if not cache or not cache.raw_data:
|
||||
return 0, 0
|
||||
|
||||
file_updated_count = 0
|
||||
cache_updated_count = 0
|
||||
|
||||
# Find recipes that need updating from the cache
|
||||
recipes_to_update = []
|
||||
for recipe in cache.raw_data:
|
||||
loras = recipe.get('loras', [])
|
||||
if not isinstance(loras, list):
|
||||
continue
|
||||
|
||||
has_match = False
|
||||
for lora in loras:
|
||||
if not isinstance(lora, dict):
|
||||
continue
|
||||
if (lora.get('hash') or '').lower() == hash_value:
|
||||
if lora.get('file_name') != new_file_name:
|
||||
lora['file_name'] = new_file_name
|
||||
has_match = True
|
||||
|
||||
if has_match:
|
||||
recipes_to_update.append(recipe)
|
||||
cache_updated_count += 1
|
||||
|
||||
if not recipes_to_update:
|
||||
return 0, 0
|
||||
|
||||
# Check if cache is initialized
|
||||
cache_initialized = self._cache is not None
|
||||
cache_updated_count = 0
|
||||
file_updated_count = 0
|
||||
|
||||
# Get all recipe JSON files in the recipes directory
|
||||
recipe_files = []
|
||||
for root, _, files in os.walk(recipes_dir):
|
||||
for file in files:
|
||||
if file.lower().endswith('.recipe.json'):
|
||||
recipe_files.append(os.path.join(root, file))
|
||||
|
||||
# Process each recipe file
|
||||
for recipe_path in recipe_files:
|
||||
try:
|
||||
# Load the recipe data
|
||||
with open(recipe_path, 'r', encoding='utf-8') as f:
|
||||
recipe_data = json.load(f)
|
||||
|
||||
# Skip if no loras or invalid structure
|
||||
if not recipe_data or not isinstance(recipe_data, dict) or 'loras' not in recipe_data:
|
||||
# Persist changes to disk
|
||||
async with self._mutation_lock:
|
||||
for recipe in recipes_to_update:
|
||||
recipe_id = recipe.get('id')
|
||||
if not recipe_id:
|
||||
continue
|
||||
|
||||
# Check if any lora has matching hash
|
||||
file_updated = False
|
||||
for lora in recipe_data.get('loras', []):
|
||||
if 'hash' in lora and lora['hash'].lower() == hash_value:
|
||||
# Update file_name
|
||||
old_file_name = lora.get('file_name', '')
|
||||
lora['file_name'] = new_file_name
|
||||
file_updated = True
|
||||
logger.info(f"Updated file_name in recipe {recipe_path}: {old_file_name} -> {new_file_name}")
|
||||
|
||||
# If updated, save the file
|
||||
if file_updated:
|
||||
with open(recipe_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(recipe_data, f, indent=4, ensure_ascii=False)
|
||||
file_updated_count += 1
|
||||
|
||||
# Also update in cache if it exists
|
||||
if cache_initialized:
|
||||
recipe_id = recipe_data.get('id')
|
||||
if recipe_id:
|
||||
for cache_item in self._cache.raw_data:
|
||||
if cache_item.get('id') == recipe_id:
|
||||
# Replace loras array with updated version
|
||||
cache_item['loras'] = recipe_data['loras']
|
||||
cache_updated_count += 1
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating recipe file {recipe_path}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
recipe_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json")
|
||||
try:
|
||||
self._write_recipe_file(recipe_path, recipe)
|
||||
file_updated_count += 1
|
||||
logger.info(f"Updated file_name in recipe {recipe_path}: -> {new_file_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating recipe file {recipe_path}: {e}")
|
||||
|
||||
# Resort cache if updates were made
|
||||
if cache_initialized and cache_updated_count > 0:
|
||||
await self._cache.resort()
|
||||
logger.info(f"Resorted recipe cache after updating {cache_updated_count} items")
|
||||
# We don't necessarily need to resort because LoRA file_name isn't a sort key,
|
||||
# but we might want to schedule a resort if we're paranoid or if searching relies on sorted state.
|
||||
# Given it's a rename of a dependency, search results might change if searching by LoRA name.
|
||||
self._schedule_resort()
|
||||
|
||||
return file_updated_count, cache_updated_count
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ...utils.utils import calculate_recipe_fingerprint
|
||||
from ...utils.civitai_utils import rewrite_preview_url
|
||||
from .errors import (
|
||||
RecipeDownloadError,
|
||||
RecipeNotFoundError,
|
||||
@@ -94,18 +95,39 @@ class RecipeAnalysisService:
|
||||
if civitai_client is None:
|
||||
raise RecipeServiceError("Civitai client unavailable")
|
||||
|
||||
temp_path = self._create_temp_path()
|
||||
temp_path = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
is_video = False
|
||||
extension = ".jpg" # Default
|
||||
|
||||
try:
|
||||
civitai_match = re.match(r"https://civitai\.com/images/(\d+)", url)
|
||||
if civitai_match:
|
||||
image_info = await civitai_client.get_image_info(civitai_match.group(1))
|
||||
if not image_info:
|
||||
raise RecipeDownloadError("Failed to fetch image information from Civitai")
|
||||
|
||||
image_url = image_info.get("url")
|
||||
if not image_url:
|
||||
raise RecipeDownloadError("No image URL found in Civitai response")
|
||||
|
||||
is_video = image_info.get("type") == "video"
|
||||
|
||||
# Use optimized preview URLs if possible
|
||||
rewritten_url, _ = rewrite_preview_url(image_url, media_type=image_info.get("type"))
|
||||
if rewritten_url:
|
||||
image_url = rewritten_url
|
||||
|
||||
if is_video:
|
||||
# Extract extension from URL
|
||||
url_path = image_url.split('?')[0].split('#')[0]
|
||||
extension = os.path.splitext(url_path)[1].lower() or ".mp4"
|
||||
else:
|
||||
extension = ".jpg"
|
||||
|
||||
temp_path = self._create_temp_path(suffix=extension)
|
||||
await self._download_image(image_url, temp_path)
|
||||
|
||||
metadata = image_info.get("meta") if "meta" in image_info else None
|
||||
if (
|
||||
isinstance(metadata, dict)
|
||||
@@ -114,22 +136,31 @@ class RecipeAnalysisService:
|
||||
):
|
||||
metadata = metadata["meta"]
|
||||
else:
|
||||
# Basic extension detection for non-Civitai URLs
|
||||
url_path = url.split('?')[0].split('#')[0]
|
||||
extension = os.path.splitext(url_path)[1].lower()
|
||||
if extension in [".mp4", ".webm"]:
|
||||
is_video = True
|
||||
else:
|
||||
extension = ".jpg"
|
||||
|
||||
temp_path = self._create_temp_path(suffix=extension)
|
||||
await self._download_image(url, temp_path)
|
||||
|
||||
if metadata is None:
|
||||
if metadata is None and not is_video:
|
||||
metadata = self._exif_utils.extract_image_metadata(temp_path)
|
||||
|
||||
if not metadata:
|
||||
return self._metadata_not_found_response(temp_path)
|
||||
|
||||
return await self._parse_metadata(
|
||||
metadata,
|
||||
metadata or {},
|
||||
recipe_scanner=recipe_scanner,
|
||||
image_path=temp_path,
|
||||
include_image_base64=True,
|
||||
is_video=is_video,
|
||||
extension=extension,
|
||||
)
|
||||
finally:
|
||||
self._safe_cleanup(temp_path)
|
||||
if temp_path:
|
||||
self._safe_cleanup(temp_path)
|
||||
|
||||
async def analyze_local_image(
|
||||
self,
|
||||
@@ -198,12 +229,16 @@ class RecipeAnalysisService:
|
||||
recipe_scanner,
|
||||
image_path: Optional[str],
|
||||
include_image_base64: bool,
|
||||
is_video: bool = False,
|
||||
extension: str = ".jpg",
|
||||
) -> AnalysisResult:
|
||||
parser = self._recipe_parser_factory.create_parser(metadata)
|
||||
if parser is None:
|
||||
payload = {"error": "No parser found for this image", "loras": []}
|
||||
if include_image_base64 and image_path:
|
||||
payload["image_base64"] = self._encode_file(image_path)
|
||||
payload["is_video"] = is_video
|
||||
payload["extension"] = extension
|
||||
return AnalysisResult(payload)
|
||||
|
||||
result = await parser.parse_metadata(metadata, recipe_scanner=recipe_scanner)
|
||||
@@ -211,6 +246,9 @@ class RecipeAnalysisService:
|
||||
if include_image_base64 and image_path:
|
||||
result["image_base64"] = self._encode_file(image_path)
|
||||
|
||||
result["is_video"] = is_video
|
||||
result["extension"] = extension
|
||||
|
||||
if "error" in result and not result.get("loras"):
|
||||
return AnalysisResult(result)
|
||||
|
||||
@@ -241,8 +279,8 @@ class RecipeAnalysisService:
|
||||
temp_file.write(data)
|
||||
return temp_file.name
|
||||
|
||||
def _create_temp_path(self) -> str:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
||||
def _create_temp_path(self, suffix: str = ".jpg") -> str:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
||||
return temp_file.name
|
||||
|
||||
def _safe_cleanup(self, path: Optional[str]) -> None:
|
||||
|
||||
@@ -5,6 +5,7 @@ import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
@@ -46,6 +47,7 @@ class RecipePersistenceService:
|
||||
name: str | None,
|
||||
tags: Iterable[str],
|
||||
metadata: Optional[dict[str, Any]],
|
||||
extension: str | None = None,
|
||||
) -> PersistenceResult:
|
||||
"""Persist a user uploaded recipe."""
|
||||
|
||||
@@ -64,13 +66,21 @@ class RecipePersistenceService:
|
||||
os.makedirs(recipes_dir, exist_ok=True)
|
||||
|
||||
recipe_id = str(uuid.uuid4())
|
||||
optimized_image, extension = self._exif_utils.optimize_image(
|
||||
image_data=resolved_image_bytes,
|
||||
target_width=self._card_preview_width,
|
||||
format="webp",
|
||||
quality=85,
|
||||
preserve_metadata=True,
|
||||
)
|
||||
|
||||
# Handle video formats by bypassing optimization and metadata embedding
|
||||
is_video = extension in [".mp4", ".webm"]
|
||||
if is_video:
|
||||
optimized_image = resolved_image_bytes
|
||||
# extension is already set
|
||||
else:
|
||||
optimized_image, extension = self._exif_utils.optimize_image(
|
||||
image_data=resolved_image_bytes,
|
||||
target_width=self._card_preview_width,
|
||||
format="webp",
|
||||
quality=85,
|
||||
preserve_metadata=True,
|
||||
)
|
||||
|
||||
image_filename = f"{recipe_id}{extension}"
|
||||
image_path = os.path.join(recipes_dir, image_filename)
|
||||
normalized_image_path = os.path.normpath(image_path)
|
||||
@@ -126,7 +136,8 @@ class RecipePersistenceService:
|
||||
with open(json_path, "w", encoding="utf-8") as file_obj:
|
||||
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
|
||||
|
||||
self._exif_utils.append_recipe_metadata(normalized_image_path, recipe_data)
|
||||
if not is_video:
|
||||
self._exif_utils.append_recipe_metadata(normalized_image_path, recipe_data)
|
||||
|
||||
matching_recipes = await self._find_matching_recipes(recipe_scanner, fingerprint, exclude_id=recipe_id)
|
||||
await recipe_scanner.add_recipe(recipe_data)
|
||||
@@ -144,12 +155,8 @@ class RecipePersistenceService:
|
||||
async def delete_recipe(self, *, recipe_scanner, recipe_id: str) -> PersistenceResult:
|
||||
"""Delete an existing recipe."""
|
||||
|
||||
recipes_dir = recipe_scanner.recipes_dir
|
||||
if not recipes_dir or not os.path.exists(recipes_dir):
|
||||
raise RecipeNotFoundError("Recipes directory not found")
|
||||
|
||||
recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json")
|
||||
if not os.path.exists(recipe_json_path):
|
||||
recipe_json_path = await recipe_scanner.get_recipe_json_path(recipe_id)
|
||||
if not recipe_json_path or not os.path.exists(recipe_json_path):
|
||||
raise RecipeNotFoundError("Recipe not found")
|
||||
|
||||
with open(recipe_json_path, "r", encoding="utf-8") as file_obj:
|
||||
@@ -166,9 +173,9 @@ class RecipePersistenceService:
|
||||
async def update_recipe(self, *, recipe_scanner, recipe_id: str, updates: dict[str, Any]) -> PersistenceResult:
|
||||
"""Update persisted metadata for a recipe."""
|
||||
|
||||
if not any(key in updates for key in ("title", "tags", "source_path", "preview_nsfw_level")):
|
||||
if not any(key in updates for key in ("title", "tags", "source_path", "preview_nsfw_level", "favorite")):
|
||||
raise RecipeValidationError(
|
||||
"At least one field to update must be provided (title or tags or source_path or preview_nsfw_level)"
|
||||
"At least one field to update must be provided (title or tags or source_path or preview_nsfw_level or favorite)"
|
||||
)
|
||||
|
||||
success = await recipe_scanner.update_recipe_metadata(recipe_id, updates)
|
||||
@@ -177,6 +184,163 @@ class RecipePersistenceService:
|
||||
|
||||
return PersistenceResult({"success": True, "recipe_id": recipe_id, "updates": updates})
|
||||
|
||||
def _normalize_target_path(self, recipe_scanner, target_path: str) -> tuple[str, str]:
|
||||
"""Normalize and validate the target path for recipe moves."""
|
||||
|
||||
if not target_path:
|
||||
raise RecipeValidationError("Target path is required")
|
||||
|
||||
recipes_root = recipe_scanner.recipes_dir
|
||||
if not recipes_root:
|
||||
raise RecipeNotFoundError("Recipes directory not found")
|
||||
|
||||
normalized_target = os.path.normpath(target_path)
|
||||
recipes_root = os.path.normpath(recipes_root)
|
||||
if not os.path.isabs(normalized_target):
|
||||
normalized_target = os.path.normpath(os.path.join(recipes_root, normalized_target))
|
||||
|
||||
try:
|
||||
common_root = os.path.commonpath([normalized_target, recipes_root])
|
||||
except ValueError as exc:
|
||||
raise RecipeValidationError("Invalid target path") from exc
|
||||
|
||||
if common_root != recipes_root:
|
||||
raise RecipeValidationError("Target path must be inside the recipes directory")
|
||||
|
||||
return normalized_target, recipes_root
|
||||
|
||||
async def _move_recipe_files(
|
||||
self,
|
||||
*,
|
||||
recipe_scanner,
|
||||
recipe_id: str,
|
||||
normalized_target: str,
|
||||
recipes_root: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Move the recipe's JSON and preview image into the normalized target."""
|
||||
|
||||
recipe_json_path = await recipe_scanner.get_recipe_json_path(recipe_id)
|
||||
if not recipe_json_path or not os.path.exists(recipe_json_path):
|
||||
raise RecipeNotFoundError("Recipe not found")
|
||||
|
||||
recipe_data = await recipe_scanner.get_recipe_by_id(recipe_id)
|
||||
if not recipe_data:
|
||||
raise RecipeNotFoundError("Recipe not found")
|
||||
|
||||
current_json_dir = os.path.dirname(recipe_json_path)
|
||||
normalized_image_path = os.path.normpath(recipe_data.get("file_path") or "") if recipe_data.get("file_path") else None
|
||||
|
||||
os.makedirs(normalized_target, exist_ok=True)
|
||||
|
||||
if os.path.normpath(current_json_dir) == normalized_target:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Recipe is already in the target folder",
|
||||
"recipe_id": recipe_id,
|
||||
"original_file_path": recipe_data.get("file_path"),
|
||||
"new_file_path": recipe_data.get("file_path"),
|
||||
}
|
||||
|
||||
new_json_path = os.path.normpath(os.path.join(normalized_target, os.path.basename(recipe_json_path)))
|
||||
shutil.move(recipe_json_path, new_json_path)
|
||||
|
||||
new_image_path = normalized_image_path
|
||||
if normalized_image_path:
|
||||
target_image_path = os.path.normpath(os.path.join(normalized_target, os.path.basename(normalized_image_path)))
|
||||
if os.path.exists(normalized_image_path) and normalized_image_path != target_image_path:
|
||||
shutil.move(normalized_image_path, target_image_path)
|
||||
new_image_path = target_image_path
|
||||
|
||||
relative_folder = os.path.relpath(normalized_target, recipes_root)
|
||||
if relative_folder in (".", ""):
|
||||
relative_folder = ""
|
||||
updates = {"file_path": new_image_path or recipe_data.get("file_path"), "folder": relative_folder.replace(os.path.sep, "/")}
|
||||
|
||||
updated = await recipe_scanner.update_recipe_metadata(recipe_id, updates)
|
||||
if not updated:
|
||||
raise RecipeNotFoundError("Recipe not found after move")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"recipe_id": recipe_id,
|
||||
"original_file_path": recipe_data.get("file_path"),
|
||||
"new_file_path": updates["file_path"],
|
||||
"json_path": new_json_path,
|
||||
"folder": updates["folder"],
|
||||
}
|
||||
|
||||
async def move_recipe(self, *, recipe_scanner, recipe_id: str, target_path: str) -> PersistenceResult:
|
||||
"""Move a recipe's assets into a new folder under the recipes root."""
|
||||
|
||||
normalized_target, recipes_root = self._normalize_target_path(recipe_scanner, target_path)
|
||||
result = await self._move_recipe_files(
|
||||
recipe_scanner=recipe_scanner,
|
||||
recipe_id=recipe_id,
|
||||
normalized_target=normalized_target,
|
||||
recipes_root=recipes_root,
|
||||
)
|
||||
return PersistenceResult(result)
|
||||
|
||||
async def move_recipes_bulk(
|
||||
self,
|
||||
*,
|
||||
recipe_scanner,
|
||||
recipe_ids: Iterable[str],
|
||||
target_path: str,
|
||||
) -> PersistenceResult:
|
||||
"""Move multiple recipes to a new folder."""
|
||||
|
||||
recipe_ids = list(recipe_ids)
|
||||
if not recipe_ids:
|
||||
raise RecipeValidationError("No recipe IDs provided")
|
||||
|
||||
normalized_target, recipes_root = self._normalize_target_path(recipe_scanner, target_path)
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
success_count = 0
|
||||
failure_count = 0
|
||||
|
||||
for recipe_id in recipe_ids:
|
||||
try:
|
||||
move_result = await self._move_recipe_files(
|
||||
recipe_scanner=recipe_scanner,
|
||||
recipe_id=str(recipe_id),
|
||||
normalized_target=normalized_target,
|
||||
recipes_root=recipes_root,
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"recipe_id": recipe_id,
|
||||
"original_file_path": move_result.get("original_file_path"),
|
||||
"new_file_path": move_result.get("new_file_path"),
|
||||
"success": True,
|
||||
"message": move_result.get("message", ""),
|
||||
"folder": move_result.get("folder", ""),
|
||||
}
|
||||
)
|
||||
success_count += 1
|
||||
except Exception as exc: # pragma: no cover - per-item error handling
|
||||
results.append(
|
||||
{
|
||||
"recipe_id": recipe_id,
|
||||
"original_file_path": None,
|
||||
"new_file_path": None,
|
||||
"success": False,
|
||||
"message": str(exc),
|
||||
}
|
||||
)
|
||||
failure_count += 1
|
||||
|
||||
return PersistenceResult(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"Moved {success_count} of {len(recipe_ids)} recipes",
|
||||
"results": results,
|
||||
"success_count": success_count,
|
||||
"failure_count": failure_count,
|
||||
}
|
||||
)
|
||||
|
||||
async def reconnect_lora(
|
||||
self,
|
||||
*,
|
||||
@@ -187,8 +351,8 @@ class RecipePersistenceService:
|
||||
) -> PersistenceResult:
|
||||
"""Reconnect a LoRA entry within an existing recipe."""
|
||||
|
||||
recipe_path = os.path.join(recipe_scanner.recipes_dir, f"{recipe_id}.recipe.json")
|
||||
if not os.path.exists(recipe_path):
|
||||
recipe_path = await recipe_scanner.get_recipe_json_path(recipe_id)
|
||||
if not recipe_path or not os.path.exists(recipe_path):
|
||||
raise RecipeNotFoundError("Recipe not found")
|
||||
|
||||
target_lora = await recipe_scanner.get_local_lora(target_name)
|
||||
@@ -233,16 +397,12 @@ class RecipePersistenceService:
|
||||
if not recipe_ids:
|
||||
raise RecipeValidationError("No recipe IDs provided")
|
||||
|
||||
recipes_dir = recipe_scanner.recipes_dir
|
||||
if not recipes_dir or not os.path.exists(recipes_dir):
|
||||
raise RecipeNotFoundError("Recipes directory not found")
|
||||
|
||||
deleted_recipes: list[str] = []
|
||||
failed_recipes: list[dict[str, Any]] = []
|
||||
|
||||
for recipe_id in recipe_ids:
|
||||
recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json")
|
||||
if not os.path.exists(recipe_json_path):
|
||||
recipe_json_path = await recipe_scanner.get_recipe_json_path(recipe_id)
|
||||
if not recipe_json_path or not os.path.exists(recipe_json_path):
|
||||
failed_recipes.append({"id": recipe_id, "reason": "Recipe not found"})
|
||||
continue
|
||||
|
||||
|
||||
@@ -20,6 +20,8 @@ class WebSocketManager:
|
||||
self._last_init_progress: Dict[str, Dict] = {}
|
||||
# Add auto-organize progress tracking
|
||||
self._auto_organize_progress: Optional[Dict] = None
|
||||
# Add recipe repair progress tracking
|
||||
self._recipe_repair_progress: Optional[Dict] = None
|
||||
self._auto_organize_lock = asyncio.Lock()
|
||||
|
||||
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||
@@ -189,6 +191,14 @@ class WebSocketManager:
|
||||
# Broadcast via WebSocket
|
||||
await self.broadcast(data)
|
||||
|
||||
async def broadcast_recipe_repair_progress(self, data: Dict):
|
||||
"""Broadcast recipe repair progress to connected clients"""
|
||||
# Store progress data in memory
|
||||
self._recipe_repair_progress = data
|
||||
|
||||
# Broadcast via WebSocket
|
||||
await self.broadcast(data)
|
||||
|
||||
def get_auto_organize_progress(self) -> Optional[Dict]:
|
||||
"""Get current auto-organize progress"""
|
||||
return self._auto_organize_progress
|
||||
@@ -197,6 +207,14 @@ class WebSocketManager:
|
||||
"""Clear auto-organize progress data"""
|
||||
self._auto_organize_progress = None
|
||||
|
||||
def get_recipe_repair_progress(self) -> Optional[Dict]:
|
||||
"""Get current recipe repair progress"""
|
||||
return self._recipe_repair_progress
|
||||
|
||||
def cleanup_recipe_repair_progress(self):
|
||||
"""Clear recipe repair progress data"""
|
||||
self._recipe_repair_progress = None
|
||||
|
||||
def is_auto_organize_running(self) -> bool:
|
||||
"""Check if auto-organize is currently running"""
|
||||
if not self._auto_organize_progress:
|
||||
|
||||
Reference in New Issue
Block a user