Refactor LoRA management with improved caching and route handling

This commit is contained in:
Will Miao
2025-02-03 21:23:49 +08:00
parent 3fa6c9e3a3
commit 12cdadb583
5 changed files with 138 additions and 118 deletions

View File

@@ -7,48 +7,10 @@ from dataclasses import dataclass
from operator import itemgetter
from ..config import config
from ..utils.file_utils import load_metadata, get_file_info
from .lora_cache import LoraCache
logger = logging.getLogger(__name__)
@dataclass
class LoraCache:
"""Cache structure for LoRA data"""
raw_data: List[Dict]
sorted_by_name: List[Dict]
sorted_by_date: List[Dict]
folders: List[str]
def update_preview_url(self, file_path: str, preview_url: str) -> bool:
"""Update preview_url for a specific lora in all cached data
Args:
file_path: The file path of the lora to update
preview_url: The new preview URL
Returns:
bool: True if the update was successful, False if the lora wasn't found
"""
# Update in raw_data
for item in self.raw_data:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
else:
return False # Lora not found
# Update in sorted lists (references to the same dict objects)
for item in self.sorted_by_name:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
for item in self.sorted_by_date:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
return True
class LoraScanner:
"""Service for scanning and managing LoRA files"""
@@ -60,7 +22,6 @@ class LoraScanner:
async def get_cached_data(self, force_refresh: bool = False) -> LoraCache:
"""Get cached LoRA data, refresh if needed"""
async with self._initialization_lock:
current_time = time.time()
# 如果正在初始化,等待完成
if self._initialization_task and not self._initialization_task.done():
@@ -100,7 +61,7 @@ class LoraScanner:
)
# Call resort_cache to create sorted views
await self.resort_cache()
await self._cache.resort()
async def get_paginated_data(self,
page: int,
@@ -110,27 +71,29 @@ class LoraScanner:
"""Get paginated LoRA data"""
# 确保缓存已初始化
cache = await self.get_cached_data()
async with cache._lock:
# Select sorted data based on sort_by parameter
data = (cache.sorted_by_date if sort_by == 'date'
else cache.sorted_by_name)
# Apply folder filter if specified
if folder is not None:
data = [item for item in data if item['folder'] == folder]
# Calculate pagination
total_items = len(data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
return {
'items': data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
# Select sorted data based on sort_by parameter
data = (cache.sorted_by_date if sort_by == 'date'
else cache.sorted_by_name)
# Apply folder filter if specified
if folder is not None:
data = [item for item in data if item['folder'] == folder]
# Calculate pagination
total_items = len(data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
return {
'items': data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
def invalidate_cache(self):
"""Invalidate the current cache"""
@@ -229,22 +192,4 @@ class LoraScanner:
except Exception as e:
logger.error(f"Error scanning {file_path}: {e}")
return None
async def resort_cache(self):
"""Resort cache data"""
if not self._cache:
return
self._cache.sorted_by_name = sorted(
self._cache.raw_data,
key=lambda x: x['model_name'].lower() # 使用 lower() 来实现不区分大小写的排序
)
self._cache.sorted_by_date = sorted(
self._cache.raw_data,
key=itemgetter('modified'),
reverse=True
)
# 更新文件夹列表
self._cache.folders = sorted(list(set(
l['folder'] for l in self._cache.raw_data
)))