mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
feat: add model version management endpoints
- Add set_version_update_ignore endpoint to toggle ignore status for specific versions - Add get_model_versions endpoint to retrieve version details with optional refresh - Update serialization to include version-specific data and preview overrides - Modify database schema to support version-level ignore tracking - Improve error handling for rate limiting and missing models These changes enable granular control over version updates and provide better visibility into model version status.
This commit is contained in:
@@ -21,6 +21,7 @@ from py.services import model_file_service
|
||||
from py.services.downloader import DownloadProgress
|
||||
from py.services.metadata_sync_service import MetadataSyncService
|
||||
from py.services.model_file_service import AutoOrganizeResult
|
||||
from py.services.model_update_service import ModelVersionRecord
|
||||
from py.services.service_registry import ServiceRegistry
|
||||
from py.services.websocket_manager import ws_manager
|
||||
from py.utils.exif_utils import ExifUtils
|
||||
@@ -42,11 +43,23 @@ class DummyRoutes(BaseModelRoutes):
|
||||
class NullUpdateRecord:
|
||||
model_type: str
|
||||
model_id: int
|
||||
largest_version_id: int | None = None
|
||||
version_ids: list[int] = field(default_factory=list)
|
||||
in_library_version_ids: list[int] = field(default_factory=list)
|
||||
versions: list[ModelVersionRecord] = field(default_factory=list)
|
||||
last_checked_at: float | None = None
|
||||
should_ignore: bool = False
|
||||
should_ignore_model: bool = False
|
||||
|
||||
@property
|
||||
def largest_version_id(self) -> int | None:
|
||||
if not self.versions:
|
||||
return None
|
||||
return max(version.version_id for version in self.versions)
|
||||
|
||||
@property
|
||||
def version_ids(self) -> list[int]:
|
||||
return [version.version_id for version in self.versions]
|
||||
|
||||
@property
|
||||
def in_library_version_ids(self) -> list[int]:
|
||||
return [version.version_id for version in self.versions if version.is_in_library]
|
||||
|
||||
def has_update(self) -> bool:
|
||||
return False
|
||||
@@ -60,10 +73,30 @@ class NullModelUpdateService:
|
||||
return None
|
||||
|
||||
async def update_in_library_versions(self, model_type, model_id, version_ids):
|
||||
return NullUpdateRecord(model_type=model_type, model_id=model_id, in_library_version_ids=list(version_ids))
|
||||
versions = [
|
||||
ModelVersionRecord(
|
||||
version_id=version_id,
|
||||
name=None,
|
||||
base_model=None,
|
||||
released_at=None,
|
||||
size_bytes=None,
|
||||
preview_url=None,
|
||||
is_in_library=True,
|
||||
should_ignore=False,
|
||||
)
|
||||
for version_id in version_ids
|
||||
]
|
||||
return NullUpdateRecord(model_type=model_type, model_id=model_id, versions=versions)
|
||||
|
||||
async def set_should_ignore(self, model_type, model_id, should_ignore):
|
||||
return NullUpdateRecord(model_type=model_type, model_id=model_id, should_ignore=should_ignore)
|
||||
return NullUpdateRecord(
|
||||
model_type=model_type,
|
||||
model_id=model_id,
|
||||
should_ignore_model=should_ignore,
|
||||
)
|
||||
|
||||
async def set_version_should_ignore(self, model_type, model_id, version_id, should_ignore):
|
||||
return await self.set_should_ignore(model_type, model_id, should_ignore)
|
||||
|
||||
async def get_record(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
57
tests/routes/test_model_update_handler.py
Normal file
57
tests/routes/test_model_update_handler.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import logging
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from py.config import config
|
||||
from py.routes.handlers.model_handlers import ModelUpdateHandler
|
||||
from py.services.model_update_service import ModelUpdateRecord, ModelVersionRecord
|
||||
|
||||
|
||||
class DummyScanner:
|
||||
def __init__(self, cache):
|
||||
self._cache = cache
|
||||
|
||||
async def get_cached_data(self):
|
||||
return self._cache
|
||||
|
||||
|
||||
class DummyService:
|
||||
def __init__(self, cache):
|
||||
self.model_type = "lora"
|
||||
self.scanner = DummyScanner(cache)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_preview_overrides_uses_static_urls():
|
||||
cache = SimpleNamespace(version_index={123: {"preview_url": "/tmp/previews/example.png"}})
|
||||
service = DummyService(cache)
|
||||
handler = ModelUpdateHandler(
|
||||
service=service,
|
||||
update_service=SimpleNamespace(),
|
||||
metadata_provider_selector=lambda *_: None,
|
||||
logger=logging.getLogger(__name__),
|
||||
)
|
||||
|
||||
record = ModelUpdateRecord(
|
||||
model_type="lora",
|
||||
model_id=42,
|
||||
versions=[
|
||||
ModelVersionRecord(
|
||||
version_id=123,
|
||||
name=None,
|
||||
base_model=None,
|
||||
released_at=None,
|
||||
size_bytes=None,
|
||||
preview_url=None,
|
||||
is_in_library=True,
|
||||
should_ignore=False,
|
||||
)
|
||||
],
|
||||
last_checked_at=None,
|
||||
should_ignore_model=False,
|
||||
)
|
||||
|
||||
overrides = await handler._build_preview_overrides(record)
|
||||
expected = config.get_preview_static_url("/tmp/previews/example.png")
|
||||
assert overrides == {123: expected}
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
@@ -8,7 +7,7 @@ from py.services.model_update_service import ModelUpdateService
|
||||
|
||||
class DummyScanner:
|
||||
def __init__(self, raw_data):
|
||||
self._cache = SimpleNamespace(raw_data=raw_data)
|
||||
self._cache = SimpleNamespace(raw_data=raw_data, version_index={})
|
||||
|
||||
async def get_cached_data(self, *args, **kwargs):
|
||||
return self._cache
|
||||
@@ -41,7 +40,28 @@ async def test_refresh_persists_versions_and_uses_cache(tmp_path):
|
||||
{"civitai": {"modelId": 1, "id": 15}},
|
||||
]
|
||||
scanner = DummyScanner(raw_data)
|
||||
provider = DummyProvider({"modelVersions": [{"id": 11}, {"id": 15}]})
|
||||
provider = DummyProvider(
|
||||
{
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 11,
|
||||
"name": "v1",
|
||||
"baseModel": "SD15",
|
||||
"publishedAt": "2024-01-01T00:00:00Z",
|
||||
"files": [{"sizeKB": 1024}],
|
||||
"images": [{"url": "https://example.com/1.png"}],
|
||||
},
|
||||
{
|
||||
"id": 15,
|
||||
"name": "v1.5",
|
||||
"baseModel": "SD15",
|
||||
"publishedAt": "2024-02-01T00:00:00Z",
|
||||
"files": [{"sizeKB": 512}],
|
||||
"images": [{"url": "https://example.com/2.png"}],
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
await service.refresh_for_model_type("lora", scanner, provider)
|
||||
record = await service.get_record("lora", 1)
|
||||
@@ -51,6 +71,8 @@ async def test_refresh_persists_versions_and_uses_cache(tmp_path):
|
||||
assert record is not None
|
||||
assert record.version_ids == [11, 15]
|
||||
assert record.in_library_version_ids == [11, 15]
|
||||
assert [version.name for version in record.versions] == ["v1", "v1.5"]
|
||||
assert record.should_ignore_model is False
|
||||
assert record.has_update() is False
|
||||
|
||||
await service.refresh_for_model_type("lora", scanner, provider)
|
||||
@@ -64,7 +86,14 @@ async def test_refresh_respects_ignore_flag(tmp_path):
|
||||
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
||||
raw_data = [{"civitai": {"modelId": 2, "id": 21}}]
|
||||
scanner = DummyScanner(raw_data)
|
||||
provider = DummyProvider({"modelVersions": [{"id": 21}, {"id": 22}]})
|
||||
provider = DummyProvider(
|
||||
{
|
||||
"modelVersions": [
|
||||
{"id": 21, "files": [], "images": []},
|
||||
{"id": 22, "files": [], "images": []},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
await service.refresh_for_model_type("lora", scanner, provider)
|
||||
await service.set_should_ignore("lora", 2, True)
|
||||
@@ -74,6 +103,9 @@ async def test_refresh_respects_ignore_flag(tmp_path):
|
||||
await service.refresh_for_model_type("lora", scanner, provider)
|
||||
assert provider.calls == 0
|
||||
assert provider.bulk_calls == []
|
||||
record = await service.get_record("lora", 2)
|
||||
assert record is not None
|
||||
assert record.should_ignore_model is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -82,7 +114,10 @@ async def test_refresh_falls_back_when_bulk_not_supported(tmp_path):
|
||||
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
||||
raw_data = [{"civitai": {"modelId": 4, "id": 41}}]
|
||||
scanner = DummyScanner(raw_data)
|
||||
provider = DummyProvider({"modelVersions": [{"id": 41}]}, support_bulk=False)
|
||||
provider = DummyProvider(
|
||||
{"modelVersions": [{"id": 41, "files": [], "images": []}]},
|
||||
support_bulk=False,
|
||||
)
|
||||
|
||||
await service.refresh_for_model_type("lora", scanner, provider)
|
||||
record = await service.get_record("lora", 4)
|
||||
@@ -117,7 +152,14 @@ async def test_update_in_library_versions_changes_update_state(tmp_path):
|
||||
service = ModelUpdateService(str(db_path), ttl_seconds=1)
|
||||
raw_data = [{"civitai": {"modelId": 3, "id": 31}}]
|
||||
scanner = DummyScanner(raw_data)
|
||||
provider = DummyProvider({"modelVersions": [{"id": 31}, {"id": 35}]})
|
||||
provider = DummyProvider(
|
||||
{
|
||||
"modelVersions": [
|
||||
{"id": 31, "files": [], "images": []},
|
||||
{"id": 35, "files": [], "images": []},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
await service.refresh_for_model_type("lora", scanner, provider)
|
||||
await service.update_in_library_versions("lora", 3, [31])
|
||||
@@ -130,3 +172,70 @@ async def test_update_in_library_versions_changes_update_state(tmp_path):
|
||||
record = await service.get_record("lora", 3)
|
||||
|
||||
assert record.has_update() is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_version_ignore_blocks_update_flag(tmp_path):
|
||||
db_path = tmp_path / "updates.sqlite"
|
||||
service = ModelUpdateService(str(db_path), ttl_seconds=1)
|
||||
raw_data = [{"civitai": {"modelId": 5, "id": 51}}]
|
||||
scanner = DummyScanner(raw_data)
|
||||
provider = DummyProvider(
|
||||
{
|
||||
"modelVersions": [
|
||||
{"id": 51, "files": [], "images": []},
|
||||
{"id": 55, "files": [], "images": []},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
await service.refresh_for_model_type("lora", scanner, provider)
|
||||
record = await service.get_record("lora", 5)
|
||||
assert record is not None
|
||||
assert record.has_update() is True
|
||||
|
||||
await service.set_version_should_ignore("lora", 5, 55, True)
|
||||
record = await service.get_record("lora", 5)
|
||||
assert record is not None
|
||||
assert record.has_update() is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_rewrites_remote_preview_urls(tmp_path):
|
||||
db_path = tmp_path / "updates.sqlite"
|
||||
service = ModelUpdateService(str(db_path), ttl_seconds=1)
|
||||
raw_data = [{"civitai": {"modelId": 7, "id": 71}}]
|
||||
scanner = DummyScanner(raw_data)
|
||||
provider = DummyProvider(
|
||||
{
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 71,
|
||||
"files": [],
|
||||
"images": [
|
||||
{
|
||||
"url": "https://image.civitai.com/high/original=true/sample.png",
|
||||
"nsfwLevel": 6,
|
||||
"type": "image",
|
||||
},
|
||||
{
|
||||
"url": "https://image.civitai.com/safe/original=true/preview.png",
|
||||
"nsfwLevel": 1,
|
||||
"type": "image",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
await service.refresh_for_model_type("lora", scanner, provider)
|
||||
record = await service.get_record("lora", 7)
|
||||
|
||||
assert record is not None
|
||||
assert record.versions
|
||||
preview_url = record.versions[0].preview_url
|
||||
assert (
|
||||
preview_url
|
||||
== "https://image.civitai.com/safe/width=450,optimized=true/preview.png"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user