diff --git a/config.py b/config.py index 684a4f05..57e00f39 100644 --- a/config.py +++ b/config.py @@ -12,9 +12,10 @@ class Config: def _init_lora_paths(self) -> List[str]: """Initialize and validate LoRA paths from ComfyUI settings""" - paths = [path.replace(os.sep, "/") + paths = list(set(path.replace(os.sep, "/") for path in folder_paths.get_folder_paths("loras") - if os.path.exists(path)] + if os.path.exists(path))) + print("Found LoRA roots:", "\n - " + "\n - ".join(paths)) if not paths: raise ValueError("No valid loras folders found in ComfyUI configuration") diff --git a/lora_manager.py b/lora_manager.py index a27511fd..c9801b26 100644 --- a/lora_manager.py +++ b/lora_manager.py @@ -1,7 +1,9 @@ +import asyncio from server import PromptServer # type: ignore from .config import config from .routes.lora_routes import LoraRoutes from .routes.api_routes import ApiRoutes +from .services.lora_scanner import LoraScanner class LoraManager: """Main entry point for LoRA Manager plugin""" @@ -20,5 +22,29 @@ class LoraManager: app.router.add_static('/loras_static', config.static_path) # Setup feature routes + routes = LoraRoutes() + api_routes = ApiRoutes() + LoraRoutes.setup_routes(app) - ApiRoutes.setup_routes(app) \ No newline at end of file + ApiRoutes.setup_routes(app) + + # Schedule cache initialization using the application's startup handler + app.on_startup.append(lambda app: cls._schedule_cache_init(routes.scanner)) + + @classmethod + async def _schedule_cache_init(cls, scanner: LoraScanner): + """Schedule cache initialization in the running event loop""" + try: + # Create the initialization task + asyncio.create_task(cls._initialize_cache(scanner)) + except Exception as e: + print(f"LoRA Manager: Error scheduling cache initialization: {e}") + + @classmethod + async def _initialize_cache(cls, scanner: LoraScanner): + """Initialize cache in background""" + try: + await scanner.get_cached_data(force_refresh=True) + print("LoRA Manager: Cache initialization completed") + except Exception as e: + print(f"LoRA Manager: Error initializing cache: {e}") \ No newline at end of file diff --git a/services/lora_scanner.py b/services/lora_scanner.py index 55aa8177..b039ca19 100644 --- a/services/lora_scanner.py +++ b/services/lora_scanner.py @@ -1,6 +1,7 @@ import os import logging import time +import asyncio from typing import List, Dict, Optional from dataclasses import dataclass from operator import itemgetter @@ -55,34 +56,59 @@ class LoraScanner: def __init__(self): self._cache: Optional[LoraCache] = None + self._initialization_lock = asyncio.Lock() + self._initialization_task: Optional[asyncio.Task] = None self.cache_ttl = 300 # 5 minutes cache TTL async def get_cached_data(self, force_refresh: bool = False) -> LoraCache: """Get cached LoRA data, refresh if needed""" - current_time = time.time() + async with self._initialization_lock: + current_time = time.time() + + # 如果正在初始化,等待完成 + if self._initialization_task and not self._initialization_task.done(): + try: + await self._initialization_task + except Exception as e: + logger.error(f"Cache initialization failed: {e}") + self._initialization_task = None + + if (self._cache is None or + force_refresh or + current_time - self._cache.timestamp > self.cache_ttl): + + # 创建新的初始化任务 + if not self._initialization_task or self._initialization_task.done(): + self._initialization_task = asyncio.create_task(self._initialize_cache()) + + try: + await self._initialization_task + except Exception as e: + logger.error(f"Cache initialization failed: {e}") + # 如果缓存已存在,继续使用旧缓存 + if self._cache is None: + raise # 如果没有缓存,则抛出异常 + + return self._cache + + async def _initialize_cache(self) -> None: + """Initialize or refresh the cache""" + # Scan for new data + raw_data = await self.scan_all_loras() - if (self._cache is None or - force_refresh or - current_time - self._cache.timestamp > self.cache_ttl): - - # Scan for new data - raw_data = await self.scan_all_loras() - - # Create sorted views - sorted_by_name = sorted(raw_data, key=itemgetter('model_name')) - sorted_by_date = sorted(raw_data, key=itemgetter('modified'), reverse=True) - folders = sorted(list(set(l['folder'] for l in raw_data))) - - # Update cache - self._cache = LoraCache( - raw_data=raw_data, - sorted_by_name=sorted_by_name, - sorted_by_date=sorted_by_date, - folders=folders, - timestamp=current_time - ) + # Create sorted views + sorted_by_name = sorted(raw_data, key=itemgetter('model_name')) + sorted_by_date = sorted(raw_data, key=itemgetter('modified'), reverse=True) + folders = sorted(list(set(l['folder'] for l in raw_data))) - return self._cache + # Update cache + self._cache = LoraCache( + raw_data=raw_data, + sorted_by_name=sorted_by_name, + sorted_by_date=sorted_by_date, + folders=folders, + timestamp=time.time() + ) async def get_paginated_data(self, page: int, @@ -90,6 +116,7 @@ class LoraScanner: sort_by: str = 'date', folder: Optional[str] = None) -> Dict: """Get paginated LoRA data""" + # 确保缓存已初始化 cache = await self.get_cached_data() # Select sorted data based on sort_by parameter