From 9ce56dd40c7abd8d2ae4713d6041988cc1ccc87a Mon Sep 17 00:00:00 2001 From: Will Miao Date: Wed, 20 May 2026 19:39:12 +0800 Subject: [PATCH] feat(lora): support relative paths in syntax (#917) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Autocomplete, copy/send-to-workflow, and recipe syntax now emit instead of , using relative paths to disambiguate identically-named loras in different subfolders without requiring file renames. Backend: 3-tier hybrid resolution (path → bare → basename fallback) across get_lora_info, get_lora_info_absolute, get_model_preview_url, get_model_civitai_url, get_model_info_by_name, get_lora_metadata_by_filename, and get_hash_by_filename. Also fix get_random_loras and get_cycler_list to return path-prefixed names for randomizer/cycler consistency. Frontend: autocomplete, copyLoraSyntax, handleSendToWorkflow emit folder-prefixed syntax. extract_lora_name preserves relative paths. Saved image metadata ( in EXIF) intentionally keeps basename-only for compatibility with A1111/Forge ecosystem. --- .gitignore | 1 + py/nodes/utils.py | 11 +- py/services/base_model_service.py | 78 +++++++++- py/services/lora_service.py | 27 +++- py/services/model_hash_index.py | 4 +- py/services/model_scanner.py | 35 ++++- py/services/recipe_scanner.py | 11 +- py/utils/utils.py | 121 ++++++++++++--- static/js/components/shared/ModelCard.js | 4 +- static/js/utils/uiHelpers.js | 4 +- .../components/autocomplete.behavior.test.js | 6 +- tests/utils/test_utils.py | 141 ++++++++++++++++++ web/comfyui/autocomplete.js | 9 +- 13 files changed, 404 insertions(+), 48 deletions(-) diff --git a/.gitignore b/.gitignore index 397596d7..76989756 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,7 @@ model_cache/ .claude/ .sisyphus/ .codex +.omo # Vue widgets development cache (but keep build output) vue-widgets/node_modules/ diff --git a/py/nodes/utils.py b/py/nodes/utils.py index 905126c7..a78a3ff3 100644 --- a/py/nodes/utils.py +++ b/py/nodes/utils.py @@ -45,10 +45,13 @@ logger = logging.getLogger(__name__) def extract_lora_name(lora_path): - """Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')""" - # Get the basename without extension - basename = os.path.basename(lora_path) - return os.path.splitext(basename)[0] + normalized = lora_path.replace("\\", "/") + basename = os.path.basename(normalized) + name_no_ext = os.path.splitext(basename)[0] + dirname = os.path.dirname(normalized) + if dirname and dirname not in (".", "/") and not normalized.startswith("/"): + return f"{dirname}/{name_no_ext}" + return name_no_ext def get_loras_list(kwargs): diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index cda957dd..782143af 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -870,22 +870,75 @@ class BaseModelService(ABC): """Get the static preview URL for a model file""" cache = await self.scanner.get_cached_data() + name_normalized = model_name.replace("\\", "/") + name_no_ext = name_normalized + for ext in (".safetensors", ".ckpt", ".pt", ".bin"): + if name_no_ext.lower().endswith(ext): + name_no_ext = name_no_ext[: -len(ext)] + break + + has_path = "/" in name_no_ext + basename = os.path.basename(name_no_ext) if has_path else name_no_ext + best_fallback = None + for model in cache.raw_data: - if model["file_name"] == model_name: + file_name = model.get("file_name", "") + folder = model.get("folder", "") + file_name_no_ext = file_name + for ext in (".safetensors", ".ckpt", ".pt", ".bin"): + if file_name_no_ext.lower().endswith(ext): + file_name_no_ext = file_name_no_ext[: -len(ext)] + break + path_name = f"{folder}/{file_name_no_ext}".replace("\\", "/") if folder else file_name_no_ext + + if name_no_ext == file_name_no_ext or name_no_ext == path_name: preview_url = model.get("preview_url") if preview_url: from ..config import config return config.get_preview_static_url(preview_url) + if has_path and file_name_no_ext == basename: + if folder and name_no_ext.startswith(folder.replace("\\", "/") + "/"): + best_fallback = model + elif best_fallback is None: + best_fallback = model + + if best_fallback: + preview_url = best_fallback.get("preview_url") + if preview_url: + from ..config import config + + return config.get_preview_static_url(preview_url) + return "/loras_static/images/no-preview.png" async def get_model_civitai_url(self, model_name: str) -> Dict[str, Optional[str]]: """Get the Civitai URL for a model file""" cache = await self.scanner.get_cached_data() + name_normalized = model_name.replace("\\", "/") + name_no_ext = name_normalized + for ext in (".safetensors", ".ckpt", ".pt", ".bin"): + if name_no_ext.lower().endswith(ext): + name_no_ext = name_no_ext[: -len(ext)] + break + + has_path = "/" in name_no_ext + basename = os.path.basename(name_no_ext) if has_path else name_no_ext + best_fallback = None + for model in cache.raw_data: - if model["file_name"] == model_name: + file_name = model.get("file_name", "") + folder = model.get("folder", "") + file_name_no_ext = file_name + for ext in (".safetensors", ".ckpt", ".pt", ".bin"): + if file_name_no_ext.lower().endswith(ext): + file_name_no_ext = file_name_no_ext[: -len(ext)] + break + path_name = f"{folder}/{file_name_no_ext}".replace("\\", "/") if folder else file_name_no_ext + + if name_no_ext == file_name_no_ext or name_no_ext == path_name: civitai_data = model.get("civitai", {}) model_id = civitai_data.get("modelId") version_id = civitai_data.get("id") @@ -904,6 +957,27 @@ class BaseModelService(ABC): "version_id": str(version_id) if version_id else None, } + if has_path and file_name_no_ext == basename: + if folder and name_no_ext.startswith(folder.replace("\\", "/") + "/"): + best_fallback = model + elif best_fallback is None: + best_fallback = model + + if best_fallback: + civitai_data = best_fallback.get("civitai", {}) + model_id = civitai_data.get("modelId") + if model_id: + version_id = civitai_data.get("id") + civitai_host = self.settings.get("civitai_host", "civitai.com") + civitai_url = build_civitai_model_page_url( + model_id, version_id, host=civitai_host + ) + return { + "civitai_url": civitai_url, + "model_id": str(model_id), + "version_id": str(version_id) if version_id else None, + } + return {"civitai_url": None, "model_id": None, "version_id": None} async def get_model_metadata(self, file_path: str) -> Optional[Dict]: diff --git a/py/services/lora_service.py b/py/services/lora_service.py index f1a121ac..6819e4fd 100644 --- a/py/services/lora_service.py +++ b/py/services/lora_service.py @@ -312,8 +312,23 @@ class LoraService(BaseModelService): """Return cached raw metadata for a LoRA matching the given filename.""" cache = await self.scanner.get_cached_data(force_refresh=False) + fn_normalized = filename.replace("\\", "/") + fn_no_ext = fn_normalized + for ext in (".safetensors", ".ckpt", ".pt", ".bin"): + if fn_no_ext.lower().endswith(ext): + fn_no_ext = fn_no_ext[: -len(ext)] + break + for lora in cache.raw_data if cache else []: - if lora.get("file_name") == filename: + file_name = lora.get("file_name", "") + folder = lora.get("folder", "") + file_name_no_ext = file_name + for ext in (".safetensors", ".ckpt", ".pt", ".bin"): + if file_name_no_ext.lower().endswith(ext): + file_name_no_ext = file_name_no_ext[: -len(ext)] + break + path_name = f"{folder}/{file_name_no_ext}".replace("\\", "/") if folder else file_name_no_ext + if fn_no_ext in (file_name_no_ext, path_name): return lora return None @@ -401,7 +416,10 @@ class LoraService(BaseModelService): locked_loras = locked_loras[:target_count] # Filter out locked LoRAs from available pool - locked_names = {lora["name"] for lora in locked_loras} + locked_names = { + os.path.basename(lora["name"]) if "/" in str(lora.get("name", "")) else lora["name"] + for lora in locked_loras + } available_pool = [ l for l in available_loras if l["file_name"] not in locked_names ] @@ -456,7 +474,7 @@ class LoraService(BaseModelService): result_loras.append( { - "name": lora["file_name"], + "name": f"{lora['folder']}/{lora['file_name']}" if lora.get("folder") else lora["file_name"], "strength": model_str, "clipStrength": clip_str, "active": True, @@ -672,8 +690,9 @@ class LoraService(BaseModelService): # Return minimal data needed for cycling return [ { - "file_name": lora["file_name"], + "file_name": f"{lora['folder']}/{lora['file_name']}" if lora.get("folder") else lora["file_name"], "model_name": lora.get("model_name", lora["file_name"]), + "folder": lora.get("folder", ""), } for lora in available_loras ] diff --git a/py/services/model_hash_index.py b/py/services/model_hash_index.py index 7cfbd108..19f62572 100644 --- a/py/services/model_hash_index.py +++ b/py/services/model_hash_index.py @@ -209,7 +209,9 @@ class ModelHashIndex: return self._filename_to_hash.get(filename) def get_hash_by_filename(self, filename: str) -> Optional[str]: - """Get hash for a filename without extension""" + """Get hash for a filename (bare basename or path-prefixed name)""" + if "/" in filename or "\\" in filename: + filename = os.path.splitext(os.path.basename(filename.replace("\\", "/")))[0] return self._filename_to_hash.get(filename) def clear(self) -> None: diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index 7d4d659e..34fb6fe1 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -1597,12 +1597,39 @@ class ModelScanner: """Get model information by name""" try: cache = await self.get_cached_data() - + + name_normalized = name.replace("\\", "/") + name_no_ext = name_normalized + for ext in (".safetensors", ".ckpt", ".pt", ".bin"): + if name_no_ext.lower().endswith(ext): + name_no_ext = name_no_ext[: -len(ext)] + break + + has_path = "/" in name_no_ext + basename = os.path.basename(name_no_ext) if has_path else name_no_ext + best_fallback = None + for model in cache.raw_data: - if model.get("file_name") == name: + file_name = model.get("file_name", "") + folder = model.get("folder", "") + file_name_no_ext = file_name + for ext in (".safetensors", ".ckpt", ".pt", ".bin"): + if file_name_no_ext.lower().endswith(ext): + file_name_no_ext = file_name_no_ext[: -len(ext)] + break + path_name = f"{folder}/{file_name_no_ext}".replace("\\", "/") if folder else file_name_no_ext + + if name_no_ext == file_name_no_ext or name_no_ext == path_name: return model - - return None + + if has_path and file_name_no_ext == basename: + if folder and name_no_ext.startswith(folder.replace("\\", "/") + "/"): + best_fallback = model + elif best_fallback is None: + best_fallback = model + + return best_fallback + except Exception as e: logger.error(f"Error getting model info by name: {e}", exc_info=True) return None diff --git a/py/services/recipe_scanner.py b/py/services/recipe_scanner.py index ceeb8732..ca49f7db 100644 --- a/py/services/recipe_scanner.py +++ b/py/services/recipe_scanner.py @@ -2517,6 +2517,7 @@ class RecipeScanner: continue file_name = None + folder = "" hash_value = (lora.get("hash") or "").lower() if ( hash_value @@ -2526,6 +2527,11 @@ class RecipeScanner: file_path = self._lora_scanner._hash_index.get_path(hash_value) if file_path: file_name = os.path.splitext(os.path.basename(file_path))[0] + if lora_cache is not None: + for cached_lora in getattr(lora_cache, "raw_data", []): + if cached_lora.get("file_path") == file_path: + folder = cached_lora.get("folder", "") + break if not file_name and lora.get("modelVersionId") and lora_cache is not None: for cached_lora in getattr(lora_cache, "raw_data", []): @@ -2540,13 +2546,16 @@ class RecipeScanner: file_name = os.path.splitext(os.path.basename(cached_path))[ 0 ] + folder = cached_lora.get("folder", "") break if not file_name: file_name = lora.get("file_name", "unknown-lora") + folder = lora.get("folder", "") + lora_name = f"{folder}/{file_name}" if folder else file_name strength = lora.get("strength", 1.0) - syntax_parts.append(f"") + syntax_parts.append(f"") return syntax_parts diff --git a/py/utils/utils.py b/py/utils/utils.py index ef2a333f..8ed20924 100644 --- a/py/utils/utils.py +++ b/py/utils/utils.py @@ -15,30 +15,64 @@ def get_lora_info(lora_name): scanner = await ServiceRegistry.get_lora_scanner() cache = await scanner.get_cached_data() + lora_name_normalized = lora_name.replace("\\", "/") + lora_name_no_ext = lora_name_normalized + for ext in (".safetensors", ".ckpt", ".pt", ".bin"): + if lora_name_no_ext.lower().endswith(ext): + lora_name_no_ext = lora_name_no_ext[: -len(ext)] + break + + has_path = "/" in lora_name_no_ext + basename = os.path.basename(lora_name_no_ext) if has_path else lora_name_no_ext + best_fallback = None + for item in cache.raw_data: - if item.get("file_name") == lora_name: - file_path = item.get("file_path") - if file_path: - # Check all lora roots including extra paths - all_roots = list(config.loras_roots or []) + list( - config.extra_loras_roots or [] + file_name = item.get("file_name", "") + folder = item.get("folder", "") + file_name_no_ext = file_name + for ext in (".safetensors", ".ckpt", ".pt", ".bin"): + if file_name_no_ext.lower().endswith(ext): + file_name_no_ext = file_name_no_ext[: -len(ext)] + break + path_name = f"{folder}/{file_name_no_ext}".replace("\\", "/") if folder else file_name_no_ext + + if lora_name_no_ext not in (file_name_no_ext, path_name): + if has_path and file_name_no_ext == basename: + if folder and lora_name_no_ext.startswith(folder.replace("\\", "/") + "/"): + best_fallback = item + elif best_fallback is None: + best_fallback = item + continue + + file_path = item.get("file_path") + if not file_path: + continue + + all_roots = list(config.loras_roots or []) + list( + config.extra_loras_roots or [] + ) + for root in all_roots: + root = root.replace(os.sep, "/") + if file_path.startswith(root): + relative_path = os.path.relpath(file_path, root).replace( + os.sep, "/" ) - for root in all_roots: - root = root.replace(os.sep, "/") - if file_path.startswith(root): - relative_path = os.path.relpath(file_path, root).replace( - os.sep, "/" - ) - # Get trigger words from civitai metadata - civitai = item.get("civitai", {}) - trigger_words = ( - civitai.get("trainedWords", []) if civitai else [] - ) - return relative_path, trigger_words - # If not found in any root, return path with trigger words from cache civitai = item.get("civitai", {}) - trigger_words = civitai.get("trainedWords", []) if civitai else [] - return file_path, trigger_words + trigger_words = ( + civitai.get("trainedWords", []) if civitai else [] + ) + return relative_path, trigger_words + civitai = item.get("civitai", {}) + trigger_words = civitai.get("trainedWords", []) if civitai else [] + return file_path, trigger_words + + if best_fallback: + file_path = best_fallback.get("file_path") + if file_path: + civitai = best_fallback.get("civitai", {}) + trigger_words = civitai.get("trainedWords", []) if civitai else [] + return file_path, trigger_words + return lora_name, [] try: @@ -77,15 +111,54 @@ def get_lora_info_absolute(lora_name): scanner = await ServiceRegistry.get_lora_scanner() cache = await scanner.get_cached_data() + lora_name_normalized = lora_name.replace("\\", "/") + lora_name_no_ext = lora_name_normalized + for ext in (".safetensors", ".ckpt", ".pt", ".bin"): + if lora_name_no_ext.lower().endswith(ext): + lora_name_no_ext = lora_name_no_ext[: -len(ext)] + break + + has_path = "/" in lora_name_no_ext + basename = os.path.basename(lora_name_no_ext) if has_path else lora_name_no_ext + best_fallback = None + for item in cache.raw_data: - if item.get("file_name") == lora_name: + file_name = item.get("file_name", "") + folder = item.get("folder", "") + file_name_no_ext = file_name + for ext in (".safetensors", ".ckpt", ".pt", ".bin"): + if file_name_no_ext.lower().endswith(ext): + file_name_no_ext = file_name_no_ext[: -len(ext)] + break + path_name = f"{folder}/{file_name_no_ext}".replace("\\", "/") if folder else file_name_no_ext + + if lora_name_no_ext == file_name_no_ext: file_path = item.get("file_path") if file_path: - # Return absolute path directly - # Get trigger words from civitai metadata civitai = item.get("civitai", {}) trigger_words = civitai.get("trainedWords", []) if civitai else [] return file_path, trigger_words + + if lora_name_no_ext == path_name: + file_path = item.get("file_path") + if file_path: + civitai = item.get("civitai", {}) + trigger_words = civitai.get("trainedWords", []) if civitai else [] + return file_path, trigger_words + + if has_path and file_name_no_ext == basename: + if folder and lora_name_no_ext.startswith(folder.replace("\\", "/") + "/"): + best_fallback = item + elif best_fallback is None: + best_fallback = item + + if best_fallback: + file_path = best_fallback.get("file_path") + if file_path: + civitai = best_fallback.get("civitai", {}) + trigger_words = civitai.get("trainedWords", []) if civitai else [] + return file_path, trigger_words + return lora_name, [] try: diff --git a/static/js/components/shared/ModelCard.js b/static/js/components/shared/ModelCard.js index f9cc7e47..e0b12e12 100644 --- a/static/js/components/shared/ModelCard.js +++ b/static/js/components/shared/ModelCard.js @@ -166,7 +166,9 @@ async function toggleFavorite(card) { function handleSendToWorkflow(card, replaceMode, modelType) { if (modelType === MODEL_TYPES.LORA) { const usageTips = JSON.parse(card.dataset.usage_tips || '{}'); - const loraSyntax = buildLoraSyntax(card.dataset.file_name, usageTips); + const folder = card.dataset.folder || ''; + const loraName = folder ? `${folder}/${card.dataset.file_name}` : card.dataset.file_name; + const loraSyntax = buildLoraSyntax(loraName, usageTips); sendLoraToWorkflow(loraSyntax, replaceMode, 'lora'); } else if (modelType === MODEL_TYPES.CHECKPOINT) { const modelPath = card.dataset.filepath; diff --git a/static/js/utils/uiHelpers.js b/static/js/utils/uiHelpers.js index 45548994..11f0e8c4 100644 --- a/static/js/utils/uiHelpers.js +++ b/static/js/utils/uiHelpers.js @@ -430,7 +430,9 @@ export function buildLoraSyntax(fileName, usageTips = {}) { export function copyLoraSyntax(card) { const usageTips = JSON.parse(card.dataset.usage_tips || "{}"); - const baseSyntax = buildLoraSyntax(card.dataset.file_name, usageTips); + const folder = card.dataset.folder || ''; + const loraName = folder ? `${folder}/${card.dataset.file_name}` : card.dataset.file_name; + const baseSyntax = buildLoraSyntax(loraName, usageTips); // Check if trigger words should be included const includeTriggerWords = state.global.settings.include_trigger_words; diff --git a/tests/frontend/components/autocomplete.behavior.test.js b/tests/frontend/components/autocomplete.behavior.test.js index 7c5a7b53..a7b5ac7c 100644 --- a/tests/frontend/components/autocomplete.behavior.test.js +++ b/tests/frontend/components/autocomplete.behavior.test.js @@ -185,7 +185,7 @@ describe('AutoComplete widget interactions', () => { expect(fetchApiMock).toHaveBeenCalledWith( '/lm/loras/usage-tips-by-path?relative_path=models%2Fexample.safetensors', ); - expect(input.value).toContain(','); + expect(input.value).toContain(','); expect(autoComplete.dropdown.style.display).toBe('none'); expect(input.focus).toHaveBeenCalled(); expect(input.setSelectionRange).toHaveBeenCalled(); @@ -1624,8 +1624,8 @@ describe('AutoComplete widget interactions', () => { await autoComplete.insertSelection('models/example.safetensors'); - expect(input.value).toContain(''); - expect(input.value).not.toContain(','); + expect(input.value).toContain(''); + expect(input.value).not.toContain(','); }); it('replaces entire phrase when selected tag ends with underscore version of search term (suffix match)', async () => { diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index a0b34258..5886b519 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -1,13 +1,43 @@ import pytest from py.services.settings_manager import SettingsManager, get_settings_manager +from py.services.service_registry import ServiceRegistry from py.utils.utils import ( calculate_recipe_fingerprint, calculate_relative_path_for_model, + get_lora_info, + get_lora_info_absolute, sanitize_folder_name, ) +class _FakeCache: + def __init__(self, items): + self.raw_data = list(items) + + +class _FakeScanner: + def __init__(self, items): + self._cache = _FakeCache(items) + + async def get_cached_data(self): + return self._cache + + +@pytest.fixture +def mock_lora_scanner(monkeypatch): + def _setup(items): + scanner = _FakeScanner(items) + + async def get_scanner(): + return scanner + + monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", get_scanner) + return scanner + + return _setup + + @pytest.fixture def isolated_settings(monkeypatch): manager = get_settings_manager() @@ -114,3 +144,114 @@ def test_calculate_recipe_fingerprint_empty_input(): ) def test_sanitize_folder_name(original, expected): assert sanitize_folder_name(original) == expected + + +def test_get_lora_info_absolute_bare_name(mock_lora_scanner): + mock_lora_scanner([ + {"file_name": "mylora", "folder": "SDXL", "file_path": "/models/Lora/SDXL/mylora.safetensors", "civitai": {"trainedWords": ["trigger1"]}}, + ]) + + path, triggers = get_lora_info_absolute("mylora") + + assert path == "/models/Lora/SDXL/mylora.safetensors" + assert triggers == ["trigger1"] + + +def test_get_lora_info_absolute_with_path(mock_lora_scanner): + mock_lora_scanner([ + {"file_name": "mylora", "folder": "SDXL/Styles", "file_path": "/models/Lora/SDXL/Styles/mylora.safetensors", "civitai": {"trainedWords": ["artistic"]}}, + {"file_name": "other", "folder": "", "file_path": "/models/Lora/other.safetensors", "civitai": {}}, + ]) + + path, triggers = get_lora_info_absolute("SDXL/Styles/mylora") + + assert path == "/models/Lora/SDXL/Styles/mylora.safetensors" + assert triggers == ["artistic"] + + +def test_get_lora_info_absolute_path_fallback_to_basename(mock_lora_scanner): + mock_lora_scanner([ + {"file_name": "mylora", "folder": "RenamedFolder", "file_path": "/models/Lora/RenamedFolder/mylora.safetensors", "civitai": {"trainedWords": ["trigger1"]}}, + ]) + + path, triggers = get_lora_info_absolute("OldFolder/mylora") + + assert path == "/models/Lora/RenamedFolder/mylora.safetensors" + assert triggers == ["trigger1"] + + +def test_get_lora_info_absolute_prefers_folder_match(mock_lora_scanner): + mock_lora_scanner([ + {"file_name": "mylora", "folder": "V1", "file_path": "/models/Lora/V1/mylora.safetensors", "civitai": {"trainedWords": ["v1"]}}, + {"file_name": "mylora", "folder": "V2", "file_path": "/models/Lora/V2/mylora.safetensors", "civitai": {"trainedWords": ["v2"]}}, + ]) + + path, triggers = get_lora_info_absolute("V2/mylora") + + assert path == "/models/Lora/V2/mylora.safetensors" + assert triggers == ["v2"] + + +def test_get_lora_info_absolute_no_folder_in_cache_no_path_in_name(mock_lora_scanner): + mock_lora_scanner([ + {"file_name": "mylora", "folder": "", "file_path": "/models/Lora/mylora.safetensors", "civitai": {}}, + ]) + + path, triggers = get_lora_info_absolute("mylora") + + assert path == "/models/Lora/mylora.safetensors" + assert triggers == [] + + +def test_get_lora_info_absolute_strips_extension(mock_lora_scanner): + mock_lora_scanner([ + {"file_name": "mylora", "folder": "SDXL", "file_path": "/models/Lora/SDXL/mylora.safetensors", "civitai": {"trainedWords": ["hello"]}}, + ]) + + path, triggers = get_lora_info_absolute("SDXL/mylora.safetensors") + + assert path == "/models/Lora/SDXL/mylora.safetensors" + assert triggers == ["hello"] + + +def test_get_lora_info_absolute_not_found_returns_original(mock_lora_scanner): + mock_lora_scanner([ + {"file_name": "mylora", "folder": "SDXL", "file_path": "/models/Lora/SDXL/mylora.safetensors", "civitai": {}}, + ]) + + path, triggers = get_lora_info_absolute("nonexistent") + + assert path == "nonexistent" + assert triggers == [] + + +def test_get_lora_info_bare_name(mock_lora_scanner): + mock_lora_scanner([ + {"file_name": "mylora", "folder": "SDXL", "file_path": "/models/Lora/SDXL/mylora.safetensors", "civitai": {"trainedWords": ["trigger1"]}}, + ]) + + path, triggers = get_lora_info("mylora") + + assert triggers == ["trigger1"] + + +def test_get_lora_info_with_path(mock_lora_scanner): + mock_lora_scanner([ + {"file_name": "mylora", "folder": "SDXL/Styles", "file_path": "/models/Lora/SDXL/Styles/mylora.safetensors", "civitai": {"trainedWords": ["artistic"]}}, + {"file_name": "other", "folder": "", "file_path": "/models/Lora/other.safetensors", "civitai": {}}, + ]) + + path, triggers = get_lora_info("SDXL/Styles/mylora") + + assert triggers == ["artistic"] + + +def test_get_lora_info_not_found_returns_original(mock_lora_scanner): + mock_lora_scanner([ + {"file_name": "mylora", "folder": "SDXL", "file_path": "/models/Lora/SDXL/mylora.safetensors", "civitai": {}}, + ]) + + path, triggers = get_lora_info("nonexistent") + + assert path == "nonexistent" + assert triggers == [] diff --git a/web/comfyui/autocomplete.js b/web/comfyui/autocomplete.js index c43499dc..4f29f2c9 100644 --- a/web/comfyui/autocomplete.js +++ b/web/comfyui/autocomplete.js @@ -226,7 +226,10 @@ const MODEL_BEHAVIORS = { } }, async getInsertText(_instance, relativePath) { - const fileName = removeLoraExtension(splitRelativePath(relativePath).fileName); + const { directories, fileName } = splitRelativePath(relativePath); + const baseName = removeLoraExtension(fileName); + const folder = directories.length ? directories.join('/') + '/' : ''; + const loraName = folder + baseName; let strength = 1.0; let hasStrength = false; @@ -262,9 +265,9 @@ const MODEL_BEHAVIORS = { } if (clipStrength !== null) { - return formatAutocompleteInsertion(``); + return formatAutocompleteInsertion(``); } - return formatAutocompleteInsertion(``); + return formatAutocompleteInsertion(``); } }, embeddings: {