From d9a6db33596eb69912c740e5d56cd1b37ff8dced Mon Sep 17 00:00:00 2001 From: Will Miao Date: Sat, 25 Oct 2025 15:31:33 +0800 Subject: [PATCH] feat: optimize model update checking with bulk operations - Refactor update filter logic to use bulk update checks when available - Add annotation method to attach update flags to response items - Improve performance by reducing API calls for update status checks - Maintain backward compatibility with fallback to individual checks - Handle edge cases and logging for failed update status resolutions --- py/services/base_model_service.py | 190 +++++++++++++++----- py/services/model_update_service.py | 99 +++++++--- tests/services/test_base_model_service.py | 78 +++++++- tests/services/test_model_update_service.py | 22 +++ 4 files changed, 314 insertions(+), 75 deletions(-) diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index e17a4dab..b7f4f40b 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -72,31 +72,35 @@ class BaseModelService(ABC): if hash_filters: filtered_data = await self._apply_hash_filters(sorted_data, hash_filters) - return self._paginate(filtered_data, page, page_size) - - filtered_data = await self._apply_common_filters( - sorted_data, - folder=folder, - base_models=base_models, - tags=tags, - favorites_only=favorites_only, - search_options=search_options, - ) - - if search: - filtered_data = await self._apply_search_filters( - filtered_data, - search, - fuzzy_search, - search_options, + else: + filtered_data = await self._apply_common_filters( + sorted_data, + folder=folder, + base_models=base_models, + tags=tags, + favorites_only=favorites_only, + search_options=search_options, ) - filtered_data = await self._apply_specific_filters(filtered_data, **kwargs) + if search: + filtered_data = await self._apply_search_filters( + filtered_data, + search, + fuzzy_search, + search_options, + ) - if has_update: - filtered_data = await self._apply_update_filter(filtered_data) + filtered_data = await self._apply_specific_filters(filtered_data, **kwargs) - return self._paginate(filtered_data, page, page_size) + if has_update: + filtered_data = await self._apply_update_filter(filtered_data) + + paginated = self._paginate(filtered_data, page, page_size) + paginated['items'] = await self._annotate_update_flags( + paginated['items'], + assume_true=has_update, + ) + return paginated async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]: @@ -167,35 +171,141 @@ class BaseModelService(ABC): ) return [] - candidates: List[tuple[Dict, int]] = [] + id_to_items: Dict[int, List[Dict]] = {} + ordered_ids: List[int] = [] for item in data: model_id = self._extract_model_id(item) - if model_id is not None: - candidates.append((item, model_id)) + if model_id is None: + continue + if model_id not in id_to_items: + id_to_items[model_id] = [] + ordered_ids.append(model_id) + id_to_items[model_id].append(item) - if not candidates: + if not ordered_ids: return [] - tasks = [ - self.update_service.has_update(self.model_type, model_id) - for _, model_id in candidates - ] - results = await asyncio.gather(*tasks, return_exceptions=True) + resolved: Optional[Dict[int, bool]] = None + bulk_method = getattr(self.update_service, "has_updates_bulk", None) + if callable(bulk_method): + try: + resolved = await bulk_method(self.model_type, ordered_ids) + except Exception as exc: + logger.error( + "Failed to resolve update status in bulk for %s models (%s): %s", + self.model_type, + ordered_ids, + exc, + exc_info=True, + ) + resolved = None + + if resolved is None: + tasks = [ + self.update_service.has_update(self.model_type, model_id) + for model_id in ordered_ids + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + resolved = {} + for model_id, result in zip(ordered_ids, results): + if isinstance(result, Exception): + logger.error( + "Failed to resolve update status for model %s (%s): %s", + model_id, + self.model_type, + result, + ) + continue + resolved[model_id] = bool(result) filtered: List[Dict] = [] - for (item, model_id), result in zip(candidates, results): - if isinstance(result, Exception): - logger.error( - "Failed to resolve update status for model %s (%s): %s", - model_id, - self.model_type, - result, - ) - continue - if result: + for item in data: + model_id = self._extract_model_id(item) + if model_id is not None and resolved.get(model_id, False): filtered.append(item) return filtered + async def _annotate_update_flags( + self, + items: List[Dict], + *, + assume_true: bool = False, + ) -> List[Dict]: + """Attach an update_available flag to each response item. + + Items without a civitai model id default to False. When the caller already + filtered for updates we can skip the lookup and mark everything True. + """ + if not items: + return [] + + annotated = [dict(item) for item in items] + + if assume_true: + for item in annotated: + item['update_available'] = True + return annotated + + if self.update_service is None: + for item in annotated: + item['update_available'] = False + return annotated + + id_to_items: Dict[int, List[Dict]] = {} + ordered_ids: List[int] = [] + for item in annotated: + model_id = self._extract_model_id(item) + if model_id is None: + item['update_available'] = False + continue + if model_id not in id_to_items: + id_to_items[model_id] = [] + ordered_ids.append(model_id) + id_to_items[model_id].append(item) + + if not ordered_ids: + return annotated + + resolved: Optional[Dict[int, bool]] = None + bulk_method = getattr(self.update_service, "has_updates_bulk", None) + if callable(bulk_method): + try: + resolved = await bulk_method(self.model_type, ordered_ids) + except Exception as exc: + logger.error( + "Failed to resolve update status in bulk for %s models (%s): %s", + self.model_type, + ordered_ids, + exc, + exc_info=True, + ) + resolved = None + + if resolved is None: + tasks = [ + self.update_service.has_update(self.model_type, model_id) + for model_id in ordered_ids + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + resolved = {} + for model_id, result in zip(ordered_ids, results): + if isinstance(result, Exception): + logger.error( + "Failed to resolve update status for model %s (%s): %s", + model_id, + self.model_type, + result, + ) + continue + resolved[model_id] = bool(result) + + for model_id, items_for_id in id_to_items.items(): + flag = bool(resolved.get(model_id, False)) + for item in items_for_id: + item['update_available'] = flag + + return annotated + @staticmethod def _extract_model_id(item: Dict) -> Optional[int]: civitai = item.get('civitai') if isinstance(item, dict) else None diff --git a/py/services/model_update_service.py b/py/services/model_update_service.py index 6392ff32..ad5494b7 100644 --- a/py/services/model_update_service.py +++ b/py/services/model_update_service.py @@ -307,6 +307,25 @@ class ModelUpdateService: record = await self.get_record(model_type, model_id) return record.has_update() if record else False + async def has_updates_bulk( + self, + model_type: str, + model_ids: Sequence[int], + ) -> Dict[int, bool]: + """Return update availability for each model id in a single database pass.""" + + normalized_ids = self._normalize_sequence(model_ids) + if not normalized_ids: + return {} + + async with self._lock: + records = self._get_records_bulk(model_type, normalized_ids) + + return { + model_id: records.get(model_id).has_update() if records.get(model_id) else False + for model_id in normalized_ids + } + async def _refresh_single_model( self, model_type: str, @@ -680,36 +699,47 @@ class ModelUpdateService: return rewritten or url def _get_record(self, model_type: str, model_id: int) -> Optional[ModelUpdateRecord]: + records = self._get_records_bulk(model_type, [model_id]) + return records.get(model_id) + + def _get_records_bulk( + self, + model_type: str, + model_ids: Sequence[int], + ) -> Dict[int, ModelUpdateRecord]: + if not model_ids: + return {} + + params = tuple(model_ids) + placeholders = ",".join("?" for _ in params) + with self._connect() as conn: - status_row = conn.execute( - """ + status_rows = conn.execute( + f""" SELECT model_id, model_type, last_checked_at, should_ignore_model FROM model_update_status - WHERE model_id = ? + WHERE model_id IN ({placeholders}) """, - (model_id,), - ).fetchone() - if not status_row: - return None - stored_type = status_row["model_type"] - if stored_type and stored_type != model_type: - logger.debug( - "Model id %s requested as %s but stored as %s", model_id, model_type, stored_type - ) + params, + ).fetchall() + if not status_rows: + return {} + version_rows = conn.execute( - """ - SELECT version_id, sort_index, name, base_model, released_at, size_bytes, - preview_url, is_in_library, should_ignore + f""" + SELECT model_id, version_id, sort_index, name, base_model, released_at, + size_bytes, preview_url, is_in_library, should_ignore FROM model_update_versions - WHERE model_id = ? - ORDER BY sort_index ASC, version_id ASC + WHERE model_id IN ({placeholders}) + ORDER BY model_id ASC, sort_index ASC, version_id ASC """, - (model_id,), + params, ).fetchall() - versions: List[ModelVersionRecord] = [] + versions_by_model: Dict[int, List[ModelVersionRecord]] = {} for row in version_rows: - versions.append( + model_id = int(row["model_id"]) + versions_by_model.setdefault(model_id, []).append( ModelVersionRecord( version_id=int(row["version_id"]), name=row["name"], @@ -723,13 +753,28 @@ class ModelUpdateService: ) ) - return ModelUpdateRecord( - model_type=stored_type or model_type, - model_id=int(status_row["model_id"]), - versions=self._sorted_versions(versions), - last_checked_at=status_row["last_checked_at"], - should_ignore_model=bool(status_row["should_ignore_model"]), - ) + records: Dict[int, ModelUpdateRecord] = {} + for status in status_rows: + model_id = int(status["model_id"]) + stored_type = status["model_type"] + if stored_type and stored_type != model_type: + logger.debug( + "Model id %s requested as %s but stored as %s", + model_id, + model_type, + stored_type, + ) + + record = ModelUpdateRecord( + model_type=stored_type or model_type, + model_id=model_id, + versions=self._sorted_versions(versions_by_model.get(model_id, [])), + last_checked_at=status["last_checked_at"], + should_ignore_model=bool(status["should_ignore_model"]), + ) + records[model_id] = record + + return records def _upsert_record(self, record: ModelUpdateRecord) -> None: payload = ( diff --git a/tests/services/test_base_model_service.py b/tests/services/test_base_model_service.py index 23b0b1d7..0595963a 100644 --- a/tests/services/test_base_model_service.py +++ b/tests/services/test_base_model_service.py @@ -68,9 +68,23 @@ class StubSearchStrategy: class StubUpdateService: - def __init__(self, decisions): + def __init__(self, decisions, *, bulk_error: bool = False): self.decisions = dict(decisions) self.calls = [] + self.bulk_calls = [] + self.bulk_error = bulk_error + + async def has_updates_bulk(self, model_type, model_ids): + self.bulk_calls.append((model_type, list(model_ids))) + if self.bulk_error: + raise RuntimeError("bulk failure") + results = {} + for model_id in model_ids: + result = self.decisions.get(model_id, False) + if isinstance(result, Exception): + raise result + results[model_id] = result + return results async def has_update(self, model_type, model_id): self.calls.append((model_type, model_id)) @@ -131,7 +145,11 @@ async def test_get_paginated_data_uses_injected_collaborators(): assert search_strategy.normalize_calls == [{"recursive": False}, {"recursive": False}] assert search_strategy.apply_calls == [([{"model_name": "Filtered"}], "query", {"recursive": False}, True)] - assert response["items"] == search_strategy.search_result + assert [item["model_name"] for item in response["items"]] == [ + entry["model_name"] for entry in search_strategy.search_result + ] + assert all("update_available" in item for item in response["items"]) + assert all(item["update_available"] is False for item in response["items"]) assert response["total"] == len(search_strategy.search_result) assert response["page"] == 1 assert response["page_size"] == 5 @@ -218,7 +236,9 @@ async def test_get_paginated_data_filters_and_searches_combination(): favorites_only=True, ) - assert response["items"] == [items[2]] + assert len(response["items"]) == 1 + assert response["items"][0]["model_name"] == items[2]["model_name"] + assert response["items"][0]["update_available"] is False assert response["total"] == 1 assert response["page"] == 1 assert response["page_size"] == 1 @@ -280,7 +300,10 @@ async def test_get_paginated_data_paginates_without_search(): assert len(repository.fetch_sorted_calls) == 1 assert filter_set.calls and filter_set.calls[0].favorites_only is False assert search_strategy.apply_called is False - assert response["items"] == items[2:4] + assert [item["model_name"] for item in response["items"]] == [ + entry["model_name"] for entry in items[2:4] + ] + assert all(item["update_available"] is False for item in response["items"]) assert response["total"] == len(items) assert response["page"] == 2 assert response["page_size"] == 2 @@ -318,8 +341,10 @@ async def test_get_paginated_data_filters_by_update_status(): has_update=True, ) - assert update_service.calls == [("stub", 1), ("stub", 2), ("stub", 3)] - assert response["items"] == [items[0], items[2]] + assert update_service.bulk_calls == [("stub", [1, 2, 3])] + assert update_service.calls == [] + assert [item["model_name"] for item in response["items"]] == ["A", "C"] + assert all(item["update_available"] is True for item in response["items"]) assert response["total"] == 2 assert response["page"] == 1 assert response["page_size"] == 5 @@ -389,7 +414,44 @@ async def test_get_paginated_data_skips_items_when_update_check_fails(): has_update=True, ) + assert update_service.bulk_calls == [("stub", [1, 2])] assert update_service.calls == [("stub", 1), ("stub", 2)] - assert response["items"] == [items[0]] - assert response["total"] == 1 + assert [item["model_name"] for item in response["items"]] == ["A"] + assert response["items"][0]["update_available"] is True + + +@pytest.mark.asyncio +async def test_get_paginated_data_annotates_update_flags_with_bulk_dedup(): + items = [ + {"model_name": "Alpha", "civitai": {"modelId": 7}}, + {"model_name": "Beta", "civitai": {"modelId": 7}}, + {"model_name": "Gamma", "civitai": {"modelId": 8}}, + ] + repository = StubRepository(items) + filter_set = PassThroughFilterSet() + search_strategy = NoSearchStrategy() + update_service = StubUpdateService({7: True, 8: False}) + settings = StubSettings({}) + + service = DummyService( + model_type="stub", + scanner=object(), + metadata_class=BaseModelMetadata, + cache_repository=repository, + filter_set=filter_set, + search_strategy=search_strategy, + settings_provider=settings, + update_service=update_service, + ) + + response = await service.get_paginated_data( + page=1, + page_size=10, + sort_by="name:asc", + ) + + assert update_service.bulk_calls == [("stub", [7, 8])] + assert update_service.calls == [] + assert [item["update_available"] for item in response["items"]] == [True, True, False] + assert response["total"] == 3 assert response["total_pages"] == 1 diff --git a/tests/services/test_model_update_service.py b/tests/services/test_model_update_service.py index 09e869c8..37eb1dff 100644 --- a/tests/services/test_model_update_service.py +++ b/tests/services/test_model_update_service.py @@ -200,6 +200,28 @@ async def test_version_ignore_blocks_update_flag(tmp_path): assert record.has_update() is False +@pytest.mark.asyncio +async def test_has_updates_bulk_returns_mapping(tmp_path): + db_path = tmp_path / "updates.sqlite" + service = ModelUpdateService(str(db_path), ttl_seconds=3600) + raw_data = [{"civitai": {"modelId": 9, "id": 91}}] + scanner = DummyScanner(raw_data) + provider = DummyProvider( + { + "modelVersions": [ + {"id": 91, "files": [], "images": []}, + {"id": 92, "files": [], "images": []}, + ] + } + ) + + await service.refresh_for_model_type("lora", scanner, provider) + mapping = await service.has_updates_bulk("lora", [9, 9, 42]) + + assert mapping == {9: True, 42: False} + assert await service.has_update("lora", 9) is True + + @pytest.mark.asyncio async def test_refresh_rewrites_remote_preview_urls(tmp_path): db_path = tmp_path / "updates.sqlite"