feat(example-images): add check pending models endpoint and improve async handling

- Add /api/example-images/check-pending endpoint to quickly check models needing downloads
- Improve DownloadManager.start_download() to return immediately without blocking
- Add _handle_download_task_done callback for proper error handling and progress saving
- Add check_pending_models() method for lightweight pre-download validation
- Update frontend ExampleImagesManager to use new check-pending endpoint
- Add comprehensive tests for new functionality
This commit is contained in:
Will Miao
2026-02-02 12:31:07 +08:00
parent 1daaff6bd4
commit 1da476d858
7 changed files with 734 additions and 17 deletions

View File

@@ -30,6 +30,7 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("POST", "/api/lm/force-download-example-images", "force_download_example_images"),
RouteDefinition("POST", "/api/lm/cleanup-example-image-folders", "cleanup_example_image_folders"),
RouteDefinition("POST", "/api/lm/example-images/set-nsfw-level", "set_example_image_nsfw_level"),
RouteDefinition("POST", "/api/lm/check-example-images-needed", "check_example_images_needed"),
)

View File

@@ -92,6 +92,19 @@ class ExampleImagesDownloadHandler:
except ExampleImagesDownloadError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=500)
async def check_example_images_needed(self, request: web.Request) -> web.StreamResponse:
"""Lightweight check to see if any models need example images downloaded."""
try:
payload = await request.json()
model_types = payload.get('model_types', ['lora', 'checkpoint', 'embedding'])
result = await self._download_manager.check_pending_models(model_types)
return web.json_response(result)
except Exception as exc:
return web.json_response(
{'success': False, 'error': str(exc)},
status=500
)
class ExampleImagesManagementHandler:
"""HTTP adapters for import/delete endpoints."""
@@ -161,6 +174,7 @@ class ExampleImagesHandlerSet:
"resume_example_images": self.download.resume_example_images,
"stop_example_images": self.download.stop_example_images,
"force_download_example_images": self.download.force_download_example_images,
"check_example_images_needed": self.download.check_example_images_needed,
"import_example_images": self.management.import_example_images,
"delete_example_image": self.management.delete_example_image,
"set_example_image_nsfw_level": self.management.set_example_image_nsfw_level,

View File

@@ -216,6 +216,11 @@ class DownloadManager:
self._progress["failed_models"] = set()
self._is_downloading = True
snapshot = self._progress.snapshot()
# Create the download task without awaiting it
# This ensures the HTTP response is returned immediately
# while the actual processing happens in the background
self._download_task = asyncio.create_task(
self._download_all_example_images(
output_dir,
@@ -227,7 +232,10 @@ class DownloadManager:
)
)
snapshot = self._progress.snapshot()
# Add a callback to handle task completion/errors
self._download_task.add_done_callback(
lambda t: self._handle_download_task_done(t, output_dir)
)
except ExampleImagesDownloadError:
# Re-raise our own exception types without wrapping
self._is_downloading = False
@@ -241,10 +249,25 @@ class DownloadManager:
)
raise ExampleImagesDownloadError(str(e)) from e
await self._broadcast_progress(status="running")
# Broadcast progress in the background without blocking the response
# This ensures the HTTP response is returned immediately
asyncio.create_task(self._broadcast_progress(status="running"))
return {"success": True, "message": "Download started", "status": snapshot}
def _handle_download_task_done(self, task: asyncio.Task, output_dir: str) -> None:
"""Handle download task completion, including saving progress on error."""
try:
# This will re-raise any exception from the task
task.result()
except Exception as e:
logger.error(f"Download task failed with error: {e}", exc_info=True)
# Ensure progress is saved even on failure
try:
self._save_progress(output_dir)
except Exception as save_error:
logger.error(f"Failed to save progress after task failure: {save_error}")
async def get_status(self, request):
"""Get the current status of example images download."""
@@ -254,6 +277,130 @@ class DownloadManager:
"status": self._progress.snapshot(),
}
async def check_pending_models(self, model_types: list[str]) -> dict:
"""Quickly check how many models need example images downloaded.
This is a lightweight check that avoids the overhead of starting
a full download task when no work is needed.
Returns:
dict with keys:
- total_models: Total number of models across specified types
- pending_count: Number of models needing example images
- processed_count: Number of already processed models
- failed_count: Number of models marked as failed
- needs_download: True if there are pending models to process
"""
from ..services.service_registry import ServiceRegistry
if self._is_downloading:
return {
"success": True,
"is_downloading": True,
"total_models": 0,
"pending_count": 0,
"processed_count": 0,
"failed_count": 0,
"needs_download": False,
"message": "Download already in progress",
}
try:
# Get scanners
scanners = []
if "lora" in model_types:
lora_scanner = await ServiceRegistry.get_lora_scanner()
scanners.append(("lora", lora_scanner))
if "checkpoint" in model_types:
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
scanners.append(("checkpoint", checkpoint_scanner))
if "embedding" in model_types:
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
scanners.append(("embedding", embedding_scanner))
# Load progress file to check processed models
settings_manager = get_settings_manager()
active_library = settings_manager.get_active_library_name()
output_dir = self._resolve_output_dir(active_library)
processed_models: set[str] = set()
failed_models: set[str] = set()
if output_dir:
progress_file = os.path.join(output_dir, ".download_progress.json")
if os.path.exists(progress_file):
try:
with open(progress_file, "r", encoding="utf-8") as f:
saved_progress = json.load(f)
processed_models = set(saved_progress.get("processed_models", []))
failed_models = set(saved_progress.get("failed_models", []))
except Exception:
pass # Ignore progress file errors for quick check
# Count models
total_models = 0
models_with_hash = 0
for scanner_type, scanner in scanners:
cache = await scanner.get_cached_data()
if cache and cache.raw_data:
for model in cache.raw_data:
total_models += 1
if model.get("sha256"):
models_with_hash += 1
# Calculate pending count
# A model is pending if it has a hash and is not in processed_models
# We also exclude failed_models unless force mode would be used
pending_count = models_with_hash - len(processed_models.intersection(
{m.get("sha256", "").lower() for scanner_type, scanner in scanners
for m in (await scanner.get_cached_data()).raw_data if m.get("sha256")}
))
# More accurate pending count: check which models actually need processing
pending_hashes = set()
for scanner_type, scanner in scanners:
cache = await scanner.get_cached_data()
if cache and cache.raw_data:
for model in cache.raw_data:
raw_hash = model.get("sha256")
if not raw_hash:
continue
model_hash = raw_hash.lower()
if model_hash not in processed_models:
# Check if model folder exists with files
model_dir = ExampleImagePathResolver.get_model_folder(
model_hash, active_library
)
if not _model_directory_has_files(model_dir):
pending_hashes.add(model_hash)
pending_count = len(pending_hashes)
return {
"success": True,
"is_downloading": False,
"total_models": total_models,
"pending_count": pending_count,
"processed_count": len(processed_models),
"failed_count": len(failed_models),
"needs_download": pending_count > 0,
}
except Exception as e:
logger.error(f"Error checking pending models: {e}", exc_info=True)
return {
"success": False,
"error": str(e),
"total_models": 0,
"pending_count": 0,
"processed_count": 0,
"failed_count": 0,
"needs_download": False,
}
async def pause_download(self, request):
"""Pause the example images download."""