Merge pull request #519 from willmiao/codex/update-example-images-download-flow

fix: reuse migrated example image folders before download
This commit is contained in:
pixelpaws
2025-10-05 09:04:06 +08:00
committed by GitHub
3 changed files with 259 additions and 16 deletions

View File

@@ -1,14 +1,19 @@
from __future__ import annotations
import logging
import os
import asyncio
import json
import time
import logging
import os
import shutil
from typing import Any, Dict
from ..services.service_registry import ServiceRegistry
from ..utils.example_images_paths import ensure_library_root_exists
from ..utils.example_images_paths import (
ExampleImagePathResolver,
ensure_library_root_exists,
uses_library_scoped_folders,
)
from ..utils.metadata_manager import MetadataManager
from .example_images_processor import ExampleImagesProcessor
from .example_images_metadata import MetadataUpdater
@@ -75,6 +80,22 @@ class _DownloadProgress(dict):
snapshot['failed_models'] = list(self['failed_models'])
return snapshot
def _model_directory_has_files(path: str) -> bool:
"""Return True when the provided directory exists and contains entries."""
if not path or not os.path.isdir(path):
return False
try:
with os.scandir(path) as entries:
for _ in entries:
return True
except OSError:
return False
return False
class DownloadManager:
"""Manages downloading example images for models."""
@@ -128,9 +149,31 @@ class DownloadManager:
self._progress['end_time'] = None
progress_file = os.path.join(output_dir, '.download_progress.json')
if os.path.exists(progress_file):
progress_source = progress_file
if uses_library_scoped_folders():
legacy_root = settings.get('example_images_path') or ''
legacy_progress = os.path.join(legacy_root, '.download_progress.json') if legacy_root else ''
if legacy_progress and os.path.exists(legacy_progress) and not os.path.exists(progress_file):
try:
os.makedirs(output_dir, exist_ok=True)
shutil.move(legacy_progress, progress_file)
logger.info(
"Migrated legacy download progress file '%s' to '%s'",
legacy_progress,
progress_file,
)
except OSError as exc:
logger.warning(
"Failed to migrate download progress file from '%s' to '%s': %s",
legacy_progress,
progress_file,
exc,
)
progress_source = legacy_progress
if os.path.exists(progress_source):
try:
with open(progress_file, 'r', encoding='utf-8') as f:
with open(progress_source, 'r', encoding='utf-8') as f:
saved_progress = json.load(f)
self._progress['processed_models'] = set(saved_progress.get('processed_models', []))
self._progress['failed_models'] = set(saved_progress.get('failed_models', []))
@@ -326,20 +369,35 @@ class DownloadManager:
logger.debug(f"Skipping known failed model: {model_name}")
return False
model_dir = ExampleImagePathResolver.get_model_folder(model_hash)
existing_files = _model_directory_has_files(model_dir)
# Skip if already processed AND directory exists with files
if model_hash in self._progress['processed_models']:
model_dir = os.path.join(output_dir, model_hash)
has_files = os.path.exists(model_dir) and any(os.listdir(model_dir))
if has_files:
if existing_files:
logger.debug(f"Skipping already processed model: {model_name}")
return False
else:
logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing")
# Remove from processed models since we need to reprocess
self._progress['processed_models'].discard(model_hash)
logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing")
# Remove from processed models since we need to reprocess
self._progress['processed_models'].discard(model_hash)
if existing_files and model_hash not in self._progress['processed_models']:
logger.debug(
"Model folder already populated for %s, marking as processed without download",
model_name,
)
self._progress['processed_models'].add(model_hash)
return False
if not model_dir:
logger.warning(
"Unable to resolve example images folder for model %s (%s)",
model_name,
model_hash,
)
return False
# Create model directory
model_dir = os.path.join(output_dir, model_hash)
os.makedirs(model_dir, exist_ok=True)
# First check for local example images - local processing doesn't need delay
@@ -629,8 +687,15 @@ class DownloadManager:
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
await self._broadcast_progress(status='running')
# Create model directory
model_dir = os.path.join(output_dir, model_hash)
model_dir = ExampleImagePathResolver.get_model_folder(model_hash)
if not model_dir:
logger.warning(
"Unable to resolve example images folder for model %s (%s)",
model_name,
model_hash,
)
return False
os.makedirs(model_dir, exist_ok=True)
# First check for local example images - local processing doesn't need delay

View File

@@ -102,6 +102,34 @@ def get_model_folder(model_hash: str, library_name: Optional[str] = None) -> str
return resolved_folder
class ExampleImagePathResolver:
"""Convenience wrapper exposing example image path helpers."""
@staticmethod
def get_model_folder(model_hash: str, library_name: Optional[str] = None) -> str:
"""Return the example image folder for a model, migrating legacy paths."""
return get_model_folder(model_hash, library_name)
@staticmethod
def get_library_root(library_name: Optional[str] = None) -> str:
"""Return the configured library root for example images."""
return get_library_root(library_name)
@staticmethod
def ensure_library_root_exists(library_name: Optional[str] = None) -> str:
"""Ensure the library root exists before writing files."""
return ensure_library_root_exists(library_name)
@staticmethod
def get_model_relative_path(model_hash: str, library_name: Optional[str] = None) -> str:
"""Return the relative path to a model folder from the static mount point."""
return get_model_relative_path(model_hash, library_name)
def get_model_relative_path(model_hash: str, library_name: Optional[str] = None) -> str:
"""Return the relative URL path from the static mount to a model folder."""

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import json
from types import SimpleNamespace
import pytest
@@ -226,3 +227,152 @@ async def test_pause_resume_blocks_processing(monkeypatch: pytest.MonkeyPatch, t
if manager._download_task is not None:
await asyncio.wait_for(manager._download_task, timeout=1)
monkeypatch.setattr(download_module.asyncio, "sleep", original_sleep)
@pytest.mark.usefixtures("tmp_path")
async def test_legacy_folder_migrated_and_skipped(monkeypatch: pytest.MonkeyPatch, tmp_path):
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}})
monkeypatch.setitem(settings.settings, "active_library", "extra")
model_hash = "d" * 64
model_path = tmp_path / "model.safetensors"
model_path.write_text("data", encoding="utf-8")
model = {
"sha256": model_hash,
"model_name": "Migrated Model",
"file_path": str(model_path),
"file_name": "model.safetensors",
"civitai": {"images": [{"url": "https://example.com/image.png"}]},
}
_patch_scanner(monkeypatch, StubScanner([model]))
legacy_folder = tmp_path / model_hash
legacy_folder.mkdir()
(legacy_folder / "image_0.png").write_text("data", encoding="utf-8")
process_called = False
download_called = False
async def fake_process_local_examples(*_args, **_kwargs):
nonlocal process_called
process_called = True
return False
async def fake_download_model_images(*_args, **_kwargs):
nonlocal download_called
download_called = True
return True, False
async def fake_get_downloader():
class _Downloader:
async def download_to_memory(self, *_a, **_kw):
return True, b"", {}
return _Downloader()
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"process_local_examples",
staticmethod(fake_process_local_examples),
)
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"download_model_images",
staticmethod(fake_download_model_images),
)
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
try:
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
if manager._download_task is not None:
await asyncio.wait_for(manager._download_task, timeout=1)
finally:
if manager._download_task is not None and not manager._download_task.done():
await asyncio.wait_for(manager._download_task, timeout=1)
library_root = tmp_path / "extra"
migrated_folder = library_root / model_hash
assert migrated_folder.exists()
assert not legacy_folder.exists()
assert not process_called
assert not download_called
assert model_hash in manager._progress["processed_models"]
@pytest.mark.usefixtures("tmp_path")
async def test_legacy_progress_file_migrates(monkeypatch: pytest.MonkeyPatch, tmp_path):
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}})
monkeypatch.setitem(settings.settings, "active_library", "extra")
model_hash = "e" * 64
model_path = tmp_path / "model-two.safetensors"
model_path.write_text("data", encoding="utf-8")
legacy_progress = tmp_path / ".download_progress.json"
legacy_progress.write_text(json.dumps({"processed_models": [model_hash], "failed_models": []}), encoding="utf-8")
legacy_folder = tmp_path / model_hash
legacy_folder.mkdir()
(legacy_folder / "existing.png").write_text("data", encoding="utf-8")
model = {
"sha256": model_hash,
"model_name": "Legacy Progress Model",
"file_path": str(model_path),
"file_name": "model-two.safetensors",
"civitai": {"images": [{"url": "https://example.com/image.png"}]},
}
_patch_scanner(monkeypatch, StubScanner([model]))
async def fake_process_local_examples(*_args, **_kwargs):
return False
async def fake_download_model_images(*_args, **_kwargs):
raise AssertionError("Remote download should not be attempted when progress is migrated")
async def fake_get_downloader():
class _Downloader:
async def download_to_memory(self, *_a, **_kw):
return True, b"", {}
return _Downloader()
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"process_local_examples",
staticmethod(fake_process_local_examples),
)
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"download_model_images",
staticmethod(fake_download_model_images),
)
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
if manager._download_task is not None:
await asyncio.wait_for(manager._download_task, timeout=1)
new_progress = (tmp_path / "extra") / ".download_progress.json"
assert model_hash in manager._progress["processed_models"]
assert not legacy_progress.exists()
assert new_progress.exists()
contents = json.loads(new_progress.read_text(encoding="utf-8"))
assert model_hash in contents.get("processed_models", [])