From 1da476d858382d9a0bdc507afcb0c89ef829b41a Mon Sep 17 00:00:00 2001 From: Will Miao Date: Mon, 2 Feb 2026 12:31:07 +0800 Subject: [PATCH] feat(example-images): add check pending models endpoint and improve async handling - Add /api/example-images/check-pending endpoint to quickly check models needing downloads - Improve DownloadManager.start_download() to return immediately without blocking - Add _handle_download_task_done callback for proper error handling and progress saving - Add check_pending_models() method for lightweight pre-download validation - Update frontend ExampleImagesManager to use new check-pending endpoint - Add comprehensive tests for new functionality --- py/routes/example_images_route_registrar.py | 1 + py/routes/handlers/example_images_handlers.py | 14 + py/utils/example_images_download_manager.py | 151 ++++++- static/js/managers/ExampleImagesManager.js | 80 +++- ...example_images_route_registrar_handlers.py | 136 +++++++ tests/routes/test_example_images_routes.py | 1 + tests/services/test_check_pending_models.py | 368 ++++++++++++++++++ 7 files changed, 734 insertions(+), 17 deletions(-) create mode 100644 tests/services/test_check_pending_models.py diff --git a/py/routes/example_images_route_registrar.py b/py/routes/example_images_route_registrar.py index e3e4b564..0c146c6e 100644 --- a/py/routes/example_images_route_registrar.py +++ b/py/routes/example_images_route_registrar.py @@ -30,6 +30,7 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("POST", "/api/lm/force-download-example-images", "force_download_example_images"), RouteDefinition("POST", "/api/lm/cleanup-example-image-folders", "cleanup_example_image_folders"), RouteDefinition("POST", "/api/lm/example-images/set-nsfw-level", "set_example_image_nsfw_level"), + RouteDefinition("POST", "/api/lm/check-example-images-needed", "check_example_images_needed"), ) diff --git a/py/routes/handlers/example_images_handlers.py b/py/routes/handlers/example_images_handlers.py index 8fe293d6..9db58188 100644 --- a/py/routes/handlers/example_images_handlers.py +++ b/py/routes/handlers/example_images_handlers.py @@ -92,6 +92,19 @@ class ExampleImagesDownloadHandler: except ExampleImagesDownloadError as exc: return web.json_response({'success': False, 'error': str(exc)}, status=500) + async def check_example_images_needed(self, request: web.Request) -> web.StreamResponse: + """Lightweight check to see if any models need example images downloaded.""" + try: + payload = await request.json() + model_types = payload.get('model_types', ['lora', 'checkpoint', 'embedding']) + result = await self._download_manager.check_pending_models(model_types) + return web.json_response(result) + except Exception as exc: + return web.json_response( + {'success': False, 'error': str(exc)}, + status=500 + ) + class ExampleImagesManagementHandler: """HTTP adapters for import/delete endpoints.""" @@ -161,6 +174,7 @@ class ExampleImagesHandlerSet: "resume_example_images": self.download.resume_example_images, "stop_example_images": self.download.stop_example_images, "force_download_example_images": self.download.force_download_example_images, + "check_example_images_needed": self.download.check_example_images_needed, "import_example_images": self.management.import_example_images, "delete_example_image": self.management.delete_example_image, "set_example_image_nsfw_level": self.management.set_example_image_nsfw_level, diff --git a/py/utils/example_images_download_manager.py b/py/utils/example_images_download_manager.py index 56022eba..990d28e8 100644 --- a/py/utils/example_images_download_manager.py +++ b/py/utils/example_images_download_manager.py @@ -216,6 +216,11 @@ class DownloadManager: self._progress["failed_models"] = set() self._is_downloading = True + snapshot = self._progress.snapshot() + + # Create the download task without awaiting it + # This ensures the HTTP response is returned immediately + # while the actual processing happens in the background self._download_task = asyncio.create_task( self._download_all_example_images( output_dir, @@ -227,7 +232,10 @@ class DownloadManager: ) ) - snapshot = self._progress.snapshot() + # Add a callback to handle task completion/errors + self._download_task.add_done_callback( + lambda t: self._handle_download_task_done(t, output_dir) + ) except ExampleImagesDownloadError: # Re-raise our own exception types without wrapping self._is_downloading = False @@ -241,10 +249,25 @@ class DownloadManager: ) raise ExampleImagesDownloadError(str(e)) from e - await self._broadcast_progress(status="running") + # Broadcast progress in the background without blocking the response + # This ensures the HTTP response is returned immediately + asyncio.create_task(self._broadcast_progress(status="running")) return {"success": True, "message": "Download started", "status": snapshot} + def _handle_download_task_done(self, task: asyncio.Task, output_dir: str) -> None: + """Handle download task completion, including saving progress on error.""" + try: + # This will re-raise any exception from the task + task.result() + except Exception as e: + logger.error(f"Download task failed with error: {e}", exc_info=True) + # Ensure progress is saved even on failure + try: + self._save_progress(output_dir) + except Exception as save_error: + logger.error(f"Failed to save progress after task failure: {save_error}") + async def get_status(self, request): """Get the current status of example images download.""" @@ -254,6 +277,130 @@ class DownloadManager: "status": self._progress.snapshot(), } + async def check_pending_models(self, model_types: list[str]) -> dict: + """Quickly check how many models need example images downloaded. + + This is a lightweight check that avoids the overhead of starting + a full download task when no work is needed. + + Returns: + dict with keys: + - total_models: Total number of models across specified types + - pending_count: Number of models needing example images + - processed_count: Number of already processed models + - failed_count: Number of models marked as failed + - needs_download: True if there are pending models to process + """ + from ..services.service_registry import ServiceRegistry + + if self._is_downloading: + return { + "success": True, + "is_downloading": True, + "total_models": 0, + "pending_count": 0, + "processed_count": 0, + "failed_count": 0, + "needs_download": False, + "message": "Download already in progress", + } + + try: + # Get scanners + scanners = [] + if "lora" in model_types: + lora_scanner = await ServiceRegistry.get_lora_scanner() + scanners.append(("lora", lora_scanner)) + + if "checkpoint" in model_types: + checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() + scanners.append(("checkpoint", checkpoint_scanner)) + + if "embedding" in model_types: + embedding_scanner = await ServiceRegistry.get_embedding_scanner() + scanners.append(("embedding", embedding_scanner)) + + # Load progress file to check processed models + settings_manager = get_settings_manager() + active_library = settings_manager.get_active_library_name() + output_dir = self._resolve_output_dir(active_library) + + processed_models: set[str] = set() + failed_models: set[str] = set() + + if output_dir: + progress_file = os.path.join(output_dir, ".download_progress.json") + if os.path.exists(progress_file): + try: + with open(progress_file, "r", encoding="utf-8") as f: + saved_progress = json.load(f) + processed_models = set(saved_progress.get("processed_models", [])) + failed_models = set(saved_progress.get("failed_models", [])) + except Exception: + pass # Ignore progress file errors for quick check + + # Count models + total_models = 0 + models_with_hash = 0 + + for scanner_type, scanner in scanners: + cache = await scanner.get_cached_data() + if cache and cache.raw_data: + for model in cache.raw_data: + total_models += 1 + if model.get("sha256"): + models_with_hash += 1 + + # Calculate pending count + # A model is pending if it has a hash and is not in processed_models + # We also exclude failed_models unless force mode would be used + pending_count = models_with_hash - len(processed_models.intersection( + {m.get("sha256", "").lower() for scanner_type, scanner in scanners + for m in (await scanner.get_cached_data()).raw_data if m.get("sha256")} + )) + + # More accurate pending count: check which models actually need processing + pending_hashes = set() + for scanner_type, scanner in scanners: + cache = await scanner.get_cached_data() + if cache and cache.raw_data: + for model in cache.raw_data: + raw_hash = model.get("sha256") + if not raw_hash: + continue + model_hash = raw_hash.lower() + if model_hash not in processed_models: + # Check if model folder exists with files + model_dir = ExampleImagePathResolver.get_model_folder( + model_hash, active_library + ) + if not _model_directory_has_files(model_dir): + pending_hashes.add(model_hash) + + pending_count = len(pending_hashes) + + return { + "success": True, + "is_downloading": False, + "total_models": total_models, + "pending_count": pending_count, + "processed_count": len(processed_models), + "failed_count": len(failed_models), + "needs_download": pending_count > 0, + } + + except Exception as e: + logger.error(f"Error checking pending models: {e}", exc_info=True) + return { + "success": False, + "error": str(e), + "total_models": 0, + "pending_count": 0, + "processed_count": 0, + "failed_count": 0, + "needs_download": False, + } + async def pause_download(self, request): """Pause the example images download.""" diff --git a/static/js/managers/ExampleImagesManager.js b/static/js/managers/ExampleImagesManager.js index c93dd456..9b2f60d8 100644 --- a/static/js/managers/ExampleImagesManager.js +++ b/static/js/managers/ExampleImagesManager.js @@ -21,7 +21,7 @@ export class ExampleImagesManager { // Auto download properties this.autoDownloadInterval = null; this.lastAutoDownloadCheck = 0; - this.autoDownloadCheckInterval = 10 * 60 * 1000; // 10 minutes in milliseconds + this.autoDownloadCheckInterval = 30 * 60 * 1000; // 30 minutes in milliseconds this.pageInitTime = Date.now(); // Track when page was initialized // Initialize download path field and check download status @@ -808,19 +808,58 @@ export class ExampleImagesManager { return; } - this.lastAutoDownloadCheck = now; - if (!this.canAutoDownload()) { console.log('Auto download conditions not met, skipping check'); return; } try { - console.log('Performing auto download check...'); + console.log('Performing auto download pre-check...'); + // Step 1: Lightweight pre-check to see if any work is needed + const checkResponse = await fetch('/api/lm/check-example-images-needed', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + model_types: ['lora', 'checkpoint', 'embedding'] + }) + }); + + if (!checkResponse.ok) { + console.warn('Auto download pre-check HTTP error:', checkResponse.status); + return; + } + + const checkData = await checkResponse.json(); + + if (!checkData.success) { + console.warn('Auto download pre-check failed:', checkData.error); + return; + } + + // Update the check timestamp only after successful pre-check + this.lastAutoDownloadCheck = now; + + // If download already in progress, skip + if (checkData.is_downloading) { + console.log('Download already in progress, skipping auto check'); + return; + } + + // If no models need downloading, skip + if (!checkData.needs_download || checkData.pending_count === 0) { + console.log(`Auto download pre-check complete: ${checkData.processed_count}/${checkData.total_models} models already processed, no work needed`); + return; + } + + console.log(`Auto download pre-check: ${checkData.pending_count} models need processing, starting download...`); + + // Step 2: Start the actual download (fire-and-forget) const optimize = state.global.settings.optimize_example_images; - const response = await fetch('/api/lm/download-example-images', { + fetch('/api/lm/download-example-images', { method: 'POST', headers: { 'Content-Type': 'application/json' @@ -830,18 +869,29 @@ export class ExampleImagesManager { model_types: ['lora', 'checkpoint', 'embedding'], auto_mode: true // Flag to indicate this is an automatic download }) + }).then(response => { + if (!response.ok) { + console.warn('Auto download start HTTP error:', response.status); + return null; + } + return response.json(); + }).then(data => { + if (data && !data.success) { + console.warn('Auto download start failed:', data.error); + // If already in progress, push back the next check to avoid hammering the API + if (data.error && data.error.includes('already in progress')) { + console.log('Download already in progress, backing off next check'); + this.lastAutoDownloadCheck = now + (5 * 60 * 1000); // Back off for 5 extra minutes + } + } else if (data && data.success) { + console.log('Auto download started:', data.message || 'Download started'); + } + }).catch(error => { + console.error('Auto download start error:', error); }); - const data = await response.json(); - - if (!data.success) { - console.warn('Auto download check failed:', data.error); - // If already in progress, push back the next check to avoid hammering the API - if (data.error && data.error.includes('already in progress')) { - console.log('Download already in progress, backing off next check'); - this.lastAutoDownloadCheck = now + (5 * 60 * 1000); // Back off for 5 extra minutes - } - } + // Immediately return without waiting for the download fetch to complete + // This keeps the UI responsive } catch (error) { console.error('Auto download check error:', error); } diff --git a/tests/routes/test_example_images_route_registrar_handlers.py b/tests/routes/test_example_images_route_registrar_handlers.py index 9f119fe9..a214dbbc 100644 --- a/tests/routes/test_example_images_route_registrar_handlers.py +++ b/tests/routes/test_example_images_route_registrar_handlers.py @@ -47,6 +47,8 @@ class StubDownloadManager: self.resume_error: Exception | None = None self.stop_error: Exception | None = None self.force_error: Exception | None = None + self.check_pending_result: dict[str, Any] | None = None + self.check_pending_calls: list[list[str]] = [] async def get_status(self, request: web.Request) -> dict[str, Any]: return {"success": True, "status": "idle"} @@ -75,6 +77,20 @@ class StubDownloadManager: raise self.force_error return {"success": True, "payload": payload} + async def check_pending_models(self, model_types: list[str]) -> dict[str, Any]: + self.check_pending_calls.append(model_types) + if self.check_pending_result is not None: + return self.check_pending_result + return { + "success": True, + "is_downloading": False, + "total_models": 100, + "pending_count": 10, + "processed_count": 90, + "failed_count": 0, + "needs_download": True, + } + class StubImportUseCase: def __init__(self) -> None: @@ -236,3 +252,123 @@ async def test_import_route_returns_validation_errors(): assert response.status == 400 body = await _json(response) assert body == {"success": False, "error": "bad payload"} + + +async def test_check_example_images_needed_returns_pending_counts(): + """Test that check_example_images_needed endpoint returns pending model counts.""" + async with registrar_app() as harness: + harness.download_manager.check_pending_result = { + "success": True, + "is_downloading": False, + "total_models": 5500, + "pending_count": 12, + "processed_count": 5488, + "failed_count": 45, + "needs_download": True, + } + + response = await harness.client.post( + "/api/lm/check-example-images-needed", + json={"model_types": ["lora", "checkpoint"]}, + ) + + assert response.status == 200 + body = await _json(response) + assert body["success"] is True + assert body["total_models"] == 5500 + assert body["pending_count"] == 12 + assert body["processed_count"] == 5488 + assert body["failed_count"] == 45 + assert body["needs_download"] is True + assert body["is_downloading"] is False + + # Verify the manager was called with correct model types + assert harness.download_manager.check_pending_calls == [["lora", "checkpoint"]] + + +async def test_check_example_images_needed_handles_download_in_progress(): + """Test that check_example_images_needed returns correct status when download is running.""" + async with registrar_app() as harness: + harness.download_manager.check_pending_result = { + "success": True, + "is_downloading": True, + "total_models": 0, + "pending_count": 0, + "processed_count": 0, + "failed_count": 0, + "needs_download": False, + "message": "Download already in progress", + } + + response = await harness.client.post( + "/api/lm/check-example-images-needed", + json={"model_types": ["lora"]}, + ) + + assert response.status == 200 + body = await _json(response) + assert body["success"] is True + assert body["is_downloading"] is True + assert body["needs_download"] is False + + +async def test_check_example_images_needed_handles_no_pending_models(): + """Test that check_example_images_needed returns correct status when no work is needed.""" + async with registrar_app() as harness: + harness.download_manager.check_pending_result = { + "success": True, + "is_downloading": False, + "total_models": 5500, + "pending_count": 0, + "processed_count": 5500, + "failed_count": 0, + "needs_download": False, + } + + response = await harness.client.post( + "/api/lm/check-example-images-needed", + json={"model_types": ["lora", "checkpoint", "embedding"]}, + ) + + assert response.status == 200 + body = await _json(response) + assert body["success"] is True + assert body["pending_count"] == 0 + assert body["needs_download"] is False + assert body["processed_count"] == 5500 + + +async def test_check_example_images_needed_uses_default_model_types(): + """Test that check_example_images_needed uses default model types when not specified.""" + async with registrar_app() as harness: + response = await harness.client.post( + "/api/lm/check-example-images-needed", + json={}, # No model_types specified + ) + + assert response.status == 200 + # Should use default model types + assert harness.download_manager.check_pending_calls == [["lora", "checkpoint", "embedding"]] + + +async def test_check_example_images_needed_returns_error_on_exception(): + """Test that check_example_images_needed returns 500 on internal error.""" + async with registrar_app() as harness: + # Simulate an error by setting result to an error state + # Actually, we need to make the method raise an exception + original_method = harness.download_manager.check_pending_models + + async def failing_check(_model_types): + raise RuntimeError("Database connection failed") + + harness.download_manager.check_pending_models = failing_check + + response = await harness.client.post( + "/api/lm/check-example-images-needed", + json={"model_types": ["lora"]}, + ) + + assert response.status == 500 + body = await _json(response) + assert body["success"] is False + assert "Database connection failed" in body["error"] diff --git a/tests/routes/test_example_images_routes.py b/tests/routes/test_example_images_routes.py index e380c967..bc926e65 100644 --- a/tests/routes/test_example_images_routes.py +++ b/tests/routes/test_example_images_routes.py @@ -502,6 +502,7 @@ def test_handler_set_route_mapping_includes_all_handlers() -> None: "resume_example_images", "stop_example_images", "force_download_example_images", + "check_example_images_needed", "import_example_images", "delete_example_image", "set_example_image_nsfw_level", diff --git a/tests/services/test_check_pending_models.py b/tests/services/test_check_pending_models.py new file mode 100644 index 00000000..518e665f --- /dev/null +++ b/tests/services/test_check_pending_models.py @@ -0,0 +1,368 @@ +"""Tests for the check_pending_models lightweight pre-check functionality.""" +from __future__ import annotations + +import json +from types import SimpleNamespace + +import pytest + +from py.services.settings_manager import get_settings_manager +from py.utils import example_images_download_manager as download_module + + +class StubScanner: + """Scanner double returning predetermined cache contents.""" + + def __init__(self, models: list[dict]) -> None: + self._cache = SimpleNamespace(raw_data=models) + + async def get_cached_data(self): + return self._cache + + +def _patch_scanners( + monkeypatch: pytest.MonkeyPatch, + lora_scanner: StubScanner | None = None, + checkpoint_scanner: StubScanner | None = None, + embedding_scanner: StubScanner | None = None, +) -> None: + """Patch ServiceRegistry to return stub scanners.""" + + async def _get_lora_scanner(cls): + return lora_scanner or StubScanner([]) + + async def _get_checkpoint_scanner(cls): + return checkpoint_scanner or StubScanner([]) + + async def _get_embedding_scanner(cls): + return embedding_scanner or StubScanner([]) + + monkeypatch.setattr( + download_module.ServiceRegistry, + "get_lora_scanner", + classmethod(_get_lora_scanner), + ) + monkeypatch.setattr( + download_module.ServiceRegistry, + "get_checkpoint_scanner", + classmethod(_get_checkpoint_scanner), + ) + monkeypatch.setattr( + download_module.ServiceRegistry, + "get_embedding_scanner", + classmethod(_get_embedding_scanner), + ) + + +class RecordingWebSocketManager: + """Collects broadcast payloads for assertions.""" + + def __init__(self) -> None: + self.payloads: list[dict] = [] + + async def broadcast(self, payload: dict) -> None: + self.payloads.append(payload) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("tmp_path") +async def test_check_pending_models_returns_zero_when_all_processed( + monkeypatch: pytest.MonkeyPatch, + tmp_path, + settings_manager, +): + """Test that check_pending_models returns 0 pending when all models are processed.""" + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path)) + + # Create processed models + processed_hashes = ["a" * 64, "b" * 64, "c" * 64] + models = [ + {"sha256": h, "model_name": f"Model {i}"} + for i, h in enumerate(processed_hashes) + ] + + # Create progress file with all models processed + progress_file = tmp_path / ".download_progress.json" + progress_file.write_text( + json.dumps({"processed_models": processed_hashes, "failed_models": []}), + encoding="utf-8", + ) + + # Create model directories with files (simulating completed downloads) + for h in processed_hashes: + model_dir = tmp_path / h + model_dir.mkdir() + (model_dir / "image_0.png").write_text("data") + + _patch_scanners(monkeypatch, lora_scanner=StubScanner(models)) + + result = await manager.check_pending_models(["lora"]) + + assert result["success"] is True + assert result["is_downloading"] is False + assert result["total_models"] == 3 + assert result["pending_count"] == 0 + assert result["processed_count"] == 3 + assert result["needs_download"] is False + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("tmp_path") +async def test_check_pending_models_finds_unprocessed_models( + monkeypatch: pytest.MonkeyPatch, + tmp_path, + settings_manager, +): + """Test that check_pending_models correctly identifies unprocessed models.""" + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path)) + + # Create models - some processed, some not + processed_hash = "a" * 64 + unprocessed_hash = "b" * 64 + models = [ + {"sha256": processed_hash, "model_name": "Processed Model"}, + {"sha256": unprocessed_hash, "model_name": "Unprocessed Model"}, + ] + + # Create progress file with only one model processed + progress_file = tmp_path / ".download_progress.json" + progress_file.write_text( + json.dumps({"processed_models": [processed_hash], "failed_models": []}), + encoding="utf-8", + ) + + # Create directory only for processed model + processed_dir = tmp_path / processed_hash + processed_dir.mkdir() + (processed_dir / "image_0.png").write_text("data") + + _patch_scanners(monkeypatch, lora_scanner=StubScanner(models)) + + result = await manager.check_pending_models(["lora"]) + + assert result["success"] is True + assert result["total_models"] == 2 + assert result["pending_count"] == 1 + assert result["processed_count"] == 1 + assert result["needs_download"] is True + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("tmp_path") +async def test_check_pending_models_skips_models_without_hash( + monkeypatch: pytest.MonkeyPatch, + tmp_path, + settings_manager, +): + """Test that models without sha256 are not counted as pending.""" + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path)) + + # Models - one with hash, one without + models = [ + {"sha256": "a" * 64, "model_name": "Hashed Model"}, + {"sha256": None, "model_name": "No Hash Model"}, + {"model_name": "Missing Hash Model"}, # No sha256 key at all + ] + + _patch_scanners(monkeypatch, lora_scanner=StubScanner(models)) + + result = await manager.check_pending_models(["lora"]) + + assert result["success"] is True + assert result["total_models"] == 3 + assert result["pending_count"] == 1 # Only the one with hash + assert result["needs_download"] is True + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("tmp_path") +async def test_check_pending_models_handles_multiple_model_types( + monkeypatch: pytest.MonkeyPatch, + tmp_path, + settings_manager, +): + """Test that check_pending_models aggregates counts across multiple model types.""" + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path)) + + lora_models = [ + {"sha256": "a" * 64, "model_name": "Lora 1"}, + {"sha256": "b" * 64, "model_name": "Lora 2"}, + ] + checkpoint_models = [ + {"sha256": "c" * 64, "model_name": "Checkpoint 1"}, + ] + embedding_models = [ + {"sha256": "d" * 64, "model_name": "Embedding 1"}, + {"sha256": "e" * 64, "model_name": "Embedding 2"}, + {"sha256": "f" * 64, "model_name": "Embedding 3"}, + ] + + _patch_scanners( + monkeypatch, + lora_scanner=StubScanner(lora_models), + checkpoint_scanner=StubScanner(checkpoint_models), + embedding_scanner=StubScanner(embedding_models), + ) + + result = await manager.check_pending_models(["lora", "checkpoint", "embedding"]) + + assert result["success"] is True + assert result["total_models"] == 6 # 2 + 1 + 3 + assert result["pending_count"] == 6 # All unprocessed + assert result["needs_download"] is True + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("tmp_path") +async def test_check_pending_models_returns_error_when_download_in_progress( + monkeypatch: pytest.MonkeyPatch, + tmp_path, + settings_manager, +): + """Test that check_pending_models returns special response when download is running.""" + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path)) + + # Simulate download in progress + manager._is_downloading = True + + result = await manager.check_pending_models(["lora"]) + + assert result["success"] is True + assert result["is_downloading"] is True + assert result["needs_download"] is False + assert result["pending_count"] == 0 + assert "already in progress" in result["message"].lower() + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("tmp_path") +async def test_check_pending_models_handles_empty_library( + monkeypatch: pytest.MonkeyPatch, + tmp_path, + settings_manager, +): + """Test that check_pending_models handles empty model library.""" + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path)) + + _patch_scanners(monkeypatch, lora_scanner=StubScanner([])) + + result = await manager.check_pending_models(["lora"]) + + assert result["success"] is True + assert result["total_models"] == 0 + assert result["pending_count"] == 0 + assert result["processed_count"] == 0 + assert result["needs_download"] is False + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("tmp_path") +async def test_check_pending_models_reads_failed_models( + monkeypatch: pytest.MonkeyPatch, + tmp_path, + settings_manager, +): + """Test that check_pending_models correctly reports failed model count.""" + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path)) + + models = [{"sha256": "a" * 64, "model_name": "Model"}] + + # Create progress file with failed models + progress_file = tmp_path / ".download_progress.json" + progress_file.write_text( + json.dumps({"processed_models": [], "failed_models": ["a" * 64, "b" * 64]}), + encoding="utf-8", + ) + + _patch_scanners(monkeypatch, lora_scanner=StubScanner(models)) + + result = await manager.check_pending_models(["lora"]) + + assert result["success"] is True + assert result["failed_count"] == 2 + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("tmp_path") +async def test_check_pending_models_handles_missing_progress_file( + monkeypatch: pytest.MonkeyPatch, + tmp_path, + settings_manager, +): + """Test that check_pending_models works correctly when no progress file exists.""" + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path)) + + models = [ + {"sha256": "a" * 64, "model_name": "Model 1"}, + {"sha256": "b" * 64, "model_name": "Model 2"}, + ] + + _patch_scanners(monkeypatch, lora_scanner=StubScanner(models)) + + # No progress file created + result = await manager.check_pending_models(["lora"]) + + assert result["success"] is True + assert result["total_models"] == 2 + assert result["pending_count"] == 2 # All pending since no progress + assert result["processed_count"] == 0 + assert result["failed_count"] == 0 + assert result["needs_download"] is True + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("tmp_path") +async def test_check_pending_models_handles_corrupted_progress_file( + monkeypatch: pytest.MonkeyPatch, + tmp_path, + settings_manager, +): + """Test that check_pending_models handles corrupted progress file gracefully.""" + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path)) + + models = [{"sha256": "a" * 64, "model_name": "Model"}] + + # Create corrupted progress file + progress_file = tmp_path / ".download_progress.json" + progress_file.write_text("not valid json", encoding="utf-8") + + _patch_scanners(monkeypatch, lora_scanner=StubScanner(models)) + + result = await manager.check_pending_models(["lora"]) + + # Should still succeed, treating all as unprocessed + assert result["success"] is True + assert result["total_models"] == 1 + assert result["pending_count"] == 1 + + +@pytest.fixture +def settings_manager(): + return get_settings_manager()