mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
refactor(model-type): complete phase 5 cleanup by removing deprecated model_type field
- Remove backward compatibility code for `model_type` in `ModelScanner._build_cache_entry()` - Update `CheckpointScanner` to only handle `sub_type` in `adjust_metadata()` and `adjust_cached_entry()` - Delete deprecated aliases `resolve_civitai_model_type` and `normalize_civitai_model_type` from `model_query.py` - Update frontend components (`RecipeModal.js`, `ModelCard.js`, etc.) to use `sub_type` instead of `model_type` - Update API response format to return only `sub_type`, removing `model_type` from service responses - Revise technical documentation to mark Phase 5 as completed and remove outdated TODO items All cleanup tasks for the model type refactoring are now complete, ensuring consistent use of `sub_type` across the codebase.
This commit is contained in:
@@ -126,7 +126,7 @@ async def test_persisted_cache_restores_model_type(tmp_path: Path, monkeypatch):
|
||||
assert loaded is True
|
||||
|
||||
cache = await scanner.get_cached_data()
|
||||
types_by_path = {item["file_path"]: item.get("model_type") for item in cache.raw_data}
|
||||
types_by_path = {item["file_path"]: item.get("sub_type") for item in cache.raw_data}
|
||||
|
||||
assert types_by_path[normalized_checkpoint_file] == "checkpoint"
|
||||
assert types_by_path[normalized_unet_file] == "diffusion_model"
|
||||
|
||||
@@ -136,8 +136,7 @@ class TestCheckpointScannerSubType:
|
||||
|
||||
result = scanner.adjust_cached_entry(entry)
|
||||
assert result["sub_type"] == "diffusion_model"
|
||||
# Also sets model_type for backward compatibility
|
||||
assert result["model_type"] == "diffusion_model"
|
||||
assert "model_type" not in result # Removed in refactoring
|
||||
finally:
|
||||
if original_checkpoints_roots is not None:
|
||||
config_module.config.checkpoints_roots = original_checkpoints_roots
|
||||
|
||||
@@ -479,7 +479,7 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
|
||||
assert dummy_scanner.calls # ensure cache updated
|
||||
|
||||
|
||||
async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_path):
|
||||
async def test_execute_download_adjusts_checkpoint_sub_type(monkeypatch, tmp_path):
|
||||
manager = DownloadManager()
|
||||
|
||||
root_dir = tmp_path / "checkpoints"
|
||||
@@ -494,7 +494,7 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p
|
||||
self.file_name = path.stem
|
||||
self.preview_url = None
|
||||
self.preview_nsfw_level = 0
|
||||
self.model_type = "checkpoint"
|
||||
self.sub_type = "checkpoint"
|
||||
|
||||
def generate_unique_filename(self, *_args, **_kwargs):
|
||||
return os.path.basename(self.file_path)
|
||||
@@ -505,7 +505,7 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p
|
||||
def to_dict(self):
|
||||
return {
|
||||
"file_path": self.file_path,
|
||||
"model_type": self.model_type,
|
||||
"sub_type": self.sub_type,
|
||||
"sha256": self.sha256,
|
||||
}
|
||||
|
||||
@@ -538,12 +538,12 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p
|
||||
self, metadata_obj, _file_path: str, root_path: Optional[str]
|
||||
):
|
||||
if root_path:
|
||||
metadata_obj.model_type = "diffusion_model"
|
||||
metadata_obj.sub_type = "diffusion_model"
|
||||
return metadata_obj
|
||||
|
||||
def adjust_cached_entry(self, entry):
|
||||
if entry.get("file_path", "").startswith(self.root):
|
||||
entry["model_type"] = "diffusion_model"
|
||||
entry["sub_type"] = "diffusion_model"
|
||||
return entry
|
||||
|
||||
async def add_model_to_cache(self, metadata_dict, relative_path):
|
||||
@@ -570,12 +570,12 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p
|
||||
)
|
||||
|
||||
assert result == {"success": True}
|
||||
assert metadata.model_type == "diffusion_model"
|
||||
assert metadata.sub_type == "diffusion_model"
|
||||
saved_metadata = MetadataManager.save_metadata.await_args.args[1]
|
||||
assert saved_metadata.model_type == "diffusion_model"
|
||||
assert saved_metadata.sub_type == "diffusion_model"
|
||||
assert dummy_scanner.add_calls
|
||||
cached_entry, _ = dummy_scanner.add_calls[0]
|
||||
assert cached_entry["model_type"] == "diffusion_model"
|
||||
assert cached_entry["sub_type"] == "diffusion_model"
|
||||
|
||||
|
||||
async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path):
|
||||
|
||||
@@ -4,9 +4,7 @@ import pytest
|
||||
from py.services.model_query import (
|
||||
_coerce_to_str,
|
||||
normalize_sub_type,
|
||||
normalize_civitai_model_type,
|
||||
resolve_sub_type,
|
||||
resolve_civitai_model_type,
|
||||
FilterCriteria,
|
||||
ModelFilterSet,
|
||||
)
|
||||
@@ -45,14 +43,6 @@ class TestNormalizeSubType:
|
||||
assert normalize_sub_type("") is None
|
||||
|
||||
|
||||
class TestNormalizeCivitaiModelTypeAlias:
|
||||
"""Test normalize_civitai_model_type is alias for normalize_sub_type."""
|
||||
|
||||
def test_alias_works_correctly(self):
|
||||
assert normalize_civitai_model_type("LoRA") == "lora"
|
||||
assert normalize_civitai_model_type("CHECKPOINT") == "checkpoint"
|
||||
|
||||
|
||||
class TestResolveSubType:
|
||||
"""Test resolve_sub_type function priority."""
|
||||
|
||||
@@ -60,44 +50,35 @@ class TestResolveSubType:
|
||||
"""Priority 1: entry['sub_type'] should be used first."""
|
||||
entry = {
|
||||
"sub_type": "locon",
|
||||
"model_type": "checkpoint", # Should be ignored
|
||||
"civitai": {"model": {"type": "dora"}}, # Should be ignored
|
||||
}
|
||||
assert resolve_sub_type(entry) == "locon"
|
||||
|
||||
def test_priority_2_model_type_field(self):
|
||||
"""Priority 2: entry['model_type'] as fallback."""
|
||||
entry = {
|
||||
"model_type": "checkpoint",
|
||||
"civitai": {"model": {"type": "dora"}}, # Should be ignored
|
||||
}
|
||||
assert resolve_sub_type(entry) == "checkpoint"
|
||||
|
||||
def test_priority_3_civitai_model_type(self):
|
||||
"""Priority 3: civitai.model.type as fallback."""
|
||||
def test_priority_2_civitai_model_type(self):
|
||||
"""Priority 2: civitai.model.type as fallback."""
|
||||
entry = {
|
||||
"civitai": {"model": {"type": "dora"}},
|
||||
}
|
||||
assert resolve_sub_type(entry) == "dora"
|
||||
|
||||
def test_priority_4_default(self):
|
||||
"""Priority 4: default to LORA when nothing found."""
|
||||
def test_priority_3_default(self):
|
||||
"""Priority 3: default to LORA when nothing found."""
|
||||
entry = {}
|
||||
assert resolve_sub_type(entry) == "LORA"
|
||||
|
||||
def test_empty_sub_type_falls_back(self):
|
||||
"""Empty sub_type should fall back to model_type."""
|
||||
"""Empty sub_type should fall back to civitai type."""
|
||||
entry = {
|
||||
"sub_type": "",
|
||||
"model_type": "checkpoint",
|
||||
"civitai": {"model": {"type": "checkpoint"}},
|
||||
}
|
||||
assert resolve_sub_type(entry) == "checkpoint"
|
||||
|
||||
def test_whitespace_sub_type_falls_back(self):
|
||||
"""Whitespace sub_type should fall back to model_type."""
|
||||
"""Whitespace sub_type should fall back to civitai type."""
|
||||
entry = {
|
||||
"sub_type": " ",
|
||||
"model_type": "checkpoint",
|
||||
"civitai": {"model": {"type": "checkpoint"}},
|
||||
}
|
||||
assert resolve_sub_type(entry) == "checkpoint"
|
||||
|
||||
@@ -110,14 +91,6 @@ class TestResolveSubType:
|
||||
assert resolve_sub_type("invalid") == "LORA"
|
||||
|
||||
|
||||
class TestResolveCivitaiModelTypeAlias:
|
||||
"""Test resolve_civitai_model_type is alias for resolve_sub_type."""
|
||||
|
||||
def test_alias_works_correctly(self):
|
||||
entry = {"sub_type": "locon"}
|
||||
assert resolve_civitai_model_type(entry) == "locon"
|
||||
|
||||
|
||||
class TestModelFilterSetWithSubType:
|
||||
"""Test ModelFilterSet applies model_types filtering correctly."""
|
||||
|
||||
@@ -145,23 +118,8 @@ class TestModelFilterSetWithSubType:
|
||||
assert result[0]["model_name"] == "Model 1"
|
||||
assert result[1]["model_name"] == "Model 2"
|
||||
|
||||
def test_filter_falls_back_to_model_type(self):
|
||||
"""Filter should fall back to model_type field."""
|
||||
settings = self.create_mock_settings()
|
||||
filter_set = ModelFilterSet(settings)
|
||||
|
||||
data = [
|
||||
{"model_type": "lora", "model_name": "Model 1"}, # Old field
|
||||
{"sub_type": "locon", "model_name": "Model 2"}, # New field
|
||||
]
|
||||
|
||||
criteria = FilterCriteria(model_types=["lora", "locon"])
|
||||
result = filter_set.apply(data, criteria)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
def test_filter_uses_civitai_type(self):
|
||||
"""Filter should use civitai.model.type as last resort."""
|
||||
"""Filter should use civitai.model.type as fallback."""
|
||||
settings = self.create_mock_settings()
|
||||
filter_set = ModelFilterSet(settings)
|
||||
|
||||
|
||||
@@ -521,7 +521,7 @@ async def test_reconcile_cache_applies_adjust_cached_entry(tmp_path: Path):
|
||||
|
||||
def _adjust(self, entry: dict) -> dict:
|
||||
applied.append(entry["file_path"])
|
||||
entry["model_type"] = "adjusted"
|
||||
entry["custom_field"] = "adjusted"
|
||||
return entry
|
||||
|
||||
scanner.adjust_cached_entry = MethodType(_adjust, scanner)
|
||||
@@ -538,7 +538,7 @@ async def test_reconcile_cache_applies_adjust_cached_entry(tmp_path: Path):
|
||||
assert normalized_new in applied
|
||||
|
||||
new_entry = next(item for item in scanner._cache.raw_data if item["file_path"] == normalized_new)
|
||||
assert new_entry["model_type"] == "adjusted"
|
||||
assert new_entry["custom_field"] == "adjusted"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -42,7 +42,7 @@ class TestLoraServiceFormatResponse:
|
||||
"usage_tips": "",
|
||||
"notes": "",
|
||||
"favorite": False,
|
||||
"sub_type": "locon", # New field
|
||||
"sub_type": "locon",
|
||||
"civitai": {},
|
||||
}
|
||||
|
||||
@@ -50,31 +50,7 @@ class TestLoraServiceFormatResponse:
|
||||
|
||||
assert "sub_type" in result
|
||||
assert result["sub_type"] == "locon"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_response_falls_back_to_model_type(self, lora_service):
|
||||
"""format_response should fall back to model_type if sub_type missing."""
|
||||
lora_data = {
|
||||
"model_name": "Test LoRA",
|
||||
"file_name": "test_lora",
|
||||
"preview_url": "test.webp",
|
||||
"preview_nsfw_level": 0,
|
||||
"base_model": "SDXL",
|
||||
"folder": "",
|
||||
"sha256": "abc123",
|
||||
"file_path": "/models/test_lora.safetensors",
|
||||
"size": 1000,
|
||||
"modified": 1234567890.0,
|
||||
"tags": [],
|
||||
"from_civitai": True,
|
||||
"model_type": "dora", # Old field
|
||||
"civitai": {},
|
||||
}
|
||||
|
||||
result = await lora_service.format_response(lora_data)
|
||||
|
||||
assert result["sub_type"] == "dora"
|
||||
assert result["model_type"] == "dora" # Both should be set
|
||||
assert "model_type" not in result # Removed in refactoring
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_response_defaults_to_lora(self, lora_service):
|
||||
@@ -98,7 +74,7 @@ class TestLoraServiceFormatResponse:
|
||||
result = await lora_service.format_response(lora_data)
|
||||
|
||||
assert result["sub_type"] == "lora"
|
||||
assert result["model_type"] == "lora"
|
||||
assert "model_type" not in result # Removed in refactoring
|
||||
|
||||
|
||||
class TestCheckpointServiceFormatResponse:
|
||||
@@ -138,7 +114,7 @@ class TestCheckpointServiceFormatResponse:
|
||||
|
||||
assert "sub_type" in result
|
||||
assert result["sub_type"] == "checkpoint"
|
||||
assert result["model_type"] == "checkpoint"
|
||||
assert "model_type" not in result # Removed in refactoring
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_response_includes_sub_type_diffusion_model(self, checkpoint_service):
|
||||
@@ -163,7 +139,7 @@ class TestCheckpointServiceFormatResponse:
|
||||
result = await checkpoint_service.format_response(checkpoint_data)
|
||||
|
||||
assert result["sub_type"] == "diffusion_model"
|
||||
assert result["model_type"] == "diffusion_model"
|
||||
assert "model_type" not in result # Removed in refactoring
|
||||
|
||||
|
||||
class TestEmbeddingServiceFormatResponse:
|
||||
@@ -203,7 +179,7 @@ class TestEmbeddingServiceFormatResponse:
|
||||
|
||||
assert "sub_type" in result
|
||||
assert result["sub_type"] == "embedding"
|
||||
assert result["model_type"] == "embedding"
|
||||
assert "model_type" not in result # Removed in refactoring
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_response_defaults_to_embedding(self, embedding_service):
|
||||
@@ -227,4 +203,4 @@ class TestEmbeddingServiceFormatResponse:
|
||||
result = await embedding_service.format_response(embedding_data)
|
||||
|
||||
assert result["sub_type"] == "embedding"
|
||||
assert result["model_type"] == "embedding"
|
||||
assert "model_type" not in result # Removed in refactoring
|
||||
|
||||
Reference in New Issue
Block a user