mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 13:12:12 -03:00
refactor(model-type): complete phase 5 cleanup by removing deprecated model_type field
- Remove backward compatibility code for `model_type` in `ModelScanner._build_cache_entry()` - Update `CheckpointScanner` to only handle `sub_type` in `adjust_metadata()` and `adjust_cached_entry()` - Delete deprecated aliases `resolve_civitai_model_type` and `normalize_civitai_model_type` from `model_query.py` - Update frontend components (`RecipeModal.js`, `ModelCard.js`, etc.) to use `sub_type` instead of `model_type` - Update API response format to return only `sub_type`, removing `model_type` from service responses - Revise technical documentation to mark Phase 5 as completed and remove outdated TODO items All cleanup tasks for the model type refactoring are now complete, ensuring consistent use of `sub_type` across the codebase.
This commit is contained in:
@@ -25,64 +25,52 @@
|
||||
|
||||
### Phase 2: 查询逻辑更新
|
||||
- [x] `model_query.py` 新增 `resolve_sub_type()` 和 `normalize_sub_type()`
|
||||
- [x] 保持向后兼容的别名 `resolve_civitai_model_type`, `normalize_civitai_model_type`
|
||||
- [x] ~~保持向后兼容的别名 `resolve_civitai_model_type`, `normalize_civitai_model_type`~~ (已在 Phase 5 移除)
|
||||
- [x] `ModelFilterSet.apply()` 更新为使用新的解析函数
|
||||
|
||||
### Phase 3: API 响应更新
|
||||
- [x] `LoraService.format_response()` 返回 `sub_type` + `model_type`
|
||||
- [x] `CheckpointService.format_response()` 返回 `sub_type` + `model_type`
|
||||
- [x] `EmbeddingService.format_response()` 返回 `sub_type` + `model_type`
|
||||
- [x] `LoraService.format_response()` 返回 `sub_type` ~~+ `model_type`~~ (已移除 `model_type`)
|
||||
- [x] `CheckpointService.format_response()` 返回 `sub_type` ~~+ `model_type`~~ (已移除 `model_type`)
|
||||
- [x] `EmbeddingService.format_response()` 返回 `sub_type` ~~+ `model_type`~~ (已移除 `model_type`)
|
||||
|
||||
### Phase 4: 前端更新
|
||||
- [x] `constants.js` 新增 `MODEL_SUBTYPE_DISPLAY_NAMES`
|
||||
- [x] `MODEL_TYPE_DISPLAY_NAMES` 作为别名保留
|
||||
|
||||
### Phase 5: 清理废弃代码 ✅
|
||||
- [x] 从 `ModelScanner._build_cache_entry()` 中移除 `model_type` 向后兼容代码
|
||||
- [x] 从 `CheckpointScanner` 中移除 `model_type` 兼容处理
|
||||
- [x] 从 `model_query.py` 中移除 `resolve_civitai_model_type` 和 `normalize_civitai_model_type` 别名
|
||||
- [x] 更新前端 `FilterManager.js` 使用 `sub_type` (已在使用 `MODEL_SUBTYPE_DISPLAY_NAMES`)
|
||||
- [x] 更新所有相关测试
|
||||
|
||||
---
|
||||
|
||||
## 遗留工作 ⏳
|
||||
|
||||
### Phase 5: 清理废弃代码(建议在下个 major version 进行)
|
||||
### Phase 5: 清理废弃代码 ✅ **已完成**
|
||||
|
||||
#### 5.1 移除 `model_type` 字段的向后兼容代码
|
||||
所有 Phase 5 的清理工作已完成:
|
||||
|
||||
**优先级**: 低
|
||||
**风险**: 高(需要确保前端和第三方集成不再依赖)
|
||||
#### 5.1 移除 `model_type` 字段的向后兼容代码 ✅
|
||||
- 从 `ModelScanner._build_cache_entry()` 中移除了 `model_type` 的设置
|
||||
- 现在只设置 `sub_type` 字段
|
||||
|
||||
```python
|
||||
# TODO: 从 ModelScanner._build_cache_entry() 中移除
|
||||
# 当前代码:
|
||||
if effective_sub_type:
|
||||
entry['sub_type'] = effective_sub_type
|
||||
entry['model_type'] = effective_sub_type # 待移除
|
||||
#### 5.2 移除 CheckpointScanner 的 model_type 兼容处理 ✅
|
||||
- `adjust_metadata()` 现在只处理 `sub_type`
|
||||
- `adjust_cached_entry()` 现在只设置 `sub_type`
|
||||
|
||||
# 目标代码:
|
||||
if effective_sub_type:
|
||||
entry['sub_type'] = effective_sub_type
|
||||
```
|
||||
#### 5.3 移除 model_query 中的向后兼容别名 ✅
|
||||
- 移除了 `resolve_civitai_model_type = resolve_sub_type`
|
||||
- 移除了 `normalize_civitai_model_type = normalize_sub_type`
|
||||
|
||||
#### 5.2 移除 CheckpointScanner 的 model_type 兼容处理
|
||||
|
||||
```python
|
||||
# TODO: 从 checkpoint_scanner.py 中移除对 model_type 的兼容处理
|
||||
# 当前 adjust_metadata 同时检查 'sub_type' 和 'model_type'
|
||||
# 目标:只处理 'sub_type'
|
||||
```
|
||||
|
||||
#### 5.3 移除 model_query 中的向后兼容别名
|
||||
|
||||
```python
|
||||
# TODO: 确认所有调用方都使用新函数后,移除这些别名
|
||||
resolve_civitai_model_type = resolve_sub_type # 待移除
|
||||
normalize_civitai_model_type = normalize_sub_type # 待移除
|
||||
```
|
||||
|
||||
#### 5.4 前端清理
|
||||
|
||||
```javascript
|
||||
// TODO: 从前端移除对 model_type 的依赖
|
||||
// FilterManager.js 中仍然使用 model_type 作为内部状态名
|
||||
// 需要统一改为使用 sub_type
|
||||
```
|
||||
#### 5.4 前端清理 ✅
|
||||
- `FilterManager.js` 已经在使用 `MODEL_SUBTYPE_DISPLAY_NAMES` (通过别名 `MODEL_TYPE_DISPLAY_NAMES`)
|
||||
- API list endpoint 现在只返回 `sub_type`,不再返回 `model_type`
|
||||
- `ModelCard.js` 现在设置 `card.dataset.sub_type` (所有模型类型通用)
|
||||
- `CheckpointContextMenu.js` 现在读取 `card.dataset.sub_type`
|
||||
- `MoveManager.js` 现在处理 `cache_entry.sub_type`
|
||||
- `RecipeModal.js` 现在读取 `checkpoint.sub_type`
|
||||
|
||||
---
|
||||
|
||||
@@ -108,16 +96,19 @@ UPDATE models SET sub_type = civitai_model_type WHERE sub_type IS NULL;
|
||||
|
||||
## 测试覆盖率
|
||||
|
||||
### 新增测试文件(已全部通过 ✅)
|
||||
### 新增/更新测试文件(已全部通过 ✅)
|
||||
|
||||
| 测试文件 | 数量 | 覆盖内容 |
|
||||
|---------|------|---------|
|
||||
| `tests/utils/test_models_sub_type.py` | 7 | Metadata sub_type 字段 |
|
||||
| `tests/services/test_model_query_sub_type.py` | 23 | sub_type 解析和过滤 |
|
||||
| `tests/services/test_model_query_sub_type.py` | 19 | sub_type 解析和过滤 |
|
||||
| `tests/services/test_checkpoint_scanner_sub_type.py` | 6 | CheckpointScanner sub_type |
|
||||
| `tests/services/test_service_format_response_sub_type.py` | 7 | API 响应 sub_type 包含 |
|
||||
| `tests/services/test_service_format_response_sub_type.py` | 6 | API 响应 sub_type 包含 |
|
||||
| `tests/services/test_checkpoint_scanner.py` | 1 | Checkpoint 缓存 sub_type |
|
||||
| `tests/services/test_model_scanner.py` | 1 | adjust_cached_entry hook |
|
||||
| `tests/services/test_download_manager.py` | 1 | Checkpoint 下载 sub_type |
|
||||
|
||||
### 需要补充的测试(TODO)
|
||||
### 需要补充的测试(可选)
|
||||
|
||||
- [ ] 集成测试:验证前端过滤使用 sub_type 字段
|
||||
- [ ] 数据库迁移测试(如果执行可选优化)
|
||||
@@ -127,13 +118,13 @@ UPDATE models SET sub_type = civitai_model_type WHERE sub_type IS NULL;
|
||||
|
||||
## 兼容性检查清单
|
||||
|
||||
在移除向后兼容代码前,请确认:
|
||||
### 已完成 ✅
|
||||
|
||||
- [ ] 前端代码已全部改用 `sub_type` 字段
|
||||
- [ ] ComfyUI Widget 代码不再依赖 `model_type`
|
||||
- [ ] 移动端/第三方客户端已更新
|
||||
- [ ] 文档已更新,说明 `model_type` 已弃用
|
||||
- [ ] 提供至少 1 个版本的弃用警告期
|
||||
- [x] 前端代码已全部改用 `sub_type` 字段
|
||||
- [x] API list endpoint 已移除 `model_type`,只返回 `sub_type`
|
||||
- [x] 后端 cache entry 已移除 `model_type`,只保留 `sub_type`
|
||||
- [x] 所有测试已更新通过
|
||||
- [x] 文档已更新
|
||||
|
||||
---
|
||||
|
||||
@@ -156,6 +147,10 @@ py/services/embedding_service.py
|
||||
```
|
||||
static/js/utils/constants.js
|
||||
static/js/managers/FilterManager.js
|
||||
static/js/managers/MoveManager.js
|
||||
static/js/components/shared/ModelCard.js
|
||||
static/js/components/ContextMenu/CheckpointContextMenu.js
|
||||
static/js/components/RecipeModal.js
|
||||
```
|
||||
|
||||
### 测试文件
|
||||
@@ -172,18 +167,20 @@ tests/services/test_service_format_response_sub_type.py
|
||||
|
||||
| 风险项 | 影响 | 缓解措施 |
|
||||
|-------|------|---------|
|
||||
| 第三方代码依赖 `model_type` | 高 | 保持别名至少 1 个 major 版本 |
|
||||
| 数据库 schema 变更 | 中 | 暂缓 schema 变更,仅运行时计算 |
|
||||
| 前端过滤失效 | 中 | 全面的集成测试覆盖 |
|
||||
| ~~第三方代码依赖 `model_type`~~ | ~~高~~ | ~~保持别名至少 1 个 major 版本~~ ✅ 已完成移除 |
|
||||
| ~~数据库 schema 变更~~ | ~~中~~ | ~~暂缓 schema 变更,仅运行时计算~~ ✅ 无需变更 |
|
||||
| ~~前端过滤失效~~ | ~~中~~ | ~~全面的集成测试覆盖~~ ✅ 测试通过 |
|
||||
| CivitAI API 变化 | 低 | 保持多源解析策略 |
|
||||
|
||||
---
|
||||
|
||||
## 时间线建议
|
||||
## 时间线
|
||||
|
||||
- **v1.x (当前)**: Phase 1-4 已完成,保持向后兼容
|
||||
- **v2.0**: 添加弃用警告,开始迁移文档
|
||||
- **v3.0**: 移除 `model_type` 兼容代码(Phase 5)
|
||||
- **v1.x**: Phase 1-4 已完成,保持向后兼容
|
||||
- **v2.0 (当前)**: ✅ Phase 5 已完成 - `model_type` 兼容代码已移除
|
||||
- API list endpoint 只返回 `sub_type`
|
||||
- Cache entry 只保留 `sub_type`
|
||||
- 移除了 `resolve_civitai_model_type` 和 `normalize_civitai_model_type` 别名
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -36,16 +36,9 @@ class CheckpointScanner(ModelScanner):
|
||||
|
||||
def adjust_metadata(self, metadata, file_path, root_path):
|
||||
"""Adjust metadata during scanning to set sub_type."""
|
||||
# Support both old 'model_type' and new 'sub_type' for backward compatibility
|
||||
if hasattr(metadata, "sub_type"):
|
||||
sub_type = self._resolve_sub_type(root_path)
|
||||
if sub_type:
|
||||
metadata.sub_type = sub_type
|
||||
elif hasattr(metadata, "model_type"):
|
||||
# Backward compatibility: fallback to model_type if sub_type not available
|
||||
sub_type = self._resolve_sub_type(root_path)
|
||||
if sub_type:
|
||||
metadata.model_type = sub_type
|
||||
sub_type = self._resolve_sub_type(root_path)
|
||||
if sub_type:
|
||||
metadata.sub_type = sub_type
|
||||
return metadata
|
||||
|
||||
def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@@ -55,8 +48,6 @@ class CheckpointScanner(ModelScanner):
|
||||
)
|
||||
if sub_type:
|
||||
entry["sub_type"] = sub_type
|
||||
# Also set model_type for backward compatibility during transition
|
||||
entry["model_type"] = sub_type
|
||||
return entry
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
|
||||
@@ -22,8 +22,8 @@ class CheckpointService(BaseModelService):
|
||||
|
||||
async def format_response(self, checkpoint_data: Dict) -> Dict:
|
||||
"""Format Checkpoint data for API response"""
|
||||
# Get sub_type from cache entry (new field) or fallback to model_type (old field)
|
||||
sub_type = checkpoint_data.get("sub_type") or checkpoint_data.get("model_type", "checkpoint")
|
||||
# Get sub_type from cache entry (new canonical field)
|
||||
sub_type = checkpoint_data.get("sub_type", "checkpoint")
|
||||
|
||||
return {
|
||||
"model_name": checkpoint_data["model_name"],
|
||||
@@ -40,8 +40,7 @@ class CheckpointService(BaseModelService):
|
||||
"from_civitai": checkpoint_data.get("from_civitai", True),
|
||||
"usage_count": checkpoint_data.get("usage_count", 0),
|
||||
"notes": checkpoint_data.get("notes", ""),
|
||||
"sub_type": sub_type, # New canonical field
|
||||
"model_type": sub_type, # Backward compatibility
|
||||
"sub_type": sub_type,
|
||||
"favorite": checkpoint_data.get("favorite", False),
|
||||
"update_available": bool(checkpoint_data.get("update_available", False)),
|
||||
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True)
|
||||
|
||||
@@ -22,8 +22,8 @@ class EmbeddingService(BaseModelService):
|
||||
|
||||
async def format_response(self, embedding_data: Dict) -> Dict:
|
||||
"""Format Embedding data for API response"""
|
||||
# Get sub_type from cache entry (new field) or fallback to model_type (old field)
|
||||
sub_type = embedding_data.get("sub_type") or embedding_data.get("model_type", "embedding")
|
||||
# Get sub_type from cache entry (new canonical field)
|
||||
sub_type = embedding_data.get("sub_type", "embedding")
|
||||
|
||||
return {
|
||||
"model_name": embedding_data["model_name"],
|
||||
@@ -40,8 +40,7 @@ class EmbeddingService(BaseModelService):
|
||||
"from_civitai": embedding_data.get("from_civitai", True),
|
||||
# "usage_count": embedding_data.get("usage_count", 0), # TODO: Enable when embedding usage tracking is implemented
|
||||
"notes": embedding_data.get("notes", ""),
|
||||
"sub_type": sub_type, # New canonical field
|
||||
"model_type": sub_type, # Backward compatibility
|
||||
"sub_type": sub_type,
|
||||
"favorite": embedding_data.get("favorite", False),
|
||||
"update_available": bool(embedding_data.get("update_available", False)),
|
||||
"civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True)
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from .model_query import resolve_sub_type
|
||||
from ..utils.models import LoraMetadata
|
||||
from ..config import config
|
||||
|
||||
@@ -23,8 +24,9 @@ class LoraService(BaseModelService):
|
||||
|
||||
async def format_response(self, lora_data: Dict) -> Dict:
|
||||
"""Format LoRA data for API response"""
|
||||
# Get sub_type from cache entry (new field) or fallback to model_type (old field)
|
||||
sub_type = lora_data.get("sub_type") or lora_data.get("model_type", "lora")
|
||||
# Resolve sub_type using priority: sub_type > model_type > civitai.model.type > default
|
||||
# Normalize to lowercase for consistent API responses
|
||||
sub_type = resolve_sub_type(lora_data).lower()
|
||||
|
||||
return {
|
||||
"model_name": lora_data["model_name"],
|
||||
@@ -46,8 +48,7 @@ class LoraService(BaseModelService):
|
||||
"notes": lora_data.get("notes", ""),
|
||||
"favorite": lora_data.get("favorite", False),
|
||||
"update_available": bool(lora_data.get("update_available", False)),
|
||||
"sub_type": sub_type, # New canonical field
|
||||
"model_type": sub_type, # Backward compatibility
|
||||
"sub_type": sub_type,
|
||||
"civitai": self.filter_civitai_data(
|
||||
lora_data.get("civitai", {}), minimal=True
|
||||
),
|
||||
|
||||
@@ -39,10 +39,6 @@ def normalize_sub_type(value: Any) -> Optional[str]:
|
||||
return candidate.lower() if candidate else None
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
normalize_civitai_model_type = normalize_sub_type
|
||||
|
||||
|
||||
def resolve_sub_type(entry: Mapping[str, Any]) -> str:
|
||||
"""Extract the sub-type from metadata, checking multiple sources.
|
||||
|
||||
@@ -77,10 +73,6 @@ def resolve_sub_type(entry: Mapping[str, Any]) -> str:
|
||||
return DEFAULT_CIVITAI_MODEL_TYPE
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
resolve_civitai_model_type = resolve_sub_type
|
||||
|
||||
|
||||
class SettingsProvider(Protocol):
|
||||
"""Protocol describing the SettingsManager contract used by query helpers."""
|
||||
|
||||
|
||||
@@ -275,16 +275,10 @@ class ModelScanner:
|
||||
_, license_flags = resolve_license_info(license_source or {})
|
||||
entry['license_flags'] = license_flags
|
||||
|
||||
# Handle sub_type (new canonical field) and model_type (backward compatibility)
|
||||
# Handle sub_type (new canonical field)
|
||||
sub_type = get_value('sub_type', None)
|
||||
model_type = get_value('model_type', None)
|
||||
|
||||
# Prefer sub_type, fallback to model_type for backward compatibility
|
||||
effective_sub_type = sub_type or model_type
|
||||
if effective_sub_type:
|
||||
entry['sub_type'] = effective_sub_type
|
||||
# Also keep model_type for backward compatibility during transition
|
||||
entry['model_type'] = effective_sub_type
|
||||
if sub_type:
|
||||
entry['sub_type'] = sub_type
|
||||
|
||||
return entry
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ export class CheckpointContextMenu extends BaseContextMenu {
|
||||
// Update the "Move to other root" label based on current model type
|
||||
const moveOtherItem = this.menu.querySelector('[data-action="move-other"]');
|
||||
if (moveOtherItem) {
|
||||
const currentType = card.dataset.model_type || 'checkpoint';
|
||||
const currentType = card.dataset.sub_type || 'checkpoint';
|
||||
const otherType = currentType === 'checkpoint' ? 'diffusion_model' : 'checkpoint';
|
||||
const typeLabel = i18n.t(`checkpoints.modelTypes.${otherType}`);
|
||||
moveOtherItem.innerHTML = `<i class="fas fa-exchange-alt"></i> ${i18n.t('checkpoints.contextMenu.moveToOtherTypeFolder', { otherType: typeLabel })}`;
|
||||
@@ -65,11 +65,11 @@ export class CheckpointContextMenu extends BaseContextMenu {
|
||||
apiClient.refreshSingleModelMetadata(this.currentCard.dataset.filepath);
|
||||
break;
|
||||
case 'move':
|
||||
moveManager.showMoveModal(this.currentCard.dataset.filepath, this.currentCard.dataset.model_type);
|
||||
moveManager.showMoveModal(this.currentCard.dataset.filepath, this.currentCard.dataset.sub_type);
|
||||
break;
|
||||
case 'move-other':
|
||||
{
|
||||
const currentType = this.currentCard.dataset.model_type || 'checkpoint';
|
||||
const currentType = this.currentCard.dataset.sub_type || 'checkpoint';
|
||||
const otherType = currentType === 'checkpoint' ? 'diffusion_model' : 'checkpoint';
|
||||
moveManager.showMoveModal(this.currentCard.dataset.filepath, otherType);
|
||||
}
|
||||
|
||||
@@ -1075,7 +1075,7 @@ class RecipeModal {
|
||||
const checkpointName = checkpoint.name || checkpoint.modelName || checkpoint.file_name || 'Checkpoint';
|
||||
const versionLabel = checkpoint.version || checkpoint.modelVersionName || '';
|
||||
const baseModel = checkpoint.baseModel || checkpoint.base_model || '';
|
||||
const modelTypeRaw = (checkpoint.model_type || checkpoint.type || 'checkpoint').toLowerCase();
|
||||
const modelTypeRaw = (checkpoint.sub_type || checkpoint.type || 'checkpoint').toLowerCase();
|
||||
const modelTypeLabel = modelTypeRaw === 'diffusion_model' ? 'Diffusion Model' : 'Checkpoint';
|
||||
|
||||
const previewMedia = isPreviewVideo ? `
|
||||
@@ -1172,7 +1172,7 @@ class RecipeModal {
|
||||
return;
|
||||
}
|
||||
|
||||
const modelType = (checkpoint.model_type || checkpoint.type || 'checkpoint').toLowerCase();
|
||||
const modelType = (checkpoint.sub_type || checkpoint.type || 'checkpoint').toLowerCase();
|
||||
const isDiffusionModel = modelType === 'diffusion_model' || modelType === 'unet';
|
||||
const widgetName = isDiffusionModel ? 'unet_name' : 'ckpt_name';
|
||||
|
||||
|
||||
@@ -176,7 +176,7 @@ function handleSendToWorkflow(card, replaceMode, modelType) {
|
||||
return;
|
||||
}
|
||||
|
||||
const subtype = (card.dataset.model_type || 'checkpoint').toLowerCase();
|
||||
const subtype = (card.dataset.sub_type || 'checkpoint').toLowerCase();
|
||||
const isDiffusionModel = subtype === 'diffusion_model';
|
||||
const widgetName = isDiffusionModel ? 'unet_name' : 'ckpt_name';
|
||||
const actionTypeText = translate(
|
||||
@@ -453,9 +453,9 @@ export function createModelCard(model, modelType) {
|
||||
card.dataset.usage_tips = model.usage_tips;
|
||||
}
|
||||
|
||||
// checkpoint specific data
|
||||
if (modelType === MODEL_TYPES.CHECKPOINT) {
|
||||
card.dataset.model_type = model.model_type; // checkpoint or diffusion_model
|
||||
// Set sub_type for all model types (lora/locon/dora, checkpoint/diffusion_model, embedding)
|
||||
if (model.sub_type) {
|
||||
card.dataset.sub_type = model.sub_type;
|
||||
}
|
||||
|
||||
// Store metadata if available
|
||||
|
||||
@@ -340,9 +340,9 @@ class MoveManager {
|
||||
folder: newRelativeFolder
|
||||
};
|
||||
|
||||
// Only update model_type if it's present in the cache_entry
|
||||
if (result.cache_entry && result.cache_entry.model_type) {
|
||||
updateData.model_type = result.cache_entry.model_type;
|
||||
// Only update sub_type if it's present in the cache_entry
|
||||
if (result.cache_entry && result.cache_entry.sub_type) {
|
||||
updateData.sub_type = result.cache_entry.sub_type;
|
||||
}
|
||||
|
||||
state.virtualScroller.updateSingleItem(result.original_file_path, updateData);
|
||||
@@ -374,9 +374,9 @@ class MoveManager {
|
||||
folder: newRelativeFolder
|
||||
};
|
||||
|
||||
// Only update model_type if it's present in the cache_entry
|
||||
if (result.cache_entry && result.cache_entry.model_type) {
|
||||
updateData.model_type = result.cache_entry.model_type;
|
||||
// Only update sub_type if it's present in the cache_entry
|
||||
if (result.cache_entry && result.cache_entry.sub_type) {
|
||||
updateData.sub_type = result.cache_entry.sub_type;
|
||||
}
|
||||
|
||||
state.virtualScroller.updateSingleItem(this.currentFilePath, updateData);
|
||||
|
||||
@@ -126,7 +126,7 @@ async def test_persisted_cache_restores_model_type(tmp_path: Path, monkeypatch):
|
||||
assert loaded is True
|
||||
|
||||
cache = await scanner.get_cached_data()
|
||||
types_by_path = {item["file_path"]: item.get("model_type") for item in cache.raw_data}
|
||||
types_by_path = {item["file_path"]: item.get("sub_type") for item in cache.raw_data}
|
||||
|
||||
assert types_by_path[normalized_checkpoint_file] == "checkpoint"
|
||||
assert types_by_path[normalized_unet_file] == "diffusion_model"
|
||||
|
||||
@@ -136,8 +136,7 @@ class TestCheckpointScannerSubType:
|
||||
|
||||
result = scanner.adjust_cached_entry(entry)
|
||||
assert result["sub_type"] == "diffusion_model"
|
||||
# Also sets model_type for backward compatibility
|
||||
assert result["model_type"] == "diffusion_model"
|
||||
assert "model_type" not in result # Removed in refactoring
|
||||
finally:
|
||||
if original_checkpoints_roots is not None:
|
||||
config_module.config.checkpoints_roots = original_checkpoints_roots
|
||||
|
||||
@@ -479,7 +479,7 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
|
||||
assert dummy_scanner.calls # ensure cache updated
|
||||
|
||||
|
||||
async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_path):
|
||||
async def test_execute_download_adjusts_checkpoint_sub_type(monkeypatch, tmp_path):
|
||||
manager = DownloadManager()
|
||||
|
||||
root_dir = tmp_path / "checkpoints"
|
||||
@@ -494,7 +494,7 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p
|
||||
self.file_name = path.stem
|
||||
self.preview_url = None
|
||||
self.preview_nsfw_level = 0
|
||||
self.model_type = "checkpoint"
|
||||
self.sub_type = "checkpoint"
|
||||
|
||||
def generate_unique_filename(self, *_args, **_kwargs):
|
||||
return os.path.basename(self.file_path)
|
||||
@@ -505,7 +505,7 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p
|
||||
def to_dict(self):
|
||||
return {
|
||||
"file_path": self.file_path,
|
||||
"model_type": self.model_type,
|
||||
"sub_type": self.sub_type,
|
||||
"sha256": self.sha256,
|
||||
}
|
||||
|
||||
@@ -538,12 +538,12 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p
|
||||
self, metadata_obj, _file_path: str, root_path: Optional[str]
|
||||
):
|
||||
if root_path:
|
||||
metadata_obj.model_type = "diffusion_model"
|
||||
metadata_obj.sub_type = "diffusion_model"
|
||||
return metadata_obj
|
||||
|
||||
def adjust_cached_entry(self, entry):
|
||||
if entry.get("file_path", "").startswith(self.root):
|
||||
entry["model_type"] = "diffusion_model"
|
||||
entry["sub_type"] = "diffusion_model"
|
||||
return entry
|
||||
|
||||
async def add_model_to_cache(self, metadata_dict, relative_path):
|
||||
@@ -570,12 +570,12 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p
|
||||
)
|
||||
|
||||
assert result == {"success": True}
|
||||
assert metadata.model_type == "diffusion_model"
|
||||
assert metadata.sub_type == "diffusion_model"
|
||||
saved_metadata = MetadataManager.save_metadata.await_args.args[1]
|
||||
assert saved_metadata.model_type == "diffusion_model"
|
||||
assert saved_metadata.sub_type == "diffusion_model"
|
||||
assert dummy_scanner.add_calls
|
||||
cached_entry, _ = dummy_scanner.add_calls[0]
|
||||
assert cached_entry["model_type"] == "diffusion_model"
|
||||
assert cached_entry["sub_type"] == "diffusion_model"
|
||||
|
||||
|
||||
async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path):
|
||||
|
||||
@@ -4,9 +4,7 @@ import pytest
|
||||
from py.services.model_query import (
|
||||
_coerce_to_str,
|
||||
normalize_sub_type,
|
||||
normalize_civitai_model_type,
|
||||
resolve_sub_type,
|
||||
resolve_civitai_model_type,
|
||||
FilterCriteria,
|
||||
ModelFilterSet,
|
||||
)
|
||||
@@ -45,14 +43,6 @@ class TestNormalizeSubType:
|
||||
assert normalize_sub_type("") is None
|
||||
|
||||
|
||||
class TestNormalizeCivitaiModelTypeAlias:
|
||||
"""Test normalize_civitai_model_type is alias for normalize_sub_type."""
|
||||
|
||||
def test_alias_works_correctly(self):
|
||||
assert normalize_civitai_model_type("LoRA") == "lora"
|
||||
assert normalize_civitai_model_type("CHECKPOINT") == "checkpoint"
|
||||
|
||||
|
||||
class TestResolveSubType:
|
||||
"""Test resolve_sub_type function priority."""
|
||||
|
||||
@@ -60,44 +50,35 @@ class TestResolveSubType:
|
||||
"""Priority 1: entry['sub_type'] should be used first."""
|
||||
entry = {
|
||||
"sub_type": "locon",
|
||||
"model_type": "checkpoint", # Should be ignored
|
||||
"civitai": {"model": {"type": "dora"}}, # Should be ignored
|
||||
}
|
||||
assert resolve_sub_type(entry) == "locon"
|
||||
|
||||
def test_priority_2_model_type_field(self):
|
||||
"""Priority 2: entry['model_type'] as fallback."""
|
||||
entry = {
|
||||
"model_type": "checkpoint",
|
||||
"civitai": {"model": {"type": "dora"}}, # Should be ignored
|
||||
}
|
||||
assert resolve_sub_type(entry) == "checkpoint"
|
||||
|
||||
def test_priority_3_civitai_model_type(self):
|
||||
"""Priority 3: civitai.model.type as fallback."""
|
||||
def test_priority_2_civitai_model_type(self):
|
||||
"""Priority 2: civitai.model.type as fallback."""
|
||||
entry = {
|
||||
"civitai": {"model": {"type": "dora"}},
|
||||
}
|
||||
assert resolve_sub_type(entry) == "dora"
|
||||
|
||||
def test_priority_4_default(self):
|
||||
"""Priority 4: default to LORA when nothing found."""
|
||||
def test_priority_3_default(self):
|
||||
"""Priority 3: default to LORA when nothing found."""
|
||||
entry = {}
|
||||
assert resolve_sub_type(entry) == "LORA"
|
||||
|
||||
def test_empty_sub_type_falls_back(self):
|
||||
"""Empty sub_type should fall back to model_type."""
|
||||
"""Empty sub_type should fall back to civitai type."""
|
||||
entry = {
|
||||
"sub_type": "",
|
||||
"model_type": "checkpoint",
|
||||
"civitai": {"model": {"type": "checkpoint"}},
|
||||
}
|
||||
assert resolve_sub_type(entry) == "checkpoint"
|
||||
|
||||
def test_whitespace_sub_type_falls_back(self):
|
||||
"""Whitespace sub_type should fall back to model_type."""
|
||||
"""Whitespace sub_type should fall back to civitai type."""
|
||||
entry = {
|
||||
"sub_type": " ",
|
||||
"model_type": "checkpoint",
|
||||
"civitai": {"model": {"type": "checkpoint"}},
|
||||
}
|
||||
assert resolve_sub_type(entry) == "checkpoint"
|
||||
|
||||
@@ -110,14 +91,6 @@ class TestResolveSubType:
|
||||
assert resolve_sub_type("invalid") == "LORA"
|
||||
|
||||
|
||||
class TestResolveCivitaiModelTypeAlias:
|
||||
"""Test resolve_civitai_model_type is alias for resolve_sub_type."""
|
||||
|
||||
def test_alias_works_correctly(self):
|
||||
entry = {"sub_type": "locon"}
|
||||
assert resolve_civitai_model_type(entry) == "locon"
|
||||
|
||||
|
||||
class TestModelFilterSetWithSubType:
|
||||
"""Test ModelFilterSet applies model_types filtering correctly."""
|
||||
|
||||
@@ -145,23 +118,8 @@ class TestModelFilterSetWithSubType:
|
||||
assert result[0]["model_name"] == "Model 1"
|
||||
assert result[1]["model_name"] == "Model 2"
|
||||
|
||||
def test_filter_falls_back_to_model_type(self):
|
||||
"""Filter should fall back to model_type field."""
|
||||
settings = self.create_mock_settings()
|
||||
filter_set = ModelFilterSet(settings)
|
||||
|
||||
data = [
|
||||
{"model_type": "lora", "model_name": "Model 1"}, # Old field
|
||||
{"sub_type": "locon", "model_name": "Model 2"}, # New field
|
||||
]
|
||||
|
||||
criteria = FilterCriteria(model_types=["lora", "locon"])
|
||||
result = filter_set.apply(data, criteria)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
def test_filter_uses_civitai_type(self):
|
||||
"""Filter should use civitai.model.type as last resort."""
|
||||
"""Filter should use civitai.model.type as fallback."""
|
||||
settings = self.create_mock_settings()
|
||||
filter_set = ModelFilterSet(settings)
|
||||
|
||||
|
||||
@@ -521,7 +521,7 @@ async def test_reconcile_cache_applies_adjust_cached_entry(tmp_path: Path):
|
||||
|
||||
def _adjust(self, entry: dict) -> dict:
|
||||
applied.append(entry["file_path"])
|
||||
entry["model_type"] = "adjusted"
|
||||
entry["custom_field"] = "adjusted"
|
||||
return entry
|
||||
|
||||
scanner.adjust_cached_entry = MethodType(_adjust, scanner)
|
||||
@@ -538,7 +538,7 @@ async def test_reconcile_cache_applies_adjust_cached_entry(tmp_path: Path):
|
||||
assert normalized_new in applied
|
||||
|
||||
new_entry = next(item for item in scanner._cache.raw_data if item["file_path"] == normalized_new)
|
||||
assert new_entry["model_type"] == "adjusted"
|
||||
assert new_entry["custom_field"] == "adjusted"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -42,7 +42,7 @@ class TestLoraServiceFormatResponse:
|
||||
"usage_tips": "",
|
||||
"notes": "",
|
||||
"favorite": False,
|
||||
"sub_type": "locon", # New field
|
||||
"sub_type": "locon",
|
||||
"civitai": {},
|
||||
}
|
||||
|
||||
@@ -50,31 +50,7 @@ class TestLoraServiceFormatResponse:
|
||||
|
||||
assert "sub_type" in result
|
||||
assert result["sub_type"] == "locon"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_response_falls_back_to_model_type(self, lora_service):
|
||||
"""format_response should fall back to model_type if sub_type missing."""
|
||||
lora_data = {
|
||||
"model_name": "Test LoRA",
|
||||
"file_name": "test_lora",
|
||||
"preview_url": "test.webp",
|
||||
"preview_nsfw_level": 0,
|
||||
"base_model": "SDXL",
|
||||
"folder": "",
|
||||
"sha256": "abc123",
|
||||
"file_path": "/models/test_lora.safetensors",
|
||||
"size": 1000,
|
||||
"modified": 1234567890.0,
|
||||
"tags": [],
|
||||
"from_civitai": True,
|
||||
"model_type": "dora", # Old field
|
||||
"civitai": {},
|
||||
}
|
||||
|
||||
result = await lora_service.format_response(lora_data)
|
||||
|
||||
assert result["sub_type"] == "dora"
|
||||
assert result["model_type"] == "dora" # Both should be set
|
||||
assert "model_type" not in result # Removed in refactoring
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_response_defaults_to_lora(self, lora_service):
|
||||
@@ -98,7 +74,7 @@ class TestLoraServiceFormatResponse:
|
||||
result = await lora_service.format_response(lora_data)
|
||||
|
||||
assert result["sub_type"] == "lora"
|
||||
assert result["model_type"] == "lora"
|
||||
assert "model_type" not in result # Removed in refactoring
|
||||
|
||||
|
||||
class TestCheckpointServiceFormatResponse:
|
||||
@@ -138,7 +114,7 @@ class TestCheckpointServiceFormatResponse:
|
||||
|
||||
assert "sub_type" in result
|
||||
assert result["sub_type"] == "checkpoint"
|
||||
assert result["model_type"] == "checkpoint"
|
||||
assert "model_type" not in result # Removed in refactoring
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_response_includes_sub_type_diffusion_model(self, checkpoint_service):
|
||||
@@ -163,7 +139,7 @@ class TestCheckpointServiceFormatResponse:
|
||||
result = await checkpoint_service.format_response(checkpoint_data)
|
||||
|
||||
assert result["sub_type"] == "diffusion_model"
|
||||
assert result["model_type"] == "diffusion_model"
|
||||
assert "model_type" not in result # Removed in refactoring
|
||||
|
||||
|
||||
class TestEmbeddingServiceFormatResponse:
|
||||
@@ -203,7 +179,7 @@ class TestEmbeddingServiceFormatResponse:
|
||||
|
||||
assert "sub_type" in result
|
||||
assert result["sub_type"] == "embedding"
|
||||
assert result["model_type"] == "embedding"
|
||||
assert "model_type" not in result # Removed in refactoring
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_response_defaults_to_embedding(self, embedding_service):
|
||||
@@ -227,4 +203,4 @@ class TestEmbeddingServiceFormatResponse:
|
||||
result = await embedding_service.format_response(embedding_data)
|
||||
|
||||
assert result["sub_type"] == "embedding"
|
||||
assert result["model_type"] == "embedding"
|
||||
assert "model_type" not in result # Removed in refactoring
|
||||
|
||||
Reference in New Issue
Block a user