Refactor LoRA management with improved caching and route handling

This commit is contained in:
Will Miao
2025-02-03 21:23:49 +08:00
parent 3fa6c9e3a3
commit 12cdadb583
5 changed files with 138 additions and 118 deletions

View File

@@ -24,9 +24,8 @@ class LoraManager:
# Setup feature routes
routes = LoraRoutes()
api_routes = ApiRoutes()
LoraRoutes.setup_routes(app)
routes.setup_routes(app)
ApiRoutes.setup_routes(app)
# Setup file monitoring

View File

@@ -51,7 +51,7 @@ class LoraRoutes:
# Get cached data
cache = await self.scanner.get_cached_data()
# Format initial data (first page only)
# Get initial data (first page only)
initial_data = await self.scanner.get_paginated_data(
page=1,
page_size=20,
@@ -83,8 +83,6 @@ class LoraRoutes:
status=500
)
@classmethod
def setup_routes(cls, app: web.Application):
def setup_routes(self, app: web.Application):
"""Register routes with the application"""
routes = cls()
app.router.add_get('/loras', routes.handle_loras_page)
app.router.add_get('/loras', self.handle_loras_page)

View File

@@ -1,10 +1,10 @@
from operator import itemgetter
import os
import time
import logging
import asyncio
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler, FileCreatedEvent, FileDeletedEvent
from typing import List, Set, Callable
from typing import List
from threading import Lock
from .lora_scanner import LoraScanner
@@ -58,38 +58,52 @@ class LoraFileHandler(FileSystemEventHandler):
if not changes:
return
logger.info(f"Processing {len(changes)} file changes")
# 获取当前缓存
cache = await self.scanner.get_cached_data()
needs_resort = False
for action, file_path in changes:
try:
if action == 'add':
# 扫描新文件
lora_data = await self.scanner.scan_single_lora(file_path)
if lora_data:
cache.raw_data.append(lora_data)
async with self.scanner._cache._lock:
# 获取当前缓存
cache = await self.scanner.get_cached_data()
needs_resort = False
new_folders = set() # 用于收集新的文件夹
for action, file_path in changes:
try:
if action == 'add':
# 扫描新文件
lora_data = await self.scanner.scan_single_lora(file_path)
if lora_data:
cache.raw_data.append(lora_data)
new_folders.add(lora_data['folder']) # 收集新文件夹
needs_resort = True
elif action == 'remove':
# 从缓存中移除
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != file_path
]
needs_resort = True
elif action == 'remove':
# 从缓存中移除
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != file_path
]
needs_resort = True
except Exception as e:
logger.error(f"Error processing {action} for {file_path}: {e}")
# 如果有变更,更新排序并重置缓存时间
if needs_resort:
await self.scanner.resort_cache()
# 更新缓存时间戳,确保下次获取时能得到最新数据
self.scanner._cache.last_update = time.time()
except Exception as e:
logger.error(f"Error processing {action} for {file_path}: {e}")
if needs_resort:
cache.sorted_by_name = sorted(
self.scanner._cache.raw_data,
key=lambda x: x['model_name'].lower() # Case-insensitive sort
)
cache.sorted_by_date = sorted(
self.scanner._cache.raw_data,
key=itemgetter('modified'),
reverse=True
)
# 更新文件夹列表,包括新添加的文件夹
all_folders = set(cache.folders) | new_folders
cache.folders = sorted(list(all_folders))
except Exception as e:
logger.error(f"Error in process_changes: {e}")

64
services/lora_cache.py Normal file
View File

@@ -0,0 +1,64 @@
import asyncio
from typing import List, Dict
from dataclasses import dataclass
from operator import itemgetter
@dataclass
class LoraCache:
"""Cache structure for LoRA data"""
raw_data: List[Dict]
sorted_by_name: List[Dict]
sorted_by_date: List[Dict]
folders: List[str]
def __post_init__(self):
self._lock = asyncio.Lock()
async def resort(self):
"""Resort all cached data views"""
async with self._lock:
self.sorted_by_name = sorted(
self.raw_data,
key=lambda x: x['model_name'].lower() # Case-insensitive sort
)
self.sorted_by_date = sorted(
self.raw_data,
key=itemgetter('modified'),
reverse=True
)
# Update folder list
self.folders = sorted(list(set(
l['folder'] for l in self.raw_data
)))
async def update_preview_url(self, file_path: str, preview_url: str) -> bool:
"""Update preview_url for a specific lora in all cached data
Args:
file_path: The file path of the lora to update
preview_url: The new preview URL
Returns:
bool: True if the update was successful, False if the lora wasn't found
"""
async with self._lock:
# Update in raw_data
for item in self.raw_data:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
else:
return False # Lora not found
# Update in sorted lists (references to the same dict objects)
for item in self.sorted_by_name:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
for item in self.sorted_by_date:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
return True

View File

@@ -7,48 +7,10 @@ from dataclasses import dataclass
from operator import itemgetter
from ..config import config
from ..utils.file_utils import load_metadata, get_file_info
from .lora_cache import LoraCache
logger = logging.getLogger(__name__)
@dataclass
class LoraCache:
"""Cache structure for LoRA data"""
raw_data: List[Dict]
sorted_by_name: List[Dict]
sorted_by_date: List[Dict]
folders: List[str]
def update_preview_url(self, file_path: str, preview_url: str) -> bool:
"""Update preview_url for a specific lora in all cached data
Args:
file_path: The file path of the lora to update
preview_url: The new preview URL
Returns:
bool: True if the update was successful, False if the lora wasn't found
"""
# Update in raw_data
for item in self.raw_data:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
else:
return False # Lora not found
# Update in sorted lists (references to the same dict objects)
for item in self.sorted_by_name:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
for item in self.sorted_by_date:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
return True
class LoraScanner:
"""Service for scanning and managing LoRA files"""
@@ -60,7 +22,6 @@ class LoraScanner:
async def get_cached_data(self, force_refresh: bool = False) -> LoraCache:
"""Get cached LoRA data, refresh if needed"""
async with self._initialization_lock:
current_time = time.time()
# 如果正在初始化,等待完成
if self._initialization_task and not self._initialization_task.done():
@@ -100,7 +61,7 @@ class LoraScanner:
)
# Call resort_cache to create sorted views
await self.resort_cache()
await self._cache.resort()
async def get_paginated_data(self,
page: int,
@@ -110,27 +71,29 @@ class LoraScanner:
"""Get paginated LoRA data"""
# 确保缓存已初始化
cache = await self.get_cached_data()
async with cache._lock:
# Select sorted data based on sort_by parameter
data = (cache.sorted_by_date if sort_by == 'date'
else cache.sorted_by_name)
# Apply folder filter if specified
if folder is not None:
data = [item for item in data if item['folder'] == folder]
# Calculate pagination
total_items = len(data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
return {
'items': data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
# Select sorted data based on sort_by parameter
data = (cache.sorted_by_date if sort_by == 'date'
else cache.sorted_by_name)
# Apply folder filter if specified
if folder is not None:
data = [item for item in data if item['folder'] == folder]
# Calculate pagination
total_items = len(data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
return {
'items': data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
def invalidate_cache(self):
"""Invalidate the current cache"""
@@ -229,22 +192,4 @@ class LoraScanner:
except Exception as e:
logger.error(f"Error scanning {file_path}: {e}")
return None
async def resort_cache(self):
"""Resort cache data"""
if not self._cache:
return
self._cache.sorted_by_name = sorted(
self._cache.raw_data,
key=lambda x: x['model_name'].lower() # 使用 lower() 来实现不区分大小写的排序
)
self._cache.sorted_by_date = sorted(
self._cache.raw_data,
key=itemgetter('modified'),
reverse=True
)
# 更新文件夹列表
self._cache.folders = sorted(list(set(
l['folder'] for l in self._cache.raw_data
)))