Files
ComfyUI-Lora-Manager/tests/routes/test_model_update_handler.py
Will Miao d77b6d78b7 feat(model-updates): filter records without updates in refresh response
Add logic to only include model update records that have actual updates in the refresh response. This improves API efficiency by reducing payload size and only returning relevant data to clients.

The change:
- Adds filtering in ModelUpdateHandler.refresh_model_updates to check has_update method
- Only serializes records that have updates available
- Updates corresponding test to verify filtering behavior

This prevents returning unnecessary data for models that don't have updates available.
2025-10-25 21:31:36 +08:00

155 lines
4.3 KiB
Python

import json
import logging
from types import SimpleNamespace
import pytest
from py.config import config
from py.routes.handlers.model_handlers import ModelUpdateHandler
from py.services.model_update_service import ModelUpdateRecord, ModelVersionRecord
class DummyScanner:
def __init__(self, cache):
self._cache = cache
async def get_cached_data(self):
return self._cache
class DummyService:
def __init__(self, cache):
self.model_type = "lora"
self.scanner = DummyScanner(cache)
class DummyUpdateService:
def __init__(self, records):
self.records = records
self.calls = []
async def refresh_for_model_type(self, model_type, scanner, provider, *, force_refresh=False):
self.calls.append(
{
"model_type": model_type,
"scanner": scanner,
"provider": provider,
"force_refresh": force_refresh,
}
)
return self.records
@pytest.mark.asyncio
async def test_build_preview_overrides_uses_static_urls():
cache = SimpleNamespace(version_index={123: {"preview_url": "/tmp/previews/example.png"}})
service = DummyService(cache)
handler = ModelUpdateHandler(
service=service,
update_service=SimpleNamespace(),
metadata_provider_selector=lambda *_: None,
logger=logging.getLogger(__name__),
)
record = ModelUpdateRecord(
model_type="lora",
model_id=42,
versions=[
ModelVersionRecord(
version_id=123,
name=None,
base_model=None,
released_at=None,
size_bytes=None,
preview_url=None,
is_in_library=True,
should_ignore=False,
)
],
last_checked_at=None,
should_ignore_model=False,
)
overrides = await handler._build_preview_overrides(record)
expected = config.get_preview_static_url("/tmp/previews/example.png")
assert overrides == {123: expected}
@pytest.mark.asyncio
async def test_refresh_model_updates_filters_records_without_updates():
cache = SimpleNamespace(version_index={})
service = DummyService(cache)
record_with_update = ModelUpdateRecord(
model_type="lora",
model_id=1,
versions=[
ModelVersionRecord(
version_id=10,
name="v1",
base_model=None,
released_at=None,
size_bytes=None,
preview_url=None,
is_in_library=False,
should_ignore=False,
)
],
last_checked_at=None,
should_ignore_model=False,
)
record_without_update = ModelUpdateRecord(
model_type="lora",
model_id=2,
versions=[
ModelVersionRecord(
version_id=20,
name="v2",
base_model=None,
released_at=None,
size_bytes=None,
preview_url=None,
is_in_library=True,
should_ignore=False,
)
],
last_checked_at=None,
should_ignore_model=False,
)
update_service = DummyUpdateService({1: record_with_update, 2: record_without_update})
async def metadata_selector(name):
assert name == "civitai_api"
return object()
handler = ModelUpdateHandler(
service=service,
update_service=update_service,
metadata_provider_selector=metadata_selector,
logger=logging.getLogger(__name__),
)
class DummyRequest:
can_read_body = True
query = {}
async def json(self):
return {}
response = await handler.refresh_model_updates(DummyRequest())
assert response.status == 200
payload = json.loads(response.text)
assert payload["success"] is True
assert len(payload["records"]) == 1
assert payload["records"][0]["modelId"] == 1
assert payload["records"][0]["hasUpdate"] is True
assert len(update_service.calls) == 1
call = update_service.calls[0]
assert call["model_type"] == "lora"
assert call["scanner"] is service.scanner
assert call["force_refresh"] is False
assert call["provider"] is not None