Files
ComfyUI-Lora-Manager/py/services/downloaded_version_history_service.py

314 lines
12 KiB
Python

from __future__ import annotations
import asyncio
import logging
import os
import sqlite3
import time
from typing import Iterable, Mapping, Optional, Sequence
from ..utils.cache_paths import get_cache_base_dir
from .settings_manager import get_settings_manager
logger = logging.getLogger(__name__)
def _normalize_model_type(model_type: str | None) -> Optional[str]:
if not isinstance(model_type, str):
return None
normalized = model_type.strip().lower()
if normalized in {"lora", "locon", "dora"}:
return "lora"
if normalized == "checkpoint":
return "checkpoint"
if normalized in {"embedding", "textualinversion"}:
return "embedding"
return None
def _normalize_int(value) -> Optional[int]:
try:
if value is None:
return None
return int(value)
except (TypeError, ValueError):
return None
def _resolve_database_path() -> str:
base_dir = get_cache_base_dir(create=True)
history_dir = os.path.join(base_dir, "download_history")
os.makedirs(history_dir, exist_ok=True)
return os.path.join(history_dir, "downloaded_versions.sqlite")
class DownloadedVersionHistoryService:
_SCHEMA = """
CREATE TABLE IF NOT EXISTS downloaded_model_versions (
model_type TEXT NOT NULL,
version_id INTEGER NOT NULL,
model_id INTEGER,
first_seen_at REAL NOT NULL,
last_seen_at REAL NOT NULL,
source TEXT NOT NULL,
last_file_path TEXT,
last_library_name TEXT,
is_deleted_override INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (model_type, version_id)
);
CREATE INDEX IF NOT EXISTS idx_downloaded_model_versions_model
ON downloaded_model_versions(model_type, model_id);
"""
def __init__(self, db_path: str | None = None, *, settings_manager=None) -> None:
self._db_path = db_path or _resolve_database_path()
self._settings = settings_manager or get_settings_manager()
self._lock = asyncio.Lock()
self._schema_initialized = False
self._ensure_directory()
self._initialize_schema()
def _ensure_directory(self) -> None:
directory = os.path.dirname(self._db_path)
if directory:
os.makedirs(directory, exist_ok=True)
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(self._db_path, check_same_thread=False)
conn.row_factory = sqlite3.Row
return conn
def _initialize_schema(self) -> None:
if self._schema_initialized:
return
with self._connect() as conn:
conn.executescript(self._SCHEMA)
conn.commit()
self._schema_initialized = True
def get_database_path(self) -> str:
return self._db_path
def _get_active_library_name(self) -> str | None:
try:
value = self._settings.get_active_library_name()
except Exception:
return None
return value or None
async def mark_downloaded(
self,
model_type: str,
version_id: int,
*,
model_id: int | None = None,
source: str = "manual",
file_path: str | None = None,
library_name: str | None = None,
) -> None:
normalized_type = _normalize_model_type(model_type)
normalized_version_id = _normalize_int(version_id)
normalized_model_id = _normalize_int(model_id)
if normalized_type is None or normalized_version_id is None:
return
active_library_name = library_name or self._get_active_library_name()
timestamp = time.time()
async with self._lock:
with self._connect() as conn:
conn.execute(
"""
INSERT INTO downloaded_model_versions (
model_type, version_id, model_id, first_seen_at, last_seen_at,
source, last_file_path, last_library_name, is_deleted_override
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0)
ON CONFLICT(model_type, version_id) DO UPDATE SET
model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id),
last_seen_at = excluded.last_seen_at,
source = excluded.source,
last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path),
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
is_deleted_override = 0
""",
(
normalized_type,
normalized_version_id,
normalized_model_id,
timestamp,
timestamp,
source,
file_path,
active_library_name,
),
)
conn.commit()
async def mark_downloaded_bulk(
self,
model_type: str,
records: Sequence[Mapping[str, object]],
*,
source: str = "scan",
library_name: str | None = None,
) -> None:
normalized_type = _normalize_model_type(model_type)
if normalized_type is None or not records:
return
timestamp = time.time()
active_library_name = library_name or self._get_active_library_name()
payload: list[tuple[object, ...]] = []
for record in records:
version_id = _normalize_int(record.get("version_id"))
if version_id is None:
continue
payload.append(
(
normalized_type,
version_id,
_normalize_int(record.get("model_id")),
timestamp,
timestamp,
source,
record.get("file_path"),
active_library_name,
)
)
if not payload:
return
async with self._lock:
with self._connect() as conn:
conn.executemany(
"""
INSERT INTO downloaded_model_versions (
model_type, version_id, model_id, first_seen_at, last_seen_at,
source, last_file_path, last_library_name, is_deleted_override
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0)
ON CONFLICT(model_type, version_id) DO UPDATE SET
model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id),
last_seen_at = excluded.last_seen_at,
source = excluded.source,
last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path),
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
is_deleted_override = 0
""",
payload,
)
conn.commit()
async def mark_not_downloaded(self, model_type: str, version_id: int) -> None:
normalized_type = _normalize_model_type(model_type)
normalized_version_id = _normalize_int(version_id)
if normalized_type is None or normalized_version_id is None:
return
timestamp = time.time()
async with self._lock:
with self._connect() as conn:
conn.execute(
"""
INSERT INTO downloaded_model_versions (
model_type, version_id, model_id, first_seen_at, last_seen_at,
source, last_file_path, last_library_name, is_deleted_override
) VALUES (?, ?, NULL, ?, ?, 'manual', NULL, ?, 1)
ON CONFLICT(model_type, version_id) DO UPDATE SET
last_seen_at = excluded.last_seen_at,
source = excluded.source,
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
is_deleted_override = 1
""",
(
normalized_type,
normalized_version_id,
timestamp,
timestamp,
self._get_active_library_name(),
),
)
conn.commit()
async def has_been_downloaded(self, model_type: str, version_id: int) -> bool:
normalized_type = _normalize_model_type(model_type)
normalized_version_id = _normalize_int(version_id)
if normalized_type is None or normalized_version_id is None:
return False
async with self._lock:
with self._connect() as conn:
row = conn.execute(
"""
SELECT is_deleted_override
FROM downloaded_model_versions
WHERE model_type = ? AND version_id = ?
""",
(normalized_type, normalized_version_id),
).fetchone()
return bool(row) and not bool(row["is_deleted_override"])
async def get_downloaded_version_ids(
self, model_type: str, model_id: int
) -> list[int]:
normalized_type = _normalize_model_type(model_type)
normalized_model_id = _normalize_int(model_id)
if normalized_type is None or normalized_model_id is None:
return []
async with self._lock:
with self._connect() as conn:
rows = conn.execute(
"""
SELECT version_id
FROM downloaded_model_versions
WHERE model_type = ? AND model_id = ? AND is_deleted_override = 0
ORDER BY version_id ASC
""",
(normalized_type, normalized_model_id),
).fetchall()
return [int(row["version_id"]) for row in rows]
async def get_downloaded_version_ids_bulk(
self, model_type: str, model_ids: Iterable[int]
) -> dict[int, set[int]]:
normalized_type = _normalize_model_type(model_type)
if normalized_type is None:
return {}
normalized_model_ids = sorted(
{
value
for value in (_normalize_int(model_id) for model_id in model_ids)
if value is not None
}
)
if not normalized_model_ids:
return {}
placeholders = ", ".join(["?"] * len(normalized_model_ids))
params: list[object] = [normalized_type, *normalized_model_ids]
async with self._lock:
with self._connect() as conn:
rows = conn.execute(
f"""
SELECT model_id, version_id
FROM downloaded_model_versions
WHERE model_type = ?
AND model_id IN ({placeholders})
AND is_deleted_override = 0
""",
params,
).fetchall()
result: dict[int, set[int]] = {}
for row in rows:
model_id = _normalize_int(row["model_id"])
version_id = _normalize_int(row["version_id"])
if model_id is None or version_id is None:
continue
result.setdefault(model_id, set()).add(version_id)
return result