From 270182b5cdcf02c77d92f80df674ba67d067b0c0 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Mon, 3 Feb 2025 14:58:04 +0800 Subject: [PATCH] Add file monitoring and scanning improvements for LoRA management --- lora_manager.py | 19 +++++- services/file_monitor.py | 123 +++++++++++++++++++++++++++++++++++++++ services/lora_scanner.py | 58 ++++++++++++++++-- utils/file_utils.py | 13 ++++- 4 files changed, 205 insertions(+), 8 deletions(-) create mode 100644 services/file_monitor.py diff --git a/lora_manager.py b/lora_manager.py index c9801b26..e647c5a4 100644 --- a/lora_manager.py +++ b/lora_manager.py @@ -4,6 +4,7 @@ from .config import config from .routes.lora_routes import LoraRoutes from .routes.api_routes import ApiRoutes from .services.lora_scanner import LoraScanner +from .services.file_monitor import LoraFileMonitor class LoraManager: """Main entry point for LoRA Manager plugin""" @@ -28,8 +29,18 @@ class LoraManager: LoraRoutes.setup_routes(app) ApiRoutes.setup_routes(app) + # Setup file monitoring + monitor = LoraFileMonitor(routes.scanner, config.loras_roots) + monitor.start() + + # Store monitor in app for cleanup + app['lora_monitor'] = monitor + # Schedule cache initialization using the application's startup handler app.on_startup.append(lambda app: cls._schedule_cache_init(routes.scanner)) + + # Add cleanup + app.on_shutdown.append(cls._cleanup) @classmethod async def _schedule_cache_init(cls, scanner: LoraScanner): @@ -47,4 +58,10 @@ class LoraManager: await scanner.get_cached_data(force_refresh=True) print("LoRA Manager: Cache initialization completed") except Exception as e: - print(f"LoRA Manager: Error initializing cache: {e}") \ No newline at end of file + print(f"LoRA Manager: Error initializing cache: {e}") + + @classmethod + async def _cleanup(cls, app): + """Cleanup resources""" + if 'lora_monitor' in app: + app['lora_monitor'].stop() \ No newline at end of file diff --git a/services/file_monitor.py b/services/file_monitor.py new file mode 100644 index 00000000..ac102473 --- /dev/null +++ b/services/file_monitor.py @@ -0,0 +1,123 @@ +import os +import time +import logging +import asyncio +from watchdog.observers import Observer +from watchdog.events import FileSystemEventHandler, FileCreatedEvent, FileDeletedEvent +from typing import List, Set, Callable +from threading import Lock +from .lora_scanner import LoraScanner + +logger = logging.getLogger(__name__) + +class LoraFileHandler(FileSystemEventHandler): + """Handler for LoRA file system events""" + + def __init__(self, scanner: LoraScanner, loop: asyncio.AbstractEventLoop): + self.scanner = scanner + self.loop = loop # 存储事件循环引用 + self.pending_changes = set() # 待处理的变更 + self.lock = Lock() # 线程安全锁 + self.update_task = None # 异步更新任务 + + def on_created(self, event): + if event.is_directory or not event.src_path.endswith('.safetensors'): + return + logger.info(f"LoRA file created: {event.src_path}") + self._schedule_update('add', event.src_path) + + def on_deleted(self, event): + if event.is_directory or not event.src_path.endswith('.safetensors'): + return + logger.info(f"LoRA file deleted: {event.src_path}") + self._schedule_update('remove', event.src_path) + + def _schedule_update(self, action: str, file_path: str): + """Schedule a cache update""" + with self.lock: + # 标准化路径 + normalized_path = file_path.replace(os.sep, '/') + self.pending_changes.add((action, normalized_path)) + + # 使用 call_soon_threadsafe 在事件循环中安排任务 + self.loop.call_soon_threadsafe(self._create_update_task) + + def _create_update_task(self): + """Create update task in the event loop""" + if self.update_task is None or self.update_task.done(): + self.update_task = asyncio.create_task(self._process_changes()) + + async def _process_changes(self, delay: float = 2.0): + """Process pending changes with debouncing""" + await asyncio.sleep(delay) + + try: + with self.lock: + changes = self.pending_changes.copy() + self.pending_changes.clear() + + if not changes: + return + + logger.info(f"Processing {len(changes)} file changes") + + # 获取当前缓存 + cache = await self.scanner.get_cached_data() + needs_resort = False + + for action, file_path in changes: + try: + if action == 'add': + # 扫描新文件 + lora_data = await self.scanner.scan_single_lora(file_path) + if lora_data: + cache.raw_data.append(lora_data) + needs_resort = True + + elif action == 'remove': + # 从缓存中移除 + cache.raw_data = [ + item for item in cache.raw_data + if item['file_path'] != file_path + ] + needs_resort = True + + except Exception as e: + logger.error(f"Error processing {action} for {file_path}: {e}") + + # 如果有变更,更新排序并重置缓存时间 + if needs_resort: + await self.scanner.resort_cache() + # 更新缓存时间戳,确保下次获取时能得到最新数据 + self.scanner._cache.last_update = time.time() + + except Exception as e: + logger.error(f"Error in process_changes: {e}") + + +class LoraFileMonitor: + """Monitor for LoRA file changes""" + + def __init__(self, scanner: LoraScanner, roots: List[str]): + self.scanner = scanner + self.roots = roots + self.observer = Observer() + # 获取当前运行的事件循环 + self.loop = asyncio.get_event_loop() + self.handler = LoraFileHandler(scanner, self.loop) + + def start(self): + """Start monitoring""" + for root in self.roots: + try: + self.observer.schedule(self.handler, root, recursive=True) + logger.info(f"Started monitoring: {root}") + except Exception as e: + logger.error(f"Error monitoring {root}: {e}") + + self.observer.start() + + def stop(self): + """Stop monitoring""" + self.observer.stop() + self.observer.join() \ No newline at end of file diff --git a/services/lora_scanner.py b/services/lora_scanner.py index b039ca19..653138e8 100644 --- a/services/lora_scanner.py +++ b/services/lora_scanner.py @@ -6,8 +6,7 @@ from typing import List, Dict, Optional from dataclasses import dataclass from operator import itemgetter from ..config import config -from ..utils.file_utils import load_metadata, get_file_info, save_metadata -from ..utils.lora_metadata import extract_lora_metadata +from ..utils.file_utils import load_metadata, get_file_info logger = logging.getLogger(__name__) @@ -181,9 +180,6 @@ class LoraScanner: if metadata is None: # Create new metadata if none exists metadata = await get_file_info(file_path) - base_model_info = await extract_lora_metadata(file_path) - metadata.base_model = base_model_info['base_model'] - await save_metadata(file_path, metadata) # Convert to dict and add folder info lora_data = metadata.to_dict() @@ -207,3 +203,55 @@ class LoraScanner: return False return self._cache.update_preview_url(file_path, preview_url) + + async def scan_single_lora(self, file_path: str) -> Optional[Dict]: + """Scan a single LoRA file and return its metadata""" + try: + if not os.path.exists(file_path): + return None + + # 获取基本文件信息 + metadata = await get_file_info(file_path) + if not metadata: + return None + + # 计算相对于 lora_roots 的文件夹路径 + folder = None + file_dir = os.path.dirname(file_path) + for root in config.loras_roots: + if file_dir.startswith(root): + rel_path = os.path.relpath(file_dir, root) + if rel_path == '.': + folder = '' # 根目录 + else: + folder = rel_path.replace(os.sep, '/') + break + + # 确保 folder 字段存在 + metadata_dict = metadata.to_dict() + metadata_dict['folder'] = folder or '' + + return metadata_dict + + except Exception as e: + logger.error(f"Error scanning {file_path}: {e}") + return None + + async def resort_cache(self): + """Resort cache data""" + if not self._cache: + return + + self._cache.sorted_by_name = sorted( + self._cache.raw_data, + key=itemgetter('model_name') + ) + self._cache.sorted_by_date = sorted( + self._cache.raw_data, + key=itemgetter('modified'), + reverse=True + ) + # 更新文件夹列表 + self._cache.folders = sorted(list(set( + l['folder'] for l in self._cache.raw_data + ))) diff --git a/utils/file_utils.py b/utils/file_utils.py index b3c9af76..72c192fd 100644 --- a/utils/file_utils.py +++ b/utils/file_utils.py @@ -2,6 +2,8 @@ import os import hashlib import json from typing import Dict, Optional + +from .lora_metadata import extract_lora_metadata from .models import LoraMetadata async def calculate_sha256(file_path: str) -> str: @@ -41,8 +43,8 @@ async def get_file_info(file_path: str) -> LoraMetadata: dir_path = os.path.dirname(file_path) preview_url = _find_preview_file(base_name, dir_path) - - return LoraMetadata( + + metadata = LoraMetadata( file_name=base_name, model_name=base_name, file_path=normalize_path(file_path), @@ -54,6 +56,13 @@ async def get_file_info(file_path: str) -> LoraMetadata: preview_url=normalize_path(preview_url), ) + # create metadata file + base_model_info = await extract_lora_metadata(file_path) + metadata.base_model = base_model_info['base_model'] + await save_metadata(file_path, metadata) + + return metadata + async def save_metadata(file_path: str, metadata: LoraMetadata) -> None: """Save metadata to .metadata.json file""" metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"