mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 14:42:11 -03:00
Fix symlink checkpoint2
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import List
|
||||
from threading import Lock
|
||||
from .lora_scanner import LoraScanner
|
||||
import platform
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,7 +24,6 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
self._ignore_paths = set() # Add ignore paths set
|
||||
self._min_ignore_timeout = 5 # minimum timeout in seconds
|
||||
self._download_speed = 1024 * 1024 # assume 1MB/s as base speed
|
||||
self._path_mappings = {} # 添加路径映射字典
|
||||
|
||||
def _should_ignore(self, path: str) -> bool:
|
||||
"""Check if path should be ignored"""
|
||||
@@ -50,23 +50,6 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
real_path.replace(os.sep, '/')
|
||||
)
|
||||
|
||||
def add_path_mapping(self, link_path: str, target_path: str):
|
||||
"""添加符号链接路径映射"""
|
||||
normalized_link = os.path.normpath(link_path).replace(os.sep, '/')
|
||||
normalized_target = os.path.normpath(target_path).replace(os.sep, '/')
|
||||
self._path_mappings[normalized_target] = normalized_link
|
||||
logger.debug(f"Added path mapping: {normalized_target} -> {normalized_link}")
|
||||
|
||||
def _map_path_to_link(self, path: str) -> str:
|
||||
"""将目标路径映射回符号链接路径"""
|
||||
normalized_path = os.path.normpath(path).replace(os.sep, '/')
|
||||
for target_prefix, link_prefix in self._path_mappings.items():
|
||||
if normalized_path.startswith(target_prefix):
|
||||
mapped_path = normalized_path.replace(target_prefix, link_prefix, 1)
|
||||
logger.debug(f"Mapped path {normalized_path} to {mapped_path}")
|
||||
return mapped_path
|
||||
return path
|
||||
|
||||
def on_created(self, event):
|
||||
if event.is_directory or not event.src_path.endswith('.safetensors'):
|
||||
return
|
||||
@@ -86,8 +69,8 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
def _schedule_update(self, action: str, file_path: str):
|
||||
"""Schedule a cache update"""
|
||||
with self.lock:
|
||||
# 将目标路径映射回符号链接路径
|
||||
mapped_path = self._map_path_to_link(file_path)
|
||||
# 使用 config 中的方法映射路径
|
||||
mapped_path = config.map_path_to_link(file_path)
|
||||
normalized_path = mapped_path.replace(os.sep, '/')
|
||||
self.pending_changes.add((action, normalized_path))
|
||||
|
||||
@@ -159,72 +142,14 @@ class LoraFileMonitor:
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.handler = LoraFileHandler(scanner, self.loop)
|
||||
|
||||
# 存储所有需要监控的路径(包括链接的目标路径)
|
||||
# 使用已存在的路径映射
|
||||
self.monitor_paths = set()
|
||||
for root in roots:
|
||||
real_root = os.path.realpath(root)
|
||||
self.monitor_paths.add(real_root)
|
||||
# 扫描根目录下的链接
|
||||
self._add_link_targets(root)
|
||||
self.monitor_paths.add(os.path.realpath(root))
|
||||
|
||||
def _is_link(self, path: str) -> bool:
|
||||
"""
|
||||
检查路径是否为链接
|
||||
支持:
|
||||
- Windows: Symbolic Links, Junction Points
|
||||
- Linux: Symbolic Links
|
||||
"""
|
||||
try:
|
||||
# 首先检查通用的符号链接
|
||||
if os.path.islink(path):
|
||||
return True
|
||||
|
||||
# Windows 特定的 Junction Points 检测
|
||||
if platform.system() == 'Windows':
|
||||
try:
|
||||
import ctypes
|
||||
FILE_ATTRIBUTE_REPARSE_POINT = 0x400
|
||||
attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path))
|
||||
return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT)
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking Windows reparse point: {e}")
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking link status for {path}: {e}")
|
||||
return False
|
||||
|
||||
def _get_link_target(self, path: str) -> str:
|
||||
"""获取链接目标路径"""
|
||||
try:
|
||||
return os.path.realpath(path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error resolving link target for {path}: {e}")
|
||||
return path
|
||||
|
||||
def _add_link_targets(self, root: str):
|
||||
"""递归扫描目录,添加链接指向的目标路径"""
|
||||
try:
|
||||
with os.scandir(root) as it:
|
||||
for entry in it:
|
||||
logger.debug(f"Checking path: {entry.path}")
|
||||
if self._is_link(entry.path):
|
||||
# 获取链接的目标路径
|
||||
target_path = self._get_link_target(entry.path)
|
||||
if os.path.isdir(target_path):
|
||||
normalized_target = os.path.normpath(target_path)
|
||||
self.monitor_paths.add(normalized_target)
|
||||
# 添加路径映射到处理器
|
||||
self.handler.add_path_mapping(entry.path, target_path)
|
||||
logger.info(f"Found link: {entry.path} -> {normalized_target}")
|
||||
# 递归扫描目标目录中的链接
|
||||
self._add_link_targets(target_path)
|
||||
elif entry.is_dir(follow_symlinks=False):
|
||||
# 递归扫描子目录
|
||||
self._add_link_targets(entry.path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning links in {root}: {e}")
|
||||
# 添加所有已映射的目标路径
|
||||
for target_path in config._path_mappings.keys():
|
||||
self.monitor_paths.add(target_path)
|
||||
|
||||
def start(self):
|
||||
"""Start monitoring"""
|
||||
|
||||
Reference in New Issue
Block a user