fix(example-images): reuse migrated folders during downloads

This commit is contained in:
pixelpaws
2025-10-05 08:37:11 +08:00
parent 98425f37b8
commit 67c82ba6ea
3 changed files with 259 additions and 16 deletions

View File

@@ -1,14 +1,19 @@
from __future__ import annotations from __future__ import annotations
import logging
import os
import asyncio import asyncio
import json import json
import time import time
import logging
import os
import shutil
from typing import Any, Dict from typing import Any, Dict
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..utils.example_images_paths import ensure_library_root_exists from ..utils.example_images_paths import (
ExampleImagePathResolver,
ensure_library_root_exists,
uses_library_scoped_folders,
)
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
from .example_images_processor import ExampleImagesProcessor from .example_images_processor import ExampleImagesProcessor
from .example_images_metadata import MetadataUpdater from .example_images_metadata import MetadataUpdater
@@ -75,6 +80,22 @@ class _DownloadProgress(dict):
snapshot['failed_models'] = list(self['failed_models']) snapshot['failed_models'] = list(self['failed_models'])
return snapshot return snapshot
def _model_directory_has_files(path: str) -> bool:
"""Return True when the provided directory exists and contains entries."""
if not path or not os.path.isdir(path):
return False
try:
with os.scandir(path) as entries:
for _ in entries:
return True
except OSError:
return False
return False
class DownloadManager: class DownloadManager:
"""Manages downloading example images for models.""" """Manages downloading example images for models."""
@@ -128,9 +149,31 @@ class DownloadManager:
self._progress['end_time'] = None self._progress['end_time'] = None
progress_file = os.path.join(output_dir, '.download_progress.json') progress_file = os.path.join(output_dir, '.download_progress.json')
if os.path.exists(progress_file): progress_source = progress_file
if uses_library_scoped_folders():
legacy_root = settings.get('example_images_path') or ''
legacy_progress = os.path.join(legacy_root, '.download_progress.json') if legacy_root else ''
if legacy_progress and os.path.exists(legacy_progress) and not os.path.exists(progress_file):
try: try:
with open(progress_file, 'r', encoding='utf-8') as f: os.makedirs(output_dir, exist_ok=True)
shutil.move(legacy_progress, progress_file)
logger.info(
"Migrated legacy download progress file '%s' to '%s'",
legacy_progress,
progress_file,
)
except OSError as exc:
logger.warning(
"Failed to migrate download progress file from '%s' to '%s': %s",
legacy_progress,
progress_file,
exc,
)
progress_source = legacy_progress
if os.path.exists(progress_source):
try:
with open(progress_source, 'r', encoding='utf-8') as f:
saved_progress = json.load(f) saved_progress = json.load(f)
self._progress['processed_models'] = set(saved_progress.get('processed_models', [])) self._progress['processed_models'] = set(saved_progress.get('processed_models', []))
self._progress['failed_models'] = set(saved_progress.get('failed_models', [])) self._progress['failed_models'] = set(saved_progress.get('failed_models', []))
@@ -326,20 +369,35 @@ class DownloadManager:
logger.debug(f"Skipping known failed model: {model_name}") logger.debug(f"Skipping known failed model: {model_name}")
return False return False
model_dir = ExampleImagePathResolver.get_model_folder(model_hash)
existing_files = _model_directory_has_files(model_dir)
# Skip if already processed AND directory exists with files # Skip if already processed AND directory exists with files
if model_hash in self._progress['processed_models']: if model_hash in self._progress['processed_models']:
model_dir = os.path.join(output_dir, model_hash) if existing_files:
has_files = os.path.exists(model_dir) and any(os.listdir(model_dir))
if has_files:
logger.debug(f"Skipping already processed model: {model_name}") logger.debug(f"Skipping already processed model: {model_name}")
return False return False
else:
logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing") logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing")
# Remove from processed models since we need to reprocess # Remove from processed models since we need to reprocess
self._progress['processed_models'].discard(model_hash) self._progress['processed_models'].discard(model_hash)
if existing_files and model_hash not in self._progress['processed_models']:
logger.debug(
"Model folder already populated for %s, marking as processed without download",
model_name,
)
self._progress['processed_models'].add(model_hash)
return False
if not model_dir:
logger.warning(
"Unable to resolve example images folder for model %s (%s)",
model_name,
model_hash,
)
return False
# Create model directory # Create model directory
model_dir = os.path.join(output_dir, model_hash)
os.makedirs(model_dir, exist_ok=True) os.makedirs(model_dir, exist_ok=True)
# First check for local example images - local processing doesn't need delay # First check for local example images - local processing doesn't need delay
@@ -629,8 +687,15 @@ class DownloadManager:
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
await self._broadcast_progress(status='running') await self._broadcast_progress(status='running')
# Create model directory model_dir = ExampleImagePathResolver.get_model_folder(model_hash)
model_dir = os.path.join(output_dir, model_hash) if not model_dir:
logger.warning(
"Unable to resolve example images folder for model %s (%s)",
model_name,
model_hash,
)
return False
os.makedirs(model_dir, exist_ok=True) os.makedirs(model_dir, exist_ok=True)
# First check for local example images - local processing doesn't need delay # First check for local example images - local processing doesn't need delay

View File

@@ -102,6 +102,34 @@ def get_model_folder(model_hash: str, library_name: Optional[str] = None) -> str
return resolved_folder return resolved_folder
class ExampleImagePathResolver:
"""Convenience wrapper exposing example image path helpers."""
@staticmethod
def get_model_folder(model_hash: str, library_name: Optional[str] = None) -> str:
"""Return the example image folder for a model, migrating legacy paths."""
return get_model_folder(model_hash, library_name)
@staticmethod
def get_library_root(library_name: Optional[str] = None) -> str:
"""Return the configured library root for example images."""
return get_library_root(library_name)
@staticmethod
def ensure_library_root_exists(library_name: Optional[str] = None) -> str:
"""Ensure the library root exists before writing files."""
return ensure_library_root_exists(library_name)
@staticmethod
def get_model_relative_path(model_hash: str, library_name: Optional[str] = None) -> str:
"""Return the relative path to a model folder from the static mount point."""
return get_model_relative_path(model_hash, library_name)
def get_model_relative_path(model_hash: str, library_name: Optional[str] = None) -> str: def get_model_relative_path(model_hash: str, library_name: Optional[str] = None) -> str:
"""Return the relative URL path from the static mount to a model folder.""" """Return the relative URL path from the static mount to a model folder."""

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import json
from types import SimpleNamespace from types import SimpleNamespace
import pytest import pytest
@@ -226,3 +227,152 @@ async def test_pause_resume_blocks_processing(monkeypatch: pytest.MonkeyPatch, t
if manager._download_task is not None: if manager._download_task is not None:
await asyncio.wait_for(manager._download_task, timeout=1) await asyncio.wait_for(manager._download_task, timeout=1)
monkeypatch.setattr(download_module.asyncio, "sleep", original_sleep) monkeypatch.setattr(download_module.asyncio, "sleep", original_sleep)
@pytest.mark.usefixtures("tmp_path")
async def test_legacy_folder_migrated_and_skipped(monkeypatch: pytest.MonkeyPatch, tmp_path):
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}})
monkeypatch.setitem(settings.settings, "active_library", "extra")
model_hash = "d" * 64
model_path = tmp_path / "model.safetensors"
model_path.write_text("data", encoding="utf-8")
model = {
"sha256": model_hash,
"model_name": "Migrated Model",
"file_path": str(model_path),
"file_name": "model.safetensors",
"civitai": {"images": [{"url": "https://example.com/image.png"}]},
}
_patch_scanner(monkeypatch, StubScanner([model]))
legacy_folder = tmp_path / model_hash
legacy_folder.mkdir()
(legacy_folder / "image_0.png").write_text("data", encoding="utf-8")
process_called = False
download_called = False
async def fake_process_local_examples(*_args, **_kwargs):
nonlocal process_called
process_called = True
return False
async def fake_download_model_images(*_args, **_kwargs):
nonlocal download_called
download_called = True
return True, False
async def fake_get_downloader():
class _Downloader:
async def download_to_memory(self, *_a, **_kw):
return True, b"", {}
return _Downloader()
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"process_local_examples",
staticmethod(fake_process_local_examples),
)
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"download_model_images",
staticmethod(fake_download_model_images),
)
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
try:
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
if manager._download_task is not None:
await asyncio.wait_for(manager._download_task, timeout=1)
finally:
if manager._download_task is not None and not manager._download_task.done():
await asyncio.wait_for(manager._download_task, timeout=1)
library_root = tmp_path / "extra"
migrated_folder = library_root / model_hash
assert migrated_folder.exists()
assert not legacy_folder.exists()
assert not process_called
assert not download_called
assert model_hash in manager._progress["processed_models"]
@pytest.mark.usefixtures("tmp_path")
async def test_legacy_progress_file_migrates(monkeypatch: pytest.MonkeyPatch, tmp_path):
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}})
monkeypatch.setitem(settings.settings, "active_library", "extra")
model_hash = "e" * 64
model_path = tmp_path / "model-two.safetensors"
model_path.write_text("data", encoding="utf-8")
legacy_progress = tmp_path / ".download_progress.json"
legacy_progress.write_text(json.dumps({"processed_models": [model_hash], "failed_models": []}), encoding="utf-8")
legacy_folder = tmp_path / model_hash
legacy_folder.mkdir()
(legacy_folder / "existing.png").write_text("data", encoding="utf-8")
model = {
"sha256": model_hash,
"model_name": "Legacy Progress Model",
"file_path": str(model_path),
"file_name": "model-two.safetensors",
"civitai": {"images": [{"url": "https://example.com/image.png"}]},
}
_patch_scanner(monkeypatch, StubScanner([model]))
async def fake_process_local_examples(*_args, **_kwargs):
return False
async def fake_download_model_images(*_args, **_kwargs):
raise AssertionError("Remote download should not be attempted when progress is migrated")
async def fake_get_downloader():
class _Downloader:
async def download_to_memory(self, *_a, **_kw):
return True, b"", {}
return _Downloader()
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"process_local_examples",
staticmethod(fake_process_local_examples),
)
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"download_model_images",
staticmethod(fake_download_model_images),
)
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
if manager._download_task is not None:
await asyncio.wait_for(manager._download_task, timeout=1)
new_progress = (tmp_path / "extra") / ".download_progress.json"
assert model_hash in manager._progress["processed_models"]
assert not legacy_progress.exists()
assert new_progress.exists()
contents = json.loads(new_progress.read_text(encoding="utf-8"))
assert model_hash in contents.get("processed_models", [])