From 32d94be08a7d8558d743e83d4bd3f769886f0009 Mon Sep 17 00:00:00 2001 From: stone9k <134336732+stone9k@users.noreply.github.com> Date: Fri, 12 Dec 2025 18:50:28 +0100 Subject: [PATCH 01/35] fix(trigger_word_toggle): missing consumeExistingState after refactor --- web/comfyui/trigger_word_toggle.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/comfyui/trigger_word_toggle.js b/web/comfyui/trigger_word_toggle.js index b548a96f..7ff85f78 100644 --- a/web/comfyui/trigger_word_toggle.js +++ b/web/comfyui/trigger_word_toggle.js @@ -311,7 +311,7 @@ app.registerExtension({ }); } else { // If no ',,' delimiter, treat the entire message as one group - const existing = existingTagMap[message.trim()]; + const existing = consumeExistingState(message.trim()); tagArray = [{ text: message.trim(), // Use existing values if available, otherwise use defaults From 5359129fad02f78a736909026a29d3edfbf15d33 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Sun, 14 Dec 2025 15:58:58 +0800 Subject: [PATCH 02/35] feat(config): improve symlink cache logging and add performance timing - Add `time` import for performance measurement - Change debug logs to info level for better visibility of cache operations - Add detailed logging for cache validation failures and successes - Include timing metrics for symlink initialization and scanning - Log cache save/load operations with mapping counts --- py/config.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/py/config.py b/py/config.py index 6212c3b9..e6ffea41 100644 --- a/py/config.py +++ b/py/config.py @@ -6,6 +6,7 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Set import logging import json import urllib.parse +import time from .utils.settings_paths import ensure_settings_file, get_settings_dir, load_settings_template @@ -282,21 +283,24 @@ class Config: def _load_symlink_cache(self) -> bool: cache_path = self._get_symlink_cache_path() if not cache_path.exists(): + logger.info("Symlink cache not found at %s", cache_path) return False try: with cache_path.open("r", encoding="utf-8") as handle: payload = json.load(handle) except Exception as exc: - logger.debug("Failed to load symlink cache %s: %s", cache_path, exc) + logger.info("Failed to load symlink cache %s: %s", cache_path, exc) return False if not isinstance(payload, dict): + logger.info("Symlink cache payload is not a dict: %s", type(payload)) return False cached_fingerprint = payload.get("fingerprint") cached_mappings = payload.get("path_mappings") if not isinstance(cached_fingerprint, dict) or not isinstance(cached_mappings, Mapping): + logger.info("Symlink cache missing fingerprint or path mappings") return False current_fingerprint = self._build_symlink_fingerprint() @@ -307,12 +311,14 @@ class Config: or not isinstance(cached_stats, Mapping) or sorted(cached_roots) != sorted(current_fingerprint["roots"]) # type: ignore[index] ): + logger.info("Symlink cache invalidated: roots changed") return False for root in current_fingerprint["roots"]: # type: ignore[assignment] cached_stat = cached_stats.get(root) if isinstance(cached_stats, Mapping) else None current_stat = current_fingerprint["stats"].get(root) # type: ignore[index] if not isinstance(cached_stat, Mapping) or not current_stat: + logger.info("Symlink cache invalidated: missing stats for %s", root) return False cached_mtime = cached_stat.get("mtime_ns") @@ -321,6 +327,7 @@ class Config: current_inode = current_stat.get("inode") if cached_inode != current_inode: + logger.info("Symlink cache invalidated: inode changed for %s", root) return False if cached_mtime != current_mtime: @@ -332,6 +339,7 @@ class Config: and cached_mtime == cached_noise and current_mtime == current_noise ): + logger.info("Symlink cache invalidated: mtime changed for %s", root) return False normalized_mappings: Dict[str, str] = {} @@ -341,6 +349,7 @@ class Config: normalized_mappings[self._normalize_path(target)] = self._normalize_path(link) self._path_mappings = normalized_mappings + logger.info("Symlink cache loaded with %d mappings", len(self._path_mappings)) return True def _save_symlink_cache(self) -> None: @@ -353,22 +362,37 @@ class Config: try: with cache_path.open("w", encoding="utf-8") as handle: json.dump(payload, handle, ensure_ascii=False, indent=2) + logger.info("Symlink cache saved to %s with %d mappings", cache_path, len(self._path_mappings)) except Exception as exc: - logger.debug("Failed to write symlink cache %s: %s", cache_path, exc) + logger.info("Failed to write symlink cache %s: %s", cache_path, exc) def _initialize_symlink_mappings(self) -> None: + start = time.perf_counter() if not self._load_symlink_cache(): self._scan_symbolic_links() self._save_symlink_cache() + logger.info( + "Symlink mappings rebuilt and cached in %.2f ms", + (time.perf_counter() - start) * 1000, + ) else: - logger.info("Loaded symlink mappings from cache") + logger.info( + "Symlink mappings restored from cache in %.2f ms", + (time.perf_counter() - start) * 1000, + ) self._rebuild_preview_roots() def _scan_symbolic_links(self): """Scan all symbolic links in LoRA, Checkpoint, and Embedding root directories""" + start = time.perf_counter() visited_dirs: Set[str] = set() for root in self._symlink_roots(): self._scan_directory_links(root, visited_dirs) + logger.info( + "Symlink scan finished in %.2f ms with %d mappings", + (time.perf_counter() - start) * 1000, + len(self._path_mappings), + ) def _scan_directory_links(self, root: str, visited_dirs: Set[str]): """Iteratively scan directory symlinks to avoid deep recursion.""" From 2494fa19a6758686cf3842c92ad74dde28fad879 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Mon, 15 Dec 2025 18:46:23 +0800 Subject: [PATCH 03/35] feat(config): add background symlink rescan and simplify cache validation - Added threading import and optional `_rescan_thread` for background operations - Simplified `_load_symlink_cache` to only validate path mappings, removing fingerprint checks - Updated `_initialize_symlink_mappings` to rebuild preview roots and schedule rescan when cache is loaded - Added `_schedule_symlink_rescan` method to perform background validation of symlinks - Cleared `_path_mappings` at start of `_scan_symbolic_links` to prevent stale entries - Background rescan improves performance by deferring symlink validation after cache load --- py/config.py | 102 ++++++++++++++--------------- tests/config/test_symlink_cache.py | 34 ++++++++++ 2 files changed, 85 insertions(+), 51 deletions(-) diff --git a/py/config.py b/py/config.py index e6ffea41..bdd9a591 100644 --- a/py/config.py +++ b/py/config.py @@ -1,8 +1,9 @@ import os import platform +import threading from pathlib import Path import folder_paths # type: ignore -from typing import Any, Dict, Iterable, List, Mapping, Optional, Set +from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple import logging import json import urllib.parse @@ -81,6 +82,8 @@ class Config: self._path_mappings: Dict[str, str] = {} # Normalized preview root directories used to validate preview access self._preview_root_paths: Set[Path] = set() + # Optional background rescan thread + self._rescan_thread: Optional[threading.Thread] = None self.loras_roots = self._init_lora_paths() self.checkpoints_roots = None self.unet_roots = None @@ -297,51 +300,11 @@ class Config: logger.info("Symlink cache payload is not a dict: %s", type(payload)) return False - cached_fingerprint = payload.get("fingerprint") cached_mappings = payload.get("path_mappings") - if not isinstance(cached_fingerprint, dict) or not isinstance(cached_mappings, Mapping): - logger.info("Symlink cache missing fingerprint or path mappings") + if not isinstance(cached_mappings, Mapping): + logger.info("Symlink cache missing path mappings") return False - current_fingerprint = self._build_symlink_fingerprint() - cached_roots = cached_fingerprint.get("roots") - cached_stats = cached_fingerprint.get("stats") - if ( - not isinstance(cached_roots, list) - or not isinstance(cached_stats, Mapping) - or sorted(cached_roots) != sorted(current_fingerprint["roots"]) # type: ignore[index] - ): - logger.info("Symlink cache invalidated: roots changed") - return False - - for root in current_fingerprint["roots"]: # type: ignore[assignment] - cached_stat = cached_stats.get(root) if isinstance(cached_stats, Mapping) else None - current_stat = current_fingerprint["stats"].get(root) # type: ignore[index] - if not isinstance(cached_stat, Mapping) or not current_stat: - logger.info("Symlink cache invalidated: missing stats for %s", root) - return False - - cached_mtime = cached_stat.get("mtime_ns") - cached_inode = cached_stat.get("inode") - current_mtime = current_stat.get("mtime_ns") - current_inode = current_stat.get("inode") - - if cached_inode != current_inode: - logger.info("Symlink cache invalidated: inode changed for %s", root) - return False - - if cached_mtime != current_mtime: - cached_noise = cached_stat.get("noise_mtime_ns") - current_noise = current_stat.get("noise_mtime_ns") - if not ( - cached_noise - and current_noise - and cached_mtime == cached_noise - and current_mtime == current_noise - ): - logger.info("Symlink cache invalidated: mtime changed for %s", root) - return False - normalized_mappings: Dict[str, str] = {} for target, link in cached_mappings.items(): if not isinstance(target, str) or not isinstance(link, str): @@ -368,23 +331,30 @@ class Config: def _initialize_symlink_mappings(self) -> None: start = time.perf_counter() - if not self._load_symlink_cache(): - self._scan_symbolic_links() - self._save_symlink_cache() - logger.info( - "Symlink mappings rebuilt and cached in %.2f ms", - (time.perf_counter() - start) * 1000, - ) - else: + cache_loaded = self._load_symlink_cache() + + if cache_loaded: logger.info( "Symlink mappings restored from cache in %.2f ms", (time.perf_counter() - start) * 1000, ) + self._rebuild_preview_roots() + self._schedule_symlink_rescan() + return + + self._scan_symbolic_links() + self._save_symlink_cache() self._rebuild_preview_roots() + logger.info( + "Symlink mappings rebuilt and cached in %.2f ms", + (time.perf_counter() - start) * 1000, + ) def _scan_symbolic_links(self): """Scan all symbolic links in LoRA, Checkpoint, and Embedding root directories""" start = time.perf_counter() + # Reset mappings before rescanning to avoid stale entries + self._path_mappings.clear() visited_dirs: Set[str] = set() for root in self._symlink_roots(): self._scan_directory_links(root, visited_dirs) @@ -394,6 +364,36 @@ class Config: len(self._path_mappings), ) + def _schedule_symlink_rescan(self) -> None: + """Trigger a best-effort background rescan to refresh stale caches.""" + + if self._rescan_thread and self._rescan_thread.is_alive(): + return + + def worker(): + try: + self._scan_symbolic_links() + self._save_symlink_cache() + self._rebuild_preview_roots() + logger.info("Background symlink rescan completed") + except Exception as exc: # pragma: no cover - defensive logging + logger.info("Background symlink rescan failed: %s", exc) + + thread = threading.Thread( + target=worker, + name="lora-manager-symlink-rescan", + daemon=True, + ) + self._rescan_thread = thread + thread.start() + + def _wait_for_rescan(self, timeout: Optional[float] = None) -> None: + """Block until the background rescan completes (testing convenience).""" + + thread = self._rescan_thread + if thread: + thread.join(timeout=timeout) + def _scan_directory_links(self, root: str, visited_dirs: Set[str]): """Iteratively scan directory symlinks to avoid deep recursion.""" try: diff --git a/tests/config/test_symlink_cache.py b/tests/config/test_symlink_cache.py index b0e46ff7..caf50195 100644 --- a/tests/config/test_symlink_cache.py +++ b/tests/config/test_symlink_cache.py @@ -62,6 +62,7 @@ def test_symlink_scan_skips_file_links(monkeypatch: pytest.MonkeyPatch, tmp_path def test_symlink_cache_reuses_previous_scan(monkeypatch: pytest.MonkeyPatch, tmp_path): loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path) + monkeypatch.setattr(config_module.Config, "_schedule_symlink_rescan", lambda self: None) target_dir = loras_dir / "target" target_dir.mkdir() @@ -85,6 +86,7 @@ def test_symlink_cache_reuses_previous_scan(monkeypatch: pytest.MonkeyPatch, tmp def test_symlink_cache_survives_noise_mtime(monkeypatch: pytest.MonkeyPatch, tmp_path): loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path) + monkeypatch.setattr(config_module.Config, "_schedule_symlink_rescan", lambda self: None) target_dir = loras_dir / "target" target_dir.mkdir() @@ -109,3 +111,35 @@ def test_symlink_cache_survives_noise_mtime(monkeypatch: pytest.MonkeyPatch, tmp second_cfg = config_module.Config() assert second_cfg.map_path_to_link(str(target_dir)) == _normalize(str(dir_link)) + + +def test_background_rescan_refreshes_cache(monkeypatch: pytest.MonkeyPatch, tmp_path): + loras_dir, _ = _setup_paths(monkeypatch, tmp_path) + + target_dir = loras_dir / "target" + target_dir.mkdir() + dir_link = loras_dir / "dir_link" + dir_link.symlink_to(target_dir, target_is_directory=True) + + # Build initial cache pointing at the first target + first_cfg = config_module.Config() + old_real = _normalize(os.path.realpath(target_dir)) + assert first_cfg.map_path_to_link(str(target_dir)) == _normalize(str(dir_link)) + + # Retarget the symlink to a new directory without touching the cache file + new_target = loras_dir / "target_v2" + new_target.mkdir() + dir_link.unlink() + dir_link.symlink_to(new_target, target_is_directory=True) + + second_cfg = config_module.Config() + + # Cache may still point at the old real path immediately after load + initial_mapping = second_cfg.map_path_to_link(str(new_target)) + assert initial_mapping in {str(new_target), _normalize(str(dir_link))} + + # Background rescan should refresh the mapping to the new target and update the cache file + second_cfg._wait_for_rescan(timeout=2.0) + new_real = _normalize(os.path.realpath(new_target)) + assert second_cfg._path_mappings.get(new_real) == _normalize(str(dir_link)) + assert second_cfg.map_path_to_link(str(new_target)) == _normalize(str(dir_link)) From 7e133e4b9d9014394c5b7eae21bd9fba597c03c8 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Mon, 15 Dec 2025 22:09:21 +0800 Subject: [PATCH 04/35] feat: rename SaveImage class to SaveImageLM for clarity The SaveImage class has been renamed to SaveImageLM to better reflect its purpose within the Lora Manager module. This change ensures consistent naming across import statements, class mappings, and the actual class definition, improving code readability and maintainability. --- __init__.py | 6 +++--- py/nodes/save_image.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/__init__.py b/__init__.py index 8e8fbd26..d41a18b4 100644 --- a/__init__.py +++ b/__init__.py @@ -4,7 +4,7 @@ try: # pragma: no cover - import fallback for pytest collection from .py.nodes.trigger_word_toggle import TriggerWordToggle from .py.nodes.prompt import PromptLoraManager from .py.nodes.lora_stacker import LoraStacker - from .py.nodes.save_image import SaveImage + from .py.nodes.save_image import SaveImageLM from .py.nodes.debug_metadata import DebugMetadata from .py.nodes.wanvideo_lora_select import WanVideoLoraSelect from .py.nodes.wanvideo_lora_select_from_text import WanVideoLoraSelectFromText @@ -24,7 +24,7 @@ except ImportError: # pragma: no cover - allows running under pytest without pa LoraManagerTextLoader = importlib.import_module("py.nodes.lora_loader").LoraManagerTextLoader TriggerWordToggle = importlib.import_module("py.nodes.trigger_word_toggle").TriggerWordToggle LoraStacker = importlib.import_module("py.nodes.lora_stacker").LoraStacker - SaveImage = importlib.import_module("py.nodes.save_image").SaveImage + SaveImageLM = importlib.import_module("py.nodes.save_image").SaveImageLM DebugMetadata = importlib.import_module("py.nodes.debug_metadata").DebugMetadata WanVideoLoraSelect = importlib.import_module("py.nodes.wanvideo_lora_select").WanVideoLoraSelect WanVideoLoraSelectFromText = importlib.import_module("py.nodes.wanvideo_lora_select_from_text").WanVideoLoraSelectFromText @@ -36,7 +36,7 @@ NODE_CLASS_MAPPINGS = { LoraManagerTextLoader.NAME: LoraManagerTextLoader, TriggerWordToggle.NAME: TriggerWordToggle, LoraStacker.NAME: LoraStacker, - SaveImage.NAME: SaveImage, + SaveImageLM.NAME: SaveImageLM, DebugMetadata.NAME: DebugMetadata, WanVideoLoraSelect.NAME: WanVideoLoraSelect, WanVideoLoraSelectFromText.NAME: WanVideoLoraSelectFromText diff --git a/py/nodes/save_image.py b/py/nodes/save_image.py index dbf44d07..e11f031a 100644 --- a/py/nodes/save_image.py +++ b/py/nodes/save_image.py @@ -9,7 +9,7 @@ from ..metadata_collector import get_metadata from PIL import Image, PngImagePlugin import piexif -class SaveImage: +class SaveImageLM: NAME = "Save Image (LoraManager)" CATEGORY = "Lora Manager/utils" DESCRIPTION = "Save images with embedded generation metadata in compatible format" From 3382d83aee1baf81789c510ac1ea2e8b5fb1176c Mon Sep 17 00:00:00 2001 From: Will Miao Date: Tue, 16 Dec 2025 21:00:04 +0800 Subject: [PATCH 05/35] feat: remove prewarm cache and improve recipe scanner initialization - Remove prewarm_cache startup hook from BaseRecipeRoutes - Add post-scan task management to RecipeScanner for proper cleanup - Ensure LoRA scanner initialization completes before recipe enrichment - Schedule post-scan enrichment after cache initialization - Improve error handling and task cancellation during shutdown --- py/routes/base_recipe_routes.py | 18 --- py/services/recipe_scanner.py | 109 +++++++++++++++++- tests/routes/test_recipe_route_scaffolding.py | 5 +- tests/services/test_recipe_scanner.py | 96 +++++++++++++++ 4 files changed, 205 insertions(+), 23 deletions(-) diff --git a/py/routes/base_recipe_routes.py b/py/routes/base_recipe_routes.py index c598a6d2..162f3491 100644 --- a/py/routes/base_recipe_routes.py +++ b/py/routes/base_recipe_routes.py @@ -79,26 +79,8 @@ class BaseRecipeRoutes: return app.on_startup.append(self.attach_dependencies) - app.on_startup.append(self.prewarm_cache) self._startup_hooks_registered = True - async def prewarm_cache(self, app: web.Application | None = None) -> None: - """Pre-load recipe and LoRA caches on startup.""" - - try: - await self.attach_dependencies(app) - - if self.lora_scanner is not None: - await self.lora_scanner.get_cached_data() - hash_index = getattr(self.lora_scanner, "_hash_index", None) - if hash_index is not None and hasattr(hash_index, "_hash_to_path"): - _ = len(hash_index._hash_to_path) - - if self.recipe_scanner is not None: - await self.recipe_scanner.get_cached_data(force_refresh=True) - except Exception as exc: - logger.error("Error pre-warming recipe cache: %s", exc, exc_info=True) - def to_route_mapping(self) -> Mapping[str, Callable]: """Return a mapping of handler name to coroutine for registrar binding.""" diff --git a/py/services/recipe_scanner.py b/py/services/recipe_scanner.py index ccaf2395..3b41643e 100644 --- a/py/services/recipe_scanner.py +++ b/py/services/recipe_scanner.py @@ -64,6 +64,7 @@ class RecipeScanner: self._initialization_task: Optional[asyncio.Task] = None self._is_initializing = False self._mutation_lock = asyncio.Lock() + self._post_scan_task: Optional[asyncio.Task] = None self._resort_tasks: Set[asyncio.Task] = set() if lora_scanner: self._lora_scanner = lora_scanner @@ -84,6 +85,10 @@ class RecipeScanner: task.cancel() self._resort_tasks.clear() + if self._post_scan_task and not self._post_scan_task.done(): + self._post_scan_task.cancel() + self._post_scan_task = None + self._cache = None self._initialization_task = None self._is_initializing = False @@ -105,6 +110,8 @@ class RecipeScanner: async def initialize_in_background(self) -> None: """Initialize cache in background using thread pool""" try: + await self._wait_for_lora_scanner() + # Set initial empty cache to avoid None reference errors if self._cache is None: self._cache = RecipeCache( @@ -115,6 +122,7 @@ class RecipeScanner: # Mark as initializing to prevent concurrent initializations self._is_initializing = True + self._initialization_task = asyncio.current_task() try: # Start timer @@ -126,11 +134,14 @@ class RecipeScanner: None, # Use default thread pool self._initialize_recipe_cache_sync # Run synchronous version in thread ) + if cache is not None: + self._cache = cache # Calculate elapsed time and log it elapsed_time = time.time() - start_time recipe_count = len(cache.raw_data) if cache and hasattr(cache, 'raw_data') else 0 logger.info(f"Recipe cache initialized in {elapsed_time:.2f} seconds. Found {recipe_count} recipes") + self._schedule_post_scan_enrichment() finally: # Mark initialization as complete regardless of outcome self._is_initializing = False @@ -237,6 +248,88 @@ class RecipeScanner: # Clean up the event loop loop.close() + async def _wait_for_lora_scanner(self) -> None: + """Ensure the LoRA scanner has initialized before recipe enrichment.""" + + if not getattr(self, "_lora_scanner", None): + return + + lora_scanner = self._lora_scanner + cache_ready = getattr(lora_scanner, "_cache", None) is not None + + # If cache is already available, we can proceed + if cache_ready: + return + + # Await an existing initialization task if present + task = getattr(lora_scanner, "_initialization_task", None) + if task and hasattr(task, "done") and not task.done(): + try: + await task + except Exception: # pragma: no cover - defensive guard + pass + if getattr(lora_scanner, "_cache", None) is not None: + return + + # Otherwise, request initialization and proceed once it completes + try: + await lora_scanner.initialize_in_background() + except Exception as exc: # pragma: no cover - defensive guard + logger.debug("Recipe Scanner: LoRA init request failed: %s", exc) + + def _schedule_post_scan_enrichment(self) -> None: + """Kick off a non-blocking enrichment pass to fill remote metadata.""" + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return + + if self._post_scan_task and not self._post_scan_task.done(): + return + + async def _run_enrichment(): + try: + await self._enrich_cache_metadata() + except asyncio.CancelledError: + raise + except Exception as exc: # pragma: no cover - defensive guard + logger.error("Recipe Scanner: error during post-scan enrichment: %s", exc, exc_info=True) + + self._post_scan_task = loop.create_task(_run_enrichment(), name="recipe_cache_enrichment") + + async def _enrich_cache_metadata(self) -> None: + """Perform remote metadata enrichment after the initial scan.""" + + cache = self._cache + if cache is None or not getattr(cache, "raw_data", None): + return + + for index, recipe in enumerate(list(cache.raw_data)): + try: + metadata_updated = await self._update_lora_information(recipe) + if metadata_updated: + recipe_id = recipe.get("id") + if recipe_id: + recipe_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json") + if os.path.exists(recipe_path): + try: + self._write_recipe_file(recipe_path, recipe) + except Exception as exc: # pragma: no cover - best-effort persistence + logger.debug("Recipe Scanner: could not persist recipe %s: %s", recipe_id, exc) + except asyncio.CancelledError: + raise + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Recipe Scanner: error enriching recipe %s: %s", recipe.get("id"), exc, exc_info=True) + + if index % 10 == 0: + await asyncio.sleep(0) + + try: + await cache.resort() + except Exception as exc: # pragma: no cover - defensive logging + logger.debug("Recipe Scanner: error resorting cache after enrichment: %s", exc) + def _schedule_resort(self, *, name_only: bool = False) -> None: """Schedule a background resort of the recipe cache.""" @@ -438,7 +531,7 @@ class RecipeScanner: recipe_data['gen_params'] = {} # Update lora information with local paths and availability - await self._update_lora_information(recipe_data) + lora_metadata_updated = await self._update_lora_information(recipe_data) if recipe_data.get('checkpoint'): checkpoint_entry = self._normalize_checkpoint_entry(recipe_data['checkpoint']) @@ -459,6 +552,12 @@ class RecipeScanner: logger.info(f"Added fingerprint to recipe: {recipe_path}") except Exception as e: logger.error(f"Error writing updated recipe with fingerprint: {e}") + elif lora_metadata_updated: + # Persist updates such as marking invalid entries as deleted + try: + self._write_recipe_file(recipe_path, recipe_data) + except Exception as e: + logger.error(f"Error writing updated recipe metadata: {e}") return recipe_data except Exception as e: @@ -519,7 +618,13 @@ class RecipeScanner: logger.warning(f"Marked lora with modelVersionId {model_version_id} as deleted") metadata_updated = True else: - logger.debug(f"Could not get hash for modelVersionId {model_version_id}") + # No hash returned; mark as deleted to avoid repeated lookups + lora['isDeleted'] = True + metadata_updated = True + logger.warning( + "Marked lora with modelVersionId %s as deleted after failed hash lookup", + model_version_id, + ) # If has hash but no file_name, look up in lora library if 'hash' in lora and (not lora.get('file_name') or not lora['file_name']): diff --git a/tests/routes/test_recipe_route_scaffolding.py b/tests/routes/test_recipe_route_scaffolding.py index 4baebfa4..0ff95da9 100644 --- a/tests/routes/test_recipe_route_scaffolding.py +++ b/tests/routes/test_recipe_route_scaffolding.py @@ -103,8 +103,7 @@ def test_register_startup_hooks_appends_once(): ] assert routes.attach_dependencies in startup_bound_to_routes - assert routes.prewarm_cache in startup_bound_to_routes - assert len(startup_bound_to_routes) == 2 + assert len(startup_bound_to_routes) == 1 def test_to_route_mapping_uses_handler_set(): @@ -212,4 +211,4 @@ def test_recipe_routes_setup_routes_uses_registrar(monkeypatch: pytest.MonkeyPat if isinstance(getattr(cb, "__self__", None), recipe_routes.RecipeRoutes) } assert {type(cb.__self__) for cb in recipe_callbacks} == {recipe_routes.RecipeRoutes} - assert {cb.__name__ for cb in recipe_callbacks} == {"attach_dependencies", "prewarm_cache"} + assert {cb.__name__ for cb in recipe_callbacks} == {"attach_dependencies"} diff --git a/tests/services/test_recipe_scanner.py b/tests/services/test_recipe_scanner.py index b6b5352c..cdca7c46 100644 --- a/tests/services/test_recipe_scanner.py +++ b/tests/services/test_recipe_scanner.py @@ -349,3 +349,99 @@ def test_enrich_formats_absolute_preview_paths(recipe_scanner, tmp_path): enriched = scanner._enrich_lora_entry(dict(lora)) assert enriched["preview_url"] == config.get_preview_static_url(str(preview_path)) + + +@pytest.mark.asyncio +async def test_initialize_waits_for_lora_scanner(monkeypatch): + ready_flag = asyncio.Event() + call_count = 0 + + class StubLoraScanner: + def __init__(self): + self._cache = None + self._is_initializing = True + + async def initialize_in_background(self): + nonlocal call_count + call_count += 1 + await asyncio.sleep(0) + self._cache = SimpleNamespace(raw_data=[]) + self._is_initializing = False + ready_flag.set() + + lora_scanner = StubLoraScanner() + scanner = RecipeScanner(lora_scanner=lora_scanner) + + await scanner.initialize_in_background() + + assert ready_flag.is_set() + assert call_count == 1 + assert scanner._cache is not None + + +@pytest.mark.asyncio +async def test_invalid_model_version_marked_deleted_and_not_retried(monkeypatch, recipe_scanner): + scanner, _ = recipe_scanner + recipes_dir = Path(config.loras_roots[0]) / "recipes" + recipes_dir.mkdir(parents=True, exist_ok=True) + + recipe = { + "id": "invalid-version", + "file_path": str(recipes_dir / "invalid-version.webp"), + "title": "Invalid", + "modified": 0.0, + "created_date": 0.0, + "loras": [{"modelVersionId": 999, "file_name": "", "hash": ""}], + } + await scanner.add_recipe(dict(recipe)) + + call_count = 0 + + async def fake_get_hash(model_version_id): + nonlocal call_count + call_count += 1 + return None + + monkeypatch.setattr(scanner, "_get_hash_from_civitai", fake_get_hash) + + metadata_updated = await scanner._update_lora_information(recipe) + + assert metadata_updated is True + assert recipe["loras"][0]["isDeleted"] is True + assert call_count == 1 + + # Subsequent calls should skip remote lookup once marked deleted + metadata_updated_again = await scanner._update_lora_information(recipe) + assert metadata_updated_again is False + assert call_count == 1 + + +@pytest.mark.asyncio +async def test_load_recipe_persists_deleted_flag_on_invalid_version(monkeypatch, recipe_scanner, tmp_path): + scanner, _ = recipe_scanner + recipes_dir = Path(config.loras_roots[0]) / "recipes" + recipes_dir.mkdir(parents=True, exist_ok=True) + + recipe_id = "persist-invalid" + recipe_path = recipes_dir / f"{recipe_id}.recipe.json" + recipe_data = { + "id": recipe_id, + "file_path": str(recipes_dir / f"{recipe_id}.webp"), + "title": "Invalid", + "modified": 0.0, + "created_date": 0.0, + "loras": [{"modelVersionId": 1234, "file_name": "", "hash": ""}], + } + recipe_path.write_text(json.dumps(recipe_data)) + + async def fake_get_hash(model_version_id): + return None + + monkeypatch.setattr(scanner, "_get_hash_from_civitai", fake_get_hash) + + loaded = await scanner._load_recipe_file(str(recipe_path)) + + assert loaded["loras"][0]["isDeleted"] is True + + persisted = json.loads(recipe_path.read_text()) + assert persisted["loras"][0]["isDeleted"] is True From 099a71b2cc50f8dc2e462a8b9c065048ed106ebd Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Tue, 16 Dec 2025 22:05:40 +0800 Subject: [PATCH 06/35] feat(config): seed root symlink mappings before deep scanning Add `_seed_root_symlink_mappings` method to ensure symlinked root folders are recorded before deep scanning, preventing them from being missed during directory traversal. This ensures that root symlinks are properly captured in the path mappings. Additionally, normalize separators in relative paths for cross-platform consistency in `BaseModelService`, and update tests to verify root symlinks are preserved in the cache. --- py/config.py | 17 +++++++++++++ py/services/base_model_service.py | 2 ++ tests/config/test_symlink_cache.py | 38 ++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+) diff --git a/py/config.py b/py/config.py index bdd9a591..c34952d5 100644 --- a/py/config.py +++ b/py/config.py @@ -355,6 +355,7 @@ class Config: start = time.perf_counter() # Reset mappings before rescanning to avoid stale entries self._path_mappings.clear() + self._seed_root_symlink_mappings() visited_dirs: Set[str] = set() for root in self._symlink_roots(): self._scan_directory_links(root, visited_dirs) @@ -458,6 +459,22 @@ class Config: self._preview_root_paths.update(self._expand_preview_root(normalized_target)) self._preview_root_paths.update(self._expand_preview_root(normalized_link)) + def _seed_root_symlink_mappings(self) -> None: + """Ensure symlinked root folders are recorded before deep scanning.""" + + for root in self._symlink_roots(): + if not root: + continue + try: + if not self._is_link(root): + continue + target_path = os.path.realpath(root) + if not os.path.isdir(target_path): + continue + self.add_path_mapping(root, target_path) + except Exception as exc: + logger.debug("Skipping root symlink %s: %s", root, exc) + def _expand_preview_root(self, path: str) -> Set[Path]: """Return normalized ``Path`` objects representing a preview root.""" diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index 84db592b..7f1d9bb6 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -716,6 +716,8 @@ class BaseModelService(ABC): if normalized_file.startswith(normalized_root): # Remove root and leading separator to get relative path relative_path = normalized_file[len(normalized_root):].lstrip(os.sep) + # Normalize separators so results are stable across platforms + relative_path = relative_path.replace(os.sep, "/") break if not relative_path: diff --git a/tests/config/test_symlink_cache.py b/tests/config/test_symlink_cache.py index caf50195..5ba9820b 100644 --- a/tests/config/test_symlink_cache.py +++ b/tests/config/test_symlink_cache.py @@ -1,3 +1,4 @@ +import json import os import pytest @@ -143,3 +144,40 @@ def test_background_rescan_refreshes_cache(monkeypatch: pytest.MonkeyPatch, tmp_ new_real = _normalize(os.path.realpath(new_target)) assert second_cfg._path_mappings.get(new_real) == _normalize(str(dir_link)) assert second_cfg.map_path_to_link(str(new_target)) == _normalize(str(dir_link)) + + +def test_symlink_roots_are_preserved(monkeypatch: pytest.MonkeyPatch, tmp_path): + settings_dir = tmp_path / "settings" + real_loras = tmp_path / "loras_real" + real_loras.mkdir() + loras_link = tmp_path / "loras_link" + loras_link.symlink_to(real_loras, target_is_directory=True) + + checkpoints_dir = tmp_path / "checkpoints" + checkpoints_dir.mkdir() + embedding_dir = tmp_path / "embeddings" + embedding_dir.mkdir() + + def fake_get_folder_paths(kind: str): + mapping = { + "loras": [str(loras_link)], + "checkpoints": [str(checkpoints_dir)], + "unet": [], + "embeddings": [str(embedding_dir)], + } + return mapping.get(kind, []) + + monkeypatch.setattr(config_module.folder_paths, "get_folder_paths", fake_get_folder_paths) + monkeypatch.setattr(config_module, "standalone_mode", True) + monkeypatch.setattr(config_module, "get_settings_dir", lambda create=True: str(settings_dir)) + monkeypatch.setattr(config_module.Config, "_schedule_symlink_rescan", lambda self: None) + + cfg = config_module.Config() + + normalized_real = _normalize(os.path.realpath(real_loras)) + normalized_link = _normalize(str(loras_link)) + assert cfg._path_mappings[normalized_real] == normalized_link + + cache_path = settings_dir / "cache" / "symlink_map.json" + payload = json.loads(cache_path.read_text(encoding="utf-8")) + assert payload["path_mappings"][normalized_real] == normalized_link From bdb4422cbc19d4b69d72120ba8296d3771f6746e Mon Sep 17 00:00:00 2001 From: Will Miao Date: Wed, 17 Dec 2025 10:34:04 +0800 Subject: [PATCH 07/35] feat(ui): adjust modal header width and enhance close button z-index, fixes #729 - Decrease modal header width from 85% to 84% for better visual alignment - Add z-index: 10 to close button to ensure it remains above other modal elements --- static/css/components/lora-modal/lora-modal.css | 2 +- static/css/components/modal/_base.css | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/static/css/components/lora-modal/lora-modal.css b/static/css/components/lora-modal/lora-modal.css index cdfe5aaf..71667045 100644 --- a/static/css/components/lora-modal/lora-modal.css +++ b/static/css/components/lora-modal/lora-modal.css @@ -20,7 +20,7 @@ } .modal-header-row { - width: 85%; + width: 84%; display: flex; align-items: flex-start; gap: var(--space-2); diff --git a/static/css/components/modal/_base.css b/static/css/components/modal/_base.css index c869e493..eeadb450 100644 --- a/static/css/components/modal/_base.css +++ b/static/css/components/modal/_base.css @@ -122,6 +122,7 @@ body.modal-open { cursor: pointer; opacity: 0.7; transition: opacity 0.2s; + z-index: 10; } .close:hover { From a07720a3bf1eb1be41ff10e3b85a1eb3428093dd Mon Sep 17 00:00:00 2001 From: Will Miao Date: Wed, 17 Dec 2025 12:52:52 +0800 Subject: [PATCH 08/35] feat: Add model path tracing to accurately identify the primary checkpoint in workflows and include new tests. --- py/metadata_collector/metadata_processor.py | 82 +++++++++- tests/metadata_collector/test_tracer.py | 172 ++++++++++++++++++++ 2 files changed, 249 insertions(+), 5 deletions(-) create mode 100644 tests/metadata_collector/test_tracer.py diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py index 587bcf12..c74cd23a 100644 --- a/py/metadata_collector/metadata_processor.py +++ b/py/metadata_collector/metadata_processor.py @@ -202,12 +202,84 @@ class MetadataProcessor: return last_valid_node if not target_class else None @staticmethod - def find_primary_checkpoint(metadata): - """Find the primary checkpoint model in the workflow""" - if not metadata.get(MODELS): + def trace_model_path(metadata, prompt, start_node_id): + """ + Trace the model connection path upstream to find the checkpoint + """ + if not prompt or not prompt.original_prompt: return None - # In most workflows, there's only one checkpoint, so we can just take the first one + current_node_id = start_node_id + depth = 0 + max_depth = 50 + + while depth < max_depth: + # Check if current node is a registered checkpoint in our metadata + # This handles cached nodes correctly because metadata contains info for all nodes in the graph + if current_node_id in metadata.get(MODELS, {}): + if metadata[MODELS][current_node_id].get("type") == "checkpoint": + return current_node_id + + if current_node_id not in prompt.original_prompt: + return None + + node = prompt.original_prompt[current_node_id] + inputs = node.get("inputs", {}) + class_type = node.get("class_type", "") + + # Determine which input to follow next + next_input_name = "model" + + # Special handling for initial node + if depth == 0: + if class_type == "SamplerCustomAdvanced": + next_input_name = "guider" + + # If the specific input doesn't exist, try generic 'model' + if next_input_name not in inputs: + if "model" in inputs: + next_input_name = "model" + else: + # Dead end - no model input to follow + return None + + # Get connected node + input_val = inputs[next_input_name] + if isinstance(input_val, list) and len(input_val) > 0: + current_node_id = input_val[0] + else: + return None + + depth += 1 + + return None + + @staticmethod + def find_primary_checkpoint(metadata, downstream_id=None): + """ + Find the primary checkpoint model in the workflow + + Parameters: + - metadata: The workflow metadata + - downstream_id: Optional ID of a downstream node to help identify the specific primary sampler + """ + if not metadata.get(MODELS): + return None + + # Method 1: Topology-based tracing (More accurate for complex workflows) + # First, find the primary sampler + primary_sampler_id, _ = MetadataProcessor.find_primary_sampler(metadata, downstream_id) + + if primary_sampler_id: + prompt = metadata.get("current_prompt") + if prompt: + # Trace back from the sampler to find the checkpoint + checkpoint_id = MetadataProcessor.trace_model_path(metadata, prompt, primary_sampler_id) + if checkpoint_id and checkpoint_id in metadata.get(MODELS, {}): + return metadata[MODELS][checkpoint_id].get("name") + + # Method 2: Fallback to the first available checkpoint (Original behavior) + # In most simple workflows, there's only one checkpoint, so we can just take the first one for node_id, model_info in metadata.get(MODELS, {}).items(): if model_info.get("type") == "checkpoint": return model_info.get("name") @@ -311,7 +383,7 @@ class MetadataProcessor: primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata, id) # Directly get checkpoint from metadata instead of tracing - checkpoint = MetadataProcessor.find_primary_checkpoint(metadata) + checkpoint = MetadataProcessor.find_primary_checkpoint(metadata, id) if checkpoint: params["checkpoint"] = checkpoint diff --git a/tests/metadata_collector/test_tracer.py b/tests/metadata_collector/test_tracer.py new file mode 100644 index 00000000..5fca8c5c --- /dev/null +++ b/tests/metadata_collector/test_tracer.py @@ -0,0 +1,172 @@ + +import pytest +from types import SimpleNamespace +from py.metadata_collector.metadata_processor import MetadataProcessor +from py.metadata_collector.constants import MODELS, SAMPLING, IS_SAMPLER + +class TestMetadataTracer: + + @pytest.fixture + def mock_workflow_metadata(self): + """ + Creates a mock metadata structure with a complex workflow graph. + Structure: + Sampler(246) -> Guider(241) -> LoraLoader(264) -> CheckpointLoader(238) + + Also includes a "Decoy" checkpoint (ID 999) that is NOT connected, + to verify we found the *connected* one, not just *any* one. + """ + + # 1. Define the Graph (Original Prompt) + # Using IDs as strings to match typical ComfyUI behavior in metadata + original_prompt = { + "246": { + "class_type": "SamplerCustomAdvanced", + "inputs": { + "guider": ["241", 0], + "noise": ["255", 0], + "sampler": ["247", 0], + "sigmas": ["248", 0], + "latent_image": ["153", 0] + } + }, + "241": { + "class_type": "CFGGuider", + "inputs": { + "model": ["264", 0], + "positive": ["239", 0], + "negative": ["240", 0] + } + }, + "264": { + "class_type": "LoraLoader", # Simplified name + "inputs": { + "model": ["238", 0], + "lora_name": "some_style_lora.safetensors" + } + }, + "238": { + "class_type": "CheckpointLoaderSimple", + "inputs": { + "ckpt_name": "Correct_Model.safetensors" + } + }, + + # unconnected / decoy nodes + "999": { + "class_type": "CheckpointLoaderSimple", + "inputs": { + "ckpt_name": "Decoy_Model.safetensors" + } + }, + "154": { # Downstream VAE Decode + "class_type": "VAEDecode", + "inputs": { + "samples": ["246", 0] + } + } + } + + # 2. Define the Metadata (Collected execution data) + metadata = { + "current_prompt": SimpleNamespace(original_prompt=original_prompt), + "execution_order": ["238", "264", "241", "246", "154", "999"], # 999 execs last or separately + + # Models Registry + MODELS: { + "238": { + "type": "checkpoint", + "name": "Correct_Model.safetensors" + }, + "999": { + "type": "checkpoint", + "name": "Decoy_Model.safetensors" + } + }, + + # Sampling Registry + SAMPLING: { + "246": { + IS_SAMPLER: True, + "parameters": { + "sampler_name": "euler", + "scheduler": "normal" + } + } + }, + "images": { + "first_decode": { + "node_id": "154" + } + } + } + + return metadata + + def test_find_primary_sampler_identifies_correct_node(self, mock_workflow_metadata): + """Verify find_primary_sampler correctly identifies the sampler connected to the downstream decode.""" + sampler_id, sampler_info = MetadataProcessor.find_primary_sampler(mock_workflow_metadata, downstream_id="154") + + assert sampler_id == "246" + assert sampler_info is not None + assert sampler_info["parameters"]["sampler_name"] == "euler" + + def test_trace_model_path_follows_topology(self, mock_workflow_metadata): + """Verify trace_model_path follows: Sampler -> Guider -> Lora -> Checkpoint.""" + prompt = mock_workflow_metadata["current_prompt"] + + # Start trace from Sampler (246) + # Should find Checkpoint (238) + ckpt_id = MetadataProcessor.trace_model_path(mock_workflow_metadata, prompt, "246") + + assert ckpt_id == "238" # Should be the ID of the connected checkpoint + + def test_find_primary_checkpoint_prioritizes_connected_model(self, mock_workflow_metadata): + """Verify find_primary_checkpoint returns the NAME of the topologically connected checkpoint, honoring the graph.""" + name = MetadataProcessor.find_primary_checkpoint(mock_workflow_metadata, downstream_id="154") + + assert name == "Correct_Model.safetensors" + assert name != "Decoy_Model.safetensors" + + def test_trace_model_path_simple_direct_connection(self): + """Verify it works for a simple Sampler -> Checkpoint connection.""" + original_prompt = { + "100": { # Sampler + "class_type": "KSampler", + "inputs": { + "model": ["101", 0] + } + }, + "101": { # Checkpoint + "class_type": "CheckpointLoaderSimple", + "inputs": {} + } + } + + metadata = { + "current_prompt": SimpleNamespace(original_prompt=original_prompt), + MODELS: { + "101": {"type": "checkpoint", "name": "Simple_Model.safetensors"} + } + } + + ckpt_id = MetadataProcessor.trace_model_path(metadata, metadata["current_prompt"], "100") + assert ckpt_id == "101" + + def test_trace_stops_at_max_depth(self): + """Verify logic halts if graph is infinitely cyclic or too deep.""" + # Create a cycle: Node 1 -> Node 2 -> Node 1 + original_prompt = { + "1": {"inputs": {"model": ["2", 0]}}, + "2": {"inputs": {"model": ["1", 0]}} + } + + metadata = { + "current_prompt": SimpleNamespace(original_prompt=original_prompt), + MODELS: {} # No checkpoints registered + } + + # Should return None, not hang forever + ckpt_id = MetadataProcessor.trace_model_path(metadata, metadata["current_prompt"], "1") + assert ckpt_id is None + From ca6bb43406770f8b70141401f5d819dc8f43875e Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 17 Dec 2025 19:07:08 +0800 Subject: [PATCH 09/35] feat: remove path separator normalization for cross-platform compatibility Removed the forced normalization of path separators to forward slashes in BaseModelService to maintain platform-specific separators. Updated test cases to use os.sep for constructing expected paths, ensuring tests work correctly across different operating systems while preserving native path representations. --- py/services/base_model_service.py | 2 -- tests/services/test_relative_path_search.py | 7 ++++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index 7f1d9bb6..84db592b 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -716,8 +716,6 @@ class BaseModelService(ABC): if normalized_file.startswith(normalized_root): # Remove root and leading separator to get relative path relative_path = normalized_file[len(normalized_root):].lstrip(os.sep) - # Normalize separators so results are stable across platforms - relative_path = relative_path.replace(os.sep, "/") break if not relative_path: diff --git a/tests/services/test_relative_path_search.py b/tests/services/test_relative_path_search.py index e26c98ea..0e10039c 100644 --- a/tests/services/test_relative_path_search.py +++ b/tests/services/test_relative_path_search.py @@ -1,3 +1,4 @@ +import os import pytest from py.services.base_model_service import BaseModelService @@ -42,8 +43,8 @@ async def test_search_relative_paths_supports_multiple_tokens(): matching = await service.search_relative_paths("flux detail") assert matching == [ - "flux/detail-model.safetensors", - "detail/flux-trained.safetensors", + f"flux{os.sep}detail-model.safetensors", + f"detail{os.sep}flux-trained.safetensors", ] @@ -60,4 +61,4 @@ async def test_search_relative_paths_excludes_tokens(): matching = await service.search_relative_paths("flux -detail") - assert matching == ["flux/keep-me.safetensors"] + assert matching == [f"flux{os.sep}keep-me.safetensors"] From c8a179488aa2f8723bb7bbd040ee2c8475aa4724 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 18 Dec 2025 22:30:41 +0800 Subject: [PATCH 10/35] feat(metadata): enhance primary sampler detection and workflow tracing - Add support for `basic_pipe` nodes in metadata processor to handle pipeline nodes like FromBasicPipe - Optimize `find_primary_checkpoint` by accepting optional `primary_sampler_id` to avoid redundant calculations - Update `get_workflow_trace` to pass known primary sampler ID for improved efficiency --- py/metadata_collector/metadata_processor.py | 14 ++- tests/metadata_collector/test_pipe_tracer.py | 98 ++++++++++++++++++++ 2 files changed, 108 insertions(+), 4 deletions(-) create mode 100644 tests/metadata_collector/test_pipe_tracer.py diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py index c74cd23a..9dd85542 100644 --- a/py/metadata_collector/metadata_processor.py +++ b/py/metadata_collector/metadata_processor.py @@ -239,6 +239,9 @@ class MetadataProcessor: if next_input_name not in inputs: if "model" in inputs: next_input_name = "model" + elif "basic_pipe" in inputs: + # Handle pipe nodes like FromBasicPipe by following the pipeline + next_input_name = "basic_pipe" else: # Dead end - no model input to follow return None @@ -255,20 +258,22 @@ class MetadataProcessor: return None @staticmethod - def find_primary_checkpoint(metadata, downstream_id=None): + def find_primary_checkpoint(metadata, downstream_id=None, primary_sampler_id=None): """ Find the primary checkpoint model in the workflow Parameters: - metadata: The workflow metadata - downstream_id: Optional ID of a downstream node to help identify the specific primary sampler + - primary_sampler_id: Optional ID of the primary sampler if already known """ if not metadata.get(MODELS): return None # Method 1: Topology-based tracing (More accurate for complex workflows) - # First, find the primary sampler - primary_sampler_id, _ = MetadataProcessor.find_primary_sampler(metadata, downstream_id) + # First, find the primary sampler if not provided + if not primary_sampler_id: + primary_sampler_id, _ = MetadataProcessor.find_primary_sampler(metadata, downstream_id) if primary_sampler_id: prompt = metadata.get("current_prompt") @@ -383,7 +388,8 @@ class MetadataProcessor: primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata, id) # Directly get checkpoint from metadata instead of tracing - checkpoint = MetadataProcessor.find_primary_checkpoint(metadata, id) + # Pass primary_sampler_id to avoid redundant calculation + checkpoint = MetadataProcessor.find_primary_checkpoint(metadata, id, primary_sampler_id) if checkpoint: params["checkpoint"] = checkpoint diff --git a/tests/metadata_collector/test_pipe_tracer.py b/tests/metadata_collector/test_pipe_tracer.py new file mode 100644 index 00000000..ddad5b98 --- /dev/null +++ b/tests/metadata_collector/test_pipe_tracer.py @@ -0,0 +1,98 @@ + +import pytest +from types import SimpleNamespace +from py.metadata_collector.metadata_processor import MetadataProcessor +from py.metadata_collector.constants import MODELS, SAMPLING, IS_SAMPLER + +class TestPipeTracer: + + @pytest.fixture + def pipe_workflow_metadata(self): + """ + Creates a mock metadata structure matching the one provided in refs/tmp. + Structure: + Load Checkpoint(28) -> Lora Loader(52) -> ToBasicPipe(69) -> FromBasicPipe(71) -> KSampler(32) + """ + + original_prompt = { + '28': { + 'inputs': {'ckpt_name': 'Illustrious\\bananaSplitzXL_vee5PointOh.safetensors'}, + 'class_type': 'CheckpointLoaderSimple' + }, + '52': { + 'inputs': { + 'model': ['28', 0], + 'clip': ['28', 1] + }, + 'class_type': 'Lora Loader (LoraManager)' + }, + '69': { + 'inputs': { + 'model': ['52', 0], + 'clip': ['52', 1], + 'vae': ['28', 2], + 'positive': ['75', 0], + 'negative': ['30', 0] + }, + 'class_type': 'ToBasicPipe' + }, + '71': { + 'inputs': {'basic_pipe': ['69', 0]}, + 'class_type': 'FromBasicPipe' + }, + '32': { + 'inputs': { + 'seed': 131755205602911, + 'steps': 5, + 'cfg': 8.0, + 'sampler_name': 'euler_ancestral', + 'scheduler': 'karras', + 'denoise': 1.0, + 'model': ['71', 0], + 'positive': ['71', 3], + 'negative': ['71', 4], + 'latent_image': ['76', 0] + }, + 'class_type': 'KSampler' + }, + '75': {'inputs': {'text': 'positive', 'clip': ['52', 1]}, 'class_type': 'CLIPTextEncode'}, + '30': {'inputs': {'text': 'negative', 'clip': ['52', 1]}, 'class_type': 'CLIPTextEncode'}, + '76': {'inputs': {'width': 832, 'height': 1216, 'batch_size': 1}, 'class_type': 'EmptyLatentImage'} + } + + metadata = { + "current_prompt": SimpleNamespace(original_prompt=original_prompt), + MODELS: { + "28": { + "type": "checkpoint", + "name": "bananaSplitzXL_vee5PointOh.safetensors" + } + }, + SAMPLING: { + "32": { + IS_SAMPLER: True, + "parameters": { + "sampler_name": "euler_ancestral", + "scheduler": "karras" + } + } + } + } + + return metadata + + def test_trace_model_path_through_pipe(self, pipe_workflow_metadata): + """Verify trace_model_path can follow: KSampler -> FromBasicPipe -> ToBasicPipe -> Lora -> Checkpoint.""" + prompt = pipe_workflow_metadata["current_prompt"] + + # Start trace from KSampler (32) + ckpt_id = MetadataProcessor.trace_model_path(pipe_workflow_metadata, prompt, "32") + + assert ckpt_id == "28" + + def test_find_primary_checkpoint_with_pipe(self, pipe_workflow_metadata): + """Verify find_primary_checkpoint returns the correct name even with pipe nodes.""" + # Providing sampler_id to test the optimization as well + name = MetadataProcessor.find_primary_checkpoint(pipe_workflow_metadata, primary_sampler_id="32") + + assert name == "bananaSplitzXL_vee5PointOh.safetensors" From 154ae825193a696d348a722ca5b98884c9e28073 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 19 Dec 2025 01:30:08 +0800 Subject: [PATCH 11/35] feat(metadata_processor): enhance primary sampler selection logic - Add pre-processing step to populate missing parameters for candidate samplers, especially for SamplerCustomAdvanced requiring tracing - Change sampler selection from most recent (closest to downstream) to first in execution order to prioritize base samplers over refine samplers - Improve parameter handling by updating sampler parameters with traced values before ranking - Maintain backward compatibility with fallback to first sampler if no criteria match --- py/metadata_collector/metadata_processor.py | 67 +++++++++++++++++---- 1 file changed, 55 insertions(+), 12 deletions(-) diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py index 9dd85542..2d39f2ba 100644 --- a/py/metadata_collector/metadata_processor.py +++ b/py/metadata_collector/metadata_processor.py @@ -39,8 +39,39 @@ class MetadataProcessor: if node_id in metadata.get(SAMPLING, {}) and metadata[SAMPLING][node_id].get(IS_SAMPLER, False): candidate_samplers[node_id] = metadata[SAMPLING][node_id] - # If we found candidate samplers, apply primary sampler logic to these candidates only - if candidate_samplers: + # If we found candidate samplers, apply primary sampler logic to these candidates only + + # PRE-PROCESS: Ensure all candidate samplers have their parameters populated + # This is especially important for SamplerCustomAdvanced which needs tracing + prompt = metadata.get("current_prompt") + for node_id in candidate_samplers: + # If a sampler is missing common parameters like steps or denoise, + # try to populate them using tracing before ranking + sampler_info = candidate_samplers[node_id] + params = sampler_info.get("parameters", {}) + + if prompt and (params.get("steps") is None or params.get("denoise") is None): + # Create a temporary params dict to use the handler + temp_params = { + "steps": params.get("steps"), + "denoise": params.get("denoise"), + "sampler": params.get("sampler_name"), + "scheduler": params.get("scheduler") + } + + # Check if it's SamplerCustomAdvanced + if prompt.original_prompt and node_id in prompt.original_prompt: + if prompt.original_prompt[node_id].get("class_type") == "SamplerCustomAdvanced": + MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, node_id, temp_params) + + # Update the actual parameters with found values + params["steps"] = temp_params.get("steps") + params["denoise"] = temp_params.get("denoise") + if temp_params.get("sampler"): + params["sampler_name"] = temp_params.get("sampler") + if temp_params.get("scheduler"): + params["scheduler"] = temp_params.get("scheduler") + # Collect potential primary samplers based on different criteria custom_advanced_samplers = [] advanced_add_noise_samplers = [] @@ -49,7 +80,6 @@ class MetadataProcessor: high_denoise_id = None # First, check for SamplerCustomAdvanced among candidates - prompt = metadata.get("current_prompt") if prompt and prompt.original_prompt: for node_id in candidate_samplers: node_info = prompt.original_prompt.get(node_id, {}) @@ -77,15 +107,16 @@ class MetadataProcessor: # Combine all potential primary samplers potential_samplers = custom_advanced_samplers + advanced_add_noise_samplers + high_denoise_samplers - # Find the most recent potential primary sampler (closest to downstream node) - for i in range(downstream_index - 1, -1, -1): + # Find the first potential primary sampler (prefer base sampler over refine) + # Use forward search to prioritize the first one in execution order + for i in range(downstream_index): node_id = execution_order[i] if node_id in potential_samplers: return node_id, candidate_samplers[node_id] - # If no potential sampler found from our criteria, return the most recent sampler + # If no potential sampler found from our criteria, return the first sampler if candidate_samplers: - for i in range(downstream_index - 1, -1, -1): + for i in range(downstream_index): node_id = execution_order[i] if node_id in candidate_samplers: return node_id, candidate_samplers[node_id] @@ -176,8 +207,11 @@ class MetadataProcessor: found_node_id = input_value[0] # Connected node_id # If we're looking for a specific node class - if target_class and prompt.original_prompt[found_node_id].get("class_type") == target_class: - return found_node_id + if target_class: + if found_node_id not in prompt.original_prompt: + return None + if prompt.original_prompt[found_node_id].get("class_type") == target_class: + return found_node_id # If we're not looking for a specific class, update the last valid node if not target_class: @@ -185,11 +219,19 @@ class MetadataProcessor: # Continue tracing through intermediate nodes current_node_id = found_node_id - # For most conditioning nodes, the input we want to follow is named "conditioning" - if "conditioning" in prompt.original_prompt[current_node_id].get("inputs", {}): + + # Check if current source node exists + if current_node_id not in prompt.original_prompt: + return found_node_id if not target_class else None + + # Determine which input to follow next on the source node + source_node_inputs = prompt.original_prompt[current_node_id].get("inputs", {}) + if input_name in source_node_inputs: + current_input = input_name + elif "conditioning" in source_node_inputs: current_input = "conditioning" else: - # If there's no "conditioning" input, return the current node + # If there's no suitable input to follow, return the current node # if we're not looking for a specific target_class return found_node_id if not target_class else None else: @@ -523,6 +565,7 @@ class MetadataProcessor: scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {}) params["steps"] = scheduler_params.get("steps") params["scheduler"] = scheduler_params.get("scheduler") + params["denoise"] = scheduler_params.get("denoise") # 2. Trace sampler input to find KSamplerSelect (only if sampler input exists) if "sampler" in sampler_inputs: From 63b087fc80d89ba913b635cdda0b36579d7688d8 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Fri, 19 Dec 2025 22:40:36 +0800 Subject: [PATCH 12/35] feat: Implement cache busting for static assets, remove client-side version mismatch banner, and add project overview documentation. --- GEMINI.md | 84 +++++++++++++++++++++++++++ py/routes/handlers/model_handlers.py | 32 ++++++++++ static/js/core.js | 5 +- static/js/managers/BannerService.js | 2 +- static/js/managers/UpdateService.js | 87 +--------------------------- templates/base.html | 6 +- templates/checkpoints.html | 2 +- templates/embeddings.html | 2 +- templates/loras.html | 2 +- templates/recipes.html | 8 +-- templates/statistics.html | 2 +- tests/frontend/core/appCore.test.js | 10 ---- 12 files changed, 132 insertions(+), 110 deletions(-) create mode 100644 GEMINI.md diff --git a/GEMINI.md b/GEMINI.md new file mode 100644 index 00000000..c0239352 --- /dev/null +++ b/GEMINI.md @@ -0,0 +1,84 @@ +# ComfyUI LoRA Manager + +## Project Overview + +ComfyUI LoRA Manager is a comprehensive extension for ComfyUI that streamlines the organization, downloading, and application of LoRA models. It functions as both a custom node within ComfyUI and a standalone application. + +**Key Features:** +* **Model Management:** Browse, organize, and download LoRA models (and Checkpoints/Embeddings) from Civitai and CivArchive. +* **Visualization:** Preview images, videos, and trigger words. +* **Workflow Integration:** "One-click" integration into ComfyUI workflows, preserving generation parameters. +* **Recipe System:** Save and share LoRA combinations as "recipes". +* **Architecture:** Hybrid Python backend (API, file management) and JavaScript/HTML frontend (Web UI). + +## Directory Structure + +* `py/`: Core Python backend source code. + * `lora_manager.py`: Main entry point for the ComfyUI node. + * `routes/`: API route definitions (using `aiohttp` in standalone, or ComfyUI's server). + * `services/`: Business logic (downloading, metadata, scanning). + * `nodes/`: ComfyUI custom node implementations. +* `static/`: Frontend static assets (CSS, JS, Images). +* `templates/`: HTML templates (Jinja2). +* `locales/`: Internationalization JSON files. +* `web/comfyui/`: JavaScript extensions specifically for the ComfyUI interface. +* `standalone.py`: Entry point for running the manager as a standalone web app. +* `tests/`: Backend tests. +* `requirements.txt`: Python runtime dependencies. +* `package.json`: Frontend development dependencies and test scripts. + +## Building and Running + +### Prerequisites +* Python 3.8+ +* Node.js (only for running frontend tests) + +### Backend Setup +1. Install Python dependencies: + ```bash + pip install -r requirements.txt + ``` + +### Running in Standalone Mode +You can run the manager independently of ComfyUI for development or management purposes. +```bash +python standalone.py --port 8188 +``` + +### Running in ComfyUI +Ensure the folder is located in `ComfyUI/custom_nodes/`. ComfyUI will automatically load it upon startup. + +## Testing + +### Backend Tests (Pytest) +1. Install development dependencies: + ```bash + pip install -r requirements-dev.txt + ``` +2. Run tests: + ```bash + pytest + ``` + * Coverage reports are generated in `coverage/backend/`. + +### Frontend Tests (Vitest) +1. Install Node dependencies: + ```bash + npm install + ``` +2. Run tests: + ```bash + npm run test + ``` +3. Run coverage: + ```bash + npm run test:coverage + ``` + +## Development Conventions + +* **Python Style:** Follow PEP 8. Use snake_case for files/functions and PascalCase for classes. +* **Frontend:** Standard ES modules. UI components often end in `_widget.js`. +* **Configuration:** User settings are stored in `settings.json`. Developers should reference `settings.json.example`. +* **Localization:** Update `locales/.json` and run `scripts/sync_translation_keys.py` when changing UI text. +* **Documentation:** Architecture details are in `docs/architecture/` and `IFLOW.md`. diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index b07bac01..babf5b63 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -61,6 +61,37 @@ class ModelPageView: self._settings = settings_service self._server_i18n = server_i18n self._logger = logger + self._app_version = self._get_app_version() + + def _get_app_version(self) -> str: + version = "1.0.0" + short_hash = "stable" + try: + import toml + current_file = os.path.abspath(__file__) + # Navigate up from py/routes/handlers/model_handlers.py to project root + root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file)))) + pyproject_path = os.path.join(root_dir, 'pyproject.toml') + + if os.path.exists(pyproject_path): + with open(pyproject_path, 'r', encoding='utf-8') as f: + data = toml.load(f) + version = data.get('project', {}).get('version', '1.0.0').replace('v', '') + + # Try to get git info for granular cache busting + git_dir = os.path.join(root_dir, '.git') + if os.path.exists(git_dir): + try: + import git + repo = git.Repo(root_dir) + short_hash = repo.head.commit.hexsha[:7] + except Exception: + # Fallback if git is not available or not a repo + pass + except Exception as e: + self._logger.debug(f"Failed to read version info for cache busting: {e}") + + return f"{version}-{short_hash}" async def handle(self, request: web.Request) -> web.Response: try: @@ -96,6 +127,7 @@ class ModelPageView: "request": request, "folders": [], "t": self._server_i18n.get_translation, + "version": self._app_version, } if not is_initializing: diff --git a/static/js/core.js b/static/js/core.js index 9701a32f..11c09a1a 100644 --- a/static/js/core.js +++ b/static/js/core.js @@ -84,10 +84,7 @@ export class AppCore { // Start onboarding if needed (after everything is initialized) setTimeout(() => { - // Do not show onboarding if version-mismatch banner is visible - if (!bannerService.isBannerVisible('version-mismatch')) { - onboardingManager.start(); - } + onboardingManager.start(); }, 1000); // Small delay to ensure all elements are rendered // Return the core instance for chaining diff --git a/static/js/managers/BannerService.js b/static/js/managers/BannerService.js index f4461fdc..5b0947a9 100644 --- a/static/js/managers/BannerService.js +++ b/static/js/managers/BannerService.js @@ -17,7 +17,7 @@ const AFDIAN_URL = 'https://afdian.com/a/pixelpawsai'; const BANNER_HISTORY_KEY = 'banner_history'; const BANNER_HISTORY_VIEWED_AT_KEY = 'banner_history_viewed_at'; const BANNER_HISTORY_LIMIT = 20; -const HISTORY_EXCLUDED_IDS = new Set(['version-mismatch']); +const HISTORY_EXCLUDED_IDS = new Set([]); /** * Banner Service for managing notification banners diff --git a/static/js/managers/UpdateService.js b/static/js/managers/UpdateService.js index d93784fa..0cfb7619 100644 --- a/static/js/managers/UpdateService.js +++ b/static/js/managers/UpdateService.js @@ -4,8 +4,7 @@ import { setStorageItem, getStoredVersionInfo, setStoredVersionInfo, - isVersionMatch, - resetDismissedBanner + isVersionMatch } from '../utils/storageHelpers.js'; import { bannerService } from './BannerService.js'; import { translate } from '../utils/i18nHelpers.js'; @@ -753,94 +752,14 @@ export class UpdateService { stored: getStoredVersionInfo() }); - // Reset dismissed status for version mismatch banner - resetDismissedBanner('version-mismatch'); - - // Register and show the version mismatch banner - this.registerVersionMismatchBanner(); + // Silently update stored version info as cache busting handles the resource updates + setStoredVersionInfo(this.currentVersionInfo); } } } catch (error) { console.error('Failed to check version info:', error); } } - - registerVersionMismatchBanner() { - // Get stored and current version for display - const storedVersion = getStoredVersionInfo() || translate('common.status.unknown'); - const currentVersion = this.currentVersionInfo || translate('common.status.unknown'); - - bannerService.registerBanner('version-mismatch', { - id: 'version-mismatch', - title: translate('banners.versionMismatch.title', {}, 'Application Update Detected'), - content: translate('banners.versionMismatch.content', { - storedVersion, - currentVersion - }, `Your browser is running an outdated version of LoRA Manager (${storedVersion}). The server has been updated to version ${currentVersion}. Please refresh to ensure proper functionality.`), - actions: [ - { - text: translate('banners.versionMismatch.refreshNow', {}, 'Refresh Now'), - icon: 'fas fa-sync', - action: 'hardRefresh', - type: 'primary' - } - ], - dismissible: false, - priority: 10, - countdown: 15, - onRegister: (bannerElement) => { - // Add countdown element - const countdownEl = document.createElement('div'); - countdownEl.className = 'banner-countdown'; - countdownEl.innerHTML = `${translate('banners.versionMismatch.refreshingIn', {}, 'Refreshing in')} 15 ${translate('banners.versionMismatch.seconds', {}, 'seconds')}...`; - bannerElement.querySelector('.banner-content').appendChild(countdownEl); - - // Start countdown - let seconds = 15; - const countdownInterval = setInterval(() => { - seconds--; - const strongEl = countdownEl.querySelector('strong'); - if (strongEl) strongEl.textContent = seconds; - - if (seconds <= 0) { - clearInterval(countdownInterval); - this.performHardRefresh(); - } - }, 1000); - - // Store interval ID for cleanup - bannerElement.dataset.countdownInterval = countdownInterval; - - // Add action button event handler - const actionBtn = bannerElement.querySelector('.banner-action[data-action="hardRefresh"]'); - if (actionBtn) { - actionBtn.addEventListener('click', (e) => { - e.preventDefault(); - clearInterval(countdownInterval); - this.performHardRefresh(); - }); - } - }, - onRemove: (bannerElement) => { - // Clear any existing interval - const intervalId = bannerElement.dataset.countdownInterval; - if (intervalId) { - clearInterval(parseInt(intervalId)); - } - } - }); - } - - performHardRefresh() { - // Update stored version info before refreshing - setStoredVersionInfo(this.currentVersionInfo); - - // Force a hard refresh by adding cache-busting parameter - const cacheBuster = new Date().getTime(); - window.location.href = window.location.pathname + - (window.location.search ? window.location.search + '&' : '?') + - `cache=${cacheBuster}`; - } } // Create and export singleton instance diff --git a/templates/base.html b/templates/base.html index b3beeff2..573e2112 100644 --- a/templates/base.html +++ b/templates/base.html @@ -4,8 +4,8 @@ {% block title %}{{ t('header.appTitle') }}{% endblock %} - - + + {% block page_css %}{% endblock %} - + {% else %} {% block main_script %}{% endblock %} {% endif %} diff --git a/templates/checkpoints.html b/templates/checkpoints.html index ffa30aff..fd3498da 100644 --- a/templates/checkpoints.html +++ b/templates/checkpoints.html @@ -40,5 +40,5 @@ {% endblock %} {% block main_script %} - + {% endblock %} diff --git a/templates/embeddings.html b/templates/embeddings.html index badf12e3..de8b807a 100644 --- a/templates/embeddings.html +++ b/templates/embeddings.html @@ -40,5 +40,5 @@ {% endblock %} {% block main_script %} - + {% endblock %} diff --git a/templates/loras.html b/templates/loras.html index 3ece6a10..5ede68d9 100644 --- a/templates/loras.html +++ b/templates/loras.html @@ -24,6 +24,6 @@ {% block main_script %} {% if not is_initializing %} - + {% endif %} {% endblock %} \ No newline at end of file diff --git a/templates/recipes.html b/templates/recipes.html index 58003607..202791a2 100644 --- a/templates/recipes.html +++ b/templates/recipes.html @@ -4,9 +4,9 @@ {% block page_id %}recipes{% endblock %} {% block page_css %} - - - + + + {% endblock %} {% block additional_components %} @@ -84,5 +84,5 @@ {% endblock %} {% block main_script %} - + {% endblock %} \ No newline at end of file diff --git a/templates/statistics.html b/templates/statistics.html index e8d73ed7..1986ac3e 100644 --- a/templates/statistics.html +++ b/templates/statistics.html @@ -192,6 +192,6 @@ {% block main_script %} {% if not is_initializing %} - + {% endif %} {% endblock %} \ No newline at end of file diff --git a/tests/frontend/core/appCore.test.js b/tests/frontend/core/appCore.test.js index 094200f1..ad3d4bf6 100644 --- a/tests/frontend/core/appCore.test.js +++ b/tests/frontend/core/appCore.test.js @@ -234,7 +234,6 @@ describe('AppCore initialization flow', () => { await vi.runAllTimersAsync(); expect(onboardingManager.start).toHaveBeenCalledTimes(1); - expect(bannerService.isBannerVisible).toHaveBeenCalledWith('version-mismatch'); }); it('does not reinitialize once initialized', async () => { @@ -262,13 +261,4 @@ describe('AppCore initialization flow', () => { expect(BulkContextMenu).not.toHaveBeenCalled(); expect(bulkManager.setBulkContextMenu).not.toHaveBeenCalled(); }); - - it('suppresses onboarding when version mismatch banner is visible', async () => { - bannerService.isBannerVisible.mockReturnValueOnce(true); - - await appCore.initialize(); - await vi.runAllTimersAsync(); - - expect(onboardingManager.start).not.toHaveBeenCalled(); - }); }); From 30fd0470deeca6c9a5aaabafdbdd105030763f52 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Sun, 21 Dec 2025 20:00:44 +0800 Subject: [PATCH 13/35] feat: Add support for video recipe previews by conditionally optimizing media during persistence and updating UI components to display videos. --- py/routes/handlers/recipe_handlers.py | 30 +++- py/services/recipes/persistence_service.py | 26 ++-- static/js/components/RecipeCard.js | 159 +++++++++++++-------- static/js/components/shared/ModelCard.js | 96 ++++++------- tests/routes/test_recipe_routes.py | 53 ++++++- tests/services/test_recipe_services.py | 45 ++++++ 6 files changed, 283 insertions(+), 126 deletions(-) diff --git a/py/routes/handlers/recipe_handlers.py b/py/routes/handlers/recipe_handlers.py index cee3ad0c..911de839 100644 --- a/py/routes/handlers/recipe_handlers.py +++ b/py/routes/handlers/recipe_handlers.py @@ -23,6 +23,7 @@ from ...services.recipes import ( RecipeValidationError, ) from ...services.metadata_service import get_default_metadata_provider +from ...utils.civitai_utils import rewrite_preview_url Logger = logging.Logger EnsureDependenciesCallable = Callable[[], Awaitable[None]] @@ -455,6 +456,7 @@ class RecipeManagementHandler: image_url = params.get("image_url") name = params.get("name") resources_raw = params.get("resources") + if not image_url: raise RecipeValidationError("Missing required field: image_url") if not name: @@ -483,7 +485,7 @@ class RecipeManagementHandler: metadata["base_model"] = base_model_from_metadata tags = self._parse_tags(params.get("tags")) - image_bytes = await self._download_image_bytes(image_url) + image_bytes, extension = await self._download_remote_media(image_url) result = await self._persistence_service.save_recipe( recipe_scanner=recipe_scanner, @@ -492,6 +494,7 @@ class RecipeManagementHandler: name=name, tags=tags, metadata=metadata, + extension=extension, ) return web.json_response(result.payload, status=result.status) except RecipeValidationError as exc: @@ -729,7 +732,7 @@ class RecipeManagementHandler: "exclude": False, } - async def _download_image_bytes(self, image_url: str) -> bytes: + async def _download_remote_media(self, image_url: str) -> tuple[bytes, str]: civitai_client = self._civitai_client_getter() downloader = await self._downloader_factory() temp_path = None @@ -744,15 +747,31 @@ class RecipeManagementHandler: image_info = await civitai_client.get_image_info(civitai_match.group(1)) if not image_info: raise RecipeDownloadError("Failed to fetch image information from Civitai") - download_url = image_info.get("url") - if not download_url: + + media_url = image_info.get("url") + if not media_url: raise RecipeDownloadError("No image URL found in Civitai response") + + # Use optimized preview URLs if possible + media_type = image_info.get("type") + rewritten_url, _ = rewrite_preview_url(media_url, media_type=media_type) + if rewritten_url: + download_url = rewritten_url + else: + download_url = media_url success, result = await downloader.download_file(download_url, temp_path, use_auth=False) if not success: raise RecipeDownloadError(f"Failed to download image: {result}") + + # Extract extension from URL + url_path = download_url.split('?')[0].split('#')[0] + extension = os.path.splitext(url_path)[1].lower() + if not extension: + extension = ".webp" # Default to webp if unknown + with open(temp_path, "rb") as file_obj: - return file_obj.read() + return file_obj.read(), extension except RecipeDownloadError: raise except RecipeValidationError: @@ -766,6 +785,7 @@ class RecipeManagementHandler: except FileNotFoundError: pass + def _safe_int(self, value: Any) -> int: try: return int(value) diff --git a/py/services/recipes/persistence_service.py b/py/services/recipes/persistence_service.py index 2640035e..535f0853 100644 --- a/py/services/recipes/persistence_service.py +++ b/py/services/recipes/persistence_service.py @@ -46,6 +46,7 @@ class RecipePersistenceService: name: str | None, tags: Iterable[str], metadata: Optional[dict[str, Any]], + extension: str | None = None, ) -> PersistenceResult: """Persist a user uploaded recipe.""" @@ -64,13 +65,21 @@ class RecipePersistenceService: os.makedirs(recipes_dir, exist_ok=True) recipe_id = str(uuid.uuid4()) - optimized_image, extension = self._exif_utils.optimize_image( - image_data=resolved_image_bytes, - target_width=self._card_preview_width, - format="webp", - quality=85, - preserve_metadata=True, - ) + + # Handle video formats by bypassing optimization and metadata embedding + is_video = extension in [".mp4", ".webm"] + if is_video: + optimized_image = resolved_image_bytes + # extension is already set + else: + optimized_image, extension = self._exif_utils.optimize_image( + image_data=resolved_image_bytes, + target_width=self._card_preview_width, + format="webp", + quality=85, + preserve_metadata=True, + ) + image_filename = f"{recipe_id}{extension}" image_path = os.path.join(recipes_dir, image_filename) normalized_image_path = os.path.normpath(image_path) @@ -126,7 +135,8 @@ class RecipePersistenceService: with open(json_path, "w", encoding="utf-8") as file_obj: json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) - self._exif_utils.append_recipe_metadata(normalized_image_path, recipe_data) + if not is_video: + self._exif_utils.append_recipe_metadata(normalized_image_path, recipe_data) matching_recipes = await self._find_matching_recipes(recipe_scanner, fingerprint, exclude_id=recipe_id) await recipe_scanner.add_recipe(recipe_data) diff --git a/static/js/components/RecipeCard.js b/static/js/components/RecipeCard.js index dec61fa0..2c42ddf8 100644 --- a/static/js/components/RecipeCard.js +++ b/static/js/components/RecipeCard.js @@ -1,5 +1,6 @@ // Recipe Card Component import { showToast, copyToClipboard, sendLoraToWorkflow } from '../utils/uiHelpers.js'; +import { configureModelCardVideo } from './shared/ModelCard.js'; import { modalManager } from '../managers/ModalManager.js'; import { getCurrentPageState } from '../state/index.js'; import { state } from '../state/index.js'; @@ -10,11 +11,11 @@ class RecipeCard { this.recipe = recipe; this.clickHandler = clickHandler; this.element = this.createCardElement(); - + // Store reference to this instance on the DOM element for updates this.element._recipeCardInstance = this; } - + createCardElement() { const card = document.createElement('div'); card.className = 'model-card'; @@ -23,24 +24,40 @@ class RecipeCard { card.dataset.nsfwLevel = this.recipe.preview_nsfw_level || 0; card.dataset.created = this.recipe.created_date; card.dataset.id = this.recipe.id || ''; - + // Get base model with fallback const baseModelLabel = (this.recipe.base_model || '').trim() || 'Unknown'; const baseModelAbbreviation = getBaseModelAbbreviation(baseModelLabel); const baseModelDisplay = baseModelLabel === 'Unknown' ? 'Unknown' : baseModelAbbreviation; - + // Ensure loras array exists const loras = this.recipe.loras || []; const lorasCount = loras.length; - + // Check if all LoRAs are available in the library const missingLorasCount = loras.filter(lora => !lora.inLibrary && !lora.isDeleted).length; const allLorasAvailable = missingLorasCount === 0 && lorasCount > 0; - + // Ensure file_url exists, fallback to file_path if needed - const imageUrl = this.recipe.file_url || - (this.recipe.file_path ? `/loras_static/root1/preview/${this.recipe.file_path.split('/').pop()}` : - '/loras_static/images/no-preview.png'); + const previewUrl = this.recipe.file_url || + (this.recipe.file_path ? `/loras_static/root1/preview/${this.recipe.file_path.split('/').pop()}` : + '/loras_static/images/no-preview.png'); + + // Video preview logic + const autoplayOnHover = state.settings.autoplay_on_hover || false; + const isVideo = previewUrl.endsWith('.mp4') || previewUrl.endsWith('.webm'); + const videoAttrs = [ + 'controls', + 'muted', + 'loop', + 'playsinline', + 'preload="none"', + `data-src="${previewUrl}"` + ]; + + if (!autoplayOnHover) { + videoAttrs.push('data-autoplay="true"'); + } // Check if in duplicates mode const pageState = getCurrentPageState(); @@ -49,7 +66,7 @@ class RecipeCard { // NSFW blur logic - similar to LoraCard const nsfwLevel = this.recipe.preview_nsfw_level !== undefined ? this.recipe.preview_nsfw_level : 0; const shouldBlur = state.settings.blur_mature_content && nsfwLevel > NSFW_LEVELS.PG13; - + if (shouldBlur) { card.classList.add('nsfw-content'); } @@ -66,11 +83,14 @@ class RecipeCard { card.innerHTML = `
- ${this.recipe.title} + ${isVideo ? + `` : + `${this.recipe.title}` + } ${!isDuplicatesMode ? `
- ${shouldBlur ? - `` : ''} ${baseModelDisplay} @@ -102,30 +122,37 @@ class RecipeCard {
`; - + this.attachEventListeners(card, isDuplicatesMode, shouldBlur); + + // Add video auto-play on hover functionality if needed + const videoElement = card.querySelector('video'); + if (videoElement) { + configureModelCardVideo(videoElement, autoplayOnHover); + } + return card; } - + getLoraStatusTitle(totalCount, missingCount) { if (totalCount === 0) return "No LoRAs in this recipe"; if (missingCount === 0) return "All LoRAs available - Ready to use"; return `${missingCount} of ${totalCount} LoRAs missing`; } - + attachEventListeners(card, isDuplicatesMode, shouldBlur) { // Add blur toggle functionality if content should be blurred if (shouldBlur) { const toggleBtn = card.querySelector('.toggle-blur-btn'); const showBtn = card.querySelector('.show-content-btn'); - + if (toggleBtn) { toggleBtn.addEventListener('click', (e) => { e.stopPropagation(); this.toggleBlurContent(card); }); } - + if (showBtn) { showBtn.addEventListener('click', (e) => { e.stopPropagation(); @@ -139,19 +166,19 @@ class RecipeCard { card.addEventListener('click', () => { this.clickHandler(this.recipe); }); - + // Share button click event - prevent propagation to card card.querySelector('.fa-share-alt')?.addEventListener('click', (e) => { e.stopPropagation(); this.shareRecipe(); }); - + // Send button click event - prevent propagation to card card.querySelector('.fa-paper-plane')?.addEventListener('click', (e) => { e.stopPropagation(); this.sendRecipeToWorkflow(e.shiftKey); }); - + // Delete button click event - prevent propagation to card card.querySelector('.fa-trash')?.addEventListener('click', (e) => { e.stopPropagation(); @@ -159,19 +186,19 @@ class RecipeCard { }); } } - + toggleBlurContent(card) { const preview = card.querySelector('.card-preview'); const isBlurred = preview.classList.toggle('blurred'); const icon = card.querySelector('.toggle-blur-btn i'); - + // Update the icon based on blur state if (isBlurred) { icon.className = 'fas fa-eye'; } else { icon.className = 'fas fa-eye-slash'; } - + // Toggle the overlay visibility const overlay = card.querySelector('.nsfw-overlay'); if (overlay) { @@ -182,13 +209,13 @@ class RecipeCard { showBlurredContent(card) { const preview = card.querySelector('.card-preview'); preview.classList.remove('blurred'); - + // Update the toggle button icon const toggleBtn = card.querySelector('.toggle-blur-btn'); if (toggleBtn) { toggleBtn.querySelector('i').className = 'fas fa-eye-slash'; } - + // Hide the overlay const overlay = card.querySelector('.nsfw-overlay'); if (overlay) { @@ -223,7 +250,7 @@ class RecipeCard { showToast('toast.recipes.sendError', {}, 'error'); } } - + showDeleteConfirmation() { try { // Get recipe ID @@ -233,15 +260,21 @@ class RecipeCard { showToast('toast.recipes.cannotDelete', {}, 'error'); return; } - + // Create delete modal content + const previewUrl = this.recipe.file_url || '/loras_static/images/no-preview.png'; + const isVideo = previewUrl.endsWith('.mp4') || previewUrl.endsWith('.webm'); + const deleteModalContent = ` `; - + // Insert before the buttons container buttonsContainer.parentNode.insertBefore(warningContainer, buttonsContainer); } - + // Check for duplicates but don't change button actions const missingNotDeleted = this.importManager.recipeData.loras.filter( lora => !lora.existsLocally && !lora.isDeleted ).length; - + // Standard button behavior regardless of duplicates nextButton.classList.remove('warning-btn'); - + if (missingNotDeleted > 0) { nextButton.textContent = translate('recipes.controls.import.downloadMissingLoras', {}, 'Download Missing LoRAs'); } else { @@ -372,30 +383,30 @@ export class RecipeDataManager { addTag() { const tagInput = document.getElementById('tagInput'); const tag = tagInput.value.trim(); - + if (!tag) return; - + if (!this.importManager.recipeTags.includes(tag)) { this.importManager.recipeTags.push(tag); this.updateTagsDisplay(); } - + tagInput.value = ''; } - + removeTag(tag) { this.importManager.recipeTags = this.importManager.recipeTags.filter(t => t !== tag); this.updateTagsDisplay(); } - + updateTagsDisplay() { const tagsContainer = document.getElementById('tagsContainer'); - + if (this.importManager.recipeTags.length === 0) { tagsContainer.innerHTML = `
${translate('recipes.controls.import.noTagsAdded', {}, 'No tags added')}
`; return; } - + tagsContainer.innerHTML = this.importManager.recipeTags.map(tag => `
${tag} @@ -410,7 +421,7 @@ export class RecipeDataManager { showToast('toast.recipes.enterRecipeName', {}, 'error'); return; } - + // Automatically mark all deleted LoRAs as excluded if (this.importManager.recipeData && this.importManager.recipeData.loras) { this.importManager.recipeData.loras.forEach(lora => { @@ -419,11 +430,11 @@ export class RecipeDataManager { } }); } - + // Update missing LoRAs list to exclude deleted LoRAs - this.importManager.missingLoras = this.importManager.recipeData.loras.filter(lora => + this.importManager.missingLoras = this.importManager.recipeData.loras.filter(lora => !lora.existsLocally && !lora.isDeleted); - + // If we have downloadable missing LoRAs, go to location step if (this.importManager.missingLoras.length > 0) { // Store only downloadable LoRAs for the download step diff --git a/tests/services/test_recipe_services.py b/tests/services/test_recipe_services.py index de72fb32..8e33e1e0 100644 --- a/tests/services/test_recipe_services.py +++ b/tests/services/test_recipe_services.py @@ -95,7 +95,7 @@ async def test_analyze_remote_image_download_failure_cleans_temp(tmp_path, monke temp_path = tmp_path / "temp.jpg" - def create_temp_path(): + def create_temp_path(suffix=".jpg"): temp_path.write_bytes(b"") return str(temp_path) @@ -401,3 +401,55 @@ async def test_save_recipe_from_widget_allows_empty_lora(tmp_path): assert stored["loras"] == [] assert stored["title"] == "recipe" assert scanner.added and scanner.added[0]["loras"] == [] + + +@pytest.mark.asyncio +async def test_analyze_remote_video(tmp_path): + exif_utils = DummyExifUtils() + + class DummyFactory: + def create_parser(self, metadata): + async def parse_metadata(m, recipe_scanner): + return {"loras": []} + return SimpleNamespace(parse_metadata=parse_metadata) + + async def downloader_factory(): + class Downloader: + async def download_file(self, url, path, use_auth=False): + Path(path).write_bytes(b"video-content") + return True, "success" + + return Downloader() + + service = RecipeAnalysisService( + exif_utils=exif_utils, + recipe_parser_factory=DummyFactory(), + downloader_factory=downloader_factory, + metadata_collector=None, + metadata_processor_cls=None, + metadata_registry_cls=None, + standalone_mode=False, + logger=logging.getLogger("test"), + ) + + class DummyClient: + async def get_image_info(self, image_id): + return { + "url": "https://civitai.com/video.mp4", + "type": "video", + "meta": {"prompt": "video prompt"}, + } + + class DummyScanner: + async def find_recipes_by_fingerprint(self, fingerprint): + return [] + + result = await service.analyze_remote_image( + url="https://civitai.com/images/123", + recipe_scanner=DummyScanner(), + civitai_client=DummyClient(), + ) + + assert result.payload["is_video"] is True + assert result.payload["extension"] == ".mp4" + assert result.payload["image_base64"] is not None From 3ba5c4c2ab9ef9ec6cd0a1fbf87b981052985298 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Mon, 22 Dec 2025 17:58:04 +0800 Subject: [PATCH 15/35] refactor: improve `update_lora_filename_by_hash` logic and add a test to verify recipe updates. --- py/services/recipe_scanner.py | 105 +++++++++++--------------- tests/services/test_recipe_scanner.py | 61 +++++++++++++++ 2 files changed, 106 insertions(+), 60 deletions(-) diff --git a/py/services/recipe_scanner.py b/py/services/recipe_scanner.py index 3b41643e..c7ac2842 100644 --- a/py/services/recipe_scanner.py +++ b/py/services/recipe_scanner.py @@ -1333,71 +1333,56 @@ class RecipeScanner: # Always use lowercase hash for consistency hash_value = hash_value.lower() - # Get recipes directory - recipes_dir = self.recipes_dir - if not recipes_dir or not os.path.exists(recipes_dir): - logger.warning(f"Recipes directory not found: {recipes_dir}") + # Get cache + cache = await self.get_cached_data() + if not cache or not cache.raw_data: + return 0, 0 + + file_updated_count = 0 + cache_updated_count = 0 + + # Find recipes that need updating from the cache + recipes_to_update = [] + for recipe in cache.raw_data: + loras = recipe.get('loras', []) + if not isinstance(loras, list): + continue + + has_match = False + for lora in loras: + if not isinstance(lora, dict): + continue + if (lora.get('hash') or '').lower() == hash_value: + if lora.get('file_name') != new_file_name: + lora['file_name'] = new_file_name + has_match = True + + if has_match: + recipes_to_update.append(recipe) + cache_updated_count += 1 + + if not recipes_to_update: return 0, 0 - # Check if cache is initialized - cache_initialized = self._cache is not None - cache_updated_count = 0 - file_updated_count = 0 - - # Get all recipe JSON files in the recipes directory - recipe_files = [] - for root, _, files in os.walk(recipes_dir): - for file in files: - if file.lower().endswith('.recipe.json'): - recipe_files.append(os.path.join(root, file)) - - # Process each recipe file - for recipe_path in recipe_files: - try: - # Load the recipe data - with open(recipe_path, 'r', encoding='utf-8') as f: - recipe_data = json.load(f) - - # Skip if no loras or invalid structure - if not recipe_data or not isinstance(recipe_data, dict) or 'loras' not in recipe_data: + # Persist changes to disk + async with self._mutation_lock: + for recipe in recipes_to_update: + recipe_id = recipe.get('id') + if not recipe_id: continue - - # Check if any lora has matching hash - file_updated = False - for lora in recipe_data.get('loras', []): - if 'hash' in lora and lora['hash'].lower() == hash_value: - # Update file_name - old_file_name = lora.get('file_name', '') - lora['file_name'] = new_file_name - file_updated = True - logger.info(f"Updated file_name in recipe {recipe_path}: {old_file_name} -> {new_file_name}") - - # If updated, save the file - if file_updated: - with open(recipe_path, 'w', encoding='utf-8') as f: - json.dump(recipe_data, f, indent=4, ensure_ascii=False) - file_updated_count += 1 - # Also update in cache if it exists - if cache_initialized: - recipe_id = recipe_data.get('id') - if recipe_id: - for cache_item in self._cache.raw_data: - if cache_item.get('id') == recipe_id: - # Replace loras array with updated version - cache_item['loras'] = recipe_data['loras'] - cache_updated_count += 1 - break - - except Exception as e: - logger.error(f"Error updating recipe file {recipe_path}: {e}") - import traceback - traceback.print_exc(file=sys.stderr) + recipe_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json") + try: + self._write_recipe_file(recipe_path, recipe) + file_updated_count += 1 + logger.info(f"Updated file_name in recipe {recipe_path}: -> {new_file_name}") + except Exception as e: + logger.error(f"Error updating recipe file {recipe_path}: {e}") - # Resort cache if updates were made - if cache_initialized and cache_updated_count > 0: - await self._cache.resort() - logger.info(f"Resorted recipe cache after updating {cache_updated_count} items") + # We don't necessarily need to resort because LoRA file_name isn't a sort key, + # but we might want to schedule a resort if we're paranoid or if searching relies on sorted state. + # Given it's a rename of a dependency, search results might change if searching by LoRA name. + self._schedule_resort() return file_updated_count, cache_updated_count diff --git a/tests/services/test_recipe_scanner.py b/tests/services/test_recipe_scanner.py index cdca7c46..822eef43 100644 --- a/tests/services/test_recipe_scanner.py +++ b/tests/services/test_recipe_scanner.py @@ -445,3 +445,64 @@ async def test_load_recipe_persists_deleted_flag_on_invalid_version(monkeypatch, persisted = json.loads(recipe_path.read_text()) assert persisted["loras"][0]["isDeleted"] is True + + +@pytest.mark.asyncio +async def test_update_lora_filename_by_hash_updates_affected_recipes(tmp_path: Path, recipe_scanner): + scanner, _ = recipe_scanner + recipes_dir = Path(config.loras_roots[0]) / "recipes" + recipes_dir.mkdir(parents=True, exist_ok=True) + + # Recipe 1: Contains the LoRA with hash "hash1" + recipe1_id = "recipe1" + recipe1_path = recipes_dir / f"{recipe1_id}.recipe.json" + recipe1_data = { + "id": recipe1_id, + "file_path": str(tmp_path / "img1.png"), + "title": "Recipe 1", + "modified": 0.0, + "created_date": 0.0, + "loras": [ + {"file_name": "old_name", "hash": "hash1"}, + {"file_name": "other_lora", "hash": "hash2"} + ], + } + recipe1_path.write_text(json.dumps(recipe1_data)) + await scanner.add_recipe(dict(recipe1_data)) + + # Recipe 2: Does NOT contain the LoRA + recipe2_id = "recipe2" + recipe2_path = recipes_dir / f"{recipe2_id}.recipe.json" + recipe2_data = { + "id": recipe2_id, + "file_path": str(tmp_path / "img2.png"), + "title": "Recipe 2", + "modified": 0.0, + "created_date": 0.0, + "loras": [ + {"file_name": "other_lora", "hash": "hash2"} + ], + } + recipe2_path.write_text(json.dumps(recipe2_data)) + await scanner.add_recipe(dict(recipe2_data)) + + # Update LoRA name for "hash1" (using different case to test normalization) + new_name = "new_name" + file_count, cache_count = await scanner.update_lora_filename_by_hash("HASH1", new_name) + + assert file_count == 1 + assert cache_count == 1 + + # Check file on disk + persisted1 = json.loads(recipe1_path.read_text()) + assert persisted1["loras"][0]["file_name"] == new_name + assert persisted1["loras"][1]["file_name"] == "other_lora" + + # Verify Recipe 2 unchanged + persisted2 = json.loads(recipe2_path.read_text()) + assert persisted2["loras"][0]["file_name"] == "other_lora" + + # Check cache + cache = await scanner.get_cached_data() + cached1 = next(r for r in cache.raw_data if r["id"] == recipe1_id) + assert cached1["loras"][0]["file_name"] == new_name From dd89aa49c14850f6f17d145e43732d6f13f1ce0f Mon Sep 17 00:00:00 2001 From: Will Miao Date: Tue, 23 Dec 2025 08:47:15 +0800 Subject: [PATCH 16/35] feat: Add HTML and attribute escaping for trigger words and class tokens to prevent XSS vulnerabilities, along with new frontend tests. Fixes #732 --- static/js/components/shared/ModelModal.js | 84 +++---- static/js/components/shared/TriggerWords.js | 211 +++++++++--------- .../components/triggerWords.escaping.test.js | 55 +++++ 3 files changed, 208 insertions(+), 142 deletions(-) create mode 100644 tests/frontend/components/triggerWords.escaping.test.js diff --git a/static/js/components/shared/ModelModal.js b/static/js/components/shared/ModelModal.js index 84e1dbb3..7e8261b1 100644 --- a/static/js/components/shared/ModelModal.js +++ b/static/js/components/shared/ModelModal.js @@ -1,15 +1,15 @@ import { showToast, openCivitai } from '../../utils/uiHelpers.js'; import { modalManager } from '../../managers/ModalManager.js'; -import { +import { toggleShowcase, - setupShowcaseScroll, + setupShowcaseScroll, scrollToTop, loadExampleImages } from './showcase/ShowcaseView.js'; import { setupTabSwitching } from './ModelDescription.js'; -import { - setupModelNameEditing, - setupBaseModelEditing, +import { + setupModelNameEditing, + setupBaseModelEditing, setupFileNameEditing } from './ModelMetadata.js'; import { setupTagEditMode } from './ModelTags.js'; @@ -242,7 +242,7 @@ export async function showModelModal(model, modelType) { const modalTitle = model.model_name; cleanupNavigationShortcuts(); detachModalHandlers(modalId); - + // Fetch complete civitai metadata let completeCivitaiData = model.civitai || {}; if (model.file_path) { @@ -254,7 +254,7 @@ export async function showModelModal(model, modelType) { // Continue with existing data if fetch fails } } - + // Update model with complete civitai data const modelWithFullData = { ...model, @@ -269,14 +269,14 @@ export async function showModelModal(model, modelType) {
`.trim() : ''; const creatorInfoAction = modelWithFullData.civitai?.creator ? `
- ${modelWithFullData.civitai.creator.image ? - `
+ ${modelWithFullData.civitai.creator.image ? + `
${modelWithFullData.civitai.creator.username} -
` : - `
+
` : + `
` - } + } ${modelWithFullData.civitai.creator.username}
`.trim() : ''; const creatorActionItems = []; @@ -310,10 +310,10 @@ export async function showModelModal(model, modelType) { const hasUpdateAvailable = Boolean(modelWithFullData.update_available); const updateAvailabilityState = { hasUpdateAvailable }; const updateBadgeTooltip = translate('modelCard.badges.updateAvailable', {}, 'Update available'); - + // Prepare LoRA specific data with complete civitai data - const escapedWords = (modelType === 'loras' || modelType === 'embeddings') && modelWithFullData.civitai?.trainedWords?.length ? - modelWithFullData.civitai.trainedWords.map(word => word.replace(/'/g, '\\\'')) : []; + const escapedWords = (modelType === 'loras' || modelType === 'embeddings') && modelWithFullData.civitai?.trainedWords?.length ? + modelWithFullData.civitai.trainedWords : []; // Generate model type specific content let typeSpecificContent; @@ -343,7 +343,7 @@ export async function showModelModal(model, modelType) { ${versionsTabBadge} `.trim(); - const tabsContent = modelType === 'loras' ? + const tabsContent = modelType === 'loras' ? ` ${versionsTabButton} @@ -351,12 +351,12 @@ export async function showModelModal(model, modelType) { ` ${versionsTabButton}`; - + const loadingExampleImagesText = translate('modals.model.loading.exampleImages', {}, 'Loading example images...'); const loadingDescriptionText = translate('modals.model.loading.description', {}, 'Loading model description...'); const loadingRecipesText = translate('modals.model.loading.recipes', {}, 'Loading recipes...'); const loadingExamplesText = translate('modals.model.loading.examples', {}, 'Loading examples...'); - + const loadingVersionsText = translate('modals.model.loading.versions', {}, 'Loading versions...'); const civitaiModelId = modelWithFullData.civitai?.modelId || ''; const civitaiVersionId = modelWithFullData.civitai?.id || ''; @@ -373,7 +373,7 @@ export async function showModelModal(model, modelType) {
`.trim(); - const tabPanesContent = modelType === 'loras' ? + const tabPanesContent = modelType === 'loras' ? `
${loadingExampleImagesText} @@ -518,7 +518,7 @@ export async function showModelModal(model, modelType) {
`; - + function updateVersionsTabBadge(hasUpdate) { const modalElement = document.getElementById(modalId); if (!modalElement) return; @@ -594,10 +594,10 @@ export async function showModelModal(model, modelType) { updateVersionsTabBadge(hasUpdate); updateCardUpdateAvailability(hasUpdate); } - + let showcaseCleanup; - const onCloseCallback = function() { + const onCloseCallback = function () { // Clean up all handlers when modal closes for LoRA const modalElement = document.getElementById(modalId); if (modalElement && modalElement._clickHandler) { @@ -610,7 +610,7 @@ export async function showModelModal(model, modelType) { } cleanupNavigationShortcuts(); }; - + modalManager.showModal(modalId, content, null, onCloseCallback); const activeModalElement = document.getElementById(modalId); if (activeModalElement) { @@ -643,17 +643,17 @@ export async function showModelModal(model, modelType) { setupEventHandlers(modelWithFullData.file_path, modelType); setupNavigationShortcuts(modelType); updateNavigationControls(); - + // LoRA specific setup if (modelType === 'loras' || modelType === 'embeddings') { setupTriggerWordsEditMode(); - + if (modelType == 'loras') { // Load recipes for this LoRA loadRecipesForLora(modelWithFullData.model_name, modelWithFullData.sha256); } } - + // Load example images asynchronously - merge regular and custom images const regularImages = modelWithFullData.civitai?.images || []; const customImages = modelWithFullData.civitai?.customImages || []; @@ -707,17 +707,17 @@ function detachModalHandlers(modalId) { */ function setupEventHandlers(filePath, modelType) { const modalElement = document.getElementById('modelModal'); - + // Remove existing event listeners first modalElement.removeEventListener('click', handleModalClick); - + // Create and store the handler function function handleModalClick(event) { const target = event.target.closest('[data-action]'); if (!target) return; - + const action = target.dataset.action; - + switch (action) { case 'close-modal': modalManager.closeModal('modelModal'); @@ -748,10 +748,10 @@ function setupEventHandlers(filePath, modelType) { break; } } - + // Add the event listener with the named function modalElement.addEventListener('click', handleModalClick); - + // Store reference to the handler on the element for potential cleanup modalElement._clickHandler = handleModalClick; } @@ -763,15 +763,15 @@ function setupEventHandlers(filePath, modelType) { */ function setupEditableFields(filePath, modelType) { const editableFields = document.querySelectorAll('.editable-field [contenteditable]'); - + editableFields.forEach(field => { - field.addEventListener('focus', function() { + field.addEventListener('focus', function () { if (this.textContent === 'Add your notes here...') { this.textContent = ''; } }); - field.addEventListener('blur', function() { + field.addEventListener('blur', function () { if (this.textContent.trim() === '') { if (this.classList.contains('notes-content')) { this.textContent = 'Add your notes here...'; @@ -783,7 +783,7 @@ function setupEditableFields(filePath, modelType) { // Add keydown event listeners for notes const notesContent = document.querySelector('.notes-content'); if (notesContent) { - notesContent.addEventListener('keydown', async function(e) { + notesContent.addEventListener('keydown', async function (e) { if (e.key === 'Enter') { if (e.shiftKey) { // Allow shift+enter for new line @@ -810,7 +810,7 @@ function setupLoraSpecificFields(filePath) { if (!presetSelector || !presetValue || !addPresetBtn || !presetTags) return; - presetSelector.addEventListener('change', function() { + presetSelector.addEventListener('change', function () { const selected = this.value; if (selected) { presetValue.style.display = 'inline-block'; @@ -828,10 +828,10 @@ function setupLoraSpecificFields(filePath) { } }); - addPresetBtn.addEventListener('click', async function() { + addPresetBtn.addEventListener('click', async function () { const key = presetSelector.value; const value = presetValue.value; - + if (!key || !value) return; const currentPath = resolveFilePath(); @@ -839,21 +839,21 @@ function setupLoraSpecificFields(filePath) { const loraCard = document.querySelector(`.model-card[data-filepath="${currentPath}"]`) || document.querySelector(`.model-card[data-filepath="${filePath}"]`); const currentPresets = parsePresets(loraCard?.dataset.usage_tips); - + currentPresets[key] = parseFloat(value); const newPresetsJson = JSON.stringify(currentPresets); await getModelApiClient().saveModelMetadata(currentPath, { usage_tips: newPresetsJson }); presetTags.innerHTML = renderPresetTags(currentPresets); - + presetSelector.value = ''; presetValue.value = ''; presetValue.style.display = 'none'; }); // Add keydown event for preset value - presetValue.addEventListener('keydown', function(e) { + presetValue.addEventListener('keydown', function (e) { if (e.key === 'Enter') { e.preventDefault(); addPresetBtn.click(); diff --git a/static/js/components/shared/TriggerWords.js b/static/js/components/shared/TriggerWords.js index dfc1cb17..ea9d929a 100644 --- a/static/js/components/shared/TriggerWords.js +++ b/static/js/components/shared/TriggerWords.js @@ -6,7 +6,7 @@ import { showToast, copyToClipboard } from '../../utils/uiHelpers.js'; import { translate } from '../../utils/i18nHelpers.js'; import { getModelApiClient } from '../../api/modelApiFactory.js'; -import { escapeAttribute } from './utils.js'; +import { escapeAttribute, escapeHtml } from './utils.js'; /** * Fetch trained words for a model @@ -17,7 +17,7 @@ async function fetchTrainedWords(filePath) { try { const response = await fetch(`/api/lm/trained-words?file_path=${encodeURIComponent(filePath)}`); const data = await response.json(); - + if (data.success) { return { trainedWords: data.trained_words || [], // Returns array of [word, frequency] pairs @@ -43,11 +43,11 @@ async function fetchTrainedWords(filePath) { function createSuggestionDropdown(trainedWords, classTokens, existingWords = []) { const dropdown = document.createElement('div'); dropdown.className = 'metadata-suggestions-dropdown'; - + // Create header const header = document.createElement('div'); header.className = 'metadata-suggestions-header'; - + // No suggestions case if ((!trainedWords || trainedWords.length === 0) && !classTokens) { header.innerHTML = `${translate('modals.model.triggerWords.suggestions.noSuggestions')}`; @@ -55,12 +55,12 @@ function createSuggestionDropdown(trainedWords, classTokens, existingWords = []) dropdown.innerHTML += `
${translate('modals.model.triggerWords.suggestions.noTrainedWords')}
`; return dropdown; } - + // Sort trained words by frequency (highest first) if available if (trainedWords && trainedWords.length > 0) { trainedWords.sort((a, b) => b[1] - a[1]); } - + // Add class tokens section if available if (classTokens) { // Add class tokens header @@ -71,45 +71,47 @@ function createSuggestionDropdown(trainedWords, classTokens, existingWords = []) ${translate('modals.model.triggerWords.suggestions.classTokenDescription')} `; dropdown.appendChild(classTokensHeader); - + // Add class tokens container const classTokensContainer = document.createElement('div'); classTokensContainer.className = 'class-tokens-container'; - + // Create a special item for the class token const tokenItem = document.createElement('div'); tokenItem.className = `metadata-suggestion-item class-token-item ${existingWords.includes(classTokens) ? 'already-added' : ''}`; tokenItem.title = `${translate('modals.model.triggerWords.suggestions.classToken')}: ${classTokens}`; + + const escapedToken = escapeHtml(classTokens); tokenItem.innerHTML = ` - + `; - + // Add click handler if not already added if (!existingWords.includes(classTokens)) { tokenItem.addEventListener('click', () => { // Automatically add this word addNewTriggerWord(classTokens); - + // Also populate the input field for potential editing const input = document.querySelector('.metadata-input'); if (input) input.value = classTokens; - + // Focus on the input if (input) input.focus(); - + // Update dropdown without removing it updateTrainedWordsDropdown(); }); } - + classTokensContainer.appendChild(tokenItem); dropdown.appendChild(classTokensContainer); - + // Add separator if we also have trained words if (trainedWords && trainedWords.length > 0) { const separator = document.createElement('div'); @@ -117,7 +119,7 @@ function createSuggestionDropdown(trainedWords, classTokens, existingWords = []) dropdown.appendChild(separator); } } - + // Add trained words header if we have any if (trainedWords && trainedWords.length > 0) { header.innerHTML = ` @@ -125,52 +127,54 @@ function createSuggestionDropdown(trainedWords, classTokens, existingWords = []) ${translate('modals.model.triggerWords.suggestions.wordsFound', { count: trainedWords.length })} `; dropdown.appendChild(header); - + // Create tag container for trained words const container = document.createElement('div'); container.className = 'metadata-suggestions-container'; - + // Add each trained word as a tag trainedWords.forEach(([word, frequency]) => { const isAdded = existingWords.includes(word); - + const item = document.createElement('div'); item.className = `metadata-suggestion-item ${isAdded ? 'already-added' : ''}`; item.title = word; // Show full word on hover if truncated + + const escapedWord = escapeHtml(word); item.innerHTML = ` - + `; - + if (!isAdded) { item.addEventListener('click', () => { // Automatically add this word addNewTriggerWord(word); - + // Also populate the input field for potential editing const input = document.querySelector('.metadata-input'); if (input) input.value = word; - + // Focus on the input if (input) input.focus(); - + // Update dropdown without removing it updateTrainedWordsDropdown(); }); } - + container.appendChild(item); }); - + dropdown.appendChild(container); } else if (!classTokens) { // If we have neither class tokens nor trained words dropdown.innerHTML += `
${translate('modals.model.triggerWords.suggestions.noTrainedWords')}
`; } - + return dropdown; } @@ -204,7 +208,7 @@ export function renderTriggerWords(words, filePath) { `; - + return `
@@ -215,9 +219,12 @@ export function renderTriggerWords(words, filePath) {
- ${words.map(word => ` -
- ${word} + ${words.map(word => { + const escapedWord = escapeHtml(word); + const escapedAttr = escapeAttribute(word); + return ` +
+ ${escapedWord} @@ -225,7 +232,7 @@ export function renderTriggerWords(words, filePath) {
- `).join('')} + `}).join('')}
- + + {% include 'components/folder_sidebar.html' %} +
From 3f646aa0c91253e3f7db5a9c5adcc05821d741b4 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Tue, 25 Nov 2025 17:41:24 +0800 Subject: [PATCH 18/35] feat: add recipe root directory and move recipe endpoints - Add GET /api/lm/recipes/roots endpoint to retrieve recipe root directories - Add POST /api/lm/recipe/move endpoint to move recipes between directories - Register new endpoints in route definitions - Implement error handling for both new endpoints with proper status codes - Enable recipe management operations for better file organization --- locales/de.json | 8 +- locales/en.json | 18 ++-- locales/es.json | 8 +- locales/fr.json | 8 +- locales/he.json | 8 +- locales/ja.json | 8 +- locales/ko.json | 8 +- locales/ru.json | 8 +- locales/zh-CN.json | 8 +- locales/zh-TW.json | 8 +- py/routes/handlers/recipe_handlers.py | 44 +++++++++ py/routes/recipe_route_registrar.py | 2 + py/services/recipe_scanner.py | 36 +++++-- py/services/recipes/persistence_service.py | 98 ++++++++++++++++--- static/js/api/recipeApi.js | 91 +++++++++++++++-- .../ContextMenu/RecipeContextMenu.js | 4 + static/js/managers/MoveManager.js | 20 +++- templates/recipes.html | 97 ++++++++++-------- tests/frontend/pages/recipesPage.test.js | 20 ++++ tests/routes/test_recipe_routes.py | 26 +++++ tests/services/test_recipe_services.py | 83 ++++++++++++++++ 21 files changed, 501 insertions(+), 110 deletions(-) diff --git a/locales/de.json b/locales/de.json index 873d2284..94826973 100644 --- a/locales/de.json +++ b/locales/de.json @@ -638,7 +638,8 @@ "recursiveUnavailable": "Rekursive Suche ist nur in der Baumansicht verfügbar", "collapseAllDisabled": "Im Listenmodus nicht verfügbar", "dragDrop": { - "unableToResolveRoot": "Zielpfad für das Verschieben konnte nicht ermittelt werden." + "unableToResolveRoot": "Zielpfad für das Verschieben konnte nicht ermittelt werden.", + "moveUnsupported": "Move is not supported for this item." } }, "statistics": { @@ -1460,7 +1461,8 @@ "bulkMoveFailures": "Fehlgeschlagene Verschiebungen:\n{failures}", "bulkMoveSuccess": "{successCount} {type}s erfolgreich verschoben", "exampleImagesDownloadSuccess": "Beispielbilder erfolgreich heruntergeladen!", - "exampleImagesDownloadFailed": "Fehler beim Herunterladen der Beispielbilder: {message}" + "exampleImagesDownloadFailed": "Fehler beim Herunterladen der Beispielbilder: {message}", + "moveFailed": "Failed to move item: {message}" } }, "banners": { @@ -1478,4 +1480,4 @@ "learnMore": "LM Civitai Extension Tutorial" } } -} +} \ No newline at end of file diff --git a/locales/en.json b/locales/en.json index df25d8e9..3345567f 100644 --- a/locales/en.json +++ b/locales/en.json @@ -32,7 +32,7 @@ "korean": "한국어", "french": "Français", "spanish": "Español", - "Hebrew": "עברית" + "Hebrew": "עברית" }, "fileSize": { "zero": "0 Bytes", @@ -336,7 +336,7 @@ "templateOptions": { "flatStructure": "Flat Structure", "byBaseModel": "By Base Model", - "byAuthor": "By Author", + "byAuthor": "By Author", "byFirstTag": "By First Tag", "baseModelFirstTag": "Base Model + First Tag", "baseModelAuthor": "Base Model + Author", @@ -347,7 +347,7 @@ "customTemplatePlaceholder": "Enter custom template (e.g., {base_model}/{author}/{first_tag})", "modelTypes": { "lora": "LoRA", - "checkpoint": "Checkpoint", + "checkpoint": "Checkpoint", "embedding": "Embedding" }, "baseModelPathMappings": "Base Model Path Mappings", @@ -420,11 +420,11 @@ "proxyHost": "Proxy Host", "proxyHostPlaceholder": "proxy.example.com", "proxyHostHelp": "The hostname or IP address of your proxy server", - "proxyPort": "Proxy Port", + "proxyPort": "Proxy Port", "proxyPortPlaceholder": "8080", "proxyPortHelp": "The port number of your proxy server", "proxyUsername": "Username (Optional)", - "proxyUsernamePlaceholder": "username", + "proxyUsernamePlaceholder": "username", "proxyUsernameHelp": "Username for proxy authentication (if required)", "proxyPassword": "Password (Optional)", "proxyPasswordPlaceholder": "password", @@ -638,7 +638,8 @@ "recursiveUnavailable": "Recursive search is available in tree view only", "collapseAllDisabled": "Not available in list view", "dragDrop": { - "unableToResolveRoot": "Unable to determine destination path for move." + "unableToResolveRoot": "Unable to determine destination path for move.", + "moveUnsupported": "Move is not supported for this item." } }, "statistics": { @@ -1460,7 +1461,8 @@ "bulkMoveFailures": "Failed moves:\n{failures}", "bulkMoveSuccess": "Successfully moved {successCount} {type}s", "exampleImagesDownloadSuccess": "Successfully downloaded example images!", - "exampleImagesDownloadFailed": "Failed to download example images: {message}" + "exampleImagesDownloadFailed": "Failed to download example images: {message}", + "moveFailed": "Failed to move item: {message}" } }, "banners": { @@ -1478,4 +1480,4 @@ "learnMore": "LM Civitai Extension Tutorial" } } -} +} \ No newline at end of file diff --git a/locales/es.json b/locales/es.json index ff0e9e16..d05018a5 100644 --- a/locales/es.json +++ b/locales/es.json @@ -638,7 +638,8 @@ "recursiveUnavailable": "La búsqueda recursiva solo está disponible en la vista en árbol", "collapseAllDisabled": "No disponible en vista de lista", "dragDrop": { - "unableToResolveRoot": "No se puede determinar la ruta de destino para el movimiento." + "unableToResolveRoot": "No se puede determinar la ruta de destino para el movimiento.", + "moveUnsupported": "Move is not supported for this item." } }, "statistics": { @@ -1460,7 +1461,8 @@ "bulkMoveFailures": "Movimientos fallidos:\n{failures}", "bulkMoveSuccess": "Movidos exitosamente {successCount} {type}s", "exampleImagesDownloadSuccess": "¡Imágenes de ejemplo descargadas exitosamente!", - "exampleImagesDownloadFailed": "Error al descargar imágenes de ejemplo: {message}" + "exampleImagesDownloadFailed": "Error al descargar imágenes de ejemplo: {message}", + "moveFailed": "Failed to move item: {message}" } }, "banners": { @@ -1478,4 +1480,4 @@ "learnMore": "LM Civitai Extension Tutorial" } } -} +} \ No newline at end of file diff --git a/locales/fr.json b/locales/fr.json index d7c82004..1d36be36 100644 --- a/locales/fr.json +++ b/locales/fr.json @@ -638,7 +638,8 @@ "recursiveUnavailable": "La recherche récursive n'est disponible qu'en vue arborescente", "collapseAllDisabled": "Non disponible en vue liste", "dragDrop": { - "unableToResolveRoot": "Impossible de déterminer le chemin de destination pour le déplacement." + "unableToResolveRoot": "Impossible de déterminer le chemin de destination pour le déplacement.", + "moveUnsupported": "Move is not supported for this item." } }, "statistics": { @@ -1460,7 +1461,8 @@ "bulkMoveFailures": "Échecs de déplacement :\n{failures}", "bulkMoveSuccess": "{successCount} {type}s déplacés avec succès", "exampleImagesDownloadSuccess": "Images d'exemple téléchargées avec succès !", - "exampleImagesDownloadFailed": "Échec du téléchargement des images d'exemple : {message}" + "exampleImagesDownloadFailed": "Échec du téléchargement des images d'exemple : {message}", + "moveFailed": "Failed to move item: {message}" } }, "banners": { @@ -1478,4 +1480,4 @@ "learnMore": "LM Civitai Extension Tutorial" } } -} +} \ No newline at end of file diff --git a/locales/he.json b/locales/he.json index 4afa4aa4..f7dace98 100644 --- a/locales/he.json +++ b/locales/he.json @@ -638,7 +638,8 @@ "recursiveUnavailable": "חיפוש רקורסיבי זמין רק בתצוגת עץ", "collapseAllDisabled": "לא זמין בתצוגת רשימה", "dragDrop": { - "unableToResolveRoot": "לא ניתן לקבוע את נתיב היעד להעברה." + "unableToResolveRoot": "לא ניתן לקבוע את נתיב היעד להעברה.", + "moveUnsupported": "Move is not supported for this item." } }, "statistics": { @@ -1460,7 +1461,8 @@ "bulkMoveFailures": "העברות שנכשלו:\n{failures}", "bulkMoveSuccess": "הועברו בהצלחה {successCount} {type}s", "exampleImagesDownloadSuccess": "תמונות הדוגמה הורדו בהצלחה!", - "exampleImagesDownloadFailed": "הורדת תמונות הדוגמה נכשלה: {message}" + "exampleImagesDownloadFailed": "הורדת תמונות הדוגמה נכשלה: {message}", + "moveFailed": "Failed to move item: {message}" } }, "banners": { @@ -1478,4 +1480,4 @@ "learnMore": "LM Civitai Extension Tutorial" } } -} +} \ No newline at end of file diff --git a/locales/ja.json b/locales/ja.json index 7b83ec8f..336cc856 100644 --- a/locales/ja.json +++ b/locales/ja.json @@ -638,7 +638,8 @@ "recursiveUnavailable": "再帰検索はツリービューでのみ利用できます", "collapseAllDisabled": "リストビューでは利用できません", "dragDrop": { - "unableToResolveRoot": "移動先のパスを特定できません。" + "unableToResolveRoot": "移動先のパスを特定できません。", + "moveUnsupported": "Move is not supported for this item." } }, "statistics": { @@ -1460,7 +1461,8 @@ "bulkMoveFailures": "失敗した移動:\n{failures}", "bulkMoveSuccess": "{successCount} {type}が正常に移動されました", "exampleImagesDownloadSuccess": "例画像が正常にダウンロードされました!", - "exampleImagesDownloadFailed": "例画像のダウンロードに失敗しました:{message}" + "exampleImagesDownloadFailed": "例画像のダウンロードに失敗しました:{message}", + "moveFailed": "Failed to move item: {message}" } }, "banners": { @@ -1478,4 +1480,4 @@ "learnMore": "LM Civitai Extension Tutorial" } } -} +} \ No newline at end of file diff --git a/locales/ko.json b/locales/ko.json index 9750f070..262b5c30 100644 --- a/locales/ko.json +++ b/locales/ko.json @@ -638,7 +638,8 @@ "recursiveUnavailable": "재귀 검색은 트리 보기에서만 사용할 수 있습니다", "collapseAllDisabled": "목록 보기에서는 사용할 수 없습니다", "dragDrop": { - "unableToResolveRoot": "이동할 대상 경로를 확인할 수 없습니다." + "unableToResolveRoot": "이동할 대상 경로를 확인할 수 없습니다.", + "moveUnsupported": "Move is not supported for this item." } }, "statistics": { @@ -1460,7 +1461,8 @@ "bulkMoveFailures": "실패한 이동:\n{failures}", "bulkMoveSuccess": "{successCount}개 {type}이(가) 성공적으로 이동되었습니다", "exampleImagesDownloadSuccess": "예시 이미지가 성공적으로 다운로드되었습니다!", - "exampleImagesDownloadFailed": "예시 이미지 다운로드 실패: {message}" + "exampleImagesDownloadFailed": "예시 이미지 다운로드 실패: {message}", + "moveFailed": "Failed to move item: {message}" } }, "banners": { @@ -1478,4 +1480,4 @@ "learnMore": "LM Civitai Extension Tutorial" } } -} +} \ No newline at end of file diff --git a/locales/ru.json b/locales/ru.json index 9c22651a..851bccfa 100644 --- a/locales/ru.json +++ b/locales/ru.json @@ -638,7 +638,8 @@ "recursiveUnavailable": "Рекурсивный поиск доступен только в режиме дерева", "collapseAllDisabled": "Недоступно в виде списка", "dragDrop": { - "unableToResolveRoot": "Не удалось определить путь назначения для перемещения." + "unableToResolveRoot": "Не удалось определить путь назначения для перемещения.", + "moveUnsupported": "Move is not supported for this item." } }, "statistics": { @@ -1460,7 +1461,8 @@ "bulkMoveFailures": "Неудачные перемещения:\n{failures}", "bulkMoveSuccess": "Успешно перемещено {successCount} {type}s", "exampleImagesDownloadSuccess": "Примеры изображений успешно загружены!", - "exampleImagesDownloadFailed": "Не удалось загрузить примеры изображений: {message}" + "exampleImagesDownloadFailed": "Не удалось загрузить примеры изображений: {message}", + "moveFailed": "Failed to move item: {message}" } }, "banners": { @@ -1478,4 +1480,4 @@ "learnMore": "LM Civitai Extension Tutorial" } } -} +} \ No newline at end of file diff --git a/locales/zh-CN.json b/locales/zh-CN.json index df02db1d..2d9c5295 100644 --- a/locales/zh-CN.json +++ b/locales/zh-CN.json @@ -638,7 +638,8 @@ "recursiveUnavailable": "仅在树形视图中可使用递归搜索", "collapseAllDisabled": "列表视图下不可用", "dragDrop": { - "unableToResolveRoot": "无法确定移动的目标路径。" + "unableToResolveRoot": "无法确定移动的目标路径。", + "moveUnsupported": "Move is not supported for this item." } }, "statistics": { @@ -1460,7 +1461,8 @@ "bulkMoveFailures": "移动失败:\n{failures}", "bulkMoveSuccess": "成功移动 {successCount} 个 {type}", "exampleImagesDownloadSuccess": "示例图片下载成功!", - "exampleImagesDownloadFailed": "示例图片下载失败:{message}" + "exampleImagesDownloadFailed": "示例图片下载失败:{message}", + "moveFailed": "Failed to move item: {message}" } }, "banners": { @@ -1478,4 +1480,4 @@ "learnMore": "浏览器插件教程" } } -} +} \ No newline at end of file diff --git a/locales/zh-TW.json b/locales/zh-TW.json index 0d5a8dae..e6a1967b 100644 --- a/locales/zh-TW.json +++ b/locales/zh-TW.json @@ -638,7 +638,8 @@ "recursiveUnavailable": "遞迴搜尋僅能在樹狀檢視中使用", "collapseAllDisabled": "列表檢視下不可用", "dragDrop": { - "unableToResolveRoot": "無法確定移動的目標路徑。" + "unableToResolveRoot": "無法確定移動的目標路徑。", + "moveUnsupported": "Move is not supported for this item." } }, "statistics": { @@ -1460,7 +1461,8 @@ "bulkMoveFailures": "移動失敗:\n{failures}", "bulkMoveSuccess": "已成功移動 {successCount} 個 {type}", "exampleImagesDownloadSuccess": "範例圖片下載成功!", - "exampleImagesDownloadFailed": "下載範例圖片失敗:{message}" + "exampleImagesDownloadFailed": "下載範例圖片失敗:{message}", + "moveFailed": "Failed to move item: {message}" } }, "banners": { @@ -1478,4 +1480,4 @@ "learnMore": "LM Civitai Extension Tutorial" } } -} +} \ No newline at end of file diff --git a/py/routes/handlers/recipe_handlers.py b/py/routes/handlers/recipe_handlers.py index f6751c4a..cf582bf7 100644 --- a/py/routes/handlers/recipe_handlers.py +++ b/py/routes/handlers/recipe_handlers.py @@ -56,6 +56,7 @@ class RecipeHandlerSet: "delete_recipe": self.management.delete_recipe, "get_top_tags": self.query.get_top_tags, "get_base_models": self.query.get_base_models, + "get_roots": self.query.get_roots, "get_folders": self.query.get_folders, "get_folder_tree": self.query.get_folder_tree, "get_unified_folder_tree": self.query.get_unified_folder_tree, @@ -69,6 +70,7 @@ class RecipeHandlerSet: "save_recipe_from_widget": self.management.save_recipe_from_widget, "get_recipes_for_lora": self.query.get_recipes_for_lora, "scan_recipes": self.query.scan_recipes, + "move_recipe": self.management.move_recipe, } @@ -306,6 +308,19 @@ class RecipeQueryHandler: self._logger.error("Error retrieving base models: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) + async def get_roots(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + roots = [recipe_scanner.recipes_dir] if recipe_scanner.recipes_dir else [] + return web.json_response({"success": True, "roots": roots}) + except Exception as exc: + self._logger.error("Error retrieving recipe roots: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + async def get_folders(self, request: web.Request) -> web.Response: try: await self._ensure_dependencies_ready() @@ -591,6 +606,35 @@ class RecipeManagementHandler: self._logger.error("Error updating recipe: %s", exc, exc_info=True) return web.json_response({"error": str(exc)}, status=500) + async def move_recipe(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + data = await request.json() + recipe_id = data.get("recipe_id") + target_path = data.get("target_path") + if not recipe_id or not target_path: + return web.json_response( + {"success": False, "error": "recipe_id and target_path are required"}, status=400 + ) + + result = await self._persistence_service.move_recipe( + recipe_scanner=recipe_scanner, + recipe_id=str(recipe_id), + target_path=str(target_path), + ) + return web.json_response(result.payload, status=result.status) + except RecipeValidationError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=400) + except RecipeNotFoundError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=404) + except Exception as exc: + self._logger.error("Error moving recipe: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + async def reconnect_lora(self, request: web.Request) -> web.Response: try: await self._ensure_dependencies_ready() diff --git a/py/routes/recipe_route_registrar.py b/py/routes/recipe_route_registrar.py index 22f18d88..f397f501 100644 --- a/py/routes/recipe_route_registrar.py +++ b/py/routes/recipe_route_registrar.py @@ -27,6 +27,7 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("DELETE", "/api/lm/recipe/{recipe_id}", "delete_recipe"), RouteDefinition("GET", "/api/lm/recipes/top-tags", "get_top_tags"), RouteDefinition("GET", "/api/lm/recipes/base-models", "get_base_models"), + RouteDefinition("GET", "/api/lm/recipes/roots", "get_roots"), RouteDefinition("GET", "/api/lm/recipes/folders", "get_folders"), RouteDefinition("GET", "/api/lm/recipes/folder-tree", "get_folder_tree"), RouteDefinition("GET", "/api/lm/recipes/unified-folder-tree", "get_unified_folder_tree"), @@ -34,6 +35,7 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share/download", "download_shared_recipe"), RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/syntax", "get_recipe_syntax"), RouteDefinition("PUT", "/api/lm/recipe/{recipe_id}/update", "update_recipe"), + RouteDefinition("POST", "/api/lm/recipe/move", "move_recipe"), RouteDefinition("POST", "/api/lm/recipe/lora/reconnect", "reconnect_lora"), RouteDefinition("GET", "/api/lm/recipes/find-duplicates", "find_duplicates"), RouteDefinition("POST", "/api/lm/recipes/bulk-delete", "bulk_delete"), diff --git a/py/services/recipe_scanner.py b/py/services/recipe_scanner.py index efa77119..1ffb30b3 100644 --- a/py/services/recipe_scanner.py +++ b/py/services/recipe_scanner.py @@ -1246,6 +1246,30 @@ class RecipeScanner: from datetime import datetime return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S') + async def get_recipe_json_path(self, recipe_id: str) -> Optional[str]: + """Locate the recipe JSON file, accounting for folder placement.""" + + recipes_dir = self.recipes_dir + if not recipes_dir: + return None + + cache = await self.get_cached_data() + folder = "" + for item in cache.raw_data: + if str(item.get("id")) == str(recipe_id): + folder = item.get("folder") or "" + break + + candidate = os.path.normpath(os.path.join(recipes_dir, folder, f"{recipe_id}.recipe.json")) + if os.path.exists(candidate): + return candidate + + for root, _, files in os.walk(recipes_dir): + if f"{recipe_id}.recipe.json" in files: + return os.path.join(root, f"{recipe_id}.recipe.json") + + return None + async def update_recipe_metadata(self, recipe_id: str, metadata: dict) -> bool: """Update recipe metadata (like title and tags) in both file system and cache @@ -1256,13 +1280,9 @@ class RecipeScanner: Returns: bool: True if successful, False otherwise """ - import os - import json - # First, find the recipe JSON file path - recipe_json_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json") - - if not os.path.exists(recipe_json_path): + recipe_json_path = await self.get_recipe_json_path(recipe_id) + if not recipe_json_path or not os.path.exists(recipe_json_path): return False try: @@ -1311,8 +1331,8 @@ class RecipeScanner: if target_name is None: raise ValueError("target_name must be provided") - recipe_json_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json") - if not os.path.exists(recipe_json_path): + recipe_json_path = await self.get_recipe_json_path(recipe_id) + if not recipe_json_path or not os.path.exists(recipe_json_path): raise RecipeNotFoundError("Recipe not found") async with self._mutation_lock: diff --git a/py/services/recipes/persistence_service.py b/py/services/recipes/persistence_service.py index 535f0853..98d7e7d5 100644 --- a/py/services/recipes/persistence_service.py +++ b/py/services/recipes/persistence_service.py @@ -5,6 +5,7 @@ import base64 import json import os import re +import shutil import time import uuid from dataclasses import dataclass @@ -154,12 +155,8 @@ class RecipePersistenceService: async def delete_recipe(self, *, recipe_scanner, recipe_id: str) -> PersistenceResult: """Delete an existing recipe.""" - recipes_dir = recipe_scanner.recipes_dir - if not recipes_dir or not os.path.exists(recipes_dir): - raise RecipeNotFoundError("Recipes directory not found") - - recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json") - if not os.path.exists(recipe_json_path): + recipe_json_path = await recipe_scanner.get_recipe_json_path(recipe_id) + if not recipe_json_path or not os.path.exists(recipe_json_path): raise RecipeNotFoundError("Recipe not found") with open(recipe_json_path, "r", encoding="utf-8") as file_obj: @@ -187,6 +184,83 @@ class RecipePersistenceService: return PersistenceResult({"success": True, "recipe_id": recipe_id, "updates": updates}) + async def move_recipe(self, *, recipe_scanner, recipe_id: str, target_path: str) -> PersistenceResult: + """Move a recipe's assets into a new folder under the recipes root.""" + + if not target_path: + raise RecipeValidationError("Target path is required") + + recipes_root = recipe_scanner.recipes_dir + if not recipes_root: + raise RecipeNotFoundError("Recipes directory not found") + + normalized_target = os.path.normpath(target_path) + recipes_root = os.path.normpath(recipes_root) + if not os.path.isabs(normalized_target): + normalized_target = os.path.normpath(os.path.join(recipes_root, normalized_target)) + + try: + common_root = os.path.commonpath([normalized_target, recipes_root]) + except ValueError as exc: + raise RecipeValidationError("Invalid target path") from exc + + if common_root != recipes_root: + raise RecipeValidationError("Target path must be inside the recipes directory") + + recipe_json_path = await recipe_scanner.get_recipe_json_path(recipe_id) + if not recipe_json_path or not os.path.exists(recipe_json_path): + raise RecipeNotFoundError("Recipe not found") + + recipe_data = await recipe_scanner.get_recipe_by_id(recipe_id) + if not recipe_data: + raise RecipeNotFoundError("Recipe not found") + + current_json_dir = os.path.dirname(recipe_json_path) + normalized_image_path = os.path.normpath(recipe_data.get("file_path") or "") if recipe_data.get("file_path") else None + + os.makedirs(normalized_target, exist_ok=True) + + if os.path.normpath(current_json_dir) == normalized_target: + return PersistenceResult( + { + "success": True, + "message": "Recipe is already in the target folder", + "recipe_id": recipe_id, + "original_file_path": recipe_data.get("file_path"), + "new_file_path": recipe_data.get("file_path"), + } + ) + + new_json_path = os.path.normpath(os.path.join(normalized_target, os.path.basename(recipe_json_path))) + shutil.move(recipe_json_path, new_json_path) + + new_image_path = normalized_image_path + if normalized_image_path: + target_image_path = os.path.normpath(os.path.join(normalized_target, os.path.basename(normalized_image_path))) + if os.path.exists(normalized_image_path) and normalized_image_path != target_image_path: + shutil.move(normalized_image_path, target_image_path) + new_image_path = target_image_path + + relative_folder = os.path.relpath(normalized_target, recipes_root) + if relative_folder in (".", ""): + relative_folder = "" + updates = {"file_path": new_image_path or recipe_data.get("file_path"), "folder": relative_folder.replace(os.path.sep, "/")} + + updated = await recipe_scanner.update_recipe_metadata(recipe_id, updates) + if not updated: + raise RecipeNotFoundError("Recipe not found after move") + + return PersistenceResult( + { + "success": True, + "recipe_id": recipe_id, + "original_file_path": recipe_data.get("file_path"), + "new_file_path": updates["file_path"], + "json_path": new_json_path, + "folder": updates["folder"], + } + ) + async def reconnect_lora( self, *, @@ -197,8 +271,8 @@ class RecipePersistenceService: ) -> PersistenceResult: """Reconnect a LoRA entry within an existing recipe.""" - recipe_path = os.path.join(recipe_scanner.recipes_dir, f"{recipe_id}.recipe.json") - if not os.path.exists(recipe_path): + recipe_path = await recipe_scanner.get_recipe_json_path(recipe_id) + if not recipe_path or not os.path.exists(recipe_path): raise RecipeNotFoundError("Recipe not found") target_lora = await recipe_scanner.get_local_lora(target_name) @@ -243,16 +317,12 @@ class RecipePersistenceService: if not recipe_ids: raise RecipeValidationError("No recipe IDs provided") - recipes_dir = recipe_scanner.recipes_dir - if not recipes_dir or not os.path.exists(recipes_dir): - raise RecipeNotFoundError("Recipes directory not found") - deleted_recipes: list[str] = [] failed_recipes: list[dict[str, Any]] = [] for recipe_id in recipe_ids: - recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json") - if not os.path.exists(recipe_json_path): + recipe_json_path = await recipe_scanner.get_recipe_json_path(recipe_id) + if not recipe_json_path or not os.path.exists(recipe_json_path): failed_recipes.append({"id": recipe_id, "reason": "Recipe not found"}) continue diff --git a/static/js/api/recipeApi.js b/static/js/api/recipeApi.js index 1421e569..632edf49 100644 --- a/static/js/api/recipeApi.js +++ b/static/js/api/recipeApi.js @@ -7,19 +7,28 @@ const RECIPE_ENDPOINTS = { detail: '/api/lm/recipe', scan: '/api/lm/recipes/scan', update: '/api/lm/recipe', + roots: '/api/lm/recipes/roots', folders: '/api/lm/recipes/folders', folderTree: '/api/lm/recipes/folder-tree', unifiedFolderTree: '/api/lm/recipes/unified-folder-tree', + move: '/api/lm/recipe/move', }; const RECIPE_SIDEBAR_CONFIG = { config: { displayName: 'Recipes', - supportsMove: false, + supportsMove: true, }, endpoints: RECIPE_ENDPOINTS, }; +function extractRecipeId(filePath) { + if (!filePath) return null; + const basename = filePath.split('/').pop().split('\\').pop(); + const dotIndex = basename.lastIndexOf('.'); + return dotIndex > 0 ? basename.substring(0, dotIndex) : basename; +} + /** * Fetch recipes with pagination for virtual scrolling * @param {number} page - Page number to fetch @@ -302,8 +311,10 @@ export async function updateRecipeMetadata(filePath, updates) { state.loadingManager.showSimpleLoading('Saving metadata...'); // Extract recipeId from filePath (basename without extension) - const basename = filePath.split('/').pop().split('\\').pop(); - const recipeId = basename.substring(0, basename.lastIndexOf('.')); + const recipeId = extractRecipeId(filePath); + if (!recipeId) { + throw new Error('Unable to determine recipe ID'); + } const response = await fetch(`${RECIPE_ENDPOINTS.update}/${recipeId}/update`, { method: 'PUT', @@ -345,6 +356,14 @@ export class RecipeSidebarApiClient { return response.json(); } + async fetchModelRoots() { + const response = await fetch(this.apiConfig.endpoints.roots); + if (!response.ok) { + throw new Error('Failed to fetch recipe roots'); + } + return response.json(); + } + async fetchModelFolders() { const response = await fetch(this.apiConfig.endpoints.folders); if (!response.ok) { @@ -353,11 +372,69 @@ export class RecipeSidebarApiClient { return response.json(); } - async moveBulkModels() { - throw new Error('Recipe move operations are not supported.'); + async moveBulkModels(filePaths, targetPath) { + const results = []; + for (const path of filePaths) { + try { + const result = await this.moveSingleModel(path, targetPath); + results.push({ + original_file_path: path, + new_file_path: result?.new_file_path, + success: !!result, + message: result?.message, + }); + } catch (error) { + results.push({ + original_file_path: path, + new_file_path: null, + success: false, + message: error.message, + }); + } + } + return results; } - async moveSingleModel() { - throw new Error('Recipe move operations are not supported.'); + async moveSingleModel(filePath, targetPath) { + if (!this.apiConfig.config.supportsMove) { + showToast('toast.api.moveNotSupported', { type: this.apiConfig.config.displayName }, 'warning'); + return null; + } + + const recipeId = extractRecipeId(filePath); + if (!recipeId) { + showToast('toast.api.moveFailed', { message: 'Recipe ID missing' }, 'error'); + return null; + } + + const response = await fetch(this.apiConfig.endpoints.move, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + recipe_id: recipeId, + target_path: targetPath, + }), + }); + + const result = await response.json(); + + if (!response.ok || !result.success) { + throw new Error(result.error || `Failed to move ${this.apiConfig.config.displayName}`); + } + + if (result.message) { + showToast('toast.api.moveInfo', { message: result.message }, 'info'); + } else { + showToast('toast.api.moveSuccess', { type: this.apiConfig.config.displayName }, 'success'); + } + + return { + original_file_path: result.original_file_path || filePath, + new_file_path: result.new_file_path || filePath, + folder: result.folder || '', + message: result.message, + }; } } diff --git a/static/js/components/ContextMenu/RecipeContextMenu.js b/static/js/components/ContextMenu/RecipeContextMenu.js index f9cb9719..6dcb709b 100644 --- a/static/js/components/ContextMenu/RecipeContextMenu.js +++ b/static/js/components/ContextMenu/RecipeContextMenu.js @@ -4,6 +4,7 @@ import { showToast, copyToClipboard, sendLoraToWorkflow } from '../../utils/uiHe import { setSessionItem, removeSessionItem } from '../../utils/storageHelpers.js'; import { updateRecipeMetadata } from '../../api/recipeApi.js'; import { state } from '../../state/index.js'; +import { moveManager } from '../../managers/MoveManager.js'; export class RecipeContextMenu extends BaseContextMenu { constructor() { @@ -77,6 +78,9 @@ export class RecipeContextMenu extends BaseContextMenu { // Share recipe this.currentCard.querySelector('.fa-share-alt')?.click(); break; + case 'move': + moveManager.showMoveModal(this.currentCard.dataset.filepath); + break; case 'delete': // Delete recipe this.currentCard.querySelector('.fa-trash')?.click(); diff --git a/static/js/managers/MoveManager.js b/static/js/managers/MoveManager.js index 1b23a827..88f62839 100644 --- a/static/js/managers/MoveManager.js +++ b/static/js/managers/MoveManager.js @@ -3,6 +3,7 @@ import { state, getCurrentPageState } from '../state/index.js'; import { modalManager } from './ModalManager.js'; import { bulkManager } from './BulkManager.js'; import { getModelApiClient } from '../api/modelApiFactory.js'; +import { RecipeSidebarApiClient } from '../api/recipeApi.js'; import { FolderTreeManager } from '../components/FolderTreeManager.js'; import { sidebarManager } from '../components/SidebarManager.js'; @@ -12,11 +13,22 @@ class MoveManager { this.bulkFilePaths = null; this.folderTreeManager = new FolderTreeManager(); this.initialized = false; + this.recipeApiClient = null; // Bind methods this.updateTargetPath = this.updateTargetPath.bind(this); } + _getApiClient(modelType = null) { + if (state.currentPageType === 'recipes') { + if (!this.recipeApiClient) { + this.recipeApiClient = new RecipeSidebarApiClient(); + } + return this.recipeApiClient; + } + return getModelApiClient(modelType); + } + initializeEventListeners() { if (this.initialized) return; @@ -36,7 +48,7 @@ class MoveManager { this.currentFilePath = null; this.bulkFilePaths = null; - const apiClient = getModelApiClient(); + const apiClient = this._getApiClient(modelType); const currentPageType = state.currentPageType; const modelConfig = apiClient.apiConfig.config; @@ -121,7 +133,7 @@ class MoveManager { async initializeFolderTree() { try { - const apiClient = getModelApiClient(); + const apiClient = this._getApiClient(); // Fetch unified folder tree const treeData = await apiClient.fetchUnifiedFolderTree(); @@ -141,7 +153,7 @@ class MoveManager { updateTargetPath() { const pathDisplay = document.getElementById('moveTargetPathDisplay'); const modelRoot = document.getElementById('moveModelRoot').value; - const apiClient = getModelApiClient(); + const apiClient = this._getApiClient(); const config = apiClient.apiConfig.config; let fullPath = modelRoot || `Select a ${config.displayName.toLowerCase()} root directory`; @@ -158,7 +170,7 @@ class MoveManager { async moveModel() { const selectedRoot = document.getElementById('moveModelRoot').value; - const apiClient = getModelApiClient(); + const apiClient = this._getApiClient(); const config = apiClient.apiConfig.config; if (!selectedRoot) { diff --git a/templates/recipes.html b/templates/recipes.html index aba5ce63..f4f1d5bd 100644 --- a/templates/recipes.html +++ b/templates/recipes.html @@ -15,17 +15,26 @@ {% endblock %} @@ -34,55 +43,59 @@ {% block init_check_url %}/api/recipes?page=1&page_size=1{% endblock %} {% block content %} - -
-
-
- -
-
- -
- -
- -
- -