diff --git a/py/utils/example_images_download_manager.py b/py/utils/example_images_download_manager.py index 990d28e8..b6c1a0a5 100644 --- a/py/utils/example_images_download_manager.py +++ b/py/utils/example_images_download_manager.py @@ -121,100 +121,65 @@ class DownloadManager: async def start_download(self, options: dict): """Start downloading example images for models.""" + # Step 1: Parse options (fast, non-blocking) + data = options or {} + auto_mode = data.get("auto_mode", False) + optimize = data.get("optimize", True) + model_types = data.get("model_types", ["lora", "checkpoint"]) + delay = float(data.get("delay", 0.2)) + force = data.get("force", False) + + # Step 2: Validate configuration (fast lookup) + settings_manager = get_settings_manager() + base_path = settings_manager.get("example_images_path") + + if not base_path: + error_msg = "Example images path not configured in settings" + if auto_mode: + logger.debug(error_msg) + return { + "success": True, + "message": "Example images path not configured, skipping auto download", + } + raise DownloadConfigurationError(error_msg) + + active_library = settings_manager.get_active_library_name() + output_dir = self._resolve_output_dir(active_library) + if not output_dir: + raise DownloadConfigurationError( + "Example images path not configured in settings" + ) + + # Step 3: Load progress file (I/O operation, done outside lock) + processed_models = set() + failed_models = set() + + try: + progress_file, processed_models, failed_models = await self._load_progress_file(output_dir) + logger.debug( + "Loaded previous progress, %s models already processed, %s models marked as failed", + len(processed_models), + len(failed_models), + ) + except Exception as e: + logger.error(f"Failed to load progress file: {e}") + # Continue with empty sets + + # Step 4: Quick state check and update (minimal lock time) async with self._state_lock: if self._is_downloading: raise DownloadInProgressError(self._progress.snapshot()) try: - data = options or {} - auto_mode = data.get("auto_mode", False) - optimize = data.get("optimize", True) - model_types = data.get("model_types", ["lora", "checkpoint"]) - delay = float(data.get("delay", 0.2)) - force = data.get("force", False) - - settings_manager = get_settings_manager() - base_path = settings_manager.get("example_images_path") - - if not base_path: - error_msg = "Example images path not configured in settings" - if auto_mode: - logger.debug(error_msg) - return { - "success": True, - "message": "Example images path not configured, skipping auto download", - } - raise DownloadConfigurationError(error_msg) - - active_library = get_settings_manager().get_active_library_name() - output_dir = self._resolve_output_dir(active_library) - if not output_dir: - raise DownloadConfigurationError( - "Example images path not configured in settings" - ) - + # Reset progress with loaded data self._progress.reset() + self._progress["processed_models"] = processed_models + self._progress["failed_models"] = failed_models self._stop_requested = False self._progress["status"] = "running" self._progress["start_time"] = time.time() self._progress["end_time"] = None - progress_file = os.path.join(output_dir, ".download_progress.json") - progress_source = progress_file - if uses_library_scoped_folders(): - legacy_root = ( - get_settings_manager().get("example_images_path") or "" - ) - legacy_progress = ( - os.path.join(legacy_root, ".download_progress.json") - if legacy_root - else "" - ) - if ( - legacy_progress - and os.path.exists(legacy_progress) - and not os.path.exists(progress_file) - ): - try: - os.makedirs(output_dir, exist_ok=True) - shutil.move(legacy_progress, progress_file) - logger.info( - "Migrated legacy download progress file '%s' to '%s'", - legacy_progress, - progress_file, - ) - except OSError as exc: - logger.warning( - "Failed to migrate download progress file from '%s' to '%s': %s", - legacy_progress, - progress_file, - exc, - ) - progress_source = legacy_progress - - if os.path.exists(progress_source): - try: - with open(progress_source, "r", encoding="utf-8") as f: - saved_progress = json.load(f) - self._progress["processed_models"] = set( - saved_progress.get("processed_models", []) - ) - self._progress["failed_models"] = set( - saved_progress.get("failed_models", []) - ) - logger.debug( - "Loaded previous progress, %s models already processed, %s models marked as failed", - len(self._progress["processed_models"]), - len(self._progress["failed_models"]), - ) - except Exception as e: - logger.error(f"Failed to load progress file: {e}") - self._progress["processed_models"] = set() - self._progress["failed_models"] = set() - else: - self._progress["processed_models"] = set() - self._progress["failed_models"] = set() - self._is_downloading = True snapshot = self._progress.snapshot() @@ -268,7 +233,7 @@ class DownloadManager: except Exception as save_error: logger.error(f"Failed to save progress after task failure: {save_error}") - async def get_status(self, request): + async def get_status(self, request) -> dict: """Get the current status of example images download.""" return { @@ -277,6 +242,87 @@ class DownloadManager: "status": self._progress.snapshot(), } + async def _load_progress_file(self, output_dir: str) -> tuple[str, set, set]: + """Load progress file from disk. Returns (progress_file_path, processed_models, failed_models). + + This is a separate async method to allow running in executor to avoid blocking event loop. + """ + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + None, self._load_progress_file_sync, output_dir + ) + + def _load_progress_file_sync(self, output_dir: str) -> tuple[str, set, set]: + """Synchronous implementation of progress file loading.""" + progress_file = os.path.join(output_dir, ".download_progress.json") + progress_source = progress_file + + # Handle legacy migration if needed + if uses_library_scoped_folders(): + legacy_root = get_settings_manager().get("example_images_path") or "" + legacy_progress = ( + os.path.join(legacy_root, ".download_progress.json") + if legacy_root + else "" + ) + if ( + legacy_progress + and os.path.exists(legacy_progress) + and not os.path.exists(progress_file) + ): + try: + os.makedirs(output_dir, exist_ok=True) + shutil.move(legacy_progress, progress_file) + logger.info( + "Migrated legacy download progress file '%s' to '%s'", + legacy_progress, + progress_file, + ) + except OSError as exc: + logger.warning( + "Failed to migrate download progress file from '%s' to '%s': %s", + legacy_progress, + progress_file, + exc, + ) + progress_source = legacy_progress + + processed_models = set() + failed_models = set() + + if os.path.exists(progress_source): + try: + with open(progress_source, "r", encoding="utf-8") as f: + saved_progress = json.load(f) + processed_models = set(saved_progress.get("processed_models", [])) + failed_models = set(saved_progress.get("failed_models", [])) + except Exception: + # Return empty sets on error + pass + + return progress_file, processed_models, failed_models + + def _load_progress_sets_sync(self, progress_file: str) -> tuple[set, set]: + """Load only the processed and failed model sets from progress file. + + This is a lighter version for quick checks without legacy migration. + Returns (processed_models, failed_models). + """ + processed_models = set() + failed_models = set() + + if os.path.exists(progress_file): + try: + with open(progress_file, "r", encoding="utf-8") as f: + saved_progress = json.load(f) + processed_models = set(saved_progress.get("processed_models", [])) + failed_models = set(saved_progress.get("failed_models", [])) + except Exception: + # Return empty sets on error + pass + + return processed_models, failed_models + async def check_pending_models(self, model_types: list[str]) -> dict: """Quickly check how many models need example images downloaded. @@ -320,62 +366,49 @@ class DownloadManager: embedding_scanner = await ServiceRegistry.get_embedding_scanner() scanners.append(("embedding", embedding_scanner)) - # Load progress file to check processed models + # Load progress file to check processed models (async to avoid blocking) settings_manager = get_settings_manager() active_library = settings_manager.get_active_library_name() output_dir = self._resolve_output_dir(active_library) - + processed_models: set[str] = set() failed_models: set[str] = set() - + if output_dir: progress_file = os.path.join(output_dir, ".download_progress.json") - if os.path.exists(progress_file): - try: - with open(progress_file, "r", encoding="utf-8") as f: - saved_progress = json.load(f) - processed_models = set(saved_progress.get("processed_models", [])) - failed_models = set(saved_progress.get("failed_models", [])) - except Exception: - pass # Ignore progress file errors for quick check + loop = asyncio.get_event_loop() + processed_models, failed_models = await loop.run_in_executor( + None, self._load_progress_sets_sync, progress_file + ) - # Count models + # Collect all models and count in a single pass per scanner total_models = 0 - models_with_hash = 0 - + all_models_with_hash: list[tuple[str, str]] = [] # (hash, name) pairs + for scanner_type, scanner in scanners: cache = await scanner.get_cached_data() if cache and cache.raw_data: for model in cache.raw_data: total_models += 1 - if model.get("sha256"): - models_with_hash += 1 - - # Calculate pending count - # A model is pending if it has a hash and is not in processed_models - # We also exclude failed_models unless force mode would be used - pending_count = models_with_hash - len(processed_models.intersection( - {m.get("sha256", "").lower() for scanner_type, scanner in scanners - for m in (await scanner.get_cached_data()).raw_data if m.get("sha256")} - )) - - # More accurate pending count: check which models actually need processing - pending_hashes = set() - for scanner_type, scanner in scanners: - cache = await scanner.get_cached_data() - if cache and cache.raw_data: - for model in cache.raw_data: raw_hash = model.get("sha256") - if not raw_hash: - continue - model_hash = raw_hash.lower() - if model_hash not in processed_models: - # Check if model folder exists with files - model_dir = ExampleImagePathResolver.get_model_folder( - model_hash, active_library - ) - if not _model_directory_has_files(model_dir): - pending_hashes.add(model_hash) + if raw_hash: + model_hash = raw_hash.lower() + all_models_with_hash.append((model_hash, model.get("model_name", "Unknown"))) + + models_with_hash = len(all_models_with_hash) + + # Calculate pending count: check which models actually need processing + # A model is pending if it has a hash, is not in processed_models, + # and its folder doesn't exist or is empty + pending_hashes = set() + for model_hash, model_name in all_models_with_hash: + if model_hash not in processed_models: + # Check if model folder exists with files + model_dir = ExampleImagePathResolver.get_model_folder( + model_hash, active_library + ) + if not _model_directory_has_files(model_dir): + pending_hashes.add(model_hash) pending_count = len(pending_hashes) diff --git a/tests/utils/test_example_images_download_manager_unit.py b/tests/utils/test_example_images_download_manager_unit.py index 51ed6c9f..fe1d6ffd 100644 --- a/tests/utils/test_example_images_download_manager_unit.py +++ b/tests/utils/test_example_images_download_manager_unit.py @@ -203,3 +203,150 @@ async def test_pause_or_resume_without_running_download( with pytest.raises(download_module.DownloadNotRunningError): await manager.stop_download(object()) + + +async def test_download_task_callback_executes_on_completion( + monkeypatch: pytest.MonkeyPatch, tmp_path +) -> None: + """Test that _handle_download_task_done callback is executed when download completes.""" + settings_manager = get_settings_manager() + settings_manager.settings["example_images_path"] = str(tmp_path) + settings_manager.settings["libraries"] = {"default": {}} + settings_manager.settings["active_library"] = "default" + + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + callback_executed = asyncio.Event() + original_callback = manager._handle_download_task_done + + def tracking_callback(task, output_dir): + original_callback(task, output_dir) + callback_executed.set() + + monkeypatch.setattr( + manager, "_handle_download_task_done", tracking_callback + ) + + async def fake_download(self, *_args): + # Simulate successful completion + async with self._state_lock: + self._progress["status"] = "completed" + self._is_downloading = False + self._download_task = None + + monkeypatch.setattr( + download_module.DownloadManager, + "_download_all_example_images", + fake_download, + ) + + result = await manager.start_download({"model_types": ["lora"], "delay": 0}) + assert result["success"] is True + + # Wait for callback to execute + await asyncio.wait_for(callback_executed.wait(), timeout=1) + assert manager._progress["status"] == "completed" + + +async def test_download_task_callback_handles_errors( + monkeypatch: pytest.MonkeyPatch, tmp_path +) -> None: + """Test that _handle_download_task_done properly handles task errors and saves progress.""" + settings_manager = get_settings_manager() + settings_manager.settings["example_images_path"] = str(tmp_path) + settings_manager.settings["libraries"] = {"default": {}} + settings_manager.settings["active_library"] = "default" + + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + callback_executed = asyncio.Event() + progress_saved = False + original_save_progress = manager._save_progress + + def tracking_save_progress(output_dir): + nonlocal progress_saved + progress_saved = True + return original_save_progress(output_dir) + + monkeypatch.setattr(manager, "_save_progress", tracking_save_progress) + + original_callback = manager._handle_download_task_done + + def tracking_callback(task, output_dir): + original_callback(task, output_dir) + callback_executed.set() + + monkeypatch.setattr( + manager, "_handle_download_task_done", tracking_callback + ) + + async def fake_download_with_error(self, *_args): + raise RuntimeError("Simulated download error") + + monkeypatch.setattr( + download_module.DownloadManager, + "_download_all_example_images", + fake_download_with_error, + ) + + result = await manager.start_download({"model_types": ["lora"], "delay": 0}) + assert result["success"] is True + + # Wait for callback to execute (it should handle the error) + await asyncio.wait_for(callback_executed.wait(), timeout=1) + # Progress should be saved even on error + assert progress_saved is True + + +async def test_get_status_returns_correct_state( + monkeypatch: pytest.MonkeyPatch, tmp_path +) -> None: + """Test that get_status returns the correct download state.""" + settings_manager = get_settings_manager() + settings_manager.settings["example_images_path"] = str(tmp_path) + settings_manager.settings["libraries"] = {"default": {}} + settings_manager.settings["active_library"] = "default" + + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + # Test idle state + status = await manager.get_status(object()) + assert status["success"] is True + assert status["is_downloading"] is False + assert status["status"]["status"] == "idle" + + started = asyncio.Event() + release = asyncio.Event() + + async def fake_download(self, *_args): + started.set() + await release.wait() + async with self._state_lock: + self._is_downloading = False + self._download_task = None + + monkeypatch.setattr( + download_module.DownloadManager, + "_download_all_example_images", + fake_download, + ) + + # Start download + result = await manager.start_download({"model_types": ["lora"], "delay": 0}) + assert result["success"] is True + + await asyncio.wait_for(started.wait(), timeout=1) + + # Test running state + status = await manager.get_status(object()) + assert status["success"] is True + assert status["is_downloading"] is True + assert status["status"]["status"] == "running" + + # Cleanup + release.set() + if manager._download_task: + await asyncio.wait_for(manager._download_task, timeout=1)