feat(example-images): add check pending models endpoint and improve async handling

- Add /api/example-images/check-pending endpoint to quickly check models needing downloads
- Improve DownloadManager.start_download() to return immediately without blocking
- Add _handle_download_task_done callback for proper error handling and progress saving
- Add check_pending_models() method for lightweight pre-download validation
- Update frontend ExampleImagesManager to use new check-pending endpoint
- Add comprehensive tests for new functionality
This commit is contained in:
Will Miao
2026-02-02 12:31:07 +08:00
parent 1daaff6bd4
commit 1da476d858
7 changed files with 734 additions and 17 deletions

View File

@@ -47,6 +47,8 @@ class StubDownloadManager:
self.resume_error: Exception | None = None
self.stop_error: Exception | None = None
self.force_error: Exception | None = None
self.check_pending_result: dict[str, Any] | None = None
self.check_pending_calls: list[list[str]] = []
async def get_status(self, request: web.Request) -> dict[str, Any]:
return {"success": True, "status": "idle"}
@@ -75,6 +77,20 @@ class StubDownloadManager:
raise self.force_error
return {"success": True, "payload": payload}
async def check_pending_models(self, model_types: list[str]) -> dict[str, Any]:
self.check_pending_calls.append(model_types)
if self.check_pending_result is not None:
return self.check_pending_result
return {
"success": True,
"is_downloading": False,
"total_models": 100,
"pending_count": 10,
"processed_count": 90,
"failed_count": 0,
"needs_download": True,
}
class StubImportUseCase:
def __init__(self) -> None:
@@ -236,3 +252,123 @@ async def test_import_route_returns_validation_errors():
assert response.status == 400
body = await _json(response)
assert body == {"success": False, "error": "bad payload"}
async def test_check_example_images_needed_returns_pending_counts():
"""Test that check_example_images_needed endpoint returns pending model counts."""
async with registrar_app() as harness:
harness.download_manager.check_pending_result = {
"success": True,
"is_downloading": False,
"total_models": 5500,
"pending_count": 12,
"processed_count": 5488,
"failed_count": 45,
"needs_download": True,
}
response = await harness.client.post(
"/api/lm/check-example-images-needed",
json={"model_types": ["lora", "checkpoint"]},
)
assert response.status == 200
body = await _json(response)
assert body["success"] is True
assert body["total_models"] == 5500
assert body["pending_count"] == 12
assert body["processed_count"] == 5488
assert body["failed_count"] == 45
assert body["needs_download"] is True
assert body["is_downloading"] is False
# Verify the manager was called with correct model types
assert harness.download_manager.check_pending_calls == [["lora", "checkpoint"]]
async def test_check_example_images_needed_handles_download_in_progress():
"""Test that check_example_images_needed returns correct status when download is running."""
async with registrar_app() as harness:
harness.download_manager.check_pending_result = {
"success": True,
"is_downloading": True,
"total_models": 0,
"pending_count": 0,
"processed_count": 0,
"failed_count": 0,
"needs_download": False,
"message": "Download already in progress",
}
response = await harness.client.post(
"/api/lm/check-example-images-needed",
json={"model_types": ["lora"]},
)
assert response.status == 200
body = await _json(response)
assert body["success"] is True
assert body["is_downloading"] is True
assert body["needs_download"] is False
async def test_check_example_images_needed_handles_no_pending_models():
"""Test that check_example_images_needed returns correct status when no work is needed."""
async with registrar_app() as harness:
harness.download_manager.check_pending_result = {
"success": True,
"is_downloading": False,
"total_models": 5500,
"pending_count": 0,
"processed_count": 5500,
"failed_count": 0,
"needs_download": False,
}
response = await harness.client.post(
"/api/lm/check-example-images-needed",
json={"model_types": ["lora", "checkpoint", "embedding"]},
)
assert response.status == 200
body = await _json(response)
assert body["success"] is True
assert body["pending_count"] == 0
assert body["needs_download"] is False
assert body["processed_count"] == 5500
async def test_check_example_images_needed_uses_default_model_types():
"""Test that check_example_images_needed uses default model types when not specified."""
async with registrar_app() as harness:
response = await harness.client.post(
"/api/lm/check-example-images-needed",
json={}, # No model_types specified
)
assert response.status == 200
# Should use default model types
assert harness.download_manager.check_pending_calls == [["lora", "checkpoint", "embedding"]]
async def test_check_example_images_needed_returns_error_on_exception():
"""Test that check_example_images_needed returns 500 on internal error."""
async with registrar_app() as harness:
# Simulate an error by setting result to an error state
# Actually, we need to make the method raise an exception
original_method = harness.download_manager.check_pending_models
async def failing_check(_model_types):
raise RuntimeError("Database connection failed")
harness.download_manager.check_pending_models = failing_check
response = await harness.client.post(
"/api/lm/check-example-images-needed",
json={"model_types": ["lora"]},
)
assert response.status == 500
body = await _json(response)
assert body["success"] is False
assert "Database connection failed" in body["error"]

View File

@@ -502,6 +502,7 @@ def test_handler_set_route_mapping_includes_all_handlers() -> None:
"resume_example_images",
"stop_example_images",
"force_download_example_images",
"check_example_images_needed",
"import_example_images",
"delete_example_image",
"set_example_image_nsfw_level",

View File

@@ -0,0 +1,368 @@
"""Tests for the check_pending_models lightweight pre-check functionality."""
from __future__ import annotations
import json
from types import SimpleNamespace
import pytest
from py.services.settings_manager import get_settings_manager
from py.utils import example_images_download_manager as download_module
class StubScanner:
"""Scanner double returning predetermined cache contents."""
def __init__(self, models: list[dict]) -> None:
self._cache = SimpleNamespace(raw_data=models)
async def get_cached_data(self):
return self._cache
def _patch_scanners(
monkeypatch: pytest.MonkeyPatch,
lora_scanner: StubScanner | None = None,
checkpoint_scanner: StubScanner | None = None,
embedding_scanner: StubScanner | None = None,
) -> None:
"""Patch ServiceRegistry to return stub scanners."""
async def _get_lora_scanner(cls):
return lora_scanner or StubScanner([])
async def _get_checkpoint_scanner(cls):
return checkpoint_scanner or StubScanner([])
async def _get_embedding_scanner(cls):
return embedding_scanner or StubScanner([])
monkeypatch.setattr(
download_module.ServiceRegistry,
"get_lora_scanner",
classmethod(_get_lora_scanner),
)
monkeypatch.setattr(
download_module.ServiceRegistry,
"get_checkpoint_scanner",
classmethod(_get_checkpoint_scanner),
)
monkeypatch.setattr(
download_module.ServiceRegistry,
"get_embedding_scanner",
classmethod(_get_embedding_scanner),
)
class RecordingWebSocketManager:
"""Collects broadcast payloads for assertions."""
def __init__(self) -> None:
self.payloads: list[dict] = []
async def broadcast(self, payload: dict) -> None:
self.payloads.append(payload)
@pytest.mark.asyncio
@pytest.mark.usefixtures("tmp_path")
async def test_check_pending_models_returns_zero_when_all_processed(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
"""Test that check_pending_models returns 0 pending when all models are processed."""
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
# Create processed models
processed_hashes = ["a" * 64, "b" * 64, "c" * 64]
models = [
{"sha256": h, "model_name": f"Model {i}"}
for i, h in enumerate(processed_hashes)
]
# Create progress file with all models processed
progress_file = tmp_path / ".download_progress.json"
progress_file.write_text(
json.dumps({"processed_models": processed_hashes, "failed_models": []}),
encoding="utf-8",
)
# Create model directories with files (simulating completed downloads)
for h in processed_hashes:
model_dir = tmp_path / h
model_dir.mkdir()
(model_dir / "image_0.png").write_text("data")
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
result = await manager.check_pending_models(["lora"])
assert result["success"] is True
assert result["is_downloading"] is False
assert result["total_models"] == 3
assert result["pending_count"] == 0
assert result["processed_count"] == 3
assert result["needs_download"] is False
@pytest.mark.asyncio
@pytest.mark.usefixtures("tmp_path")
async def test_check_pending_models_finds_unprocessed_models(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
"""Test that check_pending_models correctly identifies unprocessed models."""
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
# Create models - some processed, some not
processed_hash = "a" * 64
unprocessed_hash = "b" * 64
models = [
{"sha256": processed_hash, "model_name": "Processed Model"},
{"sha256": unprocessed_hash, "model_name": "Unprocessed Model"},
]
# Create progress file with only one model processed
progress_file = tmp_path / ".download_progress.json"
progress_file.write_text(
json.dumps({"processed_models": [processed_hash], "failed_models": []}),
encoding="utf-8",
)
# Create directory only for processed model
processed_dir = tmp_path / processed_hash
processed_dir.mkdir()
(processed_dir / "image_0.png").write_text("data")
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
result = await manager.check_pending_models(["lora"])
assert result["success"] is True
assert result["total_models"] == 2
assert result["pending_count"] == 1
assert result["processed_count"] == 1
assert result["needs_download"] is True
@pytest.mark.asyncio
@pytest.mark.usefixtures("tmp_path")
async def test_check_pending_models_skips_models_without_hash(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
"""Test that models without sha256 are not counted as pending."""
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
# Models - one with hash, one without
models = [
{"sha256": "a" * 64, "model_name": "Hashed Model"},
{"sha256": None, "model_name": "No Hash Model"},
{"model_name": "Missing Hash Model"}, # No sha256 key at all
]
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
result = await manager.check_pending_models(["lora"])
assert result["success"] is True
assert result["total_models"] == 3
assert result["pending_count"] == 1 # Only the one with hash
assert result["needs_download"] is True
@pytest.mark.asyncio
@pytest.mark.usefixtures("tmp_path")
async def test_check_pending_models_handles_multiple_model_types(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
"""Test that check_pending_models aggregates counts across multiple model types."""
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
lora_models = [
{"sha256": "a" * 64, "model_name": "Lora 1"},
{"sha256": "b" * 64, "model_name": "Lora 2"},
]
checkpoint_models = [
{"sha256": "c" * 64, "model_name": "Checkpoint 1"},
]
embedding_models = [
{"sha256": "d" * 64, "model_name": "Embedding 1"},
{"sha256": "e" * 64, "model_name": "Embedding 2"},
{"sha256": "f" * 64, "model_name": "Embedding 3"},
]
_patch_scanners(
monkeypatch,
lora_scanner=StubScanner(lora_models),
checkpoint_scanner=StubScanner(checkpoint_models),
embedding_scanner=StubScanner(embedding_models),
)
result = await manager.check_pending_models(["lora", "checkpoint", "embedding"])
assert result["success"] is True
assert result["total_models"] == 6 # 2 + 1 + 3
assert result["pending_count"] == 6 # All unprocessed
assert result["needs_download"] is True
@pytest.mark.asyncio
@pytest.mark.usefixtures("tmp_path")
async def test_check_pending_models_returns_error_when_download_in_progress(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
"""Test that check_pending_models returns special response when download is running."""
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
# Simulate download in progress
manager._is_downloading = True
result = await manager.check_pending_models(["lora"])
assert result["success"] is True
assert result["is_downloading"] is True
assert result["needs_download"] is False
assert result["pending_count"] == 0
assert "already in progress" in result["message"].lower()
@pytest.mark.asyncio
@pytest.mark.usefixtures("tmp_path")
async def test_check_pending_models_handles_empty_library(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
"""Test that check_pending_models handles empty model library."""
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
_patch_scanners(monkeypatch, lora_scanner=StubScanner([]))
result = await manager.check_pending_models(["lora"])
assert result["success"] is True
assert result["total_models"] == 0
assert result["pending_count"] == 0
assert result["processed_count"] == 0
assert result["needs_download"] is False
@pytest.mark.asyncio
@pytest.mark.usefixtures("tmp_path")
async def test_check_pending_models_reads_failed_models(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
"""Test that check_pending_models correctly reports failed model count."""
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
models = [{"sha256": "a" * 64, "model_name": "Model"}]
# Create progress file with failed models
progress_file = tmp_path / ".download_progress.json"
progress_file.write_text(
json.dumps({"processed_models": [], "failed_models": ["a" * 64, "b" * 64]}),
encoding="utf-8",
)
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
result = await manager.check_pending_models(["lora"])
assert result["success"] is True
assert result["failed_count"] == 2
@pytest.mark.asyncio
@pytest.mark.usefixtures("tmp_path")
async def test_check_pending_models_handles_missing_progress_file(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
"""Test that check_pending_models works correctly when no progress file exists."""
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
models = [
{"sha256": "a" * 64, "model_name": "Model 1"},
{"sha256": "b" * 64, "model_name": "Model 2"},
]
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
# No progress file created
result = await manager.check_pending_models(["lora"])
assert result["success"] is True
assert result["total_models"] == 2
assert result["pending_count"] == 2 # All pending since no progress
assert result["processed_count"] == 0
assert result["failed_count"] == 0
assert result["needs_download"] is True
@pytest.mark.asyncio
@pytest.mark.usefixtures("tmp_path")
async def test_check_pending_models_handles_corrupted_progress_file(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
"""Test that check_pending_models handles corrupted progress file gracefully."""
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
models = [{"sha256": "a" * 64, "model_name": "Model"}]
# Create corrupted progress file
progress_file = tmp_path / ".download_progress.json"
progress_file.write_text("not valid json", encoding="utf-8")
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
result = await manager.check_pending_models(["lora"])
# Should still succeed, treating all as unprocessed
assert result["success"] is True
assert result["total_models"] == 1
assert result["pending_count"] == 1
@pytest.fixture
def settings_manager():
return get_settings_manager()