mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-04-10 04:42:14 -03:00
feat(download-history): track downloaded model versions
This commit is contained in:
313
py/services/downloaded_version_history_service.py
Normal file
313
py/services/downloaded_version_history_service.py
Normal file
@@ -0,0 +1,313 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
from typing import Iterable, Mapping, Optional, Sequence
|
||||
|
||||
from ..utils.cache_paths import get_cache_base_dir
|
||||
from .settings_manager import get_settings_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_model_type(model_type: str | None) -> Optional[str]:
|
||||
if not isinstance(model_type, str):
|
||||
return None
|
||||
normalized = model_type.strip().lower()
|
||||
if normalized in {"lora", "locon", "dora"}:
|
||||
return "lora"
|
||||
if normalized == "checkpoint":
|
||||
return "checkpoint"
|
||||
if normalized in {"embedding", "textualinversion"}:
|
||||
return "embedding"
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_int(value) -> Optional[int]:
|
||||
try:
|
||||
if value is None:
|
||||
return None
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_database_path() -> str:
|
||||
base_dir = get_cache_base_dir(create=True)
|
||||
history_dir = os.path.join(base_dir, "download_history")
|
||||
os.makedirs(history_dir, exist_ok=True)
|
||||
return os.path.join(history_dir, "downloaded_versions.sqlite")
|
||||
|
||||
|
||||
class DownloadedVersionHistoryService:
|
||||
_SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS downloaded_model_versions (
|
||||
model_type TEXT NOT NULL,
|
||||
version_id INTEGER NOT NULL,
|
||||
model_id INTEGER,
|
||||
first_seen_at REAL NOT NULL,
|
||||
last_seen_at REAL NOT NULL,
|
||||
source TEXT NOT NULL,
|
||||
last_file_path TEXT,
|
||||
last_library_name TEXT,
|
||||
is_deleted_override INTEGER NOT NULL DEFAULT 0,
|
||||
PRIMARY KEY (model_type, version_id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_downloaded_model_versions_model
|
||||
ON downloaded_model_versions(model_type, model_id);
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str | None = None, *, settings_manager=None) -> None:
|
||||
self._db_path = db_path or _resolve_database_path()
|
||||
self._settings = settings_manager or get_settings_manager()
|
||||
self._lock = asyncio.Lock()
|
||||
self._schema_initialized = False
|
||||
self._ensure_directory()
|
||||
self._initialize_schema()
|
||||
|
||||
def _ensure_directory(self) -> None:
|
||||
directory = os.path.dirname(self._db_path)
|
||||
if directory:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self._db_path, check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
def _initialize_schema(self) -> None:
|
||||
if self._schema_initialized:
|
||||
return
|
||||
with self._connect() as conn:
|
||||
conn.executescript(self._SCHEMA)
|
||||
conn.commit()
|
||||
self._schema_initialized = True
|
||||
|
||||
def get_database_path(self) -> str:
|
||||
return self._db_path
|
||||
|
||||
def _get_active_library_name(self) -> str | None:
|
||||
try:
|
||||
value = self._settings.get_active_library_name()
|
||||
except Exception:
|
||||
return None
|
||||
return value or None
|
||||
|
||||
async def mark_downloaded(
|
||||
self,
|
||||
model_type: str,
|
||||
version_id: int,
|
||||
*,
|
||||
model_id: int | None = None,
|
||||
source: str = "manual",
|
||||
file_path: str | None = None,
|
||||
library_name: str | None = None,
|
||||
) -> None:
|
||||
normalized_type = _normalize_model_type(model_type)
|
||||
normalized_version_id = _normalize_int(version_id)
|
||||
normalized_model_id = _normalize_int(model_id)
|
||||
if normalized_type is None or normalized_version_id is None:
|
||||
return
|
||||
|
||||
active_library_name = library_name or self._get_active_library_name()
|
||||
timestamp = time.time()
|
||||
|
||||
async with self._lock:
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO downloaded_model_versions (
|
||||
model_type, version_id, model_id, first_seen_at, last_seen_at,
|
||||
source, last_file_path, last_library_name, is_deleted_override
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0)
|
||||
ON CONFLICT(model_type, version_id) DO UPDATE SET
|
||||
model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id),
|
||||
last_seen_at = excluded.last_seen_at,
|
||||
source = excluded.source,
|
||||
last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path),
|
||||
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
|
||||
is_deleted_override = 0
|
||||
""",
|
||||
(
|
||||
normalized_type,
|
||||
normalized_version_id,
|
||||
normalized_model_id,
|
||||
timestamp,
|
||||
timestamp,
|
||||
source,
|
||||
file_path,
|
||||
active_library_name,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
async def mark_downloaded_bulk(
|
||||
self,
|
||||
model_type: str,
|
||||
records: Sequence[Mapping[str, object]],
|
||||
*,
|
||||
source: str = "scan",
|
||||
library_name: str | None = None,
|
||||
) -> None:
|
||||
normalized_type = _normalize_model_type(model_type)
|
||||
if normalized_type is None or not records:
|
||||
return
|
||||
|
||||
timestamp = time.time()
|
||||
active_library_name = library_name or self._get_active_library_name()
|
||||
payload: list[tuple[object, ...]] = []
|
||||
for record in records:
|
||||
version_id = _normalize_int(record.get("version_id"))
|
||||
if version_id is None:
|
||||
continue
|
||||
payload.append(
|
||||
(
|
||||
normalized_type,
|
||||
version_id,
|
||||
_normalize_int(record.get("model_id")),
|
||||
timestamp,
|
||||
timestamp,
|
||||
source,
|
||||
record.get("file_path"),
|
||||
active_library_name,
|
||||
)
|
||||
)
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
async with self._lock:
|
||||
with self._connect() as conn:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT INTO downloaded_model_versions (
|
||||
model_type, version_id, model_id, first_seen_at, last_seen_at,
|
||||
source, last_file_path, last_library_name, is_deleted_override
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0)
|
||||
ON CONFLICT(model_type, version_id) DO UPDATE SET
|
||||
model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id),
|
||||
last_seen_at = excluded.last_seen_at,
|
||||
source = excluded.source,
|
||||
last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path),
|
||||
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
|
||||
is_deleted_override = 0
|
||||
""",
|
||||
payload,
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
async def mark_not_downloaded(self, model_type: str, version_id: int) -> None:
|
||||
normalized_type = _normalize_model_type(model_type)
|
||||
normalized_version_id = _normalize_int(version_id)
|
||||
if normalized_type is None or normalized_version_id is None:
|
||||
return
|
||||
|
||||
timestamp = time.time()
|
||||
|
||||
async with self._lock:
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO downloaded_model_versions (
|
||||
model_type, version_id, model_id, first_seen_at, last_seen_at,
|
||||
source, last_file_path, last_library_name, is_deleted_override
|
||||
) VALUES (?, ?, NULL, ?, ?, 'manual', NULL, ?, 1)
|
||||
ON CONFLICT(model_type, version_id) DO UPDATE SET
|
||||
last_seen_at = excluded.last_seen_at,
|
||||
source = excluded.source,
|
||||
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
|
||||
is_deleted_override = 1
|
||||
""",
|
||||
(
|
||||
normalized_type,
|
||||
normalized_version_id,
|
||||
timestamp,
|
||||
timestamp,
|
||||
self._get_active_library_name(),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
async def has_been_downloaded(self, model_type: str, version_id: int) -> bool:
|
||||
normalized_type = _normalize_model_type(model_type)
|
||||
normalized_version_id = _normalize_int(version_id)
|
||||
if normalized_type is None or normalized_version_id is None:
|
||||
return False
|
||||
|
||||
async with self._lock:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT is_deleted_override
|
||||
FROM downloaded_model_versions
|
||||
WHERE model_type = ? AND version_id = ?
|
||||
""",
|
||||
(normalized_type, normalized_version_id),
|
||||
).fetchone()
|
||||
return bool(row) and not bool(row["is_deleted_override"])
|
||||
|
||||
async def get_downloaded_version_ids(
|
||||
self, model_type: str, model_id: int
|
||||
) -> list[int]:
|
||||
normalized_type = _normalize_model_type(model_type)
|
||||
normalized_model_id = _normalize_int(model_id)
|
||||
if normalized_type is None or normalized_model_id is None:
|
||||
return []
|
||||
|
||||
async with self._lock:
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT version_id
|
||||
FROM downloaded_model_versions
|
||||
WHERE model_type = ? AND model_id = ? AND is_deleted_override = 0
|
||||
ORDER BY version_id ASC
|
||||
""",
|
||||
(normalized_type, normalized_model_id),
|
||||
).fetchall()
|
||||
return [int(row["version_id"]) for row in rows]
|
||||
|
||||
async def get_downloaded_version_ids_bulk(
|
||||
self, model_type: str, model_ids: Iterable[int]
|
||||
) -> dict[int, set[int]]:
|
||||
normalized_type = _normalize_model_type(model_type)
|
||||
if normalized_type is None:
|
||||
return {}
|
||||
|
||||
normalized_model_ids = sorted(
|
||||
{
|
||||
value
|
||||
for value in (_normalize_int(model_id) for model_id in model_ids)
|
||||
if value is not None
|
||||
}
|
||||
)
|
||||
if not normalized_model_ids:
|
||||
return {}
|
||||
|
||||
placeholders = ", ".join(["?"] * len(normalized_model_ids))
|
||||
params: list[object] = [normalized_type, *normalized_model_ids]
|
||||
|
||||
async with self._lock:
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(
|
||||
f"""
|
||||
SELECT model_id, version_id
|
||||
FROM downloaded_model_versions
|
||||
WHERE model_type = ?
|
||||
AND model_id IN ({placeholders})
|
||||
AND is_deleted_override = 0
|
||||
""",
|
||||
params,
|
||||
).fetchall()
|
||||
|
||||
result: dict[int, set[int]] = {}
|
||||
for row in rows:
|
||||
model_id = _normalize_int(row["model_id"])
|
||||
version_id = _normalize_int(row["version_id"])
|
||||
if model_id is None or version_id is None:
|
||||
continue
|
||||
result.setdefault(model_id, set()).add(version_id)
|
||||
return result
|
||||
Reference in New Issue
Block a user