feat(example-images): add check pending models endpoint and improve async handling

- Add /api/example-images/check-pending endpoint to quickly check models needing downloads
- Improve DownloadManager.start_download() to return immediately without blocking
- Add _handle_download_task_done callback for proper error handling and progress saving
- Add check_pending_models() method for lightweight pre-download validation
- Update frontend ExampleImagesManager to use new check-pending endpoint
- Add comprehensive tests for new functionality
This commit is contained in:
Will Miao
2026-02-02 12:31:07 +08:00
parent 1daaff6bd4
commit 1da476d858
7 changed files with 734 additions and 17 deletions

View File

@@ -47,6 +47,8 @@ class StubDownloadManager:
self.resume_error: Exception | None = None
self.stop_error: Exception | None = None
self.force_error: Exception | None = None
self.check_pending_result: dict[str, Any] | None = None
self.check_pending_calls: list[list[str]] = []
async def get_status(self, request: web.Request) -> dict[str, Any]:
return {"success": True, "status": "idle"}
@@ -75,6 +77,20 @@ class StubDownloadManager:
raise self.force_error
return {"success": True, "payload": payload}
async def check_pending_models(self, model_types: list[str]) -> dict[str, Any]:
self.check_pending_calls.append(model_types)
if self.check_pending_result is not None:
return self.check_pending_result
return {
"success": True,
"is_downloading": False,
"total_models": 100,
"pending_count": 10,
"processed_count": 90,
"failed_count": 0,
"needs_download": True,
}
class StubImportUseCase:
def __init__(self) -> None:
@@ -236,3 +252,123 @@ async def test_import_route_returns_validation_errors():
assert response.status == 400
body = await _json(response)
assert body == {"success": False, "error": "bad payload"}
async def test_check_example_images_needed_returns_pending_counts():
"""Test that check_example_images_needed endpoint returns pending model counts."""
async with registrar_app() as harness:
harness.download_manager.check_pending_result = {
"success": True,
"is_downloading": False,
"total_models": 5500,
"pending_count": 12,
"processed_count": 5488,
"failed_count": 45,
"needs_download": True,
}
response = await harness.client.post(
"/api/lm/check-example-images-needed",
json={"model_types": ["lora", "checkpoint"]},
)
assert response.status == 200
body = await _json(response)
assert body["success"] is True
assert body["total_models"] == 5500
assert body["pending_count"] == 12
assert body["processed_count"] == 5488
assert body["failed_count"] == 45
assert body["needs_download"] is True
assert body["is_downloading"] is False
# Verify the manager was called with correct model types
assert harness.download_manager.check_pending_calls == [["lora", "checkpoint"]]
async def test_check_example_images_needed_handles_download_in_progress():
"""Test that check_example_images_needed returns correct status when download is running."""
async with registrar_app() as harness:
harness.download_manager.check_pending_result = {
"success": True,
"is_downloading": True,
"total_models": 0,
"pending_count": 0,
"processed_count": 0,
"failed_count": 0,
"needs_download": False,
"message": "Download already in progress",
}
response = await harness.client.post(
"/api/lm/check-example-images-needed",
json={"model_types": ["lora"]},
)
assert response.status == 200
body = await _json(response)
assert body["success"] is True
assert body["is_downloading"] is True
assert body["needs_download"] is False
async def test_check_example_images_needed_handles_no_pending_models():
"""Test that check_example_images_needed returns correct status when no work is needed."""
async with registrar_app() as harness:
harness.download_manager.check_pending_result = {
"success": True,
"is_downloading": False,
"total_models": 5500,
"pending_count": 0,
"processed_count": 5500,
"failed_count": 0,
"needs_download": False,
}
response = await harness.client.post(
"/api/lm/check-example-images-needed",
json={"model_types": ["lora", "checkpoint", "embedding"]},
)
assert response.status == 200
body = await _json(response)
assert body["success"] is True
assert body["pending_count"] == 0
assert body["needs_download"] is False
assert body["processed_count"] == 5500
async def test_check_example_images_needed_uses_default_model_types():
"""Test that check_example_images_needed uses default model types when not specified."""
async with registrar_app() as harness:
response = await harness.client.post(
"/api/lm/check-example-images-needed",
json={}, # No model_types specified
)
assert response.status == 200
# Should use default model types
assert harness.download_manager.check_pending_calls == [["lora", "checkpoint", "embedding"]]
async def test_check_example_images_needed_returns_error_on_exception():
"""Test that check_example_images_needed returns 500 on internal error."""
async with registrar_app() as harness:
# Simulate an error by setting result to an error state
# Actually, we need to make the method raise an exception
original_method = harness.download_manager.check_pending_models
async def failing_check(_model_types):
raise RuntimeError("Database connection failed")
harness.download_manager.check_pending_models = failing_check
response = await harness.client.post(
"/api/lm/check-example-images-needed",
json={"model_types": ["lora"]},
)
assert response.status == 500
body = await _json(response)
assert body["success"] is False
assert "Database connection failed" in body["error"]

View File

@@ -502,6 +502,7 @@ def test_handler_set_route_mapping_includes_all_handlers() -> None:
"resume_example_images",
"stop_example_images",
"force_download_example_images",
"check_example_images_needed",
"import_example_images",
"delete_example_image",
"set_example_image_nsfw_level",