feat(download): add configurable base model download exclusions

This commit is contained in:
Will Miao
2026-03-26 23:06:14 +08:00
parent 5b065b47d4
commit a5191414cc
22 changed files with 988 additions and 4 deletions

View File

@@ -38,6 +38,7 @@ def isolate_settings(monkeypatch, tmp_path):
"embedding": "{base_model}/{first_tag}",
},
"base_model_path_mappings": {"BaseModel": "MappedModel"},
"download_skip_base_models": [],
}
)
monkeypatch.setattr(manager, "settings", default_settings)
@@ -443,3 +444,49 @@ def test_distribute_preview_to_entries_keeps_existing_file(tmp_path):
assert targets[0] == str(existing_preview)
assert Path(targets[1]).read_bytes() == b"preview"
@pytest.mark.asyncio
async def test_download_skips_excluded_base_model(monkeypatch, scanners, metadata_provider):
manager = DownloadManager()
get_settings_manager().settings["download_skip_base_models"] = ["SDXL 1.0"]
metadata_provider.get_model_version = AsyncMock(
return_value={
"id": 42,
"model": {"type": "LoRA", "tags": ["fantasy"]},
"baseModel": "SDXL 1.0",
"creator": {"username": "Author"},
"files": [
{
"type": "Model",
"primary": True,
"downloadUrl": "https://example.invalid/file.safetensors",
"name": "file.safetensors",
}
],
}
)
execute_download = AsyncMock()
monkeypatch.setattr(
DownloadManager, "_execute_download", execute_download, raising=False
)
result = await manager.download_from_civitai(
model_version_id=99,
use_default_paths=True,
progress_callback=None,
source=None,
)
assert result["success"] is True
assert result["skipped"] is True
assert result["status"] == "skipped"
assert result["reason"] == "base_model_excluded"
assert result["base_model"] == "SDXL 1.0"
assert result["file_name"] == "file.safetensors"
assert "file.safetensors" in result["message"]
execute_download.assert_not_called()
assert manager._active_downloads[result["download_id"]]["status"] == "skipped"

View File

@@ -605,3 +605,28 @@ def test_delete_library_switches_active(manager, tmp_path):
manager.delete_library("other")
assert manager.get_active_library_name() == "default"
def test_download_skip_base_models_are_normalized(manager):
manager.settings["download_skip_base_models"] = [
"SDXL 1.0",
"Invalid",
"SDXL 1.0",
"Pony",
"Other",
]
result = manager.get_download_skip_base_models()
assert result == ["SDXL 1.0", "Pony"]
assert manager.settings["download_skip_base_models"] == ["SDXL 1.0", "Pony"]
def test_setting_download_skip_base_models_normalizes_string_input(manager):
manager.set(
"download_skip_base_models",
"SDXL 1.0, Pony; Invalid\nSDXL 1.0"
)
assert manager.get("download_skip_base_models") == ["SDXL 1.0", "Pony"]