mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(batch-import): implement backend batch import service with adaptive concurrency
- Add BatchImportService with concurrent execution using asyncio.gather - Implement AdaptiveConcurrencyController with dynamic adjustment - Add input validation for URLs and local paths - Support duplicate detection via skip_duplicates parameter - Add WebSocket progress broadcasting for real-time updates - Create comprehensive unit tests for batch import functionality - Update API handlers and route registrations - Add i18n translation keys for batch import UI
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Integration smoke tests for the recipe route stack."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
@@ -94,19 +95,25 @@ class StubAnalysisService:
|
||||
self._recipe_parser_factory = None
|
||||
StubAnalysisService.instances.append(self)
|
||||
|
||||
async def analyze_uploaded_image(self, *, image_bytes: bytes | None, recipe_scanner) -> SimpleNamespace: # noqa: D401 - mirrors real signature
|
||||
async def analyze_uploaded_image(
|
||||
self, *, image_bytes: bytes | None, recipe_scanner
|
||||
) -> SimpleNamespace: # noqa: D401 - mirrors real signature
|
||||
if self.raise_for_uploaded:
|
||||
raise self.raise_for_uploaded
|
||||
self.upload_calls.append(image_bytes or b"")
|
||||
return self.result
|
||||
|
||||
async def analyze_remote_image(self, *, url: Optional[str], recipe_scanner, civitai_client) -> SimpleNamespace: # noqa: D401
|
||||
async def analyze_remote_image(
|
||||
self, *, url: Optional[str], recipe_scanner, civitai_client
|
||||
) -> SimpleNamespace: # noqa: D401
|
||||
if self.raise_for_remote:
|
||||
raise self.raise_for_remote
|
||||
self.remote_calls.append(url)
|
||||
return self.result
|
||||
|
||||
async def analyze_local_image(self, *, file_path: Optional[str], recipe_scanner) -> SimpleNamespace: # noqa: D401
|
||||
async def analyze_local_image(
|
||||
self, *, file_path: Optional[str], recipe_scanner
|
||||
) -> SimpleNamespace: # noqa: D401
|
||||
if self.raise_for_local:
|
||||
raise self.raise_for_local
|
||||
self.local_calls.append(file_path)
|
||||
@@ -125,11 +132,23 @@ class StubPersistenceService:
|
||||
self.save_calls: List[Dict[str, Any]] = []
|
||||
self.delete_calls: List[str] = []
|
||||
self.move_calls: List[Dict[str, str]] = []
|
||||
self.save_result = SimpleNamespace(payload={"success": True, "recipe_id": "stub-id"}, status=200)
|
||||
self.save_result = SimpleNamespace(
|
||||
payload={"success": True, "recipe_id": "stub-id"}, status=200
|
||||
)
|
||||
self.delete_result = SimpleNamespace(payload={"success": True}, status=200)
|
||||
StubPersistenceService.instances.append(self)
|
||||
|
||||
async def save_recipe(self, *, recipe_scanner, image_bytes, image_base64, name, tags, metadata, extension=None) -> SimpleNamespace: # noqa: D401
|
||||
async def save_recipe(
|
||||
self,
|
||||
*,
|
||||
recipe_scanner,
|
||||
image_bytes,
|
||||
image_base64,
|
||||
name,
|
||||
tags,
|
||||
metadata,
|
||||
extension=None,
|
||||
) -> SimpleNamespace: # noqa: D401
|
||||
self.save_calls.append(
|
||||
{
|
||||
"recipe_scanner": recipe_scanner,
|
||||
@@ -148,22 +167,42 @@ class StubPersistenceService:
|
||||
await recipe_scanner.remove_recipe(recipe_id)
|
||||
return self.delete_result
|
||||
|
||||
async def move_recipe(self, *, recipe_scanner, recipe_id: str, target_path: str) -> SimpleNamespace: # noqa: D401
|
||||
async def move_recipe(
|
||||
self, *, recipe_scanner, recipe_id: str, target_path: str
|
||||
) -> SimpleNamespace: # noqa: D401
|
||||
self.move_calls.append({"recipe_id": recipe_id, "target_path": target_path})
|
||||
return SimpleNamespace(
|
||||
payload={"success": True, "recipe_id": recipe_id, "new_file_path": target_path}, status=200
|
||||
payload={
|
||||
"success": True,
|
||||
"recipe_id": recipe_id,
|
||||
"new_file_path": target_path,
|
||||
},
|
||||
status=200,
|
||||
)
|
||||
|
||||
async def update_recipe(self, *, recipe_scanner, recipe_id: str, updates: Dict[str, Any]) -> SimpleNamespace: # pragma: no cover - unused by smoke tests
|
||||
return SimpleNamespace(payload={"success": True, "recipe_id": recipe_id, "updates": updates}, status=200)
|
||||
async def update_recipe(
|
||||
self, *, recipe_scanner, recipe_id: str, updates: Dict[str, Any]
|
||||
) -> SimpleNamespace: # pragma: no cover - unused by smoke tests
|
||||
return SimpleNamespace(
|
||||
payload={"success": True, "recipe_id": recipe_id, "updates": updates},
|
||||
status=200,
|
||||
)
|
||||
|
||||
async def reconnect_lora(self, *, recipe_scanner, recipe_id: str, lora_index: int, target_name: str) -> SimpleNamespace: # pragma: no cover
|
||||
async def reconnect_lora(
|
||||
self, *, recipe_scanner, recipe_id: str, lora_index: int, target_name: str
|
||||
) -> SimpleNamespace: # pragma: no cover
|
||||
return SimpleNamespace(payload={"success": True}, status=200)
|
||||
|
||||
async def bulk_delete(self, *, recipe_scanner, recipe_ids: List[str]) -> SimpleNamespace: # pragma: no cover
|
||||
return SimpleNamespace(payload={"success": True, "deleted": recipe_ids}, status=200)
|
||||
async def bulk_delete(
|
||||
self, *, recipe_scanner, recipe_ids: List[str]
|
||||
) -> SimpleNamespace: # pragma: no cover
|
||||
return SimpleNamespace(
|
||||
payload={"success": True, "deleted": recipe_ids}, status=200
|
||||
)
|
||||
|
||||
async def save_recipe_from_widget(self, *, recipe_scanner, metadata: Dict[str, Any], image_bytes: bytes) -> SimpleNamespace: # pragma: no cover
|
||||
async def save_recipe_from_widget(
|
||||
self, *, recipe_scanner, metadata: Dict[str, Any], image_bytes: bytes
|
||||
) -> SimpleNamespace: # pragma: no cover
|
||||
return SimpleNamespace(payload={"success": True}, status=200)
|
||||
|
||||
|
||||
@@ -176,7 +215,11 @@ class StubSharingService:
|
||||
self.share_calls: List[str] = []
|
||||
self.download_calls: List[str] = []
|
||||
self.share_result = SimpleNamespace(
|
||||
payload={"success": True, "download_url": "/share/stub", "filename": "recipe.png"},
|
||||
payload={
|
||||
"success": True,
|
||||
"download_url": "/share/stub",
|
||||
"filename": "recipe.png",
|
||||
},
|
||||
status=200,
|
||||
)
|
||||
self.download_info = SimpleNamespace(file_path="", download_filename="")
|
||||
@@ -186,7 +229,9 @@ class StubSharingService:
|
||||
self.share_calls.append(recipe_id)
|
||||
return self.share_result
|
||||
|
||||
async def prepare_download(self, *, recipe_scanner, recipe_id: str) -> SimpleNamespace:
|
||||
async def prepare_download(
|
||||
self, *, recipe_scanner, recipe_id: str
|
||||
) -> SimpleNamespace:
|
||||
self.download_calls.append(recipe_id)
|
||||
return self.download_info
|
||||
|
||||
@@ -214,7 +259,9 @@ class StubCivitaiClient:
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRouteHarness]:
|
||||
async def recipe_harness(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> AsyncIterator[RecipeRouteHarness]:
|
||||
"""Context manager that yields a fully wired recipe route harness."""
|
||||
|
||||
StubAnalysisService.instances.clear()
|
||||
@@ -237,8 +284,12 @@ async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRou
|
||||
|
||||
monkeypatch.setattr(ServiceRegistry, "get_recipe_scanner", fake_get_recipe_scanner)
|
||||
monkeypatch.setattr(ServiceRegistry, "get_civitai_client", fake_get_civitai_client)
|
||||
monkeypatch.setattr(base_recipe_routes, "RecipeAnalysisService", StubAnalysisService)
|
||||
monkeypatch.setattr(base_recipe_routes, "RecipePersistenceService", StubPersistenceService)
|
||||
monkeypatch.setattr(
|
||||
base_recipe_routes, "RecipeAnalysisService", StubAnalysisService
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
base_recipe_routes, "RecipePersistenceService", StubPersistenceService
|
||||
)
|
||||
monkeypatch.setattr(base_recipe_routes, "RecipeSharingService", StubSharingService)
|
||||
monkeypatch.setattr(base_recipe_routes, "get_downloader", fake_get_downloader)
|
||||
monkeypatch.setattr(config, "loras_roots", [str(tmp_path)], raising=False)
|
||||
@@ -294,7 +345,9 @@ async def test_list_recipes_provides_file_urls(monkeypatch, tmp_path: Path) -> N
|
||||
async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
form = FormData()
|
||||
form.add_field("image", b"stub", filename="sample.png", content_type="image/png")
|
||||
form.add_field(
|
||||
"image", b"stub", filename="sample.png", content_type="image/png"
|
||||
)
|
||||
form.add_field("name", "Test Recipe")
|
||||
form.add_field("tags", json.dumps(["tag-a"]))
|
||||
form.add_field("metadata", json.dumps({"loras": []}))
|
||||
@@ -312,7 +365,9 @@ async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) ->
|
||||
assert save_payload["recipe_id"] == "saved-id"
|
||||
assert harness.persistence.save_calls[-1]["name"] == "Test Recipe"
|
||||
|
||||
harness.persistence.delete_result = SimpleNamespace(payload={"success": True}, status=200)
|
||||
harness.persistence.delete_result = SimpleNamespace(
|
||||
payload={"success": True}, status=200
|
||||
)
|
||||
|
||||
delete_response = await harness.client.delete("/api/lm/recipe/saved-id")
|
||||
delete_payload = await delete_response.json()
|
||||
@@ -326,14 +381,20 @@ async def test_move_recipe_invokes_persistence(monkeypatch, tmp_path: Path) -> N
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
response = await harness.client.post(
|
||||
"/api/lm/recipe/move",
|
||||
json={"recipe_id": "move-me", "target_path": str(tmp_path / "recipes" / "subdir")},
|
||||
json={
|
||||
"recipe_id": "move-me",
|
||||
"target_path": str(tmp_path / "recipes" / "subdir"),
|
||||
},
|
||||
)
|
||||
|
||||
payload = await response.json()
|
||||
assert response.status == 200
|
||||
assert payload["recipe_id"] == "move-me"
|
||||
assert harness.persistence.move_calls == [
|
||||
{"recipe_id": "move-me", "target_path": str(tmp_path / "recipes" / "subdir")}
|
||||
{
|
||||
"recipe_id": "move-me",
|
||||
"target_path": str(tmp_path / "recipes" / "subdir"),
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@@ -348,7 +409,10 @@ async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None:
|
||||
async def fake_get_default_metadata_provider():
|
||||
return Provider()
|
||||
|
||||
monkeypatch.setattr("py.recipes.enrichment.get_default_metadata_provider", fake_get_default_metadata_provider)
|
||||
monkeypatch.setattr(
|
||||
"py.recipes.enrichment.get_default_metadata_provider",
|
||||
fake_get_default_metadata_provider,
|
||||
)
|
||||
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
resources = [
|
||||
@@ -397,7 +461,9 @@ async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None:
|
||||
assert harness.downloader.urls == ["https://example.com/images/1"]
|
||||
|
||||
|
||||
async def test_import_remote_recipe_falls_back_to_request_base_model(monkeypatch, tmp_path: Path) -> None:
|
||||
async def test_import_remote_recipe_falls_back_to_request_base_model(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
provider_calls: list[str | int] = []
|
||||
|
||||
class Provider:
|
||||
@@ -408,7 +474,10 @@ async def test_import_remote_recipe_falls_back_to_request_base_model(monkeypatch
|
||||
async def fake_get_default_metadata_provider():
|
||||
return Provider()
|
||||
|
||||
monkeypatch.setattr("py.recipes.enrichment.get_default_metadata_provider", fake_get_default_metadata_provider)
|
||||
monkeypatch.setattr(
|
||||
"py.recipes.enrichment.get_default_metadata_provider",
|
||||
fake_get_default_metadata_provider,
|
||||
)
|
||||
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
resources = [
|
||||
@@ -444,13 +513,16 @@ async def test_import_remote_video_recipe(monkeypatch, tmp_path: Path) -> None:
|
||||
async def fake_get_default_metadata_provider():
|
||||
return SimpleNamespace(get_model_version_info=lambda id: ({}, None))
|
||||
|
||||
monkeypatch.setattr("py.recipes.enrichment.get_default_metadata_provider", fake_get_default_metadata_provider)
|
||||
monkeypatch.setattr(
|
||||
"py.recipes.enrichment.get_default_metadata_provider",
|
||||
fake_get_default_metadata_provider,
|
||||
)
|
||||
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
harness.civitai.image_info["12345"] = {
|
||||
"id": 12345,
|
||||
"url": "https://image.civitai.com/x/y/original=true/video.mp4",
|
||||
"type": "video"
|
||||
"type": "video",
|
||||
}
|
||||
|
||||
response = await harness.client.get(
|
||||
@@ -469,7 +541,7 @@ async def test_import_remote_video_recipe(monkeypatch, tmp_path: Path) -> None:
|
||||
|
||||
# Verify downloader was called with rewritten URL
|
||||
assert "transcode=true" in harness.downloader.urls[0]
|
||||
|
||||
|
||||
# Verify persistence was called with correct extension
|
||||
call = harness.persistence.save_calls[-1]
|
||||
assert call["extension"] == ".mp4"
|
||||
@@ -477,7 +549,9 @@ async def test_import_remote_video_recipe(monkeypatch, tmp_path: Path) -> None:
|
||||
|
||||
async def test_analyze_uploaded_image_error_path(monkeypatch, tmp_path: Path) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
harness.analysis.raise_for_uploaded = RecipeValidationError("No image data provided")
|
||||
harness.analysis.raise_for_uploaded = RecipeValidationError(
|
||||
"No image data provided"
|
||||
)
|
||||
|
||||
form = FormData()
|
||||
form.add_field("image", b"", filename="empty.png", content_type="image/png")
|
||||
@@ -504,7 +578,11 @@ async def test_share_and_download_recipe(monkeypatch, tmp_path: Path) -> None:
|
||||
}
|
||||
|
||||
harness.sharing.share_result = SimpleNamespace(
|
||||
payload={"success": True, "download_url": "/api/share", "filename": "share.png"},
|
||||
payload={
|
||||
"success": True,
|
||||
"download_url": "/api/share",
|
||||
"filename": "share.png",
|
||||
},
|
||||
status=200,
|
||||
)
|
||||
harness.sharing.download_info = SimpleNamespace(
|
||||
@@ -519,15 +597,24 @@ async def test_share_and_download_recipe(monkeypatch, tmp_path: Path) -> None:
|
||||
assert share_payload["filename"] == "share.png"
|
||||
assert harness.sharing.share_calls == [recipe_id]
|
||||
|
||||
download_response = await harness.client.get(f"/api/lm/recipe/{recipe_id}/share/download")
|
||||
download_response = await harness.client.get(
|
||||
f"/api/lm/recipe/{recipe_id}/share/download"
|
||||
)
|
||||
body = await download_response.read()
|
||||
|
||||
assert download_response.status == 200
|
||||
assert download_response.headers["Content-Disposition"] == 'attachment; filename="share.png"'
|
||||
assert (
|
||||
download_response.headers["Content-Disposition"]
|
||||
== 'attachment; filename="share.png"'
|
||||
)
|
||||
assert body == b"stub"
|
||||
|
||||
download_path.unlink(missing_ok=True)
|
||||
async def test_import_remote_recipe_merges_metadata(monkeypatch, tmp_path: Path) -> None:
|
||||
|
||||
|
||||
async def test_import_remote_recipe_merges_metadata(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
# 1. Mock Metadata Provider
|
||||
class Provider:
|
||||
async def get_model_version_info(self, model_version_id):
|
||||
@@ -536,22 +623,25 @@ async def test_import_remote_recipe_merges_metadata(monkeypatch, tmp_path: Path)
|
||||
async def fake_get_default_metadata_provider():
|
||||
return Provider()
|
||||
|
||||
monkeypatch.setattr("py.recipes.enrichment.get_default_metadata_provider", fake_get_default_metadata_provider)
|
||||
monkeypatch.setattr(
|
||||
"py.recipes.enrichment.get_default_metadata_provider",
|
||||
fake_get_default_metadata_provider,
|
||||
)
|
||||
|
||||
# 2. Mock ExifUtils to return some embedded metadata
|
||||
class MockExifUtils:
|
||||
@staticmethod
|
||||
def extract_image_metadata(path):
|
||||
return "Recipe metadata: " + json.dumps({
|
||||
"gen_params": {"prompt": "from embedded", "seed": 123}
|
||||
})
|
||||
return "Recipe metadata: " + json.dumps(
|
||||
{"gen_params": {"prompt": "from embedded", "seed": 123}}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(recipe_handlers, "ExifUtils", MockExifUtils)
|
||||
|
||||
# 3. Mock Parser Factory for StubAnalysisService
|
||||
class MockParser:
|
||||
async def parse_metadata(self, raw, recipe_scanner=None):
|
||||
return json.loads(raw[len("Recipe metadata: "):])
|
||||
return json.loads(raw[len("Recipe metadata: ") :])
|
||||
|
||||
class MockFactory:
|
||||
def create_parser(self, raw):
|
||||
@@ -562,12 +652,12 @@ async def test_import_remote_recipe_merges_metadata(monkeypatch, tmp_path: Path)
|
||||
# 4. Setup Harness and run test
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
harness.analysis._recipe_parser_factory = MockFactory()
|
||||
|
||||
|
||||
# Civitai meta via image_info
|
||||
harness.civitai.image_info["1"] = {
|
||||
"id": 1,
|
||||
"url": "https://example.com/images/1.jpg",
|
||||
"meta": {"prompt": "from civitai", "cfg": 7.0}
|
||||
"meta": {"prompt": "from civitai", "cfg": 7.0},
|
||||
}
|
||||
|
||||
resources = []
|
||||
@@ -583,11 +673,11 @@ async def test_import_remote_recipe_merges_metadata(monkeypatch, tmp_path: Path)
|
||||
|
||||
payload = await response.json()
|
||||
assert response.status == 200
|
||||
|
||||
|
||||
call = harness.persistence.save_calls[-1]
|
||||
metadata = call["metadata"]
|
||||
gen_params = metadata["gen_params"]
|
||||
|
||||
|
||||
assert gen_params["seed"] == 123
|
||||
|
||||
|
||||
@@ -619,3 +709,142 @@ async def test_get_recipe_syntax(monkeypatch, tmp_path: Path) -> None:
|
||||
response_404 = await harness.client.get("/api/lm/recipe/non-existent/syntax")
|
||||
assert response_404.status == 404
|
||||
|
||||
|
||||
async def test_batch_import_start_success(monkeypatch, tmp_path: Path) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
response = await harness.client.post(
|
||||
"/api/lm/recipes/batch-import/start",
|
||||
json={
|
||||
"items": [
|
||||
{"source": "https://example.com/image1.png"},
|
||||
{"source": "https://example.com/image2.png"},
|
||||
],
|
||||
"tags": ["batch", "import"],
|
||||
"skip_no_metadata": True,
|
||||
},
|
||||
)
|
||||
payload = await response.json()
|
||||
assert response.status == 200
|
||||
assert payload["success"] is True
|
||||
assert "operation_id" in payload
|
||||
|
||||
|
||||
async def test_batch_import_start_empty_items(monkeypatch, tmp_path: Path) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
response = await harness.client.post(
|
||||
"/api/lm/recipes/batch-import/start",
|
||||
json={"items": [], "tags": []},
|
||||
)
|
||||
payload = await response.json()
|
||||
assert response.status == 400
|
||||
assert payload["success"] is False
|
||||
assert "No items provided" in payload["error"]
|
||||
|
||||
|
||||
async def test_batch_import_start_missing_source(monkeypatch, tmp_path: Path) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
response = await harness.client.post(
|
||||
"/api/lm/recipes/batch-import/start",
|
||||
json={"items": [{"source": ""}]},
|
||||
)
|
||||
payload = await response.json()
|
||||
assert response.status == 400
|
||||
assert payload["success"] is False
|
||||
assert "source" in payload["error"].lower()
|
||||
|
||||
|
||||
async def test_batch_import_start_already_running(monkeypatch, tmp_path: Path) -> None:
|
||||
import asyncio
|
||||
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
original_analyze = harness.analysis.analyze_remote_image
|
||||
|
||||
async def slow_analyze(*, url, recipe_scanner, civitai_client):
|
||||
await asyncio.sleep(0.5)
|
||||
return await original_analyze(
|
||||
url=url, recipe_scanner=recipe_scanner, civitai_client=civitai_client
|
||||
)
|
||||
|
||||
harness.analysis.analyze_remote_image = slow_analyze
|
||||
|
||||
items = [{"source": f"https://example.com/image{i}.png"} for i in range(10)]
|
||||
|
||||
response1 = await harness.client.post(
|
||||
"/api/lm/recipes/batch-import/start",
|
||||
json={"items": items},
|
||||
)
|
||||
assert response1.status == 200
|
||||
|
||||
payload1 = await response1.json()
|
||||
assert payload1["success"] is True
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
response2 = await harness.client.post(
|
||||
"/api/lm/recipes/batch-import/start",
|
||||
json={"items": [{"source": "https://example.com/other.png"}]},
|
||||
)
|
||||
payload2 = await response2.json()
|
||||
assert response2.status == 409
|
||||
assert "already in progress" in payload2["error"].lower()
|
||||
|
||||
|
||||
async def test_batch_import_get_progress_not_found(monkeypatch, tmp_path: Path) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
response = await harness.client.get(
|
||||
"/api/lm/recipes/batch-import/progress",
|
||||
params={"operation_id": "nonexistent-id"},
|
||||
)
|
||||
payload = await response.json()
|
||||
assert response.status == 404
|
||||
assert payload["success"] is False
|
||||
|
||||
|
||||
async def test_batch_import_get_progress_missing_id(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
response = await harness.client.get("/api/lm/recipes/batch-import/progress")
|
||||
payload = await response.json()
|
||||
assert response.status == 400
|
||||
assert payload["success"] is False
|
||||
|
||||
|
||||
async def test_batch_import_cancel_success(monkeypatch, tmp_path: Path) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
start_response = await harness.client.post(
|
||||
"/api/lm/recipes/batch-import/start",
|
||||
json={"items": [{"source": "https://example.com/image.png"}]},
|
||||
)
|
||||
start_payload = await start_response.json()
|
||||
operation_id = start_payload["operation_id"]
|
||||
|
||||
cancel_response = await harness.client.post(
|
||||
"/api/lm/recipes/batch-import/cancel",
|
||||
json={"operation_id": operation_id},
|
||||
)
|
||||
cancel_payload = await cancel_response.json()
|
||||
assert cancel_response.status == 200
|
||||
assert cancel_payload["success"] is True
|
||||
|
||||
|
||||
async def test_batch_import_cancel_not_found(monkeypatch, tmp_path: Path) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
response = await harness.client.post(
|
||||
"/api/lm/recipes/batch-import/cancel",
|
||||
json={"operation_id": "nonexistent-id"},
|
||||
)
|
||||
payload = await response.json()
|
||||
assert response.status == 404
|
||||
assert payload["success"] is False
|
||||
|
||||
|
||||
async def test_batch_import_cancel_missing_id(monkeypatch, tmp_path: Path) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
response = await harness.client.post(
|
||||
"/api/lm/recipes/batch-import/cancel",
|
||||
json={},
|
||||
)
|
||||
payload = await response.json()
|
||||
assert response.status == 400
|
||||
assert payload["success"] is False
|
||||
|
||||
597
tests/services/test_batch_import_service.py
Normal file
597
tests/services/test_batch_import_service.py
Normal file
@@ -0,0 +1,597 @@
|
||||
"""Unit tests for BatchImportService."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.batch_import_service import (
|
||||
AdaptiveConcurrencyController,
|
||||
BatchImportItem,
|
||||
BatchImportProgress,
|
||||
BatchImportService,
|
||||
ImportItemType,
|
||||
ImportStatus,
|
||||
)
|
||||
|
||||
|
||||
class MockWebSocketManager:
|
||||
def __init__(self):
|
||||
self.broadcasts: List[Dict[str, Any]] = []
|
||||
|
||||
async def broadcast(self, data: Dict[str, Any]):
|
||||
self.broadcasts.append(data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockAnalysisResult:
|
||||
payload: Dict[str, Any]
|
||||
status: int = 200
|
||||
|
||||
|
||||
class MockAnalysisService:
|
||||
def __init__(self, results: Optional[Dict[str, MockAnalysisResult]] = None):
|
||||
self.results = results or {}
|
||||
self.call_count = 0
|
||||
self.last_url = None
|
||||
self.last_path = None
|
||||
|
||||
async def analyze_remote_image(self, *, url: str, recipe_scanner, civitai_client):
|
||||
self.call_count += 1
|
||||
self.last_url = url
|
||||
if url in self.results:
|
||||
return self.results[url]
|
||||
return MockAnalysisResult({"error": "No metadata found", "loras": []})
|
||||
|
||||
async def analyze_local_image(self, *, file_path: str, recipe_scanner):
|
||||
self.call_count += 1
|
||||
self.last_path = file_path
|
||||
if file_path in self.results:
|
||||
return self.results[file_path]
|
||||
return MockAnalysisResult({"error": "No metadata found", "loras": []})
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockSaveResult:
|
||||
payload: Dict[str, Any]
|
||||
status: int = 200
|
||||
|
||||
|
||||
class MockPersistenceService:
|
||||
def __init__(self, should_succeed: bool = True):
|
||||
self.should_succeed = should_succeed
|
||||
self.saved_recipes: List[Dict[str, Any]] = []
|
||||
self.call_count = 0
|
||||
|
||||
async def save_recipe(
|
||||
self,
|
||||
*,
|
||||
recipe_scanner,
|
||||
image_bytes: Optional[bytes] = None,
|
||||
image_base64: Optional[str] = None,
|
||||
name: str,
|
||||
tags: List[str],
|
||||
metadata: Dict[str, Any],
|
||||
extension: Optional[str] = None,
|
||||
):
|
||||
self.call_count += 1
|
||||
self.saved_recipes.append(
|
||||
{
|
||||
"name": name,
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
if self.should_succeed:
|
||||
return MockSaveResult({"success": True, "id": f"recipe_{self.call_count}"})
|
||||
return MockSaveResult({"success": False, "error": "Save failed"}, status=400)
|
||||
|
||||
|
||||
class TestAdaptiveConcurrencyController:
|
||||
def test_initial_values(self):
|
||||
controller = AdaptiveConcurrencyController()
|
||||
assert controller.current_concurrency == 3
|
||||
assert controller.min_concurrency == 1
|
||||
assert controller.max_concurrency == 5
|
||||
|
||||
def test_custom_initial_values(self):
|
||||
controller = AdaptiveConcurrencyController(
|
||||
min_concurrency=2,
|
||||
max_concurrency=10,
|
||||
initial_concurrency=5,
|
||||
)
|
||||
assert controller.current_concurrency == 5
|
||||
assert controller.min_concurrency == 2
|
||||
assert controller.max_concurrency == 10
|
||||
|
||||
def test_increase_concurrency_on_success(self):
|
||||
controller = AdaptiveConcurrencyController(initial_concurrency=3)
|
||||
controller.record_result(duration=0.5, success=True)
|
||||
assert controller.current_concurrency == 4
|
||||
|
||||
def test_do_not_exceed_max(self):
|
||||
controller = AdaptiveConcurrencyController(
|
||||
max_concurrency=5,
|
||||
initial_concurrency=5,
|
||||
)
|
||||
controller.record_result(duration=0.5, success=True)
|
||||
assert controller.current_concurrency == 5
|
||||
|
||||
def test_decrease_concurrency_on_failure(self):
|
||||
controller = AdaptiveConcurrencyController(initial_concurrency=3)
|
||||
controller.record_result(duration=1.0, success=False)
|
||||
assert controller.current_concurrency == 2
|
||||
|
||||
def test_do_not_go_below_min(self):
|
||||
controller = AdaptiveConcurrencyController(
|
||||
min_concurrency=1,
|
||||
initial_concurrency=1,
|
||||
)
|
||||
controller.record_result(duration=1.0, success=False)
|
||||
assert controller.current_concurrency == 1
|
||||
|
||||
def test_slow_task_decreases_concurrency(self):
|
||||
controller = AdaptiveConcurrencyController(initial_concurrency=3)
|
||||
controller.record_result(duration=11.0, success=True)
|
||||
assert controller.current_concurrency == 2
|
||||
|
||||
def test_fast_task_increases_concurrency(self):
|
||||
controller = AdaptiveConcurrencyController(initial_concurrency=3)
|
||||
controller.record_result(duration=0.5, success=True)
|
||||
assert controller.current_concurrency == 4
|
||||
|
||||
def test_moderate_task_no_change(self):
|
||||
controller = AdaptiveConcurrencyController(initial_concurrency=3)
|
||||
controller.record_result(duration=5.0, success=True)
|
||||
assert controller.current_concurrency == 3
|
||||
|
||||
|
||||
class TestBatchImportProgress:
|
||||
def test_to_dict(self):
|
||||
progress = BatchImportProgress(
|
||||
operation_id="test-123",
|
||||
total=10,
|
||||
completed=5,
|
||||
success=3,
|
||||
failed=2,
|
||||
skipped=0,
|
||||
current_item="image.png",
|
||||
status="running",
|
||||
)
|
||||
result = progress.to_dict()
|
||||
assert result["operation_id"] == "test-123"
|
||||
assert result["total"] == 10
|
||||
assert result["completed"] == 5
|
||||
assert result["success"] == 3
|
||||
assert result["failed"] == 2
|
||||
assert result["progress_percent"] == 50.0
|
||||
|
||||
def test_progress_percent_zero_total(self):
|
||||
progress = BatchImportProgress(
|
||||
operation_id="test-123",
|
||||
total=0,
|
||||
)
|
||||
assert progress.to_dict()["progress_percent"] == 0
|
||||
|
||||
|
||||
class TestBatchImportItem:
|
||||
def test_defaults(self):
|
||||
item = BatchImportItem(
|
||||
id="item-1",
|
||||
source="https://example.com/image.png",
|
||||
item_type=ImportItemType.URL,
|
||||
)
|
||||
assert item.status == ImportStatus.PENDING
|
||||
assert item.error_message is None
|
||||
assert item.recipe_name is None
|
||||
|
||||
|
||||
class TestBatchImportService:
|
||||
@pytest.fixture
|
||||
def mock_services(self):
|
||||
ws_manager = MockWebSocketManager()
|
||||
analysis_service = MockAnalysisService()
|
||||
persistence_service = MockPersistenceService()
|
||||
logger = logging.getLogger("test")
|
||||
return ws_manager, analysis_service, persistence_service, logger
|
||||
|
||||
@pytest.fixture
|
||||
def service(self, mock_services):
|
||||
ws_manager, analysis_service, persistence_service, logger = mock_services
|
||||
return BatchImportService(
|
||||
analysis_service=analysis_service,
|
||||
persistence_service=persistence_service,
|
||||
ws_manager=ws_manager,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
def test_is_import_running_no_operations(self, service):
|
||||
assert not service.is_import_running()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_batch_import_creates_operation(self, service):
|
||||
recipe_scanner_getter = lambda: SimpleNamespace()
|
||||
civitai_client_getter = lambda: SimpleNamespace()
|
||||
|
||||
operation_id = await service.start_batch_import(
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
items=[{"source": "https://example.com/image.png"}],
|
||||
)
|
||||
|
||||
assert operation_id is not None
|
||||
assert service.is_import_running(operation_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_progress(self, service):
|
||||
recipe_scanner_getter = lambda: SimpleNamespace()
|
||||
civitai_client_getter = lambda: SimpleNamespace()
|
||||
|
||||
operation_id = await service.start_batch_import(
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
items=[
|
||||
{"source": "https://example.com/1.png"},
|
||||
{"source": "https://example.com/2.png"},
|
||||
],
|
||||
)
|
||||
|
||||
progress = service.get_progress(operation_id)
|
||||
assert progress is not None
|
||||
assert progress.total == 2
|
||||
assert progress.status in ("pending", "running")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_import(self, service):
|
||||
recipe_scanner_getter = lambda: SimpleNamespace()
|
||||
civitai_client_getter = lambda: SimpleNamespace()
|
||||
|
||||
operation_id = await service.start_batch_import(
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
items=[{"source": "https://example.com/image.png"}],
|
||||
)
|
||||
|
||||
assert service.cancel_import(operation_id) is True
|
||||
assert service.cancel_import("nonexistent") is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_images_non_recursive(self, service, tmp_path):
|
||||
for i in range(3):
|
||||
(tmp_path / f"image{i}.png").write_bytes(b"fake-image")
|
||||
|
||||
(tmp_path / "subdir").mkdir()
|
||||
(tmp_path / "subdir" / "hidden.png").write_bytes(b"fake-image")
|
||||
|
||||
images = await service._discover_images(str(tmp_path), recursive=False)
|
||||
assert len(images) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_images_recursive(self, service, tmp_path):
|
||||
for i in range(2):
|
||||
(tmp_path / f"image{i}.png").write_bytes(b"fake-image")
|
||||
|
||||
subdir = tmp_path / "subdir"
|
||||
subdir.mkdir()
|
||||
for i in range(2):
|
||||
(subdir / f"nested{i}.jpg").write_bytes(b"fake-image")
|
||||
|
||||
images = await service._discover_images(str(tmp_path), recursive=True)
|
||||
assert len(images) == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_images_filters_by_extension(self, service, tmp_path):
|
||||
(tmp_path / "image.png").write_bytes(b"fake-image")
|
||||
(tmp_path / "image.jpg").write_bytes(b"fake-image")
|
||||
(tmp_path / "image.webp").write_bytes(b"fake-image")
|
||||
(tmp_path / "document.pdf").write_bytes(b"fake-doc")
|
||||
(tmp_path / "script.py").write_bytes(b"print('hello')")
|
||||
|
||||
images = await service._discover_images(str(tmp_path), recursive=False)
|
||||
assert len(images) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_images_invalid_directory(self, service):
|
||||
from py.services.recipes.errors import RecipeValidationError
|
||||
|
||||
with pytest.raises(RecipeValidationError):
|
||||
await service._discover_images("/nonexistent/path", recursive=False)
|
||||
|
||||
def test_is_supported_image(self, service):
|
||||
assert service._is_supported_image("test.png") is True
|
||||
assert service._is_supported_image("test.jpg") is True
|
||||
assert service._is_supported_image("test.jpeg") is True
|
||||
assert service._is_supported_image("test.webp") is True
|
||||
assert service._is_supported_image("test.gif") is True
|
||||
assert service._is_supported_image("test.bmp") is True
|
||||
assert service._is_supported_image("test.pdf") is False
|
||||
assert service._is_supported_image("test.txt") is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_import_processes_items(self, mock_services, tmp_path):
|
||||
ws_manager, _, persistence_service, logger = mock_services
|
||||
|
||||
analysis_service = MockAnalysisService(
|
||||
{
|
||||
"https://example.com/valid.png": MockAnalysisResult(
|
||||
{
|
||||
"loras": [{"name": "test-lora", "weight": 1.0}],
|
||||
"base_model": "SD1.5",
|
||||
"gen_params": {"steps": 20},
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
service = BatchImportService(
|
||||
analysis_service=analysis_service,
|
||||
persistence_service=persistence_service,
|
||||
ws_manager=ws_manager,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
recipe_scanner_getter = lambda: SimpleNamespace(
|
||||
find_recipes_by_fingerprint=lambda x: [],
|
||||
add_recipe=lambda x: None,
|
||||
)
|
||||
civitai_client_getter = lambda: SimpleNamespace()
|
||||
|
||||
operation_id = await service.start_batch_import(
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
items=[
|
||||
{"source": "https://example.com/valid.png"},
|
||||
{"source": "https://example.com/no-meta.png"},
|
||||
],
|
||||
skip_no_metadata=True,
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
progress = service.get_progress(operation_id)
|
||||
assert progress is not None or persistence_service.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_directory_import(self, service, tmp_path):
|
||||
for i in range(5):
|
||||
(tmp_path / f"image{i}.png").write_bytes(b"fake-image")
|
||||
|
||||
recipe_scanner_getter = lambda: SimpleNamespace()
|
||||
civitai_client_getter = lambda: SimpleNamespace()
|
||||
|
||||
operation_id = await service.start_directory_import(
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
directory=str(tmp_path),
|
||||
recursive=False,
|
||||
)
|
||||
|
||||
progress = service.get_progress(operation_id)
|
||||
assert progress is not None
|
||||
assert progress.total == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_broadcasts_progress(self, mock_services):
|
||||
ws_manager, analysis_service, persistence_service, logger = mock_services
|
||||
|
||||
service = BatchImportService(
|
||||
analysis_service=analysis_service,
|
||||
persistence_service=persistence_service,
|
||||
ws_manager=ws_manager,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
recipe_scanner_getter = lambda: SimpleNamespace()
|
||||
civitai_client_getter = lambda: SimpleNamespace()
|
||||
|
||||
operation_id = await service.start_batch_import(
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
items=[{"source": "https://example.com/test.png"}],
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
assert len(ws_manager.broadcasts) > 0
|
||||
assert any(
|
||||
b.get("type") == "batch_import_progress" for b in ws_manager.broadcasts
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancellation_stops_processing(self, mock_services):
|
||||
ws_manager, analysis_service, persistence_service, logger = mock_services
|
||||
|
||||
service = BatchImportService(
|
||||
analysis_service=analysis_service,
|
||||
persistence_service=persistence_service,
|
||||
ws_manager=ws_manager,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
recipe_scanner_getter = lambda: SimpleNamespace()
|
||||
civitai_client_getter = lambda: SimpleNamespace()
|
||||
|
||||
items = [{"source": f"https://example.com/{i}.png"} for i in range(10)]
|
||||
|
||||
operation_id = await service.start_batch_import(
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
items=items,
|
||||
)
|
||||
|
||||
service.cancel_import(operation_id)
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
progress = service.get_progress(operation_id)
|
||||
if progress:
|
||||
assert progress.status == "cancelled"
|
||||
|
||||
|
||||
class TestBatchImportServiceEdgeCases:
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
ws_manager = MockWebSocketManager()
|
||||
analysis_service = MockAnalysisService()
|
||||
persistence_service = MockPersistenceService()
|
||||
logger = logging.getLogger("test")
|
||||
|
||||
return BatchImportService(
|
||||
analysis_service=analysis_service,
|
||||
persistence_service=persistence_service,
|
||||
ws_manager=ws_manager,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_items_list(self, service):
|
||||
recipe_scanner_getter = lambda: SimpleNamespace()
|
||||
civitai_client_getter = lambda: SimpleNamespace()
|
||||
|
||||
operation_id = await service.start_batch_import(
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
items=[],
|
||||
)
|
||||
|
||||
progress = service.get_progress(operation_id)
|
||||
assert progress is not None
|
||||
assert progress.total == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_url_and_path_items(self, service, tmp_path):
|
||||
(tmp_path / "local.png").write_bytes(b"fake-image")
|
||||
|
||||
recipe_scanner_getter = lambda: SimpleNamespace()
|
||||
civitai_client_getter = lambda: SimpleNamespace()
|
||||
|
||||
operation_id = await service.start_batch_import(
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
items=[
|
||||
{"source": "https://example.com/remote.png", "type": "url"},
|
||||
{"source": str(tmp_path / "local.png"), "type": "local_path"},
|
||||
],
|
||||
)
|
||||
|
||||
progress = service.get_progress(operation_id)
|
||||
assert progress is not None
|
||||
assert progress.total == 2
|
||||
assert progress.items[0].item_type == ImportItemType.URL
|
||||
assert progress.items[1].item_type == ImportItemType.LOCAL_PATH
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_are_passed_to_persistence(self, tmp_path):
|
||||
ws_manager = MockWebSocketManager()
|
||||
analysis_service = MockAnalysisService(
|
||||
{
|
||||
str(tmp_path / "test.png"): MockAnalysisResult(
|
||||
{
|
||||
"loras": [{"name": "test-lora"}],
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
persistence_service = MockPersistenceService()
|
||||
logger = logging.getLogger("test")
|
||||
|
||||
(tmp_path / "test.png").write_bytes(b"fake-image")
|
||||
|
||||
service = BatchImportService(
|
||||
analysis_service=analysis_service,
|
||||
persistence_service=persistence_service,
|
||||
ws_manager=ws_manager,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
recipe_scanner_getter = lambda: SimpleNamespace(
|
||||
find_recipes_by_fingerprint=lambda x: [],
|
||||
)
|
||||
civitai_client_getter = lambda: SimpleNamespace()
|
||||
|
||||
operation_id = await service.start_batch_import(
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
items=[{"source": str(tmp_path / "test.png")}],
|
||||
tags=["batch-import", "test"],
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
if persistence_service.saved_recipes:
|
||||
assert "batch-import" in persistence_service.saved_recipes[0]["tags"]
|
||||
assert "test" in persistence_service.saved_recipes[0]["tags"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_duplicates_parameter(self, service):
|
||||
recipe_scanner_getter = lambda: SimpleNamespace()
|
||||
civitai_client_getter = lambda: SimpleNamespace()
|
||||
|
||||
operation_id = await service.start_batch_import(
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
items=[{"source": "https://example.com/test.png"}],
|
||||
skip_duplicates=True,
|
||||
)
|
||||
|
||||
progress = service.get_progress(operation_id)
|
||||
assert progress is not None
|
||||
assert progress.skip_duplicates is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_duplicates_false_by_default(self, service):
|
||||
recipe_scanner_getter = lambda: SimpleNamespace()
|
||||
civitai_client_getter = lambda: SimpleNamespace()
|
||||
|
||||
operation_id = await service.start_batch_import(
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
items=[{"source": "https://example.com/test.png"}],
|
||||
)
|
||||
|
||||
progress = service.get_progress(operation_id)
|
||||
assert progress is not None
|
||||
assert progress.skip_duplicates is False
|
||||
|
||||
|
||||
class TestInputValidation:
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
ws_manager = MockWebSocketManager()
|
||||
analysis_service = MockAnalysisService()
|
||||
persistence_service = MockPersistenceService()
|
||||
logger = logging.getLogger("test")
|
||||
|
||||
return BatchImportService(
|
||||
analysis_service=analysis_service,
|
||||
persistence_service=persistence_service,
|
||||
ws_manager=ws_manager,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
def test_validate_valid_url(self, service):
|
||||
assert service._validate_url("https://example.com/image.png") is True
|
||||
assert service._validate_url("http://example.com/image.png") is True
|
||||
assert service._validate_url("https://civitai.com/images/123") is True
|
||||
|
||||
def test_validate_invalid_url(self, service):
|
||||
assert service._validate_url("not-a-url") is False
|
||||
assert service._validate_url("ftp://example.com/file") is False
|
||||
assert service._validate_url("") is False
|
||||
|
||||
def test_validate_valid_local_path(self, service, tmp_path):
|
||||
valid_path = str(tmp_path / "image.png")
|
||||
assert service._validate_local_path(valid_path) is True
|
||||
|
||||
def test_validate_invalid_local_path(self, service):
|
||||
assert service._validate_local_path("../etc/passwd") is False
|
||||
assert service._validate_local_path("relative/path.png") is False
|
||||
assert service._validate_local_path("") is False
|
||||
Reference in New Issue
Block a user