Files
ComfyUI-Lora-Manager/tests/services/test_route_support_services.py

256 lines
8.4 KiB
Python

import asyncio
import json
import os
from pathlib import Path
from typing import Any, Dict, List
import pytest
from py.services.download_coordinator import DownloadCoordinator
from py.services.metadata_sync_service import MetadataSyncService
from py.services.preview_asset_service import PreviewAssetService
from py.services.tag_update_service import TagUpdateService
class DummySettings:
def __init__(self, values: Dict[str, Any] | None = None) -> None:
self._values = values or {}
def get(self, key: str, default: Any = None) -> Any:
return self._values.get(key, default)
class RecordingMetadataManager:
def __init__(self) -> None:
self.saved: List[tuple[str, Dict[str, Any]]] = []
async def save_metadata(self, path: str, metadata: Dict[str, Any]) -> bool:
self.saved.append((path, json.loads(json.dumps(metadata))))
metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json"
Path(metadata_path).write_text(json.dumps(metadata))
return True
async def hydrate_model_data(self, model_data: Dict[str, Any]) -> Dict[str, Any]:
return model_data
class RecordingPreviewService:
def __init__(self) -> None:
self.calls: List[tuple[str, List[Dict[str, Any]]]] = []
async def ensure_preview_for_metadata(
self, metadata_path: str, local_metadata: Dict[str, Any], images
) -> None:
self.calls.append((metadata_path, list(images or [])))
local_metadata["preview_url"] = "preview.webp"
local_metadata["preview_nsfw_level"] = 1
class DummyProvider:
def __init__(self, payload: Dict[str, Any]) -> None:
self.payload = payload
async def get_model_by_hash(self, sha256: str):
return self.payload, None
async def get_model_version(self, model_id: int, model_version_id: int | None):
return self.payload
class FakeExifUtils:
@staticmethod
def optimize_image(**kwargs):
return kwargs["image_data"], {}
def test_metadata_sync_merges_remote_fields(tmp_path: Path) -> None:
manager = RecordingMetadataManager()
preview = RecordingPreviewService()
provider = DummyProvider({
"baseModel": "SD15",
"model": {"name": "Merged", "description": "desc", "tags": ["tag"], "creator": {"username": "user"}},
"trainedWords": ["word"],
"images": [{"url": "http://example", "nsfwLevel": 2, "type": "image"}],
})
service = MetadataSyncService(
metadata_manager=manager,
preview_service=preview,
settings=DummySettings(),
default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider),
metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider),
)
metadata_path = str(tmp_path / "model.metadata.json")
local_metadata = {"civitai": {"trainedWords": ["existing"]}}
updated = asyncio.run(service.update_model_metadata(metadata_path, local_metadata, provider.payload))
assert updated["model_name"] == "Merged"
assert updated["modelDescription"] == "desc"
assert set(updated["civitai"]["trainedWords"]) == {"existing", "word"}
assert manager.saved
assert preview.calls
def test_metadata_sync_fetch_and_update_updates_cache(tmp_path: Path) -> None:
manager = RecordingMetadataManager()
preview = RecordingPreviewService()
provider = DummyProvider({
"baseModel": "SDXL",
"model": {"name": "Updated"},
"images": [],
})
update_cache_calls: List[Dict[str, Any]] = []
async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool:
update_cache_calls.append({"original": original, "metadata": metadata})
return True
service = MetadataSyncService(
metadata_manager=manager,
preview_service=preview,
settings=DummySettings(),
default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider),
metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider),
)
model_data = {"sha256": "abc", "file_path": str(tmp_path / "model.safetensors")}
asyncio.run(manager.hydrate_model_data(model_data))
success, error = asyncio.run(
service.fetch_and_update_model(
sha256="abc",
file_path=str(tmp_path / "model.safetensors"),
model_data=model_data,
update_cache_func=update_cache,
)
)
assert success is True
assert error is None
assert update_cache_calls
assert manager.saved
def test_preview_asset_service_replace_preview(tmp_path: Path) -> None:
metadata_path = tmp_path / "sample.metadata.json"
metadata_path.write_text(json.dumps({}))
async def metadata_loader(path: str) -> Dict[str, Any]:
return json.loads(Path(path).read_text())
manager = RecordingMetadataManager()
service = PreviewAssetService(
metadata_manager=manager,
downloader_factory=lambda: asyncio.sleep(0, result=None),
exif_utils=FakeExifUtils(),
)
preview_calls: List[Dict[str, Any]] = []
async def update_preview(model_path: str, preview_path: str, nsfw: int) -> bool:
preview_calls.append({"model_path": model_path, "preview_path": preview_path, "nsfw": nsfw})
return True
model_path = str(tmp_path / "sample.safetensors")
Path(model_path).write_bytes(b"model")
result = asyncio.run(
service.replace_preview(
model_path=model_path,
preview_data=b"image-bytes",
content_type="image/png",
original_filename="preview.png",
nsfw_level=2,
update_preview_in_cache=update_preview,
metadata_loader=metadata_loader,
)
)
assert result["preview_nsfw_level"] == 2
assert preview_calls
saved_metadata = json.loads(metadata_path.read_text())
assert saved_metadata["preview_nsfw_level"] == 2
def test_download_coordinator_emits_progress() -> None:
class WSStub:
def __init__(self) -> None:
self.progress_events: List[Dict[str, Any]] = []
self.counter = 0
def generate_download_id(self) -> str:
self.counter += 1
return f"dl-{self.counter}"
async def broadcast_download_progress(self, download_id: str, payload: Dict[str, Any]) -> None:
self.progress_events.append({"id": download_id, **payload})
class DownloadManagerStub:
def __init__(self) -> None:
self.calls: List[Dict[str, Any]] = []
async def download_from_civitai(self, **kwargs) -> Dict[str, Any]:
self.calls.append(kwargs)
await kwargs["progress_callback"](10)
return {"success": True}
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
return {"success": True, "download_id": download_id}
async def get_active_downloads(self) -> Dict[str, Any]:
return {"active": []}
ws_stub = WSStub()
manager_stub = DownloadManagerStub()
coordinator = DownloadCoordinator(
ws_manager=ws_stub,
download_manager_factory=lambda: asyncio.sleep(0, result=manager_stub),
)
result = asyncio.run(coordinator.schedule_download({"model_id": 1}))
assert result["success"] is True
assert manager_stub.calls
assert ws_stub.progress_events
cancel_result = asyncio.run(coordinator.cancel_download(result["download_id"]))
assert cancel_result["success"] is True
active = asyncio.run(coordinator.list_active_downloads())
assert active == {"active": []}
def test_tag_update_service_adds_unique_tags(tmp_path: Path) -> None:
metadata_path = tmp_path / "model.metadata.json"
metadata_path.write_text(json.dumps({"tags": ["Existing"]}))
async def loader(path: str) -> Dict[str, Any]:
return json.loads(Path(path).read_text())
manager = RecordingMetadataManager()
service = TagUpdateService(metadata_manager=manager)
cache_updates: List[Dict[str, Any]] = []
async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool:
cache_updates.append(metadata)
return True
tags = asyncio.run(
service.add_tags(
file_path=str(tmp_path / "model.safetensors"),
new_tags=["New", "existing"],
metadata_loader=loader,
update_cache=update_cache,
)
)
assert tags == ["Existing", "New"]
assert manager.saved
assert cache_updates