mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 08:26:45 -03:00
fix(download): auto fetch example images after model download
This commit is contained in:
@@ -75,6 +75,65 @@ class DownloadManager:
|
||||
backend = (get_settings_manager().get("download_backend") or "python").strip()
|
||||
return backend.lower() or "python"
|
||||
|
||||
async def _schedule_auto_example_images_download(
|
||||
self,
|
||||
*,
|
||||
metadata,
|
||||
model_type: str,
|
||||
) -> None:
|
||||
settings_manager = get_settings_manager()
|
||||
if not settings_manager.get("auto_download_example_images", False):
|
||||
return
|
||||
|
||||
if not settings_manager.get("example_images_path"):
|
||||
logger.debug(
|
||||
"Skipping automatic example images download; example_images_path is not configured"
|
||||
)
|
||||
return
|
||||
|
||||
raw_hash = getattr(metadata, "sha256", "") or ""
|
||||
model_hash = str(raw_hash).strip().lower()
|
||||
if not model_hash:
|
||||
logger.debug(
|
||||
"Skipping automatic example images download for %s; missing sha256",
|
||||
getattr(metadata, "file_path", ""),
|
||||
)
|
||||
return
|
||||
|
||||
optimize = bool(settings_manager.get("optimize_example_images", True))
|
||||
|
||||
async def _run_auto_example_images_download() -> None:
|
||||
try:
|
||||
from ..utils.example_images_download_manager import (
|
||||
DownloadInProgressError,
|
||||
get_default_download_manager,
|
||||
)
|
||||
|
||||
ws_manager = await ServiceRegistry.get_websocket_manager()
|
||||
example_images_manager = get_default_download_manager(ws_manager)
|
||||
await example_images_manager.start_force_download(
|
||||
{
|
||||
"model_hashes": [model_hash],
|
||||
"optimize": optimize,
|
||||
"model_types": [model_type],
|
||||
"delay": 0,
|
||||
}
|
||||
)
|
||||
except DownloadInProgressError:
|
||||
logger.info(
|
||||
"Skipping automatic example images download for %s; another example images download is already running",
|
||||
model_hash,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Automatic example images download failed for %s: %s",
|
||||
model_hash,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
asyncio.create_task(_run_auto_example_images_download())
|
||||
|
||||
async def _download_model_file(
|
||||
self,
|
||||
download_url: str,
|
||||
@@ -1458,6 +1517,10 @@ class DownloadManager:
|
||||
version_info,
|
||||
model_version_id,
|
||||
)
|
||||
await self._schedule_auto_example_images_download(
|
||||
metadata=metadata,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
# If early_access_msg exists and download failed, replace error message
|
||||
if "early_access_msg" in locals() and not result.get("success", False):
|
||||
|
||||
@@ -233,6 +233,136 @@ async def test_successful_download_uses_defaults(
|
||||
assert captured["download_urls"] == ["https://example.invalid/file.safetensors"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_download_schedules_auto_example_images(
|
||||
monkeypatch, scanners, metadata_provider, tmp_path
|
||||
):
|
||||
manager = DownloadManager()
|
||||
scheduled = []
|
||||
|
||||
async def fake_execute_download(
|
||||
self,
|
||||
*,
|
||||
download_urls,
|
||||
save_dir,
|
||||
metadata,
|
||||
version_info,
|
||||
relative_path,
|
||||
progress_callback,
|
||||
model_type,
|
||||
download_id,
|
||||
transfer_backend=None,
|
||||
):
|
||||
return {"success": True}
|
||||
|
||||
async def fake_schedule(self, *, metadata, model_type):
|
||||
scheduled.append({"metadata": metadata, "model_type": model_type})
|
||||
|
||||
monkeypatch.setattr(
|
||||
DownloadManager, "_execute_download", fake_execute_download, raising=False
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
DownloadManager,
|
||||
"_schedule_auto_example_images_download",
|
||||
fake_schedule,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
result = await manager.download_from_civitai(
|
||||
model_version_id=99,
|
||||
save_dir=str(tmp_path),
|
||||
use_default_paths=True,
|
||||
progress_callback=None,
|
||||
source=None,
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert len(scheduled) == 1
|
||||
assert scheduled[0]["model_type"] == "lora"
|
||||
assert scheduled[0]["metadata"].sha256 == "sha256"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_example_images_download_uses_settings_payload(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
manager = DownloadManager()
|
||||
settings = get_settings_manager()
|
||||
settings.settings["auto_download_example_images"] = True
|
||||
settings.settings["example_images_path"] = str(tmp_path / "examples")
|
||||
settings.settings["optimize_example_images"] = False
|
||||
|
||||
calls = []
|
||||
|
||||
class DummyExampleImagesManager:
|
||||
async def start_force_download(self, payload):
|
||||
calls.append(payload)
|
||||
return {"success": True}
|
||||
|
||||
from py.utils import example_images_download_manager
|
||||
|
||||
monkeypatch.setattr(
|
||||
ServiceRegistry,
|
||||
"get_websocket_manager",
|
||||
AsyncMock(return_value=object()),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
example_images_download_manager,
|
||||
"get_default_download_manager",
|
||||
lambda _ws_manager: DummyExampleImagesManager(),
|
||||
)
|
||||
|
||||
metadata = SimpleNamespace(sha256="ABCDEF", file_path="model.safetensors")
|
||||
await manager._schedule_auto_example_images_download(
|
||||
metadata=metadata,
|
||||
model_type="lora",
|
||||
)
|
||||
|
||||
for _ in range(10):
|
||||
if calls:
|
||||
break
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert calls == [
|
||||
{
|
||||
"model_hashes": ["abcdef"],
|
||||
"optimize": False,
|
||||
"model_types": ["lora"],
|
||||
"delay": 0,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_example_images_download_skips_without_configuration(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
manager = DownloadManager()
|
||||
settings = get_settings_manager()
|
||||
settings.settings["auto_download_example_images"] = True
|
||||
settings.settings["example_images_path"] = ""
|
||||
|
||||
get_ws_manager = AsyncMock(return_value=object())
|
||||
monkeypatch.setattr(ServiceRegistry, "get_websocket_manager", get_ws_manager)
|
||||
|
||||
await manager._schedule_auto_example_images_download(
|
||||
metadata=SimpleNamespace(sha256="abcdef", file_path="model.safetensors"),
|
||||
model_type="lora",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
get_ws_manager.assert_not_called()
|
||||
|
||||
settings.settings["example_images_path"] = str(tmp_path / "examples")
|
||||
await manager._schedule_auto_example_images_download(
|
||||
metadata=SimpleNamespace(sha256="", file_path="model.safetensors"),
|
||||
model_type="lora",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
get_ws_manager.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_uses_active_mirrors(
|
||||
monkeypatch, scanners, metadata_provider, tmp_path
|
||||
|
||||
Reference in New Issue
Block a user