Refactor cache initialization in LoraManager and RecipeScanner for improved background processing and error handling

This commit is contained in:
Will Miao
2025-04-10 11:34:19 +08:00
parent 8fdfb68741
commit 048d486fa6
6 changed files with 266 additions and 214 deletions

View File

@@ -1,5 +1,4 @@
import asyncio import asyncio
import os
from server import PromptServer # type: ignore from server import PromptServer # type: ignore
from .config import config from .config import config
from .routes.lora_routes import LoraRoutes from .routes.lora_routes import LoraRoutes
@@ -10,9 +9,6 @@ from .services.lora_scanner import LoraScanner
from .services.checkpoint_scanner import CheckpointScanner from .services.checkpoint_scanner import CheckpointScanner
from .services.recipe_scanner import RecipeScanner from .services.recipe_scanner import RecipeScanner
from .services.file_monitor import LoraFileMonitor, CheckpointFileMonitor from .services.file_monitor import LoraFileMonitor, CheckpointFileMonitor
from .services.lora_cache import LoraCache
from .services.recipe_cache import RecipeCache
from .services.model_cache import ModelCache
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -117,68 +113,12 @@ class LoraManager:
"""Schedule cache initialization in the running event loop""" """Schedule cache initialization in the running event loop"""
try: try:
# Create low-priority initialization tasks # Create low-priority initialization tasks
lora_task = asyncio.create_task(cls._initialize_lora_cache(lora_scanner), name='lora_cache_init') lora_task = asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init')
checkpoint_task = asyncio.create_task(cls._initialize_checkpoint_cache(checkpoint_scanner), name='checkpoint_cache_init') checkpoint_task = asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init')
recipe_task = asyncio.create_task(cls._initialize_recipe_cache(recipe_scanner), name='recipe_cache_init') recipe_task = asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
logger.info("Cache initialization tasks scheduled to run in background")
except Exception as e: except Exception as e:
logger.error(f"LoRA Manager: Error scheduling cache initialization: {e}") logger.error(f"LoRA Manager: Error scheduling cache initialization: {e}")
@classmethod
async def _initialize_lora_cache(cls, scanner: LoraScanner):
"""Initialize lora cache in background"""
try:
# Set initial placeholder cache
scanner._cache = LoraCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# 使用线程池执行耗时操作
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None, # 使用默认线程池
lambda: scanner.get_cached_data_sync(force_refresh=True) # 创建同步版本的方法
)
# Load cache in phases
# await scanner.get_cached_data(force_refresh=True)
except Exception as e:
logger.error(f"LoRA Manager: Error initializing lora cache: {e}")
@classmethod
async def _initialize_checkpoint_cache(cls, scanner: CheckpointScanner):
"""Initialize checkpoint cache in background"""
try:
# Set initial placeholder cache
scanner._cache = ModelCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# Load cache in phases
await scanner.get_cached_data(force_refresh=True)
except Exception as e:
logger.error(f"LoRA Manager: Error initializing checkpoint cache: {e}")
@classmethod
async def _initialize_recipe_cache(cls, scanner: RecipeScanner):
"""Initialize recipe cache in background with a delay"""
try:
# Set initial empty cache
scanner._cache = RecipeCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[]
)
# Force refresh to load the actual data
await scanner.get_cached_data(force_refresh=True)
except Exception as e:
logger.error(f"LoRA Manager: Error initializing recipe cache: {e}")
@classmethod @classmethod
async def _cleanup(cls, app): async def _cleanup(cls, app):
"""Cleanup resources""" """Cleanup resources"""

View File

@@ -2,12 +2,14 @@ import os
import json import json
import asyncio import asyncio
import aiohttp import aiohttp
import jinja2
from aiohttp import web from aiohttp import web
import logging import logging
from datetime import datetime from datetime import datetime
from ..services.checkpoint_scanner import CheckpointScanner from ..services.checkpoint_scanner import CheckpointScanner
from ..config import config from ..config import config
from ..services.settings_manager import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -16,6 +18,10 @@ class CheckpointsRoutes:
def __init__(self): def __init__(self):
self.scanner = CheckpointScanner() self.scanner = CheckpointScanner()
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
def setup_routes(self, app): def setup_routes(self, app):
"""Register routes with the aiohttp app""" """Register routes with the aiohttp app"""
@@ -144,3 +150,59 @@ class CheckpointsRoutes:
except Exception as e: except Exception as e:
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True) logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500) return web.json_response({"error": str(e)}, status=500)
async def handle_checkpoints_page(self, request: web.Request) -> web.Response:
"""Handle GET /checkpoints request"""
try:
# 检查缓存初始化状态根据initialize_in_background的工作方式调整判断逻辑
is_initializing = (
self.scanner._cache is None or
len(self.scanner._cache.raw_data) == 0 or
hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing
)
if is_initializing:
# 如果正在初始化,返回一个只包含加载提示的页面
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=[], # 空文件夹列表
is_initializing=True, # 新增标志
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
)
logger.info("Checkpoints page is initializing, returning loading page")
else:
# 正常流程 - 获取已经初始化好的缓存数据
try:
cache = await self.scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=cache.folders,
is_initializing=False,
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
)
logger.debug(f"Checkpoints page loaded successfully with {len(cache.raw_data)} items")
except Exception as cache_error:
logger.error(f"Error loading checkpoints cache data: {cache_error}")
# 如果获取缓存失败,也显示初始化页面
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
logger.info("Checkpoints cache error, returning initialization page")
return web.Response(
text=rendered,
content_type='text/html'
)
except Exception as e:
logger.error(f"Error handling checkpoints request: {e}", exc_info=True)
return web.Response(
text="Error loading checkpoints page",
status=500
)

View File

@@ -58,13 +58,11 @@ class LoraRoutes:
async def handle_loras_page(self, request: web.Request) -> web.Response: async def handle_loras_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras request""" """Handle GET /loras request"""
try: try:
# 检查缓存初始化状态,增强判断条件 # 检查缓存初始化状态,根据initialize_in_background的工作方式调整判断逻辑
is_initializing = ( is_initializing = (
self.scanner._cache is None or self.scanner._cache is None or
(self.scanner._initialization_task is not None and len(self.scanner._cache.raw_data) == 0 or
not self.scanner._initialization_task.done()) or hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing
(self.scanner._cache is not None and len(self.scanner._cache.raw_data) == 0 and
self.scanner._initialization_task is not None)
) )
if is_initializing: if is_initializing:
@@ -79,7 +77,7 @@ class LoraRoutes:
logger.info("Loras page is initializing, returning loading page") logger.info("Loras page is initializing, returning loading page")
else: else:
# 正常流程 - 但不要等待缓存刷新 # 正常流程 - 获取已经初始化好的缓存数据
try: try:
cache = await self.scanner.get_cached_data(force_refresh=False) cache = await self.scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('loras.html') template = self.template_env.get_template('loras.html')
@@ -117,32 +115,45 @@ class LoraRoutes:
async def handle_recipes_page(self, request: web.Request) -> web.Response: async def handle_recipes_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras/recipes request""" """Handle GET /loras/recipes request"""
try: try:
# Check cache initialization status # 检查缓存初始化状态与handle_loras_page保持一致的逻辑
is_initializing = ( is_initializing = (
self.recipe_scanner._cache is None and self.recipe_scanner._cache is None or
(self.recipe_scanner._initialization_task is not None and len(self.recipe_scanner._cache.raw_data) == 0 or
not self.recipe_scanner._initialization_task.done()) hasattr(self.recipe_scanner, '_is_initializing') and self.recipe_scanner._is_initializing
) )
if is_initializing: if is_initializing:
# If initializing, return a loading page # 如果正在初始化,返回一个只包含加载提示的页面
template = self.template_env.get_template('recipes.html') template = self.template_env.get_template('recipes.html')
rendered = template.render( rendered = template.render(
is_initializing=True, is_initializing=True,
settings=settings, settings=settings,
request=request # Pass the request object to the template request=request # Pass the request object to the template
) )
else:
# return empty recipes
recipes_data = []
template = self.template_env.get_template('recipes.html') logger.info("Recipes page is initializing, returning loading page")
rendered = template.render( else:
recipes=recipes_data, # 正常流程 - 获取已经初始化好的缓存数据
is_initializing=False, try:
settings=settings, cache = await self.recipe_scanner.get_cached_data(force_refresh=False)
request=request # Pass the request object to the template template = self.template_env.get_template('recipes.html')
) rendered = template.render(
recipes=[], # Frontend will load recipes via API
is_initializing=False,
settings=settings,
request=request # Pass the request object to the template
)
logger.debug(f"Recipes page loaded successfully with {len(cache.raw_data)} items")
except Exception as cache_error:
logger.error(f"Error loading recipe cache data: {cache_error}")
# 如果获取缓存失败,也显示初始化页面
template = self.template_env.get_template('recipes.html')
rendered = template.render(
is_initializing=True,
settings=settings,
request=request
)
logger.info("Recipe cache error, returning initialization page")
return web.Response( return web.Response(
text=rendered, text=rendered,

View File

@@ -1146,7 +1146,7 @@ class RecipeRoutes:
return web.json_response({'success': False, 'error': 'Lora hash is required'}, status=400) return web.json_response({'success': False, 'error': 'Lora hash is required'}, status=400)
# Log the search parameters # Log the search parameters
logger.info(f"Getting recipes for Lora by hash: {lora_hash}") logger.debug(f"Getting recipes for Lora by hash: {lora_hash}")
# Get all recipes from cache # Get all recipes from cache
cache = await self.recipe_scanner.get_cached_data() cache = await self.recipe_scanner.get_cached_data()

View File

@@ -34,49 +34,96 @@ class ModelScanner:
self.file_extensions = file_extensions self.file_extensions = file_extensions
self._cache = None self._cache = None
self._hash_index = hash_index or ModelHashIndex() self._hash_index = hash_index or ModelHashIndex()
self._initialization_lock = asyncio.Lock()
self._initialization_task = None
self.file_monitor = None self.file_monitor = None
self._tags_count = {} # Dictionary to store tag counts self._tags_count = {} # Dictionary to store tag counts
self._is_initializing = False # Flag to track initialization state
def set_file_monitor(self, monitor): def set_file_monitor(self, monitor):
"""Set file monitor instance""" """Set file monitor instance"""
self.file_monitor = monitor self.file_monitor = monitor
async def get_cached_data(self, force_refresh: bool = False) -> ModelCache: async def initialize_in_background(self) -> None:
"""Get cached model data, refresh if needed""" """Initialize cache in background using thread pool"""
async with self._initialization_lock: try:
# Return empty cache if not initialized and no refresh requested # Set initial empty cache to avoid None reference errors
if self._cache is None and not force_refresh: if self._cache is None:
return ModelCache( self._cache = ModelCache(
raw_data=[], raw_data=[],
sorted_by_name=[], sorted_by_name=[],
sorted_by_date=[], sorted_by_date=[],
folders=[] folders=[]
) )
# Wait for ongoing initialization if any
if self._initialization_task and not self._initialization_task.done():
try:
await self._initialization_task
except Exception as e:
logger.error(f"Cache initialization failed: {e}")
self._initialization_task = None
if (self._cache is None or force_refresh): # Set initializing flag to true
# Create new initialization task self._is_initializing = True
if not self._initialization_task or self._initialization_task.done():
self._initialization_task = asyncio.create_task(self._initialize_cache()) start_time = time.time()
# Use thread pool to execute CPU-intensive operations
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None, # Use default thread pool
self._initialize_cache_sync # Run synchronous version in thread
)
logger.info(f"{self.model_type.capitalize()} cache initialized in {time.time() - start_time:.2f} seconds. Found {len(self._cache.raw_data)} models")
except Exception as e:
logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache in background: {e}")
finally:
# Always clear the initializing flag when done
self._is_initializing = False
def _initialize_cache_sync(self):
"""Synchronous version of cache initialization for thread pool execution"""
try:
# Create a new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Create a synchronous method to bypass the async lock
def sync_initialize_cache():
# Directly call the scan method to avoid lock issues
raw_data = loop.run_until_complete(self.scan_all_models())
try: # Update hash index and tags count
await self._initialization_task for model_data in raw_data:
except Exception as e: if 'sha256' in model_data and 'file_path' in model_data:
logger.error(f"Cache initialization failed: {e}") self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
# Continue using old cache if it exists
if self._cache is None: # Count tags
raise # Raise exception if no cache available if 'tags' in model_data and model_data['tags']:
for tag in model_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Update cache
self._cache.raw_data = raw_data
loop.run_until_complete(self._cache.resort())
return self._cache
return self._cache # Run our sync initialization that avoids lock conflicts
return sync_initialize_cache()
except Exception as e:
logger.error(f"Error in thread-based {self.model_type} cache initialization: {e}")
finally:
# Clean up the event loop
loop.close()
async def get_cached_data(self, force_refresh: bool = False) -> ModelCache:
"""Get cached model data, refresh if needed"""
# If cache is not initialized, return an empty cache
# Actual initialization should be done via initialize_in_background
if self._cache is None and not force_refresh:
return ModelCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# If force refresh is requested, initialize the cache directly
if force_refresh:
await self._initialize_cache()
return self._cache
async def _initialize_cache(self) -> None: async def _initialize_cache(self) -> None:
"""Initialize or refresh the cache""" """Initialize or refresh the cache"""
@@ -112,7 +159,6 @@ class ModelScanner:
# Resort cache # Resort cache
await self._cache.resort() await self._cache.resort()
self._initialization_task = None
logger.info(f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} models") logger.info(f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} models")
except Exception as e: except Exception as e:
logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache: {e}") logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache: {e}")
@@ -157,12 +203,10 @@ class ModelScanner:
async def _get_file_info(self, file_path: str) -> Optional[BaseModelMetadata]: async def _get_file_info(self, file_path: str) -> Optional[BaseModelMetadata]:
"""Get model file info and metadata (extensible for different model types)""" """Get model file info and metadata (extensible for different model types)"""
# Implementation may vary by model type - override in subclasses if needed
return await get_file_info(file_path, self.model_class) return await get_file_info(file_path, self.model_class)
def _calculate_folder(self, file_path: str) -> str: def _calculate_folder(self, file_path: str) -> str:
"""Calculate the folder path for a model file""" """Calculate the folder path for a model file"""
# Use original path to calculate relative path
for root in self.get_model_roots(): for root in self.get_model_roots():
if file_path.startswith(root): if file_path.startswith(root):
rel_path = os.path.relpath(file_path, root) rel_path = os.path.relpath(file_path, root)
@@ -172,11 +216,9 @@ class ModelScanner:
# Common methods shared between scanners # Common methods shared between scanners
async def _process_model_file(self, file_path: str, root_path: str) -> Dict: async def _process_model_file(self, file_path: str, root_path: str) -> Dict:
"""Process a single model file and return its metadata""" """Process a single model file and return its metadata"""
# Try loading existing metadata
metadata = await load_metadata(file_path, self.model_class) metadata = await load_metadata(file_path, self.model_class)
if metadata is None: if metadata is None:
# Try to find and use .civitai.info file first
civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info" civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info"
if os.path.exists(civitai_info_path): if os.path.exists(civitai_info_path):
try: try:
@@ -185,11 +227,9 @@ class ModelScanner:
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None) file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
if file_info: if file_info:
# Create a minimal file_info with the required fields
file_name = os.path.splitext(os.path.basename(file_path))[0] file_name = os.path.splitext(os.path.basename(file_path))[0]
file_info['name'] = file_name file_info['name'] = file_name
# Use from_civitai_info to create metadata
metadata = self.model_class.from_civitai_info(version_info, file_info, file_path) metadata = self.model_class.from_civitai_info(version_info, file_info, file_path)
metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path)) metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path))
await save_metadata(file_path, metadata) await save_metadata(file_path, metadata)
@@ -197,14 +237,11 @@ class ModelScanner:
except Exception as e: except Exception as e:
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}") logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
# If still no metadata, create new metadata
if metadata is None: if metadata is None:
metadata = await self._get_file_info(file_path) metadata = await self._get_file_info(file_path)
# Convert to dict and add folder info
model_data = metadata.to_dict() model_data = metadata.to_dict()
# Try to fetch missing metadata from Civitai if needed
await self._fetch_missing_metadata(file_path, model_data) await self._fetch_missing_metadata(file_path, model_data)
rel_path = os.path.relpath(file_path, root_path) rel_path = os.path.relpath(file_path, root_path)
folder = os.path.dirname(rel_path) folder = os.path.dirname(rel_path)
@@ -215,59 +252,47 @@ class ModelScanner:
async def _fetch_missing_metadata(self, file_path: str, model_data: Dict) -> None: async def _fetch_missing_metadata(self, file_path: str, model_data: Dict) -> None:
"""Fetch missing description and tags from Civitai if needed""" """Fetch missing description and tags from Civitai if needed"""
try: try:
# Skip if already marked as deleted on Civitai
if model_data.get('civitai_deleted', False): if model_data.get('civitai_deleted', False):
logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai") logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai")
return return
# Check if we need to fetch additional metadata from Civitai
needs_metadata_update = False needs_metadata_update = False
model_id = None model_id = None
# Check if we have Civitai model ID but missing metadata
if model_data.get('civitai'): if model_data.get('civitai'):
model_id = model_data['civitai'].get('modelId') model_id = model_data['civitai'].get('modelId')
if model_id: if model_id:
model_id = str(model_id) model_id = str(model_id)
# Check if tags or description are missing
tags_missing = not model_data.get('tags') or len(model_data.get('tags', [])) == 0 tags_missing = not model_data.get('tags') or len(model_data.get('tags', [])) == 0
desc_missing = not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "") desc_missing = not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")
needs_metadata_update = tags_missing or desc_missing needs_metadata_update = tags_missing or desc_missing
# Fetch missing metadata if needed
if needs_metadata_update and model_id: if needs_metadata_update and model_id:
logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}") logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}")
from ..services.civitai_client import CivitaiClient from ..services.civitai_client import CivitaiClient
client = CivitaiClient() client = CivitaiClient()
# Get metadata and status code
model_metadata, status_code = await client.get_model_metadata(model_id) model_metadata, status_code = await client.get_model_metadata(model_id)
await client.close() await client.close()
# Handle 404 status (model deleted from Civitai)
if status_code == 404: if status_code == 404:
logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)") logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)")
model_data['civitai_deleted'] = True model_data['civitai_deleted'] = True
# Save the updated metadata
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
with open(metadata_path, 'w', encoding='utf-8') as f: with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(model_data, f, indent=2, ensure_ascii=False) json.dump(model_data, f, indent=2, ensure_ascii=False)
# Process valid metadata if available
elif model_metadata: elif model_metadata:
logger.debug(f"Updating metadata for {file_path} with model ID {model_id}") logger.debug(f"Updating metadata for {file_path} with model ID {model_id}")
# Update tags if they were missing
if model_metadata.get('tags') and (not model_data.get('tags') or len(model_data.get('tags', [])) == 0): if model_metadata.get('tags') and (not model_data.get('tags') or len(model_data.get('tags', [])) == 0):
model_data['tags'] = model_metadata['tags'] model_data['tags'] = model_metadata['tags']
# Update description if it was missing
if model_metadata.get('description') and (not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")): if model_metadata.get('description') and (not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")):
model_data['modelDescription'] = model_metadata['description'] model_data['modelDescription'] = model_metadata['description']
# Save the updated metadata
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
with open(metadata_path, 'w', encoding='utf-8') as f: with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(model_data, f, indent=2, ensure_ascii=False) json.dump(model_data, f, indent=2, ensure_ascii=False)
@@ -292,14 +317,12 @@ class ModelScanner:
for entry in entries: for entry in entries:
try: try:
if entry.is_file(follow_symlinks=True): if entry.is_file(follow_symlinks=True):
# Check if file has supported extension
ext = os.path.splitext(entry.name)[1].lower() ext = os.path.splitext(entry.name)[1].lower()
if ext in self.file_extensions: if ext in self.file_extensions:
file_path = entry.path.replace(os.sep, "/") file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, models) await self._process_single_file(file_path, original_root, models)
await asyncio.sleep(0) await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True): elif entry.is_dir(follow_symlinks=True):
# For directories, continue scanning with original path
await scan_recursive(entry.path, visited_paths) await scan_recursive(entry.path, visited_paths)
except Exception as e: except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}") logger.error(f"Error processing entry {entry.path}: {e}")
@@ -321,14 +344,11 @@ class ModelScanner:
async def move_model(self, source_path: str, target_path: str) -> bool: async def move_model(self, source_path: str, target_path: str) -> bool:
"""Move a model and its associated files to a new location""" """Move a model and its associated files to a new location"""
try: try:
# Keep original path format
source_path = source_path.replace(os.sep, '/') source_path = source_path.replace(os.sep, '/')
target_path = target_path.replace(os.sep, '/') target_path = target_path.replace(os.sep, '/')
# Get file extension from source
file_ext = os.path.splitext(source_path)[1] file_ext = os.path.splitext(source_path)[1]
# If no extension or not in supported extensions, return False
if not file_ext or file_ext.lower() not in self.file_extensions: if not file_ext or file_ext.lower() not in self.file_extensions:
logger.error(f"Invalid file extension for model: {file_ext}") logger.error(f"Invalid file extension for model: {file_ext}")
return False return False
@@ -340,7 +360,6 @@ class ModelScanner:
target_file = os.path.join(target_path, f"{base_name}{file_ext}").replace(os.sep, '/') target_file = os.path.join(target_path, f"{base_name}{file_ext}").replace(os.sep, '/')
# Use real paths for file operations
real_source = os.path.realpath(source_path) real_source = os.path.realpath(source_path)
real_target = os.path.realpath(target_file) real_target = os.path.realpath(target_file)
@@ -356,10 +375,8 @@ class ModelScanner:
file_size file_size
) )
# Use real paths for file operations
shutil.move(real_source, real_target) shutil.move(real_source, real_target)
# Move associated files
source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json") source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json")
metadata = None metadata = None
if os.path.exists(source_metadata): if os.path.exists(source_metadata):
@@ -367,7 +384,6 @@ class ModelScanner:
shutil.move(source_metadata, target_metadata) shutil.move(source_metadata, target_metadata)
metadata = await self._update_metadata_paths(target_metadata, target_file) metadata = await self._update_metadata_paths(target_metadata, target_file)
# Move preview file if exists
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4', preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
'.png', '.jpeg', '.jpg', '.mp4'] '.png', '.jpeg', '.jpg', '.mp4']
for ext in preview_extensions: for ext in preview_extensions:
@@ -377,7 +393,6 @@ class ModelScanner:
shutil.move(source_preview, target_preview) shutil.move(source_preview, target_preview)
break break
# Update cache
await self.update_single_model_cache(source_path, target_file, metadata) await self.update_single_model_cache(source_path, target_file, metadata)
return True return True
@@ -392,10 +407,8 @@ class ModelScanner:
with open(metadata_path, 'r', encoding='utf-8') as f: with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f) metadata = json.load(f)
# Update file_path
metadata['file_path'] = model_path.replace(os.sep, '/') metadata['file_path'] = model_path.replace(os.sep, '/')
# Update preview_url if exists
if 'preview_url' in metadata: if 'preview_url' in metadata:
preview_dir = os.path.dirname(model_path) preview_dir = os.path.dirname(model_path)
preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0] preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0]
@@ -403,7 +416,6 @@ class ModelScanner:
new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}") new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}")
metadata['preview_url'] = new_preview_path.replace(os.sep, '/') metadata['preview_url'] = new_preview_path.replace(os.sep, '/')
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f: with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False) json.dump(metadata, f, indent=2, ensure_ascii=False)
@@ -417,7 +429,6 @@ class ModelScanner:
"""Update cache after a model has been moved or modified""" """Update cache after a model has been moved or modified"""
cache = await self.get_cached_data() cache = await self.get_cached_data()
# Find the existing item to remove its tags from count
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None) existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
if existing_item and 'tags' in existing_item: if existing_item and 'tags' in existing_item:
for tag in existing_item.get('tags', []): for tag in existing_item.get('tags', []):
@@ -426,19 +437,15 @@ class ModelScanner:
if self._tags_count[tag] == 0: if self._tags_count[tag] == 0:
del self._tags_count[tag] del self._tags_count[tag]
# Remove old path from hash index if exists
self._hash_index.remove_by_path(original_path) self._hash_index.remove_by_path(original_path)
# Remove the old entry from raw_data
cache.raw_data = [ cache.raw_data = [
item for item in cache.raw_data item for item in cache.raw_data
if item['file_path'] != original_path if item['file_path'] != original_path
] ]
if metadata: if metadata:
# If this is an update to an existing path (not a move), ensure folder is preserved
if original_path == new_path: if original_path == new_path:
# Find the folder from existing entries or calculate it
existing_folder = next((item['folder'] for item in cache.raw_data existing_folder = next((item['folder'] for item in cache.raw_data
if item['file_path'] == original_path), None) if item['file_path'] == original_path), None)
if existing_folder: if existing_folder:
@@ -446,31 +453,24 @@ class ModelScanner:
else: else:
metadata['folder'] = self._calculate_folder(new_path) metadata['folder'] = self._calculate_folder(new_path)
else: else:
# For moved files, recalculate the folder
metadata['folder'] = self._calculate_folder(new_path) metadata['folder'] = self._calculate_folder(new_path)
# Add the updated metadata to raw_data
cache.raw_data.append(metadata) cache.raw_data.append(metadata)
# Update hash index with new path
if 'sha256' in metadata: if 'sha256' in metadata:
self._hash_index.add_entry(metadata['sha256'].lower(), new_path) self._hash_index.add_entry(metadata['sha256'].lower(), new_path)
# Update folders list
all_folders = set(item['folder'] for item in cache.raw_data) all_folders = set(item['folder'] for item in cache.raw_data)
cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Update tags count with the new/updated tags
if 'tags' in metadata: if 'tags' in metadata:
for tag in metadata.get('tags', []): for tag in metadata.get('tags', []):
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Resort cache
await cache.resort() await cache.resort()
return True return True
# Hash index functionality (common for all model types)
def has_hash(self, sha256: str) -> bool: def has_hash(self, sha256: str) -> bool:
"""Check if a model with given hash exists""" """Check if a model with given hash exists"""
return self._hash_index.has_hash(sha256.lower()) return self._hash_index.has_hash(sha256.lower())
@@ -485,12 +485,10 @@ class ModelScanner:
def get_preview_url_by_hash(self, sha256: str) -> Optional[str]: def get_preview_url_by_hash(self, sha256: str) -> Optional[str]:
"""Get preview static URL for a model by its hash""" """Get preview static URL for a model by its hash"""
# Get the file path first
file_path = self._hash_index.get_path(sha256.lower()) file_path = self._hash_index.get_path(sha256.lower())
if not file_path: if not file_path:
return None return None
# Determine the preview file path (typically same name with different extension)
base_name = os.path.splitext(file_path)[0] base_name = os.path.splitext(file_path)[0]
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4', preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
'.png', '.jpeg', '.jpg', '.mp4'] '.png', '.jpeg', '.jpg', '.mp4']
@@ -498,52 +496,42 @@ class ModelScanner:
for ext in preview_extensions: for ext in preview_extensions:
preview_path = f"{base_name}{ext}" preview_path = f"{base_name}{ext}"
if os.path.exists(preview_path): if os.path.exists(preview_path):
# Convert to static URL using config
return config.get_preview_static_url(preview_path) return config.get_preview_static_url(preview_path)
return None return None
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]: async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get top tags sorted by count""" """Get top tags sorted by count"""
# Make sure cache is initialized
await self.get_cached_data() await self.get_cached_data()
# Sort tags by count in descending order
sorted_tags = sorted( sorted_tags = sorted(
[{"tag": tag, "count": count} for tag, count in self._tags_count.items()], [{"tag": tag, "count": count} for tag, count in self._tags_count.items()],
key=lambda x: x['count'], key=lambda x: x['count'],
reverse=True reverse=True
) )
# Return limited number
return sorted_tags[:limit] return sorted_tags[:limit]
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]: async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get base models sorted by frequency""" """Get base models sorted by frequency"""
# Make sure cache is initialized
cache = await self.get_cached_data() cache = await self.get_cached_data()
# Count base model occurrences
base_model_counts = {} base_model_counts = {}
for model in cache.raw_data: for model in cache.raw_data:
if 'base_model' in model and model['base_model']: if 'base_model' in model and model['base_model']:
base_model = model['base_model'] base_model = model['base_model']
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
# Sort base models by count
sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()] sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()]
sorted_models.sort(key=lambda x: x['count'], reverse=True) sorted_models.sort(key=lambda x: x['count'], reverse=True)
# Return limited number
return sorted_models[:limit] return sorted_models[:limit]
async def get_model_info_by_name(self, name): async def get_model_info_by_name(self, name):
"""Get model information by name""" """Get model information by name"""
try: try:
# Get cached data
cache = await self.get_cached_data() cache = await self.get_cached_data()
# Find the model by name
for model in cache.raw_data: for model in cache.raw_data:
if model.get("file_name") == name: if model.get("file_name") == name:
return model return model

View File

@@ -35,9 +35,61 @@ class RecipeScanner:
if lora_scanner: if lora_scanner:
self._lora_scanner = lora_scanner self._lora_scanner = lora_scanner
self._initialized = True self._initialized = True
# Initialization will be scheduled by LoraManager
async def initialize_in_background(self) -> None:
"""Initialize cache in background using thread pool"""
try:
# Set initial empty cache to avoid None reference errors
if self._cache is None:
self._cache = RecipeCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[]
)
# Mark as initializing to prevent concurrent initializations
self._is_initializing = True
try:
# Use thread pool to execute CPU-intensive operations
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None, # Use default thread pool
self._initialize_recipe_cache_sync # Run synchronous version in thread
)
logger.info("Recipe cache initialization completed in background thread")
finally:
# Mark initialization as complete regardless of outcome
self._is_initializing = False
except Exception as e:
logger.error(f"Recipe Scanner: Error initializing cache in background: {e}")
def _initialize_recipe_cache_sync(self):
"""Synchronous version of recipe cache initialization for thread pool execution"""
try:
# Create a new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Create a synchronous method to bypass the async lock
def sync_initialize_cache():
# Directly call the internal scan method to avoid lock issues
raw_data = loop.run_until_complete(self.scan_all_recipes())
# Update cache
self._cache.raw_data = raw_data
loop.run_until_complete(self._cache.resort())
return self._cache
# Run our sync initialization that avoids lock conflicts
return sync_initialize_cache()
except Exception as e:
logger.error(f"Error in thread-based recipe cache initialization: {e}")
finally:
# Clean up the event loop
loop.close()
@property @property
def recipes_dir(self) -> str: def recipes_dir(self) -> str:
"""Get path to recipes directory""" """Get path to recipes directory"""
@@ -60,49 +112,48 @@ class RecipeScanner:
if self._is_initializing and not force_refresh: if self._is_initializing and not force_refresh:
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[]) return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
# Try to acquire the lock with a timeout to prevent deadlocks # If force refresh is requested, initialize the cache directly
try: if force_refresh:
async with self._initialization_lock: # Try to acquire the lock with a timeout to prevent deadlocks
# Check again after acquiring the lock try:
if self._cache is not None and not force_refresh: async with self._initialization_lock:
return self._cache # Mark as initializing to prevent concurrent initializations
self._is_initializing = True
# Mark as initializing to prevent concurrent initializations
self._is_initializing = True
try:
# Remove dependency on lora scanner initialization
# Scan for recipe data directly
raw_data = await self.scan_all_recipes()
# Update cache try:
self._cache = RecipeCache( # Scan for recipe data directly
raw_data=raw_data, raw_data = await self.scan_all_recipes()
sorted_by_name=[],
sorted_by_date=[] # Update cache
) self._cache = RecipeCache(
raw_data=raw_data,
sorted_by_name=[],
sorted_by_date=[]
)
# Resort cache
await self._cache.resort()
return self._cache
# Resort cache except Exception as e:
await self._cache.resort() logger.error(f"Recipe Manager: Error initializing cache: {e}", exc_info=True)
# Create empty cache on error
return self._cache self._cache = RecipeCache(
raw_data=[],
except Exception as e: sorted_by_name=[],
logger.error(f"Recipe Manager: Error initializing cache: {e}", exc_info=True) sorted_by_date=[]
# Create empty cache on error )
self._cache = RecipeCache( return self._cache
raw_data=[], finally:
sorted_by_name=[], # Mark initialization as complete
sorted_by_date=[] self._is_initializing = False
)
return self._cache except Exception as e:
finally: logger.error(f"Unexpected error in get_cached_data: {e}")
# Mark initialization as complete
self._is_initializing = False
except Exception as e: # Return the cache (may be empty or partially initialized)
logger.error(f"Unexpected error in get_cached_data: {e}") return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
async def scan_all_recipes(self) -> List[Dict]: async def scan_all_recipes(self) -> List[Dict]:
"""Scan all recipe JSON files and return metadata""" """Scan all recipe JSON files and return metadata"""