mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
refactor(routes): limit update endpoints to essentials
This commit is contained in:
@@ -5,6 +5,7 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
import types
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
folder_paths_stub = types.SimpleNamespace(get_folder_paths=lambda *_: [])
|
||||
@@ -32,6 +33,41 @@ class DummyRoutes(BaseModelRoutes):
|
||||
def setup_specific_routes(self, registrar, prefix: str) -> None: # pragma: no cover - no extra routes in smoke tests
|
||||
return None
|
||||
|
||||
def __init__(self, service=None):
|
||||
super().__init__(service)
|
||||
self.set_model_update_service(NullModelUpdateService())
|
||||
|
||||
|
||||
@dataclass
|
||||
class NullUpdateRecord:
|
||||
model_type: str
|
||||
model_id: int
|
||||
largest_version_id: int | None = None
|
||||
version_ids: list[int] = field(default_factory=list)
|
||||
in_library_version_ids: list[int] = field(default_factory=list)
|
||||
last_checked_at: float | None = None
|
||||
should_ignore: bool = False
|
||||
|
||||
def has_update(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class NullModelUpdateService:
|
||||
async def refresh_for_model_type(self, *args, **kwargs):
|
||||
return {}
|
||||
|
||||
async def refresh_single_model(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
async def update_in_library_versions(self, model_type, model_id, version_ids):
|
||||
return NullUpdateRecord(model_type=model_type, model_id=model_id, in_library_version_ids=list(version_ids))
|
||||
|
||||
async def set_should_ignore(self, model_type, model_id, should_ignore):
|
||||
return NullUpdateRecord(model_type=model_type, model_id=model_id, should_ignore=should_ignore)
|
||||
|
||||
async def get_record(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
async def create_test_client(service) -> TestClient:
|
||||
routes = DummyRoutes(service)
|
||||
|
||||
85
tests/services/test_model_update_service.py
Normal file
85
tests/services/test_model_update_service.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.model_update_service import ModelUpdateService
|
||||
|
||||
|
||||
class DummyScanner:
|
||||
def __init__(self, raw_data):
|
||||
self._cache = SimpleNamespace(raw_data=raw_data)
|
||||
|
||||
async def get_cached_data(self, *args, **kwargs):
|
||||
return self._cache
|
||||
|
||||
|
||||
class DummyProvider:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
self.calls: int = 0
|
||||
|
||||
async def get_model_versions(self, model_id):
|
||||
self.calls += 1
|
||||
return self.response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_persists_versions_and_uses_cache(tmp_path):
|
||||
db_path = tmp_path / "updates.sqlite"
|
||||
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
||||
raw_data = [
|
||||
{"civitai": {"modelId": 1, "id": 11}},
|
||||
{"civitai": {"modelId": 1, "id": 15}},
|
||||
]
|
||||
scanner = DummyScanner(raw_data)
|
||||
provider = DummyProvider({"modelVersions": [{"id": 11}, {"id": 15}]})
|
||||
|
||||
await service.refresh_for_model_type("lora", scanner, provider)
|
||||
record = await service.get_record("lora", 1)
|
||||
|
||||
assert provider.calls == 1
|
||||
assert record is not None
|
||||
assert record.version_ids == [11, 15]
|
||||
assert record.in_library_version_ids == [11, 15]
|
||||
assert record.has_update() is False
|
||||
|
||||
await service.refresh_for_model_type("lora", scanner, provider)
|
||||
assert provider.calls == 1, "provider should not be called again within TTL"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_respects_ignore_flag(tmp_path):
|
||||
db_path = tmp_path / "updates.sqlite"
|
||||
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
||||
raw_data = [{"civitai": {"modelId": 2, "id": 21}}]
|
||||
scanner = DummyScanner(raw_data)
|
||||
provider = DummyProvider({"modelVersions": [{"id": 21}, {"id": 22}]})
|
||||
|
||||
await service.refresh_for_model_type("lora", scanner, provider)
|
||||
await service.set_should_ignore("lora", 2, True)
|
||||
|
||||
provider.calls = 0
|
||||
await service.refresh_for_model_type("lora", scanner, provider)
|
||||
assert provider.calls == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_in_library_versions_changes_update_state(tmp_path):
|
||||
db_path = tmp_path / "updates.sqlite"
|
||||
service = ModelUpdateService(str(db_path), ttl_seconds=1)
|
||||
raw_data = [{"civitai": {"modelId": 3, "id": 31}}]
|
||||
scanner = DummyScanner(raw_data)
|
||||
provider = DummyProvider({"modelVersions": [{"id": 31}, {"id": 35}]})
|
||||
|
||||
await service.refresh_for_model_type("lora", scanner, provider)
|
||||
await service.update_in_library_versions("lora", 3, [31])
|
||||
record = await service.get_record("lora", 3)
|
||||
|
||||
assert record is not None
|
||||
assert record.has_update() is True
|
||||
|
||||
await service.update_in_library_versions("lora", 3, [31, 35])
|
||||
record = await service.get_record("lora", 3)
|
||||
|
||||
assert record.has_update() is False
|
||||
Reference in New Issue
Block a user