mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Refactor LoRA management with improved caching and route handling
This commit is contained in:
@@ -24,9 +24,8 @@ class LoraManager:
|
||||
|
||||
# Setup feature routes
|
||||
routes = LoraRoutes()
|
||||
api_routes = ApiRoutes()
|
||||
|
||||
LoraRoutes.setup_routes(app)
|
||||
routes.setup_routes(app)
|
||||
ApiRoutes.setup_routes(app)
|
||||
|
||||
# Setup file monitoring
|
||||
|
||||
@@ -51,7 +51,7 @@ class LoraRoutes:
|
||||
# Get cached data
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Format initial data (first page only)
|
||||
# Get initial data (first page only)
|
||||
initial_data = await self.scanner.get_paginated_data(
|
||||
page=1,
|
||||
page_size=20,
|
||||
@@ -83,8 +83,6 @@ class LoraRoutes:
|
||||
status=500
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def setup_routes(cls, app: web.Application):
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Register routes with the application"""
|
||||
routes = cls()
|
||||
app.router.add_get('/loras', routes.handle_loras_page)
|
||||
app.router.add_get('/loras', self.handle_loras_page)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from operator import itemgetter
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import asyncio
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler, FileCreatedEvent, FileDeletedEvent
|
||||
from typing import List, Set, Callable
|
||||
from typing import List
|
||||
from threading import Lock
|
||||
from .lora_scanner import LoraScanner
|
||||
|
||||
@@ -58,38 +58,52 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
|
||||
if not changes:
|
||||
return
|
||||
|
||||
|
||||
|
||||
logger.info(f"Processing {len(changes)} file changes")
|
||||
|
||||
# 获取当前缓存
|
||||
cache = await self.scanner.get_cached_data()
|
||||
needs_resort = False
|
||||
|
||||
for action, file_path in changes:
|
||||
try:
|
||||
if action == 'add':
|
||||
# 扫描新文件
|
||||
lora_data = await self.scanner.scan_single_lora(file_path)
|
||||
if lora_data:
|
||||
cache.raw_data.append(lora_data)
|
||||
|
||||
async with self.scanner._cache._lock:
|
||||
# 获取当前缓存
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
needs_resort = False
|
||||
new_folders = set() # 用于收集新的文件夹
|
||||
|
||||
for action, file_path in changes:
|
||||
try:
|
||||
if action == 'add':
|
||||
# 扫描新文件
|
||||
lora_data = await self.scanner.scan_single_lora(file_path)
|
||||
if lora_data:
|
||||
cache.raw_data.append(lora_data)
|
||||
new_folders.add(lora_data['folder']) # 收集新文件夹
|
||||
needs_resort = True
|
||||
|
||||
elif action == 'remove':
|
||||
# 从缓存中移除
|
||||
cache.raw_data = [
|
||||
item for item in cache.raw_data
|
||||
if item['file_path'] != file_path
|
||||
]
|
||||
needs_resort = True
|
||||
|
||||
elif action == 'remove':
|
||||
# 从缓存中移除
|
||||
cache.raw_data = [
|
||||
item for item in cache.raw_data
|
||||
if item['file_path'] != file_path
|
||||
]
|
||||
needs_resort = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {action} for {file_path}: {e}")
|
||||
|
||||
# 如果有变更,更新排序并重置缓存时间
|
||||
if needs_resort:
|
||||
await self.scanner.resort_cache()
|
||||
# 更新缓存时间戳,确保下次获取时能得到最新数据
|
||||
self.scanner._cache.last_update = time.time()
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {action} for {file_path}: {e}")
|
||||
|
||||
if needs_resort:
|
||||
cache.sorted_by_name = sorted(
|
||||
self.scanner._cache.raw_data,
|
||||
key=lambda x: x['model_name'].lower() # Case-insensitive sort
|
||||
)
|
||||
cache.sorted_by_date = sorted(
|
||||
self.scanner._cache.raw_data,
|
||||
key=itemgetter('modified'),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# 更新文件夹列表,包括新添加的文件夹
|
||||
all_folders = set(cache.folders) | new_folders
|
||||
cache.folders = sorted(list(all_folders))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process_changes: {e}")
|
||||
|
||||
64
services/lora_cache.py
Normal file
64
services/lora_cache.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import asyncio
|
||||
from typing import List, Dict
|
||||
from dataclasses import dataclass
|
||||
from operator import itemgetter
|
||||
|
||||
@dataclass
|
||||
class LoraCache:
|
||||
"""Cache structure for LoRA data"""
|
||||
raw_data: List[Dict]
|
||||
sorted_by_name: List[Dict]
|
||||
sorted_by_date: List[Dict]
|
||||
folders: List[str]
|
||||
|
||||
def __post_init__(self):
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def resort(self):
|
||||
"""Resort all cached data views"""
|
||||
async with self._lock:
|
||||
self.sorted_by_name = sorted(
|
||||
self.raw_data,
|
||||
key=lambda x: x['model_name'].lower() # Case-insensitive sort
|
||||
)
|
||||
self.sorted_by_date = sorted(
|
||||
self.raw_data,
|
||||
key=itemgetter('modified'),
|
||||
reverse=True
|
||||
)
|
||||
# Update folder list
|
||||
self.folders = sorted(list(set(
|
||||
l['folder'] for l in self.raw_data
|
||||
)))
|
||||
|
||||
async def update_preview_url(self, file_path: str, preview_url: str) -> bool:
|
||||
"""Update preview_url for a specific lora in all cached data
|
||||
|
||||
Args:
|
||||
file_path: The file path of the lora to update
|
||||
preview_url: The new preview URL
|
||||
|
||||
Returns:
|
||||
bool: True if the update was successful, False if the lora wasn't found
|
||||
"""
|
||||
async with self._lock:
|
||||
# Update in raw_data
|
||||
for item in self.raw_data:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
break
|
||||
else:
|
||||
return False # Lora not found
|
||||
|
||||
# Update in sorted lists (references to the same dict objects)
|
||||
for item in self.sorted_by_name:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
break
|
||||
|
||||
for item in self.sorted_by_date:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
break
|
||||
|
||||
return True
|
||||
@@ -7,48 +7,10 @@ from dataclasses import dataclass
|
||||
from operator import itemgetter
|
||||
from ..config import config
|
||||
from ..utils.file_utils import load_metadata, get_file_info
|
||||
from .lora_cache import LoraCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class LoraCache:
|
||||
"""Cache structure for LoRA data"""
|
||||
raw_data: List[Dict]
|
||||
sorted_by_name: List[Dict]
|
||||
sorted_by_date: List[Dict]
|
||||
folders: List[str]
|
||||
|
||||
def update_preview_url(self, file_path: str, preview_url: str) -> bool:
|
||||
"""Update preview_url for a specific lora in all cached data
|
||||
|
||||
Args:
|
||||
file_path: The file path of the lora to update
|
||||
preview_url: The new preview URL
|
||||
|
||||
Returns:
|
||||
bool: True if the update was successful, False if the lora wasn't found
|
||||
"""
|
||||
# Update in raw_data
|
||||
for item in self.raw_data:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
break
|
||||
else:
|
||||
return False # Lora not found
|
||||
|
||||
# Update in sorted lists (references to the same dict objects)
|
||||
for item in self.sorted_by_name:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
break
|
||||
|
||||
for item in self.sorted_by_date:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
break
|
||||
|
||||
return True
|
||||
|
||||
class LoraScanner:
|
||||
"""Service for scanning and managing LoRA files"""
|
||||
|
||||
@@ -60,7 +22,6 @@ class LoraScanner:
|
||||
async def get_cached_data(self, force_refresh: bool = False) -> LoraCache:
|
||||
"""Get cached LoRA data, refresh if needed"""
|
||||
async with self._initialization_lock:
|
||||
current_time = time.time()
|
||||
|
||||
# 如果正在初始化,等待完成
|
||||
if self._initialization_task and not self._initialization_task.done():
|
||||
@@ -100,7 +61,7 @@ class LoraScanner:
|
||||
)
|
||||
|
||||
# Call resort_cache to create sorted views
|
||||
await self.resort_cache()
|
||||
await self._cache.resort()
|
||||
|
||||
async def get_paginated_data(self,
|
||||
page: int,
|
||||
@@ -110,27 +71,29 @@ class LoraScanner:
|
||||
"""Get paginated LoRA data"""
|
||||
# 确保缓存已初始化
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
async with cache._lock:
|
||||
|
||||
# Select sorted data based on sort_by parameter
|
||||
data = (cache.sorted_by_date if sort_by == 'date'
|
||||
else cache.sorted_by_name)
|
||||
|
||||
# Apply folder filter if specified
|
||||
if folder is not None:
|
||||
data = [item for item in data if item['folder'] == folder]
|
||||
|
||||
# Calculate pagination
|
||||
total_items = len(data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
return {
|
||||
'items': data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
# Select sorted data based on sort_by parameter
|
||||
data = (cache.sorted_by_date if sort_by == 'date'
|
||||
else cache.sorted_by_name)
|
||||
|
||||
# Apply folder filter if specified
|
||||
if folder is not None:
|
||||
data = [item for item in data if item['folder'] == folder]
|
||||
|
||||
# Calculate pagination
|
||||
total_items = len(data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
return {
|
||||
'items': data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
def invalidate_cache(self):
|
||||
"""Invalidate the current cache"""
|
||||
@@ -229,22 +192,4 @@ class LoraScanner:
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {file_path}: {e}")
|
||||
return None
|
||||
|
||||
async def resort_cache(self):
|
||||
"""Resort cache data"""
|
||||
if not self._cache:
|
||||
return
|
||||
|
||||
self._cache.sorted_by_name = sorted(
|
||||
self._cache.raw_data,
|
||||
key=lambda x: x['model_name'].lower() # 使用 lower() 来实现不区分大小写的排序
|
||||
)
|
||||
self._cache.sorted_by_date = sorted(
|
||||
self._cache.raw_data,
|
||||
key=itemgetter('modified'),
|
||||
reverse=True
|
||||
)
|
||||
# 更新文件夹列表
|
||||
self._cache.folders = sorted(list(set(
|
||||
l['folder'] for l in self._cache.raw_data
|
||||
)))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user