Files
ComfyUI-Lora-Manager/py/utils/example_images_download_manager.py
Will Miao 1da476d858 feat(example-images): add check pending models endpoint and improve async handling
- Add /api/example-images/check-pending endpoint to quickly check models needing downloads
- Improve DownloadManager.start_download() to return immediately without blocking
- Add _handle_download_task_done callback for proper error handling and progress saving
- Add check_pending_models() method for lightweight pre-download validation
- Update frontend ExampleImagesManager to use new check-pending endpoint
- Add comprehensive tests for new functionality
2026-02-02 12:31:07 +08:00

1359 lines
53 KiB
Python

from __future__ import annotations
import asyncio
import json
import time
import logging
import os
import re
import shutil
import uuid
from typing import Any, Dict, Iterable, List, Set, Tuple
from ..services.service_registry import ServiceRegistry
from ..utils.example_images_paths import (
ExampleImagePathResolver,
ensure_library_root_exists,
uses_library_scoped_folders,
)
from ..utils.metadata_manager import MetadataManager
from .example_images_processor import ExampleImagesProcessor
from .example_images_metadata import MetadataUpdater
from ..services.downloader import get_downloader
from ..services.settings_manager import get_settings_manager
class ExampleImagesDownloadError(RuntimeError):
"""Base error for example image download operations."""
class DownloadInProgressError(ExampleImagesDownloadError):
"""Raised when a download is already running."""
def __init__(self, progress_snapshot: dict) -> None:
super().__init__("Download already in progress")
self.progress_snapshot = progress_snapshot
class DownloadNotRunningError(ExampleImagesDownloadError):
"""Raised when pause/resume is requested without an active download."""
def __init__(self, message: str = "No download in progress") -> None:
super().__init__(message)
class DownloadConfigurationError(ExampleImagesDownloadError):
"""Raised when configuration prevents starting a download."""
logger = logging.getLogger(__name__)
class _DownloadProgress(dict):
"""Mutable mapping maintaining download progress with set-aware serialisation."""
def __init__(self) -> None:
super().__init__()
self.reset()
def reset(self) -> None:
"""Reset the progress dictionary to its initial state."""
self.update(
total=0,
completed=0,
current_model="",
status="idle",
errors=[],
last_error=None,
start_time=None,
end_time=None,
processed_models=set(),
refreshed_models=set(),
failed_models=set(),
reprocessed_models=set(),
)
def snapshot(self) -> dict:
"""Return a JSON-serialisable snapshot of the current progress."""
snapshot = dict(self)
snapshot["processed_models"] = list(self["processed_models"])
snapshot["refreshed_models"] = list(self["refreshed_models"])
snapshot["failed_models"] = list(self["failed_models"])
snapshot["reprocessed_models"] = list(self.get("reprocessed_models", set()))
return snapshot
def _model_directory_has_files(path: str) -> bool:
"""Return True when the provided directory exists and contains entries."""
if not path or not os.path.isdir(path):
return False
try:
with os.scandir(path) as entries:
for _ in entries:
return True
except OSError:
return False
return False
class DownloadManager:
"""Manages downloading example images for models."""
def __init__(self, *, ws_manager, state_lock: asyncio.Lock | None = None) -> None:
self._download_task: asyncio.Task | None = None
self._is_downloading = False
self._progress = _DownloadProgress()
self._ws_manager = ws_manager
self._state_lock = state_lock or asyncio.Lock()
self._stop_requested = False
def _resolve_output_dir(self, library_name: str | None = None) -> str:
base_path = get_settings_manager().get("example_images_path")
if not base_path:
return ""
return ensure_library_root_exists(library_name)
async def start_download(self, options: dict):
"""Start downloading example images for models."""
async with self._state_lock:
if self._is_downloading:
raise DownloadInProgressError(self._progress.snapshot())
try:
data = options or {}
auto_mode = data.get("auto_mode", False)
optimize = data.get("optimize", True)
model_types = data.get("model_types", ["lora", "checkpoint"])
delay = float(data.get("delay", 0.2))
force = data.get("force", False)
settings_manager = get_settings_manager()
base_path = settings_manager.get("example_images_path")
if not base_path:
error_msg = "Example images path not configured in settings"
if auto_mode:
logger.debug(error_msg)
return {
"success": True,
"message": "Example images path not configured, skipping auto download",
}
raise DownloadConfigurationError(error_msg)
active_library = get_settings_manager().get_active_library_name()
output_dir = self._resolve_output_dir(active_library)
if not output_dir:
raise DownloadConfigurationError(
"Example images path not configured in settings"
)
self._progress.reset()
self._stop_requested = False
self._progress["status"] = "running"
self._progress["start_time"] = time.time()
self._progress["end_time"] = None
progress_file = os.path.join(output_dir, ".download_progress.json")
progress_source = progress_file
if uses_library_scoped_folders():
legacy_root = (
get_settings_manager().get("example_images_path") or ""
)
legacy_progress = (
os.path.join(legacy_root, ".download_progress.json")
if legacy_root
else ""
)
if (
legacy_progress
and os.path.exists(legacy_progress)
and not os.path.exists(progress_file)
):
try:
os.makedirs(output_dir, exist_ok=True)
shutil.move(legacy_progress, progress_file)
logger.info(
"Migrated legacy download progress file '%s' to '%s'",
legacy_progress,
progress_file,
)
except OSError as exc:
logger.warning(
"Failed to migrate download progress file from '%s' to '%s': %s",
legacy_progress,
progress_file,
exc,
)
progress_source = legacy_progress
if os.path.exists(progress_source):
try:
with open(progress_source, "r", encoding="utf-8") as f:
saved_progress = json.load(f)
self._progress["processed_models"] = set(
saved_progress.get("processed_models", [])
)
self._progress["failed_models"] = set(
saved_progress.get("failed_models", [])
)
logger.debug(
"Loaded previous progress, %s models already processed, %s models marked as failed",
len(self._progress["processed_models"]),
len(self._progress["failed_models"]),
)
except Exception as e:
logger.error(f"Failed to load progress file: {e}")
self._progress["processed_models"] = set()
self._progress["failed_models"] = set()
else:
self._progress["processed_models"] = set()
self._progress["failed_models"] = set()
self._is_downloading = True
snapshot = self._progress.snapshot()
# Create the download task without awaiting it
# This ensures the HTTP response is returned immediately
# while the actual processing happens in the background
self._download_task = asyncio.create_task(
self._download_all_example_images(
output_dir,
optimize,
model_types,
delay,
active_library,
force,
)
)
# Add a callback to handle task completion/errors
self._download_task.add_done_callback(
lambda t: self._handle_download_task_done(t, output_dir)
)
except ExampleImagesDownloadError:
# Re-raise our own exception types without wrapping
self._is_downloading = False
self._download_task = None
raise
except Exception as e:
self._is_downloading = False
self._download_task = None
logger.error(
f"Failed to start example images download: {e}", exc_info=True
)
raise ExampleImagesDownloadError(str(e)) from e
# Broadcast progress in the background without blocking the response
# This ensures the HTTP response is returned immediately
asyncio.create_task(self._broadcast_progress(status="running"))
return {"success": True, "message": "Download started", "status": snapshot}
def _handle_download_task_done(self, task: asyncio.Task, output_dir: str) -> None:
"""Handle download task completion, including saving progress on error."""
try:
# This will re-raise any exception from the task
task.result()
except Exception as e:
logger.error(f"Download task failed with error: {e}", exc_info=True)
# Ensure progress is saved even on failure
try:
self._save_progress(output_dir)
except Exception as save_error:
logger.error(f"Failed to save progress after task failure: {save_error}")
async def get_status(self, request):
"""Get the current status of example images download."""
return {
"success": True,
"is_downloading": self._is_downloading,
"status": self._progress.snapshot(),
}
async def check_pending_models(self, model_types: list[str]) -> dict:
"""Quickly check how many models need example images downloaded.
This is a lightweight check that avoids the overhead of starting
a full download task when no work is needed.
Returns:
dict with keys:
- total_models: Total number of models across specified types
- pending_count: Number of models needing example images
- processed_count: Number of already processed models
- failed_count: Number of models marked as failed
- needs_download: True if there are pending models to process
"""
from ..services.service_registry import ServiceRegistry
if self._is_downloading:
return {
"success": True,
"is_downloading": True,
"total_models": 0,
"pending_count": 0,
"processed_count": 0,
"failed_count": 0,
"needs_download": False,
"message": "Download already in progress",
}
try:
# Get scanners
scanners = []
if "lora" in model_types:
lora_scanner = await ServiceRegistry.get_lora_scanner()
scanners.append(("lora", lora_scanner))
if "checkpoint" in model_types:
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
scanners.append(("checkpoint", checkpoint_scanner))
if "embedding" in model_types:
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
scanners.append(("embedding", embedding_scanner))
# Load progress file to check processed models
settings_manager = get_settings_manager()
active_library = settings_manager.get_active_library_name()
output_dir = self._resolve_output_dir(active_library)
processed_models: set[str] = set()
failed_models: set[str] = set()
if output_dir:
progress_file = os.path.join(output_dir, ".download_progress.json")
if os.path.exists(progress_file):
try:
with open(progress_file, "r", encoding="utf-8") as f:
saved_progress = json.load(f)
processed_models = set(saved_progress.get("processed_models", []))
failed_models = set(saved_progress.get("failed_models", []))
except Exception:
pass # Ignore progress file errors for quick check
# Count models
total_models = 0
models_with_hash = 0
for scanner_type, scanner in scanners:
cache = await scanner.get_cached_data()
if cache and cache.raw_data:
for model in cache.raw_data:
total_models += 1
if model.get("sha256"):
models_with_hash += 1
# Calculate pending count
# A model is pending if it has a hash and is not in processed_models
# We also exclude failed_models unless force mode would be used
pending_count = models_with_hash - len(processed_models.intersection(
{m.get("sha256", "").lower() for scanner_type, scanner in scanners
for m in (await scanner.get_cached_data()).raw_data if m.get("sha256")}
))
# More accurate pending count: check which models actually need processing
pending_hashes = set()
for scanner_type, scanner in scanners:
cache = await scanner.get_cached_data()
if cache and cache.raw_data:
for model in cache.raw_data:
raw_hash = model.get("sha256")
if not raw_hash:
continue
model_hash = raw_hash.lower()
if model_hash not in processed_models:
# Check if model folder exists with files
model_dir = ExampleImagePathResolver.get_model_folder(
model_hash, active_library
)
if not _model_directory_has_files(model_dir):
pending_hashes.add(model_hash)
pending_count = len(pending_hashes)
return {
"success": True,
"is_downloading": False,
"total_models": total_models,
"pending_count": pending_count,
"processed_count": len(processed_models),
"failed_count": len(failed_models),
"needs_download": pending_count > 0,
}
except Exception as e:
logger.error(f"Error checking pending models: {e}", exc_info=True)
return {
"success": False,
"error": str(e),
"total_models": 0,
"pending_count": 0,
"processed_count": 0,
"failed_count": 0,
"needs_download": False,
}
async def pause_download(self, request):
"""Pause the example images download."""
async with self._state_lock:
if not self._is_downloading:
raise DownloadNotRunningError()
self._progress["status"] = "paused"
await self._broadcast_progress(status="paused")
return {"success": True, "message": "Download paused"}
async def resume_download(self, request):
"""Resume the example images download."""
async with self._state_lock:
if not self._is_downloading:
raise DownloadNotRunningError()
if self._progress["status"] == "paused":
self._progress["status"] = "running"
else:
raise DownloadNotRunningError(
f"Download is in '{self._progress['status']}' state, cannot resume"
)
await self._broadcast_progress(status="running")
return {"success": True, "message": "Download resumed"}
async def stop_download(self, request):
"""Stop the example images download after the current model completes."""
async with self._state_lock:
if not self._is_downloading:
raise DownloadNotRunningError()
if self._progress["status"] in {"completed", "error", "stopped"}:
raise DownloadNotRunningError()
if self._progress["status"] != "stopping":
self._stop_requested = True
self._progress["status"] = "stopping"
await self._broadcast_progress(status="stopping")
return {"success": True, "message": "Download stopping"}
async def _download_all_example_images(
self,
output_dir,
optimize,
model_types,
delay,
library_name,
force: bool = False,
):
"""Download example images for all models."""
downloader = await get_downloader()
try:
# Get scanners
scanners = []
if "lora" in model_types:
lora_scanner = await ServiceRegistry.get_lora_scanner()
scanners.append(("lora", lora_scanner))
if "checkpoint" in model_types:
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
scanners.append(("checkpoint", checkpoint_scanner))
if "embedding" in model_types:
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
scanners.append(("embedding", embedding_scanner))
# Get all models
all_models = []
for scanner_type, scanner in scanners:
cache = await scanner.get_cached_data()
if cache and cache.raw_data:
for model in cache.raw_data:
if model.get("sha256"):
all_models.append((scanner_type, model, scanner))
# Update total count
self._progress["total"] = len(all_models)
logger.debug(f"Found {self._progress['total']} models to process")
await self._broadcast_progress(status="running")
# Process each model
for i, (scanner_type, model, scanner) in enumerate(all_models):
async with self._state_lock:
current_status = self._progress["status"]
if current_status not in {"running", "paused", "stopping"}:
break
# Main logic for processing model is here, but actual operations are delegated to other classes
was_remote_download = await self._process_model(
scanner_type,
model,
scanner,
output_dir,
optimize,
downloader,
library_name,
force,
)
# Update progress
self._progress["completed"] += 1
async with self._state_lock:
current_status = self._progress["status"]
should_stop = self._stop_requested and current_status == "stopping"
broadcast_status = (
"running" if current_status == "running" else current_status
)
await self._broadcast_progress(status=broadcast_status)
if should_stop:
break
# Only add delay after remote download of models, and not after processing the last model
if (
was_remote_download
and i < len(all_models) - 1
and current_status == "running"
):
await asyncio.sleep(delay)
async with self._state_lock:
if self._stop_requested and self._progress["status"] == "stopping":
self._progress["status"] = "stopped"
self._progress["end_time"] = time.time()
self._stop_requested = False
final_status = "stopped"
elif self._progress["status"] not in {"error", "stopped"}:
self._progress["status"] = "completed"
self._progress["end_time"] = time.time()
self._stop_requested = False
final_status = "completed"
else:
final_status = self._progress["status"]
self._stop_requested = False
if self._progress["end_time"] is None:
self._progress["end_time"] = time.time()
if final_status == "completed":
logger.debug(
"Example images download completed: %s/%s models processed",
self._progress["completed"],
self._progress["total"],
)
elif final_status == "stopped":
logger.debug(
"Example images download stopped: %s/%s models processed",
self._progress["completed"],
self._progress["total"],
)
reprocessed = self._progress.get("reprocessed_models", set())
if reprocessed:
logger.info(
"Detected %s models with missing or empty example image folders; reprocessing triggered for those models",
len(reprocessed),
)
await self._broadcast_progress(status=final_status)
except Exception as e:
error_msg = f"Error during example images download: {str(e)}"
logger.error(error_msg, exc_info=True)
self._progress["errors"].append(error_msg)
self._progress["last_error"] = error_msg
self._progress["status"] = "error"
self._progress["end_time"] = time.time()
await self._broadcast_progress(status="error", extra={"error": error_msg})
finally:
# Save final progress to file
try:
self._save_progress(output_dir)
except Exception as e:
logger.error(f"Failed to save progress file: {e}")
# Set download status to not downloading
async with self._state_lock:
self._is_downloading = False
self._download_task = None
self._stop_requested = False
async def _process_model(
self,
scanner_type,
model,
scanner,
output_dir,
optimize,
downloader,
library_name,
force: bool = False,
):
"""Process a single model download."""
# Check if download is paused
while self._progress["status"] == "paused":
await asyncio.sleep(1)
# Check if download should continue
if self._progress["status"] not in {"running", "stopping"}:
logger.info(f"Download stopped: {self._progress['status']}")
return False # Return False to indicate no remote download happened
model_hash = model.get("sha256", "").lower()
model_name = model.get("model_name", "Unknown")
model_file_path = model.get("file_path", "")
model_file_name = model.get("file_name", "")
try:
# Update current model info
self._progress["current_model"] = f"{model_name} ({model_hash[:8]})"
await self._broadcast_progress(status="running")
# Skip if already in failed models (unless force mode is enabled)
if not force and model_hash in self._progress["failed_models"]:
logger.debug(f"Skipping known failed model: {model_name}")
return False
model_dir = ExampleImagePathResolver.get_model_folder(
model_hash, library_name
)
existing_files = _model_directory_has_files(model_dir)
# Skip if already processed AND directory exists with files
if model_hash in self._progress["processed_models"]:
if existing_files:
logger.debug(f"Skipping already processed model: {model_name}")
return False
logger.debug(
"Model %s (%s) marked as processed but folder empty or missing, reprocessing triggered",
model_name,
model_hash,
)
# Track that we are reprocessing this model for summary logging
self._progress["reprocessed_models"].add(model_hash)
# Remove from processed models since we need to reprocess
self._progress["processed_models"].discard(model_hash)
if existing_files and model_hash not in self._progress["processed_models"]:
logger.debug(
"Model folder already populated for %s, marking as processed without download",
model_name,
)
self._progress["processed_models"].add(model_hash)
return False
if not model_dir:
logger.warning(
"Unable to resolve example images folder for model %s (%s)",
model_name,
model_hash,
)
return False
# Create model directory
os.makedirs(model_dir, exist_ok=True)
# First check for local example images - local processing doesn't need delay
local_images_processed = (
await ExampleImagesProcessor.process_local_examples(
model_file_path, model_file_name, model_name, model_dir, optimize
)
)
# If we processed local images, update metadata
if local_images_processed:
await MetadataUpdater.update_metadata_from_local_examples(
model_hash, model, scanner_type, scanner, model_dir
)
self._progress["processed_models"].add(model_hash)
return False # Return False to indicate no remote download happened
full_model = await MetadataUpdater.get_updated_model(model_hash, scanner)
civitai_payload = (full_model or {}).get("civitai") if full_model else None
civitai_payload = civitai_payload or {}
# If no local images, try to download from remote
if civitai_payload.get("images"):
images = civitai_payload.get("images", [])
(
success,
is_stale,
failed_images,
) = await ExampleImagesProcessor.download_model_images_with_tracking(
model_hash, model_name, images, model_dir, optimize, downloader
)
failed_urls: Set[str] = set(failed_images)
# If metadata is stale, try to refresh it
if is_stale and model_hash not in self._progress["refreshed_models"]:
await MetadataUpdater.refresh_model_metadata(
model_hash, model_name, scanner_type, scanner, self._progress
)
# Get the updated model data
updated_model = await MetadataUpdater.get_updated_model(
model_hash, scanner
)
updated_civitai = (
(updated_model or {}).get("civitai") if updated_model else None
)
updated_civitai = updated_civitai or {}
if updated_civitai.get("images"):
# Retry download with updated metadata
updated_images = updated_civitai.get("images", [])
(
success,
_,
additional_failed,
) = await ExampleImagesProcessor.download_model_images_with_tracking(
model_hash,
model_name,
updated_images,
model_dir,
optimize,
downloader,
)
failed_urls.update(additional_failed)
self._progress["refreshed_models"].add(model_hash)
if failed_urls:
await self._remove_failed_images_from_metadata(
model_hash,
model_name,
model_dir,
failed_urls,
scanner,
)
if failed_urls:
self._progress["failed_models"].add(model_hash)
self._progress["processed_models"].add(model_hash)
logger.info(
"Removed %s failed example images for %s",
len(failed_urls),
model_name,
)
elif success:
self._progress["processed_models"].add(model_hash)
# Remove from failed_models if force mode enabled and model was previously failed
if force and model_hash in self._progress["failed_models"]:
self._progress["failed_models"].discard(model_hash)
logger.info(
f"Removed {model_name} from failed_models after successful force retry"
)
else:
self._progress["failed_models"].add(model_hash)
logger.info(
"Example images download failed for %s despite metadata refresh",
model_name,
)
return True # Return True to indicate a remote download happened
else:
# No civitai data or images available, mark as failed to avoid future attempts
self._progress["failed_models"].add(model_hash)
logger.debug(
f"No civitai images available for model {model_name}, marking as failed"
)
# Save progress periodically
if (
self._progress["completed"] % 10 == 0
or self._progress["completed"] == self._progress["total"] - 1
):
self._save_progress(output_dir)
return False # Default return if no conditions met
except Exception as e:
error_msg = f"Error processing model {model.get('model_name')} ({model_hash}): {str(e)}"
logger.error(error_msg, exc_info=True)
self._progress["errors"].append(error_msg)
self._progress["last_error"] = error_msg
# Ensure model is marked as failed so we don't try again in this run
self._progress["failed_models"].add(model_hash)
return False
def _save_progress(self, output_dir):
"""Save download progress to file."""
try:
progress_file = os.path.join(output_dir, ".download_progress.json")
# Read existing progress file if it exists
existing_data = {}
if os.path.exists(progress_file):
try:
with open(progress_file, "r", encoding="utf-8") as f:
existing_data = json.load(f)
except Exception as e:
logger.warning(f"Failed to read existing progress file: {e}")
# Create new progress data
progress_data = {
"processed_models": list(self._progress["processed_models"]),
"refreshed_models": list(self._progress["refreshed_models"]),
"failed_models": list(self._progress["failed_models"]),
"completed": self._progress["completed"],
"total": self._progress["total"],
"last_update": time.time(),
}
# Preserve existing fields (especially naming_version)
for key, value in existing_data.items():
if key not in progress_data:
progress_data[key] = value
# Write updated progress data
with open(progress_file, "w", encoding="utf-8") as f:
json.dump(progress_data, f, indent=2)
except Exception as e:
logger.error(f"Failed to save progress file: {e}")
async def start_force_download(self, options: dict):
"""Force download example images for specific models."""
async with self._state_lock:
if self._is_downloading:
raise DownloadInProgressError(self._progress.snapshot())
data = options or {}
model_hashes = data.get("model_hashes", [])
optimize = data.get("optimize", True)
model_types = data.get("model_types", ["lora", "checkpoint"])
delay = float(data.get("delay", 0.2))
if not model_hashes:
raise DownloadConfigurationError("Missing model_hashes parameter")
settings_manager = get_settings_manager()
base_path = settings_manager.get("example_images_path")
if not base_path:
raise DownloadConfigurationError(
"Example images path not configured in settings"
)
active_library = settings_manager.get_active_library_name()
output_dir = self._resolve_output_dir(active_library)
if not output_dir:
raise DownloadConfigurationError(
"Example images path not configured in settings"
)
self._progress.reset()
self._stop_requested = False
self._progress["total"] = len(model_hashes)
self._progress["status"] = "running"
self._progress["start_time"] = time.time()
self._progress["end_time"] = None
self._is_downloading = True
await self._broadcast_progress(status="running")
try:
result = await self._download_specific_models_example_images_sync(
model_hashes,
output_dir,
optimize,
model_types,
delay,
active_library,
)
async with self._state_lock:
self._is_downloading = False
final_status = self._progress["status"]
message = "Force download completed"
if final_status == "stopped":
message = "Force download stopped"
return {"success": True, "message": message, "result": result}
except Exception as e:
async with self._state_lock:
self._is_downloading = False
logger.error(
f"Failed during forced example images download: {e}", exc_info=True
)
await self._broadcast_progress(status="error", extra={"error": str(e)})
raise ExampleImagesDownloadError(str(e)) from e
async def _download_specific_models_example_images_sync(
self,
model_hashes,
output_dir,
optimize,
model_types,
delay,
library_name,
):
"""Download example images for specific models only - synchronous version."""
downloader = await get_downloader()
try:
# Get scanners
scanners = []
if "lora" in model_types:
lora_scanner = await ServiceRegistry.get_lora_scanner()
scanners.append(("lora", lora_scanner))
if "checkpoint" in model_types:
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
scanners.append(("checkpoint", checkpoint_scanner))
if "embedding" in model_types:
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
scanners.append(("embedding", embedding_scanner))
# Find the specified models
models_to_process = []
for scanner_type, scanner in scanners:
cache = await scanner.get_cached_data()
if cache and cache.raw_data:
for model in cache.raw_data:
if model.get("sha256") in model_hashes:
models_to_process.append((scanner_type, model, scanner))
# Update total count based on found models
self._progress["total"] = len(models_to_process)
logger.debug(f"Found {self._progress['total']} models to process")
# Send initial progress via WebSocket
await self._broadcast_progress(status="running")
# Process each model
success_count = 0
for i, (scanner_type, model, scanner) in enumerate(models_to_process):
async with self._state_lock:
current_status = self._progress["status"]
if current_status not in {"running", "paused", "stopping"}:
break
# Force process this model regardless of previous status
was_successful = await self._process_specific_model(
scanner_type,
model,
scanner,
output_dir,
optimize,
downloader,
library_name,
)
if was_successful:
success_count += 1
# Update progress
self._progress["completed"] += 1
async with self._state_lock:
current_status = self._progress["status"]
should_stop = self._stop_requested and current_status == "stopping"
broadcast_status = (
"running" if current_status == "running" else current_status
)
# Send progress update via WebSocket
await self._broadcast_progress(status=broadcast_status)
if should_stop:
break
# Only add delay after remote download, and not after processing the last model
if (
was_successful
and i < len(models_to_process) - 1
and current_status == "running"
):
await asyncio.sleep(delay)
async with self._state_lock:
if self._stop_requested and self._progress["status"] == "stopping":
self._progress["status"] = "stopped"
self._progress["end_time"] = time.time()
self._stop_requested = False
final_status = "stopped"
elif self._progress["status"] not in {"error", "stopped"}:
self._progress["status"] = "completed"
self._progress["end_time"] = time.time()
self._stop_requested = False
final_status = "completed"
else:
final_status = self._progress["status"]
self._stop_requested = False
if self._progress["end_time"] is None:
self._progress["end_time"] = time.time()
if final_status == "completed":
logger.debug(
"Forced example images download completed: %s/%s models processed",
self._progress["completed"],
self._progress["total"],
)
elif final_status == "stopped":
logger.debug(
"Forced example images download stopped: %s/%s models processed",
self._progress["completed"],
self._progress["total"],
)
# Send final progress via WebSocket
await self._broadcast_progress(status=final_status)
return {
"total": self._progress["total"],
"processed": self._progress["completed"],
"successful": success_count,
"errors": self._progress["errors"],
}
except Exception as e:
error_msg = f"Error during forced example images download: {str(e)}"
logger.error(error_msg, exc_info=True)
self._progress["errors"].append(error_msg)
self._progress["last_error"] = error_msg
self._progress["status"] = "error"
self._progress["end_time"] = time.time()
# Send error status via WebSocket
await self._broadcast_progress(status="error", extra={"error": error_msg})
raise
finally:
# No need to close any sessions since we use the global downloader
pass
async def _process_specific_model(
self,
scanner_type,
model,
scanner,
output_dir,
optimize,
downloader,
library_name,
):
"""Process a specific model for forced download, ignoring previous download status."""
# Check if download is paused
while self._progress["status"] == "paused":
await asyncio.sleep(1)
# Check if download should continue
if self._progress["status"] not in {"running", "stopping"}:
logger.info(f"Download stopped: {self._progress['status']}")
return False
model_hash = model.get("sha256", "").lower()
model_name = model.get("model_name", "Unknown")
model_file_path = model.get("file_path", "")
model_file_name = model.get("file_name", "")
try:
# Update current model info
self._progress["current_model"] = f"{model_name} ({model_hash[:8]})"
await self._broadcast_progress(status="running")
model_dir = ExampleImagePathResolver.get_model_folder(
model_hash, library_name
)
if not model_dir:
logger.warning(
"Unable to resolve example images folder for model %s (%s)",
model_name,
model_hash,
)
return False
os.makedirs(model_dir, exist_ok=True)
# First check for local example images - local processing doesn't need delay
local_images_processed = (
await ExampleImagesProcessor.process_local_examples(
model_file_path, model_file_name, model_name, model_dir, optimize
)
)
# If we processed local images, update metadata
if local_images_processed:
await MetadataUpdater.update_metadata_from_local_examples(
model_hash, model, scanner_type, scanner, model_dir
)
self._progress["processed_models"].add(model_hash)
return False # Return False to indicate no remote download happened
full_model = await MetadataUpdater.get_updated_model(model_hash, scanner)
civitai_payload = (full_model or {}).get("civitai") if full_model else None
civitai_payload = civitai_payload or {}
# If no local images, try to download from remote
if civitai_payload.get("images"):
images = civitai_payload.get("images", [])
(
success,
is_stale,
failed_images,
) = await ExampleImagesProcessor.download_model_images_with_tracking(
model_hash, model_name, images, model_dir, optimize, downloader
)
failed_urls: Set[str] = set(failed_images)
# If metadata is stale, try to refresh it
if is_stale and model_hash not in self._progress["refreshed_models"]:
await MetadataUpdater.refresh_model_metadata(
model_hash, model_name, scanner_type, scanner, self._progress
)
# Get the updated model data
updated_model = await MetadataUpdater.get_updated_model(
model_hash, scanner
)
updated_civitai = (
(updated_model or {}).get("civitai") if updated_model else None
)
updated_civitai = updated_civitai or {}
if updated_civitai.get("images"):
# Retry download with updated metadata
updated_images = updated_civitai.get("images", [])
(
success,
_,
additional_failed_images,
) = await ExampleImagesProcessor.download_model_images_with_tracking(
model_hash,
model_name,
updated_images,
model_dir,
optimize,
downloader,
)
# Combine failed images from both attempts
failed_urls.update(additional_failed_images)
self._progress["refreshed_models"].add(model_hash)
# For forced downloads, remove failed images from metadata
if failed_urls:
await self._remove_failed_images_from_metadata(
model_hash, model_name, model_dir, failed_urls, scanner
)
# Mark as processed
if (
success or failed_urls
): # Mark as processed if we successfully downloaded some images or removed failed ones
self._progress["processed_models"].add(model_hash)
return True # Return True to indicate a remote download happened
else:
logger.debug(f"No civitai images available for model {model_name}")
return False
except Exception as e:
error_msg = f"Error processing model {model.get('model_name')}: {str(e)}"
logger.error(error_msg, exc_info=True)
self._progress["errors"].append(error_msg)
self._progress["last_error"] = error_msg
return False # Return False on exception
async def _remove_failed_images_from_metadata(
self,
model_hash: str,
model_name: str,
model_dir: str,
failed_images: Iterable[str],
scanner,
) -> None:
"""Mark failed images in model metadata so they won't be retried."""
failed_set: Set[str] = {url for url in failed_images if url}
if not failed_set:
return
try:
# Get current model data
model_data = await MetadataUpdater.get_updated_model(model_hash, scanner)
if not model_data:
logger.warning(
f"Could not find model data for {model_name} to remove failed images"
)
return
civitai_payload = model_data.get("civitai") or {}
current_images = civitai_payload.get("images") or []
if not current_images:
logger.warning(f"No images in metadata for {model_name}")
return
updated = False
for image in current_images:
image_url = image.get("url")
optimized_url = (
ExampleImagesProcessor.get_civitai_optimized_url(image_url)
if image_url and "civitai.com" in image_url
else None
)
if image_url not in failed_set and optimized_url not in failed_set:
continue
if image.get("downloadFailed"):
continue
image["downloadFailed"] = True
image.setdefault("downloadError", "not_found")
logger.debug(
"Marked example image %s for %s as failed due to missing remote asset",
image_url,
model_name,
)
updated = True
if not updated:
return
file_path = model_data.get("file_path")
if file_path:
model_copy = model_data.copy()
model_copy.pop("folder", None)
await MetadataManager.save_metadata(file_path, model_copy)
try:
await scanner.update_single_model_cache(
file_path, file_path, model_data
)
except AttributeError:
logger.debug(
"Scanner does not expose cache update for %s", model_name
)
except Exception as exc: # pragma: no cover - defensive logging
logger.error(
"Error removing failed images from metadata for %s: %s",
model_name,
exc,
exc_info=True,
)
def _renumber_example_image_files(self, model_dir: str) -> None:
if not model_dir or not os.path.isdir(model_dir):
return
pattern = re.compile(r"^image_(\d+)(\.[^.]+)$", re.IGNORECASE)
matches: List[Tuple[int, str, str]] = []
for entry in os.listdir(model_dir):
match = pattern.match(entry)
if match:
matches.append((int(match.group(1)), entry, match.group(2)))
if not matches:
return
matches.sort(key=lambda item: item[0])
staged_paths: List[Tuple[str, str]] = []
for _, original_name, extension in matches:
source_path = os.path.join(model_dir, original_name)
temp_name = f"tmp_{uuid.uuid4().hex}_{original_name}"
temp_path = os.path.join(model_dir, temp_name)
try:
os.rename(source_path, temp_path)
staged_paths.append((temp_path, extension))
except OSError as exc:
logger.warning("Failed to stage rename for %s: %s", source_path, exc)
for new_index, (temp_path, extension) in enumerate(staged_paths):
final_name = f"image_{new_index}{extension}"
final_path = os.path.join(model_dir, final_name)
try:
os.rename(temp_path, final_path)
except OSError as exc:
logger.warning("Failed to finalise rename for %s: %s", final_path, exc)
async def _broadcast_progress(
self,
*,
status: str | None = None,
extra: Dict[str, Any] | None = None,
) -> None:
payload = self._build_progress_payload(status=status, extra=extra)
try:
await self._ws_manager.broadcast(payload)
except Exception as exc: # pragma: no cover - defensive logging
logger.warning("Failed to broadcast example image progress: %s", exc)
def _build_progress_payload(
self,
*,
status: str | None = None,
extra: Dict[str, Any] | None = None,
) -> Dict[str, Any]:
payload: Dict[str, Any] = {
"type": "example_images_progress",
"processed": self._progress["completed"],
"total": self._progress["total"],
"status": status or self._progress["status"],
"current_model": self._progress["current_model"],
}
if self._progress["errors"]:
payload["errors"] = list(self._progress["errors"])
if self._progress["last_error"]:
payload["last_error"] = self._progress["last_error"]
if extra:
payload.update(extra)
return payload
_default_download_manager: DownloadManager | None = None
def get_default_download_manager(ws_manager) -> DownloadManager:
"""Return the singleton download manager used by default routes."""
global _default_download_manager
if (
_default_download_manager is None
or getattr(_default_download_manager, "_ws_manager", None) is not ws_manager
):
_default_download_manager = DownloadManager(ws_manager=ws_manager)
return _default_download_manager