mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
Merge pull request #626 from willmiao/codex/update-model_update_versions-primary-key
fix: support per-model version ids in update service
This commit is contained in:
@@ -98,8 +98,8 @@ class ModelUpdateService:
|
|||||||
should_ignore_model INTEGER NOT NULL DEFAULT 0
|
should_ignore_model INTEGER NOT NULL DEFAULT 0
|
||||||
);
|
);
|
||||||
CREATE TABLE IF NOT EXISTS model_update_versions (
|
CREATE TABLE IF NOT EXISTS model_update_versions (
|
||||||
version_id INTEGER PRIMARY KEY,
|
|
||||||
model_id INTEGER NOT NULL,
|
model_id INTEGER NOT NULL,
|
||||||
|
version_id INTEGER NOT NULL,
|
||||||
sort_index INTEGER NOT NULL DEFAULT 0,
|
sort_index INTEGER NOT NULL DEFAULT 0,
|
||||||
name TEXT,
|
name TEXT,
|
||||||
base_model TEXT,
|
base_model TEXT,
|
||||||
@@ -108,6 +108,7 @@ class ModelUpdateService:
|
|||||||
preview_url TEXT,
|
preview_url TEXT,
|
||||||
is_in_library INTEGER NOT NULL DEFAULT 0,
|
is_in_library INTEGER NOT NULL DEFAULT 0,
|
||||||
should_ignore INTEGER NOT NULL DEFAULT 0,
|
should_ignore INTEGER NOT NULL DEFAULT 0,
|
||||||
|
PRIMARY KEY (model_id, version_id),
|
||||||
FOREIGN KEY(model_id) REFERENCES model_update_status(model_id) ON DELETE CASCADE
|
FOREIGN KEY(model_id) REFERENCES model_update_status(model_id) ON DELETE CASCADE
|
||||||
);
|
);
|
||||||
CREATE INDEX IF NOT EXISTS idx_model_update_versions_model_id
|
CREATE INDEX IF NOT EXISTS idx_model_update_versions_model_id
|
||||||
@@ -197,6 +198,15 @@ class ModelUpdateService:
|
|||||||
if column not in version_columns:
|
if column not in version_columns:
|
||||||
conn.execute(statement)
|
conn.execute(statement)
|
||||||
|
|
||||||
|
# Refresh column metadata after applying additive migrations.
|
||||||
|
version_columns = self._get_table_columns(conn, "model_update_versions")
|
||||||
|
|
||||||
|
if self._requires_model_update_versions_pk_migration(conn):
|
||||||
|
self._migrate_model_update_versions_primary_key(
|
||||||
|
conn, version_columns
|
||||||
|
)
|
||||||
|
version_columns = self._get_table_columns(conn, "model_update_versions")
|
||||||
|
|
||||||
if not self._has_unique_constraint(conn, "model_update_status", "model_id"):
|
if not self._has_unique_constraint(conn, "model_update_status", "model_id"):
|
||||||
self._deduplicate_model_update_status(conn)
|
self._deduplicate_model_update_status(conn)
|
||||||
conn.execute(
|
conn.execute(
|
||||||
@@ -204,6 +214,12 @@ class ModelUpdateService:
|
|||||||
"uq_model_update_status_model_id ON model_update_status(model_id)"
|
"uq_model_update_status_model_id ON model_update_status(model_id)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_model_update_versions_model_id "
|
||||||
|
"ON model_update_versions(model_id)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_table_columns(self, conn: sqlite3.Connection, table: str) -> set[str]:
|
def _get_table_columns(self, conn: sqlite3.Connection, table: str) -> set[str]:
|
||||||
"""Return the set of existing columns for a table."""
|
"""Return the set of existing columns for a table."""
|
||||||
|
|
||||||
@@ -236,6 +252,100 @@ class ModelUpdateService:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _requires_model_update_versions_pk_migration(
|
||||||
|
self, conn: sqlite3.Connection
|
||||||
|
) -> bool:
|
||||||
|
"""Detect legacy schemas where version_id is the sole primary key."""
|
||||||
|
|
||||||
|
info = conn.execute("PRAGMA table_info(model_update_versions)").fetchall()
|
||||||
|
pk_columns = [row for row in info if row["pk"]]
|
||||||
|
if not pk_columns:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if len(pk_columns) == 1:
|
||||||
|
return pk_columns[0]["name"] == "version_id"
|
||||||
|
|
||||||
|
ordered = sorted(pk_columns, key=lambda row: row["pk"])
|
||||||
|
expected = ["model_id", "version_id"]
|
||||||
|
return [row["name"] for row in ordered] != expected
|
||||||
|
|
||||||
|
def _migrate_model_update_versions_primary_key(
|
||||||
|
self, conn: sqlite3.Connection, legacy_columns: set[str]
|
||||||
|
) -> None:
|
||||||
|
"""Upgrade the versions table to use a composite primary key."""
|
||||||
|
|
||||||
|
logger.info("Migrating model_update_versions table to composite primary key")
|
||||||
|
conn.execute(
|
||||||
|
"ALTER TABLE model_update_versions RENAME TO model_update_versions_legacy"
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE model_update_versions_new (
|
||||||
|
model_id INTEGER NOT NULL,
|
||||||
|
version_id INTEGER NOT NULL,
|
||||||
|
sort_index INTEGER NOT NULL DEFAULT 0,
|
||||||
|
name TEXT,
|
||||||
|
base_model TEXT,
|
||||||
|
released_at TEXT,
|
||||||
|
size_bytes INTEGER,
|
||||||
|
preview_url TEXT,
|
||||||
|
is_in_library INTEGER NOT NULL DEFAULT 0,
|
||||||
|
should_ignore INTEGER NOT NULL DEFAULT 0,
|
||||||
|
PRIMARY KEY (model_id, version_id),
|
||||||
|
FOREIGN KEY(model_id) REFERENCES model_update_status(model_id) ON DELETE CASCADE
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
target_columns = [
|
||||||
|
"model_id",
|
||||||
|
"version_id",
|
||||||
|
"sort_index",
|
||||||
|
"name",
|
||||||
|
"base_model",
|
||||||
|
"released_at",
|
||||||
|
"size_bytes",
|
||||||
|
"preview_url",
|
||||||
|
"is_in_library",
|
||||||
|
"should_ignore",
|
||||||
|
]
|
||||||
|
defaults = {
|
||||||
|
"sort_index": "0",
|
||||||
|
"name": "NULL",
|
||||||
|
"base_model": "NULL",
|
||||||
|
"released_at": "NULL",
|
||||||
|
"size_bytes": "NULL",
|
||||||
|
"preview_url": "NULL",
|
||||||
|
"is_in_library": "0",
|
||||||
|
"should_ignore": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
select_parts = []
|
||||||
|
for column in target_columns:
|
||||||
|
if column in legacy_columns:
|
||||||
|
if column in {"sort_index", "is_in_library", "should_ignore"}:
|
||||||
|
select_parts.append(f"COALESCE({column}, {defaults[column]})")
|
||||||
|
else:
|
||||||
|
select_parts.append(column)
|
||||||
|
else:
|
||||||
|
select_parts.append(defaults.get(column, "NULL"))
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO model_update_versions_new ({columns})
|
||||||
|
SELECT {select_clause}
|
||||||
|
FROM model_update_versions_legacy
|
||||||
|
""".format(
|
||||||
|
columns=", ".join(target_columns),
|
||||||
|
select_clause=", ".join(select_parts),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.execute("DROP TABLE model_update_versions_legacy")
|
||||||
|
conn.execute(
|
||||||
|
"ALTER TABLE model_update_versions_new RENAME TO model_update_versions"
|
||||||
|
)
|
||||||
|
|
||||||
def _deduplicate_model_update_status(self, conn: sqlite3.Connection) -> None:
|
def _deduplicate_model_update_status(self, conn: sqlite3.Connection) -> None:
|
||||||
"""Remove duplicate status rows before applying uniqueness constraints."""
|
"""Remove duplicate status rows before applying uniqueness constraints."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import sqlite3
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -415,6 +416,44 @@ async def test_has_updates_bulk_returns_mapping(tmp_path):
|
|||||||
assert await service.has_update("lora", 9) is True
|
assert await service.has_update("lora", 9) is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_allows_duplicate_version_ids_across_models(tmp_path):
|
||||||
|
db_path = tmp_path / "updates.sqlite"
|
||||||
|
service = ModelUpdateService(str(db_path), ttl_seconds=0)
|
||||||
|
raw_data = [
|
||||||
|
{"civitai": {"modelId": 1, "id": 42}},
|
||||||
|
{"civitai": {"modelId": 2, "id": 42}},
|
||||||
|
]
|
||||||
|
scanner = DummyScanner(raw_data)
|
||||||
|
provider = DummyProvider(
|
||||||
|
{
|
||||||
|
"modelVersions": [
|
||||||
|
{
|
||||||
|
"id": 42,
|
||||||
|
"name": "shared",
|
||||||
|
"baseModel": "SD15",
|
||||||
|
"publishedAt": "2024-03-01T00:00:00Z",
|
||||||
|
"files": [{"sizeKB": 256}],
|
||||||
|
"images": [],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await service.refresh_for_model_type("lora", scanner, provider)
|
||||||
|
|
||||||
|
assert set(results.keys()) == {1, 2}
|
||||||
|
assert results[1].version_ids == [42]
|
||||||
|
assert results[2].version_ids == [42]
|
||||||
|
|
||||||
|
with sqlite3.connect(str(db_path)) as conn:
|
||||||
|
count = conn.execute(
|
||||||
|
"SELECT COUNT(*) FROM model_update_versions WHERE version_id = 42"
|
||||||
|
).fetchone()[0]
|
||||||
|
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_refresh_rewrites_remote_preview_urls(tmp_path):
|
async def test_refresh_rewrites_remote_preview_urls(tmp_path):
|
||||||
db_path = tmp_path / "updates.sqlite"
|
db_path = tmp_path / "updates.sqlite"
|
||||||
|
|||||||
Reference in New Issue
Block a user