diff --git a/py/config.py b/py/config.py index 2d2deabd..bd830fe8 100644 --- a/py/config.py +++ b/py/config.py @@ -1,7 +1,7 @@ import os import platform import folder_paths # type: ignore -from typing import List +from typing import Dict, Iterable, List, Mapping import logging import json import urllib.parse @@ -38,39 +38,48 @@ class Config: self.save_folder_paths_to_settings() def save_folder_paths_to_settings(self): - """Save folder paths to settings.json for standalone mode to use later""" + """Persist ComfyUI-derived folder paths to the multi-library settings.""" try: - # Check if we're running in ComfyUI mode (not standalone) - # Load existing settings - settings_path = ensure_settings_file(logger) - settings = {} - if os.path.exists(settings_path): - with open(settings_path, 'r', encoding='utf-8') as f: - settings = json.load(f) + ensure_settings_file(logger) + from .services.settings_manager import settings as settings_service - # Update settings with paths - settings['folder_paths'] = { - 'loras': self.loras_roots, - 'checkpoints': self.checkpoints_roots, - 'unet': self.unet_roots, - 'embeddings': self.embeddings_roots, - } - - # Add default roots if there's only one item and key doesn't exist - if len(self.loras_roots) == 1 and "default_lora_root" not in settings: - settings["default_lora_root"] = self.loras_roots[0] - - if self.checkpoints_roots and len(self.checkpoints_roots) == 1 and "default_checkpoint_root" not in settings: - settings["default_checkpoint_root"] = self.checkpoints_roots[0] + libraries = settings_service.get_libraries() + comfy_library = libraries.get("comfyui", {}) - if self.embeddings_roots and len(self.embeddings_roots) == 1 and "default_embedding_root" not in settings: - settings["default_embedding_root"] = self.embeddings_roots[0] - - # Save settings - with open(settings_path, 'w', encoding='utf-8') as f: - json.dump(settings, f, indent=2) - - logger.info("Saved folder paths to settings.json") + default_lora_root = comfy_library.get("default_lora_root", "") + if not default_lora_root and len(self.loras_roots) == 1: + default_lora_root = self.loras_roots[0] + + default_checkpoint_root = comfy_library.get("default_checkpoint_root", "") + if (not default_checkpoint_root and self.checkpoints_roots and + len(self.checkpoints_roots) == 1): + default_checkpoint_root = self.checkpoints_roots[0] + + default_embedding_root = comfy_library.get("default_embedding_root", "") + if (not default_embedding_root and self.embeddings_roots and + len(self.embeddings_roots) == 1): + default_embedding_root = self.embeddings_roots[0] + + metadata = dict(comfy_library.get("metadata", {})) + metadata.setdefault("display_name", "ComfyUI") + metadata["source"] = "comfyui" + + settings_service.upsert_library( + "comfyui", + folder_paths={ + 'loras': list(self.loras_roots), + 'checkpoints': list(self.checkpoints_roots or []), + 'unet': list(self.unet_roots or []), + 'embeddings': list(self.embeddings_roots or []), + }, + default_lora_root=default_lora_root, + default_checkpoint_root=default_checkpoint_root, + default_embedding_root=default_embedding_root, + metadata=metadata, + activate=True, + ) + + logger.info("Updated 'comfyui' library with current folder paths") except Exception as e: logger.warning(f"Failed to save folder paths: {e}") @@ -156,31 +165,91 @@ class Config: return mapped_path return link_path + def _dedupe_existing_paths(self, raw_paths: Iterable[str]) -> Dict[str, str]: + dedup: Dict[str, str] = {} + for path in raw_paths: + if not isinstance(path, str): + continue + if not os.path.exists(path): + continue + real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/') + normalized = os.path.normpath(path).replace(os.sep, '/') + if real_path not in dedup: + dedup[real_path] = normalized + return dedup + + def _prepare_lora_paths(self, raw_paths: Iterable[str]) -> List[str]: + path_map = self._dedupe_existing_paths(raw_paths) + unique_paths = sorted(path_map.values(), key=lambda p: p.lower()) + + for original_path in unique_paths: + real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/') + if real_path != original_path: + self.add_path_mapping(original_path, real_path) + + return unique_paths + + def _prepare_checkpoint_paths( + self, checkpoint_paths: Iterable[str], unet_paths: Iterable[str] + ) -> List[str]: + checkpoint_map = self._dedupe_existing_paths(checkpoint_paths) + unet_map = self._dedupe_existing_paths(unet_paths) + + merged_map: Dict[str, str] = {} + for real_path, original in {**checkpoint_map, **unet_map}.items(): + if real_path not in merged_map: + merged_map[real_path] = original + + unique_paths = sorted(merged_map.values(), key=lambda p: p.lower()) + + checkpoint_values = set(checkpoint_map.values()) + unet_values = set(unet_map.values()) + self.checkpoints_roots = [p for p in unique_paths if p in checkpoint_values] + self.unet_roots = [p for p in unique_paths if p in unet_values] + + for original_path in unique_paths: + real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/') + if real_path != original_path: + self.add_path_mapping(original_path, real_path) + + return unique_paths + + def _prepare_embedding_paths(self, raw_paths: Iterable[str]) -> List[str]: + path_map = self._dedupe_existing_paths(raw_paths) + unique_paths = sorted(path_map.values(), key=lambda p: p.lower()) + + for original_path in unique_paths: + real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/') + if real_path != original_path: + self.add_path_mapping(original_path, real_path) + + return unique_paths + + def _apply_library_paths(self, folder_paths: Mapping[str, Iterable[str]]) -> None: + self._path_mappings.clear() + + lora_paths = folder_paths.get('loras', []) or [] + checkpoint_paths = folder_paths.get('checkpoints', []) or [] + unet_paths = folder_paths.get('unet', []) or [] + embedding_paths = folder_paths.get('embeddings', []) or [] + + self.loras_roots = self._prepare_lora_paths(lora_paths) + self.base_models_roots = self._prepare_checkpoint_paths(checkpoint_paths, unet_paths) + self.embeddings_roots = self._prepare_embedding_paths(embedding_paths) + + self._scan_symbolic_links() + def _init_lora_paths(self) -> List[str]: """Initialize and validate LoRA paths from ComfyUI settings""" try: raw_paths = folder_paths.get_folder_paths("loras") - - # Normalize and resolve symlinks, store mapping from resolved -> original - path_map = {} - for path in raw_paths: - if os.path.exists(path): - real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/') - path_map[real_path] = path_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen - - # Now sort and use only the deduplicated real paths - unique_paths = sorted(path_map.values(), key=lambda p: p.lower()) + unique_paths = self._prepare_lora_paths(raw_paths) logger.info("Found LoRA roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")) - + if not unique_paths: logger.warning("No valid loras folders found in ComfyUI configuration") return [] - - for original_path in unique_paths: - real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/') - if real_path != original_path: - self.add_path_mapping(original_path, real_path) - + return unique_paths except Exception as e: logger.warning(f"Error initializing LoRA paths: {e}") @@ -189,52 +258,17 @@ class Config: def _init_checkpoint_paths(self) -> List[str]: """Initialize and validate checkpoint paths from ComfyUI settings""" try: - # Get checkpoint paths from folder_paths raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints") raw_unet_paths = folder_paths.get_folder_paths("unet") - - # Normalize and resolve symlinks for checkpoints, store mapping from resolved -> original - checkpoint_map = {} - for path in raw_checkpoint_paths: - if os.path.exists(path): - real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/') - checkpoint_map[real_path] = checkpoint_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen - - # Normalize and resolve symlinks for unet, store mapping from resolved -> original - unet_map = {} - for path in raw_unet_paths: - if os.path.exists(path): - real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/') - unet_map[real_path] = unet_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen - - # Merge both maps and deduplicate by real path - merged_map = {} - for real_path, orig_path in {**checkpoint_map, **unet_map}.items(): - if real_path not in merged_map: - merged_map[real_path] = orig_path + unique_paths = self._prepare_checkpoint_paths(raw_checkpoint_paths, raw_unet_paths) - # Now sort and use only the deduplicated real paths - unique_paths = sorted(merged_map.values(), key=lambda p: p.lower()) - - # Split back into checkpoints and unet roots for class properties - self.checkpoints_roots = [p for p in unique_paths if p in checkpoint_map.values()] - self.unet_roots = [p for p in unique_paths if p in unet_map.values()] - - all_paths = unique_paths - - logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(all_paths) if all_paths else "[]")) - - if not all_paths: + logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")) + + if not unique_paths: logger.warning("No valid checkpoint folders found in ComfyUI configuration") return [] - - # Initialize path mappings - for original_path in all_paths: - real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/') - if real_path != original_path: - self.add_path_mapping(original_path, real_path) - - return all_paths + + return unique_paths except Exception as e: logger.warning(f"Error initializing checkpoint paths: {e}") return [] @@ -243,27 +277,13 @@ class Config: """Initialize and validate embedding paths from ComfyUI settings""" try: raw_paths = folder_paths.get_folder_paths("embeddings") - - # Normalize and resolve symlinks, store mapping from resolved -> original - path_map = {} - for path in raw_paths: - if os.path.exists(path): - real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/') - path_map[real_path] = path_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen - - # Now sort and use only the deduplicated real paths - unique_paths = sorted(path_map.values(), key=lambda p: p.lower()) + unique_paths = self._prepare_embedding_paths(raw_paths) logger.info("Found embedding roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]")) - + if not unique_paths: logger.warning("No valid embeddings folders found in ComfyUI configuration") return [] - - for original_path in unique_paths: - real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/') - if real_path != original_path: - self.add_path_mapping(original_path, real_path) - + return unique_paths except Exception as e: logger.warning(f"Error initializing embedding paths: {e}") @@ -272,7 +292,7 @@ class Config: def get_preview_static_url(self, preview_path: str) -> str: if not preview_path: return "" - + real_path = os.path.realpath(preview_path).replace(os.sep, '/') # Find longest matching path (most specific match) @@ -289,8 +309,23 @@ class Config: safe_parts = [urllib.parse.quote(part) for part in relative_path.split('/')] safe_path = '/'.join(safe_parts) return f'{best_route}/{safe_path}' - + return "" + def apply_library_settings(self, library_config: Mapping[str, object]) -> None: + """Update runtime paths to match the provided library configuration.""" + folder_paths = library_config.get('folder_paths') if isinstance(library_config, Mapping) else {} + if not isinstance(folder_paths, Mapping): + folder_paths = {} + + self._apply_library_paths(folder_paths) + + logger.info( + "Applied library settings with %d lora roots, %d checkpoint roots, and %d embedding roots", + len(self.loras_roots or []), + len(self.base_models_roots or []), + len(self.embeddings_roots or []), + ) + # Global config instance config = Config() diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index 9ea8acd7..61917736 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -82,9 +82,26 @@ class ModelScanner: self._excluded_models = [] # List to track excluded models self._persistent_cache = get_persistent_cache() self._initialized = True - + # Register this service asyncio.create_task(self._register_service()) + + def on_library_changed(self) -> None: + """Reset caches when the active library changes.""" + self._persistent_cache = get_persistent_cache() + self._cache = None + self._hash_index = ModelHashIndex() + self._tags_count = {} + self._excluded_models = [] + self._is_initializing = False + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and not loop.is_closed(): + loop.create_task(self.initialize_in_background()) async def _register_service(self): """Register this instance with the ServiceRegistry""" diff --git a/py/services/persistent_model_cache.py b/py/services/persistent_model_cache.py index e473035b..45e362ab 100644 --- a/py/services/persistent_model_cache.py +++ b/py/services/persistent_model_cache.py @@ -1,6 +1,7 @@ import json import logging import os +import re import sqlite3 import threading from dataclasses import dataclass @@ -24,11 +25,12 @@ class PersistentModelCache: """Persist core model metadata and hash index data in SQLite.""" _DEFAULT_FILENAME = "model_cache.sqlite" - _instance: Optional["PersistentModelCache"] = None + _instances: Dict[str, "PersistentModelCache"] = {} _instance_lock = threading.Lock() - def __init__(self, db_path: Optional[str] = None) -> None: - self._db_path = db_path or self._resolve_default_path() + def __init__(self, library_name: str = "default", db_path: Optional[str] = None) -> None: + self._library_name = library_name or "default" + self._db_path = db_path or self._resolve_default_path(self._library_name) self._db_lock = threading.Lock() self._schema_initialized = False try: @@ -41,11 +43,12 @@ class PersistentModelCache: self._initialize_schema() @classmethod - def get_default(cls) -> "PersistentModelCache": + def get_default(cls, library_name: Optional[str] = None) -> "PersistentModelCache": + name = (library_name or "default") with cls._instance_lock: - if cls._instance is None: - cls._instance = cls() - return cls._instance + if name not in cls._instances: + cls._instances[name] = cls(name) + return cls._instances[name] def is_enabled(self) -> bool: return os.environ.get("LORA_MANAGER_DISABLE_PERSISTENT_CACHE", "0") != "1" @@ -203,7 +206,7 @@ class PersistentModelCache: # Internal helpers ------------------------------------------------- - def _resolve_default_path(self) -> str: + def _resolve_default_path(self, library_name: str) -> str: override = os.environ.get("LORA_MANAGER_CACHE_DB") if override: return override @@ -212,7 +215,12 @@ class PersistentModelCache: except Exception as exc: # pragma: no cover - defensive guard logger.warning("Falling back to project directory for cache: %s", exc) settings_dir = os.path.dirname(os.path.dirname(self._db_path)) if hasattr(self, "_db_path") else os.getcwd() - return os.path.join(settings_dir, self._DEFAULT_FILENAME) + safe_name = re.sub(r"[^A-Za-z0-9_.-]", "_", library_name or "default") + if safe_name.lower() in ("default", ""): + legacy_path = os.path.join(settings_dir, self._DEFAULT_FILENAME) + if os.path.exists(legacy_path): + return legacy_path + return os.path.join(settings_dir, "model_cache", f"{safe_name}.sqlite") def _initialize_schema(self) -> None: with self._db_lock: @@ -343,4 +351,7 @@ class PersistentModelCache: def get_persistent_cache() -> PersistentModelCache: - return PersistentModelCache.get_default() + from .settings_manager import settings as settings_service # Local import to avoid cycles + + library_name = settings_service.get_active_library_name() + return PersistentModelCache.get_default(library_name) diff --git a/py/services/settings_manager.py b/py/services/settings_manager.py index f2f643c9..be660c99 100644 --- a/py/services/settings_manager.py +++ b/py/services/settings_manager.py @@ -1,7 +1,9 @@ +import copy import json import os import logging -from typing import Any, Dict +from datetime import datetime +from typing import Any, Dict, Iterable, List, Mapping, Optional from ..utils.settings_paths import ensure_settings_file @@ -42,6 +44,7 @@ class SettingsManager: self.settings = self._load_settings() self._migrate_setting_keys() self._ensure_default_settings() + self._migrate_to_library_registry() self._migrate_download_path_template() self._auto_set_default_roots() self._check_environment_variables() @@ -69,6 +72,223 @@ class SettingsManager: if updated: self._save_settings() + def _migrate_to_library_registry(self) -> None: + """Ensure settings include the multi-library registry structure.""" + libraries = self.settings.get("libraries") + active_name = self.settings.get("active_library") + + if not isinstance(libraries, dict) or not libraries: + library_name = active_name or "default" + library_payload = self._build_library_payload( + folder_paths=self.settings.get("folder_paths", {}), + default_lora_root=self.settings.get("default_lora_root", ""), + default_checkpoint_root=self.settings.get("default_checkpoint_root", ""), + default_embedding_root=self.settings.get("default_embedding_root", ""), + ) + libraries = {library_name: library_payload} + self.settings["libraries"] = libraries + self.settings["active_library"] = library_name + self._sync_active_library_to_root(save=False) + self._save_settings() + return + + sanitized_libraries: Dict[str, Dict[str, Any]] = {} + changed = False + for name, data in libraries.items(): + if not isinstance(data, dict): + data = {} + changed = True + payload = self._build_library_payload( + folder_paths=data.get("folder_paths"), + default_lora_root=data.get("default_lora_root"), + default_checkpoint_root=data.get("default_checkpoint_root"), + default_embedding_root=data.get("default_embedding_root"), + metadata=data.get("metadata"), + base=data, + ) + sanitized_libraries[name] = payload + if payload is not data: + changed = True + + if changed: + self.settings["libraries"] = sanitized_libraries + + if not active_name or active_name not in sanitized_libraries: + if sanitized_libraries: + self.settings["active_library"] = next(iter(sanitized_libraries.keys())) + else: + self.settings["active_library"] = "default" + + self._sync_active_library_to_root(save=changed) + + def _sync_active_library_to_root(self, *, save: bool = False) -> None: + """Update top-level folder path settings to mirror the active library.""" + libraries = self.settings.get("libraries", {}) + active_name = self.settings.get("active_library") + if not libraries: + return + + if active_name not in libraries: + active_name = next(iter(libraries.keys())) + self.settings["active_library"] = active_name + + active_library = libraries.get(active_name, {}) + folder_paths = copy.deepcopy(active_library.get("folder_paths", {})) + self.settings["folder_paths"] = folder_paths + self.settings["default_lora_root"] = active_library.get("default_lora_root", "") + self.settings["default_checkpoint_root"] = active_library.get("default_checkpoint_root", "") + self.settings["default_embedding_root"] = active_library.get("default_embedding_root", "") + + if save: + self._save_settings() + + def _current_timestamp(self) -> str: + return datetime.utcnow().replace(microsecond=0).isoformat() + "Z" + + def _build_library_payload( + self, + *, + folder_paths: Optional[Mapping[str, Iterable[str]]] = None, + default_lora_root: Optional[str] = None, + default_checkpoint_root: Optional[str] = None, + default_embedding_root: Optional[str] = None, + metadata: Optional[Mapping[str, Any]] = None, + base: Optional[Mapping[str, Any]] = None, + ) -> Dict[str, Any]: + payload: Dict[str, Any] = dict(base or {}) + timestamp = self._current_timestamp() + + if folder_paths is not None: + payload["folder_paths"] = self._normalize_folder_paths(folder_paths) + else: + payload.setdefault("folder_paths", {}) + + if default_lora_root is not None: + payload["default_lora_root"] = default_lora_root + else: + payload.setdefault("default_lora_root", "") + + if default_checkpoint_root is not None: + payload["default_checkpoint_root"] = default_checkpoint_root + else: + payload.setdefault("default_checkpoint_root", "") + + if default_embedding_root is not None: + payload["default_embedding_root"] = default_embedding_root + else: + payload.setdefault("default_embedding_root", "") + + if metadata: + merged_meta = dict(payload.get("metadata", {})) + merged_meta.update(metadata) + payload["metadata"] = merged_meta + + payload.setdefault("created_at", timestamp) + payload["updated_at"] = timestamp + return payload + + def _normalize_folder_paths( + self, folder_paths: Mapping[str, Iterable[str]] + ) -> Dict[str, List[str]]: + normalized: Dict[str, List[str]] = {} + for key, values in folder_paths.items(): + if not isinstance(values, Iterable): + continue + cleaned: List[str] = [] + seen = set() + for value in values: + if not isinstance(value, str): + continue + stripped = value.strip() + if not stripped: + continue + if stripped not in seen: + cleaned.append(stripped) + seen.add(stripped) + normalized[key] = cleaned + return normalized + + def _validate_folder_paths( + self, + library_name: str, + folder_paths: Mapping[str, Iterable[str]], + ) -> None: + """Ensure folder paths do not overlap with other libraries.""" + libraries = self.settings.get("libraries", {}) + normalized_new: Dict[str, Dict[str, str]] = {} + for key, values in folder_paths.items(): + path_map: Dict[str, str] = {} + for value in values: + if not isinstance(value, str): + continue + stripped = value.strip() + if not stripped: + continue + normalized_value = os.path.normcase(os.path.normpath(stripped)) + path_map[normalized_value] = stripped + if path_map: + normalized_new[key] = path_map + + if not normalized_new: + return + + for other_name, other in libraries.items(): + if other_name == library_name: + continue + other_paths = other.get("folder_paths", {}) + for key, new_paths in normalized_new.items(): + existing = { + os.path.normcase(os.path.normpath(path)) + for path in other_paths.get(key, []) + if isinstance(path, str) and path + } + overlap = existing.intersection(new_paths.keys()) + if overlap: + collisions = ", ".join(sorted(new_paths[value] for value in overlap)) + raise ValueError( + f"Folder path(s) {collisions} already assigned to library '{other_name}'" + ) + + def _update_active_library_entry( + self, + *, + folder_paths: Optional[Mapping[str, Iterable[str]]] = None, + default_lora_root: Optional[str] = None, + default_checkpoint_root: Optional[str] = None, + default_embedding_root: Optional[str] = None, + ) -> bool: + libraries = self.settings.get("libraries", {}) + active_name = self.settings.get("active_library") + if not active_name or active_name not in libraries: + return False + + library = libraries[active_name] + changed = False + + if folder_paths is not None: + normalized_paths = self._normalize_folder_paths(folder_paths) + if library.get("folder_paths") != normalized_paths: + library["folder_paths"] = normalized_paths + changed = True + + if default_lora_root is not None and library.get("default_lora_root") != default_lora_root: + library["default_lora_root"] = default_lora_root + changed = True + + if default_checkpoint_root is not None and library.get("default_checkpoint_root") != default_checkpoint_root: + library["default_checkpoint_root"] = default_checkpoint_root + changed = True + + if default_embedding_root is not None and library.get("default_embedding_root") != default_embedding_root: + library["default_embedding_root"] = default_embedding_root + changed = True + + if changed: + library.setdefault("created_at", self._current_timestamp()) + library["updated_at"] = self._current_timestamp() + + return changed + def _migrate_setting_keys(self) -> None: """Migrate legacy camelCase setting keys to snake_case""" key_migrations = { @@ -138,6 +358,11 @@ class SettingsManager: self.settings['default_embedding_root'] = embeddings[0] updated = True if updated: + self._update_active_library_entry( + default_lora_root=self.settings.get('default_lora_root'), + default_checkpoint_root=self.settings.get('default_checkpoint_root'), + default_embedding_root=self.settings.get('default_embedding_root'), + ) self._save_settings() def _check_environment_variables(self) -> None: @@ -168,6 +393,14 @@ class SettingsManager: def set(self, key: str, value: Any) -> None: """Set setting value and save""" self.settings[key] = value + if key == 'folder_paths' and isinstance(value, Mapping): + self._update_active_library_entry(folder_paths=value) # type: ignore[arg-type] + elif key == 'default_lora_root': + self._update_active_library_entry(default_lora_root=str(value)) + elif key == 'default_checkpoint_root': + self._update_active_library_entry(default_checkpoint_root=str(value)) + elif key == 'default_embedding_root': + self._update_active_library_entry(default_embedding_root=str(value)) self._save_settings() def delete(self, key: str) -> None: @@ -185,6 +418,227 @@ class SettingsManager: except Exception as e: logger.error(f"Error saving settings: {e}") + def get_libraries(self) -> Dict[str, Dict[str, Any]]: + """Return a copy of the registered libraries.""" + libraries = self.settings.get("libraries", {}) + return copy.deepcopy(libraries) + + def get_active_library_name(self) -> str: + """Return the currently active library name.""" + libraries = self.settings.get("libraries", {}) + active_name = self.settings.get("active_library") + if active_name and active_name in libraries: + return active_name + if libraries: + return next(iter(libraries.keys())) + return "default" + + def get_active_library(self) -> Dict[str, Any]: + """Return a copy of the active library configuration.""" + libraries = self.settings.get("libraries", {}) + active_name = self.get_active_library_name() + return copy.deepcopy(libraries.get(active_name, {})) + + def activate_library(self, library_name: str) -> None: + """Activate a library by name and refresh dependent services.""" + libraries = self.settings.get("libraries", {}) + if library_name not in libraries: + raise KeyError(f"Library '{library_name}' does not exist") + + current_active = self.get_active_library_name() + if current_active == library_name: + # Ensure root settings stay in sync even if already active + self._sync_active_library_to_root(save=False) + self._save_settings() + self._notify_library_change(library_name) + return + + self.settings["active_library"] = library_name + self._sync_active_library_to_root(save=False) + self._save_settings() + self._notify_library_change(library_name) + + def upsert_library( + self, + library_name: str, + *, + folder_paths: Optional[Mapping[str, Iterable[str]]] = None, + default_lora_root: Optional[str] = None, + default_checkpoint_root: Optional[str] = None, + default_embedding_root: Optional[str] = None, + metadata: Optional[Mapping[str, Any]] = None, + activate: bool = False, + ) -> Dict[str, Any]: + """Create or update a library definition.""" + + name = library_name.strip() + if not name: + raise ValueError("Library name cannot be empty") + + if folder_paths is not None: + self._validate_folder_paths(name, folder_paths) + + libraries = self.settings.setdefault("libraries", {}) + existing = libraries.get(name, {}) + + payload = self._build_library_payload( + folder_paths=folder_paths if folder_paths is not None else existing.get("folder_paths"), + default_lora_root=default_lora_root if default_lora_root is not None else existing.get("default_lora_root"), + default_checkpoint_root=( + default_checkpoint_root + if default_checkpoint_root is not None + else existing.get("default_checkpoint_root") + ), + default_embedding_root=( + default_embedding_root + if default_embedding_root is not None + else existing.get("default_embedding_root") + ), + metadata=metadata if metadata is not None else existing.get("metadata"), + base=existing, + ) + + libraries[name] = payload + + if activate or not self.settings.get("active_library"): + self.settings["active_library"] = name + + self._sync_active_library_to_root(save=False) + self._save_settings() + + if self.settings.get("active_library") == name: + self._notify_library_change(name) + + return payload + + def create_library( + self, + library_name: str, + *, + folder_paths: Mapping[str, Iterable[str]], + default_lora_root: str = "", + default_checkpoint_root: str = "", + default_embedding_root: str = "", + metadata: Optional[Mapping[str, Any]] = None, + activate: bool = False, + ) -> Dict[str, Any]: + """Create a new library entry.""" + + libraries = self.settings.get("libraries", {}) + if library_name in libraries: + raise ValueError(f"Library '{library_name}' already exists") + + return self.upsert_library( + library_name, + folder_paths=folder_paths, + default_lora_root=default_lora_root, + default_checkpoint_root=default_checkpoint_root, + default_embedding_root=default_embedding_root, + metadata=metadata, + activate=activate, + ) + + def rename_library(self, old_name: str, new_name: str) -> None: + """Rename an existing library.""" + + libraries = self.settings.get("libraries", {}) + if old_name not in libraries: + raise KeyError(f"Library '{old_name}' does not exist") + new_name_stripped = new_name.strip() + if not new_name_stripped: + raise ValueError("New library name cannot be empty") + if new_name_stripped in libraries: + raise ValueError(f"Library '{new_name_stripped}' already exists") + + libraries[new_name_stripped] = libraries.pop(old_name) + if self.settings.get("active_library") == old_name: + self.settings["active_library"] = new_name_stripped + active_name = new_name_stripped + else: + active_name = self.settings.get("active_library") + + self._sync_active_library_to_root(save=False) + self._save_settings() + + if active_name == new_name_stripped: + self._notify_library_change(new_name_stripped) + + def delete_library(self, library_name: str) -> None: + """Remove a library definition.""" + + libraries = self.settings.get("libraries", {}) + if library_name not in libraries: + raise KeyError(f"Library '{library_name}' does not exist") + if len(libraries) == 1: + raise ValueError("At least one library must remain") + + was_active = self.settings.get("active_library") == library_name + libraries.pop(library_name) + + if was_active: + new_active = next(iter(libraries.keys())) + self.settings["active_library"] = new_active + self._sync_active_library_to_root(save=False) + self._save_settings() + + if was_active: + self._notify_library_change(self.settings["active_library"]) + + def update_active_library_paths( + self, + folder_paths: Mapping[str, Iterable[str]], + *, + default_lora_root: Optional[str] = None, + default_checkpoint_root: Optional[str] = None, + default_embedding_root: Optional[str] = None, + ) -> None: + """Update folder paths for the active library.""" + + active_name = self.get_active_library_name() + self.upsert_library( + active_name, + folder_paths=folder_paths, + default_lora_root=default_lora_root, + default_checkpoint_root=default_checkpoint_root, + default_embedding_root=default_embedding_root, + activate=True, + ) + + def _notify_library_change(self, library_name: str) -> None: + """Notify dependent services that the active library changed.""" + libraries = self.settings.get("libraries", {}) + library_config = libraries.get(library_name, {}) + library_snapshot = copy.deepcopy(library_config) + + try: + from ..config import config # Local import to avoid circular dependency + + config.apply_library_settings(library_snapshot) + except Exception as exc: # pragma: no cover - defensive logging + logger.debug("Failed to apply library settings to config: %s", exc) + + try: + from .service_registry import ServiceRegistry # type: ignore + + for service_name in ( + "lora_scanner", + "checkpoint_scanner", + "embedding_scanner", + "recipe_scanner", + ): + service = ServiceRegistry.get_service_sync(service_name) + if service and hasattr(service, "on_library_changed"): + try: + service.on_library_changed() + except Exception as service_exc: # pragma: no cover - defensive logging + logger.debug( + "Service %s failed to handle library change: %s", + service_name, + service_exc, + ) + except Exception as exc: # pragma: no cover - defensive logging + logger.debug("Failed to notify services about library change: %s", exc) + def get_download_path_template(self, model_type: str) -> str: """Get download path template for specific model type diff --git a/settings.json.example b/settings.json.example index 673aa76d..51998c89 100644 --- a/settings.json.example +++ b/settings.json.example @@ -1,5 +1,28 @@ { "civitai_api_key": "your_civitai_api_key_here", + "active_library": "default", + "libraries": { + "default": { + "display_name": "Default Library", + "folder_paths": { + "loras": [ + "C:/path/to/your/loras_folder", + "C:/path/to/another/loras_folder" + ], + "checkpoints": [ + "C:/path/to/your/checkpoints_folder", + "C:/path/to/another/checkpoints_folder" + ], + "embeddings": [ + "C:/path/to/your/embeddings_folder", + "C:/path/to/another/embeddings_folder" + ] + }, + "default_lora_root": "C:/path/to/your/loras_folder", + "default_checkpoint_root": "C:/path/to/your/checkpoints_folder", + "default_embedding_root": "C:/path/to/your/embeddings_folder" + } + }, "folder_paths": { "loras": [ "C:/path/to/your/loras_folder", @@ -13,5 +36,8 @@ "C:/path/to/your/embeddings_folder", "C:/path/to/another/embeddings_folder" ] - } + }, + "default_lora_root": "C:/path/to/your/loras_folder", + "default_checkpoint_root": "C:/path/to/your/checkpoints_folder", + "default_embedding_root": "C:/path/to/your/embeddings_folder" } diff --git a/tests/services/test_model_scanner.py b/tests/services/test_model_scanner.py index 6d87d56e..dba82c41 100644 --- a/tests/services/test_model_scanner.py +++ b/tests/services/test_model_scanner.py @@ -312,7 +312,7 @@ async def test_update_single_model_cache_persists_changes(tmp_path: Path, monkey monkeypatch.setenv('LORA_MANAGER_DISABLE_PERSISTENT_CACHE', '0') db_path = tmp_path / 'cache.sqlite' monkeypatch.setenv('LORA_MANAGER_CACHE_DB', str(db_path)) - monkeypatch.setattr(PersistentModelCache, '_instance', None, raising=False) + monkeypatch.setattr(PersistentModelCache, '_instances', {}, raising=False) _create_files(tmp_path) scanner = DummyScanner(tmp_path) @@ -360,7 +360,7 @@ async def test_batch_delete_persists_removal(tmp_path: Path, monkeypatch): monkeypatch.setenv('LORA_MANAGER_DISABLE_PERSISTENT_CACHE', '0') db_path = tmp_path / 'cache.sqlite' monkeypatch.setenv('LORA_MANAGER_CACHE_DB', str(db_path)) - monkeypatch.setattr(PersistentModelCache, '_instance', None, raising=False) + monkeypatch.setattr(PersistentModelCache, '_instances', {}, raising=False) first, _, _ = _create_files(tmp_path) scanner = DummyScanner(tmp_path) diff --git a/tests/services/test_settings_manager.py b/tests/services/test_settings_manager.py index 1951c7c1..56eef7b9 100644 --- a/tests/services/test_settings_manager.py +++ b/tests/services/test_settings_manager.py @@ -1,4 +1,5 @@ import json +import os import pytest @@ -88,3 +89,43 @@ def test_migrates_legacy_settings_file(tmp_path, monkeypatch): assert migrated_path == str(target_dir / "settings.json") assert (target_dir / "settings.json").exists() assert not legacy_file.exists() + + +def test_migrate_creates_default_library(manager): + libraries = manager.get_libraries() + assert "default" in libraries + assert manager.get_active_library_name() == "default" + assert libraries["default"].get("folder_paths", {}) == manager.settings.get("folder_paths", {}) + + +def test_upsert_library_creates_entry_and_activates(manager, tmp_path): + lora_dir = tmp_path / "loras" + lora_dir.mkdir() + + manager.upsert_library( + "studio", + folder_paths={"loras": [str(lora_dir)]}, + activate=True, + ) + + assert manager.get_active_library_name() == "studio" + libraries = manager.get_libraries() + stored_paths = libraries["studio"]["folder_paths"]["loras"] + assert str(lora_dir).replace(os.sep, "/") in stored_paths + + +def test_delete_library_switches_active(manager, tmp_path): + other_dir = tmp_path / "other" + other_dir.mkdir() + + manager.create_library( + "other", + folder_paths={"loras": [str(other_dir)]}, + activate=True, + ) + + assert manager.get_active_library_name() == "other" + + manager.delete_library("other") + + assert manager.get_active_library_name() == "default"