diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index 8e21415d..54f4f45f 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -203,11 +203,17 @@ class ModelListingHandler: result = await self._service.get_paginated_data(**params) format_start = time.perf_counter() + formatted_raw = [ + await self._service.format_response(entry) + for entry in result["items"] + ] + # Filter out None entries returned for corrupted cache rows (issue #730). + # Note: "total" intentionally remains the pre-filter count to reflect + # the true number of models in the cache; corrupted entries are rare + # and adjusting total would cause pagination drift on every page. + formatted_items = [item for item in formatted_raw if item is not None] formatted_result = { - "items": [ - await self._service.format_response(item) - for item in result["items"] - ], + "items": formatted_items, "total": result["total"], "page": result["page"], "page_size": result["page_size"], @@ -238,11 +244,15 @@ class ModelListingHandler: result = await self._service.get_excluded_paginated_data(**params) format_start = time.perf_counter() + formatted_raw = [ + await self._service.format_response(entry) + for entry in result["items"] + ] + # Filter out None entries returned for corrupted cache rows (issue #730). + # "total" stays at the pre-filter count; see get_models for rationale. + formatted_items = [item for item in formatted_raw if item is not None] formatted_result = { - "items": [ - await self._service.format_response(item) - for item in result["items"] - ], + "items": formatted_items, "total": result["total"], "page": result["page"], "page_size": result["page_size"], @@ -533,8 +543,13 @@ class ModelManagementHandler: if not success: return web.json_response({"success": False, "error": error}) - formatted_metadata = await self._service.format_response(model_data) - return web.json_response({"success": True, "metadata": formatted_metadata}) + formatted = await self._service.format_response(model_data) + if formatted is None: + return web.json_response( + {"success": False, "error": "Model entry is corrupted (missing file_path)"}, + status=500, + ) + return web.json_response({"success": True, "metadata": formatted}) except Exception as exc: if is_expected_offline_error(str(exc)): return web.json_response( @@ -1091,10 +1106,12 @@ class ModelQueryHandler: # Sort: originals first, copies last sorted_models = self._sort_duplicate_group(filtered) - # Format response + # Format response, filtering out corrupted entries (issue #730) group = {"hash": sha256, "models": []} for model in sorted_models: - group["models"].append(await self._service.format_response(model)) + formatted = await self._service.format_response(model) + if formatted is not None: + group["models"].append(formatted) # Only include groups with 2+ models after filtering if len(group["models"]) > 1: @@ -1211,9 +1228,9 @@ class ModelQueryHandler: (m for m in cache.raw_data if m["file_path"] == path), None ) if model: - group["models"].append( - await self._service.format_response(model) - ) + formatted = await self._service.format_response(model) + if formatted is not None: + group["models"].append(formatted) hash_val = self._service.scanner.get_hash_by_filename(filename) if hash_val: main_path = self._service.get_path_by_hash(hash_val) @@ -1223,9 +1240,9 @@ class ModelQueryHandler: None, ) if main_model: - group["models"].insert( - 0, await self._service.format_response(main_model) - ) + formatted = await self._service.format_response(main_model) + if formatted is not None: + group["models"].insert(0, formatted) if group["models"]: result.append(group) return web.json_response( diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index 9b91a714..58361448 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -791,8 +791,12 @@ class BaseModelService(ABC): } @abstractmethod - async def format_response(self, model_data: Dict) -> Dict: - """Format model data for API response - must be implemented by subclasses""" + async def format_response(self, model_data: Dict) -> Optional[Dict]: + """Format model data for API response - must be implemented by subclasses. + + Subclasses should return None for corrupted entries so the handler + layer can filter them out. See issue #730. + """ pass # Common service methods that delegate to scanner diff --git a/py/services/checkpoint_service.py b/py/services/checkpoint_service.py index 82cf8aaa..4c171669 100644 --- a/py/services/checkpoint_service.py +++ b/py/services/checkpoint_service.py @@ -1,6 +1,6 @@ import os import logging -from typing import Dict +from typing import Dict, Optional from .base_model_service import BaseModelService from .auto_tag_service import extract_auto_tags @@ -21,20 +21,37 @@ class CheckpointService(BaseModelService): """ super().__init__("checkpoint", scanner, CheckpointMetadata, update_service=update_service) - async def format_response(self, checkpoint_data: Dict) -> Dict: - """Format Checkpoint data for API response""" + async def format_response(self, checkpoint_data: Dict) -> Optional[Dict]: + """Format Checkpoint data for API response. + + Returns None when the entry is missing critical fields (corrupted cache + row), so the handler layer can filter it out. See issue #730. + """ + # Guard against corrupted cache entries missing critical fields + file_path = checkpoint_data.get("file_path") + if not file_path or not isinstance(file_path, str): + logger.warning( + "Skipping corrupted checkpoint entry (missing file_path): %s", + checkpoint_data.get("file_name", ""), + ) + return None + # Get sub_type from cache entry (new canonical field) sub_type = checkpoint_data.get("sub_type", "checkpoint") - + + file_name = checkpoint_data.get("file_name") or "" + model_name = checkpoint_data.get("model_name") or file_name + folder = checkpoint_data.get("folder") or "" + return { - "model_name": checkpoint_data["model_name"], - "file_name": checkpoint_data["file_name"], + "model_name": model_name, + "file_name": file_name, "preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")), "preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0), "base_model": checkpoint_data.get("base_model", ""), - "folder": checkpoint_data["folder"], + "folder": folder, "sha256": checkpoint_data.get("sha256", ""), - "file_path": checkpoint_data["file_path"].replace(os.sep, "/"), + "file_path": file_path.replace(os.sep, "/"), "file_size": checkpoint_data.get("size", 0), "modified": checkpoint_data.get("modified", ""), "tags": checkpoint_data.get("tags", []), diff --git a/py/services/embedding_service.py b/py/services/embedding_service.py index 779cbee3..85666f17 100644 --- a/py/services/embedding_service.py +++ b/py/services/embedding_service.py @@ -1,6 +1,6 @@ import os import logging -from typing import Dict +from typing import Dict, Optional from .base_model_service import BaseModelService from .auto_tag_service import extract_auto_tags @@ -21,20 +21,37 @@ class EmbeddingService(BaseModelService): """ super().__init__("embedding", scanner, EmbeddingMetadata, update_service=update_service) - async def format_response(self, embedding_data: Dict) -> Dict: - """Format Embedding data for API response""" + async def format_response(self, embedding_data: Dict) -> Optional[Dict]: + """Format Embedding data for API response. + + Returns None when the entry is missing critical fields (corrupted cache + row), so the handler layer can filter it out. See issue #730. + """ + # Guard against corrupted cache entries missing critical fields + file_path = embedding_data.get("file_path") + if not file_path or not isinstance(file_path, str): + logger.warning( + "Skipping corrupted embedding entry (missing file_path): %s", + embedding_data.get("file_name", ""), + ) + return None + # Get sub_type from cache entry (new canonical field) sub_type = embedding_data.get("sub_type", "embedding") - + + file_name = embedding_data.get("file_name") or "" + model_name = embedding_data.get("model_name") or file_name + folder = embedding_data.get("folder") or "" + return { - "model_name": embedding_data["model_name"], - "file_name": embedding_data["file_name"], + "model_name": model_name, + "file_name": file_name, "preview_url": config.get_preview_static_url(embedding_data.get("preview_url", "")), "preview_nsfw_level": embedding_data.get("preview_nsfw_level", 0), "base_model": embedding_data.get("base_model", ""), - "folder": embedding_data["folder"], + "folder": folder, "sha256": embedding_data.get("sha256", ""), - "file_path": embedding_data["file_path"].replace(os.sep, "/"), + "file_path": file_path.replace(os.sep, "/"), "file_size": embedding_data.get("size", 0), "modified": embedding_data.get("modified", ""), "tags": embedding_data.get("tags", []), diff --git a/py/services/lora_service.py b/py/services/lora_service.py index 75c21bf7..7d99d245 100644 --- a/py/services/lora_service.py +++ b/py/services/lora_service.py @@ -24,23 +24,41 @@ class LoraService(BaseModelService): """ super().__init__("lora", scanner, LoraMetadata, update_service=update_service) - async def format_response(self, lora_data: Dict) -> Dict: - """Format LoRA data for API response""" + async def format_response(self, lora_data: Dict) -> Optional[Dict]: + """Format LoRA data for API response. + + Returns None when the entry is missing critical fields (corrupted cache + row), so the handler layer can filter it out instead of crashing the + whole listing request. See issue #730. + """ + # Guard against corrupted cache entries missing critical fields + file_path = lora_data.get("file_path") + if not file_path or not isinstance(file_path, str): + logger.warning( + "Skipping corrupted LoRA entry (missing file_path): %s", + lora_data.get("file_name", ""), + ) + return None + # Resolve sub_type using priority: sub_type > model_type > civitai.model.type > default # Normalize to lowercase for consistent API responses sub_type = resolve_sub_type(lora_data).lower() + file_name = lora_data.get("file_name") or "" + model_name = lora_data.get("model_name") or file_name + folder = lora_data.get("folder") or "" + return { - "model_name": lora_data["model_name"], - "file_name": lora_data["file_name"], + "model_name": model_name, + "file_name": file_name, "preview_url": config.get_preview_static_url( lora_data.get("preview_url", "") ), "preview_nsfw_level": lora_data.get("preview_nsfw_level", 0), "base_model": lora_data.get("base_model", ""), - "folder": lora_data["folder"], + "folder": folder, "sha256": lora_data.get("sha256", ""), - "file_path": lora_data["file_path"].replace(os.sep, "/"), + "file_path": file_path.replace(os.sep, "/"), "file_size": lora_data.get("size", 0), "modified": lora_data.get("modified", ""), "tags": lora_data.get("tags", []), diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index 5380ae08..576da54b 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -476,11 +476,20 @@ class ModelScanner: for tag in adjusted_item.get('tags') or []: tags_count[tag] = tags_count.get(tag, 0) + 1 - # Validate cache entries and check health + # Validate cache entries and check health. + # Always use the validated/repaired entries — even when there are no + # invalid entries, auto_repair may have filled in missing optional + # fields (model_name, file_name, folder) with safe defaults on a copied + # working_entry. Without this unconditional replacement the repaired + # copies are discarded and None values propagate to format_response. + # See issue #730. valid_entries, invalid_entries = CacheEntryValidator.validate_batch( adjusted_raw_data, auto_repair=True ) + # Always use the validated entries (repaired copies) + adjusted_raw_data = valid_entries + if invalid_entries: monitor = CacheHealthMonitor() report = monitor.check_health(adjusted_raw_data, auto_repair=True) diff --git a/py/services/persistent_model_cache.py b/py/services/persistent_model_cache.py index e2cc0dd0..0f4e84f0 100644 --- a/py/services/persistent_model_cache.py +++ b/py/services/persistent_model_cache.py @@ -165,8 +165,8 @@ class PersistentModelCache: item = { "file_path": file_path, - "file_name": row["file_name"], - "model_name": row["model_name"], + "file_name": row["file_name"] or "", + "model_name": row["model_name"] or "", "folder": row["folder"] or "", "size": row["size"] or 0, "modified": row["modified"] or 0.0, @@ -548,19 +548,19 @@ class PersistentModelCache: return ( model_type, item.get("file_path"), - item.get("file_name"), - item.get("model_name"), - item.get("folder"), + item.get("file_name") or "", + item.get("model_name") or "", + item.get("folder") or "", int(item.get("size") or 0), float(item.get("modified") or 0.0), (item.get("sha256") or "").lower() or None, - item.get("base_model"), - item.get("preview_url"), + item.get("base_model") or "", + item.get("preview_url") or "", int(item.get("preview_nsfw_level") or 0), 1 if item.get("from_civitai", True) else 0, 1 if item.get("favorite") else 0, - item.get("notes"), - item.get("usage_tips"), + item.get("notes") or "", + item.get("usage_tips") or "", metadata_source, civitai.get("id"), civitai.get("modelId"), diff --git a/tests/routes/test_base_model_routes_smoke.py b/tests/routes/test_base_model_routes_smoke.py index 5a41e502..f580d855 100644 --- a/tests/routes/test_base_model_routes_smoke.py +++ b/tests/routes/test_base_model_routes_smoke.py @@ -201,6 +201,45 @@ def test_list_models_returns_formatted_items(mock_service, mock_scanner): asyncio.run(scenario()) +def test_list_models_filters_out_corrupted_entries(mock_service, mock_scanner): + """Corrupted cache entries (format_response returns None) must not appear + in the response items nor cause a 500. See issue #730. + """ + mock_service.paginated_items = [ + {"file_path": "/tmp/good.safetensors", "name": "Good"}, + {"file_path": None, "name": "Corrupted"}, # triggers None from format_response + {"file_path": "/tmp/also_good.safetensors", "name": "AlsoGood"}, + ] + + # Override format_response to return None for corrupted entries + original_format = mock_service.format_response + + async def conditional_format(item): + if item.get("file_path") is None: + return None + return await original_format(item) + + mock_service.format_response = conditional_format + + async def scenario(): + client = await create_test_client(mock_service) + try: + response = await client.get("/api/lm/test-models/list") + payload = await response.json() + + assert response.status == 200 + # Only the 2 non-corrupted entries should appear + assert len(payload["items"]) == 2 + assert payload["items"][0]["name"] == "Good" + assert payload["items"][1]["name"] == "AlsoGood" + # None should never appear in the items list + assert None not in payload["items"] + finally: + await client.close() + + asyncio.run(scenario()) + + def test_model_types_endpoint_returns_counts(mock_service, mock_scanner): mock_service.model_types = [ {"type": "LoRa", "count": 3}, diff --git a/tests/services/test_service_format_response_sub_type.py b/tests/services/test_service_format_response_sub_type.py index 89c139e6..b1f3e192 100644 --- a/tests/services/test_service_format_response_sub_type.py +++ b/tests/services/test_service_format_response_sub_type.py @@ -199,8 +199,107 @@ class TestEmbeddingServiceFormatResponse: "from_civitai": True, "civitai": {}, } - + result = await embedding_service.format_response(embedding_data) - + assert result["sub_type"] == "embedding" assert "model_type" not in result # Removed in refactoring + + +class TestFormatResponseCorruptedEntries: + """Test format_response handles corrupted cache entries gracefully (issue #730). + + When cache rows have None/missing critical fields (e.g. from a partially + written or legacy DB), format_response must NOT raise KeyError/AttributeError. + Instead it returns None so the handler layer can filter the bad entry out + instead of failing the entire listing request. + """ + + @pytest.fixture + def mock_scanner(self): + scanner = MagicMock() + scanner._hash_index = MagicMock() + return scanner + + @pytest.fixture + def lora_service(self, mock_scanner): + return LoraService(mock_scanner) + + @pytest.fixture + def checkpoint_service(self, mock_scanner): + return CheckpointService(mock_scanner) + + @pytest.fixture + def embedding_service(self, mock_scanner): + return EmbeddingService(mock_scanner) + + @pytest.mark.asyncio + async def test_lora_returns_none_on_missing_file_path(self, lora_service): + """format_response returns None when file_path is missing (corrupted row).""" + lora_data = { + "model_name": "Test LoRA", + "file_name": "test_lora", + "file_path": None, # corrupted: missing file_path + "folder": "", + "sha256": "abc123", + "tags": [], + "from_civitai": True, + "civitai": {}, + } + result = await lora_service.format_response(lora_data) + assert result is None + + @pytest.mark.asyncio + async def test_lora_handles_none_model_name_gracefully(self, lora_service): + """format_response should not crash when model_name is None (legacy DB row).""" + lora_data = { + "model_name": None, # NULL from old DB row + "file_name": "test_lora", + "file_path": "/models/test_lora.safetensors", + "folder": "", + "sha256": "abc123", + "tags": [], + "from_civitai": True, + "civitai": {}, + } + result = await lora_service.format_response(lora_data) + # Should not raise; model_name falls back to file_name + assert result is not None + assert result["model_name"] == "test_lora" + + @pytest.mark.asyncio + async def test_checkpoint_returns_none_on_missing_file_path(self, checkpoint_service): + """format_response returns None when file_path is missing (corrupted row).""" + checkpoint_data = { + "model_name": "Test", + "file_name": "test", + "file_path": "", # empty string == corrupted + "folder": "", + "sha256": "abc", + "tags": [], + "from_civitai": True, + "civitai": {}, + "sub_type": "checkpoint", + } + result = await checkpoint_service.format_response(checkpoint_data) + assert result is None + + @pytest.mark.asyncio + async def test_embedding_handles_none_fields_gracefully(self, embedding_service): + """format_response should not crash when optional fields are None.""" + embedding_data = { + "model_name": None, + "file_name": None, + "file_path": "/models/test.pt", + "folder": None, + "sha256": "abc", + "tags": [], + "from_civitai": True, + "civitai": {}, + "sub_type": "embedding", + } + result = await embedding_service.format_response(embedding_data) + assert result is not None + assert result["file_path"] == "/models/test.pt" + # model_name falls back to file_name which falls back to "" + assert result["model_name"] == ""