mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 22:22:11 -03:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76d3aa2b5b | ||
|
|
c9a65c7347 | ||
|
|
f542ade628 | ||
|
|
d2c2bfbe6a | ||
|
|
2b6910bd55 | ||
|
|
b1dd733493 | ||
|
|
5dcf0a1e48 | ||
|
|
cf357b57fc | ||
|
|
4e1773833f | ||
|
|
8cf762ffd3 | ||
|
|
d997eaa429 | ||
|
|
8e51f0f19f | ||
|
|
f0e246b4ac | ||
|
|
a232997a79 | ||
|
|
08a449db99 | ||
|
|
0c023c9888 | ||
|
|
0ad92d00b3 | ||
|
|
a726cbea1e | ||
|
|
c53fa8692b | ||
|
|
3118f3b43c |
@@ -529,12 +529,15 @@
|
|||||||
"title": "Embedding-Modelle"
|
"title": "Embedding-Modelle"
|
||||||
},
|
},
|
||||||
"sidebar": {
|
"sidebar": {
|
||||||
"modelRoot": "Modell-Stammverzeichnis",
|
"modelRoot": "Stammverzeichnis",
|
||||||
"collapseAll": "Alle Ordner einklappen",
|
"collapseAll": "Alle Ordner einklappen",
|
||||||
"pinSidebar": "Sidebar anheften",
|
"pinSidebar": "Sidebar anheften",
|
||||||
"unpinSidebar": "Sidebar lösen",
|
"unpinSidebar": "Sidebar lösen",
|
||||||
"switchToListView": "Zur Listenansicht wechseln",
|
"switchToListView": "Zur Listenansicht wechseln",
|
||||||
"switchToTreeView": "Zur Baumansicht wechseln",
|
"switchToTreeView": "Zur Baumansicht wechseln",
|
||||||
|
"recursiveOn": "Unterordner durchsuchen",
|
||||||
|
"recursiveOff": "Nur aktuellen Ordner durchsuchen",
|
||||||
|
"recursiveUnavailable": "Rekursive Suche ist nur in der Baumansicht verfügbar",
|
||||||
"collapseAllDisabled": "Im Listenmodus nicht verfügbar"
|
"collapseAllDisabled": "Im Listenmodus nicht verfügbar"
|
||||||
},
|
},
|
||||||
"statistics": {
|
"statistics": {
|
||||||
|
|||||||
@@ -529,12 +529,15 @@
|
|||||||
"title": "Embedding Models"
|
"title": "Embedding Models"
|
||||||
},
|
},
|
||||||
"sidebar": {
|
"sidebar": {
|
||||||
"modelRoot": "Model Root",
|
"modelRoot": "Root",
|
||||||
"collapseAll": "Collapse All Folders",
|
"collapseAll": "Collapse All Folders",
|
||||||
"pinSidebar": "Pin Sidebar",
|
"pinSidebar": "Pin Sidebar",
|
||||||
"unpinSidebar": "Unpin Sidebar",
|
"unpinSidebar": "Unpin Sidebar",
|
||||||
"switchToListView": "Switch to List View",
|
"switchToListView": "Switch to List View",
|
||||||
"switchToTreeView": "Switch to Tree View",
|
"switchToTreeView": "Switch to Tree View",
|
||||||
|
"recursiveOn": "Search subfolders",
|
||||||
|
"recursiveOff": "Search current folder only",
|
||||||
|
"recursiveUnavailable": "Recursive search is available in tree view only",
|
||||||
"collapseAllDisabled": "Not available in list view"
|
"collapseAllDisabled": "Not available in list view"
|
||||||
},
|
},
|
||||||
"statistics": {
|
"statistics": {
|
||||||
|
|||||||
@@ -529,12 +529,15 @@
|
|||||||
"title": "Modelos embedding"
|
"title": "Modelos embedding"
|
||||||
},
|
},
|
||||||
"sidebar": {
|
"sidebar": {
|
||||||
"modelRoot": "Raíz del modelo",
|
"modelRoot": "Raíz",
|
||||||
"collapseAll": "Colapsar todas las carpetas",
|
"collapseAll": "Colapsar todas las carpetas",
|
||||||
"pinSidebar": "Fijar barra lateral",
|
"pinSidebar": "Fijar barra lateral",
|
||||||
"unpinSidebar": "Desfijar barra lateral",
|
"unpinSidebar": "Desfijar barra lateral",
|
||||||
"switchToListView": "Cambiar a vista de lista",
|
"switchToListView": "Cambiar a vista de lista",
|
||||||
"switchToTreeView": "Cambiar a vista de árbol",
|
"switchToTreeView": "Cambiar a vista de árbol",
|
||||||
|
"recursiveOn": "Buscar en subcarpetas",
|
||||||
|
"recursiveOff": "Buscar solo en la carpeta actual",
|
||||||
|
"recursiveUnavailable": "La búsqueda recursiva solo está disponible en la vista en árbol",
|
||||||
"collapseAllDisabled": "No disponible en vista de lista"
|
"collapseAllDisabled": "No disponible en vista de lista"
|
||||||
},
|
},
|
||||||
"statistics": {
|
"statistics": {
|
||||||
|
|||||||
@@ -529,12 +529,15 @@
|
|||||||
"title": "Modèles Embedding"
|
"title": "Modèles Embedding"
|
||||||
},
|
},
|
||||||
"sidebar": {
|
"sidebar": {
|
||||||
"modelRoot": "Racine du modèle",
|
"modelRoot": "Racine",
|
||||||
"collapseAll": "Réduire tous les dossiers",
|
"collapseAll": "Réduire tous les dossiers",
|
||||||
"pinSidebar": "Épingler la barre latérale",
|
"pinSidebar": "Épingler la barre latérale",
|
||||||
"unpinSidebar": "Désépingler la barre latérale",
|
"unpinSidebar": "Désépingler la barre latérale",
|
||||||
"switchToListView": "Passer en vue liste",
|
"switchToListView": "Passer en vue liste",
|
||||||
"switchToTreeView": "Passer en vue arborescence",
|
"switchToTreeView": "Passer en vue arborescence",
|
||||||
|
"recursiveOn": "Rechercher dans les sous-dossiers",
|
||||||
|
"recursiveOff": "Rechercher uniquement dans le dossier actuel",
|
||||||
|
"recursiveUnavailable": "La recherche récursive n'est disponible qu'en vue arborescente",
|
||||||
"collapseAllDisabled": "Non disponible en vue liste"
|
"collapseAllDisabled": "Non disponible en vue liste"
|
||||||
},
|
},
|
||||||
"statistics": {
|
"statistics": {
|
||||||
|
|||||||
@@ -529,12 +529,15 @@
|
|||||||
"title": "מודלי Embedding"
|
"title": "מודלי Embedding"
|
||||||
},
|
},
|
||||||
"sidebar": {
|
"sidebar": {
|
||||||
"modelRoot": "שורש המודלים",
|
"modelRoot": "שורש",
|
||||||
"collapseAll": "כווץ את כל התיקיות",
|
"collapseAll": "כווץ את כל התיקיות",
|
||||||
"pinSidebar": "נעל סרגל צד",
|
"pinSidebar": "נעל סרגל צד",
|
||||||
"unpinSidebar": "שחרר סרגל צד",
|
"unpinSidebar": "שחרר סרגל צד",
|
||||||
"switchToListView": "עבור לתצוגת רשימה",
|
"switchToListView": "עבור לתצוגת רשימה",
|
||||||
"switchToTreeView": "עבור לתצוגת עץ",
|
"switchToTreeView": "תצוגת עץ",
|
||||||
|
"recursiveOn": "חיפוש בתיקיות משנה",
|
||||||
|
"recursiveOff": "חיפוש רק בתיקייה הנוכחית",
|
||||||
|
"recursiveUnavailable": "חיפוש רקורסיבי זמין רק בתצוגת עץ",
|
||||||
"collapseAllDisabled": "לא זמין בתצוגת רשימה"
|
"collapseAllDisabled": "לא זמין בתצוגת רשימה"
|
||||||
},
|
},
|
||||||
"statistics": {
|
"statistics": {
|
||||||
@@ -1264,4 +1267,4 @@
|
|||||||
"learnMore": "LM Civitai Extension Tutorial"
|
"learnMore": "LM Civitai Extension Tutorial"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -529,12 +529,15 @@
|
|||||||
"title": "Embeddingモデル"
|
"title": "Embeddingモデル"
|
||||||
},
|
},
|
||||||
"sidebar": {
|
"sidebar": {
|
||||||
"modelRoot": "モデルルート",
|
"modelRoot": "ルート",
|
||||||
"collapseAll": "すべてのフォルダを折りたたむ",
|
"collapseAll": "すべてのフォルダを折りたたむ",
|
||||||
"pinSidebar": "サイドバーを固定",
|
"pinSidebar": "サイドバーを固定",
|
||||||
"unpinSidebar": "サイドバーの固定を解除",
|
"unpinSidebar": "サイドバーの固定を解除",
|
||||||
"switchToListView": "リストビューに切り替え",
|
"switchToListView": "リストビューに切り替え",
|
||||||
"switchToTreeView": "ツリービューに切り替え",
|
"switchToTreeView": "ツリー表示に切り替え",
|
||||||
|
"recursiveOn": "サブフォルダーを検索",
|
||||||
|
"recursiveOff": "現在のフォルダーのみを検索",
|
||||||
|
"recursiveUnavailable": "再帰検索はツリービューでのみ利用できます",
|
||||||
"collapseAllDisabled": "リストビューでは利用できません"
|
"collapseAllDisabled": "リストビューでは利用できません"
|
||||||
},
|
},
|
||||||
"statistics": {
|
"statistics": {
|
||||||
@@ -1264,4 +1267,4 @@
|
|||||||
"learnMore": "LM Civitai Extension Tutorial"
|
"learnMore": "LM Civitai Extension Tutorial"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -529,12 +529,15 @@
|
|||||||
"title": "Embedding 모델"
|
"title": "Embedding 모델"
|
||||||
},
|
},
|
||||||
"sidebar": {
|
"sidebar": {
|
||||||
"modelRoot": "모델 루트",
|
"modelRoot": "루트",
|
||||||
"collapseAll": "모든 폴더 접기",
|
"collapseAll": "모든 폴더 접기",
|
||||||
"pinSidebar": "사이드바 고정",
|
"pinSidebar": "사이드바 고정",
|
||||||
"unpinSidebar": "사이드바 고정 해제",
|
"unpinSidebar": "사이드바 고정 해제",
|
||||||
"switchToListView": "목록 보기로 전환",
|
"switchToListView": "목록 보기로 전환",
|
||||||
"switchToTreeView": "트리 보기로 전환",
|
"switchToTreeView": "트리 보기로 전환",
|
||||||
|
"recursiveOn": "하위 폴더 검색",
|
||||||
|
"recursiveOff": "현재 폴더만 검색",
|
||||||
|
"recursiveUnavailable": "재귀 검색은 트리 보기에서만 사용할 수 있습니다",
|
||||||
"collapseAllDisabled": "목록 보기에서는 사용할 수 없습니다"
|
"collapseAllDisabled": "목록 보기에서는 사용할 수 없습니다"
|
||||||
},
|
},
|
||||||
"statistics": {
|
"statistics": {
|
||||||
@@ -1264,4 +1267,4 @@
|
|||||||
"learnMore": "LM Civitai Extension Tutorial"
|
"learnMore": "LM Civitai Extension Tutorial"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -529,12 +529,15 @@
|
|||||||
"title": "Модели Embedding"
|
"title": "Модели Embedding"
|
||||||
},
|
},
|
||||||
"sidebar": {
|
"sidebar": {
|
||||||
"modelRoot": "Корень моделей",
|
"modelRoot": "Корень",
|
||||||
"collapseAll": "Свернуть все папки",
|
"collapseAll": "Свернуть все папки",
|
||||||
"pinSidebar": "Закрепить боковую панель",
|
"pinSidebar": "Закрепить боковую панель",
|
||||||
"unpinSidebar": "Открепить боковую панель",
|
"unpinSidebar": "Открепить боковую панель",
|
||||||
"switchToListView": "Переключить на вид списка",
|
"switchToListView": "Переключить на вид списка",
|
||||||
"switchToTreeView": "Переключить на древовидный вид",
|
"switchToTreeView": "Переключить на древовидный вид",
|
||||||
|
"recursiveOn": "Искать во вложенных папках",
|
||||||
|
"recursiveOff": "Искать только в текущей папке",
|
||||||
|
"recursiveUnavailable": "Рекурсивный поиск доступен только в режиме дерева",
|
||||||
"collapseAllDisabled": "Недоступно в виде списка"
|
"collapseAllDisabled": "Недоступно в виде списка"
|
||||||
},
|
},
|
||||||
"statistics": {
|
"statistics": {
|
||||||
@@ -1264,4 +1267,4 @@
|
|||||||
"learnMore": "LM Civitai Extension Tutorial"
|
"learnMore": "LM Civitai Extension Tutorial"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -535,12 +535,15 @@
|
|||||||
"title": "Embedding 模型"
|
"title": "Embedding 模型"
|
||||||
},
|
},
|
||||||
"sidebar": {
|
"sidebar": {
|
||||||
"modelRoot": "模型根目录",
|
"modelRoot": "根目录",
|
||||||
"collapseAll": "折叠所有文件夹",
|
"collapseAll": "折叠所有文件夹",
|
||||||
"pinSidebar": "固定侧边栏",
|
"pinSidebar": "固定侧边栏",
|
||||||
"unpinSidebar": "取消固定侧边栏",
|
"unpinSidebar": "取消固定侧边栏",
|
||||||
"switchToListView": "切换到列表视图",
|
"switchToListView": "切换到列表视图",
|
||||||
"switchToTreeView": "切换到树状视图",
|
"switchToTreeView": "切换到树状视图",
|
||||||
|
"recursiveOn": "搜索子文件夹",
|
||||||
|
"recursiveOff": "仅搜索当前文件夹",
|
||||||
|
"recursiveUnavailable": "仅在树形视图中可使用递归搜索",
|
||||||
"collapseAllDisabled": "列表视图下不可用"
|
"collapseAllDisabled": "列表视图下不可用"
|
||||||
},
|
},
|
||||||
"statistics": {
|
"statistics": {
|
||||||
@@ -1270,4 +1273,4 @@
|
|||||||
"learnMore": "LM Civitai Extension Tutorial"
|
"learnMore": "LM Civitai Extension Tutorial"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -529,12 +529,15 @@
|
|||||||
"title": "Embedding 模型"
|
"title": "Embedding 模型"
|
||||||
},
|
},
|
||||||
"sidebar": {
|
"sidebar": {
|
||||||
"modelRoot": "模型根目錄",
|
"modelRoot": "根目錄",
|
||||||
"collapseAll": "全部摺疊資料夾",
|
"collapseAll": "全部摺疊資料夾",
|
||||||
"pinSidebar": "固定側邊欄",
|
"pinSidebar": "固定側邊欄",
|
||||||
"unpinSidebar": "取消固定側邊欄",
|
"unpinSidebar": "取消固定側邊欄",
|
||||||
"switchToListView": "切換至列表檢視",
|
"switchToListView": "切換至列表檢視",
|
||||||
"switchToTreeView": "切換至樹狀檢視",
|
"switchToTreeView": "切換到樹狀檢視",
|
||||||
|
"recursiveOn": "搜尋子資料夾",
|
||||||
|
"recursiveOff": "僅搜尋目前資料夾",
|
||||||
|
"recursiveUnavailable": "遞迴搜尋僅能在樹狀檢視中使用",
|
||||||
"collapseAllDisabled": "列表檢視下不可用"
|
"collapseAllDisabled": "列表檢視下不可用"
|
||||||
},
|
},
|
||||||
"statistics": {
|
"statistics": {
|
||||||
@@ -1264,4 +1267,4 @@
|
|||||||
"learnMore": "LM Civitai Extension Tutorial"
|
"learnMore": "LM Civitai Extension Tutorial"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -74,8 +74,9 @@ class Config:
|
|||||||
"""Persist ComfyUI-derived folder paths to the multi-library settings."""
|
"""Persist ComfyUI-derived folder paths to the multi-library settings."""
|
||||||
try:
|
try:
|
||||||
ensure_settings_file(logger)
|
ensure_settings_file(logger)
|
||||||
from .services.settings_manager import settings as settings_service
|
from .services.settings_manager import get_settings_manager
|
||||||
|
|
||||||
|
settings_service = get_settings_manager()
|
||||||
libraries = settings_service.get_libraries()
|
libraries = settings_service.get_libraries()
|
||||||
comfy_library = libraries.get("comfyui", {})
|
comfy_library = libraries.get("comfyui", {})
|
||||||
default_library = libraries.get("default", {})
|
default_library = libraries.get("default", {})
|
||||||
@@ -442,8 +443,9 @@ class Config:
|
|||||||
"""Return the current library registry and active library name."""
|
"""Return the current library registry and active library name."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .services.settings_manager import settings as settings_service
|
from .services.settings_manager import get_settings_manager
|
||||||
|
|
||||||
|
settings_service = get_settings_manager()
|
||||||
libraries = settings_service.get_libraries()
|
libraries = settings_service.get_libraries()
|
||||||
active_library = settings_service.get_active_library_name()
|
active_library = settings_service.get_active_library_name()
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from .routes.misc_routes import MiscRoutes
|
|||||||
from .routes.preview_routes import PreviewRoutes
|
from .routes.preview_routes import PreviewRoutes
|
||||||
from .routes.example_images_routes import ExampleImagesRoutes
|
from .routes.example_images_routes import ExampleImagesRoutes
|
||||||
from .services.service_registry import ServiceRegistry
|
from .services.service_registry import ServiceRegistry
|
||||||
from .services.settings_manager import settings
|
from .services.settings_manager import get_settings_manager
|
||||||
from .utils.example_images_migration import ExampleImagesMigration
|
from .utils.example_images_migration import ExampleImagesMigration
|
||||||
from .services.websocket_manager import ws_manager
|
from .services.websocket_manager import ws_manager
|
||||||
from .services.example_images_cleanup_service import ExampleImagesCleanupService
|
from .services.example_images_cleanup_service import ExampleImagesCleanupService
|
||||||
@@ -23,6 +23,25 @@ logger = logging.getLogger(__name__)
|
|||||||
# Check if we're in standalone mode
|
# Check if we're in standalone mode
|
||||||
STANDALONE_MODE = 'nodes' not in sys.modules
|
STANDALONE_MODE = 'nodes' not in sys.modules
|
||||||
|
|
||||||
|
|
||||||
|
class _SettingsProxy:
|
||||||
|
def __init__(self):
|
||||||
|
self._manager = None
|
||||||
|
|
||||||
|
def _resolve(self):
|
||||||
|
if self._manager is None:
|
||||||
|
self._manager = get_settings_manager()
|
||||||
|
return self._manager
|
||||||
|
|
||||||
|
def get(self, *args, **kwargs):
|
||||||
|
return self._resolve().get(*args, **kwargs)
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
return getattr(self._resolve(), item)
|
||||||
|
|
||||||
|
|
||||||
|
settings = _SettingsProxy()
|
||||||
|
|
||||||
class LoraManager:
|
class LoraManager:
|
||||||
"""Main entry point for LoRA Manager plugin"""
|
"""Main entry point for LoRA Manager plugin"""
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from ..services.model_lifecycle_service import ModelLifecycleService
|
|||||||
from ..services.preview_asset_service import PreviewAssetService
|
from ..services.preview_asset_service import PreviewAssetService
|
||||||
from ..services.server_i18n import server_i18n as default_server_i18n
|
from ..services.server_i18n import server_i18n as default_server_i18n
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..services.settings_manager import settings as default_settings
|
from ..services.settings_manager import get_settings_manager
|
||||||
from ..services.tag_update_service import TagUpdateService
|
from ..services.tag_update_service import TagUpdateService
|
||||||
from ..services.websocket_manager import ws_manager as default_ws_manager
|
from ..services.websocket_manager import ws_manager as default_ws_manager
|
||||||
from ..services.use_cases import (
|
from ..services.use_cases import (
|
||||||
@@ -56,14 +56,14 @@ class BaseModelRoutes(ABC):
|
|||||||
self,
|
self,
|
||||||
service=None,
|
service=None,
|
||||||
*,
|
*,
|
||||||
settings_service=default_settings,
|
settings_service=None,
|
||||||
ws_manager=default_ws_manager,
|
ws_manager=default_ws_manager,
|
||||||
server_i18n=default_server_i18n,
|
server_i18n=default_server_i18n,
|
||||||
metadata_provider_factory=get_default_metadata_provider,
|
metadata_provider_factory=get_default_metadata_provider,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.service = None
|
self.service = None
|
||||||
self.model_type = ""
|
self.model_type = ""
|
||||||
self._settings = settings_service
|
self._settings = settings_service or get_settings_manager()
|
||||||
self._ws_manager = ws_manager
|
self._ws_manager = ws_manager
|
||||||
self._server_i18n = server_i18n
|
self._server_i18n = server_i18n
|
||||||
self._metadata_provider_factory = metadata_provider_factory
|
self._metadata_provider_factory = metadata_provider_factory
|
||||||
@@ -90,7 +90,7 @@ class BaseModelRoutes(ABC):
|
|||||||
self._metadata_sync_service = MetadataSyncService(
|
self._metadata_sync_service = MetadataSyncService(
|
||||||
metadata_manager=MetadataManager,
|
metadata_manager=MetadataManager,
|
||||||
preview_service=self._preview_service,
|
preview_service=self._preview_service,
|
||||||
settings=settings_service,
|
settings=self._settings,
|
||||||
default_metadata_provider_factory=metadata_provider_factory,
|
default_metadata_provider_factory=metadata_provider_factory,
|
||||||
metadata_provider_selector=get_metadata_provider,
|
metadata_provider_selector=get_metadata_provider,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from ..services.recipes import (
|
|||||||
)
|
)
|
||||||
from ..services.server_i18n import server_i18n
|
from ..services.server_i18n import server_i18n
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..services.settings_manager import settings
|
from ..services.settings_manager import get_settings_manager
|
||||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||||
from ..utils.exif_utils import ExifUtils
|
from ..utils.exif_utils import ExifUtils
|
||||||
from .handlers.recipe_handlers import (
|
from .handlers.recipe_handlers import (
|
||||||
@@ -48,7 +48,7 @@ class BaseRecipeRoutes:
|
|||||||
self.recipe_scanner = None
|
self.recipe_scanner = None
|
||||||
self.lora_scanner = None
|
self.lora_scanner = None
|
||||||
self.civitai_client = None
|
self.civitai_client = None
|
||||||
self.settings = settings
|
self.settings = get_settings_manager()
|
||||||
self.server_i18n = server_i18n
|
self.server_i18n = server_i18n
|
||||||
self.template_env = jinja2.Environment(
|
self.template_env = jinja2.Environment(
|
||||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||||
|
|||||||
@@ -24,10 +24,17 @@ from ...services.metadata_service import (
|
|||||||
update_metadata_providers,
|
update_metadata_providers,
|
||||||
)
|
)
|
||||||
from ...services.service_registry import ServiceRegistry
|
from ...services.service_registry import ServiceRegistry
|
||||||
from ...services.settings_manager import settings as default_settings
|
from ...services.settings_manager import get_settings_manager
|
||||||
from ...services.websocket_manager import ws_manager
|
from ...services.websocket_manager import ws_manager
|
||||||
from ...services.downloader import get_downloader
|
from ...services.downloader import get_downloader
|
||||||
from ...utils.constants import DEFAULT_NODE_COLOR, NODE_TYPES, SUPPORTED_MEDIA_EXTENSIONS
|
from ...utils.constants import (
|
||||||
|
CIVITAI_USER_MODEL_TYPES,
|
||||||
|
DEFAULT_NODE_COLOR,
|
||||||
|
NODE_TYPES,
|
||||||
|
SUPPORTED_MEDIA_EXTENSIONS,
|
||||||
|
VALID_LORA_TYPES,
|
||||||
|
)
|
||||||
|
from ...utils.civitai_utils import rewrite_preview_url
|
||||||
from ...utils.example_images_paths import is_valid_example_images_root
|
from ...utils.example_images_paths import is_valid_example_images_root
|
||||||
from ...utils.lora_metadata import extract_trained_words
|
from ...utils.lora_metadata import extract_trained_words
|
||||||
from ...utils.usage_stats import UsageStats
|
from ...utils.usage_stats import UsageStats
|
||||||
@@ -80,7 +87,7 @@ class NodeRegistry:
|
|||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
self._nodes: Dict[int, dict] = {}
|
self._nodes: Dict[str, dict] = {}
|
||||||
self._registry_updated = asyncio.Event()
|
self._registry_updated = asyncio.Event()
|
||||||
|
|
||||||
async def register_nodes(self, nodes: list[dict]) -> None:
|
async def register_nodes(self, nodes: list[dict]) -> None:
|
||||||
@@ -88,11 +95,16 @@ class NodeRegistry:
|
|||||||
self._nodes.clear()
|
self._nodes.clear()
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
node_id = node["node_id"]
|
node_id = node["node_id"]
|
||||||
|
graph_id = str(node["graph_id"])
|
||||||
|
unique_id = f"{graph_id}:{node_id}"
|
||||||
node_type = node.get("type", "")
|
node_type = node.get("type", "")
|
||||||
type_id = NODE_TYPES.get(node_type, 0)
|
type_id = NODE_TYPES.get(node_type, 0)
|
||||||
bgcolor = node.get("bgcolor") or DEFAULT_NODE_COLOR
|
bgcolor = node.get("bgcolor") or DEFAULT_NODE_COLOR
|
||||||
self._nodes[node_id] = {
|
self._nodes[unique_id] = {
|
||||||
"id": node_id,
|
"id": node_id,
|
||||||
|
"graph_id": graph_id,
|
||||||
|
"graph_name": node.get("graph_name"),
|
||||||
|
"unique_id": unique_id,
|
||||||
"bgcolor": bgcolor,
|
"bgcolor": bgcolor,
|
||||||
"title": node.get("title"),
|
"title": node.get("title"),
|
||||||
"type": type_id,
|
"type": type_id,
|
||||||
@@ -157,11 +169,11 @@ class SettingsHandler:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
settings_service=default_settings,
|
settings_service=None,
|
||||||
metadata_provider_updater: Callable[[], Awaitable[None]] = update_metadata_providers,
|
metadata_provider_updater: Callable[[], Awaitable[None]] = update_metadata_providers,
|
||||||
downloader_factory: Callable[[], Awaitable[DownloaderProtocol]] = get_downloader,
|
downloader_factory: Callable[[], Awaitable[DownloaderProtocol]] = get_downloader,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._settings = settings_service
|
self._settings = settings_service or get_settings_manager()
|
||||||
self._metadata_provider_updater = metadata_provider_updater
|
self._metadata_provider_updater = metadata_provider_updater
|
||||||
self._downloader_factory = downloader_factory
|
self._downloader_factory = downloader_factory
|
||||||
|
|
||||||
@@ -330,16 +342,65 @@ class LoraCodeHandler:
|
|||||||
logger.error("Error broadcasting lora code: %s", exc)
|
logger.error("Error broadcasting lora code: %s", exc)
|
||||||
results.append({"node_id": "broadcast", "success": False, "error": str(exc)})
|
results.append({"node_id": "broadcast", "success": False, "error": str(exc)})
|
||||||
else:
|
else:
|
||||||
for node_id in node_ids:
|
for entry in node_ids:
|
||||||
|
node_identifier = entry
|
||||||
|
graph_identifier = None
|
||||||
|
if isinstance(entry, dict):
|
||||||
|
node_identifier = entry.get("node_id")
|
||||||
|
graph_identifier = entry.get("graph_id")
|
||||||
|
|
||||||
|
if node_identifier is None:
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"node_id": node_identifier,
|
||||||
|
"graph_id": graph_identifier,
|
||||||
|
"success": False,
|
||||||
|
"error": "Missing node_id parameter",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_node_id = int(node_identifier)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
parsed_node_id = node_identifier
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"id": parsed_node_id,
|
||||||
|
"lora_code": lora_code,
|
||||||
|
"mode": mode,
|
||||||
|
}
|
||||||
|
|
||||||
|
if graph_identifier is not None:
|
||||||
|
payload["graph_id"] = str(graph_identifier)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._prompt_server.instance.send_sync(
|
self._prompt_server.instance.send_sync(
|
||||||
"lora_code_update",
|
"lora_code_update",
|
||||||
{"id": node_id, "lora_code": lora_code, "mode": mode},
|
payload,
|
||||||
|
)
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"node_id": parsed_node_id,
|
||||||
|
"graph_id": payload.get("graph_id"),
|
||||||
|
"success": True,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
results.append({"node_id": node_id, "success": True})
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
logger.error("Error sending lora code to node %s: %s", node_id, exc)
|
logger.error(
|
||||||
results.append({"node_id": node_id, "success": False, "error": str(exc)})
|
"Error sending lora code to node %s (graph %s): %s",
|
||||||
|
parsed_node_id,
|
||||||
|
graph_identifier,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"node_id": parsed_node_id,
|
||||||
|
"graph_id": payload.get("graph_id"),
|
||||||
|
"success": False,
|
||||||
|
"error": str(exc),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return web.json_response({"success": True, "results": results})
|
return web.json_response({"success": True, "results": results})
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
@@ -557,17 +618,118 @@ class ModelLibraryHandler:
|
|||||||
logger.error("Failed to get model versions status: %s", exc, exc_info=True)
|
logger.error("Failed to get model versions status: %s", exc, exc_info=True)
|
||||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def get_civitai_user_models(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
username = request.query.get("username")
|
||||||
|
if not username:
|
||||||
|
return web.json_response({"success": False, "error": "Missing required parameter: username"}, status=400)
|
||||||
|
|
||||||
|
metadata_provider = await self._metadata_provider_factory()
|
||||||
|
if not metadata_provider:
|
||||||
|
return web.json_response({"success": False, "error": "Metadata provider not available"}, status=503)
|
||||||
|
|
||||||
|
try:
|
||||||
|
models = await metadata_provider.get_user_models(username)
|
||||||
|
except NotImplementedError:
|
||||||
|
return web.json_response({"success": False, "error": "Metadata provider does not support user model queries"}, status=501)
|
||||||
|
|
||||||
|
if models is None:
|
||||||
|
return web.json_response({"success": False, "error": "Failed to fetch user models"}, status=502)
|
||||||
|
|
||||||
|
if not isinstance(models, list):
|
||||||
|
models = []
|
||||||
|
|
||||||
|
lora_scanner = await self._service_registry.get_lora_scanner()
|
||||||
|
checkpoint_scanner = await self._service_registry.get_checkpoint_scanner()
|
||||||
|
embedding_scanner = await self._service_registry.get_embedding_scanner()
|
||||||
|
|
||||||
|
normalized_allowed_types = {model_type.lower() for model_type in CIVITAI_USER_MODEL_TYPES}
|
||||||
|
lora_type_aliases = {model_type.lower() for model_type in VALID_LORA_TYPES}
|
||||||
|
|
||||||
|
type_scanner_map: Dict[str, object | None] = {
|
||||||
|
**{alias: lora_scanner for alias in lora_type_aliases},
|
||||||
|
"checkpoint": checkpoint_scanner,
|
||||||
|
"textualinversion": embedding_scanner,
|
||||||
|
}
|
||||||
|
|
||||||
|
versions: list[dict] = []
|
||||||
|
for model in models:
|
||||||
|
if not isinstance(model, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_type = str(model.get("type", "")).lower()
|
||||||
|
if model_type not in normalized_allowed_types:
|
||||||
|
continue
|
||||||
|
|
||||||
|
scanner = type_scanner_map.get(model_type)
|
||||||
|
if scanner is None:
|
||||||
|
return web.json_response({"success": False, "error": f'Scanner for type "{model_type}" is not available'}, status=503)
|
||||||
|
|
||||||
|
tags_value = model.get("tags")
|
||||||
|
tags = tags_value if isinstance(tags_value, list) else []
|
||||||
|
model_id = model.get("id")
|
||||||
|
try:
|
||||||
|
model_id_int = int(model_id)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
model_name = model.get("name", "")
|
||||||
|
|
||||||
|
versions_data = model.get("modelVersions")
|
||||||
|
if not isinstance(versions_data, list):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for version in versions_data:
|
||||||
|
if not isinstance(version, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
version_id = version.get("id")
|
||||||
|
try:
|
||||||
|
version_id_int = int(version_id)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
images = version.get("images") or []
|
||||||
|
thumbnail_url = None
|
||||||
|
if images and isinstance(images, list):
|
||||||
|
first_image = images[0]
|
||||||
|
if isinstance(first_image, dict):
|
||||||
|
raw_url = first_image.get("url")
|
||||||
|
media_type = first_image.get("type")
|
||||||
|
rewritten_url, _ = rewrite_preview_url(raw_url, media_type)
|
||||||
|
thumbnail_url = rewritten_url
|
||||||
|
|
||||||
|
in_library = await scanner.check_model_version_exists(version_id_int)
|
||||||
|
|
||||||
|
versions.append(
|
||||||
|
{
|
||||||
|
"modelId": model_id_int,
|
||||||
|
"versionId": version_id_int,
|
||||||
|
"modelName": model_name,
|
||||||
|
"versionName": version.get("name", ""),
|
||||||
|
"type": model.get("type"),
|
||||||
|
"tags": tags,
|
||||||
|
"baseModel": version.get("baseModel"),
|
||||||
|
"thumbnailUrl": thumbnail_url,
|
||||||
|
"inLibrary": in_library,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return web.json_response({"success": True, "username": username, "versions": versions})
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
logger.error("Failed to get Civitai user models: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
|
||||||
class MetadataArchiveHandler:
|
class MetadataArchiveHandler:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
metadata_archive_manager_factory: Callable[[], Awaitable[MetadataArchiveManagerProtocol]] = get_metadata_archive_manager,
|
metadata_archive_manager_factory: Callable[[], Awaitable[MetadataArchiveManagerProtocol]] = get_metadata_archive_manager,
|
||||||
settings_service=default_settings,
|
settings_service=None,
|
||||||
metadata_provider_updater: Callable[[], Awaitable[None]] = update_metadata_providers,
|
metadata_provider_updater: Callable[[], Awaitable[None]] = update_metadata_providers,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._metadata_archive_manager_factory = metadata_archive_manager_factory
|
self._metadata_archive_manager_factory = metadata_archive_manager_factory
|
||||||
self._settings = settings_service
|
self._settings = settings_service or get_settings_manager()
|
||||||
self._metadata_provider_updater = metadata_provider_updater
|
self._metadata_provider_updater = metadata_provider_updater
|
||||||
|
|
||||||
async def download_metadata_archive(self, request: web.Request) -> web.Response:
|
async def download_metadata_archive(self, request: web.Request) -> web.Response:
|
||||||
@@ -679,10 +841,21 @@ class NodeRegistryHandler:
|
|||||||
node_id = node.get("node_id")
|
node_id = node.get("node_id")
|
||||||
if node_id is None:
|
if node_id is None:
|
||||||
return web.json_response({"success": False, "error": f"Node {index} missing node_id parameter"}, status=400)
|
return web.json_response({"success": False, "error": f"Node {index} missing node_id parameter"}, status=400)
|
||||||
|
graph_id = node.get("graph_id")
|
||||||
|
if graph_id is None:
|
||||||
|
return web.json_response({"success": False, "error": f"Node {index} missing graph_id parameter"}, status=400)
|
||||||
|
graph_name = node.get("graph_name")
|
||||||
try:
|
try:
|
||||||
node["node_id"] = int(node_id)
|
node["node_id"] = int(node_id)
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
return web.json_response({"success": False, "error": f"Node {index} node_id must be an integer"}, status=400)
|
return web.json_response({"success": False, "error": f"Node {index} node_id must be an integer"}, status=400)
|
||||||
|
node["graph_id"] = str(graph_id)
|
||||||
|
if graph_name is None:
|
||||||
|
node["graph_name"] = None
|
||||||
|
elif isinstance(graph_name, str):
|
||||||
|
node["graph_name"] = graph_name
|
||||||
|
else:
|
||||||
|
node["graph_name"] = str(graph_name)
|
||||||
|
|
||||||
await self._node_registry.register_nodes(nodes)
|
await self._node_registry.register_nodes(nodes)
|
||||||
return web.json_response({"success": True, "message": f"{len(nodes)} nodes registered successfully"})
|
return web.json_response({"success": True, "message": f"{len(nodes)} nodes registered successfully"})
|
||||||
@@ -779,6 +952,7 @@ class MiscHandlerSet:
|
|||||||
"register_nodes": self.node_registry.register_nodes,
|
"register_nodes": self.node_registry.register_nodes,
|
||||||
"get_registry": self.node_registry.get_registry,
|
"get_registry": self.node_registry.get_registry,
|
||||||
"check_model_exists": self.model_library.check_model_exists,
|
"check_model_exists": self.model_library.check_model_exists,
|
||||||
|
"get_civitai_user_models": self.model_library.get_civitai_user_models,
|
||||||
"download_metadata_archive": self.metadata_archive.download_metadata_archive,
|
"download_metadata_archive": self.metadata_archive.download_metadata_archive,
|
||||||
"remove_metadata_archive": self.metadata_archive.remove_metadata_archive,
|
"remove_metadata_archive": self.metadata_archive.remove_metadata_archive,
|
||||||
"get_metadata_archive_status": self.metadata_archive.get_metadata_archive_status,
|
"get_metadata_archive_status": self.metadata_archive.get_metadata_archive_status,
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from ...services.use_cases import (
|
|||||||
from ...services.websocket_manager import WebSocketManager
|
from ...services.websocket_manager import WebSocketManager
|
||||||
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
||||||
from ...utils.file_utils import calculate_sha256
|
from ...utils.file_utils import calculate_sha256
|
||||||
|
from ...utils.metadata_manager import MetadataManager
|
||||||
|
|
||||||
|
|
||||||
class ModelPageView:
|
class ModelPageView:
|
||||||
@@ -244,6 +245,8 @@ class ModelManagementHandler:
|
|||||||
if not model_data.get("sha256"):
|
if not model_data.get("sha256"):
|
||||||
return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400)
|
return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400)
|
||||||
|
|
||||||
|
await MetadataManager.hydrate_model_data(model_data)
|
||||||
|
|
||||||
success, error = await self._metadata_sync.fetch_and_update_model(
|
success, error = await self._metadata_sync.fetch_and_update_model(
|
||||||
sha256=model_data["sha256"],
|
sha256=model_data["sha256"],
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
@@ -825,18 +828,30 @@ class ModelCivitaiHandler:
|
|||||||
status=400,
|
status=400,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cache = await self._service.scanner.get_cached_data()
|
||||||
|
version_index = cache.version_index
|
||||||
|
|
||||||
for version in versions:
|
for version in versions:
|
||||||
model_file = self._find_model_file(version.get("files", [])) if isinstance(version.get("files"), Iterable) else None
|
version_id = None
|
||||||
if model_file:
|
version_id_raw = version.get("id")
|
||||||
hashes = model_file.get("hashes", {}) if isinstance(model_file, Mapping) else {}
|
if version_id_raw is not None:
|
||||||
sha256 = hashes.get("SHA256") if isinstance(hashes, Mapping) else None
|
try:
|
||||||
if sha256:
|
version_id = int(str(version_id_raw))
|
||||||
version["existsLocally"] = self._service.has_hash(sha256)
|
except (TypeError, ValueError):
|
||||||
if version["existsLocally"]:
|
version_id = None
|
||||||
version["localPath"] = self._service.get_path_by_hash(sha256)
|
|
||||||
version["modelSizeKB"] = model_file.get("sizeKB") if isinstance(model_file, Mapping) else None
|
cache_entry = version_index.get(version_id) if (version_id is not None and version_index) else None
|
||||||
|
version["existsLocally"] = cache_entry is not None
|
||||||
|
if cache_entry and isinstance(cache_entry, Mapping):
|
||||||
|
local_path = cache_entry.get("file_path")
|
||||||
|
if local_path:
|
||||||
|
version["localPath"] = local_path
|
||||||
else:
|
else:
|
||||||
version["existsLocally"] = False
|
version.pop("localPath", None)
|
||||||
|
|
||||||
|
model_file = self._find_model_file(version.get("files", [])) if isinstance(version.get("files"), Iterable) else None
|
||||||
|
if model_file and isinstance(model_file, Mapping):
|
||||||
|
version["modelSizeKB"] = model_file.get("sizeKB")
|
||||||
return web.json_response(versions)
|
return web.json_response(versions)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self._logger.error("Error fetching %s model versions: %s", self._service.model_type, exc)
|
self._logger.error("Error fetching %s model versions: %s", self._service.model_type, exc)
|
||||||
|
|||||||
@@ -229,11 +229,27 @@ class LoraRoutes(BaseModelRoutes):
|
|||||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||||
|
|
||||||
# Send update to all connected trigger word toggle nodes
|
# Send update to all connected trigger word toggle nodes
|
||||||
for node_id in node_ids:
|
for entry in node_ids:
|
||||||
PromptServer.instance.send_sync("trigger_word_update", {
|
node_identifier = entry
|
||||||
"id": node_id,
|
graph_identifier = None
|
||||||
|
if isinstance(entry, dict):
|
||||||
|
node_identifier = entry.get("node_id")
|
||||||
|
graph_identifier = entry.get("graph_id")
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_node_id = int(node_identifier)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
parsed_node_id = node_identifier
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"id": parsed_node_id,
|
||||||
"message": trigger_words_text
|
"message": trigger_words_text
|
||||||
})
|
}
|
||||||
|
|
||||||
|
if graph_identifier is not None:
|
||||||
|
payload["graph_id"] = str(graph_identifier)
|
||||||
|
|
||||||
|
PromptServer.instance.send_sync("trigger_word_update", payload)
|
||||||
|
|
||||||
return web.json_response({"success": True})
|
return web.json_response({"success": True})
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
|||||||
RouteDefinition("POST", "/api/lm/register-nodes", "register_nodes"),
|
RouteDefinition("POST", "/api/lm/register-nodes", "register_nodes"),
|
||||||
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
|
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
|
||||||
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
|
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
|
||||||
|
RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"),
|
||||||
RouteDefinition("POST", "/api/lm/download-metadata-archive", "download_metadata_archive"),
|
RouteDefinition("POST", "/api/lm/download-metadata-archive", "download_metadata_archive"),
|
||||||
RouteDefinition("POST", "/api/lm/remove-metadata-archive", "remove_metadata_archive"),
|
RouteDefinition("POST", "/api/lm/remove-metadata-archive", "remove_metadata_archive"),
|
||||||
RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"),
|
RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"),
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from ..services.metadata_service import (
|
|||||||
get_metadata_provider,
|
get_metadata_provider,
|
||||||
update_metadata_providers,
|
update_metadata_providers,
|
||||||
)
|
)
|
||||||
from ..services.settings_manager import settings
|
from ..services.settings_manager import get_settings_manager
|
||||||
from ..services.downloader import get_downloader
|
from ..services.downloader import get_downloader
|
||||||
from ..utils.usage_stats import UsageStats
|
from ..utils.usage_stats import UsageStats
|
||||||
from .handlers.misc_handlers import (
|
from .handlers.misc_handlers import (
|
||||||
@@ -47,7 +47,7 @@ class MiscRoutes:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
settings_service=settings,
|
settings_service=None,
|
||||||
usage_stats_factory: Callable[[], UsageStats] = UsageStats,
|
usage_stats_factory: Callable[[], UsageStats] = UsageStats,
|
||||||
prompt_server: type[PromptServer] = PromptServer,
|
prompt_server: type[PromptServer] = PromptServer,
|
||||||
service_registry_adapter=build_service_registry_adapter(),
|
service_registry_adapter=build_service_registry_adapter(),
|
||||||
@@ -60,7 +60,7 @@ class MiscRoutes:
|
|||||||
node_registry: NodeRegistry | None = None,
|
node_registry: NodeRegistry | None = None,
|
||||||
standalone_mode_flag: bool = standalone_mode,
|
standalone_mode_flag: bool = standalone_mode,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._settings = settings_service
|
self._settings = settings_service or get_settings_manager()
|
||||||
self._usage_stats_factory = usage_stats_factory
|
self._usage_stats_factory = usage_stats_factory
|
||||||
self._prompt_server = prompt_server
|
self._prompt_server = prompt_server
|
||||||
self._service_registry_adapter = service_registry_adapter
|
self._service_registry_adapter = service_registry_adapter
|
||||||
|
|||||||
@@ -8,13 +8,32 @@ from collections import defaultdict, Counter
|
|||||||
from typing import Dict, List, Any
|
from typing import Dict, List, Any
|
||||||
|
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..services.settings_manager import settings
|
from ..services.settings_manager import get_settings_manager
|
||||||
from ..services.server_i18n import server_i18n
|
from ..services.server_i18n import server_i18n
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..utils.usage_stats import UsageStats
|
from ..utils.usage_stats import UsageStats
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class _SettingsProxy:
|
||||||
|
def __init__(self):
|
||||||
|
self._manager = None
|
||||||
|
|
||||||
|
def _resolve(self):
|
||||||
|
if self._manager is None:
|
||||||
|
self._manager = get_settings_manager()
|
||||||
|
return self._manager
|
||||||
|
|
||||||
|
def get(self, *args, **kwargs):
|
||||||
|
return self._resolve().get(*args, **kwargs)
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
return getattr(self._resolve(), item)
|
||||||
|
|
||||||
|
|
||||||
|
settings = _SettingsProxy()
|
||||||
|
|
||||||
class StatsRoutes:
|
class StatsRoutes:
|
||||||
"""Route handlers for Statistics page and API endpoints"""
|
"""Route handlers for Statistics page and API endpoints"""
|
||||||
|
|
||||||
@@ -66,7 +85,9 @@ class StatsRoutes:
|
|||||||
is_initializing = lora_initializing or checkpoint_initializing or embedding_initializing
|
is_initializing = lora_initializing or checkpoint_initializing or embedding_initializing
|
||||||
|
|
||||||
# 获取用户语言设置
|
# 获取用户语言设置
|
||||||
user_language = settings.get('language', 'en')
|
settings_object = settings
|
||||||
|
user_language = settings_object.get('language', 'en')
|
||||||
|
settings_manager = settings_object if not isinstance(settings_object, _SettingsProxy) else settings_object._resolve()
|
||||||
|
|
||||||
# 设置服务端i18n语言
|
# 设置服务端i18n语言
|
||||||
server_i18n.set_locale(user_language)
|
server_i18n.set_locale(user_language)
|
||||||
@@ -79,7 +100,7 @@ class StatsRoutes:
|
|||||||
template = self.template_env.get_template('statistics.html')
|
template = self.template_env.get_template('statistics.html')
|
||||||
rendered = template.render(
|
rendered = template.render(
|
||||||
is_initializing=is_initializing,
|
is_initializing=is_initializing,
|
||||||
settings=settings,
|
settings=settings_manager,
|
||||||
request=request,
|
request=request,
|
||||||
t=server_i18n.get_translation,
|
t=server_i18n.get_translation,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import os
|
|||||||
from ..utils.models import BaseModelMetadata
|
from ..utils.models import BaseModelMetadata
|
||||||
from ..utils.metadata_manager import MetadataManager
|
from ..utils.metadata_manager import MetadataManager
|
||||||
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider
|
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider
|
||||||
from .settings_manager import settings as default_settings
|
from .settings_manager import get_settings_manager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -38,7 +38,7 @@ class BaseModelService(ABC):
|
|||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.scanner = scanner
|
self.scanner = scanner
|
||||||
self.metadata_class = metadata_class
|
self.metadata_class = metadata_class
|
||||||
self.settings = settings_provider or default_settings
|
self.settings = settings_provider or get_settings_manager()
|
||||||
self.cache_repository = cache_repository or ModelCacheRepository(scanner)
|
self.cache_repository = cache_repository or ModelCacheRepository(scanner)
|
||||||
self.filter_set = filter_set or ModelFilterSet(self.settings)
|
self.filter_set = filter_set or ModelFilterSet(self.settings)
|
||||||
self.search_strategy = search_strategy or SearchStrategy()
|
self.search_strategy = search_strategy or SearchStrategy()
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import os
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import os
|
||||||
from typing import Optional, Dict, Tuple, List
|
from typing import Optional, Dict, Tuple, List
|
||||||
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
|
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
|
||||||
from .downloader import get_downloader
|
from .downloader import get_downloader
|
||||||
@@ -157,141 +157,160 @@ class CivitaiClient:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||||
"""Get specific model version with additional metadata
|
"""Get specific model version with additional metadata."""
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id: The Civitai model ID (optional if version_id is provided)
|
|
||||||
version_id: Optional specific version ID to retrieve
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Dict]: The model version data with additional fields or None if not found
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
downloader = await get_downloader()
|
downloader = await get_downloader()
|
||||||
|
|
||||||
# Case 1: Only version_id is provided
|
|
||||||
if model_id is None and version_id is not None:
|
if model_id is None and version_id is not None:
|
||||||
# First get the version info to extract model_id
|
return await self._get_version_by_id_only(downloader, version_id)
|
||||||
success, version = await downloader.make_request(
|
|
||||||
'GET',
|
|
||||||
f"{self.base_url}/model-versions/{version_id}",
|
|
||||||
use_auth=True
|
|
||||||
)
|
|
||||||
if not success:
|
|
||||||
return None
|
|
||||||
|
|
||||||
model_id = version.get('modelId')
|
|
||||||
if not model_id:
|
|
||||||
logger.error(f"No modelId found in version {version_id}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Now get the model data for additional metadata
|
|
||||||
success, model_data = await downloader.make_request(
|
|
||||||
'GET',
|
|
||||||
f"{self.base_url}/models/{model_id}",
|
|
||||||
use_auth=True
|
|
||||||
)
|
|
||||||
if success:
|
|
||||||
# Enrich version with model data
|
|
||||||
version['model']['description'] = model_data.get("description")
|
|
||||||
version['model']['tags'] = model_data.get("tags", [])
|
|
||||||
version['creator'] = model_data.get("creator")
|
|
||||||
|
|
||||||
self._remove_comfy_metadata(version)
|
if model_id is not None:
|
||||||
return version
|
return await self._get_version_with_model_id(downloader, model_id, version_id)
|
||||||
|
|
||||||
# Case 2: model_id is provided (with or without version_id)
|
|
||||||
elif model_id is not None:
|
|
||||||
# Step 1: Get model data to find version_id if not provided and get additional metadata
|
|
||||||
success, data = await downloader.make_request(
|
|
||||||
'GET',
|
|
||||||
f"{self.base_url}/models/{model_id}",
|
|
||||||
use_auth=True
|
|
||||||
)
|
|
||||||
if not success:
|
|
||||||
return None
|
|
||||||
|
|
||||||
model_versions = data.get('modelVersions', [])
|
logger.error("Either model_id or version_id must be provided")
|
||||||
if not model_versions:
|
return None
|
||||||
logger.warning(f"No model versions found for model {model_id}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Step 2: Determine the target version entry to use
|
|
||||||
target_version = None
|
|
||||||
if version_id is not None:
|
|
||||||
target_version = next(
|
|
||||||
(item for item in model_versions if item.get('id') == version_id),
|
|
||||||
None
|
|
||||||
)
|
|
||||||
if target_version is None:
|
|
||||||
logger.warning(
|
|
||||||
f"Version {version_id} not found for model {model_id}, defaulting to first version"
|
|
||||||
)
|
|
||||||
if target_version is None:
|
|
||||||
target_version = model_versions[0]
|
|
||||||
|
|
||||||
target_version_id = target_version.get('id')
|
|
||||||
|
|
||||||
# Step 3: Get detailed version info using the SHA256 hash
|
|
||||||
model_hash = None
|
|
||||||
for file_info in target_version.get('files', []):
|
|
||||||
if file_info.get('type') == 'Model' and file_info.get('primary'):
|
|
||||||
model_hash = file_info.get('hashes', {}).get('SHA256')
|
|
||||||
if model_hash:
|
|
||||||
break
|
|
||||||
|
|
||||||
version = None
|
|
||||||
if model_hash:
|
|
||||||
success, version = await downloader.make_request(
|
|
||||||
'GET',
|
|
||||||
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
|
||||||
use_auth=True
|
|
||||||
)
|
|
||||||
if not success:
|
|
||||||
logger.warning(
|
|
||||||
f"Failed to fetch version by hash for model {model_id} version {target_version_id}: {version}"
|
|
||||||
)
|
|
||||||
version = None
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"No primary model hash found for model {model_id} version {target_version_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if version is None:
|
|
||||||
version = copy.deepcopy(target_version)
|
|
||||||
version.pop('index', None)
|
|
||||||
version['modelId'] = model_id
|
|
||||||
version['model'] = {
|
|
||||||
'name': data.get('name'),
|
|
||||||
'type': data.get('type'),
|
|
||||||
'nsfw': data.get('nsfw'),
|
|
||||||
'poi': data.get('poi')
|
|
||||||
}
|
|
||||||
|
|
||||||
# Step 4: Enrich version_info with model data
|
|
||||||
# Add description and tags from model data
|
|
||||||
model_info = version.get('model')
|
|
||||||
if not isinstance(model_info, dict):
|
|
||||||
model_info = {}
|
|
||||||
version['model'] = model_info
|
|
||||||
model_info['description'] = data.get("description")
|
|
||||||
model_info['tags'] = data.get("tags", [])
|
|
||||||
|
|
||||||
# Add creator from model data
|
|
||||||
version['creator'] = data.get("creator")
|
|
||||||
|
|
||||||
self._remove_comfy_metadata(version)
|
|
||||||
return version
|
|
||||||
|
|
||||||
# Case 3: Neither model_id nor version_id provided
|
|
||||||
else:
|
|
||||||
logger.error("Either model_id or version_id must be provided")
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching model version: {e}")
|
logger.error(f"Error fetching model version: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def _get_version_by_id_only(self, downloader, version_id: int) -> Optional[Dict]:
|
||||||
|
version = await self._fetch_version_by_id(downloader, version_id)
|
||||||
|
if version is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
model_id = version.get('modelId')
|
||||||
|
if not model_id:
|
||||||
|
logger.error(f"No modelId found in version {version_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
model_data = await self._fetch_model_data(downloader, model_id)
|
||||||
|
if model_data:
|
||||||
|
self._enrich_version_with_model_data(version, model_data)
|
||||||
|
|
||||||
|
self._remove_comfy_metadata(version)
|
||||||
|
return version
|
||||||
|
|
||||||
|
async def _get_version_with_model_id(self, downloader, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
|
||||||
|
model_data = await self._fetch_model_data(downloader, model_id)
|
||||||
|
if not model_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
target_version = self._select_target_version(model_data, model_id, version_id)
|
||||||
|
if target_version is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
target_version_id = target_version.get('id')
|
||||||
|
version = await self._fetch_version_by_id(downloader, target_version_id) if target_version_id else None
|
||||||
|
|
||||||
|
if version is None:
|
||||||
|
model_hash = self._extract_primary_model_hash(target_version)
|
||||||
|
if model_hash:
|
||||||
|
version = await self._fetch_version_by_hash(downloader, model_hash)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"No primary model hash found for model {model_id} version {target_version_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if version is None:
|
||||||
|
version = self._build_version_from_model_data(target_version, model_id, model_data)
|
||||||
|
|
||||||
|
self._enrich_version_with_model_data(version, model_data)
|
||||||
|
self._remove_comfy_metadata(version)
|
||||||
|
return version
|
||||||
|
|
||||||
|
async def _fetch_model_data(self, downloader, model_id: int) -> Optional[Dict]:
|
||||||
|
success, data = await downloader.make_request(
|
||||||
|
'GET',
|
||||||
|
f"{self.base_url}/models/{model_id}",
|
||||||
|
use_auth=True
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
return data
|
||||||
|
logger.warning(f"Failed to fetch model data for model {model_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _fetch_version_by_id(self, downloader, version_id: Optional[int]) -> Optional[Dict]:
|
||||||
|
if version_id is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
success, version = await downloader.make_request(
|
||||||
|
'GET',
|
||||||
|
f"{self.base_url}/model-versions/{version_id}",
|
||||||
|
use_auth=True
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
return version
|
||||||
|
|
||||||
|
logger.warning(f"Failed to fetch version by id {version_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _fetch_version_by_hash(self, downloader, model_hash: Optional[str]) -> Optional[Dict]:
|
||||||
|
if not model_hash:
|
||||||
|
return None
|
||||||
|
|
||||||
|
success, version = await downloader.make_request(
|
||||||
|
'GET',
|
||||||
|
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
||||||
|
use_auth=True
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
return version
|
||||||
|
|
||||||
|
logger.warning(f"Failed to fetch version by hash {model_hash}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _select_target_version(self, model_data: Dict, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
|
||||||
|
model_versions = model_data.get('modelVersions', [])
|
||||||
|
if not model_versions:
|
||||||
|
logger.warning(f"No model versions found for model {model_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if version_id is not None:
|
||||||
|
target_version = next(
|
||||||
|
(item for item in model_versions if item.get('id') == version_id),
|
||||||
|
None
|
||||||
|
)
|
||||||
|
if target_version is None:
|
||||||
|
logger.warning(
|
||||||
|
f"Version {version_id} not found for model {model_id}, defaulting to first version"
|
||||||
|
)
|
||||||
|
return model_versions[0]
|
||||||
|
return target_version
|
||||||
|
|
||||||
|
return model_versions[0]
|
||||||
|
|
||||||
|
def _extract_primary_model_hash(self, version_entry: Dict) -> Optional[str]:
|
||||||
|
for file_info in version_entry.get('files', []):
|
||||||
|
if file_info.get('type') == 'Model' and file_info.get('primary'):
|
||||||
|
hashes = file_info.get('hashes', {})
|
||||||
|
model_hash = hashes.get('SHA256')
|
||||||
|
if model_hash:
|
||||||
|
return model_hash
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _build_version_from_model_data(self, version_entry: Dict, model_id: int, model_data: Dict) -> Dict:
|
||||||
|
version = copy.deepcopy(version_entry)
|
||||||
|
version.pop('index', None)
|
||||||
|
version['modelId'] = model_id
|
||||||
|
version['model'] = {
|
||||||
|
'name': model_data.get('name'),
|
||||||
|
'type': model_data.get('type'),
|
||||||
|
'nsfw': model_data.get('nsfw'),
|
||||||
|
'poi': model_data.get('poi')
|
||||||
|
}
|
||||||
|
return version
|
||||||
|
|
||||||
|
def _enrich_version_with_model_data(self, version: Dict, model_data: Dict) -> None:
|
||||||
|
model_info = version.get('model')
|
||||||
|
if not isinstance(model_info, dict):
|
||||||
|
model_info = {}
|
||||||
|
version['model'] = model_info
|
||||||
|
|
||||||
|
model_info['description'] = model_data.get("description")
|
||||||
|
model_info['tags'] = model_data.get("tags", [])
|
||||||
|
version['creator'] = model_data.get("creator")
|
||||||
|
|
||||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
"""Fetch model version metadata from Civitai
|
"""Fetch model version metadata from Civitai
|
||||||
|
|
||||||
@@ -335,7 +354,7 @@ class CivitaiClient:
|
|||||||
|
|
||||||
async def get_image_info(self, image_id: str) -> Optional[Dict]:
|
async def get_image_info(self, image_id: str) -> Optional[Dict]:
|
||||||
"""Fetch image information from Civitai API
|
"""Fetch image information from Civitai API
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image_id: The Civitai image ID
|
image_id: The Civitai image ID
|
||||||
|
|
||||||
@@ -366,3 +385,37 @@ class CivitaiClient:
|
|||||||
error_msg = f"Error fetching image info: {e}"
|
error_msg = f"Error fetching image info: {e}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||||
|
"""Fetch all models for a specific Civitai user."""
|
||||||
|
if not username:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
downloader = await get_downloader()
|
||||||
|
url = f"{self.base_url}/models?username={username}"
|
||||||
|
success, result = await downloader.make_request(
|
||||||
|
'GET',
|
||||||
|
url,
|
||||||
|
use_auth=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
logger.error("Failed to fetch models for %s: %s", username, result)
|
||||||
|
return None
|
||||||
|
|
||||||
|
items = result.get("items") if isinstance(result, dict) else None
|
||||||
|
if not isinstance(items, list):
|
||||||
|
return []
|
||||||
|
|
||||||
|
for model in items:
|
||||||
|
versions = model.get("modelVersions")
|
||||||
|
if not isinstance(versions, list):
|
||||||
|
continue
|
||||||
|
for version in versions:
|
||||||
|
self._remove_comfy_metadata(version)
|
||||||
|
|
||||||
|
return items
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
logger.error("Error fetching models for %s: %s", username, exc)
|
||||||
|
return None
|
||||||
|
|||||||
@@ -4,12 +4,14 @@ import asyncio
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
from urllib.parse import urlparse
|
||||||
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
||||||
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
|
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
|
||||||
|
from ..utils.civitai_utils import rewrite_preview_url
|
||||||
from ..utils.exif_utils import ExifUtils
|
from ..utils.exif_utils import ExifUtils
|
||||||
from ..utils.metadata_manager import MetadataManager
|
from ..utils.metadata_manager import MetadataManager
|
||||||
from .service_registry import ServiceRegistry
|
from .service_registry import ServiceRegistry
|
||||||
from .settings_manager import settings
|
from .settings_manager import get_settings_manager
|
||||||
from .metadata_service import get_default_metadata_provider
|
from .metadata_service import get_default_metadata_provider
|
||||||
from .downloader import get_downloader
|
from .downloader import get_downloader
|
||||||
|
|
||||||
@@ -241,23 +243,24 @@ class DownloadManager:
|
|||||||
|
|
||||||
# Handle use_default_paths
|
# Handle use_default_paths
|
||||||
if use_default_paths:
|
if use_default_paths:
|
||||||
|
settings_manager = get_settings_manager()
|
||||||
# Set save_dir based on model type
|
# Set save_dir based on model type
|
||||||
if model_type == 'checkpoint':
|
if model_type == 'checkpoint':
|
||||||
default_path = settings.get('default_checkpoint_root')
|
default_path = settings_manager.get('default_checkpoint_root')
|
||||||
if not default_path:
|
if not default_path:
|
||||||
return {'success': False, 'error': 'Default checkpoint root path not set in settings'}
|
return {'success': False, 'error': 'Default checkpoint root path not set in settings'}
|
||||||
save_dir = default_path
|
save_dir = default_path
|
||||||
elif model_type == 'lora':
|
elif model_type == 'lora':
|
||||||
default_path = settings.get('default_lora_root')
|
default_path = settings_manager.get('default_lora_root')
|
||||||
if not default_path:
|
if not default_path:
|
||||||
return {'success': False, 'error': 'Default lora root path not set in settings'}
|
return {'success': False, 'error': 'Default lora root path not set in settings'}
|
||||||
save_dir = default_path
|
save_dir = default_path
|
||||||
elif model_type == 'embedding':
|
elif model_type == 'embedding':
|
||||||
default_path = settings.get('default_embedding_root')
|
default_path = settings_manager.get('default_embedding_root')
|
||||||
if not default_path:
|
if not default_path:
|
||||||
return {'success': False, 'error': 'Default embedding root path not set in settings'}
|
return {'success': False, 'error': 'Default embedding root path not set in settings'}
|
||||||
save_dir = default_path
|
save_dir = default_path
|
||||||
|
|
||||||
# Calculate relative path using template
|
# Calculate relative path using template
|
||||||
relative_path = self._calculate_relative_path(version_info, model_type)
|
relative_path = self._calculate_relative_path(version_info, model_type)
|
||||||
|
|
||||||
@@ -360,7 +363,8 @@ class DownloadManager:
|
|||||||
Relative path string
|
Relative path string
|
||||||
"""
|
"""
|
||||||
# Get path template from settings for specific model type
|
# Get path template from settings for specific model type
|
||||||
path_template = settings.get_download_path_template(model_type)
|
settings_manager = get_settings_manager()
|
||||||
|
path_template = settings_manager.get_download_path_template(model_type)
|
||||||
|
|
||||||
# If template is empty, return empty path (flat structure)
|
# If template is empty, return empty path (flat structure)
|
||||||
if not path_template:
|
if not path_template:
|
||||||
@@ -377,7 +381,7 @@ class DownloadManager:
|
|||||||
author = 'Anonymous'
|
author = 'Anonymous'
|
||||||
|
|
||||||
# Apply mapping if available
|
# Apply mapping if available
|
||||||
base_model_mappings = settings.get('base_model_path_mappings', {})
|
base_model_mappings = settings_manager.get('base_model_path_mappings', {})
|
||||||
mapped_base_model = base_model_mappings.get(base_model, base_model)
|
mapped_base_model = base_model_mappings.get(base_model, base_model)
|
||||||
|
|
||||||
# Get model tags
|
# Get model tags
|
||||||
@@ -448,70 +452,103 @@ class DownloadManager:
|
|||||||
# Download preview image if available
|
# Download preview image if available
|
||||||
images = version_info.get('images', [])
|
images = version_info.get('images', [])
|
||||||
if images:
|
if images:
|
||||||
# Report preview download progress
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
await progress_callback(1) # 1% progress for starting preview download
|
await progress_callback(1) # 1% progress for starting preview download
|
||||||
|
|
||||||
# Check if it's a video or an image
|
first_image = images[0] if isinstance(images[0], dict) else None
|
||||||
is_video = images[0].get('type') == 'video'
|
preview_url = first_image.get('url') if first_image else None
|
||||||
|
media_type = (first_image.get('type') or '').lower() if first_image else ''
|
||||||
if (is_video):
|
nsfw_level = first_image.get('nsfwLevel', 0) if first_image else 0
|
||||||
# For videos, use .mp4 extension
|
|
||||||
preview_ext = '.mp4'
|
def _extension_from_url(url: str, fallback: str) -> str:
|
||||||
preview_path = os.path.splitext(save_path)[0] + preview_ext
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
# Download video directly using downloader
|
except ValueError:
|
||||||
downloader = await get_downloader()
|
return fallback
|
||||||
success, result = await downloader.download_file(
|
ext = os.path.splitext(parsed.path)[1]
|
||||||
images[0]['url'],
|
return ext or fallback
|
||||||
preview_path,
|
|
||||||
use_auth=False # Preview images typically don't need auth
|
preview_downloaded = False
|
||||||
)
|
preview_path = None
|
||||||
if success:
|
|
||||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
if preview_url:
|
||||||
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
|
downloader = await get_downloader()
|
||||||
else:
|
|
||||||
# For images, use WebP format for better performance
|
if media_type == 'video':
|
||||||
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
|
preview_ext = _extension_from_url(preview_url, '.mp4')
|
||||||
temp_path = temp_file.name
|
preview_path = os.path.splitext(save_path)[0] + preview_ext
|
||||||
|
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='video')
|
||||||
# Download the original image to temp path using downloader
|
attempt_urls: List[str] = []
|
||||||
downloader = await get_downloader()
|
if rewritten:
|
||||||
success, content, headers = await downloader.download_to_memory(
|
attempt_urls.append(rewritten_url)
|
||||||
images[0]['url'],
|
attempt_urls.append(preview_url)
|
||||||
use_auth=False
|
|
||||||
)
|
seen_attempts = set()
|
||||||
if success:
|
for attempt in attempt_urls:
|
||||||
# Save to temp file
|
if not attempt or attempt in seen_attempts:
|
||||||
with open(temp_path, 'wb') as f:
|
continue
|
||||||
f.write(content)
|
seen_attempts.add(attempt)
|
||||||
# Optimize and convert to WebP
|
success, _ = await downloader.download_file(
|
||||||
preview_path = os.path.splitext(save_path)[0] + '.webp'
|
attempt,
|
||||||
|
preview_path,
|
||||||
# Use ExifUtils to optimize and convert the image
|
use_auth=False
|
||||||
optimized_data, _ = ExifUtils.optimize_image(
|
)
|
||||||
image_data=temp_path,
|
if success:
|
||||||
target_width=CARD_PREVIEW_WIDTH,
|
preview_downloaded = True
|
||||||
format='webp',
|
break
|
||||||
quality=85,
|
else:
|
||||||
preserve_metadata=False
|
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='image')
|
||||||
)
|
if rewritten:
|
||||||
|
preview_ext = _extension_from_url(preview_url, '.png')
|
||||||
# Save the optimized image
|
preview_path = os.path.splitext(save_path)[0] + preview_ext
|
||||||
with open(preview_path, 'wb') as f:
|
success, _ = await downloader.download_file(
|
||||||
f.write(optimized_data)
|
rewritten_url,
|
||||||
|
preview_path,
|
||||||
# Update metadata
|
use_auth=False
|
||||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
)
|
||||||
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
|
if success:
|
||||||
|
preview_downloaded = True
|
||||||
# Remove temporary file
|
|
||||||
try:
|
if not preview_downloaded:
|
||||||
os.unlink(temp_path)
|
temp_path: str | None = None
|
||||||
except Exception as e:
|
try:
|
||||||
logger.warning(f"Failed to delete temp file: {e}")
|
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
|
||||||
|
temp_path = temp_file.name
|
||||||
|
|
||||||
|
success, content, _ = await downloader.download_to_memory(
|
||||||
|
preview_url,
|
||||||
|
use_auth=False
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
with open(temp_path, 'wb') as temp_file_handle:
|
||||||
|
temp_file_handle.write(content)
|
||||||
|
preview_path = os.path.splitext(save_path)[0] + '.webp'
|
||||||
|
|
||||||
|
optimized_data, _ = ExifUtils.optimize_image(
|
||||||
|
image_data=temp_path,
|
||||||
|
target_width=CARD_PREVIEW_WIDTH,
|
||||||
|
format='webp',
|
||||||
|
quality=85,
|
||||||
|
preserve_metadata=False
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(preview_path, 'wb') as preview_file:
|
||||||
|
preview_file.write(optimized_data)
|
||||||
|
|
||||||
|
preview_downloaded = True
|
||||||
|
finally:
|
||||||
|
if temp_path and os.path.exists(temp_path):
|
||||||
|
try:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete temp file: {e}")
|
||||||
|
|
||||||
|
if preview_downloaded and preview_path:
|
||||||
|
metadata.preview_url = preview_path.replace(os.sep, '/')
|
||||||
|
metadata.preview_nsfw_level = nsfw_level
|
||||||
|
if download_id and download_id in self._active_downloads:
|
||||||
|
self._active_downloads[download_id]['preview_path'] = preview_path
|
||||||
|
|
||||||
# Report preview download completion
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
await progress_callback(3) # 3% progress after preview download
|
await progress_callback(3) # 3% progress after preview download
|
||||||
|
|
||||||
@@ -675,7 +712,15 @@ class DownloadManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting metadata file: {e}")
|
logger.error(f"Error deleting metadata file: {e}")
|
||||||
|
|
||||||
# Delete preview file if exists (.webp or .mp4)
|
preview_path_value = download_info.get('preview_path')
|
||||||
|
if preview_path_value and os.path.exists(preview_path_value):
|
||||||
|
try:
|
||||||
|
os.unlink(preview_path_value)
|
||||||
|
logger.debug(f"Deleted preview file: {preview_path_value}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting preview file: {e}")
|
||||||
|
|
||||||
|
# Delete preview file if exists (.webp or .mp4) for legacy paths
|
||||||
for preview_ext in ['.webp', '.mp4']:
|
for preview_ext in ['.webp', '.mp4']:
|
||||||
preview_path = os.path.splitext(file_path)[0] + preview_ext
|
preview_path = os.path.splitext(file_path)[0] + preview_ext
|
||||||
if os.path.exists(preview_path):
|
if os.path.exists(preview_path):
|
||||||
@@ -708,4 +753,4 @@ class DownloadManager:
|
|||||||
}
|
}
|
||||||
for task_id, info in self._active_downloads.items()
|
for task_id, info in self._active_downloads.items()
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import asyncio
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Dict, Tuple, Callable, Union
|
from typing import Optional, Dict, Tuple, Callable, Union
|
||||||
from ..services.settings_manager import settings
|
from ..services.settings_manager import get_settings_manager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -94,12 +94,13 @@ class Downloader:
|
|||||||
|
|
||||||
# Check for app-level proxy settings
|
# Check for app-level proxy settings
|
||||||
proxy_url = None
|
proxy_url = None
|
||||||
if settings.get('proxy_enabled', False):
|
settings_manager = get_settings_manager()
|
||||||
proxy_host = settings.get('proxy_host', '').strip()
|
if settings_manager.get('proxy_enabled', False):
|
||||||
proxy_port = settings.get('proxy_port', '').strip()
|
proxy_host = settings_manager.get('proxy_host', '').strip()
|
||||||
proxy_type = settings.get('proxy_type', 'http').lower()
|
proxy_port = settings_manager.get('proxy_port', '').strip()
|
||||||
proxy_username = settings.get('proxy_username', '').strip()
|
proxy_type = settings_manager.get('proxy_type', 'http').lower()
|
||||||
proxy_password = settings.get('proxy_password', '').strip()
|
proxy_username = settings_manager.get('proxy_username', '').strip()
|
||||||
|
proxy_password = settings_manager.get('proxy_password', '').strip()
|
||||||
|
|
||||||
if proxy_host and proxy_port:
|
if proxy_host and proxy_port:
|
||||||
# Build proxy URL
|
# Build proxy URL
|
||||||
@@ -146,7 +147,8 @@ class Downloader:
|
|||||||
|
|
||||||
if use_auth:
|
if use_auth:
|
||||||
# Add CivitAI API key if available
|
# Add CivitAI API key if available
|
||||||
api_key = settings.get('civitai_api_key')
|
settings_manager = get_settings_manager()
|
||||||
|
api_key = settings_manager.get('civitai_api_key')
|
||||||
if api_key:
|
if api_key:
|
||||||
headers['Authorization'] = f'Bearer {api_key}'
|
headers['Authorization'] = f'Bearer {api_key}'
|
||||||
headers['Content-Type'] = 'application/json'
|
headers['Content-Type'] = 'application/json'
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from pathlib import Path
|
|||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from .service_registry import ServiceRegistry
|
from .service_registry import ServiceRegistry
|
||||||
from .settings_manager import settings
|
from .settings_manager import get_settings_manager
|
||||||
from ..utils.example_images_paths import iter_library_roots
|
from ..utils.example_images_paths import iter_library_roots
|
||||||
|
|
||||||
|
|
||||||
@@ -62,7 +62,8 @@ class ExampleImagesCleanupService:
|
|||||||
async def cleanup_example_image_folders(self) -> Dict[str, object]:
|
async def cleanup_example_image_folders(self) -> Dict[str, object]:
|
||||||
"""Clean empty or orphaned example image folders by moving them under a deleted bucket."""
|
"""Clean empty or orphaned example image folders by moving them under a deleted bucket."""
|
||||||
|
|
||||||
example_images_path = settings.get("example_images_path")
|
settings_manager = get_settings_manager()
|
||||||
|
example_images_path = settings_manager.get("example_images_path")
|
||||||
if not example_images_path:
|
if not example_images_path:
|
||||||
logger.debug("Cleanup skipped: example images path not configured")
|
logger.debug("Cleanup skipped: example images path not configured")
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from .model_metadata_provider import (
|
|||||||
CivitaiModelMetadataProvider,
|
CivitaiModelMetadataProvider,
|
||||||
FallbackMetadataProvider
|
FallbackMetadataProvider
|
||||||
)
|
)
|
||||||
from .settings_manager import settings
|
from .settings_manager import get_settings_manager
|
||||||
from .metadata_archive_manager import MetadataArchiveManager
|
from .metadata_archive_manager import MetadataArchiveManager
|
||||||
from .service_registry import ServiceRegistry
|
from .service_registry import ServiceRegistry
|
||||||
|
|
||||||
@@ -21,7 +21,8 @@ async def initialize_metadata_providers():
|
|||||||
provider_manager.default_provider = None
|
provider_manager.default_provider = None
|
||||||
|
|
||||||
# Get settings
|
# Get settings
|
||||||
enable_archive_db = settings.get('enable_metadata_archive_db', False)
|
settings_manager = get_settings_manager()
|
||||||
|
enable_archive_db = settings_manager.get('enable_metadata_archive_db', False)
|
||||||
|
|
||||||
providers = []
|
providers = []
|
||||||
|
|
||||||
@@ -87,7 +88,8 @@ async def update_metadata_providers():
|
|||||||
"""Update metadata providers based on current settings"""
|
"""Update metadata providers based on current settings"""
|
||||||
try:
|
try:
|
||||||
# Get current settings
|
# Get current settings
|
||||||
enable_archive_db = settings.get('enable_metadata_archive_db', False)
|
settings_manager = get_settings_manager()
|
||||||
|
enable_archive_db = settings_manager.get('enable_metadata_archive_db', False)
|
||||||
|
|
||||||
# Reinitialize all providers with new settings
|
# Reinitialize all providers with new settings
|
||||||
provider_manager = await initialize_metadata_providers()
|
provider_manager = await initialize_metadata_providers()
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Dict, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from natsort import natsorted
|
from natsort import natsorted
|
||||||
|
|
||||||
@@ -17,10 +17,12 @@ SUPPORTED_SORT_MODES = [
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelCache:
|
class ModelCache:
|
||||||
"""Cache structure for model data with extensible sorting"""
|
"""Cache structure for model data with extensible sorting."""
|
||||||
|
|
||||||
raw_data: List[Dict]
|
raw_data: List[Dict]
|
||||||
folders: List[str]
|
folders: List[str]
|
||||||
|
version_index: Dict[int, Dict] = field(default_factory=dict)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
# Cache for last sort: (sort_key, order) -> sorted list
|
# Cache for last sort: (sort_key, order) -> sorted list
|
||||||
@@ -28,6 +30,58 @@ class ModelCache:
|
|||||||
self._last_sorted_data: List[Dict] = []
|
self._last_sorted_data: List[Dict] = []
|
||||||
# Default sort on init
|
# Default sort on init
|
||||||
asyncio.create_task(self.resort())
|
asyncio.create_task(self.resort())
|
||||||
|
self.rebuild_version_index()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_version_id(value: Any) -> Optional[int]:
|
||||||
|
"""Normalize a potential version identifier into an integer."""
|
||||||
|
|
||||||
|
if isinstance(value, int):
|
||||||
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def rebuild_version_index(self) -> None:
|
||||||
|
"""Rebuild the version index from the current raw data."""
|
||||||
|
|
||||||
|
self.version_index = {}
|
||||||
|
for item in self.raw_data:
|
||||||
|
self.add_to_version_index(item)
|
||||||
|
|
||||||
|
def add_to_version_index(self, item: Dict) -> None:
|
||||||
|
"""Register a cache item in the version index if possible."""
|
||||||
|
|
||||||
|
civitai_data = item.get('civitai') if isinstance(item, dict) else None
|
||||||
|
if not isinstance(civitai_data, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
version_id = self._normalize_version_id(civitai_data.get('id'))
|
||||||
|
if version_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.version_index[version_id] = item
|
||||||
|
|
||||||
|
def remove_from_version_index(self, item: Dict) -> None:
|
||||||
|
"""Remove a cache item from the version index if present."""
|
||||||
|
|
||||||
|
civitai_data = item.get('civitai') if isinstance(item, dict) else None
|
||||||
|
if not isinstance(civitai_data, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
version_id = self._normalize_version_id(civitai_data.get('id'))
|
||||||
|
if version_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
existing = self.version_index.get(version_id)
|
||||||
|
if existing is item or (
|
||||||
|
isinstance(existing, dict)
|
||||||
|
and existing.get('file_path') == item.get('file_path')
|
||||||
|
):
|
||||||
|
self.version_index.pop(version_id, None)
|
||||||
|
|
||||||
async def resort(self):
|
async def resort(self):
|
||||||
"""Resort cached data according to last sort mode if set"""
|
"""Resort cached data according to last sort mode if set"""
|
||||||
@@ -41,6 +95,7 @@ class ModelCache:
|
|||||||
|
|
||||||
all_folders = set(l['folder'] for l in self.raw_data)
|
all_folders = set(l['folder'] for l in self.raw_data)
|
||||||
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||||
|
self.rebuild_version_index()
|
||||||
|
|
||||||
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
|
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
|
||||||
"""Sort data by sort_key and order"""
|
"""Sort data by sort_key and order"""
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
from ..utils.utils import calculate_relative_path_for_model, remove_empty_dirs
|
from ..utils.utils import calculate_relative_path_for_model, remove_empty_dirs
|
||||||
from ..utils.constants import AUTO_ORGANIZE_BATCH_SIZE
|
from ..utils.constants import AUTO_ORGANIZE_BATCH_SIZE
|
||||||
from ..services.settings_manager import settings
|
from ..services.settings_manager import get_settings_manager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -114,7 +114,8 @@ class ModelFileService:
|
|||||||
raise ValueError('No model roots configured')
|
raise ValueError('No model roots configured')
|
||||||
|
|
||||||
# Check if flat structure is configured for this model type
|
# Check if flat structure is configured for this model type
|
||||||
path_template = settings.get_download_path_template(self.model_type)
|
settings_manager = get_settings_manager()
|
||||||
|
path_template = settings_manager.get_download_path_template(self.model_type)
|
||||||
result.is_flat_structure = not path_template
|
result.is_flat_structure = not path_template
|
||||||
|
|
||||||
# Initialize tracking
|
# Initialize tracking
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Dict, Tuple, Any
|
from typing import Optional, Dict, Tuple, Any, List
|
||||||
from .downloader import get_downloader
|
from .downloader import get_downloader
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -61,6 +61,11 @@ class ModelMetadataProvider(ABC):
|
|||||||
"""Fetch model version metadata"""
|
"""Fetch model version metadata"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||||
|
"""Fetch models owned by the specified user"""
|
||||||
|
pass
|
||||||
|
|
||||||
class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
||||||
"""Provider that uses Civitai API for metadata"""
|
"""Provider that uses Civitai API for metadata"""
|
||||||
|
|
||||||
@@ -79,6 +84,9 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
|||||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
return await self.client.get_model_version_info(version_id)
|
return await self.client.get_model_version_info(version_id)
|
||||||
|
|
||||||
|
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||||
|
return await self.client.get_user_models(username)
|
||||||
|
|
||||||
class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
||||||
"""Provider that uses CivArchive HTML page parsing for metadata"""
|
"""Provider that uses CivArchive HTML page parsing for metadata"""
|
||||||
|
|
||||||
@@ -197,6 +205,10 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
|||||||
"""Not supported by CivArchive provider - requires both model_id and version_id"""
|
"""Not supported by CivArchive provider - requires both model_id and version_id"""
|
||||||
return None, "CivArchive provider requires both model_id and version_id"
|
return None, "CivArchive provider requires both model_id and version_id"
|
||||||
|
|
||||||
|
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||||
|
"""Not supported by CivArchive provider"""
|
||||||
|
return None
|
||||||
|
|
||||||
class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||||
"""Provider that uses SQLite database for metadata"""
|
"""Provider that uses SQLite database for metadata"""
|
||||||
|
|
||||||
@@ -329,20 +341,24 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
|||||||
"""Fetch model version metadata from SQLite database"""
|
"""Fetch model version metadata from SQLite database"""
|
||||||
async with self._aiosqlite.connect(self.db_path) as db:
|
async with self._aiosqlite.connect(self.db_path) as db:
|
||||||
db.row_factory = self._aiosqlite.Row
|
db.row_factory = self._aiosqlite.Row
|
||||||
|
|
||||||
# Get version details
|
# Get version details
|
||||||
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
|
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
|
||||||
cursor = await db.execute(version_query, (version_id,))
|
cursor = await db.execute(version_query, (version_id,))
|
||||||
version_row = await cursor.fetchone()
|
version_row = await cursor.fetchone()
|
||||||
|
|
||||||
if not version_row:
|
if not version_row:
|
||||||
return None, "Model version not found"
|
return None, "Model version not found"
|
||||||
|
|
||||||
model_id = version_row['model_id']
|
model_id = version_row['model_id']
|
||||||
|
|
||||||
# Build complete version data with model info
|
# Build complete version data with model info
|
||||||
version_data = await self._get_version_with_model_data(db, model_id, version_id)
|
version_data = await self._get_version_with_model_data(db, model_id, version_id)
|
||||||
return version_data, None
|
return version_data, None
|
||||||
|
|
||||||
|
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||||
|
"""Listing models by username is not supported for archive database"""
|
||||||
|
return None
|
||||||
|
|
||||||
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
|
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
|
||||||
"""Helper to build version data with model information"""
|
"""Helper to build version data with model information"""
|
||||||
@@ -481,6 +497,17 @@ class FallbackMetadataProvider(ModelMetadataProvider):
|
|||||||
continue
|
continue
|
||||||
return None, "No provider could retrieve the data"
|
return None, "No provider could retrieve the data"
|
||||||
|
|
||||||
|
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||||
|
for provider in self.providers:
|
||||||
|
try:
|
||||||
|
result = await provider.get_user_models(username)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Provider failed for get_user_models: {e}")
|
||||||
|
continue
|
||||||
|
return None
|
||||||
|
|
||||||
class ModelMetadataProviderManager:
|
class ModelMetadataProviderManager:
|
||||||
"""Manager for selecting and using model metadata providers"""
|
"""Manager for selecting and using model metadata providers"""
|
||||||
|
|
||||||
@@ -522,6 +549,11 @@ class ModelMetadataProviderManager:
|
|||||||
"""Fetch model version info using specified or default provider"""
|
"""Fetch model version info using specified or default provider"""
|
||||||
provider = self._get_provider(provider_name)
|
provider = self._get_provider(provider_name)
|
||||||
return await provider.get_model_version_info(version_id)
|
return await provider.get_model_version_info(version_id)
|
||||||
|
|
||||||
|
async def get_user_models(self, username: str, provider_name: str = None) -> Optional[List[Dict]]:
|
||||||
|
"""Fetch models owned by the specified user"""
|
||||||
|
provider = self._get_provider(provider_name)
|
||||||
|
return await provider.get_user_models(username)
|
||||||
|
|
||||||
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
|
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
|
||||||
"""Get provider by name or default provider"""
|
"""Get provider by name or default provider"""
|
||||||
|
|||||||
@@ -634,7 +634,8 @@ class ModelScanner:
|
|||||||
if model_data:
|
if model_data:
|
||||||
# Add to cache
|
# Add to cache
|
||||||
self._cache.raw_data.append(model_data)
|
self._cache.raw_data.append(model_data)
|
||||||
|
self._cache.add_to_version_index(model_data)
|
||||||
|
|
||||||
# Update hash index if available
|
# Update hash index if available
|
||||||
if 'sha256' in model_data and 'file_path' in model_data:
|
if 'sha256' in model_data and 'file_path' in model_data:
|
||||||
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
|
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
|
||||||
@@ -661,7 +662,9 @@ class ModelScanner:
|
|||||||
for path in missing_files:
|
for path in missing_files:
|
||||||
try:
|
try:
|
||||||
model_to_remove = path_to_item[path]
|
model_to_remove = path_to_item[path]
|
||||||
|
|
||||||
|
self._cache.remove_from_version_index(model_to_remove)
|
||||||
|
|
||||||
# Update tags count
|
# Update tags count
|
||||||
for tag in model_to_remove.get('tags', []):
|
for tag in model_to_remove.get('tags', []):
|
||||||
if tag in self._tags_count:
|
if tag in self._tags_count:
|
||||||
@@ -684,6 +687,8 @@ class ModelScanner:
|
|||||||
all_folders = set(item.get('folder', '') for item in self._cache.raw_data)
|
all_folders = set(item.get('folder', '') for item in self._cache.raw_data)
|
||||||
self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||||
|
|
||||||
|
self._cache.rebuild_version_index()
|
||||||
|
|
||||||
# Resort cache
|
# Resort cache
|
||||||
await self._cache.resort()
|
await self._cache.resort()
|
||||||
|
|
||||||
@@ -829,6 +834,8 @@ class ModelScanner:
|
|||||||
else:
|
else:
|
||||||
self._cache.raw_data = list(scan_result.raw_data)
|
self._cache.raw_data = list(scan_result.raw_data)
|
||||||
|
|
||||||
|
self._cache.rebuild_version_index()
|
||||||
|
|
||||||
await self._cache.resort()
|
await self._cache.resort()
|
||||||
|
|
||||||
async def _gather_model_data(
|
async def _gather_model_data(
|
||||||
@@ -934,7 +941,8 @@ class ModelScanner:
|
|||||||
|
|
||||||
# Add to cache
|
# Add to cache
|
||||||
self._cache.raw_data.append(metadata_dict)
|
self._cache.raw_data.append(metadata_dict)
|
||||||
|
self._cache.add_to_version_index(metadata_dict)
|
||||||
|
|
||||||
# Resort cache data
|
# Resort cache data
|
||||||
await self._cache.resort()
|
await self._cache.resort()
|
||||||
|
|
||||||
@@ -1076,6 +1084,9 @@ class ModelScanner:
|
|||||||
cache = await self.get_cached_data()
|
cache = await self.get_cached_data()
|
||||||
|
|
||||||
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
|
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
|
||||||
|
if existing_item:
|
||||||
|
cache.remove_from_version_index(existing_item)
|
||||||
|
|
||||||
if existing_item and 'tags' in existing_item:
|
if existing_item and 'tags' in existing_item:
|
||||||
for tag in existing_item.get('tags', []):
|
for tag in existing_item.get('tags', []):
|
||||||
if tag in self._tags_count:
|
if tag in self._tags_count:
|
||||||
@@ -1106,6 +1117,7 @@ class ModelScanner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
cache.raw_data.append(cache_entry)
|
cache.raw_data.append(cache_entry)
|
||||||
|
cache.add_to_version_index(cache_entry)
|
||||||
|
|
||||||
sha_value = cache_entry.get('sha256')
|
sha_value = cache_entry.get('sha256')
|
||||||
if sha_value:
|
if sha_value:
|
||||||
@@ -1117,6 +1129,8 @@ class ModelScanner:
|
|||||||
for tag in cache_entry.get('tags', []):
|
for tag in cache_entry.get('tags', []):
|
||||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||||
|
|
||||||
|
cache.rebuild_version_index()
|
||||||
|
|
||||||
await cache.resort()
|
await cache.resort()
|
||||||
|
|
||||||
if cache_modified:
|
if cache_modified:
|
||||||
@@ -1339,11 +1353,12 @@ class ModelScanner:
|
|||||||
# Update hash index
|
# Update hash index
|
||||||
for model in models_to_remove:
|
for model in models_to_remove:
|
||||||
file_path = model['file_path']
|
file_path = model['file_path']
|
||||||
|
self._cache.remove_from_version_index(model)
|
||||||
if hasattr(self, '_hash_index') and self._hash_index:
|
if hasattr(self, '_hash_index') and self._hash_index:
|
||||||
# Get the hash and filename before removal for duplicate checking
|
# Get the hash and filename before removal for duplicate checking
|
||||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||||
hash_val = model.get('sha256', '').lower()
|
hash_val = model.get('sha256', '').lower()
|
||||||
|
|
||||||
# Remove from hash index
|
# Remove from hash index
|
||||||
self._hash_index.remove_by_path(file_path, hash_val)
|
self._hash_index.remove_by_path(file_path, hash_val)
|
||||||
|
|
||||||
@@ -1352,8 +1367,9 @@ class ModelScanner:
|
|||||||
|
|
||||||
# Update cache data
|
# Update cache data
|
||||||
self._cache.raw_data = [item for item in self._cache.raw_data if item['file_path'] not in file_paths]
|
self._cache.raw_data = [item for item in self._cache.raw_data if item['file_path'] not in file_paths]
|
||||||
|
|
||||||
# Resort cache
|
# Resort cache
|
||||||
|
self._cache.rebuild_version_index()
|
||||||
await self._cache.resort()
|
await self._cache.resort()
|
||||||
|
|
||||||
await self._persist_current_cache()
|
await self._persist_current_cache()
|
||||||
@@ -1393,16 +1409,17 @@ class ModelScanner:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if the model version exists, False otherwise
|
bool: True if the model version exists, False otherwise
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
normalized_id = int(model_version_id)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cache = await self.get_cached_data()
|
cache = await self.get_cached_data()
|
||||||
if not cache or not cache.raw_data:
|
if not cache:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
for item in cache.raw_data:
|
return normalized_id in cache.version_index
|
||||||
if item.get('civitai') and item['civitai'].get('id') == model_version_id:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error checking model version existence: {e}")
|
logger.error(f"Error checking model version existence: {e}")
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -351,7 +351,7 @@ class PersistentModelCache:
|
|||||||
|
|
||||||
|
|
||||||
def get_persistent_cache() -> PersistentModelCache:
|
def get_persistent_cache() -> PersistentModelCache:
|
||||||
from .settings_manager import settings as settings_service # Local import to avoid cycles
|
from .settings_manager import get_settings_manager # Local import to avoid cycles
|
||||||
|
|
||||||
library_name = settings_service.get_active_library_name()
|
library_name = get_settings_manager().get_active_library_name()
|
||||||
return PersistentModelCache.get_default(library_name)
|
return PersistentModelCache.get_default(library_name)
|
||||||
|
|||||||
@@ -5,8 +5,10 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Awaitable, Callable, Dict, Optional, Sequence
|
from typing import Awaitable, Callable, Dict, Optional, Sequence
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS
|
from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS
|
||||||
|
from ..utils.civitai_utils import rewrite_preview_url
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -45,23 +47,59 @@ class PreviewAssetService:
|
|||||||
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
|
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
|
||||||
preview_dir = os.path.dirname(metadata_path)
|
preview_dir = os.path.dirname(metadata_path)
|
||||||
is_video = first_preview.get("type") == "video"
|
is_video = first_preview.get("type") == "video"
|
||||||
|
preview_url = first_preview.get("url")
|
||||||
|
|
||||||
|
if not preview_url:
|
||||||
|
return
|
||||||
|
|
||||||
|
def extension_from_url(url: str, fallback: str) -> str:
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
except ValueError:
|
||||||
|
return fallback
|
||||||
|
ext = os.path.splitext(parsed.path)[1]
|
||||||
|
return ext or fallback
|
||||||
|
|
||||||
|
downloader = await self._downloader_factory()
|
||||||
|
|
||||||
if is_video:
|
if is_video:
|
||||||
extension = ".mp4"
|
extension = extension_from_url(preview_url, ".mp4")
|
||||||
preview_path = os.path.join(preview_dir, base_name + extension)
|
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||||
downloader = await self._downloader_factory()
|
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type="video")
|
||||||
success, result = await downloader.download_file(
|
|
||||||
first_preview["url"], preview_path, use_auth=False
|
attempt_urls = []
|
||||||
)
|
if rewritten:
|
||||||
if success:
|
attempt_urls.append(rewritten_url)
|
||||||
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
attempt_urls.append(preview_url)
|
||||||
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
|
||||||
|
seen: set[str] = set()
|
||||||
|
for candidate in attempt_urls:
|
||||||
|
if not candidate or candidate in seen:
|
||||||
|
continue
|
||||||
|
seen.add(candidate)
|
||||||
|
|
||||||
|
success, _ = await downloader.download_file(candidate, preview_path, use_auth=False)
|
||||||
|
if success:
|
||||||
|
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
||||||
|
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
||||||
|
return
|
||||||
else:
|
else:
|
||||||
|
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type="image")
|
||||||
|
if rewritten:
|
||||||
|
extension = extension_from_url(preview_url, ".png")
|
||||||
|
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||||
|
success, _ = await downloader.download_file(
|
||||||
|
rewritten_url, preview_path, use_auth=False
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
||||||
|
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
||||||
|
return
|
||||||
|
|
||||||
extension = ".webp"
|
extension = ".webp"
|
||||||
preview_path = os.path.join(preview_dir, base_name + extension)
|
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||||
downloader = await self._downloader_factory()
|
|
||||||
success, content, _headers = await downloader.download_to_memory(
|
success, content, _headers = await downloader.download_to_memory(
|
||||||
first_preview["url"], use_auth=False
|
preview_url, use_auth=False
|
||||||
)
|
)
|
||||||
if not success:
|
if not success:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from threading import Lock
|
||||||
from typing import Any, Dict, Iterable, List, Mapping, Optional
|
from typing import Any, Dict, Iterable, List, Mapping, Optional
|
||||||
|
|
||||||
from ..utils.settings_paths import ensure_settings_file
|
from ..utils.settings_paths import ensure_settings_file
|
||||||
@@ -688,4 +689,38 @@ class SettingsManager:
|
|||||||
|
|
||||||
return templates.get(model_type, '{base_model}/{first_tag}')
|
return templates.get(model_type, '{base_model}/{first_tag}')
|
||||||
|
|
||||||
settings = SettingsManager()
|
|
||||||
|
_SETTINGS_MANAGER: Optional["SettingsManager"] = None
|
||||||
|
_SETTINGS_MANAGER_LOCK = Lock()
|
||||||
|
# Legacy module-level alias for backwards compatibility with callers that
|
||||||
|
# monkeypatch ``py.services.settings_manager.settings`` during tests.
|
||||||
|
settings: Optional["SettingsManager"] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_settings_manager() -> "SettingsManager":
|
||||||
|
"""Return the lazily initialised global :class:`SettingsManager`."""
|
||||||
|
|
||||||
|
global _SETTINGS_MANAGER, settings
|
||||||
|
if settings is not None:
|
||||||
|
return settings
|
||||||
|
|
||||||
|
if _SETTINGS_MANAGER is None:
|
||||||
|
with _SETTINGS_MANAGER_LOCK:
|
||||||
|
if _SETTINGS_MANAGER is None:
|
||||||
|
_SETTINGS_MANAGER = SettingsManager()
|
||||||
|
|
||||||
|
settings = _SETTINGS_MANAGER
|
||||||
|
return _SETTINGS_MANAGER
|
||||||
|
|
||||||
|
|
||||||
|
def reset_settings_manager() -> None:
|
||||||
|
"""Reset the cached settings manager instance.
|
||||||
|
|
||||||
|
Primarily intended for tests so they can configure the settings
|
||||||
|
directory before the manager touches the filesystem.
|
||||||
|
"""
|
||||||
|
|
||||||
|
global _SETTINGS_MANAGER, settings
|
||||||
|
with _SETTINGS_MANAGER_LOCK:
|
||||||
|
_SETTINGS_MANAGER = None
|
||||||
|
settings = None
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import logging
|
|||||||
from typing import Any, Dict, Optional, Protocol, Sequence
|
from typing import Any, Dict, Optional, Protocol, Sequence
|
||||||
|
|
||||||
from ..metadata_sync_service import MetadataSyncService
|
from ..metadata_sync_service import MetadataSyncService
|
||||||
|
from ...utils.metadata_manager import MetadataManager
|
||||||
|
|
||||||
|
|
||||||
class MetadataRefreshProgressReporter(Protocol):
|
class MetadataRefreshProgressReporter(Protocol):
|
||||||
@@ -70,6 +71,7 @@ class BulkMetadataRefreshUseCase:
|
|||||||
for model in to_process:
|
for model in to_process:
|
||||||
try:
|
try:
|
||||||
original_name = model.get("model_name")
|
original_name = model.get("model_name")
|
||||||
|
await MetadataManager.hydrate_model_data(model)
|
||||||
result, _ = await self._metadata_sync.fetch_and_update_model(
|
result, _ = await self._metadata_sync.fetch_and_update_model(
|
||||||
sha256=model["sha256"],
|
sha256=model["sha256"],
|
||||||
file_path=model["file_path"],
|
file_path=model["file_path"],
|
||||||
|
|||||||
48
py/utils/civitai_utils.py
Normal file
48
py/utils/civitai_utils.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""Utilities for working with Civitai assets."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from urllib.parse import urlparse, urlunparse
|
||||||
|
|
||||||
|
|
||||||
|
def rewrite_preview_url(source_url: str | None, media_type: str | None = None) -> tuple[str | None, bool]:
|
||||||
|
"""Rewrite Civitai preview URLs to use optimized renditions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_url: Original preview URL from the Civitai API.
|
||||||
|
media_type: Optional media type hint (e.g. ``"image"`` or ``"video"``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of the potentially rewritten URL and a flag indicating whether the
|
||||||
|
replacement occurred. When the URL is not rewritten, the original value is
|
||||||
|
returned with ``False``.
|
||||||
|
"""
|
||||||
|
if not source_url:
|
||||||
|
return source_url, False
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = urlparse(source_url)
|
||||||
|
except ValueError:
|
||||||
|
return source_url, False
|
||||||
|
|
||||||
|
if parsed.netloc.lower() != "image.civitai.com":
|
||||||
|
return source_url, False
|
||||||
|
|
||||||
|
replacement = "/width=450,optimized=true"
|
||||||
|
if (media_type or "").lower() == "video":
|
||||||
|
replacement = "/transcode=true,width=450,optimized=true"
|
||||||
|
|
||||||
|
if "/original=true" not in parsed.path:
|
||||||
|
return source_url, False
|
||||||
|
|
||||||
|
updated_path = parsed.path.replace("/original=true", replacement, 1)
|
||||||
|
if updated_path == parsed.path:
|
||||||
|
return source_url, False
|
||||||
|
|
||||||
|
rewritten = urlunparse(parsed._replace(path=updated_path))
|
||||||
|
print(rewritten)
|
||||||
|
return rewritten, True
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["rewrite_preview_url"]
|
||||||
|
|
||||||
@@ -48,6 +48,13 @@ SUPPORTED_MEDIA_EXTENSIONS = {
|
|||||||
# Valid Lora types
|
# Valid Lora types
|
||||||
VALID_LORA_TYPES = ['lora', 'locon', 'dora']
|
VALID_LORA_TYPES = ['lora', 'locon', 'dora']
|
||||||
|
|
||||||
|
# Supported Civitai model types for user model queries (case-insensitive)
|
||||||
|
CIVITAI_USER_MODEL_TYPES = [
|
||||||
|
*VALID_LORA_TYPES,
|
||||||
|
'textualinversion',
|
||||||
|
'checkpoint',
|
||||||
|
]
|
||||||
|
|
||||||
# Auto-organize settings
|
# Auto-organize settings
|
||||||
AUTO_ORGANIZE_BATCH_SIZE = 50 # Process models in batches to avoid overwhelming the system
|
AUTO_ORGANIZE_BATCH_SIZE = 50 # Process models in batches to avoid overwhelming the system
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from ..utils.metadata_manager import MetadataManager
|
|||||||
from .example_images_processor import ExampleImagesProcessor
|
from .example_images_processor import ExampleImagesProcessor
|
||||||
from .example_images_metadata import MetadataUpdater
|
from .example_images_metadata import MetadataUpdater
|
||||||
from ..services.downloader import get_downloader
|
from ..services.downloader import get_downloader
|
||||||
from ..services.settings_manager import settings
|
from ..services.settings_manager import get_settings_manager
|
||||||
|
|
||||||
|
|
||||||
class ExampleImagesDownloadError(RuntimeError):
|
class ExampleImagesDownloadError(RuntimeError):
|
||||||
@@ -107,7 +107,7 @@ class DownloadManager:
|
|||||||
self._state_lock = state_lock or asyncio.Lock()
|
self._state_lock = state_lock or asyncio.Lock()
|
||||||
|
|
||||||
def _resolve_output_dir(self, library_name: str | None = None) -> str:
|
def _resolve_output_dir(self, library_name: str | None = None) -> str:
|
||||||
base_path = settings.get('example_images_path')
|
base_path = get_settings_manager().get('example_images_path')
|
||||||
if not base_path:
|
if not base_path:
|
||||||
return ''
|
return ''
|
||||||
return ensure_library_root_exists(library_name)
|
return ensure_library_root_exists(library_name)
|
||||||
@@ -126,7 +126,8 @@ class DownloadManager:
|
|||||||
model_types = data.get('model_types', ['lora', 'checkpoint'])
|
model_types = data.get('model_types', ['lora', 'checkpoint'])
|
||||||
delay = float(data.get('delay', 0.2))
|
delay = float(data.get('delay', 0.2))
|
||||||
|
|
||||||
base_path = settings.get('example_images_path')
|
settings_manager = get_settings_manager()
|
||||||
|
base_path = settings_manager.get('example_images_path')
|
||||||
|
|
||||||
if not base_path:
|
if not base_path:
|
||||||
error_msg = 'Example images path not configured in settings'
|
error_msg = 'Example images path not configured in settings'
|
||||||
@@ -138,7 +139,7 @@ class DownloadManager:
|
|||||||
}
|
}
|
||||||
raise DownloadConfigurationError(error_msg)
|
raise DownloadConfigurationError(error_msg)
|
||||||
|
|
||||||
active_library = settings.get_active_library_name()
|
active_library = get_settings_manager().get_active_library_name()
|
||||||
output_dir = self._resolve_output_dir(active_library)
|
output_dir = self._resolve_output_dir(active_library)
|
||||||
if not output_dir:
|
if not output_dir:
|
||||||
raise DownloadConfigurationError('Example images path not configured in settings')
|
raise DownloadConfigurationError('Example images path not configured in settings')
|
||||||
@@ -151,7 +152,7 @@ class DownloadManager:
|
|||||||
progress_file = os.path.join(output_dir, '.download_progress.json')
|
progress_file = os.path.join(output_dir, '.download_progress.json')
|
||||||
progress_source = progress_file
|
progress_source = progress_file
|
||||||
if uses_library_scoped_folders():
|
if uses_library_scoped_folders():
|
||||||
legacy_root = settings.get('example_images_path') or ''
|
legacy_root = get_settings_manager().get('example_images_path') or ''
|
||||||
legacy_progress = os.path.join(legacy_root, '.download_progress.json') if legacy_root else ''
|
legacy_progress = os.path.join(legacy_root, '.download_progress.json') if legacy_root else ''
|
||||||
if legacy_progress and os.path.exists(legacy_progress) and not os.path.exists(progress_file):
|
if legacy_progress and os.path.exists(legacy_progress) and not os.path.exists(progress_file):
|
||||||
try:
|
try:
|
||||||
@@ -555,11 +556,12 @@ class DownloadManager:
|
|||||||
if not model_hashes:
|
if not model_hashes:
|
||||||
raise DownloadConfigurationError('Missing model_hashes parameter')
|
raise DownloadConfigurationError('Missing model_hashes parameter')
|
||||||
|
|
||||||
base_path = settings.get('example_images_path')
|
settings_manager = get_settings_manager()
|
||||||
|
base_path = settings_manager.get('example_images_path')
|
||||||
|
|
||||||
if not base_path:
|
if not base_path:
|
||||||
raise DownloadConfigurationError('Example images path not configured in settings')
|
raise DownloadConfigurationError('Example images path not configured in settings')
|
||||||
active_library = settings.get_active_library_name()
|
active_library = settings_manager.get_active_library_name()
|
||||||
output_dir = self._resolve_output_dir(active_library)
|
output_dir = self._resolve_output_dir(active_library)
|
||||||
if not output_dir:
|
if not output_dir:
|
||||||
raise DownloadConfigurationError('Example images path not configured in settings')
|
raise DownloadConfigurationError('Example images path not configured in settings')
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import subprocess
|
import subprocess
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from ..services.settings_manager import settings
|
from ..services.settings_manager import get_settings_manager
|
||||||
from ..utils.example_images_paths import (
|
from ..utils.example_images_paths import (
|
||||||
get_model_folder,
|
get_model_folder,
|
||||||
get_model_relative_path,
|
get_model_relative_path,
|
||||||
@@ -37,7 +37,8 @@ class ExampleImagesFileManager:
|
|||||||
}, status=400)
|
}, status=400)
|
||||||
|
|
||||||
# Get example images path from settings
|
# Get example images path from settings
|
||||||
example_images_path = settings.get('example_images_path')
|
settings_manager = get_settings_manager()
|
||||||
|
example_images_path = settings_manager.get('example_images_path')
|
||||||
if not example_images_path:
|
if not example_images_path:
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
'success': False,
|
'success': False,
|
||||||
@@ -109,7 +110,8 @@ class ExampleImagesFileManager:
|
|||||||
}, status=400)
|
}, status=400)
|
||||||
|
|
||||||
# Get example images path from settings
|
# Get example images path from settings
|
||||||
example_images_path = settings.get('example_images_path')
|
settings_manager = get_settings_manager()
|
||||||
|
example_images_path = settings_manager.get('example_images_path')
|
||||||
if not example_images_path:
|
if not example_images_path:
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
'success': False,
|
'success': False,
|
||||||
@@ -183,7 +185,8 @@ class ExampleImagesFileManager:
|
|||||||
}, status=400)
|
}, status=400)
|
||||||
|
|
||||||
# Get example images path from settings
|
# Get example images path from settings
|
||||||
example_images_path = settings.get('example_images_path')
|
settings_manager = get_settings_manager()
|
||||||
|
example_images_path = settings_manager.get('example_images_path')
|
||||||
if not example_images_path:
|
if not example_images_path:
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
'has_images': False
|
'has_images': False
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
from ..recipes.constants import GEN_PARAM_KEYS
|
from ..recipes.constants import GEN_PARAM_KEYS
|
||||||
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
|
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
|
||||||
from ..services.metadata_sync_service import MetadataSyncService
|
from ..services.metadata_sync_service import MetadataSyncService
|
||||||
from ..services.preview_asset_service import PreviewAssetService
|
from ..services.preview_asset_service import PreviewAssetService
|
||||||
from ..services.settings_manager import settings
|
from ..services.settings_manager import get_settings_manager
|
||||||
from ..services.downloader import get_downloader
|
from ..services.downloader import get_downloader
|
||||||
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
|
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
|
||||||
from ..utils.exif_utils import ExifUtils
|
from ..utils.exif_utils import ExifUtils
|
||||||
@@ -20,13 +21,46 @@ _preview_service = PreviewAssetService(
|
|||||||
exif_utils=ExifUtils,
|
exif_utils=ExifUtils,
|
||||||
)
|
)
|
||||||
|
|
||||||
_metadata_sync_service = MetadataSyncService(
|
_metadata_sync_service: MetadataSyncService | None = None
|
||||||
metadata_manager=MetadataManager,
|
_metadata_sync_service_settings: Optional["SettingsManager"] = None
|
||||||
preview_service=_preview_service,
|
|
||||||
settings=settings,
|
if TYPE_CHECKING: # pragma: no cover - import for type checkers only
|
||||||
default_metadata_provider_factory=get_default_metadata_provider,
|
from ..services.settings_manager import SettingsManager
|
||||||
metadata_provider_selector=get_metadata_provider,
|
|
||||||
)
|
|
||||||
|
def _build_metadata_sync_service(settings_manager: "SettingsManager") -> MetadataSyncService:
|
||||||
|
"""Construct a metadata sync service bound to the provided settings."""
|
||||||
|
|
||||||
|
return MetadataSyncService(
|
||||||
|
metadata_manager=MetadataManager,
|
||||||
|
preview_service=_preview_service,
|
||||||
|
settings=settings_manager,
|
||||||
|
default_metadata_provider_factory=get_default_metadata_provider,
|
||||||
|
metadata_provider_selector=get_metadata_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_metadata_sync_service() -> MetadataSyncService:
|
||||||
|
"""Return the shared metadata sync service, initialising it lazily."""
|
||||||
|
|
||||||
|
global _metadata_sync_service, _metadata_sync_service_settings
|
||||||
|
|
||||||
|
settings_manager = get_settings_manager()
|
||||||
|
|
||||||
|
if isinstance(_metadata_sync_service, MetadataSyncService):
|
||||||
|
if _metadata_sync_service_settings is not settings_manager:
|
||||||
|
_metadata_sync_service = _build_metadata_sync_service(settings_manager)
|
||||||
|
_metadata_sync_service_settings = settings_manager
|
||||||
|
elif _metadata_sync_service is None:
|
||||||
|
_metadata_sync_service = _build_metadata_sync_service(settings_manager)
|
||||||
|
_metadata_sync_service_settings = settings_manager
|
||||||
|
else:
|
||||||
|
# Tests may inject stand-ins that do not match the sync service type. Preserve
|
||||||
|
# those injections while still updating our cached settings reference so the
|
||||||
|
# next real service instantiation uses the current configuration.
|
||||||
|
_metadata_sync_service_settings = settings_manager
|
||||||
|
|
||||||
|
return _metadata_sync_service
|
||||||
|
|
||||||
|
|
||||||
class MetadataUpdater:
|
class MetadataUpdater:
|
||||||
@@ -71,7 +105,8 @@ class MetadataUpdater:
|
|||||||
async def update_cache_func(old_path, new_path, metadata):
|
async def update_cache_func(old_path, new_path, metadata):
|
||||||
return await scanner.update_single_model_cache(old_path, new_path, metadata)
|
return await scanner.update_single_model_cache(old_path, new_path, metadata)
|
||||||
|
|
||||||
success, error = await _metadata_sync_service.fetch_and_update_model(
|
await MetadataManager.hydrate_model_data(model_data)
|
||||||
|
success, error = await _get_metadata_sync_service().fetch_and_update_model(
|
||||||
sha256=model_hash,
|
sha256=model_hash,
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
model_data=model_data,
|
model_data=model_data,
|
||||||
@@ -151,16 +186,16 @@ class MetadataUpdater:
|
|||||||
if is_supported:
|
if is_supported:
|
||||||
local_images_paths.append(file_path)
|
local_images_paths.append(file_path)
|
||||||
|
|
||||||
|
await MetadataManager.hydrate_model_data(model)
|
||||||
|
civitai_data = model.setdefault('civitai', {})
|
||||||
|
|
||||||
# Check if metadata update is needed (no civitai field or empty images)
|
# Check if metadata update is needed (no civitai field or empty images)
|
||||||
needs_update = not model.get('civitai') or not model.get('civitai', {}).get('images')
|
needs_update = not civitai_data or not civitai_data.get('images')
|
||||||
|
|
||||||
if needs_update and local_images_paths:
|
if needs_update and local_images_paths:
|
||||||
logger.debug(f"Found {len(local_images_paths)} local example images for {model.get('model_name')}, updating metadata")
|
logger.debug(f"Found {len(local_images_paths)} local example images for {model.get('model_name')}, updating metadata")
|
||||||
|
|
||||||
# Create or get civitai field
|
# Create or get civitai field
|
||||||
if not model.get('civitai'):
|
|
||||||
model['civitai'] = {}
|
|
||||||
|
|
||||||
# Create images array
|
# Create images array
|
||||||
images = []
|
images = []
|
||||||
|
|
||||||
@@ -195,16 +230,13 @@ class MetadataUpdater:
|
|||||||
images.append(image_entry)
|
images.append(image_entry)
|
||||||
|
|
||||||
# Update the model's civitai.images field
|
# Update the model's civitai.images field
|
||||||
model['civitai']['images'] = images
|
civitai_data['images'] = images
|
||||||
|
|
||||||
# Save metadata to .metadata.json file
|
# Save metadata to .metadata.json file
|
||||||
file_path = model.get('file_path')
|
file_path = model.get('file_path')
|
||||||
try:
|
try:
|
||||||
# Create a copy of model data without 'folder' field
|
|
||||||
model_copy = model.copy()
|
model_copy = model.copy()
|
||||||
model_copy.pop('folder', None)
|
model_copy.pop('folder', None)
|
||||||
|
|
||||||
# Write metadata to file
|
|
||||||
await MetadataManager.save_metadata(file_path, model_copy)
|
await MetadataManager.save_metadata(file_path, model_copy)
|
||||||
logger.info(f"Saved metadata for {model.get('model_name')}")
|
logger.info(f"Saved metadata for {model.get('model_name')}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -237,16 +269,13 @@ class MetadataUpdater:
|
|||||||
tuple: (regular_images, custom_images) - Both image arrays
|
tuple: (regular_images, custom_images) - Both image arrays
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Ensure civitai field exists in model_data
|
await MetadataManager.hydrate_model_data(model_data)
|
||||||
if not model_data.get('civitai'):
|
civitai_data = model_data.setdefault('civitai', {})
|
||||||
model_data['civitai'] = {}
|
custom_images = civitai_data.get('customImages')
|
||||||
|
|
||||||
# Ensure customImages array exists
|
if not isinstance(custom_images, list):
|
||||||
if not model_data['civitai'].get('customImages'):
|
custom_images = []
|
||||||
model_data['civitai']['customImages'] = []
|
civitai_data['customImages'] = custom_images
|
||||||
|
|
||||||
# Get current customImages array
|
|
||||||
custom_images = model_data['civitai']['customImages']
|
|
||||||
|
|
||||||
# Add new image entry for each imported file
|
# Add new image entry for each imported file
|
||||||
for path_tuple in newly_imported_paths:
|
for path_tuple in newly_imported_paths:
|
||||||
@@ -304,11 +333,8 @@ class MetadataUpdater:
|
|||||||
file_path = model_data.get('file_path')
|
file_path = model_data.get('file_path')
|
||||||
if file_path:
|
if file_path:
|
||||||
try:
|
try:
|
||||||
# Create a copy of model data without 'folder' field
|
|
||||||
model_copy = model_data.copy()
|
model_copy = model_data.copy()
|
||||||
model_copy.pop('folder', None)
|
model_copy.pop('folder', None)
|
||||||
|
|
||||||
# Write metadata to file
|
|
||||||
await MetadataManager.save_metadata(file_path, model_copy)
|
await MetadataManager.save_metadata(file_path, model_copy)
|
||||||
logger.info(f"Saved metadata for {model_data.get('model_name')}")
|
logger.info(f"Saved metadata for {model_data.get('model_name')}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -319,7 +345,7 @@ class MetadataUpdater:
|
|||||||
await scanner.update_single_model_cache(file_path, file_path, model_data)
|
await scanner.update_single_model_cache(file_path, file_path, model_data)
|
||||||
|
|
||||||
# Get regular images array (might be None)
|
# Get regular images array (might be None)
|
||||||
regular_images = model_data['civitai'].get('images', [])
|
regular_images = civitai_data.get('images', [])
|
||||||
|
|
||||||
# Return both image arrays
|
# Return both image arrays
|
||||||
return regular_images, custom_images
|
return regular_images, custom_images
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
from ..services.settings_manager import settings
|
from ..services.settings_manager import get_settings_manager
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..utils.example_images_paths import iter_library_roots
|
from ..utils.example_images_paths import iter_library_roots
|
||||||
from ..utils.metadata_manager import MetadataManager
|
from ..utils.metadata_manager import MetadataManager
|
||||||
@@ -14,6 +14,25 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
CURRENT_NAMING_VERSION = 2 # Increment this when naming conventions change
|
CURRENT_NAMING_VERSION = 2 # Increment this when naming conventions change
|
||||||
|
|
||||||
|
|
||||||
|
class _SettingsProxy:
|
||||||
|
def __init__(self):
|
||||||
|
self._manager = None
|
||||||
|
|
||||||
|
def _resolve(self):
|
||||||
|
if self._manager is None:
|
||||||
|
self._manager = get_settings_manager()
|
||||||
|
return self._manager
|
||||||
|
|
||||||
|
def get(self, *args, **kwargs):
|
||||||
|
return self._resolve().get(*args, **kwargs)
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
return getattr(self._resolve(), item)
|
||||||
|
|
||||||
|
|
||||||
|
settings = _SettingsProxy()
|
||||||
|
|
||||||
class ExampleImagesMigration:
|
class ExampleImagesMigration:
|
||||||
"""Handles migrations for example images naming conventions"""
|
"""Handles migrations for example images naming conventions"""
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import re
|
|||||||
import shutil
|
import shutil
|
||||||
from typing import Iterable, List, Optional, Tuple
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from ..services.settings_manager import settings
|
from ..services.settings_manager import get_settings_manager
|
||||||
|
|
||||||
_HEX_PATTERN = re.compile(r"[a-fA-F0-9]{64}")
|
_HEX_PATTERN = re.compile(r"[a-fA-F0-9]{64}")
|
||||||
|
|
||||||
@@ -18,7 +18,8 @@ logger = logging.getLogger(__name__)
|
|||||||
def _get_configured_libraries() -> List[str]:
|
def _get_configured_libraries() -> List[str]:
|
||||||
"""Return configured library names if multi-library support is enabled."""
|
"""Return configured library names if multi-library support is enabled."""
|
||||||
|
|
||||||
libraries = settings.get("libraries")
|
settings_manager = get_settings_manager()
|
||||||
|
libraries = settings_manager.get("libraries")
|
||||||
if isinstance(libraries, dict) and libraries:
|
if isinstance(libraries, dict) and libraries:
|
||||||
return list(libraries.keys())
|
return list(libraries.keys())
|
||||||
return []
|
return []
|
||||||
@@ -27,7 +28,8 @@ def _get_configured_libraries() -> List[str]:
|
|||||||
def get_example_images_root() -> str:
|
def get_example_images_root() -> str:
|
||||||
"""Return the root directory configured for example images."""
|
"""Return the root directory configured for example images."""
|
||||||
|
|
||||||
root = settings.get("example_images_path") or ""
|
settings_manager = get_settings_manager()
|
||||||
|
root = settings_manager.get("example_images_path") or ""
|
||||||
return os.path.abspath(root) if root else ""
|
return os.path.abspath(root) if root else ""
|
||||||
|
|
||||||
|
|
||||||
@@ -41,7 +43,8 @@ def uses_library_scoped_folders() -> bool:
|
|||||||
def sanitize_library_name(library_name: Optional[str]) -> str:
|
def sanitize_library_name(library_name: Optional[str]) -> str:
|
||||||
"""Return a filesystem safe library name."""
|
"""Return a filesystem safe library name."""
|
||||||
|
|
||||||
name = library_name or settings.get_active_library_name() or "default"
|
settings_manager = get_settings_manager()
|
||||||
|
name = library_name or settings_manager.get_active_library_name() or "default"
|
||||||
safe_name = re.sub(r"[^A-Za-z0-9_.-]", "_", name)
|
safe_name = re.sub(r"[^A-Za-z0-9_.-]", "_", name)
|
||||||
return safe_name or "default"
|
return safe_name or "default"
|
||||||
|
|
||||||
@@ -161,11 +164,13 @@ def iter_library_roots() -> Iterable[Tuple[str, str]]:
|
|||||||
results.append((library, get_library_root(library)))
|
results.append((library, get_library_root(library)))
|
||||||
else:
|
else:
|
||||||
# Fall back to the active library to avoid skipping migrations/cleanup
|
# Fall back to the active library to avoid skipping migrations/cleanup
|
||||||
active = settings.get_active_library_name() or "default"
|
settings_manager = get_settings_manager()
|
||||||
|
active = settings_manager.get_active_library_name() or "default"
|
||||||
results.append((active, get_library_root(active)))
|
results.append((active, get_library_root(active)))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
active = settings.get_active_library_name() or "default"
|
settings_manager = get_settings_manager()
|
||||||
|
active = settings_manager.get_active_library_name() or "default"
|
||||||
return [(active, root)]
|
return [(active, root)]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import string
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
|
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..services.settings_manager import settings
|
from ..services.settings_manager import get_settings_manager
|
||||||
from ..utils.example_images_paths import get_model_folder, get_model_relative_path
|
from ..utils.example_images_paths import get_model_folder, get_model_relative_path
|
||||||
from .example_images_metadata import MetadataUpdater
|
from .example_images_metadata import MetadataUpdater
|
||||||
from ..utils.metadata_manager import MetadataManager
|
from ..utils.metadata_manager import MetadataManager
|
||||||
@@ -318,7 +318,7 @@ class ExampleImagesProcessor:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Get example images path
|
# Get example images path
|
||||||
example_images_path = settings.get('example_images_path')
|
example_images_path = get_settings_manager().get('example_images_path')
|
||||||
if not example_images_path:
|
if not example_images_path:
|
||||||
raise ExampleImagesValidationError('No example images path configured')
|
raise ExampleImagesValidationError('No example images path configured')
|
||||||
|
|
||||||
@@ -442,7 +442,7 @@ class ExampleImagesProcessor:
|
|||||||
}, status=400)
|
}, status=400)
|
||||||
|
|
||||||
# Get example images path
|
# Get example images path
|
||||||
example_images_path = settings.get('example_images_path')
|
example_images_path = get_settings_manager().get('example_images_path')
|
||||||
if not example_images_path:
|
if not example_images_path:
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
'success': False,
|
'success': False,
|
||||||
@@ -475,15 +475,17 @@ class ExampleImagesProcessor:
|
|||||||
'error': f"Model with hash {model_hash} not found in cache"
|
'error': f"Model with hash {model_hash} not found in cache"
|
||||||
}, status=404)
|
}, status=404)
|
||||||
|
|
||||||
# Check if model has custom images
|
await MetadataManager.hydrate_model_data(model_data)
|
||||||
if not model_data.get('civitai', {}).get('customImages'):
|
civitai_data = model_data.setdefault('civitai', {})
|
||||||
|
custom_images = civitai_data.get('customImages')
|
||||||
|
|
||||||
|
if not isinstance(custom_images, list) or not custom_images:
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
'success': False,
|
'success': False,
|
||||||
'error': f"Model has no custom images"
|
'error': f"Model has no custom images"
|
||||||
}, status=404)
|
}, status=404)
|
||||||
|
|
||||||
# Find the custom image with matching short_id
|
# Find the custom image with matching short_id
|
||||||
custom_images = model_data['civitai']['customImages']
|
|
||||||
matching_image = None
|
matching_image = None
|
||||||
new_custom_images = []
|
new_custom_images = []
|
||||||
|
|
||||||
@@ -527,17 +529,15 @@ class ExampleImagesProcessor:
|
|||||||
logger.warning(f"File for custom example with id {short_id} not found, but metadata will still be updated")
|
logger.warning(f"File for custom example with id {short_id} not found, but metadata will still be updated")
|
||||||
|
|
||||||
# Update metadata
|
# Update metadata
|
||||||
model_data['civitai']['customImages'] = new_custom_images
|
civitai_data['customImages'] = new_custom_images
|
||||||
|
model_data.setdefault('civitai', {})['customImages'] = new_custom_images
|
||||||
|
|
||||||
# Save updated metadata to file
|
# Save updated metadata to file
|
||||||
file_path = model_data.get('file_path')
|
file_path = model_data.get('file_path')
|
||||||
if file_path:
|
if file_path:
|
||||||
try:
|
try:
|
||||||
# Create a copy of model data without 'folder' field
|
|
||||||
model_copy = model_data.copy()
|
model_copy = model_data.copy()
|
||||||
model_copy.pop('folder', None)
|
model_copy.pop('folder', None)
|
||||||
|
|
||||||
# Write metadata to file
|
|
||||||
await MetadataManager.save_metadata(file_path, model_copy)
|
await MetadataManager.save_metadata(file_path, model_copy)
|
||||||
logger.debug(f"Saved updated metadata for {model_data.get('model_name')}")
|
logger.debug(f"Saved updated metadata for {model_data.get('model_name')}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -551,7 +551,7 @@ class ExampleImagesProcessor:
|
|||||||
await scanner.update_single_model_cache(file_path, file_path, model_data)
|
await scanner.update_single_model_cache(file_path, file_path, model_data)
|
||||||
|
|
||||||
# Get regular images array (might be None)
|
# Get regular images array (might be None)
|
||||||
regular_images = model_data['civitai'].get('images', [])
|
regular_images = civitai_data.get('images', [])
|
||||||
|
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
'success': True,
|
'success': True,
|
||||||
@@ -568,4 +568,4 @@ class ExampleImagesProcessor:
|
|||||||
}, status=500)
|
}, status=500)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from datetime import datetime
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Optional, Type, Union
|
from typing import Any, Dict, Optional, Type, Union
|
||||||
|
|
||||||
from .models import BaseModelMetadata, LoraMetadata
|
from .models import BaseModelMetadata, LoraMetadata
|
||||||
from .file_utils import normalize_path, find_preview_file, calculate_sha256
|
from .file_utils import normalize_path, find_preview_file, calculate_sha256
|
||||||
@@ -53,6 +53,70 @@ class MetadataManager:
|
|||||||
error_type = "Invalid JSON" if isinstance(e, json.JSONDecodeError) else "Parse error"
|
error_type = "Invalid JSON" if isinstance(e, json.JSONDecodeError) else "Parse error"
|
||||||
logger.error(f"{error_type} in metadata file: {metadata_path}. Error: {str(e)}. Skipping model to preserve existing data.")
|
logger.error(f"{error_type} in metadata file: {metadata_path}. Error: {str(e)}. Skipping model to preserve existing data.")
|
||||||
return None, True # should_skip = True
|
return None, True # should_skip = True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def load_metadata_payload(file_path: str) -> Dict:
|
||||||
|
"""
|
||||||
|
Load metadata and return it as a dictionary, including any unknown fields.
|
||||||
|
Falls back to reading the raw JSON file if parsing into a model class fails.
|
||||||
|
"""
|
||||||
|
|
||||||
|
payload: Dict = {}
|
||||||
|
metadata_obj, should_skip = await MetadataManager.load_metadata(file_path)
|
||||||
|
|
||||||
|
if metadata_obj:
|
||||||
|
payload = metadata_obj.to_dict()
|
||||||
|
unknown_fields = getattr(metadata_obj, "_unknown_fields", None)
|
||||||
|
if isinstance(unknown_fields, dict):
|
||||||
|
payload.update(unknown_fields)
|
||||||
|
else:
|
||||||
|
if not should_skip:
|
||||||
|
metadata_path = (
|
||||||
|
file_path
|
||||||
|
if file_path.endswith(".metadata.json")
|
||||||
|
else f"{os.path.splitext(file_path)[0]}.metadata.json"
|
||||||
|
)
|
||||||
|
if os.path.exists(metadata_path):
|
||||||
|
try:
|
||||||
|
with open(metadata_path, "r", encoding="utf-8") as handle:
|
||||||
|
raw = json.load(handle)
|
||||||
|
if isinstance(raw, dict):
|
||||||
|
payload = raw
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to parse metadata file %s while loading payload",
|
||||||
|
metadata_path,
|
||||||
|
)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
logger.warning("Failed to read metadata file %s: %s", metadata_path, exc)
|
||||||
|
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
payload = {}
|
||||||
|
|
||||||
|
if file_path:
|
||||||
|
payload.setdefault("file_path", normalize_path(file_path))
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def hydrate_model_data(model_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Replace the provided model data with the authoritative payload from disk.
|
||||||
|
Preserves the cached folder entry if present.
|
||||||
|
"""
|
||||||
|
|
||||||
|
file_path = model_data.get("file_path")
|
||||||
|
if not file_path:
|
||||||
|
return model_data
|
||||||
|
|
||||||
|
folder = model_data.get("folder")
|
||||||
|
payload = await MetadataManager.load_metadata_payload(file_path)
|
||||||
|
if folder is not None:
|
||||||
|
payload["folder"] = folder
|
||||||
|
|
||||||
|
model_data.clear()
|
||||||
|
model_data.update(payload)
|
||||||
|
return model_data
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def save_metadata(path: str, metadata: Union[BaseModelMetadata, Dict]) -> bool:
|
async def save_metadata(path: str, metadata: Union[BaseModelMetadata, Dict]) -> bool:
|
||||||
|
|||||||
@@ -65,6 +65,12 @@ def ensure_settings_file(logger: Optional[logging.Logger] = None) -> str:
|
|||||||
|
|
||||||
logger = logger or _LOGGER
|
logger = logger or _LOGGER
|
||||||
target_path = get_settings_file_path(create_dir=True)
|
target_path = get_settings_file_path(create_dir=True)
|
||||||
|
preferred_dir = user_config_dir(APP_NAME, appauthor=False)
|
||||||
|
preferred_path = os.path.join(preferred_dir, "settings.json")
|
||||||
|
|
||||||
|
if os.path.abspath(target_path) != os.path.abspath(preferred_path):
|
||||||
|
os.makedirs(preferred_dir, exist_ok=True)
|
||||||
|
target_path = preferred_path
|
||||||
legacy_path = get_legacy_settings_path()
|
legacy_path = get_legacy_settings_path()
|
||||||
|
|
||||||
if os.path.exists(legacy_path) and not os.path.exists(target_path):
|
if os.path.exists(legacy_path) and not os.path.exists(target_path):
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..services.settings_manager import settings
|
from ..services.settings_manager import get_settings_manager
|
||||||
from .constants import CIVITAI_MODEL_TAGS
|
from .constants import CIVITAI_MODEL_TAGS
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
@@ -143,7 +143,8 @@ def calculate_relative_path_for_model(model_data: Dict, model_type: str = 'lora'
|
|||||||
Relative path string (empty string for flat structure)
|
Relative path string (empty string for flat structure)
|
||||||
"""
|
"""
|
||||||
# Get path template from settings for specific model type
|
# Get path template from settings for specific model type
|
||||||
path_template = settings.get_download_path_template(model_type)
|
settings_manager = get_settings_manager()
|
||||||
|
path_template = settings_manager.get_download_path_template(model_type)
|
||||||
|
|
||||||
# If template is empty, return empty path (flat structure)
|
# If template is empty, return empty path (flat structure)
|
||||||
if not path_template:
|
if not path_template:
|
||||||
@@ -166,7 +167,7 @@ def calculate_relative_path_for_model(model_data: Dict, model_type: str = 'lora'
|
|||||||
model_tags = model_data.get('tags', [])
|
model_tags = model_data.get('tags', [])
|
||||||
|
|
||||||
# Apply mapping if available
|
# Apply mapping if available
|
||||||
base_model_mappings = settings.get('base_model_path_mappings', {})
|
base_model_mappings = settings_manager.get('base_model_path_mappings', {})
|
||||||
mapped_base_model = base_model_mappings.get(base_model, base_model)
|
mapped_base_model = base_model_mappings.get(base_model, base_model)
|
||||||
|
|
||||||
# Find the first Civitai model tag that exists in model_tags
|
# Find the first Civitai model tag that exists in model_tags
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "comfyui-lora-manager"
|
name = "comfyui-lora-manager"
|
||||||
description = "Revolutionize your workflow with the ultimate LoRA companion for ComfyUI!"
|
description = "Revolutionize your workflow with the ultimate LoRA companion for ComfyUI!"
|
||||||
version = "0.9.6"
|
version = "0.9.7"
|
||||||
license = {file = "LICENSE"}
|
license = {file = "LICENSE"}
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ export class SidebarManager {
|
|||||||
this.isInitialized = false;
|
this.isInitialized = false;
|
||||||
this.displayMode = 'tree'; // 'tree' or 'list'
|
this.displayMode = 'tree'; // 'tree' or 'list'
|
||||||
this.foldersList = [];
|
this.foldersList = [];
|
||||||
|
this.recursiveSearchEnabled = true;
|
||||||
|
|
||||||
// Bind methods
|
// Bind methods
|
||||||
this.handleTreeClick = this.handleTreeClick.bind(this);
|
this.handleTreeClick = this.handleTreeClick.bind(this);
|
||||||
@@ -36,6 +37,7 @@ export class SidebarManager {
|
|||||||
this.updateContainerMargin = this.updateContainerMargin.bind(this);
|
this.updateContainerMargin = this.updateContainerMargin.bind(this);
|
||||||
this.handleDisplayModeToggle = this.handleDisplayModeToggle.bind(this);
|
this.handleDisplayModeToggle = this.handleDisplayModeToggle.bind(this);
|
||||||
this.handleFolderListClick = this.handleFolderListClick.bind(this);
|
this.handleFolderListClick = this.handleFolderListClick.bind(this);
|
||||||
|
this.handleRecursiveToggle = this.handleRecursiveToggle.bind(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
async initialize(pageControls) {
|
async initialize(pageControls) {
|
||||||
@@ -89,6 +91,7 @@ export class SidebarManager {
|
|||||||
this.isHovering = false;
|
this.isHovering = false;
|
||||||
this.apiClient = null;
|
this.apiClient = null;
|
||||||
this.isInitialized = false;
|
this.isInitialized = false;
|
||||||
|
this.recursiveSearchEnabled = true;
|
||||||
|
|
||||||
// Reset container margin
|
// Reset container margin
|
||||||
const container = document.querySelector('.container');
|
const container = document.querySelector('.container');
|
||||||
@@ -111,6 +114,7 @@ export class SidebarManager {
|
|||||||
const sidebar = document.getElementById('folderSidebar');
|
const sidebar = document.getElementById('folderSidebar');
|
||||||
const hoverArea = document.getElementById('sidebarHoverArea');
|
const hoverArea = document.getElementById('sidebarHoverArea');
|
||||||
const displayModeToggleBtn = document.getElementById('sidebarDisplayModeToggle');
|
const displayModeToggleBtn = document.getElementById('sidebarDisplayModeToggle');
|
||||||
|
const recursiveToggleBtn = document.getElementById('sidebarRecursiveToggle');
|
||||||
|
|
||||||
if (pinToggleBtn) {
|
if (pinToggleBtn) {
|
||||||
pinToggleBtn.removeEventListener('click', this.handlePinToggle);
|
pinToggleBtn.removeEventListener('click', this.handlePinToggle);
|
||||||
@@ -145,6 +149,9 @@ export class SidebarManager {
|
|||||||
if (displayModeToggleBtn) {
|
if (displayModeToggleBtn) {
|
||||||
displayModeToggleBtn.removeEventListener('click', this.handleDisplayModeToggle);
|
displayModeToggleBtn.removeEventListener('click', this.handleDisplayModeToggle);
|
||||||
}
|
}
|
||||||
|
if (recursiveToggleBtn) {
|
||||||
|
recursiveToggleBtn.removeEventListener('click', this.handleRecursiveToggle);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async init() {
|
async init() {
|
||||||
@@ -197,7 +204,7 @@ export class SidebarManager {
|
|||||||
updateSidebarTitle() {
|
updateSidebarTitle() {
|
||||||
const sidebarTitle = document.getElementById('sidebarTitle');
|
const sidebarTitle = document.getElementById('sidebarTitle');
|
||||||
if (sidebarTitle) {
|
if (sidebarTitle) {
|
||||||
sidebarTitle.textContent = `${this.apiClient.apiConfig.config.displayName} Root`;
|
sidebarTitle.textContent = translate('sidebar.modelRoot');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -220,6 +227,12 @@ export class SidebarManager {
|
|||||||
collapseAllBtn.addEventListener('click', this.handleCollapseAll);
|
collapseAllBtn.addEventListener('click', this.handleCollapseAll);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Recursive toggle button
|
||||||
|
const recursiveToggleBtn = document.getElementById('sidebarRecursiveToggle');
|
||||||
|
if (recursiveToggleBtn) {
|
||||||
|
recursiveToggleBtn.addEventListener('click', this.handleRecursiveToggle);
|
||||||
|
}
|
||||||
|
|
||||||
// Tree click handler
|
// Tree click handler
|
||||||
const folderTree = document.getElementById('sidebarFolderTree');
|
const folderTree = document.getElementById('sidebarFolderTree');
|
||||||
if (folderTree) {
|
if (folderTree) {
|
||||||
@@ -645,11 +658,33 @@ export class SidebarManager {
|
|||||||
this.displayMode = this.displayMode === 'tree' ? 'list' : 'tree';
|
this.displayMode = this.displayMode === 'tree' ? 'list' : 'tree';
|
||||||
this.updateDisplayModeButton();
|
this.updateDisplayModeButton();
|
||||||
this.updateCollapseAllButton();
|
this.updateCollapseAllButton();
|
||||||
|
this.updateRecursiveToggleButton();
|
||||||
this.updateSearchRecursiveOption();
|
this.updateSearchRecursiveOption();
|
||||||
this.saveDisplayMode();
|
this.saveDisplayMode();
|
||||||
this.loadFolderTree(); // Reload with new display mode
|
this.loadFolderTree(); // Reload with new display mode
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async handleRecursiveToggle(event) {
|
||||||
|
event.stopPropagation();
|
||||||
|
|
||||||
|
if (this.displayMode !== 'tree') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.recursiveSearchEnabled = !this.recursiveSearchEnabled;
|
||||||
|
setStorageItem(`${this.pageType}_recursiveSearch`, this.recursiveSearchEnabled);
|
||||||
|
this.updateSearchRecursiveOption();
|
||||||
|
this.updateRecursiveToggleButton();
|
||||||
|
|
||||||
|
if (this.pageControls && typeof this.pageControls.resetAndReload === 'function') {
|
||||||
|
try {
|
||||||
|
await this.pageControls.resetAndReload(true);
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to reload models after toggling recursive search:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
updateDisplayModeButton() {
|
updateDisplayModeButton() {
|
||||||
const displayModeBtn = document.getElementById('sidebarDisplayModeToggle');
|
const displayModeBtn = document.getElementById('sidebarDisplayModeToggle');
|
||||||
if (displayModeBtn) {
|
if (displayModeBtn) {
|
||||||
@@ -679,8 +714,35 @@ export class SidebarManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
updateRecursiveToggleButton() {
|
||||||
|
const recursiveToggleBtn = document.getElementById('sidebarRecursiveToggle');
|
||||||
|
if (!recursiveToggleBtn) return;
|
||||||
|
|
||||||
|
const icon = recursiveToggleBtn.querySelector('i');
|
||||||
|
const isTreeMode = this.displayMode === 'tree';
|
||||||
|
const isActive = isTreeMode && this.recursiveSearchEnabled;
|
||||||
|
|
||||||
|
recursiveToggleBtn.classList.toggle('active', isActive);
|
||||||
|
recursiveToggleBtn.classList.toggle('disabled', !isTreeMode);
|
||||||
|
recursiveToggleBtn.setAttribute('aria-pressed', isActive ? 'true' : 'false');
|
||||||
|
recursiveToggleBtn.setAttribute('aria-disabled', isTreeMode ? 'false' : 'true');
|
||||||
|
|
||||||
|
if (icon) {
|
||||||
|
icon.className = 'fas fa-code-branch';
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isTreeMode) {
|
||||||
|
recursiveToggleBtn.title = translate('sidebar.recursiveUnavailable');
|
||||||
|
} else if (this.recursiveSearchEnabled) {
|
||||||
|
recursiveToggleBtn.title = translate('sidebar.recursiveOn');
|
||||||
|
} else {
|
||||||
|
recursiveToggleBtn.title = translate('sidebar.recursiveOff');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
updateSearchRecursiveOption() {
|
updateSearchRecursiveOption() {
|
||||||
this.pageControls.pageState.searchOptions.recursive = this.displayMode === 'tree';
|
const isRecursive = this.displayMode === 'tree' && this.recursiveSearchEnabled;
|
||||||
|
this.pageControls.pageState.searchOptions.recursive = isRecursive;
|
||||||
}
|
}
|
||||||
|
|
||||||
updateTreeSelection() {
|
updateTreeSelection() {
|
||||||
@@ -925,15 +987,18 @@ export class SidebarManager {
|
|||||||
const isPinned = getStorageItem(`${this.pageType}_sidebarPinned`, true);
|
const isPinned = getStorageItem(`${this.pageType}_sidebarPinned`, true);
|
||||||
const expandedPaths = getStorageItem(`${this.pageType}_expandedNodes`, []);
|
const expandedPaths = getStorageItem(`${this.pageType}_expandedNodes`, []);
|
||||||
const displayMode = getStorageItem(`${this.pageType}_displayMode`, 'tree'); // 'tree' or 'list', default to 'tree'
|
const displayMode = getStorageItem(`${this.pageType}_displayMode`, 'tree'); // 'tree' or 'list', default to 'tree'
|
||||||
|
const recursiveSearchEnabled = getStorageItem(`${this.pageType}_recursiveSearch`, true);
|
||||||
|
|
||||||
this.isPinned = isPinned;
|
this.isPinned = isPinned;
|
||||||
this.expandedNodes = new Set(expandedPaths);
|
this.expandedNodes = new Set(expandedPaths);
|
||||||
this.displayMode = displayMode;
|
this.displayMode = displayMode;
|
||||||
|
this.recursiveSearchEnabled = recursiveSearchEnabled;
|
||||||
|
|
||||||
this.updatePinButton();
|
this.updatePinButton();
|
||||||
this.updateDisplayModeButton();
|
this.updateDisplayModeButton();
|
||||||
this.updateCollapseAllButton();
|
this.updateCollapseAllButton();
|
||||||
this.updateSearchRecursiveOption();
|
this.updateSearchRecursiveOption();
|
||||||
|
this.updateRecursiveToggleButton();
|
||||||
}
|
}
|
||||||
|
|
||||||
restoreSelectedFolder() {
|
restoreSelectedFolder() {
|
||||||
@@ -974,4 +1039,4 @@ export class SidebarManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create and export global instance
|
// Create and export global instance
|
||||||
export const sidebarManager = new SidebarManager();
|
export const sidebarManager = new SidebarManager();
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ export const state = {
|
|||||||
modelname: true,
|
modelname: true,
|
||||||
tags: false,
|
tags: false,
|
||||||
creator: false,
|
creator: false,
|
||||||
recursive: true,
|
recursive: getStorageItem(`${MODEL_TYPES.LORA}_recursiveSearch`, true),
|
||||||
},
|
},
|
||||||
filters: {
|
filters: {
|
||||||
baseModel: [],
|
baseModel: [],
|
||||||
@@ -116,7 +116,7 @@ export const state = {
|
|||||||
filename: true,
|
filename: true,
|
||||||
modelname: true,
|
modelname: true,
|
||||||
creator: false,
|
creator: false,
|
||||||
recursive: true,
|
recursive: getStorageItem(`${MODEL_TYPES.CHECKPOINT}_recursiveSearch`, true),
|
||||||
},
|
},
|
||||||
filters: {
|
filters: {
|
||||||
baseModel: [],
|
baseModel: [],
|
||||||
@@ -144,7 +144,7 @@ export const state = {
|
|||||||
modelname: true,
|
modelname: true,
|
||||||
tags: false,
|
tags: false,
|
||||||
creator: false,
|
creator: false,
|
||||||
recursive: true,
|
recursive: getStorageItem(`${MODEL_TYPES.EMBEDDING}_recursiveSearch`, true),
|
||||||
},
|
},
|
||||||
filters: {
|
filters: {
|
||||||
baseModel: [],
|
baseModel: [],
|
||||||
@@ -261,4 +261,4 @@ export function initPageState(pageType) {
|
|||||||
return getCurrentPageState();
|
return getCurrentPageState();
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -435,8 +435,9 @@ export async function sendLoraToWorkflow(loraSyntax, replaceMode = false, syntax
|
|||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
// Single node - send directly
|
// Single node - send directly
|
||||||
const nodeId = Object.keys(registryData.data.nodes)[0];
|
const nodes = registryData.data.nodes;
|
||||||
return await sendToSpecificNode([nodeId], loraSyntax, replaceMode, syntaxType);
|
const nodeId = Object.keys(nodes)[0];
|
||||||
|
return await sendToSpecificNode([nodeId], nodes, loraSyntax, replaceMode, syntaxType);
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Failed to get registry:', error);
|
console.error('Failed to get registry:', error);
|
||||||
@@ -452,19 +453,65 @@ export async function sendLoraToWorkflow(loraSyntax, replaceMode = false, syntax
|
|||||||
* @param {boolean} replaceMode - Whether to replace existing LoRAs
|
* @param {boolean} replaceMode - Whether to replace existing LoRAs
|
||||||
* @param {string} syntaxType - The type of syntax ('lora' or 'recipe')
|
* @param {string} syntaxType - The type of syntax ('lora' or 'recipe')
|
||||||
*/
|
*/
|
||||||
async function sendToSpecificNode(nodeIds, loraSyntax, replaceMode, syntaxType) {
|
function resolveNodeReference(nodeKey, nodesMap) {
|
||||||
|
if (!nodeKey) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const directMatch = nodesMap?.[nodeKey];
|
||||||
|
if (directMatch) {
|
||||||
|
return {
|
||||||
|
node_id: directMatch.id,
|
||||||
|
graph_id: directMatch.graph_id ?? null,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeof nodeKey === 'string' && nodeKey.includes(':')) {
|
||||||
|
const [graphId, ...rest] = nodeKey.split(':');
|
||||||
|
const nodeIdPart = rest.join(':');
|
||||||
|
const numericNodeId = Number(nodeIdPart);
|
||||||
|
return {
|
||||||
|
node_id: Number.isNaN(numericNodeId) ? nodeIdPart : numericNodeId,
|
||||||
|
graph_id: graphId || null,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const numericId = Number(nodeKey);
|
||||||
|
return {
|
||||||
|
node_id: Number.isNaN(numericId) ? nodeKey : numericId,
|
||||||
|
graph_id: null,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
async function sendToSpecificNode(nodeIds, nodesMap, loraSyntax, replaceMode, syntaxType) {
|
||||||
try {
|
try {
|
||||||
// Call the backend API to update the lora code
|
// Call the backend API to update the lora code
|
||||||
|
const requestBody = {
|
||||||
|
lora_code: loraSyntax,
|
||||||
|
mode: replaceMode ? 'replace' : 'append'
|
||||||
|
};
|
||||||
|
|
||||||
|
if (Array.isArray(nodeIds)) {
|
||||||
|
const references = nodeIds
|
||||||
|
.map((nodeKey) => resolveNodeReference(nodeKey, nodesMap))
|
||||||
|
.filter((reference) => reference && reference.node_id !== undefined);
|
||||||
|
|
||||||
|
if (references.length > 0) {
|
||||||
|
requestBody.node_ids = references;
|
||||||
|
}
|
||||||
|
} else if (nodeIds) {
|
||||||
|
const reference = resolveNodeReference(nodeIds, nodesMap);
|
||||||
|
if (reference) {
|
||||||
|
requestBody.node_ids = [reference];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const response = await fetch('/api/lm/update-lora-code', {
|
const response = await fetch('/api/lm/update-lora-code', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
},
|
},
|
||||||
body: JSON.stringify({
|
body: JSON.stringify(requestBody)
|
||||||
node_ids: nodeIds,
|
|
||||||
lora_code: loraSyntax,
|
|
||||||
mode: replaceMode ? 'replace' : 'append'
|
|
||||||
})
|
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = await response.json();
|
const result = await response.json();
|
||||||
@@ -522,16 +569,17 @@ function showNodeSelector(nodes, loraSyntax, replaceMode, syntaxType) {
|
|||||||
hideNodeSelector();
|
hideNodeSelector();
|
||||||
|
|
||||||
// Generate node list HTML with icons and proper colors
|
// Generate node list HTML with icons and proper colors
|
||||||
const nodeItems = Object.values(nodes).map(node => {
|
const nodeItems = Object.entries(nodes).map(([nodeKey, node]) => {
|
||||||
const iconClass = NODE_TYPE_ICONS[node.type] || 'fas fa-question-circle';
|
const iconClass = NODE_TYPE_ICONS[node.type] || 'fas fa-question-circle';
|
||||||
const bgColor = node.bgcolor || DEFAULT_NODE_COLOR;
|
const bgColor = node.bgcolor || DEFAULT_NODE_COLOR;
|
||||||
|
const graphLabel = node.graph_name ? ` (${node.graph_name})` : '';
|
||||||
|
|
||||||
return `
|
return `
|
||||||
<div class="node-item" data-node-id="${node.id}">
|
<div class="node-item" data-node-id="${nodeKey}">
|
||||||
<div class="node-icon-indicator" style="background-color: ${bgColor}">
|
<div class="node-icon-indicator" style="background-color: ${bgColor}">
|
||||||
<i class="${iconClass}"></i>
|
<i class="${iconClass}"></i>
|
||||||
</div>
|
</div>
|
||||||
<span>#${node.id} ${node.title}</span>
|
<span>#${node.id}${graphLabel} ${node.title}</span>
|
||||||
</div>
|
</div>
|
||||||
`;
|
`;
|
||||||
}).join('');
|
}).join('');
|
||||||
@@ -610,10 +658,10 @@ function setupNodeSelectorEvents(selector, nodes, loraSyntax, replaceMode, synta
|
|||||||
if (action === 'send-all') {
|
if (action === 'send-all') {
|
||||||
// Send to all nodes
|
// Send to all nodes
|
||||||
const allNodeIds = Object.keys(nodes);
|
const allNodeIds = Object.keys(nodes);
|
||||||
await sendToSpecificNode(allNodeIds, loraSyntax, replaceMode, syntaxType);
|
await sendToSpecificNode(allNodeIds, nodes, loraSyntax, replaceMode, syntaxType);
|
||||||
} else if (nodeId) {
|
} else if (nodeId) {
|
||||||
// Send to specific node
|
// Send to specific node
|
||||||
await sendToSpecificNode([nodeId], loraSyntax, replaceMode, syntaxType);
|
await sendToSpecificNode([nodeId], nodes, loraSyntax, replaceMode, syntaxType);
|
||||||
}
|
}
|
||||||
|
|
||||||
hideNodeSelector();
|
hideNodeSelector();
|
||||||
|
|||||||
@@ -9,6 +9,9 @@
|
|||||||
<button class="sidebar-action-btn" id="sidebarDisplayModeToggle" title="{{ t('sidebar.switchToListView') }}">
|
<button class="sidebar-action-btn" id="sidebarDisplayModeToggle" title="{{ t('sidebar.switchToListView') }}">
|
||||||
<i class="fas fa-sitemap"></i>
|
<i class="fas fa-sitemap"></i>
|
||||||
</button>
|
</button>
|
||||||
|
<button class="sidebar-action-btn active" id="sidebarRecursiveToggle" title="{{ t('sidebar.recursiveOn') }}" aria-pressed="true">
|
||||||
|
<i class="fas fa-code-branch"></i>
|
||||||
|
</button>
|
||||||
<button class="sidebar-action-btn" id="sidebarCollapseAll" title="{{ t('sidebar.collapseAll') }}">
|
<button class="sidebar-action-btn" id="sidebarCollapseAll" title="{{ t('sidebar.collapseAll') }}">
|
||||||
<i class="fas fa-compress-alt"></i>
|
<i class="fas fa-compress-alt"></i>
|
||||||
</button>
|
</button>
|
||||||
|
|||||||
@@ -73,6 +73,30 @@ nodes_mock.NODE_CLASS_MAPPINGS = {}
|
|||||||
sys.modules['nodes'] = nodes_mock
|
sys.modules['nodes'] = nodes_mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _isolate_settings_dir(tmp_path_factory, monkeypatch):
|
||||||
|
"""Redirect settings.json into a temporary directory for each test."""
|
||||||
|
|
||||||
|
settings_dir = tmp_path_factory.mktemp("settings_dir")
|
||||||
|
|
||||||
|
def fake_get_settings_dir(create: bool = True) -> str:
|
||||||
|
if create:
|
||||||
|
settings_dir.mkdir(exist_ok=True)
|
||||||
|
return str(settings_dir)
|
||||||
|
|
||||||
|
monkeypatch.setattr("py.utils.settings_paths.get_settings_dir", fake_get_settings_dir)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"py.utils.settings_paths.user_config_dir",
|
||||||
|
lambda *_args, **_kwargs: str(settings_dir),
|
||||||
|
)
|
||||||
|
|
||||||
|
from py.services import settings_manager as settings_manager_module
|
||||||
|
|
||||||
|
settings_manager_module.reset_settings_manager()
|
||||||
|
yield
|
||||||
|
settings_manager_module.reset_settings_manager()
|
||||||
|
|
||||||
|
|
||||||
def pytest_pyfunc_call(pyfuncitem):
|
def pytest_pyfunc_call(pyfuncitem):
|
||||||
"""Allow bare async tests to run without pytest.mark.asyncio."""
|
"""Allow bare async tests to run without pytest.mark.asyncio."""
|
||||||
test_function = pyfuncitem.function
|
test_function = pyfuncitem.function
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ vi.mock(EVENT_MANAGER_MODULE, () => ({
|
|||||||
off: vi.fn(),
|
off: vi.fn(),
|
||||||
addHandler: vi.fn(),
|
addHandler: vi.fn(),
|
||||||
removeHandler: vi.fn(),
|
removeHandler: vi.fn(),
|
||||||
|
setState: vi.fn(),
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
|
|
||||||
@@ -62,6 +63,7 @@ describe('UI helper DOM utilities', () => {
|
|||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
vi.useRealTimers();
|
vi.useRealTimers();
|
||||||
|
delete global.fetch;
|
||||||
});
|
});
|
||||||
|
|
||||||
it('creates toast elements and cleans them up after timeout', async () => {
|
it('creates toast elements and cleans them up after timeout', async () => {
|
||||||
@@ -105,4 +107,53 @@ describe('UI helper DOM utilities', () => {
|
|||||||
expect(document.body.dataset.theme).toBe('dark');
|
expect(document.body.dataset.theme).toBe('dark');
|
||||||
expect(document.querySelector('.theme-toggle').classList.contains('theme-dark')).toBe(true);
|
expect(document.querySelector('.theme-toggle').classList.contains('theme-dark')).toBe(true);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('renders subgraph names in the node selector list', async () => {
|
||||||
|
const registryResponse = {
|
||||||
|
success: true,
|
||||||
|
data: {
|
||||||
|
node_count: 2,
|
||||||
|
nodes: {
|
||||||
|
'root:1': {
|
||||||
|
id: 1,
|
||||||
|
graph_id: 'root',
|
||||||
|
graph_name: null,
|
||||||
|
title: 'Root Loader',
|
||||||
|
type: 1,
|
||||||
|
bgcolor: '#123456',
|
||||||
|
},
|
||||||
|
'subgraph-uuid:2': {
|
||||||
|
id: 2,
|
||||||
|
graph_id: 'subgraph-uuid',
|
||||||
|
graph_name: 'Character Subgraph',
|
||||||
|
title: 'Nested Loader',
|
||||||
|
type: 1,
|
||||||
|
bgcolor: '#654321',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
global.fetch = vi.fn().mockResolvedValue({
|
||||||
|
json: async () => registryResponse,
|
||||||
|
});
|
||||||
|
|
||||||
|
document.body.innerHTML = '<div id="nodeSelector"></div>';
|
||||||
|
|
||||||
|
const { sendLoraToWorkflow } = await import(UI_HELPERS_MODULE);
|
||||||
|
|
||||||
|
const result = await sendLoraToWorkflow('<lora:test:1>');
|
||||||
|
|
||||||
|
expect(result).toBe(true);
|
||||||
|
expect(global.fetch).toHaveBeenCalledWith('/api/lm/get-registry');
|
||||||
|
|
||||||
|
const nodeLabels = Array.from(
|
||||||
|
document.querySelectorAll('#nodeSelector .node-item[data-node-id] span')
|
||||||
|
).map((span) => span.textContent.trim());
|
||||||
|
|
||||||
|
expect(nodeLabels).toEqual([
|
||||||
|
'#1 Root Loader',
|
||||||
|
'#2 (Character Subgraph) Nested Loader',
|
||||||
|
]);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -16,10 +16,12 @@ from aiohttp.test_utils import TestClient, TestServer
|
|||||||
from py.config import config
|
from py.config import config
|
||||||
from py.routes.base_model_routes import BaseModelRoutes
|
from py.routes.base_model_routes import BaseModelRoutes
|
||||||
from py.services import model_file_service
|
from py.services import model_file_service
|
||||||
|
from py.services.metadata_sync_service import MetadataSyncService
|
||||||
from py.services.model_file_service import AutoOrganizeResult
|
from py.services.model_file_service import AutoOrganizeResult
|
||||||
from py.services.service_registry import ServiceRegistry
|
from py.services.service_registry import ServiceRegistry
|
||||||
from py.services.websocket_manager import ws_manager
|
from py.services.websocket_manager import ws_manager
|
||||||
from py.utils.exif_utils import ExifUtils
|
from py.utils.exif_utils import ExifUtils
|
||||||
|
from py.utils.metadata_manager import MetadataManager
|
||||||
|
|
||||||
|
|
||||||
class DummyRoutes(BaseModelRoutes):
|
class DummyRoutes(BaseModelRoutes):
|
||||||
@@ -197,6 +199,116 @@ def test_replace_preview_writes_file_and_updates_cache(
|
|||||||
asyncio.run(scenario())
|
asyncio.run(scenario())
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_civitai_hydrates_metadata_before_sync(
|
||||||
|
mock_service,
|
||||||
|
mock_scanner,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path: Path,
|
||||||
|
):
|
||||||
|
model_path = tmp_path / "hydrate.safetensors"
|
||||||
|
model_path.write_bytes(b"model")
|
||||||
|
metadata_path = tmp_path / "hydrate.metadata.json"
|
||||||
|
|
||||||
|
existing_metadata = {
|
||||||
|
"file_path": str(model_path),
|
||||||
|
"sha256": "abc123",
|
||||||
|
"model_name": "Hydrated",
|
||||||
|
"preview_url": "keep/me.png",
|
||||||
|
"civitai": {
|
||||||
|
"id": 99,
|
||||||
|
"modelId": 42,
|
||||||
|
"images": [{"url": "https://example.com/existing.png", "type": "image"}],
|
||||||
|
"customImages": [{"id": "old-id", "url": "", "type": "image"}],
|
||||||
|
"trainedWords": ["keep"],
|
||||||
|
},
|
||||||
|
"custom_field": "preserve",
|
||||||
|
}
|
||||||
|
metadata_path.write_text(json.dumps(existing_metadata), encoding="utf-8")
|
||||||
|
|
||||||
|
minimal_cache_entry = {
|
||||||
|
"file_path": str(model_path),
|
||||||
|
"sha256": "abc123",
|
||||||
|
"folder": "some/folder",
|
||||||
|
"civitai": {"id": 99, "modelId": 42},
|
||||||
|
}
|
||||||
|
mock_scanner._cache.raw_data = [minimal_cache_entry]
|
||||||
|
|
||||||
|
class FakeMetadata:
|
||||||
|
def __init__(self, payload: dict) -> None:
|
||||||
|
self._payload = payload
|
||||||
|
self._unknown_fields = {"legacy_field": "legacy"}
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return self._payload.copy()
|
||||||
|
|
||||||
|
async def fake_load_metadata(path: str, *_args, **_kwargs):
|
||||||
|
assert path == str(model_path)
|
||||||
|
return FakeMetadata(existing_metadata), False
|
||||||
|
|
||||||
|
async def fake_save_metadata(path: str, metadata: dict) -> bool:
|
||||||
|
save_calls.append((path, json.loads(json.dumps(metadata))))
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def fake_fetch_and_update_model(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
sha256: str,
|
||||||
|
file_path: str,
|
||||||
|
model_data: dict,
|
||||||
|
update_cache_func,
|
||||||
|
):
|
||||||
|
captured["model_data"] = json.loads(json.dumps(model_data))
|
||||||
|
to_save = model_data.copy()
|
||||||
|
to_save.pop("folder", None)
|
||||||
|
await MetadataManager.save_metadata(
|
||||||
|
os.path.splitext(file_path)[0] + ".metadata.json",
|
||||||
|
to_save,
|
||||||
|
)
|
||||||
|
await update_cache_func(file_path, file_path, model_data)
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
save_calls: list[tuple[str, dict]] = []
|
||||||
|
captured: dict[str, dict] = {}
|
||||||
|
|
||||||
|
monkeypatch.setattr(MetadataManager, "load_metadata", staticmethod(fake_load_metadata))
|
||||||
|
monkeypatch.setattr(MetadataManager, "save_metadata", staticmethod(fake_save_metadata))
|
||||||
|
monkeypatch.setattr(MetadataSyncService, "fetch_and_update_model", fake_fetch_and_update_model)
|
||||||
|
|
||||||
|
async def scenario():
|
||||||
|
client = await create_test_client(mock_service)
|
||||||
|
try:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/lm/test-models/fetch-civitai",
|
||||||
|
json={"file_path": str(model_path)},
|
||||||
|
)
|
||||||
|
payload = await response.json()
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert payload["success"] is True
|
||||||
|
assert captured["model_data"]["custom_field"] == "preserve"
|
||||||
|
assert captured["model_data"]["civitai"]["images"][0]["url"] == "https://example.com/existing.png"
|
||||||
|
assert captured["model_data"]["civitai"]["trainedWords"] == ["keep"]
|
||||||
|
assert captured["model_data"]["civitai"]["id"] == 99
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
asyncio.run(scenario())
|
||||||
|
|
||||||
|
assert save_calls, "Metadata save should be invoked"
|
||||||
|
saved_path, saved_payload = save_calls[0]
|
||||||
|
assert saved_path == str(metadata_path)
|
||||||
|
assert saved_payload["custom_field"] == "preserve"
|
||||||
|
assert saved_payload["civitai"]["images"][0]["url"] == "https://example.com/existing.png"
|
||||||
|
assert saved_payload["civitai"]["trainedWords"] == ["keep"]
|
||||||
|
assert saved_payload["civitai"]["id"] == 99
|
||||||
|
assert saved_payload["legacy_field"] == "legacy"
|
||||||
|
|
||||||
|
assert mock_scanner.updated_models
|
||||||
|
updated_metadata = mock_scanner.updated_models[-1]["metadata"]
|
||||||
|
assert updated_metadata["custom_field"] == "preserve"
|
||||||
|
assert updated_metadata["civitai"]["customImages"][0]["id"] == "old-id"
|
||||||
|
|
||||||
|
|
||||||
def test_download_model_invokes_download_manager(
|
def test_download_model_invokes_download_manager(
|
||||||
mock_service,
|
mock_service,
|
||||||
download_manager_stub,
|
download_manager_stub,
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ async def test_get_trigger_words_broadcasts(monkeypatch, routes):
|
|||||||
|
|
||||||
monkeypatch.setattr("py.routes.lora_routes.get_lora_info", lambda name: (f"path/{name}", [f"trigger-{name}"]))
|
monkeypatch.setattr("py.routes.lora_routes.get_lora_info", lambda name: (f"path/{name}", [f"trigger-{name}"]))
|
||||||
|
|
||||||
request = DummyRequest(json_data={"lora_names": ["one"], "node_ids": ["node"]})
|
request = DummyRequest(json_data={"lora_names": ["one"], "node_ids": [{"node_id": "node", "graph_id": "graph-1"}]})
|
||||||
|
|
||||||
response = await routes.get_trigger_words(request)
|
response = await routes.get_trigger_words(request)
|
||||||
payload = json.loads(response.text)
|
payload = json.loads(response.text)
|
||||||
@@ -196,7 +196,7 @@ async def test_get_trigger_words_broadcasts(monkeypatch, routes):
|
|||||||
assert payload == {"success": True}
|
assert payload == {"success": True}
|
||||||
send_mock.assert_called_once_with(
|
send_mock.assert_called_once_with(
|
||||||
"trigger_word_update",
|
"trigger_word_update",
|
||||||
{"id": "node", "message": "trigger-one"},
|
{"id": "node", "graph_id": "graph-1", "message": "trigger-one"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,14 @@ from types import SimpleNamespace
|
|||||||
import pytest
|
import pytest
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from py.routes.handlers.misc_handlers import SettingsHandler, ServiceRegistryAdapter
|
from py.routes.handlers.misc_handlers import (
|
||||||
|
LoraCodeHandler,
|
||||||
|
ModelLibraryHandler,
|
||||||
|
NodeRegistry,
|
||||||
|
NodeRegistryHandler,
|
||||||
|
ServiceRegistryAdapter,
|
||||||
|
SettingsHandler,
|
||||||
|
)
|
||||||
from py.routes.misc_route_registrar import MISC_ROUTE_DEFINITIONS, MiscRouteRegistrar
|
from py.routes.misc_route_registrar import MISC_ROUTE_DEFINITIONS, MiscRouteRegistrar
|
||||||
from py.routes.misc_routes import MiscRoutes
|
from py.routes.misc_routes import MiscRoutes
|
||||||
|
|
||||||
@@ -126,6 +133,128 @@ class FakePromptServer:
|
|||||||
instance = Instance()
|
instance = Instance()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_nodes_requires_graph_id():
|
||||||
|
node_registry = NodeRegistry()
|
||||||
|
handler = NodeRegistryHandler(
|
||||||
|
node_registry=node_registry,
|
||||||
|
prompt_server=FakePromptServer,
|
||||||
|
standalone_mode=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
request = FakeRequest(json_data={"nodes": [{"node_id": 1}]})
|
||||||
|
response = await handler.register_nodes(request)
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert payload["success"] is False
|
||||||
|
assert "graph_id" in payload["error"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_nodes_stores_graph_identifier():
|
||||||
|
node_registry = NodeRegistry()
|
||||||
|
handler = NodeRegistryHandler(
|
||||||
|
node_registry=node_registry,
|
||||||
|
prompt_server=FakePromptServer,
|
||||||
|
standalone_mode=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
request = FakeRequest(
|
||||||
|
json_data={
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"node_id": 7,
|
||||||
|
"graph_id": "graph-123",
|
||||||
|
"graph_name": "Character Subgraph",
|
||||||
|
"type": "Lora Loader (LoraManager)",
|
||||||
|
"title": "Loader",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handler.register_nodes(request)
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert payload["success"] is True
|
||||||
|
|
||||||
|
registry = await node_registry.get_registry()
|
||||||
|
assert registry["node_count"] == 1
|
||||||
|
stored_node = next(iter(registry["nodes"].values()))
|
||||||
|
assert stored_node["graph_id"] == "graph-123"
|
||||||
|
assert stored_node["unique_id"] == "graph-123:7"
|
||||||
|
assert stored_node["graph_name"] == "Character Subgraph"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_nodes_defaults_graph_name_to_none():
|
||||||
|
node_registry = NodeRegistry()
|
||||||
|
handler = NodeRegistryHandler(
|
||||||
|
node_registry=node_registry,
|
||||||
|
prompt_server=FakePromptServer,
|
||||||
|
standalone_mode=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
request = FakeRequest(
|
||||||
|
json_data={
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"node_id": 8,
|
||||||
|
"graph_id": "root",
|
||||||
|
"type": "Lora Loader (LoraManager)",
|
||||||
|
"title": "Root Loader",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handler.register_nodes(request)
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert payload["success"] is True
|
||||||
|
|
||||||
|
registry = await node_registry.get_registry()
|
||||||
|
stored_node = next(iter(registry["nodes"].values()))
|
||||||
|
assert stored_node["graph_name"] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_lora_code_includes_graph_identifier():
|
||||||
|
send_calls: list[tuple[str, dict]] = []
|
||||||
|
|
||||||
|
class RecordingPromptServer:
|
||||||
|
class Instance:
|
||||||
|
def send_sync(self, event, payload):
|
||||||
|
send_calls.append((event, payload))
|
||||||
|
|
||||||
|
instance = Instance()
|
||||||
|
|
||||||
|
handler = LoraCodeHandler(RecordingPromptServer)
|
||||||
|
|
||||||
|
request = FakeRequest(
|
||||||
|
json_data={
|
||||||
|
"node_ids": [{"node_id": 3, "graph_id": "graph-A"}],
|
||||||
|
"lora_code": "<lora>",
|
||||||
|
"mode": "replace",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handler.update_lora_code(request)
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert payload["success"] is True
|
||||||
|
assert payload["results"] == [
|
||||||
|
{"node_id": 3, "graph_id": "graph-A", "success": True}
|
||||||
|
]
|
||||||
|
assert send_calls == [
|
||||||
|
(
|
||||||
|
"lora_code_update",
|
||||||
|
{"id": 3, "graph_id": "graph-A", "lora_code": "<lora>", "mode": "replace"},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class FakeScanner:
|
class FakeScanner:
|
||||||
async def check_model_version_exists(self, _version_id):
|
async def check_model_version_exists(self, _version_id):
|
||||||
return False
|
return False
|
||||||
@@ -138,10 +267,34 @@ async def fake_scanner_factory():
|
|||||||
return FakeScanner()
|
return FakeScanner()
|
||||||
|
|
||||||
|
|
||||||
|
class FakeExistenceScanner:
|
||||||
|
def __init__(self, existing=None):
|
||||||
|
self._existing = set(existing or [])
|
||||||
|
|
||||||
|
async def check_model_version_exists(self, version_id):
|
||||||
|
return version_id in self._existing
|
||||||
|
|
||||||
|
async def get_model_versions_by_id(self, _model_id):
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class FakeMetadataProvider:
|
class FakeMetadataProvider:
|
||||||
async def get_model_versions(self, _model_id):
|
async def get_model_versions(self, _model_id):
|
||||||
return {"modelVersions": [], "name": "", "type": "lora"}
|
return {"modelVersions": [], "name": "", "type": "lora"}
|
||||||
|
|
||||||
|
async def get_user_models(self, _username):
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class FakeUserModelsProvider(FakeMetadataProvider):
|
||||||
|
def __init__(self, models):
|
||||||
|
self.models = models
|
||||||
|
self.received_usernames: list[str] = []
|
||||||
|
|
||||||
|
async def get_user_models(self, username):
|
||||||
|
self.received_usernames.append(username)
|
||||||
|
return self.models
|
||||||
|
|
||||||
|
|
||||||
async def fake_metadata_provider_factory():
|
async def fake_metadata_provider_factory():
|
||||||
return FakeMetadataProvider()
|
return FakeMetadataProvider()
|
||||||
@@ -211,6 +364,250 @@ async def test_misc_routes_bind_produces_expected_handlers():
|
|||||||
assert set(mapping.keys()) == expected_names
|
assert set(mapping.keys()) == expected_names
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_civitai_user_models_marks_library_versions():
|
||||||
|
models = [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"name": "Model A",
|
||||||
|
"type": "LORA",
|
||||||
|
"tags": ["style"],
|
||||||
|
"modelVersions": [
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"name": "v1",
|
||||||
|
"baseModel": "Flux.1",
|
||||||
|
"images": [{"url": "http://example.com/a1.jpg"}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 101,
|
||||||
|
"name": "v2",
|
||||||
|
"baseModel": "Flux.1",
|
||||||
|
"images": [{"url": "http://example.com/a2.jpg"}],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"name": "Embedding",
|
||||||
|
"type": "TextualInversion",
|
||||||
|
"tags": ["embedding"],
|
||||||
|
"modelVersions": [
|
||||||
|
{
|
||||||
|
"id": 200,
|
||||||
|
"name": "v1",
|
||||||
|
"baseModel": None,
|
||||||
|
"images": [{"url": "http://example.com/e1.jpg"}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 202,
|
||||||
|
"name": "v2",
|
||||||
|
"baseModel": None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"name": "Checkpoint",
|
||||||
|
"type": "Checkpoint",
|
||||||
|
"tags": ["checkpoint"],
|
||||||
|
"modelVersions": [
|
||||||
|
{
|
||||||
|
"id": 300,
|
||||||
|
"name": "v1",
|
||||||
|
"baseModel": "SDXL",
|
||||||
|
"images": [],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4,
|
||||||
|
"name": "Unsupported",
|
||||||
|
"type": "Other",
|
||||||
|
"modelVersions": [
|
||||||
|
{
|
||||||
|
"id": 400,
|
||||||
|
"name": "v1",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
provider = FakeUserModelsProvider(models)
|
||||||
|
|
||||||
|
async def provider_factory():
|
||||||
|
return provider
|
||||||
|
|
||||||
|
lora_scanner = FakeExistenceScanner({101})
|
||||||
|
checkpoint_scanner = FakeExistenceScanner()
|
||||||
|
embedding_scanner = FakeExistenceScanner({202})
|
||||||
|
|
||||||
|
async def lora_factory():
|
||||||
|
return lora_scanner
|
||||||
|
|
||||||
|
async def checkpoint_factory():
|
||||||
|
return checkpoint_scanner
|
||||||
|
|
||||||
|
async def embedding_factory():
|
||||||
|
return embedding_scanner
|
||||||
|
|
||||||
|
handler = ModelLibraryHandler(
|
||||||
|
ServiceRegistryAdapter(
|
||||||
|
get_lora_scanner=lora_factory,
|
||||||
|
get_checkpoint_scanner=checkpoint_factory,
|
||||||
|
get_embedding_scanner=embedding_factory,
|
||||||
|
),
|
||||||
|
metadata_provider_factory=provider_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handler.get_civitai_user_models(FakeRequest(query={"username": "pixel"}))
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert payload["success"] is True
|
||||||
|
assert payload["username"] == "pixel"
|
||||||
|
assert payload["versions"] == [
|
||||||
|
{
|
||||||
|
"modelId": 1,
|
||||||
|
"versionId": 100,
|
||||||
|
"modelName": "Model A",
|
||||||
|
"versionName": "v1",
|
||||||
|
"type": "LORA",
|
||||||
|
"tags": ["style"],
|
||||||
|
"baseModel": "Flux.1",
|
||||||
|
"thumbnailUrl": "http://example.com/a1.jpg",
|
||||||
|
"inLibrary": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"modelId": 1,
|
||||||
|
"versionId": 101,
|
||||||
|
"modelName": "Model A",
|
||||||
|
"versionName": "v2",
|
||||||
|
"type": "LORA",
|
||||||
|
"tags": ["style"],
|
||||||
|
"baseModel": "Flux.1",
|
||||||
|
"thumbnailUrl": "http://example.com/a2.jpg",
|
||||||
|
"inLibrary": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"modelId": 2,
|
||||||
|
"versionId": 200,
|
||||||
|
"modelName": "Embedding",
|
||||||
|
"versionName": "v1",
|
||||||
|
"type": "TextualInversion",
|
||||||
|
"tags": ["embedding"],
|
||||||
|
"baseModel": None,
|
||||||
|
"thumbnailUrl": "http://example.com/e1.jpg",
|
||||||
|
"inLibrary": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"modelId": 2,
|
||||||
|
"versionId": 202,
|
||||||
|
"modelName": "Embedding",
|
||||||
|
"versionName": "v2",
|
||||||
|
"type": "TextualInversion",
|
||||||
|
"tags": ["embedding"],
|
||||||
|
"baseModel": None,
|
||||||
|
"thumbnailUrl": None,
|
||||||
|
"inLibrary": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"modelId": 3,
|
||||||
|
"versionId": 300,
|
||||||
|
"modelName": "Checkpoint",
|
||||||
|
"versionName": "v1",
|
||||||
|
"type": "Checkpoint",
|
||||||
|
"tags": ["checkpoint"],
|
||||||
|
"baseModel": "SDXL",
|
||||||
|
"thumbnailUrl": None,
|
||||||
|
"inLibrary": False,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
assert provider.received_usernames == ["pixel"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_civitai_user_models_rewrites_civitai_previews():
|
||||||
|
image_url = "https://image.civitai.com/container/example/original=true/sample.jpeg"
|
||||||
|
video_url = "https://image.civitai.com/container/example/original=true/sample.mp4"
|
||||||
|
|
||||||
|
models = [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"name": "Model A",
|
||||||
|
"type": "LORA",
|
||||||
|
"tags": ["style"],
|
||||||
|
"modelVersions": [
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"name": "preview-image",
|
||||||
|
"baseModel": "Flux.1",
|
||||||
|
"images": [
|
||||||
|
{"url": image_url, "type": "image"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 101,
|
||||||
|
"name": "preview-video",
|
||||||
|
"baseModel": "Flux.1",
|
||||||
|
"images": [
|
||||||
|
{"url": video_url, "type": "video"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
provider = FakeUserModelsProvider(models)
|
||||||
|
|
||||||
|
async def provider_factory():
|
||||||
|
return provider
|
||||||
|
|
||||||
|
handler = ModelLibraryHandler(
|
||||||
|
ServiceRegistryAdapter(
|
||||||
|
get_lora_scanner=fake_scanner_factory,
|
||||||
|
get_checkpoint_scanner=fake_scanner_factory,
|
||||||
|
get_embedding_scanner=fake_scanner_factory,
|
||||||
|
),
|
||||||
|
metadata_provider_factory=provider_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handler.get_civitai_user_models(FakeRequest(query={"username": "pixel"}))
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert payload["success"] is True
|
||||||
|
previews_by_version = {item["versionId"]: item["thumbnailUrl"] for item in payload["versions"]}
|
||||||
|
assert previews_by_version[100] == "https://image.civitai.com/container/example/width=450,optimized=true/sample.jpeg"
|
||||||
|
assert (
|
||||||
|
previews_by_version[101]
|
||||||
|
== "https://image.civitai.com/container/example/transcode=true,width=450,optimized=true/sample.mp4"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_civitai_user_models_requires_username():
|
||||||
|
provider = FakeUserModelsProvider([])
|
||||||
|
|
||||||
|
async def provider_factory():
|
||||||
|
return provider
|
||||||
|
|
||||||
|
handler = ModelLibraryHandler(
|
||||||
|
ServiceRegistryAdapter(
|
||||||
|
get_lora_scanner=fake_scanner_factory,
|
||||||
|
get_checkpoint_scanner=fake_scanner_factory,
|
||||||
|
get_embedding_scanner=fake_scanner_factory,
|
||||||
|
),
|
||||||
|
metadata_provider_factory=provider_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handler.get_civitai_user_models(FakeRequest())
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert response.status == 400
|
||||||
|
assert payload["success"] is False
|
||||||
|
assert "username" in payload["error"].lower()
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_handler_mapping_caches_result():
|
def test_ensure_handler_mapping_caches_result():
|
||||||
call_records = []
|
call_records = []
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import copy
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -169,6 +170,158 @@ async def test_get_model_version_by_version_id(monkeypatch, downloader):
|
|||||||
assert result["images"][0]["meta"]["other"] == "keep"
|
assert result["images"][0]["meta"]["other"] == "keep"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_model_version_with_model_id_prefers_version_endpoint(monkeypatch, downloader):
|
||||||
|
requests = []
|
||||||
|
|
||||||
|
model_payload = {
|
||||||
|
"modelVersions": [
|
||||||
|
{
|
||||||
|
"id": 7,
|
||||||
|
"files": [
|
||||||
|
{
|
||||||
|
"type": "Model",
|
||||||
|
"primary": True,
|
||||||
|
"hashes": {"SHA256": "hash7"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "desc",
|
||||||
|
"tags": ["tag"],
|
||||||
|
"creator": {"username": "user"},
|
||||||
|
"name": "Model",
|
||||||
|
"type": "LORA",
|
||||||
|
"nsfw": False,
|
||||||
|
"poi": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
version_payload = {
|
||||||
|
"id": 7,
|
||||||
|
"modelId": 99,
|
||||||
|
"model": {},
|
||||||
|
"files": [],
|
||||||
|
"images": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def fake_make_request(method, url, use_auth=True):
|
||||||
|
requests.append(url)
|
||||||
|
if url.endswith("/models/99"):
|
||||||
|
return True, copy.deepcopy(model_payload)
|
||||||
|
if url.endswith("/model-versions/7"):
|
||||||
|
return True, copy.deepcopy(version_payload)
|
||||||
|
return False, "unexpected"
|
||||||
|
|
||||||
|
downloader.make_request = fake_make_request
|
||||||
|
|
||||||
|
client = await CivitaiClient.get_instance()
|
||||||
|
|
||||||
|
result = await client.get_model_version(model_id=99, version_id=7)
|
||||||
|
|
||||||
|
assert result["id"] == 7
|
||||||
|
assert result["model"]["description"] == "desc"
|
||||||
|
assert result["model"]["tags"] == ["tag"]
|
||||||
|
assert result["creator"] == {"username": "user"}
|
||||||
|
assert requests[0].endswith("/models/99")
|
||||||
|
assert requests[1].endswith("/model-versions/7")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_model_version_with_model_id_fallbacks_to_hash(monkeypatch, downloader):
|
||||||
|
requests = []
|
||||||
|
|
||||||
|
model_payload = {
|
||||||
|
"modelVersions": [
|
||||||
|
{
|
||||||
|
"id": 7,
|
||||||
|
"files": [
|
||||||
|
{
|
||||||
|
"type": "Model",
|
||||||
|
"primary": True,
|
||||||
|
"hashes": {"SHA256": "hash7"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "desc",
|
||||||
|
"tags": ["tag"],
|
||||||
|
"creator": {"username": "user"},
|
||||||
|
"name": "Model",
|
||||||
|
"type": "LORA",
|
||||||
|
"nsfw": False,
|
||||||
|
"poi": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
version_payload = {
|
||||||
|
"id": 7,
|
||||||
|
"modelId": 99,
|
||||||
|
"files": [],
|
||||||
|
"images": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def fake_make_request(method, url, use_auth=True):
|
||||||
|
requests.append(url)
|
||||||
|
if url.endswith("/models/99"):
|
||||||
|
return True, copy.deepcopy(model_payload)
|
||||||
|
if url.endswith("/model-versions/7"):
|
||||||
|
return False, "boom"
|
||||||
|
if url.endswith("/model-versions/by-hash/hash7"):
|
||||||
|
return True, copy.deepcopy(version_payload)
|
||||||
|
return False, "unexpected"
|
||||||
|
|
||||||
|
downloader.make_request = fake_make_request
|
||||||
|
|
||||||
|
client = await CivitaiClient.get_instance()
|
||||||
|
|
||||||
|
result = await client.get_model_version(model_id=99, version_id=7)
|
||||||
|
|
||||||
|
assert result["id"] == 7
|
||||||
|
assert result["model"]["description"] == "desc"
|
||||||
|
assert result["model"]["tags"] == ["tag"]
|
||||||
|
assert result["creator"] == {"username": "user"}
|
||||||
|
assert requests[1].endswith("/model-versions/7")
|
||||||
|
assert requests[2].endswith("/model-versions/by-hash/hash7")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_model_version_with_model_id_builds_from_model_data(monkeypatch, downloader):
|
||||||
|
model_payload = {
|
||||||
|
"modelVersions": [
|
||||||
|
{
|
||||||
|
"id": 7,
|
||||||
|
"files": [],
|
||||||
|
"name": "v1",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "desc",
|
||||||
|
"tags": ["tag"],
|
||||||
|
"creator": {"username": "user"},
|
||||||
|
"name": "Model",
|
||||||
|
"type": "LORA",
|
||||||
|
"nsfw": False,
|
||||||
|
"poi": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def fake_make_request(method, url, use_auth=True):
|
||||||
|
if url.endswith("/models/99"):
|
||||||
|
return True, copy.deepcopy(model_payload)
|
||||||
|
if url.endswith("/model-versions/7"):
|
||||||
|
return False, "boom"
|
||||||
|
if "/model-versions/by-hash/" in url:
|
||||||
|
return False, "boom"
|
||||||
|
return False, "unexpected"
|
||||||
|
|
||||||
|
downloader.make_request = fake_make_request
|
||||||
|
|
||||||
|
client = await CivitaiClient.get_instance()
|
||||||
|
|
||||||
|
result = await client.get_model_version(model_id=99, version_id=7)
|
||||||
|
|
||||||
|
assert result["modelId"] == 99
|
||||||
|
assert result["model"]["name"] == "Model"
|
||||||
|
assert result["model"]["type"] == "LORA"
|
||||||
|
assert result["model"]["description"] == "desc"
|
||||||
|
assert result["model"]["tags"] == ["tag"]
|
||||||
|
assert result["creator"] == {"username": "user"}
|
||||||
|
|
||||||
|
|
||||||
async def test_get_model_version_requires_identifier(monkeypatch, downloader):
|
async def test_get_model_version_requires_identifier(monkeypatch, downloader):
|
||||||
client = await CivitaiClient.get_instance()
|
client = await CivitaiClient.get_instance()
|
||||||
result = await client.get_model_version()
|
result = await client.get_model_version()
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import pytest
|
|||||||
from py.services.download_manager import DownloadManager
|
from py.services.download_manager import DownloadManager
|
||||||
from py.services import download_manager
|
from py.services import download_manager
|
||||||
from py.services.service_registry import ServiceRegistry
|
from py.services.service_registry import ServiceRegistry
|
||||||
from py.services.settings_manager import settings
|
from py.services.settings_manager import SettingsManager, get_settings_manager
|
||||||
from py.utils.metadata_manager import MetadataManager
|
from py.utils.metadata_manager import MetadataManager
|
||||||
|
|
||||||
|
|
||||||
@@ -23,7 +23,8 @@ def reset_download_manager():
|
|||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def isolate_settings(monkeypatch, tmp_path):
|
def isolate_settings(monkeypatch, tmp_path):
|
||||||
"""Point settings writes at a temporary directory to avoid touching real files."""
|
"""Point settings writes at a temporary directory to avoid touching real files."""
|
||||||
default_settings = settings._get_default_settings()
|
manager = get_settings_manager()
|
||||||
|
default_settings = manager._get_default_settings()
|
||||||
default_settings.update(
|
default_settings.update(
|
||||||
{
|
{
|
||||||
"default_lora_root": str(tmp_path),
|
"default_lora_root": str(tmp_path),
|
||||||
@@ -37,8 +38,8 @@ def isolate_settings(monkeypatch, tmp_path):
|
|||||||
"base_model_path_mappings": {"BaseModel": "MappedModel"},
|
"base_model_path_mappings": {"BaseModel": "MappedModel"},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(settings, "settings", default_settings)
|
monkeypatch.setattr(manager, "settings", default_settings)
|
||||||
monkeypatch.setattr(type(settings), "_save_settings", lambda self: None)
|
monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@@ -187,7 +188,7 @@ async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata
|
|||||||
assert manager._active_downloads[result["download_id"]]["status"] == "completed"
|
assert manager._active_downloads[result["download_id"]]["status"] == "completed"
|
||||||
|
|
||||||
assert captured["relative_path"] == "MappedModel/fantasy"
|
assert captured["relative_path"] == "MappedModel/fantasy"
|
||||||
expected_dir = Path(settings.get("default_lora_root")) / "MappedModel" / "fantasy"
|
expected_dir = Path(get_settings_manager().get("default_lora_root")) / "MappedModel" / "fantasy"
|
||||||
assert captured["save_dir"] == expected_dir
|
assert captured["save_dir"] == expected_dir
|
||||||
assert captured["model_type"] == "lora"
|
assert captured["model_type"] == "lora"
|
||||||
assert captured["download_urls"] == [
|
assert captured["download_urls"] == [
|
||||||
@@ -393,3 +394,98 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
|
|||||||
assert result == {"success": True}
|
assert result == {"success": True}
|
||||||
assert [url for url, *_ in dummy_downloader.calls] == download_urls
|
assert [url for url, *_ in dummy_downloader.calls] == download_urls
|
||||||
assert dummy_scanner.calls # ensure cache updated
|
assert dummy_scanner.calls # ensure cache updated
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_download_uses_rewritten_civitai_preview(monkeypatch, tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
save_dir = tmp_path / "downloads"
|
||||||
|
save_dir.mkdir()
|
||||||
|
target_path = save_dir / "file.safetensors"
|
||||||
|
|
||||||
|
manager._active_downloads["dl"] = {}
|
||||||
|
|
||||||
|
class DummyMetadata:
|
||||||
|
def __init__(self, path: Path):
|
||||||
|
self.file_path = str(path)
|
||||||
|
self.sha256 = "sha256"
|
||||||
|
self.file_name = path.stem
|
||||||
|
self.preview_url = None
|
||||||
|
self.preview_nsfw_level = None
|
||||||
|
|
||||||
|
def generate_unique_filename(self, *_args, **_kwargs):
|
||||||
|
return os.path.basename(self.file_path)
|
||||||
|
|
||||||
|
def update_file_info(self, _path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {"file_path": self.file_path}
|
||||||
|
|
||||||
|
metadata = DummyMetadata(target_path)
|
||||||
|
version_info = {
|
||||||
|
"images": [
|
||||||
|
{
|
||||||
|
"url": "https://image.civitai.com/container/example/original=true/sample.jpeg",
|
||||||
|
"type": "image",
|
||||||
|
"nsfwLevel": 2,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
download_urls = ["https://example.invalid/file.safetensors"]
|
||||||
|
|
||||||
|
class DummyDownloader:
|
||||||
|
def __init__(self):
|
||||||
|
self.file_calls: list[tuple[str, str]] = []
|
||||||
|
self.memory_calls = 0
|
||||||
|
|
||||||
|
async def download_file(self, url, path, progress_callback=None, use_auth=None):
|
||||||
|
self.file_calls.append((url, path))
|
||||||
|
if url.endswith(".jpeg"):
|
||||||
|
Path(path).write_bytes(b"preview")
|
||||||
|
return True, None
|
||||||
|
if url.endswith(".safetensors"):
|
||||||
|
Path(path).write_bytes(b"model")
|
||||||
|
return True, None
|
||||||
|
return False, "unexpected url"
|
||||||
|
|
||||||
|
async def download_to_memory(self, *_args, **_kwargs):
|
||||||
|
self.memory_calls += 1
|
||||||
|
return False, b"", {}
|
||||||
|
|
||||||
|
dummy_downloader = DummyDownloader()
|
||||||
|
monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=dummy_downloader))
|
||||||
|
|
||||||
|
optimize_called = {"value": False}
|
||||||
|
|
||||||
|
def fake_optimize_image(**_kwargs):
|
||||||
|
optimize_called["value"] = True
|
||||||
|
return b"", {}
|
||||||
|
|
||||||
|
monkeypatch.setattr(download_manager.ExifUtils, "optimize_image", staticmethod(fake_optimize_image))
|
||||||
|
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
|
||||||
|
|
||||||
|
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
||||||
|
monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner))
|
||||||
|
|
||||||
|
result = await manager._execute_download(
|
||||||
|
download_urls=download_urls,
|
||||||
|
save_dir=str(save_dir),
|
||||||
|
metadata=metadata,
|
||||||
|
version_info=version_info,
|
||||||
|
relative_path="",
|
||||||
|
progress_callback=None,
|
||||||
|
model_type="lora",
|
||||||
|
download_id="dl",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"success": True}
|
||||||
|
preview_urls = [url for url, _ in dummy_downloader.file_calls if url.endswith(".jpeg")]
|
||||||
|
assert any("width=450,optimized=true" in url for url in preview_urls)
|
||||||
|
assert dummy_downloader.memory_calls == 0
|
||||||
|
assert optimize_called["value"] is False
|
||||||
|
assert metadata.preview_url.endswith(".jpeg")
|
||||||
|
assert metadata.preview_nsfw_level == 2
|
||||||
|
stored_preview = manager._active_downloads["dl"]["preview_path"]
|
||||||
|
assert stored_preview.endswith(".jpeg")
|
||||||
|
assert Path(stored_preview).exists()
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import pytest
|
|||||||
|
|
||||||
from py.services.example_images_cleanup_service import ExampleImagesCleanupService
|
from py.services.example_images_cleanup_service import ExampleImagesCleanupService
|
||||||
from py.services.service_registry import ServiceRegistry
|
from py.services.service_registry import ServiceRegistry
|
||||||
from py.services.settings_manager import settings
|
from py.services.settings_manager import get_settings_manager
|
||||||
|
|
||||||
|
|
||||||
class StubScanner:
|
class StubScanner:
|
||||||
@@ -21,8 +21,9 @@ class StubScanner:
|
|||||||
async def test_cleanup_moves_empty_and_orphaned(tmp_path, monkeypatch):
|
async def test_cleanup_moves_empty_and_orphaned(tmp_path, monkeypatch):
|
||||||
service = ExampleImagesCleanupService()
|
service = ExampleImagesCleanupService()
|
||||||
|
|
||||||
previous_path = settings.get('example_images_path')
|
settings_manager = get_settings_manager()
|
||||||
settings.settings['example_images_path'] = str(tmp_path)
|
previous_path = settings_manager.get('example_images_path')
|
||||||
|
settings_manager.settings['example_images_path'] = str(tmp_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
empty_folder = tmp_path / 'empty_folder'
|
empty_folder = tmp_path / 'empty_folder'
|
||||||
@@ -64,23 +65,24 @@ async def test_cleanup_moves_empty_and_orphaned(tmp_path, monkeypatch):
|
|||||||
|
|
||||||
finally:
|
finally:
|
||||||
if previous_path is None:
|
if previous_path is None:
|
||||||
settings.settings.pop('example_images_path', None)
|
settings_manager.settings.pop('example_images_path', None)
|
||||||
else:
|
else:
|
||||||
settings.settings['example_images_path'] = previous_path
|
settings_manager.settings['example_images_path'] = previous_path
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cleanup_handles_missing_path(monkeypatch):
|
async def test_cleanup_handles_missing_path(monkeypatch):
|
||||||
service = ExampleImagesCleanupService()
|
service = ExampleImagesCleanupService()
|
||||||
|
|
||||||
previous_path = settings.get('example_images_path')
|
settings_manager = get_settings_manager()
|
||||||
settings.settings.pop('example_images_path', None)
|
previous_path = settings_manager.get('example_images_path')
|
||||||
|
settings_manager.settings.pop('example_images_path', None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await service.cleanup_example_image_folders()
|
result = await service.cleanup_example_image_folders()
|
||||||
finally:
|
finally:
|
||||||
if previous_path is not None:
|
if previous_path is not None:
|
||||||
settings.settings['example_images_path'] = previous_path
|
settings_manager.settings['example_images_path'] = previous_path
|
||||||
|
|
||||||
assert result['success'] is False
|
assert result['success'] is False
|
||||||
assert result['error_code'] == 'path_not_configured'
|
assert result['error_code'] == 'path_not_configured'
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from types import SimpleNamespace
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from py.services.settings_manager import settings
|
from py.services.settings_manager import SettingsManager, get_settings_manager
|
||||||
from py.utils import example_images_download_manager as download_module
|
from py.utils import example_images_download_manager as download_module
|
||||||
|
|
||||||
|
|
||||||
@@ -43,11 +43,15 @@ def _patch_scanner(monkeypatch: pytest.MonkeyPatch, scanner: StubScanner) -> Non
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("tmp_path")
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
async def test_start_download_rejects_parallel_runs(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
async def test_start_download_rejects_parallel_runs(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
ws_manager = RecordingWebSocketManager()
|
ws_manager = RecordingWebSocketManager()
|
||||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
|
|
||||||
model = {
|
model = {
|
||||||
"sha256": "abc123",
|
"sha256": "abc123",
|
||||||
@@ -106,11 +110,15 @@ async def test_start_download_rejects_parallel_runs(monkeypatch: pytest.MonkeyPa
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("tmp_path")
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
async def test_pause_resume_blocks_processing(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
async def test_pause_resume_blocks_processing(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
ws_manager = RecordingWebSocketManager()
|
ws_manager = RecordingWebSocketManager()
|
||||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
|
|
||||||
models = [
|
models = [
|
||||||
{
|
{
|
||||||
@@ -231,13 +239,17 @@ async def test_pause_resume_blocks_processing(monkeypatch: pytest.MonkeyPatch, t
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("tmp_path")
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
async def test_legacy_folder_migrated_and_skipped(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
async def test_legacy_folder_migrated_and_skipped(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
ws_manager = RecordingWebSocketManager()
|
ws_manager = RecordingWebSocketManager()
|
||||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}})
|
monkeypatch.setitem(settings_manager.settings, "libraries", {"default": {}, "extra": {}})
|
||||||
monkeypatch.setitem(settings.settings, "active_library", "extra")
|
monkeypatch.setitem(settings_manager.settings, "active_library", "extra")
|
||||||
|
|
||||||
model_hash = "d" * 64
|
model_hash = "d" * 64
|
||||||
model_path = tmp_path / "model.safetensors"
|
model_path = tmp_path / "model.safetensors"
|
||||||
@@ -310,13 +322,17 @@ async def test_legacy_folder_migrated_and_skipped(monkeypatch: pytest.MonkeyPatc
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("tmp_path")
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
async def test_legacy_progress_file_migrates(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
async def test_legacy_progress_file_migrates(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
ws_manager = RecordingWebSocketManager()
|
ws_manager = RecordingWebSocketManager()
|
||||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}})
|
monkeypatch.setitem(settings_manager.settings, "libraries", {"default": {}, "extra": {}})
|
||||||
monkeypatch.setitem(settings.settings, "active_library", "extra")
|
monkeypatch.setitem(settings_manager.settings, "active_library", "extra")
|
||||||
|
|
||||||
model_hash = "e" * 64
|
model_hash = "e" * 64
|
||||||
model_path = tmp_path / "model-two.safetensors"
|
model_path = tmp_path / "model-two.safetensors"
|
||||||
@@ -380,20 +396,24 @@ async def test_legacy_progress_file_migrates(monkeypatch: pytest.MonkeyPatch, tm
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("tmp_path")
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
async def test_download_remains_in_initial_library(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
async def test_download_remains_in_initial_library(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
ws_manager = RecordingWebSocketManager()
|
ws_manager = RecordingWebSocketManager()
|
||||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
monkeypatch.setitem(settings.settings, "libraries", {"LibraryA": {}, "LibraryB": {}})
|
monkeypatch.setitem(settings_manager.settings, "libraries", {"LibraryA": {}, "LibraryB": {}})
|
||||||
monkeypatch.setitem(settings.settings, "active_library", "LibraryA")
|
monkeypatch.setitem(settings_manager.settings, "active_library", "LibraryA")
|
||||||
|
|
||||||
state = {"active": "LibraryA"}
|
state = {"active": "LibraryA"}
|
||||||
|
|
||||||
def fake_get_active_library_name(self):
|
def fake_get_active_library_name(self):
|
||||||
return state["active"]
|
return state["active"]
|
||||||
|
|
||||||
monkeypatch.setattr(type(settings), "get_active_library_name", fake_get_active_library_name)
|
monkeypatch.setattr(SettingsManager, "get_active_library_name", fake_get_active_library_name)
|
||||||
|
|
||||||
model_hash = "f" * 64
|
model_hash = "f" * 64
|
||||||
model_path = tmp_path / "example-model.safetensors"
|
model_path = tmp_path / "example-model.safetensors"
|
||||||
@@ -454,3 +474,7 @@ async def test_download_remains_in_initial_library(monkeypatch: pytest.MonkeyPat
|
|||||||
assert (model_dir / "local.txt").exists()
|
assert (model_dir / "local.txt").exists()
|
||||||
assert not (library_b_root / ".download_progress.json").exists()
|
assert not (library_b_root / ".download_progress.json").exists()
|
||||||
assert not (library_b_root / model_hash).exists()
|
assert not (library_b_root / model_hash).exists()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def settings_manager():
|
||||||
|
return get_settings_manager()
|
||||||
|
|||||||
@@ -243,6 +243,7 @@ async def test_initialize_in_background_uses_persisted_cache_without_full_scan(t
|
|||||||
cache = await scanner.get_cached_data()
|
cache = await scanner.get_cached_data()
|
||||||
assert len(cache.raw_data) == 1
|
assert len(cache.raw_data) == 1
|
||||||
assert cache.raw_data[0]['file_path'] == normalized
|
assert cache.raw_data[0]['file_path'] == normalized
|
||||||
|
assert cache.version_index[11]['file_path'] == normalized
|
||||||
|
|
||||||
assert scanner._hash_index.get_path('hash-one') == normalized
|
assert scanner._hash_index.get_path('hash-one') == normalized
|
||||||
|
|
||||||
@@ -301,6 +302,7 @@ async def test_load_persisted_cache_populates_cache(tmp_path: Path, monkeypatch)
|
|||||||
assert entry['file_path'] == normalized
|
assert entry['file_path'] == normalized
|
||||||
assert entry['tags'] == ['alpha']
|
assert entry['tags'] == ['alpha']
|
||||||
assert entry['civitai']['trainedWords'] == ['abc']
|
assert entry['civitai']['trainedWords'] == ['abc']
|
||||||
|
assert cache.version_index[11]['file_path'] == normalized
|
||||||
assert scanner._hash_index.get_path('hash-one') == normalized
|
assert scanner._hash_index.get_path('hash-one') == normalized
|
||||||
assert scanner._tags_count == {'alpha': 1}
|
assert scanner._tags_count == {'alpha': 1}
|
||||||
assert ws_stub.payloads[-1]['stage'] == 'loading_cache'
|
assert ws_stub.payloads[-1]['stage'] == 'loading_cache'
|
||||||
@@ -381,6 +383,66 @@ async def test_batch_delete_persists_removal(tmp_path: Path, monkeypatch):
|
|||||||
assert remaining == 0
|
assert remaining == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_version_index_tracks_version_ids(tmp_path: Path):
|
||||||
|
scanner = DummyScanner(tmp_path)
|
||||||
|
|
||||||
|
first_path = _normalize_path(tmp_path / 'alpha.txt')
|
||||||
|
second_path = _normalize_path(tmp_path / 'beta.txt')
|
||||||
|
|
||||||
|
first_entry = {
|
||||||
|
'file_path': first_path,
|
||||||
|
'file_name': 'alpha',
|
||||||
|
'model_name': 'alpha',
|
||||||
|
'folder': '',
|
||||||
|
'size': 1,
|
||||||
|
'modified': 1.0,
|
||||||
|
'sha256': 'hash-alpha',
|
||||||
|
'tags': [],
|
||||||
|
'civitai': {'id': 101, 'modelId': 1, 'name': 'alpha'},
|
||||||
|
}
|
||||||
|
|
||||||
|
second_entry = {
|
||||||
|
'file_path': second_path,
|
||||||
|
'file_name': 'beta',
|
||||||
|
'model_name': 'beta',
|
||||||
|
'folder': '',
|
||||||
|
'size': 1,
|
||||||
|
'modified': 1.0,
|
||||||
|
'sha256': 'hash-beta',
|
||||||
|
'tags': [],
|
||||||
|
'civitai': {'id': 202, 'modelId': 2, 'name': 'beta'},
|
||||||
|
}
|
||||||
|
|
||||||
|
hash_index = ModelHashIndex()
|
||||||
|
hash_index.add_entry('hash-alpha', first_path)
|
||||||
|
hash_index.add_entry('hash-beta', second_path)
|
||||||
|
|
||||||
|
scan_result = CacheBuildResult(
|
||||||
|
raw_data=[first_entry, second_entry],
|
||||||
|
hash_index=hash_index,
|
||||||
|
tags_count={},
|
||||||
|
excluded_models=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
await scanner._apply_scan_result(scan_result)
|
||||||
|
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
assert cache.version_index[101]['file_path'] == first_path
|
||||||
|
assert cache.version_index[202]['file_path'] == second_path
|
||||||
|
|
||||||
|
assert await scanner.check_model_version_exists(101) is True
|
||||||
|
assert await scanner.check_model_version_exists('202') is True
|
||||||
|
assert await scanner.check_model_version_exists(999) is False
|
||||||
|
|
||||||
|
removed = await scanner._batch_update_cache_for_deleted_models([first_path])
|
||||||
|
assert removed is True
|
||||||
|
|
||||||
|
cache_after = await scanner.get_cached_data()
|
||||||
|
assert 101 not in cache_after.version_index
|
||||||
|
assert await scanner.check_model_version_exists(101) is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reconcile_cache_adds_new_files_and_updates_hash_index(tmp_path: Path):
|
async def test_reconcile_cache_adds_new_files_and_updates_hash_index(tmp_path: Path):
|
||||||
first, _, _ = _create_files(tmp_path)
|
first, _, _ = _create_files(tmp_path)
|
||||||
|
|||||||
182
tests/services/test_preview_asset_service.py
Normal file
182
tests/services/test_preview_asset_service.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from py.services.preview_asset_service import PreviewAssetService
|
||||||
|
|
||||||
|
|
||||||
|
class StubMetadataManager:
|
||||||
|
async def save_metadata(self, *_args: Any, **_kwargs: Any) -> bool: # pragma: no cover - helper
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class RecordingExifUtils:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.called = False
|
||||||
|
|
||||||
|
def optimize_image(self, **kwargs):
|
||||||
|
self.called = True
|
||||||
|
return kwargs["image_data"], {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ensure_preview_prefers_rewritten_civitai_image(tmp_path):
|
||||||
|
metadata_path = tmp_path / "model.metadata.json"
|
||||||
|
metadata_path.write_text("{}")
|
||||||
|
local_metadata: dict[str, Any] = {}
|
||||||
|
|
||||||
|
class Downloader:
|
||||||
|
def __init__(self):
|
||||||
|
self.file_calls: list[tuple[str, str]] = []
|
||||||
|
self.memory_calls = 0
|
||||||
|
|
||||||
|
async def download_file(self, url, path, use_auth=False):
|
||||||
|
self.file_calls.append((url, path))
|
||||||
|
if "width=450,optimized=true" in url:
|
||||||
|
Path(path).write_bytes(b"image-data")
|
||||||
|
return True, None
|
||||||
|
return False, "fail"
|
||||||
|
|
||||||
|
async def download_to_memory(self, *_args, **_kwargs):
|
||||||
|
self.memory_calls += 1
|
||||||
|
return False, b"", {}
|
||||||
|
|
||||||
|
downloader = Downloader()
|
||||||
|
|
||||||
|
async def downloader_factory():
|
||||||
|
return downloader
|
||||||
|
|
||||||
|
exif_utils = RecordingExifUtils()
|
||||||
|
service = PreviewAssetService(
|
||||||
|
metadata_manager=StubMetadataManager(),
|
||||||
|
downloader_factory=downloader_factory,
|
||||||
|
exif_utils=exif_utils,
|
||||||
|
)
|
||||||
|
|
||||||
|
images = [
|
||||||
|
{
|
||||||
|
"url": "https://image.civitai.com/container/example/original=true/sample.jpeg",
|
||||||
|
"type": "image",
|
||||||
|
"nsfwLevel": 3,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
|
||||||
|
|
||||||
|
assert downloader.memory_calls == 0
|
||||||
|
assert exif_utils.called is False
|
||||||
|
assert len(downloader.file_calls) == 1
|
||||||
|
assert "width=450,optimized=true" in downloader.file_calls[0][0]
|
||||||
|
preview_path = Path(local_metadata["preview_url"])
|
||||||
|
assert preview_path.exists()
|
||||||
|
assert preview_path.suffix == ".jpeg"
|
||||||
|
assert local_metadata["preview_nsfw_level"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ensure_preview_falls_back_to_webp_when_rewrite_fails(tmp_path):
|
||||||
|
metadata_path = tmp_path / "model.metadata.json"
|
||||||
|
metadata_path.write_text("{}")
|
||||||
|
local_metadata: dict[str, Any] = {}
|
||||||
|
|
||||||
|
class Downloader:
|
||||||
|
def __init__(self):
|
||||||
|
self.file_calls: list[tuple[str, str]] = []
|
||||||
|
self.memory_calls = 0
|
||||||
|
|
||||||
|
async def download_file(self, url, path, use_auth=False):
|
||||||
|
self.file_calls.append((url, path))
|
||||||
|
return False, "fail"
|
||||||
|
|
||||||
|
async def download_to_memory(self, *_args, **_kwargs):
|
||||||
|
self.memory_calls += 1
|
||||||
|
return True, b"raw-image", {}
|
||||||
|
|
||||||
|
downloader = Downloader()
|
||||||
|
|
||||||
|
async def downloader_factory():
|
||||||
|
return downloader
|
||||||
|
|
||||||
|
class ExifUtils:
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = 0
|
||||||
|
|
||||||
|
def optimize_image(self, **kwargs):
|
||||||
|
self.calls += 1
|
||||||
|
return b"webp-data", {}
|
||||||
|
|
||||||
|
exif_utils = ExifUtils()
|
||||||
|
|
||||||
|
service = PreviewAssetService(
|
||||||
|
metadata_manager=StubMetadataManager(),
|
||||||
|
downloader_factory=downloader_factory,
|
||||||
|
exif_utils=exif_utils,
|
||||||
|
)
|
||||||
|
|
||||||
|
images = [
|
||||||
|
{
|
||||||
|
"url": "https://image.civitai.com/container/example/original=true/sample.png",
|
||||||
|
"type": "image",
|
||||||
|
"nsfwLevel": 1,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
|
||||||
|
|
||||||
|
assert downloader.memory_calls == 1
|
||||||
|
assert exif_utils.calls == 1
|
||||||
|
preview_path = Path(local_metadata["preview_url"])
|
||||||
|
assert preview_path.exists()
|
||||||
|
assert preview_path.suffix == ".webp"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ensure_preview_rewrites_civitai_video(tmp_path):
|
||||||
|
metadata_path = tmp_path / "model.metadata.json"
|
||||||
|
metadata_path.write_text("{}")
|
||||||
|
local_metadata: dict[str, Any] = {}
|
||||||
|
|
||||||
|
class Downloader:
|
||||||
|
def __init__(self):
|
||||||
|
self.file_calls: list[tuple[str, str]] = []
|
||||||
|
|
||||||
|
async def download_file(self, url, path, use_auth=False):
|
||||||
|
self.file_calls.append((url, path))
|
||||||
|
if "transcode=true,width=450,optimized=true" in url:
|
||||||
|
Path(path).write_bytes(b"video-data")
|
||||||
|
return True, None
|
||||||
|
if url.endswith(".mp4"):
|
||||||
|
return False, "fail"
|
||||||
|
return False, "unexpected"
|
||||||
|
|
||||||
|
async def download_to_memory(self, *_args, **_kwargs):
|
||||||
|
pytest.fail("download_to_memory should not be used for video previews")
|
||||||
|
|
||||||
|
downloader = Downloader()
|
||||||
|
|
||||||
|
async def downloader_factory():
|
||||||
|
return downloader
|
||||||
|
|
||||||
|
service = PreviewAssetService(
|
||||||
|
metadata_manager=StubMetadataManager(),
|
||||||
|
downloader_factory=downloader_factory,
|
||||||
|
exif_utils=RecordingExifUtils(),
|
||||||
|
)
|
||||||
|
|
||||||
|
images = [
|
||||||
|
{
|
||||||
|
"url": "https://image.civitai.com/container/example/original=true/sample.mp4",
|
||||||
|
"type": "video",
|
||||||
|
"nsfwLevel": 2,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
|
||||||
|
|
||||||
|
assert len(downloader.file_calls) >= 1
|
||||||
|
assert any("transcode=true,width=450,optimized=true" in url for url, _ in downloader.file_calls)
|
||||||
|
preview_path = Path(local_metadata["preview_url"])
|
||||||
|
assert preview_path.exists()
|
||||||
|
assert preview_path.suffix == ".mp4"
|
||||||
|
assert local_metadata["preview_nsfw_level"] == 2
|
||||||
@@ -28,6 +28,7 @@ from py.utils.example_images_processor import (
|
|||||||
ExampleImagesImportError,
|
ExampleImagesImportError,
|
||||||
ExampleImagesValidationError,
|
ExampleImagesValidationError,
|
||||||
)
|
)
|
||||||
|
from py.utils.metadata_manager import MetadataManager
|
||||||
from tests.conftest import MockModelService, MockScanner
|
from tests.conftest import MockModelService, MockScanner
|
||||||
|
|
||||||
|
|
||||||
@@ -155,7 +156,9 @@ async def test_auto_organize_use_case_rejects_when_running() -> None:
|
|||||||
await use_case.execute(file_paths=None, progress_callback=None)
|
await use_case.execute(file_paths=None, progress_callback=None)
|
||||||
|
|
||||||
|
|
||||||
async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None:
|
async def test_bulk_metadata_refresh_emits_progress_and_updates_cache(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
scanner = MockScanner()
|
scanner = MockScanner()
|
||||||
scanner._cache.raw_data = [
|
scanner._cache.raw_data = [
|
||||||
{
|
{
|
||||||
@@ -170,6 +173,25 @@ async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None:
|
|||||||
settings = StubSettings()
|
settings = StubSettings()
|
||||||
progress = ProgressCollector()
|
progress = ProgressCollector()
|
||||||
|
|
||||||
|
hydration_calls: list[str] = []
|
||||||
|
|
||||||
|
async def fake_hydrate(model_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
hydration_calls.append(model_data.get("file_path", ""))
|
||||||
|
model_data.clear()
|
||||||
|
model_data.update(
|
||||||
|
{
|
||||||
|
"file_path": "model1.safetensors",
|
||||||
|
"sha256": "hash",
|
||||||
|
"from_civitai": True,
|
||||||
|
"model_name": "Demo",
|
||||||
|
"extra": "value",
|
||||||
|
"civitai": {"images": [{"url": "existing.png", "type": "image"}]},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return model_data
|
||||||
|
|
||||||
|
monkeypatch.setattr(MetadataManager, "hydrate_model_data", staticmethod(fake_hydrate))
|
||||||
|
|
||||||
use_case = BulkMetadataRefreshUseCase(
|
use_case = BulkMetadataRefreshUseCase(
|
||||||
service=service,
|
service=service,
|
||||||
metadata_sync=metadata_sync,
|
metadata_sync=metadata_sync,
|
||||||
@@ -183,6 +205,9 @@ async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None:
|
|||||||
assert progress.events[0]["status"] == "started"
|
assert progress.events[0]["status"] == "started"
|
||||||
assert progress.events[-1]["status"] == "completed"
|
assert progress.events[-1]["status"] == "completed"
|
||||||
assert metadata_sync.calls
|
assert metadata_sync.calls
|
||||||
|
assert metadata_sync.calls[0]["model_data"]["extra"] == "value"
|
||||||
|
assert scanner._cache.raw_data[0]["extra"] == "value"
|
||||||
|
assert hydration_calls == ["model1.safetensors"]
|
||||||
assert scanner._cache.resort_calls == 1
|
assert scanner._cache.resort_calls == 1
|
||||||
|
|
||||||
|
|
||||||
@@ -314,4 +339,4 @@ async def test_import_example_images_use_case_propagates_generic_error() -> None
|
|||||||
request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]})
|
request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]})
|
||||||
|
|
||||||
with pytest.raises(ExampleImagesImportError):
|
with pytest.raises(ExampleImagesImportError):
|
||||||
await use_case.execute(request)
|
await use_case.execute(request)
|
||||||
@@ -5,7 +5,7 @@ from typing import Any, Dict
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from py.services.settings_manager import settings
|
from py.services.settings_manager import get_settings_manager
|
||||||
from py.utils import example_images_download_manager as download_module
|
from py.utils import example_images_download_manager as download_module
|
||||||
|
|
||||||
|
|
||||||
@@ -19,19 +19,21 @@ class RecordingWebSocketManager:
|
|||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def restore_settings() -> None:
|
def restore_settings() -> None:
|
||||||
original = settings.settings.copy()
|
manager = get_settings_manager()
|
||||||
|
original = manager.settings.copy()
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
settings.settings.clear()
|
manager.settings.clear()
|
||||||
settings.settings.update(original)
|
manager.settings.update(original)
|
||||||
|
|
||||||
|
|
||||||
async def test_start_download_requires_configured_path(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_start_download_requires_configured_path(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
manager = download_module.DownloadManager(ws_manager=RecordingWebSocketManager())
|
manager = download_module.DownloadManager(ws_manager=RecordingWebSocketManager())
|
||||||
|
|
||||||
# Ensure example_images_path is not configured
|
# Ensure example_images_path is not configured
|
||||||
settings.settings.pop('example_images_path', None)
|
settings_manager = get_settings_manager()
|
||||||
|
settings_manager.settings.pop('example_images_path', None)
|
||||||
|
|
||||||
with pytest.raises(download_module.DownloadConfigurationError) as exc_info:
|
with pytest.raises(download_module.DownloadConfigurationError) as exc_info:
|
||||||
await manager.start_download({})
|
await manager.start_download({})
|
||||||
@@ -44,9 +46,10 @@ async def test_start_download_requires_configured_path(monkeypatch: pytest.Monke
|
|||||||
|
|
||||||
|
|
||||||
async def test_start_download_bootstraps_progress_and_task(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
async def test_start_download_bootstraps_progress_and_task(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
||||||
settings.settings["example_images_path"] = str(tmp_path)
|
settings_manager = get_settings_manager()
|
||||||
settings.settings["libraries"] = {"default": {}}
|
settings_manager.settings["example_images_path"] = str(tmp_path)
|
||||||
settings.settings["active_library"] = "default"
|
settings_manager.settings["libraries"] = {"default": {}}
|
||||||
|
settings_manager.settings["active_library"] = "default"
|
||||||
|
|
||||||
ws_manager = RecordingWebSocketManager()
|
ws_manager = RecordingWebSocketManager()
|
||||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
@@ -84,9 +87,10 @@ async def test_start_download_bootstraps_progress_and_task(monkeypatch: pytest.M
|
|||||||
|
|
||||||
|
|
||||||
async def test_pause_and_resume_flow(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
async def test_pause_and_resume_flow(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
||||||
settings.settings["example_images_path"] = str(tmp_path)
|
settings_manager = get_settings_manager()
|
||||||
settings.settings["libraries"] = {"default": {}}
|
settings_manager.settings["example_images_path"] = str(tmp_path)
|
||||||
settings.settings["active_library"] = "default"
|
settings_manager.settings["libraries"] = {"default": {}}
|
||||||
|
settings_manager.settings["active_library"] = "default"
|
||||||
|
|
||||||
ws_manager = RecordingWebSocketManager()
|
ws_manager = RecordingWebSocketManager()
|
||||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import Any, Dict
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from py.services.settings_manager import settings
|
from py.services.settings_manager import get_settings_manager
|
||||||
from py.utils.example_images_file_manager import ExampleImagesFileManager
|
from py.utils.example_images_file_manager import ExampleImagesFileManager
|
||||||
|
|
||||||
|
|
||||||
@@ -22,16 +22,18 @@ class JsonRequest:
|
|||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def restore_settings() -> None:
|
def restore_settings() -> None:
|
||||||
original = settings.settings.copy()
|
manager = get_settings_manager()
|
||||||
|
original = manager.settings.copy()
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
settings.settings.clear()
|
manager.settings.clear()
|
||||||
settings.settings.update(original)
|
manager.settings.update(original)
|
||||||
|
|
||||||
|
|
||||||
async def test_open_folder_requires_existing_model_directory(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
async def test_open_folder_requires_existing_model_directory(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
||||||
settings.settings["example_images_path"] = str(tmp_path)
|
settings_manager = get_settings_manager()
|
||||||
|
settings_manager.settings["example_images_path"] = str(tmp_path)
|
||||||
model_hash = "a" * 64
|
model_hash = "a" * 64
|
||||||
model_folder = tmp_path / model_hash
|
model_folder = tmp_path / model_hash
|
||||||
model_folder.mkdir()
|
model_folder.mkdir()
|
||||||
@@ -65,7 +67,8 @@ async def test_open_folder_requires_existing_model_directory(monkeypatch: pytest
|
|||||||
|
|
||||||
|
|
||||||
async def test_open_folder_rejects_invalid_paths(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
async def test_open_folder_rejects_invalid_paths(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
||||||
settings.settings["example_images_path"] = str(tmp_path)
|
settings_manager = get_settings_manager()
|
||||||
|
settings_manager.settings["example_images_path"] = str(tmp_path)
|
||||||
|
|
||||||
def fake_get_model_folder(_hash):
|
def fake_get_model_folder(_hash):
|
||||||
return str(tmp_path.parent / "outside")
|
return str(tmp_path.parent / "outside")
|
||||||
@@ -81,7 +84,8 @@ async def test_open_folder_rejects_invalid_paths(monkeypatch: pytest.MonkeyPatch
|
|||||||
|
|
||||||
|
|
||||||
async def test_get_files_lists_supported_media(tmp_path) -> None:
|
async def test_get_files_lists_supported_media(tmp_path) -> None:
|
||||||
settings.settings["example_images_path"] = str(tmp_path)
|
settings_manager = get_settings_manager()
|
||||||
|
settings_manager.settings["example_images_path"] = str(tmp_path)
|
||||||
model_hash = "b" * 64
|
model_hash = "b" * 64
|
||||||
model_folder = tmp_path / model_hash
|
model_folder = tmp_path / model_hash
|
||||||
model_folder.mkdir()
|
model_folder.mkdir()
|
||||||
@@ -99,7 +103,8 @@ async def test_get_files_lists_supported_media(tmp_path) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def test_has_images_reports_presence(tmp_path) -> None:
|
async def test_has_images_reports_presence(tmp_path) -> None:
|
||||||
settings.settings["example_images_path"] = str(tmp_path)
|
settings_manager = get_settings_manager()
|
||||||
|
settings_manager.settings["example_images_path"] = str(tmp_path)
|
||||||
model_hash = "c" * 64
|
model_hash = "c" * 64
|
||||||
model_folder = tmp_path / model_hash
|
model_folder = tmp_path / model_hash
|
||||||
model_folder.mkdir()
|
model_folder.mkdir()
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
@@ -30,7 +32,23 @@ def patch_metadata_manager(monkeypatch: pytest.MonkeyPatch):
|
|||||||
saved.append((path, metadata.copy()))
|
saved.append((path, metadata.copy()))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
class SimpleMetadata:
|
||||||
|
def __init__(self, payload: Dict[str, Any]) -> None:
|
||||||
|
self._payload = payload
|
||||||
|
self._unknown_fields: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return self._payload.copy()
|
||||||
|
|
||||||
|
async def fake_load(path: str, *_args: Any, **_kwargs: Any):
|
||||||
|
metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json"
|
||||||
|
if os.path.exists(metadata_path):
|
||||||
|
data = json.loads(Path(metadata_path).read_text(encoding="utf-8"))
|
||||||
|
return SimpleMetadata(data), False
|
||||||
|
return None, False
|
||||||
|
|
||||||
monkeypatch.setattr(metadata_module.MetadataManager, "save_metadata", staticmethod(fake_save))
|
monkeypatch.setattr(metadata_module.MetadataManager, "save_metadata", staticmethod(fake_save))
|
||||||
|
monkeypatch.setattr(metadata_module.MetadataManager, "load_metadata", staticmethod(fake_load))
|
||||||
return saved
|
return saved
|
||||||
|
|
||||||
|
|
||||||
@@ -64,10 +82,80 @@ async def test_update_metadata_after_import_enriches_entries(monkeypatch: pytest
|
|||||||
assert custom[0]["hasMeta"] is True
|
assert custom[0]["hasMeta"] is True
|
||||||
assert custom[0]["type"] == "image"
|
assert custom[0]["type"] == "image"
|
||||||
|
|
||||||
assert patch_metadata_manager[0][0] == str(model_file)
|
assert Path(patch_metadata_manager[0][0]) == model_file
|
||||||
assert scanner.updates
|
assert scanner.updates
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_metadata_after_import_preserves_existing_metadata(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
patch_metadata_manager,
|
||||||
|
):
|
||||||
|
model_hash = "b" * 64
|
||||||
|
model_file = tmp_path / "preserve.safetensors"
|
||||||
|
model_file.write_text("content", encoding="utf-8")
|
||||||
|
metadata_path = tmp_path / "preserve.metadata.json"
|
||||||
|
|
||||||
|
existing_payload = {
|
||||||
|
"model_name": "Example",
|
||||||
|
"file_path": str(model_file),
|
||||||
|
"civitai": {
|
||||||
|
"id": 42,
|
||||||
|
"modelId": 88,
|
||||||
|
"name": "Example",
|
||||||
|
"trainedWords": ["foo"],
|
||||||
|
"images": [{"url": "https://example.com/default.png", "type": "image"}],
|
||||||
|
"customImages": [
|
||||||
|
{"id": "existing-id", "type": "image", "url": "", "nsfwLevel": 0}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"extraField": "keep-me",
|
||||||
|
}
|
||||||
|
metadata_path.write_text(json.dumps(existing_payload), encoding="utf-8")
|
||||||
|
|
||||||
|
model_data = {
|
||||||
|
"sha256": model_hash,
|
||||||
|
"model_name": "Example",
|
||||||
|
"file_path": str(model_file),
|
||||||
|
"civitai": {
|
||||||
|
"id": 42,
|
||||||
|
"modelId": 88,
|
||||||
|
"name": "Example",
|
||||||
|
"trainedWords": ["foo"],
|
||||||
|
"customImages": [],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
scanner = StubScanner([model_data])
|
||||||
|
|
||||||
|
image_path = tmp_path / "new.png"
|
||||||
|
image_path.write_bytes(b"fakepng")
|
||||||
|
|
||||||
|
monkeypatch.setattr(metadata_module.ExifUtils, "extract_image_metadata", staticmethod(lambda _path: None))
|
||||||
|
monkeypatch.setattr(metadata_module.MetadataUpdater, "_parse_image_metadata", staticmethod(lambda payload: None))
|
||||||
|
|
||||||
|
regular, custom = await metadata_module.MetadataUpdater.update_metadata_after_import(
|
||||||
|
model_hash,
|
||||||
|
model_data,
|
||||||
|
scanner,
|
||||||
|
[(str(image_path), "new-id")],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert regular == existing_payload["civitai"]["images"]
|
||||||
|
assert any(entry["id"] == "new-id" for entry in custom)
|
||||||
|
|
||||||
|
saved_path, saved_payload = patch_metadata_manager[-1]
|
||||||
|
assert Path(saved_path) == model_file
|
||||||
|
assert saved_payload["extraField"] == "keep-me"
|
||||||
|
assert saved_payload["civitai"]["images"] == existing_payload["civitai"]["images"]
|
||||||
|
assert saved_payload["civitai"]["trainedWords"] == ["foo"]
|
||||||
|
assert {entry["id"] for entry in saved_payload["civitai"]["customImages"]} == {"existing-id", "new-id"}
|
||||||
|
|
||||||
|
assert scanner.updates
|
||||||
|
updated_metadata = scanner.updates[-1][2]
|
||||||
|
assert updated_metadata["civitai"]["images"] == existing_payload["civitai"]["images"]
|
||||||
|
assert {entry["id"] for entry in updated_metadata["civitai"]["customImages"]} == {"existing-id", "new-id"}
|
||||||
|
|
||||||
async def test_refresh_model_metadata_records_failures(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
async def test_refresh_model_metadata_records_failures(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||||
model_hash = "b" * 64
|
model_hash = "b" * 64
|
||||||
model_file = tmp_path / "model.safetensors"
|
model_file = tmp_path / "model.safetensors"
|
||||||
@@ -79,6 +167,16 @@ async def test_refresh_model_metadata_records_failures(monkeypatch: pytest.Monke
|
|||||||
async def fetch_and_update_model(self, **_kwargs):
|
async def fetch_and_update_model(self, **_kwargs):
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
|
async def fake_hydrate(model_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
model_data["hydrated"] = True
|
||||||
|
return model_data
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
metadata_module.MetadataManager,
|
||||||
|
"hydrate_model_data",
|
||||||
|
staticmethod(fake_hydrate),
|
||||||
|
)
|
||||||
|
|
||||||
monkeypatch.setattr(metadata_module, "_metadata_sync_service", StubMetadataSync())
|
monkeypatch.setattr(metadata_module, "_metadata_sync_service", StubMetadataSync())
|
||||||
|
|
||||||
result = await metadata_module.MetadataUpdater.refresh_model_metadata(
|
result = await metadata_module.MetadataUpdater.refresh_model_metadata(
|
||||||
@@ -89,6 +187,7 @@ async def test_refresh_model_metadata_records_failures(monkeypatch: pytest.Monke
|
|||||||
{"refreshed_models": set(), "errors": [], "last_error": None},
|
{"refreshed_models": set(), "errors": [], "last_error": None},
|
||||||
)
|
)
|
||||||
assert result is True
|
assert result is True
|
||||||
|
assert cache_item["hydrated"] is True
|
||||||
|
|
||||||
|
|
||||||
async def test_update_metadata_from_local_examples_generates_entries(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
async def test_update_metadata_from_local_examples_generates_entries(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||||
@@ -112,4 +211,4 @@ async def test_update_metadata_from_local_examples_generates_entries(monkeypatch
|
|||||||
str(model_dir),
|
str(model_dir),
|
||||||
)
|
)
|
||||||
assert success is True
|
assert success is True
|
||||||
assert model_data["civitai"]["images"]
|
assert model_data["civitai"]["images"]
|
||||||
@@ -6,7 +6,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from py.services.settings_manager import settings
|
from py.services.settings_manager import get_settings_manager
|
||||||
from py.utils.example_images_paths import (
|
from py.utils.example_images_paths import (
|
||||||
ensure_library_root_exists,
|
ensure_library_root_exists,
|
||||||
get_model_folder,
|
get_model_folder,
|
||||||
@@ -18,18 +18,24 @@ from py.utils.example_images_paths import (
|
|||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def restore_settings():
|
def restore_settings():
|
||||||
original = copy.deepcopy(settings.settings)
|
manager = get_settings_manager()
|
||||||
|
original = copy.deepcopy(manager.settings)
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
settings.settings.clear()
|
manager.settings.clear()
|
||||||
settings.settings.update(original)
|
manager.settings.update(original)
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_folder_single_library(tmp_path):
|
@pytest.fixture
|
||||||
settings.settings['example_images_path'] = str(tmp_path)
|
def settings_manager():
|
||||||
settings.settings['libraries'] = {'default': {}}
|
return get_settings_manager()
|
||||||
settings.settings['active_library'] = 'default'
|
|
||||||
|
|
||||||
|
def test_get_model_folder_single_library(tmp_path, settings_manager):
|
||||||
|
settings_manager.settings['example_images_path'] = str(tmp_path)
|
||||||
|
settings_manager.settings['libraries'] = {'default': {}}
|
||||||
|
settings_manager.settings['active_library'] = 'default'
|
||||||
|
|
||||||
model_hash = 'a' * 64
|
model_hash = 'a' * 64
|
||||||
folder = get_model_folder(model_hash)
|
folder = get_model_folder(model_hash)
|
||||||
@@ -39,13 +45,13 @@ def test_get_model_folder_single_library(tmp_path):
|
|||||||
assert relative == model_hash
|
assert relative == model_hash
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_folder_multi_library(tmp_path):
|
def test_get_model_folder_multi_library(tmp_path, settings_manager):
|
||||||
settings.settings['example_images_path'] = str(tmp_path)
|
settings_manager.settings['example_images_path'] = str(tmp_path)
|
||||||
settings.settings['libraries'] = {
|
settings_manager.settings['libraries'] = {
|
||||||
'default': {},
|
'default': {},
|
||||||
'Alt Library': {},
|
'Alt Library': {},
|
||||||
}
|
}
|
||||||
settings.settings['active_library'] = 'Alt Library'
|
settings_manager.settings['active_library'] = 'Alt Library'
|
||||||
|
|
||||||
model_hash = 'b' * 64
|
model_hash = 'b' * 64
|
||||||
expected_folder = tmp_path / 'Alt_Library' / model_hash
|
expected_folder = tmp_path / 'Alt_Library' / model_hash
|
||||||
@@ -57,13 +63,13 @@ def test_get_model_folder_multi_library(tmp_path):
|
|||||||
assert relative == os.path.join('Alt_Library', model_hash).replace('\\', '/')
|
assert relative == os.path.join('Alt_Library', model_hash).replace('\\', '/')
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_folder_migrates_legacy_structure(tmp_path):
|
def test_get_model_folder_migrates_legacy_structure(tmp_path, settings_manager):
|
||||||
settings.settings['example_images_path'] = str(tmp_path)
|
settings_manager.settings['example_images_path'] = str(tmp_path)
|
||||||
settings.settings['libraries'] = {
|
settings_manager.settings['libraries'] = {
|
||||||
'default': {},
|
'default': {},
|
||||||
'extra': {},
|
'extra': {},
|
||||||
}
|
}
|
||||||
settings.settings['active_library'] = 'extra'
|
settings_manager.settings['active_library'] = 'extra'
|
||||||
|
|
||||||
model_hash = 'c' * 64
|
model_hash = 'c' * 64
|
||||||
legacy_folder = tmp_path / model_hash
|
legacy_folder = tmp_path / model_hash
|
||||||
@@ -82,31 +88,31 @@ def test_get_model_folder_migrates_legacy_structure(tmp_path):
|
|||||||
assert (expected_folder / 'image.png').exists()
|
assert (expected_folder / 'image.png').exists()
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_library_root_exists_creates_directories(tmp_path):
|
def test_ensure_library_root_exists_creates_directories(tmp_path, settings_manager):
|
||||||
settings.settings['example_images_path'] = str(tmp_path)
|
settings_manager.settings['example_images_path'] = str(tmp_path)
|
||||||
settings.settings['libraries'] = {'default': {}, 'secondary': {}}
|
settings_manager.settings['libraries'] = {'default': {}, 'secondary': {}}
|
||||||
settings.settings['active_library'] = 'secondary'
|
settings_manager.settings['active_library'] = 'secondary'
|
||||||
|
|
||||||
resolved = ensure_library_root_exists('secondary')
|
resolved = ensure_library_root_exists('secondary')
|
||||||
assert Path(resolved) == tmp_path / 'secondary'
|
assert Path(resolved) == tmp_path / 'secondary'
|
||||||
assert (tmp_path / 'secondary').is_dir()
|
assert (tmp_path / 'secondary').is_dir()
|
||||||
|
|
||||||
|
|
||||||
def test_iter_library_roots_returns_all_configured(tmp_path):
|
def test_iter_library_roots_returns_all_configured(tmp_path, settings_manager):
|
||||||
settings.settings['example_images_path'] = str(tmp_path)
|
settings_manager.settings['example_images_path'] = str(tmp_path)
|
||||||
settings.settings['libraries'] = {'default': {}, 'alt': {}}
|
settings_manager.settings['libraries'] = {'default': {}, 'alt': {}}
|
||||||
settings.settings['active_library'] = 'alt'
|
settings_manager.settings['active_library'] = 'alt'
|
||||||
|
|
||||||
roots = dict(iter_library_roots())
|
roots = dict(iter_library_roots())
|
||||||
assert roots['default'] == str(tmp_path / 'default')
|
assert roots['default'] == str(tmp_path / 'default')
|
||||||
assert roots['alt'] == str(tmp_path / 'alt')
|
assert roots['alt'] == str(tmp_path / 'alt')
|
||||||
|
|
||||||
|
|
||||||
def test_is_valid_example_images_root_accepts_hash_directories(tmp_path):
|
def test_is_valid_example_images_root_accepts_hash_directories(tmp_path, settings_manager):
|
||||||
settings.settings['example_images_path'] = str(tmp_path)
|
settings_manager.settings['example_images_path'] = str(tmp_path)
|
||||||
# Ensure single-library mode (not multi-library mode)
|
# Ensure single-library mode (not multi-library mode)
|
||||||
settings.settings['libraries'] = {'default': {}}
|
settings_manager.settings['libraries'] = {'default': {}}
|
||||||
settings.settings['active_library'] = 'default'
|
settings_manager.settings['active_library'] = 'default'
|
||||||
|
|
||||||
hash_folder = tmp_path / ('d' * 64)
|
hash_folder = tmp_path / ('d' * 64)
|
||||||
hash_folder.mkdir()
|
hash_folder.mkdir()
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
@@ -7,18 +8,42 @@ from typing import Any, Dict, Tuple
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from py.services.settings_manager import settings
|
from py.services.settings_manager import get_settings_manager
|
||||||
|
from py.utils import example_images_metadata as metadata_module
|
||||||
from py.utils import example_images_processor as processor_module
|
from py.utils import example_images_processor as processor_module
|
||||||
|
from py.utils.example_images_paths import get_model_folder
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def restore_settings() -> None:
|
def restore_settings() -> None:
|
||||||
original = settings.settings.copy()
|
manager = get_settings_manager()
|
||||||
|
original = manager.settings.copy()
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
settings.settings.clear()
|
manager.settings.clear()
|
||||||
settings.settings.update(original)
|
manager.settings.update(original)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def patch_metadata_loader(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
class SimpleMetadata:
|
||||||
|
def __init__(self, payload: Dict[str, Any]) -> None:
|
||||||
|
self._payload = payload
|
||||||
|
self._unknown_fields: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return self._payload.copy()
|
||||||
|
|
||||||
|
async def fake_load(path: str, *_args: Any, **_kwargs: Any):
|
||||||
|
metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json"
|
||||||
|
if os.path.exists(metadata_path):
|
||||||
|
data = json.loads(Path(metadata_path).read_text(encoding="utf-8"))
|
||||||
|
return SimpleMetadata(data), False
|
||||||
|
return None, False
|
||||||
|
|
||||||
|
monkeypatch.setattr(processor_module.MetadataManager, "load_metadata", staticmethod(fake_load))
|
||||||
|
monkeypatch.setattr(metadata_module.MetadataManager, "load_metadata", staticmethod(fake_load))
|
||||||
|
|
||||||
|
|
||||||
def test_get_file_extension_from_magic_bytes() -> None:
|
def test_get_file_extension_from_magic_bytes() -> None:
|
||||||
@@ -90,9 +115,10 @@ def stub_scanners(monkeypatch: pytest.MonkeyPatch, tmp_path) -> StubScanner:
|
|||||||
|
|
||||||
|
|
||||||
async def test_import_images_creates_hash_directory(monkeypatch: pytest.MonkeyPatch, tmp_path, stub_scanners: StubScanner) -> None:
|
async def test_import_images_creates_hash_directory(monkeypatch: pytest.MonkeyPatch, tmp_path, stub_scanners: StubScanner) -> None:
|
||||||
settings.settings["example_images_path"] = str(tmp_path / "examples")
|
settings_manager = get_settings_manager()
|
||||||
settings.settings["libraries"] = {"default": {}}
|
settings_manager.settings["example_images_path"] = str(tmp_path / "examples")
|
||||||
settings.settings["active_library"] = "default"
|
settings_manager.settings["libraries"] = {"default": {}}
|
||||||
|
settings_manager.settings["active_library"] = "default"
|
||||||
|
|
||||||
source_file = tmp_path / "upload.png"
|
source_file = tmp_path / "upload.png"
|
||||||
source_file.write_bytes(b"PNG data")
|
source_file.write_bytes(b"PNG data")
|
||||||
@@ -112,7 +138,7 @@ async def test_import_images_creates_hash_directory(monkeypatch: pytest.MonkeyPa
|
|||||||
assert result["success"] is True
|
assert result["success"] is True
|
||||||
assert result["files"][0]["name"].startswith("custom_short")
|
assert result["files"][0]["name"].startswith("custom_short")
|
||||||
|
|
||||||
model_folder = Path(settings.settings["example_images_path"]) / ("a" * 64)
|
model_folder = Path(settings_manager.settings["example_images_path"]) / ("a" * 64)
|
||||||
assert model_folder.exists()
|
assert model_folder.exists()
|
||||||
created_files = list(model_folder.glob("custom_short*.png"))
|
created_files = list(model_folder.glob("custom_short*.png"))
|
||||||
assert len(created_files) == 1
|
assert len(created_files) == 1
|
||||||
@@ -132,7 +158,8 @@ async def test_import_images_rejects_missing_parameters(monkeypatch: pytest.Monk
|
|||||||
|
|
||||||
|
|
||||||
async def test_import_images_raises_when_model_not_found(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
async def test_import_images_raises_when_model_not_found(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
||||||
settings.settings["example_images_path"] = str(tmp_path)
|
settings_manager = get_settings_manager()
|
||||||
|
settings_manager.settings["example_images_path"] = str(tmp_path)
|
||||||
|
|
||||||
async def _empty_scanner(cls=None):
|
async def _empty_scanner(cls=None):
|
||||||
return StubScanner([])
|
return StubScanner([])
|
||||||
@@ -143,3 +170,88 @@ async def test_import_images_raises_when_model_not_found(monkeypatch: pytest.Mon
|
|||||||
|
|
||||||
with pytest.raises(processor_module.ExampleImagesImportError):
|
with pytest.raises(processor_module.ExampleImagesImportError):
|
||||||
await processor_module.ExampleImagesProcessor.import_images("a" * 64, [str(tmp_path / "missing.png")])
|
await processor_module.ExampleImagesProcessor.import_images("a" * 64, [str(tmp_path / "missing.png")])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_custom_image_preserves_existing_metadata(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
||||||
|
settings_manager = get_settings_manager()
|
||||||
|
settings_manager.settings["example_images_path"] = str(tmp_path / "examples")
|
||||||
|
|
||||||
|
model_hash = "c" * 64
|
||||||
|
model_file = tmp_path / "keep.safetensors"
|
||||||
|
model_file.write_text("content", encoding="utf-8")
|
||||||
|
metadata_path = tmp_path / "keep.metadata.json"
|
||||||
|
|
||||||
|
existing_metadata = {
|
||||||
|
"model_name": "Keep",
|
||||||
|
"file_path": str(model_file),
|
||||||
|
"civitai": {
|
||||||
|
"images": [{"url": "https://example.com/default.png", "type": "image"}],
|
||||||
|
"customImages": [{"id": "existing-id", "url": "", "type": "image"}],
|
||||||
|
"trainedWords": ["foo"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
metadata_path.write_text(json.dumps(existing_metadata), encoding="utf-8")
|
||||||
|
|
||||||
|
model_data = {
|
||||||
|
"sha256": model_hash,
|
||||||
|
"model_name": "Keep",
|
||||||
|
"file_path": str(model_file),
|
||||||
|
"civitai": {
|
||||||
|
"customImages": [{"id": "existing-id", "url": "", "type": "image"}],
|
||||||
|
"trainedWords": ["foo"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
class Scanner(StubScanner):
|
||||||
|
def has_hash(self, hash_value: str) -> bool:
|
||||||
|
return hash_value == model_hash
|
||||||
|
|
||||||
|
scanner = Scanner([model_data])
|
||||||
|
|
||||||
|
async def _return_scanner(cls=None):
|
||||||
|
return scanner
|
||||||
|
|
||||||
|
monkeypatch.setattr(processor_module.ServiceRegistry, "get_lora_scanner", classmethod(_return_scanner))
|
||||||
|
monkeypatch.setattr(processor_module.ServiceRegistry, "get_checkpoint_scanner", classmethod(_return_scanner))
|
||||||
|
monkeypatch.setattr(processor_module.ServiceRegistry, "get_embedding_scanner", classmethod(_return_scanner))
|
||||||
|
|
||||||
|
model_folder = get_model_folder(model_hash)
|
||||||
|
os.makedirs(model_folder, exist_ok=True)
|
||||||
|
(Path(model_folder) / "custom_existing-id.png").write_bytes(b"data")
|
||||||
|
|
||||||
|
saved: list[tuple[str, Dict[str, Any]]] = []
|
||||||
|
|
||||||
|
async def fake_save(path: str, payload: Dict[str, Any]) -> bool:
|
||||||
|
saved.append((path, payload.copy()))
|
||||||
|
return True
|
||||||
|
|
||||||
|
monkeypatch.setattr(processor_module.MetadataManager, "save_metadata", staticmethod(fake_save))
|
||||||
|
|
||||||
|
class StubRequest:
|
||||||
|
def __init__(self, payload: Dict[str, Any]) -> None:
|
||||||
|
self._payload = payload
|
||||||
|
|
||||||
|
async def json(self) -> Dict[str, Any]:
|
||||||
|
return self._payload
|
||||||
|
|
||||||
|
response = await processor_module.ExampleImagesProcessor.delete_custom_image(
|
||||||
|
StubRequest({"model_hash": model_hash, "short_id": "existing-id"})
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
body = json.loads(response.text)
|
||||||
|
assert body["success"] is True
|
||||||
|
assert body["custom_images"] == []
|
||||||
|
assert not (Path(model_folder) / "custom_existing-id.png").exists()
|
||||||
|
|
||||||
|
saved_path, saved_payload = saved[-1]
|
||||||
|
assert saved_path == str(model_file)
|
||||||
|
assert saved_payload["civitai"]["images"] == existing_metadata["civitai"]["images"]
|
||||||
|
assert saved_payload["civitai"]["trainedWords"] == ["foo"]
|
||||||
|
assert saved_payload["civitai"]["customImages"] == []
|
||||||
|
|
||||||
|
assert scanner.updated
|
||||||
|
_, _, updated_metadata = scanner.updated[-1]
|
||||||
|
assert updated_metadata["civitai"]["images"] == existing_metadata["civitai"]["images"]
|
||||||
|
assert updated_metadata["civitai"]["customImages"] == []
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from py.services.settings_manager import settings
|
from py.services.settings_manager import SettingsManager, get_settings_manager
|
||||||
from py.utils.utils import (
|
from py.utils.utils import (
|
||||||
calculate_recipe_fingerprint,
|
calculate_recipe_fingerprint,
|
||||||
calculate_relative_path_for_model,
|
calculate_relative_path_for_model,
|
||||||
@@ -9,7 +9,8 @@ from py.utils.utils import (
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def isolated_settings(monkeypatch):
|
def isolated_settings(monkeypatch):
|
||||||
default_settings = settings._get_default_settings()
|
manager = get_settings_manager()
|
||||||
|
default_settings = manager._get_default_settings()
|
||||||
default_settings.update(
|
default_settings.update(
|
||||||
{
|
{
|
||||||
"download_path_templates": {
|
"download_path_templates": {
|
||||||
@@ -20,8 +21,8 @@ def isolated_settings(monkeypatch):
|
|||||||
"base_model_path_mappings": {},
|
"base_model_path_mappings": {},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(settings, "settings", default_settings)
|
monkeypatch.setattr(manager, "settings", default_settings)
|
||||||
monkeypatch.setattr(type(settings), "_save_settings", lambda self: None)
|
monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
|
||||||
return default_settings
|
return default_settings
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import { app } from "../../scripts/app.js";
|
import { app } from "../../scripts/app.js";
|
||||||
import { api } from "../../scripts/api.js";
|
import { api } from "../../scripts/api.js";
|
||||||
import { addJsonDisplayWidget } from "./json_display_widget.js";
|
import { addJsonDisplayWidget } from "./json_display_widget.js";
|
||||||
|
import { getNodeFromGraph } from "./utils.js";
|
||||||
|
|
||||||
app.registerExtension({
|
app.registerExtension({
|
||||||
name: "LoraManager.DebugMetadata",
|
name: "LoraManager.DebugMetadata",
|
||||||
@@ -8,8 +9,8 @@ app.registerExtension({
|
|||||||
setup() {
|
setup() {
|
||||||
// Add message handler to listen for metadata updates from Python
|
// Add message handler to listen for metadata updates from Python
|
||||||
api.addEventListener("metadata_update", (event) => {
|
api.addEventListener("metadata_update", (event) => {
|
||||||
const { id, metadata } = event.detail;
|
const { id, graph_id: graphId, metadata } = event.detail;
|
||||||
this.handleMetadataUpdate(id, metadata);
|
this.handleMetadataUpdate(id, graphId, metadata);
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -37,8 +38,8 @@ app.registerExtension({
|
|||||||
},
|
},
|
||||||
|
|
||||||
// Handle metadata updates from Python
|
// Handle metadata updates from Python
|
||||||
handleMetadataUpdate(id, metadata) {
|
handleMetadataUpdate(id, graphId, metadata) {
|
||||||
const node = app.graph.getNodeById(+id);
|
const node = getNodeFromGraph(graphId, id);
|
||||||
if (!node || node.comfyClass !== "Debug Metadata (LoraManager)") {
|
if (!node || node.comfyClass !== "Debug Metadata (LoraManager)") {
|
||||||
console.warn("Node not found or not a DebugMetadata node:", id);
|
console.warn("Node not found or not a DebugMetadata node:", id);
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import {
|
|||||||
chainCallback,
|
chainCallback,
|
||||||
mergeLoras,
|
mergeLoras,
|
||||||
setupInputWidgetWithAutocomplete,
|
setupInputWidgetWithAutocomplete,
|
||||||
|
getAllGraphNodes,
|
||||||
|
getNodeFromGraph,
|
||||||
} from "./utils.js";
|
} from "./utils.js";
|
||||||
import { addLorasWidget } from "./loras_widget.js";
|
import { addLorasWidget } from "./loras_widget.js";
|
||||||
|
|
||||||
@@ -16,23 +18,26 @@ app.registerExtension({
|
|||||||
setup() {
|
setup() {
|
||||||
// Add message handler to listen for messages from Python
|
// Add message handler to listen for messages from Python
|
||||||
api.addEventListener("lora_code_update", (event) => {
|
api.addEventListener("lora_code_update", (event) => {
|
||||||
const { id, lora_code, mode } = event.detail;
|
this.handleLoraCodeUpdate(event.detail || {});
|
||||||
this.handleLoraCodeUpdate(id, lora_code, mode);
|
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
|
|
||||||
// Handle lora code updates from Python
|
// Handle lora code updates from Python
|
||||||
handleLoraCodeUpdate(id, loraCode, mode) {
|
handleLoraCodeUpdate(message) {
|
||||||
|
const nodeId = message?.node_id ?? message?.id;
|
||||||
|
const graphId = message?.graph_id;
|
||||||
|
const loraCode = message?.lora_code ?? "";
|
||||||
|
const mode = message?.mode ?? "append";
|
||||||
|
|
||||||
|
const numericNodeId =
|
||||||
|
typeof nodeId === "string" ? Number(nodeId) : nodeId;
|
||||||
|
|
||||||
// Handle broadcast mode (for Desktop/non-browser support)
|
// Handle broadcast mode (for Desktop/non-browser support)
|
||||||
if (id === -1) {
|
if (numericNodeId === -1) {
|
||||||
// Find all Lora Loader nodes in the current graph
|
// Find all Lora Loader nodes in the current graph
|
||||||
const loraLoaderNodes = [];
|
const loraLoaderNodes = getAllGraphNodes(app.graph)
|
||||||
for (const nodeId in app.graph._nodes_by_id) {
|
.map(({ node }) => node)
|
||||||
const node = app.graph._nodes_by_id[nodeId];
|
.filter((node) => node?.comfyClass === "Lora Loader (LoraManager)");
|
||||||
if (node.comfyClass === "Lora Loader (LoraManager)") {
|
|
||||||
loraLoaderNodes.push(node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update each Lora Loader node found
|
// Update each Lora Loader node found
|
||||||
if (loraLoaderNodes.length > 0) {
|
if (loraLoaderNodes.length > 0) {
|
||||||
@@ -52,14 +57,18 @@ app.registerExtension({
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Standard mode - update a specific node
|
// Standard mode - update a specific node
|
||||||
const node = app.graph.getNodeById(+id);
|
const node = getNodeFromGraph(graphId, numericNodeId);
|
||||||
if (
|
if (
|
||||||
!node ||
|
!node ||
|
||||||
(node.comfyClass !== "Lora Loader (LoraManager)" &&
|
(node.comfyClass !== "Lora Loader (LoraManager)" &&
|
||||||
node.comfyClass !== "Lora Stacker (LoraManager)" &&
|
node.comfyClass !== "Lora Stacker (LoraManager)" &&
|
||||||
node.comfyClass !== "WanVideo Lora Select (LoraManager)")
|
node.comfyClass !== "WanVideo Lora Select (LoraManager)")
|
||||||
) {
|
) {
|
||||||
console.warn("Node not found or not a LoraLoader:", id);
|
console.warn(
|
||||||
|
"Node not found or not a LoraLoader:",
|
||||||
|
graphId ?? "root",
|
||||||
|
nodeId
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import {
|
|||||||
chainCallback,
|
chainCallback,
|
||||||
mergeLoras,
|
mergeLoras,
|
||||||
setupInputWidgetWithAutocomplete,
|
setupInputWidgetWithAutocomplete,
|
||||||
|
getLinkFromGraph,
|
||||||
|
getNodeKey,
|
||||||
} from "./utils.js";
|
} from "./utils.js";
|
||||||
import { addLorasWidget } from "./loras_widget.js";
|
import { addLorasWidget } from "./loras_widget.js";
|
||||||
|
|
||||||
@@ -124,17 +126,18 @@ app.registerExtension({
|
|||||||
|
|
||||||
// Helper function to find and update downstream Lora Loader nodes
|
// Helper function to find and update downstream Lora Loader nodes
|
||||||
function updateDownstreamLoaders(startNode, visited = new Set()) {
|
function updateDownstreamLoaders(startNode, visited = new Set()) {
|
||||||
if (visited.has(startNode.id)) return;
|
const nodeKey = getNodeKey(startNode);
|
||||||
visited.add(startNode.id);
|
if (!nodeKey || visited.has(nodeKey)) return;
|
||||||
|
visited.add(nodeKey);
|
||||||
|
|
||||||
// Check each output link
|
// Check each output link
|
||||||
if (startNode.outputs) {
|
if (startNode.outputs) {
|
||||||
for (const output of startNode.outputs) {
|
for (const output of startNode.outputs) {
|
||||||
if (output.links) {
|
if (output.links) {
|
||||||
for (const linkId of output.links) {
|
for (const linkId of output.links) {
|
||||||
const link = app.graph.links[linkId];
|
const link = getLinkFromGraph(startNode.graph, linkId);
|
||||||
if (link) {
|
if (link) {
|
||||||
const targetNode = app.graph.getNodeById(link.target_id);
|
const targetNode = startNode.graph?.getNodeById?.(link.target_id);
|
||||||
|
|
||||||
// If target is a Lora Loader, collect all active loras in the chain and update
|
// If target is a Lora Loader, collect all active loras in the chain and update
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { app } from "../../scripts/app.js";
|
import { app } from "../../scripts/app.js";
|
||||||
import { api } from "../../scripts/api.js";
|
import { api } from "../../scripts/api.js";
|
||||||
import { CONVERTED_TYPE } from "./utils.js";
|
import { CONVERTED_TYPE, getNodeFromGraph } from "./utils.js";
|
||||||
import { addTagsWidget } from "./tags_widget.js";
|
import { addTagsWidget } from "./tags_widget.js";
|
||||||
|
|
||||||
// TriggerWordToggle extension for ComfyUI
|
// TriggerWordToggle extension for ComfyUI
|
||||||
@@ -10,8 +10,8 @@ app.registerExtension({
|
|||||||
setup() {
|
setup() {
|
||||||
// Add message handler to listen for messages from Python
|
// Add message handler to listen for messages from Python
|
||||||
api.addEventListener("trigger_word_update", (event) => {
|
api.addEventListener("trigger_word_update", (event) => {
|
||||||
const { id, message } = event.detail;
|
const { id, graph_id: graphId, message } = event.detail;
|
||||||
this.handleTriggerWordUpdate(id, message);
|
this.handleTriggerWordUpdate(id, graphId, message);
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -76,8 +76,8 @@ app.registerExtension({
|
|||||||
},
|
},
|
||||||
|
|
||||||
// Handle trigger word updates from Python
|
// Handle trigger word updates from Python
|
||||||
handleTriggerWordUpdate(id, message) {
|
handleTriggerWordUpdate(id, graphId, message) {
|
||||||
const node = app.graph.getNodeById(+id);
|
const node = getNodeFromGraph(graphId, id);
|
||||||
if (!node || node.comfyClass !== "TriggerWord Toggle (LoraManager)") {
|
if (!node || node.comfyClass !== "TriggerWord Toggle (LoraManager)") {
|
||||||
console.warn("Node not found or not a TriggerWordToggle:", id);
|
console.warn("Node not found or not a TriggerWordToggle:", id);
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
// ComfyUI extension to track model usage statistics
|
// ComfyUI extension to track model usage statistics
|
||||||
import { app } from "../../scripts/app.js";
|
import { app } from "../../scripts/app.js";
|
||||||
import { api } from "../../scripts/api.js";
|
import { api } from "../../scripts/api.js";
|
||||||
import { showToast } from "./utils.js";
|
import { getAllGraphNodes, getNodeReference, showToast } from "./utils.js";
|
||||||
|
|
||||||
// Define target nodes and their widget configurations
|
// Define target nodes and their widget configurations
|
||||||
const PATH_CORRECTION_TARGETS = [
|
const PATH_CORRECTION_TARGETS = [
|
||||||
@@ -56,25 +56,35 @@ app.registerExtension({
|
|||||||
|
|
||||||
async refreshRegistry() {
|
async refreshRegistry() {
|
||||||
try {
|
try {
|
||||||
// Get current workflow nodes
|
|
||||||
const prompt = await app.graphToPrompt();
|
|
||||||
const workflow = prompt.workflow;
|
|
||||||
if (!workflow || !workflow.nodes) {
|
|
||||||
console.warn("No workflow nodes found for registry refresh");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find all Lora nodes
|
|
||||||
const loraNodes = [];
|
const loraNodes = [];
|
||||||
for (const node of workflow.nodes.values()) {
|
const nodeEntries = getAllGraphNodes(app.graph);
|
||||||
if (node.type === "Lora Loader (LoraManager)" ||
|
|
||||||
node.type === "Lora Stacker (LoraManager)" ||
|
for (const { graph, node } of nodeEntries) {
|
||||||
node.type === "WanVideo Lora Select (LoraManager)") {
|
if (!node || !node.comfyClass) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
node.comfyClass === "Lora Loader (LoraManager)" ||
|
||||||
|
node.comfyClass === "Lora Stacker (LoraManager)" ||
|
||||||
|
node.comfyClass === "WanVideo Lora Select (LoraManager)"
|
||||||
|
) {
|
||||||
|
const reference = getNodeReference(node);
|
||||||
|
if (!reference) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const graphName = typeof graph?.name === "string" && graph.name.trim()
|
||||||
|
? graph.name
|
||||||
|
: null;
|
||||||
|
|
||||||
loraNodes.push({
|
loraNodes.push({
|
||||||
node_id: node.id,
|
node_id: reference.node_id,
|
||||||
bgcolor: node.bgcolor || null,
|
graph_id: reference.graph_id,
|
||||||
title: node.title || node.type,
|
graph_name: graphName,
|
||||||
type: node.type
|
bgcolor: node.bgcolor ?? node.color ?? null,
|
||||||
|
title: node.title || node.comfyClass,
|
||||||
|
type: node.comfyClass,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,120 @@ export const CONVERTED_TYPE = 'converted-widget';
|
|||||||
import { app } from "../../scripts/app.js";
|
import { app } from "../../scripts/app.js";
|
||||||
import { AutoComplete } from "./autocomplete.js";
|
import { AutoComplete } from "./autocomplete.js";
|
||||||
|
|
||||||
|
const ROOT_GRAPH_ID = "root";
|
||||||
|
|
||||||
|
function isMapLike(collection) {
|
||||||
|
return collection && typeof collection.entries === "function" && typeof collection.values === "function";
|
||||||
|
}
|
||||||
|
|
||||||
|
function getChildGraphs(graph) {
|
||||||
|
if (!graph || !graph._subgraphs) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const rawSubgraphs = isMapLike(graph._subgraphs)
|
||||||
|
? Array.from(graph._subgraphs.values())
|
||||||
|
: Object.values(graph._subgraphs);
|
||||||
|
|
||||||
|
return rawSubgraphs
|
||||||
|
.map((subgraph) => subgraph?.graph || subgraph?._graph || subgraph)
|
||||||
|
.filter((subgraph) => subgraph && subgraph !== graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
function traverseGraphs(rootGraph, visitor, visited = new Set()) {
|
||||||
|
const graph = rootGraph || app.graph;
|
||||||
|
if (!graph) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const graphId = getGraphId(graph);
|
||||||
|
if (visited.has(graphId)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
visited.add(graphId);
|
||||||
|
visitor(graph);
|
||||||
|
|
||||||
|
for (const subgraph of getChildGraphs(graph)) {
|
||||||
|
traverseGraphs(subgraph, visitor, visited);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getGraphId(graph) {
|
||||||
|
return graph?.id ?? ROOT_GRAPH_ID;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getNodeGraphId(node) {
|
||||||
|
if (!node) {
|
||||||
|
return ROOT_GRAPH_ID;
|
||||||
|
}
|
||||||
|
return getGraphId(node.graph || app.graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getGraphById(graphId, rootGraph = app.graph) {
|
||||||
|
if (!graphId) {
|
||||||
|
return rootGraph;
|
||||||
|
}
|
||||||
|
|
||||||
|
let foundGraph = null;
|
||||||
|
traverseGraphs(rootGraph, (graph) => {
|
||||||
|
if (!foundGraph && getGraphId(graph) === graphId) {
|
||||||
|
foundGraph = graph;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return foundGraph;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getNodeFromGraph(graphId, nodeId) {
|
||||||
|
const graph = getGraphById(graphId) || app.graph;
|
||||||
|
if (!graph || typeof graph.getNodeById !== "function") {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const numericId = typeof nodeId === "string" ? Number(nodeId) : nodeId;
|
||||||
|
return graph.getNodeById(Number.isNaN(numericId) ? nodeId : numericId) || null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getAllGraphNodes(rootGraph = app.graph) {
|
||||||
|
const nodes = [];
|
||||||
|
traverseGraphs(rootGraph, (graph) => {
|
||||||
|
if (Array.isArray(graph._nodes)) {
|
||||||
|
for (const node of graph._nodes) {
|
||||||
|
nodes.push({ graph, node });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return nodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getNodeReference(node) {
|
||||||
|
if (!node) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
node_id: node.id,
|
||||||
|
graph_id: getNodeGraphId(node),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getNodeKey(node) {
|
||||||
|
if (!node) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return `${getNodeGraphId(node)}:${node.id}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getLinkFromGraph(graph, linkId) {
|
||||||
|
if (!graph || graph.links == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isMapLike(graph.links)) {
|
||||||
|
return graph.links.get(linkId) || null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return graph.links[linkId] || null;
|
||||||
|
}
|
||||||
|
|
||||||
export function chainCallback(object, property, callback) {
|
export function chainCallback(object, property, callback) {
|
||||||
if (object == undefined) {
|
if (object == undefined) {
|
||||||
//This should not happen.
|
//This should not happen.
|
||||||
@@ -103,42 +217,56 @@ export const LORA_PATTERN = /<lora:([^:]+):([-\d\.]+)(?::([-\d\.]+))?>/g;
|
|||||||
// Get connected Lora Stacker nodes that feed into the current node
|
// Get connected Lora Stacker nodes that feed into the current node
|
||||||
export function getConnectedInputStackers(node) {
|
export function getConnectedInputStackers(node) {
|
||||||
const connectedStackers = [];
|
const connectedStackers = [];
|
||||||
|
|
||||||
if (node.inputs) {
|
if (!node?.inputs) {
|
||||||
for (const input of node.inputs) {
|
return connectedStackers;
|
||||||
if (input.name === "lora_stack" && input.link) {
|
}
|
||||||
const link = app.graph.links[input.link];
|
|
||||||
if (link) {
|
for (const input of node.inputs) {
|
||||||
const sourceNode = app.graph.getNodeById(link.origin_id);
|
if (input.name !== "lora_stack" || !input.link) {
|
||||||
if (sourceNode && sourceNode.comfyClass === "Lora Stacker (LoraManager)") {
|
continue;
|
||||||
connectedStackers.push(sourceNode);
|
}
|
||||||
}
|
|
||||||
}
|
const link = getLinkFromGraph(node.graph, input.link);
|
||||||
}
|
if (!link) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const sourceNode = node.graph?.getNodeById?.(link.origin_id);
|
||||||
|
if (sourceNode && sourceNode.comfyClass === "Lora Stacker (LoraManager)") {
|
||||||
|
connectedStackers.push(sourceNode);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return connectedStackers;
|
return connectedStackers;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get connected TriggerWord Toggle nodes that receive output from the current node
|
// Get connected TriggerWord Toggle nodes that receive output from the current node
|
||||||
export function getConnectedTriggerToggleNodes(node) {
|
export function getConnectedTriggerToggleNodes(node) {
|
||||||
const connectedNodes = [];
|
const connectedNodes = [];
|
||||||
|
|
||||||
if (node.outputs && node.outputs.length > 0) {
|
if (!node?.outputs) {
|
||||||
for (const output of node.outputs) {
|
return connectedNodes;
|
||||||
if (output.links && output.links.length > 0) {
|
}
|
||||||
for (const linkId of output.links) {
|
|
||||||
const link = app.graph.links[linkId];
|
for (const output of node.outputs) {
|
||||||
if (link) {
|
if (!output?.links?.length) {
|
||||||
const targetNode = app.graph.getNodeById(link.target_id);
|
continue;
|
||||||
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
|
}
|
||||||
connectedNodes.push(targetNode.id);
|
|
||||||
}
|
for (const linkId of output.links) {
|
||||||
}
|
const link = getLinkFromGraph(node.graph, linkId);
|
||||||
}
|
if (!link) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const targetNode = node.graph?.getNodeById?.(link.target_id);
|
||||||
|
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
|
||||||
|
connectedNodes.push(targetNode);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return connectedNodes;
|
return connectedNodes;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -161,11 +289,15 @@ export function getActiveLorasFromNode(node) {
|
|||||||
// Recursively collect all active loras from a node and its input chain
|
// Recursively collect all active loras from a node and its input chain
|
||||||
export function collectActiveLorasFromChain(node, visited = new Set()) {
|
export function collectActiveLorasFromChain(node, visited = new Set()) {
|
||||||
// Prevent infinite loops from circular references
|
// Prevent infinite loops from circular references
|
||||||
if (visited.has(node.id)) {
|
const nodeKey = getNodeKey(node);
|
||||||
|
if (!nodeKey) {
|
||||||
return new Set();
|
return new Set();
|
||||||
}
|
}
|
||||||
visited.add(node.id);
|
if (visited.has(nodeKey)) {
|
||||||
|
return new Set();
|
||||||
|
}
|
||||||
|
visited.add(nodeKey);
|
||||||
|
|
||||||
// Get active loras from current node
|
// Get active loras from current node
|
||||||
const allActiveLoraNames = getActiveLorasFromNode(node);
|
const allActiveLoraNames = getActiveLorasFromNode(node);
|
||||||
|
|
||||||
@@ -181,14 +313,22 @@ export function collectActiveLorasFromChain(node, visited = new Set()) {
|
|||||||
|
|
||||||
// Update trigger words for connected toggle nodes
|
// Update trigger words for connected toggle nodes
|
||||||
export function updateConnectedTriggerWords(node, loraNames) {
|
export function updateConnectedTriggerWords(node, loraNames) {
|
||||||
const connectedNodeIds = getConnectedTriggerToggleNodes(node);
|
const connectedNodes = getConnectedTriggerToggleNodes(node);
|
||||||
if (connectedNodeIds.length > 0) {
|
if (connectedNodes.length > 0) {
|
||||||
|
const nodeIds = connectedNodes
|
||||||
|
.map((connectedNode) => getNodeReference(connectedNode))
|
||||||
|
.filter((reference) => reference !== null);
|
||||||
|
|
||||||
|
if (nodeIds.length === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
fetch("/api/lm/loras/get_trigger_words", {
|
fetch("/api/lm/loras/get_trigger_words", {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: { "Content-Type": "application/json" },
|
headers: { "Content-Type": "application/json" },
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
lora_names: Array.from(loraNames),
|
lora_names: Array.from(loraNames),
|
||||||
node_ids: connectedNodeIds
|
node_ids: nodeIds
|
||||||
})
|
})
|
||||||
}).catch(err => console.error("Error fetching trigger words:", err));
|
}).catch(err => console.error("Error fetching trigger words:", err));
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user