Compare commits

...

20 Commits

Author SHA1 Message Date
Will Miao
76d3aa2b5b feat(version): bump version to 0.9.7 in pyproject.toml 2025-10-09 22:15:50 +08:00
Will Miao
c9a65c7347 feat(metadata): implement model data hydration and enhance metadata handling across services, fixes #547 2025-10-09 22:15:07 +08:00
Will Miao
f542ade628 feat(civitai): implement URL rewriting for Civitai previews and enhance download handling, fixes #499 2025-10-09 17:54:37 +08:00
Will Miao
d2c2bfbe6a feat(sidebar): add recursive search functionality and toggle button 2025-10-09 17:07:10 +08:00
Will Miao
2b6910bd55 feat(misc): mark model versions in library for Civitai user models 2025-10-09 15:23:42 +08:00
Will Miao
b1dd733493 feat(civitai): enhance model version handling with cache lookup 2025-10-09 14:10:00 +08:00
pixelpaws
5dcf0a1e48 Merge pull request #545 from willmiao/codex/evaluate-sqlite-cache-indexing-necessity
feat: index cached models by version id
2025-10-09 13:54:46 +08:00
pixelpaws
cf357b57fc feat(scanner): index cached models by version id 2025-10-09 13:50:44 +08:00
pixelpaws
4e1773833f Merge pull request #544 from willmiao/codex/add-endpoint-to-fetch-civitai-user-models
Add endpoint to fetch Civitai user models
2025-10-09 11:56:57 +08:00
pixelpaws
8cf762ffd3 feat(misc): add civitai user model lookup 2025-10-09 11:49:41 +08:00
pixelpaws
d997eaa429 Merge pull request #543 from willmiao/codex/refactor-get_model_version-logic-and-add-tests, fixes #540
fix: improve Civitai model version retrieval
2025-10-09 11:07:33 +08:00
pixelpaws
8e51f0f19f fix(civitai): improve model version retrieval 2025-10-09 10:56:25 +08:00
pixelpaws
f0e246b4ac Merge pull request #542 from willmiao/codex/investigate-backend-tests-modifying-settings.json
Refactor settings manager to lazy singleton
2025-10-08 16:02:11 +08:00
pixelpaws
a232997a79 fix(utils): respect metadata sync overrides 2025-10-08 15:52:15 +08:00
pixelpaws
08a449db99 fix(metadata): refresh metadata sync settings 2025-10-08 10:38:05 +08:00
pixelpaws
0c023c9888 fix(settings): lazily resolve module aliases 2025-10-08 10:10:23 +08:00
pixelpaws
0ad92d00b3 fix(settings): restore legacy settings aliases 2025-10-08 09:54:36 +08:00
pixelpaws
a726cbea1e fix(routes): pass resolved settings to metadata sync 2025-10-08 09:32:57 +08:00
pixelpaws
c53fa8692b refactor(settings): lazily initialize manager 2025-10-08 08:56:57 +08:00
Will Miao
3118f3b43c feat(graph): enhance node handling with graph identifiers and improve metadata updates, see #408, #538 2025-10-07 23:22:38 +08:00
74 changed files with 2944 additions and 568 deletions

View File

@@ -529,12 +529,15 @@
"title": "Embedding-Modelle"
},
"sidebar": {
"modelRoot": "Modell-Stammverzeichnis",
"modelRoot": "Stammverzeichnis",
"collapseAll": "Alle Ordner einklappen",
"pinSidebar": "Sidebar anheften",
"unpinSidebar": "Sidebar lösen",
"switchToListView": "Zur Listenansicht wechseln",
"switchToTreeView": "Zur Baumansicht wechseln",
"recursiveOn": "Unterordner durchsuchen",
"recursiveOff": "Nur aktuellen Ordner durchsuchen",
"recursiveUnavailable": "Rekursive Suche ist nur in der Baumansicht verfügbar",
"collapseAllDisabled": "Im Listenmodus nicht verfügbar"
},
"statistics": {

View File

@@ -529,12 +529,15 @@
"title": "Embedding Models"
},
"sidebar": {
"modelRoot": "Model Root",
"modelRoot": "Root",
"collapseAll": "Collapse All Folders",
"pinSidebar": "Pin Sidebar",
"unpinSidebar": "Unpin Sidebar",
"switchToListView": "Switch to List View",
"switchToTreeView": "Switch to Tree View",
"recursiveOn": "Search subfolders",
"recursiveOff": "Search current folder only",
"recursiveUnavailable": "Recursive search is available in tree view only",
"collapseAllDisabled": "Not available in list view"
},
"statistics": {

View File

@@ -529,12 +529,15 @@
"title": "Modelos embedding"
},
"sidebar": {
"modelRoot": "Raíz del modelo",
"modelRoot": "Raíz",
"collapseAll": "Colapsar todas las carpetas",
"pinSidebar": "Fijar barra lateral",
"unpinSidebar": "Desfijar barra lateral",
"switchToListView": "Cambiar a vista de lista",
"switchToTreeView": "Cambiar a vista de árbol",
"recursiveOn": "Buscar en subcarpetas",
"recursiveOff": "Buscar solo en la carpeta actual",
"recursiveUnavailable": "La búsqueda recursiva solo está disponible en la vista en árbol",
"collapseAllDisabled": "No disponible en vista de lista"
},
"statistics": {

View File

@@ -529,12 +529,15 @@
"title": "Modèles Embedding"
},
"sidebar": {
"modelRoot": "Racine du modèle",
"modelRoot": "Racine",
"collapseAll": "Réduire tous les dossiers",
"pinSidebar": "Épingler la barre latérale",
"unpinSidebar": "Désépingler la barre latérale",
"switchToListView": "Passer en vue liste",
"switchToTreeView": "Passer en vue arborescence",
"recursiveOn": "Rechercher dans les sous-dossiers",
"recursiveOff": "Rechercher uniquement dans le dossier actuel",
"recursiveUnavailable": "La recherche récursive n'est disponible qu'en vue arborescente",
"collapseAllDisabled": "Non disponible en vue liste"
},
"statistics": {

View File

@@ -529,12 +529,15 @@
"title": "מודלי Embedding"
},
"sidebar": {
"modelRoot": "שורש המודלים",
"modelRoot": "שורש",
"collapseAll": "כווץ את כל התיקיות",
"pinSidebar": "נעל סרגל צד",
"unpinSidebar": "שחרר סרגל צד",
"switchToListView": "עבור לתצוגת רשימה",
"switchToTreeView": "עבור לתצוגת עץ",
"switchToTreeView": "תצוגת עץ",
"recursiveOn": "חיפוש בתיקיות משנה",
"recursiveOff": "חיפוש רק בתיקייה הנוכחית",
"recursiveUnavailable": "חיפוש רקורסיבי זמין רק בתצוגת עץ",
"collapseAllDisabled": "לא זמין בתצוגת רשימה"
},
"statistics": {
@@ -1264,4 +1267,4 @@
"learnMore": "LM Civitai Extension Tutorial"
}
}
}
}

View File

@@ -529,12 +529,15 @@
"title": "Embeddingモデル"
},
"sidebar": {
"modelRoot": "モデルルート",
"modelRoot": "ルート",
"collapseAll": "すべてのフォルダを折りたたむ",
"pinSidebar": "サイドバーを固定",
"unpinSidebar": "サイドバーの固定を解除",
"switchToListView": "リストビューに切り替え",
"switchToTreeView": "ツリービューに切り替え",
"switchToTreeView": "ツリー表示に切り替え",
"recursiveOn": "サブフォルダーを検索",
"recursiveOff": "現在のフォルダーのみを検索",
"recursiveUnavailable": "再帰検索はツリービューでのみ利用できます",
"collapseAllDisabled": "リストビューでは利用できません"
},
"statistics": {
@@ -1264,4 +1267,4 @@
"learnMore": "LM Civitai Extension Tutorial"
}
}
}
}

View File

@@ -529,12 +529,15 @@
"title": "Embedding 모델"
},
"sidebar": {
"modelRoot": "모델 루트",
"modelRoot": "루트",
"collapseAll": "모든 폴더 접기",
"pinSidebar": "사이드바 고정",
"unpinSidebar": "사이드바 고정 해제",
"switchToListView": "목록 보기로 전환",
"switchToTreeView": "트리 보기로 전환",
"recursiveOn": "하위 폴더 검색",
"recursiveOff": "현재 폴더만 검색",
"recursiveUnavailable": "재귀 검색은 트리 보기에서만 사용할 수 있습니다",
"collapseAllDisabled": "목록 보기에서는 사용할 수 없습니다"
},
"statistics": {
@@ -1264,4 +1267,4 @@
"learnMore": "LM Civitai Extension Tutorial"
}
}
}
}

View File

@@ -529,12 +529,15 @@
"title": "Модели Embedding"
},
"sidebar": {
"modelRoot": "Корень моделей",
"modelRoot": "Корень",
"collapseAll": "Свернуть все папки",
"pinSidebar": "Закрепить боковую панель",
"unpinSidebar": "Открепить боковую панель",
"switchToListView": "Переключить на вид списка",
"switchToTreeView": "Переключить на древовидный вид",
"recursiveOn": "Искать во вложенных папках",
"recursiveOff": "Искать только в текущей папке",
"recursiveUnavailable": "Рекурсивный поиск доступен только в режиме дерева",
"collapseAllDisabled": "Недоступно в виде списка"
},
"statistics": {
@@ -1264,4 +1267,4 @@
"learnMore": "LM Civitai Extension Tutorial"
}
}
}
}

View File

@@ -535,12 +535,15 @@
"title": "Embedding 模型"
},
"sidebar": {
"modelRoot": "模型根目录",
"modelRoot": "根目录",
"collapseAll": "折叠所有文件夹",
"pinSidebar": "固定侧边栏",
"unpinSidebar": "取消固定侧边栏",
"switchToListView": "切换到列表视图",
"switchToTreeView": "切换到树状视图",
"recursiveOn": "搜索子文件夹",
"recursiveOff": "仅搜索当前文件夹",
"recursiveUnavailable": "仅在树形视图中可使用递归搜索",
"collapseAllDisabled": "列表视图下不可用"
},
"statistics": {
@@ -1270,4 +1273,4 @@
"learnMore": "LM Civitai Extension Tutorial"
}
}
}
}

View File

@@ -529,12 +529,15 @@
"title": "Embedding 模型"
},
"sidebar": {
"modelRoot": "模型根目錄",
"modelRoot": "根目錄",
"collapseAll": "全部摺疊資料夾",
"pinSidebar": "固定側邊欄",
"unpinSidebar": "取消固定側邊欄",
"switchToListView": "切換至列表檢視",
"switchToTreeView": "切換樹狀檢視",
"switchToTreeView": "切換樹狀檢視",
"recursiveOn": "搜尋子資料夾",
"recursiveOff": "僅搜尋目前資料夾",
"recursiveUnavailable": "遞迴搜尋僅能在樹狀檢視中使用",
"collapseAllDisabled": "列表檢視下不可用"
},
"statistics": {
@@ -1264,4 +1267,4 @@
"learnMore": "LM Civitai Extension Tutorial"
}
}
}
}

View File

@@ -74,8 +74,9 @@ class Config:
"""Persist ComfyUI-derived folder paths to the multi-library settings."""
try:
ensure_settings_file(logger)
from .services.settings_manager import settings as settings_service
from .services.settings_manager import get_settings_manager
settings_service = get_settings_manager()
libraries = settings_service.get_libraries()
comfy_library = libraries.get("comfyui", {})
default_library = libraries.get("default", {})
@@ -442,8 +443,9 @@ class Config:
"""Return the current library registry and active library name."""
try:
from .services.settings_manager import settings as settings_service
from .services.settings_manager import get_settings_manager
settings_service = get_settings_manager()
libraries = settings_service.get_libraries()
active_library = settings_service.get_active_library_name()
return {

View File

@@ -13,7 +13,7 @@ from .routes.misc_routes import MiscRoutes
from .routes.preview_routes import PreviewRoutes
from .routes.example_images_routes import ExampleImagesRoutes
from .services.service_registry import ServiceRegistry
from .services.settings_manager import settings
from .services.settings_manager import get_settings_manager
from .utils.example_images_migration import ExampleImagesMigration
from .services.websocket_manager import ws_manager
from .services.example_images_cleanup_service import ExampleImagesCleanupService
@@ -23,6 +23,25 @@ logger = logging.getLogger(__name__)
# Check if we're in standalone mode
STANDALONE_MODE = 'nodes' not in sys.modules
class _SettingsProxy:
def __init__(self):
self._manager = None
def _resolve(self):
if self._manager is None:
self._manager = get_settings_manager()
return self._manager
def get(self, *args, **kwargs):
return self._resolve().get(*args, **kwargs)
def __getattr__(self, item):
return getattr(self._resolve(), item)
settings = _SettingsProxy()
class LoraManager:
"""Main entry point for LoRA Manager plugin"""

View File

@@ -17,7 +17,7 @@ from ..services.model_lifecycle_service import ModelLifecycleService
from ..services.preview_asset_service import PreviewAssetService
from ..services.server_i18n import server_i18n as default_server_i18n
from ..services.service_registry import ServiceRegistry
from ..services.settings_manager import settings as default_settings
from ..services.settings_manager import get_settings_manager
from ..services.tag_update_service import TagUpdateService
from ..services.websocket_manager import ws_manager as default_ws_manager
from ..services.use_cases import (
@@ -56,14 +56,14 @@ class BaseModelRoutes(ABC):
self,
service=None,
*,
settings_service=default_settings,
settings_service=None,
ws_manager=default_ws_manager,
server_i18n=default_server_i18n,
metadata_provider_factory=get_default_metadata_provider,
) -> None:
self.service = None
self.model_type = ""
self._settings = settings_service
self._settings = settings_service or get_settings_manager()
self._ws_manager = ws_manager
self._server_i18n = server_i18n
self._metadata_provider_factory = metadata_provider_factory
@@ -90,7 +90,7 @@ class BaseModelRoutes(ABC):
self._metadata_sync_service = MetadataSyncService(
metadata_manager=MetadataManager,
preview_service=self._preview_service,
settings=settings_service,
settings=self._settings,
default_metadata_provider_factory=metadata_provider_factory,
metadata_provider_selector=get_metadata_provider,
)

View File

@@ -18,7 +18,7 @@ from ..services.recipes import (
)
from ..services.server_i18n import server_i18n
from ..services.service_registry import ServiceRegistry
from ..services.settings_manager import settings
from ..services.settings_manager import get_settings_manager
from ..utils.constants import CARD_PREVIEW_WIDTH
from ..utils.exif_utils import ExifUtils
from .handlers.recipe_handlers import (
@@ -48,7 +48,7 @@ class BaseRecipeRoutes:
self.recipe_scanner = None
self.lora_scanner = None
self.civitai_client = None
self.settings = settings
self.settings = get_settings_manager()
self.server_i18n = server_i18n
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),

View File

@@ -24,10 +24,17 @@ from ...services.metadata_service import (
update_metadata_providers,
)
from ...services.service_registry import ServiceRegistry
from ...services.settings_manager import settings as default_settings
from ...services.settings_manager import get_settings_manager
from ...services.websocket_manager import ws_manager
from ...services.downloader import get_downloader
from ...utils.constants import DEFAULT_NODE_COLOR, NODE_TYPES, SUPPORTED_MEDIA_EXTENSIONS
from ...utils.constants import (
CIVITAI_USER_MODEL_TYPES,
DEFAULT_NODE_COLOR,
NODE_TYPES,
SUPPORTED_MEDIA_EXTENSIONS,
VALID_LORA_TYPES,
)
from ...utils.civitai_utils import rewrite_preview_url
from ...utils.example_images_paths import is_valid_example_images_root
from ...utils.lora_metadata import extract_trained_words
from ...utils.usage_stats import UsageStats
@@ -80,7 +87,7 @@ class NodeRegistry:
def __init__(self) -> None:
self._lock = asyncio.Lock()
self._nodes: Dict[int, dict] = {}
self._nodes: Dict[str, dict] = {}
self._registry_updated = asyncio.Event()
async def register_nodes(self, nodes: list[dict]) -> None:
@@ -88,11 +95,16 @@ class NodeRegistry:
self._nodes.clear()
for node in nodes:
node_id = node["node_id"]
graph_id = str(node["graph_id"])
unique_id = f"{graph_id}:{node_id}"
node_type = node.get("type", "")
type_id = NODE_TYPES.get(node_type, 0)
bgcolor = node.get("bgcolor") or DEFAULT_NODE_COLOR
self._nodes[node_id] = {
self._nodes[unique_id] = {
"id": node_id,
"graph_id": graph_id,
"graph_name": node.get("graph_name"),
"unique_id": unique_id,
"bgcolor": bgcolor,
"title": node.get("title"),
"type": type_id,
@@ -157,11 +169,11 @@ class SettingsHandler:
def __init__(
self,
*,
settings_service=default_settings,
settings_service=None,
metadata_provider_updater: Callable[[], Awaitable[None]] = update_metadata_providers,
downloader_factory: Callable[[], Awaitable[DownloaderProtocol]] = get_downloader,
) -> None:
self._settings = settings_service
self._settings = settings_service or get_settings_manager()
self._metadata_provider_updater = metadata_provider_updater
self._downloader_factory = downloader_factory
@@ -330,16 +342,65 @@ class LoraCodeHandler:
logger.error("Error broadcasting lora code: %s", exc)
results.append({"node_id": "broadcast", "success": False, "error": str(exc)})
else:
for node_id in node_ids:
for entry in node_ids:
node_identifier = entry
graph_identifier = None
if isinstance(entry, dict):
node_identifier = entry.get("node_id")
graph_identifier = entry.get("graph_id")
if node_identifier is None:
results.append(
{
"node_id": node_identifier,
"graph_id": graph_identifier,
"success": False,
"error": "Missing node_id parameter",
}
)
continue
try:
parsed_node_id = int(node_identifier)
except (TypeError, ValueError):
parsed_node_id = node_identifier
payload = {
"id": parsed_node_id,
"lora_code": lora_code,
"mode": mode,
}
if graph_identifier is not None:
payload["graph_id"] = str(graph_identifier)
try:
self._prompt_server.instance.send_sync(
"lora_code_update",
{"id": node_id, "lora_code": lora_code, "mode": mode},
payload,
)
results.append(
{
"node_id": parsed_node_id,
"graph_id": payload.get("graph_id"),
"success": True,
}
)
results.append({"node_id": node_id, "success": True})
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Error sending lora code to node %s: %s", node_id, exc)
results.append({"node_id": node_id, "success": False, "error": str(exc)})
logger.error(
"Error sending lora code to node %s (graph %s): %s",
parsed_node_id,
graph_identifier,
exc,
)
results.append(
{
"node_id": parsed_node_id,
"graph_id": payload.get("graph_id"),
"success": False,
"error": str(exc),
}
)
return web.json_response({"success": True, "results": results})
except Exception as exc: # pragma: no cover - defensive logging
@@ -557,17 +618,118 @@ class ModelLibraryHandler:
logger.error("Failed to get model versions status: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_civitai_user_models(self, request: web.Request) -> web.Response:
try:
username = request.query.get("username")
if not username:
return web.json_response({"success": False, "error": "Missing required parameter: username"}, status=400)
metadata_provider = await self._metadata_provider_factory()
if not metadata_provider:
return web.json_response({"success": False, "error": "Metadata provider not available"}, status=503)
try:
models = await metadata_provider.get_user_models(username)
except NotImplementedError:
return web.json_response({"success": False, "error": "Metadata provider does not support user model queries"}, status=501)
if models is None:
return web.json_response({"success": False, "error": "Failed to fetch user models"}, status=502)
if not isinstance(models, list):
models = []
lora_scanner = await self._service_registry.get_lora_scanner()
checkpoint_scanner = await self._service_registry.get_checkpoint_scanner()
embedding_scanner = await self._service_registry.get_embedding_scanner()
normalized_allowed_types = {model_type.lower() for model_type in CIVITAI_USER_MODEL_TYPES}
lora_type_aliases = {model_type.lower() for model_type in VALID_LORA_TYPES}
type_scanner_map: Dict[str, object | None] = {
**{alias: lora_scanner for alias in lora_type_aliases},
"checkpoint": checkpoint_scanner,
"textualinversion": embedding_scanner,
}
versions: list[dict] = []
for model in models:
if not isinstance(model, dict):
continue
model_type = str(model.get("type", "")).lower()
if model_type not in normalized_allowed_types:
continue
scanner = type_scanner_map.get(model_type)
if scanner is None:
return web.json_response({"success": False, "error": f'Scanner for type "{model_type}" is not available'}, status=503)
tags_value = model.get("tags")
tags = tags_value if isinstance(tags_value, list) else []
model_id = model.get("id")
try:
model_id_int = int(model_id)
except (TypeError, ValueError):
continue
model_name = model.get("name", "")
versions_data = model.get("modelVersions")
if not isinstance(versions_data, list):
continue
for version in versions_data:
if not isinstance(version, dict):
continue
version_id = version.get("id")
try:
version_id_int = int(version_id)
except (TypeError, ValueError):
continue
images = version.get("images") or []
thumbnail_url = None
if images and isinstance(images, list):
first_image = images[0]
if isinstance(first_image, dict):
raw_url = first_image.get("url")
media_type = first_image.get("type")
rewritten_url, _ = rewrite_preview_url(raw_url, media_type)
thumbnail_url = rewritten_url
in_library = await scanner.check_model_version_exists(version_id_int)
versions.append(
{
"modelId": model_id_int,
"versionId": version_id_int,
"modelName": model_name,
"versionName": version.get("name", ""),
"type": model.get("type"),
"tags": tags,
"baseModel": version.get("baseModel"),
"thumbnailUrl": thumbnail_url,
"inLibrary": in_library,
}
)
return web.json_response({"success": True, "username": username, "versions": versions})
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Failed to get Civitai user models: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
class MetadataArchiveHandler:
def __init__(
self,
*,
metadata_archive_manager_factory: Callable[[], Awaitable[MetadataArchiveManagerProtocol]] = get_metadata_archive_manager,
settings_service=default_settings,
settings_service=None,
metadata_provider_updater: Callable[[], Awaitable[None]] = update_metadata_providers,
) -> None:
self._metadata_archive_manager_factory = metadata_archive_manager_factory
self._settings = settings_service
self._settings = settings_service or get_settings_manager()
self._metadata_provider_updater = metadata_provider_updater
async def download_metadata_archive(self, request: web.Request) -> web.Response:
@@ -679,10 +841,21 @@ class NodeRegistryHandler:
node_id = node.get("node_id")
if node_id is None:
return web.json_response({"success": False, "error": f"Node {index} missing node_id parameter"}, status=400)
graph_id = node.get("graph_id")
if graph_id is None:
return web.json_response({"success": False, "error": f"Node {index} missing graph_id parameter"}, status=400)
graph_name = node.get("graph_name")
try:
node["node_id"] = int(node_id)
except (TypeError, ValueError):
return web.json_response({"success": False, "error": f"Node {index} node_id must be an integer"}, status=400)
node["graph_id"] = str(graph_id)
if graph_name is None:
node["graph_name"] = None
elif isinstance(graph_name, str):
node["graph_name"] = graph_name
else:
node["graph_name"] = str(graph_name)
await self._node_registry.register_nodes(nodes)
return web.json_response({"success": True, "message": f"{len(nodes)} nodes registered successfully"})
@@ -779,6 +952,7 @@ class MiscHandlerSet:
"register_nodes": self.node_registry.register_nodes,
"get_registry": self.node_registry.get_registry,
"check_model_exists": self.model_library.check_model_exists,
"get_civitai_user_models": self.model_library.get_civitai_user_models,
"download_metadata_archive": self.metadata_archive.download_metadata_archive,
"remove_metadata_archive": self.metadata_archive.remove_metadata_archive,
"get_metadata_archive_status": self.metadata_archive.get_metadata_archive_status,

View File

@@ -30,6 +30,7 @@ from ...services.use_cases import (
from ...services.websocket_manager import WebSocketManager
from ...services.websocket_progress_callback import WebSocketProgressCallback
from ...utils.file_utils import calculate_sha256
from ...utils.metadata_manager import MetadataManager
class ModelPageView:
@@ -244,6 +245,8 @@ class ModelManagementHandler:
if not model_data.get("sha256"):
return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400)
await MetadataManager.hydrate_model_data(model_data)
success, error = await self._metadata_sync.fetch_and_update_model(
sha256=model_data["sha256"],
file_path=file_path,
@@ -825,18 +828,30 @@ class ModelCivitaiHandler:
status=400,
)
cache = await self._service.scanner.get_cached_data()
version_index = cache.version_index
for version in versions:
model_file = self._find_model_file(version.get("files", [])) if isinstance(version.get("files"), Iterable) else None
if model_file:
hashes = model_file.get("hashes", {}) if isinstance(model_file, Mapping) else {}
sha256 = hashes.get("SHA256") if isinstance(hashes, Mapping) else None
if sha256:
version["existsLocally"] = self._service.has_hash(sha256)
if version["existsLocally"]:
version["localPath"] = self._service.get_path_by_hash(sha256)
version["modelSizeKB"] = model_file.get("sizeKB") if isinstance(model_file, Mapping) else None
version_id = None
version_id_raw = version.get("id")
if version_id_raw is not None:
try:
version_id = int(str(version_id_raw))
except (TypeError, ValueError):
version_id = None
cache_entry = version_index.get(version_id) if (version_id is not None and version_index) else None
version["existsLocally"] = cache_entry is not None
if cache_entry and isinstance(cache_entry, Mapping):
local_path = cache_entry.get("file_path")
if local_path:
version["localPath"] = local_path
else:
version["existsLocally"] = False
version.pop("localPath", None)
model_file = self._find_model_file(version.get("files", [])) if isinstance(version.get("files"), Iterable) else None
if model_file and isinstance(model_file, Mapping):
version["modelSizeKB"] = model_file.get("sizeKB")
return web.json_response(versions)
except Exception as exc:
self._logger.error("Error fetching %s model versions: %s", self._service.model_type, exc)

View File

@@ -229,11 +229,27 @@ class LoraRoutes(BaseModelRoutes):
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
# Send update to all connected trigger word toggle nodes
for node_id in node_ids:
PromptServer.instance.send_sync("trigger_word_update", {
"id": node_id,
for entry in node_ids:
node_identifier = entry
graph_identifier = None
if isinstance(entry, dict):
node_identifier = entry.get("node_id")
graph_identifier = entry.get("graph_id")
try:
parsed_node_id = int(node_identifier)
except (TypeError, ValueError):
parsed_node_id = node_identifier
payload = {
"id": parsed_node_id,
"message": trigger_words_text
})
}
if graph_identifier is not None:
payload["graph_id"] = str(graph_identifier)
PromptServer.instance.send_sync("trigger_word_update", payload)
return web.json_response({"success": True})

View File

@@ -34,6 +34,7 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("POST", "/api/lm/register-nodes", "register_nodes"),
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"),
RouteDefinition("POST", "/api/lm/download-metadata-archive", "download_metadata_archive"),
RouteDefinition("POST", "/api/lm/remove-metadata-archive", "remove_metadata_archive"),
RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"),

View File

@@ -14,7 +14,7 @@ from ..services.metadata_service import (
get_metadata_provider,
update_metadata_providers,
)
from ..services.settings_manager import settings
from ..services.settings_manager import get_settings_manager
from ..services.downloader import get_downloader
from ..utils.usage_stats import UsageStats
from .handlers.misc_handlers import (
@@ -47,7 +47,7 @@ class MiscRoutes:
def __init__(
self,
*,
settings_service=settings,
settings_service=None,
usage_stats_factory: Callable[[], UsageStats] = UsageStats,
prompt_server: type[PromptServer] = PromptServer,
service_registry_adapter=build_service_registry_adapter(),
@@ -60,7 +60,7 @@ class MiscRoutes:
node_registry: NodeRegistry | None = None,
standalone_mode_flag: bool = standalone_mode,
) -> None:
self._settings = settings_service
self._settings = settings_service or get_settings_manager()
self._usage_stats_factory = usage_stats_factory
self._prompt_server = prompt_server
self._service_registry_adapter = service_registry_adapter

View File

@@ -8,13 +8,32 @@ from collections import defaultdict, Counter
from typing import Dict, List, Any
from ..config import config
from ..services.settings_manager import settings
from ..services.settings_manager import get_settings_manager
from ..services.server_i18n import server_i18n
from ..services.service_registry import ServiceRegistry
from ..utils.usage_stats import UsageStats
logger = logging.getLogger(__name__)
class _SettingsProxy:
def __init__(self):
self._manager = None
def _resolve(self):
if self._manager is None:
self._manager = get_settings_manager()
return self._manager
def get(self, *args, **kwargs):
return self._resolve().get(*args, **kwargs)
def __getattr__(self, item):
return getattr(self._resolve(), item)
settings = _SettingsProxy()
class StatsRoutes:
"""Route handlers for Statistics page and API endpoints"""
@@ -66,7 +85,9 @@ class StatsRoutes:
is_initializing = lora_initializing or checkpoint_initializing or embedding_initializing
# 获取用户语言设置
user_language = settings.get('language', 'en')
settings_object = settings
user_language = settings_object.get('language', 'en')
settings_manager = settings_object if not isinstance(settings_object, _SettingsProxy) else settings_object._resolve()
# 设置服务端i18n语言
server_i18n.set_locale(user_language)
@@ -79,7 +100,7 @@ class StatsRoutes:
template = self.template_env.get_template('statistics.html')
rendered = template.render(
is_initializing=is_initializing,
settings=settings,
settings=settings_manager,
request=request,
t=server_i18n.get_translation,
)

View File

@@ -6,7 +6,7 @@ import os
from ..utils.models import BaseModelMetadata
from ..utils.metadata_manager import MetadataManager
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider
from .settings_manager import settings as default_settings
from .settings_manager import get_settings_manager
logger = logging.getLogger(__name__)
@@ -38,7 +38,7 @@ class BaseModelService(ABC):
self.model_type = model_type
self.scanner = scanner
self.metadata_class = metadata_class
self.settings = settings_provider or default_settings
self.settings = settings_provider or get_settings_manager()
self.cache_repository = cache_repository or ModelCacheRepository(scanner)
self.filter_set = filter_set or ModelFilterSet(self.settings)
self.search_strategy = search_strategy or SearchStrategy()

View File

@@ -1,7 +1,7 @@
import os
import asyncio
import copy
import logging
import asyncio
import os
from typing import Optional, Dict, Tuple, List
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
from .downloader import get_downloader
@@ -157,141 +157,160 @@ class CivitaiClient:
return None
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
"""Get specific model version with additional metadata
Args:
model_id: The Civitai model ID (optional if version_id is provided)
version_id: Optional specific version ID to retrieve
Returns:
Optional[Dict]: The model version data with additional fields or None if not found
"""
"""Get specific model version with additional metadata."""
try:
downloader = await get_downloader()
# Case 1: Only version_id is provided
if model_id is None and version_id is not None:
# First get the version info to extract model_id
success, version = await downloader.make_request(
'GET',
f"{self.base_url}/model-versions/{version_id}",
use_auth=True
)
if not success:
return None
model_id = version.get('modelId')
if not model_id:
logger.error(f"No modelId found in version {version_id}")
return None
# Now get the model data for additional metadata
success, model_data = await downloader.make_request(
'GET',
f"{self.base_url}/models/{model_id}",
use_auth=True
)
if success:
# Enrich version with model data
version['model']['description'] = model_data.get("description")
version['model']['tags'] = model_data.get("tags", [])
version['creator'] = model_data.get("creator")
return await self._get_version_by_id_only(downloader, version_id)
self._remove_comfy_metadata(version)
return version
# Case 2: model_id is provided (with or without version_id)
elif model_id is not None:
# Step 1: Get model data to find version_id if not provided and get additional metadata
success, data = await downloader.make_request(
'GET',
f"{self.base_url}/models/{model_id}",
use_auth=True
)
if not success:
return None
if model_id is not None:
return await self._get_version_with_model_id(downloader, model_id, version_id)
model_versions = data.get('modelVersions', [])
if not model_versions:
logger.warning(f"No model versions found for model {model_id}")
return None
logger.error("Either model_id or version_id must be provided")
return None
# Step 2: Determine the target version entry to use
target_version = None
if version_id is not None:
target_version = next(
(item for item in model_versions if item.get('id') == version_id),
None
)
if target_version is None:
logger.warning(
f"Version {version_id} not found for model {model_id}, defaulting to first version"
)
if target_version is None:
target_version = model_versions[0]
target_version_id = target_version.get('id')
# Step 3: Get detailed version info using the SHA256 hash
model_hash = None
for file_info in target_version.get('files', []):
if file_info.get('type') == 'Model' and file_info.get('primary'):
model_hash = file_info.get('hashes', {}).get('SHA256')
if model_hash:
break
version = None
if model_hash:
success, version = await downloader.make_request(
'GET',
f"{self.base_url}/model-versions/by-hash/{model_hash}",
use_auth=True
)
if not success:
logger.warning(
f"Failed to fetch version by hash for model {model_id} version {target_version_id}: {version}"
)
version = None
else:
logger.warning(
f"No primary model hash found for model {model_id} version {target_version_id}"
)
if version is None:
version = copy.deepcopy(target_version)
version.pop('index', None)
version['modelId'] = model_id
version['model'] = {
'name': data.get('name'),
'type': data.get('type'),
'nsfw': data.get('nsfw'),
'poi': data.get('poi')
}
# Step 4: Enrich version_info with model data
# Add description and tags from model data
model_info = version.get('model')
if not isinstance(model_info, dict):
model_info = {}
version['model'] = model_info
model_info['description'] = data.get("description")
model_info['tags'] = data.get("tags", [])
# Add creator from model data
version['creator'] = data.get("creator")
self._remove_comfy_metadata(version)
return version
# Case 3: Neither model_id nor version_id provided
else:
logger.error("Either model_id or version_id must be provided")
return None
except Exception as e:
logger.error(f"Error fetching model version: {e}")
return None
async def _get_version_by_id_only(self, downloader, version_id: int) -> Optional[Dict]:
version = await self._fetch_version_by_id(downloader, version_id)
if version is None:
return None
model_id = version.get('modelId')
if not model_id:
logger.error(f"No modelId found in version {version_id}")
return None
model_data = await self._fetch_model_data(downloader, model_id)
if model_data:
self._enrich_version_with_model_data(version, model_data)
self._remove_comfy_metadata(version)
return version
async def _get_version_with_model_id(self, downloader, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
model_data = await self._fetch_model_data(downloader, model_id)
if not model_data:
return None
target_version = self._select_target_version(model_data, model_id, version_id)
if target_version is None:
return None
target_version_id = target_version.get('id')
version = await self._fetch_version_by_id(downloader, target_version_id) if target_version_id else None
if version is None:
model_hash = self._extract_primary_model_hash(target_version)
if model_hash:
version = await self._fetch_version_by_hash(downloader, model_hash)
else:
logger.warning(
f"No primary model hash found for model {model_id} version {target_version_id}"
)
if version is None:
version = self._build_version_from_model_data(target_version, model_id, model_data)
self._enrich_version_with_model_data(version, model_data)
self._remove_comfy_metadata(version)
return version
async def _fetch_model_data(self, downloader, model_id: int) -> Optional[Dict]:
success, data = await downloader.make_request(
'GET',
f"{self.base_url}/models/{model_id}",
use_auth=True
)
if success:
return data
logger.warning(f"Failed to fetch model data for model {model_id}")
return None
async def _fetch_version_by_id(self, downloader, version_id: Optional[int]) -> Optional[Dict]:
if version_id is None:
return None
success, version = await downloader.make_request(
'GET',
f"{self.base_url}/model-versions/{version_id}",
use_auth=True
)
if success:
return version
logger.warning(f"Failed to fetch version by id {version_id}")
return None
async def _fetch_version_by_hash(self, downloader, model_hash: Optional[str]) -> Optional[Dict]:
if not model_hash:
return None
success, version = await downloader.make_request(
'GET',
f"{self.base_url}/model-versions/by-hash/{model_hash}",
use_auth=True
)
if success:
return version
logger.warning(f"Failed to fetch version by hash {model_hash}")
return None
def _select_target_version(self, model_data: Dict, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
model_versions = model_data.get('modelVersions', [])
if not model_versions:
logger.warning(f"No model versions found for model {model_id}")
return None
if version_id is not None:
target_version = next(
(item for item in model_versions if item.get('id') == version_id),
None
)
if target_version is None:
logger.warning(
f"Version {version_id} not found for model {model_id}, defaulting to first version"
)
return model_versions[0]
return target_version
return model_versions[0]
def _extract_primary_model_hash(self, version_entry: Dict) -> Optional[str]:
for file_info in version_entry.get('files', []):
if file_info.get('type') == 'Model' and file_info.get('primary'):
hashes = file_info.get('hashes', {})
model_hash = hashes.get('SHA256')
if model_hash:
return model_hash
return None
def _build_version_from_model_data(self, version_entry: Dict, model_id: int, model_data: Dict) -> Dict:
version = copy.deepcopy(version_entry)
version.pop('index', None)
version['modelId'] = model_id
version['model'] = {
'name': model_data.get('name'),
'type': model_data.get('type'),
'nsfw': model_data.get('nsfw'),
'poi': model_data.get('poi')
}
return version
def _enrich_version_with_model_data(self, version: Dict, model_data: Dict) -> None:
model_info = version.get('model')
if not isinstance(model_info, dict):
model_info = {}
version['model'] = model_info
model_info['description'] = model_data.get("description")
model_info['tags'] = model_data.get("tags", [])
version['creator'] = model_data.get("creator")
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version metadata from Civitai
@@ -335,7 +354,7 @@ class CivitaiClient:
async def get_image_info(self, image_id: str) -> Optional[Dict]:
"""Fetch image information from Civitai API
Args:
image_id: The Civitai image ID
@@ -366,3 +385,37 @@ class CivitaiClient:
error_msg = f"Error fetching image info: {e}"
logger.error(error_msg)
return None
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
"""Fetch all models for a specific Civitai user."""
if not username:
return None
try:
downloader = await get_downloader()
url = f"{self.base_url}/models?username={username}"
success, result = await downloader.make_request(
'GET',
url,
use_auth=True
)
if not success:
logger.error("Failed to fetch models for %s: %s", username, result)
return None
items = result.get("items") if isinstance(result, dict) else None
if not isinstance(items, list):
return []
for model in items:
versions = model.get("modelVersions")
if not isinstance(versions, list):
continue
for version in versions:
self._remove_comfy_metadata(version)
return items
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Error fetching models for %s: %s", username, exc)
return None

View File

@@ -4,12 +4,14 @@ import asyncio
from collections import OrderedDict
import uuid
from typing import Dict, List
from urllib.parse import urlparse
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
from ..utils.civitai_utils import rewrite_preview_url
from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager
from .service_registry import ServiceRegistry
from .settings_manager import settings
from .settings_manager import get_settings_manager
from .metadata_service import get_default_metadata_provider
from .downloader import get_downloader
@@ -241,23 +243,24 @@ class DownloadManager:
# Handle use_default_paths
if use_default_paths:
settings_manager = get_settings_manager()
# Set save_dir based on model type
if model_type == 'checkpoint':
default_path = settings.get('default_checkpoint_root')
default_path = settings_manager.get('default_checkpoint_root')
if not default_path:
return {'success': False, 'error': 'Default checkpoint root path not set in settings'}
save_dir = default_path
elif model_type == 'lora':
default_path = settings.get('default_lora_root')
default_path = settings_manager.get('default_lora_root')
if not default_path:
return {'success': False, 'error': 'Default lora root path not set in settings'}
save_dir = default_path
elif model_type == 'embedding':
default_path = settings.get('default_embedding_root')
default_path = settings_manager.get('default_embedding_root')
if not default_path:
return {'success': False, 'error': 'Default embedding root path not set in settings'}
save_dir = default_path
# Calculate relative path using template
relative_path = self._calculate_relative_path(version_info, model_type)
@@ -360,7 +363,8 @@ class DownloadManager:
Relative path string
"""
# Get path template from settings for specific model type
path_template = settings.get_download_path_template(model_type)
settings_manager = get_settings_manager()
path_template = settings_manager.get_download_path_template(model_type)
# If template is empty, return empty path (flat structure)
if not path_template:
@@ -377,7 +381,7 @@ class DownloadManager:
author = 'Anonymous'
# Apply mapping if available
base_model_mappings = settings.get('base_model_path_mappings', {})
base_model_mappings = settings_manager.get('base_model_path_mappings', {})
mapped_base_model = base_model_mappings.get(base_model, base_model)
# Get model tags
@@ -448,70 +452,103 @@ class DownloadManager:
# Download preview image if available
images = version_info.get('images', [])
if images:
# Report preview download progress
if progress_callback:
await progress_callback(1) # 1% progress for starting preview download
# Check if it's a video or an image
is_video = images[0].get('type') == 'video'
if (is_video):
# For videos, use .mp4 extension
preview_ext = '.mp4'
preview_path = os.path.splitext(save_path)[0] + preview_ext
# Download video directly using downloader
downloader = await get_downloader()
success, result = await downloader.download_file(
images[0]['url'],
preview_path,
use_auth=False # Preview images typically don't need auth
)
if success:
metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
else:
# For images, use WebP format for better performance
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
temp_path = temp_file.name
# Download the original image to temp path using downloader
downloader = await get_downloader()
success, content, headers = await downloader.download_to_memory(
images[0]['url'],
use_auth=False
)
if success:
# Save to temp file
with open(temp_path, 'wb') as f:
f.write(content)
# Optimize and convert to WebP
preview_path = os.path.splitext(save_path)[0] + '.webp'
# Use ExifUtils to optimize and convert the image
optimized_data, _ = ExifUtils.optimize_image(
image_data=temp_path,
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=False
)
# Save the optimized image
with open(preview_path, 'wb') as f:
f.write(optimized_data)
# Update metadata
metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
# Remove temporary file
try:
os.unlink(temp_path)
except Exception as e:
logger.warning(f"Failed to delete temp file: {e}")
first_image = images[0] if isinstance(images[0], dict) else None
preview_url = first_image.get('url') if first_image else None
media_type = (first_image.get('type') or '').lower() if first_image else ''
nsfw_level = first_image.get('nsfwLevel', 0) if first_image else 0
def _extension_from_url(url: str, fallback: str) -> str:
try:
parsed = urlparse(url)
except ValueError:
return fallback
ext = os.path.splitext(parsed.path)[1]
return ext or fallback
preview_downloaded = False
preview_path = None
if preview_url:
downloader = await get_downloader()
if media_type == 'video':
preview_ext = _extension_from_url(preview_url, '.mp4')
preview_path = os.path.splitext(save_path)[0] + preview_ext
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='video')
attempt_urls: List[str] = []
if rewritten:
attempt_urls.append(rewritten_url)
attempt_urls.append(preview_url)
seen_attempts = set()
for attempt in attempt_urls:
if not attempt or attempt in seen_attempts:
continue
seen_attempts.add(attempt)
success, _ = await downloader.download_file(
attempt,
preview_path,
use_auth=False
)
if success:
preview_downloaded = True
break
else:
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='image')
if rewritten:
preview_ext = _extension_from_url(preview_url, '.png')
preview_path = os.path.splitext(save_path)[0] + preview_ext
success, _ = await downloader.download_file(
rewritten_url,
preview_path,
use_auth=False
)
if success:
preview_downloaded = True
if not preview_downloaded:
temp_path: str | None = None
try:
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
temp_path = temp_file.name
success, content, _ = await downloader.download_to_memory(
preview_url,
use_auth=False
)
if success:
with open(temp_path, 'wb') as temp_file_handle:
temp_file_handle.write(content)
preview_path = os.path.splitext(save_path)[0] + '.webp'
optimized_data, _ = ExifUtils.optimize_image(
image_data=temp_path,
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=False
)
with open(preview_path, 'wb') as preview_file:
preview_file.write(optimized_data)
preview_downloaded = True
finally:
if temp_path and os.path.exists(temp_path):
try:
os.unlink(temp_path)
except Exception as e:
logger.warning(f"Failed to delete temp file: {e}")
if preview_downloaded and preview_path:
metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = nsfw_level
if download_id and download_id in self._active_downloads:
self._active_downloads[download_id]['preview_path'] = preview_path
# Report preview download completion
if progress_callback:
await progress_callback(3) # 3% progress after preview download
@@ -675,7 +712,15 @@ class DownloadManager:
except Exception as e:
logger.error(f"Error deleting metadata file: {e}")
# Delete preview file if exists (.webp or .mp4)
preview_path_value = download_info.get('preview_path')
if preview_path_value and os.path.exists(preview_path_value):
try:
os.unlink(preview_path_value)
logger.debug(f"Deleted preview file: {preview_path_value}")
except Exception as e:
logger.error(f"Error deleting preview file: {e}")
# Delete preview file if exists (.webp or .mp4) for legacy paths
for preview_ext in ['.webp', '.mp4']:
preview_path = os.path.splitext(file_path)[0] + preview_ext
if os.path.exists(preview_path):
@@ -708,4 +753,4 @@ class DownloadManager:
}
for task_id, info in self._active_downloads.items()
]
}
}

View File

@@ -16,7 +16,7 @@ import asyncio
import aiohttp
from datetime import datetime
from typing import Optional, Dict, Tuple, Callable, Union
from ..services.settings_manager import settings
from ..services.settings_manager import get_settings_manager
logger = logging.getLogger(__name__)
@@ -94,12 +94,13 @@ class Downloader:
# Check for app-level proxy settings
proxy_url = None
if settings.get('proxy_enabled', False):
proxy_host = settings.get('proxy_host', '').strip()
proxy_port = settings.get('proxy_port', '').strip()
proxy_type = settings.get('proxy_type', 'http').lower()
proxy_username = settings.get('proxy_username', '').strip()
proxy_password = settings.get('proxy_password', '').strip()
settings_manager = get_settings_manager()
if settings_manager.get('proxy_enabled', False):
proxy_host = settings_manager.get('proxy_host', '').strip()
proxy_port = settings_manager.get('proxy_port', '').strip()
proxy_type = settings_manager.get('proxy_type', 'http').lower()
proxy_username = settings_manager.get('proxy_username', '').strip()
proxy_password = settings_manager.get('proxy_password', '').strip()
if proxy_host and proxy_port:
# Build proxy URL
@@ -146,7 +147,8 @@ class Downloader:
if use_auth:
# Add CivitAI API key if available
api_key = settings.get('civitai_api_key')
settings_manager = get_settings_manager()
api_key = settings_manager.get('civitai_api_key')
if api_key:
headers['Authorization'] = f'Bearer {api_key}'
headers['Content-Type'] = 'application/json'

View File

@@ -11,7 +11,7 @@ from pathlib import Path
from typing import Dict, List, Tuple
from .service_registry import ServiceRegistry
from .settings_manager import settings
from .settings_manager import get_settings_manager
from ..utils.example_images_paths import iter_library_roots
@@ -62,7 +62,8 @@ class ExampleImagesCleanupService:
async def cleanup_example_image_folders(self) -> Dict[str, object]:
"""Clean empty or orphaned example image folders by moving them under a deleted bucket."""
example_images_path = settings.get("example_images_path")
settings_manager = get_settings_manager()
example_images_path = settings_manager.get("example_images_path")
if not example_images_path:
logger.debug("Cleanup skipped: example images path not configured")
return {

View File

@@ -6,7 +6,7 @@ from .model_metadata_provider import (
CivitaiModelMetadataProvider,
FallbackMetadataProvider
)
from .settings_manager import settings
from .settings_manager import get_settings_manager
from .metadata_archive_manager import MetadataArchiveManager
from .service_registry import ServiceRegistry
@@ -21,7 +21,8 @@ async def initialize_metadata_providers():
provider_manager.default_provider = None
# Get settings
enable_archive_db = settings.get('enable_metadata_archive_db', False)
settings_manager = get_settings_manager()
enable_archive_db = settings_manager.get('enable_metadata_archive_db', False)
providers = []
@@ -87,7 +88,8 @@ async def update_metadata_providers():
"""Update metadata providers based on current settings"""
try:
# Get current settings
enable_archive_db = settings.get('enable_metadata_archive_db', False)
settings_manager = get_settings_manager()
enable_archive_db = settings_manager.get('enable_metadata_archive_db', False)
# Reinitialize all providers with new settings
provider_manager = await initialize_metadata_providers()

View File

@@ -1,6 +1,6 @@
import asyncio
from typing import List, Dict, Tuple
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from dataclasses import dataclass, field
from operator import itemgetter
from natsort import natsorted
@@ -17,10 +17,12 @@ SUPPORTED_SORT_MODES = [
@dataclass
class ModelCache:
"""Cache structure for model data with extensible sorting"""
"""Cache structure for model data with extensible sorting."""
raw_data: List[Dict]
folders: List[str]
version_index: Dict[int, Dict] = field(default_factory=dict)
def __post_init__(self):
self._lock = asyncio.Lock()
# Cache for last sort: (sort_key, order) -> sorted list
@@ -28,6 +30,58 @@ class ModelCache:
self._last_sorted_data: List[Dict] = []
# Default sort on init
asyncio.create_task(self.resort())
self.rebuild_version_index()
@staticmethod
def _normalize_version_id(value: Any) -> Optional[int]:
"""Normalize a potential version identifier into an integer."""
if isinstance(value, int):
return value
if isinstance(value, str):
try:
return int(value)
except ValueError:
return None
return None
def rebuild_version_index(self) -> None:
"""Rebuild the version index from the current raw data."""
self.version_index = {}
for item in self.raw_data:
self.add_to_version_index(item)
def add_to_version_index(self, item: Dict) -> None:
"""Register a cache item in the version index if possible."""
civitai_data = item.get('civitai') if isinstance(item, dict) else None
if not isinstance(civitai_data, dict):
return
version_id = self._normalize_version_id(civitai_data.get('id'))
if version_id is None:
return
self.version_index[version_id] = item
def remove_from_version_index(self, item: Dict) -> None:
"""Remove a cache item from the version index if present."""
civitai_data = item.get('civitai') if isinstance(item, dict) else None
if not isinstance(civitai_data, dict):
return
version_id = self._normalize_version_id(civitai_data.get('id'))
if version_id is None:
return
existing = self.version_index.get(version_id)
if existing is item or (
isinstance(existing, dict)
and existing.get('file_path') == item.get('file_path')
):
self.version_index.pop(version_id, None)
async def resort(self):
"""Resort cached data according to last sort mode if set"""
@@ -41,6 +95,7 @@ class ModelCache:
all_folders = set(l['folder'] for l in self.raw_data)
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
self.rebuild_version_index()
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
"""Sort data by sort_key and order"""

View File

@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
from ..utils.utils import calculate_relative_path_for_model, remove_empty_dirs
from ..utils.constants import AUTO_ORGANIZE_BATCH_SIZE
from ..services.settings_manager import settings
from ..services.settings_manager import get_settings_manager
logger = logging.getLogger(__name__)
@@ -114,7 +114,8 @@ class ModelFileService:
raise ValueError('No model roots configured')
# Check if flat structure is configured for this model type
path_template = settings.get_download_path_template(self.model_type)
settings_manager = get_settings_manager()
path_template = settings_manager.get_download_path_template(self.model_type)
result.is_flat_structure = not path_template
# Initialize tracking

View File

@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
import json
import logging
from typing import Optional, Dict, Tuple, Any
from typing import Optional, Dict, Tuple, Any, List
from .downloader import get_downloader
try:
@@ -61,6 +61,11 @@ class ModelMetadataProvider(ABC):
"""Fetch model version metadata"""
pass
@abstractmethod
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
"""Fetch models owned by the specified user"""
pass
class CivitaiModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses Civitai API for metadata"""
@@ -79,6 +84,9 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
return await self.client.get_model_version_info(version_id)
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
return await self.client.get_user_models(username)
class CivArchiveModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses CivArchive HTML page parsing for metadata"""
@@ -197,6 +205,10 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider):
"""Not supported by CivArchive provider - requires both model_id and version_id"""
return None, "CivArchive provider requires both model_id and version_id"
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
"""Not supported by CivArchive provider"""
return None
class SQLiteModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses SQLite database for metadata"""
@@ -329,20 +341,24 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
"""Fetch model version metadata from SQLite database"""
async with self._aiosqlite.connect(self.db_path) as db:
db.row_factory = self._aiosqlite.Row
# Get version details
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
cursor = await db.execute(version_query, (version_id,))
version_row = await cursor.fetchone()
if not version_row:
return None, "Model version not found"
model_id = version_row['model_id']
# Build complete version data with model info
version_data = await self._get_version_with_model_data(db, model_id, version_id)
return version_data, None
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
"""Listing models by username is not supported for archive database"""
return None
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
"""Helper to build version data with model information"""
@@ -481,6 +497,17 @@ class FallbackMetadataProvider(ModelMetadataProvider):
continue
return None, "No provider could retrieve the data"
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
for provider in self.providers:
try:
result = await provider.get_user_models(username)
if result is not None:
return result
except Exception as e:
logger.debug(f"Provider failed for get_user_models: {e}")
continue
return None
class ModelMetadataProviderManager:
"""Manager for selecting and using model metadata providers"""
@@ -522,6 +549,11 @@ class ModelMetadataProviderManager:
"""Fetch model version info using specified or default provider"""
provider = self._get_provider(provider_name)
return await provider.get_model_version_info(version_id)
async def get_user_models(self, username: str, provider_name: str = None) -> Optional[List[Dict]]:
"""Fetch models owned by the specified user"""
provider = self._get_provider(provider_name)
return await provider.get_user_models(username)
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
"""Get provider by name or default provider"""

View File

@@ -634,7 +634,8 @@ class ModelScanner:
if model_data:
# Add to cache
self._cache.raw_data.append(model_data)
self._cache.add_to_version_index(model_data)
# Update hash index if available
if 'sha256' in model_data and 'file_path' in model_data:
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
@@ -661,7 +662,9 @@ class ModelScanner:
for path in missing_files:
try:
model_to_remove = path_to_item[path]
self._cache.remove_from_version_index(model_to_remove)
# Update tags count
for tag in model_to_remove.get('tags', []):
if tag in self._tags_count:
@@ -684,6 +687,8 @@ class ModelScanner:
all_folders = set(item.get('folder', '') for item in self._cache.raw_data)
self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
self._cache.rebuild_version_index()
# Resort cache
await self._cache.resort()
@@ -829,6 +834,8 @@ class ModelScanner:
else:
self._cache.raw_data = list(scan_result.raw_data)
self._cache.rebuild_version_index()
await self._cache.resort()
async def _gather_model_data(
@@ -934,7 +941,8 @@ class ModelScanner:
# Add to cache
self._cache.raw_data.append(metadata_dict)
self._cache.add_to_version_index(metadata_dict)
# Resort cache data
await self._cache.resort()
@@ -1076,6 +1084,9 @@ class ModelScanner:
cache = await self.get_cached_data()
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
if existing_item:
cache.remove_from_version_index(existing_item)
if existing_item and 'tags' in existing_item:
for tag in existing_item.get('tags', []):
if tag in self._tags_count:
@@ -1106,6 +1117,7 @@ class ModelScanner:
)
cache.raw_data.append(cache_entry)
cache.add_to_version_index(cache_entry)
sha_value = cache_entry.get('sha256')
if sha_value:
@@ -1117,6 +1129,8 @@ class ModelScanner:
for tag in cache_entry.get('tags', []):
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
cache.rebuild_version_index()
await cache.resort()
if cache_modified:
@@ -1339,11 +1353,12 @@ class ModelScanner:
# Update hash index
for model in models_to_remove:
file_path = model['file_path']
self._cache.remove_from_version_index(model)
if hasattr(self, '_hash_index') and self._hash_index:
# Get the hash and filename before removal for duplicate checking
file_name = os.path.splitext(os.path.basename(file_path))[0]
hash_val = model.get('sha256', '').lower()
# Remove from hash index
self._hash_index.remove_by_path(file_path, hash_val)
@@ -1352,8 +1367,9 @@ class ModelScanner:
# Update cache data
self._cache.raw_data = [item for item in self._cache.raw_data if item['file_path'] not in file_paths]
# Resort cache
self._cache.rebuild_version_index()
await self._cache.resort()
await self._persist_current_cache()
@@ -1393,16 +1409,17 @@ class ModelScanner:
Returns:
bool: True if the model version exists, False otherwise
"""
try:
normalized_id = int(model_version_id)
except (TypeError, ValueError):
return False
try:
cache = await self.get_cached_data()
if not cache or not cache.raw_data:
if not cache:
return False
for item in cache.raw_data:
if item.get('civitai') and item['civitai'].get('id') == model_version_id:
return True
return False
return normalized_id in cache.version_index
except Exception as e:
logger.error(f"Error checking model version existence: {e}")
return False

View File

@@ -351,7 +351,7 @@ class PersistentModelCache:
def get_persistent_cache() -> PersistentModelCache:
from .settings_manager import settings as settings_service # Local import to avoid cycles
from .settings_manager import get_settings_manager # Local import to avoid cycles
library_name = settings_service.get_active_library_name()
library_name = get_settings_manager().get_active_library_name()
return PersistentModelCache.get_default(library_name)

View File

@@ -5,8 +5,10 @@ from __future__ import annotations
import logging
import os
from typing import Awaitable, Callable, Dict, Optional, Sequence
from urllib.parse import urlparse
from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS
from ..utils.civitai_utils import rewrite_preview_url
logger = logging.getLogger(__name__)
@@ -45,23 +47,59 @@ class PreviewAssetService:
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
preview_dir = os.path.dirname(metadata_path)
is_video = first_preview.get("type") == "video"
preview_url = first_preview.get("url")
if not preview_url:
return
def extension_from_url(url: str, fallback: str) -> str:
try:
parsed = urlparse(url)
except ValueError:
return fallback
ext = os.path.splitext(parsed.path)[1]
return ext or fallback
downloader = await self._downloader_factory()
if is_video:
extension = ".mp4"
extension = extension_from_url(preview_url, ".mp4")
preview_path = os.path.join(preview_dir, base_name + extension)
downloader = await self._downloader_factory()
success, result = await downloader.download_file(
first_preview["url"], preview_path, use_auth=False
)
if success:
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type="video")
attempt_urls = []
if rewritten:
attempt_urls.append(rewritten_url)
attempt_urls.append(preview_url)
seen: set[str] = set()
for candidate in attempt_urls:
if not candidate or candidate in seen:
continue
seen.add(candidate)
success, _ = await downloader.download_file(candidate, preview_path, use_auth=False)
if success:
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
return
else:
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type="image")
if rewritten:
extension = extension_from_url(preview_url, ".png")
preview_path = os.path.join(preview_dir, base_name + extension)
success, _ = await downloader.download_file(
rewritten_url, preview_path, use_auth=False
)
if success:
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
return
extension = ".webp"
preview_path = os.path.join(preview_dir, base_name + extension)
downloader = await self._downloader_factory()
success, content, _headers = await downloader.download_to_memory(
first_preview["url"], use_auth=False
preview_url, use_auth=False
)
if not success:
return

View File

@@ -3,6 +3,7 @@ import json
import os
import logging
from datetime import datetime, timezone
from threading import Lock
from typing import Any, Dict, Iterable, List, Mapping, Optional
from ..utils.settings_paths import ensure_settings_file
@@ -688,4 +689,38 @@ class SettingsManager:
return templates.get(model_type, '{base_model}/{first_tag}')
settings = SettingsManager()
_SETTINGS_MANAGER: Optional["SettingsManager"] = None
_SETTINGS_MANAGER_LOCK = Lock()
# Legacy module-level alias for backwards compatibility with callers that
# monkeypatch ``py.services.settings_manager.settings`` during tests.
settings: Optional["SettingsManager"] = None
def get_settings_manager() -> "SettingsManager":
"""Return the lazily initialised global :class:`SettingsManager`."""
global _SETTINGS_MANAGER, settings
if settings is not None:
return settings
if _SETTINGS_MANAGER is None:
with _SETTINGS_MANAGER_LOCK:
if _SETTINGS_MANAGER is None:
_SETTINGS_MANAGER = SettingsManager()
settings = _SETTINGS_MANAGER
return _SETTINGS_MANAGER
def reset_settings_manager() -> None:
"""Reset the cached settings manager instance.
Primarily intended for tests so they can configure the settings
directory before the manager touches the filesystem.
"""
global _SETTINGS_MANAGER, settings
with _SETTINGS_MANAGER_LOCK:
_SETTINGS_MANAGER = None
settings = None

View File

@@ -6,6 +6,7 @@ import logging
from typing import Any, Dict, Optional, Protocol, Sequence
from ..metadata_sync_service import MetadataSyncService
from ...utils.metadata_manager import MetadataManager
class MetadataRefreshProgressReporter(Protocol):
@@ -70,6 +71,7 @@ class BulkMetadataRefreshUseCase:
for model in to_process:
try:
original_name = model.get("model_name")
await MetadataManager.hydrate_model_data(model)
result, _ = await self._metadata_sync.fetch_and_update_model(
sha256=model["sha256"],
file_path=model["file_path"],

48
py/utils/civitai_utils.py Normal file
View File

@@ -0,0 +1,48 @@
"""Utilities for working with Civitai assets."""
from __future__ import annotations
from urllib.parse import urlparse, urlunparse
def rewrite_preview_url(source_url: str | None, media_type: str | None = None) -> tuple[str | None, bool]:
"""Rewrite Civitai preview URLs to use optimized renditions.
Args:
source_url: Original preview URL from the Civitai API.
media_type: Optional media type hint (e.g. ``"image"`` or ``"video"``).
Returns:
A tuple of the potentially rewritten URL and a flag indicating whether the
replacement occurred. When the URL is not rewritten, the original value is
returned with ``False``.
"""
if not source_url:
return source_url, False
try:
parsed = urlparse(source_url)
except ValueError:
return source_url, False
if parsed.netloc.lower() != "image.civitai.com":
return source_url, False
replacement = "/width=450,optimized=true"
if (media_type or "").lower() == "video":
replacement = "/transcode=true,width=450,optimized=true"
if "/original=true" not in parsed.path:
return source_url, False
updated_path = parsed.path.replace("/original=true", replacement, 1)
if updated_path == parsed.path:
return source_url, False
rewritten = urlunparse(parsed._replace(path=updated_path))
print(rewritten)
return rewritten, True
__all__ = ["rewrite_preview_url"]

View File

@@ -48,6 +48,13 @@ SUPPORTED_MEDIA_EXTENSIONS = {
# Valid Lora types
VALID_LORA_TYPES = ['lora', 'locon', 'dora']
# Supported Civitai model types for user model queries (case-insensitive)
CIVITAI_USER_MODEL_TYPES = [
*VALID_LORA_TYPES,
'textualinversion',
'checkpoint',
]
# Auto-organize settings
AUTO_ORGANIZE_BATCH_SIZE = 50 # Process models in batches to avoid overwhelming the system

View File

@@ -18,7 +18,7 @@ from ..utils.metadata_manager import MetadataManager
from .example_images_processor import ExampleImagesProcessor
from .example_images_metadata import MetadataUpdater
from ..services.downloader import get_downloader
from ..services.settings_manager import settings
from ..services.settings_manager import get_settings_manager
class ExampleImagesDownloadError(RuntimeError):
@@ -107,7 +107,7 @@ class DownloadManager:
self._state_lock = state_lock or asyncio.Lock()
def _resolve_output_dir(self, library_name: str | None = None) -> str:
base_path = settings.get('example_images_path')
base_path = get_settings_manager().get('example_images_path')
if not base_path:
return ''
return ensure_library_root_exists(library_name)
@@ -126,7 +126,8 @@ class DownloadManager:
model_types = data.get('model_types', ['lora', 'checkpoint'])
delay = float(data.get('delay', 0.2))
base_path = settings.get('example_images_path')
settings_manager = get_settings_manager()
base_path = settings_manager.get('example_images_path')
if not base_path:
error_msg = 'Example images path not configured in settings'
@@ -138,7 +139,7 @@ class DownloadManager:
}
raise DownloadConfigurationError(error_msg)
active_library = settings.get_active_library_name()
active_library = get_settings_manager().get_active_library_name()
output_dir = self._resolve_output_dir(active_library)
if not output_dir:
raise DownloadConfigurationError('Example images path not configured in settings')
@@ -151,7 +152,7 @@ class DownloadManager:
progress_file = os.path.join(output_dir, '.download_progress.json')
progress_source = progress_file
if uses_library_scoped_folders():
legacy_root = settings.get('example_images_path') or ''
legacy_root = get_settings_manager().get('example_images_path') or ''
legacy_progress = os.path.join(legacy_root, '.download_progress.json') if legacy_root else ''
if legacy_progress and os.path.exists(legacy_progress) and not os.path.exists(progress_file):
try:
@@ -555,11 +556,12 @@ class DownloadManager:
if not model_hashes:
raise DownloadConfigurationError('Missing model_hashes parameter')
base_path = settings.get('example_images_path')
settings_manager = get_settings_manager()
base_path = settings_manager.get('example_images_path')
if not base_path:
raise DownloadConfigurationError('Example images path not configured in settings')
active_library = settings.get_active_library_name()
active_library = settings_manager.get_active_library_name()
output_dir = self._resolve_output_dir(active_library)
if not output_dir:
raise DownloadConfigurationError('Example images path not configured in settings')

View File

@@ -3,7 +3,7 @@ import os
import sys
import subprocess
from aiohttp import web
from ..services.settings_manager import settings
from ..services.settings_manager import get_settings_manager
from ..utils.example_images_paths import (
get_model_folder,
get_model_relative_path,
@@ -37,7 +37,8 @@ class ExampleImagesFileManager:
}, status=400)
# Get example images path from settings
example_images_path = settings.get('example_images_path')
settings_manager = get_settings_manager()
example_images_path = settings_manager.get('example_images_path')
if not example_images_path:
return web.json_response({
'success': False,
@@ -109,7 +110,8 @@ class ExampleImagesFileManager:
}, status=400)
# Get example images path from settings
example_images_path = settings.get('example_images_path')
settings_manager = get_settings_manager()
example_images_path = settings_manager.get('example_images_path')
if not example_images_path:
return web.json_response({
'success': False,
@@ -183,7 +185,8 @@ class ExampleImagesFileManager:
}, status=400)
# Get example images path from settings
example_images_path = settings.get('example_images_path')
settings_manager = get_settings_manager()
example_images_path = settings_manager.get('example_images_path')
if not example_images_path:
return web.json_response({
'has_images': False

View File

@@ -1,12 +1,13 @@
import logging
import os
import re
from typing import TYPE_CHECKING, Any, Dict, Optional
from ..recipes.constants import GEN_PARAM_KEYS
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
from ..services.metadata_sync_service import MetadataSyncService
from ..services.preview_asset_service import PreviewAssetService
from ..services.settings_manager import settings
from ..services.settings_manager import get_settings_manager
from ..services.downloader import get_downloader
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
from ..utils.exif_utils import ExifUtils
@@ -20,13 +21,46 @@ _preview_service = PreviewAssetService(
exif_utils=ExifUtils,
)
_metadata_sync_service = MetadataSyncService(
metadata_manager=MetadataManager,
preview_service=_preview_service,
settings=settings,
default_metadata_provider_factory=get_default_metadata_provider,
metadata_provider_selector=get_metadata_provider,
)
_metadata_sync_service: MetadataSyncService | None = None
_metadata_sync_service_settings: Optional["SettingsManager"] = None
if TYPE_CHECKING: # pragma: no cover - import for type checkers only
from ..services.settings_manager import SettingsManager
def _build_metadata_sync_service(settings_manager: "SettingsManager") -> MetadataSyncService:
"""Construct a metadata sync service bound to the provided settings."""
return MetadataSyncService(
metadata_manager=MetadataManager,
preview_service=_preview_service,
settings=settings_manager,
default_metadata_provider_factory=get_default_metadata_provider,
metadata_provider_selector=get_metadata_provider,
)
def _get_metadata_sync_service() -> MetadataSyncService:
"""Return the shared metadata sync service, initialising it lazily."""
global _metadata_sync_service, _metadata_sync_service_settings
settings_manager = get_settings_manager()
if isinstance(_metadata_sync_service, MetadataSyncService):
if _metadata_sync_service_settings is not settings_manager:
_metadata_sync_service = _build_metadata_sync_service(settings_manager)
_metadata_sync_service_settings = settings_manager
elif _metadata_sync_service is None:
_metadata_sync_service = _build_metadata_sync_service(settings_manager)
_metadata_sync_service_settings = settings_manager
else:
# Tests may inject stand-ins that do not match the sync service type. Preserve
# those injections while still updating our cached settings reference so the
# next real service instantiation uses the current configuration.
_metadata_sync_service_settings = settings_manager
return _metadata_sync_service
class MetadataUpdater:
@@ -71,7 +105,8 @@ class MetadataUpdater:
async def update_cache_func(old_path, new_path, metadata):
return await scanner.update_single_model_cache(old_path, new_path, metadata)
success, error = await _metadata_sync_service.fetch_and_update_model(
await MetadataManager.hydrate_model_data(model_data)
success, error = await _get_metadata_sync_service().fetch_and_update_model(
sha256=model_hash,
file_path=file_path,
model_data=model_data,
@@ -151,16 +186,16 @@ class MetadataUpdater:
if is_supported:
local_images_paths.append(file_path)
await MetadataManager.hydrate_model_data(model)
civitai_data = model.setdefault('civitai', {})
# Check if metadata update is needed (no civitai field or empty images)
needs_update = not model.get('civitai') or not model.get('civitai', {}).get('images')
needs_update = not civitai_data or not civitai_data.get('images')
if needs_update and local_images_paths:
logger.debug(f"Found {len(local_images_paths)} local example images for {model.get('model_name')}, updating metadata")
# Create or get civitai field
if not model.get('civitai'):
model['civitai'] = {}
# Create images array
images = []
@@ -195,16 +230,13 @@ class MetadataUpdater:
images.append(image_entry)
# Update the model's civitai.images field
model['civitai']['images'] = images
civitai_data['images'] = images
# Save metadata to .metadata.json file
file_path = model.get('file_path')
try:
# Create a copy of model data without 'folder' field
model_copy = model.copy()
model_copy.pop('folder', None)
# Write metadata to file
await MetadataManager.save_metadata(file_path, model_copy)
logger.info(f"Saved metadata for {model.get('model_name')}")
except Exception as e:
@@ -237,16 +269,13 @@ class MetadataUpdater:
tuple: (regular_images, custom_images) - Both image arrays
"""
try:
# Ensure civitai field exists in model_data
if not model_data.get('civitai'):
model_data['civitai'] = {}
# Ensure customImages array exists
if not model_data['civitai'].get('customImages'):
model_data['civitai']['customImages'] = []
# Get current customImages array
custom_images = model_data['civitai']['customImages']
await MetadataManager.hydrate_model_data(model_data)
civitai_data = model_data.setdefault('civitai', {})
custom_images = civitai_data.get('customImages')
if not isinstance(custom_images, list):
custom_images = []
civitai_data['customImages'] = custom_images
# Add new image entry for each imported file
for path_tuple in newly_imported_paths:
@@ -304,11 +333,8 @@ class MetadataUpdater:
file_path = model_data.get('file_path')
if file_path:
try:
# Create a copy of model data without 'folder' field
model_copy = model_data.copy()
model_copy.pop('folder', None)
# Write metadata to file
await MetadataManager.save_metadata(file_path, model_copy)
logger.info(f"Saved metadata for {model_data.get('model_name')}")
except Exception as e:
@@ -319,7 +345,7 @@ class MetadataUpdater:
await scanner.update_single_model_cache(file_path, file_path, model_data)
# Get regular images array (might be None)
regular_images = model_data['civitai'].get('images', [])
regular_images = civitai_data.get('images', [])
# Return both image arrays
return regular_images, custom_images

View File

@@ -3,7 +3,7 @@ import logging
import os
import re
import json
from ..services.settings_manager import settings
from ..services.settings_manager import get_settings_manager
from ..services.service_registry import ServiceRegistry
from ..utils.example_images_paths import iter_library_roots
from ..utils.metadata_manager import MetadataManager
@@ -14,6 +14,25 @@ logger = logging.getLogger(__name__)
CURRENT_NAMING_VERSION = 2 # Increment this when naming conventions change
class _SettingsProxy:
def __init__(self):
self._manager = None
def _resolve(self):
if self._manager is None:
self._manager = get_settings_manager()
return self._manager
def get(self, *args, **kwargs):
return self._resolve().get(*args, **kwargs)
def __getattr__(self, item):
return getattr(self._resolve(), item)
settings = _SettingsProxy()
class ExampleImagesMigration:
"""Handles migrations for example images naming conventions"""

View File

@@ -8,7 +8,7 @@ import re
import shutil
from typing import Iterable, List, Optional, Tuple
from ..services.settings_manager import settings
from ..services.settings_manager import get_settings_manager
_HEX_PATTERN = re.compile(r"[a-fA-F0-9]{64}")
@@ -18,7 +18,8 @@ logger = logging.getLogger(__name__)
def _get_configured_libraries() -> List[str]:
"""Return configured library names if multi-library support is enabled."""
libraries = settings.get("libraries")
settings_manager = get_settings_manager()
libraries = settings_manager.get("libraries")
if isinstance(libraries, dict) and libraries:
return list(libraries.keys())
return []
@@ -27,7 +28,8 @@ def _get_configured_libraries() -> List[str]:
def get_example_images_root() -> str:
"""Return the root directory configured for example images."""
root = settings.get("example_images_path") or ""
settings_manager = get_settings_manager()
root = settings_manager.get("example_images_path") or ""
return os.path.abspath(root) if root else ""
@@ -41,7 +43,8 @@ def uses_library_scoped_folders() -> bool:
def sanitize_library_name(library_name: Optional[str]) -> str:
"""Return a filesystem safe library name."""
name = library_name or settings.get_active_library_name() or "default"
settings_manager = get_settings_manager()
name = library_name or settings_manager.get_active_library_name() or "default"
safe_name = re.sub(r"[^A-Za-z0-9_.-]", "_", name)
return safe_name or "default"
@@ -161,11 +164,13 @@ def iter_library_roots() -> Iterable[Tuple[str, str]]:
results.append((library, get_library_root(library)))
else:
# Fall back to the active library to avoid skipping migrations/cleanup
active = settings.get_active_library_name() or "default"
settings_manager = get_settings_manager()
active = settings_manager.get_active_library_name() or "default"
results.append((active, get_library_root(active)))
return results
active = settings.get_active_library_name() or "default"
settings_manager = get_settings_manager()
active = settings_manager.get_active_library_name() or "default"
return [(active, root)]

View File

@@ -6,7 +6,7 @@ import string
from aiohttp import web
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
from ..services.service_registry import ServiceRegistry
from ..services.settings_manager import settings
from ..services.settings_manager import get_settings_manager
from ..utils.example_images_paths import get_model_folder, get_model_relative_path
from .example_images_metadata import MetadataUpdater
from ..utils.metadata_manager import MetadataManager
@@ -318,7 +318,7 @@ class ExampleImagesProcessor:
try:
# Get example images path
example_images_path = settings.get('example_images_path')
example_images_path = get_settings_manager().get('example_images_path')
if not example_images_path:
raise ExampleImagesValidationError('No example images path configured')
@@ -442,7 +442,7 @@ class ExampleImagesProcessor:
}, status=400)
# Get example images path
example_images_path = settings.get('example_images_path')
example_images_path = get_settings_manager().get('example_images_path')
if not example_images_path:
return web.json_response({
'success': False,
@@ -475,15 +475,17 @@ class ExampleImagesProcessor:
'error': f"Model with hash {model_hash} not found in cache"
}, status=404)
# Check if model has custom images
if not model_data.get('civitai', {}).get('customImages'):
await MetadataManager.hydrate_model_data(model_data)
civitai_data = model_data.setdefault('civitai', {})
custom_images = civitai_data.get('customImages')
if not isinstance(custom_images, list) or not custom_images:
return web.json_response({
'success': False,
'error': f"Model has no custom images"
}, status=404)
# Find the custom image with matching short_id
custom_images = model_data['civitai']['customImages']
matching_image = None
new_custom_images = []
@@ -527,17 +529,15 @@ class ExampleImagesProcessor:
logger.warning(f"File for custom example with id {short_id} not found, but metadata will still be updated")
# Update metadata
model_data['civitai']['customImages'] = new_custom_images
civitai_data['customImages'] = new_custom_images
model_data.setdefault('civitai', {})['customImages'] = new_custom_images
# Save updated metadata to file
file_path = model_data.get('file_path')
if file_path:
try:
# Create a copy of model data without 'folder' field
model_copy = model_data.copy()
model_copy.pop('folder', None)
# Write metadata to file
await MetadataManager.save_metadata(file_path, model_copy)
logger.debug(f"Saved updated metadata for {model_data.get('model_name')}")
except Exception as e:
@@ -551,7 +551,7 @@ class ExampleImagesProcessor:
await scanner.update_single_model_cache(file_path, file_path, model_data)
# Get regular images array (might be None)
regular_images = model_data['civitai'].get('images', [])
regular_images = civitai_data.get('images', [])
return web.json_response({
'success': True,
@@ -568,4 +568,4 @@ class ExampleImagesProcessor:
}, status=500)

View File

@@ -2,7 +2,7 @@ from datetime import datetime
import os
import json
import logging
from typing import Dict, Optional, Type, Union
from typing import Any, Dict, Optional, Type, Union
from .models import BaseModelMetadata, LoraMetadata
from .file_utils import normalize_path, find_preview_file, calculate_sha256
@@ -53,6 +53,70 @@ class MetadataManager:
error_type = "Invalid JSON" if isinstance(e, json.JSONDecodeError) else "Parse error"
logger.error(f"{error_type} in metadata file: {metadata_path}. Error: {str(e)}. Skipping model to preserve existing data.")
return None, True # should_skip = True
@staticmethod
async def load_metadata_payload(file_path: str) -> Dict:
"""
Load metadata and return it as a dictionary, including any unknown fields.
Falls back to reading the raw JSON file if parsing into a model class fails.
"""
payload: Dict = {}
metadata_obj, should_skip = await MetadataManager.load_metadata(file_path)
if metadata_obj:
payload = metadata_obj.to_dict()
unknown_fields = getattr(metadata_obj, "_unknown_fields", None)
if isinstance(unknown_fields, dict):
payload.update(unknown_fields)
else:
if not should_skip:
metadata_path = (
file_path
if file_path.endswith(".metadata.json")
else f"{os.path.splitext(file_path)[0]}.metadata.json"
)
if os.path.exists(metadata_path):
try:
with open(metadata_path, "r", encoding="utf-8") as handle:
raw = json.load(handle)
if isinstance(raw, dict):
payload = raw
except json.JSONDecodeError:
logger.warning(
"Failed to parse metadata file %s while loading payload",
metadata_path,
)
except Exception as exc: # pragma: no cover - defensive logging
logger.warning("Failed to read metadata file %s: %s", metadata_path, exc)
if not isinstance(payload, dict):
payload = {}
if file_path:
payload.setdefault("file_path", normalize_path(file_path))
return payload
@staticmethod
async def hydrate_model_data(model_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Replace the provided model data with the authoritative payload from disk.
Preserves the cached folder entry if present.
"""
file_path = model_data.get("file_path")
if not file_path:
return model_data
folder = model_data.get("folder")
payload = await MetadataManager.load_metadata_payload(file_path)
if folder is not None:
payload["folder"] = folder
model_data.clear()
model_data.update(payload)
return model_data
@staticmethod
async def save_metadata(path: str, metadata: Union[BaseModelMetadata, Dict]) -> bool:

View File

@@ -65,6 +65,12 @@ def ensure_settings_file(logger: Optional[logging.Logger] = None) -> str:
logger = logger or _LOGGER
target_path = get_settings_file_path(create_dir=True)
preferred_dir = user_config_dir(APP_NAME, appauthor=False)
preferred_path = os.path.join(preferred_dir, "settings.json")
if os.path.abspath(target_path) != os.path.abspath(preferred_path):
os.makedirs(preferred_dir, exist_ok=True)
target_path = preferred_path
legacy_path = get_legacy_settings_path()
if os.path.exists(legacy_path) and not os.path.exists(target_path):

View File

@@ -3,7 +3,7 @@ import os
from typing import Dict
from ..services.service_registry import ServiceRegistry
from ..config import config
from ..services.settings_manager import settings
from ..services.settings_manager import get_settings_manager
from .constants import CIVITAI_MODEL_TAGS
import asyncio
@@ -143,7 +143,8 @@ def calculate_relative_path_for_model(model_data: Dict, model_type: str = 'lora'
Relative path string (empty string for flat structure)
"""
# Get path template from settings for specific model type
path_template = settings.get_download_path_template(model_type)
settings_manager = get_settings_manager()
path_template = settings_manager.get_download_path_template(model_type)
# If template is empty, return empty path (flat structure)
if not path_template:
@@ -166,7 +167,7 @@ def calculate_relative_path_for_model(model_data: Dict, model_type: str = 'lora'
model_tags = model_data.get('tags', [])
# Apply mapping if available
base_model_mappings = settings.get('base_model_path_mappings', {})
base_model_mappings = settings_manager.get('base_model_path_mappings', {})
mapped_base_model = base_model_mappings.get(base_model, base_model)
# Find the first Civitai model tag that exists in model_tags

View File

@@ -1,7 +1,7 @@
[project]
name = "comfyui-lora-manager"
description = "Revolutionize your workflow with the ultimate LoRA companion for ComfyUI!"
version = "0.9.6"
version = "0.9.7"
license = {file = "LICENSE"}
dependencies = [
"aiohttp",

View File

@@ -21,6 +21,7 @@ export class SidebarManager {
this.isInitialized = false;
this.displayMode = 'tree'; // 'tree' or 'list'
this.foldersList = [];
this.recursiveSearchEnabled = true;
// Bind methods
this.handleTreeClick = this.handleTreeClick.bind(this);
@@ -36,6 +37,7 @@ export class SidebarManager {
this.updateContainerMargin = this.updateContainerMargin.bind(this);
this.handleDisplayModeToggle = this.handleDisplayModeToggle.bind(this);
this.handleFolderListClick = this.handleFolderListClick.bind(this);
this.handleRecursiveToggle = this.handleRecursiveToggle.bind(this);
}
async initialize(pageControls) {
@@ -89,6 +91,7 @@ export class SidebarManager {
this.isHovering = false;
this.apiClient = null;
this.isInitialized = false;
this.recursiveSearchEnabled = true;
// Reset container margin
const container = document.querySelector('.container');
@@ -111,6 +114,7 @@ export class SidebarManager {
const sidebar = document.getElementById('folderSidebar');
const hoverArea = document.getElementById('sidebarHoverArea');
const displayModeToggleBtn = document.getElementById('sidebarDisplayModeToggle');
const recursiveToggleBtn = document.getElementById('sidebarRecursiveToggle');
if (pinToggleBtn) {
pinToggleBtn.removeEventListener('click', this.handlePinToggle);
@@ -145,6 +149,9 @@ export class SidebarManager {
if (displayModeToggleBtn) {
displayModeToggleBtn.removeEventListener('click', this.handleDisplayModeToggle);
}
if (recursiveToggleBtn) {
recursiveToggleBtn.removeEventListener('click', this.handleRecursiveToggle);
}
}
async init() {
@@ -197,7 +204,7 @@ export class SidebarManager {
updateSidebarTitle() {
const sidebarTitle = document.getElementById('sidebarTitle');
if (sidebarTitle) {
sidebarTitle.textContent = `${this.apiClient.apiConfig.config.displayName} Root`;
sidebarTitle.textContent = translate('sidebar.modelRoot');
}
}
@@ -220,6 +227,12 @@ export class SidebarManager {
collapseAllBtn.addEventListener('click', this.handleCollapseAll);
}
// Recursive toggle button
const recursiveToggleBtn = document.getElementById('sidebarRecursiveToggle');
if (recursiveToggleBtn) {
recursiveToggleBtn.addEventListener('click', this.handleRecursiveToggle);
}
// Tree click handler
const folderTree = document.getElementById('sidebarFolderTree');
if (folderTree) {
@@ -645,11 +658,33 @@ export class SidebarManager {
this.displayMode = this.displayMode === 'tree' ? 'list' : 'tree';
this.updateDisplayModeButton();
this.updateCollapseAllButton();
this.updateRecursiveToggleButton();
this.updateSearchRecursiveOption();
this.saveDisplayMode();
this.loadFolderTree(); // Reload with new display mode
}
async handleRecursiveToggle(event) {
event.stopPropagation();
if (this.displayMode !== 'tree') {
return;
}
this.recursiveSearchEnabled = !this.recursiveSearchEnabled;
setStorageItem(`${this.pageType}_recursiveSearch`, this.recursiveSearchEnabled);
this.updateSearchRecursiveOption();
this.updateRecursiveToggleButton();
if (this.pageControls && typeof this.pageControls.resetAndReload === 'function') {
try {
await this.pageControls.resetAndReload(true);
} catch (error) {
console.error('Failed to reload models after toggling recursive search:', error);
}
}
}
updateDisplayModeButton() {
const displayModeBtn = document.getElementById('sidebarDisplayModeToggle');
if (displayModeBtn) {
@@ -679,8 +714,35 @@ export class SidebarManager {
}
}
updateRecursiveToggleButton() {
const recursiveToggleBtn = document.getElementById('sidebarRecursiveToggle');
if (!recursiveToggleBtn) return;
const icon = recursiveToggleBtn.querySelector('i');
const isTreeMode = this.displayMode === 'tree';
const isActive = isTreeMode && this.recursiveSearchEnabled;
recursiveToggleBtn.classList.toggle('active', isActive);
recursiveToggleBtn.classList.toggle('disabled', !isTreeMode);
recursiveToggleBtn.setAttribute('aria-pressed', isActive ? 'true' : 'false');
recursiveToggleBtn.setAttribute('aria-disabled', isTreeMode ? 'false' : 'true');
if (icon) {
icon.className = 'fas fa-code-branch';
}
if (!isTreeMode) {
recursiveToggleBtn.title = translate('sidebar.recursiveUnavailable');
} else if (this.recursiveSearchEnabled) {
recursiveToggleBtn.title = translate('sidebar.recursiveOn');
} else {
recursiveToggleBtn.title = translate('sidebar.recursiveOff');
}
}
updateSearchRecursiveOption() {
this.pageControls.pageState.searchOptions.recursive = this.displayMode === 'tree';
const isRecursive = this.displayMode === 'tree' && this.recursiveSearchEnabled;
this.pageControls.pageState.searchOptions.recursive = isRecursive;
}
updateTreeSelection() {
@@ -925,15 +987,18 @@ export class SidebarManager {
const isPinned = getStorageItem(`${this.pageType}_sidebarPinned`, true);
const expandedPaths = getStorageItem(`${this.pageType}_expandedNodes`, []);
const displayMode = getStorageItem(`${this.pageType}_displayMode`, 'tree'); // 'tree' or 'list', default to 'tree'
const recursiveSearchEnabled = getStorageItem(`${this.pageType}_recursiveSearch`, true);
this.isPinned = isPinned;
this.expandedNodes = new Set(expandedPaths);
this.displayMode = displayMode;
this.recursiveSearchEnabled = recursiveSearchEnabled;
this.updatePinButton();
this.updateDisplayModeButton();
this.updateCollapseAllButton();
this.updateSearchRecursiveOption();
this.updateRecursiveToggleButton();
}
restoreSelectedFolder() {
@@ -974,4 +1039,4 @@ export class SidebarManager {
}
// Create and export global instance
export const sidebarManager = new SidebarManager();
export const sidebarManager = new SidebarManager();

View File

@@ -67,7 +67,7 @@ export const state = {
modelname: true,
tags: false,
creator: false,
recursive: true,
recursive: getStorageItem(`${MODEL_TYPES.LORA}_recursiveSearch`, true),
},
filters: {
baseModel: [],
@@ -116,7 +116,7 @@ export const state = {
filename: true,
modelname: true,
creator: false,
recursive: true,
recursive: getStorageItem(`${MODEL_TYPES.CHECKPOINT}_recursiveSearch`, true),
},
filters: {
baseModel: [],
@@ -144,7 +144,7 @@ export const state = {
modelname: true,
tags: false,
creator: false,
recursive: true,
recursive: getStorageItem(`${MODEL_TYPES.EMBEDDING}_recursiveSearch`, true),
},
filters: {
baseModel: [],
@@ -261,4 +261,4 @@ export function initPageState(pageType) {
return getCurrentPageState();
}
return null;
}
}

View File

@@ -435,8 +435,9 @@ export async function sendLoraToWorkflow(loraSyntax, replaceMode = false, syntax
return true;
} else {
// Single node - send directly
const nodeId = Object.keys(registryData.data.nodes)[0];
return await sendToSpecificNode([nodeId], loraSyntax, replaceMode, syntaxType);
const nodes = registryData.data.nodes;
const nodeId = Object.keys(nodes)[0];
return await sendToSpecificNode([nodeId], nodes, loraSyntax, replaceMode, syntaxType);
}
} catch (error) {
console.error('Failed to get registry:', error);
@@ -452,19 +453,65 @@ export async function sendLoraToWorkflow(loraSyntax, replaceMode = false, syntax
* @param {boolean} replaceMode - Whether to replace existing LoRAs
* @param {string} syntaxType - The type of syntax ('lora' or 'recipe')
*/
async function sendToSpecificNode(nodeIds, loraSyntax, replaceMode, syntaxType) {
function resolveNodeReference(nodeKey, nodesMap) {
if (!nodeKey) {
return null;
}
const directMatch = nodesMap?.[nodeKey];
if (directMatch) {
return {
node_id: directMatch.id,
graph_id: directMatch.graph_id ?? null,
};
}
if (typeof nodeKey === 'string' && nodeKey.includes(':')) {
const [graphId, ...rest] = nodeKey.split(':');
const nodeIdPart = rest.join(':');
const numericNodeId = Number(nodeIdPart);
return {
node_id: Number.isNaN(numericNodeId) ? nodeIdPart : numericNodeId,
graph_id: graphId || null,
};
}
const numericId = Number(nodeKey);
return {
node_id: Number.isNaN(numericId) ? nodeKey : numericId,
graph_id: null,
};
}
async function sendToSpecificNode(nodeIds, nodesMap, loraSyntax, replaceMode, syntaxType) {
try {
// Call the backend API to update the lora code
const requestBody = {
lora_code: loraSyntax,
mode: replaceMode ? 'replace' : 'append'
};
if (Array.isArray(nodeIds)) {
const references = nodeIds
.map((nodeKey) => resolveNodeReference(nodeKey, nodesMap))
.filter((reference) => reference && reference.node_id !== undefined);
if (references.length > 0) {
requestBody.node_ids = references;
}
} else if (nodeIds) {
const reference = resolveNodeReference(nodeIds, nodesMap);
if (reference) {
requestBody.node_ids = [reference];
}
}
const response = await fetch('/api/lm/update-lora-code', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
node_ids: nodeIds,
lora_code: loraSyntax,
mode: replaceMode ? 'replace' : 'append'
})
body: JSON.stringify(requestBody)
});
const result = await response.json();
@@ -522,16 +569,17 @@ function showNodeSelector(nodes, loraSyntax, replaceMode, syntaxType) {
hideNodeSelector();
// Generate node list HTML with icons and proper colors
const nodeItems = Object.values(nodes).map(node => {
const nodeItems = Object.entries(nodes).map(([nodeKey, node]) => {
const iconClass = NODE_TYPE_ICONS[node.type] || 'fas fa-question-circle';
const bgColor = node.bgcolor || DEFAULT_NODE_COLOR;
const graphLabel = node.graph_name ? ` (${node.graph_name})` : '';
return `
<div class="node-item" data-node-id="${node.id}">
<div class="node-item" data-node-id="${nodeKey}">
<div class="node-icon-indicator" style="background-color: ${bgColor}">
<i class="${iconClass}"></i>
</div>
<span>#${node.id} ${node.title}</span>
<span>#${node.id}${graphLabel} ${node.title}</span>
</div>
`;
}).join('');
@@ -610,10 +658,10 @@ function setupNodeSelectorEvents(selector, nodes, loraSyntax, replaceMode, synta
if (action === 'send-all') {
// Send to all nodes
const allNodeIds = Object.keys(nodes);
await sendToSpecificNode(allNodeIds, loraSyntax, replaceMode, syntaxType);
await sendToSpecificNode(allNodeIds, nodes, loraSyntax, replaceMode, syntaxType);
} else if (nodeId) {
// Send to specific node
await sendToSpecificNode([nodeId], loraSyntax, replaceMode, syntaxType);
await sendToSpecificNode([nodeId], nodes, loraSyntax, replaceMode, syntaxType);
}
hideNodeSelector();

View File

@@ -9,6 +9,9 @@
<button class="sidebar-action-btn" id="sidebarDisplayModeToggle" title="{{ t('sidebar.switchToListView') }}">
<i class="fas fa-sitemap"></i>
</button>
<button class="sidebar-action-btn active" id="sidebarRecursiveToggle" title="{{ t('sidebar.recursiveOn') }}" aria-pressed="true">
<i class="fas fa-code-branch"></i>
</button>
<button class="sidebar-action-btn" id="sidebarCollapseAll" title="{{ t('sidebar.collapseAll') }}">
<i class="fas fa-compress-alt"></i>
</button>

View File

@@ -73,6 +73,30 @@ nodes_mock.NODE_CLASS_MAPPINGS = {}
sys.modules['nodes'] = nodes_mock
@pytest.fixture(autouse=True)
def _isolate_settings_dir(tmp_path_factory, monkeypatch):
"""Redirect settings.json into a temporary directory for each test."""
settings_dir = tmp_path_factory.mktemp("settings_dir")
def fake_get_settings_dir(create: bool = True) -> str:
if create:
settings_dir.mkdir(exist_ok=True)
return str(settings_dir)
monkeypatch.setattr("py.utils.settings_paths.get_settings_dir", fake_get_settings_dir)
monkeypatch.setattr(
"py.utils.settings_paths.user_config_dir",
lambda *_args, **_kwargs: str(settings_dir),
)
from py.services import settings_manager as settings_manager_module
settings_manager_module.reset_settings_manager()
yield
settings_manager_module.reset_settings_manager()
def pytest_pyfunc_call(pyfuncitem):
"""Allow bare async tests to run without pytest.mark.asyncio."""
test_function = pyfuncitem.function

View File

@@ -46,6 +46,7 @@ vi.mock(EVENT_MANAGER_MODULE, () => ({
off: vi.fn(),
addHandler: vi.fn(),
removeHandler: vi.fn(),
setState: vi.fn(),
},
}));
@@ -62,6 +63,7 @@ describe('UI helper DOM utilities', () => {
afterEach(() => {
vi.useRealTimers();
delete global.fetch;
});
it('creates toast elements and cleans them up after timeout', async () => {
@@ -105,4 +107,53 @@ describe('UI helper DOM utilities', () => {
expect(document.body.dataset.theme).toBe('dark');
expect(document.querySelector('.theme-toggle').classList.contains('theme-dark')).toBe(true);
});
it('renders subgraph names in the node selector list', async () => {
const registryResponse = {
success: true,
data: {
node_count: 2,
nodes: {
'root:1': {
id: 1,
graph_id: 'root',
graph_name: null,
title: 'Root Loader',
type: 1,
bgcolor: '#123456',
},
'subgraph-uuid:2': {
id: 2,
graph_id: 'subgraph-uuid',
graph_name: 'Character Subgraph',
title: 'Nested Loader',
type: 1,
bgcolor: '#654321',
},
},
},
};
global.fetch = vi.fn().mockResolvedValue({
json: async () => registryResponse,
});
document.body.innerHTML = '<div id="nodeSelector"></div>';
const { sendLoraToWorkflow } = await import(UI_HELPERS_MODULE);
const result = await sendLoraToWorkflow('<lora:test:1>');
expect(result).toBe(true);
expect(global.fetch).toHaveBeenCalledWith('/api/lm/get-registry');
const nodeLabels = Array.from(
document.querySelectorAll('#nodeSelector .node-item[data-node-id] span')
).map((span) => span.textContent.trim());
expect(nodeLabels).toEqual([
'#1 Root Loader',
'#2 (Character Subgraph) Nested Loader',
]);
});
});

View File

@@ -16,10 +16,12 @@ from aiohttp.test_utils import TestClient, TestServer
from py.config import config
from py.routes.base_model_routes import BaseModelRoutes
from py.services import model_file_service
from py.services.metadata_sync_service import MetadataSyncService
from py.services.model_file_service import AutoOrganizeResult
from py.services.service_registry import ServiceRegistry
from py.services.websocket_manager import ws_manager
from py.utils.exif_utils import ExifUtils
from py.utils.metadata_manager import MetadataManager
class DummyRoutes(BaseModelRoutes):
@@ -197,6 +199,116 @@ def test_replace_preview_writes_file_and_updates_cache(
asyncio.run(scenario())
def test_fetch_civitai_hydrates_metadata_before_sync(
mock_service,
mock_scanner,
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
):
model_path = tmp_path / "hydrate.safetensors"
model_path.write_bytes(b"model")
metadata_path = tmp_path / "hydrate.metadata.json"
existing_metadata = {
"file_path": str(model_path),
"sha256": "abc123",
"model_name": "Hydrated",
"preview_url": "keep/me.png",
"civitai": {
"id": 99,
"modelId": 42,
"images": [{"url": "https://example.com/existing.png", "type": "image"}],
"customImages": [{"id": "old-id", "url": "", "type": "image"}],
"trainedWords": ["keep"],
},
"custom_field": "preserve",
}
metadata_path.write_text(json.dumps(existing_metadata), encoding="utf-8")
minimal_cache_entry = {
"file_path": str(model_path),
"sha256": "abc123",
"folder": "some/folder",
"civitai": {"id": 99, "modelId": 42},
}
mock_scanner._cache.raw_data = [minimal_cache_entry]
class FakeMetadata:
def __init__(self, payload: dict) -> None:
self._payload = payload
self._unknown_fields = {"legacy_field": "legacy"}
def to_dict(self) -> dict:
return self._payload.copy()
async def fake_load_metadata(path: str, *_args, **_kwargs):
assert path == str(model_path)
return FakeMetadata(existing_metadata), False
async def fake_save_metadata(path: str, metadata: dict) -> bool:
save_calls.append((path, json.loads(json.dumps(metadata))))
return True
async def fake_fetch_and_update_model(
self,
*,
sha256: str,
file_path: str,
model_data: dict,
update_cache_func,
):
captured["model_data"] = json.loads(json.dumps(model_data))
to_save = model_data.copy()
to_save.pop("folder", None)
await MetadataManager.save_metadata(
os.path.splitext(file_path)[0] + ".metadata.json",
to_save,
)
await update_cache_func(file_path, file_path, model_data)
return True, None
save_calls: list[tuple[str, dict]] = []
captured: dict[str, dict] = {}
monkeypatch.setattr(MetadataManager, "load_metadata", staticmethod(fake_load_metadata))
monkeypatch.setattr(MetadataManager, "save_metadata", staticmethod(fake_save_metadata))
monkeypatch.setattr(MetadataSyncService, "fetch_and_update_model", fake_fetch_and_update_model)
async def scenario():
client = await create_test_client(mock_service)
try:
response = await client.post(
"/api/lm/test-models/fetch-civitai",
json={"file_path": str(model_path)},
)
payload = await response.json()
assert response.status == 200
assert payload["success"] is True
assert captured["model_data"]["custom_field"] == "preserve"
assert captured["model_data"]["civitai"]["images"][0]["url"] == "https://example.com/existing.png"
assert captured["model_data"]["civitai"]["trainedWords"] == ["keep"]
assert captured["model_data"]["civitai"]["id"] == 99
finally:
await client.close()
asyncio.run(scenario())
assert save_calls, "Metadata save should be invoked"
saved_path, saved_payload = save_calls[0]
assert saved_path == str(metadata_path)
assert saved_payload["custom_field"] == "preserve"
assert saved_payload["civitai"]["images"][0]["url"] == "https://example.com/existing.png"
assert saved_payload["civitai"]["trainedWords"] == ["keep"]
assert saved_payload["civitai"]["id"] == 99
assert saved_payload["legacy_field"] == "legacy"
assert mock_scanner.updated_models
updated_metadata = mock_scanner.updated_models[-1]["metadata"]
assert updated_metadata["custom_field"] == "preserve"
assert updated_metadata["civitai"]["customImages"][0]["id"] == "old-id"
def test_download_model_invokes_download_manager(
mock_service,
download_manager_stub,

View File

@@ -188,7 +188,7 @@ async def test_get_trigger_words_broadcasts(monkeypatch, routes):
monkeypatch.setattr("py.routes.lora_routes.get_lora_info", lambda name: (f"path/{name}", [f"trigger-{name}"]))
request = DummyRequest(json_data={"lora_names": ["one"], "node_ids": ["node"]})
request = DummyRequest(json_data={"lora_names": ["one"], "node_ids": [{"node_id": "node", "graph_id": "graph-1"}]})
response = await routes.get_trigger_words(request)
payload = json.loads(response.text)
@@ -196,7 +196,7 @@ async def test_get_trigger_words_broadcasts(monkeypatch, routes):
assert payload == {"success": True}
send_mock.assert_called_once_with(
"trigger_word_update",
{"id": "node", "message": "trigger-one"},
{"id": "node", "graph_id": "graph-1", "message": "trigger-one"},
)

View File

@@ -4,7 +4,14 @@ from types import SimpleNamespace
import pytest
from aiohttp import web
from py.routes.handlers.misc_handlers import SettingsHandler, ServiceRegistryAdapter
from py.routes.handlers.misc_handlers import (
LoraCodeHandler,
ModelLibraryHandler,
NodeRegistry,
NodeRegistryHandler,
ServiceRegistryAdapter,
SettingsHandler,
)
from py.routes.misc_route_registrar import MISC_ROUTE_DEFINITIONS, MiscRouteRegistrar
from py.routes.misc_routes import MiscRoutes
@@ -126,6 +133,128 @@ class FakePromptServer:
instance = Instance()
@pytest.mark.asyncio
async def test_register_nodes_requires_graph_id():
node_registry = NodeRegistry()
handler = NodeRegistryHandler(
node_registry=node_registry,
prompt_server=FakePromptServer,
standalone_mode=False,
)
request = FakeRequest(json_data={"nodes": [{"node_id": 1}]})
response = await handler.register_nodes(request)
payload = json.loads(response.text)
assert response.status == 400
assert payload["success"] is False
assert "graph_id" in payload["error"]
@pytest.mark.asyncio
async def test_register_nodes_stores_graph_identifier():
node_registry = NodeRegistry()
handler = NodeRegistryHandler(
node_registry=node_registry,
prompt_server=FakePromptServer,
standalone_mode=False,
)
request = FakeRequest(
json_data={
"nodes": [
{
"node_id": 7,
"graph_id": "graph-123",
"graph_name": "Character Subgraph",
"type": "Lora Loader (LoraManager)",
"title": "Loader",
}
]
}
)
response = await handler.register_nodes(request)
payload = json.loads(response.text)
assert payload["success"] is True
registry = await node_registry.get_registry()
assert registry["node_count"] == 1
stored_node = next(iter(registry["nodes"].values()))
assert stored_node["graph_id"] == "graph-123"
assert stored_node["unique_id"] == "graph-123:7"
assert stored_node["graph_name"] == "Character Subgraph"
@pytest.mark.asyncio
async def test_register_nodes_defaults_graph_name_to_none():
node_registry = NodeRegistry()
handler = NodeRegistryHandler(
node_registry=node_registry,
prompt_server=FakePromptServer,
standalone_mode=False,
)
request = FakeRequest(
json_data={
"nodes": [
{
"node_id": 8,
"graph_id": "root",
"type": "Lora Loader (LoraManager)",
"title": "Root Loader",
}
]
}
)
response = await handler.register_nodes(request)
payload = json.loads(response.text)
assert payload["success"] is True
registry = await node_registry.get_registry()
stored_node = next(iter(registry["nodes"].values()))
assert stored_node["graph_name"] is None
@pytest.mark.asyncio
async def test_update_lora_code_includes_graph_identifier():
send_calls: list[tuple[str, dict]] = []
class RecordingPromptServer:
class Instance:
def send_sync(self, event, payload):
send_calls.append((event, payload))
instance = Instance()
handler = LoraCodeHandler(RecordingPromptServer)
request = FakeRequest(
json_data={
"node_ids": [{"node_id": 3, "graph_id": "graph-A"}],
"lora_code": "<lora>",
"mode": "replace",
}
)
response = await handler.update_lora_code(request)
payload = json.loads(response.text)
assert payload["success"] is True
assert payload["results"] == [
{"node_id": 3, "graph_id": "graph-A", "success": True}
]
assert send_calls == [
(
"lora_code_update",
{"id": 3, "graph_id": "graph-A", "lora_code": "<lora>", "mode": "replace"},
)
]
class FakeScanner:
async def check_model_version_exists(self, _version_id):
return False
@@ -138,10 +267,34 @@ async def fake_scanner_factory():
return FakeScanner()
class FakeExistenceScanner:
def __init__(self, existing=None):
self._existing = set(existing or [])
async def check_model_version_exists(self, version_id):
return version_id in self._existing
async def get_model_versions_by_id(self, _model_id):
return []
class FakeMetadataProvider:
async def get_model_versions(self, _model_id):
return {"modelVersions": [], "name": "", "type": "lora"}
async def get_user_models(self, _username):
return []
class FakeUserModelsProvider(FakeMetadataProvider):
def __init__(self, models):
self.models = models
self.received_usernames: list[str] = []
async def get_user_models(self, username):
self.received_usernames.append(username)
return self.models
async def fake_metadata_provider_factory():
return FakeMetadataProvider()
@@ -211,6 +364,250 @@ async def test_misc_routes_bind_produces_expected_handlers():
assert set(mapping.keys()) == expected_names
@pytest.mark.asyncio
async def test_get_civitai_user_models_marks_library_versions():
models = [
{
"id": 1,
"name": "Model A",
"type": "LORA",
"tags": ["style"],
"modelVersions": [
{
"id": 100,
"name": "v1",
"baseModel": "Flux.1",
"images": [{"url": "http://example.com/a1.jpg"}],
},
{
"id": 101,
"name": "v2",
"baseModel": "Flux.1",
"images": [{"url": "http://example.com/a2.jpg"}],
},
],
},
{
"id": 2,
"name": "Embedding",
"type": "TextualInversion",
"tags": ["embedding"],
"modelVersions": [
{
"id": 200,
"name": "v1",
"baseModel": None,
"images": [{"url": "http://example.com/e1.jpg"}],
},
{
"id": 202,
"name": "v2",
"baseModel": None,
},
],
},
{
"id": 3,
"name": "Checkpoint",
"type": "Checkpoint",
"tags": ["checkpoint"],
"modelVersions": [
{
"id": 300,
"name": "v1",
"baseModel": "SDXL",
"images": [],
}
],
},
{
"id": 4,
"name": "Unsupported",
"type": "Other",
"modelVersions": [
{
"id": 400,
"name": "v1",
}
],
},
]
provider = FakeUserModelsProvider(models)
async def provider_factory():
return provider
lora_scanner = FakeExistenceScanner({101})
checkpoint_scanner = FakeExistenceScanner()
embedding_scanner = FakeExistenceScanner({202})
async def lora_factory():
return lora_scanner
async def checkpoint_factory():
return checkpoint_scanner
async def embedding_factory():
return embedding_scanner
handler = ModelLibraryHandler(
ServiceRegistryAdapter(
get_lora_scanner=lora_factory,
get_checkpoint_scanner=checkpoint_factory,
get_embedding_scanner=embedding_factory,
),
metadata_provider_factory=provider_factory,
)
response = await handler.get_civitai_user_models(FakeRequest(query={"username": "pixel"}))
payload = json.loads(response.text)
assert payload["success"] is True
assert payload["username"] == "pixel"
assert payload["versions"] == [
{
"modelId": 1,
"versionId": 100,
"modelName": "Model A",
"versionName": "v1",
"type": "LORA",
"tags": ["style"],
"baseModel": "Flux.1",
"thumbnailUrl": "http://example.com/a1.jpg",
"inLibrary": False,
},
{
"modelId": 1,
"versionId": 101,
"modelName": "Model A",
"versionName": "v2",
"type": "LORA",
"tags": ["style"],
"baseModel": "Flux.1",
"thumbnailUrl": "http://example.com/a2.jpg",
"inLibrary": True,
},
{
"modelId": 2,
"versionId": 200,
"modelName": "Embedding",
"versionName": "v1",
"type": "TextualInversion",
"tags": ["embedding"],
"baseModel": None,
"thumbnailUrl": "http://example.com/e1.jpg",
"inLibrary": False,
},
{
"modelId": 2,
"versionId": 202,
"modelName": "Embedding",
"versionName": "v2",
"type": "TextualInversion",
"tags": ["embedding"],
"baseModel": None,
"thumbnailUrl": None,
"inLibrary": True,
},
{
"modelId": 3,
"versionId": 300,
"modelName": "Checkpoint",
"versionName": "v1",
"type": "Checkpoint",
"tags": ["checkpoint"],
"baseModel": "SDXL",
"thumbnailUrl": None,
"inLibrary": False,
},
]
assert provider.received_usernames == ["pixel"]
@pytest.mark.asyncio
async def test_get_civitai_user_models_rewrites_civitai_previews():
image_url = "https://image.civitai.com/container/example/original=true/sample.jpeg"
video_url = "https://image.civitai.com/container/example/original=true/sample.mp4"
models = [
{
"id": 1,
"name": "Model A",
"type": "LORA",
"tags": ["style"],
"modelVersions": [
{
"id": 100,
"name": "preview-image",
"baseModel": "Flux.1",
"images": [
{"url": image_url, "type": "image"},
],
},
{
"id": 101,
"name": "preview-video",
"baseModel": "Flux.1",
"images": [
{"url": video_url, "type": "video"},
],
},
],
},
]
provider = FakeUserModelsProvider(models)
async def provider_factory():
return provider
handler = ModelLibraryHandler(
ServiceRegistryAdapter(
get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory,
),
metadata_provider_factory=provider_factory,
)
response = await handler.get_civitai_user_models(FakeRequest(query={"username": "pixel"}))
payload = json.loads(response.text)
assert payload["success"] is True
previews_by_version = {item["versionId"]: item["thumbnailUrl"] for item in payload["versions"]}
assert previews_by_version[100] == "https://image.civitai.com/container/example/width=450,optimized=true/sample.jpeg"
assert (
previews_by_version[101]
== "https://image.civitai.com/container/example/transcode=true,width=450,optimized=true/sample.mp4"
)
@pytest.mark.asyncio
async def test_get_civitai_user_models_requires_username():
provider = FakeUserModelsProvider([])
async def provider_factory():
return provider
handler = ModelLibraryHandler(
ServiceRegistryAdapter(
get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory,
),
metadata_provider_factory=provider_factory,
)
response = await handler.get_civitai_user_models(FakeRequest())
payload = json.loads(response.text)
assert response.status == 400
assert payload["success"] is False
assert "username" in payload["error"].lower()
def test_ensure_handler_mapping_caches_result():
call_records = []

View File

@@ -1,3 +1,4 @@
import copy
from unittest.mock import AsyncMock
import pytest
@@ -169,6 +170,158 @@ async def test_get_model_version_by_version_id(monkeypatch, downloader):
assert result["images"][0]["meta"]["other"] == "keep"
async def test_get_model_version_with_model_id_prefers_version_endpoint(monkeypatch, downloader):
requests = []
model_payload = {
"modelVersions": [
{
"id": 7,
"files": [
{
"type": "Model",
"primary": True,
"hashes": {"SHA256": "hash7"},
}
],
}
],
"description": "desc",
"tags": ["tag"],
"creator": {"username": "user"},
"name": "Model",
"type": "LORA",
"nsfw": False,
"poi": False,
}
version_payload = {
"id": 7,
"modelId": 99,
"model": {},
"files": [],
"images": [],
}
async def fake_make_request(method, url, use_auth=True):
requests.append(url)
if url.endswith("/models/99"):
return True, copy.deepcopy(model_payload)
if url.endswith("/model-versions/7"):
return True, copy.deepcopy(version_payload)
return False, "unexpected"
downloader.make_request = fake_make_request
client = await CivitaiClient.get_instance()
result = await client.get_model_version(model_id=99, version_id=7)
assert result["id"] == 7
assert result["model"]["description"] == "desc"
assert result["model"]["tags"] == ["tag"]
assert result["creator"] == {"username": "user"}
assert requests[0].endswith("/models/99")
assert requests[1].endswith("/model-versions/7")
async def test_get_model_version_with_model_id_fallbacks_to_hash(monkeypatch, downloader):
requests = []
model_payload = {
"modelVersions": [
{
"id": 7,
"files": [
{
"type": "Model",
"primary": True,
"hashes": {"SHA256": "hash7"},
}
],
}
],
"description": "desc",
"tags": ["tag"],
"creator": {"username": "user"},
"name": "Model",
"type": "LORA",
"nsfw": False,
"poi": False,
}
version_payload = {
"id": 7,
"modelId": 99,
"files": [],
"images": [],
}
async def fake_make_request(method, url, use_auth=True):
requests.append(url)
if url.endswith("/models/99"):
return True, copy.deepcopy(model_payload)
if url.endswith("/model-versions/7"):
return False, "boom"
if url.endswith("/model-versions/by-hash/hash7"):
return True, copy.deepcopy(version_payload)
return False, "unexpected"
downloader.make_request = fake_make_request
client = await CivitaiClient.get_instance()
result = await client.get_model_version(model_id=99, version_id=7)
assert result["id"] == 7
assert result["model"]["description"] == "desc"
assert result["model"]["tags"] == ["tag"]
assert result["creator"] == {"username": "user"}
assert requests[1].endswith("/model-versions/7")
assert requests[2].endswith("/model-versions/by-hash/hash7")
async def test_get_model_version_with_model_id_builds_from_model_data(monkeypatch, downloader):
model_payload = {
"modelVersions": [
{
"id": 7,
"files": [],
"name": "v1",
}
],
"description": "desc",
"tags": ["tag"],
"creator": {"username": "user"},
"name": "Model",
"type": "LORA",
"nsfw": False,
"poi": False,
}
async def fake_make_request(method, url, use_auth=True):
if url.endswith("/models/99"):
return True, copy.deepcopy(model_payload)
if url.endswith("/model-versions/7"):
return False, "boom"
if "/model-versions/by-hash/" in url:
return False, "boom"
return False, "unexpected"
downloader.make_request = fake_make_request
client = await CivitaiClient.get_instance()
result = await client.get_model_version(model_id=99, version_id=7)
assert result["modelId"] == 99
assert result["model"]["name"] == "Model"
assert result["model"]["type"] == "LORA"
assert result["model"]["description"] == "desc"
assert result["model"]["tags"] == ["tag"]
assert result["creator"] == {"username": "user"}
async def test_get_model_version_requires_identifier(monkeypatch, downloader):
client = await CivitaiClient.get_instance()
result = await client.get_model_version()

View File

@@ -8,7 +8,7 @@ import pytest
from py.services.download_manager import DownloadManager
from py.services import download_manager
from py.services.service_registry import ServiceRegistry
from py.services.settings_manager import settings
from py.services.settings_manager import SettingsManager, get_settings_manager
from py.utils.metadata_manager import MetadataManager
@@ -23,7 +23,8 @@ def reset_download_manager():
@pytest.fixture(autouse=True)
def isolate_settings(monkeypatch, tmp_path):
"""Point settings writes at a temporary directory to avoid touching real files."""
default_settings = settings._get_default_settings()
manager = get_settings_manager()
default_settings = manager._get_default_settings()
default_settings.update(
{
"default_lora_root": str(tmp_path),
@@ -37,8 +38,8 @@ def isolate_settings(monkeypatch, tmp_path):
"base_model_path_mappings": {"BaseModel": "MappedModel"},
}
)
monkeypatch.setattr(settings, "settings", default_settings)
monkeypatch.setattr(type(settings), "_save_settings", lambda self: None)
monkeypatch.setattr(manager, "settings", default_settings)
monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
@pytest.fixture(autouse=True)
@@ -187,7 +188,7 @@ async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata
assert manager._active_downloads[result["download_id"]]["status"] == "completed"
assert captured["relative_path"] == "MappedModel/fantasy"
expected_dir = Path(settings.get("default_lora_root")) / "MappedModel" / "fantasy"
expected_dir = Path(get_settings_manager().get("default_lora_root")) / "MappedModel" / "fantasy"
assert captured["save_dir"] == expected_dir
assert captured["model_type"] == "lora"
assert captured["download_urls"] == [
@@ -393,3 +394,98 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
assert result == {"success": True}
assert [url for url, *_ in dummy_downloader.calls] == download_urls
assert dummy_scanner.calls # ensure cache updated
@pytest.mark.asyncio
async def test_execute_download_uses_rewritten_civitai_preview(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
target_path = save_dir / "file.safetensors"
manager._active_downloads["dl"] = {}
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
self.preview_nsfw_level = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, _path):
return None
def to_dict(self):
return {"file_path": self.file_path}
metadata = DummyMetadata(target_path)
version_info = {
"images": [
{
"url": "https://image.civitai.com/container/example/original=true/sample.jpeg",
"type": "image",
"nsfwLevel": 2,
}
]
}
download_urls = ["https://example.invalid/file.safetensors"]
class DummyDownloader:
def __init__(self):
self.file_calls: list[tuple[str, str]] = []
self.memory_calls = 0
async def download_file(self, url, path, progress_callback=None, use_auth=None):
self.file_calls.append((url, path))
if url.endswith(".jpeg"):
Path(path).write_bytes(b"preview")
return True, None
if url.endswith(".safetensors"):
Path(path).write_bytes(b"model")
return True, None
return False, "unexpected url"
async def download_to_memory(self, *_args, **_kwargs):
self.memory_calls += 1
return False, b"", {}
dummy_downloader = DummyDownloader()
monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=dummy_downloader))
optimize_called = {"value": False}
def fake_optimize_image(**_kwargs):
optimize_called["value"] = True
return b"", {}
monkeypatch.setattr(download_manager.ExifUtils, "optimize_image", staticmethod(fake_optimize_image))
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner))
result = await manager._execute_download(
download_urls=download_urls,
save_dir=str(save_dir),
metadata=metadata,
version_info=version_info,
relative_path="",
progress_callback=None,
model_type="lora",
download_id="dl",
)
assert result == {"success": True}
preview_urls = [url for url, _ in dummy_downloader.file_calls if url.endswith(".jpeg")]
assert any("width=450,optimized=true" in url for url in preview_urls)
assert dummy_downloader.memory_calls == 0
assert optimize_called["value"] is False
assert metadata.preview_url.endswith(".jpeg")
assert metadata.preview_nsfw_level == 2
stored_preview = manager._active_downloads["dl"]["preview_path"]
assert stored_preview.endswith(".jpeg")
assert Path(stored_preview).exists()

View File

@@ -6,7 +6,7 @@ import pytest
from py.services.example_images_cleanup_service import ExampleImagesCleanupService
from py.services.service_registry import ServiceRegistry
from py.services.settings_manager import settings
from py.services.settings_manager import get_settings_manager
class StubScanner:
@@ -21,8 +21,9 @@ class StubScanner:
async def test_cleanup_moves_empty_and_orphaned(tmp_path, monkeypatch):
service = ExampleImagesCleanupService()
previous_path = settings.get('example_images_path')
settings.settings['example_images_path'] = str(tmp_path)
settings_manager = get_settings_manager()
previous_path = settings_manager.get('example_images_path')
settings_manager.settings['example_images_path'] = str(tmp_path)
try:
empty_folder = tmp_path / 'empty_folder'
@@ -64,23 +65,24 @@ async def test_cleanup_moves_empty_and_orphaned(tmp_path, monkeypatch):
finally:
if previous_path is None:
settings.settings.pop('example_images_path', None)
settings_manager.settings.pop('example_images_path', None)
else:
settings.settings['example_images_path'] = previous_path
settings_manager.settings['example_images_path'] = previous_path
@pytest.mark.asyncio
async def test_cleanup_handles_missing_path(monkeypatch):
service = ExampleImagesCleanupService()
previous_path = settings.get('example_images_path')
settings.settings.pop('example_images_path', None)
settings_manager = get_settings_manager()
previous_path = settings_manager.get('example_images_path')
settings_manager.settings.pop('example_images_path', None)
try:
result = await service.cleanup_example_image_folders()
finally:
if previous_path is not None:
settings.settings['example_images_path'] = previous_path
settings_manager.settings['example_images_path'] = previous_path
assert result['success'] is False
assert result['error_code'] == 'path_not_configured'

View File

@@ -7,7 +7,7 @@ from types import SimpleNamespace
import pytest
from py.services.settings_manager import settings
from py.services.settings_manager import SettingsManager, get_settings_manager
from py.utils import example_images_download_manager as download_module
@@ -43,11 +43,15 @@ def _patch_scanner(monkeypatch: pytest.MonkeyPatch, scanner: StubScanner) -> Non
@pytest.mark.usefixtures("tmp_path")
async def test_start_download_rejects_parallel_runs(monkeypatch: pytest.MonkeyPatch, tmp_path):
async def test_start_download_rejects_parallel_runs(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
model = {
"sha256": "abc123",
@@ -106,11 +110,15 @@ async def test_start_download_rejects_parallel_runs(monkeypatch: pytest.MonkeyPa
@pytest.mark.usefixtures("tmp_path")
async def test_pause_resume_blocks_processing(monkeypatch: pytest.MonkeyPatch, tmp_path):
async def test_pause_resume_blocks_processing(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
models = [
{
@@ -231,13 +239,17 @@ async def test_pause_resume_blocks_processing(monkeypatch: pytest.MonkeyPatch, t
@pytest.mark.usefixtures("tmp_path")
async def test_legacy_folder_migrated_and_skipped(monkeypatch: pytest.MonkeyPatch, tmp_path):
async def test_legacy_folder_migrated_and_skipped(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}})
monkeypatch.setitem(settings.settings, "active_library", "extra")
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
monkeypatch.setitem(settings_manager.settings, "libraries", {"default": {}, "extra": {}})
monkeypatch.setitem(settings_manager.settings, "active_library", "extra")
model_hash = "d" * 64
model_path = tmp_path / "model.safetensors"
@@ -310,13 +322,17 @@ async def test_legacy_folder_migrated_and_skipped(monkeypatch: pytest.MonkeyPatc
@pytest.mark.usefixtures("tmp_path")
async def test_legacy_progress_file_migrates(monkeypatch: pytest.MonkeyPatch, tmp_path):
async def test_legacy_progress_file_migrates(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}})
monkeypatch.setitem(settings.settings, "active_library", "extra")
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
monkeypatch.setitem(settings_manager.settings, "libraries", {"default": {}, "extra": {}})
monkeypatch.setitem(settings_manager.settings, "active_library", "extra")
model_hash = "e" * 64
model_path = tmp_path / "model-two.safetensors"
@@ -380,20 +396,24 @@ async def test_legacy_progress_file_migrates(monkeypatch: pytest.MonkeyPatch, tm
@pytest.mark.usefixtures("tmp_path")
async def test_download_remains_in_initial_library(monkeypatch: pytest.MonkeyPatch, tmp_path):
async def test_download_remains_in_initial_library(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
monkeypatch.setitem(settings.settings, "libraries", {"LibraryA": {}, "LibraryB": {}})
monkeypatch.setitem(settings.settings, "active_library", "LibraryA")
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
monkeypatch.setitem(settings_manager.settings, "libraries", {"LibraryA": {}, "LibraryB": {}})
monkeypatch.setitem(settings_manager.settings, "active_library", "LibraryA")
state = {"active": "LibraryA"}
def fake_get_active_library_name(self):
return state["active"]
monkeypatch.setattr(type(settings), "get_active_library_name", fake_get_active_library_name)
monkeypatch.setattr(SettingsManager, "get_active_library_name", fake_get_active_library_name)
model_hash = "f" * 64
model_path = tmp_path / "example-model.safetensors"
@@ -454,3 +474,7 @@ async def test_download_remains_in_initial_library(monkeypatch: pytest.MonkeyPat
assert (model_dir / "local.txt").exists()
assert not (library_b_root / ".download_progress.json").exists()
assert not (library_b_root / model_hash).exists()
@pytest.fixture
def settings_manager():
return get_settings_manager()

View File

@@ -243,6 +243,7 @@ async def test_initialize_in_background_uses_persisted_cache_without_full_scan(t
cache = await scanner.get_cached_data()
assert len(cache.raw_data) == 1
assert cache.raw_data[0]['file_path'] == normalized
assert cache.version_index[11]['file_path'] == normalized
assert scanner._hash_index.get_path('hash-one') == normalized
@@ -301,6 +302,7 @@ async def test_load_persisted_cache_populates_cache(tmp_path: Path, monkeypatch)
assert entry['file_path'] == normalized
assert entry['tags'] == ['alpha']
assert entry['civitai']['trainedWords'] == ['abc']
assert cache.version_index[11]['file_path'] == normalized
assert scanner._hash_index.get_path('hash-one') == normalized
assert scanner._tags_count == {'alpha': 1}
assert ws_stub.payloads[-1]['stage'] == 'loading_cache'
@@ -381,6 +383,66 @@ async def test_batch_delete_persists_removal(tmp_path: Path, monkeypatch):
assert remaining == 0
@pytest.mark.asyncio
async def test_version_index_tracks_version_ids(tmp_path: Path):
scanner = DummyScanner(tmp_path)
first_path = _normalize_path(tmp_path / 'alpha.txt')
second_path = _normalize_path(tmp_path / 'beta.txt')
first_entry = {
'file_path': first_path,
'file_name': 'alpha',
'model_name': 'alpha',
'folder': '',
'size': 1,
'modified': 1.0,
'sha256': 'hash-alpha',
'tags': [],
'civitai': {'id': 101, 'modelId': 1, 'name': 'alpha'},
}
second_entry = {
'file_path': second_path,
'file_name': 'beta',
'model_name': 'beta',
'folder': '',
'size': 1,
'modified': 1.0,
'sha256': 'hash-beta',
'tags': [],
'civitai': {'id': 202, 'modelId': 2, 'name': 'beta'},
}
hash_index = ModelHashIndex()
hash_index.add_entry('hash-alpha', first_path)
hash_index.add_entry('hash-beta', second_path)
scan_result = CacheBuildResult(
raw_data=[first_entry, second_entry],
hash_index=hash_index,
tags_count={},
excluded_models=[],
)
await scanner._apply_scan_result(scan_result)
cache = await scanner.get_cached_data()
assert cache.version_index[101]['file_path'] == first_path
assert cache.version_index[202]['file_path'] == second_path
assert await scanner.check_model_version_exists(101) is True
assert await scanner.check_model_version_exists('202') is True
assert await scanner.check_model_version_exists(999) is False
removed = await scanner._batch_update_cache_for_deleted_models([first_path])
assert removed is True
cache_after = await scanner.get_cached_data()
assert 101 not in cache_after.version_index
assert await scanner.check_model_version_exists(101) is False
@pytest.mark.asyncio
async def test_reconcile_cache_adds_new_files_and_updates_hash_index(tmp_path: Path):
first, _, _ = _create_files(tmp_path)

View File

@@ -0,0 +1,182 @@
from pathlib import Path
from typing import Any
import pytest
from py.services.preview_asset_service import PreviewAssetService
class StubMetadataManager:
async def save_metadata(self, *_args: Any, **_kwargs: Any) -> bool: # pragma: no cover - helper
return True
class RecordingExifUtils:
def __init__(self) -> None:
self.called = False
def optimize_image(self, **kwargs):
self.called = True
return kwargs["image_data"], {}
@pytest.mark.asyncio
async def test_ensure_preview_prefers_rewritten_civitai_image(tmp_path):
metadata_path = tmp_path / "model.metadata.json"
metadata_path.write_text("{}")
local_metadata: dict[str, Any] = {}
class Downloader:
def __init__(self):
self.file_calls: list[tuple[str, str]] = []
self.memory_calls = 0
async def download_file(self, url, path, use_auth=False):
self.file_calls.append((url, path))
if "width=450,optimized=true" in url:
Path(path).write_bytes(b"image-data")
return True, None
return False, "fail"
async def download_to_memory(self, *_args, **_kwargs):
self.memory_calls += 1
return False, b"", {}
downloader = Downloader()
async def downloader_factory():
return downloader
exif_utils = RecordingExifUtils()
service = PreviewAssetService(
metadata_manager=StubMetadataManager(),
downloader_factory=downloader_factory,
exif_utils=exif_utils,
)
images = [
{
"url": "https://image.civitai.com/container/example/original=true/sample.jpeg",
"type": "image",
"nsfwLevel": 3,
}
]
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
assert downloader.memory_calls == 0
assert exif_utils.called is False
assert len(downloader.file_calls) == 1
assert "width=450,optimized=true" in downloader.file_calls[0][0]
preview_path = Path(local_metadata["preview_url"])
assert preview_path.exists()
assert preview_path.suffix == ".jpeg"
assert local_metadata["preview_nsfw_level"] == 3
@pytest.mark.asyncio
async def test_ensure_preview_falls_back_to_webp_when_rewrite_fails(tmp_path):
metadata_path = tmp_path / "model.metadata.json"
metadata_path.write_text("{}")
local_metadata: dict[str, Any] = {}
class Downloader:
def __init__(self):
self.file_calls: list[tuple[str, str]] = []
self.memory_calls = 0
async def download_file(self, url, path, use_auth=False):
self.file_calls.append((url, path))
return False, "fail"
async def download_to_memory(self, *_args, **_kwargs):
self.memory_calls += 1
return True, b"raw-image", {}
downloader = Downloader()
async def downloader_factory():
return downloader
class ExifUtils:
def __init__(self):
self.calls = 0
def optimize_image(self, **kwargs):
self.calls += 1
return b"webp-data", {}
exif_utils = ExifUtils()
service = PreviewAssetService(
metadata_manager=StubMetadataManager(),
downloader_factory=downloader_factory,
exif_utils=exif_utils,
)
images = [
{
"url": "https://image.civitai.com/container/example/original=true/sample.png",
"type": "image",
"nsfwLevel": 1,
}
]
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
assert downloader.memory_calls == 1
assert exif_utils.calls == 1
preview_path = Path(local_metadata["preview_url"])
assert preview_path.exists()
assert preview_path.suffix == ".webp"
@pytest.mark.asyncio
async def test_ensure_preview_rewrites_civitai_video(tmp_path):
metadata_path = tmp_path / "model.metadata.json"
metadata_path.write_text("{}")
local_metadata: dict[str, Any] = {}
class Downloader:
def __init__(self):
self.file_calls: list[tuple[str, str]] = []
async def download_file(self, url, path, use_auth=False):
self.file_calls.append((url, path))
if "transcode=true,width=450,optimized=true" in url:
Path(path).write_bytes(b"video-data")
return True, None
if url.endswith(".mp4"):
return False, "fail"
return False, "unexpected"
async def download_to_memory(self, *_args, **_kwargs):
pytest.fail("download_to_memory should not be used for video previews")
downloader = Downloader()
async def downloader_factory():
return downloader
service = PreviewAssetService(
metadata_manager=StubMetadataManager(),
downloader_factory=downloader_factory,
exif_utils=RecordingExifUtils(),
)
images = [
{
"url": "https://image.civitai.com/container/example/original=true/sample.mp4",
"type": "video",
"nsfwLevel": 2,
}
]
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
assert len(downloader.file_calls) >= 1
assert any("transcode=true,width=450,optimized=true" in url for url, _ in downloader.file_calls)
preview_path = Path(local_metadata["preview_url"])
assert preview_path.exists()
assert preview_path.suffix == ".mp4"
assert local_metadata["preview_nsfw_level"] == 2

View File

@@ -28,6 +28,7 @@ from py.utils.example_images_processor import (
ExampleImagesImportError,
ExampleImagesValidationError,
)
from py.utils.metadata_manager import MetadataManager
from tests.conftest import MockModelService, MockScanner
@@ -155,7 +156,9 @@ async def test_auto_organize_use_case_rejects_when_running() -> None:
await use_case.execute(file_paths=None, progress_callback=None)
async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None:
async def test_bulk_metadata_refresh_emits_progress_and_updates_cache(
monkeypatch: pytest.MonkeyPatch,
) -> None:
scanner = MockScanner()
scanner._cache.raw_data = [
{
@@ -170,6 +173,25 @@ async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None:
settings = StubSettings()
progress = ProgressCollector()
hydration_calls: list[str] = []
async def fake_hydrate(model_data: Dict[str, Any]) -> Dict[str, Any]:
hydration_calls.append(model_data.get("file_path", ""))
model_data.clear()
model_data.update(
{
"file_path": "model1.safetensors",
"sha256": "hash",
"from_civitai": True,
"model_name": "Demo",
"extra": "value",
"civitai": {"images": [{"url": "existing.png", "type": "image"}]},
}
)
return model_data
monkeypatch.setattr(MetadataManager, "hydrate_model_data", staticmethod(fake_hydrate))
use_case = BulkMetadataRefreshUseCase(
service=service,
metadata_sync=metadata_sync,
@@ -183,6 +205,9 @@ async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None:
assert progress.events[0]["status"] == "started"
assert progress.events[-1]["status"] == "completed"
assert metadata_sync.calls
assert metadata_sync.calls[0]["model_data"]["extra"] == "value"
assert scanner._cache.raw_data[0]["extra"] == "value"
assert hydration_calls == ["model1.safetensors"]
assert scanner._cache.resort_calls == 1
@@ -314,4 +339,4 @@ async def test_import_example_images_use_case_propagates_generic_error() -> None
request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]})
with pytest.raises(ExampleImagesImportError):
await use_case.execute(request)
await use_case.execute(request)

View File

@@ -5,7 +5,7 @@ from typing import Any, Dict
import pytest
from py.services.settings_manager import settings
from py.services.settings_manager import get_settings_manager
from py.utils import example_images_download_manager as download_module
@@ -19,19 +19,21 @@ class RecordingWebSocketManager:
@pytest.fixture(autouse=True)
def restore_settings() -> None:
original = settings.settings.copy()
manager = get_settings_manager()
original = manager.settings.copy()
try:
yield
finally:
settings.settings.clear()
settings.settings.update(original)
manager.settings.clear()
manager.settings.update(original)
async def test_start_download_requires_configured_path(monkeypatch: pytest.MonkeyPatch) -> None:
manager = download_module.DownloadManager(ws_manager=RecordingWebSocketManager())
# Ensure example_images_path is not configured
settings.settings.pop('example_images_path', None)
settings_manager = get_settings_manager()
settings_manager.settings.pop('example_images_path', None)
with pytest.raises(download_module.DownloadConfigurationError) as exc_info:
await manager.start_download({})
@@ -44,9 +46,10 @@ async def test_start_download_requires_configured_path(monkeypatch: pytest.Monke
async def test_start_download_bootstraps_progress_and_task(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path)
settings.settings["libraries"] = {"default": {}}
settings.settings["active_library"] = "default"
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
settings_manager.settings["libraries"] = {"default": {}}
settings_manager.settings["active_library"] = "default"
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
@@ -84,9 +87,10 @@ async def test_start_download_bootstraps_progress_and_task(monkeypatch: pytest.M
async def test_pause_and_resume_flow(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path)
settings.settings["libraries"] = {"default": {}}
settings.settings["active_library"] = "default"
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
settings_manager.settings["libraries"] = {"default": {}}
settings_manager.settings["active_library"] = "default"
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)

View File

@@ -7,7 +7,7 @@ from typing import Any, Dict
import pytest
from py.services.settings_manager import settings
from py.services.settings_manager import get_settings_manager
from py.utils.example_images_file_manager import ExampleImagesFileManager
@@ -22,16 +22,18 @@ class JsonRequest:
@pytest.fixture(autouse=True)
def restore_settings() -> None:
original = settings.settings.copy()
manager = get_settings_manager()
original = manager.settings.copy()
try:
yield
finally:
settings.settings.clear()
settings.settings.update(original)
manager.settings.clear()
manager.settings.update(original)
async def test_open_folder_requires_existing_model_directory(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path)
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
model_hash = "a" * 64
model_folder = tmp_path / model_hash
model_folder.mkdir()
@@ -65,7 +67,8 @@ async def test_open_folder_requires_existing_model_directory(monkeypatch: pytest
async def test_open_folder_rejects_invalid_paths(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path)
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
def fake_get_model_folder(_hash):
return str(tmp_path.parent / "outside")
@@ -81,7 +84,8 @@ async def test_open_folder_rejects_invalid_paths(monkeypatch: pytest.MonkeyPatch
async def test_get_files_lists_supported_media(tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path)
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
model_hash = "b" * 64
model_folder = tmp_path / model_hash
model_folder.mkdir()
@@ -99,7 +103,8 @@ async def test_get_files_lists_supported_media(tmp_path) -> None:
async def test_has_images_reports_presence(tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path)
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
model_hash = "c" * 64
model_folder = tmp_path / model_hash
model_folder.mkdir()

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
import json
import os
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List, Tuple
@@ -30,7 +32,23 @@ def patch_metadata_manager(monkeypatch: pytest.MonkeyPatch):
saved.append((path, metadata.copy()))
return True
class SimpleMetadata:
def __init__(self, payload: Dict[str, Any]) -> None:
self._payload = payload
self._unknown_fields: Dict[str, Any] = {}
def to_dict(self) -> Dict[str, Any]:
return self._payload.copy()
async def fake_load(path: str, *_args: Any, **_kwargs: Any):
metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json"
if os.path.exists(metadata_path):
data = json.loads(Path(metadata_path).read_text(encoding="utf-8"))
return SimpleMetadata(data), False
return None, False
monkeypatch.setattr(metadata_module.MetadataManager, "save_metadata", staticmethod(fake_save))
monkeypatch.setattr(metadata_module.MetadataManager, "load_metadata", staticmethod(fake_load))
return saved
@@ -64,10 +82,80 @@ async def test_update_metadata_after_import_enriches_entries(monkeypatch: pytest
assert custom[0]["hasMeta"] is True
assert custom[0]["type"] == "image"
assert patch_metadata_manager[0][0] == str(model_file)
assert Path(patch_metadata_manager[0][0]) == model_file
assert scanner.updates
@pytest.mark.asyncio
async def test_update_metadata_after_import_preserves_existing_metadata(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
patch_metadata_manager,
):
model_hash = "b" * 64
model_file = tmp_path / "preserve.safetensors"
model_file.write_text("content", encoding="utf-8")
metadata_path = tmp_path / "preserve.metadata.json"
existing_payload = {
"model_name": "Example",
"file_path": str(model_file),
"civitai": {
"id": 42,
"modelId": 88,
"name": "Example",
"trainedWords": ["foo"],
"images": [{"url": "https://example.com/default.png", "type": "image"}],
"customImages": [
{"id": "existing-id", "type": "image", "url": "", "nsfwLevel": 0}
],
},
"extraField": "keep-me",
}
metadata_path.write_text(json.dumps(existing_payload), encoding="utf-8")
model_data = {
"sha256": model_hash,
"model_name": "Example",
"file_path": str(model_file),
"civitai": {
"id": 42,
"modelId": 88,
"name": "Example",
"trainedWords": ["foo"],
"customImages": [],
},
}
scanner = StubScanner([model_data])
image_path = tmp_path / "new.png"
image_path.write_bytes(b"fakepng")
monkeypatch.setattr(metadata_module.ExifUtils, "extract_image_metadata", staticmethod(lambda _path: None))
monkeypatch.setattr(metadata_module.MetadataUpdater, "_parse_image_metadata", staticmethod(lambda payload: None))
regular, custom = await metadata_module.MetadataUpdater.update_metadata_after_import(
model_hash,
model_data,
scanner,
[(str(image_path), "new-id")],
)
assert regular == existing_payload["civitai"]["images"]
assert any(entry["id"] == "new-id" for entry in custom)
saved_path, saved_payload = patch_metadata_manager[-1]
assert Path(saved_path) == model_file
assert saved_payload["extraField"] == "keep-me"
assert saved_payload["civitai"]["images"] == existing_payload["civitai"]["images"]
assert saved_payload["civitai"]["trainedWords"] == ["foo"]
assert {entry["id"] for entry in saved_payload["civitai"]["customImages"]} == {"existing-id", "new-id"}
assert scanner.updates
updated_metadata = scanner.updates[-1][2]
assert updated_metadata["civitai"]["images"] == existing_payload["civitai"]["images"]
assert {entry["id"] for entry in updated_metadata["civitai"]["customImages"]} == {"existing-id", "new-id"}
async def test_refresh_model_metadata_records_failures(monkeypatch: pytest.MonkeyPatch, tmp_path):
model_hash = "b" * 64
model_file = tmp_path / "model.safetensors"
@@ -79,6 +167,16 @@ async def test_refresh_model_metadata_records_failures(monkeypatch: pytest.Monke
async def fetch_and_update_model(self, **_kwargs):
return True, None
async def fake_hydrate(model_data: Dict[str, Any]) -> Dict[str, Any]:
model_data["hydrated"] = True
return model_data
monkeypatch.setattr(
metadata_module.MetadataManager,
"hydrate_model_data",
staticmethod(fake_hydrate),
)
monkeypatch.setattr(metadata_module, "_metadata_sync_service", StubMetadataSync())
result = await metadata_module.MetadataUpdater.refresh_model_metadata(
@@ -89,6 +187,7 @@ async def test_refresh_model_metadata_records_failures(monkeypatch: pytest.Monke
{"refreshed_models": set(), "errors": [], "last_error": None},
)
assert result is True
assert cache_item["hydrated"] is True
async def test_update_metadata_from_local_examples_generates_entries(monkeypatch: pytest.MonkeyPatch, tmp_path):
@@ -112,4 +211,4 @@ async def test_update_metadata_from_local_examples_generates_entries(monkeypatch
str(model_dir),
)
assert success is True
assert model_data["civitai"]["images"]
assert model_data["civitai"]["images"]

View File

@@ -6,7 +6,7 @@ from pathlib import Path
import pytest
from py.services.settings_manager import settings
from py.services.settings_manager import get_settings_manager
from py.utils.example_images_paths import (
ensure_library_root_exists,
get_model_folder,
@@ -18,18 +18,24 @@ from py.utils.example_images_paths import (
@pytest.fixture(autouse=True)
def restore_settings():
original = copy.deepcopy(settings.settings)
manager = get_settings_manager()
original = copy.deepcopy(manager.settings)
try:
yield
finally:
settings.settings.clear()
settings.settings.update(original)
manager.settings.clear()
manager.settings.update(original)
def test_get_model_folder_single_library(tmp_path):
settings.settings['example_images_path'] = str(tmp_path)
settings.settings['libraries'] = {'default': {}}
settings.settings['active_library'] = 'default'
@pytest.fixture
def settings_manager():
return get_settings_manager()
def test_get_model_folder_single_library(tmp_path, settings_manager):
settings_manager.settings['example_images_path'] = str(tmp_path)
settings_manager.settings['libraries'] = {'default': {}}
settings_manager.settings['active_library'] = 'default'
model_hash = 'a' * 64
folder = get_model_folder(model_hash)
@@ -39,13 +45,13 @@ def test_get_model_folder_single_library(tmp_path):
assert relative == model_hash
def test_get_model_folder_multi_library(tmp_path):
settings.settings['example_images_path'] = str(tmp_path)
settings.settings['libraries'] = {
def test_get_model_folder_multi_library(tmp_path, settings_manager):
settings_manager.settings['example_images_path'] = str(tmp_path)
settings_manager.settings['libraries'] = {
'default': {},
'Alt Library': {},
}
settings.settings['active_library'] = 'Alt Library'
settings_manager.settings['active_library'] = 'Alt Library'
model_hash = 'b' * 64
expected_folder = tmp_path / 'Alt_Library' / model_hash
@@ -57,13 +63,13 @@ def test_get_model_folder_multi_library(tmp_path):
assert relative == os.path.join('Alt_Library', model_hash).replace('\\', '/')
def test_get_model_folder_migrates_legacy_structure(tmp_path):
settings.settings['example_images_path'] = str(tmp_path)
settings.settings['libraries'] = {
def test_get_model_folder_migrates_legacy_structure(tmp_path, settings_manager):
settings_manager.settings['example_images_path'] = str(tmp_path)
settings_manager.settings['libraries'] = {
'default': {},
'extra': {},
}
settings.settings['active_library'] = 'extra'
settings_manager.settings['active_library'] = 'extra'
model_hash = 'c' * 64
legacy_folder = tmp_path / model_hash
@@ -82,31 +88,31 @@ def test_get_model_folder_migrates_legacy_structure(tmp_path):
assert (expected_folder / 'image.png').exists()
def test_ensure_library_root_exists_creates_directories(tmp_path):
settings.settings['example_images_path'] = str(tmp_path)
settings.settings['libraries'] = {'default': {}, 'secondary': {}}
settings.settings['active_library'] = 'secondary'
def test_ensure_library_root_exists_creates_directories(tmp_path, settings_manager):
settings_manager.settings['example_images_path'] = str(tmp_path)
settings_manager.settings['libraries'] = {'default': {}, 'secondary': {}}
settings_manager.settings['active_library'] = 'secondary'
resolved = ensure_library_root_exists('secondary')
assert Path(resolved) == tmp_path / 'secondary'
assert (tmp_path / 'secondary').is_dir()
def test_iter_library_roots_returns_all_configured(tmp_path):
settings.settings['example_images_path'] = str(tmp_path)
settings.settings['libraries'] = {'default': {}, 'alt': {}}
settings.settings['active_library'] = 'alt'
def test_iter_library_roots_returns_all_configured(tmp_path, settings_manager):
settings_manager.settings['example_images_path'] = str(tmp_path)
settings_manager.settings['libraries'] = {'default': {}, 'alt': {}}
settings_manager.settings['active_library'] = 'alt'
roots = dict(iter_library_roots())
assert roots['default'] == str(tmp_path / 'default')
assert roots['alt'] == str(tmp_path / 'alt')
def test_is_valid_example_images_root_accepts_hash_directories(tmp_path):
settings.settings['example_images_path'] = str(tmp_path)
def test_is_valid_example_images_root_accepts_hash_directories(tmp_path, settings_manager):
settings_manager.settings['example_images_path'] = str(tmp_path)
# Ensure single-library mode (not multi-library mode)
settings.settings['libraries'] = {'default': {}}
settings.settings['active_library'] = 'default'
settings_manager.settings['libraries'] = {'default': {}}
settings_manager.settings['active_library'] = 'default'
hash_folder = tmp_path / ('d' * 64)
hash_folder.mkdir()

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import json
import os
from pathlib import Path
from types import SimpleNamespace
@@ -7,18 +8,42 @@ from typing import Any, Dict, Tuple
import pytest
from py.services.settings_manager import settings
from py.services.settings_manager import get_settings_manager
from py.utils import example_images_metadata as metadata_module
from py.utils import example_images_processor as processor_module
from py.utils.example_images_paths import get_model_folder
@pytest.fixture(autouse=True)
def restore_settings() -> None:
original = settings.settings.copy()
manager = get_settings_manager()
original = manager.settings.copy()
try:
yield
finally:
settings.settings.clear()
settings.settings.update(original)
manager.settings.clear()
manager.settings.update(original)
@pytest.fixture(autouse=True)
def patch_metadata_loader(monkeypatch: pytest.MonkeyPatch) -> None:
class SimpleMetadata:
def __init__(self, payload: Dict[str, Any]) -> None:
self._payload = payload
self._unknown_fields: Dict[str, Any] = {}
def to_dict(self) -> Dict[str, Any]:
return self._payload.copy()
async def fake_load(path: str, *_args: Any, **_kwargs: Any):
metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json"
if os.path.exists(metadata_path):
data = json.loads(Path(metadata_path).read_text(encoding="utf-8"))
return SimpleMetadata(data), False
return None, False
monkeypatch.setattr(processor_module.MetadataManager, "load_metadata", staticmethod(fake_load))
monkeypatch.setattr(metadata_module.MetadataManager, "load_metadata", staticmethod(fake_load))
def test_get_file_extension_from_magic_bytes() -> None:
@@ -90,9 +115,10 @@ def stub_scanners(monkeypatch: pytest.MonkeyPatch, tmp_path) -> StubScanner:
async def test_import_images_creates_hash_directory(monkeypatch: pytest.MonkeyPatch, tmp_path, stub_scanners: StubScanner) -> None:
settings.settings["example_images_path"] = str(tmp_path / "examples")
settings.settings["libraries"] = {"default": {}}
settings.settings["active_library"] = "default"
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path / "examples")
settings_manager.settings["libraries"] = {"default": {}}
settings_manager.settings["active_library"] = "default"
source_file = tmp_path / "upload.png"
source_file.write_bytes(b"PNG data")
@@ -112,7 +138,7 @@ async def test_import_images_creates_hash_directory(monkeypatch: pytest.MonkeyPa
assert result["success"] is True
assert result["files"][0]["name"].startswith("custom_short")
model_folder = Path(settings.settings["example_images_path"]) / ("a" * 64)
model_folder = Path(settings_manager.settings["example_images_path"]) / ("a" * 64)
assert model_folder.exists()
created_files = list(model_folder.glob("custom_short*.png"))
assert len(created_files) == 1
@@ -132,7 +158,8 @@ async def test_import_images_rejects_missing_parameters(monkeypatch: pytest.Monk
async def test_import_images_raises_when_model_not_found(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path)
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
async def _empty_scanner(cls=None):
return StubScanner([])
@@ -143,3 +170,88 @@ async def test_import_images_raises_when_model_not_found(monkeypatch: pytest.Mon
with pytest.raises(processor_module.ExampleImagesImportError):
await processor_module.ExampleImagesProcessor.import_images("a" * 64, [str(tmp_path / "missing.png")])
@pytest.mark.asyncio
async def test_delete_custom_image_preserves_existing_metadata(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path / "examples")
model_hash = "c" * 64
model_file = tmp_path / "keep.safetensors"
model_file.write_text("content", encoding="utf-8")
metadata_path = tmp_path / "keep.metadata.json"
existing_metadata = {
"model_name": "Keep",
"file_path": str(model_file),
"civitai": {
"images": [{"url": "https://example.com/default.png", "type": "image"}],
"customImages": [{"id": "existing-id", "url": "", "type": "image"}],
"trainedWords": ["foo"],
},
}
metadata_path.write_text(json.dumps(existing_metadata), encoding="utf-8")
model_data = {
"sha256": model_hash,
"model_name": "Keep",
"file_path": str(model_file),
"civitai": {
"customImages": [{"id": "existing-id", "url": "", "type": "image"}],
"trainedWords": ["foo"],
},
}
class Scanner(StubScanner):
def has_hash(self, hash_value: str) -> bool:
return hash_value == model_hash
scanner = Scanner([model_data])
async def _return_scanner(cls=None):
return scanner
monkeypatch.setattr(processor_module.ServiceRegistry, "get_lora_scanner", classmethod(_return_scanner))
monkeypatch.setattr(processor_module.ServiceRegistry, "get_checkpoint_scanner", classmethod(_return_scanner))
monkeypatch.setattr(processor_module.ServiceRegistry, "get_embedding_scanner", classmethod(_return_scanner))
model_folder = get_model_folder(model_hash)
os.makedirs(model_folder, exist_ok=True)
(Path(model_folder) / "custom_existing-id.png").write_bytes(b"data")
saved: list[tuple[str, Dict[str, Any]]] = []
async def fake_save(path: str, payload: Dict[str, Any]) -> bool:
saved.append((path, payload.copy()))
return True
monkeypatch.setattr(processor_module.MetadataManager, "save_metadata", staticmethod(fake_save))
class StubRequest:
def __init__(self, payload: Dict[str, Any]) -> None:
self._payload = payload
async def json(self) -> Dict[str, Any]:
return self._payload
response = await processor_module.ExampleImagesProcessor.delete_custom_image(
StubRequest({"model_hash": model_hash, "short_id": "existing-id"})
)
assert response.status == 200
body = json.loads(response.text)
assert body["success"] is True
assert body["custom_images"] == []
assert not (Path(model_folder) / "custom_existing-id.png").exists()
saved_path, saved_payload = saved[-1]
assert saved_path == str(model_file)
assert saved_payload["civitai"]["images"] == existing_metadata["civitai"]["images"]
assert saved_payload["civitai"]["trainedWords"] == ["foo"]
assert saved_payload["civitai"]["customImages"] == []
assert scanner.updated
_, _, updated_metadata = scanner.updated[-1]
assert updated_metadata["civitai"]["images"] == existing_metadata["civitai"]["images"]
assert updated_metadata["civitai"]["customImages"] == []

View File

@@ -1,6 +1,6 @@
import pytest
from py.services.settings_manager import settings
from py.services.settings_manager import SettingsManager, get_settings_manager
from py.utils.utils import (
calculate_recipe_fingerprint,
calculate_relative_path_for_model,
@@ -9,7 +9,8 @@ from py.utils.utils import (
@pytest.fixture
def isolated_settings(monkeypatch):
default_settings = settings._get_default_settings()
manager = get_settings_manager()
default_settings = manager._get_default_settings()
default_settings.update(
{
"download_path_templates": {
@@ -20,8 +21,8 @@ def isolated_settings(monkeypatch):
"base_model_path_mappings": {},
}
)
monkeypatch.setattr(settings, "settings", default_settings)
monkeypatch.setattr(type(settings), "_save_settings", lambda self: None)
monkeypatch.setattr(manager, "settings", default_settings)
monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
return default_settings

View File

@@ -1,6 +1,7 @@
import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js";
import { addJsonDisplayWidget } from "./json_display_widget.js";
import { getNodeFromGraph } from "./utils.js";
app.registerExtension({
name: "LoraManager.DebugMetadata",
@@ -8,8 +9,8 @@ app.registerExtension({
setup() {
// Add message handler to listen for metadata updates from Python
api.addEventListener("metadata_update", (event) => {
const { id, metadata } = event.detail;
this.handleMetadataUpdate(id, metadata);
const { id, graph_id: graphId, metadata } = event.detail;
this.handleMetadataUpdate(id, graphId, metadata);
});
},
@@ -37,8 +38,8 @@ app.registerExtension({
},
// Handle metadata updates from Python
handleMetadataUpdate(id, metadata) {
const node = app.graph.getNodeById(+id);
handleMetadataUpdate(id, graphId, metadata) {
const node = getNodeFromGraph(graphId, id);
if (!node || node.comfyClass !== "Debug Metadata (LoraManager)") {
console.warn("Node not found or not a DebugMetadata node:", id);
return;

View File

@@ -7,6 +7,8 @@ import {
chainCallback,
mergeLoras,
setupInputWidgetWithAutocomplete,
getAllGraphNodes,
getNodeFromGraph,
} from "./utils.js";
import { addLorasWidget } from "./loras_widget.js";
@@ -16,23 +18,26 @@ app.registerExtension({
setup() {
// Add message handler to listen for messages from Python
api.addEventListener("lora_code_update", (event) => {
const { id, lora_code, mode } = event.detail;
this.handleLoraCodeUpdate(id, lora_code, mode);
this.handleLoraCodeUpdate(event.detail || {});
});
},
// Handle lora code updates from Python
handleLoraCodeUpdate(id, loraCode, mode) {
handleLoraCodeUpdate(message) {
const nodeId = message?.node_id ?? message?.id;
const graphId = message?.graph_id;
const loraCode = message?.lora_code ?? "";
const mode = message?.mode ?? "append";
const numericNodeId =
typeof nodeId === "string" ? Number(nodeId) : nodeId;
// Handle broadcast mode (for Desktop/non-browser support)
if (id === -1) {
if (numericNodeId === -1) {
// Find all Lora Loader nodes in the current graph
const loraLoaderNodes = [];
for (const nodeId in app.graph._nodes_by_id) {
const node = app.graph._nodes_by_id[nodeId];
if (node.comfyClass === "Lora Loader (LoraManager)") {
loraLoaderNodes.push(node);
}
}
const loraLoaderNodes = getAllGraphNodes(app.graph)
.map(({ node }) => node)
.filter((node) => node?.comfyClass === "Lora Loader (LoraManager)");
// Update each Lora Loader node found
if (loraLoaderNodes.length > 0) {
@@ -52,14 +57,18 @@ app.registerExtension({
}
// Standard mode - update a specific node
const node = app.graph.getNodeById(+id);
const node = getNodeFromGraph(graphId, numericNodeId);
if (
!node ||
(node.comfyClass !== "Lora Loader (LoraManager)" &&
node.comfyClass !== "Lora Stacker (LoraManager)" &&
node.comfyClass !== "WanVideo Lora Select (LoraManager)")
) {
console.warn("Node not found or not a LoraLoader:", id);
console.warn(
"Node not found or not a LoraLoader:",
graphId ?? "root",
nodeId
);
return;
}

View File

@@ -7,6 +7,8 @@ import {
chainCallback,
mergeLoras,
setupInputWidgetWithAutocomplete,
getLinkFromGraph,
getNodeKey,
} from "./utils.js";
import { addLorasWidget } from "./loras_widget.js";
@@ -124,17 +126,18 @@ app.registerExtension({
// Helper function to find and update downstream Lora Loader nodes
function updateDownstreamLoaders(startNode, visited = new Set()) {
if (visited.has(startNode.id)) return;
visited.add(startNode.id);
const nodeKey = getNodeKey(startNode);
if (!nodeKey || visited.has(nodeKey)) return;
visited.add(nodeKey);
// Check each output link
if (startNode.outputs) {
for (const output of startNode.outputs) {
if (output.links) {
for (const linkId of output.links) {
const link = app.graph.links[linkId];
const link = getLinkFromGraph(startNode.graph, linkId);
if (link) {
const targetNode = app.graph.getNodeById(link.target_id);
const targetNode = startNode.graph?.getNodeById?.(link.target_id);
// If target is a Lora Loader, collect all active loras in the chain and update
if (

View File

@@ -1,6 +1,6 @@
import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js";
import { CONVERTED_TYPE } from "./utils.js";
import { CONVERTED_TYPE, getNodeFromGraph } from "./utils.js";
import { addTagsWidget } from "./tags_widget.js";
// TriggerWordToggle extension for ComfyUI
@@ -10,8 +10,8 @@ app.registerExtension({
setup() {
// Add message handler to listen for messages from Python
api.addEventListener("trigger_word_update", (event) => {
const { id, message } = event.detail;
this.handleTriggerWordUpdate(id, message);
const { id, graph_id: graphId, message } = event.detail;
this.handleTriggerWordUpdate(id, graphId, message);
});
},
@@ -76,8 +76,8 @@ app.registerExtension({
},
// Handle trigger word updates from Python
handleTriggerWordUpdate(id, message) {
const node = app.graph.getNodeById(+id);
handleTriggerWordUpdate(id, graphId, message) {
const node = getNodeFromGraph(graphId, id);
if (!node || node.comfyClass !== "TriggerWord Toggle (LoraManager)") {
console.warn("Node not found or not a TriggerWordToggle:", id);
return;

View File

@@ -1,7 +1,7 @@
// ComfyUI extension to track model usage statistics
import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js";
import { showToast } from "./utils.js";
import { getAllGraphNodes, getNodeReference, showToast } from "./utils.js";
// Define target nodes and their widget configurations
const PATH_CORRECTION_TARGETS = [
@@ -56,25 +56,35 @@ app.registerExtension({
async refreshRegistry() {
try {
// Get current workflow nodes
const prompt = await app.graphToPrompt();
const workflow = prompt.workflow;
if (!workflow || !workflow.nodes) {
console.warn("No workflow nodes found for registry refresh");
return;
}
// Find all Lora nodes
const loraNodes = [];
for (const node of workflow.nodes.values()) {
if (node.type === "Lora Loader (LoraManager)" ||
node.type === "Lora Stacker (LoraManager)" ||
node.type === "WanVideo Lora Select (LoraManager)") {
const nodeEntries = getAllGraphNodes(app.graph);
for (const { graph, node } of nodeEntries) {
if (!node || !node.comfyClass) {
continue;
}
if (
node.comfyClass === "Lora Loader (LoraManager)" ||
node.comfyClass === "Lora Stacker (LoraManager)" ||
node.comfyClass === "WanVideo Lora Select (LoraManager)"
) {
const reference = getNodeReference(node);
if (!reference) {
continue;
}
const graphName = typeof graph?.name === "string" && graph.name.trim()
? graph.name
: null;
loraNodes.push({
node_id: node.id,
bgcolor: node.bgcolor || null,
title: node.title || node.type,
type: node.type
node_id: reference.node_id,
graph_id: reference.graph_id,
graph_name: graphName,
bgcolor: node.bgcolor ?? node.color ?? null,
title: node.title || node.comfyClass,
type: node.comfyClass,
});
}
}

View File

@@ -2,6 +2,120 @@ export const CONVERTED_TYPE = 'converted-widget';
import { app } from "../../scripts/app.js";
import { AutoComplete } from "./autocomplete.js";
const ROOT_GRAPH_ID = "root";
function isMapLike(collection) {
return collection && typeof collection.entries === "function" && typeof collection.values === "function";
}
function getChildGraphs(graph) {
if (!graph || !graph._subgraphs) {
return [];
}
const rawSubgraphs = isMapLike(graph._subgraphs)
? Array.from(graph._subgraphs.values())
: Object.values(graph._subgraphs);
return rawSubgraphs
.map((subgraph) => subgraph?.graph || subgraph?._graph || subgraph)
.filter((subgraph) => subgraph && subgraph !== graph);
}
function traverseGraphs(rootGraph, visitor, visited = new Set()) {
const graph = rootGraph || app.graph;
if (!graph) {
return;
}
const graphId = getGraphId(graph);
if (visited.has(graphId)) {
return;
}
visited.add(graphId);
visitor(graph);
for (const subgraph of getChildGraphs(graph)) {
traverseGraphs(subgraph, visitor, visited);
}
}
export function getGraphId(graph) {
return graph?.id ?? ROOT_GRAPH_ID;
}
export function getNodeGraphId(node) {
if (!node) {
return ROOT_GRAPH_ID;
}
return getGraphId(node.graph || app.graph);
}
export function getGraphById(graphId, rootGraph = app.graph) {
if (!graphId) {
return rootGraph;
}
let foundGraph = null;
traverseGraphs(rootGraph, (graph) => {
if (!foundGraph && getGraphId(graph) === graphId) {
foundGraph = graph;
}
});
return foundGraph;
}
export function getNodeFromGraph(graphId, nodeId) {
const graph = getGraphById(graphId) || app.graph;
if (!graph || typeof graph.getNodeById !== "function") {
return null;
}
const numericId = typeof nodeId === "string" ? Number(nodeId) : nodeId;
return graph.getNodeById(Number.isNaN(numericId) ? nodeId : numericId) || null;
}
export function getAllGraphNodes(rootGraph = app.graph) {
const nodes = [];
traverseGraphs(rootGraph, (graph) => {
if (Array.isArray(graph._nodes)) {
for (const node of graph._nodes) {
nodes.push({ graph, node });
}
}
});
return nodes;
}
export function getNodeReference(node) {
if (!node) {
return null;
}
return {
node_id: node.id,
graph_id: getNodeGraphId(node),
};
}
export function getNodeKey(node) {
if (!node) {
return null;
}
return `${getNodeGraphId(node)}:${node.id}`;
}
export function getLinkFromGraph(graph, linkId) {
if (!graph || graph.links == null) {
return null;
}
if (isMapLike(graph.links)) {
return graph.links.get(linkId) || null;
}
return graph.links[linkId] || null;
}
export function chainCallback(object, property, callback) {
if (object == undefined) {
//This should not happen.
@@ -103,42 +217,56 @@ export const LORA_PATTERN = /<lora:([^:]+):([-\d\.]+)(?::([-\d\.]+))?>/g;
// Get connected Lora Stacker nodes that feed into the current node
export function getConnectedInputStackers(node) {
const connectedStackers = [];
if (node.inputs) {
for (const input of node.inputs) {
if (input.name === "lora_stack" && input.link) {
const link = app.graph.links[input.link];
if (link) {
const sourceNode = app.graph.getNodeById(link.origin_id);
if (sourceNode && sourceNode.comfyClass === "Lora Stacker (LoraManager)") {
connectedStackers.push(sourceNode);
}
}
}
if (!node?.inputs) {
return connectedStackers;
}
for (const input of node.inputs) {
if (input.name !== "lora_stack" || !input.link) {
continue;
}
const link = getLinkFromGraph(node.graph, input.link);
if (!link) {
continue;
}
const sourceNode = node.graph?.getNodeById?.(link.origin_id);
if (sourceNode && sourceNode.comfyClass === "Lora Stacker (LoraManager)") {
connectedStackers.push(sourceNode);
}
}
return connectedStackers;
}
// Get connected TriggerWord Toggle nodes that receive output from the current node
export function getConnectedTriggerToggleNodes(node) {
const connectedNodes = [];
if (node.outputs && node.outputs.length > 0) {
for (const output of node.outputs) {
if (output.links && output.links.length > 0) {
for (const linkId of output.links) {
const link = app.graph.links[linkId];
if (link) {
const targetNode = app.graph.getNodeById(link.target_id);
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
connectedNodes.push(targetNode.id);
}
}
}
if (!node?.outputs) {
return connectedNodes;
}
for (const output of node.outputs) {
if (!output?.links?.length) {
continue;
}
for (const linkId of output.links) {
const link = getLinkFromGraph(node.graph, linkId);
if (!link) {
continue;
}
const targetNode = node.graph?.getNodeById?.(link.target_id);
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
connectedNodes.push(targetNode);
}
}
}
return connectedNodes;
}
@@ -161,11 +289,15 @@ export function getActiveLorasFromNode(node) {
// Recursively collect all active loras from a node and its input chain
export function collectActiveLorasFromChain(node, visited = new Set()) {
// Prevent infinite loops from circular references
if (visited.has(node.id)) {
const nodeKey = getNodeKey(node);
if (!nodeKey) {
return new Set();
}
visited.add(node.id);
if (visited.has(nodeKey)) {
return new Set();
}
visited.add(nodeKey);
// Get active loras from current node
const allActiveLoraNames = getActiveLorasFromNode(node);
@@ -181,14 +313,22 @@ export function collectActiveLorasFromChain(node, visited = new Set()) {
// Update trigger words for connected toggle nodes
export function updateConnectedTriggerWords(node, loraNames) {
const connectedNodeIds = getConnectedTriggerToggleNodes(node);
if (connectedNodeIds.length > 0) {
const connectedNodes = getConnectedTriggerToggleNodes(node);
if (connectedNodes.length > 0) {
const nodeIds = connectedNodes
.map((connectedNode) => getNodeReference(connectedNode))
.filter((reference) => reference !== null);
if (nodeIds.length === 0) {
return;
}
fetch("/api/lm/loras/get_trigger_words", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
lora_names: Array.from(loraNames),
node_ids: connectedNodeIds
node_ids: nodeIds
})
}).catch(err => console.error("Error fetching trigger words:", err));
}