feat: lazy hash calculation for checkpoints

Checkpoints are typically large (10GB+). This change delays SHA256
hash calculation until metadata fetch from Civitai is requested,
significantly improving initial scan performance.

- Add hash_status field to BaseModelMetadata
- CheckpointScanner skips hash during initial scan
- On-demand hash calculation during Civitai fetch
- Background bulk hash calculation support
This commit is contained in:
Will Miao
2026-02-26 22:39:11 +08:00
parent 9f15c1fc06
commit 40d9f8d0aa
5 changed files with 514 additions and 4 deletions

View File

@@ -383,10 +383,28 @@ class ModelManagementHandler:
return web.json_response(
{"success": False, "error": "Model not found in cache"}, status=404
)
if not model_data.get("sha256"):
return web.json_response(
{"success": False, "error": "No SHA256 hash found"}, status=400
)
# Check if hash needs to be calculated (lazy hash for checkpoints)
sha256 = model_data.get("sha256")
hash_status = model_data.get("hash_status", "completed")
if not sha256 or hash_status != "completed":
# For checkpoints, calculate hash on-demand
scanner = self._service.scanner
if hasattr(scanner, 'calculate_hash_for_model'):
self._logger.info(f"Lazy hash calculation triggered for {file_path}")
sha256 = await scanner.calculate_hash_for_model(file_path)
if not sha256:
return web.json_response(
{"success": False, "error": "Failed to calculate SHA256 hash"}, status=500
)
# Update model_data with new hash
model_data["sha256"] = sha256
model_data["hash_status"] = "completed"
else:
return web.json_response(
{"success": False, "error": "No SHA256 hash found"}, status=400
)
await MetadataManager.hydrate_model_data(model_data)

View File

@@ -1,7 +1,12 @@
import json
import logging
import os
from datetime import datetime
from typing import Any, Dict, List, Optional
from ..utils.models import CheckpointMetadata
from ..utils.file_utils import find_preview_file, normalize_path
from ..utils.metadata_manager import MetadataManager
from ..config import config
from .model_scanner import ModelScanner
from .model_hash_index import ModelHashIndex
@@ -21,6 +26,216 @@ class CheckpointScanner(ModelScanner):
hash_index=ModelHashIndex()
)
async def _create_default_metadata(self, file_path: str) -> Optional[CheckpointMetadata]:
"""Create default metadata for checkpoint without calculating hash (lazy hash).
Checkpoints are typically large (10GB+), so we skip hash calculation during initial
scanning to improve startup performance. Hash will be calculated on-demand when
fetching metadata from Civitai.
"""
try:
real_path = os.path.realpath(file_path)
if not os.path.exists(real_path):
logger.error(f"File not found: {file_path}")
return None
base_name = os.path.splitext(os.path.basename(file_path))[0]
dir_path = os.path.dirname(file_path)
# Find preview image
preview_url = find_preview_file(base_name, dir_path)
# Create metadata WITHOUT calculating hash
metadata = CheckpointMetadata(
file_name=base_name,
model_name=base_name,
file_path=normalize_path(file_path),
size=os.path.getsize(real_path),
modified=datetime.now().timestamp(),
sha256="", # Empty hash - will be calculated on-demand
base_model="Unknown",
preview_url=normalize_path(preview_url),
tags=[],
modelDescription="",
sub_type="checkpoint",
from_civitai=False, # Mark as local model since no hash yet
hash_status="pending" # Mark hash as pending
)
# Save the created metadata
logger.info(f"Creating checkpoint metadata (hash pending) for {file_path}")
await MetadataManager.save_metadata(file_path, metadata)
return metadata
except Exception as e:
logger.error(f"Error creating default checkpoint metadata for {file_path}: {e}")
return None
async def calculate_hash_for_model(self, file_path: str) -> Optional[str]:
"""Calculate hash for a checkpoint on-demand.
Args:
file_path: Path to the model file
Returns:
SHA256 hash string, or None if calculation failed
"""
from ..utils.file_utils import calculate_sha256
try:
real_path = os.path.realpath(file_path)
if not os.path.exists(real_path):
logger.error(f"File not found for hash calculation: {file_path}")
return None
# Load current metadata
metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class)
if metadata is None:
logger.error(f"No metadata found for {file_path}")
return None
# Check if hash is already calculated
if metadata.hash_status == "completed" and metadata.sha256:
return metadata.sha256
# Update status to calculating
metadata.hash_status = "calculating"
await MetadataManager.save_metadata(file_path, metadata)
# Calculate hash
logger.info(f"Calculating hash for checkpoint: {file_path}")
sha256 = await calculate_sha256(real_path)
# Update metadata with hash
metadata.sha256 = sha256
metadata.hash_status = "completed"
await MetadataManager.save_metadata(file_path, metadata)
# Update hash index
self._hash_index.add_entry(sha256.lower(), file_path)
logger.info(f"Hash calculated for checkpoint: {file_path}")
return sha256
except Exception as e:
logger.error(f"Error calculating hash for {file_path}: {e}")
# Update status to failed
try:
metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class)
if metadata:
metadata.hash_status = "failed"
await MetadataManager.save_metadata(file_path, metadata)
except Exception:
pass
return None
async def calculate_all_pending_hashes(self, progress_callback=None) -> Dict[str, int]:
"""Calculate hashes for all checkpoints with pending hash status.
If cache is not initialized, scans filesystem directly for metadata files
with hash_status != 'completed'.
Args:
progress_callback: Optional callback(progress, total, current_file)
Returns:
Dict with 'completed', 'failed', 'total' counts
"""
# Try to get from cache first
cache = await self.get_cached_data()
if cache and cache.raw_data:
# Use cache if available
pending_models = [
item for item in cache.raw_data
if item.get('hash_status') != 'completed' or not item.get('sha256')
]
else:
# Cache not initialized, scan filesystem directly
pending_models = await self._find_pending_models_from_filesystem()
if not pending_models:
return {'completed': 0, 'failed': 0, 'total': 0}
total = len(pending_models)
completed = 0
failed = 0
for i, model_data in enumerate(pending_models):
file_path = model_data.get('file_path')
if not file_path:
continue
try:
sha256 = await self.calculate_hash_for_model(file_path)
if sha256:
completed += 1
else:
failed += 1
except Exception as e:
logger.error(f"Error calculating hash for {file_path}: {e}")
failed += 1
if progress_callback:
try:
await progress_callback(i + 1, total, file_path)
except Exception:
pass
return {
'completed': completed,
'failed': failed,
'total': total
}
async def _find_pending_models_from_filesystem(self) -> List[Dict[str, Any]]:
"""Scan filesystem for checkpoint metadata files with pending hash status."""
pending_models = []
for root_path in self.get_model_roots():
if not os.path.exists(root_path):
continue
for dirpath, _dirnames, filenames in os.walk(root_path):
for filename in filenames:
if not filename.endswith('.metadata.json'):
continue
metadata_path = os.path.join(dirpath, filename)
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Check if hash is pending
hash_status = data.get('hash_status', 'completed')
sha256 = data.get('sha256', '')
if hash_status != 'completed' or not sha256:
# Find corresponding model file
model_name = filename.replace('.metadata.json', '')
model_path = None
# Look for model file with matching name
for ext in self.file_extensions:
potential_path = os.path.join(dirpath, model_name + ext)
if os.path.exists(potential_path):
model_path = potential_path
break
if model_path:
pending_models.append({
'file_path': model_path.replace(os.sep, '/'),
'hash_status': hash_status,
'sha256': sha256,
**{k: v for k, v in data.items() if k not in ['file_path', 'hash_status', 'sha256']}
})
except (json.JSONDecodeError, Exception) as e:
logger.debug(f"Error reading metadata file {metadata_path}: {e}")
continue
return pending_models
def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]:
"""Resolve the sub-type based on the root path."""
if not root_path:

View File

@@ -282,6 +282,11 @@ class ModelScanner:
sub_type = get_value('sub_type', None)
if sub_type:
entry['sub_type'] = sub_type
# Handle hash_status for lazy hash calculation (checkpoints)
hash_status = get_value('hash_status', 'completed')
if hash_status:
entry['hash_status'] = hash_status
return entry

View File

@@ -28,6 +28,7 @@ class BaseModelMetadata:
skip_metadata_refresh: bool = False # Whether to skip this model during bulk metadata refresh
metadata_source: Optional[str] = None # Last provider that supplied metadata
last_checked_at: float = 0 # Last checked timestamp
hash_status: str = "completed" # Hash calculation status: pending | calculating | completed | failed
_unknown_fields: Dict[str, Any] = field(default_factory=dict, repr=False, compare=False) # Store unknown fields
def __post_init__(self):

View File

@@ -0,0 +1,271 @@
"""Tests for checkpoint lazy hash calculation feature."""
import json
import os
from pathlib import Path
from typing import List
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from py.services import model_scanner
from py.services.checkpoint_scanner import CheckpointScanner
from py.services.model_scanner import ModelScanner
from py.utils.models import CheckpointMetadata
class RecordingWebSocketManager:
def __init__(self) -> None:
self.payloads: List[dict] = []
async def broadcast_init_progress(self, payload: dict) -> None:
self.payloads.append(payload)
def _normalize(path: Path) -> str:
return str(path).replace(os.sep, "/")
@pytest.fixture(autouse=True)
def reset_model_scanner_singletons():
ModelScanner._instances.clear()
ModelScanner._locks.clear()
yield
ModelScanner._instances.clear()
ModelScanner._locks.clear()
@pytest.mark.asyncio
async def test_checkpoint_default_metadata_has_pending_hash(tmp_path: Path, monkeypatch):
"""Test that checkpoint metadata is created with hash_status='pending' and empty sha256."""
checkpoints_root = tmp_path / "checkpoints"
checkpoints_root.mkdir()
# Create a fake checkpoint file (small for testing)
checkpoint_file = checkpoints_root / "test_model.safetensors"
checkpoint_file.write_text("fake checkpoint content", encoding="utf-8")
normalized_root = _normalize(checkpoints_root)
normalized_file = _normalize(checkpoint_file)
monkeypatch.setattr(
model_scanner.config,
"base_models_roots",
[normalized_root],
raising=False,
)
monkeypatch.setattr(
model_scanner.config,
"checkpoints_roots",
[normalized_root],
raising=False,
)
scanner = CheckpointScanner()
# Create default metadata
metadata = await scanner._create_default_metadata(normalized_file)
assert metadata is not None
assert metadata.sha256 == "", "sha256 should be empty for lazy hash"
assert metadata.hash_status == "pending", "hash_status should be 'pending'"
assert metadata.from_civitai is False, "from_civitai should be False for local models"
assert metadata.file_name == "test_model"
@pytest.mark.asyncio
async def test_checkpoint_metadata_saved_to_disk_with_pending_status(tmp_path: Path, monkeypatch):
"""Test that pending metadata is saved to .metadata.json file."""
checkpoints_root = tmp_path / "checkpoints"
checkpoints_root.mkdir()
checkpoint_file = checkpoints_root / "test_model.safetensors"
checkpoint_file.write_text("fake content", encoding="utf-8")
normalized_root = _normalize(checkpoints_root)
normalized_file = _normalize(checkpoint_file)
monkeypatch.setattr(
model_scanner.config,
"base_models_roots",
[normalized_root],
raising=False,
)
scanner = CheckpointScanner()
# Create metadata
metadata = await scanner._create_default_metadata(normalized_file)
assert metadata is not None
# Verify the metadata file was created
metadata_file = checkpoints_root / "test_model.metadata.json"
assert metadata_file.exists(), "Metadata file should be created"
# Load and verify content
with open(metadata_file, "r", encoding="utf-8") as f:
saved_data = json.load(f)
assert saved_data.get("sha256") == "", "Saved sha256 should be empty"
assert saved_data.get("hash_status") == "pending", "Saved hash_status should be 'pending'"
@pytest.mark.asyncio
async def test_calculate_hash_for_model_completes_pending(tmp_path: Path, monkeypatch):
"""Test that calculate_hash_for_model updates status to 'completed'."""
checkpoints_root = tmp_path / "checkpoints"
checkpoints_root.mkdir()
checkpoint_file = checkpoints_root / "test_model.safetensors"
checkpoint_file.write_text("fake content for hashing", encoding="utf-8")
normalized_root = _normalize(checkpoints_root)
normalized_file = _normalize(checkpoint_file)
monkeypatch.setattr(
model_scanner.config,
"base_models_roots",
[normalized_root],
raising=False,
)
scanner = CheckpointScanner()
# Create pending metadata
metadata = await scanner._create_default_metadata(normalized_file)
assert metadata is not None
assert metadata.hash_status == "pending"
# Calculate hash
hash_result = await scanner.calculate_hash_for_model(normalized_file)
assert hash_result is not None, "Hash calculation should succeed"
assert len(hash_result) == 64, "SHA256 should be 64 hex characters"
# Verify metadata was updated
metadata_file = checkpoints_root / "test_model.metadata.json"
with open(metadata_file, "r", encoding="utf-8") as f:
saved_data = json.load(f)
assert saved_data.get("sha256") == hash_result, "sha256 should be updated"
assert saved_data.get("hash_status") == "completed", "hash_status should be 'completed'"
@pytest.mark.asyncio
async def test_calculate_hash_skips_if_already_completed(tmp_path: Path, monkeypatch):
"""Test that calculate_hash_for_model skips calculation if already completed."""
checkpoints_root = tmp_path / "checkpoints"
checkpoints_root.mkdir()
checkpoint_file = checkpoints_root / "test_model.safetensors"
checkpoint_file.write_text("fake content", encoding="utf-8")
normalized_root = _normalize(checkpoints_root)
normalized_file = _normalize(checkpoint_file)
monkeypatch.setattr(
model_scanner.config,
"base_models_roots",
[normalized_root],
raising=False,
)
scanner = CheckpointScanner()
# Create metadata with completed hash
metadata = CheckpointMetadata(
file_name="test_model",
model_name="test_model",
file_path=normalized_file,
size=100,
modified=1234567890.0,
sha256="existing_hash_value",
base_model="Unknown",
preview_url="",
hash_status="completed",
from_civitai=False,
)
# Save metadata first
from py.utils.metadata_manager import MetadataManager
await MetadataManager.save_metadata(normalized_file, metadata)
# Calculate hash should return existing value
with patch("py.utils.file_utils.calculate_sha256") as mock_calc:
mock_calc.return_value = "new_calculated_hash"
hash_result = await scanner.calculate_hash_for_model(normalized_file)
assert hash_result == "existing_hash_value", "Should return existing hash"
mock_calc.assert_not_called(), "Should not recalculate if already completed"
@pytest.mark.asyncio
async def test_calculate_all_pending_hashes(tmp_path: Path, monkeypatch):
"""Test bulk hash calculation for all pending checkpoints."""
checkpoints_root = tmp_path / "checkpoints"
checkpoints_root.mkdir()
# Create multiple checkpoint files
for i in range(3):
checkpoint_file = checkpoints_root / f"model_{i}.safetensors"
checkpoint_file.write_text(f"content {i}", encoding="utf-8")
normalized_root = _normalize(checkpoints_root)
monkeypatch.setattr(
model_scanner.config,
"base_models_roots",
[normalized_root],
raising=False,
)
scanner = CheckpointScanner()
# Create pending metadata for all models
for i in range(3):
checkpoint_file = checkpoints_root / f"model_{i}.safetensors"
await scanner._create_default_metadata(_normalize(checkpoint_file))
# Mock progress callback
progress_calls = []
async def progress_callback(current, total, file_path):
progress_calls.append((current, total, file_path))
# Calculate all pending hashes
result = await scanner.calculate_all_pending_hashes(progress_callback)
assert result["total"] == 3, "Should find 3 pending models"
assert result["completed"] == 3, "Should complete all 3"
assert result["failed"] == 0, "Should not fail any"
assert len(progress_calls) == 3, "Progress callback should be called 3 times"
@pytest.mark.asyncio
async def test_lora_scanner_not_affected(tmp_path: Path, monkeypatch):
"""Test that LoraScanner still calculates hash during initial scan."""
from py.services.lora_scanner import LoraScanner
loras_root = tmp_path / "loras"
loras_root.mkdir()
lora_file = loras_root / "test_lora.safetensors"
lora_file.write_text("fake lora content", encoding="utf-8")
normalized_root = _normalize(loras_root)
monkeypatch.setattr(
model_scanner.config,
"loras_roots",
[normalized_root],
raising=False,
)
# Reset singleton for LoraScanner
if LoraScanner in ModelScanner._instances:
del ModelScanner._instances[LoraScanner]
scanner = LoraScanner()
# LoraScanner should use parent's _create_default_metadata which calculates hash
# We verify this by checking that it doesn't override the method
assert scanner._create_default_metadata.__qualname__ == "ModelScanner._create_default_metadata"