mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
Improve LoRA cache initialization with async locking and background task
This commit is contained in:
@@ -12,9 +12,10 @@ class Config:
|
|||||||
|
|
||||||
def _init_lora_paths(self) -> List[str]:
|
def _init_lora_paths(self) -> List[str]:
|
||||||
"""Initialize and validate LoRA paths from ComfyUI settings"""
|
"""Initialize and validate LoRA paths from ComfyUI settings"""
|
||||||
paths = [path.replace(os.sep, "/")
|
paths = list(set(path.replace(os.sep, "/")
|
||||||
for path in folder_paths.get_folder_paths("loras")
|
for path in folder_paths.get_folder_paths("loras")
|
||||||
if os.path.exists(path)]
|
if os.path.exists(path)))
|
||||||
|
print("Found LoRA roots:", "\n - " + "\n - ".join(paths))
|
||||||
|
|
||||||
if not paths:
|
if not paths:
|
||||||
raise ValueError("No valid loras folders found in ComfyUI configuration")
|
raise ValueError("No valid loras folders found in ComfyUI configuration")
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
|
import asyncio
|
||||||
from server import PromptServer # type: ignore
|
from server import PromptServer # type: ignore
|
||||||
from .config import config
|
from .config import config
|
||||||
from .routes.lora_routes import LoraRoutes
|
from .routes.lora_routes import LoraRoutes
|
||||||
from .routes.api_routes import ApiRoutes
|
from .routes.api_routes import ApiRoutes
|
||||||
|
from .services.lora_scanner import LoraScanner
|
||||||
|
|
||||||
class LoraManager:
|
class LoraManager:
|
||||||
"""Main entry point for LoRA Manager plugin"""
|
"""Main entry point for LoRA Manager plugin"""
|
||||||
@@ -20,5 +22,29 @@ class LoraManager:
|
|||||||
app.router.add_static('/loras_static', config.static_path)
|
app.router.add_static('/loras_static', config.static_path)
|
||||||
|
|
||||||
# Setup feature routes
|
# Setup feature routes
|
||||||
|
routes = LoraRoutes()
|
||||||
|
api_routes = ApiRoutes()
|
||||||
|
|
||||||
LoraRoutes.setup_routes(app)
|
LoraRoutes.setup_routes(app)
|
||||||
ApiRoutes.setup_routes(app)
|
ApiRoutes.setup_routes(app)
|
||||||
|
|
||||||
|
# Schedule cache initialization using the application's startup handler
|
||||||
|
app.on_startup.append(lambda app: cls._schedule_cache_init(routes.scanner))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _schedule_cache_init(cls, scanner: LoraScanner):
|
||||||
|
"""Schedule cache initialization in the running event loop"""
|
||||||
|
try:
|
||||||
|
# Create the initialization task
|
||||||
|
asyncio.create_task(cls._initialize_cache(scanner))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"LoRA Manager: Error scheduling cache initialization: {e}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _initialize_cache(cls, scanner: LoraScanner):
|
||||||
|
"""Initialize cache in background"""
|
||||||
|
try:
|
||||||
|
await scanner.get_cached_data(force_refresh=True)
|
||||||
|
print("LoRA Manager: Cache initialization completed")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"LoRA Manager: Error initializing cache: {e}")
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
import asyncio
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
@@ -55,34 +56,59 @@ class LoraScanner:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._cache: Optional[LoraCache] = None
|
self._cache: Optional[LoraCache] = None
|
||||||
|
self._initialization_lock = asyncio.Lock()
|
||||||
|
self._initialization_task: Optional[asyncio.Task] = None
|
||||||
self.cache_ttl = 300 # 5 minutes cache TTL
|
self.cache_ttl = 300 # 5 minutes cache TTL
|
||||||
|
|
||||||
async def get_cached_data(self, force_refresh: bool = False) -> LoraCache:
|
async def get_cached_data(self, force_refresh: bool = False) -> LoraCache:
|
||||||
"""Get cached LoRA data, refresh if needed"""
|
"""Get cached LoRA data, refresh if needed"""
|
||||||
current_time = time.time()
|
async with self._initialization_lock:
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
if (self._cache is None or
|
# 如果正在初始化,等待完成
|
||||||
force_refresh or
|
if self._initialization_task and not self._initialization_task.done():
|
||||||
current_time - self._cache.timestamp > self.cache_ttl):
|
try:
|
||||||
|
await self._initialization_task
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Cache initialization failed: {e}")
|
||||||
|
self._initialization_task = None
|
||||||
|
|
||||||
# Scan for new data
|
if (self._cache is None or
|
||||||
raw_data = await self.scan_all_loras()
|
force_refresh or
|
||||||
|
current_time - self._cache.timestamp > self.cache_ttl):
|
||||||
|
|
||||||
# Create sorted views
|
# 创建新的初始化任务
|
||||||
sorted_by_name = sorted(raw_data, key=itemgetter('model_name'))
|
if not self._initialization_task or self._initialization_task.done():
|
||||||
sorted_by_date = sorted(raw_data, key=itemgetter('modified'), reverse=True)
|
self._initialization_task = asyncio.create_task(self._initialize_cache())
|
||||||
folders = sorted(list(set(l['folder'] for l in raw_data)))
|
|
||||||
|
|
||||||
# Update cache
|
try:
|
||||||
self._cache = LoraCache(
|
await self._initialization_task
|
||||||
raw_data=raw_data,
|
except Exception as e:
|
||||||
sorted_by_name=sorted_by_name,
|
logger.error(f"Cache initialization failed: {e}")
|
||||||
sorted_by_date=sorted_by_date,
|
# 如果缓存已存在,继续使用旧缓存
|
||||||
folders=folders,
|
if self._cache is None:
|
||||||
timestamp=current_time
|
raise # 如果没有缓存,则抛出异常
|
||||||
)
|
|
||||||
|
|
||||||
return self._cache
|
return self._cache
|
||||||
|
|
||||||
|
async def _initialize_cache(self) -> None:
|
||||||
|
"""Initialize or refresh the cache"""
|
||||||
|
# Scan for new data
|
||||||
|
raw_data = await self.scan_all_loras()
|
||||||
|
|
||||||
|
# Create sorted views
|
||||||
|
sorted_by_name = sorted(raw_data, key=itemgetter('model_name'))
|
||||||
|
sorted_by_date = sorted(raw_data, key=itemgetter('modified'), reverse=True)
|
||||||
|
folders = sorted(list(set(l['folder'] for l in raw_data)))
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
self._cache = LoraCache(
|
||||||
|
raw_data=raw_data,
|
||||||
|
sorted_by_name=sorted_by_name,
|
||||||
|
sorted_by_date=sorted_by_date,
|
||||||
|
folders=folders,
|
||||||
|
timestamp=time.time()
|
||||||
|
)
|
||||||
|
|
||||||
async def get_paginated_data(self,
|
async def get_paginated_data(self,
|
||||||
page: int,
|
page: int,
|
||||||
@@ -90,6 +116,7 @@ class LoraScanner:
|
|||||||
sort_by: str = 'date',
|
sort_by: str = 'date',
|
||||||
folder: Optional[str] = None) -> Dict:
|
folder: Optional[str] = None) -> Dict:
|
||||||
"""Get paginated LoRA data"""
|
"""Get paginated LoRA data"""
|
||||||
|
# 确保缓存已初始化
|
||||||
cache = await self.get_cached_data()
|
cache = await self.get_cached_data()
|
||||||
|
|
||||||
# Select sorted data based on sort_by parameter
|
# Select sorted data based on sort_by parameter
|
||||||
|
|||||||
Reference in New Issue
Block a user