mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
refactor(routes): extract route utilities into services
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import types
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
import asyncio
|
||||
import inspect
|
||||
from unittest import mock
|
||||
import sys
|
||||
|
||||
@@ -39,6 +41,13 @@ nodes_mock.NODE_CLASS_MAPPINGS = {}
|
||||
sys.modules['nodes'] = nodes_mock
|
||||
|
||||
|
||||
def pytest_pyfunc_call(pyfuncitem):
|
||||
if inspect.iscoroutinefunction(pyfuncitem.function):
|
||||
asyncio.run(pyfuncitem.obj(**pyfuncitem.funcargs))
|
||||
return True
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockHashIndex:
|
||||
"""Minimal hash index stub mirroring the scanner contract."""
|
||||
|
||||
@@ -1,8 +1,35 @@
|
||||
import pytest
|
||||
|
||||
from py.services.base_model_service import BaseModelService
|
||||
from py.services.model_query import ModelCacheRepository, ModelFilterSet, SearchStrategy, SortParams
|
||||
from py.utils.models import BaseModelMetadata
|
||||
import importlib
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
|
||||
def import_from(module_name: str):
|
||||
existing = sys.modules.get("py")
|
||||
if existing is None or getattr(existing, "__file__", "") != str(ROOT / "py/__init__.py"):
|
||||
sys.modules.pop("py", None)
|
||||
spec = importlib.util.spec_from_file_location("py", ROOT / "py/__init__.py")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
spec.loader.exec_module(module) # type: ignore[union-attr]
|
||||
module.__path__ = [str(ROOT / "py")]
|
||||
sys.modules["py"] = module
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
|
||||
BaseModelService = import_from("py.services.base_model_service").BaseModelService
|
||||
model_query_module = import_from("py.services.model_query")
|
||||
ModelCacheRepository = model_query_module.ModelCacheRepository
|
||||
ModelFilterSet = model_query_module.ModelFilterSet
|
||||
SearchStrategy = model_query_module.SearchStrategy
|
||||
SortParams = model_query_module.SortParams
|
||||
BaseModelMetadata = import_from("py.utils.models").BaseModelMetadata
|
||||
|
||||
|
||||
class StubSettings:
|
||||
|
||||
273
tests/services/test_route_support_services.py
Normal file
273
tests/services/test_route_support_services.py
Normal file
@@ -0,0 +1,273 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def import_from(module_name: str):
|
||||
existing = sys.modules.get("py")
|
||||
if existing is None or getattr(existing, "__file__", "") != str(ROOT / "py/__init__.py"):
|
||||
sys.modules.pop("py", None)
|
||||
spec = importlib.util.spec_from_file_location("py", ROOT / "py/__init__.py")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
spec.loader.exec_module(module) # type: ignore[union-attr]
|
||||
module.__path__ = [str(ROOT / "py")]
|
||||
sys.modules["py"] = module
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
|
||||
DownloadCoordinator = import_from("py.services.download_coordinator").DownloadCoordinator
|
||||
MetadataSyncService = import_from("py.services.metadata_sync_service").MetadataSyncService
|
||||
PreviewAssetService = import_from("py.services.preview_asset_service").PreviewAssetService
|
||||
TagUpdateService = import_from("py.services.tag_update_service").TagUpdateService
|
||||
|
||||
|
||||
class DummySettings:
|
||||
def __init__(self, values: Dict[str, Any] | None = None) -> None:
|
||||
self._values = values or {}
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return self._values.get(key, default)
|
||||
|
||||
|
||||
class RecordingMetadataManager:
|
||||
def __init__(self) -> None:
|
||||
self.saved: List[tuple[str, Dict[str, Any]]] = []
|
||||
|
||||
async def save_metadata(self, path: str, metadata: Dict[str, Any]) -> bool:
|
||||
self.saved.append((path, json.loads(json.dumps(metadata))))
|
||||
metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json"
|
||||
Path(metadata_path).write_text(json.dumps(metadata))
|
||||
return True
|
||||
|
||||
|
||||
class RecordingPreviewService:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[tuple[str, List[Dict[str, Any]]]] = []
|
||||
|
||||
async def ensure_preview_for_metadata(
|
||||
self, metadata_path: str, local_metadata: Dict[str, Any], images
|
||||
) -> None:
|
||||
self.calls.append((metadata_path, list(images or [])))
|
||||
local_metadata["preview_url"] = "preview.webp"
|
||||
local_metadata["preview_nsfw_level"] = 1
|
||||
|
||||
|
||||
class DummyProvider:
|
||||
def __init__(self, payload: Dict[str, Any]) -> None:
|
||||
self.payload = payload
|
||||
|
||||
async def get_model_by_hash(self, sha256: str):
|
||||
return self.payload, None
|
||||
|
||||
async def get_model_version(self, model_id: int, model_version_id: int | None):
|
||||
return self.payload
|
||||
|
||||
|
||||
class FakeExifUtils:
|
||||
@staticmethod
|
||||
def optimize_image(**kwargs):
|
||||
return kwargs["image_data"], {}
|
||||
|
||||
|
||||
def test_metadata_sync_merges_remote_fields(tmp_path: Path) -> None:
|
||||
manager = RecordingMetadataManager()
|
||||
preview = RecordingPreviewService()
|
||||
provider = DummyProvider({
|
||||
"baseModel": "SD15",
|
||||
"model": {"name": "Merged", "description": "desc", "tags": ["tag"], "creator": {"username": "user"}},
|
||||
"trainedWords": ["word"],
|
||||
"images": [{"url": "http://example", "nsfwLevel": 2, "type": "image"}],
|
||||
})
|
||||
|
||||
service = MetadataSyncService(
|
||||
metadata_manager=manager,
|
||||
preview_service=preview,
|
||||
settings=DummySettings(),
|
||||
default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider),
|
||||
metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider),
|
||||
)
|
||||
|
||||
metadata_path = str(tmp_path / "model.metadata.json")
|
||||
local_metadata = {"civitai": {"trainedWords": ["existing"]}}
|
||||
|
||||
updated = asyncio.run(service.update_model_metadata(metadata_path, local_metadata, provider.payload))
|
||||
|
||||
assert updated["model_name"] == "Merged"
|
||||
assert updated["modelDescription"] == "desc"
|
||||
assert set(updated["civitai"]["trainedWords"]) == {"existing", "word"}
|
||||
assert manager.saved
|
||||
assert preview.calls
|
||||
|
||||
|
||||
def test_metadata_sync_fetch_and_update_updates_cache(tmp_path: Path) -> None:
|
||||
manager = RecordingMetadataManager()
|
||||
preview = RecordingPreviewService()
|
||||
provider = DummyProvider({
|
||||
"baseModel": "SDXL",
|
||||
"model": {"name": "Updated"},
|
||||
"images": [],
|
||||
})
|
||||
|
||||
update_cache_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool:
|
||||
update_cache_calls.append({"original": original, "metadata": metadata})
|
||||
return True
|
||||
|
||||
service = MetadataSyncService(
|
||||
metadata_manager=manager,
|
||||
preview_service=preview,
|
||||
settings=DummySettings(),
|
||||
default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider),
|
||||
metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider),
|
||||
)
|
||||
|
||||
model_data = {"sha256": "abc", "file_path": str(tmp_path / "model.safetensors")}
|
||||
success, error = asyncio.run(
|
||||
service.fetch_and_update_model(
|
||||
sha256="abc",
|
||||
file_path=str(tmp_path / "model.safetensors"),
|
||||
model_data=model_data,
|
||||
update_cache_func=update_cache,
|
||||
)
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert error is None
|
||||
assert update_cache_calls
|
||||
assert manager.saved
|
||||
|
||||
|
||||
def test_preview_asset_service_replace_preview(tmp_path: Path) -> None:
|
||||
metadata_path = tmp_path / "sample.metadata.json"
|
||||
metadata_path.write_text(json.dumps({}))
|
||||
|
||||
async def metadata_loader(path: str) -> Dict[str, Any]:
|
||||
return json.loads(Path(path).read_text())
|
||||
|
||||
manager = RecordingMetadataManager()
|
||||
|
||||
service = PreviewAssetService(
|
||||
metadata_manager=manager,
|
||||
downloader_factory=lambda: asyncio.sleep(0, result=None),
|
||||
exif_utils=FakeExifUtils(),
|
||||
)
|
||||
|
||||
preview_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def update_preview(model_path: str, preview_path: str, nsfw: int) -> bool:
|
||||
preview_calls.append({"model_path": model_path, "preview_path": preview_path, "nsfw": nsfw})
|
||||
return True
|
||||
|
||||
model_path = str(tmp_path / "sample.safetensors")
|
||||
Path(model_path).write_bytes(b"model")
|
||||
|
||||
result = asyncio.run(
|
||||
service.replace_preview(
|
||||
model_path=model_path,
|
||||
preview_data=b"image-bytes",
|
||||
content_type="image/png",
|
||||
original_filename="preview.png",
|
||||
nsfw_level=2,
|
||||
update_preview_in_cache=update_preview,
|
||||
metadata_loader=metadata_loader,
|
||||
)
|
||||
)
|
||||
|
||||
assert result["preview_nsfw_level"] == 2
|
||||
assert preview_calls
|
||||
saved_metadata = json.loads(metadata_path.read_text())
|
||||
assert saved_metadata["preview_nsfw_level"] == 2
|
||||
|
||||
|
||||
def test_download_coordinator_emits_progress() -> None:
|
||||
class WSStub:
|
||||
def __init__(self) -> None:
|
||||
self.progress_events: List[Dict[str, Any]] = []
|
||||
self.counter = 0
|
||||
|
||||
def generate_download_id(self) -> str:
|
||||
self.counter += 1
|
||||
return f"dl-{self.counter}"
|
||||
|
||||
async def broadcast_download_progress(self, download_id: str, payload: Dict[str, Any]) -> None:
|
||||
self.progress_events.append({"id": download_id, **payload})
|
||||
|
||||
class DownloadManagerStub:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def download_from_civitai(self, **kwargs) -> Dict[str, Any]:
|
||||
self.calls.append(kwargs)
|
||||
await kwargs["progress_callback"](10)
|
||||
return {"success": True}
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||
return {"success": True, "download_id": download_id}
|
||||
|
||||
async def get_active_downloads(self) -> Dict[str, Any]:
|
||||
return {"active": []}
|
||||
|
||||
ws_stub = WSStub()
|
||||
manager_stub = DownloadManagerStub()
|
||||
|
||||
coordinator = DownloadCoordinator(
|
||||
ws_manager=ws_stub,
|
||||
download_manager_factory=lambda: asyncio.sleep(0, result=manager_stub),
|
||||
)
|
||||
|
||||
result = asyncio.run(coordinator.schedule_download({"model_id": 1}))
|
||||
|
||||
assert result["success"] is True
|
||||
assert manager_stub.calls
|
||||
assert ws_stub.progress_events
|
||||
|
||||
cancel_result = asyncio.run(coordinator.cancel_download(result["download_id"]))
|
||||
assert cancel_result["success"] is True
|
||||
|
||||
active = asyncio.run(coordinator.list_active_downloads())
|
||||
assert active == {"active": []}
|
||||
|
||||
|
||||
def test_tag_update_service_adds_unique_tags(tmp_path: Path) -> None:
|
||||
metadata_path = tmp_path / "model.metadata.json"
|
||||
metadata_path.write_text(json.dumps({"tags": ["Existing"]}))
|
||||
|
||||
async def loader(path: str) -> Dict[str, Any]:
|
||||
return json.loads(Path(path).read_text())
|
||||
|
||||
manager = RecordingMetadataManager()
|
||||
|
||||
service = TagUpdateService(metadata_manager=manager)
|
||||
|
||||
cache_updates: List[Dict[str, Any]] = []
|
||||
|
||||
async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool:
|
||||
cache_updates.append(metadata)
|
||||
return True
|
||||
|
||||
tags = asyncio.run(
|
||||
service.add_tags(
|
||||
file_path=str(tmp_path / "model.safetensors"),
|
||||
new_tags=["New", "existing"],
|
||||
metadata_loader=loader,
|
||||
update_cache=update_cache,
|
||||
)
|
||||
)
|
||||
|
||||
assert tags == ["Existing", "New"]
|
||||
assert manager.saved
|
||||
assert cache_updates
|
||||
Reference in New Issue
Block a user