fix(cache): prevent corrupted cache rows from breaking model listings (#730)

Cache corruption (NULL model_name/file_name from legacy DB rows or partial
writes) caused format_response to raise KeyError/AttributeError, failing the
entire /loras/list request and showing no models in the UI.

Fix across three layers:
- format_response (lora/checkpoint/embedding): replace direct dict[] access
  with .get() fallbacks; return None for entries missing file_path
- handlers: filter None entries from list/excluded/fetch/duplicate/conflict
  endpoints instead of letting them crash or appear as null in responses
- model_scanner: always use validate_batch repaired copies (previously
  discarded when no invalid entries, leaving None values in raw_data)
- persistent_model_cache: add or-empty-string guards on read and write for
  nullable TEXT columns (model_name, file_name, folder, base_model, etc.)
This commit is contained in:
Will Miao
2026-06-30 09:02:42 +08:00
parent 28e7c04b37
commit 16f5222efd
9 changed files with 274 additions and 54 deletions

View File

@@ -203,11 +203,17 @@ class ModelListingHandler:
result = await self._service.get_paginated_data(**params) result = await self._service.get_paginated_data(**params)
format_start = time.perf_counter() format_start = time.perf_counter()
formatted_raw = [
await self._service.format_response(entry)
for entry in result["items"]
]
# Filter out None entries returned for corrupted cache rows (issue #730).
# Note: "total" intentionally remains the pre-filter count to reflect
# the true number of models in the cache; corrupted entries are rare
# and adjusting total would cause pagination drift on every page.
formatted_items = [item for item in formatted_raw if item is not None]
formatted_result = { formatted_result = {
"items": [ "items": formatted_items,
await self._service.format_response(item)
for item in result["items"]
],
"total": result["total"], "total": result["total"],
"page": result["page"], "page": result["page"],
"page_size": result["page_size"], "page_size": result["page_size"],
@@ -238,11 +244,15 @@ class ModelListingHandler:
result = await self._service.get_excluded_paginated_data(**params) result = await self._service.get_excluded_paginated_data(**params)
format_start = time.perf_counter() format_start = time.perf_counter()
formatted_raw = [
await self._service.format_response(entry)
for entry in result["items"]
]
# Filter out None entries returned for corrupted cache rows (issue #730).
# "total" stays at the pre-filter count; see get_models for rationale.
formatted_items = [item for item in formatted_raw if item is not None]
formatted_result = { formatted_result = {
"items": [ "items": formatted_items,
await self._service.format_response(item)
for item in result["items"]
],
"total": result["total"], "total": result["total"],
"page": result["page"], "page": result["page"],
"page_size": result["page_size"], "page_size": result["page_size"],
@@ -533,8 +543,13 @@ class ModelManagementHandler:
if not success: if not success:
return web.json_response({"success": False, "error": error}) return web.json_response({"success": False, "error": error})
formatted_metadata = await self._service.format_response(model_data) formatted = await self._service.format_response(model_data)
return web.json_response({"success": True, "metadata": formatted_metadata}) if formatted is None:
return web.json_response(
{"success": False, "error": "Model entry is corrupted (missing file_path)"},
status=500,
)
return web.json_response({"success": True, "metadata": formatted})
except Exception as exc: except Exception as exc:
if is_expected_offline_error(str(exc)): if is_expected_offline_error(str(exc)):
return web.json_response( return web.json_response(
@@ -1091,10 +1106,12 @@ class ModelQueryHandler:
# Sort: originals first, copies last # Sort: originals first, copies last
sorted_models = self._sort_duplicate_group(filtered) sorted_models = self._sort_duplicate_group(filtered)
# Format response # Format response, filtering out corrupted entries (issue #730)
group = {"hash": sha256, "models": []} group = {"hash": sha256, "models": []}
for model in sorted_models: for model in sorted_models:
group["models"].append(await self._service.format_response(model)) formatted = await self._service.format_response(model)
if formatted is not None:
group["models"].append(formatted)
# Only include groups with 2+ models after filtering # Only include groups with 2+ models after filtering
if len(group["models"]) > 1: if len(group["models"]) > 1:
@@ -1211,9 +1228,9 @@ class ModelQueryHandler:
(m for m in cache.raw_data if m["file_path"] == path), None (m for m in cache.raw_data if m["file_path"] == path), None
) )
if model: if model:
group["models"].append( formatted = await self._service.format_response(model)
await self._service.format_response(model) if formatted is not None:
) group["models"].append(formatted)
hash_val = self._service.scanner.get_hash_by_filename(filename) hash_val = self._service.scanner.get_hash_by_filename(filename)
if hash_val: if hash_val:
main_path = self._service.get_path_by_hash(hash_val) main_path = self._service.get_path_by_hash(hash_val)
@@ -1223,9 +1240,9 @@ class ModelQueryHandler:
None, None,
) )
if main_model: if main_model:
group["models"].insert( formatted = await self._service.format_response(main_model)
0, await self._service.format_response(main_model) if formatted is not None:
) group["models"].insert(0, formatted)
if group["models"]: if group["models"]:
result.append(group) result.append(group)
return web.json_response( return web.json_response(

View File

@@ -791,8 +791,12 @@ class BaseModelService(ABC):
} }
@abstractmethod @abstractmethod
async def format_response(self, model_data: Dict) -> Dict: async def format_response(self, model_data: Dict) -> Optional[Dict]:
"""Format model data for API response - must be implemented by subclasses""" """Format model data for API response - must be implemented by subclasses.
Subclasses should return None for corrupted entries so the handler
layer can filter them out. See issue #730.
"""
pass pass
# Common service methods that delegate to scanner # Common service methods that delegate to scanner

View File

@@ -1,6 +1,6 @@
import os import os
import logging import logging
from typing import Dict from typing import Dict, Optional
from .base_model_service import BaseModelService from .base_model_service import BaseModelService
from .auto_tag_service import extract_auto_tags from .auto_tag_service import extract_auto_tags
@@ -21,20 +21,37 @@ class CheckpointService(BaseModelService):
""" """
super().__init__("checkpoint", scanner, CheckpointMetadata, update_service=update_service) super().__init__("checkpoint", scanner, CheckpointMetadata, update_service=update_service)
async def format_response(self, checkpoint_data: Dict) -> Dict: async def format_response(self, checkpoint_data: Dict) -> Optional[Dict]:
"""Format Checkpoint data for API response""" """Format Checkpoint data for API response.
Returns None when the entry is missing critical fields (corrupted cache
row), so the handler layer can filter it out. See issue #730.
"""
# Guard against corrupted cache entries missing critical fields
file_path = checkpoint_data.get("file_path")
if not file_path or not isinstance(file_path, str):
logger.warning(
"Skipping corrupted checkpoint entry (missing file_path): %s",
checkpoint_data.get("file_name", "<unknown>"),
)
return None
# Get sub_type from cache entry (new canonical field) # Get sub_type from cache entry (new canonical field)
sub_type = checkpoint_data.get("sub_type", "checkpoint") sub_type = checkpoint_data.get("sub_type", "checkpoint")
file_name = checkpoint_data.get("file_name") or ""
model_name = checkpoint_data.get("model_name") or file_name
folder = checkpoint_data.get("folder") or ""
return { return {
"model_name": checkpoint_data["model_name"], "model_name": model_name,
"file_name": checkpoint_data["file_name"], "file_name": file_name,
"preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")), "preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")),
"preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0), "preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0),
"base_model": checkpoint_data.get("base_model", ""), "base_model": checkpoint_data.get("base_model", ""),
"folder": checkpoint_data["folder"], "folder": folder,
"sha256": checkpoint_data.get("sha256", ""), "sha256": checkpoint_data.get("sha256", ""),
"file_path": checkpoint_data["file_path"].replace(os.sep, "/"), "file_path": file_path.replace(os.sep, "/"),
"file_size": checkpoint_data.get("size", 0), "file_size": checkpoint_data.get("size", 0),
"modified": checkpoint_data.get("modified", ""), "modified": checkpoint_data.get("modified", ""),
"tags": checkpoint_data.get("tags", []), "tags": checkpoint_data.get("tags", []),

View File

@@ -1,6 +1,6 @@
import os import os
import logging import logging
from typing import Dict from typing import Dict, Optional
from .base_model_service import BaseModelService from .base_model_service import BaseModelService
from .auto_tag_service import extract_auto_tags from .auto_tag_service import extract_auto_tags
@@ -21,20 +21,37 @@ class EmbeddingService(BaseModelService):
""" """
super().__init__("embedding", scanner, EmbeddingMetadata, update_service=update_service) super().__init__("embedding", scanner, EmbeddingMetadata, update_service=update_service)
async def format_response(self, embedding_data: Dict) -> Dict: async def format_response(self, embedding_data: Dict) -> Optional[Dict]:
"""Format Embedding data for API response""" """Format Embedding data for API response.
Returns None when the entry is missing critical fields (corrupted cache
row), so the handler layer can filter it out. See issue #730.
"""
# Guard against corrupted cache entries missing critical fields
file_path = embedding_data.get("file_path")
if not file_path or not isinstance(file_path, str):
logger.warning(
"Skipping corrupted embedding entry (missing file_path): %s",
embedding_data.get("file_name", "<unknown>"),
)
return None
# Get sub_type from cache entry (new canonical field) # Get sub_type from cache entry (new canonical field)
sub_type = embedding_data.get("sub_type", "embedding") sub_type = embedding_data.get("sub_type", "embedding")
file_name = embedding_data.get("file_name") or ""
model_name = embedding_data.get("model_name") or file_name
folder = embedding_data.get("folder") or ""
return { return {
"model_name": embedding_data["model_name"], "model_name": model_name,
"file_name": embedding_data["file_name"], "file_name": file_name,
"preview_url": config.get_preview_static_url(embedding_data.get("preview_url", "")), "preview_url": config.get_preview_static_url(embedding_data.get("preview_url", "")),
"preview_nsfw_level": embedding_data.get("preview_nsfw_level", 0), "preview_nsfw_level": embedding_data.get("preview_nsfw_level", 0),
"base_model": embedding_data.get("base_model", ""), "base_model": embedding_data.get("base_model", ""),
"folder": embedding_data["folder"], "folder": folder,
"sha256": embedding_data.get("sha256", ""), "sha256": embedding_data.get("sha256", ""),
"file_path": embedding_data["file_path"].replace(os.sep, "/"), "file_path": file_path.replace(os.sep, "/"),
"file_size": embedding_data.get("size", 0), "file_size": embedding_data.get("size", 0),
"modified": embedding_data.get("modified", ""), "modified": embedding_data.get("modified", ""),
"tags": embedding_data.get("tags", []), "tags": embedding_data.get("tags", []),

View File

@@ -24,23 +24,41 @@ class LoraService(BaseModelService):
""" """
super().__init__("lora", scanner, LoraMetadata, update_service=update_service) super().__init__("lora", scanner, LoraMetadata, update_service=update_service)
async def format_response(self, lora_data: Dict) -> Dict: async def format_response(self, lora_data: Dict) -> Optional[Dict]:
"""Format LoRA data for API response""" """Format LoRA data for API response.
Returns None when the entry is missing critical fields (corrupted cache
row), so the handler layer can filter it out instead of crashing the
whole listing request. See issue #730.
"""
# Guard against corrupted cache entries missing critical fields
file_path = lora_data.get("file_path")
if not file_path or not isinstance(file_path, str):
logger.warning(
"Skipping corrupted LoRA entry (missing file_path): %s",
lora_data.get("file_name", "<unknown>"),
)
return None
# Resolve sub_type using priority: sub_type > model_type > civitai.model.type > default # Resolve sub_type using priority: sub_type > model_type > civitai.model.type > default
# Normalize to lowercase for consistent API responses # Normalize to lowercase for consistent API responses
sub_type = resolve_sub_type(lora_data).lower() sub_type = resolve_sub_type(lora_data).lower()
file_name = lora_data.get("file_name") or ""
model_name = lora_data.get("model_name") or file_name
folder = lora_data.get("folder") or ""
return { return {
"model_name": lora_data["model_name"], "model_name": model_name,
"file_name": lora_data["file_name"], "file_name": file_name,
"preview_url": config.get_preview_static_url( "preview_url": config.get_preview_static_url(
lora_data.get("preview_url", "") lora_data.get("preview_url", "")
), ),
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0), "preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
"base_model": lora_data.get("base_model", ""), "base_model": lora_data.get("base_model", ""),
"folder": lora_data["folder"], "folder": folder,
"sha256": lora_data.get("sha256", ""), "sha256": lora_data.get("sha256", ""),
"file_path": lora_data["file_path"].replace(os.sep, "/"), "file_path": file_path.replace(os.sep, "/"),
"file_size": lora_data.get("size", 0), "file_size": lora_data.get("size", 0),
"modified": lora_data.get("modified", ""), "modified": lora_data.get("modified", ""),
"tags": lora_data.get("tags", []), "tags": lora_data.get("tags", []),

View File

@@ -476,11 +476,20 @@ class ModelScanner:
for tag in adjusted_item.get('tags') or []: for tag in adjusted_item.get('tags') or []:
tags_count[tag] = tags_count.get(tag, 0) + 1 tags_count[tag] = tags_count.get(tag, 0) + 1
# Validate cache entries and check health # Validate cache entries and check health.
# Always use the validated/repaired entries — even when there are no
# invalid entries, auto_repair may have filled in missing optional
# fields (model_name, file_name, folder) with safe defaults on a copied
# working_entry. Without this unconditional replacement the repaired
# copies are discarded and None values propagate to format_response.
# See issue #730.
valid_entries, invalid_entries = CacheEntryValidator.validate_batch( valid_entries, invalid_entries = CacheEntryValidator.validate_batch(
adjusted_raw_data, auto_repair=True adjusted_raw_data, auto_repair=True
) )
# Always use the validated entries (repaired copies)
adjusted_raw_data = valid_entries
if invalid_entries: if invalid_entries:
monitor = CacheHealthMonitor() monitor = CacheHealthMonitor()
report = monitor.check_health(adjusted_raw_data, auto_repair=True) report = monitor.check_health(adjusted_raw_data, auto_repair=True)

View File

@@ -165,8 +165,8 @@ class PersistentModelCache:
item = { item = {
"file_path": file_path, "file_path": file_path,
"file_name": row["file_name"], "file_name": row["file_name"] or "",
"model_name": row["model_name"], "model_name": row["model_name"] or "",
"folder": row["folder"] or "", "folder": row["folder"] or "",
"size": row["size"] or 0, "size": row["size"] or 0,
"modified": row["modified"] or 0.0, "modified": row["modified"] or 0.0,
@@ -548,19 +548,19 @@ class PersistentModelCache:
return ( return (
model_type, model_type,
item.get("file_path"), item.get("file_path"),
item.get("file_name"), item.get("file_name") or "",
item.get("model_name"), item.get("model_name") or "",
item.get("folder"), item.get("folder") or "",
int(item.get("size") or 0), int(item.get("size") or 0),
float(item.get("modified") or 0.0), float(item.get("modified") or 0.0),
(item.get("sha256") or "").lower() or None, (item.get("sha256") or "").lower() or None,
item.get("base_model"), item.get("base_model") or "",
item.get("preview_url"), item.get("preview_url") or "",
int(item.get("preview_nsfw_level") or 0), int(item.get("preview_nsfw_level") or 0),
1 if item.get("from_civitai", True) else 0, 1 if item.get("from_civitai", True) else 0,
1 if item.get("favorite") else 0, 1 if item.get("favorite") else 0,
item.get("notes"), item.get("notes") or "",
item.get("usage_tips"), item.get("usage_tips") or "",
metadata_source, metadata_source,
civitai.get("id"), civitai.get("id"),
civitai.get("modelId"), civitai.get("modelId"),

View File

@@ -201,6 +201,45 @@ def test_list_models_returns_formatted_items(mock_service, mock_scanner):
asyncio.run(scenario()) asyncio.run(scenario())
def test_list_models_filters_out_corrupted_entries(mock_service, mock_scanner):
"""Corrupted cache entries (format_response returns None) must not appear
in the response items nor cause a 500. See issue #730.
"""
mock_service.paginated_items = [
{"file_path": "/tmp/good.safetensors", "name": "Good"},
{"file_path": None, "name": "Corrupted"}, # triggers None from format_response
{"file_path": "/tmp/also_good.safetensors", "name": "AlsoGood"},
]
# Override format_response to return None for corrupted entries
original_format = mock_service.format_response
async def conditional_format(item):
if item.get("file_path") is None:
return None
return await original_format(item)
mock_service.format_response = conditional_format
async def scenario():
client = await create_test_client(mock_service)
try:
response = await client.get("/api/lm/test-models/list")
payload = await response.json()
assert response.status == 200
# Only the 2 non-corrupted entries should appear
assert len(payload["items"]) == 2
assert payload["items"][0]["name"] == "Good"
assert payload["items"][1]["name"] == "AlsoGood"
# None should never appear in the items list
assert None not in payload["items"]
finally:
await client.close()
asyncio.run(scenario())
def test_model_types_endpoint_returns_counts(mock_service, mock_scanner): def test_model_types_endpoint_returns_counts(mock_service, mock_scanner):
mock_service.model_types = [ mock_service.model_types = [
{"type": "LoRa", "count": 3}, {"type": "LoRa", "count": 3},

View File

@@ -204,3 +204,102 @@ class TestEmbeddingServiceFormatResponse:
assert result["sub_type"] == "embedding" assert result["sub_type"] == "embedding"
assert "model_type" not in result # Removed in refactoring assert "model_type" not in result # Removed in refactoring
class TestFormatResponseCorruptedEntries:
"""Test format_response handles corrupted cache entries gracefully (issue #730).
When cache rows have None/missing critical fields (e.g. from a partially
written or legacy DB), format_response must NOT raise KeyError/AttributeError.
Instead it returns None so the handler layer can filter the bad entry out
instead of failing the entire listing request.
"""
@pytest.fixture
def mock_scanner(self):
scanner = MagicMock()
scanner._hash_index = MagicMock()
return scanner
@pytest.fixture
def lora_service(self, mock_scanner):
return LoraService(mock_scanner)
@pytest.fixture
def checkpoint_service(self, mock_scanner):
return CheckpointService(mock_scanner)
@pytest.fixture
def embedding_service(self, mock_scanner):
return EmbeddingService(mock_scanner)
@pytest.mark.asyncio
async def test_lora_returns_none_on_missing_file_path(self, lora_service):
"""format_response returns None when file_path is missing (corrupted row)."""
lora_data = {
"model_name": "Test LoRA",
"file_name": "test_lora",
"file_path": None, # corrupted: missing file_path
"folder": "",
"sha256": "abc123",
"tags": [],
"from_civitai": True,
"civitai": {},
}
result = await lora_service.format_response(lora_data)
assert result is None
@pytest.mark.asyncio
async def test_lora_handles_none_model_name_gracefully(self, lora_service):
"""format_response should not crash when model_name is None (legacy DB row)."""
lora_data = {
"model_name": None, # NULL from old DB row
"file_name": "test_lora",
"file_path": "/models/test_lora.safetensors",
"folder": "",
"sha256": "abc123",
"tags": [],
"from_civitai": True,
"civitai": {},
}
result = await lora_service.format_response(lora_data)
# Should not raise; model_name falls back to file_name
assert result is not None
assert result["model_name"] == "test_lora"
@pytest.mark.asyncio
async def test_checkpoint_returns_none_on_missing_file_path(self, checkpoint_service):
"""format_response returns None when file_path is missing (corrupted row)."""
checkpoint_data = {
"model_name": "Test",
"file_name": "test",
"file_path": "", # empty string == corrupted
"folder": "",
"sha256": "abc",
"tags": [],
"from_civitai": True,
"civitai": {},
"sub_type": "checkpoint",
}
result = await checkpoint_service.format_response(checkpoint_data)
assert result is None
@pytest.mark.asyncio
async def test_embedding_handles_none_fields_gracefully(self, embedding_service):
"""format_response should not crash when optional fields are None."""
embedding_data = {
"model_name": None,
"file_name": None,
"file_path": "/models/test.pt",
"folder": None,
"sha256": "abc",
"tags": [],
"from_civitai": True,
"civitai": {},
"sub_type": "embedding",
}
result = await embedding_service.format_response(embedding_data)
assert result is not None
assert result["file_path"] == "/models/test.pt"
# model_name falls back to file_name which falls back to ""
assert result["model_name"] == ""