diff --git a/py/routes/handlers/misc_handlers.py b/py/routes/handlers/misc_handlers.py index acfd1d15..dbc35924 100644 --- a/py/routes/handlers/misc_handlers.py +++ b/py/routes/handlers/misc_handlers.py @@ -1791,29 +1791,33 @@ class ModelLibraryHandler: exists = True model_type = "embedding" + if exists: + return web.json_response( + { + "success": True, + "exists": True, + "modelType": model_type, + "hasBeenDownloaded": False, + } + ) + history_service = await self._get_download_history_service() has_been_downloaded = False - history_type = model_type - if history_type: - has_been_downloaded = await history_service.has_been_downloaded( - history_type, + history_type = None + for candidate_type in ("lora", "checkpoint", "embedding"): + if await history_service.has_been_downloaded( + candidate_type, model_version_id, - ) - else: - for candidate_type in ("lora", "checkpoint", "embedding"): - if await history_service.has_been_downloaded( - candidate_type, - model_version_id, - ): - has_been_downloaded = True - history_type = candidate_type - break + ): + has_been_downloaded = True + history_type = candidate_type + break return web.json_response( { "success": True, - "exists": exists, - "modelType": model_type if exists else history_type, + "exists": False, + "modelType": history_type, "hasBeenDownloaded": has_been_downloaded, } ) @@ -1833,40 +1837,46 @@ class ModelLibraryHandler: model_type = None versions = [] downloaded_version_ids = [] - history_service = await self._get_download_history_service() if lora_versions: - model_type = "lora" - versions = self._with_downloaded_flag(lora_versions) - downloaded_version_ids = await history_service.get_downloaded_version_ids( - model_type, - model_id, + return web.json_response( + { + "success": True, + "modelType": "lora", + "versions": self._with_downloaded_flag(lora_versions), + "downloadedVersionIds": [], + } ) - elif checkpoint_versions: - model_type = "checkpoint" - versions = self._with_downloaded_flag(checkpoint_versions) - downloaded_version_ids = await history_service.get_downloaded_version_ids( - model_type, - model_id, + if checkpoint_versions: + return web.json_response( + { + "success": True, + "modelType": "checkpoint", + "versions": self._with_downloaded_flag(checkpoint_versions), + "downloadedVersionIds": [], + } ) - elif embedding_versions: - model_type = "embedding" - versions = self._with_downloaded_flag(embedding_versions) - downloaded_version_ids = await history_service.get_downloaded_version_ids( - model_type, - model_id, + if embedding_versions: + return web.json_response( + { + "success": True, + "modelType": "embedding", + "versions": self._with_downloaded_flag(embedding_versions), + "downloadedVersionIds": [], + } ) - else: - for candidate_type in ("lora", "checkpoint", "embedding"): - candidate_downloaded_version_ids = ( - await history_service.get_downloaded_version_ids( - candidate_type, - model_id, - ) + + history_service = await self._get_download_history_service() + for candidate_type in ("lora", "checkpoint", "embedding"): + candidate_downloaded_version_ids = ( + await history_service.get_downloaded_version_ids( + candidate_type, + model_id, ) - if candidate_downloaded_version_ids: - model_type = candidate_type - downloaded_version_ids = candidate_downloaded_version_ids - break + ) + if candidate_downloaded_version_ids: + model_type = candidate_type + downloaded_version_ids = candidate_downloaded_version_ids + break return web.json_response( { diff --git a/py/services/downloaded_version_history_service.py b/py/services/downloaded_version_history_service.py index dc1d3cff..032a494e 100644 --- a/py/services/downloaded_version_history_service.py +++ b/py/services/downloaded_version_history_service.py @@ -64,6 +64,7 @@ class DownloadedVersionHistoryService: self._db_path = db_path or _resolve_database_path() self._settings = settings_manager or get_settings_manager() self._lock = asyncio.Lock() + self._conn: sqlite3.Connection | None = None self._schema_initialized = False self._ensure_directory() self._initialize_schema() @@ -78,6 +79,12 @@ class DownloadedVersionHistoryService: conn.row_factory = sqlite3.Row return conn + def _get_conn(self) -> sqlite3.Connection: + if self._conn is None: + self._conn = sqlite3.connect(self._db_path, check_same_thread=False) + self._conn.row_factory = sqlite3.Row + return self._conn + def _initialize_schema(self) -> None: if self._schema_initialized: return @@ -116,33 +123,33 @@ class DownloadedVersionHistoryService: timestamp = time.time() async with self._lock: - with self._connect() as conn: - conn.execute( - """ - INSERT INTO downloaded_model_versions ( - model_type, version_id, model_id, first_seen_at, last_seen_at, - source, last_file_path, last_library_name, is_deleted_override - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0) - ON CONFLICT(model_type, version_id) DO UPDATE SET - model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id), - last_seen_at = excluded.last_seen_at, - source = excluded.source, - last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path), - last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name), - is_deleted_override = 0 - """, - ( - normalized_type, - normalized_version_id, - normalized_model_id, - timestamp, - timestamp, - source, - file_path, - active_library_name, - ), - ) - conn.commit() + conn = self._get_conn() + conn.execute( + """ + INSERT INTO downloaded_model_versions ( + model_type, version_id, model_id, first_seen_at, last_seen_at, + source, last_file_path, last_library_name, is_deleted_override + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0) + ON CONFLICT(model_type, version_id) DO UPDATE SET + model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id), + last_seen_at = excluded.last_seen_at, + source = excluded.source, + last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path), + last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name), + is_deleted_override = 0 + """, + ( + normalized_type, + normalized_version_id, + normalized_model_id, + timestamp, + timestamp, + source, + file_path, + active_library_name, + ), + ) + conn.commit() async def mark_downloaded_bulk( self, @@ -180,24 +187,24 @@ class DownloadedVersionHistoryService: return async with self._lock: - with self._connect() as conn: - conn.executemany( - """ - INSERT INTO downloaded_model_versions ( - model_type, version_id, model_id, first_seen_at, last_seen_at, - source, last_file_path, last_library_name, is_deleted_override - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0) - ON CONFLICT(model_type, version_id) DO UPDATE SET - model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id), - last_seen_at = excluded.last_seen_at, - source = excluded.source, - last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path), - last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name), - is_deleted_override = 0 - """, - payload, - ) - conn.commit() + conn = self._get_conn() + conn.executemany( + """ + INSERT INTO downloaded_model_versions ( + model_type, version_id, model_id, first_seen_at, last_seen_at, + source, last_file_path, last_library_name, is_deleted_override + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0) + ON CONFLICT(model_type, version_id) DO UPDATE SET + model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id), + last_seen_at = excluded.last_seen_at, + source = excluded.source, + last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path), + last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name), + is_deleted_override = 0 + """, + payload, + ) + conn.commit() async def mark_not_downloaded(self, model_type: str, version_id: int) -> None: normalized_type = _normalize_model_type(model_type) @@ -208,28 +215,28 @@ class DownloadedVersionHistoryService: timestamp = time.time() async with self._lock: - with self._connect() as conn: - conn.execute( - """ - INSERT INTO downloaded_model_versions ( - model_type, version_id, model_id, first_seen_at, last_seen_at, - source, last_file_path, last_library_name, is_deleted_override - ) VALUES (?, ?, NULL, ?, ?, 'manual', NULL, ?, 1) - ON CONFLICT(model_type, version_id) DO UPDATE SET - last_seen_at = excluded.last_seen_at, - source = excluded.source, - last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name), - is_deleted_override = 1 - """, - ( - normalized_type, - normalized_version_id, - timestamp, - timestamp, - self._get_active_library_name(), - ), - ) - conn.commit() + conn = self._get_conn() + conn.execute( + """ + INSERT INTO downloaded_model_versions ( + model_type, version_id, model_id, first_seen_at, last_seen_at, + source, last_file_path, last_library_name, is_deleted_override + ) VALUES (?, ?, NULL, ?, ?, 'manual', NULL, ?, 1) + ON CONFLICT(model_type, version_id) DO UPDATE SET + last_seen_at = excluded.last_seen_at, + source = excluded.source, + last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name), + is_deleted_override = 1 + """, + ( + normalized_type, + normalized_version_id, + timestamp, + timestamp, + self._get_active_library_name(), + ), + ) + conn.commit() async def has_been_downloaded(self, model_type: str, version_id: int) -> bool: normalized_type = _normalize_model_type(model_type) @@ -238,15 +245,15 @@ class DownloadedVersionHistoryService: return False async with self._lock: - with self._connect() as conn: - row = conn.execute( - """ - SELECT is_deleted_override - FROM downloaded_model_versions - WHERE model_type = ? AND version_id = ? - """, - (normalized_type, normalized_version_id), - ).fetchone() + conn = self._get_conn() + row = conn.execute( + """ + SELECT is_deleted_override + FROM downloaded_model_versions + WHERE model_type = ? AND version_id = ? + """, + (normalized_type, normalized_version_id), + ).fetchone() return bool(row) and not bool(row["is_deleted_override"]) async def get_downloaded_version_ids( @@ -258,16 +265,16 @@ class DownloadedVersionHistoryService: return [] async with self._lock: - with self._connect() as conn: - rows = conn.execute( - """ - SELECT version_id - FROM downloaded_model_versions - WHERE model_type = ? AND model_id = ? AND is_deleted_override = 0 - ORDER BY version_id ASC - """, - (normalized_type, normalized_model_id), - ).fetchall() + conn = self._get_conn() + rows = conn.execute( + """ + SELECT version_id + FROM downloaded_model_versions + WHERE model_type = ? AND model_id = ? AND is_deleted_override = 0 + ORDER BY version_id ASC + """, + (normalized_type, normalized_model_id), + ).fetchall() return [int(row["version_id"]) for row in rows] async def get_downloaded_version_ids_bulk( @@ -291,17 +298,17 @@ class DownloadedVersionHistoryService: params: list[object] = [normalized_type, *normalized_model_ids] async with self._lock: - with self._connect() as conn: - rows = conn.execute( - f""" - SELECT model_id, version_id - FROM downloaded_model_versions - WHERE model_type = ? - AND model_id IN ({placeholders}) - AND is_deleted_override = 0 - """, - params, - ).fetchall() + conn = self._get_conn() + rows = conn.execute( + f""" + SELECT model_id, version_id + FROM downloaded_model_versions + WHERE model_type = ? + AND model_id IN ({placeholders}) + AND is_deleted_override = 0 + """, + params, + ).fetchall() result: dict[int, set[int]] = {} for row in rows: