Merge pull request #522 from willmiao/codex/add-tests-for-example-images-pipeline

test: add example images route and utility coverage
This commit is contained in:
pixelpaws
2025-10-05 15:02:07 +08:00
committed by GitHub
6 changed files with 765 additions and 1 deletions

View File

@@ -0,0 +1,220 @@
from __future__ import annotations
import json
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, Dict
from aiohttp import web
from aiohttp.test_utils import TestClient, TestServer
from py.routes.example_images_route_registrar import ExampleImagesRouteRegistrar
from py.routes.handlers.example_images_handlers import (
ExampleImagesDownloadHandler,
ExampleImagesFileHandler,
ExampleImagesHandlerSet,
ExampleImagesManagementHandler,
)
from py.services.use_cases.example_images import (
DownloadExampleImagesInProgressError,
ImportExampleImagesValidationError,
)
from py.utils.example_images_download_manager import (
DownloadInProgressError,
DownloadNotRunningError,
)
class StubDownloadUseCase:
def __init__(self) -> None:
self.payloads: list[dict[str, Any]] = []
self.error: Exception | None = None
async def execute(self, payload: dict[str, Any]) -> dict[str, Any]:
self.payloads.append(payload)
if self.error:
raise self.error
return {"success": True, "payload": payload}
class StubDownloadManager:
def __init__(self) -> None:
self.pause_calls = 0
self.resume_calls = 0
self.force_payloads: list[dict[str, Any]] = []
self.pause_error: Exception | None = None
self.resume_error: Exception | None = None
self.force_error: Exception | None = None
async def get_status(self, request: web.Request) -> dict[str, Any]:
return {"success": True, "status": "idle"}
async def pause_download(self, request: web.Request) -> dict[str, Any]:
self.pause_calls += 1
if self.pause_error:
raise self.pause_error
return {"success": True, "message": "paused"}
async def resume_download(self, request: web.Request) -> dict[str, Any]:
self.resume_calls += 1
if self.resume_error:
raise self.resume_error
return {"success": True, "message": "resumed"}
async def start_force_download(self, payload: dict[str, Any]) -> dict[str, Any]:
self.force_payloads.append(payload)
if self.force_error:
raise self.force_error
return {"success": True, "payload": payload}
class StubImportUseCase:
def __init__(self) -> None:
self.requests: list[web.Request] = []
self.error: Exception | None = None
async def execute(self, request: web.Request) -> dict[str, Any]:
self.requests.append(request)
if self.error:
raise self.error
return {"success": True}
class StubProcessor:
def __init__(self) -> None:
self.delete_calls: list[web.Request] = []
async def delete_custom_image(self, request: web.Request) -> web.Response:
self.delete_calls.append(request)
return web.json_response({"deleted": True})
class StubCleanupService:
def __init__(self) -> None:
self.calls = 0
async def cleanup_example_image_folders(self) -> dict[str, Any]:
self.calls += 1
return {"success": True}
class StubFileManager:
async def open_folder(self, request: web.Request) -> web.Response:
return web.json_response({"opened": True})
async def get_files(self, request: web.Request) -> web.Response:
return web.json_response({"files": []})
async def has_images(self, request: web.Request) -> web.Response:
return web.json_response({"has": False})
@dataclass
class RegistrarHarness:
client: TestClient
download_use_case: StubDownloadUseCase
download_manager: StubDownloadManager
import_use_case: StubImportUseCase
@asynccontextmanager
async def registrar_app() -> RegistrarHarness:
app = web.Application()
download_use_case = StubDownloadUseCase()
download_manager = StubDownloadManager()
import_use_case = StubImportUseCase()
processor = StubProcessor()
cleanup_service = StubCleanupService()
file_manager = StubFileManager()
handler_set = ExampleImagesHandlerSet(
download=ExampleImagesDownloadHandler(download_use_case, download_manager),
management=ExampleImagesManagementHandler(import_use_case, processor, cleanup_service),
files=ExampleImagesFileHandler(file_manager),
)
registrar = ExampleImagesRouteRegistrar(app)
registrar.register_routes(handler_set.to_route_mapping())
server = TestServer(app)
client = TestClient(server)
await client.start_server()
try:
yield RegistrarHarness(
client=client,
download_use_case=download_use_case,
download_manager=download_manager,
import_use_case=import_use_case,
)
finally:
await client.close()
async def _json(response: web.StreamResponse) -> Dict[str, Any]:
text = await response.text()
return json.loads(text) if text else {}
async def test_download_route_surfaces_in_progress_error():
async with registrar_app() as harness:
progress = {"status": "running"}
harness.download_use_case.error = DownloadExampleImagesInProgressError(progress)
response = await harness.client.post(
"/api/lm/download-example-images",
json={"model_types": ["lora"]},
)
assert response.status == 400
body = await _json(response)
assert body["status"] == progress
assert body["error"] == "Download already in progress"
async def test_force_download_translates_manager_errors():
async with registrar_app() as harness:
snapshot = {"status": "running"}
harness.download_manager.force_error = DownloadInProgressError(snapshot)
response = await harness.client.post(
"/api/lm/force-download-example-images",
json={"model_hashes": ["abc"]},
)
assert response.status == 400
body = await _json(response)
assert body["status"] == snapshot
assert body["error"] == "Download already in progress"
async def test_pause_and_resume_return_client_errors_when_not_running():
async with registrar_app() as harness:
harness.download_manager.pause_error = DownloadNotRunningError()
harness.download_manager.resume_error = DownloadNotRunningError("Stopped")
pause_response = await harness.client.post("/api/lm/pause-example-images")
resume_response = await harness.client.post("/api/lm/resume-example-images")
assert pause_response.status == 400
assert resume_response.status == 400
pause_body = await _json(pause_response)
resume_body = await _json(resume_response)
assert pause_body == {"success": False, "error": "No download in progress"}
assert resume_body == {"success": False, "error": "Stopped"}
async def test_import_route_returns_validation_errors():
async with registrar_app() as harness:
harness.import_use_case.error = ImportExampleImagesValidationError("bad payload")
response = await harness.client.post(
"/api/lm/import-example-images",
json={"model_hash": "missing"},
)
assert response.status == 400
body = await _json(response)
assert body == {"success": False, "error": "bad payload"}

View File

@@ -0,0 +1,132 @@
from __future__ import annotations
import asyncio
from typing import Any, Dict
import pytest
from py.services.settings_manager import settings
from py.utils import example_images_download_manager as download_module
class RecordingWebSocketManager:
def __init__(self) -> None:
self.payloads: list[Dict[str, Any]] = []
async def broadcast(self, payload: Dict[str, Any]) -> None:
self.payloads.append(payload)
@pytest.fixture(autouse=True)
def restore_settings() -> None:
original = settings.settings.copy()
try:
yield
finally:
settings.settings.clear()
settings.settings.update(original)
async def test_start_download_requires_configured_path(monkeypatch: pytest.MonkeyPatch) -> None:
manager = download_module.DownloadManager(ws_manager=RecordingWebSocketManager())
with pytest.raises(download_module.ExampleImagesDownloadError) as exc_info:
await manager.start_download({})
assert isinstance(exc_info.value.__cause__, download_module.DownloadConfigurationError)
assert "not configured" in str(exc_info.value)
result = await manager.start_download({"auto_mode": True})
assert result["success"] is True
assert "skipping auto download" in result["message"]
async def test_start_download_bootstraps_progress_and_task(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path)
settings.settings["libraries"] = {"default": {}}
settings.settings["active_library"] = "default"
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
started = asyncio.Event()
release = asyncio.Event()
async def fake_download(self, output_dir, optimize, model_types, delay, library_name):
started.set()
await release.wait()
async with self._state_lock:
self._is_downloading = False
self._download_task = None
self._progress["status"] = "completed"
monkeypatch.setattr(
download_module.DownloadManager,
"_download_all_example_images",
fake_download,
)
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
assert manager._is_downloading is True
await asyncio.wait_for(started.wait(), timeout=1)
assert ws_manager.payloads[0]["status"] == "running"
task = manager._download_task
assert task is not None
release.set()
await asyncio.wait_for(task, timeout=1)
assert manager._is_downloading is False
assert manager._progress["status"] == "completed"
async def test_pause_and_resume_flow(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path)
settings.settings["libraries"] = {"default": {}}
settings.settings["active_library"] = "default"
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
started = asyncio.Event()
release = asyncio.Event()
async def fake_download(self, *_args):
started.set()
await release.wait()
async with self._state_lock:
self._is_downloading = False
self._download_task = None
monkeypatch.setattr(
download_module.DownloadManager,
"_download_all_example_images",
fake_download,
)
await manager.start_download({})
await asyncio.wait_for(started.wait(), timeout=1)
pause_response = await manager.pause_download(object())
assert pause_response == {"success": True, "message": "Download paused"}
assert manager._progress["status"] == "paused"
resume_response = await manager.resume_download(object())
assert resume_response == {"success": True, "message": "Download resumed"}
assert manager._progress["status"] == "running"
task = manager._download_task
assert task is not None
release.set()
await asyncio.wait_for(task, timeout=1)
async def test_pause_or_resume_without_running_download(monkeypatch: pytest.MonkeyPatch) -> None:
manager = download_module.DownloadManager(ws_manager=RecordingWebSocketManager())
with pytest.raises(download_module.DownloadNotRunningError):
await manager.pause_download(object())
with pytest.raises(download_module.DownloadNotRunningError):
await manager.resume_download(object())

View File

@@ -0,0 +1,113 @@
from __future__ import annotations
import json
import subprocess
from typing import Any, Dict
import pytest
from py.services.settings_manager import settings
from py.utils.example_images_file_manager import ExampleImagesFileManager
class JsonRequest:
def __init__(self, payload: Dict[str, Any], query: Dict[str, str] | None = None) -> None:
self._payload = payload
self.query = query or {}
async def json(self) -> Dict[str, Any]:
return self._payload
@pytest.fixture(autouse=True)
def restore_settings() -> None:
original = settings.settings.copy()
try:
yield
finally:
settings.settings.clear()
settings.settings.update(original)
async def test_open_folder_requires_existing_model_directory(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path)
model_hash = "a" * 64
model_folder = tmp_path / model_hash
model_folder.mkdir()
(model_folder / "image.png").write_text("data", encoding="utf-8")
popen_calls: list[list[str]] = []
class DummyPopen:
def __init__(self, cmd, *_args, **_kwargs):
popen_calls.append(cmd)
monkeypatch.setattr("subprocess.Popen", DummyPopen)
request = JsonRequest({"model_hash": model_hash})
response = await ExampleImagesFileManager.open_folder(request)
body = json.loads(response.text)
assert body["success"] is True
assert popen_calls
assert model_hash in popen_calls[0][-1]
async def test_open_folder_rejects_invalid_paths(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path)
def fake_get_model_folder(_hash):
return str(tmp_path.parent / "outside")
monkeypatch.setattr("py.utils.example_images_file_manager.get_model_folder", fake_get_model_folder)
request = JsonRequest({"model_hash": "a" * 64})
response = await ExampleImagesFileManager.open_folder(request)
body = json.loads(response.text)
assert response.status == 400
assert body["success"] is False
async def test_get_files_lists_supported_media(tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path)
model_hash = "b" * 64
model_folder = tmp_path / model_hash
model_folder.mkdir()
(model_folder / "image.png").write_text("data", encoding="utf-8")
(model_folder / "video.webm").write_text("data", encoding="utf-8")
(model_folder / "notes.txt").write_text("skip", encoding="utf-8")
request = JsonRequest({}, {"model_hash": model_hash})
response = await ExampleImagesFileManager.get_files(request)
body = json.loads(response.text)
assert response.status == 200
names = {entry["name"] for entry in body["files"]}
assert names == {"image.png", "video.webm"}
async def test_has_images_reports_presence(tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path)
model_hash = "c" * 64
model_folder = tmp_path / model_hash
model_folder.mkdir()
(model_folder / "image.png").write_text("data", encoding="utf-8")
request = JsonRequest({}, {"model_hash": model_hash})
response = await ExampleImagesFileManager.has_images(request)
body = json.loads(response.text)
assert body["has_images"] is True
empty_request = JsonRequest({}, {"model_hash": "missing"})
empty_response = await ExampleImagesFileManager.has_images(empty_request)
empty_body = json.loads(empty_response.text)
assert empty_body["has_images"] is False
async def test_has_images_requires_model_hash() -> None:
response = await ExampleImagesFileManager.has_images(JsonRequest({}, {}))
body = json.loads(response.text)
assert response.status == 400
assert body["success"] is False

View File

@@ -0,0 +1,115 @@
from __future__ import annotations
import json
from types import SimpleNamespace
from typing import Any, Dict, List, Tuple
import pytest
from py.utils import example_images_metadata as metadata_module
class StubScanner:
def __init__(self, cache_items: List[Dict[str, Any]]) -> None:
self.cache = SimpleNamespace(raw_data=cache_items)
self.updates: List[Tuple[str, str, Dict[str, Any]]] = []
async def get_cached_data(self):
return self.cache
async def update_single_model_cache(self, old_path: str, new_path: str, metadata: Dict[str, Any]) -> bool:
self.updates.append((old_path, new_path, metadata))
return True
@pytest.fixture(autouse=True)
def patch_metadata_manager(monkeypatch: pytest.MonkeyPatch):
saved: List[Tuple[str, Dict[str, Any]]] = []
async def fake_save(path: str, metadata: Dict[str, Any]) -> bool:
saved.append((path, metadata.copy()))
return True
monkeypatch.setattr(metadata_module.MetadataManager, "save_metadata", staticmethod(fake_save))
return saved
async def test_update_metadata_after_import_enriches_entries(monkeypatch: pytest.MonkeyPatch, tmp_path, patch_metadata_manager):
model_hash = "a" * 64
model_file = tmp_path / "model.safetensors"
model_file.write_text("content", encoding="utf-8")
model_data = {
"model_name": "Example",
"file_path": str(model_file),
"civitai": {},
}
scanner = StubScanner([model_data])
image_path = tmp_path / "custom.png"
image_path.write_bytes(b"fakepng")
monkeypatch.setattr(metadata_module.ExifUtils, "extract_image_metadata", staticmethod(lambda _path: "Prompt text Negative prompt: bad Steps: 20, Sampler: Euler"))
monkeypatch.setattr(metadata_module.MetadataUpdater, "_parse_image_metadata", staticmethod(lambda payload: {"prompt": "Prompt text", "negativePrompt": "bad", "parameters": {"Steps": "20"}}))
regular, custom = await metadata_module.MetadataUpdater.update_metadata_after_import(
model_hash,
model_data,
scanner,
[(str(image_path), "short-id")],
)
assert isinstance(custom, list)
assert custom[0]["id"] == "short-id"
assert custom[0]["meta"]["prompt"] == "Prompt text"
assert custom[0]["hasMeta"] is True
assert custom[0]["type"] == "image"
assert patch_metadata_manager[0][0] == str(model_file)
assert scanner.updates
async def test_refresh_model_metadata_records_failures(monkeypatch: pytest.MonkeyPatch, tmp_path):
model_hash = "b" * 64
model_file = tmp_path / "model.safetensors"
model_file.write_text("content", encoding="utf-8")
cache_item = {"sha256": model_hash, "file_path": str(model_file)}
scanner = StubScanner([cache_item])
class StubMetadataSync:
async def fetch_and_update_model(self, **_kwargs):
return True, None
monkeypatch.setattr(metadata_module, "_metadata_sync_service", StubMetadataSync())
result = await metadata_module.MetadataUpdater.refresh_model_metadata(
model_hash,
"Example",
"lora",
scanner,
{"refreshed_models": set(), "errors": [], "last_error": None},
)
assert result is True
async def test_update_metadata_from_local_examples_generates_entries(monkeypatch: pytest.MonkeyPatch, tmp_path):
model_hash = "c" * 64
model_dir = tmp_path / model_hash
model_dir.mkdir()
(model_dir / "image.png").write_text("data", encoding="utf-8")
model_data = {"model_name": "Local", "civitai": {}, "file_path": str(tmp_path / "model.safetensors")}
async def fake_save(path, metadata):
return True
monkeypatch.setattr(metadata_module.MetadataManager, "save_metadata", staticmethod(fake_save))
monkeypatch.setattr(metadata_module.ExifUtils, "extract_image_metadata", staticmethod(lambda _path: None))
success = await metadata_module.MetadataUpdater.update_metadata_from_local_examples(
model_hash,
model_data,
"lora",
StubScanner([model_data]),
str(model_dir),
)
assert success is True
assert model_data["civitai"]["images"]

View File

@@ -7,7 +7,13 @@ from pathlib import Path
import pytest
from py.services.settings_manager import settings
from py.utils.example_images_paths import get_model_folder, get_model_relative_path
from py.utils.example_images_paths import (
ensure_library_root_exists,
get_model_folder,
get_model_relative_path,
is_valid_example_images_root,
iter_library_roots,
)
@pytest.fixture(autouse=True)
@@ -74,3 +80,36 @@ def test_get_model_folder_migrates_legacy_structure(tmp_path):
assert relative == os.path.join('extra', model_hash).replace('\\', '/')
assert not legacy_folder.exists()
assert (expected_folder / 'image.png').exists()
def test_ensure_library_root_exists_creates_directories(tmp_path):
settings.settings['example_images_path'] = str(tmp_path)
settings.settings['libraries'] = {'default': {}, 'secondary': {}}
settings.settings['active_library'] = 'secondary'
resolved = ensure_library_root_exists('secondary')
assert Path(resolved) == tmp_path / 'secondary'
assert (tmp_path / 'secondary').is_dir()
def test_iter_library_roots_returns_all_configured(tmp_path):
settings.settings['example_images_path'] = str(tmp_path)
settings.settings['libraries'] = {'default': {}, 'alt': {}}
settings.settings['active_library'] = 'alt'
roots = dict(iter_library_roots())
assert roots['default'] == str(tmp_path / 'default')
assert roots['alt'] == str(tmp_path / 'alt')
def test_is_valid_example_images_root_accepts_hash_directories(tmp_path):
settings.settings['example_images_path'] = str(tmp_path)
hash_folder = tmp_path / ('d' * 64)
hash_folder.mkdir()
(hash_folder / 'image.png').write_text('data', encoding='utf-8')
assert is_valid_example_images_root(str(tmp_path)) is True
invalid_folder = tmp_path / 'not_hash'
invalid_folder.mkdir()
assert is_valid_example_images_root(str(tmp_path)) is False

View File

@@ -0,0 +1,145 @@
from __future__ import annotations
import os
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, Tuple
import pytest
from py.services.settings_manager import settings
from py.utils import example_images_processor as processor_module
@pytest.fixture(autouse=True)
def restore_settings() -> None:
original = settings.settings.copy()
try:
yield
finally:
settings.settings.clear()
settings.settings.update(original)
def test_get_file_extension_from_magic_bytes() -> None:
jpg_bytes = b"\xff\xd8\xff" + b"rest"
ext = processor_module.ExampleImagesProcessor._get_file_extension_from_content_or_headers(
jpg_bytes, {}, None
)
assert ext == ".jpg"
def test_get_file_extension_from_headers() -> None:
ext = processor_module.ExampleImagesProcessor._get_file_extension_from_content_or_headers(
b"", {"content-type": "image/png"}, None
)
assert ext == ".png"
def test_get_file_extension_from_url_fallback() -> None:
ext = processor_module.ExampleImagesProcessor._get_file_extension_from_content_or_headers(
b"", {}, "https://example.com/file.webm?query=1"
)
assert ext == ".webm"
def test_get_file_extension_defaults_to_jpg() -> None:
ext = processor_module.ExampleImagesProcessor._get_file_extension_from_content_or_headers(
b"", {}, None
)
assert ext == ".jpg"
class StubScanner:
def __init__(self, models: list[Dict[str, Any]]) -> None:
self._cache = SimpleNamespace(raw_data=models)
self.updated: list[Tuple[str, str, Dict[str, Any]]] = []
async def get_cached_data(self):
return self._cache
async def update_single_model_cache(self, old_path: str, new_path: str, metadata: Dict[str, Any]) -> bool:
self.updated.append((old_path, new_path, metadata))
return True
def has_hash(self, _hash: str) -> bool:
return True
@pytest.fixture
def stub_scanners(monkeypatch: pytest.MonkeyPatch, tmp_path) -> StubScanner:
model_hash = "a" * 64
model_path = tmp_path / "model.safetensors"
model_path.write_text("content", encoding="utf-8")
model_data = {
"sha256": model_hash,
"model_name": "Example",
"file_path": str(model_path),
"civitai": {},
}
scanner = StubScanner([model_data])
async def _return_scanner(cls=None):
return scanner
monkeypatch.setattr(processor_module.ServiceRegistry, "get_lora_scanner", classmethod(_return_scanner))
monkeypatch.setattr(processor_module.ServiceRegistry, "get_checkpoint_scanner", classmethod(_return_scanner))
monkeypatch.setattr(processor_module.ServiceRegistry, "get_embedding_scanner", classmethod(_return_scanner))
return scanner
async def test_import_images_creates_hash_directory(monkeypatch: pytest.MonkeyPatch, tmp_path, stub_scanners: StubScanner) -> None:
settings.settings["example_images_path"] = str(tmp_path / "examples")
settings.settings["libraries"] = {"default": {}}
settings.settings["active_library"] = "default"
source_file = tmp_path / "upload.png"
source_file.write_bytes(b"PNG data")
monkeypatch.setattr(processor_module.ExampleImagesProcessor, "generate_short_id", staticmethod(lambda: "short"))
recorded: Dict[str, Any] = {}
async def fake_update_metadata(model_hash, model_data, scanner, paths):
recorded["args"] = (model_hash, list(paths))
return ["regular"], ["custom"]
monkeypatch.setattr(processor_module.MetadataUpdater, "update_metadata_after_import", staticmethod(fake_update_metadata))
result = await processor_module.ExampleImagesProcessor.import_images("a" * 64, [str(source_file)])
assert result["success"] is True
assert result["files"][0]["name"].startswith("custom_short")
model_folder = Path(settings.settings["example_images_path"]) / ("a" * 64)
assert model_folder.exists()
created_files = list(model_folder.glob("custom_short*.png"))
assert len(created_files) == 1
assert created_files[0].read_bytes() == source_file.read_bytes()
model_hash, paths = recorded["args"]
assert model_hash == "a" * 64
assert paths[0][0].startswith(str(model_folder))
async def test_import_images_rejects_missing_parameters(monkeypatch: pytest.MonkeyPatch) -> None:
with pytest.raises(processor_module.ExampleImagesValidationError):
await processor_module.ExampleImagesProcessor.import_images("", [])
with pytest.raises(processor_module.ExampleImagesValidationError):
await processor_module.ExampleImagesProcessor.import_images("abc", [])
async def test_import_images_raises_when_model_not_found(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path)
async def _empty_scanner(cls=None):
return StubScanner([])
monkeypatch.setattr(processor_module.ServiceRegistry, "get_lora_scanner", classmethod(_empty_scanner))
monkeypatch.setattr(processor_module.ServiceRegistry, "get_checkpoint_scanner", classmethod(_empty_scanner))
monkeypatch.setattr(processor_module.ServiceRegistry, "get_embedding_scanner", classmethod(_empty_scanner))
with pytest.raises(processor_module.ExampleImagesImportError):
await processor_module.ExampleImagesProcessor.import_images("a" * 64, [str(tmp_path / "missing.png")])