mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-10 10:16:45 -03:00
Compare commits
4 Commits
502b7eab31
...
2ac0eb0f9d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ac0eb0f9d | ||
|
|
f028625ce9 | ||
|
|
06acc7f576 | ||
|
|
d324b57274 |
@@ -1,10 +1,22 @@
|
|||||||
import folder_paths # type: ignore
|
import os
|
||||||
from ..utils.utils import get_lora_info
|
from ..utils.utils import get_lora_info_absolute
|
||||||
|
from ..config import config
|
||||||
from .utils import FlexibleOptionalInputType, any_type, get_loras_list
|
from .utils import FlexibleOptionalInputType, any_type, get_loras_list
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _relpath_within_loras(abs_path):
|
||||||
|
"""Return abs_path relative to the first matching lora root, or basename as fallback."""
|
||||||
|
all_roots = list(config.loras_roots or []) + list(config.extra_loras_roots or [])
|
||||||
|
for root in all_roots:
|
||||||
|
try:
|
||||||
|
return os.path.relpath(abs_path, root)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
return os.path.basename(abs_path)
|
||||||
|
|
||||||
class WanVideoLoraSelectLM:
|
class WanVideoLoraSelectLM:
|
||||||
NAME = "WanVideo Lora Select (LoraManager)"
|
NAME = "WanVideo Lora Select (LoraManager)"
|
||||||
CATEGORY = "Lora Manager/stackers"
|
CATEGORY = "Lora Manager/stackers"
|
||||||
@@ -56,13 +68,13 @@ class WanVideoLoraSelectLM:
|
|||||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||||
|
|
||||||
# Get lora path and trigger words
|
# Get lora path and trigger words
|
||||||
lora_path, trigger_words = get_lora_info(lora_name)
|
lora_path, trigger_words = get_lora_info_absolute(lora_name)
|
||||||
|
|
||||||
# Create lora item for WanVideo format
|
# Create lora item for WanVideo format
|
||||||
lora_item = {
|
lora_item = {
|
||||||
"path": folder_paths.get_full_path("loras", lora_path),
|
"path": lora_path,
|
||||||
"strength": model_strength,
|
"strength": model_strength,
|
||||||
"name": lora_path.split(".")[0],
|
"name": os.path.splitext(_relpath_within_loras(lora_path))[0],
|
||||||
"blocks": selected_blocks,
|
"blocks": selected_blocks,
|
||||||
"layer_filter": layer_filter,
|
"layer_filter": layer_filter,
|
||||||
"low_mem_load": low_mem_load,
|
"low_mem_load": low_mem_load,
|
||||||
|
|||||||
@@ -1,11 +1,23 @@
|
|||||||
import folder_paths # type: ignore
|
import os
|
||||||
from ..utils.utils import get_lora_info
|
from ..utils.utils import get_lora_info_absolute
|
||||||
|
from ..config import config
|
||||||
from .utils import any_type
|
from .utils import any_type
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
# 初始化日志记录器
|
# 初始化日志记录器
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _relpath_within_loras(abs_path):
|
||||||
|
"""Return abs_path relative to the first matching lora root, or basename as fallback."""
|
||||||
|
all_roots = list(config.loras_roots or []) + list(config.extra_loras_roots or [])
|
||||||
|
for root in all_roots:
|
||||||
|
try:
|
||||||
|
return os.path.relpath(abs_path, root)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
return os.path.basename(abs_path)
|
||||||
|
|
||||||
# 定义新节点的类
|
# 定义新节点的类
|
||||||
class WanVideoLoraTextSelectLM:
|
class WanVideoLoraTextSelectLM:
|
||||||
# 节点在UI中显示的名称
|
# 节点在UI中显示的名称
|
||||||
@@ -87,12 +99,12 @@ class WanVideoLoraTextSelectLM:
|
|||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
lora_path, trigger_words = get_lora_info(lora_name_raw)
|
lora_path, trigger_words = get_lora_info_absolute(lora_name_raw)
|
||||||
|
|
||||||
lora_item = {
|
lora_item = {
|
||||||
"path": folder_paths.get_full_path("loras", lora_path),
|
"path": lora_path,
|
||||||
"strength": model_strength,
|
"strength": model_strength,
|
||||||
"name": lora_path.split(".")[0],
|
"name": os.path.splitext(_relpath_within_loras(lora_path))[0],
|
||||||
"blocks": selected_blocks,
|
"blocks": selected_blocks,
|
||||||
"layer_filter": layer_filter,
|
"layer_filter": layer_filter,
|
||||||
"low_mem_load": low_mem_load,
|
"low_mem_load": low_mem_load,
|
||||||
|
|||||||
@@ -1791,29 +1791,33 @@ class ModelLibraryHandler:
|
|||||||
exists = True
|
exists = True
|
||||||
model_type = "embedding"
|
model_type = "embedding"
|
||||||
|
|
||||||
|
if exists:
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"exists": True,
|
||||||
|
"modelType": model_type,
|
||||||
|
"hasBeenDownloaded": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
history_service = await self._get_download_history_service()
|
history_service = await self._get_download_history_service()
|
||||||
has_been_downloaded = False
|
has_been_downloaded = False
|
||||||
history_type = model_type
|
history_type = None
|
||||||
if history_type:
|
for candidate_type in ("lora", "checkpoint", "embedding"):
|
||||||
has_been_downloaded = await history_service.has_been_downloaded(
|
if await history_service.has_been_downloaded(
|
||||||
history_type,
|
candidate_type,
|
||||||
model_version_id,
|
model_version_id,
|
||||||
)
|
):
|
||||||
else:
|
has_been_downloaded = True
|
||||||
for candidate_type in ("lora", "checkpoint", "embedding"):
|
history_type = candidate_type
|
||||||
if await history_service.has_been_downloaded(
|
break
|
||||||
candidate_type,
|
|
||||||
model_version_id,
|
|
||||||
):
|
|
||||||
has_been_downloaded = True
|
|
||||||
history_type = candidate_type
|
|
||||||
break
|
|
||||||
|
|
||||||
return web.json_response(
|
return web.json_response(
|
||||||
{
|
{
|
||||||
"success": True,
|
"success": True,
|
||||||
"exists": exists,
|
"exists": False,
|
||||||
"modelType": model_type if exists else history_type,
|
"modelType": history_type,
|
||||||
"hasBeenDownloaded": has_been_downloaded,
|
"hasBeenDownloaded": has_been_downloaded,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -1833,40 +1837,46 @@ class ModelLibraryHandler:
|
|||||||
model_type = None
|
model_type = None
|
||||||
versions = []
|
versions = []
|
||||||
downloaded_version_ids = []
|
downloaded_version_ids = []
|
||||||
history_service = await self._get_download_history_service()
|
|
||||||
if lora_versions:
|
if lora_versions:
|
||||||
model_type = "lora"
|
return web.json_response(
|
||||||
versions = self._with_downloaded_flag(lora_versions)
|
{
|
||||||
downloaded_version_ids = await history_service.get_downloaded_version_ids(
|
"success": True,
|
||||||
model_type,
|
"modelType": "lora",
|
||||||
model_id,
|
"versions": self._with_downloaded_flag(lora_versions),
|
||||||
|
"downloadedVersionIds": [],
|
||||||
|
}
|
||||||
)
|
)
|
||||||
elif checkpoint_versions:
|
if checkpoint_versions:
|
||||||
model_type = "checkpoint"
|
return web.json_response(
|
||||||
versions = self._with_downloaded_flag(checkpoint_versions)
|
{
|
||||||
downloaded_version_ids = await history_service.get_downloaded_version_ids(
|
"success": True,
|
||||||
model_type,
|
"modelType": "checkpoint",
|
||||||
model_id,
|
"versions": self._with_downloaded_flag(checkpoint_versions),
|
||||||
|
"downloadedVersionIds": [],
|
||||||
|
}
|
||||||
)
|
)
|
||||||
elif embedding_versions:
|
if embedding_versions:
|
||||||
model_type = "embedding"
|
return web.json_response(
|
||||||
versions = self._with_downloaded_flag(embedding_versions)
|
{
|
||||||
downloaded_version_ids = await history_service.get_downloaded_version_ids(
|
"success": True,
|
||||||
model_type,
|
"modelType": "embedding",
|
||||||
model_id,
|
"versions": self._with_downloaded_flag(embedding_versions),
|
||||||
|
"downloadedVersionIds": [],
|
||||||
|
}
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
for candidate_type in ("lora", "checkpoint", "embedding"):
|
history_service = await self._get_download_history_service()
|
||||||
candidate_downloaded_version_ids = (
|
for candidate_type in ("lora", "checkpoint", "embedding"):
|
||||||
await history_service.get_downloaded_version_ids(
|
candidate_downloaded_version_ids = (
|
||||||
candidate_type,
|
await history_service.get_downloaded_version_ids(
|
||||||
model_id,
|
candidate_type,
|
||||||
)
|
model_id,
|
||||||
)
|
)
|
||||||
if candidate_downloaded_version_ids:
|
)
|
||||||
model_type = candidate_type
|
if candidate_downloaded_version_ids:
|
||||||
downloaded_version_ids = candidate_downloaded_version_ids
|
model_type = candidate_type
|
||||||
break
|
downloaded_version_ids = candidate_downloaded_version_ids
|
||||||
|
break
|
||||||
|
|
||||||
return web.json_response(
|
return web.json_response(
|
||||||
{
|
{
|
||||||
@@ -1880,6 +1890,86 @@ class ModelLibraryHandler:
|
|||||||
logger.error("Failed to check model existence: %s", exc, exc_info=True)
|
logger.error("Failed to check model existence: %s", exc, exc_info=True)
|
||||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def check_models_exist(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
model_ids_raw = request.query.get("modelIds", "")
|
||||||
|
if not model_ids_raw:
|
||||||
|
return web.json_response(
|
||||||
|
{"success": True, "results": []}
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_ids = model_ids_raw.split(",")
|
||||||
|
seen: set[int] = set()
|
||||||
|
model_ids: list[int] = []
|
||||||
|
for raw in raw_ids:
|
||||||
|
stripped = raw.strip()
|
||||||
|
if not stripped:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
mid = int(stripped)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
if mid not in seen:
|
||||||
|
seen.add(mid)
|
||||||
|
model_ids.append(mid)
|
||||||
|
|
||||||
|
if not model_ids:
|
||||||
|
return web.json_response(
|
||||||
|
{"success": True, "results": []}
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_scanner = await self._service_registry.get_lora_scanner()
|
||||||
|
checkpoint_scanner = await self._service_registry.get_checkpoint_scanner()
|
||||||
|
embedding_scanner = await self._service_registry.get_embedding_scanner()
|
||||||
|
|
||||||
|
results: list[dict] = []
|
||||||
|
for model_id in model_ids:
|
||||||
|
lora_versions = await lora_scanner.get_model_versions_by_id(model_id)
|
||||||
|
if lora_versions:
|
||||||
|
results.append({
|
||||||
|
"modelId": model_id,
|
||||||
|
"modelType": "lora",
|
||||||
|
"versions": self._with_downloaded_flag(lora_versions),
|
||||||
|
"downloadedVersionIds": [],
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if checkpoint_scanner:
|
||||||
|
checkpoint_versions = await checkpoint_scanner.get_model_versions_by_id(model_id)
|
||||||
|
if checkpoint_versions:
|
||||||
|
results.append({
|
||||||
|
"modelId": model_id,
|
||||||
|
"modelType": "checkpoint",
|
||||||
|
"versions": self._with_downloaded_flag(checkpoint_versions),
|
||||||
|
"downloadedVersionIds": [],
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if embedding_scanner:
|
||||||
|
embedding_versions = await embedding_scanner.get_model_versions_by_id(model_id)
|
||||||
|
if embedding_versions:
|
||||||
|
results.append({
|
||||||
|
"modelId": model_id,
|
||||||
|
"modelType": "embedding",
|
||||||
|
"versions": self._with_downloaded_flag(embedding_versions),
|
||||||
|
"downloadedVersionIds": [],
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"modelId": model_id,
|
||||||
|
"modelType": None,
|
||||||
|
"versions": [],
|
||||||
|
"downloadedVersionIds": [],
|
||||||
|
})
|
||||||
|
|
||||||
|
return web.json_response(
|
||||||
|
{"success": True, "results": results}
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to check models existence: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
async def get_model_version_download_status(
|
async def get_model_version_download_status(
|
||||||
self, request: web.Request
|
self, request: web.Request
|
||||||
) -> web.Response:
|
) -> web.Response:
|
||||||
@@ -3025,6 +3115,7 @@ class MiscHandlerSet:
|
|||||||
"update_node_widget": self.node_registry.update_node_widget,
|
"update_node_widget": self.node_registry.update_node_widget,
|
||||||
"get_registry": self.node_registry.get_registry,
|
"get_registry": self.node_registry.get_registry,
|
||||||
"check_model_exists": self.model_library.check_model_exists,
|
"check_model_exists": self.model_library.check_model_exists,
|
||||||
|
"check_models_exist": self.model_library.check_models_exist,
|
||||||
"get_model_version_download_status": self.model_library.get_model_version_download_status,
|
"get_model_version_download_status": self.model_library.get_model_version_download_status,
|
||||||
"set_model_version_download_status": self.model_library.set_model_version_download_status,
|
"set_model_version_download_status": self.model_library.set_model_version_download_status,
|
||||||
"get_civitai_user_models": self.model_library.get_civitai_user_models,
|
"get_civitai_user_models": self.model_library.get_civitai_user_models,
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
|||||||
RouteDefinition("POST", "/api/lm/update-node-widget", "update_node_widget"),
|
RouteDefinition("POST", "/api/lm/update-node-widget", "update_node_widget"),
|
||||||
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
|
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
|
||||||
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
|
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
|
||||||
|
RouteDefinition("GET", "/api/lm/check-models-exist", "check_models_exist"),
|
||||||
RouteDefinition(
|
RouteDefinition(
|
||||||
"GET",
|
"GET",
|
||||||
"/api/lm/model-version-download-status",
|
"/api/lm/model-version-download-status",
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ class DownloadedVersionHistoryService:
|
|||||||
self._db_path = db_path or _resolve_database_path()
|
self._db_path = db_path or _resolve_database_path()
|
||||||
self._settings = settings_manager or get_settings_manager()
|
self._settings = settings_manager or get_settings_manager()
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
self._conn: sqlite3.Connection | None = None
|
||||||
self._schema_initialized = False
|
self._schema_initialized = False
|
||||||
self._ensure_directory()
|
self._ensure_directory()
|
||||||
self._initialize_schema()
|
self._initialize_schema()
|
||||||
@@ -78,6 +79,12 @@ class DownloadedVersionHistoryService:
|
|||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
|
def _get_conn(self) -> sqlite3.Connection:
|
||||||
|
if self._conn is None:
|
||||||
|
self._conn = sqlite3.connect(self._db_path, check_same_thread=False)
|
||||||
|
self._conn.row_factory = sqlite3.Row
|
||||||
|
return self._conn
|
||||||
|
|
||||||
def _initialize_schema(self) -> None:
|
def _initialize_schema(self) -> None:
|
||||||
if self._schema_initialized:
|
if self._schema_initialized:
|
||||||
return
|
return
|
||||||
@@ -116,33 +123,33 @@ class DownloadedVersionHistoryService:
|
|||||||
timestamp = time.time()
|
timestamp = time.time()
|
||||||
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
with self._connect() as conn:
|
conn = self._get_conn()
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO downloaded_model_versions (
|
INSERT INTO downloaded_model_versions (
|
||||||
model_type, version_id, model_id, first_seen_at, last_seen_at,
|
model_type, version_id, model_id, first_seen_at, last_seen_at,
|
||||||
source, last_file_path, last_library_name, is_deleted_override
|
source, last_file_path, last_library_name, is_deleted_override
|
||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0)
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0)
|
||||||
ON CONFLICT(model_type, version_id) DO UPDATE SET
|
ON CONFLICT(model_type, version_id) DO UPDATE SET
|
||||||
model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id),
|
model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id),
|
||||||
last_seen_at = excluded.last_seen_at,
|
last_seen_at = excluded.last_seen_at,
|
||||||
source = excluded.source,
|
source = excluded.source,
|
||||||
last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path),
|
last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path),
|
||||||
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
|
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
|
||||||
is_deleted_override = 0
|
is_deleted_override = 0
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
normalized_type,
|
normalized_type,
|
||||||
normalized_version_id,
|
normalized_version_id,
|
||||||
normalized_model_id,
|
normalized_model_id,
|
||||||
timestamp,
|
timestamp,
|
||||||
timestamp,
|
timestamp,
|
||||||
source,
|
source,
|
||||||
file_path,
|
file_path,
|
||||||
active_library_name,
|
active_library_name,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
async def mark_downloaded_bulk(
|
async def mark_downloaded_bulk(
|
||||||
self,
|
self,
|
||||||
@@ -180,24 +187,24 @@ class DownloadedVersionHistoryService:
|
|||||||
return
|
return
|
||||||
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
with self._connect() as conn:
|
conn = self._get_conn()
|
||||||
conn.executemany(
|
conn.executemany(
|
||||||
"""
|
"""
|
||||||
INSERT INTO downloaded_model_versions (
|
INSERT INTO downloaded_model_versions (
|
||||||
model_type, version_id, model_id, first_seen_at, last_seen_at,
|
model_type, version_id, model_id, first_seen_at, last_seen_at,
|
||||||
source, last_file_path, last_library_name, is_deleted_override
|
source, last_file_path, last_library_name, is_deleted_override
|
||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0)
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0)
|
||||||
ON CONFLICT(model_type, version_id) DO UPDATE SET
|
ON CONFLICT(model_type, version_id) DO UPDATE SET
|
||||||
model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id),
|
model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id),
|
||||||
last_seen_at = excluded.last_seen_at,
|
last_seen_at = excluded.last_seen_at,
|
||||||
source = excluded.source,
|
source = excluded.source,
|
||||||
last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path),
|
last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path),
|
||||||
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
|
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
|
||||||
is_deleted_override = 0
|
is_deleted_override = 0
|
||||||
""",
|
""",
|
||||||
payload,
|
payload,
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
async def mark_not_downloaded(self, model_type: str, version_id: int) -> None:
|
async def mark_not_downloaded(self, model_type: str, version_id: int) -> None:
|
||||||
normalized_type = _normalize_model_type(model_type)
|
normalized_type = _normalize_model_type(model_type)
|
||||||
@@ -208,28 +215,28 @@ class DownloadedVersionHistoryService:
|
|||||||
timestamp = time.time()
|
timestamp = time.time()
|
||||||
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
with self._connect() as conn:
|
conn = self._get_conn()
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO downloaded_model_versions (
|
INSERT INTO downloaded_model_versions (
|
||||||
model_type, version_id, model_id, first_seen_at, last_seen_at,
|
model_type, version_id, model_id, first_seen_at, last_seen_at,
|
||||||
source, last_file_path, last_library_name, is_deleted_override
|
source, last_file_path, last_library_name, is_deleted_override
|
||||||
) VALUES (?, ?, NULL, ?, ?, 'manual', NULL, ?, 1)
|
) VALUES (?, ?, NULL, ?, ?, 'manual', NULL, ?, 1)
|
||||||
ON CONFLICT(model_type, version_id) DO UPDATE SET
|
ON CONFLICT(model_type, version_id) DO UPDATE SET
|
||||||
last_seen_at = excluded.last_seen_at,
|
last_seen_at = excluded.last_seen_at,
|
||||||
source = excluded.source,
|
source = excluded.source,
|
||||||
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
|
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
|
||||||
is_deleted_override = 1
|
is_deleted_override = 1
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
normalized_type,
|
normalized_type,
|
||||||
normalized_version_id,
|
normalized_version_id,
|
||||||
timestamp,
|
timestamp,
|
||||||
timestamp,
|
timestamp,
|
||||||
self._get_active_library_name(),
|
self._get_active_library_name(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
async def has_been_downloaded(self, model_type: str, version_id: int) -> bool:
|
async def has_been_downloaded(self, model_type: str, version_id: int) -> bool:
|
||||||
normalized_type = _normalize_model_type(model_type)
|
normalized_type = _normalize_model_type(model_type)
|
||||||
@@ -238,15 +245,15 @@ class DownloadedVersionHistoryService:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
with self._connect() as conn:
|
conn = self._get_conn()
|
||||||
row = conn.execute(
|
row = conn.execute(
|
||||||
"""
|
"""
|
||||||
SELECT is_deleted_override
|
SELECT is_deleted_override
|
||||||
FROM downloaded_model_versions
|
FROM downloaded_model_versions
|
||||||
WHERE model_type = ? AND version_id = ?
|
WHERE model_type = ? AND version_id = ?
|
||||||
""",
|
""",
|
||||||
(normalized_type, normalized_version_id),
|
(normalized_type, normalized_version_id),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
return bool(row) and not bool(row["is_deleted_override"])
|
return bool(row) and not bool(row["is_deleted_override"])
|
||||||
|
|
||||||
async def get_downloaded_version_ids(
|
async def get_downloaded_version_ids(
|
||||||
@@ -258,16 +265,16 @@ class DownloadedVersionHistoryService:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
with self._connect() as conn:
|
conn = self._get_conn()
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
"""
|
"""
|
||||||
SELECT version_id
|
SELECT version_id
|
||||||
FROM downloaded_model_versions
|
FROM downloaded_model_versions
|
||||||
WHERE model_type = ? AND model_id = ? AND is_deleted_override = 0
|
WHERE model_type = ? AND model_id = ? AND is_deleted_override = 0
|
||||||
ORDER BY version_id ASC
|
ORDER BY version_id ASC
|
||||||
""",
|
""",
|
||||||
(normalized_type, normalized_model_id),
|
(normalized_type, normalized_model_id),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
return [int(row["version_id"]) for row in rows]
|
return [int(row["version_id"]) for row in rows]
|
||||||
|
|
||||||
async def get_downloaded_version_ids_bulk(
|
async def get_downloaded_version_ids_bulk(
|
||||||
@@ -291,17 +298,17 @@ class DownloadedVersionHistoryService:
|
|||||||
params: list[object] = [normalized_type, *normalized_model_ids]
|
params: list[object] = [normalized_type, *normalized_model_ids]
|
||||||
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
with self._connect() as conn:
|
conn = self._get_conn()
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
f"""
|
f"""
|
||||||
SELECT model_id, version_id
|
SELECT model_id, version_id
|
||||||
FROM downloaded_model_versions
|
FROM downloaded_model_versions
|
||||||
WHERE model_type = ?
|
WHERE model_type = ?
|
||||||
AND model_id IN ({placeholders})
|
AND model_id IN ({placeholders})
|
||||||
AND is_deleted_override = 0
|
AND is_deleted_override = 0
|
||||||
""",
|
""",
|
||||||
params,
|
params,
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
result: dict[int, set[int]] = {}
|
result: dict[int, set[int]] = {}
|
||||||
for row in rows:
|
for row in rows:
|
||||||
|
|||||||
@@ -413,7 +413,7 @@ app.registerExtension({
|
|||||||
const savedItem = consumeQueuedState(itemState, itemText);
|
const savedItem = consumeQueuedState(itemState, itemText);
|
||||||
return {
|
return {
|
||||||
text: itemText,
|
text: itemText,
|
||||||
active: savedItem ? savedItem.active : defaultActive,
|
active: savedItem ? savedItem.active : true,
|
||||||
highlighted: false,
|
highlighted: false,
|
||||||
strength: null,
|
strength: null,
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user