mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
Refactor cache initialization in LoraManager and RecipeScanner for improved background processing and error handling
This commit is contained in:
@@ -1,5 +1,4 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
from server import PromptServer # type: ignore
|
from server import PromptServer # type: ignore
|
||||||
from .config import config
|
from .config import config
|
||||||
from .routes.lora_routes import LoraRoutes
|
from .routes.lora_routes import LoraRoutes
|
||||||
@@ -10,9 +9,6 @@ from .services.lora_scanner import LoraScanner
|
|||||||
from .services.checkpoint_scanner import CheckpointScanner
|
from .services.checkpoint_scanner import CheckpointScanner
|
||||||
from .services.recipe_scanner import RecipeScanner
|
from .services.recipe_scanner import RecipeScanner
|
||||||
from .services.file_monitor import LoraFileMonitor, CheckpointFileMonitor
|
from .services.file_monitor import LoraFileMonitor, CheckpointFileMonitor
|
||||||
from .services.lora_cache import LoraCache
|
|
||||||
from .services.recipe_cache import RecipeCache
|
|
||||||
from .services.model_cache import ModelCache
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -117,68 +113,12 @@ class LoraManager:
|
|||||||
"""Schedule cache initialization in the running event loop"""
|
"""Schedule cache initialization in the running event loop"""
|
||||||
try:
|
try:
|
||||||
# Create low-priority initialization tasks
|
# Create low-priority initialization tasks
|
||||||
lora_task = asyncio.create_task(cls._initialize_lora_cache(lora_scanner), name='lora_cache_init')
|
lora_task = asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init')
|
||||||
checkpoint_task = asyncio.create_task(cls._initialize_checkpoint_cache(checkpoint_scanner), name='checkpoint_cache_init')
|
checkpoint_task = asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init')
|
||||||
recipe_task = asyncio.create_task(cls._initialize_recipe_cache(recipe_scanner), name='recipe_cache_init')
|
recipe_task = asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
|
||||||
logger.info("Cache initialization tasks scheduled to run in background")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LoRA Manager: Error scheduling cache initialization: {e}")
|
logger.error(f"LoRA Manager: Error scheduling cache initialization: {e}")
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _initialize_lora_cache(cls, scanner: LoraScanner):
|
|
||||||
"""Initialize lora cache in background"""
|
|
||||||
try:
|
|
||||||
# Set initial placeholder cache
|
|
||||||
scanner._cache = LoraCache(
|
|
||||||
raw_data=[],
|
|
||||||
sorted_by_name=[],
|
|
||||||
sorted_by_date=[],
|
|
||||||
folders=[]
|
|
||||||
)
|
|
||||||
# 使用线程池执行耗时操作
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
await loop.run_in_executor(
|
|
||||||
None, # 使用默认线程池
|
|
||||||
lambda: scanner.get_cached_data_sync(force_refresh=True) # 创建同步版本的方法
|
|
||||||
)
|
|
||||||
# Load cache in phases
|
|
||||||
# await scanner.get_cached_data(force_refresh=True)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"LoRA Manager: Error initializing lora cache: {e}")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _initialize_checkpoint_cache(cls, scanner: CheckpointScanner):
|
|
||||||
"""Initialize checkpoint cache in background"""
|
|
||||||
try:
|
|
||||||
# Set initial placeholder cache
|
|
||||||
scanner._cache = ModelCache(
|
|
||||||
raw_data=[],
|
|
||||||
sorted_by_name=[],
|
|
||||||
sorted_by_date=[],
|
|
||||||
folders=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load cache in phases
|
|
||||||
await scanner.get_cached_data(force_refresh=True)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"LoRA Manager: Error initializing checkpoint cache: {e}")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _initialize_recipe_cache(cls, scanner: RecipeScanner):
|
|
||||||
"""Initialize recipe cache in background with a delay"""
|
|
||||||
try:
|
|
||||||
# Set initial empty cache
|
|
||||||
scanner._cache = RecipeCache(
|
|
||||||
raw_data=[],
|
|
||||||
sorted_by_name=[],
|
|
||||||
sorted_by_date=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Force refresh to load the actual data
|
|
||||||
await scanner.get_cached_data(force_refresh=True)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"LoRA Manager: Error initializing recipe cache: {e}")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _cleanup(cls, app):
|
async def _cleanup(cls, app):
|
||||||
"""Cleanup resources"""
|
"""Cleanup resources"""
|
||||||
|
|||||||
@@ -2,12 +2,14 @@ import os
|
|||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
import jinja2
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from ..services.checkpoint_scanner import CheckpointScanner
|
from ..services.checkpoint_scanner import CheckpointScanner
|
||||||
from ..config import config
|
from ..config import config
|
||||||
|
from ..services.settings_manager import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -16,6 +18,10 @@ class CheckpointsRoutes:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scanner = CheckpointScanner()
|
self.scanner = CheckpointScanner()
|
||||||
|
self.template_env = jinja2.Environment(
|
||||||
|
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||||
|
autoescape=True
|
||||||
|
)
|
||||||
|
|
||||||
def setup_routes(self, app):
|
def setup_routes(self, app):
|
||||||
"""Register routes with the aiohttp app"""
|
"""Register routes with the aiohttp app"""
|
||||||
@@ -144,3 +150,59 @@ class CheckpointsRoutes:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
||||||
return web.json_response({"error": str(e)}, status=500)
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def handle_checkpoints_page(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle GET /checkpoints request"""
|
||||||
|
try:
|
||||||
|
# 检查缓存初始化状态,根据initialize_in_background的工作方式调整判断逻辑
|
||||||
|
is_initializing = (
|
||||||
|
self.scanner._cache is None or
|
||||||
|
len(self.scanner._cache.raw_data) == 0 or
|
||||||
|
hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_initializing:
|
||||||
|
# 如果正在初始化,返回一个只包含加载提示的页面
|
||||||
|
template = self.template_env.get_template('checkpoints.html')
|
||||||
|
rendered = template.render(
|
||||||
|
folders=[], # 空文件夹列表
|
||||||
|
is_initializing=True, # 新增标志
|
||||||
|
settings=settings, # Pass settings to template
|
||||||
|
request=request # Pass the request object to the template
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Checkpoints page is initializing, returning loading page")
|
||||||
|
else:
|
||||||
|
# 正常流程 - 获取已经初始化好的缓存数据
|
||||||
|
try:
|
||||||
|
cache = await self.scanner.get_cached_data(force_refresh=False)
|
||||||
|
template = self.template_env.get_template('checkpoints.html')
|
||||||
|
rendered = template.render(
|
||||||
|
folders=cache.folders,
|
||||||
|
is_initializing=False,
|
||||||
|
settings=settings, # Pass settings to template
|
||||||
|
request=request # Pass the request object to the template
|
||||||
|
)
|
||||||
|
logger.debug(f"Checkpoints page loaded successfully with {len(cache.raw_data)} items")
|
||||||
|
except Exception as cache_error:
|
||||||
|
logger.error(f"Error loading checkpoints cache data: {cache_error}")
|
||||||
|
# 如果获取缓存失败,也显示初始化页面
|
||||||
|
template = self.template_env.get_template('checkpoints.html')
|
||||||
|
rendered = template.render(
|
||||||
|
folders=[],
|
||||||
|
is_initializing=True,
|
||||||
|
settings=settings,
|
||||||
|
request=request
|
||||||
|
)
|
||||||
|
logger.info("Checkpoints cache error, returning initialization page")
|
||||||
|
|
||||||
|
return web.Response(
|
||||||
|
text=rendered,
|
||||||
|
content_type='text/html'
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling checkpoints request: {e}", exc_info=True)
|
||||||
|
return web.Response(
|
||||||
|
text="Error loading checkpoints page",
|
||||||
|
status=500
|
||||||
|
)
|
||||||
|
|||||||
@@ -58,13 +58,11 @@ class LoraRoutes:
|
|||||||
async def handle_loras_page(self, request: web.Request) -> web.Response:
|
async def handle_loras_page(self, request: web.Request) -> web.Response:
|
||||||
"""Handle GET /loras request"""
|
"""Handle GET /loras request"""
|
||||||
try:
|
try:
|
||||||
# 检查缓存初始化状态,增强判断条件
|
# 检查缓存初始化状态,根据initialize_in_background的工作方式调整判断逻辑
|
||||||
is_initializing = (
|
is_initializing = (
|
||||||
self.scanner._cache is None or
|
self.scanner._cache is None or
|
||||||
(self.scanner._initialization_task is not None and
|
len(self.scanner._cache.raw_data) == 0 or
|
||||||
not self.scanner._initialization_task.done()) or
|
hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing
|
||||||
(self.scanner._cache is not None and len(self.scanner._cache.raw_data) == 0 and
|
|
||||||
self.scanner._initialization_task is not None)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_initializing:
|
if is_initializing:
|
||||||
@@ -79,7 +77,7 @@ class LoraRoutes:
|
|||||||
|
|
||||||
logger.info("Loras page is initializing, returning loading page")
|
logger.info("Loras page is initializing, returning loading page")
|
||||||
else:
|
else:
|
||||||
# 正常流程 - 但不要等待缓存刷新
|
# 正常流程 - 获取已经初始化好的缓存数据
|
||||||
try:
|
try:
|
||||||
cache = await self.scanner.get_cached_data(force_refresh=False)
|
cache = await self.scanner.get_cached_data(force_refresh=False)
|
||||||
template = self.template_env.get_template('loras.html')
|
template = self.template_env.get_template('loras.html')
|
||||||
@@ -117,32 +115,45 @@ class LoraRoutes:
|
|||||||
async def handle_recipes_page(self, request: web.Request) -> web.Response:
|
async def handle_recipes_page(self, request: web.Request) -> web.Response:
|
||||||
"""Handle GET /loras/recipes request"""
|
"""Handle GET /loras/recipes request"""
|
||||||
try:
|
try:
|
||||||
# Check cache initialization status
|
# 检查缓存初始化状态,与handle_loras_page保持一致的逻辑
|
||||||
is_initializing = (
|
is_initializing = (
|
||||||
self.recipe_scanner._cache is None and
|
self.recipe_scanner._cache is None or
|
||||||
(self.recipe_scanner._initialization_task is not None and
|
len(self.recipe_scanner._cache.raw_data) == 0 or
|
||||||
not self.recipe_scanner._initialization_task.done())
|
hasattr(self.recipe_scanner, '_is_initializing') and self.recipe_scanner._is_initializing
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_initializing:
|
if is_initializing:
|
||||||
# If initializing, return a loading page
|
# 如果正在初始化,返回一个只包含加载提示的页面
|
||||||
template = self.template_env.get_template('recipes.html')
|
template = self.template_env.get_template('recipes.html')
|
||||||
rendered = template.render(
|
rendered = template.render(
|
||||||
is_initializing=True,
|
is_initializing=True,
|
||||||
settings=settings,
|
settings=settings,
|
||||||
request=request # Pass the request object to the template
|
request=request # Pass the request object to the template
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# return empty recipes
|
|
||||||
recipes_data = []
|
|
||||||
|
|
||||||
template = self.template_env.get_template('recipes.html')
|
logger.info("Recipes page is initializing, returning loading page")
|
||||||
rendered = template.render(
|
else:
|
||||||
recipes=recipes_data,
|
# 正常流程 - 获取已经初始化好的缓存数据
|
||||||
is_initializing=False,
|
try:
|
||||||
settings=settings,
|
cache = await self.recipe_scanner.get_cached_data(force_refresh=False)
|
||||||
request=request # Pass the request object to the template
|
template = self.template_env.get_template('recipes.html')
|
||||||
)
|
rendered = template.render(
|
||||||
|
recipes=[], # Frontend will load recipes via API
|
||||||
|
is_initializing=False,
|
||||||
|
settings=settings,
|
||||||
|
request=request # Pass the request object to the template
|
||||||
|
)
|
||||||
|
logger.debug(f"Recipes page loaded successfully with {len(cache.raw_data)} items")
|
||||||
|
except Exception as cache_error:
|
||||||
|
logger.error(f"Error loading recipe cache data: {cache_error}")
|
||||||
|
# 如果获取缓存失败,也显示初始化页面
|
||||||
|
template = self.template_env.get_template('recipes.html')
|
||||||
|
rendered = template.render(
|
||||||
|
is_initializing=True,
|
||||||
|
settings=settings,
|
||||||
|
request=request
|
||||||
|
)
|
||||||
|
logger.info("Recipe cache error, returning initialization page")
|
||||||
|
|
||||||
return web.Response(
|
return web.Response(
|
||||||
text=rendered,
|
text=rendered,
|
||||||
|
|||||||
@@ -1146,7 +1146,7 @@ class RecipeRoutes:
|
|||||||
return web.json_response({'success': False, 'error': 'Lora hash is required'}, status=400)
|
return web.json_response({'success': False, 'error': 'Lora hash is required'}, status=400)
|
||||||
|
|
||||||
# Log the search parameters
|
# Log the search parameters
|
||||||
logger.info(f"Getting recipes for Lora by hash: {lora_hash}")
|
logger.debug(f"Getting recipes for Lora by hash: {lora_hash}")
|
||||||
|
|
||||||
# Get all recipes from cache
|
# Get all recipes from cache
|
||||||
cache = await self.recipe_scanner.get_cached_data()
|
cache = await self.recipe_scanner.get_cached_data()
|
||||||
|
|||||||
@@ -34,49 +34,96 @@ class ModelScanner:
|
|||||||
self.file_extensions = file_extensions
|
self.file_extensions = file_extensions
|
||||||
self._cache = None
|
self._cache = None
|
||||||
self._hash_index = hash_index or ModelHashIndex()
|
self._hash_index = hash_index or ModelHashIndex()
|
||||||
self._initialization_lock = asyncio.Lock()
|
|
||||||
self._initialization_task = None
|
|
||||||
self.file_monitor = None
|
self.file_monitor = None
|
||||||
self._tags_count = {} # Dictionary to store tag counts
|
self._tags_count = {} # Dictionary to store tag counts
|
||||||
|
self._is_initializing = False # Flag to track initialization state
|
||||||
|
|
||||||
def set_file_monitor(self, monitor):
|
def set_file_monitor(self, monitor):
|
||||||
"""Set file monitor instance"""
|
"""Set file monitor instance"""
|
||||||
self.file_monitor = monitor
|
self.file_monitor = monitor
|
||||||
|
|
||||||
async def get_cached_data(self, force_refresh: bool = False) -> ModelCache:
|
async def initialize_in_background(self) -> None:
|
||||||
"""Get cached model data, refresh if needed"""
|
"""Initialize cache in background using thread pool"""
|
||||||
async with self._initialization_lock:
|
try:
|
||||||
# Return empty cache if not initialized and no refresh requested
|
# Set initial empty cache to avoid None reference errors
|
||||||
if self._cache is None and not force_refresh:
|
if self._cache is None:
|
||||||
return ModelCache(
|
self._cache = ModelCache(
|
||||||
raw_data=[],
|
raw_data=[],
|
||||||
sorted_by_name=[],
|
sorted_by_name=[],
|
||||||
sorted_by_date=[],
|
sorted_by_date=[],
|
||||||
folders=[]
|
folders=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait for ongoing initialization if any
|
|
||||||
if self._initialization_task and not self._initialization_task.done():
|
|
||||||
try:
|
|
||||||
await self._initialization_task
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Cache initialization failed: {e}")
|
|
||||||
self._initialization_task = None
|
|
||||||
|
|
||||||
if (self._cache is None or force_refresh):
|
# Set initializing flag to true
|
||||||
# Create new initialization task
|
self._is_initializing = True
|
||||||
if not self._initialization_task or self._initialization_task.done():
|
|
||||||
self._initialization_task = asyncio.create_task(self._initialize_cache())
|
start_time = time.time()
|
||||||
|
# Use thread pool to execute CPU-intensive operations
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None, # Use default thread pool
|
||||||
|
self._initialize_cache_sync # Run synchronous version in thread
|
||||||
|
)
|
||||||
|
logger.info(f"{self.model_type.capitalize()} cache initialized in {time.time() - start_time:.2f} seconds. Found {len(self._cache.raw_data)} models")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache in background: {e}")
|
||||||
|
finally:
|
||||||
|
# Always clear the initializing flag when done
|
||||||
|
self._is_initializing = False
|
||||||
|
|
||||||
|
def _initialize_cache_sync(self):
|
||||||
|
"""Synchronous version of cache initialization for thread pool execution"""
|
||||||
|
try:
|
||||||
|
# Create a new event loop for this thread
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
# Create a synchronous method to bypass the async lock
|
||||||
|
def sync_initialize_cache():
|
||||||
|
# Directly call the scan method to avoid lock issues
|
||||||
|
raw_data = loop.run_until_complete(self.scan_all_models())
|
||||||
|
|
||||||
try:
|
# Update hash index and tags count
|
||||||
await self._initialization_task
|
for model_data in raw_data:
|
||||||
except Exception as e:
|
if 'sha256' in model_data and 'file_path' in model_data:
|
||||||
logger.error(f"Cache initialization failed: {e}")
|
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
|
||||||
# Continue using old cache if it exists
|
|
||||||
if self._cache is None:
|
# Count tags
|
||||||
raise # Raise exception if no cache available
|
if 'tags' in model_data and model_data['tags']:
|
||||||
|
for tag in model_data['tags']:
|
||||||
|
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
self._cache.raw_data = raw_data
|
||||||
|
loop.run_until_complete(self._cache.resort())
|
||||||
|
|
||||||
|
return self._cache
|
||||||
|
|
||||||
return self._cache
|
# Run our sync initialization that avoids lock conflicts
|
||||||
|
return sync_initialize_cache()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in thread-based {self.model_type} cache initialization: {e}")
|
||||||
|
finally:
|
||||||
|
# Clean up the event loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
async def get_cached_data(self, force_refresh: bool = False) -> ModelCache:
|
||||||
|
"""Get cached model data, refresh if needed"""
|
||||||
|
# If cache is not initialized, return an empty cache
|
||||||
|
# Actual initialization should be done via initialize_in_background
|
||||||
|
if self._cache is None and not force_refresh:
|
||||||
|
return ModelCache(
|
||||||
|
raw_data=[],
|
||||||
|
sorted_by_name=[],
|
||||||
|
sorted_by_date=[],
|
||||||
|
folders=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# If force refresh is requested, initialize the cache directly
|
||||||
|
if force_refresh:
|
||||||
|
await self._initialize_cache()
|
||||||
|
|
||||||
|
return self._cache
|
||||||
|
|
||||||
async def _initialize_cache(self) -> None:
|
async def _initialize_cache(self) -> None:
|
||||||
"""Initialize or refresh the cache"""
|
"""Initialize or refresh the cache"""
|
||||||
@@ -112,7 +159,6 @@ class ModelScanner:
|
|||||||
# Resort cache
|
# Resort cache
|
||||||
await self._cache.resort()
|
await self._cache.resort()
|
||||||
|
|
||||||
self._initialization_task = None
|
|
||||||
logger.info(f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} models")
|
logger.info(f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} models")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache: {e}")
|
logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache: {e}")
|
||||||
@@ -157,12 +203,10 @@ class ModelScanner:
|
|||||||
|
|
||||||
async def _get_file_info(self, file_path: str) -> Optional[BaseModelMetadata]:
|
async def _get_file_info(self, file_path: str) -> Optional[BaseModelMetadata]:
|
||||||
"""Get model file info and metadata (extensible for different model types)"""
|
"""Get model file info and metadata (extensible for different model types)"""
|
||||||
# Implementation may vary by model type - override in subclasses if needed
|
|
||||||
return await get_file_info(file_path, self.model_class)
|
return await get_file_info(file_path, self.model_class)
|
||||||
|
|
||||||
def _calculate_folder(self, file_path: str) -> str:
|
def _calculate_folder(self, file_path: str) -> str:
|
||||||
"""Calculate the folder path for a model file"""
|
"""Calculate the folder path for a model file"""
|
||||||
# Use original path to calculate relative path
|
|
||||||
for root in self.get_model_roots():
|
for root in self.get_model_roots():
|
||||||
if file_path.startswith(root):
|
if file_path.startswith(root):
|
||||||
rel_path = os.path.relpath(file_path, root)
|
rel_path = os.path.relpath(file_path, root)
|
||||||
@@ -172,11 +216,9 @@ class ModelScanner:
|
|||||||
# Common methods shared between scanners
|
# Common methods shared between scanners
|
||||||
async def _process_model_file(self, file_path: str, root_path: str) -> Dict:
|
async def _process_model_file(self, file_path: str, root_path: str) -> Dict:
|
||||||
"""Process a single model file and return its metadata"""
|
"""Process a single model file and return its metadata"""
|
||||||
# Try loading existing metadata
|
|
||||||
metadata = await load_metadata(file_path, self.model_class)
|
metadata = await load_metadata(file_path, self.model_class)
|
||||||
|
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
# Try to find and use .civitai.info file first
|
|
||||||
civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info"
|
civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info"
|
||||||
if os.path.exists(civitai_info_path):
|
if os.path.exists(civitai_info_path):
|
||||||
try:
|
try:
|
||||||
@@ -185,11 +227,9 @@ class ModelScanner:
|
|||||||
|
|
||||||
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
|
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
|
||||||
if file_info:
|
if file_info:
|
||||||
# Create a minimal file_info with the required fields
|
|
||||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||||
file_info['name'] = file_name
|
file_info['name'] = file_name
|
||||||
|
|
||||||
# Use from_civitai_info to create metadata
|
|
||||||
metadata = self.model_class.from_civitai_info(version_info, file_info, file_path)
|
metadata = self.model_class.from_civitai_info(version_info, file_info, file_path)
|
||||||
metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path))
|
metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path))
|
||||||
await save_metadata(file_path, metadata)
|
await save_metadata(file_path, metadata)
|
||||||
@@ -197,14 +237,11 @@ class ModelScanner:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
|
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
|
||||||
|
|
||||||
# If still no metadata, create new metadata
|
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = await self._get_file_info(file_path)
|
metadata = await self._get_file_info(file_path)
|
||||||
|
|
||||||
# Convert to dict and add folder info
|
|
||||||
model_data = metadata.to_dict()
|
model_data = metadata.to_dict()
|
||||||
|
|
||||||
# Try to fetch missing metadata from Civitai if needed
|
|
||||||
await self._fetch_missing_metadata(file_path, model_data)
|
await self._fetch_missing_metadata(file_path, model_data)
|
||||||
rel_path = os.path.relpath(file_path, root_path)
|
rel_path = os.path.relpath(file_path, root_path)
|
||||||
folder = os.path.dirname(rel_path)
|
folder = os.path.dirname(rel_path)
|
||||||
@@ -215,59 +252,47 @@ class ModelScanner:
|
|||||||
async def _fetch_missing_metadata(self, file_path: str, model_data: Dict) -> None:
|
async def _fetch_missing_metadata(self, file_path: str, model_data: Dict) -> None:
|
||||||
"""Fetch missing description and tags from Civitai if needed"""
|
"""Fetch missing description and tags from Civitai if needed"""
|
||||||
try:
|
try:
|
||||||
# Skip if already marked as deleted on Civitai
|
|
||||||
if model_data.get('civitai_deleted', False):
|
if model_data.get('civitai_deleted', False):
|
||||||
logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai")
|
logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if we need to fetch additional metadata from Civitai
|
|
||||||
needs_metadata_update = False
|
needs_metadata_update = False
|
||||||
model_id = None
|
model_id = None
|
||||||
|
|
||||||
# Check if we have Civitai model ID but missing metadata
|
|
||||||
if model_data.get('civitai'):
|
if model_data.get('civitai'):
|
||||||
model_id = model_data['civitai'].get('modelId')
|
model_id = model_data['civitai'].get('modelId')
|
||||||
|
|
||||||
if model_id:
|
if model_id:
|
||||||
model_id = str(model_id)
|
model_id = str(model_id)
|
||||||
# Check if tags or description are missing
|
|
||||||
tags_missing = not model_data.get('tags') or len(model_data.get('tags', [])) == 0
|
tags_missing = not model_data.get('tags') or len(model_data.get('tags', [])) == 0
|
||||||
desc_missing = not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")
|
desc_missing = not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")
|
||||||
needs_metadata_update = tags_missing or desc_missing
|
needs_metadata_update = tags_missing or desc_missing
|
||||||
|
|
||||||
# Fetch missing metadata if needed
|
|
||||||
if needs_metadata_update and model_id:
|
if needs_metadata_update and model_id:
|
||||||
logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}")
|
logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}")
|
||||||
from ..services.civitai_client import CivitaiClient
|
from ..services.civitai_client import CivitaiClient
|
||||||
client = CivitaiClient()
|
client = CivitaiClient()
|
||||||
|
|
||||||
# Get metadata and status code
|
|
||||||
model_metadata, status_code = await client.get_model_metadata(model_id)
|
model_metadata, status_code = await client.get_model_metadata(model_id)
|
||||||
await client.close()
|
await client.close()
|
||||||
|
|
||||||
# Handle 404 status (model deleted from Civitai)
|
|
||||||
if status_code == 404:
|
if status_code == 404:
|
||||||
logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)")
|
logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)")
|
||||||
model_data['civitai_deleted'] = True
|
model_data['civitai_deleted'] = True
|
||||||
|
|
||||||
# Save the updated metadata
|
|
||||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||||
json.dump(model_data, f, indent=2, ensure_ascii=False)
|
json.dump(model_data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
# Process valid metadata if available
|
|
||||||
elif model_metadata:
|
elif model_metadata:
|
||||||
logger.debug(f"Updating metadata for {file_path} with model ID {model_id}")
|
logger.debug(f"Updating metadata for {file_path} with model ID {model_id}")
|
||||||
|
|
||||||
# Update tags if they were missing
|
|
||||||
if model_metadata.get('tags') and (not model_data.get('tags') or len(model_data.get('tags', [])) == 0):
|
if model_metadata.get('tags') and (not model_data.get('tags') or len(model_data.get('tags', [])) == 0):
|
||||||
model_data['tags'] = model_metadata['tags']
|
model_data['tags'] = model_metadata['tags']
|
||||||
|
|
||||||
# Update description if it was missing
|
|
||||||
if model_metadata.get('description') and (not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")):
|
if model_metadata.get('description') and (not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")):
|
||||||
model_data['modelDescription'] = model_metadata['description']
|
model_data['modelDescription'] = model_metadata['description']
|
||||||
|
|
||||||
# Save the updated metadata
|
|
||||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||||
json.dump(model_data, f, indent=2, ensure_ascii=False)
|
json.dump(model_data, f, indent=2, ensure_ascii=False)
|
||||||
@@ -292,14 +317,12 @@ class ModelScanner:
|
|||||||
for entry in entries:
|
for entry in entries:
|
||||||
try:
|
try:
|
||||||
if entry.is_file(follow_symlinks=True):
|
if entry.is_file(follow_symlinks=True):
|
||||||
# Check if file has supported extension
|
|
||||||
ext = os.path.splitext(entry.name)[1].lower()
|
ext = os.path.splitext(entry.name)[1].lower()
|
||||||
if ext in self.file_extensions:
|
if ext in self.file_extensions:
|
||||||
file_path = entry.path.replace(os.sep, "/")
|
file_path = entry.path.replace(os.sep, "/")
|
||||||
await self._process_single_file(file_path, original_root, models)
|
await self._process_single_file(file_path, original_root, models)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
elif entry.is_dir(follow_symlinks=True):
|
elif entry.is_dir(follow_symlinks=True):
|
||||||
# For directories, continue scanning with original path
|
|
||||||
await scan_recursive(entry.path, visited_paths)
|
await scan_recursive(entry.path, visited_paths)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||||
@@ -321,14 +344,11 @@ class ModelScanner:
|
|||||||
async def move_model(self, source_path: str, target_path: str) -> bool:
|
async def move_model(self, source_path: str, target_path: str) -> bool:
|
||||||
"""Move a model and its associated files to a new location"""
|
"""Move a model and its associated files to a new location"""
|
||||||
try:
|
try:
|
||||||
# Keep original path format
|
|
||||||
source_path = source_path.replace(os.sep, '/')
|
source_path = source_path.replace(os.sep, '/')
|
||||||
target_path = target_path.replace(os.sep, '/')
|
target_path = target_path.replace(os.sep, '/')
|
||||||
|
|
||||||
# Get file extension from source
|
|
||||||
file_ext = os.path.splitext(source_path)[1]
|
file_ext = os.path.splitext(source_path)[1]
|
||||||
|
|
||||||
# If no extension or not in supported extensions, return False
|
|
||||||
if not file_ext or file_ext.lower() not in self.file_extensions:
|
if not file_ext or file_ext.lower() not in self.file_extensions:
|
||||||
logger.error(f"Invalid file extension for model: {file_ext}")
|
logger.error(f"Invalid file extension for model: {file_ext}")
|
||||||
return False
|
return False
|
||||||
@@ -340,7 +360,6 @@ class ModelScanner:
|
|||||||
|
|
||||||
target_file = os.path.join(target_path, f"{base_name}{file_ext}").replace(os.sep, '/')
|
target_file = os.path.join(target_path, f"{base_name}{file_ext}").replace(os.sep, '/')
|
||||||
|
|
||||||
# Use real paths for file operations
|
|
||||||
real_source = os.path.realpath(source_path)
|
real_source = os.path.realpath(source_path)
|
||||||
real_target = os.path.realpath(target_file)
|
real_target = os.path.realpath(target_file)
|
||||||
|
|
||||||
@@ -356,10 +375,8 @@ class ModelScanner:
|
|||||||
file_size
|
file_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use real paths for file operations
|
|
||||||
shutil.move(real_source, real_target)
|
shutil.move(real_source, real_target)
|
||||||
|
|
||||||
# Move associated files
|
|
||||||
source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json")
|
source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json")
|
||||||
metadata = None
|
metadata = None
|
||||||
if os.path.exists(source_metadata):
|
if os.path.exists(source_metadata):
|
||||||
@@ -367,7 +384,6 @@ class ModelScanner:
|
|||||||
shutil.move(source_metadata, target_metadata)
|
shutil.move(source_metadata, target_metadata)
|
||||||
metadata = await self._update_metadata_paths(target_metadata, target_file)
|
metadata = await self._update_metadata_paths(target_metadata, target_file)
|
||||||
|
|
||||||
# Move preview file if exists
|
|
||||||
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
|
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
|
||||||
'.png', '.jpeg', '.jpg', '.mp4']
|
'.png', '.jpeg', '.jpg', '.mp4']
|
||||||
for ext in preview_extensions:
|
for ext in preview_extensions:
|
||||||
@@ -377,7 +393,6 @@ class ModelScanner:
|
|||||||
shutil.move(source_preview, target_preview)
|
shutil.move(source_preview, target_preview)
|
||||||
break
|
break
|
||||||
|
|
||||||
# Update cache
|
|
||||||
await self.update_single_model_cache(source_path, target_file, metadata)
|
await self.update_single_model_cache(source_path, target_file, metadata)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@@ -392,10 +407,8 @@ class ModelScanner:
|
|||||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||||
metadata = json.load(f)
|
metadata = json.load(f)
|
||||||
|
|
||||||
# Update file_path
|
|
||||||
metadata['file_path'] = model_path.replace(os.sep, '/')
|
metadata['file_path'] = model_path.replace(os.sep, '/')
|
||||||
|
|
||||||
# Update preview_url if exists
|
|
||||||
if 'preview_url' in metadata:
|
if 'preview_url' in metadata:
|
||||||
preview_dir = os.path.dirname(model_path)
|
preview_dir = os.path.dirname(model_path)
|
||||||
preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0]
|
preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0]
|
||||||
@@ -403,7 +416,6 @@ class ModelScanner:
|
|||||||
new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}")
|
new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}")
|
||||||
metadata['preview_url'] = new_preview_path.replace(os.sep, '/')
|
metadata['preview_url'] = new_preview_path.replace(os.sep, '/')
|
||||||
|
|
||||||
# Save updated metadata
|
|
||||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
@@ -417,7 +429,6 @@ class ModelScanner:
|
|||||||
"""Update cache after a model has been moved or modified"""
|
"""Update cache after a model has been moved or modified"""
|
||||||
cache = await self.get_cached_data()
|
cache = await self.get_cached_data()
|
||||||
|
|
||||||
# Find the existing item to remove its tags from count
|
|
||||||
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
|
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
|
||||||
if existing_item and 'tags' in existing_item:
|
if existing_item and 'tags' in existing_item:
|
||||||
for tag in existing_item.get('tags', []):
|
for tag in existing_item.get('tags', []):
|
||||||
@@ -426,19 +437,15 @@ class ModelScanner:
|
|||||||
if self._tags_count[tag] == 0:
|
if self._tags_count[tag] == 0:
|
||||||
del self._tags_count[tag]
|
del self._tags_count[tag]
|
||||||
|
|
||||||
# Remove old path from hash index if exists
|
|
||||||
self._hash_index.remove_by_path(original_path)
|
self._hash_index.remove_by_path(original_path)
|
||||||
|
|
||||||
# Remove the old entry from raw_data
|
|
||||||
cache.raw_data = [
|
cache.raw_data = [
|
||||||
item for item in cache.raw_data
|
item for item in cache.raw_data
|
||||||
if item['file_path'] != original_path
|
if item['file_path'] != original_path
|
||||||
]
|
]
|
||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
# If this is an update to an existing path (not a move), ensure folder is preserved
|
|
||||||
if original_path == new_path:
|
if original_path == new_path:
|
||||||
# Find the folder from existing entries or calculate it
|
|
||||||
existing_folder = next((item['folder'] for item in cache.raw_data
|
existing_folder = next((item['folder'] for item in cache.raw_data
|
||||||
if item['file_path'] == original_path), None)
|
if item['file_path'] == original_path), None)
|
||||||
if existing_folder:
|
if existing_folder:
|
||||||
@@ -446,31 +453,24 @@ class ModelScanner:
|
|||||||
else:
|
else:
|
||||||
metadata['folder'] = self._calculate_folder(new_path)
|
metadata['folder'] = self._calculate_folder(new_path)
|
||||||
else:
|
else:
|
||||||
# For moved files, recalculate the folder
|
|
||||||
metadata['folder'] = self._calculate_folder(new_path)
|
metadata['folder'] = self._calculate_folder(new_path)
|
||||||
|
|
||||||
# Add the updated metadata to raw_data
|
|
||||||
cache.raw_data.append(metadata)
|
cache.raw_data.append(metadata)
|
||||||
|
|
||||||
# Update hash index with new path
|
|
||||||
if 'sha256' in metadata:
|
if 'sha256' in metadata:
|
||||||
self._hash_index.add_entry(metadata['sha256'].lower(), new_path)
|
self._hash_index.add_entry(metadata['sha256'].lower(), new_path)
|
||||||
|
|
||||||
# Update folders list
|
|
||||||
all_folders = set(item['folder'] for item in cache.raw_data)
|
all_folders = set(item['folder'] for item in cache.raw_data)
|
||||||
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||||
|
|
||||||
# Update tags count with the new/updated tags
|
|
||||||
if 'tags' in metadata:
|
if 'tags' in metadata:
|
||||||
for tag in metadata.get('tags', []):
|
for tag in metadata.get('tags', []):
|
||||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||||
|
|
||||||
# Resort cache
|
|
||||||
await cache.resort()
|
await cache.resort()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Hash index functionality (common for all model types)
|
|
||||||
def has_hash(self, sha256: str) -> bool:
|
def has_hash(self, sha256: str) -> bool:
|
||||||
"""Check if a model with given hash exists"""
|
"""Check if a model with given hash exists"""
|
||||||
return self._hash_index.has_hash(sha256.lower())
|
return self._hash_index.has_hash(sha256.lower())
|
||||||
@@ -485,12 +485,10 @@ class ModelScanner:
|
|||||||
|
|
||||||
def get_preview_url_by_hash(self, sha256: str) -> Optional[str]:
|
def get_preview_url_by_hash(self, sha256: str) -> Optional[str]:
|
||||||
"""Get preview static URL for a model by its hash"""
|
"""Get preview static URL for a model by its hash"""
|
||||||
# Get the file path first
|
|
||||||
file_path = self._hash_index.get_path(sha256.lower())
|
file_path = self._hash_index.get_path(sha256.lower())
|
||||||
if not file_path:
|
if not file_path:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Determine the preview file path (typically same name with different extension)
|
|
||||||
base_name = os.path.splitext(file_path)[0]
|
base_name = os.path.splitext(file_path)[0]
|
||||||
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
|
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
|
||||||
'.png', '.jpeg', '.jpg', '.mp4']
|
'.png', '.jpeg', '.jpg', '.mp4']
|
||||||
@@ -498,52 +496,42 @@ class ModelScanner:
|
|||||||
for ext in preview_extensions:
|
for ext in preview_extensions:
|
||||||
preview_path = f"{base_name}{ext}"
|
preview_path = f"{base_name}{ext}"
|
||||||
if os.path.exists(preview_path):
|
if os.path.exists(preview_path):
|
||||||
# Convert to static URL using config
|
|
||||||
return config.get_preview_static_url(preview_path)
|
return config.get_preview_static_url(preview_path)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
|
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
|
||||||
"""Get top tags sorted by count"""
|
"""Get top tags sorted by count"""
|
||||||
# Make sure cache is initialized
|
|
||||||
await self.get_cached_data()
|
await self.get_cached_data()
|
||||||
|
|
||||||
# Sort tags by count in descending order
|
|
||||||
sorted_tags = sorted(
|
sorted_tags = sorted(
|
||||||
[{"tag": tag, "count": count} for tag, count in self._tags_count.items()],
|
[{"tag": tag, "count": count} for tag, count in self._tags_count.items()],
|
||||||
key=lambda x: x['count'],
|
key=lambda x: x['count'],
|
||||||
reverse=True
|
reverse=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return limited number
|
|
||||||
return sorted_tags[:limit]
|
return sorted_tags[:limit]
|
||||||
|
|
||||||
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
|
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
|
||||||
"""Get base models sorted by frequency"""
|
"""Get base models sorted by frequency"""
|
||||||
# Make sure cache is initialized
|
|
||||||
cache = await self.get_cached_data()
|
cache = await self.get_cached_data()
|
||||||
|
|
||||||
# Count base model occurrences
|
|
||||||
base_model_counts = {}
|
base_model_counts = {}
|
||||||
for model in cache.raw_data:
|
for model in cache.raw_data:
|
||||||
if 'base_model' in model and model['base_model']:
|
if 'base_model' in model and model['base_model']:
|
||||||
base_model = model['base_model']
|
base_model = model['base_model']
|
||||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||||
|
|
||||||
# Sort base models by count
|
|
||||||
sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()]
|
sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()]
|
||||||
sorted_models.sort(key=lambda x: x['count'], reverse=True)
|
sorted_models.sort(key=lambda x: x['count'], reverse=True)
|
||||||
|
|
||||||
# Return limited number
|
|
||||||
return sorted_models[:limit]
|
return sorted_models[:limit]
|
||||||
|
|
||||||
async def get_model_info_by_name(self, name):
|
async def get_model_info_by_name(self, name):
|
||||||
"""Get model information by name"""
|
"""Get model information by name"""
|
||||||
try:
|
try:
|
||||||
# Get cached data
|
|
||||||
cache = await self.get_cached_data()
|
cache = await self.get_cached_data()
|
||||||
|
|
||||||
# Find the model by name
|
|
||||||
for model in cache.raw_data:
|
for model in cache.raw_data:
|
||||||
if model.get("file_name") == name:
|
if model.get("file_name") == name:
|
||||||
return model
|
return model
|
||||||
|
|||||||
@@ -35,9 +35,61 @@ class RecipeScanner:
|
|||||||
if lora_scanner:
|
if lora_scanner:
|
||||||
self._lora_scanner = lora_scanner
|
self._lora_scanner = lora_scanner
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
# Initialization will be scheduled by LoraManager
|
|
||||||
|
|
||||||
|
async def initialize_in_background(self) -> None:
|
||||||
|
"""Initialize cache in background using thread pool"""
|
||||||
|
try:
|
||||||
|
# Set initial empty cache to avoid None reference errors
|
||||||
|
if self._cache is None:
|
||||||
|
self._cache = RecipeCache(
|
||||||
|
raw_data=[],
|
||||||
|
sorted_by_name=[],
|
||||||
|
sorted_by_date=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark as initializing to prevent concurrent initializations
|
||||||
|
self._is_initializing = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use thread pool to execute CPU-intensive operations
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None, # Use default thread pool
|
||||||
|
self._initialize_recipe_cache_sync # Run synchronous version in thread
|
||||||
|
)
|
||||||
|
logger.info("Recipe cache initialization completed in background thread")
|
||||||
|
finally:
|
||||||
|
# Mark initialization as complete regardless of outcome
|
||||||
|
self._is_initializing = False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Recipe Scanner: Error initializing cache in background: {e}")
|
||||||
|
|
||||||
|
def _initialize_recipe_cache_sync(self):
|
||||||
|
"""Synchronous version of recipe cache initialization for thread pool execution"""
|
||||||
|
try:
|
||||||
|
# Create a new event loop for this thread
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
# Create a synchronous method to bypass the async lock
|
||||||
|
def sync_initialize_cache():
|
||||||
|
# Directly call the internal scan method to avoid lock issues
|
||||||
|
raw_data = loop.run_until_complete(self.scan_all_recipes())
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
self._cache.raw_data = raw_data
|
||||||
|
loop.run_until_complete(self._cache.resort())
|
||||||
|
|
||||||
|
return self._cache
|
||||||
|
|
||||||
|
# Run our sync initialization that avoids lock conflicts
|
||||||
|
return sync_initialize_cache()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in thread-based recipe cache initialization: {e}")
|
||||||
|
finally:
|
||||||
|
# Clean up the event loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def recipes_dir(self) -> str:
|
def recipes_dir(self) -> str:
|
||||||
"""Get path to recipes directory"""
|
"""Get path to recipes directory"""
|
||||||
@@ -60,49 +112,48 @@ class RecipeScanner:
|
|||||||
if self._is_initializing and not force_refresh:
|
if self._is_initializing and not force_refresh:
|
||||||
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
|
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
|
||||||
|
|
||||||
# Try to acquire the lock with a timeout to prevent deadlocks
|
# If force refresh is requested, initialize the cache directly
|
||||||
try:
|
if force_refresh:
|
||||||
async with self._initialization_lock:
|
# Try to acquire the lock with a timeout to prevent deadlocks
|
||||||
# Check again after acquiring the lock
|
try:
|
||||||
if self._cache is not None and not force_refresh:
|
async with self._initialization_lock:
|
||||||
return self._cache
|
# Mark as initializing to prevent concurrent initializations
|
||||||
|
self._is_initializing = True
|
||||||
# Mark as initializing to prevent concurrent initializations
|
|
||||||
self._is_initializing = True
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Remove dependency on lora scanner initialization
|
|
||||||
# Scan for recipe data directly
|
|
||||||
raw_data = await self.scan_all_recipes()
|
|
||||||
|
|
||||||
# Update cache
|
try:
|
||||||
self._cache = RecipeCache(
|
# Scan for recipe data directly
|
||||||
raw_data=raw_data,
|
raw_data = await self.scan_all_recipes()
|
||||||
sorted_by_name=[],
|
|
||||||
sorted_by_date=[]
|
# Update cache
|
||||||
)
|
self._cache = RecipeCache(
|
||||||
|
raw_data=raw_data,
|
||||||
|
sorted_by_name=[],
|
||||||
|
sorted_by_date=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resort cache
|
||||||
|
await self._cache.resort()
|
||||||
|
|
||||||
|
return self._cache
|
||||||
|
|
||||||
# Resort cache
|
except Exception as e:
|
||||||
await self._cache.resort()
|
logger.error(f"Recipe Manager: Error initializing cache: {e}", exc_info=True)
|
||||||
|
# Create empty cache on error
|
||||||
return self._cache
|
self._cache = RecipeCache(
|
||||||
|
raw_data=[],
|
||||||
except Exception as e:
|
sorted_by_name=[],
|
||||||
logger.error(f"Recipe Manager: Error initializing cache: {e}", exc_info=True)
|
sorted_by_date=[]
|
||||||
# Create empty cache on error
|
)
|
||||||
self._cache = RecipeCache(
|
return self._cache
|
||||||
raw_data=[],
|
finally:
|
||||||
sorted_by_name=[],
|
# Mark initialization as complete
|
||||||
sorted_by_date=[]
|
self._is_initializing = False
|
||||||
)
|
|
||||||
return self._cache
|
except Exception as e:
|
||||||
finally:
|
logger.error(f"Unexpected error in get_cached_data: {e}")
|
||||||
# Mark initialization as complete
|
|
||||||
self._is_initializing = False
|
|
||||||
|
|
||||||
except Exception as e:
|
# Return the cache (may be empty or partially initialized)
|
||||||
logger.error(f"Unexpected error in get_cached_data: {e}")
|
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
|
||||||
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
|
|
||||||
|
|
||||||
async def scan_all_recipes(self) -> List[Dict]:
|
async def scan_all_recipes(self) -> List[Dict]:
|
||||||
"""Scan all recipe JSON files and return metadata"""
|
"""Scan all recipe JSON files and return metadata"""
|
||||||
|
|||||||
Reference in New Issue
Block a user