feat: implement same_base update strategy for model annotations

Add support for configurable update flag strategy with new "same_base" mode that considers base model versions when determining update availability. The strategy is controlled by the "update_flag_strategy" setting.

When strategy is set to "same_base":
- Uses get_records_bulk instead of has_updates_bulk
- Compares model versions against highest local versions per base model
- Provides more granular update detection based on base model relationships

Fallback to existing bulk or individual update checks when:
- Strategy is not "same_base"
- Bulk operations fail
- Records are unavailable

This enables more precise update flagging for models sharing common bases.
This commit is contained in:
Will Miao
2025-11-17 19:26:41 +08:00
parent 8158441a92
commit 0e73db0669
5 changed files with 458 additions and 48 deletions

View File

@@ -11,6 +11,7 @@ from py.services.model_query import (
SearchStrategy,
SortParams,
)
from py.services.model_update_service import ModelUpdateRecord, ModelVersionRecord
from py.utils.models import BaseModelMetadata
@@ -98,6 +99,25 @@ class StubUpdateService:
return result
class StubUpdateServiceWithRecords(StubUpdateService):
def __init__(self, records, *, bulk_error: bool = False):
decisions = {
model_id: record.has_update()
for model_id, record in records.items()
}
super().__init__(decisions, bulk_error=bulk_error)
self.records = dict(records)
self.records_bulk_calls = []
async def get_records_bulk(self, model_type, model_ids):
self.records_bulk_calls.append((model_type, list(model_ids)))
return {
model_id: self.records[model_id]
for model_id in model_ids
if model_id in self.records
}
@pytest.mark.asyncio
async def test_get_paginated_data_uses_injected_collaborators():
data = [
@@ -461,6 +481,198 @@ async def test_get_paginated_data_annotates_update_flags_with_bulk_dedup():
assert response["total_pages"] == 1
@pytest.mark.asyncio
async def test_update_flag_strategy_same_base_prefers_matching_base():
items = [
{
"model_name": "Pony Version",
"civitai": {"modelId": 1, "id": 10, "baseModel": "Pony"},
"base_model": "Pony",
},
{
"model_name": "Flux Version",
"civitai": {"modelId": 1, "id": 20, "baseModel": "Flux 1.D"},
"base_model": "Flux 1.D",
},
]
repository = StubRepository(items)
filter_set = PassThroughFilterSet()
search_strategy = NoSearchStrategy()
record = ModelUpdateRecord(
model_type="stub",
model_id=1,
versions=[
ModelVersionRecord(
version_id=10,
name="Pony Local",
base_model="Pony",
released_at=None,
size_bytes=None,
preview_url=None,
is_in_library=True,
should_ignore=False,
sort_index=0,
),
ModelVersionRecord(
version_id=20,
name="Flux Local",
base_model="Flux 1.D",
released_at=None,
size_bytes=None,
preview_url=None,
is_in_library=True,
should_ignore=False,
sort_index=1,
),
ModelVersionRecord(
version_id=30,
name="Pony Remote",
base_model="Pony",
released_at=None,
size_bytes=None,
preview_url=None,
is_in_library=False,
should_ignore=False,
sort_index=2,
),
ModelVersionRecord(
version_id=40,
name="SDXL Remote",
base_model="SDXL",
released_at=None,
size_bytes=None,
preview_url=None,
is_in_library=False,
should_ignore=False,
sort_index=3,
),
],
last_checked_at=None,
should_ignore_model=False,
)
update_service = StubUpdateServiceWithRecords({1: record})
settings = StubSettings({"update_flag_strategy": "same_base"})
service = DummyService(
model_type="stub",
scanner=object(),
metadata_class=BaseModelMetadata,
cache_repository=repository,
filter_set=filter_set,
search_strategy=search_strategy,
settings_provider=settings,
update_service=update_service,
)
response = await service.get_paginated_data(
page=1,
page_size=10,
sort_by="name:asc",
)
assert update_service.records_bulk_calls == [("stub", [1])]
assert update_service.bulk_calls == []
assert len(response["items"]) == 2
flags = {item["model_name"]: item["update_available"] for item in response["items"]}
assert flags["Pony Version"] is True
assert flags["Flux Version"] is False
@pytest.mark.asyncio
async def test_update_flag_strategy_same_base_honors_latest_local_version():
items = [
{
"model_name": "Pony v0.1",
"civitai": {"modelId": 1, "id": 101, "baseModel": "Pony"},
"base_model": "Pony",
},
{
"model_name": "Pony v0.3",
"civitai": {"modelId": 1, "id": 103, "baseModel": "Pony"},
"base_model": "Pony",
},
]
repository = StubRepository(items)
filter_set = PassThroughFilterSet()
search_strategy = NoSearchStrategy()
record = ModelUpdateRecord(
model_type="stub",
model_id=1,
versions=[
ModelVersionRecord(
version_id=101,
name="Old Pony",
base_model="Pony",
released_at=None,
size_bytes=None,
preview_url=None,
is_in_library=True,
should_ignore=False,
sort_index=0,
),
ModelVersionRecord(
version_id=102,
name="Pony Remote",
base_model="Pony",
released_at=None,
size_bytes=None,
preview_url=None,
is_in_library=False,
should_ignore=False,
sort_index=1,
),
ModelVersionRecord(
version_id=103,
name="Middle Pony",
base_model="Pony",
released_at=None,
size_bytes=None,
preview_url=None,
is_in_library=True,
should_ignore=False,
sort_index=2,
),
ModelVersionRecord(
version_id=104,
name="Latest Pony",
base_model="Pony",
released_at=None,
size_bytes=None,
preview_url=None,
is_in_library=True,
should_ignore=False,
sort_index=3,
),
],
last_checked_at=None,
should_ignore_model=False,
)
update_service = StubUpdateServiceWithRecords({1: record})
settings = StubSettings({"update_flag_strategy": "same_base"})
service = DummyService(
model_type="stub",
scanner=object(),
metadata_class=BaseModelMetadata,
cache_repository=repository,
filter_set=filter_set,
search_strategy=search_strategy,
settings_provider=settings,
update_service=update_service,
)
response = await service.get_paginated_data(
page=1,
page_size=10,
sort_by="name:asc",
)
assert update_service.records_bulk_calls == [("stub", [1])]
flags = {item["model_name"]: item["update_available"] for item in response["items"]}
assert flags["Pony v0.1"] is False
assert flags["Pony v0.3"] is False
@pytest.mark.asyncio
async def test_get_paginated_data_filters_update_available_only():
items = [

View File

@@ -52,11 +52,11 @@ class NotFoundProvider:
return {}
def make_version(version_id, *, in_library, should_ignore=False):
def make_version(version_id, *, in_library, base_model=None, should_ignore=False):
return ModelVersionRecord(
version_id=version_id,
name=None,
base_model=None,
base_model=base_model,
released_at=None,
size_bytes=None,
preview_url=None,
@@ -147,6 +147,25 @@ def test_has_update_detects_newer_remote_version():
assert record.has_update() is True
def test_has_update_for_base_matches_same_base_model():
record = make_record(
make_version(5, in_library=True, base_model="Pony"),
make_version(6, in_library=False, base_model="Pony"),
make_version(7, in_library=False, base_model="Flux.1"),
)
assert record.has_update_for_base(5, "Pony") is True
def test_has_update_for_base_rejects_other_base_models():
record = make_record(
make_version(10, in_library=True, base_model="Flux"),
make_version(20, in_library=False, base_model="SDXL"),
)
assert record.has_update_for_base(10, "Flux") is False
@pytest.mark.asyncio
async def test_refresh_persists_versions_and_uses_cache(tmp_path):
db_path = tmp_path / "updates.sqlite"