feat(download-history): track downloaded model versions

This commit is contained in:
Will Miao
2026-04-03 16:13:14 +08:00
parent 4f599aeced
commit 33a7f07558
9 changed files with 881 additions and 18 deletions

View File

@@ -640,6 +640,13 @@ class DownloadManager:
or version_info.get("modelId")
or (version_info.get("model") or {}).get("id")
)
await self._record_downloaded_version_history(
model_type,
resolved_model_id,
version_info,
model_version_id,
save_path,
)
await self._sync_downloaded_version(
model_type,
resolved_model_id,
@@ -669,6 +676,55 @@ class DownloadManager:
}
return {"success": False, "error": str(e)}
async def _record_downloaded_version_history(
self,
model_type: str,
model_id_value,
version_info: Dict,
fallback_version_id=None,
file_path: str | None = None,
) -> None:
try:
history_service = await ServiceRegistry.get_downloaded_version_history_service()
except Exception as exc:
logger.debug(
"Skipping download history sync; failed to acquire history service: %s",
exc,
)
return
if history_service is None:
return
resolved_model_id = model_id_value
if resolved_model_id is None:
resolved_model_id = version_info.get("modelId")
if resolved_model_id is None:
model_info = version_info.get("model")
if isinstance(model_info, dict):
resolved_model_id = model_info.get("id")
version_id = version_info.get("id")
if version_id is None:
version_id = fallback_version_id
try:
await history_service.mark_downloaded(
model_type,
int(version_id),
model_id=int(resolved_model_id) if resolved_model_id is not None else None,
source="download",
file_path=file_path,
)
except (TypeError, ValueError):
logger.debug(
"Skipping download history sync; invalid identifiers model=%s version=%s",
resolved_model_id,
version_id,
)
except Exception as exc:
logger.debug("Failed to sync download history for %s: %s", model_type, exc)
async def _sync_downloaded_version(
self,
model_type: str,

View File

@@ -0,0 +1,313 @@
from __future__ import annotations
import asyncio
import logging
import os
import sqlite3
import time
from typing import Iterable, Mapping, Optional, Sequence
from ..utils.cache_paths import get_cache_base_dir
from .settings_manager import get_settings_manager
logger = logging.getLogger(__name__)
def _normalize_model_type(model_type: str | None) -> Optional[str]:
if not isinstance(model_type, str):
return None
normalized = model_type.strip().lower()
if normalized in {"lora", "locon", "dora"}:
return "lora"
if normalized == "checkpoint":
return "checkpoint"
if normalized in {"embedding", "textualinversion"}:
return "embedding"
return None
def _normalize_int(value) -> Optional[int]:
try:
if value is None:
return None
return int(value)
except (TypeError, ValueError):
return None
def _resolve_database_path() -> str:
base_dir = get_cache_base_dir(create=True)
history_dir = os.path.join(base_dir, "download_history")
os.makedirs(history_dir, exist_ok=True)
return os.path.join(history_dir, "downloaded_versions.sqlite")
class DownloadedVersionHistoryService:
_SCHEMA = """
CREATE TABLE IF NOT EXISTS downloaded_model_versions (
model_type TEXT NOT NULL,
version_id INTEGER NOT NULL,
model_id INTEGER,
first_seen_at REAL NOT NULL,
last_seen_at REAL NOT NULL,
source TEXT NOT NULL,
last_file_path TEXT,
last_library_name TEXT,
is_deleted_override INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (model_type, version_id)
);
CREATE INDEX IF NOT EXISTS idx_downloaded_model_versions_model
ON downloaded_model_versions(model_type, model_id);
"""
def __init__(self, db_path: str | None = None, *, settings_manager=None) -> None:
self._db_path = db_path or _resolve_database_path()
self._settings = settings_manager or get_settings_manager()
self._lock = asyncio.Lock()
self._schema_initialized = False
self._ensure_directory()
self._initialize_schema()
def _ensure_directory(self) -> None:
directory = os.path.dirname(self._db_path)
if directory:
os.makedirs(directory, exist_ok=True)
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(self._db_path, check_same_thread=False)
conn.row_factory = sqlite3.Row
return conn
def _initialize_schema(self) -> None:
if self._schema_initialized:
return
with self._connect() as conn:
conn.executescript(self._SCHEMA)
conn.commit()
self._schema_initialized = True
def get_database_path(self) -> str:
return self._db_path
def _get_active_library_name(self) -> str | None:
try:
value = self._settings.get_active_library_name()
except Exception:
return None
return value or None
async def mark_downloaded(
self,
model_type: str,
version_id: int,
*,
model_id: int | None = None,
source: str = "manual",
file_path: str | None = None,
library_name: str | None = None,
) -> None:
normalized_type = _normalize_model_type(model_type)
normalized_version_id = _normalize_int(version_id)
normalized_model_id = _normalize_int(model_id)
if normalized_type is None or normalized_version_id is None:
return
active_library_name = library_name or self._get_active_library_name()
timestamp = time.time()
async with self._lock:
with self._connect() as conn:
conn.execute(
"""
INSERT INTO downloaded_model_versions (
model_type, version_id, model_id, first_seen_at, last_seen_at,
source, last_file_path, last_library_name, is_deleted_override
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0)
ON CONFLICT(model_type, version_id) DO UPDATE SET
model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id),
last_seen_at = excluded.last_seen_at,
source = excluded.source,
last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path),
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
is_deleted_override = 0
""",
(
normalized_type,
normalized_version_id,
normalized_model_id,
timestamp,
timestamp,
source,
file_path,
active_library_name,
),
)
conn.commit()
async def mark_downloaded_bulk(
self,
model_type: str,
records: Sequence[Mapping[str, object]],
*,
source: str = "scan",
library_name: str | None = None,
) -> None:
normalized_type = _normalize_model_type(model_type)
if normalized_type is None or not records:
return
timestamp = time.time()
active_library_name = library_name or self._get_active_library_name()
payload: list[tuple[object, ...]] = []
for record in records:
version_id = _normalize_int(record.get("version_id"))
if version_id is None:
continue
payload.append(
(
normalized_type,
version_id,
_normalize_int(record.get("model_id")),
timestamp,
timestamp,
source,
record.get("file_path"),
active_library_name,
)
)
if not payload:
return
async with self._lock:
with self._connect() as conn:
conn.executemany(
"""
INSERT INTO downloaded_model_versions (
model_type, version_id, model_id, first_seen_at, last_seen_at,
source, last_file_path, last_library_name, is_deleted_override
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0)
ON CONFLICT(model_type, version_id) DO UPDATE SET
model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id),
last_seen_at = excluded.last_seen_at,
source = excluded.source,
last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path),
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
is_deleted_override = 0
""",
payload,
)
conn.commit()
async def mark_not_downloaded(self, model_type: str, version_id: int) -> None:
normalized_type = _normalize_model_type(model_type)
normalized_version_id = _normalize_int(version_id)
if normalized_type is None or normalized_version_id is None:
return
timestamp = time.time()
async with self._lock:
with self._connect() as conn:
conn.execute(
"""
INSERT INTO downloaded_model_versions (
model_type, version_id, model_id, first_seen_at, last_seen_at,
source, last_file_path, last_library_name, is_deleted_override
) VALUES (?, ?, NULL, ?, ?, 'manual', NULL, ?, 1)
ON CONFLICT(model_type, version_id) DO UPDATE SET
last_seen_at = excluded.last_seen_at,
source = excluded.source,
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
is_deleted_override = 1
""",
(
normalized_type,
normalized_version_id,
timestamp,
timestamp,
self._get_active_library_name(),
),
)
conn.commit()
async def has_been_downloaded(self, model_type: str, version_id: int) -> bool:
normalized_type = _normalize_model_type(model_type)
normalized_version_id = _normalize_int(version_id)
if normalized_type is None or normalized_version_id is None:
return False
async with self._lock:
with self._connect() as conn:
row = conn.execute(
"""
SELECT is_deleted_override
FROM downloaded_model_versions
WHERE model_type = ? AND version_id = ?
""",
(normalized_type, normalized_version_id),
).fetchone()
return bool(row) and not bool(row["is_deleted_override"])
async def get_downloaded_version_ids(
self, model_type: str, model_id: int
) -> list[int]:
normalized_type = _normalize_model_type(model_type)
normalized_model_id = _normalize_int(model_id)
if normalized_type is None or normalized_model_id is None:
return []
async with self._lock:
with self._connect() as conn:
rows = conn.execute(
"""
SELECT version_id
FROM downloaded_model_versions
WHERE model_type = ? AND model_id = ? AND is_deleted_override = 0
ORDER BY version_id ASC
""",
(normalized_type, normalized_model_id),
).fetchall()
return [int(row["version_id"]) for row in rows]
async def get_downloaded_version_ids_bulk(
self, model_type: str, model_ids: Iterable[int]
) -> dict[int, set[int]]:
normalized_type = _normalize_model_type(model_type)
if normalized_type is None:
return {}
normalized_model_ids = sorted(
{
value
for value in (_normalize_int(model_id) for model_id in model_ids)
if value is not None
}
)
if not normalized_model_ids:
return {}
placeholders = ", ".join(["?"] * len(normalized_model_ids))
params: list[object] = [normalized_type, *normalized_model_ids]
async with self._lock:
with self._connect() as conn:
rows = conn.execute(
f"""
SELECT model_id, version_id
FROM downloaded_model_versions
WHERE model_type = ?
AND model_id IN ({placeholders})
AND is_deleted_override = 0
""",
params,
).fetchall()
result: dict[int, set[int]] = {}
for row in rows:
model_id = _normalize_int(row["model_id"])
version_id = _normalize_int(row["version_id"])
if model_id is None or version_id is None:
continue
result.setdefault(model_id, set()).add(version_id)
return result

View File

@@ -411,6 +411,7 @@ class ModelScanner:
if scan_result:
await self._apply_scan_result(scan_result)
await self._save_persistent_cache(scan_result)
await self._sync_download_history(scan_result.raw_data, source='scan')
# Send final progress update
await ws_manager.broadcast_init_progress({
@@ -516,6 +517,7 @@ class ModelScanner:
)
await self._apply_scan_result(scan_result)
await self._sync_download_history(adjusted_raw_data, source='scan')
await ws_manager.broadcast_init_progress({
'stage': 'loading_cache',
@@ -576,6 +578,7 @@ class ModelScanner:
excluded_models=list(self._excluded_models)
)
await self._save_persistent_cache(snapshot)
await self._sync_download_history(snapshot.raw_data, source='scan')
def _count_model_files(self) -> int:
"""Count all model files with supported extensions in all roots
@@ -704,6 +707,7 @@ class ModelScanner:
scan_result = await self._gather_model_data()
await self._apply_scan_result(scan_result)
await self._save_persistent_cache(scan_result)
await self._sync_download_history(scan_result.raw_data, source='scan')
logger.info(
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
@@ -1101,6 +1105,49 @@ class ModelScanner:
await self._cache.resort()
async def _sync_download_history(
self,
raw_data: List[Mapping[str, Any]],
*,
source: str,
) -> None:
records: List[Dict[str, Any]] = []
for item in raw_data or []:
if not isinstance(item, Mapping):
continue
civitai = item.get('civitai')
if not isinstance(civitai, Mapping):
continue
version_id = civitai.get('id')
if version_id in (None, ''):
continue
records.append(
{
'version_id': version_id,
'model_id': civitai.get('modelId'),
'file_path': item.get('file_path'),
}
)
if not records:
return
try:
history_service = await ServiceRegistry.get_downloaded_version_history_service()
await history_service.mark_downloaded_bulk(
self.model_type,
records,
source=source,
)
except Exception as exc:
logger.debug(
"%s Scanner: Failed to sync download history: %s",
self.model_type.capitalize(),
exc,
)
async def _gather_model_data(
self,
*,

View File

@@ -167,6 +167,28 @@ class ServiceRegistry:
logger.debug(f"Created and registered {service_name}")
return service
@classmethod
async def get_downloaded_version_history_service(cls):
"""Get or create the downloaded-version history service."""
service_name = "downloaded_version_history_service"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
if service_name in cls._services:
return cls._services[service_name]
from .downloaded_version_history_service import (
DownloadedVersionHistoryService,
)
service = DownloadedVersionHistoryService()
cls._services[service_name] = service
logger.debug(f"Created and registered {service_name}")
return service
@classmethod
async def get_civarchive_client(cls):
"""Get or create CivArchive client instance"""
@@ -255,4 +277,4 @@ class ServiceRegistry:
"""Clear all registered services - mainly for testing"""
cls._services.clear()
cls._locks.clear()
logger.info("Cleared all registered services")
logger.info("Cleared all registered services")