Add file monitoring and scanning improvements for LoRA management

This commit is contained in:
Will Miao
2025-02-03 14:58:04 +08:00
parent e528438a57
commit 270182b5cd
4 changed files with 205 additions and 8 deletions

View File

@@ -4,6 +4,7 @@ from .config import config
from .routes.lora_routes import LoraRoutes from .routes.lora_routes import LoraRoutes
from .routes.api_routes import ApiRoutes from .routes.api_routes import ApiRoutes
from .services.lora_scanner import LoraScanner from .services.lora_scanner import LoraScanner
from .services.file_monitor import LoraFileMonitor
class LoraManager: class LoraManager:
"""Main entry point for LoRA Manager plugin""" """Main entry point for LoRA Manager plugin"""
@@ -28,9 +29,19 @@ class LoraManager:
LoraRoutes.setup_routes(app) LoraRoutes.setup_routes(app)
ApiRoutes.setup_routes(app) ApiRoutes.setup_routes(app)
# Setup file monitoring
monitor = LoraFileMonitor(routes.scanner, config.loras_roots)
monitor.start()
# Store monitor in app for cleanup
app['lora_monitor'] = monitor
# Schedule cache initialization using the application's startup handler # Schedule cache initialization using the application's startup handler
app.on_startup.append(lambda app: cls._schedule_cache_init(routes.scanner)) app.on_startup.append(lambda app: cls._schedule_cache_init(routes.scanner))
# Add cleanup
app.on_shutdown.append(cls._cleanup)
@classmethod @classmethod
async def _schedule_cache_init(cls, scanner: LoraScanner): async def _schedule_cache_init(cls, scanner: LoraScanner):
"""Schedule cache initialization in the running event loop""" """Schedule cache initialization in the running event loop"""
@@ -48,3 +59,9 @@ class LoraManager:
print("LoRA Manager: Cache initialization completed") print("LoRA Manager: Cache initialization completed")
except Exception as e: except Exception as e:
print(f"LoRA Manager: Error initializing cache: {e}") print(f"LoRA Manager: Error initializing cache: {e}")
@classmethod
async def _cleanup(cls, app):
"""Cleanup resources"""
if 'lora_monitor' in app:
app['lora_monitor'].stop()

123
services/file_monitor.py Normal file
View File

@@ -0,0 +1,123 @@
import os
import time
import logging
import asyncio
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler, FileCreatedEvent, FileDeletedEvent
from typing import List, Set, Callable
from threading import Lock
from .lora_scanner import LoraScanner
logger = logging.getLogger(__name__)
class LoraFileHandler(FileSystemEventHandler):
"""Handler for LoRA file system events"""
def __init__(self, scanner: LoraScanner, loop: asyncio.AbstractEventLoop):
self.scanner = scanner
self.loop = loop # 存储事件循环引用
self.pending_changes = set() # 待处理的变更
self.lock = Lock() # 线程安全锁
self.update_task = None # 异步更新任务
def on_created(self, event):
if event.is_directory or not event.src_path.endswith('.safetensors'):
return
logger.info(f"LoRA file created: {event.src_path}")
self._schedule_update('add', event.src_path)
def on_deleted(self, event):
if event.is_directory or not event.src_path.endswith('.safetensors'):
return
logger.info(f"LoRA file deleted: {event.src_path}")
self._schedule_update('remove', event.src_path)
def _schedule_update(self, action: str, file_path: str):
"""Schedule a cache update"""
with self.lock:
# 标准化路径
normalized_path = file_path.replace(os.sep, '/')
self.pending_changes.add((action, normalized_path))
# 使用 call_soon_threadsafe 在事件循环中安排任务
self.loop.call_soon_threadsafe(self._create_update_task)
def _create_update_task(self):
"""Create update task in the event loop"""
if self.update_task is None or self.update_task.done():
self.update_task = asyncio.create_task(self._process_changes())
async def _process_changes(self, delay: float = 2.0):
"""Process pending changes with debouncing"""
await asyncio.sleep(delay)
try:
with self.lock:
changes = self.pending_changes.copy()
self.pending_changes.clear()
if not changes:
return
logger.info(f"Processing {len(changes)} file changes")
# 获取当前缓存
cache = await self.scanner.get_cached_data()
needs_resort = False
for action, file_path in changes:
try:
if action == 'add':
# 扫描新文件
lora_data = await self.scanner.scan_single_lora(file_path)
if lora_data:
cache.raw_data.append(lora_data)
needs_resort = True
elif action == 'remove':
# 从缓存中移除
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != file_path
]
needs_resort = True
except Exception as e:
logger.error(f"Error processing {action} for {file_path}: {e}")
# 如果有变更,更新排序并重置缓存时间
if needs_resort:
await self.scanner.resort_cache()
# 更新缓存时间戳,确保下次获取时能得到最新数据
self.scanner._cache.last_update = time.time()
except Exception as e:
logger.error(f"Error in process_changes: {e}")
class LoraFileMonitor:
"""Monitor for LoRA file changes"""
def __init__(self, scanner: LoraScanner, roots: List[str]):
self.scanner = scanner
self.roots = roots
self.observer = Observer()
# 获取当前运行的事件循环
self.loop = asyncio.get_event_loop()
self.handler = LoraFileHandler(scanner, self.loop)
def start(self):
"""Start monitoring"""
for root in self.roots:
try:
self.observer.schedule(self.handler, root, recursive=True)
logger.info(f"Started monitoring: {root}")
except Exception as e:
logger.error(f"Error monitoring {root}: {e}")
self.observer.start()
def stop(self):
"""Stop monitoring"""
self.observer.stop()
self.observer.join()

View File

@@ -6,8 +6,7 @@ from typing import List, Dict, Optional
from dataclasses import dataclass from dataclasses import dataclass
from operator import itemgetter from operator import itemgetter
from ..config import config from ..config import config
from ..utils.file_utils import load_metadata, get_file_info, save_metadata from ..utils.file_utils import load_metadata, get_file_info
from ..utils.lora_metadata import extract_lora_metadata
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -181,9 +180,6 @@ class LoraScanner:
if metadata is None: if metadata is None:
# Create new metadata if none exists # Create new metadata if none exists
metadata = await get_file_info(file_path) metadata = await get_file_info(file_path)
base_model_info = await extract_lora_metadata(file_path)
metadata.base_model = base_model_info['base_model']
await save_metadata(file_path, metadata)
# Convert to dict and add folder info # Convert to dict and add folder info
lora_data = metadata.to_dict() lora_data = metadata.to_dict()
@@ -207,3 +203,55 @@ class LoraScanner:
return False return False
return self._cache.update_preview_url(file_path, preview_url) return self._cache.update_preview_url(file_path, preview_url)
async def scan_single_lora(self, file_path: str) -> Optional[Dict]:
"""Scan a single LoRA file and return its metadata"""
try:
if not os.path.exists(file_path):
return None
# 获取基本文件信息
metadata = await get_file_info(file_path)
if not metadata:
return None
# 计算相对于 lora_roots 的文件夹路径
folder = None
file_dir = os.path.dirname(file_path)
for root in config.loras_roots:
if file_dir.startswith(root):
rel_path = os.path.relpath(file_dir, root)
if rel_path == '.':
folder = '' # 根目录
else:
folder = rel_path.replace(os.sep, '/')
break
# 确保 folder 字段存在
metadata_dict = metadata.to_dict()
metadata_dict['folder'] = folder or ''
return metadata_dict
except Exception as e:
logger.error(f"Error scanning {file_path}: {e}")
return None
async def resort_cache(self):
"""Resort cache data"""
if not self._cache:
return
self._cache.sorted_by_name = sorted(
self._cache.raw_data,
key=itemgetter('model_name')
)
self._cache.sorted_by_date = sorted(
self._cache.raw_data,
key=itemgetter('modified'),
reverse=True
)
# 更新文件夹列表
self._cache.folders = sorted(list(set(
l['folder'] for l in self._cache.raw_data
)))

View File

@@ -2,6 +2,8 @@ import os
import hashlib import hashlib
import json import json
from typing import Dict, Optional from typing import Dict, Optional
from .lora_metadata import extract_lora_metadata
from .models import LoraMetadata from .models import LoraMetadata
async def calculate_sha256(file_path: str) -> str: async def calculate_sha256(file_path: str) -> str:
@@ -42,7 +44,7 @@ async def get_file_info(file_path: str) -> LoraMetadata:
preview_url = _find_preview_file(base_name, dir_path) preview_url = _find_preview_file(base_name, dir_path)
return LoraMetadata( metadata = LoraMetadata(
file_name=base_name, file_name=base_name,
model_name=base_name, model_name=base_name,
file_path=normalize_path(file_path), file_path=normalize_path(file_path),
@@ -54,6 +56,13 @@ async def get_file_info(file_path: str) -> LoraMetadata:
preview_url=normalize_path(preview_url), preview_url=normalize_path(preview_url),
) )
# create metadata file
base_model_info = await extract_lora_metadata(file_path)
metadata.base_model = base_model_info['base_model']
await save_metadata(file_path, metadata)
return metadata
async def save_metadata(file_path: str, metadata: LoraMetadata) -> None: async def save_metadata(file_path: str, metadata: LoraMetadata) -> None:
"""Save metadata to .metadata.json file""" """Save metadata to .metadata.json file"""
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json" metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"