diff --git a/py/config.py b/py/config.py index c4e0c1ae..27876d9f 100644 --- a/py/config.py +++ b/py/config.py @@ -7,7 +7,7 @@ import logging import json import urllib.parse -from .utils.settings_paths import ensure_settings_file, load_settings_template +from .utils.settings_paths import ensure_settings_file, get_settings_dir, load_settings_template # Use an environment variable to control standalone mode standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" @@ -87,8 +87,7 @@ class Config: self.base_models_roots = self._init_checkpoint_paths() self.embeddings_roots = self._init_embedding_paths() # Scan symbolic links during initialization - self._scan_symbolic_links() - self._rebuild_preview_roots() + self._initialize_symlink_mappings() if not standalone_mode: # Save the paths to settings.json when running in ComfyUI mode @@ -220,39 +219,212 @@ class Config: logger.error(f"Error checking link status for {path}: {e}") return False + def _normalize_path(self, path: str) -> str: + return os.path.normpath(path).replace(os.sep, '/') + + def _get_symlink_cache_path(self) -> Path: + cache_dir = Path(get_settings_dir(create=True)) / "cache" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir / "symlink_map.json" + + def _compute_noise_mtime(self, root: str) -> Optional[int]: + """Return the latest mtime of known noisy paths inside ``root``.""" + + normalized_root = self._normalize_path(root) + noise_paths: List[str] = [] + + # The first LoRA root hosts recipes and stats files which routinely + # update without changing symlink layout. + first_lora_root = self._normalize_path(self.loras_roots[0]) if self.loras_roots else None + if first_lora_root and normalized_root == first_lora_root: + recipes_dir = os.path.join(root, "recipes") + stats_file = os.path.join(root, "lora_manager_stats.json") + noise_paths.extend([recipes_dir, stats_file]) + + mtimes: List[int] = [] + for path in noise_paths: + try: + stat_result = os.stat(path) + mtimes.append(getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1e9))) + except OSError: + continue + + if not mtimes: + return None + return max(mtimes) + + def _symlink_roots(self) -> List[str]: + roots: List[str] = [] + roots.extend(self.loras_roots or []) + roots.extend(self.base_models_roots or []) + roots.extend(self.embeddings_roots or []) + return roots + + def _build_symlink_fingerprint(self) -> Dict[str, object]: + roots = [self._normalize_path(path) for path in self._symlink_roots() if path] + unique_roots = sorted(set(roots)) + + stats: Dict[str, Dict[str, int]] = {} + for root in unique_roots: + try: + root_stat = os.stat(root) + noise_mtime = self._compute_noise_mtime(root) + stats[root] = { + "mtime_ns": getattr(root_stat, "st_mtime_ns", int(root_stat.st_mtime * 1e9)), + "inode": getattr(root_stat, "st_ino", 0), + "noise_mtime_ns": noise_mtime, + } + except OSError: + continue + + return {"roots": unique_roots, "stats": stats} + + def _load_symlink_cache(self) -> bool: + cache_path = self._get_symlink_cache_path() + if not cache_path.exists(): + return False + + try: + with cache_path.open("r", encoding="utf-8") as handle: + payload = json.load(handle) + except Exception as exc: + logger.debug("Failed to load symlink cache %s: %s", cache_path, exc) + return False + + if not isinstance(payload, dict): + return False + + cached_fingerprint = payload.get("fingerprint") + cached_mappings = payload.get("path_mappings") + if not isinstance(cached_fingerprint, dict) or not isinstance(cached_mappings, Mapping): + return False + + current_fingerprint = self._build_symlink_fingerprint() + cached_roots = cached_fingerprint.get("roots") + cached_stats = cached_fingerprint.get("stats") + if ( + not isinstance(cached_roots, list) + or not isinstance(cached_stats, Mapping) + or sorted(cached_roots) != sorted(current_fingerprint["roots"]) # type: ignore[index] + ): + return False + + for root in current_fingerprint["roots"]: # type: ignore[assignment] + cached_stat = cached_stats.get(root) if isinstance(cached_stats, Mapping) else None + current_stat = current_fingerprint["stats"].get(root) # type: ignore[index] + if not isinstance(cached_stat, Mapping) or not current_stat: + return False + + cached_mtime = cached_stat.get("mtime_ns") + cached_inode = cached_stat.get("inode") + current_mtime = current_stat.get("mtime_ns") + current_inode = current_stat.get("inode") + + if cached_inode != current_inode: + return False + + if cached_mtime != current_mtime: + cached_noise = cached_stat.get("noise_mtime_ns") + current_noise = current_stat.get("noise_mtime_ns") + if not ( + cached_noise + and current_noise + and cached_mtime == cached_noise + and current_mtime == current_noise + ): + return False + + normalized_mappings: Dict[str, str] = {} + for target, link in cached_mappings.items(): + if not isinstance(target, str) or not isinstance(link, str): + continue + normalized_mappings[self._normalize_path(target)] = self._normalize_path(link) + + self._path_mappings = normalized_mappings + return True + + def _save_symlink_cache(self) -> None: + cache_path = self._get_symlink_cache_path() + payload = { + "fingerprint": self._build_symlink_fingerprint(), + "path_mappings": self._path_mappings, + } + + try: + with cache_path.open("w", encoding="utf-8") as handle: + json.dump(payload, handle, ensure_ascii=False, indent=2) + except Exception as exc: + logger.debug("Failed to write symlink cache %s: %s", cache_path, exc) + + def _initialize_symlink_mappings(self) -> None: + if not self._load_symlink_cache(): + self._scan_symbolic_links() + self._save_symlink_cache() + else: + logger.info("Loaded symlink mappings from cache") + self._rebuild_preview_roots() + def _scan_symbolic_links(self): """Scan all symbolic links in LoRA, Checkpoint, and Embedding root directories""" - for root in self.loras_roots: - self._scan_directory_links(root) - - for root in self.base_models_roots: - self._scan_directory_links(root) - - for root in self.embeddings_roots: - self._scan_directory_links(root) + visited_dirs: Set[str] = set() + for root in self._symlink_roots(): + self._scan_directory_links(root, visited_dirs) - def _scan_directory_links(self, root: str): - """Recursively scan symbolic links in a directory""" + def _scan_directory_links(self, root: str, visited_dirs: Set[str]): + """Iteratively scan directory symlinks to avoid deep recursion.""" try: - with os.scandir(root) as it: - for entry in it: - if self._is_link(entry.path): - target_path = os.path.realpath(entry.path) - if os.path.isdir(target_path): - self.add_path_mapping(entry.path, target_path) - self._scan_directory_links(target_path) - elif entry.is_dir(follow_symlinks=False): - self._scan_directory_links(entry.path) - except Exception as e: - logger.error(f"Error scanning links in {root}: {e}") + root_real = self._normalize_path(os.path.realpath(root)) + except OSError: + root_real = self._normalize_path(root) + + if root_real in visited_dirs: + return + + visited_dirs.add(root_real) + stack: List[str] = [root] + + while stack: + current = stack.pop() + try: + with os.scandir(current) as it: + for entry in it: + try: + entry_path = entry.path + if self._is_link(entry_path): + target_path = os.path.realpath(entry_path) + if not os.path.isdir(target_path): + continue + + normalized_target = self._normalize_path(target_path) + if normalized_target in visited_dirs: + continue + visited_dirs.add(normalized_target) + self.add_path_mapping(entry_path, target_path) + stack.append(target_path) + continue + + if not entry.is_dir(follow_symlinks=False): + continue + + normalized_real = self._normalize_path(os.path.realpath(entry_path)) + if normalized_real in visited_dirs: + continue + visited_dirs.add(normalized_real) + stack.append(entry_path) + except Exception as inner_exc: + logger.debug( + "Error processing directory entry %s: %s", entry.path, inner_exc + ) + except Exception as e: + logger.error(f"Error scanning links in {current}: {e}") def add_path_mapping(self, link_path: str, target_path: str): """Add a symbolic link path mapping target_path: actual target path link_path: symbolic link path """ - normalized_link = os.path.normpath(link_path).replace(os.sep, '/') - normalized_target = os.path.normpath(target_path).replace(os.sep, '/') + normalized_link = self._normalize_path(link_path) + normalized_target = self._normalize_path(target_path) # Keep the original mapping: target path -> link path self._path_mappings[normalized_target] = normalized_link logger.info(f"Added path mapping: {normalized_target} -> {normalized_link}") @@ -411,8 +583,7 @@ class Config: self.base_models_roots = self._prepare_checkpoint_paths(checkpoint_paths, unet_paths) self.embeddings_roots = self._prepare_embedding_paths(embedding_paths) - self._scan_symbolic_links() - self._rebuild_preview_roots() + self._initialize_symlink_mappings() def _init_lora_paths(self) -> List[str]: """Initialize and validate LoRA paths from ComfyUI settings""" diff --git a/py/lora_manager.py b/py/lora_manager.py index 3a8c47c2..b6b1fccd 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -158,8 +158,6 @@ class LoraManager: # Add cleanup app.on_shutdown.append(cls._cleanup) - logger.info(f"LoRA Manager: Set up routes for {len(ModelServiceFactory.get_registered_types())} model types: {', '.join(ModelServiceFactory.get_registered_types())}") - @classmethod async def _initialize_services(cls): """Initialize all services using the ServiceRegistry""" diff --git a/py/services/model_service_factory.py b/py/services/model_service_factory.py index 4d655eed..0e1da069 100644 --- a/py/services/model_service_factory.py +++ b/py/services/model_service_factory.py @@ -22,7 +22,6 @@ class ModelServiceFactory: """ cls._services[model_type] = service_class cls._routes[model_type] = route_class - logger.info(f"Registered model type '{model_type}' with service {service_class.__name__} and routes {route_class.__name__}") @classmethod def get_service_class(cls, model_type: str) -> Type: @@ -80,13 +79,10 @@ class ModelServiceFactory: Args: app: The aiohttp application instance """ - logger.info(f"Setting up routes for {len(cls._services)} registered model types") - for model_type in cls._services.keys(): try: routes_instance = cls.get_route_instance(model_type) routes_instance.setup_routes(app) - logger.info(f"Successfully set up routes for {model_type}") except Exception as e: logger.error(f"Failed to setup routes for {model_type}: {e}", exc_info=True) @@ -137,6 +133,4 @@ def register_default_model_types(): ModelServiceFactory.register_model_type('checkpoint', CheckpointService, CheckpointRoutes) # Register Embedding model type - ModelServiceFactory.register_model_type('embedding', EmbeddingService, EmbeddingRoutes) - - logger.info("Registered default model types: lora, checkpoint, embedding") \ No newline at end of file + ModelServiceFactory.register_model_type('embedding', EmbeddingService, EmbeddingRoutes) \ No newline at end of file diff --git a/tests/config/test_symlink_cache.py b/tests/config/test_symlink_cache.py new file mode 100644 index 00000000..b0e46ff7 --- /dev/null +++ b/tests/config/test_symlink_cache.py @@ -0,0 +1,111 @@ +import os + +import pytest + +from py import config as config_module + + +def _normalize(path: str) -> str: + return os.path.normpath(path).replace(os.sep, "/") + + +def _setup_paths(monkeypatch: pytest.MonkeyPatch, tmp_path): + settings_dir = tmp_path / "settings" + loras_dir = tmp_path / "loras" + loras_dir.mkdir() + checkpoint_dir = tmp_path / "checkpoints" + checkpoint_dir.mkdir() + embedding_dir = tmp_path / "embeddings" + embedding_dir.mkdir() + + def fake_get_folder_paths(kind: str): + mapping = { + "loras": [str(loras_dir)], + "checkpoints": [str(checkpoint_dir)], + "unet": [], + "embeddings": [str(embedding_dir)], + } + return mapping.get(kind, []) + + monkeypatch.setattr(config_module.folder_paths, "get_folder_paths", fake_get_folder_paths) + monkeypatch.setattr(config_module, "standalone_mode", True) + monkeypatch.setattr(config_module, "get_settings_dir", lambda create=True: str(settings_dir)) + + return loras_dir, settings_dir + + +def test_symlink_scan_skips_file_links(monkeypatch: pytest.MonkeyPatch, tmp_path): + loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path) + + target_dir = loras_dir / "target" + target_dir.mkdir() + dir_link = loras_dir / "dir_link" + dir_link.symlink_to(target_dir, target_is_directory=True) + + file_target = loras_dir / "model.safetensors" + file_target.write_text("content", encoding="utf-8") + file_link = loras_dir / "file_link" + file_link.symlink_to(file_target) + + cfg = config_module.Config() + + normalized_target_dir = _normalize(os.path.realpath(target_dir)) + normalized_link_dir = _normalize(str(dir_link)) + assert cfg._path_mappings[normalized_target_dir] == normalized_link_dir + + normalized_file_real = _normalize(os.path.realpath(file_target)) + assert normalized_file_real not in cfg._path_mappings + + cache_path = settings_dir / "cache" / "symlink_map.json" + assert cache_path.exists() + + +def test_symlink_cache_reuses_previous_scan(monkeypatch: pytest.MonkeyPatch, tmp_path): + loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path) + + target_dir = loras_dir / "target" + target_dir.mkdir() + dir_link = loras_dir / "dir_link" + dir_link.symlink_to(target_dir, target_is_directory=True) + + first_cfg = config_module.Config() + cached_mappings = dict(first_cfg._path_mappings) + cache_path = settings_dir / "cache" / "symlink_map.json" + assert cache_path.exists() + + def fail_scan(self): + raise AssertionError("Cache should bypass directory scan") + + monkeypatch.setattr(config_module.Config, "_scan_symbolic_links", fail_scan) + + second_cfg = config_module.Config() + assert second_cfg._path_mappings == cached_mappings + assert second_cfg.map_path_to_link(str(target_dir)) == _normalize(str(dir_link)) + + +def test_symlink_cache_survives_noise_mtime(monkeypatch: pytest.MonkeyPatch, tmp_path): + loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path) + + target_dir = loras_dir / "target" + target_dir.mkdir() + dir_link = loras_dir / "dir_link" + dir_link.symlink_to(target_dir, target_is_directory=True) + + recipes_dir = loras_dir / "recipes" + recipes_dir.mkdir() + noise_file = recipes_dir / "touchme.txt" + + first_cfg = config_module.Config() + cache_path = settings_dir / "cache" / "symlink_map.json" + assert cache_path.exists() + + # Update a noisy path to bump parent directory mtime + noise_file.write_text("hi", encoding="utf-8") + + def fail_scan(self): + raise AssertionError("Cache should bypass directory scan despite noise mtime") + + monkeypatch.setattr(config_module.Config, "_scan_symbolic_links", fail_scan) + + second_cfg = config_module.Config() + assert second_cfg.map_path_to_link(str(target_dir)) == _normalize(str(dir_link))