mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
fix(example-images): reuse migrated folders during downloads
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
@@ -226,3 +227,152 @@ async def test_pause_resume_blocks_processing(monkeypatch: pytest.MonkeyPatch, t
|
||||
if manager._download_task is not None:
|
||||
await asyncio.wait_for(manager._download_task, timeout=1)
|
||||
monkeypatch.setattr(download_module.asyncio, "sleep", original_sleep)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_legacy_folder_migrated_and_skipped(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
|
||||
monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}})
|
||||
monkeypatch.setitem(settings.settings, "active_library", "extra")
|
||||
|
||||
model_hash = "d" * 64
|
||||
model_path = tmp_path / "model.safetensors"
|
||||
model_path.write_text("data", encoding="utf-8")
|
||||
|
||||
model = {
|
||||
"sha256": model_hash,
|
||||
"model_name": "Migrated Model",
|
||||
"file_path": str(model_path),
|
||||
"file_name": "model.safetensors",
|
||||
"civitai": {"images": [{"url": "https://example.com/image.png"}]},
|
||||
}
|
||||
|
||||
_patch_scanner(monkeypatch, StubScanner([model]))
|
||||
|
||||
legacy_folder = tmp_path / model_hash
|
||||
legacy_folder.mkdir()
|
||||
(legacy_folder / "image_0.png").write_text("data", encoding="utf-8")
|
||||
|
||||
process_called = False
|
||||
download_called = False
|
||||
|
||||
async def fake_process_local_examples(*_args, **_kwargs):
|
||||
nonlocal process_called
|
||||
process_called = True
|
||||
return False
|
||||
|
||||
async def fake_download_model_images(*_args, **_kwargs):
|
||||
nonlocal download_called
|
||||
download_called = True
|
||||
return True, False
|
||||
|
||||
async def fake_get_downloader():
|
||||
class _Downloader:
|
||||
async def download_to_memory(self, *_a, **_kw):
|
||||
return True, b"", {}
|
||||
|
||||
return _Downloader()
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_module.ExampleImagesProcessor,
|
||||
"process_local_examples",
|
||||
staticmethod(fake_process_local_examples),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
download_module.ExampleImagesProcessor,
|
||||
"download_model_images",
|
||||
staticmethod(fake_download_model_images),
|
||||
)
|
||||
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
|
||||
|
||||
try:
|
||||
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
|
||||
assert result["success"] is True
|
||||
|
||||
if manager._download_task is not None:
|
||||
await asyncio.wait_for(manager._download_task, timeout=1)
|
||||
finally:
|
||||
if manager._download_task is not None and not manager._download_task.done():
|
||||
await asyncio.wait_for(manager._download_task, timeout=1)
|
||||
|
||||
library_root = tmp_path / "extra"
|
||||
migrated_folder = library_root / model_hash
|
||||
|
||||
assert migrated_folder.exists()
|
||||
assert not legacy_folder.exists()
|
||||
assert not process_called
|
||||
assert not download_called
|
||||
assert model_hash in manager._progress["processed_models"]
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_legacy_progress_file_migrates(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
|
||||
monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}})
|
||||
monkeypatch.setitem(settings.settings, "active_library", "extra")
|
||||
|
||||
model_hash = "e" * 64
|
||||
model_path = tmp_path / "model-two.safetensors"
|
||||
model_path.write_text("data", encoding="utf-8")
|
||||
|
||||
legacy_progress = tmp_path / ".download_progress.json"
|
||||
legacy_progress.write_text(json.dumps({"processed_models": [model_hash], "failed_models": []}), encoding="utf-8")
|
||||
|
||||
legacy_folder = tmp_path / model_hash
|
||||
legacy_folder.mkdir()
|
||||
(legacy_folder / "existing.png").write_text("data", encoding="utf-8")
|
||||
|
||||
model = {
|
||||
"sha256": model_hash,
|
||||
"model_name": "Legacy Progress Model",
|
||||
"file_path": str(model_path),
|
||||
"file_name": "model-two.safetensors",
|
||||
"civitai": {"images": [{"url": "https://example.com/image.png"}]},
|
||||
}
|
||||
|
||||
_patch_scanner(monkeypatch, StubScanner([model]))
|
||||
|
||||
async def fake_process_local_examples(*_args, **_kwargs):
|
||||
return False
|
||||
|
||||
async def fake_download_model_images(*_args, **_kwargs):
|
||||
raise AssertionError("Remote download should not be attempted when progress is migrated")
|
||||
|
||||
async def fake_get_downloader():
|
||||
class _Downloader:
|
||||
async def download_to_memory(self, *_a, **_kw):
|
||||
return True, b"", {}
|
||||
|
||||
return _Downloader()
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_module.ExampleImagesProcessor,
|
||||
"process_local_examples",
|
||||
staticmethod(fake_process_local_examples),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
download_module.ExampleImagesProcessor,
|
||||
"download_model_images",
|
||||
staticmethod(fake_download_model_images),
|
||||
)
|
||||
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
|
||||
|
||||
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
|
||||
assert result["success"] is True
|
||||
|
||||
if manager._download_task is not None:
|
||||
await asyncio.wait_for(manager._download_task, timeout=1)
|
||||
|
||||
new_progress = (tmp_path / "extra") / ".download_progress.json"
|
||||
|
||||
assert model_hash in manager._progress["processed_models"]
|
||||
assert not legacy_progress.exists()
|
||||
assert new_progress.exists()
|
||||
contents = json.loads(new_progress.read_text(encoding="utf-8"))
|
||||
assert model_hash in contents.get("processed_models", [])
|
||||
|
||||
Reference in New Issue
Block a user