refactor: unify model_type semantics by introducing sub_type field

This commit resolves the semantic confusion around the model_type field by
clearly distinguishing between:
- scanner_type: architecture-level (lora/checkpoint/embedding)
- sub_type: business-level subtype (lora/locon/dora/checkpoint/diffusion_model/embedding)

Backend Changes:
- Rename model_type to sub_type in CheckpointMetadata and EmbeddingMetadata
- Add resolve_sub_type() and normalize_sub_type() in model_query.py
- Update checkpoint_scanner to use _resolve_sub_type()
- Update service format_response to include both sub_type and model_type
- Add VALID_*_SUB_TYPES constants with backward compatible aliases

Frontend Changes:
- Add MODEL_SUBTYPE_DISPLAY_NAMES constants
- Keep MODEL_TYPE_DISPLAY_NAMES as backward compatible alias

Testing:
- Add 43 new tests covering sub_type resolution and API response

Documentation:
- Add refactoring todo document to docs/technical/

BREAKING CHANGE: None - full backward compatibility maintained
This commit is contained in:
Will Miao
2026-01-30 06:56:10 +08:00
parent 08267cdb48
commit 5e91073476
15 changed files with 1014 additions and 42 deletions

View File

@@ -5,7 +5,7 @@ import logging
import os
import time
from ..utils.constants import VALID_LORA_TYPES
from ..utils.constants import VALID_LORA_SUB_TYPES, VALID_CHECKPOINT_SUB_TYPES
from ..utils.models import BaseModelMetadata
from ..utils.metadata_manager import MetadataManager
from ..utils.usage_stats import UsageStats
@@ -15,8 +15,8 @@ from .model_query import (
ModelFilterSet,
SearchStrategy,
SettingsProvider,
normalize_civitai_model_type,
resolve_civitai_model_type,
normalize_sub_type,
resolve_sub_type,
)
from .settings_manager import get_settings_manager
@@ -568,16 +568,21 @@ class BaseModelService(ABC):
return await self.scanner.get_base_models(limit)
async def get_model_types(self, limit: int = 20) -> List[Dict[str, Any]]:
"""Get counts of normalized CivitAI model types present in the cache."""
"""Get counts of sub-types present in the cache."""
cache = await self.scanner.get_cached_data()
type_counts: Dict[str, int] = {}
for entry in cache.raw_data:
normalized_type = normalize_civitai_model_type(
resolve_civitai_model_type(entry)
)
if not normalized_type or normalized_type not in VALID_LORA_TYPES:
normalized_type = normalize_sub_type(resolve_sub_type(entry))
if not normalized_type:
continue
# Filter by valid sub-types based on scanner type
if self.model_type == "lora" and normalized_type not in VALID_LORA_SUB_TYPES:
continue
if self.model_type == "checkpoint" and normalized_type not in VALID_CHECKPOINT_SUB_TYPES:
continue
type_counts[normalized_type] = type_counts.get(normalized_type, 0) + 1
sorted_types = sorted(

View File

@@ -21,7 +21,8 @@ class CheckpointScanner(ModelScanner):
hash_index=ModelHashIndex()
)
def _resolve_model_type(self, root_path: Optional[str]) -> Optional[str]:
def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]:
"""Resolve the sub-type based on the root path."""
if not root_path:
return None
@@ -34,18 +35,28 @@ class CheckpointScanner(ModelScanner):
return None
def adjust_metadata(self, metadata, file_path, root_path):
if hasattr(metadata, "model_type"):
model_type = self._resolve_model_type(root_path)
if model_type:
metadata.model_type = model_type
"""Adjust metadata during scanning to set sub_type."""
# Support both old 'model_type' and new 'sub_type' for backward compatibility
if hasattr(metadata, "sub_type"):
sub_type = self._resolve_sub_type(root_path)
if sub_type:
metadata.sub_type = sub_type
elif hasattr(metadata, "model_type"):
# Backward compatibility: fallback to model_type if sub_type not available
sub_type = self._resolve_sub_type(root_path)
if sub_type:
metadata.model_type = sub_type
return metadata
def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
model_type = self._resolve_model_type(
"""Adjust entries loaded from the persisted cache to ensure sub_type is set."""
sub_type = self._resolve_sub_type(
self._find_root_for_file(entry.get("file_path"))
)
if model_type:
entry["model_type"] = model_type
if sub_type:
entry["sub_type"] = sub_type
# Also set model_type for backward compatibility during transition
entry["model_type"] = sub_type
return entry
def get_model_roots(self) -> List[str]:

View File

@@ -22,6 +22,9 @@ class CheckpointService(BaseModelService):
async def format_response(self, checkpoint_data: Dict) -> Dict:
"""Format Checkpoint data for API response"""
# Get sub_type from cache entry (new field) or fallback to model_type (old field)
sub_type = checkpoint_data.get("sub_type") or checkpoint_data.get("model_type", "checkpoint")
return {
"model_name": checkpoint_data["model_name"],
"file_name": checkpoint_data["file_name"],
@@ -37,7 +40,8 @@ class CheckpointService(BaseModelService):
"from_civitai": checkpoint_data.get("from_civitai", True),
"usage_count": checkpoint_data.get("usage_count", 0),
"notes": checkpoint_data.get("notes", ""),
"model_type": checkpoint_data.get("model_type", "checkpoint"),
"sub_type": sub_type, # New canonical field
"model_type": sub_type, # Backward compatibility
"favorite": checkpoint_data.get("favorite", False),
"update_available": bool(checkpoint_data.get("update_available", False)),
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True)

View File

@@ -22,6 +22,9 @@ class EmbeddingService(BaseModelService):
async def format_response(self, embedding_data: Dict) -> Dict:
"""Format Embedding data for API response"""
# Get sub_type from cache entry (new field) or fallback to model_type (old field)
sub_type = embedding_data.get("sub_type") or embedding_data.get("model_type", "embedding")
return {
"model_name": embedding_data["model_name"],
"file_name": embedding_data["file_name"],
@@ -37,7 +40,8 @@ class EmbeddingService(BaseModelService):
"from_civitai": embedding_data.get("from_civitai", True),
# "usage_count": embedding_data.get("usage_count", 0), # TODO: Enable when embedding usage tracking is implemented
"notes": embedding_data.get("notes", ""),
"model_type": embedding_data.get("model_type", "embedding"),
"sub_type": sub_type, # New canonical field
"model_type": sub_type, # Backward compatibility
"favorite": embedding_data.get("favorite", False),
"update_available": bool(embedding_data.get("update_available", False)),
"civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True)

View File

@@ -23,6 +23,9 @@ class LoraService(BaseModelService):
async def format_response(self, lora_data: Dict) -> Dict:
"""Format LoRA data for API response"""
# Get sub_type from cache entry (new field) or fallback to model_type (old field)
sub_type = lora_data.get("sub_type") or lora_data.get("model_type", "lora")
return {
"model_name": lora_data["model_name"],
"file_name": lora_data["file_name"],
@@ -43,6 +46,8 @@ class LoraService(BaseModelService):
"notes": lora_data.get("notes", ""),
"favorite": lora_data.get("favorite", False),
"update_available": bool(lora_data.get("update_available", False)),
"sub_type": sub_type, # New canonical field
"model_type": sub_type, # Backward compatibility
"civitai": self.filter_civitai_data(
lora_data.get("civitai", {}), minimal=True
),

View File

@@ -33,32 +33,54 @@ def _coerce_to_str(value: Any) -> Optional[str]:
return candidate if candidate else None
def normalize_civitai_model_type(value: Any) -> Optional[str]:
"""Return a lowercase string suitable for comparisons."""
def normalize_sub_type(value: Any) -> Optional[str]:
"""Return a lowercase string suitable for sub_type comparisons."""
candidate = _coerce_to_str(value)
return candidate.lower() if candidate else None
def resolve_civitai_model_type(entry: Mapping[str, Any]) -> str:
"""Extract the model type from CivitAI metadata, defaulting to LORA."""
# Backward compatibility alias
normalize_civitai_model_type = normalize_sub_type
def resolve_sub_type(entry: Mapping[str, Any]) -> str:
"""Extract the sub-type from metadata, checking multiple sources.
Priority:
1. entry['sub_type'] - new canonical field
2. entry['model_type'] - backward compatibility
3. civitai.model.type - CivitAI API data
4. DEFAULT_CIVITAI_MODEL_TYPE - fallback
"""
if not isinstance(entry, Mapping):
return DEFAULT_CIVITAI_MODEL_TYPE
civitai = entry.get("civitai")
if isinstance(civitai, Mapping):
civitai_model = civitai.get("model")
if isinstance(civitai_model, Mapping):
model_type = _coerce_to_str(civitai_model.get("type"))
if model_type:
return model_type
# Priority 1: Check new canonical field 'sub_type'
sub_type = _coerce_to_str(entry.get("sub_type"))
if sub_type:
return sub_type
# Priority 2: Backward compatibility - check 'model_type' field
model_type = _coerce_to_str(entry.get("model_type"))
if model_type:
return model_type
# Priority 3: Extract from CivitAI metadata
civitai = entry.get("civitai")
if isinstance(civitai, Mapping):
civitai_model = civitai.get("model")
if isinstance(civitai_model, Mapping):
civitai_type = _coerce_to_str(civitai_model.get("type"))
if civitai_type:
return civitai_type
return DEFAULT_CIVITAI_MODEL_TYPE
# Backward compatibility alias
resolve_civitai_model_type = resolve_sub_type
class SettingsProvider(Protocol):
"""Protocol describing the SettingsManager contract used by query helpers."""
@@ -313,7 +335,7 @@ class ModelFilterSet:
normalized_model_types = {
model_type
for model_type in (
normalize_civitai_model_type(value) for value in model_types
normalize_sub_type(value) for value in model_types
)
if model_type
}
@@ -321,7 +343,7 @@ class ModelFilterSet:
items = [
item
for item in items
if normalize_civitai_model_type(resolve_civitai_model_type(item))
if normalize_sub_type(resolve_sub_type(item))
in normalized_model_types
]
model_types_duration = time.perf_counter() - t0

View File

@@ -275,9 +275,16 @@ class ModelScanner:
_, license_flags = resolve_license_info(license_source or {})
entry['license_flags'] = license_flags
# Handle sub_type (new canonical field) and model_type (backward compatibility)
sub_type = get_value('sub_type', None)
model_type = get_value('model_type', None)
if model_type:
entry['model_type'] = model_type
# Prefer sub_type, fallback to model_type for backward compatibility
effective_sub_type = sub_type or model_type
if effective_sub_type:
entry['sub_type'] = effective_sub_type
# Also keep model_type for backward compatibility during transition
entry['model_type'] = effective_sub_type
return entry

View File

@@ -45,8 +45,13 @@ SUPPORTED_MEDIA_EXTENSIONS = {
"videos": [".mp4", ".webm"],
}
# Valid Lora types
VALID_LORA_TYPES = ["lora", "locon", "dora"]
# Valid sub-types for each scanner type
VALID_LORA_SUB_TYPES = ["lora", "locon", "dora"]
VALID_CHECKPOINT_SUB_TYPES = ["checkpoint", "diffusion_model"]
VALID_EMBEDDING_SUB_TYPES = ["embedding"]
# Backward compatibility alias
VALID_LORA_TYPES = VALID_LORA_SUB_TYPES
# Supported Civitai model types for user model queries (case-insensitive)
CIVITAI_USER_MODEL_TYPES = [

View File

@@ -173,14 +173,14 @@ class LoraMetadata(BaseModelMetadata):
@dataclass
class CheckpointMetadata(BaseModelMetadata):
"""Represents the metadata structure for a Checkpoint model"""
model_type: str = "checkpoint" # Model type (checkpoint, diffusion_model, etc.)
sub_type: str = "checkpoint" # Model sub-type (checkpoint, diffusion_model, etc.)
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'CheckpointMetadata':
"""Create CheckpointMetadata instance from Civitai version info"""
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
model_type = version_info.get('type', 'checkpoint')
sub_type = version_info.get('type', 'checkpoint')
# Extract tags and description if available
tags = []
@@ -203,7 +203,7 @@ class CheckpointMetadata(BaseModelMetadata):
preview_nsfw_level=0,
from_civitai=True,
civitai=version_info,
model_type=model_type,
sub_type=sub_type,
tags=tags,
modelDescription=description
)
@@ -211,14 +211,14 @@ class CheckpointMetadata(BaseModelMetadata):
@dataclass
class EmbeddingMetadata(BaseModelMetadata):
"""Represents the metadata structure for an Embedding model"""
model_type: str = "embedding" # Model type (embedding, textual_inversion, etc.)
sub_type: str = "embedding"
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'EmbeddingMetadata':
"""Create EmbeddingMetadata instance from Civitai version info"""
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
model_type = version_info.get('type', 'embedding')
sub_type = version_info.get('type', 'embedding')
# Extract tags and description if available
tags = []
@@ -241,7 +241,7 @@ class EmbeddingMetadata(BaseModelMetadata):
preview_nsfw_level=0,
from_civitai=True,
civitai=version_info,
model_type=model_type,
sub_type=sub_type,
tags=tags,
modelDescription=description
)