fix(download): auto fetch example images after model download

This commit is contained in:
Will Miao
2026-04-21 22:48:06 +08:00
parent 7fa40023b0
commit a1dff6dd47
2 changed files with 193 additions and 0 deletions

View File

@@ -75,6 +75,65 @@ class DownloadManager:
backend = (get_settings_manager().get("download_backend") or "python").strip()
return backend.lower() or "python"
async def _schedule_auto_example_images_download(
self,
*,
metadata,
model_type: str,
) -> None:
settings_manager = get_settings_manager()
if not settings_manager.get("auto_download_example_images", False):
return
if not settings_manager.get("example_images_path"):
logger.debug(
"Skipping automatic example images download; example_images_path is not configured"
)
return
raw_hash = getattr(metadata, "sha256", "") or ""
model_hash = str(raw_hash).strip().lower()
if not model_hash:
logger.debug(
"Skipping automatic example images download for %s; missing sha256",
getattr(metadata, "file_path", ""),
)
return
optimize = bool(settings_manager.get("optimize_example_images", True))
async def _run_auto_example_images_download() -> None:
try:
from ..utils.example_images_download_manager import (
DownloadInProgressError,
get_default_download_manager,
)
ws_manager = await ServiceRegistry.get_websocket_manager()
example_images_manager = get_default_download_manager(ws_manager)
await example_images_manager.start_force_download(
{
"model_hashes": [model_hash],
"optimize": optimize,
"model_types": [model_type],
"delay": 0,
}
)
except DownloadInProgressError:
logger.info(
"Skipping automatic example images download for %s; another example images download is already running",
model_hash,
)
except Exception as exc:
logger.warning(
"Automatic example images download failed for %s: %s",
model_hash,
exc,
exc_info=True,
)
asyncio.create_task(_run_auto_example_images_download())
async def _download_model_file(
self,
download_url: str,
@@ -1458,6 +1517,10 @@ class DownloadManager:
version_info,
model_version_id,
)
await self._schedule_auto_example_images_download(
metadata=metadata,
model_type=model_type,
)
# If early_access_msg exists and download failed, replace error message
if "early_access_msg" in locals() and not result.get("success", False):

View File

@@ -233,6 +233,136 @@ async def test_successful_download_uses_defaults(
assert captured["download_urls"] == ["https://example.invalid/file.safetensors"]
@pytest.mark.asyncio
async def test_successful_download_schedules_auto_example_images(
monkeypatch, scanners, metadata_provider, tmp_path
):
manager = DownloadManager()
scheduled = []
async def fake_execute_download(
self,
*,
download_urls,
save_dir,
metadata,
version_info,
relative_path,
progress_callback,
model_type,
download_id,
transfer_backend=None,
):
return {"success": True}
async def fake_schedule(self, *, metadata, model_type):
scheduled.append({"metadata": metadata, "model_type": model_type})
monkeypatch.setattr(
DownloadManager, "_execute_download", fake_execute_download, raising=False
)
monkeypatch.setattr(
DownloadManager,
"_schedule_auto_example_images_download",
fake_schedule,
raising=False,
)
result = await manager.download_from_civitai(
model_version_id=99,
save_dir=str(tmp_path),
use_default_paths=True,
progress_callback=None,
source=None,
)
assert result["success"] is True
assert len(scheduled) == 1
assert scheduled[0]["model_type"] == "lora"
assert scheduled[0]["metadata"].sha256 == "sha256"
@pytest.mark.asyncio
async def test_auto_example_images_download_uses_settings_payload(
monkeypatch, tmp_path
):
manager = DownloadManager()
settings = get_settings_manager()
settings.settings["auto_download_example_images"] = True
settings.settings["example_images_path"] = str(tmp_path / "examples")
settings.settings["optimize_example_images"] = False
calls = []
class DummyExampleImagesManager:
async def start_force_download(self, payload):
calls.append(payload)
return {"success": True}
from py.utils import example_images_download_manager
monkeypatch.setattr(
ServiceRegistry,
"get_websocket_manager",
AsyncMock(return_value=object()),
)
monkeypatch.setattr(
example_images_download_manager,
"get_default_download_manager",
lambda _ws_manager: DummyExampleImagesManager(),
)
metadata = SimpleNamespace(sha256="ABCDEF", file_path="model.safetensors")
await manager._schedule_auto_example_images_download(
metadata=metadata,
model_type="lora",
)
for _ in range(10):
if calls:
break
await asyncio.sleep(0)
assert calls == [
{
"model_hashes": ["abcdef"],
"optimize": False,
"model_types": ["lora"],
"delay": 0,
}
]
@pytest.mark.asyncio
async def test_auto_example_images_download_skips_without_configuration(
monkeypatch, tmp_path
):
manager = DownloadManager()
settings = get_settings_manager()
settings.settings["auto_download_example_images"] = True
settings.settings["example_images_path"] = ""
get_ws_manager = AsyncMock(return_value=object())
monkeypatch.setattr(ServiceRegistry, "get_websocket_manager", get_ws_manager)
await manager._schedule_auto_example_images_download(
metadata=SimpleNamespace(sha256="abcdef", file_path="model.safetensors"),
model_type="lora",
)
await asyncio.sleep(0)
get_ws_manager.assert_not_called()
settings.settings["example_images_path"] = str(tmp_path / "examples")
await manager._schedule_auto_example_images_download(
metadata=SimpleNamespace(sha256="", file_path="model.safetensors"),
model_type="lora",
)
await asyncio.sleep(0)
get_ws_manager.assert_not_called()
@pytest.mark.asyncio
async def test_download_uses_active_mirrors(
monkeypatch, scanners, metadata_provider, tmp_path