mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
Merge pull request #519 from willmiao/codex/update-example-images-download-flow
fix: reuse migrated example image folders before download
This commit is contained in:
@@ -1,14 +1,19 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..utils.example_images_paths import ensure_library_root_exists
|
from ..utils.example_images_paths import (
|
||||||
|
ExampleImagePathResolver,
|
||||||
|
ensure_library_root_exists,
|
||||||
|
uses_library_scoped_folders,
|
||||||
|
)
|
||||||
from ..utils.metadata_manager import MetadataManager
|
from ..utils.metadata_manager import MetadataManager
|
||||||
from .example_images_processor import ExampleImagesProcessor
|
from .example_images_processor import ExampleImagesProcessor
|
||||||
from .example_images_metadata import MetadataUpdater
|
from .example_images_metadata import MetadataUpdater
|
||||||
@@ -75,6 +80,22 @@ class _DownloadProgress(dict):
|
|||||||
snapshot['failed_models'] = list(self['failed_models'])
|
snapshot['failed_models'] = list(self['failed_models'])
|
||||||
return snapshot
|
return snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def _model_directory_has_files(path: str) -> bool:
|
||||||
|
"""Return True when the provided directory exists and contains entries."""
|
||||||
|
|
||||||
|
if not path or not os.path.isdir(path):
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
with os.scandir(path) as entries:
|
||||||
|
for _ in entries:
|
||||||
|
return True
|
||||||
|
except OSError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
class DownloadManager:
|
class DownloadManager:
|
||||||
"""Manages downloading example images for models."""
|
"""Manages downloading example images for models."""
|
||||||
|
|
||||||
@@ -128,9 +149,31 @@ class DownloadManager:
|
|||||||
self._progress['end_time'] = None
|
self._progress['end_time'] = None
|
||||||
|
|
||||||
progress_file = os.path.join(output_dir, '.download_progress.json')
|
progress_file = os.path.join(output_dir, '.download_progress.json')
|
||||||
if os.path.exists(progress_file):
|
progress_source = progress_file
|
||||||
|
if uses_library_scoped_folders():
|
||||||
|
legacy_root = settings.get('example_images_path') or ''
|
||||||
|
legacy_progress = os.path.join(legacy_root, '.download_progress.json') if legacy_root else ''
|
||||||
|
if legacy_progress and os.path.exists(legacy_progress) and not os.path.exists(progress_file):
|
||||||
|
try:
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
shutil.move(legacy_progress, progress_file)
|
||||||
|
logger.info(
|
||||||
|
"Migrated legacy download progress file '%s' to '%s'",
|
||||||
|
legacy_progress,
|
||||||
|
progress_file,
|
||||||
|
)
|
||||||
|
except OSError as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to migrate download progress file from '%s' to '%s': %s",
|
||||||
|
legacy_progress,
|
||||||
|
progress_file,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
progress_source = legacy_progress
|
||||||
|
|
||||||
|
if os.path.exists(progress_source):
|
||||||
try:
|
try:
|
||||||
with open(progress_file, 'r', encoding='utf-8') as f:
|
with open(progress_source, 'r', encoding='utf-8') as f:
|
||||||
saved_progress = json.load(f)
|
saved_progress = json.load(f)
|
||||||
self._progress['processed_models'] = set(saved_progress.get('processed_models', []))
|
self._progress['processed_models'] = set(saved_progress.get('processed_models', []))
|
||||||
self._progress['failed_models'] = set(saved_progress.get('failed_models', []))
|
self._progress['failed_models'] = set(saved_progress.get('failed_models', []))
|
||||||
@@ -326,20 +369,35 @@ class DownloadManager:
|
|||||||
logger.debug(f"Skipping known failed model: {model_name}")
|
logger.debug(f"Skipping known failed model: {model_name}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
model_dir = ExampleImagePathResolver.get_model_folder(model_hash)
|
||||||
|
existing_files = _model_directory_has_files(model_dir)
|
||||||
|
|
||||||
# Skip if already processed AND directory exists with files
|
# Skip if already processed AND directory exists with files
|
||||||
if model_hash in self._progress['processed_models']:
|
if model_hash in self._progress['processed_models']:
|
||||||
model_dir = os.path.join(output_dir, model_hash)
|
if existing_files:
|
||||||
has_files = os.path.exists(model_dir) and any(os.listdir(model_dir))
|
|
||||||
if has_files:
|
|
||||||
logger.debug(f"Skipping already processed model: {model_name}")
|
logger.debug(f"Skipping already processed model: {model_name}")
|
||||||
return False
|
return False
|
||||||
else:
|
logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing")
|
||||||
logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing")
|
# Remove from processed models since we need to reprocess
|
||||||
# Remove from processed models since we need to reprocess
|
self._progress['processed_models'].discard(model_hash)
|
||||||
self._progress['processed_models'].discard(model_hash)
|
|
||||||
|
if existing_files and model_hash not in self._progress['processed_models']:
|
||||||
|
logger.debug(
|
||||||
|
"Model folder already populated for %s, marking as processed without download",
|
||||||
|
model_name,
|
||||||
|
)
|
||||||
|
self._progress['processed_models'].add(model_hash)
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not model_dir:
|
||||||
|
logger.warning(
|
||||||
|
"Unable to resolve example images folder for model %s (%s)",
|
||||||
|
model_name,
|
||||||
|
model_hash,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
# Create model directory
|
# Create model directory
|
||||||
model_dir = os.path.join(output_dir, model_hash)
|
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
|
||||||
# First check for local example images - local processing doesn't need delay
|
# First check for local example images - local processing doesn't need delay
|
||||||
@@ -629,8 +687,15 @@ class DownloadManager:
|
|||||||
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
|
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
|
||||||
await self._broadcast_progress(status='running')
|
await self._broadcast_progress(status='running')
|
||||||
|
|
||||||
# Create model directory
|
model_dir = ExampleImagePathResolver.get_model_folder(model_hash)
|
||||||
model_dir = os.path.join(output_dir, model_hash)
|
if not model_dir:
|
||||||
|
logger.warning(
|
||||||
|
"Unable to resolve example images folder for model %s (%s)",
|
||||||
|
model_name,
|
||||||
|
model_hash,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
|
||||||
# First check for local example images - local processing doesn't need delay
|
# First check for local example images - local processing doesn't need delay
|
||||||
|
|||||||
@@ -102,6 +102,34 @@ def get_model_folder(model_hash: str, library_name: Optional[str] = None) -> str
|
|||||||
return resolved_folder
|
return resolved_folder
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleImagePathResolver:
|
||||||
|
"""Convenience wrapper exposing example image path helpers."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_model_folder(model_hash: str, library_name: Optional[str] = None) -> str:
|
||||||
|
"""Return the example image folder for a model, migrating legacy paths."""
|
||||||
|
|
||||||
|
return get_model_folder(model_hash, library_name)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_library_root(library_name: Optional[str] = None) -> str:
|
||||||
|
"""Return the configured library root for example images."""
|
||||||
|
|
||||||
|
return get_library_root(library_name)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ensure_library_root_exists(library_name: Optional[str] = None) -> str:
|
||||||
|
"""Ensure the library root exists before writing files."""
|
||||||
|
|
||||||
|
return ensure_library_root_exists(library_name)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_model_relative_path(model_hash: str, library_name: Optional[str] = None) -> str:
|
||||||
|
"""Return the relative path to a model folder from the static mount point."""
|
||||||
|
|
||||||
|
return get_model_relative_path(model_hash, library_name)
|
||||||
|
|
||||||
|
|
||||||
def get_model_relative_path(model_hash: str, library_name: Optional[str] = None) -> str:
|
def get_model_relative_path(model_hash: str, library_name: Optional[str] = None) -> str:
|
||||||
"""Return the relative URL path from the static mount to a model folder."""
|
"""Return the relative URL path from the static mount to a model folder."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -226,3 +227,152 @@ async def test_pause_resume_blocks_processing(monkeypatch: pytest.MonkeyPatch, t
|
|||||||
if manager._download_task is not None:
|
if manager._download_task is not None:
|
||||||
await asyncio.wait_for(manager._download_task, timeout=1)
|
await asyncio.wait_for(manager._download_task, timeout=1)
|
||||||
monkeypatch.setattr(download_module.asyncio, "sleep", original_sleep)
|
monkeypatch.setattr(download_module.asyncio, "sleep", original_sleep)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_legacy_folder_migrated_and_skipped(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
|
||||||
|
monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}})
|
||||||
|
monkeypatch.setitem(settings.settings, "active_library", "extra")
|
||||||
|
|
||||||
|
model_hash = "d" * 64
|
||||||
|
model_path = tmp_path / "model.safetensors"
|
||||||
|
model_path.write_text("data", encoding="utf-8")
|
||||||
|
|
||||||
|
model = {
|
||||||
|
"sha256": model_hash,
|
||||||
|
"model_name": "Migrated Model",
|
||||||
|
"file_path": str(model_path),
|
||||||
|
"file_name": "model.safetensors",
|
||||||
|
"civitai": {"images": [{"url": "https://example.com/image.png"}]},
|
||||||
|
}
|
||||||
|
|
||||||
|
_patch_scanner(monkeypatch, StubScanner([model]))
|
||||||
|
|
||||||
|
legacy_folder = tmp_path / model_hash
|
||||||
|
legacy_folder.mkdir()
|
||||||
|
(legacy_folder / "image_0.png").write_text("data", encoding="utf-8")
|
||||||
|
|
||||||
|
process_called = False
|
||||||
|
download_called = False
|
||||||
|
|
||||||
|
async def fake_process_local_examples(*_args, **_kwargs):
|
||||||
|
nonlocal process_called
|
||||||
|
process_called = True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def fake_download_model_images(*_args, **_kwargs):
|
||||||
|
nonlocal download_called
|
||||||
|
download_called = True
|
||||||
|
return True, False
|
||||||
|
|
||||||
|
async def fake_get_downloader():
|
||||||
|
class _Downloader:
|
||||||
|
async def download_to_memory(self, *_a, **_kw):
|
||||||
|
return True, b"", {}
|
||||||
|
|
||||||
|
return _Downloader()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_module.ExampleImagesProcessor,
|
||||||
|
"process_local_examples",
|
||||||
|
staticmethod(fake_process_local_examples),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_module.ExampleImagesProcessor,
|
||||||
|
"download_model_images",
|
||||||
|
staticmethod(fake_download_model_images),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
|
||||||
|
assert result["success"] is True
|
||||||
|
|
||||||
|
if manager._download_task is not None:
|
||||||
|
await asyncio.wait_for(manager._download_task, timeout=1)
|
||||||
|
finally:
|
||||||
|
if manager._download_task is not None and not manager._download_task.done():
|
||||||
|
await asyncio.wait_for(manager._download_task, timeout=1)
|
||||||
|
|
||||||
|
library_root = tmp_path / "extra"
|
||||||
|
migrated_folder = library_root / model_hash
|
||||||
|
|
||||||
|
assert migrated_folder.exists()
|
||||||
|
assert not legacy_folder.exists()
|
||||||
|
assert not process_called
|
||||||
|
assert not download_called
|
||||||
|
assert model_hash in manager._progress["processed_models"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_legacy_progress_file_migrates(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
|
||||||
|
monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}})
|
||||||
|
monkeypatch.setitem(settings.settings, "active_library", "extra")
|
||||||
|
|
||||||
|
model_hash = "e" * 64
|
||||||
|
model_path = tmp_path / "model-two.safetensors"
|
||||||
|
model_path.write_text("data", encoding="utf-8")
|
||||||
|
|
||||||
|
legacy_progress = tmp_path / ".download_progress.json"
|
||||||
|
legacy_progress.write_text(json.dumps({"processed_models": [model_hash], "failed_models": []}), encoding="utf-8")
|
||||||
|
|
||||||
|
legacy_folder = tmp_path / model_hash
|
||||||
|
legacy_folder.mkdir()
|
||||||
|
(legacy_folder / "existing.png").write_text("data", encoding="utf-8")
|
||||||
|
|
||||||
|
model = {
|
||||||
|
"sha256": model_hash,
|
||||||
|
"model_name": "Legacy Progress Model",
|
||||||
|
"file_path": str(model_path),
|
||||||
|
"file_name": "model-two.safetensors",
|
||||||
|
"civitai": {"images": [{"url": "https://example.com/image.png"}]},
|
||||||
|
}
|
||||||
|
|
||||||
|
_patch_scanner(monkeypatch, StubScanner([model]))
|
||||||
|
|
||||||
|
async def fake_process_local_examples(*_args, **_kwargs):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def fake_download_model_images(*_args, **_kwargs):
|
||||||
|
raise AssertionError("Remote download should not be attempted when progress is migrated")
|
||||||
|
|
||||||
|
async def fake_get_downloader():
|
||||||
|
class _Downloader:
|
||||||
|
async def download_to_memory(self, *_a, **_kw):
|
||||||
|
return True, b"", {}
|
||||||
|
|
||||||
|
return _Downloader()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_module.ExampleImagesProcessor,
|
||||||
|
"process_local_examples",
|
||||||
|
staticmethod(fake_process_local_examples),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_module.ExampleImagesProcessor,
|
||||||
|
"download_model_images",
|
||||||
|
staticmethod(fake_download_model_images),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
|
||||||
|
|
||||||
|
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
|
||||||
|
assert result["success"] is True
|
||||||
|
|
||||||
|
if manager._download_task is not None:
|
||||||
|
await asyncio.wait_for(manager._download_task, timeout=1)
|
||||||
|
|
||||||
|
new_progress = (tmp_path / "extra") / ".download_progress.json"
|
||||||
|
|
||||||
|
assert model_hash in manager._progress["processed_models"]
|
||||||
|
assert not legacy_progress.exists()
|
||||||
|
assert new_progress.exists()
|
||||||
|
contents = json.loads(new_progress.read_text(encoding="utf-8"))
|
||||||
|
assert model_hash in contents.get("processed_models", [])
|
||||||
|
|||||||
Reference in New Issue
Block a user