mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 13:12:12 -03:00
Backend fixes: - Add missing API route for /api/lm/recipes/batch-import/progress (GET) - Add missing API route for /api/lm/recipes/batch-import/directory (POST) - Add missing API route for /api/lm/recipes/browse-directory (POST) - Register WebSocket endpoint for batch import progress - Fix skip_no_metadata default value (True -> False) to allow no-LoRA imports - Add items array to BatchImportProgress.to_dict() for detailed results Frontend implementation: - Create BatchImportManager.js with complete batch import workflow - Add directory browser UI for selecting folders - Add batch import modal with URL list and directory input modes - Implement real-time progress tracking (WebSocket + HTTP polling) - Add results summary with success/failed/skipped statistics - Add expandable details view showing individual item status - Auto-refresh recipe list after import completion UI improvements: - Add spinner animation for importing status - Simplify results summary UI to match progress stats styling - Fix current item text alignment - Fix dark theme styling for directory browser button - Fix batch import button styling consistency Translations: - Add batch import related i18n keys to all locale files - Run sync_translation_keys.py to sync all translations Fixes: - Batch import now allows images without LoRAs (matches single import behavior) - Progress endpoint now returns complete items array with status details - Results view correctly displays skipped items with error messages
594 lines
20 KiB
Python
594 lines
20 KiB
Python
"""Batch import service for importing multiple images as recipes."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import time
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import Any, Callable, Dict, List, Optional, Set
|
|
|
|
from aiohttp import web
|
|
|
|
from .recipes import (
|
|
RecipeAnalysisService,
|
|
RecipePersistenceService,
|
|
RecipeValidationError,
|
|
RecipeDownloadError,
|
|
RecipeNotFoundError,
|
|
)
|
|
|
|
|
|
class ImportItemType(Enum):
|
|
"""Type of import item."""
|
|
|
|
URL = "url"
|
|
LOCAL_PATH = "local_path"
|
|
|
|
|
|
class ImportStatus(Enum):
|
|
"""Status of an individual import item."""
|
|
|
|
PENDING = "pending"
|
|
PROCESSING = "processing"
|
|
SUCCESS = "success"
|
|
FAILED = "failed"
|
|
SKIPPED = "skipped"
|
|
|
|
|
|
@dataclass
|
|
class BatchImportItem:
|
|
"""Represents a single item to import."""
|
|
|
|
id: str
|
|
source: str
|
|
item_type: ImportItemType
|
|
status: ImportStatus = ImportStatus.PENDING
|
|
error_message: Optional[str] = None
|
|
recipe_name: Optional[str] = None
|
|
recipe_id: Optional[str] = None
|
|
duration: float = 0.0
|
|
|
|
|
|
@dataclass
|
|
class BatchImportProgress:
|
|
"""Tracks progress of a batch import operation."""
|
|
|
|
operation_id: str
|
|
total: int
|
|
completed: int = 0
|
|
success: int = 0
|
|
failed: int = 0
|
|
skipped: int = 0
|
|
current_item: str = ""
|
|
status: str = "pending"
|
|
started_at: float = field(default_factory=time.time)
|
|
finished_at: Optional[float] = None
|
|
items: List[BatchImportItem] = field(default_factory=list)
|
|
tags: List[str] = field(default_factory=list)
|
|
skip_no_metadata: bool = False
|
|
skip_duplicates: bool = False
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"operation_id": self.operation_id,
|
|
"total": self.total,
|
|
"completed": self.completed,
|
|
"success": self.success,
|
|
"failed": self.failed,
|
|
"skipped": self.skipped,
|
|
"current_item": self.current_item,
|
|
"status": self.status,
|
|
"started_at": self.started_at,
|
|
"finished_at": self.finished_at,
|
|
"progress_percent": round((self.completed / self.total) * 100, 1)
|
|
if self.total > 0
|
|
else 0,
|
|
"items": [
|
|
{
|
|
"id": item.id,
|
|
"source": item.source,
|
|
"item_type": item.item_type.value,
|
|
"status": item.status.value,
|
|
"error_message": item.error_message,
|
|
"recipe_name": item.recipe_name,
|
|
"recipe_id": item.recipe_id,
|
|
"duration": item.duration,
|
|
}
|
|
for item in self.items
|
|
],
|
|
}
|
|
|
|
|
|
class AdaptiveConcurrencyController:
|
|
"""Adjusts concurrency based on task performance."""
|
|
|
|
def __init__(
|
|
self,
|
|
min_concurrency: int = 1,
|
|
max_concurrency: int = 5,
|
|
initial_concurrency: int = 3,
|
|
) -> None:
|
|
self.min_concurrency = min_concurrency
|
|
self.max_concurrency = max_concurrency
|
|
self.current_concurrency = initial_concurrency
|
|
self._task_durations: List[float] = []
|
|
self._recent_errors = 0
|
|
self._recent_successes = 0
|
|
|
|
def record_result(self, duration: float, success: bool) -> None:
|
|
self._task_durations.append(duration)
|
|
if len(self._task_durations) > 10:
|
|
self._task_durations.pop(0)
|
|
|
|
if success:
|
|
self._recent_successes += 1
|
|
if duration < 1.0 and self.current_concurrency < self.max_concurrency:
|
|
self.current_concurrency = min(
|
|
self.current_concurrency + 1, self.max_concurrency
|
|
)
|
|
elif duration > 10.0 and self.current_concurrency > self.min_concurrency:
|
|
self.current_concurrency = max(
|
|
self.current_concurrency - 1, self.min_concurrency
|
|
)
|
|
else:
|
|
self._recent_errors += 1
|
|
if self.current_concurrency > self.min_concurrency:
|
|
self.current_concurrency = max(
|
|
self.current_concurrency - 1, self.min_concurrency
|
|
)
|
|
|
|
def reset_counters(self) -> None:
|
|
self._recent_errors = 0
|
|
self._recent_successes = 0
|
|
|
|
def get_semaphore(self) -> asyncio.Semaphore:
|
|
return asyncio.Semaphore(self.current_concurrency)
|
|
|
|
|
|
class BatchImportService:
|
|
"""Service for batch importing images as recipes."""
|
|
|
|
SUPPORTED_EXTENSIONS: Set[str] = {".jpg", ".jpeg", ".png", ".webp", ".gif", ".bmp"}
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
analysis_service: RecipeAnalysisService,
|
|
persistence_service: RecipePersistenceService,
|
|
ws_manager: Any,
|
|
logger: logging.Logger,
|
|
) -> None:
|
|
self._analysis_service = analysis_service
|
|
self._persistence_service = persistence_service
|
|
self._ws_manager = ws_manager
|
|
self._logger = logger
|
|
self._active_operations: Dict[str, BatchImportProgress] = {}
|
|
self._cancellation_flags: Dict[str, bool] = {}
|
|
self._concurrency_controller = AdaptiveConcurrencyController()
|
|
|
|
def is_import_running(self, operation_id: Optional[str] = None) -> bool:
|
|
if operation_id:
|
|
progress = self._active_operations.get(operation_id)
|
|
return progress is not None and progress.status in ("pending", "running")
|
|
return any(
|
|
p.status in ("pending", "running") for p in self._active_operations.values()
|
|
)
|
|
|
|
def get_progress(self, operation_id: str) -> Optional[BatchImportProgress]:
|
|
return self._active_operations.get(operation_id)
|
|
|
|
def cancel_import(self, operation_id: str) -> bool:
|
|
if operation_id in self._active_operations:
|
|
self._cancellation_flags[operation_id] = True
|
|
return True
|
|
return False
|
|
|
|
def _validate_url(self, url: str) -> bool:
|
|
import re
|
|
|
|
url_pattern = re.compile(
|
|
r"^https?://"
|
|
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|"
|
|
r"localhost|"
|
|
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})"
|
|
r"(?::\d+)?"
|
|
r"(?:/?|[/?]\S+)$",
|
|
re.IGNORECASE,
|
|
)
|
|
return url_pattern.match(url) is not None
|
|
|
|
def _validate_local_path(self, path: str) -> bool:
|
|
try:
|
|
normalized = os.path.normpath(path)
|
|
if not os.path.isabs(normalized):
|
|
return False
|
|
if ".." in normalized:
|
|
return False
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
def _is_duplicate_source(
|
|
self,
|
|
source: str,
|
|
item_type: ImportItemType,
|
|
recipe_scanner: Any,
|
|
) -> bool:
|
|
try:
|
|
cache = recipe_scanner.get_cached_data_sync()
|
|
if not cache:
|
|
return False
|
|
|
|
for recipe in getattr(cache, "raw_data", []):
|
|
source_path = recipe.get("source_path") or recipe.get("source_url")
|
|
if source_path and source_path == source:
|
|
return True
|
|
return False
|
|
except Exception:
|
|
self._logger.warning("Failed to check for duplicates", exc_info=True)
|
|
return False
|
|
|
|
async def start_batch_import(
|
|
self,
|
|
*,
|
|
recipe_scanner_getter: Callable[[], Any],
|
|
civitai_client_getter: Callable[[], Any],
|
|
items: List[Dict[str, str]],
|
|
tags: Optional[List[str]] = None,
|
|
skip_no_metadata: bool = False,
|
|
skip_duplicates: bool = False,
|
|
) -> str:
|
|
operation_id = str(uuid.uuid4())
|
|
|
|
import_items = []
|
|
for idx, item in enumerate(items):
|
|
source = item.get("source", "")
|
|
item_type_str = item.get("type", "url")
|
|
|
|
if item_type_str == "url" or source.startswith(("http://", "https://")):
|
|
item_type = ImportItemType.URL
|
|
else:
|
|
item_type = ImportItemType.LOCAL_PATH
|
|
|
|
batch_import_item = BatchImportItem(
|
|
id=f"{operation_id}_{idx}",
|
|
source=source,
|
|
item_type=item_type,
|
|
)
|
|
import_items.append(batch_import_item)
|
|
|
|
progress = BatchImportProgress(
|
|
operation_id=operation_id,
|
|
total=len(import_items),
|
|
items=import_items,
|
|
tags=tags or [],
|
|
skip_no_metadata=skip_no_metadata,
|
|
skip_duplicates=skip_duplicates,
|
|
)
|
|
|
|
self._active_operations[operation_id] = progress
|
|
self._cancellation_flags[operation_id] = False
|
|
|
|
asyncio.create_task(
|
|
self._run_batch_import(
|
|
operation_id=operation_id,
|
|
recipe_scanner_getter=recipe_scanner_getter,
|
|
civitai_client_getter=civitai_client_getter,
|
|
)
|
|
)
|
|
|
|
return operation_id
|
|
|
|
async def start_directory_import(
|
|
self,
|
|
*,
|
|
recipe_scanner_getter: Callable[[], Any],
|
|
civitai_client_getter: Callable[[], Any],
|
|
directory: str,
|
|
recursive: bool = True,
|
|
tags: Optional[List[str]] = None,
|
|
skip_no_metadata: bool = False,
|
|
skip_duplicates: bool = False,
|
|
) -> str:
|
|
image_paths = await self._discover_images(directory, recursive)
|
|
|
|
items = [{"source": path, "type": "local_path"} for path in image_paths]
|
|
|
|
return await self.start_batch_import(
|
|
recipe_scanner_getter=recipe_scanner_getter,
|
|
civitai_client_getter=civitai_client_getter,
|
|
items=items,
|
|
tags=tags,
|
|
skip_no_metadata=skip_no_metadata,
|
|
skip_duplicates=skip_duplicates,
|
|
)
|
|
|
|
async def _discover_images(
|
|
self,
|
|
directory: str,
|
|
recursive: bool = True,
|
|
) -> List[str]:
|
|
if not os.path.isdir(directory):
|
|
raise RecipeValidationError(f"Directory not found: {directory}")
|
|
|
|
image_paths: List[str] = []
|
|
|
|
if recursive:
|
|
for root, _, files in os.walk(directory):
|
|
for filename in files:
|
|
if self._is_supported_image(filename):
|
|
image_paths.append(os.path.join(root, filename))
|
|
else:
|
|
for filename in os.listdir(directory):
|
|
filepath = os.path.join(directory, filename)
|
|
if os.path.isfile(filepath) and self._is_supported_image(filename):
|
|
image_paths.append(filepath)
|
|
|
|
return sorted(image_paths)
|
|
|
|
def _is_supported_image(self, filename: str) -> bool:
|
|
ext = os.path.splitext(filename)[1].lower()
|
|
return ext in self.SUPPORTED_EXTENSIONS
|
|
|
|
async def _run_batch_import(
|
|
self,
|
|
*,
|
|
operation_id: str,
|
|
recipe_scanner_getter: Callable[[], Any],
|
|
civitai_client_getter: Callable[[], Any],
|
|
) -> None:
|
|
progress = self._active_operations.get(operation_id)
|
|
if not progress:
|
|
return
|
|
|
|
progress.status = "running"
|
|
await self._broadcast_progress(progress)
|
|
|
|
self._concurrency_controller = AdaptiveConcurrencyController()
|
|
|
|
async def process_item(item: BatchImportItem) -> None:
|
|
if self._cancellation_flags.get(operation_id, False):
|
|
return
|
|
|
|
progress.current_item = (
|
|
os.path.basename(item.source)
|
|
if item.item_type == ImportItemType.LOCAL_PATH
|
|
else item.source[:50]
|
|
)
|
|
item.status = ImportStatus.PROCESSING
|
|
await self._broadcast_progress(progress)
|
|
|
|
start_time = time.time()
|
|
try:
|
|
result = await self._import_single_item(
|
|
item=item,
|
|
recipe_scanner_getter=recipe_scanner_getter,
|
|
civitai_client_getter=civitai_client_getter,
|
|
tags=progress.tags,
|
|
skip_no_metadata=progress.skip_no_metadata,
|
|
skip_duplicates=progress.skip_duplicates,
|
|
semaphore=self._concurrency_controller.get_semaphore(),
|
|
)
|
|
|
|
duration = time.time() - start_time
|
|
item.duration = duration
|
|
self._concurrency_controller.record_result(
|
|
duration, result.get("success", False)
|
|
)
|
|
|
|
if result.get("success"):
|
|
item.status = ImportStatus.SUCCESS
|
|
item.recipe_name = result.get("recipe_name")
|
|
item.recipe_id = result.get("recipe_id")
|
|
progress.success += 1
|
|
elif result.get("skipped"):
|
|
item.status = ImportStatus.SKIPPED
|
|
item.error_message = result.get("error")
|
|
progress.skipped += 1
|
|
else:
|
|
item.status = ImportStatus.FAILED
|
|
item.error_message = result.get("error")
|
|
progress.failed += 1
|
|
|
|
except Exception as e:
|
|
self._logger.error(f"Error importing {item.source}: {e}")
|
|
item.status = ImportStatus.FAILED
|
|
item.error_message = str(e)
|
|
item.duration = time.time() - start_time
|
|
progress.failed += 1
|
|
self._concurrency_controller.record_result(item.duration, False)
|
|
|
|
progress.completed += 1
|
|
await self._broadcast_progress(progress)
|
|
|
|
tasks = [process_item(item) for item in progress.items]
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
if self._cancellation_flags.get(operation_id, False):
|
|
progress.status = "cancelled"
|
|
else:
|
|
progress.status = "completed"
|
|
|
|
progress.finished_at = time.time()
|
|
progress.current_item = ""
|
|
await self._broadcast_progress(progress)
|
|
|
|
await asyncio.sleep(5)
|
|
self._cleanup_operation(operation_id)
|
|
|
|
async def _import_single_item(
|
|
self,
|
|
*,
|
|
item: BatchImportItem,
|
|
recipe_scanner_getter: Callable[[], Any],
|
|
civitai_client_getter: Callable[[], Any],
|
|
tags: List[str],
|
|
skip_no_metadata: bool,
|
|
skip_duplicates: bool,
|
|
semaphore: asyncio.Semaphore,
|
|
) -> Dict[str, Any]:
|
|
async with semaphore:
|
|
recipe_scanner = recipe_scanner_getter()
|
|
if recipe_scanner is None:
|
|
return {"success": False, "error": "Recipe scanner unavailable"}
|
|
|
|
try:
|
|
if item.item_type == ImportItemType.URL:
|
|
if not self._validate_url(item.source):
|
|
return {
|
|
"success": False,
|
|
"error": f"Invalid URL format: {item.source}",
|
|
}
|
|
|
|
if skip_duplicates:
|
|
if self._is_duplicate_source(
|
|
item.source, item.item_type, recipe_scanner
|
|
):
|
|
return {
|
|
"success": False,
|
|
"skipped": True,
|
|
"error": "Duplicate source URL",
|
|
}
|
|
|
|
civitai_client = civitai_client_getter()
|
|
analysis_result = await self._analysis_service.analyze_remote_image(
|
|
url=item.source,
|
|
recipe_scanner=recipe_scanner,
|
|
civitai_client=civitai_client,
|
|
)
|
|
else:
|
|
if not self._validate_local_path(item.source):
|
|
return {
|
|
"success": False,
|
|
"error": f"Invalid or unsafe path: {item.source}",
|
|
}
|
|
|
|
if not os.path.exists(item.source):
|
|
return {
|
|
"success": False,
|
|
"error": f"File not found: {item.source}",
|
|
}
|
|
|
|
if skip_duplicates:
|
|
if self._is_duplicate_source(
|
|
item.source, item.item_type, recipe_scanner
|
|
):
|
|
return {
|
|
"success": False,
|
|
"skipped": True,
|
|
"error": "Duplicate source path",
|
|
}
|
|
|
|
analysis_result = await self._analysis_service.analyze_local_image(
|
|
file_path=item.source,
|
|
recipe_scanner=recipe_scanner,
|
|
)
|
|
|
|
payload = analysis_result.payload
|
|
|
|
if payload.get("error"):
|
|
if skip_no_metadata and "No metadata" in payload.get("error", ""):
|
|
return {
|
|
"success": False,
|
|
"skipped": True,
|
|
"error": payload["error"],
|
|
}
|
|
return {"success": False, "error": payload["error"]}
|
|
|
|
loras = payload.get("loras", [])
|
|
if not loras:
|
|
if skip_no_metadata:
|
|
return {
|
|
"success": False,
|
|
"skipped": True,
|
|
"error": "No LoRAs found in image",
|
|
}
|
|
# When skip_no_metadata is False, allow importing images without LoRAs
|
|
# Continue with empty loras list
|
|
|
|
recipe_name = self._generate_recipe_name(item, payload)
|
|
all_tags = list(set(tags + (payload.get("tags", []) or [])))
|
|
|
|
metadata = {
|
|
"base_model": payload.get("base_model", ""),
|
|
"loras": loras,
|
|
"gen_params": payload.get("gen_params", {}),
|
|
"source_path": item.source,
|
|
}
|
|
|
|
if payload.get("checkpoint"):
|
|
metadata["checkpoint"] = payload["checkpoint"]
|
|
|
|
image_bytes = None
|
|
image_base64 = payload.get("image_base64")
|
|
|
|
if item.item_type == ImportItemType.LOCAL_PATH:
|
|
with open(item.source, "rb") as f:
|
|
image_bytes = f.read()
|
|
image_base64 = None
|
|
|
|
save_result = await self._persistence_service.save_recipe(
|
|
recipe_scanner=recipe_scanner,
|
|
image_bytes=image_bytes,
|
|
image_base64=image_base64,
|
|
name=recipe_name,
|
|
tags=all_tags,
|
|
metadata=metadata,
|
|
extension=payload.get("extension"),
|
|
)
|
|
|
|
if save_result.status == 200:
|
|
return {
|
|
"success": True,
|
|
"recipe_name": recipe_name,
|
|
"recipe_id": save_result.payload.get("id"),
|
|
}
|
|
else:
|
|
return {
|
|
"success": False,
|
|
"error": save_result.payload.get(
|
|
"error", "Failed to save recipe"
|
|
),
|
|
}
|
|
|
|
except RecipeValidationError as e:
|
|
return {"success": False, "error": str(e)}
|
|
except RecipeDownloadError as e:
|
|
return {"success": False, "error": str(e)}
|
|
except RecipeNotFoundError as e:
|
|
return {"success": False, "skipped": True, "error": str(e)}
|
|
except Exception as e:
|
|
self._logger.error(
|
|
f"Unexpected error importing {item.source}: {e}", exc_info=True
|
|
)
|
|
return {"success": False, "error": str(e)}
|
|
|
|
def _generate_recipe_name(
|
|
self, item: BatchImportItem, payload: Dict[str, Any]
|
|
) -> str:
|
|
if item.item_type == ImportItemType.LOCAL_PATH:
|
|
base_name = os.path.splitext(os.path.basename(item.source))[0]
|
|
return base_name[:100]
|
|
else:
|
|
loras = payload.get("loras", [])
|
|
if loras:
|
|
first_lora = loras[0].get("name", "Recipe")
|
|
return f"Import - {first_lora}"[:100]
|
|
return f"Imported Recipe {item.id[:8]}"
|
|
|
|
async def _broadcast_progress(self, progress: BatchImportProgress) -> None:
|
|
await self._ws_manager.broadcast(
|
|
{
|
|
"type": "batch_import_progress",
|
|
**progress.to_dict(),
|
|
}
|
|
)
|
|
|
|
def _cleanup_operation(self, operation_id: str) -> None:
|
|
if operation_id in self._cancellation_flags:
|
|
del self._cancellation_flags[operation_id]
|