mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 23:25:43 -03:00
refactor: unify model_type semantics by introducing sub_type field
This commit resolves the semantic confusion around the model_type field by clearly distinguishing between: - scanner_type: architecture-level (lora/checkpoint/embedding) - sub_type: business-level subtype (lora/locon/dora/checkpoint/diffusion_model/embedding) Backend Changes: - Rename model_type to sub_type in CheckpointMetadata and EmbeddingMetadata - Add resolve_sub_type() and normalize_sub_type() in model_query.py - Update checkpoint_scanner to use _resolve_sub_type() - Update service format_response to include both sub_type and model_type - Add VALID_*_SUB_TYPES constants with backward compatible aliases Frontend Changes: - Add MODEL_SUBTYPE_DISPLAY_NAMES constants - Keep MODEL_TYPE_DISPLAY_NAMES as backward compatible alias Testing: - Add 43 new tests covering sub_type resolution and API response Documentation: - Add refactoring todo document to docs/technical/ BREAKING CHANGE: None - full backward compatibility maintained
This commit is contained in:
194
docs/technical/model_type_refactoring_todo.md
Normal file
194
docs/technical/model_type_refactoring_todo.md
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
# Model Type 字段重构 - 遗留工作清单
|
||||||
|
|
||||||
|
> **状态**: Phase 1-4 已完成 | **创建日期**: 2026-01-30
|
||||||
|
> **相关文件**: `py/utils/models.py`, `py/services/model_query.py`, `py/services/checkpoint_scanner.py`, etc.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
本次重构旨在解决 `model_type` 字段语义不统一的问题。系统中有两个层面的"类型"概念:
|
||||||
|
|
||||||
|
1. **Scanner Type** (`scanner_type`): 架构层面的大类 - `lora`, `checkpoint`, `embedding`
|
||||||
|
2. **Sub Type** (`sub_type`): 业务层面的细分类型 - `lora`/`locon`/`dora`, `checkpoint`/`diffusion_model`, `embedding`
|
||||||
|
|
||||||
|
重构目标是统一使用 `sub_type` 表示细分类型,保留 `model_type` 作为向后兼容的别名。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 已完成工作 ✅
|
||||||
|
|
||||||
|
### Phase 1: 后端字段重命名
|
||||||
|
- [x] `CheckpointMetadata.model_type` → `sub_type`
|
||||||
|
- [x] `EmbeddingMetadata.model_type` → `sub_type`
|
||||||
|
- [x] `model_scanner.py` `_build_cache_entry()` 同时处理 `sub_type` 和 `model_type`
|
||||||
|
|
||||||
|
### Phase 2: 查询逻辑更新
|
||||||
|
- [x] `model_query.py` 新增 `resolve_sub_type()` 和 `normalize_sub_type()`
|
||||||
|
- [x] 保持向后兼容的别名 `resolve_civitai_model_type`, `normalize_civitai_model_type`
|
||||||
|
- [x] `ModelFilterSet.apply()` 更新为使用新的解析函数
|
||||||
|
|
||||||
|
### Phase 3: API 响应更新
|
||||||
|
- [x] `LoraService.format_response()` 返回 `sub_type` + `model_type`
|
||||||
|
- [x] `CheckpointService.format_response()` 返回 `sub_type` + `model_type`
|
||||||
|
- [x] `EmbeddingService.format_response()` 返回 `sub_type` + `model_type`
|
||||||
|
|
||||||
|
### Phase 4: 前端更新
|
||||||
|
- [x] `constants.js` 新增 `MODEL_SUBTYPE_DISPLAY_NAMES`
|
||||||
|
- [x] `MODEL_TYPE_DISPLAY_NAMES` 作为别名保留
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 遗留工作 ⏳
|
||||||
|
|
||||||
|
### Phase 5: 清理废弃代码(建议在下个 major version 进行)
|
||||||
|
|
||||||
|
#### 5.1 移除 `model_type` 字段的向后兼容代码
|
||||||
|
|
||||||
|
**优先级**: 低
|
||||||
|
**风险**: 高(需要确保前端和第三方集成不再依赖)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# TODO: 从 ModelScanner._build_cache_entry() 中移除
|
||||||
|
# 当前代码:
|
||||||
|
if effective_sub_type:
|
||||||
|
entry['sub_type'] = effective_sub_type
|
||||||
|
entry['model_type'] = effective_sub_type # 待移除
|
||||||
|
|
||||||
|
# 目标代码:
|
||||||
|
if effective_sub_type:
|
||||||
|
entry['sub_type'] = effective_sub_type
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 5.2 移除 CheckpointScanner 的 model_type 兼容处理
|
||||||
|
|
||||||
|
```python
|
||||||
|
# TODO: 从 checkpoint_scanner.py 中移除对 model_type 的兼容处理
|
||||||
|
# 当前 adjust_metadata 同时检查 'sub_type' 和 'model_type'
|
||||||
|
# 目标:只处理 'sub_type'
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 5.3 移除 model_query 中的向后兼容别名
|
||||||
|
|
||||||
|
```python
|
||||||
|
# TODO: 确认所有调用方都使用新函数后,移除这些别名
|
||||||
|
resolve_civitai_model_type = resolve_sub_type # 待移除
|
||||||
|
normalize_civitai_model_type = normalize_sub_type # 待移除
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 5.4 前端清理
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
// TODO: 从前端移除对 model_type 的依赖
|
||||||
|
// FilterManager.js 中仍然使用 model_type 作为内部状态名
|
||||||
|
// 需要统一改为使用 sub_type
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 数据库迁移评估
|
||||||
|
|
||||||
|
### 当前状态
|
||||||
|
- `persistent_model_cache.py` 使用 `civitai_model_type` 列存储 CivitAI 原始类型
|
||||||
|
- 缓存 entry 中的 `sub_type` 在运行期动态计算
|
||||||
|
- 数据库 schema **无需立即修改**
|
||||||
|
|
||||||
|
### 未来可选优化
|
||||||
|
```sql
|
||||||
|
-- 可选:在 models 表中添加 sub_type 列(与 civitai_model_type 保持一致但语义更清晰)
|
||||||
|
ALTER TABLE models ADD COLUMN sub_type TEXT;
|
||||||
|
|
||||||
|
-- 数据迁移
|
||||||
|
UPDATE models SET sub_type = civitai_model_type WHERE sub_type IS NULL;
|
||||||
|
```
|
||||||
|
|
||||||
|
**建议**: 如果决定添加 `sub_type` 列,应与 Phase 5 一起进行。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 测试覆盖率
|
||||||
|
|
||||||
|
### 新增测试文件(已全部通过 ✅)
|
||||||
|
|
||||||
|
| 测试文件 | 数量 | 覆盖内容 |
|
||||||
|
|---------|------|---------|
|
||||||
|
| `tests/utils/test_models_sub_type.py` | 7 | Metadata sub_type 字段 |
|
||||||
|
| `tests/services/test_model_query_sub_type.py` | 23 | sub_type 解析和过滤 |
|
||||||
|
| `tests/services/test_checkpoint_scanner_sub_type.py` | 6 | CheckpointScanner sub_type |
|
||||||
|
| `tests/services/test_service_format_response_sub_type.py` | 7 | API 响应 sub_type 包含 |
|
||||||
|
|
||||||
|
### 需要补充的测试(TODO)
|
||||||
|
|
||||||
|
- [ ] 集成测试:验证前端过滤使用 sub_type 字段
|
||||||
|
- [ ] 数据库迁移测试(如果执行可选优化)
|
||||||
|
- [ ] 性能测试:确认 resolve_sub_type 的优先级查找没有显著性能影响
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 兼容性检查清单
|
||||||
|
|
||||||
|
在移除向后兼容代码前,请确认:
|
||||||
|
|
||||||
|
- [ ] 前端代码已全部改用 `sub_type` 字段
|
||||||
|
- [ ] ComfyUI Widget 代码不再依赖 `model_type`
|
||||||
|
- [ ] 移动端/第三方客户端已更新
|
||||||
|
- [ ] 文档已更新,说明 `model_type` 已弃用
|
||||||
|
- [ ] 提供至少 1 个版本的弃用警告期
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 相关文件清单
|
||||||
|
|
||||||
|
### 核心文件
|
||||||
|
```
|
||||||
|
py/utils/models.py
|
||||||
|
py/utils/constants.py
|
||||||
|
py/services/model_scanner.py
|
||||||
|
py/services/model_query.py
|
||||||
|
py/services/checkpoint_scanner.py
|
||||||
|
py/services/base_model_service.py
|
||||||
|
py/services/lora_service.py
|
||||||
|
py/services/checkpoint_service.py
|
||||||
|
py/services/embedding_service.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### 前端文件
|
||||||
|
```
|
||||||
|
static/js/utils/constants.js
|
||||||
|
static/js/managers/FilterManager.js
|
||||||
|
```
|
||||||
|
|
||||||
|
### 测试文件
|
||||||
|
```
|
||||||
|
tests/utils/test_models_sub_type.py
|
||||||
|
tests/services/test_model_query_sub_type.py
|
||||||
|
tests/services/test_checkpoint_scanner_sub_type.py
|
||||||
|
tests/services/test_service_format_response_sub_type.py
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 风险评估
|
||||||
|
|
||||||
|
| 风险项 | 影响 | 缓解措施 |
|
||||||
|
|-------|------|---------|
|
||||||
|
| 第三方代码依赖 `model_type` | 高 | 保持别名至少 1 个 major 版本 |
|
||||||
|
| 数据库 schema 变更 | 中 | 暂缓 schema 变更,仅运行时计算 |
|
||||||
|
| 前端过滤失效 | 中 | 全面的集成测试覆盖 |
|
||||||
|
| CivitAI API 变化 | 低 | 保持多源解析策略 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 时间线建议
|
||||||
|
|
||||||
|
- **v1.x (当前)**: Phase 1-4 已完成,保持向后兼容
|
||||||
|
- **v2.0**: 添加弃用警告,开始迁移文档
|
||||||
|
- **v3.0**: 移除 `model_type` 兼容代码(Phase 5)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 备注
|
||||||
|
|
||||||
|
- 重构期间发现 `civitai_model_type` 数据库列命名尚可,但语义上应理解为存储 CivitAI API 返回的原始类型值
|
||||||
|
- Checkpoint 的 `diffusion_model` sub_type 不能通过 CivitAI API 获取,必须通过文件路径(model root)判断
|
||||||
|
- LoRA 的 sub_type(lora/locon/dora)直接来自 CivitAI API 的 `version_info.model.type`
|
||||||
@@ -5,7 +5,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from ..utils.constants import VALID_LORA_TYPES
|
from ..utils.constants import VALID_LORA_SUB_TYPES, VALID_CHECKPOINT_SUB_TYPES
|
||||||
from ..utils.models import BaseModelMetadata
|
from ..utils.models import BaseModelMetadata
|
||||||
from ..utils.metadata_manager import MetadataManager
|
from ..utils.metadata_manager import MetadataManager
|
||||||
from ..utils.usage_stats import UsageStats
|
from ..utils.usage_stats import UsageStats
|
||||||
@@ -15,8 +15,8 @@ from .model_query import (
|
|||||||
ModelFilterSet,
|
ModelFilterSet,
|
||||||
SearchStrategy,
|
SearchStrategy,
|
||||||
SettingsProvider,
|
SettingsProvider,
|
||||||
normalize_civitai_model_type,
|
normalize_sub_type,
|
||||||
resolve_civitai_model_type,
|
resolve_sub_type,
|
||||||
)
|
)
|
||||||
from .settings_manager import get_settings_manager
|
from .settings_manager import get_settings_manager
|
||||||
|
|
||||||
@@ -568,16 +568,21 @@ class BaseModelService(ABC):
|
|||||||
return await self.scanner.get_base_models(limit)
|
return await self.scanner.get_base_models(limit)
|
||||||
|
|
||||||
async def get_model_types(self, limit: int = 20) -> List[Dict[str, Any]]:
|
async def get_model_types(self, limit: int = 20) -> List[Dict[str, Any]]:
|
||||||
"""Get counts of normalized CivitAI model types present in the cache."""
|
"""Get counts of sub-types present in the cache."""
|
||||||
cache = await self.scanner.get_cached_data()
|
cache = await self.scanner.get_cached_data()
|
||||||
|
|
||||||
type_counts: Dict[str, int] = {}
|
type_counts: Dict[str, int] = {}
|
||||||
for entry in cache.raw_data:
|
for entry in cache.raw_data:
|
||||||
normalized_type = normalize_civitai_model_type(
|
normalized_type = normalize_sub_type(resolve_sub_type(entry))
|
||||||
resolve_civitai_model_type(entry)
|
if not normalized_type:
|
||||||
)
|
|
||||||
if not normalized_type or normalized_type not in VALID_LORA_TYPES:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Filter by valid sub-types based on scanner type
|
||||||
|
if self.model_type == "lora" and normalized_type not in VALID_LORA_SUB_TYPES:
|
||||||
|
continue
|
||||||
|
if self.model_type == "checkpoint" and normalized_type not in VALID_CHECKPOINT_SUB_TYPES:
|
||||||
|
continue
|
||||||
|
|
||||||
type_counts[normalized_type] = type_counts.get(normalized_type, 0) + 1
|
type_counts[normalized_type] = type_counts.get(normalized_type, 0) + 1
|
||||||
|
|
||||||
sorted_types = sorted(
|
sorted_types = sorted(
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ class CheckpointScanner(ModelScanner):
|
|||||||
hash_index=ModelHashIndex()
|
hash_index=ModelHashIndex()
|
||||||
)
|
)
|
||||||
|
|
||||||
def _resolve_model_type(self, root_path: Optional[str]) -> Optional[str]:
|
def _resolve_sub_type(self, root_path: Optional[str]) -> Optional[str]:
|
||||||
|
"""Resolve the sub-type based on the root path."""
|
||||||
if not root_path:
|
if not root_path:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -34,18 +35,28 @@ class CheckpointScanner(ModelScanner):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def adjust_metadata(self, metadata, file_path, root_path):
|
def adjust_metadata(self, metadata, file_path, root_path):
|
||||||
if hasattr(metadata, "model_type"):
|
"""Adjust metadata during scanning to set sub_type."""
|
||||||
model_type = self._resolve_model_type(root_path)
|
# Support both old 'model_type' and new 'sub_type' for backward compatibility
|
||||||
if model_type:
|
if hasattr(metadata, "sub_type"):
|
||||||
metadata.model_type = model_type
|
sub_type = self._resolve_sub_type(root_path)
|
||||||
|
if sub_type:
|
||||||
|
metadata.sub_type = sub_type
|
||||||
|
elif hasattr(metadata, "model_type"):
|
||||||
|
# Backward compatibility: fallback to model_type if sub_type not available
|
||||||
|
sub_type = self._resolve_sub_type(root_path)
|
||||||
|
if sub_type:
|
||||||
|
metadata.model_type = sub_type
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
|
def adjust_cached_entry(self, entry: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
model_type = self._resolve_model_type(
|
"""Adjust entries loaded from the persisted cache to ensure sub_type is set."""
|
||||||
|
sub_type = self._resolve_sub_type(
|
||||||
self._find_root_for_file(entry.get("file_path"))
|
self._find_root_for_file(entry.get("file_path"))
|
||||||
)
|
)
|
||||||
if model_type:
|
if sub_type:
|
||||||
entry["model_type"] = model_type
|
entry["sub_type"] = sub_type
|
||||||
|
# Also set model_type for backward compatibility during transition
|
||||||
|
entry["model_type"] = sub_type
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
def get_model_roots(self) -> List[str]:
|
def get_model_roots(self) -> List[str]:
|
||||||
|
|||||||
@@ -22,6 +22,9 @@ class CheckpointService(BaseModelService):
|
|||||||
|
|
||||||
async def format_response(self, checkpoint_data: Dict) -> Dict:
|
async def format_response(self, checkpoint_data: Dict) -> Dict:
|
||||||
"""Format Checkpoint data for API response"""
|
"""Format Checkpoint data for API response"""
|
||||||
|
# Get sub_type from cache entry (new field) or fallback to model_type (old field)
|
||||||
|
sub_type = checkpoint_data.get("sub_type") or checkpoint_data.get("model_type", "checkpoint")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model_name": checkpoint_data["model_name"],
|
"model_name": checkpoint_data["model_name"],
|
||||||
"file_name": checkpoint_data["file_name"],
|
"file_name": checkpoint_data["file_name"],
|
||||||
@@ -37,7 +40,8 @@ class CheckpointService(BaseModelService):
|
|||||||
"from_civitai": checkpoint_data.get("from_civitai", True),
|
"from_civitai": checkpoint_data.get("from_civitai", True),
|
||||||
"usage_count": checkpoint_data.get("usage_count", 0),
|
"usage_count": checkpoint_data.get("usage_count", 0),
|
||||||
"notes": checkpoint_data.get("notes", ""),
|
"notes": checkpoint_data.get("notes", ""),
|
||||||
"model_type": checkpoint_data.get("model_type", "checkpoint"),
|
"sub_type": sub_type, # New canonical field
|
||||||
|
"model_type": sub_type, # Backward compatibility
|
||||||
"favorite": checkpoint_data.get("favorite", False),
|
"favorite": checkpoint_data.get("favorite", False),
|
||||||
"update_available": bool(checkpoint_data.get("update_available", False)),
|
"update_available": bool(checkpoint_data.get("update_available", False)),
|
||||||
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True)
|
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True)
|
||||||
|
|||||||
@@ -22,6 +22,9 @@ class EmbeddingService(BaseModelService):
|
|||||||
|
|
||||||
async def format_response(self, embedding_data: Dict) -> Dict:
|
async def format_response(self, embedding_data: Dict) -> Dict:
|
||||||
"""Format Embedding data for API response"""
|
"""Format Embedding data for API response"""
|
||||||
|
# Get sub_type from cache entry (new field) or fallback to model_type (old field)
|
||||||
|
sub_type = embedding_data.get("sub_type") or embedding_data.get("model_type", "embedding")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model_name": embedding_data["model_name"],
|
"model_name": embedding_data["model_name"],
|
||||||
"file_name": embedding_data["file_name"],
|
"file_name": embedding_data["file_name"],
|
||||||
@@ -37,7 +40,8 @@ class EmbeddingService(BaseModelService):
|
|||||||
"from_civitai": embedding_data.get("from_civitai", True),
|
"from_civitai": embedding_data.get("from_civitai", True),
|
||||||
# "usage_count": embedding_data.get("usage_count", 0), # TODO: Enable when embedding usage tracking is implemented
|
# "usage_count": embedding_data.get("usage_count", 0), # TODO: Enable when embedding usage tracking is implemented
|
||||||
"notes": embedding_data.get("notes", ""),
|
"notes": embedding_data.get("notes", ""),
|
||||||
"model_type": embedding_data.get("model_type", "embedding"),
|
"sub_type": sub_type, # New canonical field
|
||||||
|
"model_type": sub_type, # Backward compatibility
|
||||||
"favorite": embedding_data.get("favorite", False),
|
"favorite": embedding_data.get("favorite", False),
|
||||||
"update_available": bool(embedding_data.get("update_available", False)),
|
"update_available": bool(embedding_data.get("update_available", False)),
|
||||||
"civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True)
|
"civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True)
|
||||||
|
|||||||
@@ -23,6 +23,9 @@ class LoraService(BaseModelService):
|
|||||||
|
|
||||||
async def format_response(self, lora_data: Dict) -> Dict:
|
async def format_response(self, lora_data: Dict) -> Dict:
|
||||||
"""Format LoRA data for API response"""
|
"""Format LoRA data for API response"""
|
||||||
|
# Get sub_type from cache entry (new field) or fallback to model_type (old field)
|
||||||
|
sub_type = lora_data.get("sub_type") or lora_data.get("model_type", "lora")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model_name": lora_data["model_name"],
|
"model_name": lora_data["model_name"],
|
||||||
"file_name": lora_data["file_name"],
|
"file_name": lora_data["file_name"],
|
||||||
@@ -43,6 +46,8 @@ class LoraService(BaseModelService):
|
|||||||
"notes": lora_data.get("notes", ""),
|
"notes": lora_data.get("notes", ""),
|
||||||
"favorite": lora_data.get("favorite", False),
|
"favorite": lora_data.get("favorite", False),
|
||||||
"update_available": bool(lora_data.get("update_available", False)),
|
"update_available": bool(lora_data.get("update_available", False)),
|
||||||
|
"sub_type": sub_type, # New canonical field
|
||||||
|
"model_type": sub_type, # Backward compatibility
|
||||||
"civitai": self.filter_civitai_data(
|
"civitai": self.filter_civitai_data(
|
||||||
lora_data.get("civitai", {}), minimal=True
|
lora_data.get("civitai", {}), minimal=True
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -33,32 +33,54 @@ def _coerce_to_str(value: Any) -> Optional[str]:
|
|||||||
return candidate if candidate else None
|
return candidate if candidate else None
|
||||||
|
|
||||||
|
|
||||||
def normalize_civitai_model_type(value: Any) -> Optional[str]:
|
def normalize_sub_type(value: Any) -> Optional[str]:
|
||||||
"""Return a lowercase string suitable for comparisons."""
|
"""Return a lowercase string suitable for sub_type comparisons."""
|
||||||
candidate = _coerce_to_str(value)
|
candidate = _coerce_to_str(value)
|
||||||
return candidate.lower() if candidate else None
|
return candidate.lower() if candidate else None
|
||||||
|
|
||||||
|
|
||||||
def resolve_civitai_model_type(entry: Mapping[str, Any]) -> str:
|
# Backward compatibility alias
|
||||||
"""Extract the model type from CivitAI metadata, defaulting to LORA."""
|
normalize_civitai_model_type = normalize_sub_type
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_sub_type(entry: Mapping[str, Any]) -> str:
|
||||||
|
"""Extract the sub-type from metadata, checking multiple sources.
|
||||||
|
|
||||||
|
Priority:
|
||||||
|
1. entry['sub_type'] - new canonical field
|
||||||
|
2. entry['model_type'] - backward compatibility
|
||||||
|
3. civitai.model.type - CivitAI API data
|
||||||
|
4. DEFAULT_CIVITAI_MODEL_TYPE - fallback
|
||||||
|
"""
|
||||||
if not isinstance(entry, Mapping):
|
if not isinstance(entry, Mapping):
|
||||||
return DEFAULT_CIVITAI_MODEL_TYPE
|
return DEFAULT_CIVITAI_MODEL_TYPE
|
||||||
|
|
||||||
civitai = entry.get("civitai")
|
# Priority 1: Check new canonical field 'sub_type'
|
||||||
if isinstance(civitai, Mapping):
|
sub_type = _coerce_to_str(entry.get("sub_type"))
|
||||||
civitai_model = civitai.get("model")
|
if sub_type:
|
||||||
if isinstance(civitai_model, Mapping):
|
return sub_type
|
||||||
model_type = _coerce_to_str(civitai_model.get("type"))
|
|
||||||
if model_type:
|
|
||||||
return model_type
|
|
||||||
|
|
||||||
|
# Priority 2: Backward compatibility - check 'model_type' field
|
||||||
model_type = _coerce_to_str(entry.get("model_type"))
|
model_type = _coerce_to_str(entry.get("model_type"))
|
||||||
if model_type:
|
if model_type:
|
||||||
return model_type
|
return model_type
|
||||||
|
|
||||||
|
# Priority 3: Extract from CivitAI metadata
|
||||||
|
civitai = entry.get("civitai")
|
||||||
|
if isinstance(civitai, Mapping):
|
||||||
|
civitai_model = civitai.get("model")
|
||||||
|
if isinstance(civitai_model, Mapping):
|
||||||
|
civitai_type = _coerce_to_str(civitai_model.get("type"))
|
||||||
|
if civitai_type:
|
||||||
|
return civitai_type
|
||||||
|
|
||||||
return DEFAULT_CIVITAI_MODEL_TYPE
|
return DEFAULT_CIVITAI_MODEL_TYPE
|
||||||
|
|
||||||
|
|
||||||
|
# Backward compatibility alias
|
||||||
|
resolve_civitai_model_type = resolve_sub_type
|
||||||
|
|
||||||
|
|
||||||
class SettingsProvider(Protocol):
|
class SettingsProvider(Protocol):
|
||||||
"""Protocol describing the SettingsManager contract used by query helpers."""
|
"""Protocol describing the SettingsManager contract used by query helpers."""
|
||||||
|
|
||||||
@@ -313,7 +335,7 @@ class ModelFilterSet:
|
|||||||
normalized_model_types = {
|
normalized_model_types = {
|
||||||
model_type
|
model_type
|
||||||
for model_type in (
|
for model_type in (
|
||||||
normalize_civitai_model_type(value) for value in model_types
|
normalize_sub_type(value) for value in model_types
|
||||||
)
|
)
|
||||||
if model_type
|
if model_type
|
||||||
}
|
}
|
||||||
@@ -321,7 +343,7 @@ class ModelFilterSet:
|
|||||||
items = [
|
items = [
|
||||||
item
|
item
|
||||||
for item in items
|
for item in items
|
||||||
if normalize_civitai_model_type(resolve_civitai_model_type(item))
|
if normalize_sub_type(resolve_sub_type(item))
|
||||||
in normalized_model_types
|
in normalized_model_types
|
||||||
]
|
]
|
||||||
model_types_duration = time.perf_counter() - t0
|
model_types_duration = time.perf_counter() - t0
|
||||||
|
|||||||
@@ -275,9 +275,16 @@ class ModelScanner:
|
|||||||
_, license_flags = resolve_license_info(license_source or {})
|
_, license_flags = resolve_license_info(license_source or {})
|
||||||
entry['license_flags'] = license_flags
|
entry['license_flags'] = license_flags
|
||||||
|
|
||||||
|
# Handle sub_type (new canonical field) and model_type (backward compatibility)
|
||||||
|
sub_type = get_value('sub_type', None)
|
||||||
model_type = get_value('model_type', None)
|
model_type = get_value('model_type', None)
|
||||||
if model_type:
|
|
||||||
entry['model_type'] = model_type
|
# Prefer sub_type, fallback to model_type for backward compatibility
|
||||||
|
effective_sub_type = sub_type or model_type
|
||||||
|
if effective_sub_type:
|
||||||
|
entry['sub_type'] = effective_sub_type
|
||||||
|
# Also keep model_type for backward compatibility during transition
|
||||||
|
entry['model_type'] = effective_sub_type
|
||||||
|
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
|
|||||||
@@ -45,8 +45,13 @@ SUPPORTED_MEDIA_EXTENSIONS = {
|
|||||||
"videos": [".mp4", ".webm"],
|
"videos": [".mp4", ".webm"],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Valid Lora types
|
# Valid sub-types for each scanner type
|
||||||
VALID_LORA_TYPES = ["lora", "locon", "dora"]
|
VALID_LORA_SUB_TYPES = ["lora", "locon", "dora"]
|
||||||
|
VALID_CHECKPOINT_SUB_TYPES = ["checkpoint", "diffusion_model"]
|
||||||
|
VALID_EMBEDDING_SUB_TYPES = ["embedding"]
|
||||||
|
|
||||||
|
# Backward compatibility alias
|
||||||
|
VALID_LORA_TYPES = VALID_LORA_SUB_TYPES
|
||||||
|
|
||||||
# Supported Civitai model types for user model queries (case-insensitive)
|
# Supported Civitai model types for user model queries (case-insensitive)
|
||||||
CIVITAI_USER_MODEL_TYPES = [
|
CIVITAI_USER_MODEL_TYPES = [
|
||||||
|
|||||||
@@ -173,14 +173,14 @@ class LoraMetadata(BaseModelMetadata):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class CheckpointMetadata(BaseModelMetadata):
|
class CheckpointMetadata(BaseModelMetadata):
|
||||||
"""Represents the metadata structure for a Checkpoint model"""
|
"""Represents the metadata structure for a Checkpoint model"""
|
||||||
model_type: str = "checkpoint" # Model type (checkpoint, diffusion_model, etc.)
|
sub_type: str = "checkpoint" # Model sub-type (checkpoint, diffusion_model, etc.)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'CheckpointMetadata':
|
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'CheckpointMetadata':
|
||||||
"""Create CheckpointMetadata instance from Civitai version info"""
|
"""Create CheckpointMetadata instance from Civitai version info"""
|
||||||
file_name = file_info['name']
|
file_name = file_info['name']
|
||||||
base_model = determine_base_model(version_info.get('baseModel', ''))
|
base_model = determine_base_model(version_info.get('baseModel', ''))
|
||||||
model_type = version_info.get('type', 'checkpoint')
|
sub_type = version_info.get('type', 'checkpoint')
|
||||||
|
|
||||||
# Extract tags and description if available
|
# Extract tags and description if available
|
||||||
tags = []
|
tags = []
|
||||||
@@ -203,7 +203,7 @@ class CheckpointMetadata(BaseModelMetadata):
|
|||||||
preview_nsfw_level=0,
|
preview_nsfw_level=0,
|
||||||
from_civitai=True,
|
from_civitai=True,
|
||||||
civitai=version_info,
|
civitai=version_info,
|
||||||
model_type=model_type,
|
sub_type=sub_type,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
modelDescription=description
|
modelDescription=description
|
||||||
)
|
)
|
||||||
@@ -211,14 +211,14 @@ class CheckpointMetadata(BaseModelMetadata):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingMetadata(BaseModelMetadata):
|
class EmbeddingMetadata(BaseModelMetadata):
|
||||||
"""Represents the metadata structure for an Embedding model"""
|
"""Represents the metadata structure for an Embedding model"""
|
||||||
model_type: str = "embedding" # Model type (embedding, textual_inversion, etc.)
|
sub_type: str = "embedding"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'EmbeddingMetadata':
|
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'EmbeddingMetadata':
|
||||||
"""Create EmbeddingMetadata instance from Civitai version info"""
|
"""Create EmbeddingMetadata instance from Civitai version info"""
|
||||||
file_name = file_info['name']
|
file_name = file_info['name']
|
||||||
base_model = determine_base_model(version_info.get('baseModel', ''))
|
base_model = determine_base_model(version_info.get('baseModel', ''))
|
||||||
model_type = version_info.get('type', 'embedding')
|
sub_type = version_info.get('type', 'embedding')
|
||||||
|
|
||||||
# Extract tags and description if available
|
# Extract tags and description if available
|
||||||
tags = []
|
tags = []
|
||||||
@@ -241,7 +241,7 @@ class EmbeddingMetadata(BaseModelMetadata):
|
|||||||
preview_nsfw_level=0,
|
preview_nsfw_level=0,
|
||||||
from_civitai=True,
|
from_civitai=True,
|
||||||
civitai=version_info,
|
civitai=version_info,
|
||||||
model_type=model_type,
|
sub_type=sub_type,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
modelDescription=description
|
modelDescription=description
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -57,12 +57,22 @@ export const BASE_MODELS = {
|
|||||||
UNKNOWN: "Other"
|
UNKNOWN: "Other"
|
||||||
};
|
};
|
||||||
|
|
||||||
export const MODEL_TYPE_DISPLAY_NAMES = {
|
// Model sub-type display names (new canonical field: sub_type)
|
||||||
|
export const MODEL_SUBTYPE_DISPLAY_NAMES = {
|
||||||
|
// LoRA sub-types
|
||||||
lora: "LoRA",
|
lora: "LoRA",
|
||||||
locon: "LyCORIS",
|
locon: "LyCORIS",
|
||||||
dora: "DoRA",
|
dora: "DoRA",
|
||||||
|
// Checkpoint sub-types
|
||||||
|
checkpoint: "Checkpoint",
|
||||||
|
diffusion_model: "Diffusion Model",
|
||||||
|
// Embedding sub-types
|
||||||
|
embedding: "Embedding",
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Backward compatibility alias
|
||||||
|
export const MODEL_TYPE_DISPLAY_NAMES = MODEL_SUBTYPE_DISPLAY_NAMES;
|
||||||
|
|
||||||
export const BASE_MODEL_ABBREVIATIONS = {
|
export const BASE_MODEL_ABBREVIATIONS = {
|
||||||
// Stable Diffusion 1.x models
|
// Stable Diffusion 1.x models
|
||||||
[BASE_MODELS.SD_1_4]: 'SD1',
|
[BASE_MODELS.SD_1_4]: 'SD1',
|
||||||
|
|||||||
145
tests/services/test_checkpoint_scanner_sub_type.py
Normal file
145
tests/services/test_checkpoint_scanner_sub_type.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""Tests for CheckpointScanner sub_type resolution."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from py.services.checkpoint_scanner import CheckpointScanner
|
||||||
|
from py.utils.models import CheckpointMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckpointScannerSubType:
|
||||||
|
"""Test CheckpointScanner sub_type resolution logic."""
|
||||||
|
|
||||||
|
def create_scanner(self):
|
||||||
|
"""Create scanner with no async initialization."""
|
||||||
|
# Create scanner without calling __init__ to avoid async issues
|
||||||
|
scanner = object.__new__(CheckpointScanner)
|
||||||
|
scanner.model_type = "checkpoint"
|
||||||
|
scanner.model_class = CheckpointMetadata
|
||||||
|
scanner.file_extensions = {'.ckpt', '.safetensors'}
|
||||||
|
scanner._hash_index = MagicMock()
|
||||||
|
return scanner
|
||||||
|
|
||||||
|
def test_resolve_sub_type_checkpoint_root(self):
|
||||||
|
"""_resolve_sub_type should return 'checkpoint' for checkpoints_roots."""
|
||||||
|
scanner = self.create_scanner()
|
||||||
|
|
||||||
|
from py import config as config_module
|
||||||
|
original_checkpoints_roots = getattr(config_module.config, 'checkpoints_roots', None)
|
||||||
|
original_unet_roots = getattr(config_module.config, 'unet_roots', None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
config_module.config.checkpoints_roots = ["/models/checkpoints"]
|
||||||
|
config_module.config.unet_roots = ["/models/unet"]
|
||||||
|
|
||||||
|
result = scanner._resolve_sub_type("/models/checkpoints")
|
||||||
|
assert result == "checkpoint"
|
||||||
|
finally:
|
||||||
|
if original_checkpoints_roots is not None:
|
||||||
|
config_module.config.checkpoints_roots = original_checkpoints_roots
|
||||||
|
if original_unet_roots is not None:
|
||||||
|
config_module.config.unet_roots = original_unet_roots
|
||||||
|
|
||||||
|
def test_resolve_sub_type_unet_root(self):
|
||||||
|
"""_resolve_sub_type should return 'diffusion_model' for unet_roots."""
|
||||||
|
scanner = self.create_scanner()
|
||||||
|
|
||||||
|
from py import config as config_module
|
||||||
|
original_checkpoints_roots = getattr(config_module.config, 'checkpoints_roots', None)
|
||||||
|
original_unet_roots = getattr(config_module.config, 'unet_roots', None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
config_module.config.checkpoints_roots = ["/models/checkpoints"]
|
||||||
|
config_module.config.unet_roots = ["/models/unet"]
|
||||||
|
|
||||||
|
result = scanner._resolve_sub_type("/models/unet")
|
||||||
|
assert result == "diffusion_model"
|
||||||
|
finally:
|
||||||
|
if original_checkpoints_roots is not None:
|
||||||
|
config_module.config.checkpoints_roots = original_checkpoints_roots
|
||||||
|
if original_unet_roots is not None:
|
||||||
|
config_module.config.unet_roots = original_unet_roots
|
||||||
|
|
||||||
|
def test_resolve_sub_type_none_root(self):
|
||||||
|
"""_resolve_sub_type should return None for None input."""
|
||||||
|
scanner = self.create_scanner()
|
||||||
|
result = scanner._resolve_sub_type(None)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_resolve_sub_type_unknown_root(self):
|
||||||
|
"""_resolve_sub_type should return None for unknown root."""
|
||||||
|
scanner = self.create_scanner()
|
||||||
|
|
||||||
|
from py import config as config_module
|
||||||
|
original_checkpoints_roots = getattr(config_module.config, 'checkpoints_roots', None)
|
||||||
|
original_unet_roots = getattr(config_module.config, 'unet_roots', None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
config_module.config.checkpoints_roots = ["/models/checkpoints"]
|
||||||
|
config_module.config.unet_roots = ["/models/unet"]
|
||||||
|
|
||||||
|
result = scanner._resolve_sub_type("/models/unknown")
|
||||||
|
assert result is None
|
||||||
|
finally:
|
||||||
|
if original_checkpoints_roots is not None:
|
||||||
|
config_module.config.checkpoints_roots = original_checkpoints_roots
|
||||||
|
if original_unet_roots is not None:
|
||||||
|
config_module.config.unet_roots = original_unet_roots
|
||||||
|
|
||||||
|
def test_adjust_metadata_sets_sub_type(self):
|
||||||
|
"""adjust_metadata should set sub_type on metadata."""
|
||||||
|
scanner = self.create_scanner()
|
||||||
|
|
||||||
|
metadata = CheckpointMetadata(
|
||||||
|
file_name="test",
|
||||||
|
model_name="Test",
|
||||||
|
file_path="/models/checkpoints/model.safetensors",
|
||||||
|
size=1000,
|
||||||
|
modified=1234567890.0,
|
||||||
|
sha256="abc123",
|
||||||
|
base_model="SDXL",
|
||||||
|
preview_url="",
|
||||||
|
)
|
||||||
|
|
||||||
|
from py import config as config_module
|
||||||
|
original_checkpoints_roots = getattr(config_module.config, 'checkpoints_roots', None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
config_module.config.checkpoints_roots = ["/models/checkpoints"]
|
||||||
|
config_module.config.unet_roots = []
|
||||||
|
|
||||||
|
result = scanner.adjust_metadata(metadata, "/models/checkpoints/model.safetensors", "/models/checkpoints")
|
||||||
|
assert result.sub_type == "checkpoint"
|
||||||
|
finally:
|
||||||
|
if original_checkpoints_roots is not None:
|
||||||
|
config_module.config.checkpoints_roots = original_checkpoints_roots
|
||||||
|
|
||||||
|
def test_adjust_cached_entry_sets_sub_type(self):
|
||||||
|
"""adjust_cached_entry should set sub_type on entry."""
|
||||||
|
scanner = self.create_scanner()
|
||||||
|
# Mock get_model_roots to return the expected roots
|
||||||
|
scanner.get_model_roots = lambda: ["/models/unet"]
|
||||||
|
|
||||||
|
entry = {
|
||||||
|
"file_path": "/models/unet/model.safetensors",
|
||||||
|
"model_name": "Test",
|
||||||
|
}
|
||||||
|
|
||||||
|
from py import config as config_module
|
||||||
|
original_checkpoints_roots = getattr(config_module.config, 'checkpoints_roots', None)
|
||||||
|
original_unet_roots = getattr(config_module.config, 'unet_roots', None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
config_module.config.checkpoints_roots = []
|
||||||
|
config_module.config.unet_roots = ["/models/unet"]
|
||||||
|
|
||||||
|
result = scanner.adjust_cached_entry(entry)
|
||||||
|
assert result["sub_type"] == "diffusion_model"
|
||||||
|
# Also sets model_type for backward compatibility
|
||||||
|
assert result["model_type"] == "diffusion_model"
|
||||||
|
finally:
|
||||||
|
if original_checkpoints_roots is not None:
|
||||||
|
config_module.config.checkpoints_roots = original_checkpoints_roots
|
||||||
|
if original_unet_roots is not None:
|
||||||
|
config_module.config.unet_roots = original_unet_roots
|
||||||
203
tests/services/test_model_query_sub_type.py
Normal file
203
tests/services/test_model_query_sub_type.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
"""Tests for model_query sub_type resolution."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from py.services.model_query import (
|
||||||
|
_coerce_to_str,
|
||||||
|
normalize_sub_type,
|
||||||
|
normalize_civitai_model_type,
|
||||||
|
resolve_sub_type,
|
||||||
|
resolve_civitai_model_type,
|
||||||
|
FilterCriteria,
|
||||||
|
ModelFilterSet,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCoerceToStr:
|
||||||
|
"""Test _coerce_to_str helper."""
|
||||||
|
|
||||||
|
def test_none_returns_none(self):
|
||||||
|
assert _coerce_to_str(None) is None
|
||||||
|
|
||||||
|
def test_string_returns_stripped(self):
|
||||||
|
assert _coerce_to_str(" test ") == "test"
|
||||||
|
|
||||||
|
def test_empty_returns_none(self):
|
||||||
|
assert _coerce_to_str(" ") is None
|
||||||
|
|
||||||
|
def test_number_converts_to_str(self):
|
||||||
|
assert _coerce_to_str(123) == "123"
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeSubType:
|
||||||
|
"""Test normalize_sub_type function."""
|
||||||
|
|
||||||
|
def test_normalizes_to_lowercase(self):
|
||||||
|
assert normalize_sub_type("LoRA") == "lora"
|
||||||
|
assert normalize_sub_type("CHECKPOINT") == "checkpoint"
|
||||||
|
|
||||||
|
def test_strips_whitespace(self):
|
||||||
|
assert normalize_sub_type(" LoRA ") == "lora"
|
||||||
|
|
||||||
|
def test_none_returns_none(self):
|
||||||
|
assert normalize_sub_type(None) is None
|
||||||
|
|
||||||
|
def test_empty_returns_none(self):
|
||||||
|
assert normalize_sub_type("") is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeCivitaiModelTypeAlias:
|
||||||
|
"""Test normalize_civitai_model_type is alias for normalize_sub_type."""
|
||||||
|
|
||||||
|
def test_alias_works_correctly(self):
|
||||||
|
assert normalize_civitai_model_type("LoRA") == "lora"
|
||||||
|
assert normalize_civitai_model_type("CHECKPOINT") == "checkpoint"
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolveSubType:
|
||||||
|
"""Test resolve_sub_type function priority."""
|
||||||
|
|
||||||
|
def test_priority_1_sub_type_field(self):
|
||||||
|
"""Priority 1: entry['sub_type'] should be used first."""
|
||||||
|
entry = {
|
||||||
|
"sub_type": "locon",
|
||||||
|
"model_type": "checkpoint", # Should be ignored
|
||||||
|
"civitai": {"model": {"type": "dora"}}, # Should be ignored
|
||||||
|
}
|
||||||
|
assert resolve_sub_type(entry) == "locon"
|
||||||
|
|
||||||
|
def test_priority_2_model_type_field(self):
|
||||||
|
"""Priority 2: entry['model_type'] as fallback."""
|
||||||
|
entry = {
|
||||||
|
"model_type": "checkpoint",
|
||||||
|
"civitai": {"model": {"type": "dora"}}, # Should be ignored
|
||||||
|
}
|
||||||
|
assert resolve_sub_type(entry) == "checkpoint"
|
||||||
|
|
||||||
|
def test_priority_3_civitai_model_type(self):
|
||||||
|
"""Priority 3: civitai.model.type as fallback."""
|
||||||
|
entry = {
|
||||||
|
"civitai": {"model": {"type": "dora"}},
|
||||||
|
}
|
||||||
|
assert resolve_sub_type(entry) == "dora"
|
||||||
|
|
||||||
|
def test_priority_4_default(self):
|
||||||
|
"""Priority 4: default to LORA when nothing found."""
|
||||||
|
entry = {}
|
||||||
|
assert resolve_sub_type(entry) == "LORA"
|
||||||
|
|
||||||
|
def test_empty_sub_type_falls_back(self):
|
||||||
|
"""Empty sub_type should fall back to model_type."""
|
||||||
|
entry = {
|
||||||
|
"sub_type": "",
|
||||||
|
"model_type": "checkpoint",
|
||||||
|
}
|
||||||
|
assert resolve_sub_type(entry) == "checkpoint"
|
||||||
|
|
||||||
|
def test_whitespace_sub_type_falls_back(self):
|
||||||
|
"""Whitespace sub_type should fall back to model_type."""
|
||||||
|
entry = {
|
||||||
|
"sub_type": " ",
|
||||||
|
"model_type": "checkpoint",
|
||||||
|
}
|
||||||
|
assert resolve_sub_type(entry) == "checkpoint"
|
||||||
|
|
||||||
|
def test_none_entry_returns_default(self):
|
||||||
|
"""None entry should return default."""
|
||||||
|
assert resolve_sub_type(None) == "LORA"
|
||||||
|
|
||||||
|
def test_non_mapping_returns_default(self):
|
||||||
|
"""Non-mapping entry should return default."""
|
||||||
|
assert resolve_sub_type("invalid") == "LORA"
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolveCivitaiModelTypeAlias:
|
||||||
|
"""Test resolve_civitai_model_type is alias for resolve_sub_type."""
|
||||||
|
|
||||||
|
def test_alias_works_correctly(self):
|
||||||
|
entry = {"sub_type": "locon"}
|
||||||
|
assert resolve_civitai_model_type(entry) == "locon"
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelFilterSetWithSubType:
|
||||||
|
"""Test ModelFilterSet applies model_types filtering correctly."""
|
||||||
|
|
||||||
|
def create_mock_settings(self):
|
||||||
|
class MockSettings:
|
||||||
|
def get(self, key, default=None):
|
||||||
|
return default
|
||||||
|
return MockSettings()
|
||||||
|
|
||||||
|
def test_filter_by_sub_type(self):
|
||||||
|
"""Filter should work with sub_type field."""
|
||||||
|
settings = self.create_mock_settings()
|
||||||
|
filter_set = ModelFilterSet(settings)
|
||||||
|
|
||||||
|
data = [
|
||||||
|
{"sub_type": "lora", "model_name": "Model 1"},
|
||||||
|
{"sub_type": "locon", "model_name": "Model 2"},
|
||||||
|
{"sub_type": "dora", "model_name": "Model 3"},
|
||||||
|
]
|
||||||
|
|
||||||
|
criteria = FilterCriteria(model_types=["lora", "locon"])
|
||||||
|
result = filter_set.apply(data, criteria)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["model_name"] == "Model 1"
|
||||||
|
assert result[1]["model_name"] == "Model 2"
|
||||||
|
|
||||||
|
def test_filter_falls_back_to_model_type(self):
|
||||||
|
"""Filter should fall back to model_type field."""
|
||||||
|
settings = self.create_mock_settings()
|
||||||
|
filter_set = ModelFilterSet(settings)
|
||||||
|
|
||||||
|
data = [
|
||||||
|
{"model_type": "lora", "model_name": "Model 1"}, # Old field
|
||||||
|
{"sub_type": "locon", "model_name": "Model 2"}, # New field
|
||||||
|
]
|
||||||
|
|
||||||
|
criteria = FilterCriteria(model_types=["lora", "locon"])
|
||||||
|
result = filter_set.apply(data, criteria)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
def test_filter_uses_civitai_type(self):
|
||||||
|
"""Filter should use civitai.model.type as last resort."""
|
||||||
|
settings = self.create_mock_settings()
|
||||||
|
filter_set = ModelFilterSet(settings)
|
||||||
|
|
||||||
|
data = [
|
||||||
|
{"civitai": {"model": {"type": "dora"}}, "model_name": "Model 1"},
|
||||||
|
]
|
||||||
|
|
||||||
|
criteria = FilterCriteria(model_types=["dora"])
|
||||||
|
result = filter_set.apply(data, criteria)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
def test_filter_case_insensitive(self):
|
||||||
|
"""Filter should be case insensitive."""
|
||||||
|
settings = self.create_mock_settings()
|
||||||
|
filter_set = ModelFilterSet(settings)
|
||||||
|
|
||||||
|
data = [
|
||||||
|
{"sub_type": "LoRA", "model_name": "Model 1"},
|
||||||
|
]
|
||||||
|
|
||||||
|
criteria = FilterCriteria(model_types=["lora"])
|
||||||
|
result = filter_set.apply(data, criteria)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
def test_filter_no_match_returns_empty(self):
|
||||||
|
"""Filter with no match should return empty list."""
|
||||||
|
settings = self.create_mock_settings()
|
||||||
|
filter_set = ModelFilterSet(settings)
|
||||||
|
|
||||||
|
data = [
|
||||||
|
{"sub_type": "lora", "model_name": "Model 1"},
|
||||||
|
]
|
||||||
|
|
||||||
|
criteria = FilterCriteria(model_types=["checkpoint"])
|
||||||
|
result = filter_set.apply(data, criteria)
|
||||||
|
|
||||||
|
assert len(result) == 0
|
||||||
230
tests/services/test_service_format_response_sub_type.py
Normal file
230
tests/services/test_service_format_response_sub_type.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
"""Tests for service format_response sub_type inclusion."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import MagicMock, AsyncMock
|
||||||
|
|
||||||
|
from py.services.lora_service import LoraService
|
||||||
|
from py.services.checkpoint_service import CheckpointService
|
||||||
|
from py.services.embedding_service import EmbeddingService
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoraServiceFormatResponse:
|
||||||
|
"""Test LoraService.format_response includes sub_type."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_scanner(self):
|
||||||
|
scanner = MagicMock()
|
||||||
|
scanner._hash_index = MagicMock()
|
||||||
|
return scanner
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def lora_service(self, mock_scanner):
|
||||||
|
return LoraService(mock_scanner)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_format_response_includes_sub_type(self, lora_service):
|
||||||
|
"""format_response should include sub_type field."""
|
||||||
|
lora_data = {
|
||||||
|
"model_name": "Test LoRA",
|
||||||
|
"file_name": "test_lora",
|
||||||
|
"preview_url": "test.webp",
|
||||||
|
"preview_nsfw_level": 0,
|
||||||
|
"base_model": "SDXL",
|
||||||
|
"folder": "",
|
||||||
|
"sha256": "abc123",
|
||||||
|
"file_path": "/models/test_lora.safetensors",
|
||||||
|
"size": 1000,
|
||||||
|
"modified": 1234567890.0,
|
||||||
|
"tags": [],
|
||||||
|
"from_civitai": True,
|
||||||
|
"usage_count": 0,
|
||||||
|
"usage_tips": "",
|
||||||
|
"notes": "",
|
||||||
|
"favorite": False,
|
||||||
|
"sub_type": "locon", # New field
|
||||||
|
"civitai": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await lora_service.format_response(lora_data)
|
||||||
|
|
||||||
|
assert "sub_type" in result
|
||||||
|
assert result["sub_type"] == "locon"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_format_response_falls_back_to_model_type(self, lora_service):
|
||||||
|
"""format_response should fall back to model_type if sub_type missing."""
|
||||||
|
lora_data = {
|
||||||
|
"model_name": "Test LoRA",
|
||||||
|
"file_name": "test_lora",
|
||||||
|
"preview_url": "test.webp",
|
||||||
|
"preview_nsfw_level": 0,
|
||||||
|
"base_model": "SDXL",
|
||||||
|
"folder": "",
|
||||||
|
"sha256": "abc123",
|
||||||
|
"file_path": "/models/test_lora.safetensors",
|
||||||
|
"size": 1000,
|
||||||
|
"modified": 1234567890.0,
|
||||||
|
"tags": [],
|
||||||
|
"from_civitai": True,
|
||||||
|
"model_type": "dora", # Old field
|
||||||
|
"civitai": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await lora_service.format_response(lora_data)
|
||||||
|
|
||||||
|
assert result["sub_type"] == "dora"
|
||||||
|
assert result["model_type"] == "dora" # Both should be set
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_format_response_defaults_to_lora(self, lora_service):
|
||||||
|
"""format_response should default to 'lora' if no type field."""
|
||||||
|
lora_data = {
|
||||||
|
"model_name": "Test LoRA",
|
||||||
|
"file_name": "test_lora",
|
||||||
|
"preview_url": "test.webp",
|
||||||
|
"preview_nsfw_level": 0,
|
||||||
|
"base_model": "SDXL",
|
||||||
|
"folder": "",
|
||||||
|
"sha256": "abc123",
|
||||||
|
"file_path": "/models/test_lora.safetensors",
|
||||||
|
"size": 1000,
|
||||||
|
"modified": 1234567890.0,
|
||||||
|
"tags": [],
|
||||||
|
"from_civitai": True,
|
||||||
|
"civitai": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await lora_service.format_response(lora_data)
|
||||||
|
|
||||||
|
assert result["sub_type"] == "lora"
|
||||||
|
assert result["model_type"] == "lora"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckpointServiceFormatResponse:
|
||||||
|
"""Test CheckpointService.format_response includes sub_type."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_scanner(self):
|
||||||
|
scanner = MagicMock()
|
||||||
|
scanner._hash_index = MagicMock()
|
||||||
|
return scanner
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def checkpoint_service(self, mock_scanner):
|
||||||
|
return CheckpointService(mock_scanner)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_format_response_includes_sub_type_checkpoint(self, checkpoint_service):
|
||||||
|
"""format_response should include sub_type field for checkpoint."""
|
||||||
|
checkpoint_data = {
|
||||||
|
"model_name": "Test Checkpoint",
|
||||||
|
"file_name": "test_ckpt",
|
||||||
|
"preview_url": "test.webp",
|
||||||
|
"preview_nsfw_level": 0,
|
||||||
|
"base_model": "SDXL",
|
||||||
|
"folder": "",
|
||||||
|
"sha256": "abc123",
|
||||||
|
"file_path": "/models/test.safetensors",
|
||||||
|
"size": 1000,
|
||||||
|
"modified": 1234567890.0,
|
||||||
|
"tags": [],
|
||||||
|
"from_civitai": True,
|
||||||
|
"sub_type": "checkpoint",
|
||||||
|
"civitai": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await checkpoint_service.format_response(checkpoint_data)
|
||||||
|
|
||||||
|
assert "sub_type" in result
|
||||||
|
assert result["sub_type"] == "checkpoint"
|
||||||
|
assert result["model_type"] == "checkpoint"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_format_response_includes_sub_type_diffusion_model(self, checkpoint_service):
|
||||||
|
"""format_response should include sub_type field for diffusion_model."""
|
||||||
|
checkpoint_data = {
|
||||||
|
"model_name": "Test Diffusion Model",
|
||||||
|
"file_name": "test_unet",
|
||||||
|
"preview_url": "test.webp",
|
||||||
|
"preview_nsfw_level": 0,
|
||||||
|
"base_model": "SDXL",
|
||||||
|
"folder": "",
|
||||||
|
"sha256": "abc123",
|
||||||
|
"file_path": "/models/test.safetensors",
|
||||||
|
"size": 1000,
|
||||||
|
"modified": 1234567890.0,
|
||||||
|
"tags": [],
|
||||||
|
"from_civitai": True,
|
||||||
|
"sub_type": "diffusion_model",
|
||||||
|
"civitai": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await checkpoint_service.format_response(checkpoint_data)
|
||||||
|
|
||||||
|
assert result["sub_type"] == "diffusion_model"
|
||||||
|
assert result["model_type"] == "diffusion_model"
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingServiceFormatResponse:
|
||||||
|
"""Test EmbeddingService.format_response includes sub_type."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_scanner(self):
|
||||||
|
scanner = MagicMock()
|
||||||
|
scanner._hash_index = MagicMock()
|
||||||
|
return scanner
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def embedding_service(self, mock_scanner):
|
||||||
|
return EmbeddingService(mock_scanner)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_format_response_includes_sub_type(self, embedding_service):
|
||||||
|
"""format_response should include sub_type field."""
|
||||||
|
embedding_data = {
|
||||||
|
"model_name": "Test Embedding",
|
||||||
|
"file_name": "test_emb",
|
||||||
|
"preview_url": "test.webp",
|
||||||
|
"preview_nsfw_level": 0,
|
||||||
|
"base_model": "SD1.5",
|
||||||
|
"folder": "",
|
||||||
|
"sha256": "abc123",
|
||||||
|
"file_path": "/models/test.pt",
|
||||||
|
"size": 1000,
|
||||||
|
"modified": 1234567890.0,
|
||||||
|
"tags": [],
|
||||||
|
"from_civitai": True,
|
||||||
|
"sub_type": "embedding",
|
||||||
|
"civitai": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await embedding_service.format_response(embedding_data)
|
||||||
|
|
||||||
|
assert "sub_type" in result
|
||||||
|
assert result["sub_type"] == "embedding"
|
||||||
|
assert result["model_type"] == "embedding"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_format_response_defaults_to_embedding(self, embedding_service):
|
||||||
|
"""format_response should default to 'embedding' if no type field."""
|
||||||
|
embedding_data = {
|
||||||
|
"model_name": "Test Embedding",
|
||||||
|
"file_name": "test_emb",
|
||||||
|
"preview_url": "test.webp",
|
||||||
|
"preview_nsfw_level": 0,
|
||||||
|
"base_model": "SD1.5",
|
||||||
|
"folder": "",
|
||||||
|
"sha256": "abc123",
|
||||||
|
"file_path": "/models/test.pt",
|
||||||
|
"size": 1000,
|
||||||
|
"modified": 1234567890.0,
|
||||||
|
"tags": [],
|
||||||
|
"from_civitai": True,
|
||||||
|
"civitai": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await embedding_service.format_response(embedding_data)
|
||||||
|
|
||||||
|
assert result["sub_type"] == "embedding"
|
||||||
|
assert result["model_type"] == "embedding"
|
||||||
127
tests/utils/test_models_sub_type.py
Normal file
127
tests/utils/test_models_sub_type.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
"""Tests for model sub_type field refactoring."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from py.utils.models import (
|
||||||
|
BaseModelMetadata,
|
||||||
|
LoraMetadata,
|
||||||
|
CheckpointMetadata,
|
||||||
|
EmbeddingMetadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckpointMetadataSubType:
|
||||||
|
"""Test CheckpointMetadata uses sub_type field."""
|
||||||
|
|
||||||
|
def test_checkpoint_has_sub_type_field(self):
|
||||||
|
"""CheckpointMetadata should have sub_type field."""
|
||||||
|
metadata = CheckpointMetadata(
|
||||||
|
file_name="test",
|
||||||
|
model_name="Test Model",
|
||||||
|
file_path="/test/model.safetensors",
|
||||||
|
size=1000,
|
||||||
|
modified=1234567890.0,
|
||||||
|
sha256="abc123",
|
||||||
|
base_model="SDXL",
|
||||||
|
preview_url="",
|
||||||
|
)
|
||||||
|
assert hasattr(metadata, "sub_type")
|
||||||
|
assert metadata.sub_type == "checkpoint"
|
||||||
|
|
||||||
|
def test_checkpoint_sub_type_can_be_diffusion_model(self):
|
||||||
|
"""CheckpointMetadata sub_type can be set to diffusion_model."""
|
||||||
|
metadata = CheckpointMetadata(
|
||||||
|
file_name="test",
|
||||||
|
model_name="Test Model",
|
||||||
|
file_path="/test/model.safetensors",
|
||||||
|
size=1000,
|
||||||
|
modified=1234567890.0,
|
||||||
|
sha256="abc123",
|
||||||
|
base_model="SDXL",
|
||||||
|
preview_url="",
|
||||||
|
sub_type="diffusion_model",
|
||||||
|
)
|
||||||
|
assert metadata.sub_type == "diffusion_model"
|
||||||
|
|
||||||
|
def test_checkpoint_from_civitai_info_uses_sub_type(self):
|
||||||
|
"""from_civitai_info should use sub_type from version_info."""
|
||||||
|
version_info = {
|
||||||
|
"baseModel": "SDXL",
|
||||||
|
"model": {"name": "Test", "description": "", "tags": []},
|
||||||
|
"files": [{"name": "model.safetensors", "sizeKB": 1000, "hashes": {"SHA256": "abc123"}, "primary": True}],
|
||||||
|
}
|
||||||
|
file_info = version_info["files"][0]
|
||||||
|
save_path = "/test/model.safetensors"
|
||||||
|
|
||||||
|
metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||||
|
|
||||||
|
assert hasattr(metadata, "sub_type")
|
||||||
|
# When type is missing from version_info, defaults to "checkpoint"
|
||||||
|
assert metadata.sub_type == "checkpoint"
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingMetadataSubType:
|
||||||
|
"""Test EmbeddingMetadata uses sub_type field."""
|
||||||
|
|
||||||
|
def test_embedding_has_sub_type_field(self):
|
||||||
|
"""EmbeddingMetadata should have sub_type field."""
|
||||||
|
metadata = EmbeddingMetadata(
|
||||||
|
file_name="test",
|
||||||
|
model_name="Test Model",
|
||||||
|
file_path="/test/model.pt",
|
||||||
|
size=1000,
|
||||||
|
modified=1234567890.0,
|
||||||
|
sha256="abc123",
|
||||||
|
base_model="SD1.5",
|
||||||
|
preview_url="",
|
||||||
|
)
|
||||||
|
assert hasattr(metadata, "sub_type")
|
||||||
|
assert metadata.sub_type == "embedding"
|
||||||
|
|
||||||
|
def test_embedding_from_civitai_info_uses_sub_type(self):
|
||||||
|
"""from_civitai_info should use sub_type from version_info."""
|
||||||
|
version_info = {
|
||||||
|
"baseModel": "SD1.5",
|
||||||
|
"model": {"name": "Test", "description": "", "tags": []},
|
||||||
|
"files": [{"name": "model.pt", "sizeKB": 1000, "hashes": {"SHA256": "abc123"}, "primary": True}],
|
||||||
|
}
|
||||||
|
file_info = version_info["files"][0]
|
||||||
|
save_path = "/test/model.pt"
|
||||||
|
|
||||||
|
metadata = EmbeddingMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||||
|
|
||||||
|
assert hasattr(metadata, "sub_type")
|
||||||
|
assert metadata.sub_type == "embedding"
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoraMetadataConsistency:
|
||||||
|
"""Test LoraMetadata consistency (no sub_type field, uses civitai data)."""
|
||||||
|
|
||||||
|
def test_lora_does_not_have_sub_type_field(self):
|
||||||
|
"""LoraMetadata should not have sub_type field (uses civitai.model.type)."""
|
||||||
|
metadata = LoraMetadata(
|
||||||
|
file_name="test",
|
||||||
|
model_name="Test Model",
|
||||||
|
file_path="/test/model.safetensors",
|
||||||
|
size=1000,
|
||||||
|
modified=1234567890.0,
|
||||||
|
sha256="abc123",
|
||||||
|
base_model="SDXL",
|
||||||
|
preview_url="",
|
||||||
|
)
|
||||||
|
# Lora doesn't have sub_type field - it uses civitai data
|
||||||
|
assert not hasattr(metadata, "sub_type")
|
||||||
|
|
||||||
|
def test_lora_from_civitai_info_extracts_type(self):
|
||||||
|
"""from_civitai_info should extract type from civitai data."""
|
||||||
|
version_info = {
|
||||||
|
"baseModel": "SDXL",
|
||||||
|
"model": {"name": "Test", "description": "", "tags": [], "type": "Lora"},
|
||||||
|
"files": [{"name": "model.safetensors", "sizeKB": 1000, "hashes": {"SHA256": "abc123"}, "primary": True}],
|
||||||
|
}
|
||||||
|
file_info = version_info["files"][0]
|
||||||
|
save_path = "/test/model.safetensors"
|
||||||
|
|
||||||
|
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||||
|
|
||||||
|
# Type is stored in civitai dict
|
||||||
|
assert metadata.civitai.get("model", {}).get("type") == "Lora"
|
||||||
Reference in New Issue
Block a user