From cb054953aaf0e4c883a7f8fff3f41691b80a9096 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Sat, 22 Feb 2025 20:28:47 +0800 Subject: [PATCH] Fix symlink checkpoint1 --- services/file_monitor.py | 136 +++++++++++++++++++++++++++++++++++---- services/lora_scanner.py | 65 ++++++++++--------- 2 files changed, 157 insertions(+), 44 deletions(-) diff --git a/services/file_monitor.py b/services/file_monitor.py index 3387ea8b..09a03882 100644 --- a/services/file_monitor.py +++ b/services/file_monitor.py @@ -7,6 +7,7 @@ from watchdog.events import FileSystemEventHandler, FileCreatedEvent, FileDelete from typing import List from threading import Lock from .lora_scanner import LoraScanner +import platform logger = logging.getLogger(__name__) @@ -22,14 +23,17 @@ class LoraFileHandler(FileSystemEventHandler): self._ignore_paths = set() # Add ignore paths set self._min_ignore_timeout = 5 # minimum timeout in seconds self._download_speed = 1024 * 1024 # assume 1MB/s as base speed + self._path_mappings = {} # 添加路径映射字典 def _should_ignore(self, path: str) -> bool: """Check if path should be ignored""" - return path.replace(os.sep, '/') in self._ignore_paths + real_path = os.path.realpath(path) # Resolve any symbolic links + return real_path.replace(os.sep, '/') in self._ignore_paths def add_ignore_path(self, path: str, file_size: int = 0): """Add path to ignore list with dynamic timeout based on file size""" - self._ignore_paths.add(path.replace(os.sep, '/')) + real_path = os.path.realpath(path) # Resolve any symbolic links + self._ignore_paths.add(real_path.replace(os.sep, '/')) # Calculate timeout based on file size, with a minimum value # Assuming average download speed of 1MB/s @@ -38,14 +42,31 @@ class LoraFileHandler(FileSystemEventHandler): (file_size / self._download_speed) * 1.5 # Add 50% buffer ) - logger.debug(f"Adding {path} to ignore list for {timeout:.1f} seconds") + logger.debug(f"Adding {real_path} to ignore list for {timeout:.1f} seconds") asyncio.get_event_loop().call_later( timeout, self._ignore_paths.discard, - path + real_path.replace(os.sep, '/') ) + def add_path_mapping(self, link_path: str, target_path: str): + """添加符号链接路径映射""" + normalized_link = os.path.normpath(link_path).replace(os.sep, '/') + normalized_target = os.path.normpath(target_path).replace(os.sep, '/') + self._path_mappings[normalized_target] = normalized_link + logger.debug(f"Added path mapping: {normalized_target} -> {normalized_link}") + + def _map_path_to_link(self, path: str) -> str: + """将目标路径映射回符号链接路径""" + normalized_path = os.path.normpath(path).replace(os.sep, '/') + for target_prefix, link_prefix in self._path_mappings.items(): + if normalized_path.startswith(target_prefix): + mapped_path = normalized_path.replace(target_prefix, link_prefix, 1) + logger.debug(f"Mapped path {normalized_path} to {mapped_path}") + return mapped_path + return path + def on_created(self, event): if event.is_directory or not event.src_path.endswith('.safetensors'): return @@ -65,11 +86,11 @@ class LoraFileHandler(FileSystemEventHandler): def _schedule_update(self, action: str, file_path: str): """Schedule a cache update""" with self.lock: - # 标准化路径 - normalized_path = file_path.replace(os.sep, '/') + # 将目标路径映射回符号链接路径 + mapped_path = self._map_path_to_link(file_path) + normalized_path = mapped_path.replace(os.sep, '/') self.pending_changes.add((action, normalized_path)) - # 使用 call_soon_threadsafe 在事件循环中安排任务 self.loop.call_soon_threadsafe(self._create_update_task) def _create_update_task(self): @@ -134,24 +155,111 @@ class LoraFileMonitor: def __init__(self, scanner: LoraScanner, roots: List[str]): self.scanner = scanner scanner.set_file_monitor(self) - self.roots = roots self.observer = Observer() - # 获取当前运行的事件循环 self.loop = asyncio.get_event_loop() self.handler = LoraFileHandler(scanner, self.loop) + # 存储所有需要监控的路径(包括链接的目标路径) + self.monitor_paths = set() + for root in roots: + real_root = os.path.realpath(root) + self.monitor_paths.add(real_root) + # 扫描根目录下的链接 + self._add_link_targets(root) + + def _is_link(self, path: str) -> bool: + """ + 检查路径是否为链接 + 支持: + - Windows: Symbolic Links, Junction Points + - Linux: Symbolic Links + """ + try: + # 首先检查通用的符号链接 + if os.path.islink(path): + return True + + # Windows 特定的 Junction Points 检测 + if platform.system() == 'Windows': + try: + import ctypes + FILE_ATTRIBUTE_REPARSE_POINT = 0x400 + attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path)) + return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT) + except Exception as e: + logger.error(f"Error checking Windows reparse point: {e}") + + return False + + except Exception as e: + logger.error(f"Error checking link status for {path}: {e}") + return False + + def _get_link_target(self, path: str) -> str: + """获取链接目标路径""" + try: + return os.path.realpath(path) + except Exception as e: + logger.error(f"Error resolving link target for {path}: {e}") + return path + + def _add_link_targets(self, root: str): + """递归扫描目录,添加链接指向的目标路径""" + try: + with os.scandir(root) as it: + for entry in it: + logger.debug(f"Checking path: {entry.path}") + if self._is_link(entry.path): + # 获取链接的目标路径 + target_path = self._get_link_target(entry.path) + if os.path.isdir(target_path): + normalized_target = os.path.normpath(target_path) + self.monitor_paths.add(normalized_target) + # 添加路径映射到处理器 + self.handler.add_path_mapping(entry.path, target_path) + logger.info(f"Found link: {entry.path} -> {normalized_target}") + # 递归扫描目标目录中的链接 + self._add_link_targets(target_path) + elif entry.is_dir(follow_symlinks=False): + # 递归扫描子目录 + self._add_link_targets(entry.path) + except Exception as e: + logger.error(f"Error scanning links in {root}: {e}") + def start(self): """Start monitoring""" - for root in self.roots: + for path_info in self.monitor_paths: try: - self.observer.schedule(self.handler, root, recursive=True) - logger.info(f"Started monitoring: {root}") + if isinstance(path_info, tuple): + # 对于链接,监控目标路径 + _, target_path = path_info + self.observer.schedule(self.handler, target_path, recursive=True) + logger.info(f"Started monitoring target path: {target_path}") + else: + # 对于普通路径,直接监控 + self.observer.schedule(self.handler, path_info, recursive=True) + logger.info(f"Started monitoring: {path_info}") except Exception as e: - logger.error(f"Error monitoring {root}: {e}") + logger.error(f"Error monitoring {path_info}: {e}") self.observer.start() def stop(self): """Stop monitoring""" self.observer.stop() - self.observer.join() \ No newline at end of file + self.observer.join() + + def rescan_links(self): + """重新扫描链接(当添加新的链接时调用)""" + new_paths = set() + for path in self.monitor_paths.copy(): + self._add_link_targets(path) + + # 添加新发现的路径到监控 + new_paths = self.monitor_paths - set(self.observer.watches.keys()) + for path in new_paths: + try: + self.observer.schedule(self.handler, path, recursive=True) + logger.info(f"Added new monitoring path: {path}") + except Exception as e: + logger.error(f"Error adding new monitor for {path}: {e}") \ No newline at end of file diff --git a/services/lora_scanner.py b/services/lora_scanner.py index 5f6b09c6..ef7cc209 100644 --- a/services/lora_scanner.py +++ b/services/lora_scanner.py @@ -231,23 +231,35 @@ class LoraScanner: async def _scan_directory(self, root_path: str) -> List[Dict]: """Scan a single directory for LoRA files""" loras = [] + original_root = root_path # 保存原始根路径 - # 使用异步安全的目录遍历方式 - async def scan_recursive(path: str): + async def scan_recursive(path: str, visited_paths: set): + """递归扫描目录,避免循环链接""" try: + real_path = os.path.realpath(path) + if real_path in visited_paths: + logger.debug(f"Skipping already visited path: {path}") + return + visited_paths.add(real_path) + with os.scandir(path) as it: - entries = list(it) # 同步获取目录条目 + entries = list(it) for entry in entries: - if entry.is_file() and entry.name.endswith('.safetensors'): - file_path = entry.path.replace(os.sep, "/") - await self._process_single_file(file_path, root_path, loras) - await asyncio.sleep(0) # 释放事件循环 - elif entry.is_dir(): - await scan_recursive(entry.path) + try: + if entry.is_file(follow_symlinks=True) and entry.name.endswith('.safetensors'): + # 使用原始路径而不是真实路径 + file_path = entry.path.replace(os.sep, "/") + await self._process_single_file(file_path, original_root, loras) + await asyncio.sleep(0) + elif entry.is_dir(follow_symlinks=True): + # 对于目录,使用原始路径继续扫描 + await scan_recursive(entry.path, visited_paths) + except Exception as e: + logger.error(f"Error processing entry {entry.path}: {e}") except Exception as e: logger.error(f"Error scanning {path}: {e}") - await scan_recursive(root_path) + await scan_recursive(root_path, set()) return loras async def _process_single_file(self, file_path: str, root_path: str, loras: list): @@ -316,6 +328,7 @@ class LoraScanner: def _calculate_folder(self, file_path: str) -> str: """Calculate the folder path for a LoRA file""" + # 使用原始路径计算相对路径 for root in config.loras_roots: if file_path.startswith(root): rel_path = os.path.relpath(file_path, root) @@ -323,46 +336,38 @@ class LoraScanner: return '' async def move_model(self, source_path: str, target_path: str) -> bool: - """Move a model and its associated files to a new location - - Args: - source_path: Full path to the source lora file - target_path: Full path to the target directory - - Returns: - bool: True if successful, False otherwise - """ + """Move a model and its associated files to a new location""" try: - # Ensure paths are normalized + # 保持原始路径格式 source_path = source_path.replace(os.sep, '/') target_path = target_path.replace(os.sep, '/') - # Get base name without extension + # 其余代码保持不变 base_name = os.path.splitext(os.path.basename(source_path))[0] source_dir = os.path.dirname(source_path) - # Create target directory if it doesn't exist os.makedirs(target_path, exist_ok=True) - # Calculate target lora path target_lora = os.path.join(target_path, f"{base_name}.safetensors").replace(os.sep, '/') - # Get source file size for timeout calculation - file_size = os.path.getsize(source_path) + # 使用真实路径进行文件操作 + real_source = os.path.realpath(source_path) + real_target = os.path.realpath(target_lora) + + file_size = os.path.getsize(real_source) - # Tell file monitor to ignore these paths if self.file_monitor: self.file_monitor.handler.add_ignore_path( - source_path, + real_source, file_size ) self.file_monitor.handler.add_ignore_path( - target_lora, + real_target, file_size ) - # Move main lora file - shutil.move(source_path, target_lora) + # 使用真实路径进行文件操作 + shutil.move(real_source, real_target) # Move associated files source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json")