Compare commits

...

4 Commits

Author SHA1 Message Date
Will Miao
2ac0eb0f9d fix(wanvideo): resolve lora path resolution and name truncation for extra folder paths
- Use get_lora_info_absolute to obtain correct absolute paths for loras
  in LM extra folder paths, instead of folder_paths.get_full_path which
  only searches ComfyUI's standard loras directories (returned None)
- Fix name field truncation: str.split('.')[0] stopped at the first dot,
  replaced with os.path.splitext to only strip the file extension
- Add _relpath_within_loras helper to preserve subdirectory info in the
  name field, matching WanVideoWrapper's os.path.splitext(lora)[0] format
2026-05-02 14:55:12 +08:00
Will Miao
f028625ce9 feat(check-models-exist): add batch endpoint for checking multiple model IDs
New endpoint: GET /api/lm/check-models-exist?modelIds=1,2,3,...

Accepts comma-separated modelIds, returns a results array with one
entry per modelId. Uses a single scanner lookup batch - three
service-registry calls total, regardless of model count. Skips
history checks entirely (same rationale as the singleton endpoint:
when models exist locally, history is redundant).

Expected: reduces 231 HTTP round-trips to 1 for the browser
extension's model-card indicator flow. Combined with the prior
SQLite-connection and history-skip fixes, total wall-clock time
for a 175K-lora user's page load drops from ~9.4s to <10ms.
2026-05-02 13:43:53 +08:00
Will Miao
06acc7f576 fix(trigger-word-toggle): default group children to active regardless of default_active 2026-05-02 13:33:42 +08:00
Will Miao
d324b57274 perf(check-model-exists): eliminate SQLite connection-per-query overhead and skip redundant history checks
Root cause: 231 concurrent /check-model-exists requests on 175K-lora library
caused ~9.4s wall clock time. The bottleneck was two-fold:

1. DownloadedVersionHistoryService opened a new sqlite3.connect() for every
   query under asyncio.Lock. With a large WAL from 175K entries, each
   connect() took ~8ms. Serialized by the lock across 231 requests, the
   230th request waited ~1848ms just for lock acquisition.

2. check_model_exists always queried download history even when the model
   was found locally. The history result (hasBeenDownloaded /
   downloadedVersionIds) is only used by the UI when the model is NOT
   found locally; when found, the 'in library' indicator takes priority.

Changes:
- downloaded_version_history_service.py: added persistent _get_conn() that
  creates the SQLite connection once and reuses it across all queries
- misc_handlers.py: early-return from check_model_exists when the model
  exists locally, bypassing the history service entirely (lock skipped)

Expected: per-request wait time drops from ~1912ms to <3ms, wall clock
from ~9.4s to <0.3s for the 175K-lora user's 231-card page.
2026-05-02 13:31:20 +08:00
6 changed files with 276 additions and 153 deletions

View File

@@ -1,10 +1,22 @@
import folder_paths # type: ignore
from ..utils.utils import get_lora_info
import os
from ..utils.utils import get_lora_info_absolute
from ..config import config
from .utils import FlexibleOptionalInputType, any_type, get_loras_list
import logging
logger = logging.getLogger(__name__)
def _relpath_within_loras(abs_path):
"""Return abs_path relative to the first matching lora root, or basename as fallback."""
all_roots = list(config.loras_roots or []) + list(config.extra_loras_roots or [])
for root in all_roots:
try:
return os.path.relpath(abs_path, root)
except ValueError:
continue
return os.path.basename(abs_path)
class WanVideoLoraSelectLM:
NAME = "WanVideo Lora Select (LoraManager)"
CATEGORY = "Lora Manager/stackers"
@@ -56,13 +68,13 @@ class WanVideoLoraSelectLM:
clip_strength = float(lora.get('clipStrength', model_strength))
# Get lora path and trigger words
lora_path, trigger_words = get_lora_info(lora_name)
lora_path, trigger_words = get_lora_info_absolute(lora_name)
# Create lora item for WanVideo format
lora_item = {
"path": folder_paths.get_full_path("loras", lora_path),
"path": lora_path,
"strength": model_strength,
"name": lora_path.split(".")[0],
"name": os.path.splitext(_relpath_within_loras(lora_path))[0],
"blocks": selected_blocks,
"layer_filter": layer_filter,
"low_mem_load": low_mem_load,

View File

@@ -1,11 +1,23 @@
import folder_paths # type: ignore
from ..utils.utils import get_lora_info
import os
from ..utils.utils import get_lora_info_absolute
from ..config import config
from .utils import any_type
import logging
# 初始化日志记录器
logger = logging.getLogger(__name__)
def _relpath_within_loras(abs_path):
"""Return abs_path relative to the first matching lora root, or basename as fallback."""
all_roots = list(config.loras_roots or []) + list(config.extra_loras_roots or [])
for root in all_roots:
try:
return os.path.relpath(abs_path, root)
except ValueError:
continue
return os.path.basename(abs_path)
# 定义新节点的类
class WanVideoLoraTextSelectLM:
# 节点在UI中显示的名称
@@ -87,12 +99,12 @@ class WanVideoLoraTextSelectLM:
else:
continue
lora_path, trigger_words = get_lora_info(lora_name_raw)
lora_path, trigger_words = get_lora_info_absolute(lora_name_raw)
lora_item = {
"path": folder_paths.get_full_path("loras", lora_path),
"path": lora_path,
"strength": model_strength,
"name": lora_path.split(".")[0],
"name": os.path.splitext(_relpath_within_loras(lora_path))[0],
"blocks": selected_blocks,
"layer_filter": layer_filter,
"low_mem_load": low_mem_load,

View File

@@ -1791,15 +1791,19 @@ class ModelLibraryHandler:
exists = True
model_type = "embedding"
if exists:
return web.json_response(
{
"success": True,
"exists": True,
"modelType": model_type,
"hasBeenDownloaded": False,
}
)
history_service = await self._get_download_history_service()
has_been_downloaded = False
history_type = model_type
if history_type:
has_been_downloaded = await history_service.has_been_downloaded(
history_type,
model_version_id,
)
else:
history_type = None
for candidate_type in ("lora", "checkpoint", "embedding"):
if await history_service.has_been_downloaded(
candidate_type,
@@ -1812,8 +1816,8 @@ class ModelLibraryHandler:
return web.json_response(
{
"success": True,
"exists": exists,
"modelType": model_type if exists else history_type,
"exists": False,
"modelType": history_type,
"hasBeenDownloaded": has_been_downloaded,
}
)
@@ -1833,29 +1837,35 @@ class ModelLibraryHandler:
model_type = None
versions = []
downloaded_version_ids = []
history_service = await self._get_download_history_service()
if lora_versions:
model_type = "lora"
versions = self._with_downloaded_flag(lora_versions)
downloaded_version_ids = await history_service.get_downloaded_version_ids(
model_type,
model_id,
return web.json_response(
{
"success": True,
"modelType": "lora",
"versions": self._with_downloaded_flag(lora_versions),
"downloadedVersionIds": [],
}
)
elif checkpoint_versions:
model_type = "checkpoint"
versions = self._with_downloaded_flag(checkpoint_versions)
downloaded_version_ids = await history_service.get_downloaded_version_ids(
model_type,
model_id,
if checkpoint_versions:
return web.json_response(
{
"success": True,
"modelType": "checkpoint",
"versions": self._with_downloaded_flag(checkpoint_versions),
"downloadedVersionIds": [],
}
)
elif embedding_versions:
model_type = "embedding"
versions = self._with_downloaded_flag(embedding_versions)
downloaded_version_ids = await history_service.get_downloaded_version_ids(
model_type,
model_id,
if embedding_versions:
return web.json_response(
{
"success": True,
"modelType": "embedding",
"versions": self._with_downloaded_flag(embedding_versions),
"downloadedVersionIds": [],
}
)
else:
history_service = await self._get_download_history_service()
for candidate_type in ("lora", "checkpoint", "embedding"):
candidate_downloaded_version_ids = (
await history_service.get_downloaded_version_ids(
@@ -1880,6 +1890,86 @@ class ModelLibraryHandler:
logger.error("Failed to check model existence: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def check_models_exist(self, request: web.Request) -> web.Response:
try:
model_ids_raw = request.query.get("modelIds", "")
if not model_ids_raw:
return web.json_response(
{"success": True, "results": []}
)
raw_ids = model_ids_raw.split(",")
seen: set[int] = set()
model_ids: list[int] = []
for raw in raw_ids:
stripped = raw.strip()
if not stripped:
continue
try:
mid = int(stripped)
except ValueError:
continue
if mid not in seen:
seen.add(mid)
model_ids.append(mid)
if not model_ids:
return web.json_response(
{"success": True, "results": []}
)
lora_scanner = await self._service_registry.get_lora_scanner()
checkpoint_scanner = await self._service_registry.get_checkpoint_scanner()
embedding_scanner = await self._service_registry.get_embedding_scanner()
results: list[dict] = []
for model_id in model_ids:
lora_versions = await lora_scanner.get_model_versions_by_id(model_id)
if lora_versions:
results.append({
"modelId": model_id,
"modelType": "lora",
"versions": self._with_downloaded_flag(lora_versions),
"downloadedVersionIds": [],
})
continue
if checkpoint_scanner:
checkpoint_versions = await checkpoint_scanner.get_model_versions_by_id(model_id)
if checkpoint_versions:
results.append({
"modelId": model_id,
"modelType": "checkpoint",
"versions": self._with_downloaded_flag(checkpoint_versions),
"downloadedVersionIds": [],
})
continue
if embedding_scanner:
embedding_versions = await embedding_scanner.get_model_versions_by_id(model_id)
if embedding_versions:
results.append({
"modelId": model_id,
"modelType": "embedding",
"versions": self._with_downloaded_flag(embedding_versions),
"downloadedVersionIds": [],
})
continue
results.append({
"modelId": model_id,
"modelType": None,
"versions": [],
"downloadedVersionIds": [],
})
return web.json_response(
{"success": True, "results": results}
)
except Exception as exc:
logger.error("Failed to check models existence: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_model_version_download_status(
self, request: web.Request
) -> web.Response:
@@ -3025,6 +3115,7 @@ class MiscHandlerSet:
"update_node_widget": self.node_registry.update_node_widget,
"get_registry": self.node_registry.get_registry,
"check_model_exists": self.model_library.check_model_exists,
"check_models_exist": self.model_library.check_models_exist,
"get_model_version_download_status": self.model_library.get_model_version_download_status,
"set_model_version_download_status": self.model_library.set_model_version_download_status,
"get_civitai_user_models": self.model_library.get_civitai_user_models,

View File

@@ -43,6 +43,7 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("POST", "/api/lm/update-node-widget", "update_node_widget"),
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
RouteDefinition("GET", "/api/lm/check-models-exist", "check_models_exist"),
RouteDefinition(
"GET",
"/api/lm/model-version-download-status",

View File

@@ -64,6 +64,7 @@ class DownloadedVersionHistoryService:
self._db_path = db_path or _resolve_database_path()
self._settings = settings_manager or get_settings_manager()
self._lock = asyncio.Lock()
self._conn: sqlite3.Connection | None = None
self._schema_initialized = False
self._ensure_directory()
self._initialize_schema()
@@ -78,6 +79,12 @@ class DownloadedVersionHistoryService:
conn.row_factory = sqlite3.Row
return conn
def _get_conn(self) -> sqlite3.Connection:
if self._conn is None:
self._conn = sqlite3.connect(self._db_path, check_same_thread=False)
self._conn.row_factory = sqlite3.Row
return self._conn
def _initialize_schema(self) -> None:
if self._schema_initialized:
return
@@ -116,7 +123,7 @@ class DownloadedVersionHistoryService:
timestamp = time.time()
async with self._lock:
with self._connect() as conn:
conn = self._get_conn()
conn.execute(
"""
INSERT INTO downloaded_model_versions (
@@ -180,7 +187,7 @@ class DownloadedVersionHistoryService:
return
async with self._lock:
with self._connect() as conn:
conn = self._get_conn()
conn.executemany(
"""
INSERT INTO downloaded_model_versions (
@@ -208,7 +215,7 @@ class DownloadedVersionHistoryService:
timestamp = time.time()
async with self._lock:
with self._connect() as conn:
conn = self._get_conn()
conn.execute(
"""
INSERT INTO downloaded_model_versions (
@@ -238,7 +245,7 @@ class DownloadedVersionHistoryService:
return False
async with self._lock:
with self._connect() as conn:
conn = self._get_conn()
row = conn.execute(
"""
SELECT is_deleted_override
@@ -258,7 +265,7 @@ class DownloadedVersionHistoryService:
return []
async with self._lock:
with self._connect() as conn:
conn = self._get_conn()
rows = conn.execute(
"""
SELECT version_id
@@ -291,7 +298,7 @@ class DownloadedVersionHistoryService:
params: list[object] = [normalized_type, *normalized_model_ids]
async with self._lock:
with self._connect() as conn:
conn = self._get_conn()
rows = conn.execute(
f"""
SELECT model_id, version_id

View File

@@ -413,7 +413,7 @@ app.registerExtension({
const savedItem = consumeQueuedState(itemState, itemText);
return {
text: itemText,
active: savedItem ? savedItem.active : defaultActive,
active: savedItem ? savedItem.active : true,
highlighted: false,
strength: null,
};