mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Merge pull request #519 from willmiao/codex/update-example-images-download-flow
fix: reuse migrated example image folders before download
This commit is contained in:
@@ -1,14 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict
|
||||
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.example_images_paths import ensure_library_root_exists
|
||||
from ..utils.example_images_paths import (
|
||||
ExampleImagePathResolver,
|
||||
ensure_library_root_exists,
|
||||
uses_library_scoped_folders,
|
||||
)
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .example_images_processor import ExampleImagesProcessor
|
||||
from .example_images_metadata import MetadataUpdater
|
||||
@@ -75,6 +80,22 @@ class _DownloadProgress(dict):
|
||||
snapshot['failed_models'] = list(self['failed_models'])
|
||||
return snapshot
|
||||
|
||||
|
||||
def _model_directory_has_files(path: str) -> bool:
|
||||
"""Return True when the provided directory exists and contains entries."""
|
||||
|
||||
if not path or not os.path.isdir(path):
|
||||
return False
|
||||
|
||||
try:
|
||||
with os.scandir(path) as entries:
|
||||
for _ in entries:
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
class DownloadManager:
|
||||
"""Manages downloading example images for models."""
|
||||
|
||||
@@ -128,9 +149,31 @@ class DownloadManager:
|
||||
self._progress['end_time'] = None
|
||||
|
||||
progress_file = os.path.join(output_dir, '.download_progress.json')
|
||||
if os.path.exists(progress_file):
|
||||
progress_source = progress_file
|
||||
if uses_library_scoped_folders():
|
||||
legacy_root = settings.get('example_images_path') or ''
|
||||
legacy_progress = os.path.join(legacy_root, '.download_progress.json') if legacy_root else ''
|
||||
if legacy_progress and os.path.exists(legacy_progress) and not os.path.exists(progress_file):
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
shutil.move(legacy_progress, progress_file)
|
||||
logger.info(
|
||||
"Migrated legacy download progress file '%s' to '%s'",
|
||||
legacy_progress,
|
||||
progress_file,
|
||||
)
|
||||
except OSError as exc:
|
||||
logger.warning(
|
||||
"Failed to migrate download progress file from '%s' to '%s': %s",
|
||||
legacy_progress,
|
||||
progress_file,
|
||||
exc,
|
||||
)
|
||||
progress_source = legacy_progress
|
||||
|
||||
if os.path.exists(progress_source):
|
||||
try:
|
||||
with open(progress_file, 'r', encoding='utf-8') as f:
|
||||
with open(progress_source, 'r', encoding='utf-8') as f:
|
||||
saved_progress = json.load(f)
|
||||
self._progress['processed_models'] = set(saved_progress.get('processed_models', []))
|
||||
self._progress['failed_models'] = set(saved_progress.get('failed_models', []))
|
||||
@@ -326,20 +369,35 @@ class DownloadManager:
|
||||
logger.debug(f"Skipping known failed model: {model_name}")
|
||||
return False
|
||||
|
||||
model_dir = ExampleImagePathResolver.get_model_folder(model_hash)
|
||||
existing_files = _model_directory_has_files(model_dir)
|
||||
|
||||
# Skip if already processed AND directory exists with files
|
||||
if model_hash in self._progress['processed_models']:
|
||||
model_dir = os.path.join(output_dir, model_hash)
|
||||
has_files = os.path.exists(model_dir) and any(os.listdir(model_dir))
|
||||
if has_files:
|
||||
if existing_files:
|
||||
logger.debug(f"Skipping already processed model: {model_name}")
|
||||
return False
|
||||
else:
|
||||
logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing")
|
||||
# Remove from processed models since we need to reprocess
|
||||
self._progress['processed_models'].discard(model_hash)
|
||||
|
||||
logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing")
|
||||
# Remove from processed models since we need to reprocess
|
||||
self._progress['processed_models'].discard(model_hash)
|
||||
|
||||
if existing_files and model_hash not in self._progress['processed_models']:
|
||||
logger.debug(
|
||||
"Model folder already populated for %s, marking as processed without download",
|
||||
model_name,
|
||||
)
|
||||
self._progress['processed_models'].add(model_hash)
|
||||
return False
|
||||
|
||||
if not model_dir:
|
||||
logger.warning(
|
||||
"Unable to resolve example images folder for model %s (%s)",
|
||||
model_name,
|
||||
model_hash,
|
||||
)
|
||||
return False
|
||||
|
||||
# Create model directory
|
||||
model_dir = os.path.join(output_dir, model_hash)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
# First check for local example images - local processing doesn't need delay
|
||||
@@ -629,8 +687,15 @@ class DownloadManager:
|
||||
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
|
||||
await self._broadcast_progress(status='running')
|
||||
|
||||
# Create model directory
|
||||
model_dir = os.path.join(output_dir, model_hash)
|
||||
model_dir = ExampleImagePathResolver.get_model_folder(model_hash)
|
||||
if not model_dir:
|
||||
logger.warning(
|
||||
"Unable to resolve example images folder for model %s (%s)",
|
||||
model_name,
|
||||
model_hash,
|
||||
)
|
||||
return False
|
||||
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
# First check for local example images - local processing doesn't need delay
|
||||
|
||||
@@ -102,6 +102,34 @@ def get_model_folder(model_hash: str, library_name: Optional[str] = None) -> str
|
||||
return resolved_folder
|
||||
|
||||
|
||||
class ExampleImagePathResolver:
|
||||
"""Convenience wrapper exposing example image path helpers."""
|
||||
|
||||
@staticmethod
|
||||
def get_model_folder(model_hash: str, library_name: Optional[str] = None) -> str:
|
||||
"""Return the example image folder for a model, migrating legacy paths."""
|
||||
|
||||
return get_model_folder(model_hash, library_name)
|
||||
|
||||
@staticmethod
|
||||
def get_library_root(library_name: Optional[str] = None) -> str:
|
||||
"""Return the configured library root for example images."""
|
||||
|
||||
return get_library_root(library_name)
|
||||
|
||||
@staticmethod
|
||||
def ensure_library_root_exists(library_name: Optional[str] = None) -> str:
|
||||
"""Ensure the library root exists before writing files."""
|
||||
|
||||
return ensure_library_root_exists(library_name)
|
||||
|
||||
@staticmethod
|
||||
def get_model_relative_path(model_hash: str, library_name: Optional[str] = None) -> str:
|
||||
"""Return the relative path to a model folder from the static mount point."""
|
||||
|
||||
return get_model_relative_path(model_hash, library_name)
|
||||
|
||||
|
||||
def get_model_relative_path(model_hash: str, library_name: Optional[str] = None) -> str:
|
||||
"""Return the relative URL path from the static mount to a model folder."""
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
@@ -226,3 +227,152 @@ async def test_pause_resume_blocks_processing(monkeypatch: pytest.MonkeyPatch, t
|
||||
if manager._download_task is not None:
|
||||
await asyncio.wait_for(manager._download_task, timeout=1)
|
||||
monkeypatch.setattr(download_module.asyncio, "sleep", original_sleep)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_legacy_folder_migrated_and_skipped(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
|
||||
monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}})
|
||||
monkeypatch.setitem(settings.settings, "active_library", "extra")
|
||||
|
||||
model_hash = "d" * 64
|
||||
model_path = tmp_path / "model.safetensors"
|
||||
model_path.write_text("data", encoding="utf-8")
|
||||
|
||||
model = {
|
||||
"sha256": model_hash,
|
||||
"model_name": "Migrated Model",
|
||||
"file_path": str(model_path),
|
||||
"file_name": "model.safetensors",
|
||||
"civitai": {"images": [{"url": "https://example.com/image.png"}]},
|
||||
}
|
||||
|
||||
_patch_scanner(monkeypatch, StubScanner([model]))
|
||||
|
||||
legacy_folder = tmp_path / model_hash
|
||||
legacy_folder.mkdir()
|
||||
(legacy_folder / "image_0.png").write_text("data", encoding="utf-8")
|
||||
|
||||
process_called = False
|
||||
download_called = False
|
||||
|
||||
async def fake_process_local_examples(*_args, **_kwargs):
|
||||
nonlocal process_called
|
||||
process_called = True
|
||||
return False
|
||||
|
||||
async def fake_download_model_images(*_args, **_kwargs):
|
||||
nonlocal download_called
|
||||
download_called = True
|
||||
return True, False
|
||||
|
||||
async def fake_get_downloader():
|
||||
class _Downloader:
|
||||
async def download_to_memory(self, *_a, **_kw):
|
||||
return True, b"", {}
|
||||
|
||||
return _Downloader()
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_module.ExampleImagesProcessor,
|
||||
"process_local_examples",
|
||||
staticmethod(fake_process_local_examples),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
download_module.ExampleImagesProcessor,
|
||||
"download_model_images",
|
||||
staticmethod(fake_download_model_images),
|
||||
)
|
||||
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
|
||||
|
||||
try:
|
||||
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
|
||||
assert result["success"] is True
|
||||
|
||||
if manager._download_task is not None:
|
||||
await asyncio.wait_for(manager._download_task, timeout=1)
|
||||
finally:
|
||||
if manager._download_task is not None and not manager._download_task.done():
|
||||
await asyncio.wait_for(manager._download_task, timeout=1)
|
||||
|
||||
library_root = tmp_path / "extra"
|
||||
migrated_folder = library_root / model_hash
|
||||
|
||||
assert migrated_folder.exists()
|
||||
assert not legacy_folder.exists()
|
||||
assert not process_called
|
||||
assert not download_called
|
||||
assert model_hash in manager._progress["processed_models"]
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_legacy_progress_file_migrates(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
|
||||
monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}})
|
||||
monkeypatch.setitem(settings.settings, "active_library", "extra")
|
||||
|
||||
model_hash = "e" * 64
|
||||
model_path = tmp_path / "model-two.safetensors"
|
||||
model_path.write_text("data", encoding="utf-8")
|
||||
|
||||
legacy_progress = tmp_path / ".download_progress.json"
|
||||
legacy_progress.write_text(json.dumps({"processed_models": [model_hash], "failed_models": []}), encoding="utf-8")
|
||||
|
||||
legacy_folder = tmp_path / model_hash
|
||||
legacy_folder.mkdir()
|
||||
(legacy_folder / "existing.png").write_text("data", encoding="utf-8")
|
||||
|
||||
model = {
|
||||
"sha256": model_hash,
|
||||
"model_name": "Legacy Progress Model",
|
||||
"file_path": str(model_path),
|
||||
"file_name": "model-two.safetensors",
|
||||
"civitai": {"images": [{"url": "https://example.com/image.png"}]},
|
||||
}
|
||||
|
||||
_patch_scanner(monkeypatch, StubScanner([model]))
|
||||
|
||||
async def fake_process_local_examples(*_args, **_kwargs):
|
||||
return False
|
||||
|
||||
async def fake_download_model_images(*_args, **_kwargs):
|
||||
raise AssertionError("Remote download should not be attempted when progress is migrated")
|
||||
|
||||
async def fake_get_downloader():
|
||||
class _Downloader:
|
||||
async def download_to_memory(self, *_a, **_kw):
|
||||
return True, b"", {}
|
||||
|
||||
return _Downloader()
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_module.ExampleImagesProcessor,
|
||||
"process_local_examples",
|
||||
staticmethod(fake_process_local_examples),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
download_module.ExampleImagesProcessor,
|
||||
"download_model_images",
|
||||
staticmethod(fake_download_model_images),
|
||||
)
|
||||
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
|
||||
|
||||
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
|
||||
assert result["success"] is True
|
||||
|
||||
if manager._download_task is not None:
|
||||
await asyncio.wait_for(manager._download_task, timeout=1)
|
||||
|
||||
new_progress = (tmp_path / "extra") / ".download_progress.json"
|
||||
|
||||
assert model_hash in manager._progress["processed_models"]
|
||||
assert not legacy_progress.exists()
|
||||
assert new_progress.exists()
|
||||
contents = json.loads(new_progress.read_text(encoding="utf-8"))
|
||||
assert model_hash in contents.get("processed_models", [])
|
||||
|
||||
Reference in New Issue
Block a user