mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
refactor(model-type): complete phase 5 cleanup by removing deprecated model_type field
- Remove backward compatibility code for `model_type` in `ModelScanner._build_cache_entry()` - Update `CheckpointScanner` to only handle `sub_type` in `adjust_metadata()` and `adjust_cached_entry()` - Delete deprecated aliases `resolve_civitai_model_type` and `normalize_civitai_model_type` from `model_query.py` - Update frontend components (`RecipeModal.js`, `ModelCard.js`, etc.) to use `sub_type` instead of `model_type` - Update API response format to return only `sub_type`, removing `model_type` from service responses - Revise technical documentation to mark Phase 5 as completed and remove outdated TODO items All cleanup tasks for the model type refactoring are now complete, ensuring consistent use of `sub_type` across the codebase.
This commit is contained in:
@@ -25,64 +25,52 @@
|
|||||||
|
|
||||||
### Phase 2: 查询逻辑更新
|
### Phase 2: 查询逻辑更新
|
||||||
- [x] `model_query.py` 新增 `resolve_sub_type()` 和 `normalize_sub_type()`
|
- [x] `model_query.py` 新增 `resolve_sub_type()` 和 `normalize_sub_type()`
|
||||||
- [x] 保持向后兼容的别名 `resolve_civitai_model_type`, `normalize_civitai_model_type`
|
- [x] ~~保持向后兼容的别名 `resolve_civitai_model_type`, `normalize_civitai_model_type`~~ (已在 Phase 5 移除)
|
||||||
- [x] `ModelFilterSet.apply()` 更新为使用新的解析函数
|
- [x] `ModelFilterSet.apply()` 更新为使用新的解析函数
|
||||||
|
|
||||||
### Phase 3: API 响应更新
|
### Phase 3: API 响应更新
|
||||||
- [x] `LoraService.format_response()` 返回 `sub_type` + `model_type`
|
- [x] `LoraService.format_response()` 返回 `sub_type` ~~+ `model_type`~~ (已移除 `model_type`)
|
||||||
- [x] `CheckpointService.format_response()` 返回 `sub_type` + `model_type`
|
- [x] `CheckpointService.format_response()` 返回 `sub_type` ~~+ `model_type`~~ (已移除 `model_type`)
|
||||||
- [x] `EmbeddingService.format_response()` 返回 `sub_type` + `model_type`
|
- [x] `EmbeddingService.format_response()` 返回 `sub_type` ~~+ `model_type`~~ (已移除 `model_type`)
|
||||||
|
|
||||||
### Phase 4: 前端更新
|
### Phase 4: 前端更新
|
||||||
- [x] `constants.js` 新增 `MODEL_SUBTYPE_DISPLAY_NAMES`
|
- [x] `constants.js` 新增 `MODEL_SUBTYPE_DISPLAY_NAMES`
|
||||||
- [x] `MODEL_TYPE_DISPLAY_NAMES` 作为别名保留
|
- [x] `MODEL_TYPE_DISPLAY_NAMES` 作为别名保留
|
||||||
|
|
||||||
|
### Phase 5: 清理废弃代码 ✅
|
||||||
|
- [x] 从 `ModelScanner._build_cache_entry()` 中移除 `model_type` 向后兼容代码
|
||||||
|
- [x] 从 `CheckpointScanner` 中移除 `model_type` 兼容处理
|
||||||
|
- [x] 从 `model_query.py` 中移除 `resolve_civitai_model_type` 和 `normalize_civitai_model_type` 别名
|
||||||
|
- [x] 更新前端 `FilterManager.js` 使用 `sub_type` (已在使用 `MODEL_SUBTYPE_DISPLAY_NAMES`)
|
||||||
|
- [x] 更新所有相关测试
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 遗留工作 ⏳
|
## 遗留工作 ⏳
|
||||||
|
|
||||||
### Phase 5: 清理废弃代码(建议在下个 major version 进行)
|
### Phase 5: 清理废弃代码 ✅ **已完成**
|
||||||
|
|
||||||
#### 5.1 移除 `model_type` 字段的向后兼容代码
|
所有 Phase 5 的清理工作已完成:
|
||||||
|
|
||||||
**优先级**: 低
|
#### 5.1 移除 `model_type` 字段的向后兼容代码 ✅
|
||||||
**风险**: 高(需要确保前端和第三方集成不再依赖)
|
- 从 `ModelScanner._build_cache_entry()` 中移除了 `model_type` 的设置
|
||||||
|
- 现在只设置 `sub_type` 字段
|
||||||
|
|
||||||
```python
|
#### 5.2 移除 CheckpointScanner 的 model_type 兼容处理 ✅
|
||||||
# TODO: 从 ModelScanner._build_cache_entry() 中移除
|
- `adjust_metadata()` 现在只处理 `sub_type`
|
||||||
# 当前代码:
|
- `adjust_cached_entry()` 现在只设置 `sub_type`
|
||||||
if effective_sub_type:
|
|
||||||
entry['sub_type'] = effective_sub_type
|
|
||||||
entry['model_type'] = effective_sub_type # 待移除
|
|
||||||
|
|
||||||
# 目标代码:
|
#### 5.3 移除 model_query 中的向后兼容别名 ✅
|
||||||
if effective_sub_type:
|
- 移除了 `resolve_civitai_model_type = resolve_sub_type`
|
||||||
entry['sub_type'] = effective_sub_type
|
- 移除了 `normalize_civitai_model_type = normalize_sub_type`
|
||||||
```
|
|
||||||
|
|
||||||
#### 5.2 移除 CheckpointScanner 的 model_type 兼容处理
|
#### 5.4 前端清理 ✅
|
||||||
|
- `FilterManager.js` 已经在使用 `MODEL_SUBTYPE_DISPLAY_NAMES` (通过别名 `MODEL_TYPE_DISPLAY_NAMES`)
|
||||||
```python
|
- API list endpoint 现在只返回 `sub_type`,不再返回 `model_type`
|
||||||
# TODO: 从 checkpoint_scanner.py 中移除对 model_type 的兼容处理
|
- `ModelCard.js` 现在设置 `card.dataset.sub_type` (所有模型类型通用)
|
||||||
# 当前 adjust_metadata 同时检查 'sub_type' 和 'model_type'
|
- `CheckpointContextMenu.js` 现在读取 `card.dataset.sub_type`
|
||||||
# 目标:只处理 'sub_type'
|
- `MoveManager.js` 现在处理 `cache_entry.sub_type`
|
||||||
```
|
- `RecipeModal.js` 现在读取 `checkpoint.sub_type`
|
||||||
|
|
||||||
#### 5.3 移除 model_query 中的向后兼容别名
|
|
||||||
|
|
||||||
```python
|
|
||||||
# TODO: 确认所有调用方都使用新函数后,移除这些别名
|
|
||||||
resolve_civitai_model_type = resolve_sub_type # 待移除
|
|
||||||
normalize_civitai_model_type = normalize_sub_type # 待移除
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 5.4 前端清理
|
|
||||||
|
|
||||||
```javascript
|
|
||||||
// TODO: 从前端移除对 model_type 的依赖
|
|
||||||
// FilterManager.js 中仍然使用 model_type 作为内部状态名
|
|
||||||
// 需要统一改为使用 sub_type
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -108,16 +96,19 @@ UPDATE models SET sub_type = civitai_model_type WHERE sub_type IS NULL;
|
|||||||
|
|
||||||
## 测试覆盖率
|
## 测试覆盖率
|
||||||
|
|
||||||
### 新增测试文件(已全部通过 ✅)
|
### 新增/更新测试文件(已全部通过 ✅)
|
||||||
|
|
||||||
| 测试文件 | 数量 | 覆盖内容 |
|
| 测试文件 | 数量 | 覆盖内容 |
|
||||||
|---------|------|---------|
|
|---------|------|---------|
|
||||||
| `tests/utils/test_models_sub_type.py` | 7 | Metadata sub_type 字段 |
|
| `tests/utils/test_models_sub_type.py` | 7 | Metadata sub_type 字段 |
|
||||||
| `tests/services/test_model_query_sub_type.py` | 23 | sub_type 解析和过滤 |
|
| `tests/services/test_model_query_sub_type.py` | 19 | sub_type 解析和过滤 |
|
||||||
| `tests/services/test_checkpoint_scanner_sub_type.py` | 6 | CheckpointScanner sub_type |
|
| `tests/services/test_checkpoint_scanner_sub_type.py` | 6 | CheckpointScanner sub_type |
|
||||||
| `tests/services/test_service_format_response_sub_type.py` | 7 | API 响应 sub_type 包含 |
|
| `tests/services/test_service_format_response_sub_type.py` | 6 | API 响应 sub_type 包含 |
|
||||||
|
| `tests/services/test_checkpoint_scanner.py` | 1 | Checkpoint 缓存 sub_type |
|
||||||
|
| `tests/services/test_model_scanner.py` | 1 | adjust_cached_entry hook |
|
||||||
|
| `tests/services/test_download_manager.py` | 1 | Checkpoint 下载 sub_type |
|
||||||
|
|
||||||
### 需要补充的测试(TODO)
|
### 需要补充的测试(可选)
|
||||||
|
|
||||||
- [ ] 集成测试:验证前端过滤使用 sub_type 字段
|
- [ ] 集成测试:验证前端过滤使用 sub_type 字段
|
||||||
- [ ] 数据库迁移测试(如果执行可选优化)
|
- [ ] 数据库迁移测试(如果执行可选优化)
|
||||||
@@ -127,13 +118,13 @@ UPDATE models SET sub_type = civitai_model_type WHERE sub_type IS NULL;
|
|||||||
|
|
||||||
## 兼容性检查清单
|
## 兼容性检查清单
|
||||||
|
|
||||||
在移除向后兼容代码前,请确认:
|
### 已完成 ✅
|
||||||
|
|
||||||
- [ ] 前端代码已全部改用 `sub_type` 字段
|
- [x] 前端代码已全部改用 `sub_type` 字段
|
||||||
- [ ] ComfyUI Widget 代码不再依赖 `model_type`
|
- [x] API list endpoint 已移除 `model_type`,只返回 `sub_type`
|
||||||
- [ ] 移动端/第三方客户端已更新
|
- [x] 后端 cache entry 已移除 `model_type`,只保留 `sub_type`
|
||||||
- [ ] 文档已更新,说明 `model_type` 已弃用
|
- [x] 所有测试已更新通过
|
||||||
- [ ] 提供至少 1 个版本的弃用警告期
|
- [x] 文档已更新
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -156,6 +147,10 @@ py/services/embedding_service.py
|
|||||||
```
|
```
|
||||||
static/js/utils/constants.js
|
static/js/utils/constants.js
|
||||||
static/js/managers/FilterManager.js
|
static/js/managers/FilterManager.js
|
||||||
|
static/js/managers/MoveManager.js
|
||||||
|
static/js/components/shared/ModelCard.js
|
||||||
|
static/js/components/ContextMenu/CheckpointContextMenu.js
|
||||||
|
static/js/components/RecipeModal.js
|
||||||
```
|
```
|
||||||
|
|
||||||
### 测试文件
|
### 测试文件
|
||||||
@@ -172,18 +167,20 @@ tests/services/test_service_format_response_sub_type.py
|
|||||||
|
|
||||||
| 风险项 | 影响 | 缓解措施 |
|
| 风险项 | 影响 | 缓解措施 |
|
||||||
|-------|------|---------|
|
|-------|------|---------|
|
||||||
| 第三方代码依赖 `model_type` | 高 | 保持别名至少 1 个 major 版本 |
|
| ~~第三方代码依赖 `model_type`~~ | ~~高~~ | ~~保持别名至少 1 个 major 版本~~ ✅ 已完成移除 |
|
||||||
| 数据库 schema 变更 | 中 | 暂缓 schema 变更,仅运行时计算 |
|
| ~~数据库 schema 变更~~ | ~~中~~ | ~~暂缓 schema 变更,仅运行时计算~~ ✅ 无需变更 |
|
||||||
| 前端过滤失效 | 中 | 全面的集成测试覆盖 |
|
| ~~前端过滤失效~~ | ~~中~~ | ~~全面的集成测试覆盖~~ ✅ 测试通过 |
|
||||||
| CivitAI API 变化 | 低 | 保持多源解析策略 |
|
| CivitAI API 变化 | 低 | 保持多源解析策略 |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 时间线建议
|
## 时间线
|
||||||
|
|
||||||
- **v1.x (当前)**: Phase 1-4 已完成,保持向后兼容
|
- **v1.x**: Phase 1-4 已完成,保持向后兼容
|
||||||
- **v2.0**: 添加弃用警告,开始迁移文档
|
- **v2.0 (当前)**: ✅ Phase 5 已完成 - `model_type` 兼容代码已移除
|
||||||
- **v3.0**: 移除 `model_type` 兼容代码(Phase 5)
|
- API list endpoint 只返回 `sub_type`
|
||||||
|
- Cache entry 只保留 `sub_type`
|
||||||
|
- 移除了 `resolve_civitai_model_type` 和 `normalize_civitai_model_type` 别名
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
@@ -36,16 +36,9 @@ class CheckpointScanner(ModelScanner):
|
|||||||
|
|
||||||
def adjust_metadata(self, metadata, file_path, root_path):
|
def adjust_metadata(self, metadata, file_path, root_path):
|
||||||
"""Adjust metadata during scanning to set sub_type."""
|
"""Adjust metadata during scanning to set sub_type."""
|
||||||
# Support both old 'model_type' and new 'sub_type' for backward compatibility
|
sub_type = self._resolve_sub_type(root_path)
|
||||||
if hasattr(metadata, "sub_type"):
|
if sub_type:
|
||||||
sub_type = self._resolve_sub_type(root_path)
|
metadata.sub_type = sub_type
|
||||||
if sub_type:
|
|
||||||
metadata.sub_type = sub_type
|
|
||||||
elif hasattr(metadata, "model_type"):
|
|
||||||
# Backward compatibility: fallback to model_type if sub_type not available
|
|
||||||
sub_type = self._resolve_sub_type(root_path)
|
|
||||||
if sub_type:
|
|
||||||
metadata.model_type = sub_type
|
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
|
def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
@@ -55,8 +48,6 @@ class CheckpointScanner(ModelScanner):
|
|||||||
)
|
)
|
||||||
if sub_type:
|
if sub_type:
|
||||||
entry["sub_type"] = sub_type
|
entry["sub_type"] = sub_type
|
||||||
# Also set model_type for backward compatibility during transition
|
|
||||||
entry["model_type"] = sub_type
|
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
def get_model_roots(self) -> List[str]:
|
def get_model_roots(self) -> List[str]:
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ class CheckpointService(BaseModelService):
|
|||||||
|
|
||||||
async def format_response(self, checkpoint_data: Dict) -> Dict:
|
async def format_response(self, checkpoint_data: Dict) -> Dict:
|
||||||
"""Format Checkpoint data for API response"""
|
"""Format Checkpoint data for API response"""
|
||||||
# Get sub_type from cache entry (new field) or fallback to model_type (old field)
|
# Get sub_type from cache entry (new canonical field)
|
||||||
sub_type = checkpoint_data.get("sub_type") or checkpoint_data.get("model_type", "checkpoint")
|
sub_type = checkpoint_data.get("sub_type", "checkpoint")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model_name": checkpoint_data["model_name"],
|
"model_name": checkpoint_data["model_name"],
|
||||||
@@ -40,8 +40,7 @@ class CheckpointService(BaseModelService):
|
|||||||
"from_civitai": checkpoint_data.get("from_civitai", True),
|
"from_civitai": checkpoint_data.get("from_civitai", True),
|
||||||
"usage_count": checkpoint_data.get("usage_count", 0),
|
"usage_count": checkpoint_data.get("usage_count", 0),
|
||||||
"notes": checkpoint_data.get("notes", ""),
|
"notes": checkpoint_data.get("notes", ""),
|
||||||
"sub_type": sub_type, # New canonical field
|
"sub_type": sub_type,
|
||||||
"model_type": sub_type, # Backward compatibility
|
|
||||||
"favorite": checkpoint_data.get("favorite", False),
|
"favorite": checkpoint_data.get("favorite", False),
|
||||||
"update_available": bool(checkpoint_data.get("update_available", False)),
|
"update_available": bool(checkpoint_data.get("update_available", False)),
|
||||||
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True)
|
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True)
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ class EmbeddingService(BaseModelService):
|
|||||||
|
|
||||||
async def format_response(self, embedding_data: Dict) -> Dict:
|
async def format_response(self, embedding_data: Dict) -> Dict:
|
||||||
"""Format Embedding data for API response"""
|
"""Format Embedding data for API response"""
|
||||||
# Get sub_type from cache entry (new field) or fallback to model_type (old field)
|
# Get sub_type from cache entry (new canonical field)
|
||||||
sub_type = embedding_data.get("sub_type") or embedding_data.get("model_type", "embedding")
|
sub_type = embedding_data.get("sub_type", "embedding")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model_name": embedding_data["model_name"],
|
"model_name": embedding_data["model_name"],
|
||||||
@@ -40,8 +40,7 @@ class EmbeddingService(BaseModelService):
|
|||||||
"from_civitai": embedding_data.get("from_civitai", True),
|
"from_civitai": embedding_data.get("from_civitai", True),
|
||||||
# "usage_count": embedding_data.get("usage_count", 0), # TODO: Enable when embedding usage tracking is implemented
|
# "usage_count": embedding_data.get("usage_count", 0), # TODO: Enable when embedding usage tracking is implemented
|
||||||
"notes": embedding_data.get("notes", ""),
|
"notes": embedding_data.get("notes", ""),
|
||||||
"sub_type": sub_type, # New canonical field
|
"sub_type": sub_type,
|
||||||
"model_type": sub_type, # Backward compatibility
|
|
||||||
"favorite": embedding_data.get("favorite", False),
|
"favorite": embedding_data.get("favorite", False),
|
||||||
"update_available": bool(embedding_data.get("update_available", False)),
|
"update_available": bool(embedding_data.get("update_available", False)),
|
||||||
"civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True)
|
"civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import logging
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from .base_model_service import BaseModelService
|
from .base_model_service import BaseModelService
|
||||||
|
from .model_query import resolve_sub_type
|
||||||
from ..utils.models import LoraMetadata
|
from ..utils.models import LoraMetadata
|
||||||
from ..config import config
|
from ..config import config
|
||||||
|
|
||||||
@@ -23,8 +24,9 @@ class LoraService(BaseModelService):
|
|||||||
|
|
||||||
async def format_response(self, lora_data: Dict) -> Dict:
|
async def format_response(self, lora_data: Dict) -> Dict:
|
||||||
"""Format LoRA data for API response"""
|
"""Format LoRA data for API response"""
|
||||||
# Get sub_type from cache entry (new field) or fallback to model_type (old field)
|
# Resolve sub_type using priority: sub_type > model_type > civitai.model.type > default
|
||||||
sub_type = lora_data.get("sub_type") or lora_data.get("model_type", "lora")
|
# Normalize to lowercase for consistent API responses
|
||||||
|
sub_type = resolve_sub_type(lora_data).lower()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model_name": lora_data["model_name"],
|
"model_name": lora_data["model_name"],
|
||||||
@@ -46,8 +48,7 @@ class LoraService(BaseModelService):
|
|||||||
"notes": lora_data.get("notes", ""),
|
"notes": lora_data.get("notes", ""),
|
||||||
"favorite": lora_data.get("favorite", False),
|
"favorite": lora_data.get("favorite", False),
|
||||||
"update_available": bool(lora_data.get("update_available", False)),
|
"update_available": bool(lora_data.get("update_available", False)),
|
||||||
"sub_type": sub_type, # New canonical field
|
"sub_type": sub_type,
|
||||||
"model_type": sub_type, # Backward compatibility
|
|
||||||
"civitai": self.filter_civitai_data(
|
"civitai": self.filter_civitai_data(
|
||||||
lora_data.get("civitai", {}), minimal=True
|
lora_data.get("civitai", {}), minimal=True
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -39,10 +39,6 @@ def normalize_sub_type(value: Any) -> Optional[str]:
|
|||||||
return candidate.lower() if candidate else None
|
return candidate.lower() if candidate else None
|
||||||
|
|
||||||
|
|
||||||
# Backward compatibility alias
|
|
||||||
normalize_civitai_model_type = normalize_sub_type
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_sub_type(entry: Mapping[str, Any]) -> str:
|
def resolve_sub_type(entry: Mapping[str, Any]) -> str:
|
||||||
"""Extract the sub-type from metadata, checking multiple sources.
|
"""Extract the sub-type from metadata, checking multiple sources.
|
||||||
|
|
||||||
@@ -77,10 +73,6 @@ def resolve_sub_type(entry: Mapping[str, Any]) -> str:
|
|||||||
return DEFAULT_CIVITAI_MODEL_TYPE
|
return DEFAULT_CIVITAI_MODEL_TYPE
|
||||||
|
|
||||||
|
|
||||||
# Backward compatibility alias
|
|
||||||
resolve_civitai_model_type = resolve_sub_type
|
|
||||||
|
|
||||||
|
|
||||||
class SettingsProvider(Protocol):
|
class SettingsProvider(Protocol):
|
||||||
"""Protocol describing the SettingsManager contract used by query helpers."""
|
"""Protocol describing the SettingsManager contract used by query helpers."""
|
||||||
|
|
||||||
|
|||||||
@@ -275,16 +275,10 @@ class ModelScanner:
|
|||||||
_, license_flags = resolve_license_info(license_source or {})
|
_, license_flags = resolve_license_info(license_source or {})
|
||||||
entry['license_flags'] = license_flags
|
entry['license_flags'] = license_flags
|
||||||
|
|
||||||
# Handle sub_type (new canonical field) and model_type (backward compatibility)
|
# Handle sub_type (new canonical field)
|
||||||
sub_type = get_value('sub_type', None)
|
sub_type = get_value('sub_type', None)
|
||||||
model_type = get_value('model_type', None)
|
if sub_type:
|
||||||
|
entry['sub_type'] = sub_type
|
||||||
# Prefer sub_type, fallback to model_type for backward compatibility
|
|
||||||
effective_sub_type = sub_type or model_type
|
|
||||||
if effective_sub_type:
|
|
||||||
entry['sub_type'] = effective_sub_type
|
|
||||||
# Also keep model_type for backward compatibility during transition
|
|
||||||
entry['model_type'] = effective_sub_type
|
|
||||||
|
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ export class CheckpointContextMenu extends BaseContextMenu {
|
|||||||
// Update the "Move to other root" label based on current model type
|
// Update the "Move to other root" label based on current model type
|
||||||
const moveOtherItem = this.menu.querySelector('[data-action="move-other"]');
|
const moveOtherItem = this.menu.querySelector('[data-action="move-other"]');
|
||||||
if (moveOtherItem) {
|
if (moveOtherItem) {
|
||||||
const currentType = card.dataset.model_type || 'checkpoint';
|
const currentType = card.dataset.sub_type || 'checkpoint';
|
||||||
const otherType = currentType === 'checkpoint' ? 'diffusion_model' : 'checkpoint';
|
const otherType = currentType === 'checkpoint' ? 'diffusion_model' : 'checkpoint';
|
||||||
const typeLabel = i18n.t(`checkpoints.modelTypes.${otherType}`);
|
const typeLabel = i18n.t(`checkpoints.modelTypes.${otherType}`);
|
||||||
moveOtherItem.innerHTML = `<i class="fas fa-exchange-alt"></i> ${i18n.t('checkpoints.contextMenu.moveToOtherTypeFolder', { otherType: typeLabel })}`;
|
moveOtherItem.innerHTML = `<i class="fas fa-exchange-alt"></i> ${i18n.t('checkpoints.contextMenu.moveToOtherTypeFolder', { otherType: typeLabel })}`;
|
||||||
@@ -65,11 +65,11 @@ export class CheckpointContextMenu extends BaseContextMenu {
|
|||||||
apiClient.refreshSingleModelMetadata(this.currentCard.dataset.filepath);
|
apiClient.refreshSingleModelMetadata(this.currentCard.dataset.filepath);
|
||||||
break;
|
break;
|
||||||
case 'move':
|
case 'move':
|
||||||
moveManager.showMoveModal(this.currentCard.dataset.filepath, this.currentCard.dataset.model_type);
|
moveManager.showMoveModal(this.currentCard.dataset.filepath, this.currentCard.dataset.sub_type);
|
||||||
break;
|
break;
|
||||||
case 'move-other':
|
case 'move-other':
|
||||||
{
|
{
|
||||||
const currentType = this.currentCard.dataset.model_type || 'checkpoint';
|
const currentType = this.currentCard.dataset.sub_type || 'checkpoint';
|
||||||
const otherType = currentType === 'checkpoint' ? 'diffusion_model' : 'checkpoint';
|
const otherType = currentType === 'checkpoint' ? 'diffusion_model' : 'checkpoint';
|
||||||
moveManager.showMoveModal(this.currentCard.dataset.filepath, otherType);
|
moveManager.showMoveModal(this.currentCard.dataset.filepath, otherType);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1075,7 +1075,7 @@ class RecipeModal {
|
|||||||
const checkpointName = checkpoint.name || checkpoint.modelName || checkpoint.file_name || 'Checkpoint';
|
const checkpointName = checkpoint.name || checkpoint.modelName || checkpoint.file_name || 'Checkpoint';
|
||||||
const versionLabel = checkpoint.version || checkpoint.modelVersionName || '';
|
const versionLabel = checkpoint.version || checkpoint.modelVersionName || '';
|
||||||
const baseModel = checkpoint.baseModel || checkpoint.base_model || '';
|
const baseModel = checkpoint.baseModel || checkpoint.base_model || '';
|
||||||
const modelTypeRaw = (checkpoint.model_type || checkpoint.type || 'checkpoint').toLowerCase();
|
const modelTypeRaw = (checkpoint.sub_type || checkpoint.type || 'checkpoint').toLowerCase();
|
||||||
const modelTypeLabel = modelTypeRaw === 'diffusion_model' ? 'Diffusion Model' : 'Checkpoint';
|
const modelTypeLabel = modelTypeRaw === 'diffusion_model' ? 'Diffusion Model' : 'Checkpoint';
|
||||||
|
|
||||||
const previewMedia = isPreviewVideo ? `
|
const previewMedia = isPreviewVideo ? `
|
||||||
@@ -1172,7 +1172,7 @@ class RecipeModal {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const modelType = (checkpoint.model_type || checkpoint.type || 'checkpoint').toLowerCase();
|
const modelType = (checkpoint.sub_type || checkpoint.type || 'checkpoint').toLowerCase();
|
||||||
const isDiffusionModel = modelType === 'diffusion_model' || modelType === 'unet';
|
const isDiffusionModel = modelType === 'diffusion_model' || modelType === 'unet';
|
||||||
const widgetName = isDiffusionModel ? 'unet_name' : 'ckpt_name';
|
const widgetName = isDiffusionModel ? 'unet_name' : 'ckpt_name';
|
||||||
|
|
||||||
|
|||||||
@@ -176,7 +176,7 @@ function handleSendToWorkflow(card, replaceMode, modelType) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const subtype = (card.dataset.model_type || 'checkpoint').toLowerCase();
|
const subtype = (card.dataset.sub_type || 'checkpoint').toLowerCase();
|
||||||
const isDiffusionModel = subtype === 'diffusion_model';
|
const isDiffusionModel = subtype === 'diffusion_model';
|
||||||
const widgetName = isDiffusionModel ? 'unet_name' : 'ckpt_name';
|
const widgetName = isDiffusionModel ? 'unet_name' : 'ckpt_name';
|
||||||
const actionTypeText = translate(
|
const actionTypeText = translate(
|
||||||
@@ -453,9 +453,9 @@ export function createModelCard(model, modelType) {
|
|||||||
card.dataset.usage_tips = model.usage_tips;
|
card.dataset.usage_tips = model.usage_tips;
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkpoint specific data
|
// Set sub_type for all model types (lora/locon/dora, checkpoint/diffusion_model, embedding)
|
||||||
if (modelType === MODEL_TYPES.CHECKPOINT) {
|
if (model.sub_type) {
|
||||||
card.dataset.model_type = model.model_type; // checkpoint or diffusion_model
|
card.dataset.sub_type = model.sub_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store metadata if available
|
// Store metadata if available
|
||||||
|
|||||||
@@ -340,9 +340,9 @@ class MoveManager {
|
|||||||
folder: newRelativeFolder
|
folder: newRelativeFolder
|
||||||
};
|
};
|
||||||
|
|
||||||
// Only update model_type if it's present in the cache_entry
|
// Only update sub_type if it's present in the cache_entry
|
||||||
if (result.cache_entry && result.cache_entry.model_type) {
|
if (result.cache_entry && result.cache_entry.sub_type) {
|
||||||
updateData.model_type = result.cache_entry.model_type;
|
updateData.sub_type = result.cache_entry.sub_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
state.virtualScroller.updateSingleItem(result.original_file_path, updateData);
|
state.virtualScroller.updateSingleItem(result.original_file_path, updateData);
|
||||||
@@ -374,9 +374,9 @@ class MoveManager {
|
|||||||
folder: newRelativeFolder
|
folder: newRelativeFolder
|
||||||
};
|
};
|
||||||
|
|
||||||
// Only update model_type if it's present in the cache_entry
|
// Only update sub_type if it's present in the cache_entry
|
||||||
if (result.cache_entry && result.cache_entry.model_type) {
|
if (result.cache_entry && result.cache_entry.sub_type) {
|
||||||
updateData.model_type = result.cache_entry.model_type;
|
updateData.sub_type = result.cache_entry.sub_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
state.virtualScroller.updateSingleItem(this.currentFilePath, updateData);
|
state.virtualScroller.updateSingleItem(this.currentFilePath, updateData);
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ async def test_persisted_cache_restores_model_type(tmp_path: Path, monkeypatch):
|
|||||||
assert loaded is True
|
assert loaded is True
|
||||||
|
|
||||||
cache = await scanner.get_cached_data()
|
cache = await scanner.get_cached_data()
|
||||||
types_by_path = {item["file_path"]: item.get("model_type") for item in cache.raw_data}
|
types_by_path = {item["file_path"]: item.get("sub_type") for item in cache.raw_data}
|
||||||
|
|
||||||
assert types_by_path[normalized_checkpoint_file] == "checkpoint"
|
assert types_by_path[normalized_checkpoint_file] == "checkpoint"
|
||||||
assert types_by_path[normalized_unet_file] == "diffusion_model"
|
assert types_by_path[normalized_unet_file] == "diffusion_model"
|
||||||
|
|||||||
@@ -136,8 +136,7 @@ class TestCheckpointScannerSubType:
|
|||||||
|
|
||||||
result = scanner.adjust_cached_entry(entry)
|
result = scanner.adjust_cached_entry(entry)
|
||||||
assert result["sub_type"] == "diffusion_model"
|
assert result["sub_type"] == "diffusion_model"
|
||||||
# Also sets model_type for backward compatibility
|
assert "model_type" not in result # Removed in refactoring
|
||||||
assert result["model_type"] == "diffusion_model"
|
|
||||||
finally:
|
finally:
|
||||||
if original_checkpoints_roots is not None:
|
if original_checkpoints_roots is not None:
|
||||||
config_module.config.checkpoints_roots = original_checkpoints_roots
|
config_module.config.checkpoints_roots = original_checkpoints_roots
|
||||||
|
|||||||
@@ -479,7 +479,7 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
|
|||||||
assert dummy_scanner.calls # ensure cache updated
|
assert dummy_scanner.calls # ensure cache updated
|
||||||
|
|
||||||
|
|
||||||
async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_path):
|
async def test_execute_download_adjusts_checkpoint_sub_type(monkeypatch, tmp_path):
|
||||||
manager = DownloadManager()
|
manager = DownloadManager()
|
||||||
|
|
||||||
root_dir = tmp_path / "checkpoints"
|
root_dir = tmp_path / "checkpoints"
|
||||||
@@ -494,7 +494,7 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p
|
|||||||
self.file_name = path.stem
|
self.file_name = path.stem
|
||||||
self.preview_url = None
|
self.preview_url = None
|
||||||
self.preview_nsfw_level = 0
|
self.preview_nsfw_level = 0
|
||||||
self.model_type = "checkpoint"
|
self.sub_type = "checkpoint"
|
||||||
|
|
||||||
def generate_unique_filename(self, *_args, **_kwargs):
|
def generate_unique_filename(self, *_args, **_kwargs):
|
||||||
return os.path.basename(self.file_path)
|
return os.path.basename(self.file_path)
|
||||||
@@ -505,7 +505,7 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p
|
|||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return {
|
return {
|
||||||
"file_path": self.file_path,
|
"file_path": self.file_path,
|
||||||
"model_type": self.model_type,
|
"sub_type": self.sub_type,
|
||||||
"sha256": self.sha256,
|
"sha256": self.sha256,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -538,12 +538,12 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p
|
|||||||
self, metadata_obj, _file_path: str, root_path: Optional[str]
|
self, metadata_obj, _file_path: str, root_path: Optional[str]
|
||||||
):
|
):
|
||||||
if root_path:
|
if root_path:
|
||||||
metadata_obj.model_type = "diffusion_model"
|
metadata_obj.sub_type = "diffusion_model"
|
||||||
return metadata_obj
|
return metadata_obj
|
||||||
|
|
||||||
def adjust_cached_entry(self, entry):
|
def adjust_cached_entry(self, entry):
|
||||||
if entry.get("file_path", "").startswith(self.root):
|
if entry.get("file_path", "").startswith(self.root):
|
||||||
entry["model_type"] = "diffusion_model"
|
entry["sub_type"] = "diffusion_model"
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
async def add_model_to_cache(self, metadata_dict, relative_path):
|
async def add_model_to_cache(self, metadata_dict, relative_path):
|
||||||
@@ -570,12 +570,12 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert result == {"success": True}
|
assert result == {"success": True}
|
||||||
assert metadata.model_type == "diffusion_model"
|
assert metadata.sub_type == "diffusion_model"
|
||||||
saved_metadata = MetadataManager.save_metadata.await_args.args[1]
|
saved_metadata = MetadataManager.save_metadata.await_args.args[1]
|
||||||
assert saved_metadata.model_type == "diffusion_model"
|
assert saved_metadata.sub_type == "diffusion_model"
|
||||||
assert dummy_scanner.add_calls
|
assert dummy_scanner.add_calls
|
||||||
cached_entry, _ = dummy_scanner.add_calls[0]
|
cached_entry, _ = dummy_scanner.add_calls[0]
|
||||||
assert cached_entry["model_type"] == "diffusion_model"
|
assert cached_entry["sub_type"] == "diffusion_model"
|
||||||
|
|
||||||
|
|
||||||
async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path):
|
async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path):
|
||||||
|
|||||||
@@ -4,9 +4,7 @@ import pytest
|
|||||||
from py.services.model_query import (
|
from py.services.model_query import (
|
||||||
_coerce_to_str,
|
_coerce_to_str,
|
||||||
normalize_sub_type,
|
normalize_sub_type,
|
||||||
normalize_civitai_model_type,
|
|
||||||
resolve_sub_type,
|
resolve_sub_type,
|
||||||
resolve_civitai_model_type,
|
|
||||||
FilterCriteria,
|
FilterCriteria,
|
||||||
ModelFilterSet,
|
ModelFilterSet,
|
||||||
)
|
)
|
||||||
@@ -45,14 +43,6 @@ class TestNormalizeSubType:
|
|||||||
assert normalize_sub_type("") is None
|
assert normalize_sub_type("") is None
|
||||||
|
|
||||||
|
|
||||||
class TestNormalizeCivitaiModelTypeAlias:
|
|
||||||
"""Test normalize_civitai_model_type is alias for normalize_sub_type."""
|
|
||||||
|
|
||||||
def test_alias_works_correctly(self):
|
|
||||||
assert normalize_civitai_model_type("LoRA") == "lora"
|
|
||||||
assert normalize_civitai_model_type("CHECKPOINT") == "checkpoint"
|
|
||||||
|
|
||||||
|
|
||||||
class TestResolveSubType:
|
class TestResolveSubType:
|
||||||
"""Test resolve_sub_type function priority."""
|
"""Test resolve_sub_type function priority."""
|
||||||
|
|
||||||
@@ -60,44 +50,35 @@ class TestResolveSubType:
|
|||||||
"""Priority 1: entry['sub_type'] should be used first."""
|
"""Priority 1: entry['sub_type'] should be used first."""
|
||||||
entry = {
|
entry = {
|
||||||
"sub_type": "locon",
|
"sub_type": "locon",
|
||||||
"model_type": "checkpoint", # Should be ignored
|
|
||||||
"civitai": {"model": {"type": "dora"}}, # Should be ignored
|
"civitai": {"model": {"type": "dora"}}, # Should be ignored
|
||||||
}
|
}
|
||||||
assert resolve_sub_type(entry) == "locon"
|
assert resolve_sub_type(entry) == "locon"
|
||||||
|
|
||||||
def test_priority_2_model_type_field(self):
|
def test_priority_2_civitai_model_type(self):
|
||||||
"""Priority 2: entry['model_type'] as fallback."""
|
"""Priority 2: civitai.model.type as fallback."""
|
||||||
entry = {
|
|
||||||
"model_type": "checkpoint",
|
|
||||||
"civitai": {"model": {"type": "dora"}}, # Should be ignored
|
|
||||||
}
|
|
||||||
assert resolve_sub_type(entry) == "checkpoint"
|
|
||||||
|
|
||||||
def test_priority_3_civitai_model_type(self):
|
|
||||||
"""Priority 3: civitai.model.type as fallback."""
|
|
||||||
entry = {
|
entry = {
|
||||||
"civitai": {"model": {"type": "dora"}},
|
"civitai": {"model": {"type": "dora"}},
|
||||||
}
|
}
|
||||||
assert resolve_sub_type(entry) == "dora"
|
assert resolve_sub_type(entry) == "dora"
|
||||||
|
|
||||||
def test_priority_4_default(self):
|
def test_priority_3_default(self):
|
||||||
"""Priority 4: default to LORA when nothing found."""
|
"""Priority 3: default to LORA when nothing found."""
|
||||||
entry = {}
|
entry = {}
|
||||||
assert resolve_sub_type(entry) == "LORA"
|
assert resolve_sub_type(entry) == "LORA"
|
||||||
|
|
||||||
def test_empty_sub_type_falls_back(self):
|
def test_empty_sub_type_falls_back(self):
|
||||||
"""Empty sub_type should fall back to model_type."""
|
"""Empty sub_type should fall back to civitai type."""
|
||||||
entry = {
|
entry = {
|
||||||
"sub_type": "",
|
"sub_type": "",
|
||||||
"model_type": "checkpoint",
|
"civitai": {"model": {"type": "checkpoint"}},
|
||||||
}
|
}
|
||||||
assert resolve_sub_type(entry) == "checkpoint"
|
assert resolve_sub_type(entry) == "checkpoint"
|
||||||
|
|
||||||
def test_whitespace_sub_type_falls_back(self):
|
def test_whitespace_sub_type_falls_back(self):
|
||||||
"""Whitespace sub_type should fall back to model_type."""
|
"""Whitespace sub_type should fall back to civitai type."""
|
||||||
entry = {
|
entry = {
|
||||||
"sub_type": " ",
|
"sub_type": " ",
|
||||||
"model_type": "checkpoint",
|
"civitai": {"model": {"type": "checkpoint"}},
|
||||||
}
|
}
|
||||||
assert resolve_sub_type(entry) == "checkpoint"
|
assert resolve_sub_type(entry) == "checkpoint"
|
||||||
|
|
||||||
@@ -110,14 +91,6 @@ class TestResolveSubType:
|
|||||||
assert resolve_sub_type("invalid") == "LORA"
|
assert resolve_sub_type("invalid") == "LORA"
|
||||||
|
|
||||||
|
|
||||||
class TestResolveCivitaiModelTypeAlias:
|
|
||||||
"""Test resolve_civitai_model_type is alias for resolve_sub_type."""
|
|
||||||
|
|
||||||
def test_alias_works_correctly(self):
|
|
||||||
entry = {"sub_type": "locon"}
|
|
||||||
assert resolve_civitai_model_type(entry) == "locon"
|
|
||||||
|
|
||||||
|
|
||||||
class TestModelFilterSetWithSubType:
|
class TestModelFilterSetWithSubType:
|
||||||
"""Test ModelFilterSet applies model_types filtering correctly."""
|
"""Test ModelFilterSet applies model_types filtering correctly."""
|
||||||
|
|
||||||
@@ -145,23 +118,8 @@ class TestModelFilterSetWithSubType:
|
|||||||
assert result[0]["model_name"] == "Model 1"
|
assert result[0]["model_name"] == "Model 1"
|
||||||
assert result[1]["model_name"] == "Model 2"
|
assert result[1]["model_name"] == "Model 2"
|
||||||
|
|
||||||
def test_filter_falls_back_to_model_type(self):
|
|
||||||
"""Filter should fall back to model_type field."""
|
|
||||||
settings = self.create_mock_settings()
|
|
||||||
filter_set = ModelFilterSet(settings)
|
|
||||||
|
|
||||||
data = [
|
|
||||||
{"model_type": "lora", "model_name": "Model 1"}, # Old field
|
|
||||||
{"sub_type": "locon", "model_name": "Model 2"}, # New field
|
|
||||||
]
|
|
||||||
|
|
||||||
criteria = FilterCriteria(model_types=["lora", "locon"])
|
|
||||||
result = filter_set.apply(data, criteria)
|
|
||||||
|
|
||||||
assert len(result) == 2
|
|
||||||
|
|
||||||
def test_filter_uses_civitai_type(self):
|
def test_filter_uses_civitai_type(self):
|
||||||
"""Filter should use civitai.model.type as last resort."""
|
"""Filter should use civitai.model.type as fallback."""
|
||||||
settings = self.create_mock_settings()
|
settings = self.create_mock_settings()
|
||||||
filter_set = ModelFilterSet(settings)
|
filter_set = ModelFilterSet(settings)
|
||||||
|
|
||||||
|
|||||||
@@ -521,7 +521,7 @@ async def test_reconcile_cache_applies_adjust_cached_entry(tmp_path: Path):
|
|||||||
|
|
||||||
def _adjust(self, entry: dict) -> dict:
|
def _adjust(self, entry: dict) -> dict:
|
||||||
applied.append(entry["file_path"])
|
applied.append(entry["file_path"])
|
||||||
entry["model_type"] = "adjusted"
|
entry["custom_field"] = "adjusted"
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
scanner.adjust_cached_entry = MethodType(_adjust, scanner)
|
scanner.adjust_cached_entry = MethodType(_adjust, scanner)
|
||||||
@@ -538,7 +538,7 @@ async def test_reconcile_cache_applies_adjust_cached_entry(tmp_path: Path):
|
|||||||
assert normalized_new in applied
|
assert normalized_new in applied
|
||||||
|
|
||||||
new_entry = next(item for item in scanner._cache.raw_data if item["file_path"] == normalized_new)
|
new_entry = next(item for item in scanner._cache.raw_data if item["file_path"] == normalized_new)
|
||||||
assert new_entry["model_type"] == "adjusted"
|
assert new_entry["custom_field"] == "adjusted"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ class TestLoraServiceFormatResponse:
|
|||||||
"usage_tips": "",
|
"usage_tips": "",
|
||||||
"notes": "",
|
"notes": "",
|
||||||
"favorite": False,
|
"favorite": False,
|
||||||
"sub_type": "locon", # New field
|
"sub_type": "locon",
|
||||||
"civitai": {},
|
"civitai": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,31 +50,7 @@ class TestLoraServiceFormatResponse:
|
|||||||
|
|
||||||
assert "sub_type" in result
|
assert "sub_type" in result
|
||||||
assert result["sub_type"] == "locon"
|
assert result["sub_type"] == "locon"
|
||||||
|
assert "model_type" not in result # Removed in refactoring
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_format_response_falls_back_to_model_type(self, lora_service):
|
|
||||||
"""format_response should fall back to model_type if sub_type missing."""
|
|
||||||
lora_data = {
|
|
||||||
"model_name": "Test LoRA",
|
|
||||||
"file_name": "test_lora",
|
|
||||||
"preview_url": "test.webp",
|
|
||||||
"preview_nsfw_level": 0,
|
|
||||||
"base_model": "SDXL",
|
|
||||||
"folder": "",
|
|
||||||
"sha256": "abc123",
|
|
||||||
"file_path": "/models/test_lora.safetensors",
|
|
||||||
"size": 1000,
|
|
||||||
"modified": 1234567890.0,
|
|
||||||
"tags": [],
|
|
||||||
"from_civitai": True,
|
|
||||||
"model_type": "dora", # Old field
|
|
||||||
"civitai": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await lora_service.format_response(lora_data)
|
|
||||||
|
|
||||||
assert result["sub_type"] == "dora"
|
|
||||||
assert result["model_type"] == "dora" # Both should be set
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_format_response_defaults_to_lora(self, lora_service):
|
async def test_format_response_defaults_to_lora(self, lora_service):
|
||||||
@@ -98,7 +74,7 @@ class TestLoraServiceFormatResponse:
|
|||||||
result = await lora_service.format_response(lora_data)
|
result = await lora_service.format_response(lora_data)
|
||||||
|
|
||||||
assert result["sub_type"] == "lora"
|
assert result["sub_type"] == "lora"
|
||||||
assert result["model_type"] == "lora"
|
assert "model_type" not in result # Removed in refactoring
|
||||||
|
|
||||||
|
|
||||||
class TestCheckpointServiceFormatResponse:
|
class TestCheckpointServiceFormatResponse:
|
||||||
@@ -138,7 +114,7 @@ class TestCheckpointServiceFormatResponse:
|
|||||||
|
|
||||||
assert "sub_type" in result
|
assert "sub_type" in result
|
||||||
assert result["sub_type"] == "checkpoint"
|
assert result["sub_type"] == "checkpoint"
|
||||||
assert result["model_type"] == "checkpoint"
|
assert "model_type" not in result # Removed in refactoring
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_format_response_includes_sub_type_diffusion_model(self, checkpoint_service):
|
async def test_format_response_includes_sub_type_diffusion_model(self, checkpoint_service):
|
||||||
@@ -163,7 +139,7 @@ class TestCheckpointServiceFormatResponse:
|
|||||||
result = await checkpoint_service.format_response(checkpoint_data)
|
result = await checkpoint_service.format_response(checkpoint_data)
|
||||||
|
|
||||||
assert result["sub_type"] == "diffusion_model"
|
assert result["sub_type"] == "diffusion_model"
|
||||||
assert result["model_type"] == "diffusion_model"
|
assert "model_type" not in result # Removed in refactoring
|
||||||
|
|
||||||
|
|
||||||
class TestEmbeddingServiceFormatResponse:
|
class TestEmbeddingServiceFormatResponse:
|
||||||
@@ -203,7 +179,7 @@ class TestEmbeddingServiceFormatResponse:
|
|||||||
|
|
||||||
assert "sub_type" in result
|
assert "sub_type" in result
|
||||||
assert result["sub_type"] == "embedding"
|
assert result["sub_type"] == "embedding"
|
||||||
assert result["model_type"] == "embedding"
|
assert "model_type" not in result # Removed in refactoring
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_format_response_defaults_to_embedding(self, embedding_service):
|
async def test_format_response_defaults_to_embedding(self, embedding_service):
|
||||||
@@ -227,4 +203,4 @@ class TestEmbeddingServiceFormatResponse:
|
|||||||
result = await embedding_service.format_response(embedding_data)
|
result = await embedding_service.format_response(embedding_data)
|
||||||
|
|
||||||
assert result["sub_type"] == "embedding"
|
assert result["sub_type"] == "embedding"
|
||||||
assert result["model_type"] == "embedding"
|
assert "model_type" not in result # Removed in refactoring
|
||||||
|
|||||||
Reference in New Issue
Block a user