feat(batch-import): implement backend batch import service with adaptive concurrency

- Add BatchImportService with concurrent execution using asyncio.gather
- Implement AdaptiveConcurrencyController with dynamic adjustment
- Add input validation for URLs and local paths
- Support duplicate detection via skip_duplicates parameter
- Add WebSocket progress broadcasting for real-time updates
- Create comprehensive unit tests for batch import functionality
- Update API handlers and route registrations
- Add i18n translation keys for batch import UI
This commit is contained in:
Will Miao
2026-03-14 06:54:24 +08:00
parent c89d4dae85
commit f86651652c
8 changed files with 1928 additions and 127 deletions

View File

@@ -729,6 +729,42 @@
"failed": "Failed to repair recipe: {message}",
"missingId": "Cannot repair recipe: Missing recipe ID"
}
},
"batchImport": {
"title": "Batch Import Recipes",
"action": "Batch Import",
"urlsMode": "URLs / Paths",
"directoryMode": "Directory",
"urlsDescription": "Enter image URLs or local file paths (one per line). Each will be imported as a recipe.",
"urlsLabel": "Image URLs or Local Paths",
"urlsPlaceholder": "https://civitai.com/images/...\nhttps://civitai.com/images/...\nC:/path/to/image.png\n...",
"directoryDescription": "Enter a directory path to import all images from that folder.",
"directoryLabel": "Directory Path",
"directoryPlaceholder": "/path/to/images/folder",
"tagsOptional": "Tags (optional, applied to all recipes)",
"addTagPlaceholder": "Add a tag",
"addTag": "Add",
"noTags": "No tags added",
"skipNoMetadata": "Skip images without metadata",
"skipNoMetadataHelp": "Images without LoRA metadata will be skipped automatically.",
"startImport": "Start Import",
"importing": "Importing Recipes...",
"progress": "Progress",
"currentItem": "Current",
"preparing": "Preparing...",
"cancelImport": "Cancel",
"cancelled": "Batch import cancelled",
"completed": "Import Completed",
"completedSuccess": "Successfully imported {count} recipe(s)",
"successCount": "Successful",
"failedCount": "Failed",
"skippedCount": "Skipped",
"totalProcessed": "Total processed",
"errors": {
"enterUrls": "Please enter at least one URL or path",
"enterDirectory": "Please enter a directory path",
"startFailed": "Failed to start import: {message}"
}
}
},
"checkpoints": {

View File

@@ -722,13 +722,49 @@
"getInfoFailed": "获取缺失 LoRA 信息失败",
"prepareError": "准备下载 LoRA 时出错:{message}"
},
"repair": {
"repair": {
"starting": "正在修复配方元数据...",
"success": "配方元数据修复成功",
"skipped": "配方已是最新版本,无需修复",
"failed": "修复配方失败:{message}",
"missingId": "无法修复配方:缺少配方 ID"
}
},
"batchImport": {
"title": "批量导入配方",
"action": "批量导入",
"urlsMode": "URL / 路径",
"directoryMode": "目录",
"urlsDescription": "输入图片 URL 或本地文件路径(每行一个)。每个将作为配方导入。",
"urlsLabel": "图片 URL 或本地路径",
"urlsPlaceholder": "https://civitai.com/images/...\nhttps://civitai.com/images/...\nC:/path/to/image.png\n...",
"directoryDescription": "输入目录路径以导入该文件夹中的所有图片。",
"directoryLabel": "目录路径",
"directoryPlaceholder": "/图片/文件夹/路径",
"tagsOptional": "标签(可选,应用于所有配方)",
"addTagPlaceholder": "添加标签",
"addTag": "添加",
"noTags": "未添加标签",
"skipNoMetadata": "跳过无元数据的图片",
"skipNoMetadataHelp": "没有 LoRA 元数据的图片将自动跳过。",
"startImport": "开始导入",
"importing": "正在导入配方...",
"progress": "进度",
"currentItem": "当前",
"preparing": "准备中...",
"cancelImport": "取消",
"cancelled": "批量导入已取消",
"completed": "导入完成",
"completedSuccess": "成功导入 {count} 个配方",
"successCount": "成功",
"failedCount": "失败",
"skippedCount": "跳过",
"totalProcessed": "总计处理",
"errors": {
"enterUrls": "请至少输入一个 URL 或路径",
"enterDirectory": "请输入目录路径",
"startFailed": "启动导入失败:{message}"
}
}
},
"checkpoints": {

View File

@@ -1,4 +1,5 @@
"""Base infrastructure shared across recipe routes."""
from __future__ import annotations
import logging
@@ -16,12 +17,14 @@ from ..services.recipes import (
RecipePersistenceService,
RecipeSharingService,
)
from ..services.batch_import_service import BatchImportService
from ..services.server_i18n import server_i18n
from ..services.service_registry import ServiceRegistry
from ..services.settings_manager import get_settings_manager
from ..utils.constants import CARD_PREVIEW_WIDTH
from ..utils.exif_utils import ExifUtils
from .handlers.recipe_handlers import (
BatchImportHandler,
RecipeAnalysisHandler,
RecipeHandlerSet,
RecipeListingHandler,
@@ -116,7 +119,10 @@ class BaseRecipeRoutes:
recipe_scanner_getter = lambda: self.recipe_scanner
civitai_client_getter = lambda: self.civitai_client
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
standalone_mode = (
os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1"
or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
)
if not standalone_mode:
from ..metadata_collector import get_metadata # type: ignore[import-not-found]
from ..metadata_collector.metadata_processor import ( # type: ignore[import-not-found]
@@ -190,6 +196,22 @@ class BaseRecipeRoutes:
sharing_service=sharing_service,
)
from ..services.websocket_manager import ws_manager
batch_import_service = BatchImportService(
analysis_service=analysis_service,
persistence_service=persistence_service,
ws_manager=ws_manager,
logger=logger,
)
batch_import = BatchImportHandler(
ensure_dependencies_ready=self.ensure_dependencies_ready,
recipe_scanner_getter=recipe_scanner_getter,
civitai_client_getter=civitai_client_getter,
logger=logger,
batch_import_service=batch_import_service,
)
return RecipeHandlerSet(
page_view=page_view,
listing=listing,
@@ -197,4 +219,5 @@ class BaseRecipeRoutes:
management=management,
analysis=analysis,
sharing=sharing,
batch_import=batch_import,
)

View File

@@ -1,4 +1,5 @@
"""Dedicated handler objects for recipe-related routes."""
from __future__ import annotations
import json
@@ -29,6 +30,7 @@ from ...utils.exif_utils import ExifUtils
from ...recipes.merger import GenParamsMerger
from ...recipes.enrichment import RecipeEnricher
from ...services.websocket_manager import ws_manager as default_ws_manager
from ...services.batch_import_service import BatchImportService
Logger = logging.Logger
EnsureDependenciesCallable = Callable[[], Awaitable[None]]
@@ -46,8 +48,11 @@ class RecipeHandlerSet:
management: "RecipeManagementHandler"
analysis: "RecipeAnalysisHandler"
sharing: "RecipeSharingHandler"
batch_import: "BatchImportHandler"
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]:
def to_route_mapping(
self,
) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]:
"""Expose handler coroutines keyed by registrar handler names."""
return {
@@ -81,6 +86,9 @@ class RecipeHandlerSet:
"cancel_repair": self.management.cancel_repair,
"repair_recipe": self.management.repair_recipe,
"get_repair_progress": self.management.get_repair_progress,
"start_batch_import": self.batch_import.start_batch_import,
"get_batch_import_progress": self.batch_import.get_batch_import_progress,
"cancel_batch_import": self.batch_import.cancel_batch_import,
}
@@ -170,8 +178,10 @@ class RecipeListingHandler:
search_options = {
"title": request.query.get("search_title", "true").lower() == "true",
"tags": request.query.get("search_tags", "true").lower() == "true",
"lora_name": request.query.get("search_lora_name", "true").lower() == "true",
"lora_model": request.query.get("search_lora_model", "true").lower() == "true",
"lora_name": request.query.get("search_lora_name", "true").lower()
== "true",
"lora_model": request.query.get("search_lora_model", "true").lower()
== "true",
"prompt": request.query.get("search_prompt", "true").lower() == "true",
}
@@ -246,7 +256,9 @@ class RecipeListingHandler:
return web.json_response({"error": "Recipe not found"}, status=404)
return web.json_response(recipe)
except Exception as exc:
self._logger.error("Error retrieving recipe details: %s", exc, exc_info=True)
self._logger.error(
"Error retrieving recipe details: %s", exc, exc_info=True
)
return web.json_response({"error": str(exc)}, status=500)
def format_recipe_file_url(self, file_path: str) -> str:
@@ -256,7 +268,9 @@ class RecipeListingHandler:
if static_url:
return static_url
except Exception as exc: # pragma: no cover - logging path
self._logger.error("Error formatting recipe file URL: %s", exc, exc_info=True)
self._logger.error(
"Error formatting recipe file URL: %s", exc, exc_info=True
)
return "/loras_static/images/no-preview.png"
return "/loras_static/images/no-preview.png"
@@ -293,7 +307,9 @@ class RecipeQueryHandler:
for tag in recipe.get("tags", []) or []:
tag_counts[tag] = tag_counts.get(tag, 0) + 1
sorted_tags = [{"tag": tag, "count": count} for tag, count in tag_counts.items()]
sorted_tags = [
{"tag": tag, "count": count} for tag, count in tag_counts.items()
]
sorted_tags.sort(key=lambda entry: entry["count"], reverse=True)
return web.json_response({"success": True, "tags": sorted_tags[:limit]})
except Exception as exc:
@@ -313,9 +329,14 @@ class RecipeQueryHandler:
for recipe in getattr(cache, "raw_data", []):
base_model = recipe.get("base_model")
if base_model:
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
base_model_counts[base_model] = (
base_model_counts.get(base_model, 0) + 1
)
sorted_models = [{"name": model, "count": count} for model, count in base_model_counts.items()]
sorted_models = [
{"name": model, "count": count}
for model, count in base_model_counts.items()
]
sorted_models.sort(key=lambda entry: entry["count"], reverse=True)
return web.json_response({"success": True, "base_models": sorted_models})
except Exception as exc:
@@ -345,7 +366,9 @@ class RecipeQueryHandler:
folders = await recipe_scanner.get_folders()
return web.json_response({"success": True, "folders": folders})
except Exception as exc:
self._logger.error("Error retrieving recipe folders: %s", exc, exc_info=True)
self._logger.error(
"Error retrieving recipe folders: %s", exc, exc_info=True
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_folder_tree(self, request: web.Request) -> web.Response:
@@ -358,7 +381,9 @@ class RecipeQueryHandler:
folder_tree = await recipe_scanner.get_folder_tree()
return web.json_response({"success": True, "tree": folder_tree})
except Exception as exc:
self._logger.error("Error retrieving recipe folder tree: %s", exc, exc_info=True)
self._logger.error(
"Error retrieving recipe folder tree: %s", exc, exc_info=True
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_unified_folder_tree(self, request: web.Request) -> web.Response:
@@ -371,7 +396,9 @@ class RecipeQueryHandler:
folder_tree = await recipe_scanner.get_folder_tree()
return web.json_response({"success": True, "tree": folder_tree})
except Exception as exc:
self._logger.error("Error retrieving unified recipe folder tree: %s", exc, exc_info=True)
self._logger.error(
"Error retrieving unified recipe folder tree: %s", exc, exc_info=True
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_recipes_for_lora(self, request: web.Request) -> web.Response:
@@ -383,7 +410,9 @@ class RecipeQueryHandler:
lora_hash = request.query.get("hash")
if not lora_hash:
return web.json_response({"success": False, "error": "Lora hash is required"}, status=400)
return web.json_response(
{"success": False, "error": "Lora hash is required"}, status=400
)
matching_recipes = await recipe_scanner.get_recipes_for_lora(lora_hash)
return web.json_response({"success": True, "recipes": matching_recipes})
@@ -400,7 +429,9 @@ class RecipeQueryHandler:
self._logger.info("Manually triggering recipe cache rebuild")
await recipe_scanner.get_cached_data(force_refresh=True)
return web.json_response({"success": True, "message": "Recipe cache refreshed successfully"})
return web.json_response(
{"success": True, "message": "Recipe cache refreshed successfully"}
)
except Exception as exc:
self._logger.error("Error refreshing recipe cache: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
@@ -429,7 +460,9 @@ class RecipeQueryHandler:
"id": recipe.get("id"),
"title": recipe.get("title"),
"file_url": recipe.get("file_url")
or self._format_recipe_file_url(recipe.get("file_path", "")),
or self._format_recipe_file_url(
recipe.get("file_path", "")
),
"modified": recipe.get("modified"),
"created_date": recipe.get("created_date"),
"lora_count": len(recipe.get("loras", [])),
@@ -437,7 +470,9 @@ class RecipeQueryHandler:
)
if len(recipes) >= 2:
recipes.sort(key=lambda entry: entry.get("modified", 0), reverse=True)
recipes.sort(
key=lambda entry: entry.get("modified", 0), reverse=True
)
response_data.append(
{
"type": "fingerprint",
@@ -460,7 +495,9 @@ class RecipeQueryHandler:
"id": recipe.get("id"),
"title": recipe.get("title"),
"file_url": recipe.get("file_url")
or self._format_recipe_file_url(recipe.get("file_path", "")),
or self._format_recipe_file_url(
recipe.get("file_path", "")
),
"modified": recipe.get("modified"),
"created_date": recipe.get("created_date"),
"lora_count": len(recipe.get("loras", [])),
@@ -468,7 +505,9 @@ class RecipeQueryHandler:
)
if len(recipes) >= 2:
recipes.sort(key=lambda entry: entry.get("modified", 0), reverse=True)
recipes.sort(
key=lambda entry: entry.get("modified", 0), reverse=True
)
response_data.append(
{
"type": "source_url",
@@ -479,9 +518,13 @@ class RecipeQueryHandler:
)
response_data.sort(key=lambda entry: entry["count"], reverse=True)
return web.json_response({"success": True, "duplicate_groups": response_data})
return web.json_response(
{"success": True, "duplicate_groups": response_data}
)
except Exception as exc:
self._logger.error("Error finding duplicate recipes: %s", exc, exc_info=True)
self._logger.error(
"Error finding duplicate recipes: %s", exc, exc_info=True
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_recipe_syntax(self, request: web.Request) -> web.Response:
@@ -498,9 +541,13 @@ class RecipeQueryHandler:
return web.json_response({"error": "Recipe not found"}, status=404)
if not syntax_parts:
return web.json_response({"error": "No LoRAs found in this recipe"}, status=400)
return web.json_response(
{"error": "No LoRAs found in this recipe"}, status=400
)
return web.json_response({"success": True, "syntax": " ".join(syntax_parts)})
return web.json_response(
{"success": True, "syntax": " ".join(syntax_parts)}
)
except Exception as exc:
self._logger.error("Error generating recipe syntax: %s", exc, exc_info=True)
return web.json_response({"error": str(exc)}, status=500)
@@ -561,11 +608,17 @@ class RecipeManagementHandler:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
return web.json_response({"success": False, "error": "Recipe scanner unavailable"}, status=503)
return web.json_response(
{"success": False, "error": "Recipe scanner unavailable"},
status=503,
)
# Check if already running
if self._ws_manager.is_recipe_repair_running():
return web.json_response({"success": False, "error": "Recipe repair already in progress"}, status=409)
return web.json_response(
{"success": False, "error": "Recipe repair already in progress"},
status=409,
)
recipe_scanner.reset_cancellation()
@@ -579,11 +632,12 @@ class RecipeManagementHandler:
progress_callback=progress_callback
)
except Exception as e:
self._logger.error(f"Error in recipe repair task: {e}", exc_info=True)
await self._ws_manager.broadcast_recipe_repair_progress({
"status": "error",
"error": str(e)
})
self._logger.error(
f"Error in recipe repair task: {e}", exc_info=True
)
await self._ws_manager.broadcast_recipe_repair_progress(
{"status": "error", "error": str(e)}
)
finally:
# Keep the final status for a while so the UI can see it
await asyncio.sleep(5)
@@ -593,7 +647,9 @@ class RecipeManagementHandler:
asyncio.create_task(run_repair())
return web.json_response({"success": True, "message": "Recipe repair started"})
return web.json_response(
{"success": True, "message": "Recipe repair started"}
)
except Exception as exc:
self._logger.error("Error starting recipe repair: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
@@ -603,10 +659,15 @@ class RecipeManagementHandler:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
return web.json_response({"success": False, "error": "Recipe scanner unavailable"}, status=503)
return web.json_response(
{"success": False, "error": "Recipe scanner unavailable"},
status=503,
)
recipe_scanner.cancel_task()
return web.json_response({"success": True, "message": "Cancellation requested"})
return web.json_response(
{"success": True, "message": "Cancellation requested"}
)
except Exception as exc:
self._logger.error("Error cancelling recipe repair: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
@@ -616,7 +677,10 @@ class RecipeManagementHandler:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
return web.json_response({"success": False, "error": "Recipe scanner unavailable"}, status=503)
return web.json_response(
{"success": False, "error": "Recipe scanner unavailable"},
status=503,
)
recipe_id = request.match_info["recipe_id"]
result = await recipe_scanner.repair_recipe_by_id(recipe_id)
@@ -632,25 +696,26 @@ class RecipeManagementHandler:
progress = self._ws_manager.get_recipe_repair_progress()
if progress:
return web.json_response({"success": True, "progress": progress})
return web.json_response({"success": False, "message": "No repair in progress"}, status=404)
return web.json_response(
{"success": False, "message": "No repair in progress"}, status=404
)
except Exception as exc:
self._logger.error("Error getting repair progress: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def import_remote_recipe(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
# 1. Parse Parameters
params = request.rel_url.query
image_url = params.get("image_url")
name = params.get("name")
resources_raw = params.get("resources")
if not image_url:
raise RecipeValidationError("Missing required field: image_url")
if not name:
@@ -658,64 +723,80 @@ class RecipeManagementHandler:
if not resources_raw:
raise RecipeValidationError("Missing required field: resources")
checkpoint_entry, lora_entries = self._parse_resources_payload(resources_raw)
checkpoint_entry, lora_entries = self._parse_resources_payload(
resources_raw
)
gen_params_request = self._parse_gen_params(params.get("gen_params"))
# 2. Initial Metadata Construction
metadata: Dict[str, Any] = {
"base_model": params.get("base_model", "") or "",
"loras": lora_entries,
"gen_params": gen_params_request or {},
"source_url": image_url
"source_url": image_url,
}
source_path = params.get("source_path")
if source_path:
metadata["source_path"] = source_path
# Checkpoint handling
if checkpoint_entry:
metadata["checkpoint"] = checkpoint_entry
# Ensure checkpoint is also in gen_params for consistency if needed by enricher?
# Actually enricher looks at metadata['checkpoint'], so this is fine.
# Try to resolve base model from checkpoint if not explicitly provided
if not metadata["base_model"]:
base_model_from_metadata = await self._resolve_base_model_from_checkpoint(checkpoint_entry)
base_model_from_metadata = (
await self._resolve_base_model_from_checkpoint(checkpoint_entry)
)
if base_model_from_metadata:
metadata["base_model"] = base_model_from_metadata
tags = self._parse_tags(params.get("tags"))
# 3. Download Image
image_bytes, extension, civitai_meta_from_download = await self._download_remote_media(image_url)
(
image_bytes,
extension,
civitai_meta_from_download,
) = await self._download_remote_media(image_url)
# 4. Extract Embedded Metadata
# Note: We still extract this here because Enricher currently expects 'gen_params' to already be populated
# with embedded data if we want it to merge it.
# Note: We still extract this here because Enricher currently expects 'gen_params' to already be populated
# with embedded data if we want it to merge it.
# However, logic in Enricher merges: request > civitai > embedded.
# So we should gather embedded params and put them into the recipe's gen_params (as initial state)
# So we should gather embedded params and put them into the recipe's gen_params (as initial state)
# OR pass them to enricher to handle?
# The interface of Enricher.enrich_recipe takes `recipe` (with gen_params) and `request_params`.
# So let's extract embedded and put it into recipe['gen_params'] but careful not to overwrite request params.
# Actually, `GenParamsMerger` which `Enricher` uses handles 3 layers.
# But `Enricher` interface is: recipe['gen_params'] (as embedded) + request_params + civitai (fetched internally).
# Wait, `Enricher` fetches Civitai info internally based on URL.
# Wait, `Enricher` fetches Civitai info internally based on URL.
# `civitai_meta_from_download` is returned by `_download_remote_media` which might be useful if URL didn't have ID.
# Let's extract embedded metadata first
embedded_gen_params = {}
try:
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_img:
with tempfile.NamedTemporaryFile(
suffix=extension, delete=False
) as temp_img:
temp_img.write(image_bytes)
temp_img_path = temp_img.name
try:
raw_embedded = ExifUtils.extract_image_metadata(temp_img_path)
if raw_embedded:
parser = self._analysis_service._recipe_parser_factory.create_parser(raw_embedded)
parser = (
self._analysis_service._recipe_parser_factory.create_parser(
raw_embedded
)
)
if parser:
parsed_embedded = await parser.parse_metadata(raw_embedded, recipe_scanner=recipe_scanner)
parsed_embedded = await parser.parse_metadata(
raw_embedded, recipe_scanner=recipe_scanner
)
if parsed_embedded and "gen_params" in parsed_embedded:
embedded_gen_params = parsed_embedded["gen_params"]
else:
@@ -724,7 +805,9 @@ class RecipeManagementHandler:
if os.path.exists(temp_img_path):
os.unlink(temp_img_path)
except Exception as exc:
self._logger.warning("Failed to extract embedded metadata during import: %s", exc)
self._logger.warning(
"Failed to extract embedded metadata during import: %s", exc
)
# Pre-populate gen_params with embedded data so Enricher treats it as the "base" layer
if embedded_gen_params:
@@ -732,18 +815,18 @@ class RecipeManagementHandler:
# But wait, we want request params to override everything.
# So we should set recipe['gen_params'] = embedded, and pass request params to enricher.
metadata["gen_params"] = embedded_gen_params
# 5. Enrich with unified logic
# This will fetch Civitai info (if URL matches) and merge: request > civitai > embedded
civitai_client = self._civitai_client_getter()
await RecipeEnricher.enrich_recipe(
recipe=metadata,
recipe=metadata,
civitai_client=civitai_client,
request_params=gen_params_request # Pass explicit request params here to override
request_params=gen_params_request, # Pass explicit request params here to override
)
# If we got civitai_meta from download but Enricher didn't fetch it (e.g. not a civitai URL or failed),
# we might want to manually merge it?
# we might want to manually merge it?
# But usually `import_remote_recipe` is used with Civitai URLs.
# For now, relying on Enricher's internal fetch is consistent with repair.
@@ -762,7 +845,9 @@ class RecipeManagementHandler:
except RecipeDownloadError as exc:
return web.json_response({"error": str(exc)}, status=400)
except Exception as exc:
self._logger.error("Error importing recipe from remote source: %s", exc, exc_info=True)
self._logger.error(
"Error importing recipe from remote source: %s", exc, exc_info=True
)
return web.json_response({"error": str(exc)}, status=500)
async def delete_recipe(self, request: web.Request) -> web.Response:
@@ -816,7 +901,11 @@ class RecipeManagementHandler:
target_path = data.get("target_path")
if not recipe_id or not target_path:
return web.json_response(
{"success": False, "error": "recipe_id and target_path are required"}, status=400
{
"success": False,
"error": "recipe_id and target_path are required",
},
status=400,
)
result = await self._persistence_service.move_recipe(
@@ -845,7 +934,11 @@ class RecipeManagementHandler:
target_path = data.get("target_path")
if not recipe_ids or not target_path:
return web.json_response(
{"success": False, "error": "recipe_ids and target_path are required"}, status=400
{
"success": False,
"error": "recipe_ids and target_path are required",
},
status=400,
)
result = await self._persistence_service.move_recipes_bulk(
@@ -934,7 +1027,9 @@ class RecipeManagementHandler:
except RecipeValidationError as exc:
return web.json_response({"error": str(exc)}, status=400)
except Exception as exc:
self._logger.error("Error saving recipe from widget: %s", exc, exc_info=True)
self._logger.error(
"Error saving recipe from widget: %s", exc, exc_info=True
)
return web.json_response({"error": str(exc)}, status=500)
async def _parse_save_payload(self, reader) -> dict[str, Any]:
@@ -1006,7 +1101,9 @@ class RecipeManagementHandler:
raise RecipeValidationError("gen_params payload must be an object")
return parsed
def _parse_resources_payload(self, payload_raw: str) -> tuple[Optional[Dict[str, Any]], List[Dict[str, Any]]]:
def _parse_resources_payload(
self, payload_raw: str
) -> tuple[Optional[Dict[str, Any]], List[Dict[str, Any]]]:
try:
payload = json.loads(payload_raw)
except json.JSONDecodeError as exc:
@@ -1066,15 +1163,19 @@ class RecipeManagementHandler:
civitai_match = re.match(r"https://civitai\.com/images/(\d+)", image_url)
if civitai_match:
if civitai_client is None:
raise RecipeDownloadError("Civitai client unavailable for image download")
raise RecipeDownloadError(
"Civitai client unavailable for image download"
)
image_info = await civitai_client.get_image_info(civitai_match.group(1))
if not image_info:
raise RecipeDownloadError("Failed to fetch image information from Civitai")
raise RecipeDownloadError(
"Failed to fetch image information from Civitai"
)
media_url = image_info.get("url")
if not media_url:
raise RecipeDownloadError("No image URL found in Civitai response")
# Use optimized preview URLs if possible
media_type = image_info.get("type")
rewritten_url, _ = rewrite_preview_url(media_url, media_type=media_type)
@@ -1083,18 +1184,24 @@ class RecipeManagementHandler:
else:
download_url = media_url
success, result = await downloader.download_file(download_url, temp_path, use_auth=False)
success, result = await downloader.download_file(
download_url, temp_path, use_auth=False
)
if not success:
raise RecipeDownloadError(f"Failed to download image: {result}")
# Extract extension from URL
url_path = download_url.split('?')[0].split('#')[0]
url_path = download_url.split("?")[0].split("#")[0]
extension = os.path.splitext(url_path)[1].lower()
if not extension:
extension = ".webp" # Default to webp if unknown
extension = ".webp" # Default to webp if unknown
with open(temp_path, "rb") as file_obj:
return file_obj.read(), extension, image_info.get("meta") if civitai_match and image_info else None
return (
file_obj.read(),
extension,
image_info.get("meta") if civitai_match and image_info else None,
)
except RecipeDownloadError:
raise
except RecipeValidationError:
@@ -1108,14 +1215,15 @@ class RecipeManagementHandler:
except FileNotFoundError:
pass
def _safe_int(self, value: Any) -> int:
try:
return int(value)
except (TypeError, ValueError):
return 0
async def _resolve_base_model_from_checkpoint(self, checkpoint_entry: Dict[str, Any]) -> str:
async def _resolve_base_model_from_checkpoint(
self, checkpoint_entry: Dict[str, Any]
) -> str:
version_id = self._safe_int(checkpoint_entry.get("modelVersionId"))
if not version_id:
@@ -1134,7 +1242,9 @@ class RecipeManagementHandler:
base_model = version_info.get("baseModel") or ""
return str(base_model) if base_model is not None else ""
except Exception as exc: # pragma: no cover - defensive logging
self._logger.warning("Failed to resolve base model from checkpoint metadata: %s", exc)
self._logger.warning(
"Failed to resolve base model from checkpoint metadata: %s", exc
)
return ""
@@ -1279,5 +1389,178 @@ class RecipeSharingHandler:
except RecipeNotFoundError as exc:
return web.json_response({"error": str(exc)}, status=404)
except Exception as exc:
self._logger.error("Error downloading shared recipe: %s", exc, exc_info=True)
self._logger.error(
"Error downloading shared recipe: %s", exc, exc_info=True
)
return web.json_response({"error": str(exc)}, status=500)
class BatchImportHandler:
"""Handle batch import operations for recipes."""
def __init__(
self,
*,
ensure_dependencies_ready: EnsureDependenciesCallable,
recipe_scanner_getter: RecipeScannerGetter,
civitai_client_getter: CivitaiClientGetter,
logger: Logger,
batch_import_service: BatchImportService,
) -> None:
self._ensure_dependencies_ready = ensure_dependencies_ready
self._recipe_scanner_getter = recipe_scanner_getter
self._civitai_client_getter = civitai_client_getter
self._logger = logger
self._batch_import_service = batch_import_service
async def start_batch_import(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
if self._batch_import_service.is_import_running():
return web.json_response(
{"success": False, "error": "Batch import already in progress"},
status=409,
)
data = await request.json()
items = data.get("items", [])
tags = data.get("tags", [])
skip_no_metadata = data.get("skip_no_metadata", True)
if not items:
return web.json_response(
{"success": False, "error": "No items provided"},
status=400,
)
for item in items:
if not item.get("source"):
return web.json_response(
{
"success": False,
"error": "Each item must have a 'source' field",
},
status=400,
)
operation_id = await self._batch_import_service.start_batch_import(
recipe_scanner_getter=self._recipe_scanner_getter,
civitai_client_getter=self._civitai_client_getter,
items=items,
tags=tags,
skip_no_metadata=skip_no_metadata,
)
return web.json_response(
{
"success": True,
"operation_id": operation_id,
}
)
except RecipeValidationError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except Exception as exc:
self._logger.error("Error starting batch import: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def start_directory_import(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
if self._batch_import_service.is_import_running():
return web.json_response(
{"success": False, "error": "Batch import already in progress"},
status=409,
)
data = await request.json()
directory = data.get("directory")
recursive = data.get("recursive", True)
tags = data.get("tags", [])
skip_no_metadata = data.get("skip_no_metadata", True)
if not directory:
return web.json_response(
{"success": False, "error": "Directory path is required"},
status=400,
)
operation_id = await self._batch_import_service.start_directory_import(
recipe_scanner_getter=self._recipe_scanner_getter,
civitai_client_getter=self._civitai_client_getter,
directory=directory,
recursive=recursive,
tags=tags,
skip_no_metadata=skip_no_metadata,
)
return web.json_response(
{
"success": True,
"operation_id": operation_id,
}
)
except RecipeValidationError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except Exception as exc:
self._logger.error(
"Error starting directory import: %s", exc, exc_info=True
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_batch_import_progress(self, request: web.Request) -> web.Response:
try:
operation_id = request.query.get("operation_id")
if not operation_id:
return web.json_response(
{"success": False, "error": "operation_id is required"},
status=400,
)
progress = self._batch_import_service.get_progress(operation_id)
if not progress:
return web.json_response(
{"success": False, "error": "Operation not found"},
status=404,
)
return web.json_response(
{
"success": True,
"progress": progress.to_dict(),
}
)
except Exception as exc:
self._logger.error(
"Error getting batch import progress: %s", exc, exc_info=True
)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def cancel_batch_import(self, request: web.Request) -> web.Response:
try:
data = await request.json()
operation_id = data.get("operation_id")
if not operation_id:
return web.json_response(
{"success": False, "error": "operation_id is required"},
status=400,
)
cancelled = self._batch_import_service.cancel_import(operation_id)
if not cancelled:
return web.json_response(
{
"success": False,
"error": "Operation not found or already completed",
},
status=404,
)
return web.json_response(
{"success": True, "message": "Cancellation requested"}
)
except Exception as exc:
self._logger.error("Error cancelling batch import: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)

View File

@@ -1,4 +1,5 @@
"""Route registrar for recipe endpoints."""
from __future__ import annotations
from dataclasses import dataclass
@@ -22,7 +23,9 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}", "get_recipe"),
RouteDefinition("GET", "/api/lm/recipes/import-remote", "import_remote_recipe"),
RouteDefinition("POST", "/api/lm/recipes/analyze-image", "analyze_uploaded_image"),
RouteDefinition("POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image"),
RouteDefinition(
"POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image"
),
RouteDefinition("POST", "/api/lm/recipes/save", "save_recipe"),
RouteDefinition("DELETE", "/api/lm/recipe/{recipe_id}", "delete_recipe"),
RouteDefinition("GET", "/api/lm/recipes/top-tags", "get_top_tags"),
@@ -30,9 +33,13 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("GET", "/api/lm/recipes/roots", "get_roots"),
RouteDefinition("GET", "/api/lm/recipes/folders", "get_folders"),
RouteDefinition("GET", "/api/lm/recipes/folder-tree", "get_folder_tree"),
RouteDefinition("GET", "/api/lm/recipes/unified-folder-tree", "get_unified_folder_tree"),
RouteDefinition(
"GET", "/api/lm/recipes/unified-folder-tree", "get_unified_folder_tree"
),
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share", "share_recipe"),
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share/download", "download_shared_recipe"),
RouteDefinition(
"GET", "/api/lm/recipe/{recipe_id}/share/download", "download_shared_recipe"
),
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/syntax", "get_recipe_syntax"),
RouteDefinition("PUT", "/api/lm/recipe/{recipe_id}/update", "update_recipe"),
RouteDefinition("POST", "/api/lm/recipe/move", "move_recipe"),
@@ -40,13 +47,22 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("POST", "/api/lm/recipe/lora/reconnect", "reconnect_lora"),
RouteDefinition("GET", "/api/lm/recipes/find-duplicates", "find_duplicates"),
RouteDefinition("POST", "/api/lm/recipes/bulk-delete", "bulk_delete"),
RouteDefinition("POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget"),
RouteDefinition(
"POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget"
),
RouteDefinition("GET", "/api/lm/recipes/for-lora", "get_recipes_for_lora"),
RouteDefinition("GET", "/api/lm/recipes/scan", "scan_recipes"),
RouteDefinition("POST", "/api/lm/recipes/repair", "repair_recipes"),
RouteDefinition("POST", "/api/lm/recipes/cancel-repair", "cancel_repair"),
RouteDefinition("POST", "/api/lm/recipe/{recipe_id}/repair", "repair_recipe"),
RouteDefinition("GET", "/api/lm/recipes/repair-progress", "get_repair_progress"),
RouteDefinition("POST", "/api/lm/recipes/batch-import/start", "start_batch_import"),
RouteDefinition(
"GET", "/api/lm/recipes/batch-import/progress", "get_batch_import_progress"
),
RouteDefinition(
"POST", "/api/lm/recipes/batch-import/cancel", "cancel_batch_import"
),
)
@@ -63,7 +79,9 @@ class RecipeRouteRegistrar:
def __init__(self, app: web.Application) -> None:
self._app = app
def register_routes(self, handler_lookup: Mapping[str, Callable[[web.Request], object]]) -> None:
def register_routes(
self, handler_lookup: Mapping[str, Callable[[web.Request], object]]
) -> None:
for definition in ROUTE_DEFINITIONS:
handler = handler_lookup[definition.handler_name]
self._bind_route(definition.method, definition.path, handler)

View File

@@ -0,0 +1,579 @@
"""Batch import service for importing multiple images as recipes."""
from __future__ import annotations
import asyncio
import logging
import os
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set
from aiohttp import web
from .recipes import (
RecipeAnalysisService,
RecipePersistenceService,
RecipeValidationError,
RecipeDownloadError,
RecipeNotFoundError,
)
class ImportItemType(Enum):
"""Type of import item."""
URL = "url"
LOCAL_PATH = "local_path"
class ImportStatus(Enum):
"""Status of an individual import item."""
PENDING = "pending"
PROCESSING = "processing"
SUCCESS = "success"
FAILED = "failed"
SKIPPED = "skipped"
@dataclass
class BatchImportItem:
"""Represents a single item to import."""
id: str
source: str
item_type: ImportItemType
status: ImportStatus = ImportStatus.PENDING
error_message: Optional[str] = None
recipe_name: Optional[str] = None
recipe_id: Optional[str] = None
duration: float = 0.0
@dataclass
class BatchImportProgress:
"""Tracks progress of a batch import operation."""
operation_id: str
total: int
completed: int = 0
success: int = 0
failed: int = 0
skipped: int = 0
current_item: str = ""
status: str = "pending"
started_at: float = field(default_factory=time.time)
finished_at: Optional[float] = None
items: List[BatchImportItem] = field(default_factory=list)
tags: List[str] = field(default_factory=list)
skip_no_metadata: bool = True
skip_duplicates: bool = False
def to_dict(self) -> Dict[str, Any]:
return {
"operation_id": self.operation_id,
"total": self.total,
"completed": self.completed,
"success": self.success,
"failed": self.failed,
"skipped": self.skipped,
"current_item": self.current_item,
"status": self.status,
"started_at": self.started_at,
"finished_at": self.finished_at,
"progress_percent": round((self.completed / self.total) * 100, 1)
if self.total > 0
else 0,
}
class AdaptiveConcurrencyController:
"""Adjusts concurrency based on task performance."""
def __init__(
self,
min_concurrency: int = 1,
max_concurrency: int = 5,
initial_concurrency: int = 3,
) -> None:
self.min_concurrency = min_concurrency
self.max_concurrency = max_concurrency
self.current_concurrency = initial_concurrency
self._task_durations: List[float] = []
self._recent_errors = 0
self._recent_successes = 0
def record_result(self, duration: float, success: bool) -> None:
self._task_durations.append(duration)
if len(self._task_durations) > 10:
self._task_durations.pop(0)
if success:
self._recent_successes += 1
if duration < 1.0 and self.current_concurrency < self.max_concurrency:
self.current_concurrency = min(
self.current_concurrency + 1, self.max_concurrency
)
elif duration > 10.0 and self.current_concurrency > self.min_concurrency:
self.current_concurrency = max(
self.current_concurrency - 1, self.min_concurrency
)
else:
self._recent_errors += 1
if self.current_concurrency > self.min_concurrency:
self.current_concurrency = max(
self.current_concurrency - 1, self.min_concurrency
)
def reset_counters(self) -> None:
self._recent_errors = 0
self._recent_successes = 0
def get_semaphore(self) -> asyncio.Semaphore:
return asyncio.Semaphore(self.current_concurrency)
class BatchImportService:
"""Service for batch importing images as recipes."""
SUPPORTED_EXTENSIONS: Set[str] = {".jpg", ".jpeg", ".png", ".webp", ".gif", ".bmp"}
def __init__(
self,
*,
analysis_service: RecipeAnalysisService,
persistence_service: RecipePersistenceService,
ws_manager: Any,
logger: logging.Logger,
) -> None:
self._analysis_service = analysis_service
self._persistence_service = persistence_service
self._ws_manager = ws_manager
self._logger = logger
self._active_operations: Dict[str, BatchImportProgress] = {}
self._cancellation_flags: Dict[str, bool] = {}
self._concurrency_controller = AdaptiveConcurrencyController()
def is_import_running(self, operation_id: Optional[str] = None) -> bool:
if operation_id:
progress = self._active_operations.get(operation_id)
return progress is not None and progress.status in ("pending", "running")
return any(
p.status in ("pending", "running") for p in self._active_operations.values()
)
def get_progress(self, operation_id: str) -> Optional[BatchImportProgress]:
return self._active_operations.get(operation_id)
def cancel_import(self, operation_id: str) -> bool:
if operation_id in self._active_operations:
self._cancellation_flags[operation_id] = True
return True
return False
def _validate_url(self, url: str) -> bool:
import re
url_pattern = re.compile(
r"^https?://"
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|"
r"localhost|"
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})"
r"(?::\d+)?"
r"(?:/?|[/?]\S+)$",
re.IGNORECASE,
)
return url_pattern.match(url) is not None
def _validate_local_path(self, path: str) -> bool:
try:
normalized = os.path.normpath(path)
if not os.path.isabs(normalized):
return False
if ".." in normalized:
return False
return True
except Exception:
return False
def _is_duplicate_source(
self,
source: str,
item_type: ImportItemType,
recipe_scanner: Any,
) -> bool:
try:
cache = recipe_scanner.get_cached_data_sync()
if not cache:
return False
for recipe in getattr(cache, "raw_data", []):
source_path = recipe.get("source_path") or recipe.get("source_url")
if source_path and source_path == source:
return True
return False
except Exception:
self._logger.warning("Failed to check for duplicates", exc_info=True)
return False
async def start_batch_import(
self,
*,
recipe_scanner_getter: Callable[[], Any],
civitai_client_getter: Callable[[], Any],
items: List[Dict[str, str]],
tags: Optional[List[str]] = None,
skip_no_metadata: bool = True,
skip_duplicates: bool = False,
) -> str:
operation_id = str(uuid.uuid4())
import_items = []
for idx, item in enumerate(items):
source = item.get("source", "")
item_type_str = item.get("type", "url")
if item_type_str == "url" or source.startswith(("http://", "https://")):
item_type = ImportItemType.URL
else:
item_type = ImportItemType.LOCAL_PATH
batch_import_item = BatchImportItem(
id=f"{operation_id}_{idx}",
source=source,
item_type=item_type,
)
import_items.append(batch_import_item)
progress = BatchImportProgress(
operation_id=operation_id,
total=len(import_items),
items=import_items,
tags=tags or [],
skip_no_metadata=skip_no_metadata,
skip_duplicates=skip_duplicates,
)
self._active_operations[operation_id] = progress
self._cancellation_flags[operation_id] = False
asyncio.create_task(
self._run_batch_import(
operation_id=operation_id,
recipe_scanner_getter=recipe_scanner_getter,
civitai_client_getter=civitai_client_getter,
)
)
return operation_id
async def start_directory_import(
self,
*,
recipe_scanner_getter: Callable[[], Any],
civitai_client_getter: Callable[[], Any],
directory: str,
recursive: bool = True,
tags: Optional[List[str]] = None,
skip_no_metadata: bool = True,
skip_duplicates: bool = False,
) -> str:
image_paths = await self._discover_images(directory, recursive)
items = [{"source": path, "type": "local_path"} for path in image_paths]
return await self.start_batch_import(
recipe_scanner_getter=recipe_scanner_getter,
civitai_client_getter=civitai_client_getter,
items=items,
tags=tags,
skip_no_metadata=skip_no_metadata,
skip_duplicates=skip_duplicates,
)
async def _discover_images(
self,
directory: str,
recursive: bool = True,
) -> List[str]:
if not os.path.isdir(directory):
raise RecipeValidationError(f"Directory not found: {directory}")
image_paths: List[str] = []
if recursive:
for root, _, files in os.walk(directory):
for filename in files:
if self._is_supported_image(filename):
image_paths.append(os.path.join(root, filename))
else:
for filename in os.listdir(directory):
filepath = os.path.join(directory, filename)
if os.path.isfile(filepath) and self._is_supported_image(filename):
image_paths.append(filepath)
return sorted(image_paths)
def _is_supported_image(self, filename: str) -> bool:
ext = os.path.splitext(filename)[1].lower()
return ext in self.SUPPORTED_EXTENSIONS
async def _run_batch_import(
self,
*,
operation_id: str,
recipe_scanner_getter: Callable[[], Any],
civitai_client_getter: Callable[[], Any],
) -> None:
progress = self._active_operations.get(operation_id)
if not progress:
return
progress.status = "running"
await self._broadcast_progress(progress)
self._concurrency_controller = AdaptiveConcurrencyController()
async def process_item(item: BatchImportItem) -> None:
if self._cancellation_flags.get(operation_id, False):
return
progress.current_item = (
os.path.basename(item.source)
if item.item_type == ImportItemType.LOCAL_PATH
else item.source[:50]
)
item.status = ImportStatus.PROCESSING
await self._broadcast_progress(progress)
start_time = time.time()
try:
result = await self._import_single_item(
item=item,
recipe_scanner_getter=recipe_scanner_getter,
civitai_client_getter=civitai_client_getter,
tags=progress.tags,
skip_no_metadata=progress.skip_no_metadata,
skip_duplicates=progress.skip_duplicates,
semaphore=self._concurrency_controller.get_semaphore(),
)
duration = time.time() - start_time
item.duration = duration
self._concurrency_controller.record_result(
duration, result.get("success", False)
)
if result.get("success"):
item.status = ImportStatus.SUCCESS
item.recipe_name = result.get("recipe_name")
item.recipe_id = result.get("recipe_id")
progress.success += 1
elif result.get("skipped"):
item.status = ImportStatus.SKIPPED
item.error_message = result.get("error")
progress.skipped += 1
else:
item.status = ImportStatus.FAILED
item.error_message = result.get("error")
progress.failed += 1
except Exception as e:
self._logger.error(f"Error importing {item.source}: {e}")
item.status = ImportStatus.FAILED
item.error_message = str(e)
item.duration = time.time() - start_time
progress.failed += 1
self._concurrency_controller.record_result(item.duration, False)
progress.completed += 1
await self._broadcast_progress(progress)
tasks = [process_item(item) for item in progress.items]
await asyncio.gather(*tasks, return_exceptions=True)
if self._cancellation_flags.get(operation_id, False):
progress.status = "cancelled"
else:
progress.status = "completed"
progress.finished_at = time.time()
progress.current_item = ""
await self._broadcast_progress(progress)
await asyncio.sleep(5)
self._cleanup_operation(operation_id)
async def _import_single_item(
self,
*,
item: BatchImportItem,
recipe_scanner_getter: Callable[[], Any],
civitai_client_getter: Callable[[], Any],
tags: List[str],
skip_no_metadata: bool,
skip_duplicates: bool,
semaphore: asyncio.Semaphore,
) -> Dict[str, Any]:
async with semaphore:
recipe_scanner = recipe_scanner_getter()
if recipe_scanner is None:
return {"success": False, "error": "Recipe scanner unavailable"}
try:
if item.item_type == ImportItemType.URL:
if not self._validate_url(item.source):
return {
"success": False,
"error": f"Invalid URL format: {item.source}",
}
if skip_duplicates:
if self._is_duplicate_source(
item.source, item.item_type, recipe_scanner
):
return {
"success": False,
"skipped": True,
"error": "Duplicate source URL",
}
civitai_client = civitai_client_getter()
analysis_result = await self._analysis_service.analyze_remote_image(
url=item.source,
recipe_scanner=recipe_scanner,
civitai_client=civitai_client,
)
else:
if not self._validate_local_path(item.source):
return {
"success": False,
"error": f"Invalid or unsafe path: {item.source}",
}
if not os.path.exists(item.source):
return {
"success": False,
"error": f"File not found: {item.source}",
}
if skip_duplicates:
if self._is_duplicate_source(
item.source, item.item_type, recipe_scanner
):
return {
"success": False,
"skipped": True,
"error": "Duplicate source path",
}
analysis_result = await self._analysis_service.analyze_local_image(
file_path=item.source,
recipe_scanner=recipe_scanner,
)
payload = analysis_result.payload
if payload.get("error"):
if skip_no_metadata and "No metadata" in payload.get("error", ""):
return {
"success": False,
"skipped": True,
"error": payload["error"],
}
return {"success": False, "error": payload["error"]}
loras = payload.get("loras", [])
if not loras:
if skip_no_metadata:
return {
"success": False,
"skipped": True,
"error": "No LoRAs found in image",
}
return {"success": False, "error": "No LoRAs found in image"}
recipe_name = self._generate_recipe_name(item, payload)
all_tags = list(set(tags + (payload.get("tags", []) or [])))
metadata = {
"base_model": payload.get("base_model", ""),
"loras": loras,
"gen_params": payload.get("gen_params", {}),
"source_path": item.source,
}
if payload.get("checkpoint"):
metadata["checkpoint"] = payload["checkpoint"]
image_bytes = None
image_base64 = payload.get("image_base64")
if item.item_type == ImportItemType.LOCAL_PATH:
with open(item.source, "rb") as f:
image_bytes = f.read()
image_base64 = None
save_result = await self._persistence_service.save_recipe(
recipe_scanner=recipe_scanner,
image_bytes=image_bytes,
image_base64=image_base64,
name=recipe_name,
tags=all_tags,
metadata=metadata,
extension=payload.get("extension"),
)
if save_result.status == 200:
return {
"success": True,
"recipe_name": recipe_name,
"recipe_id": save_result.payload.get("id"),
}
else:
return {
"success": False,
"error": save_result.payload.get(
"error", "Failed to save recipe"
),
}
except RecipeValidationError as e:
return {"success": False, "error": str(e)}
except RecipeDownloadError as e:
return {"success": False, "error": str(e)}
except RecipeNotFoundError as e:
return {"success": False, "skipped": True, "error": str(e)}
except Exception as e:
self._logger.error(
f"Unexpected error importing {item.source}: {e}", exc_info=True
)
return {"success": False, "error": str(e)}
def _generate_recipe_name(
self, item: BatchImportItem, payload: Dict[str, Any]
) -> str:
if item.item_type == ImportItemType.LOCAL_PATH:
base_name = os.path.splitext(os.path.basename(item.source))[0]
return base_name[:100]
else:
loras = payload.get("loras", [])
if loras:
first_lora = loras[0].get("name", "Recipe")
return f"Import - {first_lora}"[:100]
return f"Imported Recipe {item.id[:8]}"
async def _broadcast_progress(self, progress: BatchImportProgress) -> None:
await self._ws_manager.broadcast(
{
"type": "batch_import_progress",
**progress.to_dict(),
}
)
def _cleanup_operation(self, operation_id: str) -> None:
if operation_id in self._cancellation_flags:
del self._cancellation_flags[operation_id]

View File

@@ -1,4 +1,5 @@
"""Integration smoke tests for the recipe route stack."""
from __future__ import annotations
import json
@@ -94,19 +95,25 @@ class StubAnalysisService:
self._recipe_parser_factory = None
StubAnalysisService.instances.append(self)
async def analyze_uploaded_image(self, *, image_bytes: bytes | None, recipe_scanner) -> SimpleNamespace: # noqa: D401 - mirrors real signature
async def analyze_uploaded_image(
self, *, image_bytes: bytes | None, recipe_scanner
) -> SimpleNamespace: # noqa: D401 - mirrors real signature
if self.raise_for_uploaded:
raise self.raise_for_uploaded
self.upload_calls.append(image_bytes or b"")
return self.result
async def analyze_remote_image(self, *, url: Optional[str], recipe_scanner, civitai_client) -> SimpleNamespace: # noqa: D401
async def analyze_remote_image(
self, *, url: Optional[str], recipe_scanner, civitai_client
) -> SimpleNamespace: # noqa: D401
if self.raise_for_remote:
raise self.raise_for_remote
self.remote_calls.append(url)
return self.result
async def analyze_local_image(self, *, file_path: Optional[str], recipe_scanner) -> SimpleNamespace: # noqa: D401
async def analyze_local_image(
self, *, file_path: Optional[str], recipe_scanner
) -> SimpleNamespace: # noqa: D401
if self.raise_for_local:
raise self.raise_for_local
self.local_calls.append(file_path)
@@ -125,11 +132,23 @@ class StubPersistenceService:
self.save_calls: List[Dict[str, Any]] = []
self.delete_calls: List[str] = []
self.move_calls: List[Dict[str, str]] = []
self.save_result = SimpleNamespace(payload={"success": True, "recipe_id": "stub-id"}, status=200)
self.save_result = SimpleNamespace(
payload={"success": True, "recipe_id": "stub-id"}, status=200
)
self.delete_result = SimpleNamespace(payload={"success": True}, status=200)
StubPersistenceService.instances.append(self)
async def save_recipe(self, *, recipe_scanner, image_bytes, image_base64, name, tags, metadata, extension=None) -> SimpleNamespace: # noqa: D401
async def save_recipe(
self,
*,
recipe_scanner,
image_bytes,
image_base64,
name,
tags,
metadata,
extension=None,
) -> SimpleNamespace: # noqa: D401
self.save_calls.append(
{
"recipe_scanner": recipe_scanner,
@@ -148,22 +167,42 @@ class StubPersistenceService:
await recipe_scanner.remove_recipe(recipe_id)
return self.delete_result
async def move_recipe(self, *, recipe_scanner, recipe_id: str, target_path: str) -> SimpleNamespace: # noqa: D401
async def move_recipe(
self, *, recipe_scanner, recipe_id: str, target_path: str
) -> SimpleNamespace: # noqa: D401
self.move_calls.append({"recipe_id": recipe_id, "target_path": target_path})
return SimpleNamespace(
payload={"success": True, "recipe_id": recipe_id, "new_file_path": target_path}, status=200
payload={
"success": True,
"recipe_id": recipe_id,
"new_file_path": target_path,
},
status=200,
)
async def update_recipe(self, *, recipe_scanner, recipe_id: str, updates: Dict[str, Any]) -> SimpleNamespace: # pragma: no cover - unused by smoke tests
return SimpleNamespace(payload={"success": True, "recipe_id": recipe_id, "updates": updates}, status=200)
async def update_recipe(
self, *, recipe_scanner, recipe_id: str, updates: Dict[str, Any]
) -> SimpleNamespace: # pragma: no cover - unused by smoke tests
return SimpleNamespace(
payload={"success": True, "recipe_id": recipe_id, "updates": updates},
status=200,
)
async def reconnect_lora(self, *, recipe_scanner, recipe_id: str, lora_index: int, target_name: str) -> SimpleNamespace: # pragma: no cover
async def reconnect_lora(
self, *, recipe_scanner, recipe_id: str, lora_index: int, target_name: str
) -> SimpleNamespace: # pragma: no cover
return SimpleNamespace(payload={"success": True}, status=200)
async def bulk_delete(self, *, recipe_scanner, recipe_ids: List[str]) -> SimpleNamespace: # pragma: no cover
return SimpleNamespace(payload={"success": True, "deleted": recipe_ids}, status=200)
async def bulk_delete(
self, *, recipe_scanner, recipe_ids: List[str]
) -> SimpleNamespace: # pragma: no cover
return SimpleNamespace(
payload={"success": True, "deleted": recipe_ids}, status=200
)
async def save_recipe_from_widget(self, *, recipe_scanner, metadata: Dict[str, Any], image_bytes: bytes) -> SimpleNamespace: # pragma: no cover
async def save_recipe_from_widget(
self, *, recipe_scanner, metadata: Dict[str, Any], image_bytes: bytes
) -> SimpleNamespace: # pragma: no cover
return SimpleNamespace(payload={"success": True}, status=200)
@@ -176,7 +215,11 @@ class StubSharingService:
self.share_calls: List[str] = []
self.download_calls: List[str] = []
self.share_result = SimpleNamespace(
payload={"success": True, "download_url": "/share/stub", "filename": "recipe.png"},
payload={
"success": True,
"download_url": "/share/stub",
"filename": "recipe.png",
},
status=200,
)
self.download_info = SimpleNamespace(file_path="", download_filename="")
@@ -186,7 +229,9 @@ class StubSharingService:
self.share_calls.append(recipe_id)
return self.share_result
async def prepare_download(self, *, recipe_scanner, recipe_id: str) -> SimpleNamespace:
async def prepare_download(
self, *, recipe_scanner, recipe_id: str
) -> SimpleNamespace:
self.download_calls.append(recipe_id)
return self.download_info
@@ -214,7 +259,9 @@ class StubCivitaiClient:
@asynccontextmanager
async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRouteHarness]:
async def recipe_harness(
monkeypatch, tmp_path: Path
) -> AsyncIterator[RecipeRouteHarness]:
"""Context manager that yields a fully wired recipe route harness."""
StubAnalysisService.instances.clear()
@@ -237,8 +284,12 @@ async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRou
monkeypatch.setattr(ServiceRegistry, "get_recipe_scanner", fake_get_recipe_scanner)
monkeypatch.setattr(ServiceRegistry, "get_civitai_client", fake_get_civitai_client)
monkeypatch.setattr(base_recipe_routes, "RecipeAnalysisService", StubAnalysisService)
monkeypatch.setattr(base_recipe_routes, "RecipePersistenceService", StubPersistenceService)
monkeypatch.setattr(
base_recipe_routes, "RecipeAnalysisService", StubAnalysisService
)
monkeypatch.setattr(
base_recipe_routes, "RecipePersistenceService", StubPersistenceService
)
monkeypatch.setattr(base_recipe_routes, "RecipeSharingService", StubSharingService)
monkeypatch.setattr(base_recipe_routes, "get_downloader", fake_get_downloader)
monkeypatch.setattr(config, "loras_roots", [str(tmp_path)], raising=False)
@@ -294,7 +345,9 @@ async def test_list_recipes_provides_file_urls(monkeypatch, tmp_path: Path) -> N
async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) -> None:
async with recipe_harness(monkeypatch, tmp_path) as harness:
form = FormData()
form.add_field("image", b"stub", filename="sample.png", content_type="image/png")
form.add_field(
"image", b"stub", filename="sample.png", content_type="image/png"
)
form.add_field("name", "Test Recipe")
form.add_field("tags", json.dumps(["tag-a"]))
form.add_field("metadata", json.dumps({"loras": []}))
@@ -312,7 +365,9 @@ async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) ->
assert save_payload["recipe_id"] == "saved-id"
assert harness.persistence.save_calls[-1]["name"] == "Test Recipe"
harness.persistence.delete_result = SimpleNamespace(payload={"success": True}, status=200)
harness.persistence.delete_result = SimpleNamespace(
payload={"success": True}, status=200
)
delete_response = await harness.client.delete("/api/lm/recipe/saved-id")
delete_payload = await delete_response.json()
@@ -326,14 +381,20 @@ async def test_move_recipe_invokes_persistence(monkeypatch, tmp_path: Path) -> N
async with recipe_harness(monkeypatch, tmp_path) as harness:
response = await harness.client.post(
"/api/lm/recipe/move",
json={"recipe_id": "move-me", "target_path": str(tmp_path / "recipes" / "subdir")},
json={
"recipe_id": "move-me",
"target_path": str(tmp_path / "recipes" / "subdir"),
},
)
payload = await response.json()
assert response.status == 200
assert payload["recipe_id"] == "move-me"
assert harness.persistence.move_calls == [
{"recipe_id": "move-me", "target_path": str(tmp_path / "recipes" / "subdir")}
{
"recipe_id": "move-me",
"target_path": str(tmp_path / "recipes" / "subdir"),
}
]
@@ -348,7 +409,10 @@ async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None:
async def fake_get_default_metadata_provider():
return Provider()
monkeypatch.setattr("py.recipes.enrichment.get_default_metadata_provider", fake_get_default_metadata_provider)
monkeypatch.setattr(
"py.recipes.enrichment.get_default_metadata_provider",
fake_get_default_metadata_provider,
)
async with recipe_harness(monkeypatch, tmp_path) as harness:
resources = [
@@ -397,7 +461,9 @@ async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None:
assert harness.downloader.urls == ["https://example.com/images/1"]
async def test_import_remote_recipe_falls_back_to_request_base_model(monkeypatch, tmp_path: Path) -> None:
async def test_import_remote_recipe_falls_back_to_request_base_model(
monkeypatch, tmp_path: Path
) -> None:
provider_calls: list[str | int] = []
class Provider:
@@ -408,7 +474,10 @@ async def test_import_remote_recipe_falls_back_to_request_base_model(monkeypatch
async def fake_get_default_metadata_provider():
return Provider()
monkeypatch.setattr("py.recipes.enrichment.get_default_metadata_provider", fake_get_default_metadata_provider)
monkeypatch.setattr(
"py.recipes.enrichment.get_default_metadata_provider",
fake_get_default_metadata_provider,
)
async with recipe_harness(monkeypatch, tmp_path) as harness:
resources = [
@@ -444,13 +513,16 @@ async def test_import_remote_video_recipe(monkeypatch, tmp_path: Path) -> None:
async def fake_get_default_metadata_provider():
return SimpleNamespace(get_model_version_info=lambda id: ({}, None))
monkeypatch.setattr("py.recipes.enrichment.get_default_metadata_provider", fake_get_default_metadata_provider)
monkeypatch.setattr(
"py.recipes.enrichment.get_default_metadata_provider",
fake_get_default_metadata_provider,
)
async with recipe_harness(monkeypatch, tmp_path) as harness:
harness.civitai.image_info["12345"] = {
"id": 12345,
"url": "https://image.civitai.com/x/y/original=true/video.mp4",
"type": "video"
"type": "video",
}
response = await harness.client.get(
@@ -469,7 +541,7 @@ async def test_import_remote_video_recipe(monkeypatch, tmp_path: Path) -> None:
# Verify downloader was called with rewritten URL
assert "transcode=true" in harness.downloader.urls[0]
# Verify persistence was called with correct extension
call = harness.persistence.save_calls[-1]
assert call["extension"] == ".mp4"
@@ -477,7 +549,9 @@ async def test_import_remote_video_recipe(monkeypatch, tmp_path: Path) -> None:
async def test_analyze_uploaded_image_error_path(monkeypatch, tmp_path: Path) -> None:
async with recipe_harness(monkeypatch, tmp_path) as harness:
harness.analysis.raise_for_uploaded = RecipeValidationError("No image data provided")
harness.analysis.raise_for_uploaded = RecipeValidationError(
"No image data provided"
)
form = FormData()
form.add_field("image", b"", filename="empty.png", content_type="image/png")
@@ -504,7 +578,11 @@ async def test_share_and_download_recipe(monkeypatch, tmp_path: Path) -> None:
}
harness.sharing.share_result = SimpleNamespace(
payload={"success": True, "download_url": "/api/share", "filename": "share.png"},
payload={
"success": True,
"download_url": "/api/share",
"filename": "share.png",
},
status=200,
)
harness.sharing.download_info = SimpleNamespace(
@@ -519,15 +597,24 @@ async def test_share_and_download_recipe(monkeypatch, tmp_path: Path) -> None:
assert share_payload["filename"] == "share.png"
assert harness.sharing.share_calls == [recipe_id]
download_response = await harness.client.get(f"/api/lm/recipe/{recipe_id}/share/download")
download_response = await harness.client.get(
f"/api/lm/recipe/{recipe_id}/share/download"
)
body = await download_response.read()
assert download_response.status == 200
assert download_response.headers["Content-Disposition"] == 'attachment; filename="share.png"'
assert (
download_response.headers["Content-Disposition"]
== 'attachment; filename="share.png"'
)
assert body == b"stub"
download_path.unlink(missing_ok=True)
async def test_import_remote_recipe_merges_metadata(monkeypatch, tmp_path: Path) -> None:
async def test_import_remote_recipe_merges_metadata(
monkeypatch, tmp_path: Path
) -> None:
# 1. Mock Metadata Provider
class Provider:
async def get_model_version_info(self, model_version_id):
@@ -536,22 +623,25 @@ async def test_import_remote_recipe_merges_metadata(monkeypatch, tmp_path: Path)
async def fake_get_default_metadata_provider():
return Provider()
monkeypatch.setattr("py.recipes.enrichment.get_default_metadata_provider", fake_get_default_metadata_provider)
monkeypatch.setattr(
"py.recipes.enrichment.get_default_metadata_provider",
fake_get_default_metadata_provider,
)
# 2. Mock ExifUtils to return some embedded metadata
class MockExifUtils:
@staticmethod
def extract_image_metadata(path):
return "Recipe metadata: " + json.dumps({
"gen_params": {"prompt": "from embedded", "seed": 123}
})
return "Recipe metadata: " + json.dumps(
{"gen_params": {"prompt": "from embedded", "seed": 123}}
)
monkeypatch.setattr(recipe_handlers, "ExifUtils", MockExifUtils)
# 3. Mock Parser Factory for StubAnalysisService
class MockParser:
async def parse_metadata(self, raw, recipe_scanner=None):
return json.loads(raw[len("Recipe metadata: "):])
return json.loads(raw[len("Recipe metadata: ") :])
class MockFactory:
def create_parser(self, raw):
@@ -562,12 +652,12 @@ async def test_import_remote_recipe_merges_metadata(monkeypatch, tmp_path: Path)
# 4. Setup Harness and run test
async with recipe_harness(monkeypatch, tmp_path) as harness:
harness.analysis._recipe_parser_factory = MockFactory()
# Civitai meta via image_info
harness.civitai.image_info["1"] = {
"id": 1,
"url": "https://example.com/images/1.jpg",
"meta": {"prompt": "from civitai", "cfg": 7.0}
"meta": {"prompt": "from civitai", "cfg": 7.0},
}
resources = []
@@ -583,11 +673,11 @@ async def test_import_remote_recipe_merges_metadata(monkeypatch, tmp_path: Path)
payload = await response.json()
assert response.status == 200
call = harness.persistence.save_calls[-1]
metadata = call["metadata"]
gen_params = metadata["gen_params"]
assert gen_params["seed"] == 123
@@ -619,3 +709,142 @@ async def test_get_recipe_syntax(monkeypatch, tmp_path: Path) -> None:
response_404 = await harness.client.get("/api/lm/recipe/non-existent/syntax")
assert response_404.status == 404
async def test_batch_import_start_success(monkeypatch, tmp_path: Path) -> None:
async with recipe_harness(monkeypatch, tmp_path) as harness:
response = await harness.client.post(
"/api/lm/recipes/batch-import/start",
json={
"items": [
{"source": "https://example.com/image1.png"},
{"source": "https://example.com/image2.png"},
],
"tags": ["batch", "import"],
"skip_no_metadata": True,
},
)
payload = await response.json()
assert response.status == 200
assert payload["success"] is True
assert "operation_id" in payload
async def test_batch_import_start_empty_items(monkeypatch, tmp_path: Path) -> None:
async with recipe_harness(monkeypatch, tmp_path) as harness:
response = await harness.client.post(
"/api/lm/recipes/batch-import/start",
json={"items": [], "tags": []},
)
payload = await response.json()
assert response.status == 400
assert payload["success"] is False
assert "No items provided" in payload["error"]
async def test_batch_import_start_missing_source(monkeypatch, tmp_path: Path) -> None:
async with recipe_harness(monkeypatch, tmp_path) as harness:
response = await harness.client.post(
"/api/lm/recipes/batch-import/start",
json={"items": [{"source": ""}]},
)
payload = await response.json()
assert response.status == 400
assert payload["success"] is False
assert "source" in payload["error"].lower()
async def test_batch_import_start_already_running(monkeypatch, tmp_path: Path) -> None:
import asyncio
async with recipe_harness(monkeypatch, tmp_path) as harness:
original_analyze = harness.analysis.analyze_remote_image
async def slow_analyze(*, url, recipe_scanner, civitai_client):
await asyncio.sleep(0.5)
return await original_analyze(
url=url, recipe_scanner=recipe_scanner, civitai_client=civitai_client
)
harness.analysis.analyze_remote_image = slow_analyze
items = [{"source": f"https://example.com/image{i}.png"} for i in range(10)]
response1 = await harness.client.post(
"/api/lm/recipes/batch-import/start",
json={"items": items},
)
assert response1.status == 200
payload1 = await response1.json()
assert payload1["success"] is True
await asyncio.sleep(0.1)
response2 = await harness.client.post(
"/api/lm/recipes/batch-import/start",
json={"items": [{"source": "https://example.com/other.png"}]},
)
payload2 = await response2.json()
assert response2.status == 409
assert "already in progress" in payload2["error"].lower()
async def test_batch_import_get_progress_not_found(monkeypatch, tmp_path: Path) -> None:
async with recipe_harness(monkeypatch, tmp_path) as harness:
response = await harness.client.get(
"/api/lm/recipes/batch-import/progress",
params={"operation_id": "nonexistent-id"},
)
payload = await response.json()
assert response.status == 404
assert payload["success"] is False
async def test_batch_import_get_progress_missing_id(
monkeypatch, tmp_path: Path
) -> None:
async with recipe_harness(monkeypatch, tmp_path) as harness:
response = await harness.client.get("/api/lm/recipes/batch-import/progress")
payload = await response.json()
assert response.status == 400
assert payload["success"] is False
async def test_batch_import_cancel_success(monkeypatch, tmp_path: Path) -> None:
async with recipe_harness(monkeypatch, tmp_path) as harness:
start_response = await harness.client.post(
"/api/lm/recipes/batch-import/start",
json={"items": [{"source": "https://example.com/image.png"}]},
)
start_payload = await start_response.json()
operation_id = start_payload["operation_id"]
cancel_response = await harness.client.post(
"/api/lm/recipes/batch-import/cancel",
json={"operation_id": operation_id},
)
cancel_payload = await cancel_response.json()
assert cancel_response.status == 200
assert cancel_payload["success"] is True
async def test_batch_import_cancel_not_found(monkeypatch, tmp_path: Path) -> None:
async with recipe_harness(monkeypatch, tmp_path) as harness:
response = await harness.client.post(
"/api/lm/recipes/batch-import/cancel",
json={"operation_id": "nonexistent-id"},
)
payload = await response.json()
assert response.status == 404
assert payload["success"] is False
async def test_batch_import_cancel_missing_id(monkeypatch, tmp_path: Path) -> None:
async with recipe_harness(monkeypatch, tmp_path) as harness:
response = await harness.client.post(
"/api/lm/recipes/batch-import/cancel",
json={},
)
payload = await response.json()
assert response.status == 400
assert payload["success"] is False

View File

@@ -0,0 +1,597 @@
"""Unit tests for BatchImportService."""
from __future__ import annotations
import asyncio
import logging
import os
import tempfile
from dataclasses import dataclass
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List, Optional
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from py.services.batch_import_service import (
AdaptiveConcurrencyController,
BatchImportItem,
BatchImportProgress,
BatchImportService,
ImportItemType,
ImportStatus,
)
class MockWebSocketManager:
def __init__(self):
self.broadcasts: List[Dict[str, Any]] = []
async def broadcast(self, data: Dict[str, Any]):
self.broadcasts.append(data)
@dataclass
class MockAnalysisResult:
payload: Dict[str, Any]
status: int = 200
class MockAnalysisService:
def __init__(self, results: Optional[Dict[str, MockAnalysisResult]] = None):
self.results = results or {}
self.call_count = 0
self.last_url = None
self.last_path = None
async def analyze_remote_image(self, *, url: str, recipe_scanner, civitai_client):
self.call_count += 1
self.last_url = url
if url in self.results:
return self.results[url]
return MockAnalysisResult({"error": "No metadata found", "loras": []})
async def analyze_local_image(self, *, file_path: str, recipe_scanner):
self.call_count += 1
self.last_path = file_path
if file_path in self.results:
return self.results[file_path]
return MockAnalysisResult({"error": "No metadata found", "loras": []})
@dataclass
class MockSaveResult:
payload: Dict[str, Any]
status: int = 200
class MockPersistenceService:
def __init__(self, should_succeed: bool = True):
self.should_succeed = should_succeed
self.saved_recipes: List[Dict[str, Any]] = []
self.call_count = 0
async def save_recipe(
self,
*,
recipe_scanner,
image_bytes: Optional[bytes] = None,
image_base64: Optional[str] = None,
name: str,
tags: List[str],
metadata: Dict[str, Any],
extension: Optional[str] = None,
):
self.call_count += 1
self.saved_recipes.append(
{
"name": name,
"tags": tags,
"metadata": metadata,
}
)
if self.should_succeed:
return MockSaveResult({"success": True, "id": f"recipe_{self.call_count}"})
return MockSaveResult({"success": False, "error": "Save failed"}, status=400)
class TestAdaptiveConcurrencyController:
def test_initial_values(self):
controller = AdaptiveConcurrencyController()
assert controller.current_concurrency == 3
assert controller.min_concurrency == 1
assert controller.max_concurrency == 5
def test_custom_initial_values(self):
controller = AdaptiveConcurrencyController(
min_concurrency=2,
max_concurrency=10,
initial_concurrency=5,
)
assert controller.current_concurrency == 5
assert controller.min_concurrency == 2
assert controller.max_concurrency == 10
def test_increase_concurrency_on_success(self):
controller = AdaptiveConcurrencyController(initial_concurrency=3)
controller.record_result(duration=0.5, success=True)
assert controller.current_concurrency == 4
def test_do_not_exceed_max(self):
controller = AdaptiveConcurrencyController(
max_concurrency=5,
initial_concurrency=5,
)
controller.record_result(duration=0.5, success=True)
assert controller.current_concurrency == 5
def test_decrease_concurrency_on_failure(self):
controller = AdaptiveConcurrencyController(initial_concurrency=3)
controller.record_result(duration=1.0, success=False)
assert controller.current_concurrency == 2
def test_do_not_go_below_min(self):
controller = AdaptiveConcurrencyController(
min_concurrency=1,
initial_concurrency=1,
)
controller.record_result(duration=1.0, success=False)
assert controller.current_concurrency == 1
def test_slow_task_decreases_concurrency(self):
controller = AdaptiveConcurrencyController(initial_concurrency=3)
controller.record_result(duration=11.0, success=True)
assert controller.current_concurrency == 2
def test_fast_task_increases_concurrency(self):
controller = AdaptiveConcurrencyController(initial_concurrency=3)
controller.record_result(duration=0.5, success=True)
assert controller.current_concurrency == 4
def test_moderate_task_no_change(self):
controller = AdaptiveConcurrencyController(initial_concurrency=3)
controller.record_result(duration=5.0, success=True)
assert controller.current_concurrency == 3
class TestBatchImportProgress:
def test_to_dict(self):
progress = BatchImportProgress(
operation_id="test-123",
total=10,
completed=5,
success=3,
failed=2,
skipped=0,
current_item="image.png",
status="running",
)
result = progress.to_dict()
assert result["operation_id"] == "test-123"
assert result["total"] == 10
assert result["completed"] == 5
assert result["success"] == 3
assert result["failed"] == 2
assert result["progress_percent"] == 50.0
def test_progress_percent_zero_total(self):
progress = BatchImportProgress(
operation_id="test-123",
total=0,
)
assert progress.to_dict()["progress_percent"] == 0
class TestBatchImportItem:
def test_defaults(self):
item = BatchImportItem(
id="item-1",
source="https://example.com/image.png",
item_type=ImportItemType.URL,
)
assert item.status == ImportStatus.PENDING
assert item.error_message is None
assert item.recipe_name is None
class TestBatchImportService:
@pytest.fixture
def mock_services(self):
ws_manager = MockWebSocketManager()
analysis_service = MockAnalysisService()
persistence_service = MockPersistenceService()
logger = logging.getLogger("test")
return ws_manager, analysis_service, persistence_service, logger
@pytest.fixture
def service(self, mock_services):
ws_manager, analysis_service, persistence_service, logger = mock_services
return BatchImportService(
analysis_service=analysis_service,
persistence_service=persistence_service,
ws_manager=ws_manager,
logger=logger,
)
def test_is_import_running_no_operations(self, service):
assert not service.is_import_running()
@pytest.mark.asyncio
async def test_start_batch_import_creates_operation(self, service):
recipe_scanner_getter = lambda: SimpleNamespace()
civitai_client_getter = lambda: SimpleNamespace()
operation_id = await service.start_batch_import(
recipe_scanner_getter=recipe_scanner_getter,
civitai_client_getter=civitai_client_getter,
items=[{"source": "https://example.com/image.png"}],
)
assert operation_id is not None
assert service.is_import_running(operation_id)
@pytest.mark.asyncio
async def test_get_progress(self, service):
recipe_scanner_getter = lambda: SimpleNamespace()
civitai_client_getter = lambda: SimpleNamespace()
operation_id = await service.start_batch_import(
recipe_scanner_getter=recipe_scanner_getter,
civitai_client_getter=civitai_client_getter,
items=[
{"source": "https://example.com/1.png"},
{"source": "https://example.com/2.png"},
],
)
progress = service.get_progress(operation_id)
assert progress is not None
assert progress.total == 2
assert progress.status in ("pending", "running")
@pytest.mark.asyncio
async def test_cancel_import(self, service):
recipe_scanner_getter = lambda: SimpleNamespace()
civitai_client_getter = lambda: SimpleNamespace()
operation_id = await service.start_batch_import(
recipe_scanner_getter=recipe_scanner_getter,
civitai_client_getter=civitai_client_getter,
items=[{"source": "https://example.com/image.png"}],
)
assert service.cancel_import(operation_id) is True
assert service.cancel_import("nonexistent") is False
@pytest.mark.asyncio
async def test_discover_images_non_recursive(self, service, tmp_path):
for i in range(3):
(tmp_path / f"image{i}.png").write_bytes(b"fake-image")
(tmp_path / "subdir").mkdir()
(tmp_path / "subdir" / "hidden.png").write_bytes(b"fake-image")
images = await service._discover_images(str(tmp_path), recursive=False)
assert len(images) == 3
@pytest.mark.asyncio
async def test_discover_images_recursive(self, service, tmp_path):
for i in range(2):
(tmp_path / f"image{i}.png").write_bytes(b"fake-image")
subdir = tmp_path / "subdir"
subdir.mkdir()
for i in range(2):
(subdir / f"nested{i}.jpg").write_bytes(b"fake-image")
images = await service._discover_images(str(tmp_path), recursive=True)
assert len(images) == 4
@pytest.mark.asyncio
async def test_discover_images_filters_by_extension(self, service, tmp_path):
(tmp_path / "image.png").write_bytes(b"fake-image")
(tmp_path / "image.jpg").write_bytes(b"fake-image")
(tmp_path / "image.webp").write_bytes(b"fake-image")
(tmp_path / "document.pdf").write_bytes(b"fake-doc")
(tmp_path / "script.py").write_bytes(b"print('hello')")
images = await service._discover_images(str(tmp_path), recursive=False)
assert len(images) == 3
@pytest.mark.asyncio
async def test_discover_images_invalid_directory(self, service):
from py.services.recipes.errors import RecipeValidationError
with pytest.raises(RecipeValidationError):
await service._discover_images("/nonexistent/path", recursive=False)
def test_is_supported_image(self, service):
assert service._is_supported_image("test.png") is True
assert service._is_supported_image("test.jpg") is True
assert service._is_supported_image("test.jpeg") is True
assert service._is_supported_image("test.webp") is True
assert service._is_supported_image("test.gif") is True
assert service._is_supported_image("test.bmp") is True
assert service._is_supported_image("test.pdf") is False
assert service._is_supported_image("test.txt") is False
@pytest.mark.asyncio
async def test_batch_import_processes_items(self, mock_services, tmp_path):
ws_manager, _, persistence_service, logger = mock_services
analysis_service = MockAnalysisService(
{
"https://example.com/valid.png": MockAnalysisResult(
{
"loras": [{"name": "test-lora", "weight": 1.0}],
"base_model": "SD1.5",
"gen_params": {"steps": 20},
}
),
}
)
service = BatchImportService(
analysis_service=analysis_service,
persistence_service=persistence_service,
ws_manager=ws_manager,
logger=logger,
)
recipe_scanner_getter = lambda: SimpleNamespace(
find_recipes_by_fingerprint=lambda x: [],
add_recipe=lambda x: None,
)
civitai_client_getter = lambda: SimpleNamespace()
operation_id = await service.start_batch_import(
recipe_scanner_getter=recipe_scanner_getter,
civitai_client_getter=civitai_client_getter,
items=[
{"source": "https://example.com/valid.png"},
{"source": "https://example.com/no-meta.png"},
],
skip_no_metadata=True,
)
await asyncio.sleep(0.5)
progress = service.get_progress(operation_id)
assert progress is not None or persistence_service.call_count == 1
@pytest.mark.asyncio
async def test_start_directory_import(self, service, tmp_path):
for i in range(5):
(tmp_path / f"image{i}.png").write_bytes(b"fake-image")
recipe_scanner_getter = lambda: SimpleNamespace()
civitai_client_getter = lambda: SimpleNamespace()
operation_id = await service.start_directory_import(
recipe_scanner_getter=recipe_scanner_getter,
civitai_client_getter=civitai_client_getter,
directory=str(tmp_path),
recursive=False,
)
progress = service.get_progress(operation_id)
assert progress is not None
assert progress.total == 5
@pytest.mark.asyncio
async def test_websocket_broadcasts_progress(self, mock_services):
ws_manager, analysis_service, persistence_service, logger = mock_services
service = BatchImportService(
analysis_service=analysis_service,
persistence_service=persistence_service,
ws_manager=ws_manager,
logger=logger,
)
recipe_scanner_getter = lambda: SimpleNamespace()
civitai_client_getter = lambda: SimpleNamespace()
operation_id = await service.start_batch_import(
recipe_scanner_getter=recipe_scanner_getter,
civitai_client_getter=civitai_client_getter,
items=[{"source": "https://example.com/test.png"}],
)
await asyncio.sleep(0.3)
assert len(ws_manager.broadcasts) > 0
assert any(
b.get("type") == "batch_import_progress" for b in ws_manager.broadcasts
)
@pytest.mark.asyncio
async def test_cancellation_stops_processing(self, mock_services):
ws_manager, analysis_service, persistence_service, logger = mock_services
service = BatchImportService(
analysis_service=analysis_service,
persistence_service=persistence_service,
ws_manager=ws_manager,
logger=logger,
)
recipe_scanner_getter = lambda: SimpleNamespace()
civitai_client_getter = lambda: SimpleNamespace()
items = [{"source": f"https://example.com/{i}.png"} for i in range(10)]
operation_id = await service.start_batch_import(
recipe_scanner_getter=recipe_scanner_getter,
civitai_client_getter=civitai_client_getter,
items=items,
)
service.cancel_import(operation_id)
await asyncio.sleep(0.3)
progress = service.get_progress(operation_id)
if progress:
assert progress.status == "cancelled"
class TestBatchImportServiceEdgeCases:
@pytest.fixture
def service(self):
ws_manager = MockWebSocketManager()
analysis_service = MockAnalysisService()
persistence_service = MockPersistenceService()
logger = logging.getLogger("test")
return BatchImportService(
analysis_service=analysis_service,
persistence_service=persistence_service,
ws_manager=ws_manager,
logger=logger,
)
@pytest.mark.asyncio
async def test_empty_items_list(self, service):
recipe_scanner_getter = lambda: SimpleNamespace()
civitai_client_getter = lambda: SimpleNamespace()
operation_id = await service.start_batch_import(
recipe_scanner_getter=recipe_scanner_getter,
civitai_client_getter=civitai_client_getter,
items=[],
)
progress = service.get_progress(operation_id)
assert progress is not None
assert progress.total == 0
@pytest.mark.asyncio
async def test_mixed_url_and_path_items(self, service, tmp_path):
(tmp_path / "local.png").write_bytes(b"fake-image")
recipe_scanner_getter = lambda: SimpleNamespace()
civitai_client_getter = lambda: SimpleNamespace()
operation_id = await service.start_batch_import(
recipe_scanner_getter=recipe_scanner_getter,
civitai_client_getter=civitai_client_getter,
items=[
{"source": "https://example.com/remote.png", "type": "url"},
{"source": str(tmp_path / "local.png"), "type": "local_path"},
],
)
progress = service.get_progress(operation_id)
assert progress is not None
assert progress.total == 2
assert progress.items[0].item_type == ImportItemType.URL
assert progress.items[1].item_type == ImportItemType.LOCAL_PATH
@pytest.mark.asyncio
async def test_tags_are_passed_to_persistence(self, tmp_path):
ws_manager = MockWebSocketManager()
analysis_service = MockAnalysisService(
{
str(tmp_path / "test.png"): MockAnalysisResult(
{
"loras": [{"name": "test-lora"}],
}
),
}
)
persistence_service = MockPersistenceService()
logger = logging.getLogger("test")
(tmp_path / "test.png").write_bytes(b"fake-image")
service = BatchImportService(
analysis_service=analysis_service,
persistence_service=persistence_service,
ws_manager=ws_manager,
logger=logger,
)
recipe_scanner_getter = lambda: SimpleNamespace(
find_recipes_by_fingerprint=lambda x: [],
)
civitai_client_getter = lambda: SimpleNamespace()
operation_id = await service.start_batch_import(
recipe_scanner_getter=recipe_scanner_getter,
civitai_client_getter=civitai_client_getter,
items=[{"source": str(tmp_path / "test.png")}],
tags=["batch-import", "test"],
)
await asyncio.sleep(0.3)
if persistence_service.saved_recipes:
assert "batch-import" in persistence_service.saved_recipes[0]["tags"]
assert "test" in persistence_service.saved_recipes[0]["tags"]
@pytest.mark.asyncio
async def test_skip_duplicates_parameter(self, service):
recipe_scanner_getter = lambda: SimpleNamespace()
civitai_client_getter = lambda: SimpleNamespace()
operation_id = await service.start_batch_import(
recipe_scanner_getter=recipe_scanner_getter,
civitai_client_getter=civitai_client_getter,
items=[{"source": "https://example.com/test.png"}],
skip_duplicates=True,
)
progress = service.get_progress(operation_id)
assert progress is not None
assert progress.skip_duplicates is True
@pytest.mark.asyncio
async def test_skip_duplicates_false_by_default(self, service):
recipe_scanner_getter = lambda: SimpleNamespace()
civitai_client_getter = lambda: SimpleNamespace()
operation_id = await service.start_batch_import(
recipe_scanner_getter=recipe_scanner_getter,
civitai_client_getter=civitai_client_getter,
items=[{"source": "https://example.com/test.png"}],
)
progress = service.get_progress(operation_id)
assert progress is not None
assert progress.skip_duplicates is False
class TestInputValidation:
@pytest.fixture
def service(self):
ws_manager = MockWebSocketManager()
analysis_service = MockAnalysisService()
persistence_service = MockPersistenceService()
logger = logging.getLogger("test")
return BatchImportService(
analysis_service=analysis_service,
persistence_service=persistence_service,
ws_manager=ws_manager,
logger=logger,
)
def test_validate_valid_url(self, service):
assert service._validate_url("https://example.com/image.png") is True
assert service._validate_url("http://example.com/image.png") is True
assert service._validate_url("https://civitai.com/images/123") is True
def test_validate_invalid_url(self, service):
assert service._validate_url("not-a-url") is False
assert service._validate_url("ftp://example.com/file") is False
assert service._validate_url("") is False
def test_validate_valid_local_path(self, service, tmp_path):
valid_path = str(tmp_path / "image.png")
assert service._validate_local_path(valid_path) is True
def test_validate_invalid_local_path(self, service):
assert service._validate_local_path("../etc/passwd") is False
assert service._validate_local_path("relative/path.png") is False
assert service._validate_local_path("") is False