mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 13:42:12 -03:00
270 lines
9.1 KiB
Python
270 lines
9.1 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List
|
|
|
|
import pytest
|
|
|
|
from py.services.download_coordinator import DownloadCoordinator
|
|
from py.services.downloader import DownloadProgress
|
|
from py.services.metadata_sync_service import MetadataSyncService
|
|
from py.services.preview_asset_service import PreviewAssetService
|
|
from py.services.tag_update_service import TagUpdateService
|
|
|
|
|
|
class DummySettings:
|
|
def __init__(self, values: Dict[str, Any] | None = None) -> None:
|
|
self._values = values or {}
|
|
|
|
def get(self, key: str, default: Any = None) -> Any:
|
|
return self._values.get(key, default)
|
|
|
|
|
|
class RecordingMetadataManager:
|
|
def __init__(self) -> None:
|
|
self.saved: List[tuple[str, Dict[str, Any]]] = []
|
|
|
|
async def save_metadata(self, path: str, metadata: Dict[str, Any]) -> bool:
|
|
self.saved.append((path, json.loads(json.dumps(metadata))))
|
|
metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json"
|
|
Path(metadata_path).write_text(json.dumps(metadata))
|
|
return True
|
|
|
|
async def hydrate_model_data(self, model_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
return model_data
|
|
|
|
|
|
class RecordingPreviewService:
|
|
def __init__(self) -> None:
|
|
self.calls: List[tuple[str, List[Dict[str, Any]]]] = []
|
|
|
|
async def ensure_preview_for_metadata(
|
|
self, metadata_path: str, local_metadata: Dict[str, Any], images
|
|
) -> None:
|
|
self.calls.append((metadata_path, list(images or [])))
|
|
local_metadata["preview_url"] = "preview.webp"
|
|
local_metadata["preview_nsfw_level"] = 1
|
|
|
|
|
|
class DummyProvider:
|
|
def __init__(self, payload: Dict[str, Any]) -> None:
|
|
self.payload = payload
|
|
|
|
async def get_model_by_hash(self, sha256: str):
|
|
return self.payload, None
|
|
|
|
async def get_model_version(self, model_id: int, model_version_id: int | None):
|
|
return self.payload
|
|
|
|
|
|
class FakeExifUtils:
|
|
@staticmethod
|
|
def optimize_image(**kwargs):
|
|
return kwargs["image_data"], {}
|
|
|
|
|
|
def test_metadata_sync_merges_remote_fields(tmp_path: Path) -> None:
|
|
manager = RecordingMetadataManager()
|
|
preview = RecordingPreviewService()
|
|
provider = DummyProvider({
|
|
"baseModel": "SD15",
|
|
"model": {"name": "Merged", "description": "desc", "tags": ["tag"], "creator": {"username": "user"}},
|
|
"trainedWords": ["word"],
|
|
"images": [{"url": "http://example", "nsfwLevel": 2, "type": "image"}],
|
|
})
|
|
|
|
service = MetadataSyncService(
|
|
metadata_manager=manager,
|
|
preview_service=preview,
|
|
settings=DummySettings(),
|
|
default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider),
|
|
metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider),
|
|
)
|
|
|
|
metadata_path = str(tmp_path / "model.metadata.json")
|
|
local_metadata = {"civitai": {"trainedWords": ["existing"]}}
|
|
|
|
updated = asyncio.run(service.update_model_metadata(metadata_path, local_metadata, provider.payload))
|
|
|
|
assert updated["model_name"] == "Merged"
|
|
assert updated["modelDescription"] == "desc"
|
|
assert set(updated["civitai"]["trainedWords"]) == {"existing", "word"}
|
|
assert manager.saved
|
|
assert preview.calls
|
|
|
|
|
|
def test_metadata_sync_fetch_and_update_updates_cache(tmp_path: Path) -> None:
|
|
manager = RecordingMetadataManager()
|
|
preview = RecordingPreviewService()
|
|
provider = DummyProvider({
|
|
"baseModel": "SDXL",
|
|
"model": {"name": "Updated"},
|
|
"images": [],
|
|
})
|
|
|
|
update_cache_calls: List[Dict[str, Any]] = []
|
|
|
|
async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool:
|
|
update_cache_calls.append({"original": original, "metadata": metadata})
|
|
return True
|
|
|
|
service = MetadataSyncService(
|
|
metadata_manager=manager,
|
|
preview_service=preview,
|
|
settings=DummySettings(),
|
|
default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider),
|
|
metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider),
|
|
)
|
|
|
|
model_data = {"sha256": "abc", "file_path": str(tmp_path / "model.safetensors")}
|
|
asyncio.run(manager.hydrate_model_data(model_data))
|
|
success, error = asyncio.run(
|
|
service.fetch_and_update_model(
|
|
sha256="abc",
|
|
file_path=str(tmp_path / "model.safetensors"),
|
|
model_data=model_data,
|
|
update_cache_func=update_cache,
|
|
)
|
|
)
|
|
|
|
assert success is True
|
|
assert error is None
|
|
assert update_cache_calls
|
|
assert manager.saved
|
|
|
|
|
|
def test_preview_asset_service_replace_preview(tmp_path: Path) -> None:
|
|
metadata_path = tmp_path / "sample.metadata.json"
|
|
metadata_path.write_text(json.dumps({}))
|
|
|
|
async def metadata_loader(path: str) -> Dict[str, Any]:
|
|
return json.loads(Path(path).read_text())
|
|
|
|
manager = RecordingMetadataManager()
|
|
|
|
service = PreviewAssetService(
|
|
metadata_manager=manager,
|
|
downloader_factory=lambda: asyncio.sleep(0, result=None),
|
|
exif_utils=FakeExifUtils(),
|
|
)
|
|
|
|
preview_calls: List[Dict[str, Any]] = []
|
|
|
|
async def update_preview(model_path: str, preview_path: str, nsfw: int) -> bool:
|
|
preview_calls.append({"model_path": model_path, "preview_path": preview_path, "nsfw": nsfw})
|
|
return True
|
|
|
|
model_path = str(tmp_path / "sample.safetensors")
|
|
Path(model_path).write_bytes(b"model")
|
|
|
|
result = asyncio.run(
|
|
service.replace_preview(
|
|
model_path=model_path,
|
|
preview_data=b"image-bytes",
|
|
content_type="image/png",
|
|
original_filename="preview.png",
|
|
nsfw_level=2,
|
|
update_preview_in_cache=update_preview,
|
|
metadata_loader=metadata_loader,
|
|
)
|
|
)
|
|
|
|
assert result["preview_nsfw_level"] == 2
|
|
assert preview_calls
|
|
saved_metadata = json.loads(metadata_path.read_text())
|
|
assert saved_metadata["preview_nsfw_level"] == 2
|
|
|
|
|
|
def test_download_coordinator_emits_progress() -> None:
|
|
class WSStub:
|
|
def __init__(self) -> None:
|
|
self.progress_events: List[Dict[str, Any]] = []
|
|
self.counter = 0
|
|
|
|
def generate_download_id(self) -> str:
|
|
self.counter += 1
|
|
return f"dl-{self.counter}"
|
|
|
|
async def broadcast_download_progress(self, download_id: str, payload: Dict[str, Any]) -> None:
|
|
self.progress_events.append({"id": download_id, **payload})
|
|
|
|
class DownloadManagerStub:
|
|
def __init__(self) -> None:
|
|
self.calls: List[Dict[str, Any]] = []
|
|
self.snapshot = DownloadProgress(
|
|
percent_complete=25.0,
|
|
bytes_downloaded=256,
|
|
total_bytes=1024,
|
|
bytes_per_second=128.0,
|
|
timestamp=0.0,
|
|
)
|
|
|
|
async def download_from_civitai(self, **kwargs) -> Dict[str, Any]:
|
|
self.calls.append(kwargs)
|
|
await kwargs["progress_callback"](self.snapshot)
|
|
return {"success": True}
|
|
|
|
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
|
return {"success": True, "download_id": download_id}
|
|
|
|
async def get_active_downloads(self) -> Dict[str, Any]:
|
|
return {"active": []}
|
|
|
|
ws_stub = WSStub()
|
|
manager_stub = DownloadManagerStub()
|
|
|
|
coordinator = DownloadCoordinator(
|
|
ws_manager=ws_stub,
|
|
download_manager_factory=lambda: asyncio.sleep(0, result=manager_stub),
|
|
)
|
|
|
|
result = asyncio.run(coordinator.schedule_download({"model_id": 1}))
|
|
|
|
assert result["success"] is True
|
|
assert manager_stub.calls
|
|
assert ws_stub.progress_events
|
|
expected_progress = round(manager_stub.snapshot.percent_complete)
|
|
first_event = ws_stub.progress_events[0]
|
|
assert first_event["progress"] == expected_progress
|
|
assert first_event["bytes_downloaded"] == manager_stub.snapshot.bytes_downloaded
|
|
assert first_event["total_bytes"] == manager_stub.snapshot.total_bytes
|
|
assert first_event["bytes_per_second"] == manager_stub.snapshot.bytes_per_second
|
|
|
|
cancel_result = asyncio.run(coordinator.cancel_download(result["download_id"]))
|
|
assert cancel_result["success"] is True
|
|
|
|
active = asyncio.run(coordinator.list_active_downloads())
|
|
assert active == {"active": []}
|
|
|
|
|
|
def test_tag_update_service_adds_unique_tags(tmp_path: Path) -> None:
|
|
metadata_path = tmp_path / "model.metadata.json"
|
|
metadata_path.write_text(json.dumps({"tags": ["Existing"]}))
|
|
|
|
async def loader(path: str) -> Dict[str, Any]:
|
|
return json.loads(Path(path).read_text())
|
|
|
|
manager = RecordingMetadataManager()
|
|
|
|
service = TagUpdateService(metadata_manager=manager)
|
|
|
|
cache_updates: List[Dict[str, Any]] = []
|
|
|
|
async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool:
|
|
cache_updates.append(metadata)
|
|
return True
|
|
|
|
tags = asyncio.run(
|
|
service.add_tags(
|
|
file_path=str(tmp_path / "model.safetensors"),
|
|
new_tags=["New", "existing"],
|
|
metadata_loader=loader,
|
|
update_cache=update_cache,
|
|
)
|
|
)
|
|
|
|
assert tags == ["Existing", "New"]
|
|
assert manager.saved
|
|
assert cache_updates
|