mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 16:36:45 -03:00
Compare commits
6 Commits
30db8c3d1d
...
v1.0.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b6dd6938b0 | ||
|
|
727d0ef043 | ||
|
|
9344d86332 | ||
|
|
d36b16c213 | ||
|
|
33a7f07558 | ||
|
|
4f599aeced |
@@ -56,6 +56,13 @@ Insomnia Art Designs, megakirbs, Brennok, 2018cfh, W+K+White, wackop, Takkan, Ca
|
|||||||
|
|
||||||
## Release Notes
|
## Release Notes
|
||||||
|
|
||||||
|
### v1.0.2
|
||||||
|
|
||||||
|
* **Model Download History Tracking** - LoRA Manager now keeps a history of downloaded model versions, allowing it to recognize whether a version has been downloaded before, even if it is no longer currently present in your library.
|
||||||
|
* **Skip Previously Downloaded Model Versions** - Added a new setting, `Skip previously downloaded model versions`, to help avoid downloading model versions you have already downloaded in the past.
|
||||||
|
* **LoRA Stack Combiner Trigger Words Fix** - Fixed an issue where trigger word updates from `LORA_STACK` inputs were not propagated correctly through the LoRA Stack Combiner node.
|
||||||
|
* **CivitAI Example Image Compatibility** - Improved support for CivitAI CDN subdomains so example images load more reliably.
|
||||||
|
|
||||||
### v1.0.1
|
### v1.0.1
|
||||||
|
|
||||||
* **Batch Recipe Import** - Import recipes from multiple URLs or directories simultaneously with optimized concurrency.
|
* **Batch Recipe Import** - Import recipes from multiple URLs or directories simultaneously with optimized concurrency.
|
||||||
|
|||||||
@@ -341,6 +341,10 @@
|
|||||||
"saveFailed": "Ausgeschlossene Basismodelle konnten nicht gespeichert werden: {message}"
|
"saveFailed": "Ausgeschlossene Basismodelle konnten nicht gespeichert werden: {message}"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"skipPreviouslyDownloadedModelVersions": {
|
||||||
|
"label": "Bereits heruntergeladene Modellversionen überspringen",
|
||||||
|
"help": "Wenn aktiviert, überspringt LoRA Manager den Download einer Modellversion, wenn der Download-Verlaufsdienst diese spezifische Version als bereits heruntergeladen erfasst hat. Gilt für alle Download-Abläufe."
|
||||||
|
},
|
||||||
"layoutSettings": {
|
"layoutSettings": {
|
||||||
"displayDensity": "Anzeige-Dichte",
|
"displayDensity": "Anzeige-Dichte",
|
||||||
"displayDensityOptions": {
|
"displayDensityOptions": {
|
||||||
@@ -827,7 +831,7 @@
|
|||||||
},
|
},
|
||||||
"contextMenu": {
|
"contextMenu": {
|
||||||
"moveToOtherTypeFolder": "In {otherType}-Ordner verschieben",
|
"moveToOtherTypeFolder": "In {otherType}-Ordner verschieben",
|
||||||
"sendToWorkflow": "[TODO: Translate] Send to Workflow"
|
"sendToWorkflow": "An Workflow senden"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"embeddings": {
|
"embeddings": {
|
||||||
|
|||||||
@@ -325,7 +325,7 @@
|
|||||||
},
|
},
|
||||||
"downloadSkipBaseModels": {
|
"downloadSkipBaseModels": {
|
||||||
"label": "Skip downloads for base models",
|
"label": "Skip downloads for base models",
|
||||||
"help": "When a model version uses one of these base models, LoRA Manager will skip the download before any file transfer starts. Applies to all download flows. Only supported base models can be selected here.",
|
"help": "When enabled, versions using the selected base models will be skipped.",
|
||||||
"searchPlaceholder": "Filter base models...",
|
"searchPlaceholder": "Filter base models...",
|
||||||
"empty": "No base models match the current search.",
|
"empty": "No base models match the current search.",
|
||||||
"summary": {
|
"summary": {
|
||||||
@@ -341,6 +341,10 @@
|
|||||||
"saveFailed": "Unable to save excluded base models: {message}"
|
"saveFailed": "Unable to save excluded base models: {message}"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"skipPreviouslyDownloadedModelVersions": {
|
||||||
|
"label": "Skip previously downloaded model versions",
|
||||||
|
"help": "When enabled, versions downloaded before will be skipped."
|
||||||
|
},
|
||||||
"layoutSettings": {
|
"layoutSettings": {
|
||||||
"displayDensity": "Display Density",
|
"displayDensity": "Display Density",
|
||||||
"displayDensityOptions": {
|
"displayDensityOptions": {
|
||||||
|
|||||||
@@ -341,6 +341,10 @@
|
|||||||
"saveFailed": "No se pudieron guardar los modelos base excluidos: {message}"
|
"saveFailed": "No se pudieron guardar los modelos base excluidos: {message}"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"skipPreviouslyDownloadedModelVersions": {
|
||||||
|
"label": "Omitir versiones de modelos previamente descargadas",
|
||||||
|
"help": "Cuando está habilitado, LoRA Manager omitirá la descarga de una versión de modelo si el servicio de historial de descargas registra esa versión exacta como ya descargada. Aplica a todos los flujos de descarga."
|
||||||
|
},
|
||||||
"layoutSettings": {
|
"layoutSettings": {
|
||||||
"displayDensity": "Densidad de visualización",
|
"displayDensity": "Densidad de visualización",
|
||||||
"displayDensityOptions": {
|
"displayDensityOptions": {
|
||||||
@@ -827,7 +831,7 @@
|
|||||||
},
|
},
|
||||||
"contextMenu": {
|
"contextMenu": {
|
||||||
"moveToOtherTypeFolder": "Mover a la carpeta {otherType}",
|
"moveToOtherTypeFolder": "Mover a la carpeta {otherType}",
|
||||||
"sendToWorkflow": "[TODO: Translate] Send to Workflow"
|
"sendToWorkflow": "Enviar al flujo de trabajo"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"embeddings": {
|
"embeddings": {
|
||||||
|
|||||||
@@ -341,6 +341,10 @@
|
|||||||
"saveFailed": "Impossible d’enregistrer les modèles de base exclus : {message}"
|
"saveFailed": "Impossible d’enregistrer les modèles de base exclus : {message}"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"skipPreviouslyDownloadedModelVersions": {
|
||||||
|
"label": "Ignorer les versions de modèles précédemment téléchargées",
|
||||||
|
"help": "Lorsque activé, LoRA Manager ignorera le téléchargement d'une version de modèle si le service d'historique des téléchargements enregistre cette version exacte comme déjà téléchargée. S'applique à tous les flux de téléchargement."
|
||||||
|
},
|
||||||
"layoutSettings": {
|
"layoutSettings": {
|
||||||
"displayDensity": "Densité d'affichage",
|
"displayDensity": "Densité d'affichage",
|
||||||
"displayDensityOptions": {
|
"displayDensityOptions": {
|
||||||
@@ -827,7 +831,7 @@
|
|||||||
},
|
},
|
||||||
"contextMenu": {
|
"contextMenu": {
|
||||||
"moveToOtherTypeFolder": "Déplacer vers le dossier {otherType}",
|
"moveToOtherTypeFolder": "Déplacer vers le dossier {otherType}",
|
||||||
"sendToWorkflow": "[TODO: Translate] Send to Workflow"
|
"sendToWorkflow": "Envoyer vers le workflow"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"embeddings": {
|
"embeddings": {
|
||||||
|
|||||||
@@ -341,6 +341,10 @@
|
|||||||
"saveFailed": "לא ניתן לשמור את מודלי הבסיס המוחרגים: {message}"
|
"saveFailed": "לא ניתן לשמור את מודלי הבסיס המוחרגים: {message}"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"skipPreviouslyDownloadedModelVersions": {
|
||||||
|
"label": "דלג על גרסאות מודלים שהורדו בעבר",
|
||||||
|
"help": "כאשר מופעל, LoRA Manager ידלג על הורדת גרסת מודל אם שירות היסטוריית ההורדות רושם את הגרסה המדויקת הזו ככבר שהורדה. חל על כל תהליכי ההורדה."
|
||||||
|
},
|
||||||
"layoutSettings": {
|
"layoutSettings": {
|
||||||
"displayDensity": "צפיפות תצוגה",
|
"displayDensity": "צפיפות תצוגה",
|
||||||
"displayDensityOptions": {
|
"displayDensityOptions": {
|
||||||
@@ -827,7 +831,7 @@
|
|||||||
},
|
},
|
||||||
"contextMenu": {
|
"contextMenu": {
|
||||||
"moveToOtherTypeFolder": "העבר לתיקיית {otherType}",
|
"moveToOtherTypeFolder": "העבר לתיקיית {otherType}",
|
||||||
"sendToWorkflow": "[TODO: Translate] Send to Workflow"
|
"sendToWorkflow": "שלח ל-workflow"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"embeddings": {
|
"embeddings": {
|
||||||
|
|||||||
@@ -341,6 +341,10 @@
|
|||||||
"saveFailed": "除外するベースモデルを保存できませんでした: {message}"
|
"saveFailed": "除外するベースモデルを保存できませんでした: {message}"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"skipPreviouslyDownloadedModelVersions": {
|
||||||
|
"label": "以前にダウンロードしたモデルバージョンをスキップ",
|
||||||
|
"help": "有効にすると、ダウンロード履歴サービスがそのバージョンが既にダウンロード済みと記録している場合、LoRA Managerはそのモデルバージョンのダウンロードをスキップします。すべてのダウンロードフローに適用されます。"
|
||||||
|
},
|
||||||
"layoutSettings": {
|
"layoutSettings": {
|
||||||
"displayDensity": "表示密度",
|
"displayDensity": "表示密度",
|
||||||
"displayDensityOptions": {
|
"displayDensityOptions": {
|
||||||
@@ -827,7 +831,7 @@
|
|||||||
},
|
},
|
||||||
"contextMenu": {
|
"contextMenu": {
|
||||||
"moveToOtherTypeFolder": "{otherType} フォルダに移動",
|
"moveToOtherTypeFolder": "{otherType} フォルダに移動",
|
||||||
"sendToWorkflow": "[TODO: Translate] Send to Workflow"
|
"sendToWorkflow": "ワークフローに送信"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"embeddings": {
|
"embeddings": {
|
||||||
|
|||||||
@@ -341,6 +341,10 @@
|
|||||||
"saveFailed": "제외된 기본 모델을 저장할 수 없습니다: {message}"
|
"saveFailed": "제외된 기본 모델을 저장할 수 없습니다: {message}"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"skipPreviouslyDownloadedModelVersions": {
|
||||||
|
"label": "이전에 다운로드한 모델 버전 건너뛰기",
|
||||||
|
"help": "활성화하면 다운로드 기록 서비스가 해당 버전이 이미 다운로드되었음을 기록한 경우 LoRA Manager는 해당 모델 버전 다운로드를 건너뜁니다. 모든 다운로드 플로우에 적용됩니다."
|
||||||
|
},
|
||||||
"layoutSettings": {
|
"layoutSettings": {
|
||||||
"displayDensity": "표시 밀도",
|
"displayDensity": "표시 밀도",
|
||||||
"displayDensityOptions": {
|
"displayDensityOptions": {
|
||||||
@@ -827,7 +831,7 @@
|
|||||||
},
|
},
|
||||||
"contextMenu": {
|
"contextMenu": {
|
||||||
"moveToOtherTypeFolder": "{otherType} 폴더로 이동",
|
"moveToOtherTypeFolder": "{otherType} 폴더로 이동",
|
||||||
"sendToWorkflow": "[TODO: Translate] Send to Workflow"
|
"sendToWorkflow": "워크플로우로 전송"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"embeddings": {
|
"embeddings": {
|
||||||
|
|||||||
@@ -341,6 +341,10 @@
|
|||||||
"saveFailed": "Не удалось сохранить исключённые базовые модели: {message}"
|
"saveFailed": "Не удалось сохранить исключённые базовые модели: {message}"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"skipPreviouslyDownloadedModelVersions": {
|
||||||
|
"label": "Пропускать ранее загруженные версии моделей",
|
||||||
|
"help": "Если включено, LoRA Manager будет пропускать загрузку версии модели, если сервис истории загрузок записал, что эта конкретная версия уже загружена. Применяется ко всем потокам загрузки."
|
||||||
|
},
|
||||||
"layoutSettings": {
|
"layoutSettings": {
|
||||||
"displayDensity": "Плотность отображения",
|
"displayDensity": "Плотность отображения",
|
||||||
"displayDensityOptions": {
|
"displayDensityOptions": {
|
||||||
@@ -827,7 +831,7 @@
|
|||||||
},
|
},
|
||||||
"contextMenu": {
|
"contextMenu": {
|
||||||
"moveToOtherTypeFolder": "Переместить в папку {otherType}",
|
"moveToOtherTypeFolder": "Переместить в папку {otherType}",
|
||||||
"sendToWorkflow": "[TODO: Translate] Send to Workflow"
|
"sendToWorkflow": "Отправить в workflow"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"embeddings": {
|
"embeddings": {
|
||||||
|
|||||||
@@ -341,6 +341,10 @@
|
|||||||
"saveFailed": "无法保存已排除的基础模型:{message}"
|
"saveFailed": "无法保存已排除的基础模型:{message}"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"skipPreviouslyDownloadedModelVersions": {
|
||||||
|
"label": "跳过已下载的模型版本",
|
||||||
|
"help": "启用后,如果下载历史服务记录显示该版本已下载,LoRA Manager 将跳过下载该模型版本。适用于所有下载流程。"
|
||||||
|
},
|
||||||
"layoutSettings": {
|
"layoutSettings": {
|
||||||
"displayDensity": "显示密度",
|
"displayDensity": "显示密度",
|
||||||
"displayDensityOptions": {
|
"displayDensityOptions": {
|
||||||
@@ -827,7 +831,7 @@
|
|||||||
},
|
},
|
||||||
"contextMenu": {
|
"contextMenu": {
|
||||||
"moveToOtherTypeFolder": "移动到 {otherType} 文件夹",
|
"moveToOtherTypeFolder": "移动到 {otherType} 文件夹",
|
||||||
"sendToWorkflow": "[TODO: Translate] Send to Workflow"
|
"sendToWorkflow": "发送到工作流"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"embeddings": {
|
"embeddings": {
|
||||||
|
|||||||
@@ -341,6 +341,10 @@
|
|||||||
"saveFailed": "無法儲存已排除的基礎模型:{message}"
|
"saveFailed": "無法儲存已排除的基礎模型:{message}"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"skipPreviouslyDownloadedModelVersions": {
|
||||||
|
"label": "跳過已下載的模型版本",
|
||||||
|
"help": "啟用後,如果下載歷史服務記錄顯示該版本已下載,LoRA Manager 將跳過下載該模型版本。適用於所有下載流程。"
|
||||||
|
},
|
||||||
"layoutSettings": {
|
"layoutSettings": {
|
||||||
"displayDensity": "顯示密度",
|
"displayDensity": "顯示密度",
|
||||||
"displayDensityOptions": {
|
"displayDensityOptions": {
|
||||||
@@ -827,7 +831,7 @@
|
|||||||
},
|
},
|
||||||
"contextMenu": {
|
"contextMenu": {
|
||||||
"moveToOtherTypeFolder": "移動到 {otherType} 資料夾",
|
"moveToOtherTypeFolder": "移動到 {otherType} 資料夾",
|
||||||
"sendToWorkflow": "[TODO: Translate] Send to Workflow"
|
"sendToWorkflow": "傳送到工作流"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"embeddings": {
|
"embeddings": {
|
||||||
|
|||||||
@@ -751,6 +751,7 @@ class ServiceRegistryAdapter:
|
|||||||
get_lora_scanner: Callable[[], Awaitable]
|
get_lora_scanner: Callable[[], Awaitable]
|
||||||
get_checkpoint_scanner: Callable[[], Awaitable]
|
get_checkpoint_scanner: Callable[[], Awaitable]
|
||||||
get_embedding_scanner: Callable[[], Awaitable]
|
get_embedding_scanner: Callable[[], Awaitable]
|
||||||
|
get_downloaded_version_history_service: Callable[[], Awaitable]
|
||||||
|
|
||||||
|
|
||||||
class ModelLibraryHandler:
|
class ModelLibraryHandler:
|
||||||
@@ -764,6 +765,41 @@ class ModelLibraryHandler:
|
|||||||
self._service_registry = service_registry
|
self._service_registry = service_registry
|
||||||
self._metadata_provider_factory = metadata_provider_factory
|
self._metadata_provider_factory = metadata_provider_factory
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_model_type(model_type: str | None) -> str | None:
|
||||||
|
if not isinstance(model_type, str):
|
||||||
|
return None
|
||||||
|
normalized = model_type.strip().lower()
|
||||||
|
if normalized in {"lora", "locon", "dora"}:
|
||||||
|
return "lora"
|
||||||
|
if normalized == "checkpoint":
|
||||||
|
return "checkpoint"
|
||||||
|
if normalized in {"embedding", "textualinversion"}:
|
||||||
|
return "embedding"
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _get_scanner_for_type(self, model_type: str | None):
|
||||||
|
normalized_type = self._normalize_model_type(model_type)
|
||||||
|
if normalized_type == "lora":
|
||||||
|
return normalized_type, await self._service_registry.get_lora_scanner()
|
||||||
|
if normalized_type == "checkpoint":
|
||||||
|
return normalized_type, await self._service_registry.get_checkpoint_scanner()
|
||||||
|
if normalized_type == "embedding":
|
||||||
|
return normalized_type, await self._service_registry.get_embedding_scanner()
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
async def _get_download_history_service(self):
|
||||||
|
return await self._service_registry.get_downloaded_version_history_service()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _with_downloaded_flag(versions: list[dict]) -> list[dict]:
|
||||||
|
enriched: list[dict] = []
|
||||||
|
for version in versions:
|
||||||
|
entry = dict(version)
|
||||||
|
entry.setdefault("hasBeenDownloaded", True)
|
||||||
|
enriched.append(entry)
|
||||||
|
return enriched
|
||||||
|
|
||||||
async def check_model_exists(self, request: web.Request) -> web.Response:
|
async def check_model_exists(self, request: web.Request) -> web.Response:
|
||||||
try:
|
try:
|
||||||
model_id_str = request.query.get("modelId")
|
model_id_str = request.query.get("modelId")
|
||||||
@@ -819,11 +855,30 @@ class ModelLibraryHandler:
|
|||||||
exists = True
|
exists = True
|
||||||
model_type = "embedding"
|
model_type = "embedding"
|
||||||
|
|
||||||
|
history_service = await self._get_download_history_service()
|
||||||
|
has_been_downloaded = False
|
||||||
|
history_type = model_type
|
||||||
|
if history_type:
|
||||||
|
has_been_downloaded = await history_service.has_been_downloaded(
|
||||||
|
history_type,
|
||||||
|
model_version_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for candidate_type in ("lora", "checkpoint", "embedding"):
|
||||||
|
if await history_service.has_been_downloaded(
|
||||||
|
candidate_type,
|
||||||
|
model_version_id,
|
||||||
|
):
|
||||||
|
has_been_downloaded = True
|
||||||
|
history_type = candidate_type
|
||||||
|
break
|
||||||
|
|
||||||
return web.json_response(
|
return web.json_response(
|
||||||
{
|
{
|
||||||
"success": True,
|
"success": True,
|
||||||
"exists": exists,
|
"exists": exists,
|
||||||
"modelType": model_type if exists else None,
|
"modelType": model_type if exists else history_type,
|
||||||
|
"hasBeenDownloaded": has_been_downloaded,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -841,23 +896,166 @@ class ModelLibraryHandler:
|
|||||||
|
|
||||||
model_type = None
|
model_type = None
|
||||||
versions = []
|
versions = []
|
||||||
|
downloaded_version_ids = []
|
||||||
|
history_service = await self._get_download_history_service()
|
||||||
if lora_versions:
|
if lora_versions:
|
||||||
model_type = "lora"
|
model_type = "lora"
|
||||||
versions = lora_versions
|
versions = self._with_downloaded_flag(lora_versions)
|
||||||
|
downloaded_version_ids = await history_service.get_downloaded_version_ids(
|
||||||
|
model_type,
|
||||||
|
model_id,
|
||||||
|
)
|
||||||
elif checkpoint_versions:
|
elif checkpoint_versions:
|
||||||
model_type = "checkpoint"
|
model_type = "checkpoint"
|
||||||
versions = checkpoint_versions
|
versions = self._with_downloaded_flag(checkpoint_versions)
|
||||||
|
downloaded_version_ids = await history_service.get_downloaded_version_ids(
|
||||||
|
model_type,
|
||||||
|
model_id,
|
||||||
|
)
|
||||||
elif embedding_versions:
|
elif embedding_versions:
|
||||||
model_type = "embedding"
|
model_type = "embedding"
|
||||||
versions = embedding_versions
|
versions = self._with_downloaded_flag(embedding_versions)
|
||||||
|
downloaded_version_ids = await history_service.get_downloaded_version_ids(
|
||||||
|
model_type,
|
||||||
|
model_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for candidate_type in ("lora", "checkpoint", "embedding"):
|
||||||
|
candidate_downloaded_version_ids = (
|
||||||
|
await history_service.get_downloaded_version_ids(
|
||||||
|
candidate_type,
|
||||||
|
model_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if candidate_downloaded_version_ids:
|
||||||
|
model_type = candidate_type
|
||||||
|
downloaded_version_ids = candidate_downloaded_version_ids
|
||||||
|
break
|
||||||
|
|
||||||
return web.json_response(
|
return web.json_response(
|
||||||
{"success": True, "modelType": model_type, "versions": versions}
|
{
|
||||||
|
"success": True,
|
||||||
|
"modelType": model_type,
|
||||||
|
"versions": versions,
|
||||||
|
"downloadedVersionIds": downloaded_version_ids,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
logger.error("Failed to check model existence: %s", exc, exc_info=True)
|
logger.error("Failed to check model existence: %s", exc, exc_info=True)
|
||||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def get_model_version_download_status(
|
||||||
|
self, request: web.Request
|
||||||
|
) -> web.Response:
|
||||||
|
try:
|
||||||
|
model_type, _ = await self._get_scanner_for_type(request.query.get("modelType"))
|
||||||
|
if not model_type:
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": "Parameter modelType is required"},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_version_id_str = request.query.get("modelVersionId")
|
||||||
|
if not model_version_id_str:
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": "Missing required parameter: modelVersionId"},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
model_version_id = int(model_version_id_str)
|
||||||
|
except ValueError:
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": "Parameter modelVersionId must be an integer"},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
history_service = await self._get_download_history_service()
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"modelType": model_type,
|
||||||
|
"modelVersionId": model_version_id,
|
||||||
|
"hasBeenDownloaded": await history_service.has_been_downloaded(
|
||||||
|
model_type,
|
||||||
|
model_version_id,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
logger.error(
|
||||||
|
"Failed to get model version download status: %s",
|
||||||
|
exc,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def set_model_version_download_status(
|
||||||
|
self, request: web.Request
|
||||||
|
) -> web.Response:
|
||||||
|
try:
|
||||||
|
if request.method == "GET":
|
||||||
|
data = request.query
|
||||||
|
else:
|
||||||
|
data = await request.json()
|
||||||
|
model_type, _ = await self._get_scanner_for_type(data.get("modelType"))
|
||||||
|
if not model_type:
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": "Parameter modelType is required"},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_version_id = int(data.get("modelVersionId"))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": "Parameter modelVersionId must be an integer"},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
downloaded = data.get("downloaded")
|
||||||
|
if isinstance(downloaded, str):
|
||||||
|
normalized_downloaded = downloaded.strip().lower()
|
||||||
|
if normalized_downloaded in {"true", "1"}:
|
||||||
|
downloaded = True
|
||||||
|
elif normalized_downloaded in {"false", "0"}:
|
||||||
|
downloaded = False
|
||||||
|
|
||||||
|
if not isinstance(downloaded, bool):
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": "Parameter downloaded must be a boolean"},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
history_service = await self._get_download_history_service()
|
||||||
|
if downloaded:
|
||||||
|
model_id = data.get("modelId")
|
||||||
|
file_path = data.get("filePath")
|
||||||
|
await history_service.mark_downloaded(
|
||||||
|
model_type,
|
||||||
|
model_version_id,
|
||||||
|
model_id=model_id,
|
||||||
|
source="manual",
|
||||||
|
file_path=file_path if isinstance(file_path, str) else None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await history_service.mark_not_downloaded(model_type, model_version_id)
|
||||||
|
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"modelType": model_type,
|
||||||
|
"modelVersionId": model_version_id,
|
||||||
|
"hasBeenDownloaded": downloaded,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
logger.error(
|
||||||
|
"Failed to set model version download status: %s",
|
||||||
|
exc,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
async def get_model_versions_status(self, request: web.Request) -> web.Response:
|
async def get_model_versions_status(self, request: web.Request) -> web.Response:
|
||||||
try:
|
try:
|
||||||
model_id_str = request.query.get("modelId")
|
model_id_str = request.query.get("modelId")
|
||||||
@@ -896,18 +1094,8 @@ class ModelLibraryHandler:
|
|||||||
model_name = response.get("name", "")
|
model_name = response.get("name", "")
|
||||||
model_type = response.get("type", "").lower()
|
model_type = response.get("type", "").lower()
|
||||||
|
|
||||||
scanner = None
|
normalized_type, scanner = await self._get_scanner_for_type(model_type)
|
||||||
normalized_type = None
|
if not normalized_type:
|
||||||
if model_type in {"lora", "locon", "dora"}:
|
|
||||||
scanner = await self._service_registry.get_lora_scanner()
|
|
||||||
normalized_type = "lora"
|
|
||||||
elif model_type == "checkpoint":
|
|
||||||
scanner = await self._service_registry.get_checkpoint_scanner()
|
|
||||||
normalized_type = "checkpoint"
|
|
||||||
elif model_type == "textualinversion":
|
|
||||||
scanner = await self._service_registry.get_embedding_scanner()
|
|
||||||
normalized_type = "embedding"
|
|
||||||
else:
|
|
||||||
return web.json_response(
|
return web.json_response(
|
||||||
{
|
{
|
||||||
"success": False,
|
"success": False,
|
||||||
@@ -925,8 +1113,14 @@ class ModelLibraryHandler:
|
|||||||
status=503,
|
status=503,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
history_service = await self._get_download_history_service()
|
||||||
local_versions = await scanner.get_model_versions_by_id(model_id)
|
local_versions = await scanner.get_model_versions_by_id(model_id)
|
||||||
local_version_ids = {version["versionId"] for version in local_versions}
|
local_version_ids = {version["versionId"] for version in local_versions}
|
||||||
|
downloaded_version_ids = await history_service.get_downloaded_version_ids(
|
||||||
|
normalized_type,
|
||||||
|
model_id,
|
||||||
|
)
|
||||||
|
downloaded_version_id_set = set(downloaded_version_ids)
|
||||||
|
|
||||||
enriched_versions = []
|
enriched_versions = []
|
||||||
for version in versions:
|
for version in versions:
|
||||||
@@ -939,6 +1133,7 @@ class ModelLibraryHandler:
|
|||||||
if version.get("images")
|
if version.get("images")
|
||||||
else None,
|
else None,
|
||||||
"inLibrary": version_id in local_version_ids,
|
"inLibrary": version_id in local_version_ids,
|
||||||
|
"hasBeenDownloaded": version_id in downloaded_version_id_set,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1007,6 +1202,33 @@ class ModelLibraryHandler:
|
|||||||
}
|
}
|
||||||
|
|
||||||
versions: list[dict] = []
|
versions: list[dict] = []
|
||||||
|
history_service = await self._get_download_history_service()
|
||||||
|
model_ids: list[int] = []
|
||||||
|
for model in models:
|
||||||
|
try:
|
||||||
|
model_ids.append(int(model.get("id")))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora_downloaded = await history_service.get_downloaded_version_ids_bulk(
|
||||||
|
"lora",
|
||||||
|
model_ids,
|
||||||
|
)
|
||||||
|
checkpoint_downloaded = await history_service.get_downloaded_version_ids_bulk(
|
||||||
|
"checkpoint",
|
||||||
|
model_ids,
|
||||||
|
)
|
||||||
|
embedding_downloaded = await history_service.get_downloaded_version_ids_bulk(
|
||||||
|
"embedding",
|
||||||
|
model_ids,
|
||||||
|
)
|
||||||
|
downloaded_version_map: Dict[str, Dict[int, set[int]]] = {
|
||||||
|
"lora": lora_downloaded,
|
||||||
|
"locon": lora_downloaded,
|
||||||
|
"dora": lora_downloaded,
|
||||||
|
"checkpoint": checkpoint_downloaded,
|
||||||
|
"textualinversion": embedding_downloaded,
|
||||||
|
}
|
||||||
for model in models:
|
for model in models:
|
||||||
if not isinstance(model, dict):
|
if not isinstance(model, dict):
|
||||||
continue
|
continue
|
||||||
@@ -1061,6 +1283,8 @@ class ModelLibraryHandler:
|
|||||||
in_library = await scanner.check_model_version_exists(
|
in_library = await scanner.check_model_version_exists(
|
||||||
version_id_int
|
version_id_int
|
||||||
)
|
)
|
||||||
|
downloaded_versions = downloaded_version_map.get(model_type, {})
|
||||||
|
downloaded_version_ids = downloaded_versions.get(model_id_int, set())
|
||||||
|
|
||||||
versions.append(
|
versions.append(
|
||||||
{
|
{
|
||||||
@@ -1073,6 +1297,7 @@ class ModelLibraryHandler:
|
|||||||
"baseModel": version.get("baseModel"),
|
"baseModel": version.get("baseModel"),
|
||||||
"thumbnailUrl": thumbnail_url,
|
"thumbnailUrl": thumbnail_url,
|
||||||
"inLibrary": in_library,
|
"inLibrary": in_library,
|
||||||
|
"hasBeenDownloaded": version_id_int in downloaded_version_ids,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1655,6 +1880,8 @@ class MiscHandlerSet:
|
|||||||
"update_node_widget": self.node_registry.update_node_widget,
|
"update_node_widget": self.node_registry.update_node_widget,
|
||||||
"get_registry": self.node_registry.get_registry,
|
"get_registry": self.node_registry.get_registry,
|
||||||
"check_model_exists": self.model_library.check_model_exists,
|
"check_model_exists": self.model_library.check_model_exists,
|
||||||
|
"get_model_version_download_status": self.model_library.get_model_version_download_status,
|
||||||
|
"set_model_version_download_status": self.model_library.set_model_version_download_status,
|
||||||
"get_civitai_user_models": self.model_library.get_civitai_user_models,
|
"get_civitai_user_models": self.model_library.get_civitai_user_models,
|
||||||
"download_metadata_archive": self.metadata_archive.download_metadata_archive,
|
"download_metadata_archive": self.metadata_archive.download_metadata_archive,
|
||||||
"remove_metadata_archive": self.metadata_archive.remove_metadata_archive,
|
"remove_metadata_archive": self.metadata_archive.remove_metadata_archive,
|
||||||
@@ -1679,4 +1906,5 @@ def build_service_registry_adapter() -> ServiceRegistryAdapter:
|
|||||||
get_lora_scanner=ServiceRegistry.get_lora_scanner,
|
get_lora_scanner=ServiceRegistry.get_lora_scanner,
|
||||||
get_checkpoint_scanner=ServiceRegistry.get_checkpoint_scanner,
|
get_checkpoint_scanner=ServiceRegistry.get_checkpoint_scanner,
|
||||||
get_embedding_scanner=ServiceRegistry.get_embedding_scanner,
|
get_embedding_scanner=ServiceRegistry.get_embedding_scanner,
|
||||||
|
get_downloaded_version_history_service=ServiceRegistry.get_downloaded_version_history_service,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -37,6 +37,21 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
|||||||
RouteDefinition("POST", "/api/lm/update-node-widget", "update_node_widget"),
|
RouteDefinition("POST", "/api/lm/update-node-widget", "update_node_widget"),
|
||||||
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
|
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
|
||||||
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
|
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
|
||||||
|
RouteDefinition(
|
||||||
|
"GET",
|
||||||
|
"/api/lm/model-version-download-status",
|
||||||
|
"get_model_version_download_status",
|
||||||
|
),
|
||||||
|
RouteDefinition(
|
||||||
|
"POST",
|
||||||
|
"/api/lm/model-version-download-status",
|
||||||
|
"set_model_version_download_status",
|
||||||
|
),
|
||||||
|
RouteDefinition(
|
||||||
|
"GET",
|
||||||
|
"/api/lm/set-model-version-download-status",
|
||||||
|
"set_model_version_download_status",
|
||||||
|
),
|
||||||
RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"),
|
RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"),
|
||||||
RouteDefinition(
|
RouteDefinition(
|
||||||
"POST", "/api/lm/download-metadata-archive", "download_metadata_archive"
|
"POST", "/api/lm/download-metadata-archive", "download_metadata_archive"
|
||||||
|
|||||||
@@ -64,6 +64,19 @@ class DownloadManager:
|
|||||||
"""Get the checkpoint scanner from registry"""
|
"""Get the checkpoint scanner from registry"""
|
||||||
return await ServiceRegistry.get_checkpoint_scanner()
|
return await ServiceRegistry.get_checkpoint_scanner()
|
||||||
|
|
||||||
|
async def _has_been_downloaded(self, model_type: str, model_version_id: int) -> bool:
|
||||||
|
try:
|
||||||
|
history_service = await ServiceRegistry.get_downloaded_version_history_service()
|
||||||
|
return await history_service.has_been_downloaded(model_type, model_version_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(
|
||||||
|
"Failed to read download history for %s version %s: %s",
|
||||||
|
model_type,
|
||||||
|
model_version_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
async def download_from_civitai(
|
async def download_from_civitai(
|
||||||
self,
|
self,
|
||||||
model_id: int = None,
|
model_id: int = None,
|
||||||
@@ -355,6 +368,57 @@ class DownloadManager:
|
|||||||
"error": f'Model type "{model_type_from_info}" is not supported for download',
|
"error": f'Model type "{model_type_from_info}" is not supported for download',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resolved_version_id = model_version_id
|
||||||
|
raw_version_id = version_info.get("id")
|
||||||
|
if resolved_version_id is None and raw_version_id is not None:
|
||||||
|
try:
|
||||||
|
resolved_version_id = int(raw_version_id)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
resolved_version_id = None
|
||||||
|
|
||||||
|
if (
|
||||||
|
get_settings_manager().get_skip_previously_downloaded_model_versions()
|
||||||
|
and resolved_version_id is not None
|
||||||
|
and await self._has_been_downloaded(model_type, resolved_version_id)
|
||||||
|
):
|
||||||
|
file_name = ""
|
||||||
|
files = version_info.get("files")
|
||||||
|
if isinstance(files, list):
|
||||||
|
primary_file = next(
|
||||||
|
(
|
||||||
|
file_info
|
||||||
|
for file_info in files
|
||||||
|
if isinstance(file_info, dict) and file_info.get("primary")
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
selected_file = primary_file
|
||||||
|
if selected_file is None:
|
||||||
|
selected_file = next(
|
||||||
|
(file_info for file_info in files if isinstance(file_info, dict)),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if isinstance(selected_file, dict):
|
||||||
|
raw_file_name = selected_file.get("name", "")
|
||||||
|
if isinstance(raw_file_name, str):
|
||||||
|
file_name = raw_file_name.strip()
|
||||||
|
|
||||||
|
message = (
|
||||||
|
f"Skipped download for '{file_name or version_info.get('name') or f'model_version:{resolved_version_id}'}' "
|
||||||
|
f"because version {resolved_version_id} was already downloaded before"
|
||||||
|
)
|
||||||
|
logger.info(message)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"skipped": True,
|
||||||
|
"status": "skipped",
|
||||||
|
"reason": "previously_downloaded_version",
|
||||||
|
"message": message,
|
||||||
|
"model_version_id": resolved_version_id,
|
||||||
|
"file_name": file_name,
|
||||||
|
"download_id": download_id,
|
||||||
|
}
|
||||||
|
|
||||||
excluded_base_models = get_settings_manager().get_download_skip_base_models()
|
excluded_base_models = get_settings_manager().get_download_skip_base_models()
|
||||||
base_model_value = version_info.get("baseModel", "")
|
base_model_value = version_info.get("baseModel", "")
|
||||||
if (
|
if (
|
||||||
@@ -640,6 +704,13 @@ class DownloadManager:
|
|||||||
or version_info.get("modelId")
|
or version_info.get("modelId")
|
||||||
or (version_info.get("model") or {}).get("id")
|
or (version_info.get("model") or {}).get("id")
|
||||||
)
|
)
|
||||||
|
await self._record_downloaded_version_history(
|
||||||
|
model_type,
|
||||||
|
resolved_model_id,
|
||||||
|
version_info,
|
||||||
|
model_version_id,
|
||||||
|
save_path,
|
||||||
|
)
|
||||||
await self._sync_downloaded_version(
|
await self._sync_downloaded_version(
|
||||||
model_type,
|
model_type,
|
||||||
resolved_model_id,
|
resolved_model_id,
|
||||||
@@ -669,6 +740,55 @@ class DownloadManager:
|
|||||||
}
|
}
|
||||||
return {"success": False, "error": str(e)}
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
async def _record_downloaded_version_history(
|
||||||
|
self,
|
||||||
|
model_type: str,
|
||||||
|
model_id_value,
|
||||||
|
version_info: Dict,
|
||||||
|
fallback_version_id=None,
|
||||||
|
file_path: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
history_service = await ServiceRegistry.get_downloaded_version_history_service()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(
|
||||||
|
"Skipping download history sync; failed to acquire history service: %s",
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if history_service is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
resolved_model_id = model_id_value
|
||||||
|
if resolved_model_id is None:
|
||||||
|
resolved_model_id = version_info.get("modelId")
|
||||||
|
if resolved_model_id is None:
|
||||||
|
model_info = version_info.get("model")
|
||||||
|
if isinstance(model_info, dict):
|
||||||
|
resolved_model_id = model_info.get("id")
|
||||||
|
|
||||||
|
version_id = version_info.get("id")
|
||||||
|
if version_id is None:
|
||||||
|
version_id = fallback_version_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
await history_service.mark_downloaded(
|
||||||
|
model_type,
|
||||||
|
int(version_id),
|
||||||
|
model_id=int(resolved_model_id) if resolved_model_id is not None else None,
|
||||||
|
source="download",
|
||||||
|
file_path=file_path,
|
||||||
|
)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
logger.debug(
|
||||||
|
"Skipping download history sync; invalid identifiers model=%s version=%s",
|
||||||
|
resolved_model_id,
|
||||||
|
version_id,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Failed to sync download history for %s: %s", model_type, exc)
|
||||||
|
|
||||||
async def _sync_downloaded_version(
|
async def _sync_downloaded_version(
|
||||||
self,
|
self,
|
||||||
model_type: str,
|
model_type: str,
|
||||||
|
|||||||
313
py/services/downloaded_version_history_service.py
Normal file
313
py/services/downloaded_version_history_service.py
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
from typing import Iterable, Mapping, Optional, Sequence
|
||||||
|
|
||||||
|
from ..utils.cache_paths import get_cache_base_dir
|
||||||
|
from .settings_manager import get_settings_manager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_model_type(model_type: str | None) -> Optional[str]:
|
||||||
|
if not isinstance(model_type, str):
|
||||||
|
return None
|
||||||
|
normalized = model_type.strip().lower()
|
||||||
|
if normalized in {"lora", "locon", "dora"}:
|
||||||
|
return "lora"
|
||||||
|
if normalized == "checkpoint":
|
||||||
|
return "checkpoint"
|
||||||
|
if normalized in {"embedding", "textualinversion"}:
|
||||||
|
return "embedding"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_int(value) -> Optional[int]:
|
||||||
|
try:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_database_path() -> str:
|
||||||
|
base_dir = get_cache_base_dir(create=True)
|
||||||
|
history_dir = os.path.join(base_dir, "download_history")
|
||||||
|
os.makedirs(history_dir, exist_ok=True)
|
||||||
|
return os.path.join(history_dir, "downloaded_versions.sqlite")
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadedVersionHistoryService:
|
||||||
|
_SCHEMA = """
|
||||||
|
CREATE TABLE IF NOT EXISTS downloaded_model_versions (
|
||||||
|
model_type TEXT NOT NULL,
|
||||||
|
version_id INTEGER NOT NULL,
|
||||||
|
model_id INTEGER,
|
||||||
|
first_seen_at REAL NOT NULL,
|
||||||
|
last_seen_at REAL NOT NULL,
|
||||||
|
source TEXT NOT NULL,
|
||||||
|
last_file_path TEXT,
|
||||||
|
last_library_name TEXT,
|
||||||
|
is_deleted_override INTEGER NOT NULL DEFAULT 0,
|
||||||
|
PRIMARY KEY (model_type, version_id)
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_downloaded_model_versions_model
|
||||||
|
ON downloaded_model_versions(model_type, model_id);
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str | None = None, *, settings_manager=None) -> None:
|
||||||
|
self._db_path = db_path or _resolve_database_path()
|
||||||
|
self._settings = settings_manager or get_settings_manager()
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._schema_initialized = False
|
||||||
|
self._ensure_directory()
|
||||||
|
self._initialize_schema()
|
||||||
|
|
||||||
|
def _ensure_directory(self) -> None:
|
||||||
|
directory = os.path.dirname(self._db_path)
|
||||||
|
if directory:
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
|
||||||
|
def _connect(self) -> sqlite3.Connection:
|
||||||
|
conn = sqlite3.connect(self._db_path, check_same_thread=False)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def _initialize_schema(self) -> None:
|
||||||
|
if self._schema_initialized:
|
||||||
|
return
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.executescript(self._SCHEMA)
|
||||||
|
conn.commit()
|
||||||
|
self._schema_initialized = True
|
||||||
|
|
||||||
|
def get_database_path(self) -> str:
|
||||||
|
return self._db_path
|
||||||
|
|
||||||
|
def _get_active_library_name(self) -> str | None:
|
||||||
|
try:
|
||||||
|
value = self._settings.get_active_library_name()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
return value or None
|
||||||
|
|
||||||
|
async def mark_downloaded(
|
||||||
|
self,
|
||||||
|
model_type: str,
|
||||||
|
version_id: int,
|
||||||
|
*,
|
||||||
|
model_id: int | None = None,
|
||||||
|
source: str = "manual",
|
||||||
|
file_path: str | None = None,
|
||||||
|
library_name: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
normalized_type = _normalize_model_type(model_type)
|
||||||
|
normalized_version_id = _normalize_int(version_id)
|
||||||
|
normalized_model_id = _normalize_int(model_id)
|
||||||
|
if normalized_type is None or normalized_version_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
active_library_name = library_name or self._get_active_library_name()
|
||||||
|
timestamp = time.time()
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO downloaded_model_versions (
|
||||||
|
model_type, version_id, model_id, first_seen_at, last_seen_at,
|
||||||
|
source, last_file_path, last_library_name, is_deleted_override
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0)
|
||||||
|
ON CONFLICT(model_type, version_id) DO UPDATE SET
|
||||||
|
model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id),
|
||||||
|
last_seen_at = excluded.last_seen_at,
|
||||||
|
source = excluded.source,
|
||||||
|
last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path),
|
||||||
|
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
|
||||||
|
is_deleted_override = 0
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
normalized_type,
|
||||||
|
normalized_version_id,
|
||||||
|
normalized_model_id,
|
||||||
|
timestamp,
|
||||||
|
timestamp,
|
||||||
|
source,
|
||||||
|
file_path,
|
||||||
|
active_library_name,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
async def mark_downloaded_bulk(
|
||||||
|
self,
|
||||||
|
model_type: str,
|
||||||
|
records: Sequence[Mapping[str, object]],
|
||||||
|
*,
|
||||||
|
source: str = "scan",
|
||||||
|
library_name: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
normalized_type = _normalize_model_type(model_type)
|
||||||
|
if normalized_type is None or not records:
|
||||||
|
return
|
||||||
|
|
||||||
|
timestamp = time.time()
|
||||||
|
active_library_name = library_name or self._get_active_library_name()
|
||||||
|
payload: list[tuple[object, ...]] = []
|
||||||
|
for record in records:
|
||||||
|
version_id = _normalize_int(record.get("version_id"))
|
||||||
|
if version_id is None:
|
||||||
|
continue
|
||||||
|
payload.append(
|
||||||
|
(
|
||||||
|
normalized_type,
|
||||||
|
version_id,
|
||||||
|
_normalize_int(record.get("model_id")),
|
||||||
|
timestamp,
|
||||||
|
timestamp,
|
||||||
|
source,
|
||||||
|
record.get("file_path"),
|
||||||
|
active_library_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not payload:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.executemany(
|
||||||
|
"""
|
||||||
|
INSERT INTO downloaded_model_versions (
|
||||||
|
model_type, version_id, model_id, first_seen_at, last_seen_at,
|
||||||
|
source, last_file_path, last_library_name, is_deleted_override
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0)
|
||||||
|
ON CONFLICT(model_type, version_id) DO UPDATE SET
|
||||||
|
model_id = COALESCE(excluded.model_id, downloaded_model_versions.model_id),
|
||||||
|
last_seen_at = excluded.last_seen_at,
|
||||||
|
source = excluded.source,
|
||||||
|
last_file_path = COALESCE(excluded.last_file_path, downloaded_model_versions.last_file_path),
|
||||||
|
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
|
||||||
|
is_deleted_override = 0
|
||||||
|
""",
|
||||||
|
payload,
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
async def mark_not_downloaded(self, model_type: str, version_id: int) -> None:
|
||||||
|
normalized_type = _normalize_model_type(model_type)
|
||||||
|
normalized_version_id = _normalize_int(version_id)
|
||||||
|
if normalized_type is None or normalized_version_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
timestamp = time.time()
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO downloaded_model_versions (
|
||||||
|
model_type, version_id, model_id, first_seen_at, last_seen_at,
|
||||||
|
source, last_file_path, last_library_name, is_deleted_override
|
||||||
|
) VALUES (?, ?, NULL, ?, ?, 'manual', NULL, ?, 1)
|
||||||
|
ON CONFLICT(model_type, version_id) DO UPDATE SET
|
||||||
|
last_seen_at = excluded.last_seen_at,
|
||||||
|
source = excluded.source,
|
||||||
|
last_library_name = COALESCE(excluded.last_library_name, downloaded_model_versions.last_library_name),
|
||||||
|
is_deleted_override = 1
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
normalized_type,
|
||||||
|
normalized_version_id,
|
||||||
|
timestamp,
|
||||||
|
timestamp,
|
||||||
|
self._get_active_library_name(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
async def has_been_downloaded(self, model_type: str, version_id: int) -> bool:
|
||||||
|
normalized_type = _normalize_model_type(model_type)
|
||||||
|
normalized_version_id = _normalize_int(version_id)
|
||||||
|
if normalized_type is None or normalized_version_id is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
with self._connect() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT is_deleted_override
|
||||||
|
FROM downloaded_model_versions
|
||||||
|
WHERE model_type = ? AND version_id = ?
|
||||||
|
""",
|
||||||
|
(normalized_type, normalized_version_id),
|
||||||
|
).fetchone()
|
||||||
|
return bool(row) and not bool(row["is_deleted_override"])
|
||||||
|
|
||||||
|
async def get_downloaded_version_ids(
|
||||||
|
self, model_type: str, model_id: int
|
||||||
|
) -> list[int]:
|
||||||
|
normalized_type = _normalize_model_type(model_type)
|
||||||
|
normalized_model_id = _normalize_int(model_id)
|
||||||
|
if normalized_type is None or normalized_model_id is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
with self._connect() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT version_id
|
||||||
|
FROM downloaded_model_versions
|
||||||
|
WHERE model_type = ? AND model_id = ? AND is_deleted_override = 0
|
||||||
|
ORDER BY version_id ASC
|
||||||
|
""",
|
||||||
|
(normalized_type, normalized_model_id),
|
||||||
|
).fetchall()
|
||||||
|
return [int(row["version_id"]) for row in rows]
|
||||||
|
|
||||||
|
async def get_downloaded_version_ids_bulk(
|
||||||
|
self, model_type: str, model_ids: Iterable[int]
|
||||||
|
) -> dict[int, set[int]]:
|
||||||
|
normalized_type = _normalize_model_type(model_type)
|
||||||
|
if normalized_type is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
normalized_model_ids = sorted(
|
||||||
|
{
|
||||||
|
value
|
||||||
|
for value in (_normalize_int(model_id) for model_id in model_ids)
|
||||||
|
if value is not None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if not normalized_model_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
placeholders = ", ".join(["?"] * len(normalized_model_ids))
|
||||||
|
params: list[object] = [normalized_type, *normalized_model_ids]
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
with self._connect() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
f"""
|
||||||
|
SELECT model_id, version_id
|
||||||
|
FROM downloaded_model_versions
|
||||||
|
WHERE model_type = ?
|
||||||
|
AND model_id IN ({placeholders})
|
||||||
|
AND is_deleted_override = 0
|
||||||
|
""",
|
||||||
|
params,
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
result: dict[int, set[int]] = {}
|
||||||
|
for row in rows:
|
||||||
|
model_id = _normalize_int(row["model_id"])
|
||||||
|
version_id = _normalize_int(row["version_id"])
|
||||||
|
if model_id is None or version_id is None:
|
||||||
|
continue
|
||||||
|
result.setdefault(model_id, set()).add(version_id)
|
||||||
|
return result
|
||||||
@@ -411,6 +411,7 @@ class ModelScanner:
|
|||||||
if scan_result:
|
if scan_result:
|
||||||
await self._apply_scan_result(scan_result)
|
await self._apply_scan_result(scan_result)
|
||||||
await self._save_persistent_cache(scan_result)
|
await self._save_persistent_cache(scan_result)
|
||||||
|
await self._sync_download_history(scan_result.raw_data, source='scan')
|
||||||
|
|
||||||
# Send final progress update
|
# Send final progress update
|
||||||
await ws_manager.broadcast_init_progress({
|
await ws_manager.broadcast_init_progress({
|
||||||
@@ -516,6 +517,7 @@ class ModelScanner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
await self._apply_scan_result(scan_result)
|
await self._apply_scan_result(scan_result)
|
||||||
|
await self._sync_download_history(adjusted_raw_data, source='scan')
|
||||||
|
|
||||||
await ws_manager.broadcast_init_progress({
|
await ws_manager.broadcast_init_progress({
|
||||||
'stage': 'loading_cache',
|
'stage': 'loading_cache',
|
||||||
@@ -576,6 +578,7 @@ class ModelScanner:
|
|||||||
excluded_models=list(self._excluded_models)
|
excluded_models=list(self._excluded_models)
|
||||||
)
|
)
|
||||||
await self._save_persistent_cache(snapshot)
|
await self._save_persistent_cache(snapshot)
|
||||||
|
await self._sync_download_history(snapshot.raw_data, source='scan')
|
||||||
def _count_model_files(self) -> int:
|
def _count_model_files(self) -> int:
|
||||||
"""Count all model files with supported extensions in all roots
|
"""Count all model files with supported extensions in all roots
|
||||||
|
|
||||||
@@ -704,6 +707,7 @@ class ModelScanner:
|
|||||||
scan_result = await self._gather_model_data()
|
scan_result = await self._gather_model_data()
|
||||||
await self._apply_scan_result(scan_result)
|
await self._apply_scan_result(scan_result)
|
||||||
await self._save_persistent_cache(scan_result)
|
await self._save_persistent_cache(scan_result)
|
||||||
|
await self._sync_download_history(scan_result.raw_data, source='scan')
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
|
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
|
||||||
@@ -1101,6 +1105,49 @@ class ModelScanner:
|
|||||||
|
|
||||||
await self._cache.resort()
|
await self._cache.resort()
|
||||||
|
|
||||||
|
async def _sync_download_history(
|
||||||
|
self,
|
||||||
|
raw_data: List[Mapping[str, Any]],
|
||||||
|
*,
|
||||||
|
source: str,
|
||||||
|
) -> None:
|
||||||
|
records: List[Dict[str, Any]] = []
|
||||||
|
for item in raw_data or []:
|
||||||
|
if not isinstance(item, Mapping):
|
||||||
|
continue
|
||||||
|
civitai = item.get('civitai')
|
||||||
|
if not isinstance(civitai, Mapping):
|
||||||
|
continue
|
||||||
|
|
||||||
|
version_id = civitai.get('id')
|
||||||
|
if version_id in (None, ''):
|
||||||
|
continue
|
||||||
|
|
||||||
|
records.append(
|
||||||
|
{
|
||||||
|
'version_id': version_id,
|
||||||
|
'model_id': civitai.get('modelId'),
|
||||||
|
'file_path': item.get('file_path'),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
history_service = await ServiceRegistry.get_downloaded_version_history_service()
|
||||||
|
await history_service.mark_downloaded_bulk(
|
||||||
|
self.model_type,
|
||||||
|
records,
|
||||||
|
source=source,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(
|
||||||
|
"%s Scanner: Failed to sync download history: %s",
|
||||||
|
self.model_type.capitalize(),
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
async def _gather_model_data(
|
async def _gather_model_data(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
|||||||
@@ -167,6 +167,28 @@ class ServiceRegistry:
|
|||||||
logger.debug(f"Created and registered {service_name}")
|
logger.debug(f"Created and registered {service_name}")
|
||||||
return service
|
return service
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_downloaded_version_history_service(cls):
|
||||||
|
"""Get or create the downloaded-version history service."""
|
||||||
|
|
||||||
|
service_name = "downloaded_version_history_service"
|
||||||
|
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
async with cls._get_lock(service_name):
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
from .downloaded_version_history_service import (
|
||||||
|
DownloadedVersionHistoryService,
|
||||||
|
)
|
||||||
|
|
||||||
|
service = DownloadedVersionHistoryService()
|
||||||
|
cls._services[service_name] = service
|
||||||
|
logger.debug(f"Created and registered {service_name}")
|
||||||
|
return service
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_civarchive_client(cls):
|
async def get_civarchive_client(cls):
|
||||||
"""Get or create CivArchive client instance"""
|
"""Get or create CivArchive client instance"""
|
||||||
@@ -255,4 +277,4 @@ class ServiceRegistry:
|
|||||||
"""Clear all registered services - mainly for testing"""
|
"""Clear all registered services - mainly for testing"""
|
||||||
cls._services.clear()
|
cls._services.clear()
|
||||||
cls._locks.clear()
|
cls._locks.clear()
|
||||||
logger.info("Cleared all registered services")
|
logger.info("Cleared all registered services")
|
||||||
|
|||||||
@@ -91,6 +91,7 @@ DEFAULT_SETTINGS: Dict[str, Any] = {
|
|||||||
"update_flag_strategy": "same_base",
|
"update_flag_strategy": "same_base",
|
||||||
"auto_organize_exclusions": [],
|
"auto_organize_exclusions": [],
|
||||||
"metadata_refresh_skip_paths": [],
|
"metadata_refresh_skip_paths": [],
|
||||||
|
"skip_previously_downloaded_model_versions": False,
|
||||||
"download_skip_base_models": [],
|
"download_skip_base_models": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -314,6 +315,10 @@ class SettingsManager:
|
|||||||
self.settings["download_skip_base_models"] = []
|
self.settings["download_skip_base_models"] = []
|
||||||
inserted_defaults = True
|
inserted_defaults = True
|
||||||
|
|
||||||
|
if "skip_previously_downloaded_model_versions" not in self.settings:
|
||||||
|
self.settings["skip_previously_downloaded_model_versions"] = False
|
||||||
|
inserted_defaults = True
|
||||||
|
|
||||||
had_mature_level = "mature_blur_level" in self.settings
|
had_mature_level = "mature_blur_level" in self.settings
|
||||||
raw_mature_level = self.settings.get("mature_blur_level")
|
raw_mature_level = self.settings.get("mature_blur_level")
|
||||||
normalized_mature_level = self.normalize_mature_blur_level(raw_mature_level)
|
normalized_mature_level = self.normalize_mature_blur_level(raw_mature_level)
|
||||||
@@ -1090,6 +1095,17 @@ class SettingsManager:
|
|||||||
self._save_settings()
|
self._save_settings()
|
||||||
return base_models
|
return base_models
|
||||||
|
|
||||||
|
def get_skip_previously_downloaded_model_versions(self) -> bool:
|
||||||
|
value = self.settings.get("skip_previously_downloaded_model_versions", False)
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
normalized = False
|
||||||
|
if isinstance(value, str):
|
||||||
|
normalized = value.strip().lower() in {"1", "true", "yes", "on"}
|
||||||
|
self.settings["skip_previously_downloaded_model_versions"] = normalized
|
||||||
|
self._save_settings()
|
||||||
|
return normalized
|
||||||
|
|
||||||
def get_extra_folder_paths(self) -> Dict[str, List[str]]:
|
def get_extra_folder_paths(self) -> Dict[str, List[str]]:
|
||||||
"""Get extra folder paths for the active library.
|
"""Get extra folder paths for the active library.
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "comfyui-lora-manager"
|
name = "comfyui-lora-manager"
|
||||||
description = "Revolutionize your workflow with the ultimate LoRA companion for ComfyUI!"
|
description = "Revolutionize your workflow with the ultimate LoRA companion for ComfyUI!"
|
||||||
version = "1.0.1"
|
version = "1.0.2"
|
||||||
license = {file = "LICENSE"}
|
license = {file = "LICENSE"}
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
|
|||||||
@@ -146,6 +146,10 @@ export class SettingsManager {
|
|||||||
backendSettings?.metadata_refresh_skip_paths ?? defaults.metadata_refresh_skip_paths
|
backendSettings?.metadata_refresh_skip_paths ?? defaults.metadata_refresh_skip_paths
|
||||||
);
|
);
|
||||||
|
|
||||||
|
merged.skip_previously_downloaded_model_versions =
|
||||||
|
backendSettings?.skip_previously_downloaded_model_versions
|
||||||
|
?? defaults.skip_previously_downloaded_model_versions;
|
||||||
|
|
||||||
merged.download_skip_base_models = this.normalizeDownloadSkipBaseModels(
|
merged.download_skip_base_models = this.normalizeDownloadSkipBaseModels(
|
||||||
backendSettings?.download_skip_base_models ?? defaults.download_skip_base_models
|
backendSettings?.download_skip_base_models ?? defaults.download_skip_base_models
|
||||||
);
|
);
|
||||||
@@ -836,6 +840,12 @@ export class SettingsManager {
|
|||||||
hideEarlyAccessUpdatesCheckbox.checked = state.global.settings.hide_early_access_updates || false;
|
hideEarlyAccessUpdatesCheckbox.checked = state.global.settings.hide_early_access_updates || false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const skipPreviouslyDownloadedModelVersionsCheckbox = document.getElementById('skipPreviouslyDownloadedModelVersions');
|
||||||
|
if (skipPreviouslyDownloadedModelVersionsCheckbox) {
|
||||||
|
skipPreviouslyDownloadedModelVersionsCheckbox.checked =
|
||||||
|
state.global.settings.skip_previously_downloaded_model_versions || false;
|
||||||
|
}
|
||||||
|
|
||||||
// Set optimize example images setting
|
// Set optimize example images setting
|
||||||
const optimizeExampleImagesCheckbox = document.getElementById('optimizeExampleImages');
|
const optimizeExampleImagesCheckbox = document.getElementById('optimizeExampleImages');
|
||||||
if (optimizeExampleImagesCheckbox) {
|
if (optimizeExampleImagesCheckbox) {
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ const DEFAULT_SETTINGS_BASE = Object.freeze({
|
|||||||
hide_early_access_updates: false,
|
hide_early_access_updates: false,
|
||||||
auto_organize_exclusions: [],
|
auto_organize_exclusions: [],
|
||||||
metadata_refresh_skip_paths: [],
|
metadata_refresh_skip_paths: [],
|
||||||
|
skip_previously_downloaded_model_versions: false,
|
||||||
download_skip_base_models: [],
|
download_skip_base_models: [],
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -735,6 +735,24 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div class="setting-item">
|
||||||
|
<div class="setting-row">
|
||||||
|
<div class="setting-info">
|
||||||
|
<label for="skipPreviouslyDownloadedModelVersions">
|
||||||
|
{{ t('settings.skipPreviouslyDownloadedModelVersions.label') }}
|
||||||
|
<i class="fas fa-info-circle info-icon" data-tooltip="{{ t('settings.skipPreviouslyDownloadedModelVersions.help') }}"></i>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
<div class="setting-control">
|
||||||
|
<label class="toggle-switch">
|
||||||
|
<input type="checkbox" id="skipPreviouslyDownloadedModelVersions"
|
||||||
|
onchange="settingsManager.saveToggleSetting('skipPreviouslyDownloadedModelVersions', 'skip_previously_downloaded_model_versions')">
|
||||||
|
<span class="toggle-slider"></span>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<div class="setting-item">
|
<div class="setting-item">
|
||||||
<div class="setting-row">
|
<div class="setting-row">
|
||||||
<div class="setting-info">
|
<div class="setting-info">
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ vi.mock('../../../static/js/state/index.js', () => {
|
|||||||
},
|
},
|
||||||
createDefaultSettings: () => ({
|
createDefaultSettings: () => ({
|
||||||
language: 'en',
|
language: 'en',
|
||||||
|
skip_previously_downloaded_model_versions: false,
|
||||||
download_skip_base_models: [],
|
download_skip_base_models: [],
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
@@ -117,6 +118,7 @@ describe('SettingsManager download skip base models UI', () => {
|
|||||||
document.body.innerHTML = '';
|
document.body.innerHTML = '';
|
||||||
vi.clearAllMocks();
|
vi.clearAllMocks();
|
||||||
state.global.settings = {
|
state.global.settings = {
|
||||||
|
skip_previously_downloaded_model_versions: false,
|
||||||
download_skip_base_models: [],
|
download_skip_base_models: [],
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
@@ -150,4 +152,31 @@ describe('SettingsManager download skip base models UI', () => {
|
|||||||
expect(document.querySelectorAll('#downloadSkipBaseModelsContainer input')).toHaveLength(0);
|
expect(document.querySelectorAll('#downloadSkipBaseModelsContainer input')).toHaveLength(0);
|
||||||
expect(document.getElementById('downloadSkipBaseModelsEmpty').hidden).toBe(false);
|
expect(document.getElementById('downloadSkipBaseModelsEmpty').hidden).toBe(false);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('initializes the previously-downloaded-version toggle from settings', () => {
|
||||||
|
document.body.innerHTML = '<input id="skipPreviouslyDownloadedModelVersions" type="checkbox" />';
|
||||||
|
state.global.settings.skip_previously_downloaded_model_versions = true;
|
||||||
|
const manager = createManager();
|
||||||
|
|
||||||
|
manager.loadSettingsToUI();
|
||||||
|
|
||||||
|
expect(document.getElementById('skipPreviouslyDownloadedModelVersions').checked).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('saves the previously-downloaded-version toggle with the expected setting key', async () => {
|
||||||
|
document.body.innerHTML = '<input id="skipPreviouslyDownloadedModelVersions" type="checkbox" checked />';
|
||||||
|
const manager = createManager();
|
||||||
|
manager.saveSetting = vi.fn().mockResolvedValue();
|
||||||
|
manager.applyFrontendSettings = vi.fn();
|
||||||
|
|
||||||
|
await manager.saveToggleSetting(
|
||||||
|
'skipPreviouslyDownloadedModelVersions',
|
||||||
|
'skip_previously_downloaded_model_versions',
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(manager.saveSetting).toHaveBeenCalledWith(
|
||||||
|
'skip_previously_downloaded_model_versions',
|
||||||
|
true,
|
||||||
|
);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
151
tests/frontend/utils/loraChainTraversal.test.js
Normal file
151
tests/frontend/utils/loraChainTraversal.test.js
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||||
|
|
||||||
|
const { APP_MODULE, UTILS_MODULE } = vi.hoisted(() => ({
|
||||||
|
APP_MODULE: new URL("../../../scripts/app.js", import.meta.url).pathname,
|
||||||
|
UTILS_MODULE: new URL("../../../web/comfyui/utils.js", import.meta.url).pathname,
|
||||||
|
}));
|
||||||
|
|
||||||
|
vi.mock(APP_MODULE, () => ({
|
||||||
|
app: {
|
||||||
|
graph: null,
|
||||||
|
registerExtension: vi.fn(),
|
||||||
|
ui: {
|
||||||
|
settings: {
|
||||||
|
getSettingValue: vi.fn(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
describe("LoRA chain traversal", () => {
|
||||||
|
let collectActiveLorasFromChain;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
vi.resetModules();
|
||||||
|
({ collectActiveLorasFromChain } = await import(UTILS_MODULE));
|
||||||
|
});
|
||||||
|
|
||||||
|
function createGraph(nodes, links) {
|
||||||
|
const graph = {
|
||||||
|
_nodes: nodes,
|
||||||
|
links,
|
||||||
|
getNodeById(id) {
|
||||||
|
return nodes.find((node) => node.id === id) ?? null;
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
nodes.forEach((node) => {
|
||||||
|
node.graph = graph;
|
||||||
|
});
|
||||||
|
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
it("aggregates active LoRAs through a combiner with multiple LORA_STACK inputs", () => {
|
||||||
|
const randomizerA = {
|
||||||
|
id: 1,
|
||||||
|
comfyClass: "Lora Randomizer (LoraManager)",
|
||||||
|
mode: 0,
|
||||||
|
widgets: [
|
||||||
|
{
|
||||||
|
name: "loras",
|
||||||
|
value: [
|
||||||
|
{ name: "Alpha", active: true },
|
||||||
|
{ name: "Ignored", active: false },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
inputs: [],
|
||||||
|
outputs: [],
|
||||||
|
};
|
||||||
|
const randomizerB = {
|
||||||
|
id: 2,
|
||||||
|
comfyClass: "Lora Randomizer (LoraManager)",
|
||||||
|
mode: 0,
|
||||||
|
widgets: [
|
||||||
|
{
|
||||||
|
name: "loras",
|
||||||
|
value: [{ name: "Beta", active: true }],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
inputs: [],
|
||||||
|
outputs: [],
|
||||||
|
};
|
||||||
|
const combiner = {
|
||||||
|
id: 3,
|
||||||
|
comfyClass: "Lora Stack Combiner (LoraManager)",
|
||||||
|
mode: 0,
|
||||||
|
widgets: [],
|
||||||
|
inputs: [
|
||||||
|
{ name: "lora_stack_a", type: "LORA_STACK", link: 11 },
|
||||||
|
{ name: "lora_stack_b", type: "LORA_STACK", link: 12 },
|
||||||
|
],
|
||||||
|
outputs: [],
|
||||||
|
};
|
||||||
|
const loader = {
|
||||||
|
id: 4,
|
||||||
|
comfyClass: "Lora Loader (LoraManager)",
|
||||||
|
mode: 0,
|
||||||
|
widgets: [],
|
||||||
|
inputs: [{ name: "lora_stack", type: "LORA_STACK", link: 13 }],
|
||||||
|
outputs: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
createGraph(
|
||||||
|
[randomizerA, randomizerB, combiner, loader],
|
||||||
|
{
|
||||||
|
11: { origin_id: 1, target_id: 3 },
|
||||||
|
12: { origin_id: 2, target_id: 3 },
|
||||||
|
13: { origin_id: 3, target_id: 4 },
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const result = collectActiveLorasFromChain(loader);
|
||||||
|
|
||||||
|
expect([...result]).toEqual(["Alpha", "Beta"]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("stops propagation when the combiner is inactive", () => {
|
||||||
|
const randomizer = {
|
||||||
|
id: 1,
|
||||||
|
comfyClass: "Lora Randomizer (LoraManager)",
|
||||||
|
mode: 0,
|
||||||
|
widgets: [
|
||||||
|
{
|
||||||
|
name: "loras",
|
||||||
|
value: [{ name: "Alpha", active: true }],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
inputs: [],
|
||||||
|
outputs: [],
|
||||||
|
};
|
||||||
|
const combiner = {
|
||||||
|
id: 2,
|
||||||
|
comfyClass: "Lora Stack Combiner (LoraManager)",
|
||||||
|
mode: 2,
|
||||||
|
widgets: [],
|
||||||
|
inputs: [{ name: "lora_stack_a", type: "LORA_STACK", link: 21 }],
|
||||||
|
outputs: [],
|
||||||
|
};
|
||||||
|
const loader = {
|
||||||
|
id: 3,
|
||||||
|
comfyClass: "Lora Loader (LoraManager)",
|
||||||
|
mode: 0,
|
||||||
|
widgets: [],
|
||||||
|
inputs: [{ name: "lora_stack", type: "LORA_STACK", link: 22 }],
|
||||||
|
outputs: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
createGraph(
|
||||||
|
[randomizer, combiner, loader],
|
||||||
|
{
|
||||||
|
21: { origin_id: 1, target_id: 2 },
|
||||||
|
22: { origin_id: 2, target_id: 3 },
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const result = collectActiveLorasFromChain(loader);
|
||||||
|
|
||||||
|
expect(result.size).toBe(0);
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: TestModelLibraryHandlerSnapshots.test_check_model_exists_empty_response
|
# name: TestModelLibraryHandlerSnapshots.test_check_model_exists_empty_response
|
||||||
dict({
|
dict({
|
||||||
|
'downloadedVersionIds': list([
|
||||||
|
]),
|
||||||
'modelType': None,
|
'modelType': None,
|
||||||
'success': True,
|
'success': True,
|
||||||
'versions': list([
|
'versions': list([
|
||||||
|
|||||||
@@ -66,6 +66,27 @@ class FakePromptServer:
|
|||||||
instance = Instance()
|
instance = Instance()
|
||||||
|
|
||||||
|
|
||||||
|
class FakeDownloadHistoryService:
|
||||||
|
async def has_been_downloaded(self, _model_type, _version_id):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_downloaded_version_ids(self, _model_type, _model_id):
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_downloaded_version_ids_bulk(self, _model_type, _model_ids):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def mark_downloaded(self, *_args, **_kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def mark_not_downloaded(self, *_args, **_kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def fake_download_history_service_factory():
|
||||||
|
return FakeDownloadHistoryService()
|
||||||
|
|
||||||
|
|
||||||
class TestSettingsHandlerSnapshots:
|
class TestSettingsHandlerSnapshots:
|
||||||
"""Snapshot tests for SettingsHandler responses."""
|
"""Snapshot tests for SettingsHandler responses."""
|
||||||
|
|
||||||
@@ -223,6 +244,7 @@ class TestModelLibraryHandlerSnapshots:
|
|||||||
get_lora_scanner=scanner_factory,
|
get_lora_scanner=scanner_factory,
|
||||||
get_checkpoint_scanner=scanner_factory,
|
get_checkpoint_scanner=scanner_factory,
|
||||||
get_embedding_scanner=scanner_factory,
|
get_embedding_scanner=scanner_factory,
|
||||||
|
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||||
),
|
),
|
||||||
metadata_provider_factory=lambda: None,
|
metadata_provider_factory=lambda: None,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -23,9 +23,10 @@ from py.routes.misc_routes import MiscRoutes
|
|||||||
|
|
||||||
|
|
||||||
class FakeRequest:
|
class FakeRequest:
|
||||||
def __init__(self, *, json_data=None, query=None):
|
def __init__(self, *, json_data=None, query=None, method="POST"):
|
||||||
self._json_data = json_data or {}
|
self._json_data = json_data or {}
|
||||||
self.query = query or {}
|
self.query = query or {}
|
||||||
|
self.method = method
|
||||||
|
|
||||||
async def json(self):
|
async def json(self):
|
||||||
return self._json_data
|
return self._json_data
|
||||||
@@ -438,6 +439,46 @@ async def fake_metadata_archive_manager_factory():
|
|||||||
return FakeMetadataArchiveManager()
|
return FakeMetadataArchiveManager()
|
||||||
|
|
||||||
|
|
||||||
|
class FakeDownloadHistoryService:
|
||||||
|
def __init__(self, downloaded_by_type=None):
|
||||||
|
self.downloaded_by_type = downloaded_by_type or {}
|
||||||
|
self.marked_downloaded: list[tuple] = []
|
||||||
|
self.marked_not_downloaded: list[tuple] = []
|
||||||
|
|
||||||
|
async def has_been_downloaded(self, model_type, version_id):
|
||||||
|
return version_id in self.downloaded_by_type.get(model_type, set())
|
||||||
|
|
||||||
|
async def get_downloaded_version_ids(self, model_type, model_id):
|
||||||
|
entries = self.downloaded_by_type.get(model_type, {})
|
||||||
|
if isinstance(entries, dict):
|
||||||
|
return sorted(entries.get(model_id, set()))
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_downloaded_version_ids_bulk(self, model_type, model_ids):
|
||||||
|
entries = self.downloaded_by_type.get(model_type, {})
|
||||||
|
if not isinstance(entries, dict):
|
||||||
|
return {}
|
||||||
|
return {
|
||||||
|
model_id: set(entries.get(model_id, set()))
|
||||||
|
for model_id in model_ids
|
||||||
|
if model_id in entries
|
||||||
|
}
|
||||||
|
|
||||||
|
async def mark_downloaded(
|
||||||
|
self, model_type, version_id, *, model_id=None, source="manual", file_path=None
|
||||||
|
):
|
||||||
|
self.marked_downloaded.append(
|
||||||
|
(model_type, version_id, model_id, source, file_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mark_not_downloaded(self, model_type, version_id):
|
||||||
|
self.marked_not_downloaded.append((model_type, version_id))
|
||||||
|
|
||||||
|
|
||||||
|
async def fake_download_history_service_factory():
|
||||||
|
return FakeDownloadHistoryService()
|
||||||
|
|
||||||
|
|
||||||
class RecordingRegistrar:
|
class RecordingRegistrar:
|
||||||
def __init__(self, _app):
|
def __init__(self, _app):
|
||||||
self.registered_mapping = None
|
self.registered_mapping = None
|
||||||
@@ -452,6 +493,7 @@ async def test_misc_routes_bind_produces_expected_handlers():
|
|||||||
get_lora_scanner=fake_scanner_factory,
|
get_lora_scanner=fake_scanner_factory,
|
||||||
get_checkpoint_scanner=fake_scanner_factory,
|
get_checkpoint_scanner=fake_scanner_factory,
|
||||||
get_embedding_scanner=fake_scanner_factory,
|
get_embedding_scanner=fake_scanner_factory,
|
||||||
|
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||||
)
|
)
|
||||||
|
|
||||||
recorded_registrars = []
|
recorded_registrars = []
|
||||||
@@ -578,6 +620,7 @@ async def test_get_civitai_user_models_marks_library_versions():
|
|||||||
get_lora_scanner=lora_factory,
|
get_lora_scanner=lora_factory,
|
||||||
get_checkpoint_scanner=checkpoint_factory,
|
get_checkpoint_scanner=checkpoint_factory,
|
||||||
get_embedding_scanner=embedding_factory,
|
get_embedding_scanner=embedding_factory,
|
||||||
|
get_downloaded_version_history_service=lambda: fake_download_history_service_factory(),
|
||||||
),
|
),
|
||||||
metadata_provider_factory=provider_factory,
|
metadata_provider_factory=provider_factory,
|
||||||
)
|
)
|
||||||
@@ -600,6 +643,7 @@ async def test_get_civitai_user_models_marks_library_versions():
|
|||||||
"baseModel": "Flux.1",
|
"baseModel": "Flux.1",
|
||||||
"thumbnailUrl": "http://example.com/a1.jpg",
|
"thumbnailUrl": "http://example.com/a1.jpg",
|
||||||
"inLibrary": False,
|
"inLibrary": False,
|
||||||
|
"hasBeenDownloaded": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"modelId": 1,
|
"modelId": 1,
|
||||||
@@ -611,6 +655,7 @@ async def test_get_civitai_user_models_marks_library_versions():
|
|||||||
"baseModel": "Flux.1",
|
"baseModel": "Flux.1",
|
||||||
"thumbnailUrl": "http://example.com/a2.jpg",
|
"thumbnailUrl": "http://example.com/a2.jpg",
|
||||||
"inLibrary": True,
|
"inLibrary": True,
|
||||||
|
"hasBeenDownloaded": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"modelId": 2,
|
"modelId": 2,
|
||||||
@@ -622,6 +667,7 @@ async def test_get_civitai_user_models_marks_library_versions():
|
|||||||
"baseModel": None,
|
"baseModel": None,
|
||||||
"thumbnailUrl": "http://example.com/e1.jpg",
|
"thumbnailUrl": "http://example.com/e1.jpg",
|
||||||
"inLibrary": False,
|
"inLibrary": False,
|
||||||
|
"hasBeenDownloaded": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"modelId": 2,
|
"modelId": 2,
|
||||||
@@ -633,6 +679,7 @@ async def test_get_civitai_user_models_marks_library_versions():
|
|||||||
"baseModel": None,
|
"baseModel": None,
|
||||||
"thumbnailUrl": None,
|
"thumbnailUrl": None,
|
||||||
"inLibrary": True,
|
"inLibrary": True,
|
||||||
|
"hasBeenDownloaded": False,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"modelId": 3,
|
"modelId": 3,
|
||||||
@@ -644,6 +691,7 @@ async def test_get_civitai_user_models_marks_library_versions():
|
|||||||
"baseModel": "SDXL",
|
"baseModel": "SDXL",
|
||||||
"thumbnailUrl": None,
|
"thumbnailUrl": None,
|
||||||
"inLibrary": False,
|
"inLibrary": False,
|
||||||
|
"hasBeenDownloaded": False,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -692,6 +740,7 @@ async def test_get_civitai_user_models_rewrites_civitai_previews():
|
|||||||
get_lora_scanner=fake_scanner_factory,
|
get_lora_scanner=fake_scanner_factory,
|
||||||
get_checkpoint_scanner=fake_scanner_factory,
|
get_checkpoint_scanner=fake_scanner_factory,
|
||||||
get_embedding_scanner=fake_scanner_factory,
|
get_embedding_scanner=fake_scanner_factory,
|
||||||
|
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||||
),
|
),
|
||||||
metadata_provider_factory=provider_factory,
|
metadata_provider_factory=provider_factory,
|
||||||
)
|
)
|
||||||
@@ -727,6 +776,7 @@ async def test_get_civitai_user_models_requires_username():
|
|||||||
get_lora_scanner=fake_scanner_factory,
|
get_lora_scanner=fake_scanner_factory,
|
||||||
get_checkpoint_scanner=fake_scanner_factory,
|
get_checkpoint_scanner=fake_scanner_factory,
|
||||||
get_embedding_scanner=fake_scanner_factory,
|
get_embedding_scanner=fake_scanner_factory,
|
||||||
|
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||||
),
|
),
|
||||||
metadata_provider_factory=provider_factory,
|
metadata_provider_factory=provider_factory,
|
||||||
)
|
)
|
||||||
@@ -760,6 +810,7 @@ def test_ensure_handler_mapping_caches_result():
|
|||||||
get_lora_scanner=fake_scanner_factory,
|
get_lora_scanner=fake_scanner_factory,
|
||||||
get_checkpoint_scanner=fake_scanner_factory,
|
get_checkpoint_scanner=fake_scanner_factory,
|
||||||
get_embedding_scanner=fake_scanner_factory,
|
get_embedding_scanner=fake_scanner_factory,
|
||||||
|
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||||
),
|
),
|
||||||
metadata_provider_factory=fake_metadata_provider_factory,
|
metadata_provider_factory=fake_metadata_provider_factory,
|
||||||
metadata_archive_manager_factory=fake_metadata_archive_manager_factory,
|
metadata_archive_manager_factory=fake_metadata_archive_manager_factory,
|
||||||
@@ -802,6 +853,7 @@ async def test_check_model_exists_returns_local_versions():
|
|||||||
get_lora_scanner=lora_factory,
|
get_lora_scanner=lora_factory,
|
||||||
get_checkpoint_scanner=checkpoint_factory,
|
get_checkpoint_scanner=checkpoint_factory,
|
||||||
get_embedding_scanner=embedding_factory,
|
get_embedding_scanner=embedding_factory,
|
||||||
|
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||||
),
|
),
|
||||||
metadata_provider_factory=fake_metadata_provider_factory,
|
metadata_provider_factory=fake_metadata_provider_factory,
|
||||||
)
|
)
|
||||||
@@ -811,10 +863,139 @@ async def test_check_model_exists_returns_local_versions():
|
|||||||
|
|
||||||
assert payload["success"] is True
|
assert payload["success"] is True
|
||||||
assert payload["modelType"] == "lora"
|
assert payload["modelType"] == "lora"
|
||||||
assert payload["versions"] == versions
|
assert payload["versions"] == [
|
||||||
|
{"versionId": 11, "name": "v1", "fileName": "model-one", "hasBeenDownloaded": True},
|
||||||
|
{"versionId": 12, "name": "v2", "fileName": "model-two", "hasBeenDownloaded": True},
|
||||||
|
]
|
||||||
assert lora_scanner.version_calls == [5]
|
assert lora_scanner.version_calls == [5]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_model_exists_model_id_only_does_not_call_metadata_provider():
|
||||||
|
async def metadata_provider_factory():
|
||||||
|
raise AssertionError("metadata provider should not be called for modelId-only checks")
|
||||||
|
|
||||||
|
handler = ModelLibraryHandler(
|
||||||
|
ServiceRegistryAdapter(
|
||||||
|
get_lora_scanner=fake_scanner_factory,
|
||||||
|
get_checkpoint_scanner=fake_scanner_factory,
|
||||||
|
get_embedding_scanner=fake_scanner_factory,
|
||||||
|
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||||
|
),
|
||||||
|
metadata_provider_factory=metadata_provider_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handler.check_model_exists(FakeRequest(query={"modelId": "5"}))
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert payload == {
|
||||||
|
"success": True,
|
||||||
|
"modelType": None,
|
||||||
|
"versions": [],
|
||||||
|
"downloadedVersionIds": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_model_exists_returns_download_history_when_file_missing():
|
||||||
|
history_service = FakeDownloadHistoryService({"checkpoint": {999}})
|
||||||
|
|
||||||
|
async def history_factory():
|
||||||
|
return history_service
|
||||||
|
|
||||||
|
handler = ModelLibraryHandler(
|
||||||
|
ServiceRegistryAdapter(
|
||||||
|
get_lora_scanner=fake_scanner_factory,
|
||||||
|
get_checkpoint_scanner=fake_scanner_factory,
|
||||||
|
get_embedding_scanner=fake_scanner_factory,
|
||||||
|
get_downloaded_version_history_service=history_factory,
|
||||||
|
),
|
||||||
|
metadata_provider_factory=fake_metadata_provider_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handler.check_model_exists(
|
||||||
|
FakeRequest(query={"modelId": "5", "modelVersionId": "999"})
|
||||||
|
)
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert payload == {
|
||||||
|
"success": True,
|
||||||
|
"exists": False,
|
||||||
|
"modelType": "checkpoint",
|
||||||
|
"hasBeenDownloaded": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_version_download_status_endpoints():
|
||||||
|
history_service = FakeDownloadHistoryService({"lora": {123}})
|
||||||
|
|
||||||
|
async def history_factory():
|
||||||
|
return history_service
|
||||||
|
|
||||||
|
handler = ModelLibraryHandler(
|
||||||
|
ServiceRegistryAdapter(
|
||||||
|
get_lora_scanner=fake_scanner_factory,
|
||||||
|
get_checkpoint_scanner=fake_scanner_factory,
|
||||||
|
get_embedding_scanner=fake_scanner_factory,
|
||||||
|
get_downloaded_version_history_service=history_factory,
|
||||||
|
),
|
||||||
|
metadata_provider_factory=fake_metadata_provider_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
get_response = await handler.get_model_version_download_status(
|
||||||
|
FakeRequest(query={"modelType": "lora", "modelVersionId": "123"})
|
||||||
|
)
|
||||||
|
get_payload = json.loads(get_response.text)
|
||||||
|
assert get_payload == {
|
||||||
|
"success": True,
|
||||||
|
"modelType": "lora",
|
||||||
|
"modelVersionId": 123,
|
||||||
|
"hasBeenDownloaded": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
set_response = await handler.set_model_version_download_status(
|
||||||
|
FakeRequest(
|
||||||
|
json_data={
|
||||||
|
"modelType": "checkpoint",
|
||||||
|
"modelVersionId": 456,
|
||||||
|
"modelId": 78,
|
||||||
|
"downloaded": True,
|
||||||
|
"filePath": "/tmp/model.safetensors",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
set_payload = json.loads(set_response.text)
|
||||||
|
assert set_payload == {
|
||||||
|
"success": True,
|
||||||
|
"modelType": "checkpoint",
|
||||||
|
"modelVersionId": 456,
|
||||||
|
"hasBeenDownloaded": True,
|
||||||
|
}
|
||||||
|
assert history_service.marked_downloaded == [
|
||||||
|
("checkpoint", 456, 78, "manual", "/tmp/model.safetensors")
|
||||||
|
]
|
||||||
|
|
||||||
|
set_get_response = await handler.set_model_version_download_status(
|
||||||
|
FakeRequest(
|
||||||
|
method="GET",
|
||||||
|
query={
|
||||||
|
"modelType": "embedding",
|
||||||
|
"modelVersionId": "789",
|
||||||
|
"modelId": "12",
|
||||||
|
"downloaded": "false",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
set_get_payload = json.loads(set_get_response.text)
|
||||||
|
assert set_get_payload == {
|
||||||
|
"success": True,
|
||||||
|
"modelType": "embedding",
|
||||||
|
"modelVersionId": 789,
|
||||||
|
"hasBeenDownloaded": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_create_handler_set_uses_provided_dependencies():
|
def test_create_handler_set_uses_provided_dependencies():
|
||||||
recorded_handlers: list[dict] = []
|
recorded_handlers: list[dict] = []
|
||||||
|
|
||||||
@@ -845,6 +1026,7 @@ def test_create_handler_set_uses_provided_dependencies():
|
|||||||
get_lora_scanner=fake_scanner_factory,
|
get_lora_scanner=fake_scanner_factory,
|
||||||
get_checkpoint_scanner=fake_scanner_factory,
|
get_checkpoint_scanner=fake_scanner_factory,
|
||||||
get_embedding_scanner=fake_scanner_factory,
|
get_embedding_scanner=fake_scanner_factory,
|
||||||
|
get_downloaded_version_history_service=fake_download_history_service_factory,
|
||||||
),
|
),
|
||||||
metadata_provider_factory=fake_metadata_provider_factory,
|
metadata_provider_factory=fake_metadata_provider_factory,
|
||||||
metadata_archive_manager_factory=fake_metadata_archive_manager_factory,
|
metadata_archive_manager_factory=fake_metadata_archive_manager_factory,
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ def isolate_settings(monkeypatch, tmp_path):
|
|||||||
"embedding": "{base_model}/{first_tag}",
|
"embedding": "{base_model}/{first_tag}",
|
||||||
},
|
},
|
||||||
"base_model_path_mappings": {"BaseModel": "MappedModel"},
|
"base_model_path_mappings": {"BaseModel": "MappedModel"},
|
||||||
|
"skip_previously_downloaded_model_versions": False,
|
||||||
"download_skip_base_models": [],
|
"download_skip_base_models": [],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -454,7 +455,7 @@ async def test_download_skips_excluded_base_model(monkeypatch, scanners, metadat
|
|||||||
|
|
||||||
metadata_provider.get_model_version = AsyncMock(
|
metadata_provider.get_model_version = AsyncMock(
|
||||||
return_value={
|
return_value={
|
||||||
"id": 42,
|
"id": 99,
|
||||||
"model": {"type": "LoRA", "tags": ["fantasy"]},
|
"model": {"type": "LoRA", "tags": ["fantasy"]},
|
||||||
"baseModel": "SDXL 1.0",
|
"baseModel": "SDXL 1.0",
|
||||||
"creator": {"username": "Author"},
|
"creator": {"username": "Author"},
|
||||||
@@ -490,3 +491,104 @@ async def test_download_skips_excluded_base_model(monkeypatch, scanners, metadat
|
|||||||
assert "file.safetensors" in result["message"]
|
assert "file.safetensors" in result["message"]
|
||||||
execute_download.assert_not_called()
|
execute_download.assert_not_called()
|
||||||
assert manager._active_downloads[result["download_id"]]["status"] == "skipped"
|
assert manager._active_downloads[result["download_id"]]["status"] == "skipped"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_skips_previously_downloaded_version(monkeypatch, scanners, metadata_provider):
|
||||||
|
manager = DownloadManager()
|
||||||
|
get_settings_manager().settings["skip_previously_downloaded_model_versions"] = True
|
||||||
|
|
||||||
|
metadata_provider.get_model_version = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"id": 42,
|
||||||
|
"model": {"type": "LoRA", "tags": ["fantasy"]},
|
||||||
|
"baseModel": "SDXL 1.0",
|
||||||
|
"creator": {"username": "Author"},
|
||||||
|
"files": [
|
||||||
|
{
|
||||||
|
"type": "Model",
|
||||||
|
"primary": True,
|
||||||
|
"downloadUrl": "https://example.invalid/file.safetensors",
|
||||||
|
"name": "file.safetensors",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
history_service = AsyncMock()
|
||||||
|
history_service.has_been_downloaded = AsyncMock(return_value=True)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
ServiceRegistry,
|
||||||
|
"get_downloaded_version_history_service",
|
||||||
|
AsyncMock(return_value=history_service),
|
||||||
|
)
|
||||||
|
|
||||||
|
execute_download = AsyncMock()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
DownloadManager, "_execute_download", execute_download, raising=False
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await manager.download_from_civitai(
|
||||||
|
model_version_id=99,
|
||||||
|
use_default_paths=True,
|
||||||
|
progress_callback=None,
|
||||||
|
source=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["skipped"] is True
|
||||||
|
assert result["status"] == "skipped"
|
||||||
|
assert result["reason"] == "previously_downloaded_version"
|
||||||
|
assert result["model_version_id"] == 99
|
||||||
|
assert result["file_name"] == "file.safetensors"
|
||||||
|
history_service.has_been_downloaded.assert_awaited_once_with("lora", 99)
|
||||||
|
execute_download.assert_not_called()
|
||||||
|
assert manager._active_downloads[result["download_id"]]["status"] == "skipped"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_proceeds_when_history_skip_disabled(monkeypatch, scanners, metadata_provider):
|
||||||
|
manager = DownloadManager()
|
||||||
|
get_settings_manager().settings["skip_previously_downloaded_model_versions"] = False
|
||||||
|
|
||||||
|
metadata_provider.get_model_version = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"id": 42,
|
||||||
|
"model": {"type": "LoRA", "tags": ["fantasy"]},
|
||||||
|
"baseModel": "SDXL 1.0",
|
||||||
|
"creator": {"username": "Author"},
|
||||||
|
"files": [
|
||||||
|
{
|
||||||
|
"type": "Model",
|
||||||
|
"primary": True,
|
||||||
|
"downloadUrl": "https://example.invalid/file.safetensors",
|
||||||
|
"name": "file.safetensors",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
history_service = AsyncMock()
|
||||||
|
history_service.has_been_downloaded = AsyncMock(return_value=True)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
ServiceRegistry,
|
||||||
|
"get_downloaded_version_history_service",
|
||||||
|
AsyncMock(return_value=history_service),
|
||||||
|
)
|
||||||
|
|
||||||
|
execute_download = AsyncMock(return_value={"success": True, "download_id": "done"})
|
||||||
|
monkeypatch.setattr(
|
||||||
|
DownloadManager, "_execute_download", execute_download, raising=False
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await manager.download_from_civitai(
|
||||||
|
model_version_id=99,
|
||||||
|
use_default_paths=True,
|
||||||
|
progress_callback=None,
|
||||||
|
source=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result.get("skipped") is not True
|
||||||
|
history_service.has_been_downloaded.assert_not_called()
|
||||||
|
execute_download.assert_awaited_once()
|
||||||
|
|||||||
70
tests/services/test_downloaded_version_history_service.py
Normal file
70
tests/services/test_downloaded_version_history_service.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from py.services.downloaded_version_history_service import (
|
||||||
|
DownloadedVersionHistoryService,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DummySettings:
|
||||||
|
def get_active_library_name(self) -> str:
|
||||||
|
return "alpha"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_history_roundtrip_and_manual_override(tmp_path: Path) -> None:
|
||||||
|
db_path = tmp_path / "download-history.sqlite"
|
||||||
|
service = DownloadedVersionHistoryService(
|
||||||
|
str(db_path),
|
||||||
|
settings_manager=DummySettings(),
|
||||||
|
)
|
||||||
|
|
||||||
|
await service.mark_downloaded(
|
||||||
|
"lora",
|
||||||
|
101,
|
||||||
|
model_id=11,
|
||||||
|
source="scan",
|
||||||
|
file_path="/models/a.safetensors",
|
||||||
|
)
|
||||||
|
assert await service.has_been_downloaded("lora", 101) is True
|
||||||
|
assert await service.get_downloaded_version_ids("lora", 11) == [101]
|
||||||
|
|
||||||
|
await service.mark_not_downloaded("lora", 101)
|
||||||
|
assert await service.has_been_downloaded("lora", 101) is False
|
||||||
|
assert await service.get_downloaded_version_ids("lora", 11) == []
|
||||||
|
|
||||||
|
await service.mark_downloaded(
|
||||||
|
"lora",
|
||||||
|
101,
|
||||||
|
model_id=11,
|
||||||
|
source="download",
|
||||||
|
file_path="/models/a.safetensors",
|
||||||
|
)
|
||||||
|
assert await service.has_been_downloaded("lora", 101) is True
|
||||||
|
assert await service.get_downloaded_version_ids("lora", 11) == [101]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_history_bulk_lookup(tmp_path: Path) -> None:
|
||||||
|
db_path = tmp_path / "download-history.sqlite"
|
||||||
|
service = DownloadedVersionHistoryService(
|
||||||
|
str(db_path),
|
||||||
|
settings_manager=DummySettings(),
|
||||||
|
)
|
||||||
|
|
||||||
|
await service.mark_downloaded_bulk(
|
||||||
|
"checkpoint",
|
||||||
|
[
|
||||||
|
{"model_id": 5, "version_id": 501, "file_path": "/m/one.safetensors"},
|
||||||
|
{"model_id": 5, "version_id": 502, "file_path": "/m/two.safetensors"},
|
||||||
|
{"model_id": 6, "version_id": 601, "file_path": "/m/three.safetensors"},
|
||||||
|
],
|
||||||
|
source="scan",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await service.get_downloaded_version_ids("checkpoint", 5) == [501, 502]
|
||||||
|
assert await service.get_downloaded_version_ids_bulk("checkpoint", [5, 6, 7]) == {
|
||||||
|
5: {501, 502},
|
||||||
|
6: {601},
|
||||||
|
}
|
||||||
@@ -829,3 +829,14 @@ def test_setting_download_skip_base_models_normalizes_string_input(manager):
|
|||||||
manager.set("download_skip_base_models", "SDXL 1.0, Pony; Invalid\nSDXL 1.0")
|
manager.set("download_skip_base_models", "SDXL 1.0, Pony; Invalid\nSDXL 1.0")
|
||||||
|
|
||||||
assert manager.get("download_skip_base_models") == ["SDXL 1.0", "Pony"]
|
assert manager.get("download_skip_base_models") == ["SDXL 1.0", "Pony"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_skip_previously_downloaded_model_versions_defaults_false(manager):
|
||||||
|
assert manager.get_skip_previously_downloaded_model_versions() is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_skip_previously_downloaded_model_versions_coerces_string_input(manager):
|
||||||
|
manager.settings["skip_previously_downloaded_model_versions"] = "true"
|
||||||
|
|
||||||
|
assert manager.get_skip_previously_downloaded_model_versions() is True
|
||||||
|
assert manager.settings["skip_previously_downloaded_model_versions"] is True
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import type { LoraPoolConfig, RandomizerConfig, CyclerConfig } from './composabl
|
|||||||
import {
|
import {
|
||||||
setupModeChangeHandler,
|
setupModeChangeHandler,
|
||||||
createModeChangeCallback,
|
createModeChangeCallback,
|
||||||
LORA_PROVIDER_NODE_TYPES
|
LORA_CHAIN_NODE_TYPES
|
||||||
} from './mode-change-handler'
|
} from './mode-change-handler'
|
||||||
|
|
||||||
const LORA_POOL_WIDGET_MIN_WIDTH = 500
|
const LORA_POOL_WIDGET_MIN_WIDTH = 500
|
||||||
@@ -755,8 +755,8 @@ app.registerExtension({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register mode change handlers for LoRA provider nodes
|
// Register mode change handlers for LORA_STACK chain nodes
|
||||||
if (LORA_PROVIDER_NODE_TYPES.includes(comfyClass)) {
|
if (LORA_CHAIN_NODE_TYPES.includes(comfyClass)) {
|
||||||
const originalOnNodeCreated = nodeType.prototype.onNodeCreated
|
const originalOnNodeCreated = nodeType.prototype.onNodeCreated
|
||||||
|
|
||||||
nodeType.prototype.onNodeCreated = function () {
|
nodeType.prototype.onNodeCreated = function () {
|
||||||
|
|||||||
@@ -18,7 +18,22 @@ export const LORA_PROVIDER_NODE_TYPES = [
|
|||||||
"Lora Cycler (LoraManager)",
|
"Lora Cycler (LoraManager)",
|
||||||
] as const;
|
] as const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Nodes that do not own LoRA state themselves, but merge or forward LORA_STACK
|
||||||
|
* inputs so downstream trigger-word updates must traverse through them.
|
||||||
|
*/
|
||||||
|
export const LORA_STACK_AGGREGATOR_NODE_TYPES = [
|
||||||
|
"Lora Stack Combiner (LoraManager)",
|
||||||
|
] as const;
|
||||||
|
|
||||||
|
export const LORA_CHAIN_NODE_TYPES = [
|
||||||
|
...LORA_PROVIDER_NODE_TYPES,
|
||||||
|
...LORA_STACK_AGGREGATOR_NODE_TYPES,
|
||||||
|
] as const;
|
||||||
|
|
||||||
export type LoraProviderNodeType = typeof LORA_PROVIDER_NODE_TYPES[number];
|
export type LoraProviderNodeType = typeof LORA_PROVIDER_NODE_TYPES[number];
|
||||||
|
export type LoraStackAggregatorNodeType = typeof LORA_STACK_AGGREGATOR_NODE_TYPES[number];
|
||||||
|
export type LoraChainNodeType = typeof LORA_CHAIN_NODE_TYPES[number];
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check if a node class is a LoRA provider node.
|
* Check if a node class is a LoRA provider node.
|
||||||
@@ -27,6 +42,16 @@ export function isLoraProviderNode(comfyClass: string): comfyClass is LoraProvid
|
|||||||
return LORA_PROVIDER_NODE_TYPES.includes(comfyClass as LoraProviderNodeType);
|
return LORA_PROVIDER_NODE_TYPES.includes(comfyClass as LoraProviderNodeType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function isLoraStackAggregatorNode(
|
||||||
|
comfyClass: string
|
||||||
|
): comfyClass is LoraStackAggregatorNodeType {
|
||||||
|
return LORA_STACK_AGGREGATOR_NODE_TYPES.includes(comfyClass as LoraStackAggregatorNodeType);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function isLoraChainNode(comfyClass: string): comfyClass is LoraChainNodeType {
|
||||||
|
return LORA_CHAIN_NODE_TYPES.includes(comfyClass as LoraChainNodeType);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Extract active LoRA filenames from a node based on its type.
|
* Extract active LoRA filenames from a node based on its type.
|
||||||
*
|
*
|
||||||
@@ -40,6 +65,10 @@ export function getActiveLorasFromNodeByType(node: any): Set<string> {
|
|||||||
return extractFromCyclerConfig(node);
|
return extractFromCyclerConfig(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isLoraStackAggregatorNode(comfyClass)) {
|
||||||
|
return new Set<string>();
|
||||||
|
}
|
||||||
|
|
||||||
// Default: use lorasWidget (works for Stacker and Randomizer)
|
// Default: use lorasWidget (works for Stacker and Randomizer)
|
||||||
return extractFromLorasWidget(node);
|
return extractFromLorasWidget(node);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,10 +10,27 @@ export const LORA_PROVIDER_NODE_TYPES = [
|
|||||||
"Lora Cycler (LoraManager)",
|
"Lora Cycler (LoraManager)",
|
||||||
];
|
];
|
||||||
|
|
||||||
|
export const LORA_STACK_AGGREGATOR_NODE_TYPES = [
|
||||||
|
"Lora Stack Combiner (LoraManager)",
|
||||||
|
];
|
||||||
|
|
||||||
|
export const LORA_CHAIN_NODE_TYPES = [
|
||||||
|
...LORA_PROVIDER_NODE_TYPES,
|
||||||
|
...LORA_STACK_AGGREGATOR_NODE_TYPES,
|
||||||
|
];
|
||||||
|
|
||||||
export function isLoraProviderNode(comfyClass) {
|
export function isLoraProviderNode(comfyClass) {
|
||||||
return LORA_PROVIDER_NODE_TYPES.includes(comfyClass);
|
return LORA_PROVIDER_NODE_TYPES.includes(comfyClass);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function isLoraStackAggregatorNode(comfyClass) {
|
||||||
|
return LORA_STACK_AGGREGATOR_NODE_TYPES.includes(comfyClass);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function isLoraChainNode(comfyClass) {
|
||||||
|
return LORA_CHAIN_NODE_TYPES.includes(comfyClass);
|
||||||
|
}
|
||||||
|
|
||||||
function isMapLike(collection) {
|
function isMapLike(collection) {
|
||||||
return collection && typeof collection.entries === "function" && typeof collection.values === "function";
|
return collection && typeof collection.entries === "function" && typeof collection.values === "function";
|
||||||
}
|
}
|
||||||
@@ -245,16 +262,20 @@ export function hideWidgetForGood(node, widget, suffix = "") {
|
|||||||
// Update pattern to match both formats: <lora:name:model_strength> or <lora:name:model_strength:clip_strength>
|
// Update pattern to match both formats: <lora:name:model_strength> or <lora:name:model_strength:clip_strength>
|
||||||
export const LORA_PATTERN = /<lora:([^:]+):([-\d\.]+)(?::([-\d\.]+))?>/g;
|
export const LORA_PATTERN = /<lora:([^:]+):([-\d\.]+)(?::([-\d\.]+))?>/g;
|
||||||
|
|
||||||
// Get connected Lora Stacker nodes that feed into the current node
|
function isLoraStackInput(input) {
|
||||||
export function getConnectedInputStackers(node) {
|
return input?.type === "LORA_STACK";
|
||||||
const connectedStackers = [];
|
}
|
||||||
|
|
||||||
|
// Get connected LORA_STACK chain nodes that feed into the current node
|
||||||
|
export function getConnectedInputLoraChainNodes(node) {
|
||||||
|
const connectedNodes = [];
|
||||||
|
|
||||||
if (!node?.inputs) {
|
if (!node?.inputs) {
|
||||||
return connectedStackers;
|
return connectedNodes;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const input of node.inputs) {
|
for (const input of node.inputs) {
|
||||||
if (input.name !== "lora_stack" || !input.link) {
|
if (!isLoraStackInput(input) || !input.link) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -264,12 +285,12 @@ export function getConnectedInputStackers(node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const sourceNode = node.graph?.getNodeById?.(link.origin_id);
|
const sourceNode = node.graph?.getNodeById?.(link.origin_id);
|
||||||
if (sourceNode && isLoraProviderNode(sourceNode.comfyClass)) {
|
if (sourceNode && isLoraChainNode(sourceNode.comfyClass)) {
|
||||||
connectedStackers.push(sourceNode);
|
connectedNodes.push(sourceNode);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return connectedStackers;
|
return connectedNodes;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get connected TriggerWord Toggle nodes that receive output from the current node
|
// Get connected TriggerWord Toggle nodes that receive output from the current node
|
||||||
@@ -314,6 +335,11 @@ export function getActiveLorasFromNode(node) {
|
|||||||
return activeLoraNames;
|
return activeLoraNames;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Aggregator nodes do not own LoRA state directly; they only forward upstream stacks.
|
||||||
|
if (isLoraStackAggregatorNode(node.comfyClass)) {
|
||||||
|
return activeLoraNames;
|
||||||
|
}
|
||||||
|
|
||||||
// Handle Lora Stacker and Lora Randomizer (lorasWidget)
|
// Handle Lora Stacker and Lora Randomizer (lorasWidget)
|
||||||
let lorasWidget = node.lorasWidget;
|
let lorasWidget = node.lorasWidget;
|
||||||
if (!lorasWidget && node.widgets) {
|
if (!lorasWidget && node.widgets) {
|
||||||
@@ -348,14 +374,18 @@ export function collectActiveLorasFromChain(node, visited = new Set()) {
|
|||||||
// Mode 2 is Never, Mode 4 is Bypass
|
// Mode 2 is Never, Mode 4 is Bypass
|
||||||
const isNodeActive = node.mode === undefined || node.mode === 0 || node.mode === 3;
|
const isNodeActive = node.mode === undefined || node.mode === 0 || node.mode === 3;
|
||||||
|
|
||||||
|
if (!isNodeActive) {
|
||||||
|
return new Set();
|
||||||
|
}
|
||||||
|
|
||||||
// Get active loras from current node only if node is active
|
// Get active loras from current node only if node is active
|
||||||
const allActiveLoraNames = isNodeActive ? getActiveLorasFromNode(node) : new Set();
|
const allActiveLoraNames = getActiveLorasFromNode(node);
|
||||||
|
|
||||||
// Get connected input stackers and collect their active loras
|
// Get connected input LORA_STACK chain nodes and collect their active loras
|
||||||
const inputStackers = getConnectedInputStackers(node);
|
const inputChainNodes = getConnectedInputLoraChainNodes(node);
|
||||||
for (const stacker of inputStackers) {
|
for (const chainNode of inputChainNodes) {
|
||||||
const stackerLoras = collectActiveLorasFromChain(stacker, visited);
|
const upstreamLoras = collectActiveLorasFromChain(chainNode, visited);
|
||||||
stackerLoras.forEach(name => allActiveLoraNames.add(name));
|
upstreamLoras.forEach(name => allActiveLoraNames.add(name));
|
||||||
}
|
}
|
||||||
|
|
||||||
return allActiveLoraNames;
|
return allActiveLoraNames;
|
||||||
@@ -819,8 +849,8 @@ export function updateDownstreamLoaders(startNode, visited = new Set()) {
|
|||||||
collectActiveLorasFromChain(targetNode);
|
collectActiveLorasFromChain(targetNode);
|
||||||
updateConnectedTriggerWords(targetNode, allActiveLoraNames);
|
updateConnectedTriggerWords(targetNode, allActiveLoraNames);
|
||||||
}
|
}
|
||||||
// If target is another LoRA provider node, recursively check its outputs
|
// If target is another LORA_STACK chain node, recursively check its outputs
|
||||||
else if (targetNode && isLoraProviderNode(targetNode.comfyClass)) {
|
else if (targetNode && isLoraChainNode(targetNode.comfyClass)) {
|
||||||
updateDownstreamLoaders(targetNode, visited);
|
updateDownstreamLoaders(targetNode, visited);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14938,11 +14938,24 @@ const LORA_PROVIDER_NODE_TYPES$1 = [
|
|||||||
"Lora Randomizer (LoraManager)",
|
"Lora Randomizer (LoraManager)",
|
||||||
"Lora Cycler (LoraManager)"
|
"Lora Cycler (LoraManager)"
|
||||||
];
|
];
|
||||||
|
const LORA_STACK_AGGREGATOR_NODE_TYPES$1 = [
|
||||||
|
"Lora Stack Combiner (LoraManager)"
|
||||||
|
];
|
||||||
|
const LORA_CHAIN_NODE_TYPES$1 = [
|
||||||
|
...LORA_PROVIDER_NODE_TYPES$1,
|
||||||
|
...LORA_STACK_AGGREGATOR_NODE_TYPES$1
|
||||||
|
];
|
||||||
|
function isLoraStackAggregatorNode$1(comfyClass) {
|
||||||
|
return LORA_STACK_AGGREGATOR_NODE_TYPES$1.includes(comfyClass);
|
||||||
|
}
|
||||||
function getActiveLorasFromNodeByType(node) {
|
function getActiveLorasFromNodeByType(node) {
|
||||||
const comfyClass = node == null ? void 0 : node.comfyClass;
|
const comfyClass = node == null ? void 0 : node.comfyClass;
|
||||||
if (comfyClass === "Lora Cycler (LoraManager)") {
|
if (comfyClass === "Lora Cycler (LoraManager)") {
|
||||||
return extractFromCyclerConfig(node);
|
return extractFromCyclerConfig(node);
|
||||||
}
|
}
|
||||||
|
if (isLoraStackAggregatorNode$1(comfyClass)) {
|
||||||
|
return /* @__PURE__ */ new Set();
|
||||||
|
}
|
||||||
return extractFromLorasWidget(node);
|
return extractFromLorasWidget(node);
|
||||||
}
|
}
|
||||||
function extractFromLorasWidget(node) {
|
function extractFromLorasWidget(node) {
|
||||||
@@ -15002,8 +15015,18 @@ const LORA_PROVIDER_NODE_TYPES = [
|
|||||||
"Lora Randomizer (LoraManager)",
|
"Lora Randomizer (LoraManager)",
|
||||||
"Lora Cycler (LoraManager)"
|
"Lora Cycler (LoraManager)"
|
||||||
];
|
];
|
||||||
function isLoraProviderNode(comfyClass) {
|
const LORA_STACK_AGGREGATOR_NODE_TYPES = [
|
||||||
return LORA_PROVIDER_NODE_TYPES.includes(comfyClass);
|
"Lora Stack Combiner (LoraManager)"
|
||||||
|
];
|
||||||
|
const LORA_CHAIN_NODE_TYPES = [
|
||||||
|
...LORA_PROVIDER_NODE_TYPES,
|
||||||
|
...LORA_STACK_AGGREGATOR_NODE_TYPES
|
||||||
|
];
|
||||||
|
function isLoraStackAggregatorNode(comfyClass) {
|
||||||
|
return LORA_STACK_AGGREGATOR_NODE_TYPES.includes(comfyClass);
|
||||||
|
}
|
||||||
|
function isLoraChainNode(comfyClass) {
|
||||||
|
return LORA_CHAIN_NODE_TYPES.includes(comfyClass);
|
||||||
}
|
}
|
||||||
function isMapLike(collection) {
|
function isMapLike(collection) {
|
||||||
return collection && typeof collection.entries === "function" && typeof collection.values === "function";
|
return collection && typeof collection.entries === "function" && typeof collection.values === "function";
|
||||||
@@ -15041,14 +15064,17 @@ function getLinkFromGraph(graph, linkId) {
|
|||||||
}
|
}
|
||||||
return graph.links[linkId] || null;
|
return graph.links[linkId] || null;
|
||||||
}
|
}
|
||||||
function getConnectedInputStackers(node) {
|
function isLoraStackInput(input) {
|
||||||
|
return (input == null ? void 0 : input.type) === "LORA_STACK";
|
||||||
|
}
|
||||||
|
function getConnectedInputLoraChainNodes(node) {
|
||||||
var _a2, _b;
|
var _a2, _b;
|
||||||
const connectedStackers = [];
|
const connectedNodes = [];
|
||||||
if (!(node == null ? void 0 : node.inputs)) {
|
if (!(node == null ? void 0 : node.inputs)) {
|
||||||
return connectedStackers;
|
return connectedNodes;
|
||||||
}
|
}
|
||||||
for (const input of node.inputs) {
|
for (const input of node.inputs) {
|
||||||
if (input.name !== "lora_stack" || !input.link) {
|
if (!isLoraStackInput(input) || !input.link) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const link = getLinkFromGraph(node.graph, input.link);
|
const link = getLinkFromGraph(node.graph, input.link);
|
||||||
@@ -15056,11 +15082,11 @@ function getConnectedInputStackers(node) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const sourceNode = (_b = (_a2 = node.graph) == null ? void 0 : _a2.getNodeById) == null ? void 0 : _b.call(_a2, link.origin_id);
|
const sourceNode = (_b = (_a2 = node.graph) == null ? void 0 : _a2.getNodeById) == null ? void 0 : _b.call(_a2, link.origin_id);
|
||||||
if (sourceNode && isLoraProviderNode(sourceNode.comfyClass)) {
|
if (sourceNode && isLoraChainNode(sourceNode.comfyClass)) {
|
||||||
connectedStackers.push(sourceNode);
|
connectedNodes.push(sourceNode);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return connectedStackers;
|
return connectedNodes;
|
||||||
}
|
}
|
||||||
function getConnectedTriggerToggleNodes(node) {
|
function getConnectedTriggerToggleNodes(node) {
|
||||||
var _a2, _b, _c;
|
var _a2, _b, _c;
|
||||||
@@ -15095,6 +15121,9 @@ function getActiveLorasFromNode(node) {
|
|||||||
}
|
}
|
||||||
return activeLoraNames;
|
return activeLoraNames;
|
||||||
}
|
}
|
||||||
|
if (isLoraStackAggregatorNode(node.comfyClass)) {
|
||||||
|
return activeLoraNames;
|
||||||
|
}
|
||||||
let lorasWidget = node.lorasWidget;
|
let lorasWidget = node.lorasWidget;
|
||||||
if (!lorasWidget && node.widgets) {
|
if (!lorasWidget && node.widgets) {
|
||||||
lorasWidget = node.widgets.find((w2) => w2.name === "loras");
|
lorasWidget = node.widgets.find((w2) => w2.name === "loras");
|
||||||
@@ -15118,11 +15147,14 @@ function collectActiveLorasFromChain(node, visited = /* @__PURE__ */ new Set())
|
|||||||
}
|
}
|
||||||
visited.add(nodeKey);
|
visited.add(nodeKey);
|
||||||
const isNodeActive2 = node.mode === void 0 || node.mode === 0 || node.mode === 3;
|
const isNodeActive2 = node.mode === void 0 || node.mode === 0 || node.mode === 3;
|
||||||
const allActiveLoraNames = isNodeActive2 ? getActiveLorasFromNode(node) : /* @__PURE__ */ new Set();
|
if (!isNodeActive2) {
|
||||||
const inputStackers = getConnectedInputStackers(node);
|
return /* @__PURE__ */ new Set();
|
||||||
for (const stacker of inputStackers) {
|
}
|
||||||
const stackerLoras = collectActiveLorasFromChain(stacker, visited);
|
const allActiveLoraNames = getActiveLorasFromNode(node);
|
||||||
stackerLoras.forEach((name) => allActiveLoraNames.add(name));
|
const inputChainNodes = getConnectedInputLoraChainNodes(node);
|
||||||
|
for (const chainNode of inputChainNodes) {
|
||||||
|
const upstreamLoras = collectActiveLorasFromChain(chainNode, visited);
|
||||||
|
upstreamLoras.forEach((name) => allActiveLoraNames.add(name));
|
||||||
}
|
}
|
||||||
return allActiveLoraNames;
|
return allActiveLoraNames;
|
||||||
}
|
}
|
||||||
@@ -15191,7 +15223,7 @@ function updateDownstreamLoaders(startNode, visited = /* @__PURE__ */ new Set())
|
|||||||
if (targetNode && targetNode.comfyClass === "Lora Loader (LoraManager)") {
|
if (targetNode && targetNode.comfyClass === "Lora Loader (LoraManager)") {
|
||||||
const allActiveLoraNames = collectActiveLorasFromChain(targetNode);
|
const allActiveLoraNames = collectActiveLorasFromChain(targetNode);
|
||||||
updateConnectedTriggerWords(targetNode, allActiveLoraNames);
|
updateConnectedTriggerWords(targetNode, allActiveLoraNames);
|
||||||
} else if (targetNode && isLoraProviderNode(targetNode.comfyClass)) {
|
} else if (targetNode && isLoraChainNode(targetNode.comfyClass)) {
|
||||||
updateDownstreamLoaders(targetNode, visited);
|
updateDownstreamLoaders(targetNode, visited);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -15784,7 +15816,7 @@ app$1.registerExtension({
|
|||||||
return originalConfigure == null ? void 0 : originalConfigure.apply(this, arguments);
|
return originalConfigure == null ? void 0 : originalConfigure.apply(this, arguments);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
if (LORA_PROVIDER_NODE_TYPES$1.includes(comfyClass)) {
|
if (LORA_CHAIN_NODE_TYPES$1.includes(comfyClass)) {
|
||||||
const originalOnNodeCreated = nodeType.prototype.onNodeCreated;
|
const originalOnNodeCreated = nodeType.prototype.onNodeCreated;
|
||||||
nodeType.prototype.onNodeCreated = function() {
|
nodeType.prototype.onNodeCreated = function() {
|
||||||
originalOnNodeCreated == null ? void 0 : originalOnNodeCreated.apply(this, arguments);
|
originalOnNodeCreated == null ? void 0 : originalOnNodeCreated.apply(this, arguments);
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user