mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Merge pull request #520 from willmiao/codex/adjust-example-images-download-to-use-library-name
fix: keep example image downloads in initial library
This commit is contained in:
@@ -106,11 +106,10 @@ class DownloadManager:
|
|||||||
self._ws_manager = ws_manager
|
self._ws_manager = ws_manager
|
||||||
self._state_lock = state_lock or asyncio.Lock()
|
self._state_lock = state_lock or asyncio.Lock()
|
||||||
|
|
||||||
def _resolve_output_dir(self) -> str:
|
def _resolve_output_dir(self, library_name: str | None = None) -> str:
|
||||||
base_path = settings.get('example_images_path')
|
base_path = settings.get('example_images_path')
|
||||||
if not base_path:
|
if not base_path:
|
||||||
return ''
|
return ''
|
||||||
library_name = settings.get_active_library_name()
|
|
||||||
return ensure_library_root_exists(library_name)
|
return ensure_library_root_exists(library_name)
|
||||||
|
|
||||||
async def start_download(self, options: dict):
|
async def start_download(self, options: dict):
|
||||||
@@ -139,7 +138,8 @@ class DownloadManager:
|
|||||||
}
|
}
|
||||||
raise DownloadConfigurationError(error_msg)
|
raise DownloadConfigurationError(error_msg)
|
||||||
|
|
||||||
output_dir = self._resolve_output_dir()
|
active_library = settings.get_active_library_name()
|
||||||
|
output_dir = self._resolve_output_dir(active_library)
|
||||||
if not output_dir:
|
if not output_dir:
|
||||||
raise DownloadConfigurationError('Example images path not configured in settings')
|
raise DownloadConfigurationError('Example images path not configured in settings')
|
||||||
|
|
||||||
@@ -196,7 +196,8 @@ class DownloadManager:
|
|||||||
output_dir,
|
output_dir,
|
||||||
optimize,
|
optimize,
|
||||||
model_types,
|
model_types,
|
||||||
delay
|
delay,
|
||||||
|
active_library,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -261,7 +262,14 @@ class DownloadManager:
|
|||||||
'message': 'Download resumed'
|
'message': 'Download resumed'
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _download_all_example_images(self, output_dir, optimize, model_types, delay):
|
async def _download_all_example_images(
|
||||||
|
self,
|
||||||
|
output_dir,
|
||||||
|
optimize,
|
||||||
|
model_types,
|
||||||
|
delay,
|
||||||
|
library_name,
|
||||||
|
):
|
||||||
"""Download example images for all models."""
|
"""Download example images for all models."""
|
||||||
|
|
||||||
downloader = await get_downloader()
|
downloader = await get_downloader()
|
||||||
@@ -299,8 +307,13 @@ class DownloadManager:
|
|||||||
for i, (scanner_type, model, scanner) in enumerate(all_models):
|
for i, (scanner_type, model, scanner) in enumerate(all_models):
|
||||||
# Main logic for processing model is here, but actual operations are delegated to other classes
|
# Main logic for processing model is here, but actual operations are delegated to other classes
|
||||||
was_remote_download = await self._process_model(
|
was_remote_download = await self._process_model(
|
||||||
scanner_type, model, scanner,
|
scanner_type,
|
||||||
output_dir, optimize, downloader
|
model,
|
||||||
|
scanner,
|
||||||
|
output_dir,
|
||||||
|
optimize,
|
||||||
|
downloader,
|
||||||
|
library_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update progress
|
# Update progress
|
||||||
@@ -342,7 +355,16 @@ class DownloadManager:
|
|||||||
self._is_downloading = False
|
self._is_downloading = False
|
||||||
self._download_task = None
|
self._download_task = None
|
||||||
|
|
||||||
async def _process_model(self, scanner_type, model, scanner, output_dir, optimize, downloader):
|
async def _process_model(
|
||||||
|
self,
|
||||||
|
scanner_type,
|
||||||
|
model,
|
||||||
|
scanner,
|
||||||
|
output_dir,
|
||||||
|
optimize,
|
||||||
|
downloader,
|
||||||
|
library_name,
|
||||||
|
):
|
||||||
"""Process a single model download."""
|
"""Process a single model download."""
|
||||||
|
|
||||||
# Check if download is paused
|
# Check if download is paused
|
||||||
@@ -369,7 +391,7 @@ class DownloadManager:
|
|||||||
logger.debug(f"Skipping known failed model: {model_name}")
|
logger.debug(f"Skipping known failed model: {model_name}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
model_dir = ExampleImagePathResolver.get_model_folder(model_hash)
|
model_dir = ExampleImagePathResolver.get_model_folder(model_hash, library_name)
|
||||||
existing_files = _model_directory_has_files(model_dir)
|
existing_files = _model_directory_has_files(model_dir)
|
||||||
|
|
||||||
# Skip if already processed AND directory exists with files
|
# Skip if already processed AND directory exists with files
|
||||||
@@ -532,7 +554,8 @@ class DownloadManager:
|
|||||||
|
|
||||||
if not base_path:
|
if not base_path:
|
||||||
raise DownloadConfigurationError('Example images path not configured in settings')
|
raise DownloadConfigurationError('Example images path not configured in settings')
|
||||||
output_dir = self._resolve_output_dir()
|
active_library = settings.get_active_library_name()
|
||||||
|
output_dir = self._resolve_output_dir(active_library)
|
||||||
if not output_dir:
|
if not output_dir:
|
||||||
raise DownloadConfigurationError('Example images path not configured in settings')
|
raise DownloadConfigurationError('Example images path not configured in settings')
|
||||||
|
|
||||||
@@ -552,7 +575,8 @@ class DownloadManager:
|
|||||||
output_dir,
|
output_dir,
|
||||||
optimize,
|
optimize,
|
||||||
model_types,
|
model_types,
|
||||||
delay
|
delay,
|
||||||
|
active_library,
|
||||||
)
|
)
|
||||||
|
|
||||||
async with self._state_lock:
|
async with self._state_lock:
|
||||||
@@ -571,7 +595,15 @@ class DownloadManager:
|
|||||||
await self._broadcast_progress(status='error', extra={'error': str(e)})
|
await self._broadcast_progress(status='error', extra={'error': str(e)})
|
||||||
raise ExampleImagesDownloadError(str(e)) from e
|
raise ExampleImagesDownloadError(str(e)) from e
|
||||||
|
|
||||||
async def _download_specific_models_example_images_sync(self, model_hashes, output_dir, optimize, model_types, delay):
|
async def _download_specific_models_example_images_sync(
|
||||||
|
self,
|
||||||
|
model_hashes,
|
||||||
|
output_dir,
|
||||||
|
optimize,
|
||||||
|
model_types,
|
||||||
|
delay,
|
||||||
|
library_name,
|
||||||
|
):
|
||||||
"""Download example images for specific models only - synchronous version."""
|
"""Download example images for specific models only - synchronous version."""
|
||||||
|
|
||||||
downloader = await get_downloader()
|
downloader = await get_downloader()
|
||||||
@@ -612,8 +644,13 @@ class DownloadManager:
|
|||||||
for i, (scanner_type, model, scanner) in enumerate(models_to_process):
|
for i, (scanner_type, model, scanner) in enumerate(models_to_process):
|
||||||
# Force process this model regardless of previous status
|
# Force process this model regardless of previous status
|
||||||
was_successful = await self._process_specific_model(
|
was_successful = await self._process_specific_model(
|
||||||
scanner_type, model, scanner,
|
scanner_type,
|
||||||
output_dir, optimize, downloader
|
model,
|
||||||
|
scanner,
|
||||||
|
output_dir,
|
||||||
|
optimize,
|
||||||
|
downloader,
|
||||||
|
library_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
if was_successful:
|
if was_successful:
|
||||||
@@ -665,7 +702,16 @@ class DownloadManager:
|
|||||||
# No need to close any sessions since we use the global downloader
|
# No need to close any sessions since we use the global downloader
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _process_specific_model(self, scanner_type, model, scanner, output_dir, optimize, downloader):
|
async def _process_specific_model(
|
||||||
|
self,
|
||||||
|
scanner_type,
|
||||||
|
model,
|
||||||
|
scanner,
|
||||||
|
output_dir,
|
||||||
|
optimize,
|
||||||
|
downloader,
|
||||||
|
library_name,
|
||||||
|
):
|
||||||
"""Process a specific model for forced download, ignoring previous download status."""
|
"""Process a specific model for forced download, ignoring previous download status."""
|
||||||
|
|
||||||
# Check if download is paused
|
# Check if download is paused
|
||||||
@@ -687,7 +733,7 @@ class DownloadManager:
|
|||||||
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
|
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
|
||||||
await self._broadcast_progress(status='running')
|
await self._broadcast_progress(status='running')
|
||||||
|
|
||||||
model_dir = ExampleImagePathResolver.get_model_folder(model_hash)
|
model_dir = ExampleImagePathResolver.get_model_folder(model_hash, library_name)
|
||||||
if not model_dir:
|
if not model_dir:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Unable to resolve example images folder for model %s (%s)",
|
"Unable to resolve example images folder for model %s (%s)",
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -376,3 +377,80 @@ async def test_legacy_progress_file_migrates(monkeypatch: pytest.MonkeyPatch, tm
|
|||||||
assert new_progress.exists()
|
assert new_progress.exists()
|
||||||
contents = json.loads(new_progress.read_text(encoding="utf-8"))
|
contents = json.loads(new_progress.read_text(encoding="utf-8"))
|
||||||
assert model_hash in contents.get("processed_models", [])
|
assert model_hash in contents.get("processed_models", [])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_download_remains_in_initial_library(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
|
||||||
|
monkeypatch.setitem(settings.settings, "libraries", {"LibraryA": {}, "LibraryB": {}})
|
||||||
|
monkeypatch.setitem(settings.settings, "active_library", "LibraryA")
|
||||||
|
|
||||||
|
state = {"active": "LibraryA"}
|
||||||
|
|
||||||
|
def fake_get_active_library_name(self):
|
||||||
|
return state["active"]
|
||||||
|
|
||||||
|
monkeypatch.setattr(type(settings), "get_active_library_name", fake_get_active_library_name)
|
||||||
|
|
||||||
|
model_hash = "f" * 64
|
||||||
|
model_path = tmp_path / "example-model.safetensors"
|
||||||
|
model_path.write_text("data", encoding="utf-8")
|
||||||
|
|
||||||
|
model = {
|
||||||
|
"sha256": model_hash,
|
||||||
|
"model_name": "Library Switch Model",
|
||||||
|
"file_path": str(model_path),
|
||||||
|
"file_name": "example-model.safetensors",
|
||||||
|
}
|
||||||
|
|
||||||
|
_patch_scanner(monkeypatch, StubScanner([model]))
|
||||||
|
|
||||||
|
async def fake_process_local_examples(
|
||||||
|
_file_path,
|
||||||
|
_file_name,
|
||||||
|
_model_name,
|
||||||
|
model_dir,
|
||||||
|
_optimize,
|
||||||
|
):
|
||||||
|
Path(model_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
(Path(model_dir) / "local.txt").write_text("data", encoding="utf-8")
|
||||||
|
state["active"] = "LibraryB"
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def fake_update_metadata(*_args, **_kwargs):
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def fake_get_downloader():
|
||||||
|
return object()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_module.ExampleImagesProcessor,
|
||||||
|
"process_local_examples",
|
||||||
|
staticmethod(fake_process_local_examples),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_module.MetadataUpdater,
|
||||||
|
"update_metadata_from_local_examples",
|
||||||
|
staticmethod(fake_update_metadata),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
|
||||||
|
|
||||||
|
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
|
||||||
|
assert result["success"] is True
|
||||||
|
|
||||||
|
if manager._download_task is not None:
|
||||||
|
await asyncio.wait_for(manager._download_task, timeout=1)
|
||||||
|
|
||||||
|
library_a_root = tmp_path / "LibraryA"
|
||||||
|
library_b_root = tmp_path / "LibraryB"
|
||||||
|
|
||||||
|
progress_file = library_a_root / ".download_progress.json"
|
||||||
|
model_dir = library_a_root / model_hash
|
||||||
|
|
||||||
|
assert progress_file.exists()
|
||||||
|
assert (model_dir / "local.txt").exists()
|
||||||
|
assert not (library_b_root / ".download_progress.json").exists()
|
||||||
|
assert not (library_b_root / model_hash).exists()
|
||||||
|
|||||||
Reference in New Issue
Block a user