diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 5b297356..38c21e6f 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -75,6 +75,65 @@ class DownloadManager: backend = (get_settings_manager().get("download_backend") or "python").strip() return backend.lower() or "python" + async def _schedule_auto_example_images_download( + self, + *, + metadata, + model_type: str, + ) -> None: + settings_manager = get_settings_manager() + if not settings_manager.get("auto_download_example_images", False): + return + + if not settings_manager.get("example_images_path"): + logger.debug( + "Skipping automatic example images download; example_images_path is not configured" + ) + return + + raw_hash = getattr(metadata, "sha256", "") or "" + model_hash = str(raw_hash).strip().lower() + if not model_hash: + logger.debug( + "Skipping automatic example images download for %s; missing sha256", + getattr(metadata, "file_path", ""), + ) + return + + optimize = bool(settings_manager.get("optimize_example_images", True)) + + async def _run_auto_example_images_download() -> None: + try: + from ..utils.example_images_download_manager import ( + DownloadInProgressError, + get_default_download_manager, + ) + + ws_manager = await ServiceRegistry.get_websocket_manager() + example_images_manager = get_default_download_manager(ws_manager) + await example_images_manager.start_force_download( + { + "model_hashes": [model_hash], + "optimize": optimize, + "model_types": [model_type], + "delay": 0, + } + ) + except DownloadInProgressError: + logger.info( + "Skipping automatic example images download for %s; another example images download is already running", + model_hash, + ) + except Exception as exc: + logger.warning( + "Automatic example images download failed for %s: %s", + model_hash, + exc, + exc_info=True, + ) + + asyncio.create_task(_run_auto_example_images_download()) + async def _download_model_file( self, download_url: str, @@ -1458,6 +1517,10 @@ class DownloadManager: version_info, model_version_id, ) + await self._schedule_auto_example_images_download( + metadata=metadata, + model_type=model_type, + ) # If early_access_msg exists and download failed, replace error message if "early_access_msg" in locals() and not result.get("success", False): diff --git a/tests/services/test_download_manager_basic.py b/tests/services/test_download_manager_basic.py index 3117d612..22c7e42d 100644 --- a/tests/services/test_download_manager_basic.py +++ b/tests/services/test_download_manager_basic.py @@ -233,6 +233,136 @@ async def test_successful_download_uses_defaults( assert captured["download_urls"] == ["https://example.invalid/file.safetensors"] +@pytest.mark.asyncio +async def test_successful_download_schedules_auto_example_images( + monkeypatch, scanners, metadata_provider, tmp_path +): + manager = DownloadManager() + scheduled = [] + + async def fake_execute_download( + self, + *, + download_urls, + save_dir, + metadata, + version_info, + relative_path, + progress_callback, + model_type, + download_id, + transfer_backend=None, + ): + return {"success": True} + + async def fake_schedule(self, *, metadata, model_type): + scheduled.append({"metadata": metadata, "model_type": model_type}) + + monkeypatch.setattr( + DownloadManager, "_execute_download", fake_execute_download, raising=False + ) + monkeypatch.setattr( + DownloadManager, + "_schedule_auto_example_images_download", + fake_schedule, + raising=False, + ) + + result = await manager.download_from_civitai( + model_version_id=99, + save_dir=str(tmp_path), + use_default_paths=True, + progress_callback=None, + source=None, + ) + + assert result["success"] is True + assert len(scheduled) == 1 + assert scheduled[0]["model_type"] == "lora" + assert scheduled[0]["metadata"].sha256 == "sha256" + + +@pytest.mark.asyncio +async def test_auto_example_images_download_uses_settings_payload( + monkeypatch, tmp_path +): + manager = DownloadManager() + settings = get_settings_manager() + settings.settings["auto_download_example_images"] = True + settings.settings["example_images_path"] = str(tmp_path / "examples") + settings.settings["optimize_example_images"] = False + + calls = [] + + class DummyExampleImagesManager: + async def start_force_download(self, payload): + calls.append(payload) + return {"success": True} + + from py.utils import example_images_download_manager + + monkeypatch.setattr( + ServiceRegistry, + "get_websocket_manager", + AsyncMock(return_value=object()), + ) + monkeypatch.setattr( + example_images_download_manager, + "get_default_download_manager", + lambda _ws_manager: DummyExampleImagesManager(), + ) + + metadata = SimpleNamespace(sha256="ABCDEF", file_path="model.safetensors") + await manager._schedule_auto_example_images_download( + metadata=metadata, + model_type="lora", + ) + + for _ in range(10): + if calls: + break + await asyncio.sleep(0) + + assert calls == [ + { + "model_hashes": ["abcdef"], + "optimize": False, + "model_types": ["lora"], + "delay": 0, + } + ] + + +@pytest.mark.asyncio +async def test_auto_example_images_download_skips_without_configuration( + monkeypatch, tmp_path +): + manager = DownloadManager() + settings = get_settings_manager() + settings.settings["auto_download_example_images"] = True + settings.settings["example_images_path"] = "" + + get_ws_manager = AsyncMock(return_value=object()) + monkeypatch.setattr(ServiceRegistry, "get_websocket_manager", get_ws_manager) + + await manager._schedule_auto_example_images_download( + metadata=SimpleNamespace(sha256="abcdef", file_path="model.safetensors"), + model_type="lora", + ) + await asyncio.sleep(0) + + get_ws_manager.assert_not_called() + + settings.settings["example_images_path"] = str(tmp_path / "examples") + await manager._schedule_auto_example_images_download( + metadata=SimpleNamespace(sha256="", file_path="model.safetensors"), + model_type="lora", + ) + await asyncio.sleep(0) + + get_ws_manager.assert_not_called() + + @pytest.mark.asyncio async def test_download_uses_active_mirrors( monkeypatch, scanners, metadata_provider, tmp_path