fix(model-updates): support per-model version ids

This commit is contained in:
pixelpaws
2025-10-30 23:15:23 +08:00
parent a92883509a
commit 1be3235564
2 changed files with 150 additions and 1 deletions

View File

@@ -98,8 +98,8 @@ class ModelUpdateService:
should_ignore_model INTEGER NOT NULL DEFAULT 0
);
CREATE TABLE IF NOT EXISTS model_update_versions (
version_id INTEGER PRIMARY KEY,
model_id INTEGER NOT NULL,
version_id INTEGER NOT NULL,
sort_index INTEGER NOT NULL DEFAULT 0,
name TEXT,
base_model TEXT,
@@ -108,6 +108,7 @@ class ModelUpdateService:
preview_url TEXT,
is_in_library INTEGER NOT NULL DEFAULT 0,
should_ignore INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (model_id, version_id),
FOREIGN KEY(model_id) REFERENCES model_update_status(model_id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_model_update_versions_model_id
@@ -197,6 +198,15 @@ class ModelUpdateService:
if column not in version_columns:
conn.execute(statement)
# Refresh column metadata after applying additive migrations.
version_columns = self._get_table_columns(conn, "model_update_versions")
if self._requires_model_update_versions_pk_migration(conn):
self._migrate_model_update_versions_primary_key(
conn, version_columns
)
version_columns = self._get_table_columns(conn, "model_update_versions")
if not self._has_unique_constraint(conn, "model_update_status", "model_id"):
self._deduplicate_model_update_status(conn)
conn.execute(
@@ -204,6 +214,12 @@ class ModelUpdateService:
"uq_model_update_status_model_id ON model_update_status(model_id)"
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_model_update_versions_model_id "
"ON model_update_versions(model_id)"
)
def _get_table_columns(self, conn: sqlite3.Connection, table: str) -> set[str]:
"""Return the set of existing columns for a table."""
@@ -236,6 +252,100 @@ class ModelUpdateService:
return True
return False
def _requires_model_update_versions_pk_migration(
self, conn: sqlite3.Connection
) -> bool:
"""Detect legacy schemas where version_id is the sole primary key."""
info = conn.execute("PRAGMA table_info(model_update_versions)").fetchall()
pk_columns = [row for row in info if row["pk"]]
if not pk_columns:
return True
if len(pk_columns) == 1:
return pk_columns[0]["name"] == "version_id"
ordered = sorted(pk_columns, key=lambda row: row["pk"])
expected = ["model_id", "version_id"]
return [row["name"] for row in ordered] != expected
def _migrate_model_update_versions_primary_key(
self, conn: sqlite3.Connection, legacy_columns: set[str]
) -> None:
"""Upgrade the versions table to use a composite primary key."""
logger.info("Migrating model_update_versions table to composite primary key")
conn.execute(
"ALTER TABLE model_update_versions RENAME TO model_update_versions_legacy"
)
conn.execute(
"""
CREATE TABLE model_update_versions_new (
model_id INTEGER NOT NULL,
version_id INTEGER NOT NULL,
sort_index INTEGER NOT NULL DEFAULT 0,
name TEXT,
base_model TEXT,
released_at TEXT,
size_bytes INTEGER,
preview_url TEXT,
is_in_library INTEGER NOT NULL DEFAULT 0,
should_ignore INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (model_id, version_id),
FOREIGN KEY(model_id) REFERENCES model_update_status(model_id) ON DELETE CASCADE
)
"""
)
target_columns = [
"model_id",
"version_id",
"sort_index",
"name",
"base_model",
"released_at",
"size_bytes",
"preview_url",
"is_in_library",
"should_ignore",
]
defaults = {
"sort_index": "0",
"name": "NULL",
"base_model": "NULL",
"released_at": "NULL",
"size_bytes": "NULL",
"preview_url": "NULL",
"is_in_library": "0",
"should_ignore": "0",
}
select_parts = []
for column in target_columns:
if column in legacy_columns:
if column in {"sort_index", "is_in_library", "should_ignore"}:
select_parts.append(f"COALESCE({column}, {defaults[column]})")
else:
select_parts.append(column)
else:
select_parts.append(defaults.get(column, "NULL"))
conn.execute(
"""
INSERT INTO model_update_versions_new ({columns})
SELECT {select_clause}
FROM model_update_versions_legacy
""".format(
columns=", ".join(target_columns),
select_clause=", ".join(select_parts),
)
)
conn.execute("DROP TABLE model_update_versions_legacy")
conn.execute(
"ALTER TABLE model_update_versions_new RENAME TO model_update_versions"
)
def _deduplicate_model_update_status(self, conn: sqlite3.Connection) -> None:
"""Remove duplicate status rows before applying uniqueness constraints."""

View File

@@ -1,4 +1,5 @@
import logging
import sqlite3
from types import SimpleNamespace
import pytest
@@ -415,6 +416,44 @@ async def test_has_updates_bulk_returns_mapping(tmp_path):
assert await service.has_update("lora", 9) is True
@pytest.mark.asyncio
async def test_refresh_allows_duplicate_version_ids_across_models(tmp_path):
db_path = tmp_path / "updates.sqlite"
service = ModelUpdateService(str(db_path), ttl_seconds=0)
raw_data = [
{"civitai": {"modelId": 1, "id": 42}},
{"civitai": {"modelId": 2, "id": 42}},
]
scanner = DummyScanner(raw_data)
provider = DummyProvider(
{
"modelVersions": [
{
"id": 42,
"name": "shared",
"baseModel": "SD15",
"publishedAt": "2024-03-01T00:00:00Z",
"files": [{"sizeKB": 256}],
"images": [],
}
]
}
)
results = await service.refresh_for_model_type("lora", scanner, provider)
assert set(results.keys()) == {1, 2}
assert results[1].version_ids == [42]
assert results[2].version_ids == [42]
with sqlite3.connect(str(db_path)) as conn:
count = conn.execute(
"SELECT COUNT(*) FROM model_update_versions WHERE version_id = 42"
).fetchone()[0]
assert count == 2
@pytest.mark.asyncio
async def test_refresh_rewrites_remote_preview_urls(tmp_path):
db_path = tmp_path / "updates.sqlite"