diff --git a/locales/en.json b/locales/en.json index bf085afd..2780fd34 100644 --- a/locales/en.json +++ b/locales/en.json @@ -729,6 +729,42 @@ "failed": "Failed to repair recipe: {message}", "missingId": "Cannot repair recipe: Missing recipe ID" } + }, + "batchImport": { + "title": "Batch Import Recipes", + "action": "Batch Import", + "urlsMode": "URLs / Paths", + "directoryMode": "Directory", + "urlsDescription": "Enter image URLs or local file paths (one per line). Each will be imported as a recipe.", + "urlsLabel": "Image URLs or Local Paths", + "urlsPlaceholder": "https://civitai.com/images/...\nhttps://civitai.com/images/...\nC:/path/to/image.png\n...", + "directoryDescription": "Enter a directory path to import all images from that folder.", + "directoryLabel": "Directory Path", + "directoryPlaceholder": "/path/to/images/folder", + "tagsOptional": "Tags (optional, applied to all recipes)", + "addTagPlaceholder": "Add a tag", + "addTag": "Add", + "noTags": "No tags added", + "skipNoMetadata": "Skip images without metadata", + "skipNoMetadataHelp": "Images without LoRA metadata will be skipped automatically.", + "startImport": "Start Import", + "importing": "Importing Recipes...", + "progress": "Progress", + "currentItem": "Current", + "preparing": "Preparing...", + "cancelImport": "Cancel", + "cancelled": "Batch import cancelled", + "completed": "Import Completed", + "completedSuccess": "Successfully imported {count} recipe(s)", + "successCount": "Successful", + "failedCount": "Failed", + "skippedCount": "Skipped", + "totalProcessed": "Total processed", + "errors": { + "enterUrls": "Please enter at least one URL or path", + "enterDirectory": "Please enter a directory path", + "startFailed": "Failed to start import: {message}" + } } }, "checkpoints": { diff --git a/locales/zh-CN.json b/locales/zh-CN.json index fb2722e3..3ea49e1c 100644 --- a/locales/zh-CN.json +++ b/locales/zh-CN.json @@ -722,13 +722,49 @@ "getInfoFailed": "获取缺失 LoRA 信息失败", "prepareError": "准备下载 LoRA 时出错:{message}" }, - "repair": { +"repair": { "starting": "正在修复配方元数据...", "success": "配方元数据修复成功", "skipped": "配方已是最新版本,无需修复", "failed": "修复配方失败:{message}", "missingId": "无法修复配方:缺少配方 ID" } + }, + "batchImport": { + "title": "批量导入配方", + "action": "批量导入", + "urlsMode": "URL / 路径", + "directoryMode": "目录", + "urlsDescription": "输入图片 URL 或本地文件路径(每行一个)。每个将作为配方导入。", + "urlsLabel": "图片 URL 或本地路径", + "urlsPlaceholder": "https://civitai.com/images/...\nhttps://civitai.com/images/...\nC:/path/to/image.png\n...", + "directoryDescription": "输入目录路径以导入该文件夹中的所有图片。", + "directoryLabel": "目录路径", + "directoryPlaceholder": "/图片/文件夹/路径", + "tagsOptional": "标签(可选,应用于所有配方)", + "addTagPlaceholder": "添加标签", + "addTag": "添加", + "noTags": "未添加标签", + "skipNoMetadata": "跳过无元数据的图片", + "skipNoMetadataHelp": "没有 LoRA 元数据的图片将自动跳过。", + "startImport": "开始导入", + "importing": "正在导入配方...", + "progress": "进度", + "currentItem": "当前", + "preparing": "准备中...", + "cancelImport": "取消", + "cancelled": "批量导入已取消", + "completed": "导入完成", + "completedSuccess": "成功导入 {count} 个配方", + "successCount": "成功", + "failedCount": "失败", + "skippedCount": "跳过", + "totalProcessed": "总计处理", + "errors": { + "enterUrls": "请至少输入一个 URL 或路径", + "enterDirectory": "请输入目录路径", + "startFailed": "启动导入失败:{message}" + } } }, "checkpoints": { diff --git a/py/routes/base_recipe_routes.py b/py/routes/base_recipe_routes.py index 162f3491..1b7eaf12 100644 --- a/py/routes/base_recipe_routes.py +++ b/py/routes/base_recipe_routes.py @@ -1,4 +1,5 @@ """Base infrastructure shared across recipe routes.""" + from __future__ import annotations import logging @@ -16,12 +17,14 @@ from ..services.recipes import ( RecipePersistenceService, RecipeSharingService, ) +from ..services.batch_import_service import BatchImportService from ..services.server_i18n import server_i18n from ..services.service_registry import ServiceRegistry from ..services.settings_manager import get_settings_manager from ..utils.constants import CARD_PREVIEW_WIDTH from ..utils.exif_utils import ExifUtils from .handlers.recipe_handlers import ( + BatchImportHandler, RecipeAnalysisHandler, RecipeHandlerSet, RecipeListingHandler, @@ -116,7 +119,10 @@ class BaseRecipeRoutes: recipe_scanner_getter = lambda: self.recipe_scanner civitai_client_getter = lambda: self.civitai_client - standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" + standalone_mode = ( + os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" + or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" + ) if not standalone_mode: from ..metadata_collector import get_metadata # type: ignore[import-not-found] from ..metadata_collector.metadata_processor import ( # type: ignore[import-not-found] @@ -190,6 +196,22 @@ class BaseRecipeRoutes: sharing_service=sharing_service, ) + from ..services.websocket_manager import ws_manager + + batch_import_service = BatchImportService( + analysis_service=analysis_service, + persistence_service=persistence_service, + ws_manager=ws_manager, + logger=logger, + ) + batch_import = BatchImportHandler( + ensure_dependencies_ready=self.ensure_dependencies_ready, + recipe_scanner_getter=recipe_scanner_getter, + civitai_client_getter=civitai_client_getter, + logger=logger, + batch_import_service=batch_import_service, + ) + return RecipeHandlerSet( page_view=page_view, listing=listing, @@ -197,4 +219,5 @@ class BaseRecipeRoutes: management=management, analysis=analysis, sharing=sharing, + batch_import=batch_import, ) diff --git a/py/routes/handlers/recipe_handlers.py b/py/routes/handlers/recipe_handlers.py index 0ffe211c..021b9fc3 100644 --- a/py/routes/handlers/recipe_handlers.py +++ b/py/routes/handlers/recipe_handlers.py @@ -1,4 +1,5 @@ """Dedicated handler objects for recipe-related routes.""" + from __future__ import annotations import json @@ -29,6 +30,7 @@ from ...utils.exif_utils import ExifUtils from ...recipes.merger import GenParamsMerger from ...recipes.enrichment import RecipeEnricher from ...services.websocket_manager import ws_manager as default_ws_manager +from ...services.batch_import_service import BatchImportService Logger = logging.Logger EnsureDependenciesCallable = Callable[[], Awaitable[None]] @@ -46,8 +48,11 @@ class RecipeHandlerSet: management: "RecipeManagementHandler" analysis: "RecipeAnalysisHandler" sharing: "RecipeSharingHandler" + batch_import: "BatchImportHandler" - def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]: + def to_route_mapping( + self, + ) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]: """Expose handler coroutines keyed by registrar handler names.""" return { @@ -81,6 +86,9 @@ class RecipeHandlerSet: "cancel_repair": self.management.cancel_repair, "repair_recipe": self.management.repair_recipe, "get_repair_progress": self.management.get_repair_progress, + "start_batch_import": self.batch_import.start_batch_import, + "get_batch_import_progress": self.batch_import.get_batch_import_progress, + "cancel_batch_import": self.batch_import.cancel_batch_import, } @@ -170,8 +178,10 @@ class RecipeListingHandler: search_options = { "title": request.query.get("search_title", "true").lower() == "true", "tags": request.query.get("search_tags", "true").lower() == "true", - "lora_name": request.query.get("search_lora_name", "true").lower() == "true", - "lora_model": request.query.get("search_lora_model", "true").lower() == "true", + "lora_name": request.query.get("search_lora_name", "true").lower() + == "true", + "lora_model": request.query.get("search_lora_model", "true").lower() + == "true", "prompt": request.query.get("search_prompt", "true").lower() == "true", } @@ -246,7 +256,9 @@ class RecipeListingHandler: return web.json_response({"error": "Recipe not found"}, status=404) return web.json_response(recipe) except Exception as exc: - self._logger.error("Error retrieving recipe details: %s", exc, exc_info=True) + self._logger.error( + "Error retrieving recipe details: %s", exc, exc_info=True + ) return web.json_response({"error": str(exc)}, status=500) def format_recipe_file_url(self, file_path: str) -> str: @@ -256,7 +268,9 @@ class RecipeListingHandler: if static_url: return static_url except Exception as exc: # pragma: no cover - logging path - self._logger.error("Error formatting recipe file URL: %s", exc, exc_info=True) + self._logger.error( + "Error formatting recipe file URL: %s", exc, exc_info=True + ) return "/loras_static/images/no-preview.png" return "/loras_static/images/no-preview.png" @@ -293,7 +307,9 @@ class RecipeQueryHandler: for tag in recipe.get("tags", []) or []: tag_counts[tag] = tag_counts.get(tag, 0) + 1 - sorted_tags = [{"tag": tag, "count": count} for tag, count in tag_counts.items()] + sorted_tags = [ + {"tag": tag, "count": count} for tag, count in tag_counts.items() + ] sorted_tags.sort(key=lambda entry: entry["count"], reverse=True) return web.json_response({"success": True, "tags": sorted_tags[:limit]}) except Exception as exc: @@ -313,9 +329,14 @@ class RecipeQueryHandler: for recipe in getattr(cache, "raw_data", []): base_model = recipe.get("base_model") if base_model: - base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 + base_model_counts[base_model] = ( + base_model_counts.get(base_model, 0) + 1 + ) - sorted_models = [{"name": model, "count": count} for model, count in base_model_counts.items()] + sorted_models = [ + {"name": model, "count": count} + for model, count in base_model_counts.items() + ] sorted_models.sort(key=lambda entry: entry["count"], reverse=True) return web.json_response({"success": True, "base_models": sorted_models}) except Exception as exc: @@ -345,7 +366,9 @@ class RecipeQueryHandler: folders = await recipe_scanner.get_folders() return web.json_response({"success": True, "folders": folders}) except Exception as exc: - self._logger.error("Error retrieving recipe folders: %s", exc, exc_info=True) + self._logger.error( + "Error retrieving recipe folders: %s", exc, exc_info=True + ) return web.json_response({"success": False, "error": str(exc)}, status=500) async def get_folder_tree(self, request: web.Request) -> web.Response: @@ -358,7 +381,9 @@ class RecipeQueryHandler: folder_tree = await recipe_scanner.get_folder_tree() return web.json_response({"success": True, "tree": folder_tree}) except Exception as exc: - self._logger.error("Error retrieving recipe folder tree: %s", exc, exc_info=True) + self._logger.error( + "Error retrieving recipe folder tree: %s", exc, exc_info=True + ) return web.json_response({"success": False, "error": str(exc)}, status=500) async def get_unified_folder_tree(self, request: web.Request) -> web.Response: @@ -371,7 +396,9 @@ class RecipeQueryHandler: folder_tree = await recipe_scanner.get_folder_tree() return web.json_response({"success": True, "tree": folder_tree}) except Exception as exc: - self._logger.error("Error retrieving unified recipe folder tree: %s", exc, exc_info=True) + self._logger.error( + "Error retrieving unified recipe folder tree: %s", exc, exc_info=True + ) return web.json_response({"success": False, "error": str(exc)}, status=500) async def get_recipes_for_lora(self, request: web.Request) -> web.Response: @@ -383,7 +410,9 @@ class RecipeQueryHandler: lora_hash = request.query.get("hash") if not lora_hash: - return web.json_response({"success": False, "error": "Lora hash is required"}, status=400) + return web.json_response( + {"success": False, "error": "Lora hash is required"}, status=400 + ) matching_recipes = await recipe_scanner.get_recipes_for_lora(lora_hash) return web.json_response({"success": True, "recipes": matching_recipes}) @@ -400,7 +429,9 @@ class RecipeQueryHandler: self._logger.info("Manually triggering recipe cache rebuild") await recipe_scanner.get_cached_data(force_refresh=True) - return web.json_response({"success": True, "message": "Recipe cache refreshed successfully"}) + return web.json_response( + {"success": True, "message": "Recipe cache refreshed successfully"} + ) except Exception as exc: self._logger.error("Error refreshing recipe cache: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) @@ -429,7 +460,9 @@ class RecipeQueryHandler: "id": recipe.get("id"), "title": recipe.get("title"), "file_url": recipe.get("file_url") - or self._format_recipe_file_url(recipe.get("file_path", "")), + or self._format_recipe_file_url( + recipe.get("file_path", "") + ), "modified": recipe.get("modified"), "created_date": recipe.get("created_date"), "lora_count": len(recipe.get("loras", [])), @@ -437,7 +470,9 @@ class RecipeQueryHandler: ) if len(recipes) >= 2: - recipes.sort(key=lambda entry: entry.get("modified", 0), reverse=True) + recipes.sort( + key=lambda entry: entry.get("modified", 0), reverse=True + ) response_data.append( { "type": "fingerprint", @@ -460,7 +495,9 @@ class RecipeQueryHandler: "id": recipe.get("id"), "title": recipe.get("title"), "file_url": recipe.get("file_url") - or self._format_recipe_file_url(recipe.get("file_path", "")), + or self._format_recipe_file_url( + recipe.get("file_path", "") + ), "modified": recipe.get("modified"), "created_date": recipe.get("created_date"), "lora_count": len(recipe.get("loras", [])), @@ -468,7 +505,9 @@ class RecipeQueryHandler: ) if len(recipes) >= 2: - recipes.sort(key=lambda entry: entry.get("modified", 0), reverse=True) + recipes.sort( + key=lambda entry: entry.get("modified", 0), reverse=True + ) response_data.append( { "type": "source_url", @@ -479,9 +518,13 @@ class RecipeQueryHandler: ) response_data.sort(key=lambda entry: entry["count"], reverse=True) - return web.json_response({"success": True, "duplicate_groups": response_data}) + return web.json_response( + {"success": True, "duplicate_groups": response_data} + ) except Exception as exc: - self._logger.error("Error finding duplicate recipes: %s", exc, exc_info=True) + self._logger.error( + "Error finding duplicate recipes: %s", exc, exc_info=True + ) return web.json_response({"success": False, "error": str(exc)}, status=500) async def get_recipe_syntax(self, request: web.Request) -> web.Response: @@ -498,9 +541,13 @@ class RecipeQueryHandler: return web.json_response({"error": "Recipe not found"}, status=404) if not syntax_parts: - return web.json_response({"error": "No LoRAs found in this recipe"}, status=400) + return web.json_response( + {"error": "No LoRAs found in this recipe"}, status=400 + ) - return web.json_response({"success": True, "syntax": " ".join(syntax_parts)}) + return web.json_response( + {"success": True, "syntax": " ".join(syntax_parts)} + ) except Exception as exc: self._logger.error("Error generating recipe syntax: %s", exc, exc_info=True) return web.json_response({"error": str(exc)}, status=500) @@ -561,11 +608,17 @@ class RecipeManagementHandler: await self._ensure_dependencies_ready() recipe_scanner = self._recipe_scanner_getter() if recipe_scanner is None: - return web.json_response({"success": False, "error": "Recipe scanner unavailable"}, status=503) + return web.json_response( + {"success": False, "error": "Recipe scanner unavailable"}, + status=503, + ) # Check if already running if self._ws_manager.is_recipe_repair_running(): - return web.json_response({"success": False, "error": "Recipe repair already in progress"}, status=409) + return web.json_response( + {"success": False, "error": "Recipe repair already in progress"}, + status=409, + ) recipe_scanner.reset_cancellation() @@ -579,11 +632,12 @@ class RecipeManagementHandler: progress_callback=progress_callback ) except Exception as e: - self._logger.error(f"Error in recipe repair task: {e}", exc_info=True) - await self._ws_manager.broadcast_recipe_repair_progress({ - "status": "error", - "error": str(e) - }) + self._logger.error( + f"Error in recipe repair task: {e}", exc_info=True + ) + await self._ws_manager.broadcast_recipe_repair_progress( + {"status": "error", "error": str(e)} + ) finally: # Keep the final status for a while so the UI can see it await asyncio.sleep(5) @@ -593,7 +647,9 @@ class RecipeManagementHandler: asyncio.create_task(run_repair()) - return web.json_response({"success": True, "message": "Recipe repair started"}) + return web.json_response( + {"success": True, "message": "Recipe repair started"} + ) except Exception as exc: self._logger.error("Error starting recipe repair: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) @@ -603,10 +659,15 @@ class RecipeManagementHandler: await self._ensure_dependencies_ready() recipe_scanner = self._recipe_scanner_getter() if recipe_scanner is None: - return web.json_response({"success": False, "error": "Recipe scanner unavailable"}, status=503) + return web.json_response( + {"success": False, "error": "Recipe scanner unavailable"}, + status=503, + ) recipe_scanner.cancel_task() - return web.json_response({"success": True, "message": "Cancellation requested"}) + return web.json_response( + {"success": True, "message": "Cancellation requested"} + ) except Exception as exc: self._logger.error("Error cancelling recipe repair: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) @@ -616,7 +677,10 @@ class RecipeManagementHandler: await self._ensure_dependencies_ready() recipe_scanner = self._recipe_scanner_getter() if recipe_scanner is None: - return web.json_response({"success": False, "error": "Recipe scanner unavailable"}, status=503) + return web.json_response( + {"success": False, "error": "Recipe scanner unavailable"}, + status=503, + ) recipe_id = request.match_info["recipe_id"] result = await recipe_scanner.repair_recipe_by_id(recipe_id) @@ -632,25 +696,26 @@ class RecipeManagementHandler: progress = self._ws_manager.get_recipe_repair_progress() if progress: return web.json_response({"success": True, "progress": progress}) - return web.json_response({"success": False, "message": "No repair in progress"}, status=404) + return web.json_response( + {"success": False, "message": "No repair in progress"}, status=404 + ) except Exception as exc: self._logger.error("Error getting repair progress: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) - async def import_remote_recipe(self, request: web.Request) -> web.Response: try: await self._ensure_dependencies_ready() recipe_scanner = self._recipe_scanner_getter() if recipe_scanner is None: raise RuntimeError("Recipe scanner unavailable") - + # 1. Parse Parameters params = request.rel_url.query image_url = params.get("image_url") name = params.get("name") resources_raw = params.get("resources") - + if not image_url: raise RecipeValidationError("Missing required field: image_url") if not name: @@ -658,64 +723,80 @@ class RecipeManagementHandler: if not resources_raw: raise RecipeValidationError("Missing required field: resources") - checkpoint_entry, lora_entries = self._parse_resources_payload(resources_raw) + checkpoint_entry, lora_entries = self._parse_resources_payload( + resources_raw + ) gen_params_request = self._parse_gen_params(params.get("gen_params")) - + # 2. Initial Metadata Construction metadata: Dict[str, Any] = { "base_model": params.get("base_model", "") or "", "loras": lora_entries, "gen_params": gen_params_request or {}, - "source_url": image_url + "source_url": image_url, } - + source_path = params.get("source_path") if source_path: metadata["source_path"] = source_path - + # Checkpoint handling if checkpoint_entry: metadata["checkpoint"] = checkpoint_entry # Ensure checkpoint is also in gen_params for consistency if needed by enricher? # Actually enricher looks at metadata['checkpoint'], so this is fine. - + # Try to resolve base model from checkpoint if not explicitly provided if not metadata["base_model"]: - base_model_from_metadata = await self._resolve_base_model_from_checkpoint(checkpoint_entry) + base_model_from_metadata = ( + await self._resolve_base_model_from_checkpoint(checkpoint_entry) + ) if base_model_from_metadata: metadata["base_model"] = base_model_from_metadata tags = self._parse_tags(params.get("tags")) - + # 3. Download Image - image_bytes, extension, civitai_meta_from_download = await self._download_remote_media(image_url) + ( + image_bytes, + extension, + civitai_meta_from_download, + ) = await self._download_remote_media(image_url) # 4. Extract Embedded Metadata - # Note: We still extract this here because Enricher currently expects 'gen_params' to already be populated - # with embedded data if we want it to merge it. + # Note: We still extract this here because Enricher currently expects 'gen_params' to already be populated + # with embedded data if we want it to merge it. # However, logic in Enricher merges: request > civitai > embedded. - # So we should gather embedded params and put them into the recipe's gen_params (as initial state) + # So we should gather embedded params and put them into the recipe's gen_params (as initial state) # OR pass them to enricher to handle? # The interface of Enricher.enrich_recipe takes `recipe` (with gen_params) and `request_params`. # So let's extract embedded and put it into recipe['gen_params'] but careful not to overwrite request params. # Actually, `GenParamsMerger` which `Enricher` uses handles 3 layers. # But `Enricher` interface is: recipe['gen_params'] (as embedded) + request_params + civitai (fetched internally). - # Wait, `Enricher` fetches Civitai info internally based on URL. + # Wait, `Enricher` fetches Civitai info internally based on URL. # `civitai_meta_from_download` is returned by `_download_remote_media` which might be useful if URL didn't have ID. - + # Let's extract embedded metadata first embedded_gen_params = {} try: - with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_img: + with tempfile.NamedTemporaryFile( + suffix=extension, delete=False + ) as temp_img: temp_img.write(image_bytes) temp_img_path = temp_img.name - + try: raw_embedded = ExifUtils.extract_image_metadata(temp_img_path) if raw_embedded: - parser = self._analysis_service._recipe_parser_factory.create_parser(raw_embedded) + parser = ( + self._analysis_service._recipe_parser_factory.create_parser( + raw_embedded + ) + ) if parser: - parsed_embedded = await parser.parse_metadata(raw_embedded, recipe_scanner=recipe_scanner) + parsed_embedded = await parser.parse_metadata( + raw_embedded, recipe_scanner=recipe_scanner + ) if parsed_embedded and "gen_params" in parsed_embedded: embedded_gen_params = parsed_embedded["gen_params"] else: @@ -724,7 +805,9 @@ class RecipeManagementHandler: if os.path.exists(temp_img_path): os.unlink(temp_img_path) except Exception as exc: - self._logger.warning("Failed to extract embedded metadata during import: %s", exc) + self._logger.warning( + "Failed to extract embedded metadata during import: %s", exc + ) # Pre-populate gen_params with embedded data so Enricher treats it as the "base" layer if embedded_gen_params: @@ -732,18 +815,18 @@ class RecipeManagementHandler: # But wait, we want request params to override everything. # So we should set recipe['gen_params'] = embedded, and pass request params to enricher. metadata["gen_params"] = embedded_gen_params - + # 5. Enrich with unified logic # This will fetch Civitai info (if URL matches) and merge: request > civitai > embedded civitai_client = self._civitai_client_getter() await RecipeEnricher.enrich_recipe( - recipe=metadata, + recipe=metadata, civitai_client=civitai_client, - request_params=gen_params_request # Pass explicit request params here to override + request_params=gen_params_request, # Pass explicit request params here to override ) - + # If we got civitai_meta from download but Enricher didn't fetch it (e.g. not a civitai URL or failed), - # we might want to manually merge it? + # we might want to manually merge it? # But usually `import_remote_recipe` is used with Civitai URLs. # For now, relying on Enricher's internal fetch is consistent with repair. @@ -762,7 +845,9 @@ class RecipeManagementHandler: except RecipeDownloadError as exc: return web.json_response({"error": str(exc)}, status=400) except Exception as exc: - self._logger.error("Error importing recipe from remote source: %s", exc, exc_info=True) + self._logger.error( + "Error importing recipe from remote source: %s", exc, exc_info=True + ) return web.json_response({"error": str(exc)}, status=500) async def delete_recipe(self, request: web.Request) -> web.Response: @@ -816,7 +901,11 @@ class RecipeManagementHandler: target_path = data.get("target_path") if not recipe_id or not target_path: return web.json_response( - {"success": False, "error": "recipe_id and target_path are required"}, status=400 + { + "success": False, + "error": "recipe_id and target_path are required", + }, + status=400, ) result = await self._persistence_service.move_recipe( @@ -845,7 +934,11 @@ class RecipeManagementHandler: target_path = data.get("target_path") if not recipe_ids or not target_path: return web.json_response( - {"success": False, "error": "recipe_ids and target_path are required"}, status=400 + { + "success": False, + "error": "recipe_ids and target_path are required", + }, + status=400, ) result = await self._persistence_service.move_recipes_bulk( @@ -934,7 +1027,9 @@ class RecipeManagementHandler: except RecipeValidationError as exc: return web.json_response({"error": str(exc)}, status=400) except Exception as exc: - self._logger.error("Error saving recipe from widget: %s", exc, exc_info=True) + self._logger.error( + "Error saving recipe from widget: %s", exc, exc_info=True + ) return web.json_response({"error": str(exc)}, status=500) async def _parse_save_payload(self, reader) -> dict[str, Any]: @@ -1006,7 +1101,9 @@ class RecipeManagementHandler: raise RecipeValidationError("gen_params payload must be an object") return parsed - def _parse_resources_payload(self, payload_raw: str) -> tuple[Optional[Dict[str, Any]], List[Dict[str, Any]]]: + def _parse_resources_payload( + self, payload_raw: str + ) -> tuple[Optional[Dict[str, Any]], List[Dict[str, Any]]]: try: payload = json.loads(payload_raw) except json.JSONDecodeError as exc: @@ -1066,15 +1163,19 @@ class RecipeManagementHandler: civitai_match = re.match(r"https://civitai\.com/images/(\d+)", image_url) if civitai_match: if civitai_client is None: - raise RecipeDownloadError("Civitai client unavailable for image download") + raise RecipeDownloadError( + "Civitai client unavailable for image download" + ) image_info = await civitai_client.get_image_info(civitai_match.group(1)) if not image_info: - raise RecipeDownloadError("Failed to fetch image information from Civitai") - + raise RecipeDownloadError( + "Failed to fetch image information from Civitai" + ) + media_url = image_info.get("url") if not media_url: raise RecipeDownloadError("No image URL found in Civitai response") - + # Use optimized preview URLs if possible media_type = image_info.get("type") rewritten_url, _ = rewrite_preview_url(media_url, media_type=media_type) @@ -1083,18 +1184,24 @@ class RecipeManagementHandler: else: download_url = media_url - success, result = await downloader.download_file(download_url, temp_path, use_auth=False) + success, result = await downloader.download_file( + download_url, temp_path, use_auth=False + ) if not success: raise RecipeDownloadError(f"Failed to download image: {result}") - + # Extract extension from URL - url_path = download_url.split('?')[0].split('#')[0] + url_path = download_url.split("?")[0].split("#")[0] extension = os.path.splitext(url_path)[1].lower() if not extension: - extension = ".webp" # Default to webp if unknown + extension = ".webp" # Default to webp if unknown with open(temp_path, "rb") as file_obj: - return file_obj.read(), extension, image_info.get("meta") if civitai_match and image_info else None + return ( + file_obj.read(), + extension, + image_info.get("meta") if civitai_match and image_info else None, + ) except RecipeDownloadError: raise except RecipeValidationError: @@ -1108,14 +1215,15 @@ class RecipeManagementHandler: except FileNotFoundError: pass - def _safe_int(self, value: Any) -> int: try: return int(value) except (TypeError, ValueError): return 0 - async def _resolve_base_model_from_checkpoint(self, checkpoint_entry: Dict[str, Any]) -> str: + async def _resolve_base_model_from_checkpoint( + self, checkpoint_entry: Dict[str, Any] + ) -> str: version_id = self._safe_int(checkpoint_entry.get("modelVersionId")) if not version_id: @@ -1134,7 +1242,9 @@ class RecipeManagementHandler: base_model = version_info.get("baseModel") or "" return str(base_model) if base_model is not None else "" except Exception as exc: # pragma: no cover - defensive logging - self._logger.warning("Failed to resolve base model from checkpoint metadata: %s", exc) + self._logger.warning( + "Failed to resolve base model from checkpoint metadata: %s", exc + ) return "" @@ -1279,5 +1389,178 @@ class RecipeSharingHandler: except RecipeNotFoundError as exc: return web.json_response({"error": str(exc)}, status=404) except Exception as exc: - self._logger.error("Error downloading shared recipe: %s", exc, exc_info=True) + self._logger.error( + "Error downloading shared recipe: %s", exc, exc_info=True + ) return web.json_response({"error": str(exc)}, status=500) + + +class BatchImportHandler: + """Handle batch import operations for recipes.""" + + def __init__( + self, + *, + ensure_dependencies_ready: EnsureDependenciesCallable, + recipe_scanner_getter: RecipeScannerGetter, + civitai_client_getter: CivitaiClientGetter, + logger: Logger, + batch_import_service: BatchImportService, + ) -> None: + self._ensure_dependencies_ready = ensure_dependencies_ready + self._recipe_scanner_getter = recipe_scanner_getter + self._civitai_client_getter = civitai_client_getter + self._logger = logger + self._batch_import_service = batch_import_service + + async def start_batch_import(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + + if self._batch_import_service.is_import_running(): + return web.json_response( + {"success": False, "error": "Batch import already in progress"}, + status=409, + ) + + data = await request.json() + items = data.get("items", []) + tags = data.get("tags", []) + skip_no_metadata = data.get("skip_no_metadata", True) + + if not items: + return web.json_response( + {"success": False, "error": "No items provided"}, + status=400, + ) + + for item in items: + if not item.get("source"): + return web.json_response( + { + "success": False, + "error": "Each item must have a 'source' field", + }, + status=400, + ) + + operation_id = await self._batch_import_service.start_batch_import( + recipe_scanner_getter=self._recipe_scanner_getter, + civitai_client_getter=self._civitai_client_getter, + items=items, + tags=tags, + skip_no_metadata=skip_no_metadata, + ) + + return web.json_response( + { + "success": True, + "operation_id": operation_id, + } + ) + except RecipeValidationError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=400) + except Exception as exc: + self._logger.error("Error starting batch import: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def start_directory_import(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + + if self._batch_import_service.is_import_running(): + return web.json_response( + {"success": False, "error": "Batch import already in progress"}, + status=409, + ) + + data = await request.json() + directory = data.get("directory") + recursive = data.get("recursive", True) + tags = data.get("tags", []) + skip_no_metadata = data.get("skip_no_metadata", True) + + if not directory: + return web.json_response( + {"success": False, "error": "Directory path is required"}, + status=400, + ) + + operation_id = await self._batch_import_service.start_directory_import( + recipe_scanner_getter=self._recipe_scanner_getter, + civitai_client_getter=self._civitai_client_getter, + directory=directory, + recursive=recursive, + tags=tags, + skip_no_metadata=skip_no_metadata, + ) + + return web.json_response( + { + "success": True, + "operation_id": operation_id, + } + ) + except RecipeValidationError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=400) + except Exception as exc: + self._logger.error( + "Error starting directory import: %s", exc, exc_info=True + ) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_batch_import_progress(self, request: web.Request) -> web.Response: + try: + operation_id = request.query.get("operation_id") + if not operation_id: + return web.json_response( + {"success": False, "error": "operation_id is required"}, + status=400, + ) + + progress = self._batch_import_service.get_progress(operation_id) + if not progress: + return web.json_response( + {"success": False, "error": "Operation not found"}, + status=404, + ) + + return web.json_response( + { + "success": True, + "progress": progress.to_dict(), + } + ) + except Exception as exc: + self._logger.error( + "Error getting batch import progress: %s", exc, exc_info=True + ) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def cancel_batch_import(self, request: web.Request) -> web.Response: + try: + data = await request.json() + operation_id = data.get("operation_id") + + if not operation_id: + return web.json_response( + {"success": False, "error": "operation_id is required"}, + status=400, + ) + + cancelled = self._batch_import_service.cancel_import(operation_id) + if not cancelled: + return web.json_response( + { + "success": False, + "error": "Operation not found or already completed", + }, + status=404, + ) + + return web.json_response( + {"success": True, "message": "Cancellation requested"} + ) + except Exception as exc: + self._logger.error("Error cancelling batch import: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) diff --git a/py/routes/recipe_route_registrar.py b/py/routes/recipe_route_registrar.py index 819b4331..aa098687 100644 --- a/py/routes/recipe_route_registrar.py +++ b/py/routes/recipe_route_registrar.py @@ -1,4 +1,5 @@ """Route registrar for recipe endpoints.""" + from __future__ import annotations from dataclasses import dataclass @@ -22,7 +23,9 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("GET", "/api/lm/recipe/{recipe_id}", "get_recipe"), RouteDefinition("GET", "/api/lm/recipes/import-remote", "import_remote_recipe"), RouteDefinition("POST", "/api/lm/recipes/analyze-image", "analyze_uploaded_image"), - RouteDefinition("POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image"), + RouteDefinition( + "POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image" + ), RouteDefinition("POST", "/api/lm/recipes/save", "save_recipe"), RouteDefinition("DELETE", "/api/lm/recipe/{recipe_id}", "delete_recipe"), RouteDefinition("GET", "/api/lm/recipes/top-tags", "get_top_tags"), @@ -30,9 +33,13 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("GET", "/api/lm/recipes/roots", "get_roots"), RouteDefinition("GET", "/api/lm/recipes/folders", "get_folders"), RouteDefinition("GET", "/api/lm/recipes/folder-tree", "get_folder_tree"), - RouteDefinition("GET", "/api/lm/recipes/unified-folder-tree", "get_unified_folder_tree"), + RouteDefinition( + "GET", "/api/lm/recipes/unified-folder-tree", "get_unified_folder_tree" + ), RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share", "share_recipe"), - RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share/download", "download_shared_recipe"), + RouteDefinition( + "GET", "/api/lm/recipe/{recipe_id}/share/download", "download_shared_recipe" + ), RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/syntax", "get_recipe_syntax"), RouteDefinition("PUT", "/api/lm/recipe/{recipe_id}/update", "update_recipe"), RouteDefinition("POST", "/api/lm/recipe/move", "move_recipe"), @@ -40,13 +47,22 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("POST", "/api/lm/recipe/lora/reconnect", "reconnect_lora"), RouteDefinition("GET", "/api/lm/recipes/find-duplicates", "find_duplicates"), RouteDefinition("POST", "/api/lm/recipes/bulk-delete", "bulk_delete"), - RouteDefinition("POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget"), + RouteDefinition( + "POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget" + ), RouteDefinition("GET", "/api/lm/recipes/for-lora", "get_recipes_for_lora"), RouteDefinition("GET", "/api/lm/recipes/scan", "scan_recipes"), RouteDefinition("POST", "/api/lm/recipes/repair", "repair_recipes"), RouteDefinition("POST", "/api/lm/recipes/cancel-repair", "cancel_repair"), RouteDefinition("POST", "/api/lm/recipe/{recipe_id}/repair", "repair_recipe"), RouteDefinition("GET", "/api/lm/recipes/repair-progress", "get_repair_progress"), + RouteDefinition("POST", "/api/lm/recipes/batch-import/start", "start_batch_import"), + RouteDefinition( + "GET", "/api/lm/recipes/batch-import/progress", "get_batch_import_progress" + ), + RouteDefinition( + "POST", "/api/lm/recipes/batch-import/cancel", "cancel_batch_import" + ), ) @@ -63,7 +79,9 @@ class RecipeRouteRegistrar: def __init__(self, app: web.Application) -> None: self._app = app - def register_routes(self, handler_lookup: Mapping[str, Callable[[web.Request], object]]) -> None: + def register_routes( + self, handler_lookup: Mapping[str, Callable[[web.Request], object]] + ) -> None: for definition in ROUTE_DEFINITIONS: handler = handler_lookup[definition.handler_name] self._bind_route(definition.method, definition.path, handler) diff --git a/py/services/batch_import_service.py b/py/services/batch_import_service.py new file mode 100644 index 00000000..62737ad9 --- /dev/null +++ b/py/services/batch_import_service.py @@ -0,0 +1,579 @@ +"""Batch import service for importing multiple images as recipes.""" + +from __future__ import annotations + +import asyncio +import logging +import os +import time +import uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Set + +from aiohttp import web + +from .recipes import ( + RecipeAnalysisService, + RecipePersistenceService, + RecipeValidationError, + RecipeDownloadError, + RecipeNotFoundError, +) + + +class ImportItemType(Enum): + """Type of import item.""" + + URL = "url" + LOCAL_PATH = "local_path" + + +class ImportStatus(Enum): + """Status of an individual import item.""" + + PENDING = "pending" + PROCESSING = "processing" + SUCCESS = "success" + FAILED = "failed" + SKIPPED = "skipped" + + +@dataclass +class BatchImportItem: + """Represents a single item to import.""" + + id: str + source: str + item_type: ImportItemType + status: ImportStatus = ImportStatus.PENDING + error_message: Optional[str] = None + recipe_name: Optional[str] = None + recipe_id: Optional[str] = None + duration: float = 0.0 + + +@dataclass +class BatchImportProgress: + """Tracks progress of a batch import operation.""" + + operation_id: str + total: int + completed: int = 0 + success: int = 0 + failed: int = 0 + skipped: int = 0 + current_item: str = "" + status: str = "pending" + started_at: float = field(default_factory=time.time) + finished_at: Optional[float] = None + items: List[BatchImportItem] = field(default_factory=list) + tags: List[str] = field(default_factory=list) + skip_no_metadata: bool = True + skip_duplicates: bool = False + + def to_dict(self) -> Dict[str, Any]: + return { + "operation_id": self.operation_id, + "total": self.total, + "completed": self.completed, + "success": self.success, + "failed": self.failed, + "skipped": self.skipped, + "current_item": self.current_item, + "status": self.status, + "started_at": self.started_at, + "finished_at": self.finished_at, + "progress_percent": round((self.completed / self.total) * 100, 1) + if self.total > 0 + else 0, + } + + +class AdaptiveConcurrencyController: + """Adjusts concurrency based on task performance.""" + + def __init__( + self, + min_concurrency: int = 1, + max_concurrency: int = 5, + initial_concurrency: int = 3, + ) -> None: + self.min_concurrency = min_concurrency + self.max_concurrency = max_concurrency + self.current_concurrency = initial_concurrency + self._task_durations: List[float] = [] + self._recent_errors = 0 + self._recent_successes = 0 + + def record_result(self, duration: float, success: bool) -> None: + self._task_durations.append(duration) + if len(self._task_durations) > 10: + self._task_durations.pop(0) + + if success: + self._recent_successes += 1 + if duration < 1.0 and self.current_concurrency < self.max_concurrency: + self.current_concurrency = min( + self.current_concurrency + 1, self.max_concurrency + ) + elif duration > 10.0 and self.current_concurrency > self.min_concurrency: + self.current_concurrency = max( + self.current_concurrency - 1, self.min_concurrency + ) + else: + self._recent_errors += 1 + if self.current_concurrency > self.min_concurrency: + self.current_concurrency = max( + self.current_concurrency - 1, self.min_concurrency + ) + + def reset_counters(self) -> None: + self._recent_errors = 0 + self._recent_successes = 0 + + def get_semaphore(self) -> asyncio.Semaphore: + return asyncio.Semaphore(self.current_concurrency) + + +class BatchImportService: + """Service for batch importing images as recipes.""" + + SUPPORTED_EXTENSIONS: Set[str] = {".jpg", ".jpeg", ".png", ".webp", ".gif", ".bmp"} + + def __init__( + self, + *, + analysis_service: RecipeAnalysisService, + persistence_service: RecipePersistenceService, + ws_manager: Any, + logger: logging.Logger, + ) -> None: + self._analysis_service = analysis_service + self._persistence_service = persistence_service + self._ws_manager = ws_manager + self._logger = logger + self._active_operations: Dict[str, BatchImportProgress] = {} + self._cancellation_flags: Dict[str, bool] = {} + self._concurrency_controller = AdaptiveConcurrencyController() + + def is_import_running(self, operation_id: Optional[str] = None) -> bool: + if operation_id: + progress = self._active_operations.get(operation_id) + return progress is not None and progress.status in ("pending", "running") + return any( + p.status in ("pending", "running") for p in self._active_operations.values() + ) + + def get_progress(self, operation_id: str) -> Optional[BatchImportProgress]: + return self._active_operations.get(operation_id) + + def cancel_import(self, operation_id: str) -> bool: + if operation_id in self._active_operations: + self._cancellation_flags[operation_id] = True + return True + return False + + def _validate_url(self, url: str) -> bool: + import re + + url_pattern = re.compile( + r"^https?://" + r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|" + r"localhost|" + r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" + r"(?::\d+)?" + r"(?:/?|[/?]\S+)$", + re.IGNORECASE, + ) + return url_pattern.match(url) is not None + + def _validate_local_path(self, path: str) -> bool: + try: + normalized = os.path.normpath(path) + if not os.path.isabs(normalized): + return False + if ".." in normalized: + return False + return True + except Exception: + return False + + def _is_duplicate_source( + self, + source: str, + item_type: ImportItemType, + recipe_scanner: Any, + ) -> bool: + try: + cache = recipe_scanner.get_cached_data_sync() + if not cache: + return False + + for recipe in getattr(cache, "raw_data", []): + source_path = recipe.get("source_path") or recipe.get("source_url") + if source_path and source_path == source: + return True + return False + except Exception: + self._logger.warning("Failed to check for duplicates", exc_info=True) + return False + + async def start_batch_import( + self, + *, + recipe_scanner_getter: Callable[[], Any], + civitai_client_getter: Callable[[], Any], + items: List[Dict[str, str]], + tags: Optional[List[str]] = None, + skip_no_metadata: bool = True, + skip_duplicates: bool = False, + ) -> str: + operation_id = str(uuid.uuid4()) + + import_items = [] + for idx, item in enumerate(items): + source = item.get("source", "") + item_type_str = item.get("type", "url") + + if item_type_str == "url" or source.startswith(("http://", "https://")): + item_type = ImportItemType.URL + else: + item_type = ImportItemType.LOCAL_PATH + + batch_import_item = BatchImportItem( + id=f"{operation_id}_{idx}", + source=source, + item_type=item_type, + ) + import_items.append(batch_import_item) + + progress = BatchImportProgress( + operation_id=operation_id, + total=len(import_items), + items=import_items, + tags=tags or [], + skip_no_metadata=skip_no_metadata, + skip_duplicates=skip_duplicates, + ) + + self._active_operations[operation_id] = progress + self._cancellation_flags[operation_id] = False + + asyncio.create_task( + self._run_batch_import( + operation_id=operation_id, + recipe_scanner_getter=recipe_scanner_getter, + civitai_client_getter=civitai_client_getter, + ) + ) + + return operation_id + + async def start_directory_import( + self, + *, + recipe_scanner_getter: Callable[[], Any], + civitai_client_getter: Callable[[], Any], + directory: str, + recursive: bool = True, + tags: Optional[List[str]] = None, + skip_no_metadata: bool = True, + skip_duplicates: bool = False, + ) -> str: + image_paths = await self._discover_images(directory, recursive) + + items = [{"source": path, "type": "local_path"} for path in image_paths] + + return await self.start_batch_import( + recipe_scanner_getter=recipe_scanner_getter, + civitai_client_getter=civitai_client_getter, + items=items, + tags=tags, + skip_no_metadata=skip_no_metadata, + skip_duplicates=skip_duplicates, + ) + + async def _discover_images( + self, + directory: str, + recursive: bool = True, + ) -> List[str]: + if not os.path.isdir(directory): + raise RecipeValidationError(f"Directory not found: {directory}") + + image_paths: List[str] = [] + + if recursive: + for root, _, files in os.walk(directory): + for filename in files: + if self._is_supported_image(filename): + image_paths.append(os.path.join(root, filename)) + else: + for filename in os.listdir(directory): + filepath = os.path.join(directory, filename) + if os.path.isfile(filepath) and self._is_supported_image(filename): + image_paths.append(filepath) + + return sorted(image_paths) + + def _is_supported_image(self, filename: str) -> bool: + ext = os.path.splitext(filename)[1].lower() + return ext in self.SUPPORTED_EXTENSIONS + + async def _run_batch_import( + self, + *, + operation_id: str, + recipe_scanner_getter: Callable[[], Any], + civitai_client_getter: Callable[[], Any], + ) -> None: + progress = self._active_operations.get(operation_id) + if not progress: + return + + progress.status = "running" + await self._broadcast_progress(progress) + + self._concurrency_controller = AdaptiveConcurrencyController() + + async def process_item(item: BatchImportItem) -> None: + if self._cancellation_flags.get(operation_id, False): + return + + progress.current_item = ( + os.path.basename(item.source) + if item.item_type == ImportItemType.LOCAL_PATH + else item.source[:50] + ) + item.status = ImportStatus.PROCESSING + await self._broadcast_progress(progress) + + start_time = time.time() + try: + result = await self._import_single_item( + item=item, + recipe_scanner_getter=recipe_scanner_getter, + civitai_client_getter=civitai_client_getter, + tags=progress.tags, + skip_no_metadata=progress.skip_no_metadata, + skip_duplicates=progress.skip_duplicates, + semaphore=self._concurrency_controller.get_semaphore(), + ) + + duration = time.time() - start_time + item.duration = duration + self._concurrency_controller.record_result( + duration, result.get("success", False) + ) + + if result.get("success"): + item.status = ImportStatus.SUCCESS + item.recipe_name = result.get("recipe_name") + item.recipe_id = result.get("recipe_id") + progress.success += 1 + elif result.get("skipped"): + item.status = ImportStatus.SKIPPED + item.error_message = result.get("error") + progress.skipped += 1 + else: + item.status = ImportStatus.FAILED + item.error_message = result.get("error") + progress.failed += 1 + + except Exception as e: + self._logger.error(f"Error importing {item.source}: {e}") + item.status = ImportStatus.FAILED + item.error_message = str(e) + item.duration = time.time() - start_time + progress.failed += 1 + self._concurrency_controller.record_result(item.duration, False) + + progress.completed += 1 + await self._broadcast_progress(progress) + + tasks = [process_item(item) for item in progress.items] + await asyncio.gather(*tasks, return_exceptions=True) + + if self._cancellation_flags.get(operation_id, False): + progress.status = "cancelled" + else: + progress.status = "completed" + + progress.finished_at = time.time() + progress.current_item = "" + await self._broadcast_progress(progress) + + await asyncio.sleep(5) + self._cleanup_operation(operation_id) + + async def _import_single_item( + self, + *, + item: BatchImportItem, + recipe_scanner_getter: Callable[[], Any], + civitai_client_getter: Callable[[], Any], + tags: List[str], + skip_no_metadata: bool, + skip_duplicates: bool, + semaphore: asyncio.Semaphore, + ) -> Dict[str, Any]: + async with semaphore: + recipe_scanner = recipe_scanner_getter() + if recipe_scanner is None: + return {"success": False, "error": "Recipe scanner unavailable"} + + try: + if item.item_type == ImportItemType.URL: + if not self._validate_url(item.source): + return { + "success": False, + "error": f"Invalid URL format: {item.source}", + } + + if skip_duplicates: + if self._is_duplicate_source( + item.source, item.item_type, recipe_scanner + ): + return { + "success": False, + "skipped": True, + "error": "Duplicate source URL", + } + + civitai_client = civitai_client_getter() + analysis_result = await self._analysis_service.analyze_remote_image( + url=item.source, + recipe_scanner=recipe_scanner, + civitai_client=civitai_client, + ) + else: + if not self._validate_local_path(item.source): + return { + "success": False, + "error": f"Invalid or unsafe path: {item.source}", + } + + if not os.path.exists(item.source): + return { + "success": False, + "error": f"File not found: {item.source}", + } + + if skip_duplicates: + if self._is_duplicate_source( + item.source, item.item_type, recipe_scanner + ): + return { + "success": False, + "skipped": True, + "error": "Duplicate source path", + } + + analysis_result = await self._analysis_service.analyze_local_image( + file_path=item.source, + recipe_scanner=recipe_scanner, + ) + + payload = analysis_result.payload + + if payload.get("error"): + if skip_no_metadata and "No metadata" in payload.get("error", ""): + return { + "success": False, + "skipped": True, + "error": payload["error"], + } + return {"success": False, "error": payload["error"]} + + loras = payload.get("loras", []) + if not loras: + if skip_no_metadata: + return { + "success": False, + "skipped": True, + "error": "No LoRAs found in image", + } + return {"success": False, "error": "No LoRAs found in image"} + + recipe_name = self._generate_recipe_name(item, payload) + all_tags = list(set(tags + (payload.get("tags", []) or []))) + + metadata = { + "base_model": payload.get("base_model", ""), + "loras": loras, + "gen_params": payload.get("gen_params", {}), + "source_path": item.source, + } + + if payload.get("checkpoint"): + metadata["checkpoint"] = payload["checkpoint"] + + image_bytes = None + image_base64 = payload.get("image_base64") + + if item.item_type == ImportItemType.LOCAL_PATH: + with open(item.source, "rb") as f: + image_bytes = f.read() + image_base64 = None + + save_result = await self._persistence_service.save_recipe( + recipe_scanner=recipe_scanner, + image_bytes=image_bytes, + image_base64=image_base64, + name=recipe_name, + tags=all_tags, + metadata=metadata, + extension=payload.get("extension"), + ) + + if save_result.status == 200: + return { + "success": True, + "recipe_name": recipe_name, + "recipe_id": save_result.payload.get("id"), + } + else: + return { + "success": False, + "error": save_result.payload.get( + "error", "Failed to save recipe" + ), + } + + except RecipeValidationError as e: + return {"success": False, "error": str(e)} + except RecipeDownloadError as e: + return {"success": False, "error": str(e)} + except RecipeNotFoundError as e: + return {"success": False, "skipped": True, "error": str(e)} + except Exception as e: + self._logger.error( + f"Unexpected error importing {item.source}: {e}", exc_info=True + ) + return {"success": False, "error": str(e)} + + def _generate_recipe_name( + self, item: BatchImportItem, payload: Dict[str, Any] + ) -> str: + if item.item_type == ImportItemType.LOCAL_PATH: + base_name = os.path.splitext(os.path.basename(item.source))[0] + return base_name[:100] + else: + loras = payload.get("loras", []) + if loras: + first_lora = loras[0].get("name", "Recipe") + return f"Import - {first_lora}"[:100] + return f"Imported Recipe {item.id[:8]}" + + async def _broadcast_progress(self, progress: BatchImportProgress) -> None: + await self._ws_manager.broadcast( + { + "type": "batch_import_progress", + **progress.to_dict(), + } + ) + + def _cleanup_operation(self, operation_id: str) -> None: + if operation_id in self._cancellation_flags: + del self._cancellation_flags[operation_id] diff --git a/tests/routes/test_recipe_routes.py b/tests/routes/test_recipe_routes.py index 7619e330..55b7b05f 100644 --- a/tests/routes/test_recipe_routes.py +++ b/tests/routes/test_recipe_routes.py @@ -1,4 +1,5 @@ """Integration smoke tests for the recipe route stack.""" + from __future__ import annotations import json @@ -94,19 +95,25 @@ class StubAnalysisService: self._recipe_parser_factory = None StubAnalysisService.instances.append(self) - async def analyze_uploaded_image(self, *, image_bytes: bytes | None, recipe_scanner) -> SimpleNamespace: # noqa: D401 - mirrors real signature + async def analyze_uploaded_image( + self, *, image_bytes: bytes | None, recipe_scanner + ) -> SimpleNamespace: # noqa: D401 - mirrors real signature if self.raise_for_uploaded: raise self.raise_for_uploaded self.upload_calls.append(image_bytes or b"") return self.result - async def analyze_remote_image(self, *, url: Optional[str], recipe_scanner, civitai_client) -> SimpleNamespace: # noqa: D401 + async def analyze_remote_image( + self, *, url: Optional[str], recipe_scanner, civitai_client + ) -> SimpleNamespace: # noqa: D401 if self.raise_for_remote: raise self.raise_for_remote self.remote_calls.append(url) return self.result - async def analyze_local_image(self, *, file_path: Optional[str], recipe_scanner) -> SimpleNamespace: # noqa: D401 + async def analyze_local_image( + self, *, file_path: Optional[str], recipe_scanner + ) -> SimpleNamespace: # noqa: D401 if self.raise_for_local: raise self.raise_for_local self.local_calls.append(file_path) @@ -125,11 +132,23 @@ class StubPersistenceService: self.save_calls: List[Dict[str, Any]] = [] self.delete_calls: List[str] = [] self.move_calls: List[Dict[str, str]] = [] - self.save_result = SimpleNamespace(payload={"success": True, "recipe_id": "stub-id"}, status=200) + self.save_result = SimpleNamespace( + payload={"success": True, "recipe_id": "stub-id"}, status=200 + ) self.delete_result = SimpleNamespace(payload={"success": True}, status=200) StubPersistenceService.instances.append(self) - async def save_recipe(self, *, recipe_scanner, image_bytes, image_base64, name, tags, metadata, extension=None) -> SimpleNamespace: # noqa: D401 + async def save_recipe( + self, + *, + recipe_scanner, + image_bytes, + image_base64, + name, + tags, + metadata, + extension=None, + ) -> SimpleNamespace: # noqa: D401 self.save_calls.append( { "recipe_scanner": recipe_scanner, @@ -148,22 +167,42 @@ class StubPersistenceService: await recipe_scanner.remove_recipe(recipe_id) return self.delete_result - async def move_recipe(self, *, recipe_scanner, recipe_id: str, target_path: str) -> SimpleNamespace: # noqa: D401 + async def move_recipe( + self, *, recipe_scanner, recipe_id: str, target_path: str + ) -> SimpleNamespace: # noqa: D401 self.move_calls.append({"recipe_id": recipe_id, "target_path": target_path}) return SimpleNamespace( - payload={"success": True, "recipe_id": recipe_id, "new_file_path": target_path}, status=200 + payload={ + "success": True, + "recipe_id": recipe_id, + "new_file_path": target_path, + }, + status=200, ) - async def update_recipe(self, *, recipe_scanner, recipe_id: str, updates: Dict[str, Any]) -> SimpleNamespace: # pragma: no cover - unused by smoke tests - return SimpleNamespace(payload={"success": True, "recipe_id": recipe_id, "updates": updates}, status=200) + async def update_recipe( + self, *, recipe_scanner, recipe_id: str, updates: Dict[str, Any] + ) -> SimpleNamespace: # pragma: no cover - unused by smoke tests + return SimpleNamespace( + payload={"success": True, "recipe_id": recipe_id, "updates": updates}, + status=200, + ) - async def reconnect_lora(self, *, recipe_scanner, recipe_id: str, lora_index: int, target_name: str) -> SimpleNamespace: # pragma: no cover + async def reconnect_lora( + self, *, recipe_scanner, recipe_id: str, lora_index: int, target_name: str + ) -> SimpleNamespace: # pragma: no cover return SimpleNamespace(payload={"success": True}, status=200) - async def bulk_delete(self, *, recipe_scanner, recipe_ids: List[str]) -> SimpleNamespace: # pragma: no cover - return SimpleNamespace(payload={"success": True, "deleted": recipe_ids}, status=200) + async def bulk_delete( + self, *, recipe_scanner, recipe_ids: List[str] + ) -> SimpleNamespace: # pragma: no cover + return SimpleNamespace( + payload={"success": True, "deleted": recipe_ids}, status=200 + ) - async def save_recipe_from_widget(self, *, recipe_scanner, metadata: Dict[str, Any], image_bytes: bytes) -> SimpleNamespace: # pragma: no cover + async def save_recipe_from_widget( + self, *, recipe_scanner, metadata: Dict[str, Any], image_bytes: bytes + ) -> SimpleNamespace: # pragma: no cover return SimpleNamespace(payload={"success": True}, status=200) @@ -176,7 +215,11 @@ class StubSharingService: self.share_calls: List[str] = [] self.download_calls: List[str] = [] self.share_result = SimpleNamespace( - payload={"success": True, "download_url": "/share/stub", "filename": "recipe.png"}, + payload={ + "success": True, + "download_url": "/share/stub", + "filename": "recipe.png", + }, status=200, ) self.download_info = SimpleNamespace(file_path="", download_filename="") @@ -186,7 +229,9 @@ class StubSharingService: self.share_calls.append(recipe_id) return self.share_result - async def prepare_download(self, *, recipe_scanner, recipe_id: str) -> SimpleNamespace: + async def prepare_download( + self, *, recipe_scanner, recipe_id: str + ) -> SimpleNamespace: self.download_calls.append(recipe_id) return self.download_info @@ -214,7 +259,9 @@ class StubCivitaiClient: @asynccontextmanager -async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRouteHarness]: +async def recipe_harness( + monkeypatch, tmp_path: Path +) -> AsyncIterator[RecipeRouteHarness]: """Context manager that yields a fully wired recipe route harness.""" StubAnalysisService.instances.clear() @@ -237,8 +284,12 @@ async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRou monkeypatch.setattr(ServiceRegistry, "get_recipe_scanner", fake_get_recipe_scanner) monkeypatch.setattr(ServiceRegistry, "get_civitai_client", fake_get_civitai_client) - monkeypatch.setattr(base_recipe_routes, "RecipeAnalysisService", StubAnalysisService) - monkeypatch.setattr(base_recipe_routes, "RecipePersistenceService", StubPersistenceService) + monkeypatch.setattr( + base_recipe_routes, "RecipeAnalysisService", StubAnalysisService + ) + monkeypatch.setattr( + base_recipe_routes, "RecipePersistenceService", StubPersistenceService + ) monkeypatch.setattr(base_recipe_routes, "RecipeSharingService", StubSharingService) monkeypatch.setattr(base_recipe_routes, "get_downloader", fake_get_downloader) monkeypatch.setattr(config, "loras_roots", [str(tmp_path)], raising=False) @@ -294,7 +345,9 @@ async def test_list_recipes_provides_file_urls(monkeypatch, tmp_path: Path) -> N async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) -> None: async with recipe_harness(monkeypatch, tmp_path) as harness: form = FormData() - form.add_field("image", b"stub", filename="sample.png", content_type="image/png") + form.add_field( + "image", b"stub", filename="sample.png", content_type="image/png" + ) form.add_field("name", "Test Recipe") form.add_field("tags", json.dumps(["tag-a"])) form.add_field("metadata", json.dumps({"loras": []})) @@ -312,7 +365,9 @@ async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) -> assert save_payload["recipe_id"] == "saved-id" assert harness.persistence.save_calls[-1]["name"] == "Test Recipe" - harness.persistence.delete_result = SimpleNamespace(payload={"success": True}, status=200) + harness.persistence.delete_result = SimpleNamespace( + payload={"success": True}, status=200 + ) delete_response = await harness.client.delete("/api/lm/recipe/saved-id") delete_payload = await delete_response.json() @@ -326,14 +381,20 @@ async def test_move_recipe_invokes_persistence(monkeypatch, tmp_path: Path) -> N async with recipe_harness(monkeypatch, tmp_path) as harness: response = await harness.client.post( "/api/lm/recipe/move", - json={"recipe_id": "move-me", "target_path": str(tmp_path / "recipes" / "subdir")}, + json={ + "recipe_id": "move-me", + "target_path": str(tmp_path / "recipes" / "subdir"), + }, ) payload = await response.json() assert response.status == 200 assert payload["recipe_id"] == "move-me" assert harness.persistence.move_calls == [ - {"recipe_id": "move-me", "target_path": str(tmp_path / "recipes" / "subdir")} + { + "recipe_id": "move-me", + "target_path": str(tmp_path / "recipes" / "subdir"), + } ] @@ -348,7 +409,10 @@ async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None: async def fake_get_default_metadata_provider(): return Provider() - monkeypatch.setattr("py.recipes.enrichment.get_default_metadata_provider", fake_get_default_metadata_provider) + monkeypatch.setattr( + "py.recipes.enrichment.get_default_metadata_provider", + fake_get_default_metadata_provider, + ) async with recipe_harness(monkeypatch, tmp_path) as harness: resources = [ @@ -397,7 +461,9 @@ async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None: assert harness.downloader.urls == ["https://example.com/images/1"] -async def test_import_remote_recipe_falls_back_to_request_base_model(monkeypatch, tmp_path: Path) -> None: +async def test_import_remote_recipe_falls_back_to_request_base_model( + monkeypatch, tmp_path: Path +) -> None: provider_calls: list[str | int] = [] class Provider: @@ -408,7 +474,10 @@ async def test_import_remote_recipe_falls_back_to_request_base_model(monkeypatch async def fake_get_default_metadata_provider(): return Provider() - monkeypatch.setattr("py.recipes.enrichment.get_default_metadata_provider", fake_get_default_metadata_provider) + monkeypatch.setattr( + "py.recipes.enrichment.get_default_metadata_provider", + fake_get_default_metadata_provider, + ) async with recipe_harness(monkeypatch, tmp_path) as harness: resources = [ @@ -444,13 +513,16 @@ async def test_import_remote_video_recipe(monkeypatch, tmp_path: Path) -> None: async def fake_get_default_metadata_provider(): return SimpleNamespace(get_model_version_info=lambda id: ({}, None)) - monkeypatch.setattr("py.recipes.enrichment.get_default_metadata_provider", fake_get_default_metadata_provider) + monkeypatch.setattr( + "py.recipes.enrichment.get_default_metadata_provider", + fake_get_default_metadata_provider, + ) async with recipe_harness(monkeypatch, tmp_path) as harness: harness.civitai.image_info["12345"] = { "id": 12345, "url": "https://image.civitai.com/x/y/original=true/video.mp4", - "type": "video" + "type": "video", } response = await harness.client.get( @@ -469,7 +541,7 @@ async def test_import_remote_video_recipe(monkeypatch, tmp_path: Path) -> None: # Verify downloader was called with rewritten URL assert "transcode=true" in harness.downloader.urls[0] - + # Verify persistence was called with correct extension call = harness.persistence.save_calls[-1] assert call["extension"] == ".mp4" @@ -477,7 +549,9 @@ async def test_import_remote_video_recipe(monkeypatch, tmp_path: Path) -> None: async def test_analyze_uploaded_image_error_path(monkeypatch, tmp_path: Path) -> None: async with recipe_harness(monkeypatch, tmp_path) as harness: - harness.analysis.raise_for_uploaded = RecipeValidationError("No image data provided") + harness.analysis.raise_for_uploaded = RecipeValidationError( + "No image data provided" + ) form = FormData() form.add_field("image", b"", filename="empty.png", content_type="image/png") @@ -504,7 +578,11 @@ async def test_share_and_download_recipe(monkeypatch, tmp_path: Path) -> None: } harness.sharing.share_result = SimpleNamespace( - payload={"success": True, "download_url": "/api/share", "filename": "share.png"}, + payload={ + "success": True, + "download_url": "/api/share", + "filename": "share.png", + }, status=200, ) harness.sharing.download_info = SimpleNamespace( @@ -519,15 +597,24 @@ async def test_share_and_download_recipe(monkeypatch, tmp_path: Path) -> None: assert share_payload["filename"] == "share.png" assert harness.sharing.share_calls == [recipe_id] - download_response = await harness.client.get(f"/api/lm/recipe/{recipe_id}/share/download") + download_response = await harness.client.get( + f"/api/lm/recipe/{recipe_id}/share/download" + ) body = await download_response.read() assert download_response.status == 200 - assert download_response.headers["Content-Disposition"] == 'attachment; filename="share.png"' + assert ( + download_response.headers["Content-Disposition"] + == 'attachment; filename="share.png"' + ) assert body == b"stub" download_path.unlink(missing_ok=True) -async def test_import_remote_recipe_merges_metadata(monkeypatch, tmp_path: Path) -> None: + + +async def test_import_remote_recipe_merges_metadata( + monkeypatch, tmp_path: Path +) -> None: # 1. Mock Metadata Provider class Provider: async def get_model_version_info(self, model_version_id): @@ -536,22 +623,25 @@ async def test_import_remote_recipe_merges_metadata(monkeypatch, tmp_path: Path) async def fake_get_default_metadata_provider(): return Provider() - monkeypatch.setattr("py.recipes.enrichment.get_default_metadata_provider", fake_get_default_metadata_provider) + monkeypatch.setattr( + "py.recipes.enrichment.get_default_metadata_provider", + fake_get_default_metadata_provider, + ) # 2. Mock ExifUtils to return some embedded metadata class MockExifUtils: @staticmethod def extract_image_metadata(path): - return "Recipe metadata: " + json.dumps({ - "gen_params": {"prompt": "from embedded", "seed": 123} - }) + return "Recipe metadata: " + json.dumps( + {"gen_params": {"prompt": "from embedded", "seed": 123}} + ) monkeypatch.setattr(recipe_handlers, "ExifUtils", MockExifUtils) # 3. Mock Parser Factory for StubAnalysisService class MockParser: async def parse_metadata(self, raw, recipe_scanner=None): - return json.loads(raw[len("Recipe metadata: "):]) + return json.loads(raw[len("Recipe metadata: ") :]) class MockFactory: def create_parser(self, raw): @@ -562,12 +652,12 @@ async def test_import_remote_recipe_merges_metadata(monkeypatch, tmp_path: Path) # 4. Setup Harness and run test async with recipe_harness(monkeypatch, tmp_path) as harness: harness.analysis._recipe_parser_factory = MockFactory() - + # Civitai meta via image_info harness.civitai.image_info["1"] = { "id": 1, "url": "https://example.com/images/1.jpg", - "meta": {"prompt": "from civitai", "cfg": 7.0} + "meta": {"prompt": "from civitai", "cfg": 7.0}, } resources = [] @@ -583,11 +673,11 @@ async def test_import_remote_recipe_merges_metadata(monkeypatch, tmp_path: Path) payload = await response.json() assert response.status == 200 - + call = harness.persistence.save_calls[-1] metadata = call["metadata"] gen_params = metadata["gen_params"] - + assert gen_params["seed"] == 123 @@ -619,3 +709,142 @@ async def test_get_recipe_syntax(monkeypatch, tmp_path: Path) -> None: response_404 = await harness.client.get("/api/lm/recipe/non-existent/syntax") assert response_404.status == 404 + +async def test_batch_import_start_success(monkeypatch, tmp_path: Path) -> None: + async with recipe_harness(monkeypatch, tmp_path) as harness: + response = await harness.client.post( + "/api/lm/recipes/batch-import/start", + json={ + "items": [ + {"source": "https://example.com/image1.png"}, + {"source": "https://example.com/image2.png"}, + ], + "tags": ["batch", "import"], + "skip_no_metadata": True, + }, + ) + payload = await response.json() + assert response.status == 200 + assert payload["success"] is True + assert "operation_id" in payload + + +async def test_batch_import_start_empty_items(monkeypatch, tmp_path: Path) -> None: + async with recipe_harness(monkeypatch, tmp_path) as harness: + response = await harness.client.post( + "/api/lm/recipes/batch-import/start", + json={"items": [], "tags": []}, + ) + payload = await response.json() + assert response.status == 400 + assert payload["success"] is False + assert "No items provided" in payload["error"] + + +async def test_batch_import_start_missing_source(monkeypatch, tmp_path: Path) -> None: + async with recipe_harness(monkeypatch, tmp_path) as harness: + response = await harness.client.post( + "/api/lm/recipes/batch-import/start", + json={"items": [{"source": ""}]}, + ) + payload = await response.json() + assert response.status == 400 + assert payload["success"] is False + assert "source" in payload["error"].lower() + + +async def test_batch_import_start_already_running(monkeypatch, tmp_path: Path) -> None: + import asyncio + + async with recipe_harness(monkeypatch, tmp_path) as harness: + original_analyze = harness.analysis.analyze_remote_image + + async def slow_analyze(*, url, recipe_scanner, civitai_client): + await asyncio.sleep(0.5) + return await original_analyze( + url=url, recipe_scanner=recipe_scanner, civitai_client=civitai_client + ) + + harness.analysis.analyze_remote_image = slow_analyze + + items = [{"source": f"https://example.com/image{i}.png"} for i in range(10)] + + response1 = await harness.client.post( + "/api/lm/recipes/batch-import/start", + json={"items": items}, + ) + assert response1.status == 200 + + payload1 = await response1.json() + assert payload1["success"] is True + + await asyncio.sleep(0.1) + + response2 = await harness.client.post( + "/api/lm/recipes/batch-import/start", + json={"items": [{"source": "https://example.com/other.png"}]}, + ) + payload2 = await response2.json() + assert response2.status == 409 + assert "already in progress" in payload2["error"].lower() + + +async def test_batch_import_get_progress_not_found(monkeypatch, tmp_path: Path) -> None: + async with recipe_harness(monkeypatch, tmp_path) as harness: + response = await harness.client.get( + "/api/lm/recipes/batch-import/progress", + params={"operation_id": "nonexistent-id"}, + ) + payload = await response.json() + assert response.status == 404 + assert payload["success"] is False + + +async def test_batch_import_get_progress_missing_id( + monkeypatch, tmp_path: Path +) -> None: + async with recipe_harness(monkeypatch, tmp_path) as harness: + response = await harness.client.get("/api/lm/recipes/batch-import/progress") + payload = await response.json() + assert response.status == 400 + assert payload["success"] is False + + +async def test_batch_import_cancel_success(monkeypatch, tmp_path: Path) -> None: + async with recipe_harness(monkeypatch, tmp_path) as harness: + start_response = await harness.client.post( + "/api/lm/recipes/batch-import/start", + json={"items": [{"source": "https://example.com/image.png"}]}, + ) + start_payload = await start_response.json() + operation_id = start_payload["operation_id"] + + cancel_response = await harness.client.post( + "/api/lm/recipes/batch-import/cancel", + json={"operation_id": operation_id}, + ) + cancel_payload = await cancel_response.json() + assert cancel_response.status == 200 + assert cancel_payload["success"] is True + + +async def test_batch_import_cancel_not_found(monkeypatch, tmp_path: Path) -> None: + async with recipe_harness(monkeypatch, tmp_path) as harness: + response = await harness.client.post( + "/api/lm/recipes/batch-import/cancel", + json={"operation_id": "nonexistent-id"}, + ) + payload = await response.json() + assert response.status == 404 + assert payload["success"] is False + + +async def test_batch_import_cancel_missing_id(monkeypatch, tmp_path: Path) -> None: + async with recipe_harness(monkeypatch, tmp_path) as harness: + response = await harness.client.post( + "/api/lm/recipes/batch-import/cancel", + json={}, + ) + payload = await response.json() + assert response.status == 400 + assert payload["success"] is False diff --git a/tests/services/test_batch_import_service.py b/tests/services/test_batch_import_service.py new file mode 100644 index 00000000..003a0c12 --- /dev/null +++ b/tests/services/test_batch_import_service.py @@ -0,0 +1,597 @@ +"""Unit tests for BatchImportService.""" + +from __future__ import annotations + +import asyncio +import logging +import os +import tempfile +from dataclasses import dataclass +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Dict, List, Optional +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from py.services.batch_import_service import ( + AdaptiveConcurrencyController, + BatchImportItem, + BatchImportProgress, + BatchImportService, + ImportItemType, + ImportStatus, +) + + +class MockWebSocketManager: + def __init__(self): + self.broadcasts: List[Dict[str, Any]] = [] + + async def broadcast(self, data: Dict[str, Any]): + self.broadcasts.append(data) + + +@dataclass +class MockAnalysisResult: + payload: Dict[str, Any] + status: int = 200 + + +class MockAnalysisService: + def __init__(self, results: Optional[Dict[str, MockAnalysisResult]] = None): + self.results = results or {} + self.call_count = 0 + self.last_url = None + self.last_path = None + + async def analyze_remote_image(self, *, url: str, recipe_scanner, civitai_client): + self.call_count += 1 + self.last_url = url + if url in self.results: + return self.results[url] + return MockAnalysisResult({"error": "No metadata found", "loras": []}) + + async def analyze_local_image(self, *, file_path: str, recipe_scanner): + self.call_count += 1 + self.last_path = file_path + if file_path in self.results: + return self.results[file_path] + return MockAnalysisResult({"error": "No metadata found", "loras": []}) + + +@dataclass +class MockSaveResult: + payload: Dict[str, Any] + status: int = 200 + + +class MockPersistenceService: + def __init__(self, should_succeed: bool = True): + self.should_succeed = should_succeed + self.saved_recipes: List[Dict[str, Any]] = [] + self.call_count = 0 + + async def save_recipe( + self, + *, + recipe_scanner, + image_bytes: Optional[bytes] = None, + image_base64: Optional[str] = None, + name: str, + tags: List[str], + metadata: Dict[str, Any], + extension: Optional[str] = None, + ): + self.call_count += 1 + self.saved_recipes.append( + { + "name": name, + "tags": tags, + "metadata": metadata, + } + ) + if self.should_succeed: + return MockSaveResult({"success": True, "id": f"recipe_{self.call_count}"}) + return MockSaveResult({"success": False, "error": "Save failed"}, status=400) + + +class TestAdaptiveConcurrencyController: + def test_initial_values(self): + controller = AdaptiveConcurrencyController() + assert controller.current_concurrency == 3 + assert controller.min_concurrency == 1 + assert controller.max_concurrency == 5 + + def test_custom_initial_values(self): + controller = AdaptiveConcurrencyController( + min_concurrency=2, + max_concurrency=10, + initial_concurrency=5, + ) + assert controller.current_concurrency == 5 + assert controller.min_concurrency == 2 + assert controller.max_concurrency == 10 + + def test_increase_concurrency_on_success(self): + controller = AdaptiveConcurrencyController(initial_concurrency=3) + controller.record_result(duration=0.5, success=True) + assert controller.current_concurrency == 4 + + def test_do_not_exceed_max(self): + controller = AdaptiveConcurrencyController( + max_concurrency=5, + initial_concurrency=5, + ) + controller.record_result(duration=0.5, success=True) + assert controller.current_concurrency == 5 + + def test_decrease_concurrency_on_failure(self): + controller = AdaptiveConcurrencyController(initial_concurrency=3) + controller.record_result(duration=1.0, success=False) + assert controller.current_concurrency == 2 + + def test_do_not_go_below_min(self): + controller = AdaptiveConcurrencyController( + min_concurrency=1, + initial_concurrency=1, + ) + controller.record_result(duration=1.0, success=False) + assert controller.current_concurrency == 1 + + def test_slow_task_decreases_concurrency(self): + controller = AdaptiveConcurrencyController(initial_concurrency=3) + controller.record_result(duration=11.0, success=True) + assert controller.current_concurrency == 2 + + def test_fast_task_increases_concurrency(self): + controller = AdaptiveConcurrencyController(initial_concurrency=3) + controller.record_result(duration=0.5, success=True) + assert controller.current_concurrency == 4 + + def test_moderate_task_no_change(self): + controller = AdaptiveConcurrencyController(initial_concurrency=3) + controller.record_result(duration=5.0, success=True) + assert controller.current_concurrency == 3 + + +class TestBatchImportProgress: + def test_to_dict(self): + progress = BatchImportProgress( + operation_id="test-123", + total=10, + completed=5, + success=3, + failed=2, + skipped=0, + current_item="image.png", + status="running", + ) + result = progress.to_dict() + assert result["operation_id"] == "test-123" + assert result["total"] == 10 + assert result["completed"] == 5 + assert result["success"] == 3 + assert result["failed"] == 2 + assert result["progress_percent"] == 50.0 + + def test_progress_percent_zero_total(self): + progress = BatchImportProgress( + operation_id="test-123", + total=0, + ) + assert progress.to_dict()["progress_percent"] == 0 + + +class TestBatchImportItem: + def test_defaults(self): + item = BatchImportItem( + id="item-1", + source="https://example.com/image.png", + item_type=ImportItemType.URL, + ) + assert item.status == ImportStatus.PENDING + assert item.error_message is None + assert item.recipe_name is None + + +class TestBatchImportService: + @pytest.fixture + def mock_services(self): + ws_manager = MockWebSocketManager() + analysis_service = MockAnalysisService() + persistence_service = MockPersistenceService() + logger = logging.getLogger("test") + return ws_manager, analysis_service, persistence_service, logger + + @pytest.fixture + def service(self, mock_services): + ws_manager, analysis_service, persistence_service, logger = mock_services + return BatchImportService( + analysis_service=analysis_service, + persistence_service=persistence_service, + ws_manager=ws_manager, + logger=logger, + ) + + def test_is_import_running_no_operations(self, service): + assert not service.is_import_running() + + @pytest.mark.asyncio + async def test_start_batch_import_creates_operation(self, service): + recipe_scanner_getter = lambda: SimpleNamespace() + civitai_client_getter = lambda: SimpleNamespace() + + operation_id = await service.start_batch_import( + recipe_scanner_getter=recipe_scanner_getter, + civitai_client_getter=civitai_client_getter, + items=[{"source": "https://example.com/image.png"}], + ) + + assert operation_id is not None + assert service.is_import_running(operation_id) + + @pytest.mark.asyncio + async def test_get_progress(self, service): + recipe_scanner_getter = lambda: SimpleNamespace() + civitai_client_getter = lambda: SimpleNamespace() + + operation_id = await service.start_batch_import( + recipe_scanner_getter=recipe_scanner_getter, + civitai_client_getter=civitai_client_getter, + items=[ + {"source": "https://example.com/1.png"}, + {"source": "https://example.com/2.png"}, + ], + ) + + progress = service.get_progress(operation_id) + assert progress is not None + assert progress.total == 2 + assert progress.status in ("pending", "running") + + @pytest.mark.asyncio + async def test_cancel_import(self, service): + recipe_scanner_getter = lambda: SimpleNamespace() + civitai_client_getter = lambda: SimpleNamespace() + + operation_id = await service.start_batch_import( + recipe_scanner_getter=recipe_scanner_getter, + civitai_client_getter=civitai_client_getter, + items=[{"source": "https://example.com/image.png"}], + ) + + assert service.cancel_import(operation_id) is True + assert service.cancel_import("nonexistent") is False + + @pytest.mark.asyncio + async def test_discover_images_non_recursive(self, service, tmp_path): + for i in range(3): + (tmp_path / f"image{i}.png").write_bytes(b"fake-image") + + (tmp_path / "subdir").mkdir() + (tmp_path / "subdir" / "hidden.png").write_bytes(b"fake-image") + + images = await service._discover_images(str(tmp_path), recursive=False) + assert len(images) == 3 + + @pytest.mark.asyncio + async def test_discover_images_recursive(self, service, tmp_path): + for i in range(2): + (tmp_path / f"image{i}.png").write_bytes(b"fake-image") + + subdir = tmp_path / "subdir" + subdir.mkdir() + for i in range(2): + (subdir / f"nested{i}.jpg").write_bytes(b"fake-image") + + images = await service._discover_images(str(tmp_path), recursive=True) + assert len(images) == 4 + + @pytest.mark.asyncio + async def test_discover_images_filters_by_extension(self, service, tmp_path): + (tmp_path / "image.png").write_bytes(b"fake-image") + (tmp_path / "image.jpg").write_bytes(b"fake-image") + (tmp_path / "image.webp").write_bytes(b"fake-image") + (tmp_path / "document.pdf").write_bytes(b"fake-doc") + (tmp_path / "script.py").write_bytes(b"print('hello')") + + images = await service._discover_images(str(tmp_path), recursive=False) + assert len(images) == 3 + + @pytest.mark.asyncio + async def test_discover_images_invalid_directory(self, service): + from py.services.recipes.errors import RecipeValidationError + + with pytest.raises(RecipeValidationError): + await service._discover_images("/nonexistent/path", recursive=False) + + def test_is_supported_image(self, service): + assert service._is_supported_image("test.png") is True + assert service._is_supported_image("test.jpg") is True + assert service._is_supported_image("test.jpeg") is True + assert service._is_supported_image("test.webp") is True + assert service._is_supported_image("test.gif") is True + assert service._is_supported_image("test.bmp") is True + assert service._is_supported_image("test.pdf") is False + assert service._is_supported_image("test.txt") is False + + @pytest.mark.asyncio + async def test_batch_import_processes_items(self, mock_services, tmp_path): + ws_manager, _, persistence_service, logger = mock_services + + analysis_service = MockAnalysisService( + { + "https://example.com/valid.png": MockAnalysisResult( + { + "loras": [{"name": "test-lora", "weight": 1.0}], + "base_model": "SD1.5", + "gen_params": {"steps": 20}, + } + ), + } + ) + + service = BatchImportService( + analysis_service=analysis_service, + persistence_service=persistence_service, + ws_manager=ws_manager, + logger=logger, + ) + + recipe_scanner_getter = lambda: SimpleNamespace( + find_recipes_by_fingerprint=lambda x: [], + add_recipe=lambda x: None, + ) + civitai_client_getter = lambda: SimpleNamespace() + + operation_id = await service.start_batch_import( + recipe_scanner_getter=recipe_scanner_getter, + civitai_client_getter=civitai_client_getter, + items=[ + {"source": "https://example.com/valid.png"}, + {"source": "https://example.com/no-meta.png"}, + ], + skip_no_metadata=True, + ) + + await asyncio.sleep(0.5) + + progress = service.get_progress(operation_id) + assert progress is not None or persistence_service.call_count == 1 + + @pytest.mark.asyncio + async def test_start_directory_import(self, service, tmp_path): + for i in range(5): + (tmp_path / f"image{i}.png").write_bytes(b"fake-image") + + recipe_scanner_getter = lambda: SimpleNamespace() + civitai_client_getter = lambda: SimpleNamespace() + + operation_id = await service.start_directory_import( + recipe_scanner_getter=recipe_scanner_getter, + civitai_client_getter=civitai_client_getter, + directory=str(tmp_path), + recursive=False, + ) + + progress = service.get_progress(operation_id) + assert progress is not None + assert progress.total == 5 + + @pytest.mark.asyncio + async def test_websocket_broadcasts_progress(self, mock_services): + ws_manager, analysis_service, persistence_service, logger = mock_services + + service = BatchImportService( + analysis_service=analysis_service, + persistence_service=persistence_service, + ws_manager=ws_manager, + logger=logger, + ) + + recipe_scanner_getter = lambda: SimpleNamespace() + civitai_client_getter = lambda: SimpleNamespace() + + operation_id = await service.start_batch_import( + recipe_scanner_getter=recipe_scanner_getter, + civitai_client_getter=civitai_client_getter, + items=[{"source": "https://example.com/test.png"}], + ) + + await asyncio.sleep(0.3) + + assert len(ws_manager.broadcasts) > 0 + assert any( + b.get("type") == "batch_import_progress" for b in ws_manager.broadcasts + ) + + @pytest.mark.asyncio + async def test_cancellation_stops_processing(self, mock_services): + ws_manager, analysis_service, persistence_service, logger = mock_services + + service = BatchImportService( + analysis_service=analysis_service, + persistence_service=persistence_service, + ws_manager=ws_manager, + logger=logger, + ) + + recipe_scanner_getter = lambda: SimpleNamespace() + civitai_client_getter = lambda: SimpleNamespace() + + items = [{"source": f"https://example.com/{i}.png"} for i in range(10)] + + operation_id = await service.start_batch_import( + recipe_scanner_getter=recipe_scanner_getter, + civitai_client_getter=civitai_client_getter, + items=items, + ) + + service.cancel_import(operation_id) + await asyncio.sleep(0.3) + + progress = service.get_progress(operation_id) + if progress: + assert progress.status == "cancelled" + + +class TestBatchImportServiceEdgeCases: + @pytest.fixture + def service(self): + ws_manager = MockWebSocketManager() + analysis_service = MockAnalysisService() + persistence_service = MockPersistenceService() + logger = logging.getLogger("test") + + return BatchImportService( + analysis_service=analysis_service, + persistence_service=persistence_service, + ws_manager=ws_manager, + logger=logger, + ) + + @pytest.mark.asyncio + async def test_empty_items_list(self, service): + recipe_scanner_getter = lambda: SimpleNamespace() + civitai_client_getter = lambda: SimpleNamespace() + + operation_id = await service.start_batch_import( + recipe_scanner_getter=recipe_scanner_getter, + civitai_client_getter=civitai_client_getter, + items=[], + ) + + progress = service.get_progress(operation_id) + assert progress is not None + assert progress.total == 0 + + @pytest.mark.asyncio + async def test_mixed_url_and_path_items(self, service, tmp_path): + (tmp_path / "local.png").write_bytes(b"fake-image") + + recipe_scanner_getter = lambda: SimpleNamespace() + civitai_client_getter = lambda: SimpleNamespace() + + operation_id = await service.start_batch_import( + recipe_scanner_getter=recipe_scanner_getter, + civitai_client_getter=civitai_client_getter, + items=[ + {"source": "https://example.com/remote.png", "type": "url"}, + {"source": str(tmp_path / "local.png"), "type": "local_path"}, + ], + ) + + progress = service.get_progress(operation_id) + assert progress is not None + assert progress.total == 2 + assert progress.items[0].item_type == ImportItemType.URL + assert progress.items[1].item_type == ImportItemType.LOCAL_PATH + + @pytest.mark.asyncio + async def test_tags_are_passed_to_persistence(self, tmp_path): + ws_manager = MockWebSocketManager() + analysis_service = MockAnalysisService( + { + str(tmp_path / "test.png"): MockAnalysisResult( + { + "loras": [{"name": "test-lora"}], + } + ), + } + ) + persistence_service = MockPersistenceService() + logger = logging.getLogger("test") + + (tmp_path / "test.png").write_bytes(b"fake-image") + + service = BatchImportService( + analysis_service=analysis_service, + persistence_service=persistence_service, + ws_manager=ws_manager, + logger=logger, + ) + + recipe_scanner_getter = lambda: SimpleNamespace( + find_recipes_by_fingerprint=lambda x: [], + ) + civitai_client_getter = lambda: SimpleNamespace() + + operation_id = await service.start_batch_import( + recipe_scanner_getter=recipe_scanner_getter, + civitai_client_getter=civitai_client_getter, + items=[{"source": str(tmp_path / "test.png")}], + tags=["batch-import", "test"], + ) + + await asyncio.sleep(0.3) + + if persistence_service.saved_recipes: + assert "batch-import" in persistence_service.saved_recipes[0]["tags"] + assert "test" in persistence_service.saved_recipes[0]["tags"] + + @pytest.mark.asyncio + async def test_skip_duplicates_parameter(self, service): + recipe_scanner_getter = lambda: SimpleNamespace() + civitai_client_getter = lambda: SimpleNamespace() + + operation_id = await service.start_batch_import( + recipe_scanner_getter=recipe_scanner_getter, + civitai_client_getter=civitai_client_getter, + items=[{"source": "https://example.com/test.png"}], + skip_duplicates=True, + ) + + progress = service.get_progress(operation_id) + assert progress is not None + assert progress.skip_duplicates is True + + @pytest.mark.asyncio + async def test_skip_duplicates_false_by_default(self, service): + recipe_scanner_getter = lambda: SimpleNamespace() + civitai_client_getter = lambda: SimpleNamespace() + + operation_id = await service.start_batch_import( + recipe_scanner_getter=recipe_scanner_getter, + civitai_client_getter=civitai_client_getter, + items=[{"source": "https://example.com/test.png"}], + ) + + progress = service.get_progress(operation_id) + assert progress is not None + assert progress.skip_duplicates is False + + +class TestInputValidation: + @pytest.fixture + def service(self): + ws_manager = MockWebSocketManager() + analysis_service = MockAnalysisService() + persistence_service = MockPersistenceService() + logger = logging.getLogger("test") + + return BatchImportService( + analysis_service=analysis_service, + persistence_service=persistence_service, + ws_manager=ws_manager, + logger=logger, + ) + + def test_validate_valid_url(self, service): + assert service._validate_url("https://example.com/image.png") is True + assert service._validate_url("http://example.com/image.png") is True + assert service._validate_url("https://civitai.com/images/123") is True + + def test_validate_invalid_url(self, service): + assert service._validate_url("not-a-url") is False + assert service._validate_url("ftp://example.com/file") is False + assert service._validate_url("") is False + + def test_validate_valid_local_path(self, service, tmp_path): + valid_path = str(tmp_path / "image.png") + assert service._validate_local_path(valid_path) is True + + def test_validate_invalid_local_path(self, service): + assert service._validate_local_path("../etc/passwd") is False + assert service._validate_local_path("relative/path.png") is False + assert service._validate_local_path("") is False