import json import logging import os from datetime import datetime from typing import Any, Dict, List, Optional from ..utils.models import CheckpointMetadata from ..utils.file_utils import find_preview_file, normalize_path from ..utils.metadata_manager import MetadataManager from ..config import config from .model_scanner import ModelScanner from .model_hash_index import ModelHashIndex logger = logging.getLogger(__name__) class CheckpointScanner(ModelScanner): """Service for scanning and managing checkpoint files""" def __init__(self): # Define supported file extensions file_extensions = { ".ckpt", ".pt", ".pt2", ".bin", ".pth", ".safetensors", ".pkl", ".sft", ".gguf", } super().__init__( model_type="checkpoint", model_class=CheckpointMetadata, file_extensions=file_extensions, hash_index=ModelHashIndex(), ) async def _create_default_metadata( self, file_path: str ) -> Optional[CheckpointMetadata]: """Create default metadata for checkpoint without calculating hash (lazy hash). Checkpoints are typically large (10GB+), so we skip hash calculation during initial scanning to improve startup performance. Hash will be calculated on-demand when fetching metadata from Civitai. """ try: real_path = os.path.realpath(file_path) if not os.path.exists(real_path): logger.error(f"File not found: {file_path}") return None base_name = os.path.splitext(os.path.basename(file_path))[0] dir_path = os.path.dirname(file_path) # Find preview image preview_url = find_preview_file(base_name, dir_path) # Create metadata WITHOUT calculating hash metadata = CheckpointMetadata( file_name=base_name, model_name=base_name, file_path=normalize_path(file_path), size=os.path.getsize(real_path), modified=datetime.now().timestamp(), sha256="", # Empty hash - will be calculated on-demand base_model="Unknown", preview_url=normalize_path(preview_url), tags=[], modelDescription="", sub_type="checkpoint", from_civitai=False, # Mark as local model since no hash yet hash_status="pending", # Mark hash as pending ) # Save the created metadata logger.info(f"Creating checkpoint metadata (hash pending) for {file_path}") await MetadataManager.save_metadata(file_path, metadata) return metadata except Exception as e: logger.error( f"Error creating default checkpoint metadata for {file_path}: {e}" ) return None async def calculate_hash_for_model(self, file_path: str) -> Optional[str]: """Calculate hash for a checkpoint on-demand. Args: file_path: Path to the model file Returns: SHA256 hash string, or None if calculation failed """ from ..utils.file_utils import calculate_sha256 try: real_path = os.path.realpath(file_path) if not os.path.exists(real_path): logger.error(f"File not found for hash calculation: {file_path}") return None # Load current metadata metadata, _ = await MetadataManager.load_metadata( file_path, self.model_class ) if metadata is None: logger.error(f"No metadata found for {file_path}") return None # Check if hash is already calculated if metadata.hash_status == "completed" and metadata.sha256: return metadata.sha256 # Update status to calculating metadata.hash_status = "calculating" await MetadataManager.save_metadata(file_path, metadata) # Calculate hash logger.info(f"Calculating hash for checkpoint: {file_path}") sha256 = await calculate_sha256(real_path) # Update metadata with hash metadata.sha256 = sha256 metadata.hash_status = "completed" await MetadataManager.save_metadata(file_path, metadata) # Update hash index self._hash_index.add_entry(sha256.lower(), file_path) logger.info(f"Hash calculated for checkpoint: {file_path}") return sha256 except Exception as e: logger.error(f"Error calculating hash for {file_path}: {e}") # Update status to failed try: metadata, _ = await MetadataManager.load_metadata( file_path, self.model_class ) if metadata: metadata.hash_status = "failed" await MetadataManager.save_metadata(file_path, metadata) except Exception: pass return None async def calculate_all_pending_hashes( self, progress_callback=None ) -> Dict[str, int]: """Calculate hashes for all checkpoints with pending hash status. If cache is not initialized, scans filesystem directly for metadata files with hash_status != 'completed'. Args: progress_callback: Optional callback(progress, total, current_file) Returns: Dict with 'completed', 'failed', 'total' counts """ # Try to get from cache first cache = await self.get_cached_data() if cache and cache.raw_data: # Use cache if available pending_models = [ item for item in cache.raw_data if item.get("hash_status") != "completed" or not item.get("sha256") ] else: # Cache not initialized, scan filesystem directly pending_models = await self._find_pending_models_from_filesystem() if not pending_models: return {"completed": 0, "failed": 0, "total": 0} total = len(pending_models) completed = 0 failed = 0 for i, model_data in enumerate(pending_models): file_path = model_data.get("file_path") if not file_path: continue try: sha256 = await self.calculate_hash_for_model(file_path) if sha256: completed += 1 else: failed += 1 except Exception as e: logger.error(f"Error calculating hash for {file_path}: {e}") failed += 1 if progress_callback: try: await progress_callback(i + 1, total, file_path) except Exception: pass return {"completed": completed, "failed": failed, "total": total} async def _find_pending_models_from_filesystem(self) -> List[Dict[str, Any]]: """Scan filesystem for checkpoint metadata files with pending hash status.""" pending_models = [] for root_path in self.get_model_roots(): if not os.path.exists(root_path): continue for dirpath, _dirnames, filenames in os.walk(root_path): for filename in filenames: if not filename.endswith(".metadata.json"): continue metadata_path = os.path.join(dirpath, filename) try: with open(metadata_path, "r", encoding="utf-8") as f: data = json.load(f) # Check if hash is pending hash_status = data.get("hash_status", "completed") sha256 = data.get("sha256", "") if hash_status != "completed" or not sha256: # Find corresponding model file model_name = filename.replace(".metadata.json", "") model_path = None # Look for model file with matching name for ext in self.file_extensions: potential_path = os.path.join(dirpath, model_name + ext) if os.path.exists(potential_path): model_path = potential_path break if model_path: pending_models.append( { "file_path": model_path.replace(os.sep, "/"), "hash_status": hash_status, "sha256": sha256, **{ k: v for k, v in data.items() if k not in [ "file_path", "hash_status", "sha256", ] }, } ) except (json.JSONDecodeError, Exception) as e: logger.debug( f"Error reading metadata file {metadata_path}: {e}" ) continue return pending_models def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]: """Resolve the sub-type based on the root path. Checks both standard ComfyUI paths and LoRA Manager's extra folder paths. """ if not root_path: return None # Check standard ComfyUI checkpoint paths if config.checkpoints_roots and root_path in config.checkpoints_roots: return "checkpoint" # Check extra checkpoint paths if ( config.extra_checkpoints_roots and root_path in config.extra_checkpoints_roots ): return "checkpoint" # Check standard ComfyUI unet paths if config.unet_roots and root_path in config.unet_roots: return "diffusion_model" # Check extra unet paths if config.extra_unet_roots and root_path in config.extra_unet_roots: return "diffusion_model" return None def adjust_metadata(self, metadata, file_path, root_path): """Adjust metadata during scanning to set sub_type.""" sub_type = self._resolve_sub_type(root_path) if sub_type: metadata.sub_type = sub_type return metadata def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]: """Adjust entries loaded from the persisted cache to ensure sub_type is set.""" sub_type = self._resolve_sub_type( self._find_root_for_file(entry.get("file_path")) ) if sub_type: entry["sub_type"] = sub_type return entry def get_model_roots(self) -> List[str]: """Get checkpoint root directories (including extra paths)""" roots: List[str] = [] roots.extend(config.base_models_roots or []) roots.extend(config.extra_checkpoints_roots or []) roots.extend(config.extra_unet_roots or []) # Remove duplicates while preserving order seen: set = set() unique_roots: List[str] = [] for root in roots: if root not in seen: seen.add(root) unique_roots.append(root) return unique_roots