From 8fdfb687419731171f0f877ac8575d6e09b0c321 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 10 Apr 2025 09:08:36 +0800 Subject: [PATCH 01/36] checkpoint --- py/config.py | 2 +- py/lora_manager.py | 107 ++++-- py/routes/checkpoints_routes.py | 156 +++++++-- py/services/checkpoint_scanner.py | 131 +++++++ py/services/file_monitor.py | 90 ++++- py/services/lora_scanner.py | 426 +++++------------------ py/services/model_cache.py | 64 ++++ py/services/model_hash_index.py | 78 +++++ py/services/model_scanner.py | 554 ++++++++++++++++++++++++++++++ py/utils/file_utils.py | 91 +++-- py/utils/lora_metadata.py | 66 +++- py/utils/models.py | 116 ++++--- 12 files changed, 1397 insertions(+), 484 deletions(-) create mode 100644 py/services/checkpoint_scanner.py create mode 100644 py/services/model_cache.py create mode 100644 py/services/model_hash_index.py create mode 100644 py/services/model_scanner.py diff --git a/py/config.py b/py/config.py index a081f472..b0ca1f4c 100644 --- a/py/config.py +++ b/py/config.py @@ -73,7 +73,7 @@ class Config: """添加静态路由映射""" normalized_path = os.path.normpath(path).replace(os.sep, '/') self._route_mappings[normalized_path] = route - logger.info(f"Added route mapping: {normalized_path} -> {route}") + # logger.info(f"Added route mapping: {normalized_path} -> {route}") def map_path_to_link(self, path: str) -> str: """将目标路径映射回符号链接路径""" diff --git a/py/lora_manager.py b/py/lora_manager.py index a7ab3fb8..14556fa4 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -7,10 +7,12 @@ from .routes.api_routes import ApiRoutes from .routes.recipe_routes import RecipeRoutes from .routes.checkpoints_routes import CheckpointsRoutes from .services.lora_scanner import LoraScanner +from .services.checkpoint_scanner import CheckpointScanner from .services.recipe_scanner import RecipeScanner -from .services.file_monitor import LoraFileMonitor +from .services.file_monitor import LoraFileMonitor, CheckpointFileMonitor from .services.lora_cache import LoraCache from .services.recipe_cache import RecipeCache +from .services.model_cache import ModelCache import logging logger = logging.getLogger(__name__) @@ -23,7 +25,7 @@ class LoraManager: """Initialize and register all routes""" app = PromptServer.instance.app - added_targets = set() # 用于跟踪已添加的目标路径 + added_targets = set() # Track already added target paths # Add static routes for each lora root for idx, root in enumerate(config.loras_roots, start=1): @@ -35,15 +37,34 @@ class LoraManager: if link == root: real_root = target break - # 为原始路径添加静态路由 + # Add static route for original path app.router.add_static(preview_path, real_root) logger.info(f"Added static route {preview_path} -> {real_root}") - # 记录路由映射 + # Record route mapping config.add_route_mapping(real_root, preview_path) added_targets.add(real_root) - # 为符号链接的目标路径添加额外的静态路由 + # Add static routes for each checkpoint root + checkpoint_scanner = CheckpointScanner() + for idx, root in enumerate(checkpoint_scanner.get_model_roots(), start=1): + preview_path = f'/checkpoints_static/root{idx}/preview' + + real_root = root + if root in config._path_mappings.values(): + for target, link in config._path_mappings.items(): + if link == root: + real_root = target + break + # Add static route for original path + app.router.add_static(preview_path, real_root) + logger.info(f"Added static route {preview_path} -> {real_root}") + + # Record route mapping + config.add_route_mapping(real_root, preview_path) + added_targets.add(real_root) + + # Add static routes for symlink target paths link_idx = 1 for target_path, link_path in config._path_mappings.items(): @@ -59,37 +80,47 @@ class LoraManager: app.router.add_static('/loras_static', config.static_path) # Setup feature routes - routes = LoraRoutes() + lora_routes = LoraRoutes() checkpoints_routes = CheckpointsRoutes() # Setup file monitoring - monitor = LoraFileMonitor(routes.scanner, config.loras_roots) - monitor.start() + lora_monitor = LoraFileMonitor(lora_routes.scanner, config.loras_roots) + lora_monitor.start() - routes.setup_routes(app) + checkpoint_monitor = CheckpointFileMonitor(checkpoints_routes.scanner, checkpoints_routes.scanner.get_model_roots()) + checkpoint_monitor.start() + + lora_routes.setup_routes(app) checkpoints_routes.setup_routes(app) - ApiRoutes.setup_routes(app, monitor) + ApiRoutes.setup_routes(app, lora_monitor) RecipeRoutes.setup_routes(app) - # Store monitor in app for cleanup - app['lora_monitor'] = monitor + # Store monitors in app for cleanup + app['lora_monitor'] = lora_monitor + app['checkpoint_monitor'] = checkpoint_monitor + + logger.info("PromptServer app: ", app) # Schedule cache initialization using the application's startup handler - app.on_startup.append(lambda app: cls._schedule_cache_init(routes.scanner, routes.recipe_scanner)) + app.on_startup.append(lambda app: cls._schedule_cache_init( + lora_routes.scanner, + checkpoints_routes.scanner, + lora_routes.recipe_scanner + )) # Add cleanup app.on_shutdown.append(cls._cleanup) app.on_shutdown.append(ApiRoutes.cleanup) @classmethod - async def _schedule_cache_init(cls, scanner: LoraScanner, recipe_scanner: RecipeScanner): + async def _schedule_cache_init(cls, lora_scanner, checkpoint_scanner, recipe_scanner): """Schedule cache initialization in the running event loop""" try: - # 创建低优先级的初始化任务 - lora_task = asyncio.create_task(cls._initialize_lora_cache(scanner), name='lora_cache_init') - - # Schedule recipe cache initialization with a delay to let lora scanner initialize first - recipe_task = asyncio.create_task(cls._initialize_recipe_cache(recipe_scanner, delay=2), name='recipe_cache_init') + # Create low-priority initialization tasks + lora_task = asyncio.create_task(cls._initialize_lora_cache(lora_scanner), name='lora_cache_init') + checkpoint_task = asyncio.create_task(cls._initialize_checkpoint_cache(checkpoint_scanner), name='checkpoint_cache_init') + recipe_task = asyncio.create_task(cls._initialize_recipe_cache(recipe_scanner), name='recipe_cache_init') + logger.info("Cache initialization tasks scheduled to run in background") except Exception as e: logger.error(f"LoRA Manager: Error scheduling cache initialization: {e}") @@ -97,26 +128,45 @@ class LoraManager: async def _initialize_lora_cache(cls, scanner: LoraScanner): """Initialize lora cache in background""" try: - # 设置初始缓存占位 + # Set initial placeholder cache scanner._cache = LoraCache( raw_data=[], sorted_by_name=[], sorted_by_date=[], folders=[] ) - - # 分阶段加载缓存 - await scanner.get_cached_data(force_refresh=True) + # 使用线程池执行耗时操作 + loop = asyncio.get_event_loop() + await loop.run_in_executor( + None, # 使用默认线程池 + lambda: scanner.get_cached_data_sync(force_refresh=True) # 创建同步版本的方法 + ) + # Load cache in phases + # await scanner.get_cached_data(force_refresh=True) except Exception as e: logger.error(f"LoRA Manager: Error initializing lora cache: {e}") @classmethod - async def _initialize_recipe_cache(cls, scanner: RecipeScanner, delay: float = 2.0): - """Initialize recipe cache in background with a delay""" + async def _initialize_checkpoint_cache(cls, scanner: CheckpointScanner): + """Initialize checkpoint cache in background""" try: - # Wait for the specified delay to let lora scanner initialize first - await asyncio.sleep(delay) + # Set initial placeholder cache + scanner._cache = ModelCache( + raw_data=[], + sorted_by_name=[], + sorted_by_date=[], + folders=[] + ) + # Load cache in phases + await scanner.get_cached_data(force_refresh=True) + except Exception as e: + logger.error(f"LoRA Manager: Error initializing checkpoint cache: {e}") + + @classmethod + async def _initialize_recipe_cache(cls, scanner: RecipeScanner): + """Initialize recipe cache in background with a delay""" + try: # Set initial empty cache scanner._cache = RecipeCache( raw_data=[], @@ -134,3 +184,6 @@ class LoraManager: """Cleanup resources""" if 'lora_monitor' in app: app['lora_monitor'].stop() + + if 'checkpoint_monitor' in app: + app['checkpoint_monitor'].stop() diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index 0a79d6f9..f2e2e631 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -1,44 +1,146 @@ import os +import json +import asyncio +import aiohttp from aiohttp import web -import jinja2 import logging +from datetime import datetime + +from ..services.checkpoint_scanner import CheckpointScanner from ..config import config -from ..services.settings_manager import settings logger = logging.getLogger(__name__) -logging.getLogger('asyncio').setLevel(logging.CRITICAL) class CheckpointsRoutes: - """Route handlers for Checkpoints management endpoints""" + """API routes for checkpoint management""" def __init__(self): - self.template_env = jinja2.Environment( - loader=jinja2.FileSystemLoader(config.templates_path), - autoescape=True - ) + self.scanner = CheckpointScanner() + + def setup_routes(self, app): + """Register routes with the aiohttp app""" + app.router.add_get('/lora_manager/api/checkpoints', self.get_checkpoints) + app.router.add_get('/lora_manager/api/checkpoints/scan', self.scan_checkpoints) + app.router.add_get('/lora_manager/api/checkpoints/info/{name}', self.get_checkpoint_info) - async def handle_checkpoints_page(self, request: web.Request) -> web.Response: - """Handle GET /checkpoints request""" + async def get_checkpoints(self, request): + """Get paginated checkpoint data""" try: - template = self.template_env.get_template('checkpoints.html') - rendered = template.render( - is_initializing=False, - settings=settings, - request=request + # Parse query parameters + page = int(request.query.get('page', '1')) + page_size = min(int(request.query.get('page_size', '20')), 100) + sort_by = request.query.get('sort', 'name') + folder = request.query.get('folder', None) + search = request.query.get('search', None) + fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true' + base_models = request.query.getall('base_model', []) + tags = request.query.getall('tag', []) + + # Process search options + search_options = { + 'filename': request.query.get('search_filename', 'true').lower() == 'true', + 'modelname': request.query.get('search_modelname', 'true').lower() == 'true', + 'tags': request.query.get('search_tags', 'false').lower() == 'true', + 'recursive': request.query.get('recursive', 'false').lower() == 'true', + } + + # Process hash filters if provided + hash_filters = {} + if 'hash' in request.query: + hash_filters['single_hash'] = request.query['hash'] + elif 'hashes' in request.query: + try: + hash_list = json.loads(request.query['hashes']) + if isinstance(hash_list, list): + hash_filters['multiple_hashes'] = hash_list + except (json.JSONDecodeError, TypeError): + pass + + # Get data from scanner + result = await self.get_paginated_data( + page=page, + page_size=page_size, + sort_by=sort_by, + folder=folder, + search=search, + fuzzy_search=fuzzy_search, + base_models=base_models, + tags=tags, + search_options=search_options, + hash_filters=hash_filters ) - return web.Response( - text=rendered, - content_type='text/html' - ) + # Return as JSON + return web.json_response(result) except Exception as e: - logger.error(f"Error handling checkpoints request: {e}", exc_info=True) - return web.Response( - text="Error loading checkpoints page", - status=500 - ) + logger.error(f"Error in get_checkpoints: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) - def setup_routes(self, app: web.Application): - """Register routes with the application""" - app.router.add_get('/checkpoints', self.handle_checkpoints_page) + async def get_paginated_data(self, page, page_size, sort_by='name', + folder=None, search=None, fuzzy_search=False, + base_models=None, tags=None, + search_options=None, hash_filters=None): + """Get paginated and filtered checkpoint data""" + cache = await self.scanner.get_cached_data() + + # Implement similar filtering logic as in LoraScanner + # (Adapt code from LoraScanner.get_paginated_data) + # ... + + # For now, a simplified implementation: + filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name + + # Apply basic folder filtering if needed + if folder is not None: + filtered_data = [ + cp for cp in filtered_data + if cp['folder'] == folder + ] + + # Apply basic search if needed + if search: + filtered_data = [ + cp for cp in filtered_data + if search.lower() in cp['file_name'].lower() or + search.lower() in cp['model_name'].lower() + ] + + # Calculate pagination + total_items = len(filtered_data) + start_idx = (page - 1) * page_size + end_idx = min(start_idx + page_size, total_items) + + result = { + 'items': filtered_data[start_idx:end_idx], + 'total': total_items, + 'page': page, + 'page_size': page_size, + 'total_pages': (total_items + page_size - 1) // page_size + } + + return result + + async def scan_checkpoints(self, request): + """Force a rescan of checkpoint files""" + try: + await self.scanner.get_cached_data(force_refresh=True) + return web.json_response({"status": "success", "message": "Checkpoint scan completed"}) + except Exception as e: + logger.error(f"Error in scan_checkpoints: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) + + async def get_checkpoint_info(self, request): + """Get detailed information for a specific checkpoint by name""" + try: + name = request.match_info.get('name', '') + checkpoint_info = await self.scanner.get_checkpoint_info_by_name(name) + + if checkpoint_info: + return web.json_response(checkpoint_info) + else: + return web.json_response({"error": "Checkpoint not found"}, status=404) + + except Exception as e: + logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) diff --git a/py/services/checkpoint_scanner.py b/py/services/checkpoint_scanner.py new file mode 100644 index 00000000..27d15273 --- /dev/null +++ b/py/services/checkpoint_scanner.py @@ -0,0 +1,131 @@ +import os +import logging +import asyncio +from typing import List, Dict, Optional, Set +import folder_paths # type: ignore + +from ..utils.models import CheckpointMetadata +from ..config import config +from .model_scanner import ModelScanner +from .model_hash_index import ModelHashIndex + +logger = logging.getLogger(__name__) + +class CheckpointScanner(ModelScanner): + """Service for scanning and managing checkpoint files""" + + _instance = None + _lock = asyncio.Lock() + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if not hasattr(self, '_initialized'): + # Define supported file extensions + file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft'} + super().__init__( + model_type="checkpoint", + model_class=CheckpointMetadata, + file_extensions=file_extensions, + hash_index=ModelHashIndex() + ) + self._checkpoint_roots = self._init_checkpoint_roots() + self._initialized = True + + @classmethod + async def get_instance(cls): + """Get singleton instance with async support""" + async with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def _init_checkpoint_roots(self) -> List[str]: + """Initialize checkpoint roots from ComfyUI settings""" + # Get both checkpoint and diffusion_models paths + checkpoint_paths = folder_paths.get_folder_paths("checkpoints") + diffusion_paths = folder_paths.get_folder_paths("diffusion_models") + + # Combine, normalize and deduplicate paths + all_paths = set() + for path in checkpoint_paths + diffusion_paths: + if os.path.exists(path): + norm_path = path.replace(os.sep, "/") + all_paths.add(norm_path) + + # Sort for consistent order + sorted_paths = sorted(all_paths, key=lambda p: p.lower()) + logger.info(f"Found checkpoint roots: {sorted_paths}") + + return sorted_paths + + def get_model_roots(self) -> List[str]: + """Get checkpoint root directories""" + return self._checkpoint_roots + + async def scan_all_models(self) -> List[Dict]: + """Scan all checkpoint directories and return metadata""" + all_checkpoints = [] + + # Create scan tasks for each directory + scan_tasks = [] + for root in self._checkpoint_roots: + task = asyncio.create_task(self._scan_directory(root)) + scan_tasks.append(task) + + # Wait for all tasks to complete + for task in scan_tasks: + try: + checkpoints = await task + all_checkpoints.extend(checkpoints) + except Exception as e: + logger.error(f"Error scanning checkpoint directory: {e}") + + return all_checkpoints + + async def _scan_directory(self, root_path: str) -> List[Dict]: + """Scan a directory for checkpoint files""" + checkpoints = [] + original_root = root_path + + async def scan_recursive(path: str, visited_paths: set): + try: + real_path = os.path.realpath(path) + if real_path in visited_paths: + logger.debug(f"Skipping already visited path: {path}") + return + visited_paths.add(real_path) + + with os.scandir(path) as it: + entries = list(it) + for entry in entries: + try: + if entry.is_file(follow_symlinks=True): + # Check if file has supported extension + ext = os.path.splitext(entry.name)[1].lower() + if ext in self.file_extensions: + file_path = entry.path.replace(os.sep, "/") + await self._process_single_file(file_path, original_root, checkpoints) + await asyncio.sleep(0) + elif entry.is_dir(follow_symlinks=True): + # For directories, continue scanning with original path + await scan_recursive(entry.path, visited_paths) + except Exception as e: + logger.error(f"Error processing entry {entry.path}: {e}") + except Exception as e: + logger.error(f"Error scanning {path}: {e}") + + await scan_recursive(root_path, set()) + return checkpoints + + async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list): + """Process a single checkpoint file and add to results""" + try: + result = await self._process_model_file(file_path, root_path) + if result: + checkpoints.append(result) + except Exception as e: + logger.error(f"Error processing {file_path}: {e}") \ No newline at end of file diff --git a/py/services/file_monitor.py b/py/services/file_monitor.py index 9ed44d0f..1b9438f3 100644 --- a/py/services/file_monitor.py +++ b/py/services/file_monitor.py @@ -7,6 +7,8 @@ from watchdog.observers import Observer from watchdog.events import FileSystemEventHandler from typing import List, Dict, Set from threading import Lock + +from .checkpoint_scanner import CheckpointScanner from .lora_scanner import LoraScanner from ..config import config @@ -330,4 +332,90 @@ class LoraFileMonitor: self.observer.schedule(self.handler, path, recursive=True) logger.info(f"Added new monitoring path: {path}") except Exception as e: - logger.error(f"Error adding new monitor for {path}: {e}") \ No newline at end of file + logger.error(f"Error adding new monitor for {path}: {e}") + +# Add CheckpointFileMonitor class + +class CheckpointFileMonitor(LoraFileMonitor): + """Monitor for checkpoint file changes""" + + def __init__(self, scanner: CheckpointScanner, roots: List[str]): + # Reuse most of the LoraFileMonitor functionality, but with a different handler + self.scanner = scanner + scanner.set_file_monitor(self) + self.observer = Observer() + self.loop = asyncio.get_event_loop() + self.handler = CheckpointFileHandler(scanner, self.loop) + + # Use existing path mappings + self.monitor_paths = set() + for root in roots: + self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/')) + + # Add all mapped target paths + for target_path in config._path_mappings.keys(): + self.monitor_paths.add(target_path) + +class CheckpointFileHandler(LoraFileHandler): + """Handler for checkpoint file system events""" + + def __init__(self, scanner: CheckpointScanner, loop: asyncio.AbstractEventLoop): + super().__init__(scanner, loop) + # Configure supported file extensions + self.supported_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft'} + + def on_created(self, event): + if event.is_directory: + return + + # Handle supported file extensions directly + file_ext = os.path.splitext(event.src_path)[1].lower() + if file_ext in self.supported_extensions: + if self._should_ignore(event.src_path): + return + + # Process this file directly + normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/') + if normalized_path not in self.scheduled_files: + logger.info(f"Checkpoint file created: {event.src_path}") + self.scheduled_files.add(normalized_path) + self._schedule_update('add', event.src_path) + + # Ignore modifications for a short period after creation + self.loop.call_later( + self.debounce_delay * 2, + self.scheduled_files.discard, + normalized_path + ) + + def on_modified(self, event): + if event.is_directory: + return + + # Only process supported file types + file_ext = os.path.splitext(event.src_path)[1].lower() + if file_ext in self.supported_extensions: + super().on_modified(event) + + def on_deleted(self, event): + if event.is_directory: + return + + file_ext = os.path.splitext(event.src_path)[1].lower() + if file_ext not in self.supported_extensions: + return + + super().on_deleted(event) + + def on_moved(self, event): + """Handle file move/rename events""" + src_ext = os.path.splitext(event.src_path)[1].lower() + dest_ext = os.path.splitext(event.dest_path)[1].lower() + + # If destination has supported extension, treat as new file + if dest_ext in self.supported_extensions: + super().on_moved(event) + + # If source was supported extension, treat as deleted + elif src_ext in self.supported_extensions: + super().on_moved(event) \ No newline at end of file diff --git a/py/services/lora_scanner.py b/py/services/lora_scanner.py index c8142086..c4e10c3d 100644 --- a/py/services/lora_scanner.py +++ b/py/services/lora_scanner.py @@ -4,13 +4,11 @@ import logging import asyncio import shutil import time -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Set from ..utils.models import LoraMetadata from ..config import config -from ..utils.file_utils import load_metadata, get_file_info, normalize_path, find_preview_file, save_metadata -from ..utils.lora_metadata import extract_lora_metadata -from .lora_cache import LoraCache +from .model_scanner import ModelScanner from .lora_hash_index import LoraHashIndex from .settings_manager import settings from ..utils.constants import NSFW_LEVELS @@ -19,7 +17,7 @@ import sys logger = logging.getLogger(__name__) -class LoraScanner: +class LoraScanner(ModelScanner): """Service for scanning and managing LoRA files""" _instance = None @@ -31,20 +29,20 @@ class LoraScanner: return cls._instance def __init__(self): - # 确保初始化只执行一次 + # Ensure initialization happens only once if not hasattr(self, '_initialized'): - self._cache: Optional[LoraCache] = None - self._hash_index = LoraHashIndex() - self._initialization_lock = asyncio.Lock() - self._initialization_task: Optional[asyncio.Task] = None + # Define supported file extensions + file_extensions = {'.safetensors'} + + # Initialize parent class + super().__init__( + model_type="lora", + model_class=LoraMetadata, + file_extensions=file_extensions, + hash_index=LoraHashIndex() + ) self._initialized = True - self.file_monitor = None # Add this line - self._tags_count = {} # Add a dictionary to store tag counts - - def set_file_monitor(self, monitor): - """Set file monitor instance""" - self.file_monitor = monitor - + @classmethod async def get_instance(cls): """Get singleton instance with async support""" @@ -52,89 +50,74 @@ class LoraScanner: if cls._instance is None: cls._instance = cls() return cls._instance - - async def get_cached_data(self, force_refresh: bool = False) -> LoraCache: - """Get cached LoRA data, refresh if needed""" - async with self._initialization_lock: + + def get_model_roots(self) -> List[str]: + """Get lora root directories""" + return config.loras_roots + + async def scan_all_models(self) -> List[Dict]: + """Scan all LoRA directories and return metadata""" + all_loras = [] + + # Create scan tasks for each directory + scan_tasks = [] + for lora_root in self.get_model_roots(): + task = asyncio.create_task(self._scan_directory(lora_root)) + scan_tasks.append(task) - # 如果缓存未初始化但需要响应请求,返回空缓存 - if self._cache is None and not force_refresh: - return LoraCache( - raw_data=[], - sorted_by_name=[], - sorted_by_date=[], - folders=[] - ) - - # 如果正在初始化,等待完成 - if self._initialization_task and not self._initialization_task.done(): - try: - await self._initialization_task - except Exception as e: - logger.error(f"Cache initialization failed: {e}") - self._initialization_task = None - - if (self._cache is None or force_refresh): + # Wait for all tasks to complete + for task in scan_tasks: + try: + loras = await task + all_loras.extend(loras) + except Exception as e: + logger.error(f"Error scanning directory: {e}") - # 创建新的初始化任务 - if not self._initialization_task or self._initialization_task.done(): - self._initialization_task = asyncio.create_task(self._initialize_cache()) + return all_loras + + async def _scan_directory(self, root_path: str) -> List[Dict]: + """Scan a single directory for LoRA files""" + loras = [] + original_root = root_path # Save original root path + + async def scan_recursive(path: str, visited_paths: set): + """Recursively scan directory, avoiding circular symlinks""" + try: + real_path = os.path.realpath(path) + if real_path in visited_paths: + logger.debug(f"Skipping already visited path: {path}") + return + visited_paths.add(real_path) - try: - await self._initialization_task - except Exception as e: - logger.error(f"Cache initialization failed: {e}") - # 如果缓存已存在,继续使用旧缓存 - if self._cache is None: - raise # 如果没有缓存,则抛出异常 - - return self._cache + with os.scandir(path) as it: + entries = list(it) + for entry in entries: + try: + if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions): + # Use original path instead of real path + file_path = entry.path.replace(os.sep, "/") + await self._process_single_file(file_path, original_root, loras) + await asyncio.sleep(0) + elif entry.is_dir(follow_symlinks=True): + # For directories, continue scanning with original path + await scan_recursive(entry.path, visited_paths) + except Exception as e: + logger.error(f"Error processing entry {entry.path}: {e}") + except Exception as e: + logger.error(f"Error scanning {path}: {e}") - async def _initialize_cache(self) -> None: - """Initialize or refresh the cache""" + await scan_recursive(root_path, set()) + return loras + + async def _process_single_file(self, file_path: str, root_path: str, loras: list): + """Process a single file and add to results list""" try: - start_time = time.time() - # Clear existing hash index - self._hash_index.clear() - - # Clear existing tags count - self._tags_count = {} - - # Scan for new data - raw_data = await self.scan_all_loras() - - # Build hash index and tags count - for lora_data in raw_data: - if 'sha256' in lora_data and 'file_path' in lora_data: - self._hash_index.add_entry(lora_data['sha256'].lower(), lora_data['file_path']) - - # Count tags - if 'tags' in lora_data and lora_data['tags']: - for tag in lora_data['tags']: - self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 - - # Update cache - self._cache = LoraCache( - raw_data=raw_data, - sorted_by_name=[], - sorted_by_date=[], - folders=[] - ) - - # Call resort_cache to create sorted views - await self._cache.resort() - - self._initialization_task = None - logger.info(f"LoRA Manager: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} loras") + result = await self._process_model_file(file_path, root_path) + if result: + loras.append(result) except Exception as e: - logger.error(f"LoRA Manager: Error initializing cache: {e}") - self._cache = LoraCache( - raw_data=[], - sorted_by_name=[], - sorted_by_date=[], - folders=[] - ) - + logger.error(f"Error processing {file_path}: {e}") + async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name', folder: str = None, search: str = None, fuzzy_search: bool = False, base_models: list = None, tags: list = None, @@ -280,240 +263,14 @@ class LoraScanner: return result - def invalidate_cache(self): - """Invalidate the current cache""" - self._cache = None - - async def scan_all_loras(self) -> List[Dict]: - """Scan all LoRA directories and return metadata""" - all_loras = [] - - # 分目录异步扫描 - scan_tasks = [] - for loras_root in config.loras_roots: - task = asyncio.create_task(self._scan_directory(loras_root)) - scan_tasks.append(task) - - for task in scan_tasks: - try: - loras = await task - all_loras.extend(loras) - except Exception as e: - logger.error(f"Error scanning directory: {e}") - - return all_loras - - async def _scan_directory(self, root_path: str) -> List[Dict]: - """Scan a single directory for LoRA files""" - loras = [] - original_root = root_path # 保存原始根路径 - - async def scan_recursive(path: str, visited_paths: set): - """递归扫描目录,避免循环链接""" - try: - real_path = os.path.realpath(path) - if real_path in visited_paths: - logger.debug(f"Skipping already visited path: {path}") - return - visited_paths.add(real_path) - - with os.scandir(path) as it: - entries = list(it) - for entry in entries: - try: - if entry.is_file(follow_symlinks=True) and entry.name.endswith('.safetensors'): - # 使用原始路径而不是真实路径 - file_path = entry.path.replace(os.sep, "/") - await self._process_single_file(file_path, original_root, loras) - await asyncio.sleep(0) - elif entry.is_dir(follow_symlinks=True): - # 对于目录,使用原始路径继续扫描 - await scan_recursive(entry.path, visited_paths) - except Exception as e: - logger.error(f"Error processing entry {entry.path}: {e}") - except Exception as e: - logger.error(f"Error scanning {path}: {e}") - - await scan_recursive(root_path, set()) - return loras - - async def _process_single_file(self, file_path: str, root_path: str, loras: list): - """处理单个文件并添加到结果列表""" - try: - result = await self._process_lora_file(file_path, root_path) - if result: - loras.append(result) - except Exception as e: - logger.error(f"Error processing {file_path}: {e}") - - async def _process_lora_file(self, file_path: str, root_path: str) -> Dict: - """Process a single LoRA file and return its metadata""" - # Try loading existing metadata - metadata = await load_metadata(file_path) - - if metadata is None: - # Try to find and use .civitai.info file first - civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info" - if os.path.exists(civitai_info_path): - try: - with open(civitai_info_path, 'r', encoding='utf-8') as f: - version_info = json.load(f) - - file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None) - if file_info: - # Create a minimal file_info with the required fields - file_name = os.path.splitext(os.path.basename(file_path))[0] - file_info['name'] = file_name - - # Use from_civitai_info to create metadata - metadata = LoraMetadata.from_civitai_info(version_info, file_info, file_path) - metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path)) - await save_metadata(file_path, metadata) - logger.debug(f"Created metadata from .civitai.info for {file_path}") - except Exception as e: - logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}") - - # If still no metadata, create new metadata using get_file_info - if metadata is None: - metadata = await get_file_info(file_path) - - # Convert to dict and add folder info - lora_data = metadata.to_dict() - # Try to fetch missing metadata from Civitai if needed - await self._fetch_missing_metadata(file_path, lora_data) - rel_path = os.path.relpath(file_path, root_path) - folder = os.path.dirname(rel_path) - lora_data['folder'] = folder.replace(os.path.sep, '/') - - return lora_data - - async def _fetch_missing_metadata(self, file_path: str, lora_data: Dict) -> None: - """Fetch missing description and tags from Civitai if needed - - Args: - file_path: Path to the lora file - lora_data: Lora metadata dictionary to update - """ - try: - # Skip if already marked as deleted on Civitai - if lora_data.get('civitai_deleted', False): - logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai") - return - - # Check if we need to fetch additional metadata from Civitai - needs_metadata_update = False - model_id = None - - # Check if we have Civitai model ID but missing metadata - if lora_data.get('civitai'): - # Try to get model ID directly from the correct location - model_id = lora_data['civitai'].get('modelId') - - if model_id: - model_id = str(model_id) - # Check if tags are missing or empty - tags_missing = not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0 - - # Check if description is missing or empty - desc_missing = not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "") - - needs_metadata_update = tags_missing or desc_missing - - # Fetch missing metadata if needed - if needs_metadata_update and model_id: - logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}") - from ..services.civitai_client import CivitaiClient - client = CivitaiClient() - - # Get metadata and status code - model_metadata, status_code = await client.get_model_metadata(model_id) - await client.close() - - # Handle 404 status (model deleted from Civitai) - if status_code == 404: - logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)") - # Mark as deleted to avoid future API calls - lora_data['civitai_deleted'] = True - - # Save the updated metadata back to file - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(lora_data, f, indent=2, ensure_ascii=False) - - # Process valid metadata if available - elif model_metadata: - logger.debug(f"Updating metadata for {file_path} with model ID {model_id}") - - # Update tags if they were missing - if model_metadata.get('tags') and (not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0): - lora_data['tags'] = model_metadata['tags'] - - # Update description if it was missing - if model_metadata.get('description') and (not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "")): - lora_data['modelDescription'] = model_metadata['description'] - - # Save the updated metadata back to file - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(lora_data, f, indent=2, ensure_ascii=False) - except Exception as e: - logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}") - - async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool: - """Update preview URL in cache for a specific lora - - Args: - file_path: The file path of the lora to update - preview_url: The new preview URL - - Returns: - bool: True if the update was successful, False if cache doesn't exist or lora wasn't found - """ - if self._cache is None: - return False - - return await self._cache.update_preview_url(file_path, preview_url) - - async def scan_single_lora(self, file_path: str) -> Optional[Dict]: - """Scan a single LoRA file and return its metadata""" - try: - if not os.path.exists(os.path.realpath(file_path)): - return None - - # 获取基本文件信息 - metadata = await get_file_info(file_path) - if not metadata: - return None - - folder = self._calculate_folder(file_path) - - # 确保 folder 字段存在 - metadata_dict = metadata.to_dict() - metadata_dict['folder'] = folder or '' - - return metadata_dict - - except Exception as e: - logger.error(f"Error scanning {file_path}: {e}") - return None - - def _calculate_folder(self, file_path: str) -> str: - """Calculate the folder path for a LoRA file""" - # 使用原始路径计算相对路径 - for root in config.loras_roots: - if file_path.startswith(root): - rel_path = os.path.relpath(file_path, root) - return os.path.dirname(rel_path).replace(os.path.sep, '/') - return '' - async def move_model(self, source_path: str, target_path: str) -> bool: """Move a model and its associated files to a new location""" try: - # 保持原始路径格式 + # Keep original path format source_path = source_path.replace(os.sep, '/') target_path = target_path.replace(os.sep, '/') - # 其余代码保持不变 + # Rest of the code remains unchanged base_name = os.path.splitext(os.path.basename(source_path))[0] source_dir = os.path.dirname(source_path) @@ -521,7 +278,7 @@ class LoraScanner: target_lora = os.path.join(target_path, f"{base_name}.safetensors").replace(os.sep, '/') - # 使用真实路径进行文件操作 + # Use real paths for file operations real_source = os.path.realpath(source_path) real_target = os.path.realpath(target_lora) @@ -537,7 +294,7 @@ class LoraScanner: file_size ) - # 使用真实路径进行文件操作 + # Use real paths for file operations shutil.move(real_source, real_target) # Move associated files @@ -648,7 +405,7 @@ class LoraScanner: except Exception as e: logger.error(f"Error updating metadata paths: {e}", exc_info=True) - # Add new methods for hash index functionality + # Lora-specific hash index functionality def has_lora_hash(self, sha256: str) -> bool: """Check if a LoRA with given hash exists""" return self._hash_index.has_hash(sha256.lower()) @@ -681,16 +438,8 @@ class LoraScanner: return None - # Add new method to get top tags async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]: - """Get top tags sorted by count - - Args: - limit: Maximum number of tags to return - - Returns: - List of dictionaries with tag name and count, sorted by count - """ + """Get top tags sorted by count""" # Make sure cache is initialized await self.get_cached_data() @@ -705,14 +454,7 @@ class LoraScanner: return sorted_tags[:limit] async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]: - """Get base models used in loras sorted by frequency - - Args: - limit: Maximum number of base models to return - - Returns: - List of dictionaries with base model name and count, sorted by count - """ + """Get base models used in loras sorted by frequency""" # Make sure cache is initialized cache = await self.get_cached_data() diff --git a/py/services/model_cache.py b/py/services/model_cache.py new file mode 100644 index 00000000..b652d919 --- /dev/null +++ b/py/services/model_cache.py @@ -0,0 +1,64 @@ +import asyncio +from typing import List, Dict +from dataclasses import dataclass +from operator import itemgetter + +@dataclass +class ModelCache: + """Cache structure for model data""" + raw_data: List[Dict] + sorted_by_name: List[Dict] + sorted_by_date: List[Dict] + folders: List[str] + + def __post_init__(self): + self._lock = asyncio.Lock() + + async def resort(self, name_only: bool = False): + """Resort all cached data views""" + async with self._lock: + self.sorted_by_name = sorted( + self.raw_data, + key=lambda x: x['model_name'].lower() # Case-insensitive sort + ) + if not name_only: + self.sorted_by_date = sorted( + self.raw_data, + key=itemgetter('modified'), + reverse=True + ) + # Update folder list + all_folders = set(l['folder'] for l in self.raw_data) + self.folders = sorted(list(all_folders), key=lambda x: x.lower()) + + async def update_preview_url(self, file_path: str, preview_url: str) -> bool: + """Update preview_url for a specific model in all cached data + + Args: + file_path: The file path of the model to update + preview_url: The new preview URL + + Returns: + bool: True if the update was successful, False if the model wasn't found + """ + async with self._lock: + # Update in raw_data + for item in self.raw_data: + if item['file_path'] == file_path: + item['preview_url'] = preview_url + break + else: + return False # Model not found + + # Update in sorted lists (references to the same dict objects) + for item in self.sorted_by_name: + if item['file_path'] == file_path: + item['preview_url'] = preview_url + break + + for item in self.sorted_by_date: + if item['file_path'] == file_path: + item['preview_url'] = preview_url + break + + return True \ No newline at end of file diff --git a/py/services/model_hash_index.py b/py/services/model_hash_index.py new file mode 100644 index 00000000..2f8ef0eb --- /dev/null +++ b/py/services/model_hash_index.py @@ -0,0 +1,78 @@ +from typing import Dict, Optional, Set + +class ModelHashIndex: + """Index for looking up models by hash or path""" + + def __init__(self): + self._hash_to_path: Dict[str, str] = {} + self._path_to_hash: Dict[str, str] = {} + + def add_entry(self, sha256: str, file_path: str) -> None: + """Add or update hash index entry""" + if not sha256 or not file_path: + return + + # Ensure hash is lowercase for consistency + sha256 = sha256.lower() + + # Remove old path mapping if hash exists + if sha256 in self._hash_to_path: + old_path = self._hash_to_path[sha256] + if old_path in self._path_to_hash: + del self._path_to_hash[old_path] + + # Remove old hash mapping if path exists + if file_path in self._path_to_hash: + old_hash = self._path_to_hash[file_path] + if old_hash in self._hash_to_path: + del self._hash_to_path[old_hash] + + # Add new mappings + self._hash_to_path[sha256] = file_path + self._path_to_hash[file_path] = sha256 + + def remove_by_path(self, file_path: str) -> None: + """Remove entry by file path""" + if file_path in self._path_to_hash: + hash_val = self._path_to_hash[file_path] + if hash_val in self._hash_to_path: + del self._hash_to_path[hash_val] + del self._path_to_hash[file_path] + + def remove_by_hash(self, sha256: str) -> None: + """Remove entry by hash""" + sha256 = sha256.lower() + if sha256 in self._hash_to_path: + path = self._hash_to_path[sha256] + if path in self._path_to_hash: + del self._path_to_hash[path] + del self._hash_to_path[sha256] + + def has_hash(self, sha256: str) -> bool: + """Check if hash exists in index""" + return sha256.lower() in self._hash_to_path + + def get_path(self, sha256: str) -> Optional[str]: + """Get file path for a hash""" + return self._hash_to_path.get(sha256.lower()) + + def get_hash(self, file_path: str) -> Optional[str]: + """Get hash for a file path""" + return self._path_to_hash.get(file_path) + + def clear(self) -> None: + """Clear all entries""" + self._hash_to_path.clear() + self._path_to_hash.clear() + + def get_all_hashes(self) -> Set[str]: + """Get all hashes in the index""" + return set(self._hash_to_path.keys()) + + def get_all_paths(self) -> Set[str]: + """Get all file paths in the index""" + return set(self._path_to_hash.keys()) + + def __len__(self) -> int: + """Get number of entries""" + return len(self._hash_to_path) \ No newline at end of file diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py new file mode 100644 index 00000000..b15e35c9 --- /dev/null +++ b/py/services/model_scanner.py @@ -0,0 +1,554 @@ +import json +import os +import logging +import asyncio +import time +import shutil +from typing import List, Dict, Optional, Type, Set + +from ..utils.models import BaseModelMetadata +from ..config import config +from ..utils.file_utils import load_metadata, get_file_info, find_preview_file, save_metadata +from .model_cache import ModelCache +from .model_hash_index import ModelHashIndex + +logger = logging.getLogger(__name__) + +class ModelScanner: + """Base service for scanning and managing model files""" + + _instance = None + _lock = asyncio.Lock() + + def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None): + """Initialize the scanner + + Args: + model_type: Type of model (lora, checkpoint, etc.) + model_class: Class used to create metadata instances + file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'}) + hash_index: Hash index instance (optional) + """ + self.model_type = model_type + self.model_class = model_class + self.file_extensions = file_extensions + self._cache = None + self._hash_index = hash_index or ModelHashIndex() + self._initialization_lock = asyncio.Lock() + self._initialization_task = None + self.file_monitor = None + self._tags_count = {} # Dictionary to store tag counts + + def set_file_monitor(self, monitor): + """Set file monitor instance""" + self.file_monitor = monitor + + async def get_cached_data(self, force_refresh: bool = False) -> ModelCache: + """Get cached model data, refresh if needed""" + async with self._initialization_lock: + # Return empty cache if not initialized and no refresh requested + if self._cache is None and not force_refresh: + return ModelCache( + raw_data=[], + sorted_by_name=[], + sorted_by_date=[], + folders=[] + ) + + # Wait for ongoing initialization if any + if self._initialization_task and not self._initialization_task.done(): + try: + await self._initialization_task + except Exception as e: + logger.error(f"Cache initialization failed: {e}") + self._initialization_task = None + + if (self._cache is None or force_refresh): + # Create new initialization task + if not self._initialization_task or self._initialization_task.done(): + self._initialization_task = asyncio.create_task(self._initialize_cache()) + + try: + await self._initialization_task + except Exception as e: + logger.error(f"Cache initialization failed: {e}") + # Continue using old cache if it exists + if self._cache is None: + raise # Raise exception if no cache available + + return self._cache + + async def _initialize_cache(self) -> None: + """Initialize or refresh the cache""" + try: + start_time = time.time() + # Clear existing hash index + self._hash_index.clear() + + # Clear existing tags count + self._tags_count = {} + + # Scan for new data + raw_data = await self.scan_all_models() + + # Build hash index and tags count + for model_data in raw_data: + if 'sha256' in model_data and 'file_path' in model_data: + self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path']) + + # Count tags + if 'tags' in model_data and model_data['tags']: + for tag in model_data['tags']: + self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 + + # Update cache + self._cache = ModelCache( + raw_data=raw_data, + sorted_by_name=[], + sorted_by_date=[], + folders=[] + ) + + # Resort cache + await self._cache.resort() + + self._initialization_task = None + logger.info(f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} models") + except Exception as e: + logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache: {e}") + self._cache = ModelCache( + raw_data=[], + sorted_by_name=[], + sorted_by_date=[], + folders=[] + ) + + # These methods should be implemented in child classes + async def scan_all_models(self) -> List[Dict]: + """Scan all model directories and return metadata""" + raise NotImplementedError("Subclasses must implement scan_all_models") + + def get_model_roots(self) -> List[str]: + """Get model root directories""" + raise NotImplementedError("Subclasses must implement get_model_roots") + + async def scan_single_model(self, file_path: str) -> Optional[Dict]: + """Scan a single model file and return its metadata""" + try: + if not os.path.exists(os.path.realpath(file_path)): + return None + + # Get basic file info + metadata = await self._get_file_info(file_path) + if not metadata: + return None + + folder = self._calculate_folder(file_path) + + # Ensure folder field exists + metadata_dict = metadata.to_dict() + metadata_dict['folder'] = folder or '' + + return metadata_dict + + except Exception as e: + logger.error(f"Error scanning {file_path}: {e}") + return None + + async def _get_file_info(self, file_path: str) -> Optional[BaseModelMetadata]: + """Get model file info and metadata (extensible for different model types)""" + # Implementation may vary by model type - override in subclasses if needed + return await get_file_info(file_path, self.model_class) + + def _calculate_folder(self, file_path: str) -> str: + """Calculate the folder path for a model file""" + # Use original path to calculate relative path + for root in self.get_model_roots(): + if file_path.startswith(root): + rel_path = os.path.relpath(file_path, root) + return os.path.dirname(rel_path).replace(os.path.sep, '/') + return '' + + # Common methods shared between scanners + async def _process_model_file(self, file_path: str, root_path: str) -> Dict: + """Process a single model file and return its metadata""" + # Try loading existing metadata + metadata = await load_metadata(file_path, self.model_class) + + if metadata is None: + # Try to find and use .civitai.info file first + civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info" + if os.path.exists(civitai_info_path): + try: + with open(civitai_info_path, 'r', encoding='utf-8') as f: + version_info = json.load(f) + + file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None) + if file_info: + # Create a minimal file_info with the required fields + file_name = os.path.splitext(os.path.basename(file_path))[0] + file_info['name'] = file_name + + # Use from_civitai_info to create metadata + metadata = self.model_class.from_civitai_info(version_info, file_info, file_path) + metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path)) + await save_metadata(file_path, metadata) + logger.debug(f"Created metadata from .civitai.info for {file_path}") + except Exception as e: + logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}") + + # If still no metadata, create new metadata + if metadata is None: + metadata = await self._get_file_info(file_path) + + # Convert to dict and add folder info + model_data = metadata.to_dict() + + # Try to fetch missing metadata from Civitai if needed + await self._fetch_missing_metadata(file_path, model_data) + rel_path = os.path.relpath(file_path, root_path) + folder = os.path.dirname(rel_path) + model_data['folder'] = folder.replace(os.path.sep, '/') + + return model_data + + async def _fetch_missing_metadata(self, file_path: str, model_data: Dict) -> None: + """Fetch missing description and tags from Civitai if needed""" + try: + # Skip if already marked as deleted on Civitai + if model_data.get('civitai_deleted', False): + logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai") + return + + # Check if we need to fetch additional metadata from Civitai + needs_metadata_update = False + model_id = None + + # Check if we have Civitai model ID but missing metadata + if model_data.get('civitai'): + model_id = model_data['civitai'].get('modelId') + + if model_id: + model_id = str(model_id) + # Check if tags or description are missing + tags_missing = not model_data.get('tags') or len(model_data.get('tags', [])) == 0 + desc_missing = not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "") + needs_metadata_update = tags_missing or desc_missing + + # Fetch missing metadata if needed + if needs_metadata_update and model_id: + logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}") + from ..services.civitai_client import CivitaiClient + client = CivitaiClient() + + # Get metadata and status code + model_metadata, status_code = await client.get_model_metadata(model_id) + await client.close() + + # Handle 404 status (model deleted from Civitai) + if status_code == 404: + logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)") + model_data['civitai_deleted'] = True + + # Save the updated metadata + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(model_data, f, indent=2, ensure_ascii=False) + + # Process valid metadata if available + elif model_metadata: + logger.debug(f"Updating metadata for {file_path} with model ID {model_id}") + + # Update tags if they were missing + if model_metadata.get('tags') and (not model_data.get('tags') or len(model_data.get('tags', [])) == 0): + model_data['tags'] = model_metadata['tags'] + + # Update description if it was missing + if model_metadata.get('description') and (not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")): + model_data['modelDescription'] = model_metadata['description'] + + # Save the updated metadata + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(model_data, f, indent=2, ensure_ascii=False) + except Exception as e: + logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}") + + async def _scan_directory(self, root_path: str) -> List[Dict]: + """Base implementation for directory scanning""" + models = [] + original_root = root_path + + async def scan_recursive(path: str, visited_paths: set): + try: + real_path = os.path.realpath(path) + if real_path in visited_paths: + logger.debug(f"Skipping already visited path: {path}") + return + visited_paths.add(real_path) + + with os.scandir(path) as it: + entries = list(it) + for entry in entries: + try: + if entry.is_file(follow_symlinks=True): + # Check if file has supported extension + ext = os.path.splitext(entry.name)[1].lower() + if ext in self.file_extensions: + file_path = entry.path.replace(os.sep, "/") + await self._process_single_file(file_path, original_root, models) + await asyncio.sleep(0) + elif entry.is_dir(follow_symlinks=True): + # For directories, continue scanning with original path + await scan_recursive(entry.path, visited_paths) + except Exception as e: + logger.error(f"Error processing entry {entry.path}: {e}") + except Exception as e: + logger.error(f"Error scanning {path}: {e}") + + await scan_recursive(root_path, set()) + return models + + async def _process_single_file(self, file_path: str, root_path: str, models_list: list): + """Process a single file and add to results list""" + try: + result = await self._process_model_file(file_path, root_path) + if result: + models_list.append(result) + except Exception as e: + logger.error(f"Error processing {file_path}: {e}") + + async def move_model(self, source_path: str, target_path: str) -> bool: + """Move a model and its associated files to a new location""" + try: + # Keep original path format + source_path = source_path.replace(os.sep, '/') + target_path = target_path.replace(os.sep, '/') + + # Get file extension from source + file_ext = os.path.splitext(source_path)[1] + + # If no extension or not in supported extensions, return False + if not file_ext or file_ext.lower() not in self.file_extensions: + logger.error(f"Invalid file extension for model: {file_ext}") + return False + + base_name = os.path.splitext(os.path.basename(source_path))[0] + source_dir = os.path.dirname(source_path) + + os.makedirs(target_path, exist_ok=True) + + target_file = os.path.join(target_path, f"{base_name}{file_ext}").replace(os.sep, '/') + + # Use real paths for file operations + real_source = os.path.realpath(source_path) + real_target = os.path.realpath(target_file) + + file_size = os.path.getsize(real_source) + + if self.file_monitor: + self.file_monitor.handler.add_ignore_path( + real_source, + file_size + ) + self.file_monitor.handler.add_ignore_path( + real_target, + file_size + ) + + # Use real paths for file operations + shutil.move(real_source, real_target) + + # Move associated files + source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json") + metadata = None + if os.path.exists(source_metadata): + target_metadata = os.path.join(target_path, f"{base_name}.metadata.json") + shutil.move(source_metadata, target_metadata) + metadata = await self._update_metadata_paths(target_metadata, target_file) + + # Move preview file if exists + preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4', + '.png', '.jpeg', '.jpg', '.mp4'] + for ext in preview_extensions: + source_preview = os.path.join(source_dir, f"{base_name}{ext}") + if os.path.exists(source_preview): + target_preview = os.path.join(target_path, f"{base_name}{ext}") + shutil.move(source_preview, target_preview) + break + + # Update cache + await self.update_single_model_cache(source_path, target_file, metadata) + + return True + + except Exception as e: + logger.error(f"Error moving model: {e}", exc_info=True) + return False + + async def _update_metadata_paths(self, metadata_path: str, model_path: str) -> Dict: + """Update file paths in metadata file""" + try: + with open(metadata_path, 'r', encoding='utf-8') as f: + metadata = json.load(f) + + # Update file_path + metadata['file_path'] = model_path.replace(os.sep, '/') + + # Update preview_url if exists + if 'preview_url' in metadata: + preview_dir = os.path.dirname(model_path) + preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0] + preview_ext = os.path.splitext(metadata['preview_url'])[1] + new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}") + metadata['preview_url'] = new_preview_path.replace(os.sep, '/') + + # Save updated metadata + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) + + return metadata + + except Exception as e: + logger.error(f"Error updating metadata paths: {e}", exc_info=True) + return None + + async def update_single_model_cache(self, original_path: str, new_path: str, metadata: Dict) -> bool: + """Update cache after a model has been moved or modified""" + cache = await self.get_cached_data() + + # Find the existing item to remove its tags from count + existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None) + if existing_item and 'tags' in existing_item: + for tag in existing_item.get('tags', []): + if tag in self._tags_count: + self._tags_count[tag] = max(0, self._tags_count[tag] - 1) + if self._tags_count[tag] == 0: + del self._tags_count[tag] + + # Remove old path from hash index if exists + self._hash_index.remove_by_path(original_path) + + # Remove the old entry from raw_data + cache.raw_data = [ + item for item in cache.raw_data + if item['file_path'] != original_path + ] + + if metadata: + # If this is an update to an existing path (not a move), ensure folder is preserved + if original_path == new_path: + # Find the folder from existing entries or calculate it + existing_folder = next((item['folder'] for item in cache.raw_data + if item['file_path'] == original_path), None) + if existing_folder: + metadata['folder'] = existing_folder + else: + metadata['folder'] = self._calculate_folder(new_path) + else: + # For moved files, recalculate the folder + metadata['folder'] = self._calculate_folder(new_path) + + # Add the updated metadata to raw_data + cache.raw_data.append(metadata) + + # Update hash index with new path + if 'sha256' in metadata: + self._hash_index.add_entry(metadata['sha256'].lower(), new_path) + + # Update folders list + all_folders = set(item['folder'] for item in cache.raw_data) + cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) + + # Update tags count with the new/updated tags + if 'tags' in metadata: + for tag in metadata.get('tags', []): + self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 + + # Resort cache + await cache.resort() + + return True + + # Hash index functionality (common for all model types) + def has_hash(self, sha256: str) -> bool: + """Check if a model with given hash exists""" + return self._hash_index.has_hash(sha256.lower()) + + def get_path_by_hash(self, sha256: str) -> Optional[str]: + """Get file path for a model by its hash""" + return self._hash_index.get_path(sha256.lower()) + + def get_hash_by_path(self, file_path: str) -> Optional[str]: + """Get hash for a model by its file path""" + return self._hash_index.get_hash(file_path) + + def get_preview_url_by_hash(self, sha256: str) -> Optional[str]: + """Get preview static URL for a model by its hash""" + # Get the file path first + file_path = self._hash_index.get_path(sha256.lower()) + if not file_path: + return None + + # Determine the preview file path (typically same name with different extension) + base_name = os.path.splitext(file_path)[0] + preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4', + '.png', '.jpeg', '.jpg', '.mp4'] + + for ext in preview_extensions: + preview_path = f"{base_name}{ext}" + if os.path.exists(preview_path): + # Convert to static URL using config + return config.get_preview_static_url(preview_path) + + return None + + async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]: + """Get top tags sorted by count""" + # Make sure cache is initialized + await self.get_cached_data() + + # Sort tags by count in descending order + sorted_tags = sorted( + [{"tag": tag, "count": count} for tag, count in self._tags_count.items()], + key=lambda x: x['count'], + reverse=True + ) + + # Return limited number + return sorted_tags[:limit] + + async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]: + """Get base models sorted by frequency""" + # Make sure cache is initialized + cache = await self.get_cached_data() + + # Count base model occurrences + base_model_counts = {} + for model in cache.raw_data: + if 'base_model' in model and model['base_model']: + base_model = model['base_model'] + base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 + + # Sort base models by count + sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()] + sorted_models.sort(key=lambda x: x['count'], reverse=True) + + # Return limited number + return sorted_models[:limit] + + async def get_model_info_by_name(self, name): + """Get model information by name""" + try: + # Get cached data + cache = await self.get_cached_data() + + # Find the model by name + for model in cache.raw_data: + if model.get("file_name") == name: + return model + + return None + except Exception as e: + logger.error(f"Error getting model info by name: {e}", exc_info=True) + return None \ No newline at end of file diff --git a/py/utils/file_utils.py b/py/utils/file_utils.py index 0f282051..859e86ae 100644 --- a/py/utils/file_utils.py +++ b/py/utils/file_utils.py @@ -2,12 +2,12 @@ import logging import os import hashlib import json -from typing import Dict, Optional +import time +from typing import Dict, Optional, Type from .model_utils import determine_base_model - -from .lora_metadata import extract_lora_metadata -from .models import LoraMetadata +from .lora_metadata import extract_lora_metadata, extract_checkpoint_metadata +from .models import BaseModelMetadata, LoraMetadata, CheckpointMetadata logger = logging.getLogger(__name__) @@ -15,7 +15,7 @@ async def calculate_sha256(file_path: str) -> str: """Calculate SHA256 hash of a file""" sha256_hash = hashlib.sha256() with open(file_path, "rb") as f: - for byte_block in iter(lambda: f.read(4096), b""): + for byte_block in iter(lambda: f.read(128 * 1024), b""): sha256_hash.update(byte_block) return sha256_hash.hexdigest() @@ -42,8 +42,8 @@ def normalize_path(path: str) -> str: """Normalize file path to use forward slashes""" return path.replace(os.sep, "/") if path else path -async def get_file_info(file_path: str) -> Optional[LoraMetadata]: - """Get basic file information as LoraMetadata object""" +async def get_file_info(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]: + """Get basic file information as a model metadata object""" # First check if file actually exists and resolve symlinks try: real_path = os.path.realpath(file_path) @@ -74,27 +74,52 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]: try: # If we didn't get SHA256 from the .json file, calculate it if not sha256: + start_time = time.time() sha256 = await calculate_sha256(real_path) + logger.debug(f"Calculated SHA256 for {file_path} in {time.time() - start_time:.2f} seconds") + + # Create default metadata based on model class + if model_class == CheckpointMetadata: + metadata = CheckpointMetadata( + file_name=base_name, + model_name=base_name, + file_path=normalize_path(file_path), + size=os.path.getsize(real_path), + modified=os.path.getmtime(real_path), + sha256=sha256, + base_model="Unknown", # Will be updated later + preview_url=normalize_path(preview_url), + tags=[], + modelDescription="", + model_type="checkpoint" + ) - metadata = LoraMetadata( - file_name=base_name, - model_name=base_name, - file_path=normalize_path(file_path), - size=os.path.getsize(real_path), - modified=os.path.getmtime(real_path), - sha256=sha256, - base_model="Unknown", # Will be updated later - usage_tips="", - notes="", - from_civitai=True, - preview_url=normalize_path(preview_url), - tags=[], - modelDescription="" - ) + # Extract checkpoint-specific metadata + # model_info = await extract_checkpoint_metadata(real_path) + # metadata.base_model = model_info['base_model'] + # if 'model_type' in model_info: + # metadata.model_type = model_info['model_type'] + + else: # Default to LoraMetadata + metadata = LoraMetadata( + file_name=base_name, + model_name=base_name, + file_path=normalize_path(file_path), + size=os.path.getsize(real_path), + modified=os.path.getmtime(real_path), + sha256=sha256, + base_model="Unknown", # Will be updated later + usage_tips="{}", + preview_url=normalize_path(preview_url), + tags=[], + modelDescription="" + ) + + # Extract lora-specific metadata + model_info = await extract_lora_metadata(real_path) + metadata.base_model = model_info['base_model'] - # create metadata file - base_model_info = await extract_lora_metadata(real_path) - metadata.base_model = base_model_info['base_model'] + # Save metadata to file await save_metadata(file_path, metadata) return metadata @@ -102,7 +127,7 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]: logger.error(f"Error getting file info for {file_path}: {e}") return None -async def save_metadata(file_path: str, metadata: LoraMetadata) -> None: +async def save_metadata(file_path: str, metadata: BaseModelMetadata) -> None: """Save metadata to .metadata.json file""" metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json" try: @@ -115,7 +140,7 @@ async def save_metadata(file_path: str, metadata: LoraMetadata) -> None: except Exception as e: print(f"Error saving metadata to {metadata_path}: {str(e)}") -async def load_metadata(file_path: str) -> Optional[LoraMetadata]: +async def load_metadata(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]: """Load metadata from .metadata.json file""" metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json" try: @@ -162,12 +187,22 @@ async def load_metadata(file_path: str) -> Optional[LoraMetadata]: if 'modelDescription' not in data: data['modelDescription'] = "" needs_update = True + + # For checkpoint metadata + if model_class == CheckpointMetadata and 'model_type' not in data: + data['model_type'] = "checkpoint" + needs_update = True + + # For lora metadata + if model_class == LoraMetadata and 'usage_tips' not in data: + data['usage_tips'] = "{}" + needs_update = True if needs_update: with open(metadata_path, 'w', encoding='utf-8') as f: json.dump(data, f, indent=2, ensure_ascii=False) - return LoraMetadata.from_dict(data) + return model_class.from_dict(data) except Exception as e: print(f"Error loading metadata from {metadata_path}: {str(e)}") diff --git a/py/utils/lora_metadata.py b/py/utils/lora_metadata.py index c04cd829..f221562d 100644 --- a/py/utils/lora_metadata.py +++ b/py/utils/lora_metadata.py @@ -1,6 +1,7 @@ from safetensors import safe_open from typing import Dict from .model_utils import determine_base_model +import os async def extract_lora_metadata(file_path: str) -> Dict: """Extract essential metadata from safetensors file""" @@ -13,4 +14,67 @@ async def extract_lora_metadata(file_path: str) -> Dict: return {"base_model": base_model} except Exception as e: print(f"Error reading metadata from {file_path}: {str(e)}") - return {"base_model": "Unknown"} \ No newline at end of file + return {"base_model": "Unknown"} + +async def extract_checkpoint_metadata(file_path: str) -> dict: + """Extract metadata from a checkpoint file to determine model type and base model""" + try: + # Analyze filename for clues about the model + filename = os.path.basename(file_path).lower() + + model_info = { + 'base_model': 'Unknown', + 'model_type': 'checkpoint' + } + + # Detect base model from filename + if 'xl' in filename or 'sdxl' in filename: + model_info['base_model'] = 'SDXL' + elif 'sd3' in filename: + model_info['base_model'] = 'SD3' + elif 'sd2' in filename or 'v2' in filename: + model_info['base_model'] = 'SD2.x' + elif 'sd1' in filename or 'v1' in filename: + model_info['base_model'] = 'SD1.5' + + # Detect model type from filename + if 'inpaint' in filename: + model_info['model_type'] = 'inpainting' + elif 'anime' in filename: + model_info['model_type'] = 'anime' + elif 'realistic' in filename: + model_info['model_type'] = 'realistic' + + # Try to peek at the safetensors file structure if available + if file_path.endswith('.safetensors'): + import json + import struct + + with open(file_path, 'rb') as f: + header_size = struct.unpack(' 'LoraMetadata': - """Create LoraMetadata instance from dictionary""" - # Create a copy of the data to avoid modifying the input + def from_dict(cls, data: Dict) -> 'BaseModelMetadata': + """Create instance from dictionary""" data_copy = data.copy() return cls(**data_copy) - @classmethod - def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata': - """Create LoraMetadata instance from Civitai version info""" - file_name = file_info['name'] - base_model = determine_base_model(version_info.get('baseModel', '')) - - return cls( - file_name=os.path.splitext(file_name)[0], - model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]), - file_path=save_path.replace(os.sep, '/'), - size=file_info.get('sizeKB', 0) * 1024, - modified=datetime.now().timestamp(), - sha256=file_info['hashes'].get('SHA256', '').lower(), - base_model=base_model, - preview_url=None, # Will be updated after preview download - preview_nsfw_level=0, # Will be updated after preview download, it is decided by the nsfw level of the preview image - from_civitai=True, - civitai=version_info - ) - def to_dict(self) -> Dict: """Convert to dictionary for JSON serialization""" return asdict(self) @@ -76,30 +54,54 @@ class LoraMetadata: self.file_path = file_path.replace(os.sep, '/') @dataclass -class CheckpointMetadata: - """Represents the metadata structure for a Checkpoint model""" - file_name: str # The filename without extension - model_name: str # The checkpoint's name defined by the creator - file_path: str # Full path to the model file - size: int # File size in bytes - modified: float # Last modified timestamp - sha256: str # SHA256 hash of the file - base_model: str # Base model type (SD1.5/SD2.1/SDXL/etc.) - preview_url: str # Preview image URL - preview_nsfw_level: int = 0 # NSFW level of the preview image - model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.) - notes: str = "" # Additional notes - from_civitai: bool = True # Whether from Civitai - civitai: Optional[Dict] = None # Civitai API data if available - tags: List[str] = None # Model tags - modelDescription: str = "" # Full model description - - # Additional checkpoint-specific fields - resolution: Optional[str] = None # Native resolution (e.g., 512x512, 1024x1024) - vae_included: bool = False # Whether VAE is included in the checkpoint - architecture: str = "" # Model architecture (if known) - - def __post_init__(self): - if self.tags is None: - self.tags = [] +class LoraMetadata(BaseModelMetadata): + """Represents the metadata structure for a Lora model""" + usage_tips: str = "{}" # Usage tips for the model, json string + + @classmethod + def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata': + """Create LoraMetadata instance from Civitai version info""" + file_name = file_info['name'] + base_model = determine_base_model(version_info.get('baseModel', '')) + + return cls( + file_name=os.path.splitext(file_name)[0], + model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]), + file_path=save_path.replace(os.sep, '/'), + size=file_info.get('sizeKB', 0) * 1024, + modified=datetime.now().timestamp(), + sha256=file_info['hashes'].get('SHA256', '').lower(), + base_model=base_model, + preview_url=None, # Will be updated after preview download + preview_nsfw_level=0, # Will be updated after preview download + from_civitai=True, + civitai=version_info + ) + +@dataclass +class CheckpointMetadata(BaseModelMetadata): + """Represents the metadata structure for a Checkpoint model""" + model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.) + + @classmethod + def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'CheckpointMetadata': + """Create CheckpointMetadata instance from Civitai version info""" + file_name = file_info['name'] + base_model = determine_base_model(version_info.get('baseModel', '')) + model_type = version_info.get('type', 'checkpoint') + + return cls( + file_name=os.path.splitext(file_name)[0], + model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]), + file_path=save_path.replace(os.sep, '/'), + size=file_info.get('sizeKB', 0) * 1024, + modified=datetime.now().timestamp(), + sha256=file_info['hashes'].get('SHA256', '').lower(), + base_model=base_model, + preview_url=None, # Will be updated after preview download + preview_nsfw_level=0, + from_civitai=True, + civitai=version_info, + model_type=model_type + ) From 048d486fa6fefa02fcf71534025ef94e91d37767 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 10 Apr 2025 11:34:19 +0800 Subject: [PATCH 02/36] Refactor cache initialization in LoraManager and RecipeScanner for improved background processing and error handling --- py/lora_manager.py | 66 +------------ py/routes/checkpoints_routes.py | 62 ++++++++++++ py/routes/lora_routes.py | 53 ++++++----- py/routes/recipe_routes.py | 2 +- py/services/model_scanner.py | 162 +++++++++++++++----------------- py/services/recipe_scanner.py | 135 +++++++++++++++++--------- 6 files changed, 266 insertions(+), 214 deletions(-) diff --git a/py/lora_manager.py b/py/lora_manager.py index 14556fa4..b37d46a2 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -1,5 +1,4 @@ import asyncio -import os from server import PromptServer # type: ignore from .config import config from .routes.lora_routes import LoraRoutes @@ -10,9 +9,6 @@ from .services.lora_scanner import LoraScanner from .services.checkpoint_scanner import CheckpointScanner from .services.recipe_scanner import RecipeScanner from .services.file_monitor import LoraFileMonitor, CheckpointFileMonitor -from .services.lora_cache import LoraCache -from .services.recipe_cache import RecipeCache -from .services.model_cache import ModelCache import logging logger = logging.getLogger(__name__) @@ -117,68 +113,12 @@ class LoraManager: """Schedule cache initialization in the running event loop""" try: # Create low-priority initialization tasks - lora_task = asyncio.create_task(cls._initialize_lora_cache(lora_scanner), name='lora_cache_init') - checkpoint_task = asyncio.create_task(cls._initialize_checkpoint_cache(checkpoint_scanner), name='checkpoint_cache_init') - recipe_task = asyncio.create_task(cls._initialize_recipe_cache(recipe_scanner), name='recipe_cache_init') - logger.info("Cache initialization tasks scheduled to run in background") + lora_task = asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init') + checkpoint_task = asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init') + recipe_task = asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init') except Exception as e: logger.error(f"LoRA Manager: Error scheduling cache initialization: {e}") - @classmethod - async def _initialize_lora_cache(cls, scanner: LoraScanner): - """Initialize lora cache in background""" - try: - # Set initial placeholder cache - scanner._cache = LoraCache( - raw_data=[], - sorted_by_name=[], - sorted_by_date=[], - folders=[] - ) - # 使用线程池执行耗时操作 - loop = asyncio.get_event_loop() - await loop.run_in_executor( - None, # 使用默认线程池 - lambda: scanner.get_cached_data_sync(force_refresh=True) # 创建同步版本的方法 - ) - # Load cache in phases - # await scanner.get_cached_data(force_refresh=True) - except Exception as e: - logger.error(f"LoRA Manager: Error initializing lora cache: {e}") - - @classmethod - async def _initialize_checkpoint_cache(cls, scanner: CheckpointScanner): - """Initialize checkpoint cache in background""" - try: - # Set initial placeholder cache - scanner._cache = ModelCache( - raw_data=[], - sorted_by_name=[], - sorted_by_date=[], - folders=[] - ) - - # Load cache in phases - await scanner.get_cached_data(force_refresh=True) - except Exception as e: - logger.error(f"LoRA Manager: Error initializing checkpoint cache: {e}") - - @classmethod - async def _initialize_recipe_cache(cls, scanner: RecipeScanner): - """Initialize recipe cache in background with a delay""" - try: - # Set initial empty cache - scanner._cache = RecipeCache( - raw_data=[], - sorted_by_name=[], - sorted_by_date=[] - ) - - # Force refresh to load the actual data - await scanner.get_cached_data(force_refresh=True) - except Exception as e: - logger.error(f"LoRA Manager: Error initializing recipe cache: {e}") - @classmethod async def _cleanup(cls, app): """Cleanup resources""" diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index f2e2e631..12a8aeb1 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -2,12 +2,14 @@ import os import json import asyncio import aiohttp +import jinja2 from aiohttp import web import logging from datetime import datetime from ..services.checkpoint_scanner import CheckpointScanner from ..config import config +from ..services.settings_manager import settings logger = logging.getLogger(__name__) @@ -16,6 +18,10 @@ class CheckpointsRoutes: def __init__(self): self.scanner = CheckpointScanner() + self.template_env = jinja2.Environment( + loader=jinja2.FileSystemLoader(config.templates_path), + autoescape=True + ) def setup_routes(self, app): """Register routes with the aiohttp app""" @@ -144,3 +150,59 @@ class CheckpointsRoutes: except Exception as e: logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True) return web.json_response({"error": str(e)}, status=500) + + async def handle_checkpoints_page(self, request: web.Request) -> web.Response: + """Handle GET /checkpoints request""" + try: + # 检查缓存初始化状态,根据initialize_in_background的工作方式调整判断逻辑 + is_initializing = ( + self.scanner._cache is None or + len(self.scanner._cache.raw_data) == 0 or + hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing + ) + + if is_initializing: + # 如果正在初始化,返回一个只包含加载提示的页面 + template = self.template_env.get_template('checkpoints.html') + rendered = template.render( + folders=[], # 空文件夹列表 + is_initializing=True, # 新增标志 + settings=settings, # Pass settings to template + request=request # Pass the request object to the template + ) + + logger.info("Checkpoints page is initializing, returning loading page") + else: + # 正常流程 - 获取已经初始化好的缓存数据 + try: + cache = await self.scanner.get_cached_data(force_refresh=False) + template = self.template_env.get_template('checkpoints.html') + rendered = template.render( + folders=cache.folders, + is_initializing=False, + settings=settings, # Pass settings to template + request=request # Pass the request object to the template + ) + logger.debug(f"Checkpoints page loaded successfully with {len(cache.raw_data)} items") + except Exception as cache_error: + logger.error(f"Error loading checkpoints cache data: {cache_error}") + # 如果获取缓存失败,也显示初始化页面 + template = self.template_env.get_template('checkpoints.html') + rendered = template.render( + folders=[], + is_initializing=True, + settings=settings, + request=request + ) + logger.info("Checkpoints cache error, returning initialization page") + + return web.Response( + text=rendered, + content_type='text/html' + ) + except Exception as e: + logger.error(f"Error handling checkpoints request: {e}", exc_info=True) + return web.Response( + text="Error loading checkpoints page", + status=500 + ) diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index 196bfa25..448c424f 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -58,13 +58,11 @@ class LoraRoutes: async def handle_loras_page(self, request: web.Request) -> web.Response: """Handle GET /loras request""" try: - # 检查缓存初始化状态,增强判断条件 + # 检查缓存初始化状态,根据initialize_in_background的工作方式调整判断逻辑 is_initializing = ( self.scanner._cache is None or - (self.scanner._initialization_task is not None and - not self.scanner._initialization_task.done()) or - (self.scanner._cache is not None and len(self.scanner._cache.raw_data) == 0 and - self.scanner._initialization_task is not None) + len(self.scanner._cache.raw_data) == 0 or + hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing ) if is_initializing: @@ -79,7 +77,7 @@ class LoraRoutes: logger.info("Loras page is initializing, returning loading page") else: - # 正常流程 - 但不要等待缓存刷新 + # 正常流程 - 获取已经初始化好的缓存数据 try: cache = await self.scanner.get_cached_data(force_refresh=False) template = self.template_env.get_template('loras.html') @@ -117,32 +115,45 @@ class LoraRoutes: async def handle_recipes_page(self, request: web.Request) -> web.Response: """Handle GET /loras/recipes request""" try: - # Check cache initialization status + # 检查缓存初始化状态,与handle_loras_page保持一致的逻辑 is_initializing = ( - self.recipe_scanner._cache is None and - (self.recipe_scanner._initialization_task is not None and - not self.recipe_scanner._initialization_task.done()) + self.recipe_scanner._cache is None or + len(self.recipe_scanner._cache.raw_data) == 0 or + hasattr(self.recipe_scanner, '_is_initializing') and self.recipe_scanner._is_initializing ) if is_initializing: - # If initializing, return a loading page + # 如果正在初始化,返回一个只包含加载提示的页面 template = self.template_env.get_template('recipes.html') rendered = template.render( is_initializing=True, settings=settings, request=request # Pass the request object to the template ) - else: - # return empty recipes - recipes_data = [] - template = self.template_env.get_template('recipes.html') - rendered = template.render( - recipes=recipes_data, - is_initializing=False, - settings=settings, - request=request # Pass the request object to the template - ) + logger.info("Recipes page is initializing, returning loading page") + else: + # 正常流程 - 获取已经初始化好的缓存数据 + try: + cache = await self.recipe_scanner.get_cached_data(force_refresh=False) + template = self.template_env.get_template('recipes.html') + rendered = template.render( + recipes=[], # Frontend will load recipes via API + is_initializing=False, + settings=settings, + request=request # Pass the request object to the template + ) + logger.debug(f"Recipes page loaded successfully with {len(cache.raw_data)} items") + except Exception as cache_error: + logger.error(f"Error loading recipe cache data: {cache_error}") + # 如果获取缓存失败,也显示初始化页面 + template = self.template_env.get_template('recipes.html') + rendered = template.render( + is_initializing=True, + settings=settings, + request=request + ) + logger.info("Recipe cache error, returning initialization page") return web.Response( text=rendered, diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index 32de5722..796df6a1 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -1146,7 +1146,7 @@ class RecipeRoutes: return web.json_response({'success': False, 'error': 'Lora hash is required'}, status=400) # Log the search parameters - logger.info(f"Getting recipes for Lora by hash: {lora_hash}") + logger.debug(f"Getting recipes for Lora by hash: {lora_hash}") # Get all recipes from cache cache = await self.recipe_scanner.get_cached_data() diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index b15e35c9..a4b83f6f 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -34,49 +34,96 @@ class ModelScanner: self.file_extensions = file_extensions self._cache = None self._hash_index = hash_index or ModelHashIndex() - self._initialization_lock = asyncio.Lock() - self._initialization_task = None self.file_monitor = None self._tags_count = {} # Dictionary to store tag counts + self._is_initializing = False # Flag to track initialization state def set_file_monitor(self, monitor): """Set file monitor instance""" self.file_monitor = monitor - async def get_cached_data(self, force_refresh: bool = False) -> ModelCache: - """Get cached model data, refresh if needed""" - async with self._initialization_lock: - # Return empty cache if not initialized and no refresh requested - if self._cache is None and not force_refresh: - return ModelCache( + async def initialize_in_background(self) -> None: + """Initialize cache in background using thread pool""" + try: + # Set initial empty cache to avoid None reference errors + if self._cache is None: + self._cache = ModelCache( raw_data=[], sorted_by_name=[], sorted_by_date=[], folders=[] ) - - # Wait for ongoing initialization if any - if self._initialization_task and not self._initialization_task.done(): - try: - await self._initialization_task - except Exception as e: - logger.error(f"Cache initialization failed: {e}") - self._initialization_task = None - if (self._cache is None or force_refresh): - # Create new initialization task - if not self._initialization_task or self._initialization_task.done(): - self._initialization_task = asyncio.create_task(self._initialize_cache()) + # Set initializing flag to true + self._is_initializing = True + + start_time = time.time() + # Use thread pool to execute CPU-intensive operations + loop = asyncio.get_event_loop() + await loop.run_in_executor( + None, # Use default thread pool + self._initialize_cache_sync # Run synchronous version in thread + ) + logger.info(f"{self.model_type.capitalize()} cache initialized in {time.time() - start_time:.2f} seconds. Found {len(self._cache.raw_data)} models") + except Exception as e: + logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache in background: {e}") + finally: + # Always clear the initializing flag when done + self._is_initializing = False + + def _initialize_cache_sync(self): + """Synchronous version of cache initialization for thread pool execution""" + try: + # Create a new event loop for this thread + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Create a synchronous method to bypass the async lock + def sync_initialize_cache(): + # Directly call the scan method to avoid lock issues + raw_data = loop.run_until_complete(self.scan_all_models()) - try: - await self._initialization_task - except Exception as e: - logger.error(f"Cache initialization failed: {e}") - # Continue using old cache if it exists - if self._cache is None: - raise # Raise exception if no cache available + # Update hash index and tags count + for model_data in raw_data: + if 'sha256' in model_data and 'file_path' in model_data: + self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path']) + + # Count tags + if 'tags' in model_data and model_data['tags']: + for tag in model_data['tags']: + self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 + + # Update cache + self._cache.raw_data = raw_data + loop.run_until_complete(self._cache.resort()) + + return self._cache - return self._cache + # Run our sync initialization that avoids lock conflicts + return sync_initialize_cache() + except Exception as e: + logger.error(f"Error in thread-based {self.model_type} cache initialization: {e}") + finally: + # Clean up the event loop + loop.close() + + async def get_cached_data(self, force_refresh: bool = False) -> ModelCache: + """Get cached model data, refresh if needed""" + # If cache is not initialized, return an empty cache + # Actual initialization should be done via initialize_in_background + if self._cache is None and not force_refresh: + return ModelCache( + raw_data=[], + sorted_by_name=[], + sorted_by_date=[], + folders=[] + ) + + # If force refresh is requested, initialize the cache directly + if force_refresh: + await self._initialize_cache() + + return self._cache async def _initialize_cache(self) -> None: """Initialize or refresh the cache""" @@ -112,7 +159,6 @@ class ModelScanner: # Resort cache await self._cache.resort() - self._initialization_task = None logger.info(f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} models") except Exception as e: logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache: {e}") @@ -157,12 +203,10 @@ class ModelScanner: async def _get_file_info(self, file_path: str) -> Optional[BaseModelMetadata]: """Get model file info and metadata (extensible for different model types)""" - # Implementation may vary by model type - override in subclasses if needed return await get_file_info(file_path, self.model_class) def _calculate_folder(self, file_path: str) -> str: """Calculate the folder path for a model file""" - # Use original path to calculate relative path for root in self.get_model_roots(): if file_path.startswith(root): rel_path = os.path.relpath(file_path, root) @@ -172,11 +216,9 @@ class ModelScanner: # Common methods shared between scanners async def _process_model_file(self, file_path: str, root_path: str) -> Dict: """Process a single model file and return its metadata""" - # Try loading existing metadata metadata = await load_metadata(file_path, self.model_class) if metadata is None: - # Try to find and use .civitai.info file first civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info" if os.path.exists(civitai_info_path): try: @@ -185,11 +227,9 @@ class ModelScanner: file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None) if file_info: - # Create a minimal file_info with the required fields file_name = os.path.splitext(os.path.basename(file_path))[0] file_info['name'] = file_name - # Use from_civitai_info to create metadata metadata = self.model_class.from_civitai_info(version_info, file_info, file_path) metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path)) await save_metadata(file_path, metadata) @@ -197,14 +237,11 @@ class ModelScanner: except Exception as e: logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}") - # If still no metadata, create new metadata if metadata is None: metadata = await self._get_file_info(file_path) - # Convert to dict and add folder info model_data = metadata.to_dict() - # Try to fetch missing metadata from Civitai if needed await self._fetch_missing_metadata(file_path, model_data) rel_path = os.path.relpath(file_path, root_path) folder = os.path.dirname(rel_path) @@ -215,59 +252,47 @@ class ModelScanner: async def _fetch_missing_metadata(self, file_path: str, model_data: Dict) -> None: """Fetch missing description and tags from Civitai if needed""" try: - # Skip if already marked as deleted on Civitai if model_data.get('civitai_deleted', False): logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai") return - # Check if we need to fetch additional metadata from Civitai needs_metadata_update = False model_id = None - # Check if we have Civitai model ID but missing metadata if model_data.get('civitai'): model_id = model_data['civitai'].get('modelId') if model_id: model_id = str(model_id) - # Check if tags or description are missing tags_missing = not model_data.get('tags') or len(model_data.get('tags', [])) == 0 desc_missing = not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "") needs_metadata_update = tags_missing or desc_missing - # Fetch missing metadata if needed if needs_metadata_update and model_id: logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}") from ..services.civitai_client import CivitaiClient client = CivitaiClient() - # Get metadata and status code model_metadata, status_code = await client.get_model_metadata(model_id) await client.close() - # Handle 404 status (model deleted from Civitai) if status_code == 404: logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)") model_data['civitai_deleted'] = True - # Save the updated metadata metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' with open(metadata_path, 'w', encoding='utf-8') as f: json.dump(model_data, f, indent=2, ensure_ascii=False) - # Process valid metadata if available elif model_metadata: logger.debug(f"Updating metadata for {file_path} with model ID {model_id}") - # Update tags if they were missing if model_metadata.get('tags') and (not model_data.get('tags') or len(model_data.get('tags', [])) == 0): model_data['tags'] = model_metadata['tags'] - # Update description if it was missing if model_metadata.get('description') and (not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")): model_data['modelDescription'] = model_metadata['description'] - # Save the updated metadata metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' with open(metadata_path, 'w', encoding='utf-8') as f: json.dump(model_data, f, indent=2, ensure_ascii=False) @@ -292,14 +317,12 @@ class ModelScanner: for entry in entries: try: if entry.is_file(follow_symlinks=True): - # Check if file has supported extension ext = os.path.splitext(entry.name)[1].lower() if ext in self.file_extensions: file_path = entry.path.replace(os.sep, "/") await self._process_single_file(file_path, original_root, models) await asyncio.sleep(0) elif entry.is_dir(follow_symlinks=True): - # For directories, continue scanning with original path await scan_recursive(entry.path, visited_paths) except Exception as e: logger.error(f"Error processing entry {entry.path}: {e}") @@ -321,14 +344,11 @@ class ModelScanner: async def move_model(self, source_path: str, target_path: str) -> bool: """Move a model and its associated files to a new location""" try: - # Keep original path format source_path = source_path.replace(os.sep, '/') target_path = target_path.replace(os.sep, '/') - # Get file extension from source file_ext = os.path.splitext(source_path)[1] - # If no extension or not in supported extensions, return False if not file_ext or file_ext.lower() not in self.file_extensions: logger.error(f"Invalid file extension for model: {file_ext}") return False @@ -340,7 +360,6 @@ class ModelScanner: target_file = os.path.join(target_path, f"{base_name}{file_ext}").replace(os.sep, '/') - # Use real paths for file operations real_source = os.path.realpath(source_path) real_target = os.path.realpath(target_file) @@ -356,10 +375,8 @@ class ModelScanner: file_size ) - # Use real paths for file operations shutil.move(real_source, real_target) - # Move associated files source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json") metadata = None if os.path.exists(source_metadata): @@ -367,7 +384,6 @@ class ModelScanner: shutil.move(source_metadata, target_metadata) metadata = await self._update_metadata_paths(target_metadata, target_file) - # Move preview file if exists preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4', '.png', '.jpeg', '.jpg', '.mp4'] for ext in preview_extensions: @@ -377,7 +393,6 @@ class ModelScanner: shutil.move(source_preview, target_preview) break - # Update cache await self.update_single_model_cache(source_path, target_file, metadata) return True @@ -392,10 +407,8 @@ class ModelScanner: with open(metadata_path, 'r', encoding='utf-8') as f: metadata = json.load(f) - # Update file_path metadata['file_path'] = model_path.replace(os.sep, '/') - # Update preview_url if exists if 'preview_url' in metadata: preview_dir = os.path.dirname(model_path) preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0] @@ -403,7 +416,6 @@ class ModelScanner: new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}") metadata['preview_url'] = new_preview_path.replace(os.sep, '/') - # Save updated metadata with open(metadata_path, 'w', encoding='utf-8') as f: json.dump(metadata, f, indent=2, ensure_ascii=False) @@ -417,7 +429,6 @@ class ModelScanner: """Update cache after a model has been moved or modified""" cache = await self.get_cached_data() - # Find the existing item to remove its tags from count existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None) if existing_item and 'tags' in existing_item: for tag in existing_item.get('tags', []): @@ -426,19 +437,15 @@ class ModelScanner: if self._tags_count[tag] == 0: del self._tags_count[tag] - # Remove old path from hash index if exists self._hash_index.remove_by_path(original_path) - # Remove the old entry from raw_data cache.raw_data = [ item for item in cache.raw_data if item['file_path'] != original_path ] if metadata: - # If this is an update to an existing path (not a move), ensure folder is preserved if original_path == new_path: - # Find the folder from existing entries or calculate it existing_folder = next((item['folder'] for item in cache.raw_data if item['file_path'] == original_path), None) if existing_folder: @@ -446,31 +453,24 @@ class ModelScanner: else: metadata['folder'] = self._calculate_folder(new_path) else: - # For moved files, recalculate the folder metadata['folder'] = self._calculate_folder(new_path) - # Add the updated metadata to raw_data cache.raw_data.append(metadata) - # Update hash index with new path if 'sha256' in metadata: self._hash_index.add_entry(metadata['sha256'].lower(), new_path) - # Update folders list all_folders = set(item['folder'] for item in cache.raw_data) cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) - # Update tags count with the new/updated tags if 'tags' in metadata: for tag in metadata.get('tags', []): self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 - # Resort cache await cache.resort() return True - # Hash index functionality (common for all model types) def has_hash(self, sha256: str) -> bool: """Check if a model with given hash exists""" return self._hash_index.has_hash(sha256.lower()) @@ -485,12 +485,10 @@ class ModelScanner: def get_preview_url_by_hash(self, sha256: str) -> Optional[str]: """Get preview static URL for a model by its hash""" - # Get the file path first file_path = self._hash_index.get_path(sha256.lower()) if not file_path: return None - # Determine the preview file path (typically same name with different extension) base_name = os.path.splitext(file_path)[0] preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4', '.png', '.jpeg', '.jpg', '.mp4'] @@ -498,52 +496,42 @@ class ModelScanner: for ext in preview_extensions: preview_path = f"{base_name}{ext}" if os.path.exists(preview_path): - # Convert to static URL using config return config.get_preview_static_url(preview_path) return None async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]: """Get top tags sorted by count""" - # Make sure cache is initialized await self.get_cached_data() - # Sort tags by count in descending order sorted_tags = sorted( [{"tag": tag, "count": count} for tag, count in self._tags_count.items()], key=lambda x: x['count'], reverse=True ) - # Return limited number return sorted_tags[:limit] async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]: """Get base models sorted by frequency""" - # Make sure cache is initialized cache = await self.get_cached_data() - # Count base model occurrences base_model_counts = {} for model in cache.raw_data: if 'base_model' in model and model['base_model']: base_model = model['base_model'] base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 - # Sort base models by count sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()] sorted_models.sort(key=lambda x: x['count'], reverse=True) - # Return limited number return sorted_models[:limit] async def get_model_info_by_name(self, name): """Get model information by name""" try: - # Get cached data cache = await self.get_cached_data() - # Find the model by name for model in cache.raw_data: if model.get("file_name") == name: return model diff --git a/py/services/recipe_scanner.py b/py/services/recipe_scanner.py index ec3310ee..588e4e43 100644 --- a/py/services/recipe_scanner.py +++ b/py/services/recipe_scanner.py @@ -35,9 +35,61 @@ class RecipeScanner: if lora_scanner: self._lora_scanner = lora_scanner self._initialized = True - - # Initialization will be scheduled by LoraManager + async def initialize_in_background(self) -> None: + """Initialize cache in background using thread pool""" + try: + # Set initial empty cache to avoid None reference errors + if self._cache is None: + self._cache = RecipeCache( + raw_data=[], + sorted_by_name=[], + sorted_by_date=[] + ) + + # Mark as initializing to prevent concurrent initializations + self._is_initializing = True + + try: + # Use thread pool to execute CPU-intensive operations + loop = asyncio.get_event_loop() + await loop.run_in_executor( + None, # Use default thread pool + self._initialize_recipe_cache_sync # Run synchronous version in thread + ) + logger.info("Recipe cache initialization completed in background thread") + finally: + # Mark initialization as complete regardless of outcome + self._is_initializing = False + except Exception as e: + logger.error(f"Recipe Scanner: Error initializing cache in background: {e}") + + def _initialize_recipe_cache_sync(self): + """Synchronous version of recipe cache initialization for thread pool execution""" + try: + # Create a new event loop for this thread + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Create a synchronous method to bypass the async lock + def sync_initialize_cache(): + # Directly call the internal scan method to avoid lock issues + raw_data = loop.run_until_complete(self.scan_all_recipes()) + + # Update cache + self._cache.raw_data = raw_data + loop.run_until_complete(self._cache.resort()) + + return self._cache + + # Run our sync initialization that avoids lock conflicts + return sync_initialize_cache() + except Exception as e: + logger.error(f"Error in thread-based recipe cache initialization: {e}") + finally: + # Clean up the event loop + loop.close() + @property def recipes_dir(self) -> str: """Get path to recipes directory""" @@ -60,49 +112,48 @@ class RecipeScanner: if self._is_initializing and not force_refresh: return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[]) - # Try to acquire the lock with a timeout to prevent deadlocks - try: - async with self._initialization_lock: - # Check again after acquiring the lock - if self._cache is not None and not force_refresh: - return self._cache - - # Mark as initializing to prevent concurrent initializations - self._is_initializing = True - - try: - # Remove dependency on lora scanner initialization - # Scan for recipe data directly - raw_data = await self.scan_all_recipes() + # If force refresh is requested, initialize the cache directly + if force_refresh: + # Try to acquire the lock with a timeout to prevent deadlocks + try: + async with self._initialization_lock: + # Mark as initializing to prevent concurrent initializations + self._is_initializing = True - # Update cache - self._cache = RecipeCache( - raw_data=raw_data, - sorted_by_name=[], - sorted_by_date=[] - ) + try: + # Scan for recipe data directly + raw_data = await self.scan_all_recipes() + + # Update cache + self._cache = RecipeCache( + raw_data=raw_data, + sorted_by_name=[], + sorted_by_date=[] + ) + + # Resort cache + await self._cache.resort() + + return self._cache - # Resort cache - await self._cache.resort() - - return self._cache - - except Exception as e: - logger.error(f"Recipe Manager: Error initializing cache: {e}", exc_info=True) - # Create empty cache on error - self._cache = RecipeCache( - raw_data=[], - sorted_by_name=[], - sorted_by_date=[] - ) - return self._cache - finally: - # Mark initialization as complete - self._is_initializing = False + except Exception as e: + logger.error(f"Recipe Manager: Error initializing cache: {e}", exc_info=True) + # Create empty cache on error + self._cache = RecipeCache( + raw_data=[], + sorted_by_name=[], + sorted_by_date=[] + ) + return self._cache + finally: + # Mark initialization as complete + self._is_initializing = False + + except Exception as e: + logger.error(f"Unexpected error in get_cached_data: {e}") - except Exception as e: - logger.error(f"Unexpected error in get_cached_data: {e}") - return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[]) + # Return the cache (may be empty or partially initialized) + return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[]) async def scan_all_recipes(self) -> List[Dict]: """Scan all recipe JSON files and return metadata""" From 252e90a633441c53038920fc0d0d3ab0a0305863 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 10 Apr 2025 16:04:08 +0800 Subject: [PATCH 03/36] Enhance Checkpoints Manager: Implement API integration for checkpoints, add filtering and sorting options, and improve UI components for better user experience --- py/routes/checkpoints_routes.py | 225 ++++++++++++++++-- static/js/api/checkpointApi.js | 247 ++++++++++++++++++++ static/js/checkpoints.js | 120 ++++++++-- static/js/components/CheckpointCard.js | 147 ++++++++++++ static/js/components/CheckpointModal.js | 120 ++++++++++ static/js/managers/FilterManager.js | 5 +- static/js/managers/SearchManager.js | 1 + templates/checkpoints.html | 17 +- templates/components/checkpoint_modals.html | 35 +++ templates/components/header.html | 1 + 10 files changed, 877 insertions(+), 41 deletions(-) create mode 100644 static/js/api/checkpointApi.js create mode 100644 static/js/components/CheckpointCard.js create mode 100644 static/js/components/CheckpointModal.js create mode 100644 templates/components/checkpoint_modals.html diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index 12a8aeb1..efd480dd 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -10,6 +10,7 @@ from datetime import datetime from ..services.checkpoint_scanner import CheckpointScanner from ..config import config from ..services.settings_manager import settings +from ..utils.utils import fuzzy_match logger = logging.getLogger(__name__) @@ -25,9 +26,12 @@ class CheckpointsRoutes: def setup_routes(self, app): """Register routes with the aiohttp app""" - app.router.add_get('/lora_manager/api/checkpoints', self.get_checkpoints) - app.router.add_get('/lora_manager/api/checkpoints/scan', self.scan_checkpoints) - app.router.add_get('/lora_manager/api/checkpoints/info/{name}', self.get_checkpoint_info) + app.router.add_get('/checkpoints', self.handle_checkpoints_page) + app.router.add_get('/api/checkpoints', self.get_checkpoints) + app.router.add_get('/api/checkpoints/base-models', self.get_base_models) + app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags) + app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints) + app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info) async def get_checkpoints(self, request): """Get paginated checkpoint data""" @@ -76,8 +80,17 @@ class CheckpointsRoutes: hash_filters=hash_filters ) + # Format response items + formatted_result = { + 'items': [self._format_checkpoint_response(cp) for cp in result['items']], + 'total': result['total'], + 'page': result['page'], + 'page_size': result['page_size'], + 'total_pages': result['total_pages'] + } + # Return as JSON - return web.json_response(result) + return web.json_response(formatted_result) except Exception as e: logger.error(f"Error in get_checkpoints: {e}", exc_info=True) @@ -90,28 +103,122 @@ class CheckpointsRoutes: """Get paginated and filtered checkpoint data""" cache = await self.scanner.get_cached_data() - # Implement similar filtering logic as in LoraScanner - # (Adapt code from LoraScanner.get_paginated_data) - # ... - - # For now, a simplified implementation: + # Get default search options if not provided + if search_options is None: + search_options = { + 'filename': True, + 'modelname': True, + 'tags': False, + 'recursive': False, + } + + # Get the base data set filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name - # Apply basic folder filtering if needed + # Apply hash filtering if provided (highest priority) + if hash_filters: + single_hash = hash_filters.get('single_hash') + multiple_hashes = hash_filters.get('multiple_hashes') + + if single_hash: + # Filter by single hash + single_hash = single_hash.lower() # Ensure lowercase for matching + filtered_data = [ + cp for cp in filtered_data + if cp.get('sha256', '').lower() == single_hash + ] + elif multiple_hashes: + # Filter by multiple hashes + hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup + filtered_data = [ + cp for cp in filtered_data + if cp.get('sha256', '').lower() in hash_set + ] + + # Jump to pagination + total_items = len(filtered_data) + start_idx = (page - 1) * page_size + end_idx = min(start_idx + page_size, total_items) + + result = { + 'items': filtered_data[start_idx:end_idx], + 'total': total_items, + 'page': page, + 'page_size': page_size, + 'total_pages': (total_items + page_size - 1) // page_size + } + + return result + + # Apply SFW filtering if enabled in settings + if settings.get('show_only_sfw', False): + filtered_data = [ + cp for cp in filtered_data + if not cp.get('preview_nsfw_level') or cp.get('preview_nsfw_level') < NSFW_LEVELS['R'] + ] + + # Apply folder filtering if folder is not None: + if search_options.get('recursive', False): + # Recursive folder filtering - include all subfolders + filtered_data = [ + cp for cp in filtered_data + if cp['folder'].startswith(folder) + ] + else: + # Exact folder filtering + filtered_data = [ + cp for cp in filtered_data + if cp['folder'] == folder + ] + + # Apply base model filtering + if base_models and len(base_models) > 0: filtered_data = [ cp for cp in filtered_data - if cp['folder'] == folder + if cp.get('base_model') in base_models ] - # Apply basic search if needed + # Apply tag filtering + if tags and len(tags) > 0: + filtered_data = [ + cp for cp in filtered_data + if any(tag in cp.get('tags', []) for tag in tags) + ] + + # Apply search filtering if search: - filtered_data = [ - cp for cp in filtered_data - if search.lower() in cp['file_name'].lower() or - search.lower() in cp['model_name'].lower() - ] - + search_results = [] + + for cp in filtered_data: + # Search by file name + if search_options.get('filename', True): + if fuzzy_search: + if fuzzy_match(cp.get('file_name', ''), search): + search_results.append(cp) + continue + elif search.lower() in cp.get('file_name', '').lower(): + search_results.append(cp) + continue + + # Search by model name + if search_options.get('modelname', True): + if fuzzy_search: + if fuzzy_match(cp.get('model_name', ''), search): + search_results.append(cp) + continue + elif search.lower() in cp.get('model_name', '').lower(): + search_results.append(cp) + continue + + # Search by tags + if search_options.get('tags', False) and 'tags' in cp: + if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower()) for tag in cp['tags']): + search_results.append(cp) + continue + + filtered_data = search_results + # Calculate pagination total_items = len(filtered_data) start_idx = (page - 1) * page_size @@ -127,6 +234,88 @@ class CheckpointsRoutes: return result + def _format_checkpoint_response(self, checkpoint): + """Format checkpoint data for API response""" + return { + "model_name": checkpoint["model_name"], + "file_name": checkpoint["file_name"], + "preview_url": config.get_preview_static_url(checkpoint.get("preview_url", "")), + "preview_nsfw_level": checkpoint.get("preview_nsfw_level", 0), + "base_model": checkpoint.get("base_model", ""), + "folder": checkpoint["folder"], + "sha256": checkpoint.get("sha256", ""), + "file_path": checkpoint["file_path"].replace(os.sep, "/"), + "file_size": checkpoint.get("size", 0), + "modified": checkpoint.get("modified", ""), + "tags": checkpoint.get("tags", []), + "modelDescription": checkpoint.get("modelDescription", ""), + "from_civitai": checkpoint.get("from_civitai", True), + "notes": checkpoint.get("notes", ""), + "model_type": checkpoint.get("model_type", "checkpoint"), + "civitai": self._filter_civitai_data(checkpoint.get("civitai", {})) + } + + def _filter_civitai_data(self, data): + """Filter relevant fields from CivitAI data""" + if not data: + return {} + + fields = [ + "id", "modelId", "name", "createdAt", "updatedAt", + "publishedAt", "trainedWords", "baseModel", "description", + "model", "images" + ] + return {k: data[k] for k in fields if k in data} + + async def get_top_tags(self, request: web.Request) -> web.Response: + """Handle request for top tags sorted by frequency""" + try: + # Parse query parameters + limit = int(request.query.get('limit', '20')) + + # Validate limit + if limit < 1 or limit > 100: + limit = 20 # Default to a reasonable limit + + # Get top tags + top_tags = await self.scanner.get_top_tags(limit) + + return web.json_response({ + 'success': True, + 'tags': top_tags + }) + + except Exception as e: + logger.error(f"Error getting top tags: {str(e)}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': 'Internal server error' + }, status=500) + + async def get_base_models(self, request: web.Request) -> web.Response: + """Get base models used in loras""" + try: + # Parse query parameters + limit = int(request.query.get('limit', '20')) + + # Validate limit + if limit < 1 or limit > 100: + limit = 20 # Default to a reasonable limit + + # Get base models + base_models = await self.scanner.get_base_models(limit) + + return web.json_response({ + 'success': True, + 'base_models': base_models + }) + except Exception as e: + logger.error(f"Error retrieving base models: {e}") + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + async def scan_checkpoints(self, request): """Force a rescan of checkpoint files""" try: diff --git a/static/js/api/checkpointApi.js b/static/js/api/checkpointApi.js new file mode 100644 index 00000000..cbf1ca0b --- /dev/null +++ b/static/js/api/checkpointApi.js @@ -0,0 +1,247 @@ +import { state, getCurrentPageState } from '../state/index.js'; +import { showToast } from '../utils/uiHelpers.js'; +import { confirmDelete } from '../utils/modalUtils.js'; +import { createCheckpointCard } from '../components/CheckpointCard.js'; + +// Load more checkpoints with pagination +export async function loadMoreCheckpoints(resetPagination = true) { + try { + const pageState = getCurrentPageState(); + + // Don't load if we're already loading or there are no more items + if (pageState.isLoading || (!resetPagination && !pageState.hasMore)) { + return; + } + + // Set loading state + pageState.isLoading = true; + document.body.classList.add('loading'); + + // Reset pagination if requested + if (resetPagination) { + pageState.currentPage = 1; + const grid = document.getElementById('checkpointGrid'); + if (grid) grid.innerHTML = ''; + } + + // Build API URL with parameters + const params = new URLSearchParams({ + page: pageState.currentPage, + page_size: pageState.pageSize || 20, + sort: pageState.sortBy || 'name' + }); + + // Add folder filter if active + if (pageState.activeFolder) { + params.append('folder', pageState.activeFolder); + } + + // Add search if available + if (pageState.filters && pageState.filters.search) { + params.append('search', pageState.filters.search); + + // Add search options + if (pageState.searchOptions) { + params.append('search_filename', pageState.searchOptions.filename.toString()); + params.append('search_modelname', pageState.searchOptions.modelname.toString()); + params.append('recursive', pageState.searchOptions.recursive.toString()); + } + } + + // Add base model filters + if (pageState.filters && pageState.filters.baseModel && pageState.filters.baseModel.length > 0) { + pageState.filters.baseModel.forEach(model => { + params.append('base_model', model); + }); + } + + // Add tags filters + if (pageState.filters && pageState.filters.tags && pageState.filters.tags.length > 0) { + pageState.filters.tags.forEach(tag => { + params.append('tag', tag); + }); + } + + // Execute fetch + const response = await fetch(`/api/checkpoints?${params.toString()}`); + + if (!response.ok) { + throw new Error(`Failed to load checkpoints: ${response.status} ${response.statusText}`); + } + + const data = await response.json(); + + // Update state with response data + pageState.hasMore = data.page < data.total_pages; + + // Update UI with checkpoints + const grid = document.getElementById('checkpointGrid'); + if (!grid) { + return; + } + + // Clear grid if this is the first page + if (resetPagination) { + grid.innerHTML = ''; + } + + // Check for empty result + if (data.items.length === 0 && resetPagination) { + grid.innerHTML = ` +
+

No checkpoints found

+

Add checkpoints to your models folders to see them here.

+
+ `; + return; + } + + // Render checkpoint cards + data.items.forEach(checkpoint => { + const card = createCheckpointCard(checkpoint); + grid.appendChild(card); + }); + } catch (error) { + console.error('Error loading checkpoints:', error); + showToast('Failed to load checkpoints', 'error'); + } finally { + // Clear loading state + const pageState = getCurrentPageState(); + pageState.isLoading = false; + document.body.classList.remove('loading'); + } +} + +// Reset and reload checkpoints +export async function resetAndReload() { + const pageState = getCurrentPageState(); + pageState.currentPage = 1; + pageState.hasMore = true; + await loadMoreCheckpoints(true); +} + +// Refresh checkpoints +export async function refreshCheckpoints() { + try { + showToast('Scanning for checkpoints...', 'info'); + const response = await fetch('/api/checkpoints/scan'); + + if (!response.ok) { + throw new Error(`Failed to scan checkpoints: ${response.status} ${response.statusText}`); + } + + await resetAndReload(); + showToast('Checkpoints refreshed successfully', 'success'); + } catch (error) { + console.error('Error refreshing checkpoints:', error); + showToast('Failed to refresh checkpoints', 'error'); + } +} + +// Delete a checkpoint +export function deleteCheckpoint(filePath) { + confirmDelete('Are you sure you want to delete this checkpoint?', () => { + _performDelete(filePath); + }); +} + +// Private function to perform the delete operation +async function _performDelete(filePath) { + try { + showToast('Deleting checkpoint...', 'info'); + + const response = await fetch('/api/model/delete', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + file_path: filePath, + model_type: 'checkpoint' + }) + }); + + if (!response.ok) { + throw new Error(`Failed to delete checkpoint: ${response.status} ${response.statusText}`); + } + + const data = await response.json(); + + if (data.success) { + // Remove the card from UI + const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); + if (card) { + card.remove(); + } + + showToast('Checkpoint deleted successfully', 'success'); + } else { + throw new Error(data.error || 'Failed to delete checkpoint'); + } + } catch (error) { + console.error('Error deleting checkpoint:', error); + showToast(`Failed to delete checkpoint: ${error.message}`, 'error'); + } +} + +// Replace checkpoint preview +export function replaceCheckpointPreview(filePath) { + // Open file picker + const input = document.createElement('input'); + input.type = 'file'; + input.accept = 'image/*'; + input.onchange = async (e) => { + if (!e.target.files.length) return; + + const file = e.target.files[0]; + await _uploadPreview(filePath, file); + }; + input.click(); +} + +// Upload a preview image +async function _uploadPreview(filePath, file) { + try { + showToast('Uploading preview...', 'info'); + + const formData = new FormData(); + formData.append('file', file); + formData.append('file_path', filePath); + formData.append('model_type', 'checkpoint'); + + const response = await fetch('/api/model/preview', { + method: 'POST', + body: formData + }); + + if (!response.ok) { + throw new Error(`Failed to upload preview: ${response.status} ${response.statusText}`); + } + + const data = await response.json(); + + if (data.success) { + // Update the preview in UI + const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); + if (card) { + const img = card.querySelector('.card-preview img'); + if (img) { + // Add timestamp to prevent caching + const timestamp = new Date().getTime(); + if (data.preview_url) { + img.src = `${data.preview_url}?t=${timestamp}`; + } else { + img.src = `/api/model/preview_image?path=${encodeURIComponent(filePath)}&t=${timestamp}`; + } + } + } + + showToast('Preview updated successfully', 'success'); + } else { + throw new Error(data.error || 'Failed to update preview'); + } + } catch (error) { + console.error('Error updating preview:', error); + showToast(`Failed to update preview: ${error.message}`, 'error'); + } +} \ No newline at end of file diff --git a/static/js/checkpoints.js b/static/js/checkpoints.js index ea149a2f..8b563f8c 100644 --- a/static/js/checkpoints.js +++ b/static/js/checkpoints.js @@ -1,36 +1,128 @@ import { appCore } from './core.js'; -import { state, initPageState } from './state/index.js'; +import { state, getCurrentPageState } from './state/index.js'; +import { + loadMoreCheckpoints, + resetAndReload, + refreshCheckpoints, + deleteCheckpoint, + replaceCheckpointPreview +} from './api/checkpointApi.js'; +import { + restoreFolderFilter, + toggleFolder, + openCivitai, + showToast +} from './utils/uiHelpers.js'; +import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js'; +import { toggleApiKeyVisibility } from './managers/SettingsManager.js'; +import { initializeInfiniteScroll } from './utils/infiniteScroll.js'; +import { setStorageItem, getStorageItem } from './utils/storageHelpers.js'; // Initialize the Checkpoints page class CheckpointsPageManager { constructor() { - // Initialize any necessary state - this.initialized = false; + // Get page state + this.pageState = getCurrentPageState(); + + // Set default values + this.pageState.pageSize = 20; + this.pageState.isLoading = false; + this.pageState.hasMore = true; + + // Expose functions to window object + this._exposeGlobalFunctions(); + } + + _exposeGlobalFunctions() { + // API functions + window.loadCheckpoints = (reset = true) => this.loadCheckpoints(reset); + window.refreshCheckpoints = refreshCheckpoints; + window.deleteCheckpoint = deleteCheckpoint; + window.replaceCheckpointPreview = replaceCheckpointPreview; + + // UI helper functions + window.toggleFolder = toggleFolder; + window.openCivitai = openCivitai; + window.confirmDelete = confirmDelete; + window.closeDeleteModal = closeDeleteModal; + window.toggleApiKeyVisibility = toggleApiKeyVisibility; + + // Add reference to this manager + window.checkpointManager = this; } async initialize() { - if (this.initialized) return; + // Initialize event listeners + this._initEventListeners(); - // Initialize page state - initPageState('checkpoints'); + // Restore folder filters if available + restoreFolderFilter('checkpoints'); - // Initialize core application - await appCore.initialize(); + // Load sort preference + this._loadSortPreference(); - // Initialize page-specific components - this._initializeWorkInProgress(); + // Load initial checkpoints + await this.loadCheckpoints(); - this.initialized = true; + // Initialize infinite scroll + initializeInfiniteScroll('checkpoints'); + + // Initialize common page features + appCore.initializePageFeatures(); + + console.log('Checkpoints Manager initialized'); } - _initializeWorkInProgress() { - // Add any work-in-progress specific initialization here - console.log('Checkpoints Manager is under development'); + _initEventListeners() { + // Sort select handler + const sortSelect = document.getElementById('sortSelect'); + if (sortSelect) { + sortSelect.addEventListener('change', async (e) => { + this.pageState.sortBy = e.target.value; + this._saveSortPreference(e.target.value); + await resetAndReload(); + }); + } + + // Folder tags handler + document.querySelectorAll('.folder-tags .tag').forEach(tag => { + tag.addEventListener('click', toggleFolder); + }); + + // Refresh button handler + const refreshBtn = document.getElementById('refreshBtn'); + if (refreshBtn) { + refreshBtn.addEventListener('click', () => refreshCheckpoints()); + } + } + + _loadSortPreference() { + const savedSort = getStorageItem('checkpoints_sort'); + if (savedSort) { + this.pageState.sortBy = savedSort; + const sortSelect = document.getElementById('sortSelect'); + if (sortSelect) { + sortSelect.value = savedSort; + } + } + } + + _saveSortPreference(sortValue) { + setStorageItem('checkpoints_sort', sortValue); + } + + // Load checkpoints with optional pagination reset + async loadCheckpoints(resetPage = true) { + await loadMoreCheckpoints(resetPage); } } // Initialize everything when DOM is ready document.addEventListener('DOMContentLoaded', async () => { + // Initialize core application + await appCore.initialize(); + + // Initialize checkpoints page const checkpointsPage = new CheckpointsPageManager(); await checkpointsPage.initialize(); }); diff --git a/static/js/components/CheckpointCard.js b/static/js/components/CheckpointCard.js new file mode 100644 index 00000000..a9246b8e --- /dev/null +++ b/static/js/components/CheckpointCard.js @@ -0,0 +1,147 @@ +import { showToast } from '../utils/uiHelpers.js'; +import { state } from '../state/index.js'; +import { CheckpointModal } from './CheckpointModal.js'; + +// Create an instance of the modal +const checkpointModal = new CheckpointModal(); + +export function createCheckpointCard(checkpoint) { + const card = document.createElement('div'); + card.className = 'lora-card'; // Reuse the same class for styling + card.dataset.sha256 = checkpoint.sha256; + card.dataset.filepath = checkpoint.file_path; + card.dataset.name = checkpoint.model_name; + card.dataset.file_name = checkpoint.file_name; + card.dataset.folder = checkpoint.folder; + card.dataset.modified = checkpoint.modified; + card.dataset.file_size = checkpoint.file_size; + card.dataset.from_civitai = checkpoint.from_civitai; + card.dataset.base_model = checkpoint.base_model || 'Unknown'; + + // Store metadata if available + if (checkpoint.civitai) { + card.dataset.meta = JSON.stringify(checkpoint.civitai || {}); + } + + // Store tags if available + if (checkpoint.tags && Array.isArray(checkpoint.tags)) { + card.dataset.tags = JSON.stringify(checkpoint.tags); + } + + // Determine preview URL + const previewUrl = checkpoint.preview_url || '/loras_static/images/no-preview.png'; + const version = state.previewVersions ? state.previewVersions.get(checkpoint.file_path) : null; + const versionedPreviewUrl = version ? `${previewUrl}?t=${version}` : previewUrl; + + card.innerHTML = ` +
+ ${checkpoint.model_name} +
+ + ${checkpoint.base_model || 'Unknown'} + +
+ + + + +
+
+ +
+ `; + + // Main card click event + card.addEventListener('click', () => { + // Show checkpoint details modal + const checkpointMeta = { + sha256: card.dataset.sha256, + file_path: card.dataset.filepath, + model_name: card.dataset.name, + file_name: card.dataset.file_name, + folder: card.dataset.folder, + modified: card.dataset.modified, + file_size: parseInt(card.dataset.file_size || '0'), + from_civitai: card.dataset.from_civitai === 'true', + base_model: card.dataset.base_model, + preview_url: versionedPreviewUrl, + // Parse civitai metadata from the card's dataset + civitai: (() => { + try { + return JSON.parse(card.dataset.meta || '{}'); + } catch (e) { + console.error('Failed to parse civitai metadata:', e); + return {}; // Return empty object on error + } + })(), + tags: (() => { + try { + return JSON.parse(card.dataset.tags || '[]'); + } catch (e) { + console.error('Failed to parse tags:', e); + return []; // Return empty array on error + } + })() + }; + checkpointModal.showCheckpointDetails(checkpointMeta); + }); + + // Civitai button click event + if (checkpoint.from_civitai) { + card.querySelector('.fa-globe')?.addEventListener('click', e => { + e.stopPropagation(); + openCivitai(checkpoint.model_name); + }); + } + + // Delete button click event + card.querySelector('.fa-trash')?.addEventListener('click', e => { + e.stopPropagation(); + deleteCheckpoint(checkpoint.file_path); + }); + + // Replace preview button click event + card.querySelector('.fa-image')?.addEventListener('click', e => { + e.stopPropagation(); + replaceCheckpointPreview(checkpoint.file_path); + }); + + return card; +} + +// These functions will be implemented in checkpointApi.js +function openCivitai(modelName) { + if (window.openCivitai) { + window.openCivitai(modelName); + } else { + console.log('Opening Civitai for:', modelName); + } +} + +function deleteCheckpoint(filePath) { + if (window.deleteCheckpoint) { + window.deleteCheckpoint(filePath); + } else { + console.log('Delete checkpoint:', filePath); + } +} + +function replaceCheckpointPreview(filePath) { + if (window.replaceCheckpointPreview) { + window.replaceCheckpointPreview(filePath); + } else { + console.log('Replace checkpoint preview:', filePath); + } +} \ No newline at end of file diff --git a/static/js/components/CheckpointModal.js b/static/js/components/CheckpointModal.js new file mode 100644 index 00000000..54dd2d33 --- /dev/null +++ b/static/js/components/CheckpointModal.js @@ -0,0 +1,120 @@ +import { showToast } from '../utils/uiHelpers.js'; +import { modalManager } from '../managers/ModalManager.js'; + +/** + * CheckpointModal - Component for displaying checkpoint details + * This is a basic implementation that can be expanded in the future + */ +export class CheckpointModal { + constructor() { + this.modal = document.getElementById('checkpointModal'); + this.modalTitle = document.getElementById('checkpointModalTitle'); + this.modalContent = document.getElementById('checkpointModalContent'); + this.currentCheckpoint = null; + + // Initialize close events + this._initCloseEvents(); + } + + _initCloseEvents() { + if (!this.modal) return; + + // Close button + const closeBtn = this.modal.querySelector('.close'); + if (closeBtn) { + closeBtn.addEventListener('click', () => this.close()); + } + + // Click outside to close + this.modal.addEventListener('click', (e) => { + if (e.target === this.modal) { + this.close(); + } + }); + } + + /** + * Show checkpoint details in the modal + * @param {Object} checkpoint - Checkpoint data + */ + showCheckpointDetails(checkpoint) { + if (!this.modal || !this.modalContent) { + console.error('Checkpoint modal elements not found'); + return; + } + + this.currentCheckpoint = checkpoint; + + // Set modal title + if (this.modalTitle) { + this.modalTitle.textContent = checkpoint.model_name || 'Checkpoint Details'; + } + + // This is a basic implementation that can be expanded with more details + // For now, just display some basic information + this.modalContent.innerHTML = ` +
+
+ ${checkpoint.model_name} +
+
+

${checkpoint.model_name}

+
+
+ File Name: + ${checkpoint.file_name} +
+
+ Location: + ${checkpoint.folder} +
+
+ Base Model: + ${checkpoint.base_model || 'Unknown'} +
+
+ File Size: + ${this._formatFileSize(checkpoint.file_size)} +
+
+ SHA256: + ${checkpoint.sha256 || 'Unknown'} +
+
+
+
+
+

Detailed checkpoint information will be implemented in a future update.

+
+ `; + + // Show the modal + this.modal.style.display = 'block'; + } + + /** + * Close the modal + */ + close() { + if (this.modal) { + this.modal.style.display = 'none'; + this.currentCheckpoint = null; + } + } + + /** + * Format file size for display + * @param {number} bytes - File size in bytes + * @returns {string} - Formatted file size + */ + _formatFileSize(bytes) { + if (!bytes) return 'Unknown'; + + const sizes = ['Bytes', 'KB', 'MB', 'GB', 'TB']; + if (bytes === 0) return '0 Bytes'; + const i = Math.floor(Math.log(bytes) / Math.log(1024)); + if (i === 0) return `${bytes} ${sizes[i]}`; + return `${(bytes / Math.pow(1024, i)).toFixed(2)} ${sizes[i]}`; + } +} \ No newline at end of file diff --git a/static/js/managers/FilterManager.js b/static/js/managers/FilterManager.js index 96d31388..57cbbf35 100644 --- a/static/js/managers/FilterManager.js +++ b/static/js/managers/FilterManager.js @@ -70,6 +70,8 @@ export class FilterManager { let tagsEndpoint = '/api/loras/top-tags?limit=20'; if (this.currentPage === 'recipes') { tagsEndpoint = '/api/recipes/top-tags?limit=20'; + } else if (this.currentPage === 'checkpoints') { + tagsEndpoint = '/api/checkpoints/top-tags?limit=20'; } const response = await fetch(tagsEndpoint); @@ -143,7 +145,8 @@ export class FilterManager { apiEndpoint = '/api/loras/base-models'; } else if (this.currentPage === 'recipes') { apiEndpoint = '/api/recipes/base-models'; - } else { + } else if (this.currentPage === 'checkpoints') { + apiEndpoint = '/api/checkpoints/base-models'; return; // No API endpoint for other pages } diff --git a/static/js/managers/SearchManager.js b/static/js/managers/SearchManager.js index 4d909b36..49e6749b 100644 --- a/static/js/managers/SearchManager.js +++ b/static/js/managers/SearchManager.js @@ -302,6 +302,7 @@ export class SearchManager { pageState.searchOptions = { filename: options.filename || false, modelname: options.modelname || false, + tags: options.tags || false, recursive: recursive }; } diff --git a/templates/checkpoints.html b/templates/checkpoints.html index f4a4d6b8..4ff483df 100644 --- a/templates/checkpoints.html +++ b/templates/checkpoints.html @@ -8,18 +8,19 @@ {% endblock %} {% block init_title %}Initializing Checkpoints Manager{% endblock %} -{% block init_message %}Setting up checkpoints interface. This may take a few moments...{% endblock %} +{% block init_message %}Scanning and building checkpoints cache. This may take a few moments...{% endblock %} {% block init_check_url %}/api/checkpoints?page=1&page_size=1{% endblock %} +{% block additional_components %} +{% include 'components/checkpoint_modals.html' %} +{% endblock %} + {% block content %} -
-
- -

Checkpoints Manager

-

This feature is currently under development and will be available soon.

-

Please check back later for updates!

+ {% include 'components/controls.html' %} + +
+
-
{% endblock %} {% block main_script %} diff --git a/templates/components/checkpoint_modals.html b/templates/components/checkpoint_modals.html new file mode 100644 index 00000000..af971475 --- /dev/null +++ b/templates/components/checkpoint_modals.html @@ -0,0 +1,35 @@ + + + + + + \ No newline at end of file diff --git a/templates/components/header.html b/templates/components/header.html index f9ddc624..51a7ba06 100644 --- a/templates/components/header.html +++ b/templates/components/header.html @@ -74,6 +74,7 @@ {% elif request.path == '/checkpoints' %}
Filename
Checkpoint Name
+
Tags
{% else %}
Filename
From ee04df40c34d38dc74d73fa71e2a370dc83e232d Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 10 Apr 2025 19:41:02 +0800 Subject: [PATCH 04/36] Refactor controls and pagination for Checkpoints and LoRAs: Implement unified PageControls, enhance API integration, and improve event handling for better user experience. --- static/js/api/checkpointApi.js | 5 + static/js/api/loraApi.js | 16 +- static/js/checkpoints.js | 103 +---- .../controls/CheckpointsControls.js | 46 +++ .../js/components/controls/LorasControls.js | 147 +++++++ static/js/components/controls/PageControls.js | 391 ++++++++++++++++++ static/js/components/controls/index.js | 23 ++ static/js/loras.js | 164 +------- static/js/recipes.js | 5 + static/js/utils/infiniteScroll.js | 13 +- templates/components/controls.html | 18 +- 11 files changed, 667 insertions(+), 264 deletions(-) create mode 100644 static/js/components/controls/CheckpointsControls.js create mode 100644 static/js/components/controls/LorasControls.js create mode 100644 static/js/components/controls/PageControls.js create mode 100644 static/js/components/controls/index.js diff --git a/static/js/api/checkpointApi.js b/static/js/api/checkpointApi.js index cbf1ca0b..1dd0c8fe 100644 --- a/static/js/api/checkpointApi.js +++ b/static/js/api/checkpointApi.js @@ -101,6 +101,11 @@ export async function loadMoreCheckpoints(resetPagination = true) { const card = createCheckpointCard(checkpoint); grid.appendChild(card); }); + + // Increment the page number AFTER successful loading + if (data.items.length > 0) { + pageState.currentPage++; + } } catch (error) { console.error('Error loading checkpoints:', error); showToast('Failed to load checkpoints', 'error'); diff --git a/static/js/api/loraApi.js b/static/js/api/loraApi.js index 8ebee93c..c344e930 100644 --- a/static/js/api/loraApi.js +++ b/static/js/api/loraApi.js @@ -19,7 +19,6 @@ export async function loadMoreLoras(resetPage = false, updateFolders = false) { // Clear grid if resetting const grid = document.getElementById('loraGrid'); if (grid) grid.innerHTML = ''; - initializeInfiniteScroll(); } const params = new URLSearchParams({ @@ -62,9 +61,6 @@ export async function loadMoreLoras(resetPage = false, updateFolders = false) { const filterLoraHash = getSessionItem('recipe_to_lora_filterLoraHash'); const filterLoraHashes = getSessionItem('recipe_to_lora_filterLoraHashes'); - console.log('Filter Lora Hash:', filterLoraHash); - console.log('Filter Lora Hashes:', filterLoraHashes); - // Add hash filter parameter if present if (filterLoraHash) { params.append('lora_hash', filterLoraHash); @@ -93,13 +89,10 @@ export async function loadMoreLoras(resetPage = false, updateFolders = false) { pageState.hasMore = false; } else if (data.items.length > 0) { pageState.hasMore = pageState.currentPage < data.total_pages; - pageState.currentPage++; appendLoraCards(data.items); - const sentinel = document.getElementById('scroll-sentinel'); - if (sentinel && state.observer) { - state.observer.observe(sentinel); - } + // Increment the page number AFTER successful loading + pageState.currentPage++; } else { pageState.hasMore = false; } @@ -303,10 +296,7 @@ export async function resetAndReload(updateFolders = false) { const pageState = getCurrentPageState(); console.log('Resetting with state:', { ...pageState }); - // Initialize infinite scroll - will reset the observer - initializeInfiniteScroll(); - - // Load more loras with reset flag + // Reset pagination and load more loras await loadMoreLoras(true, updateFolders); } diff --git a/static/js/checkpoints.js b/static/js/checkpoints.js index 8b563f8c..fc09463d 100644 --- a/static/js/checkpoints.js +++ b/static/js/checkpoints.js @@ -1,68 +1,30 @@ import { appCore } from './core.js'; -import { state, getCurrentPageState } from './state/index.js'; -import { - loadMoreCheckpoints, - resetAndReload, - refreshCheckpoints, - deleteCheckpoint, - replaceCheckpointPreview -} from './api/checkpointApi.js'; -import { - restoreFolderFilter, - toggleFolder, - openCivitai, - showToast -} from './utils/uiHelpers.js'; -import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js'; import { toggleApiKeyVisibility } from './managers/SettingsManager.js'; import { initializeInfiniteScroll } from './utils/infiniteScroll.js'; -import { setStorageItem, getStorageItem } from './utils/storageHelpers.js'; +import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js'; +import { createPageControls } from './components/controls/index.js'; // Initialize the Checkpoints page class CheckpointsPageManager { constructor() { - // Get page state - this.pageState = getCurrentPageState(); + // Initialize page controls + this.pageControls = createPageControls('checkpoints'); - // Set default values - this.pageState.pageSize = 20; - this.pageState.isLoading = false; - this.pageState.hasMore = true; - - // Expose functions to window object - this._exposeGlobalFunctions(); + // Expose only necessary functions to global scope + this._exposeRequiredGlobalFunctions(); } - _exposeGlobalFunctions() { - // API functions - window.loadCheckpoints = (reset = true) => this.loadCheckpoints(reset); - window.refreshCheckpoints = refreshCheckpoints; - window.deleteCheckpoint = deleteCheckpoint; - window.replaceCheckpointPreview = replaceCheckpointPreview; - - // UI helper functions - window.toggleFolder = toggleFolder; - window.openCivitai = openCivitai; + _exposeRequiredGlobalFunctions() { + // Minimal set of functions that need to remain global window.confirmDelete = confirmDelete; window.closeDeleteModal = closeDeleteModal; window.toggleApiKeyVisibility = toggleApiKeyVisibility; - - // Add reference to this manager - window.checkpointManager = this; } async initialize() { - // Initialize event listeners - this._initEventListeners(); - - // Restore folder filters if available - restoreFolderFilter('checkpoints'); - - // Load sort preference - this._loadSortPreference(); - - // Load initial checkpoints - await this.loadCheckpoints(); + // Initialize page-specific components + this.pageControls.restoreFolderFilter(); + this.pageControls.initFolderTagsVisibility(); // Initialize infinite scroll initializeInfiniteScroll('checkpoints'); @@ -72,49 +34,6 @@ class CheckpointsPageManager { console.log('Checkpoints Manager initialized'); } - - _initEventListeners() { - // Sort select handler - const sortSelect = document.getElementById('sortSelect'); - if (sortSelect) { - sortSelect.addEventListener('change', async (e) => { - this.pageState.sortBy = e.target.value; - this._saveSortPreference(e.target.value); - await resetAndReload(); - }); - } - - // Folder tags handler - document.querySelectorAll('.folder-tags .tag').forEach(tag => { - tag.addEventListener('click', toggleFolder); - }); - - // Refresh button handler - const refreshBtn = document.getElementById('refreshBtn'); - if (refreshBtn) { - refreshBtn.addEventListener('click', () => refreshCheckpoints()); - } - } - - _loadSortPreference() { - const savedSort = getStorageItem('checkpoints_sort'); - if (savedSort) { - this.pageState.sortBy = savedSort; - const sortSelect = document.getElementById('sortSelect'); - if (sortSelect) { - sortSelect.value = savedSort; - } - } - } - - _saveSortPreference(sortValue) { - setStorageItem('checkpoints_sort', sortValue); - } - - // Load checkpoints with optional pagination reset - async loadCheckpoints(resetPage = true) { - await loadMoreCheckpoints(resetPage); - } } // Initialize everything when DOM is ready diff --git a/static/js/components/controls/CheckpointsControls.js b/static/js/components/controls/CheckpointsControls.js new file mode 100644 index 00000000..2dd968e2 --- /dev/null +++ b/static/js/components/controls/CheckpointsControls.js @@ -0,0 +1,46 @@ +// CheckpointsControls.js - Specific implementation for the Checkpoints page +import { PageControls } from './PageControls.js'; +import { loadMoreCheckpoints, resetAndReload, refreshCheckpoints } from '../../api/checkpointApi.js'; +import { showToast } from '../../utils/uiHelpers.js'; + +/** + * CheckpointsControls class - Extends PageControls for Checkpoint-specific functionality + */ +export class CheckpointsControls extends PageControls { + constructor() { + // Initialize with 'checkpoints' page type + super('checkpoints'); + + // Register API methods specific to the Checkpoints page + this.registerCheckpointsAPI(); + } + + /** + * Register Checkpoint-specific API methods + */ + registerCheckpointsAPI() { + const checkpointsAPI = { + // Core API functions + loadMoreModels: async (resetPage = false, updateFolders = false) => { + return await loadMoreCheckpoints(resetPage, updateFolders); + }, + + resetAndReload: async (updateFolders = false) => { + return await resetAndReload(updateFolders); + }, + + refreshModels: async () => { + return await refreshCheckpoints(); + }, + + // No clearCustomFilter implementation is needed for checkpoints + // as custom filters are currently only used for LoRAs + clearCustomFilter: async () => { + showToast('No custom filter to clear', 'info'); + } + }; + + // Register the API + this.registerAPI(checkpointsAPI); + } +} \ No newline at end of file diff --git a/static/js/components/controls/LorasControls.js b/static/js/components/controls/LorasControls.js new file mode 100644 index 00000000..4d0cc9eb --- /dev/null +++ b/static/js/components/controls/LorasControls.js @@ -0,0 +1,147 @@ +// LorasControls.js - Specific implementation for the LoRAs page +import { PageControls } from './PageControls.js'; +import { loadMoreLoras, fetchCivitai, resetAndReload, refreshLoras } from '../../api/loraApi.js'; +import { getSessionItem, removeSessionItem } from '../../utils/storageHelpers.js'; +import { showToast } from '../../utils/uiHelpers.js'; + +/** + * LorasControls class - Extends PageControls for LoRA-specific functionality + */ +export class LorasControls extends PageControls { + constructor() { + // Initialize with 'loras' page type + super('loras'); + + // Register API methods specific to the LoRAs page + this.registerLorasAPI(); + + // Check for custom filters (e.g., from recipe navigation) + this.checkCustomFilters(); + } + + /** + * Register LoRA-specific API methods + */ + registerLorasAPI() { + const lorasAPI = { + // Core API functions + loadMoreModels: async (resetPage = false, updateFolders = false) => { + return await loadMoreLoras(resetPage, updateFolders); + }, + + resetAndReload: async (updateFolders = false) => { + return await resetAndReload(updateFolders); + }, + + refreshModels: async () => { + return await refreshLoras(); + }, + + // LoRA-specific API functions + fetchFromCivitai: async () => { + return await fetchCivitai(); + }, + + showDownloadModal: () => { + if (window.downloadManager) { + window.downloadManager.showDownloadModal(); + } else { + console.error('Download manager not available'); + } + }, + + toggleBulkMode: () => { + if (window.bulkManager) { + window.bulkManager.toggleBulkMode(); + } else { + console.error('Bulk manager not available'); + } + }, + + clearCustomFilter: async () => { + await this.clearCustomFilter(); + } + }; + + // Register the API + this.registerAPI(lorasAPI); + } + + /** + * Check for custom filter parameters in session storage (e.g., from recipe page navigation) + */ + checkCustomFilters() { + const filterLoraHash = getSessionItem('recipe_to_lora_filterLoraHash'); + const filterLoraHashes = getSessionItem('recipe_to_lora_filterLoraHashes'); + const filterRecipeName = getSessionItem('filterRecipeName'); + const viewLoraDetail = getSessionItem('viewLoraDetail'); + + if ((filterLoraHash || filterLoraHashes) && filterRecipeName) { + // Found custom filter parameters, set up the custom filter + + // Show the filter indicator + const indicator = document.getElementById('customFilterIndicator'); + const filterText = indicator?.querySelector('.customFilterText'); + + if (indicator && filterText) { + indicator.classList.remove('hidden'); + + // Set text content with recipe name + const filterType = filterLoraHash && viewLoraDetail ? "Viewing LoRA from" : "Viewing LoRAs from"; + const displayText = `${filterType}: ${filterRecipeName}`; + + filterText.textContent = this._truncateText(displayText, 30); + filterText.setAttribute('title', displayText); + + // Add pulse animation + const filterElement = indicator.querySelector('.filter-active'); + if (filterElement) { + filterElement.classList.add('animate'); + setTimeout(() => filterElement.classList.remove('animate'), 600); + } + } + + // If we're viewing a specific LoRA detail, set up to open the modal + if (filterLoraHash && viewLoraDetail) { + this.pageState.pendingLoraHash = filterLoraHash; + } + } + } + + /** + * Clear the custom filter and reload the page + */ + async clearCustomFilter() { + console.log("Clearing custom filter..."); + // Remove filter parameters from session storage + removeSessionItem('recipe_to_lora_filterLoraHash'); + removeSessionItem('recipe_to_lora_filterLoraHashes'); + removeSessionItem('filterRecipeName'); + removeSessionItem('viewLoraDetail'); + + // Hide the filter indicator + const indicator = document.getElementById('customFilterIndicator'); + if (indicator) { + indicator.classList.add('hidden'); + } + + // Reset state + if (this.pageState.pendingLoraHash) { + delete this.pageState.pendingLoraHash; + } + + // Reload the loras + await resetAndReload(); + showToast('Filter cleared', 'info'); + } + + /** + * Helper to truncate text with ellipsis + * @param {string} text - Text to truncate + * @param {number} maxLength - Maximum length before truncating + * @returns {string} - Truncated text + */ + _truncateText(text, maxLength) { + return text.length > maxLength ? text.substring(0, maxLength - 3) + '...' : text; + } +} \ No newline at end of file diff --git a/static/js/components/controls/PageControls.js b/static/js/components/controls/PageControls.js new file mode 100644 index 00000000..43498f12 --- /dev/null +++ b/static/js/components/controls/PageControls.js @@ -0,0 +1,391 @@ +// PageControls.js - Manages controls for both LoRAs and Checkpoints pages +import { state, getCurrentPageState, setCurrentPageType } from '../../state/index.js'; +import { getStorageItem, setStorageItem } from '../../utils/storageHelpers.js'; +import { showToast } from '../../utils/uiHelpers.js'; + +/** + * PageControls class - Unified control management for model pages + */ +export class PageControls { + constructor(pageType) { + // Set the current page type in state + setCurrentPageType(pageType); + + // Store the page type + this.pageType = pageType; + + // Get the current page state + this.pageState = getCurrentPageState(); + + // Initialize state based on page type + this.initializeState(); + + // Store API methods + this.api = null; + + // Initialize event listeners + this.initEventListeners(); + + console.log(`PageControls initialized for ${pageType} page`); + } + + /** + * Initialize state based on page type + */ + initializeState() { + // Set default values + this.pageState.pageSize = 20; + this.pageState.isLoading = false; + this.pageState.hasMore = true; + + // Load sort preference + this.loadSortPreference(); + } + + /** + * Register API methods for the page + * @param {Object} api - API methods for the page + */ + registerAPI(api) { + this.api = api; + console.log(`API methods registered for ${this.pageType} page`); + } + + /** + * Initialize event listeners for controls + */ + initEventListeners() { + // Sort select handler + const sortSelect = document.getElementById('sortSelect'); + if (sortSelect) { + sortSelect.value = this.pageState.sortBy; + sortSelect.addEventListener('change', async (e) => { + this.pageState.sortBy = e.target.value; + this.saveSortPreference(e.target.value); + await this.resetAndReload(); + }); + } + + // Folder tags handler + document.querySelectorAll('.folder-tags .tag').forEach(tag => { + tag.addEventListener('click', (e) => this.handleFolderClick(e.currentTarget)); + }); + + // Refresh button handler + const refreshBtn = document.querySelector('[data-action="refresh"]'); + if (refreshBtn) { + refreshBtn.addEventListener('click', () => this.refreshModels()); + } + + // Toggle folders button + const toggleFoldersBtn = document.querySelector('.toggle-folders-btn'); + if (toggleFoldersBtn) { + toggleFoldersBtn.addEventListener('click', () => this.toggleFolderTags()); + } + + // Clear custom filter handler + const clearFilterBtn = document.querySelector('.clear-filter'); + if (clearFilterBtn) { + clearFilterBtn.addEventListener('click', () => this.clearCustomFilter()); + } + + // Page-specific event listeners + this.initPageSpecificListeners(); + } + + /** + * Initialize page-specific event listeners + */ + initPageSpecificListeners() { + if (this.pageType === 'loras') { + // Fetch from Civitai button + const fetchButton = document.querySelector('[data-action="fetch"]'); + if (fetchButton) { + fetchButton.addEventListener('click', () => this.fetchFromCivitai()); + } + + // Download button + const downloadButton = document.querySelector('[data-action="download"]'); + if (downloadButton) { + downloadButton.addEventListener('click', () => this.showDownloadModal()); + } + + // Bulk operations button + const bulkButton = document.querySelector('[data-action="bulk"]'); + if (bulkButton) { + bulkButton.addEventListener('click', () => this.toggleBulkMode()); + } + } + } + + /** + * Toggle folder selection + * @param {HTMLElement} tagElement - The folder tag element that was clicked + */ + handleFolderClick(tagElement) { + const folder = tagElement.dataset.folder; + const wasActive = tagElement.classList.contains('active'); + + document.querySelectorAll('.folder-tags .tag').forEach(t => { + t.classList.remove('active'); + }); + + if (!wasActive) { + tagElement.classList.add('active'); + this.pageState.activeFolder = folder; + setStorageItem(`${this.pageType}_activeFolder`, folder); + } else { + this.pageState.activeFolder = null; + setStorageItem(`${this.pageType}_activeFolder`, null); + } + + this.resetAndReload(); + } + + /** + * Restore folder filter from storage + */ + restoreFolderFilter() { + const activeFolder = getStorageItem(`${this.pageType}_activeFolder`); + const folderTag = activeFolder && document.querySelector(`.tag[data-folder="${activeFolder}"]`); + + if (folderTag) { + folderTag.classList.add('active'); + this.pageState.activeFolder = activeFolder; + this.filterByFolder(activeFolder); + } + } + + /** + * Filter displayed cards by folder + * @param {string} folderPath - Folder path to filter by + */ + filterByFolder(folderPath) { + const cardSelector = this.pageType === 'loras' ? '.lora-card' : '.checkpoint-card'; + document.querySelectorAll(cardSelector).forEach(card => { + card.style.display = card.dataset.folder === folderPath ? '' : 'none'; + }); + } + + /** + * Update the folder tags display with new folder list + * @param {Array} folders - List of folder names + */ + updateFolderTags(folders) { + const folderTagsContainer = document.querySelector('.folder-tags'); + if (!folderTagsContainer) return; + + // Keep track of currently selected folder + const currentFolder = this.pageState.activeFolder; + + // Create HTML for folder tags + const tagsHTML = folders.map(folder => { + const isActive = folder === currentFolder; + return `
${folder}
`; + }).join(''); + + // Update the container + folderTagsContainer.innerHTML = tagsHTML; + + // Reattach click handlers + const tags = folderTagsContainer.querySelectorAll('.tag'); + tags.forEach(tag => { + tag.addEventListener('click', (e) => this.handleFolderClick(e.currentTarget)); + if (tag.dataset.folder === currentFolder) { + tag.scrollIntoView({ behavior: 'smooth', block: 'nearest' }); + } + }); + } + + /** + * Toggle visibility of folder tags + */ + toggleFolderTags() { + const folderTags = document.querySelector('.folder-tags'); + const toggleBtn = document.querySelector('.toggle-folders-btn i'); + + if (folderTags) { + folderTags.classList.toggle('collapsed'); + + if (folderTags.classList.contains('collapsed')) { + // Change icon to indicate folders are hidden + toggleBtn.className = 'fas fa-folder-plus'; + toggleBtn.parentElement.title = 'Show folder tags'; + setStorageItem('folderTagsCollapsed', 'true'); + } else { + // Change icon to indicate folders are visible + toggleBtn.className = 'fas fa-folder-minus'; + toggleBtn.parentElement.title = 'Hide folder tags'; + setStorageItem('folderTagsCollapsed', 'false'); + } + } + } + + /** + * Initialize folder tags visibility based on stored preference + */ + initFolderTagsVisibility() { + const isCollapsed = getStorageItem('folderTagsCollapsed'); + if (isCollapsed) { + const folderTags = document.querySelector('.folder-tags'); + const toggleBtn = document.querySelector('.toggle-folders-btn i'); + if (folderTags) { + folderTags.classList.add('collapsed'); + } + if (toggleBtn) { + toggleBtn.className = 'fas fa-folder-plus'; + toggleBtn.parentElement.title = 'Show folder tags'; + } + } else { + const toggleBtn = document.querySelector('.toggle-folders-btn i'); + if (toggleBtn) { + toggleBtn.className = 'fas fa-folder-minus'; + toggleBtn.parentElement.title = 'Hide folder tags'; + } + } + } + + /** + * Load sort preference from storage + */ + loadSortPreference() { + const savedSort = getStorageItem(`${this.pageType}_sort`); + if (savedSort) { + this.pageState.sortBy = savedSort; + const sortSelect = document.getElementById('sortSelect'); + if (sortSelect) { + sortSelect.value = savedSort; + } + } + } + + /** + * Save sort preference to storage + * @param {string} sortValue - The sort value to save + */ + saveSortPreference(sortValue) { + setStorageItem(`${this.pageType}_sort`, sortValue); + } + + /** + * Open model page on Civitai + * @param {string} modelName - Name of the model + */ + openCivitai(modelName) { + // Get card selector based on page type + const cardSelector = this.pageType === 'loras' + ? `.lora-card[data-name="${modelName}"]` + : `.checkpoint-card[data-name="${modelName}"]`; + + const card = document.querySelector(cardSelector); + if (!card) return; + + const metaData = JSON.parse(card.dataset.meta); + const civitaiId = metaData.modelId; + const versionId = metaData.id; + + // Build URL + if (civitaiId) { + let url = `https://civitai.com/models/${civitaiId}`; + if (versionId) { + url += `?modelVersionId=${versionId}`; + } + window.open(url, '_blank'); + } else { + // If no ID, try searching by name + window.open(`https://civitai.com/models?query=${encodeURIComponent(modelName)}`, '_blank'); + } + } + + /** + * Reset and reload the models list + */ + async resetAndReload(updateFolders = false) { + if (!this.api) { + console.error('API methods not registered'); + return; + } + + try { + await this.api.resetAndReload(updateFolders); + } catch (error) { + console.error(`Error reloading ${this.pageType}:`, error); + showToast(`Failed to reload ${this.pageType}: ${error.message}`, 'error'); + } + } + + /** + * Refresh models list + */ + async refreshModels() { + if (!this.api) { + console.error('API methods not registered'); + return; + } + + try { + await this.api.refreshModels(); + } catch (error) { + console.error(`Error refreshing ${this.pageType}:`, error); + showToast(`Failed to refresh ${this.pageType}: ${error.message}`, 'error'); + } + } + + /** + * Fetch metadata from Civitai (LoRAs only) + */ + async fetchFromCivitai() { + if (this.pageType !== 'loras' || !this.api) { + console.error('Fetch from Civitai is only available for LoRAs'); + return; + } + + try { + await this.api.fetchFromCivitai(); + } catch (error) { + console.error('Error fetching metadata:', error); + showToast('Failed to fetch metadata: ' + error.message, 'error'); + } + } + + /** + * Show download modal (LoRAs only) + */ + showDownloadModal() { + if (this.pageType !== 'loras' || !this.api) { + console.error('Download modal is only available for LoRAs'); + return; + } + + this.api.showDownloadModal(); + } + + /** + * Toggle bulk mode (LoRAs only) + */ + toggleBulkMode() { + if (this.pageType !== 'loras' || !this.api) { + console.error('Bulk mode is only available for LoRAs'); + return; + } + + this.api.toggleBulkMode(); + } + + /** + * Clear custom filter + */ + async clearCustomFilter() { + if (!this.api) { + console.error('API methods not registered'); + return; + } + + try { + await this.api.clearCustomFilter(); + } catch (error) { + console.error('Error clearing custom filter:', error); + showToast('Failed to clear custom filter: ' + error.message, 'error'); + } + } +} \ No newline at end of file diff --git a/static/js/components/controls/index.js b/static/js/components/controls/index.js new file mode 100644 index 00000000..c767c62f --- /dev/null +++ b/static/js/components/controls/index.js @@ -0,0 +1,23 @@ +// Controls components index file +import { PageControls } from './PageControls.js'; +import { LorasControls } from './LorasControls.js'; +import { CheckpointsControls } from './CheckpointsControls.js'; + +// Export the classes +export { PageControls, LorasControls, CheckpointsControls }; + +/** + * Factory function to create the appropriate controls based on page type + * @param {string} pageType - The type of page ('loras' or 'checkpoints') + * @returns {PageControls} - The appropriate controls instance + */ +export function createPageControls(pageType) { + if (pageType === 'loras') { + return new LorasControls(); + } else if (pageType === 'checkpoints') { + return new CheckpointsControls(); + } else { + console.error(`Unknown page type: ${pageType}`); + return null; + } +} \ No newline at end of file diff --git a/static/js/loras.js b/static/js/loras.js index d9750786..03c4dfa3 100644 --- a/static/js/loras.js +++ b/static/js/loras.js @@ -1,23 +1,14 @@ import { appCore } from './core.js'; import { state } from './state/index.js'; import { showLoraModal, toggleShowcase, scrollToTop } from './components/loraModal/index.js'; -import { loadMoreLoras, fetchCivitai, deleteModel, replacePreview, resetAndReload, refreshLoras } from './api/loraApi.js'; -import { - restoreFolderFilter, - toggleFolder, - copyTriggerWord, - openCivitai, - toggleFolderTags, - initFolderTagsVisibility, -} from './utils/uiHelpers.js'; -import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js'; -import { DownloadManager } from './managers/DownloadManager.js'; -import { toggleApiKeyVisibility } from './managers/SettingsManager.js'; -import { LoraContextMenu } from './components/ContextMenu.js'; -import { moveManager } from './managers/MoveManager.js'; import { updateCardsForBulkMode } from './components/LoraCard.js'; import { bulkManager } from './managers/BulkManager.js'; -import { setStorageItem, getStorageItem, getSessionItem, removeSessionItem } from './utils/storageHelpers.js'; +import { DownloadManager } from './managers/DownloadManager.js'; +import { toggleApiKeyVisibility } from './managers/SettingsManager.js'; +import { moveManager } from './managers/MoveManager.js'; +import { LoraContextMenu } from './components/ContextMenu.js'; +import { createPageControls } from './components/controls/index.js'; +import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js'; // Initialize the LoRA page class LoraPageManager { @@ -29,24 +20,20 @@ class LoraPageManager { // Initialize managers this.downloadManager = new DownloadManager(); - // Expose necessary functions to the page - this._exposeGlobalFunctions(); + // Initialize page controls + this.pageControls = createPageControls('loras'); + + // Expose necessary functions to the page that still need global access + // These will be refactored in future updates + this._exposeRequiredGlobalFunctions(); } - _exposeGlobalFunctions() { - // Only expose what's needed for the page - window.loadMoreLoras = loadMoreLoras; - window.fetchCivitai = fetchCivitai; - window.deleteModel = deleteModel; - window.replacePreview = replacePreview; - window.toggleFolder = toggleFolder; - window.copyTriggerWord = copyTriggerWord; + _exposeRequiredGlobalFunctions() { + // Only expose what's still needed globally + // Most functionality is now handled by the PageControls component window.showLoraModal = showLoraModal; window.confirmDelete = confirmDelete; window.closeDeleteModal = closeDeleteModal; - window.refreshLoras = refreshLoras; - window.openCivitai = openCivitai; - window.toggleFolderTags = toggleFolderTags; window.toggleApiKeyVisibility = toggleApiKeyVisibility; window.downloadManager = this.downloadManager; window.moveManager = moveManager; @@ -64,14 +51,10 @@ class LoraPageManager { async initialize() { // Initialize page-specific components - this.initEventListeners(); - restoreFolderFilter(); - initFolderTagsVisibility(); + this.pageControls.restoreFolderFilter(); + this.pageControls.initFolderTagsVisibility(); new LoraContextMenu(); - // Check for custom filters from recipe page navigation - this.checkCustomFilters(); - // Initialize cards for current bulk mode state (should be false initially) updateCardsForBulkMode(state.bulkMode); @@ -81,119 +64,6 @@ class LoraPageManager { // Initialize common page features (lazy loading, infinite scroll) appCore.initializePageFeatures(); } - - // Check for custom filter parameters in session storage - checkCustomFilters() { - const filterLoraHash = getSessionItem('recipe_to_lora_filterLoraHash'); - const filterLoraHashes = getSessionItem('recipe_to_lora_filterLoraHashes'); - const filterRecipeName = getSessionItem('filterRecipeName'); - const viewLoraDetail = getSessionItem('viewLoraDetail'); - - console.log("Checking custom filters..."); - console.log("filterLoraHash:", filterLoraHash); - console.log("filterLoraHashes:", filterLoraHashes); - console.log("filterRecipeName:", filterRecipeName); - console.log("viewLoraDetail:", viewLoraDetail); - - if ((filterLoraHash || filterLoraHashes) && filterRecipeName) { - // Found custom filter parameters, set up the custom filter - - // Show the filter indicator - const indicator = document.getElementById('customFilterIndicator'); - const filterText = indicator.querySelector('.customFilterText'); - - if (indicator && filterText) { - indicator.classList.remove('hidden'); - - // Set text content with recipe name - const filterType = filterLoraHash && viewLoraDetail ? "Viewing LoRA from" : "Viewing LoRAs from"; - const displayText = `${filterType}: ${filterRecipeName}`; - - filterText.textContent = this._truncateText(displayText, 30); - filterText.setAttribute('title', displayText); - - // Add click handler for the clear button - const clearBtn = indicator.querySelector('.clear-filter'); - if (clearBtn) { - clearBtn.addEventListener('click', this.clearCustomFilter); - } - - // Add pulse animation - const filterElement = indicator.querySelector('.filter-active'); - if (filterElement) { - filterElement.classList.add('animate'); - setTimeout(() => filterElement.classList.remove('animate'), 600); - } - } - - // If we're viewing a specific LoRA detail, set up to open the modal - if (filterLoraHash && viewLoraDetail) { - // Store this to fetch after initial load completes - state.pendingLoraHash = filterLoraHash; - } - } - } - - // Helper to truncate text with ellipsis - _truncateText(text, maxLength) { - return text.length > maxLength ? text.substring(0, maxLength - 3) + '...' : text; - } - - // Clear the custom filter and reload the page - clearCustomFilter = async () => { - console.log("Clearing custom filter..."); - // Remove filter parameters from session storage - removeSessionItem('recipe_to_lora_filterLoraHash'); - removeSessionItem('recipe_to_lora_filterLoraHashes'); - removeSessionItem('filterRecipeName'); - removeSessionItem('viewLoraDetail'); - - // Hide the filter indicator - const indicator = document.getElementById('customFilterIndicator'); - if (indicator) { - indicator.classList.add('hidden'); - } - - // Reset state - if (state.pendingLoraHash) { - delete state.pendingLoraHash; - } - - // Reload the loras - await resetAndReload(); - } - - loadSortPreference() { - const savedSort = getStorageItem('loras_sort'); - if (savedSort) { - state.sortBy = savedSort; - const sortSelect = document.getElementById('sortSelect'); - if (sortSelect) { - sortSelect.value = savedSort; - } - } - } - - saveSortPreference(sortValue) { - setStorageItem('loras_sort', sortValue); - } - - initEventListeners() { - const sortSelect = document.getElementById('sortSelect'); - if (sortSelect) { - sortSelect.value = state.sortBy; - this.loadSortPreference(); - sortSelect.addEventListener('change', async (e) => { - state.sortBy = e.target.value; - this.saveSortPreference(e.target.value); - await resetAndReload(); - }); - } - - document.querySelectorAll('.folder-tags .tag').forEach(tag => { - tag.addEventListener('click', toggleFolder); - }); - } } // Initialize everything when DOM is ready diff --git a/static/js/recipes.js b/static/js/recipes.js index ba55e62c..3c190e56 100644 --- a/static/js/recipes.js +++ b/static/js/recipes.js @@ -251,6 +251,11 @@ class RecipeManager { // Update pagination state based on current page and total pages this.pageState.hasMore = data.page < data.total_pages; + // Increment the page number AFTER successful loading + if (data.items.length > 0) { + this.pageState.currentPage++; + } + } catch (error) { console.error('Error loading recipes:', error); appCore.showToast('Failed to load recipes', 'error'); diff --git a/static/js/utils/infiniteScroll.js b/static/js/utils/infiniteScroll.js index ddb0e341..60795e14 100644 --- a/static/js/utils/infiniteScroll.js +++ b/static/js/utils/infiniteScroll.js @@ -1,5 +1,6 @@ import { state, getCurrentPageState } from '../state/index.js'; import { loadMoreLoras } from '../api/loraApi.js'; +import { loadMoreCheckpoints } from '../api/checkpointApi.js'; import { debounce } from './debounce.js'; export function initializeInfiniteScroll(pageType = 'loras') { @@ -21,7 +22,6 @@ export function initializeInfiniteScroll(pageType = 'loras') { case 'recipes': loadMoreFunction = () => { if (!pageState.isLoading && pageState.hasMore) { - pageState.currentPage++; window.recipeManager.loadRecipes(false); // false to not reset pagination } }; @@ -30,15 +30,18 @@ export function initializeInfiniteScroll(pageType = 'loras') { case 'checkpoints': loadMoreFunction = () => { if (!pageState.isLoading && pageState.hasMore) { - pageState.currentPage++; - window.checkpointManager.loadCheckpoints(false); // false to not reset pagination + loadMoreCheckpoints(false); // false to not reset } }; gridId = 'checkpointGrid'; break; case 'loras': default: - loadMoreFunction = () => loadMoreLoras(false); // false to not reset + loadMoreFunction = () => { + if (!pageState.isLoading && pageState.hasMore) { + loadMoreLoras(false); // false to not reset + } + }; gridId = 'loraGrid'; break; } @@ -85,4 +88,4 @@ export function initializeInfiniteScroll(pageType = 'loras') { state.observer.observe(sentinel); } -} \ No newline at end of file +} \ No newline at end of file diff --git a/templates/components/controls.html b/templates/components/controls.html index 45acc7a4..56921cd5 100644 --- a/templates/components/controls.html +++ b/templates/components/controls.html @@ -16,21 +16,25 @@
- + +
+ +
+
- -
-
-
+ + {% if request.path == '/loras' %}
-
+ {% endif %}
-
From 152ec0da0d9d7f16a361e7295dae3bf70c94ae68 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 10 Apr 2025 19:57:04 +0800 Subject: [PATCH 05/36] Refactor Checkpoints functionality: Integrate loadMoreCheckpoints API, remove CheckpointSearchManager, and enhance FilterManager for improved checkpoint loading and filtering. --- static/js/checkpoints.js | 6 + static/js/managers/CheckpointSearchManager.js | 150 ------------------ static/js/managers/FilterManager.js | 14 +- 3 files changed, 13 insertions(+), 157 deletions(-) delete mode 100644 static/js/managers/CheckpointSearchManager.js diff --git a/static/js/checkpoints.js b/static/js/checkpoints.js index fc09463d..3daa4b5c 100644 --- a/static/js/checkpoints.js +++ b/static/js/checkpoints.js @@ -3,6 +3,7 @@ import { toggleApiKeyVisibility } from './managers/SettingsManager.js'; import { initializeInfiniteScroll } from './utils/infiniteScroll.js'; import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js'; import { createPageControls } from './components/controls/index.js'; +import { loadMoreCheckpoints } from './api/checkpointApi.js'; // Initialize the Checkpoints page class CheckpointsPageManager { @@ -19,6 +20,11 @@ class CheckpointsPageManager { window.confirmDelete = confirmDelete; window.closeDeleteModal = closeDeleteModal; window.toggleApiKeyVisibility = toggleApiKeyVisibility; + + // Add loadCheckpoints function to window for FilterManager compatibility + window.checkpointManager = { + loadCheckpoints: (reset) => loadMoreCheckpoints(reset) + }; } async initialize() { diff --git a/static/js/managers/CheckpointSearchManager.js b/static/js/managers/CheckpointSearchManager.js deleted file mode 100644 index 7567bc7b..00000000 --- a/static/js/managers/CheckpointSearchManager.js +++ /dev/null @@ -1,150 +0,0 @@ -/** - * CheckpointSearchManager - Specialized search manager for the Checkpoints page - * Extends the base SearchManager with checkpoint-specific functionality - */ -import { SearchManager } from './SearchManager.js'; -import { state } from '../state/index.js'; -import { showToast } from '../utils/uiHelpers.js'; - -export class CheckpointSearchManager extends SearchManager { - constructor(options = {}) { - super({ - page: 'checkpoints', - ...options - }); - - this.currentSearchTerm = ''; - - // Store this instance in the state - if (state) { - state.searchManager = this; - } - } - - async performSearch() { - const searchTerm = this.searchInput.value.trim().toLowerCase(); - - if (searchTerm === this.currentSearchTerm && !this.isSearching) { - return; // Avoid duplicate searches - } - - this.currentSearchTerm = searchTerm; - - const grid = document.getElementById('checkpointGrid'); - - if (!searchTerm) { - if (state) { - state.currentPage = 1; - } - this.resetAndReloadCheckpoints(); - return; - } - - try { - this.isSearching = true; - if (state && state.loadingManager) { - state.loadingManager.showSimpleLoading('Searching checkpoints...'); - } - - // Store current scroll position - const scrollPosition = window.pageYOffset || document.documentElement.scrollTop; - - if (state) { - state.currentPage = 1; - state.hasMore = true; - } - - const url = new URL('/api/checkpoints', window.location.origin); - url.searchParams.set('page', '1'); - url.searchParams.set('page_size', '20'); - url.searchParams.set('sort_by', state ? state.sortBy : 'name'); - url.searchParams.set('search', searchTerm); - url.searchParams.set('fuzzy', 'true'); - - // Add search options - const searchOptions = this.getActiveSearchOptions(); - url.searchParams.set('search_filename', searchOptions.filename.toString()); - url.searchParams.set('search_modelname', searchOptions.modelname.toString()); - - // Always send folder parameter if there is an active folder - if (state && state.activeFolder) { - url.searchParams.set('folder', state.activeFolder); - // Add recursive parameter when recursive search is enabled - const recursive = this.recursiveSearchToggle ? this.recursiveSearchToggle.checked : false; - url.searchParams.set('recursive', recursive.toString()); - } - - const response = await fetch(url); - - if (!response.ok) { - throw new Error('Search failed'); - } - - const data = await response.json(); - - if (searchTerm === this.currentSearchTerm && grid) { - grid.innerHTML = ''; - - if (data.items.length === 0) { - grid.innerHTML = '
No matching checkpoints found
'; - if (state) { - state.hasMore = false; - } - } else { - this.appendCheckpointCards(data.items); - if (state) { - state.hasMore = state.currentPage < data.total_pages; - state.currentPage++; - } - } - - // Restore scroll position after content is loaded - setTimeout(() => { - window.scrollTo({ - top: scrollPosition, - behavior: 'instant' // Use 'instant' to prevent animation - }); - }, 10); - } - } catch (error) { - console.error('Checkpoint search error:', error); - showToast('Checkpoint search failed', 'error'); - } finally { - this.isSearching = false; - if (state && state.loadingManager) { - state.loadingManager.hide(); - } - } - } - - resetAndReloadCheckpoints() { - // This function would be implemented in the checkpoints page - if (typeof window.loadCheckpoints === 'function') { - window.loadCheckpoints(); - } else { - // Fallback to reloading the page - window.location.reload(); - } - } - - appendCheckpointCards(checkpoints) { - // This function would be implemented in the checkpoints page - const grid = document.getElementById('checkpointGrid'); - if (!grid) return; - - if (typeof window.appendCheckpointCards === 'function') { - window.appendCheckpointCards(checkpoints); - } else { - // Fallback implementation - checkpoints.forEach(checkpoint => { - const card = document.createElement('div'); - card.className = 'checkpoint-card'; - card.innerHTML = ` -

${checkpoint.name}

-

${checkpoint.filename || 'No filename'}

- `; - grid.appendChild(card); - }); - } - } -} \ No newline at end of file diff --git a/static/js/managers/FilterManager.js b/static/js/managers/FilterManager.js index 57cbbf35..1ebdc9e2 100644 --- a/static/js/managers/FilterManager.js +++ b/static/js/managers/FilterManager.js @@ -1,7 +1,8 @@ -import { BASE_MODELS, BASE_MODEL_CLASSES } from '../utils/constants.js'; -import { state, getCurrentPageState } from '../state/index.js'; +import { BASE_MODEL_CLASSES } from '../utils/constants.js'; +import { getCurrentPageState } from '../state/index.js'; import { showToast, updatePanelPositions } from '../utils/uiHelpers.js'; import { loadMoreLoras } from '../api/loraApi.js'; +import { loadMoreCheckpoints } from '../api/checkpointApi.js'; import { removeStorageItem, setStorageItem, getStorageItem } from '../utils/storageHelpers.js'; export class FilterManager { @@ -73,7 +74,7 @@ export class FilterManager { } else if (this.currentPage === 'checkpoints') { tagsEndpoint = '/api/checkpoints/top-tags?limit=20'; } - + const response = await fetch(tagsEndpoint); if (!response.ok) throw new Error('Failed to fetch tags'); @@ -147,7 +148,6 @@ export class FilterManager { apiEndpoint = '/api/recipes/base-models'; } else if (this.currentPage === 'checkpoints') { apiEndpoint = '/api/checkpoints/base-models'; - return; // No API endpoint for other pages } // Fetch base models @@ -283,7 +283,7 @@ export class FilterManager { // For loras page, reset the page and reload await loadMoreLoras(true, true); } else if (this.currentPage === 'checkpoints' && window.checkpointManager) { - await window.checkpointManager.loadCheckpoints(true); + await loadMoreCheckpoints(true); } // Update filter button to show active state @@ -339,8 +339,8 @@ export class FilterManager { await window.recipeManager.loadRecipes(true); } else if (this.currentPage === 'loras') { await loadMoreLoras(true, true); - } else if (this.currentPage === 'checkpoints' && window.checkpointManager) { - await window.checkpointManager.loadCheckpoints(true); + } else if (this.currentPage === 'checkpoints') { + await loadMoreCheckpoints(true); } showToast(`Filters cleared`, 'info'); From 131c3cc324a699a0d0ee9da060c7ad77ed12cba4 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 10 Apr 2025 21:07:17 +0800 Subject: [PATCH 06/36] Add Civitai metadata fetching functionality for checkpoints - Implement fetchCivitai API method to retrieve metadata from Civitai. - Enhance CheckpointsControls to include fetch from Civitai functionality. - Update PageControls to register fetch from Civitai event listener for both LoRAs and Checkpoints. --- py/routes/checkpoints_routes.py | 176 ++++++++++++++++++ static/js/api/checkpointApi.js | 78 ++++++++ .../controls/CheckpointsControls.js | 7 +- static/js/components/controls/PageControls.js | 22 +-- 4 files changed, 271 insertions(+), 12 deletions(-) diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index efd480dd..fcd47c60 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -1,12 +1,18 @@ import os import json import asyncio +from typing import Dict import aiohttp import jinja2 from aiohttp import web import logging from datetime import datetime +from ..utils.model_utils import determine_base_model + +from ..utils.constants import NSFW_LEVELS +from ..services.civitai_client import CivitaiClient +from ..services.websocket_manager import ws_manager from ..services.checkpoint_scanner import CheckpointScanner from ..config import config from ..services.settings_manager import settings @@ -28,6 +34,7 @@ class CheckpointsRoutes: """Register routes with the aiohttp app""" app.router.add_get('/checkpoints', self.handle_checkpoints_page) app.router.add_get('/api/checkpoints', self.get_checkpoints) + app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai) app.router.add_get('/api/checkpoints/base-models', self.get_base_models) app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags) app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints) @@ -267,6 +274,175 @@ class CheckpointsRoutes: ] return {k: data[k] for k in fields if k in data} + async def fetch_all_civitai(self, request: web.Request) -> web.Response: + """Fetch CivitAI metadata for all checkpoints in the background""" + try: + cache = await self.scanner.get_cached_data() + total = len(cache.raw_data) + processed = 0 + success = 0 + needs_resort = False + + # Prepare checkpoints to process + to_process = [ + cp for cp in cache.raw_data + if cp.get('sha256') and (not cp.get('civitai') or 'id' not in cp.get('civitai')) and cp.get('from_civitai', True) + ] + total_to_process = len(to_process) + + # Send initial progress + await ws_manager.broadcast({ + 'status': 'started', + 'total': total_to_process, + 'processed': 0, + 'success': 0 + }) + + # Process each checkpoint + for cp in to_process: + try: + original_name = cp.get('model_name') + if await self._fetch_and_update_single_checkpoint( + sha256=cp['sha256'], + file_path=cp['file_path'], + checkpoint=cp + ): + success += 1 + if original_name != cp.get('model_name'): + needs_resort = True + + processed += 1 + + # Send progress update + await ws_manager.broadcast({ + 'status': 'processing', + 'total': total_to_process, + 'processed': processed, + 'success': success, + 'current_name': cp.get('model_name', 'Unknown') + }) + + except Exception as e: + logger.error(f"Error fetching CivitAI data for {cp['file_path']}: {e}") + + if needs_resort: + await cache.resort(name_only=True) + + # Send completion message + await ws_manager.broadcast({ + 'status': 'completed', + 'total': total_to_process, + 'processed': processed, + 'success': success + }) + + return web.json_response({ + "success": True, + "message": f"Successfully updated {success} of {processed} processed checkpoints (total: {total})" + }) + + except Exception as e: + # Send error message + await ws_manager.broadcast({ + 'status': 'error', + 'error': str(e) + }) + logger.error(f"Error in fetch_all_civitai for checkpoints: {e}") + return web.Response(text=str(e), status=500) + + async def _fetch_and_update_single_checkpoint(self, sha256: str, file_path: str, checkpoint: dict) -> bool: + """Fetch and update metadata for a single checkpoint without sorting""" + client = CivitaiClient() + try: + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + + # Load local metadata + local_metadata = self._load_local_metadata(metadata_path) + + # Fetch metadata from Civitai + civitai_metadata = await client.get_model_by_hash(sha256) + if not civitai_metadata: + # Mark as not from CivitAI if not found + local_metadata['from_civitai'] = False + checkpoint['from_civitai'] = False + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(local_metadata, f, indent=2, ensure_ascii=False) + return False + + # Update metadata with Civitai data + await self._update_model_metadata( + metadata_path, + local_metadata, + civitai_metadata, + client + ) + + # Update cache object directly + checkpoint.update({ + 'model_name': local_metadata.get('model_name'), + 'preview_url': local_metadata.get('preview_url'), + 'from_civitai': True, + 'civitai': civitai_metadata + }) + + return True + + except Exception as e: + logger.error(f"Error fetching CivitAI data for checkpoint: {e}") + return False + finally: + await client.close() + + def _load_local_metadata(self, metadata_path: str) -> Dict: + """Load local metadata file""" + if os.path.exists(metadata_path): + try: + with open(metadata_path, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + logger.error(f"Error loading metadata from {metadata_path}: {e}") + return {} + + async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict, + civitai_metadata: Dict, client: CivitaiClient) -> None: + """Update local metadata with CivitAI data""" + local_metadata['civitai'] = civitai_metadata + + # Update model name if available + if 'model' in civitai_metadata: + if civitai_metadata.get('model', {}).get('name'): + local_metadata['model_name'] = civitai_metadata['model']['name'] + + # Fetch additional model metadata (description and tags) if we have model ID + model_id = civitai_metadata['modelId'] + if model_id: + model_metadata, _ = await client.get_model_metadata(str(model_id)) + if model_metadata: + local_metadata['modelDescription'] = model_metadata.get('description', '') + local_metadata['tags'] = model_metadata.get('tags', []) + + # Update base model + local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel')) + + # Update preview if needed + if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']): + first_preview = next((img for img in civitai_metadata.get('images', [])), None) + if first_preview: + preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1] + base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0] + preview_filename = base_name + preview_ext + preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename) + + if await client.download_preview_image(first_preview['url'], preview_path): + local_metadata['preview_url'] = preview_path.replace(os.sep, '/') + local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0) + + # Save updated metadata + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(local_metadata, f, indent=2, ensure_ascii=False) + + await self.scanner.update_single_model_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata) + async def get_top_tags(self, request: web.Request) -> web.Response: """Handle request for top tags sorted by frequency""" try: diff --git a/static/js/api/checkpointApi.js b/static/js/api/checkpointApi.js index 1dd0c8fe..e0ed5d4f 100644 --- a/static/js/api/checkpointApi.js +++ b/static/js/api/checkpointApi.js @@ -249,4 +249,82 @@ async function _uploadPreview(filePath, file) { console.error('Error updating preview:', error); showToast(`Failed to update preview: ${error.message}`, 'error'); } +} + +// Fetch metadata from Civitai for checkpoints +export async function fetchCivitai() { + let ws = null; + + await state.loadingManager.showWithProgress(async (loading) => { + try { + const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; + const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`); + + const operationComplete = new Promise((resolve, reject) => { + ws.onmessage = (event) => { + const data = JSON.parse(event.data); + + switch(data.status) { + case 'started': + loading.setStatus('Starting metadata fetch...'); + break; + + case 'processing': + const percent = ((data.processed / data.total) * 100).toFixed(1); + loading.setProgress(percent); + loading.setStatus( + `Processing (${data.processed}/${data.total}) ${data.current_name}` + ); + break; + + case 'completed': + loading.setProgress(100); + loading.setStatus( + `Completed: Updated ${data.success} of ${data.processed} checkpoints` + ); + resolve(); + break; + + case 'error': + reject(new Error(data.error)); + break; + } + }; + + ws.onerror = (error) => { + reject(new Error('WebSocket error: ' + error.message)); + }; + }); + + await new Promise((resolve, reject) => { + ws.onopen = resolve; + ws.onerror = reject; + }); + + const response = await fetch('/api/checkpoints/fetch-all-civitai', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ model_type: 'checkpoint' }) // Specify we're fetching checkpoint metadata + }); + + if (!response.ok) { + throw new Error('Failed to fetch metadata'); + } + + await operationComplete; + + await resetAndReload(); + + } catch (error) { + console.error('Error fetching metadata:', error); + showToast('Failed to fetch metadata: ' + error.message, 'error'); + } finally { + if (ws) { + ws.close(); + } + } + }, { + initialMessage: 'Connecting...', + completionMessage: 'Metadata update complete' + }); } \ No newline at end of file diff --git a/static/js/components/controls/CheckpointsControls.js b/static/js/components/controls/CheckpointsControls.js index 2dd968e2..44c6104a 100644 --- a/static/js/components/controls/CheckpointsControls.js +++ b/static/js/components/controls/CheckpointsControls.js @@ -1,6 +1,6 @@ // CheckpointsControls.js - Specific implementation for the Checkpoints page import { PageControls } from './PageControls.js'; -import { loadMoreCheckpoints, resetAndReload, refreshCheckpoints } from '../../api/checkpointApi.js'; +import { loadMoreCheckpoints, resetAndReload, refreshCheckpoints, fetchCivitai } from '../../api/checkpointApi.js'; import { showToast } from '../../utils/uiHelpers.js'; /** @@ -33,6 +33,11 @@ export class CheckpointsControls extends PageControls { return await refreshCheckpoints(); }, + // Add fetch from Civitai functionality for checkpoints + fetchFromCivitai: async () => { + return await fetchCivitai(); + }, + // No clearCustomFilter implementation is needed for checkpoints // as custom filters are currently only used for LoRAs clearCustomFilter: async () => { diff --git a/static/js/components/controls/PageControls.js b/static/js/components/controls/PageControls.js index 43498f12..0bc4f64e 100644 --- a/static/js/components/controls/PageControls.js +++ b/static/js/components/controls/PageControls.js @@ -97,20 +97,20 @@ export class PageControls { * Initialize page-specific event listeners */ initPageSpecificListeners() { + // Fetch from Civitai button - available for both loras and checkpoints + const fetchButton = document.querySelector('[data-action="fetch"]'); + if (fetchButton) { + fetchButton.addEventListener('click', () => this.fetchFromCivitai()); + } + if (this.pageType === 'loras') { - // Fetch from Civitai button - const fetchButton = document.querySelector('[data-action="fetch"]'); - if (fetchButton) { - fetchButton.addEventListener('click', () => this.fetchFromCivitai()); - } - - // Download button + // Download button - LoRAs only const downloadButton = document.querySelector('[data-action="download"]'); if (downloadButton) { downloadButton.addEventListener('click', () => this.showDownloadModal()); } - // Bulk operations button + // Bulk operations button - LoRAs only const bulkButton = document.querySelector('[data-action="bulk"]'); if (bulkButton) { bulkButton.addEventListener('click', () => this.toggleBulkMode()); @@ -332,11 +332,11 @@ export class PageControls { } /** - * Fetch metadata from Civitai (LoRAs only) + * Fetch metadata from Civitai (available for both LoRAs and Checkpoints) */ async fetchFromCivitai() { - if (this.pageType !== 'loras' || !this.api) { - console.error('Fetch from Civitai is only available for LoRAs'); + if (!this.api) { + console.error('API methods not registered'); return; } From 311bf1f1575ec478311e072c1fd42d8dcf3d5ec8 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 10 Apr 2025 21:15:12 +0800 Subject: [PATCH 07/36] Add support for '.gguf' file extension in CheckpointScanner --- py/services/checkpoint_scanner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/services/checkpoint_scanner.py b/py/services/checkpoint_scanner.py index 27d15273..4cc77b6a 100644 --- a/py/services/checkpoint_scanner.py +++ b/py/services/checkpoint_scanner.py @@ -25,7 +25,7 @@ class CheckpointScanner(ModelScanner): def __init__(self): if not hasattr(self, '_initialized'): # Define supported file extensions - file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft'} + file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'} super().__init__( model_type="checkpoint", model_class=CheckpointMetadata, From 559e57ca4663e868556c0cec2ba439bd58af00e7 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 10 Apr 2025 21:28:34 +0800 Subject: [PATCH 08/36] Enhance CheckpointCard: Implement NSFW content handling, toggle blur functionality, and improve video autoplay behavior --- static/js/components/CheckpointCard.js | 168 ++++++++++++++++++++++++- 1 file changed, 163 insertions(+), 5 deletions(-) diff --git a/static/js/components/CheckpointCard.js b/static/js/components/CheckpointCard.js index a9246b8e..b8925b52 100644 --- a/static/js/components/CheckpointCard.js +++ b/static/js/components/CheckpointCard.js @@ -1,6 +1,7 @@ import { showToast } from '../utils/uiHelpers.js'; import { state } from '../state/index.js'; import { CheckpointModal } from './CheckpointModal.js'; +import { NSFW_LEVELS } from '../utils/constants.js'; // Create an instance of the modal const checkpointModal = new CheckpointModal(); @@ -28,28 +29,73 @@ export function createCheckpointCard(checkpoint) { card.dataset.tags = JSON.stringify(checkpoint.tags); } + // Store NSFW level if available + const nsfwLevel = checkpoint.preview_nsfw_level !== undefined ? checkpoint.preview_nsfw_level : 0; + card.dataset.nsfwLevel = nsfwLevel; + + // Determine if the preview should be blurred based on NSFW level and user settings + const shouldBlur = state.settings.blurMatureContent && nsfwLevel > NSFW_LEVELS.PG13; + if (shouldBlur) { + card.classList.add('nsfw-content'); + } + // Determine preview URL const previewUrl = checkpoint.preview_url || '/loras_static/images/no-preview.png'; const version = state.previewVersions ? state.previewVersions.get(checkpoint.file_path) : null; const versionedPreviewUrl = version ? `${previewUrl}?t=${version}` : previewUrl; + // Determine NSFW warning text based on level + let nsfwText = "Mature Content"; + if (nsfwLevel >= NSFW_LEVELS.XXX) { + nsfwText = "XXX-rated Content"; + } else if (nsfwLevel >= NSFW_LEVELS.X) { + nsfwText = "X-rated Content"; + } else if (nsfwLevel >= NSFW_LEVELS.R) { + nsfwText = "R-rated Content"; + } + + // Check if autoplayOnHover is enabled for video previews + const autoplayOnHover = state.global?.settings?.autoplayOnHover || false; + const isVideo = previewUrl.endsWith('.mp4'); + const videoAttrs = autoplayOnHover ? 'controls muted loop' : 'controls autoplay muted loop'; + card.innerHTML = ` -
- ${checkpoint.model_name} +
+ ${isVideo ? + `` : + `${checkpoint.model_name}` + }
- - ${checkpoint.base_model || 'Unknown'} + ${shouldBlur ? + `` : ''} + + ${checkpoint.base_model}
+ +
+ ${shouldBlur ? ` +
+
+

${nsfwText}

+ +
+
+ ` : ''} \ No newline at end of file +
\ No newline at end of file From 129ca9da81e5930b66fe8d8d610349df553c9d07 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 10 Apr 2025 22:59:09 +0800 Subject: [PATCH 11/36] feat: Implement checkpoint modal functionality with metadata editing, showcase display, and utility functions - Added ModelMetadata.js for handling model metadata editing, including model name, base model, and file name. - Introduced ShowcaseView.js to manage the display of images and videos in the checkpoint modal, including NSFW filtering and lazy loading. - Created index.js as the main entry point for the checkpoint modal, integrating various components and functionalities. - Developed utils.js for utility functions related to file size formatting and tag rendering. - Enhanced user experience with editable fields, toast notifications, and improved showcase scrolling. --- static/js/components/CheckpointCard.js | 14 +- static/js/components/CheckpointModal.js | 990 ------------------ .../checkpointModal/ModelDescription.js | 102 ++ .../checkpointModal/ModelMetadata.js | 492 +++++++++ .../checkpointModal/ShowcaseView.js | 489 +++++++++ static/js/components/checkpointModal/index.js | 219 ++++ static/js/components/checkpointModal/utils.js | 74 ++ 7 files changed, 1384 insertions(+), 996 deletions(-) delete mode 100644 static/js/components/CheckpointModal.js create mode 100644 static/js/components/checkpointModal/ModelDescription.js create mode 100644 static/js/components/checkpointModal/ModelMetadata.js create mode 100644 static/js/components/checkpointModal/ShowcaseView.js create mode 100644 static/js/components/checkpointModal/index.js create mode 100644 static/js/components/checkpointModal/utils.js diff --git a/static/js/components/CheckpointCard.js b/static/js/components/CheckpointCard.js index b8925b52..f7031c15 100644 --- a/static/js/components/CheckpointCard.js +++ b/static/js/components/CheckpointCard.js @@ -1,11 +1,8 @@ import { showToast } from '../utils/uiHelpers.js'; import { state } from '../state/index.js'; -import { CheckpointModal } from './CheckpointModal.js'; +import { showCheckpointModal } from './checkpointModal/index.js'; import { NSFW_LEVELS } from '../utils/constants.js'; -// Create an instance of the modal -const checkpointModal = new CheckpointModal(); - export function createCheckpointCard(checkpoint) { const card = document.createElement('div'); card.className = 'lora-card'; // Reuse the same class for styling @@ -29,6 +26,10 @@ export function createCheckpointCard(checkpoint) { card.dataset.tags = JSON.stringify(checkpoint.tags); } + if (checkpoint.modelDescription) { + card.dataset.modelDescription = checkpoint.modelDescription; + } + // Store NSFW level if available const nsfwLevel = checkpoint.preview_nsfw_level !== undefined ? checkpoint.preview_nsfw_level : 0; card.dataset.nsfwLevel = nsfwLevel; @@ -139,9 +140,10 @@ export function createCheckpointCard(checkpoint) { console.error('Failed to parse tags:', e); return []; // Return empty array on error } - })() + })(), + modelDescription: card.dataset.modelDescription || '' }; - checkpointModal.showCheckpointDetails(checkpointMeta); + showCheckpointModal(checkpointMeta); }); // Toggle blur button functionality diff --git a/static/js/components/CheckpointModal.js b/static/js/components/CheckpointModal.js deleted file mode 100644 index 160b274d..00000000 --- a/static/js/components/CheckpointModal.js +++ /dev/null @@ -1,990 +0,0 @@ -import { showToast } from '../utils/uiHelpers.js'; -import { BASE_MODELS } from '../utils/constants.js'; - -/** - * CheckpointModal - Component for displaying checkpoint details - * Similar to LoraModal but customized for checkpoint models - */ -export class CheckpointModal { - constructor() { - this.modal = document.getElementById('checkpointModal'); - this.modalTitle = document.getElementById('checkpointModalTitle'); - this.modalContent = document.getElementById('checkpointModalContent'); - this.currentCheckpoint = null; - - // Initialize close events - this._initCloseEvents(); - } - - _initCloseEvents() { - if (!this.modal) return; - - // Close button - const closeBtn = this.modal.querySelector('.close'); - if (closeBtn) { - closeBtn.addEventListener('click', () => this.close()); - } - - // Click outside to close - this.modal.addEventListener('click', (e) => { - if (e.target === this.modal) { - this.close(); - } - }); - } - - /** - * Format file size for display - * @param {number} bytes - File size in bytes - * @returns {string} - Formatted file size - */ - _formatFileSize(bytes) { - if (!bytes) return 'Unknown'; - - const sizes = ['Bytes', 'KB', 'MB', 'GB', 'TB']; - if (bytes === 0) return '0 Bytes'; - const i = Math.floor(Math.log(bytes) / Math.log(1024)); - if (i === 0) return `${bytes} ${sizes[i]}`; - return `${(bytes / Math.pow(1024, i)).toFixed(2)} ${sizes[i]}`; - } - - /** - * Render compact tags for the checkpoint - * @param {Array} tags - Array of tags - * @returns {string} - HTML for tags - */ - _renderCompactTags(tags) { - if (!tags || tags.length === 0) return ''; - - // Display up to 5 tags, with a count if there are more - const visibleTags = tags.slice(0, 5); - const remainingCount = Math.max(0, tags.length - 5); - - return ` -
-
- ${visibleTags.map(tag => `${tag}`).join('')} - ${remainingCount > 0 ? - `+${remainingCount}` : - ''} -
- ${tags.length > 0 ? - `
-
- ${tags.map(tag => `${tag}`).join('')} -
-
` : - ''} -
- `; - } - - /** - * Set up tag tooltip functionality - */ - _setupTagTooltip() { - const tagsContainer = document.querySelector('.model-tags-container'); - const tooltip = document.querySelector('.model-tags-tooltip'); - - if (tagsContainer && tooltip) { - tagsContainer.addEventListener('mouseenter', () => { - tooltip.classList.add('visible'); - }); - - tagsContainer.addEventListener('mouseleave', () => { - tooltip.classList.remove('visible'); - }); - } - } - - /** - * Render showcase content (example images) - * @param {Array} images - Array of image data - * @returns {string} - HTML content - */ - _renderShowcaseContent(images) { - if (!images?.length) return '
No example images available
'; - - return ` -
- - Scroll or click to show ${images.length} examples -
- - `; - } - - /** - * Show checkpoint details in the modal - * @param {Object} checkpoint - Checkpoint data - */ - showCheckpointDetails(checkpoint) { - if (!this.modal) { - console.error('Checkpoint modal element not found'); - return; - } - - this.currentCheckpoint = checkpoint; - - const content = ` - - `; - - this.modal.innerHTML = content; - this.modal.style.display = 'block'; - - this._setupEditableFields(); - this._setupShowcaseScroll(); - this._setupTabSwitching(); - this._setupTagTooltip(); - this._setupModelNameEditing(); - this._setupBaseModelEditing(); - this._setupFileNameEditing(); - - // If we have a model ID but no description, fetch it - if (checkpoint.civitai?.modelId && !checkpoint.modelDescription) { - this._loadModelDescription(checkpoint.civitai.modelId, checkpoint.file_path); - } - } - - /** - * Close the checkpoint modal - */ - close() { - if (this.modal) { - this.modal.style.display = 'none'; - this.currentCheckpoint = null; - } - } - - /** - * Set up editable fields in the modal - */ - _setupEditableFields() { - const editableFields = this.modal.querySelectorAll('.editable-field [contenteditable]'); - - editableFields.forEach(field => { - field.addEventListener('focus', function() { - if (this.textContent === 'Add your notes here...') { - this.textContent = ''; - } - }); - - field.addEventListener('blur', function() { - if (this.textContent.trim() === '') { - if (this.classList.contains('notes-content')) { - this.textContent = 'Add your notes here...'; - } - } - }); - }); - - // Add keydown event listeners for notes - const notesContent = this.modal.querySelector('.notes-content'); - if (notesContent) { - notesContent.addEventListener('keydown', async (e) => { - if (e.key === 'Enter') { - if (e.shiftKey) { - // Allow shift+enter for new line - return; - } - e.preventDefault(); - const filePath = this.modal.querySelector('.file-path').textContent + - this.modal.querySelector('#file-name').textContent; - await this._saveNotes(filePath); - } - }); - } - } - - /** - * Save notes for the checkpoint - * @param {string} filePath - Path to the checkpoint file - */ - async _saveNotes(filePath) { - const content = this.modal.querySelector('.notes-content').textContent; - try { - // This would typically call an API endpoint to save the notes - // For now we'll just show a success message - console.log('Would save notes:', content, 'for file:', filePath); - - showToast('Notes saved successfully', 'success'); - } catch (error) { - showToast('Failed to save notes', 'error'); - } - } - - /** - * Set up model name editing functionality - */ - _setupModelNameEditing() { - const modelNameContent = this.modal.querySelector('.model-name-content'); - const editBtn = this.modal.querySelector('.edit-model-name-btn'); - - if (!modelNameContent || !editBtn) return; - - // Show edit button on hover - const modelNameHeader = this.modal.querySelector('.model-name-header'); - modelNameHeader.addEventListener('mouseenter', () => { - editBtn.classList.add('visible'); - }); - - modelNameHeader.addEventListener('mouseleave', () => { - if (!modelNameContent.getAttribute('data-editing')) { - editBtn.classList.remove('visible'); - } - }); - - // Handle edit button click - editBtn.addEventListener('click', () => { - modelNameContent.setAttribute('data-editing', 'true'); - modelNameContent.focus(); - - // Place cursor at the end - const range = document.createRange(); - const sel = window.getSelection(); - if (modelNameContent.childNodes.length > 0) { - range.setStart(modelNameContent.childNodes[0], modelNameContent.textContent.length); - range.collapse(true); - sel.removeAllRanges(); - sel.addRange(range); - } - - editBtn.classList.add('visible'); - }); - - // Handle focus out - modelNameContent.addEventListener('blur', function() { - this.removeAttribute('data-editing'); - editBtn.classList.remove('visible'); - - if (this.textContent.trim() === '') { - // Restore original model name if empty - this.textContent = 'Checkpoint Details'; - } - }); - - // Handle enter key - modelNameContent.addEventListener('keydown', (e) => { - if (e.key === 'Enter') { - e.preventDefault(); - modelNameContent.blur(); - // Save model name here (would call an API endpoint) - showToast('Model name updated', 'success'); - } - }); - - // Limit model name length - modelNameContent.addEventListener('input', function() { - if (this.textContent.length > 100) { - this.textContent = this.textContent.substring(0, 100); - // Place cursor at the end - const range = document.createRange(); - const sel = window.getSelection(); - range.setStart(this.childNodes[0], 100); - range.collapse(true); - sel.removeAllRanges(); - sel.addRange(range); - - showToast('Model name is limited to 100 characters', 'warning'); - } - }); - } - - /** - * Set up base model editing functionality - */ - _setupBaseModelEditing() { - const baseModelContent = this.modal.querySelector('.base-model-content'); - const editBtn = this.modal.querySelector('.edit-base-model-btn'); - - if (!baseModelContent || !editBtn) return; - - // Show edit button on hover - const baseModelDisplay = this.modal.querySelector('.base-model-display'); - baseModelDisplay.addEventListener('mouseenter', () => { - editBtn.classList.add('visible'); - }); - - baseModelDisplay.addEventListener('mouseleave', () => { - if (!baseModelDisplay.classList.contains('editing')) { - editBtn.classList.remove('visible'); - } - }); - - // Handle edit button click - editBtn.addEventListener('click', () => { - baseModelDisplay.classList.add('editing'); - - // Store the original value to check for changes later - const originalValue = baseModelContent.textContent.trim(); - - // Create dropdown selector to replace the base model content - const currentValue = originalValue; - const dropdown = document.createElement('select'); - dropdown.className = 'base-model-selector'; - - // Flag to track if a change was made - let valueChanged = false; - - // Add options from BASE_MODELS constants - const baseModelCategories = { - 'Stable Diffusion 1.x': [BASE_MODELS.SD_1_4, BASE_MODELS.SD_1_5, BASE_MODELS.SD_1_5_LCM, BASE_MODELS.SD_1_5_HYPER], - 'Stable Diffusion 2.x': [BASE_MODELS.SD_2_0, BASE_MODELS.SD_2_1], - 'Stable Diffusion 3.x': [BASE_MODELS.SD_3, BASE_MODELS.SD_3_5, BASE_MODELS.SD_3_5_MEDIUM, BASE_MODELS.SD_3_5_LARGE, BASE_MODELS.SD_3_5_LARGE_TURBO], - 'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER], - 'Video Models': [BASE_MODELS.SVD, BASE_MODELS.WAN_VIDEO, BASE_MODELS.HUNYUAN_VIDEO], - 'Other Models': [ - BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.AURAFLOW, - BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1, - BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI, - BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.UNKNOWN - ] - }; - - // Create option groups for better organization - Object.entries(baseModelCategories).forEach(([category, models]) => { - const group = document.createElement('optgroup'); - group.label = category; - - models.forEach(model => { - const option = document.createElement('option'); - option.value = model; - option.textContent = model; - option.selected = model === currentValue; - group.appendChild(option); - }); - - dropdown.appendChild(group); - }); - - // Replace content with dropdown - baseModelContent.style.display = 'none'; - baseModelDisplay.insertBefore(dropdown, editBtn); - - // Hide edit button during editing - editBtn.style.display = 'none'; - - // Focus the dropdown - dropdown.focus(); - - // Handle dropdown change - dropdown.addEventListener('change', function() { - const selectedModel = this.value; - baseModelContent.textContent = selectedModel; - - // Mark that a change was made if the value differs from original - if (selectedModel !== originalValue) { - valueChanged = true; - } else { - valueChanged = false; - } - }); - - // Function to save changes and exit edit mode - const saveAndExit = function() { - // Check if dropdown still exists and remove it - if (dropdown && dropdown.parentNode === baseModelDisplay) { - baseModelDisplay.removeChild(dropdown); - } - - // Show the content and edit button - baseModelContent.style.display = ''; - editBtn.style.display = ''; - - // Remove editing class - baseModelDisplay.classList.remove('editing'); - - // Only save if the value has actually changed - if (valueChanged || baseModelContent.textContent.trim() !== originalValue) { - // Get file path for saving - const filePath = document.querySelector('#checkpointModal .modal-content') - .querySelector('.file-path').textContent + - document.querySelector('#checkpointModal .modal-content') - .querySelector('#file-name').textContent + '.safetensors'; - - // Save the changes (would call API to save model base change) - showToast('Base model updated successfully', 'success'); - } - - // Remove this event listener - document.removeEventListener('click', outsideClickHandler); - }; - - // Handle outside clicks to save and exit - const outsideClickHandler = function(e) { - // If click is outside the dropdown and base model display - if (!baseModelDisplay.contains(e.target)) { - saveAndExit(); - } - }; - - // Add delayed event listener for outside clicks - setTimeout(() => { - document.addEventListener('click', outsideClickHandler); - }, 0); - - // Also handle dropdown blur event - dropdown.addEventListener('blur', function(e) { - // Only save if the related target is not the edit button or inside the baseModelDisplay - if (!baseModelDisplay.contains(e.relatedTarget)) { - saveAndExit(); - } - }); - }); - } - - /** - * Set up file name editing functionality - */ - _setupFileNameEditing() { - const fileNameContent = this.modal.querySelector('.file-name-content'); - const editBtn = this.modal.querySelector('.edit-file-name-btn'); - - if (!fileNameContent || !editBtn) return; - - // Show edit button on hover - const fileNameWrapper = this.modal.querySelector('.file-name-wrapper'); - fileNameWrapper.addEventListener('mouseenter', () => { - editBtn.classList.add('visible'); - }); - - fileNameWrapper.addEventListener('mouseleave', () => { - if (!fileNameWrapper.classList.contains('editing')) { - editBtn.classList.remove('visible'); - } - }); - - // Handle edit button click - editBtn.addEventListener('click', () => { - fileNameWrapper.classList.add('editing'); - fileNameContent.setAttribute('contenteditable', 'true'); - fileNameContent.focus(); - - // Store original value - fileNameContent.dataset.originalValue = fileNameContent.textContent.trim(); - - // Place cursor at the end - const range = document.createRange(); - const sel = window.getSelection(); - range.selectNodeContents(fileNameContent); - range.collapse(false); - sel.removeAllRanges(); - sel.addRange(range); - - editBtn.classList.add('visible'); - }); - - // Handle keyboard events - fileNameContent.addEventListener('keydown', function(e) { - if (!this.getAttribute('contenteditable')) return; - - if (e.key === 'Enter') { - e.preventDefault(); - this.blur(); - } else if (e.key === 'Escape') { - e.preventDefault(); - // Restore original value - this.textContent = this.dataset.originalValue; - exitEditMode(); - } - }); - - // Handle input validation - fileNameContent.addEventListener('input', function() { - if (!this.getAttribute('contenteditable')) return; - - // Replace invalid characters for filenames - const invalidChars = /[\\/:*?"<>|]/g; - if (invalidChars.test(this.textContent)) { - const cursorPos = window.getSelection().getRangeAt(0).startOffset; - this.textContent = this.textContent.replace(invalidChars, ''); - - // Restore cursor position - const range = document.createRange(); - const sel = window.getSelection(); - const newPos = Math.min(cursorPos, this.textContent.length); - - if (this.firstChild) { - range.setStart(this.firstChild, newPos); - range.collapse(true); - sel.removeAllRanges(); - sel.addRange(range); - } - - showToast('Invalid characters removed from filename', 'warning'); - } - }); - - // Handle focus out - save changes - fileNameContent.addEventListener('blur', function() { - if (!this.getAttribute('contenteditable')) return; - - const newFileName = this.textContent.trim(); - const originalValue = this.dataset.originalValue; - - // Validation - if (!newFileName) { - this.textContent = originalValue; - showToast('File name cannot be empty', 'error'); - exitEditMode(); - return; - } - - if (newFileName !== originalValue) { - // Would call API to rename file - showToast(`File would be renamed to: ${newFileName}`, 'success'); - } - - exitEditMode(); - }); - - function exitEditMode() { - fileNameContent.removeAttribute('contenteditable'); - fileNameWrapper.classList.remove('editing'); - editBtn.classList.remove('visible'); - } - } - - /** - * Set up showcase scroll functionality - */ - _setupShowcaseScroll() { - // Initialize scroll listeners for showcase section - const showcaseSection = this.modal.querySelector('.showcase-section'); - if (!showcaseSection) return; - - // Set up back-to-top button - const backToTopBtn = showcaseSection.querySelector('.back-to-top'); - const modalContent = this.modal.querySelector('.modal-content'); - - if (backToTopBtn && modalContent) { - modalContent.addEventListener('scroll', () => { - if (modalContent.scrollTop > 300) { - backToTopBtn.classList.add('visible'); - } else { - backToTopBtn.classList.remove('visible'); - } - }); - } - - // Set up scroll to toggle showcase - document.addEventListener('wheel', (event) => { - if (this.modal.style.display !== 'block') return; - - const showcase = this.modal.querySelector('.showcase-section'); - if (!showcase) return; - - const carousel = showcase.querySelector('.carousel'); - const scrollIndicator = showcase.querySelector('.scroll-indicator'); - - if (carousel?.classList.contains('collapsed') && event.deltaY > 0) { - const isNearBottom = modalContent.scrollHeight - modalContent.scrollTop - modalContent.clientHeight < 100; - - if (isNearBottom) { - this._toggleShowcase(scrollIndicator); - event.preventDefault(); - } - } - }, { passive: false }); - } - - /** - * Toggle showcase expansion - * @param {HTMLElement} element - The scroll indicator element - */ - _toggleShowcase(element) { - const carousel = element.nextElementSibling; - const isCollapsed = carousel.classList.contains('collapsed'); - const indicator = element.querySelector('span'); - const icon = element.querySelector('i'); - - carousel.classList.toggle('collapsed'); - - if (isCollapsed) { - const count = carousel.querySelectorAll('.media-wrapper').length; - indicator.textContent = `Scroll or click to hide examples`; - icon.classList.replace('fa-chevron-down', 'fa-chevron-up'); - this._initLazyLoading(carousel); - this._initMetadataPanelHandlers(carousel); - } else { - const count = carousel.querySelectorAll('.media-wrapper').length; - indicator.textContent = `Scroll or click to show ${count} examples`; - icon.classList.replace('fa-chevron-up', 'fa-chevron-down'); - } - } - - /** - * Initialize lazy loading for images - * @param {HTMLElement} container - Container with lazy-load images - */ - _initLazyLoading(container) { - const lazyImages = container.querySelectorAll('img.lazy'); - - const lazyLoad = (image) => { - image.src = image.dataset.src; - image.classList.remove('lazy'); - }; - - const observer = new IntersectionObserver((entries) => { - entries.forEach(entry => { - if (entry.isIntersecting) { - lazyLoad(entry.target); - observer.unobserve(entry.target); - } - }); - }); - - lazyImages.forEach(image => observer.observe(image)); - } - - /** - * Initialize metadata panel handlers - * @param {HTMLElement} container - Container with metadata panels - */ - _initMetadataPanelHandlers(container) { - const mediaWrappers = container.querySelectorAll('.media-wrapper'); - - mediaWrappers.forEach(wrapper => { - const metadataPanel = wrapper.querySelector('.image-metadata-panel'); - if (!metadataPanel) return; - - // Prevent events from bubbling - metadataPanel.addEventListener('click', (e) => { - e.stopPropagation(); - }); - - // Handle copy prompt buttons - const copyBtns = metadataPanel.querySelectorAll('.copy-prompt-btn'); - copyBtns.forEach(copyBtn => { - const promptIndex = copyBtn.dataset.promptIndex; - const promptElement = wrapper.querySelector(`#prompt-${promptIndex}`); - - copyBtn.addEventListener('click', async (e) => { - e.stopPropagation(); - - if (!promptElement) return; - - try { - await navigator.clipboard.writeText(promptElement.textContent); - showToast('Prompt copied to clipboard', 'success'); - } catch (err) { - console.error('Copy failed:', err); - showToast('Copy failed', 'error'); - } - }); - }); - - // Prevent panel scroll from causing modal scroll - metadataPanel.addEventListener('wheel', (e) => { - e.stopPropagation(); - }); - }); - } - - /** - * Set up tab switching functionality - */ - _setupTabSwitching() { - const tabButtons = this.modal.querySelectorAll('.showcase-tabs .tab-btn'); - - tabButtons.forEach(button => { - button.addEventListener('click', () => { - // Remove active class from all tabs - this.modal.querySelectorAll('.showcase-tabs .tab-btn').forEach(btn => - btn.classList.remove('active') - ); - this.modal.querySelectorAll('.tab-content .tab-pane').forEach(tab => - tab.classList.remove('active') - ); - - // Add active class to clicked tab - button.classList.add('active'); - const tabId = `${button.dataset.tab}-tab`; - this.modal.querySelector(`#${tabId}`).classList.add('active'); - - // If switching to description tab, handle content - if (button.dataset.tab === 'description') { - const descriptionContent = this.modal.querySelector('.model-description-content'); - if (descriptionContent) { - const hasContent = descriptionContent.innerHTML.trim() !== ''; - this.modal.querySelector('.model-description-loading')?.classList.add('hidden'); - - if (!hasContent) { - descriptionContent.innerHTML = '
No model description available
'; - descriptionContent.classList.remove('hidden'); - } - } - } - }); - }); - } - - /** - * Load model description from API - * @param {string} modelId - Model ID - * @param {string} filePath - File path - */ - async _loadModelDescription(modelId, filePath) { - try { - const descriptionContainer = this.modal.querySelector('.model-description-content'); - const loadingElement = this.modal.querySelector('.model-description-loading'); - - if (!descriptionContainer || !loadingElement) return; - - // Show loading indicator - loadingElement.classList.remove('hidden'); - descriptionContainer.classList.add('hidden'); - - // In production, this would fetch from the API - // For now, just simulate loading - setTimeout(() => { - descriptionContainer.innerHTML = '

This is a placeholder for the checkpoint model description.

'; - - // Show the description and hide loading indicator - descriptionContainer.classList.remove('hidden'); - loadingElement.classList.add('hidden'); - }, 500); - } catch (error) { - console.error('Error loading model description:', error); - const loadingElement = this.modal.querySelector('.model-description-loading'); - if (loadingElement) { - loadingElement.innerHTML = `
Failed to load model description. ${error.message}
`; - } - - // Show empty state message - const descriptionContainer = this.modal.querySelector('.model-description-content'); - if (descriptionContainer) { - descriptionContainer.innerHTML = '
No model description available
'; - descriptionContainer.classList.remove('hidden'); - } - } - } - - /** - * Scroll to top of modal content - * @param {HTMLElement} button - The back to top button - */ - scrollToTop(button) { - const modalContent = button.closest('.modal-content'); - if (modalContent) { - modalContent.scrollTo({ - top: 0, - behavior: 'smooth' - }); - } - } -} - -// Create and export global instance -export const checkpointModal = new CheckpointModal(); - -// Add global functions for use in HTML -window.toggleShowcase = function(element) { - checkpointModal._toggleShowcase(element); -}; - -window.scrollToTopCheckpoint = function(button) { - checkpointModal.scrollToTop(button); -}; - -window.saveCheckpointNotes = function(filePath) { - checkpointModal._saveNotes(filePath); -}; \ No newline at end of file diff --git a/static/js/components/checkpointModal/ModelDescription.js b/static/js/components/checkpointModal/ModelDescription.js new file mode 100644 index 00000000..0f50fe8a --- /dev/null +++ b/static/js/components/checkpointModal/ModelDescription.js @@ -0,0 +1,102 @@ +/** + * ModelDescription.js + * Handles checkpoint model descriptions + */ +import { showToast } from '../../utils/uiHelpers.js'; + +/** + * Set up tab switching functionality + */ +export function setupTabSwitching() { + const tabButtons = document.querySelectorAll('.showcase-tabs .tab-btn'); + + tabButtons.forEach(button => { + button.addEventListener('click', () => { + // Remove active class from all tabs + document.querySelectorAll('.showcase-tabs .tab-btn').forEach(btn => + btn.classList.remove('active') + ); + document.querySelectorAll('.tab-content .tab-pane').forEach(tab => + tab.classList.remove('active') + ); + + // Add active class to clicked tab + button.classList.add('active'); + const tabId = `${button.dataset.tab}-tab`; + document.getElementById(tabId).classList.add('active'); + + // If switching to description tab, make sure content is properly loaded and displayed + if (button.dataset.tab === 'description') { + const descriptionContent = document.querySelector('.model-description-content'); + if (descriptionContent) { + const hasContent = descriptionContent.innerHTML.trim() !== ''; + document.querySelector('.model-description-loading')?.classList.add('hidden'); + + // If no content, show a message + if (!hasContent) { + descriptionContent.innerHTML = '
No model description available
'; + descriptionContent.classList.remove('hidden'); + } + } + } + }); + }); +} + +/** + * Load model description from API + * @param {string} modelId - The Civitai model ID + * @param {string} filePath - File path for the model + */ +export async function loadModelDescription(modelId, filePath) { + try { + const descriptionContainer = document.querySelector('.model-description-content'); + const loadingElement = document.querySelector('.model-description-loading'); + + if (!descriptionContainer || !loadingElement) return; + + // Show loading indicator + loadingElement.classList.remove('hidden'); + descriptionContainer.classList.add('hidden'); + + // Try to get model description from API + const response = await fetch(`/api/checkpoint-model-description?model_id=${modelId}&file_path=${encodeURIComponent(filePath)}`); + + if (!response.ok) { + throw new Error(`Failed to fetch model description: ${response.statusText}`); + } + + const data = await response.json(); + + if (data.success && data.description) { + // Update the description content + descriptionContainer.innerHTML = data.description; + + // Process any links in the description to open in new tab + const links = descriptionContainer.querySelectorAll('a'); + links.forEach(link => { + link.setAttribute('target', '_blank'); + link.setAttribute('rel', 'noopener noreferrer'); + }); + + // Show the description and hide loading indicator + descriptionContainer.classList.remove('hidden'); + loadingElement.classList.add('hidden'); + } else { + throw new Error(data.error || 'No description available'); + } + } catch (error) { + console.error('Error loading model description:', error); + const loadingElement = document.querySelector('.model-description-loading'); + if (loadingElement) { + loadingElement.innerHTML = `
Failed to load model description. ${error.message}
`; + } + + // Show empty state message in the description container + const descriptionContainer = document.querySelector('.model-description-content'); + if (descriptionContainer) { + descriptionContainer.innerHTML = '
No model description available
'; + descriptionContainer.classList.remove('hidden'); + } + } +} \ No newline at end of file diff --git a/static/js/components/checkpointModal/ModelMetadata.js b/static/js/components/checkpointModal/ModelMetadata.js new file mode 100644 index 00000000..56e84266 --- /dev/null +++ b/static/js/components/checkpointModal/ModelMetadata.js @@ -0,0 +1,492 @@ +/** + * ModelMetadata.js + * Handles checkpoint model metadata editing functionality + */ +import { showToast } from '../../utils/uiHelpers.js'; +import { BASE_MODELS } from '../../utils/constants.js'; + +/** + * Save model metadata to the server + * @param {string} filePath - Path to the model file + * @param {Object} data - Metadata to save + * @returns {Promise} - Promise that resolves with the server response + */ +export async function saveModelMetadata(filePath, data) { + const response = await fetch('/checkpoints/api/save-metadata', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + file_path: filePath, + ...data + }) + }); + + if (!response.ok) { + throw new Error('Failed to save metadata'); + } + + return response.json(); +} + +/** + * Set up model name editing functionality + */ +export function setupModelNameEditing() { + const modelNameContent = document.querySelector('.model-name-content'); + const editBtn = document.querySelector('.edit-model-name-btn'); + + if (!modelNameContent || !editBtn) return; + + // Show edit button on hover + const modelNameHeader = document.querySelector('.model-name-header'); + modelNameHeader.addEventListener('mouseenter', () => { + editBtn.classList.add('visible'); + }); + + modelNameHeader.addEventListener('mouseleave', () => { + if (!modelNameContent.getAttribute('data-editing')) { + editBtn.classList.remove('visible'); + } + }); + + // Handle edit button click + editBtn.addEventListener('click', () => { + modelNameContent.setAttribute('data-editing', 'true'); + modelNameContent.focus(); + + // Place cursor at the end + const range = document.createRange(); + const sel = window.getSelection(); + if (modelNameContent.childNodes.length > 0) { + range.setStart(modelNameContent.childNodes[0], modelNameContent.textContent.length); + range.collapse(true); + sel.removeAllRanges(); + sel.addRange(range); + } + + editBtn.classList.add('visible'); + }); + + // Handle focus out + modelNameContent.addEventListener('blur', function() { + this.removeAttribute('data-editing'); + editBtn.classList.remove('visible'); + + if (this.textContent.trim() === '') { + // Restore original model name if empty + const filePath = document.querySelector('#checkpointModal .modal-content') + .querySelector('.file-path').textContent + + document.querySelector('#checkpointModal .modal-content') + .querySelector('#file-name').textContent; + const checkpointCard = document.querySelector(`.checkpoint-card[data-filepath="${filePath}"]`); + if (checkpointCard) { + this.textContent = checkpointCard.dataset.model_name; + } + } + }); + + // Handle enter key + modelNameContent.addEventListener('keydown', function(e) { + if (e.key === 'Enter') { + e.preventDefault(); + const filePath = document.querySelector('#checkpointModal .modal-content') + .querySelector('.file-path').textContent + + document.querySelector('#checkpointModal .modal-content') + .querySelector('#file-name').textContent; + saveModelName(filePath); + this.blur(); + } + }); + + // Limit model name length + modelNameContent.addEventListener('input', function() { + if (this.textContent.length > 100) { + this.textContent = this.textContent.substring(0, 100); + // Place cursor at the end + const range = document.createRange(); + const sel = window.getSelection(); + range.setStart(this.childNodes[0], 100); + range.collapse(true); + sel.removeAllRanges(); + sel.addRange(range); + + showToast('Model name is limited to 100 characters', 'warning'); + } + }); +} + +/** + * Save model name + * @param {string} filePath - File path + */ +async function saveModelName(filePath) { + const modelNameElement = document.querySelector('.model-name-content'); + const newModelName = modelNameElement.textContent.trim(); + + // Validate model name + if (!newModelName) { + showToast('Model name cannot be empty', 'error'); + return; + } + + // Check if model name is too long + if (newModelName.length > 100) { + showToast('Model name is too long (maximum 100 characters)', 'error'); + // Truncate the displayed text + modelNameElement.textContent = newModelName.substring(0, 100); + return; + } + + try { + await saveModelMetadata(filePath, { model_name: newModelName }); + + // Update the corresponding checkpoint card's dataset and display + const checkpointCard = document.querySelector(`.checkpoint-card[data-filepath="${filePath}"]`); + if (checkpointCard) { + checkpointCard.dataset.model_name = newModelName; + const titleElement = checkpointCard.querySelector('.card-title'); + if (titleElement) { + titleElement.textContent = newModelName; + } + } + + showToast('Model name updated successfully', 'success'); + + // Reload the page to reflect the sorted order + setTimeout(() => { + window.location.reload(); + }, 1500); + } catch (error) { + showToast('Failed to update model name', 'error'); + } +} + +/** + * Set up base model editing functionality + */ +export function setupBaseModelEditing() { + const baseModelContent = document.querySelector('.base-model-content'); + const editBtn = document.querySelector('.edit-base-model-btn'); + + if (!baseModelContent || !editBtn) return; + + // Show edit button on hover + const baseModelDisplay = document.querySelector('.base-model-display'); + baseModelDisplay.addEventListener('mouseenter', () => { + editBtn.classList.add('visible'); + }); + + baseModelDisplay.addEventListener('mouseleave', () => { + if (!baseModelDisplay.classList.contains('editing')) { + editBtn.classList.remove('visible'); + } + }); + + // Handle edit button click + editBtn.addEventListener('click', () => { + baseModelDisplay.classList.add('editing'); + + // Store the original value to check for changes later + const originalValue = baseModelContent.textContent.trim(); + + // Create dropdown selector to replace the base model content + const currentValue = originalValue; + const dropdown = document.createElement('select'); + dropdown.className = 'base-model-selector'; + + // Flag to track if a change was made + let valueChanged = false; + + // Add options from BASE_MODELS constants + const baseModelCategories = { + 'Stable Diffusion 1.x': [BASE_MODELS.SD_1_4, BASE_MODELS.SD_1_5, BASE_MODELS.SD_1_5_LCM, BASE_MODELS.SD_1_5_HYPER], + 'Stable Diffusion 2.x': [BASE_MODELS.SD_2_0, BASE_MODELS.SD_2_1], + 'Stable Diffusion 3.x': [BASE_MODELS.SD_3, BASE_MODELS.SD_3_5, BASE_MODELS.SD_3_5_MEDIUM, BASE_MODELS.SD_3_5_LARGE, BASE_MODELS.SD_3_5_LARGE_TURBO], + 'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER], + 'Video Models': [BASE_MODELS.SVD, BASE_MODELS.WAN_VIDEO, BASE_MODELS.HUNYUAN_VIDEO], + 'Other Models': [ + BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.AURAFLOW, + BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1, + BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI, + BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.UNKNOWN + ] + }; + + // Create option groups for better organization + Object.entries(baseModelCategories).forEach(([category, models]) => { + const group = document.createElement('optgroup'); + group.label = category; + + models.forEach(model => { + const option = document.createElement('option'); + option.value = model; + option.textContent = model; + option.selected = model === currentValue; + group.appendChild(option); + }); + + dropdown.appendChild(group); + }); + + // Replace content with dropdown + baseModelContent.style.display = 'none'; + baseModelDisplay.insertBefore(dropdown, editBtn); + + // Hide edit button during editing + editBtn.style.display = 'none'; + + // Focus the dropdown + dropdown.focus(); + + // Handle dropdown change + dropdown.addEventListener('change', function() { + const selectedModel = this.value; + baseModelContent.textContent = selectedModel; + + // Mark that a change was made if the value differs from original + if (selectedModel !== originalValue) { + valueChanged = true; + } else { + valueChanged = false; + } + }); + + // Function to save changes and exit edit mode + const saveAndExit = function() { + // Check if dropdown still exists and remove it + if (dropdown && dropdown.parentNode === baseModelDisplay) { + baseModelDisplay.removeChild(dropdown); + } + + // Show the content and edit button + baseModelContent.style.display = ''; + editBtn.style.display = ''; + + // Remove editing class + baseModelDisplay.classList.remove('editing'); + + // Only save if the value has actually changed + if (valueChanged || baseModelContent.textContent.trim() !== originalValue) { + // Get file path for saving + const filePath = document.querySelector('#checkpointModal .modal-content') + .querySelector('.file-path').textContent + + document.querySelector('#checkpointModal .modal-content') + .querySelector('#file-name').textContent; + + // Save the changes, passing the original value for comparison + saveBaseModel(filePath, originalValue); + } + + // Remove this event listener + document.removeEventListener('click', outsideClickHandler); + }; + + // Handle outside clicks to save and exit + const outsideClickHandler = function(e) { + // If click is outside the dropdown and base model display + if (!baseModelDisplay.contains(e.target)) { + saveAndExit(); + } + }; + + // Add delayed event listener for outside clicks + setTimeout(() => { + document.addEventListener('click', outsideClickHandler); + }, 0); + + // Also handle dropdown blur event + dropdown.addEventListener('blur', function(e) { + // Only save if the related target is not the edit button or inside the baseModelDisplay + if (!baseModelDisplay.contains(e.relatedTarget)) { + saveAndExit(); + } + }); + }); +} + +/** + * Save base model + * @param {string} filePath - File path + * @param {string} originalValue - Original value (for comparison) + */ +async function saveBaseModel(filePath, originalValue) { + const baseModelElement = document.querySelector('.base-model-content'); + const newBaseModel = baseModelElement.textContent.trim(); + + // Only save if the value has actually changed + if (newBaseModel === originalValue) { + return; // No change, no need to save + } + + try { + await saveModelMetadata(filePath, { base_model: newBaseModel }); + + // Update the corresponding checkpoint card's dataset + const checkpointCard = document.querySelector(`.checkpoint-card[data-filepath="${filePath}"]`); + if (checkpointCard) { + checkpointCard.dataset.base_model = newBaseModel; + } + + showToast('Base model updated successfully', 'success'); + } catch (error) { + showToast('Failed to update base model', 'error'); + } +} + +/** + * Set up file name editing functionality + */ +export function setupFileNameEditing() { + const fileNameContent = document.querySelector('.file-name-content'); + const editBtn = document.querySelector('.edit-file-name-btn'); + + if (!fileNameContent || !editBtn) return; + + // Show edit button on hover + const fileNameWrapper = document.querySelector('.file-name-wrapper'); + fileNameWrapper.addEventListener('mouseenter', () => { + editBtn.classList.add('visible'); + }); + + fileNameWrapper.addEventListener('mouseleave', () => { + if (!fileNameWrapper.classList.contains('editing')) { + editBtn.classList.remove('visible'); + } + }); + + // Handle edit button click + editBtn.addEventListener('click', () => { + fileNameWrapper.classList.add('editing'); + fileNameContent.setAttribute('contenteditable', 'true'); + fileNameContent.focus(); + + // Store original value for comparison later + fileNameContent.dataset.originalValue = fileNameContent.textContent.trim(); + + // Place cursor at the end + const range = document.createRange(); + const sel = window.getSelection(); + range.selectNodeContents(fileNameContent); + range.collapse(false); + sel.removeAllRanges(); + sel.addRange(range); + + editBtn.classList.add('visible'); + }); + + // Handle keyboard events in edit mode + fileNameContent.addEventListener('keydown', function(e) { + if (!this.getAttribute('contenteditable')) return; + + if (e.key === 'Enter') { + e.preventDefault(); + this.blur(); // Trigger save on Enter + } else if (e.key === 'Escape') { + e.preventDefault(); + // Restore original value + this.textContent = this.dataset.originalValue; + exitEditMode(); + } + }); + + // Handle input validation + fileNameContent.addEventListener('input', function() { + if (!this.getAttribute('contenteditable')) return; + + // Replace invalid characters for filenames + const invalidChars = /[\\/:*?"<>|]/g; + if (invalidChars.test(this.textContent)) { + const cursorPos = window.getSelection().getRangeAt(0).startOffset; + this.textContent = this.textContent.replace(invalidChars, ''); + + // Restore cursor position + const range = document.createRange(); + const sel = window.getSelection(); + const newPos = Math.min(cursorPos, this.textContent.length); + + if (this.firstChild) { + range.setStart(this.firstChild, newPos); + range.collapse(true); + sel.removeAllRanges(); + sel.addRange(range); + } + + showToast('Invalid characters removed from filename', 'warning'); + } + }); + + // Handle focus out - save changes + fileNameContent.addEventListener('blur', async function() { + if (!this.getAttribute('contenteditable')) return; + + const newFileName = this.textContent.trim(); + const originalValue = this.dataset.originalValue; + + // Basic validation + if (!newFileName) { + // Restore original value if empty + this.textContent = originalValue; + showToast('File name cannot be empty', 'error'); + exitEditMode(); + return; + } + + if (newFileName === originalValue) { + // No changes, just exit edit mode + exitEditMode(); + return; + } + + try { + // Get the full file path + const filePath = document.querySelector('#checkpointModal .modal-content') + .querySelector('.file-path').textContent + originalValue; + + // Call API to rename the file + const response = await fetch('/api/rename_checkpoint', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + file_path: filePath, + new_file_name: newFileName + }) + }); + + const result = await response.json(); + + if (result.success) { + showToast('File name updated successfully', 'success'); + + // Update the checkpoint card with new file path + const checkpointCard = document.querySelector(`.checkpoint-card[data-filepath="${filePath}"]`); + if (checkpointCard) { + const newFilePath = filePath.replace(originalValue, newFileName); + checkpointCard.dataset.filepath = newFilePath; + } + + // Reload the page after a short delay to reflect changes + setTimeout(() => { + window.location.reload(); + }, 1500); + } else { + throw new Error(result.error || 'Unknown error'); + } + } catch (error) { + console.error('Error renaming file:', error); + this.textContent = originalValue; // Restore original file name + showToast(`Failed to rename file: ${error.message}`, 'error'); + } finally { + exitEditMode(); + } + }); + + function exitEditMode() { + fileNameContent.removeAttribute('contenteditable'); + fileNameWrapper.classList.remove('editing'); + editBtn.classList.remove('visible'); + } +} \ No newline at end of file diff --git a/static/js/components/checkpointModal/ShowcaseView.js b/static/js/components/checkpointModal/ShowcaseView.js new file mode 100644 index 00000000..d9843fc3 --- /dev/null +++ b/static/js/components/checkpointModal/ShowcaseView.js @@ -0,0 +1,489 @@ +/** + * ShowcaseView.js + * Handles showcase content (images, videos) display for checkpoint modal + */ +import { showToast } from '../../utils/uiHelpers.js'; +import { state } from '../../state/index.js'; +import { NSFW_LEVELS } from '../../utils/constants.js'; + +/** + * Render showcase content + * @param {Array} images - Array of images/videos to show + * @returns {string} HTML content + */ +export function renderShowcaseContent(images) { + if (!images?.length) return '
No example images available
'; + + // Filter images based on SFW setting + const showOnlySFW = state.settings.show_only_sfw; + let filteredImages = images; + let hiddenCount = 0; + + if (showOnlySFW) { + filteredImages = images.filter(img => { + const nsfwLevel = img.nsfwLevel !== undefined ? img.nsfwLevel : 0; + const isSfw = nsfwLevel < NSFW_LEVELS.R; + if (!isSfw) hiddenCount++; + return isSfw; + }); + } + + // Show message if no images are available after filtering + if (filteredImages.length === 0) { + return ` +
+

All example images are filtered due to NSFW content settings

+

Your settings are currently set to show only safe-for-work content

+

You can change this in Settings

+
+ `; + } + + // Show hidden content notification if applicable + const hiddenNotification = hiddenCount > 0 ? + `
+ ${hiddenCount} ${hiddenCount === 1 ? 'image' : 'images'} hidden due to SFW-only setting +
` : ''; + + return ` +
+ + Scroll or click to show ${filteredImages.length} examples +
+ + `; +} + +/** + * Generate media wrapper HTML for an image or video + * @param {Object} media - Media object with image or video data + * @returns {string} HTML content + */ +function generateMediaWrapper(media) { + // Calculate appropriate aspect ratio: + // 1. Keep original aspect ratio + // 2. Limit maximum height to 60% of viewport height + // 3. Ensure minimum height is 40% of container width + const aspectRatio = (media.height / media.width) * 100; + const containerWidth = 800; // modal content maximum width + const minHeightPercent = 40; + const maxHeightPercent = (window.innerHeight * 0.6 / containerWidth) * 100; + const heightPercent = Math.max( + minHeightPercent, + Math.min(maxHeightPercent, aspectRatio) + ); + + // Check if media should be blurred + const nsfwLevel = media.nsfwLevel !== undefined ? media.nsfwLevel : 0; + const shouldBlur = state.settings.blurMatureContent && nsfwLevel > NSFW_LEVELS.PG13; + + // Determine NSFW warning text based on level + let nsfwText = "Mature Content"; + if (nsfwLevel >= NSFW_LEVELS.XXX) { + nsfwText = "XXX-rated Content"; + } else if (nsfwLevel >= NSFW_LEVELS.X) { + nsfwText = "X-rated Content"; + } else if (nsfwLevel >= NSFW_LEVELS.R) { + nsfwText = "R-rated Content"; + } + + // Extract metadata from the media + const meta = media.meta || {}; + const prompt = meta.prompt || ''; + const negativePrompt = meta.negative_prompt || meta.negativePrompt || ''; + const size = meta.Size || `${media.width}x${media.height}`; + const seed = meta.seed || ''; + const model = meta.Model || ''; + const steps = meta.steps || ''; + const sampler = meta.sampler || ''; + const cfgScale = meta.cfgScale || ''; + const clipSkip = meta.clipSkip || ''; + + // Check if we have any meaningful generation parameters + const hasParams = seed || model || steps || sampler || cfgScale || clipSkip; + const hasPrompts = prompt || negativePrompt; + + // Create metadata panel content + const metadataPanel = generateMetadataPanel( + hasParams, hasPrompts, + prompt, negativePrompt, + size, seed, model, steps, sampler, cfgScale, clipSkip + ); + + // Check if this is a video or image + if (media.type === 'video') { + return generateVideoWrapper(media, heightPercent, shouldBlur, nsfwText, metadataPanel); + } + + return generateImageWrapper(media, heightPercent, shouldBlur, nsfwText, metadataPanel); +} + +/** + * Generate metadata panel HTML + */ +function generateMetadataPanel(hasParams, hasPrompts, prompt, negativePrompt, size, seed, model, steps, sampler, cfgScale, clipSkip) { + // Create unique IDs for prompt copying + const promptIndex = Math.random().toString(36).substring(2, 15); + const negPromptIndex = Math.random().toString(36).substring(2, 15); + + let content = ''; + return content; +} + +/** + * Generate video wrapper HTML + */ +function generateVideoWrapper(media, heightPercent, shouldBlur, nsfwText, metadataPanel) { + return ` +
+ ${shouldBlur ? ` + + ` : ''} + + ${shouldBlur ? ` +
+
+

${nsfwText}

+ +
+
+ ` : ''} + ${metadataPanel} +
+ `; +} + +/** + * Generate image wrapper HTML + */ +function generateImageWrapper(media, heightPercent, shouldBlur, nsfwText, metadataPanel) { + return ` +
+ ${shouldBlur ? ` + + ` : ''} + Preview + ${shouldBlur ? ` +
+
+

${nsfwText}

+ +
+
+ ` : ''} + ${metadataPanel} +
+ `; +} + +/** + * Toggle showcase expansion + */ +export function toggleShowcase(element) { + const carousel = element.nextElementSibling; + const isCollapsed = carousel.classList.contains('collapsed'); + const indicator = element.querySelector('span'); + const icon = element.querySelector('i'); + + carousel.classList.toggle('collapsed'); + + if (isCollapsed) { + const count = carousel.querySelectorAll('.media-wrapper').length; + indicator.textContent = `Scroll or click to hide examples`; + icon.classList.replace('fa-chevron-down', 'fa-chevron-up'); + initLazyLoading(carousel); + + // Initialize NSFW content blur toggle handlers + initNsfwBlurHandlers(carousel); + + // Initialize metadata panel interaction handlers + initMetadataPanelHandlers(carousel); + } else { + const count = carousel.querySelectorAll('.media-wrapper').length; + indicator.textContent = `Scroll or click to show ${count} examples`; + icon.classList.replace('fa-chevron-up', 'fa-chevron-down'); + } +} + +/** + * Initialize metadata panel interaction handlers + */ +function initMetadataPanelHandlers(container) { + const mediaWrappers = container.querySelectorAll('.media-wrapper'); + + mediaWrappers.forEach(wrapper => { + const metadataPanel = wrapper.querySelector('.image-metadata-panel'); + if (!metadataPanel) return; + + // Prevent events from bubbling + metadataPanel.addEventListener('click', (e) => { + e.stopPropagation(); + }); + + // Handle copy prompt buttons + const copyBtns = metadataPanel.querySelectorAll('.copy-prompt-btn'); + copyBtns.forEach(copyBtn => { + const promptIndex = copyBtn.dataset.promptIndex; + const promptElement = wrapper.querySelector(`#prompt-${promptIndex}`); + + copyBtn.addEventListener('click', async (e) => { + e.stopPropagation(); + + if (!promptElement) return; + + try { + await navigator.clipboard.writeText(promptElement.textContent); + showToast('Prompt copied to clipboard', 'success'); + } catch (err) { + console.error('Copy failed:', err); + showToast('Copy failed', 'error'); + } + }); + }); + + // Prevent panel scroll from causing modal scroll + metadataPanel.addEventListener('wheel', (e) => { + e.stopPropagation(); + }); + }); +} + +/** + * Initialize blur toggle handlers + */ +function initNsfwBlurHandlers(container) { + // Handle toggle blur buttons + const toggleButtons = container.querySelectorAll('.toggle-blur-btn'); + toggleButtons.forEach(btn => { + btn.addEventListener('click', (e) => { + e.stopPropagation(); + const wrapper = btn.closest('.media-wrapper'); + const media = wrapper.querySelector('img, video'); + const isBlurred = media.classList.toggle('blurred'); + const icon = btn.querySelector('i'); + + // Update the icon based on blur state + if (isBlurred) { + icon.className = 'fas fa-eye'; + } else { + icon.className = 'fas fa-eye-slash'; + } + + // Toggle the overlay visibility + const overlay = wrapper.querySelector('.nsfw-overlay'); + if (overlay) { + overlay.style.display = isBlurred ? 'flex' : 'none'; + } + }); + }); + + // Handle "Show" buttons in overlays + const showButtons = container.querySelectorAll('.show-content-btn'); + showButtons.forEach(btn => { + btn.addEventListener('click', (e) => { + e.stopPropagation(); + const wrapper = btn.closest('.media-wrapper'); + const media = wrapper.querySelector('img, video'); + media.classList.remove('blurred'); + + // Update the toggle button icon + const toggleBtn = wrapper.querySelector('.toggle-blur-btn'); + if (toggleBtn) { + toggleBtn.querySelector('i').className = 'fas fa-eye-slash'; + } + + // Hide the overlay + const overlay = wrapper.querySelector('.nsfw-overlay'); + if (overlay) { + overlay.style.display = 'none'; + } + }); + }); +} + +/** + * Initialize lazy loading for images and videos + */ +function initLazyLoading(container) { + const lazyElements = container.querySelectorAll('.lazy'); + + const lazyLoad = (element) => { + if (element.tagName.toLowerCase() === 'video') { + element.src = element.dataset.src; + element.querySelector('source').src = element.dataset.src; + element.load(); + } else { + element.src = element.dataset.src; + } + element.classList.remove('lazy'); + }; + + const observer = new IntersectionObserver((entries) => { + entries.forEach(entry => { + if (entry.isIntersecting) { + lazyLoad(entry.target); + observer.unobserve(entry.target); + } + }); + }); + + lazyElements.forEach(element => observer.observe(element)); +} + +/** + * Set up showcase scroll functionality + */ +export function setupShowcaseScroll() { + // Listen for wheel events + document.addEventListener('wheel', (event) => { + const modalContent = document.querySelector('#checkpointModal .modal-content'); + if (!modalContent) return; + + const showcase = modalContent.querySelector('.showcase-section'); + if (!showcase) return; + + const carousel = showcase.querySelector('.carousel'); + const scrollIndicator = showcase.querySelector('.scroll-indicator'); + + if (carousel?.classList.contains('collapsed') && event.deltaY > 0) { + const isNearBottom = modalContent.scrollHeight - modalContent.scrollTop - modalContent.clientHeight < 100; + + if (isNearBottom) { + toggleShowcase(scrollIndicator); + event.preventDefault(); + } + } + }, { passive: false }); + + // Use MutationObserver to set up back-to-top button when modal content is added + const observer = new MutationObserver((mutations) => { + for (const mutation of mutations) { + if (mutation.type === 'childList' && mutation.addedNodes.length) { + const checkpointModal = document.getElementById('checkpointModal'); + if (checkpointModal && checkpointModal.querySelector('.modal-content')) { + setupBackToTopButton(checkpointModal.querySelector('.modal-content')); + } + } + } + }); + + // Start observing the document body for changes + observer.observe(document.body, { childList: true, subtree: true }); + + // Also try to set up the button immediately in case the modal is already open + const modalContent = document.querySelector('#checkpointModal .modal-content'); + if (modalContent) { + setupBackToTopButton(modalContent); + } +} + +/** + * Set up back-to-top button + */ +function setupBackToTopButton(modalContent) { + // Remove any existing scroll listeners to avoid duplicates + modalContent.onscroll = null; + + // Add new scroll listener + modalContent.addEventListener('scroll', () => { + const backToTopBtn = modalContent.querySelector('.back-to-top'); + if (backToTopBtn) { + if (modalContent.scrollTop > 300) { + backToTopBtn.classList.add('visible'); + } else { + backToTopBtn.classList.remove('visible'); + } + } + }); + + // Trigger a scroll event to check initial position + modalContent.dispatchEvent(new Event('scroll')); +} + +/** + * Scroll to top of modal content + */ +export function scrollToTop(button) { + const modalContent = button.closest('.modal-content'); + if (modalContent) { + modalContent.scrollTo({ + top: 0, + behavior: 'smooth' + }); + } +} \ No newline at end of file diff --git a/static/js/components/checkpointModal/index.js b/static/js/components/checkpointModal/index.js new file mode 100644 index 00000000..479b0c9b --- /dev/null +++ b/static/js/components/checkpointModal/index.js @@ -0,0 +1,219 @@ +/** + * CheckpointModal - Main entry point + * + * Modularized checkpoint modal component that handles checkpoint model details display + */ +import { showToast } from '../../utils/uiHelpers.js'; +import { state } from '../../state/index.js'; +import { modalManager } from '../../managers/ModalManager.js'; +import { renderShowcaseContent, toggleShowcase, setupShowcaseScroll, scrollToTop } from './ShowcaseView.js'; +import { setupTabSwitching, loadModelDescription } from './ModelDescription.js'; +import { + setupModelNameEditing, + setupBaseModelEditing, + setupFileNameEditing, + saveModelMetadata +} from './ModelMetadata.js'; +import { renderCompactTags, setupTagTooltip, formatFileSize } from './utils.js'; + +/** + * Display the checkpoint modal with the given checkpoint data + * @param {Object} checkpoint - Checkpoint data object + */ +export function showCheckpointModal(checkpoint) { + const content = ` + + `; + + modalManager.showModal('checkpointModal', content); + setupEditableFields(); + setupShowcaseScroll(); + setupTabSwitching(); + setupTagTooltip(); + setupModelNameEditing(); + setupBaseModelEditing(); + setupFileNameEditing(); + + // If we have a model ID but no description, fetch it + if (checkpoint.civitai?.modelId && !checkpoint.modelDescription) { + loadModelDescription(checkpoint.civitai.modelId, checkpoint.file_path); + } +} + +/** + * Set up editable fields in the checkpoint modal + */ +function setupEditableFields() { + const editableFields = document.querySelectorAll('.editable-field [contenteditable]'); + + editableFields.forEach(field => { + field.addEventListener('focus', function() { + if (this.textContent === 'Add your notes here...') { + this.textContent = ''; + } + }); + + field.addEventListener('blur', function() { + if (this.textContent.trim() === '') { + if (this.classList.contains('notes-content')) { + this.textContent = 'Add your notes here...'; + } + } + }); + }); + + // Add keydown event listeners for notes + const notesContent = document.querySelector('.notes-content'); + if (notesContent) { + notesContent.addEventListener('keydown', async function(e) { + if (e.key === 'Enter') { + if (e.shiftKey) { + // Allow shift+enter for new line + return; + } + e.preventDefault(); + const filePath = document.querySelector('#checkpointModal .modal-content') + .querySelector('.file-path').textContent + + document.querySelector('#checkpointModal .modal-content') + .querySelector('#file-name').textContent; + await saveNotes(filePath); + } + }); + } +} + +/** + * Save checkpoint notes + * @param {string} filePath - Path to the checkpoint file + */ +async function saveNotes(filePath) { + const content = document.querySelector('.notes-content').textContent; + try { + await saveModelMetadata(filePath, { notes: content }); + + // Update the corresponding checkpoint card's dataset + const checkpointCard = document.querySelector(`.checkpoint-card[data-filepath="${filePath}"]`); + if (checkpointCard) { + checkpointCard.dataset.notes = content; + } + + showToast('Notes saved successfully', 'success'); + } catch (error) { + showToast('Failed to save notes', 'error'); + } +} + +// Export the checkpoint modal API +const checkpointModal = { + show: showCheckpointModal, + toggleShowcase, + scrollToTop +}; + +export { checkpointModal }; + +// Define global functions for use in HTML +window.toggleShowcase = function(element) { + toggleShowcase(element); +}; + +window.scrollToTopCheckpoint = function(button) { + scrollToTop(button); +}; + +window.saveCheckpointNotes = function(filePath) { + saveNotes(filePath); +}; \ No newline at end of file diff --git a/static/js/components/checkpointModal/utils.js b/static/js/components/checkpointModal/utils.js new file mode 100644 index 00000000..62bc59fe --- /dev/null +++ b/static/js/components/checkpointModal/utils.js @@ -0,0 +1,74 @@ +/** + * utils.js + * CheckpointModal component utility functions + */ +import { showToast } from '../../utils/uiHelpers.js'; + +/** + * Format file size for display + * @param {number} bytes - File size in bytes + * @returns {string} - Formatted file size + */ +export function formatFileSize(bytes) { + if (!bytes) return 'N/A'; + + const units = ['B', 'KB', 'MB', 'GB', 'TB']; + let size = bytes; + let unitIndex = 0; + + while (size >= 1024 && unitIndex < units.length - 1) { + size /= 1024; + unitIndex++; + } + + return `${size.toFixed(1)} ${units[unitIndex]}`; +} + +/** + * Render compact tags + * @param {Array} tags - Array of tags + * @returns {string} HTML content + */ +export function renderCompactTags(tags) { + if (!tags || tags.length === 0) return ''; + + // Display up to 5 tags, with a tooltip indicator if there are more + const visibleTags = tags.slice(0, 5); + const remainingCount = Math.max(0, tags.length - 5); + + return ` +
+
+ ${visibleTags.map(tag => `${tag}`).join('')} + ${remainingCount > 0 ? + `+${remainingCount}` : + ''} +
+ ${tags.length > 0 ? + `
+
+ ${tags.map(tag => `${tag}`).join('')} +
+
` : + ''} +
+ `; +} + +/** + * Set up tag tooltip functionality + */ +export function setupTagTooltip() { + const tagsContainer = document.querySelector('.model-tags-container'); + const tooltip = document.querySelector('.model-tags-tooltip'); + + if (tagsContainer && tooltip) { + tagsContainer.addEventListener('mouseenter', () => { + tooltip.classList.add('visible'); + }); + + tagsContainer.addEventListener('mouseleave', () => { + tooltip.classList.remove('visible'); + }); + } +} \ No newline at end of file From fafec56f0903827ab992df2f9067841412cbdd60 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 11 Apr 2025 05:52:56 +0800 Subject: [PATCH 12/36] refactor: Rename update_single_lora_cache to update_single_model_cache for consistency --- py/routes/api_routes.py | 6 ++-- py/services/lora_scanner.py | 58 +------------------------------------ 2 files changed, 4 insertions(+), 60 deletions(-) diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index 0dff5524..af5aef44 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -396,7 +396,7 @@ class ApiRoutes: with open(metadata_path, 'w', encoding='utf-8') as f: json.dump(local_metadata, f, indent=2, ensure_ascii=False) - await self.scanner.update_single_lora_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata) + await self.scanner.update_single_model_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata) async def fetch_all_civitai(self, request: web.Request) -> web.Response: """Fetch CivitAI metadata for all loras in the background""" @@ -745,7 +745,7 @@ class ApiRoutes: json.dump(metadata, f, indent=2, ensure_ascii=False) # Update cache - await self.scanner.update_single_lora_cache(file_path, file_path, metadata) + await self.scanner.update_single_model_cache(file_path, file_path, metadata) # If model_name was updated, resort the cache if 'model_name' in metadata_updates: @@ -1132,7 +1132,7 @@ class ApiRoutes: # Update the scanner cache if metadata: - await self.scanner.update_single_lora_cache(file_path, new_file_path, metadata) + await self.scanner.update_single_model_cache(file_path, new_file_path, metadata) # Update recipe files and cache if hash is available if hash_value: diff --git a/py/services/lora_scanner.py b/py/services/lora_scanner.py index c4e10c3d..fee5f30b 100644 --- a/py/services/lora_scanner.py +++ b/py/services/lora_scanner.py @@ -315,7 +315,7 @@ class LoraScanner(ModelScanner): break # Update cache - await self.update_single_lora_cache(source_path, target_lora, metadata) + await self.update_single_model_cache(source_path, target_lora, metadata) return True @@ -323,62 +323,6 @@ class LoraScanner(ModelScanner): logger.error(f"Error moving model: {e}", exc_info=True) return False - async def update_single_lora_cache(self, original_path: str, new_path: str, metadata: Dict) -> bool: - cache = await self.get_cached_data() - - # Find the existing item to remove its tags from count - existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None) - if existing_item and 'tags' in existing_item: - for tag in existing_item.get('tags', []): - if tag in self._tags_count: - self._tags_count[tag] = max(0, self._tags_count[tag] - 1) - if self._tags_count[tag] == 0: - del self._tags_count[tag] - - # Remove old path from hash index if exists - self._hash_index.remove_by_path(original_path) - - # Remove the old entry from raw_data - cache.raw_data = [ - item for item in cache.raw_data - if item['file_path'] != original_path - ] - - if metadata: - # If this is an update to an existing path (not a move), ensure folder is preserved - if original_path == new_path: - # Find the folder from existing entries or calculate it - existing_folder = next((item['folder'] for item in cache.raw_data - if item['file_path'] == original_path), None) - if existing_folder: - metadata['folder'] = existing_folder - else: - metadata['folder'] = self._calculate_folder(new_path) - else: - # For moved files, recalculate the folder - metadata['folder'] = self._calculate_folder(new_path) - - # Add the updated metadata to raw_data - cache.raw_data.append(metadata) - - # Update hash index with new path - if 'sha256' in metadata: - self._hash_index.add_entry(metadata['sha256'].lower(), new_path) - - # Update folders list - all_folders = set(item['folder'] for item in cache.raw_data) - cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) - - # Update tags count with the new/updated tags - if 'tags' in metadata: - for tag in metadata.get('tags', []): - self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 - - # Resort cache - await cache.resort() - - return True - async def _update_metadata_paths(self, metadata_path: str, lora_path: str) -> Dict: """Update file paths in metadata file""" try: From 18aa8d11adaa13f5172daea883584b8c6cc963f5 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 11 Apr 2025 05:59:32 +0800 Subject: [PATCH 13/36] refactor: Remove showToast call from clearCustomFilter method in LorasControls --- py/services/lora_scanner.py | 20 ------------------- .../js/components/controls/LorasControls.js | 1 - 2 files changed, 21 deletions(-) diff --git a/py/services/lora_scanner.py b/py/services/lora_scanner.py index fee5f30b..654a8ac1 100644 --- a/py/services/lora_scanner.py +++ b/py/services/lora_scanner.py @@ -362,26 +362,6 @@ class LoraScanner(ModelScanner): """Get hash for a LoRA by its file path""" return self._hash_index.get_hash(file_path) - def get_preview_url_by_hash(self, sha256: str) -> Optional[str]: - """Get preview static URL for a LoRA by its hash""" - # Get the file path first - file_path = self._hash_index.get_path(sha256.lower()) - if not file_path: - return None - - # Determine the preview file path (typically same name with different extension) - base_name = os.path.splitext(file_path)[0] - preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4', - '.png', '.jpeg', '.jpg', '.mp4'] - - for ext in preview_extensions: - preview_path = f"{base_name}{ext}" - if os.path.exists(preview_path): - # Convert to static URL using config - return config.get_preview_static_url(preview_path) - - return None - async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]: """Get top tags sorted by count""" # Make sure cache is initialized diff --git a/static/js/components/controls/LorasControls.js b/static/js/components/controls/LorasControls.js index 4d0cc9eb..0a45a9e6 100644 --- a/static/js/components/controls/LorasControls.js +++ b/static/js/components/controls/LorasControls.js @@ -132,7 +132,6 @@ export class LorasControls extends PageControls { // Reload the loras await resetAndReload(); - showToast('Filter cleared', 'info'); } /** From 86810d9f03b0853d58ffdfa59d41b26139a7769b Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 11 Apr 2025 06:05:19 +0800 Subject: [PATCH 14/36] refactor: Remove move_model method from LoraScanner class to streamline code --- py/services/lora_scanner.py | 60 ------------------------------------ py/services/model_scanner.py | 3 +- 2 files changed, 2 insertions(+), 61 deletions(-) diff --git a/py/services/lora_scanner.py b/py/services/lora_scanner.py index 654a8ac1..a57c44e0 100644 --- a/py/services/lora_scanner.py +++ b/py/services/lora_scanner.py @@ -263,66 +263,6 @@ class LoraScanner(ModelScanner): return result - async def move_model(self, source_path: str, target_path: str) -> bool: - """Move a model and its associated files to a new location""" - try: - # Keep original path format - source_path = source_path.replace(os.sep, '/') - target_path = target_path.replace(os.sep, '/') - - # Rest of the code remains unchanged - base_name = os.path.splitext(os.path.basename(source_path))[0] - source_dir = os.path.dirname(source_path) - - os.makedirs(target_path, exist_ok=True) - - target_lora = os.path.join(target_path, f"{base_name}.safetensors").replace(os.sep, '/') - - # Use real paths for file operations - real_source = os.path.realpath(source_path) - real_target = os.path.realpath(target_lora) - - file_size = os.path.getsize(real_source) - - if self.file_monitor: - self.file_monitor.handler.add_ignore_path( - real_source, - file_size - ) - self.file_monitor.handler.add_ignore_path( - real_target, - file_size - ) - - # Use real paths for file operations - shutil.move(real_source, real_target) - - # Move associated files - source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json") - if os.path.exists(source_metadata): - target_metadata = os.path.join(target_path, f"{base_name}.metadata.json") - shutil.move(source_metadata, target_metadata) - metadata = await self._update_metadata_paths(target_metadata, target_lora) - - # Move preview file if exists - preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4', - '.png', '.jpeg', '.jpg', '.mp4'] - for ext in preview_extensions: - source_preview = os.path.join(source_dir, f"{base_name}{ext}") - if os.path.exists(source_preview): - target_preview = os.path.join(target_path, f"{base_name}{ext}") - shutil.move(source_preview, target_preview) - break - - # Update cache - await self.update_single_model_cache(source_path, target_lora, metadata) - - return True - - except Exception as e: - logger.error(f"Error moving model: {e}", exc_info=True) - return False - async def _update_metadata_paths(self, metadata_path: str, lora_path: str) -> Dict: """Update file paths in metadata file""" try: diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index a4b83f6f..97c41f57 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -482,7 +482,8 @@ class ModelScanner: def get_hash_by_path(self, file_path: str) -> Optional[str]: """Get hash for a model by its file path""" return self._hash_index.get_hash(file_path) - + + # TODO: Adjust this method to use metadata instead of finding the file def get_preview_url_by_hash(self, sha256: str) -> Optional[str]: """Get preview static URL for a model by its hash""" file_path = self._hash_index.get_path(sha256.lower()) From 7393e92b21b56aed4db1a793f39850a2d1fbfca2 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 11 Apr 2025 06:19:15 +0800 Subject: [PATCH 15/36] refactor: Consolidate preview file extensions into constants for improved maintainability --- py/routes/api_routes.py | 29 +++++++++-------------------- py/services/model_scanner.py | 9 +++------ py/utils/constants.py | 16 +++++++++++++++- py/utils/file_utils.py | 15 +++------------ 4 files changed, 30 insertions(+), 39 deletions(-) diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index af5aef44..76ca11a2 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -17,6 +17,7 @@ from ..services.settings_manager import settings import asyncio from .update_routes import UpdateRoutes from ..services.recipe_scanner import RecipeScanner +from ..utils.constants import PREVIEW_EXTENSIONS logger = logging.getLogger(__name__) @@ -244,18 +245,12 @@ class ApiRoutes: patterns = [ f"{file_name}.safetensors", # Required f"{file_name}.metadata.json", - f"{file_name}.preview.png", - f"{file_name}.preview.jpg", - f"{file_name}.preview.jpeg", - f"{file_name}.preview.webp", - f"{file_name}.preview.mp4", - f"{file_name}.png", - f"{file_name}.jpg", - f"{file_name}.jpeg", - f"{file_name}.webp", - f"{file_name}.mp4" ] + # 添加所有预览文件扩展名 + for ext in PREVIEW_EXTENSIONS: + patterns.append(f"{file_name}{ext}") + deleted = [] main_file = patterns[0] main_path = os.path.join(target_dir, main_file).replace(os.sep, '/') @@ -1054,18 +1049,12 @@ class ApiRoutes: patterns = [ f"{old_file_name}.safetensors", # Required f"{old_file_name}.metadata.json", - f"{old_file_name}.preview.png", - f"{old_file_name}.preview.jpg", - f"{old_file_name}.preview.jpeg", - f"{old_file_name}.preview.webp", - f"{old_file_name}.preview.mp4", - f"{old_file_name}.png", - f"{old_file_name}.jpg", - f"{old_file_name}.jpeg", - f"{old_file_name}.webp", - f"{old_file_name}.mp4" ] + # 添加所有预览文件扩展名 + for ext in PREVIEW_EXTENSIONS: + patterns.append(f"{old_file_name}{ext}") + # Find all matching files existing_files = [] for pattern in patterns: diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index 97c41f57..d9134b2a 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -11,6 +11,7 @@ from ..config import config from ..utils.file_utils import load_metadata, get_file_info, find_preview_file, save_metadata from .model_cache import ModelCache from .model_hash_index import ModelHashIndex +from ..utils.constants import PREVIEW_EXTENSIONS logger = logging.getLogger(__name__) @@ -384,9 +385,7 @@ class ModelScanner: shutil.move(source_metadata, target_metadata) metadata = await self._update_metadata_paths(target_metadata, target_file) - preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4', - '.png', '.jpeg', '.jpg', '.mp4'] - for ext in preview_extensions: + for ext in PREVIEW_EXTENSIONS: source_preview = os.path.join(source_dir, f"{base_name}{ext}") if os.path.exists(source_preview): target_preview = os.path.join(target_path, f"{base_name}{ext}") @@ -491,10 +490,8 @@ class ModelScanner: return None base_name = os.path.splitext(file_path)[0] - preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4', - '.png', '.jpeg', '.jpg', '.mp4'] - for ext in preview_extensions: + for ext in PREVIEW_EXTENSIONS: preview_path = f"{base_name}{ext}" if os.path.exists(preview_path): return config.get_preview_static_url(preview_path) diff --git a/py/utils/constants.py b/py/utils/constants.py index 69a96ca2..ac4e68e3 100644 --- a/py/utils/constants.py +++ b/py/utils/constants.py @@ -5,4 +5,18 @@ NSFW_LEVELS = { "X": 8, "XXX": 16, "Blocked": 32, # Probably not actually visible through the API without being logged in on model owner account? -} \ No newline at end of file +} + +# 预览文件扩展名 +PREVIEW_EXTENSIONS = [ + '.preview.png', + '.preview.jpeg', + '.preview.jpg', + '.preview.webp', + '.preview.mp4', + '.png', + '.jpeg', + '.jpg', + '.webp', + '.mp4' +] \ No newline at end of file diff --git a/py/utils/file_utils.py b/py/utils/file_utils.py index 859e86ae..058469d6 100644 --- a/py/utils/file_utils.py +++ b/py/utils/file_utils.py @@ -8,6 +8,7 @@ from typing import Dict, Optional, Type from .model_utils import determine_base_model from .lora_metadata import extract_lora_metadata, extract_checkpoint_metadata from .models import BaseModelMetadata, LoraMetadata, CheckpointMetadata +from .constants import PREVIEW_EXTENSIONS logger = logging.getLogger(__name__) @@ -21,19 +22,9 @@ async def calculate_sha256(file_path: str) -> str: def find_preview_file(base_name: str, dir_path: str) -> str: """Find preview file for given base name in directory""" - preview_patterns = [ - f"{base_name}.preview.png", - f"{base_name}.preview.jpg", - f"{base_name}.preview.jpeg", - f"{base_name}.preview.mp4", - f"{base_name}.png", - f"{base_name}.jpg", - f"{base_name}.jpeg", - f"{base_name}.mp4" - ] - for pattern in preview_patterns: - full_pattern = os.path.join(dir_path, pattern) + for ext in PREVIEW_EXTENSIONS: + full_pattern = os.path.join(dir_path, f"{base_name}{ext}") if os.path.exists(full_pattern): return full_pattern.replace(os.sep, "/") return "" From ac244e6ad99002922ef3103d8bd26dc32606be1d Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 11 Apr 2025 08:19:19 +0800 Subject: [PATCH 16/36] refactor: Replace hardcoded image width with CARD_PREVIEW_WIDTH constant for consistency --- py/routes/recipe_routes.py | 5 +++-- py/utils/constants.py | 11 +++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index 796df6a1..bffc208e 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -9,6 +9,7 @@ import asyncio from ..utils.exif_utils import ExifUtils from ..utils.recipe_parsers import RecipeParserFactory from ..services.civitai_client import CivitaiClient +from ..utils.constants import CARD_PREVIEW_WIDTH from ..services.recipe_scanner import RecipeScanner from ..services.lora_scanner import LoraScanner @@ -424,7 +425,7 @@ class RecipeRoutes: # Optimize the image (resize and convert to WebP) optimized_image, extension = ExifUtils.optimize_image( image_data=image, - target_width=480, + target_width=CARD_PREVIEW_WIDTH, format='webp', quality=85, preserve_metadata=True @@ -828,7 +829,7 @@ class RecipeRoutes: # Optimize the image (resize and convert to WebP) optimized_image, extension = ExifUtils.optimize_image( image_data=image, - target_width=480, + target_width=CARD_PREVIEW_WIDTH, format='webp', quality=85, preserve_metadata=True diff --git a/py/utils/constants.py b/py/utils/constants.py index ac4e68e3..0521c5d9 100644 --- a/py/utils/constants.py +++ b/py/utils/constants.py @@ -7,16 +7,19 @@ NSFW_LEVELS = { "Blocked": 32, # Probably not actually visible through the API without being logged in on model owner account? } -# 预览文件扩展名 +# preview extensions PREVIEW_EXTENSIONS = [ + '.webp', + '.preview.webp', '.preview.png', '.preview.jpeg', '.preview.jpg', - '.preview.webp', '.preview.mp4', '.png', '.jpeg', '.jpg', - '.webp', '.mp4' -] \ No newline at end of file +] + +# Card preview image width +CARD_PREVIEW_WIDTH = 480 \ No newline at end of file From b0a5b48fb27aa0fdb801a109cff602855a0c7fc3 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 11 Apr 2025 08:43:21 +0800 Subject: [PATCH 17/36] refactor: Enhance preview file handling and add update_preview_in_cache method for ModelScanner --- py/routes/api_routes.py | 28 ++++++++++++++++++++-------- py/services/model_scanner.py | 17 ++++++++++++++++- static/js/components/LoraCard.js | 1 + 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index 76ca11a2..a35bf2b2 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -17,7 +17,8 @@ from ..services.settings_manager import settings import asyncio from .update_routes import UpdateRoutes from ..services.recipe_scanner import RecipeScanner -from ..utils.constants import PREVIEW_EXTENSIONS +from ..utils.constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH +from ..utils.exif_utils import ExifUtils logger = logging.getLogger(__name__) @@ -302,18 +303,29 @@ class ApiRoutes: async def _save_preview_file(self, model_path: str, preview_data: bytes, content_type: str) -> str: """Save preview file and return its path""" - # Determine file extension based on content type - if content_type.startswith('video/'): - extension = '.preview.mp4' - else: - extension = '.preview.png' - base_name = os.path.splitext(os.path.basename(model_path))[0] folder = os.path.dirname(model_path) + + # Determine if content is video or image + if content_type.startswith('video/'): + # For videos, keep original format and use .mp4 extension + extension = '.mp4' + optimized_data = preview_data + else: + # For images, optimize and convert to WebP + optimized_data, _ = ExifUtils.optimize_image( + image_data=preview_data, + target_width=CARD_PREVIEW_WIDTH, + format='webp', + quality=85, + preserve_metadata=True + ) + extension = '.webp' # Use .webp without .preview part + preview_path = os.path.join(folder, base_name + extension).replace(os.sep, '/') with open(preview_path, 'wb') as f: - f.write(preview_data) + f.write(optimized_data) return preview_path diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index d9134b2a..04d0cb2a 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -537,4 +537,19 @@ class ModelScanner: return None except Exception as e: logger.error(f"Error getting model info by name: {e}", exc_info=True) - return None \ No newline at end of file + return None + + async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool: + """Update preview URL in cache for a specific lora + + Args: + file_path: The file path of the lora to update + preview_url: The new preview URL + + Returns: + bool: True if the update was successful, False if cache doesn't exist or lora wasn't found + """ + if self._cache is None: + return False + + return await self._cache.update_preview_url(file_path, preview_url) \ No newline at end of file diff --git a/static/js/components/LoraCard.js b/static/js/components/LoraCard.js index c4baee23..2e0b9bd0 100644 --- a/static/js/components/LoraCard.js +++ b/static/js/components/LoraCard.js @@ -3,6 +3,7 @@ import { state } from '../state/index.js'; import { showLoraModal } from './loraModal/index.js'; import { bulkManager } from '../managers/BulkManager.js'; import { NSFW_LEVELS } from '../utils/constants.js'; +import { replacePreview } from '../api/loraApi.js' export function createLoraCard(lora) { const card = document.createElement('div'); From 297ff0dd25f25b6818df1faa29fe066ee2144900 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 11 Apr 2025 09:00:58 +0800 Subject: [PATCH 18/36] refactor: Improve download handling for previews and optimize image conversion in DownloadManager --- py/services/download_manager.py | 60 ++++++++++++++++++++++++++++---- static/js/components/LoraCard.js | 2 +- 2 files changed, 54 insertions(+), 8 deletions(-) diff --git a/py/services/download_manager.py b/py/services/download_manager.py index df0d3af2..1dc2a945 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -5,6 +5,11 @@ from typing import Optional, Dict from .civitai_client import CivitaiClient from .file_monitor import LoraFileMonitor from ..utils.models import LoraMetadata +from ..utils.constants import CARD_PREVIEW_WIDTH +from ..utils.exif_utils import ExifUtils + +# Download to temporary file first +import tempfile logger = logging.getLogger(__name__) @@ -128,13 +133,54 @@ class DownloadManager: if progress_callback: await progress_callback(1) # 1% progress for starting preview download - preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png' - preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext - if await self.civitai_client.download_preview_image(images[0]['url'], preview_path): - metadata.preview_url = preview_path.replace(os.sep, '/') - metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0) - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False) + # Check if it's a video or an image + is_video = images[0].get('type') == 'video' + + if is_video: + # For videos, use .mp4 extension + preview_ext = '.mp4' + preview_path = os.path.splitext(save_path)[0] + preview_ext + + # Download video directly + if await self.civitai_client.download_preview_image(images[0]['url'], preview_path): + metadata.preview_url = preview_path.replace(os.sep, '/') + metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0) + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False) + else: + # For images, use WebP format for better performance + with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file: + temp_path = temp_file.name + + # Download the original image to temp path + if await self.civitai_client.download_preview_image(images[0]['url'], temp_path): + # Optimize and convert to WebP + preview_path = os.path.splitext(save_path)[0] + '.webp' + + # Use ExifUtils to optimize and convert the image + optimized_data, _ = ExifUtils.optimize_image( + image_data=temp_path, + target_width=CARD_PREVIEW_WIDTH, + format='webp', + quality=85, + preserve_metadata=True + ) + + # Save the optimized image + with open(preview_path, 'wb') as f: + f.write(optimized_data) + + # Update metadata + metadata.preview_url = preview_path.replace(os.sep, '/') + metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0) + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False) + + # Remove temporary file + try: + os.unlink(temp_path) + except Exception as e: + logger.warning(f"Failed to delete temp file: {e}") # Report preview download completion if progress_callback: diff --git a/static/js/components/LoraCard.js b/static/js/components/LoraCard.js index 2e0b9bd0..e6534081 100644 --- a/static/js/components/LoraCard.js +++ b/static/js/components/LoraCard.js @@ -3,7 +3,7 @@ import { state } from '../state/index.js'; import { showLoraModal } from './loraModal/index.js'; import { bulkManager } from '../managers/BulkManager.js'; import { NSFW_LEVELS } from '../utils/constants.js'; -import { replacePreview } from '../api/loraApi.js' +import { replacePreview, deleteModel } from '../api/loraApi.js' export function createLoraCard(lora) { const card = document.createElement('div'); From 31d27ff3fa1567b6ace7df216a3bab3116ebe3e0 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 11 Apr 2025 10:54:19 +0800 Subject: [PATCH 19/36] refactor: Extract model-related utility functions into ModelRouteUtils for better code organization --- py/routes/api_routes.py | 296 +++++++------------------------ py/routes/checkpoints_routes.py | 119 +------------ py/utils/routes_common.py | 252 ++++++++++++++++++++++++++ static/js/components/LoraCard.js | 2 +- 4 files changed, 318 insertions(+), 351 deletions(-) create mode 100644 py/utils/routes_common.py diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index a35bf2b2..331cadb4 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -2,9 +2,9 @@ import os import json import logging from aiohttp import web -from typing import Dict, List +from typing import Dict -from ..utils.model_utils import determine_base_model +from ..utils.routes_common import ModelRouteUtils from ..services.file_monitor import LoraFileMonitor from ..services.download_manager import DownloadManager @@ -72,7 +72,19 @@ class ApiRoutes: target_dir = os.path.dirname(file_path) file_name = os.path.splitext(os.path.basename(file_path))[0] - deleted_files = await self._delete_model_files(target_dir, file_name) + deleted_files = await ModelRouteUtils.delete_model_files( + target_dir, + file_name, + self.download_manager.file_monitor + ) + + # Remove from cache + cache = await self.scanner.get_cached_data() + cache.raw_data = [item for item in cache.raw_data if item['file_path'] != file_path] + await cache.resort() + + # update hash index + self.scanner._hash_index.remove_by_path(file_path) return web.json_response({ 'success': True, @@ -90,14 +102,18 @@ class ApiRoutes: metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json' # Check if model is from CivitAI - local_metadata = await self._load_local_metadata(metadata_path) + local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path) # Fetch and update metadata civitai_metadata = await self.civitai_client.get_model_by_hash(local_metadata["sha256"]) if not civitai_metadata: - return await self._handle_not_found_on_civitai(metadata_path, local_metadata) + await ModelRouteUtils.handle_not_found_on_civitai(metadata_path, local_metadata) + return web.json_response({"success": False, "error": "Not found on CivitAI"}, status=404) - await self._update_model_metadata(metadata_path, local_metadata, civitai_metadata, self.civitai_client) + await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, self.civitai_client) + + # Update the cache + await self.scanner.update_single_model_cache(data['file_path'], data['file_path'], local_metadata) return web.json_response({"success": True}) @@ -139,10 +155,12 @@ class ApiRoutes: fuzzy_search = request.query.get('fuzzy', 'false').lower() == 'true' # Parse search options - search_filename = request.query.get('search_filename', 'true').lower() == 'true' - search_modelname = request.query.get('search_modelname', 'true').lower() == 'true' - search_tags = request.query.get('search_tags', 'false').lower() == 'true' - recursive = request.query.get('recursive', 'false').lower() == 'true' + search_options = { + 'filename': request.query.get('search_filename', 'true').lower() == 'true', + 'modelname': request.query.get('search_modelname', 'true').lower() == 'true', + 'tags': request.query.get('search_tags', 'false').lower() == 'true', + 'recursive': request.query.get('recursive', 'false').lower() == 'true' + } # Get filter parameters base_models = request.query.get('base_models', None) @@ -159,14 +177,6 @@ class ApiRoutes: if tags: filters['tags'] = tags.split(',') - # Add search options to filters - search_options = { - 'filename': search_filename, - 'modelname': search_modelname, - 'tags': search_tags, - 'recursive': recursive - } - # Add lora hash filtering options hash_filters = {} if lora_hash: @@ -225,67 +235,10 @@ class ApiRoutes: "from_civitai": lora.get("from_civitai", True), "usage_tips": lora.get("usage_tips", ""), "notes": lora.get("notes", ""), - "civitai": self._filter_civitai_data(lora.get("civitai", {})) + "civitai": ModelRouteUtils.filter_civitai_data(lora.get("civitai", {})) } - def _filter_civitai_data(self, data: Dict) -> Dict: - """Filter relevant fields from CivitAI data""" - if not data: - return {} - - fields = [ - "id", "modelId", "name", "createdAt", "updatedAt", - "publishedAt", "trainedWords", "baseModel", "description", - "model", "images" - ] - return {k: data[k] for k in fields if k in data} - # Private helper methods - async def _delete_model_files(self, target_dir: str, file_name: str) -> List[str]: - """Delete model and associated files""" - patterns = [ - f"{file_name}.safetensors", # Required - f"{file_name}.metadata.json", - ] - - # 添加所有预览文件扩展名 - for ext in PREVIEW_EXTENSIONS: - patterns.append(f"{file_name}{ext}") - - deleted = [] - main_file = patterns[0] - main_path = os.path.join(target_dir, main_file).replace(os.sep, '/') - - if os.path.exists(main_path): - # Notify file monitor to ignore delete event - self.download_manager.file_monitor.handler.add_ignore_path(main_path, 0) - - # Delete file - os.remove(main_path) - deleted.append(main_path) - else: - logger.warning(f"Model file not found: {main_file}") - - # Remove from cache - cache = await self.scanner.get_cached_data() - cache.raw_data = [item for item in cache.raw_data if item['file_path'] != main_path] - await cache.resort() - - # update hash index - self.scanner._hash_index.remove_by_path(main_path) - - # Delete optional files - for pattern in patterns[1:]: - path = os.path.join(target_dir, pattern) - if os.path.exists(path): - try: - os.remove(path) - deleted.append(pattern) - except Exception as e: - logger.warning(f"Failed to delete {pattern}: {e}") - - return deleted - async def _read_preview_file(self, reader) -> tuple[bytes, str]: """Read preview file and content type from multipart request""" field = await reader.next() @@ -345,66 +298,6 @@ class ApiRoutes: except Exception as e: logger.error(f"Error updating metadata: {e}") - async def _load_local_metadata(self, metadata_path: str) -> Dict: - """Load local metadata file""" - if os.path.exists(metadata_path): - try: - with open(metadata_path, 'r', encoding='utf-8') as f: - return json.load(f) - except Exception as e: - logger.error(f"Error loading metadata from {metadata_path}: {e}") - return {} - - async def _handle_not_found_on_civitai(self, metadata_path: str, local_metadata: Dict) -> web.Response: - """Handle case when model is not found on CivitAI""" - local_metadata['from_civitai'] = False - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(local_metadata, f, indent=2, ensure_ascii=False) - return web.json_response( - {"success": False, "error": "Not found on CivitAI"}, - status=404 - ) - - async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict, - civitai_metadata: Dict, client: CivitaiClient) -> None: - """Update local metadata with CivitAI data""" - local_metadata['civitai'] = civitai_metadata - - # Update model name if available - if 'model' in civitai_metadata: - if civitai_metadata.get('model', {}).get('name'): - local_metadata['model_name'] = civitai_metadata['model']['name'] - - # Fetch additional model metadata (description and tags) if we have model ID - model_id = civitai_metadata['modelId'] - if model_id: - model_metadata, _ = await client.get_model_metadata(str(model_id)) - if model_metadata: - local_metadata['modelDescription'] = model_metadata.get('description', '') - local_metadata['tags'] = model_metadata.get('tags', []) - - # Update base model - local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel')) - - # Update preview if needed - if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']): - first_preview = next((img for img in civitai_metadata.get('images', [])), None) - if first_preview: - preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1] - base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0] - preview_filename = base_name + preview_ext - preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename) - - if await client.download_preview_image(first_preview['url'], preview_path): - local_metadata['preview_url'] = preview_path.replace(os.sep, '/') - local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0) - - # Save updated metadata - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(local_metadata, f, indent=2, ensure_ascii=False) - - await self.scanner.update_single_model_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata) - async def fetch_all_civitai(self, request: web.Request) -> web.Response: """Fetch CivitAI metadata for all loras in the background""" try: @@ -414,14 +307,14 @@ class ApiRoutes: success = 0 needs_resort = False - # 准备要处理的 loras + # Prepare loras to process to_process = [ lora for lora in cache.raw_data - if lora.get('sha256') and (not lora.get('civitai') or 'id' not in lora.get('civitai')) and lora.get('from_civitai') # TODO: for lora not from CivitAI but added traineWords + if lora.get('sha256') and (not lora.get('civitai') or 'id' not in lora.get('civitai')) and lora.get('from_civitai', True) # TODO: for lora not from CivitAI but added traineWords ] total_to_process = len(to_process) - # 发送初始进度 + # Send initial progress await ws_manager.broadcast({ 'status': 'started', 'total': total_to_process, @@ -432,10 +325,11 @@ class ApiRoutes: for lora in to_process: try: original_name = lora.get('model_name') - if await self._fetch_and_update_single_lora( + if await ModelRouteUtils.fetch_and_update_model( sha256=lora['sha256'], file_path=lora['file_path'], - lora=lora + model_data=lora, + update_cache_func=self.scanner.update_single_model_cache ): success += 1 if original_name != lora.get('model_name'): @@ -443,7 +337,7 @@ class ApiRoutes: processed += 1 - # 每处理一个就发送进度更新 + # Send progress update await ws_manager.broadcast({ 'status': 'processing', 'total': total_to_process, @@ -458,7 +352,7 @@ class ApiRoutes: if needs_resort: await cache.resort(name_only=True) - # 发送完成消息 + # Send completion message await ws_manager.broadcast({ 'status': 'completed', 'total': total_to_process, @@ -472,7 +366,7 @@ class ApiRoutes: }) except Exception as e: - # 发送错误消息 + # Send error message await ws_manager.broadcast({ 'status': 'error', 'error': str(e) @@ -480,58 +374,6 @@ class ApiRoutes: logger.error(f"Error in fetch_all_civitai: {e}") return web.Response(text=str(e), status=500) - async def _fetch_and_update_single_lora(self, sha256: str, file_path: str, lora: dict) -> bool: - """Fetch and update metadata for a single lora without sorting - - Args: - sha256: SHA256 hash of the lora file - file_path: Path to the lora file - lora: The lora object in cache to update - - Returns: - bool: True if successful, False otherwise - """ - client = CivitaiClient() - try: - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - - # Check if model is from CivitAI - local_metadata = await self._load_local_metadata(metadata_path) - - # Fetch metadata - civitai_metadata = await client.get_model_by_hash(sha256) - if not civitai_metadata: - # Mark as not from CivitAI if not found - local_metadata['from_civitai'] = False - lora['from_civitai'] = False - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(local_metadata, f, indent=2, ensure_ascii=False) - return False - - # Update metadata - await self._update_model_metadata( - metadata_path, - local_metadata, - civitai_metadata, - client - ) - - # Update cache object directly - lora.update({ - 'model_name': local_metadata.get('model_name'), - 'preview_url': local_metadata.get('preview_url'), - 'from_civitai': True, - 'civitai': civitai_metadata - }) - - return True - - except Exception as e: - logger.error(f"Error fetching CivitAI data: {e}") - return False - finally: - await client.close() - async def get_lora_roots(self, request: web.Request) -> web.Response: """Get all configured LoRA root directories""" return web.json_response({ @@ -669,7 +511,7 @@ class ApiRoutes: return web.json_response({'success': True}) except Exception as e: - logger.error(f"Error updating settings: {e}", exc_info=True) # 添加 exc_info=True 以获取完整堆栈 + logger.error(f"Error updating settings: {e}", exc_info=True) return web.Response(status=500, text=str(e)) async def move_model(self, request: web.Request) -> web.Response: @@ -731,11 +573,7 @@ class ApiRoutes: metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' # Load existing metadata - if os.path.exists(metadata_path): - with open(metadata_path, 'r', encoding='utf-8') as f: - metadata = json.load(f) - else: - metadata = {} + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) # Handle nested updates (for civitai.trainedWords) for key, value in metadata_updates.items(): @@ -798,7 +636,10 @@ class ApiRoutes: except Exception as e: logger.error(f"Error getting lora preview URL: {e}", exc_info=True) - return web.Response(text=str(e), status=500) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) async def get_lora_civitai_url(self, request: web.Request) -> web.Response: """Get the Civitai URL for a LoRA file""" @@ -921,14 +762,9 @@ class ApiRoutes: tags = [] if file_path: metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - if os.path.exists(metadata_path): - try: - with open(metadata_path, 'r', encoding='utf-8') as f: - metadata = json.load(f) - description = metadata.get('modelDescription') - tags = metadata.get('tags', []) - except Exception as e: - logger.error(f"Error loading metadata from {metadata_path}: {e}") + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + description = metadata.get('modelDescription') + tags = metadata.get('tags', []) # If description is not in metadata, fetch from CivitAI if not description: @@ -943,16 +779,14 @@ class ApiRoutes: if file_path: try: metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - if os.path.exists(metadata_path): - with open(metadata_path, 'r', encoding='utf-8') as f: - metadata = json.load(f) - - metadata['modelDescription'] = description - metadata['tags'] = tags - - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(metadata, f, indent=2, ensure_ascii=False) - logger.info(f"Saved model metadata to file for {file_path}") + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + + metadata['modelDescription'] = description + metadata['tags'] = tags + + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) + logger.info(f"Saved model metadata to file for {file_path}") except Exception as e: logger.error(f"Error saving model metadata: {e}") @@ -1018,12 +852,6 @@ class ApiRoutes: 'error': str(e) }, status=500) - def get_multipart_ext(self, filename): - parts = filename.split(".") - if len(parts) > 2: # 如果包含多级扩展名 - return "." + ".".join(parts[-2:]) # 取最后两部分,如 ".metadata.json" - return os.path.splitext(filename)[1] # 否则取普通扩展名,如 ".safetensors" - async def rename_lora(self, request: web.Request) -> web.Response: """Handle renaming a LoRA file and its associated files""" try: @@ -1063,7 +891,7 @@ class ApiRoutes: f"{old_file_name}.metadata.json", ] - # 添加所有预览文件扩展名 + # Add all preview file extensions for ext in PREVIEW_EXTENSIONS: patterns.append(f"{old_file_name}{ext}") @@ -1080,12 +908,8 @@ class ApiRoutes: metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json") if os.path.exists(metadata_path): - try: - with open(metadata_path, 'r', encoding='utf-8') as f: - metadata = json.load(f) - hash_value = metadata.get('sha256') - except Exception as e: - logger.error(f"Error loading metadata for rename: {e}") + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + hash_value = metadata.get('sha256') # Rename all files renamed_files = [] @@ -1101,7 +925,7 @@ class ApiRoutes: for old_path, pattern in existing_files: # Get the file extension like .safetensors or .metadata.json - ext = self.get_multipart_ext(pattern) + ext = ModelRouteUtils.get_multipart_ext(pattern) # Create the new path new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/') @@ -1123,7 +947,7 @@ class ApiRoutes: # Update preview_url if it exists if 'preview_url' in metadata and metadata['preview_url']: old_preview = metadata['preview_url'] - ext = self.get_multipart_ext(old_preview) + ext = ModelRouteUtils.get_multipart_ext(old_preview) new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/') metadata['preview_url'] = new_preview diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index fcd47c60..af16f732 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -1,15 +1,10 @@ import os import json -import asyncio -from typing import Dict -import aiohttp import jinja2 from aiohttp import web import logging -from datetime import datetime - -from ..utils.model_utils import determine_base_model +from ..utils.routes_common import ModelRouteUtils from ..utils.constants import NSFW_LEVELS from ..services.civitai_client import CivitaiClient from ..services.websocket_manager import ws_manager @@ -259,21 +254,9 @@ class CheckpointsRoutes: "from_civitai": checkpoint.get("from_civitai", True), "notes": checkpoint.get("notes", ""), "model_type": checkpoint.get("model_type", "checkpoint"), - "civitai": self._filter_civitai_data(checkpoint.get("civitai", {})) + "civitai": ModelRouteUtils.filter_civitai_data(checkpoint.get("civitai", {})) } - def _filter_civitai_data(self, data): - """Filter relevant fields from CivitAI data""" - if not data: - return {} - - fields = [ - "id", "modelId", "name", "createdAt", "updatedAt", - "publishedAt", "trainedWords", "baseModel", "description", - "model", "images" - ] - return {k: data[k] for k in fields if k in data} - async def fetch_all_civitai(self, request: web.Request) -> web.Response: """Fetch CivitAI metadata for all checkpoints in the background""" try: @@ -302,10 +285,11 @@ class CheckpointsRoutes: for cp in to_process: try: original_name = cp.get('model_name') - if await self._fetch_and_update_single_checkpoint( + if await ModelRouteUtils.fetch_and_update_model( sha256=cp['sha256'], file_path=cp['file_path'], - checkpoint=cp + model_data=cp, + update_cache_func=self.scanner.update_single_model_cache ): success += 1 if original_name != cp.get('model_name'): @@ -350,99 +334,6 @@ class CheckpointsRoutes: logger.error(f"Error in fetch_all_civitai for checkpoints: {e}") return web.Response(text=str(e), status=500) - async def _fetch_and_update_single_checkpoint(self, sha256: str, file_path: str, checkpoint: dict) -> bool: - """Fetch and update metadata for a single checkpoint without sorting""" - client = CivitaiClient() - try: - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - - # Load local metadata - local_metadata = self._load_local_metadata(metadata_path) - - # Fetch metadata from Civitai - civitai_metadata = await client.get_model_by_hash(sha256) - if not civitai_metadata: - # Mark as not from CivitAI if not found - local_metadata['from_civitai'] = False - checkpoint['from_civitai'] = False - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(local_metadata, f, indent=2, ensure_ascii=False) - return False - - # Update metadata with Civitai data - await self._update_model_metadata( - metadata_path, - local_metadata, - civitai_metadata, - client - ) - - # Update cache object directly - checkpoint.update({ - 'model_name': local_metadata.get('model_name'), - 'preview_url': local_metadata.get('preview_url'), - 'from_civitai': True, - 'civitai': civitai_metadata - }) - - return True - - except Exception as e: - logger.error(f"Error fetching CivitAI data for checkpoint: {e}") - return False - finally: - await client.close() - - def _load_local_metadata(self, metadata_path: str) -> Dict: - """Load local metadata file""" - if os.path.exists(metadata_path): - try: - with open(metadata_path, 'r', encoding='utf-8') as f: - return json.load(f) - except Exception as e: - logger.error(f"Error loading metadata from {metadata_path}: {e}") - return {} - - async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict, - civitai_metadata: Dict, client: CivitaiClient) -> None: - """Update local metadata with CivitAI data""" - local_metadata['civitai'] = civitai_metadata - - # Update model name if available - if 'model' in civitai_metadata: - if civitai_metadata.get('model', {}).get('name'): - local_metadata['model_name'] = civitai_metadata['model']['name'] - - # Fetch additional model metadata (description and tags) if we have model ID - model_id = civitai_metadata['modelId'] - if model_id: - model_metadata, _ = await client.get_model_metadata(str(model_id)) - if model_metadata: - local_metadata['modelDescription'] = model_metadata.get('description', '') - local_metadata['tags'] = model_metadata.get('tags', []) - - # Update base model - local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel')) - - # Update preview if needed - if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']): - first_preview = next((img for img in civitai_metadata.get('images', [])), None) - if first_preview: - preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1] - base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0] - preview_filename = base_name + preview_ext - preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename) - - if await client.download_preview_image(first_preview['url'], preview_path): - local_metadata['preview_url'] = preview_path.replace(os.sep, '/') - local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0) - - # Save updated metadata - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(local_metadata, f, indent=2, ensure_ascii=False) - - await self.scanner.update_single_model_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata) - async def get_top_tags(self, request: web.Request) -> web.Response: """Handle request for top tags sorted by frequency""" try: diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py new file mode 100644 index 00000000..69ea63a1 --- /dev/null +++ b/py/utils/routes_common.py @@ -0,0 +1,252 @@ +import os +import json +import logging +from typing import Dict, List, Callable, Awaitable + +from .model_utils import determine_base_model +from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH +from ..config import config +from ..services.civitai_client import CivitaiClient +from ..utils.exif_utils import ExifUtils + +logger = logging.getLogger(__name__) + + +class ModelRouteUtils: + """Shared utilities for model routes (LoRAs, Checkpoints, etc.)""" + + @staticmethod + async def load_local_metadata(metadata_path: str) -> Dict: + """Load local metadata file""" + if os.path.exists(metadata_path): + try: + with open(metadata_path, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + logger.error(f"Error loading metadata from {metadata_path}: {e}") + return {} + + @staticmethod + async def handle_not_found_on_civitai(metadata_path: str, local_metadata: Dict) -> None: + """Handle case when model is not found on CivitAI""" + local_metadata['from_civitai'] = False + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(local_metadata, f, indent=2, ensure_ascii=False) + + @staticmethod + async def update_model_metadata(metadata_path: str, local_metadata: Dict, + civitai_metadata: Dict, client: CivitaiClient) -> None: + """Update local metadata with CivitAI data""" + local_metadata['civitai'] = civitai_metadata + + # Update model name if available + if 'model' in civitai_metadata: + if civitai_metadata.get('model', {}).get('name'): + local_metadata['model_name'] = civitai_metadata['model']['name'] + + # Fetch additional model metadata (description and tags) if we have model ID + model_id = civitai_metadata['modelId'] + if model_id: + model_metadata, _ = await client.get_model_metadata(str(model_id)) + if model_metadata: + local_metadata['modelDescription'] = model_metadata.get('description', '') + local_metadata['tags'] = model_metadata.get('tags', []) + + # Update base model + local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel')) + + # Update preview if needed + if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']): + first_preview = next((img for img in civitai_metadata.get('images', [])), None) + if first_preview: + # Determine if content is video or image + is_video = first_preview['type'] == 'video' + + if is_video: + # For videos use .mp4 extension + preview_ext = '.mp4' + else: + # For images use .webp extension + preview_ext = '.webp' + + base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0] + preview_filename = base_name + preview_ext + preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename) + + if is_video: + # Download video as is + if await client.download_preview_image(first_preview['url'], preview_path): + local_metadata['preview_url'] = preview_path.replace(os.sep, '/') + local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0) + else: + # For images, download and then optimize to WebP + temp_path = preview_path + ".temp" + if await client.download_preview_image(first_preview['url'], temp_path): + try: + # Read the downloaded image + with open(temp_path, 'rb') as f: + image_data = f.read() + + # Optimize and convert to WebP + optimized_data, _ = ExifUtils.optimize_image( + image_data=image_data, + target_width=CARD_PREVIEW_WIDTH, + format='webp', + quality=85, + preserve_metadata=True + ) + + # Save the optimized WebP image + with open(preview_path, 'wb') as f: + f.write(optimized_data) + + # Update metadata + local_metadata['preview_url'] = preview_path.replace(os.sep, '/') + local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0) + + # Remove the temporary file + if os.path.exists(temp_path): + os.remove(temp_path) + + except Exception as e: + logger.error(f"Error optimizing preview image: {e}") + # If optimization fails, try to use the downloaded image directly + if os.path.exists(temp_path): + os.rename(temp_path, preview_path) + local_metadata['preview_url'] = preview_path.replace(os.sep, '/') + local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0) + + # Save updated metadata + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(local_metadata, f, indent=2, ensure_ascii=False) + + @staticmethod + async def fetch_and_update_model( + sha256: str, + file_path: str, + model_data: dict, + update_cache_func: Callable[[str, str, Dict], Awaitable[bool]] + ) -> bool: + """Fetch and update metadata for a single model + + Args: + sha256: SHA256 hash of the model file + file_path: Path to the model file + model_data: The model object in cache to update + update_cache_func: Function to update the cache with new metadata + + Returns: + bool: True if successful, False otherwise + """ + client = CivitaiClient() + try: + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + + # Check if model metadata exists + local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + + # Fetch metadata from Civitai + civitai_metadata = await client.get_model_by_hash(sha256) + if not civitai_metadata: + # Mark as not from CivitAI if not found + local_metadata['from_civitai'] = False + model_data['from_civitai'] = False + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(local_metadata, f, indent=2, ensure_ascii=False) + return False + + # Update metadata + await ModelRouteUtils.update_model_metadata( + metadata_path, + local_metadata, + civitai_metadata, + client + ) + + # Update cache object directly + model_data.update({ + 'model_name': local_metadata.get('model_name'), + 'preview_url': local_metadata.get('preview_url'), + 'from_civitai': True, + 'civitai': civitai_metadata + }) + + # Update cache using the provided function + await update_cache_func(file_path, file_path, local_metadata) + + return True + + except Exception as e: + logger.error(f"Error fetching CivitAI data: {e}") + return False + finally: + await client.close() + + @staticmethod + def filter_civitai_data(data: Dict) -> Dict: + """Filter relevant fields from CivitAI data""" + if not data: + return {} + + fields = [ + "id", "modelId", "name", "createdAt", "updatedAt", + "publishedAt", "trainedWords", "baseModel", "description", + "model", "images" + ] + return {k: data[k] for k in fields if k in data} + + @staticmethod + async def delete_model_files(target_dir: str, file_name: str, file_monitor=None) -> List[str]: + """Delete model and associated files + + Args: + target_dir: Directory containing the model files + file_name: Base name of the model file without extension + file_monitor: Optional file monitor to ignore delete events + + Returns: + List of deleted file paths + """ + patterns = [ + f"{file_name}.safetensors", # Required + f"{file_name}.metadata.json", + ] + + # Add all preview file extensions + for ext in PREVIEW_EXTENSIONS: + patterns.append(f"{file_name}{ext}") + + deleted = [] + main_file = patterns[0] + main_path = os.path.join(target_dir, main_file).replace(os.sep, '/') + + if os.path.exists(main_path): + # Notify file monitor to ignore delete event if available + if file_monitor: + file_monitor.handler.add_ignore_path(main_path, 0) + + # Delete file + os.remove(main_path) + deleted.append(main_path) + else: + logger.warning(f"Model file not found: {main_file}") + + # Delete optional files + for pattern in patterns[1:]: + path = os.path.join(target_dir, pattern) + if os.path.exists(path): + try: + os.remove(path) + deleted.append(pattern) + except Exception as e: + logger.warning(f"Failed to delete {pattern}: {e}") + + return deleted + + @staticmethod + def get_multipart_ext(filename): + """Get extension that may have multiple parts like .metadata.json""" + parts = filename.split(".") + if len(parts) > 2: # If contains multi-part extension + return "." + ".".join(parts[-2:]) # Take the last two parts, like ".metadata.json" + return os.path.splitext(filename)[1] # Otherwise take the regular extension, like ".safetensors" \ No newline at end of file diff --git a/static/js/components/LoraCard.js b/static/js/components/LoraCard.js index e6534081..86737b59 100644 --- a/static/js/components/LoraCard.js +++ b/static/js/components/LoraCard.js @@ -1,4 +1,4 @@ -import { showToast } from '../utils/uiHelpers.js'; +import { showToast, openCivitai } from '../utils/uiHelpers.js'; import { state } from '../state/index.js'; import { showLoraModal } from './loraModal/index.js'; import { bulkManager } from '../managers/BulkManager.js'; From 56670066c77bc7105b7a416fd3f817e17aa1b8c5 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 11 Apr 2025 11:17:49 +0800 Subject: [PATCH 20/36] refactor: Optimize preview image handling by converting to webp format and improving error logging --- py/routes/checkpoints_routes.py | 1 - py/routes/lora_routes.py | 2 -- py/utils/file_utils.py | 35 ++++++++++++++++++++++++++++++++- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index af16f732..7b591a65 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -439,7 +439,6 @@ class CheckpointsRoutes: settings=settings, # Pass settings to template request=request # Pass the request object to the template ) - logger.debug(f"Checkpoints page loaded successfully with {len(cache.raw_data)} items") except Exception as cache_error: logger.error(f"Error loading checkpoints cache data: {cache_error}") # 如果获取缓存失败,也显示初始化页面 diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index 448c424f..a2c392fc 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -87,7 +87,6 @@ class LoraRoutes: settings=settings, # Pass settings to template request=request # Pass the request object to the template ) - logger.debug(f"Loras page loaded successfully with {len(cache.raw_data)} items") except Exception as cache_error: logger.error(f"Error loading cache data: {cache_error}") # 如果获取缓存失败,也显示初始化页面 @@ -143,7 +142,6 @@ class LoraRoutes: settings=settings, request=request # Pass the request object to the template ) - logger.debug(f"Recipes page loaded successfully with {len(cache.raw_data)} items") except Exception as cache_error: logger.error(f"Error loading recipe cache data: {cache_error}") # 如果获取缓存失败,也显示初始化页面 diff --git a/py/utils/file_utils.py b/py/utils/file_utils.py index 058469d6..1a9825a2 100644 --- a/py/utils/file_utils.py +++ b/py/utils/file_utils.py @@ -8,7 +8,8 @@ from typing import Dict, Optional, Type from .model_utils import determine_base_model from .lora_metadata import extract_lora_metadata, extract_checkpoint_metadata from .models import BaseModelMetadata, LoraMetadata, CheckpointMetadata -from .constants import PREVIEW_EXTENSIONS +from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH +from .exif_utils import ExifUtils logger = logging.getLogger(__name__) @@ -26,7 +27,38 @@ def find_preview_file(base_name: str, dir_path: str) -> str: for ext in PREVIEW_EXTENSIONS: full_pattern = os.path.join(dir_path, f"{base_name}{ext}") if os.path.exists(full_pattern): + # Check if this is an image and not already webp + if ext.lower().endswith(('.jpg', '.jpeg', '.png')) and not ext.lower().endswith('.webp'): + try: + # Optimize the image to webp format + webp_path = os.path.join(dir_path, f"{base_name}.webp") + + # Use ExifUtils to optimize the image + with open(full_pattern, 'rb') as f: + image_data = f.read() + + optimized_data, _ = ExifUtils.optimize_image( + image_data=image_data, + target_width=CARD_PREVIEW_WIDTH, + format='webp', + quality=85, + preserve_metadata=True + ) + + # Save the optimized webp file + with open(webp_path, 'wb') as f: + f.write(optimized_data) + + logger.debug(f"Optimized preview image from {full_pattern} to {webp_path}") + return webp_path.replace(os.sep, "/") + except Exception as e: + logger.error(f"Error optimizing preview image {full_pattern}: {e}") + # Fall back to original file if optimization fails + return full_pattern.replace(os.sep, "/") + + # Return the original path for webp images or non-image files return full_pattern.replace(os.sep, "/") + return "" def normalize_path(path: str) -> str: @@ -154,6 +186,7 @@ async def load_metadata(file_path: str, model_class: Type[BaseModelMetadata] = L data['file_path'] = normalize_path(file_path) needs_update = True + # TODO: optimize preview image to webp format if not already done preview_url = data.get('preview_url', '') if not preview_url or not os.path.exists(preview_url): base_name = os.path.splitext(os.path.basename(file_path))[0] From e991dc061d20583a6413c7c97da85c1aeda32ba6 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 11 Apr 2025 12:06:05 +0800 Subject: [PATCH 21/36] refactor: Implement common endpoint handlers for model management in ModelRouteUtils and update routes in CheckpointsRoutes --- py/routes/api_routes.py | 76 +---------- py/routes/checkpoints_routes.py | 17 +++ py/utils/routes_common.py | 174 ++++++++++++++++++++++++- static/js/components/CheckpointCard.js | 5 +- static/js/utils/modalUtils.js | 18 ++- 5 files changed, 210 insertions(+), 80 deletions(-) diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index 331cadb4..2c63827e 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -63,85 +63,15 @@ class ApiRoutes: async def delete_model(self, request: web.Request) -> web.Response: """Handle model deletion request""" - try: - data = await request.json() - file_path = data.get('file_path') - if not file_path: - return web.Response(text='Model path is required', status=400) - - target_dir = os.path.dirname(file_path) - file_name = os.path.splitext(os.path.basename(file_path))[0] - - deleted_files = await ModelRouteUtils.delete_model_files( - target_dir, - file_name, - self.download_manager.file_monitor - ) - - # Remove from cache - cache = await self.scanner.get_cached_data() - cache.raw_data = [item for item in cache.raw_data if item['file_path'] != file_path] - await cache.resort() - - # update hash index - self.scanner._hash_index.remove_by_path(file_path) - - return web.json_response({ - 'success': True, - 'deleted_files': deleted_files - }) - - except Exception as e: - logger.error(f"Error deleting model: {e}", exc_info=True) - return web.Response(text=str(e), status=500) + return await ModelRouteUtils.handle_delete_model(request, self.scanner) async def fetch_civitai(self, request: web.Request) -> web.Response: """Handle CivitAI metadata fetch request""" - try: - data = await request.json() - metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json' - - # Check if model is from CivitAI - local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path) - - # Fetch and update metadata - civitai_metadata = await self.civitai_client.get_model_by_hash(local_metadata["sha256"]) - if not civitai_metadata: - await ModelRouteUtils.handle_not_found_on_civitai(metadata_path, local_metadata) - return web.json_response({"success": False, "error": "Not found on CivitAI"}, status=404) - - await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, self.civitai_client) - - # Update the cache - await self.scanner.update_single_model_cache(data['file_path'], data['file_path'], local_metadata) - - return web.json_response({"success": True}) - - except Exception as e: - logger.error(f"Error fetching from CivitAI: {e}", exc_info=True) - return web.json_response({"success": False, "error": str(e)}, status=500) + return await ModelRouteUtils.handle_fetch_civitai(request, self.scanner) async def replace_preview(self, request: web.Request) -> web.Response: """Handle preview image replacement request""" - try: - reader = await request.multipart() - preview_data, content_type = await self._read_preview_file(reader) - model_path = await self._read_model_path(reader) - - preview_path = await self._save_preview_file(model_path, preview_data, content_type) - await self._update_preview_metadata(model_path, preview_path) - - # Update preview URL in scanner cache - await self.scanner.update_preview_in_cache(model_path, preview_path) - - return web.json_response({ - "success": True, - "preview_url": config.get_preview_static_url(preview_path) - }) - - except Exception as e: - logger.error(f"Error replacing preview: {e}", exc_info=True) - return web.Response(text=str(e), status=500) + return await ModelRouteUtils.handle_replace_preview(request, self.scanner) async def get_loras(self, request: web.Request) -> web.Response: """Handle paginated LoRA data request""" diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index 7b591a65..e65bea5b 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -34,6 +34,11 @@ class CheckpointsRoutes: app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags) app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints) app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info) + + # Add new routes for model management similar to LoRA routes + app.router.add_post('/api/checkpoints/delete', self.delete_model) + app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai) + app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview) async def get_checkpoints(self, request): """Get paginated checkpoint data""" @@ -461,3 +466,15 @@ class CheckpointsRoutes: text="Error loading checkpoints page", status=500 ) + + async def delete_model(self, request: web.Request) -> web.Response: + """Handle checkpoint model deletion request""" + return await ModelRouteUtils.handle_delete_model(request, self.scanner) + + async def fetch_civitai(self, request: web.Request) -> web.Response: + """Handle CivitAI metadata fetch request for checkpoints""" + return await ModelRouteUtils.handle_fetch_civitai(request, self.scanner) + + async def replace_preview(self, request: web.Request) -> web.Response: + """Handle preview image replacement for checkpoints""" + return await ModelRouteUtils.handle_replace_preview(request, self.scanner) diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index 69ea63a1..5b68b368 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -2,6 +2,7 @@ import os import json import logging from typing import Dict, List, Callable, Awaitable +from aiohttp import web from .model_utils import determine_base_model from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH @@ -249,4 +250,175 @@ class ModelRouteUtils: parts = filename.split(".") if len(parts) > 2: # If contains multi-part extension return "." + ".".join(parts[-2:]) # Take the last two parts, like ".metadata.json" - return os.path.splitext(filename)[1] # Otherwise take the regular extension, like ".safetensors" \ No newline at end of file + return os.path.splitext(filename)[1] # Otherwise take the regular extension, like ".safetensors" + + # New common endpoint handlers + + @staticmethod + async def handle_delete_model(request: web.Request, scanner) -> web.Response: + """Handle model deletion request + + Args: + request: The aiohttp request + scanner: The model scanner instance with cache management methods + + Returns: + web.Response: The HTTP response + """ + try: + data = await request.json() + file_path = data.get('file_path') + if not file_path: + return web.Response(text='Model path is required', status=400) + + target_dir = os.path.dirname(file_path) + file_name = os.path.splitext(os.path.basename(file_path))[0] + + # Get the file monitor from the scanner if available + file_monitor = getattr(scanner, 'file_monitor', None) + + deleted_files = await ModelRouteUtils.delete_model_files( + target_dir, + file_name, + file_monitor + ) + + # Remove from cache + cache = await scanner.get_cached_data() + cache.raw_data = [item for item in cache.raw_data if item['file_path'] != file_path] + await cache.resort() + + # Update hash index if available + if hasattr(scanner, '_hash_index') and scanner._hash_index: + scanner._hash_index.remove_by_path(file_path) + + return web.json_response({ + 'success': True, + 'deleted_files': deleted_files + }) + + except Exception as e: + logger.error(f"Error deleting model: {e}", exc_info=True) + return web.Response(text=str(e), status=500) + + @staticmethod + async def handle_fetch_civitai(request: web.Request, scanner) -> web.Response: + """Handle CivitAI metadata fetch request + + Args: + request: The aiohttp request + scanner: The model scanner instance with cache management methods + + Returns: + web.Response: The HTTP response + """ + try: + data = await request.json() + metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json' + + # Check if model metadata exists + local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + if not local_metadata or not local_metadata.get('sha256'): + return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400) + + # Create a client for fetching from Civitai + client = CivitaiClient() + try: + # Fetch and update metadata + civitai_metadata = await client.get_model_by_hash(local_metadata["sha256"]) + if not civitai_metadata: + await ModelRouteUtils.handle_not_found_on_civitai(metadata_path, local_metadata) + return web.json_response({"success": False, "error": "Not found on CivitAI"}, status=404) + + await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, client) + + # Update the cache + await scanner.update_single_model_cache(data['file_path'], data['file_path'], local_metadata) + + return web.json_response({"success": True}) + finally: + await client.close() + + except Exception as e: + logger.error(f"Error fetching from CivitAI: {e}", exc_info=True) + return web.json_response({"success": False, "error": str(e)}, status=500) + + @staticmethod + async def handle_replace_preview(request: web.Request, scanner) -> web.Response: + """Handle preview image replacement request + + Args: + request: The aiohttp request + scanner: The model scanner instance with methods to update cache + + Returns: + web.Response: The HTTP response + """ + try: + reader = await request.multipart() + + # Read preview file data + field = await reader.next() + if field.name != 'preview_file': + raise ValueError("Expected 'preview_file' field") + content_type = field.headers.get('Content-Type', 'image/png') + preview_data = await field.read() + + # Read model path + field = await reader.next() + if field.name != 'model_path': + raise ValueError("Expected 'model_path' field") + model_path = (await field.read()).decode() + + # Save preview file + base_name = os.path.splitext(os.path.basename(model_path))[0] + folder = os.path.dirname(model_path) + + # Determine if content is video or image + if content_type.startswith('video/'): + # For videos, keep original format and use .mp4 extension + extension = '.mp4' + optimized_data = preview_data + else: + # For images, optimize and convert to WebP + optimized_data, _ = ExifUtils.optimize_image( + image_data=preview_data, + target_width=CARD_PREVIEW_WIDTH, + format='webp', + quality=85, + preserve_metadata=True + ) + extension = '.webp' # Use .webp without .preview part + + preview_path = os.path.join(folder, base_name + extension).replace(os.sep, '/') + + with open(preview_path, 'wb') as f: + f.write(optimized_data) + + # Update preview path in metadata + metadata_path = os.path.splitext(model_path)[0] + '.metadata.json' + if os.path.exists(metadata_path): + try: + with open(metadata_path, 'r', encoding='utf-8') as f: + metadata = json.load(f) + + # Update preview_url directly in the metadata dict + metadata['preview_url'] = preview_path + + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) + except Exception as e: + logger.error(f"Error updating metadata: {e}") + + # Update preview URL in scanner cache + if hasattr(scanner, 'update_preview_in_cache'): + await scanner.update_preview_in_cache(model_path, preview_path) + + return web.json_response({ + "success": True, + "preview_url": config.get_preview_static_url(preview_path) + }) + + except Exception as e: + logger.error(f"Error replacing preview: {e}", exc_info=True) + return web.Response(text=str(e), status=500) \ No newline at end of file diff --git a/static/js/components/CheckpointCard.js b/static/js/components/CheckpointCard.js index f7031c15..abb95afe 100644 --- a/static/js/components/CheckpointCard.js +++ b/static/js/components/CheckpointCard.js @@ -294,7 +294,10 @@ function deleteCheckpoint(filePath) { if (window.deleteCheckpoint) { window.deleteCheckpoint(filePath); } else { - console.log('Delete checkpoint:', filePath); + // Use the modal delete functionality + import('../utils/modalUtils.js').then(({ showDeleteModal }) => { + showDeleteModal(filePath, 'checkpoint'); + }); } } diff --git a/static/js/utils/modalUtils.js b/static/js/utils/modalUtils.js index 49dfe15d..be5fe25d 100644 --- a/static/js/utils/modalUtils.js +++ b/static/js/utils/modalUtils.js @@ -1,10 +1,12 @@ import { modalManager } from '../managers/ModalManager.js'; let pendingDeletePath = null; +let pendingModelType = null; -export function showDeleteModal(filePath) { - event.stopPropagation(); +export function showDeleteModal(filePath, modelType = 'lora') { + // event.stopPropagation(); pendingDeletePath = filePath; + pendingModelType = modelType; const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); const modelName = card.dataset.name; @@ -23,11 +25,15 @@ export function showDeleteModal(filePath) { export async function confirmDelete() { if (!pendingDeletePath) return; - const modal = document.getElementById('deleteModal'); const card = document.querySelector(`.lora-card[data-filepath="${pendingDeletePath}"]`); try { - const response = await fetch('/api/delete_model', { + // Use the appropriate endpoint based on model type + const endpoint = pendingModelType === 'checkpoint' ? + '/api/checkpoints/delete' : + '/api/delete_model'; + + const response = await fetch(endpoint, { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -53,4 +59,6 @@ export async function confirmDelete() { export function closeDeleteModal() { modalManager.closeModal('deleteModal'); -} \ No newline at end of file + pendingDeletePath = null; + pendingModelType = null; +} \ No newline at end of file From 3df96034a1edda66f0582912cd13e0fbcf0446b9 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 11 Apr 2025 14:35:56 +0800 Subject: [PATCH 22/36] refactor: Consolidate model handling functions into baseModelApi for better code reuse and organization --- static/js/api/baseModelApi.js | 512 +++++++++++++++++++++++++ static/js/api/checkpointApi.js | 335 ++-------------- static/js/api/loraApi.js | 348 ++--------------- static/js/components/CheckpointCard.js | 3 +- 4 files changed, 580 insertions(+), 618 deletions(-) create mode 100644 static/js/api/baseModelApi.js diff --git a/static/js/api/baseModelApi.js b/static/js/api/baseModelApi.js new file mode 100644 index 00000000..20a669a3 --- /dev/null +++ b/static/js/api/baseModelApi.js @@ -0,0 +1,512 @@ +// filepath: d:\Workspace\ComfyUI\custom_nodes\ComfyUI-Lora-Manager\static\js\api\baseModelApi.js +import { state, getCurrentPageState } from '../state/index.js'; +import { showToast } from '../utils/uiHelpers.js'; +import { showDeleteModal, confirmDelete } from '../utils/modalUtils.js'; + +/** + * Shared functionality for handling models (loras and checkpoints) + */ + +// Generic function to load more models with pagination +export async function loadMoreModels(options = {}) { + const { + resetPage = false, + updateFolders = false, + modelType = 'lora', // 'lora' or 'checkpoint' + createCardFunction, + endpoint = '/api/loras' + } = options; + + const pageState = getCurrentPageState(); + + if (pageState.isLoading || (!pageState.hasMore && !resetPage)) return; + + pageState.isLoading = true; + document.body.classList.add('loading'); + + try { + // Reset to first page if requested + if (resetPage) { + pageState.currentPage = 1; + // Clear grid if resetting + const gridId = modelType === 'checkpoint' ? 'checkpointGrid' : 'loraGrid'; + const grid = document.getElementById(gridId); + if (grid) grid.innerHTML = ''; + } + + const params = new URLSearchParams({ + page: pageState.currentPage, + page_size: pageState.pageSize || 20, + sort_by: pageState.sortBy + }); + + if (pageState.activeFolder !== null) { + params.append('folder', pageState.activeFolder); + } + + // Add search parameters if there's a search term + if (pageState.filters?.search) { + params.append('search', pageState.filters.search); + params.append('fuzzy', 'true'); + + // Add search option parameters if available + if (pageState.searchOptions) { + params.append('search_filename', pageState.searchOptions.filename.toString()); + params.append('search_modelname', pageState.searchOptions.modelname.toString()); + if (pageState.searchOptions.tags !== undefined) { + params.append('search_tags', pageState.searchOptions.tags.toString()); + } + params.append('recursive', (pageState.searchOptions?.recursive ?? false).toString()); + } + } + + // Add filter parameters if active + if (pageState.filters) { + // Handle tags filters + if (pageState.filters.tags && pageState.filters.tags.length > 0) { + // Checkpoints API expects individual 'tag' parameters, Loras API expects comma-separated 'tags' + if (modelType === 'checkpoint') { + pageState.filters.tags.forEach(tag => { + params.append('tag', tag); + }); + } else { + params.append('tags', pageState.filters.tags.join(',')); + } + } + + // Handle base model filters + if (pageState.filters.baseModel && pageState.filters.baseModel.length > 0) { + if (modelType === 'checkpoint') { + pageState.filters.baseModel.forEach(model => { + params.append('base_model', model); + }); + } else { + params.append('base_models', pageState.filters.baseModel.join(',')); + } + } + } + + // Add model-specific parameters + if (modelType === 'lora') { + // Check for recipe-based filtering parameters from session storage + const filterLoraHash = getSessionItem ? getSessionItem('recipe_to_lora_filterLoraHash') : null; + const filterLoraHashes = getSessionItem ? getSessionItem('recipe_to_lora_filterLoraHashes') : null; + + // Add hash filter parameter if present + if (filterLoraHash) { + params.append('lora_hash', filterLoraHash); + } + // Add multiple hashes filter if present + else if (filterLoraHashes) { + try { + if (Array.isArray(filterLoraHashes) && filterLoraHashes.length > 0) { + params.append('lora_hashes', filterLoraHashes.join(',')); + } + } catch (error) { + console.error('Error parsing lora hashes from session storage:', error); + } + } + } + + const response = await fetch(`${endpoint}?${params}`); + if (!response.ok) { + throw new Error(`Failed to fetch models: ${response.statusText}`); + } + + const data = await response.json(); + + const gridId = modelType === 'checkpoint' ? 'checkpointGrid' : 'loraGrid'; + const grid = document.getElementById(gridId); + + if (data.items.length === 0 && pageState.currentPage === 1) { + grid.innerHTML = `
No ${modelType}s found in this folder
`; + pageState.hasMore = false; + } else if (data.items.length > 0) { + pageState.hasMore = pageState.currentPage < data.total_pages; + + // Append model cards using the provided card creation function + data.items.forEach(model => { + const card = createCardFunction(model); + grid.appendChild(card); + }); + + // Increment the page number AFTER successful loading + pageState.currentPage++; + } else { + pageState.hasMore = false; + } + + if (updateFolders && data.folders) { + updateFolderTags(data.folders); + } + + } catch (error) { + console.error(`Error loading ${modelType}s:`, error); + showToast(`Failed to load ${modelType}s: ${error.message}`, 'error'); + } finally { + pageState.isLoading = false; + document.body.classList.remove('loading'); + } +} + +// Update folder tags in the UI +export function updateFolderTags(folders) { + const folderTagsContainer = document.querySelector('.folder-tags'); + if (!folderTagsContainer) return; + + // Keep track of currently selected folder + const pageState = getCurrentPageState(); + const currentFolder = pageState.activeFolder; + + // Create HTML for folder tags + const tagsHTML = folders.map(folder => { + const isActive = folder === currentFolder; + return `
${folder}
`; + }).join(''); + + // Update the container + folderTagsContainer.innerHTML = tagsHTML; + + // Reattach click handlers and ensure the active tag is visible + const tags = folderTagsContainer.querySelectorAll('.tag'); + tags.forEach(tag => { + if (typeof toggleFolder === 'function') { + tag.addEventListener('click', toggleFolder); + } + if (tag.dataset.folder === currentFolder) { + tag.scrollIntoView({ behavior: 'smooth', block: 'nearest' }); + } + }); +} + +// Generic function to replace a model preview +export function replaceModelPreview(filePath, modelType = 'lora') { + // Open file picker + const input = document.createElement('input'); + input.type = 'file'; + input.accept ='image/*,video/mp4'; + + input.onchange = async function() { + if (!input.files || !input.files[0]) return; + + const file = input.files[0]; + await uploadPreview(filePath, file, modelType); + }; + + input.click(); +} + +// Delete a model (generic) +export function deleteModel(filePath, modelType = 'lora') { + if (modelType === 'checkpoint') { + confirmDelete('Are you sure you want to delete this checkpoint?', () => { + performDelete(filePath, modelType); + }); + } else { + showDeleteModal(filePath); + } +} + +// Reset and reload models +export async function resetAndReload(options = {}) { + const { + updateFolders = false, + modelType = 'lora', + loadMoreFunction + } = options; + + const pageState = getCurrentPageState(); + console.log('Resetting with state:', { ...pageState }); + + // Reset pagination and load more models + if (typeof loadMoreFunction === 'function') { + await loadMoreFunction(true, updateFolders); + } +} + +// Generic function to refresh models +export async function refreshModels(options = {}) { + const { + modelType = 'lora', + scanEndpoint = '/api/loras/scan', + resetAndReloadFunction + } = options; + + try { + state.loadingManager.showSimpleLoading(`Refreshing ${modelType}s...`); + + const response = await fetch(scanEndpoint); + + if (!response.ok) { + throw new Error(`Failed to refresh ${modelType}s: ${response.status} ${response.statusText}`); + } + + if (typeof resetAndReloadFunction === 'function') { + await resetAndReloadFunction(); + } + + showToast(`Refresh complete`, 'success'); + } catch (error) { + console.error(`Refresh failed:`, error); + showToast(`Failed to refresh ${modelType}s`, 'error'); + } finally { + state.loadingManager.hide(); + state.loadingManager.restoreProgressBar(); + } +} + +// Generic fetch from Civitai +export async function fetchCivitaiMetadata(options = {}) { + const { + modelType = 'lora', + fetchEndpoint = '/api/fetch-all-civitai', + resetAndReloadFunction + } = options; + + let ws = null; + + await state.loadingManager.showWithProgress(async (loading) => { + try { + const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; + ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`); + + const operationComplete = new Promise((resolve, reject) => { + ws.onmessage = (event) => { + const data = JSON.parse(event.data); + + switch(data.status) { + case 'started': + loading.setStatus('Starting metadata fetch...'); + break; + + case 'processing': + const percent = ((data.processed / data.total) * 100).toFixed(1); + loading.setProgress(percent); + loading.setStatus( + `Processing (${data.processed}/${data.total}) ${data.current_name}` + ); + break; + + case 'completed': + loading.setProgress(100); + loading.setStatus( + `Completed: Updated ${data.success} of ${data.processed} ${modelType}s` + ); + resolve(); + break; + + case 'error': + reject(new Error(data.error)); + break; + } + }; + + ws.onerror = (error) => { + reject(new Error('WebSocket error: ' + error.message)); + }; + }); + + await new Promise((resolve, reject) => { + ws.onopen = resolve; + ws.onerror = reject; + }); + + const requestBody = modelType === 'checkpoint' + ? JSON.stringify({ model_type: 'checkpoint' }) + : JSON.stringify({}); + + const response = await fetch(fetchEndpoint, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: requestBody + }); + + if (!response.ok) { + throw new Error('Failed to fetch metadata'); + } + + await operationComplete; + + if (typeof resetAndReloadFunction === 'function') { + await resetAndReloadFunction(); + } + + } catch (error) { + console.error('Error fetching metadata:', error); + showToast('Failed to fetch metadata: ' + error.message, 'error'); + } finally { + if (ws) { + ws.close(); + } + } + }, { + initialMessage: 'Connecting...', + completionMessage: 'Metadata update complete' + }); +} + +// Generic function to refresh single model metadata +export async function refreshSingleModelMetadata(filePath, modelType = 'lora') { + try { + state.loadingManager.showSimpleLoading('Refreshing metadata...'); + + const endpoint = modelType === 'checkpoint' + ? '/api/checkpoints/fetch-civitai' + : '/api/fetch-civitai'; + + const response = await fetch(endpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ file_path: filePath }) + }); + + if (!response.ok) { + throw new Error('Failed to refresh metadata'); + } + + const data = await response.json(); + + if (data.success) { + showToast('Metadata refreshed successfully', 'success'); + return true; + } else { + throw new Error(data.error || 'Failed to refresh metadata'); + } + } catch (error) { + console.error('Error refreshing metadata:', error); + showToast(error.message, 'error'); + return false; + } finally { + state.loadingManager.hide(); + state.loadingManager.restoreProgressBar(); + } +} + +// Private methods + +// Upload a preview image +async function uploadPreview(filePath, file, modelType = 'lora') { + const loadingOverlay = document.getElementById('loading-overlay'); + const loadingStatus = document.querySelector('.loading-status'); + + try { + if (loadingOverlay) loadingOverlay.style.display = 'flex'; + if (loadingStatus) loadingStatus.textContent = 'Uploading preview...'; + + const formData = new FormData(); + + // Use appropriate parameter names and endpoint based on model type + // Prepare common form data + formData.append('preview_file', file); + formData.append('model_path', filePath); + + // Set endpoint based on model type + const endpoint = modelType === 'checkpoint' + ? '/api/checkpoints/replace-preview' + : '/api/replace_preview'; + + const response = await fetch(endpoint, { + method: 'POST', + body: formData + }); + + if (!response.ok) { + throw new Error('Upload failed'); + } + + const data = await response.json(); + + // Update the card preview in UI + const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); + if (card) { + const previewContainer = card.querySelector('.card-preview'); + const oldPreview = previewContainer.querySelector('img, video'); + + // For LoRA models, use timestamp to prevent caching + if (modelType === 'lora') { + state.previewVersions?.set(filePath, Date.now()); + } + + const timestamp = Date.now(); + const previewUrl = data.preview_url ? + `${data.preview_url}?t=${timestamp}` : + `/api/model/preview_image?path=${encodeURIComponent(filePath)}&t=${timestamp}`; + + // Create appropriate element based on file type + if (file.type.startsWith('video/')) { + const video = document.createElement('video'); + video.controls = true; + video.autoplay = true; + video.muted = true; + video.loop = true; + video.src = previewUrl; + oldPreview.replaceWith(video); + } else { + const img = document.createElement('img'); + img.src = previewUrl; + oldPreview.replaceWith(img); + } + + showToast('Preview updated successfully', 'success'); + } + } catch (error) { + console.error('Error uploading preview:', error); + showToast('Failed to upload preview image', 'error'); + } finally { + if (loadingOverlay) loadingOverlay.style.display = 'none'; + } +} + +// Private function to perform the delete operation +async function performDelete(filePath, modelType = 'lora') { + try { + showToast(`Deleting ${modelType}...`, 'info'); + + const response = await fetch('/api/model/delete', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + file_path: filePath, + model_type: modelType + }) + }); + + if (!response.ok) { + throw new Error(`Failed to delete ${modelType}: ${response.status} ${response.statusText}`); + } + + const data = await response.json(); + + if (data.success) { + // Remove the card from UI + const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); + if (card) { + card.remove(); + } + + showToast(`${modelType} deleted successfully`, 'success'); + } else { + throw new Error(data.error || `Failed to delete ${modelType}`); + } + } catch (error) { + console.error(`Error deleting ${modelType}:`, error); + showToast(`Failed to delete ${modelType}: ${error.message}`, 'error'); + } +} + +// Helper function to get session item - import if available, otherwise provide fallback +function getSessionItem(key) { + if (typeof window !== 'undefined' && window.sessionStorage) { + const item = window.sessionStorage.getItem(key); + try { + return item ? JSON.parse(item) : null; + } catch (e) { + return item; + } + } + return null; +} \ No newline at end of file diff --git a/static/js/api/checkpointApi.js b/static/js/api/checkpointApi.js index e0ed5d4f..8b243be9 100644 --- a/static/js/api/checkpointApi.js +++ b/static/js/api/checkpointApi.js @@ -1,330 +1,57 @@ -import { state, getCurrentPageState } from '../state/index.js'; -import { showToast } from '../utils/uiHelpers.js'; -import { confirmDelete } from '../utils/modalUtils.js'; import { createCheckpointCard } from '../components/CheckpointCard.js'; +import { + loadMoreModels, + resetAndReload as baseResetAndReload, + refreshModels as baseRefreshModels, + deleteModel as baseDeleteModel, + replaceModelPreview, + fetchCivitaiMetadata +} from './baseModelApi.js'; // Load more checkpoints with pagination export async function loadMoreCheckpoints(resetPagination = true) { - try { - const pageState = getCurrentPageState(); - - // Don't load if we're already loading or there are no more items - if (pageState.isLoading || (!resetPagination && !pageState.hasMore)) { - return; - } - - // Set loading state - pageState.isLoading = true; - document.body.classList.add('loading'); - - // Reset pagination if requested - if (resetPagination) { - pageState.currentPage = 1; - const grid = document.getElementById('checkpointGrid'); - if (grid) grid.innerHTML = ''; - } - - // Build API URL with parameters - const params = new URLSearchParams({ - page: pageState.currentPage, - page_size: pageState.pageSize || 20, - sort: pageState.sortBy || 'name' - }); - - // Add folder filter if active - if (pageState.activeFolder) { - params.append('folder', pageState.activeFolder); - } - - // Add search if available - if (pageState.filters && pageState.filters.search) { - params.append('search', pageState.filters.search); - - // Add search options - if (pageState.searchOptions) { - params.append('search_filename', pageState.searchOptions.filename.toString()); - params.append('search_modelname', pageState.searchOptions.modelname.toString()); - params.append('recursive', pageState.searchOptions.recursive.toString()); - } - } - - // Add base model filters - if (pageState.filters && pageState.filters.baseModel && pageState.filters.baseModel.length > 0) { - pageState.filters.baseModel.forEach(model => { - params.append('base_model', model); - }); - } - - // Add tags filters - if (pageState.filters && pageState.filters.tags && pageState.filters.tags.length > 0) { - pageState.filters.tags.forEach(tag => { - params.append('tag', tag); - }); - } - - // Execute fetch - const response = await fetch(`/api/checkpoints?${params.toString()}`); - - if (!response.ok) { - throw new Error(`Failed to load checkpoints: ${response.status} ${response.statusText}`); - } - - const data = await response.json(); - - // Update state with response data - pageState.hasMore = data.page < data.total_pages; - - // Update UI with checkpoints - const grid = document.getElementById('checkpointGrid'); - if (!grid) { - return; - } - - // Clear grid if this is the first page - if (resetPagination) { - grid.innerHTML = ''; - } - - // Check for empty result - if (data.items.length === 0 && resetPagination) { - grid.innerHTML = ` -
-

No checkpoints found

-

Add checkpoints to your models folders to see them here.

-
- `; - return; - } - - // Render checkpoint cards - data.items.forEach(checkpoint => { - const card = createCheckpointCard(checkpoint); - grid.appendChild(card); - }); - - // Increment the page number AFTER successful loading - if (data.items.length > 0) { - pageState.currentPage++; - } - } catch (error) { - console.error('Error loading checkpoints:', error); - showToast('Failed to load checkpoints', 'error'); - } finally { - // Clear loading state - const pageState = getCurrentPageState(); - pageState.isLoading = false; - document.body.classList.remove('loading'); - } + return loadMoreModels({ + resetPage: resetPagination, + updateFolders: true, + modelType: 'checkpoint', + createCardFunction: createCheckpointCard, + endpoint: '/api/checkpoints' + }); } // Reset and reload checkpoints export async function resetAndReload() { - const pageState = getCurrentPageState(); - pageState.currentPage = 1; - pageState.hasMore = true; - await loadMoreCheckpoints(true); + return baseResetAndReload({ + updateFolders: true, + modelType: 'checkpoint', + loadMoreFunction: loadMoreCheckpoints + }); } // Refresh checkpoints export async function refreshCheckpoints() { - try { - showToast('Scanning for checkpoints...', 'info'); - const response = await fetch('/api/checkpoints/scan'); - - if (!response.ok) { - throw new Error(`Failed to scan checkpoints: ${response.status} ${response.statusText}`); - } - - await resetAndReload(); - showToast('Checkpoints refreshed successfully', 'success'); - } catch (error) { - console.error('Error refreshing checkpoints:', error); - showToast('Failed to refresh checkpoints', 'error'); - } + return baseRefreshModels({ + modelType: 'checkpoint', + scanEndpoint: '/api/checkpoints/scan', + resetAndReloadFunction: resetAndReload + }); } // Delete a checkpoint export function deleteCheckpoint(filePath) { - confirmDelete('Are you sure you want to delete this checkpoint?', () => { - _performDelete(filePath); - }); -} - -// Private function to perform the delete operation -async function _performDelete(filePath) { - try { - showToast('Deleting checkpoint...', 'info'); - - const response = await fetch('/api/model/delete', { - method: 'POST', - headers: { - 'Content-Type': 'application/json' - }, - body: JSON.stringify({ - file_path: filePath, - model_type: 'checkpoint' - }) - }); - - if (!response.ok) { - throw new Error(`Failed to delete checkpoint: ${response.status} ${response.statusText}`); - } - - const data = await response.json(); - - if (data.success) { - // Remove the card from UI - const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); - if (card) { - card.remove(); - } - - showToast('Checkpoint deleted successfully', 'success'); - } else { - throw new Error(data.error || 'Failed to delete checkpoint'); - } - } catch (error) { - console.error('Error deleting checkpoint:', error); - showToast(`Failed to delete checkpoint: ${error.message}`, 'error'); - } + return baseDeleteModel(filePath, 'checkpoint'); } // Replace checkpoint preview export function replaceCheckpointPreview(filePath) { - // Open file picker - const input = document.createElement('input'); - input.type = 'file'; - input.accept = 'image/*'; - input.onchange = async (e) => { - if (!e.target.files.length) return; - - const file = e.target.files[0]; - await _uploadPreview(filePath, file); - }; - input.click(); -} - -// Upload a preview image -async function _uploadPreview(filePath, file) { - try { - showToast('Uploading preview...', 'info'); - - const formData = new FormData(); - formData.append('file', file); - formData.append('file_path', filePath); - formData.append('model_type', 'checkpoint'); - - const response = await fetch('/api/model/preview', { - method: 'POST', - body: formData - }); - - if (!response.ok) { - throw new Error(`Failed to upload preview: ${response.status} ${response.statusText}`); - } - - const data = await response.json(); - - if (data.success) { - // Update the preview in UI - const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); - if (card) { - const img = card.querySelector('.card-preview img'); - if (img) { - // Add timestamp to prevent caching - const timestamp = new Date().getTime(); - if (data.preview_url) { - img.src = `${data.preview_url}?t=${timestamp}`; - } else { - img.src = `/api/model/preview_image?path=${encodeURIComponent(filePath)}&t=${timestamp}`; - } - } - } - - showToast('Preview updated successfully', 'success'); - } else { - throw new Error(data.error || 'Failed to update preview'); - } - } catch (error) { - console.error('Error updating preview:', error); - showToast(`Failed to update preview: ${error.message}`, 'error'); - } + return replaceModelPreview(filePath, 'checkpoint'); } // Fetch metadata from Civitai for checkpoints export async function fetchCivitai() { - let ws = null; - - await state.loadingManager.showWithProgress(async (loading) => { - try { - const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; - const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`); - - const operationComplete = new Promise((resolve, reject) => { - ws.onmessage = (event) => { - const data = JSON.parse(event.data); - - switch(data.status) { - case 'started': - loading.setStatus('Starting metadata fetch...'); - break; - - case 'processing': - const percent = ((data.processed / data.total) * 100).toFixed(1); - loading.setProgress(percent); - loading.setStatus( - `Processing (${data.processed}/${data.total}) ${data.current_name}` - ); - break; - - case 'completed': - loading.setProgress(100); - loading.setStatus( - `Completed: Updated ${data.success} of ${data.processed} checkpoints` - ); - resolve(); - break; - - case 'error': - reject(new Error(data.error)); - break; - } - }; - - ws.onerror = (error) => { - reject(new Error('WebSocket error: ' + error.message)); - }; - }); - - await new Promise((resolve, reject) => { - ws.onopen = resolve; - ws.onerror = reject; - }); - - const response = await fetch('/api/checkpoints/fetch-all-civitai', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ model_type: 'checkpoint' }) // Specify we're fetching checkpoint metadata - }); - - if (!response.ok) { - throw new Error('Failed to fetch metadata'); - } - - await operationComplete; - - await resetAndReload(); - - } catch (error) { - console.error('Error fetching metadata:', error); - showToast('Failed to fetch metadata: ' + error.message, 'error'); - } finally { - if (ws) { - ws.close(); - } - } - }, { - initialMessage: 'Connecting...', - completionMessage: 'Metadata update complete' + return fetchCivitaiMetadata({ + modelType: 'checkpoint', + fetchEndpoint: '/api/checkpoints/fetch-all-civitai', + resetAndReloadFunction: resetAndReload }); } \ No newline at end of file diff --git a/static/js/api/loraApi.js b/static/js/api/loraApi.js index c344e930..5e433799 100644 --- a/static/js/api/loraApi.js +++ b/static/js/api/loraApi.js @@ -1,285 +1,38 @@ -import { state, getCurrentPageState } from '../state/index.js'; -import { showToast } from '../utils/uiHelpers.js'; import { createLoraCard } from '../components/LoraCard.js'; -import { initializeInfiniteScroll } from '../utils/infiniteScroll.js'; -import { showDeleteModal } from '../utils/modalUtils.js'; -import { toggleFolder } from '../utils/uiHelpers.js'; -import { getSessionItem } from '../utils/storageHelpers.js'; +import { + loadMoreModels, + resetAndReload as baseResetAndReload, + refreshModels as baseRefreshModels, + deleteModel as baseDeleteModel, + replaceModelPreview, + fetchCivitaiMetadata, + refreshSingleModelMetadata +} from './baseModelApi.js'; export async function loadMoreLoras(resetPage = false, updateFolders = false) { - const pageState = getCurrentPageState(); - - if (pageState.isLoading || (!pageState.hasMore && !resetPage)) return; - - pageState.isLoading = true; - try { - // Reset to first page if requested - if (resetPage) { - pageState.currentPage = 1; - // Clear grid if resetting - const grid = document.getElementById('loraGrid'); - if (grid) grid.innerHTML = ''; - } - - const params = new URLSearchParams({ - page: pageState.currentPage, - page_size: 20, - sort_by: pageState.sortBy - }); - - if (pageState.activeFolder !== null) { - params.append('folder', pageState.activeFolder); - } - - // Add search parameters if there's a search term - if (pageState.filters?.search) { - params.append('search', pageState.filters.search); - params.append('fuzzy', 'true'); - - // Add search option parameters if available - if (pageState.searchOptions) { - params.append('search_filename', pageState.searchOptions.filename.toString()); - params.append('search_modelname', pageState.searchOptions.modelname.toString()); - params.append('search_tags', (pageState.searchOptions.tags || false).toString()); - params.append('recursive', (pageState.searchOptions?.recursive ?? false).toString()); - } - } - - // Add filter parameters if active - if (pageState.filters) { - if (pageState.filters.tags && pageState.filters.tags.length > 0) { - // Convert the array of tags to a comma-separated string - params.append('tags', pageState.filters.tags.join(',')); - } - if (pageState.filters.baseModel && pageState.filters.baseModel.length > 0) { - // Convert the array of base models to a comma-separated string - params.append('base_models', pageState.filters.baseModel.join(',')); - } - } - - // Check for recipe-based filtering parameters from session storage - const filterLoraHash = getSessionItem('recipe_to_lora_filterLoraHash'); - const filterLoraHashes = getSessionItem('recipe_to_lora_filterLoraHashes'); - - // Add hash filter parameter if present - if (filterLoraHash) { - params.append('lora_hash', filterLoraHash); - } - // Add multiple hashes filter if present - else if (filterLoraHashes) { - try { - if (Array.isArray(filterLoraHashes) && filterLoraHashes.length > 0) { - params.append('lora_hashes', filterLoraHashes.join(',')); - } - } catch (error) { - console.error('Error parsing lora hashes from session storage:', error); - } - } - - const response = await fetch(`/api/loras?${params}`); - if (!response.ok) { - throw new Error(`Failed to fetch loras: ${response.statusText}`); - } - - const data = await response.json(); - - if (data.items.length === 0 && pageState.currentPage === 1) { - const grid = document.getElementById('loraGrid'); - grid.innerHTML = '
No loras found in this folder
'; - pageState.hasMore = false; - } else if (data.items.length > 0) { - pageState.hasMore = pageState.currentPage < data.total_pages; - appendLoraCards(data.items); - - // Increment the page number AFTER successful loading - pageState.currentPage++; - } else { - pageState.hasMore = false; - } - - if (updateFolders && data.folders) { - updateFolderTags(data.folders); - } - - } catch (error) { - console.error('Error loading loras:', error); - showToast('Failed to load loras: ' + error.message, 'error'); - } finally { - pageState.isLoading = false; - } -} - -function updateFolderTags(folders) { - const folderTagsContainer = document.querySelector('.folder-tags'); - if (!folderTagsContainer) return; - - // Keep track of currently selected folder - const pageState = getCurrentPageState(); - const currentFolder = pageState.activeFolder; - - // Create HTML for folder tags - const tagsHTML = folders.map(folder => { - const isActive = folder === currentFolder; - return `
${folder}
`; - }).join(''); - - // Update the container - folderTagsContainer.innerHTML = tagsHTML; - - // Reattach click handlers and ensure the active tag is visible - const tags = folderTagsContainer.querySelectorAll('.tag'); - tags.forEach(tag => { - tag.addEventListener('click', toggleFolder); - if (tag.dataset.folder === currentFolder) { - tag.scrollIntoView({ behavior: 'smooth', block: 'nearest' }); - } + return loadMoreModels({ + resetPage, + updateFolders, + modelType: 'lora', + createCardFunction: createLoraCard, + endpoint: '/api/loras' }); } export async function fetchCivitai() { - let ws = null; - - await state.loadingManager.showWithProgress(async (loading) => { - try { - const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; - const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`); - - const operationComplete = new Promise((resolve, reject) => { - ws.onmessage = (event) => { - const data = JSON.parse(event.data); - - switch(data.status) { - case 'started': - loading.setStatus('Starting metadata fetch...'); - break; - - case 'processing': - const percent = ((data.processed / data.total) * 100).toFixed(1); - loading.setProgress(percent); - loading.setStatus( - `Processing (${data.processed}/${data.total}) ${data.current_name}` - ); - break; - - case 'completed': - loading.setProgress(100); - loading.setStatus( - `Completed: Updated ${data.success} of ${data.processed} loras` - ); - resolve(); - break; - - case 'error': - reject(new Error(data.error)); - break; - } - }; - - ws.onerror = (error) => { - reject(new Error('WebSocket error: ' + error.message)); - }; - }); - - await new Promise((resolve, reject) => { - ws.onopen = resolve; - ws.onerror = reject; - }); - - const response = await fetch('/api/fetch-all-civitai', { - method: 'POST', - headers: { 'Content-Type': 'application/json' } - }); - - if (!response.ok) { - throw new Error('Failed to fetch metadata'); - } - - await operationComplete; - - await resetAndReload(); - - } catch (error) { - console.error('Error fetching metadata:', error); - showToast('Failed to fetch metadata: ' + error.message, 'error'); - } finally { - if (ws) { - ws.close(); - } - } - }, { - initialMessage: 'Connecting...', - completionMessage: 'Metadata update complete' + return fetchCivitaiMetadata({ + modelType: 'lora', + fetchEndpoint: '/api/fetch-all-civitai', + resetAndReloadFunction: resetAndReload }); } export async function deleteModel(filePath) { - showDeleteModal(filePath); + return baseDeleteModel(filePath, 'lora'); } export async function replacePreview(filePath) { - const loadingOverlay = document.getElementById('loading-overlay'); - const loadingStatus = document.querySelector('.loading-status'); - - const input = document.createElement('input'); - input.type = 'file'; - input.accept = 'image/*,video/mp4'; - - input.onchange = async function() { - if (!input.files || !input.files[0]) return; - - const file = input.files[0]; - const formData = new FormData(); - formData.append('preview_file', file); - formData.append('model_path', filePath); - - try { - loadingOverlay.style.display = 'flex'; - loadingStatus.textContent = 'Uploading preview...'; - - const response = await fetch('/api/replace_preview', { - method: 'POST', - body: formData - }); - - if (!response.ok) { - throw new Error('Upload failed'); - } - - const data = await response.json(); - - // 更新预览版本 - state.previewVersions.set(filePath, Date.now()); - - // 更新卡片显示 - const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); - const previewContainer = card.querySelector('.card-preview'); - const oldPreview = previewContainer.querySelector('img, video'); - - const previewUrl = `${data.preview_url}?t=${state.previewVersions.get(filePath)}`; - - if (file.type.startsWith('video/')) { - const video = document.createElement('video'); - video.controls = true; - video.autoplay = true; - video.muted = true; - video.loop = true; - video.src = previewUrl; - oldPreview.replaceWith(video); - } else { - const img = document.createElement('img'); - img.src = previewUrl; - oldPreview.replaceWith(img); - } - - } catch (error) { - console.error('Error uploading preview:', error); - alert('Failed to upload preview image'); - } finally { - loadingOverlay.style.display = 'none'; - } - }; - - input.click(); + return replaceModelPreview(filePath, 'lora'); } export function appendLoraCards(loras) { @@ -293,57 +46,26 @@ export function appendLoraCards(loras) { } export async function resetAndReload(updateFolders = false) { - const pageState = getCurrentPageState(); - console.log('Resetting with state:', { ...pageState }); - - // Reset pagination and load more loras - await loadMoreLoras(true, updateFolders); + return baseResetAndReload({ + updateFolders, + modelType: 'lora', + loadMoreFunction: loadMoreLoras + }); } export async function refreshLoras() { - try { - state.loadingManager.showSimpleLoading('Refreshing loras...'); - await resetAndReload(); - showToast('Refresh complete', 'success'); - } catch (error) { - console.error('Refresh failed:', error); - showToast('Failed to refresh loras', 'error'); - } finally { - state.loadingManager.hide(); - state.loadingManager.restoreProgressBar(); - } + return baseRefreshModels({ + modelType: 'lora', + scanEndpoint: '/api/loras/scan', + resetAndReloadFunction: resetAndReload + }); } export async function refreshSingleLoraMetadata(filePath) { - try { - state.loadingManager.showSimpleLoading('Refreshing metadata...'); - const response = await fetch('/api/fetch-civitai', { - method: 'POST', - headers: { - 'Content-Type': 'application/json' - }, - body: JSON.stringify({ file_path: filePath }) - }); - - if (!response.ok) { - throw new Error('Failed to refresh metadata'); - } - - const data = await response.json(); - - if (data.success) { - showToast('Metadata refreshed successfully', 'success'); - // Reload the current view to show updated data - await resetAndReload(); - } else { - throw new Error(data.error || 'Failed to refresh metadata'); - } - } catch (error) { - console.error('Error refreshing metadata:', error); - showToast(error.message, 'error'); - } finally { - state.loadingManager.hide(); - state.loadingManager.restoreProgressBar(); + const success = await refreshSingleModelMetadata(filePath, 'lora'); + if (success) { + // Reload the current view to show updated data + await resetAndReload(); } } diff --git a/static/js/components/CheckpointCard.js b/static/js/components/CheckpointCard.js index abb95afe..c000ecfc 100644 --- a/static/js/components/CheckpointCard.js +++ b/static/js/components/CheckpointCard.js @@ -2,6 +2,7 @@ import { showToast } from '../utils/uiHelpers.js'; import { state } from '../state/index.js'; import { showCheckpointModal } from './checkpointModal/index.js'; import { NSFW_LEVELS } from '../utils/constants.js'; +import { replaceCheckpointPreview as apiReplaceCheckpointPreview } from '../api/checkpointApi.js'; export function createCheckpointCard(checkpoint) { const card = document.createElement('div'); @@ -305,6 +306,6 @@ function replaceCheckpointPreview(filePath) { if (window.replaceCheckpointPreview) { window.replaceCheckpointPreview(filePath); } else { - console.log('Replace checkpoint preview:', filePath); + apiReplaceCheckpointPreview(filePath); } } \ No newline at end of file From 1db49a4dd43643a4c9917d3ccfaf5945609a7f43 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 11 Apr 2025 18:25:37 +0800 Subject: [PATCH 23/36] refactor: Enhance checkpoint download functionality with new modal and manager integration --- py/routes/checkpoints_routes.py | 36 ++ py/services/download_manager.py | 87 ++-- py/utils/routes_common.py | 81 +++- static/js/checkpoints.js | 4 + .../controls/CheckpointsControls.js | 9 + static/js/components/controls/PageControls.js | 18 +- .../js/managers/CheckpointDownloadManager.js | 423 ++++++++++++++++++ static/js/managers/ModalManager.js | 13 + templates/components/checkpoint_modals.html | 69 +++ templates/components/lora_modals.html | 2 +- 10 files changed, 699 insertions(+), 43 deletions(-) create mode 100644 static/js/managers/CheckpointDownloadManager.js diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index e65bea5b..1d3627a5 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -3,12 +3,14 @@ import json import jinja2 from aiohttp import web import logging +import asyncio from ..utils.routes_common import ModelRouteUtils from ..utils.constants import NSFW_LEVELS from ..services.civitai_client import CivitaiClient from ..services.websocket_manager import ws_manager from ..services.checkpoint_scanner import CheckpointScanner +from ..services.download_manager import DownloadManager from ..config import config from ..services.settings_manager import settings from ..utils.utils import fuzzy_match @@ -24,6 +26,8 @@ class CheckpointsRoutes: loader=jinja2.FileSystemLoader(config.templates_path), autoescape=True ) + self.download_manager = DownloadManager() + self._download_lock = asyncio.Lock() def setup_routes(self, app): """Register routes with the aiohttp app""" @@ -34,11 +38,13 @@ class CheckpointsRoutes: app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags) app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints) app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info) + app.router.add_get('/api/checkpoints/roots', self.get_checkpoint_roots) # Add new routes for model management similar to LoRA routes app.router.add_post('/api/checkpoints/delete', self.delete_model) app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai) app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview) + app.router.add_post('/api/checkpoints/download', self.download_checkpoint) async def get_checkpoints(self, request): """Get paginated checkpoint data""" @@ -478,3 +484,33 @@ class CheckpointsRoutes: async def replace_preview(self, request: web.Request) -> web.Response: """Handle preview image replacement for checkpoints""" return await ModelRouteUtils.handle_replace_preview(request, self.scanner) + + async def download_checkpoint(self, request: web.Request) -> web.Response: + """Handle checkpoint download request""" + async with self._download_lock: + # Initialize DownloadManager with the file monitor if the scanner has one + if not hasattr(self, 'download_manager') or self.download_manager is None: + file_monitor = getattr(self.scanner, 'file_monitor', None) + self.download_manager = DownloadManager(file_monitor) + + # Use the common download handler with model_type="checkpoint" + return await ModelRouteUtils.handle_download_model( + request=request, + download_manager=self.download_manager, + model_type="checkpoint" + ) + + async def get_checkpoint_roots(self, request): + """Return the checkpoint root directories""" + try: + roots = self.scanner.get_model_roots() + return web.json_response({ + "success": True, + "roots": roots + }) + except Exception as e: + logger.error(f"Error getting checkpoint roots: {e}", exc_info=True) + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 1dc2a945..91786336 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -4,7 +4,7 @@ import json from typing import Optional, Dict from .civitai_client import CivitaiClient from .file_monitor import LoraFileMonitor -from ..utils.models import LoraMetadata +from ..utils.models import LoraMetadata, CheckpointMetadata from ..utils.constants import CARD_PREVIEW_WIDTH from ..utils.exif_utils import ExifUtils @@ -20,7 +20,22 @@ class DownloadManager: async def download_from_civitai(self, download_url: str = None, model_hash: str = None, model_version_id: str = None, save_dir: str = None, - relative_path: str = '', progress_callback=None) -> Dict: + relative_path: str = '', progress_callback=None, + model_type: str = "lora") -> Dict: + """Download model from Civitai + + Args: + download_url: Direct download URL for the model + model_hash: SHA256 hash of the model + model_version_id: Civitai model version ID + save_dir: Directory to save the model to + relative_path: Relative path within save_dir + progress_callback: Callback function for progress updates + model_type: Type of model ('lora' or 'checkpoint') + + Returns: + Dict with download result + """ try: # Update save directory with relative path if provided if relative_path: @@ -46,7 +61,7 @@ class DownloadManager: if not version_info: return {'success': False, 'error': 'Failed to fetch model metadata'} - # Check if this is an early access LoRA + # Check if this is an early access model if version_info.get('earlyAccessEndsAt'): early_access_date = version_info.get('earlyAccessEndsAt', '') # Convert to a readable date if possible @@ -54,12 +69,12 @@ class DownloadManager: from datetime import datetime date_obj = datetime.fromisoformat(early_access_date.replace('Z', '+00:00')) formatted_date = date_obj.strftime('%Y-%m-%d') - early_access_msg = f"This LoRA requires early access payment (until {formatted_date}). " + early_access_msg = f"This model requires early access payment (until {formatted_date}). " except: - early_access_msg = "This LoRA requires early access payment. " + early_access_msg = "This model requires early access payment. " early_access_msg += "Please ensure you have purchased early access and are logged in to Civitai." - logger.warning(f"Early access LoRA detected: {version_info.get('name', 'Unknown')}") + logger.warning(f"Early access model detected: {version_info.get('name', 'Unknown')}") # We'll still try to download, but log a warning and prepare for potential failure if progress_callback: @@ -69,26 +84,32 @@ class DownloadManager: if progress_callback: await progress_callback(0) - # 2. 获取文件信息 + # 2. Get file information file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None) if not file_info: return {'success': False, 'error': 'No primary file found in metadata'} - # 3. 准备下载 + # 3. Prepare download file_name = file_info['name'] save_path = os.path.join(save_dir, file_name) file_size = file_info.get('sizeKB', 0) * 1024 - # 4. 通知文件监控系统 - 使用规范化路径和文件大小 - self.file_monitor.handler.add_ignore_path( - save_path.replace(os.sep, '/'), - file_size - ) + # 4. Notify file monitor - use normalized path and file size + if self.file_monitor and self.file_monitor.handler: + self.file_monitor.handler.add_ignore_path( + save_path.replace(os.sep, '/'), + file_size + ) - # 5. 准备元数据 - metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path) + # 5. Prepare metadata based on model type + if model_type == "checkpoint": + metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path) + logger.info(f"Creating CheckpointMetadata for {file_name}") + else: + metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path) + logger.info(f"Creating LoraMetadata for {file_name}") - # 5.1 获取并更新模型标签和描述信息 + # 5.1 Get and update model tags and description model_id = version_info.get('modelId') if model_id: model_metadata, _ = await self.civitai_client.get_model_metadata(str(model_id)) @@ -98,14 +119,15 @@ class DownloadManager: if model_metadata.get("description"): metadata.modelDescription = model_metadata.get("description", "") - # 6. 开始下载流程 + # 6. Start download process result = await self._execute_download( download_url=file_info.get('downloadUrl', ''), save_dir=save_dir, metadata=metadata, version_info=version_info, relative_path=relative_path, - progress_callback=progress_callback + progress_callback=progress_callback, + model_type=model_type ) return result @@ -119,8 +141,9 @@ class DownloadManager: return {'success': False, 'error': str(e)} async def _execute_download(self, download_url: str, save_dir: str, - metadata: LoraMetadata, version_info: Dict, - relative_path: str, progress_callback=None) -> Dict: + metadata, version_info: Dict, + relative_path: str, progress_callback=None, + model_type: str = "lora") -> Dict: """Execute the actual download process including preview images and model files""" try: save_path = metadata.file_path @@ -201,15 +224,21 @@ class DownloadManager: os.remove(path) return {'success': False, 'error': result} - # 4. 更新文件信息(大小和修改时间) + # 4. Update file information (size and modified time) metadata.update_file_info(save_path) - # 5. 最终更新元数据 + # 5. Final metadata update with open(metadata_path, 'w', encoding='utf-8') as f: json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False) - # 6. update lora cache - cache = await self.file_monitor.scanner.get_cached_data() + # 6. Update cache based on model type + if model_type == "checkpoint" and hasattr(self.file_monitor, "checkpoint_scanner"): + cache = await self.file_monitor.checkpoint_scanner.get_cached_data() + logger.info(f"Updating checkpoint cache for {save_path}") + else: + cache = await self.file_monitor.scanner.get_cached_data() + logger.info(f"Updating lora cache for {save_path}") + metadata_dict = metadata.to_dict() metadata_dict['folder'] = relative_path cache.raw_data.append(metadata_dict) @@ -218,11 +247,11 @@ class DownloadManager: all_folders.add(relative_path) cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) - # Update the hash index with the new LoRA entry - self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) - - # Update the hash index with the new LoRA entry - self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) + # Update the hash index with the new model entry + if model_type == "checkpoint" and hasattr(self.file_monitor, "checkpoint_scanner"): + self.file_monitor.checkpoint_scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) + else: + self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) # Report 100% completion if progress_callback: diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index 5b68b368..6c0dc8d7 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -9,6 +9,7 @@ from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH from ..config import config from ..services.civitai_client import CivitaiClient from ..utils.exif_utils import ExifUtils +from ..services.download_manager import DownloadManager logger = logging.getLogger(__name__) @@ -421,4 +422,82 @@ class ModelRouteUtils: except Exception as e: logger.error(f"Error replacing preview: {e}", exc_info=True) - return web.Response(text=str(e), status=500) \ No newline at end of file + return web.Response(text=str(e), status=500) + + @staticmethod + async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type="lora") -> web.Response: + """Handle model download request + + Args: + request: The aiohttp request + download_manager: Instance of DownloadManager + model_type: Type of model ('lora' or 'checkpoint') + + Returns: + web.Response: The HTTP response + """ + try: + data = await request.json() + + # Create progress callback + async def progress_callback(progress): + from ..services.websocket_manager import ws_manager + await ws_manager.broadcast({ + 'status': 'progress', + 'progress': progress + }) + + # Check which identifier is provided + download_url = data.get('download_url') + model_hash = data.get('model_hash') + model_version_id = data.get('model_version_id') + + # Validate that at least one identifier is provided + if not any([download_url, model_hash, model_version_id]): + return web.Response( + status=400, + text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'" + ) + + # Use the correct root directory based on model type + root_key = 'checkpoint_root' if model_type == 'checkpoint' else 'lora_root' + save_dir = data.get(root_key) + + result = await download_manager.download_from_civitai( + download_url=download_url, + model_hash=model_hash, + model_version_id=model_version_id, + save_dir=save_dir, + relative_path=data.get('relative_path', ''), + progress_callback=progress_callback, + model_type=model_type + ) + + if not result.get('success', False): + error_message = result.get('error', 'Unknown error') + + # Return 401 for early access errors + if 'early access' in error_message.lower(): + logger.warning(f"Early access download failed: {error_message}") + return web.Response( + status=401, # Use 401 status code to match Civitai's response + text=f"Early Access Restriction: {error_message}" + ) + + return web.Response(status=500, text=error_message) + + return web.json_response(result) + + except Exception as e: + error_message = str(e) + + # Check if this might be an early access error + if '401' in error_message: + logger.warning(f"Early access error (401): {error_message}") + return web.Response( + status=401, + text="Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com." + ) + + logger.error(f"Error downloading {model_type}: {error_message}") + return web.Response(status=500, text=error_message) \ No newline at end of file diff --git a/static/js/checkpoints.js b/static/js/checkpoints.js index 72342ed7..2f1d316f 100644 --- a/static/js/checkpoints.js +++ b/static/js/checkpoints.js @@ -3,6 +3,7 @@ import { initializeInfiniteScroll } from './utils/infiniteScroll.js'; import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js'; import { createPageControls } from './components/controls/index.js'; import { loadMoreCheckpoints } from './api/checkpointApi.js'; +import { CheckpointDownloadManager } from './managers/CheckpointDownloadManager.js'; // Initialize the Checkpoints page class CheckpointsPageManager { @@ -10,6 +11,9 @@ class CheckpointsPageManager { // Initialize page controls this.pageControls = createPageControls('checkpoints'); + // Initialize checkpoint download manager + window.checkpointDownloadManager = new CheckpointDownloadManager(); + // Expose only necessary functions to global scope this._exposeRequiredGlobalFunctions(); } diff --git a/static/js/components/controls/CheckpointsControls.js b/static/js/components/controls/CheckpointsControls.js index 44c6104a..8cc323f1 100644 --- a/static/js/components/controls/CheckpointsControls.js +++ b/static/js/components/controls/CheckpointsControls.js @@ -2,6 +2,7 @@ import { PageControls } from './PageControls.js'; import { loadMoreCheckpoints, resetAndReload, refreshCheckpoints, fetchCivitai } from '../../api/checkpointApi.js'; import { showToast } from '../../utils/uiHelpers.js'; +import { CheckpointDownloadManager } from '../../managers/CheckpointDownloadManager.js'; /** * CheckpointsControls class - Extends PageControls for Checkpoint-specific functionality @@ -11,6 +12,9 @@ export class CheckpointsControls extends PageControls { // Initialize with 'checkpoints' page type super('checkpoints'); + // Initialize checkpoint download manager + this.downloadManager = new CheckpointDownloadManager(); + // Register API methods specific to the Checkpoints page this.registerCheckpointsAPI(); } @@ -38,6 +42,11 @@ export class CheckpointsControls extends PageControls { return await fetchCivitai(); }, + // Add show download modal functionality + showDownloadModal: () => { + this.downloadManager.showDownloadModal(); + }, + // No clearCustomFilter implementation is needed for checkpoints // as custom filters are currently only used for LoRAs clearCustomFilter: async () => { diff --git a/static/js/components/controls/PageControls.js b/static/js/components/controls/PageControls.js index 0bc4f64e..77599fca 100644 --- a/static/js/components/controls/PageControls.js +++ b/static/js/components/controls/PageControls.js @@ -103,13 +103,12 @@ export class PageControls { fetchButton.addEventListener('click', () => this.fetchFromCivitai()); } + const downloadButton = document.querySelector('[data-action="download"]'); + if (downloadButton) { + downloadButton.addEventListener('click', () => this.showDownloadModal()); + } + if (this.pageType === 'loras') { - // Download button - LoRAs only - const downloadButton = document.querySelector('[data-action="download"]'); - if (downloadButton) { - downloadButton.addEventListener('click', () => this.showDownloadModal()); - } - // Bulk operations button - LoRAs only const bulkButton = document.querySelector('[data-action="bulk"]'); if (bulkButton) { @@ -349,14 +348,9 @@ export class PageControls { } /** - * Show download modal (LoRAs only) + * Show download modal */ showDownloadModal() { - if (this.pageType !== 'loras' || !this.api) { - console.error('Download modal is only available for LoRAs'); - return; - } - this.api.showDownloadModal(); } diff --git a/static/js/managers/CheckpointDownloadManager.js b/static/js/managers/CheckpointDownloadManager.js new file mode 100644 index 00000000..5dfb235c --- /dev/null +++ b/static/js/managers/CheckpointDownloadManager.js @@ -0,0 +1,423 @@ +import { modalManager } from './ModalManager.js'; +import { showToast } from '../utils/uiHelpers.js'; +import { LoadingManager } from './LoadingManager.js'; +import { state } from '../state/index.js'; +import { resetAndReload } from '../api/checkpointApi.js'; +import { getStorageItem } from '../utils/storageHelpers.js'; + +export class CheckpointDownloadManager { + constructor() { + this.currentVersion = null; + this.versions = []; + this.modelInfo = null; + this.modelVersionId = null; + + this.initialized = false; + this.selectedFolder = ''; + + this.loadingManager = new LoadingManager(); + this.folderClickHandler = null; + this.updateTargetPath = this.updateTargetPath.bind(this); + } + + showDownloadModal() { + console.log('Showing checkpoint download modal...'); + if (!this.initialized) { + const modal = document.getElementById('checkpointDownloadModal'); + if (!modal) { + console.error('Checkpoint download modal element not found'); + return; + } + this.initialized = true; + } + + modalManager.showModal('checkpointDownloadModal', null, () => { + // Cleanup handler when modal closes + this.cleanupFolderBrowser(); + }); + this.resetSteps(); + } + + resetSteps() { + document.querySelectorAll('#checkpointDownloadModal .download-step').forEach(step => step.style.display = 'none'); + document.getElementById('cpUrlStep').style.display = 'block'; + document.getElementById('checkpointUrl').value = ''; + document.getElementById('cpUrlError').textContent = ''; + + // Clear new folder input + const newFolderInput = document.getElementById('cpNewFolder'); + if (newFolderInput) { + newFolderInput.value = ''; + } + + this.currentVersion = null; + this.versions = []; + this.modelInfo = null; + this.modelVersionId = null; + + // Clear selected folder and remove selection from UI + this.selectedFolder = ''; + const folderBrowser = document.getElementById('cpFolderBrowser'); + if (folderBrowser) { + folderBrowser.querySelectorAll('.folder-item').forEach(f => + f.classList.remove('selected')); + } + } + + async validateAndFetchVersions() { + const url = document.getElementById('checkpointUrl').value.trim(); + const errorElement = document.getElementById('cpUrlError'); + + try { + this.loadingManager.showSimpleLoading('Fetching model versions...'); + + const modelId = this.extractModelId(url); + if (!modelId) { + throw new Error('Invalid Civitai URL format'); + } + + const response = await fetch(`/api/civitai/versions/${modelId}`); + if (!response.ok) { + throw new Error('Failed to fetch model versions'); + } + + this.versions = await response.json(); + if (!this.versions.length) { + throw new Error('No versions available for this model'); + } + + // If we have a version ID from URL, pre-select it + if (this.modelVersionId) { + this.currentVersion = this.versions.find(v => v.id.toString() === this.modelVersionId); + } + + this.showVersionStep(); + } catch (error) { + errorElement.textContent = error.message; + } finally { + this.loadingManager.hide(); + } + } + + extractModelId(url) { + const modelMatch = url.match(/civitai\.com\/models\/(\d+)/); + const versionMatch = url.match(/modelVersionId=(\d+)/); + + if (modelMatch) { + this.modelVersionId = versionMatch ? versionMatch[1] : null; + return modelMatch[1]; + } + return null; + } + + showVersionStep() { + document.getElementById('cpUrlStep').style.display = 'none'; + document.getElementById('cpVersionStep').style.display = 'block'; + + const versionList = document.getElementById('cpVersionList'); + versionList.innerHTML = this.versions.map(version => { + const firstImage = version.images?.find(img => !img.url.endsWith('.mp4')); + const thumbnailUrl = firstImage ? firstImage.url : '/loras_static/images/no-preview.png'; + + // Use version-level size or fallback to first file + const fileSize = version.modelSizeKB ? + (version.modelSizeKB / 1024).toFixed(2) : + (version.files[0]?.sizeKB / 1024).toFixed(2); + + // Use version-level existsLocally flag + const existsLocally = version.existsLocally; + const localPath = version.localPath; + + // Check if this is an early access version + const isEarlyAccess = version.availability === 'EarlyAccess'; + + // Create early access badge if needed + let earlyAccessBadge = ''; + if (isEarlyAccess) { + earlyAccessBadge = ` +
+ Early Access +
+ `; + } + + // Status badge for local models + const localStatus = existsLocally ? + `
+ In Library +
${localPath || ''}
+
` : ''; + + return ` +
+
+ Version preview +
+
+
+

${version.name}

+ ${localStatus} +
+
+ ${version.baseModel ? `
${version.baseModel}
` : ''} + ${earlyAccessBadge} +
+
+ ${new Date(version.createdAt).toLocaleDateString()} + ${fileSize} MB +
+
+
+ `; + }).join(''); + + // Update Next button state based on initial selection + this.updateNextButtonState(); + } + + selectVersion(versionId) { + this.currentVersion = this.versions.find(v => v.id.toString() === versionId.toString()); + if (!this.currentVersion) return; + + document.querySelectorAll('#cpVersionList .version-item').forEach(item => { + item.classList.toggle('selected', item.querySelector('h3').textContent === this.currentVersion.name); + }); + + // Update Next button state after selection + this.updateNextButtonState(); + } + + updateNextButtonState() { + const nextButton = document.querySelector('#cpVersionStep .primary-btn'); + if (!nextButton) return; + + const existsLocally = this.currentVersion?.existsLocally; + + if (existsLocally) { + nextButton.disabled = true; + nextButton.classList.add('disabled'); + nextButton.textContent = 'Already in Library'; + } else { + nextButton.disabled = false; + nextButton.classList.remove('disabled'); + nextButton.textContent = 'Next'; + } + } + + async proceedToLocation() { + if (!this.currentVersion) { + showToast('Please select a version', 'error'); + return; + } + + // Double-check if the version exists locally + const existsLocally = this.currentVersion.existsLocally; + if (existsLocally) { + showToast('This version already exists in your library', 'info'); + return; + } + + document.getElementById('cpVersionStep').style.display = 'none'; + document.getElementById('cpLocationStep').style.display = 'block'; + + try { + // Use checkpoint roots endpoint instead of lora roots + const response = await fetch('/api/checkpoints/roots'); + if (!response.ok) { + throw new Error('Failed to fetch checkpoint roots'); + } + + const data = await response.json(); + const checkpointRoot = document.getElementById('checkpointRoot'); + checkpointRoot.innerHTML = data.roots.map(root => + `` + ).join(''); + + // Set default checkpoint root if available + const defaultRoot = getStorageItem('settings', {}).default_checkpoints_root; + if (defaultRoot && data.roots.includes(defaultRoot)) { + checkpointRoot.value = defaultRoot; + } + + // Initialize folder browser after loading roots + this.initializeFolderBrowser(); + } catch (error) { + showToast(error.message, 'error'); + } + } + + backToUrl() { + document.getElementById('cpVersionStep').style.display = 'none'; + document.getElementById('cpUrlStep').style.display = 'block'; + } + + backToVersions() { + document.getElementById('cpLocationStep').style.display = 'none'; + document.getElementById('cpVersionStep').style.display = 'block'; + } + + async startDownload() { + const checkpointRoot = document.getElementById('checkpointRoot').value; + const newFolder = document.getElementById('cpNewFolder').value.trim(); + + if (!checkpointRoot) { + showToast('Please select a checkpoint root directory', 'error'); + return; + } + + // Construct relative path + let targetFolder = ''; + if (this.selectedFolder) { + targetFolder = this.selectedFolder; + } + if (newFolder) { + targetFolder = targetFolder ? + `${targetFolder}/${newFolder}` : newFolder; + } + + try { + const downloadUrl = this.currentVersion.downloadUrl; + if (!downloadUrl) { + throw new Error('No download URL available'); + } + + // Show enhanced loading with progress details + const updateProgress = this.loadingManager.showDownloadProgress(1); + updateProgress(0, 0, this.currentVersion.name); + + // Setup WebSocket for progress updates + const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; + const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`); + + ws.onmessage = (event) => { + const data = JSON.parse(event.data); + if (data.status === 'progress') { + // Update progress display with current progress + updateProgress(data.progress, 0, this.currentVersion.name); + + // Add more detailed status messages based on progress + if (data.progress < 3) { + this.loadingManager.setStatus(`Preparing download...`); + } else if (data.progress === 3) { + this.loadingManager.setStatus(`Downloaded preview image`); + } else if (data.progress > 3 && data.progress < 100) { + this.loadingManager.setStatus(`Downloading checkpoint file`); + } else { + this.loadingManager.setStatus(`Finalizing download...`); + } + } + }; + + ws.onerror = (error) => { + console.error('WebSocket error:', error); + // Continue with download even if WebSocket fails + }; + + // Start download using checkpoint download endpoint + const response = await fetch('/api/checkpoints/download', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + download_url: downloadUrl, + checkpoint_root: checkpointRoot, + relative_path: targetFolder + }) + }); + + if (!response.ok) { + throw new Error(await response.text()); + } + + showToast('Download completed successfully', 'success'); + modalManager.closeModal('checkpointDownloadModal'); + + // Update state and trigger reload with folder update + state.activeFolder = targetFolder; + await resetAndReload(true); // Pass true to update folders + + } catch (error) { + showToast(error.message, 'error'); + } finally { + this.loadingManager.hide(); + } + } + + initializeFolderBrowser() { + const folderBrowser = document.getElementById('cpFolderBrowser'); + if (!folderBrowser) return; + + // Cleanup existing handler if any + this.cleanupFolderBrowser(); + + // Create new handler + this.folderClickHandler = (event) => { + const folderItem = event.target.closest('.folder-item'); + if (!folderItem) return; + + if (folderItem.classList.contains('selected')) { + folderItem.classList.remove('selected'); + this.selectedFolder = ''; + } else { + folderBrowser.querySelectorAll('.folder-item').forEach(f => + f.classList.remove('selected')); + folderItem.classList.add('selected'); + this.selectedFolder = folderItem.dataset.folder; + } + + // Update path display after folder selection + this.updateTargetPath(); + }; + + // Add the new handler + folderBrowser.addEventListener('click', this.folderClickHandler); + + // Add event listeners for path updates + const checkpointRoot = document.getElementById('checkpointRoot'); + const newFolder = document.getElementById('cpNewFolder'); + + checkpointRoot.addEventListener('change', this.updateTargetPath); + newFolder.addEventListener('input', this.updateTargetPath); + + // Update initial path + this.updateTargetPath(); + } + + cleanupFolderBrowser() { + if (this.folderClickHandler) { + const folderBrowser = document.getElementById('cpFolderBrowser'); + if (folderBrowser) { + folderBrowser.removeEventListener('click', this.folderClickHandler); + this.folderClickHandler = null; + } + } + + // Remove path update listeners + const checkpointRoot = document.getElementById('checkpointRoot'); + const newFolder = document.getElementById('cpNewFolder'); + + if (checkpointRoot) checkpointRoot.removeEventListener('change', this.updateTargetPath); + if (newFolder) newFolder.removeEventListener('input', this.updateTargetPath); + } + + updateTargetPath() { + const pathDisplay = document.getElementById('cpTargetPathDisplay'); + const checkpointRoot = document.getElementById('checkpointRoot').value; + const newFolder = document.getElementById('cpNewFolder').value.trim(); + + let fullPath = checkpointRoot || 'Select a checkpoint root directory'; + + if (checkpointRoot) { + if (this.selectedFolder) { + fullPath += '/' + this.selectedFolder; + } + if (newFolder) { + fullPath += '/' + newFolder; + } + } + + pathDisplay.innerHTML = `${fullPath}`; + } +} \ No newline at end of file diff --git a/static/js/managers/ModalManager.js b/static/js/managers/ModalManager.js index 5c2a2700..989a2806 100644 --- a/static/js/managers/ModalManager.js +++ b/static/js/managers/ModalManager.js @@ -35,6 +35,19 @@ export class ModalManager { closeOnOutsideClick: true }); } + + // Add checkpointDownloadModal registration + const checkpointDownloadModal = document.getElementById('checkpointDownloadModal'); + if (checkpointDownloadModal) { + this.registerModal('checkpointDownloadModal', { + element: checkpointDownloadModal, + onClose: () => { + this.getModal('checkpointDownloadModal').element.style.display = 'none'; + document.body.classList.remove('modal-open'); + }, + closeOnOutsideClick: true + }); + } const deleteModal = document.getElementById('deleteModal'); if (deleteModal) { diff --git a/templates/components/checkpoint_modals.html b/templates/components/checkpoint_modals.html index af971475..0a0ea127 100644 --- a/templates/components/checkpoint_modals.html +++ b/templates/components/checkpoint_modals.html @@ -32,4 +32,73 @@
+ + + + \ No newline at end of file diff --git a/templates/components/lora_modals.html b/templates/components/lora_modals.html index 3bc0ab30..8eeb438b 100644 --- a/templates/components/lora_modals.html +++ b/templates/components/lora_modals.html @@ -4,7 +4,7 @@ - +