Files
ComfyUI-Lora-Manager/tests/services/test_route_support_services.py
Will Miao 3d6bb432c4 feat: normalize tags to lowercase for Windows compatibility, see #637
Convert all tags to lowercase in tag processing logic to prevent case sensitivity issues on Windows filesystems. This ensures consistent tag matching and prevents duplicate tags with different cases from being created.

Changes include:
- TagUpdateService now converts tags to lowercase before comparison
- Utils function converts model tags to lowercase before priority resolution
- Test cases updated to reflect lowercase tag expectations
2025-11-04 12:54:09 +08:00

270 lines
9.1 KiB
Python

import asyncio
import json
import os
from pathlib import Path
from typing import Any, Dict, List
import pytest
from py.services.download_coordinator import DownloadCoordinator
from py.services.downloader import DownloadProgress
from py.services.metadata_sync_service import MetadataSyncService
from py.services.preview_asset_service import PreviewAssetService
from py.services.tag_update_service import TagUpdateService
class DummySettings:
def __init__(self, values: Dict[str, Any] | None = None) -> None:
self._values = values or {}
def get(self, key: str, default: Any = None) -> Any:
return self._values.get(key, default)
class RecordingMetadataManager:
def __init__(self) -> None:
self.saved: List[tuple[str, Dict[str, Any]]] = []
async def save_metadata(self, path: str, metadata: Dict[str, Any]) -> bool:
self.saved.append((path, json.loads(json.dumps(metadata))))
metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json"
Path(metadata_path).write_text(json.dumps(metadata))
return True
async def hydrate_model_data(self, model_data: Dict[str, Any]) -> Dict[str, Any]:
return model_data
class RecordingPreviewService:
def __init__(self) -> None:
self.calls: List[tuple[str, List[Dict[str, Any]]]] = []
async def ensure_preview_for_metadata(
self, metadata_path: str, local_metadata: Dict[str, Any], images
) -> None:
self.calls.append((metadata_path, list(images or [])))
local_metadata["preview_url"] = "preview.webp"
local_metadata["preview_nsfw_level"] = 1
class DummyProvider:
def __init__(self, payload: Dict[str, Any]) -> None:
self.payload = payload
async def get_model_by_hash(self, sha256: str):
return self.payload, None
async def get_model_version(self, model_id: int, model_version_id: int | None):
return self.payload
class FakeExifUtils:
@staticmethod
def optimize_image(**kwargs):
return kwargs["image_data"], {}
def test_metadata_sync_merges_remote_fields(tmp_path: Path) -> None:
manager = RecordingMetadataManager()
preview = RecordingPreviewService()
provider = DummyProvider({
"baseModel": "SD15",
"model": {"name": "Merged", "description": "desc", "tags": ["tag"], "creator": {"username": "user"}},
"trainedWords": ["word"],
"images": [{"url": "http://example", "nsfwLevel": 2, "type": "image"}],
})
service = MetadataSyncService(
metadata_manager=manager,
preview_service=preview,
settings=DummySettings(),
default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider),
metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider),
)
metadata_path = str(tmp_path / "model.metadata.json")
local_metadata = {"civitai": {"trainedWords": ["existing"]}}
updated = asyncio.run(service.update_model_metadata(metadata_path, local_metadata, provider.payload))
assert updated["model_name"] == "Merged"
assert updated["modelDescription"] == "desc"
assert set(updated["civitai"]["trainedWords"]) == {"existing", "word"}
assert manager.saved
assert preview.calls
def test_metadata_sync_fetch_and_update_updates_cache(tmp_path: Path) -> None:
manager = RecordingMetadataManager()
preview = RecordingPreviewService()
provider = DummyProvider({
"baseModel": "SDXL",
"model": {"name": "Updated"},
"images": [],
})
update_cache_calls: List[Dict[str, Any]] = []
async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool:
update_cache_calls.append({"original": original, "metadata": metadata})
return True
service = MetadataSyncService(
metadata_manager=manager,
preview_service=preview,
settings=DummySettings(),
default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider),
metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider),
)
model_data = {"sha256": "abc", "file_path": str(tmp_path / "model.safetensors")}
asyncio.run(manager.hydrate_model_data(model_data))
success, error = asyncio.run(
service.fetch_and_update_model(
sha256="abc",
file_path=str(tmp_path / "model.safetensors"),
model_data=model_data,
update_cache_func=update_cache,
)
)
assert success is True
assert error is None
assert update_cache_calls
assert manager.saved
def test_preview_asset_service_replace_preview(tmp_path: Path) -> None:
metadata_path = tmp_path / "sample.metadata.json"
metadata_path.write_text(json.dumps({}))
async def metadata_loader(path: str) -> Dict[str, Any]:
return json.loads(Path(path).read_text())
manager = RecordingMetadataManager()
service = PreviewAssetService(
metadata_manager=manager,
downloader_factory=lambda: asyncio.sleep(0, result=None),
exif_utils=FakeExifUtils(),
)
preview_calls: List[Dict[str, Any]] = []
async def update_preview(model_path: str, preview_path: str, nsfw: int) -> bool:
preview_calls.append({"model_path": model_path, "preview_path": preview_path, "nsfw": nsfw})
return True
model_path = str(tmp_path / "sample.safetensors")
Path(model_path).write_bytes(b"model")
result = asyncio.run(
service.replace_preview(
model_path=model_path,
preview_data=b"image-bytes",
content_type="image/png",
original_filename="preview.png",
nsfw_level=2,
update_preview_in_cache=update_preview,
metadata_loader=metadata_loader,
)
)
assert result["preview_nsfw_level"] == 2
assert preview_calls
saved_metadata = json.loads(metadata_path.read_text())
assert saved_metadata["preview_nsfw_level"] == 2
def test_download_coordinator_emits_progress() -> None:
class WSStub:
def __init__(self) -> None:
self.progress_events: List[Dict[str, Any]] = []
self.counter = 0
def generate_download_id(self) -> str:
self.counter += 1
return f"dl-{self.counter}"
async def broadcast_download_progress(self, download_id: str, payload: Dict[str, Any]) -> None:
self.progress_events.append({"id": download_id, **payload})
class DownloadManagerStub:
def __init__(self) -> None:
self.calls: List[Dict[str, Any]] = []
self.snapshot = DownloadProgress(
percent_complete=25.0,
bytes_downloaded=256,
total_bytes=1024,
bytes_per_second=128.0,
timestamp=0.0,
)
async def download_from_civitai(self, **kwargs) -> Dict[str, Any]:
self.calls.append(kwargs)
await kwargs["progress_callback"](self.snapshot)
return {"success": True}
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
return {"success": True, "download_id": download_id}
async def get_active_downloads(self) -> Dict[str, Any]:
return {"active": []}
ws_stub = WSStub()
manager_stub = DownloadManagerStub()
coordinator = DownloadCoordinator(
ws_manager=ws_stub,
download_manager_factory=lambda: asyncio.sleep(0, result=manager_stub),
)
result = asyncio.run(coordinator.schedule_download({"model_id": 1}))
assert result["success"] is True
assert manager_stub.calls
assert ws_stub.progress_events
expected_progress = round(manager_stub.snapshot.percent_complete)
first_event = ws_stub.progress_events[0]
assert first_event["progress"] == expected_progress
assert first_event["bytes_downloaded"] == manager_stub.snapshot.bytes_downloaded
assert first_event["total_bytes"] == manager_stub.snapshot.total_bytes
assert first_event["bytes_per_second"] == manager_stub.snapshot.bytes_per_second
cancel_result = asyncio.run(coordinator.cancel_download(result["download_id"]))
assert cancel_result["success"] is True
active = asyncio.run(coordinator.list_active_downloads())
assert active == {"active": []}
def test_tag_update_service_adds_unique_tags(tmp_path: Path) -> None:
metadata_path = tmp_path / "model.metadata.json"
metadata_path.write_text(json.dumps({"tags": ["existing"]}))
async def loader(path: str) -> Dict[str, Any]:
return json.loads(Path(path).read_text())
manager = RecordingMetadataManager()
service = TagUpdateService(metadata_manager=manager)
cache_updates: List[Dict[str, Any]] = []
async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool:
cache_updates.append(metadata)
return True
tags = asyncio.run(
service.add_tags(
file_path=str(tmp_path / "model.safetensors"),
new_tags=["new", "existing"],
metadata_loader=loader,
update_cache=update_cache,
)
)
assert tags == ["existing", "new"]
assert manager.saved
assert cache_updates