mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 13:12:12 -03:00
Refactor _prepare_checkpoint_paths() to return a tuple instead of having side effects on instance variables. This prevents extra unet paths from being incorrectly classified as checkpoints when processing extra paths. - Changed return type from List[str] to Tuple[List[str], List[str], List[str]] (all_paths, checkpoint_roots, unet_roots) - Updated _init_checkpoint_paths() and _apply_library_paths() callers - Fixed extra paths processing to properly isolate main and extra roots - Updated test_checkpoint_path_overlap.py tests for new API This ensures models in extra unet paths are correctly identified as diffusion_model type and don't appear in checkpoints list.
329 lines
12 KiB
Python
329 lines
12 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from ..utils.models import CheckpointMetadata
|
|
from ..utils.file_utils import find_preview_file, normalize_path
|
|
from ..utils.metadata_manager import MetadataManager
|
|
from ..config import config
|
|
from .model_scanner import ModelScanner
|
|
from .model_hash_index import ModelHashIndex
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CheckpointScanner(ModelScanner):
|
|
"""Service for scanning and managing checkpoint files"""
|
|
|
|
def __init__(self):
|
|
# Define supported file extensions
|
|
file_extensions = {
|
|
".ckpt",
|
|
".pt",
|
|
".pt2",
|
|
".bin",
|
|
".pth",
|
|
".safetensors",
|
|
".pkl",
|
|
".sft",
|
|
".gguf",
|
|
}
|
|
super().__init__(
|
|
model_type="checkpoint",
|
|
model_class=CheckpointMetadata,
|
|
file_extensions=file_extensions,
|
|
hash_index=ModelHashIndex(),
|
|
)
|
|
|
|
async def _create_default_metadata(
|
|
self, file_path: str
|
|
) -> Optional[CheckpointMetadata]:
|
|
"""Create default metadata for checkpoint without calculating hash (lazy hash).
|
|
|
|
Checkpoints are typically large (10GB+), so we skip hash calculation during initial
|
|
scanning to improve startup performance. Hash will be calculated on-demand when
|
|
fetching metadata from Civitai.
|
|
"""
|
|
try:
|
|
real_path = os.path.realpath(file_path)
|
|
if not os.path.exists(real_path):
|
|
logger.error(f"File not found: {file_path}")
|
|
return None
|
|
|
|
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
|
dir_path = os.path.dirname(file_path)
|
|
|
|
# Find preview image
|
|
preview_url = find_preview_file(base_name, dir_path)
|
|
|
|
# Create metadata WITHOUT calculating hash
|
|
metadata = CheckpointMetadata(
|
|
file_name=base_name,
|
|
model_name=base_name,
|
|
file_path=normalize_path(file_path),
|
|
size=os.path.getsize(real_path),
|
|
modified=datetime.now().timestamp(),
|
|
sha256="", # Empty hash - will be calculated on-demand
|
|
base_model="Unknown",
|
|
preview_url=normalize_path(preview_url),
|
|
tags=[],
|
|
modelDescription="",
|
|
sub_type="checkpoint",
|
|
from_civitai=False, # Mark as local model since no hash yet
|
|
hash_status="pending", # Mark hash as pending
|
|
)
|
|
|
|
# Save the created metadata
|
|
logger.info(f"Creating checkpoint metadata (hash pending) for {file_path}")
|
|
await MetadataManager.save_metadata(file_path, metadata)
|
|
|
|
return metadata
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error creating default checkpoint metadata for {file_path}: {e}"
|
|
)
|
|
return None
|
|
|
|
async def calculate_hash_for_model(self, file_path: str) -> Optional[str]:
|
|
"""Calculate hash for a checkpoint on-demand.
|
|
|
|
Args:
|
|
file_path: Path to the model file
|
|
|
|
Returns:
|
|
SHA256 hash string, or None if calculation failed
|
|
"""
|
|
from ..utils.file_utils import calculate_sha256
|
|
|
|
try:
|
|
real_path = os.path.realpath(file_path)
|
|
if not os.path.exists(real_path):
|
|
logger.error(f"File not found for hash calculation: {file_path}")
|
|
return None
|
|
|
|
# Load current metadata
|
|
metadata, _ = await MetadataManager.load_metadata(
|
|
file_path, self.model_class
|
|
)
|
|
if metadata is None:
|
|
logger.error(f"No metadata found for {file_path}")
|
|
return None
|
|
|
|
# Check if hash is already calculated
|
|
if metadata.hash_status == "completed" and metadata.sha256:
|
|
return metadata.sha256
|
|
|
|
# Update status to calculating
|
|
metadata.hash_status = "calculating"
|
|
await MetadataManager.save_metadata(file_path, metadata)
|
|
|
|
# Calculate hash
|
|
logger.info(f"Calculating hash for checkpoint: {file_path}")
|
|
sha256 = await calculate_sha256(real_path)
|
|
|
|
# Update metadata with hash
|
|
metadata.sha256 = sha256
|
|
metadata.hash_status = "completed"
|
|
await MetadataManager.save_metadata(file_path, metadata)
|
|
|
|
# Update hash index
|
|
self._hash_index.add_entry(sha256.lower(), file_path)
|
|
|
|
logger.info(f"Hash calculated for checkpoint: {file_path}")
|
|
return sha256
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating hash for {file_path}: {e}")
|
|
# Update status to failed
|
|
try:
|
|
metadata, _ = await MetadataManager.load_metadata(
|
|
file_path, self.model_class
|
|
)
|
|
if metadata:
|
|
metadata.hash_status = "failed"
|
|
await MetadataManager.save_metadata(file_path, metadata)
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
async def calculate_all_pending_hashes(
|
|
self, progress_callback=None
|
|
) -> Dict[str, int]:
|
|
"""Calculate hashes for all checkpoints with pending hash status.
|
|
|
|
If cache is not initialized, scans filesystem directly for metadata files
|
|
with hash_status != 'completed'.
|
|
|
|
Args:
|
|
progress_callback: Optional callback(progress, total, current_file)
|
|
|
|
Returns:
|
|
Dict with 'completed', 'failed', 'total' counts
|
|
"""
|
|
# Try to get from cache first
|
|
cache = await self.get_cached_data()
|
|
|
|
if cache and cache.raw_data:
|
|
# Use cache if available
|
|
pending_models = [
|
|
item
|
|
for item in cache.raw_data
|
|
if item.get("hash_status") != "completed" or not item.get("sha256")
|
|
]
|
|
else:
|
|
# Cache not initialized, scan filesystem directly
|
|
pending_models = await self._find_pending_models_from_filesystem()
|
|
|
|
if not pending_models:
|
|
return {"completed": 0, "failed": 0, "total": 0}
|
|
|
|
total = len(pending_models)
|
|
completed = 0
|
|
failed = 0
|
|
|
|
for i, model_data in enumerate(pending_models):
|
|
file_path = model_data.get("file_path")
|
|
if not file_path:
|
|
continue
|
|
|
|
try:
|
|
sha256 = await self.calculate_hash_for_model(file_path)
|
|
if sha256:
|
|
completed += 1
|
|
else:
|
|
failed += 1
|
|
except Exception as e:
|
|
logger.error(f"Error calculating hash for {file_path}: {e}")
|
|
failed += 1
|
|
|
|
if progress_callback:
|
|
try:
|
|
await progress_callback(i + 1, total, file_path)
|
|
except Exception:
|
|
pass
|
|
|
|
return {"completed": completed, "failed": failed, "total": total}
|
|
|
|
async def _find_pending_models_from_filesystem(self) -> List[Dict[str, Any]]:
|
|
"""Scan filesystem for checkpoint metadata files with pending hash status."""
|
|
pending_models = []
|
|
|
|
for root_path in self.get_model_roots():
|
|
if not os.path.exists(root_path):
|
|
continue
|
|
|
|
for dirpath, _dirnames, filenames in os.walk(root_path):
|
|
for filename in filenames:
|
|
if not filename.endswith(".metadata.json"):
|
|
continue
|
|
|
|
metadata_path = os.path.join(dirpath, filename)
|
|
try:
|
|
with open(metadata_path, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
|
|
# Check if hash is pending
|
|
hash_status = data.get("hash_status", "completed")
|
|
sha256 = data.get("sha256", "")
|
|
|
|
if hash_status != "completed" or not sha256:
|
|
# Find corresponding model file
|
|
model_name = filename.replace(".metadata.json", "")
|
|
model_path = None
|
|
|
|
# Look for model file with matching name
|
|
for ext in self.file_extensions:
|
|
potential_path = os.path.join(dirpath, model_name + ext)
|
|
if os.path.exists(potential_path):
|
|
model_path = potential_path
|
|
break
|
|
|
|
if model_path:
|
|
pending_models.append(
|
|
{
|
|
"file_path": model_path.replace(os.sep, "/"),
|
|
"hash_status": hash_status,
|
|
"sha256": sha256,
|
|
**{
|
|
k: v
|
|
for k, v in data.items()
|
|
if k
|
|
not in [
|
|
"file_path",
|
|
"hash_status",
|
|
"sha256",
|
|
]
|
|
},
|
|
}
|
|
)
|
|
except (json.JSONDecodeError, Exception) as e:
|
|
logger.debug(
|
|
f"Error reading metadata file {metadata_path}: {e}"
|
|
)
|
|
continue
|
|
|
|
return pending_models
|
|
|
|
def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]:
|
|
"""Resolve the sub-type based on the root path.
|
|
|
|
Checks both standard ComfyUI paths and LoRA Manager's extra folder paths.
|
|
"""
|
|
if not root_path:
|
|
return None
|
|
|
|
# Check standard ComfyUI checkpoint paths
|
|
if config.checkpoints_roots and root_path in config.checkpoints_roots:
|
|
return "checkpoint"
|
|
|
|
# Check extra checkpoint paths
|
|
if (
|
|
config.extra_checkpoints_roots
|
|
and root_path in config.extra_checkpoints_roots
|
|
):
|
|
return "checkpoint"
|
|
|
|
# Check standard ComfyUI unet paths
|
|
if config.unet_roots and root_path in config.unet_roots:
|
|
return "diffusion_model"
|
|
|
|
# Check extra unet paths
|
|
if config.extra_unet_roots and root_path in config.extra_unet_roots:
|
|
return "diffusion_model"
|
|
|
|
return None
|
|
|
|
def adjust_metadata(self, metadata, file_path, root_path):
|
|
"""Adjust metadata during scanning to set sub_type."""
|
|
sub_type = self._resolve_sub_type(root_path)
|
|
if sub_type:
|
|
metadata.sub_type = sub_type
|
|
return metadata
|
|
|
|
def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Adjust entries loaded from the persisted cache to ensure sub_type is set."""
|
|
sub_type = self._resolve_sub_type(
|
|
self._find_root_for_file(entry.get("file_path"))
|
|
)
|
|
if sub_type:
|
|
entry["sub_type"] = sub_type
|
|
return entry
|
|
|
|
def get_model_roots(self) -> List[str]:
|
|
"""Get checkpoint root directories (including extra paths)"""
|
|
roots: List[str] = []
|
|
roots.extend(config.base_models_roots or [])
|
|
roots.extend(config.extra_checkpoints_roots or [])
|
|
roots.extend(config.extra_unet_roots or [])
|
|
# Remove duplicates while preserving order
|
|
seen: set = set()
|
|
unique_roots: List[str] = []
|
|
for root in roots:
|
|
if root not in seen:
|
|
seen.add(root)
|
|
unique_roots.append(root)
|
|
return unique_roots
|