From 5e91073476effc6280ea61c2a6414c473537cc0d Mon Sep 17 00:00:00 2001 From: Will Miao Date: Fri, 30 Jan 2026 06:56:10 +0800 Subject: [PATCH] refactor: unify model_type semantics by introducing sub_type field This commit resolves the semantic confusion around the model_type field by clearly distinguishing between: - scanner_type: architecture-level (lora/checkpoint/embedding) - sub_type: business-level subtype (lora/locon/dora/checkpoint/diffusion_model/embedding) Backend Changes: - Rename model_type to sub_type in CheckpointMetadata and EmbeddingMetadata - Add resolve_sub_type() and normalize_sub_type() in model_query.py - Update checkpoint_scanner to use _resolve_sub_type() - Update service format_response to include both sub_type and model_type - Add VALID_*_SUB_TYPES constants with backward compatible aliases Frontend Changes: - Add MODEL_SUBTYPE_DISPLAY_NAMES constants - Keep MODEL_TYPE_DISPLAY_NAMES as backward compatible alias Testing: - Add 43 new tests covering sub_type resolution and API response Documentation: - Add refactoring todo document to docs/technical/ BREAKING CHANGE: None - full backward compatibility maintained --- docs/technical/model_type_refactoring_todo.md | 194 +++++++++++++++ py/services/base_model_service.py | 21 +- py/services/checkpoint_scanner.py | 27 +- py/services/checkpoint_service.py | 6 +- py/services/embedding_service.py | 6 +- py/services/lora_service.py | 5 + py/services/model_query.py | 48 +++- py/services/model_scanner.py | 11 +- py/utils/constants.py | 9 +- py/utils/models.py | 12 +- static/js/utils/constants.js | 12 +- .../test_checkpoint_scanner_sub_type.py | 145 +++++++++++ tests/services/test_model_query_sub_type.py | 203 ++++++++++++++++ .../test_service_format_response_sub_type.py | 230 ++++++++++++++++++ tests/utils/test_models_sub_type.py | 127 ++++++++++ 15 files changed, 1014 insertions(+), 42 deletions(-) create mode 100644 docs/technical/model_type_refactoring_todo.md create mode 100644 tests/services/test_checkpoint_scanner_sub_type.py create mode 100644 tests/services/test_model_query_sub_type.py create mode 100644 tests/services/test_service_format_response_sub_type.py create mode 100644 tests/utils/test_models_sub_type.py diff --git a/docs/technical/model_type_refactoring_todo.md b/docs/technical/model_type_refactoring_todo.md new file mode 100644 index 00000000..490bf7f7 --- /dev/null +++ b/docs/technical/model_type_refactoring_todo.md @@ -0,0 +1,194 @@ +# Model Type 字段重构 - 遗留工作清单 + +> **状态**: Phase 1-4 已完成 | **创建日期**: 2026-01-30 +> **相关文件**: `py/utils/models.py`, `py/services/model_query.py`, `py/services/checkpoint_scanner.py`, etc. + +--- + +## 概述 + +本次重构旨在解决 `model_type` 字段语义不统一的问题。系统中有两个层面的"类型"概念: + +1. **Scanner Type** (`scanner_type`): 架构层面的大类 - `lora`, `checkpoint`, `embedding` +2. **Sub Type** (`sub_type`): 业务层面的细分类型 - `lora`/`locon`/`dora`, `checkpoint`/`diffusion_model`, `embedding` + +重构目标是统一使用 `sub_type` 表示细分类型,保留 `model_type` 作为向后兼容的别名。 + +--- + +## 已完成工作 ✅ + +### Phase 1: 后端字段重命名 +- [x] `CheckpointMetadata.model_type` → `sub_type` +- [x] `EmbeddingMetadata.model_type` → `sub_type` +- [x] `model_scanner.py` `_build_cache_entry()` 同时处理 `sub_type` 和 `model_type` + +### Phase 2: 查询逻辑更新 +- [x] `model_query.py` 新增 `resolve_sub_type()` 和 `normalize_sub_type()` +- [x] 保持向后兼容的别名 `resolve_civitai_model_type`, `normalize_civitai_model_type` +- [x] `ModelFilterSet.apply()` 更新为使用新的解析函数 + +### Phase 3: API 响应更新 +- [x] `LoraService.format_response()` 返回 `sub_type` + `model_type` +- [x] `CheckpointService.format_response()` 返回 `sub_type` + `model_type` +- [x] `EmbeddingService.format_response()` 返回 `sub_type` + `model_type` + +### Phase 4: 前端更新 +- [x] `constants.js` 新增 `MODEL_SUBTYPE_DISPLAY_NAMES` +- [x] `MODEL_TYPE_DISPLAY_NAMES` 作为别名保留 + +--- + +## 遗留工作 ⏳ + +### Phase 5: 清理废弃代码(建议在下个 major version 进行) + +#### 5.1 移除 `model_type` 字段的向后兼容代码 + +**优先级**: 低 +**风险**: 高(需要确保前端和第三方集成不再依赖) + +```python +# TODO: 从 ModelScanner._build_cache_entry() 中移除 +# 当前代码: +if effective_sub_type: + entry['sub_type'] = effective_sub_type + entry['model_type'] = effective_sub_type # 待移除 + +# 目标代码: +if effective_sub_type: + entry['sub_type'] = effective_sub_type +``` + +#### 5.2 移除 CheckpointScanner 的 model_type 兼容处理 + +```python +# TODO: 从 checkpoint_scanner.py 中移除对 model_type 的兼容处理 +# 当前 adjust_metadata 同时检查 'sub_type' 和 'model_type' +# 目标:只处理 'sub_type' +``` + +#### 5.3 移除 model_query 中的向后兼容别名 + +```python +# TODO: 确认所有调用方都使用新函数后,移除这些别名 +resolve_civitai_model_type = resolve_sub_type # 待移除 +normalize_civitai_model_type = normalize_sub_type # 待移除 +``` + +#### 5.4 前端清理 + +```javascript +// TODO: 从前端移除对 model_type 的依赖 +// FilterManager.js 中仍然使用 model_type 作为内部状态名 +// 需要统一改为使用 sub_type +``` + +--- + +## 数据库迁移评估 + +### 当前状态 +- `persistent_model_cache.py` 使用 `civitai_model_type` 列存储 CivitAI 原始类型 +- 缓存 entry 中的 `sub_type` 在运行期动态计算 +- 数据库 schema **无需立即修改** + +### 未来可选优化 +```sql +-- 可选:在 models 表中添加 sub_type 列(与 civitai_model_type 保持一致但语义更清晰) +ALTER TABLE models ADD COLUMN sub_type TEXT; + +-- 数据迁移 +UPDATE models SET sub_type = civitai_model_type WHERE sub_type IS NULL; +``` + +**建议**: 如果决定添加 `sub_type` 列,应与 Phase 5 一起进行。 + +--- + +## 测试覆盖率 + +### 新增测试文件(已全部通过 ✅) + +| 测试文件 | 数量 | 覆盖内容 | +|---------|------|---------| +| `tests/utils/test_models_sub_type.py` | 7 | Metadata sub_type 字段 | +| `tests/services/test_model_query_sub_type.py` | 23 | sub_type 解析和过滤 | +| `tests/services/test_checkpoint_scanner_sub_type.py` | 6 | CheckpointScanner sub_type | +| `tests/services/test_service_format_response_sub_type.py` | 7 | API 响应 sub_type 包含 | + +### 需要补充的测试(TODO) + +- [ ] 集成测试:验证前端过滤使用 sub_type 字段 +- [ ] 数据库迁移测试(如果执行可选优化) +- [ ] 性能测试:确认 resolve_sub_type 的优先级查找没有显著性能影响 + +--- + +## 兼容性检查清单 + +在移除向后兼容代码前,请确认: + +- [ ] 前端代码已全部改用 `sub_type` 字段 +- [ ] ComfyUI Widget 代码不再依赖 `model_type` +- [ ] 移动端/第三方客户端已更新 +- [ ] 文档已更新,说明 `model_type` 已弃用 +- [ ] 提供至少 1 个版本的弃用警告期 + +--- + +## 相关文件清单 + +### 核心文件 +``` +py/utils/models.py +py/utils/constants.py +py/services/model_scanner.py +py/services/model_query.py +py/services/checkpoint_scanner.py +py/services/base_model_service.py +py/services/lora_service.py +py/services/checkpoint_service.py +py/services/embedding_service.py +``` + +### 前端文件 +``` +static/js/utils/constants.js +static/js/managers/FilterManager.js +``` + +### 测试文件 +``` +tests/utils/test_models_sub_type.py +tests/services/test_model_query_sub_type.py +tests/services/test_checkpoint_scanner_sub_type.py +tests/services/test_service_format_response_sub_type.py +``` + +--- + +## 风险评估 + +| 风险项 | 影响 | 缓解措施 | +|-------|------|---------| +| 第三方代码依赖 `model_type` | 高 | 保持别名至少 1 个 major 版本 | +| 数据库 schema 变更 | 中 | 暂缓 schema 变更,仅运行时计算 | +| 前端过滤失效 | 中 | 全面的集成测试覆盖 | +| CivitAI API 变化 | 低 | 保持多源解析策略 | + +--- + +## 时间线建议 + +- **v1.x (当前)**: Phase 1-4 已完成,保持向后兼容 +- **v2.0**: 添加弃用警告,开始迁移文档 +- **v3.0**: 移除 `model_type` 兼容代码(Phase 5) + +--- + +## 备注 + +- 重构期间发现 `civitai_model_type` 数据库列命名尚可,但语义上应理解为存储 CivitAI API 返回的原始类型值 +- Checkpoint 的 `diffusion_model` sub_type 不能通过 CivitAI API 获取,必须通过文件路径(model root)判断 +- LoRA 的 sub_type(lora/locon/dora)直接来自 CivitAI API 的 `version_info.model.type` diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index 84a86bb8..7faab2d2 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -5,7 +5,7 @@ import logging import os import time -from ..utils.constants import VALID_LORA_TYPES +from ..utils.constants import VALID_LORA_SUB_TYPES, VALID_CHECKPOINT_SUB_TYPES from ..utils.models import BaseModelMetadata from ..utils.metadata_manager import MetadataManager from ..utils.usage_stats import UsageStats @@ -15,8 +15,8 @@ from .model_query import ( ModelFilterSet, SearchStrategy, SettingsProvider, - normalize_civitai_model_type, - resolve_civitai_model_type, + normalize_sub_type, + resolve_sub_type, ) from .settings_manager import get_settings_manager @@ -568,16 +568,21 @@ class BaseModelService(ABC): return await self.scanner.get_base_models(limit) async def get_model_types(self, limit: int = 20) -> List[Dict[str, Any]]: - """Get counts of normalized CivitAI model types present in the cache.""" + """Get counts of sub-types present in the cache.""" cache = await self.scanner.get_cached_data() type_counts: Dict[str, int] = {} for entry in cache.raw_data: - normalized_type = normalize_civitai_model_type( - resolve_civitai_model_type(entry) - ) - if not normalized_type or normalized_type not in VALID_LORA_TYPES: + normalized_type = normalize_sub_type(resolve_sub_type(entry)) + if not normalized_type: continue + + # Filter by valid sub-types based on scanner type + if self.model_type == "lora" and normalized_type not in VALID_LORA_SUB_TYPES: + continue + if self.model_type == "checkpoint" and normalized_type not in VALID_CHECKPOINT_SUB_TYPES: + continue + type_counts[normalized_type] = type_counts.get(normalized_type, 0) + 1 sorted_types = sorted( diff --git a/py/services/checkpoint_scanner.py b/py/services/checkpoint_scanner.py index 25afce90..6a9d5129 100644 --- a/py/services/checkpoint_scanner.py +++ b/py/services/checkpoint_scanner.py @@ -21,7 +21,8 @@ class CheckpointScanner(ModelScanner): hash_index=ModelHashIndex() ) - def _resolve_model_type(self, root_path: Optional[str]) -> Optional[str]: + def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]: + """Resolve the sub-type based on the root path.""" if not root_path: return None @@ -34,18 +35,28 @@ class CheckpointScanner(ModelScanner): return None def adjust_metadata(self, metadata, file_path, root_path): - if hasattr(metadata, "model_type"): - model_type = self._resolve_model_type(root_path) - if model_type: - metadata.model_type = model_type + """Adjust metadata during scanning to set sub_type.""" + # Support both old 'model_type' and new 'sub_type' for backward compatibility + if hasattr(metadata, "sub_type"): + sub_type = self._resolve_sub_type(root_path) + if sub_type: + metadata.sub_type = sub_type + elif hasattr(metadata, "model_type"): + # Backward compatibility: fallback to model_type if sub_type not available + sub_type = self._resolve_sub_type(root_path) + if sub_type: + metadata.model_type = sub_type return metadata def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]: - model_type = self._resolve_model_type( + """Adjust entries loaded from the persisted cache to ensure sub_type is set.""" + sub_type = self._resolve_sub_type( self._find_root_for_file(entry.get("file_path")) ) - if model_type: - entry["model_type"] = model_type + if sub_type: + entry["sub_type"] = sub_type + # Also set model_type for backward compatibility during transition + entry["model_type"] = sub_type return entry def get_model_roots(self) -> List[str]: diff --git a/py/services/checkpoint_service.py b/py/services/checkpoint_service.py index 924f250a..bdb7e97b 100644 --- a/py/services/checkpoint_service.py +++ b/py/services/checkpoint_service.py @@ -22,6 +22,9 @@ class CheckpointService(BaseModelService): async def format_response(self, checkpoint_data: Dict) -> Dict: """Format Checkpoint data for API response""" + # Get sub_type from cache entry (new field) or fallback to model_type (old field) + sub_type = checkpoint_data.get("sub_type") or checkpoint_data.get("model_type", "checkpoint") + return { "model_name": checkpoint_data["model_name"], "file_name": checkpoint_data["file_name"], @@ -37,7 +40,8 @@ class CheckpointService(BaseModelService): "from_civitai": checkpoint_data.get("from_civitai", True), "usage_count": checkpoint_data.get("usage_count", 0), "notes": checkpoint_data.get("notes", ""), - "model_type": checkpoint_data.get("model_type", "checkpoint"), + "sub_type": sub_type, # New canonical field + "model_type": sub_type, # Backward compatibility "favorite": checkpoint_data.get("favorite", False), "update_available": bool(checkpoint_data.get("update_available", False)), "civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True) diff --git a/py/services/embedding_service.py b/py/services/embedding_service.py index bfa51d15..881f4b6b 100644 --- a/py/services/embedding_service.py +++ b/py/services/embedding_service.py @@ -22,6 +22,9 @@ class EmbeddingService(BaseModelService): async def format_response(self, embedding_data: Dict) -> Dict: """Format Embedding data for API response""" + # Get sub_type from cache entry (new field) or fallback to model_type (old field) + sub_type = embedding_data.get("sub_type") or embedding_data.get("model_type", "embedding") + return { "model_name": embedding_data["model_name"], "file_name": embedding_data["file_name"], @@ -37,7 +40,8 @@ class EmbeddingService(BaseModelService): "from_civitai": embedding_data.get("from_civitai", True), # "usage_count": embedding_data.get("usage_count", 0), # TODO: Enable when embedding usage tracking is implemented "notes": embedding_data.get("notes", ""), - "model_type": embedding_data.get("model_type", "embedding"), + "sub_type": sub_type, # New canonical field + "model_type": sub_type, # Backward compatibility "favorite": embedding_data.get("favorite", False), "update_available": bool(embedding_data.get("update_available", False)), "civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True) diff --git a/py/services/lora_service.py b/py/services/lora_service.py index 40e85330..00ca3bf2 100644 --- a/py/services/lora_service.py +++ b/py/services/lora_service.py @@ -23,6 +23,9 @@ class LoraService(BaseModelService): async def format_response(self, lora_data: Dict) -> Dict: """Format LoRA data for API response""" + # Get sub_type from cache entry (new field) or fallback to model_type (old field) + sub_type = lora_data.get("sub_type") or lora_data.get("model_type", "lora") + return { "model_name": lora_data["model_name"], "file_name": lora_data["file_name"], @@ -43,6 +46,8 @@ class LoraService(BaseModelService): "notes": lora_data.get("notes", ""), "favorite": lora_data.get("favorite", False), "update_available": bool(lora_data.get("update_available", False)), + "sub_type": sub_type, # New canonical field + "model_type": sub_type, # Backward compatibility "civitai": self.filter_civitai_data( lora_data.get("civitai", {}), minimal=True ), diff --git a/py/services/model_query.py b/py/services/model_query.py index 842c39df..88e4439f 100644 --- a/py/services/model_query.py +++ b/py/services/model_query.py @@ -33,32 +33,54 @@ def _coerce_to_str(value: Any) -> Optional[str]: return candidate if candidate else None -def normalize_civitai_model_type(value: Any) -> Optional[str]: - """Return a lowercase string suitable for comparisons.""" +def normalize_sub_type(value: Any) -> Optional[str]: + """Return a lowercase string suitable for sub_type comparisons.""" candidate = _coerce_to_str(value) return candidate.lower() if candidate else None -def resolve_civitai_model_type(entry: Mapping[str, Any]) -> str: - """Extract the model type from CivitAI metadata, defaulting to LORA.""" +# Backward compatibility alias +normalize_civitai_model_type = normalize_sub_type + + +def resolve_sub_type(entry: Mapping[str, Any]) -> str: + """Extract the sub-type from metadata, checking multiple sources. + + Priority: + 1. entry['sub_type'] - new canonical field + 2. entry['model_type'] - backward compatibility + 3. civitai.model.type - CivitAI API data + 4. DEFAULT_CIVITAI_MODEL_TYPE - fallback + """ if not isinstance(entry, Mapping): return DEFAULT_CIVITAI_MODEL_TYPE - civitai = entry.get("civitai") - if isinstance(civitai, Mapping): - civitai_model = civitai.get("model") - if isinstance(civitai_model, Mapping): - model_type = _coerce_to_str(civitai_model.get("type")) - if model_type: - return model_type + # Priority 1: Check new canonical field 'sub_type' + sub_type = _coerce_to_str(entry.get("sub_type")) + if sub_type: + return sub_type + # Priority 2: Backward compatibility - check 'model_type' field model_type = _coerce_to_str(entry.get("model_type")) if model_type: return model_type + # Priority 3: Extract from CivitAI metadata + civitai = entry.get("civitai") + if isinstance(civitai, Mapping): + civitai_model = civitai.get("model") + if isinstance(civitai_model, Mapping): + civitai_type = _coerce_to_str(civitai_model.get("type")) + if civitai_type: + return civitai_type + return DEFAULT_CIVITAI_MODEL_TYPE +# Backward compatibility alias +resolve_civitai_model_type = resolve_sub_type + + class SettingsProvider(Protocol): """Protocol describing the SettingsManager contract used by query helpers.""" @@ -313,7 +335,7 @@ class ModelFilterSet: normalized_model_types = { model_type for model_type in ( - normalize_civitai_model_type(value) for value in model_types + normalize_sub_type(value) for value in model_types ) if model_type } @@ -321,7 +343,7 @@ class ModelFilterSet: items = [ item for item in items - if normalize_civitai_model_type(resolve_civitai_model_type(item)) + if normalize_sub_type(resolve_sub_type(item)) in normalized_model_types ] model_types_duration = time.perf_counter() - t0 diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index 40015516..8ec53b3d 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -275,9 +275,16 @@ class ModelScanner: _, license_flags = resolve_license_info(license_source or {}) entry['license_flags'] = license_flags + # Handle sub_type (new canonical field) and model_type (backward compatibility) + sub_type = get_value('sub_type', None) model_type = get_value('model_type', None) - if model_type: - entry['model_type'] = model_type + + # Prefer sub_type, fallback to model_type for backward compatibility + effective_sub_type = sub_type or model_type + if effective_sub_type: + entry['sub_type'] = effective_sub_type + # Also keep model_type for backward compatibility during transition + entry['model_type'] = effective_sub_type return entry diff --git a/py/utils/constants.py b/py/utils/constants.py index 7646c962..e70ffe6b 100644 --- a/py/utils/constants.py +++ b/py/utils/constants.py @@ -45,8 +45,13 @@ SUPPORTED_MEDIA_EXTENSIONS = { "videos": [".mp4", ".webm"], } -# Valid Lora types -VALID_LORA_TYPES = ["lora", "locon", "dora"] +# Valid sub-types for each scanner type +VALID_LORA_SUB_TYPES = ["lora", "locon", "dora"] +VALID_CHECKPOINT_SUB_TYPES = ["checkpoint", "diffusion_model"] +VALID_EMBEDDING_SUB_TYPES = ["embedding"] + +# Backward compatibility alias +VALID_LORA_TYPES = VALID_LORA_SUB_TYPES # Supported Civitai model types for user model queries (case-insensitive) CIVITAI_USER_MODEL_TYPES = [ diff --git a/py/utils/models.py b/py/utils/models.py index 75acb840..c55bd29e 100644 --- a/py/utils/models.py +++ b/py/utils/models.py @@ -173,14 +173,14 @@ class LoraMetadata(BaseModelMetadata): @dataclass class CheckpointMetadata(BaseModelMetadata): """Represents the metadata structure for a Checkpoint model""" - model_type: str = "checkpoint" # Model type (checkpoint, diffusion_model, etc.) + sub_type: str = "checkpoint" # Model sub-type (checkpoint, diffusion_model, etc.) @classmethod def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'CheckpointMetadata': """Create CheckpointMetadata instance from Civitai version info""" file_name = file_info['name'] base_model = determine_base_model(version_info.get('baseModel', '')) - model_type = version_info.get('type', 'checkpoint') + sub_type = version_info.get('type', 'checkpoint') # Extract tags and description if available tags = [] @@ -203,7 +203,7 @@ class CheckpointMetadata(BaseModelMetadata): preview_nsfw_level=0, from_civitai=True, civitai=version_info, - model_type=model_type, + sub_type=sub_type, tags=tags, modelDescription=description ) @@ -211,14 +211,14 @@ class CheckpointMetadata(BaseModelMetadata): @dataclass class EmbeddingMetadata(BaseModelMetadata): """Represents the metadata structure for an Embedding model""" - model_type: str = "embedding" # Model type (embedding, textual_inversion, etc.) + sub_type: str = "embedding" @classmethod def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'EmbeddingMetadata': """Create EmbeddingMetadata instance from Civitai version info""" file_name = file_info['name'] base_model = determine_base_model(version_info.get('baseModel', '')) - model_type = version_info.get('type', 'embedding') + sub_type = version_info.get('type', 'embedding') # Extract tags and description if available tags = [] @@ -241,7 +241,7 @@ class EmbeddingMetadata(BaseModelMetadata): preview_nsfw_level=0, from_civitai=True, civitai=version_info, - model_type=model_type, + sub_type=sub_type, tags=tags, modelDescription=description ) diff --git a/static/js/utils/constants.js b/static/js/utils/constants.js index 9d36e91f..4a5a2081 100644 --- a/static/js/utils/constants.js +++ b/static/js/utils/constants.js @@ -57,12 +57,22 @@ export const BASE_MODELS = { UNKNOWN: "Other" }; -export const MODEL_TYPE_DISPLAY_NAMES = { +// Model sub-type display names (new canonical field: sub_type) +export const MODEL_SUBTYPE_DISPLAY_NAMES = { + // LoRA sub-types lora: "LoRA", locon: "LyCORIS", dora: "DoRA", + // Checkpoint sub-types + checkpoint: "Checkpoint", + diffusion_model: "Diffusion Model", + // Embedding sub-types + embedding: "Embedding", }; +// Backward compatibility alias +export const MODEL_TYPE_DISPLAY_NAMES = MODEL_SUBTYPE_DISPLAY_NAMES; + export const BASE_MODEL_ABBREVIATIONS = { // Stable Diffusion 1.x models [BASE_MODELS.SD_1_4]: 'SD1', diff --git a/tests/services/test_checkpoint_scanner_sub_type.py b/tests/services/test_checkpoint_scanner_sub_type.py new file mode 100644 index 00000000..443b8f3f --- /dev/null +++ b/tests/services/test_checkpoint_scanner_sub_type.py @@ -0,0 +1,145 @@ +"""Tests for CheckpointScanner sub_type resolution.""" + +import pytest +import asyncio +from unittest.mock import MagicMock, patch + +from py.services.checkpoint_scanner import CheckpointScanner +from py.utils.models import CheckpointMetadata + + +class TestCheckpointScannerSubType: + """Test CheckpointScanner sub_type resolution logic.""" + + def create_scanner(self): + """Create scanner with no async initialization.""" + # Create scanner without calling __init__ to avoid async issues + scanner = object.__new__(CheckpointScanner) + scanner.model_type = "checkpoint" + scanner.model_class = CheckpointMetadata + scanner.file_extensions = {'.ckpt', '.safetensors'} + scanner._hash_index = MagicMock() + return scanner + + def test_resolve_sub_type_checkpoint_root(self): + """_resolve_sub_type should return 'checkpoint' for checkpoints_roots.""" + scanner = self.create_scanner() + + from py import config as config_module + original_checkpoints_roots = getattr(config_module.config, 'checkpoints_roots', None) + original_unet_roots = getattr(config_module.config, 'unet_roots', None) + + try: + config_module.config.checkpoints_roots = ["/models/checkpoints"] + config_module.config.unet_roots = ["/models/unet"] + + result = scanner._resolve_sub_type("/models/checkpoints") + assert result == "checkpoint" + finally: + if original_checkpoints_roots is not None: + config_module.config.checkpoints_roots = original_checkpoints_roots + if original_unet_roots is not None: + config_module.config.unet_roots = original_unet_roots + + def test_resolve_sub_type_unet_root(self): + """_resolve_sub_type should return 'diffusion_model' for unet_roots.""" + scanner = self.create_scanner() + + from py import config as config_module + original_checkpoints_roots = getattr(config_module.config, 'checkpoints_roots', None) + original_unet_roots = getattr(config_module.config, 'unet_roots', None) + + try: + config_module.config.checkpoints_roots = ["/models/checkpoints"] + config_module.config.unet_roots = ["/models/unet"] + + result = scanner._resolve_sub_type("/models/unet") + assert result == "diffusion_model" + finally: + if original_checkpoints_roots is not None: + config_module.config.checkpoints_roots = original_checkpoints_roots + if original_unet_roots is not None: + config_module.config.unet_roots = original_unet_roots + + def test_resolve_sub_type_none_root(self): + """_resolve_sub_type should return None for None input.""" + scanner = self.create_scanner() + result = scanner._resolve_sub_type(None) + assert result is None + + def test_resolve_sub_type_unknown_root(self): + """_resolve_sub_type should return None for unknown root.""" + scanner = self.create_scanner() + + from py import config as config_module + original_checkpoints_roots = getattr(config_module.config, 'checkpoints_roots', None) + original_unet_roots = getattr(config_module.config, 'unet_roots', None) + + try: + config_module.config.checkpoints_roots = ["/models/checkpoints"] + config_module.config.unet_roots = ["/models/unet"] + + result = scanner._resolve_sub_type("/models/unknown") + assert result is None + finally: + if original_checkpoints_roots is not None: + config_module.config.checkpoints_roots = original_checkpoints_roots + if original_unet_roots is not None: + config_module.config.unet_roots = original_unet_roots + + def test_adjust_metadata_sets_sub_type(self): + """adjust_metadata should set sub_type on metadata.""" + scanner = self.create_scanner() + + metadata = CheckpointMetadata( + file_name="test", + model_name="Test", + file_path="/models/checkpoints/model.safetensors", + size=1000, + modified=1234567890.0, + sha256="abc123", + base_model="SDXL", + preview_url="", + ) + + from py import config as config_module + original_checkpoints_roots = getattr(config_module.config, 'checkpoints_roots', None) + + try: + config_module.config.checkpoints_roots = ["/models/checkpoints"] + config_module.config.unet_roots = [] + + result = scanner.adjust_metadata(metadata, "/models/checkpoints/model.safetensors", "/models/checkpoints") + assert result.sub_type == "checkpoint" + finally: + if original_checkpoints_roots is not None: + config_module.config.checkpoints_roots = original_checkpoints_roots + + def test_adjust_cached_entry_sets_sub_type(self): + """adjust_cached_entry should set sub_type on entry.""" + scanner = self.create_scanner() + # Mock get_model_roots to return the expected roots + scanner.get_model_roots = lambda: ["/models/unet"] + + entry = { + "file_path": "/models/unet/model.safetensors", + "model_name": "Test", + } + + from py import config as config_module + original_checkpoints_roots = getattr(config_module.config, 'checkpoints_roots', None) + original_unet_roots = getattr(config_module.config, 'unet_roots', None) + + try: + config_module.config.checkpoints_roots = [] + config_module.config.unet_roots = ["/models/unet"] + + result = scanner.adjust_cached_entry(entry) + assert result["sub_type"] == "diffusion_model" + # Also sets model_type for backward compatibility + assert result["model_type"] == "diffusion_model" + finally: + if original_checkpoints_roots is not None: + config_module.config.checkpoints_roots = original_checkpoints_roots + if original_unet_roots is not None: + config_module.config.unet_roots = original_unet_roots diff --git a/tests/services/test_model_query_sub_type.py b/tests/services/test_model_query_sub_type.py new file mode 100644 index 00000000..69df6889 --- /dev/null +++ b/tests/services/test_model_query_sub_type.py @@ -0,0 +1,203 @@ +"""Tests for model_query sub_type resolution.""" + +import pytest +from py.services.model_query import ( + _coerce_to_str, + normalize_sub_type, + normalize_civitai_model_type, + resolve_sub_type, + resolve_civitai_model_type, + FilterCriteria, + ModelFilterSet, +) + + +class TestCoerceToStr: + """Test _coerce_to_str helper.""" + + def test_none_returns_none(self): + assert _coerce_to_str(None) is None + + def test_string_returns_stripped(self): + assert _coerce_to_str(" test ") == "test" + + def test_empty_returns_none(self): + assert _coerce_to_str(" ") is None + + def test_number_converts_to_str(self): + assert _coerce_to_str(123) == "123" + + +class TestNormalizeSubType: + """Test normalize_sub_type function.""" + + def test_normalizes_to_lowercase(self): + assert normalize_sub_type("LoRA") == "lora" + assert normalize_sub_type("CHECKPOINT") == "checkpoint" + + def test_strips_whitespace(self): + assert normalize_sub_type(" LoRA ") == "lora" + + def test_none_returns_none(self): + assert normalize_sub_type(None) is None + + def test_empty_returns_none(self): + assert normalize_sub_type("") is None + + +class TestNormalizeCivitaiModelTypeAlias: + """Test normalize_civitai_model_type is alias for normalize_sub_type.""" + + def test_alias_works_correctly(self): + assert normalize_civitai_model_type("LoRA") == "lora" + assert normalize_civitai_model_type("CHECKPOINT") == "checkpoint" + + +class TestResolveSubType: + """Test resolve_sub_type function priority.""" + + def test_priority_1_sub_type_field(self): + """Priority 1: entry['sub_type'] should be used first.""" + entry = { + "sub_type": "locon", + "model_type": "checkpoint", # Should be ignored + "civitai": {"model": {"type": "dora"}}, # Should be ignored + } + assert resolve_sub_type(entry) == "locon" + + def test_priority_2_model_type_field(self): + """Priority 2: entry['model_type'] as fallback.""" + entry = { + "model_type": "checkpoint", + "civitai": {"model": {"type": "dora"}}, # Should be ignored + } + assert resolve_sub_type(entry) == "checkpoint" + + def test_priority_3_civitai_model_type(self): + """Priority 3: civitai.model.type as fallback.""" + entry = { + "civitai": {"model": {"type": "dora"}}, + } + assert resolve_sub_type(entry) == "dora" + + def test_priority_4_default(self): + """Priority 4: default to LORA when nothing found.""" + entry = {} + assert resolve_sub_type(entry) == "LORA" + + def test_empty_sub_type_falls_back(self): + """Empty sub_type should fall back to model_type.""" + entry = { + "sub_type": "", + "model_type": "checkpoint", + } + assert resolve_sub_type(entry) == "checkpoint" + + def test_whitespace_sub_type_falls_back(self): + """Whitespace sub_type should fall back to model_type.""" + entry = { + "sub_type": " ", + "model_type": "checkpoint", + } + assert resolve_sub_type(entry) == "checkpoint" + + def test_none_entry_returns_default(self): + """None entry should return default.""" + assert resolve_sub_type(None) == "LORA" + + def test_non_mapping_returns_default(self): + """Non-mapping entry should return default.""" + assert resolve_sub_type("invalid") == "LORA" + + +class TestResolveCivitaiModelTypeAlias: + """Test resolve_civitai_model_type is alias for resolve_sub_type.""" + + def test_alias_works_correctly(self): + entry = {"sub_type": "locon"} + assert resolve_civitai_model_type(entry) == "locon" + + +class TestModelFilterSetWithSubType: + """Test ModelFilterSet applies model_types filtering correctly.""" + + def create_mock_settings(self): + class MockSettings: + def get(self, key, default=None): + return default + return MockSettings() + + def test_filter_by_sub_type(self): + """Filter should work with sub_type field.""" + settings = self.create_mock_settings() + filter_set = ModelFilterSet(settings) + + data = [ + {"sub_type": "lora", "model_name": "Model 1"}, + {"sub_type": "locon", "model_name": "Model 2"}, + {"sub_type": "dora", "model_name": "Model 3"}, + ] + + criteria = FilterCriteria(model_types=["lora", "locon"]) + result = filter_set.apply(data, criteria) + + assert len(result) == 2 + assert result[0]["model_name"] == "Model 1" + assert result[1]["model_name"] == "Model 2" + + def test_filter_falls_back_to_model_type(self): + """Filter should fall back to model_type field.""" + settings = self.create_mock_settings() + filter_set = ModelFilterSet(settings) + + data = [ + {"model_type": "lora", "model_name": "Model 1"}, # Old field + {"sub_type": "locon", "model_name": "Model 2"}, # New field + ] + + criteria = FilterCriteria(model_types=["lora", "locon"]) + result = filter_set.apply(data, criteria) + + assert len(result) == 2 + + def test_filter_uses_civitai_type(self): + """Filter should use civitai.model.type as last resort.""" + settings = self.create_mock_settings() + filter_set = ModelFilterSet(settings) + + data = [ + {"civitai": {"model": {"type": "dora"}}, "model_name": "Model 1"}, + ] + + criteria = FilterCriteria(model_types=["dora"]) + result = filter_set.apply(data, criteria) + + assert len(result) == 1 + + def test_filter_case_insensitive(self): + """Filter should be case insensitive.""" + settings = self.create_mock_settings() + filter_set = ModelFilterSet(settings) + + data = [ + {"sub_type": "LoRA", "model_name": "Model 1"}, + ] + + criteria = FilterCriteria(model_types=["lora"]) + result = filter_set.apply(data, criteria) + + assert len(result) == 1 + + def test_filter_no_match_returns_empty(self): + """Filter with no match should return empty list.""" + settings = self.create_mock_settings() + filter_set = ModelFilterSet(settings) + + data = [ + {"sub_type": "lora", "model_name": "Model 1"}, + ] + + criteria = FilterCriteria(model_types=["checkpoint"]) + result = filter_set.apply(data, criteria) + + assert len(result) == 0 diff --git a/tests/services/test_service_format_response_sub_type.py b/tests/services/test_service_format_response_sub_type.py new file mode 100644 index 00000000..2929347b --- /dev/null +++ b/tests/services/test_service_format_response_sub_type.py @@ -0,0 +1,230 @@ +"""Tests for service format_response sub_type inclusion.""" + +import pytest +import asyncio +from unittest.mock import MagicMock, AsyncMock + +from py.services.lora_service import LoraService +from py.services.checkpoint_service import CheckpointService +from py.services.embedding_service import EmbeddingService + + +class TestLoraServiceFormatResponse: + """Test LoraService.format_response includes sub_type.""" + + @pytest.fixture + def mock_scanner(self): + scanner = MagicMock() + scanner._hash_index = MagicMock() + return scanner + + @pytest.fixture + def lora_service(self, mock_scanner): + return LoraService(mock_scanner) + + @pytest.mark.asyncio + async def test_format_response_includes_sub_type(self, lora_service): + """format_response should include sub_type field.""" + lora_data = { + "model_name": "Test LoRA", + "file_name": "test_lora", + "preview_url": "test.webp", + "preview_nsfw_level": 0, + "base_model": "SDXL", + "folder": "", + "sha256": "abc123", + "file_path": "/models/test_lora.safetensors", + "size": 1000, + "modified": 1234567890.0, + "tags": [], + "from_civitai": True, + "usage_count": 0, + "usage_tips": "", + "notes": "", + "favorite": False, + "sub_type": "locon", # New field + "civitai": {}, + } + + result = await lora_service.format_response(lora_data) + + assert "sub_type" in result + assert result["sub_type"] == "locon" + + @pytest.mark.asyncio + async def test_format_response_falls_back_to_model_type(self, lora_service): + """format_response should fall back to model_type if sub_type missing.""" + lora_data = { + "model_name": "Test LoRA", + "file_name": "test_lora", + "preview_url": "test.webp", + "preview_nsfw_level": 0, + "base_model": "SDXL", + "folder": "", + "sha256": "abc123", + "file_path": "/models/test_lora.safetensors", + "size": 1000, + "modified": 1234567890.0, + "tags": [], + "from_civitai": True, + "model_type": "dora", # Old field + "civitai": {}, + } + + result = await lora_service.format_response(lora_data) + + assert result["sub_type"] == "dora" + assert result["model_type"] == "dora" # Both should be set + + @pytest.mark.asyncio + async def test_format_response_defaults_to_lora(self, lora_service): + """format_response should default to 'lora' if no type field.""" + lora_data = { + "model_name": "Test LoRA", + "file_name": "test_lora", + "preview_url": "test.webp", + "preview_nsfw_level": 0, + "base_model": "SDXL", + "folder": "", + "sha256": "abc123", + "file_path": "/models/test_lora.safetensors", + "size": 1000, + "modified": 1234567890.0, + "tags": [], + "from_civitai": True, + "civitai": {}, + } + + result = await lora_service.format_response(lora_data) + + assert result["sub_type"] == "lora" + assert result["model_type"] == "lora" + + +class TestCheckpointServiceFormatResponse: + """Test CheckpointService.format_response includes sub_type.""" + + @pytest.fixture + def mock_scanner(self): + scanner = MagicMock() + scanner._hash_index = MagicMock() + return scanner + + @pytest.fixture + def checkpoint_service(self, mock_scanner): + return CheckpointService(mock_scanner) + + @pytest.mark.asyncio + async def test_format_response_includes_sub_type_checkpoint(self, checkpoint_service): + """format_response should include sub_type field for checkpoint.""" + checkpoint_data = { + "model_name": "Test Checkpoint", + "file_name": "test_ckpt", + "preview_url": "test.webp", + "preview_nsfw_level": 0, + "base_model": "SDXL", + "folder": "", + "sha256": "abc123", + "file_path": "/models/test.safetensors", + "size": 1000, + "modified": 1234567890.0, + "tags": [], + "from_civitai": True, + "sub_type": "checkpoint", + "civitai": {}, + } + + result = await checkpoint_service.format_response(checkpoint_data) + + assert "sub_type" in result + assert result["sub_type"] == "checkpoint" + assert result["model_type"] == "checkpoint" + + @pytest.mark.asyncio + async def test_format_response_includes_sub_type_diffusion_model(self, checkpoint_service): + """format_response should include sub_type field for diffusion_model.""" + checkpoint_data = { + "model_name": "Test Diffusion Model", + "file_name": "test_unet", + "preview_url": "test.webp", + "preview_nsfw_level": 0, + "base_model": "SDXL", + "folder": "", + "sha256": "abc123", + "file_path": "/models/test.safetensors", + "size": 1000, + "modified": 1234567890.0, + "tags": [], + "from_civitai": True, + "sub_type": "diffusion_model", + "civitai": {}, + } + + result = await checkpoint_service.format_response(checkpoint_data) + + assert result["sub_type"] == "diffusion_model" + assert result["model_type"] == "diffusion_model" + + +class TestEmbeddingServiceFormatResponse: + """Test EmbeddingService.format_response includes sub_type.""" + + @pytest.fixture + def mock_scanner(self): + scanner = MagicMock() + scanner._hash_index = MagicMock() + return scanner + + @pytest.fixture + def embedding_service(self, mock_scanner): + return EmbeddingService(mock_scanner) + + @pytest.mark.asyncio + async def test_format_response_includes_sub_type(self, embedding_service): + """format_response should include sub_type field.""" + embedding_data = { + "model_name": "Test Embedding", + "file_name": "test_emb", + "preview_url": "test.webp", + "preview_nsfw_level": 0, + "base_model": "SD1.5", + "folder": "", + "sha256": "abc123", + "file_path": "/models/test.pt", + "size": 1000, + "modified": 1234567890.0, + "tags": [], + "from_civitai": True, + "sub_type": "embedding", + "civitai": {}, + } + + result = await embedding_service.format_response(embedding_data) + + assert "sub_type" in result + assert result["sub_type"] == "embedding" + assert result["model_type"] == "embedding" + + @pytest.mark.asyncio + async def test_format_response_defaults_to_embedding(self, embedding_service): + """format_response should default to 'embedding' if no type field.""" + embedding_data = { + "model_name": "Test Embedding", + "file_name": "test_emb", + "preview_url": "test.webp", + "preview_nsfw_level": 0, + "base_model": "SD1.5", + "folder": "", + "sha256": "abc123", + "file_path": "/models/test.pt", + "size": 1000, + "modified": 1234567890.0, + "tags": [], + "from_civitai": True, + "civitai": {}, + } + + result = await embedding_service.format_response(embedding_data) + + assert result["sub_type"] == "embedding" + assert result["model_type"] == "embedding" diff --git a/tests/utils/test_models_sub_type.py b/tests/utils/test_models_sub_type.py new file mode 100644 index 00000000..6eb79676 --- /dev/null +++ b/tests/utils/test_models_sub_type.py @@ -0,0 +1,127 @@ +"""Tests for model sub_type field refactoring.""" + +import pytest +from py.utils.models import ( + BaseModelMetadata, + LoraMetadata, + CheckpointMetadata, + EmbeddingMetadata, +) + + +class TestCheckpointMetadataSubType: + """Test CheckpointMetadata uses sub_type field.""" + + def test_checkpoint_has_sub_type_field(self): + """CheckpointMetadata should have sub_type field.""" + metadata = CheckpointMetadata( + file_name="test", + model_name="Test Model", + file_path="/test/model.safetensors", + size=1000, + modified=1234567890.0, + sha256="abc123", + base_model="SDXL", + preview_url="", + ) + assert hasattr(metadata, "sub_type") + assert metadata.sub_type == "checkpoint" + + def test_checkpoint_sub_type_can_be_diffusion_model(self): + """CheckpointMetadata sub_type can be set to diffusion_model.""" + metadata = CheckpointMetadata( + file_name="test", + model_name="Test Model", + file_path="/test/model.safetensors", + size=1000, + modified=1234567890.0, + sha256="abc123", + base_model="SDXL", + preview_url="", + sub_type="diffusion_model", + ) + assert metadata.sub_type == "diffusion_model" + + def test_checkpoint_from_civitai_info_uses_sub_type(self): + """from_civitai_info should use sub_type from version_info.""" + version_info = { + "baseModel": "SDXL", + "model": {"name": "Test", "description": "", "tags": []}, + "files": [{"name": "model.safetensors", "sizeKB": 1000, "hashes": {"SHA256": "abc123"}, "primary": True}], + } + file_info = version_info["files"][0] + save_path = "/test/model.safetensors" + + metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path) + + assert hasattr(metadata, "sub_type") + # When type is missing from version_info, defaults to "checkpoint" + assert metadata.sub_type == "checkpoint" + + +class TestEmbeddingMetadataSubType: + """Test EmbeddingMetadata uses sub_type field.""" + + def test_embedding_has_sub_type_field(self): + """EmbeddingMetadata should have sub_type field.""" + metadata = EmbeddingMetadata( + file_name="test", + model_name="Test Model", + file_path="/test/model.pt", + size=1000, + modified=1234567890.0, + sha256="abc123", + base_model="SD1.5", + preview_url="", + ) + assert hasattr(metadata, "sub_type") + assert metadata.sub_type == "embedding" + + def test_embedding_from_civitai_info_uses_sub_type(self): + """from_civitai_info should use sub_type from version_info.""" + version_info = { + "baseModel": "SD1.5", + "model": {"name": "Test", "description": "", "tags": []}, + "files": [{"name": "model.pt", "sizeKB": 1000, "hashes": {"SHA256": "abc123"}, "primary": True}], + } + file_info = version_info["files"][0] + save_path = "/test/model.pt" + + metadata = EmbeddingMetadata.from_civitai_info(version_info, file_info, save_path) + + assert hasattr(metadata, "sub_type") + assert metadata.sub_type == "embedding" + + +class TestLoraMetadataConsistency: + """Test LoraMetadata consistency (no sub_type field, uses civitai data).""" + + def test_lora_does_not_have_sub_type_field(self): + """LoraMetadata should not have sub_type field (uses civitai.model.type).""" + metadata = LoraMetadata( + file_name="test", + model_name="Test Model", + file_path="/test/model.safetensors", + size=1000, + modified=1234567890.0, + sha256="abc123", + base_model="SDXL", + preview_url="", + ) + # Lora doesn't have sub_type field - it uses civitai data + assert not hasattr(metadata, "sub_type") + + def test_lora_from_civitai_info_extracts_type(self): + """from_civitai_info should extract type from civitai data.""" + version_info = { + "baseModel": "SDXL", + "model": {"name": "Test", "description": "", "tags": [], "type": "Lora"}, + "files": [{"name": "model.safetensors", "sizeKB": 1000, "hashes": {"SHA256": "abc123"}, "primary": True}], + } + file_info = version_info["files"][0] + save_path = "/test/model.safetensors" + + metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path) + + # Type is stored in civitai dict + assert metadata.civitai.get("model", {}).get("type") == "Lora"