mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
fix(model-updates): support per-model version ids
This commit is contained in:
@@ -98,8 +98,8 @@ class ModelUpdateService:
|
||||
should_ignore_model INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS model_update_versions (
|
||||
version_id INTEGER PRIMARY KEY,
|
||||
model_id INTEGER NOT NULL,
|
||||
version_id INTEGER NOT NULL,
|
||||
sort_index INTEGER NOT NULL DEFAULT 0,
|
||||
name TEXT,
|
||||
base_model TEXT,
|
||||
@@ -108,6 +108,7 @@ class ModelUpdateService:
|
||||
preview_url TEXT,
|
||||
is_in_library INTEGER NOT NULL DEFAULT 0,
|
||||
should_ignore INTEGER NOT NULL DEFAULT 0,
|
||||
PRIMARY KEY (model_id, version_id),
|
||||
FOREIGN KEY(model_id) REFERENCES model_update_status(model_id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_model_update_versions_model_id
|
||||
@@ -197,6 +198,15 @@ class ModelUpdateService:
|
||||
if column not in version_columns:
|
||||
conn.execute(statement)
|
||||
|
||||
# Refresh column metadata after applying additive migrations.
|
||||
version_columns = self._get_table_columns(conn, "model_update_versions")
|
||||
|
||||
if self._requires_model_update_versions_pk_migration(conn):
|
||||
self._migrate_model_update_versions_primary_key(
|
||||
conn, version_columns
|
||||
)
|
||||
version_columns = self._get_table_columns(conn, "model_update_versions")
|
||||
|
||||
if not self._has_unique_constraint(conn, "model_update_status", "model_id"):
|
||||
self._deduplicate_model_update_status(conn)
|
||||
conn.execute(
|
||||
@@ -204,6 +214,12 @@ class ModelUpdateService:
|
||||
"uq_model_update_status_model_id ON model_update_status(model_id)"
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_model_update_versions_model_id "
|
||||
"ON model_update_versions(model_id)"
|
||||
)
|
||||
|
||||
|
||||
def _get_table_columns(self, conn: sqlite3.Connection, table: str) -> set[str]:
|
||||
"""Return the set of existing columns for a table."""
|
||||
|
||||
@@ -236,6 +252,100 @@ class ModelUpdateService:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _requires_model_update_versions_pk_migration(
|
||||
self, conn: sqlite3.Connection
|
||||
) -> bool:
|
||||
"""Detect legacy schemas where version_id is the sole primary key."""
|
||||
|
||||
info = conn.execute("PRAGMA table_info(model_update_versions)").fetchall()
|
||||
pk_columns = [row for row in info if row["pk"]]
|
||||
if not pk_columns:
|
||||
return True
|
||||
|
||||
if len(pk_columns) == 1:
|
||||
return pk_columns[0]["name"] == "version_id"
|
||||
|
||||
ordered = sorted(pk_columns, key=lambda row: row["pk"])
|
||||
expected = ["model_id", "version_id"]
|
||||
return [row["name"] for row in ordered] != expected
|
||||
|
||||
def _migrate_model_update_versions_primary_key(
|
||||
self, conn: sqlite3.Connection, legacy_columns: set[str]
|
||||
) -> None:
|
||||
"""Upgrade the versions table to use a composite primary key."""
|
||||
|
||||
logger.info("Migrating model_update_versions table to composite primary key")
|
||||
conn.execute(
|
||||
"ALTER TABLE model_update_versions RENAME TO model_update_versions_legacy"
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE model_update_versions_new (
|
||||
model_id INTEGER NOT NULL,
|
||||
version_id INTEGER NOT NULL,
|
||||
sort_index INTEGER NOT NULL DEFAULT 0,
|
||||
name TEXT,
|
||||
base_model TEXT,
|
||||
released_at TEXT,
|
||||
size_bytes INTEGER,
|
||||
preview_url TEXT,
|
||||
is_in_library INTEGER NOT NULL DEFAULT 0,
|
||||
should_ignore INTEGER NOT NULL DEFAULT 0,
|
||||
PRIMARY KEY (model_id, version_id),
|
||||
FOREIGN KEY(model_id) REFERENCES model_update_status(model_id) ON DELETE CASCADE
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
target_columns = [
|
||||
"model_id",
|
||||
"version_id",
|
||||
"sort_index",
|
||||
"name",
|
||||
"base_model",
|
||||
"released_at",
|
||||
"size_bytes",
|
||||
"preview_url",
|
||||
"is_in_library",
|
||||
"should_ignore",
|
||||
]
|
||||
defaults = {
|
||||
"sort_index": "0",
|
||||
"name": "NULL",
|
||||
"base_model": "NULL",
|
||||
"released_at": "NULL",
|
||||
"size_bytes": "NULL",
|
||||
"preview_url": "NULL",
|
||||
"is_in_library": "0",
|
||||
"should_ignore": "0",
|
||||
}
|
||||
|
||||
select_parts = []
|
||||
for column in target_columns:
|
||||
if column in legacy_columns:
|
||||
if column in {"sort_index", "is_in_library", "should_ignore"}:
|
||||
select_parts.append(f"COALESCE({column}, {defaults[column]})")
|
||||
else:
|
||||
select_parts.append(column)
|
||||
else:
|
||||
select_parts.append(defaults.get(column, "NULL"))
|
||||
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO model_update_versions_new ({columns})
|
||||
SELECT {select_clause}
|
||||
FROM model_update_versions_legacy
|
||||
""".format(
|
||||
columns=", ".join(target_columns),
|
||||
select_clause=", ".join(select_parts),
|
||||
)
|
||||
)
|
||||
|
||||
conn.execute("DROP TABLE model_update_versions_legacy")
|
||||
conn.execute(
|
||||
"ALTER TABLE model_update_versions_new RENAME TO model_update_versions"
|
||||
)
|
||||
|
||||
def _deduplicate_model_update_status(self, conn: sqlite3.Connection) -> None:
|
||||
"""Remove duplicate status rows before applying uniqueness constraints."""
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import sqlite3
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
@@ -415,6 +416,44 @@ async def test_has_updates_bulk_returns_mapping(tmp_path):
|
||||
assert await service.has_update("lora", 9) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_allows_duplicate_version_ids_across_models(tmp_path):
|
||||
db_path = tmp_path / "updates.sqlite"
|
||||
service = ModelUpdateService(str(db_path), ttl_seconds=0)
|
||||
raw_data = [
|
||||
{"civitai": {"modelId": 1, "id": 42}},
|
||||
{"civitai": {"modelId": 2, "id": 42}},
|
||||
]
|
||||
scanner = DummyScanner(raw_data)
|
||||
provider = DummyProvider(
|
||||
{
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 42,
|
||||
"name": "shared",
|
||||
"baseModel": "SD15",
|
||||
"publishedAt": "2024-03-01T00:00:00Z",
|
||||
"files": [{"sizeKB": 256}],
|
||||
"images": [],
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
results = await service.refresh_for_model_type("lora", scanner, provider)
|
||||
|
||||
assert set(results.keys()) == {1, 2}
|
||||
assert results[1].version_ids == [42]
|
||||
assert results[2].version_ids == [42]
|
||||
|
||||
with sqlite3.connect(str(db_path)) as conn:
|
||||
count = conn.execute(
|
||||
"SELECT COUNT(*) FROM model_update_versions WHERE version_id = 42"
|
||||
).fetchone()[0]
|
||||
|
||||
assert count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_rewrites_remote_preview_urls(tmp_path):
|
||||
db_path = tmp_path / "updates.sqlite"
|
||||
|
||||
Reference in New Issue
Block a user