refactor(model-type): complete phase 5 cleanup by removing deprecated model_type field

- Remove backward compatibility code for `model_type` in `ModelScanner._build_cache_entry()`
- Update `CheckpointScanner` to only handle `sub_type` in `adjust_metadata()` and `adjust_cached_entry()`
- Delete deprecated aliases `resolve_civitai_model_type` and `normalize_civitai_model_type` from `model_query.py`
- Update frontend components (`RecipeModal.js`, `ModelCard.js`, etc.) to use `sub_type` instead of `model_type`
- Update API response format to return only `sub_type`, removing `model_type` from service responses
- Revise technical documentation to mark Phase 5 as completed and remove outdated TODO items

All cleanup tasks for the model type refactoring are now complete, ensuring consistent use of `sub_type` across the codebase.
This commit is contained in:
Will Miao
2026-01-30 07:48:31 +08:00
parent 5e91073476
commit 84c62f2954
17 changed files with 115 additions and 209 deletions

View File

@@ -25,64 +25,52 @@
### Phase 2: 查询逻辑更新
- [x] `model_query.py` 新增 `resolve_sub_type()``normalize_sub_type()`
- [x] 保持向后兼容的别名 `resolve_civitai_model_type`, `normalize_civitai_model_type`
- [x] ~~保持向后兼容的别名 `resolve_civitai_model_type`, `normalize_civitai_model_type`~~ (已在 Phase 5 移除)
- [x] `ModelFilterSet.apply()` 更新为使用新的解析函数
### Phase 3: API 响应更新
- [x] `LoraService.format_response()` 返回 `sub_type` + `model_type`
- [x] `CheckpointService.format_response()` 返回 `sub_type` + `model_type`
- [x] `EmbeddingService.format_response()` 返回 `sub_type` + `model_type`
- [x] `LoraService.format_response()` 返回 `sub_type` ~~+ `model_type`~~ (已移除 `model_type`)
- [x] `CheckpointService.format_response()` 返回 `sub_type` ~~+ `model_type`~~ (已移除 `model_type`)
- [x] `EmbeddingService.format_response()` 返回 `sub_type` ~~+ `model_type`~~ (已移除 `model_type`)
### Phase 4: 前端更新
- [x] `constants.js` 新增 `MODEL_SUBTYPE_DISPLAY_NAMES`
- [x] `MODEL_TYPE_DISPLAY_NAMES` 作为别名保留
### Phase 5: 清理废弃代码 ✅
- [x]`ModelScanner._build_cache_entry()` 中移除 `model_type` 向后兼容代码
- [x]`CheckpointScanner` 中移除 `model_type` 兼容处理
- [x]`model_query.py` 中移除 `resolve_civitai_model_type``normalize_civitai_model_type` 别名
- [x] 更新前端 `FilterManager.js` 使用 `sub_type` (已在使用 `MODEL_SUBTYPE_DISPLAY_NAMES`)
- [x] 更新所有相关测试
---
## 遗留工作 ⏳
### Phase 5: 清理废弃代码(建议在下个 major version 进行)
### Phase 5: 清理废弃代码 ✅ **已完成**
#### 5.1 移除 `model_type` 字段的向后兼容代码
所有 Phase 5 的清理工作已完成:
**优先级**: 低
**风险**: 高(需要确保前端和第三方集成不再依赖)
#### 5.1 移除 `model_type` 字段的向后兼容代码 ✅
-`ModelScanner._build_cache_entry()` 中移除了 `model_type` 的设置
- 现在只设置 `sub_type` 字段
```python
# TODO: 从 ModelScanner._build_cache_entry() 中移除
# 当前代码:
if effective_sub_type:
entry['sub_type'] = effective_sub_type
entry['model_type'] = effective_sub_type # 待移除
#### 5.2 移除 CheckpointScanner 的 model_type 兼容处理 ✅
- `adjust_metadata()` 现在只处理 `sub_type`
- `adjust_cached_entry()` 现在只设置 `sub_type`
# 目标代码:
if effective_sub_type:
entry['sub_type'] = effective_sub_type
```
#### 5.3 移除 model_query 中的向后兼容别名 ✅
- 移除了 `resolve_civitai_model_type = resolve_sub_type`
- 移除了 `normalize_civitai_model_type = normalize_sub_type`
#### 5.2 移除 CheckpointScanner 的 model_type 兼容处理
```python
# TODO: 从 checkpoint_scanner.py 中移除对 model_type 的兼容处理
# 当前 adjust_metadata 同时检查 'sub_type' 和 'model_type'
# 目标:只处理 'sub_type'
```
#### 5.3 移除 model_query 中的向后兼容别名
```python
# TODO: 确认所有调用方都使用新函数后,移除这些别名
resolve_civitai_model_type = resolve_sub_type # 待移除
normalize_civitai_model_type = normalize_sub_type # 待移除
```
#### 5.4 前端清理
```javascript
// TODO: 从前端移除对 model_type 的依赖
// FilterManager.js 中仍然使用 model_type 作为内部状态名
// 需要统一改为使用 sub_type
```
#### 5.4 前端清理 ✅
- `FilterManager.js` 已经在使用 `MODEL_SUBTYPE_DISPLAY_NAMES` (通过别名 `MODEL_TYPE_DISPLAY_NAMES`)
- API list endpoint 现在只返回 `sub_type`,不再返回 `model_type`
- `ModelCard.js` 现在设置 `card.dataset.sub_type` (所有模型类型通用)
- `CheckpointContextMenu.js` 现在读取 `card.dataset.sub_type`
- `MoveManager.js` 现在处理 `cache_entry.sub_type`
- `RecipeModal.js` 现在读取 `checkpoint.sub_type`
---
@@ -108,16 +96,19 @@ UPDATE models SET sub_type = civitai_model_type WHERE sub_type IS NULL;
## 测试覆盖率
### 新增测试文件(已全部通过 ✅)
### 新增/更新测试文件(已全部通过 ✅)
| 测试文件 | 数量 | 覆盖内容 |
|---------|------|---------|
| `tests/utils/test_models_sub_type.py` | 7 | Metadata sub_type 字段 |
| `tests/services/test_model_query_sub_type.py` | 23 | sub_type 解析和过滤 |
| `tests/services/test_model_query_sub_type.py` | 19 | sub_type 解析和过滤 |
| `tests/services/test_checkpoint_scanner_sub_type.py` | 6 | CheckpointScanner sub_type |
| `tests/services/test_service_format_response_sub_type.py` | 7 | API 响应 sub_type 包含 |
| `tests/services/test_service_format_response_sub_type.py` | 6 | API 响应 sub_type 包含 |
| `tests/services/test_checkpoint_scanner.py` | 1 | Checkpoint 缓存 sub_type |
| `tests/services/test_model_scanner.py` | 1 | adjust_cached_entry hook |
| `tests/services/test_download_manager.py` | 1 | Checkpoint 下载 sub_type |
### 需要补充的测试(TODO
### 需要补充的测试(可选
- [ ] 集成测试:验证前端过滤使用 sub_type 字段
- [ ] 数据库迁移测试(如果执行可选优化)
@@ -127,13 +118,13 @@ UPDATE models SET sub_type = civitai_model_type WHERE sub_type IS NULL;
## 兼容性检查清单
在移除向后兼容代码前,请确认:
### 已完成 ✅
- [ ] 前端代码已全部改用 `sub_type` 字段
- [ ] ComfyUI Widget 代码不再依赖 `model_type`
- [ ] 移动端/第三方客户端已更新
- [ ] 文档已更新,说明 `model_type` 已弃用
- [ ] 提供至少 1 个版本的弃用警告期
- [x] 前端代码已全部改用 `sub_type` 字段
- [x] API list endpoint 已移除 `model_type`,只返回 `sub_type`
- [x] 后端 cache entry 已移除 `model_type`,只保留 `sub_type`
- [x] 所有测试已更新通过
- [x] 文档已更新
---
@@ -156,6 +147,10 @@ py/services/embedding_service.py
```
static/js/utils/constants.js
static/js/managers/FilterManager.js
static/js/managers/MoveManager.js
static/js/components/shared/ModelCard.js
static/js/components/ContextMenu/CheckpointContextMenu.js
static/js/components/RecipeModal.js
```
### 测试文件
@@ -172,18 +167,20 @@ tests/services/test_service_format_response_sub_type.py
| 风险项 | 影响 | 缓解措施 |
|-------|------|---------|
| 第三方代码依赖 `model_type` | 高 | 保持别名至少 1 个 major 版本 |
| 数据库 schema 变更 | 中 | 暂缓 schema 变更,仅运行时计算 |
| 前端过滤失效 | 中 | 全面的集成测试覆盖 |
| ~~第三方代码依赖 `model_type`~~ | ~~高~~ | ~~保持别名至少 1 个 major 版本~~ ✅ 已完成移除 |
| ~~数据库 schema 变更~~ | ~~中~~ | ~~暂缓 schema 变更,仅运行时计算~~ ✅ 无需变更 |
| ~~前端过滤失效~~ | ~~中~~ | ~~全面的集成测试覆盖~~ ✅ 测试通过 |
| CivitAI API 变化 | 低 | 保持多源解析策略 |
---
## 时间线建议
## 时间线
- **v1.x (当前)**: Phase 1-4 已完成,保持向后兼容
- **v2.0**: 添加弃用警告,开始迁移文档
- **v3.0**: 移除 `model_type` 兼容代码Phase 5
- **v1.x**: Phase 1-4 已完成,保持向后兼容
- **v2.0 (当前)**: ✅ Phase 5 已完成 - `model_type` 兼容代码已移除
- API list endpoint 只返回 `sub_type`
- Cache entry 只保留 `sub_type`
- 移除了 `resolve_civitai_model_type``normalize_civitai_model_type` 别名
---

View File

@@ -36,16 +36,9 @@ class CheckpointScanner(ModelScanner):
def adjust_metadata(self, metadata, file_path, root_path):
"""Adjust metadata during scanning to set sub_type."""
# Support both old 'model_type' and new 'sub_type' for backward compatibility
if hasattr(metadata, "sub_type"):
sub_type = self._resolve_sub_type(root_path)
if sub_type:
metadata.sub_type = sub_type
elif hasattr(metadata, "model_type"):
# Backward compatibility: fallback to model_type if sub_type not available
sub_type = self._resolve_sub_type(root_path)
if sub_type:
metadata.model_type = sub_type
sub_type = self._resolve_sub_type(root_path)
if sub_type:
metadata.sub_type = sub_type
return metadata
def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
@@ -55,8 +48,6 @@ class CheckpointScanner(ModelScanner):
)
if sub_type:
entry["sub_type"] = sub_type
# Also set model_type for backward compatibility during transition
entry["model_type"] = sub_type
return entry
def get_model_roots(self) -> List[str]:

View File

@@ -22,8 +22,8 @@ class CheckpointService(BaseModelService):
async def format_response(self, checkpoint_data: Dict) -> Dict:
"""Format Checkpoint data for API response"""
# Get sub_type from cache entry (new field) or fallback to model_type (old field)
sub_type = checkpoint_data.get("sub_type") or checkpoint_data.get("model_type", "checkpoint")
# Get sub_type from cache entry (new canonical field)
sub_type = checkpoint_data.get("sub_type", "checkpoint")
return {
"model_name": checkpoint_data["model_name"],
@@ -40,8 +40,7 @@ class CheckpointService(BaseModelService):
"from_civitai": checkpoint_data.get("from_civitai", True),
"usage_count": checkpoint_data.get("usage_count", 0),
"notes": checkpoint_data.get("notes", ""),
"sub_type": sub_type, # New canonical field
"model_type": sub_type, # Backward compatibility
"sub_type": sub_type,
"favorite": checkpoint_data.get("favorite", False),
"update_available": bool(checkpoint_data.get("update_available", False)),
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True)

View File

@@ -22,8 +22,8 @@ class EmbeddingService(BaseModelService):
async def format_response(self, embedding_data: Dict) -> Dict:
"""Format Embedding data for API response"""
# Get sub_type from cache entry (new field) or fallback to model_type (old field)
sub_type = embedding_data.get("sub_type") or embedding_data.get("model_type", "embedding")
# Get sub_type from cache entry (new canonical field)
sub_type = embedding_data.get("sub_type", "embedding")
return {
"model_name": embedding_data["model_name"],
@@ -40,8 +40,7 @@ class EmbeddingService(BaseModelService):
"from_civitai": embedding_data.get("from_civitai", True),
# "usage_count": embedding_data.get("usage_count", 0), # TODO: Enable when embedding usage tracking is implemented
"notes": embedding_data.get("notes", ""),
"sub_type": sub_type, # New canonical field
"model_type": sub_type, # Backward compatibility
"sub_type": sub_type,
"favorite": embedding_data.get("favorite", False),
"update_available": bool(embedding_data.get("update_available", False)),
"civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True)

View File

@@ -3,6 +3,7 @@ import logging
from typing import Dict, List, Optional
from .base_model_service import BaseModelService
from .model_query import resolve_sub_type
from ..utils.models import LoraMetadata
from ..config import config
@@ -23,8 +24,9 @@ class LoraService(BaseModelService):
async def format_response(self, lora_data: Dict) -> Dict:
"""Format LoRA data for API response"""
# Get sub_type from cache entry (new field) or fallback to model_type (old field)
sub_type = lora_data.get("sub_type") or lora_data.get("model_type", "lora")
# Resolve sub_type using priority: sub_type > model_type > civitai.model.type > default
# Normalize to lowercase for consistent API responses
sub_type = resolve_sub_type(lora_data).lower()
return {
"model_name": lora_data["model_name"],
@@ -46,8 +48,7 @@ class LoraService(BaseModelService):
"notes": lora_data.get("notes", ""),
"favorite": lora_data.get("favorite", False),
"update_available": bool(lora_data.get("update_available", False)),
"sub_type": sub_type, # New canonical field
"model_type": sub_type, # Backward compatibility
"sub_type": sub_type,
"civitai": self.filter_civitai_data(
lora_data.get("civitai", {}), minimal=True
),

View File

@@ -39,10 +39,6 @@ def normalize_sub_type(value: Any) -> Optional[str]:
return candidate.lower() if candidate else None
# Backward compatibility alias
normalize_civitai_model_type = normalize_sub_type
def resolve_sub_type(entry: Mapping[str, Any]) -> str:
"""Extract the sub-type from metadata, checking multiple sources.
@@ -77,10 +73,6 @@ def resolve_sub_type(entry: Mapping[str, Any]) -> str:
return DEFAULT_CIVITAI_MODEL_TYPE
# Backward compatibility alias
resolve_civitai_model_type = resolve_sub_type
class SettingsProvider(Protocol):
"""Protocol describing the SettingsManager contract used by query helpers."""

View File

@@ -275,16 +275,10 @@ class ModelScanner:
_, license_flags = resolve_license_info(license_source or {})
entry['license_flags'] = license_flags
# Handle sub_type (new canonical field) and model_type (backward compatibility)
# Handle sub_type (new canonical field)
sub_type = get_value('sub_type', None)
model_type = get_value('model_type', None)
# Prefer sub_type, fallback to model_type for backward compatibility
effective_sub_type = sub_type or model_type
if effective_sub_type:
entry['sub_type'] = effective_sub_type
# Also keep model_type for backward compatibility during transition
entry['model_type'] = effective_sub_type
if sub_type:
entry['sub_type'] = sub_type
return entry

View File

@@ -26,7 +26,7 @@ export class CheckpointContextMenu extends BaseContextMenu {
// Update the "Move to other root" label based on current model type
const moveOtherItem = this.menu.querySelector('[data-action="move-other"]');
if (moveOtherItem) {
const currentType = card.dataset.model_type || 'checkpoint';
const currentType = card.dataset.sub_type || 'checkpoint';
const otherType = currentType === 'checkpoint' ? 'diffusion_model' : 'checkpoint';
const typeLabel = i18n.t(`checkpoints.modelTypes.${otherType}`);
moveOtherItem.innerHTML = `<i class="fas fa-exchange-alt"></i> ${i18n.t('checkpoints.contextMenu.moveToOtherTypeFolder', { otherType: typeLabel })}`;
@@ -65,11 +65,11 @@ export class CheckpointContextMenu extends BaseContextMenu {
apiClient.refreshSingleModelMetadata(this.currentCard.dataset.filepath);
break;
case 'move':
moveManager.showMoveModal(this.currentCard.dataset.filepath, this.currentCard.dataset.model_type);
moveManager.showMoveModal(this.currentCard.dataset.filepath, this.currentCard.dataset.sub_type);
break;
case 'move-other':
{
const currentType = this.currentCard.dataset.model_type || 'checkpoint';
const currentType = this.currentCard.dataset.sub_type || 'checkpoint';
const otherType = currentType === 'checkpoint' ? 'diffusion_model' : 'checkpoint';
moveManager.showMoveModal(this.currentCard.dataset.filepath, otherType);
}

View File

@@ -1075,7 +1075,7 @@ class RecipeModal {
const checkpointName = checkpoint.name || checkpoint.modelName || checkpoint.file_name || 'Checkpoint';
const versionLabel = checkpoint.version || checkpoint.modelVersionName || '';
const baseModel = checkpoint.baseModel || checkpoint.base_model || '';
const modelTypeRaw = (checkpoint.model_type || checkpoint.type || 'checkpoint').toLowerCase();
const modelTypeRaw = (checkpoint.sub_type || checkpoint.type || 'checkpoint').toLowerCase();
const modelTypeLabel = modelTypeRaw === 'diffusion_model' ? 'Diffusion Model' : 'Checkpoint';
const previewMedia = isPreviewVideo ? `
@@ -1172,7 +1172,7 @@ class RecipeModal {
return;
}
const modelType = (checkpoint.model_type || checkpoint.type || 'checkpoint').toLowerCase();
const modelType = (checkpoint.sub_type || checkpoint.type || 'checkpoint').toLowerCase();
const isDiffusionModel = modelType === 'diffusion_model' || modelType === 'unet';
const widgetName = isDiffusionModel ? 'unet_name' : 'ckpt_name';

View File

@@ -176,7 +176,7 @@ function handleSendToWorkflow(card, replaceMode, modelType) {
return;
}
const subtype = (card.dataset.model_type || 'checkpoint').toLowerCase();
const subtype = (card.dataset.sub_type || 'checkpoint').toLowerCase();
const isDiffusionModel = subtype === 'diffusion_model';
const widgetName = isDiffusionModel ? 'unet_name' : 'ckpt_name';
const actionTypeText = translate(
@@ -453,9 +453,9 @@ export function createModelCard(model, modelType) {
card.dataset.usage_tips = model.usage_tips;
}
// checkpoint specific data
if (modelType === MODEL_TYPES.CHECKPOINT) {
card.dataset.model_type = model.model_type; // checkpoint or diffusion_model
// Set sub_type for all model types (lora/locon/dora, checkpoint/diffusion_model, embedding)
if (model.sub_type) {
card.dataset.sub_type = model.sub_type;
}
// Store metadata if available

View File

@@ -340,9 +340,9 @@ class MoveManager {
folder: newRelativeFolder
};
// Only update model_type if it's present in the cache_entry
if (result.cache_entry && result.cache_entry.model_type) {
updateData.model_type = result.cache_entry.model_type;
// Only update sub_type if it's present in the cache_entry
if (result.cache_entry && result.cache_entry.sub_type) {
updateData.sub_type = result.cache_entry.sub_type;
}
state.virtualScroller.updateSingleItem(result.original_file_path, updateData);
@@ -374,9 +374,9 @@ class MoveManager {
folder: newRelativeFolder
};
// Only update model_type if it's present in the cache_entry
if (result.cache_entry && result.cache_entry.model_type) {
updateData.model_type = result.cache_entry.model_type;
// Only update sub_type if it's present in the cache_entry
if (result.cache_entry && result.cache_entry.sub_type) {
updateData.sub_type = result.cache_entry.sub_type;
}
state.virtualScroller.updateSingleItem(this.currentFilePath, updateData);

View File

@@ -126,7 +126,7 @@ async def test_persisted_cache_restores_model_type(tmp_path: Path, monkeypatch):
assert loaded is True
cache = await scanner.get_cached_data()
types_by_path = {item["file_path"]: item.get("model_type") for item in cache.raw_data}
types_by_path = {item["file_path"]: item.get("sub_type") for item in cache.raw_data}
assert types_by_path[normalized_checkpoint_file] == "checkpoint"
assert types_by_path[normalized_unet_file] == "diffusion_model"

View File

@@ -136,8 +136,7 @@ class TestCheckpointScannerSubType:
result = scanner.adjust_cached_entry(entry)
assert result["sub_type"] == "diffusion_model"
# Also sets model_type for backward compatibility
assert result["model_type"] == "diffusion_model"
assert "model_type" not in result # Removed in refactoring
finally:
if original_checkpoints_roots is not None:
config_module.config.checkpoints_roots = original_checkpoints_roots

View File

@@ -479,7 +479,7 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
assert dummy_scanner.calls # ensure cache updated
async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_path):
async def test_execute_download_adjusts_checkpoint_sub_type(monkeypatch, tmp_path):
manager = DownloadManager()
root_dir = tmp_path / "checkpoints"
@@ -494,7 +494,7 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p
self.file_name = path.stem
self.preview_url = None
self.preview_nsfw_level = 0
self.model_type = "checkpoint"
self.sub_type = "checkpoint"
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
@@ -505,7 +505,7 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p
def to_dict(self):
return {
"file_path": self.file_path,
"model_type": self.model_type,
"sub_type": self.sub_type,
"sha256": self.sha256,
}
@@ -538,12 +538,12 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p
self, metadata_obj, _file_path: str, root_path: Optional[str]
):
if root_path:
metadata_obj.model_type = "diffusion_model"
metadata_obj.sub_type = "diffusion_model"
return metadata_obj
def adjust_cached_entry(self, entry):
if entry.get("file_path", "").startswith(self.root):
entry["model_type"] = "diffusion_model"
entry["sub_type"] = "diffusion_model"
return entry
async def add_model_to_cache(self, metadata_dict, relative_path):
@@ -570,12 +570,12 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p
)
assert result == {"success": True}
assert metadata.model_type == "diffusion_model"
assert metadata.sub_type == "diffusion_model"
saved_metadata = MetadataManager.save_metadata.await_args.args[1]
assert saved_metadata.model_type == "diffusion_model"
assert saved_metadata.sub_type == "diffusion_model"
assert dummy_scanner.add_calls
cached_entry, _ = dummy_scanner.add_calls[0]
assert cached_entry["model_type"] == "diffusion_model"
assert cached_entry["sub_type"] == "diffusion_model"
async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path):

View File

@@ -4,9 +4,7 @@ import pytest
from py.services.model_query import (
_coerce_to_str,
normalize_sub_type,
normalize_civitai_model_type,
resolve_sub_type,
resolve_civitai_model_type,
FilterCriteria,
ModelFilterSet,
)
@@ -45,14 +43,6 @@ class TestNormalizeSubType:
assert normalize_sub_type("") is None
class TestNormalizeCivitaiModelTypeAlias:
"""Test normalize_civitai_model_type is alias for normalize_sub_type."""
def test_alias_works_correctly(self):
assert normalize_civitai_model_type("LoRA") == "lora"
assert normalize_civitai_model_type("CHECKPOINT") == "checkpoint"
class TestResolveSubType:
"""Test resolve_sub_type function priority."""
@@ -60,44 +50,35 @@ class TestResolveSubType:
"""Priority 1: entry['sub_type'] should be used first."""
entry = {
"sub_type": "locon",
"model_type": "checkpoint", # Should be ignored
"civitai": {"model": {"type": "dora"}}, # Should be ignored
}
assert resolve_sub_type(entry) == "locon"
def test_priority_2_model_type_field(self):
"""Priority 2: entry['model_type'] as fallback."""
entry = {
"model_type": "checkpoint",
"civitai": {"model": {"type": "dora"}}, # Should be ignored
}
assert resolve_sub_type(entry) == "checkpoint"
def test_priority_3_civitai_model_type(self):
"""Priority 3: civitai.model.type as fallback."""
def test_priority_2_civitai_model_type(self):
"""Priority 2: civitai.model.type as fallback."""
entry = {
"civitai": {"model": {"type": "dora"}},
}
assert resolve_sub_type(entry) == "dora"
def test_priority_4_default(self):
"""Priority 4: default to LORA when nothing found."""
def test_priority_3_default(self):
"""Priority 3: default to LORA when nothing found."""
entry = {}
assert resolve_sub_type(entry) == "LORA"
def test_empty_sub_type_falls_back(self):
"""Empty sub_type should fall back to model_type."""
"""Empty sub_type should fall back to civitai type."""
entry = {
"sub_type": "",
"model_type": "checkpoint",
"civitai": {"model": {"type": "checkpoint"}},
}
assert resolve_sub_type(entry) == "checkpoint"
def test_whitespace_sub_type_falls_back(self):
"""Whitespace sub_type should fall back to model_type."""
"""Whitespace sub_type should fall back to civitai type."""
entry = {
"sub_type": " ",
"model_type": "checkpoint",
"civitai": {"model": {"type": "checkpoint"}},
}
assert resolve_sub_type(entry) == "checkpoint"
@@ -110,14 +91,6 @@ class TestResolveSubType:
assert resolve_sub_type("invalid") == "LORA"
class TestResolveCivitaiModelTypeAlias:
"""Test resolve_civitai_model_type is alias for resolve_sub_type."""
def test_alias_works_correctly(self):
entry = {"sub_type": "locon"}
assert resolve_civitai_model_type(entry) == "locon"
class TestModelFilterSetWithSubType:
"""Test ModelFilterSet applies model_types filtering correctly."""
@@ -145,23 +118,8 @@ class TestModelFilterSetWithSubType:
assert result[0]["model_name"] == "Model 1"
assert result[1]["model_name"] == "Model 2"
def test_filter_falls_back_to_model_type(self):
"""Filter should fall back to model_type field."""
settings = self.create_mock_settings()
filter_set = ModelFilterSet(settings)
data = [
{"model_type": "lora", "model_name": "Model 1"}, # Old field
{"sub_type": "locon", "model_name": "Model 2"}, # New field
]
criteria = FilterCriteria(model_types=["lora", "locon"])
result = filter_set.apply(data, criteria)
assert len(result) == 2
def test_filter_uses_civitai_type(self):
"""Filter should use civitai.model.type as last resort."""
"""Filter should use civitai.model.type as fallback."""
settings = self.create_mock_settings()
filter_set = ModelFilterSet(settings)

View File

@@ -521,7 +521,7 @@ async def test_reconcile_cache_applies_adjust_cached_entry(tmp_path: Path):
def _adjust(self, entry: dict) -> dict:
applied.append(entry["file_path"])
entry["model_type"] = "adjusted"
entry["custom_field"] = "adjusted"
return entry
scanner.adjust_cached_entry = MethodType(_adjust, scanner)
@@ -538,7 +538,7 @@ async def test_reconcile_cache_applies_adjust_cached_entry(tmp_path: Path):
assert normalized_new in applied
new_entry = next(item for item in scanner._cache.raw_data if item["file_path"] == normalized_new)
assert new_entry["model_type"] == "adjusted"
assert new_entry["custom_field"] == "adjusted"
@pytest.mark.asyncio

View File

@@ -42,7 +42,7 @@ class TestLoraServiceFormatResponse:
"usage_tips": "",
"notes": "",
"favorite": False,
"sub_type": "locon", # New field
"sub_type": "locon",
"civitai": {},
}
@@ -50,31 +50,7 @@ class TestLoraServiceFormatResponse:
assert "sub_type" in result
assert result["sub_type"] == "locon"
@pytest.mark.asyncio
async def test_format_response_falls_back_to_model_type(self, lora_service):
"""format_response should fall back to model_type if sub_type missing."""
lora_data = {
"model_name": "Test LoRA",
"file_name": "test_lora",
"preview_url": "test.webp",
"preview_nsfw_level": 0,
"base_model": "SDXL",
"folder": "",
"sha256": "abc123",
"file_path": "/models/test_lora.safetensors",
"size": 1000,
"modified": 1234567890.0,
"tags": [],
"from_civitai": True,
"model_type": "dora", # Old field
"civitai": {},
}
result = await lora_service.format_response(lora_data)
assert result["sub_type"] == "dora"
assert result["model_type"] == "dora" # Both should be set
assert "model_type" not in result # Removed in refactoring
@pytest.mark.asyncio
async def test_format_response_defaults_to_lora(self, lora_service):
@@ -98,7 +74,7 @@ class TestLoraServiceFormatResponse:
result = await lora_service.format_response(lora_data)
assert result["sub_type"] == "lora"
assert result["model_type"] == "lora"
assert "model_type" not in result # Removed in refactoring
class TestCheckpointServiceFormatResponse:
@@ -138,7 +114,7 @@ class TestCheckpointServiceFormatResponse:
assert "sub_type" in result
assert result["sub_type"] == "checkpoint"
assert result["model_type"] == "checkpoint"
assert "model_type" not in result # Removed in refactoring
@pytest.mark.asyncio
async def test_format_response_includes_sub_type_diffusion_model(self, checkpoint_service):
@@ -163,7 +139,7 @@ class TestCheckpointServiceFormatResponse:
result = await checkpoint_service.format_response(checkpoint_data)
assert result["sub_type"] == "diffusion_model"
assert result["model_type"] == "diffusion_model"
assert "model_type" not in result # Removed in refactoring
class TestEmbeddingServiceFormatResponse:
@@ -203,7 +179,7 @@ class TestEmbeddingServiceFormatResponse:
assert "sub_type" in result
assert result["sub_type"] == "embedding"
assert result["model_type"] == "embedding"
assert "model_type" not in result # Removed in refactoring
@pytest.mark.asyncio
async def test_format_response_defaults_to_embedding(self, embedding_service):
@@ -227,4 +203,4 @@ class TestEmbeddingServiceFormatResponse:
result = await embedding_service.format_response(embedding_data)
assert result["sub_type"] == "embedding"
assert result["model_type"] == "embedding"
assert "model_type" not in result # Removed in refactoring