From 5b564cd8a302b2b93182ac5c9b0f583ae6746483 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Sun, 5 Oct 2025 09:10:25 +0800 Subject: [PATCH] fix(example-images): pin downloads to start library --- py/utils/example_images_download_manager.py | 78 +++++++++++++++---- ...t_example_images_download_manager_async.py | 78 +++++++++++++++++++ 2 files changed, 140 insertions(+), 16 deletions(-) diff --git a/py/utils/example_images_download_manager.py b/py/utils/example_images_download_manager.py index 590e050a..ed6ca69a 100644 --- a/py/utils/example_images_download_manager.py +++ b/py/utils/example_images_download_manager.py @@ -106,11 +106,10 @@ class DownloadManager: self._ws_manager = ws_manager self._state_lock = state_lock or asyncio.Lock() - def _resolve_output_dir(self) -> str: + def _resolve_output_dir(self, library_name: str | None = None) -> str: base_path = settings.get('example_images_path') if not base_path: return '' - library_name = settings.get_active_library_name() return ensure_library_root_exists(library_name) async def start_download(self, options: dict): @@ -139,7 +138,8 @@ class DownloadManager: } raise DownloadConfigurationError(error_msg) - output_dir = self._resolve_output_dir() + active_library = settings.get_active_library_name() + output_dir = self._resolve_output_dir(active_library) if not output_dir: raise DownloadConfigurationError('Example images path not configured in settings') @@ -196,7 +196,8 @@ class DownloadManager: output_dir, optimize, model_types, - delay + delay, + active_library, ) ) @@ -261,7 +262,14 @@ class DownloadManager: 'message': 'Download resumed' } - async def _download_all_example_images(self, output_dir, optimize, model_types, delay): + async def _download_all_example_images( + self, + output_dir, + optimize, + model_types, + delay, + library_name, + ): """Download example images for all models.""" downloader = await get_downloader() @@ -299,8 +307,13 @@ class DownloadManager: for i, (scanner_type, model, scanner) in enumerate(all_models): # Main logic for processing model is here, but actual operations are delegated to other classes was_remote_download = await self._process_model( - scanner_type, model, scanner, - output_dir, optimize, downloader + scanner_type, + model, + scanner, + output_dir, + optimize, + downloader, + library_name, ) # Update progress @@ -342,7 +355,16 @@ class DownloadManager: self._is_downloading = False self._download_task = None - async def _process_model(self, scanner_type, model, scanner, output_dir, optimize, downloader): + async def _process_model( + self, + scanner_type, + model, + scanner, + output_dir, + optimize, + downloader, + library_name, + ): """Process a single model download.""" # Check if download is paused @@ -369,7 +391,7 @@ class DownloadManager: logger.debug(f"Skipping known failed model: {model_name}") return False - model_dir = ExampleImagePathResolver.get_model_folder(model_hash) + model_dir = ExampleImagePathResolver.get_model_folder(model_hash, library_name) existing_files = _model_directory_has_files(model_dir) # Skip if already processed AND directory exists with files @@ -532,7 +554,8 @@ class DownloadManager: if not base_path: raise DownloadConfigurationError('Example images path not configured in settings') - output_dir = self._resolve_output_dir() + active_library = settings.get_active_library_name() + output_dir = self._resolve_output_dir(active_library) if not output_dir: raise DownloadConfigurationError('Example images path not configured in settings') @@ -552,7 +575,8 @@ class DownloadManager: output_dir, optimize, model_types, - delay + delay, + active_library, ) async with self._state_lock: @@ -571,7 +595,15 @@ class DownloadManager: await self._broadcast_progress(status='error', extra={'error': str(e)}) raise ExampleImagesDownloadError(str(e)) from e - async def _download_specific_models_example_images_sync(self, model_hashes, output_dir, optimize, model_types, delay): + async def _download_specific_models_example_images_sync( + self, + model_hashes, + output_dir, + optimize, + model_types, + delay, + library_name, + ): """Download example images for specific models only - synchronous version.""" downloader = await get_downloader() @@ -612,8 +644,13 @@ class DownloadManager: for i, (scanner_type, model, scanner) in enumerate(models_to_process): # Force process this model regardless of previous status was_successful = await self._process_specific_model( - scanner_type, model, scanner, - output_dir, optimize, downloader + scanner_type, + model, + scanner, + output_dir, + optimize, + downloader, + library_name, ) if was_successful: @@ -665,7 +702,16 @@ class DownloadManager: # No need to close any sessions since we use the global downloader pass - async def _process_specific_model(self, scanner_type, model, scanner, output_dir, optimize, downloader): + async def _process_specific_model( + self, + scanner_type, + model, + scanner, + output_dir, + optimize, + downloader, + library_name, + ): """Process a specific model for forced download, ignoring previous download status.""" # Check if download is paused @@ -687,7 +733,7 @@ class DownloadManager: self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" await self._broadcast_progress(status='running') - model_dir = ExampleImagePathResolver.get_model_folder(model_hash) + model_dir = ExampleImagePathResolver.get_model_folder(model_hash, library_name) if not model_dir: logger.warning( "Unable to resolve example images folder for model %s (%s)", diff --git a/tests/services/test_example_images_download_manager_async.py b/tests/services/test_example_images_download_manager_async.py index a0833800..929862ae 100644 --- a/tests/services/test_example_images_download_manager_async.py +++ b/tests/services/test_example_images_download_manager_async.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import json +from pathlib import Path from types import SimpleNamespace import pytest @@ -376,3 +377,80 @@ async def test_legacy_progress_file_migrates(monkeypatch: pytest.MonkeyPatch, tm assert new_progress.exists() contents = json.loads(new_progress.read_text(encoding="utf-8")) assert model_hash in contents.get("processed_models", []) + + +@pytest.mark.usefixtures("tmp_path") +async def test_download_remains_in_initial_library(monkeypatch: pytest.MonkeyPatch, tmp_path): + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path)) + monkeypatch.setitem(settings.settings, "libraries", {"LibraryA": {}, "LibraryB": {}}) + monkeypatch.setitem(settings.settings, "active_library", "LibraryA") + + state = {"active": "LibraryA"} + + def fake_get_active_library_name(self): + return state["active"] + + monkeypatch.setattr(type(settings), "get_active_library_name", fake_get_active_library_name) + + model_hash = "f" * 64 + model_path = tmp_path / "example-model.safetensors" + model_path.write_text("data", encoding="utf-8") + + model = { + "sha256": model_hash, + "model_name": "Library Switch Model", + "file_path": str(model_path), + "file_name": "example-model.safetensors", + } + + _patch_scanner(monkeypatch, StubScanner([model])) + + async def fake_process_local_examples( + _file_path, + _file_name, + _model_name, + model_dir, + _optimize, + ): + Path(model_dir).mkdir(parents=True, exist_ok=True) + (Path(model_dir) / "local.txt").write_text("data", encoding="utf-8") + state["active"] = "LibraryB" + return True + + async def fake_update_metadata(*_args, **_kwargs): + return True + + async def fake_get_downloader(): + return object() + + monkeypatch.setattr( + download_module.ExampleImagesProcessor, + "process_local_examples", + staticmethod(fake_process_local_examples), + ) + monkeypatch.setattr( + download_module.MetadataUpdater, + "update_metadata_from_local_examples", + staticmethod(fake_update_metadata), + ) + monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader) + + result = await manager.start_download({"model_types": ["lora"], "delay": 0}) + assert result["success"] is True + + if manager._download_task is not None: + await asyncio.wait_for(manager._download_task, timeout=1) + + library_a_root = tmp_path / "LibraryA" + library_b_root = tmp_path / "LibraryB" + + progress_file = library_a_root / ".download_progress.json" + model_dir = library_a_root / model_hash + + assert progress_file.exists() + assert (model_dir / "local.txt").exists() + assert not (library_b_root / ".download_progress.json").exists() + assert not (library_b_root / model_hash).exists()