mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
Fix symlink checkpoint1
This commit is contained in:
@@ -7,6 +7,7 @@ from watchdog.events import FileSystemEventHandler, FileCreatedEvent, FileDelete
|
|||||||
from typing import List
|
from typing import List
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from .lora_scanner import LoraScanner
|
from .lora_scanner import LoraScanner
|
||||||
|
import platform
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -22,14 +23,17 @@ class LoraFileHandler(FileSystemEventHandler):
|
|||||||
self._ignore_paths = set() # Add ignore paths set
|
self._ignore_paths = set() # Add ignore paths set
|
||||||
self._min_ignore_timeout = 5 # minimum timeout in seconds
|
self._min_ignore_timeout = 5 # minimum timeout in seconds
|
||||||
self._download_speed = 1024 * 1024 # assume 1MB/s as base speed
|
self._download_speed = 1024 * 1024 # assume 1MB/s as base speed
|
||||||
|
self._path_mappings = {} # 添加路径映射字典
|
||||||
|
|
||||||
def _should_ignore(self, path: str) -> bool:
|
def _should_ignore(self, path: str) -> bool:
|
||||||
"""Check if path should be ignored"""
|
"""Check if path should be ignored"""
|
||||||
return path.replace(os.sep, '/') in self._ignore_paths
|
real_path = os.path.realpath(path) # Resolve any symbolic links
|
||||||
|
return real_path.replace(os.sep, '/') in self._ignore_paths
|
||||||
|
|
||||||
def add_ignore_path(self, path: str, file_size: int = 0):
|
def add_ignore_path(self, path: str, file_size: int = 0):
|
||||||
"""Add path to ignore list with dynamic timeout based on file size"""
|
"""Add path to ignore list with dynamic timeout based on file size"""
|
||||||
self._ignore_paths.add(path.replace(os.sep, '/'))
|
real_path = os.path.realpath(path) # Resolve any symbolic links
|
||||||
|
self._ignore_paths.add(real_path.replace(os.sep, '/'))
|
||||||
|
|
||||||
# Calculate timeout based on file size, with a minimum value
|
# Calculate timeout based on file size, with a minimum value
|
||||||
# Assuming average download speed of 1MB/s
|
# Assuming average download speed of 1MB/s
|
||||||
@@ -38,14 +42,31 @@ class LoraFileHandler(FileSystemEventHandler):
|
|||||||
(file_size / self._download_speed) * 1.5 # Add 50% buffer
|
(file_size / self._download_speed) * 1.5 # Add 50% buffer
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"Adding {path} to ignore list for {timeout:.1f} seconds")
|
logger.debug(f"Adding {real_path} to ignore list for {timeout:.1f} seconds")
|
||||||
|
|
||||||
asyncio.get_event_loop().call_later(
|
asyncio.get_event_loop().call_later(
|
||||||
timeout,
|
timeout,
|
||||||
self._ignore_paths.discard,
|
self._ignore_paths.discard,
|
||||||
path
|
real_path.replace(os.sep, '/')
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def add_path_mapping(self, link_path: str, target_path: str):
|
||||||
|
"""添加符号链接路径映射"""
|
||||||
|
normalized_link = os.path.normpath(link_path).replace(os.sep, '/')
|
||||||
|
normalized_target = os.path.normpath(target_path).replace(os.sep, '/')
|
||||||
|
self._path_mappings[normalized_target] = normalized_link
|
||||||
|
logger.debug(f"Added path mapping: {normalized_target} -> {normalized_link}")
|
||||||
|
|
||||||
|
def _map_path_to_link(self, path: str) -> str:
|
||||||
|
"""将目标路径映射回符号链接路径"""
|
||||||
|
normalized_path = os.path.normpath(path).replace(os.sep, '/')
|
||||||
|
for target_prefix, link_prefix in self._path_mappings.items():
|
||||||
|
if normalized_path.startswith(target_prefix):
|
||||||
|
mapped_path = normalized_path.replace(target_prefix, link_prefix, 1)
|
||||||
|
logger.debug(f"Mapped path {normalized_path} to {mapped_path}")
|
||||||
|
return mapped_path
|
||||||
|
return path
|
||||||
|
|
||||||
def on_created(self, event):
|
def on_created(self, event):
|
||||||
if event.is_directory or not event.src_path.endswith('.safetensors'):
|
if event.is_directory or not event.src_path.endswith('.safetensors'):
|
||||||
return
|
return
|
||||||
@@ -65,11 +86,11 @@ class LoraFileHandler(FileSystemEventHandler):
|
|||||||
def _schedule_update(self, action: str, file_path: str):
|
def _schedule_update(self, action: str, file_path: str):
|
||||||
"""Schedule a cache update"""
|
"""Schedule a cache update"""
|
||||||
with self.lock:
|
with self.lock:
|
||||||
# 标准化路径
|
# 将目标路径映射回符号链接路径
|
||||||
normalized_path = file_path.replace(os.sep, '/')
|
mapped_path = self._map_path_to_link(file_path)
|
||||||
|
normalized_path = mapped_path.replace(os.sep, '/')
|
||||||
self.pending_changes.add((action, normalized_path))
|
self.pending_changes.add((action, normalized_path))
|
||||||
|
|
||||||
# 使用 call_soon_threadsafe 在事件循环中安排任务
|
|
||||||
self.loop.call_soon_threadsafe(self._create_update_task)
|
self.loop.call_soon_threadsafe(self._create_update_task)
|
||||||
|
|
||||||
def _create_update_task(self):
|
def _create_update_task(self):
|
||||||
@@ -134,20 +155,92 @@ class LoraFileMonitor:
|
|||||||
def __init__(self, scanner: LoraScanner, roots: List[str]):
|
def __init__(self, scanner: LoraScanner, roots: List[str]):
|
||||||
self.scanner = scanner
|
self.scanner = scanner
|
||||||
scanner.set_file_monitor(self)
|
scanner.set_file_monitor(self)
|
||||||
self.roots = roots
|
|
||||||
self.observer = Observer()
|
self.observer = Observer()
|
||||||
# 获取当前运行的事件循环
|
|
||||||
self.loop = asyncio.get_event_loop()
|
self.loop = asyncio.get_event_loop()
|
||||||
self.handler = LoraFileHandler(scanner, self.loop)
|
self.handler = LoraFileHandler(scanner, self.loop)
|
||||||
|
|
||||||
|
# 存储所有需要监控的路径(包括链接的目标路径)
|
||||||
|
self.monitor_paths = set()
|
||||||
|
for root in roots:
|
||||||
|
real_root = os.path.realpath(root)
|
||||||
|
self.monitor_paths.add(real_root)
|
||||||
|
# 扫描根目录下的链接
|
||||||
|
self._add_link_targets(root)
|
||||||
|
|
||||||
|
def _is_link(self, path: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查路径是否为链接
|
||||||
|
支持:
|
||||||
|
- Windows: Symbolic Links, Junction Points
|
||||||
|
- Linux: Symbolic Links
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 首先检查通用的符号链接
|
||||||
|
if os.path.islink(path):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Windows 特定的 Junction Points 检测
|
||||||
|
if platform.system() == 'Windows':
|
||||||
|
try:
|
||||||
|
import ctypes
|
||||||
|
FILE_ATTRIBUTE_REPARSE_POINT = 0x400
|
||||||
|
attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path))
|
||||||
|
return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking Windows reparse point: {e}")
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking link status for {path}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _get_link_target(self, path: str) -> str:
|
||||||
|
"""获取链接目标路径"""
|
||||||
|
try:
|
||||||
|
return os.path.realpath(path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error resolving link target for {path}: {e}")
|
||||||
|
return path
|
||||||
|
|
||||||
|
def _add_link_targets(self, root: str):
|
||||||
|
"""递归扫描目录,添加链接指向的目标路径"""
|
||||||
|
try:
|
||||||
|
with os.scandir(root) as it:
|
||||||
|
for entry in it:
|
||||||
|
logger.debug(f"Checking path: {entry.path}")
|
||||||
|
if self._is_link(entry.path):
|
||||||
|
# 获取链接的目标路径
|
||||||
|
target_path = self._get_link_target(entry.path)
|
||||||
|
if os.path.isdir(target_path):
|
||||||
|
normalized_target = os.path.normpath(target_path)
|
||||||
|
self.monitor_paths.add(normalized_target)
|
||||||
|
# 添加路径映射到处理器
|
||||||
|
self.handler.add_path_mapping(entry.path, target_path)
|
||||||
|
logger.info(f"Found link: {entry.path} -> {normalized_target}")
|
||||||
|
# 递归扫描目标目录中的链接
|
||||||
|
self._add_link_targets(target_path)
|
||||||
|
elif entry.is_dir(follow_symlinks=False):
|
||||||
|
# 递归扫描子目录
|
||||||
|
self._add_link_targets(entry.path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error scanning links in {root}: {e}")
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Start monitoring"""
|
"""Start monitoring"""
|
||||||
for root in self.roots:
|
for path_info in self.monitor_paths:
|
||||||
try:
|
try:
|
||||||
self.observer.schedule(self.handler, root, recursive=True)
|
if isinstance(path_info, tuple):
|
||||||
logger.info(f"Started monitoring: {root}")
|
# 对于链接,监控目标路径
|
||||||
|
_, target_path = path_info
|
||||||
|
self.observer.schedule(self.handler, target_path, recursive=True)
|
||||||
|
logger.info(f"Started monitoring target path: {target_path}")
|
||||||
|
else:
|
||||||
|
# 对于普通路径,直接监控
|
||||||
|
self.observer.schedule(self.handler, path_info, recursive=True)
|
||||||
|
logger.info(f"Started monitoring: {path_info}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error monitoring {root}: {e}")
|
logger.error(f"Error monitoring {path_info}: {e}")
|
||||||
|
|
||||||
self.observer.start()
|
self.observer.start()
|
||||||
|
|
||||||
@@ -155,3 +248,18 @@ class LoraFileMonitor:
|
|||||||
"""Stop monitoring"""
|
"""Stop monitoring"""
|
||||||
self.observer.stop()
|
self.observer.stop()
|
||||||
self.observer.join()
|
self.observer.join()
|
||||||
|
|
||||||
|
def rescan_links(self):
|
||||||
|
"""重新扫描链接(当添加新的链接时调用)"""
|
||||||
|
new_paths = set()
|
||||||
|
for path in self.monitor_paths.copy():
|
||||||
|
self._add_link_targets(path)
|
||||||
|
|
||||||
|
# 添加新发现的路径到监控
|
||||||
|
new_paths = self.monitor_paths - set(self.observer.watches.keys())
|
||||||
|
for path in new_paths:
|
||||||
|
try:
|
||||||
|
self.observer.schedule(self.handler, path, recursive=True)
|
||||||
|
logger.info(f"Added new monitoring path: {path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding new monitor for {path}: {e}")
|
||||||
@@ -231,23 +231,35 @@ class LoraScanner:
|
|||||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||||
"""Scan a single directory for LoRA files"""
|
"""Scan a single directory for LoRA files"""
|
||||||
loras = []
|
loras = []
|
||||||
|
original_root = root_path # 保存原始根路径
|
||||||
|
|
||||||
# 使用异步安全的目录遍历方式
|
async def scan_recursive(path: str, visited_paths: set):
|
||||||
async def scan_recursive(path: str):
|
"""递归扫描目录,避免循环链接"""
|
||||||
try:
|
try:
|
||||||
|
real_path = os.path.realpath(path)
|
||||||
|
if real_path in visited_paths:
|
||||||
|
logger.debug(f"Skipping already visited path: {path}")
|
||||||
|
return
|
||||||
|
visited_paths.add(real_path)
|
||||||
|
|
||||||
with os.scandir(path) as it:
|
with os.scandir(path) as it:
|
||||||
entries = list(it) # 同步获取目录条目
|
entries = list(it)
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
if entry.is_file() and entry.name.endswith('.safetensors'):
|
try:
|
||||||
file_path = entry.path.replace(os.sep, "/")
|
if entry.is_file(follow_symlinks=True) and entry.name.endswith('.safetensors'):
|
||||||
await self._process_single_file(file_path, root_path, loras)
|
# 使用原始路径而不是真实路径
|
||||||
await asyncio.sleep(0) # 释放事件循环
|
file_path = entry.path.replace(os.sep, "/")
|
||||||
elif entry.is_dir():
|
await self._process_single_file(file_path, original_root, loras)
|
||||||
await scan_recursive(entry.path)
|
await asyncio.sleep(0)
|
||||||
|
elif entry.is_dir(follow_symlinks=True):
|
||||||
|
# 对于目录,使用原始路径继续扫描
|
||||||
|
await scan_recursive(entry.path, visited_paths)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error scanning {path}: {e}")
|
logger.error(f"Error scanning {path}: {e}")
|
||||||
|
|
||||||
await scan_recursive(root_path)
|
await scan_recursive(root_path, set())
|
||||||
return loras
|
return loras
|
||||||
|
|
||||||
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
|
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
|
||||||
@@ -316,6 +328,7 @@ class LoraScanner:
|
|||||||
|
|
||||||
def _calculate_folder(self, file_path: str) -> str:
|
def _calculate_folder(self, file_path: str) -> str:
|
||||||
"""Calculate the folder path for a LoRA file"""
|
"""Calculate the folder path for a LoRA file"""
|
||||||
|
# 使用原始路径计算相对路径
|
||||||
for root in config.loras_roots:
|
for root in config.loras_roots:
|
||||||
if file_path.startswith(root):
|
if file_path.startswith(root):
|
||||||
rel_path = os.path.relpath(file_path, root)
|
rel_path = os.path.relpath(file_path, root)
|
||||||
@@ -323,46 +336,38 @@ class LoraScanner:
|
|||||||
return ''
|
return ''
|
||||||
|
|
||||||
async def move_model(self, source_path: str, target_path: str) -> bool:
|
async def move_model(self, source_path: str, target_path: str) -> bool:
|
||||||
"""Move a model and its associated files to a new location
|
"""Move a model and its associated files to a new location"""
|
||||||
|
|
||||||
Args:
|
|
||||||
source_path: Full path to the source lora file
|
|
||||||
target_path: Full path to the target directory
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if successful, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Ensure paths are normalized
|
# 保持原始路径格式
|
||||||
source_path = source_path.replace(os.sep, '/')
|
source_path = source_path.replace(os.sep, '/')
|
||||||
target_path = target_path.replace(os.sep, '/')
|
target_path = target_path.replace(os.sep, '/')
|
||||||
|
|
||||||
# Get base name without extension
|
# 其余代码保持不变
|
||||||
base_name = os.path.splitext(os.path.basename(source_path))[0]
|
base_name = os.path.splitext(os.path.basename(source_path))[0]
|
||||||
source_dir = os.path.dirname(source_path)
|
source_dir = os.path.dirname(source_path)
|
||||||
|
|
||||||
# Create target directory if it doesn't exist
|
|
||||||
os.makedirs(target_path, exist_ok=True)
|
os.makedirs(target_path, exist_ok=True)
|
||||||
|
|
||||||
# Calculate target lora path
|
|
||||||
target_lora = os.path.join(target_path, f"{base_name}.safetensors").replace(os.sep, '/')
|
target_lora = os.path.join(target_path, f"{base_name}.safetensors").replace(os.sep, '/')
|
||||||
|
|
||||||
# Get source file size for timeout calculation
|
# 使用真实路径进行文件操作
|
||||||
file_size = os.path.getsize(source_path)
|
real_source = os.path.realpath(source_path)
|
||||||
|
real_target = os.path.realpath(target_lora)
|
||||||
|
|
||||||
|
file_size = os.path.getsize(real_source)
|
||||||
|
|
||||||
# Tell file monitor to ignore these paths
|
|
||||||
if self.file_monitor:
|
if self.file_monitor:
|
||||||
self.file_monitor.handler.add_ignore_path(
|
self.file_monitor.handler.add_ignore_path(
|
||||||
source_path,
|
real_source,
|
||||||
file_size
|
file_size
|
||||||
)
|
)
|
||||||
self.file_monitor.handler.add_ignore_path(
|
self.file_monitor.handler.add_ignore_path(
|
||||||
target_lora,
|
real_target,
|
||||||
file_size
|
file_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# Move main lora file
|
# 使用真实路径进行文件操作
|
||||||
shutil.move(source_path, target_lora)
|
shutil.move(real_source, real_target)
|
||||||
|
|
||||||
# Move associated files
|
# Move associated files
|
||||||
source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json")
|
source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json")
|
||||||
|
|||||||
Reference in New Issue
Block a user