Compare commits

...

22 Commits

Author SHA1 Message Date
Will Miao
9112cd3b62 chore: Add .claude/ to gitignore
Exclude Claude Code personal configuration directory containing:
- settings.local.json (personal permissions and local paths)
- skills/ (personal skills)

These contain machine-specific paths and personal preferences
that should not be shared across the team.
2026-03-22 14:17:15 +08:00
Will Miao
7df4e8d037 fix(metadata_hook): correct function signature to fix bound method error
Fix issue #866 where the metadata hook's async wrapper used *args/**kwargs
which caused AttributeError when ComfyUI's make_locked_method_func tried
to access __func__ on the func parameter.

The async_map_node_over_list_with_metadata wrapper now uses the exact
same signature as ComfyUI's _async_map_node_over_list:
- Removed: *args, **kwargs
- Added: explicit v3_data=None parameter

This ensures the func parameter (always a string like obj.FUNCTION) is
passed correctly to make_locked_method_func without any type conversion.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-22 13:25:04 +08:00
Will Miao
4000b7f7e7 feat: Add configurable LoRA strength adjustment step setting
Implements issue #808 - Allow users to customize the strength
variation range for LoRA widget arrow buttons.

Changes:
- Add 'Strength Adjustment Step' setting (0.01-0.1) in settings.js
- Replace hardcoded 0.05 increments with configurable step value
- Apply to both LoRA strength and CLIP strength controls

Fixes #808
2026-03-19 17:33:18 +08:00
Will Miao
76c15105e6 feat(lora-pool): add regex include/exclude name pattern filtering (#839)
Add name pattern filtering to LoRA Pool node allowing users to filter
LoRAs by filename or model name using either plain text or regex patterns.

Features:
- Include patterns: only show LoRAs matching at least one pattern
- Exclude patterns: exclude LoRAs matching any pattern
- Regex toggle: switch between substring and regex matching
- Case-insensitive matching for both modes
- Invalid regex automatically falls back to substring matching
- Filters apply to both file_name and model_name fields

Backend:
- Update LoraPoolLM._default_config() with namePatterns structure
- Add name pattern filtering to _apply_pool_filters() and _apply_specific_filters()
- Add API parameter parsing for name_pattern_include/exclude/use_regex
- Update LoraPoolConfig type with namePatterns field

Frontend:
- Add NamePatternsSection.vue component with pattern input UI
- Update useLoraPoolState to manage pattern state and API integration
- Update LoraPoolSummaryView to display NamePatternsSection
- Increase LORA_POOL_WIDGET_MIN_HEIGHT to accommodate new UI

Tests:
- Add 7 test cases covering text/regex include, exclude, combined
  filtering, model name fallback, and invalid regex handling

Closes #839
2026-03-19 17:15:05 +08:00
Will Miao
b11c90e19b feat: add type ignore comments and remove unused imports
- Add `# type: ignore` comments to comfy.sd and folder_paths imports
- Remove unused imports: os, random, and extract_lora_name
- Clean up import statements across checkpoint_loader, lora_randomizer, and unet_loader nodes
2026-03-19 15:54:49 +08:00
pixelpaws
9f5d2d0c18 Merge pull request #862 from EnragedAntelope/claude/add-webp-image-support-t8kG9
Improve webp image support
2026-03-19 15:35:16 +08:00
Will Miao
a0dc5229f4 feat(unet_loader): move torch import inside methods for lazy loading
- Delay torch import until needed in load_unet and load_unet_gguf methods
- This improves module loading performance by avoiding unnecessary imports
- Maintains functionality while reducing initial import overhead
2026-03-19 15:29:41 +08:00
Will Miao
61c31ecbd0 fix: exclude __init__.py from pytest collection to prevent CI import errors 2026-03-19 14:43:45 +08:00
Will Miao
1ae1b0d607 refactor: move No LoRA feature from LoRA Pool to Lora Cycler widget
Move the 'empty/no LoRA' cycling functionality from the LoRA Pool node
to the Lora Cycler widget for cleaner architecture:

Frontend changes:
- Add include_no_lora field to CyclerConfig interface
- Add includeNoLora state and logic to useLoraCyclerState composable
- Add toggle UI in LoraCyclerSettingsView with special styling
- Show 'No LoRA' entry in LoraListModal when enabled
- Update LoraCyclerWidget to integrate new logic

Backend changes:
- lora_cycler.py reads include_no_lora from config
- Calculate effective_total_count (actual count + 1 when enabled)
- Return empty lora_stack when on No LoRA position
- Return actual LoRA count in total_count (not effective count)

Reverted files to pre-PR state:
- lora_loader.py, lora_pool.py, lora_randomizer.py, lora_stacker.py
- lora_routes.py, lora_service.py
- LoraPoolWidget.vue and related files

Related to PR #861

Co-authored-by: dogatech <dogatech@dogatech.home>
2026-03-19 14:19:49 +08:00
dogatech
8dd849892d Allow for empty lora (no loras option) in Lora Pool 2026-03-19 09:23:03 +08:00
Will Miao
03e1fa75c5 feat: auto-focus URL input when batch import modal opens 2026-03-18 22:33:45 +08:00
Will Miao
fefcaa4a45 fix: improve Civitai recipe import by extracting EXIF when API metadata is empty
- Add validation to check if Civitai API metadata contains recipe fields
- Fall back to EXIF extraction when API returns empty metadata (meta.meta=null)
- Improve error messages to distinguish between missing metadata and unsupported format
- Add _has_recipe_fields() helper method to validate metadata content

This fixes import failures for Civitai images where the API returns
metadata wrapper but no actual generation parameters (e.g., images
edited in Photoshop that lost their original generation metadata)
2026-03-18 22:30:36 +08:00
Will Miao
701a6a6c44 refactor: remove GGUF loading logic from CheckpointLoaderLM
GGUF models are pure Unet models and should be handled by UNETLoaderLM.
2026-03-18 21:36:07 +08:00
Will Miao
0ef414d17e feat: standardize Checkpoint/Unet loader names and use OS-native path separators
- Rename nodes to 'Checkpoint Loader (LoraManager)' and 'Unet Loader (LoraManager)'\n- Use os.sep for relative path formatting in model COMBO inputs\n- Update path matching to be robust across OS separators\n- Update docstrings and comments
2026-03-18 21:33:19 +08:00
Will Miao
75dccaef87 test: fix cache validator tests to account for new hash_status field and side effects 2026-03-18 21:10:56 +08:00
Will Miao
7e87ec9521 fix: persist hash_status in model cache to support lazy hashing on restart 2026-03-18 21:07:40 +08:00
Will Miao
46522edb1b refactor: simplify GGUF import helper with dynamic path detection
- Add _get_gguf_path() to dynamically derive ComfyUI-GGUF path from current file location
- Remove Strategy 2 and 3, keeping only Strategy 1 (sys.modules path-based lookup)
- Remove hard-coded absolute paths
- Streamline logging output
- Code cleanup: reduced from 235 to 154 lines
2026-03-18 19:55:54 +08:00
Will Miao
2dae4c1291 fix: isolate extra unet paths from checkpoints to prevent type misclassification
Refactor _prepare_checkpoint_paths() to return a tuple instead of having
side effects on instance variables. This prevents extra unet paths from
being incorrectly classified as checkpoints when processing extra paths.

- Changed return type from List[str] to Tuple[List[str], List[str], List[str]]
  (all_paths, checkpoint_roots, unet_roots)
- Updated _init_checkpoint_paths() and _apply_library_paths() callers
- Fixed extra paths processing to properly isolate main and extra roots
- Updated test_checkpoint_path_overlap.py tests for new API

This ensures models in extra unet paths are correctly identified as
diffusion_model type and don't appear in checkpoints list.
2026-03-17 22:03:57 +08:00
EnragedAntelope
a32325402e Merge branch 'willmiao:main' into claude/add-webp-image-support-t8kG9 2026-03-17 08:37:46 -04:00
Will Miao
70c150bd80 fix(services): implement stable sorting for model and recipe caches
Add file_path as a tie-breaker for all sort modes in ModelCache, BaseModelService, LoraService, and RecipeCache to ensure deterministic ordering when primary keys are identical. Resolves issue #859.
2026-03-17 14:20:23 +08:00
Claude
05ebd7493d chore: update package-lock.json after npm install
https://claude.ai/code/session_01SgT2pkisi27bEQELX5EeXZ
2026-03-17 01:33:34 +00:00
Claude
90986bd795 feat: add case-insensitive webp support for lora cover photos
Make preview file discovery case-insensitive so files with uppercase
extensions like .WEBP are found on case-sensitive filesystems. Also
explicitly list image/webp in the file picker accept attribute for
broader browser compatibility.

https://claude.ai/code/session_01SgT2pkisi27bEQELX5EeXZ
2026-03-17 01:32:48 +00:00
47 changed files with 2708 additions and 475 deletions

1
.gitignore vendored
View File

@@ -14,6 +14,7 @@ model_cache/
# agent # agent
.opencode/ .opencode/
.claude/
# Vue widgets development cache (but keep build output) # Vue widgets development cache (but keep build output)
vue-widgets/node_modules/ vue-widgets/node_modules/

View File

@@ -1,6 +1,8 @@
try: # pragma: no cover - import fallback for pytest collection try: # pragma: no cover - import fallback for pytest collection
from .py.lora_manager import LoraManager from .py.lora_manager import LoraManager
from .py.nodes.lora_loader import LoraLoaderLM, LoraTextLoaderLM from .py.nodes.lora_loader import LoraLoaderLM, LoraTextLoaderLM
from .py.nodes.checkpoint_loader import CheckpointLoaderLM
from .py.nodes.unet_loader import UNETLoaderLM
from .py.nodes.trigger_word_toggle import TriggerWordToggleLM from .py.nodes.trigger_word_toggle import TriggerWordToggleLM
from .py.nodes.prompt import PromptLM from .py.nodes.prompt import PromptLM
from .py.nodes.text import TextLM from .py.nodes.text import TextLM
@@ -27,12 +29,12 @@ except (
PromptLM = importlib.import_module("py.nodes.prompt").PromptLM PromptLM = importlib.import_module("py.nodes.prompt").PromptLM
TextLM = importlib.import_module("py.nodes.text").TextLM TextLM = importlib.import_module("py.nodes.text").TextLM
LoraManager = importlib.import_module("py.lora_manager").LoraManager LoraManager = importlib.import_module("py.lora_manager").LoraManager
LoraLoaderLM = importlib.import_module( LoraLoaderLM = importlib.import_module("py.nodes.lora_loader").LoraLoaderLM
"py.nodes.lora_loader" LoraTextLoaderLM = importlib.import_module("py.nodes.lora_loader").LoraTextLoaderLM
).LoraLoaderLM CheckpointLoaderLM = importlib.import_module(
LoraTextLoaderLM = importlib.import_module( "py.nodes.checkpoint_loader"
"py.nodes.lora_loader" ).CheckpointLoaderLM
).LoraTextLoaderLM UNETLoaderLM = importlib.import_module("py.nodes.unet_loader").UNETLoaderLM
TriggerWordToggleLM = importlib.import_module( TriggerWordToggleLM = importlib.import_module(
"py.nodes.trigger_word_toggle" "py.nodes.trigger_word_toggle"
).TriggerWordToggleLM ).TriggerWordToggleLM
@@ -49,9 +51,7 @@ except (
LoraRandomizerLM = importlib.import_module( LoraRandomizerLM = importlib.import_module(
"py.nodes.lora_randomizer" "py.nodes.lora_randomizer"
).LoraRandomizerLM ).LoraRandomizerLM
LoraCyclerLM = importlib.import_module( LoraCyclerLM = importlib.import_module("py.nodes.lora_cycler").LoraCyclerLM
"py.nodes.lora_cycler"
).LoraCyclerLM
init_metadata_collector = importlib.import_module("py.metadata_collector").init init_metadata_collector = importlib.import_module("py.metadata_collector").init
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
@@ -59,6 +59,8 @@ NODE_CLASS_MAPPINGS = {
TextLM.NAME: TextLM, TextLM.NAME: TextLM,
LoraLoaderLM.NAME: LoraLoaderLM, LoraLoaderLM.NAME: LoraLoaderLM,
LoraTextLoaderLM.NAME: LoraTextLoaderLM, LoraTextLoaderLM.NAME: LoraTextLoaderLM,
CheckpointLoaderLM.NAME: CheckpointLoaderLM,
UNETLoaderLM.NAME: UNETLoaderLM,
TriggerWordToggleLM.NAME: TriggerWordToggleLM, TriggerWordToggleLM.NAME: TriggerWordToggleLM,
LoraStackerLM.NAME: LoraStackerLM, LoraStackerLM.NAME: LoraStackerLM,
SaveImageLM.NAME: SaveImageLM, SaveImageLM.NAME: SaveImageLM,

3
package-lock.json generated
View File

@@ -114,7 +114,6 @@
} }
], ],
"license": "MIT", "license": "MIT",
"peer": true,
"engines": { "engines": {
"node": ">=18" "node": ">=18"
}, },
@@ -138,7 +137,6 @@
} }
], ],
"license": "MIT", "license": "MIT",
"peer": true,
"engines": { "engines": {
"node": ">=18" "node": ">=18"
} }
@@ -1613,7 +1611,6 @@
"integrity": "sha512-MyL55p3Ut3cXbeBEG7Hcv0mVM8pp8PBNWxRqchZnSfAiES1v1mRnMeFfaHWIPULpwsYfvO+ZmMZz5tGCnjzDUQ==", "integrity": "sha512-MyL55p3Ut3cXbeBEG7Hcv0mVM8pp8PBNWxRqchZnSfAiES1v1mRnMeFfaHWIPULpwsYfvO+ZmMZz5tGCnjzDUQ==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"peer": true,
"dependencies": { "dependencies": {
"cssstyle": "^4.0.1", "cssstyle": "^4.0.1",
"data-urls": "^5.0.0", "data-urls": "^5.0.0",

View File

@@ -707,7 +707,13 @@ class Config:
def _prepare_checkpoint_paths( def _prepare_checkpoint_paths(
self, checkpoint_paths: Iterable[str], unet_paths: Iterable[str] self, checkpoint_paths: Iterable[str], unet_paths: Iterable[str]
) -> List[str]: ) -> Tuple[List[str], List[str], List[str]]:
"""Prepare checkpoint paths and return (all_roots, checkpoint_roots, unet_roots).
Returns:
Tuple of (all_unique_paths, checkpoint_only_paths, unet_only_paths)
This method does NOT modify instance variables - callers must set them.
"""
checkpoint_map = self._dedupe_existing_paths(checkpoint_paths) checkpoint_map = self._dedupe_existing_paths(checkpoint_paths)
unet_map = self._dedupe_existing_paths(unet_paths) unet_map = self._dedupe_existing_paths(unet_paths)
@@ -737,8 +743,8 @@ class Config:
checkpoint_values = set(checkpoint_map.values()) checkpoint_values = set(checkpoint_map.values())
unet_values = set(unet_map.values()) unet_values = set(unet_map.values())
self.checkpoints_roots = [p for p in unique_paths if p in checkpoint_values] checkpoint_roots = [p for p in unique_paths if p in checkpoint_values]
self.unet_roots = [p for p in unique_paths if p in unet_values] unet_roots = [p for p in unique_paths if p in unet_values]
for original_path in unique_paths: for original_path in unique_paths:
real_path = os.path.normpath(os.path.realpath(original_path)).replace( real_path = os.path.normpath(os.path.realpath(original_path)).replace(
@@ -747,7 +753,7 @@ class Config:
if real_path != original_path: if real_path != original_path:
self.add_path_mapping(original_path, real_path) self.add_path_mapping(original_path, real_path)
return unique_paths return unique_paths, checkpoint_roots, unet_roots
def _prepare_embedding_paths(self, raw_paths: Iterable[str]) -> List[str]: def _prepare_embedding_paths(self, raw_paths: Iterable[str]) -> List[str]:
path_map = self._dedupe_existing_paths(raw_paths) path_map = self._dedupe_existing_paths(raw_paths)
@@ -776,9 +782,11 @@ class Config:
embedding_paths = folder_paths.get("embeddings", []) or [] embedding_paths = folder_paths.get("embeddings", []) or []
self.loras_roots = self._prepare_lora_paths(lora_paths) self.loras_roots = self._prepare_lora_paths(lora_paths)
self.base_models_roots = self._prepare_checkpoint_paths( (
checkpoint_paths, unet_paths self.base_models_roots,
) self.checkpoints_roots,
self.unet_roots,
) = self._prepare_checkpoint_paths(checkpoint_paths, unet_paths)
self.embeddings_roots = self._prepare_embedding_paths(embedding_paths) self.embeddings_roots = self._prepare_embedding_paths(embedding_paths)
# Process extra paths (only for LoRA Manager, not shared with ComfyUI) # Process extra paths (only for LoRA Manager, not shared with ComfyUI)
@@ -789,18 +797,11 @@ class Config:
extra_embedding_paths = extra_paths.get("embeddings", []) or [] extra_embedding_paths = extra_paths.get("embeddings", []) or []
self.extra_loras_roots = self._prepare_lora_paths(extra_lora_paths) self.extra_loras_roots = self._prepare_lora_paths(extra_lora_paths)
# Save main paths before processing extra paths ( _prepare_checkpoint_paths overwrites them) (
saved_checkpoints_roots = self.checkpoints_roots _,
saved_unet_roots = self.unet_roots self.extra_checkpoints_roots,
self.extra_checkpoints_roots = self._prepare_checkpoint_paths( self.extra_unet_roots,
extra_checkpoint_paths, extra_unet_paths ) = self._prepare_checkpoint_paths(extra_checkpoint_paths, extra_unet_paths)
)
self.extra_unet_roots = (
self.unet_roots if self.unet_roots is not None else []
) # unet_roots was set by _prepare_checkpoint_paths
# Restore main paths
self.checkpoints_roots = saved_checkpoints_roots
self.unet_roots = saved_unet_roots
self.extra_embeddings_roots = self._prepare_embedding_paths( self.extra_embeddings_roots = self._prepare_embedding_paths(
extra_embedding_paths extra_embedding_paths
) )
@@ -857,9 +858,11 @@ class Config:
try: try:
raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints") raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
raw_unet_paths = folder_paths.get_folder_paths("unet") raw_unet_paths = folder_paths.get_folder_paths("unet")
unique_paths = self._prepare_checkpoint_paths( (
raw_checkpoint_paths, raw_unet_paths unique_paths,
) self.checkpoints_roots,
self.unet_roots,
) = self._prepare_checkpoint_paths(raw_checkpoint_paths, raw_unet_paths)
logger.info( logger.info(
"Found checkpoint roots:" "Found checkpoint roots:"

View File

@@ -148,10 +148,13 @@ class MetadataHook:
"""Install hooks for asynchronous execution model""" """Install hooks for asynchronous execution model"""
# Store the original _async_map_node_over_list function # Store the original _async_map_node_over_list function
original_map_node_over_list = getattr(execution, map_node_func_name) original_map_node_over_list = getattr(execution, map_node_func_name)
# Wrapped async function, compatible with both stable and nightly # Wrapped async function - signature must exactly match _async_map_node_over_list
async def async_map_node_over_list_with_metadata(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, *args, **kwargs): async def async_map_node_over_list_with_metadata(
hidden_inputs = kwargs.get('hidden_inputs', None) prompt_id, unique_id, obj, input_data_all, func,
allow_interrupt=False, execution_block_cb=None,
pre_execute_cb=None, v3_data=None
):
# Only collect metadata when calling the main function of nodes # Only collect metadata when calling the main function of nodes
if func == obj.FUNCTION and hasattr(obj, '__class__'): if func == obj.FUNCTION and hasattr(obj, '__class__'):
try: try:
@@ -163,13 +166,13 @@ class MetadataHook:
registry.record_node_execution(node_id, class_type, input_data_all, None) registry.record_node_execution(node_id, class_type, input_data_all, None)
except Exception as e: except Exception as e:
logger.error(f"Error collecting metadata (pre-execution): {str(e)}") logger.error(f"Error collecting metadata (pre-execution): {str(e)}")
# Call original function with all args/kwargs # Call original function with exact parameters
results = await original_map_node_over_list( results = await original_map_node_over_list(
prompt_id, unique_id, obj, input_data_all, func, prompt_id, unique_id, obj, input_data_all, func,
allow_interrupt, execution_block_cb, pre_execute_cb, *args, **kwargs allow_interrupt, execution_block_cb, pre_execute_cb, v3_data=v3_data
) )
if func == obj.FUNCTION and hasattr(obj, '__class__'): if func == obj.FUNCTION and hasattr(obj, '__class__'):
try: try:
registry = MetadataRegistry() registry = MetadataRegistry()
@@ -180,28 +183,28 @@ class MetadataHook:
registry.update_node_execution(node_id, class_type, results) registry.update_node_execution(node_id, class_type, results)
except Exception as e: except Exception as e:
logger.error(f"Error collecting metadata (post-execution): {str(e)}") logger.error(f"Error collecting metadata (post-execution): {str(e)}")
return results return results
# Also hook the execute function to track the current prompt_id # Also hook the execute function to track the current prompt_id
original_execute = execution.execute original_execute = execution.execute
async def async_execute_with_prompt_tracking(*args, **kwargs): async def async_execute_with_prompt_tracking(*args, **kwargs):
if len(args) >= 7: # Check if we have enough arguments if len(args) >= 7: # Check if we have enough arguments
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7] server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
registry = MetadataRegistry() registry = MetadataRegistry()
# Start collection if this is a new prompt # Start collection if this is a new prompt
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id: if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
registry.start_collection(prompt_id) registry.start_collection(prompt_id)
# Store the dynprompt reference for node lookups # Store the dynprompt reference for node lookups
if hasattr(prompt, 'original_prompt'): if hasattr(prompt, 'original_prompt'):
registry.set_current_prompt(prompt) registry.set_current_prompt(prompt)
# Execute the original function # Execute the original function
return await original_execute(*args, **kwargs) return await original_execute(*args, **kwargs)
# Replace the functions with async versions # Replace the functions with async versions
setattr(execution, map_node_func_name, async_map_node_over_list_with_metadata) setattr(execution, map_node_func_name, async_map_node_over_list_with_metadata)
execution.execute = async_execute_with_prompt_tracking execution.execute = async_execute_with_prompt_tracking

View File

@@ -0,0 +1,118 @@
import logging
from typing import List, Tuple
import comfy.sd # type: ignore
import folder_paths # type: ignore
from ..utils.utils import get_checkpoint_info_absolute, _format_model_name_for_comfyui
logger = logging.getLogger(__name__)
class CheckpointLoaderLM:
"""Checkpoint Loader with support for extra folder paths
Loads checkpoints from both standard ComfyUI folders and LoRA Manager's
extra folder paths, providing a unified interface for checkpoint loading.
"""
NAME = "Checkpoint Loader (LoraManager)"
CATEGORY = "Lora Manager/loaders"
@classmethod
def INPUT_TYPES(s):
# Get list of checkpoint names from scanner (includes extra folder paths)
checkpoint_names = s._get_checkpoint_names()
return {
"required": {
"ckpt_name": (
checkpoint_names,
{"tooltip": "The name of the checkpoint (model) to load."},
),
}
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
RETURN_NAMES = ("MODEL", "CLIP", "VAE")
OUTPUT_TOOLTIPS = (
"The model used for denoising latents.",
"The CLIP model used for encoding text prompts.",
"The VAE model used for encoding and decoding images to and from latent space.",
)
FUNCTION = "load_checkpoint"
@classmethod
def _get_checkpoint_names(cls) -> List[str]:
"""Get list of checkpoint names from scanner cache in ComfyUI format (relative path with extension)"""
try:
from ..services.service_registry import ServiceRegistry
import asyncio
async def _get_names():
scanner = await ServiceRegistry.get_checkpoint_scanner()
cache = await scanner.get_cached_data()
# Get all model roots for calculating relative paths
model_roots = scanner.get_model_roots()
# Filter only checkpoint type (not diffusion_model) and format names
names = []
for item in cache.raw_data:
if item.get("sub_type") == "checkpoint":
file_path = item.get("file_path", "")
if file_path:
# Format using relative path with OS-native separator
formatted_name = _format_model_name_for_comfyui(
file_path, model_roots
)
if formatted_name:
names.append(formatted_name)
return sorted(names)
try:
loop = asyncio.get_running_loop()
import concurrent.futures
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(_get_names())
finally:
new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result()
except RuntimeError:
return asyncio.run(_get_names())
except Exception as e:
logger.error(f"Error getting checkpoint names: {e}")
return []
def load_checkpoint(self, ckpt_name: str) -> Tuple:
"""Load a checkpoint by name, supporting extra folder paths
Args:
ckpt_name: The name of the checkpoint to load (relative path with extension)
Returns:
Tuple of (MODEL, CLIP, VAE)
"""
# Get absolute path from cache using ComfyUI-style name
ckpt_path, metadata = get_checkpoint_info_absolute(ckpt_name)
if metadata is None:
raise FileNotFoundError(
f"Checkpoint '{ckpt_name}' not found in LoRA Manager cache. "
"Make sure the checkpoint is indexed and try again."
)
# Load regular checkpoint using ComfyUI's API
logger.info(f"Loading checkpoint from: {ckpt_path}")
out = comfy.sd.load_checkpoint_guess_config(
ckpt_path,
output_vae=True,
output_clip=True,
embedding_directory=folder_paths.get_folder_paths("embeddings"),
)
return out[:3]

View File

@@ -0,0 +1,161 @@
"""
Helper module to safely import ComfyUI-GGUF modules.
This module provides a robust way to import ComfyUI-GGUF functionality
regardless of how ComfyUI loaded it.
"""
import sys
import os
import importlib.util
import logging
from typing import Optional, Tuple, Any
logger = logging.getLogger(__name__)
def _get_gguf_path() -> str:
"""Get the path to ComfyUI-GGUF based on this file's location.
Since ComfyUI-Lora-Manager and ComfyUI-GGUF are both in custom_nodes/,
we can derive the GGUF path from our own location.
"""
# This file is at: custom_nodes/ComfyUI-Lora-Manager/py/nodes/gguf_import_helper.py
# ComfyUI-GGUF is at: custom_nodes/ComfyUI-GGUF
current_file = os.path.abspath(__file__)
# Go up 4 levels: nodes -> py -> ComfyUI-Lora-Manager -> custom_nodes
custom_nodes_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(current_file)))
)
return os.path.join(custom_nodes_dir, "ComfyUI-GGUF")
def _find_gguf_module() -> Optional[Any]:
"""Find ComfyUI-GGUF module in sys.modules.
ComfyUI registers modules using the full path with dots replaced by _x_.
"""
gguf_path = _get_gguf_path()
sys_module_name = gguf_path.replace(".", "_x_")
logger.debug(f"[GGUF Import] Looking for module '{sys_module_name}' in sys.modules")
if sys_module_name in sys.modules:
logger.info(f"[GGUF Import] Found module: '{sys_module_name}'")
return sys.modules[sys_module_name]
logger.debug(f"[GGUF Import] Module not found: '{sys_module_name}'")
return None
def _load_gguf_modules_directly() -> Optional[Any]:
"""Load ComfyUI-GGUF modules directly from file paths."""
gguf_path = _get_gguf_path()
logger.info(f"[GGUF Import] Direct Load: Attempting to load from '{gguf_path}'")
if not os.path.exists(gguf_path):
logger.warning(f"[GGUF Import] Path does not exist: {gguf_path}")
return None
try:
namespace = "ComfyUI_GGUF_Dynamic"
init_path = os.path.join(gguf_path, "__init__.py")
if not os.path.exists(init_path):
logger.warning(f"[GGUF Import] __init__.py not found at '{init_path}'")
return None
logger.debug(f"[GGUF Import] Loading from '{init_path}'")
spec = importlib.util.spec_from_file_location(namespace, init_path)
if not spec or not spec.loader:
logger.error(f"[GGUF Import] Failed to create spec for '{init_path}'")
return None
package = importlib.util.module_from_spec(spec)
package.__path__ = [gguf_path]
sys.modules[namespace] = package
spec.loader.exec_module(package)
logger.debug(f"[GGUF Import] Loaded main package '{namespace}'")
# Load submodules
loaded = []
for submod_name in ["loader", "ops", "nodes"]:
submod_path = os.path.join(gguf_path, f"{submod_name}.py")
if os.path.exists(submod_path):
submod_spec = importlib.util.spec_from_file_location(
f"{namespace}.{submod_name}", submod_path
)
if submod_spec and submod_spec.loader:
submod = importlib.util.module_from_spec(submod_spec)
submod.__package__ = namespace
sys.modules[f"{namespace}.{submod_name}"] = submod
submod_spec.loader.exec_module(submod)
setattr(package, submod_name, submod)
loaded.append(submod_name)
logger.debug(f"[GGUF Import] Loaded submodule '{submod_name}'")
logger.info(f"[GGUF Import] Direct Load success: {loaded}")
return package
except Exception as e:
logger.error(f"[GGUF Import] Direct Load failed: {e}", exc_info=True)
return None
def get_gguf_modules() -> Tuple[Any, Any, Any]:
"""Get ComfyUI-GGUF modules (loader, ops, nodes).
Returns:
Tuple of (loader_module, ops_module, nodes_module)
Raises:
RuntimeError: If ComfyUI-GGUF cannot be found or loaded.
"""
logger.debug("[GGUF Import] Starting module search...")
# Try to find already loaded module first
gguf_module = _find_gguf_module()
if gguf_module is None:
logger.info("[GGUF Import] Not found in sys.modules, trying direct load...")
gguf_module = _load_gguf_modules_directly()
if gguf_module is None:
raise RuntimeError(
"ComfyUI-GGUF is not installed. "
"Please install from https://github.com/city96/ComfyUI-GGUF"
)
# Extract submodules
loader = getattr(gguf_module, "loader", None)
ops = getattr(gguf_module, "ops", None)
nodes = getattr(gguf_module, "nodes", None)
if loader is None or ops is None or nodes is None:
missing = [
name
for name, mod in [("loader", loader), ("ops", ops), ("nodes", nodes)]
if mod is None
]
raise RuntimeError(f"ComfyUI-GGUF missing submodules: {missing}")
logger.debug("[GGUF Import] All modules loaded successfully")
return loader, ops, nodes
def get_gguf_sd_loader():
"""Get the gguf_sd_loader function from ComfyUI-GGUF."""
loader, _, _ = get_gguf_modules()
return getattr(loader, "gguf_sd_loader")
def get_ggml_ops():
"""Get the GGMLOps class from ComfyUI-GGUF."""
_, ops, _ = get_gguf_modules()
return getattr(ops, "GGMLOps")
def get_gguf_model_patcher():
"""Get the GGUFModelPatcher class from ComfyUI-GGUF."""
_, _, nodes = get_gguf_modules()
return getattr(nodes, "GGUFModelPatcher")

View File

@@ -56,6 +56,9 @@ class LoraCyclerLM:
clip_strength = float(cycler_config.get("clip_strength", 1.0)) clip_strength = float(cycler_config.get("clip_strength", 1.0))
sort_by = "filename" sort_by = "filename"
# Include "no lora" option
include_no_lora = cycler_config.get("include_no_lora", False)
# Dual-index mechanism for batch queue synchronization # Dual-index mechanism for batch queue synchronization
execution_index = cycler_config.get("execution_index") # Can be None execution_index = cycler_config.get("execution_index") # Can be None
# next_index_from_config = cycler_config.get("next_index") # Not used on backend # next_index_from_config = cycler_config.get("next_index") # Not used on backend
@@ -71,7 +74,10 @@ class LoraCyclerLM:
total_count = len(lora_list) total_count = len(lora_list)
if total_count == 0: # Calculate effective total count (includes no lora option if enabled)
effective_total_count = total_count + 1 if include_no_lora else total_count
if total_count == 0 and not include_no_lora:
logger.warning("[LoraCyclerLM] No LoRAs available in pool") logger.warning("[LoraCyclerLM] No LoRAs available in pool")
return { return {
"result": ([],), "result": ([],),
@@ -93,42 +99,66 @@ class LoraCyclerLM:
else: else:
actual_index = current_index actual_index = current_index
# Clamp index to valid range (1-based) # Clamp index to valid range (1-based, includes no lora if enabled)
clamped_index = max(1, min(actual_index, total_count)) clamped_index = max(1, min(actual_index, effective_total_count))
# Get LoRA at current index (convert to 0-based for list access) # Check if current index is the "no lora" option (last position when include_no_lora is True)
current_lora = lora_list[clamped_index - 1] is_no_lora = include_no_lora and clamped_index == effective_total_count
# Build LORA_STACK with single LoRA if is_no_lora:
lora_path, _ = get_lora_info(current_lora["file_name"]) # "No LoRA" option - return empty stack
if not lora_path:
logger.warning(
f"[LoraCyclerLM] Could not find path for LoRA: {current_lora['file_name']}"
)
lora_stack = [] lora_stack = []
current_lora_name = "No LoRA"
current_lora_filename = "No LoRA"
else: else:
# Normalize path separators # Get LoRA at current index (convert to 0-based for list access)
lora_path = lora_path.replace("/", os.sep) current_lora = lora_list[clamped_index - 1]
lora_stack = [(lora_path, model_strength, clip_strength)] current_lora_name = current_lora["file_name"]
current_lora_filename = current_lora["file_name"]
# Build LORA_STACK with single LoRA
if current_lora["file_name"] == "None":
lora_path = None
else:
lora_path, _ = get_lora_info(current_lora["file_name"])
if not lora_path:
if current_lora["file_name"] != "None":
logger.warning(
f"[LoraCyclerLM] Could not find path for LoRA: {current_lora['file_name']}"
)
lora_stack = []
else:
# Normalize path separators
lora_path = lora_path.replace("/", os.sep)
lora_stack = [(lora_path, model_strength, clip_strength)]
# Calculate next index (wrap to 1 if at end) # Calculate next index (wrap to 1 if at end)
next_index = clamped_index + 1 next_index = clamped_index + 1
if next_index > total_count: if next_index > effective_total_count:
next_index = 1 next_index = 1
# Get next LoRA for UI display (what will be used next generation) # Get next LoRA for UI display (what will be used next generation)
next_lora = lora_list[next_index - 1] is_next_no_lora = include_no_lora and next_index == effective_total_count
next_display_name = next_lora["file_name"] if is_next_no_lora:
next_display_name = "No LoRA"
next_lora_filename = "No LoRA"
else:
next_lora = lora_list[next_index - 1]
next_display_name = next_lora["file_name"]
next_lora_filename = next_lora["file_name"]
return { return {
"result": (lora_stack,), "result": (lora_stack,),
"ui": { "ui": {
"current_index": [clamped_index], "current_index": [clamped_index],
"next_index": [next_index], "next_index": [next_index],
"total_count": [total_count], "total_count": [
"current_lora_name": [current_lora["file_name"]], total_count
"current_lora_filename": [current_lora["file_name"]], ], # Return actual LoRA count, not effective_total_count
"current_lora_name": [current_lora_name],
"current_lora_filename": [current_lora_filename],
"next_lora_name": [next_display_name], "next_lora_name": [next_display_name],
"next_lora_filename": [next_lora["file_name"]], "next_lora_filename": [next_lora_filename],
}, },
} }

View File

@@ -82,6 +82,7 @@ class LoraPoolLM:
"folders": {"include": [], "exclude": []}, "folders": {"include": [], "exclude": []},
"favoritesOnly": False, "favoritesOnly": False,
"license": {"noCreditRequired": False, "allowSelling": False}, "license": {"noCreditRequired": False, "allowSelling": False},
"namePatterns": {"include": [], "exclude": [], "useRegex": False},
}, },
"preview": {"matchCount": 0, "lastUpdated": 0}, "preview": {"matchCount": 0, "lastUpdated": 0},
} }

View File

@@ -7,10 +7,8 @@ and tracks the last used combination for reuse.
""" """
import logging import logging
import random
import os import os
from ..utils.utils import get_lora_info from ..utils.utils import get_lora_info
from .utils import extract_lora_name
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

205
py/nodes/unet_loader.py Normal file
View File

@@ -0,0 +1,205 @@
import logging
import os
from typing import List, Tuple
import comfy.sd # type: ignore
from ..utils.utils import get_checkpoint_info_absolute, _format_model_name_for_comfyui
logger = logging.getLogger(__name__)
class UNETLoaderLM:
"""UNET Loader with support for extra folder paths
Loads diffusion models/UNets from both standard ComfyUI folders and LoRA Manager's
extra folder paths, providing a unified interface for UNET loading.
Supports both regular diffusion models and GGUF format models.
"""
NAME = "Unet Loader (LoraManager)"
CATEGORY = "Lora Manager/loaders"
@classmethod
def INPUT_TYPES(s):
# Get list of unet names from scanner (includes extra folder paths)
unet_names = s._get_unet_names()
return {
"required": {
"unet_name": (
unet_names,
{"tooltip": "The name of the diffusion model to load."},
),
"weight_dtype": (
["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],
{"tooltip": "The dtype to use for the model weights."},
),
}
}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("MODEL",)
OUTPUT_TOOLTIPS = ("The model used for denoising latents.",)
FUNCTION = "load_unet"
@classmethod
def _get_unet_names(cls) -> List[str]:
"""Get list of diffusion model names from scanner cache in ComfyUI format (relative path with extension)"""
try:
from ..services.service_registry import ServiceRegistry
import asyncio
async def _get_names():
scanner = await ServiceRegistry.get_checkpoint_scanner()
cache = await scanner.get_cached_data()
# Get all model roots for calculating relative paths
model_roots = scanner.get_model_roots()
# Filter only diffusion_model type and format names
names = []
for item in cache.raw_data:
if item.get("sub_type") == "diffusion_model":
file_path = item.get("file_path", "")
if file_path:
# Format using relative path with OS-native separator
formatted_name = _format_model_name_for_comfyui(
file_path, model_roots
)
if formatted_name:
names.append(formatted_name)
return sorted(names)
try:
loop = asyncio.get_running_loop()
import concurrent.futures
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(_get_names())
finally:
new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result()
except RuntimeError:
return asyncio.run(_get_names())
except Exception as e:
logger.error(f"Error getting unet names: {e}")
return []
def load_unet(self, unet_name: str, weight_dtype: str) -> Tuple:
"""Load a diffusion model by name, supporting extra folder paths
Args:
unet_name: The name of the diffusion model to load (relative path with extension)
weight_dtype: The dtype to use for model weights
Returns:
Tuple of (MODEL,)
"""
import torch
# Get absolute path from cache using ComfyUI-style name
unet_path, metadata = get_checkpoint_info_absolute(unet_name)
if metadata is None:
raise FileNotFoundError(
f"Diffusion model '{unet_name}' not found in LoRA Manager cache. "
"Make sure the model is indexed and try again."
)
# Check if it's a GGUF model
if unet_path.endswith(".gguf"):
return self._load_gguf_unet(unet_path, unet_name, weight_dtype)
# Load regular diffusion model using ComfyUI's API
logger.info(f"Loading diffusion model from: {unet_path}")
# Build model options based on weight_dtype
model_options = {}
if weight_dtype == "fp8_e4m3fn":
model_options["dtype"] = torch.float8_e4m3fn
elif weight_dtype == "fp8_e4m3fn_fast":
model_options["dtype"] = torch.float8_e4m3fn
model_options["fp8_optimizations"] = True
elif weight_dtype == "fp8_e5m2":
model_options["dtype"] = torch.float8_e5m2
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,)
def _load_gguf_unet(
self, unet_path: str, unet_name: str, weight_dtype: str
) -> Tuple:
"""Load a GGUF format diffusion model
Args:
unet_path: Absolute path to the GGUF file
unet_name: Name of the model for error messages
weight_dtype: The dtype to use for model weights
Returns:
Tuple of (MODEL,)
"""
import torch
from .gguf_import_helper import get_gguf_modules
# Get ComfyUI-GGUF modules using helper (handles various import scenarios)
try:
loader_module, ops_module, nodes_module = get_gguf_modules()
gguf_sd_loader = getattr(loader_module, "gguf_sd_loader")
GGMLOps = getattr(ops_module, "GGMLOps")
GGUFModelPatcher = getattr(nodes_module, "GGUFModelPatcher")
except RuntimeError as e:
raise RuntimeError(f"Cannot load GGUF model '{unet_name}'. {str(e)}")
logger.info(f"Loading GGUF diffusion model from: {unet_path}")
try:
# Load GGUF state dict
sd, extra = gguf_sd_loader(unet_path)
# Prepare kwargs for metadata if supported
kwargs = {}
import inspect
valid_params = inspect.signature(
comfy.sd.load_diffusion_model_state_dict
).parameters
if "metadata" in valid_params:
kwargs["metadata"] = extra.get("metadata", {})
# Setup custom operations with GGUF support
ops = GGMLOps()
# Handle weight_dtype for GGUF models
if weight_dtype in ("default", None):
ops.Linear.dequant_dtype = None
elif weight_dtype in ["target"]:
ops.Linear.dequant_dtype = weight_dtype
else:
ops.Linear.dequant_dtype = getattr(torch, weight_dtype, None)
# Load the model
model = comfy.sd.load_diffusion_model_state_dict(
sd, model_options={"custom_operations": ops}, **kwargs
)
if model is None:
raise RuntimeError(
f"Could not detect model type for GGUF diffusion model: {unet_path}"
)
# Wrap with GGUFModelPatcher
model = GGUFModelPatcher.clone(model)
return (model,)
except Exception as e:
logger.error(f"Error loading GGUF diffusion model '{unet_name}': {e}")
raise RuntimeError(
f"Failed to load GGUF diffusion model '{unet_name}': {str(e)}"
)

View File

@@ -309,6 +309,13 @@ class ModelListingHandler:
else: else:
allow_selling_generated_content = None # None means no filter applied allow_selling_generated_content = None # None means no filter applied
# Name pattern filters for LoRA Pool
name_pattern_include = request.query.getall("name_pattern_include", [])
name_pattern_exclude = request.query.getall("name_pattern_exclude", [])
name_pattern_use_regex = (
request.query.get("name_pattern_use_regex", "false").lower() == "true"
)
return { return {
"page": page, "page": page,
"page_size": page_size, "page_size": page_size,
@@ -328,6 +335,9 @@ class ModelListingHandler:
"credit_required": credit_required, "credit_required": credit_required,
"allow_selling_generated_content": allow_selling_generated_content, "allow_selling_generated_content": allow_selling_generated_content,
"model_types": model_types, "model_types": model_types,
"name_pattern_include": name_pattern_include,
"name_pattern_exclude": name_pattern_exclude,
"name_pattern_use_regex": name_pattern_use_regex,
**self._parse_specific_params(request), **self._parse_specific_params(request),
} }

View File

@@ -208,7 +208,11 @@ class BaseModelService(ABC):
reverse = sort_params.order == "desc" reverse = sort_params.order == "desc"
annotated.sort( annotated.sort(
key=lambda x: (x.get("usage_count", 0), x.get("model_name", "").lower()), key=lambda x: (
x.get("usage_count", 0),
x.get("model_name", "").lower(),
x.get("file_path", "").lower()
),
reverse=reverse, reverse=reverse,
) )
return annotated return annotated

View File

@@ -58,6 +58,7 @@ class CacheEntryValidator:
'preview_nsfw_level': (0, False), 'preview_nsfw_level': (0, False),
'notes': ('', False), 'notes': ('', False),
'usage_tips': ('', False), 'usage_tips': ('', False),
'hash_status': ('completed', False),
} }
@classmethod @classmethod
@@ -90,13 +91,31 @@ class CacheEntryValidator:
errors: List[str] = [] errors: List[str] = []
repaired = False repaired = False
# If auto_repair is on, we work on a copy. If not, we still need a safe way to check fields.
working_entry = dict(entry) if auto_repair else entry working_entry = dict(entry) if auto_repair else entry
# Determine effective hash_status for validation logic
hash_status = entry.get('hash_status')
if hash_status is None:
if auto_repair:
working_entry['hash_status'] = 'completed'
repaired = True
hash_status = 'completed'
for field_name, (default_value, is_required) in cls.CORE_FIELDS.items(): for field_name, (default_value, is_required) in cls.CORE_FIELDS.items():
value = working_entry.get(field_name) # Get current value from the original entry to avoid side effects during validation
value = entry.get(field_name)
# Check if field is missing or None # Check if field is missing or None
if value is None: if value is None:
# Special case: sha256 can be None/empty if hash_status is pending
if field_name == 'sha256' and hash_status == 'pending':
if auto_repair:
working_entry[field_name] = ''
repaired = True
continue
if is_required: if is_required:
errors.append(f"Required field '{field_name}' is missing or None") errors.append(f"Required field '{field_name}' is missing or None")
if auto_repair: if auto_repair:
@@ -107,6 +126,10 @@ class CacheEntryValidator:
# Validate field type and value # Validate field type and value
field_error = cls._validate_field(field_name, value, default_value) field_error = cls._validate_field(field_name, value, default_value)
if field_error: if field_error:
# Special case: allow empty string for sha256 if pending
if field_name == 'sha256' and hash_status == 'pending' and value == '':
continue
errors.append(field_error) errors.append(field_error)
if auto_repair: if auto_repair:
working_entry[field_name] = cls._get_default_copy(default_value) working_entry[field_name] = cls._get_default_copy(default_value)
@@ -127,7 +150,7 @@ class CacheEntryValidator:
# Special validation: sha256 must not be empty for required field # Special validation: sha256 must not be empty for required field
# BUT allow empty sha256 when hash_status is pending (lazy hash calculation) # BUT allow empty sha256 when hash_status is pending (lazy hash calculation)
sha256 = working_entry.get('sha256', '') sha256 = working_entry.get('sha256', '')
hash_status = working_entry.get('hash_status', 'completed') # Use the effective hash_status we determined earlier
if not sha256 or (isinstance(sha256, str) and not sha256.strip()): if not sha256 or (isinstance(sha256, str) and not sha256.strip()):
# Allow empty sha256 for lazy hash calculation (checkpoints) # Allow empty sha256 for lazy hash calculation (checkpoints)
if hash_status != 'pending': if hash_status != 'pending':
@@ -144,8 +167,13 @@ class CacheEntryValidator:
if isinstance(sha256, str): if isinstance(sha256, str):
normalized_sha = sha256.lower().strip() normalized_sha = sha256.lower().strip()
if normalized_sha != sha256: if normalized_sha != sha256:
working_entry['sha256'] = normalized_sha if auto_repair:
repaired = True working_entry['sha256'] = normalized_sha
repaired = True
else:
# If not auto-repairing, we don't consider case difference as a "critical error"
# that invalidates the entry, but we also don't mark it repaired.
pass
# Determine if entry is valid # Determine if entry is valid
# Entry is valid if no critical required field errors remain after repair # Entry is valid if no critical required field errors remain after repair

View File

@@ -13,22 +13,35 @@ from .model_hash_index import ModelHashIndex
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CheckpointScanner(ModelScanner): class CheckpointScanner(ModelScanner):
"""Service for scanning and managing checkpoint files""" """Service for scanning and managing checkpoint files"""
def __init__(self): def __init__(self):
# Define supported file extensions # Define supported file extensions
file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft', '.gguf'} file_extensions = {
".ckpt",
".pt",
".pt2",
".bin",
".pth",
".safetensors",
".pkl",
".sft",
".gguf",
}
super().__init__( super().__init__(
model_type="checkpoint", model_type="checkpoint",
model_class=CheckpointMetadata, model_class=CheckpointMetadata,
file_extensions=file_extensions, file_extensions=file_extensions,
hash_index=ModelHashIndex() hash_index=ModelHashIndex(),
) )
async def _create_default_metadata(self, file_path: str) -> Optional[CheckpointMetadata]: async def _create_default_metadata(
self, file_path: str
) -> Optional[CheckpointMetadata]:
"""Create default metadata for checkpoint without calculating hash (lazy hash). """Create default metadata for checkpoint without calculating hash (lazy hash).
Checkpoints are typically large (10GB+), so we skip hash calculation during initial Checkpoints are typically large (10GB+), so we skip hash calculation during initial
scanning to improve startup performance. Hash will be calculated on-demand when scanning to improve startup performance. Hash will be calculated on-demand when
fetching metadata from Civitai. fetching metadata from Civitai.
@@ -38,13 +51,13 @@ class CheckpointScanner(ModelScanner):
if not os.path.exists(real_path): if not os.path.exists(real_path):
logger.error(f"File not found: {file_path}") logger.error(f"File not found: {file_path}")
return None return None
base_name = os.path.splitext(os.path.basename(file_path))[0] base_name = os.path.splitext(os.path.basename(file_path))[0]
dir_path = os.path.dirname(file_path) dir_path = os.path.dirname(file_path)
# Find preview image # Find preview image
preview_url = find_preview_file(base_name, dir_path) preview_url = find_preview_file(base_name, dir_path)
# Create metadata WITHOUT calculating hash # Create metadata WITHOUT calculating hash
metadata = CheckpointMetadata( metadata = CheckpointMetadata(
file_name=base_name, file_name=base_name,
@@ -59,70 +72,76 @@ class CheckpointScanner(ModelScanner):
modelDescription="", modelDescription="",
sub_type="checkpoint", sub_type="checkpoint",
from_civitai=False, # Mark as local model since no hash yet from_civitai=False, # Mark as local model since no hash yet
hash_status="pending" # Mark hash as pending hash_status="pending", # Mark hash as pending
) )
# Save the created metadata # Save the created metadata
logger.info(f"Creating checkpoint metadata (hash pending) for {file_path}") logger.info(f"Creating checkpoint metadata (hash pending) for {file_path}")
await MetadataManager.save_metadata(file_path, metadata) await MetadataManager.save_metadata(file_path, metadata)
return metadata return metadata
except Exception as e: except Exception as e:
logger.error(f"Error creating default checkpoint metadata for {file_path}: {e}") logger.error(
f"Error creating default checkpoint metadata for {file_path}: {e}"
)
return None return None
async def calculate_hash_for_model(self, file_path: str) -> Optional[str]: async def calculate_hash_for_model(self, file_path: str) -> Optional[str]:
"""Calculate hash for a checkpoint on-demand. """Calculate hash for a checkpoint on-demand.
Args: Args:
file_path: Path to the model file file_path: Path to the model file
Returns: Returns:
SHA256 hash string, or None if calculation failed SHA256 hash string, or None if calculation failed
""" """
from ..utils.file_utils import calculate_sha256 from ..utils.file_utils import calculate_sha256
try: try:
real_path = os.path.realpath(file_path) real_path = os.path.realpath(file_path)
if not os.path.exists(real_path): if not os.path.exists(real_path):
logger.error(f"File not found for hash calculation: {file_path}") logger.error(f"File not found for hash calculation: {file_path}")
return None return None
# Load current metadata # Load current metadata
metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class) metadata, _ = await MetadataManager.load_metadata(
file_path, self.model_class
)
if metadata is None: if metadata is None:
logger.error(f"No metadata found for {file_path}") logger.error(f"No metadata found for {file_path}")
return None return None
# Check if hash is already calculated # Check if hash is already calculated
if metadata.hash_status == "completed" and metadata.sha256: if metadata.hash_status == "completed" and metadata.sha256:
return metadata.sha256 return metadata.sha256
# Update status to calculating # Update status to calculating
metadata.hash_status = "calculating" metadata.hash_status = "calculating"
await MetadataManager.save_metadata(file_path, metadata) await MetadataManager.save_metadata(file_path, metadata)
# Calculate hash # Calculate hash
logger.info(f"Calculating hash for checkpoint: {file_path}") logger.info(f"Calculating hash for checkpoint: {file_path}")
sha256 = await calculate_sha256(real_path) sha256 = await calculate_sha256(real_path)
# Update metadata with hash # Update metadata with hash
metadata.sha256 = sha256 metadata.sha256 = sha256
metadata.hash_status = "completed" metadata.hash_status = "completed"
await MetadataManager.save_metadata(file_path, metadata) await MetadataManager.save_metadata(file_path, metadata)
# Update hash index # Update hash index
self._hash_index.add_entry(sha256.lower(), file_path) self._hash_index.add_entry(sha256.lower(), file_path)
logger.info(f"Hash calculated for checkpoint: {file_path}") logger.info(f"Hash calculated for checkpoint: {file_path}")
return sha256 return sha256
except Exception as e: except Exception as e:
logger.error(f"Error calculating hash for {file_path}: {e}") logger.error(f"Error calculating hash for {file_path}: {e}")
# Update status to failed # Update status to failed
try: try:
metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class) metadata, _ = await MetadataManager.load_metadata(
file_path, self.model_class
)
if metadata: if metadata:
metadata.hash_status = "failed" metadata.hash_status = "failed"
await MetadataManager.save_metadata(file_path, metadata) await MetadataManager.save_metadata(file_path, metadata)
@@ -130,43 +149,46 @@ class CheckpointScanner(ModelScanner):
pass pass
return None return None
async def calculate_all_pending_hashes(self, progress_callback=None) -> Dict[str, int]: async def calculate_all_pending_hashes(
self, progress_callback=None
) -> Dict[str, int]:
"""Calculate hashes for all checkpoints with pending hash status. """Calculate hashes for all checkpoints with pending hash status.
If cache is not initialized, scans filesystem directly for metadata files If cache is not initialized, scans filesystem directly for metadata files
with hash_status != 'completed'. with hash_status != 'completed'.
Args: Args:
progress_callback: Optional callback(progress, total, current_file) progress_callback: Optional callback(progress, total, current_file)
Returns: Returns:
Dict with 'completed', 'failed', 'total' counts Dict with 'completed', 'failed', 'total' counts
""" """
# Try to get from cache first # Try to get from cache first
cache = await self.get_cached_data() cache = await self.get_cached_data()
if cache and cache.raw_data: if cache and cache.raw_data:
# Use cache if available # Use cache if available
pending_models = [ pending_models = [
item for item in cache.raw_data item
if item.get('hash_status') != 'completed' or not item.get('sha256') for item in cache.raw_data
if item.get("hash_status") != "completed" or not item.get("sha256")
] ]
else: else:
# Cache not initialized, scan filesystem directly # Cache not initialized, scan filesystem directly
pending_models = await self._find_pending_models_from_filesystem() pending_models = await self._find_pending_models_from_filesystem()
if not pending_models: if not pending_models:
return {'completed': 0, 'failed': 0, 'total': 0} return {"completed": 0, "failed": 0, "total": 0}
total = len(pending_models) total = len(pending_models)
completed = 0 completed = 0
failed = 0 failed = 0
for i, model_data in enumerate(pending_models): for i, model_data in enumerate(pending_models):
file_path = model_data.get('file_path') file_path = model_data.get("file_path")
if not file_path: if not file_path:
continue continue
try: try:
sha256 = await self.calculate_hash_for_model(file_path) sha256 = await self.calculate_hash_for_model(file_path)
if sha256: if sha256:
@@ -176,77 +198,102 @@ class CheckpointScanner(ModelScanner):
except Exception as e: except Exception as e:
logger.error(f"Error calculating hash for {file_path}: {e}") logger.error(f"Error calculating hash for {file_path}: {e}")
failed += 1 failed += 1
if progress_callback: if progress_callback:
try: try:
await progress_callback(i + 1, total, file_path) await progress_callback(i + 1, total, file_path)
except Exception: except Exception:
pass pass
return { return {"completed": completed, "failed": failed, "total": total}
'completed': completed,
'failed': failed,
'total': total
}
async def _find_pending_models_from_filesystem(self) -> List[Dict[str, Any]]: async def _find_pending_models_from_filesystem(self) -> List[Dict[str, Any]]:
"""Scan filesystem for checkpoint metadata files with pending hash status.""" """Scan filesystem for checkpoint metadata files with pending hash status."""
pending_models = [] pending_models = []
for root_path in self.get_model_roots(): for root_path in self.get_model_roots():
if not os.path.exists(root_path): if not os.path.exists(root_path):
continue continue
for dirpath, _dirnames, filenames in os.walk(root_path): for dirpath, _dirnames, filenames in os.walk(root_path):
for filename in filenames: for filename in filenames:
if not filename.endswith('.metadata.json'): if not filename.endswith(".metadata.json"):
continue continue
metadata_path = os.path.join(dirpath, filename) metadata_path = os.path.join(dirpath, filename)
try: try:
with open(metadata_path, 'r', encoding='utf-8') as f: with open(metadata_path, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
# Check if hash is pending # Check if hash is pending
hash_status = data.get('hash_status', 'completed') hash_status = data.get("hash_status", "completed")
sha256 = data.get('sha256', '') sha256 = data.get("sha256", "")
if hash_status != 'completed' or not sha256: if hash_status != "completed" or not sha256:
# Find corresponding model file # Find corresponding model file
model_name = filename.replace('.metadata.json', '') model_name = filename.replace(".metadata.json", "")
model_path = None model_path = None
# Look for model file with matching name # Look for model file with matching name
for ext in self.file_extensions: for ext in self.file_extensions:
potential_path = os.path.join(dirpath, model_name + ext) potential_path = os.path.join(dirpath, model_name + ext)
if os.path.exists(potential_path): if os.path.exists(potential_path):
model_path = potential_path model_path = potential_path
break break
if model_path: if model_path:
pending_models.append({ pending_models.append(
'file_path': model_path.replace(os.sep, '/'), {
'hash_status': hash_status, "file_path": model_path.replace(os.sep, "/"),
'sha256': sha256, "hash_status": hash_status,
**{k: v for k, v in data.items() if k not in ['file_path', 'hash_status', 'sha256']} "sha256": sha256,
}) **{
k: v
for k, v in data.items()
if k
not in [
"file_path",
"hash_status",
"sha256",
]
},
}
)
except (json.JSONDecodeError, Exception) as e: except (json.JSONDecodeError, Exception) as e:
logger.debug(f"Error reading metadata file {metadata_path}: {e}") logger.debug(
f"Error reading metadata file {metadata_path}: {e}"
)
continue continue
return pending_models return pending_models
def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]: def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]:
"""Resolve the sub-type based on the root path.""" """Resolve the sub-type based on the root path.
Checks both standard ComfyUI paths and LoRA Manager's extra folder paths.
"""
if not root_path: if not root_path:
return None return None
# Check standard ComfyUI checkpoint paths
if config.checkpoints_roots and root_path in config.checkpoints_roots: if config.checkpoints_roots and root_path in config.checkpoints_roots:
return "checkpoint" return "checkpoint"
# Check extra checkpoint paths
if (
config.extra_checkpoints_roots
and root_path in config.extra_checkpoints_roots
):
return "checkpoint"
# Check standard ComfyUI unet paths
if config.unet_roots and root_path in config.unet_roots: if config.unet_roots and root_path in config.unet_roots:
return "diffusion_model" return "diffusion_model"
# Check extra unet paths
if config.extra_unet_roots and root_path in config.extra_unet_roots:
return "diffusion_model"
return None return None
def adjust_metadata(self, metadata, file_path, root_path): def adjust_metadata(self, metadata, file_path, root_path):

View File

@@ -27,7 +27,7 @@ class LoraService(BaseModelService):
# Resolve sub_type using priority: sub_type > model_type > civitai.model.type > default # Resolve sub_type using priority: sub_type > model_type > civitai.model.type > default
# Normalize to lowercase for consistent API responses # Normalize to lowercase for consistent API responses
sub_type = resolve_sub_type(lora_data).lower() sub_type = resolve_sub_type(lora_data).lower()
return { return {
"model_name": lora_data["model_name"], "model_name": lora_data["model_name"],
"file_name": lora_data["file_name"], "file_name": lora_data["file_name"],
@@ -48,7 +48,9 @@ class LoraService(BaseModelService):
"notes": lora_data.get("notes", ""), "notes": lora_data.get("notes", ""),
"favorite": lora_data.get("favorite", False), "favorite": lora_data.get("favorite", False),
"update_available": bool(lora_data.get("update_available", False)), "update_available": bool(lora_data.get("update_available", False)),
"skip_metadata_refresh": bool(lora_data.get("skip_metadata_refresh", False)), "skip_metadata_refresh": bool(
lora_data.get("skip_metadata_refresh", False)
),
"sub_type": sub_type, "sub_type": sub_type,
"civitai": self.filter_civitai_data( "civitai": self.filter_civitai_data(
lora_data.get("civitai", {}), minimal=True lora_data.get("civitai", {}), minimal=True
@@ -62,6 +64,68 @@ class LoraService(BaseModelService):
if first_letter: if first_letter:
data = self._filter_by_first_letter(data, first_letter) data = self._filter_by_first_letter(data, first_letter)
# Handle name pattern filters
name_pattern_include = kwargs.get("name_pattern_include", [])
name_pattern_exclude = kwargs.get("name_pattern_exclude", [])
name_pattern_use_regex = kwargs.get("name_pattern_use_regex", False)
if name_pattern_include or name_pattern_exclude:
import re
def matches_pattern(name, pattern, use_regex):
"""Check if name matches pattern (regex or substring)"""
if not name:
return False
if use_regex:
try:
return bool(re.search(pattern, name, re.IGNORECASE))
except re.error:
# Invalid regex, fall back to substring match
return pattern.lower() in name.lower()
else:
return pattern.lower() in name.lower()
def matches_any_pattern(name, patterns, use_regex):
"""Check if name matches any of the patterns"""
if not patterns:
return True
return any(matches_pattern(name, p, use_regex) for p in patterns)
filtered = []
for lora in data:
model_name = lora.get("model_name", "")
file_name = lora.get("file_name", "")
names_to_check = [n for n in [model_name, file_name] if n]
# Check exclude patterns first
excluded = False
if name_pattern_exclude:
for name in names_to_check:
if matches_any_pattern(
name, name_pattern_exclude, name_pattern_use_regex
):
excluded = True
break
if excluded:
continue
# Check include patterns
if name_pattern_include:
included = False
for name in names_to_check:
if matches_any_pattern(
name, name_pattern_include, name_pattern_use_regex
):
included = True
break
if not included:
continue
filtered.append(lora)
data = filtered
return data return data
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]: def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
@@ -368,9 +432,7 @@ class LoraService(BaseModelService):
rng.uniform(clip_strength_min, clip_strength_max), 2 rng.uniform(clip_strength_min, clip_strength_max), 2
) )
else: else:
clip_str = round( clip_str = round(rng.uniform(clip_strength_min, clip_strength_max), 2)
rng.uniform(clip_strength_min, clip_strength_max), 2
)
result_loras.append( result_loras.append(
{ {
@@ -485,12 +547,69 @@ class LoraService(BaseModelService):
if bool(lora.get("license_flags", 127) & (1 << 1)) if bool(lora.get("license_flags", 127) & (1 << 1))
] ]
# Apply name pattern filters
name_patterns = filter_section.get("namePatterns", {})
include_patterns = name_patterns.get("include", [])
exclude_patterns = name_patterns.get("exclude", [])
use_regex = name_patterns.get("useRegex", False)
if include_patterns or exclude_patterns:
import re
def matches_pattern(name, pattern, use_regex):
"""Check if name matches pattern (regex or substring)"""
if not name:
return False
if use_regex:
try:
return bool(re.search(pattern, name, re.IGNORECASE))
except re.error:
# Invalid regex, fall back to substring match
return pattern.lower() in name.lower()
else:
return pattern.lower() in name.lower()
def matches_any_pattern(name, patterns, use_regex):
"""Check if name matches any of the patterns"""
if not patterns:
return True
return any(matches_pattern(name, p, use_regex) for p in patterns)
filtered = []
for lora in available_loras:
model_name = lora.get("model_name", "")
file_name = lora.get("file_name", "")
names_to_check = [n for n in [model_name, file_name] if n]
# Check exclude patterns first
excluded = False
if exclude_patterns:
for name in names_to_check:
if matches_any_pattern(name, exclude_patterns, use_regex):
excluded = True
break
if excluded:
continue
# Check include patterns
if include_patterns:
included = False
for name in names_to_check:
if matches_any_pattern(name, include_patterns, use_regex):
included = True
break
if not included:
continue
filtered.append(lora)
available_loras = filtered
return available_loras return available_loras
async def get_cycler_list( async def get_cycler_list(
self, self, pool_config: Optional[Dict] = None, sort_by: str = "filename"
pool_config: Optional[Dict] = None,
sort_by: str = "filename"
) -> List[Dict]: ) -> List[Dict]:
""" """
Get filtered and sorted LoRA list for cycling. Get filtered and sorted LoRA list for cycling.
@@ -516,12 +635,18 @@ class LoraService(BaseModelService):
if sort_by == "model_name": if sort_by == "model_name":
available_loras = sorted( available_loras = sorted(
available_loras, available_loras,
key=lambda x: (x.get("model_name") or x.get("file_name", "")).lower() key=lambda x: (
(x.get("model_name") or x.get("file_name", "")).lower(),
x.get("file_path", "").lower(),
),
) )
else: # Default to filename else: # Default to filename
available_loras = sorted( available_loras = sorted(
available_loras, available_loras,
key=lambda x: x.get("file_name", "").lower() key=lambda x: (
x.get("file_name", "").lower(),
x.get("file_path", "").lower(),
),
) )
# Return minimal data needed for cycling # Return minimal data needed for cycling

View File

@@ -221,33 +221,45 @@ class ModelCache:
start_time = time.perf_counter() start_time = time.perf_counter()
reverse = (order == 'desc') reverse = (order == 'desc')
if sort_key == 'name': if sort_key == 'name':
# Natural sort by configured display name, case-insensitive # Natural sort by configured display name, case-insensitive, with file_path as tie-breaker
result = natsorted( result = natsorted(
data, data,
key=lambda x: self._get_display_name(x).lower(), key=lambda x: (
self._get_display_name(x).lower(),
x.get('file_path', '').lower()
),
reverse=reverse reverse=reverse
) )
elif sort_key == 'date': elif sort_key == 'date':
# Sort by modified timestamp (use .get() with default to handle missing fields) # Sort by modified timestamp, fallback to name and path for stability
result = sorted( result = sorted(
data, data,
key=lambda x: x.get('modified', 0.0), key=lambda x: (
x.get('modified', 0.0),
self._get_display_name(x).lower(),
x.get('file_path', '').lower()
),
reverse=reverse reverse=reverse
) )
elif sort_key == 'size': elif sort_key == 'size':
# Sort by file size (use .get() with default to handle missing fields) # Sort by file size, fallback to name and path for stability
result = sorted( result = sorted(
data, data,
key=lambda x: x.get('size', 0), key=lambda x: (
x.get('size', 0),
self._get_display_name(x).lower(),
x.get('file_path', '').lower()
),
reverse=reverse reverse=reverse
) )
elif sort_key == 'usage': elif sort_key == 'usage':
# Sort by usage count, fallback to 0, then name for stability # Sort by usage count, fallback to 0, then name and path for stability
return sorted( return sorted(
data, data,
key=lambda x: ( key=lambda x: (
x.get('usage_count', 0), x.get('usage_count', 0),
self._get_display_name(x).lower() self._get_display_name(x).lower(),
x.get('file_path', '').lower()
), ),
reverse=reverse reverse=reverse
) )

View File

@@ -14,7 +14,6 @@ from ..utils.metadata_manager import MetadataManager
from ..utils.civitai_utils import resolve_license_info from ..utils.civitai_utils import resolve_license_info
from .model_cache import ModelCache from .model_cache import ModelCache
from .model_hash_index import ModelHashIndex from .model_hash_index import ModelHashIndex
from ..utils.constants import PREVIEW_EXTENSIONS
from .model_lifecycle_service import delete_model_artifacts from .model_lifecycle_service import delete_model_artifacts
from .service_registry import ServiceRegistry from .service_registry import ServiceRegistry
from .websocket_manager import ws_manager from .websocket_manager import ws_manager
@@ -1442,14 +1441,13 @@ class ModelScanner:
file_path = self._hash_index.get_path(sha256.lower()) file_path = self._hash_index.get_path(sha256.lower())
if not file_path: if not file_path:
return None return None
base_name = os.path.splitext(file_path)[0] dir_path = os.path.dirname(file_path)
base_name = os.path.splitext(os.path.basename(file_path))[0]
for ext in PREVIEW_EXTENSIONS: preview_path = find_preview_file(base_name, dir_path)
preview_path = f"{base_name}{ext}" if preview_path:
if os.path.exists(preview_path): return config.get_preview_static_url(preview_path)
return config.get_preview_static_url(preview_path)
return None return None
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]: async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:

View File

@@ -56,6 +56,7 @@ class PersistentModelCache:
"exclude", "exclude",
"db_checked", "db_checked",
"last_checked_at", "last_checked_at",
"hash_status",
) )
_MODEL_UPDATE_COLUMNS: Tuple[str, ...] = _MODEL_COLUMNS[2:] _MODEL_UPDATE_COLUMNS: Tuple[str, ...] = _MODEL_COLUMNS[2:]
_instances: Dict[str, "PersistentModelCache"] = {} _instances: Dict[str, "PersistentModelCache"] = {}
@@ -186,6 +187,7 @@ class PersistentModelCache:
"civitai_deleted": bool(row["civitai_deleted"]), "civitai_deleted": bool(row["civitai_deleted"]),
"skip_metadata_refresh": bool(row["skip_metadata_refresh"]), "skip_metadata_refresh": bool(row["skip_metadata_refresh"]),
"license_flags": int(license_value), "license_flags": int(license_value),
"hash_status": row["hash_status"] or "completed",
} }
raw_data.append(item) raw_data.append(item)
@@ -449,6 +451,7 @@ class PersistentModelCache:
exclude INTEGER, exclude INTEGER,
db_checked INTEGER, db_checked INTEGER,
last_checked_at REAL, last_checked_at REAL,
hash_status TEXT,
PRIMARY KEY (model_type, file_path) PRIMARY KEY (model_type, file_path)
); );
@@ -496,6 +499,7 @@ class PersistentModelCache:
"skip_metadata_refresh": "INTEGER DEFAULT 0", "skip_metadata_refresh": "INTEGER DEFAULT 0",
# Persisting without explicit flags should assume CivitAI's documented defaults (0b111001 == 57). # Persisting without explicit flags should assume CivitAI's documented defaults (0b111001 == 57).
"license_flags": f"INTEGER DEFAULT {DEFAULT_LICENSE_FLAGS}", "license_flags": f"INTEGER DEFAULT {DEFAULT_LICENSE_FLAGS}",
"hash_status": "TEXT DEFAULT 'completed'",
} }
for column, definition in required_columns.items(): for column, definition in required_columns.items():
@@ -570,6 +574,7 @@ class PersistentModelCache:
1 if item.get("exclude") else 0, 1 if item.get("exclude") else 0,
1 if item.get("db_checked") else 0, 1 if item.get("db_checked") else 0,
float(item.get("last_checked_at") or 0.0), float(item.get("last_checked_at") or 0.0),
item.get("hash_status", "completed"),
) )
def _insert_model_sql(self) -> str: def _insert_model_sql(self) -> str:

View File

@@ -135,7 +135,8 @@ class RecipeCache:
"""Sort cached views. Caller must hold ``_lock``.""" """Sort cached views. Caller must hold ``_lock``."""
self.sorted_by_name = natsorted( self.sorted_by_name = natsorted(
self.raw_data, key=lambda x: x.get("title", "").lower() self.raw_data,
key=lambda x: (x.get("title", "").lower(), x.get("file_path", "").lower()),
) )
if not name_only: if not name_only:
self.sorted_by_date = sorted( self.sorted_by_date = sorted(

View File

@@ -1,4 +1,5 @@
"""Services responsible for recipe metadata analysis.""" """Services responsible for recipe metadata analysis."""
from __future__ import annotations from __future__ import annotations
import base64 import base64
@@ -69,7 +70,9 @@ class RecipeAnalysisService:
try: try:
metadata = self._exif_utils.extract_image_metadata(temp_path) metadata = self._exif_utils.extract_image_metadata(temp_path)
if not metadata: if not metadata:
return AnalysisResult({"error": "No metadata found in this image", "loras": []}) return AnalysisResult(
{"error": "No metadata found in this image", "loras": []}
)
return await self._parse_metadata( return await self._parse_metadata(
metadata, metadata,
@@ -105,29 +108,33 @@ class RecipeAnalysisService:
if civitai_match: if civitai_match:
image_info = await civitai_client.get_image_info(civitai_match.group(1)) image_info = await civitai_client.get_image_info(civitai_match.group(1))
if not image_info: if not image_info:
raise RecipeDownloadError("Failed to fetch image information from Civitai") raise RecipeDownloadError(
"Failed to fetch image information from Civitai"
)
image_url = image_info.get("url") image_url = image_info.get("url")
if not image_url: if not image_url:
raise RecipeDownloadError("No image URL found in Civitai response") raise RecipeDownloadError("No image URL found in Civitai response")
is_video = image_info.get("type") == "video" is_video = image_info.get("type") == "video"
# Use optimized preview URLs if possible # Use optimized preview URLs if possible
rewritten_url, _ = rewrite_preview_url(image_url, media_type=image_info.get("type")) rewritten_url, _ = rewrite_preview_url(
image_url, media_type=image_info.get("type")
)
if rewritten_url: if rewritten_url:
image_url = rewritten_url image_url = rewritten_url
if is_video: if is_video:
# Extract extension from URL # Extract extension from URL
url_path = image_url.split('?')[0].split('#')[0] url_path = image_url.split("?")[0].split("#")[0]
extension = os.path.splitext(url_path)[1].lower() or ".mp4" extension = os.path.splitext(url_path)[1].lower() or ".mp4"
else: else:
extension = ".jpg" extension = ".jpg"
temp_path = self._create_temp_path(suffix=extension) temp_path = self._create_temp_path(suffix=extension)
await self._download_image(image_url, temp_path) await self._download_image(image_url, temp_path)
metadata = image_info.get("meta") if "meta" in image_info else None metadata = image_info.get("meta") if "meta" in image_info else None
if ( if (
isinstance(metadata, dict) isinstance(metadata, dict)
@@ -135,15 +142,23 @@ class RecipeAnalysisService:
and isinstance(metadata["meta"], dict) and isinstance(metadata["meta"], dict)
): ):
metadata = metadata["meta"] metadata = metadata["meta"]
# Validate that metadata contains meaningful recipe fields
# If not, treat as None to trigger EXIF extraction from downloaded image
if isinstance(metadata, dict) and not self._has_recipe_fields(metadata):
self._logger.debug(
"Civitai API metadata lacks recipe fields, will extract from EXIF"
)
metadata = None
else: else:
# Basic extension detection for non-Civitai URLs # Basic extension detection for non-Civitai URLs
url_path = url.split('?')[0].split('#')[0] url_path = url.split("?")[0].split("#")[0]
extension = os.path.splitext(url_path)[1].lower() extension = os.path.splitext(url_path)[1].lower()
if extension in [".mp4", ".webm"]: if extension in [".mp4", ".webm"]:
is_video = True is_video = True
else: else:
extension = ".jpg" extension = ".jpg"
temp_path = self._create_temp_path(suffix=extension) temp_path = self._create_temp_path(suffix=extension)
await self._download_image(url, temp_path) await self._download_image(url, temp_path)
@@ -211,7 +226,9 @@ class RecipeAnalysisService:
image_bytes = self._convert_tensor_to_png_bytes(latest_image) image_bytes = self._convert_tensor_to_png_bytes(latest_image)
if image_bytes is None: if image_bytes is None:
raise RecipeValidationError("Cannot handle this data shape from metadata registry") raise RecipeValidationError(
"Cannot handle this data shape from metadata registry"
)
return AnalysisResult( return AnalysisResult(
{ {
@@ -222,6 +239,22 @@ class RecipeAnalysisService:
# Internal helpers ------------------------------------------------- # Internal helpers -------------------------------------------------
def _has_recipe_fields(self, metadata: dict[str, Any]) -> bool:
"""Check if metadata contains meaningful recipe-related fields."""
recipe_fields = {
"prompt",
"negative_prompt",
"resources",
"hashes",
"params",
"generationData",
"Workflow",
"prompt_type",
"positive",
"negative",
}
return any(field in metadata for field in recipe_fields)
async def _parse_metadata( async def _parse_metadata(
self, self,
metadata: dict[str, Any], metadata: dict[str, Any],
@@ -234,7 +267,12 @@ class RecipeAnalysisService:
) -> AnalysisResult: ) -> AnalysisResult:
parser = self._recipe_parser_factory.create_parser(metadata) parser = self._recipe_parser_factory.create_parser(metadata)
if parser is None: if parser is None:
payload = {"error": "No parser found for this image", "loras": []} # Provide more specific error message based on metadata source
if not metadata:
error_msg = "This image does not contain any generation metadata (prompt, models, or parameters)"
else:
error_msg = "No parser found for this image"
payload = {"error": error_msg, "loras": []}
if include_image_base64 and image_path: if include_image_base64 and image_path:
payload["image_base64"] = self._encode_file(image_path) payload["image_base64"] = self._encode_file(image_path)
payload["is_video"] = is_video payload["is_video"] = is_video
@@ -257,7 +295,9 @@ class RecipeAnalysisService:
matching_recipes: list[str] = [] matching_recipes: list[str] = []
if fingerprint: if fingerprint:
matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(fingerprint) matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(
fingerprint
)
result["matching_recipes"] = matching_recipes result["matching_recipes"] = matching_recipes
return AnalysisResult(result) return AnalysisResult(result)
@@ -269,7 +309,10 @@ class RecipeAnalysisService:
raise RecipeDownloadError(f"Failed to download image from URL: {result}") raise RecipeDownloadError(f"Failed to download image from URL: {result}")
def _metadata_not_found_response(self, path: str) -> AnalysisResult: def _metadata_not_found_response(self, path: str) -> AnalysisResult:
payload: dict[str, Any] = {"error": "No metadata found in this image", "loras": []} payload: dict[str, Any] = {
"error": "No metadata found in this image",
"loras": [],
}
if os.path.exists(path): if os.path.exists(path):
payload["image_base64"] = self._encode_file(path) payload["image_base64"] = self._encode_file(path)
return AnalysisResult(payload) return AnalysisResult(payload)
@@ -305,7 +348,9 @@ class RecipeAnalysisService:
if hasattr(tensor_image, "shape"): if hasattr(tensor_image, "shape"):
self._logger.debug( self._logger.debug(
"Tensor shape: %s, dtype: %s", tensor_image.shape, getattr(tensor_image, "dtype", None) "Tensor shape: %s, dtype: %s",
tensor_image.shape,
getattr(tensor_image, "dtype", None),
) )
import torch # type: ignore[import-not-found] import torch # type: ignore[import-not-found]

View File

@@ -40,49 +40,39 @@ async def calculate_sha256(file_path: str) -> str:
return sha256_hash.hexdigest() return sha256_hash.hexdigest()
def find_preview_file(base_name: str, dir_path: str) -> str: def find_preview_file(base_name: str, dir_path: str) -> str:
"""Find preview file for given base name in directory""" """Find preview file for given base name in directory.
Performs an exact-case check first (fast path), then falls back to a
case-insensitive scan so that files like ``model.WEBP`` or ``model.Png``
are discovered on case-sensitive filesystems.
"""
temp_extensions = PREVIEW_EXTENSIONS.copy() temp_extensions = PREVIEW_EXTENSIONS.copy()
# Add example extension for compatibility # Add example extension for compatibility
# https://github.com/willmiao/ComfyUI-Lora-Manager/issues/225 # https://github.com/willmiao/ComfyUI-Lora-Manager/issues/225
# The preview image will be optimized to lora-name.webp, so it won't affect other logic # The preview image will be optimized to lora-name.webp, so it won't affect other logic
temp_extensions.append(".example.0.jpeg") temp_extensions.append(".example.0.jpeg")
# Fast path: exact-case match
for ext in temp_extensions: for ext in temp_extensions:
full_pattern = os.path.join(dir_path, f"{base_name}{ext}") full_pattern = os.path.join(dir_path, f"{base_name}{ext}")
if os.path.exists(full_pattern): if os.path.exists(full_pattern):
# Check if this is an image and not already webp
# TODO: disable the optimization for now, maybe add a config option later
# if ext.lower().endswith(('.jpg', '.jpeg', '.png')) and not ext.lower().endswith('.webp'):
# try:
# # Optimize the image to webp format
# webp_path = os.path.join(dir_path, f"{base_name}.webp")
# # Use ExifUtils to optimize the image
# with open(full_pattern, 'rb') as f:
# image_data = f.read()
# optimized_data, _ = ExifUtils.optimize_image(
# image_data=image_data,
# target_width=CARD_PREVIEW_WIDTH,
# format='webp',
# quality=85,
# preserve_metadata=False
# )
# # Save the optimized webp file
# with open(webp_path, 'wb') as f:
# f.write(optimized_data)
# logger.debug(f"Optimized preview image from {full_pattern} to {webp_path}")
# return webp_path.replace(os.sep, "/")
# except Exception as e:
# logger.error(f"Error optimizing preview image {full_pattern}: {e}")
# # Fall back to original file if optimization fails
# return full_pattern.replace(os.sep, "/")
# Return the original path for webp images or non-image files
return full_pattern.replace(os.sep, "/") return full_pattern.replace(os.sep, "/")
# Slow path: case-insensitive match for systems with mixed-case extensions
# (e.g. .WEBP, .Png, .JPG placed manually or by external tools)
try:
dir_entries = os.listdir(dir_path)
except OSError:
return ""
base_lower = base_name.lower()
for ext in temp_extensions:
target = f"{base_lower}{ext}" # ext is already lowercase
for entry in dir_entries:
if entry.lower() == target:
return os.path.join(dir_path, entry).replace(os.sep, "/")
return "" return ""
def get_preview_extension(preview_path: str) -> str: def get_preview_extension(preview_path: str) -> str:

View File

@@ -112,6 +112,115 @@ def get_lora_info_absolute(lora_name):
return asyncio.run(_get_lora_info_absolute_async()) return asyncio.run(_get_lora_info_absolute_async())
def get_checkpoint_info_absolute(checkpoint_name):
"""Get the absolute checkpoint path and metadata from cache
Supports ComfyUI-style model names (e.g., "folder/model_name.ext")
Args:
checkpoint_name: The model name, can be:
- ComfyUI format: "folder/model_name.safetensors"
- Simple name: "model_name"
Returns:
tuple: (absolute_path, metadata) where absolute_path is the full
file system path to the checkpoint file, or original checkpoint_name if not found,
metadata is the full model metadata dict or None
"""
async def _get_checkpoint_info_absolute_async():
from ..services.service_registry import ServiceRegistry
scanner = await ServiceRegistry.get_checkpoint_scanner()
cache = await scanner.get_cached_data()
# Get model roots for matching
model_roots = scanner.get_model_roots()
# Normalize the checkpoint name
normalized_name = checkpoint_name.replace(os.sep, "/")
for item in cache.raw_data:
file_path = item.get("file_path", "")
if not file_path:
continue
# Format the stored path as ComfyUI-style name
formatted_name = _format_model_name_for_comfyui(file_path, model_roots)
# Match by formatted name (normalize separators for robust comparison)
if formatted_name.replace(os.sep, "/") == normalized_name or formatted_name == checkpoint_name:
return file_path, item
# Also try matching by basename only (for backward compatibility)
file_name = item.get("file_name", "")
if (
file_name == checkpoint_name
or file_name == os.path.splitext(normalized_name)[0]
):
return file_path, item
return checkpoint_name, None
try:
# Check if we're already in an event loop
loop = asyncio.get_running_loop()
# If we're in a running loop, we need to use a different approach
# Create a new thread to run the async code
import concurrent.futures
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(
_get_checkpoint_info_absolute_async()
)
finally:
new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result()
except RuntimeError:
# No event loop is running, we can use asyncio.run()
return asyncio.run(_get_checkpoint_info_absolute_async())
def _format_model_name_for_comfyui(file_path: str, model_roots: list) -> str:
"""Format file path to ComfyUI-style model name (relative path with extension)
Example: /path/to/checkpoints/Illustrious/model.safetensors -> Illustrious/model.safetensors
Args:
file_path: Absolute path to the model file
model_roots: List of model root directories
Returns:
ComfyUI-style model name with relative path and extension
"""
# Find the matching root and get relative path
for root in model_roots:
try:
# Normalize paths for comparison
norm_file = os.path.normcase(os.path.abspath(file_path))
norm_root = os.path.normcase(os.path.abspath(root))
# Add trailing separator for prefix check
if not norm_root.endswith(os.sep):
norm_root += os.sep
if norm_file.startswith(norm_root):
# Use os.path.relpath to get relative path with OS-native separator
return os.path.relpath(file_path, root)
except (ValueError, TypeError):
continue
# If no root matches, just return the basename with extension
return os.path.basename(file_path)
def fuzzy_match(text: str, pattern: str, threshold: float = 0.85) -> bool: def fuzzy_match(text: str, pattern: str, threshold: float = 0.85) -> bool:
""" """
Check if text matches pattern using fuzzy matching. Check if text matches pattern using fuzzy matching.

View File

@@ -1,5 +1,5 @@
[pytest] [pytest]
addopts = -v --import-mode=importlib -m "not performance" addopts = -v --import-mode=importlib -m "not performance" --ignore=__init__.py
testpaths = tests testpaths = tests
python_files = test_*.py python_files = test_*.py
python_classes = Test* python_classes = Test*

View File

@@ -251,7 +251,7 @@ export class BaseModelApiClient {
replaceModelPreview(filePath) { replaceModelPreview(filePath) {
const input = document.createElement('input'); const input = document.createElement('input');
input.type = 'file'; input.type = 'file';
input.accept = 'image/*,video/mp4'; input.accept = 'image/*,image/webp,video/mp4';
input.onchange = async () => { input.onchange = async () => {
if (!input.files || !input.files[0]) return; if (!input.files || !input.files[0]) return;

View File

@@ -104,6 +104,14 @@ export class BatchImportManager {
// Clean up any existing connections // Clean up any existing connections
this.cleanupConnections(); this.cleanupConnections();
// Focus on the URL input field for better UX
setTimeout(() => {
const urlInput = document.getElementById('batchUrlInput');
if (urlInput) {
urlInput.focus();
}
}, 100);
} }
/** /**

View File

@@ -36,8 +36,8 @@ class TestCheckpointPathOverlap:
config._preview_root_paths = set() config._preview_root_paths = set()
config._cached_fingerprint = None config._cached_fingerprint = None
# Call the method under test # Call the method under test - now returns a tuple
result = config._prepare_checkpoint_paths( all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
[str(checkpoints_link)], [str(unet_link)] [str(checkpoints_link)], [str(unet_link)]
) )
@@ -50,21 +50,27 @@ class TestCheckpointPathOverlap:
] ]
assert len(warning_messages) == 1 assert len(warning_messages) == 1
assert "checkpoints" in warning_messages[0].lower() assert "checkpoints" in warning_messages[0].lower()
assert "diffusion_models" in warning_messages[0].lower() or "unet" in warning_messages[0].lower() assert (
"diffusion_models" in warning_messages[0].lower()
or "unet" in warning_messages[0].lower()
)
# Verify warning mentions backward compatibility fallback # Verify warning mentions backward compatibility fallback
assert "falling back" in warning_messages[0].lower() or "backward compatibility" in warning_messages[0].lower() assert (
"falling back" in warning_messages[0].lower()
or "backward compatibility" in warning_messages[0].lower()
)
# Verify only one path is returned (deduplication still works) # Verify only one path is returned (deduplication still works)
assert len(result) == 1 assert len(all_paths) == 1
# Prioritizes checkpoints path for backward compatibility # Prioritizes checkpoints path for backward compatibility
assert _normalize(result[0]) == _normalize(str(checkpoints_link)) assert _normalize(all_paths[0]) == _normalize(str(checkpoints_link))
# Verify checkpoints_roots has the path (prioritized) # Verify checkpoint_roots has the path (prioritized)
assert len(config.checkpoints_roots) == 1 assert len(checkpoint_roots) == 1
assert _normalize(config.checkpoints_roots[0]) == _normalize(str(checkpoints_link)) assert _normalize(checkpoint_roots[0]) == _normalize(str(checkpoints_link))
# Verify unet_roots is empty (overlapping paths removed) # Verify unet_roots is empty (overlapping paths removed)
assert config.unet_roots == [] assert unet_roots == []
def test_non_overlapping_paths_no_warning( def test_non_overlapping_paths_no_warning(
self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog
@@ -83,7 +89,7 @@ class TestCheckpointPathOverlap:
config._preview_root_paths = set() config._preview_root_paths = set()
config._cached_fingerprint = None config._cached_fingerprint = None
result = config._prepare_checkpoint_paths( all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
[str(checkpoints_dir)], [str(unet_dir)] [str(checkpoints_dir)], [str(unet_dir)]
) )
@@ -97,14 +103,14 @@ class TestCheckpointPathOverlap:
assert len(warning_messages) == 0 assert len(warning_messages) == 0
# Verify both paths are returned # Verify both paths are returned
assert len(result) == 2 assert len(all_paths) == 2
normalized_result = [_normalize(p) for p in result] normalized_result = [_normalize(p) for p in all_paths]
assert _normalize(str(checkpoints_dir)) in normalized_result assert _normalize(str(checkpoints_dir)) in normalized_result
assert _normalize(str(unet_dir)) in normalized_result assert _normalize(str(unet_dir)) in normalized_result
# Verify both roots are properly set # Verify both roots are properly set
assert len(config.checkpoints_roots) == 1 assert len(checkpoint_roots) == 1
assert len(config.unet_roots) == 1 assert len(unet_roots) == 1
def test_partial_overlap_prioritizes_checkpoints( def test_partial_overlap_prioritizes_checkpoints(
self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog
@@ -129,9 +135,9 @@ class TestCheckpointPathOverlap:
config._cached_fingerprint = None config._cached_fingerprint = None
# One checkpoint path overlaps with one unet path # One checkpoint path overlaps with one unet path
result = config._prepare_checkpoint_paths( all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
[str(shared_link), str(separate_checkpoint)], [str(shared_link), str(separate_checkpoint)],
[str(shared_link), str(separate_unet)] [str(shared_link), str(separate_unet)],
) )
# Verify warning was logged for the overlapping path # Verify warning was logged for the overlapping path
@@ -144,17 +150,20 @@ class TestCheckpointPathOverlap:
assert len(warning_messages) == 1 assert len(warning_messages) == 1
# Verify 3 unique paths (shared counted once as checkpoint, plus separate ones) # Verify 3 unique paths (shared counted once as checkpoint, plus separate ones)
assert len(result) == 3 assert len(all_paths) == 3
# Verify the overlapping path appears in warning message # Verify the overlapping path appears in warning message
assert str(shared_link.name) in warning_messages[0] or str(shared_dir.name) in warning_messages[0] assert (
str(shared_link.name) in warning_messages[0]
or str(shared_dir.name) in warning_messages[0]
)
# Verify checkpoints_roots includes both checkpoint paths (including the shared one) # Verify checkpoint_roots includes both checkpoint paths (including the shared one)
assert len(config.checkpoints_roots) == 2 assert len(checkpoint_roots) == 2
checkpoint_normalized = [_normalize(p) for p in config.checkpoints_roots] checkpoint_normalized = [_normalize(p) for p in checkpoint_roots]
assert _normalize(str(shared_link)) in checkpoint_normalized assert _normalize(str(shared_link)) in checkpoint_normalized
assert _normalize(str(separate_checkpoint)) in checkpoint_normalized assert _normalize(str(separate_checkpoint)) in checkpoint_normalized
# Verify unet_roots only includes the non-overlapping unet path # Verify unet_roots only includes the non-overlapping unet path
assert len(config.unet_roots) == 1 assert len(unet_roots) == 1
assert _normalize(config.unet_roots[0]) == _normalize(str(separate_unet)) assert _normalize(unet_roots[0]) == _normalize(str(separate_unet))

View File

@@ -194,6 +194,7 @@ class TestCacheHealthMonitor:
'preview_nsfw_level': 0, 'preview_nsfw_level': 0,
'notes': '', 'notes': '',
'usage_tips': '', 'usage_tips': '',
'hash_status': 'completed',
} }
incomplete_entry = { incomplete_entry = {
'file_path': '/models/test2.safetensors', 'file_path': '/models/test2.safetensors',

View File

@@ -369,3 +369,289 @@ async def test_pool_filter_combined_all_filters(lora_service):
# - tags: tag1 ✓ # - tags: tag1 ✓
assert len(filtered) == 1 assert len(filtered) == 1
assert filtered[0]["file_name"] == "match_all.safetensors" assert filtered[0]["file_name"] == "match_all.safetensors"
@pytest.mark.asyncio
async def test_pool_filter_name_patterns_include_text(lora_service):
"""Test filtering by name patterns with text matching (useRegex=False)."""
sample_loras = [
{
"file_name": "character_anime_v1.safetensors",
"model_name": "Anime Character",
"base_model": "Illustrious",
"folder": "",
"license_flags": build_license_flags(None),
},
{
"file_name": "character_realistic_v1.safetensors",
"model_name": "Realistic Character",
"base_model": "Illustrious",
"folder": "",
"license_flags": build_license_flags(None),
},
{
"file_name": "style_watercolor_v1.safetensors",
"model_name": "Watercolor Style",
"base_model": "Illustrious",
"folder": "",
"license_flags": build_license_flags(None),
},
]
# Test include patterns with text matching
pool_config = {
"baseModels": [],
"tags": {"include": [], "exclude": []},
"folders": {"include": [], "exclude": []},
"license": {"noCreditRequired": False, "allowSelling": False},
"namePatterns": {"include": ["character"], "exclude": [], "useRegex": False},
}
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
assert len(filtered) == 2
file_names = {lora["file_name"] for lora in filtered}
assert file_names == {
"character_anime_v1.safetensors",
"character_realistic_v1.safetensors",
}
@pytest.mark.asyncio
async def test_pool_filter_name_patterns_exclude_text(lora_service):
"""Test excluding by name patterns with text matching (useRegex=False)."""
sample_loras = [
{
"file_name": "character_anime_v1.safetensors",
"model_name": "Anime Character",
"base_model": "Illustrious",
"folder": "",
"license_flags": build_license_flags(None),
},
{
"file_name": "character_realistic_v1.safetensors",
"model_name": "Realistic Character",
"base_model": "Illustrious",
"folder": "",
"license_flags": build_license_flags(None),
},
{
"file_name": "style_watercolor_v1.safetensors",
"model_name": "Watercolor Style",
"base_model": "Illustrious",
"folder": "",
"license_flags": build_license_flags(None),
},
]
# Test exclude patterns with text matching
pool_config = {
"baseModels": [],
"tags": {"include": [], "exclude": []},
"folders": {"include": [], "exclude": []},
"license": {"noCreditRequired": False, "allowSelling": False},
"namePatterns": {"include": [], "exclude": ["anime"], "useRegex": False},
}
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
assert len(filtered) == 2
file_names = {lora["file_name"] for lora in filtered}
assert file_names == {
"character_realistic_v1.safetensors",
"style_watercolor_v1.safetensors",
}
@pytest.mark.asyncio
async def test_pool_filter_name_patterns_include_regex(lora_service):
"""Test filtering by name patterns with regex matching (useRegex=True)."""
sample_loras = [
{
"file_name": "character_anime_v1.safetensors",
"model_name": "Anime Character",
"base_model": "Illustrious",
"folder": "",
"license_flags": build_license_flags(None),
},
{
"file_name": "character_realistic_v1.safetensors",
"model_name": "Realistic Character",
"base_model": "Illustrious",
"folder": "",
"license_flags": build_license_flags(None),
},
{
"file_name": "style_watercolor_v1.safetensors",
"model_name": "Watercolor Style",
"base_model": "Illustrious",
"folder": "",
"license_flags": build_license_flags(None),
},
]
# Test include patterns with regex matching - match files starting with "character_"
pool_config = {
"baseModels": [],
"tags": {"include": [], "exclude": []},
"folders": {"include": [], "exclude": []},
"license": {"noCreditRequired": False, "allowSelling": False},
"namePatterns": {"include": ["^character_"], "exclude": [], "useRegex": True},
}
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
assert len(filtered) == 2
file_names = {lora["file_name"] for lora in filtered}
assert file_names == {
"character_anime_v1.safetensors",
"character_realistic_v1.safetensors",
}
@pytest.mark.asyncio
async def test_pool_filter_name_patterns_exclude_regex(lora_service):
"""Test excluding by name patterns with regex matching (useRegex=True)."""
sample_loras = [
{
"file_name": "character_anime_v1.safetensors",
"model_name": "Anime Character",
"base_model": "Illustrious",
"folder": "",
"license_flags": build_license_flags(None),
},
{
"file_name": "character_realistic_v1.safetensors",
"model_name": "Realistic Character",
"base_model": "Illustrious",
"folder": "",
"license_flags": build_license_flags(None),
},
{
"file_name": "style_watercolor_v1.safetensors",
"model_name": "Watercolor Style",
"base_model": "Illustrious",
"folder": "",
"license_flags": build_license_flags(None),
},
]
# Test exclude patterns with regex matching - exclude files ending with "_v1.safetensors"
pool_config = {
"baseModels": [],
"tags": {"include": [], "exclude": []},
"folders": {"include": [], "exclude": []},
"license": {"noCreditRequired": False, "allowSelling": False},
"namePatterns": {
"include": [],
"exclude": ["_v1\\.safetensors$"],
"useRegex": True,
},
}
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
assert len(filtered) == 0 # All files match the exclude pattern
@pytest.mark.asyncio
async def test_pool_filter_name_patterns_combined(lora_service):
"""Test combining include and exclude name patterns."""
sample_loras = [
{
"file_name": "character_anime_v1.safetensors",
"model_name": "Anime Character",
"base_model": "Illustrious",
"folder": "",
"license_flags": build_license_flags(None),
},
{
"file_name": "character_realistic_v1.safetensors",
"model_name": "Realistic Character",
"base_model": "Illustrious",
"folder": "",
"license_flags": build_license_flags(None),
},
{
"file_name": "style_watercolor_v1.safetensors",
"model_name": "Watercolor Style",
"base_model": "Illustrious",
"folder": "",
"license_flags": build_license_flags(None),
},
]
# Test include "character" but exclude "anime"
pool_config = {
"baseModels": [],
"tags": {"include": [], "exclude": []},
"folders": {"include": [], "exclude": []},
"license": {"noCreditRequired": False, "allowSelling": False},
"namePatterns": {
"include": ["character"],
"exclude": ["anime"],
"useRegex": False,
},
}
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
assert len(filtered) == 1
assert filtered[0]["file_name"] == "character_realistic_v1.safetensors"
@pytest.mark.asyncio
async def test_pool_filter_name_patterns_model_name_fallback(lora_service):
"""Test that name pattern filtering falls back to model_name when file_name doesn't match."""
sample_loras = [
{
"file_name": "abc123.safetensors",
"model_name": "Super Anime Character",
"base_model": "Illustrious",
"folder": "",
"license_flags": build_license_flags(None),
},
{
"file_name": "def456.safetensors",
"model_name": "Realistic Portrait",
"base_model": "Illustrious",
"folder": "",
"license_flags": build_license_flags(None),
},
]
# Should match model_name even if file_name doesn't contain the pattern
pool_config = {
"baseModels": [],
"tags": {"include": [], "exclude": []},
"folders": {"include": [], "exclude": []},
"license": {"noCreditRequired": False, "allowSelling": False},
"namePatterns": {"include": ["anime"], "exclude": [], "useRegex": False},
}
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
assert len(filtered) == 1
assert filtered[0]["file_name"] == "abc123.safetensors"
@pytest.mark.asyncio
async def test_pool_filter_name_patterns_invalid_regex(lora_service):
"""Test that invalid regex falls back to substring matching."""
sample_loras = [
{
"file_name": "character_anime[test]_v1.safetensors",
"model_name": "Anime Character",
"base_model": "Illustrious",
"folder": "",
"license_flags": build_license_flags(None),
},
]
# Invalid regex pattern (unclosed character class) should fall back to substring matching
# The pattern "anime[" is invalid regex but valid substring - it exists in the filename
pool_config = {
"baseModels": [],
"tags": {"include": [], "exclude": []},
"folders": {"include": [], "exclude": []},
"license": {"noCreditRequired": False, "allowSelling": False},
"namePatterns": {"include": ["anime["], "exclude": [], "useRegex": True},
}
# Should not crash and should match using substring fallback
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
assert len(filtered) == 1 # Substring match works even with invalid regex

View File

@@ -0,0 +1,158 @@
"""Tests for checkpoint and unet loaders with extra folder paths support"""
import pytest
import os
# Get project root directory (ComfyUI-Lora-Manager folder)
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class TestCheckpointLoaderLM:
"""Test CheckpointLoaderLM node"""
def test_class_attributes(self):
"""Test that CheckpointLoaderLM has required class attributes"""
# Import in a way that doesn't require ComfyUI
import ast
filepath = os.path.join(PROJECT_ROOT, "py", "nodes", "checkpoint_loader.py")
with open(filepath, "r") as f:
tree = ast.parse(f.read())
# Find CheckpointLoaderLM class
classes = {
node.name: node for node in ast.walk(tree) if isinstance(node, ast.ClassDef)
}
assert "CheckpointLoaderLM" in classes
cls = classes["CheckpointLoaderLM"]
# Check for NAME attribute
name_attr = [
n
for n in cls.body
if isinstance(n, ast.Assign)
and any(t.id == "NAME" for t in n.targets if isinstance(t, ast.Name))
]
assert len(name_attr) > 0, "CheckpointLoaderLM should have NAME attribute"
# Check for CATEGORY attribute
cat_attr = [
n
for n in cls.body
if isinstance(n, ast.Assign)
and any(t.id == "CATEGORY" for t in n.targets if isinstance(t, ast.Name))
]
assert len(cat_attr) > 0, "CheckpointLoaderLM should have CATEGORY attribute"
# Check for INPUT_TYPES method
input_types = [
n
for n in cls.body
if isinstance(n, ast.FunctionDef) and n.name == "INPUT_TYPES"
]
assert len(input_types) > 0, "CheckpointLoaderLM should have INPUT_TYPES method"
# Check for load_checkpoint method
load_method = [
n
for n in cls.body
if isinstance(n, ast.FunctionDef) and n.name == "load_checkpoint"
]
assert len(load_method) > 0, (
"CheckpointLoaderLM should have load_checkpoint method"
)
class TestUNETLoaderLM:
"""Test UNETLoaderLM node"""
def test_class_attributes(self):
"""Test that UNETLoaderLM has required class attributes"""
# Import in a way that doesn't require ComfyUI
import ast
filepath = os.path.join(PROJECT_ROOT, "py", "nodes", "unet_loader.py")
with open(filepath, "r") as f:
tree = ast.parse(f.read())
# Find UNETLoaderLM class
classes = {
node.name: node for node in ast.walk(tree) if isinstance(node, ast.ClassDef)
}
assert "UNETLoaderLM" in classes
cls = classes["UNETLoaderLM"]
# Check for NAME attribute
name_attr = [
n
for n in cls.body
if isinstance(n, ast.Assign)
and any(t.id == "NAME" for t in n.targets if isinstance(t, ast.Name))
]
assert len(name_attr) > 0, "UNETLoaderLM should have NAME attribute"
# Check for CATEGORY attribute
cat_attr = [
n
for n in cls.body
if isinstance(n, ast.Assign)
and any(t.id == "CATEGORY" for t in n.targets if isinstance(t, ast.Name))
]
assert len(cat_attr) > 0, "UNETLoaderLM should have CATEGORY attribute"
# Check for INPUT_TYPES method
input_types = [
n
for n in cls.body
if isinstance(n, ast.FunctionDef) and n.name == "INPUT_TYPES"
]
assert len(input_types) > 0, "UNETLoaderLM should have INPUT_TYPES method"
# Check for load_unet method
load_method = [
n
for n in cls.body
if isinstance(n, ast.FunctionDef) and n.name == "load_unet"
]
assert len(load_method) > 0, "UNETLoaderLM should have load_unet method"
class TestUtils:
"""Test utility functions"""
def test_get_checkpoint_info_absolute_exists(self):
"""Test that get_checkpoint_info_absolute function exists in utils"""
import ast
filepath = os.path.join(PROJECT_ROOT, "py", "utils", "utils.py")
with open(filepath, "r") as f:
tree = ast.parse(f.read())
functions = [
node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)
]
assert "get_checkpoint_info_absolute" in functions, (
"get_checkpoint_info_absolute should exist"
)
def test_format_model_name_for_comfyui_exists(self):
"""Test that _format_model_name_for_comfyui function exists in utils"""
import ast
filepath = os.path.join(PROJECT_ROOT, "py", "utils", "utils.py")
with open(filepath, "r") as f:
tree = ast.parse(f.read())
functions = [
node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)
]
assert "_format_model_name_for_comfyui" in functions, (
"_format_model_name_for_comfyui should exist"
)

View File

@@ -2,8 +2,8 @@
<div class="lora-cycler-widget"> <div class="lora-cycler-widget">
<LoraCyclerSettingsView <LoraCyclerSettingsView
:current-index="state.currentIndex.value" :current-index="state.currentIndex.value"
:total-count="state.totalCount.value" :total-count="displayTotalCount"
:current-lora-name="state.currentLoraName.value" :current-lora-name="displayLoraName"
:current-lora-filename="state.currentLoraFilename.value" :current-lora-filename="state.currentLoraFilename.value"
:model-strength="state.modelStrength.value" :model-strength="state.modelStrength.value"
:clip-strength="state.clipStrength.value" :clip-strength="state.clipStrength.value"
@@ -16,11 +16,14 @@
:is-pause-disabled="hasQueuedPrompts" :is-pause-disabled="hasQueuedPrompts"
:is-workflow-executing="state.isWorkflowExecuting.value" :is-workflow-executing="state.isWorkflowExecuting.value"
:executing-repeat-step="state.executingRepeatStep.value" :executing-repeat-step="state.executingRepeatStep.value"
:include-no-lora="state.includeNoLora.value"
:is-no-lora="isNoLora"
@update:current-index="handleIndexUpdate" @update:current-index="handleIndexUpdate"
@update:model-strength="state.modelStrength.value = $event" @update:model-strength="state.modelStrength.value = $event"
@update:clip-strength="state.clipStrength.value = $event" @update:clip-strength="state.clipStrength.value = $event"
@update:use-custom-clip-range="handleUseCustomClipRangeChange" @update:use-custom-clip-range="handleUseCustomClipRangeChange"
@update:repeat-count="handleRepeatCountChange" @update:repeat-count="handleRepeatCountChange"
@update:include-no-lora="handleIncludeNoLoraChange"
@toggle-pause="handleTogglePause" @toggle-pause="handleTogglePause"
@reset-index="handleResetIndex" @reset-index="handleResetIndex"
@open-lora-selector="isModalOpen = true" @open-lora-selector="isModalOpen = true"
@@ -30,6 +33,7 @@
:visible="isModalOpen" :visible="isModalOpen"
:lora-list="cachedLoraList" :lora-list="cachedLoraList"
:current-index="state.currentIndex.value" :current-index="state.currentIndex.value"
:include-no-lora="state.includeNoLora.value"
@close="isModalOpen = false" @close="isModalOpen = false"
@select="handleModalSelect" @select="handleModalSelect"
/> />
@@ -37,7 +41,7 @@
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { onMounted, ref } from 'vue' import { onMounted, ref, computed } from 'vue'
import LoraCyclerSettingsView from './lora-cycler/LoraCyclerSettingsView.vue' import LoraCyclerSettingsView from './lora-cycler/LoraCyclerSettingsView.vue'
import LoraListModal from './lora-cycler/LoraListModal.vue' import LoraListModal from './lora-cycler/LoraListModal.vue'
import { useLoraCyclerState } from '../composables/useLoraCyclerState' import { useLoraCyclerState } from '../composables/useLoraCyclerState'
@@ -102,6 +106,31 @@ const isModalOpen = ref(false)
// Cache for LoRA list (used by modal) // Cache for LoRA list (used by modal)
const cachedLoraList = ref<LoraItem[]>([]) const cachedLoraList = ref<LoraItem[]>([])
// Computed: display total count (includes no lora option if enabled)
const displayTotalCount = computed(() => {
const baseCount = state.totalCount.value
return state.includeNoLora.value ? baseCount + 1 : baseCount
})
// Computed: display LoRA name (shows "No LoRA" if on the last index and includeNoLora is enabled)
const displayLoraName = computed(() => {
const currentIndex = state.currentIndex.value
const totalCount = state.totalCount.value
// If includeNoLora is enabled and we're on the last position (no lora slot)
if (state.includeNoLora.value && currentIndex === totalCount + 1) {
return 'No LoRA'
}
// Otherwise show the normal LoRA name
return state.currentLoraName.value
})
// Computed: check if currently on "No LoRA" option
const isNoLora = computed(() => {
return state.includeNoLora.value && state.currentIndex.value === state.totalCount.value + 1
})
// Get pool config from connected node // Get pool config from connected node
const getPoolConfig = (): LoraPoolConfig | null => { const getPoolConfig = (): LoraPoolConfig | null => {
// Check if getPoolConfig method exists on node (added by main.ts) // Check if getPoolConfig method exists on node (added by main.ts)
@@ -113,7 +142,17 @@ const getPoolConfig = (): LoraPoolConfig | null => {
// Update display from LoRA list and index // Update display from LoRA list and index
const updateDisplayFromLoraList = (loraList: LoraItem[], index: number) => { const updateDisplayFromLoraList = (loraList: LoraItem[], index: number) => {
if (loraList.length > 0 && index > 0 && index <= loraList.length) { const actualLoraCount = loraList.length
// If index is beyond actual LoRA count, it means we're on the "no lora" option
if (state.includeNoLora.value && index === actualLoraCount + 1) {
state.currentLoraName.value = 'No LoRA'
state.currentLoraFilename.value = 'No LoRA'
return
}
// Otherwise, show normal LoRA info
if (actualLoraCount > 0 && index > 0 && index <= actualLoraCount) {
const currentLora = loraList[index - 1] const currentLora = loraList[index - 1]
if (currentLora) { if (currentLora) {
state.currentLoraName.value = currentLora.file_name state.currentLoraName.value = currentLora.file_name
@@ -124,6 +163,14 @@ const updateDisplayFromLoraList = (loraList: LoraItem[], index: number) => {
// Handle index update from user // Handle index update from user
const handleIndexUpdate = async (newIndex: number) => { const handleIndexUpdate = async (newIndex: number) => {
// Calculate max valid index (includes no lora slot if enabled)
const maxIndex = state.includeNoLora.value
? state.totalCount.value + 1
: state.totalCount.value
// Clamp index to valid range
const clampedIndex = Math.max(1, Math.min(newIndex, maxIndex || 1))
// Reset execution state when user manually changes index // Reset execution state when user manually changes index
// This ensures the next execution starts from the user-set index // This ensures the next execution starts from the user-set index
;(props.widget as any)[HAS_EXECUTED] = false ;(props.widget as any)[HAS_EXECUTED] = false
@@ -134,14 +181,14 @@ const handleIndexUpdate = async (newIndex: number) => {
executionQueue.length = 0 executionQueue.length = 0
hasQueuedPrompts.value = false hasQueuedPrompts.value = false
state.setIndex(newIndex) state.setIndex(clampedIndex)
// Refresh list to update current LoRA display // Refresh list to update current LoRA display
try { try {
const poolConfig = getPoolConfig() const poolConfig = getPoolConfig()
const loraList = await state.fetchCyclerList(poolConfig) const loraList = await state.fetchCyclerList(poolConfig)
cachedLoraList.value = loraList cachedLoraList.value = loraList
updateDisplayFromLoraList(loraList, newIndex) updateDisplayFromLoraList(loraList, clampedIndex)
} catch (error) { } catch (error) {
console.error('[LoraCyclerWidget] Error updating index:', error) console.error('[LoraCyclerWidget] Error updating index:', error)
} }
@@ -169,6 +216,17 @@ const handleRepeatCountChange = (newValue: number) => {
state.displayRepeatUsed.value = 0 state.displayRepeatUsed.value = 0
} }
// Handle include no lora toggle
const handleIncludeNoLoraChange = (newValue: boolean) => {
state.includeNoLora.value = newValue
// If turning off and current index is beyond the actual LoRA count,
// clamp it to the last valid LoRA index
if (!newValue && state.currentIndex.value > state.totalCount.value) {
state.currentIndex.value = Math.max(1, state.totalCount.value)
}
}
// Handle pause toggle // Handle pause toggle
const handleTogglePause = () => { const handleTogglePause = () => {
state.togglePause() state.togglePause()

View File

@@ -8,6 +8,9 @@
:exclude-tags="state.excludeTags.value" :exclude-tags="state.excludeTags.value"
:include-folders="state.includeFolders.value" :include-folders="state.includeFolders.value"
:exclude-folders="state.excludeFolders.value" :exclude-folders="state.excludeFolders.value"
:include-patterns="state.includePatterns.value"
:exclude-patterns="state.excludePatterns.value"
:use-regex="state.useRegex.value"
:no-credit-required="state.noCreditRequired.value" :no-credit-required="state.noCreditRequired.value"
:allow-selling="state.allowSelling.value" :allow-selling="state.allowSelling.value"
:preview-items="state.previewItems.value" :preview-items="state.previewItems.value"
@@ -16,6 +19,9 @@
@open-modal="openModal" @open-modal="openModal"
@update:include-folders="state.includeFolders.value = $event" @update:include-folders="state.includeFolders.value = $event"
@update:exclude-folders="state.excludeFolders.value = $event" @update:exclude-folders="state.excludeFolders.value = $event"
@update:include-patterns="state.includePatterns.value = $event"
@update:exclude-patterns="state.excludePatterns.value = $event"
@update:use-regex="state.useRegex.value = $event"
@update:no-credit-required="state.noCreditRequired.value = $event" @update:no-credit-required="state.noCreditRequired.value = $event"
@update:allow-selling="state.allowSelling.value = $event" @update:allow-selling="state.allowSelling.value = $event"
@refresh="state.refreshPreview" @refresh="state.refreshPreview"

View File

@@ -13,7 +13,9 @@
@click="handleOpenSelector" @click="handleOpenSelector"
> >
<span class="progress-label">{{ isWorkflowExecuting ? 'Using LoRA:' : 'Next LoRA:' }}</span> <span class="progress-label">{{ isWorkflowExecuting ? 'Using LoRA:' : 'Next LoRA:' }}</span>
<span class="progress-name clickable" :class="{ disabled: isPauseDisabled }" :title="currentLoraFilename"> <span class="progress-name clickable"
:class="{ disabled: isPauseDisabled, 'no-lora': isNoLora }"
:title="currentLoraFilename">
{{ currentLoraName || 'None' }} {{ currentLoraName || 'None' }}
<svg class="selector-icon" viewBox="0 0 24 24" fill="currentColor"> <svg class="selector-icon" viewBox="0 0 24 24" fill="currentColor">
<path d="M7 10l5 5 5-5z"/> <path d="M7 10l5 5 5-5z"/>
@@ -160,6 +162,27 @@
/> />
</div> </div>
</div> </div>
<!-- Include No LoRA Toggle -->
<div class="setting-section">
<div class="section-header-with-toggle">
<label class="setting-label">
Add "No LoRA" step
</label>
<button
type="button"
class="toggle-switch"
:class="{ 'toggle-switch--active': includeNoLora }"
@click="$emit('update:includeNoLora', !includeNoLora)"
role="switch"
:aria-checked="includeNoLora"
title="Add an iteration without LoRA for comparison"
>
<span class="toggle-switch__track"></span>
<span class="toggle-switch__thumb"></span>
</button>
</div>
</div>
</div> </div>
</template> </template>
@@ -182,6 +205,8 @@ const props = defineProps<{
isPauseDisabled: boolean isPauseDisabled: boolean
isWorkflowExecuting: boolean isWorkflowExecuting: boolean
executingRepeatStep: number executingRepeatStep: number
includeNoLora: boolean
isNoLora?: boolean
}>() }>()
const emit = defineEmits<{ const emit = defineEmits<{
@@ -190,6 +215,7 @@ const emit = defineEmits<{
'update:clipStrength': [value: number] 'update:clipStrength': [value: number]
'update:useCustomClipRange': [value: boolean] 'update:useCustomClipRange': [value: boolean]
'update:repeatCount': [value: number] 'update:repeatCount': [value: number]
'update:includeNoLora': [value: boolean]
'toggle-pause': [] 'toggle-pause': []
'reset-index': [] 'reset-index': []
'open-lora-selector': [] 'open-lora-selector': []
@@ -346,6 +372,16 @@ const onRepeatBlur = (event: Event) => {
color: rgba(191, 219, 254, 1); color: rgba(191, 219, 254, 1);
} }
.progress-name.no-lora {
font-style: italic;
color: rgba(226, 232, 240, 0.6);
}
.progress-name.clickable.no-lora:hover:not(.disabled) {
background: rgba(160, 174, 192, 0.2);
color: rgba(226, 232, 240, 0.8);
}
.progress-name.clickable.disabled { .progress-name.clickable.disabled {
cursor: not-allowed; cursor: not-allowed;
opacity: 0.5; opacity: 0.5;

View File

@@ -35,7 +35,10 @@
v-for="item in filteredList" v-for="item in filteredList"
:key="item.index" :key="item.index"
class="lora-item" class="lora-item"
:class="{ active: currentIndex === item.index }" :class="{
active: currentIndex === item.index,
'no-lora-item': item.lora.file_name === 'No LoRA'
}"
@mouseenter="showPreview(item.lora.file_name, $event)" @mouseenter="showPreview(item.lora.file_name, $event)"
@mouseleave="hidePreview" @mouseleave="hidePreview"
@click="selectLora(item.index)" @click="selectLora(item.index)"
@@ -65,6 +68,7 @@ const props = defineProps<{
visible: boolean visible: boolean
loraList: LoraItem[] loraList: LoraItem[]
currentIndex: number currentIndex: number
includeNoLora?: boolean
}>() }>()
const emit = defineEmits<{ const emit = defineEmits<{
@@ -79,7 +83,8 @@ const searchInputRef = ref<HTMLInputElement | null>(null)
let previewTooltip: any = null let previewTooltip: any = null
const subtitleText = computed(() => { const subtitleText = computed(() => {
const total = props.loraList.length const baseTotal = props.loraList.length
const total = props.includeNoLora ? baseTotal + 1 : baseTotal
const filtered = filteredList.value.length const filtered = filteredList.value.length
if (filtered === total) { if (filtered === total) {
return `Total: ${total} LoRA${total !== 1 ? 's' : ''}` return `Total: ${total} LoRA${total !== 1 ? 's' : ''}`
@@ -88,11 +93,19 @@ const subtitleText = computed(() => {
}) })
const filteredList = computed<LoraListItem[]>(() => { const filteredList = computed<LoraListItem[]>(() => {
const list = props.loraList.map((lora, idx) => ({ const list: LoraListItem[] = props.loraList.map((lora, idx) => ({
index: idx + 1, index: idx + 1,
lora lora
})) }))
// Add "No LoRA" option at the end if includeNoLora is enabled
if (props.includeNoLora) {
list.push({
index: list.length + 1,
lora: { file_name: 'No LoRA' } as LoraItem
})
}
if (!searchQuery.value.trim()) { if (!searchQuery.value.trim()) {
return list return list
} }
@@ -303,6 +316,15 @@ onUnmounted(() => {
font-weight: 500; font-weight: 500;
} }
.lora-item.no-lora-item .lora-name {
font-style: italic;
color: rgba(226, 232, 240, 0.6);
}
.lora-item.no-lora-item:hover .lora-name {
color: rgba(226, 232, 240, 0.8);
}
.no-results { .no-results {
padding: 32px 20px; padding: 32px 20px;
text-align: center; text-align: center;

View File

@@ -24,6 +24,15 @@
@edit-exclude="$emit('open-modal', 'excludeFolders')" @edit-exclude="$emit('open-modal', 'excludeFolders')"
/> />
<NamePatternsSection
:include-patterns="includePatterns"
:exclude-patterns="excludePatterns"
:use-regex="useRegex"
@update:include-patterns="$emit('update:includePatterns', $event)"
@update:exclude-patterns="$emit('update:excludePatterns', $event)"
@update:use-regex="$emit('update:useRegex', $event)"
/>
<LicenseSection <LicenseSection
:no-credit-required="noCreditRequired" :no-credit-required="noCreditRequired"
:allow-selling="allowSelling" :allow-selling="allowSelling"
@@ -46,6 +55,7 @@
import BaseModelSection from './sections/BaseModelSection.vue' import BaseModelSection from './sections/BaseModelSection.vue'
import TagsSection from './sections/TagsSection.vue' import TagsSection from './sections/TagsSection.vue'
import FoldersSection from './sections/FoldersSection.vue' import FoldersSection from './sections/FoldersSection.vue'
import NamePatternsSection from './sections/NamePatternsSection.vue'
import LicenseSection from './sections/LicenseSection.vue' import LicenseSection from './sections/LicenseSection.vue'
import LoraPoolPreview from './LoraPoolPreview.vue' import LoraPoolPreview from './LoraPoolPreview.vue'
import type { BaseModelOption, LoraItem } from '../../composables/types' import type { BaseModelOption, LoraItem } from '../../composables/types'
@@ -61,6 +71,10 @@ defineProps<{
// Folders // Folders
includeFolders: string[] includeFolders: string[]
excludeFolders: string[] excludeFolders: string[]
// Name patterns
includePatterns: string[]
excludePatterns: string[]
useRegex: boolean
// License // License
noCreditRequired: boolean noCreditRequired: boolean
allowSelling: boolean allowSelling: boolean
@@ -74,6 +88,9 @@ defineEmits<{
'open-modal': [modal: ModalType] 'open-modal': [modal: ModalType]
'update:includeFolders': [value: string[]] 'update:includeFolders': [value: string[]]
'update:excludeFolders': [value: string[]] 'update:excludeFolders': [value: string[]]
'update:includePatterns': [value: string[]]
'update:excludePatterns': [value: string[]]
'update:useRegex': [value: boolean]
'update:noCreditRequired': [value: boolean] 'update:noCreditRequired': [value: boolean]
'update:allowSelling': [value: boolean] 'update:allowSelling': [value: boolean]
refresh: [] refresh: []

View File

@@ -0,0 +1,255 @@
<template>
<div class="section">
<div class="section__header">
<span class="section__title">NAME PATTERNS</span>
<label class="section__toggle">
<input
type="checkbox"
:checked="useRegex"
@change="$emit('update:useRegex', ($event.target as HTMLInputElement).checked)"
/>
<span class="section__toggle-label">Use Regex</span>
</label>
</div>
<div class="section__columns">
<!-- Include column -->
<div class="section__column">
<div class="section__column-header">
<span class="section__column-title section__column-title--include">INCLUDE</span>
</div>
<div class="section__input-wrapper">
<input
type="text"
v-model="includeInput"
:placeholder="useRegex ? 'Add regex pattern...' : 'Add text pattern...'"
class="section__input"
@keydown.enter="addInclude"
/>
<button type="button" class="section__add-btn" @click="addInclude">+</button>
</div>
<div class="section__patterns">
<FilterChip
v-for="pattern in includePatterns"
:key="pattern"
:label="pattern"
variant="include"
removable
@remove="removeInclude(pattern)"
/>
<div v-if="includePatterns.length === 0" class="section__empty">
{{ useRegex ? 'No regex patterns' : 'No text patterns' }}
</div>
</div>
</div>
<!-- Exclude column -->
<div class="section__column">
<div class="section__column-header">
<span class="section__column-title section__column-title--exclude">EXCLUDE</span>
</div>
<div class="section__input-wrapper">
<input
type="text"
v-model="excludeInput"
:placeholder="useRegex ? 'Add regex pattern...' : 'Add text pattern...'"
class="section__input"
@keydown.enter="addExclude"
/>
<button type="button" class="section__add-btn" @click="addExclude">+</button>
</div>
<div class="section__patterns">
<FilterChip
v-for="pattern in excludePatterns"
:key="pattern"
:label="pattern"
variant="exclude"
removable
@remove="removeExclude(pattern)"
/>
<div v-if="excludePatterns.length === 0" class="section__empty">
{{ useRegex ? 'No regex patterns' : 'No text patterns' }}
</div>
</div>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { ref } from 'vue'
import FilterChip from '../shared/FilterChip.vue'
const props = defineProps<{
includePatterns: string[]
excludePatterns: string[]
useRegex: boolean
}>()
const emit = defineEmits<{
'update:includePatterns': [value: string[]]
'update:excludePatterns': [value: string[]]
'update:useRegex': [value: boolean]
}>()
const includeInput = ref('')
const excludeInput = ref('')
const addInclude = () => {
const pattern = includeInput.value.trim()
if (pattern && !props.includePatterns.includes(pattern)) {
emit('update:includePatterns', [...props.includePatterns, pattern])
includeInput.value = ''
}
}
const addExclude = () => {
const pattern = excludeInput.value.trim()
if (pattern && !props.excludePatterns.includes(pattern)) {
emit('update:excludePatterns', [...props.excludePatterns, pattern])
excludeInput.value = ''
}
}
const removeInclude = (pattern: string) => {
emit('update:includePatterns', props.includePatterns.filter(p => p !== pattern))
}
const removeExclude = (pattern: string) => {
emit('update:excludePatterns', props.excludePatterns.filter(p => p !== pattern))
}
</script>
<style scoped>
.section {
margin-bottom: 16px;
}
.section__header {
display: flex;
align-items: center;
justify-content: space-between;
margin-bottom: 8px;
}
.section__title {
font-size: 10px;
font-weight: 600;
text-transform: uppercase;
letter-spacing: 0.05em;
color: var(--fg-color, #fff);
opacity: 0.6;
}
.section__toggle {
display: flex;
align-items: center;
gap: 6px;
cursor: pointer;
font-size: 11px;
color: var(--fg-color, #fff);
opacity: 0.7;
}
.section__toggle input[type="checkbox"] {
margin: 0;
width: 14px;
height: 14px;
cursor: pointer;
}
.section__toggle-label {
font-weight: 500;
}
.section__columns {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 12px;
}
.section__column {
min-width: 0;
}
.section__column-header {
display: flex;
align-items: center;
justify-content: space-between;
margin-bottom: 6px;
}
.section__column-title {
font-size: 9px;
font-weight: 500;
text-transform: uppercase;
letter-spacing: 0.03em;
}
.section__column-title--include {
color: #4299e1;
}
.section__column-title--exclude {
color: #ef4444;
}
.section__input-wrapper {
display: flex;
gap: 4px;
margin-bottom: 8px;
}
.section__input {
flex: 1;
min-width: 0;
padding: 6px 8px;
background: var(--comfy-input-bg, #333);
border: 1px solid var(--comfy-input-border, #444);
border-radius: 4px;
color: var(--fg-color, #fff);
font-size: 12px;
outline: none;
}
.section__input:focus {
border-color: #4299e1;
}
.section__add-btn {
width: 28px;
height: 28px;
display: flex;
align-items: center;
justify-content: center;
background: var(--comfy-input-bg, #333);
border: 1px solid var(--comfy-input-border, #444);
border-radius: 4px;
color: var(--fg-color, #fff);
font-size: 16px;
font-weight: 500;
cursor: pointer;
transition: all 0.15s;
}
.section__add-btn:hover {
background: var(--comfy-input-bg-hover, #444);
border-color: #4299e1;
}
.section__patterns {
display: flex;
flex-wrap: wrap;
gap: 4px;
min-height: 22px;
}
.section__empty {
font-size: 10px;
color: var(--fg-color, #fff);
opacity: 0.3;
font-style: italic;
min-height: 22px;
display: flex;
align-items: center;
}
</style>

View File

@@ -10,6 +10,12 @@ export interface LoraPoolConfig {
noCreditRequired: boolean noCreditRequired: boolean
allowSelling: boolean allowSelling: boolean
} }
namePatterns: {
include: string[]
exclude: string[]
useRegex: boolean
}
includeEmptyLora?: boolean // Optional, deprecated (moved to Cycler)
} }
preview: { matchCount: number; lastUpdated: number } preview: { matchCount: number; lastUpdated: number }
} }
@@ -84,6 +90,8 @@ export interface CyclerConfig {
repeat_count: number // How many times each LoRA should repeat (default: 1) repeat_count: number // How many times each LoRA should repeat (default: 1)
repeat_used: number // How many times current index has been used repeat_used: number // How many times current index has been used
is_paused: boolean // Whether iteration is paused is_paused: boolean // Whether iteration is paused
// Include "no LoRA" option in cycle
include_no_lora: boolean // Whether to include empty LoRA option
} }
// Widget config union type // Widget config union type

View File

@@ -4,6 +4,7 @@ import type { ComponentWidget, CyclerConfig, LoraPoolConfig } from './types'
export interface CyclerLoraItem { export interface CyclerLoraItem {
file_name: string file_name: string
model_name: string model_name: string
file_path: string
} }
export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) { export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
@@ -34,6 +35,7 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
const repeatUsed = ref(0) // How many times current index has been used (internal tracking) const repeatUsed = ref(0) // How many times current index has been used (internal tracking)
const displayRepeatUsed = ref(0) // For UI display, deferred updates like currentIndex const displayRepeatUsed = ref(0) // For UI display, deferred updates like currentIndex
const isPaused = ref(false) // Whether iteration is paused const isPaused = ref(false) // Whether iteration is paused
const includeNoLora = ref(false) // Whether to include empty LoRA option in cycle
// Execution progress tracking (visual feedback) // Execution progress tracking (visual feedback)
const isWorkflowExecuting = ref(false) // Workflow is currently running const isWorkflowExecuting = ref(false) // Workflow is currently running
@@ -58,6 +60,7 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
repeat_count: repeatCount.value, repeat_count: repeatCount.value,
repeat_used: repeatUsed.value, repeat_used: repeatUsed.value,
is_paused: isPaused.value, is_paused: isPaused.value,
include_no_lora: includeNoLora.value,
} }
} }
return { return {
@@ -75,6 +78,7 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
repeat_count: repeatCount.value, repeat_count: repeatCount.value,
repeat_used: repeatUsed.value, repeat_used: repeatUsed.value,
is_paused: isPaused.value, is_paused: isPaused.value,
include_no_lora: includeNoLora.value,
} }
} }
@@ -93,12 +97,13 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
sortBy.value = config.sort_by || 'filename' sortBy.value = config.sort_by || 'filename'
currentLoraName.value = config.current_lora_name || '' currentLoraName.value = config.current_lora_name || ''
currentLoraFilename.value = config.current_lora_filename || '' currentLoraFilename.value = config.current_lora_filename || ''
// Advanced index control features // Advanced index control features
repeatCount.value = config.repeat_count ?? 1 repeatCount.value = config.repeat_count ?? 1
repeatUsed.value = config.repeat_used ?? 0 repeatUsed.value = config.repeat_used ?? 0
isPaused.value = config.is_paused ?? false isPaused.value = config.is_paused ?? false
// Note: execution_index and next_index are not restored from config includeNoLora.value = config.include_no_lora ?? false
// as they are transient values used only during batch execution // Note: execution_index and next_index are not restored from config
// as they are transient values used only during batch execution
} finally { } finally {
isRestoring = false isRestoring = false
} }
@@ -111,7 +116,9 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
// Calculate the next index (wrap to 1 if at end) // Calculate the next index (wrap to 1 if at end)
const current = executionIndex.value ?? currentIndex.value const current = executionIndex.value ?? currentIndex.value
let next = current + 1 let next = current + 1
if (totalCount.value > 0 && next > totalCount.value) { // Total count includes no lora option if enabled
const effectiveTotalCount = includeNoLora.value ? totalCount.value + 1 : totalCount.value
if (effectiveTotalCount > 0 && next > effectiveTotalCount) {
next = 1 next = 1
} }
nextIndex.value = next nextIndex.value = next
@@ -122,7 +129,9 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
if (nextIndex.value === null) { if (nextIndex.value === null) {
// First execution uses current_index, so next is current + 1 // First execution uses current_index, so next is current + 1
let next = currentIndex.value + 1 let next = currentIndex.value + 1
if (totalCount.value > 0 && next > totalCount.value) { // Total count includes no lora option if enabled
const effectiveTotalCount = includeNoLora.value ? totalCount.value + 1 : totalCount.value
if (effectiveTotalCount > 0 && next > effectiveTotalCount) {
next = 1 next = 1
} }
nextIndex.value = next nextIndex.value = next
@@ -230,7 +239,9 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
// Set index manually // Set index manually
const setIndex = (index: number) => { const setIndex = (index: number) => {
if (index >= 1 && index <= totalCount.value) { // Total count includes no lora option if enabled
const effectiveTotalCount = includeNoLora.value ? totalCount.value + 1 : totalCount.value
if (index >= 1 && index <= effectiveTotalCount) {
currentIndex.value = index currentIndex.value = index
} }
} }
@@ -272,6 +283,7 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
repeatCount, repeatCount,
repeatUsed, repeatUsed,
isPaused, isPaused,
includeNoLora,
], () => { ], () => {
widget.value = buildConfig() widget.value = buildConfig()
}, { deep: true }) }, { deep: true })
@@ -294,6 +306,7 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
repeatUsed, repeatUsed,
displayRepeatUsed, displayRepeatUsed,
isPaused, isPaused,
includeNoLora,
isWorkflowExecuting, isWorkflowExecuting,
executingRepeatStep, executingRepeatStep,

View File

@@ -62,6 +62,9 @@ export function useLoraPoolApi() {
foldersExclude?: string[] foldersExclude?: string[]
noCreditRequired?: boolean noCreditRequired?: boolean
allowSelling?: boolean allowSelling?: boolean
namePatternsInclude?: string[]
namePatternsExclude?: string[]
namePatternsUseRegex?: boolean
page?: number page?: number
pageSize?: number pageSize?: number
} }
@@ -92,6 +95,13 @@ export function useLoraPoolApi() {
urlParams.set('allow_selling_generated_content', String(params.allowSelling)) urlParams.set('allow_selling_generated_content', String(params.allowSelling))
} }
// Name pattern filters
params.namePatternsInclude?.forEach(pattern => urlParams.append('name_pattern_include', pattern))
params.namePatternsExclude?.forEach(pattern => urlParams.append('name_pattern_exclude', pattern))
if (params.namePatternsUseRegex !== undefined) {
urlParams.set('name_pattern_use_regex', String(params.namePatternsUseRegex))
}
const response = await fetch(`/api/lm/loras/list?${urlParams}`) const response = await fetch(`/api/lm/loras/list?${urlParams}`)
const data = await response.json() const data = await response.json()

View File

@@ -24,6 +24,9 @@ export function useLoraPoolState(widget: ComponentWidget<LoraPoolConfig>) {
const excludeFolders = ref<string[]>([]) const excludeFolders = ref<string[]>([])
const noCreditRequired = ref(false) const noCreditRequired = ref(false)
const allowSelling = ref(false) const allowSelling = ref(false)
const includePatterns = ref<string[]>([])
const excludePatterns = ref<string[]>([])
const useRegex = ref(false)
// Available options from API // Available options from API
const availableBaseModels = ref<BaseModelOption[]>([]) const availableBaseModels = ref<BaseModelOption[]>([])
@@ -52,6 +55,11 @@ export function useLoraPoolState(widget: ComponentWidget<LoraPoolConfig>) {
license: { license: {
noCreditRequired: noCreditRequired.value, noCreditRequired: noCreditRequired.value,
allowSelling: allowSelling.value allowSelling: allowSelling.value
},
namePatterns: {
include: includePatterns.value,
exclude: excludePatterns.value,
useRegex: useRegex.value
} }
}, },
preview: { preview: {
@@ -94,6 +102,9 @@ export function useLoraPoolState(widget: ComponentWidget<LoraPoolConfig>) {
updateIfChanged(excludeFolders, filters.folders?.exclude || []) updateIfChanged(excludeFolders, filters.folders?.exclude || [])
updateIfChanged(noCreditRequired, filters.license?.noCreditRequired ?? false) updateIfChanged(noCreditRequired, filters.license?.noCreditRequired ?? false)
updateIfChanged(allowSelling, filters.license?.allowSelling ?? false) updateIfChanged(allowSelling, filters.license?.allowSelling ?? false)
updateIfChanged(includePatterns, filters.namePatterns?.include || [])
updateIfChanged(excludePatterns, filters.namePatterns?.exclude || [])
updateIfChanged(useRegex, filters.namePatterns?.useRegex ?? false)
// matchCount doesn't trigger watchers, so direct assignment is fine // matchCount doesn't trigger watchers, so direct assignment is fine
matchCount.value = preview?.matchCount || 0 matchCount.value = preview?.matchCount || 0
@@ -125,6 +136,9 @@ export function useLoraPoolState(widget: ComponentWidget<LoraPoolConfig>) {
foldersExclude: excludeFolders.value, foldersExclude: excludeFolders.value,
noCreditRequired: noCreditRequired.value || undefined, noCreditRequired: noCreditRequired.value || undefined,
allowSelling: allowSelling.value || undefined, allowSelling: allowSelling.value || undefined,
namePatternsInclude: includePatterns.value,
namePatternsExclude: excludePatterns.value,
namePatternsUseRegex: useRegex.value,
pageSize: 6 pageSize: 6
}) })
@@ -150,7 +164,10 @@ export function useLoraPoolState(widget: ComponentWidget<LoraPoolConfig>) {
includeFolders, includeFolders,
excludeFolders, excludeFolders,
noCreditRequired, noCreditRequired,
allowSelling allowSelling,
includePatterns,
excludePatterns,
useRegex
], onFilterChange, { deep: true }) ], onFilterChange, { deep: true })
return { return {
@@ -162,6 +179,9 @@ export function useLoraPoolState(widget: ComponentWidget<LoraPoolConfig>) {
excludeFolders, excludeFolders,
noCreditRequired, noCreditRequired,
allowSelling, allowSelling,
includePatterns,
excludePatterns,
useRegex,
// Available options // Available options
availableBaseModels, availableBaseModels,

View File

@@ -13,12 +13,12 @@ import {
} from './mode-change-handler' } from './mode-change-handler'
const LORA_POOL_WIDGET_MIN_WIDTH = 500 const LORA_POOL_WIDGET_MIN_WIDTH = 500
const LORA_POOL_WIDGET_MIN_HEIGHT = 400 const LORA_POOL_WIDGET_MIN_HEIGHT = 520
const LORA_RANDOMIZER_WIDGET_MIN_WIDTH = 500 const LORA_RANDOMIZER_WIDGET_MIN_WIDTH = 500
const LORA_RANDOMIZER_WIDGET_MIN_HEIGHT = 448 const LORA_RANDOMIZER_WIDGET_MIN_HEIGHT = 448
const LORA_RANDOMIZER_WIDGET_MAX_HEIGHT = LORA_RANDOMIZER_WIDGET_MIN_HEIGHT const LORA_RANDOMIZER_WIDGET_MAX_HEIGHT = LORA_RANDOMIZER_WIDGET_MIN_HEIGHT
const LORA_CYCLER_WIDGET_MIN_WIDTH = 380 const LORA_CYCLER_WIDGET_MIN_WIDTH = 380
const LORA_CYCLER_WIDGET_MIN_HEIGHT = 314 const LORA_CYCLER_WIDGET_MIN_HEIGHT = 344
const LORA_CYCLER_WIDGET_MAX_HEIGHT = LORA_CYCLER_WIDGET_MIN_HEIGHT const LORA_CYCLER_WIDGET_MAX_HEIGHT = LORA_CYCLER_WIDGET_MIN_HEIGHT
const JSON_DISPLAY_WIDGET_MIN_WIDTH = 300 const JSON_DISPLAY_WIDGET_MIN_WIDTH = 300
const JSON_DISPLAY_WIDGET_MIN_HEIGHT = 200 const JSON_DISPLAY_WIDGET_MIN_HEIGHT = 200

View File

@@ -84,7 +84,8 @@ describe('useLoraCyclerState', () => {
current_lora_filename: '', current_lora_filename: '',
repeat_count: 1, repeat_count: 1,
repeat_used: 0, repeat_used: 0,
is_paused: false is_paused: false,
include_no_lora: false
}) })
expect(state.currentIndex.value).toBe(5) expect(state.currentIndex.value).toBe(5)

View File

@@ -24,6 +24,7 @@ export function createMockCyclerConfig(overrides: Partial<CyclerConfig> = {}): C
repeat_count: 1, repeat_count: 1,
repeat_used: 0, repeat_used: 0,
is_paused: false, is_paused: false,
include_no_lora: false,
...overrides ...overrides
} }
} }
@@ -54,7 +55,8 @@ export function createMockPoolConfig(overrides: Partial<LoraPoolConfig> = {}): L
export function createMockLoraList(count: number = 5): CyclerLoraItem[] { export function createMockLoraList(count: number = 5): CyclerLoraItem[] {
return Array.from({ length: count }, (_, i) => ({ return Array.from({ length: count }, (_, i) => ({
file_name: `lora${i + 1}.safetensors`, file_name: `lora${i + 1}.safetensors`,
model_name: `LoRA Model ${i + 1}` model_name: `LoRA Model ${i + 1}`,
file_path: `/models/loras/lora${i + 1}.safetensors`
})) }))
} }

View File

@@ -14,6 +14,7 @@ import { initDrag, createContextMenu, initHeaderDrag, initReorderDrag, handleKey
import { forwardMiddleMouseToCanvas } from "./utils.js"; import { forwardMiddleMouseToCanvas } from "./utils.js";
import { PreviewTooltip } from "./preview_tooltip.js"; import { PreviewTooltip } from "./preview_tooltip.js";
import { ensureLmStyles } from "./lm_styles_loader.js"; import { ensureLmStyles } from "./lm_styles_loader.js";
import { getStrengthStepPreference } from "./settings.js";
export function addLorasWidget(node, name, opts, callback) { export function addLorasWidget(node, name, opts, callback) {
ensureLmStyles(); ensureLmStyles();
@@ -416,7 +417,7 @@ export function addLorasWidget(node, name, opts, callback) {
const loraIndex = lorasData.findIndex(l => l.name === name); const loraIndex = lorasData.findIndex(l => l.name === name);
if (loraIndex >= 0) { if (loraIndex >= 0) {
lorasData[loraIndex].strength = (parseFloat(lorasData[loraIndex].strength) - 0.05).toFixed(2); lorasData[loraIndex].strength = (parseFloat(lorasData[loraIndex].strength) - getStrengthStepPreference()).toFixed(2);
// Sync clipStrength if collapsed // Sync clipStrength if collapsed
syncClipStrengthIfCollapsed(lorasData[loraIndex]); syncClipStrengthIfCollapsed(lorasData[loraIndex]);
@@ -488,7 +489,7 @@ export function addLorasWidget(node, name, opts, callback) {
const loraIndex = lorasData.findIndex(l => l.name === name); const loraIndex = lorasData.findIndex(l => l.name === name);
if (loraIndex >= 0) { if (loraIndex >= 0) {
lorasData[loraIndex].strength = (parseFloat(lorasData[loraIndex].strength) + 0.05).toFixed(2); lorasData[loraIndex].strength = (parseFloat(lorasData[loraIndex].strength) + getStrengthStepPreference()).toFixed(2);
// Sync clipStrength if collapsed // Sync clipStrength if collapsed
syncClipStrengthIfCollapsed(lorasData[loraIndex]); syncClipStrengthIfCollapsed(lorasData[loraIndex]);
@@ -541,7 +542,7 @@ export function addLorasWidget(node, name, opts, callback) {
const loraIndex = lorasData.findIndex(l => l.name === name); const loraIndex = lorasData.findIndex(l => l.name === name);
if (loraIndex >= 0) { if (loraIndex >= 0) {
lorasData[loraIndex].clipStrength = (parseFloat(lorasData[loraIndex].clipStrength) - 0.05).toFixed(2); lorasData[loraIndex].clipStrength = (parseFloat(lorasData[loraIndex].clipStrength) - getStrengthStepPreference()).toFixed(2);
const newValue = formatLoraValue(lorasData); const newValue = formatLoraValue(lorasData);
updateWidgetValue(newValue); updateWidgetValue(newValue);
@@ -611,7 +612,7 @@ export function addLorasWidget(node, name, opts, callback) {
const loraIndex = lorasData.findIndex(l => l.name === name); const loraIndex = lorasData.findIndex(l => l.name === name);
if (loraIndex >= 0) { if (loraIndex >= 0) {
lorasData[loraIndex].clipStrength = (parseFloat(lorasData[loraIndex].clipStrength) + 0.05).toFixed(2); lorasData[loraIndex].clipStrength = (parseFloat(lorasData[loraIndex].clipStrength) + getStrengthStepPreference()).toFixed(2);
const newValue = formatLoraValue(lorasData); const newValue = formatLoraValue(lorasData);
updateWidgetValue(newValue); updateWidgetValue(newValue);

View File

@@ -24,6 +24,9 @@ const NEW_TAB_TEMPLATE_DEFAULT = "Default";
const NEW_TAB_ZOOM_LEVEL = 0.8; const NEW_TAB_ZOOM_LEVEL = 0.8;
const STRENGTH_STEP_SETTING_ID = "loramanager.strength_step";
const STRENGTH_STEP_DEFAULT = 0.05;
// ============================================================================ // ============================================================================
// Helper Functions // Helper Functions
// ============================================================================ // ============================================================================
@@ -232,6 +235,32 @@ const getNewTabTemplatePreference = (() => {
}; };
})(); })();
const getStrengthStepPreference = (() => {
let settingsUnavailableLogged = false;
return () => {
const settingManager = app?.extensionManager?.setting;
if (!settingManager || typeof settingManager.get !== "function") {
if (!settingsUnavailableLogged) {
console.warn("LoRA Manager: settings API unavailable, using default strength step.");
settingsUnavailableLogged = true;
}
return STRENGTH_STEP_DEFAULT;
}
try {
const value = settingManager.get(STRENGTH_STEP_SETTING_ID);
return value ?? STRENGTH_STEP_DEFAULT;
} catch (error) {
if (!settingsUnavailableLogged) {
console.warn("LoRA Manager: unable to read strength step setting, using default.", error);
settingsUnavailableLogged = true;
}
return STRENGTH_STEP_DEFAULT;
}
};
})();
// ============================================================================ // ============================================================================
// Register Extension with All Settings // Register Extension with All Settings
// ============================================================================ // ============================================================================
@@ -293,6 +322,19 @@ app.registerExtension({
tooltip: "Choose a template workflow to load when creating a new workflow tab. 'Default (Blank)' keeps ComfyUI's original blank workflow behavior.", tooltip: "Choose a template workflow to load when creating a new workflow tab. 'Default (Blank)' keeps ComfyUI's original blank workflow behavior.",
category: ["LoRA Manager", "Workflow", "New Tab Template"], category: ["LoRA Manager", "Workflow", "New Tab Template"],
}, },
{
id: STRENGTH_STEP_SETTING_ID,
name: "Strength Adjustment Step",
type: "slider",
attrs: {
min: 0.01,
max: 0.1,
step: 0.01,
},
defaultValue: STRENGTH_STEP_DEFAULT,
tooltip: "Step size for adjusting LoRA strength via arrow buttons or keyboard (default: 0.05)",
category: ["LoRA Manager", "LoRA Widget", "Strength Step"],
},
], ],
async setup() { async setup() {
await loadWorkflowOptions(); await loadWorkflowOptions();
@@ -375,4 +417,5 @@ export {
getTagSpaceReplacementPreference, getTagSpaceReplacementPreference,
getUsageStatisticsPreference, getUsageStatisticsPreference,
getNewTabTemplatePreference, getNewTabTemplatePreference,
getStrengthStepPreference,
}; };

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long