diff --git a/py/services/model_update_service.py b/py/services/model_update_service.py index a69c4286..0a15a7bd 100644 --- a/py/services/model_update_service.py +++ b/py/services/model_update_service.py @@ -209,6 +209,8 @@ class ModelUpdateRecord: class ModelUpdateService: """Persist and query remote model version metadata.""" + _SQLITE_MAX_VARIABLES = 500 + _SCHEMA = """ PRAGMA foreign_keys = ON; CREATE TABLE IF NOT EXISTS model_update_status ( @@ -1439,33 +1441,41 @@ class ModelUpdateService: if not model_ids: return {} - params = tuple(model_ids) - placeholders = ",".join("?" for _ in params) + ids = list(model_ids) + status_rows: list = [] + version_rows: list = [] with self._connect() as conn: - status_rows = conn.execute( - f""" - SELECT model_id, model_type, last_checked_at, should_ignore_model - FROM model_update_status - WHERE model_id IN ({placeholders}) - """, - params, - ).fetchall() + for start in range(0, len(ids), self._SQLITE_MAX_VARIABLES): + chunk = tuple(ids[start : start + self._SQLITE_MAX_VARIABLES]) + placeholders = ",".join("?" for _ in chunk) + + chunk_status = conn.execute( + f""" + SELECT model_id, model_type, last_checked_at, should_ignore_model + FROM model_update_status + WHERE model_id IN ({placeholders}) + """, + chunk, + ).fetchall() + status_rows.extend(chunk_status) + + chunk_versions = conn.execute( + f""" + SELECT model_id, version_id, sort_index, name, base_model, released_at, + size_bytes, preview_url, is_in_library, should_ignore, early_access_ends_at, + is_early_access + FROM model_update_versions + WHERE model_id IN ({placeholders}) + ORDER BY model_id ASC, sort_index ASC, version_id ASC + """, + chunk, + ).fetchall() + version_rows.extend(chunk_versions) + if not status_rows: return {} - version_rows = conn.execute( - f""" - SELECT model_id, version_id, sort_index, name, base_model, released_at, - size_bytes, preview_url, is_in_library, should_ignore, early_access_ends_at, - is_early_access - FROM model_update_versions - WHERE model_id IN ({placeholders}) - ORDER BY model_id ASC, sort_index ASC, version_id ASC - """, - params, - ).fetchall() - versions_by_model: Dict[int, List[ModelVersionRecord]] = {} for row in version_rows: model_id = int(row["model_id"]) diff --git a/tests/services/test_model_update_service.py b/tests/services/test_model_update_service.py index 6538ac17..143771a7 100644 --- a/tests/services/test_model_update_service.py +++ b/tests/services/test_model_update_service.py @@ -442,6 +442,42 @@ async def test_has_updates_bulk_returns_mapping(tmp_path): assert await service.has_update("lora", 9) is True +@pytest.mark.asyncio +async def test_has_updates_bulk_handles_more_than_sqlite_max_variables(tmp_path): + """Bulk query with >999 model IDs must not raise 'too many SQL variables'.""" + db_path = tmp_path / "updates.sqlite" + service = ModelUpdateService(str(db_path), ttl_seconds=3600) + + model_ids = list(range(1, 1201)) + with sqlite3.connect(str(db_path)) as conn: + conn.execute("INSERT INTO model_update_status (model_id, model_type) VALUES (?, ?)", (1, "lora")) + conn.execute("INSERT INTO model_update_versions (model_id, version_id, sort_index, name) VALUES (?, ?, ?, ?)", (1, 10, 0, "v1")) + + mapping = await service.has_updates_bulk("lora", model_ids) + + assert mapping[1] is True + assert len(mapping) == len(model_ids) + assert all(v is False for k, v in mapping.items() if k != 1) + + +@pytest.mark.asyncio +async def test_get_records_bulk_handles_more_than_sqlite_max_variables(tmp_path): + """Bulk record fetch with >999 model IDs must not raise 'too many SQL variables'.""" + db_path = tmp_path / "updates.sqlite" + service = ModelUpdateService(str(db_path), ttl_seconds=3600) + + model_ids = list(range(1, 1201)) + with sqlite3.connect(str(db_path)) as conn: + conn.execute("INSERT INTO model_update_status (model_id, model_type) VALUES (?, ?)", (1, "lora")) + conn.execute("INSERT INTO model_update_versions (model_id, version_id, sort_index, name) VALUES (?, ?, ?, ?)", (1, 10, 0, "v1")) + + records = await service.get_records_bulk("lora", model_ids) + + assert 1 in records + assert records[1].model_id == 1 + assert len(records) == 1 + + @pytest.mark.asyncio async def test_refresh_allows_duplicate_version_ids_across_models(tmp_path): db_path = tmp_path / "updates.sqlite"