From 8fdfb687419731171f0f877ac8575d6e09b0c321 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 10 Apr 2025 09:08:36 +0800 Subject: [PATCH] checkpoint --- py/config.py | 2 +- py/lora_manager.py | 107 ++++-- py/routes/checkpoints_routes.py | 156 +++++++-- py/services/checkpoint_scanner.py | 131 +++++++ py/services/file_monitor.py | 90 ++++- py/services/lora_scanner.py | 426 +++++------------------ py/services/model_cache.py | 64 ++++ py/services/model_hash_index.py | 78 +++++ py/services/model_scanner.py | 554 ++++++++++++++++++++++++++++++ py/utils/file_utils.py | 91 +++-- py/utils/lora_metadata.py | 66 +++- py/utils/models.py | 116 ++++--- 12 files changed, 1397 insertions(+), 484 deletions(-) create mode 100644 py/services/checkpoint_scanner.py create mode 100644 py/services/model_cache.py create mode 100644 py/services/model_hash_index.py create mode 100644 py/services/model_scanner.py diff --git a/py/config.py b/py/config.py index a081f472..b0ca1f4c 100644 --- a/py/config.py +++ b/py/config.py @@ -73,7 +73,7 @@ class Config: """添加静态路由映射""" normalized_path = os.path.normpath(path).replace(os.sep, '/') self._route_mappings[normalized_path] = route - logger.info(f"Added route mapping: {normalized_path} -> {route}") + # logger.info(f"Added route mapping: {normalized_path} -> {route}") def map_path_to_link(self, path: str) -> str: """将目标路径映射回符号链接路径""" diff --git a/py/lora_manager.py b/py/lora_manager.py index a7ab3fb8..14556fa4 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -7,10 +7,12 @@ from .routes.api_routes import ApiRoutes from .routes.recipe_routes import RecipeRoutes from .routes.checkpoints_routes import CheckpointsRoutes from .services.lora_scanner import LoraScanner +from .services.checkpoint_scanner import CheckpointScanner from .services.recipe_scanner import RecipeScanner -from .services.file_monitor import LoraFileMonitor +from .services.file_monitor import LoraFileMonitor, CheckpointFileMonitor from .services.lora_cache import LoraCache from .services.recipe_cache import RecipeCache +from .services.model_cache import ModelCache import logging logger = logging.getLogger(__name__) @@ -23,7 +25,7 @@ class LoraManager: """Initialize and register all routes""" app = PromptServer.instance.app - added_targets = set() # 用于跟踪已添加的目标路径 + added_targets = set() # Track already added target paths # Add static routes for each lora root for idx, root in enumerate(config.loras_roots, start=1): @@ -35,15 +37,34 @@ class LoraManager: if link == root: real_root = target break - # 为原始路径添加静态路由 + # Add static route for original path app.router.add_static(preview_path, real_root) logger.info(f"Added static route {preview_path} -> {real_root}") - # 记录路由映射 + # Record route mapping config.add_route_mapping(real_root, preview_path) added_targets.add(real_root) - # 为符号链接的目标路径添加额外的静态路由 + # Add static routes for each checkpoint root + checkpoint_scanner = CheckpointScanner() + for idx, root in enumerate(checkpoint_scanner.get_model_roots(), start=1): + preview_path = f'/checkpoints_static/root{idx}/preview' + + real_root = root + if root in config._path_mappings.values(): + for target, link in config._path_mappings.items(): + if link == root: + real_root = target + break + # Add static route for original path + app.router.add_static(preview_path, real_root) + logger.info(f"Added static route {preview_path} -> {real_root}") + + # Record route mapping + config.add_route_mapping(real_root, preview_path) + added_targets.add(real_root) + + # Add static routes for symlink target paths link_idx = 1 for target_path, link_path in config._path_mappings.items(): @@ -59,37 +80,47 @@ class LoraManager: app.router.add_static('/loras_static', config.static_path) # Setup feature routes - routes = LoraRoutes() + lora_routes = LoraRoutes() checkpoints_routes = CheckpointsRoutes() # Setup file monitoring - monitor = LoraFileMonitor(routes.scanner, config.loras_roots) - monitor.start() + lora_monitor = LoraFileMonitor(lora_routes.scanner, config.loras_roots) + lora_monitor.start() - routes.setup_routes(app) + checkpoint_monitor = CheckpointFileMonitor(checkpoints_routes.scanner, checkpoints_routes.scanner.get_model_roots()) + checkpoint_monitor.start() + + lora_routes.setup_routes(app) checkpoints_routes.setup_routes(app) - ApiRoutes.setup_routes(app, monitor) + ApiRoutes.setup_routes(app, lora_monitor) RecipeRoutes.setup_routes(app) - # Store monitor in app for cleanup - app['lora_monitor'] = monitor + # Store monitors in app for cleanup + app['lora_monitor'] = lora_monitor + app['checkpoint_monitor'] = checkpoint_monitor + + logger.info("PromptServer app: ", app) # Schedule cache initialization using the application's startup handler - app.on_startup.append(lambda app: cls._schedule_cache_init(routes.scanner, routes.recipe_scanner)) + app.on_startup.append(lambda app: cls._schedule_cache_init( + lora_routes.scanner, + checkpoints_routes.scanner, + lora_routes.recipe_scanner + )) # Add cleanup app.on_shutdown.append(cls._cleanup) app.on_shutdown.append(ApiRoutes.cleanup) @classmethod - async def _schedule_cache_init(cls, scanner: LoraScanner, recipe_scanner: RecipeScanner): + async def _schedule_cache_init(cls, lora_scanner, checkpoint_scanner, recipe_scanner): """Schedule cache initialization in the running event loop""" try: - # 创建低优先级的初始化任务 - lora_task = asyncio.create_task(cls._initialize_lora_cache(scanner), name='lora_cache_init') - - # Schedule recipe cache initialization with a delay to let lora scanner initialize first - recipe_task = asyncio.create_task(cls._initialize_recipe_cache(recipe_scanner, delay=2), name='recipe_cache_init') + # Create low-priority initialization tasks + lora_task = asyncio.create_task(cls._initialize_lora_cache(lora_scanner), name='lora_cache_init') + checkpoint_task = asyncio.create_task(cls._initialize_checkpoint_cache(checkpoint_scanner), name='checkpoint_cache_init') + recipe_task = asyncio.create_task(cls._initialize_recipe_cache(recipe_scanner), name='recipe_cache_init') + logger.info("Cache initialization tasks scheduled to run in background") except Exception as e: logger.error(f"LoRA Manager: Error scheduling cache initialization: {e}") @@ -97,26 +128,45 @@ class LoraManager: async def _initialize_lora_cache(cls, scanner: LoraScanner): """Initialize lora cache in background""" try: - # 设置初始缓存占位 + # Set initial placeholder cache scanner._cache = LoraCache( raw_data=[], sorted_by_name=[], sorted_by_date=[], folders=[] ) - - # 分阶段加载缓存 - await scanner.get_cached_data(force_refresh=True) + # 使用线程池执行耗时操作 + loop = asyncio.get_event_loop() + await loop.run_in_executor( + None, # 使用默认线程池 + lambda: scanner.get_cached_data_sync(force_refresh=True) # 创建同步版本的方法 + ) + # Load cache in phases + # await scanner.get_cached_data(force_refresh=True) except Exception as e: logger.error(f"LoRA Manager: Error initializing lora cache: {e}") @classmethod - async def _initialize_recipe_cache(cls, scanner: RecipeScanner, delay: float = 2.0): - """Initialize recipe cache in background with a delay""" + async def _initialize_checkpoint_cache(cls, scanner: CheckpointScanner): + """Initialize checkpoint cache in background""" try: - # Wait for the specified delay to let lora scanner initialize first - await asyncio.sleep(delay) + # Set initial placeholder cache + scanner._cache = ModelCache( + raw_data=[], + sorted_by_name=[], + sorted_by_date=[], + folders=[] + ) + # Load cache in phases + await scanner.get_cached_data(force_refresh=True) + except Exception as e: + logger.error(f"LoRA Manager: Error initializing checkpoint cache: {e}") + + @classmethod + async def _initialize_recipe_cache(cls, scanner: RecipeScanner): + """Initialize recipe cache in background with a delay""" + try: # Set initial empty cache scanner._cache = RecipeCache( raw_data=[], @@ -134,3 +184,6 @@ class LoraManager: """Cleanup resources""" if 'lora_monitor' in app: app['lora_monitor'].stop() + + if 'checkpoint_monitor' in app: + app['checkpoint_monitor'].stop() diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index 0a79d6f9..f2e2e631 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -1,44 +1,146 @@ import os +import json +import asyncio +import aiohttp from aiohttp import web -import jinja2 import logging +from datetime import datetime + +from ..services.checkpoint_scanner import CheckpointScanner from ..config import config -from ..services.settings_manager import settings logger = logging.getLogger(__name__) -logging.getLogger('asyncio').setLevel(logging.CRITICAL) class CheckpointsRoutes: - """Route handlers for Checkpoints management endpoints""" + """API routes for checkpoint management""" def __init__(self): - self.template_env = jinja2.Environment( - loader=jinja2.FileSystemLoader(config.templates_path), - autoescape=True - ) + self.scanner = CheckpointScanner() + + def setup_routes(self, app): + """Register routes with the aiohttp app""" + app.router.add_get('/lora_manager/api/checkpoints', self.get_checkpoints) + app.router.add_get('/lora_manager/api/checkpoints/scan', self.scan_checkpoints) + app.router.add_get('/lora_manager/api/checkpoints/info/{name}', self.get_checkpoint_info) - async def handle_checkpoints_page(self, request: web.Request) -> web.Response: - """Handle GET /checkpoints request""" + async def get_checkpoints(self, request): + """Get paginated checkpoint data""" try: - template = self.template_env.get_template('checkpoints.html') - rendered = template.render( - is_initializing=False, - settings=settings, - request=request + # Parse query parameters + page = int(request.query.get('page', '1')) + page_size = min(int(request.query.get('page_size', '20')), 100) + sort_by = request.query.get('sort', 'name') + folder = request.query.get('folder', None) + search = request.query.get('search', None) + fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true' + base_models = request.query.getall('base_model', []) + tags = request.query.getall('tag', []) + + # Process search options + search_options = { + 'filename': request.query.get('search_filename', 'true').lower() == 'true', + 'modelname': request.query.get('search_modelname', 'true').lower() == 'true', + 'tags': request.query.get('search_tags', 'false').lower() == 'true', + 'recursive': request.query.get('recursive', 'false').lower() == 'true', + } + + # Process hash filters if provided + hash_filters = {} + if 'hash' in request.query: + hash_filters['single_hash'] = request.query['hash'] + elif 'hashes' in request.query: + try: + hash_list = json.loads(request.query['hashes']) + if isinstance(hash_list, list): + hash_filters['multiple_hashes'] = hash_list + except (json.JSONDecodeError, TypeError): + pass + + # Get data from scanner + result = await self.get_paginated_data( + page=page, + page_size=page_size, + sort_by=sort_by, + folder=folder, + search=search, + fuzzy_search=fuzzy_search, + base_models=base_models, + tags=tags, + search_options=search_options, + hash_filters=hash_filters ) - return web.Response( - text=rendered, - content_type='text/html' - ) + # Return as JSON + return web.json_response(result) except Exception as e: - logger.error(f"Error handling checkpoints request: {e}", exc_info=True) - return web.Response( - text="Error loading checkpoints page", - status=500 - ) + logger.error(f"Error in get_checkpoints: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) - def setup_routes(self, app: web.Application): - """Register routes with the application""" - app.router.add_get('/checkpoints', self.handle_checkpoints_page) + async def get_paginated_data(self, page, page_size, sort_by='name', + folder=None, search=None, fuzzy_search=False, + base_models=None, tags=None, + search_options=None, hash_filters=None): + """Get paginated and filtered checkpoint data""" + cache = await self.scanner.get_cached_data() + + # Implement similar filtering logic as in LoraScanner + # (Adapt code from LoraScanner.get_paginated_data) + # ... + + # For now, a simplified implementation: + filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name + + # Apply basic folder filtering if needed + if folder is not None: + filtered_data = [ + cp for cp in filtered_data + if cp['folder'] == folder + ] + + # Apply basic search if needed + if search: + filtered_data = [ + cp for cp in filtered_data + if search.lower() in cp['file_name'].lower() or + search.lower() in cp['model_name'].lower() + ] + + # Calculate pagination + total_items = len(filtered_data) + start_idx = (page - 1) * page_size + end_idx = min(start_idx + page_size, total_items) + + result = { + 'items': filtered_data[start_idx:end_idx], + 'total': total_items, + 'page': page, + 'page_size': page_size, + 'total_pages': (total_items + page_size - 1) // page_size + } + + return result + + async def scan_checkpoints(self, request): + """Force a rescan of checkpoint files""" + try: + await self.scanner.get_cached_data(force_refresh=True) + return web.json_response({"status": "success", "message": "Checkpoint scan completed"}) + except Exception as e: + logger.error(f"Error in scan_checkpoints: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) + + async def get_checkpoint_info(self, request): + """Get detailed information for a specific checkpoint by name""" + try: + name = request.match_info.get('name', '') + checkpoint_info = await self.scanner.get_checkpoint_info_by_name(name) + + if checkpoint_info: + return web.json_response(checkpoint_info) + else: + return web.json_response({"error": "Checkpoint not found"}, status=404) + + except Exception as e: + logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) diff --git a/py/services/checkpoint_scanner.py b/py/services/checkpoint_scanner.py new file mode 100644 index 00000000..27d15273 --- /dev/null +++ b/py/services/checkpoint_scanner.py @@ -0,0 +1,131 @@ +import os +import logging +import asyncio +from typing import List, Dict, Optional, Set +import folder_paths # type: ignore + +from ..utils.models import CheckpointMetadata +from ..config import config +from .model_scanner import ModelScanner +from .model_hash_index import ModelHashIndex + +logger = logging.getLogger(__name__) + +class CheckpointScanner(ModelScanner): + """Service for scanning and managing checkpoint files""" + + _instance = None + _lock = asyncio.Lock() + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if not hasattr(self, '_initialized'): + # Define supported file extensions + file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft'} + super().__init__( + model_type="checkpoint", + model_class=CheckpointMetadata, + file_extensions=file_extensions, + hash_index=ModelHashIndex() + ) + self._checkpoint_roots = self._init_checkpoint_roots() + self._initialized = True + + @classmethod + async def get_instance(cls): + """Get singleton instance with async support""" + async with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def _init_checkpoint_roots(self) -> List[str]: + """Initialize checkpoint roots from ComfyUI settings""" + # Get both checkpoint and diffusion_models paths + checkpoint_paths = folder_paths.get_folder_paths("checkpoints") + diffusion_paths = folder_paths.get_folder_paths("diffusion_models") + + # Combine, normalize and deduplicate paths + all_paths = set() + for path in checkpoint_paths + diffusion_paths: + if os.path.exists(path): + norm_path = path.replace(os.sep, "/") + all_paths.add(norm_path) + + # Sort for consistent order + sorted_paths = sorted(all_paths, key=lambda p: p.lower()) + logger.info(f"Found checkpoint roots: {sorted_paths}") + + return sorted_paths + + def get_model_roots(self) -> List[str]: + """Get checkpoint root directories""" + return self._checkpoint_roots + + async def scan_all_models(self) -> List[Dict]: + """Scan all checkpoint directories and return metadata""" + all_checkpoints = [] + + # Create scan tasks for each directory + scan_tasks = [] + for root in self._checkpoint_roots: + task = asyncio.create_task(self._scan_directory(root)) + scan_tasks.append(task) + + # Wait for all tasks to complete + for task in scan_tasks: + try: + checkpoints = await task + all_checkpoints.extend(checkpoints) + except Exception as e: + logger.error(f"Error scanning checkpoint directory: {e}") + + return all_checkpoints + + async def _scan_directory(self, root_path: str) -> List[Dict]: + """Scan a directory for checkpoint files""" + checkpoints = [] + original_root = root_path + + async def scan_recursive(path: str, visited_paths: set): + try: + real_path = os.path.realpath(path) + if real_path in visited_paths: + logger.debug(f"Skipping already visited path: {path}") + return + visited_paths.add(real_path) + + with os.scandir(path) as it: + entries = list(it) + for entry in entries: + try: + if entry.is_file(follow_symlinks=True): + # Check if file has supported extension + ext = os.path.splitext(entry.name)[1].lower() + if ext in self.file_extensions: + file_path = entry.path.replace(os.sep, "/") + await self._process_single_file(file_path, original_root, checkpoints) + await asyncio.sleep(0) + elif entry.is_dir(follow_symlinks=True): + # For directories, continue scanning with original path + await scan_recursive(entry.path, visited_paths) + except Exception as e: + logger.error(f"Error processing entry {entry.path}: {e}") + except Exception as e: + logger.error(f"Error scanning {path}: {e}") + + await scan_recursive(root_path, set()) + return checkpoints + + async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list): + """Process a single checkpoint file and add to results""" + try: + result = await self._process_model_file(file_path, root_path) + if result: + checkpoints.append(result) + except Exception as e: + logger.error(f"Error processing {file_path}: {e}") \ No newline at end of file diff --git a/py/services/file_monitor.py b/py/services/file_monitor.py index 9ed44d0f..1b9438f3 100644 --- a/py/services/file_monitor.py +++ b/py/services/file_monitor.py @@ -7,6 +7,8 @@ from watchdog.observers import Observer from watchdog.events import FileSystemEventHandler from typing import List, Dict, Set from threading import Lock + +from .checkpoint_scanner import CheckpointScanner from .lora_scanner import LoraScanner from ..config import config @@ -330,4 +332,90 @@ class LoraFileMonitor: self.observer.schedule(self.handler, path, recursive=True) logger.info(f"Added new monitoring path: {path}") except Exception as e: - logger.error(f"Error adding new monitor for {path}: {e}") \ No newline at end of file + logger.error(f"Error adding new monitor for {path}: {e}") + +# Add CheckpointFileMonitor class + +class CheckpointFileMonitor(LoraFileMonitor): + """Monitor for checkpoint file changes""" + + def __init__(self, scanner: CheckpointScanner, roots: List[str]): + # Reuse most of the LoraFileMonitor functionality, but with a different handler + self.scanner = scanner + scanner.set_file_monitor(self) + self.observer = Observer() + self.loop = asyncio.get_event_loop() + self.handler = CheckpointFileHandler(scanner, self.loop) + + # Use existing path mappings + self.monitor_paths = set() + for root in roots: + self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/')) + + # Add all mapped target paths + for target_path in config._path_mappings.keys(): + self.monitor_paths.add(target_path) + +class CheckpointFileHandler(LoraFileHandler): + """Handler for checkpoint file system events""" + + def __init__(self, scanner: CheckpointScanner, loop: asyncio.AbstractEventLoop): + super().__init__(scanner, loop) + # Configure supported file extensions + self.supported_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft'} + + def on_created(self, event): + if event.is_directory: + return + + # Handle supported file extensions directly + file_ext = os.path.splitext(event.src_path)[1].lower() + if file_ext in self.supported_extensions: + if self._should_ignore(event.src_path): + return + + # Process this file directly + normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/') + if normalized_path not in self.scheduled_files: + logger.info(f"Checkpoint file created: {event.src_path}") + self.scheduled_files.add(normalized_path) + self._schedule_update('add', event.src_path) + + # Ignore modifications for a short period after creation + self.loop.call_later( + self.debounce_delay * 2, + self.scheduled_files.discard, + normalized_path + ) + + def on_modified(self, event): + if event.is_directory: + return + + # Only process supported file types + file_ext = os.path.splitext(event.src_path)[1].lower() + if file_ext in self.supported_extensions: + super().on_modified(event) + + def on_deleted(self, event): + if event.is_directory: + return + + file_ext = os.path.splitext(event.src_path)[1].lower() + if file_ext not in self.supported_extensions: + return + + super().on_deleted(event) + + def on_moved(self, event): + """Handle file move/rename events""" + src_ext = os.path.splitext(event.src_path)[1].lower() + dest_ext = os.path.splitext(event.dest_path)[1].lower() + + # If destination has supported extension, treat as new file + if dest_ext in self.supported_extensions: + super().on_moved(event) + + # If source was supported extension, treat as deleted + elif src_ext in self.supported_extensions: + super().on_moved(event) \ No newline at end of file diff --git a/py/services/lora_scanner.py b/py/services/lora_scanner.py index c8142086..c4e10c3d 100644 --- a/py/services/lora_scanner.py +++ b/py/services/lora_scanner.py @@ -4,13 +4,11 @@ import logging import asyncio import shutil import time -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Set from ..utils.models import LoraMetadata from ..config import config -from ..utils.file_utils import load_metadata, get_file_info, normalize_path, find_preview_file, save_metadata -from ..utils.lora_metadata import extract_lora_metadata -from .lora_cache import LoraCache +from .model_scanner import ModelScanner from .lora_hash_index import LoraHashIndex from .settings_manager import settings from ..utils.constants import NSFW_LEVELS @@ -19,7 +17,7 @@ import sys logger = logging.getLogger(__name__) -class LoraScanner: +class LoraScanner(ModelScanner): """Service for scanning and managing LoRA files""" _instance = None @@ -31,20 +29,20 @@ class LoraScanner: return cls._instance def __init__(self): - # 确保初始化只执行一次 + # Ensure initialization happens only once if not hasattr(self, '_initialized'): - self._cache: Optional[LoraCache] = None - self._hash_index = LoraHashIndex() - self._initialization_lock = asyncio.Lock() - self._initialization_task: Optional[asyncio.Task] = None + # Define supported file extensions + file_extensions = {'.safetensors'} + + # Initialize parent class + super().__init__( + model_type="lora", + model_class=LoraMetadata, + file_extensions=file_extensions, + hash_index=LoraHashIndex() + ) self._initialized = True - self.file_monitor = None # Add this line - self._tags_count = {} # Add a dictionary to store tag counts - - def set_file_monitor(self, monitor): - """Set file monitor instance""" - self.file_monitor = monitor - + @classmethod async def get_instance(cls): """Get singleton instance with async support""" @@ -52,89 +50,74 @@ class LoraScanner: if cls._instance is None: cls._instance = cls() return cls._instance - - async def get_cached_data(self, force_refresh: bool = False) -> LoraCache: - """Get cached LoRA data, refresh if needed""" - async with self._initialization_lock: + + def get_model_roots(self) -> List[str]: + """Get lora root directories""" + return config.loras_roots + + async def scan_all_models(self) -> List[Dict]: + """Scan all LoRA directories and return metadata""" + all_loras = [] + + # Create scan tasks for each directory + scan_tasks = [] + for lora_root in self.get_model_roots(): + task = asyncio.create_task(self._scan_directory(lora_root)) + scan_tasks.append(task) - # 如果缓存未初始化但需要响应请求,返回空缓存 - if self._cache is None and not force_refresh: - return LoraCache( - raw_data=[], - sorted_by_name=[], - sorted_by_date=[], - folders=[] - ) - - # 如果正在初始化,等待完成 - if self._initialization_task and not self._initialization_task.done(): - try: - await self._initialization_task - except Exception as e: - logger.error(f"Cache initialization failed: {e}") - self._initialization_task = None - - if (self._cache is None or force_refresh): + # Wait for all tasks to complete + for task in scan_tasks: + try: + loras = await task + all_loras.extend(loras) + except Exception as e: + logger.error(f"Error scanning directory: {e}") - # 创建新的初始化任务 - if not self._initialization_task or self._initialization_task.done(): - self._initialization_task = asyncio.create_task(self._initialize_cache()) + return all_loras + + async def _scan_directory(self, root_path: str) -> List[Dict]: + """Scan a single directory for LoRA files""" + loras = [] + original_root = root_path # Save original root path + + async def scan_recursive(path: str, visited_paths: set): + """Recursively scan directory, avoiding circular symlinks""" + try: + real_path = os.path.realpath(path) + if real_path in visited_paths: + logger.debug(f"Skipping already visited path: {path}") + return + visited_paths.add(real_path) - try: - await self._initialization_task - except Exception as e: - logger.error(f"Cache initialization failed: {e}") - # 如果缓存已存在,继续使用旧缓存 - if self._cache is None: - raise # 如果没有缓存,则抛出异常 - - return self._cache + with os.scandir(path) as it: + entries = list(it) + for entry in entries: + try: + if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions): + # Use original path instead of real path + file_path = entry.path.replace(os.sep, "/") + await self._process_single_file(file_path, original_root, loras) + await asyncio.sleep(0) + elif entry.is_dir(follow_symlinks=True): + # For directories, continue scanning with original path + await scan_recursive(entry.path, visited_paths) + except Exception as e: + logger.error(f"Error processing entry {entry.path}: {e}") + except Exception as e: + logger.error(f"Error scanning {path}: {e}") - async def _initialize_cache(self) -> None: - """Initialize or refresh the cache""" + await scan_recursive(root_path, set()) + return loras + + async def _process_single_file(self, file_path: str, root_path: str, loras: list): + """Process a single file and add to results list""" try: - start_time = time.time() - # Clear existing hash index - self._hash_index.clear() - - # Clear existing tags count - self._tags_count = {} - - # Scan for new data - raw_data = await self.scan_all_loras() - - # Build hash index and tags count - for lora_data in raw_data: - if 'sha256' in lora_data and 'file_path' in lora_data: - self._hash_index.add_entry(lora_data['sha256'].lower(), lora_data['file_path']) - - # Count tags - if 'tags' in lora_data and lora_data['tags']: - for tag in lora_data['tags']: - self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 - - # Update cache - self._cache = LoraCache( - raw_data=raw_data, - sorted_by_name=[], - sorted_by_date=[], - folders=[] - ) - - # Call resort_cache to create sorted views - await self._cache.resort() - - self._initialization_task = None - logger.info(f"LoRA Manager: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} loras") + result = await self._process_model_file(file_path, root_path) + if result: + loras.append(result) except Exception as e: - logger.error(f"LoRA Manager: Error initializing cache: {e}") - self._cache = LoraCache( - raw_data=[], - sorted_by_name=[], - sorted_by_date=[], - folders=[] - ) - + logger.error(f"Error processing {file_path}: {e}") + async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name', folder: str = None, search: str = None, fuzzy_search: bool = False, base_models: list = None, tags: list = None, @@ -280,240 +263,14 @@ class LoraScanner: return result - def invalidate_cache(self): - """Invalidate the current cache""" - self._cache = None - - async def scan_all_loras(self) -> List[Dict]: - """Scan all LoRA directories and return metadata""" - all_loras = [] - - # 分目录异步扫描 - scan_tasks = [] - for loras_root in config.loras_roots: - task = asyncio.create_task(self._scan_directory(loras_root)) - scan_tasks.append(task) - - for task in scan_tasks: - try: - loras = await task - all_loras.extend(loras) - except Exception as e: - logger.error(f"Error scanning directory: {e}") - - return all_loras - - async def _scan_directory(self, root_path: str) -> List[Dict]: - """Scan a single directory for LoRA files""" - loras = [] - original_root = root_path # 保存原始根路径 - - async def scan_recursive(path: str, visited_paths: set): - """递归扫描目录,避免循环链接""" - try: - real_path = os.path.realpath(path) - if real_path in visited_paths: - logger.debug(f"Skipping already visited path: {path}") - return - visited_paths.add(real_path) - - with os.scandir(path) as it: - entries = list(it) - for entry in entries: - try: - if entry.is_file(follow_symlinks=True) and entry.name.endswith('.safetensors'): - # 使用原始路径而不是真实路径 - file_path = entry.path.replace(os.sep, "/") - await self._process_single_file(file_path, original_root, loras) - await asyncio.sleep(0) - elif entry.is_dir(follow_symlinks=True): - # 对于目录,使用原始路径继续扫描 - await scan_recursive(entry.path, visited_paths) - except Exception as e: - logger.error(f"Error processing entry {entry.path}: {e}") - except Exception as e: - logger.error(f"Error scanning {path}: {e}") - - await scan_recursive(root_path, set()) - return loras - - async def _process_single_file(self, file_path: str, root_path: str, loras: list): - """处理单个文件并添加到结果列表""" - try: - result = await self._process_lora_file(file_path, root_path) - if result: - loras.append(result) - except Exception as e: - logger.error(f"Error processing {file_path}: {e}") - - async def _process_lora_file(self, file_path: str, root_path: str) -> Dict: - """Process a single LoRA file and return its metadata""" - # Try loading existing metadata - metadata = await load_metadata(file_path) - - if metadata is None: - # Try to find and use .civitai.info file first - civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info" - if os.path.exists(civitai_info_path): - try: - with open(civitai_info_path, 'r', encoding='utf-8') as f: - version_info = json.load(f) - - file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None) - if file_info: - # Create a minimal file_info with the required fields - file_name = os.path.splitext(os.path.basename(file_path))[0] - file_info['name'] = file_name - - # Use from_civitai_info to create metadata - metadata = LoraMetadata.from_civitai_info(version_info, file_info, file_path) - metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path)) - await save_metadata(file_path, metadata) - logger.debug(f"Created metadata from .civitai.info for {file_path}") - except Exception as e: - logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}") - - # If still no metadata, create new metadata using get_file_info - if metadata is None: - metadata = await get_file_info(file_path) - - # Convert to dict and add folder info - lora_data = metadata.to_dict() - # Try to fetch missing metadata from Civitai if needed - await self._fetch_missing_metadata(file_path, lora_data) - rel_path = os.path.relpath(file_path, root_path) - folder = os.path.dirname(rel_path) - lora_data['folder'] = folder.replace(os.path.sep, '/') - - return lora_data - - async def _fetch_missing_metadata(self, file_path: str, lora_data: Dict) -> None: - """Fetch missing description and tags from Civitai if needed - - Args: - file_path: Path to the lora file - lora_data: Lora metadata dictionary to update - """ - try: - # Skip if already marked as deleted on Civitai - if lora_data.get('civitai_deleted', False): - logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai") - return - - # Check if we need to fetch additional metadata from Civitai - needs_metadata_update = False - model_id = None - - # Check if we have Civitai model ID but missing metadata - if lora_data.get('civitai'): - # Try to get model ID directly from the correct location - model_id = lora_data['civitai'].get('modelId') - - if model_id: - model_id = str(model_id) - # Check if tags are missing or empty - tags_missing = not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0 - - # Check if description is missing or empty - desc_missing = not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "") - - needs_metadata_update = tags_missing or desc_missing - - # Fetch missing metadata if needed - if needs_metadata_update and model_id: - logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}") - from ..services.civitai_client import CivitaiClient - client = CivitaiClient() - - # Get metadata and status code - model_metadata, status_code = await client.get_model_metadata(model_id) - await client.close() - - # Handle 404 status (model deleted from Civitai) - if status_code == 404: - logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)") - # Mark as deleted to avoid future API calls - lora_data['civitai_deleted'] = True - - # Save the updated metadata back to file - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(lora_data, f, indent=2, ensure_ascii=False) - - # Process valid metadata if available - elif model_metadata: - logger.debug(f"Updating metadata for {file_path} with model ID {model_id}") - - # Update tags if they were missing - if model_metadata.get('tags') and (not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0): - lora_data['tags'] = model_metadata['tags'] - - # Update description if it was missing - if model_metadata.get('description') and (not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "")): - lora_data['modelDescription'] = model_metadata['description'] - - # Save the updated metadata back to file - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(lora_data, f, indent=2, ensure_ascii=False) - except Exception as e: - logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}") - - async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool: - """Update preview URL in cache for a specific lora - - Args: - file_path: The file path of the lora to update - preview_url: The new preview URL - - Returns: - bool: True if the update was successful, False if cache doesn't exist or lora wasn't found - """ - if self._cache is None: - return False - - return await self._cache.update_preview_url(file_path, preview_url) - - async def scan_single_lora(self, file_path: str) -> Optional[Dict]: - """Scan a single LoRA file and return its metadata""" - try: - if not os.path.exists(os.path.realpath(file_path)): - return None - - # 获取基本文件信息 - metadata = await get_file_info(file_path) - if not metadata: - return None - - folder = self._calculate_folder(file_path) - - # 确保 folder 字段存在 - metadata_dict = metadata.to_dict() - metadata_dict['folder'] = folder or '' - - return metadata_dict - - except Exception as e: - logger.error(f"Error scanning {file_path}: {e}") - return None - - def _calculate_folder(self, file_path: str) -> str: - """Calculate the folder path for a LoRA file""" - # 使用原始路径计算相对路径 - for root in config.loras_roots: - if file_path.startswith(root): - rel_path = os.path.relpath(file_path, root) - return os.path.dirname(rel_path).replace(os.path.sep, '/') - return '' - async def move_model(self, source_path: str, target_path: str) -> bool: """Move a model and its associated files to a new location""" try: - # 保持原始路径格式 + # Keep original path format source_path = source_path.replace(os.sep, '/') target_path = target_path.replace(os.sep, '/') - # 其余代码保持不变 + # Rest of the code remains unchanged base_name = os.path.splitext(os.path.basename(source_path))[0] source_dir = os.path.dirname(source_path) @@ -521,7 +278,7 @@ class LoraScanner: target_lora = os.path.join(target_path, f"{base_name}.safetensors").replace(os.sep, '/') - # 使用真实路径进行文件操作 + # Use real paths for file operations real_source = os.path.realpath(source_path) real_target = os.path.realpath(target_lora) @@ -537,7 +294,7 @@ class LoraScanner: file_size ) - # 使用真实路径进行文件操作 + # Use real paths for file operations shutil.move(real_source, real_target) # Move associated files @@ -648,7 +405,7 @@ class LoraScanner: except Exception as e: logger.error(f"Error updating metadata paths: {e}", exc_info=True) - # Add new methods for hash index functionality + # Lora-specific hash index functionality def has_lora_hash(self, sha256: str) -> bool: """Check if a LoRA with given hash exists""" return self._hash_index.has_hash(sha256.lower()) @@ -681,16 +438,8 @@ class LoraScanner: return None - # Add new method to get top tags async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]: - """Get top tags sorted by count - - Args: - limit: Maximum number of tags to return - - Returns: - List of dictionaries with tag name and count, sorted by count - """ + """Get top tags sorted by count""" # Make sure cache is initialized await self.get_cached_data() @@ -705,14 +454,7 @@ class LoraScanner: return sorted_tags[:limit] async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]: - """Get base models used in loras sorted by frequency - - Args: - limit: Maximum number of base models to return - - Returns: - List of dictionaries with base model name and count, sorted by count - """ + """Get base models used in loras sorted by frequency""" # Make sure cache is initialized cache = await self.get_cached_data() diff --git a/py/services/model_cache.py b/py/services/model_cache.py new file mode 100644 index 00000000..b652d919 --- /dev/null +++ b/py/services/model_cache.py @@ -0,0 +1,64 @@ +import asyncio +from typing import List, Dict +from dataclasses import dataclass +from operator import itemgetter + +@dataclass +class ModelCache: + """Cache structure for model data""" + raw_data: List[Dict] + sorted_by_name: List[Dict] + sorted_by_date: List[Dict] + folders: List[str] + + def __post_init__(self): + self._lock = asyncio.Lock() + + async def resort(self, name_only: bool = False): + """Resort all cached data views""" + async with self._lock: + self.sorted_by_name = sorted( + self.raw_data, + key=lambda x: x['model_name'].lower() # Case-insensitive sort + ) + if not name_only: + self.sorted_by_date = sorted( + self.raw_data, + key=itemgetter('modified'), + reverse=True + ) + # Update folder list + all_folders = set(l['folder'] for l in self.raw_data) + self.folders = sorted(list(all_folders), key=lambda x: x.lower()) + + async def update_preview_url(self, file_path: str, preview_url: str) -> bool: + """Update preview_url for a specific model in all cached data + + Args: + file_path: The file path of the model to update + preview_url: The new preview URL + + Returns: + bool: True if the update was successful, False if the model wasn't found + """ + async with self._lock: + # Update in raw_data + for item in self.raw_data: + if item['file_path'] == file_path: + item['preview_url'] = preview_url + break + else: + return False # Model not found + + # Update in sorted lists (references to the same dict objects) + for item in self.sorted_by_name: + if item['file_path'] == file_path: + item['preview_url'] = preview_url + break + + for item in self.sorted_by_date: + if item['file_path'] == file_path: + item['preview_url'] = preview_url + break + + return True \ No newline at end of file diff --git a/py/services/model_hash_index.py b/py/services/model_hash_index.py new file mode 100644 index 00000000..2f8ef0eb --- /dev/null +++ b/py/services/model_hash_index.py @@ -0,0 +1,78 @@ +from typing import Dict, Optional, Set + +class ModelHashIndex: + """Index for looking up models by hash or path""" + + def __init__(self): + self._hash_to_path: Dict[str, str] = {} + self._path_to_hash: Dict[str, str] = {} + + def add_entry(self, sha256: str, file_path: str) -> None: + """Add or update hash index entry""" + if not sha256 or not file_path: + return + + # Ensure hash is lowercase for consistency + sha256 = sha256.lower() + + # Remove old path mapping if hash exists + if sha256 in self._hash_to_path: + old_path = self._hash_to_path[sha256] + if old_path in self._path_to_hash: + del self._path_to_hash[old_path] + + # Remove old hash mapping if path exists + if file_path in self._path_to_hash: + old_hash = self._path_to_hash[file_path] + if old_hash in self._hash_to_path: + del self._hash_to_path[old_hash] + + # Add new mappings + self._hash_to_path[sha256] = file_path + self._path_to_hash[file_path] = sha256 + + def remove_by_path(self, file_path: str) -> None: + """Remove entry by file path""" + if file_path in self._path_to_hash: + hash_val = self._path_to_hash[file_path] + if hash_val in self._hash_to_path: + del self._hash_to_path[hash_val] + del self._path_to_hash[file_path] + + def remove_by_hash(self, sha256: str) -> None: + """Remove entry by hash""" + sha256 = sha256.lower() + if sha256 in self._hash_to_path: + path = self._hash_to_path[sha256] + if path in self._path_to_hash: + del self._path_to_hash[path] + del self._hash_to_path[sha256] + + def has_hash(self, sha256: str) -> bool: + """Check if hash exists in index""" + return sha256.lower() in self._hash_to_path + + def get_path(self, sha256: str) -> Optional[str]: + """Get file path for a hash""" + return self._hash_to_path.get(sha256.lower()) + + def get_hash(self, file_path: str) -> Optional[str]: + """Get hash for a file path""" + return self._path_to_hash.get(file_path) + + def clear(self) -> None: + """Clear all entries""" + self._hash_to_path.clear() + self._path_to_hash.clear() + + def get_all_hashes(self) -> Set[str]: + """Get all hashes in the index""" + return set(self._hash_to_path.keys()) + + def get_all_paths(self) -> Set[str]: + """Get all file paths in the index""" + return set(self._path_to_hash.keys()) + + def __len__(self) -> int: + """Get number of entries""" + return len(self._hash_to_path) \ No newline at end of file diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py new file mode 100644 index 00000000..b15e35c9 --- /dev/null +++ b/py/services/model_scanner.py @@ -0,0 +1,554 @@ +import json +import os +import logging +import asyncio +import time +import shutil +from typing import List, Dict, Optional, Type, Set + +from ..utils.models import BaseModelMetadata +from ..config import config +from ..utils.file_utils import load_metadata, get_file_info, find_preview_file, save_metadata +from .model_cache import ModelCache +from .model_hash_index import ModelHashIndex + +logger = logging.getLogger(__name__) + +class ModelScanner: + """Base service for scanning and managing model files""" + + _instance = None + _lock = asyncio.Lock() + + def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None): + """Initialize the scanner + + Args: + model_type: Type of model (lora, checkpoint, etc.) + model_class: Class used to create metadata instances + file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'}) + hash_index: Hash index instance (optional) + """ + self.model_type = model_type + self.model_class = model_class + self.file_extensions = file_extensions + self._cache = None + self._hash_index = hash_index or ModelHashIndex() + self._initialization_lock = asyncio.Lock() + self._initialization_task = None + self.file_monitor = None + self._tags_count = {} # Dictionary to store tag counts + + def set_file_monitor(self, monitor): + """Set file monitor instance""" + self.file_monitor = monitor + + async def get_cached_data(self, force_refresh: bool = False) -> ModelCache: + """Get cached model data, refresh if needed""" + async with self._initialization_lock: + # Return empty cache if not initialized and no refresh requested + if self._cache is None and not force_refresh: + return ModelCache( + raw_data=[], + sorted_by_name=[], + sorted_by_date=[], + folders=[] + ) + + # Wait for ongoing initialization if any + if self._initialization_task and not self._initialization_task.done(): + try: + await self._initialization_task + except Exception as e: + logger.error(f"Cache initialization failed: {e}") + self._initialization_task = None + + if (self._cache is None or force_refresh): + # Create new initialization task + if not self._initialization_task or self._initialization_task.done(): + self._initialization_task = asyncio.create_task(self._initialize_cache()) + + try: + await self._initialization_task + except Exception as e: + logger.error(f"Cache initialization failed: {e}") + # Continue using old cache if it exists + if self._cache is None: + raise # Raise exception if no cache available + + return self._cache + + async def _initialize_cache(self) -> None: + """Initialize or refresh the cache""" + try: + start_time = time.time() + # Clear existing hash index + self._hash_index.clear() + + # Clear existing tags count + self._tags_count = {} + + # Scan for new data + raw_data = await self.scan_all_models() + + # Build hash index and tags count + for model_data in raw_data: + if 'sha256' in model_data and 'file_path' in model_data: + self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path']) + + # Count tags + if 'tags' in model_data and model_data['tags']: + for tag in model_data['tags']: + self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 + + # Update cache + self._cache = ModelCache( + raw_data=raw_data, + sorted_by_name=[], + sorted_by_date=[], + folders=[] + ) + + # Resort cache + await self._cache.resort() + + self._initialization_task = None + logger.info(f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} models") + except Exception as e: + logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache: {e}") + self._cache = ModelCache( + raw_data=[], + sorted_by_name=[], + sorted_by_date=[], + folders=[] + ) + + # These methods should be implemented in child classes + async def scan_all_models(self) -> List[Dict]: + """Scan all model directories and return metadata""" + raise NotImplementedError("Subclasses must implement scan_all_models") + + def get_model_roots(self) -> List[str]: + """Get model root directories""" + raise NotImplementedError("Subclasses must implement get_model_roots") + + async def scan_single_model(self, file_path: str) -> Optional[Dict]: + """Scan a single model file and return its metadata""" + try: + if not os.path.exists(os.path.realpath(file_path)): + return None + + # Get basic file info + metadata = await self._get_file_info(file_path) + if not metadata: + return None + + folder = self._calculate_folder(file_path) + + # Ensure folder field exists + metadata_dict = metadata.to_dict() + metadata_dict['folder'] = folder or '' + + return metadata_dict + + except Exception as e: + logger.error(f"Error scanning {file_path}: {e}") + return None + + async def _get_file_info(self, file_path: str) -> Optional[BaseModelMetadata]: + """Get model file info and metadata (extensible for different model types)""" + # Implementation may vary by model type - override in subclasses if needed + return await get_file_info(file_path, self.model_class) + + def _calculate_folder(self, file_path: str) -> str: + """Calculate the folder path for a model file""" + # Use original path to calculate relative path + for root in self.get_model_roots(): + if file_path.startswith(root): + rel_path = os.path.relpath(file_path, root) + return os.path.dirname(rel_path).replace(os.path.sep, '/') + return '' + + # Common methods shared between scanners + async def _process_model_file(self, file_path: str, root_path: str) -> Dict: + """Process a single model file and return its metadata""" + # Try loading existing metadata + metadata = await load_metadata(file_path, self.model_class) + + if metadata is None: + # Try to find and use .civitai.info file first + civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info" + if os.path.exists(civitai_info_path): + try: + with open(civitai_info_path, 'r', encoding='utf-8') as f: + version_info = json.load(f) + + file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None) + if file_info: + # Create a minimal file_info with the required fields + file_name = os.path.splitext(os.path.basename(file_path))[0] + file_info['name'] = file_name + + # Use from_civitai_info to create metadata + metadata = self.model_class.from_civitai_info(version_info, file_info, file_path) + metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path)) + await save_metadata(file_path, metadata) + logger.debug(f"Created metadata from .civitai.info for {file_path}") + except Exception as e: + logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}") + + # If still no metadata, create new metadata + if metadata is None: + metadata = await self._get_file_info(file_path) + + # Convert to dict and add folder info + model_data = metadata.to_dict() + + # Try to fetch missing metadata from Civitai if needed + await self._fetch_missing_metadata(file_path, model_data) + rel_path = os.path.relpath(file_path, root_path) + folder = os.path.dirname(rel_path) + model_data['folder'] = folder.replace(os.path.sep, '/') + + return model_data + + async def _fetch_missing_metadata(self, file_path: str, model_data: Dict) -> None: + """Fetch missing description and tags from Civitai if needed""" + try: + # Skip if already marked as deleted on Civitai + if model_data.get('civitai_deleted', False): + logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai") + return + + # Check if we need to fetch additional metadata from Civitai + needs_metadata_update = False + model_id = None + + # Check if we have Civitai model ID but missing metadata + if model_data.get('civitai'): + model_id = model_data['civitai'].get('modelId') + + if model_id: + model_id = str(model_id) + # Check if tags or description are missing + tags_missing = not model_data.get('tags') or len(model_data.get('tags', [])) == 0 + desc_missing = not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "") + needs_metadata_update = tags_missing or desc_missing + + # Fetch missing metadata if needed + if needs_metadata_update and model_id: + logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}") + from ..services.civitai_client import CivitaiClient + client = CivitaiClient() + + # Get metadata and status code + model_metadata, status_code = await client.get_model_metadata(model_id) + await client.close() + + # Handle 404 status (model deleted from Civitai) + if status_code == 404: + logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)") + model_data['civitai_deleted'] = True + + # Save the updated metadata + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(model_data, f, indent=2, ensure_ascii=False) + + # Process valid metadata if available + elif model_metadata: + logger.debug(f"Updating metadata for {file_path} with model ID {model_id}") + + # Update tags if they were missing + if model_metadata.get('tags') and (not model_data.get('tags') or len(model_data.get('tags', [])) == 0): + model_data['tags'] = model_metadata['tags'] + + # Update description if it was missing + if model_metadata.get('description') and (not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")): + model_data['modelDescription'] = model_metadata['description'] + + # Save the updated metadata + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(model_data, f, indent=2, ensure_ascii=False) + except Exception as e: + logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}") + + async def _scan_directory(self, root_path: str) -> List[Dict]: + """Base implementation for directory scanning""" + models = [] + original_root = root_path + + async def scan_recursive(path: str, visited_paths: set): + try: + real_path = os.path.realpath(path) + if real_path in visited_paths: + logger.debug(f"Skipping already visited path: {path}") + return + visited_paths.add(real_path) + + with os.scandir(path) as it: + entries = list(it) + for entry in entries: + try: + if entry.is_file(follow_symlinks=True): + # Check if file has supported extension + ext = os.path.splitext(entry.name)[1].lower() + if ext in self.file_extensions: + file_path = entry.path.replace(os.sep, "/") + await self._process_single_file(file_path, original_root, models) + await asyncio.sleep(0) + elif entry.is_dir(follow_symlinks=True): + # For directories, continue scanning with original path + await scan_recursive(entry.path, visited_paths) + except Exception as e: + logger.error(f"Error processing entry {entry.path}: {e}") + except Exception as e: + logger.error(f"Error scanning {path}: {e}") + + await scan_recursive(root_path, set()) + return models + + async def _process_single_file(self, file_path: str, root_path: str, models_list: list): + """Process a single file and add to results list""" + try: + result = await self._process_model_file(file_path, root_path) + if result: + models_list.append(result) + except Exception as e: + logger.error(f"Error processing {file_path}: {e}") + + async def move_model(self, source_path: str, target_path: str) -> bool: + """Move a model and its associated files to a new location""" + try: + # Keep original path format + source_path = source_path.replace(os.sep, '/') + target_path = target_path.replace(os.sep, '/') + + # Get file extension from source + file_ext = os.path.splitext(source_path)[1] + + # If no extension or not in supported extensions, return False + if not file_ext or file_ext.lower() not in self.file_extensions: + logger.error(f"Invalid file extension for model: {file_ext}") + return False + + base_name = os.path.splitext(os.path.basename(source_path))[0] + source_dir = os.path.dirname(source_path) + + os.makedirs(target_path, exist_ok=True) + + target_file = os.path.join(target_path, f"{base_name}{file_ext}").replace(os.sep, '/') + + # Use real paths for file operations + real_source = os.path.realpath(source_path) + real_target = os.path.realpath(target_file) + + file_size = os.path.getsize(real_source) + + if self.file_monitor: + self.file_monitor.handler.add_ignore_path( + real_source, + file_size + ) + self.file_monitor.handler.add_ignore_path( + real_target, + file_size + ) + + # Use real paths for file operations + shutil.move(real_source, real_target) + + # Move associated files + source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json") + metadata = None + if os.path.exists(source_metadata): + target_metadata = os.path.join(target_path, f"{base_name}.metadata.json") + shutil.move(source_metadata, target_metadata) + metadata = await self._update_metadata_paths(target_metadata, target_file) + + # Move preview file if exists + preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4', + '.png', '.jpeg', '.jpg', '.mp4'] + for ext in preview_extensions: + source_preview = os.path.join(source_dir, f"{base_name}{ext}") + if os.path.exists(source_preview): + target_preview = os.path.join(target_path, f"{base_name}{ext}") + shutil.move(source_preview, target_preview) + break + + # Update cache + await self.update_single_model_cache(source_path, target_file, metadata) + + return True + + except Exception as e: + logger.error(f"Error moving model: {e}", exc_info=True) + return False + + async def _update_metadata_paths(self, metadata_path: str, model_path: str) -> Dict: + """Update file paths in metadata file""" + try: + with open(metadata_path, 'r', encoding='utf-8') as f: + metadata = json.load(f) + + # Update file_path + metadata['file_path'] = model_path.replace(os.sep, '/') + + # Update preview_url if exists + if 'preview_url' in metadata: + preview_dir = os.path.dirname(model_path) + preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0] + preview_ext = os.path.splitext(metadata['preview_url'])[1] + new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}") + metadata['preview_url'] = new_preview_path.replace(os.sep, '/') + + # Save updated metadata + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) + + return metadata + + except Exception as e: + logger.error(f"Error updating metadata paths: {e}", exc_info=True) + return None + + async def update_single_model_cache(self, original_path: str, new_path: str, metadata: Dict) -> bool: + """Update cache after a model has been moved or modified""" + cache = await self.get_cached_data() + + # Find the existing item to remove its tags from count + existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None) + if existing_item and 'tags' in existing_item: + for tag in existing_item.get('tags', []): + if tag in self._tags_count: + self._tags_count[tag] = max(0, self._tags_count[tag] - 1) + if self._tags_count[tag] == 0: + del self._tags_count[tag] + + # Remove old path from hash index if exists + self._hash_index.remove_by_path(original_path) + + # Remove the old entry from raw_data + cache.raw_data = [ + item for item in cache.raw_data + if item['file_path'] != original_path + ] + + if metadata: + # If this is an update to an existing path (not a move), ensure folder is preserved + if original_path == new_path: + # Find the folder from existing entries or calculate it + existing_folder = next((item['folder'] for item in cache.raw_data + if item['file_path'] == original_path), None) + if existing_folder: + metadata['folder'] = existing_folder + else: + metadata['folder'] = self._calculate_folder(new_path) + else: + # For moved files, recalculate the folder + metadata['folder'] = self._calculate_folder(new_path) + + # Add the updated metadata to raw_data + cache.raw_data.append(metadata) + + # Update hash index with new path + if 'sha256' in metadata: + self._hash_index.add_entry(metadata['sha256'].lower(), new_path) + + # Update folders list + all_folders = set(item['folder'] for item in cache.raw_data) + cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) + + # Update tags count with the new/updated tags + if 'tags' in metadata: + for tag in metadata.get('tags', []): + self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 + + # Resort cache + await cache.resort() + + return True + + # Hash index functionality (common for all model types) + def has_hash(self, sha256: str) -> bool: + """Check if a model with given hash exists""" + return self._hash_index.has_hash(sha256.lower()) + + def get_path_by_hash(self, sha256: str) -> Optional[str]: + """Get file path for a model by its hash""" + return self._hash_index.get_path(sha256.lower()) + + def get_hash_by_path(self, file_path: str) -> Optional[str]: + """Get hash for a model by its file path""" + return self._hash_index.get_hash(file_path) + + def get_preview_url_by_hash(self, sha256: str) -> Optional[str]: + """Get preview static URL for a model by its hash""" + # Get the file path first + file_path = self._hash_index.get_path(sha256.lower()) + if not file_path: + return None + + # Determine the preview file path (typically same name with different extension) + base_name = os.path.splitext(file_path)[0] + preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4', + '.png', '.jpeg', '.jpg', '.mp4'] + + for ext in preview_extensions: + preview_path = f"{base_name}{ext}" + if os.path.exists(preview_path): + # Convert to static URL using config + return config.get_preview_static_url(preview_path) + + return None + + async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]: + """Get top tags sorted by count""" + # Make sure cache is initialized + await self.get_cached_data() + + # Sort tags by count in descending order + sorted_tags = sorted( + [{"tag": tag, "count": count} for tag, count in self._tags_count.items()], + key=lambda x: x['count'], + reverse=True + ) + + # Return limited number + return sorted_tags[:limit] + + async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]: + """Get base models sorted by frequency""" + # Make sure cache is initialized + cache = await self.get_cached_data() + + # Count base model occurrences + base_model_counts = {} + for model in cache.raw_data: + if 'base_model' in model and model['base_model']: + base_model = model['base_model'] + base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 + + # Sort base models by count + sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()] + sorted_models.sort(key=lambda x: x['count'], reverse=True) + + # Return limited number + return sorted_models[:limit] + + async def get_model_info_by_name(self, name): + """Get model information by name""" + try: + # Get cached data + cache = await self.get_cached_data() + + # Find the model by name + for model in cache.raw_data: + if model.get("file_name") == name: + return model + + return None + except Exception as e: + logger.error(f"Error getting model info by name: {e}", exc_info=True) + return None \ No newline at end of file diff --git a/py/utils/file_utils.py b/py/utils/file_utils.py index 0f282051..859e86ae 100644 --- a/py/utils/file_utils.py +++ b/py/utils/file_utils.py @@ -2,12 +2,12 @@ import logging import os import hashlib import json -from typing import Dict, Optional +import time +from typing import Dict, Optional, Type from .model_utils import determine_base_model - -from .lora_metadata import extract_lora_metadata -from .models import LoraMetadata +from .lora_metadata import extract_lora_metadata, extract_checkpoint_metadata +from .models import BaseModelMetadata, LoraMetadata, CheckpointMetadata logger = logging.getLogger(__name__) @@ -15,7 +15,7 @@ async def calculate_sha256(file_path: str) -> str: """Calculate SHA256 hash of a file""" sha256_hash = hashlib.sha256() with open(file_path, "rb") as f: - for byte_block in iter(lambda: f.read(4096), b""): + for byte_block in iter(lambda: f.read(128 * 1024), b""): sha256_hash.update(byte_block) return sha256_hash.hexdigest() @@ -42,8 +42,8 @@ def normalize_path(path: str) -> str: """Normalize file path to use forward slashes""" return path.replace(os.sep, "/") if path else path -async def get_file_info(file_path: str) -> Optional[LoraMetadata]: - """Get basic file information as LoraMetadata object""" +async def get_file_info(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]: + """Get basic file information as a model metadata object""" # First check if file actually exists and resolve symlinks try: real_path = os.path.realpath(file_path) @@ -74,27 +74,52 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]: try: # If we didn't get SHA256 from the .json file, calculate it if not sha256: + start_time = time.time() sha256 = await calculate_sha256(real_path) + logger.debug(f"Calculated SHA256 for {file_path} in {time.time() - start_time:.2f} seconds") + + # Create default metadata based on model class + if model_class == CheckpointMetadata: + metadata = CheckpointMetadata( + file_name=base_name, + model_name=base_name, + file_path=normalize_path(file_path), + size=os.path.getsize(real_path), + modified=os.path.getmtime(real_path), + sha256=sha256, + base_model="Unknown", # Will be updated later + preview_url=normalize_path(preview_url), + tags=[], + modelDescription="", + model_type="checkpoint" + ) - metadata = LoraMetadata( - file_name=base_name, - model_name=base_name, - file_path=normalize_path(file_path), - size=os.path.getsize(real_path), - modified=os.path.getmtime(real_path), - sha256=sha256, - base_model="Unknown", # Will be updated later - usage_tips="", - notes="", - from_civitai=True, - preview_url=normalize_path(preview_url), - tags=[], - modelDescription="" - ) + # Extract checkpoint-specific metadata + # model_info = await extract_checkpoint_metadata(real_path) + # metadata.base_model = model_info['base_model'] + # if 'model_type' in model_info: + # metadata.model_type = model_info['model_type'] + + else: # Default to LoraMetadata + metadata = LoraMetadata( + file_name=base_name, + model_name=base_name, + file_path=normalize_path(file_path), + size=os.path.getsize(real_path), + modified=os.path.getmtime(real_path), + sha256=sha256, + base_model="Unknown", # Will be updated later + usage_tips="{}", + preview_url=normalize_path(preview_url), + tags=[], + modelDescription="" + ) + + # Extract lora-specific metadata + model_info = await extract_lora_metadata(real_path) + metadata.base_model = model_info['base_model'] - # create metadata file - base_model_info = await extract_lora_metadata(real_path) - metadata.base_model = base_model_info['base_model'] + # Save metadata to file await save_metadata(file_path, metadata) return metadata @@ -102,7 +127,7 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]: logger.error(f"Error getting file info for {file_path}: {e}") return None -async def save_metadata(file_path: str, metadata: LoraMetadata) -> None: +async def save_metadata(file_path: str, metadata: BaseModelMetadata) -> None: """Save metadata to .metadata.json file""" metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json" try: @@ -115,7 +140,7 @@ async def save_metadata(file_path: str, metadata: LoraMetadata) -> None: except Exception as e: print(f"Error saving metadata to {metadata_path}: {str(e)}") -async def load_metadata(file_path: str) -> Optional[LoraMetadata]: +async def load_metadata(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]: """Load metadata from .metadata.json file""" metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json" try: @@ -162,12 +187,22 @@ async def load_metadata(file_path: str) -> Optional[LoraMetadata]: if 'modelDescription' not in data: data['modelDescription'] = "" needs_update = True + + # For checkpoint metadata + if model_class == CheckpointMetadata and 'model_type' not in data: + data['model_type'] = "checkpoint" + needs_update = True + + # For lora metadata + if model_class == LoraMetadata and 'usage_tips' not in data: + data['usage_tips'] = "{}" + needs_update = True if needs_update: with open(metadata_path, 'w', encoding='utf-8') as f: json.dump(data, f, indent=2, ensure_ascii=False) - return LoraMetadata.from_dict(data) + return model_class.from_dict(data) except Exception as e: print(f"Error loading metadata from {metadata_path}: {str(e)}") diff --git a/py/utils/lora_metadata.py b/py/utils/lora_metadata.py index c04cd829..f221562d 100644 --- a/py/utils/lora_metadata.py +++ b/py/utils/lora_metadata.py @@ -1,6 +1,7 @@ from safetensors import safe_open from typing import Dict from .model_utils import determine_base_model +import os async def extract_lora_metadata(file_path: str) -> Dict: """Extract essential metadata from safetensors file""" @@ -13,4 +14,67 @@ async def extract_lora_metadata(file_path: str) -> Dict: return {"base_model": base_model} except Exception as e: print(f"Error reading metadata from {file_path}: {str(e)}") - return {"base_model": "Unknown"} \ No newline at end of file + return {"base_model": "Unknown"} + +async def extract_checkpoint_metadata(file_path: str) -> dict: + """Extract metadata from a checkpoint file to determine model type and base model""" + try: + # Analyze filename for clues about the model + filename = os.path.basename(file_path).lower() + + model_info = { + 'base_model': 'Unknown', + 'model_type': 'checkpoint' + } + + # Detect base model from filename + if 'xl' in filename or 'sdxl' in filename: + model_info['base_model'] = 'SDXL' + elif 'sd3' in filename: + model_info['base_model'] = 'SD3' + elif 'sd2' in filename or 'v2' in filename: + model_info['base_model'] = 'SD2.x' + elif 'sd1' in filename or 'v1' in filename: + model_info['base_model'] = 'SD1.5' + + # Detect model type from filename + if 'inpaint' in filename: + model_info['model_type'] = 'inpainting' + elif 'anime' in filename: + model_info['model_type'] = 'anime' + elif 'realistic' in filename: + model_info['model_type'] = 'realistic' + + # Try to peek at the safetensors file structure if available + if file_path.endswith('.safetensors'): + import json + import struct + + with open(file_path, 'rb') as f: + header_size = struct.unpack(' 'LoraMetadata': - """Create LoraMetadata instance from dictionary""" - # Create a copy of the data to avoid modifying the input + def from_dict(cls, data: Dict) -> 'BaseModelMetadata': + """Create instance from dictionary""" data_copy = data.copy() return cls(**data_copy) - @classmethod - def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata': - """Create LoraMetadata instance from Civitai version info""" - file_name = file_info['name'] - base_model = determine_base_model(version_info.get('baseModel', '')) - - return cls( - file_name=os.path.splitext(file_name)[0], - model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]), - file_path=save_path.replace(os.sep, '/'), - size=file_info.get('sizeKB', 0) * 1024, - modified=datetime.now().timestamp(), - sha256=file_info['hashes'].get('SHA256', '').lower(), - base_model=base_model, - preview_url=None, # Will be updated after preview download - preview_nsfw_level=0, # Will be updated after preview download, it is decided by the nsfw level of the preview image - from_civitai=True, - civitai=version_info - ) - def to_dict(self) -> Dict: """Convert to dictionary for JSON serialization""" return asdict(self) @@ -76,30 +54,54 @@ class LoraMetadata: self.file_path = file_path.replace(os.sep, '/') @dataclass -class CheckpointMetadata: - """Represents the metadata structure for a Checkpoint model""" - file_name: str # The filename without extension - model_name: str # The checkpoint's name defined by the creator - file_path: str # Full path to the model file - size: int # File size in bytes - modified: float # Last modified timestamp - sha256: str # SHA256 hash of the file - base_model: str # Base model type (SD1.5/SD2.1/SDXL/etc.) - preview_url: str # Preview image URL - preview_nsfw_level: int = 0 # NSFW level of the preview image - model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.) - notes: str = "" # Additional notes - from_civitai: bool = True # Whether from Civitai - civitai: Optional[Dict] = None # Civitai API data if available - tags: List[str] = None # Model tags - modelDescription: str = "" # Full model description - - # Additional checkpoint-specific fields - resolution: Optional[str] = None # Native resolution (e.g., 512x512, 1024x1024) - vae_included: bool = False # Whether VAE is included in the checkpoint - architecture: str = "" # Model architecture (if known) - - def __post_init__(self): - if self.tags is None: - self.tags = [] +class LoraMetadata(BaseModelMetadata): + """Represents the metadata structure for a Lora model""" + usage_tips: str = "{}" # Usage tips for the model, json string + + @classmethod + def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata': + """Create LoraMetadata instance from Civitai version info""" + file_name = file_info['name'] + base_model = determine_base_model(version_info.get('baseModel', '')) + + return cls( + file_name=os.path.splitext(file_name)[0], + model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]), + file_path=save_path.replace(os.sep, '/'), + size=file_info.get('sizeKB', 0) * 1024, + modified=datetime.now().timestamp(), + sha256=file_info['hashes'].get('SHA256', '').lower(), + base_model=base_model, + preview_url=None, # Will be updated after preview download + preview_nsfw_level=0, # Will be updated after preview download + from_civitai=True, + civitai=version_info + ) + +@dataclass +class CheckpointMetadata(BaseModelMetadata): + """Represents the metadata structure for a Checkpoint model""" + model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.) + + @classmethod + def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'CheckpointMetadata': + """Create CheckpointMetadata instance from Civitai version info""" + file_name = file_info['name'] + base_model = determine_base_model(version_info.get('baseModel', '')) + model_type = version_info.get('type', 'checkpoint') + + return cls( + file_name=os.path.splitext(file_name)[0], + model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]), + file_path=save_path.replace(os.sep, '/'), + size=file_info.get('sizeKB', 0) * 1024, + modified=datetime.now().timestamp(), + sha256=file_info['hashes'].get('SHA256', '').lower(), + base_model=base_model, + preview_url=None, # Will be updated after preview download + preview_nsfw_level=0, + from_civitai=True, + civitai=version_info, + model_type=model_type + )