mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(example-images): add check pending models endpoint and improve async handling
- Add /api/example-images/check-pending endpoint to quickly check models needing downloads - Improve DownloadManager.start_download() to return immediately without blocking - Add _handle_download_task_done callback for proper error handling and progress saving - Add check_pending_models() method for lightweight pre-download validation - Update frontend ExampleImagesManager to use new check-pending endpoint - Add comprehensive tests for new functionality
This commit is contained in:
@@ -30,6 +30,7 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("POST", "/api/lm/force-download-example-images", "force_download_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/cleanup-example-image-folders", "cleanup_example_image_folders"),
|
||||
RouteDefinition("POST", "/api/lm/example-images/set-nsfw-level", "set_example_image_nsfw_level"),
|
||||
RouteDefinition("POST", "/api/lm/check-example-images-needed", "check_example_images_needed"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -92,6 +92,19 @@ class ExampleImagesDownloadHandler:
|
||||
except ExampleImagesDownloadError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||
|
||||
async def check_example_images_needed(self, request: web.Request) -> web.StreamResponse:
|
||||
"""Lightweight check to see if any models need example images downloaded."""
|
||||
try:
|
||||
payload = await request.json()
|
||||
model_types = payload.get('model_types', ['lora', 'checkpoint', 'embedding'])
|
||||
result = await self._download_manager.check_pending_models(model_types)
|
||||
return web.json_response(result)
|
||||
except Exception as exc:
|
||||
return web.json_response(
|
||||
{'success': False, 'error': str(exc)},
|
||||
status=500
|
||||
)
|
||||
|
||||
|
||||
class ExampleImagesManagementHandler:
|
||||
"""HTTP adapters for import/delete endpoints."""
|
||||
@@ -161,6 +174,7 @@ class ExampleImagesHandlerSet:
|
||||
"resume_example_images": self.download.resume_example_images,
|
||||
"stop_example_images": self.download.stop_example_images,
|
||||
"force_download_example_images": self.download.force_download_example_images,
|
||||
"check_example_images_needed": self.download.check_example_images_needed,
|
||||
"import_example_images": self.management.import_example_images,
|
||||
"delete_example_image": self.management.delete_example_image,
|
||||
"set_example_image_nsfw_level": self.management.set_example_image_nsfw_level,
|
||||
|
||||
@@ -216,6 +216,11 @@ class DownloadManager:
|
||||
self._progress["failed_models"] = set()
|
||||
|
||||
self._is_downloading = True
|
||||
snapshot = self._progress.snapshot()
|
||||
|
||||
# Create the download task without awaiting it
|
||||
# This ensures the HTTP response is returned immediately
|
||||
# while the actual processing happens in the background
|
||||
self._download_task = asyncio.create_task(
|
||||
self._download_all_example_images(
|
||||
output_dir,
|
||||
@@ -227,7 +232,10 @@ class DownloadManager:
|
||||
)
|
||||
)
|
||||
|
||||
snapshot = self._progress.snapshot()
|
||||
# Add a callback to handle task completion/errors
|
||||
self._download_task.add_done_callback(
|
||||
lambda t: self._handle_download_task_done(t, output_dir)
|
||||
)
|
||||
except ExampleImagesDownloadError:
|
||||
# Re-raise our own exception types without wrapping
|
||||
self._is_downloading = False
|
||||
@@ -241,10 +249,25 @@ class DownloadManager:
|
||||
)
|
||||
raise ExampleImagesDownloadError(str(e)) from e
|
||||
|
||||
await self._broadcast_progress(status="running")
|
||||
# Broadcast progress in the background without blocking the response
|
||||
# This ensures the HTTP response is returned immediately
|
||||
asyncio.create_task(self._broadcast_progress(status="running"))
|
||||
|
||||
return {"success": True, "message": "Download started", "status": snapshot}
|
||||
|
||||
def _handle_download_task_done(self, task: asyncio.Task, output_dir: str) -> None:
|
||||
"""Handle download task completion, including saving progress on error."""
|
||||
try:
|
||||
# This will re-raise any exception from the task
|
||||
task.result()
|
||||
except Exception as e:
|
||||
logger.error(f"Download task failed with error: {e}", exc_info=True)
|
||||
# Ensure progress is saved even on failure
|
||||
try:
|
||||
self._save_progress(output_dir)
|
||||
except Exception as save_error:
|
||||
logger.error(f"Failed to save progress after task failure: {save_error}")
|
||||
|
||||
async def get_status(self, request):
|
||||
"""Get the current status of example images download."""
|
||||
|
||||
@@ -254,6 +277,130 @@ class DownloadManager:
|
||||
"status": self._progress.snapshot(),
|
||||
}
|
||||
|
||||
async def check_pending_models(self, model_types: list[str]) -> dict:
|
||||
"""Quickly check how many models need example images downloaded.
|
||||
|
||||
This is a lightweight check that avoids the overhead of starting
|
||||
a full download task when no work is needed.
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
- total_models: Total number of models across specified types
|
||||
- pending_count: Number of models needing example images
|
||||
- processed_count: Number of already processed models
|
||||
- failed_count: Number of models marked as failed
|
||||
- needs_download: True if there are pending models to process
|
||||
"""
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
if self._is_downloading:
|
||||
return {
|
||||
"success": True,
|
||||
"is_downloading": True,
|
||||
"total_models": 0,
|
||||
"pending_count": 0,
|
||||
"processed_count": 0,
|
||||
"failed_count": 0,
|
||||
"needs_download": False,
|
||||
"message": "Download already in progress",
|
||||
}
|
||||
|
||||
try:
|
||||
# Get scanners
|
||||
scanners = []
|
||||
if "lora" in model_types:
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
scanners.append(("lora", lora_scanner))
|
||||
|
||||
if "checkpoint" in model_types:
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
scanners.append(("checkpoint", checkpoint_scanner))
|
||||
|
||||
if "embedding" in model_types:
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
scanners.append(("embedding", embedding_scanner))
|
||||
|
||||
# Load progress file to check processed models
|
||||
settings_manager = get_settings_manager()
|
||||
active_library = settings_manager.get_active_library_name()
|
||||
output_dir = self._resolve_output_dir(active_library)
|
||||
|
||||
processed_models: set[str] = set()
|
||||
failed_models: set[str] = set()
|
||||
|
||||
if output_dir:
|
||||
progress_file = os.path.join(output_dir, ".download_progress.json")
|
||||
if os.path.exists(progress_file):
|
||||
try:
|
||||
with open(progress_file, "r", encoding="utf-8") as f:
|
||||
saved_progress = json.load(f)
|
||||
processed_models = set(saved_progress.get("processed_models", []))
|
||||
failed_models = set(saved_progress.get("failed_models", []))
|
||||
except Exception:
|
||||
pass # Ignore progress file errors for quick check
|
||||
|
||||
# Count models
|
||||
total_models = 0
|
||||
models_with_hash = 0
|
||||
|
||||
for scanner_type, scanner in scanners:
|
||||
cache = await scanner.get_cached_data()
|
||||
if cache and cache.raw_data:
|
||||
for model in cache.raw_data:
|
||||
total_models += 1
|
||||
if model.get("sha256"):
|
||||
models_with_hash += 1
|
||||
|
||||
# Calculate pending count
|
||||
# A model is pending if it has a hash and is not in processed_models
|
||||
# We also exclude failed_models unless force mode would be used
|
||||
pending_count = models_with_hash - len(processed_models.intersection(
|
||||
{m.get("sha256", "").lower() for scanner_type, scanner in scanners
|
||||
for m in (await scanner.get_cached_data()).raw_data if m.get("sha256")}
|
||||
))
|
||||
|
||||
# More accurate pending count: check which models actually need processing
|
||||
pending_hashes = set()
|
||||
for scanner_type, scanner in scanners:
|
||||
cache = await scanner.get_cached_data()
|
||||
if cache and cache.raw_data:
|
||||
for model in cache.raw_data:
|
||||
raw_hash = model.get("sha256")
|
||||
if not raw_hash:
|
||||
continue
|
||||
model_hash = raw_hash.lower()
|
||||
if model_hash not in processed_models:
|
||||
# Check if model folder exists with files
|
||||
model_dir = ExampleImagePathResolver.get_model_folder(
|
||||
model_hash, active_library
|
||||
)
|
||||
if not _model_directory_has_files(model_dir):
|
||||
pending_hashes.add(model_hash)
|
||||
|
||||
pending_count = len(pending_hashes)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"is_downloading": False,
|
||||
"total_models": total_models,
|
||||
"pending_count": pending_count,
|
||||
"processed_count": len(processed_models),
|
||||
"failed_count": len(failed_models),
|
||||
"needs_download": pending_count > 0,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking pending models: {e}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"total_models": 0,
|
||||
"pending_count": 0,
|
||||
"processed_count": 0,
|
||||
"failed_count": 0,
|
||||
"needs_download": False,
|
||||
}
|
||||
|
||||
async def pause_download(self, request):
|
||||
"""Pause the example images download."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user