mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
Add file monitoring and scanning improvements for LoRA management
This commit is contained in:
@@ -4,6 +4,7 @@ from .config import config
|
|||||||
from .routes.lora_routes import LoraRoutes
|
from .routes.lora_routes import LoraRoutes
|
||||||
from .routes.api_routes import ApiRoutes
|
from .routes.api_routes import ApiRoutes
|
||||||
from .services.lora_scanner import LoraScanner
|
from .services.lora_scanner import LoraScanner
|
||||||
|
from .services.file_monitor import LoraFileMonitor
|
||||||
|
|
||||||
class LoraManager:
|
class LoraManager:
|
||||||
"""Main entry point for LoRA Manager plugin"""
|
"""Main entry point for LoRA Manager plugin"""
|
||||||
@@ -28,8 +29,18 @@ class LoraManager:
|
|||||||
LoraRoutes.setup_routes(app)
|
LoraRoutes.setup_routes(app)
|
||||||
ApiRoutes.setup_routes(app)
|
ApiRoutes.setup_routes(app)
|
||||||
|
|
||||||
|
# Setup file monitoring
|
||||||
|
monitor = LoraFileMonitor(routes.scanner, config.loras_roots)
|
||||||
|
monitor.start()
|
||||||
|
|
||||||
|
# Store monitor in app for cleanup
|
||||||
|
app['lora_monitor'] = monitor
|
||||||
|
|
||||||
# Schedule cache initialization using the application's startup handler
|
# Schedule cache initialization using the application's startup handler
|
||||||
app.on_startup.append(lambda app: cls._schedule_cache_init(routes.scanner))
|
app.on_startup.append(lambda app: cls._schedule_cache_init(routes.scanner))
|
||||||
|
|
||||||
|
# Add cleanup
|
||||||
|
app.on_shutdown.append(cls._cleanup)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _schedule_cache_init(cls, scanner: LoraScanner):
|
async def _schedule_cache_init(cls, scanner: LoraScanner):
|
||||||
@@ -47,4 +58,10 @@ class LoraManager:
|
|||||||
await scanner.get_cached_data(force_refresh=True)
|
await scanner.get_cached_data(force_refresh=True)
|
||||||
print("LoRA Manager: Cache initialization completed")
|
print("LoRA Manager: Cache initialization completed")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"LoRA Manager: Error initializing cache: {e}")
|
print(f"LoRA Manager: Error initializing cache: {e}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _cleanup(cls, app):
|
||||||
|
"""Cleanup resources"""
|
||||||
|
if 'lora_monitor' in app:
|
||||||
|
app['lora_monitor'].stop()
|
||||||
123
services/file_monitor.py
Normal file
123
services/file_monitor.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from watchdog.observers import Observer
|
||||||
|
from watchdog.events import FileSystemEventHandler, FileCreatedEvent, FileDeletedEvent
|
||||||
|
from typing import List, Set, Callable
|
||||||
|
from threading import Lock
|
||||||
|
from .lora_scanner import LoraScanner
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class LoraFileHandler(FileSystemEventHandler):
|
||||||
|
"""Handler for LoRA file system events"""
|
||||||
|
|
||||||
|
def __init__(self, scanner: LoraScanner, loop: asyncio.AbstractEventLoop):
|
||||||
|
self.scanner = scanner
|
||||||
|
self.loop = loop # 存储事件循环引用
|
||||||
|
self.pending_changes = set() # 待处理的变更
|
||||||
|
self.lock = Lock() # 线程安全锁
|
||||||
|
self.update_task = None # 异步更新任务
|
||||||
|
|
||||||
|
def on_created(self, event):
|
||||||
|
if event.is_directory or not event.src_path.endswith('.safetensors'):
|
||||||
|
return
|
||||||
|
logger.info(f"LoRA file created: {event.src_path}")
|
||||||
|
self._schedule_update('add', event.src_path)
|
||||||
|
|
||||||
|
def on_deleted(self, event):
|
||||||
|
if event.is_directory or not event.src_path.endswith('.safetensors'):
|
||||||
|
return
|
||||||
|
logger.info(f"LoRA file deleted: {event.src_path}")
|
||||||
|
self._schedule_update('remove', event.src_path)
|
||||||
|
|
||||||
|
def _schedule_update(self, action: str, file_path: str):
|
||||||
|
"""Schedule a cache update"""
|
||||||
|
with self.lock:
|
||||||
|
# 标准化路径
|
||||||
|
normalized_path = file_path.replace(os.sep, '/')
|
||||||
|
self.pending_changes.add((action, normalized_path))
|
||||||
|
|
||||||
|
# 使用 call_soon_threadsafe 在事件循环中安排任务
|
||||||
|
self.loop.call_soon_threadsafe(self._create_update_task)
|
||||||
|
|
||||||
|
def _create_update_task(self):
|
||||||
|
"""Create update task in the event loop"""
|
||||||
|
if self.update_task is None or self.update_task.done():
|
||||||
|
self.update_task = asyncio.create_task(self._process_changes())
|
||||||
|
|
||||||
|
async def _process_changes(self, delay: float = 2.0):
|
||||||
|
"""Process pending changes with debouncing"""
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with self.lock:
|
||||||
|
changes = self.pending_changes.copy()
|
||||||
|
self.pending_changes.clear()
|
||||||
|
|
||||||
|
if not changes:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Processing {len(changes)} file changes")
|
||||||
|
|
||||||
|
# 获取当前缓存
|
||||||
|
cache = await self.scanner.get_cached_data()
|
||||||
|
needs_resort = False
|
||||||
|
|
||||||
|
for action, file_path in changes:
|
||||||
|
try:
|
||||||
|
if action == 'add':
|
||||||
|
# 扫描新文件
|
||||||
|
lora_data = await self.scanner.scan_single_lora(file_path)
|
||||||
|
if lora_data:
|
||||||
|
cache.raw_data.append(lora_data)
|
||||||
|
needs_resort = True
|
||||||
|
|
||||||
|
elif action == 'remove':
|
||||||
|
# 从缓存中移除
|
||||||
|
cache.raw_data = [
|
||||||
|
item for item in cache.raw_data
|
||||||
|
if item['file_path'] != file_path
|
||||||
|
]
|
||||||
|
needs_resort = True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing {action} for {file_path}: {e}")
|
||||||
|
|
||||||
|
# 如果有变更,更新排序并重置缓存时间
|
||||||
|
if needs_resort:
|
||||||
|
await self.scanner.resort_cache()
|
||||||
|
# 更新缓存时间戳,确保下次获取时能得到最新数据
|
||||||
|
self.scanner._cache.last_update = time.time()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in process_changes: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class LoraFileMonitor:
|
||||||
|
"""Monitor for LoRA file changes"""
|
||||||
|
|
||||||
|
def __init__(self, scanner: LoraScanner, roots: List[str]):
|
||||||
|
self.scanner = scanner
|
||||||
|
self.roots = roots
|
||||||
|
self.observer = Observer()
|
||||||
|
# 获取当前运行的事件循环
|
||||||
|
self.loop = asyncio.get_event_loop()
|
||||||
|
self.handler = LoraFileHandler(scanner, self.loop)
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Start monitoring"""
|
||||||
|
for root in self.roots:
|
||||||
|
try:
|
||||||
|
self.observer.schedule(self.handler, root, recursive=True)
|
||||||
|
logger.info(f"Started monitoring: {root}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error monitoring {root}: {e}")
|
||||||
|
|
||||||
|
self.observer.start()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stop monitoring"""
|
||||||
|
self.observer.stop()
|
||||||
|
self.observer.join()
|
||||||
@@ -6,8 +6,7 @@ from typing import List, Dict, Optional
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..utils.file_utils import load_metadata, get_file_info, save_metadata
|
from ..utils.file_utils import load_metadata, get_file_info
|
||||||
from ..utils.lora_metadata import extract_lora_metadata
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -181,9 +180,6 @@ class LoraScanner:
|
|||||||
if metadata is None:
|
if metadata is None:
|
||||||
# Create new metadata if none exists
|
# Create new metadata if none exists
|
||||||
metadata = await get_file_info(file_path)
|
metadata = await get_file_info(file_path)
|
||||||
base_model_info = await extract_lora_metadata(file_path)
|
|
||||||
metadata.base_model = base_model_info['base_model']
|
|
||||||
await save_metadata(file_path, metadata)
|
|
||||||
|
|
||||||
# Convert to dict and add folder info
|
# Convert to dict and add folder info
|
||||||
lora_data = metadata.to_dict()
|
lora_data = metadata.to_dict()
|
||||||
@@ -207,3 +203,55 @@ class LoraScanner:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
return self._cache.update_preview_url(file_path, preview_url)
|
return self._cache.update_preview_url(file_path, preview_url)
|
||||||
|
|
||||||
|
async def scan_single_lora(self, file_path: str) -> Optional[Dict]:
|
||||||
|
"""Scan a single LoRA file and return its metadata"""
|
||||||
|
try:
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 获取基本文件信息
|
||||||
|
metadata = await get_file_info(file_path)
|
||||||
|
if not metadata:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 计算相对于 lora_roots 的文件夹路径
|
||||||
|
folder = None
|
||||||
|
file_dir = os.path.dirname(file_path)
|
||||||
|
for root in config.loras_roots:
|
||||||
|
if file_dir.startswith(root):
|
||||||
|
rel_path = os.path.relpath(file_dir, root)
|
||||||
|
if rel_path == '.':
|
||||||
|
folder = '' # 根目录
|
||||||
|
else:
|
||||||
|
folder = rel_path.replace(os.sep, '/')
|
||||||
|
break
|
||||||
|
|
||||||
|
# 确保 folder 字段存在
|
||||||
|
metadata_dict = metadata.to_dict()
|
||||||
|
metadata_dict['folder'] = folder or ''
|
||||||
|
|
||||||
|
return metadata_dict
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error scanning {file_path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def resort_cache(self):
|
||||||
|
"""Resort cache data"""
|
||||||
|
if not self._cache:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._cache.sorted_by_name = sorted(
|
||||||
|
self._cache.raw_data,
|
||||||
|
key=itemgetter('model_name')
|
||||||
|
)
|
||||||
|
self._cache.sorted_by_date = sorted(
|
||||||
|
self._cache.raw_data,
|
||||||
|
key=itemgetter('modified'),
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
# 更新文件夹列表
|
||||||
|
self._cache.folders = sorted(list(set(
|
||||||
|
l['folder'] for l in self._cache.raw_data
|
||||||
|
)))
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ import os
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from .lora_metadata import extract_lora_metadata
|
||||||
from .models import LoraMetadata
|
from .models import LoraMetadata
|
||||||
|
|
||||||
async def calculate_sha256(file_path: str) -> str:
|
async def calculate_sha256(file_path: str) -> str:
|
||||||
@@ -41,8 +43,8 @@ async def get_file_info(file_path: str) -> LoraMetadata:
|
|||||||
dir_path = os.path.dirname(file_path)
|
dir_path = os.path.dirname(file_path)
|
||||||
|
|
||||||
preview_url = _find_preview_file(base_name, dir_path)
|
preview_url = _find_preview_file(base_name, dir_path)
|
||||||
|
|
||||||
return LoraMetadata(
|
metadata = LoraMetadata(
|
||||||
file_name=base_name,
|
file_name=base_name,
|
||||||
model_name=base_name,
|
model_name=base_name,
|
||||||
file_path=normalize_path(file_path),
|
file_path=normalize_path(file_path),
|
||||||
@@ -54,6 +56,13 @@ async def get_file_info(file_path: str) -> LoraMetadata:
|
|||||||
preview_url=normalize_path(preview_url),
|
preview_url=normalize_path(preview_url),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# create metadata file
|
||||||
|
base_model_info = await extract_lora_metadata(file_path)
|
||||||
|
metadata.base_model = base_model_info['base_model']
|
||||||
|
await save_metadata(file_path, metadata)
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
async def save_metadata(file_path: str, metadata: LoraMetadata) -> None:
|
async def save_metadata(file_path: str, metadata: LoraMetadata) -> None:
|
||||||
"""Save metadata to .metadata.json file"""
|
"""Save metadata to .metadata.json file"""
|
||||||
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
|
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
|
||||||
|
|||||||
Reference in New Issue
Block a user