mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
Refactor LoRA management with improved caching and route handling
This commit is contained in:
@@ -24,9 +24,8 @@ class LoraManager:
|
|||||||
|
|
||||||
# Setup feature routes
|
# Setup feature routes
|
||||||
routes = LoraRoutes()
|
routes = LoraRoutes()
|
||||||
api_routes = ApiRoutes()
|
|
||||||
|
|
||||||
LoraRoutes.setup_routes(app)
|
routes.setup_routes(app)
|
||||||
ApiRoutes.setup_routes(app)
|
ApiRoutes.setup_routes(app)
|
||||||
|
|
||||||
# Setup file monitoring
|
# Setup file monitoring
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class LoraRoutes:
|
|||||||
# Get cached data
|
# Get cached data
|
||||||
cache = await self.scanner.get_cached_data()
|
cache = await self.scanner.get_cached_data()
|
||||||
|
|
||||||
# Format initial data (first page only)
|
# Get initial data (first page only)
|
||||||
initial_data = await self.scanner.get_paginated_data(
|
initial_data = await self.scanner.get_paginated_data(
|
||||||
page=1,
|
page=1,
|
||||||
page_size=20,
|
page_size=20,
|
||||||
@@ -83,8 +83,6 @@ class LoraRoutes:
|
|||||||
status=500
|
status=500
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
def setup_routes(self, app: web.Application):
|
||||||
def setup_routes(cls, app: web.Application):
|
|
||||||
"""Register routes with the application"""
|
"""Register routes with the application"""
|
||||||
routes = cls()
|
app.router.add_get('/loras', self.handle_loras_page)
|
||||||
app.router.add_get('/loras', routes.handle_loras_page)
|
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
|
from operator import itemgetter
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
from watchdog.observers import Observer
|
from watchdog.observers import Observer
|
||||||
from watchdog.events import FileSystemEventHandler, FileCreatedEvent, FileDeletedEvent
|
from watchdog.events import FileSystemEventHandler, FileCreatedEvent, FileDeletedEvent
|
||||||
from typing import List, Set, Callable
|
from typing import List
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from .lora_scanner import LoraScanner
|
from .lora_scanner import LoraScanner
|
||||||
|
|
||||||
@@ -59,37 +59,51 @@ class LoraFileHandler(FileSystemEventHandler):
|
|||||||
if not changes:
|
if not changes:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
logger.info(f"Processing {len(changes)} file changes")
|
logger.info(f"Processing {len(changes)} file changes")
|
||||||
|
|
||||||
# 获取当前缓存
|
async with self.scanner._cache._lock:
|
||||||
cache = await self.scanner.get_cached_data()
|
# 获取当前缓存
|
||||||
needs_resort = False
|
cache = await self.scanner.get_cached_data()
|
||||||
|
|
||||||
for action, file_path in changes:
|
needs_resort = False
|
||||||
try:
|
new_folders = set() # 用于收集新的文件夹
|
||||||
if action == 'add':
|
|
||||||
# 扫描新文件
|
for action, file_path in changes:
|
||||||
lora_data = await self.scanner.scan_single_lora(file_path)
|
try:
|
||||||
if lora_data:
|
if action == 'add':
|
||||||
cache.raw_data.append(lora_data)
|
# 扫描新文件
|
||||||
|
lora_data = await self.scanner.scan_single_lora(file_path)
|
||||||
|
if lora_data:
|
||||||
|
cache.raw_data.append(lora_data)
|
||||||
|
new_folders.add(lora_data['folder']) # 收集新文件夹
|
||||||
|
needs_resort = True
|
||||||
|
|
||||||
|
elif action == 'remove':
|
||||||
|
# 从缓存中移除
|
||||||
|
cache.raw_data = [
|
||||||
|
item for item in cache.raw_data
|
||||||
|
if item['file_path'] != file_path
|
||||||
|
]
|
||||||
needs_resort = True
|
needs_resort = True
|
||||||
|
|
||||||
elif action == 'remove':
|
except Exception as e:
|
||||||
# 从缓存中移除
|
logger.error(f"Error processing {action} for {file_path}: {e}")
|
||||||
cache.raw_data = [
|
|
||||||
item for item in cache.raw_data
|
|
||||||
if item['file_path'] != file_path
|
|
||||||
]
|
|
||||||
needs_resort = True
|
|
||||||
|
|
||||||
except Exception as e:
|
if needs_resort:
|
||||||
logger.error(f"Error processing {action} for {file_path}: {e}")
|
cache.sorted_by_name = sorted(
|
||||||
|
self.scanner._cache.raw_data,
|
||||||
|
key=lambda x: x['model_name'].lower() # Case-insensitive sort
|
||||||
|
)
|
||||||
|
cache.sorted_by_date = sorted(
|
||||||
|
self.scanner._cache.raw_data,
|
||||||
|
key=itemgetter('modified'),
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
# 如果有变更,更新排序并重置缓存时间
|
# 更新文件夹列表,包括新添加的文件夹
|
||||||
if needs_resort:
|
all_folders = set(cache.folders) | new_folders
|
||||||
await self.scanner.resort_cache()
|
cache.folders = sorted(list(all_folders))
|
||||||
# 更新缓存时间戳,确保下次获取时能得到最新数据
|
|
||||||
self.scanner._cache.last_update = time.time()
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in process_changes: {e}")
|
logger.error(f"Error in process_changes: {e}")
|
||||||
|
|||||||
64
services/lora_cache.py
Normal file
64
services/lora_cache.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import List, Dict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from operator import itemgetter
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoraCache:
|
||||||
|
"""Cache structure for LoRA data"""
|
||||||
|
raw_data: List[Dict]
|
||||||
|
sorted_by_name: List[Dict]
|
||||||
|
sorted_by_date: List[Dict]
|
||||||
|
folders: List[str]
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def resort(self):
|
||||||
|
"""Resort all cached data views"""
|
||||||
|
async with self._lock:
|
||||||
|
self.sorted_by_name = sorted(
|
||||||
|
self.raw_data,
|
||||||
|
key=lambda x: x['model_name'].lower() # Case-insensitive sort
|
||||||
|
)
|
||||||
|
self.sorted_by_date = sorted(
|
||||||
|
self.raw_data,
|
||||||
|
key=itemgetter('modified'),
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
# Update folder list
|
||||||
|
self.folders = sorted(list(set(
|
||||||
|
l['folder'] for l in self.raw_data
|
||||||
|
)))
|
||||||
|
|
||||||
|
async def update_preview_url(self, file_path: str, preview_url: str) -> bool:
|
||||||
|
"""Update preview_url for a specific lora in all cached data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: The file path of the lora to update
|
||||||
|
preview_url: The new preview URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the update was successful, False if the lora wasn't found
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
# Update in raw_data
|
||||||
|
for item in self.raw_data:
|
||||||
|
if item['file_path'] == file_path:
|
||||||
|
item['preview_url'] = preview_url
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return False # Lora not found
|
||||||
|
|
||||||
|
# Update in sorted lists (references to the same dict objects)
|
||||||
|
for item in self.sorted_by_name:
|
||||||
|
if item['file_path'] == file_path:
|
||||||
|
item['preview_url'] = preview_url
|
||||||
|
break
|
||||||
|
|
||||||
|
for item in self.sorted_by_date:
|
||||||
|
if item['file_path'] == file_path:
|
||||||
|
item['preview_url'] = preview_url
|
||||||
|
break
|
||||||
|
|
||||||
|
return True
|
||||||
@@ -7,48 +7,10 @@ from dataclasses import dataclass
|
|||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..utils.file_utils import load_metadata, get_file_info
|
from ..utils.file_utils import load_metadata, get_file_info
|
||||||
|
from .lora_cache import LoraCache
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LoraCache:
|
|
||||||
"""Cache structure for LoRA data"""
|
|
||||||
raw_data: List[Dict]
|
|
||||||
sorted_by_name: List[Dict]
|
|
||||||
sorted_by_date: List[Dict]
|
|
||||||
folders: List[str]
|
|
||||||
|
|
||||||
def update_preview_url(self, file_path: str, preview_url: str) -> bool:
|
|
||||||
"""Update preview_url for a specific lora in all cached data
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: The file path of the lora to update
|
|
||||||
preview_url: The new preview URL
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the update was successful, False if the lora wasn't found
|
|
||||||
"""
|
|
||||||
# Update in raw_data
|
|
||||||
for item in self.raw_data:
|
|
||||||
if item['file_path'] == file_path:
|
|
||||||
item['preview_url'] = preview_url
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
return False # Lora not found
|
|
||||||
|
|
||||||
# Update in sorted lists (references to the same dict objects)
|
|
||||||
for item in self.sorted_by_name:
|
|
||||||
if item['file_path'] == file_path:
|
|
||||||
item['preview_url'] = preview_url
|
|
||||||
break
|
|
||||||
|
|
||||||
for item in self.sorted_by_date:
|
|
||||||
if item['file_path'] == file_path:
|
|
||||||
item['preview_url'] = preview_url
|
|
||||||
break
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
class LoraScanner:
|
class LoraScanner:
|
||||||
"""Service for scanning and managing LoRA files"""
|
"""Service for scanning and managing LoRA files"""
|
||||||
|
|
||||||
@@ -60,7 +22,6 @@ class LoraScanner:
|
|||||||
async def get_cached_data(self, force_refresh: bool = False) -> LoraCache:
|
async def get_cached_data(self, force_refresh: bool = False) -> LoraCache:
|
||||||
"""Get cached LoRA data, refresh if needed"""
|
"""Get cached LoRA data, refresh if needed"""
|
||||||
async with self._initialization_lock:
|
async with self._initialization_lock:
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
# 如果正在初始化,等待完成
|
# 如果正在初始化,等待完成
|
||||||
if self._initialization_task and not self._initialization_task.done():
|
if self._initialization_task and not self._initialization_task.done():
|
||||||
@@ -100,7 +61,7 @@ class LoraScanner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Call resort_cache to create sorted views
|
# Call resort_cache to create sorted views
|
||||||
await self.resort_cache()
|
await self._cache.resort()
|
||||||
|
|
||||||
async def get_paginated_data(self,
|
async def get_paginated_data(self,
|
||||||
page: int,
|
page: int,
|
||||||
@@ -111,26 +72,28 @@ class LoraScanner:
|
|||||||
# 确保缓存已初始化
|
# 确保缓存已初始化
|
||||||
cache = await self.get_cached_data()
|
cache = await self.get_cached_data()
|
||||||
|
|
||||||
# Select sorted data based on sort_by parameter
|
async with cache._lock:
|
||||||
data = (cache.sorted_by_date if sort_by == 'date'
|
|
||||||
else cache.sorted_by_name)
|
|
||||||
|
|
||||||
# Apply folder filter if specified
|
# Select sorted data based on sort_by parameter
|
||||||
if folder is not None:
|
data = (cache.sorted_by_date if sort_by == 'date'
|
||||||
data = [item for item in data if item['folder'] == folder]
|
else cache.sorted_by_name)
|
||||||
|
|
||||||
# Calculate pagination
|
# Apply folder filter if specified
|
||||||
total_items = len(data)
|
if folder is not None:
|
||||||
start_idx = (page - 1) * page_size
|
data = [item for item in data if item['folder'] == folder]
|
||||||
end_idx = min(start_idx + page_size, total_items)
|
|
||||||
|
|
||||||
return {
|
# Calculate pagination
|
||||||
'items': data[start_idx:end_idx],
|
total_items = len(data)
|
||||||
'total': total_items,
|
start_idx = (page - 1) * page_size
|
||||||
'page': page,
|
end_idx = min(start_idx + page_size, total_items)
|
||||||
'page_size': page_size,
|
|
||||||
'total_pages': (total_items + page_size - 1) // page_size
|
return {
|
||||||
}
|
'items': data[start_idx:end_idx],
|
||||||
|
'total': total_items,
|
||||||
|
'page': page,
|
||||||
|
'page_size': page_size,
|
||||||
|
'total_pages': (total_items + page_size - 1) // page_size
|
||||||
|
}
|
||||||
|
|
||||||
def invalidate_cache(self):
|
def invalidate_cache(self):
|
||||||
"""Invalidate the current cache"""
|
"""Invalidate the current cache"""
|
||||||
@@ -230,21 +193,3 @@ class LoraScanner:
|
|||||||
logger.error(f"Error scanning {file_path}: {e}")
|
logger.error(f"Error scanning {file_path}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def resort_cache(self):
|
|
||||||
"""Resort cache data"""
|
|
||||||
if not self._cache:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._cache.sorted_by_name = sorted(
|
|
||||||
self._cache.raw_data,
|
|
||||||
key=lambda x: x['model_name'].lower() # 使用 lower() 来实现不区分大小写的排序
|
|
||||||
)
|
|
||||||
self._cache.sorted_by_date = sorted(
|
|
||||||
self._cache.raw_data,
|
|
||||||
key=itemgetter('modified'),
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
# 更新文件夹列表
|
|
||||||
self._cache.folders = sorted(list(set(
|
|
||||||
l['folder'] for l in self._cache.raw_data
|
|
||||||
)))
|
|
||||||
|
|||||||
Reference in New Issue
Block a user