fix(example-images): reuse migrated folders during downloads

This commit is contained in:
pixelpaws
2025-10-05 08:37:11 +08:00
parent 98425f37b8
commit 67c82ba6ea
3 changed files with 259 additions and 16 deletions

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import json
from types import SimpleNamespace
import pytest
@@ -226,3 +227,152 @@ async def test_pause_resume_blocks_processing(monkeypatch: pytest.MonkeyPatch, t
if manager._download_task is not None:
await asyncio.wait_for(manager._download_task, timeout=1)
monkeypatch.setattr(download_module.asyncio, "sleep", original_sleep)
@pytest.mark.usefixtures("tmp_path")
async def test_legacy_folder_migrated_and_skipped(monkeypatch: pytest.MonkeyPatch, tmp_path):
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}})
monkeypatch.setitem(settings.settings, "active_library", "extra")
model_hash = "d" * 64
model_path = tmp_path / "model.safetensors"
model_path.write_text("data", encoding="utf-8")
model = {
"sha256": model_hash,
"model_name": "Migrated Model",
"file_path": str(model_path),
"file_name": "model.safetensors",
"civitai": {"images": [{"url": "https://example.com/image.png"}]},
}
_patch_scanner(monkeypatch, StubScanner([model]))
legacy_folder = tmp_path / model_hash
legacy_folder.mkdir()
(legacy_folder / "image_0.png").write_text("data", encoding="utf-8")
process_called = False
download_called = False
async def fake_process_local_examples(*_args, **_kwargs):
nonlocal process_called
process_called = True
return False
async def fake_download_model_images(*_args, **_kwargs):
nonlocal download_called
download_called = True
return True, False
async def fake_get_downloader():
class _Downloader:
async def download_to_memory(self, *_a, **_kw):
return True, b"", {}
return _Downloader()
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"process_local_examples",
staticmethod(fake_process_local_examples),
)
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"download_model_images",
staticmethod(fake_download_model_images),
)
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
try:
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
if manager._download_task is not None:
await asyncio.wait_for(manager._download_task, timeout=1)
finally:
if manager._download_task is not None and not manager._download_task.done():
await asyncio.wait_for(manager._download_task, timeout=1)
library_root = tmp_path / "extra"
migrated_folder = library_root / model_hash
assert migrated_folder.exists()
assert not legacy_folder.exists()
assert not process_called
assert not download_called
assert model_hash in manager._progress["processed_models"]
@pytest.mark.usefixtures("tmp_path")
async def test_legacy_progress_file_migrates(monkeypatch: pytest.MonkeyPatch, tmp_path):
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}})
monkeypatch.setitem(settings.settings, "active_library", "extra")
model_hash = "e" * 64
model_path = tmp_path / "model-two.safetensors"
model_path.write_text("data", encoding="utf-8")
legacy_progress = tmp_path / ".download_progress.json"
legacy_progress.write_text(json.dumps({"processed_models": [model_hash], "failed_models": []}), encoding="utf-8")
legacy_folder = tmp_path / model_hash
legacy_folder.mkdir()
(legacy_folder / "existing.png").write_text("data", encoding="utf-8")
model = {
"sha256": model_hash,
"model_name": "Legacy Progress Model",
"file_path": str(model_path),
"file_name": "model-two.safetensors",
"civitai": {"images": [{"url": "https://example.com/image.png"}]},
}
_patch_scanner(monkeypatch, StubScanner([model]))
async def fake_process_local_examples(*_args, **_kwargs):
return False
async def fake_download_model_images(*_args, **_kwargs):
raise AssertionError("Remote download should not be attempted when progress is migrated")
async def fake_get_downloader():
class _Downloader:
async def download_to_memory(self, *_a, **_kw):
return True, b"", {}
return _Downloader()
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"process_local_examples",
staticmethod(fake_process_local_examples),
)
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"download_model_images",
staticmethod(fake_download_model_images),
)
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
if manager._download_task is not None:
await asyncio.wait_for(manager._download_task, timeout=1)
new_progress = (tmp_path / "extra") / ".download_progress.json"
assert model_hash in manager._progress["processed_models"]
assert not legacy_progress.exists()
assert new_progress.exists()
contents = json.loads(new_progress.read_text(encoding="utf-8"))
assert model_hash in contents.get("processed_models", [])