mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(batch-import): implement backend batch import service with adaptive concurrency
- Add BatchImportService with concurrent execution using asyncio.gather - Implement AdaptiveConcurrencyController with dynamic adjustment - Add input validation for URLs and local paths - Support duplicate detection via skip_duplicates parameter - Add WebSocket progress broadcasting for real-time updates - Create comprehensive unit tests for batch import functionality - Update API handlers and route registrations - Add i18n translation keys for batch import UI
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Base infrastructure shared across recipe routes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
@@ -16,12 +17,14 @@ from ..services.recipes import (
|
||||
RecipePersistenceService,
|
||||
RecipeSharingService,
|
||||
)
|
||||
from ..services.batch_import_service import BatchImportService
|
||||
from ..services.server_i18n import server_i18n
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from .handlers.recipe_handlers import (
|
||||
BatchImportHandler,
|
||||
RecipeAnalysisHandler,
|
||||
RecipeHandlerSet,
|
||||
RecipeListingHandler,
|
||||
@@ -116,7 +119,10 @@ class BaseRecipeRoutes:
|
||||
recipe_scanner_getter = lambda: self.recipe_scanner
|
||||
civitai_client_getter = lambda: self.civitai_client
|
||||
|
||||
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
standalone_mode = (
|
||||
os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1"
|
||||
or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
)
|
||||
if not standalone_mode:
|
||||
from ..metadata_collector import get_metadata # type: ignore[import-not-found]
|
||||
from ..metadata_collector.metadata_processor import ( # type: ignore[import-not-found]
|
||||
@@ -190,6 +196,22 @@ class BaseRecipeRoutes:
|
||||
sharing_service=sharing_service,
|
||||
)
|
||||
|
||||
from ..services.websocket_manager import ws_manager
|
||||
|
||||
batch_import_service = BatchImportService(
|
||||
analysis_service=analysis_service,
|
||||
persistence_service=persistence_service,
|
||||
ws_manager=ws_manager,
|
||||
logger=logger,
|
||||
)
|
||||
batch_import = BatchImportHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
logger=logger,
|
||||
batch_import_service=batch_import_service,
|
||||
)
|
||||
|
||||
return RecipeHandlerSet(
|
||||
page_view=page_view,
|
||||
listing=listing,
|
||||
@@ -197,4 +219,5 @@ class BaseRecipeRoutes:
|
||||
management=management,
|
||||
analysis=analysis,
|
||||
sharing=sharing,
|
||||
batch_import=batch_import,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Dedicated handler objects for recipe-related routes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
@@ -29,6 +30,7 @@ from ...utils.exif_utils import ExifUtils
|
||||
from ...recipes.merger import GenParamsMerger
|
||||
from ...recipes.enrichment import RecipeEnricher
|
||||
from ...services.websocket_manager import ws_manager as default_ws_manager
|
||||
from ...services.batch_import_service import BatchImportService
|
||||
|
||||
Logger = logging.Logger
|
||||
EnsureDependenciesCallable = Callable[[], Awaitable[None]]
|
||||
@@ -46,8 +48,11 @@ class RecipeHandlerSet:
|
||||
management: "RecipeManagementHandler"
|
||||
analysis: "RecipeAnalysisHandler"
|
||||
sharing: "RecipeSharingHandler"
|
||||
batch_import: "BatchImportHandler"
|
||||
|
||||
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]:
|
||||
def to_route_mapping(
|
||||
self,
|
||||
) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]:
|
||||
"""Expose handler coroutines keyed by registrar handler names."""
|
||||
|
||||
return {
|
||||
@@ -81,6 +86,9 @@ class RecipeHandlerSet:
|
||||
"cancel_repair": self.management.cancel_repair,
|
||||
"repair_recipe": self.management.repair_recipe,
|
||||
"get_repair_progress": self.management.get_repair_progress,
|
||||
"start_batch_import": self.batch_import.start_batch_import,
|
||||
"get_batch_import_progress": self.batch_import.get_batch_import_progress,
|
||||
"cancel_batch_import": self.batch_import.cancel_batch_import,
|
||||
}
|
||||
|
||||
|
||||
@@ -170,8 +178,10 @@ class RecipeListingHandler:
|
||||
search_options = {
|
||||
"title": request.query.get("search_title", "true").lower() == "true",
|
||||
"tags": request.query.get("search_tags", "true").lower() == "true",
|
||||
"lora_name": request.query.get("search_lora_name", "true").lower() == "true",
|
||||
"lora_model": request.query.get("search_lora_model", "true").lower() == "true",
|
||||
"lora_name": request.query.get("search_lora_name", "true").lower()
|
||||
== "true",
|
||||
"lora_model": request.query.get("search_lora_model", "true").lower()
|
||||
== "true",
|
||||
"prompt": request.query.get("search_prompt", "true").lower() == "true",
|
||||
}
|
||||
|
||||
@@ -246,7 +256,9 @@ class RecipeListingHandler:
|
||||
return web.json_response({"error": "Recipe not found"}, status=404)
|
||||
return web.json_response(recipe)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving recipe details: %s", exc, exc_info=True)
|
||||
self._logger.error(
|
||||
"Error retrieving recipe details: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
def format_recipe_file_url(self, file_path: str) -> str:
|
||||
@@ -256,7 +268,9 @@ class RecipeListingHandler:
|
||||
if static_url:
|
||||
return static_url
|
||||
except Exception as exc: # pragma: no cover - logging path
|
||||
self._logger.error("Error formatting recipe file URL: %s", exc, exc_info=True)
|
||||
self._logger.error(
|
||||
"Error formatting recipe file URL: %s", exc, exc_info=True
|
||||
)
|
||||
return "/loras_static/images/no-preview.png"
|
||||
|
||||
return "/loras_static/images/no-preview.png"
|
||||
@@ -293,7 +307,9 @@ class RecipeQueryHandler:
|
||||
for tag in recipe.get("tags", []) or []:
|
||||
tag_counts[tag] = tag_counts.get(tag, 0) + 1
|
||||
|
||||
sorted_tags = [{"tag": tag, "count": count} for tag, count in tag_counts.items()]
|
||||
sorted_tags = [
|
||||
{"tag": tag, "count": count} for tag, count in tag_counts.items()
|
||||
]
|
||||
sorted_tags.sort(key=lambda entry: entry["count"], reverse=True)
|
||||
return web.json_response({"success": True, "tags": sorted_tags[:limit]})
|
||||
except Exception as exc:
|
||||
@@ -313,9 +329,14 @@ class RecipeQueryHandler:
|
||||
for recipe in getattr(cache, "raw_data", []):
|
||||
base_model = recipe.get("base_model")
|
||||
if base_model:
|
||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||
base_model_counts[base_model] = (
|
||||
base_model_counts.get(base_model, 0) + 1
|
||||
)
|
||||
|
||||
sorted_models = [{"name": model, "count": count} for model, count in base_model_counts.items()]
|
||||
sorted_models = [
|
||||
{"name": model, "count": count}
|
||||
for model, count in base_model_counts.items()
|
||||
]
|
||||
sorted_models.sort(key=lambda entry: entry["count"], reverse=True)
|
||||
return web.json_response({"success": True, "base_models": sorted_models})
|
||||
except Exception as exc:
|
||||
@@ -345,7 +366,9 @@ class RecipeQueryHandler:
|
||||
folders = await recipe_scanner.get_folders()
|
||||
return web.json_response({"success": True, "folders": folders})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving recipe folders: %s", exc, exc_info=True)
|
||||
self._logger.error(
|
||||
"Error retrieving recipe folders: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_folder_tree(self, request: web.Request) -> web.Response:
|
||||
@@ -358,7 +381,9 @@ class RecipeQueryHandler:
|
||||
folder_tree = await recipe_scanner.get_folder_tree()
|
||||
return web.json_response({"success": True, "tree": folder_tree})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving recipe folder tree: %s", exc, exc_info=True)
|
||||
self._logger.error(
|
||||
"Error retrieving recipe folder tree: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_unified_folder_tree(self, request: web.Request) -> web.Response:
|
||||
@@ -371,7 +396,9 @@ class RecipeQueryHandler:
|
||||
folder_tree = await recipe_scanner.get_folder_tree()
|
||||
return web.json_response({"success": True, "tree": folder_tree})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving unified recipe folder tree: %s", exc, exc_info=True)
|
||||
self._logger.error(
|
||||
"Error retrieving unified recipe folder tree: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_recipes_for_lora(self, request: web.Request) -> web.Response:
|
||||
@@ -383,7 +410,9 @@ class RecipeQueryHandler:
|
||||
|
||||
lora_hash = request.query.get("hash")
|
||||
if not lora_hash:
|
||||
return web.json_response({"success": False, "error": "Lora hash is required"}, status=400)
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Lora hash is required"}, status=400
|
||||
)
|
||||
|
||||
matching_recipes = await recipe_scanner.get_recipes_for_lora(lora_hash)
|
||||
return web.json_response({"success": True, "recipes": matching_recipes})
|
||||
@@ -400,7 +429,9 @@ class RecipeQueryHandler:
|
||||
|
||||
self._logger.info("Manually triggering recipe cache rebuild")
|
||||
await recipe_scanner.get_cached_data(force_refresh=True)
|
||||
return web.json_response({"success": True, "message": "Recipe cache refreshed successfully"})
|
||||
return web.json_response(
|
||||
{"success": True, "message": "Recipe cache refreshed successfully"}
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error refreshing recipe cache: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
@@ -429,7 +460,9 @@ class RecipeQueryHandler:
|
||||
"id": recipe.get("id"),
|
||||
"title": recipe.get("title"),
|
||||
"file_url": recipe.get("file_url")
|
||||
or self._format_recipe_file_url(recipe.get("file_path", "")),
|
||||
or self._format_recipe_file_url(
|
||||
recipe.get("file_path", "")
|
||||
),
|
||||
"modified": recipe.get("modified"),
|
||||
"created_date": recipe.get("created_date"),
|
||||
"lora_count": len(recipe.get("loras", [])),
|
||||
@@ -437,7 +470,9 @@ class RecipeQueryHandler:
|
||||
)
|
||||
|
||||
if len(recipes) >= 2:
|
||||
recipes.sort(key=lambda entry: entry.get("modified", 0), reverse=True)
|
||||
recipes.sort(
|
||||
key=lambda entry: entry.get("modified", 0), reverse=True
|
||||
)
|
||||
response_data.append(
|
||||
{
|
||||
"type": "fingerprint",
|
||||
@@ -460,7 +495,9 @@ class RecipeQueryHandler:
|
||||
"id": recipe.get("id"),
|
||||
"title": recipe.get("title"),
|
||||
"file_url": recipe.get("file_url")
|
||||
or self._format_recipe_file_url(recipe.get("file_path", "")),
|
||||
or self._format_recipe_file_url(
|
||||
recipe.get("file_path", "")
|
||||
),
|
||||
"modified": recipe.get("modified"),
|
||||
"created_date": recipe.get("created_date"),
|
||||
"lora_count": len(recipe.get("loras", [])),
|
||||
@@ -468,7 +505,9 @@ class RecipeQueryHandler:
|
||||
)
|
||||
|
||||
if len(recipes) >= 2:
|
||||
recipes.sort(key=lambda entry: entry.get("modified", 0), reverse=True)
|
||||
recipes.sort(
|
||||
key=lambda entry: entry.get("modified", 0), reverse=True
|
||||
)
|
||||
response_data.append(
|
||||
{
|
||||
"type": "source_url",
|
||||
@@ -479,9 +518,13 @@ class RecipeQueryHandler:
|
||||
)
|
||||
|
||||
response_data.sort(key=lambda entry: entry["count"], reverse=True)
|
||||
return web.json_response({"success": True, "duplicate_groups": response_data})
|
||||
return web.json_response(
|
||||
{"success": True, "duplicate_groups": response_data}
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error finding duplicate recipes: %s", exc, exc_info=True)
|
||||
self._logger.error(
|
||||
"Error finding duplicate recipes: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_recipe_syntax(self, request: web.Request) -> web.Response:
|
||||
@@ -498,9 +541,13 @@ class RecipeQueryHandler:
|
||||
return web.json_response({"error": "Recipe not found"}, status=404)
|
||||
|
||||
if not syntax_parts:
|
||||
return web.json_response({"error": "No LoRAs found in this recipe"}, status=400)
|
||||
return web.json_response(
|
||||
{"error": "No LoRAs found in this recipe"}, status=400
|
||||
)
|
||||
|
||||
return web.json_response({"success": True, "syntax": " ".join(syntax_parts)})
|
||||
return web.json_response(
|
||||
{"success": True, "syntax": " ".join(syntax_parts)}
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error generating recipe syntax: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
@@ -561,11 +608,17 @@ class RecipeManagementHandler:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
return web.json_response({"success": False, "error": "Recipe scanner unavailable"}, status=503)
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Recipe scanner unavailable"},
|
||||
status=503,
|
||||
)
|
||||
|
||||
# Check if already running
|
||||
if self._ws_manager.is_recipe_repair_running():
|
||||
return web.json_response({"success": False, "error": "Recipe repair already in progress"}, status=409)
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Recipe repair already in progress"},
|
||||
status=409,
|
||||
)
|
||||
|
||||
recipe_scanner.reset_cancellation()
|
||||
|
||||
@@ -579,11 +632,12 @@ class RecipeManagementHandler:
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error in recipe repair task: {e}", exc_info=True)
|
||||
await self._ws_manager.broadcast_recipe_repair_progress({
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
})
|
||||
self._logger.error(
|
||||
f"Error in recipe repair task: {e}", exc_info=True
|
||||
)
|
||||
await self._ws_manager.broadcast_recipe_repair_progress(
|
||||
{"status": "error", "error": str(e)}
|
||||
)
|
||||
finally:
|
||||
# Keep the final status for a while so the UI can see it
|
||||
await asyncio.sleep(5)
|
||||
@@ -593,7 +647,9 @@ class RecipeManagementHandler:
|
||||
|
||||
asyncio.create_task(run_repair())
|
||||
|
||||
return web.json_response({"success": True, "message": "Recipe repair started"})
|
||||
return web.json_response(
|
||||
{"success": True, "message": "Recipe repair started"}
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error starting recipe repair: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
@@ -603,10 +659,15 @@ class RecipeManagementHandler:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
return web.json_response({"success": False, "error": "Recipe scanner unavailable"}, status=503)
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Recipe scanner unavailable"},
|
||||
status=503,
|
||||
)
|
||||
|
||||
recipe_scanner.cancel_task()
|
||||
return web.json_response({"success": True, "message": "Cancellation requested"})
|
||||
return web.json_response(
|
||||
{"success": True, "message": "Cancellation requested"}
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error cancelling recipe repair: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
@@ -616,7 +677,10 @@ class RecipeManagementHandler:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
return web.json_response({"success": False, "error": "Recipe scanner unavailable"}, status=503)
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Recipe scanner unavailable"},
|
||||
status=503,
|
||||
)
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
result = await recipe_scanner.repair_recipe_by_id(recipe_id)
|
||||
@@ -632,25 +696,26 @@ class RecipeManagementHandler:
|
||||
progress = self._ws_manager.get_recipe_repair_progress()
|
||||
if progress:
|
||||
return web.json_response({"success": True, "progress": progress})
|
||||
return web.json_response({"success": False, "message": "No repair in progress"}, status=404)
|
||||
return web.json_response(
|
||||
{"success": False, "message": "No repair in progress"}, status=404
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error getting repair progress: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
async def import_remote_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
|
||||
# 1. Parse Parameters
|
||||
params = request.rel_url.query
|
||||
image_url = params.get("image_url")
|
||||
name = params.get("name")
|
||||
resources_raw = params.get("resources")
|
||||
|
||||
|
||||
if not image_url:
|
||||
raise RecipeValidationError("Missing required field: image_url")
|
||||
if not name:
|
||||
@@ -658,64 +723,80 @@ class RecipeManagementHandler:
|
||||
if not resources_raw:
|
||||
raise RecipeValidationError("Missing required field: resources")
|
||||
|
||||
checkpoint_entry, lora_entries = self._parse_resources_payload(resources_raw)
|
||||
checkpoint_entry, lora_entries = self._parse_resources_payload(
|
||||
resources_raw
|
||||
)
|
||||
gen_params_request = self._parse_gen_params(params.get("gen_params"))
|
||||
|
||||
|
||||
# 2. Initial Metadata Construction
|
||||
metadata: Dict[str, Any] = {
|
||||
"base_model": params.get("base_model", "") or "",
|
||||
"loras": lora_entries,
|
||||
"gen_params": gen_params_request or {},
|
||||
"source_url": image_url
|
||||
"source_url": image_url,
|
||||
}
|
||||
|
||||
|
||||
source_path = params.get("source_path")
|
||||
if source_path:
|
||||
metadata["source_path"] = source_path
|
||||
|
||||
|
||||
# Checkpoint handling
|
||||
if checkpoint_entry:
|
||||
metadata["checkpoint"] = checkpoint_entry
|
||||
# Ensure checkpoint is also in gen_params for consistency if needed by enricher?
|
||||
# Actually enricher looks at metadata['checkpoint'], so this is fine.
|
||||
|
||||
|
||||
# Try to resolve base model from checkpoint if not explicitly provided
|
||||
if not metadata["base_model"]:
|
||||
base_model_from_metadata = await self._resolve_base_model_from_checkpoint(checkpoint_entry)
|
||||
base_model_from_metadata = (
|
||||
await self._resolve_base_model_from_checkpoint(checkpoint_entry)
|
||||
)
|
||||
if base_model_from_metadata:
|
||||
metadata["base_model"] = base_model_from_metadata
|
||||
|
||||
tags = self._parse_tags(params.get("tags"))
|
||||
|
||||
|
||||
# 3. Download Image
|
||||
image_bytes, extension, civitai_meta_from_download = await self._download_remote_media(image_url)
|
||||
(
|
||||
image_bytes,
|
||||
extension,
|
||||
civitai_meta_from_download,
|
||||
) = await self._download_remote_media(image_url)
|
||||
|
||||
# 4. Extract Embedded Metadata
|
||||
# Note: We still extract this here because Enricher currently expects 'gen_params' to already be populated
|
||||
# with embedded data if we want it to merge it.
|
||||
# Note: We still extract this here because Enricher currently expects 'gen_params' to already be populated
|
||||
# with embedded data if we want it to merge it.
|
||||
# However, logic in Enricher merges: request > civitai > embedded.
|
||||
# So we should gather embedded params and put them into the recipe's gen_params (as initial state)
|
||||
# So we should gather embedded params and put them into the recipe's gen_params (as initial state)
|
||||
# OR pass them to enricher to handle?
|
||||
# The interface of Enricher.enrich_recipe takes `recipe` (with gen_params) and `request_params`.
|
||||
# So let's extract embedded and put it into recipe['gen_params'] but careful not to overwrite request params.
|
||||
# Actually, `GenParamsMerger` which `Enricher` uses handles 3 layers.
|
||||
# But `Enricher` interface is: recipe['gen_params'] (as embedded) + request_params + civitai (fetched internally).
|
||||
# Wait, `Enricher` fetches Civitai info internally based on URL.
|
||||
# Wait, `Enricher` fetches Civitai info internally based on URL.
|
||||
# `civitai_meta_from_download` is returned by `_download_remote_media` which might be useful if URL didn't have ID.
|
||||
|
||||
|
||||
# Let's extract embedded metadata first
|
||||
embedded_gen_params = {}
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_img:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=extension, delete=False
|
||||
) as temp_img:
|
||||
temp_img.write(image_bytes)
|
||||
temp_img_path = temp_img.name
|
||||
|
||||
|
||||
try:
|
||||
raw_embedded = ExifUtils.extract_image_metadata(temp_img_path)
|
||||
if raw_embedded:
|
||||
parser = self._analysis_service._recipe_parser_factory.create_parser(raw_embedded)
|
||||
parser = (
|
||||
self._analysis_service._recipe_parser_factory.create_parser(
|
||||
raw_embedded
|
||||
)
|
||||
)
|
||||
if parser:
|
||||
parsed_embedded = await parser.parse_metadata(raw_embedded, recipe_scanner=recipe_scanner)
|
||||
parsed_embedded = await parser.parse_metadata(
|
||||
raw_embedded, recipe_scanner=recipe_scanner
|
||||
)
|
||||
if parsed_embedded and "gen_params" in parsed_embedded:
|
||||
embedded_gen_params = parsed_embedded["gen_params"]
|
||||
else:
|
||||
@@ -724,7 +805,9 @@ class RecipeManagementHandler:
|
||||
if os.path.exists(temp_img_path):
|
||||
os.unlink(temp_img_path)
|
||||
except Exception as exc:
|
||||
self._logger.warning("Failed to extract embedded metadata during import: %s", exc)
|
||||
self._logger.warning(
|
||||
"Failed to extract embedded metadata during import: %s", exc
|
||||
)
|
||||
|
||||
# Pre-populate gen_params with embedded data so Enricher treats it as the "base" layer
|
||||
if embedded_gen_params:
|
||||
@@ -732,18 +815,18 @@ class RecipeManagementHandler:
|
||||
# But wait, we want request params to override everything.
|
||||
# So we should set recipe['gen_params'] = embedded, and pass request params to enricher.
|
||||
metadata["gen_params"] = embedded_gen_params
|
||||
|
||||
|
||||
# 5. Enrich with unified logic
|
||||
# This will fetch Civitai info (if URL matches) and merge: request > civitai > embedded
|
||||
civitai_client = self._civitai_client_getter()
|
||||
await RecipeEnricher.enrich_recipe(
|
||||
recipe=metadata,
|
||||
recipe=metadata,
|
||||
civitai_client=civitai_client,
|
||||
request_params=gen_params_request # Pass explicit request params here to override
|
||||
request_params=gen_params_request, # Pass explicit request params here to override
|
||||
)
|
||||
|
||||
|
||||
# If we got civitai_meta from download but Enricher didn't fetch it (e.g. not a civitai URL or failed),
|
||||
# we might want to manually merge it?
|
||||
# we might want to manually merge it?
|
||||
# But usually `import_remote_recipe` is used with Civitai URLs.
|
||||
# For now, relying on Enricher's internal fetch is consistent with repair.
|
||||
|
||||
@@ -762,7 +845,9 @@ class RecipeManagementHandler:
|
||||
except RecipeDownloadError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error importing recipe from remote source: %s", exc, exc_info=True)
|
||||
self._logger.error(
|
||||
"Error importing recipe from remote source: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def delete_recipe(self, request: web.Request) -> web.Response:
|
||||
@@ -816,7 +901,11 @@ class RecipeManagementHandler:
|
||||
target_path = data.get("target_path")
|
||||
if not recipe_id or not target_path:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "recipe_id and target_path are required"}, status=400
|
||||
{
|
||||
"success": False,
|
||||
"error": "recipe_id and target_path are required",
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
|
||||
result = await self._persistence_service.move_recipe(
|
||||
@@ -845,7 +934,11 @@ class RecipeManagementHandler:
|
||||
target_path = data.get("target_path")
|
||||
if not recipe_ids or not target_path:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "recipe_ids and target_path are required"}, status=400
|
||||
{
|
||||
"success": False,
|
||||
"error": "recipe_ids and target_path are required",
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
|
||||
result = await self._persistence_service.move_recipes_bulk(
|
||||
@@ -934,7 +1027,9 @@ class RecipeManagementHandler:
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error saving recipe from widget: %s", exc, exc_info=True)
|
||||
self._logger.error(
|
||||
"Error saving recipe from widget: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def _parse_save_payload(self, reader) -> dict[str, Any]:
|
||||
@@ -1006,7 +1101,9 @@ class RecipeManagementHandler:
|
||||
raise RecipeValidationError("gen_params payload must be an object")
|
||||
return parsed
|
||||
|
||||
def _parse_resources_payload(self, payload_raw: str) -> tuple[Optional[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
def _parse_resources_payload(
|
||||
self, payload_raw: str
|
||||
) -> tuple[Optional[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
try:
|
||||
payload = json.loads(payload_raw)
|
||||
except json.JSONDecodeError as exc:
|
||||
@@ -1066,15 +1163,19 @@ class RecipeManagementHandler:
|
||||
civitai_match = re.match(r"https://civitai\.com/images/(\d+)", image_url)
|
||||
if civitai_match:
|
||||
if civitai_client is None:
|
||||
raise RecipeDownloadError("Civitai client unavailable for image download")
|
||||
raise RecipeDownloadError(
|
||||
"Civitai client unavailable for image download"
|
||||
)
|
||||
image_info = await civitai_client.get_image_info(civitai_match.group(1))
|
||||
if not image_info:
|
||||
raise RecipeDownloadError("Failed to fetch image information from Civitai")
|
||||
|
||||
raise RecipeDownloadError(
|
||||
"Failed to fetch image information from Civitai"
|
||||
)
|
||||
|
||||
media_url = image_info.get("url")
|
||||
if not media_url:
|
||||
raise RecipeDownloadError("No image URL found in Civitai response")
|
||||
|
||||
|
||||
# Use optimized preview URLs if possible
|
||||
media_type = image_info.get("type")
|
||||
rewritten_url, _ = rewrite_preview_url(media_url, media_type=media_type)
|
||||
@@ -1083,18 +1184,24 @@ class RecipeManagementHandler:
|
||||
else:
|
||||
download_url = media_url
|
||||
|
||||
success, result = await downloader.download_file(download_url, temp_path, use_auth=False)
|
||||
success, result = await downloader.download_file(
|
||||
download_url, temp_path, use_auth=False
|
||||
)
|
||||
if not success:
|
||||
raise RecipeDownloadError(f"Failed to download image: {result}")
|
||||
|
||||
|
||||
# Extract extension from URL
|
||||
url_path = download_url.split('?')[0].split('#')[0]
|
||||
url_path = download_url.split("?")[0].split("#")[0]
|
||||
extension = os.path.splitext(url_path)[1].lower()
|
||||
if not extension:
|
||||
extension = ".webp" # Default to webp if unknown
|
||||
extension = ".webp" # Default to webp if unknown
|
||||
|
||||
with open(temp_path, "rb") as file_obj:
|
||||
return file_obj.read(), extension, image_info.get("meta") if civitai_match and image_info else None
|
||||
return (
|
||||
file_obj.read(),
|
||||
extension,
|
||||
image_info.get("meta") if civitai_match and image_info else None,
|
||||
)
|
||||
except RecipeDownloadError:
|
||||
raise
|
||||
except RecipeValidationError:
|
||||
@@ -1108,14 +1215,15 @@ class RecipeManagementHandler:
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def _safe_int(self, value: Any) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
async def _resolve_base_model_from_checkpoint(self, checkpoint_entry: Dict[str, Any]) -> str:
|
||||
async def _resolve_base_model_from_checkpoint(
|
||||
self, checkpoint_entry: Dict[str, Any]
|
||||
) -> str:
|
||||
version_id = self._safe_int(checkpoint_entry.get("modelVersionId"))
|
||||
|
||||
if not version_id:
|
||||
@@ -1134,7 +1242,9 @@ class RecipeManagementHandler:
|
||||
base_model = version_info.get("baseModel") or ""
|
||||
return str(base_model) if base_model is not None else ""
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._logger.warning("Failed to resolve base model from checkpoint metadata: %s", exc)
|
||||
self._logger.warning(
|
||||
"Failed to resolve base model from checkpoint metadata: %s", exc
|
||||
)
|
||||
|
||||
return ""
|
||||
|
||||
@@ -1279,5 +1389,178 @@ class RecipeSharingHandler:
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error downloading shared recipe: %s", exc, exc_info=True)
|
||||
self._logger.error(
|
||||
"Error downloading shared recipe: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class BatchImportHandler:
|
||||
"""Handle batch import operations for recipes."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
civitai_client_getter: CivitaiClientGetter,
|
||||
logger: Logger,
|
||||
batch_import_service: BatchImportService,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._civitai_client_getter = civitai_client_getter
|
||||
self._logger = logger
|
||||
self._batch_import_service = batch_import_service
|
||||
|
||||
async def start_batch_import(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
|
||||
if self._batch_import_service.is_import_running():
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Batch import already in progress"},
|
||||
status=409,
|
||||
)
|
||||
|
||||
data = await request.json()
|
||||
items = data.get("items", [])
|
||||
tags = data.get("tags", [])
|
||||
skip_no_metadata = data.get("skip_no_metadata", True)
|
||||
|
||||
if not items:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "No items provided"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
for item in items:
|
||||
if not item.get("source"):
|
||||
return web.json_response(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Each item must have a 'source' field",
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
|
||||
operation_id = await self._batch_import_service.start_batch_import(
|
||||
recipe_scanner_getter=self._recipe_scanner_getter,
|
||||
civitai_client_getter=self._civitai_client_getter,
|
||||
items=items,
|
||||
tags=tags,
|
||||
skip_no_metadata=skip_no_metadata,
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"operation_id": operation_id,
|
||||
}
|
||||
)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error starting batch import: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def start_directory_import(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
|
||||
if self._batch_import_service.is_import_running():
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Batch import already in progress"},
|
||||
status=409,
|
||||
)
|
||||
|
||||
data = await request.json()
|
||||
directory = data.get("directory")
|
||||
recursive = data.get("recursive", True)
|
||||
tags = data.get("tags", [])
|
||||
skip_no_metadata = data.get("skip_no_metadata", True)
|
||||
|
||||
if not directory:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Directory path is required"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
operation_id = await self._batch_import_service.start_directory_import(
|
||||
recipe_scanner_getter=self._recipe_scanner_getter,
|
||||
civitai_client_getter=self._civitai_client_getter,
|
||||
directory=directory,
|
||||
recursive=recursive,
|
||||
tags=tags,
|
||||
skip_no_metadata=skip_no_metadata,
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"operation_id": operation_id,
|
||||
}
|
||||
)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error(
|
||||
"Error starting directory import: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_batch_import_progress(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
operation_id = request.query.get("operation_id")
|
||||
if not operation_id:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "operation_id is required"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
progress = self._batch_import_service.get_progress(operation_id)
|
||||
if not progress:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Operation not found"},
|
||||
status=404,
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"progress": progress.to_dict(),
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.error(
|
||||
"Error getting batch import progress: %s", exc, exc_info=True
|
||||
)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def cancel_batch_import(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
data = await request.json()
|
||||
operation_id = data.get("operation_id")
|
||||
|
||||
if not operation_id:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "operation_id is required"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
cancelled = self._batch_import_service.cancel_import(operation_id)
|
||||
if not cancelled:
|
||||
return web.json_response(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Operation not found or already completed",
|
||||
},
|
||||
status=404,
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{"success": True, "message": "Cancellation requested"}
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error cancelling batch import: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Route registrar for recipe endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
@@ -22,7 +23,9 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}", "get_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/import-remote", "import_remote_recipe"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/analyze-image", "analyze_uploaded_image"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image"),
|
||||
RouteDefinition(
|
||||
"POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image"
|
||||
),
|
||||
RouteDefinition("POST", "/api/lm/recipes/save", "save_recipe"),
|
||||
RouteDefinition("DELETE", "/api/lm/recipe/{recipe_id}", "delete_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/top-tags", "get_top_tags"),
|
||||
@@ -30,9 +33,13 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/api/lm/recipes/roots", "get_roots"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/folders", "get_folders"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/folder-tree", "get_folder_tree"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/unified-folder-tree", "get_unified_folder_tree"),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/recipes/unified-folder-tree", "get_unified_folder_tree"
|
||||
),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share", "share_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share/download", "download_shared_recipe"),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/recipe/{recipe_id}/share/download", "download_shared_recipe"
|
||||
),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/syntax", "get_recipe_syntax"),
|
||||
RouteDefinition("PUT", "/api/lm/recipe/{recipe_id}/update", "update_recipe"),
|
||||
RouteDefinition("POST", "/api/lm/recipe/move", "move_recipe"),
|
||||
@@ -40,13 +47,22 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("POST", "/api/lm/recipe/lora/reconnect", "reconnect_lora"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/find-duplicates", "find_duplicates"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/bulk-delete", "bulk_delete"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget"),
|
||||
RouteDefinition(
|
||||
"POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget"
|
||||
),
|
||||
RouteDefinition("GET", "/api/lm/recipes/for-lora", "get_recipes_for_lora"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/scan", "scan_recipes"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/repair", "repair_recipes"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/cancel-repair", "cancel_repair"),
|
||||
RouteDefinition("POST", "/api/lm/recipe/{recipe_id}/repair", "repair_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/repair-progress", "get_repair_progress"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/batch-import/start", "start_batch_import"),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/recipes/batch-import/progress", "get_batch_import_progress"
|
||||
),
|
||||
RouteDefinition(
|
||||
"POST", "/api/lm/recipes/batch-import/cancel", "cancel_batch_import"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -63,7 +79,9 @@ class RecipeRouteRegistrar:
|
||||
def __init__(self, app: web.Application) -> None:
|
||||
self._app = app
|
||||
|
||||
def register_routes(self, handler_lookup: Mapping[str, Callable[[web.Request], object]]) -> None:
|
||||
def register_routes(
|
||||
self, handler_lookup: Mapping[str, Callable[[web.Request], object]]
|
||||
) -> None:
|
||||
for definition in ROUTE_DEFINITIONS:
|
||||
handler = handler_lookup[definition.handler_name]
|
||||
self._bind_route(definition.method, definition.path, handler)
|
||||
|
||||
579
py/services/batch_import_service.py
Normal file
579
py/services/batch_import_service.py
Normal file
@@ -0,0 +1,579 @@
|
||||
"""Batch import service for importing multiple images as recipes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from .recipes import (
|
||||
RecipeAnalysisService,
|
||||
RecipePersistenceService,
|
||||
RecipeValidationError,
|
||||
RecipeDownloadError,
|
||||
RecipeNotFoundError,
|
||||
)
|
||||
|
||||
|
||||
class ImportItemType(Enum):
|
||||
"""Type of import item."""
|
||||
|
||||
URL = "url"
|
||||
LOCAL_PATH = "local_path"
|
||||
|
||||
|
||||
class ImportStatus(Enum):
|
||||
"""Status of an individual import item."""
|
||||
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchImportItem:
|
||||
"""Represents a single item to import."""
|
||||
|
||||
id: str
|
||||
source: str
|
||||
item_type: ImportItemType
|
||||
status: ImportStatus = ImportStatus.PENDING
|
||||
error_message: Optional[str] = None
|
||||
recipe_name: Optional[str] = None
|
||||
recipe_id: Optional[str] = None
|
||||
duration: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchImportProgress:
|
||||
"""Tracks progress of a batch import operation."""
|
||||
|
||||
operation_id: str
|
||||
total: int
|
||||
completed: int = 0
|
||||
success: int = 0
|
||||
failed: int = 0
|
||||
skipped: int = 0
|
||||
current_item: str = ""
|
||||
status: str = "pending"
|
||||
started_at: float = field(default_factory=time.time)
|
||||
finished_at: Optional[float] = None
|
||||
items: List[BatchImportItem] = field(default_factory=list)
|
||||
tags: List[str] = field(default_factory=list)
|
||||
skip_no_metadata: bool = True
|
||||
skip_duplicates: bool = False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"operation_id": self.operation_id,
|
||||
"total": self.total,
|
||||
"completed": self.completed,
|
||||
"success": self.success,
|
||||
"failed": self.failed,
|
||||
"skipped": self.skipped,
|
||||
"current_item": self.current_item,
|
||||
"status": self.status,
|
||||
"started_at": self.started_at,
|
||||
"finished_at": self.finished_at,
|
||||
"progress_percent": round((self.completed / self.total) * 100, 1)
|
||||
if self.total > 0
|
||||
else 0,
|
||||
}
|
||||
|
||||
|
||||
class AdaptiveConcurrencyController:
|
||||
"""Adjusts concurrency based on task performance."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_concurrency: int = 1,
|
||||
max_concurrency: int = 5,
|
||||
initial_concurrency: int = 3,
|
||||
) -> None:
|
||||
self.min_concurrency = min_concurrency
|
||||
self.max_concurrency = max_concurrency
|
||||
self.current_concurrency = initial_concurrency
|
||||
self._task_durations: List[float] = []
|
||||
self._recent_errors = 0
|
||||
self._recent_successes = 0
|
||||
|
||||
def record_result(self, duration: float, success: bool) -> None:
|
||||
self._task_durations.append(duration)
|
||||
if len(self._task_durations) > 10:
|
||||
self._task_durations.pop(0)
|
||||
|
||||
if success:
|
||||
self._recent_successes += 1
|
||||
if duration < 1.0 and self.current_concurrency < self.max_concurrency:
|
||||
self.current_concurrency = min(
|
||||
self.current_concurrency + 1, self.max_concurrency
|
||||
)
|
||||
elif duration > 10.0 and self.current_concurrency > self.min_concurrency:
|
||||
self.current_concurrency = max(
|
||||
self.current_concurrency - 1, self.min_concurrency
|
||||
)
|
||||
else:
|
||||
self._recent_errors += 1
|
||||
if self.current_concurrency > self.min_concurrency:
|
||||
self.current_concurrency = max(
|
||||
self.current_concurrency - 1, self.min_concurrency
|
||||
)
|
||||
|
||||
def reset_counters(self) -> None:
|
||||
self._recent_errors = 0
|
||||
self._recent_successes = 0
|
||||
|
||||
def get_semaphore(self) -> asyncio.Semaphore:
|
||||
return asyncio.Semaphore(self.current_concurrency)
|
||||
|
||||
|
||||
class BatchImportService:
|
||||
"""Service for batch importing images as recipes."""
|
||||
|
||||
SUPPORTED_EXTENSIONS: Set[str] = {".jpg", ".jpeg", ".png", ".webp", ".gif", ".bmp"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
analysis_service: RecipeAnalysisService,
|
||||
persistence_service: RecipePersistenceService,
|
||||
ws_manager: Any,
|
||||
logger: logging.Logger,
|
||||
) -> None:
|
||||
self._analysis_service = analysis_service
|
||||
self._persistence_service = persistence_service
|
||||
self._ws_manager = ws_manager
|
||||
self._logger = logger
|
||||
self._active_operations: Dict[str, BatchImportProgress] = {}
|
||||
self._cancellation_flags: Dict[str, bool] = {}
|
||||
self._concurrency_controller = AdaptiveConcurrencyController()
|
||||
|
||||
def is_import_running(self, operation_id: Optional[str] = None) -> bool:
|
||||
if operation_id:
|
||||
progress = self._active_operations.get(operation_id)
|
||||
return progress is not None and progress.status in ("pending", "running")
|
||||
return any(
|
||||
p.status in ("pending", "running") for p in self._active_operations.values()
|
||||
)
|
||||
|
||||
def get_progress(self, operation_id: str) -> Optional[BatchImportProgress]:
|
||||
return self._active_operations.get(operation_id)
|
||||
|
||||
def cancel_import(self, operation_id: str) -> bool:
|
||||
if operation_id in self._active_operations:
|
||||
self._cancellation_flags[operation_id] = True
|
||||
return True
|
||||
return False
|
||||
|
||||
def _validate_url(self, url: str) -> bool:
|
||||
import re
|
||||
|
||||
url_pattern = re.compile(
|
||||
r"^https?://"
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|"
|
||||
r"localhost|"
|
||||
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})"
|
||||
r"(?::\d+)?"
|
||||
r"(?:/?|[/?]\S+)$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
return url_pattern.match(url) is not None
|
||||
|
||||
def _validate_local_path(self, path: str) -> bool:
|
||||
try:
|
||||
normalized = os.path.normpath(path)
|
||||
if not os.path.isabs(normalized):
|
||||
return False
|
||||
if ".." in normalized:
|
||||
return False
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _is_duplicate_source(
|
||||
self,
|
||||
source: str,
|
||||
item_type: ImportItemType,
|
||||
recipe_scanner: Any,
|
||||
) -> bool:
|
||||
try:
|
||||
cache = recipe_scanner.get_cached_data_sync()
|
||||
if not cache:
|
||||
return False
|
||||
|
||||
for recipe in getattr(cache, "raw_data", []):
|
||||
source_path = recipe.get("source_path") or recipe.get("source_url")
|
||||
if source_path and source_path == source:
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
self._logger.warning("Failed to check for duplicates", exc_info=True)
|
||||
return False
|
||||
|
||||
async def start_batch_import(
|
||||
self,
|
||||
*,
|
||||
recipe_scanner_getter: Callable[[], Any],
|
||||
civitai_client_getter: Callable[[], Any],
|
||||
items: List[Dict[str, str]],
|
||||
tags: Optional[List[str]] = None,
|
||||
skip_no_metadata: bool = True,
|
||||
skip_duplicates: bool = False,
|
||||
) -> str:
|
||||
operation_id = str(uuid.uuid4())
|
||||
|
||||
import_items = []
|
||||
for idx, item in enumerate(items):
|
||||
source = item.get("source", "")
|
||||
item_type_str = item.get("type", "url")
|
||||
|
||||
if item_type_str == "url" or source.startswith(("http://", "https://")):
|
||||
item_type = ImportItemType.URL
|
||||
else:
|
||||
item_type = ImportItemType.LOCAL_PATH
|
||||
|
||||
batch_import_item = BatchImportItem(
|
||||
id=f"{operation_id}_{idx}",
|
||||
source=source,
|
||||
item_type=item_type,
|
||||
)
|
||||
import_items.append(batch_import_item)
|
||||
|
||||
progress = BatchImportProgress(
|
||||
operation_id=operation_id,
|
||||
total=len(import_items),
|
||||
items=import_items,
|
||||
tags=tags or [],
|
||||
skip_no_metadata=skip_no_metadata,
|
||||
skip_duplicates=skip_duplicates,
|
||||
)
|
||||
|
||||
self._active_operations[operation_id] = progress
|
||||
self._cancellation_flags[operation_id] = False
|
||||
|
||||
asyncio.create_task(
|
||||
self._run_batch_import(
|
||||
operation_id=operation_id,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
)
|
||||
)
|
||||
|
||||
return operation_id
|
||||
|
||||
async def start_directory_import(
|
||||
self,
|
||||
*,
|
||||
recipe_scanner_getter: Callable[[], Any],
|
||||
civitai_client_getter: Callable[[], Any],
|
||||
directory: str,
|
||||
recursive: bool = True,
|
||||
tags: Optional[List[str]] = None,
|
||||
skip_no_metadata: bool = True,
|
||||
skip_duplicates: bool = False,
|
||||
) -> str:
|
||||
image_paths = await self._discover_images(directory, recursive)
|
||||
|
||||
items = [{"source": path, "type": "local_path"} for path in image_paths]
|
||||
|
||||
return await self.start_batch_import(
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
items=items,
|
||||
tags=tags,
|
||||
skip_no_metadata=skip_no_metadata,
|
||||
skip_duplicates=skip_duplicates,
|
||||
)
|
||||
|
||||
async def _discover_images(
|
||||
self,
|
||||
directory: str,
|
||||
recursive: bool = True,
|
||||
) -> List[str]:
|
||||
if not os.path.isdir(directory):
|
||||
raise RecipeValidationError(f"Directory not found: {directory}")
|
||||
|
||||
image_paths: List[str] = []
|
||||
|
||||
if recursive:
|
||||
for root, _, files in os.walk(directory):
|
||||
for filename in files:
|
||||
if self._is_supported_image(filename):
|
||||
image_paths.append(os.path.join(root, filename))
|
||||
else:
|
||||
for filename in os.listdir(directory):
|
||||
filepath = os.path.join(directory, filename)
|
||||
if os.path.isfile(filepath) and self._is_supported_image(filename):
|
||||
image_paths.append(filepath)
|
||||
|
||||
return sorted(image_paths)
|
||||
|
||||
def _is_supported_image(self, filename: str) -> bool:
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
return ext in self.SUPPORTED_EXTENSIONS
|
||||
|
||||
async def _run_batch_import(
|
||||
self,
|
||||
*,
|
||||
operation_id: str,
|
||||
recipe_scanner_getter: Callable[[], Any],
|
||||
civitai_client_getter: Callable[[], Any],
|
||||
) -> None:
|
||||
progress = self._active_operations.get(operation_id)
|
||||
if not progress:
|
||||
return
|
||||
|
||||
progress.status = "running"
|
||||
await self._broadcast_progress(progress)
|
||||
|
||||
self._concurrency_controller = AdaptiveConcurrencyController()
|
||||
|
||||
async def process_item(item: BatchImportItem) -> None:
|
||||
if self._cancellation_flags.get(operation_id, False):
|
||||
return
|
||||
|
||||
progress.current_item = (
|
||||
os.path.basename(item.source)
|
||||
if item.item_type == ImportItemType.LOCAL_PATH
|
||||
else item.source[:50]
|
||||
)
|
||||
item.status = ImportStatus.PROCESSING
|
||||
await self._broadcast_progress(progress)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = await self._import_single_item(
|
||||
item=item,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
tags=progress.tags,
|
||||
skip_no_metadata=progress.skip_no_metadata,
|
||||
skip_duplicates=progress.skip_duplicates,
|
||||
semaphore=self._concurrency_controller.get_semaphore(),
|
||||
)
|
||||
|
||||
duration = time.time() - start_time
|
||||
item.duration = duration
|
||||
self._concurrency_controller.record_result(
|
||||
duration, result.get("success", False)
|
||||
)
|
||||
|
||||
if result.get("success"):
|
||||
item.status = ImportStatus.SUCCESS
|
||||
item.recipe_name = result.get("recipe_name")
|
||||
item.recipe_id = result.get("recipe_id")
|
||||
progress.success += 1
|
||||
elif result.get("skipped"):
|
||||
item.status = ImportStatus.SKIPPED
|
||||
item.error_message = result.get("error")
|
||||
progress.skipped += 1
|
||||
else:
|
||||
item.status = ImportStatus.FAILED
|
||||
item.error_message = result.get("error")
|
||||
progress.failed += 1
|
||||
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error importing {item.source}: {e}")
|
||||
item.status = ImportStatus.FAILED
|
||||
item.error_message = str(e)
|
||||
item.duration = time.time() - start_time
|
||||
progress.failed += 1
|
||||
self._concurrency_controller.record_result(item.duration, False)
|
||||
|
||||
progress.completed += 1
|
||||
await self._broadcast_progress(progress)
|
||||
|
||||
tasks = [process_item(item) for item in progress.items]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
if self._cancellation_flags.get(operation_id, False):
|
||||
progress.status = "cancelled"
|
||||
else:
|
||||
progress.status = "completed"
|
||||
|
||||
progress.finished_at = time.time()
|
||||
progress.current_item = ""
|
||||
await self._broadcast_progress(progress)
|
||||
|
||||
await asyncio.sleep(5)
|
||||
self._cleanup_operation(operation_id)
|
||||
|
||||
async def _import_single_item(
|
||||
self,
|
||||
*,
|
||||
item: BatchImportItem,
|
||||
recipe_scanner_getter: Callable[[], Any],
|
||||
civitai_client_getter: Callable[[], Any],
|
||||
tags: List[str],
|
||||
skip_no_metadata: bool,
|
||||
skip_duplicates: bool,
|
||||
semaphore: asyncio.Semaphore,
|
||||
) -> Dict[str, Any]:
|
||||
async with semaphore:
|
||||
recipe_scanner = recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
return {"success": False, "error": "Recipe scanner unavailable"}
|
||||
|
||||
try:
|
||||
if item.item_type == ImportItemType.URL:
|
||||
if not self._validate_url(item.source):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid URL format: {item.source}",
|
||||
}
|
||||
|
||||
if skip_duplicates:
|
||||
if self._is_duplicate_source(
|
||||
item.source, item.item_type, recipe_scanner
|
||||
):
|
||||
return {
|
||||
"success": False,
|
||||
"skipped": True,
|
||||
"error": "Duplicate source URL",
|
||||
}
|
||||
|
||||
civitai_client = civitai_client_getter()
|
||||
analysis_result = await self._analysis_service.analyze_remote_image(
|
||||
url=item.source,
|
||||
recipe_scanner=recipe_scanner,
|
||||
civitai_client=civitai_client,
|
||||
)
|
||||
else:
|
||||
if not self._validate_local_path(item.source):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid or unsafe path: {item.source}",
|
||||
}
|
||||
|
||||
if not os.path.exists(item.source):
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"File not found: {item.source}",
|
||||
}
|
||||
|
||||
if skip_duplicates:
|
||||
if self._is_duplicate_source(
|
||||
item.source, item.item_type, recipe_scanner
|
||||
):
|
||||
return {
|
||||
"success": False,
|
||||
"skipped": True,
|
||||
"error": "Duplicate source path",
|
||||
}
|
||||
|
||||
analysis_result = await self._analysis_service.analyze_local_image(
|
||||
file_path=item.source,
|
||||
recipe_scanner=recipe_scanner,
|
||||
)
|
||||
|
||||
payload = analysis_result.payload
|
||||
|
||||
if payload.get("error"):
|
||||
if skip_no_metadata and "No metadata" in payload.get("error", ""):
|
||||
return {
|
||||
"success": False,
|
||||
"skipped": True,
|
||||
"error": payload["error"],
|
||||
}
|
||||
return {"success": False, "error": payload["error"]}
|
||||
|
||||
loras = payload.get("loras", [])
|
||||
if not loras:
|
||||
if skip_no_metadata:
|
||||
return {
|
||||
"success": False,
|
||||
"skipped": True,
|
||||
"error": "No LoRAs found in image",
|
||||
}
|
||||
return {"success": False, "error": "No LoRAs found in image"}
|
||||
|
||||
recipe_name = self._generate_recipe_name(item, payload)
|
||||
all_tags = list(set(tags + (payload.get("tags", []) or [])))
|
||||
|
||||
metadata = {
|
||||
"base_model": payload.get("base_model", ""),
|
||||
"loras": loras,
|
||||
"gen_params": payload.get("gen_params", {}),
|
||||
"source_path": item.source,
|
||||
}
|
||||
|
||||
if payload.get("checkpoint"):
|
||||
metadata["checkpoint"] = payload["checkpoint"]
|
||||
|
||||
image_bytes = None
|
||||
image_base64 = payload.get("image_base64")
|
||||
|
||||
if item.item_type == ImportItemType.LOCAL_PATH:
|
||||
with open(item.source, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
image_base64 = None
|
||||
|
||||
save_result = await self._persistence_service.save_recipe(
|
||||
recipe_scanner=recipe_scanner,
|
||||
image_bytes=image_bytes,
|
||||
image_base64=image_base64,
|
||||
name=recipe_name,
|
||||
tags=all_tags,
|
||||
metadata=metadata,
|
||||
extension=payload.get("extension"),
|
||||
)
|
||||
|
||||
if save_result.status == 200:
|
||||
return {
|
||||
"success": True,
|
||||
"recipe_name": recipe_name,
|
||||
"recipe_id": save_result.payload.get("id"),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": save_result.payload.get(
|
||||
"error", "Failed to save recipe"
|
||||
),
|
||||
}
|
||||
|
||||
except RecipeValidationError as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
except RecipeDownloadError as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
except RecipeNotFoundError as e:
|
||||
return {"success": False, "skipped": True, "error": str(e)}
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"Unexpected error importing {item.source}: {e}", exc_info=True
|
||||
)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
def _generate_recipe_name(
|
||||
self, item: BatchImportItem, payload: Dict[str, Any]
|
||||
) -> str:
|
||||
if item.item_type == ImportItemType.LOCAL_PATH:
|
||||
base_name = os.path.splitext(os.path.basename(item.source))[0]
|
||||
return base_name[:100]
|
||||
else:
|
||||
loras = payload.get("loras", [])
|
||||
if loras:
|
||||
first_lora = loras[0].get("name", "Recipe")
|
||||
return f"Import - {first_lora}"[:100]
|
||||
return f"Imported Recipe {item.id[:8]}"
|
||||
|
||||
async def _broadcast_progress(self, progress: BatchImportProgress) -> None:
|
||||
await self._ws_manager.broadcast(
|
||||
{
|
||||
"type": "batch_import_progress",
|
||||
**progress.to_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
def _cleanup_operation(self, operation_id: str) -> None:
|
||||
if operation_id in self._cancellation_flags:
|
||||
del self._cancellation_flags[operation_id]
|
||||
Reference in New Issue
Block a user