diff --git a/py/config.py b/py/config.py index a081f472..b0ca1f4c 100644 --- a/py/config.py +++ b/py/config.py @@ -73,7 +73,7 @@ class Config: """添加静态路由映射""" normalized_path = os.path.normpath(path).replace(os.sep, '/') self._route_mappings[normalized_path] = route - logger.info(f"Added route mapping: {normalized_path} -> {route}") + # logger.info(f"Added route mapping: {normalized_path} -> {route}") def map_path_to_link(self, path: str) -> str: """将目标路径映射回符号链接路径""" diff --git a/py/lora_manager.py b/py/lora_manager.py index a7ab3fb8..3e9a3bbf 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -1,16 +1,11 @@ import asyncio -import os from server import PromptServer # type: ignore from .config import config from .routes.lora_routes import LoraRoutes from .routes.api_routes import ApiRoutes from .routes.recipe_routes import RecipeRoutes from .routes.checkpoints_routes import CheckpointsRoutes -from .services.lora_scanner import LoraScanner -from .services.recipe_scanner import RecipeScanner -from .services.file_monitor import LoraFileMonitor -from .services.lora_cache import LoraCache -from .services.recipe_cache import RecipeCache +from .services.service_registry import ServiceRegistry import logging logger = logging.getLogger(__name__) @@ -23,7 +18,7 @@ class LoraManager: """Initialize and register all routes""" app = PromptServer.instance.app - added_targets = set() # 用于跟踪已添加的目标路径 + added_targets = set() # Track already added target paths # Add static routes for each lora root for idx, root in enumerate(config.loras_roots, start=1): @@ -35,15 +30,36 @@ class LoraManager: if link == root: real_root = target break - # 为原始路径添加静态路由 + # Add static route for original path app.router.add_static(preview_path, real_root) logger.info(f"Added static route {preview_path} -> {real_root}") - # 记录路由映射 + # Record route mapping config.add_route_mapping(real_root, preview_path) added_targets.add(real_root) - # 为符号链接的目标路径添加额外的静态路由 + # Get checkpoint scanner instance + checkpoint_scanner = asyncio.run(ServiceRegistry.get_checkpoint_scanner()) + + # Add static routes for each checkpoint root + for idx, root in enumerate(checkpoint_scanner.get_model_roots(), start=1): + preview_path = f'/checkpoints_static/root{idx}/preview' + + real_root = root + if root in config._path_mappings.values(): + for target, link in config._path_mappings.items(): + if link == root: + real_root = target + break + # Add static route for original path + app.router.add_static(preview_path, real_root) + logger.info(f"Added static route {preview_path} -> {real_root}") + + # Record route mapping + config.add_route_mapping(real_root, preview_path) + added_targets.add(real_root) + + # Add static routes for symlink target paths link_idx = 1 for target_path, link_path in config._path_mappings.items(): @@ -59,78 +75,89 @@ class LoraManager: app.router.add_static('/loras_static', config.static_path) # Setup feature routes - routes = LoraRoutes() + lora_routes = LoraRoutes() checkpoints_routes = CheckpointsRoutes() - # Setup file monitoring - monitor = LoraFileMonitor(routes.scanner, config.loras_roots) - monitor.start() - - routes.setup_routes(app) + # Initialize routes + lora_routes.setup_routes(app) checkpoints_routes.setup_routes(app) - ApiRoutes.setup_routes(app, monitor) + ApiRoutes.setup_routes(app) RecipeRoutes.setup_routes(app) - # Store monitor in app for cleanup - app['lora_monitor'] = monitor - - # Schedule cache initialization using the application's startup handler - app.on_startup.append(lambda app: cls._schedule_cache_init(routes.scanner, routes.recipe_scanner)) + # Schedule service initialization + app.on_startup.append(lambda app: cls._initialize_services()) # Add cleanup app.on_shutdown.append(cls._cleanup) app.on_shutdown.append(ApiRoutes.cleanup) @classmethod - async def _schedule_cache_init(cls, scanner: LoraScanner, recipe_scanner: RecipeScanner): - """Schedule cache initialization in the running event loop""" + async def _initialize_services(cls): + """Initialize all services using the ServiceRegistry""" try: - # 创建低优先级的初始化任务 - lora_task = asyncio.create_task(cls._initialize_lora_cache(scanner), name='lora_cache_init') + logger.info("LoRA Manager: Initializing services via ServiceRegistry") - # Schedule recipe cache initialization with a delay to let lora scanner initialize first - recipe_task = asyncio.create_task(cls._initialize_recipe_cache(recipe_scanner, delay=2), name='recipe_cache_init') + # Initialize CivitaiClient first to ensure it's ready for other services + civitai_client = await ServiceRegistry.get_civitai_client() + + # Get file monitors through ServiceRegistry + lora_monitor = await ServiceRegistry.get_lora_monitor() + checkpoint_monitor = await ServiceRegistry.get_checkpoint_monitor() + + # Start monitors + lora_monitor.start() + logger.info("Lora monitor started") + + # Make sure checkpoint monitor has paths before starting + await checkpoint_monitor.initialize_paths() + checkpoint_monitor.start() + logger.info("Checkpoint monitor started") + + # Register DownloadManager with ServiceRegistry + download_manager = await ServiceRegistry.get_download_manager() + + # Initialize WebSocket manager + ws_manager = await ServiceRegistry.get_websocket_manager() + + # Initialize scanners in background + lora_scanner = await ServiceRegistry.get_lora_scanner() + checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() + + # Initialize recipe scanner if needed + recipe_scanner = await ServiceRegistry.get_recipe_scanner() + + # Create low-priority initialization tasks + asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init') + asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init') + asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init') + + logger.info("LoRA Manager: All services initialized and background tasks scheduled") + except Exception as e: - logger.error(f"LoRA Manager: Error scheduling cache initialization: {e}") - - @classmethod - async def _initialize_lora_cache(cls, scanner: LoraScanner): - """Initialize lora cache in background""" - try: - # 设置初始缓存占位 - scanner._cache = LoraCache( - raw_data=[], - sorted_by_name=[], - sorted_by_date=[], - folders=[] - ) - - # 分阶段加载缓存 - await scanner.get_cached_data(force_refresh=True) - except Exception as e: - logger.error(f"LoRA Manager: Error initializing lora cache: {e}") - - @classmethod - async def _initialize_recipe_cache(cls, scanner: RecipeScanner, delay: float = 2.0): - """Initialize recipe cache in background with a delay""" - try: - # Wait for the specified delay to let lora scanner initialize first - await asyncio.sleep(delay) - - # Set initial empty cache - scanner._cache = RecipeCache( - raw_data=[], - sorted_by_name=[], - sorted_by_date=[] - ) - - # Force refresh to load the actual data - await scanner.get_cached_data(force_refresh=True) - except Exception as e: - logger.error(f"LoRA Manager: Error initializing recipe cache: {e}") + logger.error(f"LoRA Manager: Error initializing services: {e}", exc_info=True) @classmethod async def _cleanup(cls, app): - """Cleanup resources""" - if 'lora_monitor' in app: - app['lora_monitor'].stop() + """Cleanup resources using ServiceRegistry""" + try: + logger.info("LoRA Manager: Cleaning up services") + + # Get monitors from ServiceRegistry + lora_monitor = await ServiceRegistry.get_service("lora_monitor") + if lora_monitor: + lora_monitor.stop() + logger.info("Stopped LoRA monitor") + + checkpoint_monitor = await ServiceRegistry.get_service("checkpoint_monitor") + if checkpoint_monitor: + checkpoint_monitor.stop() + logger.info("Stopped checkpoint monitor") + + # Close CivitaiClient gracefully + civitai_client = await ServiceRegistry.get_service("civitai_client") + if civitai_client: + await civitai_client.close() + logger.info("Closed CivitaiClient connection") + + except Exception as e: + logger.error(f"Error during cleanup: {e}", exc_info=True) diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index 0dff5524..c064d66b 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -2,37 +2,44 @@ import os import json import logging from aiohttp import web -from typing import Dict, List +from typing import Dict -from ..utils.model_utils import determine_base_model +from ..utils.routes_common import ModelRouteUtils -from ..services.file_monitor import LoraFileMonitor -from ..services.download_manager import DownloadManager -from ..services.civitai_client import CivitaiClient from ..config import config -from ..services.lora_scanner import LoraScanner -from operator import itemgetter from ..services.websocket_manager import ws_manager from ..services.settings_manager import settings import asyncio from .update_routes import UpdateRoutes -from ..services.recipe_scanner import RecipeScanner +from ..utils.constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH +from ..utils.exif_utils import ExifUtils +from ..services.service_registry import ServiceRegistry logger = logging.getLogger(__name__) class ApiRoutes: """API route handlers for LoRA management""" - def __init__(self, file_monitor: LoraFileMonitor): - self.scanner = LoraScanner() - self.civitai_client = CivitaiClient() - self.download_manager = DownloadManager(file_monitor) + def __init__(self): + self.scanner = None # Will be initialized in setup_routes + self.civitai_client = None # Will be initialized in setup_routes + self.download_manager = None # Will be initialized in setup_routes self._download_lock = asyncio.Lock() + async def initialize_services(self): + """Initialize services from ServiceRegistry""" + self.scanner = await ServiceRegistry.get_lora_scanner() + self.civitai_client = await ServiceRegistry.get_civitai_client() + self.download_manager = await ServiceRegistry.get_download_manager() + @classmethod - def setup_routes(cls, app: web.Application, monitor: LoraFileMonitor): + def setup_routes(cls, app: web.Application): """Register API routes""" - routes = cls(monitor) + routes = cls() + + # Schedule service initialization on app startup + app.on_startup.append(lambda _: routes.initialize_services()) + app.router.add_post('/api/delete_model', routes.delete_model) app.router.add_post('/api/fetch-civitai', routes.fetch_civitai) app.router.add_post('/api/replace_preview', routes.replace_preview) @@ -48,86 +55,51 @@ class ApiRoutes: app.router.add_post('/api/settings', routes.update_settings) app.router.add_post('/api/move_model', routes.move_model) app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route - app.router.add_post('/loras/api/save-metadata', routes.save_metadata) + app.router.add_post('/api/loras/save-metadata', routes.save_metadata) app.router.add_get('/api/lora-preview-url', routes.get_lora_preview_url) # Add new route app.router.add_post('/api/move_models_bulk', routes.move_models_bulk) app.router.add_get('/api/loras/top-tags', routes.get_top_tags) # Add new route for top tags app.router.add_get('/api/loras/base-models', routes.get_base_models) # Add new route for base models app.router.add_get('/api/lora-civitai-url', routes.get_lora_civitai_url) # Add new route for Civitai URL app.router.add_post('/api/rename_lora', routes.rename_lora) # Add new route for renaming LoRA files + app.router.add_get('/api/loras/scan', routes.scan_loras) # Add new route for scanning LoRA files # Add update check routes UpdateRoutes.setup_routes(app) async def delete_model(self, request: web.Request) -> web.Response: """Handle model deletion request""" - try: - data = await request.json() - file_path = data.get('file_path') - if not file_path: - return web.Response(text='Model path is required', status=400) - - target_dir = os.path.dirname(file_path) - file_name = os.path.splitext(os.path.basename(file_path))[0] - - deleted_files = await self._delete_model_files(target_dir, file_name) - - return web.json_response({ - 'success': True, - 'deleted_files': deleted_files - }) - - except Exception as e: - logger.error(f"Error deleting model: {e}", exc_info=True) - return web.Response(text=str(e), status=500) + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + return await ModelRouteUtils.handle_delete_model(request, self.scanner) async def fetch_civitai(self, request: web.Request) -> web.Response: """Handle CivitAI metadata fetch request""" - try: - data = await request.json() - metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json' - - # Check if model is from CivitAI - local_metadata = await self._load_local_metadata(metadata_path) - - # Fetch and update metadata - civitai_metadata = await self.civitai_client.get_model_by_hash(local_metadata["sha256"]) - if not civitai_metadata: - return await self._handle_not_found_on_civitai(metadata_path, local_metadata) - - await self._update_model_metadata(metadata_path, local_metadata, civitai_metadata, self.civitai_client) - - return web.json_response({"success": True}) - - except Exception as e: - logger.error(f"Error fetching from CivitAI: {e}", exc_info=True) - return web.json_response({"success": False, "error": str(e)}, status=500) + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + return await ModelRouteUtils.handle_fetch_civitai(request, self.scanner) async def replace_preview(self, request: web.Request) -> web.Response: """Handle preview image replacement request""" - try: - reader = await request.multipart() - preview_data, content_type = await self._read_preview_file(reader) - model_path = await self._read_model_path(reader) - - preview_path = await self._save_preview_file(model_path, preview_data, content_type) - await self._update_preview_metadata(model_path, preview_path) - - # Update preview URL in scanner cache - await self.scanner.update_preview_in_cache(model_path, preview_path) - - return web.json_response({ - "success": True, - "preview_url": config.get_preview_static_url(preview_path) - }) - + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + return await ModelRouteUtils.handle_replace_preview(request, self.scanner) + + async def scan_loras(self, request: web.Request) -> web.Response: + """Force a rescan of LoRA files""" + try: + await self.scanner.get_cached_data(force_refresh=True) + return web.json_response({"status": "success", "message": "LoRA scan completed"}) except Exception as e: - logger.error(f"Error replacing preview: {e}", exc_info=True) - return web.Response(text=str(e), status=500) + logger.error(f"Error in scan_loras: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) async def get_loras(self, request: web.Request) -> web.Response: """Handle paginated LoRA data request""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + # Parse query parameters page = int(request.query.get('page', '1')) page_size = int(request.query.get('page_size', '20')) @@ -137,10 +109,12 @@ class ApiRoutes: fuzzy_search = request.query.get('fuzzy', 'false').lower() == 'true' # Parse search options - search_filename = request.query.get('search_filename', 'true').lower() == 'true' - search_modelname = request.query.get('search_modelname', 'true').lower() == 'true' - search_tags = request.query.get('search_tags', 'false').lower() == 'true' - recursive = request.query.get('recursive', 'false').lower() == 'true' + search_options = { + 'filename': request.query.get('search_filename', 'true').lower() == 'true', + 'modelname': request.query.get('search_modelname', 'true').lower() == 'true', + 'tags': request.query.get('search_tags', 'false').lower() == 'true', + 'recursive': request.query.get('recursive', 'false').lower() == 'true' + } # Get filter parameters base_models = request.query.get('base_models', None) @@ -157,14 +131,6 @@ class ApiRoutes: if tags: filters['tags'] = tags.split(',') - # Add search options to filters - search_options = { - 'filename': search_filename, - 'modelname': search_modelname, - 'tags': search_tags, - 'recursive': recursive - } - # Add lora hash filtering options hash_filters = {} if lora_hash: @@ -223,73 +189,10 @@ class ApiRoutes: "from_civitai": lora.get("from_civitai", True), "usage_tips": lora.get("usage_tips", ""), "notes": lora.get("notes", ""), - "civitai": self._filter_civitai_data(lora.get("civitai", {})) + "civitai": ModelRouteUtils.filter_civitai_data(lora.get("civitai", {})) } - def _filter_civitai_data(self, data: Dict) -> Dict: - """Filter relevant fields from CivitAI data""" - if not data: - return {} - - fields = [ - "id", "modelId", "name", "createdAt", "updatedAt", - "publishedAt", "trainedWords", "baseModel", "description", - "model", "images" - ] - return {k: data[k] for k in fields if k in data} - # Private helper methods - async def _delete_model_files(self, target_dir: str, file_name: str) -> List[str]: - """Delete model and associated files""" - patterns = [ - f"{file_name}.safetensors", # Required - f"{file_name}.metadata.json", - f"{file_name}.preview.png", - f"{file_name}.preview.jpg", - f"{file_name}.preview.jpeg", - f"{file_name}.preview.webp", - f"{file_name}.preview.mp4", - f"{file_name}.png", - f"{file_name}.jpg", - f"{file_name}.jpeg", - f"{file_name}.webp", - f"{file_name}.mp4" - ] - - deleted = [] - main_file = patterns[0] - main_path = os.path.join(target_dir, main_file).replace(os.sep, '/') - - if os.path.exists(main_path): - # Notify file monitor to ignore delete event - self.download_manager.file_monitor.handler.add_ignore_path(main_path, 0) - - # Delete file - os.remove(main_path) - deleted.append(main_path) - else: - logger.warning(f"Model file not found: {main_file}") - - # Remove from cache - cache = await self.scanner.get_cached_data() - cache.raw_data = [item for item in cache.raw_data if item['file_path'] != main_path] - await cache.resort() - - # update hash index - self.scanner._hash_index.remove_by_path(main_path) - - # Delete optional files - for pattern in patterns[1:]: - path = os.path.join(target_dir, pattern) - if os.path.exists(path): - try: - os.remove(path) - deleted.append(pattern) - except Exception as e: - logger.warning(f"Failed to delete {pattern}: {e}") - - return deleted - async def _read_preview_file(self, reader) -> tuple[bytes, str]: """Read preview file and content type from multipart request""" field = await reader.next() @@ -307,18 +210,29 @@ class ApiRoutes: async def _save_preview_file(self, model_path: str, preview_data: bytes, content_type: str) -> str: """Save preview file and return its path""" - # Determine file extension based on content type - if content_type.startswith('video/'): - extension = '.preview.mp4' - else: - extension = '.preview.png' - base_name = os.path.splitext(os.path.basename(model_path))[0] folder = os.path.dirname(model_path) + + # Determine if content is video or image + if content_type.startswith('video/'): + # For videos, keep original format and use .mp4 extension + extension = '.mp4' + optimized_data = preview_data + else: + # For images, optimize and convert to WebP + optimized_data, _ = ExifUtils.optimize_image( + image_data=preview_data, + target_width=CARD_PREVIEW_WIDTH, + format='webp', + quality=85, + preserve_metadata=True + ) + extension = '.webp' # Use .webp without .preview part + preview_path = os.path.join(folder, base_name + extension).replace(os.sep, '/') with open(preview_path, 'wb') as f: - f.write(preview_data) + f.write(optimized_data) return preview_path @@ -338,83 +252,26 @@ class ApiRoutes: except Exception as e: logger.error(f"Error updating metadata: {e}") - async def _load_local_metadata(self, metadata_path: str) -> Dict: - """Load local metadata file""" - if os.path.exists(metadata_path): - try: - with open(metadata_path, 'r', encoding='utf-8') as f: - return json.load(f) - except Exception as e: - logger.error(f"Error loading metadata from {metadata_path}: {e}") - return {} - - async def _handle_not_found_on_civitai(self, metadata_path: str, local_metadata: Dict) -> web.Response: - """Handle case when model is not found on CivitAI""" - local_metadata['from_civitai'] = False - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(local_metadata, f, indent=2, ensure_ascii=False) - return web.json_response( - {"success": False, "error": "Not found on CivitAI"}, - status=404 - ) - - async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict, - civitai_metadata: Dict, client: CivitaiClient) -> None: - """Update local metadata with CivitAI data""" - local_metadata['civitai'] = civitai_metadata - - # Update model name if available - if 'model' in civitai_metadata: - if civitai_metadata.get('model', {}).get('name'): - local_metadata['model_name'] = civitai_metadata['model']['name'] - - # Fetch additional model metadata (description and tags) if we have model ID - model_id = civitai_metadata['modelId'] - if model_id: - model_metadata, _ = await client.get_model_metadata(str(model_id)) - if model_metadata: - local_metadata['modelDescription'] = model_metadata.get('description', '') - local_metadata['tags'] = model_metadata.get('tags', []) - - # Update base model - local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel')) - - # Update preview if needed - if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']): - first_preview = next((img for img in civitai_metadata.get('images', [])), None) - if first_preview: - preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1] - base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0] - preview_filename = base_name + preview_ext - preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename) - - if await client.download_preview_image(first_preview['url'], preview_path): - local_metadata['preview_url'] = preview_path.replace(os.sep, '/') - local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0) - - # Save updated metadata - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(local_metadata, f, indent=2, ensure_ascii=False) - - await self.scanner.update_single_lora_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata) - async def fetch_all_civitai(self, request: web.Request) -> web.Response: """Fetch CivitAI metadata for all loras in the background""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + cache = await self.scanner.get_cached_data() total = len(cache.raw_data) processed = 0 success = 0 needs_resort = False - # 准备要处理的 loras + # Prepare loras to process to_process = [ lora for lora in cache.raw_data - if lora.get('sha256') and (not lora.get('civitai') or 'id' not in lora.get('civitai')) and lora.get('from_civitai') # TODO: for lora not from CivitAI but added traineWords + if lora.get('sha256') and (not lora.get('civitai') or 'id' not in lora.get('civitai')) and lora.get('from_civitai', True) # TODO: for lora not from CivitAI but added traineWords ] total_to_process = len(to_process) - # 发送初始进度 + # Send initial progress await ws_manager.broadcast({ 'status': 'started', 'total': total_to_process, @@ -425,10 +282,11 @@ class ApiRoutes: for lora in to_process: try: original_name = lora.get('model_name') - if await self._fetch_and_update_single_lora( + if await ModelRouteUtils.fetch_and_update_model( sha256=lora['sha256'], file_path=lora['file_path'], - lora=lora + model_data=lora, + update_cache_func=self.scanner.update_single_model_cache ): success += 1 if original_name != lora.get('model_name'): @@ -436,7 +294,7 @@ class ApiRoutes: processed += 1 - # 每处理一个就发送进度更新 + # Send progress update await ws_manager.broadcast({ 'status': 'processing', 'total': total_to_process, @@ -451,7 +309,7 @@ class ApiRoutes: if needs_resort: await cache.resort(name_only=True) - # 发送完成消息 + # Send completion message await ws_manager.broadcast({ 'status': 'completed', 'total': total_to_process, @@ -465,7 +323,7 @@ class ApiRoutes: }) except Exception as e: - # 发送错误消息 + # Send error message await ws_manager.broadcast({ 'status': 'error', 'error': str(e) @@ -473,58 +331,6 @@ class ApiRoutes: logger.error(f"Error in fetch_all_civitai: {e}") return web.Response(text=str(e), status=500) - async def _fetch_and_update_single_lora(self, sha256: str, file_path: str, lora: dict) -> bool: - """Fetch and update metadata for a single lora without sorting - - Args: - sha256: SHA256 hash of the lora file - file_path: Path to the lora file - lora: The lora object in cache to update - - Returns: - bool: True if successful, False otherwise - """ - client = CivitaiClient() - try: - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - - # Check if model is from CivitAI - local_metadata = await self._load_local_metadata(metadata_path) - - # Fetch metadata - civitai_metadata = await client.get_model_by_hash(sha256) - if not civitai_metadata: - # Mark as not from CivitAI if not found - local_metadata['from_civitai'] = False - lora['from_civitai'] = False - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(local_metadata, f, indent=2, ensure_ascii=False) - return False - - # Update metadata - await self._update_model_metadata( - metadata_path, - local_metadata, - civitai_metadata, - client - ) - - # Update cache object directly - lora.update({ - 'model_name': local_metadata.get('model_name'), - 'preview_url': local_metadata.get('preview_url'), - 'from_civitai': True, - 'civitai': civitai_metadata - }) - - return True - - except Exception as e: - logger.error(f"Error fetching CivitAI data: {e}") - return False - finally: - await client.close() - async def get_lora_roots(self, request: web.Request) -> web.Response: """Get all configured LoRA root directories""" return web.json_response({ @@ -533,6 +339,9 @@ class ApiRoutes: async def get_folders(self, request: web.Request) -> web.Response: """Get all folders in the cache""" + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + cache = await self.scanner.get_cached_data() return web.json_response({ 'folders': cache.folders @@ -541,6 +350,12 @@ class ApiRoutes: async def get_civitai_versions(self, request: web.Request) -> web.Response: """Get available versions for a Civitai model with local availability info""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + + if self.civitai_client is None: + self.civitai_client = await ServiceRegistry.get_civitai_client() + model_id = request.match_info['model_id'] versions = await self.civitai_client.get_model_versions(model_id) if not versions: @@ -574,9 +389,12 @@ class ApiRoutes: async def get_civitai_model(self, request: web.Request) -> web.Response: """Get CivitAI model details by model version ID or hash""" try: - model_version_id = request.match_info['modelVersionId'] + if self.civitai_client is None: + self.civitai_client = await ServiceRegistry.get_civitai_client() + + model_version_id = request.match_info.get('modelVersionId') if not model_version_id: - hash = request.match_info['hash'] + hash = request.match_info.get('hash') model = await self.civitai_client.get_model_by_hash(hash) return web.json_response(model) @@ -591,6 +409,9 @@ class ApiRoutes: async def download_lora(self, request: web.Request) -> web.Response: async with self._download_lock: try: + if self.download_manager is None: + self.download_manager = await ServiceRegistry.get_download_manager() + data = await request.json() # Create progress callback @@ -662,12 +483,15 @@ class ApiRoutes: return web.json_response({'success': True}) except Exception as e: - logger.error(f"Error updating settings: {e}", exc_info=True) # 添加 exc_info=True 以获取完整堆栈 + logger.error(f"Error updating settings: {e}", exc_info=True) return web.Response(status=500, text=str(e)) async def move_model(self, request: web.Request) -> web.Response: """Handle model move request""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + data = await request.json() file_path = data.get('file_path') # full path of the model file, e.g. /path/to/model.safetensors target_path = data.get('target_path') # folder path to move the model to, e.g. /path/to/target_folder @@ -706,12 +530,17 @@ class ApiRoutes: @classmethod async def cleanup(cls): """Add cleanup method for application shutdown""" - if hasattr(cls, '_instance'): - await cls._instance.civitai_client.close() + # Now we don't need to store an instance, as services are managed by ServiceRegistry + civitai_client = await ServiceRegistry.get_civitai_client() + if civitai_client: + await civitai_client.close() async def save_metadata(self, request: web.Request) -> web.Response: """Handle saving metadata updates""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + data = await request.json() file_path = data.get('file_path') if not file_path: @@ -724,11 +553,7 @@ class ApiRoutes: metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' # Load existing metadata - if os.path.exists(metadata_path): - with open(metadata_path, 'r', encoding='utf-8') as f: - metadata = json.load(f) - else: - metadata = {} + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) # Handle nested updates (for civitai.trainedWords) for key, value in metadata_updates.items(): @@ -745,7 +570,7 @@ class ApiRoutes: json.dump(metadata, f, indent=2, ensure_ascii=False) # Update cache - await self.scanner.update_single_lora_cache(file_path, file_path, metadata) + await self.scanner.update_single_model_cache(file_path, file_path, metadata) # If model_name was updated, resort the cache if 'model_name' in metadata_updates: @@ -761,6 +586,9 @@ class ApiRoutes: async def get_lora_preview_url(self, request: web.Request) -> web.Response: """Get the static preview URL for a LoRA file""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + # Get lora file name from query parameters lora_name = request.query.get('name') if not lora_name: @@ -791,11 +619,17 @@ class ApiRoutes: except Exception as e: logger.error(f"Error getting lora preview URL: {e}", exc_info=True) - return web.Response(text=str(e), status=500) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) async def get_lora_civitai_url(self, request: web.Request) -> web.Response: """Get the Civitai URL for a LoRA file""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + # Get lora file name from query parameters lora_name = request.query.get('name') if not lora_name: @@ -841,6 +675,9 @@ class ApiRoutes: async def move_models_bulk(self, request: web.Request) -> web.Response: """Handle bulk model move request""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + data = await request.json() file_paths = data.get('file_paths', []) # list of full paths of the model files, e.g. ["/path/to/model1.safetensors", "/path/to/model2.safetensors"] target_path = data.get('target_path') # folder path to move the models to, e.g. "/path/to/target_folder" @@ -899,6 +736,9 @@ class ApiRoutes: async def get_lora_model_description(self, request: web.Request) -> web.Response: """Get model description for a Lora model""" try: + if self.civitai_client is None: + self.civitai_client = await ServiceRegistry.get_civitai_client() + # Get parameters model_id = request.query.get('model_id') file_path = request.query.get('file_path') @@ -914,14 +754,9 @@ class ApiRoutes: tags = [] if file_path: metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - if os.path.exists(metadata_path): - try: - with open(metadata_path, 'r', encoding='utf-8') as f: - metadata = json.load(f) - description = metadata.get('modelDescription') - tags = metadata.get('tags', []) - except Exception as e: - logger.error(f"Error loading metadata from {metadata_path}: {e}") + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + description = metadata.get('modelDescription') + tags = metadata.get('tags', []) # If description is not in metadata, fetch from CivitAI if not description: @@ -936,16 +771,14 @@ class ApiRoutes: if file_path: try: metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - if os.path.exists(metadata_path): - with open(metadata_path, 'r', encoding='utf-8') as f: - metadata = json.load(f) - - metadata['modelDescription'] = description - metadata['tags'] = tags - - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(metadata, f, indent=2, ensure_ascii=False) - logger.info(f"Saved model metadata to file for {file_path}") + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + + metadata['modelDescription'] = description + metadata['tags'] = tags + + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) + logger.info(f"Saved model metadata to file for {file_path}") except Exception as e: logger.error(f"Error saving model metadata: {e}") @@ -956,7 +789,7 @@ class ApiRoutes: }) except Exception as e: - logger.error(f"Error getting model metadata: {e}", exc_info=True) + logger.error(f"Error getting model metadata: {e}") return web.json_response({ 'success': False, 'error': str(e) @@ -965,6 +798,9 @@ class ApiRoutes: async def get_top_tags(self, request: web.Request) -> web.Response: """Handle request for top tags sorted by frequency""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + # Parse query parameters limit = int(request.query.get('limit', '20')) @@ -990,6 +826,9 @@ class ApiRoutes: async def get_base_models(self, request: web.Request) -> web.Response: """Get base models used in loras""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + # Parse query parameters limit = int(request.query.get('limit', '20')) @@ -1011,15 +850,15 @@ class ApiRoutes: 'error': str(e) }, status=500) - def get_multipart_ext(self, filename): - parts = filename.split(".") - if len(parts) > 2: # 如果包含多级扩展名 - return "." + ".".join(parts[-2:]) # 取最后两部分,如 ".metadata.json" - return os.path.splitext(filename)[1] # 否则取普通扩展名,如 ".safetensors" - async def rename_lora(self, request: web.Request) -> web.Response: """Handle renaming a LoRA file and its associated files""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + + if self.download_manager is None: + self.download_manager = await ServiceRegistry.get_download_manager() + data = await request.json() file_path = data.get('file_path') new_file_name = data.get('new_file_name') @@ -1054,18 +893,12 @@ class ApiRoutes: patterns = [ f"{old_file_name}.safetensors", # Required f"{old_file_name}.metadata.json", - f"{old_file_name}.preview.png", - f"{old_file_name}.preview.jpg", - f"{old_file_name}.preview.jpeg", - f"{old_file_name}.preview.webp", - f"{old_file_name}.preview.mp4", - f"{old_file_name}.png", - f"{old_file_name}.jpg", - f"{old_file_name}.jpeg", - f"{old_file_name}.webp", - f"{old_file_name}.mp4" ] + # Add all preview file extensions + for ext in PREVIEW_EXTENSIONS: + patterns.append(f"{old_file_name}{ext}") + # Find all matching files existing_files = [] for pattern in patterns: @@ -1079,12 +912,8 @@ class ApiRoutes: metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json") if os.path.exists(metadata_path): - try: - with open(metadata_path, 'r', encoding='utf-8') as f: - metadata = json.load(f) - hash_value = metadata.get('sha256') - except Exception as e: - logger.error(f"Error loading metadata for rename: {e}") + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + hash_value = metadata.get('sha256') # Rename all files renamed_files = [] @@ -1092,15 +921,18 @@ class ApiRoutes: # Notify file monitor to ignore these events main_file_path = os.path.join(target_dir, f"{old_file_name}.safetensors") - if os.path.exists(main_file_path) and self.download_manager.file_monitor: - # Add old and new paths to ignore list - file_size = os.path.getsize(main_file_path) - self.download_manager.file_monitor.handler.add_ignore_path(main_file_path, file_size) - self.download_manager.file_monitor.handler.add_ignore_path(new_file_path, file_size) + if os.path.exists(main_file_path): + # Get lora monitor through ServiceRegistry instead of download_manager + lora_monitor = await ServiceRegistry.get_lora_monitor() + if lora_monitor: + # Add old and new paths to ignore list + file_size = os.path.getsize(main_file_path) + lora_monitor.handler.add_ignore_path(main_file_path, file_size) + lora_monitor.handler.add_ignore_path(new_file_path, file_size) for old_path, pattern in existing_files: # Get the file extension like .safetensors or .metadata.json - ext = self.get_multipart_ext(pattern) + ext = ModelRouteUtils.get_multipart_ext(pattern) # Create the new path new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/') @@ -1122,7 +954,7 @@ class ApiRoutes: # Update preview_url if it exists if 'preview_url' in metadata and metadata['preview_url']: old_preview = metadata['preview_url'] - ext = self.get_multipart_ext(old_preview) + ext = ModelRouteUtils.get_multipart_ext(old_preview) new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/') metadata['preview_url'] = new_preview @@ -1132,11 +964,11 @@ class ApiRoutes: # Update the scanner cache if metadata: - await self.scanner.update_single_lora_cache(file_path, new_file_path, metadata) + await self.scanner.update_single_model_cache(file_path, new_file_path, metadata) # Update recipe files and cache if hash is available if hash_value: - recipe_scanner = RecipeScanner(self.scanner) + recipe_scanner = await ServiceRegistry.get_recipe_scanner() recipes_updated, cache_updated = await recipe_scanner.update_lora_filename_by_hash(hash_value, new_file_name) logger.info(f"Updated {recipes_updated} recipe files and {cache_updated} cache entries for renamed LoRA") diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index 0a79d6f9..8a947bd6 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -1,37 +1,478 @@ import os -from aiohttp import web +import json import jinja2 +from aiohttp import web import logging +import asyncio + +from ..utils.routes_common import ModelRouteUtils +from ..utils.constants import NSFW_LEVELS +from ..services.websocket_manager import ws_manager +from ..services.service_registry import ServiceRegistry from ..config import config from ..services.settings_manager import settings +from ..utils.utils import fuzzy_match logger = logging.getLogger(__name__) -logging.getLogger('asyncio').setLevel(logging.CRITICAL) class CheckpointsRoutes: - """Route handlers for Checkpoints management endpoints""" + """API routes for checkpoint management""" def __init__(self): + self.scanner = None # Will be initialized in setup_routes self.template_env = jinja2.Environment( loader=jinja2.FileSystemLoader(config.templates_path), autoescape=True ) + self.download_manager = None # Will be initialized in setup_routes + self._download_lock = asyncio.Lock() + + async def initialize_services(self): + """Initialize services from ServiceRegistry""" + self.scanner = await ServiceRegistry.get_checkpoint_scanner() + self.download_manager = await ServiceRegistry.get_download_manager() + + def setup_routes(self, app): + """Register routes with the aiohttp app""" + # Schedule service initialization on app startup + app.on_startup.append(lambda _: self.initialize_services()) + + app.router.add_get('/checkpoints', self.handle_checkpoints_page) + app.router.add_get('/api/checkpoints', self.get_checkpoints) + app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai) + app.router.add_get('/api/checkpoints/base-models', self.get_base_models) + app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags) + app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints) + app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info) + app.router.add_get('/api/checkpoints/roots', self.get_checkpoint_roots) + + # Add new routes for model management similar to LoRA routes + app.router.add_post('/api/checkpoints/delete', self.delete_model) + app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai) + app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview) + app.router.add_post('/api/checkpoints/download', self.download_checkpoint) + app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route + + async def get_checkpoints(self, request): + """Get paginated checkpoint data""" + try: + # Parse query parameters + page = int(request.query.get('page', '1')) + page_size = min(int(request.query.get('page_size', '20')), 100) + sort_by = request.query.get('sort_by', 'name') + folder = request.query.get('folder', None) + search = request.query.get('search', None) + fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true' + base_models = request.query.getall('base_model', []) + tags = request.query.getall('tag', []) + + # Process search options + search_options = { + 'filename': request.query.get('search_filename', 'true').lower() == 'true', + 'modelname': request.query.get('search_modelname', 'true').lower() == 'true', + 'tags': request.query.get('search_tags', 'false').lower() == 'true', + 'recursive': request.query.get('recursive', 'false').lower() == 'true', + } + + # Process hash filters if provided + hash_filters = {} + if 'hash' in request.query: + hash_filters['single_hash'] = request.query['hash'] + elif 'hashes' in request.query: + try: + hash_list = json.loads(request.query['hashes']) + if isinstance(hash_list, list): + hash_filters['multiple_hashes'] = hash_list + except (json.JSONDecodeError, TypeError): + pass + + # Get data from scanner + result = await self.get_paginated_data( + page=page, + page_size=page_size, + sort_by=sort_by, + folder=folder, + search=search, + fuzzy_search=fuzzy_search, + base_models=base_models, + tags=tags, + search_options=search_options, + hash_filters=hash_filters + ) + + # Format response items + formatted_result = { + 'items': [self._format_checkpoint_response(cp) for cp in result['items']], + 'total': result['total'], + 'page': result['page'], + 'page_size': result['page_size'], + 'total_pages': result['total_pages'] + } + + # Return as JSON + return web.json_response(formatted_result) + + except Exception as e: + logger.error(f"Error in get_checkpoints: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) + + async def get_paginated_data(self, page, page_size, sort_by='name', + folder=None, search=None, fuzzy_search=False, + base_models=None, tags=None, + search_options=None, hash_filters=None): + """Get paginated and filtered checkpoint data""" + cache = await self.scanner.get_cached_data() + + # Get default search options if not provided + if search_options is None: + search_options = { + 'filename': True, + 'modelname': True, + 'tags': False, + 'recursive': False, + } + + # Get the base data set + filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name + + # Apply hash filtering if provided (highest priority) + if hash_filters: + single_hash = hash_filters.get('single_hash') + multiple_hashes = hash_filters.get('multiple_hashes') + + if single_hash: + # Filter by single hash + single_hash = single_hash.lower() # Ensure lowercase for matching + filtered_data = [ + cp for cp in filtered_data + if cp.get('sha256', '').lower() == single_hash + ] + elif multiple_hashes: + # Filter by multiple hashes + hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup + filtered_data = [ + cp for cp in filtered_data + if cp.get('sha256', '').lower() in hash_set + ] + + # Jump to pagination + total_items = len(filtered_data) + start_idx = (page - 1) * page_size + end_idx = min(start_idx + page_size, total_items) + + result = { + 'items': filtered_data[start_idx:end_idx], + 'total': total_items, + 'page': page, + 'page_size': page_size, + 'total_pages': (total_items + page_size - 1) // page_size + } + + return result + + # Apply SFW filtering if enabled in settings + if settings.get('show_only_sfw', False): + filtered_data = [ + cp for cp in filtered_data + if not cp.get('preview_nsfw_level') or cp.get('preview_nsfw_level') < NSFW_LEVELS['R'] + ] + + # Apply folder filtering + if folder is not None: + if search_options.get('recursive', False): + # Recursive folder filtering - include all subfolders + filtered_data = [ + cp for cp in filtered_data + if cp['folder'].startswith(folder) + ] + else: + # Exact folder filtering + filtered_data = [ + cp for cp in filtered_data + if cp['folder'] == folder + ] + + # Apply base model filtering + if base_models and len(base_models) > 0: + filtered_data = [ + cp for cp in filtered_data + if cp.get('base_model') in base_models + ] + + # Apply tag filtering + if tags and len(tags) > 0: + filtered_data = [ + cp for cp in filtered_data + if any(tag in cp.get('tags', []) for tag in tags) + ] + + # Apply search filtering + if search: + search_results = [] + + for cp in filtered_data: + # Search by file name + if search_options.get('filename', True): + if fuzzy_search: + if fuzzy_match(cp.get('file_name', ''), search): + search_results.append(cp) + continue + elif search.lower() in cp.get('file_name', '').lower(): + search_results.append(cp) + continue + + # Search by model name + if search_options.get('modelname', True): + if fuzzy_search: + if fuzzy_match(cp.get('model_name', ''), search): + search_results.append(cp) + continue + elif search.lower() in cp.get('model_name', '').lower(): + search_results.append(cp) + continue + + # Search by tags + if search_options.get('tags', False) and 'tags' in cp: + if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower()) for tag in cp['tags']): + search_results.append(cp) + continue + + filtered_data = search_results + + # Calculate pagination + total_items = len(filtered_data) + start_idx = (page - 1) * page_size + end_idx = min(start_idx + page_size, total_items) + + result = { + 'items': filtered_data[start_idx:end_idx], + 'total': total_items, + 'page': page, + 'page_size': page_size, + 'total_pages': (total_items + page_size - 1) // page_size + } + + return result + + def _format_checkpoint_response(self, checkpoint): + """Format checkpoint data for API response""" + return { + "model_name": checkpoint["model_name"], + "file_name": checkpoint["file_name"], + "preview_url": config.get_preview_static_url(checkpoint.get("preview_url", "")), + "preview_nsfw_level": checkpoint.get("preview_nsfw_level", 0), + "base_model": checkpoint.get("base_model", ""), + "folder": checkpoint["folder"], + "sha256": checkpoint.get("sha256", ""), + "file_path": checkpoint["file_path"].replace(os.sep, "/"), + "file_size": checkpoint.get("size", 0), + "modified": checkpoint.get("modified", ""), + "tags": checkpoint.get("tags", []), + "modelDescription": checkpoint.get("modelDescription", ""), + "from_civitai": checkpoint.get("from_civitai", True), + "notes": checkpoint.get("notes", ""), + "model_type": checkpoint.get("model_type", "checkpoint"), + "civitai": ModelRouteUtils.filter_civitai_data(checkpoint.get("civitai", {})) + } + + async def fetch_all_civitai(self, request: web.Request) -> web.Response: + """Fetch CivitAI metadata for all checkpoints in the background""" + try: + cache = await self.scanner.get_cached_data() + total = len(cache.raw_data) + processed = 0 + success = 0 + needs_resort = False + + # Prepare checkpoints to process + to_process = [ + cp for cp in cache.raw_data + if cp.get('sha256') and (not cp.get('civitai') or 'id' not in cp.get('civitai')) and cp.get('from_civitai', True) + ] + total_to_process = len(to_process) + + # Send initial progress + await ws_manager.broadcast({ + 'status': 'started', + 'total': total_to_process, + 'processed': 0, + 'success': 0 + }) + + # Process each checkpoint + for cp in to_process: + try: + original_name = cp.get('model_name') + if await ModelRouteUtils.fetch_and_update_model( + sha256=cp['sha256'], + file_path=cp['file_path'], + model_data=cp, + update_cache_func=self.scanner.update_single_model_cache + ): + success += 1 + if original_name != cp.get('model_name'): + needs_resort = True + + processed += 1 + + # Send progress update + await ws_manager.broadcast({ + 'status': 'processing', + 'total': total_to_process, + 'processed': processed, + 'success': success, + 'current_name': cp.get('model_name', 'Unknown') + }) + + except Exception as e: + logger.error(f"Error fetching CivitAI data for {cp['file_path']}: {e}") + + if needs_resort: + await cache.resort(name_only=True) + + # Send completion message + await ws_manager.broadcast({ + 'status': 'completed', + 'total': total_to_process, + 'processed': processed, + 'success': success + }) + + return web.json_response({ + "success": True, + "message": f"Successfully updated {success} of {processed} processed checkpoints (total: {total})" + }) + + except Exception as e: + # Send error message + await ws_manager.broadcast({ + 'status': 'error', + 'error': str(e) + }) + logger.error(f"Error in fetch_all_civitai for checkpoints: {e}") + return web.Response(text=str(e), status=500) + + async def get_top_tags(self, request: web.Request) -> web.Response: + """Handle request for top tags sorted by frequency""" + try: + # Parse query parameters + limit = int(request.query.get('limit', '20')) + + # Validate limit + if limit < 1 or limit > 100: + limit = 20 # Default to a reasonable limit + + # Get top tags + top_tags = await self.scanner.get_top_tags(limit) + + return web.json_response({ + 'success': True, + 'tags': top_tags + }) + + except Exception as e: + logger.error(f"Error getting top tags: {str(e)}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': 'Internal server error' + }, status=500) + + async def get_base_models(self, request: web.Request) -> web.Response: + """Get base models used in loras""" + try: + # Parse query parameters + limit = int(request.query.get('limit', '20')) + + # Validate limit + if limit < 1 or limit > 100: + limit = 20 # Default to a reasonable limit + + # Get base models + base_models = await self.scanner.get_base_models(limit) + + return web.json_response({ + 'success': True, + 'base_models': base_models + }) + except Exception as e: + logger.error(f"Error retrieving base models: {e}") + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def scan_checkpoints(self, request): + """Force a rescan of checkpoint files""" + try: + await self.scanner.get_cached_data(force_refresh=True) + return web.json_response({"status": "success", "message": "Checkpoint scan completed"}) + except Exception as e: + logger.error(f"Error in scan_checkpoints: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) + + async def get_checkpoint_info(self, request): + """Get detailed information for a specific checkpoint by name""" + try: + name = request.match_info.get('name', '') + checkpoint_info = await self.scanner.get_checkpoint_info_by_name(name) + + if checkpoint_info: + return web.json_response(checkpoint_info) + else: + return web.json_response({"error": "Checkpoint not found"}, status=404) + + except Exception as e: + logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) async def handle_checkpoints_page(self, request: web.Request) -> web.Response: """Handle GET /checkpoints request""" try: - template = self.template_env.get_template('checkpoints.html') - rendered = template.render( - is_initializing=False, - settings=settings, - request=request + # 检查缓存初始化状态,根据initialize_in_background的工作方式调整判断逻辑 + is_initializing = ( + self.scanner._cache is None or + len(self.scanner._cache.raw_data) == 0 or + hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing ) + + if is_initializing: + # 如果正在初始化,返回一个只包含加载提示的页面 + template = self.template_env.get_template('checkpoints.html') + rendered = template.render( + folders=[], # 空文件夹列表 + is_initializing=True, # 新增标志 + settings=settings, # Pass settings to template + request=request # Pass the request object to the template + ) + + logger.info("Checkpoints page is initializing, returning loading page") + else: + # 正常流程 - 获取已经初始化好的缓存数据 + try: + cache = await self.scanner.get_cached_data(force_refresh=False) + template = self.template_env.get_template('checkpoints.html') + rendered = template.render( + folders=cache.folders, + is_initializing=False, + settings=settings, # Pass settings to template + request=request # Pass the request object to the template + ) + except Exception as cache_error: + logger.error(f"Error loading checkpoints cache data: {cache_error}") + # 如果获取缓存失败,也显示初始化页面 + template = self.template_env.get_template('checkpoints.html') + rendered = template.render( + folders=[], + is_initializing=True, + settings=settings, + request=request + ) + logger.info("Checkpoints cache error, returning initialization page") return web.Response( text=rendered, content_type='text/html' ) - except Exception as e: logger.error(f"Error handling checkpoints request: {e}", exc_info=True) return web.Response( @@ -39,6 +480,87 @@ class CheckpointsRoutes: status=500 ) - def setup_routes(self, app: web.Application): - """Register routes with the application""" - app.router.add_get('/checkpoints', self.handle_checkpoints_page) + async def delete_model(self, request: web.Request) -> web.Response: + """Handle checkpoint model deletion request""" + return await ModelRouteUtils.handle_delete_model(request, self.scanner) + + async def fetch_civitai(self, request: web.Request) -> web.Response: + """Handle CivitAI metadata fetch request for checkpoints""" + return await ModelRouteUtils.handle_fetch_civitai(request, self.scanner) + + async def replace_preview(self, request: web.Request) -> web.Response: + """Handle preview image replacement for checkpoints""" + return await ModelRouteUtils.handle_replace_preview(request, self.scanner) + + async def download_checkpoint(self, request: web.Request) -> web.Response: + """Handle checkpoint download request""" + async with self._download_lock: + # Get the download manager from service registry if not already initialized + if self.download_manager is None: + self.download_manager = await ServiceRegistry.get_download_manager() + + # Use the common download handler with model_type="checkpoint" + return await ModelRouteUtils.handle_download_model( + request=request, + download_manager=self.download_manager, + model_type="checkpoint" + ) + + async def get_checkpoint_roots(self, request): + """Return the checkpoint root directories""" + try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_checkpoint_scanner() + + roots = self.scanner.get_model_roots() + return web.json_response({ + "success": True, + "roots": roots + }) + except Exception as e: + logger.error(f"Error getting checkpoint roots: {e}", exc_info=True) + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + + async def save_metadata(self, request: web.Request) -> web.Response: + """Handle saving metadata updates for checkpoints""" + try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_checkpoint_scanner() + + data = await request.json() + file_path = data.get('file_path') + if not file_path: + return web.Response(text='File path is required', status=400) + + # Remove file path from data to avoid saving it + metadata_updates = {k: v for k, v in data.items() if k != 'file_path'} + + # Get metadata file path + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + + # Load existing metadata + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + + # Update metadata + metadata.update(metadata_updates) + + # Save updated metadata + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) + + # Update cache + await self.scanner.update_single_model_cache(file_path, file_path, metadata) + + # If model_name was updated, resort the cache + if 'model_name' in metadata_updates: + cache = await self.scanner.get_cached_data() + await cache.resort(name_only=True) + + return web.json_response({'success': True}) + + except Exception as e: + logger.error(f"Error saving checkpoint metadata: {e}", exc_info=True) + return web.Response(text=str(e), status=500) diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index 196bfa25..f22071b1 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -1,12 +1,11 @@ import os from aiohttp import web import jinja2 -from typing import Dict, List +from typing import Dict import logging -from ..services.lora_scanner import LoraScanner -from ..services.recipe_scanner import RecipeScanner from ..config import config -from ..services.settings_manager import settings # Add this import +from ..services.settings_manager import settings +from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import logger = logging.getLogger(__name__) logging.getLogger('asyncio').setLevel(logging.CRITICAL) @@ -15,13 +14,19 @@ class LoraRoutes: """Route handlers for LoRA management endpoints""" def __init__(self): - self.scanner = LoraScanner() - self.recipe_scanner = RecipeScanner(self.scanner) + # Initialize service references as None, will be set during async init + self.scanner = None + self.recipe_scanner = None self.template_env = jinja2.Environment( loader=jinja2.FileSystemLoader(config.templates_path), autoescape=True ) + async def init_services(self): + """Initialize services from ServiceRegistry""" + self.scanner = await ServiceRegistry.get_lora_scanner() + self.recipe_scanner = await ServiceRegistry.get_recipe_scanner() + def format_lora_data(self, lora: Dict) -> Dict: """Format LoRA data for template rendering""" return { @@ -58,41 +63,40 @@ class LoraRoutes: async def handle_loras_page(self, request: web.Request) -> web.Response: """Handle GET /loras request""" try: - # 检查缓存初始化状态,增强判断条件 + # Ensure services are initialized + await self.init_services() + + # Check if the LoraScanner is initializing is_initializing = ( self.scanner._cache is None or - (self.scanner._initialization_task is not None and - not self.scanner._initialization_task.done()) or - (self.scanner._cache is not None and len(self.scanner._cache.raw_data) == 0 and - self.scanner._initialization_task is not None) + len(self.scanner._cache.raw_data) == 0 or + hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing ) if is_initializing: - # 如果正在初始化,返回一个只包含加载提示的页面 + # If still initializing, return loading page template = self.template_env.get_template('loras.html') rendered = template.render( - folders=[], # 空文件夹列表 - is_initializing=True, # 新增标志 - settings=settings, # Pass settings to template - request=request # Pass the request object to the template + folders=[], + is_initializing=True, + settings=settings, + request=request ) logger.info("Loras page is initializing, returning loading page") else: - # 正常流程 - 但不要等待缓存刷新 + # Normal flow - get data from initialized cache try: cache = await self.scanner.get_cached_data(force_refresh=False) template = self.template_env.get_template('loras.html') rendered = template.render( folders=cache.folders, is_initializing=False, - settings=settings, # Pass settings to template - request=request # Pass the request object to the template + settings=settings, + request=request ) - logger.debug(f"Loras page loaded successfully with {len(cache.raw_data)} items") except Exception as cache_error: logger.error(f"Error loading cache data: {cache_error}") - # 如果获取缓存失败,也显示初始化页面 template = self.template_env.get_template('loras.html') rendered = template.render( folders=[], @@ -117,32 +121,47 @@ class LoraRoutes: async def handle_recipes_page(self, request: web.Request) -> web.Response: """Handle GET /loras/recipes request""" try: - # Check cache initialization status + # Ensure services are initialized + await self.init_services() + + # Check if the RecipeScanner is initializing is_initializing = ( - self.recipe_scanner._cache is None and - (self.recipe_scanner._initialization_task is not None and - not self.recipe_scanner._initialization_task.done()) + self.recipe_scanner._cache is None or + len(self.recipe_scanner._cache.raw_data) == 0 or + hasattr(self.recipe_scanner, '_is_initializing') and self.recipe_scanner._is_initializing ) if is_initializing: - # If initializing, return a loading page + # 如果正在初始化,返回一个只包含加载提示的页面 template = self.template_env.get_template('recipes.html') rendered = template.render( is_initializing=True, settings=settings, request=request # Pass the request object to the template ) - else: - # return empty recipes - recipes_data = [] - template = self.template_env.get_template('recipes.html') - rendered = template.render( - recipes=recipes_data, - is_initializing=False, - settings=settings, - request=request # Pass the request object to the template - ) + logger.info("Recipes page is initializing, returning loading page") + else: + # 正常流程 - 获取已经初始化好的缓存数据 + try: + cache = await self.recipe_scanner.get_cached_data(force_refresh=False) + template = self.template_env.get_template('recipes.html') + rendered = template.render( + recipes=[], # Frontend will load recipes via API + is_initializing=False, + settings=settings, + request=request # Pass the request object to the template + ) + except Exception as cache_error: + logger.error(f"Error loading recipe cache data: {cache_error}") + # 如果获取缓存失败,也显示初始化页面 + template = self.template_env.get_template('recipes.html') + rendered = template.render( + is_initializing=True, + settings=settings, + request=request + ) + logger.info("Recipe cache error, returning initialization page") return web.Response( text=rendered, @@ -174,5 +193,13 @@ class LoraRoutes: def setup_routes(self, app: web.Application): """Register routes with the application""" + # Add an app startup handler to initialize services + app.on_startup.append(self._on_startup) + + # Register routes app.router.add_get('/loras', self.handle_loras_page) app.router.add_get('/loras/recipes', self.handle_recipes_page) + + async def _on_startup(self, app): + """Initialize services when the app starts""" + await self.init_services() diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index 32de5722..48328537 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -8,13 +8,12 @@ import json import asyncio from ..utils.exif_utils import ExifUtils from ..utils.recipe_parsers import RecipeParserFactory -from ..services.civitai_client import CivitaiClient +from ..utils.constants import CARD_PREVIEW_WIDTH -from ..services.recipe_scanner import RecipeScanner -from ..services.lora_scanner import LoraScanner from ..config import config from ..workflow.parser import WorkflowParser from ..utils.utils import download_civitai_image +from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import logger = logging.getLogger(__name__) @@ -22,13 +21,19 @@ class RecipeRoutes: """API route handlers for Recipe management""" def __init__(self): - self.recipe_scanner = RecipeScanner(LoraScanner()) - self.civitai_client = CivitaiClient() + # Initialize service references as None, will be set during async init + self.recipe_scanner = None + self.civitai_client = None self.parser = WorkflowParser() # Pre-warm the cache self._init_cache_task = None + async def init_services(self): + """Initialize services from ServiceRegistry""" + self.recipe_scanner = await ServiceRegistry.get_recipe_scanner() + self.civitai_client = await ServiceRegistry.get_civitai_client() + @classmethod def setup_routes(cls, app: web.Application): """Register API routes""" @@ -67,7 +72,10 @@ class RecipeRoutes: async def _init_cache(self, app): """Initialize cache on startup""" try: - # First, ensure the lora scanner is fully initialized + # Initialize services first + await self.init_services() + + # Now that services are initialized, get the lora scanner lora_scanner = self.recipe_scanner._lora_scanner # Get lora cache to ensure it's initialized @@ -85,6 +93,9 @@ class RecipeRoutes: async def get_recipes(self, request: web.Request) -> web.Response: """API endpoint for getting paginated recipes""" try: + # Ensure services are initialized + await self.init_services() + # Get query parameters with defaults page = int(request.query.get('page', '1')) page_size = int(request.query.get('page_size', '20')) @@ -154,6 +165,9 @@ class RecipeRoutes: async def get_recipe_detail(self, request: web.Request) -> web.Response: """Get detailed information about a specific recipe""" try: + # Ensure services are initialized + await self.init_services() + recipe_id = request.match_info['recipe_id'] # Use the new get_recipe_by_id method from recipe_scanner @@ -207,6 +221,9 @@ class RecipeRoutes: """Analyze an uploaded image or URL for recipe metadata""" temp_path = None try: + # Ensure services are initialized + await self.init_services() + # Check if request contains multipart data (image) or JSON data (url) content_type = request.headers.get('Content-Type', '') @@ -325,6 +342,9 @@ class RecipeRoutes: async def save_recipe(self, request: web.Request) -> web.Response: """Save a recipe to the recipes folder""" try: + # Ensure services are initialized + await self.init_services() + reader = await request.multipart() # Process form data @@ -424,7 +444,7 @@ class RecipeRoutes: # Optimize the image (resize and convert to WebP) optimized_image, extension = ExifUtils.optimize_image( image_data=image, - target_width=480, + target_width=CARD_PREVIEW_WIDTH, format='webp', quality=85, preserve_metadata=True @@ -526,6 +546,9 @@ class RecipeRoutes: async def delete_recipe(self, request: web.Request) -> web.Response: """Delete a recipe by ID""" try: + # Ensure services are initialized + await self.init_services() + recipe_id = request.match_info['recipe_id'] # Get recipes directory @@ -573,6 +596,9 @@ class RecipeRoutes: async def get_top_tags(self, request: web.Request) -> web.Response: """Get top tags used in recipes""" try: + # Ensure services are initialized + await self.init_services() + # Get limit parameter with default limit = int(request.query.get('limit', '20')) @@ -605,6 +631,9 @@ class RecipeRoutes: async def get_base_models(self, request: web.Request) -> web.Response: """Get base models used in recipes""" try: + # Ensure services are initialized + await self.init_services() + # Get all recipes from cache cache = await self.recipe_scanner.get_cached_data() @@ -633,6 +662,9 @@ class RecipeRoutes: async def share_recipe(self, request: web.Request) -> web.Response: """Process a recipe image for sharing by adding metadata to EXIF""" try: + # Ensure services are initialized + await self.init_services() + recipe_id = request.match_info['recipe_id'] # Get all recipes from cache @@ -692,6 +724,9 @@ class RecipeRoutes: async def download_shared_recipe(self, request: web.Request) -> web.Response: """Serve a processed recipe image for download""" try: + # Ensure services are initialized + await self.init_services() + recipe_id = request.match_info['recipe_id'] # Check if we have this shared recipe @@ -748,6 +783,9 @@ class RecipeRoutes: async def save_recipe_from_widget(self, request: web.Request) -> web.Response: """Save a recipe from the LoRAs widget""" try: + # Ensure services are initialized + await self.init_services() + reader = await request.multipart() # Process form data @@ -828,7 +866,7 @@ class RecipeRoutes: # Optimize the image (resize and convert to WebP) optimized_image, extension = ExifUtils.optimize_image( image_data=image, - target_width=480, + target_width=CARD_PREVIEW_WIDTH, format='webp', quality=85, preserve_metadata=True @@ -922,6 +960,9 @@ class RecipeRoutes: async def get_recipe_syntax(self, request: web.Request) -> web.Response: """Generate recipe syntax for LoRAs in the recipe, looking up proper file names using hash_index""" try: + # Ensure services are initialized + await self.init_services() + recipe_id = request.match_info['recipe_id'] # Get all recipes from cache @@ -1002,6 +1043,9 @@ class RecipeRoutes: async def update_recipe(self, request: web.Request) -> web.Response: """Update recipe metadata (name and tags)""" try: + # Ensure services are initialized + await self.init_services() + recipe_id = request.match_info['recipe_id'] data = await request.json() @@ -1029,6 +1073,9 @@ class RecipeRoutes: async def reconnect_lora(self, request: web.Request) -> web.Response: """Reconnect a deleted LoRA in a recipe to a local LoRA file""" try: + # Ensure services are initialized + await self.init_services() + # Parse request data data = await request.json() @@ -1139,6 +1186,9 @@ class RecipeRoutes: async def get_recipes_for_lora(self, request: web.Request) -> web.Response: """Get recipes that use a specific Lora""" try: + # Ensure services are initialized + await self.init_services() + lora_hash = request.query.get('hash') # Hash is required @@ -1146,7 +1196,7 @@ class RecipeRoutes: return web.json_response({'success': False, 'error': 'Lora hash is required'}, status=400) # Log the search parameters - logger.info(f"Getting recipes for Lora by hash: {lora_hash}") + logger.debug(f"Getting recipes for Lora by hash: {lora_hash}") # Get all recipes from cache cache = await self.recipe_scanner.get_cached_data() diff --git a/py/services/checkpoint_scanner.py b/py/services/checkpoint_scanner.py new file mode 100644 index 00000000..acaaa461 --- /dev/null +++ b/py/services/checkpoint_scanner.py @@ -0,0 +1,132 @@ +import os +import logging +import asyncio +from typing import List, Dict, Optional, Set +import folder_paths # type: ignore + +from ..utils.models import CheckpointMetadata +from ..config import config +from .model_scanner import ModelScanner +from .model_hash_index import ModelHashIndex +from .service_registry import ServiceRegistry + +logger = logging.getLogger(__name__) + +class CheckpointScanner(ModelScanner): + """Service for scanning and managing checkpoint files""" + + _instance = None + _lock = asyncio.Lock() + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if not hasattr(self, '_initialized'): + # Define supported file extensions + file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'} + super().__init__( + model_type="checkpoint", + model_class=CheckpointMetadata, + file_extensions=file_extensions, + hash_index=ModelHashIndex() + ) + self._checkpoint_roots = self._init_checkpoint_roots() + self._initialized = True + + @classmethod + async def get_instance(cls): + """Get singleton instance with async support""" + async with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def _init_checkpoint_roots(self) -> List[str]: + """Initialize checkpoint roots from ComfyUI settings""" + # Get both checkpoint and diffusion_models paths + checkpoint_paths = folder_paths.get_folder_paths("checkpoints") + diffusion_paths = folder_paths.get_folder_paths("diffusion_models") + + # Combine, normalize and deduplicate paths + all_paths = set() + for path in checkpoint_paths + diffusion_paths: + if os.path.exists(path): + norm_path = path.replace(os.sep, "/") + all_paths.add(norm_path) + + # Sort for consistent order + sorted_paths = sorted(all_paths, key=lambda p: p.lower()) + logger.info(f"Found checkpoint roots: {sorted_paths}") + + return sorted_paths + + def get_model_roots(self) -> List[str]: + """Get checkpoint root directories""" + return self._checkpoint_roots + + async def scan_all_models(self) -> List[Dict]: + """Scan all checkpoint directories and return metadata""" + all_checkpoints = [] + + # Create scan tasks for each directory + scan_tasks = [] + for root in self._checkpoint_roots: + task = asyncio.create_task(self._scan_directory(root)) + scan_tasks.append(task) + + # Wait for all tasks to complete + for task in scan_tasks: + try: + checkpoints = await task + all_checkpoints.extend(checkpoints) + except Exception as e: + logger.error(f"Error scanning checkpoint directory: {e}") + + return all_checkpoints + + async def _scan_directory(self, root_path: str) -> List[Dict]: + """Scan a directory for checkpoint files""" + checkpoints = [] + original_root = root_path + + async def scan_recursive(path: str, visited_paths: set): + try: + real_path = os.path.realpath(path) + if real_path in visited_paths: + logger.debug(f"Skipping already visited path: {path}") + return + visited_paths.add(real_path) + + with os.scandir(path) as it: + entries = list(it) + for entry in entries: + try: + if entry.is_file(follow_symlinks=True): + # Check if file has supported extension + ext = os.path.splitext(entry.name)[1].lower() + if ext in self.file_extensions: + file_path = entry.path.replace(os.sep, "/") + await self._process_single_file(file_path, original_root, checkpoints) + await asyncio.sleep(0) + elif entry.is_dir(follow_symlinks=True): + # For directories, continue scanning with original path + await scan_recursive(entry.path, visited_paths) + except Exception as e: + logger.error(f"Error processing entry {entry.path}: {e}") + except Exception as e: + logger.error(f"Error scanning {path}: {e}") + + await scan_recursive(root_path, set()) + return checkpoints + + async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list): + """Process a single checkpoint file and add to results""" + try: + result = await self._process_model_file(file_path, root_path) + if result: + checkpoints.append(result) + except Exception as e: + logger.error(f"Error processing {file_path}: {e}") \ No newline at end of file diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index fbd77739..24227f8d 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -3,6 +3,7 @@ import aiohttp import os import json import logging +import asyncio from email.parser import Parser from typing import Optional, Dict, Tuple, List from urllib.parse import unquote @@ -11,7 +12,23 @@ from ..utils.models import LoraMetadata logger = logging.getLogger(__name__) class CivitaiClient: + _instance = None + _lock = asyncio.Lock() + + @classmethod + async def get_instance(cls): + """Get singleton instance of CivitaiClient""" + async with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + def __init__(self): + # Check if already initialized for singleton pattern + if hasattr(self, '_initialized'): + return + self._initialized = True + self.base_url = "https://civitai.com/api/v1" self.headers = { 'User-Agent': 'ComfyUI-LoRA-Manager/1.0' diff --git a/py/services/download_manager.py b/py/services/download_manager.py index df0d3af2..7231fdfb 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -1,21 +1,79 @@ import logging import os import json -from typing import Optional, Dict +import asyncio +from typing import Optional, Dict, Any from .civitai_client import CivitaiClient -from .file_monitor import LoraFileMonitor -from ..utils.models import LoraMetadata +from ..utils.models import LoraMetadata, CheckpointMetadata +from ..utils.constants import CARD_PREVIEW_WIDTH +from ..utils.exif_utils import ExifUtils +from .service_registry import ServiceRegistry + +# Download to temporary file first +import tempfile logger = logging.getLogger(__name__) class DownloadManager: - def __init__(self, file_monitor: Optional[LoraFileMonitor] = None): - self.civitai_client = CivitaiClient() - self.file_monitor = file_monitor + _instance = None + _lock = asyncio.Lock() + + @classmethod + async def get_instance(cls): + """Get singleton instance of DownloadManager""" + async with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def __init__(self): + # Check if already initialized for singleton pattern + if hasattr(self, '_initialized'): + return + self._initialized = True + + self._civitai_client = None # Will be lazily initialized + + async def _get_civitai_client(self): + """Lazily initialize CivitaiClient from registry""" + if self._civitai_client is None: + self._civitai_client = await ServiceRegistry.get_civitai_client() + return self._civitai_client + + async def _get_lora_monitor(self): + """Get the lora file monitor from registry""" + return await ServiceRegistry.get_lora_monitor() + + async def _get_checkpoint_monitor(self): + """Get the checkpoint file monitor from registry""" + return await ServiceRegistry.get_checkpoint_monitor() + + async def _get_lora_scanner(self): + """Get the lora scanner from registry""" + return await ServiceRegistry.get_lora_scanner() + + async def _get_checkpoint_scanner(self): + """Get the checkpoint scanner from registry""" + return await ServiceRegistry.get_checkpoint_scanner() async def download_from_civitai(self, download_url: str = None, model_hash: str = None, model_version_id: str = None, save_dir: str = None, - relative_path: str = '', progress_callback=None) -> Dict: + relative_path: str = '', progress_callback=None, + model_type: str = "lora") -> Dict: + """Download model from Civitai + + Args: + download_url: Direct download URL for the model + model_hash: SHA256 hash of the model + model_version_id: Civitai model version ID + save_dir: Directory to save the model to + relative_path: Relative path within save_dir + progress_callback: Callback function for progress updates + model_type: Type of model ('lora' or 'checkpoint') + + Returns: + Dict with download result + """ try: # Update save directory with relative path if provided if relative_path: @@ -23,25 +81,28 @@ class DownloadManager: # Create directory if it doesn't exist os.makedirs(save_dir, exist_ok=True) + # Get civitai client + civitai_client = await self._get_civitai_client() + # Get version info based on the provided identifier version_info = None if download_url: # Extract version ID from download URL version_id = download_url.split('/')[-1] - version_info = await self.civitai_client.get_model_version_info(version_id) + version_info = await civitai_client.get_model_version_info(version_id) elif model_version_id: # Use model version ID directly - version_info = await self.civitai_client.get_model_version_info(model_version_id) + version_info = await civitai_client.get_model_version_info(model_version_id) elif model_hash: # Get model by hash - version_info = await self.civitai_client.get_model_by_hash(model_hash) + version_info = await civitai_client.get_model_by_hash(model_hash) if not version_info: return {'success': False, 'error': 'Failed to fetch model metadata'} - # Check if this is an early access LoRA + # Check if this is an early access model if version_info.get('earlyAccessEndsAt'): early_access_date = version_info.get('earlyAccessEndsAt', '') # Convert to a readable date if possible @@ -49,12 +110,12 @@ class DownloadManager: from datetime import datetime date_obj = datetime.fromisoformat(early_access_date.replace('Z', '+00:00')) formatted_date = date_obj.strftime('%Y-%m-%d') - early_access_msg = f"This LoRA requires early access payment (until {formatted_date}). " + early_access_msg = f"This model requires early access payment (until {formatted_date}). " except: - early_access_msg = "This LoRA requires early access payment. " + early_access_msg = "This model requires early access payment. " early_access_msg += "Please ensure you have purchased early access and are logged in to Civitai." - logger.warning(f"Early access LoRA detected: {version_info.get('name', 'Unknown')}") + logger.warning(f"Early access model detected: {version_info.get('name', 'Unknown')}") # We'll still try to download, but log a warning and prepare for potential failure if progress_callback: @@ -64,43 +125,51 @@ class DownloadManager: if progress_callback: await progress_callback(0) - # 2. 获取文件信息 + # 2. Get file information file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None) if not file_info: return {'success': False, 'error': 'No primary file found in metadata'} - # 3. 准备下载 + # 3. Prepare download file_name = file_info['name'] save_path = os.path.join(save_dir, file_name) file_size = file_info.get('sizeKB', 0) * 1024 - # 4. 通知文件监控系统 - 使用规范化路径和文件大小 - self.file_monitor.handler.add_ignore_path( - save_path.replace(os.sep, '/'), - file_size - ) + # 4. Notify file monitor - use normalized path and file size + file_monitor = await self._get_lora_monitor() if model_type == "lora" else await self._get_checkpoint_monitor() + if file_monitor and file_monitor.handler: + file_monitor.handler.add_ignore_path( + save_path.replace(os.sep, '/'), + file_size + ) - # 5. 准备元数据 - metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path) + # 5. Prepare metadata based on model type + if model_type == "checkpoint": + metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path) + logger.info(f"Creating CheckpointMetadata for {file_name}") + else: + metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path) + logger.info(f"Creating LoraMetadata for {file_name}") - # 5.1 获取并更新模型标签和描述信息 + # 5.1 Get and update model tags and description model_id = version_info.get('modelId') if model_id: - model_metadata, _ = await self.civitai_client.get_model_metadata(str(model_id)) + model_metadata, _ = await civitai_client.get_model_metadata(str(model_id)) if model_metadata: if model_metadata.get("tags"): metadata.tags = model_metadata.get("tags", []) if model_metadata.get("description"): metadata.modelDescription = model_metadata.get("description", "") - # 6. 开始下载流程 + # 6. Start download process result = await self._execute_download( download_url=file_info.get('downloadUrl', ''), save_dir=save_dir, metadata=metadata, version_info=version_info, relative_path=relative_path, - progress_callback=progress_callback + progress_callback=progress_callback, + model_type=model_type ) return result @@ -114,10 +183,12 @@ class DownloadManager: return {'success': False, 'error': str(e)} async def _execute_download(self, download_url: str, save_dir: str, - metadata: LoraMetadata, version_info: Dict, - relative_path: str, progress_callback=None) -> Dict: + metadata, version_info: Dict, + relative_path: str, progress_callback=None, + model_type: str = "lora") -> Dict: """Execute the actual download process including preview images and model files""" try: + civitai_client = await self._get_civitai_client() save_path = metadata.file_path metadata_path = os.path.splitext(save_path)[0] + '.metadata.json' @@ -128,20 +199,61 @@ class DownloadManager: if progress_callback: await progress_callback(1) # 1% progress for starting preview download - preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png' - preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext - if await self.civitai_client.download_preview_image(images[0]['url'], preview_path): - metadata.preview_url = preview_path.replace(os.sep, '/') - metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0) - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False) + # Check if it's a video or an image + is_video = images[0].get('type') == 'video' + + if is_video: + # For videos, use .mp4 extension + preview_ext = '.mp4' + preview_path = os.path.splitext(save_path)[0] + preview_ext + + # Download video directly + if await civitai_client.download_preview_image(images[0]['url'], preview_path): + metadata.preview_url = preview_path.replace(os.sep, '/') + metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0) + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False) + else: + # For images, use WebP format for better performance + with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file: + temp_path = temp_file.name + + # Download the original image to temp path + if await civitai_client.download_preview_image(images[0]['url'], temp_path): + # Optimize and convert to WebP + preview_path = os.path.splitext(save_path)[0] + '.webp' + + # Use ExifUtils to optimize and convert the image + optimized_data, _ = ExifUtils.optimize_image( + image_data=temp_path, + target_width=CARD_PREVIEW_WIDTH, + format='webp', + quality=85, + preserve_metadata=True + ) + + # Save the optimized image + with open(preview_path, 'wb') as f: + f.write(optimized_data) + + # Update metadata + metadata.preview_url = preview_path.replace(os.sep, '/') + metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0) + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False) + + # Remove temporary file + try: + os.unlink(temp_path) + except Exception as e: + logger.warning(f"Failed to delete temp file: {e}") # Report preview download completion if progress_callback: await progress_callback(3) # 3% progress after preview download # Download model file with progress tracking - success, result = await self.civitai_client._download_file( + success, result = await civitai_client._download_file( download_url, save_dir, os.path.basename(save_path), @@ -155,15 +267,22 @@ class DownloadManager: os.remove(path) return {'success': False, 'error': result} - # 4. 更新文件信息(大小和修改时间) + # 4. Update file information (size and modified time) metadata.update_file_info(save_path) - # 5. 最终更新元数据 + # 5. Final metadata update with open(metadata_path, 'w', encoding='utf-8') as f: json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False) - # 6. update lora cache - cache = await self.file_monitor.scanner.get_cached_data() + # 6. Update cache based on model type + if model_type == "checkpoint": + scanner = await self._get_checkpoint_scanner() + logger.info(f"Updating checkpoint cache for {save_path}") + else: + scanner = await self._get_lora_scanner() + logger.info(f"Updating lora cache for {save_path}") + + cache = await scanner.get_cached_data() metadata_dict = metadata.to_dict() metadata_dict['folder'] = relative_path cache.raw_data.append(metadata_dict) @@ -172,11 +291,8 @@ class DownloadManager: all_folders.add(relative_path) cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) - # Update the hash index with the new LoRA entry - self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) - - # Update the hash index with the new LoRA entry - self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) + # Update the hash index with the new model entry + scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) # Report 100% completion if progress_callback: diff --git a/py/services/file_monitor.py b/py/services/file_monitor.py index 9ed44d0f..fc43914d 100644 --- a/py/services/file_monitor.py +++ b/py/services/file_monitor.py @@ -1,37 +1,42 @@ -from operator import itemgetter import os import logging import asyncio import time from watchdog.observers import Observer from watchdog.events import FileSystemEventHandler -from typing import List, Dict, Set +from typing import List, Dict, Set, Optional from threading import Lock -from .lora_scanner import LoraScanner + from ..config import config +from .service_registry import ServiceRegistry logger = logging.getLogger(__name__) -class LoraFileHandler(FileSystemEventHandler): - """Handler for LoRA file system events""" +# Configuration constant to control file monitoring functionality +ENABLE_FILE_MONITORING = False + +class BaseFileHandler(FileSystemEventHandler): + """Base handler for file system events""" - def __init__(self, scanner: LoraScanner, loop: asyncio.AbstractEventLoop): - self.scanner = scanner - self.loop = loop # 存储事件循环引用 - self.pending_changes = set() # 待处理的变更 - self.lock = Lock() # 线程安全锁 - self.update_task = None # 异步更新任务 - self._ignore_paths = set() # Add ignore paths set - self._min_ignore_timeout = 5 # minimum timeout in seconds - self._download_speed = 1024 * 1024 # assume 1MB/s as base speed + def __init__(self, loop: asyncio.AbstractEventLoop): + self.loop = loop # Store event loop reference + self.pending_changes = set() # Pending changes + self.lock = Lock() # Thread-safe lock + self.update_task = None # Async update task + self._ignore_paths = set() # Paths to ignore + self._min_ignore_timeout = 5 # Minimum timeout in seconds + self._download_speed = 1024 * 1024 # Assume 1MB/s as base speed # Track modified files with timestamps for debouncing self.modified_files: Dict[str, float] = {} self.debounce_timer = None - self.debounce_delay = 3.0 # seconds to wait after last modification + self.debounce_delay = 3.0 # Seconds to wait after last modification - # Track files that are already scheduled for processing + # Track files already scheduled for processing self.scheduled_files: Set[str] = set() + + # File extensions to monitor - should be overridden by subclasses + self.file_extensions = set() def _should_ignore(self, path: str) -> bool: """Check if path should be ignored""" @@ -56,35 +61,33 @@ class LoraFileHandler(FileSystemEventHandler): if event.is_directory: return - # Handle safetensors files directly - if event.src_path.endswith('.safetensors'): + # Handle appropriate files based on extensions + file_ext = os.path.splitext(event.src_path)[1].lower() + if file_ext in self.file_extensions: if self._should_ignore(event.src_path): return - # We'll process this file directly and ignore subsequent modifications - # to prevent duplicate processing + # Process this file directly and ignore subsequent modifications normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/') if normalized_path not in self.scheduled_files: - logger.info(f"LoRA file created: {event.src_path}") + logger.info(f"File created: {event.src_path}") self.scheduled_files.add(normalized_path) self._schedule_update('add', event.src_path) # Ignore modifications for a short period after creation - # This helps avoid duplicate processing self.loop.call_later( self.debounce_delay * 2, self.scheduled_files.discard, normalized_path ) - # For browser downloads, we'll catch them when they're renamed to .safetensors - def on_modified(self, event): if event.is_directory: return - # Only process safetensors files - if event.src_path.endswith('.safetensors'): + # Only process files with supported extensions + file_ext = os.path.splitext(event.src_path)[1].lower() + if file_ext in self.file_extensions: if self._should_ignore(event.src_path): return @@ -132,12 +135,17 @@ class LoraFileHandler(FileSystemEventHandler): # Process stable files for file_path in files_to_process: - logger.info(f"Processing modified LoRA file: {file_path}") + logger.info(f"Processing modified file: {file_path}") self._schedule_update('add', file_path) def on_deleted(self, event): - if event.is_directory or not event.src_path.endswith('.safetensors'): + if event.is_directory: return + + file_ext = os.path.splitext(event.src_path)[1].lower() + if file_ext not in self.file_extensions: + return + if self._should_ignore(event.src_path): return @@ -145,14 +153,17 @@ class LoraFileHandler(FileSystemEventHandler): normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/') self.scheduled_files.discard(normalized_path) - logger.info(f"LoRA file deleted: {event.src_path}") + logger.info(f"File deleted: {event.src_path}") self._schedule_update('remove', event.src_path) def on_moved(self, event): """Handle file move/rename events""" - # If destination is a safetensors file, treat it as a new file - if event.dest_path.endswith('.safetensors'): + src_ext = os.path.splitext(event.src_path)[1].lower() + dest_ext = os.path.splitext(event.dest_path)[1].lower() + + # If destination has supported extension, treat as new file + if dest_ext in self.file_extensions: if self._should_ignore(event.dest_path): return @@ -160,7 +171,7 @@ class LoraFileHandler(FileSystemEventHandler): # Only process if not already scheduled if normalized_path not in self.scheduled_files: - logger.info(f"LoRA file renamed/moved to: {event.dest_path}") + logger.info(f"File renamed/moved to: {event.dest_path}") self.scheduled_files.add(normalized_path) self._schedule_update('add', event.dest_path) @@ -171,21 +182,21 @@ class LoraFileHandler(FileSystemEventHandler): normalized_path ) - # If source was a safetensors file, treat it as deleted - if event.src_path.endswith('.safetensors'): + # If source was a supported file, treat it as deleted + if src_ext in self.file_extensions: if self._should_ignore(event.src_path): return normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/') self.scheduled_files.discard(normalized_path) - logger.info(f"LoRA file moved/renamed from: {event.src_path}") + logger.info(f"File moved/renamed from: {event.src_path}") self._schedule_update('remove', event.src_path) - def _schedule_update(self, action: str, file_path: str): #file_path is a real path + def _schedule_update(self, action: str, file_path: str): """Schedule a cache update""" with self.lock: - # 使用 config 中的方法映射路径 + # Use config method to map path mapped_path = config.map_path_to_link(file_path) normalized_path = mapped_path.replace(os.sep, '/') self.pending_changes.add((action, normalized_path)) @@ -196,7 +207,20 @@ class LoraFileHandler(FileSystemEventHandler): """Create update task in the event loop""" if self.update_task is None or self.update_task.done(): self.update_task = asyncio.create_task(self._process_changes()) + + async def _process_changes(self, delay: float = 2.0): + """Process pending changes with debouncing - should be implemented by subclasses""" + raise NotImplementedError("Subclasses must implement _process_changes") + +class LoraFileHandler(BaseFileHandler): + """Handler for LoRA file system events""" + + def __init__(self, loop: asyncio.AbstractEventLoop): + super().__init__(loop) + # Set supported file extensions for LoRAs + self.file_extensions = {'.safetensors'} + async def _process_changes(self, delay: float = 2.0): """Process pending changes with debouncing""" await asyncio.sleep(delay) @@ -209,9 +233,11 @@ class LoraFileHandler(FileSystemEventHandler): if not changes: return - logger.info(f"Processing {len(changes)} file changes") + logger.info(f"Processing {len(changes)} LoRA file changes") - cache = await self.scanner.get_cached_data() + # Get scanner through ServiceRegistry + scanner = await ServiceRegistry.get_lora_scanner() + cache = await scanner.get_cached_data() needs_resort = False new_folders = set() @@ -225,36 +251,36 @@ class LoraFileHandler(FileSystemEventHandler): continue # Scan new file - lora_data = await self.scanner.scan_single_lora(file_path) - if lora_data: + model_data = await scanner.scan_single_model(file_path) + if model_data: # Update tags count - for tag in lora_data.get('tags', []): - self.scanner._tags_count[tag] = self.scanner._tags_count.get(tag, 0) + 1 + for tag in model_data.get('tags', []): + scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1 - cache.raw_data.append(lora_data) - new_folders.add(lora_data['folder']) + cache.raw_data.append(model_data) + new_folders.add(model_data['folder']) # Update hash index - if 'sha256' in lora_data: - self.scanner._hash_index.add_entry( - lora_data['sha256'], - lora_data['file_path'] + if 'sha256' in model_data: + scanner._hash_index.add_entry( + model_data['sha256'], + model_data['file_path'] ) needs_resort = True elif action == 'remove': - # Find the lora to remove so we can update tags count - lora_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None) - if lora_to_remove: + # Find the model to remove so we can update tags count + model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None) + if model_to_remove: # Update tags count by reducing counts - for tag in lora_to_remove.get('tags', []): - if tag in self.scanner._tags_count: - self.scanner._tags_count[tag] = max(0, self.scanner._tags_count[tag] - 1) - if self.scanner._tags_count[tag] == 0: - del self.scanner._tags_count[tag] + for tag in model_to_remove.get('tags', []): + if tag in scanner._tags_count: + scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1) + if scanner._tags_count[tag] == 0: + del scanner._tags_count[tag] # Remove from cache and hash index logger.info(f"Removing {file_path} from cache") - self.scanner._hash_index.remove_by_path(file_path) + scanner._hash_index.remove_by_path(file_path) cache.raw_data = [ item for item in cache.raw_data if item['file_path'] != file_path @@ -272,62 +298,245 @@ class LoraFileHandler(FileSystemEventHandler): cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) except Exception as e: - logger.error(f"Error in process_changes: {e}") + logger.error(f"Error in process_changes for LoRA: {e}") -class LoraFileMonitor: - """Monitor for LoRA file changes""" +class CheckpointFileHandler(BaseFileHandler): + """Handler for checkpoint file system events""" - def __init__(self, scanner: LoraScanner, roots: List[str]): - self.scanner = scanner - scanner.set_file_monitor(self) + def __init__(self, loop: asyncio.AbstractEventLoop): + super().__init__(loop) + # Set supported file extensions for checkpoints + self.file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'} + + async def _process_changes(self, delay: float = 2.0): + """Process pending changes with debouncing for checkpoint files""" + await asyncio.sleep(delay) + + try: + with self.lock: + changes = self.pending_changes.copy() + self.pending_changes.clear() + + if not changes: + return + + logger.info(f"Processing {len(changes)} checkpoint file changes") + + # Get scanner through ServiceRegistry + scanner = await ServiceRegistry.get_checkpoint_scanner() + cache = await scanner.get_cached_data() + needs_resort = False + new_folders = set() + + for action, file_path in changes: + try: + if action == 'add': + # Check if file already exists in cache + existing = next((item for item in cache.raw_data if item['file_path'] == file_path), None) + if existing: + logger.info(f"File {file_path} already in cache, skipping") + continue + + # Scan new file + model_data = await scanner.scan_single_model(file_path) + if model_data: + # Update tags count if applicable + for tag in model_data.get('tags', []): + scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1 + + cache.raw_data.append(model_data) + new_folders.add(model_data['folder']) + # Update hash index + if 'sha256' in model_data: + scanner._hash_index.add_entry( + model_data['sha256'], + model_data['file_path'] + ) + needs_resort = True + + elif action == 'remove': + # Find the model to remove so we can update tags count + model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None) + if model_to_remove: + # Update tags count by reducing counts + for tag in model_to_remove.get('tags', []): + if tag in scanner._tags_count: + scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1) + if scanner._tags_count[tag] == 0: + del scanner._tags_count[tag] + + # Remove from cache and hash index + logger.info(f"Removing {file_path} from checkpoint cache") + scanner._hash_index.remove_by_path(file_path) + cache.raw_data = [ + item for item in cache.raw_data + if item['file_path'] != file_path + ] + needs_resort = True + + except Exception as e: + logger.error(f"Error processing checkpoint {action} for {file_path}: {e}") + + if needs_resort: + await cache.resort() + + # Update folder list + all_folders = set(cache.folders) | new_folders + cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) + + except Exception as e: + logger.error(f"Error in process_changes for checkpoint: {e}") + + +class BaseFileMonitor: + """Base class for file monitoring""" + + def __init__(self, monitor_paths: List[str]): self.observer = Observer() self.loop = asyncio.get_event_loop() - self.handler = LoraFileHandler(scanner, self.loop) - - # 使用已存在的路径映射 self.monitor_paths = set() - for root in roots: - self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/')) + + # Process monitor paths + for path in monitor_paths: + self.monitor_paths.add(os.path.realpath(path).replace(os.sep, '/')) - # 添加所有已映射的目标路径 + # Add mapped paths from config for target_path in config._path_mappings.keys(): self.monitor_paths.add(target_path) - + def start(self): - """Start monitoring""" - for path_info in self.monitor_paths: + """Start file monitoring""" + if not ENABLE_FILE_MONITORING: + logger.info("File monitoring is disabled via ENABLE_FILE_MONITORING setting") + return + + for path in self.monitor_paths: try: - if isinstance(path_info, tuple): - # 对于链接,监控目标路径 - _, target_path = path_info - self.observer.schedule(self.handler, target_path, recursive=True) - logger.info(f"Started monitoring target path: {target_path}") - else: - # 对于普通路径,直接监控 - self.observer.schedule(self.handler, path_info, recursive=True) - logger.info(f"Started monitoring: {path_info}") + self.observer.schedule(self.handler, path, recursive=True) + logger.info(f"Started monitoring: {path}") except Exception as e: - logger.error(f"Error monitoring {path_info}: {e}") + logger.error(f"Error monitoring {path}: {e}") self.observer.start() - + def stop(self): - """Stop monitoring""" + """Stop file monitoring""" + if not ENABLE_FILE_MONITORING: + return + self.observer.stop() self.observer.join() - + def rescan_links(self): - """重新扫描链接(当添加新的链接时调用)""" + """Rescan links when new ones are added""" + if not ENABLE_FILE_MONITORING: + return + + # Find new paths not yet being monitored new_paths = set() - for path in self.monitor_paths.copy(): - self._add_link_targets(path) + for path in config._path_mappings.keys(): + real_path = os.path.realpath(path).replace(os.sep, '/') + if real_path not in self.monitor_paths: + new_paths.add(real_path) + self.monitor_paths.add(real_path) - # 添加新发现的路径到监控 - new_paths = self.monitor_paths - set(self.observer.watches.keys()) + # Add new paths to monitoring for path in new_paths: try: self.observer.schedule(self.handler, path, recursive=True) logger.info(f"Added new monitoring path: {path}") except Exception as e: - logger.error(f"Error adding new monitor for {path}: {e}") \ No newline at end of file + logger.error(f"Error adding new monitor for {path}: {e}") + + +class LoraFileMonitor(BaseFileMonitor): + """Monitor for LoRA file changes""" + + _instance = None + _lock = asyncio.Lock() + + def __new__(cls, monitor_paths=None): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, monitor_paths=None): + if not hasattr(self, '_initialized'): + if monitor_paths is None: + from ..config import config + monitor_paths = config.loras_roots + + super().__init__(monitor_paths) + self.handler = LoraFileHandler(self.loop) + self._initialized = True + + @classmethod + async def get_instance(cls): + """Get singleton instance with async support""" + async with cls._lock: + if cls._instance is None: + from ..config import config + cls._instance = cls(config.loras_roots) + return cls._instance + + +class CheckpointFileMonitor(BaseFileMonitor): + """Monitor for checkpoint file changes""" + + _instance = None + _lock = asyncio.Lock() + + def __new__(cls, monitor_paths=None): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, monitor_paths=None): + if not hasattr(self, '_initialized'): + if monitor_paths is None: + # Get checkpoint roots from scanner + monitor_paths = [] + # We'll initialize monitor paths later when scanner is available + + super().__init__(monitor_paths or []) + self.handler = CheckpointFileHandler(self.loop) + self._initialized = True + + @classmethod + async def get_instance(cls): + """Get singleton instance with async support""" + async with cls._lock: + if cls._instance is None: + cls._instance = cls([]) + + # Now get checkpoint roots from scanner + from .checkpoint_scanner import CheckpointScanner + scanner = await CheckpointScanner.get_instance() + monitor_paths = scanner.get_model_roots() + + # Update monitor paths - but don't actually monitor them + for path in monitor_paths: + real_path = os.path.realpath(path).replace(os.sep, '/') + cls._instance.monitor_paths.add(real_path) + + return cls._instance + + def start(self): + """Override start to check global enable flag""" + if not ENABLE_FILE_MONITORING: + logger.info("Checkpoint file monitoring is disabled via ENABLE_FILE_MONITORING setting") + return + + logger.info("Checkpoint file monitoring is temporarily disabled") + # Skip the actual monitoring setup + pass + + async def initialize_paths(self): + """Initialize monitor paths from scanner - currently disabled""" + if not ENABLE_FILE_MONITORING: + logger.info("Checkpoint path initialization skipped (monitoring disabled)") + return + + logger.info("Checkpoint file path initialization skipped (monitoring disabled)") + pass \ No newline at end of file diff --git a/py/services/lora_scanner.py b/py/services/lora_scanner.py index c8142086..29908ef9 100644 --- a/py/services/lora_scanner.py +++ b/py/services/lora_scanner.py @@ -4,22 +4,21 @@ import logging import asyncio import shutil import time -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Set from ..utils.models import LoraMetadata from ..config import config -from ..utils.file_utils import load_metadata, get_file_info, normalize_path, find_preview_file, save_metadata -from ..utils.lora_metadata import extract_lora_metadata -from .lora_cache import LoraCache +from .model_scanner import ModelScanner from .lora_hash_index import LoraHashIndex from .settings_manager import settings from ..utils.constants import NSFW_LEVELS from ..utils.utils import fuzzy_match +from .service_registry import ServiceRegistry import sys logger = logging.getLogger(__name__) -class LoraScanner: +class LoraScanner(ModelScanner): """Service for scanning and managing LoRA files""" _instance = None @@ -31,20 +30,20 @@ class LoraScanner: return cls._instance def __init__(self): - # 确保初始化只执行一次 + # Ensure initialization happens only once if not hasattr(self, '_initialized'): - self._cache: Optional[LoraCache] = None - self._hash_index = LoraHashIndex() - self._initialization_lock = asyncio.Lock() - self._initialization_task: Optional[asyncio.Task] = None + # Define supported file extensions + file_extensions = {'.safetensors'} + + # Initialize parent class + super().__init__( + model_type="lora", + model_class=LoraMetadata, + file_extensions=file_extensions, + hash_index=LoraHashIndex() + ) self._initialized = True - self.file_monitor = None # Add this line - self._tags_count = {} # Add a dictionary to store tag counts - - def set_file_monitor(self, monitor): - """Set file monitor instance""" - self.file_monitor = monitor - + @classmethod async def get_instance(cls): """Get singleton instance with async support""" @@ -52,89 +51,74 @@ class LoraScanner: if cls._instance is None: cls._instance = cls() return cls._instance - - async def get_cached_data(self, force_refresh: bool = False) -> LoraCache: - """Get cached LoRA data, refresh if needed""" - async with self._initialization_lock: + + def get_model_roots(self) -> List[str]: + """Get lora root directories""" + return config.loras_roots + + async def scan_all_models(self) -> List[Dict]: + """Scan all LoRA directories and return metadata""" + all_loras = [] + + # Create scan tasks for each directory + scan_tasks = [] + for lora_root in self.get_model_roots(): + task = asyncio.create_task(self._scan_directory(lora_root)) + scan_tasks.append(task) - # 如果缓存未初始化但需要响应请求,返回空缓存 - if self._cache is None and not force_refresh: - return LoraCache( - raw_data=[], - sorted_by_name=[], - sorted_by_date=[], - folders=[] - ) - - # 如果正在初始化,等待完成 - if self._initialization_task and not self._initialization_task.done(): - try: - await self._initialization_task - except Exception as e: - logger.error(f"Cache initialization failed: {e}") - self._initialization_task = None - - if (self._cache is None or force_refresh): + # Wait for all tasks to complete + for task in scan_tasks: + try: + loras = await task + all_loras.extend(loras) + except Exception as e: + logger.error(f"Error scanning directory: {e}") - # 创建新的初始化任务 - if not self._initialization_task or self._initialization_task.done(): - self._initialization_task = asyncio.create_task(self._initialize_cache()) + return all_loras + + async def _scan_directory(self, root_path: str) -> List[Dict]: + """Scan a single directory for LoRA files""" + loras = [] + original_root = root_path # Save original root path + + async def scan_recursive(path: str, visited_paths: set): + """Recursively scan directory, avoiding circular symlinks""" + try: + real_path = os.path.realpath(path) + if real_path in visited_paths: + logger.debug(f"Skipping already visited path: {path}") + return + visited_paths.add(real_path) - try: - await self._initialization_task - except Exception as e: - logger.error(f"Cache initialization failed: {e}") - # 如果缓存已存在,继续使用旧缓存 - if self._cache is None: - raise # 如果没有缓存,则抛出异常 - - return self._cache + with os.scandir(path) as it: + entries = list(it) + for entry in entries: + try: + if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions): + # Use original path instead of real path + file_path = entry.path.replace(os.sep, "/") + await self._process_single_file(file_path, original_root, loras) + await asyncio.sleep(0) + elif entry.is_dir(follow_symlinks=True): + # For directories, continue scanning with original path + await scan_recursive(entry.path, visited_paths) + except Exception as e: + logger.error(f"Error processing entry {entry.path}: {e}") + except Exception as e: + logger.error(f"Error scanning {path}: {e}") - async def _initialize_cache(self) -> None: - """Initialize or refresh the cache""" + await scan_recursive(root_path, set()) + return loras + + async def _process_single_file(self, file_path: str, root_path: str, loras: list): + """Process a single file and add to results list""" try: - start_time = time.time() - # Clear existing hash index - self._hash_index.clear() - - # Clear existing tags count - self._tags_count = {} - - # Scan for new data - raw_data = await self.scan_all_loras() - - # Build hash index and tags count - for lora_data in raw_data: - if 'sha256' in lora_data and 'file_path' in lora_data: - self._hash_index.add_entry(lora_data['sha256'].lower(), lora_data['file_path']) - - # Count tags - if 'tags' in lora_data and lora_data['tags']: - for tag in lora_data['tags']: - self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 - - # Update cache - self._cache = LoraCache( - raw_data=raw_data, - sorted_by_name=[], - sorted_by_date=[], - folders=[] - ) - - # Call resort_cache to create sorted views - await self._cache.resort() - - self._initialization_task = None - logger.info(f"LoRA Manager: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} loras") + result = await self._process_model_file(file_path, root_path) + if result: + loras.append(result) except Exception as e: - logger.error(f"LoRA Manager: Error initializing cache: {e}") - self._cache = LoraCache( - raw_data=[], - sorted_by_name=[], - sorted_by_date=[], - folders=[] - ) - + logger.error(f"Error processing {file_path}: {e}") + async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name', folder: str = None, search: str = None, fuzzy_search: bool = False, base_models: list = None, tags: list = None, @@ -280,348 +264,6 @@ class LoraScanner: return result - def invalidate_cache(self): - """Invalidate the current cache""" - self._cache = None - - async def scan_all_loras(self) -> List[Dict]: - """Scan all LoRA directories and return metadata""" - all_loras = [] - - # 分目录异步扫描 - scan_tasks = [] - for loras_root in config.loras_roots: - task = asyncio.create_task(self._scan_directory(loras_root)) - scan_tasks.append(task) - - for task in scan_tasks: - try: - loras = await task - all_loras.extend(loras) - except Exception as e: - logger.error(f"Error scanning directory: {e}") - - return all_loras - - async def _scan_directory(self, root_path: str) -> List[Dict]: - """Scan a single directory for LoRA files""" - loras = [] - original_root = root_path # 保存原始根路径 - - async def scan_recursive(path: str, visited_paths: set): - """递归扫描目录,避免循环链接""" - try: - real_path = os.path.realpath(path) - if real_path in visited_paths: - logger.debug(f"Skipping already visited path: {path}") - return - visited_paths.add(real_path) - - with os.scandir(path) as it: - entries = list(it) - for entry in entries: - try: - if entry.is_file(follow_symlinks=True) and entry.name.endswith('.safetensors'): - # 使用原始路径而不是真实路径 - file_path = entry.path.replace(os.sep, "/") - await self._process_single_file(file_path, original_root, loras) - await asyncio.sleep(0) - elif entry.is_dir(follow_symlinks=True): - # 对于目录,使用原始路径继续扫描 - await scan_recursive(entry.path, visited_paths) - except Exception as e: - logger.error(f"Error processing entry {entry.path}: {e}") - except Exception as e: - logger.error(f"Error scanning {path}: {e}") - - await scan_recursive(root_path, set()) - return loras - - async def _process_single_file(self, file_path: str, root_path: str, loras: list): - """处理单个文件并添加到结果列表""" - try: - result = await self._process_lora_file(file_path, root_path) - if result: - loras.append(result) - except Exception as e: - logger.error(f"Error processing {file_path}: {e}") - - async def _process_lora_file(self, file_path: str, root_path: str) -> Dict: - """Process a single LoRA file and return its metadata""" - # Try loading existing metadata - metadata = await load_metadata(file_path) - - if metadata is None: - # Try to find and use .civitai.info file first - civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info" - if os.path.exists(civitai_info_path): - try: - with open(civitai_info_path, 'r', encoding='utf-8') as f: - version_info = json.load(f) - - file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None) - if file_info: - # Create a minimal file_info with the required fields - file_name = os.path.splitext(os.path.basename(file_path))[0] - file_info['name'] = file_name - - # Use from_civitai_info to create metadata - metadata = LoraMetadata.from_civitai_info(version_info, file_info, file_path) - metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path)) - await save_metadata(file_path, metadata) - logger.debug(f"Created metadata from .civitai.info for {file_path}") - except Exception as e: - logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}") - - # If still no metadata, create new metadata using get_file_info - if metadata is None: - metadata = await get_file_info(file_path) - - # Convert to dict and add folder info - lora_data = metadata.to_dict() - # Try to fetch missing metadata from Civitai if needed - await self._fetch_missing_metadata(file_path, lora_data) - rel_path = os.path.relpath(file_path, root_path) - folder = os.path.dirname(rel_path) - lora_data['folder'] = folder.replace(os.path.sep, '/') - - return lora_data - - async def _fetch_missing_metadata(self, file_path: str, lora_data: Dict) -> None: - """Fetch missing description and tags from Civitai if needed - - Args: - file_path: Path to the lora file - lora_data: Lora metadata dictionary to update - """ - try: - # Skip if already marked as deleted on Civitai - if lora_data.get('civitai_deleted', False): - logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai") - return - - # Check if we need to fetch additional metadata from Civitai - needs_metadata_update = False - model_id = None - - # Check if we have Civitai model ID but missing metadata - if lora_data.get('civitai'): - # Try to get model ID directly from the correct location - model_id = lora_data['civitai'].get('modelId') - - if model_id: - model_id = str(model_id) - # Check if tags are missing or empty - tags_missing = not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0 - - # Check if description is missing or empty - desc_missing = not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "") - - needs_metadata_update = tags_missing or desc_missing - - # Fetch missing metadata if needed - if needs_metadata_update and model_id: - logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}") - from ..services.civitai_client import CivitaiClient - client = CivitaiClient() - - # Get metadata and status code - model_metadata, status_code = await client.get_model_metadata(model_id) - await client.close() - - # Handle 404 status (model deleted from Civitai) - if status_code == 404: - logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)") - # Mark as deleted to avoid future API calls - lora_data['civitai_deleted'] = True - - # Save the updated metadata back to file - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(lora_data, f, indent=2, ensure_ascii=False) - - # Process valid metadata if available - elif model_metadata: - logger.debug(f"Updating metadata for {file_path} with model ID {model_id}") - - # Update tags if they were missing - if model_metadata.get('tags') and (not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0): - lora_data['tags'] = model_metadata['tags'] - - # Update description if it was missing - if model_metadata.get('description') and (not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "")): - lora_data['modelDescription'] = model_metadata['description'] - - # Save the updated metadata back to file - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(lora_data, f, indent=2, ensure_ascii=False) - except Exception as e: - logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}") - - async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool: - """Update preview URL in cache for a specific lora - - Args: - file_path: The file path of the lora to update - preview_url: The new preview URL - - Returns: - bool: True if the update was successful, False if cache doesn't exist or lora wasn't found - """ - if self._cache is None: - return False - - return await self._cache.update_preview_url(file_path, preview_url) - - async def scan_single_lora(self, file_path: str) -> Optional[Dict]: - """Scan a single LoRA file and return its metadata""" - try: - if not os.path.exists(os.path.realpath(file_path)): - return None - - # 获取基本文件信息 - metadata = await get_file_info(file_path) - if not metadata: - return None - - folder = self._calculate_folder(file_path) - - # 确保 folder 字段存在 - metadata_dict = metadata.to_dict() - metadata_dict['folder'] = folder or '' - - return metadata_dict - - except Exception as e: - logger.error(f"Error scanning {file_path}: {e}") - return None - - def _calculate_folder(self, file_path: str) -> str: - """Calculate the folder path for a LoRA file""" - # 使用原始路径计算相对路径 - for root in config.loras_roots: - if file_path.startswith(root): - rel_path = os.path.relpath(file_path, root) - return os.path.dirname(rel_path).replace(os.path.sep, '/') - return '' - - async def move_model(self, source_path: str, target_path: str) -> bool: - """Move a model and its associated files to a new location""" - try: - # 保持原始路径格式 - source_path = source_path.replace(os.sep, '/') - target_path = target_path.replace(os.sep, '/') - - # 其余代码保持不变 - base_name = os.path.splitext(os.path.basename(source_path))[0] - source_dir = os.path.dirname(source_path) - - os.makedirs(target_path, exist_ok=True) - - target_lora = os.path.join(target_path, f"{base_name}.safetensors").replace(os.sep, '/') - - # 使用真实路径进行文件操作 - real_source = os.path.realpath(source_path) - real_target = os.path.realpath(target_lora) - - file_size = os.path.getsize(real_source) - - if self.file_monitor: - self.file_monitor.handler.add_ignore_path( - real_source, - file_size - ) - self.file_monitor.handler.add_ignore_path( - real_target, - file_size - ) - - # 使用真实路径进行文件操作 - shutil.move(real_source, real_target) - - # Move associated files - source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json") - if os.path.exists(source_metadata): - target_metadata = os.path.join(target_path, f"{base_name}.metadata.json") - shutil.move(source_metadata, target_metadata) - metadata = await self._update_metadata_paths(target_metadata, target_lora) - - # Move preview file if exists - preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4', - '.png', '.jpeg', '.jpg', '.mp4'] - for ext in preview_extensions: - source_preview = os.path.join(source_dir, f"{base_name}{ext}") - if os.path.exists(source_preview): - target_preview = os.path.join(target_path, f"{base_name}{ext}") - shutil.move(source_preview, target_preview) - break - - # Update cache - await self.update_single_lora_cache(source_path, target_lora, metadata) - - return True - - except Exception as e: - logger.error(f"Error moving model: {e}", exc_info=True) - return False - - async def update_single_lora_cache(self, original_path: str, new_path: str, metadata: Dict) -> bool: - cache = await self.get_cached_data() - - # Find the existing item to remove its tags from count - existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None) - if existing_item and 'tags' in existing_item: - for tag in existing_item.get('tags', []): - if tag in self._tags_count: - self._tags_count[tag] = max(0, self._tags_count[tag] - 1) - if self._tags_count[tag] == 0: - del self._tags_count[tag] - - # Remove old path from hash index if exists - self._hash_index.remove_by_path(original_path) - - # Remove the old entry from raw_data - cache.raw_data = [ - item for item in cache.raw_data - if item['file_path'] != original_path - ] - - if metadata: - # If this is an update to an existing path (not a move), ensure folder is preserved - if original_path == new_path: - # Find the folder from existing entries or calculate it - existing_folder = next((item['folder'] for item in cache.raw_data - if item['file_path'] == original_path), None) - if existing_folder: - metadata['folder'] = existing_folder - else: - metadata['folder'] = self._calculate_folder(new_path) - else: - # For moved files, recalculate the folder - metadata['folder'] = self._calculate_folder(new_path) - - # Add the updated metadata to raw_data - cache.raw_data.append(metadata) - - # Update hash index with new path - if 'sha256' in metadata: - self._hash_index.add_entry(metadata['sha256'].lower(), new_path) - - # Update folders list - all_folders = set(item['folder'] for item in cache.raw_data) - cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) - - # Update tags count with the new/updated tags - if 'tags' in metadata: - for tag in metadata.get('tags', []): - self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 - - # Resort cache - await cache.resort() - - return True - async def _update_metadata_paths(self, metadata_path: str, lora_path: str) -> Dict: """Update file paths in metadata file""" try: @@ -648,7 +290,7 @@ class LoraScanner: except Exception as e: logger.error(f"Error updating metadata paths: {e}", exc_info=True) - # Add new methods for hash index functionality + # Lora-specific hash index functionality def has_lora_hash(self, sha256: str) -> bool: """Check if a LoRA with given hash exists""" return self._hash_index.has_hash(sha256.lower()) @@ -661,36 +303,8 @@ class LoraScanner: """Get hash for a LoRA by its file path""" return self._hash_index.get_hash(file_path) - def get_preview_url_by_hash(self, sha256: str) -> Optional[str]: - """Get preview static URL for a LoRA by its hash""" - # Get the file path first - file_path = self._hash_index.get_path(sha256.lower()) - if not file_path: - return None - - # Determine the preview file path (typically same name with different extension) - base_name = os.path.splitext(file_path)[0] - preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4', - '.png', '.jpeg', '.jpg', '.mp4'] - - for ext in preview_extensions: - preview_path = f"{base_name}{ext}" - if os.path.exists(preview_path): - # Convert to static URL using config - return config.get_preview_static_url(preview_path) - - return None - - # Add new method to get top tags async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]: - """Get top tags sorted by count - - Args: - limit: Maximum number of tags to return - - Returns: - List of dictionaries with tag name and count, sorted by count - """ + """Get top tags sorted by count""" # Make sure cache is initialized await self.get_cached_data() @@ -705,14 +319,7 @@ class LoraScanner: return sorted_tags[:limit] async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]: - """Get base models used in loras sorted by frequency - - Args: - limit: Maximum number of base models to return - - Returns: - List of dictionaries with base model name and count, sorted by count - """ + """Get base models used in loras sorted by frequency""" # Make sure cache is initialized cache = await self.get_cached_data() diff --git a/py/services/model_cache.py b/py/services/model_cache.py new file mode 100644 index 00000000..b652d919 --- /dev/null +++ b/py/services/model_cache.py @@ -0,0 +1,64 @@ +import asyncio +from typing import List, Dict +from dataclasses import dataclass +from operator import itemgetter + +@dataclass +class ModelCache: + """Cache structure for model data""" + raw_data: List[Dict] + sorted_by_name: List[Dict] + sorted_by_date: List[Dict] + folders: List[str] + + def __post_init__(self): + self._lock = asyncio.Lock() + + async def resort(self, name_only: bool = False): + """Resort all cached data views""" + async with self._lock: + self.sorted_by_name = sorted( + self.raw_data, + key=lambda x: x['model_name'].lower() # Case-insensitive sort + ) + if not name_only: + self.sorted_by_date = sorted( + self.raw_data, + key=itemgetter('modified'), + reverse=True + ) + # Update folder list + all_folders = set(l['folder'] for l in self.raw_data) + self.folders = sorted(list(all_folders), key=lambda x: x.lower()) + + async def update_preview_url(self, file_path: str, preview_url: str) -> bool: + """Update preview_url for a specific model in all cached data + + Args: + file_path: The file path of the model to update + preview_url: The new preview URL + + Returns: + bool: True if the update was successful, False if the model wasn't found + """ + async with self._lock: + # Update in raw_data + for item in self.raw_data: + if item['file_path'] == file_path: + item['preview_url'] = preview_url + break + else: + return False # Model not found + + # Update in sorted lists (references to the same dict objects) + for item in self.sorted_by_name: + if item['file_path'] == file_path: + item['preview_url'] = preview_url + break + + for item in self.sorted_by_date: + if item['file_path'] == file_path: + item['preview_url'] = preview_url + break + + return True \ No newline at end of file diff --git a/py/services/model_hash_index.py b/py/services/model_hash_index.py new file mode 100644 index 00000000..2f8ef0eb --- /dev/null +++ b/py/services/model_hash_index.py @@ -0,0 +1,78 @@ +from typing import Dict, Optional, Set + +class ModelHashIndex: + """Index for looking up models by hash or path""" + + def __init__(self): + self._hash_to_path: Dict[str, str] = {} + self._path_to_hash: Dict[str, str] = {} + + def add_entry(self, sha256: str, file_path: str) -> None: + """Add or update hash index entry""" + if not sha256 or not file_path: + return + + # Ensure hash is lowercase for consistency + sha256 = sha256.lower() + + # Remove old path mapping if hash exists + if sha256 in self._hash_to_path: + old_path = self._hash_to_path[sha256] + if old_path in self._path_to_hash: + del self._path_to_hash[old_path] + + # Remove old hash mapping if path exists + if file_path in self._path_to_hash: + old_hash = self._path_to_hash[file_path] + if old_hash in self._hash_to_path: + del self._hash_to_path[old_hash] + + # Add new mappings + self._hash_to_path[sha256] = file_path + self._path_to_hash[file_path] = sha256 + + def remove_by_path(self, file_path: str) -> None: + """Remove entry by file path""" + if file_path in self._path_to_hash: + hash_val = self._path_to_hash[file_path] + if hash_val in self._hash_to_path: + del self._hash_to_path[hash_val] + del self._path_to_hash[file_path] + + def remove_by_hash(self, sha256: str) -> None: + """Remove entry by hash""" + sha256 = sha256.lower() + if sha256 in self._hash_to_path: + path = self._hash_to_path[sha256] + if path in self._path_to_hash: + del self._path_to_hash[path] + del self._hash_to_path[sha256] + + def has_hash(self, sha256: str) -> bool: + """Check if hash exists in index""" + return sha256.lower() in self._hash_to_path + + def get_path(self, sha256: str) -> Optional[str]: + """Get file path for a hash""" + return self._hash_to_path.get(sha256.lower()) + + def get_hash(self, file_path: str) -> Optional[str]: + """Get hash for a file path""" + return self._path_to_hash.get(file_path) + + def clear(self) -> None: + """Clear all entries""" + self._hash_to_path.clear() + self._path_to_hash.clear() + + def get_all_hashes(self) -> Set[str]: + """Get all hashes in the index""" + return set(self._hash_to_path.keys()) + + def get_all_paths(self) -> Set[str]: + """Get all file paths in the index""" + return set(self._path_to_hash.keys()) + + def __len__(self) -> int: + """Get number of entries""" + return len(self._hash_to_path) \ No newline at end of file diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py new file mode 100644 index 00000000..0ada59b9 --- /dev/null +++ b/py/services/model_scanner.py @@ -0,0 +1,700 @@ +import json +import os +import logging +import asyncio +import time +import shutil +from typing import List, Dict, Optional, Type, Set + +from ..utils.models import BaseModelMetadata +from ..config import config +from ..utils.file_utils import load_metadata, get_file_info, find_preview_file, save_metadata +from .model_cache import ModelCache +from .model_hash_index import ModelHashIndex +from ..utils.constants import PREVIEW_EXTENSIONS +from .service_registry import ServiceRegistry + +logger = logging.getLogger(__name__) + +class ModelScanner: + """Base service for scanning and managing model files""" + + _lock = asyncio.Lock() + + def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None): + """Initialize the scanner + + Args: + model_type: Type of model (lora, checkpoint, etc.) + model_class: Class used to create metadata instances + file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'}) + hash_index: Hash index instance (optional) + """ + self.model_type = model_type + self.model_class = model_class + self.file_extensions = file_extensions + self._cache = None + self._hash_index = hash_index or ModelHashIndex() + self._tags_count = {} # Dictionary to store tag counts + self._is_initializing = False # Flag to track initialization state + + # Register this service + asyncio.create_task(self._register_service()) + + async def _register_service(self): + """Register this instance with the ServiceRegistry""" + service_name = f"{self.model_type}_scanner" + await ServiceRegistry.register_service(service_name, self) + + async def initialize_in_background(self) -> None: + """Initialize cache in background using thread pool""" + try: + # Set initial empty cache to avoid None reference errors + if self._cache is None: + self._cache = ModelCache( + raw_data=[], + sorted_by_name=[], + sorted_by_date=[], + folders=[] + ) + + # Set initializing flag to true + self._is_initializing = True + + start_time = time.time() + # Use thread pool to execute CPU-intensive operations + loop = asyncio.get_event_loop() + await loop.run_in_executor( + None, # Use default thread pool + self._initialize_cache_sync # Run synchronous version in thread + ) + logger.info(f"{self.model_type.capitalize()} cache initialized in {time.time() - start_time:.2f} seconds. Found {len(self._cache.raw_data)} models") + except Exception as e: + logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache in background: {e}") + finally: + # Always clear the initializing flag when done + self._is_initializing = False + + def _initialize_cache_sync(self): + """Synchronous version of cache initialization for thread pool execution""" + try: + # Create a new event loop for this thread + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Create a synchronous method to bypass the async lock + def sync_initialize_cache(): + # Directly call the scan method to avoid lock issues + raw_data = loop.run_until_complete(self.scan_all_models()) + + # Update hash index and tags count + for model_data in raw_data: + if 'sha256' in model_data and 'file_path' in model_data: + self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path']) + + # Count tags + if 'tags' in model_data and model_data['tags']: + for tag in model_data['tags']: + self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 + + # Update cache + self._cache.raw_data = raw_data + loop.run_until_complete(self._cache.resort()) + + return self._cache + + # Run our sync initialization that avoids lock conflicts + return sync_initialize_cache() + except Exception as e: + logger.error(f"Error in thread-based {self.model_type} cache initialization: {e}") + finally: + # Clean up the event loop + loop.close() + + async def get_cached_data(self, force_refresh: bool = False) -> ModelCache: + """Get cached model data, refresh if needed""" + # If cache is not initialized, return an empty cache + # Actual initialization should be done via initialize_in_background + if self._cache is None and not force_refresh: + return ModelCache( + raw_data=[], + sorted_by_name=[], + sorted_by_date=[], + folders=[] + ) + + # If force refresh is requested, initialize the cache directly + if force_refresh: + if self._cache is None: + # For initial creation, do a full initialization + await self._initialize_cache() + else: + # For subsequent refreshes, use fast reconciliation + await self._reconcile_cache() + + return self._cache + + async def _initialize_cache(self) -> None: + """Initialize or refresh the cache""" + try: + start_time = time.time() + # Clear existing hash index + self._hash_index.clear() + + # Clear existing tags count + self._tags_count = {} + + # Scan for new data + raw_data = await self.scan_all_models() + + # Build hash index and tags count + for model_data in raw_data: + if 'sha256' in model_data and 'file_path' in model_data: + self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path']) + + # Count tags + if 'tags' in model_data and model_data['tags']: + for tag in model_data['tags']: + self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 + + # Update cache + self._cache = ModelCache( + raw_data=raw_data, + sorted_by_name=[], + sorted_by_date=[], + folders=[] + ) + + # Resort cache + await self._cache.resort() + + logger.info(f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} models") + except Exception as e: + logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache: {e}") + self._cache = ModelCache( + raw_data=[], + sorted_by_name=[], + sorted_by_date=[], + folders=[] + ) + + async def _reconcile_cache(self) -> None: + """Fast cache reconciliation - only process differences between cache and filesystem""" + try: + start_time = time.time() + logger.info(f"{self.model_type.capitalize()} Scanner: Starting fast cache reconciliation...") + + # Get current cached file paths + cached_paths = {item['file_path'] for item in self._cache.raw_data} + path_to_item = {item['file_path']: item for item in self._cache.raw_data} + + # Track found files and new files + found_paths = set() + new_files = [] + + # Scan all model roots + for root_path in self.get_model_roots(): + if not os.path.exists(root_path): + continue + + # Track visited real paths to avoid symlink loops + visited_real_paths = set() + + # Recursively scan directory + for root, _, files in os.walk(root_path, followlinks=True): + real_root = os.path.realpath(root) + if real_root in visited_real_paths: + continue + visited_real_paths.add(real_root) + + for file in files: + ext = os.path.splitext(file)[1].lower() + if ext in self.file_extensions: + # Construct paths exactly as they would be in cache + file_path = os.path.join(root, file).replace(os.sep, '/') + + # Check if this file is already in cache + if file_path in cached_paths: + found_paths.add(file_path) + continue + + # Try case-insensitive match on Windows + if os.name == 'nt': + lower_path = file_path.lower() + matched = False + for cached_path in cached_paths: + if cached_path.lower() == lower_path: + found_paths.add(cached_path) + matched = True + break + if matched: + continue + + # This is a new file to process + new_files.append(file_path) + + # Yield control periodically + await asyncio.sleep(0) + + # Process new files in batches + total_added = 0 + if new_files: + logger.info(f"{self.model_type.capitalize()} Scanner: Found {len(new_files)} new files to process") + batch_size = 50 + for i in range(0, len(new_files), batch_size): + batch = new_files[i:i+batch_size] + for path in batch: + try: + model_data = await self.scan_single_model(path) + if model_data: + # Add to cache + self._cache.raw_data.append(model_data) + + # Update hash index if available + if 'sha256' in model_data and 'file_path' in model_data: + self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path']) + + # Update tags count + if 'tags' in model_data and model_data['tags']: + for tag in model_data['tags']: + self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 + + total_added += 1 + except Exception as e: + logger.error(f"Error adding {path} to cache: {e}") + + # Yield control after each batch + await asyncio.sleep(0) + + # Find missing files (in cache but not in filesystem) + missing_files = cached_paths - found_paths + total_removed = 0 + + if missing_files: + logger.info(f"{self.model_type.capitalize()} Scanner: Found {len(missing_files)} files to remove from cache") + + # Process files to remove + for path in missing_files: + try: + model_to_remove = path_to_item[path] + + # Update tags count + for tag in model_to_remove.get('tags', []): + if tag in self._tags_count: + self._tags_count[tag] = max(0, self._tags_count[tag] - 1) + if self._tags_count[tag] == 0: + del self._tags_count[tag] + + # Remove from hash index + self._hash_index.remove_by_path(path) + total_removed += 1 + except Exception as e: + logger.error(f"Error removing {path} from cache: {e}") + + # Update cache data + self._cache.raw_data = [item for item in self._cache.raw_data if item['file_path'] not in missing_files] + + # Resort cache if changes were made + if total_added > 0 or total_removed > 0: + # Update folders list + all_folders = set(item.get('folder', '') for item in self._cache.raw_data) + self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) + + # Resort cache + await self._cache.resort() + + logger.info(f"{self.model_type.capitalize()} Scanner: Cache reconciliation completed in {time.time() - start_time:.2f} seconds. Added {total_added}, removed {total_removed} models.") + except Exception as e: + logger.error(f"{self.model_type.capitalize()} Scanner: Error reconciling cache: {e}", exc_info=True) + + # These methods should be implemented in child classes + async def scan_all_models(self) -> List[Dict]: + """Scan all model directories and return metadata""" + raise NotImplementedError("Subclasses must implement scan_all_models") + + def get_model_roots(self) -> List[str]: + """Get model root directories""" + raise NotImplementedError("Subclasses must implement get_model_roots") + + async def scan_single_model(self, file_path: str) -> Optional[Dict]: + """Scan a single model file and return its metadata""" + try: + if not os.path.exists(os.path.realpath(file_path)): + return None + + # Get basic file info + metadata = await self._get_file_info(file_path) + if not metadata: + return None + + folder = self._calculate_folder(file_path) + + # Ensure folder field exists + metadata_dict = metadata.to_dict() + metadata_dict['folder'] = folder or '' + + return metadata_dict + + except Exception as e: + logger.error(f"Error scanning {file_path}: {e}") + return None + + async def _get_file_info(self, file_path: str) -> Optional[BaseModelMetadata]: + """Get model file info and metadata (extensible for different model types)""" + return await get_file_info(file_path, self.model_class) + + def _calculate_folder(self, file_path: str) -> str: + """Calculate the folder path for a model file""" + for root in self.get_model_roots(): + if file_path.startswith(root): + rel_path = os.path.relpath(file_path, root) + return os.path.dirname(rel_path).replace(os.path.sep, '/') + return '' + + # Common methods shared between scanners + async def _process_model_file(self, file_path: str, root_path: str) -> Dict: + """Process a single model file and return its metadata""" + metadata = await load_metadata(file_path, self.model_class) + + if metadata is None: + civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info" + if os.path.exists(civitai_info_path): + try: + with open(civitai_info_path, 'r', encoding='utf-8') as f: + version_info = json.load(f) + + file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None) + if file_info: + file_name = os.path.splitext(os.path.basename(file_path))[0] + file_info['name'] = file_name + + metadata = self.model_class.from_civitai_info(version_info, file_info, file_path) + metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path)) + await save_metadata(file_path, metadata) + logger.debug(f"Created metadata from .civitai.info for {file_path}") + except Exception as e: + logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}") + + if metadata is None: + metadata = await self._get_file_info(file_path) + + model_data = metadata.to_dict() + + await self._fetch_missing_metadata(file_path, model_data) + rel_path = os.path.relpath(file_path, root_path) + folder = os.path.dirname(rel_path) + model_data['folder'] = folder.replace(os.path.sep, '/') + + return model_data + + async def _fetch_missing_metadata(self, file_path: str, model_data: Dict) -> None: + """Fetch missing description and tags from Civitai if needed""" + try: + if model_data.get('civitai_deleted', False): + logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai") + return + + needs_metadata_update = False + model_id = None + + if model_data.get('civitai'): + model_id = model_data['civitai'].get('modelId') + + if model_id: + model_id = str(model_id) + tags_missing = not model_data.get('tags') or len(model_data.get('tags', [])) == 0 + desc_missing = not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "") + needs_metadata_update = tags_missing or desc_missing + + if needs_metadata_update and model_id: + logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}") + from ..services.civitai_client import CivitaiClient + client = CivitaiClient() + + model_metadata, status_code = await client.get_model_metadata(model_id) + await client.close() + + if status_code == 404: + logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)") + model_data['civitai_deleted'] = True + + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(model_data, f, indent=2, ensure_ascii=False) + + elif model_metadata: + logger.debug(f"Updating metadata for {file_path} with model ID {model_id}") + + if model_metadata.get('tags') and (not model_data.get('tags') or len(model_data.get('tags', [])) == 0): + model_data['tags'] = model_metadata['tags'] + + if model_metadata.get('description') and (not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")): + model_data['modelDescription'] = model_metadata['description'] + + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(model_data, f, indent=2, ensure_ascii=False) + except Exception as e: + logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}") + + async def _scan_directory(self, root_path: str) -> List[Dict]: + """Base implementation for directory scanning""" + models = [] + original_root = root_path + + async def scan_recursive(path: str, visited_paths: set): + try: + real_path = os.path.realpath(path) + if real_path in visited_paths: + logger.debug(f"Skipping already visited path: {path}") + return + visited_paths.add(real_path) + + with os.scandir(path) as it: + entries = list(it) + for entry in entries: + try: + if entry.is_file(follow_symlinks=True): + ext = os.path.splitext(entry.name)[1].lower() + if ext in self.file_extensions: + file_path = entry.path.replace(os.sep, "/") + await self._process_single_file(file_path, original_root, models) + await asyncio.sleep(0) + elif entry.is_dir(follow_symlinks=True): + await scan_recursive(entry.path, visited_paths) + except Exception as e: + logger.error(f"Error processing entry {entry.path}: {e}") + except Exception as e: + logger.error(f"Error scanning {path}: {e}") + + await scan_recursive(root_path, set()) + return models + + async def _process_single_file(self, file_path: str, root_path: str, models_list: list): + """Process a single file and add to results list""" + try: + result = await self._process_model_file(file_path, root_path) + if result: + models_list.append(result) + except Exception as e: + logger.error(f"Error processing {file_path}: {e}") + + async def move_model(self, source_path: str, target_path: str) -> bool: + """Move a model and its associated files to a new location""" + try: + source_path = source_path.replace(os.sep, '/') + target_path = target_path.replace(os.sep, '/') + + file_ext = os.path.splitext(source_path)[1] + + if not file_ext or file_ext.lower() not in self.file_extensions: + logger.error(f"Invalid file extension for model: {file_ext}") + return False + + base_name = os.path.splitext(os.path.basename(source_path))[0] + source_dir = os.path.dirname(source_path) + + os.makedirs(target_path, exist_ok=True) + + target_file = os.path.join(target_path, f"{base_name}{file_ext}").replace(os.sep, '/') + + real_source = os.path.realpath(source_path) + real_target = os.path.realpath(target_file) + + file_size = os.path.getsize(real_source) + + # Get the appropriate file monitor through ServiceRegistry + if self.model_type == "lora": + monitor = await ServiceRegistry.get_lora_monitor() + elif self.model_type == "checkpoint": + monitor = await ServiceRegistry.get_checkpoint_monitor() + else: + monitor = None + + if monitor: + monitor.handler.add_ignore_path( + real_source, + file_size + ) + monitor.handler.add_ignore_path( + real_target, + file_size + ) + + shutil.move(real_source, real_target) + + source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json") + metadata = None + if os.path.exists(source_metadata): + target_metadata = os.path.join(target_path, f"{base_name}.metadata.json") + shutil.move(source_metadata, target_metadata) + metadata = await self._update_metadata_paths(target_metadata, target_file) + + for ext in PREVIEW_EXTENSIONS: + source_preview = os.path.join(source_dir, f"{base_name}{ext}") + if os.path.exists(source_preview): + target_preview = os.path.join(target_path, f"{base_name}{ext}") + shutil.move(source_preview, target_preview) + break + + await self.update_single_model_cache(source_path, target_file, metadata) + + return True + + except Exception as e: + logger.error(f"Error moving model: {e}", exc_info=True) + return False + + async def _update_metadata_paths(self, metadata_path: str, model_path: str) -> Dict: + """Update file paths in metadata file""" + try: + with open(metadata_path, 'r', encoding='utf-8') as f: + metadata = json.load(f) + + metadata['file_path'] = model_path.replace(os.sep, '/') + + if 'preview_url' in metadata: + preview_dir = os.path.dirname(model_path) + preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0] + preview_ext = os.path.splitext(metadata['preview_url'])[1] + new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}") + metadata['preview_url'] = new_preview_path.replace(os.sep, '/') + + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) + + return metadata + + except Exception as e: + logger.error(f"Error updating metadata paths: {e}", exc_info=True) + return None + + async def update_single_model_cache(self, original_path: str, new_path: str, metadata: Dict) -> bool: + """Update cache after a model has been moved or modified""" + cache = await self.get_cached_data() + + existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None) + if existing_item and 'tags' in existing_item: + for tag in existing_item.get('tags', []): + if tag in self._tags_count: + self._tags_count[tag] = max(0, self._tags_count[tag] - 1) + if self._tags_count[tag] == 0: + del self._tags_count[tag] + + self._hash_index.remove_by_path(original_path) + + cache.raw_data = [ + item for item in cache.raw_data + if item['file_path'] != original_path + ] + + if metadata: + if original_path == new_path: + existing_folder = next((item['folder'] for item in cache.raw_data + if item['file_path'] == original_path), None) + if existing_folder: + metadata['folder'] = existing_folder + else: + metadata['folder'] = self._calculate_folder(new_path) + else: + metadata['folder'] = self._calculate_folder(new_path) + + cache.raw_data.append(metadata) + + if 'sha256' in metadata: + self._hash_index.add_entry(metadata['sha256'].lower(), new_path) + + all_folders = set(item['folder'] for item in cache.raw_data) + cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) + + if 'tags' in metadata: + for tag in metadata.get('tags', []): + self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 + + await cache.resort() + + return True + + def has_hash(self, sha256: str) -> bool: + """Check if a model with given hash exists""" + return self._hash_index.has_hash(sha256.lower()) + + def get_path_by_hash(self, sha256: str) -> Optional[str]: + """Get file path for a model by its hash""" + return self._hash_index.get_path(sha256.lower()) + + def get_hash_by_path(self, file_path: str) -> Optional[str]: + """Get hash for a model by its file path""" + return self._hash_index.get_hash(file_path) + + # TODO: Adjust this method to use metadata instead of finding the file + def get_preview_url_by_hash(self, sha256: str) -> Optional[str]: + """Get preview static URL for a model by its hash""" + file_path = self._hash_index.get_path(sha256.lower()) + if not file_path: + return None + + base_name = os.path.splitext(file_path)[0] + + for ext in PREVIEW_EXTENSIONS: + preview_path = f"{base_name}{ext}" + if os.path.exists(preview_path): + return config.get_preview_static_url(preview_path) + + return None + + async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]: + """Get top tags sorted by count""" + await self.get_cached_data() + + sorted_tags = sorted( + [{"tag": tag, "count": count} for tag, count in self._tags_count.items()], + key=lambda x: x['count'], + reverse=True + ) + + return sorted_tags[:limit] + + async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]: + """Get base models sorted by frequency""" + cache = await self.get_cached_data() + + base_model_counts = {} + for model in cache.raw_data: + if 'base_model' in model and model['base_model']: + base_model = model['base_model'] + base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 + + sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()] + sorted_models.sort(key=lambda x: x['count'], reverse=True) + + return sorted_models[:limit] + + async def get_model_info_by_name(self, name): + """Get model information by name""" + try: + cache = await self.get_cached_data() + + for model in cache.raw_data: + if model.get("file_name") == name: + return model + + return None + except Exception as e: + logger.error(f"Error getting model info by name: {e}", exc_info=True) + return None + + async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool: + """Update preview URL in cache for a specific lora + + Args: + file_path: The file path of the lora to update + preview_url: The new preview URL + + Returns: + bool: True if the update was successful, False if cache doesn't exist or lora wasn't found + """ + if self._cache is None: + return False + + return await self._cache.update_preview_url(file_path, preview_url) \ No newline at end of file diff --git a/py/services/recipe_scanner.py b/py/services/recipe_scanner.py index ec3310ee..5f3e3ffd 100644 --- a/py/services/recipe_scanner.py +++ b/py/services/recipe_scanner.py @@ -5,8 +5,8 @@ import json from typing import List, Dict, Optional, Any, Tuple from ..config import config from .recipe_cache import RecipeCache +from .service_registry import ServiceRegistry from .lora_scanner import LoraScanner -from .civitai_client import CivitaiClient from ..utils.utils import fuzzy_match import sys @@ -18,11 +18,22 @@ class RecipeScanner: _instance = None _lock = asyncio.Lock() + @classmethod + async def get_instance(cls, lora_scanner: Optional[LoraScanner] = None): + """Get singleton instance of RecipeScanner""" + async with cls._lock: + if cls._instance is None: + if not lora_scanner: + # Get lora scanner from service registry if not provided + lora_scanner = await ServiceRegistry.get_lora_scanner() + cls._instance = cls(lora_scanner) + return cls._instance + def __new__(cls, lora_scanner: Optional[LoraScanner] = None): if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._lora_scanner = lora_scanner - cls._instance._civitai_client = CivitaiClient() + cls._instance._civitai_client = None # Will be lazily initialized return cls._instance def __init__(self, lora_scanner: Optional[LoraScanner] = None): @@ -35,9 +46,67 @@ class RecipeScanner: if lora_scanner: self._lora_scanner = lora_scanner self._initialized = True - - # Initialization will be scheduled by LoraManager + async def _get_civitai_client(self): + """Lazily initialize CivitaiClient from registry""" + if self._civitai_client is None: + self._civitai_client = await ServiceRegistry.get_civitai_client() + return self._civitai_client + + async def initialize_in_background(self) -> None: + """Initialize cache in background using thread pool""" + try: + # Set initial empty cache to avoid None reference errors + if self._cache is None: + self._cache = RecipeCache( + raw_data=[], + sorted_by_name=[], + sorted_by_date=[] + ) + + # Mark as initializing to prevent concurrent initializations + self._is_initializing = True + + try: + # Use thread pool to execute CPU-intensive operations + loop = asyncio.get_event_loop() + await loop.run_in_executor( + None, # Use default thread pool + self._initialize_recipe_cache_sync # Run synchronous version in thread + ) + logger.info("Recipe cache initialization completed in background thread") + finally: + # Mark initialization as complete regardless of outcome + self._is_initializing = False + except Exception as e: + logger.error(f"Recipe Scanner: Error initializing cache in background: {e}") + + def _initialize_recipe_cache_sync(self): + """Synchronous version of recipe cache initialization for thread pool execution""" + try: + # Create a new event loop for this thread + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Create a synchronous method to bypass the async lock + def sync_initialize_cache(): + # Directly call the internal scan method to avoid lock issues + raw_data = loop.run_until_complete(self.scan_all_recipes()) + + # Update cache + self._cache.raw_data = raw_data + loop.run_until_complete(self._cache.resort()) + + return self._cache + + # Run our sync initialization that avoids lock conflicts + return sync_initialize_cache() + except Exception as e: + logger.error(f"Error in thread-based recipe cache initialization: {e}") + finally: + # Clean up the event loop + loop.close() + @property def recipes_dir(self) -> str: """Get path to recipes directory""" @@ -60,49 +129,48 @@ class RecipeScanner: if self._is_initializing and not force_refresh: return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[]) - # Try to acquire the lock with a timeout to prevent deadlocks - try: - async with self._initialization_lock: - # Check again after acquiring the lock - if self._cache is not None and not force_refresh: - return self._cache - - # Mark as initializing to prevent concurrent initializations - self._is_initializing = True - - try: - # Remove dependency on lora scanner initialization - # Scan for recipe data directly - raw_data = await self.scan_all_recipes() + # If force refresh is requested, initialize the cache directly + if force_refresh: + # Try to acquire the lock with a timeout to prevent deadlocks + try: + async with self._initialization_lock: + # Mark as initializing to prevent concurrent initializations + self._is_initializing = True - # Update cache - self._cache = RecipeCache( - raw_data=raw_data, - sorted_by_name=[], - sorted_by_date=[] - ) + try: + # Scan for recipe data directly + raw_data = await self.scan_all_recipes() + + # Update cache + self._cache = RecipeCache( + raw_data=raw_data, + sorted_by_name=[], + sorted_by_date=[] + ) + + # Resort cache + await self._cache.resort() + + return self._cache - # Resort cache - await self._cache.resort() - - return self._cache - - except Exception as e: - logger.error(f"Recipe Manager: Error initializing cache: {e}", exc_info=True) - # Create empty cache on error - self._cache = RecipeCache( - raw_data=[], - sorted_by_name=[], - sorted_by_date=[] - ) - return self._cache - finally: - # Mark initialization as complete - self._is_initializing = False + except Exception as e: + logger.error(f"Recipe Manager: Error initializing cache: {e}", exc_info=True) + # Create empty cache on error + self._cache = RecipeCache( + raw_data=[], + sorted_by_name=[], + sorted_by_date=[] + ) + return self._cache + finally: + # Mark initialization as complete + self._is_initializing = False + + except Exception as e: + logger.error(f"Unexpected error in get_cached_data: {e}") - except Exception as e: - logger.error(f"Unexpected error in get_cached_data: {e}") - return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[]) + # Return the cache (may be empty or partially initialized) + return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[]) async def scan_all_recipes(self) -> List[Dict]: """Scan all recipe JSON files and return metadata""" @@ -255,10 +323,13 @@ class RecipeScanner: async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]: """Get hash from Civitai API""" try: - if not self._civitai_client: + # Get CivitaiClient from ServiceRegistry + civitai_client = await self._get_civitai_client() + if not civitai_client: + logger.error("Failed to get CivitaiClient from ServiceRegistry") return None - version_info = await self._civitai_client.get_model_version_info(model_version_id) + version_info = await civitai_client.get_model_version_info(model_version_id) if not version_info or not version_info.get('files'): logger.debug(f"No files found in version info for ID: {model_version_id}") @@ -278,10 +349,12 @@ class RecipeScanner: async def _get_model_version_name(self, model_version_id: str) -> Optional[str]: """Get model version name from Civitai API""" try: - if not self._civitai_client: + # Get CivitaiClient from ServiceRegistry + civitai_client = await self._get_civitai_client() + if not civitai_client: return None - version_info = await self._civitai_client.get_model_version_info(model_version_id) + version_info = await civitai_client.get_model_version_info(model_version_id) if version_info and 'name' in version_info: return version_info['name'] diff --git a/py/services/service_registry.py b/py/services/service_registry.py new file mode 100644 index 00000000..5940a1ed --- /dev/null +++ b/py/services/service_registry.py @@ -0,0 +1,124 @@ +import asyncio +import logging +from typing import Optional, Dict, Any, TypeVar, Type + +logger = logging.getLogger(__name__) + +T = TypeVar('T') # Define a type variable for service types + +class ServiceRegistry: + """Centralized registry for service singletons""" + + _instance = None + _services: Dict[str, Any] = {} + _lock = asyncio.Lock() + + @classmethod + def get_instance(cls): + """Get singleton instance of the registry""" + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + async def register_service(cls, service_name: str, service_instance: Any) -> None: + """Register a service instance with the registry""" + registry = cls.get_instance() + async with cls._lock: + registry._services[service_name] = service_instance + logger.debug(f"Registered service: {service_name}") + + @classmethod + async def get_service(cls, service_name: str) -> Any: + """Get a service instance by name""" + registry = cls.get_instance() + async with cls._lock: + if service_name not in registry._services: + logger.debug(f"Service {service_name} not found in registry") + return None + return registry._services[service_name] + + # Convenience methods for common services + @classmethod + async def get_lora_scanner(cls): + """Get the LoraScanner instance""" + from .lora_scanner import LoraScanner + scanner = await cls.get_service("lora_scanner") + if scanner is None: + scanner = await LoraScanner.get_instance() + await cls.register_service("lora_scanner", scanner) + return scanner + + @classmethod + async def get_checkpoint_scanner(cls): + """Get the CheckpointScanner instance""" + from .checkpoint_scanner import CheckpointScanner + scanner = await cls.get_service("checkpoint_scanner") + if scanner is None: + scanner = await CheckpointScanner.get_instance() + await cls.register_service("checkpoint_scanner", scanner) + return scanner + + @classmethod + async def get_lora_monitor(cls): + """Get the LoraFileMonitor instance""" + from .file_monitor import LoraFileMonitor + monitor = await cls.get_service("lora_monitor") + if monitor is None: + monitor = await LoraFileMonitor.get_instance() + await cls.register_service("lora_monitor", monitor) + return monitor + + @classmethod + async def get_checkpoint_monitor(cls): + """Get the CheckpointFileMonitor instance""" + from .file_monitor import CheckpointFileMonitor + monitor = await cls.get_service("checkpoint_monitor") + if monitor is None: + monitor = await CheckpointFileMonitor.get_instance() + await cls.register_service("checkpoint_monitor", monitor) + return monitor + + @classmethod + async def get_civitai_client(cls): + """Get the CivitaiClient instance""" + from .civitai_client import CivitaiClient + client = await cls.get_service("civitai_client") + if client is None: + client = await CivitaiClient.get_instance() + await cls.register_service("civitai_client", client) + return client + + @classmethod + async def get_download_manager(cls): + """Get the DownloadManager instance""" + from .download_manager import DownloadManager + manager = await cls.get_service("download_manager") + if manager is None: + # We'll let DownloadManager.get_instance handle file_monitor parameter + manager = await DownloadManager.get_instance() + await cls.register_service("download_manager", manager) + return manager + + @classmethod + async def get_recipe_scanner(cls): + """Get the RecipeScanner instance""" + from .recipe_scanner import RecipeScanner + scanner = await cls.get_service("recipe_scanner") + if scanner is None: + lora_scanner = await cls.get_lora_scanner() + scanner = RecipeScanner(lora_scanner) + await cls.register_service("recipe_scanner", scanner) + return scanner + + @classmethod + async def get_websocket_manager(cls): + """Get the WebSocketManager instance""" + from .websocket_manager import ws_manager + manager = await cls.get_service("websocket_manager") + if manager is None: + # ws_manager is already a global instance in websocket_manager.py + from .websocket_manager import ws_manager + await cls.register_service("websocket_manager", ws_manager) + manager = ws_manager + return manager \ No newline at end of file diff --git a/py/utils/constants.py b/py/utils/constants.py index 69a96ca2..0521c5d9 100644 --- a/py/utils/constants.py +++ b/py/utils/constants.py @@ -5,4 +5,21 @@ NSFW_LEVELS = { "X": 8, "XXX": 16, "Blocked": 32, # Probably not actually visible through the API without being logged in on model owner account? -} \ No newline at end of file +} + +# preview extensions +PREVIEW_EXTENSIONS = [ + '.webp', + '.preview.webp', + '.preview.png', + '.preview.jpeg', + '.preview.jpg', + '.preview.mp4', + '.png', + '.jpeg', + '.jpg', + '.mp4' +] + +# Card preview image width +CARD_PREVIEW_WIDTH = 480 \ No newline at end of file diff --git a/py/utils/file_utils.py b/py/utils/file_utils.py index 0f282051..366d1c56 100644 --- a/py/utils/file_utils.py +++ b/py/utils/file_utils.py @@ -2,12 +2,14 @@ import logging import os import hashlib import json -from typing import Dict, Optional +import time +from typing import Dict, Optional, Type from .model_utils import determine_base_model - -from .lora_metadata import extract_lora_metadata -from .models import LoraMetadata +from .lora_metadata import extract_lora_metadata, extract_checkpoint_metadata +from .models import BaseModelMetadata, LoraMetadata, CheckpointMetadata +from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH +from .exif_utils import ExifUtils logger = logging.getLogger(__name__) @@ -15,35 +17,56 @@ async def calculate_sha256(file_path: str) -> str: """Calculate SHA256 hash of a file""" sha256_hash = hashlib.sha256() with open(file_path, "rb") as f: - for byte_block in iter(lambda: f.read(4096), b""): + for byte_block in iter(lambda: f.read(128 * 1024), b""): sha256_hash.update(byte_block) return sha256_hash.hexdigest() def find_preview_file(base_name: str, dir_path: str) -> str: """Find preview file for given base name in directory""" - preview_patterns = [ - f"{base_name}.preview.png", - f"{base_name}.preview.jpg", - f"{base_name}.preview.jpeg", - f"{base_name}.preview.mp4", - f"{base_name}.png", - f"{base_name}.jpg", - f"{base_name}.jpeg", - f"{base_name}.mp4" - ] - for pattern in preview_patterns: - full_pattern = os.path.join(dir_path, pattern) + for ext in PREVIEW_EXTENSIONS: + full_pattern = os.path.join(dir_path, f"{base_name}{ext}") if os.path.exists(full_pattern): + # Check if this is an image and not already webp + if ext.lower().endswith(('.jpg', '.jpeg', '.png')) and not ext.lower().endswith('.webp'): + try: + # Optimize the image to webp format + webp_path = os.path.join(dir_path, f"{base_name}.webp") + + # Use ExifUtils to optimize the image + with open(full_pattern, 'rb') as f: + image_data = f.read() + + optimized_data, _ = ExifUtils.optimize_image( + image_data=image_data, + target_width=CARD_PREVIEW_WIDTH, + format='webp', + quality=85, + preserve_metadata=True + ) + + # Save the optimized webp file + with open(webp_path, 'wb') as f: + f.write(optimized_data) + + logger.debug(f"Optimized preview image from {full_pattern} to {webp_path}") + return webp_path.replace(os.sep, "/") + except Exception as e: + logger.error(f"Error optimizing preview image {full_pattern}: {e}") + # Fall back to original file if optimization fails + return full_pattern.replace(os.sep, "/") + + # Return the original path for webp images or non-image files return full_pattern.replace(os.sep, "/") + return "" def normalize_path(path: str) -> str: """Normalize file path to use forward slashes""" return path.replace(os.sep, "/") if path else path -async def get_file_info(file_path: str) -> Optional[LoraMetadata]: - """Get basic file information as LoraMetadata object""" +async def get_file_info(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]: + """Get basic file information as a model metadata object""" # First check if file actually exists and resolve symlinks try: real_path = os.path.realpath(file_path) @@ -70,31 +93,67 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]: logger.debug(f"Using SHA256 from .json file for {file_path}") except Exception as e: logger.error(f"Error reading .json file for {file_path}: {e}") + + # If SHA256 is still not found, check for a .sha256 file + if sha256 is None: + sha256_file = f"{os.path.splitext(file_path)[0]}.sha256" + if os.path.exists(sha256_file): + try: + with open(sha256_file, 'r', encoding='utf-8') as f: + sha256 = f.read().strip().lower() + logger.debug(f"Using SHA256 from .sha256 file for {file_path}") + except Exception as e: + logger.error(f"Error reading .sha256 file for {file_path}: {e}") try: # If we didn't get SHA256 from the .json file, calculate it if not sha256: + start_time = time.time() sha256 = await calculate_sha256(real_path) + logger.debug(f"Calculated SHA256 for {file_path} in {time.time() - start_time:.2f} seconds") + + # Create default metadata based on model class + if model_class == CheckpointMetadata: + metadata = CheckpointMetadata( + file_name=base_name, + model_name=base_name, + file_path=normalize_path(file_path), + size=os.path.getsize(real_path), + modified=os.path.getmtime(real_path), + sha256=sha256, + base_model="Unknown", # Will be updated later + preview_url=normalize_path(preview_url), + tags=[], + modelDescription="", + model_type="checkpoint" + ) - metadata = LoraMetadata( - file_name=base_name, - model_name=base_name, - file_path=normalize_path(file_path), - size=os.path.getsize(real_path), - modified=os.path.getmtime(real_path), - sha256=sha256, - base_model="Unknown", # Will be updated later - usage_tips="", - notes="", - from_civitai=True, - preview_url=normalize_path(preview_url), - tags=[], - modelDescription="" - ) + # Extract checkpoint-specific metadata + # model_info = await extract_checkpoint_metadata(real_path) + # metadata.base_model = model_info['base_model'] + # if 'model_type' in model_info: + # metadata.model_type = model_info['model_type'] + + else: # Default to LoraMetadata + metadata = LoraMetadata( + file_name=base_name, + model_name=base_name, + file_path=normalize_path(file_path), + size=os.path.getsize(real_path), + modified=os.path.getmtime(real_path), + sha256=sha256, + base_model="Unknown", # Will be updated later + usage_tips="{}", + preview_url=normalize_path(preview_url), + tags=[], + modelDescription="" + ) + + # Extract lora-specific metadata + model_info = await extract_lora_metadata(real_path) + metadata.base_model = model_info['base_model'] - # create metadata file - base_model_info = await extract_lora_metadata(real_path) - metadata.base_model = base_model_info['base_model'] + # Save metadata to file await save_metadata(file_path, metadata) return metadata @@ -102,7 +161,7 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]: logger.error(f"Error getting file info for {file_path}: {e}") return None -async def save_metadata(file_path: str, metadata: LoraMetadata) -> None: +async def save_metadata(file_path: str, metadata: BaseModelMetadata) -> None: """Save metadata to .metadata.json file""" metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json" try: @@ -115,7 +174,7 @@ async def save_metadata(file_path: str, metadata: LoraMetadata) -> None: except Exception as e: print(f"Error saving metadata to {metadata_path}: {str(e)}") -async def load_metadata(file_path: str) -> Optional[LoraMetadata]: +async def load_metadata(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]: """Load metadata from .metadata.json file""" metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json" try: @@ -138,6 +197,7 @@ async def load_metadata(file_path: str) -> Optional[LoraMetadata]: data['file_path'] = normalize_path(file_path) needs_update = True + # TODO: optimize preview image to webp format if not already done preview_url = data.get('preview_url', '') if not preview_url or not os.path.exists(preview_url): base_name = os.path.splitext(os.path.basename(file_path))[0] @@ -162,12 +222,22 @@ async def load_metadata(file_path: str) -> Optional[LoraMetadata]: if 'modelDescription' not in data: data['modelDescription'] = "" needs_update = True + + # For checkpoint metadata + if model_class == CheckpointMetadata and 'model_type' not in data: + data['model_type'] = "checkpoint" + needs_update = True + + # For lora metadata + if model_class == LoraMetadata and 'usage_tips' not in data: + data['usage_tips'] = "{}" + needs_update = True if needs_update: with open(metadata_path, 'w', encoding='utf-8') as f: json.dump(data, f, indent=2, ensure_ascii=False) - return LoraMetadata.from_dict(data) + return model_class.from_dict(data) except Exception as e: print(f"Error loading metadata from {metadata_path}: {str(e)}") diff --git a/py/utils/lora_metadata.py b/py/utils/lora_metadata.py index c04cd829..f221562d 100644 --- a/py/utils/lora_metadata.py +++ b/py/utils/lora_metadata.py @@ -1,6 +1,7 @@ from safetensors import safe_open from typing import Dict from .model_utils import determine_base_model +import os async def extract_lora_metadata(file_path: str) -> Dict: """Extract essential metadata from safetensors file""" @@ -13,4 +14,67 @@ async def extract_lora_metadata(file_path: str) -> Dict: return {"base_model": base_model} except Exception as e: print(f"Error reading metadata from {file_path}: {str(e)}") - return {"base_model": "Unknown"} \ No newline at end of file + return {"base_model": "Unknown"} + +async def extract_checkpoint_metadata(file_path: str) -> dict: + """Extract metadata from a checkpoint file to determine model type and base model""" + try: + # Analyze filename for clues about the model + filename = os.path.basename(file_path).lower() + + model_info = { + 'base_model': 'Unknown', + 'model_type': 'checkpoint' + } + + # Detect base model from filename + if 'xl' in filename or 'sdxl' in filename: + model_info['base_model'] = 'SDXL' + elif 'sd3' in filename: + model_info['base_model'] = 'SD3' + elif 'sd2' in filename or 'v2' in filename: + model_info['base_model'] = 'SD2.x' + elif 'sd1' in filename or 'v1' in filename: + model_info['base_model'] = 'SD1.5' + + # Detect model type from filename + if 'inpaint' in filename: + model_info['model_type'] = 'inpainting' + elif 'anime' in filename: + model_info['model_type'] = 'anime' + elif 'realistic' in filename: + model_info['model_type'] = 'realistic' + + # Try to peek at the safetensors file structure if available + if file_path.endswith('.safetensors'): + import json + import struct + + with open(file_path, 'rb') as f: + header_size = struct.unpack(' 'LoraMetadata': - """Create LoraMetadata instance from dictionary""" - # Create a copy of the data to avoid modifying the input + def from_dict(cls, data: Dict) -> 'BaseModelMetadata': + """Create instance from dictionary""" data_copy = data.copy() return cls(**data_copy) - @classmethod - def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata': - """Create LoraMetadata instance from Civitai version info""" - file_name = file_info['name'] - base_model = determine_base_model(version_info.get('baseModel', '')) - - return cls( - file_name=os.path.splitext(file_name)[0], - model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]), - file_path=save_path.replace(os.sep, '/'), - size=file_info.get('sizeKB', 0) * 1024, - modified=datetime.now().timestamp(), - sha256=file_info['hashes'].get('SHA256', '').lower(), - base_model=base_model, - preview_url=None, # Will be updated after preview download - preview_nsfw_level=0, # Will be updated after preview download, it is decided by the nsfw level of the preview image - from_civitai=True, - civitai=version_info - ) - def to_dict(self) -> Dict: """Convert to dictionary for JSON serialization""" return asdict(self) @@ -76,30 +54,54 @@ class LoraMetadata: self.file_path = file_path.replace(os.sep, '/') @dataclass -class CheckpointMetadata: - """Represents the metadata structure for a Checkpoint model""" - file_name: str # The filename without extension - model_name: str # The checkpoint's name defined by the creator - file_path: str # Full path to the model file - size: int # File size in bytes - modified: float # Last modified timestamp - sha256: str # SHA256 hash of the file - base_model: str # Base model type (SD1.5/SD2.1/SDXL/etc.) - preview_url: str # Preview image URL - preview_nsfw_level: int = 0 # NSFW level of the preview image - model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.) - notes: str = "" # Additional notes - from_civitai: bool = True # Whether from Civitai - civitai: Optional[Dict] = None # Civitai API data if available - tags: List[str] = None # Model tags - modelDescription: str = "" # Full model description - - # Additional checkpoint-specific fields - resolution: Optional[str] = None # Native resolution (e.g., 512x512, 1024x1024) - vae_included: bool = False # Whether VAE is included in the checkpoint - architecture: str = "" # Model architecture (if known) - - def __post_init__(self): - if self.tags is None: - self.tags = [] +class LoraMetadata(BaseModelMetadata): + """Represents the metadata structure for a Lora model""" + usage_tips: str = "{}" # Usage tips for the model, json string + + @classmethod + def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata': + """Create LoraMetadata instance from Civitai version info""" + file_name = file_info['name'] + base_model = determine_base_model(version_info.get('baseModel', '')) + + return cls( + file_name=os.path.splitext(file_name)[0], + model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]), + file_path=save_path.replace(os.sep, '/'), + size=file_info.get('sizeKB', 0) * 1024, + modified=datetime.now().timestamp(), + sha256=file_info['hashes'].get('SHA256', '').lower(), + base_model=base_model, + preview_url=None, # Will be updated after preview download + preview_nsfw_level=0, # Will be updated after preview download + from_civitai=True, + civitai=version_info + ) + +@dataclass +class CheckpointMetadata(BaseModelMetadata): + """Represents the metadata structure for a Checkpoint model""" + model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.) + + @classmethod + def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'CheckpointMetadata': + """Create CheckpointMetadata instance from Civitai version info""" + file_name = file_info['name'] + base_model = determine_base_model(version_info.get('baseModel', '')) + model_type = version_info.get('type', 'checkpoint') + + return cls( + file_name=os.path.splitext(file_name)[0], + model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]), + file_path=save_path.replace(os.sep, '/'), + size=file_info.get('sizeKB', 0) * 1024, + modified=datetime.now().timestamp(), + sha256=file_info['hashes'].get('SHA256', '').lower(), + base_model=base_model, + preview_url=None, # Will be updated after preview download + preview_nsfw_level=0, + from_civitai=True, + civitai=version_info, + model_type=model_type + ) diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py new file mode 100644 index 00000000..6c0dc8d7 --- /dev/null +++ b/py/utils/routes_common.py @@ -0,0 +1,503 @@ +import os +import json +import logging +from typing import Dict, List, Callable, Awaitable +from aiohttp import web + +from .model_utils import determine_base_model +from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH +from ..config import config +from ..services.civitai_client import CivitaiClient +from ..utils.exif_utils import ExifUtils +from ..services.download_manager import DownloadManager + +logger = logging.getLogger(__name__) + + +class ModelRouteUtils: + """Shared utilities for model routes (LoRAs, Checkpoints, etc.)""" + + @staticmethod + async def load_local_metadata(metadata_path: str) -> Dict: + """Load local metadata file""" + if os.path.exists(metadata_path): + try: + with open(metadata_path, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + logger.error(f"Error loading metadata from {metadata_path}: {e}") + return {} + + @staticmethod + async def handle_not_found_on_civitai(metadata_path: str, local_metadata: Dict) -> None: + """Handle case when model is not found on CivitAI""" + local_metadata['from_civitai'] = False + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(local_metadata, f, indent=2, ensure_ascii=False) + + @staticmethod + async def update_model_metadata(metadata_path: str, local_metadata: Dict, + civitai_metadata: Dict, client: CivitaiClient) -> None: + """Update local metadata with CivitAI data""" + local_metadata['civitai'] = civitai_metadata + + # Update model name if available + if 'model' in civitai_metadata: + if civitai_metadata.get('model', {}).get('name'): + local_metadata['model_name'] = civitai_metadata['model']['name'] + + # Fetch additional model metadata (description and tags) if we have model ID + model_id = civitai_metadata['modelId'] + if model_id: + model_metadata, _ = await client.get_model_metadata(str(model_id)) + if model_metadata: + local_metadata['modelDescription'] = model_metadata.get('description', '') + local_metadata['tags'] = model_metadata.get('tags', []) + + # Update base model + local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel')) + + # Update preview if needed + if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']): + first_preview = next((img for img in civitai_metadata.get('images', [])), None) + if first_preview: + # Determine if content is video or image + is_video = first_preview['type'] == 'video' + + if is_video: + # For videos use .mp4 extension + preview_ext = '.mp4' + else: + # For images use .webp extension + preview_ext = '.webp' + + base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0] + preview_filename = base_name + preview_ext + preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename) + + if is_video: + # Download video as is + if await client.download_preview_image(first_preview['url'], preview_path): + local_metadata['preview_url'] = preview_path.replace(os.sep, '/') + local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0) + else: + # For images, download and then optimize to WebP + temp_path = preview_path + ".temp" + if await client.download_preview_image(first_preview['url'], temp_path): + try: + # Read the downloaded image + with open(temp_path, 'rb') as f: + image_data = f.read() + + # Optimize and convert to WebP + optimized_data, _ = ExifUtils.optimize_image( + image_data=image_data, + target_width=CARD_PREVIEW_WIDTH, + format='webp', + quality=85, + preserve_metadata=True + ) + + # Save the optimized WebP image + with open(preview_path, 'wb') as f: + f.write(optimized_data) + + # Update metadata + local_metadata['preview_url'] = preview_path.replace(os.sep, '/') + local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0) + + # Remove the temporary file + if os.path.exists(temp_path): + os.remove(temp_path) + + except Exception as e: + logger.error(f"Error optimizing preview image: {e}") + # If optimization fails, try to use the downloaded image directly + if os.path.exists(temp_path): + os.rename(temp_path, preview_path) + local_metadata['preview_url'] = preview_path.replace(os.sep, '/') + local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0) + + # Save updated metadata + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(local_metadata, f, indent=2, ensure_ascii=False) + + @staticmethod + async def fetch_and_update_model( + sha256: str, + file_path: str, + model_data: dict, + update_cache_func: Callable[[str, str, Dict], Awaitable[bool]] + ) -> bool: + """Fetch and update metadata for a single model + + Args: + sha256: SHA256 hash of the model file + file_path: Path to the model file + model_data: The model object in cache to update + update_cache_func: Function to update the cache with new metadata + + Returns: + bool: True if successful, False otherwise + """ + client = CivitaiClient() + try: + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + + # Check if model metadata exists + local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + + # Fetch metadata from Civitai + civitai_metadata = await client.get_model_by_hash(sha256) + if not civitai_metadata: + # Mark as not from CivitAI if not found + local_metadata['from_civitai'] = False + model_data['from_civitai'] = False + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(local_metadata, f, indent=2, ensure_ascii=False) + return False + + # Update metadata + await ModelRouteUtils.update_model_metadata( + metadata_path, + local_metadata, + civitai_metadata, + client + ) + + # Update cache object directly + model_data.update({ + 'model_name': local_metadata.get('model_name'), + 'preview_url': local_metadata.get('preview_url'), + 'from_civitai': True, + 'civitai': civitai_metadata + }) + + # Update cache using the provided function + await update_cache_func(file_path, file_path, local_metadata) + + return True + + except Exception as e: + logger.error(f"Error fetching CivitAI data: {e}") + return False + finally: + await client.close() + + @staticmethod + def filter_civitai_data(data: Dict) -> Dict: + """Filter relevant fields from CivitAI data""" + if not data: + return {} + + fields = [ + "id", "modelId", "name", "createdAt", "updatedAt", + "publishedAt", "trainedWords", "baseModel", "description", + "model", "images" + ] + return {k: data[k] for k in fields if k in data} + + @staticmethod + async def delete_model_files(target_dir: str, file_name: str, file_monitor=None) -> List[str]: + """Delete model and associated files + + Args: + target_dir: Directory containing the model files + file_name: Base name of the model file without extension + file_monitor: Optional file monitor to ignore delete events + + Returns: + List of deleted file paths + """ + patterns = [ + f"{file_name}.safetensors", # Required + f"{file_name}.metadata.json", + ] + + # Add all preview file extensions + for ext in PREVIEW_EXTENSIONS: + patterns.append(f"{file_name}{ext}") + + deleted = [] + main_file = patterns[0] + main_path = os.path.join(target_dir, main_file).replace(os.sep, '/') + + if os.path.exists(main_path): + # Notify file monitor to ignore delete event if available + if file_monitor: + file_monitor.handler.add_ignore_path(main_path, 0) + + # Delete file + os.remove(main_path) + deleted.append(main_path) + else: + logger.warning(f"Model file not found: {main_file}") + + # Delete optional files + for pattern in patterns[1:]: + path = os.path.join(target_dir, pattern) + if os.path.exists(path): + try: + os.remove(path) + deleted.append(pattern) + except Exception as e: + logger.warning(f"Failed to delete {pattern}: {e}") + + return deleted + + @staticmethod + def get_multipart_ext(filename): + """Get extension that may have multiple parts like .metadata.json""" + parts = filename.split(".") + if len(parts) > 2: # If contains multi-part extension + return "." + ".".join(parts[-2:]) # Take the last two parts, like ".metadata.json" + return os.path.splitext(filename)[1] # Otherwise take the regular extension, like ".safetensors" + + # New common endpoint handlers + + @staticmethod + async def handle_delete_model(request: web.Request, scanner) -> web.Response: + """Handle model deletion request + + Args: + request: The aiohttp request + scanner: The model scanner instance with cache management methods + + Returns: + web.Response: The HTTP response + """ + try: + data = await request.json() + file_path = data.get('file_path') + if not file_path: + return web.Response(text='Model path is required', status=400) + + target_dir = os.path.dirname(file_path) + file_name = os.path.splitext(os.path.basename(file_path))[0] + + # Get the file monitor from the scanner if available + file_monitor = getattr(scanner, 'file_monitor', None) + + deleted_files = await ModelRouteUtils.delete_model_files( + target_dir, + file_name, + file_monitor + ) + + # Remove from cache + cache = await scanner.get_cached_data() + cache.raw_data = [item for item in cache.raw_data if item['file_path'] != file_path] + await cache.resort() + + # Update hash index if available + if hasattr(scanner, '_hash_index') and scanner._hash_index: + scanner._hash_index.remove_by_path(file_path) + + return web.json_response({ + 'success': True, + 'deleted_files': deleted_files + }) + + except Exception as e: + logger.error(f"Error deleting model: {e}", exc_info=True) + return web.Response(text=str(e), status=500) + + @staticmethod + async def handle_fetch_civitai(request: web.Request, scanner) -> web.Response: + """Handle CivitAI metadata fetch request + + Args: + request: The aiohttp request + scanner: The model scanner instance with cache management methods + + Returns: + web.Response: The HTTP response + """ + try: + data = await request.json() + metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json' + + # Check if model metadata exists + local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + if not local_metadata or not local_metadata.get('sha256'): + return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400) + + # Create a client for fetching from Civitai + client = CivitaiClient() + try: + # Fetch and update metadata + civitai_metadata = await client.get_model_by_hash(local_metadata["sha256"]) + if not civitai_metadata: + await ModelRouteUtils.handle_not_found_on_civitai(metadata_path, local_metadata) + return web.json_response({"success": False, "error": "Not found on CivitAI"}, status=404) + + await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, client) + + # Update the cache + await scanner.update_single_model_cache(data['file_path'], data['file_path'], local_metadata) + + return web.json_response({"success": True}) + finally: + await client.close() + + except Exception as e: + logger.error(f"Error fetching from CivitAI: {e}", exc_info=True) + return web.json_response({"success": False, "error": str(e)}, status=500) + + @staticmethod + async def handle_replace_preview(request: web.Request, scanner) -> web.Response: + """Handle preview image replacement request + + Args: + request: The aiohttp request + scanner: The model scanner instance with methods to update cache + + Returns: + web.Response: The HTTP response + """ + try: + reader = await request.multipart() + + # Read preview file data + field = await reader.next() + if field.name != 'preview_file': + raise ValueError("Expected 'preview_file' field") + content_type = field.headers.get('Content-Type', 'image/png') + preview_data = await field.read() + + # Read model path + field = await reader.next() + if field.name != 'model_path': + raise ValueError("Expected 'model_path' field") + model_path = (await field.read()).decode() + + # Save preview file + base_name = os.path.splitext(os.path.basename(model_path))[0] + folder = os.path.dirname(model_path) + + # Determine if content is video or image + if content_type.startswith('video/'): + # For videos, keep original format and use .mp4 extension + extension = '.mp4' + optimized_data = preview_data + else: + # For images, optimize and convert to WebP + optimized_data, _ = ExifUtils.optimize_image( + image_data=preview_data, + target_width=CARD_PREVIEW_WIDTH, + format='webp', + quality=85, + preserve_metadata=True + ) + extension = '.webp' # Use .webp without .preview part + + preview_path = os.path.join(folder, base_name + extension).replace(os.sep, '/') + + with open(preview_path, 'wb') as f: + f.write(optimized_data) + + # Update preview path in metadata + metadata_path = os.path.splitext(model_path)[0] + '.metadata.json' + if os.path.exists(metadata_path): + try: + with open(metadata_path, 'r', encoding='utf-8') as f: + metadata = json.load(f) + + # Update preview_url directly in the metadata dict + metadata['preview_url'] = preview_path + + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) + except Exception as e: + logger.error(f"Error updating metadata: {e}") + + # Update preview URL in scanner cache + if hasattr(scanner, 'update_preview_in_cache'): + await scanner.update_preview_in_cache(model_path, preview_path) + + return web.json_response({ + "success": True, + "preview_url": config.get_preview_static_url(preview_path) + }) + + except Exception as e: + logger.error(f"Error replacing preview: {e}", exc_info=True) + return web.Response(text=str(e), status=500) + + @staticmethod + async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type="lora") -> web.Response: + """Handle model download request + + Args: + request: The aiohttp request + download_manager: Instance of DownloadManager + model_type: Type of model ('lora' or 'checkpoint') + + Returns: + web.Response: The HTTP response + """ + try: + data = await request.json() + + # Create progress callback + async def progress_callback(progress): + from ..services.websocket_manager import ws_manager + await ws_manager.broadcast({ + 'status': 'progress', + 'progress': progress + }) + + # Check which identifier is provided + download_url = data.get('download_url') + model_hash = data.get('model_hash') + model_version_id = data.get('model_version_id') + + # Validate that at least one identifier is provided + if not any([download_url, model_hash, model_version_id]): + return web.Response( + status=400, + text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'" + ) + + # Use the correct root directory based on model type + root_key = 'checkpoint_root' if model_type == 'checkpoint' else 'lora_root' + save_dir = data.get(root_key) + + result = await download_manager.download_from_civitai( + download_url=download_url, + model_hash=model_hash, + model_version_id=model_version_id, + save_dir=save_dir, + relative_path=data.get('relative_path', ''), + progress_callback=progress_callback, + model_type=model_type + ) + + if not result.get('success', False): + error_message = result.get('error', 'Unknown error') + + # Return 401 for early access errors + if 'early access' in error_message.lower(): + logger.warning(f"Early access download failed: {error_message}") + return web.Response( + status=401, # Use 401 status code to match Civitai's response + text=f"Early Access Restriction: {error_message}" + ) + + return web.Response(status=500, text=error_message) + + return web.json_response(result) + + except Exception as e: + error_message = str(e) + + # Check if this might be an early access error + if '401' in error_message: + logger.warning(f"Early access error (401): {error_message}") + return web.Response( + status=401, + text="Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com." + ) + + logger.error(f"Error downloading {model_type}: {error_message}") + return web.Response(status=500, text=error_message) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index bb4cbc5b..b1dac0b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "beautifulsoup4", "piexif", "Pillow", + "olefile", # for getting rid of warning message "requests" ] diff --git a/requirements.txt b/requirements.txt index 88799022..c3c5cb67 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ watchdog beautifulsoup4 piexif Pillow +olefile requests \ No newline at end of file diff --git a/static/js/api/baseModelApi.js b/static/js/api/baseModelApi.js new file mode 100644 index 00000000..befc0471 --- /dev/null +++ b/static/js/api/baseModelApi.js @@ -0,0 +1,499 @@ +// filepath: d:\Workspace\ComfyUI\custom_nodes\ComfyUI-Lora-Manager\static\js\api\baseModelApi.js +import { state, getCurrentPageState } from '../state/index.js'; +import { showToast } from '../utils/uiHelpers.js'; +import { showDeleteModal, confirmDelete } from '../utils/modalUtils.js'; +import { getSessionItem } from '../utils/storageHelpers.js'; + +/** + * Shared functionality for handling models (loras and checkpoints) + */ + +// Generic function to load more models with pagination +export async function loadMoreModels(options = {}) { + const { + resetPage = false, + updateFolders = false, + modelType = 'lora', // 'lora' or 'checkpoint' + createCardFunction, + endpoint = '/api/loras' + } = options; + + const pageState = getCurrentPageState(); + + if (pageState.isLoading || (!pageState.hasMore && !resetPage)) return; + + pageState.isLoading = true; + document.body.classList.add('loading'); + + try { + // Reset to first page if requested + if (resetPage) { + pageState.currentPage = 1; + // Clear grid if resetting + const gridId = modelType === 'checkpoint' ? 'checkpointGrid' : 'loraGrid'; + const grid = document.getElementById(gridId); + if (grid) grid.innerHTML = ''; + } + + const params = new URLSearchParams({ + page: pageState.currentPage, + page_size: pageState.pageSize || 20, + sort_by: pageState.sortBy + }); + + if (pageState.activeFolder !== null) { + params.append('folder', pageState.activeFolder); + } + + // Add search parameters if there's a search term + if (pageState.filters?.search) { + params.append('search', pageState.filters.search); + params.append('fuzzy', 'true'); + + // Add search option parameters if available + if (pageState.searchOptions) { + params.append('search_filename', pageState.searchOptions.filename.toString()); + params.append('search_modelname', pageState.searchOptions.modelname.toString()); + if (pageState.searchOptions.tags !== undefined) { + params.append('search_tags', pageState.searchOptions.tags.toString()); + } + params.append('recursive', (pageState.searchOptions?.recursive ?? false).toString()); + } + } + + // Add filter parameters if active + if (pageState.filters) { + // Handle tags filters + if (pageState.filters.tags && pageState.filters.tags.length > 0) { + // Checkpoints API expects individual 'tag' parameters, Loras API expects comma-separated 'tags' + if (modelType === 'checkpoint') { + pageState.filters.tags.forEach(tag => { + params.append('tag', tag); + }); + } else { + params.append('tags', pageState.filters.tags.join(',')); + } + } + + // Handle base model filters + if (pageState.filters.baseModel && pageState.filters.baseModel.length > 0) { + if (modelType === 'checkpoint') { + pageState.filters.baseModel.forEach(model => { + params.append('base_model', model); + }); + } else { + params.append('base_models', pageState.filters.baseModel.join(',')); + } + } + } + + // Add model-specific parameters + if (modelType === 'lora') { + // Check for recipe-based filtering parameters from session storage + const filterLoraHash = getSessionItem('recipe_to_lora_filterLoraHash'); + const filterLoraHashes = getSessionItem('recipe_to_lora_filterLoraHashes'); + + // Add hash filter parameter if present + if (filterLoraHash) { + params.append('lora_hash', filterLoraHash); + } + // Add multiple hashes filter if present + else if (filterLoraHashes) { + try { + if (Array.isArray(filterLoraHashes) && filterLoraHashes.length > 0) { + params.append('lora_hashes', filterLoraHashes.join(',')); + } + } catch (error) { + console.error('Error parsing lora hashes from session storage:', error); + } + } + } + + const response = await fetch(`${endpoint}?${params}`); + if (!response.ok) { + throw new Error(`Failed to fetch models: ${response.statusText}`); + } + + const data = await response.json(); + + const gridId = modelType === 'checkpoint' ? 'checkpointGrid' : 'loraGrid'; + const grid = document.getElementById(gridId); + + if (data.items.length === 0 && pageState.currentPage === 1) { + grid.innerHTML = `
No ${modelType}s found in this folder
`; + pageState.hasMore = false; + } else if (data.items.length > 0) { + pageState.hasMore = pageState.currentPage < data.total_pages; + + // Append model cards using the provided card creation function + data.items.forEach(model => { + const card = createCardFunction(model); + grid.appendChild(card); + }); + + // Increment the page number AFTER successful loading + pageState.currentPage++; + } else { + pageState.hasMore = false; + } + + if (updateFolders && data.folders) { + updateFolderTags(data.folders); + } + + } catch (error) { + console.error(`Error loading ${modelType}s:`, error); + showToast(`Failed to load ${modelType}s: ${error.message}`, 'error'); + } finally { + pageState.isLoading = false; + document.body.classList.remove('loading'); + } +} + +// Update folder tags in the UI +export function updateFolderTags(folders) { + const folderTagsContainer = document.querySelector('.folder-tags'); + if (!folderTagsContainer) return; + + // Keep track of currently selected folder + const pageState = getCurrentPageState(); + const currentFolder = pageState.activeFolder; + + // Create HTML for folder tags + const tagsHTML = folders.map(folder => { + const isActive = folder === currentFolder; + return `
${folder}
`; + }).join(''); + + // Update the container + folderTagsContainer.innerHTML = tagsHTML; + + // Reattach click handlers and ensure the active tag is visible + const tags = folderTagsContainer.querySelectorAll('.tag'); + tags.forEach(tag => { + if (typeof toggleFolder === 'function') { + tag.addEventListener('click', toggleFolder); + } + if (tag.dataset.folder === currentFolder) { + tag.scrollIntoView({ behavior: 'smooth', block: 'nearest' }); + } + }); +} + +// Generic function to replace a model preview +export function replaceModelPreview(filePath, modelType = 'lora') { + // Open file picker + const input = document.createElement('input'); + input.type = 'file'; + input.accept ='image/*,video/mp4'; + + input.onchange = async function() { + if (!input.files || !input.files[0]) return; + + const file = input.files[0]; + await uploadPreview(filePath, file, modelType); + }; + + input.click(); +} + +// Delete a model (generic) +export function deleteModel(filePath, modelType = 'lora') { + if (modelType === 'checkpoint') { + confirmDelete('Are you sure you want to delete this checkpoint?', () => { + performDelete(filePath, modelType); + }); + } else { + showDeleteModal(filePath); + } +} + +// Reset and reload models +export async function resetAndReload(options = {}) { + const { + updateFolders = false, + modelType = 'lora', + loadMoreFunction + } = options; + + const pageState = getCurrentPageState(); + + // Reset pagination and load more models + if (typeof loadMoreFunction === 'function') { + await loadMoreFunction(true, updateFolders); + } +} + +// Generic function to refresh models +export async function refreshModels(options = {}) { + const { + modelType = 'lora', + scanEndpoint = '/api/loras/scan', + resetAndReloadFunction + } = options; + + try { + state.loadingManager.showSimpleLoading(`Refreshing ${modelType}s...`); + + const response = await fetch(scanEndpoint); + + if (!response.ok) { + throw new Error(`Failed to refresh ${modelType}s: ${response.status} ${response.statusText}`); + } + + if (typeof resetAndReloadFunction === 'function') { + await resetAndReloadFunction(); + } + + showToast(`Refresh complete`, 'success'); + } catch (error) { + console.error(`Refresh failed:`, error); + showToast(`Failed to refresh ${modelType}s`, 'error'); + } finally { + state.loadingManager.hide(); + state.loadingManager.restoreProgressBar(); + } +} + +// Generic fetch from Civitai +export async function fetchCivitaiMetadata(options = {}) { + const { + modelType = 'lora', + fetchEndpoint = '/api/fetch-all-civitai', + resetAndReloadFunction + } = options; + + let ws = null; + + await state.loadingManager.showWithProgress(async (loading) => { + try { + const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; + ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`); + + const operationComplete = new Promise((resolve, reject) => { + ws.onmessage = (event) => { + const data = JSON.parse(event.data); + + switch(data.status) { + case 'started': + loading.setStatus('Starting metadata fetch...'); + break; + + case 'processing': + const percent = ((data.processed / data.total) * 100).toFixed(1); + loading.setProgress(percent); + loading.setStatus( + `Processing (${data.processed}/${data.total}) ${data.current_name}` + ); + break; + + case 'completed': + loading.setProgress(100); + loading.setStatus( + `Completed: Updated ${data.success} of ${data.processed} ${modelType}s` + ); + resolve(); + break; + + case 'error': + reject(new Error(data.error)); + break; + } + }; + + ws.onerror = (error) => { + reject(new Error('WebSocket error: ' + error.message)); + }; + }); + + await new Promise((resolve, reject) => { + ws.onopen = resolve; + ws.onerror = reject; + }); + + const requestBody = modelType === 'checkpoint' + ? JSON.stringify({ model_type: 'checkpoint' }) + : JSON.stringify({}); + + const response = await fetch(fetchEndpoint, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: requestBody + }); + + if (!response.ok) { + throw new Error('Failed to fetch metadata'); + } + + await operationComplete; + + if (typeof resetAndReloadFunction === 'function') { + await resetAndReloadFunction(); + } + + } catch (error) { + console.error('Error fetching metadata:', error); + showToast('Failed to fetch metadata: ' + error.message, 'error'); + } finally { + if (ws) { + ws.close(); + } + } + }, { + initialMessage: 'Connecting...', + completionMessage: 'Metadata update complete' + }); +} + +// Generic function to refresh single model metadata +export async function refreshSingleModelMetadata(filePath, modelType = 'lora') { + try { + state.loadingManager.showSimpleLoading('Refreshing metadata...'); + + const endpoint = modelType === 'checkpoint' + ? '/api/checkpoints/fetch-civitai' + : '/api/fetch-civitai'; + + const response = await fetch(endpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ file_path: filePath }) + }); + + if (!response.ok) { + throw new Error('Failed to refresh metadata'); + } + + const data = await response.json(); + + if (data.success) { + showToast('Metadata refreshed successfully', 'success'); + return true; + } else { + throw new Error(data.error || 'Failed to refresh metadata'); + } + } catch (error) { + console.error('Error refreshing metadata:', error); + showToast(error.message, 'error'); + return false; + } finally { + state.loadingManager.hide(); + state.loadingManager.restoreProgressBar(); + } +} + +// Private methods + +// Upload a preview image +async function uploadPreview(filePath, file, modelType = 'lora') { + const loadingOverlay = document.getElementById('loading-overlay'); + const loadingStatus = document.querySelector('.loading-status'); + + try { + if (loadingOverlay) loadingOverlay.style.display = 'flex'; + if (loadingStatus) loadingStatus.textContent = 'Uploading preview...'; + + const formData = new FormData(); + + // Use appropriate parameter names and endpoint based on model type + // Prepare common form data + formData.append('preview_file', file); + formData.append('model_path', filePath); + + // Set endpoint based on model type + const endpoint = modelType === 'checkpoint' + ? '/api/checkpoints/replace-preview' + : '/api/replace_preview'; + + const response = await fetch(endpoint, { + method: 'POST', + body: formData + }); + + if (!response.ok) { + throw new Error('Upload failed'); + } + + const data = await response.json(); + + // Update the card preview in UI + const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); + if (card) { + const previewContainer = card.querySelector('.card-preview'); + const oldPreview = previewContainer.querySelector('img, video'); + + // For LoRA models, use timestamp to prevent caching + if (modelType === 'lora') { + state.previewVersions?.set(filePath, Date.now()); + } + + const timestamp = Date.now(); + const previewUrl = data.preview_url ? + `${data.preview_url}?t=${timestamp}` : + `/api/model/preview_image?path=${encodeURIComponent(filePath)}&t=${timestamp}`; + + // Create appropriate element based on file type + if (file.type.startsWith('video/')) { + const video = document.createElement('video'); + video.controls = true; + video.autoplay = true; + video.muted = true; + video.loop = true; + video.src = previewUrl; + oldPreview.replaceWith(video); + } else { + const img = document.createElement('img'); + img.src = previewUrl; + oldPreview.replaceWith(img); + } + + showToast('Preview updated successfully', 'success'); + } + } catch (error) { + console.error('Error uploading preview:', error); + showToast('Failed to upload preview image', 'error'); + } finally { + if (loadingOverlay) loadingOverlay.style.display = 'none'; + } +} + +// Private function to perform the delete operation +async function performDelete(filePath, modelType = 'lora') { + try { + showToast(`Deleting ${modelType}...`, 'info'); + + const response = await fetch('/api/model/delete', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + file_path: filePath, + model_type: modelType + }) + }); + + if (!response.ok) { + throw new Error(`Failed to delete ${modelType}: ${response.status} ${response.statusText}`); + } + + const data = await response.json(); + + if (data.success) { + // Remove the card from UI + const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); + if (card) { + card.remove(); + } + + showToast(`${modelType} deleted successfully`, 'success'); + } else { + throw new Error(data.error || `Failed to delete ${modelType}`); + } + } catch (error) { + console.error(`Error deleting ${modelType}:`, error); + showToast(`Failed to delete ${modelType}: ${error.message}`, 'error'); + } +} \ No newline at end of file diff --git a/static/js/api/checkpointApi.js b/static/js/api/checkpointApi.js new file mode 100644 index 00000000..8b243be9 --- /dev/null +++ b/static/js/api/checkpointApi.js @@ -0,0 +1,57 @@ +import { createCheckpointCard } from '../components/CheckpointCard.js'; +import { + loadMoreModels, + resetAndReload as baseResetAndReload, + refreshModels as baseRefreshModels, + deleteModel as baseDeleteModel, + replaceModelPreview, + fetchCivitaiMetadata +} from './baseModelApi.js'; + +// Load more checkpoints with pagination +export async function loadMoreCheckpoints(resetPagination = true) { + return loadMoreModels({ + resetPage: resetPagination, + updateFolders: true, + modelType: 'checkpoint', + createCardFunction: createCheckpointCard, + endpoint: '/api/checkpoints' + }); +} + +// Reset and reload checkpoints +export async function resetAndReload() { + return baseResetAndReload({ + updateFolders: true, + modelType: 'checkpoint', + loadMoreFunction: loadMoreCheckpoints + }); +} + +// Refresh checkpoints +export async function refreshCheckpoints() { + return baseRefreshModels({ + modelType: 'checkpoint', + scanEndpoint: '/api/checkpoints/scan', + resetAndReloadFunction: resetAndReload + }); +} + +// Delete a checkpoint +export function deleteCheckpoint(filePath) { + return baseDeleteModel(filePath, 'checkpoint'); +} + +// Replace checkpoint preview +export function replaceCheckpointPreview(filePath) { + return replaceModelPreview(filePath, 'checkpoint'); +} + +// Fetch metadata from Civitai for checkpoints +export async function fetchCivitai() { + return fetchCivitaiMetadata({ + modelType: 'checkpoint', + fetchEndpoint: '/api/checkpoints/fetch-all-civitai', + resetAndReloadFunction: resetAndReload + }); +} \ No newline at end of file diff --git a/static/js/api/loraApi.js b/static/js/api/loraApi.js index 8ebee93c..5e433799 100644 --- a/static/js/api/loraApi.js +++ b/static/js/api/loraApi.js @@ -1,292 +1,38 @@ -import { state, getCurrentPageState } from '../state/index.js'; -import { showToast } from '../utils/uiHelpers.js'; import { createLoraCard } from '../components/LoraCard.js'; -import { initializeInfiniteScroll } from '../utils/infiniteScroll.js'; -import { showDeleteModal } from '../utils/modalUtils.js'; -import { toggleFolder } from '../utils/uiHelpers.js'; -import { getSessionItem } from '../utils/storageHelpers.js'; +import { + loadMoreModels, + resetAndReload as baseResetAndReload, + refreshModels as baseRefreshModels, + deleteModel as baseDeleteModel, + replaceModelPreview, + fetchCivitaiMetadata, + refreshSingleModelMetadata +} from './baseModelApi.js'; export async function loadMoreLoras(resetPage = false, updateFolders = false) { - const pageState = getCurrentPageState(); - - if (pageState.isLoading || (!pageState.hasMore && !resetPage)) return; - - pageState.isLoading = true; - try { - // Reset to first page if requested - if (resetPage) { - pageState.currentPage = 1; - // Clear grid if resetting - const grid = document.getElementById('loraGrid'); - if (grid) grid.innerHTML = ''; - initializeInfiniteScroll(); - } - - const params = new URLSearchParams({ - page: pageState.currentPage, - page_size: 20, - sort_by: pageState.sortBy - }); - - if (pageState.activeFolder !== null) { - params.append('folder', pageState.activeFolder); - } - - // Add search parameters if there's a search term - if (pageState.filters?.search) { - params.append('search', pageState.filters.search); - params.append('fuzzy', 'true'); - - // Add search option parameters if available - if (pageState.searchOptions) { - params.append('search_filename', pageState.searchOptions.filename.toString()); - params.append('search_modelname', pageState.searchOptions.modelname.toString()); - params.append('search_tags', (pageState.searchOptions.tags || false).toString()); - params.append('recursive', (pageState.searchOptions?.recursive ?? false).toString()); - } - } - - // Add filter parameters if active - if (pageState.filters) { - if (pageState.filters.tags && pageState.filters.tags.length > 0) { - // Convert the array of tags to a comma-separated string - params.append('tags', pageState.filters.tags.join(',')); - } - if (pageState.filters.baseModel && pageState.filters.baseModel.length > 0) { - // Convert the array of base models to a comma-separated string - params.append('base_models', pageState.filters.baseModel.join(',')); - } - } - - // Check for recipe-based filtering parameters from session storage - const filterLoraHash = getSessionItem('recipe_to_lora_filterLoraHash'); - const filterLoraHashes = getSessionItem('recipe_to_lora_filterLoraHashes'); - - console.log('Filter Lora Hash:', filterLoraHash); - console.log('Filter Lora Hashes:', filterLoraHashes); - - // Add hash filter parameter if present - if (filterLoraHash) { - params.append('lora_hash', filterLoraHash); - } - // Add multiple hashes filter if present - else if (filterLoraHashes) { - try { - if (Array.isArray(filterLoraHashes) && filterLoraHashes.length > 0) { - params.append('lora_hashes', filterLoraHashes.join(',')); - } - } catch (error) { - console.error('Error parsing lora hashes from session storage:', error); - } - } - - const response = await fetch(`/api/loras?${params}`); - if (!response.ok) { - throw new Error(`Failed to fetch loras: ${response.statusText}`); - } - - const data = await response.json(); - - if (data.items.length === 0 && pageState.currentPage === 1) { - const grid = document.getElementById('loraGrid'); - grid.innerHTML = '
No loras found in this folder
'; - pageState.hasMore = false; - } else if (data.items.length > 0) { - pageState.hasMore = pageState.currentPage < data.total_pages; - pageState.currentPage++; - appendLoraCards(data.items); - - const sentinel = document.getElementById('scroll-sentinel'); - if (sentinel && state.observer) { - state.observer.observe(sentinel); - } - } else { - pageState.hasMore = false; - } - - if (updateFolders && data.folders) { - updateFolderTags(data.folders); - } - - } catch (error) { - console.error('Error loading loras:', error); - showToast('Failed to load loras: ' + error.message, 'error'); - } finally { - pageState.isLoading = false; - } -} - -function updateFolderTags(folders) { - const folderTagsContainer = document.querySelector('.folder-tags'); - if (!folderTagsContainer) return; - - // Keep track of currently selected folder - const pageState = getCurrentPageState(); - const currentFolder = pageState.activeFolder; - - // Create HTML for folder tags - const tagsHTML = folders.map(folder => { - const isActive = folder === currentFolder; - return `
${folder}
`; - }).join(''); - - // Update the container - folderTagsContainer.innerHTML = tagsHTML; - - // Reattach click handlers and ensure the active tag is visible - const tags = folderTagsContainer.querySelectorAll('.tag'); - tags.forEach(tag => { - tag.addEventListener('click', toggleFolder); - if (tag.dataset.folder === currentFolder) { - tag.scrollIntoView({ behavior: 'smooth', block: 'nearest' }); - } + return loadMoreModels({ + resetPage, + updateFolders, + modelType: 'lora', + createCardFunction: createLoraCard, + endpoint: '/api/loras' }); } export async function fetchCivitai() { - let ws = null; - - await state.loadingManager.showWithProgress(async (loading) => { - try { - const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; - const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`); - - const operationComplete = new Promise((resolve, reject) => { - ws.onmessage = (event) => { - const data = JSON.parse(event.data); - - switch(data.status) { - case 'started': - loading.setStatus('Starting metadata fetch...'); - break; - - case 'processing': - const percent = ((data.processed / data.total) * 100).toFixed(1); - loading.setProgress(percent); - loading.setStatus( - `Processing (${data.processed}/${data.total}) ${data.current_name}` - ); - break; - - case 'completed': - loading.setProgress(100); - loading.setStatus( - `Completed: Updated ${data.success} of ${data.processed} loras` - ); - resolve(); - break; - - case 'error': - reject(new Error(data.error)); - break; - } - }; - - ws.onerror = (error) => { - reject(new Error('WebSocket error: ' + error.message)); - }; - }); - - await new Promise((resolve, reject) => { - ws.onopen = resolve; - ws.onerror = reject; - }); - - const response = await fetch('/api/fetch-all-civitai', { - method: 'POST', - headers: { 'Content-Type': 'application/json' } - }); - - if (!response.ok) { - throw new Error('Failed to fetch metadata'); - } - - await operationComplete; - - await resetAndReload(); - - } catch (error) { - console.error('Error fetching metadata:', error); - showToast('Failed to fetch metadata: ' + error.message, 'error'); - } finally { - if (ws) { - ws.close(); - } - } - }, { - initialMessage: 'Connecting...', - completionMessage: 'Metadata update complete' + return fetchCivitaiMetadata({ + modelType: 'lora', + fetchEndpoint: '/api/fetch-all-civitai', + resetAndReloadFunction: resetAndReload }); } export async function deleteModel(filePath) { - showDeleteModal(filePath); + return baseDeleteModel(filePath, 'lora'); } export async function replacePreview(filePath) { - const loadingOverlay = document.getElementById('loading-overlay'); - const loadingStatus = document.querySelector('.loading-status'); - - const input = document.createElement('input'); - input.type = 'file'; - input.accept = 'image/*,video/mp4'; - - input.onchange = async function() { - if (!input.files || !input.files[0]) return; - - const file = input.files[0]; - const formData = new FormData(); - formData.append('preview_file', file); - formData.append('model_path', filePath); - - try { - loadingOverlay.style.display = 'flex'; - loadingStatus.textContent = 'Uploading preview...'; - - const response = await fetch('/api/replace_preview', { - method: 'POST', - body: formData - }); - - if (!response.ok) { - throw new Error('Upload failed'); - } - - const data = await response.json(); - - // 更新预览版本 - state.previewVersions.set(filePath, Date.now()); - - // 更新卡片显示 - const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); - const previewContainer = card.querySelector('.card-preview'); - const oldPreview = previewContainer.querySelector('img, video'); - - const previewUrl = `${data.preview_url}?t=${state.previewVersions.get(filePath)}`; - - if (file.type.startsWith('video/')) { - const video = document.createElement('video'); - video.controls = true; - video.autoplay = true; - video.muted = true; - video.loop = true; - video.src = previewUrl; - oldPreview.replaceWith(video); - } else { - const img = document.createElement('img'); - img.src = previewUrl; - oldPreview.replaceWith(img); - } - - } catch (error) { - console.error('Error uploading preview:', error); - alert('Failed to upload preview image'); - } finally { - loadingOverlay.style.display = 'none'; - } - }; - - input.click(); + return replaceModelPreview(filePath, 'lora'); } export function appendLoraCards(loras) { @@ -300,60 +46,26 @@ export function appendLoraCards(loras) { } export async function resetAndReload(updateFolders = false) { - const pageState = getCurrentPageState(); - console.log('Resetting with state:', { ...pageState }); - - // Initialize infinite scroll - will reset the observer - initializeInfiniteScroll(); - - // Load more loras with reset flag - await loadMoreLoras(true, updateFolders); + return baseResetAndReload({ + updateFolders, + modelType: 'lora', + loadMoreFunction: loadMoreLoras + }); } export async function refreshLoras() { - try { - state.loadingManager.showSimpleLoading('Refreshing loras...'); - await resetAndReload(); - showToast('Refresh complete', 'success'); - } catch (error) { - console.error('Refresh failed:', error); - showToast('Failed to refresh loras', 'error'); - } finally { - state.loadingManager.hide(); - state.loadingManager.restoreProgressBar(); - } + return baseRefreshModels({ + modelType: 'lora', + scanEndpoint: '/api/loras/scan', + resetAndReloadFunction: resetAndReload + }); } export async function refreshSingleLoraMetadata(filePath) { - try { - state.loadingManager.showSimpleLoading('Refreshing metadata...'); - const response = await fetch('/api/fetch-civitai', { - method: 'POST', - headers: { - 'Content-Type': 'application/json' - }, - body: JSON.stringify({ file_path: filePath }) - }); - - if (!response.ok) { - throw new Error('Failed to refresh metadata'); - } - - const data = await response.json(); - - if (data.success) { - showToast('Metadata refreshed successfully', 'success'); - // Reload the current view to show updated data - await resetAndReload(); - } else { - throw new Error(data.error || 'Failed to refresh metadata'); - } - } catch (error) { - console.error('Error refreshing metadata:', error); - showToast(error.message, 'error'); - } finally { - state.loadingManager.hide(); - state.loadingManager.restoreProgressBar(); + const success = await refreshSingleModelMetadata(filePath, 'lora'); + if (success) { + // Reload the current view to show updated data + await resetAndReload(); } } diff --git a/static/js/checkpoints.js b/static/js/checkpoints.js index ea149a2f..2f1d316f 100644 --- a/static/js/checkpoints.js +++ b/static/js/checkpoints.js @@ -1,36 +1,55 @@ import { appCore } from './core.js'; -import { state, initPageState } from './state/index.js'; +import { initializeInfiniteScroll } from './utils/infiniteScroll.js'; +import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js'; +import { createPageControls } from './components/controls/index.js'; +import { loadMoreCheckpoints } from './api/checkpointApi.js'; +import { CheckpointDownloadManager } from './managers/CheckpointDownloadManager.js'; // Initialize the Checkpoints page class CheckpointsPageManager { constructor() { - // Initialize any necessary state - this.initialized = false; + // Initialize page controls + this.pageControls = createPageControls('checkpoints'); + + // Initialize checkpoint download manager + window.checkpointDownloadManager = new CheckpointDownloadManager(); + + // Expose only necessary functions to global scope + this._exposeRequiredGlobalFunctions(); + } + + _exposeRequiredGlobalFunctions() { + // Minimal set of functions that need to remain global + window.confirmDelete = confirmDelete; + window.closeDeleteModal = closeDeleteModal; + + // Add loadCheckpoints function to window for FilterManager compatibility + window.checkpointManager = { + loadCheckpoints: (reset) => loadMoreCheckpoints(reset) + }; } async initialize() { - if (this.initialized) return; - - // Initialize page state - initPageState('checkpoints'); - - // Initialize core application - await appCore.initialize(); - // Initialize page-specific components - this._initializeWorkInProgress(); + this.pageControls.restoreFolderFilter(); + this.pageControls.initFolderTagsVisibility(); - this.initialized = true; - } - - _initializeWorkInProgress() { - // Add any work-in-progress specific initialization here - console.log('Checkpoints Manager is under development'); + // Initialize infinite scroll + initializeInfiniteScroll('checkpoints'); + + // Initialize common page features + appCore.initializePageFeatures(); + + console.log('Checkpoints Manager initialized'); } } // Initialize everything when DOM is ready document.addEventListener('DOMContentLoaded', async () => { + // Initialize core application + await appCore.initialize(); + + // Initialize checkpoints page const checkpointsPage = new CheckpointsPageManager(); await checkpointsPage.initialize(); }); diff --git a/static/js/components/CheckpointCard.js b/static/js/components/CheckpointCard.js new file mode 100644 index 00000000..0ff1271d --- /dev/null +++ b/static/js/components/CheckpointCard.js @@ -0,0 +1,313 @@ +import { showToast } from '../utils/uiHelpers.js'; +import { state } from '../state/index.js'; +import { showCheckpointModal } from './checkpointModal/index.js'; +import { NSFW_LEVELS } from '../utils/constants.js'; +import { replaceCheckpointPreview as apiReplaceCheckpointPreview } from '../api/checkpointApi.js'; + +export function createCheckpointCard(checkpoint) { + const card = document.createElement('div'); + card.className = 'lora-card'; // Reuse the same class for styling + card.dataset.sha256 = checkpoint.sha256; + card.dataset.filepath = checkpoint.file_path; + card.dataset.name = checkpoint.model_name; + card.dataset.file_name = checkpoint.file_name; + card.dataset.folder = checkpoint.folder; + card.dataset.modified = checkpoint.modified; + card.dataset.file_size = checkpoint.file_size; + card.dataset.from_civitai = checkpoint.from_civitai; + card.dataset.notes = checkpoint.notes || ''; + card.dataset.base_model = checkpoint.base_model || 'Unknown'; + + // Store metadata if available + if (checkpoint.civitai) { + card.dataset.meta = JSON.stringify(checkpoint.civitai || {}); + } + + // Store tags if available + if (checkpoint.tags && Array.isArray(checkpoint.tags)) { + card.dataset.tags = JSON.stringify(checkpoint.tags); + } + + if (checkpoint.modelDescription) { + card.dataset.modelDescription = checkpoint.modelDescription; + } + + // Store NSFW level if available + const nsfwLevel = checkpoint.preview_nsfw_level !== undefined ? checkpoint.preview_nsfw_level : 0; + card.dataset.nsfwLevel = nsfwLevel; + + // Determine if the preview should be blurred based on NSFW level and user settings + const shouldBlur = state.settings.blurMatureContent && nsfwLevel > NSFW_LEVELS.PG13; + if (shouldBlur) { + card.classList.add('nsfw-content'); + } + + // Determine preview URL + const previewUrl = checkpoint.preview_url || '/loras_static/images/no-preview.png'; + const version = state.previewVersions ? state.previewVersions.get(checkpoint.file_path) : null; + const versionedPreviewUrl = version ? `${previewUrl}?t=${version}` : previewUrl; + + // Determine NSFW warning text based on level + let nsfwText = "Mature Content"; + if (nsfwLevel >= NSFW_LEVELS.XXX) { + nsfwText = "XXX-rated Content"; + } else if (nsfwLevel >= NSFW_LEVELS.X) { + nsfwText = "X-rated Content"; + } else if (nsfwLevel >= NSFW_LEVELS.R) { + nsfwText = "R-rated Content"; + } + + // Check if autoplayOnHover is enabled for video previews + const autoplayOnHover = state.global?.settings?.autoplayOnHover || false; + const isVideo = previewUrl.endsWith('.mp4'); + const videoAttrs = autoplayOnHover ? 'controls muted loop' : 'controls autoplay muted loop'; + + card.innerHTML = ` +
+ ${isVideo ? + `` : + `${checkpoint.model_name}` + } +
+ ${shouldBlur ? + `` : ''} + + ${checkpoint.base_model} + +
+ + + + + + +
+
+ ${shouldBlur ? ` +
+
+

${nsfwText}

+ +
+
+ ` : ''} + +
+ `; + + // Main card click event + card.addEventListener('click', () => { + // Show checkpoint details modal + const checkpointMeta = { + sha256: card.dataset.sha256, + file_path: card.dataset.filepath, + model_name: card.dataset.name, + file_name: card.dataset.file_name, + folder: card.dataset.folder, + modified: card.dataset.modified, + file_size: parseInt(card.dataset.file_size || '0'), + from_civitai: card.dataset.from_civitai === 'true', + base_model: card.dataset.base_model, + notes: card.dataset.notes || '', + preview_url: versionedPreviewUrl, + // Parse civitai metadata from the card's dataset + civitai: (() => { + try { + return JSON.parse(card.dataset.meta || '{}'); + } catch (e) { + console.error('Failed to parse civitai metadata:', e); + return {}; // Return empty object on error + } + })(), + tags: (() => { + try { + return JSON.parse(card.dataset.tags || '[]'); + } catch (e) { + console.error('Failed to parse tags:', e); + return []; // Return empty array on error + } + })(), + modelDescription: card.dataset.modelDescription || '' + }; + showCheckpointModal(checkpointMeta); + }); + + // Toggle blur button functionality + const toggleBlurBtn = card.querySelector('.toggle-blur-btn'); + if (toggleBlurBtn) { + toggleBlurBtn.addEventListener('click', (e) => { + e.stopPropagation(); + const preview = card.querySelector('.card-preview'); + const isBlurred = preview.classList.toggle('blurred'); + const icon = toggleBlurBtn.querySelector('i'); + + // Update the icon based on blur state + if (isBlurred) { + icon.className = 'fas fa-eye'; + } else { + icon.className = 'fas fa-eye-slash'; + } + + // Toggle the overlay visibility + const overlay = card.querySelector('.nsfw-overlay'); + if (overlay) { + overlay.style.display = isBlurred ? 'flex' : 'none'; + } + }); + } + + // Show content button functionality + const showContentBtn = card.querySelector('.show-content-btn'); + if (showContentBtn) { + showContentBtn.addEventListener('click', (e) => { + e.stopPropagation(); + const preview = card.querySelector('.card-preview'); + preview.classList.remove('blurred'); + + // Update the toggle button icon + const toggleBtn = card.querySelector('.toggle-blur-btn'); + if (toggleBtn) { + toggleBtn.querySelector('i').className = 'fas fa-eye-slash'; + } + + // Hide the overlay + const overlay = card.querySelector('.nsfw-overlay'); + if (overlay) { + overlay.style.display = 'none'; + } + }); + } + + // Copy button click event + card.querySelector('.fa-copy')?.addEventListener('click', async e => { + e.stopPropagation(); + const checkpointName = card.dataset.file_name; + + try { + // Modern clipboard API + if (navigator.clipboard && window.isSecureContext) { + await navigator.clipboard.writeText(checkpointName); + } else { + // Fallback for older browsers + const textarea = document.createElement('textarea'); + textarea.value = checkpointName; + textarea.style.position = 'absolute'; + textarea.style.left = '-99999px'; + document.body.appendChild(textarea); + textarea.select(); + document.execCommand('copy'); + document.body.removeChild(textarea); + } + showToast('Checkpoint name copied', 'success'); + } catch (err) { + console.error('Copy failed:', err); + showToast('Copy failed', 'error'); + } + }); + + // Civitai button click event + if (checkpoint.from_civitai) { + card.querySelector('.fa-globe')?.addEventListener('click', e => { + e.stopPropagation(); + openCivitai(checkpoint.model_name); + }); + } + + // Delete button click event + card.querySelector('.fa-trash')?.addEventListener('click', e => { + e.stopPropagation(); + deleteCheckpoint(checkpoint.file_path); + }); + + // Replace preview button click event + card.querySelector('.fa-image')?.addEventListener('click', e => { + e.stopPropagation(); + replaceCheckpointPreview(checkpoint.file_path); + }); + + // Add autoplayOnHover handlers for video elements if needed + const videoElement = card.querySelector('video'); + if (videoElement && autoplayOnHover) { + const cardPreview = card.querySelector('.card-preview'); + + // Remove autoplay attribute and pause initially + videoElement.removeAttribute('autoplay'); + videoElement.pause(); + + // Add mouse events to trigger play/pause + cardPreview.addEventListener('mouseenter', () => { + videoElement.play(); + }); + + cardPreview.addEventListener('mouseleave', () => { + videoElement.pause(); + videoElement.currentTime = 0; + }); + } + + return card; +} + +// These functions will be implemented in checkpointApi.js +function openCivitai(modelName) { + // Check if the global function exists (registered by PageControls) + if (window.openCivitai) { + window.openCivitai(modelName); + } else { + // Fallback implementation + const card = document.querySelector(`.lora-card[data-name="${modelName}"]`); + if (!card) return; + + const metaData = JSON.parse(card.dataset.meta || '{}'); + const civitaiId = metaData.modelId; + const versionId = metaData.id; + + // Build URL + if (civitaiId) { + let url = `https://civitai.com/models/${civitaiId}`; + if (versionId) { + url += `?modelVersionId=${versionId}`; + } + window.open(url, '_blank'); + } else { + // If no ID, try searching by name + window.open(`https://civitai.com/models?query=${encodeURIComponent(modelName)}`, '_blank'); + } + } +} + +function deleteCheckpoint(filePath) { + if (window.deleteCheckpoint) { + window.deleteCheckpoint(filePath); + } else { + // Use the modal delete functionality + import('../utils/modalUtils.js').then(({ showDeleteModal }) => { + showDeleteModal(filePath, 'checkpoint'); + }); + } +} + +function replaceCheckpointPreview(filePath) { + if (window.replaceCheckpointPreview) { + window.replaceCheckpointPreview(filePath); + } else { + apiReplaceCheckpointPreview(filePath); + } +} \ No newline at end of file diff --git a/static/js/components/ContextMenu.js b/static/js/components/ContextMenu.js index 77380b1d..cce09f61 100644 --- a/static/js/components/ContextMenu.js +++ b/static/js/components/ContextMenu.js @@ -130,7 +130,7 @@ export class LoraContextMenu { } async saveModelMetadata(filePath, data) { - const response = await fetch('/loras/api/save-metadata', { + const response = await fetch('/api/loras/save-metadata', { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/static/js/components/LoraCard.js b/static/js/components/LoraCard.js index c4baee23..86737b59 100644 --- a/static/js/components/LoraCard.js +++ b/static/js/components/LoraCard.js @@ -1,8 +1,9 @@ -import { showToast } from '../utils/uiHelpers.js'; +import { showToast, openCivitai } from '../utils/uiHelpers.js'; import { state } from '../state/index.js'; import { showLoraModal } from './loraModal/index.js'; import { bulkManager } from '../managers/BulkManager.js'; import { NSFW_LEVELS } from '../utils/constants.js'; +import { replacePreview, deleteModel } from '../api/loraApi.js' export function createLoraCard(lora) { const card = document.createElement('div'); diff --git a/static/js/components/LoraModal.js b/static/js/components/LoraModal.js deleted file mode 100644 index 186158b4..00000000 --- a/static/js/components/LoraModal.js +++ /dev/null @@ -1,1733 +0,0 @@ -import { showToast } from '../utils/uiHelpers.js'; -import { state } from '../state/index.js'; -import { modalManager } from '../managers/ModalManager.js'; -import { NSFW_LEVELS, BASE_MODELS } from '../utils/constants.js'; - -export function showLoraModal(lora) { - const escapedWords = lora.civitai?.trainedWords?.length ? - lora.civitai.trainedWords.map(word => word.replace(/'/g, '\\\'')) : []; - - const content = ` - - `; - - modalManager.showModal('loraModal', content); - setupEditableFields(); - setupShowcaseScroll(); - setupTabSwitching(); - setupTagTooltip(); - setupTriggerWordsEditMode(); - setupModelNameEditing(); - setupBaseModelEditing(); - setupFileNameEditing(); - - // If we have a model ID but no description, fetch it - if (lora.civitai?.modelId && !lora.modelDescription) { - loadModelDescription(lora.civitai.modelId, lora.file_path); - } -} - -// Function to render showcase content -function renderShowcaseContent(images) { - if (!images?.length) return '
No example images available
'; - - // Filter images based on SFW setting - const showOnlySFW = state.settings.show_only_sfw; - let filteredImages = images; - let hiddenCount = 0; - - if (showOnlySFW) { - filteredImages = images.filter(img => { - const nsfwLevel = img.nsfwLevel !== undefined ? img.nsfwLevel : 0; - const isSfw = nsfwLevel < NSFW_LEVELS.R; - if (!isSfw) hiddenCount++; - return isSfw; - }); - } - - // Show message if no images are available after filtering - if (filteredImages.length === 0) { - return ` -
-

All example images are filtered due to NSFW content settings

-

Your settings are currently set to show only safe-for-work content

-

You can change this in Settings

-
- `; - } - - // Show hidden content notification if applicable - const hiddenNotification = hiddenCount > 0 ? - `
- ${hiddenCount} ${hiddenCount === 1 ? 'image' : 'images'} hidden due to SFW-only setting -
` : ''; - - return ` -
- - Scroll or click to show ${filteredImages.length} examples -
- - `; -} - -// Helper function to generate video wrapper HTML -function generateVideoWrapper(img, heightPercent, shouldBlur, nsfwText, metadataPanel) { - return ` -
- ${shouldBlur ? ` - - ` : ''} - - ${shouldBlur ? ` -
-
-

${nsfwText}

- -
-
- ` : ''} - ${metadataPanel} -
- `; -} - -// Helper function to generate image wrapper HTML -function generateImageWrapper(img, heightPercent, shouldBlur, nsfwText, metadataPanel) { - return ` -
- ${shouldBlur ? ` - - ` : ''} - Preview - ${shouldBlur ? ` -
-
-

${nsfwText}

- -
-
- ` : ''} - ${metadataPanel} -
- `; -} - -// New function to handle tab switching -function setupTabSwitching() { - const tabButtons = document.querySelectorAll('.showcase-tabs .tab-btn'); - - tabButtons.forEach(button => { - button.addEventListener('click', () => { - // Remove active class from all tabs - document.querySelectorAll('.showcase-tabs .tab-btn').forEach(btn => - btn.classList.remove('active') - ); - document.querySelectorAll('.tab-content .tab-pane').forEach(tab => - tab.classList.remove('active') - ); - - // Add active class to clicked tab - button.classList.add('active'); - const tabId = `${button.dataset.tab}-tab`; - document.getElementById(tabId).classList.add('active'); - - // If switching to description tab, make sure content is properly sized - if (button.dataset.tab === 'description') { - const descriptionContent = document.querySelector('.model-description-content'); - if (descriptionContent) { - const hasContent = descriptionContent.innerHTML.trim() !== ''; - document.querySelector('.model-description-loading')?.classList.add('hidden'); - - // If no content, show a message - if (!hasContent) { - descriptionContent.innerHTML = '
No model description available
'; - descriptionContent.classList.remove('hidden'); - } - } - } - }); - }); -} - -// New function to load model description -async function loadModelDescription(modelId, filePath) { - try { - const descriptionContainer = document.querySelector('.model-description-content'); - const loadingElement = document.querySelector('.model-description-loading'); - - if (!descriptionContainer || !loadingElement) return; - - // Show loading indicator - loadingElement.classList.remove('hidden'); - descriptionContainer.classList.add('hidden'); - - // Try to get model description from API - const response = await fetch(`/api/lora-model-description?model_id=${modelId}&file_path=${encodeURIComponent(filePath)}`); - - if (!response.ok) { - throw new Error(`Failed to fetch model description: ${response.statusText}`); - } - - const data = await response.json(); - - if (data.success && data.description) { - // Update the description content - descriptionContainer.innerHTML = data.description; - - // Process any links in the description to open in new tab - const links = descriptionContainer.querySelectorAll('a'); - links.forEach(link => { - link.setAttribute('target', '_blank'); - link.setAttribute('rel', 'noopener noreferrer'); - }); - - // Show the description and hide loading indicator - descriptionContainer.classList.remove('hidden'); - loadingElement.classList.add('hidden'); - } else { - throw new Error(data.error || 'No description available'); - } - } catch (error) { - console.error('Error loading model description:', error); - const loadingElement = document.querySelector('.model-description-loading'); - if (loadingElement) { - loadingElement.innerHTML = `
Failed to load model description. ${error.message}
`; - } - - // Show empty state message in the description container - const descriptionContainer = document.querySelector('.model-description-content'); - if (descriptionContainer) { - descriptionContainer.innerHTML = '
No model description available
'; - descriptionContainer.classList.remove('hidden'); - } - } -} - -// 添加复制文件名的函数 -window.copyFileName = async function(fileName) { - try { - await navigator.clipboard.writeText(fileName); - showToast('File name copied', 'success'); - } catch (err) { - console.error('Copy failed:', err); - showToast('Copy failed', 'error'); - } -}; - -// Add function to save model name -window.saveModelName = async function(filePath) { - const modelNameElement = document.querySelector('.model-name-content'); - const newModelName = modelNameElement.textContent.trim(); - - // Validate model name - if (!newModelName) { - showToast('Model name cannot be empty', 'error'); - return; - } - - // Check if model name is too long (limit to 100 characters) - if (newModelName.length > 100) { - showToast('Model name is too long (maximum 100 characters)', 'error'); - // Truncate the displayed text - modelNameElement.textContent = newModelName.substring(0, 100); - return; - } - - try { - await saveModelMetadata(filePath, { model_name: newModelName }); - - // Update the corresponding lora card's dataset and display - const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); - if (loraCard) { - loraCard.dataset.model_name = newModelName; - const titleElement = loraCard.querySelector('.card-title'); - if (titleElement) { - titleElement.textContent = newModelName; - } - } - - showToast('Model name updated successfully', 'success'); - - // Reload the page to reflect the sorted order - setTimeout(() => { - window.location.reload(); - }, 1500); - } catch (error) { - showToast('Failed to update model name', 'error'); - } -}; - -function setupEditableFields() { - const editableFields = document.querySelectorAll('.editable-field [contenteditable]'); - - editableFields.forEach(field => { - field.addEventListener('focus', function() { - if (this.textContent === 'Add your notes here...') { - this.textContent = ''; - } - }); - - field.addEventListener('blur', function() { - if (this.textContent.trim() === '') { - if (this.classList.contains('notes-content')) { - this.textContent = 'Add your notes here...'; - } - } - }); - }); - - const presetSelector = document.getElementById('preset-selector'); - const presetValue = document.getElementById('preset-value'); - const addPresetBtn = document.querySelector('.add-preset-btn'); - const presetTags = document.querySelector('.preset-tags'); - - presetSelector.addEventListener('change', function() { - const selected = this.value; - if (selected) { - presetValue.style.display = 'inline-block'; - presetValue.min = selected.includes('strength') ? -10 : 0; - presetValue.max = selected.includes('strength') ? 10 : 10; - presetValue.step = 0.5; - if (selected === 'clip_skip') { - presetValue.type = 'number'; - presetValue.step = 1; - } - // Add auto-focus - setTimeout(() => presetValue.focus(), 0); - } else { - presetValue.style.display = 'none'; - } - }); - - addPresetBtn.addEventListener('click', async function() { - const key = presetSelector.value; - const value = presetValue.value; - - if (!key || !value) return; - - const filePath = document.querySelector('#loraModal .modal-content') - .querySelector('.file-path').textContent + - document.querySelector('#loraModal .modal-content') - .querySelector('#file-name').textContent + '.safetensors'; - - const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); - const currentPresets = parsePresets(loraCard.dataset.usage_tips); - - currentPresets[key] = parseFloat(value); - const newPresetsJson = JSON.stringify(currentPresets); - - await saveModelMetadata(filePath, { - usage_tips: newPresetsJson - }); - - loraCard.dataset.usage_tips = newPresetsJson; - presetTags.innerHTML = renderPresetTags(currentPresets); - - presetSelector.value = ''; - presetValue.value = ''; - presetValue.style.display = 'none'; - }); - - // Add keydown event listeners for notes - const notesContent = document.querySelector('.notes-content'); - if (notesContent) { - notesContent.addEventListener('keydown', async function(e) { - if (e.key === 'Enter') { - if (e.shiftKey) { - // Allow shift+enter for new line - return; - } - e.preventDefault(); - const filePath = document.querySelector('#loraModal .modal-content') - .querySelector('.file-path').textContent + - document.querySelector('#loraModal .modal-content') - .querySelector('#file-name').textContent + '.safetensors'; - await saveNotes(filePath); - } - }); - } - - // Add keydown event for preset value - presetValue.addEventListener('keydown', function(e) { - if (e.key === 'Enter') { - e.preventDefault(); - addPresetBtn.click(); - } - }); -} - -window.saveNotes = async function(filePath) { - const content = document.querySelector('.notes-content').textContent; - try { - await saveModelMetadata(filePath, { notes: content }); - - // Update the corresponding lora card's dataset - const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); - if (loraCard) { - loraCard.dataset.notes = content; - } - - showToast('Notes saved successfully', 'success'); - } catch (error) { - showToast('Failed to save notes', 'error'); - } -}; - -async function saveModelMetadata(filePath, data) { - const response = await fetch('/loras/api/save-metadata', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - file_path: filePath, - ...data - }) - }); - - if (!response.ok) { - throw new Error('Failed to save metadata'); - } -} - -function renderTriggerWords(words, filePath) { - if (!words.length) return ` -
-
- - -
-
- No trigger word needed - -
- - -
- `; - - return ` -
-
- - -
-
-
- ${words.map(word => ` -
- ${word} - - - - -
- `).join('')} -
-
- - -
- `; -} - -export function toggleShowcase(element) { - const carousel = element.nextElementSibling; - const isCollapsed = carousel.classList.contains('collapsed'); - const indicator = element.querySelector('span'); - const icon = element.querySelector('i'); - - carousel.classList.toggle('collapsed'); - - if (isCollapsed) { - const count = carousel.querySelectorAll('.media-wrapper').length; - indicator.textContent = `Scroll or click to hide examples`; - icon.classList.replace('fa-chevron-down', 'fa-chevron-up'); - initLazyLoading(carousel); - - // Initialize NSFW content blur toggle handlers - initNsfwBlurHandlers(carousel); - - // Initialize metadata panel interaction handlers - initMetadataPanelHandlers(carousel); - } else { - const count = carousel.querySelectorAll('.media-wrapper').length; - indicator.textContent = `Scroll or click to show ${count} examples`; - icon.classList.replace('fa-chevron-up', 'fa-chevron-down'); - - // Make sure any open metadata panels get closed - const carouselContainer = carousel.querySelector('.carousel-container'); - if (carouselContainer) { - carouselContainer.style.height = '0'; - setTimeout(() => { - carouselContainer.style.height = ''; - }, 300); - } - } -} - -// Function to initialize metadata panel interactions -function initMetadataPanelHandlers(container) { - // Find all media wrappers - const mediaWrappers = container.querySelectorAll('.media-wrapper'); - - mediaWrappers.forEach(wrapper => { - // Get the metadata panel - const metadataPanel = wrapper.querySelector('.image-metadata-panel'); - if (!metadataPanel) return; - - // Prevent events from the metadata panel from bubbling - metadataPanel.addEventListener('click', (e) => { - e.stopPropagation(); - }); - - // Handle copy prompt button clicks - const copyBtns = metadataPanel.querySelectorAll('.copy-prompt-btn'); - copyBtns.forEach(copyBtn => { - const promptIndex = copyBtn.dataset.promptIndex; - const promptElement = wrapper.querySelector(`#prompt-${promptIndex}`); - - copyBtn.addEventListener('click', async (e) => { - e.stopPropagation(); // Prevent bubbling - - if (!promptElement) return; - - try { - await navigator.clipboard.writeText(promptElement.textContent); - showToast('Prompt copied to clipboard', 'success'); - } catch (err) { - console.error('Copy failed:', err); - showToast('Copy failed', 'error'); - } - }); - }); - - // Prevent scrolling in the metadata panel from scrolling the whole modal - metadataPanel.addEventListener('wheel', (e) => { - const isAtTop = metadataPanel.scrollTop === 0; - const isAtBottom = metadataPanel.scrollHeight - metadataPanel.scrollTop === metadataPanel.clientHeight; - - // Only prevent default if scrolling would cause the panel to scroll - if ((e.deltaY < 0 && !isAtTop) || (e.deltaY > 0 && !isAtBottom)) { - e.stopPropagation(); - } - }, { passive: true }); - }); -} - -// New function to initialize blur toggle handlers for showcase images/videos -function initNsfwBlurHandlers(container) { - // Handle toggle blur buttons - const toggleButtons = container.querySelectorAll('.toggle-blur-btn'); - toggleButtons.forEach(btn => { - btn.addEventListener('click', (e) => { - e.stopPropagation(); - const wrapper = btn.closest('.media-wrapper'); - const media = wrapper.querySelector('img, video'); - const isBlurred = media.classList.toggle('blurred'); - const icon = btn.querySelector('i'); - - // Update the icon based on blur state - if (isBlurred) { - icon.className = 'fas fa-eye'; - } else { - icon.className = 'fas fa-eye-slash'; - } - - // Toggle the overlay visibility - const overlay = wrapper.querySelector('.nsfw-overlay'); - if (overlay) { - overlay.style.display = isBlurred ? 'flex' : 'none'; - } - }); - }); - - // Handle "Show" buttons in overlays - const showButtons = container.querySelectorAll('.show-content-btn'); - showButtons.forEach(btn => { - btn.addEventListener('click', (e) => { - e.stopPropagation(); - const wrapper = btn.closest('.media-wrapper'); - const media = wrapper.querySelector('img, video'); - media.classList.remove('blurred'); - - // Update the toggle button icon - const toggleBtn = wrapper.querySelector('.toggle-blur-btn'); - if (toggleBtn) { - toggleBtn.querySelector('i').className = 'fas fa-eye-slash'; - } - - // Hide the overlay - const overlay = wrapper.querySelector('.nsfw-overlay'); - if (overlay) { - overlay.style.display = 'none'; - } - }); - }); -} - -// Add lazy loading initialization -function initLazyLoading(container) { - const lazyElements = container.querySelectorAll('.lazy'); - - const lazyLoad = (element) => { - if (element.tagName.toLowerCase() === 'video') { - element.src = element.dataset.src; - element.querySelector('source').src = element.dataset.src; - element.load(); - } else { - element.src = element.dataset.src; - } - element.classList.remove('lazy'); - }; - - const observer = new IntersectionObserver((entries) => { - entries.forEach(entry => { - if (entry.isIntersecting) { - lazyLoad(entry.target); - observer.unobserve(entry.target); - } - }); - }); - - lazyElements.forEach(element => observer.observe(element)); -} - -export function setupShowcaseScroll() { - // Add event listener to document for wheel events - document.addEventListener('wheel', (event) => { - // Find the active modal content - const modalContent = document.querySelector('#loraModal .modal-content'); - if (!modalContent) return; - - const showcase = modalContent.querySelector('.showcase-section'); - if (!showcase) return; - - const carousel = showcase.querySelector('.carousel'); - const scrollIndicator = showcase.querySelector('.scroll-indicator'); - - if (carousel?.classList.contains('collapsed') && event.deltaY > 0) { - const isNearBottom = modalContent.scrollHeight - modalContent.scrollTop - modalContent.clientHeight < 100; - - if (isNearBottom) { - toggleShowcase(scrollIndicator); - event.preventDefault(); - } - } - }, { passive: false }); - - // Use MutationObserver instead of deprecated DOMNodeInserted - const observer = new MutationObserver((mutations) => { - for (const mutation of mutations) { - if (mutation.type === 'childList' && mutation.addedNodes.length) { - // Check if loraModal content was added - const loraModal = document.getElementById('loraModal'); - if (loraModal && loraModal.querySelector('.modal-content')) { - setupBackToTopButton(loraModal.querySelector('.modal-content')); - } - } - } - }); - - // Start observing the document body for changes - observer.observe(document.body, { childList: true, subtree: true }); - - // Also try to set up the button immediately in case the modal is already open - const modalContent = document.querySelector('#loraModal .modal-content'); - if (modalContent) { - setupBackToTopButton(modalContent); - } -} - -// New helper function to set up the back to top button -function setupBackToTopButton(modalContent) { - // Remove any existing scroll listeners to avoid duplicates - modalContent.onscroll = null; - - // Add new scroll listener - modalContent.addEventListener('scroll', () => { - const backToTopBtn = modalContent.querySelector('.back-to-top'); - if (backToTopBtn) { - if (modalContent.scrollTop > 300) { - backToTopBtn.classList.add('visible'); - } else { - backToTopBtn.classList.remove('visible'); - } - } - }); - - // Trigger a scroll event to check initial position - modalContent.dispatchEvent(new Event('scroll')); -} - -export function scrollToTop(button) { - const modalContent = button.closest('.modal-content'); - if (modalContent) { - modalContent.scrollTo({ - top: 0, - behavior: 'smooth' - }); - } -} - -function parsePresets(usageTips) { - if (!usageTips) return {}; - try { - return JSON.parse(usageTips); - } catch { - return {}; - } -} - -function renderPresetTags(presets) { - return Object.entries(presets).map(([key, value]) => ` -
- ${formatPresetKey(key)}: ${value} - -
- `).join(''); -} - -function formatPresetKey(key) { - return key.split('_').map(word => - word.charAt(0).toUpperCase() + word.slice(1) - ).join(' '); -} - -window.removePreset = async function(key) { - const filePath = document.querySelector('#loraModal .modal-content') - .querySelector('.file-path').textContent + - document.querySelector('#loraModal .modal-content') - .querySelector('#file-name').textContent + '.safetensors'; - const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); - const currentPresets = parsePresets(loraCard.dataset.usage_tips); - - delete currentPresets[key]; - const newPresetsJson = JSON.stringify(currentPresets); - - await saveModelMetadata(filePath, { - usage_tips: newPresetsJson - }); - - loraCard.dataset.usage_tips = newPresetsJson; - document.querySelector('.preset-tags').innerHTML = renderPresetTags(currentPresets); -}; - -// 添加文件大小格式化函数 -function formatFileSize(bytes) { - if (!bytes) return 'N/A'; - const units = ['B', 'KB', 'MB', 'GB']; - let size = bytes; - let unitIndex = 0; - - while (size >= 1024 && unitIndex < units.length - 1) { - size /= 1024; - unitIndex++; - } - - return `${size.toFixed(1)} ${units[unitIndex]}`; -} - -// New function to render compact tags with tooltip -function renderCompactTags(tags) { - if (!tags || tags.length === 0) return ''; - - // Display up to 5 tags, with a tooltip indicator if there are more - const visibleTags = tags.slice(0, 5); - const remainingCount = Math.max(0, tags.length - 5); - - return ` -
-
- ${visibleTags.map(tag => `${tag}`).join('')} - ${remainingCount > 0 ? - `+${remainingCount}` : - ''} -
- ${tags.length > 0 ? - `
-
- ${tags.map(tag => `${tag}`).join('')} -
-
` : - ''} -
- `; -} - -// Setup tooltip functionality -function setupTagTooltip() { - const tagsContainer = document.querySelector('.model-tags-container'); - const tooltip = document.querySelector('.model-tags-tooltip'); - - if (tagsContainer && tooltip) { - tagsContainer.addEventListener('mouseenter', () => { - tooltip.classList.add('visible'); - }); - - tagsContainer.addEventListener('mouseleave', () => { - tooltip.classList.remove('visible'); - }); - } -} - -// Set up trigger words edit mode -function setupTriggerWordsEditMode() { - const editBtn = document.querySelector('.edit-trigger-words-btn'); - if (!editBtn) return; - - editBtn.addEventListener('click', function() { - const triggerWordsSection = this.closest('.trigger-words'); - const isEditMode = triggerWordsSection.classList.toggle('edit-mode'); - - // Toggle edit mode UI elements - const triggerWordTags = triggerWordsSection.querySelectorAll('.trigger-word-tag'); - const editControls = triggerWordsSection.querySelector('.trigger-words-edit-controls'); - const noTriggerWords = triggerWordsSection.querySelector('.no-trigger-words'); - const tagsContainer = triggerWordsSection.querySelector('.trigger-words-tags'); - - if (isEditMode) { - this.innerHTML = ''; // Change to cancel icon - this.title = "Cancel editing"; - editControls.style.display = 'flex'; - - // If we have no trigger words yet, hide the "No trigger word needed" text - // and show the empty tags container - if (noTriggerWords) { - noTriggerWords.style.display = 'none'; - if (tagsContainer) tagsContainer.style.display = 'flex'; - } - - // Disable click-to-copy and show delete buttons - triggerWordTags.forEach(tag => { - tag.onclick = null; - tag.querySelector('.trigger-word-copy').style.display = 'none'; - tag.querySelector('.delete-trigger-word-btn').style.display = 'block'; - }); - } else { - this.innerHTML = ''; // Change back to edit icon - this.title = "Edit trigger words"; - editControls.style.display = 'none'; - - // If we have no trigger words, show the "No trigger word needed" text - // and hide the empty tags container - const currentTags = triggerWordsSection.querySelectorAll('.trigger-word-tag'); - if (noTriggerWords && currentTags.length === 0) { - noTriggerWords.style.display = ''; - if (tagsContainer) tagsContainer.style.display = 'none'; - } - - // Restore original state - triggerWordTags.forEach(tag => { - const word = tag.dataset.word; - tag.onclick = () => copyTriggerWord(word); - tag.querySelector('.trigger-word-copy').style.display = 'flex'; - tag.querySelector('.delete-trigger-word-btn').style.display = 'none'; - }); - - // Hide add form if open - triggerWordsSection.querySelector('.add-trigger-word-form').style.display = 'none'; - } - }); - - // Set up add trigger word button - const addBtn = document.querySelector('.add-trigger-word-btn'); - if (addBtn) { - addBtn.addEventListener('click', function() { - const triggerWordsSection = this.closest('.trigger-words'); - const addForm = triggerWordsSection.querySelector('.add-trigger-word-form'); - addForm.style.display = 'flex'; - addForm.querySelector('input').focus(); - }); - } - - // Set up confirm and cancel add buttons - const confirmAddBtn = document.querySelector('.confirm-add-trigger-word-btn'); - const cancelAddBtn = document.querySelector('.cancel-add-trigger-word-btn'); - const triggerWordInput = document.querySelector('.new-trigger-word-input'); - - if (confirmAddBtn && triggerWordInput) { - confirmAddBtn.addEventListener('click', function() { - addNewTriggerWord(triggerWordInput.value); - }); - - // Add keydown event to input - triggerWordInput.addEventListener('keydown', function(e) { - if (e.key === 'Enter') { - e.preventDefault(); - addNewTriggerWord(this.value); - } - }); - } - - if (cancelAddBtn) { - cancelAddBtn.addEventListener('click', function() { - const addForm = this.closest('.add-trigger-word-form'); - addForm.style.display = 'none'; - addForm.querySelector('input').value = ''; - }); - } - - // Set up save button - const saveBtn = document.querySelector('.save-trigger-words-btn'); - if (saveBtn) { - saveBtn.addEventListener('click', saveTriggerWords); - } - - // Set up delete buttons - document.querySelectorAll('.delete-trigger-word-btn').forEach(btn => { - btn.addEventListener('click', function(e) { - e.stopPropagation(); - const tag = this.closest('.trigger-word-tag'); - tag.remove(); - }); - }); -} - -// Function to add a new trigger word -function addNewTriggerWord(word) { - word = word.trim(); - if (!word) return; - - const triggerWordsSection = document.querySelector('.trigger-words'); - let tagsContainer = document.querySelector('.trigger-words-tags'); - - // Ensure tags container exists and is visible - if (tagsContainer) { - tagsContainer.style.display = 'flex'; - } else { - // Create tags container if it doesn't exist - const contentDiv = triggerWordsSection.querySelector('.trigger-words-content'); - if (contentDiv) { - tagsContainer = document.createElement('div'); - tagsContainer.className = 'trigger-words-tags'; - contentDiv.appendChild(tagsContainer); - } - } - - if (!tagsContainer) return; - - // Hide "no trigger words" message if it exists - const noTriggerWordsMsg = triggerWordsSection.querySelector('.no-trigger-words'); - if (noTriggerWordsMsg) { - noTriggerWordsMsg.style.display = 'none'; - } - - // Validation: Check length - if (word.split(/\s+/).length > 30) { - showToast('Trigger word should not exceed 30 words', 'error'); - return; - } - - // Validation: Check total number - const currentTags = tagsContainer.querySelectorAll('.trigger-word-tag'); - if (currentTags.length >= 10) { - showToast('Maximum 10 trigger words allowed', 'error'); - return; - } - - // Validation: Check for duplicates - const existingWords = Array.from(currentTags).map(tag => tag.dataset.word); - if (existingWords.includes(word)) { - showToast('This trigger word already exists', 'error'); - return; - } - - // Create new tag - const newTag = document.createElement('div'); - newTag.className = 'trigger-word-tag'; - newTag.dataset.word = word; - newTag.innerHTML = ` - ${word} - - - `; - - // Add event listener to delete button - const deleteBtn = newTag.querySelector('.delete-trigger-word-btn'); - deleteBtn.addEventListener('click', function() { - newTag.remove(); - }); - - tagsContainer.appendChild(newTag); - - // Clear and hide the input form - const triggerWordInput = document.querySelector('.new-trigger-word-input'); - triggerWordInput.value = ''; - document.querySelector('.add-trigger-word-form').style.display = 'none'; -} - -// Function to save updated trigger words -async function saveTriggerWords() { - const filePath = document.querySelector('.edit-trigger-words-btn').dataset.filePath; - const triggerWordTags = document.querySelectorAll('.trigger-word-tag'); - const words = Array.from(triggerWordTags).map(tag => tag.dataset.word); - - try { - // Special format for updating nested civitai.trainedWords - await saveModelMetadata(filePath, { - civitai: { trainedWords: words } - }); - - // Update UI - const editBtn = document.querySelector('.edit-trigger-words-btn'); - editBtn.click(); // Exit edit mode - - // Update the LoRA card's dataset - const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); - if (loraCard) { - try { - // Create a proper structure for civitai data - let civitaiData = {}; - - // Parse existing data if available - if (loraCard.dataset.meta) { - civitaiData = JSON.parse(loraCard.dataset.meta); - } - - // Update trainedWords property - civitaiData.trainedWords = words; - - // Update the meta dataset attribute with the full civitai data - loraCard.dataset.meta = JSON.stringify(civitaiData); - - // For debugging, log the updated data to verify it's correct - console.log("Updated civitai data:", civitaiData); - } catch (e) { - console.error('Error updating civitai data:', e); - } - } - - // If we saved an empty array and there's a no-trigger-words element, show it - const noTriggerWords = document.querySelector('.no-trigger-words'); - const tagsContainer = document.querySelector('.trigger-words-tags'); - if (words.length === 0 && noTriggerWords) { - noTriggerWords.style.display = ''; - if (tagsContainer) tagsContainer.style.display = 'none'; - } - - showToast('Trigger words updated successfully', 'success'); - } catch (error) { - console.error('Error saving trigger words:', error); - showToast('Failed to update trigger words', 'error'); - } -} - -// Add copy trigger word function -window.copyTriggerWord = async function(word) { - try { - await navigator.clipboard.writeText(word); - showToast('Trigger word copied', 'success'); - } catch (err) { - console.error('Copy failed:', err); - showToast('Copy failed', 'error'); - } -}; - -// New function to handle model name editing -function setupModelNameEditing() { - const modelNameContent = document.querySelector('.model-name-content'); - const editBtn = document.querySelector('.edit-model-name-btn'); - - if (!modelNameContent || !editBtn) return; - - // Show edit button on hover - const modelNameHeader = document.querySelector('.model-name-header'); - modelNameHeader.addEventListener('mouseenter', () => { - editBtn.classList.add('visible'); - }); - - modelNameHeader.addEventListener('mouseleave', () => { - if (!modelNameContent.getAttribute('data-editing')) { - editBtn.classList.remove('visible'); - } - }); - - // Handle edit button click - editBtn.addEventListener('click', () => { - modelNameContent.setAttribute('data-editing', 'true'); - modelNameContent.focus(); - - // Place cursor at the end - const range = document.createRange(); - const sel = window.getSelection(); - if (modelNameContent.childNodes.length > 0) { - range.setStart(modelNameContent.childNodes[0], modelNameContent.textContent.length); - range.collapse(true); - sel.removeAllRanges(); - sel.addRange(range); - } - - editBtn.classList.add('visible'); - }); - - // Handle focus out - modelNameContent.addEventListener('blur', function() { - this.removeAttribute('data-editing'); - editBtn.classList.remove('visible'); - - if (this.textContent.trim() === '') { - // Restore original model name if empty - const filePath = document.querySelector('#loraModal .modal-content') - .querySelector('.file-path').textContent + - document.querySelector('#loraModal .modal-content') - .querySelector('#file-name').textContent + '.safetensors'; - const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); - if (loraCard) { - this.textContent = loraCard.dataset.model_name; - } - } - }); - - // Handle enter key - modelNameContent.addEventListener('keydown', function(e) { - if (e.key === 'Enter') { - e.preventDefault(); - const filePath = document.querySelector('#loraModal .modal-content') - .querySelector('.file-path').textContent + - document.querySelector('#loraModal .modal-content') - .querySelector('#file-name').textContent + '.safetensors'; - saveModelName(filePath); - this.blur(); - } - }); - - // Limit model name length - modelNameContent.addEventListener('input', function() { - // Limit model name length - if (this.textContent.length > 100) { - this.textContent = this.textContent.substring(0, 100); - // Place cursor at the end - const range = document.createRange(); - const sel = window.getSelection(); - range.setStart(this.childNodes[0], 100); - range.collapse(true); - sel.removeAllRanges(); - sel.addRange(range); - - showToast('Model name is limited to 100 characters', 'warning'); - } - }); -} - -// Add save model base model function -window.saveBaseModel = async function(filePath, originalValue) { - const baseModelElement = document.querySelector('.base-model-content'); - const newBaseModel = baseModelElement.textContent.trim(); - - // Only save if the value has actually changed - if (newBaseModel === originalValue) { - return; // No change, no need to save - } - - try { - await saveModelMetadata(filePath, { base_model: newBaseModel }); - - // Update the corresponding lora card's dataset - const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); - if (loraCard) { - loraCard.dataset.base_model = newBaseModel; - } - - showToast('Base model updated successfully', 'success'); - } catch (error) { - showToast('Failed to update base model', 'error'); - } -}; - -// New function to handle base model editing -function setupBaseModelEditing() { - const baseModelContent = document.querySelector('.base-model-content'); - const editBtn = document.querySelector('.edit-base-model-btn'); - - if (!baseModelContent || !editBtn) return; - - // Show edit button on hover - const baseModelDisplay = document.querySelector('.base-model-display'); - baseModelDisplay.addEventListener('mouseenter', () => { - editBtn.classList.add('visible'); - }); - - baseModelDisplay.addEventListener('mouseleave', () => { - if (!baseModelDisplay.classList.contains('editing')) { - editBtn.classList.remove('visible'); - } - }); - - // Handle edit button click - editBtn.addEventListener('click', () => { - baseModelDisplay.classList.add('editing'); - - // Store the original value to check for changes later - const originalValue = baseModelContent.textContent.trim(); - - // Create dropdown selector to replace the base model content - const currentValue = originalValue; - const dropdown = document.createElement('select'); - dropdown.className = 'base-model-selector'; - - // Flag to track if a change was made - let valueChanged = false; - - // Add options from BASE_MODELS constants - const baseModelCategories = { - 'Stable Diffusion 1.x': [BASE_MODELS.SD_1_4, BASE_MODELS.SD_1_5, BASE_MODELS.SD_1_5_LCM, BASE_MODELS.SD_1_5_HYPER], - 'Stable Diffusion 2.x': [BASE_MODELS.SD_2_0, BASE_MODELS.SD_2_1], - 'Stable Diffusion 3.x': [BASE_MODELS.SD_3, BASE_MODELS.SD_3_5, BASE_MODELS.SD_3_5_MEDIUM, BASE_MODELS.SD_3_5_LARGE, BASE_MODELS.SD_3_5_LARGE_TURBO], - 'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER], - 'Video Models': [BASE_MODELS.SVD, BASE_MODELS.WAN_VIDEO, BASE_MODELS.HUNYUAN_VIDEO], - 'Other Models': [ - BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.AURAFLOW, - BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1, - BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI, - BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.UNKNOWN - ] - }; - - // Create option groups for better organization - Object.entries(baseModelCategories).forEach(([category, models]) => { - const group = document.createElement('optgroup'); - group.label = category; - - models.forEach(model => { - const option = document.createElement('option'); - option.value = model; - option.textContent = model; - option.selected = model === currentValue; - group.appendChild(option); - }); - - dropdown.appendChild(group); - }); - - // Replace content with dropdown - baseModelContent.style.display = 'none'; - baseModelDisplay.insertBefore(dropdown, editBtn); - - // Hide edit button during editing - editBtn.style.display = 'none'; - - // Focus the dropdown - dropdown.focus(); - - // Handle dropdown change - dropdown.addEventListener('change', function() { - const selectedModel = this.value; - baseModelContent.textContent = selectedModel; - - // Mark that a change was made if the value differs from original - if (selectedModel !== originalValue) { - valueChanged = true; - } else { - valueChanged = false; - } - }); - - // Function to save changes and exit edit mode - const saveAndExit = function() { - // Check if dropdown still exists and remove it - if (dropdown && dropdown.parentNode === baseModelDisplay) { - baseModelDisplay.removeChild(dropdown); - } - - // Show the content and edit button - baseModelContent.style.display = ''; - editBtn.style.display = ''; - - // Remove editing class - baseModelDisplay.classList.remove('editing'); - - // Only save if the value has actually changed - if (valueChanged || baseModelContent.textContent.trim() !== originalValue) { - // Get file path for saving - const filePath = document.querySelector('#loraModal .modal-content') - .querySelector('.file-path').textContent + - document.querySelector('#loraModal .modal-content') - .querySelector('#file-name').textContent + '.safetensors'; - - // Save the changes, passing the original value for comparison - saveBaseModel(filePath, originalValue); - } - - // Remove this event listener - document.removeEventListener('click', outsideClickHandler); - }; - - // Handle outside clicks to save and exit - const outsideClickHandler = function(e) { - // If click is outside the dropdown and base model display - if (!baseModelDisplay.contains(e.target)) { - saveAndExit(); - } - }; - - // Add delayed event listener for outside clicks - setTimeout(() => { - document.addEventListener('click', outsideClickHandler); - }, 0); - - // Also handle dropdown blur event - dropdown.addEventListener('blur', function(e) { - // Only save if the related target is not the edit button or inside the baseModelDisplay - if (!baseModelDisplay.contains(e.relatedTarget)) { - saveAndExit(); - } - }); - }); -} - -// New function to handle file name editing -function setupFileNameEditing() { - const fileNameContent = document.querySelector('.file-name-content'); - const editBtn = document.querySelector('.edit-file-name-btn'); - - if (!fileNameContent || !editBtn) return; - - // Show edit button on hover - const fileNameWrapper = document.querySelector('.file-name-wrapper'); - fileNameWrapper.addEventListener('mouseenter', () => { - editBtn.classList.add('visible'); - }); - - fileNameWrapper.addEventListener('mouseleave', () => { - if (!fileNameWrapper.classList.contains('editing')) { - editBtn.classList.remove('visible'); - } - }); - - // Handle edit button click - editBtn.addEventListener('click', () => { - fileNameWrapper.classList.add('editing'); - fileNameContent.setAttribute('contenteditable', 'true'); - fileNameContent.focus(); - - // Store original value for comparison later - fileNameContent.dataset.originalValue = fileNameContent.textContent.trim(); - - // Place cursor at the end - const range = document.createRange(); - const sel = window.getSelection(); - range.selectNodeContents(fileNameContent); - range.collapse(false); - sel.removeAllRanges(); - sel.addRange(range); - - editBtn.classList.add('visible'); - }); - - // Handle keyboard events in edit mode - fileNameContent.addEventListener('keydown', function(e) { - if (!this.getAttribute('contenteditable')) return; - - if (e.key === 'Enter') { - e.preventDefault(); - this.blur(); // Trigger save on Enter - } else if (e.key === 'Escape') { - e.preventDefault(); - // Restore original value - this.textContent = this.dataset.originalValue; - exitEditMode(); - } - }); - - // Handle input validation - fileNameContent.addEventListener('input', function() { - if (!this.getAttribute('contenteditable')) return; - - // Replace invalid characters for filenames - const invalidChars = /[\\/:*?"<>|]/g; - if (invalidChars.test(this.textContent)) { - const cursorPos = window.getSelection().getRangeAt(0).startOffset; - this.textContent = this.textContent.replace(invalidChars, ''); - - // Restore cursor position - const range = document.createRange(); - const sel = window.getSelection(); - const newPos = Math.min(cursorPos, this.textContent.length); - - if (this.firstChild) { - range.setStart(this.firstChild, newPos); - range.collapse(true); - sel.removeAllRanges(); - sel.addRange(range); - } - - showToast('Invalid characters removed from filename', 'warning'); - } - }); - - // Handle focus out - save changes - fileNameContent.addEventListener('blur', async function() { - if (!this.getAttribute('contenteditable')) return; - - const newFileName = this.textContent.trim(); - const originalValue = this.dataset.originalValue; - - // Basic validation - if (!newFileName) { - // Restore original value if empty - this.textContent = originalValue; - showToast('File name cannot be empty', 'error'); - exitEditMode(); - return; - } - - if (newFileName === originalValue) { - // No changes, just exit edit mode - exitEditMode(); - return; - } - - try { - // Get the full file path - const filePath = document.querySelector('#loraModal .modal-content') - .querySelector('.file-path').textContent + originalValue + '.safetensors'; - - // Call API to rename the file - const response = await fetch('/api/rename_lora', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - file_path: filePath, - new_file_name: newFileName - }) - }); - - const result = await response.json(); - - if (result.success) { - showToast('File name updated successfully', 'success'); - - // Update card in the gallery - const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); - if (loraCard) { - // Update the card's filepath attribute to the new path - loraCard.dataset.filepath = result.new_file_path; - loraCard.dataset.file_name = newFileName; - - // Update the filename display in the card - const cardFileName = loraCard.querySelector('.card-filename'); - if (cardFileName) { - cardFileName.textContent = newFileName; - } - } - - // Handle the case where we need to reload the page - if (result.reload_required) { - showToast('Reloading page to apply changes...', 'info'); - setTimeout(() => { - window.location.reload(); - }, 1500); - } - } else { - // Show error and restore original filename - showToast(result.error || 'Failed to update file name', 'error'); - this.textContent = originalValue; - } - } catch (error) { - console.error('Error saving filename:', error); - showToast('Failed to update file name', 'error'); - this.textContent = originalValue; - } finally { - exitEditMode(); - } - }); - - function exitEditMode() { - fileNameContent.removeAttribute('contenteditable'); - fileNameWrapper.classList.remove('editing'); - editBtn.classList.remove('visible'); - } -} diff --git a/static/js/components/checkpointModal/ModelDescription.js b/static/js/components/checkpointModal/ModelDescription.js new file mode 100644 index 00000000..0f50fe8a --- /dev/null +++ b/static/js/components/checkpointModal/ModelDescription.js @@ -0,0 +1,102 @@ +/** + * ModelDescription.js + * Handles checkpoint model descriptions + */ +import { showToast } from '../../utils/uiHelpers.js'; + +/** + * Set up tab switching functionality + */ +export function setupTabSwitching() { + const tabButtons = document.querySelectorAll('.showcase-tabs .tab-btn'); + + tabButtons.forEach(button => { + button.addEventListener('click', () => { + // Remove active class from all tabs + document.querySelectorAll('.showcase-tabs .tab-btn').forEach(btn => + btn.classList.remove('active') + ); + document.querySelectorAll('.tab-content .tab-pane').forEach(tab => + tab.classList.remove('active') + ); + + // Add active class to clicked tab + button.classList.add('active'); + const tabId = `${button.dataset.tab}-tab`; + document.getElementById(tabId).classList.add('active'); + + // If switching to description tab, make sure content is properly loaded and displayed + if (button.dataset.tab === 'description') { + const descriptionContent = document.querySelector('.model-description-content'); + if (descriptionContent) { + const hasContent = descriptionContent.innerHTML.trim() !== ''; + document.querySelector('.model-description-loading')?.classList.add('hidden'); + + // If no content, show a message + if (!hasContent) { + descriptionContent.innerHTML = '
No model description available
'; + descriptionContent.classList.remove('hidden'); + } + } + } + }); + }); +} + +/** + * Load model description from API + * @param {string} modelId - The Civitai model ID + * @param {string} filePath - File path for the model + */ +export async function loadModelDescription(modelId, filePath) { + try { + const descriptionContainer = document.querySelector('.model-description-content'); + const loadingElement = document.querySelector('.model-description-loading'); + + if (!descriptionContainer || !loadingElement) return; + + // Show loading indicator + loadingElement.classList.remove('hidden'); + descriptionContainer.classList.add('hidden'); + + // Try to get model description from API + const response = await fetch(`/api/checkpoint-model-description?model_id=${modelId}&file_path=${encodeURIComponent(filePath)}`); + + if (!response.ok) { + throw new Error(`Failed to fetch model description: ${response.statusText}`); + } + + const data = await response.json(); + + if (data.success && data.description) { + // Update the description content + descriptionContainer.innerHTML = data.description; + + // Process any links in the description to open in new tab + const links = descriptionContainer.querySelectorAll('a'); + links.forEach(link => { + link.setAttribute('target', '_blank'); + link.setAttribute('rel', 'noopener noreferrer'); + }); + + // Show the description and hide loading indicator + descriptionContainer.classList.remove('hidden'); + loadingElement.classList.add('hidden'); + } else { + throw new Error(data.error || 'No description available'); + } + } catch (error) { + console.error('Error loading model description:', error); + const loadingElement = document.querySelector('.model-description-loading'); + if (loadingElement) { + loadingElement.innerHTML = `
Failed to load model description. ${error.message}
`; + } + + // Show empty state message in the description container + const descriptionContainer = document.querySelector('.model-description-content'); + if (descriptionContainer) { + descriptionContainer.innerHTML = '
No model description available
'; + descriptionContainer.classList.remove('hidden'); + } + } +} \ No newline at end of file diff --git a/static/js/components/checkpointModal/ModelMetadata.js b/static/js/components/checkpointModal/ModelMetadata.js new file mode 100644 index 00000000..4bf7b2b3 --- /dev/null +++ b/static/js/components/checkpointModal/ModelMetadata.js @@ -0,0 +1,484 @@ +/** + * ModelMetadata.js + * Handles checkpoint model metadata editing functionality + */ +import { showToast } from '../../utils/uiHelpers.js'; +import { BASE_MODELS } from '../../utils/constants.js'; +import { updateCheckpointCard } from '../../utils/cardUpdater.js'; + +/** + * Save model metadata to the server + * @param {string} filePath - Path to the model file + * @param {Object} data - Metadata to save + * @returns {Promise} - Promise that resolves with the server response + */ +export async function saveModelMetadata(filePath, data) { + const response = await fetch('/api/checkpoints/save-metadata', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + file_path: filePath, + ...data + }) + }); + + if (!response.ok) { + throw new Error('Failed to save metadata'); + } + + return response.json(); +} + +/** + * Set up model name editing functionality + * @param {string} filePath - The full file path of the model. + */ +export function setupModelNameEditing(filePath) { + const modelNameContent = document.querySelector('.model-name-content'); + const editBtn = document.querySelector('.edit-model-name-btn'); + + if (!modelNameContent || !editBtn) return; + + // Show edit button on hover + const modelNameHeader = document.querySelector('.model-name-header'); + modelNameHeader.addEventListener('mouseenter', () => { + editBtn.classList.add('visible'); + }); + + modelNameHeader.addEventListener('mouseleave', () => { + if (!modelNameContent.getAttribute('data-editing')) { + editBtn.classList.remove('visible'); + } + }); + + // Handle edit button click + editBtn.addEventListener('click', () => { + modelNameContent.setAttribute('data-editing', 'true'); + modelNameContent.focus(); + + // Place cursor at the end + const range = document.createRange(); + const sel = window.getSelection(); + if (modelNameContent.childNodes.length > 0) { + range.setStart(modelNameContent.childNodes[0], modelNameContent.textContent.length); + range.collapse(true); + sel.removeAllRanges(); + sel.addRange(range); + } + + editBtn.classList.add('visible'); + }); + + // Handle focus out + modelNameContent.addEventListener('blur', function() { + this.removeAttribute('data-editing'); + editBtn.classList.remove('visible'); + + if (this.textContent.trim() === '') { + // Restore original model name if empty + // Use the passed filePath to find the card + const checkpointCard = document.querySelector(`.checkpoint-card[data-filepath="${filePath}"]`); + if (checkpointCard) { + this.textContent = checkpointCard.dataset.model_name; + } + } + }); + + // Handle enter key + modelNameContent.addEventListener('keydown', function(e) { + if (e.key === 'Enter') { + e.preventDefault(); + // Use the passed filePath + saveModelName(filePath); + this.blur(); + } + }); + + // Limit model name length + modelNameContent.addEventListener('input', function() { + if (this.textContent.length > 100) { + this.textContent = this.textContent.substring(0, 100); + // Place cursor at the end + const range = document.createRange(); + const sel = window.getSelection(); + range.setStart(this.childNodes[0], 100); + range.collapse(true); + sel.removeAllRanges(); + sel.addRange(range); + + showToast('Model name is limited to 100 characters', 'warning'); + } + }); +} + +/** + * Save model name + * @param {string} filePath - File path + */ +async function saveModelName(filePath) { + const modelNameElement = document.querySelector('.model-name-content'); + const newModelName = modelNameElement.textContent.trim(); + + // Validate model name + if (!newModelName) { + showToast('Model name cannot be empty', 'error'); + return; + } + + // Check if model name is too long + if (newModelName.length > 100) { + showToast('Model name is too long (maximum 100 characters)', 'error'); + // Truncate the displayed text + modelNameElement.textContent = newModelName.substring(0, 100); + return; + } + + try { + await saveModelMetadata(filePath, { model_name: newModelName }); + + // Update the card with the new model name + updateCheckpointCard(filePath, { name: newModelName }); + + showToast('Model name updated successfully', 'success'); + + // No need to reload the entire page + // setTimeout(() => { + // window.location.reload(); + // }, 1500); + } catch (error) { + showToast('Failed to update model name', 'error'); + } +} + +/** + * Set up base model editing functionality + * @param {string} filePath - The full file path of the model. + */ +export function setupBaseModelEditing(filePath) { + const baseModelContent = document.querySelector('.base-model-content'); + const editBtn = document.querySelector('.edit-base-model-btn'); + + if (!baseModelContent || !editBtn) return; + + // Show edit button on hover + const baseModelDisplay = document.querySelector('.base-model-display'); + baseModelDisplay.addEventListener('mouseenter', () => { + editBtn.classList.add('visible'); + }); + + baseModelDisplay.addEventListener('mouseleave', () => { + if (!baseModelDisplay.classList.contains('editing')) { + editBtn.classList.remove('visible'); + } + }); + + // Handle edit button click + editBtn.addEventListener('click', () => { + baseModelDisplay.classList.add('editing'); + + // Store the original value to check for changes later + const originalValue = baseModelContent.textContent.trim(); + + // Create dropdown selector to replace the base model content + const currentValue = originalValue; + const dropdown = document.createElement('select'); + dropdown.className = 'base-model-selector'; + + // Flag to track if a change was made + let valueChanged = false; + + // Add options from BASE_MODELS constants + const baseModelCategories = { + 'Stable Diffusion 1.x': [BASE_MODELS.SD_1_4, BASE_MODELS.SD_1_5, BASE_MODELS.SD_1_5_LCM, BASE_MODELS.SD_1_5_HYPER], + 'Stable Diffusion 2.x': [BASE_MODELS.SD_2_0, BASE_MODELS.SD_2_1], + 'Stable Diffusion 3.x': [BASE_MODELS.SD_3, BASE_MODELS.SD_3_5, BASE_MODELS.SD_3_5_MEDIUM, BASE_MODELS.SD_3_5_LARGE, BASE_MODELS.SD_3_5_LARGE_TURBO], + 'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER], + 'Video Models': [BASE_MODELS.SVD, BASE_MODELS.WAN_VIDEO, BASE_MODELS.HUNYUAN_VIDEO], + 'Other Models': [ + BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.AURAFLOW, + BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1, + BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI, + BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.UNKNOWN + ] + }; + + // Create option groups for better organization + Object.entries(baseModelCategories).forEach(([category, models]) => { + const group = document.createElement('optgroup'); + group.label = category; + + models.forEach(model => { + const option = document.createElement('option'); + option.value = model; + option.textContent = model; + option.selected = model === currentValue; + group.appendChild(option); + }); + + dropdown.appendChild(group); + }); + + // Replace content with dropdown + baseModelContent.style.display = 'none'; + baseModelDisplay.insertBefore(dropdown, editBtn); + + // Hide edit button during editing + editBtn.style.display = 'none'; + + // Focus the dropdown + dropdown.focus(); + + // Handle dropdown change + dropdown.addEventListener('change', function() { + const selectedModel = this.value; + baseModelContent.textContent = selectedModel; + + // Mark that a change was made if the value differs from original + if (selectedModel !== originalValue) { + valueChanged = true; + } else { + valueChanged = false; + } + }); + + // Function to save changes and exit edit mode + const saveAndExit = function() { + // Check if dropdown still exists and remove it + if (dropdown && dropdown.parentNode === baseModelDisplay) { + baseModelDisplay.removeChild(dropdown); + } + + // Show the content and edit button + baseModelContent.style.display = ''; + editBtn.style.display = ''; + + // Remove editing class + baseModelDisplay.classList.remove('editing'); + + // Only save if the value has actually changed + if (valueChanged || baseModelContent.textContent.trim() !== originalValue) { + // Use the passed filePath for saving + saveBaseModel(filePath, originalValue); + } + + // Remove this event listener + document.removeEventListener('click', outsideClickHandler); + }; + + // Handle outside clicks to save and exit + const outsideClickHandler = function(e) { + // If click is outside the dropdown and base model display + if (!baseModelDisplay.contains(e.target)) { + saveAndExit(); + } + }; + + // Add delayed event listener for outside clicks + setTimeout(() => { + document.addEventListener('click', outsideClickHandler); + }, 0); + + // Also handle dropdown blur event + dropdown.addEventListener('blur', function(e) { + // Only save if the related target is not the edit button or inside the baseModelDisplay + if (!baseModelDisplay.contains(e.relatedTarget)) { + saveAndExit(); + } + }); + }); +} + +/** + * Save base model + * @param {string} filePath - File path + * @param {string} originalValue - Original value (for comparison) + */ +async function saveBaseModel(filePath, originalValue) { + const baseModelElement = document.querySelector('.base-model-content'); + const newBaseModel = baseModelElement.textContent.trim(); + + // Only save if the value has actually changed + if (newBaseModel === originalValue) { + return; // No change, no need to save + } + + try { + await saveModelMetadata(filePath, { base_model: newBaseModel }); + + // Update the card with the new base model + updateCheckpointCard(filePath, { base_model: newBaseModel }); + + showToast('Base model updated successfully', 'success'); + } catch (error) { + showToast('Failed to update base model', 'error'); + } +} + +/** + * Set up file name editing functionality + * @param {string} filePath - The full file path of the model. + */ +export function setupFileNameEditing(filePath) { + const fileNameContent = document.querySelector('.file-name-content'); + const editBtn = document.querySelector('.edit-file-name-btn'); + + if (!fileNameContent || !editBtn) return; + + // Show edit button on hover + const fileNameWrapper = document.querySelector('.file-name-wrapper'); + fileNameWrapper.addEventListener('mouseenter', () => { + editBtn.classList.add('visible'); + }); + + fileNameWrapper.addEventListener('mouseleave', () => { + if (!fileNameWrapper.classList.contains('editing')) { + editBtn.classList.remove('visible'); + } + }); + + // Handle edit button click + editBtn.addEventListener('click', () => { + fileNameWrapper.classList.add('editing'); + fileNameContent.setAttribute('contenteditable', 'true'); + fileNameContent.focus(); + + // Store original value for comparison later + fileNameContent.dataset.originalValue = fileNameContent.textContent.trim(); + + // Place cursor at the end + const range = document.createRange(); + const sel = window.getSelection(); + range.selectNodeContents(fileNameContent); + range.collapse(false); + sel.removeAllRanges(); + sel.addRange(range); + + editBtn.classList.add('visible'); + }); + + // Handle keyboard events in edit mode + fileNameContent.addEventListener('keydown', function(e) { + if (!this.getAttribute('contenteditable')) return; + + if (e.key === 'Enter') { + e.preventDefault(); + this.blur(); // Trigger save on Enter + } else if (e.key === 'Escape') { + e.preventDefault(); + // Restore original value + this.textContent = this.dataset.originalValue; + exitEditMode(); + } + }); + + // Handle input validation + fileNameContent.addEventListener('input', function() { + if (!this.getAttribute('contenteditable')) return; + + // Replace invalid characters for filenames + const invalidChars = /[\\/:*?"<>|]/g; + if (invalidChars.test(this.textContent)) { + const cursorPos = window.getSelection().getRangeAt(0).startOffset; + this.textContent = this.textContent.replace(invalidChars, ''); + + // Restore cursor position + const range = document.createRange(); + const sel = window.getSelection(); + const newPos = Math.min(cursorPos, this.textContent.length); + + if (this.firstChild) { + range.setStart(this.firstChild, newPos); + range.collapse(true); + sel.removeAllRanges(); + sel.addRange(range); + } + + showToast('Invalid characters removed from filename', 'warning'); + } + }); + + // Handle focus out - save changes + fileNameContent.addEventListener('blur', async function() { + if (!this.getAttribute('contenteditable')) return; + + const newFileName = this.textContent.trim(); + const originalValue = this.dataset.originalValue; + + // Basic validation + if (!newFileName) { + // Restore original value if empty + this.textContent = originalValue; + showToast('File name cannot be empty', 'error'); + exitEditMode(); + return; + } + + if (newFileName === originalValue) { + // No changes, just exit edit mode + exitEditMode(); + return; + } + + try { + // Use the passed filePath (which includes the original filename) + // Call API to rename the file + const response = await fetch('/api/rename_checkpoint', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + file_path: filePath, // Use the full original path + new_file_name: newFileName + }) + }); + + const result = await response.json(); + + if (result.success) { + showToast('File name updated successfully', 'success'); + + // Get the new file path from the result + const pathParts = filePath.split(/[\\/]/); + pathParts.pop(); // Remove old filename + const newFilePath = [...pathParts, newFileName].join('/'); + + // Update the checkpoint card with new file path + updateCheckpointCard(filePath, { + filepath: newFilePath, + file_name: newFileName + }); + + // Update the file name display in the modal + document.querySelector('#file-name').textContent = newFileName; + + // Update the modal's data-filepath attribute + const modalContent = document.querySelector('#checkpointModal .modal-content'); + if (modalContent) { + modalContent.dataset.filepath = newFilePath; + } + + // Reload the page after a short delay to reflect changes + setTimeout(() => { + window.location.reload(); + }, 1500); + } else { + throw new Error(result.error || 'Unknown error'); + } + } catch (error) { + console.error('Error renaming file:', error); + this.textContent = originalValue; // Restore original file name + showToast(`Failed to rename file: ${error.message}`, 'error'); + } finally { + exitEditMode(); + } + }); + + function exitEditMode() { + fileNameContent.removeAttribute('contenteditable'); + fileNameWrapper.classList.remove('editing'); + editBtn.classList.remove('visible'); + } +} \ No newline at end of file diff --git a/static/js/components/checkpointModal/ShowcaseView.js b/static/js/components/checkpointModal/ShowcaseView.js new file mode 100644 index 00000000..d9843fc3 --- /dev/null +++ b/static/js/components/checkpointModal/ShowcaseView.js @@ -0,0 +1,489 @@ +/** + * ShowcaseView.js + * Handles showcase content (images, videos) display for checkpoint modal + */ +import { showToast } from '../../utils/uiHelpers.js'; +import { state } from '../../state/index.js'; +import { NSFW_LEVELS } from '../../utils/constants.js'; + +/** + * Render showcase content + * @param {Array} images - Array of images/videos to show + * @returns {string} HTML content + */ +export function renderShowcaseContent(images) { + if (!images?.length) return '
No example images available
'; + + // Filter images based on SFW setting + const showOnlySFW = state.settings.show_only_sfw; + let filteredImages = images; + let hiddenCount = 0; + + if (showOnlySFW) { + filteredImages = images.filter(img => { + const nsfwLevel = img.nsfwLevel !== undefined ? img.nsfwLevel : 0; + const isSfw = nsfwLevel < NSFW_LEVELS.R; + if (!isSfw) hiddenCount++; + return isSfw; + }); + } + + // Show message if no images are available after filtering + if (filteredImages.length === 0) { + return ` +
+

All example images are filtered due to NSFW content settings

+

Your settings are currently set to show only safe-for-work content

+

You can change this in Settings

+
+ `; + } + + // Show hidden content notification if applicable + const hiddenNotification = hiddenCount > 0 ? + `
+ ${hiddenCount} ${hiddenCount === 1 ? 'image' : 'images'} hidden due to SFW-only setting +
` : ''; + + return ` +
+ + Scroll or click to show ${filteredImages.length} examples +
+ + `; +} + +/** + * Generate media wrapper HTML for an image or video + * @param {Object} media - Media object with image or video data + * @returns {string} HTML content + */ +function generateMediaWrapper(media) { + // Calculate appropriate aspect ratio: + // 1. Keep original aspect ratio + // 2. Limit maximum height to 60% of viewport height + // 3. Ensure minimum height is 40% of container width + const aspectRatio = (media.height / media.width) * 100; + const containerWidth = 800; // modal content maximum width + const minHeightPercent = 40; + const maxHeightPercent = (window.innerHeight * 0.6 / containerWidth) * 100; + const heightPercent = Math.max( + minHeightPercent, + Math.min(maxHeightPercent, aspectRatio) + ); + + // Check if media should be blurred + const nsfwLevel = media.nsfwLevel !== undefined ? media.nsfwLevel : 0; + const shouldBlur = state.settings.blurMatureContent && nsfwLevel > NSFW_LEVELS.PG13; + + // Determine NSFW warning text based on level + let nsfwText = "Mature Content"; + if (nsfwLevel >= NSFW_LEVELS.XXX) { + nsfwText = "XXX-rated Content"; + } else if (nsfwLevel >= NSFW_LEVELS.X) { + nsfwText = "X-rated Content"; + } else if (nsfwLevel >= NSFW_LEVELS.R) { + nsfwText = "R-rated Content"; + } + + // Extract metadata from the media + const meta = media.meta || {}; + const prompt = meta.prompt || ''; + const negativePrompt = meta.negative_prompt || meta.negativePrompt || ''; + const size = meta.Size || `${media.width}x${media.height}`; + const seed = meta.seed || ''; + const model = meta.Model || ''; + const steps = meta.steps || ''; + const sampler = meta.sampler || ''; + const cfgScale = meta.cfgScale || ''; + const clipSkip = meta.clipSkip || ''; + + // Check if we have any meaningful generation parameters + const hasParams = seed || model || steps || sampler || cfgScale || clipSkip; + const hasPrompts = prompt || negativePrompt; + + // Create metadata panel content + const metadataPanel = generateMetadataPanel( + hasParams, hasPrompts, + prompt, negativePrompt, + size, seed, model, steps, sampler, cfgScale, clipSkip + ); + + // Check if this is a video or image + if (media.type === 'video') { + return generateVideoWrapper(media, heightPercent, shouldBlur, nsfwText, metadataPanel); + } + + return generateImageWrapper(media, heightPercent, shouldBlur, nsfwText, metadataPanel); +} + +/** + * Generate metadata panel HTML + */ +function generateMetadataPanel(hasParams, hasPrompts, prompt, negativePrompt, size, seed, model, steps, sampler, cfgScale, clipSkip) { + // Create unique IDs for prompt copying + const promptIndex = Math.random().toString(36).substring(2, 15); + const negPromptIndex = Math.random().toString(36).substring(2, 15); + + let content = '
'; + return content; +} + +/** + * Generate video wrapper HTML + */ +function generateVideoWrapper(media, heightPercent, shouldBlur, nsfwText, metadataPanel) { + return ` +
+ ${shouldBlur ? ` + + ` : ''} + + ${shouldBlur ? ` +
+
+

${nsfwText}

+ +
+
+ ` : ''} + ${metadataPanel} +
+ `; +} + +/** + * Generate image wrapper HTML + */ +function generateImageWrapper(media, heightPercent, shouldBlur, nsfwText, metadataPanel) { + return ` +
+ ${shouldBlur ? ` + + ` : ''} + Preview + ${shouldBlur ? ` +
+
+

${nsfwText}

+ +
+
+ ` : ''} + ${metadataPanel} +
+ `; +} + +/** + * Toggle showcase expansion + */ +export function toggleShowcase(element) { + const carousel = element.nextElementSibling; + const isCollapsed = carousel.classList.contains('collapsed'); + const indicator = element.querySelector('span'); + const icon = element.querySelector('i'); + + carousel.classList.toggle('collapsed'); + + if (isCollapsed) { + const count = carousel.querySelectorAll('.media-wrapper').length; + indicator.textContent = `Scroll or click to hide examples`; + icon.classList.replace('fa-chevron-down', 'fa-chevron-up'); + initLazyLoading(carousel); + + // Initialize NSFW content blur toggle handlers + initNsfwBlurHandlers(carousel); + + // Initialize metadata panel interaction handlers + initMetadataPanelHandlers(carousel); + } else { + const count = carousel.querySelectorAll('.media-wrapper').length; + indicator.textContent = `Scroll or click to show ${count} examples`; + icon.classList.replace('fa-chevron-up', 'fa-chevron-down'); + } +} + +/** + * Initialize metadata panel interaction handlers + */ +function initMetadataPanelHandlers(container) { + const mediaWrappers = container.querySelectorAll('.media-wrapper'); + + mediaWrappers.forEach(wrapper => { + const metadataPanel = wrapper.querySelector('.image-metadata-panel'); + if (!metadataPanel) return; + + // Prevent events from bubbling + metadataPanel.addEventListener('click', (e) => { + e.stopPropagation(); + }); + + // Handle copy prompt buttons + const copyBtns = metadataPanel.querySelectorAll('.copy-prompt-btn'); + copyBtns.forEach(copyBtn => { + const promptIndex = copyBtn.dataset.promptIndex; + const promptElement = wrapper.querySelector(`#prompt-${promptIndex}`); + + copyBtn.addEventListener('click', async (e) => { + e.stopPropagation(); + + if (!promptElement) return; + + try { + await navigator.clipboard.writeText(promptElement.textContent); + showToast('Prompt copied to clipboard', 'success'); + } catch (err) { + console.error('Copy failed:', err); + showToast('Copy failed', 'error'); + } + }); + }); + + // Prevent panel scroll from causing modal scroll + metadataPanel.addEventListener('wheel', (e) => { + e.stopPropagation(); + }); + }); +} + +/** + * Initialize blur toggle handlers + */ +function initNsfwBlurHandlers(container) { + // Handle toggle blur buttons + const toggleButtons = container.querySelectorAll('.toggle-blur-btn'); + toggleButtons.forEach(btn => { + btn.addEventListener('click', (e) => { + e.stopPropagation(); + const wrapper = btn.closest('.media-wrapper'); + const media = wrapper.querySelector('img, video'); + const isBlurred = media.classList.toggle('blurred'); + const icon = btn.querySelector('i'); + + // Update the icon based on blur state + if (isBlurred) { + icon.className = 'fas fa-eye'; + } else { + icon.className = 'fas fa-eye-slash'; + } + + // Toggle the overlay visibility + const overlay = wrapper.querySelector('.nsfw-overlay'); + if (overlay) { + overlay.style.display = isBlurred ? 'flex' : 'none'; + } + }); + }); + + // Handle "Show" buttons in overlays + const showButtons = container.querySelectorAll('.show-content-btn'); + showButtons.forEach(btn => { + btn.addEventListener('click', (e) => { + e.stopPropagation(); + const wrapper = btn.closest('.media-wrapper'); + const media = wrapper.querySelector('img, video'); + media.classList.remove('blurred'); + + // Update the toggle button icon + const toggleBtn = wrapper.querySelector('.toggle-blur-btn'); + if (toggleBtn) { + toggleBtn.querySelector('i').className = 'fas fa-eye-slash'; + } + + // Hide the overlay + const overlay = wrapper.querySelector('.nsfw-overlay'); + if (overlay) { + overlay.style.display = 'none'; + } + }); + }); +} + +/** + * Initialize lazy loading for images and videos + */ +function initLazyLoading(container) { + const lazyElements = container.querySelectorAll('.lazy'); + + const lazyLoad = (element) => { + if (element.tagName.toLowerCase() === 'video') { + element.src = element.dataset.src; + element.querySelector('source').src = element.dataset.src; + element.load(); + } else { + element.src = element.dataset.src; + } + element.classList.remove('lazy'); + }; + + const observer = new IntersectionObserver((entries) => { + entries.forEach(entry => { + if (entry.isIntersecting) { + lazyLoad(entry.target); + observer.unobserve(entry.target); + } + }); + }); + + lazyElements.forEach(element => observer.observe(element)); +} + +/** + * Set up showcase scroll functionality + */ +export function setupShowcaseScroll() { + // Listen for wheel events + document.addEventListener('wheel', (event) => { + const modalContent = document.querySelector('#checkpointModal .modal-content'); + if (!modalContent) return; + + const showcase = modalContent.querySelector('.showcase-section'); + if (!showcase) return; + + const carousel = showcase.querySelector('.carousel'); + const scrollIndicator = showcase.querySelector('.scroll-indicator'); + + if (carousel?.classList.contains('collapsed') && event.deltaY > 0) { + const isNearBottom = modalContent.scrollHeight - modalContent.scrollTop - modalContent.clientHeight < 100; + + if (isNearBottom) { + toggleShowcase(scrollIndicator); + event.preventDefault(); + } + } + }, { passive: false }); + + // Use MutationObserver to set up back-to-top button when modal content is added + const observer = new MutationObserver((mutations) => { + for (const mutation of mutations) { + if (mutation.type === 'childList' && mutation.addedNodes.length) { + const checkpointModal = document.getElementById('checkpointModal'); + if (checkpointModal && checkpointModal.querySelector('.modal-content')) { + setupBackToTopButton(checkpointModal.querySelector('.modal-content')); + } + } + } + }); + + // Start observing the document body for changes + observer.observe(document.body, { childList: true, subtree: true }); + + // Also try to set up the button immediately in case the modal is already open + const modalContent = document.querySelector('#checkpointModal .modal-content'); + if (modalContent) { + setupBackToTopButton(modalContent); + } +} + +/** + * Set up back-to-top button + */ +function setupBackToTopButton(modalContent) { + // Remove any existing scroll listeners to avoid duplicates + modalContent.onscroll = null; + + // Add new scroll listener + modalContent.addEventListener('scroll', () => { + const backToTopBtn = modalContent.querySelector('.back-to-top'); + if (backToTopBtn) { + if (modalContent.scrollTop > 300) { + backToTopBtn.classList.add('visible'); + } else { + backToTopBtn.classList.remove('visible'); + } + } + }); + + // Trigger a scroll event to check initial position + modalContent.dispatchEvent(new Event('scroll')); +} + +/** + * Scroll to top of modal content + */ +export function scrollToTop(button) { + const modalContent = button.closest('.modal-content'); + if (modalContent) { + modalContent.scrollTo({ + top: 0, + behavior: 'smooth' + }); + } +} \ No newline at end of file diff --git a/static/js/components/checkpointModal/index.js b/static/js/components/checkpointModal/index.js new file mode 100644 index 00000000..879f9218 --- /dev/null +++ b/static/js/components/checkpointModal/index.js @@ -0,0 +1,214 @@ +/** + * CheckpointModal - Main entry point + * + * Modularized checkpoint modal component that handles checkpoint model details display + */ +import { showToast } from '../../utils/uiHelpers.js'; +import { state } from '../../state/index.js'; +import { modalManager } from '../../managers/ModalManager.js'; +import { renderShowcaseContent, toggleShowcase, setupShowcaseScroll, scrollToTop } from './ShowcaseView.js'; +import { setupTabSwitching, loadModelDescription } from './ModelDescription.js'; +import { + setupModelNameEditing, + setupBaseModelEditing, + setupFileNameEditing, + saveModelMetadata +} from './ModelMetadata.js'; +import { renderCompactTags, setupTagTooltip, formatFileSize } from './utils.js'; +import { updateCheckpointCard } from '../../utils/cardUpdater.js'; + +/** + * Display the checkpoint modal with the given checkpoint data + * @param {Object} checkpoint - Checkpoint data object + */ +export function showCheckpointModal(checkpoint) { + const content = ` + + `; + + modalManager.showModal('checkpointModal', content); + setupEditableFields(checkpoint.file_path); + setupShowcaseScroll(); + setupTabSwitching(); + setupTagTooltip(); + setupModelNameEditing(checkpoint.file_path); + setupBaseModelEditing(checkpoint.file_path); + setupFileNameEditing(checkpoint.file_path); + + // If we have a model ID but no description, fetch it + if (checkpoint.civitai?.modelId && !checkpoint.modelDescription) { + loadModelDescription(checkpoint.civitai.modelId, checkpoint.file_path); + } +} + +/** + * Set up editable fields in the checkpoint modal +* @param {string} filePath - The full file path of the model. + */ +function setupEditableFields(filePath) { + const editableFields = document.querySelectorAll('.editable-field [contenteditable]'); + + editableFields.forEach(field => { + field.addEventListener('focus', function() { + if (this.textContent === 'Add your notes here...') { + this.textContent = ''; + } + }); + + field.addEventListener('blur', function() { + if (this.textContent.trim() === '') { + if (this.classList.contains('notes-content')) { + this.textContent = 'Add your notes here...'; + } + } + }); + }); + + // Add keydown event listeners for notes + const notesContent = document.querySelector('.notes-content'); + if (notesContent) { + notesContent.addEventListener('keydown', async function(e) { + if (e.key === 'Enter') { + if (e.shiftKey) { + // Allow shift+enter for new line + return; + } + e.preventDefault(); + await saveNotes(filePath); + } + }); + } +} + +/** + * Save checkpoint notes + * @param {string} filePath - Path to the checkpoint file + */ +async function saveNotes(filePath) { + const content = document.querySelector('.notes-content').textContent; + try { + await saveModelMetadata(filePath, { notes: content }); + + // Update the corresponding checkpoint card's dataset + updateCheckpointCard(filePath, { notes: content }); + + showToast('Notes saved successfully', 'success'); + } catch (error) { + showToast('Failed to save notes', 'error'); + } +} + +// Export the checkpoint modal API +const checkpointModal = { + show: showCheckpointModal, + toggleShowcase, + scrollToTop +}; + +export { checkpointModal }; + +// Define global functions for use in HTML +window.toggleShowcase = function(element) { + toggleShowcase(element); +}; + +window.scrollToTopCheckpoint = function(button) { + scrollToTop(button); +}; + +window.saveCheckpointNotes = function(filePath) { + saveNotes(filePath); +}; \ No newline at end of file diff --git a/static/js/components/checkpointModal/utils.js b/static/js/components/checkpointModal/utils.js new file mode 100644 index 00000000..62bc59fe --- /dev/null +++ b/static/js/components/checkpointModal/utils.js @@ -0,0 +1,74 @@ +/** + * utils.js + * CheckpointModal component utility functions + */ +import { showToast } from '../../utils/uiHelpers.js'; + +/** + * Format file size for display + * @param {number} bytes - File size in bytes + * @returns {string} - Formatted file size + */ +export function formatFileSize(bytes) { + if (!bytes) return 'N/A'; + + const units = ['B', 'KB', 'MB', 'GB', 'TB']; + let size = bytes; + let unitIndex = 0; + + while (size >= 1024 && unitIndex < units.length - 1) { + size /= 1024; + unitIndex++; + } + + return `${size.toFixed(1)} ${units[unitIndex]}`; +} + +/** + * Render compact tags + * @param {Array} tags - Array of tags + * @returns {string} HTML content + */ +export function renderCompactTags(tags) { + if (!tags || tags.length === 0) return ''; + + // Display up to 5 tags, with a tooltip indicator if there are more + const visibleTags = tags.slice(0, 5); + const remainingCount = Math.max(0, tags.length - 5); + + return ` +
+
+ ${visibleTags.map(tag => `${tag}`).join('')} + ${remainingCount > 0 ? + `+${remainingCount}` : + ''} +
+ ${tags.length > 0 ? + `
+
+ ${tags.map(tag => `${tag}`).join('')} +
+
` : + ''} +
+ `; +} + +/** + * Set up tag tooltip functionality + */ +export function setupTagTooltip() { + const tagsContainer = document.querySelector('.model-tags-container'); + const tooltip = document.querySelector('.model-tags-tooltip'); + + if (tagsContainer && tooltip) { + tagsContainer.addEventListener('mouseenter', () => { + tooltip.classList.add('visible'); + }); + + tagsContainer.addEventListener('mouseleave', () => { + tooltip.classList.remove('visible'); + }); + } +} \ No newline at end of file diff --git a/static/js/components/controls/CheckpointsControls.js b/static/js/components/controls/CheckpointsControls.js new file mode 100644 index 00000000..8cc323f1 --- /dev/null +++ b/static/js/components/controls/CheckpointsControls.js @@ -0,0 +1,60 @@ +// CheckpointsControls.js - Specific implementation for the Checkpoints page +import { PageControls } from './PageControls.js'; +import { loadMoreCheckpoints, resetAndReload, refreshCheckpoints, fetchCivitai } from '../../api/checkpointApi.js'; +import { showToast } from '../../utils/uiHelpers.js'; +import { CheckpointDownloadManager } from '../../managers/CheckpointDownloadManager.js'; + +/** + * CheckpointsControls class - Extends PageControls for Checkpoint-specific functionality + */ +export class CheckpointsControls extends PageControls { + constructor() { + // Initialize with 'checkpoints' page type + super('checkpoints'); + + // Initialize checkpoint download manager + this.downloadManager = new CheckpointDownloadManager(); + + // Register API methods specific to the Checkpoints page + this.registerCheckpointsAPI(); + } + + /** + * Register Checkpoint-specific API methods + */ + registerCheckpointsAPI() { + const checkpointsAPI = { + // Core API functions + loadMoreModels: async (resetPage = false, updateFolders = false) => { + return await loadMoreCheckpoints(resetPage, updateFolders); + }, + + resetAndReload: async (updateFolders = false) => { + return await resetAndReload(updateFolders); + }, + + refreshModels: async () => { + return await refreshCheckpoints(); + }, + + // Add fetch from Civitai functionality for checkpoints + fetchFromCivitai: async () => { + return await fetchCivitai(); + }, + + // Add show download modal functionality + showDownloadModal: () => { + this.downloadManager.showDownloadModal(); + }, + + // No clearCustomFilter implementation is needed for checkpoints + // as custom filters are currently only used for LoRAs + clearCustomFilter: async () => { + showToast('No custom filter to clear', 'info'); + } + }; + + // Register the API + this.registerAPI(checkpointsAPI); + } +} \ No newline at end of file diff --git a/static/js/components/controls/LorasControls.js b/static/js/components/controls/LorasControls.js new file mode 100644 index 00000000..0a45a9e6 --- /dev/null +++ b/static/js/components/controls/LorasControls.js @@ -0,0 +1,146 @@ +// LorasControls.js - Specific implementation for the LoRAs page +import { PageControls } from './PageControls.js'; +import { loadMoreLoras, fetchCivitai, resetAndReload, refreshLoras } from '../../api/loraApi.js'; +import { getSessionItem, removeSessionItem } from '../../utils/storageHelpers.js'; +import { showToast } from '../../utils/uiHelpers.js'; + +/** + * LorasControls class - Extends PageControls for LoRA-specific functionality + */ +export class LorasControls extends PageControls { + constructor() { + // Initialize with 'loras' page type + super('loras'); + + // Register API methods specific to the LoRAs page + this.registerLorasAPI(); + + // Check for custom filters (e.g., from recipe navigation) + this.checkCustomFilters(); + } + + /** + * Register LoRA-specific API methods + */ + registerLorasAPI() { + const lorasAPI = { + // Core API functions + loadMoreModels: async (resetPage = false, updateFolders = false) => { + return await loadMoreLoras(resetPage, updateFolders); + }, + + resetAndReload: async (updateFolders = false) => { + return await resetAndReload(updateFolders); + }, + + refreshModels: async () => { + return await refreshLoras(); + }, + + // LoRA-specific API functions + fetchFromCivitai: async () => { + return await fetchCivitai(); + }, + + showDownloadModal: () => { + if (window.downloadManager) { + window.downloadManager.showDownloadModal(); + } else { + console.error('Download manager not available'); + } + }, + + toggleBulkMode: () => { + if (window.bulkManager) { + window.bulkManager.toggleBulkMode(); + } else { + console.error('Bulk manager not available'); + } + }, + + clearCustomFilter: async () => { + await this.clearCustomFilter(); + } + }; + + // Register the API + this.registerAPI(lorasAPI); + } + + /** + * Check for custom filter parameters in session storage (e.g., from recipe page navigation) + */ + checkCustomFilters() { + const filterLoraHash = getSessionItem('recipe_to_lora_filterLoraHash'); + const filterLoraHashes = getSessionItem('recipe_to_lora_filterLoraHashes'); + const filterRecipeName = getSessionItem('filterRecipeName'); + const viewLoraDetail = getSessionItem('viewLoraDetail'); + + if ((filterLoraHash || filterLoraHashes) && filterRecipeName) { + // Found custom filter parameters, set up the custom filter + + // Show the filter indicator + const indicator = document.getElementById('customFilterIndicator'); + const filterText = indicator?.querySelector('.customFilterText'); + + if (indicator && filterText) { + indicator.classList.remove('hidden'); + + // Set text content with recipe name + const filterType = filterLoraHash && viewLoraDetail ? "Viewing LoRA from" : "Viewing LoRAs from"; + const displayText = `${filterType}: ${filterRecipeName}`; + + filterText.textContent = this._truncateText(displayText, 30); + filterText.setAttribute('title', displayText); + + // Add pulse animation + const filterElement = indicator.querySelector('.filter-active'); + if (filterElement) { + filterElement.classList.add('animate'); + setTimeout(() => filterElement.classList.remove('animate'), 600); + } + } + + // If we're viewing a specific LoRA detail, set up to open the modal + if (filterLoraHash && viewLoraDetail) { + this.pageState.pendingLoraHash = filterLoraHash; + } + } + } + + /** + * Clear the custom filter and reload the page + */ + async clearCustomFilter() { + console.log("Clearing custom filter..."); + // Remove filter parameters from session storage + removeSessionItem('recipe_to_lora_filterLoraHash'); + removeSessionItem('recipe_to_lora_filterLoraHashes'); + removeSessionItem('filterRecipeName'); + removeSessionItem('viewLoraDetail'); + + // Hide the filter indicator + const indicator = document.getElementById('customFilterIndicator'); + if (indicator) { + indicator.classList.add('hidden'); + } + + // Reset state + if (this.pageState.pendingLoraHash) { + delete this.pageState.pendingLoraHash; + } + + // Reload the loras + await resetAndReload(); + } + + /** + * Helper to truncate text with ellipsis + * @param {string} text - Text to truncate + * @param {number} maxLength - Maximum length before truncating + * @returns {string} - Truncated text + */ + _truncateText(text, maxLength) { + return text.length > maxLength ? text.substring(0, maxLength - 3) + '...' : text; + } +} \ No newline at end of file diff --git a/static/js/components/controls/PageControls.js b/static/js/components/controls/PageControls.js new file mode 100644 index 00000000..b30151c1 --- /dev/null +++ b/static/js/components/controls/PageControls.js @@ -0,0 +1,388 @@ +// PageControls.js - Manages controls for both LoRAs and Checkpoints pages +import { state, getCurrentPageState, setCurrentPageType } from '../../state/index.js'; +import { getStorageItem, setStorageItem } from '../../utils/storageHelpers.js'; +import { showToast } from '../../utils/uiHelpers.js'; + +/** + * PageControls class - Unified control management for model pages + */ +export class PageControls { + constructor(pageType) { + // Set the current page type in state + setCurrentPageType(pageType); + + // Store the page type + this.pageType = pageType; + + // Get the current page state + this.pageState = getCurrentPageState(); + + // Initialize state based on page type + this.initializeState(); + + // Store API methods + this.api = null; + + // Initialize event listeners + this.initEventListeners(); + + console.log(`PageControls initialized for ${pageType} page`); + } + + /** + * Initialize state based on page type + */ + initializeState() { + // Set default values + this.pageState.pageSize = 20; + this.pageState.isLoading = false; + this.pageState.hasMore = true; + + // Load sort preference + this.loadSortPreference(); + } + + /** + * Register API methods for the page + * @param {Object} api - API methods for the page + */ + registerAPI(api) { + this.api = api; + console.log(`API methods registered for ${this.pageType} page`); + } + + /** + * Initialize event listeners for controls + */ + initEventListeners() { + // Sort select handler + const sortSelect = document.getElementById('sortSelect'); + if (sortSelect) { + sortSelect.value = this.pageState.sortBy; + sortSelect.addEventListener('change', async (e) => { + this.pageState.sortBy = e.target.value; + this.saveSortPreference(e.target.value); + await this.resetAndReload(); + }); + } + + // Use event delegation for folder tags - this is the key fix + const folderTagsContainer = document.querySelector('.folder-tags-container'); + if (folderTagsContainer) { + folderTagsContainer.addEventListener('click', (e) => { + const tag = e.target.closest('.tag'); + if (tag) { + this.handleFolderClick(tag); + } + }); + } + + // Refresh button handler + const refreshBtn = document.querySelector('[data-action="refresh"]'); + if (refreshBtn) { + refreshBtn.addEventListener('click', () => this.refreshModels()); + } + + // Toggle folders button + const toggleFoldersBtn = document.querySelector('.toggle-folders-btn'); + if (toggleFoldersBtn) { + toggleFoldersBtn.addEventListener('click', () => this.toggleFolderTags()); + } + + // Clear custom filter handler + const clearFilterBtn = document.querySelector('.clear-filter'); + if (clearFilterBtn) { + clearFilterBtn.addEventListener('click', () => this.clearCustomFilter()); + } + + // Page-specific event listeners + this.initPageSpecificListeners(); + } + + /** + * Initialize page-specific event listeners + */ + initPageSpecificListeners() { + // Fetch from Civitai button - available for both loras and checkpoints + const fetchButton = document.querySelector('[data-action="fetch"]'); + if (fetchButton) { + fetchButton.addEventListener('click', () => this.fetchFromCivitai()); + } + + const downloadButton = document.querySelector('[data-action="download"]'); + if (downloadButton) { + downloadButton.addEventListener('click', () => this.showDownloadModal()); + } + + if (this.pageType === 'loras') { + // Bulk operations button - LoRAs only + const bulkButton = document.querySelector('[data-action="bulk"]'); + if (bulkButton) { + bulkButton.addEventListener('click', () => this.toggleBulkMode()); + } + } + } + + /** + * Toggle folder selection + * @param {HTMLElement} tagElement - The folder tag element that was clicked + */ + handleFolderClick(tagElement) { + const folder = tagElement.dataset.folder; + const wasActive = tagElement.classList.contains('active'); + + document.querySelectorAll('.folder-tags .tag').forEach(t => { + t.classList.remove('active'); + }); + + if (!wasActive) { + tagElement.classList.add('active'); + this.pageState.activeFolder = folder; + setStorageItem(`${this.pageType}_activeFolder`, folder); + } else { + this.pageState.activeFolder = null; + setStorageItem(`${this.pageType}_activeFolder`, null); + } + + this.resetAndReload(); + } + + /** + * Restore folder filter from storage + */ + restoreFolderFilter() { + const activeFolder = getStorageItem(`${this.pageType}_activeFolder`); + const folderTag = activeFolder && document.querySelector(`.tag[data-folder="${activeFolder}"]`); + + if (folderTag) { + folderTag.classList.add('active'); + this.pageState.activeFolder = activeFolder; + this.filterByFolder(activeFolder); + } + } + + /** + * Filter displayed cards by folder + * @param {string} folderPath - Folder path to filter by + */ + filterByFolder(folderPath) { + const cardSelector = this.pageType === 'loras' ? '.lora-card' : '.checkpoint-card'; + document.querySelectorAll(cardSelector).forEach(card => { + card.style.display = card.dataset.folder === folderPath ? '' : 'none'; + }); + } + + /** + * Update the folder tags display with new folder list + * @param {Array} folders - List of folder names + */ + updateFolderTags(folders) { + const folderTagsContainer = document.querySelector('.folder-tags'); + if (!folderTagsContainer) return; + + // Keep track of currently selected folder + const currentFolder = this.pageState.activeFolder; + + // Create HTML for folder tags + const tagsHTML = folders.map(folder => { + const isActive = folder === currentFolder; + return `
${folder}
`; + }).join(''); + + // Update the container + folderTagsContainer.innerHTML = tagsHTML; + + // Scroll active folder into view (no need to reattach click handlers) + const activeTag = folderTagsContainer.querySelector(`.tag[data-folder="${currentFolder}"]`); + if (activeTag) { + activeTag.scrollIntoView({ behavior: 'smooth', block: 'nearest' }); + } + } + + /** + * Toggle visibility of folder tags + */ + toggleFolderTags() { + const folderTags = document.querySelector('.folder-tags'); + const toggleBtn = document.querySelector('.toggle-folders-btn i'); + + if (folderTags) { + folderTags.classList.toggle('collapsed'); + + if (folderTags.classList.contains('collapsed')) { + // Change icon to indicate folders are hidden + toggleBtn.className = 'fas fa-folder-plus'; + toggleBtn.parentElement.title = 'Show folder tags'; + setStorageItem('folderTagsCollapsed', 'true'); + } else { + // Change icon to indicate folders are visible + toggleBtn.className = 'fas fa-folder-minus'; + toggleBtn.parentElement.title = 'Hide folder tags'; + setStorageItem('folderTagsCollapsed', 'false'); + } + } + } + + /** + * Initialize folder tags visibility based on stored preference + */ + initFolderTagsVisibility() { + const isCollapsed = getStorageItem('folderTagsCollapsed'); + if (isCollapsed) { + const folderTags = document.querySelector('.folder-tags'); + const toggleBtn = document.querySelector('.toggle-folders-btn i'); + if (folderTags) { + folderTags.classList.add('collapsed'); + } + if (toggleBtn) { + toggleBtn.className = 'fas fa-folder-plus'; + toggleBtn.parentElement.title = 'Show folder tags'; + } + } else { + const toggleBtn = document.querySelector('.toggle-folders-btn i'); + if (toggleBtn) { + toggleBtn.className = 'fas fa-folder-minus'; + toggleBtn.parentElement.title = 'Hide folder tags'; + } + } + } + + /** + * Load sort preference from storage + */ + loadSortPreference() { + const savedSort = getStorageItem(`${this.pageType}_sort`); + if (savedSort) { + this.pageState.sortBy = savedSort; + const sortSelect = document.getElementById('sortSelect'); + if (sortSelect) { + sortSelect.value = savedSort; + } + } + } + + /** + * Save sort preference to storage + * @param {string} sortValue - The sort value to save + */ + saveSortPreference(sortValue) { + setStorageItem(`${this.pageType}_sort`, sortValue); + } + + /** + * Open model page on Civitai + * @param {string} modelName - Name of the model + */ + openCivitai(modelName) { + // Get card selector based on page type + const cardSelector = this.pageType === 'loras' + ? `.lora-card[data-name="${modelName}"]` + : `.checkpoint-card[data-name="${modelName}"]`; + + const card = document.querySelector(cardSelector); + if (!card) return; + + const metaData = JSON.parse(card.dataset.meta); + const civitaiId = metaData.modelId; + const versionId = metaData.id; + + // Build URL + if (civitaiId) { + let url = `https://civitai.com/models/${civitaiId}`; + if (versionId) { + url += `?modelVersionId=${versionId}`; + } + window.open(url, '_blank'); + } else { + // If no ID, try searching by name + window.open(`https://civitai.com/models?query=${encodeURIComponent(modelName)}`, '_blank'); + } + } + + /** + * Reset and reload the models list + */ + async resetAndReload(updateFolders = false) { + if (!this.api) { + console.error('API methods not registered'); + return; + } + + try { + await this.api.resetAndReload(updateFolders); + } catch (error) { + console.error(`Error reloading ${this.pageType}:`, error); + showToast(`Failed to reload ${this.pageType}: ${error.message}`, 'error'); + } + } + + /** + * Refresh models list + */ + async refreshModels() { + if (!this.api) { + console.error('API methods not registered'); + return; + } + + try { + await this.api.refreshModels(); + } catch (error) { + console.error(`Error refreshing ${this.pageType}:`, error); + showToast(`Failed to refresh ${this.pageType}: ${error.message}`, 'error'); + } + } + + /** + * Fetch metadata from Civitai (available for both LoRAs and Checkpoints) + */ + async fetchFromCivitai() { + if (!this.api) { + console.error('API methods not registered'); + return; + } + + try { + await this.api.fetchFromCivitai(); + } catch (error) { + console.error('Error fetching metadata:', error); + showToast('Failed to fetch metadata: ' + error.message, 'error'); + } + } + + /** + * Show download modal + */ + showDownloadModal() { + this.api.showDownloadModal(); + } + + /** + * Toggle bulk mode (LoRAs only) + */ + toggleBulkMode() { + if (this.pageType !== 'loras' || !this.api) { + console.error('Bulk mode is only available for LoRAs'); + return; + } + + this.api.toggleBulkMode(); + } + + /** + * Clear custom filter + */ + async clearCustomFilter() { + if (!this.api) { + console.error('API methods not registered'); + return; + } + + try { + await this.api.clearCustomFilter(); + } catch (error) { + console.error('Error clearing custom filter:', error); + showToast('Failed to clear custom filter: ' + error.message, 'error'); + } + } +} \ No newline at end of file diff --git a/static/js/components/controls/index.js b/static/js/components/controls/index.js new file mode 100644 index 00000000..c767c62f --- /dev/null +++ b/static/js/components/controls/index.js @@ -0,0 +1,23 @@ +// Controls components index file +import { PageControls } from './PageControls.js'; +import { LorasControls } from './LorasControls.js'; +import { CheckpointsControls } from './CheckpointsControls.js'; + +// Export the classes +export { PageControls, LorasControls, CheckpointsControls }; + +/** + * Factory function to create the appropriate controls based on page type + * @param {string} pageType - The type of page ('loras' or 'checkpoints') + * @returns {PageControls} - The appropriate controls instance + */ +export function createPageControls(pageType) { + if (pageType === 'loras') { + return new LorasControls(); + } else if (pageType === 'checkpoints') { + return new CheckpointsControls(); + } else { + console.error(`Unknown page type: ${pageType}`); + return null; + } +} \ No newline at end of file diff --git a/static/js/components/loraModal/ModelMetadata.js b/static/js/components/loraModal/ModelMetadata.js index 879ec861..eb0a3d18 100644 --- a/static/js/components/loraModal/ModelMetadata.js +++ b/static/js/components/loraModal/ModelMetadata.js @@ -4,6 +4,7 @@ */ import { showToast } from '../../utils/uiHelpers.js'; import { BASE_MODELS } from '../../utils/constants.js'; +import { updateLoraCard } from '../../utils/cardUpdater.js'; /** * 保存模型元数据到服务器 @@ -12,7 +13,7 @@ import { BASE_MODELS } from '../../utils/constants.js'; * @returns {Promise} 保存操作的Promise */ export async function saveModelMetadata(filePath, data) { - const response = await fetch('/loras/api/save-metadata', { + const response = await fetch('/api/loras/save-metadata', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -32,13 +33,17 @@ export async function saveModelMetadata(filePath, data) { /** * 设置模型名称编辑功能 + * @param {string} filePath - 文件路径 */ -export function setupModelNameEditing() { +export function setupModelNameEditing(filePath) { const modelNameContent = document.querySelector('.model-name-content'); const editBtn = document.querySelector('.edit-model-name-btn'); if (!modelNameContent || !editBtn) return; + // Store the file path in a data attribute for later use + modelNameContent.dataset.filePath = filePath; + // Show edit button on hover const modelNameHeader = document.querySelector('.model-name-header'); modelNameHeader.addEventListener('mouseenter', () => { @@ -76,10 +81,7 @@ export function setupModelNameEditing() { if (this.textContent.trim() === '') { // Restore original model name if empty - const filePath = document.querySelector('#loraModal .modal-content') - .querySelector('.file-path').textContent + - document.querySelector('#loraModal .modal-content') - .querySelector('#file-name').textContent + '.safetensors'; + const filePath = this.dataset.filePath; const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); if (loraCard) { this.textContent = loraCard.dataset.model_name; @@ -91,10 +93,7 @@ export function setupModelNameEditing() { modelNameContent.addEventListener('keydown', function(e) { if (e.key === 'Enter') { e.preventDefault(); - const filePath = document.querySelector('#loraModal .modal-content') - .querySelector('.file-path').textContent + - document.querySelector('#loraModal .modal-content') - .querySelector('#file-name').textContent + '.safetensors'; + const filePath = this.dataset.filePath; saveModelName(filePath); this.blur(); } @@ -144,21 +143,9 @@ async function saveModelName(filePath) { await saveModelMetadata(filePath, { model_name: newModelName }); // Update the corresponding lora card's dataset and display - const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); - if (loraCard) { - loraCard.dataset.model_name = newModelName; - const titleElement = loraCard.querySelector('.card-title'); - if (titleElement) { - titleElement.textContent = newModelName; - } - } + updateLoraCard(filePath, { model_name: newModelName }); showToast('Model name updated successfully', 'success'); - - // Reload the page to reflect the sorted order - setTimeout(() => { - window.location.reload(); - }, 1500); } catch (error) { showToast('Failed to update model name', 'error'); } @@ -166,13 +153,17 @@ async function saveModelName(filePath) { /** * 设置基础模型编辑功能 + * @param {string} filePath - 文件路径 */ -export function setupBaseModelEditing() { +export function setupBaseModelEditing(filePath) { const baseModelContent = document.querySelector('.base-model-content'); const editBtn = document.querySelector('.edit-base-model-btn'); if (!baseModelContent || !editBtn) return; + // Store the file path in a data attribute for later use + baseModelContent.dataset.filePath = filePath; + // Show edit button on hover const baseModelDisplay = document.querySelector('.base-model-display'); baseModelDisplay.addEventListener('mouseenter', () => { @@ -270,11 +261,8 @@ export function setupBaseModelEditing() { // Only save if the value has actually changed if (valueChanged || baseModelContent.textContent.trim() !== originalValue) { - // Get file path for saving - const filePath = document.querySelector('#loraModal .modal-content') - .querySelector('.file-path').textContent + - document.querySelector('#loraModal .modal-content') - .querySelector('#file-name').textContent + '.safetensors'; + // Get file path from the dataset + const filePath = baseModelContent.dataset.filePath; // Save the changes, passing the original value for comparison saveBaseModel(filePath, originalValue); @@ -325,10 +313,7 @@ async function saveBaseModel(filePath, originalValue) { await saveModelMetadata(filePath, { base_model: newBaseModel }); // Update the corresponding lora card's dataset - const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); - if (loraCard) { - loraCard.dataset.base_model = newBaseModel; - } + updateLoraCard(filePath, { base_model: newBaseModel }); showToast('Base model updated successfully', 'success'); } catch (error) { @@ -338,13 +323,17 @@ async function saveBaseModel(filePath, originalValue) { /** * 设置文件名编辑功能 + * @param {string} filePath - 文件路径 */ -export function setupFileNameEditing() { +export function setupFileNameEditing(filePath) { const fileNameContent = document.querySelector('.file-name-content'); const editBtn = document.querySelector('.edit-file-name-btn'); if (!fileNameContent || !editBtn) return; + // Store the original file path + fileNameContent.dataset.filePath = filePath; + // Show edit button on hover const fileNameWrapper = document.querySelector('.file-name-wrapper'); fileNameWrapper.addEventListener('mouseenter', () => { @@ -441,9 +430,8 @@ export function setupFileNameEditing() { } try { - // Get the full file path - const filePath = document.querySelector('#loraModal .modal-content') - .querySelector('.file-path').textContent + originalValue + '.safetensors'; + // Get the file path from the dataset + const filePath = this.dataset.filePath; // Call API to rename the file const response = await fetch('/api/rename_lora', { @@ -462,17 +450,10 @@ export function setupFileNameEditing() { if (result.success) { showToast('File name updated successfully', 'success'); - // Update the LoRA card with new file path - const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); - if (loraCard) { - const newFilePath = filePath.replace(originalValue, newFileName); - loraCard.dataset.filepath = newFilePath; - } - - // Reload the page after a short delay to reflect changes - setTimeout(() => { - window.location.reload(); - }, 1500); + // Get the new file path and update the card + const newFilePath = filePath.replace(originalValue, newFileName); + // Pass the new file_name in the updates object for proper card update + updateLoraCard(filePath, { file_name: newFileName }, newFilePath); } else { throw new Error(result.error || 'Unknown error'); } diff --git a/static/js/components/loraModal/index.js b/static/js/components/loraModal/index.js index 0cc44203..017a1b2a 100644 --- a/static/js/components/loraModal/index.js +++ b/static/js/components/loraModal/index.js @@ -18,6 +18,7 @@ import { saveModelMetadata } from './ModelMetadata.js'; import { renderCompactTags, setupTagTooltip, formatFileSize } from './utils.js'; +import { updateLoraCard } from '../../utils/cardUpdater.js'; /** * 显示LoRA模型弹窗 @@ -152,14 +153,14 @@ export function showLoraModal(lora) { `; modalManager.showModal('loraModal', content); - setupEditableFields(); + setupEditableFields(lora.file_path); setupShowcaseScroll(); setupTabSwitching(); setupTagTooltip(); setupTriggerWordsEditMode(); - setupModelNameEditing(); - setupBaseModelEditing(); - setupFileNameEditing(); + setupModelNameEditing(lora.file_path); + setupBaseModelEditing(lora.file_path); + setupFileNameEditing(lora.file_path); // If we have a model ID but no description, fetch it if (lora.civitai?.modelId && !lora.modelDescription) { @@ -188,10 +189,7 @@ window.saveNotes = async function(filePath) { await saveModelMetadata(filePath, { notes: content }); // Update the corresponding lora card's dataset - const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); - if (loraCard) { - loraCard.dataset.notes = content; - } + updateLoraCard(filePath, { notes: content }); showToast('Notes saved successfully', 'success'); } catch (error) { @@ -199,7 +197,7 @@ window.saveNotes = async function(filePath) { } }; -function setupEditableFields() { +function setupEditableFields(filePath) { const editableFields = document.querySelectorAll('.editable-field [contenteditable]'); editableFields.forEach(field => { @@ -247,11 +245,6 @@ function setupEditableFields() { if (!key || !value) return; - const filePath = document.querySelector('#loraModal .modal-content') - .querySelector('.file-path').textContent + - document.querySelector('#loraModal .modal-content') - .querySelector('#file-name').textContent + '.safetensors'; - const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); const currentPresets = parsePresets(loraCard.dataset.usage_tips); @@ -262,7 +255,9 @@ function setupEditableFields() { usage_tips: newPresetsJson }); - loraCard.dataset.usage_tips = newPresetsJson; + // Update the card with the new usage tips + updateLoraCard(filePath, { usage_tips: newPresetsJson }); + presetTags.innerHTML = renderPresetTags(currentPresets); presetSelector.value = ''; @@ -280,10 +275,6 @@ function setupEditableFields() { return; } e.preventDefault(); - const filePath = document.querySelector('#loraModal .modal-content') - .querySelector('.file-path').textContent + - document.querySelector('#loraModal .modal-content') - .querySelector('#file-name').textContent + '.safetensors'; await saveNotes(filePath); } }); diff --git a/static/js/loras.js b/static/js/loras.js index d9750786..24543725 100644 --- a/static/js/loras.js +++ b/static/js/loras.js @@ -1,23 +1,14 @@ import { appCore } from './core.js'; import { state } from './state/index.js'; import { showLoraModal, toggleShowcase, scrollToTop } from './components/loraModal/index.js'; -import { loadMoreLoras, fetchCivitai, deleteModel, replacePreview, resetAndReload, refreshLoras } from './api/loraApi.js'; -import { - restoreFolderFilter, - toggleFolder, - copyTriggerWord, - openCivitai, - toggleFolderTags, - initFolderTagsVisibility, -} from './utils/uiHelpers.js'; -import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js'; -import { DownloadManager } from './managers/DownloadManager.js'; -import { toggleApiKeyVisibility } from './managers/SettingsManager.js'; -import { LoraContextMenu } from './components/ContextMenu.js'; -import { moveManager } from './managers/MoveManager.js'; +import { loadMoreLoras } from './api/loraApi.js'; import { updateCardsForBulkMode } from './components/LoraCard.js'; import { bulkManager } from './managers/BulkManager.js'; -import { setStorageItem, getStorageItem, getSessionItem, removeSessionItem } from './utils/storageHelpers.js'; +import { DownloadManager } from './managers/DownloadManager.js'; +import { moveManager } from './managers/MoveManager.js'; +import { LoraContextMenu } from './components/ContextMenu.js'; +import { createPageControls } from './components/controls/index.js'; +import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js'; // Initialize the LoRA page class LoraPageManager { @@ -29,25 +20,21 @@ class LoraPageManager { // Initialize managers this.downloadManager = new DownloadManager(); - // Expose necessary functions to the page - this._exposeGlobalFunctions(); + // Initialize page controls + this.pageControls = createPageControls('loras'); + + // Expose necessary functions to the page that still need global access + // These will be refactored in future updates + this._exposeRequiredGlobalFunctions(); } - _exposeGlobalFunctions() { - // Only expose what's needed for the page + _exposeRequiredGlobalFunctions() { + // Only expose what's still needed globally + // Most functionality is now handled by the PageControls component window.loadMoreLoras = loadMoreLoras; - window.fetchCivitai = fetchCivitai; - window.deleteModel = deleteModel; - window.replacePreview = replacePreview; - window.toggleFolder = toggleFolder; - window.copyTriggerWord = copyTriggerWord; window.showLoraModal = showLoraModal; window.confirmDelete = confirmDelete; window.closeDeleteModal = closeDeleteModal; - window.refreshLoras = refreshLoras; - window.openCivitai = openCivitai; - window.toggleFolderTags = toggleFolderTags; - window.toggleApiKeyVisibility = toggleApiKeyVisibility; window.downloadManager = this.downloadManager; window.moveManager = moveManager; window.toggleShowcase = toggleShowcase; @@ -64,14 +51,10 @@ class LoraPageManager { async initialize() { // Initialize page-specific components - this.initEventListeners(); - restoreFolderFilter(); - initFolderTagsVisibility(); + this.pageControls.restoreFolderFilter(); + this.pageControls.initFolderTagsVisibility(); new LoraContextMenu(); - // Check for custom filters from recipe page navigation - this.checkCustomFilters(); - // Initialize cards for current bulk mode state (should be false initially) updateCardsForBulkMode(state.bulkMode); @@ -81,119 +64,6 @@ class LoraPageManager { // Initialize common page features (lazy loading, infinite scroll) appCore.initializePageFeatures(); } - - // Check for custom filter parameters in session storage - checkCustomFilters() { - const filterLoraHash = getSessionItem('recipe_to_lora_filterLoraHash'); - const filterLoraHashes = getSessionItem('recipe_to_lora_filterLoraHashes'); - const filterRecipeName = getSessionItem('filterRecipeName'); - const viewLoraDetail = getSessionItem('viewLoraDetail'); - - console.log("Checking custom filters..."); - console.log("filterLoraHash:", filterLoraHash); - console.log("filterLoraHashes:", filterLoraHashes); - console.log("filterRecipeName:", filterRecipeName); - console.log("viewLoraDetail:", viewLoraDetail); - - if ((filterLoraHash || filterLoraHashes) && filterRecipeName) { - // Found custom filter parameters, set up the custom filter - - // Show the filter indicator - const indicator = document.getElementById('customFilterIndicator'); - const filterText = indicator.querySelector('.customFilterText'); - - if (indicator && filterText) { - indicator.classList.remove('hidden'); - - // Set text content with recipe name - const filterType = filterLoraHash && viewLoraDetail ? "Viewing LoRA from" : "Viewing LoRAs from"; - const displayText = `${filterType}: ${filterRecipeName}`; - - filterText.textContent = this._truncateText(displayText, 30); - filterText.setAttribute('title', displayText); - - // Add click handler for the clear button - const clearBtn = indicator.querySelector('.clear-filter'); - if (clearBtn) { - clearBtn.addEventListener('click', this.clearCustomFilter); - } - - // Add pulse animation - const filterElement = indicator.querySelector('.filter-active'); - if (filterElement) { - filterElement.classList.add('animate'); - setTimeout(() => filterElement.classList.remove('animate'), 600); - } - } - - // If we're viewing a specific LoRA detail, set up to open the modal - if (filterLoraHash && viewLoraDetail) { - // Store this to fetch after initial load completes - state.pendingLoraHash = filterLoraHash; - } - } - } - - // Helper to truncate text with ellipsis - _truncateText(text, maxLength) { - return text.length > maxLength ? text.substring(0, maxLength - 3) + '...' : text; - } - - // Clear the custom filter and reload the page - clearCustomFilter = async () => { - console.log("Clearing custom filter..."); - // Remove filter parameters from session storage - removeSessionItem('recipe_to_lora_filterLoraHash'); - removeSessionItem('recipe_to_lora_filterLoraHashes'); - removeSessionItem('filterRecipeName'); - removeSessionItem('viewLoraDetail'); - - // Hide the filter indicator - const indicator = document.getElementById('customFilterIndicator'); - if (indicator) { - indicator.classList.add('hidden'); - } - - // Reset state - if (state.pendingLoraHash) { - delete state.pendingLoraHash; - } - - // Reload the loras - await resetAndReload(); - } - - loadSortPreference() { - const savedSort = getStorageItem('loras_sort'); - if (savedSort) { - state.sortBy = savedSort; - const sortSelect = document.getElementById('sortSelect'); - if (sortSelect) { - sortSelect.value = savedSort; - } - } - } - - saveSortPreference(sortValue) { - setStorageItem('loras_sort', sortValue); - } - - initEventListeners() { - const sortSelect = document.getElementById('sortSelect'); - if (sortSelect) { - sortSelect.value = state.sortBy; - this.loadSortPreference(); - sortSelect.addEventListener('change', async (e) => { - state.sortBy = e.target.value; - this.saveSortPreference(e.target.value); - await resetAndReload(); - }); - } - - document.querySelectorAll('.folder-tags .tag').forEach(tag => { - tag.addEventListener('click', toggleFolder); - }); - } } // Initialize everything when DOM is ready diff --git a/static/js/managers/CheckpointDownloadManager.js b/static/js/managers/CheckpointDownloadManager.js new file mode 100644 index 00000000..5dfb235c --- /dev/null +++ b/static/js/managers/CheckpointDownloadManager.js @@ -0,0 +1,423 @@ +import { modalManager } from './ModalManager.js'; +import { showToast } from '../utils/uiHelpers.js'; +import { LoadingManager } from './LoadingManager.js'; +import { state } from '../state/index.js'; +import { resetAndReload } from '../api/checkpointApi.js'; +import { getStorageItem } from '../utils/storageHelpers.js'; + +export class CheckpointDownloadManager { + constructor() { + this.currentVersion = null; + this.versions = []; + this.modelInfo = null; + this.modelVersionId = null; + + this.initialized = false; + this.selectedFolder = ''; + + this.loadingManager = new LoadingManager(); + this.folderClickHandler = null; + this.updateTargetPath = this.updateTargetPath.bind(this); + } + + showDownloadModal() { + console.log('Showing checkpoint download modal...'); + if (!this.initialized) { + const modal = document.getElementById('checkpointDownloadModal'); + if (!modal) { + console.error('Checkpoint download modal element not found'); + return; + } + this.initialized = true; + } + + modalManager.showModal('checkpointDownloadModal', null, () => { + // Cleanup handler when modal closes + this.cleanupFolderBrowser(); + }); + this.resetSteps(); + } + + resetSteps() { + document.querySelectorAll('#checkpointDownloadModal .download-step').forEach(step => step.style.display = 'none'); + document.getElementById('cpUrlStep').style.display = 'block'; + document.getElementById('checkpointUrl').value = ''; + document.getElementById('cpUrlError').textContent = ''; + + // Clear new folder input + const newFolderInput = document.getElementById('cpNewFolder'); + if (newFolderInput) { + newFolderInput.value = ''; + } + + this.currentVersion = null; + this.versions = []; + this.modelInfo = null; + this.modelVersionId = null; + + // Clear selected folder and remove selection from UI + this.selectedFolder = ''; + const folderBrowser = document.getElementById('cpFolderBrowser'); + if (folderBrowser) { + folderBrowser.querySelectorAll('.folder-item').forEach(f => + f.classList.remove('selected')); + } + } + + async validateAndFetchVersions() { + const url = document.getElementById('checkpointUrl').value.trim(); + const errorElement = document.getElementById('cpUrlError'); + + try { + this.loadingManager.showSimpleLoading('Fetching model versions...'); + + const modelId = this.extractModelId(url); + if (!modelId) { + throw new Error('Invalid Civitai URL format'); + } + + const response = await fetch(`/api/civitai/versions/${modelId}`); + if (!response.ok) { + throw new Error('Failed to fetch model versions'); + } + + this.versions = await response.json(); + if (!this.versions.length) { + throw new Error('No versions available for this model'); + } + + // If we have a version ID from URL, pre-select it + if (this.modelVersionId) { + this.currentVersion = this.versions.find(v => v.id.toString() === this.modelVersionId); + } + + this.showVersionStep(); + } catch (error) { + errorElement.textContent = error.message; + } finally { + this.loadingManager.hide(); + } + } + + extractModelId(url) { + const modelMatch = url.match(/civitai\.com\/models\/(\d+)/); + const versionMatch = url.match(/modelVersionId=(\d+)/); + + if (modelMatch) { + this.modelVersionId = versionMatch ? versionMatch[1] : null; + return modelMatch[1]; + } + return null; + } + + showVersionStep() { + document.getElementById('cpUrlStep').style.display = 'none'; + document.getElementById('cpVersionStep').style.display = 'block'; + + const versionList = document.getElementById('cpVersionList'); + versionList.innerHTML = this.versions.map(version => { + const firstImage = version.images?.find(img => !img.url.endsWith('.mp4')); + const thumbnailUrl = firstImage ? firstImage.url : '/loras_static/images/no-preview.png'; + + // Use version-level size or fallback to first file + const fileSize = version.modelSizeKB ? + (version.modelSizeKB / 1024).toFixed(2) : + (version.files[0]?.sizeKB / 1024).toFixed(2); + + // Use version-level existsLocally flag + const existsLocally = version.existsLocally; + const localPath = version.localPath; + + // Check if this is an early access version + const isEarlyAccess = version.availability === 'EarlyAccess'; + + // Create early access badge if needed + let earlyAccessBadge = ''; + if (isEarlyAccess) { + earlyAccessBadge = ` +
+ Early Access +
+ `; + } + + // Status badge for local models + const localStatus = existsLocally ? + `
+ In Library +
${localPath || ''}
+
` : ''; + + return ` +
+
+ Version preview +
+
+
+

${version.name}

+ ${localStatus} +
+
+ ${version.baseModel ? `
${version.baseModel}
` : ''} + ${earlyAccessBadge} +
+
+ ${new Date(version.createdAt).toLocaleDateString()} + ${fileSize} MB +
+
+
+ `; + }).join(''); + + // Update Next button state based on initial selection + this.updateNextButtonState(); + } + + selectVersion(versionId) { + this.currentVersion = this.versions.find(v => v.id.toString() === versionId.toString()); + if (!this.currentVersion) return; + + document.querySelectorAll('#cpVersionList .version-item').forEach(item => { + item.classList.toggle('selected', item.querySelector('h3').textContent === this.currentVersion.name); + }); + + // Update Next button state after selection + this.updateNextButtonState(); + } + + updateNextButtonState() { + const nextButton = document.querySelector('#cpVersionStep .primary-btn'); + if (!nextButton) return; + + const existsLocally = this.currentVersion?.existsLocally; + + if (existsLocally) { + nextButton.disabled = true; + nextButton.classList.add('disabled'); + nextButton.textContent = 'Already in Library'; + } else { + nextButton.disabled = false; + nextButton.classList.remove('disabled'); + nextButton.textContent = 'Next'; + } + } + + async proceedToLocation() { + if (!this.currentVersion) { + showToast('Please select a version', 'error'); + return; + } + + // Double-check if the version exists locally + const existsLocally = this.currentVersion.existsLocally; + if (existsLocally) { + showToast('This version already exists in your library', 'info'); + return; + } + + document.getElementById('cpVersionStep').style.display = 'none'; + document.getElementById('cpLocationStep').style.display = 'block'; + + try { + // Use checkpoint roots endpoint instead of lora roots + const response = await fetch('/api/checkpoints/roots'); + if (!response.ok) { + throw new Error('Failed to fetch checkpoint roots'); + } + + const data = await response.json(); + const checkpointRoot = document.getElementById('checkpointRoot'); + checkpointRoot.innerHTML = data.roots.map(root => + `` + ).join(''); + + // Set default checkpoint root if available + const defaultRoot = getStorageItem('settings', {}).default_checkpoints_root; + if (defaultRoot && data.roots.includes(defaultRoot)) { + checkpointRoot.value = defaultRoot; + } + + // Initialize folder browser after loading roots + this.initializeFolderBrowser(); + } catch (error) { + showToast(error.message, 'error'); + } + } + + backToUrl() { + document.getElementById('cpVersionStep').style.display = 'none'; + document.getElementById('cpUrlStep').style.display = 'block'; + } + + backToVersions() { + document.getElementById('cpLocationStep').style.display = 'none'; + document.getElementById('cpVersionStep').style.display = 'block'; + } + + async startDownload() { + const checkpointRoot = document.getElementById('checkpointRoot').value; + const newFolder = document.getElementById('cpNewFolder').value.trim(); + + if (!checkpointRoot) { + showToast('Please select a checkpoint root directory', 'error'); + return; + } + + // Construct relative path + let targetFolder = ''; + if (this.selectedFolder) { + targetFolder = this.selectedFolder; + } + if (newFolder) { + targetFolder = targetFolder ? + `${targetFolder}/${newFolder}` : newFolder; + } + + try { + const downloadUrl = this.currentVersion.downloadUrl; + if (!downloadUrl) { + throw new Error('No download URL available'); + } + + // Show enhanced loading with progress details + const updateProgress = this.loadingManager.showDownloadProgress(1); + updateProgress(0, 0, this.currentVersion.name); + + // Setup WebSocket for progress updates + const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; + const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`); + + ws.onmessage = (event) => { + const data = JSON.parse(event.data); + if (data.status === 'progress') { + // Update progress display with current progress + updateProgress(data.progress, 0, this.currentVersion.name); + + // Add more detailed status messages based on progress + if (data.progress < 3) { + this.loadingManager.setStatus(`Preparing download...`); + } else if (data.progress === 3) { + this.loadingManager.setStatus(`Downloaded preview image`); + } else if (data.progress > 3 && data.progress < 100) { + this.loadingManager.setStatus(`Downloading checkpoint file`); + } else { + this.loadingManager.setStatus(`Finalizing download...`); + } + } + }; + + ws.onerror = (error) => { + console.error('WebSocket error:', error); + // Continue with download even if WebSocket fails + }; + + // Start download using checkpoint download endpoint + const response = await fetch('/api/checkpoints/download', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + download_url: downloadUrl, + checkpoint_root: checkpointRoot, + relative_path: targetFolder + }) + }); + + if (!response.ok) { + throw new Error(await response.text()); + } + + showToast('Download completed successfully', 'success'); + modalManager.closeModal('checkpointDownloadModal'); + + // Update state and trigger reload with folder update + state.activeFolder = targetFolder; + await resetAndReload(true); // Pass true to update folders + + } catch (error) { + showToast(error.message, 'error'); + } finally { + this.loadingManager.hide(); + } + } + + initializeFolderBrowser() { + const folderBrowser = document.getElementById('cpFolderBrowser'); + if (!folderBrowser) return; + + // Cleanup existing handler if any + this.cleanupFolderBrowser(); + + // Create new handler + this.folderClickHandler = (event) => { + const folderItem = event.target.closest('.folder-item'); + if (!folderItem) return; + + if (folderItem.classList.contains('selected')) { + folderItem.classList.remove('selected'); + this.selectedFolder = ''; + } else { + folderBrowser.querySelectorAll('.folder-item').forEach(f => + f.classList.remove('selected')); + folderItem.classList.add('selected'); + this.selectedFolder = folderItem.dataset.folder; + } + + // Update path display after folder selection + this.updateTargetPath(); + }; + + // Add the new handler + folderBrowser.addEventListener('click', this.folderClickHandler); + + // Add event listeners for path updates + const checkpointRoot = document.getElementById('checkpointRoot'); + const newFolder = document.getElementById('cpNewFolder'); + + checkpointRoot.addEventListener('change', this.updateTargetPath); + newFolder.addEventListener('input', this.updateTargetPath); + + // Update initial path + this.updateTargetPath(); + } + + cleanupFolderBrowser() { + if (this.folderClickHandler) { + const folderBrowser = document.getElementById('cpFolderBrowser'); + if (folderBrowser) { + folderBrowser.removeEventListener('click', this.folderClickHandler); + this.folderClickHandler = null; + } + } + + // Remove path update listeners + const checkpointRoot = document.getElementById('checkpointRoot'); + const newFolder = document.getElementById('cpNewFolder'); + + if (checkpointRoot) checkpointRoot.removeEventListener('change', this.updateTargetPath); + if (newFolder) newFolder.removeEventListener('input', this.updateTargetPath); + } + + updateTargetPath() { + const pathDisplay = document.getElementById('cpTargetPathDisplay'); + const checkpointRoot = document.getElementById('checkpointRoot').value; + const newFolder = document.getElementById('cpNewFolder').value.trim(); + + let fullPath = checkpointRoot || 'Select a checkpoint root directory'; + + if (checkpointRoot) { + if (this.selectedFolder) { + fullPath += '/' + this.selectedFolder; + } + if (newFolder) { + fullPath += '/' + newFolder; + } + } + + pathDisplay.innerHTML = `${fullPath}`; + } +} \ No newline at end of file diff --git a/static/js/managers/CheckpointSearchManager.js b/static/js/managers/CheckpointSearchManager.js deleted file mode 100644 index 7567bc7b..00000000 --- a/static/js/managers/CheckpointSearchManager.js +++ /dev/null @@ -1,150 +0,0 @@ -/** - * CheckpointSearchManager - Specialized search manager for the Checkpoints page - * Extends the base SearchManager with checkpoint-specific functionality - */ -import { SearchManager } from './SearchManager.js'; -import { state } from '../state/index.js'; -import { showToast } from '../utils/uiHelpers.js'; - -export class CheckpointSearchManager extends SearchManager { - constructor(options = {}) { - super({ - page: 'checkpoints', - ...options - }); - - this.currentSearchTerm = ''; - - // Store this instance in the state - if (state) { - state.searchManager = this; - } - } - - async performSearch() { - const searchTerm = this.searchInput.value.trim().toLowerCase(); - - if (searchTerm === this.currentSearchTerm && !this.isSearching) { - return; // Avoid duplicate searches - } - - this.currentSearchTerm = searchTerm; - - const grid = document.getElementById('checkpointGrid'); - - if (!searchTerm) { - if (state) { - state.currentPage = 1; - } - this.resetAndReloadCheckpoints(); - return; - } - - try { - this.isSearching = true; - if (state && state.loadingManager) { - state.loadingManager.showSimpleLoading('Searching checkpoints...'); - } - - // Store current scroll position - const scrollPosition = window.pageYOffset || document.documentElement.scrollTop; - - if (state) { - state.currentPage = 1; - state.hasMore = true; - } - - const url = new URL('/api/checkpoints', window.location.origin); - url.searchParams.set('page', '1'); - url.searchParams.set('page_size', '20'); - url.searchParams.set('sort_by', state ? state.sortBy : 'name'); - url.searchParams.set('search', searchTerm); - url.searchParams.set('fuzzy', 'true'); - - // Add search options - const searchOptions = this.getActiveSearchOptions(); - url.searchParams.set('search_filename', searchOptions.filename.toString()); - url.searchParams.set('search_modelname', searchOptions.modelname.toString()); - - // Always send folder parameter if there is an active folder - if (state && state.activeFolder) { - url.searchParams.set('folder', state.activeFolder); - // Add recursive parameter when recursive search is enabled - const recursive = this.recursiveSearchToggle ? this.recursiveSearchToggle.checked : false; - url.searchParams.set('recursive', recursive.toString()); - } - - const response = await fetch(url); - - if (!response.ok) { - throw new Error('Search failed'); - } - - const data = await response.json(); - - if (searchTerm === this.currentSearchTerm && grid) { - grid.innerHTML = ''; - - if (data.items.length === 0) { - grid.innerHTML = '
No matching checkpoints found
'; - if (state) { - state.hasMore = false; - } - } else { - this.appendCheckpointCards(data.items); - if (state) { - state.hasMore = state.currentPage < data.total_pages; - state.currentPage++; - } - } - - // Restore scroll position after content is loaded - setTimeout(() => { - window.scrollTo({ - top: scrollPosition, - behavior: 'instant' // Use 'instant' to prevent animation - }); - }, 10); - } - } catch (error) { - console.error('Checkpoint search error:', error); - showToast('Checkpoint search failed', 'error'); - } finally { - this.isSearching = false; - if (state && state.loadingManager) { - state.loadingManager.hide(); - } - } - } - - resetAndReloadCheckpoints() { - // This function would be implemented in the checkpoints page - if (typeof window.loadCheckpoints === 'function') { - window.loadCheckpoints(); - } else { - // Fallback to reloading the page - window.location.reload(); - } - } - - appendCheckpointCards(checkpoints) { - // This function would be implemented in the checkpoints page - const grid = document.getElementById('checkpointGrid'); - if (!grid) return; - - if (typeof window.appendCheckpointCards === 'function') { - window.appendCheckpointCards(checkpoints); - } else { - // Fallback implementation - checkpoints.forEach(checkpoint => { - const card = document.createElement('div'); - card.className = 'checkpoint-card'; - card.innerHTML = ` -

${checkpoint.name}

-

${checkpoint.filename || 'No filename'}

- `; - grid.appendChild(card); - }); - } - } -} \ No newline at end of file diff --git a/static/js/managers/FilterManager.js b/static/js/managers/FilterManager.js index 96d31388..1ebdc9e2 100644 --- a/static/js/managers/FilterManager.js +++ b/static/js/managers/FilterManager.js @@ -1,7 +1,8 @@ -import { BASE_MODELS, BASE_MODEL_CLASSES } from '../utils/constants.js'; -import { state, getCurrentPageState } from '../state/index.js'; +import { BASE_MODEL_CLASSES } from '../utils/constants.js'; +import { getCurrentPageState } from '../state/index.js'; import { showToast, updatePanelPositions } from '../utils/uiHelpers.js'; import { loadMoreLoras } from '../api/loraApi.js'; +import { loadMoreCheckpoints } from '../api/checkpointApi.js'; import { removeStorageItem, setStorageItem, getStorageItem } from '../utils/storageHelpers.js'; export class FilterManager { @@ -70,8 +71,10 @@ export class FilterManager { let tagsEndpoint = '/api/loras/top-tags?limit=20'; if (this.currentPage === 'recipes') { tagsEndpoint = '/api/recipes/top-tags?limit=20'; + } else if (this.currentPage === 'checkpoints') { + tagsEndpoint = '/api/checkpoints/top-tags?limit=20'; } - + const response = await fetch(tagsEndpoint); if (!response.ok) throw new Error('Failed to fetch tags'); @@ -143,8 +146,8 @@ export class FilterManager { apiEndpoint = '/api/loras/base-models'; } else if (this.currentPage === 'recipes') { apiEndpoint = '/api/recipes/base-models'; - } else { - return; // No API endpoint for other pages + } else if (this.currentPage === 'checkpoints') { + apiEndpoint = '/api/checkpoints/base-models'; } // Fetch base models @@ -280,7 +283,7 @@ export class FilterManager { // For loras page, reset the page and reload await loadMoreLoras(true, true); } else if (this.currentPage === 'checkpoints' && window.checkpointManager) { - await window.checkpointManager.loadCheckpoints(true); + await loadMoreCheckpoints(true); } // Update filter button to show active state @@ -336,8 +339,8 @@ export class FilterManager { await window.recipeManager.loadRecipes(true); } else if (this.currentPage === 'loras') { await loadMoreLoras(true, true); - } else if (this.currentPage === 'checkpoints' && window.checkpointManager) { - await window.checkpointManager.loadCheckpoints(true); + } else if (this.currentPage === 'checkpoints') { + await loadMoreCheckpoints(true); } showToast(`Filters cleared`, 'info'); diff --git a/static/js/managers/ModalManager.js b/static/js/managers/ModalManager.js index 9b6ecf91..989a2806 100644 --- a/static/js/managers/ModalManager.js +++ b/static/js/managers/ModalManager.js @@ -23,6 +23,32 @@ export class ModalManager { }); } + // Add checkpointModal registration + const checkpointModal = document.getElementById('checkpointModal'); + if (checkpointModal) { + this.registerModal('checkpointModal', { + element: checkpointModal, + onClose: () => { + this.getModal('checkpointModal').element.style.display = 'none'; + document.body.classList.remove('modal-open'); + }, + closeOnOutsideClick: true + }); + } + + // Add checkpointDownloadModal registration + const checkpointDownloadModal = document.getElementById('checkpointDownloadModal'); + if (checkpointDownloadModal) { + this.registerModal('checkpointDownloadModal', { + element: checkpointDownloadModal, + onClose: () => { + this.getModal('checkpointDownloadModal').element.style.display = 'none'; + document.body.classList.remove('modal-open'); + }, + closeOnOutsideClick: true + }); + } + const deleteModal = document.getElementById('deleteModal'); if (deleteModal) { this.registerModal('deleteModal', { diff --git a/static/js/managers/SearchManager.js b/static/js/managers/SearchManager.js index 4d909b36..49e6749b 100644 --- a/static/js/managers/SearchManager.js +++ b/static/js/managers/SearchManager.js @@ -302,6 +302,7 @@ export class SearchManager { pageState.searchOptions = { filename: options.filename || false, modelname: options.modelname || false, + tags: options.tags || false, recursive: recursive }; } diff --git a/static/js/managers/SettingsManager.js b/static/js/managers/SettingsManager.js index 1f4f7cda..eff5655a 100644 --- a/static/js/managers/SettingsManager.js +++ b/static/js/managers/SettingsManager.js @@ -50,6 +50,11 @@ export class SettingsManager { observer.observe(settingsModal, { attributes: true }); } + // Add event listeners for all toggle-visibility buttons + document.querySelectorAll('.toggle-visibility').forEach(button => { + button.addEventListener('click', () => this.toggleInputVisibility(button)); + }); + this.initialized = true; } @@ -271,6 +276,19 @@ export class SettingsManager { } } + toggleInputVisibility(button) { + const input = button.parentElement.querySelector('input'); + const icon = button.querySelector('i'); + + if (input.type === 'password') { + input.type = 'text'; + icon.className = 'fas fa-eye-slash'; + } else { + input.type = 'password'; + icon.className = 'fas fa-eye'; + } + } + async reloadContent() { if (this.currentPage === 'loras') { // Reload the loras without updating folders @@ -387,17 +405,3 @@ export class SettingsManager { // Create singleton instance export const settingsManager = new SettingsManager(); - -// Helper function for toggling API key visibility -export function toggleApiKeyVisibility(button) { - const input = button.parentElement.querySelector('input'); - const icon = button.querySelector('i'); - - if (input.type === 'password') { - input.type = 'text'; - icon.className = 'fas fa-eye-slash'; - } else { - input.type = 'password'; - icon.className = 'fas fa-eye'; - } -} diff --git a/static/js/recipes.js b/static/js/recipes.js index ba55e62c..be87a914 100644 --- a/static/js/recipes.js +++ b/static/js/recipes.js @@ -4,7 +4,6 @@ import { ImportManager } from './managers/ImportManager.js'; import { RecipeCard } from './components/RecipeCard.js'; import { RecipeModal } from './components/RecipeModal.js'; import { getCurrentPageState } from './state/index.js'; -import { toggleApiKeyVisibility } from './managers/SettingsManager.js'; import { getSessionItem, removeSessionItem } from './utils/storageHelpers.js'; class RecipeManager { @@ -67,7 +66,6 @@ class RecipeManager { // Only expose what's needed for the page window.recipeManager = this; window.importManager = this.importManager; - window.toggleApiKeyVisibility = toggleApiKeyVisibility; } _checkCustomFilter() { @@ -251,6 +249,11 @@ class RecipeManager { // Update pagination state based on current page and total pages this.pageState.hasMore = data.page < data.total_pages; + // Increment the page number AFTER successful loading + if (data.items.length > 0) { + this.pageState.currentPage++; + } + } catch (error) { console.error('Error loading recipes:', error); appCore.showToast('Failed to load recipes', 'error'); diff --git a/static/js/utils/cardUpdater.js b/static/js/utils/cardUpdater.js new file mode 100644 index 00000000..ce099f98 --- /dev/null +++ b/static/js/utils/cardUpdater.js @@ -0,0 +1,128 @@ +/** + * Utility functions to update checkpoint cards after modal edits + */ + +/** + * Update the checkpoint card after metadata edits in the modal + * @param {string} filePath - Path to the checkpoint file + * @param {Object} updates - Object containing the updates (model_name, base_model, etc) + */ +export function updateCheckpointCard(filePath, updates) { + // Find the card with matching filepath + const checkpointCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); + if (!checkpointCard) return; + + // Update card dataset and visual elements based on the updates object + Object.entries(updates).forEach(([key, value]) => { + // Update dataset + checkpointCard.dataset[key] = value; + + // Update visual elements based on the property + switch(key) { + case 'name': // model_name + // Update the model name in the footer + const modelNameElement = checkpointCard.querySelector('.model-name'); + if (modelNameElement) modelNameElement.textContent = value; + break; + + case 'base_model': + // Update the base model label in the card header + const baseModelLabel = checkpointCard.querySelector('.base-model-label'); + if (baseModelLabel) { + baseModelLabel.textContent = value; + baseModelLabel.title = value; + } + break; + + case 'filepath': + // The filepath was changed (file renamed), update the dataset + checkpointCard.dataset.filepath = value; + break; + + case 'tags': + // Update tags if they're displayed on the card + try { + checkpointCard.dataset.tags = JSON.stringify(value); + } catch (e) { + console.error('Failed to update tags:', e); + } + break; + + // Add other properties as needed + } + }); +} + +/** + * Update the Lora card after metadata edits in the modal + * @param {string} filePath - Path to the Lora file + * @param {Object} updates - Object containing the updates (model_name, base_model, notes, usage_tips, etc) + * @param {string} [newFilePath] - Optional new file path if the file has been renamed + */ +export function updateLoraCard(filePath, updates, newFilePath) { + // Find the card with matching filepath + const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); + if (!loraCard) return; + + // If file was renamed, update the filepath first + if (newFilePath) { + loraCard.dataset.filepath = newFilePath; + } + + // Update card dataset and visual elements based on the updates object + Object.entries(updates).forEach(([key, value]) => { + // Update dataset + loraCard.dataset[key] = value; + + // Update visual elements based on the property + switch(key) { + case 'model_name': + // Update the model name in the card title + const titleElement = loraCard.querySelector('.card-title'); + if (titleElement) titleElement.textContent = value; + + // Also update the model name in the footer if it exists + const modelNameElement = loraCard.querySelector('.model-name'); + if (modelNameElement) modelNameElement.textContent = value; + break; + + case 'file_name': + // Update the file_name in the dataset + loraCard.dataset.file_name = value; + break; + + case 'base_model': + // Update the base model label in the card header if it exists + const baseModelLabel = loraCard.querySelector('.base-model-label'); + if (baseModelLabel) { + baseModelLabel.textContent = value; + baseModelLabel.title = value; + } + break; + + case 'tags': + // Update tags if they're displayed on the card + try { + if (typeof value === 'string') { + loraCard.dataset.tags = value; + } else { + loraCard.dataset.tags = JSON.stringify(value); + } + + // If there's a tag container, update its content + const tagContainer = loraCard.querySelector('.card-tags'); + if (tagContainer) { + // This depends on how your tags are rendered + // You may need to update this logic based on your tag rendering function + } + } catch (e) { + console.error('Failed to update tags:', e); + } + break; + + // No visual updates needed for notes, usage_tips as they're typically not shown on cards + } + }); + + return loraCard; // Return the updated card element for chaining +} \ No newline at end of file diff --git a/static/js/utils/infiniteScroll.js b/static/js/utils/infiniteScroll.js index ddb0e341..60795e14 100644 --- a/static/js/utils/infiniteScroll.js +++ b/static/js/utils/infiniteScroll.js @@ -1,5 +1,6 @@ import { state, getCurrentPageState } from '../state/index.js'; import { loadMoreLoras } from '../api/loraApi.js'; +import { loadMoreCheckpoints } from '../api/checkpointApi.js'; import { debounce } from './debounce.js'; export function initializeInfiniteScroll(pageType = 'loras') { @@ -21,7 +22,6 @@ export function initializeInfiniteScroll(pageType = 'loras') { case 'recipes': loadMoreFunction = () => { if (!pageState.isLoading && pageState.hasMore) { - pageState.currentPage++; window.recipeManager.loadRecipes(false); // false to not reset pagination } }; @@ -30,15 +30,18 @@ export function initializeInfiniteScroll(pageType = 'loras') { case 'checkpoints': loadMoreFunction = () => { if (!pageState.isLoading && pageState.hasMore) { - pageState.currentPage++; - window.checkpointManager.loadCheckpoints(false); // false to not reset pagination + loadMoreCheckpoints(false); // false to not reset } }; gridId = 'checkpointGrid'; break; case 'loras': default: - loadMoreFunction = () => loadMoreLoras(false); // false to not reset + loadMoreFunction = () => { + if (!pageState.isLoading && pageState.hasMore) { + loadMoreLoras(false); // false to not reset + } + }; gridId = 'loraGrid'; break; } @@ -85,4 +88,4 @@ export function initializeInfiniteScroll(pageType = 'loras') { state.observer.observe(sentinel); } -} \ No newline at end of file +} \ No newline at end of file diff --git a/static/js/utils/modalUtils.js b/static/js/utils/modalUtils.js index 49dfe15d..be5fe25d 100644 --- a/static/js/utils/modalUtils.js +++ b/static/js/utils/modalUtils.js @@ -1,10 +1,12 @@ import { modalManager } from '../managers/ModalManager.js'; let pendingDeletePath = null; +let pendingModelType = null; -export function showDeleteModal(filePath) { - event.stopPropagation(); +export function showDeleteModal(filePath, modelType = 'lora') { + // event.stopPropagation(); pendingDeletePath = filePath; + pendingModelType = modelType; const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); const modelName = card.dataset.name; @@ -23,11 +25,15 @@ export function showDeleteModal(filePath) { export async function confirmDelete() { if (!pendingDeletePath) return; - const modal = document.getElementById('deleteModal'); const card = document.querySelector(`.lora-card[data-filepath="${pendingDeletePath}"]`); try { - const response = await fetch('/api/delete_model', { + // Use the appropriate endpoint based on model type + const endpoint = pendingModelType === 'checkpoint' ? + '/api/checkpoints/delete' : + '/api/delete_model'; + + const response = await fetch(endpoint, { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -53,4 +59,6 @@ export async function confirmDelete() { export function closeDeleteModal() { modalManager.closeModal('deleteModal'); -} \ No newline at end of file + pendingDeletePath = null; + pendingModelType = null; +} \ No newline at end of file diff --git a/templates/checkpoints.html b/templates/checkpoints.html index f4a4d6b8..4ff483df 100644 --- a/templates/checkpoints.html +++ b/templates/checkpoints.html @@ -8,18 +8,19 @@ {% endblock %} {% block init_title %}Initializing Checkpoints Manager{% endblock %} -{% block init_message %}Setting up checkpoints interface. This may take a few moments...{% endblock %} +{% block init_message %}Scanning and building checkpoints cache. This may take a few moments...{% endblock %} {% block init_check_url %}/api/checkpoints?page=1&page_size=1{% endblock %} +{% block additional_components %} +{% include 'components/checkpoint_modals.html' %} +{% endblock %} + {% block content %} -
-
- -

Checkpoints Manager

-

This feature is currently under development and will be available soon.

-

Please check back later for updates!

+ {% include 'components/controls.html' %} + +
+
-
{% endblock %} {% block main_script %} diff --git a/templates/components/checkpoint_modals.html b/templates/components/checkpoint_modals.html new file mode 100644 index 00000000..0a0ea127 --- /dev/null +++ b/templates/components/checkpoint_modals.html @@ -0,0 +1,104 @@ + + + + + + + + + \ No newline at end of file diff --git a/templates/components/controls.html b/templates/components/controls.html index 45acc7a4..56921cd5 100644 --- a/templates/components/controls.html +++ b/templates/components/controls.html @@ -16,21 +16,25 @@
- + +
+ +
+
- -
-
-
+ + {% if request.path == '/loras' %}
-
+ {% endif %}
-
diff --git a/templates/components/header.html b/templates/components/header.html index f9ddc624..51a7ba06 100644 --- a/templates/components/header.html +++ b/templates/components/header.html @@ -74,6 +74,7 @@ {% elif request.path == '/checkpoints' %}
Filename
Checkpoint Name
+
Tags
{% else %}
Filename
diff --git a/templates/components/lora_modals.html b/templates/components/lora_modals.html index a6ab03d1..8eeb438b 100644 --- a/templates/components/lora_modals.html +++ b/templates/components/lora_modals.html @@ -1,7 +1,10 @@ - + + + + - \ No newline at end of file + \ No newline at end of file diff --git a/templates/components/modals.html b/templates/components/modals.html index 2e489d44..aec1246a 100644 --- a/templates/components/modals.html +++ b/templates/components/modals.html @@ -30,7 +30,7 @@ value="{{ settings.get('civitai_api_key', '') }}" onblur="settingsManager.saveInputSetting('civitaiApiKey', 'civitai_api_key')" onkeydown="if(event.key === 'Enter') { this.blur(); }" /> -