mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 08:26:45 -03:00
fix(download): auto fetch example images after model download
This commit is contained in:
@@ -75,6 +75,65 @@ class DownloadManager:
|
|||||||
backend = (get_settings_manager().get("download_backend") or "python").strip()
|
backend = (get_settings_manager().get("download_backend") or "python").strip()
|
||||||
return backend.lower() or "python"
|
return backend.lower() or "python"
|
||||||
|
|
||||||
|
async def _schedule_auto_example_images_download(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
metadata,
|
||||||
|
model_type: str,
|
||||||
|
) -> None:
|
||||||
|
settings_manager = get_settings_manager()
|
||||||
|
if not settings_manager.get("auto_download_example_images", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
if not settings_manager.get("example_images_path"):
|
||||||
|
logger.debug(
|
||||||
|
"Skipping automatic example images download; example_images_path is not configured"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
raw_hash = getattr(metadata, "sha256", "") or ""
|
||||||
|
model_hash = str(raw_hash).strip().lower()
|
||||||
|
if not model_hash:
|
||||||
|
logger.debug(
|
||||||
|
"Skipping automatic example images download for %s; missing sha256",
|
||||||
|
getattr(metadata, "file_path", ""),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
optimize = bool(settings_manager.get("optimize_example_images", True))
|
||||||
|
|
||||||
|
async def _run_auto_example_images_download() -> None:
|
||||||
|
try:
|
||||||
|
from ..utils.example_images_download_manager import (
|
||||||
|
DownloadInProgressError,
|
||||||
|
get_default_download_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
ws_manager = await ServiceRegistry.get_websocket_manager()
|
||||||
|
example_images_manager = get_default_download_manager(ws_manager)
|
||||||
|
await example_images_manager.start_force_download(
|
||||||
|
{
|
||||||
|
"model_hashes": [model_hash],
|
||||||
|
"optimize": optimize,
|
||||||
|
"model_types": [model_type],
|
||||||
|
"delay": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except DownloadInProgressError:
|
||||||
|
logger.info(
|
||||||
|
"Skipping automatic example images download for %s; another example images download is already running",
|
||||||
|
model_hash,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Automatic example images download failed for %s: %s",
|
||||||
|
model_hash,
|
||||||
|
exc,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.create_task(_run_auto_example_images_download())
|
||||||
|
|
||||||
async def _download_model_file(
|
async def _download_model_file(
|
||||||
self,
|
self,
|
||||||
download_url: str,
|
download_url: str,
|
||||||
@@ -1458,6 +1517,10 @@ class DownloadManager:
|
|||||||
version_info,
|
version_info,
|
||||||
model_version_id,
|
model_version_id,
|
||||||
)
|
)
|
||||||
|
await self._schedule_auto_example_images_download(
|
||||||
|
metadata=metadata,
|
||||||
|
model_type=model_type,
|
||||||
|
)
|
||||||
|
|
||||||
# If early_access_msg exists and download failed, replace error message
|
# If early_access_msg exists and download failed, replace error message
|
||||||
if "early_access_msg" in locals() and not result.get("success", False):
|
if "early_access_msg" in locals() and not result.get("success", False):
|
||||||
|
|||||||
@@ -233,6 +233,136 @@ async def test_successful_download_uses_defaults(
|
|||||||
assert captured["download_urls"] == ["https://example.invalid/file.safetensors"]
|
assert captured["download_urls"] == ["https://example.invalid/file.safetensors"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_successful_download_schedules_auto_example_images(
|
||||||
|
monkeypatch, scanners, metadata_provider, tmp_path
|
||||||
|
):
|
||||||
|
manager = DownloadManager()
|
||||||
|
scheduled = []
|
||||||
|
|
||||||
|
async def fake_execute_download(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
download_urls,
|
||||||
|
save_dir,
|
||||||
|
metadata,
|
||||||
|
version_info,
|
||||||
|
relative_path,
|
||||||
|
progress_callback,
|
||||||
|
model_type,
|
||||||
|
download_id,
|
||||||
|
transfer_backend=None,
|
||||||
|
):
|
||||||
|
return {"success": True}
|
||||||
|
|
||||||
|
async def fake_schedule(self, *, metadata, model_type):
|
||||||
|
scheduled.append({"metadata": metadata, "model_type": model_type})
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
DownloadManager, "_execute_download", fake_execute_download, raising=False
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
DownloadManager,
|
||||||
|
"_schedule_auto_example_images_download",
|
||||||
|
fake_schedule,
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await manager.download_from_civitai(
|
||||||
|
model_version_id=99,
|
||||||
|
save_dir=str(tmp_path),
|
||||||
|
use_default_paths=True,
|
||||||
|
progress_callback=None,
|
||||||
|
source=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert len(scheduled) == 1
|
||||||
|
assert scheduled[0]["model_type"] == "lora"
|
||||||
|
assert scheduled[0]["metadata"].sha256 == "sha256"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_example_images_download_uses_settings_payload(
|
||||||
|
monkeypatch, tmp_path
|
||||||
|
):
|
||||||
|
manager = DownloadManager()
|
||||||
|
settings = get_settings_manager()
|
||||||
|
settings.settings["auto_download_example_images"] = True
|
||||||
|
settings.settings["example_images_path"] = str(tmp_path / "examples")
|
||||||
|
settings.settings["optimize_example_images"] = False
|
||||||
|
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
class DummyExampleImagesManager:
|
||||||
|
async def start_force_download(self, payload):
|
||||||
|
calls.append(payload)
|
||||||
|
return {"success": True}
|
||||||
|
|
||||||
|
from py.utils import example_images_download_manager
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
ServiceRegistry,
|
||||||
|
"get_websocket_manager",
|
||||||
|
AsyncMock(return_value=object()),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
example_images_download_manager,
|
||||||
|
"get_default_download_manager",
|
||||||
|
lambda _ws_manager: DummyExampleImagesManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = SimpleNamespace(sha256="ABCDEF", file_path="model.safetensors")
|
||||||
|
await manager._schedule_auto_example_images_download(
|
||||||
|
metadata=metadata,
|
||||||
|
model_type="lora",
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
if calls:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert calls == [
|
||||||
|
{
|
||||||
|
"model_hashes": ["abcdef"],
|
||||||
|
"optimize": False,
|
||||||
|
"model_types": ["lora"],
|
||||||
|
"delay": 0,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_example_images_download_skips_without_configuration(
|
||||||
|
monkeypatch, tmp_path
|
||||||
|
):
|
||||||
|
manager = DownloadManager()
|
||||||
|
settings = get_settings_manager()
|
||||||
|
settings.settings["auto_download_example_images"] = True
|
||||||
|
settings.settings["example_images_path"] = ""
|
||||||
|
|
||||||
|
get_ws_manager = AsyncMock(return_value=object())
|
||||||
|
monkeypatch.setattr(ServiceRegistry, "get_websocket_manager", get_ws_manager)
|
||||||
|
|
||||||
|
await manager._schedule_auto_example_images_download(
|
||||||
|
metadata=SimpleNamespace(sha256="abcdef", file_path="model.safetensors"),
|
||||||
|
model_type="lora",
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
get_ws_manager.assert_not_called()
|
||||||
|
|
||||||
|
settings.settings["example_images_path"] = str(tmp_path / "examples")
|
||||||
|
await manager._schedule_auto_example_images_download(
|
||||||
|
metadata=SimpleNamespace(sha256="", file_path="model.safetensors"),
|
||||||
|
model_type="lora",
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
get_ws_manager.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_uses_active_mirrors(
|
async def test_download_uses_active_mirrors(
|
||||||
monkeypatch, scanners, metadata_provider, tmp_path
|
monkeypatch, scanners, metadata_provider, tmp_path
|
||||||
|
|||||||
Reference in New Issue
Block a user