mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 13:12:12 -03:00
feat: lazy hash calculation for checkpoints
Checkpoints are typically large (10GB+). This change delays SHA256 hash calculation until metadata fetch from Civitai is requested, significantly improving initial scan performance. - Add hash_status field to BaseModelMetadata - CheckpointScanner skips hash during initial scan - On-demand hash calculation during Civitai fetch - Background bulk hash calculation support
This commit is contained in:
@@ -383,10 +383,28 @@ class ModelManagementHandler:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Model not found in cache"}, status=404
|
||||
)
|
||||
if not model_data.get("sha256"):
|
||||
return web.json_response(
|
||||
{"success": False, "error": "No SHA256 hash found"}, status=400
|
||||
)
|
||||
|
||||
# Check if hash needs to be calculated (lazy hash for checkpoints)
|
||||
sha256 = model_data.get("sha256")
|
||||
hash_status = model_data.get("hash_status", "completed")
|
||||
|
||||
if not sha256 or hash_status != "completed":
|
||||
# For checkpoints, calculate hash on-demand
|
||||
scanner = self._service.scanner
|
||||
if hasattr(scanner, 'calculate_hash_for_model'):
|
||||
self._logger.info(f"Lazy hash calculation triggered for {file_path}")
|
||||
sha256 = await scanner.calculate_hash_for_model(file_path)
|
||||
if not sha256:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Failed to calculate SHA256 hash"}, status=500
|
||||
)
|
||||
# Update model_data with new hash
|
||||
model_data["sha256"] = sha256
|
||||
model_data["hash_status"] = "completed"
|
||||
else:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "No SHA256 hash found"}, status=400
|
||||
)
|
||||
|
||||
await MetadataManager.hydrate_model_data(model_data)
|
||||
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..utils.models import CheckpointMetadata
|
||||
from ..utils.file_utils import find_preview_file, normalize_path
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from ..config import config
|
||||
from .model_scanner import ModelScanner
|
||||
from .model_hash_index import ModelHashIndex
|
||||
@@ -21,6 +26,216 @@ class CheckpointScanner(ModelScanner):
|
||||
hash_index=ModelHashIndex()
|
||||
)
|
||||
|
||||
async def _create_default_metadata(self, file_path: str) -> Optional[CheckpointMetadata]:
|
||||
"""Create default metadata for checkpoint without calculating hash (lazy hash).
|
||||
|
||||
Checkpoints are typically large (10GB+), so we skip hash calculation during initial
|
||||
scanning to improve startup performance. Hash will be calculated on-demand when
|
||||
fetching metadata from Civitai.
|
||||
"""
|
||||
try:
|
||||
real_path = os.path.realpath(file_path)
|
||||
if not os.path.exists(real_path):
|
||||
logger.error(f"File not found: {file_path}")
|
||||
return None
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
dir_path = os.path.dirname(file_path)
|
||||
|
||||
# Find preview image
|
||||
preview_url = find_preview_file(base_name, dir_path)
|
||||
|
||||
# Create metadata WITHOUT calculating hash
|
||||
metadata = CheckpointMetadata(
|
||||
file_name=base_name,
|
||||
model_name=base_name,
|
||||
file_path=normalize_path(file_path),
|
||||
size=os.path.getsize(real_path),
|
||||
modified=datetime.now().timestamp(),
|
||||
sha256="", # Empty hash - will be calculated on-demand
|
||||
base_model="Unknown",
|
||||
preview_url=normalize_path(preview_url),
|
||||
tags=[],
|
||||
modelDescription="",
|
||||
sub_type="checkpoint",
|
||||
from_civitai=False, # Mark as local model since no hash yet
|
||||
hash_status="pending" # Mark hash as pending
|
||||
)
|
||||
|
||||
# Save the created metadata
|
||||
logger.info(f"Creating checkpoint metadata (hash pending) for {file_path}")
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
return metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating default checkpoint metadata for {file_path}: {e}")
|
||||
return None
|
||||
|
||||
async def calculate_hash_for_model(self, file_path: str) -> Optional[str]:
|
||||
"""Calculate hash for a checkpoint on-demand.
|
||||
|
||||
Args:
|
||||
file_path: Path to the model file
|
||||
|
||||
Returns:
|
||||
SHA256 hash string, or None if calculation failed
|
||||
"""
|
||||
from ..utils.file_utils import calculate_sha256
|
||||
|
||||
try:
|
||||
real_path = os.path.realpath(file_path)
|
||||
if not os.path.exists(real_path):
|
||||
logger.error(f"File not found for hash calculation: {file_path}")
|
||||
return None
|
||||
|
||||
# Load current metadata
|
||||
metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class)
|
||||
if metadata is None:
|
||||
logger.error(f"No metadata found for {file_path}")
|
||||
return None
|
||||
|
||||
# Check if hash is already calculated
|
||||
if metadata.hash_status == "completed" and metadata.sha256:
|
||||
return metadata.sha256
|
||||
|
||||
# Update status to calculating
|
||||
metadata.hash_status = "calculating"
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
# Calculate hash
|
||||
logger.info(f"Calculating hash for checkpoint: {file_path}")
|
||||
sha256 = await calculate_sha256(real_path)
|
||||
|
||||
# Update metadata with hash
|
||||
metadata.sha256 = sha256
|
||||
metadata.hash_status = "completed"
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
# Update hash index
|
||||
self._hash_index.add_entry(sha256.lower(), file_path)
|
||||
|
||||
logger.info(f"Hash calculated for checkpoint: {file_path}")
|
||||
return sha256
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating hash for {file_path}: {e}")
|
||||
# Update status to failed
|
||||
try:
|
||||
metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class)
|
||||
if metadata:
|
||||
metadata.hash_status = "failed"
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
async def calculate_all_pending_hashes(self, progress_callback=None) -> Dict[str, int]:
|
||||
"""Calculate hashes for all checkpoints with pending hash status.
|
||||
|
||||
If cache is not initialized, scans filesystem directly for metadata files
|
||||
with hash_status != 'completed'.
|
||||
|
||||
Args:
|
||||
progress_callback: Optional callback(progress, total, current_file)
|
||||
|
||||
Returns:
|
||||
Dict with 'completed', 'failed', 'total' counts
|
||||
"""
|
||||
# Try to get from cache first
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
if cache and cache.raw_data:
|
||||
# Use cache if available
|
||||
pending_models = [
|
||||
item for item in cache.raw_data
|
||||
if item.get('hash_status') != 'completed' or not item.get('sha256')
|
||||
]
|
||||
else:
|
||||
# Cache not initialized, scan filesystem directly
|
||||
pending_models = await self._find_pending_models_from_filesystem()
|
||||
|
||||
if not pending_models:
|
||||
return {'completed': 0, 'failed': 0, 'total': 0}
|
||||
|
||||
total = len(pending_models)
|
||||
completed = 0
|
||||
failed = 0
|
||||
|
||||
for i, model_data in enumerate(pending_models):
|
||||
file_path = model_data.get('file_path')
|
||||
if not file_path:
|
||||
continue
|
||||
|
||||
try:
|
||||
sha256 = await self.calculate_hash_for_model(file_path)
|
||||
if sha256:
|
||||
completed += 1
|
||||
else:
|
||||
failed += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating hash for {file_path}: {e}")
|
||||
failed += 1
|
||||
|
||||
if progress_callback:
|
||||
try:
|
||||
await progress_callback(i + 1, total, file_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
'completed': completed,
|
||||
'failed': failed,
|
||||
'total': total
|
||||
}
|
||||
|
||||
async def _find_pending_models_from_filesystem(self) -> List[Dict[str, Any]]:
|
||||
"""Scan filesystem for checkpoint metadata files with pending hash status."""
|
||||
pending_models = []
|
||||
|
||||
for root_path in self.get_model_roots():
|
||||
if not os.path.exists(root_path):
|
||||
continue
|
||||
|
||||
for dirpath, _dirnames, filenames in os.walk(root_path):
|
||||
for filename in filenames:
|
||||
if not filename.endswith('.metadata.json'):
|
||||
continue
|
||||
|
||||
metadata_path = os.path.join(dirpath, filename)
|
||||
try:
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Check if hash is pending
|
||||
hash_status = data.get('hash_status', 'completed')
|
||||
sha256 = data.get('sha256', '')
|
||||
|
||||
if hash_status != 'completed' or not sha256:
|
||||
# Find corresponding model file
|
||||
model_name = filename.replace('.metadata.json', '')
|
||||
model_path = None
|
||||
|
||||
# Look for model file with matching name
|
||||
for ext in self.file_extensions:
|
||||
potential_path = os.path.join(dirpath, model_name + ext)
|
||||
if os.path.exists(potential_path):
|
||||
model_path = potential_path
|
||||
break
|
||||
|
||||
if model_path:
|
||||
pending_models.append({
|
||||
'file_path': model_path.replace(os.sep, '/'),
|
||||
'hash_status': hash_status,
|
||||
'sha256': sha256,
|
||||
**{k: v for k, v in data.items() if k not in ['file_path', 'hash_status', 'sha256']}
|
||||
})
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.debug(f"Error reading metadata file {metadata_path}: {e}")
|
||||
continue
|
||||
|
||||
return pending_models
|
||||
|
||||
def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]:
|
||||
"""Resolve the sub-type based on the root path."""
|
||||
if not root_path:
|
||||
|
||||
@@ -282,6 +282,11 @@ class ModelScanner:
|
||||
sub_type = get_value('sub_type', None)
|
||||
if sub_type:
|
||||
entry['sub_type'] = sub_type
|
||||
|
||||
# Handle hash_status for lazy hash calculation (checkpoints)
|
||||
hash_status = get_value('hash_status', 'completed')
|
||||
if hash_status:
|
||||
entry['hash_status'] = hash_status
|
||||
|
||||
return entry
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ class BaseModelMetadata:
|
||||
skip_metadata_refresh: bool = False # Whether to skip this model during bulk metadata refresh
|
||||
metadata_source: Optional[str] = None # Last provider that supplied metadata
|
||||
last_checked_at: float = 0 # Last checked timestamp
|
||||
hash_status: str = "completed" # Hash calculation status: pending | calculating | completed | failed
|
||||
_unknown_fields: Dict[str, Any] = field(default_factory=dict, repr=False, compare=False) # Store unknown fields
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
271
tests/services/test_checkpoint_lazy_hash.py
Normal file
271
tests/services/test_checkpoint_lazy_hash.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""Tests for checkpoint lazy hash calculation feature."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services import model_scanner
|
||||
from py.services.checkpoint_scanner import CheckpointScanner
|
||||
from py.services.model_scanner import ModelScanner
|
||||
from py.utils.models import CheckpointMetadata
|
||||
|
||||
|
||||
class RecordingWebSocketManager:
|
||||
def __init__(self) -> None:
|
||||
self.payloads: List[dict] = []
|
||||
|
||||
async def broadcast_init_progress(self, payload: dict) -> None:
|
||||
self.payloads.append(payload)
|
||||
|
||||
|
||||
def _normalize(path: Path) -> str:
|
||||
return str(path).replace(os.sep, "/")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_model_scanner_singletons():
|
||||
ModelScanner._instances.clear()
|
||||
ModelScanner._locks.clear()
|
||||
yield
|
||||
ModelScanner._instances.clear()
|
||||
ModelScanner._locks.clear()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkpoint_default_metadata_has_pending_hash(tmp_path: Path, monkeypatch):
|
||||
"""Test that checkpoint metadata is created with hash_status='pending' and empty sha256."""
|
||||
checkpoints_root = tmp_path / "checkpoints"
|
||||
checkpoints_root.mkdir()
|
||||
|
||||
# Create a fake checkpoint file (small for testing)
|
||||
checkpoint_file = checkpoints_root / "test_model.safetensors"
|
||||
checkpoint_file.write_text("fake checkpoint content", encoding="utf-8")
|
||||
|
||||
normalized_root = _normalize(checkpoints_root)
|
||||
normalized_file = _normalize(checkpoint_file)
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"base_models_roots",
|
||||
[normalized_root],
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"checkpoints_roots",
|
||||
[normalized_root],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
scanner = CheckpointScanner()
|
||||
|
||||
# Create default metadata
|
||||
metadata = await scanner._create_default_metadata(normalized_file)
|
||||
|
||||
assert metadata is not None
|
||||
assert metadata.sha256 == "", "sha256 should be empty for lazy hash"
|
||||
assert metadata.hash_status == "pending", "hash_status should be 'pending'"
|
||||
assert metadata.from_civitai is False, "from_civitai should be False for local models"
|
||||
assert metadata.file_name == "test_model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkpoint_metadata_saved_to_disk_with_pending_status(tmp_path: Path, monkeypatch):
|
||||
"""Test that pending metadata is saved to .metadata.json file."""
|
||||
checkpoints_root = tmp_path / "checkpoints"
|
||||
checkpoints_root.mkdir()
|
||||
|
||||
checkpoint_file = checkpoints_root / "test_model.safetensors"
|
||||
checkpoint_file.write_text("fake content", encoding="utf-8")
|
||||
|
||||
normalized_root = _normalize(checkpoints_root)
|
||||
normalized_file = _normalize(checkpoint_file)
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"base_models_roots",
|
||||
[normalized_root],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
scanner = CheckpointScanner()
|
||||
|
||||
# Create metadata
|
||||
metadata = await scanner._create_default_metadata(normalized_file)
|
||||
assert metadata is not None
|
||||
|
||||
# Verify the metadata file was created
|
||||
metadata_file = checkpoints_root / "test_model.metadata.json"
|
||||
assert metadata_file.exists(), "Metadata file should be created"
|
||||
|
||||
# Load and verify content
|
||||
with open(metadata_file, "r", encoding="utf-8") as f:
|
||||
saved_data = json.load(f)
|
||||
|
||||
assert saved_data.get("sha256") == "", "Saved sha256 should be empty"
|
||||
assert saved_data.get("hash_status") == "pending", "Saved hash_status should be 'pending'"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_hash_for_model_completes_pending(tmp_path: Path, monkeypatch):
|
||||
"""Test that calculate_hash_for_model updates status to 'completed'."""
|
||||
checkpoints_root = tmp_path / "checkpoints"
|
||||
checkpoints_root.mkdir()
|
||||
|
||||
checkpoint_file = checkpoints_root / "test_model.safetensors"
|
||||
checkpoint_file.write_text("fake content for hashing", encoding="utf-8")
|
||||
|
||||
normalized_root = _normalize(checkpoints_root)
|
||||
normalized_file = _normalize(checkpoint_file)
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"base_models_roots",
|
||||
[normalized_root],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
scanner = CheckpointScanner()
|
||||
|
||||
# Create pending metadata
|
||||
metadata = await scanner._create_default_metadata(normalized_file)
|
||||
assert metadata is not None
|
||||
assert metadata.hash_status == "pending"
|
||||
|
||||
# Calculate hash
|
||||
hash_result = await scanner.calculate_hash_for_model(normalized_file)
|
||||
|
||||
assert hash_result is not None, "Hash calculation should succeed"
|
||||
assert len(hash_result) == 64, "SHA256 should be 64 hex characters"
|
||||
|
||||
# Verify metadata was updated
|
||||
metadata_file = checkpoints_root / "test_model.metadata.json"
|
||||
with open(metadata_file, "r", encoding="utf-8") as f:
|
||||
saved_data = json.load(f)
|
||||
|
||||
assert saved_data.get("sha256") == hash_result, "sha256 should be updated"
|
||||
assert saved_data.get("hash_status") == "completed", "hash_status should be 'completed'"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_hash_skips_if_already_completed(tmp_path: Path, monkeypatch):
|
||||
"""Test that calculate_hash_for_model skips calculation if already completed."""
|
||||
checkpoints_root = tmp_path / "checkpoints"
|
||||
checkpoints_root.mkdir()
|
||||
|
||||
checkpoint_file = checkpoints_root / "test_model.safetensors"
|
||||
checkpoint_file.write_text("fake content", encoding="utf-8")
|
||||
|
||||
normalized_root = _normalize(checkpoints_root)
|
||||
normalized_file = _normalize(checkpoint_file)
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"base_models_roots",
|
||||
[normalized_root],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
scanner = CheckpointScanner()
|
||||
|
||||
# Create metadata with completed hash
|
||||
metadata = CheckpointMetadata(
|
||||
file_name="test_model",
|
||||
model_name="test_model",
|
||||
file_path=normalized_file,
|
||||
size=100,
|
||||
modified=1234567890.0,
|
||||
sha256="existing_hash_value",
|
||||
base_model="Unknown",
|
||||
preview_url="",
|
||||
hash_status="completed",
|
||||
from_civitai=False,
|
||||
)
|
||||
|
||||
# Save metadata first
|
||||
from py.utils.metadata_manager import MetadataManager
|
||||
await MetadataManager.save_metadata(normalized_file, metadata)
|
||||
|
||||
# Calculate hash should return existing value
|
||||
with patch("py.utils.file_utils.calculate_sha256") as mock_calc:
|
||||
mock_calc.return_value = "new_calculated_hash"
|
||||
hash_result = await scanner.calculate_hash_for_model(normalized_file)
|
||||
|
||||
assert hash_result == "existing_hash_value", "Should return existing hash"
|
||||
mock_calc.assert_not_called(), "Should not recalculate if already completed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_all_pending_hashes(tmp_path: Path, monkeypatch):
|
||||
"""Test bulk hash calculation for all pending checkpoints."""
|
||||
checkpoints_root = tmp_path / "checkpoints"
|
||||
checkpoints_root.mkdir()
|
||||
|
||||
# Create multiple checkpoint files
|
||||
for i in range(3):
|
||||
checkpoint_file = checkpoints_root / f"model_{i}.safetensors"
|
||||
checkpoint_file.write_text(f"content {i}", encoding="utf-8")
|
||||
|
||||
normalized_root = _normalize(checkpoints_root)
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"base_models_roots",
|
||||
[normalized_root],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
scanner = CheckpointScanner()
|
||||
|
||||
# Create pending metadata for all models
|
||||
for i in range(3):
|
||||
checkpoint_file = checkpoints_root / f"model_{i}.safetensors"
|
||||
await scanner._create_default_metadata(_normalize(checkpoint_file))
|
||||
|
||||
# Mock progress callback
|
||||
progress_calls = []
|
||||
async def progress_callback(current, total, file_path):
|
||||
progress_calls.append((current, total, file_path))
|
||||
|
||||
# Calculate all pending hashes
|
||||
result = await scanner.calculate_all_pending_hashes(progress_callback)
|
||||
|
||||
assert result["total"] == 3, "Should find 3 pending models"
|
||||
assert result["completed"] == 3, "Should complete all 3"
|
||||
assert result["failed"] == 0, "Should not fail any"
|
||||
assert len(progress_calls) == 3, "Progress callback should be called 3 times"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lora_scanner_not_affected(tmp_path: Path, monkeypatch):
|
||||
"""Test that LoraScanner still calculates hash during initial scan."""
|
||||
from py.services.lora_scanner import LoraScanner
|
||||
|
||||
loras_root = tmp_path / "loras"
|
||||
loras_root.mkdir()
|
||||
|
||||
lora_file = loras_root / "test_lora.safetensors"
|
||||
lora_file.write_text("fake lora content", encoding="utf-8")
|
||||
|
||||
normalized_root = _normalize(loras_root)
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"loras_roots",
|
||||
[normalized_root],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
# Reset singleton for LoraScanner
|
||||
if LoraScanner in ModelScanner._instances:
|
||||
del ModelScanner._instances[LoraScanner]
|
||||
|
||||
scanner = LoraScanner()
|
||||
|
||||
# LoraScanner should use parent's _create_default_metadata which calculates hash
|
||||
# We verify this by checking that it doesn't override the method
|
||||
assert scanner._create_default_metadata.__qualname__ == "ModelScanner._create_default_metadata"
|
||||
Reference in New Issue
Block a user