mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
Fix symlink checkpoint2
This commit is contained in:
113
config.py
113
config.py
@@ -1,14 +1,90 @@
|
|||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
import folder_paths # type: ignore
|
import folder_paths # type: ignore
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Global configuration for LoRA Manager"""
|
"""Global configuration for LoRA Manager"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.loras_roots = self._init_lora_paths()
|
|
||||||
self.templates_path = os.path.join(os.path.dirname(__file__), 'templates')
|
self.templates_path = os.path.join(os.path.dirname(__file__), 'templates')
|
||||||
self.static_path = os.path.join(os.path.dirname(__file__), 'static')
|
self.static_path = os.path.join(os.path.dirname(__file__), 'static')
|
||||||
|
# 路径映射字典, target to link mapping
|
||||||
|
self._path_mappings = {}
|
||||||
|
# 静态路由映射字典, target to route mapping
|
||||||
|
self._route_mappings = {}
|
||||||
|
self.loras_roots = self._init_lora_paths()
|
||||||
|
# 在初始化时扫描符号链接
|
||||||
|
self._scan_symbolic_links()
|
||||||
|
|
||||||
|
def _is_link(self, path: str) -> bool:
|
||||||
|
try:
|
||||||
|
if os.path.islink(path):
|
||||||
|
return True
|
||||||
|
if platform.system() == 'Windows':
|
||||||
|
try:
|
||||||
|
import ctypes
|
||||||
|
FILE_ATTRIBUTE_REPARSE_POINT = 0x400
|
||||||
|
attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path))
|
||||||
|
return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking Windows reparse point: {e}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking link status for {path}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _scan_symbolic_links(self):
|
||||||
|
"""扫描所有 LoRA 根目录中的符号链接"""
|
||||||
|
for root in self.loras_roots:
|
||||||
|
self._scan_directory_links(root)
|
||||||
|
|
||||||
|
def _scan_directory_links(self, root: str):
|
||||||
|
"""递归扫描目录中的符号链接"""
|
||||||
|
try:
|
||||||
|
with os.scandir(root) as it:
|
||||||
|
for entry in it:
|
||||||
|
if self._is_link(entry.path):
|
||||||
|
target_path = os.path.realpath(entry.path)
|
||||||
|
if os.path.isdir(target_path):
|
||||||
|
self.add_path_mapping(entry.path, target_path)
|
||||||
|
self._scan_directory_links(target_path)
|
||||||
|
elif entry.is_dir(follow_symlinks=False):
|
||||||
|
self._scan_directory_links(entry.path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error scanning links in {root}: {e}")
|
||||||
|
|
||||||
|
def add_path_mapping(self, link_path: str, target_path: str):
|
||||||
|
"""添加符号链接路径映射
|
||||||
|
target_path: 实际目标路径
|
||||||
|
link_path: 符号链接路径
|
||||||
|
"""
|
||||||
|
normalized_link = os.path.normpath(link_path).replace(os.sep, '/')
|
||||||
|
normalized_target = os.path.normpath(target_path).replace(os.sep, '/')
|
||||||
|
# 保持原有的映射关系:目标路径 -> 链接路径
|
||||||
|
self._path_mappings[normalized_target] = normalized_link
|
||||||
|
logger.info(f"Added path mapping: {normalized_target} -> {normalized_link}")
|
||||||
|
|
||||||
|
def add_route_mapping(self, path: str, route: str):
|
||||||
|
"""添加静态路由映射"""
|
||||||
|
normalized_path = os.path.normpath(path).replace(os.sep, '/')
|
||||||
|
self._route_mappings[normalized_path] = route
|
||||||
|
logger.info(f"Added route mapping: {normalized_path} -> {route}")
|
||||||
|
|
||||||
|
def map_path_to_link(self, path: str) -> str:
|
||||||
|
"""将目标路径映射回符号链接路径"""
|
||||||
|
normalized_path = os.path.normpath(path).replace(os.sep, '/')
|
||||||
|
# 检查路径是否包含在任何映射的目标路径中
|
||||||
|
for target_path, link_path in self._path_mappings.items():
|
||||||
|
if normalized_path.startswith(target_path):
|
||||||
|
# 如果路径以目标路径开头,则替换为链接路径
|
||||||
|
mapped_path = normalized_path.replace(target_path, link_path, 1)
|
||||||
|
logger.info(f"Mapped path {normalized_path} to {mapped_path}")
|
||||||
|
return mapped_path
|
||||||
|
return path
|
||||||
|
|
||||||
def _init_lora_paths(self) -> List[str]:
|
def _init_lora_paths(self) -> List[str]:
|
||||||
"""Initialize and validate LoRA paths from ComfyUI settings"""
|
"""Initialize and validate LoRA paths from ComfyUI settings"""
|
||||||
@@ -20,6 +96,12 @@ class Config:
|
|||||||
if not paths:
|
if not paths:
|
||||||
raise ValueError("No valid loras folders found in ComfyUI configuration")
|
raise ValueError("No valid loras folders found in ComfyUI configuration")
|
||||||
|
|
||||||
|
# 初始化路径映射
|
||||||
|
for path in paths:
|
||||||
|
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||||
|
if real_path != path:
|
||||||
|
self.add_path_mapping(real_path, path)
|
||||||
|
|
||||||
return paths
|
return paths
|
||||||
|
|
||||||
def get_preview_static_url(self, preview_path: str) -> str:
|
def get_preview_static_url(self, preview_path: str) -> str:
|
||||||
@@ -27,10 +109,31 @@ class Config:
|
|||||||
if not preview_path:
|
if not preview_path:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
for idx, root in enumerate(self.loras_roots, start=1):
|
# 获取真实路径和规范化路径
|
||||||
if preview_path.startswith(root):
|
real_preview_path = os.path.realpath(preview_path)
|
||||||
relative_path = os.path.relpath(preview_path, root)
|
normalized_real_path = real_preview_path.replace(os.sep, '/')
|
||||||
return f'/loras_static/root{idx}/preview/{relative_path.replace(os.sep, "/")}'
|
normalized_preview_path = preview_path.replace(os.sep, '/')
|
||||||
|
|
||||||
|
# 首先尝试使用原始路径查找路由
|
||||||
|
for root_path, route in self._route_mappings.items():
|
||||||
|
if normalized_preview_path.startswith(root_path):
|
||||||
|
relative_path = os.path.relpath(normalized_preview_path, root_path)
|
||||||
|
return f'{route}/{relative_path.replace(os.sep, "/")}'
|
||||||
|
|
||||||
|
# 如果没找到,尝试使用真实路径查找路由
|
||||||
|
for root_path, route in self._route_mappings.items():
|
||||||
|
if normalized_real_path.startswith(root_path):
|
||||||
|
relative_path = os.path.relpath(normalized_real_path, root_path)
|
||||||
|
return f'{route}/{relative_path.replace(os.sep, "/")}'
|
||||||
|
|
||||||
|
# 如果还没找到,尝试使用路径映射
|
||||||
|
mapped_path = self.map_path_to_link(real_preview_path)
|
||||||
|
normalized_mapped_path = mapped_path.replace(os.sep, '/')
|
||||||
|
|
||||||
|
for root_path, route in self._route_mappings.items():
|
||||||
|
if normalized_mapped_path.startswith(root_path):
|
||||||
|
relative_path = os.path.relpath(normalized_mapped_path, root_path)
|
||||||
|
return f'{route}/{relative_path.replace(os.sep, "/")}'
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
from server import PromptServer # type: ignore
|
from server import PromptServer # type: ignore
|
||||||
from .config import config
|
from .config import config
|
||||||
from .routes.lora_routes import LoraRoutes
|
from .routes.lora_routes import LoraRoutes
|
||||||
@@ -6,6 +7,9 @@ from .routes.api_routes import ApiRoutes
|
|||||||
from .services.lora_scanner import LoraScanner
|
from .services.lora_scanner import LoraScanner
|
||||||
from .services.file_monitor import LoraFileMonitor
|
from .services.file_monitor import LoraFileMonitor
|
||||||
from .services.lora_cache import LoraCache
|
from .services.lora_cache import LoraCache
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class LoraManager:
|
class LoraManager:
|
||||||
"""Main entry point for LoRA Manager plugin"""
|
"""Main entry point for LoRA Manager plugin"""
|
||||||
@@ -15,10 +19,32 @@ class LoraManager:
|
|||||||
"""Initialize and register all routes"""
|
"""Initialize and register all routes"""
|
||||||
app = PromptServer.instance.app
|
app = PromptServer.instance.app
|
||||||
|
|
||||||
|
added_targets = set() # 用于跟踪已添加的目标路径
|
||||||
|
|
||||||
# Add static routes for each lora root
|
# Add static routes for each lora root
|
||||||
for idx, root in enumerate(config.loras_roots, start=1):
|
for idx, root in enumerate(config.loras_roots, start=1):
|
||||||
preview_path = f'/loras_static/root{idx}/preview'
|
preview_path = f'/loras_static/root{idx}/preview'
|
||||||
|
|
||||||
|
# 为原始路径添加静态路由
|
||||||
app.router.add_static(preview_path, root)
|
app.router.add_static(preview_path, root)
|
||||||
|
logger.info(f"Added static route {preview_path} -> {root}")
|
||||||
|
|
||||||
|
# 记录路由映射
|
||||||
|
config.add_route_mapping(root, preview_path)
|
||||||
|
added_targets.add(root)
|
||||||
|
|
||||||
|
# 为符号链接的目标路径添加额外的静态路由
|
||||||
|
link_idx = 1
|
||||||
|
|
||||||
|
for target_path, link_path in config._path_mappings.items():
|
||||||
|
if target_path not in added_targets:
|
||||||
|
route_path = f'/loras_static/link_{link_idx}/preview'
|
||||||
|
app.router.add_static(route_path, target_path)
|
||||||
|
logger.info(f"Added static route for link target {route_path} -> {target_path}")
|
||||||
|
config.add_route_mapping(target_path, route_path)
|
||||||
|
config.add_route_mapping(link_path, route_path) # 也为符号链接路径添加路由映射
|
||||||
|
added_targets.add(target_path)
|
||||||
|
link_idx += 1
|
||||||
|
|
||||||
# Add static route for plugin assets
|
# Add static route for plugin assets
|
||||||
app.router.add_static('/loras_static', config.static_path)
|
app.router.add_static('/loras_static', config.static_path)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from typing import List
|
|||||||
from threading import Lock
|
from threading import Lock
|
||||||
from .lora_scanner import LoraScanner
|
from .lora_scanner import LoraScanner
|
||||||
import platform
|
import platform
|
||||||
|
from ..config import config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -23,7 +24,6 @@ class LoraFileHandler(FileSystemEventHandler):
|
|||||||
self._ignore_paths = set() # Add ignore paths set
|
self._ignore_paths = set() # Add ignore paths set
|
||||||
self._min_ignore_timeout = 5 # minimum timeout in seconds
|
self._min_ignore_timeout = 5 # minimum timeout in seconds
|
||||||
self._download_speed = 1024 * 1024 # assume 1MB/s as base speed
|
self._download_speed = 1024 * 1024 # assume 1MB/s as base speed
|
||||||
self._path_mappings = {} # 添加路径映射字典
|
|
||||||
|
|
||||||
def _should_ignore(self, path: str) -> bool:
|
def _should_ignore(self, path: str) -> bool:
|
||||||
"""Check if path should be ignored"""
|
"""Check if path should be ignored"""
|
||||||
@@ -50,23 +50,6 @@ class LoraFileHandler(FileSystemEventHandler):
|
|||||||
real_path.replace(os.sep, '/')
|
real_path.replace(os.sep, '/')
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_path_mapping(self, link_path: str, target_path: str):
|
|
||||||
"""添加符号链接路径映射"""
|
|
||||||
normalized_link = os.path.normpath(link_path).replace(os.sep, '/')
|
|
||||||
normalized_target = os.path.normpath(target_path).replace(os.sep, '/')
|
|
||||||
self._path_mappings[normalized_target] = normalized_link
|
|
||||||
logger.debug(f"Added path mapping: {normalized_target} -> {normalized_link}")
|
|
||||||
|
|
||||||
def _map_path_to_link(self, path: str) -> str:
|
|
||||||
"""将目标路径映射回符号链接路径"""
|
|
||||||
normalized_path = os.path.normpath(path).replace(os.sep, '/')
|
|
||||||
for target_prefix, link_prefix in self._path_mappings.items():
|
|
||||||
if normalized_path.startswith(target_prefix):
|
|
||||||
mapped_path = normalized_path.replace(target_prefix, link_prefix, 1)
|
|
||||||
logger.debug(f"Mapped path {normalized_path} to {mapped_path}")
|
|
||||||
return mapped_path
|
|
||||||
return path
|
|
||||||
|
|
||||||
def on_created(self, event):
|
def on_created(self, event):
|
||||||
if event.is_directory or not event.src_path.endswith('.safetensors'):
|
if event.is_directory or not event.src_path.endswith('.safetensors'):
|
||||||
return
|
return
|
||||||
@@ -86,8 +69,8 @@ class LoraFileHandler(FileSystemEventHandler):
|
|||||||
def _schedule_update(self, action: str, file_path: str):
|
def _schedule_update(self, action: str, file_path: str):
|
||||||
"""Schedule a cache update"""
|
"""Schedule a cache update"""
|
||||||
with self.lock:
|
with self.lock:
|
||||||
# 将目标路径映射回符号链接路径
|
# 使用 config 中的方法映射路径
|
||||||
mapped_path = self._map_path_to_link(file_path)
|
mapped_path = config.map_path_to_link(file_path)
|
||||||
normalized_path = mapped_path.replace(os.sep, '/')
|
normalized_path = mapped_path.replace(os.sep, '/')
|
||||||
self.pending_changes.add((action, normalized_path))
|
self.pending_changes.add((action, normalized_path))
|
||||||
|
|
||||||
@@ -159,72 +142,14 @@ class LoraFileMonitor:
|
|||||||
self.loop = asyncio.get_event_loop()
|
self.loop = asyncio.get_event_loop()
|
||||||
self.handler = LoraFileHandler(scanner, self.loop)
|
self.handler = LoraFileHandler(scanner, self.loop)
|
||||||
|
|
||||||
# 存储所有需要监控的路径(包括链接的目标路径)
|
# 使用已存在的路径映射
|
||||||
self.monitor_paths = set()
|
self.monitor_paths = set()
|
||||||
for root in roots:
|
for root in roots:
|
||||||
real_root = os.path.realpath(root)
|
self.monitor_paths.add(os.path.realpath(root))
|
||||||
self.monitor_paths.add(real_root)
|
|
||||||
# 扫描根目录下的链接
|
|
||||||
self._add_link_targets(root)
|
|
||||||
|
|
||||||
def _is_link(self, path: str) -> bool:
|
# 添加所有已映射的目标路径
|
||||||
"""
|
for target_path in config._path_mappings.keys():
|
||||||
检查路径是否为链接
|
self.monitor_paths.add(target_path)
|
||||||
支持:
|
|
||||||
- Windows: Symbolic Links, Junction Points
|
|
||||||
- Linux: Symbolic Links
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 首先检查通用的符号链接
|
|
||||||
if os.path.islink(path):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Windows 特定的 Junction Points 检测
|
|
||||||
if platform.system() == 'Windows':
|
|
||||||
try:
|
|
||||||
import ctypes
|
|
||||||
FILE_ATTRIBUTE_REPARSE_POINT = 0x400
|
|
||||||
attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path))
|
|
||||||
return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error checking Windows reparse point: {e}")
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error checking link status for {path}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _get_link_target(self, path: str) -> str:
|
|
||||||
"""获取链接目标路径"""
|
|
||||||
try:
|
|
||||||
return os.path.realpath(path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error resolving link target for {path}: {e}")
|
|
||||||
return path
|
|
||||||
|
|
||||||
def _add_link_targets(self, root: str):
|
|
||||||
"""递归扫描目录,添加链接指向的目标路径"""
|
|
||||||
try:
|
|
||||||
with os.scandir(root) as it:
|
|
||||||
for entry in it:
|
|
||||||
logger.debug(f"Checking path: {entry.path}")
|
|
||||||
if self._is_link(entry.path):
|
|
||||||
# 获取链接的目标路径
|
|
||||||
target_path = self._get_link_target(entry.path)
|
|
||||||
if os.path.isdir(target_path):
|
|
||||||
normalized_target = os.path.normpath(target_path)
|
|
||||||
self.monitor_paths.add(normalized_target)
|
|
||||||
# 添加路径映射到处理器
|
|
||||||
self.handler.add_path_mapping(entry.path, target_path)
|
|
||||||
logger.info(f"Found link: {entry.path} -> {normalized_target}")
|
|
||||||
# 递归扫描目标目录中的链接
|
|
||||||
self._add_link_targets(target_path)
|
|
||||||
elif entry.is_dir(follow_symlinks=False):
|
|
||||||
# 递归扫描子目录
|
|
||||||
self._add_link_targets(entry.path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error scanning links in {root}: {e}")
|
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Start monitoring"""
|
"""Start monitoring"""
|
||||||
|
|||||||
Reference in New Issue
Block a user