mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-07-02 23:41:16 -03:00
fix(cache): prevent corrupted cache rows from breaking model listings (#730)
Cache corruption (NULL model_name/file_name from legacy DB rows or partial writes) caused format_response to raise KeyError/AttributeError, failing the entire /loras/list request and showing no models in the UI. Fix across three layers: - format_response (lora/checkpoint/embedding): replace direct dict[] access with .get() fallbacks; return None for entries missing file_path - handlers: filter None entries from list/excluded/fetch/duplicate/conflict endpoints instead of letting them crash or appear as null in responses - model_scanner: always use validate_batch repaired copies (previously discarded when no invalid entries, leaving None values in raw_data) - persistent_model_cache: add or-empty-string guards on read and write for nullable TEXT columns (model_name, file_name, folder, base_model, etc.)
This commit is contained in:
@@ -203,11 +203,17 @@ class ModelListingHandler:
|
|||||||
result = await self._service.get_paginated_data(**params)
|
result = await self._service.get_paginated_data(**params)
|
||||||
|
|
||||||
format_start = time.perf_counter()
|
format_start = time.perf_counter()
|
||||||
|
formatted_raw = [
|
||||||
|
await self._service.format_response(entry)
|
||||||
|
for entry in result["items"]
|
||||||
|
]
|
||||||
|
# Filter out None entries returned for corrupted cache rows (issue #730).
|
||||||
|
# Note: "total" intentionally remains the pre-filter count to reflect
|
||||||
|
# the true number of models in the cache; corrupted entries are rare
|
||||||
|
# and adjusting total would cause pagination drift on every page.
|
||||||
|
formatted_items = [item for item in formatted_raw if item is not None]
|
||||||
formatted_result = {
|
formatted_result = {
|
||||||
"items": [
|
"items": formatted_items,
|
||||||
await self._service.format_response(item)
|
|
||||||
for item in result["items"]
|
|
||||||
],
|
|
||||||
"total": result["total"],
|
"total": result["total"],
|
||||||
"page": result["page"],
|
"page": result["page"],
|
||||||
"page_size": result["page_size"],
|
"page_size": result["page_size"],
|
||||||
@@ -238,11 +244,15 @@ class ModelListingHandler:
|
|||||||
result = await self._service.get_excluded_paginated_data(**params)
|
result = await self._service.get_excluded_paginated_data(**params)
|
||||||
|
|
||||||
format_start = time.perf_counter()
|
format_start = time.perf_counter()
|
||||||
|
formatted_raw = [
|
||||||
|
await self._service.format_response(entry)
|
||||||
|
for entry in result["items"]
|
||||||
|
]
|
||||||
|
# Filter out None entries returned for corrupted cache rows (issue #730).
|
||||||
|
# "total" stays at the pre-filter count; see get_models for rationale.
|
||||||
|
formatted_items = [item for item in formatted_raw if item is not None]
|
||||||
formatted_result = {
|
formatted_result = {
|
||||||
"items": [
|
"items": formatted_items,
|
||||||
await self._service.format_response(item)
|
|
||||||
for item in result["items"]
|
|
||||||
],
|
|
||||||
"total": result["total"],
|
"total": result["total"],
|
||||||
"page": result["page"],
|
"page": result["page"],
|
||||||
"page_size": result["page_size"],
|
"page_size": result["page_size"],
|
||||||
@@ -533,8 +543,13 @@ class ModelManagementHandler:
|
|||||||
if not success:
|
if not success:
|
||||||
return web.json_response({"success": False, "error": error})
|
return web.json_response({"success": False, "error": error})
|
||||||
|
|
||||||
formatted_metadata = await self._service.format_response(model_data)
|
formatted = await self._service.format_response(model_data)
|
||||||
return web.json_response({"success": True, "metadata": formatted_metadata})
|
if formatted is None:
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": "Model entry is corrupted (missing file_path)"},
|
||||||
|
status=500,
|
||||||
|
)
|
||||||
|
return web.json_response({"success": True, "metadata": formatted})
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if is_expected_offline_error(str(exc)):
|
if is_expected_offline_error(str(exc)):
|
||||||
return web.json_response(
|
return web.json_response(
|
||||||
@@ -1091,10 +1106,12 @@ class ModelQueryHandler:
|
|||||||
# Sort: originals first, copies last
|
# Sort: originals first, copies last
|
||||||
sorted_models = self._sort_duplicate_group(filtered)
|
sorted_models = self._sort_duplicate_group(filtered)
|
||||||
|
|
||||||
# Format response
|
# Format response, filtering out corrupted entries (issue #730)
|
||||||
group = {"hash": sha256, "models": []}
|
group = {"hash": sha256, "models": []}
|
||||||
for model in sorted_models:
|
for model in sorted_models:
|
||||||
group["models"].append(await self._service.format_response(model))
|
formatted = await self._service.format_response(model)
|
||||||
|
if formatted is not None:
|
||||||
|
group["models"].append(formatted)
|
||||||
|
|
||||||
# Only include groups with 2+ models after filtering
|
# Only include groups with 2+ models after filtering
|
||||||
if len(group["models"]) > 1:
|
if len(group["models"]) > 1:
|
||||||
@@ -1211,9 +1228,9 @@ class ModelQueryHandler:
|
|||||||
(m for m in cache.raw_data if m["file_path"] == path), None
|
(m for m in cache.raw_data if m["file_path"] == path), None
|
||||||
)
|
)
|
||||||
if model:
|
if model:
|
||||||
group["models"].append(
|
formatted = await self._service.format_response(model)
|
||||||
await self._service.format_response(model)
|
if formatted is not None:
|
||||||
)
|
group["models"].append(formatted)
|
||||||
hash_val = self._service.scanner.get_hash_by_filename(filename)
|
hash_val = self._service.scanner.get_hash_by_filename(filename)
|
||||||
if hash_val:
|
if hash_val:
|
||||||
main_path = self._service.get_path_by_hash(hash_val)
|
main_path = self._service.get_path_by_hash(hash_val)
|
||||||
@@ -1223,9 +1240,9 @@ class ModelQueryHandler:
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
if main_model:
|
if main_model:
|
||||||
group["models"].insert(
|
formatted = await self._service.format_response(main_model)
|
||||||
0, await self._service.format_response(main_model)
|
if formatted is not None:
|
||||||
)
|
group["models"].insert(0, formatted)
|
||||||
if group["models"]:
|
if group["models"]:
|
||||||
result.append(group)
|
result.append(group)
|
||||||
return web.json_response(
|
return web.json_response(
|
||||||
|
|||||||
@@ -791,8 +791,12 @@ class BaseModelService(ABC):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def format_response(self, model_data: Dict) -> Dict:
|
async def format_response(self, model_data: Dict) -> Optional[Dict]:
|
||||||
"""Format model data for API response - must be implemented by subclasses"""
|
"""Format model data for API response - must be implemented by subclasses.
|
||||||
|
|
||||||
|
Subclasses should return None for corrupted entries so the handler
|
||||||
|
layer can filter them out. See issue #730.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Common service methods that delegate to scanner
|
# Common service methods that delegate to scanner
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from .base_model_service import BaseModelService
|
from .base_model_service import BaseModelService
|
||||||
from .auto_tag_service import extract_auto_tags
|
from .auto_tag_service import extract_auto_tags
|
||||||
@@ -21,20 +21,37 @@ class CheckpointService(BaseModelService):
|
|||||||
"""
|
"""
|
||||||
super().__init__("checkpoint", scanner, CheckpointMetadata, update_service=update_service)
|
super().__init__("checkpoint", scanner, CheckpointMetadata, update_service=update_service)
|
||||||
|
|
||||||
async def format_response(self, checkpoint_data: Dict) -> Dict:
|
async def format_response(self, checkpoint_data: Dict) -> Optional[Dict]:
|
||||||
"""Format Checkpoint data for API response"""
|
"""Format Checkpoint data for API response.
|
||||||
|
|
||||||
|
Returns None when the entry is missing critical fields (corrupted cache
|
||||||
|
row), so the handler layer can filter it out. See issue #730.
|
||||||
|
"""
|
||||||
|
# Guard against corrupted cache entries missing critical fields
|
||||||
|
file_path = checkpoint_data.get("file_path")
|
||||||
|
if not file_path or not isinstance(file_path, str):
|
||||||
|
logger.warning(
|
||||||
|
"Skipping corrupted checkpoint entry (missing file_path): %s",
|
||||||
|
checkpoint_data.get("file_name", "<unknown>"),
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
# Get sub_type from cache entry (new canonical field)
|
# Get sub_type from cache entry (new canonical field)
|
||||||
sub_type = checkpoint_data.get("sub_type", "checkpoint")
|
sub_type = checkpoint_data.get("sub_type", "checkpoint")
|
||||||
|
|
||||||
|
file_name = checkpoint_data.get("file_name") or ""
|
||||||
|
model_name = checkpoint_data.get("model_name") or file_name
|
||||||
|
folder = checkpoint_data.get("folder") or ""
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model_name": checkpoint_data["model_name"],
|
"model_name": model_name,
|
||||||
"file_name": checkpoint_data["file_name"],
|
"file_name": file_name,
|
||||||
"preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")),
|
"preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")),
|
||||||
"preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0),
|
"preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0),
|
||||||
"base_model": checkpoint_data.get("base_model", ""),
|
"base_model": checkpoint_data.get("base_model", ""),
|
||||||
"folder": checkpoint_data["folder"],
|
"folder": folder,
|
||||||
"sha256": checkpoint_data.get("sha256", ""),
|
"sha256": checkpoint_data.get("sha256", ""),
|
||||||
"file_path": checkpoint_data["file_path"].replace(os.sep, "/"),
|
"file_path": file_path.replace(os.sep, "/"),
|
||||||
"file_size": checkpoint_data.get("size", 0),
|
"file_size": checkpoint_data.get("size", 0),
|
||||||
"modified": checkpoint_data.get("modified", ""),
|
"modified": checkpoint_data.get("modified", ""),
|
||||||
"tags": checkpoint_data.get("tags", []),
|
"tags": checkpoint_data.get("tags", []),
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from .base_model_service import BaseModelService
|
from .base_model_service import BaseModelService
|
||||||
from .auto_tag_service import extract_auto_tags
|
from .auto_tag_service import extract_auto_tags
|
||||||
@@ -21,20 +21,37 @@ class EmbeddingService(BaseModelService):
|
|||||||
"""
|
"""
|
||||||
super().__init__("embedding", scanner, EmbeddingMetadata, update_service=update_service)
|
super().__init__("embedding", scanner, EmbeddingMetadata, update_service=update_service)
|
||||||
|
|
||||||
async def format_response(self, embedding_data: Dict) -> Dict:
|
async def format_response(self, embedding_data: Dict) -> Optional[Dict]:
|
||||||
"""Format Embedding data for API response"""
|
"""Format Embedding data for API response.
|
||||||
|
|
||||||
|
Returns None when the entry is missing critical fields (corrupted cache
|
||||||
|
row), so the handler layer can filter it out. See issue #730.
|
||||||
|
"""
|
||||||
|
# Guard against corrupted cache entries missing critical fields
|
||||||
|
file_path = embedding_data.get("file_path")
|
||||||
|
if not file_path or not isinstance(file_path, str):
|
||||||
|
logger.warning(
|
||||||
|
"Skipping corrupted embedding entry (missing file_path): %s",
|
||||||
|
embedding_data.get("file_name", "<unknown>"),
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
# Get sub_type from cache entry (new canonical field)
|
# Get sub_type from cache entry (new canonical field)
|
||||||
sub_type = embedding_data.get("sub_type", "embedding")
|
sub_type = embedding_data.get("sub_type", "embedding")
|
||||||
|
|
||||||
|
file_name = embedding_data.get("file_name") or ""
|
||||||
|
model_name = embedding_data.get("model_name") or file_name
|
||||||
|
folder = embedding_data.get("folder") or ""
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model_name": embedding_data["model_name"],
|
"model_name": model_name,
|
||||||
"file_name": embedding_data["file_name"],
|
"file_name": file_name,
|
||||||
"preview_url": config.get_preview_static_url(embedding_data.get("preview_url", "")),
|
"preview_url": config.get_preview_static_url(embedding_data.get("preview_url", "")),
|
||||||
"preview_nsfw_level": embedding_data.get("preview_nsfw_level", 0),
|
"preview_nsfw_level": embedding_data.get("preview_nsfw_level", 0),
|
||||||
"base_model": embedding_data.get("base_model", ""),
|
"base_model": embedding_data.get("base_model", ""),
|
||||||
"folder": embedding_data["folder"],
|
"folder": folder,
|
||||||
"sha256": embedding_data.get("sha256", ""),
|
"sha256": embedding_data.get("sha256", ""),
|
||||||
"file_path": embedding_data["file_path"].replace(os.sep, "/"),
|
"file_path": file_path.replace(os.sep, "/"),
|
||||||
"file_size": embedding_data.get("size", 0),
|
"file_size": embedding_data.get("size", 0),
|
||||||
"modified": embedding_data.get("modified", ""),
|
"modified": embedding_data.get("modified", ""),
|
||||||
"tags": embedding_data.get("tags", []),
|
"tags": embedding_data.get("tags", []),
|
||||||
|
|||||||
@@ -24,23 +24,41 @@ class LoraService(BaseModelService):
|
|||||||
"""
|
"""
|
||||||
super().__init__("lora", scanner, LoraMetadata, update_service=update_service)
|
super().__init__("lora", scanner, LoraMetadata, update_service=update_service)
|
||||||
|
|
||||||
async def format_response(self, lora_data: Dict) -> Dict:
|
async def format_response(self, lora_data: Dict) -> Optional[Dict]:
|
||||||
"""Format LoRA data for API response"""
|
"""Format LoRA data for API response.
|
||||||
|
|
||||||
|
Returns None when the entry is missing critical fields (corrupted cache
|
||||||
|
row), so the handler layer can filter it out instead of crashing the
|
||||||
|
whole listing request. See issue #730.
|
||||||
|
"""
|
||||||
|
# Guard against corrupted cache entries missing critical fields
|
||||||
|
file_path = lora_data.get("file_path")
|
||||||
|
if not file_path or not isinstance(file_path, str):
|
||||||
|
logger.warning(
|
||||||
|
"Skipping corrupted LoRA entry (missing file_path): %s",
|
||||||
|
lora_data.get("file_name", "<unknown>"),
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
# Resolve sub_type using priority: sub_type > model_type > civitai.model.type > default
|
# Resolve sub_type using priority: sub_type > model_type > civitai.model.type > default
|
||||||
# Normalize to lowercase for consistent API responses
|
# Normalize to lowercase for consistent API responses
|
||||||
sub_type = resolve_sub_type(lora_data).lower()
|
sub_type = resolve_sub_type(lora_data).lower()
|
||||||
|
|
||||||
|
file_name = lora_data.get("file_name") or ""
|
||||||
|
model_name = lora_data.get("model_name") or file_name
|
||||||
|
folder = lora_data.get("folder") or ""
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model_name": lora_data["model_name"],
|
"model_name": model_name,
|
||||||
"file_name": lora_data["file_name"],
|
"file_name": file_name,
|
||||||
"preview_url": config.get_preview_static_url(
|
"preview_url": config.get_preview_static_url(
|
||||||
lora_data.get("preview_url", "")
|
lora_data.get("preview_url", "")
|
||||||
),
|
),
|
||||||
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
|
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
|
||||||
"base_model": lora_data.get("base_model", ""),
|
"base_model": lora_data.get("base_model", ""),
|
||||||
"folder": lora_data["folder"],
|
"folder": folder,
|
||||||
"sha256": lora_data.get("sha256", ""),
|
"sha256": lora_data.get("sha256", ""),
|
||||||
"file_path": lora_data["file_path"].replace(os.sep, "/"),
|
"file_path": file_path.replace(os.sep, "/"),
|
||||||
"file_size": lora_data.get("size", 0),
|
"file_size": lora_data.get("size", 0),
|
||||||
"modified": lora_data.get("modified", ""),
|
"modified": lora_data.get("modified", ""),
|
||||||
"tags": lora_data.get("tags", []),
|
"tags": lora_data.get("tags", []),
|
||||||
|
|||||||
@@ -476,11 +476,20 @@ class ModelScanner:
|
|||||||
for tag in adjusted_item.get('tags') or []:
|
for tag in adjusted_item.get('tags') or []:
|
||||||
tags_count[tag] = tags_count.get(tag, 0) + 1
|
tags_count[tag] = tags_count.get(tag, 0) + 1
|
||||||
|
|
||||||
# Validate cache entries and check health
|
# Validate cache entries and check health.
|
||||||
|
# Always use the validated/repaired entries — even when there are no
|
||||||
|
# invalid entries, auto_repair may have filled in missing optional
|
||||||
|
# fields (model_name, file_name, folder) with safe defaults on a copied
|
||||||
|
# working_entry. Without this unconditional replacement the repaired
|
||||||
|
# copies are discarded and None values propagate to format_response.
|
||||||
|
# See issue #730.
|
||||||
valid_entries, invalid_entries = CacheEntryValidator.validate_batch(
|
valid_entries, invalid_entries = CacheEntryValidator.validate_batch(
|
||||||
adjusted_raw_data, auto_repair=True
|
adjusted_raw_data, auto_repair=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Always use the validated entries (repaired copies)
|
||||||
|
adjusted_raw_data = valid_entries
|
||||||
|
|
||||||
if invalid_entries:
|
if invalid_entries:
|
||||||
monitor = CacheHealthMonitor()
|
monitor = CacheHealthMonitor()
|
||||||
report = monitor.check_health(adjusted_raw_data, auto_repair=True)
|
report = monitor.check_health(adjusted_raw_data, auto_repair=True)
|
||||||
|
|||||||
@@ -165,8 +165,8 @@ class PersistentModelCache:
|
|||||||
|
|
||||||
item = {
|
item = {
|
||||||
"file_path": file_path,
|
"file_path": file_path,
|
||||||
"file_name": row["file_name"],
|
"file_name": row["file_name"] or "",
|
||||||
"model_name": row["model_name"],
|
"model_name": row["model_name"] or "",
|
||||||
"folder": row["folder"] or "",
|
"folder": row["folder"] or "",
|
||||||
"size": row["size"] or 0,
|
"size": row["size"] or 0,
|
||||||
"modified": row["modified"] or 0.0,
|
"modified": row["modified"] or 0.0,
|
||||||
@@ -548,19 +548,19 @@ class PersistentModelCache:
|
|||||||
return (
|
return (
|
||||||
model_type,
|
model_type,
|
||||||
item.get("file_path"),
|
item.get("file_path"),
|
||||||
item.get("file_name"),
|
item.get("file_name") or "",
|
||||||
item.get("model_name"),
|
item.get("model_name") or "",
|
||||||
item.get("folder"),
|
item.get("folder") or "",
|
||||||
int(item.get("size") or 0),
|
int(item.get("size") or 0),
|
||||||
float(item.get("modified") or 0.0),
|
float(item.get("modified") or 0.0),
|
||||||
(item.get("sha256") or "").lower() or None,
|
(item.get("sha256") or "").lower() or None,
|
||||||
item.get("base_model"),
|
item.get("base_model") or "",
|
||||||
item.get("preview_url"),
|
item.get("preview_url") or "",
|
||||||
int(item.get("preview_nsfw_level") or 0),
|
int(item.get("preview_nsfw_level") or 0),
|
||||||
1 if item.get("from_civitai", True) else 0,
|
1 if item.get("from_civitai", True) else 0,
|
||||||
1 if item.get("favorite") else 0,
|
1 if item.get("favorite") else 0,
|
||||||
item.get("notes"),
|
item.get("notes") or "",
|
||||||
item.get("usage_tips"),
|
item.get("usage_tips") or "",
|
||||||
metadata_source,
|
metadata_source,
|
||||||
civitai.get("id"),
|
civitai.get("id"),
|
||||||
civitai.get("modelId"),
|
civitai.get("modelId"),
|
||||||
|
|||||||
@@ -201,6 +201,45 @@ def test_list_models_returns_formatted_items(mock_service, mock_scanner):
|
|||||||
asyncio.run(scenario())
|
asyncio.run(scenario())
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_models_filters_out_corrupted_entries(mock_service, mock_scanner):
|
||||||
|
"""Corrupted cache entries (format_response returns None) must not appear
|
||||||
|
in the response items nor cause a 500. See issue #730.
|
||||||
|
"""
|
||||||
|
mock_service.paginated_items = [
|
||||||
|
{"file_path": "/tmp/good.safetensors", "name": "Good"},
|
||||||
|
{"file_path": None, "name": "Corrupted"}, # triggers None from format_response
|
||||||
|
{"file_path": "/tmp/also_good.safetensors", "name": "AlsoGood"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Override format_response to return None for corrupted entries
|
||||||
|
original_format = mock_service.format_response
|
||||||
|
|
||||||
|
async def conditional_format(item):
|
||||||
|
if item.get("file_path") is None:
|
||||||
|
return None
|
||||||
|
return await original_format(item)
|
||||||
|
|
||||||
|
mock_service.format_response = conditional_format
|
||||||
|
|
||||||
|
async def scenario():
|
||||||
|
client = await create_test_client(mock_service)
|
||||||
|
try:
|
||||||
|
response = await client.get("/api/lm/test-models/list")
|
||||||
|
payload = await response.json()
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
# Only the 2 non-corrupted entries should appear
|
||||||
|
assert len(payload["items"]) == 2
|
||||||
|
assert payload["items"][0]["name"] == "Good"
|
||||||
|
assert payload["items"][1]["name"] == "AlsoGood"
|
||||||
|
# None should never appear in the items list
|
||||||
|
assert None not in payload["items"]
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
asyncio.run(scenario())
|
||||||
|
|
||||||
|
|
||||||
def test_model_types_endpoint_returns_counts(mock_service, mock_scanner):
|
def test_model_types_endpoint_returns_counts(mock_service, mock_scanner):
|
||||||
mock_service.model_types = [
|
mock_service.model_types = [
|
||||||
{"type": "LoRa", "count": 3},
|
{"type": "LoRa", "count": 3},
|
||||||
|
|||||||
@@ -204,3 +204,102 @@ class TestEmbeddingServiceFormatResponse:
|
|||||||
|
|
||||||
assert result["sub_type"] == "embedding"
|
assert result["sub_type"] == "embedding"
|
||||||
assert "model_type" not in result # Removed in refactoring
|
assert "model_type" not in result # Removed in refactoring
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatResponseCorruptedEntries:
|
||||||
|
"""Test format_response handles corrupted cache entries gracefully (issue #730).
|
||||||
|
|
||||||
|
When cache rows have None/missing critical fields (e.g. from a partially
|
||||||
|
written or legacy DB), format_response must NOT raise KeyError/AttributeError.
|
||||||
|
Instead it returns None so the handler layer can filter the bad entry out
|
||||||
|
instead of failing the entire listing request.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_scanner(self):
|
||||||
|
scanner = MagicMock()
|
||||||
|
scanner._hash_index = MagicMock()
|
||||||
|
return scanner
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def lora_service(self, mock_scanner):
|
||||||
|
return LoraService(mock_scanner)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def checkpoint_service(self, mock_scanner):
|
||||||
|
return CheckpointService(mock_scanner)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def embedding_service(self, mock_scanner):
|
||||||
|
return EmbeddingService(mock_scanner)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lora_returns_none_on_missing_file_path(self, lora_service):
|
||||||
|
"""format_response returns None when file_path is missing (corrupted row)."""
|
||||||
|
lora_data = {
|
||||||
|
"model_name": "Test LoRA",
|
||||||
|
"file_name": "test_lora",
|
||||||
|
"file_path": None, # corrupted: missing file_path
|
||||||
|
"folder": "",
|
||||||
|
"sha256": "abc123",
|
||||||
|
"tags": [],
|
||||||
|
"from_civitai": True,
|
||||||
|
"civitai": {},
|
||||||
|
}
|
||||||
|
result = await lora_service.format_response(lora_data)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lora_handles_none_model_name_gracefully(self, lora_service):
|
||||||
|
"""format_response should not crash when model_name is None (legacy DB row)."""
|
||||||
|
lora_data = {
|
||||||
|
"model_name": None, # NULL from old DB row
|
||||||
|
"file_name": "test_lora",
|
||||||
|
"file_path": "/models/test_lora.safetensors",
|
||||||
|
"folder": "",
|
||||||
|
"sha256": "abc123",
|
||||||
|
"tags": [],
|
||||||
|
"from_civitai": True,
|
||||||
|
"civitai": {},
|
||||||
|
}
|
||||||
|
result = await lora_service.format_response(lora_data)
|
||||||
|
# Should not raise; model_name falls back to file_name
|
||||||
|
assert result is not None
|
||||||
|
assert result["model_name"] == "test_lora"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkpoint_returns_none_on_missing_file_path(self, checkpoint_service):
|
||||||
|
"""format_response returns None when file_path is missing (corrupted row)."""
|
||||||
|
checkpoint_data = {
|
||||||
|
"model_name": "Test",
|
||||||
|
"file_name": "test",
|
||||||
|
"file_path": "", # empty string == corrupted
|
||||||
|
"folder": "",
|
||||||
|
"sha256": "abc",
|
||||||
|
"tags": [],
|
||||||
|
"from_civitai": True,
|
||||||
|
"civitai": {},
|
||||||
|
"sub_type": "checkpoint",
|
||||||
|
}
|
||||||
|
result = await checkpoint_service.format_response(checkpoint_data)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embedding_handles_none_fields_gracefully(self, embedding_service):
|
||||||
|
"""format_response should not crash when optional fields are None."""
|
||||||
|
embedding_data = {
|
||||||
|
"model_name": None,
|
||||||
|
"file_name": None,
|
||||||
|
"file_path": "/models/test.pt",
|
||||||
|
"folder": None,
|
||||||
|
"sha256": "abc",
|
||||||
|
"tags": [],
|
||||||
|
"from_civitai": True,
|
||||||
|
"civitai": {},
|
||||||
|
"sub_type": "embedding",
|
||||||
|
}
|
||||||
|
result = await embedding_service.format_response(embedding_data)
|
||||||
|
assert result is not None
|
||||||
|
assert result["file_path"] == "/models/test.pt"
|
||||||
|
# model_name falls back to file_name which falls back to ""
|
||||||
|
assert result["model_name"] == ""
|
||||||
|
|||||||
Reference in New Issue
Block a user