diff --git a/py/lora_manager.py b/py/lora_manager.py index e57fc4a9..1f66ddb4 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -189,6 +189,10 @@ class LoraManager: # Register DownloadManager with ServiceRegistry await ServiceRegistry.get_download_manager() + + # Initialize DownloadQueueService for persistent queue/history + await ServiceRegistry.get_download_queue_service() + await ServiceRegistry.get_backup_service() from .services.metadata_service import initialize_metadata_providers diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index 545da08b..1340ee07 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -37,6 +37,7 @@ from ...services.use_cases import ( ) from ...services.websocket_manager import WebSocketManager from ...services.websocket_progress_callback import WebSocketProgressCallback +from ...services.download_queue_service import DownloadQueueService from ...services.errors import RateLimitError, ResourceNotFoundError from ...utils.civitai_utils import resolve_license_payload from ...utils.file_utils import calculate_sha256 @@ -1567,6 +1568,255 @@ class ModelDownloadHandler: ) return web.json_response({"success": False, "error": str(exc)}, status=500) + # ------------------------------------------------------------------ + # Download queue / history handlers + # ------------------------------------------------------------------ + + async def get_download_queue(self, request: web.Request) -> web.Response: + try: + service = await DownloadQueueService.get_instance() + queue = await service.get_queue() + stats = await service.get_stats() + return web.json_response({"success": True, "queue": queue, "stats": stats}) + except Exception as exc: + self._logger.error( + "Error getting download queue: %s", exc, exc_info=True + ) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def add_to_download_queue(self, request: web.Request) -> web.Response: + try: + import uuid + + download_id = request.query.get("download_id") or str(uuid.uuid4()) + model_id_str = request.query.get("model_id") + model_version_id_str = request.query.get("model_version_id") + model_name = request.query.get("model_name", "") + version_name = request.query.get("version_name", "") + thumbnail_url = request.query.get("thumbnail_url", "") + source = request.query.get("source") + file_params_json = request.query.get("file_params") + + model_id = int(model_id_str) if model_id_str else None + model_version_id = int(model_version_id_str) if model_version_id_str else None + file_params = json.loads(file_params_json) if file_params_json else None + + service = await DownloadQueueService.get_instance() + item = await service.add_to_queue( + download_id=download_id, + model_id=model_id, + model_version_id=model_version_id, + model_name=model_name, + version_name=version_name, + thumbnail_url=thumbnail_url, + source=source, + file_params=file_params, + ) + return web.json_response({"success": True, "item": item}) + except Exception as exc: + self._logger.error( + "Error adding to download queue: %s", exc, exc_info=True + ) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def remove_from_download_queue(self, request: web.Request) -> web.Response: + try: + download_id = request.query.get("download_id") + if not download_id: + return web.json_response( + {"success": False, "error": "download_id is required"}, status=400 + ) + + service = await DownloadQueueService.get_instance() + removed = await service.remove_from_queue(download_id) + return web.json_response({"success": removed}) + except Exception as exc: + self._logger.error( + "Error removing from download queue: %s", exc, exc_info=True + ) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def move_queue_item_to_top(self, request: web.Request) -> web.Response: + try: + download_id = request.query.get("download_id") + if not download_id: + return web.json_response( + {"success": False, "error": "download_id is required"}, status=400 + ) + + service = await DownloadQueueService.get_instance() + moved = await service.move_to_top(download_id) + return web.json_response({"success": moved}) + except Exception as exc: + self._logger.error( + "Error moving queue item to top: %s", exc, exc_info=True + ) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def move_queue_item_to_end(self, request: web.Request) -> web.Response: + try: + download_id = request.query.get("download_id") + if not download_id: + return web.json_response( + {"success": False, "error": "download_id is required"}, status=400 + ) + + service = await DownloadQueueService.get_instance() + moved = await service.move_to_end(download_id) + return web.json_response({"success": moved}) + except Exception as exc: + self._logger.error( + "Error moving queue item to end: %s", exc, exc_info=True + ) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def clear_download_queue(self, request: web.Request) -> web.Response: + try: + status_filter = request.query.get("status") or None + service = await DownloadQueueService.get_instance() + cleared = await service.clear_queue(status_filter=status_filter) + return web.json_response({"success": True, "cleared": cleared}) + except Exception as exc: + self._logger.error( + "Error clearing download queue: %s", exc, exc_info=True + ) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_download_history(self, request: web.Request) -> web.Response: + try: + limit = min(int(request.query.get("limit", "50")), 500) + offset = int(request.query.get("offset", "0")) + status_filter = request.query.get("status") or None + service = await DownloadQueueService.get_instance() + result = await service.get_history( + limit=limit, offset=offset, status_filter=status_filter + ) + return web.json_response( + { + "success": True, + "items": result["items"], + "total": result["total"], + "limit": result["limit"], + "offset": result["offset"], + } + ) + except Exception as exc: + self._logger.error( + "Error getting download history: %s", exc, exc_info=True + ) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def clear_download_history(self, request: web.Request) -> web.Response: + try: + status_filter = request.query.get("status") or None + service = await DownloadQueueService.get_instance() + cleared = await service.clear_history(status_filter=status_filter) + return web.json_response({"success": True, "cleared": cleared}) + except Exception as exc: + self._logger.error( + "Error clearing download history: %s", exc, exc_info=True + ) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def delete_download_history_item(self, request: web.Request) -> web.Response: + try: + item_id = int(request.query.get("id", "0")) + if not item_id: + return web.json_response( + {"success": False, "error": "id is required"}, status=400 + ) + + service = await DownloadQueueService.get_instance() + deleted = await service.delete_history_item(item_id) + return web.json_response({"success": deleted}) + except Exception as exc: + self._logger.error( + "Error deleting download history item: %s", exc, exc_info=True + ) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def retry_download_from_history(self, request: web.Request) -> web.Response: + try: + item_id = int(request.query.get("id", "0")) + if not item_id: + return web.json_response( + {"success": False, "error": "id is required"}, status=400 + ) + + service = await DownloadQueueService.get_instance() + item = await service.retry_from_history(item_id) + if item is None: + return web.json_response( + {"success": False, "error": "History item not found or not retryable"}, + status=404, + ) + return web.json_response({"success": True, "item": item}) + except Exception as exc: + self._logger.error( + "Error retrying download from history: %s", exc, exc_info=True + ) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def retry_all_failed_downloads(self, request: web.Request) -> web.Response: + try: + service = await DownloadQueueService.get_instance() + retry_count = await service.retry_all_failed() + return web.json_response({"success": True, "retry_count": retry_count}) + except Exception as exc: + self._logger.error( + "Error retrying all failed downloads: %s", exc, exc_info=True + ) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def complete_download_in_queue(self, request: web.Request) -> web.Response: + """Atomically move a download from queue to history with terminal status.""" + try: + download_id = request.query.get("download_id") + if not download_id: + return web.json_response( + {"success": False, "error": "download_id is required"}, status=400 + ) + status = request.query.get("status", "completed") + error = request.query.get("error") + file_path = request.query.get("file_path") + try: + bytes_downloaded = int(request.query.get("bytes_downloaded", "0")) + except (TypeError, ValueError): + bytes_downloaded = 0 + total_bytes_raw = request.query.get("total_bytes") + total_bytes = int(total_bytes_raw) if total_bytes_raw else None + + service = await DownloadQueueService.get_instance() + item = await service.complete_download( + download_id=download_id, + status=status, + error=error, + file_path=file_path, + bytes_downloaded=bytes_downloaded, + total_bytes=total_bytes, + ) + if item is None: + return web.json_response( + {"success": False, "error": "Download not found in queue"}, status=404 + ) + return web.json_response({"success": True, "item": item}) + except Exception as exc: + self._logger.error( + "Error completing download: %s", exc, exc_info=True + ) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_download_stats(self, request: web.Request) -> web.Response: + try: + service = await DownloadQueueService.get_instance() + stats = await service.get_stats() + return web.json_response({"success": True, "stats": stats}) + except Exception as exc: + self._logger.error( + "Error getting download stats: %s", exc, exc_info=True + ) + return web.json_response({"success": False, "error": str(exc)}, status=500) + class ModelCivitaiHandler: """CivitAI integration endpoints.""" @@ -2596,6 +2846,19 @@ class ModelHandlerSet: "pause_download_get": self.download.pause_download_get, "resume_download_get": self.download.resume_download_get, "get_download_progress": self.download.get_download_progress, + "get_download_queue": self.download.get_download_queue, + "add_to_download_queue": self.download.add_to_download_queue, + "remove_from_download_queue": self.download.remove_from_download_queue, + "move_queue_item_to_top": self.download.move_queue_item_to_top, + "move_queue_item_to_end": self.download.move_queue_item_to_end, + "clear_download_queue": self.download.clear_download_queue, + "get_download_history": self.download.get_download_history, + "clear_download_history": self.download.clear_download_history, + "delete_download_history_item": self.download.delete_download_history_item, + "retry_download_from_history": self.download.retry_download_from_history, + "retry_all_failed_downloads": self.download.retry_all_failed_downloads, + "complete_download_in_queue": self.download.complete_download_in_queue, + "get_download_stats": self.download.get_download_stats, "get_civitai_versions": self.civitai.get_civitai_versions, "get_civitai_model_by_version": self.civitai.get_civitai_model_by_version, "get_civitai_model_by_hash": self.civitai.get_civitai_model_by_hash, diff --git a/py/routes/model_route_registrar.py b/py/routes/model_route_registrar.py index 527e1f75..c7c98f46 100644 --- a/py/routes/model_route_registrar.py +++ b/py/routes/model_route_registrar.py @@ -107,6 +107,37 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition( "GET", "/api/lm/download-progress/{download_id}", "get_download_progress" ), + RouteDefinition("GET", "/api/lm/downloads/queue", "get_download_queue"), + RouteDefinition("GET", "/api/lm/downloads/queue/add", "add_to_download_queue"), + RouteDefinition( + "GET", "/api/lm/downloads/queue/remove", "remove_from_download_queue" + ), + RouteDefinition( + "GET", "/api/lm/downloads/queue/move-to-top", "move_queue_item_to_top" + ), + RouteDefinition( + "GET", "/api/lm/downloads/queue/move-to-end", "move_queue_item_to_end" + ), + RouteDefinition( + "GET", "/api/lm/downloads/queue/clear", "clear_download_queue" + ), + RouteDefinition("GET", "/api/lm/downloads/history", "get_download_history"), + RouteDefinition( + "GET", "/api/lm/downloads/history/clear", "clear_download_history" + ), + RouteDefinition( + "GET", "/api/lm/downloads/history/delete", "delete_download_history_item" + ), + RouteDefinition( + "GET", "/api/lm/downloads/history/retry", "retry_download_from_history" + ), + RouteDefinition( + "GET", "/api/lm/downloads/history/retry-all", "retry_all_failed_downloads" + ), + RouteDefinition("GET", "/api/lm/downloads/stats", "get_download_stats"), + RouteDefinition( + "GET", "/api/lm/downloads/queue/complete", "complete_download_in_queue" + ), RouteDefinition("POST", "/api/lm/{prefix}/cancel-task", "cancel_task"), RouteDefinition("GET", "/{prefix}", "handle_models_page"), ) diff --git a/py/services/download_queue_service.py b/py/services/download_queue_service.py new file mode 100644 index 00000000..243bf295 --- /dev/null +++ b/py/services/download_queue_service.py @@ -0,0 +1,730 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import os +import sqlite3 +import time +from typing import Any, Optional + +from ..utils.cache_paths import get_cache_base_dir + +logger = logging.getLogger(__name__) + + +def _resolve_database_path() -> str: + base_dir = get_cache_base_dir(create=True) + history_dir = os.path.join(base_dir, "download_history") + os.makedirs(history_dir, exist_ok=True) + return os.path.join(history_dir, "download_queue.sqlite") + + +class DownloadQueueService: + """Persistent download queue and history manager backed by SQLite. + + Provides a singleton interface for managing a download queue and + corresponding history table, both stored in a single SQLite database + under the cache directory. + """ + + _instance: Optional[DownloadQueueService] = None + _class_lock: asyncio.Lock = asyncio.Lock() + + _SCHEMA = """ + CREATE TABLE IF NOT EXISTS download_queue ( + download_id TEXT PRIMARY KEY, + model_id INTEGER, + model_version_id INTEGER, + model_name TEXT NOT NULL DEFAULT '', + version_name TEXT DEFAULT '', + thumbnail_url TEXT DEFAULT '', + source TEXT, + file_params TEXT, + status TEXT NOT NULL DEFAULT 'queued', + priority INTEGER DEFAULT 0, + progress INTEGER DEFAULT 0, + bytes_downloaded INTEGER DEFAULT 0, + total_bytes INTEGER, + bytes_per_second REAL DEFAULT 0.0, + error TEXT, + file_path TEXT, + added_at REAL NOT NULL, + started_at REAL, + completed_at REAL + ); + CREATE INDEX IF NOT EXISTS idx_dq_status ON download_queue(status); + CREATE INDEX IF NOT EXISTS idx_dq_added ON download_queue(added_at); + + CREATE TABLE IF NOT EXISTS download_history ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + download_id TEXT, + model_id INTEGER, + model_version_id INTEGER, + model_name TEXT NOT NULL DEFAULT '', + version_name TEXT DEFAULT '', + thumbnail_url TEXT DEFAULT '', + status TEXT NOT NULL, + error TEXT, + file_path TEXT, + bytes_downloaded INTEGER DEFAULT 0, + total_bytes INTEGER, + completed_at REAL NOT NULL, + is_already_exists INTEGER DEFAULT 0 + ); + CREATE INDEX IF NOT EXISTS idx_dh_completed ON download_history(completed_at DESC); + CREATE INDEX IF NOT EXISTS idx_dh_status ON download_history(status); + """ + + @classmethod + async def get_instance(cls) -> DownloadQueueService: + """Return the singleton instance, creating it if necessary.""" + async with cls._class_lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def __init__(self, db_path: Optional[str] = None) -> None: + self._db_path = db_path or _resolve_database_path() + self._lock = asyncio.Lock() + self._conn: Optional[sqlite3.Connection] = None + self._schema_initialized = False + self._ensure_directory() + self._initialize_schema() + + def _ensure_directory(self) -> None: + directory = os.path.dirname(self._db_path) + if directory: + os.makedirs(directory, exist_ok=True) + + def _connect(self) -> sqlite3.Connection: + conn = sqlite3.connect(self._db_path, check_same_thread=False) + conn.row_factory = sqlite3.Row + return conn + + def _get_conn(self) -> sqlite3.Connection: + if self._conn is None: + self._conn = sqlite3.connect(self._db_path, check_same_thread=False) + self._conn.row_factory = sqlite3.Row + return self._conn + + def _initialize_schema(self) -> None: + if self._schema_initialized: + return + with self._connect() as conn: + conn.executescript(self._SCHEMA) + conn.commit() + self._schema_initialized = True + + def get_database_path(self) -> str: + """Return the resolved database file path.""" + return self._db_path + + def close(self) -> None: + """Close the persistent SQLite connection, if open. + + This is called before plugin update operations to release the + database file lock on Windows, allowing ``shutil.rmtree()`` to + succeed when the cache resides inside the plugin directory. + """ + if self._conn is not None: + try: + self._conn.close() + except Exception: + pass + finally: + self._conn = None + + # ------------------------------------------------------------------ + # Queue methods + # ------------------------------------------------------------------ + + async def add_to_queue( + self, + download_id: str, + model_id: Optional[int] = None, + model_version_id: Optional[int] = None, + model_name: str = "", + version_name: str = "", + thumbnail_url: str = "", + source: Optional[str] = None, + file_params: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + """Insert a new download into the queue. + + Returns the inserted row as a dict (or an empty dict if the + download_id already exists). + """ + now = time.time() + file_params_json = json.dumps(file_params) if file_params is not None else None + + async with self._lock: + conn = self._get_conn() + conn.execute( + """ + INSERT OR IGNORE INTO download_queue ( + download_id, model_id, model_version_id, model_name, + version_name, thumbnail_url, source, file_params, + status, priority, added_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 'queued', 0, ?) + """, + ( + download_id, + model_id, + model_version_id, + model_name, + version_name, + thumbnail_url, + source, + file_params_json, + now, + ), + ) + conn.commit() + row = conn.execute( + "SELECT * FROM download_queue WHERE download_id = ?", + (download_id,), + ).fetchone() + + return dict(row) if row else {} + + async def get_queue(self) -> list[dict[str, Any]]: + """Return all items in the queue ordered by priority then added time.""" + async with self._lock: + conn = self._get_conn() + rows = conn.execute( + "SELECT * FROM download_queue ORDER BY priority DESC, added_at ASC" + ).fetchall() + return [dict(row) for row in rows] + + async def get_queued_count(self) -> int: + """Return the number of items with status ``'queued'``.""" + async with self._lock: + conn = self._get_conn() + row = conn.execute( + "SELECT COUNT(*) AS cnt FROM download_queue WHERE status = 'queued'" + ).fetchone() + return row["cnt"] if row else 0 + + async def update_status( + self, + download_id: str, + status: str, + **extra: Any, + ) -> bool: + """Update the status and/or extra fields of a queue item. + + Accepted extra keyword arguments: + ``progress``, ``error``, ``file_path``, ``bytes_downloaded``, + ``total_bytes``, ``bytes_per_second``. + + Returns ``True`` if a row was updated. + """ + allowed_extra = { + "progress", + "error", + "file_path", + "bytes_downloaded", + "total_bytes", + "bytes_per_second", + } + + set_clauses: list[str] = ["status = ?"] + params: list[Any] = [status] + now = time.time() + + if status in ("downloading",): + set_clauses.append("started_at = COALESCE(started_at, ?)") + params.append(now) + if status in ("completed", "failed", "canceled"): + set_clauses.append("completed_at = ?") + params.append(now) + + for key, value in extra.items(): + if key in allowed_extra: + set_clauses.append(f"{key} = ?") + params.append(value) + + params.append(download_id) + + async with self._lock: + conn = self._get_conn() + cursor = conn.execute( + f"UPDATE download_queue SET {', '.join(set_clauses)} " + "WHERE download_id = ?", + params, + ) + conn.commit() + return cursor.rowcount > 0 + + async def remove_from_queue(self, download_id: str) -> bool: + """Remove a single item from the queue by download_id. + + Returns ``True`` if a row was deleted. + """ + async with self._lock: + conn = self._get_conn() + cursor = conn.execute( + "DELETE FROM download_queue WHERE download_id = ?", + (download_id,), + ) + conn.commit() + return cursor.rowcount > 0 + + async def move_to_top(self, download_id: str) -> bool: + """Move an item to the front of the queue (highest priority). + + Returns ``True`` if the item was found and updated. + """ + async with self._lock: + conn = self._get_conn() + row = conn.execute( + "SELECT priority FROM download_queue WHERE download_id = ?", + (download_id,), + ).fetchone() + if row is None: + return False + + max_row = conn.execute( + "SELECT MAX(priority) AS mx FROM download_queue" + ).fetchone() + max_priority: int = max_row["mx"] if max_row["mx"] is not None else 0 + + conn.execute( + "UPDATE download_queue SET priority = ? WHERE download_id = ?", + (max_priority + 1, download_id), + ) + conn.commit() + return True + + async def move_to_end(self, download_id: str) -> bool: + """Move an item to the end of the queue (lowest priority). + + Returns ``True`` if the item was found and updated. + """ + async with self._lock: + conn = self._get_conn() + row = conn.execute( + "SELECT priority FROM download_queue WHERE download_id = ?", + (download_id,), + ).fetchone() + if row is None: + return False + + min_row = conn.execute( + "SELECT MIN(priority) AS mn FROM download_queue" + ).fetchone() + min_priority: int = min_row["mn"] if min_row["mn"] is not None else 0 + + conn.execute( + "UPDATE download_queue SET priority = ? WHERE download_id = ?", + (min_priority - 1, download_id), + ) + conn.commit() + return True + + async def clear_queue(self, status_filter: Optional[str] = None) -> int: + """Remove items from the queue. + + When *status_filter* is provided only items with that status are + deleted. Returns the number of deleted rows. + """ + async with self._lock: + conn = self._get_conn() + if status_filter is not None: + cursor = conn.execute( + "DELETE FROM download_queue WHERE status = ?", + (status_filter,), + ) + else: + cursor = conn.execute("DELETE FROM download_queue") + conn.commit() + return cursor.rowcount + + async def complete_download( + self, + download_id: str, + status: str = "completed", + error: Optional[str] = None, + file_path: Optional[str] = None, + bytes_downloaded: int = 0, + total_bytes: Optional[int] = None, + ) -> Optional[dict[str, Any]]: + """Atomically move a download from the queue into the history table. + + Looks up the queue record by ``download_id``, deletes it from the + queue, and inserts a corresponding history entry with the given + terminal status (``completed``, ``failed``, or ``canceled``). + + Returns the original queue record (before deletion) on success, + or ``None`` if the download was not found in the queue. + """ + async with self._lock: + conn = self._get_conn() + row = conn.execute( + "SELECT * FROM download_queue WHERE download_id = ?", + (download_id,), + ).fetchone() + if row is None: + return None + + now = time.time() + conn.execute( + "DELETE FROM download_queue WHERE download_id = ?", + (download_id,), + ) + conn.execute( + """ + INSERT INTO download_history ( + download_id, model_id, model_version_id, model_name, + version_name, thumbnail_url, status, error, file_path, + bytes_downloaded, total_bytes, completed_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + row["download_id"], + row["model_id"], + row["model_version_id"], + row["model_name"], + row["version_name"], + row["thumbnail_url"], + status, + error, + file_path, + bytes_downloaded, + total_bytes, + now, + ), + ) + conn.commit() + return dict(row) + + async def pop_next_download(self) -> Optional[dict[str, Any]]: + """Atomically fetch and mark the next queued item as ``downloading``. + + The item with the highest priority (and earliest ``added_at`` + among ties) whose status is ``'queued'`` is selected, set to + ``'downloading'``, and returned as a dict. Returns ``None`` if + the queue is empty. + """ + async with self._lock: + conn = self._get_conn() + row = conn.execute( + """ + SELECT * FROM download_queue + WHERE status = 'queued' + ORDER BY priority DESC, added_at ASC + LIMIT 1 + """ + ).fetchone() + if row is None: + return None + + download_id = row["download_id"] + now = time.time() + conn.execute( + "UPDATE download_queue SET status = 'downloading', " + "started_at = COALESCE(started_at, ?) " + "WHERE download_id = ?", + (now, download_id), + ) + conn.commit() + updated = conn.execute( + "SELECT * FROM download_queue WHERE download_id = ?", + (download_id,), + ).fetchone() + + return dict(updated) if updated else None + + # ------------------------------------------------------------------ + # History methods + # ------------------------------------------------------------------ + + async def add_to_history( + self, + download_id: Optional[str] = None, + model_id: Optional[int] = None, + model_version_id: Optional[int] = None, + model_name: str = "", + version_name: str = "", + thumbnail_url: str = "", + status: str = "completed", + error: Optional[str] = None, + file_path: Optional[str] = None, + bytes_downloaded: int = 0, + total_bytes: Optional[int] = None, + is_already_exists: int = 0, + ) -> int: + """Insert a record into the download history. + + Returns the ``id`` (AUTOINCREMENT primary key) of the newly + inserted row. + """ + now = time.time() + + async with self._lock: + conn = self._get_conn() + cursor = conn.execute( + """ + INSERT INTO download_history ( + download_id, model_id, model_version_id, model_name, + version_name, thumbnail_url, status, error, file_path, + bytes_downloaded, total_bytes, completed_at, is_already_exists + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + download_id, + model_id, + model_version_id, + model_name, + version_name, + thumbnail_url, + status, + error, + file_path, + bytes_downloaded, + total_bytes, + now, + is_already_exists, + ), + ) + conn.commit() + return cursor.lastrowid or 0 + + async def get_history( + self, + limit: int = 50, + offset: int = 0, + status_filter: Optional[str] = None, + ) -> dict[str, Any]: + """Return a page of download history entries. + + Returns a dict with keys ``items``, ``total``, ``limit``, and + ``offset``. + """ + async with self._lock: + conn = self._get_conn() + + if status_filter is not None: + count_row = conn.execute( + "SELECT COUNT(*) AS cnt FROM download_history WHERE status = ?", + (status_filter,), + ).fetchone() + rows = conn.execute( + "SELECT * FROM download_history WHERE status = ? " + "ORDER BY completed_at DESC LIMIT ? OFFSET ?", + (status_filter, limit, offset), + ).fetchall() + else: + count_row = conn.execute( + "SELECT COUNT(*) AS cnt FROM download_history" + ).fetchone() + rows = conn.execute( + "SELECT * FROM download_history " + "ORDER BY completed_at DESC LIMIT ? OFFSET ?", + (limit, offset), + ).fetchall() + + return { + "items": [dict(row) for row in rows], + "total": count_row["cnt"] if count_row else 0, + "limit": limit, + "offset": offset, + } + + async def delete_history_item(self, id: int) -> bool: + """Delete a single history entry by its *id*. + + Returns ``True`` if a row was deleted. + """ + async with self._lock: + conn = self._get_conn() + cursor = conn.execute( + "DELETE FROM download_history WHERE id = ?", + (id,), + ) + conn.commit() + return cursor.rowcount > 0 + + async def clear_history( + self, + status_filter: Optional[str] = None, + before_timestamp: Optional[float] = None, + ) -> int: + """Remove history entries matching the optional filters. + + Both ``status_filter`` and ``before_timestamp`` can be combined + (AND logic). Returns the number of deleted rows. + """ + async with self._lock: + conn = self._get_conn() + + clauses: list[str] = [] + params: list[Any] = [] + + if status_filter is not None: + clauses.append("status = ?") + params.append(status_filter) + if before_timestamp is not None: + clauses.append("completed_at < ?") + params.append(before_timestamp) + + where = "" + if clauses: + where = " WHERE " + " AND ".join(clauses) + + cursor = conn.execute( + f"DELETE FROM download_history{where}", + params, + ) + conn.commit() + return cursor.rowcount + + async def get_history_count(self, status_filter: Optional[str] = None) -> int: + """Return the number of history entries, optionally filtered by status.""" + async with self._lock: + conn = self._get_conn() + if status_filter is not None: + row = conn.execute( + "SELECT COUNT(*) AS cnt FROM download_history WHERE status = ?", + (status_filter,), + ).fetchone() + else: + row = conn.execute( + "SELECT COUNT(*) AS cnt FROM download_history" + ).fetchone() + return row["cnt"] if row else 0 + + # ------------------------------------------------------------------ + # Retry + # ------------------------------------------------------------------ + + async def retry_from_history(self, item_id: int) -> Optional[dict[str, Any]]: + """Re-queue a failed or canceled download from history. + + Looks up the history record by its primary key. If the status is + ``failed`` or ``canceled`` a new queue entry is created with the + same model metadata and a fresh download id. + """ + async with self._lock: + conn = self._get_conn() + row = conn.execute( + "SELECT * FROM download_history WHERE id = ?", + (item_id,), + ).fetchone() + if row is None: + return None + status = str(row["status"]) + if status not in ("failed", "canceled"): + return None + + import uuid + + new_id = str(uuid.uuid4()) + now = time.time() + conn.execute( + """ + INSERT INTO download_queue ( + download_id, model_id, model_version_id, model_name, + version_name, thumbnail_url, source, file_params, + status, priority, added_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, NULL, 'queued', 0, ?) + """, + ( + new_id, + row["model_id"], + row["model_version_id"], + row["model_name"], + row["version_name"], + row["thumbnail_url"], + "retry", + now, + ), + ) + conn.commit() + queued = conn.execute( + "SELECT * FROM download_queue WHERE download_id = ?", + (new_id,), + ).fetchone() + + return dict(queued) if queued else None + + async def retry_all_failed(self) -> int: + """Re-queue all failed and canceled downloads from history. + + Returns the number of items that were re-queued. + """ + async with self._lock: + conn = self._get_conn() + rows = conn.execute( + "SELECT * FROM download_history WHERE status IN ('failed', 'canceled')" + ).fetchall() + if not rows: + return 0 + + import uuid + + now = time.time() + count = 0 + for row in rows: + new_id = str(uuid.uuid4()) + conn.execute( + """ + INSERT INTO download_queue ( + download_id, model_id, model_version_id, model_name, + version_name, thumbnail_url, source, file_params, + status, priority, added_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, NULL, 'queued', 0, ?) + """, + ( + new_id, + row["model_id"], + row["model_version_id"], + row["model_name"], + row["version_name"], + row["thumbnail_url"], + "retry", + now, + ), + ) + count += 1 + conn.commit() + + return count + + # ------------------------------------------------------------------ + # Stats + # ------------------------------------------------------------------ + + async def get_stats(self) -> dict[str, int]: + """Return aggregate counts across both tables. + + Returns a dict with keys ``queued``, ``downloading``, ``paused`` + (all from the queue table) and ``completed``, ``failed``, + ``canceled`` (all from the history table). + """ + async with self._lock: + conn = self._get_conn() + + queue_rows = conn.execute( + "SELECT status, COUNT(*) AS cnt FROM download_queue GROUP BY status" + ).fetchall() + queue_stats: dict[str, int] = {} + for row in queue_rows: + queue_stats[str(row["status"])] = row["cnt"] + + history_rows = conn.execute( + "SELECT status, COUNT(*) AS cnt FROM download_history GROUP BY status" + ).fetchall() + history_stats: dict[str, int] = {} + for row in history_rows: + history_stats[str(row["status"])] = row["cnt"] + + return { + "queued": queue_stats.get("queued", 0), + "downloading": queue_stats.get("downloading", 0), + "paused": queue_stats.get("paused", 0), + "completed": history_stats.get("completed", 0), + "failed": history_stats.get("failed", 0), + "canceled": history_stats.get("canceled", 0), + } diff --git a/py/services/service_registry.py b/py/services/service_registry.py index 7d16c905..162579c7 100644 --- a/py/services/service_registry.py +++ b/py/services/service_registry.py @@ -188,6 +188,25 @@ class ServiceRegistry: logger.debug(f"Created and registered {service_name}") return service + @classmethod + async def get_download_queue_service(cls): + """Get or create the download queue service.""" + service_name = "download_queue_service" + + if service_name in cls._services: + return cls._services[service_name] + + async with cls._get_lock(service_name): + if service_name in cls._services: + return cls._services[service_name] + + from .download_queue_service import DownloadQueueService + + service = await DownloadQueueService.get_instance() + cls._services[service_name] = service + logger.debug(f"Created and registered {service_name}") + return service + @classmethod async def get_backup_service(cls): """Get or create the backup service."""