diff --git a/lora_manager.py b/lora_manager.py index e647c5a4..1afa1a1c 100644 --- a/lora_manager.py +++ b/lora_manager.py @@ -24,9 +24,8 @@ class LoraManager: # Setup feature routes routes = LoraRoutes() - api_routes = ApiRoutes() - LoraRoutes.setup_routes(app) + routes.setup_routes(app) ApiRoutes.setup_routes(app) # Setup file monitoring diff --git a/routes/lora_routes.py b/routes/lora_routes.py index 3c61055e..350e785a 100644 --- a/routes/lora_routes.py +++ b/routes/lora_routes.py @@ -51,7 +51,7 @@ class LoraRoutes: # Get cached data cache = await self.scanner.get_cached_data() - # Format initial data (first page only) + # Get initial data (first page only) initial_data = await self.scanner.get_paginated_data( page=1, page_size=20, @@ -83,8 +83,6 @@ class LoraRoutes: status=500 ) - @classmethod - def setup_routes(cls, app: web.Application): + def setup_routes(self, app: web.Application): """Register routes with the application""" - routes = cls() - app.router.add_get('/loras', routes.handle_loras_page) + app.router.add_get('/loras', self.handle_loras_page) diff --git a/services/file_monitor.py b/services/file_monitor.py index ac102473..d9b1c06a 100644 --- a/services/file_monitor.py +++ b/services/file_monitor.py @@ -1,10 +1,10 @@ +from operator import itemgetter import os -import time import logging import asyncio from watchdog.observers import Observer from watchdog.events import FileSystemEventHandler, FileCreatedEvent, FileDeletedEvent -from typing import List, Set, Callable +from typing import List from threading import Lock from .lora_scanner import LoraScanner @@ -58,38 +58,52 @@ class LoraFileHandler(FileSystemEventHandler): if not changes: return - + + logger.info(f"Processing {len(changes)} file changes") - - # 获取当前缓存 - cache = await self.scanner.get_cached_data() - needs_resort = False - - for action, file_path in changes: - try: - if action == 'add': - # 扫描新文件 - lora_data = await self.scanner.scan_single_lora(file_path) - if lora_data: - cache.raw_data.append(lora_data) + + async with self.scanner._cache._lock: + # 获取当前缓存 + cache = await self.scanner.get_cached_data() + + needs_resort = False + new_folders = set() # 用于收集新的文件夹 + + for action, file_path in changes: + try: + if action == 'add': + # 扫描新文件 + lora_data = await self.scanner.scan_single_lora(file_path) + if lora_data: + cache.raw_data.append(lora_data) + new_folders.add(lora_data['folder']) # 收集新文件夹 + needs_resort = True + + elif action == 'remove': + # 从缓存中移除 + cache.raw_data = [ + item for item in cache.raw_data + if item['file_path'] != file_path + ] needs_resort = True - elif action == 'remove': - # 从缓存中移除 - cache.raw_data = [ - item for item in cache.raw_data - if item['file_path'] != file_path - ] - needs_resort = True - - except Exception as e: - logger.error(f"Error processing {action} for {file_path}: {e}") - - # 如果有变更,更新排序并重置缓存时间 - if needs_resort: - await self.scanner.resort_cache() - # 更新缓存时间戳,确保下次获取时能得到最新数据 - self.scanner._cache.last_update = time.time() + except Exception as e: + logger.error(f"Error processing {action} for {file_path}: {e}") + + if needs_resort: + cache.sorted_by_name = sorted( + self.scanner._cache.raw_data, + key=lambda x: x['model_name'].lower() # Case-insensitive sort + ) + cache.sorted_by_date = sorted( + self.scanner._cache.raw_data, + key=itemgetter('modified'), + reverse=True + ) + + # 更新文件夹列表,包括新添加的文件夹 + all_folders = set(cache.folders) | new_folders + cache.folders = sorted(list(all_folders)) except Exception as e: logger.error(f"Error in process_changes: {e}") diff --git a/services/lora_cache.py b/services/lora_cache.py new file mode 100644 index 00000000..d97a6567 --- /dev/null +++ b/services/lora_cache.py @@ -0,0 +1,64 @@ +import asyncio +from typing import List, Dict +from dataclasses import dataclass +from operator import itemgetter + +@dataclass +class LoraCache: + """Cache structure for LoRA data""" + raw_data: List[Dict] + sorted_by_name: List[Dict] + sorted_by_date: List[Dict] + folders: List[str] + + def __post_init__(self): + self._lock = asyncio.Lock() + + async def resort(self): + """Resort all cached data views""" + async with self._lock: + self.sorted_by_name = sorted( + self.raw_data, + key=lambda x: x['model_name'].lower() # Case-insensitive sort + ) + self.sorted_by_date = sorted( + self.raw_data, + key=itemgetter('modified'), + reverse=True + ) + # Update folder list + self.folders = sorted(list(set( + l['folder'] for l in self.raw_data + ))) + + async def update_preview_url(self, file_path: str, preview_url: str) -> bool: + """Update preview_url for a specific lora in all cached data + + Args: + file_path: The file path of the lora to update + preview_url: The new preview URL + + Returns: + bool: True if the update was successful, False if the lora wasn't found + """ + async with self._lock: + # Update in raw_data + for item in self.raw_data: + if item['file_path'] == file_path: + item['preview_url'] = preview_url + break + else: + return False # Lora not found + + # Update in sorted lists (references to the same dict objects) + for item in self.sorted_by_name: + if item['file_path'] == file_path: + item['preview_url'] = preview_url + break + + for item in self.sorted_by_date: + if item['file_path'] == file_path: + item['preview_url'] = preview_url + break + + return True \ No newline at end of file diff --git a/services/lora_scanner.py b/services/lora_scanner.py index 501ba4bb..b055cc64 100644 --- a/services/lora_scanner.py +++ b/services/lora_scanner.py @@ -7,48 +7,10 @@ from dataclasses import dataclass from operator import itemgetter from ..config import config from ..utils.file_utils import load_metadata, get_file_info +from .lora_cache import LoraCache logger = logging.getLogger(__name__) -@dataclass -class LoraCache: - """Cache structure for LoRA data""" - raw_data: List[Dict] - sorted_by_name: List[Dict] - sorted_by_date: List[Dict] - folders: List[str] - - def update_preview_url(self, file_path: str, preview_url: str) -> bool: - """Update preview_url for a specific lora in all cached data - - Args: - file_path: The file path of the lora to update - preview_url: The new preview URL - - Returns: - bool: True if the update was successful, False if the lora wasn't found - """ - # Update in raw_data - for item in self.raw_data: - if item['file_path'] == file_path: - item['preview_url'] = preview_url - break - else: - return False # Lora not found - - # Update in sorted lists (references to the same dict objects) - for item in self.sorted_by_name: - if item['file_path'] == file_path: - item['preview_url'] = preview_url - break - - for item in self.sorted_by_date: - if item['file_path'] == file_path: - item['preview_url'] = preview_url - break - - return True - class LoraScanner: """Service for scanning and managing LoRA files""" @@ -60,7 +22,6 @@ class LoraScanner: async def get_cached_data(self, force_refresh: bool = False) -> LoraCache: """Get cached LoRA data, refresh if needed""" async with self._initialization_lock: - current_time = time.time() # 如果正在初始化,等待完成 if self._initialization_task and not self._initialization_task.done(): @@ -100,7 +61,7 @@ class LoraScanner: ) # Call resort_cache to create sorted views - await self.resort_cache() + await self._cache.resort() async def get_paginated_data(self, page: int, @@ -110,27 +71,29 @@ class LoraScanner: """Get paginated LoRA data""" # 确保缓存已初始化 cache = await self.get_cached_data() + + async with cache._lock: - # Select sorted data based on sort_by parameter - data = (cache.sorted_by_date if sort_by == 'date' - else cache.sorted_by_name) - - # Apply folder filter if specified - if folder is not None: - data = [item for item in data if item['folder'] == folder] - - # Calculate pagination - total_items = len(data) - start_idx = (page - 1) * page_size - end_idx = min(start_idx + page_size, total_items) - - return { - 'items': data[start_idx:end_idx], - 'total': total_items, - 'page': page, - 'page_size': page_size, - 'total_pages': (total_items + page_size - 1) // page_size - } + # Select sorted data based on sort_by parameter + data = (cache.sorted_by_date if sort_by == 'date' + else cache.sorted_by_name) + + # Apply folder filter if specified + if folder is not None: + data = [item for item in data if item['folder'] == folder] + + # Calculate pagination + total_items = len(data) + start_idx = (page - 1) * page_size + end_idx = min(start_idx + page_size, total_items) + + return { + 'items': data[start_idx:end_idx], + 'total': total_items, + 'page': page, + 'page_size': page_size, + 'total_pages': (total_items + page_size - 1) // page_size + } def invalidate_cache(self): """Invalidate the current cache""" @@ -229,22 +192,4 @@ class LoraScanner: except Exception as e: logger.error(f"Error scanning {file_path}: {e}") return None - - async def resort_cache(self): - """Resort cache data""" - if not self._cache: - return - - self._cache.sorted_by_name = sorted( - self._cache.raw_data, - key=lambda x: x['model_name'].lower() # 使用 lower() 来实现不区分大小写的排序 - ) - self._cache.sorted_by_date = sorted( - self._cache.raw_data, - key=itemgetter('modified'), - reverse=True - ) - # 更新文件夹列表 - self._cache.folders = sorted(list(set( - l['folder'] for l in self._cache.raw_data - ))) +