diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index fa4d078c..2ccc0436 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -383,10 +383,28 @@ class ModelManagementHandler: return web.json_response( {"success": False, "error": "Model not found in cache"}, status=404 ) - if not model_data.get("sha256"): - return web.json_response( - {"success": False, "error": "No SHA256 hash found"}, status=400 - ) + + # Check if hash needs to be calculated (lazy hash for checkpoints) + sha256 = model_data.get("sha256") + hash_status = model_data.get("hash_status", "completed") + + if not sha256 or hash_status != "completed": + # For checkpoints, calculate hash on-demand + scanner = self._service.scanner + if hasattr(scanner, 'calculate_hash_for_model'): + self._logger.info(f"Lazy hash calculation triggered for {file_path}") + sha256 = await scanner.calculate_hash_for_model(file_path) + if not sha256: + return web.json_response( + {"success": False, "error": "Failed to calculate SHA256 hash"}, status=500 + ) + # Update model_data with new hash + model_data["sha256"] = sha256 + model_data["hash_status"] = "completed" + else: + return web.json_response( + {"success": False, "error": "No SHA256 hash found"}, status=400 + ) await MetadataManager.hydrate_model_data(model_data) diff --git a/py/services/checkpoint_scanner.py b/py/services/checkpoint_scanner.py index 075f9129..f28fa439 100644 --- a/py/services/checkpoint_scanner.py +++ b/py/services/checkpoint_scanner.py @@ -1,7 +1,12 @@ +import json import logging +import os +from datetime import datetime from typing import Any, Dict, List, Optional from ..utils.models import CheckpointMetadata +from ..utils.file_utils import find_preview_file, normalize_path +from ..utils.metadata_manager import MetadataManager from ..config import config from .model_scanner import ModelScanner from .model_hash_index import ModelHashIndex @@ -21,6 +26,216 @@ class CheckpointScanner(ModelScanner): hash_index=ModelHashIndex() ) + async def _create_default_metadata(self, file_path: str) -> Optional[CheckpointMetadata]: + """Create default metadata for checkpoint without calculating hash (lazy hash). + + Checkpoints are typically large (10GB+), so we skip hash calculation during initial + scanning to improve startup performance. Hash will be calculated on-demand when + fetching metadata from Civitai. + """ + try: + real_path = os.path.realpath(file_path) + if not os.path.exists(real_path): + logger.error(f"File not found: {file_path}") + return None + + base_name = os.path.splitext(os.path.basename(file_path))[0] + dir_path = os.path.dirname(file_path) + + # Find preview image + preview_url = find_preview_file(base_name, dir_path) + + # Create metadata WITHOUT calculating hash + metadata = CheckpointMetadata( + file_name=base_name, + model_name=base_name, + file_path=normalize_path(file_path), + size=os.path.getsize(real_path), + modified=datetime.now().timestamp(), + sha256="", # Empty hash - will be calculated on-demand + base_model="Unknown", + preview_url=normalize_path(preview_url), + tags=[], + modelDescription="", + sub_type="checkpoint", + from_civitai=False, # Mark as local model since no hash yet + hash_status="pending" # Mark hash as pending + ) + + # Save the created metadata + logger.info(f"Creating checkpoint metadata (hash pending) for {file_path}") + await MetadataManager.save_metadata(file_path, metadata) + + return metadata + + except Exception as e: + logger.error(f"Error creating default checkpoint metadata for {file_path}: {e}") + return None + + async def calculate_hash_for_model(self, file_path: str) -> Optional[str]: + """Calculate hash for a checkpoint on-demand. + + Args: + file_path: Path to the model file + + Returns: + SHA256 hash string, or None if calculation failed + """ + from ..utils.file_utils import calculate_sha256 + + try: + real_path = os.path.realpath(file_path) + if not os.path.exists(real_path): + logger.error(f"File not found for hash calculation: {file_path}") + return None + + # Load current metadata + metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class) + if metadata is None: + logger.error(f"No metadata found for {file_path}") + return None + + # Check if hash is already calculated + if metadata.hash_status == "completed" and metadata.sha256: + return metadata.sha256 + + # Update status to calculating + metadata.hash_status = "calculating" + await MetadataManager.save_metadata(file_path, metadata) + + # Calculate hash + logger.info(f"Calculating hash for checkpoint: {file_path}") + sha256 = await calculate_sha256(real_path) + + # Update metadata with hash + metadata.sha256 = sha256 + metadata.hash_status = "completed" + await MetadataManager.save_metadata(file_path, metadata) + + # Update hash index + self._hash_index.add_entry(sha256.lower(), file_path) + + logger.info(f"Hash calculated for checkpoint: {file_path}") + return sha256 + + except Exception as e: + logger.error(f"Error calculating hash for {file_path}: {e}") + # Update status to failed + try: + metadata, _ = await MetadataManager.load_metadata(file_path, self.model_class) + if metadata: + metadata.hash_status = "failed" + await MetadataManager.save_metadata(file_path, metadata) + except Exception: + pass + return None + + async def calculate_all_pending_hashes(self, progress_callback=None) -> Dict[str, int]: + """Calculate hashes for all checkpoints with pending hash status. + + If cache is not initialized, scans filesystem directly for metadata files + with hash_status != 'completed'. + + Args: + progress_callback: Optional callback(progress, total, current_file) + + Returns: + Dict with 'completed', 'failed', 'total' counts + """ + # Try to get from cache first + cache = await self.get_cached_data() + + if cache and cache.raw_data: + # Use cache if available + pending_models = [ + item for item in cache.raw_data + if item.get('hash_status') != 'completed' or not item.get('sha256') + ] + else: + # Cache not initialized, scan filesystem directly + pending_models = await self._find_pending_models_from_filesystem() + + if not pending_models: + return {'completed': 0, 'failed': 0, 'total': 0} + + total = len(pending_models) + completed = 0 + failed = 0 + + for i, model_data in enumerate(pending_models): + file_path = model_data.get('file_path') + if not file_path: + continue + + try: + sha256 = await self.calculate_hash_for_model(file_path) + if sha256: + completed += 1 + else: + failed += 1 + except Exception as e: + logger.error(f"Error calculating hash for {file_path}: {e}") + failed += 1 + + if progress_callback: + try: + await progress_callback(i + 1, total, file_path) + except Exception: + pass + + return { + 'completed': completed, + 'failed': failed, + 'total': total + } + + async def _find_pending_models_from_filesystem(self) -> List[Dict[str, Any]]: + """Scan filesystem for checkpoint metadata files with pending hash status.""" + pending_models = [] + + for root_path in self.get_model_roots(): + if not os.path.exists(root_path): + continue + + for dirpath, _dirnames, filenames in os.walk(root_path): + for filename in filenames: + if not filename.endswith('.metadata.json'): + continue + + metadata_path = os.path.join(dirpath, filename) + try: + with open(metadata_path, 'r', encoding='utf-8') as f: + data = json.load(f) + + # Check if hash is pending + hash_status = data.get('hash_status', 'completed') + sha256 = data.get('sha256', '') + + if hash_status != 'completed' or not sha256: + # Find corresponding model file + model_name = filename.replace('.metadata.json', '') + model_path = None + + # Look for model file with matching name + for ext in self.file_extensions: + potential_path = os.path.join(dirpath, model_name + ext) + if os.path.exists(potential_path): + model_path = potential_path + break + + if model_path: + pending_models.append({ + 'file_path': model_path.replace(os.sep, '/'), + 'hash_status': hash_status, + 'sha256': sha256, + **{k: v for k, v in data.items() if k not in ['file_path', 'hash_status', 'sha256']} + }) + except (json.JSONDecodeError, Exception) as e: + logger.debug(f"Error reading metadata file {metadata_path}: {e}") + continue + + return pending_models + def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]: """Resolve the sub-type based on the root path.""" if not root_path: diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index fc0e6ec2..3c427790 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -282,6 +282,11 @@ class ModelScanner: sub_type = get_value('sub_type', None) if sub_type: entry['sub_type'] = sub_type + + # Handle hash_status for lazy hash calculation (checkpoints) + hash_status = get_value('hash_status', 'completed') + if hash_status: + entry['hash_status'] = hash_status return entry diff --git a/py/utils/models.py b/py/utils/models.py index eec140db..378b567b 100644 --- a/py/utils/models.py +++ b/py/utils/models.py @@ -28,6 +28,7 @@ class BaseModelMetadata: skip_metadata_refresh: bool = False # Whether to skip this model during bulk metadata refresh metadata_source: Optional[str] = None # Last provider that supplied metadata last_checked_at: float = 0 # Last checked timestamp + hash_status: str = "completed" # Hash calculation status: pending | calculating | completed | failed _unknown_fields: Dict[str, Any] = field(default_factory=dict, repr=False, compare=False) # Store unknown fields def __post_init__(self): diff --git a/tests/services/test_checkpoint_lazy_hash.py b/tests/services/test_checkpoint_lazy_hash.py new file mode 100644 index 00000000..67a8462f --- /dev/null +++ b/tests/services/test_checkpoint_lazy_hash.py @@ -0,0 +1,271 @@ +"""Tests for checkpoint lazy hash calculation feature.""" + +import json +import os +from pathlib import Path +from typing import List +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from py.services import model_scanner +from py.services.checkpoint_scanner import CheckpointScanner +from py.services.model_scanner import ModelScanner +from py.utils.models import CheckpointMetadata + + +class RecordingWebSocketManager: + def __init__(self) -> None: + self.payloads: List[dict] = [] + + async def broadcast_init_progress(self, payload: dict) -> None: + self.payloads.append(payload) + + +def _normalize(path: Path) -> str: + return str(path).replace(os.sep, "/") + + +@pytest.fixture(autouse=True) +def reset_model_scanner_singletons(): + ModelScanner._instances.clear() + ModelScanner._locks.clear() + yield + ModelScanner._instances.clear() + ModelScanner._locks.clear() + + +@pytest.mark.asyncio +async def test_checkpoint_default_metadata_has_pending_hash(tmp_path: Path, monkeypatch): + """Test that checkpoint metadata is created with hash_status='pending' and empty sha256.""" + checkpoints_root = tmp_path / "checkpoints" + checkpoints_root.mkdir() + + # Create a fake checkpoint file (small for testing) + checkpoint_file = checkpoints_root / "test_model.safetensors" + checkpoint_file.write_text("fake checkpoint content", encoding="utf-8") + + normalized_root = _normalize(checkpoints_root) + normalized_file = _normalize(checkpoint_file) + + monkeypatch.setattr( + model_scanner.config, + "base_models_roots", + [normalized_root], + raising=False, + ) + monkeypatch.setattr( + model_scanner.config, + "checkpoints_roots", + [normalized_root], + raising=False, + ) + + scanner = CheckpointScanner() + + # Create default metadata + metadata = await scanner._create_default_metadata(normalized_file) + + assert metadata is not None + assert metadata.sha256 == "", "sha256 should be empty for lazy hash" + assert metadata.hash_status == "pending", "hash_status should be 'pending'" + assert metadata.from_civitai is False, "from_civitai should be False for local models" + assert metadata.file_name == "test_model" + + +@pytest.mark.asyncio +async def test_checkpoint_metadata_saved_to_disk_with_pending_status(tmp_path: Path, monkeypatch): + """Test that pending metadata is saved to .metadata.json file.""" + checkpoints_root = tmp_path / "checkpoints" + checkpoints_root.mkdir() + + checkpoint_file = checkpoints_root / "test_model.safetensors" + checkpoint_file.write_text("fake content", encoding="utf-8") + + normalized_root = _normalize(checkpoints_root) + normalized_file = _normalize(checkpoint_file) + + monkeypatch.setattr( + model_scanner.config, + "base_models_roots", + [normalized_root], + raising=False, + ) + + scanner = CheckpointScanner() + + # Create metadata + metadata = await scanner._create_default_metadata(normalized_file) + assert metadata is not None + + # Verify the metadata file was created + metadata_file = checkpoints_root / "test_model.metadata.json" + assert metadata_file.exists(), "Metadata file should be created" + + # Load and verify content + with open(metadata_file, "r", encoding="utf-8") as f: + saved_data = json.load(f) + + assert saved_data.get("sha256") == "", "Saved sha256 should be empty" + assert saved_data.get("hash_status") == "pending", "Saved hash_status should be 'pending'" + + +@pytest.mark.asyncio +async def test_calculate_hash_for_model_completes_pending(tmp_path: Path, monkeypatch): + """Test that calculate_hash_for_model updates status to 'completed'.""" + checkpoints_root = tmp_path / "checkpoints" + checkpoints_root.mkdir() + + checkpoint_file = checkpoints_root / "test_model.safetensors" + checkpoint_file.write_text("fake content for hashing", encoding="utf-8") + + normalized_root = _normalize(checkpoints_root) + normalized_file = _normalize(checkpoint_file) + + monkeypatch.setattr( + model_scanner.config, + "base_models_roots", + [normalized_root], + raising=False, + ) + + scanner = CheckpointScanner() + + # Create pending metadata + metadata = await scanner._create_default_metadata(normalized_file) + assert metadata is not None + assert metadata.hash_status == "pending" + + # Calculate hash + hash_result = await scanner.calculate_hash_for_model(normalized_file) + + assert hash_result is not None, "Hash calculation should succeed" + assert len(hash_result) == 64, "SHA256 should be 64 hex characters" + + # Verify metadata was updated + metadata_file = checkpoints_root / "test_model.metadata.json" + with open(metadata_file, "r", encoding="utf-8") as f: + saved_data = json.load(f) + + assert saved_data.get("sha256") == hash_result, "sha256 should be updated" + assert saved_data.get("hash_status") == "completed", "hash_status should be 'completed'" + + +@pytest.mark.asyncio +async def test_calculate_hash_skips_if_already_completed(tmp_path: Path, monkeypatch): + """Test that calculate_hash_for_model skips calculation if already completed.""" + checkpoints_root = tmp_path / "checkpoints" + checkpoints_root.mkdir() + + checkpoint_file = checkpoints_root / "test_model.safetensors" + checkpoint_file.write_text("fake content", encoding="utf-8") + + normalized_root = _normalize(checkpoints_root) + normalized_file = _normalize(checkpoint_file) + + monkeypatch.setattr( + model_scanner.config, + "base_models_roots", + [normalized_root], + raising=False, + ) + + scanner = CheckpointScanner() + + # Create metadata with completed hash + metadata = CheckpointMetadata( + file_name="test_model", + model_name="test_model", + file_path=normalized_file, + size=100, + modified=1234567890.0, + sha256="existing_hash_value", + base_model="Unknown", + preview_url="", + hash_status="completed", + from_civitai=False, + ) + + # Save metadata first + from py.utils.metadata_manager import MetadataManager + await MetadataManager.save_metadata(normalized_file, metadata) + + # Calculate hash should return existing value + with patch("py.utils.file_utils.calculate_sha256") as mock_calc: + mock_calc.return_value = "new_calculated_hash" + hash_result = await scanner.calculate_hash_for_model(normalized_file) + + assert hash_result == "existing_hash_value", "Should return existing hash" + mock_calc.assert_not_called(), "Should not recalculate if already completed" + + +@pytest.mark.asyncio +async def test_calculate_all_pending_hashes(tmp_path: Path, monkeypatch): + """Test bulk hash calculation for all pending checkpoints.""" + checkpoints_root = tmp_path / "checkpoints" + checkpoints_root.mkdir() + + # Create multiple checkpoint files + for i in range(3): + checkpoint_file = checkpoints_root / f"model_{i}.safetensors" + checkpoint_file.write_text(f"content {i}", encoding="utf-8") + + normalized_root = _normalize(checkpoints_root) + + monkeypatch.setattr( + model_scanner.config, + "base_models_roots", + [normalized_root], + raising=False, + ) + + scanner = CheckpointScanner() + + # Create pending metadata for all models + for i in range(3): + checkpoint_file = checkpoints_root / f"model_{i}.safetensors" + await scanner._create_default_metadata(_normalize(checkpoint_file)) + + # Mock progress callback + progress_calls = [] + async def progress_callback(current, total, file_path): + progress_calls.append((current, total, file_path)) + + # Calculate all pending hashes + result = await scanner.calculate_all_pending_hashes(progress_callback) + + assert result["total"] == 3, "Should find 3 pending models" + assert result["completed"] == 3, "Should complete all 3" + assert result["failed"] == 0, "Should not fail any" + assert len(progress_calls) == 3, "Progress callback should be called 3 times" + + +@pytest.mark.asyncio +async def test_lora_scanner_not_affected(tmp_path: Path, monkeypatch): + """Test that LoraScanner still calculates hash during initial scan.""" + from py.services.lora_scanner import LoraScanner + + loras_root = tmp_path / "loras" + loras_root.mkdir() + + lora_file = loras_root / "test_lora.safetensors" + lora_file.write_text("fake lora content", encoding="utf-8") + + normalized_root = _normalize(loras_root) + + monkeypatch.setattr( + model_scanner.config, + "loras_roots", + [normalized_root], + raising=False, + ) + + # Reset singleton for LoraScanner + if LoraScanner in ModelScanner._instances: + del ModelScanner._instances[LoraScanner] + + scanner = LoraScanner() + + # LoraScanner should use parent's _create_default_metadata which calculates hash + # We verify this by checking that it doesn't override the method + assert scanner._create_default_metadata.__qualname__ == "ModelScanner._create_default_metadata"