mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
- Rename `preview_overrides` to `version_context` to better reflect expanded purpose - Add file_path and file_name fields to version serialization - Update method names and parameter signatures for consistency - Include file metadata from cache in version context building - Maintain backward compatibility with existing preview URL functionality The changes provide more comprehensive version information including file details while maintaining existing preview override behavior.
155 lines
4.3 KiB
Python
155 lines
4.3 KiB
Python
import json
|
|
import logging
|
|
from types import SimpleNamespace
|
|
|
|
import pytest
|
|
|
|
from py.config import config
|
|
from py.routes.handlers.model_handlers import ModelUpdateHandler
|
|
from py.services.model_update_service import ModelUpdateRecord, ModelVersionRecord
|
|
|
|
|
|
class DummyScanner:
|
|
def __init__(self, cache):
|
|
self._cache = cache
|
|
|
|
async def get_cached_data(self):
|
|
return self._cache
|
|
|
|
|
|
class DummyService:
|
|
def __init__(self, cache):
|
|
self.model_type = "lora"
|
|
self.scanner = DummyScanner(cache)
|
|
|
|
|
|
class DummyUpdateService:
|
|
def __init__(self, records):
|
|
self.records = records
|
|
self.calls = []
|
|
|
|
async def refresh_for_model_type(self, model_type, scanner, provider, *, force_refresh=False):
|
|
self.calls.append(
|
|
{
|
|
"model_type": model_type,
|
|
"scanner": scanner,
|
|
"provider": provider,
|
|
"force_refresh": force_refresh,
|
|
}
|
|
)
|
|
return self.records
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_build_version_context_includes_static_urls():
|
|
cache = SimpleNamespace(version_index={123: {"preview_url": "/tmp/previews/example.png"}})
|
|
service = DummyService(cache)
|
|
handler = ModelUpdateHandler(
|
|
service=service,
|
|
update_service=SimpleNamespace(),
|
|
metadata_provider_selector=lambda *_: None,
|
|
logger=logging.getLogger(__name__),
|
|
)
|
|
|
|
record = ModelUpdateRecord(
|
|
model_type="lora",
|
|
model_id=42,
|
|
versions=[
|
|
ModelVersionRecord(
|
|
version_id=123,
|
|
name=None,
|
|
base_model=None,
|
|
released_at=None,
|
|
size_bytes=None,
|
|
preview_url=None,
|
|
is_in_library=True,
|
|
should_ignore=False,
|
|
)
|
|
],
|
|
last_checked_at=None,
|
|
should_ignore_model=False,
|
|
)
|
|
|
|
overrides = await handler._build_version_context(record)
|
|
expected = config.get_preview_static_url("/tmp/previews/example.png")
|
|
assert overrides == {123: {"file_path": None, "file_name": None, "preview_override": expected}}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_model_updates_filters_records_without_updates():
|
|
cache = SimpleNamespace(version_index={})
|
|
service = DummyService(cache)
|
|
|
|
record_with_update = ModelUpdateRecord(
|
|
model_type="lora",
|
|
model_id=1,
|
|
versions=[
|
|
ModelVersionRecord(
|
|
version_id=10,
|
|
name="v1",
|
|
base_model=None,
|
|
released_at=None,
|
|
size_bytes=None,
|
|
preview_url=None,
|
|
is_in_library=False,
|
|
should_ignore=False,
|
|
)
|
|
],
|
|
last_checked_at=None,
|
|
should_ignore_model=False,
|
|
)
|
|
record_without_update = ModelUpdateRecord(
|
|
model_type="lora",
|
|
model_id=2,
|
|
versions=[
|
|
ModelVersionRecord(
|
|
version_id=20,
|
|
name="v2",
|
|
base_model=None,
|
|
released_at=None,
|
|
size_bytes=None,
|
|
preview_url=None,
|
|
is_in_library=True,
|
|
should_ignore=False,
|
|
)
|
|
],
|
|
last_checked_at=None,
|
|
should_ignore_model=False,
|
|
)
|
|
|
|
update_service = DummyUpdateService({1: record_with_update, 2: record_without_update})
|
|
|
|
async def metadata_selector(name):
|
|
assert name == "civitai_api"
|
|
return object()
|
|
|
|
handler = ModelUpdateHandler(
|
|
service=service,
|
|
update_service=update_service,
|
|
metadata_provider_selector=metadata_selector,
|
|
logger=logging.getLogger(__name__),
|
|
)
|
|
|
|
class DummyRequest:
|
|
can_read_body = True
|
|
query = {}
|
|
|
|
async def json(self):
|
|
return {}
|
|
|
|
response = await handler.refresh_model_updates(DummyRequest())
|
|
assert response.status == 200
|
|
|
|
payload = json.loads(response.text)
|
|
assert payload["success"] is True
|
|
assert len(payload["records"]) == 1
|
|
assert payload["records"][0]["modelId"] == 1
|
|
assert payload["records"][0]["hasUpdate"] is True
|
|
|
|
assert len(update_service.calls) == 1
|
|
call = update_service.calls[0]
|
|
assert call["model_type"] == "lora"
|
|
assert call["scanner"] is service.scanner
|
|
assert call["force_refresh"] is False
|
|
assert call["provider"] is not None
|