From 0e73db06692e4ef3f432c1c3cb4a14a73523790f Mon Sep 17 00:00:00 2001 From: Will Miao Date: Mon, 17 Nov 2025 19:26:41 +0800 Subject: [PATCH] feat: implement same_base update strategy for model annotations Add support for configurable update flag strategy with new "same_base" mode that considers base model versions when determining update availability. The strategy is controlled by the "update_flag_strategy" setting. When strategy is set to "same_base": - Uses get_records_bulk instead of has_updates_bulk - Compares model versions against highest local versions per base model - Provides more granular update detection based on base model relationships Fallback to existing bulk or individual update checks when: - Strategy is not "same_base" - Bulk operations fail - Records are unavailable This enables more precise update flagging for models sharing common bases. --- py/services/base_model_service.py | 139 +++++++++++-- py/services/model_update_service.py | 131 +++++++++--- py/services/settings_manager.py | 1 + tests/services/test_base_model_service.py | 212 ++++++++++++++++++++ tests/services/test_model_update_service.py | 23 ++- 5 files changed, 458 insertions(+), 48 deletions(-) diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index d28b8f72..a9d1fa09 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -267,20 +267,49 @@ class BaseModelService(ABC): if not ordered_ids: return annotated + strategy_value = self.settings.get("update_flag_strategy") + if isinstance(strategy_value, str) and strategy_value.strip(): + strategy = strategy_value.strip().lower() + else: + strategy = "any" + same_base_mode = strategy == "same_base" + + records = None resolved: Optional[Dict[int, bool]] = None - bulk_method = getattr(self.update_service, "has_updates_bulk", None) - if callable(bulk_method): - try: - resolved = await bulk_method(self.model_type, ordered_ids) - except Exception as exc: - logger.error( - "Failed to resolve update status in bulk for %s models (%s): %s", - self.model_type, - ordered_ids, - exc, - exc_info=True, - ) - resolved = None + if same_base_mode: + record_method = getattr(self.update_service, "get_records_bulk", None) + if callable(record_method): + try: + records = await record_method(self.model_type, ordered_ids) + resolved = { + model_id: record.has_update() + for model_id, record in records.items() + } + except Exception as exc: + logger.error( + "Failed to resolve update records in bulk for %s models (%s): %s", + self.model_type, + ordered_ids, + exc, + exc_info=True, + ) + records = None + resolved = None + + if resolved is None: + bulk_method = getattr(self.update_service, "has_updates_bulk", None) + if callable(bulk_method): + try: + resolved = await bulk_method(self.model_type, ordered_ids) + except Exception as exc: + logger.error( + "Failed to resolve update status in bulk for %s models (%s): %s", + self.model_type, + ordered_ids, + exc, + exc_info=True, + ) + resolved = None if resolved is None: tasks = [ @@ -301,8 +330,24 @@ class BaseModelService(ABC): resolved[model_id] = bool(result) for model_id, items_for_id in id_to_items.items(): - flag = bool(resolved.get(model_id, False)) + default_flag = bool(resolved.get(model_id, False)) if resolved else False + record = records.get(model_id) if records else None + base_highest_versions = ( + self._build_highest_local_versions_by_base(record) if same_base_mode and record else {} + ) for item in items_for_id: + if same_base_mode and record is not None: + base_model = self._extract_base_model(item) + normalized_base = self._normalize_base_model_name(base_model) + threshold_version = base_highest_versions.get(normalized_base) if normalized_base else None + if threshold_version is None: + threshold_version = self._extract_version_id(item) + flag = record.has_update_for_base( + threshold_version, + base_model, + ) + else: + flag = default_flag item['update_available'] = flag return annotated @@ -319,7 +364,71 @@ class BaseModelService(ABC): return int(value) except (TypeError, ValueError): return None - + + @staticmethod + def _extract_version_id(item: Dict) -> Optional[int]: + civitai = item.get('civitai') if isinstance(item, dict) else None + if not isinstance(civitai, dict): + return None + value = civitai.get('id') + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return None + + @staticmethod + def _extract_base_model(item: Dict) -> Optional[str]: + value = item.get('base_model') + if value is None: + return None + if isinstance(value, str): + candidate = value.strip() + else: + try: + candidate = str(value).strip() + except Exception: + return None + return candidate if candidate else None + + @staticmethod + def _normalize_base_model_name(value: Optional[str]) -> Optional[str]: + """Return a lowercased, trimmed base model name for comparison.""" + + if value is None: + return None + if isinstance(value, str): + candidate = value.strip() + else: + try: + candidate = str(value).strip() + except Exception: + return None + return candidate.lower() if candidate else None + + def _build_highest_local_versions_by_base(self, record) -> Dict[str, int]: + """Return the highest local version id known for each normalized base model.""" + + if record is None: + return {} + + highest_by_base: Dict[str, int] = {} + for version in getattr(record, "versions", []): + if not getattr(version, "is_in_library", False): + continue + normalized_base = self._normalize_base_model_name(getattr(version, "base_model", None)) + if normalized_base is None: + continue + version_id = getattr(version, "version_id", None) + if version_id is None: + continue + current_max = highest_by_base.get(normalized_base) + if current_max is None or version_id > current_max: + highest_by_base[normalized_base] = version_id + + return highest_by_base + def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict: """Apply pagination to filtered data""" total_items = len(data) diff --git a/py/services/model_update_service.py b/py/services/model_update_service.py index d31ad28f..d75e1cfd 100644 --- a/py/services/model_update_service.py +++ b/py/services/model_update_service.py @@ -17,6 +17,41 @@ from ..utils.preview_selection import select_preview_media logger = logging.getLogger(__name__) +def _normalize_int(value) -> Optional[int]: + """Safely convert a value to an integer.""" + + try: + if value is None: + return None + return int(value) + except (TypeError, ValueError): + return None + + +def _normalize_string(value) -> Optional[str]: + """Return a stripped string or None if the value is empty.""" + + if value is None: + return None + if isinstance(value, str): + stripped = value.strip() + return stripped or None + try: + normalized = str(value).strip() + return normalized or None + except Exception: + return None + + +def _normalize_base_model(value) -> Optional[str]: + """Normalize base-model names for case-insensitive comparison.""" + + normalized = _normalize_string(value) + if normalized is None: + return None + return normalized.lower() + + @dataclass class ModelVersionRecord: """Persisted metadata for a single model version.""" @@ -85,6 +120,47 @@ class ModelUpdateRecord: return True return False + def has_update_for_base( + self, + local_version_id: Optional[int], + local_base_model: Optional[str], + ) -> bool: + """Return True when a newer remote version with the same base model exists.""" + + if self.should_ignore_model: + return False + + normalized_base = _normalize_base_model(local_base_model) + if normalized_base is None: + return False + + threshold = _normalize_int(local_version_id) + if threshold is None: + highest_local = None + for version in self.versions: + if not version.is_in_library: + continue + version_base = _normalize_base_model(version.base_model) + if version_base != normalized_base: + continue + if highest_local is None or version.version_id > highest_local: + highest_local = version.version_id + threshold = highest_local + + if threshold is None: + return False + + for version in self.versions: + if version.is_in_library or version.should_ignore: + continue + version_base = _normalize_base_model(version.base_model) + if version_base != normalized_base: + continue + if version.version_id > threshold: + return True + + return False + class ModelUpdateService: """Persist and query remote model version metadata.""" @@ -628,6 +704,20 @@ class ModelUpdateService: for model_id in normalized_ids } + async def get_records_bulk( + self, + model_type: str, + model_ids: Sequence[int], + ) -> Dict[int, ModelUpdateRecord]: + """Return cached update records for the requested models.""" + + normalized_ids = self._normalize_sequence(model_ids) + if not normalized_ids: + return {} + + async with self._lock: + return self._get_records_bulk(model_type, normalized_ids) + async def _refresh_single_model( self, model_type: str, @@ -799,7 +889,7 @@ class ModelUpdateService: ) continue for key, value in response.items(): - normalized_key = self._normalize_int(key) + normalized_key = _normalize_int(key) if normalized_key is None: continue if isinstance(value, Mapping): @@ -832,8 +922,8 @@ class ModelUpdateService: civitai = item.get("civitai") if isinstance(item, dict) else None if not isinstance(civitai, dict): continue - model_id = self._normalize_int(civitai.get("modelId")) - version_id = self._normalize_int(civitai.get("id")) + model_id = _normalize_int(civitai.get("modelId")) + version_id = _normalize_int(civitai.get("id")) if model_id is None or version_id is None: continue if target_set is not None and model_id not in target_set: @@ -973,35 +1063,14 @@ class ModelUpdateService: return True return (now - record.last_checked_at) >= self._ttl_seconds - @staticmethod - def _normalize_int(value) -> Optional[int]: - try: - if value is None: - return None - return int(value) - except (TypeError, ValueError): - return None - def _normalize_sequence(self, values: Sequence[int]) -> List[int]: normalized = [ item - for item in (self._normalize_int(value) for value in values) + for item in (_normalize_int(value) for value in values) if item is not None ] return sorted(dict.fromkeys(normalized)) - @staticmethod - def _normalize_string(value) -> Optional[str]: - if value is None: - return None - if isinstance(value, str): - stripped = value.strip() - return stripped or None - try: - return str(value) - except Exception: # pragma: no cover - defensive conversion - return None - def _extract_versions(self, response) -> Optional[List[ModelVersionRecord]]: if not isinstance(response, Mapping): return None @@ -1014,12 +1083,12 @@ class ModelUpdateService: for index, entry in enumerate(versions): if not isinstance(entry, Mapping): continue - version_id = self._normalize_int(entry.get("id")) + version_id = _normalize_int(entry.get("id")) if version_id is None: continue - name = self._normalize_string(entry.get("name")) - base_model = self._normalize_string(entry.get("baseModel")) - released_at = self._normalize_string(entry.get("publishedAt") or entry.get("createdAt")) + name = _normalize_string(entry.get("name")) + base_model = _normalize_string(entry.get("baseModel")) + released_at = _normalize_string(entry.get("publishedAt") or entry.get("createdAt")) size_bytes = self._extract_size_bytes(entry.get("files")) preview_url = self._extract_preview_url(entry.get("images")) extracted.append( @@ -1152,11 +1221,11 @@ class ModelUpdateService: name=row["name"], base_model=row["base_model"], released_at=row["released_at"], - size_bytes=self._normalize_int(row["size_bytes"]), + size_bytes=_normalize_int(row["size_bytes"]), preview_url=row["preview_url"], is_in_library=bool(row["is_in_library"]), should_ignore=bool(row["should_ignore"]), - sort_index=self._normalize_int(row["sort_index"]) or 0, + sort_index=_normalize_int(row["sort_index"]) or 0, ) ) diff --git a/py/services/settings_manager.py b/py/services/settings_manager.py index 37476f68..6ff96c4b 100644 --- a/py/services/settings_manager.py +++ b/py/services/settings_manager.py @@ -61,6 +61,7 @@ DEFAULT_SETTINGS: Dict[str, Any] = { "priority_tags": DEFAULT_PRIORITY_TAG_CONFIG.copy(), "model_name_display": "model_name", "model_card_footer_action": "example_images", + "update_flag_strategy": "any", } diff --git a/tests/services/test_base_model_service.py b/tests/services/test_base_model_service.py index 77a2a1b4..8c412aa1 100644 --- a/tests/services/test_base_model_service.py +++ b/tests/services/test_base_model_service.py @@ -11,6 +11,7 @@ from py.services.model_query import ( SearchStrategy, SortParams, ) +from py.services.model_update_service import ModelUpdateRecord, ModelVersionRecord from py.utils.models import BaseModelMetadata @@ -98,6 +99,25 @@ class StubUpdateService: return result +class StubUpdateServiceWithRecords(StubUpdateService): + def __init__(self, records, *, bulk_error: bool = False): + decisions = { + model_id: record.has_update() + for model_id, record in records.items() + } + super().__init__(decisions, bulk_error=bulk_error) + self.records = dict(records) + self.records_bulk_calls = [] + + async def get_records_bulk(self, model_type, model_ids): + self.records_bulk_calls.append((model_type, list(model_ids))) + return { + model_id: self.records[model_id] + for model_id in model_ids + if model_id in self.records + } + + @pytest.mark.asyncio async def test_get_paginated_data_uses_injected_collaborators(): data = [ @@ -461,6 +481,198 @@ async def test_get_paginated_data_annotates_update_flags_with_bulk_dedup(): assert response["total_pages"] == 1 +@pytest.mark.asyncio +async def test_update_flag_strategy_same_base_prefers_matching_base(): + items = [ + { + "model_name": "Pony Version", + "civitai": {"modelId": 1, "id": 10, "baseModel": "Pony"}, + "base_model": "Pony", + }, + { + "model_name": "Flux Version", + "civitai": {"modelId": 1, "id": 20, "baseModel": "Flux 1.D"}, + "base_model": "Flux 1.D", + }, + ] + repository = StubRepository(items) + filter_set = PassThroughFilterSet() + search_strategy = NoSearchStrategy() + record = ModelUpdateRecord( + model_type="stub", + model_id=1, + versions=[ + ModelVersionRecord( + version_id=10, + name="Pony Local", + base_model="Pony", + released_at=None, + size_bytes=None, + preview_url=None, + is_in_library=True, + should_ignore=False, + sort_index=0, + ), + ModelVersionRecord( + version_id=20, + name="Flux Local", + base_model="Flux 1.D", + released_at=None, + size_bytes=None, + preview_url=None, + is_in_library=True, + should_ignore=False, + sort_index=1, + ), + ModelVersionRecord( + version_id=30, + name="Pony Remote", + base_model="Pony", + released_at=None, + size_bytes=None, + preview_url=None, + is_in_library=False, + should_ignore=False, + sort_index=2, + ), + ModelVersionRecord( + version_id=40, + name="SDXL Remote", + base_model="SDXL", + released_at=None, + size_bytes=None, + preview_url=None, + is_in_library=False, + should_ignore=False, + sort_index=3, + ), + ], + last_checked_at=None, + should_ignore_model=False, + ) + update_service = StubUpdateServiceWithRecords({1: record}) + settings = StubSettings({"update_flag_strategy": "same_base"}) + + service = DummyService( + model_type="stub", + scanner=object(), + metadata_class=BaseModelMetadata, + cache_repository=repository, + filter_set=filter_set, + search_strategy=search_strategy, + settings_provider=settings, + update_service=update_service, + ) + + response = await service.get_paginated_data( + page=1, + page_size=10, + sort_by="name:asc", + ) + + assert update_service.records_bulk_calls == [("stub", [1])] + assert update_service.bulk_calls == [] + assert len(response["items"]) == 2 + flags = {item["model_name"]: item["update_available"] for item in response["items"]} + assert flags["Pony Version"] is True + assert flags["Flux Version"] is False + + +@pytest.mark.asyncio +async def test_update_flag_strategy_same_base_honors_latest_local_version(): + items = [ + { + "model_name": "Pony v0.1", + "civitai": {"modelId": 1, "id": 101, "baseModel": "Pony"}, + "base_model": "Pony", + }, + { + "model_name": "Pony v0.3", + "civitai": {"modelId": 1, "id": 103, "baseModel": "Pony"}, + "base_model": "Pony", + }, + ] + repository = StubRepository(items) + filter_set = PassThroughFilterSet() + search_strategy = NoSearchStrategy() + record = ModelUpdateRecord( + model_type="stub", + model_id=1, + versions=[ + ModelVersionRecord( + version_id=101, + name="Old Pony", + base_model="Pony", + released_at=None, + size_bytes=None, + preview_url=None, + is_in_library=True, + should_ignore=False, + sort_index=0, + ), + ModelVersionRecord( + version_id=102, + name="Pony Remote", + base_model="Pony", + released_at=None, + size_bytes=None, + preview_url=None, + is_in_library=False, + should_ignore=False, + sort_index=1, + ), + ModelVersionRecord( + version_id=103, + name="Middle Pony", + base_model="Pony", + released_at=None, + size_bytes=None, + preview_url=None, + is_in_library=True, + should_ignore=False, + sort_index=2, + ), + ModelVersionRecord( + version_id=104, + name="Latest Pony", + base_model="Pony", + released_at=None, + size_bytes=None, + preview_url=None, + is_in_library=True, + should_ignore=False, + sort_index=3, + ), + ], + last_checked_at=None, + should_ignore_model=False, + ) + update_service = StubUpdateServiceWithRecords({1: record}) + settings = StubSettings({"update_flag_strategy": "same_base"}) + + service = DummyService( + model_type="stub", + scanner=object(), + metadata_class=BaseModelMetadata, + cache_repository=repository, + filter_set=filter_set, + search_strategy=search_strategy, + settings_provider=settings, + update_service=update_service, + ) + + response = await service.get_paginated_data( + page=1, + page_size=10, + sort_by="name:asc", + ) + + assert update_service.records_bulk_calls == [("stub", [1])] + flags = {item["model_name"]: item["update_available"] for item in response["items"]} + assert flags["Pony v0.1"] is False + assert flags["Pony v0.3"] is False + + @pytest.mark.asyncio async def test_get_paginated_data_filters_update_available_only(): items = [ diff --git a/tests/services/test_model_update_service.py b/tests/services/test_model_update_service.py index 23c7d003..5e4534e8 100644 --- a/tests/services/test_model_update_service.py +++ b/tests/services/test_model_update_service.py @@ -52,11 +52,11 @@ class NotFoundProvider: return {} -def make_version(version_id, *, in_library, should_ignore=False): +def make_version(version_id, *, in_library, base_model=None, should_ignore=False): return ModelVersionRecord( version_id=version_id, name=None, - base_model=None, + base_model=base_model, released_at=None, size_bytes=None, preview_url=None, @@ -147,6 +147,25 @@ def test_has_update_detects_newer_remote_version(): assert record.has_update() is True +def test_has_update_for_base_matches_same_base_model(): + record = make_record( + make_version(5, in_library=True, base_model="Pony"), + make_version(6, in_library=False, base_model="Pony"), + make_version(7, in_library=False, base_model="Flux.1"), + ) + + assert record.has_update_for_base(5, "Pony") is True + + +def test_has_update_for_base_rejects_other_base_models(): + record = make_record( + make_version(10, in_library=True, base_model="Flux"), + make_version(20, in_library=False, base_model="SDXL"), + ) + + assert record.has_update_for_base(10, "Flux") is False + + @pytest.mark.asyncio async def test_refresh_persists_versions_and_uses_cache(tmp_path): db_path = tmp_path / "updates.sqlite"