From 67c82ba6ea31b2c1631e6a17cf15dcd06d47a7d4 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Sun, 5 Oct 2025 08:37:11 +0800 Subject: [PATCH] fix(example-images): reuse migrated folders during downloads --- py/utils/example_images_download_manager.py | 97 +++++++++-- py/utils/example_images_paths.py | 28 ++++ ...t_example_images_download_manager_async.py | 150 ++++++++++++++++++ 3 files changed, 259 insertions(+), 16 deletions(-) diff --git a/py/utils/example_images_download_manager.py b/py/utils/example_images_download_manager.py index 6afe82d5..590e050a 100644 --- a/py/utils/example_images_download_manager.py +++ b/py/utils/example_images_download_manager.py @@ -1,14 +1,19 @@ from __future__ import annotations -import logging -import os import asyncio import json import time +import logging +import os +import shutil from typing import Any, Dict from ..services.service_registry import ServiceRegistry -from ..utils.example_images_paths import ensure_library_root_exists +from ..utils.example_images_paths import ( + ExampleImagePathResolver, + ensure_library_root_exists, + uses_library_scoped_folders, +) from ..utils.metadata_manager import MetadataManager from .example_images_processor import ExampleImagesProcessor from .example_images_metadata import MetadataUpdater @@ -75,6 +80,22 @@ class _DownloadProgress(dict): snapshot['failed_models'] = list(self['failed_models']) return snapshot + +def _model_directory_has_files(path: str) -> bool: + """Return True when the provided directory exists and contains entries.""" + + if not path or not os.path.isdir(path): + return False + + try: + with os.scandir(path) as entries: + for _ in entries: + return True + except OSError: + return False + + return False + class DownloadManager: """Manages downloading example images for models.""" @@ -128,9 +149,31 @@ class DownloadManager: self._progress['end_time'] = None progress_file = os.path.join(output_dir, '.download_progress.json') - if os.path.exists(progress_file): + progress_source = progress_file + if uses_library_scoped_folders(): + legacy_root = settings.get('example_images_path') or '' + legacy_progress = os.path.join(legacy_root, '.download_progress.json') if legacy_root else '' + if legacy_progress and os.path.exists(legacy_progress) and not os.path.exists(progress_file): + try: + os.makedirs(output_dir, exist_ok=True) + shutil.move(legacy_progress, progress_file) + logger.info( + "Migrated legacy download progress file '%s' to '%s'", + legacy_progress, + progress_file, + ) + except OSError as exc: + logger.warning( + "Failed to migrate download progress file from '%s' to '%s': %s", + legacy_progress, + progress_file, + exc, + ) + progress_source = legacy_progress + + if os.path.exists(progress_source): try: - with open(progress_file, 'r', encoding='utf-8') as f: + with open(progress_source, 'r', encoding='utf-8') as f: saved_progress = json.load(f) self._progress['processed_models'] = set(saved_progress.get('processed_models', [])) self._progress['failed_models'] = set(saved_progress.get('failed_models', [])) @@ -326,20 +369,35 @@ class DownloadManager: logger.debug(f"Skipping known failed model: {model_name}") return False + model_dir = ExampleImagePathResolver.get_model_folder(model_hash) + existing_files = _model_directory_has_files(model_dir) + # Skip if already processed AND directory exists with files if model_hash in self._progress['processed_models']: - model_dir = os.path.join(output_dir, model_hash) - has_files = os.path.exists(model_dir) and any(os.listdir(model_dir)) - if has_files: + if existing_files: logger.debug(f"Skipping already processed model: {model_name}") return False - else: - logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing") - # Remove from processed models since we need to reprocess - self._progress['processed_models'].discard(model_hash) - + logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing") + # Remove from processed models since we need to reprocess + self._progress['processed_models'].discard(model_hash) + + if existing_files and model_hash not in self._progress['processed_models']: + logger.debug( + "Model folder already populated for %s, marking as processed without download", + model_name, + ) + self._progress['processed_models'].add(model_hash) + return False + + if not model_dir: + logger.warning( + "Unable to resolve example images folder for model %s (%s)", + model_name, + model_hash, + ) + return False + # Create model directory - model_dir = os.path.join(output_dir, model_hash) os.makedirs(model_dir, exist_ok=True) # First check for local example images - local processing doesn't need delay @@ -629,8 +687,15 @@ class DownloadManager: self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" await self._broadcast_progress(status='running') - # Create model directory - model_dir = os.path.join(output_dir, model_hash) + model_dir = ExampleImagePathResolver.get_model_folder(model_hash) + if not model_dir: + logger.warning( + "Unable to resolve example images folder for model %s (%s)", + model_name, + model_hash, + ) + return False + os.makedirs(model_dir, exist_ok=True) # First check for local example images - local processing doesn't need delay diff --git a/py/utils/example_images_paths.py b/py/utils/example_images_paths.py index 768dec1d..b272f72e 100644 --- a/py/utils/example_images_paths.py +++ b/py/utils/example_images_paths.py @@ -102,6 +102,34 @@ def get_model_folder(model_hash: str, library_name: Optional[str] = None) -> str return resolved_folder +class ExampleImagePathResolver: + """Convenience wrapper exposing example image path helpers.""" + + @staticmethod + def get_model_folder(model_hash: str, library_name: Optional[str] = None) -> str: + """Return the example image folder for a model, migrating legacy paths.""" + + return get_model_folder(model_hash, library_name) + + @staticmethod + def get_library_root(library_name: Optional[str] = None) -> str: + """Return the configured library root for example images.""" + + return get_library_root(library_name) + + @staticmethod + def ensure_library_root_exists(library_name: Optional[str] = None) -> str: + """Ensure the library root exists before writing files.""" + + return ensure_library_root_exists(library_name) + + @staticmethod + def get_model_relative_path(model_hash: str, library_name: Optional[str] = None) -> str: + """Return the relative path to a model folder from the static mount point.""" + + return get_model_relative_path(model_hash, library_name) + + def get_model_relative_path(model_hash: str, library_name: Optional[str] = None) -> str: """Return the relative URL path from the static mount to a model folder.""" diff --git a/tests/services/test_example_images_download_manager_async.py b/tests/services/test_example_images_download_manager_async.py index 7eef56fb..a0833800 100644 --- a/tests/services/test_example_images_download_manager_async.py +++ b/tests/services/test_example_images_download_manager_async.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import json from types import SimpleNamespace import pytest @@ -226,3 +227,152 @@ async def test_pause_resume_blocks_processing(monkeypatch: pytest.MonkeyPatch, t if manager._download_task is not None: await asyncio.wait_for(manager._download_task, timeout=1) monkeypatch.setattr(download_module.asyncio, "sleep", original_sleep) + + +@pytest.mark.usefixtures("tmp_path") +async def test_legacy_folder_migrated_and_skipped(monkeypatch: pytest.MonkeyPatch, tmp_path): + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path)) + monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}}) + monkeypatch.setitem(settings.settings, "active_library", "extra") + + model_hash = "d" * 64 + model_path = tmp_path / "model.safetensors" + model_path.write_text("data", encoding="utf-8") + + model = { + "sha256": model_hash, + "model_name": "Migrated Model", + "file_path": str(model_path), + "file_name": "model.safetensors", + "civitai": {"images": [{"url": "https://example.com/image.png"}]}, + } + + _patch_scanner(monkeypatch, StubScanner([model])) + + legacy_folder = tmp_path / model_hash + legacy_folder.mkdir() + (legacy_folder / "image_0.png").write_text("data", encoding="utf-8") + + process_called = False + download_called = False + + async def fake_process_local_examples(*_args, **_kwargs): + nonlocal process_called + process_called = True + return False + + async def fake_download_model_images(*_args, **_kwargs): + nonlocal download_called + download_called = True + return True, False + + async def fake_get_downloader(): + class _Downloader: + async def download_to_memory(self, *_a, **_kw): + return True, b"", {} + + return _Downloader() + + monkeypatch.setattr( + download_module.ExampleImagesProcessor, + "process_local_examples", + staticmethod(fake_process_local_examples), + ) + monkeypatch.setattr( + download_module.ExampleImagesProcessor, + "download_model_images", + staticmethod(fake_download_model_images), + ) + monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader) + + try: + result = await manager.start_download({"model_types": ["lora"], "delay": 0}) + assert result["success"] is True + + if manager._download_task is not None: + await asyncio.wait_for(manager._download_task, timeout=1) + finally: + if manager._download_task is not None and not manager._download_task.done(): + await asyncio.wait_for(manager._download_task, timeout=1) + + library_root = tmp_path / "extra" + migrated_folder = library_root / model_hash + + assert migrated_folder.exists() + assert not legacy_folder.exists() + assert not process_called + assert not download_called + assert model_hash in manager._progress["processed_models"] + + +@pytest.mark.usefixtures("tmp_path") +async def test_legacy_progress_file_migrates(monkeypatch: pytest.MonkeyPatch, tmp_path): + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path)) + monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}}) + monkeypatch.setitem(settings.settings, "active_library", "extra") + + model_hash = "e" * 64 + model_path = tmp_path / "model-two.safetensors" + model_path.write_text("data", encoding="utf-8") + + legacy_progress = tmp_path / ".download_progress.json" + legacy_progress.write_text(json.dumps({"processed_models": [model_hash], "failed_models": []}), encoding="utf-8") + + legacy_folder = tmp_path / model_hash + legacy_folder.mkdir() + (legacy_folder / "existing.png").write_text("data", encoding="utf-8") + + model = { + "sha256": model_hash, + "model_name": "Legacy Progress Model", + "file_path": str(model_path), + "file_name": "model-two.safetensors", + "civitai": {"images": [{"url": "https://example.com/image.png"}]}, + } + + _patch_scanner(monkeypatch, StubScanner([model])) + + async def fake_process_local_examples(*_args, **_kwargs): + return False + + async def fake_download_model_images(*_args, **_kwargs): + raise AssertionError("Remote download should not be attempted when progress is migrated") + + async def fake_get_downloader(): + class _Downloader: + async def download_to_memory(self, *_a, **_kw): + return True, b"", {} + + return _Downloader() + + monkeypatch.setattr( + download_module.ExampleImagesProcessor, + "process_local_examples", + staticmethod(fake_process_local_examples), + ) + monkeypatch.setattr( + download_module.ExampleImagesProcessor, + "download_model_images", + staticmethod(fake_download_model_images), + ) + monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader) + + result = await manager.start_download({"model_types": ["lora"], "delay": 0}) + assert result["success"] is True + + if manager._download_task is not None: + await asyncio.wait_for(manager._download_task, timeout=1) + + new_progress = (tmp_path / "extra") / ".download_progress.json" + + assert model_hash in manager._progress["processed_models"] + assert not legacy_progress.exists() + assert new_progress.exists() + contents = json.loads(new_progress.read_text(encoding="utf-8")) + assert model_hash in contents.get("processed_models", [])