mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-04-08 05:16:41 -03:00
314 lines
12 KiB
Python
314 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import sqlite3
|
|
import time
|
|
from typing import Iterable, Mapping, Optional, Sequence
|
|
|
|
from ..utils.cache_paths import get_cache_base_dir
|
|
from .settings_manager import get_settings_manager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _normalize_model_type(model_type: str | None) -> Optional[str]:
|
|
if not isinstance(model_type, str):
|
|
return None
|
|
normalized = model_type.strip().lower()
|
|
if normalized in {"lora", "locon", "dora"}:
|
|
return "lora"
|
|
if normalized == "checkpoint":
|
|
return "checkpoint"
|
|
if normalized in {"embedding", "textualinversion"}:
|
|
return "embedding"
|
|
return None
|
|
|
|
|
|
def _normalize_int(value) -> Optional[int]:
|
|
try:
|
|
if value is None:
|
|
return None
|
|
return int(value)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
|
|
def _resolve_database_path() -> str:
|
|
base_dir = get_cache_base_dir(create=True)
|
|
history_dir = os.path.join(base_dir, "download_history")
|
|
os.makedirs(history_dir, exist_ok=True)
|
|
return os.path.join(history_dir, "downloaded_versions.sqlite")
|
|
|
|
|
|
class DownloadedVersionHistoryService:
|
|
_SCHEMA = """
|
|
CREATE TABLE IF NOT EXISTS downloaded_model_versions (
|
|
model_type TEXT NOT NULL,
|
|
version_id INTEGER NOT NULL,
|
|
model_id INTEGER,
|
|
first_seen_at REAL NOT NULL,
|
|
last_seen_at REAL NOT NULL,
|
|
source TEXT NOT NULL,
|
|
last_file_path TEXT,
|
|
last_library_name TEXT,
|
|
is_deleted_override INTEGER NOT NULL DEFAULT 0,
|
|
PRIMARY KEY (model_type, version_id)
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_downloaded_model_versions_model
|
|
ON downloaded_model_versions(model_type, model_id);
|
|
"""
|
|
|
|
def __init__(self, db_path: str | None = None, *, settings_manager=None) -> None:
|
|
self._db_path = db_path or _resolve_database_path()
|
|
self._settings = settings_manager or get_settings_manager()
|
|
self._lock = asyncio.Lock()
|
|
self._schema_initialized = False
|
|
self._ensure_directory()
|
|
self._initialize_schema()
|
|
|
|
def _ensure_directory(self) -> None:
|
|
directory = os.path.dirname(self._db_path)
|
|
if directory:
|
|
os.makedirs(directory, exist_ok=True)
|
|
|
|
def _connect(self) -> sqlite3.Connection:
|
|
conn = sqlite3.connect(self._db_path, check_same_thread=False)
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
def _initialize_schema(self) -> None:
|
|
if self._schema_initialized:
|
|
return
|
|
with self._connect() as conn:
|
|
conn.executescript(self._SCHEMA)
|
|
conn.commit()
|
|
self._schema_initialized = True
|
|
|
|
def get_database_path(self) -> str:
|
|
return self._db_path
|
|
|
|
def _get_active_library_name(self) -> str | None:
|
|
try:
|
|
value = self._settings.get_active_library_name()
|
|
except Exception:
|
|
return None
|
|
return value or None
|
|
|
|
async def mark_downloaded(
|
|
self,
|
|
model_type: str,
|
|
version_id: int,
|
|
*,
|
|
model_id: int | None = None,
|
|
source: str = "manual",
|
|
file_path: str | None = None,
|
|
library_name: str | None = None,
|
|
) -> None:
|
|
normalized_type = _normalize_model_type(model_type)
|
|
normalized_version_id = _normalize_int(version_id)
|
|
normalized_model_id = _normalize_int(model_id)
|
|
if normalized_type is None or normalized_version_id is None:
|
|
return
|
|
|
|
active_library_name = library_name or self._get_active_library_name()
|
|
timestamp = time.time()
|
|
|
|
async with self._lock:
|
|
with self._connect() as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO downloaded_model_versions (
|
|
model_type, version_id, model_id, first_seen_at, last_seen_at,
|
|
source, last_file_path, last_library_name, is_deleted_override
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0)
|
|
ON CONFLICT(model_type, version_id) DO UPDATE SET
|
|
model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id),
|
|
last_seen_at = excluded.last_seen_at,
|
|
source = excluded.source,
|
|
last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path),
|
|
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
|
|
is_deleted_override = 0
|
|
""",
|
|
(
|
|
normalized_type,
|
|
normalized_version_id,
|
|
normalized_model_id,
|
|
timestamp,
|
|
timestamp,
|
|
source,
|
|
file_path,
|
|
active_library_name,
|
|
),
|
|
)
|
|
conn.commit()
|
|
|
|
async def mark_downloaded_bulk(
|
|
self,
|
|
model_type: str,
|
|
records: Sequence[Mapping[str, object]],
|
|
*,
|
|
source: str = "scan",
|
|
library_name: str | None = None,
|
|
) -> None:
|
|
normalized_type = _normalize_model_type(model_type)
|
|
if normalized_type is None or not records:
|
|
return
|
|
|
|
timestamp = time.time()
|
|
active_library_name = library_name or self._get_active_library_name()
|
|
payload: list[tuple[object, ...]] = []
|
|
for record in records:
|
|
version_id = _normalize_int(record.get("version_id"))
|
|
if version_id is None:
|
|
continue
|
|
payload.append(
|
|
(
|
|
normalized_type,
|
|
version_id,
|
|
_normalize_int(record.get("model_id")),
|
|
timestamp,
|
|
timestamp,
|
|
source,
|
|
record.get("file_path"),
|
|
active_library_name,
|
|
)
|
|
)
|
|
|
|
if not payload:
|
|
return
|
|
|
|
async with self._lock:
|
|
with self._connect() as conn:
|
|
conn.executemany(
|
|
"""
|
|
INSERT INTO downloaded_model_versions (
|
|
model_type, version_id, model_id, first_seen_at, last_seen_at,
|
|
source, last_file_path, last_library_name, is_deleted_override
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0)
|
|
ON CONFLICT(model_type, version_id) DO UPDATE SET
|
|
model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id),
|
|
last_seen_at = excluded.last_seen_at,
|
|
source = excluded.source,
|
|
last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path),
|
|
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
|
|
is_deleted_override = 0
|
|
""",
|
|
payload,
|
|
)
|
|
conn.commit()
|
|
|
|
async def mark_not_downloaded(self, model_type: str, version_id: int) -> None:
|
|
normalized_type = _normalize_model_type(model_type)
|
|
normalized_version_id = _normalize_int(version_id)
|
|
if normalized_type is None or normalized_version_id is None:
|
|
return
|
|
|
|
timestamp = time.time()
|
|
|
|
async with self._lock:
|
|
with self._connect() as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO downloaded_model_versions (
|
|
model_type, version_id, model_id, first_seen_at, last_seen_at,
|
|
source, last_file_path, last_library_name, is_deleted_override
|
|
) VALUES (?, ?, NULL, ?, ?, 'manual', NULL, ?, 1)
|
|
ON CONFLICT(model_type, version_id) DO UPDATE SET
|
|
last_seen_at = excluded.last_seen_at,
|
|
source = excluded.source,
|
|
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
|
|
is_deleted_override = 1
|
|
""",
|
|
(
|
|
normalized_type,
|
|
normalized_version_id,
|
|
timestamp,
|
|
timestamp,
|
|
self._get_active_library_name(),
|
|
),
|
|
)
|
|
conn.commit()
|
|
|
|
async def has_been_downloaded(self, model_type: str, version_id: int) -> bool:
|
|
normalized_type = _normalize_model_type(model_type)
|
|
normalized_version_id = _normalize_int(version_id)
|
|
if normalized_type is None or normalized_version_id is None:
|
|
return False
|
|
|
|
async with self._lock:
|
|
with self._connect() as conn:
|
|
row = conn.execute(
|
|
"""
|
|
SELECT is_deleted_override
|
|
FROM downloaded_model_versions
|
|
WHERE model_type = ? AND version_id = ?
|
|
""",
|
|
(normalized_type, normalized_version_id),
|
|
).fetchone()
|
|
return bool(row) and not bool(row["is_deleted_override"])
|
|
|
|
async def get_downloaded_version_ids(
|
|
self, model_type: str, model_id: int
|
|
) -> list[int]:
|
|
normalized_type = _normalize_model_type(model_type)
|
|
normalized_model_id = _normalize_int(model_id)
|
|
if normalized_type is None or normalized_model_id is None:
|
|
return []
|
|
|
|
async with self._lock:
|
|
with self._connect() as conn:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT version_id
|
|
FROM downloaded_model_versions
|
|
WHERE model_type = ? AND model_id = ? AND is_deleted_override = 0
|
|
ORDER BY version_id ASC
|
|
""",
|
|
(normalized_type, normalized_model_id),
|
|
).fetchall()
|
|
return [int(row["version_id"]) for row in rows]
|
|
|
|
async def get_downloaded_version_ids_bulk(
|
|
self, model_type: str, model_ids: Iterable[int]
|
|
) -> dict[int, set[int]]:
|
|
normalized_type = _normalize_model_type(model_type)
|
|
if normalized_type is None:
|
|
return {}
|
|
|
|
normalized_model_ids = sorted(
|
|
{
|
|
value
|
|
for value in (_normalize_int(model_id) for model_id in model_ids)
|
|
if value is not None
|
|
}
|
|
)
|
|
if not normalized_model_ids:
|
|
return {}
|
|
|
|
placeholders = ", ".join(["?"] * len(normalized_model_ids))
|
|
params: list[object] = [normalized_type, *normalized_model_ids]
|
|
|
|
async with self._lock:
|
|
with self._connect() as conn:
|
|
rows = conn.execute(
|
|
f"""
|
|
SELECT model_id, version_id
|
|
FROM downloaded_model_versions
|
|
WHERE model_type = ?
|
|
AND model_id IN ({placeholders})
|
|
AND is_deleted_override = 0
|
|
""",
|
|
params,
|
|
).fetchall()
|
|
|
|
result: dict[int, set[int]] = {}
|
|
for row in rows:
|
|
model_id = _normalize_int(row["model_id"])
|
|
version_id = _normalize_int(row["version_id"])
|
|
if model_id is None or version_id is None:
|
|
continue
|
|
result.setdefault(model_id, set()).add(version_id)
|
|
return result
|