mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Compare commits
24 Commits
b5a0725d2c
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4000b7f7e7 | ||
|
|
76c15105e6 | ||
|
|
b11c90e19b | ||
|
|
9f5d2d0c18 | ||
|
|
a0dc5229f4 | ||
|
|
61c31ecbd0 | ||
|
|
1ae1b0d607 | ||
|
|
8dd849892d | ||
|
|
03e1fa75c5 | ||
|
|
fefcaa4a45 | ||
|
|
701a6a6c44 | ||
|
|
0ef414d17e | ||
|
|
75dccaef87 | ||
|
|
7e87ec9521 | ||
|
|
46522edb1b | ||
|
|
2dae4c1291 | ||
|
|
a32325402e | ||
|
|
70c150bd80 | ||
|
|
9e81c33f8a | ||
|
|
22c0dbd734 | ||
|
|
d0c58472be | ||
|
|
b3c530bf36 | ||
|
|
05ebd7493d | ||
|
|
90986bd795 |
20
__init__.py
20
__init__.py
@@ -1,6 +1,8 @@
|
||||
try: # pragma: no cover - import fallback for pytest collection
|
||||
from .py.lora_manager import LoraManager
|
||||
from .py.nodes.lora_loader import LoraLoaderLM, LoraTextLoaderLM
|
||||
from .py.nodes.checkpoint_loader import CheckpointLoaderLM
|
||||
from .py.nodes.unet_loader import UNETLoaderLM
|
||||
from .py.nodes.trigger_word_toggle import TriggerWordToggleLM
|
||||
from .py.nodes.prompt import PromptLM
|
||||
from .py.nodes.text import TextLM
|
||||
@@ -27,12 +29,12 @@ except (
|
||||
PromptLM = importlib.import_module("py.nodes.prompt").PromptLM
|
||||
TextLM = importlib.import_module("py.nodes.text").TextLM
|
||||
LoraManager = importlib.import_module("py.lora_manager").LoraManager
|
||||
LoraLoaderLM = importlib.import_module(
|
||||
"py.nodes.lora_loader"
|
||||
).LoraLoaderLM
|
||||
LoraTextLoaderLM = importlib.import_module(
|
||||
"py.nodes.lora_loader"
|
||||
).LoraTextLoaderLM
|
||||
LoraLoaderLM = importlib.import_module("py.nodes.lora_loader").LoraLoaderLM
|
||||
LoraTextLoaderLM = importlib.import_module("py.nodes.lora_loader").LoraTextLoaderLM
|
||||
CheckpointLoaderLM = importlib.import_module(
|
||||
"py.nodes.checkpoint_loader"
|
||||
).CheckpointLoaderLM
|
||||
UNETLoaderLM = importlib.import_module("py.nodes.unet_loader").UNETLoaderLM
|
||||
TriggerWordToggleLM = importlib.import_module(
|
||||
"py.nodes.trigger_word_toggle"
|
||||
).TriggerWordToggleLM
|
||||
@@ -49,9 +51,7 @@ except (
|
||||
LoraRandomizerLM = importlib.import_module(
|
||||
"py.nodes.lora_randomizer"
|
||||
).LoraRandomizerLM
|
||||
LoraCyclerLM = importlib.import_module(
|
||||
"py.nodes.lora_cycler"
|
||||
).LoraCyclerLM
|
||||
LoraCyclerLM = importlib.import_module("py.nodes.lora_cycler").LoraCyclerLM
|
||||
init_metadata_collector = importlib.import_module("py.metadata_collector").init
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
@@ -59,6 +59,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
TextLM.NAME: TextLM,
|
||||
LoraLoaderLM.NAME: LoraLoaderLM,
|
||||
LoraTextLoaderLM.NAME: LoraTextLoaderLM,
|
||||
CheckpointLoaderLM.NAME: CheckpointLoaderLM,
|
||||
UNETLoaderLM.NAME: UNETLoaderLM,
|
||||
TriggerWordToggleLM.NAME: TriggerWordToggleLM,
|
||||
LoraStackerLM.NAME: LoraStackerLM,
|
||||
SaveImageLM.NAME: SaveImageLM,
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
"backToTop": "Nach oben",
|
||||
"settings": "Einstellungen",
|
||||
"help": "Hilfe",
|
||||
"add": "Hinzufügen"
|
||||
"add": "Hinzufügen",
|
||||
"close": "Schließen"
|
||||
},
|
||||
"status": {
|
||||
"loading": "Wird geladen...",
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
"backToTop": "Back to top",
|
||||
"settings": "Settings",
|
||||
"help": "Help",
|
||||
"add": "Add"
|
||||
"add": "Add",
|
||||
"close": "Close"
|
||||
},
|
||||
"status": {
|
||||
"loading": "Loading...",
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
"backToTop": "Volver arriba",
|
||||
"settings": "Configuración",
|
||||
"help": "Ayuda",
|
||||
"add": "Añadir"
|
||||
"add": "Añadir",
|
||||
"close": "Cerrar"
|
||||
},
|
||||
"status": {
|
||||
"loading": "Cargando...",
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
"backToTop": "Retour en haut",
|
||||
"settings": "Paramètres",
|
||||
"help": "Aide",
|
||||
"add": "Ajouter"
|
||||
"add": "Ajouter",
|
||||
"close": "Fermer"
|
||||
},
|
||||
"status": {
|
||||
"loading": "Chargement...",
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
"backToTop": "חזרה למעלה",
|
||||
"settings": "הגדרות",
|
||||
"help": "עזרה",
|
||||
"add": "הוספה"
|
||||
"add": "הוספה",
|
||||
"close": "סגור"
|
||||
},
|
||||
"status": {
|
||||
"loading": "טוען...",
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
"backToTop": "トップへ戻る",
|
||||
"settings": "設定",
|
||||
"help": "ヘルプ",
|
||||
"add": "追加"
|
||||
"add": "追加",
|
||||
"close": "閉じる"
|
||||
},
|
||||
"status": {
|
||||
"loading": "読み込み中...",
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
"backToTop": "맨 위로",
|
||||
"settings": "설정",
|
||||
"help": "도움말",
|
||||
"add": "추가"
|
||||
"add": "추가",
|
||||
"close": "닫기"
|
||||
},
|
||||
"status": {
|
||||
"loading": "로딩 중...",
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
"backToTop": "Наверх",
|
||||
"settings": "Настройки",
|
||||
"help": "Справка",
|
||||
"add": "Добавить"
|
||||
"add": "Добавить",
|
||||
"close": "Закрыть"
|
||||
},
|
||||
"status": {
|
||||
"loading": "Загрузка...",
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
"backToTop": "返回顶部",
|
||||
"settings": "设置",
|
||||
"help": "帮助",
|
||||
"add": "添加"
|
||||
"add": "添加",
|
||||
"close": "关闭"
|
||||
},
|
||||
"status": {
|
||||
"loading": "加载中...",
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
"backToTop": "回到頂部",
|
||||
"settings": "設定",
|
||||
"help": "說明",
|
||||
"add": "新增"
|
||||
"add": "新增",
|
||||
"close": "關閉"
|
||||
},
|
||||
"status": {
|
||||
"loading": "載入中...",
|
||||
|
||||
3
package-lock.json
generated
3
package-lock.json
generated
@@ -114,7 +114,6 @@
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
},
|
||||
@@ -138,7 +137,6 @@
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
@@ -1613,7 +1611,6 @@
|
||||
"integrity": "sha512-MyL55p3Ut3cXbeBEG7Hcv0mVM8pp8PBNWxRqchZnSfAiES1v1mRnMeFfaHWIPULpwsYfvO+ZmMZz5tGCnjzDUQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"cssstyle": "^4.0.1",
|
||||
"data-urls": "^5.0.0",
|
||||
|
||||
47
py/config.py
47
py/config.py
@@ -707,7 +707,13 @@ class Config:
|
||||
|
||||
def _prepare_checkpoint_paths(
|
||||
self, checkpoint_paths: Iterable[str], unet_paths: Iterable[str]
|
||||
) -> List[str]:
|
||||
) -> Tuple[List[str], List[str], List[str]]:
|
||||
"""Prepare checkpoint paths and return (all_roots, checkpoint_roots, unet_roots).
|
||||
|
||||
Returns:
|
||||
Tuple of (all_unique_paths, checkpoint_only_paths, unet_only_paths)
|
||||
This method does NOT modify instance variables - callers must set them.
|
||||
"""
|
||||
checkpoint_map = self._dedupe_existing_paths(checkpoint_paths)
|
||||
unet_map = self._dedupe_existing_paths(unet_paths)
|
||||
|
||||
@@ -737,8 +743,8 @@ class Config:
|
||||
|
||||
checkpoint_values = set(checkpoint_map.values())
|
||||
unet_values = set(unet_map.values())
|
||||
self.checkpoints_roots = [p for p in unique_paths if p in checkpoint_values]
|
||||
self.unet_roots = [p for p in unique_paths if p in unet_values]
|
||||
checkpoint_roots = [p for p in unique_paths if p in checkpoint_values]
|
||||
unet_roots = [p for p in unique_paths if p in unet_values]
|
||||
|
||||
for original_path in unique_paths:
|
||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(
|
||||
@@ -747,7 +753,7 @@ class Config:
|
||||
if real_path != original_path:
|
||||
self.add_path_mapping(original_path, real_path)
|
||||
|
||||
return unique_paths
|
||||
return unique_paths, checkpoint_roots, unet_roots
|
||||
|
||||
def _prepare_embedding_paths(self, raw_paths: Iterable[str]) -> List[str]:
|
||||
path_map = self._dedupe_existing_paths(raw_paths)
|
||||
@@ -776,9 +782,11 @@ class Config:
|
||||
embedding_paths = folder_paths.get("embeddings", []) or []
|
||||
|
||||
self.loras_roots = self._prepare_lora_paths(lora_paths)
|
||||
self.base_models_roots = self._prepare_checkpoint_paths(
|
||||
checkpoint_paths, unet_paths
|
||||
)
|
||||
(
|
||||
self.base_models_roots,
|
||||
self.checkpoints_roots,
|
||||
self.unet_roots,
|
||||
) = self._prepare_checkpoint_paths(checkpoint_paths, unet_paths)
|
||||
self.embeddings_roots = self._prepare_embedding_paths(embedding_paths)
|
||||
|
||||
# Process extra paths (only for LoRA Manager, not shared with ComfyUI)
|
||||
@@ -789,18 +797,11 @@ class Config:
|
||||
extra_embedding_paths = extra_paths.get("embeddings", []) or []
|
||||
|
||||
self.extra_loras_roots = self._prepare_lora_paths(extra_lora_paths)
|
||||
# Save main paths before processing extra paths ( _prepare_checkpoint_paths overwrites them)
|
||||
saved_checkpoints_roots = self.checkpoints_roots
|
||||
saved_unet_roots = self.unet_roots
|
||||
self.extra_checkpoints_roots = self._prepare_checkpoint_paths(
|
||||
extra_checkpoint_paths, extra_unet_paths
|
||||
)
|
||||
self.extra_unet_roots = (
|
||||
self.unet_roots if self.unet_roots is not None else []
|
||||
) # unet_roots was set by _prepare_checkpoint_paths
|
||||
# Restore main paths
|
||||
self.checkpoints_roots = saved_checkpoints_roots
|
||||
self.unet_roots = saved_unet_roots
|
||||
(
|
||||
_,
|
||||
self.extra_checkpoints_roots,
|
||||
self.extra_unet_roots,
|
||||
) = self._prepare_checkpoint_paths(extra_checkpoint_paths, extra_unet_paths)
|
||||
self.extra_embeddings_roots = self._prepare_embedding_paths(
|
||||
extra_embedding_paths
|
||||
)
|
||||
@@ -857,9 +858,11 @@ class Config:
|
||||
try:
|
||||
raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||
raw_unet_paths = folder_paths.get_folder_paths("unet")
|
||||
unique_paths = self._prepare_checkpoint_paths(
|
||||
raw_checkpoint_paths, raw_unet_paths
|
||||
)
|
||||
(
|
||||
unique_paths,
|
||||
self.checkpoints_roots,
|
||||
self.unet_roots,
|
||||
) = self._prepare_checkpoint_paths(raw_checkpoint_paths, raw_unet_paths)
|
||||
|
||||
logger.info(
|
||||
"Found checkpoint roots:"
|
||||
|
||||
118
py/nodes/checkpoint_loader.py
Normal file
118
py/nodes/checkpoint_loader.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
import comfy.sd # type: ignore
|
||||
import folder_paths # type: ignore
|
||||
from ..utils.utils import get_checkpoint_info_absolute, _format_model_name_for_comfyui
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CheckpointLoaderLM:
|
||||
"""Checkpoint Loader with support for extra folder paths
|
||||
|
||||
Loads checkpoints from both standard ComfyUI folders and LoRA Manager's
|
||||
extra folder paths, providing a unified interface for checkpoint loading.
|
||||
"""
|
||||
|
||||
NAME = "Checkpoint Loader (LoraManager)"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
# Get list of checkpoint names from scanner (includes extra folder paths)
|
||||
checkpoint_names = s._get_checkpoint_names()
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (
|
||||
checkpoint_names,
|
||||
{"tooltip": "The name of the checkpoint (model) to load."},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
RETURN_NAMES = ("MODEL", "CLIP", "VAE")
|
||||
OUTPUT_TOOLTIPS = (
|
||||
"The model used for denoising latents.",
|
||||
"The CLIP model used for encoding text prompts.",
|
||||
"The VAE model used for encoding and decoding images to and from latent space.",
|
||||
)
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
||||
@classmethod
|
||||
def _get_checkpoint_names(cls) -> List[str]:
|
||||
"""Get list of checkpoint names from scanner cache in ComfyUI format (relative path with extension)"""
|
||||
try:
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
import asyncio
|
||||
|
||||
async def _get_names():
|
||||
scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
# Get all model roots for calculating relative paths
|
||||
model_roots = scanner.get_model_roots()
|
||||
|
||||
# Filter only checkpoint type (not diffusion_model) and format names
|
||||
names = []
|
||||
for item in cache.raw_data:
|
||||
if item.get("sub_type") == "checkpoint":
|
||||
file_path = item.get("file_path", "")
|
||||
if file_path:
|
||||
# Format using relative path with OS-native separator
|
||||
formatted_name = _format_model_name_for_comfyui(
|
||||
file_path, model_roots
|
||||
)
|
||||
if formatted_name:
|
||||
names.append(formatted_name)
|
||||
|
||||
return sorted(names)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_thread():
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(_get_names())
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
return future.result()
|
||||
except RuntimeError:
|
||||
return asyncio.run(_get_names())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting checkpoint names: {e}")
|
||||
return []
|
||||
|
||||
def load_checkpoint(self, ckpt_name: str) -> Tuple:
|
||||
"""Load a checkpoint by name, supporting extra folder paths
|
||||
|
||||
Args:
|
||||
ckpt_name: The name of the checkpoint to load (relative path with extension)
|
||||
|
||||
Returns:
|
||||
Tuple of (MODEL, CLIP, VAE)
|
||||
"""
|
||||
# Get absolute path from cache using ComfyUI-style name
|
||||
ckpt_path, metadata = get_checkpoint_info_absolute(ckpt_name)
|
||||
|
||||
if metadata is None:
|
||||
raise FileNotFoundError(
|
||||
f"Checkpoint '{ckpt_name}' not found in LoRA Manager cache. "
|
||||
"Make sure the checkpoint is indexed and try again."
|
||||
)
|
||||
|
||||
# Load regular checkpoint using ComfyUI's API
|
||||
logger.info(f"Loading checkpoint from: {ckpt_path}")
|
||||
out = comfy.sd.load_checkpoint_guess_config(
|
||||
ckpt_path,
|
||||
output_vae=True,
|
||||
output_clip=True,
|
||||
embedding_directory=folder_paths.get_folder_paths("embeddings"),
|
||||
)
|
||||
return out[:3]
|
||||
161
py/nodes/gguf_import_helper.py
Normal file
161
py/nodes/gguf_import_helper.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
Helper module to safely import ComfyUI-GGUF modules.
|
||||
|
||||
This module provides a robust way to import ComfyUI-GGUF functionality
|
||||
regardless of how ComfyUI loaded it.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import importlib.util
|
||||
import logging
|
||||
from typing import Optional, Tuple, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_gguf_path() -> str:
|
||||
"""Get the path to ComfyUI-GGUF based on this file's location.
|
||||
|
||||
Since ComfyUI-Lora-Manager and ComfyUI-GGUF are both in custom_nodes/,
|
||||
we can derive the GGUF path from our own location.
|
||||
"""
|
||||
# This file is at: custom_nodes/ComfyUI-Lora-Manager/py/nodes/gguf_import_helper.py
|
||||
# ComfyUI-GGUF is at: custom_nodes/ComfyUI-GGUF
|
||||
current_file = os.path.abspath(__file__)
|
||||
# Go up 4 levels: nodes -> py -> ComfyUI-Lora-Manager -> custom_nodes
|
||||
custom_nodes_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(current_file)))
|
||||
)
|
||||
return os.path.join(custom_nodes_dir, "ComfyUI-GGUF")
|
||||
|
||||
|
||||
def _find_gguf_module() -> Optional[Any]:
|
||||
"""Find ComfyUI-GGUF module in sys.modules.
|
||||
|
||||
ComfyUI registers modules using the full path with dots replaced by _x_.
|
||||
"""
|
||||
gguf_path = _get_gguf_path()
|
||||
sys_module_name = gguf_path.replace(".", "_x_")
|
||||
|
||||
logger.debug(f"[GGUF Import] Looking for module '{sys_module_name}' in sys.modules")
|
||||
if sys_module_name in sys.modules:
|
||||
logger.info(f"[GGUF Import] Found module: '{sys_module_name}'")
|
||||
return sys.modules[sys_module_name]
|
||||
|
||||
logger.debug(f"[GGUF Import] Module not found: '{sys_module_name}'")
|
||||
return None
|
||||
|
||||
|
||||
def _load_gguf_modules_directly() -> Optional[Any]:
|
||||
"""Load ComfyUI-GGUF modules directly from file paths."""
|
||||
gguf_path = _get_gguf_path()
|
||||
|
||||
logger.info(f"[GGUF Import] Direct Load: Attempting to load from '{gguf_path}'")
|
||||
|
||||
if not os.path.exists(gguf_path):
|
||||
logger.warning(f"[GGUF Import] Path does not exist: {gguf_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
namespace = "ComfyUI_GGUF_Dynamic"
|
||||
init_path = os.path.join(gguf_path, "__init__.py")
|
||||
|
||||
if not os.path.exists(init_path):
|
||||
logger.warning(f"[GGUF Import] __init__.py not found at '{init_path}'")
|
||||
return None
|
||||
|
||||
logger.debug(f"[GGUF Import] Loading from '{init_path}'")
|
||||
spec = importlib.util.spec_from_file_location(namespace, init_path)
|
||||
if not spec or not spec.loader:
|
||||
logger.error(f"[GGUF Import] Failed to create spec for '{init_path}'")
|
||||
return None
|
||||
|
||||
package = importlib.util.module_from_spec(spec)
|
||||
package.__path__ = [gguf_path]
|
||||
sys.modules[namespace] = package
|
||||
spec.loader.exec_module(package)
|
||||
logger.debug(f"[GGUF Import] Loaded main package '{namespace}'")
|
||||
|
||||
# Load submodules
|
||||
loaded = []
|
||||
for submod_name in ["loader", "ops", "nodes"]:
|
||||
submod_path = os.path.join(gguf_path, f"{submod_name}.py")
|
||||
if os.path.exists(submod_path):
|
||||
submod_spec = importlib.util.spec_from_file_location(
|
||||
f"{namespace}.{submod_name}", submod_path
|
||||
)
|
||||
if submod_spec and submod_spec.loader:
|
||||
submod = importlib.util.module_from_spec(submod_spec)
|
||||
submod.__package__ = namespace
|
||||
sys.modules[f"{namespace}.{submod_name}"] = submod
|
||||
submod_spec.loader.exec_module(submod)
|
||||
setattr(package, submod_name, submod)
|
||||
loaded.append(submod_name)
|
||||
logger.debug(f"[GGUF Import] Loaded submodule '{submod_name}'")
|
||||
|
||||
logger.info(f"[GGUF Import] Direct Load success: {loaded}")
|
||||
return package
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GGUF Import] Direct Load failed: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def get_gguf_modules() -> Tuple[Any, Any, Any]:
|
||||
"""Get ComfyUI-GGUF modules (loader, ops, nodes).
|
||||
|
||||
Returns:
|
||||
Tuple of (loader_module, ops_module, nodes_module)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If ComfyUI-GGUF cannot be found or loaded.
|
||||
"""
|
||||
logger.debug("[GGUF Import] Starting module search...")
|
||||
|
||||
# Try to find already loaded module first
|
||||
gguf_module = _find_gguf_module()
|
||||
|
||||
if gguf_module is None:
|
||||
logger.info("[GGUF Import] Not found in sys.modules, trying direct load...")
|
||||
gguf_module = _load_gguf_modules_directly()
|
||||
|
||||
if gguf_module is None:
|
||||
raise RuntimeError(
|
||||
"ComfyUI-GGUF is not installed. "
|
||||
"Please install from https://github.com/city96/ComfyUI-GGUF"
|
||||
)
|
||||
|
||||
# Extract submodules
|
||||
loader = getattr(gguf_module, "loader", None)
|
||||
ops = getattr(gguf_module, "ops", None)
|
||||
nodes = getattr(gguf_module, "nodes", None)
|
||||
|
||||
if loader is None or ops is None or nodes is None:
|
||||
missing = [
|
||||
name
|
||||
for name, mod in [("loader", loader), ("ops", ops), ("nodes", nodes)]
|
||||
if mod is None
|
||||
]
|
||||
raise RuntimeError(f"ComfyUI-GGUF missing submodules: {missing}")
|
||||
|
||||
logger.debug("[GGUF Import] All modules loaded successfully")
|
||||
return loader, ops, nodes
|
||||
|
||||
|
||||
def get_gguf_sd_loader():
|
||||
"""Get the gguf_sd_loader function from ComfyUI-GGUF."""
|
||||
loader, _, _ = get_gguf_modules()
|
||||
return getattr(loader, "gguf_sd_loader")
|
||||
|
||||
|
||||
def get_ggml_ops():
|
||||
"""Get the GGMLOps class from ComfyUI-GGUF."""
|
||||
_, ops, _ = get_gguf_modules()
|
||||
return getattr(ops, "GGMLOps")
|
||||
|
||||
|
||||
def get_gguf_model_patcher():
|
||||
"""Get the GGUFModelPatcher class from ComfyUI-GGUF."""
|
||||
_, _, nodes = get_gguf_modules()
|
||||
return getattr(nodes, "GGUFModelPatcher")
|
||||
@@ -56,6 +56,9 @@ class LoraCyclerLM:
|
||||
clip_strength = float(cycler_config.get("clip_strength", 1.0))
|
||||
sort_by = "filename"
|
||||
|
||||
# Include "no lora" option
|
||||
include_no_lora = cycler_config.get("include_no_lora", False)
|
||||
|
||||
# Dual-index mechanism for batch queue synchronization
|
||||
execution_index = cycler_config.get("execution_index") # Can be None
|
||||
# next_index_from_config = cycler_config.get("next_index") # Not used on backend
|
||||
@@ -71,7 +74,10 @@ class LoraCyclerLM:
|
||||
|
||||
total_count = len(lora_list)
|
||||
|
||||
if total_count == 0:
|
||||
# Calculate effective total count (includes no lora option if enabled)
|
||||
effective_total_count = total_count + 1 if include_no_lora else total_count
|
||||
|
||||
if total_count == 0 and not include_no_lora:
|
||||
logger.warning("[LoraCyclerLM] No LoRAs available in pool")
|
||||
return {
|
||||
"result": ([],),
|
||||
@@ -93,42 +99,66 @@ class LoraCyclerLM:
|
||||
else:
|
||||
actual_index = current_index
|
||||
|
||||
# Clamp index to valid range (1-based)
|
||||
clamped_index = max(1, min(actual_index, total_count))
|
||||
# Clamp index to valid range (1-based, includes no lora if enabled)
|
||||
clamped_index = max(1, min(actual_index, effective_total_count))
|
||||
|
||||
# Get LoRA at current index (convert to 0-based for list access)
|
||||
current_lora = lora_list[clamped_index - 1]
|
||||
# Check if current index is the "no lora" option (last position when include_no_lora is True)
|
||||
is_no_lora = include_no_lora and clamped_index == effective_total_count
|
||||
|
||||
# Build LORA_STACK with single LoRA
|
||||
lora_path, _ = get_lora_info(current_lora["file_name"])
|
||||
if not lora_path:
|
||||
logger.warning(
|
||||
f"[LoraCyclerLM] Could not find path for LoRA: {current_lora['file_name']}"
|
||||
)
|
||||
if is_no_lora:
|
||||
# "No LoRA" option - return empty stack
|
||||
lora_stack = []
|
||||
current_lora_name = "No LoRA"
|
||||
current_lora_filename = "No LoRA"
|
||||
else:
|
||||
# Normalize path separators
|
||||
lora_path = lora_path.replace("/", os.sep)
|
||||
lora_stack = [(lora_path, model_strength, clip_strength)]
|
||||
# Get LoRA at current index (convert to 0-based for list access)
|
||||
current_lora = lora_list[clamped_index - 1]
|
||||
current_lora_name = current_lora["file_name"]
|
||||
current_lora_filename = current_lora["file_name"]
|
||||
|
||||
# Build LORA_STACK with single LoRA
|
||||
if current_lora["file_name"] == "None":
|
||||
lora_path = None
|
||||
else:
|
||||
lora_path, _ = get_lora_info(current_lora["file_name"])
|
||||
|
||||
if not lora_path:
|
||||
if current_lora["file_name"] != "None":
|
||||
logger.warning(
|
||||
f"[LoraCyclerLM] Could not find path for LoRA: {current_lora['file_name']}"
|
||||
)
|
||||
lora_stack = []
|
||||
else:
|
||||
# Normalize path separators
|
||||
lora_path = lora_path.replace("/", os.sep)
|
||||
lora_stack = [(lora_path, model_strength, clip_strength)]
|
||||
|
||||
# Calculate next index (wrap to 1 if at end)
|
||||
next_index = clamped_index + 1
|
||||
if next_index > total_count:
|
||||
if next_index > effective_total_count:
|
||||
next_index = 1
|
||||
|
||||
# Get next LoRA for UI display (what will be used next generation)
|
||||
next_lora = lora_list[next_index - 1]
|
||||
next_display_name = next_lora["file_name"]
|
||||
is_next_no_lora = include_no_lora and next_index == effective_total_count
|
||||
if is_next_no_lora:
|
||||
next_display_name = "No LoRA"
|
||||
next_lora_filename = "No LoRA"
|
||||
else:
|
||||
next_lora = lora_list[next_index - 1]
|
||||
next_display_name = next_lora["file_name"]
|
||||
next_lora_filename = next_lora["file_name"]
|
||||
|
||||
return {
|
||||
"result": (lora_stack,),
|
||||
"ui": {
|
||||
"current_index": [clamped_index],
|
||||
"next_index": [next_index],
|
||||
"total_count": [total_count],
|
||||
"current_lora_name": [current_lora["file_name"]],
|
||||
"current_lora_filename": [current_lora["file_name"]],
|
||||
"total_count": [
|
||||
total_count
|
||||
], # Return actual LoRA count, not effective_total_count
|
||||
"current_lora_name": [current_lora_name],
|
||||
"current_lora_filename": [current_lora_filename],
|
||||
"next_lora_name": [next_display_name],
|
||||
"next_lora_filename": [next_lora["file_name"]],
|
||||
"next_lora_filename": [next_lora_filename],
|
||||
},
|
||||
}
|
||||
|
||||
@@ -82,6 +82,7 @@ class LoraPoolLM:
|
||||
"folders": {"include": [], "exclude": []},
|
||||
"favoritesOnly": False,
|
||||
"license": {"noCreditRequired": False, "allowSelling": False},
|
||||
"namePatterns": {"include": [], "exclude": [], "useRegex": False},
|
||||
},
|
||||
"preview": {"matchCount": 0, "lastUpdated": 0},
|
||||
}
|
||||
|
||||
@@ -7,10 +7,8 @@ and tracks the last used combination for reuse.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
import os
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import extract_lora_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
205
py/nodes/unet_loader.py
Normal file
205
py/nodes/unet_loader.py
Normal file
@@ -0,0 +1,205 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
import comfy.sd # type: ignore
|
||||
from ..utils.utils import get_checkpoint_info_absolute, _format_model_name_for_comfyui
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UNETLoaderLM:
|
||||
"""UNET Loader with support for extra folder paths
|
||||
|
||||
Loads diffusion models/UNets from both standard ComfyUI folders and LoRA Manager's
|
||||
extra folder paths, providing a unified interface for UNET loading.
|
||||
Supports both regular diffusion models and GGUF format models.
|
||||
"""
|
||||
|
||||
NAME = "Unet Loader (LoraManager)"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
# Get list of unet names from scanner (includes extra folder paths)
|
||||
unet_names = s._get_unet_names()
|
||||
return {
|
||||
"required": {
|
||||
"unet_name": (
|
||||
unet_names,
|
||||
{"tooltip": "The name of the diffusion model to load."},
|
||||
),
|
||||
"weight_dtype": (
|
||||
["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],
|
||||
{"tooltip": "The dtype to use for the model weights."},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_NAMES = ("MODEL",)
|
||||
OUTPUT_TOOLTIPS = ("The model used for denoising latents.",)
|
||||
FUNCTION = "load_unet"
|
||||
|
||||
@classmethod
|
||||
def _get_unet_names(cls) -> List[str]:
|
||||
"""Get list of diffusion model names from scanner cache in ComfyUI format (relative path with extension)"""
|
||||
try:
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
import asyncio
|
||||
|
||||
async def _get_names():
|
||||
scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
# Get all model roots for calculating relative paths
|
||||
model_roots = scanner.get_model_roots()
|
||||
|
||||
# Filter only diffusion_model type and format names
|
||||
names = []
|
||||
for item in cache.raw_data:
|
||||
if item.get("sub_type") == "diffusion_model":
|
||||
file_path = item.get("file_path", "")
|
||||
if file_path:
|
||||
# Format using relative path with OS-native separator
|
||||
formatted_name = _format_model_name_for_comfyui(
|
||||
file_path, model_roots
|
||||
)
|
||||
if formatted_name:
|
||||
names.append(formatted_name)
|
||||
|
||||
return sorted(names)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_thread():
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(_get_names())
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
return future.result()
|
||||
except RuntimeError:
|
||||
return asyncio.run(_get_names())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting unet names: {e}")
|
||||
return []
|
||||
|
||||
def load_unet(self, unet_name: str, weight_dtype: str) -> Tuple:
|
||||
"""Load a diffusion model by name, supporting extra folder paths
|
||||
|
||||
Args:
|
||||
unet_name: The name of the diffusion model to load (relative path with extension)
|
||||
weight_dtype: The dtype to use for model weights
|
||||
|
||||
Returns:
|
||||
Tuple of (MODEL,)
|
||||
"""
|
||||
import torch
|
||||
|
||||
# Get absolute path from cache using ComfyUI-style name
|
||||
unet_path, metadata = get_checkpoint_info_absolute(unet_name)
|
||||
|
||||
if metadata is None:
|
||||
raise FileNotFoundError(
|
||||
f"Diffusion model '{unet_name}' not found in LoRA Manager cache. "
|
||||
"Make sure the model is indexed and try again."
|
||||
)
|
||||
|
||||
# Check if it's a GGUF model
|
||||
if unet_path.endswith(".gguf"):
|
||||
return self._load_gguf_unet(unet_path, unet_name, weight_dtype)
|
||||
|
||||
# Load regular diffusion model using ComfyUI's API
|
||||
logger.info(f"Loading diffusion model from: {unet_path}")
|
||||
|
||||
# Build model options based on weight_dtype
|
||||
model_options = {}
|
||||
if weight_dtype == "fp8_e4m3fn":
|
||||
model_options["dtype"] = torch.float8_e4m3fn
|
||||
elif weight_dtype == "fp8_e4m3fn_fast":
|
||||
model_options["dtype"] = torch.float8_e4m3fn
|
||||
model_options["fp8_optimizations"] = True
|
||||
elif weight_dtype == "fp8_e5m2":
|
||||
model_options["dtype"] = torch.float8_e5m2
|
||||
|
||||
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||
return (model,)
|
||||
|
||||
def _load_gguf_unet(
|
||||
self, unet_path: str, unet_name: str, weight_dtype: str
|
||||
) -> Tuple:
|
||||
"""Load a GGUF format diffusion model
|
||||
|
||||
Args:
|
||||
unet_path: Absolute path to the GGUF file
|
||||
unet_name: Name of the model for error messages
|
||||
weight_dtype: The dtype to use for model weights
|
||||
|
||||
Returns:
|
||||
Tuple of (MODEL,)
|
||||
"""
|
||||
import torch
|
||||
from .gguf_import_helper import get_gguf_modules
|
||||
|
||||
# Get ComfyUI-GGUF modules using helper (handles various import scenarios)
|
||||
try:
|
||||
loader_module, ops_module, nodes_module = get_gguf_modules()
|
||||
gguf_sd_loader = getattr(loader_module, "gguf_sd_loader")
|
||||
GGMLOps = getattr(ops_module, "GGMLOps")
|
||||
GGUFModelPatcher = getattr(nodes_module, "GGUFModelPatcher")
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(f"Cannot load GGUF model '{unet_name}'. {str(e)}")
|
||||
|
||||
logger.info(f"Loading GGUF diffusion model from: {unet_path}")
|
||||
|
||||
try:
|
||||
# Load GGUF state dict
|
||||
sd, extra = gguf_sd_loader(unet_path)
|
||||
|
||||
# Prepare kwargs for metadata if supported
|
||||
kwargs = {}
|
||||
import inspect
|
||||
|
||||
valid_params = inspect.signature(
|
||||
comfy.sd.load_diffusion_model_state_dict
|
||||
).parameters
|
||||
if "metadata" in valid_params:
|
||||
kwargs["metadata"] = extra.get("metadata", {})
|
||||
|
||||
# Setup custom operations with GGUF support
|
||||
ops = GGMLOps()
|
||||
|
||||
# Handle weight_dtype for GGUF models
|
||||
if weight_dtype in ("default", None):
|
||||
ops.Linear.dequant_dtype = None
|
||||
elif weight_dtype in ["target"]:
|
||||
ops.Linear.dequant_dtype = weight_dtype
|
||||
else:
|
||||
ops.Linear.dequant_dtype = getattr(torch, weight_dtype, None)
|
||||
|
||||
# Load the model
|
||||
model = comfy.sd.load_diffusion_model_state_dict(
|
||||
sd, model_options={"custom_operations": ops}, **kwargs
|
||||
)
|
||||
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
f"Could not detect model type for GGUF diffusion model: {unet_path}"
|
||||
)
|
||||
|
||||
# Wrap with GGUFModelPatcher
|
||||
model = GGUFModelPatcher.clone(model)
|
||||
|
||||
return (model,)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading GGUF diffusion model '{unet_name}': {e}")
|
||||
raise RuntimeError(
|
||||
f"Failed to load GGUF diffusion model '{unet_name}': {str(e)}"
|
||||
)
|
||||
@@ -309,6 +309,13 @@ class ModelListingHandler:
|
||||
else:
|
||||
allow_selling_generated_content = None # None means no filter applied
|
||||
|
||||
# Name pattern filters for LoRA Pool
|
||||
name_pattern_include = request.query.getall("name_pattern_include", [])
|
||||
name_pattern_exclude = request.query.getall("name_pattern_exclude", [])
|
||||
name_pattern_use_regex = (
|
||||
request.query.get("name_pattern_use_regex", "false").lower() == "true"
|
||||
)
|
||||
|
||||
return {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
@@ -328,6 +335,9 @@ class ModelListingHandler:
|
||||
"credit_required": credit_required,
|
||||
"allow_selling_generated_content": allow_selling_generated_content,
|
||||
"model_types": model_types,
|
||||
"name_pattern_include": name_pattern_include,
|
||||
"name_pattern_exclude": name_pattern_exclude,
|
||||
"name_pattern_use_regex": name_pattern_use_regex,
|
||||
**self._parse_specific_params(request),
|
||||
}
|
||||
|
||||
|
||||
@@ -208,7 +208,11 @@ class BaseModelService(ABC):
|
||||
|
||||
reverse = sort_params.order == "desc"
|
||||
annotated.sort(
|
||||
key=lambda x: (x.get("usage_count", 0), x.get("model_name", "").lower()),
|
||||
key=lambda x: (
|
||||
x.get("usage_count", 0),
|
||||
x.get("model_name", "").lower(),
|
||||
x.get("file_path", "").lower()
|
||||
),
|
||||
reverse=reverse,
|
||||
)
|
||||
return annotated
|
||||
|
||||
@@ -58,6 +58,7 @@ class CacheEntryValidator:
|
||||
'preview_nsfw_level': (0, False),
|
||||
'notes': ('', False),
|
||||
'usage_tips': ('', False),
|
||||
'hash_status': ('completed', False),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -90,13 +91,31 @@ class CacheEntryValidator:
|
||||
|
||||
errors: List[str] = []
|
||||
repaired = False
|
||||
|
||||
# If auto_repair is on, we work on a copy. If not, we still need a safe way to check fields.
|
||||
working_entry = dict(entry) if auto_repair else entry
|
||||
|
||||
# Determine effective hash_status for validation logic
|
||||
hash_status = entry.get('hash_status')
|
||||
if hash_status is None:
|
||||
if auto_repair:
|
||||
working_entry['hash_status'] = 'completed'
|
||||
repaired = True
|
||||
hash_status = 'completed'
|
||||
|
||||
for field_name, (default_value, is_required) in cls.CORE_FIELDS.items():
|
||||
value = working_entry.get(field_name)
|
||||
# Get current value from the original entry to avoid side effects during validation
|
||||
value = entry.get(field_name)
|
||||
|
||||
# Check if field is missing or None
|
||||
if value is None:
|
||||
# Special case: sha256 can be None/empty if hash_status is pending
|
||||
if field_name == 'sha256' and hash_status == 'pending':
|
||||
if auto_repair:
|
||||
working_entry[field_name] = ''
|
||||
repaired = True
|
||||
continue
|
||||
|
||||
if is_required:
|
||||
errors.append(f"Required field '{field_name}' is missing or None")
|
||||
if auto_repair:
|
||||
@@ -107,6 +126,10 @@ class CacheEntryValidator:
|
||||
# Validate field type and value
|
||||
field_error = cls._validate_field(field_name, value, default_value)
|
||||
if field_error:
|
||||
# Special case: allow empty string for sha256 if pending
|
||||
if field_name == 'sha256' and hash_status == 'pending' and value == '':
|
||||
continue
|
||||
|
||||
errors.append(field_error)
|
||||
if auto_repair:
|
||||
working_entry[field_name] = cls._get_default_copy(default_value)
|
||||
@@ -127,7 +150,7 @@ class CacheEntryValidator:
|
||||
# Special validation: sha256 must not be empty for required field
|
||||
# BUT allow empty sha256 when hash_status is pending (lazy hash calculation)
|
||||
sha256 = working_entry.get('sha256', '')
|
||||
hash_status = working_entry.get('hash_status', 'completed')
|
||||
# Use the effective hash_status we determined earlier
|
||||
if not sha256 or (isinstance(sha256, str) and not sha256.strip()):
|
||||
# Allow empty sha256 for lazy hash calculation (checkpoints)
|
||||
if hash_status != 'pending':
|
||||
@@ -144,8 +167,13 @@ class CacheEntryValidator:
|
||||
if isinstance(sha256, str):
|
||||
normalized_sha = sha256.lower().strip()
|
||||
if normalized_sha != sha256:
|
||||
working_entry['sha256'] = normalized_sha
|
||||
repaired = True
|
||||
if auto_repair:
|
||||
working_entry['sha256'] = normalized_sha
|
||||
repaired = True
|
||||
else:
|
||||
# If not auto-repairing, we don't consider case difference as a "critical error"
|
||||
# that invalidates the entry, but we also don't mark it repaired.
|
||||
pass
|
||||
|
||||
# Determine if entry is valid
|
||||
# Entry is valid if no critical required field errors remain after repair
|
||||
|
||||
@@ -13,22 +13,35 @@ from .model_hash_index import ModelHashIndex
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CheckpointScanner(ModelScanner):
|
||||
"""Service for scanning and managing checkpoint files"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft', '.gguf'}
|
||||
file_extensions = {
|
||||
".ckpt",
|
||||
".pt",
|
||||
".pt2",
|
||||
".bin",
|
||||
".pth",
|
||||
".safetensors",
|
||||
".pkl",
|
||||
".sft",
|
||||
".gguf",
|
||||
}
|
||||
super().__init__(
|
||||
model_type="checkpoint",
|
||||
model_class=CheckpointMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex()
|
||||
hash_index=ModelHashIndex(),
|
||||
)
|
||||
|
||||
async def _create_default_metadata(self, file_path: str) -> Optional[CheckpointMetadata]:
|
||||
async def _create_default_metadata(
|
||||
self, file_path: str
|
||||
) -> Optional[CheckpointMetadata]:
|
||||
"""Create default metadata for checkpoint without calculating hash (lazy hash).
|
||||
|
||||
|
||||
Checkpoints are typically large (10GB+), so we skip hash calculation during initial
|
||||
scanning to improve startup performance. Hash will be calculated on-demand when
|
||||
fetching metadata from Civitai.
|
||||
@@ -38,13 +51,13 @@ class CheckpointScanner(ModelScanner):
|
||||
if not os.path.exists(real_path):
|
||||
logger.error(f"File not found: {file_path}")
|
||||
return None
|
||||
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
dir_path = os.path.dirname(file_path)
|
||||
|
||||
|
||||
# Find preview image
|
||||
preview_url = find_preview_file(base_name, dir_path)
|
||||
|
||||
|
||||
# Create metadata WITHOUT calculating hash
|
||||
metadata = CheckpointMetadata(
|
||||
file_name=base_name,
|
||||
@@ -59,70 +72,76 @@ class CheckpointScanner(ModelScanner):
|
||||
modelDescription="",
|
||||
sub_type="checkpoint",
|
||||
from_civitai=False, # Mark as local model since no hash yet
|
||||
hash_status="pending" # Mark hash as pending
|
||||
hash_status="pending", # Mark hash as pending
|
||||
)
|
||||
|
||||
|
||||
# Save the created metadata
|
||||
logger.info(f"Creating checkpoint metadata (hash pending) for {file_path}")
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating default checkpoint metadata for {file_path}: {e}")
|
||||
logger.error(
|
||||
f"Error creating default checkpoint metadata for {file_path}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def calculate_hash_for_model(self, file_path: str) -> Optional[str]:
|
||||
"""Calculate hash for a checkpoint on-demand.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the model file
|
||||
|
||||
|
||||
Returns:
|
||||
SHA256 hash string, or None if calculation failed
|
||||
"""
|
||||
from ..utils.file_utils import calculate_sha256
|
||||
|
||||
|
||||
try:
|
||||
real_path = os.path.realpath(file_path)
|
||||
if not os.path.exists(real_path):
|
||||
logger.error(f"File not found for hash calculation: {file_path}")
|
||||
return None
|
||||
|
||||
|
||||
# Load current metadata
|
||||
metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class)
|
||||
metadata, _ = await MetadataManager.load_metadata(
|
||||
file_path, self.model_class
|
||||
)
|
||||
if metadata is None:
|
||||
logger.error(f"No metadata found for {file_path}")
|
||||
return None
|
||||
|
||||
|
||||
# Check if hash is already calculated
|
||||
if metadata.hash_status == "completed" and metadata.sha256:
|
||||
return metadata.sha256
|
||||
|
||||
|
||||
# Update status to calculating
|
||||
metadata.hash_status = "calculating"
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
|
||||
# Calculate hash
|
||||
logger.info(f"Calculating hash for checkpoint: {file_path}")
|
||||
sha256 = await calculate_sha256(real_path)
|
||||
|
||||
|
||||
# Update metadata with hash
|
||||
metadata.sha256 = sha256
|
||||
metadata.hash_status = "completed"
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
|
||||
# Update hash index
|
||||
self._hash_index.add_entry(sha256.lower(), file_path)
|
||||
|
||||
|
||||
logger.info(f"Hash calculated for checkpoint: {file_path}")
|
||||
return sha256
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating hash for {file_path}: {e}")
|
||||
# Update status to failed
|
||||
try:
|
||||
metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class)
|
||||
metadata, _ = await MetadataManager.load_metadata(
|
||||
file_path, self.model_class
|
||||
)
|
||||
if metadata:
|
||||
metadata.hash_status = "failed"
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
@@ -130,43 +149,46 @@ class CheckpointScanner(ModelScanner):
|
||||
pass
|
||||
return None
|
||||
|
||||
async def calculate_all_pending_hashes(self, progress_callback=None) -> Dict[str, int]:
|
||||
async def calculate_all_pending_hashes(
|
||||
self, progress_callback=None
|
||||
) -> Dict[str, int]:
|
||||
"""Calculate hashes for all checkpoints with pending hash status.
|
||||
|
||||
|
||||
If cache is not initialized, scans filesystem directly for metadata files
|
||||
with hash_status != 'completed'.
|
||||
|
||||
|
||||
Args:
|
||||
progress_callback: Optional callback(progress, total, current_file)
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with 'completed', 'failed', 'total' counts
|
||||
"""
|
||||
# Try to get from cache first
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
|
||||
if cache and cache.raw_data:
|
||||
# Use cache if available
|
||||
pending_models = [
|
||||
item for item in cache.raw_data
|
||||
if item.get('hash_status') != 'completed' or not item.get('sha256')
|
||||
item
|
||||
for item in cache.raw_data
|
||||
if item.get("hash_status") != "completed" or not item.get("sha256")
|
||||
]
|
||||
else:
|
||||
# Cache not initialized, scan filesystem directly
|
||||
pending_models = await self._find_pending_models_from_filesystem()
|
||||
|
||||
|
||||
if not pending_models:
|
||||
return {'completed': 0, 'failed': 0, 'total': 0}
|
||||
|
||||
return {"completed": 0, "failed": 0, "total": 0}
|
||||
|
||||
total = len(pending_models)
|
||||
completed = 0
|
||||
failed = 0
|
||||
|
||||
|
||||
for i, model_data in enumerate(pending_models):
|
||||
file_path = model_data.get('file_path')
|
||||
file_path = model_data.get("file_path")
|
||||
if not file_path:
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
sha256 = await self.calculate_hash_for_model(file_path)
|
||||
if sha256:
|
||||
@@ -176,77 +198,102 @@ class CheckpointScanner(ModelScanner):
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating hash for {file_path}: {e}")
|
||||
failed += 1
|
||||
|
||||
|
||||
if progress_callback:
|
||||
try:
|
||||
await progress_callback(i + 1, total, file_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
'completed': completed,
|
||||
'failed': failed,
|
||||
'total': total
|
||||
}
|
||||
|
||||
|
||||
return {"completed": completed, "failed": failed, "total": total}
|
||||
|
||||
async def _find_pending_models_from_filesystem(self) -> List[Dict[str, Any]]:
|
||||
"""Scan filesystem for checkpoint metadata files with pending hash status."""
|
||||
pending_models = []
|
||||
|
||||
|
||||
for root_path in self.get_model_roots():
|
||||
if not os.path.exists(root_path):
|
||||
continue
|
||||
|
||||
|
||||
for dirpath, _dirnames, filenames in os.walk(root_path):
|
||||
for filename in filenames:
|
||||
if not filename.endswith('.metadata.json'):
|
||||
if not filename.endswith(".metadata.json"):
|
||||
continue
|
||||
|
||||
|
||||
metadata_path = os.path.join(dirpath, filename)
|
||||
try:
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
with open(metadata_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
|
||||
# Check if hash is pending
|
||||
hash_status = data.get('hash_status', 'completed')
|
||||
sha256 = data.get('sha256', '')
|
||||
|
||||
if hash_status != 'completed' or not sha256:
|
||||
hash_status = data.get("hash_status", "completed")
|
||||
sha256 = data.get("sha256", "")
|
||||
|
||||
if hash_status != "completed" or not sha256:
|
||||
# Find corresponding model file
|
||||
model_name = filename.replace('.metadata.json', '')
|
||||
model_name = filename.replace(".metadata.json", "")
|
||||
model_path = None
|
||||
|
||||
|
||||
# Look for model file with matching name
|
||||
for ext in self.file_extensions:
|
||||
potential_path = os.path.join(dirpath, model_name + ext)
|
||||
if os.path.exists(potential_path):
|
||||
model_path = potential_path
|
||||
break
|
||||
|
||||
|
||||
if model_path:
|
||||
pending_models.append({
|
||||
'file_path': model_path.replace(os.sep, '/'),
|
||||
'hash_status': hash_status,
|
||||
'sha256': sha256,
|
||||
**{k: v for k, v in data.items() if k not in ['file_path', 'hash_status', 'sha256']}
|
||||
})
|
||||
pending_models.append(
|
||||
{
|
||||
"file_path": model_path.replace(os.sep, "/"),
|
||||
"hash_status": hash_status,
|
||||
"sha256": sha256,
|
||||
**{
|
||||
k: v
|
||||
for k, v in data.items()
|
||||
if k
|
||||
not in [
|
||||
"file_path",
|
||||
"hash_status",
|
||||
"sha256",
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.debug(f"Error reading metadata file {metadata_path}: {e}")
|
||||
logger.debug(
|
||||
f"Error reading metadata file {metadata_path}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
return pending_models
|
||||
|
||||
def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]:
|
||||
"""Resolve the sub-type based on the root path."""
|
||||
"""Resolve the sub-type based on the root path.
|
||||
|
||||
Checks both standard ComfyUI paths and LoRA Manager's extra folder paths.
|
||||
"""
|
||||
if not root_path:
|
||||
return None
|
||||
|
||||
# Check standard ComfyUI checkpoint paths
|
||||
if config.checkpoints_roots and root_path in config.checkpoints_roots:
|
||||
return "checkpoint"
|
||||
|
||||
# Check extra checkpoint paths
|
||||
if (
|
||||
config.extra_checkpoints_roots
|
||||
and root_path in config.extra_checkpoints_roots
|
||||
):
|
||||
return "checkpoint"
|
||||
|
||||
# Check standard ComfyUI unet paths
|
||||
if config.unet_roots and root_path in config.unet_roots:
|
||||
return "diffusion_model"
|
||||
|
||||
# Check extra unet paths
|
||||
if config.extra_unet_roots and root_path in config.extra_unet_roots:
|
||||
return "diffusion_model"
|
||||
|
||||
return None
|
||||
|
||||
def adjust_metadata(self, metadata, file_path, root_path):
|
||||
|
||||
@@ -27,7 +27,7 @@ class LoraService(BaseModelService):
|
||||
# Resolve sub_type using priority: sub_type > model_type > civitai.model.type > default
|
||||
# Normalize to lowercase for consistent API responses
|
||||
sub_type = resolve_sub_type(lora_data).lower()
|
||||
|
||||
|
||||
return {
|
||||
"model_name": lora_data["model_name"],
|
||||
"file_name": lora_data["file_name"],
|
||||
@@ -48,7 +48,9 @@ class LoraService(BaseModelService):
|
||||
"notes": lora_data.get("notes", ""),
|
||||
"favorite": lora_data.get("favorite", False),
|
||||
"update_available": bool(lora_data.get("update_available", False)),
|
||||
"skip_metadata_refresh": bool(lora_data.get("skip_metadata_refresh", False)),
|
||||
"skip_metadata_refresh": bool(
|
||||
lora_data.get("skip_metadata_refresh", False)
|
||||
),
|
||||
"sub_type": sub_type,
|
||||
"civitai": self.filter_civitai_data(
|
||||
lora_data.get("civitai", {}), minimal=True
|
||||
@@ -62,6 +64,68 @@ class LoraService(BaseModelService):
|
||||
if first_letter:
|
||||
data = self._filter_by_first_letter(data, first_letter)
|
||||
|
||||
# Handle name pattern filters
|
||||
name_pattern_include = kwargs.get("name_pattern_include", [])
|
||||
name_pattern_exclude = kwargs.get("name_pattern_exclude", [])
|
||||
name_pattern_use_regex = kwargs.get("name_pattern_use_regex", False)
|
||||
|
||||
if name_pattern_include or name_pattern_exclude:
|
||||
import re
|
||||
|
||||
def matches_pattern(name, pattern, use_regex):
|
||||
"""Check if name matches pattern (regex or substring)"""
|
||||
if not name:
|
||||
return False
|
||||
if use_regex:
|
||||
try:
|
||||
return bool(re.search(pattern, name, re.IGNORECASE))
|
||||
except re.error:
|
||||
# Invalid regex, fall back to substring match
|
||||
return pattern.lower() in name.lower()
|
||||
else:
|
||||
return pattern.lower() in name.lower()
|
||||
|
||||
def matches_any_pattern(name, patterns, use_regex):
|
||||
"""Check if name matches any of the patterns"""
|
||||
if not patterns:
|
||||
return True
|
||||
return any(matches_pattern(name, p, use_regex) for p in patterns)
|
||||
|
||||
filtered = []
|
||||
for lora in data:
|
||||
model_name = lora.get("model_name", "")
|
||||
file_name = lora.get("file_name", "")
|
||||
names_to_check = [n for n in [model_name, file_name] if n]
|
||||
|
||||
# Check exclude patterns first
|
||||
excluded = False
|
||||
if name_pattern_exclude:
|
||||
for name in names_to_check:
|
||||
if matches_any_pattern(
|
||||
name, name_pattern_exclude, name_pattern_use_regex
|
||||
):
|
||||
excluded = True
|
||||
break
|
||||
|
||||
if excluded:
|
||||
continue
|
||||
|
||||
# Check include patterns
|
||||
if name_pattern_include:
|
||||
included = False
|
||||
for name in names_to_check:
|
||||
if matches_any_pattern(
|
||||
name, name_pattern_include, name_pattern_use_regex
|
||||
):
|
||||
included = True
|
||||
break
|
||||
if not included:
|
||||
continue
|
||||
|
||||
filtered.append(lora)
|
||||
|
||||
data = filtered
|
||||
|
||||
return data
|
||||
|
||||
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
|
||||
@@ -368,9 +432,7 @@ class LoraService(BaseModelService):
|
||||
rng.uniform(clip_strength_min, clip_strength_max), 2
|
||||
)
|
||||
else:
|
||||
clip_str = round(
|
||||
rng.uniform(clip_strength_min, clip_strength_max), 2
|
||||
)
|
||||
clip_str = round(rng.uniform(clip_strength_min, clip_strength_max), 2)
|
||||
|
||||
result_loras.append(
|
||||
{
|
||||
@@ -485,12 +547,69 @@ class LoraService(BaseModelService):
|
||||
if bool(lora.get("license_flags", 127) & (1 << 1))
|
||||
]
|
||||
|
||||
# Apply name pattern filters
|
||||
name_patterns = filter_section.get("namePatterns", {})
|
||||
include_patterns = name_patterns.get("include", [])
|
||||
exclude_patterns = name_patterns.get("exclude", [])
|
||||
use_regex = name_patterns.get("useRegex", False)
|
||||
|
||||
if include_patterns or exclude_patterns:
|
||||
import re
|
||||
|
||||
def matches_pattern(name, pattern, use_regex):
|
||||
"""Check if name matches pattern (regex or substring)"""
|
||||
if not name:
|
||||
return False
|
||||
if use_regex:
|
||||
try:
|
||||
return bool(re.search(pattern, name, re.IGNORECASE))
|
||||
except re.error:
|
||||
# Invalid regex, fall back to substring match
|
||||
return pattern.lower() in name.lower()
|
||||
else:
|
||||
return pattern.lower() in name.lower()
|
||||
|
||||
def matches_any_pattern(name, patterns, use_regex):
|
||||
"""Check if name matches any of the patterns"""
|
||||
if not patterns:
|
||||
return True
|
||||
return any(matches_pattern(name, p, use_regex) for p in patterns)
|
||||
|
||||
filtered = []
|
||||
for lora in available_loras:
|
||||
model_name = lora.get("model_name", "")
|
||||
file_name = lora.get("file_name", "")
|
||||
names_to_check = [n for n in [model_name, file_name] if n]
|
||||
|
||||
# Check exclude patterns first
|
||||
excluded = False
|
||||
if exclude_patterns:
|
||||
for name in names_to_check:
|
||||
if matches_any_pattern(name, exclude_patterns, use_regex):
|
||||
excluded = True
|
||||
break
|
||||
|
||||
if excluded:
|
||||
continue
|
||||
|
||||
# Check include patterns
|
||||
if include_patterns:
|
||||
included = False
|
||||
for name in names_to_check:
|
||||
if matches_any_pattern(name, include_patterns, use_regex):
|
||||
included = True
|
||||
break
|
||||
if not included:
|
||||
continue
|
||||
|
||||
filtered.append(lora)
|
||||
|
||||
available_loras = filtered
|
||||
|
||||
return available_loras
|
||||
|
||||
async def get_cycler_list(
|
||||
self,
|
||||
pool_config: Optional[Dict] = None,
|
||||
sort_by: str = "filename"
|
||||
self, pool_config: Optional[Dict] = None, sort_by: str = "filename"
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Get filtered and sorted LoRA list for cycling.
|
||||
@@ -516,12 +635,18 @@ class LoraService(BaseModelService):
|
||||
if sort_by == "model_name":
|
||||
available_loras = sorted(
|
||||
available_loras,
|
||||
key=lambda x: (x.get("model_name") or x.get("file_name", "")).lower()
|
||||
key=lambda x: (
|
||||
(x.get("model_name") or x.get("file_name", "")).lower(),
|
||||
x.get("file_path", "").lower(),
|
||||
),
|
||||
)
|
||||
else: # Default to filename
|
||||
available_loras = sorted(
|
||||
available_loras,
|
||||
key=lambda x: x.get("file_name", "").lower()
|
||||
key=lambda x: (
|
||||
x.get("file_name", "").lower(),
|
||||
x.get("file_path", "").lower(),
|
||||
),
|
||||
)
|
||||
|
||||
# Return minimal data needed for cycling
|
||||
|
||||
@@ -221,33 +221,45 @@ class ModelCache:
|
||||
start_time = time.perf_counter()
|
||||
reverse = (order == 'desc')
|
||||
if sort_key == 'name':
|
||||
# Natural sort by configured display name, case-insensitive
|
||||
# Natural sort by configured display name, case-insensitive, with file_path as tie-breaker
|
||||
result = natsorted(
|
||||
data,
|
||||
key=lambda x: self._get_display_name(x).lower(),
|
||||
key=lambda x: (
|
||||
self._get_display_name(x).lower(),
|
||||
x.get('file_path', '').lower()
|
||||
),
|
||||
reverse=reverse
|
||||
)
|
||||
elif sort_key == 'date':
|
||||
# Sort by modified timestamp (use .get() with default to handle missing fields)
|
||||
# Sort by modified timestamp, fallback to name and path for stability
|
||||
result = sorted(
|
||||
data,
|
||||
key=lambda x: x.get('modified', 0.0),
|
||||
key=lambda x: (
|
||||
x.get('modified', 0.0),
|
||||
self._get_display_name(x).lower(),
|
||||
x.get('file_path', '').lower()
|
||||
),
|
||||
reverse=reverse
|
||||
)
|
||||
elif sort_key == 'size':
|
||||
# Sort by file size (use .get() with default to handle missing fields)
|
||||
# Sort by file size, fallback to name and path for stability
|
||||
result = sorted(
|
||||
data,
|
||||
key=lambda x: x.get('size', 0),
|
||||
key=lambda x: (
|
||||
x.get('size', 0),
|
||||
self._get_display_name(x).lower(),
|
||||
x.get('file_path', '').lower()
|
||||
),
|
||||
reverse=reverse
|
||||
)
|
||||
elif sort_key == 'usage':
|
||||
# Sort by usage count, fallback to 0, then name for stability
|
||||
# Sort by usage count, fallback to 0, then name and path for stability
|
||||
return sorted(
|
||||
data,
|
||||
key=lambda x: (
|
||||
x.get('usage_count', 0),
|
||||
self._get_display_name(x).lower()
|
||||
self._get_display_name(x).lower(),
|
||||
x.get('file_path', '').lower()
|
||||
),
|
||||
reverse=reverse
|
||||
)
|
||||
|
||||
@@ -14,7 +14,6 @@ from ..utils.metadata_manager import MetadataManager
|
||||
from ..utils.civitai_utils import resolve_license_info
|
||||
from .model_cache import ModelCache
|
||||
from .model_hash_index import ModelHashIndex
|
||||
from ..utils.constants import PREVIEW_EXTENSIONS
|
||||
from .model_lifecycle_service import delete_model_artifacts
|
||||
from .service_registry import ServiceRegistry
|
||||
from .websocket_manager import ws_manager
|
||||
@@ -1442,14 +1441,13 @@ class ModelScanner:
|
||||
file_path = self._hash_index.get_path(sha256.lower())
|
||||
if not file_path:
|
||||
return None
|
||||
|
||||
base_name = os.path.splitext(file_path)[0]
|
||||
|
||||
for ext in PREVIEW_EXTENSIONS:
|
||||
preview_path = f"{base_name}{ext}"
|
||||
if os.path.exists(preview_path):
|
||||
return config.get_preview_static_url(preview_path)
|
||||
|
||||
|
||||
dir_path = os.path.dirname(file_path)
|
||||
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
preview_path = find_preview_file(base_name, dir_path)
|
||||
if preview_path:
|
||||
return config.get_preview_static_url(preview_path)
|
||||
|
||||
return None
|
||||
|
||||
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
|
||||
|
||||
@@ -56,6 +56,7 @@ class PersistentModelCache:
|
||||
"exclude",
|
||||
"db_checked",
|
||||
"last_checked_at",
|
||||
"hash_status",
|
||||
)
|
||||
_MODEL_UPDATE_COLUMNS: Tuple[str, ...] = _MODEL_COLUMNS[2:]
|
||||
_instances: Dict[str, "PersistentModelCache"] = {}
|
||||
@@ -186,6 +187,7 @@ class PersistentModelCache:
|
||||
"civitai_deleted": bool(row["civitai_deleted"]),
|
||||
"skip_metadata_refresh": bool(row["skip_metadata_refresh"]),
|
||||
"license_flags": int(license_value),
|
||||
"hash_status": row["hash_status"] or "completed",
|
||||
}
|
||||
raw_data.append(item)
|
||||
|
||||
@@ -449,6 +451,7 @@ class PersistentModelCache:
|
||||
exclude INTEGER,
|
||||
db_checked INTEGER,
|
||||
last_checked_at REAL,
|
||||
hash_status TEXT,
|
||||
PRIMARY KEY (model_type, file_path)
|
||||
);
|
||||
|
||||
@@ -496,6 +499,7 @@ class PersistentModelCache:
|
||||
"skip_metadata_refresh": "INTEGER DEFAULT 0",
|
||||
# Persisting without explicit flags should assume CivitAI's documented defaults (0b111001 == 57).
|
||||
"license_flags": f"INTEGER DEFAULT {DEFAULT_LICENSE_FLAGS}",
|
||||
"hash_status": "TEXT DEFAULT 'completed'",
|
||||
}
|
||||
|
||||
for column, definition in required_columns.items():
|
||||
@@ -570,6 +574,7 @@ class PersistentModelCache:
|
||||
1 if item.get("exclude") else 0,
|
||||
1 if item.get("db_checked") else 0,
|
||||
float(item.get("last_checked_at") or 0.0),
|
||||
item.get("hash_status", "completed"),
|
||||
)
|
||||
|
||||
def _insert_model_sql(self) -> str:
|
||||
|
||||
@@ -135,7 +135,8 @@ class RecipeCache:
|
||||
"""Sort cached views. Caller must hold ``_lock``."""
|
||||
|
||||
self.sorted_by_name = natsorted(
|
||||
self.raw_data, key=lambda x: x.get("title", "").lower()
|
||||
self.raw_data,
|
||||
key=lambda x: (x.get("title", "").lower(), x.get("file_path", "").lower()),
|
||||
)
|
||||
if not name_only:
|
||||
self.sorted_by_date = sorted(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Services responsible for recipe metadata analysis."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
@@ -69,7 +70,9 @@ class RecipeAnalysisService:
|
||||
try:
|
||||
metadata = self._exif_utils.extract_image_metadata(temp_path)
|
||||
if not metadata:
|
||||
return AnalysisResult({"error": "No metadata found in this image", "loras": []})
|
||||
return AnalysisResult(
|
||||
{"error": "No metadata found in this image", "loras": []}
|
||||
)
|
||||
|
||||
return await self._parse_metadata(
|
||||
metadata,
|
||||
@@ -105,29 +108,33 @@ class RecipeAnalysisService:
|
||||
if civitai_match:
|
||||
image_info = await civitai_client.get_image_info(civitai_match.group(1))
|
||||
if not image_info:
|
||||
raise RecipeDownloadError("Failed to fetch image information from Civitai")
|
||||
|
||||
raise RecipeDownloadError(
|
||||
"Failed to fetch image information from Civitai"
|
||||
)
|
||||
|
||||
image_url = image_info.get("url")
|
||||
if not image_url:
|
||||
raise RecipeDownloadError("No image URL found in Civitai response")
|
||||
|
||||
|
||||
is_video = image_info.get("type") == "video"
|
||||
|
||||
|
||||
# Use optimized preview URLs if possible
|
||||
rewritten_url, _ = rewrite_preview_url(image_url, media_type=image_info.get("type"))
|
||||
rewritten_url, _ = rewrite_preview_url(
|
||||
image_url, media_type=image_info.get("type")
|
||||
)
|
||||
if rewritten_url:
|
||||
image_url = rewritten_url
|
||||
|
||||
if is_video:
|
||||
# Extract extension from URL
|
||||
url_path = image_url.split('?')[0].split('#')[0]
|
||||
url_path = image_url.split("?")[0].split("#")[0]
|
||||
extension = os.path.splitext(url_path)[1].lower() or ".mp4"
|
||||
else:
|
||||
extension = ".jpg"
|
||||
|
||||
temp_path = self._create_temp_path(suffix=extension)
|
||||
await self._download_image(image_url, temp_path)
|
||||
|
||||
|
||||
metadata = image_info.get("meta") if "meta" in image_info else None
|
||||
if (
|
||||
isinstance(metadata, dict)
|
||||
@@ -135,15 +142,23 @@ class RecipeAnalysisService:
|
||||
and isinstance(metadata["meta"], dict)
|
||||
):
|
||||
metadata = metadata["meta"]
|
||||
|
||||
# Validate that metadata contains meaningful recipe fields
|
||||
# If not, treat as None to trigger EXIF extraction from downloaded image
|
||||
if isinstance(metadata, dict) and not self._has_recipe_fields(metadata):
|
||||
self._logger.debug(
|
||||
"Civitai API metadata lacks recipe fields, will extract from EXIF"
|
||||
)
|
||||
metadata = None
|
||||
else:
|
||||
# Basic extension detection for non-Civitai URLs
|
||||
url_path = url.split('?')[0].split('#')[0]
|
||||
url_path = url.split("?")[0].split("#")[0]
|
||||
extension = os.path.splitext(url_path)[1].lower()
|
||||
if extension in [".mp4", ".webm"]:
|
||||
is_video = True
|
||||
else:
|
||||
extension = ".jpg"
|
||||
|
||||
|
||||
temp_path = self._create_temp_path(suffix=extension)
|
||||
await self._download_image(url, temp_path)
|
||||
|
||||
@@ -211,7 +226,9 @@ class RecipeAnalysisService:
|
||||
|
||||
image_bytes = self._convert_tensor_to_png_bytes(latest_image)
|
||||
if image_bytes is None:
|
||||
raise RecipeValidationError("Cannot handle this data shape from metadata registry")
|
||||
raise RecipeValidationError(
|
||||
"Cannot handle this data shape from metadata registry"
|
||||
)
|
||||
|
||||
return AnalysisResult(
|
||||
{
|
||||
@@ -222,6 +239,22 @@ class RecipeAnalysisService:
|
||||
|
||||
# Internal helpers -------------------------------------------------
|
||||
|
||||
def _has_recipe_fields(self, metadata: dict[str, Any]) -> bool:
|
||||
"""Check if metadata contains meaningful recipe-related fields."""
|
||||
recipe_fields = {
|
||||
"prompt",
|
||||
"negative_prompt",
|
||||
"resources",
|
||||
"hashes",
|
||||
"params",
|
||||
"generationData",
|
||||
"Workflow",
|
||||
"prompt_type",
|
||||
"positive",
|
||||
"negative",
|
||||
}
|
||||
return any(field in metadata for field in recipe_fields)
|
||||
|
||||
async def _parse_metadata(
|
||||
self,
|
||||
metadata: dict[str, Any],
|
||||
@@ -234,7 +267,12 @@ class RecipeAnalysisService:
|
||||
) -> AnalysisResult:
|
||||
parser = self._recipe_parser_factory.create_parser(metadata)
|
||||
if parser is None:
|
||||
payload = {"error": "No parser found for this image", "loras": []}
|
||||
# Provide more specific error message based on metadata source
|
||||
if not metadata:
|
||||
error_msg = "This image does not contain any generation metadata (prompt, models, or parameters)"
|
||||
else:
|
||||
error_msg = "No parser found for this image"
|
||||
payload = {"error": error_msg, "loras": []}
|
||||
if include_image_base64 and image_path:
|
||||
payload["image_base64"] = self._encode_file(image_path)
|
||||
payload["is_video"] = is_video
|
||||
@@ -257,7 +295,9 @@ class RecipeAnalysisService:
|
||||
|
||||
matching_recipes: list[str] = []
|
||||
if fingerprint:
|
||||
matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(fingerprint)
|
||||
matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(
|
||||
fingerprint
|
||||
)
|
||||
result["matching_recipes"] = matching_recipes
|
||||
|
||||
return AnalysisResult(result)
|
||||
@@ -269,7 +309,10 @@ class RecipeAnalysisService:
|
||||
raise RecipeDownloadError(f"Failed to download image from URL: {result}")
|
||||
|
||||
def _metadata_not_found_response(self, path: str) -> AnalysisResult:
|
||||
payload: dict[str, Any] = {"error": "No metadata found in this image", "loras": []}
|
||||
payload: dict[str, Any] = {
|
||||
"error": "No metadata found in this image",
|
||||
"loras": [],
|
||||
}
|
||||
if os.path.exists(path):
|
||||
payload["image_base64"] = self._encode_file(path)
|
||||
return AnalysisResult(payload)
|
||||
@@ -305,7 +348,9 @@ class RecipeAnalysisService:
|
||||
|
||||
if hasattr(tensor_image, "shape"):
|
||||
self._logger.debug(
|
||||
"Tensor shape: %s, dtype: %s", tensor_image.shape, getattr(tensor_image, "dtype", None)
|
||||
"Tensor shape: %s, dtype: %s",
|
||||
tensor_image.shape,
|
||||
getattr(tensor_image, "dtype", None),
|
||||
)
|
||||
|
||||
import torch # type: ignore[import-not-found]
|
||||
|
||||
@@ -40,49 +40,39 @@ async def calculate_sha256(file_path: str) -> str:
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
def find_preview_file(base_name: str, dir_path: str) -> str:
|
||||
"""Find preview file for given base name in directory"""
|
||||
|
||||
"""Find preview file for given base name in directory.
|
||||
|
||||
Performs an exact-case check first (fast path), then falls back to a
|
||||
case-insensitive scan so that files like ``model.WEBP`` or ``model.Png``
|
||||
are discovered on case-sensitive filesystems.
|
||||
"""
|
||||
|
||||
temp_extensions = PREVIEW_EXTENSIONS.copy()
|
||||
# Add example extension for compatibility
|
||||
# https://github.com/willmiao/ComfyUI-Lora-Manager/issues/225
|
||||
# The preview image will be optimized to lora-name.webp, so it won't affect other logic
|
||||
temp_extensions.append(".example.0.jpeg")
|
||||
|
||||
# Fast path: exact-case match
|
||||
for ext in temp_extensions:
|
||||
full_pattern = os.path.join(dir_path, f"{base_name}{ext}")
|
||||
if os.path.exists(full_pattern):
|
||||
# Check if this is an image and not already webp
|
||||
# TODO: disable the optimization for now, maybe add a config option later
|
||||
# if ext.lower().endswith(('.jpg', '.jpeg', '.png')) and not ext.lower().endswith('.webp'):
|
||||
# try:
|
||||
# # Optimize the image to webp format
|
||||
# webp_path = os.path.join(dir_path, f"{base_name}.webp")
|
||||
|
||||
# # Use ExifUtils to optimize the image
|
||||
# with open(full_pattern, 'rb') as f:
|
||||
# image_data = f.read()
|
||||
|
||||
# optimized_data, _ = ExifUtils.optimize_image(
|
||||
# image_data=image_data,
|
||||
# target_width=CARD_PREVIEW_WIDTH,
|
||||
# format='webp',
|
||||
# quality=85,
|
||||
# preserve_metadata=False
|
||||
# )
|
||||
|
||||
# # Save the optimized webp file
|
||||
# with open(webp_path, 'wb') as f:
|
||||
# f.write(optimized_data)
|
||||
|
||||
# logger.debug(f"Optimized preview image from {full_pattern} to {webp_path}")
|
||||
# return webp_path.replace(os.sep, "/")
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error optimizing preview image {full_pattern}: {e}")
|
||||
# # Fall back to original file if optimization fails
|
||||
# return full_pattern.replace(os.sep, "/")
|
||||
|
||||
# Return the original path for webp images or non-image files
|
||||
return full_pattern.replace(os.sep, "/")
|
||||
|
||||
|
||||
# Slow path: case-insensitive match for systems with mixed-case extensions
|
||||
# (e.g. .WEBP, .Png, .JPG placed manually or by external tools)
|
||||
try:
|
||||
dir_entries = os.listdir(dir_path)
|
||||
except OSError:
|
||||
return ""
|
||||
|
||||
base_lower = base_name.lower()
|
||||
for ext in temp_extensions:
|
||||
target = f"{base_lower}{ext}" # ext is already lowercase
|
||||
for entry in dir_entries:
|
||||
if entry.lower() == target:
|
||||
return os.path.join(dir_path, entry).replace(os.sep, "/")
|
||||
|
||||
return ""
|
||||
|
||||
def get_preview_extension(preview_path: str) -> str:
|
||||
|
||||
@@ -112,6 +112,115 @@ def get_lora_info_absolute(lora_name):
|
||||
return asyncio.run(_get_lora_info_absolute_async())
|
||||
|
||||
|
||||
def get_checkpoint_info_absolute(checkpoint_name):
|
||||
"""Get the absolute checkpoint path and metadata from cache
|
||||
|
||||
Supports ComfyUI-style model names (e.g., "folder/model_name.ext")
|
||||
|
||||
Args:
|
||||
checkpoint_name: The model name, can be:
|
||||
- ComfyUI format: "folder/model_name.safetensors"
|
||||
- Simple name: "model_name"
|
||||
|
||||
Returns:
|
||||
tuple: (absolute_path, metadata) where absolute_path is the full
|
||||
file system path to the checkpoint file, or original checkpoint_name if not found,
|
||||
metadata is the full model metadata dict or None
|
||||
"""
|
||||
|
||||
async def _get_checkpoint_info_absolute_async():
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
# Get model roots for matching
|
||||
model_roots = scanner.get_model_roots()
|
||||
|
||||
# Normalize the checkpoint name
|
||||
normalized_name = checkpoint_name.replace(os.sep, "/")
|
||||
|
||||
for item in cache.raw_data:
|
||||
file_path = item.get("file_path", "")
|
||||
if not file_path:
|
||||
continue
|
||||
|
||||
# Format the stored path as ComfyUI-style name
|
||||
formatted_name = _format_model_name_for_comfyui(file_path, model_roots)
|
||||
|
||||
# Match by formatted name (normalize separators for robust comparison)
|
||||
if formatted_name.replace(os.sep, "/") == normalized_name or formatted_name == checkpoint_name:
|
||||
return file_path, item
|
||||
|
||||
# Also try matching by basename only (for backward compatibility)
|
||||
file_name = item.get("file_name", "")
|
||||
if (
|
||||
file_name == checkpoint_name
|
||||
or file_name == os.path.splitext(normalized_name)[0]
|
||||
):
|
||||
return file_path, item
|
||||
|
||||
return checkpoint_name, None
|
||||
|
||||
try:
|
||||
# Check if we're already in an event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we're in a running loop, we need to use a different approach
|
||||
# Create a new thread to run the async code
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_thread():
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(
|
||||
_get_checkpoint_info_absolute_async()
|
||||
)
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
return future.result()
|
||||
|
||||
except RuntimeError:
|
||||
# No event loop is running, we can use asyncio.run()
|
||||
return asyncio.run(_get_checkpoint_info_absolute_async())
|
||||
|
||||
|
||||
def _format_model_name_for_comfyui(file_path: str, model_roots: list) -> str:
|
||||
"""Format file path to ComfyUI-style model name (relative path with extension)
|
||||
|
||||
Example: /path/to/checkpoints/Illustrious/model.safetensors -> Illustrious/model.safetensors
|
||||
|
||||
Args:
|
||||
file_path: Absolute path to the model file
|
||||
model_roots: List of model root directories
|
||||
|
||||
Returns:
|
||||
ComfyUI-style model name with relative path and extension
|
||||
"""
|
||||
# Find the matching root and get relative path
|
||||
for root in model_roots:
|
||||
try:
|
||||
# Normalize paths for comparison
|
||||
norm_file = os.path.normcase(os.path.abspath(file_path))
|
||||
norm_root = os.path.normcase(os.path.abspath(root))
|
||||
|
||||
# Add trailing separator for prefix check
|
||||
if not norm_root.endswith(os.sep):
|
||||
norm_root += os.sep
|
||||
|
||||
if norm_file.startswith(norm_root):
|
||||
# Use os.path.relpath to get relative path with OS-native separator
|
||||
return os.path.relpath(file_path, root)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
# If no root matches, just return the basename with extension
|
||||
return os.path.basename(file_path)
|
||||
|
||||
|
||||
def fuzzy_match(text: str, pattern: str, threshold: float = 0.85) -> bool:
|
||||
"""
|
||||
Check if text matches pattern using fuzzy matching.
|
||||
@@ -173,10 +282,13 @@ def sanitize_folder_name(name: str, replacement: str = "_") -> str:
|
||||
# Collapse repeated replacement characters to a single instance
|
||||
if replacement:
|
||||
sanitized = re.sub(f"{re.escape(replacement)}+", replacement, sanitized)
|
||||
sanitized = sanitized.strip(replacement)
|
||||
|
||||
# Remove trailing spaces or periods which are invalid on Windows
|
||||
sanitized = sanitized.rstrip(" .")
|
||||
# Combine stripping to be idempotent:
|
||||
# Right side: strip replacement, space, and dot (Windows restriction)
|
||||
# Left side: strip replacement and space (leading dots are allowed)
|
||||
sanitized = sanitized.rstrip(" ." + replacement).lstrip(" " + replacement)
|
||||
else:
|
||||
# If no replacement, just strip spaces and dots from right, spaces from left
|
||||
sanitized = sanitized.rstrip(" .").lstrip(" ")
|
||||
|
||||
if not sanitized:
|
||||
return "unnamed"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[pytest]
|
||||
addopts = -v --import-mode=importlib -m "not performance"
|
||||
addopts = -v --import-mode=importlib -m "not performance" --ignore=__init__.py
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
|
||||
@@ -251,7 +251,7 @@ export class BaseModelApiClient {
|
||||
replaceModelPreview(filePath) {
|
||||
const input = document.createElement('input');
|
||||
input.type = 'file';
|
||||
input.accept = 'image/*,video/mp4';
|
||||
input.accept = 'image/*,image/webp,video/mp4';
|
||||
|
||||
input.onchange = async () => {
|
||||
if (!input.files || !input.files[0]) return;
|
||||
|
||||
@@ -2,6 +2,7 @@ import { modalManager } from './ModalManager.js';
|
||||
import { showToast } from '../utils/uiHelpers.js';
|
||||
import { translate } from '../utils/i18nHelpers.js';
|
||||
import { WS_ENDPOINTS } from '../api/apiConfig.js';
|
||||
import { getStorageItem, setStorageItem } from '../utils/storageHelpers.js';
|
||||
|
||||
/**
|
||||
* Manager for batch importing recipes from multiple images
|
||||
@@ -34,6 +35,14 @@ export class BatchImportManager {
|
||||
*/
|
||||
initialize() {
|
||||
this.initialized = true;
|
||||
|
||||
// Add event listener for persisting "Skip images without metadata" choice
|
||||
const skipNoMetadata = document.getElementById('batchSkipNoMetadata');
|
||||
if (skipNoMetadata) {
|
||||
skipNoMetadata.addEventListener('change', (e) => {
|
||||
setStorageItem('batch_import_skip_no_metadata', e.target.checked);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -61,7 +70,10 @@ export class BatchImportManager {
|
||||
if (tagsInput) tagsInput.value = '';
|
||||
|
||||
const skipNoMetadata = document.getElementById('batchSkipNoMetadata');
|
||||
if (skipNoMetadata) skipNoMetadata.checked = true;
|
||||
if (skipNoMetadata) {
|
||||
// Load preference from storage, defaulting to true
|
||||
skipNoMetadata.checked = getStorageItem('batch_import_skip_no_metadata', true);
|
||||
}
|
||||
|
||||
const recursiveCheck = document.getElementById('batchRecursiveCheck');
|
||||
if (recursiveCheck) recursiveCheck.checked = true;
|
||||
@@ -92,6 +104,14 @@ export class BatchImportManager {
|
||||
|
||||
// Clean up any existing connections
|
||||
this.cleanupConnections();
|
||||
|
||||
// Focus on the URL input field for better UX
|
||||
setTimeout(() => {
|
||||
const urlInput = document.getElementById('batchUrlInput');
|
||||
if (urlInput) {
|
||||
urlInput.focus();
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -36,8 +36,8 @@ class TestCheckpointPathOverlap:
|
||||
config._preview_root_paths = set()
|
||||
config._cached_fingerprint = None
|
||||
|
||||
# Call the method under test
|
||||
result = config._prepare_checkpoint_paths(
|
||||
# Call the method under test - now returns a tuple
|
||||
all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
|
||||
[str(checkpoints_link)], [str(unet_link)]
|
||||
)
|
||||
|
||||
@@ -50,21 +50,27 @@ class TestCheckpointPathOverlap:
|
||||
]
|
||||
assert len(warning_messages) == 1
|
||||
assert "checkpoints" in warning_messages[0].lower()
|
||||
assert "diffusion_models" in warning_messages[0].lower() or "unet" in warning_messages[0].lower()
|
||||
assert (
|
||||
"diffusion_models" in warning_messages[0].lower()
|
||||
or "unet" in warning_messages[0].lower()
|
||||
)
|
||||
# Verify warning mentions backward compatibility fallback
|
||||
assert "falling back" in warning_messages[0].lower() or "backward compatibility" in warning_messages[0].lower()
|
||||
assert (
|
||||
"falling back" in warning_messages[0].lower()
|
||||
or "backward compatibility" in warning_messages[0].lower()
|
||||
)
|
||||
|
||||
# Verify only one path is returned (deduplication still works)
|
||||
assert len(result) == 1
|
||||
assert len(all_paths) == 1
|
||||
# Prioritizes checkpoints path for backward compatibility
|
||||
assert _normalize(result[0]) == _normalize(str(checkpoints_link))
|
||||
assert _normalize(all_paths[0]) == _normalize(str(checkpoints_link))
|
||||
|
||||
# Verify checkpoints_roots has the path (prioritized)
|
||||
assert len(config.checkpoints_roots) == 1
|
||||
assert _normalize(config.checkpoints_roots[0]) == _normalize(str(checkpoints_link))
|
||||
# Verify checkpoint_roots has the path (prioritized)
|
||||
assert len(checkpoint_roots) == 1
|
||||
assert _normalize(checkpoint_roots[0]) == _normalize(str(checkpoints_link))
|
||||
|
||||
# Verify unet_roots is empty (overlapping paths removed)
|
||||
assert config.unet_roots == []
|
||||
assert unet_roots == []
|
||||
|
||||
def test_non_overlapping_paths_no_warning(
|
||||
self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog
|
||||
@@ -83,7 +89,7 @@ class TestCheckpointPathOverlap:
|
||||
config._preview_root_paths = set()
|
||||
config._cached_fingerprint = None
|
||||
|
||||
result = config._prepare_checkpoint_paths(
|
||||
all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
|
||||
[str(checkpoints_dir)], [str(unet_dir)]
|
||||
)
|
||||
|
||||
@@ -97,14 +103,14 @@ class TestCheckpointPathOverlap:
|
||||
assert len(warning_messages) == 0
|
||||
|
||||
# Verify both paths are returned
|
||||
assert len(result) == 2
|
||||
normalized_result = [_normalize(p) for p in result]
|
||||
assert len(all_paths) == 2
|
||||
normalized_result = [_normalize(p) for p in all_paths]
|
||||
assert _normalize(str(checkpoints_dir)) in normalized_result
|
||||
assert _normalize(str(unet_dir)) in normalized_result
|
||||
|
||||
# Verify both roots are properly set
|
||||
assert len(config.checkpoints_roots) == 1
|
||||
assert len(config.unet_roots) == 1
|
||||
assert len(checkpoint_roots) == 1
|
||||
assert len(unet_roots) == 1
|
||||
|
||||
def test_partial_overlap_prioritizes_checkpoints(
|
||||
self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog
|
||||
@@ -129,9 +135,9 @@ class TestCheckpointPathOverlap:
|
||||
config._cached_fingerprint = None
|
||||
|
||||
# One checkpoint path overlaps with one unet path
|
||||
result = config._prepare_checkpoint_paths(
|
||||
all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
|
||||
[str(shared_link), str(separate_checkpoint)],
|
||||
[str(shared_link), str(separate_unet)]
|
||||
[str(shared_link), str(separate_unet)],
|
||||
)
|
||||
|
||||
# Verify warning was logged for the overlapping path
|
||||
@@ -144,17 +150,20 @@ class TestCheckpointPathOverlap:
|
||||
assert len(warning_messages) == 1
|
||||
|
||||
# Verify 3 unique paths (shared counted once as checkpoint, plus separate ones)
|
||||
assert len(result) == 3
|
||||
assert len(all_paths) == 3
|
||||
|
||||
# Verify the overlapping path appears in warning message
|
||||
assert str(shared_link.name) in warning_messages[0] or str(shared_dir.name) in warning_messages[0]
|
||||
assert (
|
||||
str(shared_link.name) in warning_messages[0]
|
||||
or str(shared_dir.name) in warning_messages[0]
|
||||
)
|
||||
|
||||
# Verify checkpoints_roots includes both checkpoint paths (including the shared one)
|
||||
assert len(config.checkpoints_roots) == 2
|
||||
checkpoint_normalized = [_normalize(p) for p in config.checkpoints_roots]
|
||||
# Verify checkpoint_roots includes both checkpoint paths (including the shared one)
|
||||
assert len(checkpoint_roots) == 2
|
||||
checkpoint_normalized = [_normalize(p) for p in checkpoint_roots]
|
||||
assert _normalize(str(shared_link)) in checkpoint_normalized
|
||||
assert _normalize(str(separate_checkpoint)) in checkpoint_normalized
|
||||
|
||||
# Verify unet_roots only includes the non-overlapping unet path
|
||||
assert len(config.unet_roots) == 1
|
||||
assert _normalize(config.unet_roots[0]) == _normalize(str(separate_unet))
|
||||
assert len(unet_roots) == 1
|
||||
assert _normalize(unet_roots[0]) == _normalize(str(separate_unet))
|
||||
|
||||
@@ -267,4 +267,431 @@ describe('AutoComplete widget interactions', () => {
|
||||
const scrollTopAfter = autoComplete.scrollContainer?.scrollTop || 0;
|
||||
expect(scrollTopAfter).toBeGreaterThanOrEqual(scrollTopBefore);
|
||||
});
|
||||
|
||||
it('replaces entire multi-word phrase when it matches selected tag (Danbooru convention)', async () => {
|
||||
const mockTags = [
|
||||
{ tag_name: 'looking_to_the_side', category: 0, post_count: 1234 },
|
||||
{ tag_name: 'looking_away', category: 0, post_count: 5678 },
|
||||
];
|
||||
|
||||
fetchApiMock.mockResolvedValue({
|
||||
json: () => Promise.resolve({ success: true, words: mockTags }),
|
||||
});
|
||||
|
||||
caretHelperInstance.getBeforeCursor.mockReturnValue('looking to the side');
|
||||
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
|
||||
|
||||
const input = document.createElement('textarea');
|
||||
input.value = 'looking to the side';
|
||||
input.selectionStart = input.value.length;
|
||||
input.focus = vi.fn();
|
||||
input.setSelectionRange = vi.fn();
|
||||
document.body.append(input);
|
||||
|
||||
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||
const autoComplete = new AutoComplete(input, 'prompt', {
|
||||
debounceDelay: 0,
|
||||
showPreview: false,
|
||||
minChars: 1,
|
||||
});
|
||||
|
||||
autoComplete.searchType = 'custom_words';
|
||||
autoComplete.activeCommand = null;
|
||||
autoComplete.items = mockTags;
|
||||
autoComplete.selectedIndex = 0;
|
||||
|
||||
await autoComplete.insertSelection('looking_to_the_side');
|
||||
|
||||
expect(input.value).toBe('looking_to_the_side, ');
|
||||
expect(autoComplete.dropdown.style.display).toBe('none');
|
||||
expect(input.focus).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('replaces only last token when typing partial match (e.g., "hello 1gi" -> "1girl")', async () => {
|
||||
const mockTags = [
|
||||
{ tag_name: '1girl', category: 4, post_count: 500000 },
|
||||
{ tag_name: '1boy', category: 4, post_count: 300000 },
|
||||
];
|
||||
|
||||
fetchApiMock.mockResolvedValue({
|
||||
json: () => Promise.resolve({ success: true, words: mockTags }),
|
||||
});
|
||||
|
||||
caretHelperInstance.getBeforeCursor.mockReturnValue('hello 1gi');
|
||||
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
|
||||
|
||||
const input = document.createElement('textarea');
|
||||
input.value = 'hello 1gi';
|
||||
input.selectionStart = input.value.length;
|
||||
input.focus = vi.fn();
|
||||
input.setSelectionRange = vi.fn();
|
||||
document.body.append(input);
|
||||
|
||||
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||
const autoComplete = new AutoComplete(input, 'prompt', {
|
||||
debounceDelay: 0,
|
||||
showPreview: false,
|
||||
minChars: 1,
|
||||
});
|
||||
|
||||
autoComplete.searchType = 'custom_words';
|
||||
autoComplete.activeCommand = null;
|
||||
autoComplete.items = mockTags;
|
||||
autoComplete.selectedIndex = 0;
|
||||
autoComplete.currentSearchTerm = 'hello 1gi';
|
||||
|
||||
await autoComplete.insertSelection('1girl');
|
||||
|
||||
expect(input.value).toBe('hello 1girl, ');
|
||||
});
|
||||
|
||||
it('replaces entire phrase for underscore tag match (e.g., "blue hair" -> "blue_hair")', async () => {
|
||||
const mockTags = [
|
||||
{ tag_name: 'blue_hair', category: 0, post_count: 45000 },
|
||||
{ tag_name: 'blue_eyes', category: 0, post_count: 80000 },
|
||||
];
|
||||
|
||||
fetchApiMock.mockResolvedValue({
|
||||
json: () => Promise.resolve({ success: true, words: mockTags }),
|
||||
});
|
||||
|
||||
caretHelperInstance.getBeforeCursor.mockReturnValue('blue hair');
|
||||
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
|
||||
|
||||
const input = document.createElement('textarea');
|
||||
input.value = 'blue hair';
|
||||
input.selectionStart = input.value.length;
|
||||
input.focus = vi.fn();
|
||||
input.setSelectionRange = vi.fn();
|
||||
document.body.append(input);
|
||||
|
||||
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||
const autoComplete = new AutoComplete(input, 'prompt', {
|
||||
debounceDelay: 0,
|
||||
showPreview: false,
|
||||
minChars: 1,
|
||||
});
|
||||
|
||||
autoComplete.searchType = 'custom_words';
|
||||
autoComplete.activeCommand = null;
|
||||
autoComplete.items = mockTags;
|
||||
autoComplete.selectedIndex = 0;
|
||||
autoComplete.currentSearchTerm = 'blue hair';
|
||||
|
||||
await autoComplete.insertSelection('blue_hair');
|
||||
|
||||
expect(input.value).toBe('blue_hair, ');
|
||||
});
|
||||
|
||||
it('handles multi-word phrase with preceding text correctly', async () => {
|
||||
const mockTags = [
|
||||
{ tag_name: 'looking_to_the_side', category: 0, post_count: 1234 },
|
||||
];
|
||||
|
||||
fetchApiMock.mockResolvedValue({
|
||||
json: () => Promise.resolve({ success: true, words: mockTags }),
|
||||
});
|
||||
|
||||
caretHelperInstance.getBeforeCursor.mockReturnValue('1girl, looking to the side');
|
||||
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
|
||||
|
||||
const input = document.createElement('textarea');
|
||||
input.value = '1girl, looking to the side';
|
||||
input.selectionStart = input.value.length;
|
||||
input.focus = vi.fn();
|
||||
input.setSelectionRange = vi.fn();
|
||||
document.body.append(input);
|
||||
|
||||
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||
const autoComplete = new AutoComplete(input, 'prompt', {
|
||||
debounceDelay: 0,
|
||||
showPreview: false,
|
||||
minChars: 1,
|
||||
});
|
||||
|
||||
autoComplete.searchType = 'custom_words';
|
||||
autoComplete.activeCommand = null;
|
||||
autoComplete.items = mockTags;
|
||||
autoComplete.selectedIndex = 0;
|
||||
autoComplete.currentSearchTerm = 'looking to the side';
|
||||
|
||||
await autoComplete.insertSelection('looking_to_the_side');
|
||||
|
||||
expect(input.value).toBe('1girl, looking_to_the_side, ');
|
||||
});
|
||||
|
||||
it('replaces entire command and search term when using command mode with multi-word phrase', async () => {
|
||||
const mockTags = [
|
||||
{ tag_name: 'looking_to_the_side', category: 4, post_count: 1234 },
|
||||
{ tag_name: 'looking_away', category: 4, post_count: 5678 },
|
||||
];
|
||||
|
||||
fetchApiMock.mockResolvedValue({
|
||||
json: () => Promise.resolve({ success: true, words: mockTags }),
|
||||
});
|
||||
|
||||
// Simulate "/char looking to the side" input
|
||||
caretHelperInstance.getBeforeCursor.mockReturnValue('/char looking to the side');
|
||||
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
|
||||
|
||||
const input = document.createElement('textarea');
|
||||
input.value = '/char looking to the side';
|
||||
input.selectionStart = input.value.length;
|
||||
input.focus = vi.fn();
|
||||
input.setSelectionRange = vi.fn();
|
||||
document.body.append(input);
|
||||
|
||||
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||
const autoComplete = new AutoComplete(input, 'prompt', {
|
||||
debounceDelay: 0,
|
||||
showPreview: false,
|
||||
minChars: 1,
|
||||
});
|
||||
|
||||
// Set up command mode state
|
||||
autoComplete.searchType = 'custom_words';
|
||||
autoComplete.activeCommand = { categories: [4, 11], label: 'Character' };
|
||||
autoComplete.items = mockTags;
|
||||
autoComplete.selectedIndex = 0;
|
||||
autoComplete.currentSearchTerm = '/char looking to the side';
|
||||
|
||||
await autoComplete.insertSelection('looking_to_the_side');
|
||||
|
||||
// Command part should be replaced along with search term
|
||||
expect(input.value).toBe('looking_to_the_side, ');
|
||||
});
|
||||
|
||||
it('replaces only last token when multi-word query does not exactly match selected tag', async () => {
|
||||
const mockTags = [
|
||||
{ tag_name: 'blue_hair', category: 0, post_count: 45000 },
|
||||
{ tag_name: 'blue_eyes', category: 0, post_count: 80000 },
|
||||
];
|
||||
|
||||
fetchApiMock.mockResolvedValue({
|
||||
json: () => Promise.resolve({ success: true, words: mockTags }),
|
||||
});
|
||||
|
||||
// User types "looking to the blue" but selects "blue_hair" (doesn't match entire phrase)
|
||||
caretHelperInstance.getBeforeCursor.mockReturnValue('looking to the blue');
|
||||
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
|
||||
|
||||
const input = document.createElement('textarea');
|
||||
input.value = 'looking to the blue';
|
||||
input.selectionStart = input.value.length;
|
||||
input.focus = vi.fn();
|
||||
input.setSelectionRange = vi.fn();
|
||||
document.body.append(input);
|
||||
|
||||
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||
const autoComplete = new AutoComplete(input, 'prompt', {
|
||||
debounceDelay: 0,
|
||||
showPreview: false,
|
||||
minChars: 1,
|
||||
});
|
||||
|
||||
autoComplete.searchType = 'custom_words';
|
||||
autoComplete.activeCommand = null;
|
||||
autoComplete.items = mockTags;
|
||||
autoComplete.selectedIndex = 0;
|
||||
autoComplete.currentSearchTerm = 'looking to the blue';
|
||||
|
||||
await autoComplete.insertSelection('blue_hair');
|
||||
|
||||
// Only "blue" should be replaced, not the entire phrase
|
||||
expect(input.value).toBe('looking to the blue_hair, ');
|
||||
});
|
||||
|
||||
it('handles multiple consecutive spaces in multi-word phrase correctly', async () => {
|
||||
const mockTags = [
|
||||
{ tag_name: 'looking_to_the_side', category: 0, post_count: 1234 },
|
||||
];
|
||||
|
||||
fetchApiMock.mockResolvedValue({
|
||||
json: () => Promise.resolve({ success: true, words: mockTags }),
|
||||
});
|
||||
|
||||
// Input with multiple spaces between words
|
||||
caretHelperInstance.getBeforeCursor.mockReturnValue('looking to the side');
|
||||
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
|
||||
|
||||
const input = document.createElement('textarea');
|
||||
input.value = 'looking to the side';
|
||||
input.selectionStart = input.value.length;
|
||||
input.focus = vi.fn();
|
||||
input.setSelectionRange = vi.fn();
|
||||
document.body.append(input);
|
||||
|
||||
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||
const autoComplete = new AutoComplete(input, 'prompt', {
|
||||
debounceDelay: 0,
|
||||
showPreview: false,
|
||||
minChars: 1,
|
||||
});
|
||||
|
||||
autoComplete.searchType = 'custom_words';
|
||||
autoComplete.activeCommand = null;
|
||||
autoComplete.items = mockTags;
|
||||
autoComplete.selectedIndex = 0;
|
||||
autoComplete.currentSearchTerm = 'looking to the side';
|
||||
|
||||
await autoComplete.insertSelection('looking_to_the_side');
|
||||
|
||||
// Multiple spaces should be normalized to single underscores for matching
|
||||
expect(input.value).toBe('looking_to_the_side, ');
|
||||
});
|
||||
|
||||
it('handles command mode with partial match replacing only last token', async () => {
|
||||
const mockTags = [
|
||||
{ tag_name: 'blue_hair', category: 0, post_count: 45000 },
|
||||
];
|
||||
|
||||
fetchApiMock.mockResolvedValue({
|
||||
json: () => Promise.resolve({ success: true, words: mockTags }),
|
||||
});
|
||||
|
||||
// Command mode but selected tag doesn't match entire search phrase
|
||||
caretHelperInstance.getBeforeCursor.mockReturnValue('/general looking to the blue');
|
||||
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
|
||||
|
||||
const input = document.createElement('textarea');
|
||||
input.value = '/general looking to the blue';
|
||||
input.selectionStart = input.value.length;
|
||||
input.focus = vi.fn();
|
||||
input.setSelectionRange = vi.fn();
|
||||
document.body.append(input);
|
||||
|
||||
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||
const autoComplete = new AutoComplete(input, 'prompt', {
|
||||
debounceDelay: 0,
|
||||
showPreview: false,
|
||||
minChars: 1,
|
||||
});
|
||||
|
||||
// Command mode with activeCommand
|
||||
autoComplete.searchType = 'custom_words';
|
||||
autoComplete.activeCommand = { categories: [0, 7], label: 'General' };
|
||||
autoComplete.items = mockTags;
|
||||
autoComplete.selectedIndex = 0;
|
||||
autoComplete.currentSearchTerm = '/general looking to the blue';
|
||||
|
||||
await autoComplete.insertSelection('blue_hair');
|
||||
|
||||
// In command mode, the entire command + search term should be replaced
|
||||
expect(input.value).toBe('blue_hair, ');
|
||||
});
|
||||
|
||||
it('replaces entire phrase when selected tag starts with underscore version of search term (prefix match)', async () => {
|
||||
const mockTags = [
|
||||
{ tag_name: 'looking_to_the_side', category: 0, post_count: 1234 },
|
||||
];
|
||||
|
||||
fetchApiMock.mockResolvedValue({
|
||||
json: () => Promise.resolve({ success: true, words: mockTags }),
|
||||
});
|
||||
|
||||
// User types partial phrase "looking to the" and selects "looking_to_the_side"
|
||||
caretHelperInstance.getBeforeCursor.mockReturnValue('looking to the');
|
||||
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
|
||||
|
||||
const input = document.createElement('textarea');
|
||||
input.value = 'looking to the';
|
||||
input.selectionStart = input.value.length;
|
||||
input.focus = vi.fn();
|
||||
input.setSelectionRange = vi.fn();
|
||||
document.body.append(input);
|
||||
|
||||
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||
const autoComplete = new AutoComplete(input, 'prompt', {
|
||||
debounceDelay: 0,
|
||||
showPreview: false,
|
||||
minChars: 1,
|
||||
});
|
||||
|
||||
autoComplete.searchType = 'custom_words';
|
||||
autoComplete.activeCommand = null;
|
||||
autoComplete.items = mockTags;
|
||||
autoComplete.selectedIndex = 0;
|
||||
autoComplete.currentSearchTerm = 'looking to the';
|
||||
|
||||
await autoComplete.insertSelection('looking_to_the_side');
|
||||
|
||||
// Entire phrase should be replaced with selected tag (with underscores)
|
||||
expect(input.value).toBe('looking_to_the_side, ');
|
||||
});
|
||||
|
||||
it('inserts tag with underscores regardless of space replacement setting', async () => {
|
||||
const mockTags = [
|
||||
{ tag_name: 'blue_hair', category: 0, post_count: 45000 },
|
||||
];
|
||||
|
||||
fetchApiMock.mockResolvedValue({
|
||||
json: () => Promise.resolve({ success: true, words: mockTags }),
|
||||
});
|
||||
|
||||
caretHelperInstance.getBeforeCursor.mockReturnValue('blue');
|
||||
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
|
||||
|
||||
const input = document.createElement('textarea');
|
||||
input.value = 'blue';
|
||||
input.selectionStart = input.value.length;
|
||||
input.focus = vi.fn();
|
||||
input.setSelectionRange = vi.fn();
|
||||
document.body.append(input);
|
||||
|
||||
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||
const autoComplete = new AutoComplete(input, 'prompt', {
|
||||
debounceDelay: 0,
|
||||
showPreview: false,
|
||||
minChars: 1,
|
||||
});
|
||||
|
||||
autoComplete.searchType = 'custom_words';
|
||||
autoComplete.activeCommand = null;
|
||||
autoComplete.items = mockTags;
|
||||
autoComplete.selectedIndex = 0;
|
||||
|
||||
await autoComplete.insertSelection('blue_hair');
|
||||
|
||||
// Tag should be inserted with underscores, not spaces
|
||||
expect(input.value).toBe('blue_hair, ');
|
||||
});
|
||||
|
||||
it('replaces entire phrase when selected tag ends with underscore version of search term (suffix match)', async () => {
|
||||
const mockTags = [
|
||||
{ tag_name: 'looking_to_the_side', category: 0, post_count: 1234 },
|
||||
];
|
||||
|
||||
fetchApiMock.mockResolvedValue({
|
||||
json: () => Promise.resolve({ success: true, words: mockTags }),
|
||||
});
|
||||
|
||||
// User types suffix "to the side" and selects "looking_to_the_side"
|
||||
caretHelperInstance.getBeforeCursor.mockReturnValue('to the side');
|
||||
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
|
||||
|
||||
const input = document.createElement('textarea');
|
||||
input.value = 'to the side';
|
||||
input.selectionStart = input.value.length;
|
||||
input.focus = vi.fn();
|
||||
input.setSelectionRange = vi.fn();
|
||||
document.body.append(input);
|
||||
|
||||
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||
const autoComplete = new AutoComplete(input, 'prompt', {
|
||||
debounceDelay: 0,
|
||||
showPreview: false,
|
||||
minChars: 1,
|
||||
});
|
||||
|
||||
autoComplete.searchType = 'custom_words';
|
||||
autoComplete.activeCommand = null;
|
||||
autoComplete.items = mockTags;
|
||||
autoComplete.selectedIndex = 0;
|
||||
autoComplete.currentSearchTerm = 'to the side';
|
||||
|
||||
await autoComplete.insertSelection('looking_to_the_side');
|
||||
|
||||
// Entire phrase should be replaced with selected tag
|
||||
expect(input.value).toBe('looking_to_the_side, ');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -194,6 +194,7 @@ class TestCacheHealthMonitor:
|
||||
'preview_nsfw_level': 0,
|
||||
'notes': '',
|
||||
'usage_tips': '',
|
||||
'hash_status': 'completed',
|
||||
}
|
||||
incomplete_entry = {
|
||||
'file_path': '/models/test2.safetensors',
|
||||
|
||||
@@ -369,3 +369,289 @@ async def test_pool_filter_combined_all_filters(lora_service):
|
||||
# - tags: tag1 ✓
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["file_name"] == "match_all.safetensors"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_filter_name_patterns_include_text(lora_service):
|
||||
"""Test filtering by name patterns with text matching (useRegex=False)."""
|
||||
sample_loras = [
|
||||
{
|
||||
"file_name": "character_anime_v1.safetensors",
|
||||
"model_name": "Anime Character",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
{
|
||||
"file_name": "character_realistic_v1.safetensors",
|
||||
"model_name": "Realistic Character",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
{
|
||||
"file_name": "style_watercolor_v1.safetensors",
|
||||
"model_name": "Watercolor Style",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
]
|
||||
|
||||
# Test include patterns with text matching
|
||||
pool_config = {
|
||||
"baseModels": [],
|
||||
"tags": {"include": [], "exclude": []},
|
||||
"folders": {"include": [], "exclude": []},
|
||||
"license": {"noCreditRequired": False, "allowSelling": False},
|
||||
"namePatterns": {"include": ["character"], "exclude": [], "useRegex": False},
|
||||
}
|
||||
|
||||
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
|
||||
assert len(filtered) == 2
|
||||
file_names = {lora["file_name"] for lora in filtered}
|
||||
assert file_names == {
|
||||
"character_anime_v1.safetensors",
|
||||
"character_realistic_v1.safetensors",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_filter_name_patterns_exclude_text(lora_service):
|
||||
"""Test excluding by name patterns with text matching (useRegex=False)."""
|
||||
sample_loras = [
|
||||
{
|
||||
"file_name": "character_anime_v1.safetensors",
|
||||
"model_name": "Anime Character",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
{
|
||||
"file_name": "character_realistic_v1.safetensors",
|
||||
"model_name": "Realistic Character",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
{
|
||||
"file_name": "style_watercolor_v1.safetensors",
|
||||
"model_name": "Watercolor Style",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
]
|
||||
|
||||
# Test exclude patterns with text matching
|
||||
pool_config = {
|
||||
"baseModels": [],
|
||||
"tags": {"include": [], "exclude": []},
|
||||
"folders": {"include": [], "exclude": []},
|
||||
"license": {"noCreditRequired": False, "allowSelling": False},
|
||||
"namePatterns": {"include": [], "exclude": ["anime"], "useRegex": False},
|
||||
}
|
||||
|
||||
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
|
||||
assert len(filtered) == 2
|
||||
file_names = {lora["file_name"] for lora in filtered}
|
||||
assert file_names == {
|
||||
"character_realistic_v1.safetensors",
|
||||
"style_watercolor_v1.safetensors",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_filter_name_patterns_include_regex(lora_service):
|
||||
"""Test filtering by name patterns with regex matching (useRegex=True)."""
|
||||
sample_loras = [
|
||||
{
|
||||
"file_name": "character_anime_v1.safetensors",
|
||||
"model_name": "Anime Character",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
{
|
||||
"file_name": "character_realistic_v1.safetensors",
|
||||
"model_name": "Realistic Character",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
{
|
||||
"file_name": "style_watercolor_v1.safetensors",
|
||||
"model_name": "Watercolor Style",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
]
|
||||
|
||||
# Test include patterns with regex matching - match files starting with "character_"
|
||||
pool_config = {
|
||||
"baseModels": [],
|
||||
"tags": {"include": [], "exclude": []},
|
||||
"folders": {"include": [], "exclude": []},
|
||||
"license": {"noCreditRequired": False, "allowSelling": False},
|
||||
"namePatterns": {"include": ["^character_"], "exclude": [], "useRegex": True},
|
||||
}
|
||||
|
||||
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
|
||||
assert len(filtered) == 2
|
||||
file_names = {lora["file_name"] for lora in filtered}
|
||||
assert file_names == {
|
||||
"character_anime_v1.safetensors",
|
||||
"character_realistic_v1.safetensors",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_filter_name_patterns_exclude_regex(lora_service):
|
||||
"""Test excluding by name patterns with regex matching (useRegex=True)."""
|
||||
sample_loras = [
|
||||
{
|
||||
"file_name": "character_anime_v1.safetensors",
|
||||
"model_name": "Anime Character",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
{
|
||||
"file_name": "character_realistic_v1.safetensors",
|
||||
"model_name": "Realistic Character",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
{
|
||||
"file_name": "style_watercolor_v1.safetensors",
|
||||
"model_name": "Watercolor Style",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
]
|
||||
|
||||
# Test exclude patterns with regex matching - exclude files ending with "_v1.safetensors"
|
||||
pool_config = {
|
||||
"baseModels": [],
|
||||
"tags": {"include": [], "exclude": []},
|
||||
"folders": {"include": [], "exclude": []},
|
||||
"license": {"noCreditRequired": False, "allowSelling": False},
|
||||
"namePatterns": {
|
||||
"include": [],
|
||||
"exclude": ["_v1\\.safetensors$"],
|
||||
"useRegex": True,
|
||||
},
|
||||
}
|
||||
|
||||
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
|
||||
assert len(filtered) == 0 # All files match the exclude pattern
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_filter_name_patterns_combined(lora_service):
|
||||
"""Test combining include and exclude name patterns."""
|
||||
sample_loras = [
|
||||
{
|
||||
"file_name": "character_anime_v1.safetensors",
|
||||
"model_name": "Anime Character",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
{
|
||||
"file_name": "character_realistic_v1.safetensors",
|
||||
"model_name": "Realistic Character",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
{
|
||||
"file_name": "style_watercolor_v1.safetensors",
|
||||
"model_name": "Watercolor Style",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
]
|
||||
|
||||
# Test include "character" but exclude "anime"
|
||||
pool_config = {
|
||||
"baseModels": [],
|
||||
"tags": {"include": [], "exclude": []},
|
||||
"folders": {"include": [], "exclude": []},
|
||||
"license": {"noCreditRequired": False, "allowSelling": False},
|
||||
"namePatterns": {
|
||||
"include": ["character"],
|
||||
"exclude": ["anime"],
|
||||
"useRegex": False,
|
||||
},
|
||||
}
|
||||
|
||||
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["file_name"] == "character_realistic_v1.safetensors"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_filter_name_patterns_model_name_fallback(lora_service):
|
||||
"""Test that name pattern filtering falls back to model_name when file_name doesn't match."""
|
||||
sample_loras = [
|
||||
{
|
||||
"file_name": "abc123.safetensors",
|
||||
"model_name": "Super Anime Character",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
{
|
||||
"file_name": "def456.safetensors",
|
||||
"model_name": "Realistic Portrait",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
]
|
||||
|
||||
# Should match model_name even if file_name doesn't contain the pattern
|
||||
pool_config = {
|
||||
"baseModels": [],
|
||||
"tags": {"include": [], "exclude": []},
|
||||
"folders": {"include": [], "exclude": []},
|
||||
"license": {"noCreditRequired": False, "allowSelling": False},
|
||||
"namePatterns": {"include": ["anime"], "exclude": [], "useRegex": False},
|
||||
}
|
||||
|
||||
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["file_name"] == "abc123.safetensors"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_filter_name_patterns_invalid_regex(lora_service):
|
||||
"""Test that invalid regex falls back to substring matching."""
|
||||
sample_loras = [
|
||||
{
|
||||
"file_name": "character_anime[test]_v1.safetensors",
|
||||
"model_name": "Anime Character",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
]
|
||||
|
||||
# Invalid regex pattern (unclosed character class) should fall back to substring matching
|
||||
# The pattern "anime[" is invalid regex but valid substring - it exists in the filename
|
||||
pool_config = {
|
||||
"baseModels": [],
|
||||
"tags": {"include": [], "exclude": []},
|
||||
"folders": {"include": [], "exclude": []},
|
||||
"license": {"noCreditRequired": False, "allowSelling": False},
|
||||
"namePatterns": {"include": ["anime["], "exclude": [], "useRegex": True},
|
||||
}
|
||||
|
||||
# Should not crash and should match using substring fallback
|
||||
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
|
||||
assert len(filtered) == 1 # Substring match works even with invalid regex
|
||||
|
||||
158
tests/test_checkpoint_loaders.py
Normal file
158
tests/test_checkpoint_loaders.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Tests for checkpoint and unet loaders with extra folder paths support"""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
|
||||
|
||||
# Get project root directory (ComfyUI-Lora-Manager folder)
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
class TestCheckpointLoaderLM:
|
||||
"""Test CheckpointLoaderLM node"""
|
||||
|
||||
def test_class_attributes(self):
|
||||
"""Test that CheckpointLoaderLM has required class attributes"""
|
||||
# Import in a way that doesn't require ComfyUI
|
||||
import ast
|
||||
|
||||
filepath = os.path.join(PROJECT_ROOT, "py", "nodes", "checkpoint_loader.py")
|
||||
|
||||
with open(filepath, "r") as f:
|
||||
tree = ast.parse(f.read())
|
||||
|
||||
# Find CheckpointLoaderLM class
|
||||
classes = {
|
||||
node.name: node for node in ast.walk(tree) if isinstance(node, ast.ClassDef)
|
||||
}
|
||||
assert "CheckpointLoaderLM" in classes
|
||||
|
||||
cls = classes["CheckpointLoaderLM"]
|
||||
|
||||
# Check for NAME attribute
|
||||
name_attr = [
|
||||
n
|
||||
for n in cls.body
|
||||
if isinstance(n, ast.Assign)
|
||||
and any(t.id == "NAME" for t in n.targets if isinstance(t, ast.Name))
|
||||
]
|
||||
assert len(name_attr) > 0, "CheckpointLoaderLM should have NAME attribute"
|
||||
|
||||
# Check for CATEGORY attribute
|
||||
cat_attr = [
|
||||
n
|
||||
for n in cls.body
|
||||
if isinstance(n, ast.Assign)
|
||||
and any(t.id == "CATEGORY" for t in n.targets if isinstance(t, ast.Name))
|
||||
]
|
||||
assert len(cat_attr) > 0, "CheckpointLoaderLM should have CATEGORY attribute"
|
||||
|
||||
# Check for INPUT_TYPES method
|
||||
input_types = [
|
||||
n
|
||||
for n in cls.body
|
||||
if isinstance(n, ast.FunctionDef) and n.name == "INPUT_TYPES"
|
||||
]
|
||||
assert len(input_types) > 0, "CheckpointLoaderLM should have INPUT_TYPES method"
|
||||
|
||||
# Check for load_checkpoint method
|
||||
load_method = [
|
||||
n
|
||||
for n in cls.body
|
||||
if isinstance(n, ast.FunctionDef) and n.name == "load_checkpoint"
|
||||
]
|
||||
assert len(load_method) > 0, (
|
||||
"CheckpointLoaderLM should have load_checkpoint method"
|
||||
)
|
||||
|
||||
|
||||
class TestUNETLoaderLM:
|
||||
"""Test UNETLoaderLM node"""
|
||||
|
||||
def test_class_attributes(self):
|
||||
"""Test that UNETLoaderLM has required class attributes"""
|
||||
# Import in a way that doesn't require ComfyUI
|
||||
import ast
|
||||
|
||||
filepath = os.path.join(PROJECT_ROOT, "py", "nodes", "unet_loader.py")
|
||||
|
||||
with open(filepath, "r") as f:
|
||||
tree = ast.parse(f.read())
|
||||
|
||||
# Find UNETLoaderLM class
|
||||
classes = {
|
||||
node.name: node for node in ast.walk(tree) if isinstance(node, ast.ClassDef)
|
||||
}
|
||||
assert "UNETLoaderLM" in classes
|
||||
|
||||
cls = classes["UNETLoaderLM"]
|
||||
|
||||
# Check for NAME attribute
|
||||
name_attr = [
|
||||
n
|
||||
for n in cls.body
|
||||
if isinstance(n, ast.Assign)
|
||||
and any(t.id == "NAME" for t in n.targets if isinstance(t, ast.Name))
|
||||
]
|
||||
assert len(name_attr) > 0, "UNETLoaderLM should have NAME attribute"
|
||||
|
||||
# Check for CATEGORY attribute
|
||||
cat_attr = [
|
||||
n
|
||||
for n in cls.body
|
||||
if isinstance(n, ast.Assign)
|
||||
and any(t.id == "CATEGORY" for t in n.targets if isinstance(t, ast.Name))
|
||||
]
|
||||
assert len(cat_attr) > 0, "UNETLoaderLM should have CATEGORY attribute"
|
||||
|
||||
# Check for INPUT_TYPES method
|
||||
input_types = [
|
||||
n
|
||||
for n in cls.body
|
||||
if isinstance(n, ast.FunctionDef) and n.name == "INPUT_TYPES"
|
||||
]
|
||||
assert len(input_types) > 0, "UNETLoaderLM should have INPUT_TYPES method"
|
||||
|
||||
# Check for load_unet method
|
||||
load_method = [
|
||||
n
|
||||
for n in cls.body
|
||||
if isinstance(n, ast.FunctionDef) and n.name == "load_unet"
|
||||
]
|
||||
assert len(load_method) > 0, "UNETLoaderLM should have load_unet method"
|
||||
|
||||
|
||||
class TestUtils:
|
||||
"""Test utility functions"""
|
||||
|
||||
def test_get_checkpoint_info_absolute_exists(self):
|
||||
"""Test that get_checkpoint_info_absolute function exists in utils"""
|
||||
import ast
|
||||
|
||||
filepath = os.path.join(PROJECT_ROOT, "py", "utils", "utils.py")
|
||||
|
||||
with open(filepath, "r") as f:
|
||||
tree = ast.parse(f.read())
|
||||
|
||||
functions = [
|
||||
node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)
|
||||
]
|
||||
assert "get_checkpoint_info_absolute" in functions, (
|
||||
"get_checkpoint_info_absolute should exist"
|
||||
)
|
||||
|
||||
def test_format_model_name_for_comfyui_exists(self):
|
||||
"""Test that _format_model_name_for_comfyui function exists in utils"""
|
||||
import ast
|
||||
|
||||
filepath = os.path.join(PROJECT_ROOT, "py", "utils", "utils.py")
|
||||
|
||||
with open(filepath, "r") as f:
|
||||
tree = ast.parse(f.read())
|
||||
|
||||
functions = [
|
||||
node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)
|
||||
]
|
||||
assert "_format_model_name_for_comfyui" in functions, (
|
||||
"_format_model_name_for_comfyui should exist"
|
||||
)
|
||||
@@ -242,36 +242,70 @@ class TestTagFTSIndexSearch:
|
||||
)
|
||||
|
||||
def test_search_pagination_ordering_consistency(self, populated_fts):
|
||||
"""Test that pagination maintains consistent ordering."""
|
||||
"""Test that pagination maintains consistent ordering by post_count."""
|
||||
page1 = populated_fts.search("1", limit=10, offset=0)
|
||||
page2 = populated_fts.search("1", limit=10, offset=10)
|
||||
|
||||
assert len(page1) > 0, "Page 1 should have results"
|
||||
assert len(page2) > 0, "Page 2 should have results"
|
||||
|
||||
# Page 2 scores should all be <= Page 1 min score
|
||||
page1_min_score = min(r["rank_score"] for r in page1)
|
||||
page2_max_score = max(r["rank_score"] for r in page2)
|
||||
# Page 2 max post_count should be <= Page 1 min post_count
|
||||
page1_min_posts = min(r["post_count"] for r in page1)
|
||||
page2_max_posts = max(r["post_count"] for r in page2)
|
||||
|
||||
assert page2_max_score <= page1_min_score, (
|
||||
f"Page 2 max score ({page2_max_score}) should be <= Page 1 min score ({page1_min_score})"
|
||||
assert page2_max_posts <= page1_min_posts, (
|
||||
f"Page 2 max post_count ({page2_max_posts}) should be <= Page 1 min post_count ({page1_min_posts})"
|
||||
)
|
||||
|
||||
def test_search_rank_score_includes_popularity_weight(self, populated_fts):
|
||||
"""Test that rank_score includes post_count popularity weighting."""
|
||||
def test_search_returns_popular_tags_higher(self, populated_fts):
|
||||
"""Test that search returns popular tags (higher post_count) first."""
|
||||
results = populated_fts.search("1", limit=5)
|
||||
|
||||
assert len(results) >= 2, "Need at least 2 results to compare"
|
||||
|
||||
# 1girl has 6M posts, should have higher rank_score than tags with fewer posts
|
||||
# 1girl has 6M posts, should be ranked first
|
||||
girl_result = next((r for r in results if r["tag_name"] == "1girl"), None)
|
||||
assert girl_result is not None, "1girl should be in results"
|
||||
assert results[0]["tag_name"] == "1girl", (
|
||||
"1girl should be first due to highest post_count"
|
||||
)
|
||||
|
||||
# Find a tag with significantly fewer posts
|
||||
low_post_result = next((r for r in results if r["post_count"] < 10000), None)
|
||||
if low_post_result:
|
||||
assert girl_result["rank_score"] > low_post_result["rank_score"], (
|
||||
f"1girl (6M posts) should have higher score than {low_post_result['tag_name']} ({low_post_result['post_count']} posts)"
|
||||
assert girl_result["post_count"] > low_post_result["post_count"], (
|
||||
f"1girl (6M posts) should have higher post_count than {low_post_result['tag_name']} ({low_post_result['post_count']} posts)"
|
||||
)
|
||||
|
||||
def test_search_popularity_ordering(self, populated_fts):
|
||||
"""Test that results are ordered by post_count (popularity)."""
|
||||
results = populated_fts.search("1", limit=20)
|
||||
|
||||
# Get 1girl and 1boy results for comparison
|
||||
girl_result = next((r for r in results if r["tag_name"] == "1girl"), None)
|
||||
boy_result = next((r for r in results if r["tag_name"] == "1boy"), None)
|
||||
|
||||
assert girl_result is not None, "1girl should be in results"
|
||||
assert boy_result is not None, "1boy should be in results"
|
||||
|
||||
# 1girl: 6M posts, 1boy: 1.4M posts
|
||||
assert girl_result["post_count"] == 6008644, "1girl should have 6M posts"
|
||||
assert boy_result["post_count"] == 1405457, "1boy should have 1.4M posts"
|
||||
|
||||
# 1girl should rank higher due to higher post_count
|
||||
girl_rank = results.index(girl_result)
|
||||
boy_rank = results.index(boy_result)
|
||||
assert girl_rank < boy_rank, (
|
||||
f"1girl should rank higher than 1boy due to higher post_count "
|
||||
f"(girl rank: {girl_rank}, boy rank: {boy_rank})"
|
||||
)
|
||||
|
||||
# Verify results are sorted by post_count descending
|
||||
for i in range(len(results) - 1):
|
||||
assert results[i]["post_count"] >= results[i + 1]["post_count"], (
|
||||
f"Results should be sorted by post_count descending: "
|
||||
f"{results[i]['tag_name']} ({results[i]['post_count']}) >= "
|
||||
f"{results[i + 1]['tag_name']} ({results[i + 1]['post_count']})"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
<div class="lora-cycler-widget">
|
||||
<LoraCyclerSettingsView
|
||||
:current-index="state.currentIndex.value"
|
||||
:total-count="state.totalCount.value"
|
||||
:current-lora-name="state.currentLoraName.value"
|
||||
:total-count="displayTotalCount"
|
||||
:current-lora-name="displayLoraName"
|
||||
:current-lora-filename="state.currentLoraFilename.value"
|
||||
:model-strength="state.modelStrength.value"
|
||||
:clip-strength="state.clipStrength.value"
|
||||
@@ -16,11 +16,14 @@
|
||||
:is-pause-disabled="hasQueuedPrompts"
|
||||
:is-workflow-executing="state.isWorkflowExecuting.value"
|
||||
:executing-repeat-step="state.executingRepeatStep.value"
|
||||
:include-no-lora="state.includeNoLora.value"
|
||||
:is-no-lora="isNoLora"
|
||||
@update:current-index="handleIndexUpdate"
|
||||
@update:model-strength="state.modelStrength.value = $event"
|
||||
@update:clip-strength="state.clipStrength.value = $event"
|
||||
@update:use-custom-clip-range="handleUseCustomClipRangeChange"
|
||||
@update:repeat-count="handleRepeatCountChange"
|
||||
@update:include-no-lora="handleIncludeNoLoraChange"
|
||||
@toggle-pause="handleTogglePause"
|
||||
@reset-index="handleResetIndex"
|
||||
@open-lora-selector="isModalOpen = true"
|
||||
@@ -30,6 +33,7 @@
|
||||
:visible="isModalOpen"
|
||||
:lora-list="cachedLoraList"
|
||||
:current-index="state.currentIndex.value"
|
||||
:include-no-lora="state.includeNoLora.value"
|
||||
@close="isModalOpen = false"
|
||||
@select="handleModalSelect"
|
||||
/>
|
||||
@@ -37,7 +41,7 @@
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { onMounted, ref } from 'vue'
|
||||
import { onMounted, ref, computed } from 'vue'
|
||||
import LoraCyclerSettingsView from './lora-cycler/LoraCyclerSettingsView.vue'
|
||||
import LoraListModal from './lora-cycler/LoraListModal.vue'
|
||||
import { useLoraCyclerState } from '../composables/useLoraCyclerState'
|
||||
@@ -102,6 +106,31 @@ const isModalOpen = ref(false)
|
||||
// Cache for LoRA list (used by modal)
|
||||
const cachedLoraList = ref<LoraItem[]>([])
|
||||
|
||||
// Computed: display total count (includes no lora option if enabled)
|
||||
const displayTotalCount = computed(() => {
|
||||
const baseCount = state.totalCount.value
|
||||
return state.includeNoLora.value ? baseCount + 1 : baseCount
|
||||
})
|
||||
|
||||
// Computed: display LoRA name (shows "No LoRA" if on the last index and includeNoLora is enabled)
|
||||
const displayLoraName = computed(() => {
|
||||
const currentIndex = state.currentIndex.value
|
||||
const totalCount = state.totalCount.value
|
||||
|
||||
// If includeNoLora is enabled and we're on the last position (no lora slot)
|
||||
if (state.includeNoLora.value && currentIndex === totalCount + 1) {
|
||||
return 'No LoRA'
|
||||
}
|
||||
|
||||
// Otherwise show the normal LoRA name
|
||||
return state.currentLoraName.value
|
||||
})
|
||||
|
||||
// Computed: check if currently on "No LoRA" option
|
||||
const isNoLora = computed(() => {
|
||||
return state.includeNoLora.value && state.currentIndex.value === state.totalCount.value + 1
|
||||
})
|
||||
|
||||
// Get pool config from connected node
|
||||
const getPoolConfig = (): LoraPoolConfig | null => {
|
||||
// Check if getPoolConfig method exists on node (added by main.ts)
|
||||
@@ -113,7 +142,17 @@ const getPoolConfig = (): LoraPoolConfig | null => {
|
||||
|
||||
// Update display from LoRA list and index
|
||||
const updateDisplayFromLoraList = (loraList: LoraItem[], index: number) => {
|
||||
if (loraList.length > 0 && index > 0 && index <= loraList.length) {
|
||||
const actualLoraCount = loraList.length
|
||||
|
||||
// If index is beyond actual LoRA count, it means we're on the "no lora" option
|
||||
if (state.includeNoLora.value && index === actualLoraCount + 1) {
|
||||
state.currentLoraName.value = 'No LoRA'
|
||||
state.currentLoraFilename.value = 'No LoRA'
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, show normal LoRA info
|
||||
if (actualLoraCount > 0 && index > 0 && index <= actualLoraCount) {
|
||||
const currentLora = loraList[index - 1]
|
||||
if (currentLora) {
|
||||
state.currentLoraName.value = currentLora.file_name
|
||||
@@ -124,6 +163,14 @@ const updateDisplayFromLoraList = (loraList: LoraItem[], index: number) => {
|
||||
|
||||
// Handle index update from user
|
||||
const handleIndexUpdate = async (newIndex: number) => {
|
||||
// Calculate max valid index (includes no lora slot if enabled)
|
||||
const maxIndex = state.includeNoLora.value
|
||||
? state.totalCount.value + 1
|
||||
: state.totalCount.value
|
||||
|
||||
// Clamp index to valid range
|
||||
const clampedIndex = Math.max(1, Math.min(newIndex, maxIndex || 1))
|
||||
|
||||
// Reset execution state when user manually changes index
|
||||
// This ensures the next execution starts from the user-set index
|
||||
;(props.widget as any)[HAS_EXECUTED] = false
|
||||
@@ -134,14 +181,14 @@ const handleIndexUpdate = async (newIndex: number) => {
|
||||
executionQueue.length = 0
|
||||
hasQueuedPrompts.value = false
|
||||
|
||||
state.setIndex(newIndex)
|
||||
state.setIndex(clampedIndex)
|
||||
|
||||
// Refresh list to update current LoRA display
|
||||
try {
|
||||
const poolConfig = getPoolConfig()
|
||||
const loraList = await state.fetchCyclerList(poolConfig)
|
||||
cachedLoraList.value = loraList
|
||||
updateDisplayFromLoraList(loraList, newIndex)
|
||||
updateDisplayFromLoraList(loraList, clampedIndex)
|
||||
} catch (error) {
|
||||
console.error('[LoraCyclerWidget] Error updating index:', error)
|
||||
}
|
||||
@@ -169,6 +216,17 @@ const handleRepeatCountChange = (newValue: number) => {
|
||||
state.displayRepeatUsed.value = 0
|
||||
}
|
||||
|
||||
// Handle include no lora toggle
|
||||
const handleIncludeNoLoraChange = (newValue: boolean) => {
|
||||
state.includeNoLora.value = newValue
|
||||
|
||||
// If turning off and current index is beyond the actual LoRA count,
|
||||
// clamp it to the last valid LoRA index
|
||||
if (!newValue && state.currentIndex.value > state.totalCount.value) {
|
||||
state.currentIndex.value = Math.max(1, state.totalCount.value)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle pause toggle
|
||||
const handleTogglePause = () => {
|
||||
state.togglePause()
|
||||
|
||||
@@ -8,6 +8,9 @@
|
||||
:exclude-tags="state.excludeTags.value"
|
||||
:include-folders="state.includeFolders.value"
|
||||
:exclude-folders="state.excludeFolders.value"
|
||||
:include-patterns="state.includePatterns.value"
|
||||
:exclude-patterns="state.excludePatterns.value"
|
||||
:use-regex="state.useRegex.value"
|
||||
:no-credit-required="state.noCreditRequired.value"
|
||||
:allow-selling="state.allowSelling.value"
|
||||
:preview-items="state.previewItems.value"
|
||||
@@ -16,6 +19,9 @@
|
||||
@open-modal="openModal"
|
||||
@update:include-folders="state.includeFolders.value = $event"
|
||||
@update:exclude-folders="state.excludeFolders.value = $event"
|
||||
@update:include-patterns="state.includePatterns.value = $event"
|
||||
@update:exclude-patterns="state.excludePatterns.value = $event"
|
||||
@update:use-regex="state.useRegex.value = $event"
|
||||
@update:no-credit-required="state.noCreditRequired.value = $event"
|
||||
@update:allow-selling="state.allowSelling.value = $event"
|
||||
@refresh="state.refreshPreview"
|
||||
|
||||
@@ -13,7 +13,9 @@
|
||||
@click="handleOpenSelector"
|
||||
>
|
||||
<span class="progress-label">{{ isWorkflowExecuting ? 'Using LoRA:' : 'Next LoRA:' }}</span>
|
||||
<span class="progress-name clickable" :class="{ disabled: isPauseDisabled }" :title="currentLoraFilename">
|
||||
<span class="progress-name clickable"
|
||||
:class="{ disabled: isPauseDisabled, 'no-lora': isNoLora }"
|
||||
:title="currentLoraFilename">
|
||||
{{ currentLoraName || 'None' }}
|
||||
<svg class="selector-icon" viewBox="0 0 24 24" fill="currentColor">
|
||||
<path d="M7 10l5 5 5-5z"/>
|
||||
@@ -160,6 +162,27 @@
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Include No LoRA Toggle -->
|
||||
<div class="setting-section">
|
||||
<div class="section-header-with-toggle">
|
||||
<label class="setting-label">
|
||||
Add "No LoRA" step
|
||||
</label>
|
||||
<button
|
||||
type="button"
|
||||
class="toggle-switch"
|
||||
:class="{ 'toggle-switch--active': includeNoLora }"
|
||||
@click="$emit('update:includeNoLora', !includeNoLora)"
|
||||
role="switch"
|
||||
:aria-checked="includeNoLora"
|
||||
title="Add an iteration without LoRA for comparison"
|
||||
>
|
||||
<span class="toggle-switch__track"></span>
|
||||
<span class="toggle-switch__thumb"></span>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
@@ -182,6 +205,8 @@ const props = defineProps<{
|
||||
isPauseDisabled: boolean
|
||||
isWorkflowExecuting: boolean
|
||||
executingRepeatStep: number
|
||||
includeNoLora: boolean
|
||||
isNoLora?: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
@@ -190,6 +215,7 @@ const emit = defineEmits<{
|
||||
'update:clipStrength': [value: number]
|
||||
'update:useCustomClipRange': [value: boolean]
|
||||
'update:repeatCount': [value: number]
|
||||
'update:includeNoLora': [value: boolean]
|
||||
'toggle-pause': []
|
||||
'reset-index': []
|
||||
'open-lora-selector': []
|
||||
@@ -346,6 +372,16 @@ const onRepeatBlur = (event: Event) => {
|
||||
color: rgba(191, 219, 254, 1);
|
||||
}
|
||||
|
||||
.progress-name.no-lora {
|
||||
font-style: italic;
|
||||
color: rgba(226, 232, 240, 0.6);
|
||||
}
|
||||
|
||||
.progress-name.clickable.no-lora:hover:not(.disabled) {
|
||||
background: rgba(160, 174, 192, 0.2);
|
||||
color: rgba(226, 232, 240, 0.8);
|
||||
}
|
||||
|
||||
.progress-name.clickable.disabled {
|
||||
cursor: not-allowed;
|
||||
opacity: 0.5;
|
||||
|
||||
@@ -35,7 +35,10 @@
|
||||
v-for="item in filteredList"
|
||||
:key="item.index"
|
||||
class="lora-item"
|
||||
:class="{ active: currentIndex === item.index }"
|
||||
:class="{
|
||||
active: currentIndex === item.index,
|
||||
'no-lora-item': item.lora.file_name === 'No LoRA'
|
||||
}"
|
||||
@mouseenter="showPreview(item.lora.file_name, $event)"
|
||||
@mouseleave="hidePreview"
|
||||
@click="selectLora(item.index)"
|
||||
@@ -65,6 +68,7 @@ const props = defineProps<{
|
||||
visible: boolean
|
||||
loraList: LoraItem[]
|
||||
currentIndex: number
|
||||
includeNoLora?: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
@@ -79,7 +83,8 @@ const searchInputRef = ref<HTMLInputElement | null>(null)
|
||||
let previewTooltip: any = null
|
||||
|
||||
const subtitleText = computed(() => {
|
||||
const total = props.loraList.length
|
||||
const baseTotal = props.loraList.length
|
||||
const total = props.includeNoLora ? baseTotal + 1 : baseTotal
|
||||
const filtered = filteredList.value.length
|
||||
if (filtered === total) {
|
||||
return `Total: ${total} LoRA${total !== 1 ? 's' : ''}`
|
||||
@@ -88,11 +93,19 @@ const subtitleText = computed(() => {
|
||||
})
|
||||
|
||||
const filteredList = computed<LoraListItem[]>(() => {
|
||||
const list = props.loraList.map((lora, idx) => ({
|
||||
const list: LoraListItem[] = props.loraList.map((lora, idx) => ({
|
||||
index: idx + 1,
|
||||
lora
|
||||
}))
|
||||
|
||||
// Add "No LoRA" option at the end if includeNoLora is enabled
|
||||
if (props.includeNoLora) {
|
||||
list.push({
|
||||
index: list.length + 1,
|
||||
lora: { file_name: 'No LoRA' } as LoraItem
|
||||
})
|
||||
}
|
||||
|
||||
if (!searchQuery.value.trim()) {
|
||||
return list
|
||||
}
|
||||
@@ -303,6 +316,15 @@ onUnmounted(() => {
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.lora-item.no-lora-item .lora-name {
|
||||
font-style: italic;
|
||||
color: rgba(226, 232, 240, 0.6);
|
||||
}
|
||||
|
||||
.lora-item.no-lora-item:hover .lora-name {
|
||||
color: rgba(226, 232, 240, 0.8);
|
||||
}
|
||||
|
||||
.no-results {
|
||||
padding: 32px 20px;
|
||||
text-align: center;
|
||||
|
||||
@@ -24,6 +24,15 @@
|
||||
@edit-exclude="$emit('open-modal', 'excludeFolders')"
|
||||
/>
|
||||
|
||||
<NamePatternsSection
|
||||
:include-patterns="includePatterns"
|
||||
:exclude-patterns="excludePatterns"
|
||||
:use-regex="useRegex"
|
||||
@update:include-patterns="$emit('update:includePatterns', $event)"
|
||||
@update:exclude-patterns="$emit('update:excludePatterns', $event)"
|
||||
@update:use-regex="$emit('update:useRegex', $event)"
|
||||
/>
|
||||
|
||||
<LicenseSection
|
||||
:no-credit-required="noCreditRequired"
|
||||
:allow-selling="allowSelling"
|
||||
@@ -46,6 +55,7 @@
|
||||
import BaseModelSection from './sections/BaseModelSection.vue'
|
||||
import TagsSection from './sections/TagsSection.vue'
|
||||
import FoldersSection from './sections/FoldersSection.vue'
|
||||
import NamePatternsSection from './sections/NamePatternsSection.vue'
|
||||
import LicenseSection from './sections/LicenseSection.vue'
|
||||
import LoraPoolPreview from './LoraPoolPreview.vue'
|
||||
import type { BaseModelOption, LoraItem } from '../../composables/types'
|
||||
@@ -61,6 +71,10 @@ defineProps<{
|
||||
// Folders
|
||||
includeFolders: string[]
|
||||
excludeFolders: string[]
|
||||
// Name patterns
|
||||
includePatterns: string[]
|
||||
excludePatterns: string[]
|
||||
useRegex: boolean
|
||||
// License
|
||||
noCreditRequired: boolean
|
||||
allowSelling: boolean
|
||||
@@ -74,6 +88,9 @@ defineEmits<{
|
||||
'open-modal': [modal: ModalType]
|
||||
'update:includeFolders': [value: string[]]
|
||||
'update:excludeFolders': [value: string[]]
|
||||
'update:includePatterns': [value: string[]]
|
||||
'update:excludePatterns': [value: string[]]
|
||||
'update:useRegex': [value: boolean]
|
||||
'update:noCreditRequired': [value: boolean]
|
||||
'update:allowSelling': [value: boolean]
|
||||
refresh: []
|
||||
|
||||
@@ -0,0 +1,255 @@
|
||||
<template>
|
||||
<div class="section">
|
||||
<div class="section__header">
|
||||
<span class="section__title">NAME PATTERNS</span>
|
||||
<label class="section__toggle">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="useRegex"
|
||||
@change="$emit('update:useRegex', ($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
<span class="section__toggle-label">Use Regex</span>
|
||||
</label>
|
||||
</div>
|
||||
<div class="section__columns">
|
||||
<!-- Include column -->
|
||||
<div class="section__column">
|
||||
<div class="section__column-header">
|
||||
<span class="section__column-title section__column-title--include">INCLUDE</span>
|
||||
</div>
|
||||
<div class="section__input-wrapper">
|
||||
<input
|
||||
type="text"
|
||||
v-model="includeInput"
|
||||
:placeholder="useRegex ? 'Add regex pattern...' : 'Add text pattern...'"
|
||||
class="section__input"
|
||||
@keydown.enter="addInclude"
|
||||
/>
|
||||
<button type="button" class="section__add-btn" @click="addInclude">+</button>
|
||||
</div>
|
||||
<div class="section__patterns">
|
||||
<FilterChip
|
||||
v-for="pattern in includePatterns"
|
||||
:key="pattern"
|
||||
:label="pattern"
|
||||
variant="include"
|
||||
removable
|
||||
@remove="removeInclude(pattern)"
|
||||
/>
|
||||
<div v-if="includePatterns.length === 0" class="section__empty">
|
||||
{{ useRegex ? 'No regex patterns' : 'No text patterns' }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Exclude column -->
|
||||
<div class="section__column">
|
||||
<div class="section__column-header">
|
||||
<span class="section__column-title section__column-title--exclude">EXCLUDE</span>
|
||||
</div>
|
||||
<div class="section__input-wrapper">
|
||||
<input
|
||||
type="text"
|
||||
v-model="excludeInput"
|
||||
:placeholder="useRegex ? 'Add regex pattern...' : 'Add text pattern...'"
|
||||
class="section__input"
|
||||
@keydown.enter="addExclude"
|
||||
/>
|
||||
<button type="button" class="section__add-btn" @click="addExclude">+</button>
|
||||
</div>
|
||||
<div class="section__patterns">
|
||||
<FilterChip
|
||||
v-for="pattern in excludePatterns"
|
||||
:key="pattern"
|
||||
:label="pattern"
|
||||
variant="exclude"
|
||||
removable
|
||||
@remove="removeExclude(pattern)"
|
||||
/>
|
||||
<div v-if="excludePatterns.length === 0" class="section__empty">
|
||||
{{ useRegex ? 'No regex patterns' : 'No text patterns' }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref } from 'vue'
|
||||
import FilterChip from '../shared/FilterChip.vue'
|
||||
|
||||
const props = defineProps<{
|
||||
includePatterns: string[]
|
||||
excludePatterns: string[]
|
||||
useRegex: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:includePatterns': [value: string[]]
|
||||
'update:excludePatterns': [value: string[]]
|
||||
'update:useRegex': [value: boolean]
|
||||
}>()
|
||||
|
||||
const includeInput = ref('')
|
||||
const excludeInput = ref('')
|
||||
|
||||
const addInclude = () => {
|
||||
const pattern = includeInput.value.trim()
|
||||
if (pattern && !props.includePatterns.includes(pattern)) {
|
||||
emit('update:includePatterns', [...props.includePatterns, pattern])
|
||||
includeInput.value = ''
|
||||
}
|
||||
}
|
||||
|
||||
const addExclude = () => {
|
||||
const pattern = excludeInput.value.trim()
|
||||
if (pattern && !props.excludePatterns.includes(pattern)) {
|
||||
emit('update:excludePatterns', [...props.excludePatterns, pattern])
|
||||
excludeInput.value = ''
|
||||
}
|
||||
}
|
||||
|
||||
const removeInclude = (pattern: string) => {
|
||||
emit('update:includePatterns', props.includePatterns.filter(p => p !== pattern))
|
||||
}
|
||||
|
||||
const removeExclude = (pattern: string) => {
|
||||
emit('update:excludePatterns', props.excludePatterns.filter(p => p !== pattern))
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.section {
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.section__header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.section__title {
|
||||
font-size: 10px;
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
color: var(--fg-color, #fff);
|
||||
opacity: 0.6;
|
||||
}
|
||||
|
||||
.section__toggle {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
cursor: pointer;
|
||||
font-size: 11px;
|
||||
color: var(--fg-color, #fff);
|
||||
opacity: 0.7;
|
||||
}
|
||||
|
||||
.section__toggle input[type="checkbox"] {
|
||||
margin: 0;
|
||||
width: 14px;
|
||||
height: 14px;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.section__toggle-label {
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.section__columns {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.section__column {
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.section__column-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 6px;
|
||||
}
|
||||
|
||||
.section__column-title {
|
||||
font-size: 9px;
|
||||
font-weight: 500;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.03em;
|
||||
}
|
||||
|
||||
.section__column-title--include {
|
||||
color: #4299e1;
|
||||
}
|
||||
|
||||
.section__column-title--exclude {
|
||||
color: #ef4444;
|
||||
}
|
||||
|
||||
.section__input-wrapper {
|
||||
display: flex;
|
||||
gap: 4px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.section__input {
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
padding: 6px 8px;
|
||||
background: var(--comfy-input-bg, #333);
|
||||
border: 1px solid var(--comfy-input-border, #444);
|
||||
border-radius: 4px;
|
||||
color: var(--fg-color, #fff);
|
||||
font-size: 12px;
|
||||
outline: none;
|
||||
}
|
||||
|
||||
.section__input:focus {
|
||||
border-color: #4299e1;
|
||||
}
|
||||
|
||||
.section__add-btn {
|
||||
width: 28px;
|
||||
height: 28px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background: var(--comfy-input-bg, #333);
|
||||
border: 1px solid var(--comfy-input-border, #444);
|
||||
border-radius: 4px;
|
||||
color: var(--fg-color, #fff);
|
||||
font-size: 16px;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: all 0.15s;
|
||||
}
|
||||
|
||||
.section__add-btn:hover {
|
||||
background: var(--comfy-input-bg-hover, #444);
|
||||
border-color: #4299e1;
|
||||
}
|
||||
|
||||
.section__patterns {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 4px;
|
||||
min-height: 22px;
|
||||
}
|
||||
|
||||
.section__empty {
|
||||
font-size: 10px;
|
||||
color: var(--fg-color, #fff);
|
||||
opacity: 0.3;
|
||||
font-style: italic;
|
||||
min-height: 22px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
</style>
|
||||
@@ -10,6 +10,12 @@ export interface LoraPoolConfig {
|
||||
noCreditRequired: boolean
|
||||
allowSelling: boolean
|
||||
}
|
||||
namePatterns: {
|
||||
include: string[]
|
||||
exclude: string[]
|
||||
useRegex: boolean
|
||||
}
|
||||
includeEmptyLora?: boolean // Optional, deprecated (moved to Cycler)
|
||||
}
|
||||
preview: { matchCount: number; lastUpdated: number }
|
||||
}
|
||||
@@ -84,6 +90,8 @@ export interface CyclerConfig {
|
||||
repeat_count: number // How many times each LoRA should repeat (default: 1)
|
||||
repeat_used: number // How many times current index has been used
|
||||
is_paused: boolean // Whether iteration is paused
|
||||
// Include "no LoRA" option in cycle
|
||||
include_no_lora: boolean // Whether to include empty LoRA option
|
||||
}
|
||||
|
||||
// Widget config union type
|
||||
|
||||
@@ -4,6 +4,7 @@ import type { ComponentWidget, CyclerConfig, LoraPoolConfig } from './types'
|
||||
export interface CyclerLoraItem {
|
||||
file_name: string
|
||||
model_name: string
|
||||
file_path: string
|
||||
}
|
||||
|
||||
export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
||||
@@ -34,6 +35,7 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
||||
const repeatUsed = ref(0) // How many times current index has been used (internal tracking)
|
||||
const displayRepeatUsed = ref(0) // For UI display, deferred updates like currentIndex
|
||||
const isPaused = ref(false) // Whether iteration is paused
|
||||
const includeNoLora = ref(false) // Whether to include empty LoRA option in cycle
|
||||
|
||||
// Execution progress tracking (visual feedback)
|
||||
const isWorkflowExecuting = ref(false) // Workflow is currently running
|
||||
@@ -58,6 +60,7 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
||||
repeat_count: repeatCount.value,
|
||||
repeat_used: repeatUsed.value,
|
||||
is_paused: isPaused.value,
|
||||
include_no_lora: includeNoLora.value,
|
||||
}
|
||||
}
|
||||
return {
|
||||
@@ -75,6 +78,7 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
||||
repeat_count: repeatCount.value,
|
||||
repeat_used: repeatUsed.value,
|
||||
is_paused: isPaused.value,
|
||||
include_no_lora: includeNoLora.value,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,12 +97,13 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
||||
sortBy.value = config.sort_by || 'filename'
|
||||
currentLoraName.value = config.current_lora_name || ''
|
||||
currentLoraFilename.value = config.current_lora_filename || ''
|
||||
// Advanced index control features
|
||||
repeatCount.value = config.repeat_count ?? 1
|
||||
repeatUsed.value = config.repeat_used ?? 0
|
||||
isPaused.value = config.is_paused ?? false
|
||||
// Note: execution_index and next_index are not restored from config
|
||||
// as they are transient values used only during batch execution
|
||||
// Advanced index control features
|
||||
repeatCount.value = config.repeat_count ?? 1
|
||||
repeatUsed.value = config.repeat_used ?? 0
|
||||
isPaused.value = config.is_paused ?? false
|
||||
includeNoLora.value = config.include_no_lora ?? false
|
||||
// Note: execution_index and next_index are not restored from config
|
||||
// as they are transient values used only during batch execution
|
||||
} finally {
|
||||
isRestoring = false
|
||||
}
|
||||
@@ -111,7 +116,9 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
||||
// Calculate the next index (wrap to 1 if at end)
|
||||
const current = executionIndex.value ?? currentIndex.value
|
||||
let next = current + 1
|
||||
if (totalCount.value > 0 && next > totalCount.value) {
|
||||
// Total count includes no lora option if enabled
|
||||
const effectiveTotalCount = includeNoLora.value ? totalCount.value + 1 : totalCount.value
|
||||
if (effectiveTotalCount > 0 && next > effectiveTotalCount) {
|
||||
next = 1
|
||||
}
|
||||
nextIndex.value = next
|
||||
@@ -122,7 +129,9 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
||||
if (nextIndex.value === null) {
|
||||
// First execution uses current_index, so next is current + 1
|
||||
let next = currentIndex.value + 1
|
||||
if (totalCount.value > 0 && next > totalCount.value) {
|
||||
// Total count includes no lora option if enabled
|
||||
const effectiveTotalCount = includeNoLora.value ? totalCount.value + 1 : totalCount.value
|
||||
if (effectiveTotalCount > 0 && next > effectiveTotalCount) {
|
||||
next = 1
|
||||
}
|
||||
nextIndex.value = next
|
||||
@@ -230,7 +239,9 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
||||
|
||||
// Set index manually
|
||||
const setIndex = (index: number) => {
|
||||
if (index >= 1 && index <= totalCount.value) {
|
||||
// Total count includes no lora option if enabled
|
||||
const effectiveTotalCount = includeNoLora.value ? totalCount.value + 1 : totalCount.value
|
||||
if (index >= 1 && index <= effectiveTotalCount) {
|
||||
currentIndex.value = index
|
||||
}
|
||||
}
|
||||
@@ -272,6 +283,7 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
||||
repeatCount,
|
||||
repeatUsed,
|
||||
isPaused,
|
||||
includeNoLora,
|
||||
], () => {
|
||||
widget.value = buildConfig()
|
||||
}, { deep: true })
|
||||
@@ -294,6 +306,7 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
||||
repeatUsed,
|
||||
displayRepeatUsed,
|
||||
isPaused,
|
||||
includeNoLora,
|
||||
isWorkflowExecuting,
|
||||
executingRepeatStep,
|
||||
|
||||
|
||||
@@ -62,6 +62,9 @@ export function useLoraPoolApi() {
|
||||
foldersExclude?: string[]
|
||||
noCreditRequired?: boolean
|
||||
allowSelling?: boolean
|
||||
namePatternsInclude?: string[]
|
||||
namePatternsExclude?: string[]
|
||||
namePatternsUseRegex?: boolean
|
||||
page?: number
|
||||
pageSize?: number
|
||||
}
|
||||
@@ -92,6 +95,13 @@ export function useLoraPoolApi() {
|
||||
urlParams.set('allow_selling_generated_content', String(params.allowSelling))
|
||||
}
|
||||
|
||||
// Name pattern filters
|
||||
params.namePatternsInclude?.forEach(pattern => urlParams.append('name_pattern_include', pattern))
|
||||
params.namePatternsExclude?.forEach(pattern => urlParams.append('name_pattern_exclude', pattern))
|
||||
if (params.namePatternsUseRegex !== undefined) {
|
||||
urlParams.set('name_pattern_use_regex', String(params.namePatternsUseRegex))
|
||||
}
|
||||
|
||||
const response = await fetch(`/api/lm/loras/list?${urlParams}`)
|
||||
const data = await response.json()
|
||||
|
||||
|
||||
@@ -24,6 +24,9 @@ export function useLoraPoolState(widget: ComponentWidget<LoraPoolConfig>) {
|
||||
const excludeFolders = ref<string[]>([])
|
||||
const noCreditRequired = ref(false)
|
||||
const allowSelling = ref(false)
|
||||
const includePatterns = ref<string[]>([])
|
||||
const excludePatterns = ref<string[]>([])
|
||||
const useRegex = ref(false)
|
||||
|
||||
// Available options from API
|
||||
const availableBaseModels = ref<BaseModelOption[]>([])
|
||||
@@ -52,6 +55,11 @@ export function useLoraPoolState(widget: ComponentWidget<LoraPoolConfig>) {
|
||||
license: {
|
||||
noCreditRequired: noCreditRequired.value,
|
||||
allowSelling: allowSelling.value
|
||||
},
|
||||
namePatterns: {
|
||||
include: includePatterns.value,
|
||||
exclude: excludePatterns.value,
|
||||
useRegex: useRegex.value
|
||||
}
|
||||
},
|
||||
preview: {
|
||||
@@ -94,6 +102,9 @@ export function useLoraPoolState(widget: ComponentWidget<LoraPoolConfig>) {
|
||||
updateIfChanged(excludeFolders, filters.folders?.exclude || [])
|
||||
updateIfChanged(noCreditRequired, filters.license?.noCreditRequired ?? false)
|
||||
updateIfChanged(allowSelling, filters.license?.allowSelling ?? false)
|
||||
updateIfChanged(includePatterns, filters.namePatterns?.include || [])
|
||||
updateIfChanged(excludePatterns, filters.namePatterns?.exclude || [])
|
||||
updateIfChanged(useRegex, filters.namePatterns?.useRegex ?? false)
|
||||
|
||||
// matchCount doesn't trigger watchers, so direct assignment is fine
|
||||
matchCount.value = preview?.matchCount || 0
|
||||
@@ -125,6 +136,9 @@ export function useLoraPoolState(widget: ComponentWidget<LoraPoolConfig>) {
|
||||
foldersExclude: excludeFolders.value,
|
||||
noCreditRequired: noCreditRequired.value || undefined,
|
||||
allowSelling: allowSelling.value || undefined,
|
||||
namePatternsInclude: includePatterns.value,
|
||||
namePatternsExclude: excludePatterns.value,
|
||||
namePatternsUseRegex: useRegex.value,
|
||||
pageSize: 6
|
||||
})
|
||||
|
||||
@@ -150,7 +164,10 @@ export function useLoraPoolState(widget: ComponentWidget<LoraPoolConfig>) {
|
||||
includeFolders,
|
||||
excludeFolders,
|
||||
noCreditRequired,
|
||||
allowSelling
|
||||
allowSelling,
|
||||
includePatterns,
|
||||
excludePatterns,
|
||||
useRegex
|
||||
], onFilterChange, { deep: true })
|
||||
|
||||
return {
|
||||
@@ -162,6 +179,9 @@ export function useLoraPoolState(widget: ComponentWidget<LoraPoolConfig>) {
|
||||
excludeFolders,
|
||||
noCreditRequired,
|
||||
allowSelling,
|
||||
includePatterns,
|
||||
excludePatterns,
|
||||
useRegex,
|
||||
|
||||
// Available options
|
||||
availableBaseModels,
|
||||
|
||||
@@ -13,12 +13,12 @@ import {
|
||||
} from './mode-change-handler'
|
||||
|
||||
const LORA_POOL_WIDGET_MIN_WIDTH = 500
|
||||
const LORA_POOL_WIDGET_MIN_HEIGHT = 400
|
||||
const LORA_POOL_WIDGET_MIN_HEIGHT = 520
|
||||
const LORA_RANDOMIZER_WIDGET_MIN_WIDTH = 500
|
||||
const LORA_RANDOMIZER_WIDGET_MIN_HEIGHT = 448
|
||||
const LORA_RANDOMIZER_WIDGET_MAX_HEIGHT = LORA_RANDOMIZER_WIDGET_MIN_HEIGHT
|
||||
const LORA_CYCLER_WIDGET_MIN_WIDTH = 380
|
||||
const LORA_CYCLER_WIDGET_MIN_HEIGHT = 314
|
||||
const LORA_CYCLER_WIDGET_MIN_HEIGHT = 344
|
||||
const LORA_CYCLER_WIDGET_MAX_HEIGHT = LORA_CYCLER_WIDGET_MIN_HEIGHT
|
||||
const JSON_DISPLAY_WIDGET_MIN_WIDTH = 300
|
||||
const JSON_DISPLAY_WIDGET_MIN_HEIGHT = 200
|
||||
|
||||
@@ -84,7 +84,8 @@ describe('useLoraCyclerState', () => {
|
||||
current_lora_filename: '',
|
||||
repeat_count: 1,
|
||||
repeat_used: 0,
|
||||
is_paused: false
|
||||
is_paused: false,
|
||||
include_no_lora: false
|
||||
})
|
||||
|
||||
expect(state.currentIndex.value).toBe(5)
|
||||
|
||||
4
vue-widgets/tests/fixtures/mockConfigs.ts
vendored
4
vue-widgets/tests/fixtures/mockConfigs.ts
vendored
@@ -24,6 +24,7 @@ export function createMockCyclerConfig(overrides: Partial<CyclerConfig> = {}): C
|
||||
repeat_count: 1,
|
||||
repeat_used: 0,
|
||||
is_paused: false,
|
||||
include_no_lora: false,
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
@@ -54,7 +55,8 @@ export function createMockPoolConfig(overrides: Partial<LoraPoolConfig> = {}): L
|
||||
export function createMockLoraList(count: number = 5): CyclerLoraItem[] {
|
||||
return Array.from({ length: count }, (_, i) => ({
|
||||
file_name: `lora${i + 1}.safetensors`,
|
||||
model_name: `LoRA Model ${i + 1}`
|
||||
model_name: `LoRA Model ${i + 1}`,
|
||||
file_path: `/models/loras/lora${i + 1}.safetensors`
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -1905,10 +1905,38 @@ class AutoComplete {
|
||||
|
||||
// For regular tag autocomplete (no command), only replace the last space-separated token
|
||||
// This allows "hello 1gi" + selecting "1girl" to become "hello 1girl, "
|
||||
// However, if the user typed a multi-word phrase that matches a tag (e.g., "looking to the side"
|
||||
// matching "looking_to_the_side"), replace the entire phrase instead of just the last word.
|
||||
// Command mode (e.g., "/char miku") should replace the entire command+search
|
||||
let searchTerm = fullSearchTerm;
|
||||
if (this.modelType === 'prompt' && this.searchType === 'custom_words' && !this.activeCommand) {
|
||||
searchTerm = this._getLastSpaceToken(fullSearchTerm);
|
||||
// Check if the selectedItem exists and its tag_name matches the full search term
|
||||
// when converted to underscore format (Danbooru convention)
|
||||
const selectedItem = this.selectedIndex >= 0 ? this.items[this.selectedIndex] : null;
|
||||
const selectedTagName = selectedItem && typeof selectedItem === 'object' && 'tag_name'
|
||||
? selectedItem.tag_name
|
||||
: null;
|
||||
|
||||
// Convert full search term to underscore format and check if it matches selected tag
|
||||
// Normalize multiple spaces to single underscore for matching (e.g., "looking to the side" -> "looking_to_the_side")
|
||||
const underscoreVersion = fullSearchTerm.replace(/ +/g, '_').toLowerCase();
|
||||
const selectedTagLower = selectedTagName?.toLowerCase() ?? '';
|
||||
|
||||
// If multi-word search term is a prefix or suffix of the selected tag,
|
||||
// replace the entire phrase. This handles cases where user types partial tag name.
|
||||
// Examples:
|
||||
// - "looking to the" -> "looking_to_the_side" (prefix match)
|
||||
// - "to the side" -> "looking_to_the_side" (suffix match)
|
||||
// - "looking to the side" -> "looking_to_the_side" (exact match)
|
||||
if (fullSearchTerm.includes(' ') && (
|
||||
selectedTagLower.startsWith(underscoreVersion) ||
|
||||
selectedTagLower.endsWith(underscoreVersion) ||
|
||||
underscoreVersion === selectedTagLower
|
||||
)) {
|
||||
searchTerm = fullSearchTerm;
|
||||
} else {
|
||||
searchTerm = this._getLastSpaceToken(fullSearchTerm);
|
||||
}
|
||||
}
|
||||
|
||||
const searchStartPos = caretPos - searchTerm.length;
|
||||
|
||||
@@ -14,6 +14,7 @@ import { initDrag, createContextMenu, initHeaderDrag, initReorderDrag, handleKey
|
||||
import { forwardMiddleMouseToCanvas } from "./utils.js";
|
||||
import { PreviewTooltip } from "./preview_tooltip.js";
|
||||
import { ensureLmStyles } from "./lm_styles_loader.js";
|
||||
import { getStrengthStepPreference } from "./settings.js";
|
||||
|
||||
export function addLorasWidget(node, name, opts, callback) {
|
||||
ensureLmStyles();
|
||||
@@ -416,7 +417,7 @@ export function addLorasWidget(node, name, opts, callback) {
|
||||
const loraIndex = lorasData.findIndex(l => l.name === name);
|
||||
|
||||
if (loraIndex >= 0) {
|
||||
lorasData[loraIndex].strength = (parseFloat(lorasData[loraIndex].strength) - 0.05).toFixed(2);
|
||||
lorasData[loraIndex].strength = (parseFloat(lorasData[loraIndex].strength) - getStrengthStepPreference()).toFixed(2);
|
||||
// Sync clipStrength if collapsed
|
||||
syncClipStrengthIfCollapsed(lorasData[loraIndex]);
|
||||
|
||||
@@ -488,7 +489,7 @@ export function addLorasWidget(node, name, opts, callback) {
|
||||
const loraIndex = lorasData.findIndex(l => l.name === name);
|
||||
|
||||
if (loraIndex >= 0) {
|
||||
lorasData[loraIndex].strength = (parseFloat(lorasData[loraIndex].strength) + 0.05).toFixed(2);
|
||||
lorasData[loraIndex].strength = (parseFloat(lorasData[loraIndex].strength) + getStrengthStepPreference()).toFixed(2);
|
||||
// Sync clipStrength if collapsed
|
||||
syncClipStrengthIfCollapsed(lorasData[loraIndex]);
|
||||
|
||||
@@ -541,7 +542,7 @@ export function addLorasWidget(node, name, opts, callback) {
|
||||
const loraIndex = lorasData.findIndex(l => l.name === name);
|
||||
|
||||
if (loraIndex >= 0) {
|
||||
lorasData[loraIndex].clipStrength = (parseFloat(lorasData[loraIndex].clipStrength) - 0.05).toFixed(2);
|
||||
lorasData[loraIndex].clipStrength = (parseFloat(lorasData[loraIndex].clipStrength) - getStrengthStepPreference()).toFixed(2);
|
||||
|
||||
const newValue = formatLoraValue(lorasData);
|
||||
updateWidgetValue(newValue);
|
||||
@@ -611,7 +612,7 @@ export function addLorasWidget(node, name, opts, callback) {
|
||||
const loraIndex = lorasData.findIndex(l => l.name === name);
|
||||
|
||||
if (loraIndex >= 0) {
|
||||
lorasData[loraIndex].clipStrength = (parseFloat(lorasData[loraIndex].clipStrength) + 0.05).toFixed(2);
|
||||
lorasData[loraIndex].clipStrength = (parseFloat(lorasData[loraIndex].clipStrength) + getStrengthStepPreference()).toFixed(2);
|
||||
|
||||
const newValue = formatLoraValue(lorasData);
|
||||
updateWidgetValue(newValue);
|
||||
|
||||
@@ -24,6 +24,9 @@ const NEW_TAB_TEMPLATE_DEFAULT = "Default";
|
||||
|
||||
const NEW_TAB_ZOOM_LEVEL = 0.8;
|
||||
|
||||
const STRENGTH_STEP_SETTING_ID = "loramanager.strength_step";
|
||||
const STRENGTH_STEP_DEFAULT = 0.05;
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
@@ -232,6 +235,32 @@ const getNewTabTemplatePreference = (() => {
|
||||
};
|
||||
})();
|
||||
|
||||
const getStrengthStepPreference = (() => {
|
||||
let settingsUnavailableLogged = false;
|
||||
|
||||
return () => {
|
||||
const settingManager = app?.extensionManager?.setting;
|
||||
if (!settingManager || typeof settingManager.get !== "function") {
|
||||
if (!settingsUnavailableLogged) {
|
||||
console.warn("LoRA Manager: settings API unavailable, using default strength step.");
|
||||
settingsUnavailableLogged = true;
|
||||
}
|
||||
return STRENGTH_STEP_DEFAULT;
|
||||
}
|
||||
|
||||
try {
|
||||
const value = settingManager.get(STRENGTH_STEP_SETTING_ID);
|
||||
return value ?? STRENGTH_STEP_DEFAULT;
|
||||
} catch (error) {
|
||||
if (!settingsUnavailableLogged) {
|
||||
console.warn("LoRA Manager: unable to read strength step setting, using default.", error);
|
||||
settingsUnavailableLogged = true;
|
||||
}
|
||||
return STRENGTH_STEP_DEFAULT;
|
||||
}
|
||||
};
|
||||
})();
|
||||
|
||||
// ============================================================================
|
||||
// Register Extension with All Settings
|
||||
// ============================================================================
|
||||
@@ -293,6 +322,19 @@ app.registerExtension({
|
||||
tooltip: "Choose a template workflow to load when creating a new workflow tab. 'Default (Blank)' keeps ComfyUI's original blank workflow behavior.",
|
||||
category: ["LoRA Manager", "Workflow", "New Tab Template"],
|
||||
},
|
||||
{
|
||||
id: STRENGTH_STEP_SETTING_ID,
|
||||
name: "Strength Adjustment Step",
|
||||
type: "slider",
|
||||
attrs: {
|
||||
min: 0.01,
|
||||
max: 0.1,
|
||||
step: 0.01,
|
||||
},
|
||||
defaultValue: STRENGTH_STEP_DEFAULT,
|
||||
tooltip: "Step size for adjusting LoRA strength via arrow buttons or keyboard (default: 0.05)",
|
||||
category: ["LoRA Manager", "LoRA Widget", "Strength Step"],
|
||||
},
|
||||
],
|
||||
async setup() {
|
||||
await loadWorkflowOptions();
|
||||
@@ -375,4 +417,5 @@ export {
|
||||
getTagSpaceReplacementPreference,
|
||||
getUsageStatisticsPreference,
|
||||
getNewTabTemplatePreference,
|
||||
getStrengthStepPreference,
|
||||
};
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user