From 99bdf9a3b3462d58d27655eb8d91f81800c2dadd Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Sun, 23 Feb 2025 06:25:05 +0800 Subject: [PATCH] Fix symlink checkpoint2 --- config.py | 115 +++++++++++++++++++++++++++++++++++++-- lora_manager.py | 26 +++++++++ services/file_monitor.py | 91 +++---------------------------- 3 files changed, 143 insertions(+), 89 deletions(-) diff --git a/config.py b/config.py index 57e00f39..23f821b1 100644 --- a/config.py +++ b/config.py @@ -1,14 +1,90 @@ import os +import platform import folder_paths # type: ignore from typing import List +import logging + +logger = logging.getLogger(__name__) class Config: """Global configuration for LoRA Manager""" def __init__(self): - self.loras_roots = self._init_lora_paths() self.templates_path = os.path.join(os.path.dirname(__file__), 'templates') self.static_path = os.path.join(os.path.dirname(__file__), 'static') + # 路径映射字典, target to link mapping + self._path_mappings = {} + # 静态路由映射字典, target to route mapping + self._route_mappings = {} + self.loras_roots = self._init_lora_paths() + # 在初始化时扫描符号链接 + self._scan_symbolic_links() + + def _is_link(self, path: str) -> bool: + try: + if os.path.islink(path): + return True + if platform.system() == 'Windows': + try: + import ctypes + FILE_ATTRIBUTE_REPARSE_POINT = 0x400 + attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path)) + return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT) + except Exception as e: + logger.error(f"Error checking Windows reparse point: {e}") + return False + except Exception as e: + logger.error(f"Error checking link status for {path}: {e}") + return False + + def _scan_symbolic_links(self): + """扫描所有 LoRA 根目录中的符号链接""" + for root in self.loras_roots: + self._scan_directory_links(root) + + def _scan_directory_links(self, root: str): + """递归扫描目录中的符号链接""" + try: + with os.scandir(root) as it: + for entry in it: + if self._is_link(entry.path): + target_path = os.path.realpath(entry.path) + if os.path.isdir(target_path): + self.add_path_mapping(entry.path, target_path) + self._scan_directory_links(target_path) + elif entry.is_dir(follow_symlinks=False): + self._scan_directory_links(entry.path) + except Exception as e: + logger.error(f"Error scanning links in {root}: {e}") + + def add_path_mapping(self, link_path: str, target_path: str): + """添加符号链接路径映射 + target_path: 实际目标路径 + link_path: 符号链接路径 + """ + normalized_link = os.path.normpath(link_path).replace(os.sep, '/') + normalized_target = os.path.normpath(target_path).replace(os.sep, '/') + # 保持原有的映射关系:目标路径 -> 链接路径 + self._path_mappings[normalized_target] = normalized_link + logger.info(f"Added path mapping: {normalized_target} -> {normalized_link}") + + def add_route_mapping(self, path: str, route: str): + """添加静态路由映射""" + normalized_path = os.path.normpath(path).replace(os.sep, '/') + self._route_mappings[normalized_path] = route + logger.info(f"Added route mapping: {normalized_path} -> {route}") + + def map_path_to_link(self, path: str) -> str: + """将目标路径映射回符号链接路径""" + normalized_path = os.path.normpath(path).replace(os.sep, '/') + # 检查路径是否包含在任何映射的目标路径中 + for target_path, link_path in self._path_mappings.items(): + if normalized_path.startswith(target_path): + # 如果路径以目标路径开头,则替换为链接路径 + mapped_path = normalized_path.replace(target_path, link_path, 1) + logger.info(f"Mapped path {normalized_path} to {mapped_path}") + return mapped_path + return path def _init_lora_paths(self) -> List[str]: """Initialize and validate LoRA paths from ComfyUI settings""" @@ -20,17 +96,44 @@ class Config: if not paths: raise ValueError("No valid loras folders found in ComfyUI configuration") + # 初始化路径映射 + for path in paths: + real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/') + if real_path != path: + self.add_path_mapping(real_path, path) + return paths def get_preview_static_url(self, preview_path: str) -> str: """Convert local preview path to static URL""" if not preview_path: return "" - - for idx, root in enumerate(self.loras_roots, start=1): - if preview_path.startswith(root): - relative_path = os.path.relpath(preview_path, root) - return f'/loras_static/root{idx}/preview/{relative_path.replace(os.sep, "/")}' + + # 获取真实路径和规范化路径 + real_preview_path = os.path.realpath(preview_path) + normalized_real_path = real_preview_path.replace(os.sep, '/') + normalized_preview_path = preview_path.replace(os.sep, '/') + + # 首先尝试使用原始路径查找路由 + for root_path, route in self._route_mappings.items(): + if normalized_preview_path.startswith(root_path): + relative_path = os.path.relpath(normalized_preview_path, root_path) + return f'{route}/{relative_path.replace(os.sep, "/")}' + + # 如果没找到,尝试使用真实路径查找路由 + for root_path, route in self._route_mappings.items(): + if normalized_real_path.startswith(root_path): + relative_path = os.path.relpath(normalized_real_path, root_path) + return f'{route}/{relative_path.replace(os.sep, "/")}' + + # 如果还没找到,尝试使用路径映射 + mapped_path = self.map_path_to_link(real_preview_path) + normalized_mapped_path = mapped_path.replace(os.sep, '/') + + for root_path, route in self._route_mappings.items(): + if normalized_mapped_path.startswith(root_path): + relative_path = os.path.relpath(normalized_mapped_path, root_path) + return f'{route}/{relative_path.replace(os.sep, "/")}' return "" diff --git a/lora_manager.py b/lora_manager.py index da80ec5b..3e32fc90 100644 --- a/lora_manager.py +++ b/lora_manager.py @@ -1,4 +1,5 @@ import asyncio +import os from server import PromptServer # type: ignore from .config import config from .routes.lora_routes import LoraRoutes @@ -6,6 +7,9 @@ from .routes.api_routes import ApiRoutes from .services.lora_scanner import LoraScanner from .services.file_monitor import LoraFileMonitor from .services.lora_cache import LoraCache +import logging + +logger = logging.getLogger(__name__) class LoraManager: """Main entry point for LoRA Manager plugin""" @@ -14,11 +18,33 @@ class LoraManager: def add_routes(cls): """Initialize and register all routes""" app = PromptServer.instance.app + + added_targets = set() # 用于跟踪已添加的目标路径 # Add static routes for each lora root for idx, root in enumerate(config.loras_roots, start=1): preview_path = f'/loras_static/root{idx}/preview' + + # 为原始路径添加静态路由 app.router.add_static(preview_path, root) + logger.info(f"Added static route {preview_path} -> {root}") + + # 记录路由映射 + config.add_route_mapping(root, preview_path) + added_targets.add(root) + + # 为符号链接的目标路径添加额外的静态路由 + link_idx = 1 + + for target_path, link_path in config._path_mappings.items(): + if target_path not in added_targets: + route_path = f'/loras_static/link_{link_idx}/preview' + app.router.add_static(route_path, target_path) + logger.info(f"Added static route for link target {route_path} -> {target_path}") + config.add_route_mapping(target_path, route_path) + config.add_route_mapping(link_path, route_path) # 也为符号链接路径添加路由映射 + added_targets.add(target_path) + link_idx += 1 # Add static route for plugin assets app.router.add_static('/loras_static', config.static_path) diff --git a/services/file_monitor.py b/services/file_monitor.py index 09a03882..6fd7027f 100644 --- a/services/file_monitor.py +++ b/services/file_monitor.py @@ -8,6 +8,7 @@ from typing import List from threading import Lock from .lora_scanner import LoraScanner import platform +from ..config import config logger = logging.getLogger(__name__) @@ -23,7 +24,6 @@ class LoraFileHandler(FileSystemEventHandler): self._ignore_paths = set() # Add ignore paths set self._min_ignore_timeout = 5 # minimum timeout in seconds self._download_speed = 1024 * 1024 # assume 1MB/s as base speed - self._path_mappings = {} # 添加路径映射字典 def _should_ignore(self, path: str) -> bool: """Check if path should be ignored""" @@ -50,23 +50,6 @@ class LoraFileHandler(FileSystemEventHandler): real_path.replace(os.sep, '/') ) - def add_path_mapping(self, link_path: str, target_path: str): - """添加符号链接路径映射""" - normalized_link = os.path.normpath(link_path).replace(os.sep, '/') - normalized_target = os.path.normpath(target_path).replace(os.sep, '/') - self._path_mappings[normalized_target] = normalized_link - logger.debug(f"Added path mapping: {normalized_target} -> {normalized_link}") - - def _map_path_to_link(self, path: str) -> str: - """将目标路径映射回符号链接路径""" - normalized_path = os.path.normpath(path).replace(os.sep, '/') - for target_prefix, link_prefix in self._path_mappings.items(): - if normalized_path.startswith(target_prefix): - mapped_path = normalized_path.replace(target_prefix, link_prefix, 1) - logger.debug(f"Mapped path {normalized_path} to {mapped_path}") - return mapped_path - return path - def on_created(self, event): if event.is_directory or not event.src_path.endswith('.safetensors'): return @@ -86,8 +69,8 @@ class LoraFileHandler(FileSystemEventHandler): def _schedule_update(self, action: str, file_path: str): """Schedule a cache update""" with self.lock: - # 将目标路径映射回符号链接路径 - mapped_path = self._map_path_to_link(file_path) + # 使用 config 中的方法映射路径 + mapped_path = config.map_path_to_link(file_path) normalized_path = mapped_path.replace(os.sep, '/') self.pending_changes.add((action, normalized_path)) @@ -159,72 +142,14 @@ class LoraFileMonitor: self.loop = asyncio.get_event_loop() self.handler = LoraFileHandler(scanner, self.loop) - # 存储所有需要监控的路径(包括链接的目标路径) + # 使用已存在的路径映射 self.monitor_paths = set() for root in roots: - real_root = os.path.realpath(root) - self.monitor_paths.add(real_root) - # 扫描根目录下的链接 - self._add_link_targets(root) + self.monitor_paths.add(os.path.realpath(root)) - def _is_link(self, path: str) -> bool: - """ - 检查路径是否为链接 - 支持: - - Windows: Symbolic Links, Junction Points - - Linux: Symbolic Links - """ - try: - # 首先检查通用的符号链接 - if os.path.islink(path): - return True - - # Windows 特定的 Junction Points 检测 - if platform.system() == 'Windows': - try: - import ctypes - FILE_ATTRIBUTE_REPARSE_POINT = 0x400 - attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path)) - return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT) - except Exception as e: - logger.error(f"Error checking Windows reparse point: {e}") - - return False - - except Exception as e: - logger.error(f"Error checking link status for {path}: {e}") - return False - - def _get_link_target(self, path: str) -> str: - """获取链接目标路径""" - try: - return os.path.realpath(path) - except Exception as e: - logger.error(f"Error resolving link target for {path}: {e}") - return path - - def _add_link_targets(self, root: str): - """递归扫描目录,添加链接指向的目标路径""" - try: - with os.scandir(root) as it: - for entry in it: - logger.debug(f"Checking path: {entry.path}") - if self._is_link(entry.path): - # 获取链接的目标路径 - target_path = self._get_link_target(entry.path) - if os.path.isdir(target_path): - normalized_target = os.path.normpath(target_path) - self.monitor_paths.add(normalized_target) - # 添加路径映射到处理器 - self.handler.add_path_mapping(entry.path, target_path) - logger.info(f"Found link: {entry.path} -> {normalized_target}") - # 递归扫描目标目录中的链接 - self._add_link_targets(target_path) - elif entry.is_dir(follow_symlinks=False): - # 递归扫描子目录 - self._add_link_targets(entry.path) - except Exception as e: - logger.error(f"Error scanning links in {root}: {e}") + # 添加所有已映射的目标路径 + for target_path in config._path_mappings.keys(): + self.monitor_paths.add(target_path) def start(self): """Start monitoring"""