Improve LoRA cache initialization with async locking and background task

This commit is contained in:
Will Miao
2025-02-02 23:51:36 +08:00
parent 8e8b80ddcf
commit 9aa0dcfb2b
3 changed files with 79 additions and 25 deletions

View File

@@ -12,9 +12,10 @@ class Config:
def _init_lora_paths(self) -> List[str]: def _init_lora_paths(self) -> List[str]:
"""Initialize and validate LoRA paths from ComfyUI settings""" """Initialize and validate LoRA paths from ComfyUI settings"""
paths = [path.replace(os.sep, "/") paths = list(set(path.replace(os.sep, "/")
for path in folder_paths.get_folder_paths("loras") for path in folder_paths.get_folder_paths("loras")
if os.path.exists(path)] if os.path.exists(path)))
print("Found LoRA roots:", "\n - " + "\n - ".join(paths))
if not paths: if not paths:
raise ValueError("No valid loras folders found in ComfyUI configuration") raise ValueError("No valid loras folders found in ComfyUI configuration")

View File

@@ -1,7 +1,9 @@
import asyncio
from server import PromptServer # type: ignore from server import PromptServer # type: ignore
from .config import config from .config import config
from .routes.lora_routes import LoraRoutes from .routes.lora_routes import LoraRoutes
from .routes.api_routes import ApiRoutes from .routes.api_routes import ApiRoutes
from .services.lora_scanner import LoraScanner
class LoraManager: class LoraManager:
"""Main entry point for LoRA Manager plugin""" """Main entry point for LoRA Manager plugin"""
@@ -20,5 +22,29 @@ class LoraManager:
app.router.add_static('/loras_static', config.static_path) app.router.add_static('/loras_static', config.static_path)
# Setup feature routes # Setup feature routes
routes = LoraRoutes()
api_routes = ApiRoutes()
LoraRoutes.setup_routes(app) LoraRoutes.setup_routes(app)
ApiRoutes.setup_routes(app) ApiRoutes.setup_routes(app)
# Schedule cache initialization using the application's startup handler
app.on_startup.append(lambda app: cls._schedule_cache_init(routes.scanner))
@classmethod
async def _schedule_cache_init(cls, scanner: LoraScanner):
"""Schedule cache initialization in the running event loop"""
try:
# Create the initialization task
asyncio.create_task(cls._initialize_cache(scanner))
except Exception as e:
print(f"LoRA Manager: Error scheduling cache initialization: {e}")
@classmethod
async def _initialize_cache(cls, scanner: LoraScanner):
"""Initialize cache in background"""
try:
await scanner.get_cached_data(force_refresh=True)
print("LoRA Manager: Cache initialization completed")
except Exception as e:
print(f"LoRA Manager: Error initializing cache: {e}")

View File

@@ -1,6 +1,7 @@
import os import os
import logging import logging
import time import time
import asyncio
from typing import List, Dict, Optional from typing import List, Dict, Optional
from dataclasses import dataclass from dataclasses import dataclass
from operator import itemgetter from operator import itemgetter
@@ -55,34 +56,59 @@ class LoraScanner:
def __init__(self): def __init__(self):
self._cache: Optional[LoraCache] = None self._cache: Optional[LoraCache] = None
self._initialization_lock = asyncio.Lock()
self._initialization_task: Optional[asyncio.Task] = None
self.cache_ttl = 300 # 5 minutes cache TTL self.cache_ttl = 300 # 5 minutes cache TTL
async def get_cached_data(self, force_refresh: bool = False) -> LoraCache: async def get_cached_data(self, force_refresh: bool = False) -> LoraCache:
"""Get cached LoRA data, refresh if needed""" """Get cached LoRA data, refresh if needed"""
current_time = time.time() async with self._initialization_lock:
current_time = time.time()
# 如果正在初始化,等待完成
if self._initialization_task and not self._initialization_task.done():
try:
await self._initialization_task
except Exception as e:
logger.error(f"Cache initialization failed: {e}")
self._initialization_task = None
if (self._cache is None or
force_refresh or
current_time - self._cache.timestamp > self.cache_ttl):
# 创建新的初始化任务
if not self._initialization_task or self._initialization_task.done():
self._initialization_task = asyncio.create_task(self._initialize_cache())
try:
await self._initialization_task
except Exception as e:
logger.error(f"Cache initialization failed: {e}")
# 如果缓存已存在,继续使用旧缓存
if self._cache is None:
raise # 如果没有缓存,则抛出异常
return self._cache
async def _initialize_cache(self) -> None:
"""Initialize or refresh the cache"""
# Scan for new data
raw_data = await self.scan_all_loras()
if (self._cache is None or # Create sorted views
force_refresh or sorted_by_name = sorted(raw_data, key=itemgetter('model_name'))
current_time - self._cache.timestamp > self.cache_ttl): sorted_by_date = sorted(raw_data, key=itemgetter('modified'), reverse=True)
folders = sorted(list(set(l['folder'] for l in raw_data)))
# Scan for new data
raw_data = await self.scan_all_loras()
# Create sorted views
sorted_by_name = sorted(raw_data, key=itemgetter('model_name'))
sorted_by_date = sorted(raw_data, key=itemgetter('modified'), reverse=True)
folders = sorted(list(set(l['folder'] for l in raw_data)))
# Update cache
self._cache = LoraCache(
raw_data=raw_data,
sorted_by_name=sorted_by_name,
sorted_by_date=sorted_by_date,
folders=folders,
timestamp=current_time
)
return self._cache # Update cache
self._cache = LoraCache(
raw_data=raw_data,
sorted_by_name=sorted_by_name,
sorted_by_date=sorted_by_date,
folders=folders,
timestamp=time.time()
)
async def get_paginated_data(self, async def get_paginated_data(self,
page: int, page: int,
@@ -90,6 +116,7 @@ class LoraScanner:
sort_by: str = 'date', sort_by: str = 'date',
folder: Optional[str] = None) -> Dict: folder: Optional[str] = None) -> Dict:
"""Get paginated LoRA data""" """Get paginated LoRA data"""
# 确保缓存已初始化
cache = await self.get_cached_data() cache = await self.get_cached_data()
# Select sorted data based on sort_by parameter # Select sorted data based on sort_by parameter