mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 05:32:12 -03:00
Convert all tags to lowercase in tag processing logic to prevent case sensitivity issues on Windows filesystems. This ensures consistent tag matching and prevents duplicate tags with different cases from being created. Changes include: - TagUpdateService now converts tags to lowercase before comparison - Utils function converts model tags to lowercase before priority resolution - Test cases updated to reflect lowercase tag expectations
270 lines
9.1 KiB
Python
270 lines
9.1 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List
|
|
|
|
import pytest
|
|
|
|
from py.services.download_coordinator import DownloadCoordinator
|
|
from py.services.downloader import DownloadProgress
|
|
from py.services.metadata_sync_service import MetadataSyncService
|
|
from py.services.preview_asset_service import PreviewAssetService
|
|
from py.services.tag_update_service import TagUpdateService
|
|
|
|
|
|
class DummySettings:
|
|
def __init__(self, values: Dict[str, Any] | None = None) -> None:
|
|
self._values = values or {}
|
|
|
|
def get(self, key: str, default: Any = None) -> Any:
|
|
return self._values.get(key, default)
|
|
|
|
|
|
class RecordingMetadataManager:
|
|
def __init__(self) -> None:
|
|
self.saved: List[tuple[str, Dict[str, Any]]] = []
|
|
|
|
async def save_metadata(self, path: str, metadata: Dict[str, Any]) -> bool:
|
|
self.saved.append((path, json.loads(json.dumps(metadata))))
|
|
metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json"
|
|
Path(metadata_path).write_text(json.dumps(metadata))
|
|
return True
|
|
|
|
async def hydrate_model_data(self, model_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
return model_data
|
|
|
|
|
|
class RecordingPreviewService:
|
|
def __init__(self) -> None:
|
|
self.calls: List[tuple[str, List[Dict[str, Any]]]] = []
|
|
|
|
async def ensure_preview_for_metadata(
|
|
self, metadata_path: str, local_metadata: Dict[str, Any], images
|
|
) -> None:
|
|
self.calls.append((metadata_path, list(images or [])))
|
|
local_metadata["preview_url"] = "preview.webp"
|
|
local_metadata["preview_nsfw_level"] = 1
|
|
|
|
|
|
class DummyProvider:
|
|
def __init__(self, payload: Dict[str, Any]) -> None:
|
|
self.payload = payload
|
|
|
|
async def get_model_by_hash(self, sha256: str):
|
|
return self.payload, None
|
|
|
|
async def get_model_version(self, model_id: int, model_version_id: int | None):
|
|
return self.payload
|
|
|
|
|
|
class FakeExifUtils:
|
|
@staticmethod
|
|
def optimize_image(**kwargs):
|
|
return kwargs["image_data"], {}
|
|
|
|
|
|
def test_metadata_sync_merges_remote_fields(tmp_path: Path) -> None:
|
|
manager = RecordingMetadataManager()
|
|
preview = RecordingPreviewService()
|
|
provider = DummyProvider({
|
|
"baseModel": "SD15",
|
|
"model": {"name": "Merged", "description": "desc", "tags": ["tag"], "creator": {"username": "user"}},
|
|
"trainedWords": ["word"],
|
|
"images": [{"url": "http://example", "nsfwLevel": 2, "type": "image"}],
|
|
})
|
|
|
|
service = MetadataSyncService(
|
|
metadata_manager=manager,
|
|
preview_service=preview,
|
|
settings=DummySettings(),
|
|
default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider),
|
|
metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider),
|
|
)
|
|
|
|
metadata_path = str(tmp_path / "model.metadata.json")
|
|
local_metadata = {"civitai": {"trainedWords": ["existing"]}}
|
|
|
|
updated = asyncio.run(service.update_model_metadata(metadata_path, local_metadata, provider.payload))
|
|
|
|
assert updated["model_name"] == "Merged"
|
|
assert updated["modelDescription"] == "desc"
|
|
assert set(updated["civitai"]["trainedWords"]) == {"existing", "word"}
|
|
assert manager.saved
|
|
assert preview.calls
|
|
|
|
|
|
def test_metadata_sync_fetch_and_update_updates_cache(tmp_path: Path) -> None:
|
|
manager = RecordingMetadataManager()
|
|
preview = RecordingPreviewService()
|
|
provider = DummyProvider({
|
|
"baseModel": "SDXL",
|
|
"model": {"name": "Updated"},
|
|
"images": [],
|
|
})
|
|
|
|
update_cache_calls: List[Dict[str, Any]] = []
|
|
|
|
async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool:
|
|
update_cache_calls.append({"original": original, "metadata": metadata})
|
|
return True
|
|
|
|
service = MetadataSyncService(
|
|
metadata_manager=manager,
|
|
preview_service=preview,
|
|
settings=DummySettings(),
|
|
default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider),
|
|
metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider),
|
|
)
|
|
|
|
model_data = {"sha256": "abc", "file_path": str(tmp_path / "model.safetensors")}
|
|
asyncio.run(manager.hydrate_model_data(model_data))
|
|
success, error = asyncio.run(
|
|
service.fetch_and_update_model(
|
|
sha256="abc",
|
|
file_path=str(tmp_path / "model.safetensors"),
|
|
model_data=model_data,
|
|
update_cache_func=update_cache,
|
|
)
|
|
)
|
|
|
|
assert success is True
|
|
assert error is None
|
|
assert update_cache_calls
|
|
assert manager.saved
|
|
|
|
|
|
def test_preview_asset_service_replace_preview(tmp_path: Path) -> None:
|
|
metadata_path = tmp_path / "sample.metadata.json"
|
|
metadata_path.write_text(json.dumps({}))
|
|
|
|
async def metadata_loader(path: str) -> Dict[str, Any]:
|
|
return json.loads(Path(path).read_text())
|
|
|
|
manager = RecordingMetadataManager()
|
|
|
|
service = PreviewAssetService(
|
|
metadata_manager=manager,
|
|
downloader_factory=lambda: asyncio.sleep(0, result=None),
|
|
exif_utils=FakeExifUtils(),
|
|
)
|
|
|
|
preview_calls: List[Dict[str, Any]] = []
|
|
|
|
async def update_preview(model_path: str, preview_path: str, nsfw: int) -> bool:
|
|
preview_calls.append({"model_path": model_path, "preview_path": preview_path, "nsfw": nsfw})
|
|
return True
|
|
|
|
model_path = str(tmp_path / "sample.safetensors")
|
|
Path(model_path).write_bytes(b"model")
|
|
|
|
result = asyncio.run(
|
|
service.replace_preview(
|
|
model_path=model_path,
|
|
preview_data=b"image-bytes",
|
|
content_type="image/png",
|
|
original_filename="preview.png",
|
|
nsfw_level=2,
|
|
update_preview_in_cache=update_preview,
|
|
metadata_loader=metadata_loader,
|
|
)
|
|
)
|
|
|
|
assert result["preview_nsfw_level"] == 2
|
|
assert preview_calls
|
|
saved_metadata = json.loads(metadata_path.read_text())
|
|
assert saved_metadata["preview_nsfw_level"] == 2
|
|
|
|
|
|
def test_download_coordinator_emits_progress() -> None:
|
|
class WSStub:
|
|
def __init__(self) -> None:
|
|
self.progress_events: List[Dict[str, Any]] = []
|
|
self.counter = 0
|
|
|
|
def generate_download_id(self) -> str:
|
|
self.counter += 1
|
|
return f"dl-{self.counter}"
|
|
|
|
async def broadcast_download_progress(self, download_id: str, payload: Dict[str, Any]) -> None:
|
|
self.progress_events.append({"id": download_id, **payload})
|
|
|
|
class DownloadManagerStub:
|
|
def __init__(self) -> None:
|
|
self.calls: List[Dict[str, Any]] = []
|
|
self.snapshot = DownloadProgress(
|
|
percent_complete=25.0,
|
|
bytes_downloaded=256,
|
|
total_bytes=1024,
|
|
bytes_per_second=128.0,
|
|
timestamp=0.0,
|
|
)
|
|
|
|
async def download_from_civitai(self, **kwargs) -> Dict[str, Any]:
|
|
self.calls.append(kwargs)
|
|
await kwargs["progress_callback"](self.snapshot)
|
|
return {"success": True}
|
|
|
|
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
|
return {"success": True, "download_id": download_id}
|
|
|
|
async def get_active_downloads(self) -> Dict[str, Any]:
|
|
return {"active": []}
|
|
|
|
ws_stub = WSStub()
|
|
manager_stub = DownloadManagerStub()
|
|
|
|
coordinator = DownloadCoordinator(
|
|
ws_manager=ws_stub,
|
|
download_manager_factory=lambda: asyncio.sleep(0, result=manager_stub),
|
|
)
|
|
|
|
result = asyncio.run(coordinator.schedule_download({"model_id": 1}))
|
|
|
|
assert result["success"] is True
|
|
assert manager_stub.calls
|
|
assert ws_stub.progress_events
|
|
expected_progress = round(manager_stub.snapshot.percent_complete)
|
|
first_event = ws_stub.progress_events[0]
|
|
assert first_event["progress"] == expected_progress
|
|
assert first_event["bytes_downloaded"] == manager_stub.snapshot.bytes_downloaded
|
|
assert first_event["total_bytes"] == manager_stub.snapshot.total_bytes
|
|
assert first_event["bytes_per_second"] == manager_stub.snapshot.bytes_per_second
|
|
|
|
cancel_result = asyncio.run(coordinator.cancel_download(result["download_id"]))
|
|
assert cancel_result["success"] is True
|
|
|
|
active = asyncio.run(coordinator.list_active_downloads())
|
|
assert active == {"active": []}
|
|
|
|
|
|
def test_tag_update_service_adds_unique_tags(tmp_path: Path) -> None:
|
|
metadata_path = tmp_path / "model.metadata.json"
|
|
metadata_path.write_text(json.dumps({"tags": ["existing"]}))
|
|
|
|
async def loader(path: str) -> Dict[str, Any]:
|
|
return json.loads(Path(path).read_text())
|
|
|
|
manager = RecordingMetadataManager()
|
|
|
|
service = TagUpdateService(metadata_manager=manager)
|
|
|
|
cache_updates: List[Dict[str, Any]] = []
|
|
|
|
async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool:
|
|
cache_updates.append(metadata)
|
|
return True
|
|
|
|
tags = asyncio.run(
|
|
service.add_tags(
|
|
file_path=str(tmp_path / "model.safetensors"),
|
|
new_tags=["new", "existing"],
|
|
metadata_loader=loader,
|
|
update_cache=update_cache,
|
|
)
|
|
)
|
|
|
|
assert tags == ["existing", "new"]
|
|
assert manager.saved
|
|
assert cache_updates
|