mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 14:42:11 -03:00
refactor: unify model_type semantics by introducing sub_type field
This commit resolves the semantic confusion around the model_type field by clearly distinguishing between: - scanner_type: architecture-level (lora/checkpoint/embedding) - sub_type: business-level subtype (lora/locon/dora/checkpoint/diffusion_model/embedding) Backend Changes: - Rename model_type to sub_type in CheckpointMetadata and EmbeddingMetadata - Add resolve_sub_type() and normalize_sub_type() in model_query.py - Update checkpoint_scanner to use _resolve_sub_type() - Update service format_response to include both sub_type and model_type - Add VALID_*_SUB_TYPES constants with backward compatible aliases Frontend Changes: - Add MODEL_SUBTYPE_DISPLAY_NAMES constants - Keep MODEL_TYPE_DISPLAY_NAMES as backward compatible alias Testing: - Add 43 new tests covering sub_type resolution and API response Documentation: - Add refactoring todo document to docs/technical/ BREAKING CHANGE: None - full backward compatibility maintained
This commit is contained in:
145
tests/services/test_checkpoint_scanner_sub_type.py
Normal file
145
tests/services/test_checkpoint_scanner_sub_type.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Tests for CheckpointScanner sub_type resolution."""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from py.services.checkpoint_scanner import CheckpointScanner
|
||||
from py.utils.models import CheckpointMetadata
|
||||
|
||||
|
||||
class TestCheckpointScannerSubType:
|
||||
"""Test CheckpointScanner sub_type resolution logic."""
|
||||
|
||||
def create_scanner(self):
|
||||
"""Create scanner with no async initialization."""
|
||||
# Create scanner without calling __init__ to avoid async issues
|
||||
scanner = object.__new__(CheckpointScanner)
|
||||
scanner.model_type = "checkpoint"
|
||||
scanner.model_class = CheckpointMetadata
|
||||
scanner.file_extensions = {'.ckpt', '.safetensors'}
|
||||
scanner._hash_index = MagicMock()
|
||||
return scanner
|
||||
|
||||
def test_resolve_sub_type_checkpoint_root(self):
|
||||
"""_resolve_sub_type should return 'checkpoint' for checkpoints_roots."""
|
||||
scanner = self.create_scanner()
|
||||
|
||||
from py import config as config_module
|
||||
original_checkpoints_roots = getattr(config_module.config, 'checkpoints_roots', None)
|
||||
original_unet_roots = getattr(config_module.config, 'unet_roots', None)
|
||||
|
||||
try:
|
||||
config_module.config.checkpoints_roots = ["/models/checkpoints"]
|
||||
config_module.config.unet_roots = ["/models/unet"]
|
||||
|
||||
result = scanner._resolve_sub_type("/models/checkpoints")
|
||||
assert result == "checkpoint"
|
||||
finally:
|
||||
if original_checkpoints_roots is not None:
|
||||
config_module.config.checkpoints_roots = original_checkpoints_roots
|
||||
if original_unet_roots is not None:
|
||||
config_module.config.unet_roots = original_unet_roots
|
||||
|
||||
def test_resolve_sub_type_unet_root(self):
|
||||
"""_resolve_sub_type should return 'diffusion_model' for unet_roots."""
|
||||
scanner = self.create_scanner()
|
||||
|
||||
from py import config as config_module
|
||||
original_checkpoints_roots = getattr(config_module.config, 'checkpoints_roots', None)
|
||||
original_unet_roots = getattr(config_module.config, 'unet_roots', None)
|
||||
|
||||
try:
|
||||
config_module.config.checkpoints_roots = ["/models/checkpoints"]
|
||||
config_module.config.unet_roots = ["/models/unet"]
|
||||
|
||||
result = scanner._resolve_sub_type("/models/unet")
|
||||
assert result == "diffusion_model"
|
||||
finally:
|
||||
if original_checkpoints_roots is not None:
|
||||
config_module.config.checkpoints_roots = original_checkpoints_roots
|
||||
if original_unet_roots is not None:
|
||||
config_module.config.unet_roots = original_unet_roots
|
||||
|
||||
def test_resolve_sub_type_none_root(self):
|
||||
"""_resolve_sub_type should return None for None input."""
|
||||
scanner = self.create_scanner()
|
||||
result = scanner._resolve_sub_type(None)
|
||||
assert result is None
|
||||
|
||||
def test_resolve_sub_type_unknown_root(self):
|
||||
"""_resolve_sub_type should return None for unknown root."""
|
||||
scanner = self.create_scanner()
|
||||
|
||||
from py import config as config_module
|
||||
original_checkpoints_roots = getattr(config_module.config, 'checkpoints_roots', None)
|
||||
original_unet_roots = getattr(config_module.config, 'unet_roots', None)
|
||||
|
||||
try:
|
||||
config_module.config.checkpoints_roots = ["/models/checkpoints"]
|
||||
config_module.config.unet_roots = ["/models/unet"]
|
||||
|
||||
result = scanner._resolve_sub_type("/models/unknown")
|
||||
assert result is None
|
||||
finally:
|
||||
if original_checkpoints_roots is not None:
|
||||
config_module.config.checkpoints_roots = original_checkpoints_roots
|
||||
if original_unet_roots is not None:
|
||||
config_module.config.unet_roots = original_unet_roots
|
||||
|
||||
def test_adjust_metadata_sets_sub_type(self):
|
||||
"""adjust_metadata should set sub_type on metadata."""
|
||||
scanner = self.create_scanner()
|
||||
|
||||
metadata = CheckpointMetadata(
|
||||
file_name="test",
|
||||
model_name="Test",
|
||||
file_path="/models/checkpoints/model.safetensors",
|
||||
size=1000,
|
||||
modified=1234567890.0,
|
||||
sha256="abc123",
|
||||
base_model="SDXL",
|
||||
preview_url="",
|
||||
)
|
||||
|
||||
from py import config as config_module
|
||||
original_checkpoints_roots = getattr(config_module.config, 'checkpoints_roots', None)
|
||||
|
||||
try:
|
||||
config_module.config.checkpoints_roots = ["/models/checkpoints"]
|
||||
config_module.config.unet_roots = []
|
||||
|
||||
result = scanner.adjust_metadata(metadata, "/models/checkpoints/model.safetensors", "/models/checkpoints")
|
||||
assert result.sub_type == "checkpoint"
|
||||
finally:
|
||||
if original_checkpoints_roots is not None:
|
||||
config_module.config.checkpoints_roots = original_checkpoints_roots
|
||||
|
||||
def test_adjust_cached_entry_sets_sub_type(self):
|
||||
"""adjust_cached_entry should set sub_type on entry."""
|
||||
scanner = self.create_scanner()
|
||||
# Mock get_model_roots to return the expected roots
|
||||
scanner.get_model_roots = lambda: ["/models/unet"]
|
||||
|
||||
entry = {
|
||||
"file_path": "/models/unet/model.safetensors",
|
||||
"model_name": "Test",
|
||||
}
|
||||
|
||||
from py import config as config_module
|
||||
original_checkpoints_roots = getattr(config_module.config, 'checkpoints_roots', None)
|
||||
original_unet_roots = getattr(config_module.config, 'unet_roots', None)
|
||||
|
||||
try:
|
||||
config_module.config.checkpoints_roots = []
|
||||
config_module.config.unet_roots = ["/models/unet"]
|
||||
|
||||
result = scanner.adjust_cached_entry(entry)
|
||||
assert result["sub_type"] == "diffusion_model"
|
||||
# Also sets model_type for backward compatibility
|
||||
assert result["model_type"] == "diffusion_model"
|
||||
finally:
|
||||
if original_checkpoints_roots is not None:
|
||||
config_module.config.checkpoints_roots = original_checkpoints_roots
|
||||
if original_unet_roots is not None:
|
||||
config_module.config.unet_roots = original_unet_roots
|
||||
203
tests/services/test_model_query_sub_type.py
Normal file
203
tests/services/test_model_query_sub_type.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""Tests for model_query sub_type resolution."""
|
||||
|
||||
import pytest
|
||||
from py.services.model_query import (
|
||||
_coerce_to_str,
|
||||
normalize_sub_type,
|
||||
normalize_civitai_model_type,
|
||||
resolve_sub_type,
|
||||
resolve_civitai_model_type,
|
||||
FilterCriteria,
|
||||
ModelFilterSet,
|
||||
)
|
||||
|
||||
|
||||
class TestCoerceToStr:
|
||||
"""Test _coerce_to_str helper."""
|
||||
|
||||
def test_none_returns_none(self):
|
||||
assert _coerce_to_str(None) is None
|
||||
|
||||
def test_string_returns_stripped(self):
|
||||
assert _coerce_to_str(" test ") == "test"
|
||||
|
||||
def test_empty_returns_none(self):
|
||||
assert _coerce_to_str(" ") is None
|
||||
|
||||
def test_number_converts_to_str(self):
|
||||
assert _coerce_to_str(123) == "123"
|
||||
|
||||
|
||||
class TestNormalizeSubType:
|
||||
"""Test normalize_sub_type function."""
|
||||
|
||||
def test_normalizes_to_lowercase(self):
|
||||
assert normalize_sub_type("LoRA") == "lora"
|
||||
assert normalize_sub_type("CHECKPOINT") == "checkpoint"
|
||||
|
||||
def test_strips_whitespace(self):
|
||||
assert normalize_sub_type(" LoRA ") == "lora"
|
||||
|
||||
def test_none_returns_none(self):
|
||||
assert normalize_sub_type(None) is None
|
||||
|
||||
def test_empty_returns_none(self):
|
||||
assert normalize_sub_type("") is None
|
||||
|
||||
|
||||
class TestNormalizeCivitaiModelTypeAlias:
|
||||
"""Test normalize_civitai_model_type is alias for normalize_sub_type."""
|
||||
|
||||
def test_alias_works_correctly(self):
|
||||
assert normalize_civitai_model_type("LoRA") == "lora"
|
||||
assert normalize_civitai_model_type("CHECKPOINT") == "checkpoint"
|
||||
|
||||
|
||||
class TestResolveSubType:
|
||||
"""Test resolve_sub_type function priority."""
|
||||
|
||||
def test_priority_1_sub_type_field(self):
|
||||
"""Priority 1: entry['sub_type'] should be used first."""
|
||||
entry = {
|
||||
"sub_type": "locon",
|
||||
"model_type": "checkpoint", # Should be ignored
|
||||
"civitai": {"model": {"type": "dora"}}, # Should be ignored
|
||||
}
|
||||
assert resolve_sub_type(entry) == "locon"
|
||||
|
||||
def test_priority_2_model_type_field(self):
|
||||
"""Priority 2: entry['model_type'] as fallback."""
|
||||
entry = {
|
||||
"model_type": "checkpoint",
|
||||
"civitai": {"model": {"type": "dora"}}, # Should be ignored
|
||||
}
|
||||
assert resolve_sub_type(entry) == "checkpoint"
|
||||
|
||||
def test_priority_3_civitai_model_type(self):
|
||||
"""Priority 3: civitai.model.type as fallback."""
|
||||
entry = {
|
||||
"civitai": {"model": {"type": "dora"}},
|
||||
}
|
||||
assert resolve_sub_type(entry) == "dora"
|
||||
|
||||
def test_priority_4_default(self):
|
||||
"""Priority 4: default to LORA when nothing found."""
|
||||
entry = {}
|
||||
assert resolve_sub_type(entry) == "LORA"
|
||||
|
||||
def test_empty_sub_type_falls_back(self):
|
||||
"""Empty sub_type should fall back to model_type."""
|
||||
entry = {
|
||||
"sub_type": "",
|
||||
"model_type": "checkpoint",
|
||||
}
|
||||
assert resolve_sub_type(entry) == "checkpoint"
|
||||
|
||||
def test_whitespace_sub_type_falls_back(self):
|
||||
"""Whitespace sub_type should fall back to model_type."""
|
||||
entry = {
|
||||
"sub_type": " ",
|
||||
"model_type": "checkpoint",
|
||||
}
|
||||
assert resolve_sub_type(entry) == "checkpoint"
|
||||
|
||||
def test_none_entry_returns_default(self):
|
||||
"""None entry should return default."""
|
||||
assert resolve_sub_type(None) == "LORA"
|
||||
|
||||
def test_non_mapping_returns_default(self):
|
||||
"""Non-mapping entry should return default."""
|
||||
assert resolve_sub_type("invalid") == "LORA"
|
||||
|
||||
|
||||
class TestResolveCivitaiModelTypeAlias:
|
||||
"""Test resolve_civitai_model_type is alias for resolve_sub_type."""
|
||||
|
||||
def test_alias_works_correctly(self):
|
||||
entry = {"sub_type": "locon"}
|
||||
assert resolve_civitai_model_type(entry) == "locon"
|
||||
|
||||
|
||||
class TestModelFilterSetWithSubType:
|
||||
"""Test ModelFilterSet applies model_types filtering correctly."""
|
||||
|
||||
def create_mock_settings(self):
|
||||
class MockSettings:
|
||||
def get(self, key, default=None):
|
||||
return default
|
||||
return MockSettings()
|
||||
|
||||
def test_filter_by_sub_type(self):
|
||||
"""Filter should work with sub_type field."""
|
||||
settings = self.create_mock_settings()
|
||||
filter_set = ModelFilterSet(settings)
|
||||
|
||||
data = [
|
||||
{"sub_type": "lora", "model_name": "Model 1"},
|
||||
{"sub_type": "locon", "model_name": "Model 2"},
|
||||
{"sub_type": "dora", "model_name": "Model 3"},
|
||||
]
|
||||
|
||||
criteria = FilterCriteria(model_types=["lora", "locon"])
|
||||
result = filter_set.apply(data, criteria)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["model_name"] == "Model 1"
|
||||
assert result[1]["model_name"] == "Model 2"
|
||||
|
||||
def test_filter_falls_back_to_model_type(self):
|
||||
"""Filter should fall back to model_type field."""
|
||||
settings = self.create_mock_settings()
|
||||
filter_set = ModelFilterSet(settings)
|
||||
|
||||
data = [
|
||||
{"model_type": "lora", "model_name": "Model 1"}, # Old field
|
||||
{"sub_type": "locon", "model_name": "Model 2"}, # New field
|
||||
]
|
||||
|
||||
criteria = FilterCriteria(model_types=["lora", "locon"])
|
||||
result = filter_set.apply(data, criteria)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
def test_filter_uses_civitai_type(self):
|
||||
"""Filter should use civitai.model.type as last resort."""
|
||||
settings = self.create_mock_settings()
|
||||
filter_set = ModelFilterSet(settings)
|
||||
|
||||
data = [
|
||||
{"civitai": {"model": {"type": "dora"}}, "model_name": "Model 1"},
|
||||
]
|
||||
|
||||
criteria = FilterCriteria(model_types=["dora"])
|
||||
result = filter_set.apply(data, criteria)
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
def test_filter_case_insensitive(self):
|
||||
"""Filter should be case insensitive."""
|
||||
settings = self.create_mock_settings()
|
||||
filter_set = ModelFilterSet(settings)
|
||||
|
||||
data = [
|
||||
{"sub_type": "LoRA", "model_name": "Model 1"},
|
||||
]
|
||||
|
||||
criteria = FilterCriteria(model_types=["lora"])
|
||||
result = filter_set.apply(data, criteria)
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
def test_filter_no_match_returns_empty(self):
|
||||
"""Filter with no match should return empty list."""
|
||||
settings = self.create_mock_settings()
|
||||
filter_set = ModelFilterSet(settings)
|
||||
|
||||
data = [
|
||||
{"sub_type": "lora", "model_name": "Model 1"},
|
||||
]
|
||||
|
||||
criteria = FilterCriteria(model_types=["checkpoint"])
|
||||
result = filter_set.apply(data, criteria)
|
||||
|
||||
assert len(result) == 0
|
||||
230
tests/services/test_service_format_response_sub_type.py
Normal file
230
tests/services/test_service_format_response_sub_type.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""Tests for service format_response sub_type inclusion."""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from py.services.lora_service import LoraService
|
||||
from py.services.checkpoint_service import CheckpointService
|
||||
from py.services.embedding_service import EmbeddingService
|
||||
|
||||
|
||||
class TestLoraServiceFormatResponse:
|
||||
"""Test LoraService.format_response includes sub_type."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_scanner(self):
|
||||
scanner = MagicMock()
|
||||
scanner._hash_index = MagicMock()
|
||||
return scanner
|
||||
|
||||
@pytest.fixture
|
||||
def lora_service(self, mock_scanner):
|
||||
return LoraService(mock_scanner)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_response_includes_sub_type(self, lora_service):
|
||||
"""format_response should include sub_type field."""
|
||||
lora_data = {
|
||||
"model_name": "Test LoRA",
|
||||
"file_name": "test_lora",
|
||||
"preview_url": "test.webp",
|
||||
"preview_nsfw_level": 0,
|
||||
"base_model": "SDXL",
|
||||
"folder": "",
|
||||
"sha256": "abc123",
|
||||
"file_path": "/models/test_lora.safetensors",
|
||||
"size": 1000,
|
||||
"modified": 1234567890.0,
|
||||
"tags": [],
|
||||
"from_civitai": True,
|
||||
"usage_count": 0,
|
||||
"usage_tips": "",
|
||||
"notes": "",
|
||||
"favorite": False,
|
||||
"sub_type": "locon", # New field
|
||||
"civitai": {},
|
||||
}
|
||||
|
||||
result = await lora_service.format_response(lora_data)
|
||||
|
||||
assert "sub_type" in result
|
||||
assert result["sub_type"] == "locon"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_response_falls_back_to_model_type(self, lora_service):
|
||||
"""format_response should fall back to model_type if sub_type missing."""
|
||||
lora_data = {
|
||||
"model_name": "Test LoRA",
|
||||
"file_name": "test_lora",
|
||||
"preview_url": "test.webp",
|
||||
"preview_nsfw_level": 0,
|
||||
"base_model": "SDXL",
|
||||
"folder": "",
|
||||
"sha256": "abc123",
|
||||
"file_path": "/models/test_lora.safetensors",
|
||||
"size": 1000,
|
||||
"modified": 1234567890.0,
|
||||
"tags": [],
|
||||
"from_civitai": True,
|
||||
"model_type": "dora", # Old field
|
||||
"civitai": {},
|
||||
}
|
||||
|
||||
result = await lora_service.format_response(lora_data)
|
||||
|
||||
assert result["sub_type"] == "dora"
|
||||
assert result["model_type"] == "dora" # Both should be set
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_response_defaults_to_lora(self, lora_service):
|
||||
"""format_response should default to 'lora' if no type field."""
|
||||
lora_data = {
|
||||
"model_name": "Test LoRA",
|
||||
"file_name": "test_lora",
|
||||
"preview_url": "test.webp",
|
||||
"preview_nsfw_level": 0,
|
||||
"base_model": "SDXL",
|
||||
"folder": "",
|
||||
"sha256": "abc123",
|
||||
"file_path": "/models/test_lora.safetensors",
|
||||
"size": 1000,
|
||||
"modified": 1234567890.0,
|
||||
"tags": [],
|
||||
"from_civitai": True,
|
||||
"civitai": {},
|
||||
}
|
||||
|
||||
result = await lora_service.format_response(lora_data)
|
||||
|
||||
assert result["sub_type"] == "lora"
|
||||
assert result["model_type"] == "lora"
|
||||
|
||||
|
||||
class TestCheckpointServiceFormatResponse:
|
||||
"""Test CheckpointService.format_response includes sub_type."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_scanner(self):
|
||||
scanner = MagicMock()
|
||||
scanner._hash_index = MagicMock()
|
||||
return scanner
|
||||
|
||||
@pytest.fixture
|
||||
def checkpoint_service(self, mock_scanner):
|
||||
return CheckpointService(mock_scanner)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_response_includes_sub_type_checkpoint(self, checkpoint_service):
|
||||
"""format_response should include sub_type field for checkpoint."""
|
||||
checkpoint_data = {
|
||||
"model_name": "Test Checkpoint",
|
||||
"file_name": "test_ckpt",
|
||||
"preview_url": "test.webp",
|
||||
"preview_nsfw_level": 0,
|
||||
"base_model": "SDXL",
|
||||
"folder": "",
|
||||
"sha256": "abc123",
|
||||
"file_path": "/models/test.safetensors",
|
||||
"size": 1000,
|
||||
"modified": 1234567890.0,
|
||||
"tags": [],
|
||||
"from_civitai": True,
|
||||
"sub_type": "checkpoint",
|
||||
"civitai": {},
|
||||
}
|
||||
|
||||
result = await checkpoint_service.format_response(checkpoint_data)
|
||||
|
||||
assert "sub_type" in result
|
||||
assert result["sub_type"] == "checkpoint"
|
||||
assert result["model_type"] == "checkpoint"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_response_includes_sub_type_diffusion_model(self, checkpoint_service):
|
||||
"""format_response should include sub_type field for diffusion_model."""
|
||||
checkpoint_data = {
|
||||
"model_name": "Test Diffusion Model",
|
||||
"file_name": "test_unet",
|
||||
"preview_url": "test.webp",
|
||||
"preview_nsfw_level": 0,
|
||||
"base_model": "SDXL",
|
||||
"folder": "",
|
||||
"sha256": "abc123",
|
||||
"file_path": "/models/test.safetensors",
|
||||
"size": 1000,
|
||||
"modified": 1234567890.0,
|
||||
"tags": [],
|
||||
"from_civitai": True,
|
||||
"sub_type": "diffusion_model",
|
||||
"civitai": {},
|
||||
}
|
||||
|
||||
result = await checkpoint_service.format_response(checkpoint_data)
|
||||
|
||||
assert result["sub_type"] == "diffusion_model"
|
||||
assert result["model_type"] == "diffusion_model"
|
||||
|
||||
|
||||
class TestEmbeddingServiceFormatResponse:
|
||||
"""Test EmbeddingService.format_response includes sub_type."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_scanner(self):
|
||||
scanner = MagicMock()
|
||||
scanner._hash_index = MagicMock()
|
||||
return scanner
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_service(self, mock_scanner):
|
||||
return EmbeddingService(mock_scanner)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_response_includes_sub_type(self, embedding_service):
|
||||
"""format_response should include sub_type field."""
|
||||
embedding_data = {
|
||||
"model_name": "Test Embedding",
|
||||
"file_name": "test_emb",
|
||||
"preview_url": "test.webp",
|
||||
"preview_nsfw_level": 0,
|
||||
"base_model": "SD1.5",
|
||||
"folder": "",
|
||||
"sha256": "abc123",
|
||||
"file_path": "/models/test.pt",
|
||||
"size": 1000,
|
||||
"modified": 1234567890.0,
|
||||
"tags": [],
|
||||
"from_civitai": True,
|
||||
"sub_type": "embedding",
|
||||
"civitai": {},
|
||||
}
|
||||
|
||||
result = await embedding_service.format_response(embedding_data)
|
||||
|
||||
assert "sub_type" in result
|
||||
assert result["sub_type"] == "embedding"
|
||||
assert result["model_type"] == "embedding"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_response_defaults_to_embedding(self, embedding_service):
|
||||
"""format_response should default to 'embedding' if no type field."""
|
||||
embedding_data = {
|
||||
"model_name": "Test Embedding",
|
||||
"file_name": "test_emb",
|
||||
"preview_url": "test.webp",
|
||||
"preview_nsfw_level": 0,
|
||||
"base_model": "SD1.5",
|
||||
"folder": "",
|
||||
"sha256": "abc123",
|
||||
"file_path": "/models/test.pt",
|
||||
"size": 1000,
|
||||
"modified": 1234567890.0,
|
||||
"tags": [],
|
||||
"from_civitai": True,
|
||||
"civitai": {},
|
||||
}
|
||||
|
||||
result = await embedding_service.format_response(embedding_data)
|
||||
|
||||
assert result["sub_type"] == "embedding"
|
||||
assert result["model_type"] == "embedding"
|
||||
127
tests/utils/test_models_sub_type.py
Normal file
127
tests/utils/test_models_sub_type.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Tests for model sub_type field refactoring."""
|
||||
|
||||
import pytest
|
||||
from py.utils.models import (
|
||||
BaseModelMetadata,
|
||||
LoraMetadata,
|
||||
CheckpointMetadata,
|
||||
EmbeddingMetadata,
|
||||
)
|
||||
|
||||
|
||||
class TestCheckpointMetadataSubType:
|
||||
"""Test CheckpointMetadata uses sub_type field."""
|
||||
|
||||
def test_checkpoint_has_sub_type_field(self):
|
||||
"""CheckpointMetadata should have sub_type field."""
|
||||
metadata = CheckpointMetadata(
|
||||
file_name="test",
|
||||
model_name="Test Model",
|
||||
file_path="/test/model.safetensors",
|
||||
size=1000,
|
||||
modified=1234567890.0,
|
||||
sha256="abc123",
|
||||
base_model="SDXL",
|
||||
preview_url="",
|
||||
)
|
||||
assert hasattr(metadata, "sub_type")
|
||||
assert metadata.sub_type == "checkpoint"
|
||||
|
||||
def test_checkpoint_sub_type_can_be_diffusion_model(self):
|
||||
"""CheckpointMetadata sub_type can be set to diffusion_model."""
|
||||
metadata = CheckpointMetadata(
|
||||
file_name="test",
|
||||
model_name="Test Model",
|
||||
file_path="/test/model.safetensors",
|
||||
size=1000,
|
||||
modified=1234567890.0,
|
||||
sha256="abc123",
|
||||
base_model="SDXL",
|
||||
preview_url="",
|
||||
sub_type="diffusion_model",
|
||||
)
|
||||
assert metadata.sub_type == "diffusion_model"
|
||||
|
||||
def test_checkpoint_from_civitai_info_uses_sub_type(self):
|
||||
"""from_civitai_info should use sub_type from version_info."""
|
||||
version_info = {
|
||||
"baseModel": "SDXL",
|
||||
"model": {"name": "Test", "description": "", "tags": []},
|
||||
"files": [{"name": "model.safetensors", "sizeKB": 1000, "hashes": {"SHA256": "abc123"}, "primary": True}],
|
||||
}
|
||||
file_info = version_info["files"][0]
|
||||
save_path = "/test/model.safetensors"
|
||||
|
||||
metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
|
||||
assert hasattr(metadata, "sub_type")
|
||||
# When type is missing from version_info, defaults to "checkpoint"
|
||||
assert metadata.sub_type == "checkpoint"
|
||||
|
||||
|
||||
class TestEmbeddingMetadataSubType:
|
||||
"""Test EmbeddingMetadata uses sub_type field."""
|
||||
|
||||
def test_embedding_has_sub_type_field(self):
|
||||
"""EmbeddingMetadata should have sub_type field."""
|
||||
metadata = EmbeddingMetadata(
|
||||
file_name="test",
|
||||
model_name="Test Model",
|
||||
file_path="/test/model.pt",
|
||||
size=1000,
|
||||
modified=1234567890.0,
|
||||
sha256="abc123",
|
||||
base_model="SD1.5",
|
||||
preview_url="",
|
||||
)
|
||||
assert hasattr(metadata, "sub_type")
|
||||
assert metadata.sub_type == "embedding"
|
||||
|
||||
def test_embedding_from_civitai_info_uses_sub_type(self):
|
||||
"""from_civitai_info should use sub_type from version_info."""
|
||||
version_info = {
|
||||
"baseModel": "SD1.5",
|
||||
"model": {"name": "Test", "description": "", "tags": []},
|
||||
"files": [{"name": "model.pt", "sizeKB": 1000, "hashes": {"SHA256": "abc123"}, "primary": True}],
|
||||
}
|
||||
file_info = version_info["files"][0]
|
||||
save_path = "/test/model.pt"
|
||||
|
||||
metadata = EmbeddingMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
|
||||
assert hasattr(metadata, "sub_type")
|
||||
assert metadata.sub_type == "embedding"
|
||||
|
||||
|
||||
class TestLoraMetadataConsistency:
|
||||
"""Test LoraMetadata consistency (no sub_type field, uses civitai data)."""
|
||||
|
||||
def test_lora_does_not_have_sub_type_field(self):
|
||||
"""LoraMetadata should not have sub_type field (uses civitai.model.type)."""
|
||||
metadata = LoraMetadata(
|
||||
file_name="test",
|
||||
model_name="Test Model",
|
||||
file_path="/test/model.safetensors",
|
||||
size=1000,
|
||||
modified=1234567890.0,
|
||||
sha256="abc123",
|
||||
base_model="SDXL",
|
||||
preview_url="",
|
||||
)
|
||||
# Lora doesn't have sub_type field - it uses civitai data
|
||||
assert not hasattr(metadata, "sub_type")
|
||||
|
||||
def test_lora_from_civitai_info_extracts_type(self):
|
||||
"""from_civitai_info should extract type from civitai data."""
|
||||
version_info = {
|
||||
"baseModel": "SDXL",
|
||||
"model": {"name": "Test", "description": "", "tags": [], "type": "Lora"},
|
||||
"files": [{"name": "model.safetensors", "sizeKB": 1000, "hashes": {"SHA256": "abc123"}, "primary": True}],
|
||||
}
|
||||
file_info = version_info["files"][0]
|
||||
save_path = "/test/model.safetensors"
|
||||
|
||||
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
|
||||
# Type is stored in civitai dict
|
||||
assert metadata.civitai.get("model", {}).get("type") == "Lora"
|
||||
Reference in New Issue
Block a user