Compare commits

...

20 Commits

Author SHA1 Message Date
Will Miao
76d3aa2b5b feat(version): bump version to 0.9.7 in pyproject.toml 2025-10-09 22:15:50 +08:00
Will Miao
c9a65c7347 feat(metadata): implement model data hydration and enhance metadata handling across services, fixes #547 2025-10-09 22:15:07 +08:00
Will Miao
f542ade628 feat(civitai): implement URL rewriting for Civitai previews and enhance download handling, fixes #499 2025-10-09 17:54:37 +08:00
Will Miao
d2c2bfbe6a feat(sidebar): add recursive search functionality and toggle button 2025-10-09 17:07:10 +08:00
Will Miao
2b6910bd55 feat(misc): mark model versions in library for Civitai user models 2025-10-09 15:23:42 +08:00
Will Miao
b1dd733493 feat(civitai): enhance model version handling with cache lookup 2025-10-09 14:10:00 +08:00
pixelpaws
5dcf0a1e48 Merge pull request #545 from willmiao/codex/evaluate-sqlite-cache-indexing-necessity
feat: index cached models by version id
2025-10-09 13:54:46 +08:00
pixelpaws
cf357b57fc feat(scanner): index cached models by version id 2025-10-09 13:50:44 +08:00
pixelpaws
4e1773833f Merge pull request #544 from willmiao/codex/add-endpoint-to-fetch-civitai-user-models
Add endpoint to fetch Civitai user models
2025-10-09 11:56:57 +08:00
pixelpaws
8cf762ffd3 feat(misc): add civitai user model lookup 2025-10-09 11:49:41 +08:00
pixelpaws
d997eaa429 Merge pull request #543 from willmiao/codex/refactor-get_model_version-logic-and-add-tests, fixes #540
fix: improve Civitai model version retrieval
2025-10-09 11:07:33 +08:00
pixelpaws
8e51f0f19f fix(civitai): improve model version retrieval 2025-10-09 10:56:25 +08:00
pixelpaws
f0e246b4ac Merge pull request #542 from willmiao/codex/investigate-backend-tests-modifying-settings.json
Refactor settings manager to lazy singleton
2025-10-08 16:02:11 +08:00
pixelpaws
a232997a79 fix(utils): respect metadata sync overrides 2025-10-08 15:52:15 +08:00
pixelpaws
08a449db99 fix(metadata): refresh metadata sync settings 2025-10-08 10:38:05 +08:00
pixelpaws
0c023c9888 fix(settings): lazily resolve module aliases 2025-10-08 10:10:23 +08:00
pixelpaws
0ad92d00b3 fix(settings): restore legacy settings aliases 2025-10-08 09:54:36 +08:00
pixelpaws
a726cbea1e fix(routes): pass resolved settings to metadata sync 2025-10-08 09:32:57 +08:00
pixelpaws
c53fa8692b refactor(settings): lazily initialize manager 2025-10-08 08:56:57 +08:00
Will Miao
3118f3b43c feat(graph): enhance node handling with graph identifiers and improve metadata updates, see #408, #538 2025-10-07 23:22:38 +08:00
74 changed files with 2944 additions and 568 deletions

View File

@@ -529,12 +529,15 @@
"title": "Embedding-Modelle" "title": "Embedding-Modelle"
}, },
"sidebar": { "sidebar": {
"modelRoot": "Modell-Stammverzeichnis", "modelRoot": "Stammverzeichnis",
"collapseAll": "Alle Ordner einklappen", "collapseAll": "Alle Ordner einklappen",
"pinSidebar": "Sidebar anheften", "pinSidebar": "Sidebar anheften",
"unpinSidebar": "Sidebar lösen", "unpinSidebar": "Sidebar lösen",
"switchToListView": "Zur Listenansicht wechseln", "switchToListView": "Zur Listenansicht wechseln",
"switchToTreeView": "Zur Baumansicht wechseln", "switchToTreeView": "Zur Baumansicht wechseln",
"recursiveOn": "Unterordner durchsuchen",
"recursiveOff": "Nur aktuellen Ordner durchsuchen",
"recursiveUnavailable": "Rekursive Suche ist nur in der Baumansicht verfügbar",
"collapseAllDisabled": "Im Listenmodus nicht verfügbar" "collapseAllDisabled": "Im Listenmodus nicht verfügbar"
}, },
"statistics": { "statistics": {

View File

@@ -529,12 +529,15 @@
"title": "Embedding Models" "title": "Embedding Models"
}, },
"sidebar": { "sidebar": {
"modelRoot": "Model Root", "modelRoot": "Root",
"collapseAll": "Collapse All Folders", "collapseAll": "Collapse All Folders",
"pinSidebar": "Pin Sidebar", "pinSidebar": "Pin Sidebar",
"unpinSidebar": "Unpin Sidebar", "unpinSidebar": "Unpin Sidebar",
"switchToListView": "Switch to List View", "switchToListView": "Switch to List View",
"switchToTreeView": "Switch to Tree View", "switchToTreeView": "Switch to Tree View",
"recursiveOn": "Search subfolders",
"recursiveOff": "Search current folder only",
"recursiveUnavailable": "Recursive search is available in tree view only",
"collapseAllDisabled": "Not available in list view" "collapseAllDisabled": "Not available in list view"
}, },
"statistics": { "statistics": {

View File

@@ -529,12 +529,15 @@
"title": "Modelos embedding" "title": "Modelos embedding"
}, },
"sidebar": { "sidebar": {
"modelRoot": "Raíz del modelo", "modelRoot": "Raíz",
"collapseAll": "Colapsar todas las carpetas", "collapseAll": "Colapsar todas las carpetas",
"pinSidebar": "Fijar barra lateral", "pinSidebar": "Fijar barra lateral",
"unpinSidebar": "Desfijar barra lateral", "unpinSidebar": "Desfijar barra lateral",
"switchToListView": "Cambiar a vista de lista", "switchToListView": "Cambiar a vista de lista",
"switchToTreeView": "Cambiar a vista de árbol", "switchToTreeView": "Cambiar a vista de árbol",
"recursiveOn": "Buscar en subcarpetas",
"recursiveOff": "Buscar solo en la carpeta actual",
"recursiveUnavailable": "La búsqueda recursiva solo está disponible en la vista en árbol",
"collapseAllDisabled": "No disponible en vista de lista" "collapseAllDisabled": "No disponible en vista de lista"
}, },
"statistics": { "statistics": {

View File

@@ -529,12 +529,15 @@
"title": "Modèles Embedding" "title": "Modèles Embedding"
}, },
"sidebar": { "sidebar": {
"modelRoot": "Racine du modèle", "modelRoot": "Racine",
"collapseAll": "Réduire tous les dossiers", "collapseAll": "Réduire tous les dossiers",
"pinSidebar": "Épingler la barre latérale", "pinSidebar": "Épingler la barre latérale",
"unpinSidebar": "Désépingler la barre latérale", "unpinSidebar": "Désépingler la barre latérale",
"switchToListView": "Passer en vue liste", "switchToListView": "Passer en vue liste",
"switchToTreeView": "Passer en vue arborescence", "switchToTreeView": "Passer en vue arborescence",
"recursiveOn": "Rechercher dans les sous-dossiers",
"recursiveOff": "Rechercher uniquement dans le dossier actuel",
"recursiveUnavailable": "La recherche récursive n'est disponible qu'en vue arborescente",
"collapseAllDisabled": "Non disponible en vue liste" "collapseAllDisabled": "Non disponible en vue liste"
}, },
"statistics": { "statistics": {

View File

@@ -529,12 +529,15 @@
"title": "מודלי Embedding" "title": "מודלי Embedding"
}, },
"sidebar": { "sidebar": {
"modelRoot": "שורש המודלים", "modelRoot": "שורש",
"collapseAll": "כווץ את כל התיקיות", "collapseAll": "כווץ את כל התיקיות",
"pinSidebar": "נעל סרגל צד", "pinSidebar": "נעל סרגל צד",
"unpinSidebar": "שחרר סרגל צד", "unpinSidebar": "שחרר סרגל צד",
"switchToListView": "עבור לתצוגת רשימה", "switchToListView": "עבור לתצוגת רשימה",
"switchToTreeView": "עבור לתצוגת עץ", "switchToTreeView": "תצוגת עץ",
"recursiveOn": "חיפוש בתיקיות משנה",
"recursiveOff": "חיפוש רק בתיקייה הנוכחית",
"recursiveUnavailable": "חיפוש רקורסיבי זמין רק בתצוגת עץ",
"collapseAllDisabled": "לא זמין בתצוגת רשימה" "collapseAllDisabled": "לא זמין בתצוגת רשימה"
}, },
"statistics": { "statistics": {

View File

@@ -529,12 +529,15 @@
"title": "Embeddingモデル" "title": "Embeddingモデル"
}, },
"sidebar": { "sidebar": {
"modelRoot": "モデルルート", "modelRoot": "ルート",
"collapseAll": "すべてのフォルダを折りたたむ", "collapseAll": "すべてのフォルダを折りたたむ",
"pinSidebar": "サイドバーを固定", "pinSidebar": "サイドバーを固定",
"unpinSidebar": "サイドバーの固定を解除", "unpinSidebar": "サイドバーの固定を解除",
"switchToListView": "リストビューに切り替え", "switchToListView": "リストビューに切り替え",
"switchToTreeView": "ツリービューに切り替え", "switchToTreeView": "ツリー表示に切り替え",
"recursiveOn": "サブフォルダーを検索",
"recursiveOff": "現在のフォルダーのみを検索",
"recursiveUnavailable": "再帰検索はツリービューでのみ利用できます",
"collapseAllDisabled": "リストビューでは利用できません" "collapseAllDisabled": "リストビューでは利用できません"
}, },
"statistics": { "statistics": {

View File

@@ -529,12 +529,15 @@
"title": "Embedding 모델" "title": "Embedding 모델"
}, },
"sidebar": { "sidebar": {
"modelRoot": "모델 루트", "modelRoot": "루트",
"collapseAll": "모든 폴더 접기", "collapseAll": "모든 폴더 접기",
"pinSidebar": "사이드바 고정", "pinSidebar": "사이드바 고정",
"unpinSidebar": "사이드바 고정 해제", "unpinSidebar": "사이드바 고정 해제",
"switchToListView": "목록 보기로 전환", "switchToListView": "목록 보기로 전환",
"switchToTreeView": "트리 보기로 전환", "switchToTreeView": "트리 보기로 전환",
"recursiveOn": "하위 폴더 검색",
"recursiveOff": "현재 폴더만 검색",
"recursiveUnavailable": "재귀 검색은 트리 보기에서만 사용할 수 있습니다",
"collapseAllDisabled": "목록 보기에서는 사용할 수 없습니다" "collapseAllDisabled": "목록 보기에서는 사용할 수 없습니다"
}, },
"statistics": { "statistics": {

View File

@@ -529,12 +529,15 @@
"title": "Модели Embedding" "title": "Модели Embedding"
}, },
"sidebar": { "sidebar": {
"modelRoot": "Корень моделей", "modelRoot": "Корень",
"collapseAll": "Свернуть все папки", "collapseAll": "Свернуть все папки",
"pinSidebar": "Закрепить боковую панель", "pinSidebar": "Закрепить боковую панель",
"unpinSidebar": "Открепить боковую панель", "unpinSidebar": "Открепить боковую панель",
"switchToListView": "Переключить на вид списка", "switchToListView": "Переключить на вид списка",
"switchToTreeView": "Переключить на древовидный вид", "switchToTreeView": "Переключить на древовидный вид",
"recursiveOn": "Искать во вложенных папках",
"recursiveOff": "Искать только в текущей папке",
"recursiveUnavailable": "Рекурсивный поиск доступен только в режиме дерева",
"collapseAllDisabled": "Недоступно в виде списка" "collapseAllDisabled": "Недоступно в виде списка"
}, },
"statistics": { "statistics": {

View File

@@ -535,12 +535,15 @@
"title": "Embedding 模型" "title": "Embedding 模型"
}, },
"sidebar": { "sidebar": {
"modelRoot": "模型根目录", "modelRoot": "根目录",
"collapseAll": "折叠所有文件夹", "collapseAll": "折叠所有文件夹",
"pinSidebar": "固定侧边栏", "pinSidebar": "固定侧边栏",
"unpinSidebar": "取消固定侧边栏", "unpinSidebar": "取消固定侧边栏",
"switchToListView": "切换到列表视图", "switchToListView": "切换到列表视图",
"switchToTreeView": "切换到树状视图", "switchToTreeView": "切换到树状视图",
"recursiveOn": "搜索子文件夹",
"recursiveOff": "仅搜索当前文件夹",
"recursiveUnavailable": "仅在树形视图中可使用递归搜索",
"collapseAllDisabled": "列表视图下不可用" "collapseAllDisabled": "列表视图下不可用"
}, },
"statistics": { "statistics": {

View File

@@ -529,12 +529,15 @@
"title": "Embedding 模型" "title": "Embedding 模型"
}, },
"sidebar": { "sidebar": {
"modelRoot": "模型根目錄", "modelRoot": "根目錄",
"collapseAll": "全部摺疊資料夾", "collapseAll": "全部摺疊資料夾",
"pinSidebar": "固定側邊欄", "pinSidebar": "固定側邊欄",
"unpinSidebar": "取消固定側邊欄", "unpinSidebar": "取消固定側邊欄",
"switchToListView": "切換至列表檢視", "switchToListView": "切換至列表檢視",
"switchToTreeView": "切換樹狀檢視", "switchToTreeView": "切換樹狀檢視",
"recursiveOn": "搜尋子資料夾",
"recursiveOff": "僅搜尋目前資料夾",
"recursiveUnavailable": "遞迴搜尋僅能在樹狀檢視中使用",
"collapseAllDisabled": "列表檢視下不可用" "collapseAllDisabled": "列表檢視下不可用"
}, },
"statistics": { "statistics": {

View File

@@ -74,8 +74,9 @@ class Config:
"""Persist ComfyUI-derived folder paths to the multi-library settings.""" """Persist ComfyUI-derived folder paths to the multi-library settings."""
try: try:
ensure_settings_file(logger) ensure_settings_file(logger)
from .services.settings_manager import settings as settings_service from .services.settings_manager import get_settings_manager
settings_service = get_settings_manager()
libraries = settings_service.get_libraries() libraries = settings_service.get_libraries()
comfy_library = libraries.get("comfyui", {}) comfy_library = libraries.get("comfyui", {})
default_library = libraries.get("default", {}) default_library = libraries.get("default", {})
@@ -442,8 +443,9 @@ class Config:
"""Return the current library registry and active library name.""" """Return the current library registry and active library name."""
try: try:
from .services.settings_manager import settings as settings_service from .services.settings_manager import get_settings_manager
settings_service = get_settings_manager()
libraries = settings_service.get_libraries() libraries = settings_service.get_libraries()
active_library = settings_service.get_active_library_name() active_library = settings_service.get_active_library_name()
return { return {

View File

@@ -13,7 +13,7 @@ from .routes.misc_routes import MiscRoutes
from .routes.preview_routes import PreviewRoutes from .routes.preview_routes import PreviewRoutes
from .routes.example_images_routes import ExampleImagesRoutes from .routes.example_images_routes import ExampleImagesRoutes
from .services.service_registry import ServiceRegistry from .services.service_registry import ServiceRegistry
from .services.settings_manager import settings from .services.settings_manager import get_settings_manager
from .utils.example_images_migration import ExampleImagesMigration from .utils.example_images_migration import ExampleImagesMigration
from .services.websocket_manager import ws_manager from .services.websocket_manager import ws_manager
from .services.example_images_cleanup_service import ExampleImagesCleanupService from .services.example_images_cleanup_service import ExampleImagesCleanupService
@@ -23,6 +23,25 @@ logger = logging.getLogger(__name__)
# Check if we're in standalone mode # Check if we're in standalone mode
STANDALONE_MODE = 'nodes' not in sys.modules STANDALONE_MODE = 'nodes' not in sys.modules
class _SettingsProxy:
def __init__(self):
self._manager = None
def _resolve(self):
if self._manager is None:
self._manager = get_settings_manager()
return self._manager
def get(self, *args, **kwargs):
return self._resolve().get(*args, **kwargs)
def __getattr__(self, item):
return getattr(self._resolve(), item)
settings = _SettingsProxy()
class LoraManager: class LoraManager:
"""Main entry point for LoRA Manager plugin""" """Main entry point for LoRA Manager plugin"""

View File

@@ -17,7 +17,7 @@ from ..services.model_lifecycle_service import ModelLifecycleService
from ..services.preview_asset_service import PreviewAssetService from ..services.preview_asset_service import PreviewAssetService
from ..services.server_i18n import server_i18n as default_server_i18n from ..services.server_i18n import server_i18n as default_server_i18n
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..services.settings_manager import settings as default_settings from ..services.settings_manager import get_settings_manager
from ..services.tag_update_service import TagUpdateService from ..services.tag_update_service import TagUpdateService
from ..services.websocket_manager import ws_manager as default_ws_manager from ..services.websocket_manager import ws_manager as default_ws_manager
from ..services.use_cases import ( from ..services.use_cases import (
@@ -56,14 +56,14 @@ class BaseModelRoutes(ABC):
self, self,
service=None, service=None,
*, *,
settings_service=default_settings, settings_service=None,
ws_manager=default_ws_manager, ws_manager=default_ws_manager,
server_i18n=default_server_i18n, server_i18n=default_server_i18n,
metadata_provider_factory=get_default_metadata_provider, metadata_provider_factory=get_default_metadata_provider,
) -> None: ) -> None:
self.service = None self.service = None
self.model_type = "" self.model_type = ""
self._settings = settings_service self._settings = settings_service or get_settings_manager()
self._ws_manager = ws_manager self._ws_manager = ws_manager
self._server_i18n = server_i18n self._server_i18n = server_i18n
self._metadata_provider_factory = metadata_provider_factory self._metadata_provider_factory = metadata_provider_factory
@@ -90,7 +90,7 @@ class BaseModelRoutes(ABC):
self._metadata_sync_service = MetadataSyncService( self._metadata_sync_service = MetadataSyncService(
metadata_manager=MetadataManager, metadata_manager=MetadataManager,
preview_service=self._preview_service, preview_service=self._preview_service,
settings=settings_service, settings=self._settings,
default_metadata_provider_factory=metadata_provider_factory, default_metadata_provider_factory=metadata_provider_factory,
metadata_provider_selector=get_metadata_provider, metadata_provider_selector=get_metadata_provider,
) )

View File

@@ -18,7 +18,7 @@ from ..services.recipes import (
) )
from ..services.server_i18n import server_i18n from ..services.server_i18n import server_i18n
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..services.settings_manager import settings from ..services.settings_manager import get_settings_manager
from ..utils.constants import CARD_PREVIEW_WIDTH from ..utils.constants import CARD_PREVIEW_WIDTH
from ..utils.exif_utils import ExifUtils from ..utils.exif_utils import ExifUtils
from .handlers.recipe_handlers import ( from .handlers.recipe_handlers import (
@@ -48,7 +48,7 @@ class BaseRecipeRoutes:
self.recipe_scanner = None self.recipe_scanner = None
self.lora_scanner = None self.lora_scanner = None
self.civitai_client = None self.civitai_client = None
self.settings = settings self.settings = get_settings_manager()
self.server_i18n = server_i18n self.server_i18n = server_i18n
self.template_env = jinja2.Environment( self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path), loader=jinja2.FileSystemLoader(config.templates_path),

View File

@@ -24,10 +24,17 @@ from ...services.metadata_service import (
update_metadata_providers, update_metadata_providers,
) )
from ...services.service_registry import ServiceRegistry from ...services.service_registry import ServiceRegistry
from ...services.settings_manager import settings as default_settings from ...services.settings_manager import get_settings_manager
from ...services.websocket_manager import ws_manager from ...services.websocket_manager import ws_manager
from ...services.downloader import get_downloader from ...services.downloader import get_downloader
from ...utils.constants import DEFAULT_NODE_COLOR, NODE_TYPES, SUPPORTED_MEDIA_EXTENSIONS from ...utils.constants import (
CIVITAI_USER_MODEL_TYPES,
DEFAULT_NODE_COLOR,
NODE_TYPES,
SUPPORTED_MEDIA_EXTENSIONS,
VALID_LORA_TYPES,
)
from ...utils.civitai_utils import rewrite_preview_url
from ...utils.example_images_paths import is_valid_example_images_root from ...utils.example_images_paths import is_valid_example_images_root
from ...utils.lora_metadata import extract_trained_words from ...utils.lora_metadata import extract_trained_words
from ...utils.usage_stats import UsageStats from ...utils.usage_stats import UsageStats
@@ -80,7 +87,7 @@ class NodeRegistry:
def __init__(self) -> None: def __init__(self) -> None:
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._nodes: Dict[int, dict] = {} self._nodes: Dict[str, dict] = {}
self._registry_updated = asyncio.Event() self._registry_updated = asyncio.Event()
async def register_nodes(self, nodes: list[dict]) -> None: async def register_nodes(self, nodes: list[dict]) -> None:
@@ -88,11 +95,16 @@ class NodeRegistry:
self._nodes.clear() self._nodes.clear()
for node in nodes: for node in nodes:
node_id = node["node_id"] node_id = node["node_id"]
graph_id = str(node["graph_id"])
unique_id = f"{graph_id}:{node_id}"
node_type = node.get("type", "") node_type = node.get("type", "")
type_id = NODE_TYPES.get(node_type, 0) type_id = NODE_TYPES.get(node_type, 0)
bgcolor = node.get("bgcolor") or DEFAULT_NODE_COLOR bgcolor = node.get("bgcolor") or DEFAULT_NODE_COLOR
self._nodes[node_id] = { self._nodes[unique_id] = {
"id": node_id, "id": node_id,
"graph_id": graph_id,
"graph_name": node.get("graph_name"),
"unique_id": unique_id,
"bgcolor": bgcolor, "bgcolor": bgcolor,
"title": node.get("title"), "title": node.get("title"),
"type": type_id, "type": type_id,
@@ -157,11 +169,11 @@ class SettingsHandler:
def __init__( def __init__(
self, self,
*, *,
settings_service=default_settings, settings_service=None,
metadata_provider_updater: Callable[[], Awaitable[None]] = update_metadata_providers, metadata_provider_updater: Callable[[], Awaitable[None]] = update_metadata_providers,
downloader_factory: Callable[[], Awaitable[DownloaderProtocol]] = get_downloader, downloader_factory: Callable[[], Awaitable[DownloaderProtocol]] = get_downloader,
) -> None: ) -> None:
self._settings = settings_service self._settings = settings_service or get_settings_manager()
self._metadata_provider_updater = metadata_provider_updater self._metadata_provider_updater = metadata_provider_updater
self._downloader_factory = downloader_factory self._downloader_factory = downloader_factory
@@ -330,16 +342,65 @@ class LoraCodeHandler:
logger.error("Error broadcasting lora code: %s", exc) logger.error("Error broadcasting lora code: %s", exc)
results.append({"node_id": "broadcast", "success": False, "error": str(exc)}) results.append({"node_id": "broadcast", "success": False, "error": str(exc)})
else: else:
for node_id in node_ids: for entry in node_ids:
node_identifier = entry
graph_identifier = None
if isinstance(entry, dict):
node_identifier = entry.get("node_id")
graph_identifier = entry.get("graph_id")
if node_identifier is None:
results.append(
{
"node_id": node_identifier,
"graph_id": graph_identifier,
"success": False,
"error": "Missing node_id parameter",
}
)
continue
try:
parsed_node_id = int(node_identifier)
except (TypeError, ValueError):
parsed_node_id = node_identifier
payload = {
"id": parsed_node_id,
"lora_code": lora_code,
"mode": mode,
}
if graph_identifier is not None:
payload["graph_id"] = str(graph_identifier)
try: try:
self._prompt_server.instance.send_sync( self._prompt_server.instance.send_sync(
"lora_code_update", "lora_code_update",
{"id": node_id, "lora_code": lora_code, "mode": mode}, payload,
)
results.append(
{
"node_id": parsed_node_id,
"graph_id": payload.get("graph_id"),
"success": True,
}
) )
results.append({"node_id": node_id, "success": True})
except Exception as exc: # pragma: no cover - defensive logging except Exception as exc: # pragma: no cover - defensive logging
logger.error("Error sending lora code to node %s: %s", node_id, exc) logger.error(
results.append({"node_id": node_id, "success": False, "error": str(exc)}) "Error sending lora code to node %s (graph %s): %s",
parsed_node_id,
graph_identifier,
exc,
)
results.append(
{
"node_id": parsed_node_id,
"graph_id": payload.get("graph_id"),
"success": False,
"error": str(exc),
}
)
return web.json_response({"success": True, "results": results}) return web.json_response({"success": True, "results": results})
except Exception as exc: # pragma: no cover - defensive logging except Exception as exc: # pragma: no cover - defensive logging
@@ -557,17 +618,118 @@ class ModelLibraryHandler:
logger.error("Failed to get model versions status: %s", exc, exc_info=True) logger.error("Failed to get model versions status: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500) return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_civitai_user_models(self, request: web.Request) -> web.Response:
try:
username = request.query.get("username")
if not username:
return web.json_response({"success": False, "error": "Missing required parameter: username"}, status=400)
metadata_provider = await self._metadata_provider_factory()
if not metadata_provider:
return web.json_response({"success": False, "error": "Metadata provider not available"}, status=503)
try:
models = await metadata_provider.get_user_models(username)
except NotImplementedError:
return web.json_response({"success": False, "error": "Metadata provider does not support user model queries"}, status=501)
if models is None:
return web.json_response({"success": False, "error": "Failed to fetch user models"}, status=502)
if not isinstance(models, list):
models = []
lora_scanner = await self._service_registry.get_lora_scanner()
checkpoint_scanner = await self._service_registry.get_checkpoint_scanner()
embedding_scanner = await self._service_registry.get_embedding_scanner()
normalized_allowed_types = {model_type.lower() for model_type in CIVITAI_USER_MODEL_TYPES}
lora_type_aliases = {model_type.lower() for model_type in VALID_LORA_TYPES}
type_scanner_map: Dict[str, object | None] = {
**{alias: lora_scanner for alias in lora_type_aliases},
"checkpoint": checkpoint_scanner,
"textualinversion": embedding_scanner,
}
versions: list[dict] = []
for model in models:
if not isinstance(model, dict):
continue
model_type = str(model.get("type", "")).lower()
if model_type not in normalized_allowed_types:
continue
scanner = type_scanner_map.get(model_type)
if scanner is None:
return web.json_response({"success": False, "error": f'Scanner for type "{model_type}" is not available'}, status=503)
tags_value = model.get("tags")
tags = tags_value if isinstance(tags_value, list) else []
model_id = model.get("id")
try:
model_id_int = int(model_id)
except (TypeError, ValueError):
continue
model_name = model.get("name", "")
versions_data = model.get("modelVersions")
if not isinstance(versions_data, list):
continue
for version in versions_data:
if not isinstance(version, dict):
continue
version_id = version.get("id")
try:
version_id_int = int(version_id)
except (TypeError, ValueError):
continue
images = version.get("images") or []
thumbnail_url = None
if images and isinstance(images, list):
first_image = images[0]
if isinstance(first_image, dict):
raw_url = first_image.get("url")
media_type = first_image.get("type")
rewritten_url, _ = rewrite_preview_url(raw_url, media_type)
thumbnail_url = rewritten_url
in_library = await scanner.check_model_version_exists(version_id_int)
versions.append(
{
"modelId": model_id_int,
"versionId": version_id_int,
"modelName": model_name,
"versionName": version.get("name", ""),
"type": model.get("type"),
"tags": tags,
"baseModel": version.get("baseModel"),
"thumbnailUrl": thumbnail_url,
"inLibrary": in_library,
}
)
return web.json_response({"success": True, "username": username, "versions": versions})
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Failed to get Civitai user models: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
class MetadataArchiveHandler: class MetadataArchiveHandler:
def __init__( def __init__(
self, self,
*, *,
metadata_archive_manager_factory: Callable[[], Awaitable[MetadataArchiveManagerProtocol]] = get_metadata_archive_manager, metadata_archive_manager_factory: Callable[[], Awaitable[MetadataArchiveManagerProtocol]] = get_metadata_archive_manager,
settings_service=default_settings, settings_service=None,
metadata_provider_updater: Callable[[], Awaitable[None]] = update_metadata_providers, metadata_provider_updater: Callable[[], Awaitable[None]] = update_metadata_providers,
) -> None: ) -> None:
self._metadata_archive_manager_factory = metadata_archive_manager_factory self._metadata_archive_manager_factory = metadata_archive_manager_factory
self._settings = settings_service self._settings = settings_service or get_settings_manager()
self._metadata_provider_updater = metadata_provider_updater self._metadata_provider_updater = metadata_provider_updater
async def download_metadata_archive(self, request: web.Request) -> web.Response: async def download_metadata_archive(self, request: web.Request) -> web.Response:
@@ -679,10 +841,21 @@ class NodeRegistryHandler:
node_id = node.get("node_id") node_id = node.get("node_id")
if node_id is None: if node_id is None:
return web.json_response({"success": False, "error": f"Node {index} missing node_id parameter"}, status=400) return web.json_response({"success": False, "error": f"Node {index} missing node_id parameter"}, status=400)
graph_id = node.get("graph_id")
if graph_id is None:
return web.json_response({"success": False, "error": f"Node {index} missing graph_id parameter"}, status=400)
graph_name = node.get("graph_name")
try: try:
node["node_id"] = int(node_id) node["node_id"] = int(node_id)
except (TypeError, ValueError): except (TypeError, ValueError):
return web.json_response({"success": False, "error": f"Node {index} node_id must be an integer"}, status=400) return web.json_response({"success": False, "error": f"Node {index} node_id must be an integer"}, status=400)
node["graph_id"] = str(graph_id)
if graph_name is None:
node["graph_name"] = None
elif isinstance(graph_name, str):
node["graph_name"] = graph_name
else:
node["graph_name"] = str(graph_name)
await self._node_registry.register_nodes(nodes) await self._node_registry.register_nodes(nodes)
return web.json_response({"success": True, "message": f"{len(nodes)} nodes registered successfully"}) return web.json_response({"success": True, "message": f"{len(nodes)} nodes registered successfully"})
@@ -779,6 +952,7 @@ class MiscHandlerSet:
"register_nodes": self.node_registry.register_nodes, "register_nodes": self.node_registry.register_nodes,
"get_registry": self.node_registry.get_registry, "get_registry": self.node_registry.get_registry,
"check_model_exists": self.model_library.check_model_exists, "check_model_exists": self.model_library.check_model_exists,
"get_civitai_user_models": self.model_library.get_civitai_user_models,
"download_metadata_archive": self.metadata_archive.download_metadata_archive, "download_metadata_archive": self.metadata_archive.download_metadata_archive,
"remove_metadata_archive": self.metadata_archive.remove_metadata_archive, "remove_metadata_archive": self.metadata_archive.remove_metadata_archive,
"get_metadata_archive_status": self.metadata_archive.get_metadata_archive_status, "get_metadata_archive_status": self.metadata_archive.get_metadata_archive_status,

View File

@@ -30,6 +30,7 @@ from ...services.use_cases import (
from ...services.websocket_manager import WebSocketManager from ...services.websocket_manager import WebSocketManager
from ...services.websocket_progress_callback import WebSocketProgressCallback from ...services.websocket_progress_callback import WebSocketProgressCallback
from ...utils.file_utils import calculate_sha256 from ...utils.file_utils import calculate_sha256
from ...utils.metadata_manager import MetadataManager
class ModelPageView: class ModelPageView:
@@ -244,6 +245,8 @@ class ModelManagementHandler:
if not model_data.get("sha256"): if not model_data.get("sha256"):
return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400) return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400)
await MetadataManager.hydrate_model_data(model_data)
success, error = await self._metadata_sync.fetch_and_update_model( success, error = await self._metadata_sync.fetch_and_update_model(
sha256=model_data["sha256"], sha256=model_data["sha256"],
file_path=file_path, file_path=file_path,
@@ -825,18 +828,30 @@ class ModelCivitaiHandler:
status=400, status=400,
) )
cache = await self._service.scanner.get_cached_data()
version_index = cache.version_index
for version in versions: for version in versions:
model_file = self._find_model_file(version.get("files", [])) if isinstance(version.get("files"), Iterable) else None version_id = None
if model_file: version_id_raw = version.get("id")
hashes = model_file.get("hashes", {}) if isinstance(model_file, Mapping) else {} if version_id_raw is not None:
sha256 = hashes.get("SHA256") if isinstance(hashes, Mapping) else None try:
if sha256: version_id = int(str(version_id_raw))
version["existsLocally"] = self._service.has_hash(sha256) except (TypeError, ValueError):
if version["existsLocally"]: version_id = None
version["localPath"] = self._service.get_path_by_hash(sha256)
version["modelSizeKB"] = model_file.get("sizeKB") if isinstance(model_file, Mapping) else None cache_entry = version_index.get(version_id) if (version_id is not None and version_index) else None
version["existsLocally"] = cache_entry is not None
if cache_entry and isinstance(cache_entry, Mapping):
local_path = cache_entry.get("file_path")
if local_path:
version["localPath"] = local_path
else: else:
version["existsLocally"] = False version.pop("localPath", None)
model_file = self._find_model_file(version.get("files", [])) if isinstance(version.get("files"), Iterable) else None
if model_file and isinstance(model_file, Mapping):
version["modelSizeKB"] = model_file.get("sizeKB")
return web.json_response(versions) return web.json_response(versions)
except Exception as exc: except Exception as exc:
self._logger.error("Error fetching %s model versions: %s", self._service.model_type, exc) self._logger.error("Error fetching %s model versions: %s", self._service.model_type, exc)

View File

@@ -229,11 +229,27 @@ class LoraRoutes(BaseModelRoutes):
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
# Send update to all connected trigger word toggle nodes # Send update to all connected trigger word toggle nodes
for node_id in node_ids: for entry in node_ids:
PromptServer.instance.send_sync("trigger_word_update", { node_identifier = entry
"id": node_id, graph_identifier = None
if isinstance(entry, dict):
node_identifier = entry.get("node_id")
graph_identifier = entry.get("graph_id")
try:
parsed_node_id = int(node_identifier)
except (TypeError, ValueError):
parsed_node_id = node_identifier
payload = {
"id": parsed_node_id,
"message": trigger_words_text "message": trigger_words_text
}) }
if graph_identifier is not None:
payload["graph_id"] = str(graph_identifier)
PromptServer.instance.send_sync("trigger_word_update", payload)
return web.json_response({"success": True}) return web.json_response({"success": True})

View File

@@ -34,6 +34,7 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("POST", "/api/lm/register-nodes", "register_nodes"), RouteDefinition("POST", "/api/lm/register-nodes", "register_nodes"),
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"), RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"), RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"),
RouteDefinition("POST", "/api/lm/download-metadata-archive", "download_metadata_archive"), RouteDefinition("POST", "/api/lm/download-metadata-archive", "download_metadata_archive"),
RouteDefinition("POST", "/api/lm/remove-metadata-archive", "remove_metadata_archive"), RouteDefinition("POST", "/api/lm/remove-metadata-archive", "remove_metadata_archive"),
RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"), RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"),

View File

@@ -14,7 +14,7 @@ from ..services.metadata_service import (
get_metadata_provider, get_metadata_provider,
update_metadata_providers, update_metadata_providers,
) )
from ..services.settings_manager import settings from ..services.settings_manager import get_settings_manager
from ..services.downloader import get_downloader from ..services.downloader import get_downloader
from ..utils.usage_stats import UsageStats from ..utils.usage_stats import UsageStats
from .handlers.misc_handlers import ( from .handlers.misc_handlers import (
@@ -47,7 +47,7 @@ class MiscRoutes:
def __init__( def __init__(
self, self,
*, *,
settings_service=settings, settings_service=None,
usage_stats_factory: Callable[[], UsageStats] = UsageStats, usage_stats_factory: Callable[[], UsageStats] = UsageStats,
prompt_server: type[PromptServer] = PromptServer, prompt_server: type[PromptServer] = PromptServer,
service_registry_adapter=build_service_registry_adapter(), service_registry_adapter=build_service_registry_adapter(),
@@ -60,7 +60,7 @@ class MiscRoutes:
node_registry: NodeRegistry | None = None, node_registry: NodeRegistry | None = None,
standalone_mode_flag: bool = standalone_mode, standalone_mode_flag: bool = standalone_mode,
) -> None: ) -> None:
self._settings = settings_service self._settings = settings_service or get_settings_manager()
self._usage_stats_factory = usage_stats_factory self._usage_stats_factory = usage_stats_factory
self._prompt_server = prompt_server self._prompt_server = prompt_server
self._service_registry_adapter = service_registry_adapter self._service_registry_adapter = service_registry_adapter

View File

@@ -8,13 +8,32 @@ from collections import defaultdict, Counter
from typing import Dict, List, Any from typing import Dict, List, Any
from ..config import config from ..config import config
from ..services.settings_manager import settings from ..services.settings_manager import get_settings_manager
from ..services.server_i18n import server_i18n from ..services.server_i18n import server_i18n
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..utils.usage_stats import UsageStats from ..utils.usage_stats import UsageStats
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class _SettingsProxy:
def __init__(self):
self._manager = None
def _resolve(self):
if self._manager is None:
self._manager = get_settings_manager()
return self._manager
def get(self, *args, **kwargs):
return self._resolve().get(*args, **kwargs)
def __getattr__(self, item):
return getattr(self._resolve(), item)
settings = _SettingsProxy()
class StatsRoutes: class StatsRoutes:
"""Route handlers for Statistics page and API endpoints""" """Route handlers for Statistics page and API endpoints"""
@@ -66,7 +85,9 @@ class StatsRoutes:
is_initializing = lora_initializing or checkpoint_initializing or embedding_initializing is_initializing = lora_initializing or checkpoint_initializing or embedding_initializing
# 获取用户语言设置 # 获取用户语言设置
user_language = settings.get('language', 'en') settings_object = settings
user_language = settings_object.get('language', 'en')
settings_manager = settings_object if not isinstance(settings_object, _SettingsProxy) else settings_object._resolve()
# 设置服务端i18n语言 # 设置服务端i18n语言
server_i18n.set_locale(user_language) server_i18n.set_locale(user_language)
@@ -79,7 +100,7 @@ class StatsRoutes:
template = self.template_env.get_template('statistics.html') template = self.template_env.get_template('statistics.html')
rendered = template.render( rendered = template.render(
is_initializing=is_initializing, is_initializing=is_initializing,
settings=settings, settings=settings_manager,
request=request, request=request,
t=server_i18n.get_translation, t=server_i18n.get_translation,
) )

View File

@@ -6,7 +6,7 @@ import os
from ..utils.models import BaseModelMetadata from ..utils.models import BaseModelMetadata
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider
from .settings_manager import settings as default_settings from .settings_manager import get_settings_manager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -38,7 +38,7 @@ class BaseModelService(ABC):
self.model_type = model_type self.model_type = model_type
self.scanner = scanner self.scanner = scanner
self.metadata_class = metadata_class self.metadata_class = metadata_class
self.settings = settings_provider or default_settings self.settings = settings_provider or get_settings_manager()
self.cache_repository = cache_repository or ModelCacheRepository(scanner) self.cache_repository = cache_repository or ModelCacheRepository(scanner)
self.filter_set = filter_set or ModelFilterSet(self.settings) self.filter_set = filter_set or ModelFilterSet(self.settings)
self.search_strategy = search_strategy or SearchStrategy() self.search_strategy = search_strategy or SearchStrategy()

View File

@@ -1,7 +1,7 @@
import os import asyncio
import copy import copy
import logging import logging
import asyncio import os
from typing import Optional, Dict, Tuple, List from typing import Optional, Dict, Tuple, List
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
from .downloader import get_downloader from .downloader import get_downloader
@@ -157,141 +157,160 @@ class CivitaiClient:
return None return None
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
"""Get specific model version with additional metadata """Get specific model version with additional metadata."""
Args:
model_id: The Civitai model ID (optional if version_id is provided)
version_id: Optional specific version ID to retrieve
Returns:
Optional[Dict]: The model version data with additional fields or None if not found
"""
try: try:
downloader = await get_downloader() downloader = await get_downloader()
# Case 1: Only version_id is provided
if model_id is None and version_id is not None: if model_id is None and version_id is not None:
# First get the version info to extract model_id return await self._get_version_by_id_only(downloader, version_id)
success, version = await downloader.make_request(
'GET',
f"{self.base_url}/model-versions/{version_id}",
use_auth=True
)
if not success:
return None
model_id = version.get('modelId') if model_id is not None:
if not model_id: return await self._get_version_with_model_id(downloader, model_id, version_id)
logger.error(f"No modelId found in version {version_id}")
return None
# Now get the model data for additional metadata logger.error("Either model_id or version_id must be provided")
success, model_data = await downloader.make_request( return None
'GET',
f"{self.base_url}/models/{model_id}",
use_auth=True
)
if success:
# Enrich version with model data
version['model']['description'] = model_data.get("description")
version['model']['tags'] = model_data.get("tags", [])
version['creator'] = model_data.get("creator")
self._remove_comfy_metadata(version)
return version
# Case 2: model_id is provided (with or without version_id)
elif model_id is not None:
# Step 1: Get model data to find version_id if not provided and get additional metadata
success, data = await downloader.make_request(
'GET',
f"{self.base_url}/models/{model_id}",
use_auth=True
)
if not success:
return None
model_versions = data.get('modelVersions', [])
if not model_versions:
logger.warning(f"No model versions found for model {model_id}")
return None
# Step 2: Determine the target version entry to use
target_version = None
if version_id is not None:
target_version = next(
(item for item in model_versions if item.get('id') == version_id),
None
)
if target_version is None:
logger.warning(
f"Version {version_id} not found for model {model_id}, defaulting to first version"
)
if target_version is None:
target_version = model_versions[0]
target_version_id = target_version.get('id')
# Step 3: Get detailed version info using the SHA256 hash
model_hash = None
for file_info in target_version.get('files', []):
if file_info.get('type') == 'Model' and file_info.get('primary'):
model_hash = file_info.get('hashes', {}).get('SHA256')
if model_hash:
break
version = None
if model_hash:
success, version = await downloader.make_request(
'GET',
f"{self.base_url}/model-versions/by-hash/{model_hash}",
use_auth=True
)
if not success:
logger.warning(
f"Failed to fetch version by hash for model {model_id} version {target_version_id}: {version}"
)
version = None
else:
logger.warning(
f"No primary model hash found for model {model_id} version {target_version_id}"
)
if version is None:
version = copy.deepcopy(target_version)
version.pop('index', None)
version['modelId'] = model_id
version['model'] = {
'name': data.get('name'),
'type': data.get('type'),
'nsfw': data.get('nsfw'),
'poi': data.get('poi')
}
# Step 4: Enrich version_info with model data
# Add description and tags from model data
model_info = version.get('model')
if not isinstance(model_info, dict):
model_info = {}
version['model'] = model_info
model_info['description'] = data.get("description")
model_info['tags'] = data.get("tags", [])
# Add creator from model data
version['creator'] = data.get("creator")
self._remove_comfy_metadata(version)
return version
# Case 3: Neither model_id nor version_id provided
else:
logger.error("Either model_id or version_id must be provided")
return None
except Exception as e: except Exception as e:
logger.error(f"Error fetching model version: {e}") logger.error(f"Error fetching model version: {e}")
return None return None
async def _get_version_by_id_only(self, downloader, version_id: int) -> Optional[Dict]:
version = await self._fetch_version_by_id(downloader, version_id)
if version is None:
return None
model_id = version.get('modelId')
if not model_id:
logger.error(f"No modelId found in version {version_id}")
return None
model_data = await self._fetch_model_data(downloader, model_id)
if model_data:
self._enrich_version_with_model_data(version, model_data)
self._remove_comfy_metadata(version)
return version
async def _get_version_with_model_id(self, downloader, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
model_data = await self._fetch_model_data(downloader, model_id)
if not model_data:
return None
target_version = self._select_target_version(model_data, model_id, version_id)
if target_version is None:
return None
target_version_id = target_version.get('id')
version = await self._fetch_version_by_id(downloader, target_version_id) if target_version_id else None
if version is None:
model_hash = self._extract_primary_model_hash(target_version)
if model_hash:
version = await self._fetch_version_by_hash(downloader, model_hash)
else:
logger.warning(
f"No primary model hash found for model {model_id} version {target_version_id}"
)
if version is None:
version = self._build_version_from_model_data(target_version, model_id, model_data)
self._enrich_version_with_model_data(version, model_data)
self._remove_comfy_metadata(version)
return version
async def _fetch_model_data(self, downloader, model_id: int) -> Optional[Dict]:
success, data = await downloader.make_request(
'GET',
f"{self.base_url}/models/{model_id}",
use_auth=True
)
if success:
return data
logger.warning(f"Failed to fetch model data for model {model_id}")
return None
async def _fetch_version_by_id(self, downloader, version_id: Optional[int]) -> Optional[Dict]:
if version_id is None:
return None
success, version = await downloader.make_request(
'GET',
f"{self.base_url}/model-versions/{version_id}",
use_auth=True
)
if success:
return version
logger.warning(f"Failed to fetch version by id {version_id}")
return None
async def _fetch_version_by_hash(self, downloader, model_hash: Optional[str]) -> Optional[Dict]:
if not model_hash:
return None
success, version = await downloader.make_request(
'GET',
f"{self.base_url}/model-versions/by-hash/{model_hash}",
use_auth=True
)
if success:
return version
logger.warning(f"Failed to fetch version by hash {model_hash}")
return None
def _select_target_version(self, model_data: Dict, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
model_versions = model_data.get('modelVersions', [])
if not model_versions:
logger.warning(f"No model versions found for model {model_id}")
return None
if version_id is not None:
target_version = next(
(item for item in model_versions if item.get('id') == version_id),
None
)
if target_version is None:
logger.warning(
f"Version {version_id} not found for model {model_id}, defaulting to first version"
)
return model_versions[0]
return target_version
return model_versions[0]
def _extract_primary_model_hash(self, version_entry: Dict) -> Optional[str]:
for file_info in version_entry.get('files', []):
if file_info.get('type') == 'Model' and file_info.get('primary'):
hashes = file_info.get('hashes', {})
model_hash = hashes.get('SHA256')
if model_hash:
return model_hash
return None
def _build_version_from_model_data(self, version_entry: Dict, model_id: int, model_data: Dict) -> Dict:
version = copy.deepcopy(version_entry)
version.pop('index', None)
version['modelId'] = model_id
version['model'] = {
'name': model_data.get('name'),
'type': model_data.get('type'),
'nsfw': model_data.get('nsfw'),
'poi': model_data.get('poi')
}
return version
def _enrich_version_with_model_data(self, version: Dict, model_data: Dict) -> None:
model_info = version.get('model')
if not isinstance(model_info, dict):
model_info = {}
version['model'] = model_info
model_info['description'] = model_data.get("description")
model_info['tags'] = model_data.get("tags", [])
version['creator'] = model_data.get("creator")
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version metadata from Civitai """Fetch model version metadata from Civitai
@@ -366,3 +385,37 @@ class CivitaiClient:
error_msg = f"Error fetching image info: {e}" error_msg = f"Error fetching image info: {e}"
logger.error(error_msg) logger.error(error_msg)
return None return None
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
"""Fetch all models for a specific Civitai user."""
if not username:
return None
try:
downloader = await get_downloader()
url = f"{self.base_url}/models?username={username}"
success, result = await downloader.make_request(
'GET',
url,
use_auth=True
)
if not success:
logger.error("Failed to fetch models for %s: %s", username, result)
return None
items = result.get("items") if isinstance(result, dict) else None
if not isinstance(items, list):
return []
for model in items:
versions = model.get("modelVersions")
if not isinstance(versions, list):
continue
for version in versions:
self._remove_comfy_metadata(version)
return items
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Error fetching models for %s: %s", username, exc)
return None

View File

@@ -4,12 +4,14 @@ import asyncio
from collections import OrderedDict from collections import OrderedDict
import uuid import uuid
from typing import Dict, List from typing import Dict, List
from urllib.parse import urlparse
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
from ..utils.civitai_utils import rewrite_preview_url
from ..utils.exif_utils import ExifUtils from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
from .service_registry import ServiceRegistry from .service_registry import ServiceRegistry
from .settings_manager import settings from .settings_manager import get_settings_manager
from .metadata_service import get_default_metadata_provider from .metadata_service import get_default_metadata_provider
from .downloader import get_downloader from .downloader import get_downloader
@@ -241,19 +243,20 @@ class DownloadManager:
# Handle use_default_paths # Handle use_default_paths
if use_default_paths: if use_default_paths:
settings_manager = get_settings_manager()
# Set save_dir based on model type # Set save_dir based on model type
if model_type == 'checkpoint': if model_type == 'checkpoint':
default_path = settings.get('default_checkpoint_root') default_path = settings_manager.get('default_checkpoint_root')
if not default_path: if not default_path:
return {'success': False, 'error': 'Default checkpoint root path not set in settings'} return {'success': False, 'error': 'Default checkpoint root path not set in settings'}
save_dir = default_path save_dir = default_path
elif model_type == 'lora': elif model_type == 'lora':
default_path = settings.get('default_lora_root') default_path = settings_manager.get('default_lora_root')
if not default_path: if not default_path:
return {'success': False, 'error': 'Default lora root path not set in settings'} return {'success': False, 'error': 'Default lora root path not set in settings'}
save_dir = default_path save_dir = default_path
elif model_type == 'embedding': elif model_type == 'embedding':
default_path = settings.get('default_embedding_root') default_path = settings_manager.get('default_embedding_root')
if not default_path: if not default_path:
return {'success': False, 'error': 'Default embedding root path not set in settings'} return {'success': False, 'error': 'Default embedding root path not set in settings'}
save_dir = default_path save_dir = default_path
@@ -360,7 +363,8 @@ class DownloadManager:
Relative path string Relative path string
""" """
# Get path template from settings for specific model type # Get path template from settings for specific model type
path_template = settings.get_download_path_template(model_type) settings_manager = get_settings_manager()
path_template = settings_manager.get_download_path_template(model_type)
# If template is empty, return empty path (flat structure) # If template is empty, return empty path (flat structure)
if not path_template: if not path_template:
@@ -377,7 +381,7 @@ class DownloadManager:
author = 'Anonymous' author = 'Anonymous'
# Apply mapping if available # Apply mapping if available
base_model_mappings = settings.get('base_model_path_mappings', {}) base_model_mappings = settings_manager.get('base_model_path_mappings', {})
mapped_base_model = base_model_mappings.get(base_model, base_model) mapped_base_model = base_model_mappings.get(base_model, base_model)
# Get model tags # Get model tags
@@ -448,70 +452,103 @@ class DownloadManager:
# Download preview image if available # Download preview image if available
images = version_info.get('images', []) images = version_info.get('images', [])
if images: if images:
# Report preview download progress
if progress_callback: if progress_callback:
await progress_callback(1) # 1% progress for starting preview download await progress_callback(1) # 1% progress for starting preview download
# Check if it's a video or an image first_image = images[0] if isinstance(images[0], dict) else None
is_video = images[0].get('type') == 'video' preview_url = first_image.get('url') if first_image else None
media_type = (first_image.get('type') or '').lower() if first_image else ''
nsfw_level = first_image.get('nsfwLevel', 0) if first_image else 0
if (is_video): def _extension_from_url(url: str, fallback: str) -> str:
# For videos, use .mp4 extension try:
preview_ext = '.mp4' parsed = urlparse(url)
preview_path = os.path.splitext(save_path)[0] + preview_ext except ValueError:
return fallback
ext = os.path.splitext(parsed.path)[1]
return ext or fallback
# Download video directly using downloader preview_downloaded = False
preview_path = None
if preview_url:
downloader = await get_downloader() downloader = await get_downloader()
success, result = await downloader.download_file(
images[0]['url'],
preview_path,
use_auth=False # Preview images typically don't need auth
)
if success:
metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
else:
# For images, use WebP format for better performance
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
temp_path = temp_file.name
# Download the original image to temp path using downloader if media_type == 'video':
downloader = await get_downloader() preview_ext = _extension_from_url(preview_url, '.mp4')
success, content, headers = await downloader.download_to_memory( preview_path = os.path.splitext(save_path)[0] + preview_ext
images[0]['url'], rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='video')
use_auth=False attempt_urls: List[str] = []
) if rewritten:
if success: attempt_urls.append(rewritten_url)
# Save to temp file attempt_urls.append(preview_url)
with open(temp_path, 'wb') as f:
f.write(content)
# Optimize and convert to WebP
preview_path = os.path.splitext(save_path)[0] + '.webp'
# Use ExifUtils to optimize and convert the image seen_attempts = set()
optimized_data, _ = ExifUtils.optimize_image( for attempt in attempt_urls:
image_data=temp_path, if not attempt or attempt in seen_attempts:
target_width=CARD_PREVIEW_WIDTH, continue
format='webp', seen_attempts.add(attempt)
quality=85, success, _ = await downloader.download_file(
preserve_metadata=False attempt,
) preview_path,
use_auth=False
)
if success:
preview_downloaded = True
break
else:
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='image')
if rewritten:
preview_ext = _extension_from_url(preview_url, '.png')
preview_path = os.path.splitext(save_path)[0] + preview_ext
success, _ = await downloader.download_file(
rewritten_url,
preview_path,
use_auth=False
)
if success:
preview_downloaded = True
# Save the optimized image if not preview_downloaded:
with open(preview_path, 'wb') as f: temp_path: str | None = None
f.write(optimized_data) try:
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
temp_path = temp_file.name
# Update metadata success, content, _ = await downloader.download_to_memory(
metadata.preview_url = preview_path.replace(os.sep, '/') preview_url,
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0) use_auth=False
)
if success:
with open(temp_path, 'wb') as temp_file_handle:
temp_file_handle.write(content)
preview_path = os.path.splitext(save_path)[0] + '.webp'
# Remove temporary file optimized_data, _ = ExifUtils.optimize_image(
try: image_data=temp_path,
os.unlink(temp_path) target_width=CARD_PREVIEW_WIDTH,
except Exception as e: format='webp',
logger.warning(f"Failed to delete temp file: {e}") quality=85,
preserve_metadata=False
)
with open(preview_path, 'wb') as preview_file:
preview_file.write(optimized_data)
preview_downloaded = True
finally:
if temp_path and os.path.exists(temp_path):
try:
os.unlink(temp_path)
except Exception as e:
logger.warning(f"Failed to delete temp file: {e}")
if preview_downloaded and preview_path:
metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = nsfw_level
if download_id and download_id in self._active_downloads:
self._active_downloads[download_id]['preview_path'] = preview_path
# Report preview download completion
if progress_callback: if progress_callback:
await progress_callback(3) # 3% progress after preview download await progress_callback(3) # 3% progress after preview download
@@ -675,7 +712,15 @@ class DownloadManager:
except Exception as e: except Exception as e:
logger.error(f"Error deleting metadata file: {e}") logger.error(f"Error deleting metadata file: {e}")
# Delete preview file if exists (.webp or .mp4) preview_path_value = download_info.get('preview_path')
if preview_path_value and os.path.exists(preview_path_value):
try:
os.unlink(preview_path_value)
logger.debug(f"Deleted preview file: {preview_path_value}")
except Exception as e:
logger.error(f"Error deleting preview file: {e}")
# Delete preview file if exists (.webp or .mp4) for legacy paths
for preview_ext in ['.webp', '.mp4']: for preview_ext in ['.webp', '.mp4']:
preview_path = os.path.splitext(file_path)[0] + preview_ext preview_path = os.path.splitext(file_path)[0] + preview_ext
if os.path.exists(preview_path): if os.path.exists(preview_path):

View File

@@ -16,7 +16,7 @@ import asyncio
import aiohttp import aiohttp
from datetime import datetime from datetime import datetime
from typing import Optional, Dict, Tuple, Callable, Union from typing import Optional, Dict, Tuple, Callable, Union
from ..services.settings_manager import settings from ..services.settings_manager import get_settings_manager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -94,12 +94,13 @@ class Downloader:
# Check for app-level proxy settings # Check for app-level proxy settings
proxy_url = None proxy_url = None
if settings.get('proxy_enabled', False): settings_manager = get_settings_manager()
proxy_host = settings.get('proxy_host', '').strip() if settings_manager.get('proxy_enabled', False):
proxy_port = settings.get('proxy_port', '').strip() proxy_host = settings_manager.get('proxy_host', '').strip()
proxy_type = settings.get('proxy_type', 'http').lower() proxy_port = settings_manager.get('proxy_port', '').strip()
proxy_username = settings.get('proxy_username', '').strip() proxy_type = settings_manager.get('proxy_type', 'http').lower()
proxy_password = settings.get('proxy_password', '').strip() proxy_username = settings_manager.get('proxy_username', '').strip()
proxy_password = settings_manager.get('proxy_password', '').strip()
if proxy_host and proxy_port: if proxy_host and proxy_port:
# Build proxy URL # Build proxy URL
@@ -146,7 +147,8 @@ class Downloader:
if use_auth: if use_auth:
# Add CivitAI API key if available # Add CivitAI API key if available
api_key = settings.get('civitai_api_key') settings_manager = get_settings_manager()
api_key = settings_manager.get('civitai_api_key')
if api_key: if api_key:
headers['Authorization'] = f'Bearer {api_key}' headers['Authorization'] = f'Bearer {api_key}'
headers['Content-Type'] = 'application/json' headers['Content-Type'] = 'application/json'

View File

@@ -11,7 +11,7 @@ from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from .service_registry import ServiceRegistry from .service_registry import ServiceRegistry
from .settings_manager import settings from .settings_manager import get_settings_manager
from ..utils.example_images_paths import iter_library_roots from ..utils.example_images_paths import iter_library_roots
@@ -62,7 +62,8 @@ class ExampleImagesCleanupService:
async def cleanup_example_image_folders(self) -> Dict[str, object]: async def cleanup_example_image_folders(self) -> Dict[str, object]:
"""Clean empty or orphaned example image folders by moving them under a deleted bucket.""" """Clean empty or orphaned example image folders by moving them under a deleted bucket."""
example_images_path = settings.get("example_images_path") settings_manager = get_settings_manager()
example_images_path = settings_manager.get("example_images_path")
if not example_images_path: if not example_images_path:
logger.debug("Cleanup skipped: example images path not configured") logger.debug("Cleanup skipped: example images path not configured")
return { return {

View File

@@ -6,7 +6,7 @@ from .model_metadata_provider import (
CivitaiModelMetadataProvider, CivitaiModelMetadataProvider,
FallbackMetadataProvider FallbackMetadataProvider
) )
from .settings_manager import settings from .settings_manager import get_settings_manager
from .metadata_archive_manager import MetadataArchiveManager from .metadata_archive_manager import MetadataArchiveManager
from .service_registry import ServiceRegistry from .service_registry import ServiceRegistry
@@ -21,7 +21,8 @@ async def initialize_metadata_providers():
provider_manager.default_provider = None provider_manager.default_provider = None
# Get settings # Get settings
enable_archive_db = settings.get('enable_metadata_archive_db', False) settings_manager = get_settings_manager()
enable_archive_db = settings_manager.get('enable_metadata_archive_db', False)
providers = [] providers = []
@@ -87,7 +88,8 @@ async def update_metadata_providers():
"""Update metadata providers based on current settings""" """Update metadata providers based on current settings"""
try: try:
# Get current settings # Get current settings
enable_archive_db = settings.get('enable_metadata_archive_db', False) settings_manager = get_settings_manager()
enable_archive_db = settings_manager.get('enable_metadata_archive_db', False)
# Reinitialize all providers with new settings # Reinitialize all providers with new settings
provider_manager = await initialize_metadata_providers() provider_manager = await initialize_metadata_providers()

View File

@@ -1,6 +1,6 @@
import asyncio import asyncio
from typing import List, Dict, Tuple from typing import Any, Dict, List, Optional, Tuple
from dataclasses import dataclass from dataclasses import dataclass, field
from operator import itemgetter from operator import itemgetter
from natsort import natsorted from natsort import natsorted
@@ -17,9 +17,11 @@ SUPPORTED_SORT_MODES = [
@dataclass @dataclass
class ModelCache: class ModelCache:
"""Cache structure for model data with extensible sorting""" """Cache structure for model data with extensible sorting."""
raw_data: List[Dict] raw_data: List[Dict]
folders: List[str] folders: List[str]
version_index: Dict[int, Dict] = field(default_factory=dict)
def __post_init__(self): def __post_init__(self):
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
@@ -28,6 +30,58 @@ class ModelCache:
self._last_sorted_data: List[Dict] = [] self._last_sorted_data: List[Dict] = []
# Default sort on init # Default sort on init
asyncio.create_task(self.resort()) asyncio.create_task(self.resort())
self.rebuild_version_index()
@staticmethod
def _normalize_version_id(value: Any) -> Optional[int]:
"""Normalize a potential version identifier into an integer."""
if isinstance(value, int):
return value
if isinstance(value, str):
try:
return int(value)
except ValueError:
return None
return None
def rebuild_version_index(self) -> None:
"""Rebuild the version index from the current raw data."""
self.version_index = {}
for item in self.raw_data:
self.add_to_version_index(item)
def add_to_version_index(self, item: Dict) -> None:
"""Register a cache item in the version index if possible."""
civitai_data = item.get('civitai') if isinstance(item, dict) else None
if not isinstance(civitai_data, dict):
return
version_id = self._normalize_version_id(civitai_data.get('id'))
if version_id is None:
return
self.version_index[version_id] = item
def remove_from_version_index(self, item: Dict) -> None:
"""Remove a cache item from the version index if present."""
civitai_data = item.get('civitai') if isinstance(item, dict) else None
if not isinstance(civitai_data, dict):
return
version_id = self._normalize_version_id(civitai_data.get('id'))
if version_id is None:
return
existing = self.version_index.get(version_id)
if existing is item or (
isinstance(existing, dict)
and existing.get('file_path') == item.get('file_path')
):
self.version_index.pop(version_id, None)
async def resort(self): async def resort(self):
"""Resort cached data according to last sort mode if set""" """Resort cached data according to last sort mode if set"""
@@ -41,6 +95,7 @@ class ModelCache:
all_folders = set(l['folder'] for l in self.raw_data) all_folders = set(l['folder'] for l in self.raw_data)
self.folders = sorted(list(all_folders), key=lambda x: x.lower()) self.folders = sorted(list(all_folders), key=lambda x: x.lower())
self.rebuild_version_index()
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]: def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
"""Sort data by sort_key and order""" """Sort data by sort_key and order"""

View File

@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
from ..utils.utils import calculate_relative_path_for_model, remove_empty_dirs from ..utils.utils import calculate_relative_path_for_model, remove_empty_dirs
from ..utils.constants import AUTO_ORGANIZE_BATCH_SIZE from ..utils.constants import AUTO_ORGANIZE_BATCH_SIZE
from ..services.settings_manager import settings from ..services.settings_manager import get_settings_manager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -114,7 +114,8 @@ class ModelFileService:
raise ValueError('No model roots configured') raise ValueError('No model roots configured')
# Check if flat structure is configured for this model type # Check if flat structure is configured for this model type
path_template = settings.get_download_path_template(self.model_type) settings_manager = get_settings_manager()
path_template = settings_manager.get_download_path_template(self.model_type)
result.is_flat_structure = not path_template result.is_flat_structure = not path_template
# Initialize tracking # Initialize tracking

View File

@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import json import json
import logging import logging
from typing import Optional, Dict, Tuple, Any from typing import Optional, Dict, Tuple, Any, List
from .downloader import get_downloader from .downloader import get_downloader
try: try:
@@ -61,6 +61,11 @@ class ModelMetadataProvider(ABC):
"""Fetch model version metadata""" """Fetch model version metadata"""
pass pass
@abstractmethod
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
"""Fetch models owned by the specified user"""
pass
class CivitaiModelMetadataProvider(ModelMetadataProvider): class CivitaiModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses Civitai API for metadata""" """Provider that uses Civitai API for metadata"""
@@ -79,6 +84,9 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
return await self.client.get_model_version_info(version_id) return await self.client.get_model_version_info(version_id)
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
return await self.client.get_user_models(username)
class CivArchiveModelMetadataProvider(ModelMetadataProvider): class CivArchiveModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses CivArchive HTML page parsing for metadata""" """Provider that uses CivArchive HTML page parsing for metadata"""
@@ -197,6 +205,10 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider):
"""Not supported by CivArchive provider - requires both model_id and version_id""" """Not supported by CivArchive provider - requires both model_id and version_id"""
return None, "CivArchive provider requires both model_id and version_id" return None, "CivArchive provider requires both model_id and version_id"
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
"""Not supported by CivArchive provider"""
return None
class SQLiteModelMetadataProvider(ModelMetadataProvider): class SQLiteModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses SQLite database for metadata""" """Provider that uses SQLite database for metadata"""
@@ -344,6 +356,10 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
version_data = await self._get_version_with_model_data(db, model_id, version_id) version_data = await self._get_version_with_model_data(db, model_id, version_id)
return version_data, None return version_data, None
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
"""Listing models by username is not supported for archive database"""
return None
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]: async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
"""Helper to build version data with model information""" """Helper to build version data with model information"""
# Get version details # Get version details
@@ -481,6 +497,17 @@ class FallbackMetadataProvider(ModelMetadataProvider):
continue continue
return None, "No provider could retrieve the data" return None, "No provider could retrieve the data"
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
for provider in self.providers:
try:
result = await provider.get_user_models(username)
if result is not None:
return result
except Exception as e:
logger.debug(f"Provider failed for get_user_models: {e}")
continue
return None
class ModelMetadataProviderManager: class ModelMetadataProviderManager:
"""Manager for selecting and using model metadata providers""" """Manager for selecting and using model metadata providers"""
@@ -523,6 +550,11 @@ class ModelMetadataProviderManager:
provider = self._get_provider(provider_name) provider = self._get_provider(provider_name)
return await provider.get_model_version_info(version_id) return await provider.get_model_version_info(version_id)
async def get_user_models(self, username: str, provider_name: str = None) -> Optional[List[Dict]]:
"""Fetch models owned by the specified user"""
provider = self._get_provider(provider_name)
return await provider.get_user_models(username)
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider: def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
"""Get provider by name or default provider""" """Get provider by name or default provider"""
if provider_name and provider_name in self.providers: if provider_name and provider_name in self.providers:

View File

@@ -634,6 +634,7 @@ class ModelScanner:
if model_data: if model_data:
# Add to cache # Add to cache
self._cache.raw_data.append(model_data) self._cache.raw_data.append(model_data)
self._cache.add_to_version_index(model_data)
# Update hash index if available # Update hash index if available
if 'sha256' in model_data and 'file_path' in model_data: if 'sha256' in model_data and 'file_path' in model_data:
@@ -662,6 +663,8 @@ class ModelScanner:
try: try:
model_to_remove = path_to_item[path] model_to_remove = path_to_item[path]
self._cache.remove_from_version_index(model_to_remove)
# Update tags count # Update tags count
for tag in model_to_remove.get('tags', []): for tag in model_to_remove.get('tags', []):
if tag in self._tags_count: if tag in self._tags_count:
@@ -684,6 +687,8 @@ class ModelScanner:
all_folders = set(item.get('folder', '') for item in self._cache.raw_data) all_folders = set(item.get('folder', '') for item in self._cache.raw_data)
self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
self._cache.rebuild_version_index()
# Resort cache # Resort cache
await self._cache.resort() await self._cache.resort()
@@ -829,6 +834,8 @@ class ModelScanner:
else: else:
self._cache.raw_data = list(scan_result.raw_data) self._cache.raw_data = list(scan_result.raw_data)
self._cache.rebuild_version_index()
await self._cache.resort() await self._cache.resort()
async def _gather_model_data( async def _gather_model_data(
@@ -934,6 +941,7 @@ class ModelScanner:
# Add to cache # Add to cache
self._cache.raw_data.append(metadata_dict) self._cache.raw_data.append(metadata_dict)
self._cache.add_to_version_index(metadata_dict)
# Resort cache data # Resort cache data
await self._cache.resort() await self._cache.resort()
@@ -1076,6 +1084,9 @@ class ModelScanner:
cache = await self.get_cached_data() cache = await self.get_cached_data()
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None) existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
if existing_item:
cache.remove_from_version_index(existing_item)
if existing_item and 'tags' in existing_item: if existing_item and 'tags' in existing_item:
for tag in existing_item.get('tags', []): for tag in existing_item.get('tags', []):
if tag in self._tags_count: if tag in self._tags_count:
@@ -1106,6 +1117,7 @@ class ModelScanner:
) )
cache.raw_data.append(cache_entry) cache.raw_data.append(cache_entry)
cache.add_to_version_index(cache_entry)
sha_value = cache_entry.get('sha256') sha_value = cache_entry.get('sha256')
if sha_value: if sha_value:
@@ -1117,6 +1129,8 @@ class ModelScanner:
for tag in cache_entry.get('tags', []): for tag in cache_entry.get('tags', []):
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
cache.rebuild_version_index()
await cache.resort() await cache.resort()
if cache_modified: if cache_modified:
@@ -1339,6 +1353,7 @@ class ModelScanner:
# Update hash index # Update hash index
for model in models_to_remove: for model in models_to_remove:
file_path = model['file_path'] file_path = model['file_path']
self._cache.remove_from_version_index(model)
if hasattr(self, '_hash_index') and self._hash_index: if hasattr(self, '_hash_index') and self._hash_index:
# Get the hash and filename before removal for duplicate checking # Get the hash and filename before removal for duplicate checking
file_name = os.path.splitext(os.path.basename(file_path))[0] file_name = os.path.splitext(os.path.basename(file_path))[0]
@@ -1354,6 +1369,7 @@ class ModelScanner:
self._cache.raw_data = [item for item in self._cache.raw_data if item['file_path'] not in file_paths] self._cache.raw_data = [item for item in self._cache.raw_data if item['file_path'] not in file_paths]
# Resort cache # Resort cache
self._cache.rebuild_version_index()
await self._cache.resort() await self._cache.resort()
await self._persist_current_cache() await self._persist_current_cache()
@@ -1393,16 +1409,17 @@ class ModelScanner:
Returns: Returns:
bool: True if the model version exists, False otherwise bool: True if the model version exists, False otherwise
""" """
try:
normalized_id = int(model_version_id)
except (TypeError, ValueError):
return False
try: try:
cache = await self.get_cached_data() cache = await self.get_cached_data()
if not cache or not cache.raw_data: if not cache:
return False return False
for item in cache.raw_data: return normalized_id in cache.version_index
if item.get('civitai') and item['civitai'].get('id') == model_version_id:
return True
return False
except Exception as e: except Exception as e:
logger.error(f"Error checking model version existence: {e}") logger.error(f"Error checking model version existence: {e}")
return False return False

View File

@@ -351,7 +351,7 @@ class PersistentModelCache:
def get_persistent_cache() -> PersistentModelCache: def get_persistent_cache() -> PersistentModelCache:
from .settings_manager import settings as settings_service # Local import to avoid cycles from .settings_manager import get_settings_manager # Local import to avoid cycles
library_name = settings_service.get_active_library_name() library_name = get_settings_manager().get_active_library_name()
return PersistentModelCache.get_default(library_name) return PersistentModelCache.get_default(library_name)

View File

@@ -5,8 +5,10 @@ from __future__ import annotations
import logging import logging
import os import os
from typing import Awaitable, Callable, Dict, Optional, Sequence from typing import Awaitable, Callable, Dict, Optional, Sequence
from urllib.parse import urlparse
from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS
from ..utils.civitai_utils import rewrite_preview_url
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -45,23 +47,59 @@ class PreviewAssetService:
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0] base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
preview_dir = os.path.dirname(metadata_path) preview_dir = os.path.dirname(metadata_path)
is_video = first_preview.get("type") == "video" is_video = first_preview.get("type") == "video"
preview_url = first_preview.get("url")
if not preview_url:
return
def extension_from_url(url: str, fallback: str) -> str:
try:
parsed = urlparse(url)
except ValueError:
return fallback
ext = os.path.splitext(parsed.path)[1]
return ext or fallback
downloader = await self._downloader_factory()
if is_video: if is_video:
extension = ".mp4" extension = extension_from_url(preview_url, ".mp4")
preview_path = os.path.join(preview_dir, base_name + extension) preview_path = os.path.join(preview_dir, base_name + extension)
downloader = await self._downloader_factory() rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type="video")
success, result = await downloader.download_file(
first_preview["url"], preview_path, use_auth=False attempt_urls = []
) if rewritten:
if success: attempt_urls.append(rewritten_url)
local_metadata["preview_url"] = preview_path.replace(os.sep, "/") attempt_urls.append(preview_url)
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
seen: set[str] = set()
for candidate in attempt_urls:
if not candidate or candidate in seen:
continue
seen.add(candidate)
success, _ = await downloader.download_file(candidate, preview_path, use_auth=False)
if success:
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
return
else: else:
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type="image")
if rewritten:
extension = extension_from_url(preview_url, ".png")
preview_path = os.path.join(preview_dir, base_name + extension)
success, _ = await downloader.download_file(
rewritten_url, preview_path, use_auth=False
)
if success:
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
return
extension = ".webp" extension = ".webp"
preview_path = os.path.join(preview_dir, base_name + extension) preview_path = os.path.join(preview_dir, base_name + extension)
downloader = await self._downloader_factory()
success, content, _headers = await downloader.download_to_memory( success, content, _headers = await downloader.download_to_memory(
first_preview["url"], use_auth=False preview_url, use_auth=False
) )
if not success: if not success:
return return

View File

@@ -3,6 +3,7 @@ import json
import os import os
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from threading import Lock
from typing import Any, Dict, Iterable, List, Mapping, Optional from typing import Any, Dict, Iterable, List, Mapping, Optional
from ..utils.settings_paths import ensure_settings_file from ..utils.settings_paths import ensure_settings_file
@@ -688,4 +689,38 @@ class SettingsManager:
return templates.get(model_type, '{base_model}/{first_tag}') return templates.get(model_type, '{base_model}/{first_tag}')
settings = SettingsManager()
_SETTINGS_MANAGER: Optional["SettingsManager"] = None
_SETTINGS_MANAGER_LOCK = Lock()
# Legacy module-level alias for backwards compatibility with callers that
# monkeypatch ``py.services.settings_manager.settings`` during tests.
settings: Optional["SettingsManager"] = None
def get_settings_manager() -> "SettingsManager":
"""Return the lazily initialised global :class:`SettingsManager`."""
global _SETTINGS_MANAGER, settings
if settings is not None:
return settings
if _SETTINGS_MANAGER is None:
with _SETTINGS_MANAGER_LOCK:
if _SETTINGS_MANAGER is None:
_SETTINGS_MANAGER = SettingsManager()
settings = _SETTINGS_MANAGER
return _SETTINGS_MANAGER
def reset_settings_manager() -> None:
"""Reset the cached settings manager instance.
Primarily intended for tests so they can configure the settings
directory before the manager touches the filesystem.
"""
global _SETTINGS_MANAGER, settings
with _SETTINGS_MANAGER_LOCK:
_SETTINGS_MANAGER = None
settings = None

View File

@@ -6,6 +6,7 @@ import logging
from typing import Any, Dict, Optional, Protocol, Sequence from typing import Any, Dict, Optional, Protocol, Sequence
from ..metadata_sync_service import MetadataSyncService from ..metadata_sync_service import MetadataSyncService
from ...utils.metadata_manager import MetadataManager
class MetadataRefreshProgressReporter(Protocol): class MetadataRefreshProgressReporter(Protocol):
@@ -70,6 +71,7 @@ class BulkMetadataRefreshUseCase:
for model in to_process: for model in to_process:
try: try:
original_name = model.get("model_name") original_name = model.get("model_name")
await MetadataManager.hydrate_model_data(model)
result, _ = await self._metadata_sync.fetch_and_update_model( result, _ = await self._metadata_sync.fetch_and_update_model(
sha256=model["sha256"], sha256=model["sha256"],
file_path=model["file_path"], file_path=model["file_path"],

48
py/utils/civitai_utils.py Normal file
View File

@@ -0,0 +1,48 @@
"""Utilities for working with Civitai assets."""
from __future__ import annotations
from urllib.parse import urlparse, urlunparse
def rewrite_preview_url(source_url: str | None, media_type: str | None = None) -> tuple[str | None, bool]:
"""Rewrite Civitai preview URLs to use optimized renditions.
Args:
source_url: Original preview URL from the Civitai API.
media_type: Optional media type hint (e.g. ``"image"`` or ``"video"``).
Returns:
A tuple of the potentially rewritten URL and a flag indicating whether the
replacement occurred. When the URL is not rewritten, the original value is
returned with ``False``.
"""
if not source_url:
return source_url, False
try:
parsed = urlparse(source_url)
except ValueError:
return source_url, False
if parsed.netloc.lower() != "image.civitai.com":
return source_url, False
replacement = "/width=450,optimized=true"
if (media_type or "").lower() == "video":
replacement = "/transcode=true,width=450,optimized=true"
if "/original=true" not in parsed.path:
return source_url, False
updated_path = parsed.path.replace("/original=true", replacement, 1)
if updated_path == parsed.path:
return source_url, False
rewritten = urlunparse(parsed._replace(path=updated_path))
print(rewritten)
return rewritten, True
__all__ = ["rewrite_preview_url"]

View File

@@ -48,6 +48,13 @@ SUPPORTED_MEDIA_EXTENSIONS = {
# Valid Lora types # Valid Lora types
VALID_LORA_TYPES = ['lora', 'locon', 'dora'] VALID_LORA_TYPES = ['lora', 'locon', 'dora']
# Supported Civitai model types for user model queries (case-insensitive)
CIVITAI_USER_MODEL_TYPES = [
*VALID_LORA_TYPES,
'textualinversion',
'checkpoint',
]
# Auto-organize settings # Auto-organize settings
AUTO_ORGANIZE_BATCH_SIZE = 50 # Process models in batches to avoid overwhelming the system AUTO_ORGANIZE_BATCH_SIZE = 50 # Process models in batches to avoid overwhelming the system

View File

@@ -18,7 +18,7 @@ from ..utils.metadata_manager import MetadataManager
from .example_images_processor import ExampleImagesProcessor from .example_images_processor import ExampleImagesProcessor
from .example_images_metadata import MetadataUpdater from .example_images_metadata import MetadataUpdater
from ..services.downloader import get_downloader from ..services.downloader import get_downloader
from ..services.settings_manager import settings from ..services.settings_manager import get_settings_manager
class ExampleImagesDownloadError(RuntimeError): class ExampleImagesDownloadError(RuntimeError):
@@ -107,7 +107,7 @@ class DownloadManager:
self._state_lock = state_lock or asyncio.Lock() self._state_lock = state_lock or asyncio.Lock()
def _resolve_output_dir(self, library_name: str | None = None) -> str: def _resolve_output_dir(self, library_name: str | None = None) -> str:
base_path = settings.get('example_images_path') base_path = get_settings_manager().get('example_images_path')
if not base_path: if not base_path:
return '' return ''
return ensure_library_root_exists(library_name) return ensure_library_root_exists(library_name)
@@ -126,7 +126,8 @@ class DownloadManager:
model_types = data.get('model_types', ['lora', 'checkpoint']) model_types = data.get('model_types', ['lora', 'checkpoint'])
delay = float(data.get('delay', 0.2)) delay = float(data.get('delay', 0.2))
base_path = settings.get('example_images_path') settings_manager = get_settings_manager()
base_path = settings_manager.get('example_images_path')
if not base_path: if not base_path:
error_msg = 'Example images path not configured in settings' error_msg = 'Example images path not configured in settings'
@@ -138,7 +139,7 @@ class DownloadManager:
} }
raise DownloadConfigurationError(error_msg) raise DownloadConfigurationError(error_msg)
active_library = settings.get_active_library_name() active_library = get_settings_manager().get_active_library_name()
output_dir = self._resolve_output_dir(active_library) output_dir = self._resolve_output_dir(active_library)
if not output_dir: if not output_dir:
raise DownloadConfigurationError('Example images path not configured in settings') raise DownloadConfigurationError('Example images path not configured in settings')
@@ -151,7 +152,7 @@ class DownloadManager:
progress_file = os.path.join(output_dir, '.download_progress.json') progress_file = os.path.join(output_dir, '.download_progress.json')
progress_source = progress_file progress_source = progress_file
if uses_library_scoped_folders(): if uses_library_scoped_folders():
legacy_root = settings.get('example_images_path') or '' legacy_root = get_settings_manager().get('example_images_path') or ''
legacy_progress = os.path.join(legacy_root, '.download_progress.json') if legacy_root else '' legacy_progress = os.path.join(legacy_root, '.download_progress.json') if legacy_root else ''
if legacy_progress and os.path.exists(legacy_progress) and not os.path.exists(progress_file): if legacy_progress and os.path.exists(legacy_progress) and not os.path.exists(progress_file):
try: try:
@@ -555,11 +556,12 @@ class DownloadManager:
if not model_hashes: if not model_hashes:
raise DownloadConfigurationError('Missing model_hashes parameter') raise DownloadConfigurationError('Missing model_hashes parameter')
base_path = settings.get('example_images_path') settings_manager = get_settings_manager()
base_path = settings_manager.get('example_images_path')
if not base_path: if not base_path:
raise DownloadConfigurationError('Example images path not configured in settings') raise DownloadConfigurationError('Example images path not configured in settings')
active_library = settings.get_active_library_name() active_library = settings_manager.get_active_library_name()
output_dir = self._resolve_output_dir(active_library) output_dir = self._resolve_output_dir(active_library)
if not output_dir: if not output_dir:
raise DownloadConfigurationError('Example images path not configured in settings') raise DownloadConfigurationError('Example images path not configured in settings')

View File

@@ -3,7 +3,7 @@ import os
import sys import sys
import subprocess import subprocess
from aiohttp import web from aiohttp import web
from ..services.settings_manager import settings from ..services.settings_manager import get_settings_manager
from ..utils.example_images_paths import ( from ..utils.example_images_paths import (
get_model_folder, get_model_folder,
get_model_relative_path, get_model_relative_path,
@@ -37,7 +37,8 @@ class ExampleImagesFileManager:
}, status=400) }, status=400)
# Get example images path from settings # Get example images path from settings
example_images_path = settings.get('example_images_path') settings_manager = get_settings_manager()
example_images_path = settings_manager.get('example_images_path')
if not example_images_path: if not example_images_path:
return web.json_response({ return web.json_response({
'success': False, 'success': False,
@@ -109,7 +110,8 @@ class ExampleImagesFileManager:
}, status=400) }, status=400)
# Get example images path from settings # Get example images path from settings
example_images_path = settings.get('example_images_path') settings_manager = get_settings_manager()
example_images_path = settings_manager.get('example_images_path')
if not example_images_path: if not example_images_path:
return web.json_response({ return web.json_response({
'success': False, 'success': False,
@@ -183,7 +185,8 @@ class ExampleImagesFileManager:
}, status=400) }, status=400)
# Get example images path from settings # Get example images path from settings
example_images_path = settings.get('example_images_path') settings_manager = get_settings_manager()
example_images_path = settings_manager.get('example_images_path')
if not example_images_path: if not example_images_path:
return web.json_response({ return web.json_response({
'has_images': False 'has_images': False

View File

@@ -1,12 +1,13 @@
import logging import logging
import os import os
import re import re
from typing import TYPE_CHECKING, Any, Dict, Optional
from ..recipes.constants import GEN_PARAM_KEYS from ..recipes.constants import GEN_PARAM_KEYS
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
from ..services.metadata_sync_service import MetadataSyncService from ..services.metadata_sync_service import MetadataSyncService
from ..services.preview_asset_service import PreviewAssetService from ..services.preview_asset_service import PreviewAssetService
from ..services.settings_manager import settings from ..services.settings_manager import get_settings_manager
from ..services.downloader import get_downloader from ..services.downloader import get_downloader
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
from ..utils.exif_utils import ExifUtils from ..utils.exif_utils import ExifUtils
@@ -20,13 +21,46 @@ _preview_service = PreviewAssetService(
exif_utils=ExifUtils, exif_utils=ExifUtils,
) )
_metadata_sync_service = MetadataSyncService( _metadata_sync_service: MetadataSyncService | None = None
metadata_manager=MetadataManager, _metadata_sync_service_settings: Optional["SettingsManager"] = None
preview_service=_preview_service,
settings=settings, if TYPE_CHECKING: # pragma: no cover - import for type checkers only
default_metadata_provider_factory=get_default_metadata_provider, from ..services.settings_manager import SettingsManager
metadata_provider_selector=get_metadata_provider,
)
def _build_metadata_sync_service(settings_manager: "SettingsManager") -> MetadataSyncService:
"""Construct a metadata sync service bound to the provided settings."""
return MetadataSyncService(
metadata_manager=MetadataManager,
preview_service=_preview_service,
settings=settings_manager,
default_metadata_provider_factory=get_default_metadata_provider,
metadata_provider_selector=get_metadata_provider,
)
def _get_metadata_sync_service() -> MetadataSyncService:
"""Return the shared metadata sync service, initialising it lazily."""
global _metadata_sync_service, _metadata_sync_service_settings
settings_manager = get_settings_manager()
if isinstance(_metadata_sync_service, MetadataSyncService):
if _metadata_sync_service_settings is not settings_manager:
_metadata_sync_service = _build_metadata_sync_service(settings_manager)
_metadata_sync_service_settings = settings_manager
elif _metadata_sync_service is None:
_metadata_sync_service = _build_metadata_sync_service(settings_manager)
_metadata_sync_service_settings = settings_manager
else:
# Tests may inject stand-ins that do not match the sync service type. Preserve
# those injections while still updating our cached settings reference so the
# next real service instantiation uses the current configuration.
_metadata_sync_service_settings = settings_manager
return _metadata_sync_service
class MetadataUpdater: class MetadataUpdater:
@@ -71,7 +105,8 @@ class MetadataUpdater:
async def update_cache_func(old_path, new_path, metadata): async def update_cache_func(old_path, new_path, metadata):
return await scanner.update_single_model_cache(old_path, new_path, metadata) return await scanner.update_single_model_cache(old_path, new_path, metadata)
success, error = await _metadata_sync_service.fetch_and_update_model( await MetadataManager.hydrate_model_data(model_data)
success, error = await _get_metadata_sync_service().fetch_and_update_model(
sha256=model_hash, sha256=model_hash,
file_path=file_path, file_path=file_path,
model_data=model_data, model_data=model_data,
@@ -151,16 +186,16 @@ class MetadataUpdater:
if is_supported: if is_supported:
local_images_paths.append(file_path) local_images_paths.append(file_path)
await MetadataManager.hydrate_model_data(model)
civitai_data = model.setdefault('civitai', {})
# Check if metadata update is needed (no civitai field or empty images) # Check if metadata update is needed (no civitai field or empty images)
needs_update = not model.get('civitai') or not model.get('civitai', {}).get('images') needs_update = not civitai_data or not civitai_data.get('images')
if needs_update and local_images_paths: if needs_update and local_images_paths:
logger.debug(f"Found {len(local_images_paths)} local example images for {model.get('model_name')}, updating metadata") logger.debug(f"Found {len(local_images_paths)} local example images for {model.get('model_name')}, updating metadata")
# Create or get civitai field # Create or get civitai field
if not model.get('civitai'):
model['civitai'] = {}
# Create images array # Create images array
images = [] images = []
@@ -195,16 +230,13 @@ class MetadataUpdater:
images.append(image_entry) images.append(image_entry)
# Update the model's civitai.images field # Update the model's civitai.images field
model['civitai']['images'] = images civitai_data['images'] = images
# Save metadata to .metadata.json file # Save metadata to .metadata.json file
file_path = model.get('file_path') file_path = model.get('file_path')
try: try:
# Create a copy of model data without 'folder' field
model_copy = model.copy() model_copy = model.copy()
model_copy.pop('folder', None) model_copy.pop('folder', None)
# Write metadata to file
await MetadataManager.save_metadata(file_path, model_copy) await MetadataManager.save_metadata(file_path, model_copy)
logger.info(f"Saved metadata for {model.get('model_name')}") logger.info(f"Saved metadata for {model.get('model_name')}")
except Exception as e: except Exception as e:
@@ -237,16 +269,13 @@ class MetadataUpdater:
tuple: (regular_images, custom_images) - Both image arrays tuple: (regular_images, custom_images) - Both image arrays
""" """
try: try:
# Ensure civitai field exists in model_data await MetadataManager.hydrate_model_data(model_data)
if not model_data.get('civitai'): civitai_data = model_data.setdefault('civitai', {})
model_data['civitai'] = {} custom_images = civitai_data.get('customImages')
# Ensure customImages array exists if not isinstance(custom_images, list):
if not model_data['civitai'].get('customImages'): custom_images = []
model_data['civitai']['customImages'] = [] civitai_data['customImages'] = custom_images
# Get current customImages array
custom_images = model_data['civitai']['customImages']
# Add new image entry for each imported file # Add new image entry for each imported file
for path_tuple in newly_imported_paths: for path_tuple in newly_imported_paths:
@@ -304,11 +333,8 @@ class MetadataUpdater:
file_path = model_data.get('file_path') file_path = model_data.get('file_path')
if file_path: if file_path:
try: try:
# Create a copy of model data without 'folder' field
model_copy = model_data.copy() model_copy = model_data.copy()
model_copy.pop('folder', None) model_copy.pop('folder', None)
# Write metadata to file
await MetadataManager.save_metadata(file_path, model_copy) await MetadataManager.save_metadata(file_path, model_copy)
logger.info(f"Saved metadata for {model_data.get('model_name')}") logger.info(f"Saved metadata for {model_data.get('model_name')}")
except Exception as e: except Exception as e:
@@ -319,7 +345,7 @@ class MetadataUpdater:
await scanner.update_single_model_cache(file_path, file_path, model_data) await scanner.update_single_model_cache(file_path, file_path, model_data)
# Get regular images array (might be None) # Get regular images array (might be None)
regular_images = model_data['civitai'].get('images', []) regular_images = civitai_data.get('images', [])
# Return both image arrays # Return both image arrays
return regular_images, custom_images return regular_images, custom_images

View File

@@ -3,7 +3,7 @@ import logging
import os import os
import re import re
import json import json
from ..services.settings_manager import settings from ..services.settings_manager import get_settings_manager
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..utils.example_images_paths import iter_library_roots from ..utils.example_images_paths import iter_library_roots
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
@@ -14,6 +14,25 @@ logger = logging.getLogger(__name__)
CURRENT_NAMING_VERSION = 2 # Increment this when naming conventions change CURRENT_NAMING_VERSION = 2 # Increment this when naming conventions change
class _SettingsProxy:
def __init__(self):
self._manager = None
def _resolve(self):
if self._manager is None:
self._manager = get_settings_manager()
return self._manager
def get(self, *args, **kwargs):
return self._resolve().get(*args, **kwargs)
def __getattr__(self, item):
return getattr(self._resolve(), item)
settings = _SettingsProxy()
class ExampleImagesMigration: class ExampleImagesMigration:
"""Handles migrations for example images naming conventions""" """Handles migrations for example images naming conventions"""

View File

@@ -8,7 +8,7 @@ import re
import shutil import shutil
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple
from ..services.settings_manager import settings from ..services.settings_manager import get_settings_manager
_HEX_PATTERN = re.compile(r"[a-fA-F0-9]{64}") _HEX_PATTERN = re.compile(r"[a-fA-F0-9]{64}")
@@ -18,7 +18,8 @@ logger = logging.getLogger(__name__)
def _get_configured_libraries() -> List[str]: def _get_configured_libraries() -> List[str]:
"""Return configured library names if multi-library support is enabled.""" """Return configured library names if multi-library support is enabled."""
libraries = settings.get("libraries") settings_manager = get_settings_manager()
libraries = settings_manager.get("libraries")
if isinstance(libraries, dict) and libraries: if isinstance(libraries, dict) and libraries:
return list(libraries.keys()) return list(libraries.keys())
return [] return []
@@ -27,7 +28,8 @@ def _get_configured_libraries() -> List[str]:
def get_example_images_root() -> str: def get_example_images_root() -> str:
"""Return the root directory configured for example images.""" """Return the root directory configured for example images."""
root = settings.get("example_images_path") or "" settings_manager = get_settings_manager()
root = settings_manager.get("example_images_path") or ""
return os.path.abspath(root) if root else "" return os.path.abspath(root) if root else ""
@@ -41,7 +43,8 @@ def uses_library_scoped_folders() -> bool:
def sanitize_library_name(library_name: Optional[str]) -> str: def sanitize_library_name(library_name: Optional[str]) -> str:
"""Return a filesystem safe library name.""" """Return a filesystem safe library name."""
name = library_name or settings.get_active_library_name() or "default" settings_manager = get_settings_manager()
name = library_name or settings_manager.get_active_library_name() or "default"
safe_name = re.sub(r"[^A-Za-z0-9_.-]", "_", name) safe_name = re.sub(r"[^A-Za-z0-9_.-]", "_", name)
return safe_name or "default" return safe_name or "default"
@@ -161,11 +164,13 @@ def iter_library_roots() -> Iterable[Tuple[str, str]]:
results.append((library, get_library_root(library))) results.append((library, get_library_root(library)))
else: else:
# Fall back to the active library to avoid skipping migrations/cleanup # Fall back to the active library to avoid skipping migrations/cleanup
active = settings.get_active_library_name() or "default" settings_manager = get_settings_manager()
active = settings_manager.get_active_library_name() or "default"
results.append((active, get_library_root(active))) results.append((active, get_library_root(active)))
return results return results
active = settings.get_active_library_name() or "default" settings_manager = get_settings_manager()
active = settings_manager.get_active_library_name() or "default"
return [(active, root)] return [(active, root)]

View File

@@ -6,7 +6,7 @@ import string
from aiohttp import web from aiohttp import web
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..services.settings_manager import settings from ..services.settings_manager import get_settings_manager
from ..utils.example_images_paths import get_model_folder, get_model_relative_path from ..utils.example_images_paths import get_model_folder, get_model_relative_path
from .example_images_metadata import MetadataUpdater from .example_images_metadata import MetadataUpdater
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
@@ -318,7 +318,7 @@ class ExampleImagesProcessor:
try: try:
# Get example images path # Get example images path
example_images_path = settings.get('example_images_path') example_images_path = get_settings_manager().get('example_images_path')
if not example_images_path: if not example_images_path:
raise ExampleImagesValidationError('No example images path configured') raise ExampleImagesValidationError('No example images path configured')
@@ -442,7 +442,7 @@ class ExampleImagesProcessor:
}, status=400) }, status=400)
# Get example images path # Get example images path
example_images_path = settings.get('example_images_path') example_images_path = get_settings_manager().get('example_images_path')
if not example_images_path: if not example_images_path:
return web.json_response({ return web.json_response({
'success': False, 'success': False,
@@ -475,15 +475,17 @@ class ExampleImagesProcessor:
'error': f"Model with hash {model_hash} not found in cache" 'error': f"Model with hash {model_hash} not found in cache"
}, status=404) }, status=404)
# Check if model has custom images await MetadataManager.hydrate_model_data(model_data)
if not model_data.get('civitai', {}).get('customImages'): civitai_data = model_data.setdefault('civitai', {})
custom_images = civitai_data.get('customImages')
if not isinstance(custom_images, list) or not custom_images:
return web.json_response({ return web.json_response({
'success': False, 'success': False,
'error': f"Model has no custom images" 'error': f"Model has no custom images"
}, status=404) }, status=404)
# Find the custom image with matching short_id # Find the custom image with matching short_id
custom_images = model_data['civitai']['customImages']
matching_image = None matching_image = None
new_custom_images = [] new_custom_images = []
@@ -527,17 +529,15 @@ class ExampleImagesProcessor:
logger.warning(f"File for custom example with id {short_id} not found, but metadata will still be updated") logger.warning(f"File for custom example with id {short_id} not found, but metadata will still be updated")
# Update metadata # Update metadata
model_data['civitai']['customImages'] = new_custom_images civitai_data['customImages'] = new_custom_images
model_data.setdefault('civitai', {})['customImages'] = new_custom_images
# Save updated metadata to file # Save updated metadata to file
file_path = model_data.get('file_path') file_path = model_data.get('file_path')
if file_path: if file_path:
try: try:
# Create a copy of model data without 'folder' field
model_copy = model_data.copy() model_copy = model_data.copy()
model_copy.pop('folder', None) model_copy.pop('folder', None)
# Write metadata to file
await MetadataManager.save_metadata(file_path, model_copy) await MetadataManager.save_metadata(file_path, model_copy)
logger.debug(f"Saved updated metadata for {model_data.get('model_name')}") logger.debug(f"Saved updated metadata for {model_data.get('model_name')}")
except Exception as e: except Exception as e:
@@ -551,7 +551,7 @@ class ExampleImagesProcessor:
await scanner.update_single_model_cache(file_path, file_path, model_data) await scanner.update_single_model_cache(file_path, file_path, model_data)
# Get regular images array (might be None) # Get regular images array (might be None)
regular_images = model_data['civitai'].get('images', []) regular_images = civitai_data.get('images', [])
return web.json_response({ return web.json_response({
'success': True, 'success': True,

View File

@@ -2,7 +2,7 @@ from datetime import datetime
import os import os
import json import json
import logging import logging
from typing import Dict, Optional, Type, Union from typing import Any, Dict, Optional, Type, Union
from .models import BaseModelMetadata, LoraMetadata from .models import BaseModelMetadata, LoraMetadata
from .file_utils import normalize_path, find_preview_file, calculate_sha256 from .file_utils import normalize_path, find_preview_file, calculate_sha256
@@ -54,6 +54,70 @@ class MetadataManager:
logger.error(f"{error_type} in metadata file: {metadata_path}. Error: {str(e)}. Skipping model to preserve existing data.") logger.error(f"{error_type} in metadata file: {metadata_path}. Error: {str(e)}. Skipping model to preserve existing data.")
return None, True # should_skip = True return None, True # should_skip = True
@staticmethod
async def load_metadata_payload(file_path: str) -> Dict:
"""
Load metadata and return it as a dictionary, including any unknown fields.
Falls back to reading the raw JSON file if parsing into a model class fails.
"""
payload: Dict = {}
metadata_obj, should_skip = await MetadataManager.load_metadata(file_path)
if metadata_obj:
payload = metadata_obj.to_dict()
unknown_fields = getattr(metadata_obj, "_unknown_fields", None)
if isinstance(unknown_fields, dict):
payload.update(unknown_fields)
else:
if not should_skip:
metadata_path = (
file_path
if file_path.endswith(".metadata.json")
else f"{os.path.splitext(file_path)[0]}.metadata.json"
)
if os.path.exists(metadata_path):
try:
with open(metadata_path, "r", encoding="utf-8") as handle:
raw = json.load(handle)
if isinstance(raw, dict):
payload = raw
except json.JSONDecodeError:
logger.warning(
"Failed to parse metadata file %s while loading payload",
metadata_path,
)
except Exception as exc: # pragma: no cover - defensive logging
logger.warning("Failed to read metadata file %s: %s", metadata_path, exc)
if not isinstance(payload, dict):
payload = {}
if file_path:
payload.setdefault("file_path", normalize_path(file_path))
return payload
@staticmethod
async def hydrate_model_data(model_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Replace the provided model data with the authoritative payload from disk.
Preserves the cached folder entry if present.
"""
file_path = model_data.get("file_path")
if not file_path:
return model_data
folder = model_data.get("folder")
payload = await MetadataManager.load_metadata_payload(file_path)
if folder is not None:
payload["folder"] = folder
model_data.clear()
model_data.update(payload)
return model_data
@staticmethod @staticmethod
async def save_metadata(path: str, metadata: Union[BaseModelMetadata, Dict]) -> bool: async def save_metadata(path: str, metadata: Union[BaseModelMetadata, Dict]) -> bool:
""" """

View File

@@ -65,6 +65,12 @@ def ensure_settings_file(logger: Optional[logging.Logger] = None) -> str:
logger = logger or _LOGGER logger = logger or _LOGGER
target_path = get_settings_file_path(create_dir=True) target_path = get_settings_file_path(create_dir=True)
preferred_dir = user_config_dir(APP_NAME, appauthor=False)
preferred_path = os.path.join(preferred_dir, "settings.json")
if os.path.abspath(target_path) != os.path.abspath(preferred_path):
os.makedirs(preferred_dir, exist_ok=True)
target_path = preferred_path
legacy_path = get_legacy_settings_path() legacy_path = get_legacy_settings_path()
if os.path.exists(legacy_path) and not os.path.exists(target_path): if os.path.exists(legacy_path) and not os.path.exists(target_path):

View File

@@ -3,7 +3,7 @@ import os
from typing import Dict from typing import Dict
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..config import config from ..config import config
from ..services.settings_manager import settings from ..services.settings_manager import get_settings_manager
from .constants import CIVITAI_MODEL_TAGS from .constants import CIVITAI_MODEL_TAGS
import asyncio import asyncio
@@ -143,7 +143,8 @@ def calculate_relative_path_for_model(model_data: Dict, model_type: str = 'lora'
Relative path string (empty string for flat structure) Relative path string (empty string for flat structure)
""" """
# Get path template from settings for specific model type # Get path template from settings for specific model type
path_template = settings.get_download_path_template(model_type) settings_manager = get_settings_manager()
path_template = settings_manager.get_download_path_template(model_type)
# If template is empty, return empty path (flat structure) # If template is empty, return empty path (flat structure)
if not path_template: if not path_template:
@@ -166,7 +167,7 @@ def calculate_relative_path_for_model(model_data: Dict, model_type: str = 'lora'
model_tags = model_data.get('tags', []) model_tags = model_data.get('tags', [])
# Apply mapping if available # Apply mapping if available
base_model_mappings = settings.get('base_model_path_mappings', {}) base_model_mappings = settings_manager.get('base_model_path_mappings', {})
mapped_base_model = base_model_mappings.get(base_model, base_model) mapped_base_model = base_model_mappings.get(base_model, base_model)
# Find the first Civitai model tag that exists in model_tags # Find the first Civitai model tag that exists in model_tags

View File

@@ -1,7 +1,7 @@
[project] [project]
name = "comfyui-lora-manager" name = "comfyui-lora-manager"
description = "Revolutionize your workflow with the ultimate LoRA companion for ComfyUI!" description = "Revolutionize your workflow with the ultimate LoRA companion for ComfyUI!"
version = "0.9.6" version = "0.9.7"
license = {file = "LICENSE"} license = {file = "LICENSE"}
dependencies = [ dependencies = [
"aiohttp", "aiohttp",

View File

@@ -21,6 +21,7 @@ export class SidebarManager {
this.isInitialized = false; this.isInitialized = false;
this.displayMode = 'tree'; // 'tree' or 'list' this.displayMode = 'tree'; // 'tree' or 'list'
this.foldersList = []; this.foldersList = [];
this.recursiveSearchEnabled = true;
// Bind methods // Bind methods
this.handleTreeClick = this.handleTreeClick.bind(this); this.handleTreeClick = this.handleTreeClick.bind(this);
@@ -36,6 +37,7 @@ export class SidebarManager {
this.updateContainerMargin = this.updateContainerMargin.bind(this); this.updateContainerMargin = this.updateContainerMargin.bind(this);
this.handleDisplayModeToggle = this.handleDisplayModeToggle.bind(this); this.handleDisplayModeToggle = this.handleDisplayModeToggle.bind(this);
this.handleFolderListClick = this.handleFolderListClick.bind(this); this.handleFolderListClick = this.handleFolderListClick.bind(this);
this.handleRecursiveToggle = this.handleRecursiveToggle.bind(this);
} }
async initialize(pageControls) { async initialize(pageControls) {
@@ -89,6 +91,7 @@ export class SidebarManager {
this.isHovering = false; this.isHovering = false;
this.apiClient = null; this.apiClient = null;
this.isInitialized = false; this.isInitialized = false;
this.recursiveSearchEnabled = true;
// Reset container margin // Reset container margin
const container = document.querySelector('.container'); const container = document.querySelector('.container');
@@ -111,6 +114,7 @@ export class SidebarManager {
const sidebar = document.getElementById('folderSidebar'); const sidebar = document.getElementById('folderSidebar');
const hoverArea = document.getElementById('sidebarHoverArea'); const hoverArea = document.getElementById('sidebarHoverArea');
const displayModeToggleBtn = document.getElementById('sidebarDisplayModeToggle'); const displayModeToggleBtn = document.getElementById('sidebarDisplayModeToggle');
const recursiveToggleBtn = document.getElementById('sidebarRecursiveToggle');
if (pinToggleBtn) { if (pinToggleBtn) {
pinToggleBtn.removeEventListener('click', this.handlePinToggle); pinToggleBtn.removeEventListener('click', this.handlePinToggle);
@@ -145,6 +149,9 @@ export class SidebarManager {
if (displayModeToggleBtn) { if (displayModeToggleBtn) {
displayModeToggleBtn.removeEventListener('click', this.handleDisplayModeToggle); displayModeToggleBtn.removeEventListener('click', this.handleDisplayModeToggle);
} }
if (recursiveToggleBtn) {
recursiveToggleBtn.removeEventListener('click', this.handleRecursiveToggle);
}
} }
async init() { async init() {
@@ -197,7 +204,7 @@ export class SidebarManager {
updateSidebarTitle() { updateSidebarTitle() {
const sidebarTitle = document.getElementById('sidebarTitle'); const sidebarTitle = document.getElementById('sidebarTitle');
if (sidebarTitle) { if (sidebarTitle) {
sidebarTitle.textContent = `${this.apiClient.apiConfig.config.displayName} Root`; sidebarTitle.textContent = translate('sidebar.modelRoot');
} }
} }
@@ -220,6 +227,12 @@ export class SidebarManager {
collapseAllBtn.addEventListener('click', this.handleCollapseAll); collapseAllBtn.addEventListener('click', this.handleCollapseAll);
} }
// Recursive toggle button
const recursiveToggleBtn = document.getElementById('sidebarRecursiveToggle');
if (recursiveToggleBtn) {
recursiveToggleBtn.addEventListener('click', this.handleRecursiveToggle);
}
// Tree click handler // Tree click handler
const folderTree = document.getElementById('sidebarFolderTree'); const folderTree = document.getElementById('sidebarFolderTree');
if (folderTree) { if (folderTree) {
@@ -645,11 +658,33 @@ export class SidebarManager {
this.displayMode = this.displayMode === 'tree' ? 'list' : 'tree'; this.displayMode = this.displayMode === 'tree' ? 'list' : 'tree';
this.updateDisplayModeButton(); this.updateDisplayModeButton();
this.updateCollapseAllButton(); this.updateCollapseAllButton();
this.updateRecursiveToggleButton();
this.updateSearchRecursiveOption(); this.updateSearchRecursiveOption();
this.saveDisplayMode(); this.saveDisplayMode();
this.loadFolderTree(); // Reload with new display mode this.loadFolderTree(); // Reload with new display mode
} }
async handleRecursiveToggle(event) {
event.stopPropagation();
if (this.displayMode !== 'tree') {
return;
}
this.recursiveSearchEnabled = !this.recursiveSearchEnabled;
setStorageItem(`${this.pageType}_recursiveSearch`, this.recursiveSearchEnabled);
this.updateSearchRecursiveOption();
this.updateRecursiveToggleButton();
if (this.pageControls && typeof this.pageControls.resetAndReload === 'function') {
try {
await this.pageControls.resetAndReload(true);
} catch (error) {
console.error('Failed to reload models after toggling recursive search:', error);
}
}
}
updateDisplayModeButton() { updateDisplayModeButton() {
const displayModeBtn = document.getElementById('sidebarDisplayModeToggle'); const displayModeBtn = document.getElementById('sidebarDisplayModeToggle');
if (displayModeBtn) { if (displayModeBtn) {
@@ -679,8 +714,35 @@ export class SidebarManager {
} }
} }
updateRecursiveToggleButton() {
const recursiveToggleBtn = document.getElementById('sidebarRecursiveToggle');
if (!recursiveToggleBtn) return;
const icon = recursiveToggleBtn.querySelector('i');
const isTreeMode = this.displayMode === 'tree';
const isActive = isTreeMode && this.recursiveSearchEnabled;
recursiveToggleBtn.classList.toggle('active', isActive);
recursiveToggleBtn.classList.toggle('disabled', !isTreeMode);
recursiveToggleBtn.setAttribute('aria-pressed', isActive ? 'true' : 'false');
recursiveToggleBtn.setAttribute('aria-disabled', isTreeMode ? 'false' : 'true');
if (icon) {
icon.className = 'fas fa-code-branch';
}
if (!isTreeMode) {
recursiveToggleBtn.title = translate('sidebar.recursiveUnavailable');
} else if (this.recursiveSearchEnabled) {
recursiveToggleBtn.title = translate('sidebar.recursiveOn');
} else {
recursiveToggleBtn.title = translate('sidebar.recursiveOff');
}
}
updateSearchRecursiveOption() { updateSearchRecursiveOption() {
this.pageControls.pageState.searchOptions.recursive = this.displayMode === 'tree'; const isRecursive = this.displayMode === 'tree' && this.recursiveSearchEnabled;
this.pageControls.pageState.searchOptions.recursive = isRecursive;
} }
updateTreeSelection() { updateTreeSelection() {
@@ -925,15 +987,18 @@ export class SidebarManager {
const isPinned = getStorageItem(`${this.pageType}_sidebarPinned`, true); const isPinned = getStorageItem(`${this.pageType}_sidebarPinned`, true);
const expandedPaths = getStorageItem(`${this.pageType}_expandedNodes`, []); const expandedPaths = getStorageItem(`${this.pageType}_expandedNodes`, []);
const displayMode = getStorageItem(`${this.pageType}_displayMode`, 'tree'); // 'tree' or 'list', default to 'tree' const displayMode = getStorageItem(`${this.pageType}_displayMode`, 'tree'); // 'tree' or 'list', default to 'tree'
const recursiveSearchEnabled = getStorageItem(`${this.pageType}_recursiveSearch`, true);
this.isPinned = isPinned; this.isPinned = isPinned;
this.expandedNodes = new Set(expandedPaths); this.expandedNodes = new Set(expandedPaths);
this.displayMode = displayMode; this.displayMode = displayMode;
this.recursiveSearchEnabled = recursiveSearchEnabled;
this.updatePinButton(); this.updatePinButton();
this.updateDisplayModeButton(); this.updateDisplayModeButton();
this.updateCollapseAllButton(); this.updateCollapseAllButton();
this.updateSearchRecursiveOption(); this.updateSearchRecursiveOption();
this.updateRecursiveToggleButton();
} }
restoreSelectedFolder() { restoreSelectedFolder() {

View File

@@ -67,7 +67,7 @@ export const state = {
modelname: true, modelname: true,
tags: false, tags: false,
creator: false, creator: false,
recursive: true, recursive: getStorageItem(`${MODEL_TYPES.LORA}_recursiveSearch`, true),
}, },
filters: { filters: {
baseModel: [], baseModel: [],
@@ -116,7 +116,7 @@ export const state = {
filename: true, filename: true,
modelname: true, modelname: true,
creator: false, creator: false,
recursive: true, recursive: getStorageItem(`${MODEL_TYPES.CHECKPOINT}_recursiveSearch`, true),
}, },
filters: { filters: {
baseModel: [], baseModel: [],
@@ -144,7 +144,7 @@ export const state = {
modelname: true, modelname: true,
tags: false, tags: false,
creator: false, creator: false,
recursive: true, recursive: getStorageItem(`${MODEL_TYPES.EMBEDDING}_recursiveSearch`, true),
}, },
filters: { filters: {
baseModel: [], baseModel: [],

View File

@@ -435,8 +435,9 @@ export async function sendLoraToWorkflow(loraSyntax, replaceMode = false, syntax
return true; return true;
} else { } else {
// Single node - send directly // Single node - send directly
const nodeId = Object.keys(registryData.data.nodes)[0]; const nodes = registryData.data.nodes;
return await sendToSpecificNode([nodeId], loraSyntax, replaceMode, syntaxType); const nodeId = Object.keys(nodes)[0];
return await sendToSpecificNode([nodeId], nodes, loraSyntax, replaceMode, syntaxType);
} }
} catch (error) { } catch (error) {
console.error('Failed to get registry:', error); console.error('Failed to get registry:', error);
@@ -452,19 +453,65 @@ export async function sendLoraToWorkflow(loraSyntax, replaceMode = false, syntax
* @param {boolean} replaceMode - Whether to replace existing LoRAs * @param {boolean} replaceMode - Whether to replace existing LoRAs
* @param {string} syntaxType - The type of syntax ('lora' or 'recipe') * @param {string} syntaxType - The type of syntax ('lora' or 'recipe')
*/ */
async function sendToSpecificNode(nodeIds, loraSyntax, replaceMode, syntaxType) { function resolveNodeReference(nodeKey, nodesMap) {
if (!nodeKey) {
return null;
}
const directMatch = nodesMap?.[nodeKey];
if (directMatch) {
return {
node_id: directMatch.id,
graph_id: directMatch.graph_id ?? null,
};
}
if (typeof nodeKey === 'string' && nodeKey.includes(':')) {
const [graphId, ...rest] = nodeKey.split(':');
const nodeIdPart = rest.join(':');
const numericNodeId = Number(nodeIdPart);
return {
node_id: Number.isNaN(numericNodeId) ? nodeIdPart : numericNodeId,
graph_id: graphId || null,
};
}
const numericId = Number(nodeKey);
return {
node_id: Number.isNaN(numericId) ? nodeKey : numericId,
graph_id: null,
};
}
async function sendToSpecificNode(nodeIds, nodesMap, loraSyntax, replaceMode, syntaxType) {
try { try {
// Call the backend API to update the lora code // Call the backend API to update the lora code
const requestBody = {
lora_code: loraSyntax,
mode: replaceMode ? 'replace' : 'append'
};
if (Array.isArray(nodeIds)) {
const references = nodeIds
.map((nodeKey) => resolveNodeReference(nodeKey, nodesMap))
.filter((reference) => reference && reference.node_id !== undefined);
if (references.length > 0) {
requestBody.node_ids = references;
}
} else if (nodeIds) {
const reference = resolveNodeReference(nodeIds, nodesMap);
if (reference) {
requestBody.node_ids = [reference];
}
}
const response = await fetch('/api/lm/update-lora-code', { const response = await fetch('/api/lm/update-lora-code', {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json' 'Content-Type': 'application/json'
}, },
body: JSON.stringify({ body: JSON.stringify(requestBody)
node_ids: nodeIds,
lora_code: loraSyntax,
mode: replaceMode ? 'replace' : 'append'
})
}); });
const result = await response.json(); const result = await response.json();
@@ -522,16 +569,17 @@ function showNodeSelector(nodes, loraSyntax, replaceMode, syntaxType) {
hideNodeSelector(); hideNodeSelector();
// Generate node list HTML with icons and proper colors // Generate node list HTML with icons and proper colors
const nodeItems = Object.values(nodes).map(node => { const nodeItems = Object.entries(nodes).map(([nodeKey, node]) => {
const iconClass = NODE_TYPE_ICONS[node.type] || 'fas fa-question-circle'; const iconClass = NODE_TYPE_ICONS[node.type] || 'fas fa-question-circle';
const bgColor = node.bgcolor || DEFAULT_NODE_COLOR; const bgColor = node.bgcolor || DEFAULT_NODE_COLOR;
const graphLabel = node.graph_name ? ` (${node.graph_name})` : '';
return ` return `
<div class="node-item" data-node-id="${node.id}"> <div class="node-item" data-node-id="${nodeKey}">
<div class="node-icon-indicator" style="background-color: ${bgColor}"> <div class="node-icon-indicator" style="background-color: ${bgColor}">
<i class="${iconClass}"></i> <i class="${iconClass}"></i>
</div> </div>
<span>#${node.id} ${node.title}</span> <span>#${node.id}${graphLabel} ${node.title}</span>
</div> </div>
`; `;
}).join(''); }).join('');
@@ -610,10 +658,10 @@ function setupNodeSelectorEvents(selector, nodes, loraSyntax, replaceMode, synta
if (action === 'send-all') { if (action === 'send-all') {
// Send to all nodes // Send to all nodes
const allNodeIds = Object.keys(nodes); const allNodeIds = Object.keys(nodes);
await sendToSpecificNode(allNodeIds, loraSyntax, replaceMode, syntaxType); await sendToSpecificNode(allNodeIds, nodes, loraSyntax, replaceMode, syntaxType);
} else if (nodeId) { } else if (nodeId) {
// Send to specific node // Send to specific node
await sendToSpecificNode([nodeId], loraSyntax, replaceMode, syntaxType); await sendToSpecificNode([nodeId], nodes, loraSyntax, replaceMode, syntaxType);
} }
hideNodeSelector(); hideNodeSelector();

View File

@@ -9,6 +9,9 @@
<button class="sidebar-action-btn" id="sidebarDisplayModeToggle" title="{{ t('sidebar.switchToListView') }}"> <button class="sidebar-action-btn" id="sidebarDisplayModeToggle" title="{{ t('sidebar.switchToListView') }}">
<i class="fas fa-sitemap"></i> <i class="fas fa-sitemap"></i>
</button> </button>
<button class="sidebar-action-btn active" id="sidebarRecursiveToggle" title="{{ t('sidebar.recursiveOn') }}" aria-pressed="true">
<i class="fas fa-code-branch"></i>
</button>
<button class="sidebar-action-btn" id="sidebarCollapseAll" title="{{ t('sidebar.collapseAll') }}"> <button class="sidebar-action-btn" id="sidebarCollapseAll" title="{{ t('sidebar.collapseAll') }}">
<i class="fas fa-compress-alt"></i> <i class="fas fa-compress-alt"></i>
</button> </button>

View File

@@ -73,6 +73,30 @@ nodes_mock.NODE_CLASS_MAPPINGS = {}
sys.modules['nodes'] = nodes_mock sys.modules['nodes'] = nodes_mock
@pytest.fixture(autouse=True)
def _isolate_settings_dir(tmp_path_factory, monkeypatch):
"""Redirect settings.json into a temporary directory for each test."""
settings_dir = tmp_path_factory.mktemp("settings_dir")
def fake_get_settings_dir(create: bool = True) -> str:
if create:
settings_dir.mkdir(exist_ok=True)
return str(settings_dir)
monkeypatch.setattr("py.utils.settings_paths.get_settings_dir", fake_get_settings_dir)
monkeypatch.setattr(
"py.utils.settings_paths.user_config_dir",
lambda *_args, **_kwargs: str(settings_dir),
)
from py.services import settings_manager as settings_manager_module
settings_manager_module.reset_settings_manager()
yield
settings_manager_module.reset_settings_manager()
def pytest_pyfunc_call(pyfuncitem): def pytest_pyfunc_call(pyfuncitem):
"""Allow bare async tests to run without pytest.mark.asyncio.""" """Allow bare async tests to run without pytest.mark.asyncio."""
test_function = pyfuncitem.function test_function = pyfuncitem.function

View File

@@ -46,6 +46,7 @@ vi.mock(EVENT_MANAGER_MODULE, () => ({
off: vi.fn(), off: vi.fn(),
addHandler: vi.fn(), addHandler: vi.fn(),
removeHandler: vi.fn(), removeHandler: vi.fn(),
setState: vi.fn(),
}, },
})); }));
@@ -62,6 +63,7 @@ describe('UI helper DOM utilities', () => {
afterEach(() => { afterEach(() => {
vi.useRealTimers(); vi.useRealTimers();
delete global.fetch;
}); });
it('creates toast elements and cleans them up after timeout', async () => { it('creates toast elements and cleans them up after timeout', async () => {
@@ -105,4 +107,53 @@ describe('UI helper DOM utilities', () => {
expect(document.body.dataset.theme).toBe('dark'); expect(document.body.dataset.theme).toBe('dark');
expect(document.querySelector('.theme-toggle').classList.contains('theme-dark')).toBe(true); expect(document.querySelector('.theme-toggle').classList.contains('theme-dark')).toBe(true);
}); });
it('renders subgraph names in the node selector list', async () => {
const registryResponse = {
success: true,
data: {
node_count: 2,
nodes: {
'root:1': {
id: 1,
graph_id: 'root',
graph_name: null,
title: 'Root Loader',
type: 1,
bgcolor: '#123456',
},
'subgraph-uuid:2': {
id: 2,
graph_id: 'subgraph-uuid',
graph_name: 'Character Subgraph',
title: 'Nested Loader',
type: 1,
bgcolor: '#654321',
},
},
},
};
global.fetch = vi.fn().mockResolvedValue({
json: async () => registryResponse,
});
document.body.innerHTML = '<div id="nodeSelector"></div>';
const { sendLoraToWorkflow } = await import(UI_HELPERS_MODULE);
const result = await sendLoraToWorkflow('<lora:test:1>');
expect(result).toBe(true);
expect(global.fetch).toHaveBeenCalledWith('/api/lm/get-registry');
const nodeLabels = Array.from(
document.querySelectorAll('#nodeSelector .node-item[data-node-id] span')
).map((span) => span.textContent.trim());
expect(nodeLabels).toEqual([
'#1 Root Loader',
'#2 (Character Subgraph) Nested Loader',
]);
});
}); });

View File

@@ -16,10 +16,12 @@ from aiohttp.test_utils import TestClient, TestServer
from py.config import config from py.config import config
from py.routes.base_model_routes import BaseModelRoutes from py.routes.base_model_routes import BaseModelRoutes
from py.services import model_file_service from py.services import model_file_service
from py.services.metadata_sync_service import MetadataSyncService
from py.services.model_file_service import AutoOrganizeResult from py.services.model_file_service import AutoOrganizeResult
from py.services.service_registry import ServiceRegistry from py.services.service_registry import ServiceRegistry
from py.services.websocket_manager import ws_manager from py.services.websocket_manager import ws_manager
from py.utils.exif_utils import ExifUtils from py.utils.exif_utils import ExifUtils
from py.utils.metadata_manager import MetadataManager
class DummyRoutes(BaseModelRoutes): class DummyRoutes(BaseModelRoutes):
@@ -197,6 +199,116 @@ def test_replace_preview_writes_file_and_updates_cache(
asyncio.run(scenario()) asyncio.run(scenario())
def test_fetch_civitai_hydrates_metadata_before_sync(
mock_service,
mock_scanner,
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
):
model_path = tmp_path / "hydrate.safetensors"
model_path.write_bytes(b"model")
metadata_path = tmp_path / "hydrate.metadata.json"
existing_metadata = {
"file_path": str(model_path),
"sha256": "abc123",
"model_name": "Hydrated",
"preview_url": "keep/me.png",
"civitai": {
"id": 99,
"modelId": 42,
"images": [{"url": "https://example.com/existing.png", "type": "image"}],
"customImages": [{"id": "old-id", "url": "", "type": "image"}],
"trainedWords": ["keep"],
},
"custom_field": "preserve",
}
metadata_path.write_text(json.dumps(existing_metadata), encoding="utf-8")
minimal_cache_entry = {
"file_path": str(model_path),
"sha256": "abc123",
"folder": "some/folder",
"civitai": {"id": 99, "modelId": 42},
}
mock_scanner._cache.raw_data = [minimal_cache_entry]
class FakeMetadata:
def __init__(self, payload: dict) -> None:
self._payload = payload
self._unknown_fields = {"legacy_field": "legacy"}
def to_dict(self) -> dict:
return self._payload.copy()
async def fake_load_metadata(path: str, *_args, **_kwargs):
assert path == str(model_path)
return FakeMetadata(existing_metadata), False
async def fake_save_metadata(path: str, metadata: dict) -> bool:
save_calls.append((path, json.loads(json.dumps(metadata))))
return True
async def fake_fetch_and_update_model(
self,
*,
sha256: str,
file_path: str,
model_data: dict,
update_cache_func,
):
captured["model_data"] = json.loads(json.dumps(model_data))
to_save = model_data.copy()
to_save.pop("folder", None)
await MetadataManager.save_metadata(
os.path.splitext(file_path)[0] + ".metadata.json",
to_save,
)
await update_cache_func(file_path, file_path, model_data)
return True, None
save_calls: list[tuple[str, dict]] = []
captured: dict[str, dict] = {}
monkeypatch.setattr(MetadataManager, "load_metadata", staticmethod(fake_load_metadata))
monkeypatch.setattr(MetadataManager, "save_metadata", staticmethod(fake_save_metadata))
monkeypatch.setattr(MetadataSyncService, "fetch_and_update_model", fake_fetch_and_update_model)
async def scenario():
client = await create_test_client(mock_service)
try:
response = await client.post(
"/api/lm/test-models/fetch-civitai",
json={"file_path": str(model_path)},
)
payload = await response.json()
assert response.status == 200
assert payload["success"] is True
assert captured["model_data"]["custom_field"] == "preserve"
assert captured["model_data"]["civitai"]["images"][0]["url"] == "https://example.com/existing.png"
assert captured["model_data"]["civitai"]["trainedWords"] == ["keep"]
assert captured["model_data"]["civitai"]["id"] == 99
finally:
await client.close()
asyncio.run(scenario())
assert save_calls, "Metadata save should be invoked"
saved_path, saved_payload = save_calls[0]
assert saved_path == str(metadata_path)
assert saved_payload["custom_field"] == "preserve"
assert saved_payload["civitai"]["images"][0]["url"] == "https://example.com/existing.png"
assert saved_payload["civitai"]["trainedWords"] == ["keep"]
assert saved_payload["civitai"]["id"] == 99
assert saved_payload["legacy_field"] == "legacy"
assert mock_scanner.updated_models
updated_metadata = mock_scanner.updated_models[-1]["metadata"]
assert updated_metadata["custom_field"] == "preserve"
assert updated_metadata["civitai"]["customImages"][0]["id"] == "old-id"
def test_download_model_invokes_download_manager( def test_download_model_invokes_download_manager(
mock_service, mock_service,
download_manager_stub, download_manager_stub,

View File

@@ -188,7 +188,7 @@ async def test_get_trigger_words_broadcasts(monkeypatch, routes):
monkeypatch.setattr("py.routes.lora_routes.get_lora_info", lambda name: (f"path/{name}", [f"trigger-{name}"])) monkeypatch.setattr("py.routes.lora_routes.get_lora_info", lambda name: (f"path/{name}", [f"trigger-{name}"]))
request = DummyRequest(json_data={"lora_names": ["one"], "node_ids": ["node"]}) request = DummyRequest(json_data={"lora_names": ["one"], "node_ids": [{"node_id": "node", "graph_id": "graph-1"}]})
response = await routes.get_trigger_words(request) response = await routes.get_trigger_words(request)
payload = json.loads(response.text) payload = json.loads(response.text)
@@ -196,7 +196,7 @@ async def test_get_trigger_words_broadcasts(monkeypatch, routes):
assert payload == {"success": True} assert payload == {"success": True}
send_mock.assert_called_once_with( send_mock.assert_called_once_with(
"trigger_word_update", "trigger_word_update",
{"id": "node", "message": "trigger-one"}, {"id": "node", "graph_id": "graph-1", "message": "trigger-one"},
) )

View File

@@ -4,7 +4,14 @@ from types import SimpleNamespace
import pytest import pytest
from aiohttp import web from aiohttp import web
from py.routes.handlers.misc_handlers import SettingsHandler, ServiceRegistryAdapter from py.routes.handlers.misc_handlers import (
LoraCodeHandler,
ModelLibraryHandler,
NodeRegistry,
NodeRegistryHandler,
ServiceRegistryAdapter,
SettingsHandler,
)
from py.routes.misc_route_registrar import MISC_ROUTE_DEFINITIONS, MiscRouteRegistrar from py.routes.misc_route_registrar import MISC_ROUTE_DEFINITIONS, MiscRouteRegistrar
from py.routes.misc_routes import MiscRoutes from py.routes.misc_routes import MiscRoutes
@@ -126,6 +133,128 @@ class FakePromptServer:
instance = Instance() instance = Instance()
@pytest.mark.asyncio
async def test_register_nodes_requires_graph_id():
node_registry = NodeRegistry()
handler = NodeRegistryHandler(
node_registry=node_registry,
prompt_server=FakePromptServer,
standalone_mode=False,
)
request = FakeRequest(json_data={"nodes": [{"node_id": 1}]})
response = await handler.register_nodes(request)
payload = json.loads(response.text)
assert response.status == 400
assert payload["success"] is False
assert "graph_id" in payload["error"]
@pytest.mark.asyncio
async def test_register_nodes_stores_graph_identifier():
node_registry = NodeRegistry()
handler = NodeRegistryHandler(
node_registry=node_registry,
prompt_server=FakePromptServer,
standalone_mode=False,
)
request = FakeRequest(
json_data={
"nodes": [
{
"node_id": 7,
"graph_id": "graph-123",
"graph_name": "Character Subgraph",
"type": "Lora Loader (LoraManager)",
"title": "Loader",
}
]
}
)
response = await handler.register_nodes(request)
payload = json.loads(response.text)
assert payload["success"] is True
registry = await node_registry.get_registry()
assert registry["node_count"] == 1
stored_node = next(iter(registry["nodes"].values()))
assert stored_node["graph_id"] == "graph-123"
assert stored_node["unique_id"] == "graph-123:7"
assert stored_node["graph_name"] == "Character Subgraph"
@pytest.mark.asyncio
async def test_register_nodes_defaults_graph_name_to_none():
node_registry = NodeRegistry()
handler = NodeRegistryHandler(
node_registry=node_registry,
prompt_server=FakePromptServer,
standalone_mode=False,
)
request = FakeRequest(
json_data={
"nodes": [
{
"node_id": 8,
"graph_id": "root",
"type": "Lora Loader (LoraManager)",
"title": "Root Loader",
}
]
}
)
response = await handler.register_nodes(request)
payload = json.loads(response.text)
assert payload["success"] is True
registry = await node_registry.get_registry()
stored_node = next(iter(registry["nodes"].values()))
assert stored_node["graph_name"] is None
@pytest.mark.asyncio
async def test_update_lora_code_includes_graph_identifier():
send_calls: list[tuple[str, dict]] = []
class RecordingPromptServer:
class Instance:
def send_sync(self, event, payload):
send_calls.append((event, payload))
instance = Instance()
handler = LoraCodeHandler(RecordingPromptServer)
request = FakeRequest(
json_data={
"node_ids": [{"node_id": 3, "graph_id": "graph-A"}],
"lora_code": "<lora>",
"mode": "replace",
}
)
response = await handler.update_lora_code(request)
payload = json.loads(response.text)
assert payload["success"] is True
assert payload["results"] == [
{"node_id": 3, "graph_id": "graph-A", "success": True}
]
assert send_calls == [
(
"lora_code_update",
{"id": 3, "graph_id": "graph-A", "lora_code": "<lora>", "mode": "replace"},
)
]
class FakeScanner: class FakeScanner:
async def check_model_version_exists(self, _version_id): async def check_model_version_exists(self, _version_id):
return False return False
@@ -138,10 +267,34 @@ async def fake_scanner_factory():
return FakeScanner() return FakeScanner()
class FakeExistenceScanner:
def __init__(self, existing=None):
self._existing = set(existing or [])
async def check_model_version_exists(self, version_id):
return version_id in self._existing
async def get_model_versions_by_id(self, _model_id):
return []
class FakeMetadataProvider: class FakeMetadataProvider:
async def get_model_versions(self, _model_id): async def get_model_versions(self, _model_id):
return {"modelVersions": [], "name": "", "type": "lora"} return {"modelVersions": [], "name": "", "type": "lora"}
async def get_user_models(self, _username):
return []
class FakeUserModelsProvider(FakeMetadataProvider):
def __init__(self, models):
self.models = models
self.received_usernames: list[str] = []
async def get_user_models(self, username):
self.received_usernames.append(username)
return self.models
async def fake_metadata_provider_factory(): async def fake_metadata_provider_factory():
return FakeMetadataProvider() return FakeMetadataProvider()
@@ -211,6 +364,250 @@ async def test_misc_routes_bind_produces_expected_handlers():
assert set(mapping.keys()) == expected_names assert set(mapping.keys()) == expected_names
@pytest.mark.asyncio
async def test_get_civitai_user_models_marks_library_versions():
models = [
{
"id": 1,
"name": "Model A",
"type": "LORA",
"tags": ["style"],
"modelVersions": [
{
"id": 100,
"name": "v1",
"baseModel": "Flux.1",
"images": [{"url": "http://example.com/a1.jpg"}],
},
{
"id": 101,
"name": "v2",
"baseModel": "Flux.1",
"images": [{"url": "http://example.com/a2.jpg"}],
},
],
},
{
"id": 2,
"name": "Embedding",
"type": "TextualInversion",
"tags": ["embedding"],
"modelVersions": [
{
"id": 200,
"name": "v1",
"baseModel": None,
"images": [{"url": "http://example.com/e1.jpg"}],
},
{
"id": 202,
"name": "v2",
"baseModel": None,
},
],
},
{
"id": 3,
"name": "Checkpoint",
"type": "Checkpoint",
"tags": ["checkpoint"],
"modelVersions": [
{
"id": 300,
"name": "v1",
"baseModel": "SDXL",
"images": [],
}
],
},
{
"id": 4,
"name": "Unsupported",
"type": "Other",
"modelVersions": [
{
"id": 400,
"name": "v1",
}
],
},
]
provider = FakeUserModelsProvider(models)
async def provider_factory():
return provider
lora_scanner = FakeExistenceScanner({101})
checkpoint_scanner = FakeExistenceScanner()
embedding_scanner = FakeExistenceScanner({202})
async def lora_factory():
return lora_scanner
async def checkpoint_factory():
return checkpoint_scanner
async def embedding_factory():
return embedding_scanner
handler = ModelLibraryHandler(
ServiceRegistryAdapter(
get_lora_scanner=lora_factory,
get_checkpoint_scanner=checkpoint_factory,
get_embedding_scanner=embedding_factory,
),
metadata_provider_factory=provider_factory,
)
response = await handler.get_civitai_user_models(FakeRequest(query={"username": "pixel"}))
payload = json.loads(response.text)
assert payload["success"] is True
assert payload["username"] == "pixel"
assert payload["versions"] == [
{
"modelId": 1,
"versionId": 100,
"modelName": "Model A",
"versionName": "v1",
"type": "LORA",
"tags": ["style"],
"baseModel": "Flux.1",
"thumbnailUrl": "http://example.com/a1.jpg",
"inLibrary": False,
},
{
"modelId": 1,
"versionId": 101,
"modelName": "Model A",
"versionName": "v2",
"type": "LORA",
"tags": ["style"],
"baseModel": "Flux.1",
"thumbnailUrl": "http://example.com/a2.jpg",
"inLibrary": True,
},
{
"modelId": 2,
"versionId": 200,
"modelName": "Embedding",
"versionName": "v1",
"type": "TextualInversion",
"tags": ["embedding"],
"baseModel": None,
"thumbnailUrl": "http://example.com/e1.jpg",
"inLibrary": False,
},
{
"modelId": 2,
"versionId": 202,
"modelName": "Embedding",
"versionName": "v2",
"type": "TextualInversion",
"tags": ["embedding"],
"baseModel": None,
"thumbnailUrl": None,
"inLibrary": True,
},
{
"modelId": 3,
"versionId": 300,
"modelName": "Checkpoint",
"versionName": "v1",
"type": "Checkpoint",
"tags": ["checkpoint"],
"baseModel": "SDXL",
"thumbnailUrl": None,
"inLibrary": False,
},
]
assert provider.received_usernames == ["pixel"]
@pytest.mark.asyncio
async def test_get_civitai_user_models_rewrites_civitai_previews():
image_url = "https://image.civitai.com/container/example/original=true/sample.jpeg"
video_url = "https://image.civitai.com/container/example/original=true/sample.mp4"
models = [
{
"id": 1,
"name": "Model A",
"type": "LORA",
"tags": ["style"],
"modelVersions": [
{
"id": 100,
"name": "preview-image",
"baseModel": "Flux.1",
"images": [
{"url": image_url, "type": "image"},
],
},
{
"id": 101,
"name": "preview-video",
"baseModel": "Flux.1",
"images": [
{"url": video_url, "type": "video"},
],
},
],
},
]
provider = FakeUserModelsProvider(models)
async def provider_factory():
return provider
handler = ModelLibraryHandler(
ServiceRegistryAdapter(
get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory,
),
metadata_provider_factory=provider_factory,
)
response = await handler.get_civitai_user_models(FakeRequest(query={"username": "pixel"}))
payload = json.loads(response.text)
assert payload["success"] is True
previews_by_version = {item["versionId"]: item["thumbnailUrl"] for item in payload["versions"]}
assert previews_by_version[100] == "https://image.civitai.com/container/example/width=450,optimized=true/sample.jpeg"
assert (
previews_by_version[101]
== "https://image.civitai.com/container/example/transcode=true,width=450,optimized=true/sample.mp4"
)
@pytest.mark.asyncio
async def test_get_civitai_user_models_requires_username():
provider = FakeUserModelsProvider([])
async def provider_factory():
return provider
handler = ModelLibraryHandler(
ServiceRegistryAdapter(
get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory,
),
metadata_provider_factory=provider_factory,
)
response = await handler.get_civitai_user_models(FakeRequest())
payload = json.loads(response.text)
assert response.status == 400
assert payload["success"] is False
assert "username" in payload["error"].lower()
def test_ensure_handler_mapping_caches_result(): def test_ensure_handler_mapping_caches_result():
call_records = [] call_records = []

View File

@@ -1,3 +1,4 @@
import copy
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
import pytest import pytest
@@ -169,6 +170,158 @@ async def test_get_model_version_by_version_id(monkeypatch, downloader):
assert result["images"][0]["meta"]["other"] == "keep" assert result["images"][0]["meta"]["other"] == "keep"
async def test_get_model_version_with_model_id_prefers_version_endpoint(monkeypatch, downloader):
requests = []
model_payload = {
"modelVersions": [
{
"id": 7,
"files": [
{
"type": "Model",
"primary": True,
"hashes": {"SHA256": "hash7"},
}
],
}
],
"description": "desc",
"tags": ["tag"],
"creator": {"username": "user"},
"name": "Model",
"type": "LORA",
"nsfw": False,
"poi": False,
}
version_payload = {
"id": 7,
"modelId": 99,
"model": {},
"files": [],
"images": [],
}
async def fake_make_request(method, url, use_auth=True):
requests.append(url)
if url.endswith("/models/99"):
return True, copy.deepcopy(model_payload)
if url.endswith("/model-versions/7"):
return True, copy.deepcopy(version_payload)
return False, "unexpected"
downloader.make_request = fake_make_request
client = await CivitaiClient.get_instance()
result = await client.get_model_version(model_id=99, version_id=7)
assert result["id"] == 7
assert result["model"]["description"] == "desc"
assert result["model"]["tags"] == ["tag"]
assert result["creator"] == {"username": "user"}
assert requests[0].endswith("/models/99")
assert requests[1].endswith("/model-versions/7")
async def test_get_model_version_with_model_id_fallbacks_to_hash(monkeypatch, downloader):
requests = []
model_payload = {
"modelVersions": [
{
"id": 7,
"files": [
{
"type": "Model",
"primary": True,
"hashes": {"SHA256": "hash7"},
}
],
}
],
"description": "desc",
"tags": ["tag"],
"creator": {"username": "user"},
"name": "Model",
"type": "LORA",
"nsfw": False,
"poi": False,
}
version_payload = {
"id": 7,
"modelId": 99,
"files": [],
"images": [],
}
async def fake_make_request(method, url, use_auth=True):
requests.append(url)
if url.endswith("/models/99"):
return True, copy.deepcopy(model_payload)
if url.endswith("/model-versions/7"):
return False, "boom"
if url.endswith("/model-versions/by-hash/hash7"):
return True, copy.deepcopy(version_payload)
return False, "unexpected"
downloader.make_request = fake_make_request
client = await CivitaiClient.get_instance()
result = await client.get_model_version(model_id=99, version_id=7)
assert result["id"] == 7
assert result["model"]["description"] == "desc"
assert result["model"]["tags"] == ["tag"]
assert result["creator"] == {"username": "user"}
assert requests[1].endswith("/model-versions/7")
assert requests[2].endswith("/model-versions/by-hash/hash7")
async def test_get_model_version_with_model_id_builds_from_model_data(monkeypatch, downloader):
model_payload = {
"modelVersions": [
{
"id": 7,
"files": [],
"name": "v1",
}
],
"description": "desc",
"tags": ["tag"],
"creator": {"username": "user"},
"name": "Model",
"type": "LORA",
"nsfw": False,
"poi": False,
}
async def fake_make_request(method, url, use_auth=True):
if url.endswith("/models/99"):
return True, copy.deepcopy(model_payload)
if url.endswith("/model-versions/7"):
return False, "boom"
if "/model-versions/by-hash/" in url:
return False, "boom"
return False, "unexpected"
downloader.make_request = fake_make_request
client = await CivitaiClient.get_instance()
result = await client.get_model_version(model_id=99, version_id=7)
assert result["modelId"] == 99
assert result["model"]["name"] == "Model"
assert result["model"]["type"] == "LORA"
assert result["model"]["description"] == "desc"
assert result["model"]["tags"] == ["tag"]
assert result["creator"] == {"username": "user"}
async def test_get_model_version_requires_identifier(monkeypatch, downloader): async def test_get_model_version_requires_identifier(monkeypatch, downloader):
client = await CivitaiClient.get_instance() client = await CivitaiClient.get_instance()
result = await client.get_model_version() result = await client.get_model_version()

View File

@@ -8,7 +8,7 @@ import pytest
from py.services.download_manager import DownloadManager from py.services.download_manager import DownloadManager
from py.services import download_manager from py.services import download_manager
from py.services.service_registry import ServiceRegistry from py.services.service_registry import ServiceRegistry
from py.services.settings_manager import settings from py.services.settings_manager import SettingsManager, get_settings_manager
from py.utils.metadata_manager import MetadataManager from py.utils.metadata_manager import MetadataManager
@@ -23,7 +23,8 @@ def reset_download_manager():
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def isolate_settings(monkeypatch, tmp_path): def isolate_settings(monkeypatch, tmp_path):
"""Point settings writes at a temporary directory to avoid touching real files.""" """Point settings writes at a temporary directory to avoid touching real files."""
default_settings = settings._get_default_settings() manager = get_settings_manager()
default_settings = manager._get_default_settings()
default_settings.update( default_settings.update(
{ {
"default_lora_root": str(tmp_path), "default_lora_root": str(tmp_path),
@@ -37,8 +38,8 @@ def isolate_settings(monkeypatch, tmp_path):
"base_model_path_mappings": {"BaseModel": "MappedModel"}, "base_model_path_mappings": {"BaseModel": "MappedModel"},
} }
) )
monkeypatch.setattr(settings, "settings", default_settings) monkeypatch.setattr(manager, "settings", default_settings)
monkeypatch.setattr(type(settings), "_save_settings", lambda self: None) monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@@ -187,7 +188,7 @@ async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata
assert manager._active_downloads[result["download_id"]]["status"] == "completed" assert manager._active_downloads[result["download_id"]]["status"] == "completed"
assert captured["relative_path"] == "MappedModel/fantasy" assert captured["relative_path"] == "MappedModel/fantasy"
expected_dir = Path(settings.get("default_lora_root")) / "MappedModel" / "fantasy" expected_dir = Path(get_settings_manager().get("default_lora_root")) / "MappedModel" / "fantasy"
assert captured["save_dir"] == expected_dir assert captured["save_dir"] == expected_dir
assert captured["model_type"] == "lora" assert captured["model_type"] == "lora"
assert captured["download_urls"] == [ assert captured["download_urls"] == [
@@ -393,3 +394,98 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
assert result == {"success": True} assert result == {"success": True}
assert [url for url, *_ in dummy_downloader.calls] == download_urls assert [url for url, *_ in dummy_downloader.calls] == download_urls
assert dummy_scanner.calls # ensure cache updated assert dummy_scanner.calls # ensure cache updated
@pytest.mark.asyncio
async def test_execute_download_uses_rewritten_civitai_preview(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
target_path = save_dir / "file.safetensors"
manager._active_downloads["dl"] = {}
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
self.preview_nsfw_level = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, _path):
return None
def to_dict(self):
return {"file_path": self.file_path}
metadata = DummyMetadata(target_path)
version_info = {
"images": [
{
"url": "https://image.civitai.com/container/example/original=true/sample.jpeg",
"type": "image",
"nsfwLevel": 2,
}
]
}
download_urls = ["https://example.invalid/file.safetensors"]
class DummyDownloader:
def __init__(self):
self.file_calls: list[tuple[str, str]] = []
self.memory_calls = 0
async def download_file(self, url, path, progress_callback=None, use_auth=None):
self.file_calls.append((url, path))
if url.endswith(".jpeg"):
Path(path).write_bytes(b"preview")
return True, None
if url.endswith(".safetensors"):
Path(path).write_bytes(b"model")
return True, None
return False, "unexpected url"
async def download_to_memory(self, *_args, **_kwargs):
self.memory_calls += 1
return False, b"", {}
dummy_downloader = DummyDownloader()
monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=dummy_downloader))
optimize_called = {"value": False}
def fake_optimize_image(**_kwargs):
optimize_called["value"] = True
return b"", {}
monkeypatch.setattr(download_manager.ExifUtils, "optimize_image", staticmethod(fake_optimize_image))
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner))
result = await manager._execute_download(
download_urls=download_urls,
save_dir=str(save_dir),
metadata=metadata,
version_info=version_info,
relative_path="",
progress_callback=None,
model_type="lora",
download_id="dl",
)
assert result == {"success": True}
preview_urls = [url for url, _ in dummy_downloader.file_calls if url.endswith(".jpeg")]
assert any("width=450,optimized=true" in url for url in preview_urls)
assert dummy_downloader.memory_calls == 0
assert optimize_called["value"] is False
assert metadata.preview_url.endswith(".jpeg")
assert metadata.preview_nsfw_level == 2
stored_preview = manager._active_downloads["dl"]["preview_path"]
assert stored_preview.endswith(".jpeg")
assert Path(stored_preview).exists()

View File

@@ -6,7 +6,7 @@ import pytest
from py.services.example_images_cleanup_service import ExampleImagesCleanupService from py.services.example_images_cleanup_service import ExampleImagesCleanupService
from py.services.service_registry import ServiceRegistry from py.services.service_registry import ServiceRegistry
from py.services.settings_manager import settings from py.services.settings_manager import get_settings_manager
class StubScanner: class StubScanner:
@@ -21,8 +21,9 @@ class StubScanner:
async def test_cleanup_moves_empty_and_orphaned(tmp_path, monkeypatch): async def test_cleanup_moves_empty_and_orphaned(tmp_path, monkeypatch):
service = ExampleImagesCleanupService() service = ExampleImagesCleanupService()
previous_path = settings.get('example_images_path') settings_manager = get_settings_manager()
settings.settings['example_images_path'] = str(tmp_path) previous_path = settings_manager.get('example_images_path')
settings_manager.settings['example_images_path'] = str(tmp_path)
try: try:
empty_folder = tmp_path / 'empty_folder' empty_folder = tmp_path / 'empty_folder'
@@ -64,23 +65,24 @@ async def test_cleanup_moves_empty_and_orphaned(tmp_path, monkeypatch):
finally: finally:
if previous_path is None: if previous_path is None:
settings.settings.pop('example_images_path', None) settings_manager.settings.pop('example_images_path', None)
else: else:
settings.settings['example_images_path'] = previous_path settings_manager.settings['example_images_path'] = previous_path
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_handles_missing_path(monkeypatch): async def test_cleanup_handles_missing_path(monkeypatch):
service = ExampleImagesCleanupService() service = ExampleImagesCleanupService()
previous_path = settings.get('example_images_path') settings_manager = get_settings_manager()
settings.settings.pop('example_images_path', None) previous_path = settings_manager.get('example_images_path')
settings_manager.settings.pop('example_images_path', None)
try: try:
result = await service.cleanup_example_image_folders() result = await service.cleanup_example_image_folders()
finally: finally:
if previous_path is not None: if previous_path is not None:
settings.settings['example_images_path'] = previous_path settings_manager.settings['example_images_path'] = previous_path
assert result['success'] is False assert result['success'] is False
assert result['error_code'] == 'path_not_configured' assert result['error_code'] == 'path_not_configured'

View File

@@ -7,7 +7,7 @@ from types import SimpleNamespace
import pytest import pytest
from py.services.settings_manager import settings from py.services.settings_manager import SettingsManager, get_settings_manager
from py.utils import example_images_download_manager as download_module from py.utils import example_images_download_manager as download_module
@@ -43,11 +43,15 @@ def _patch_scanner(monkeypatch: pytest.MonkeyPatch, scanner: StubScanner) -> Non
@pytest.mark.usefixtures("tmp_path") @pytest.mark.usefixtures("tmp_path")
async def test_start_download_rejects_parallel_runs(monkeypatch: pytest.MonkeyPatch, tmp_path): async def test_start_download_rejects_parallel_runs(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
ws_manager = RecordingWebSocketManager() ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager) manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path)) monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
model = { model = {
"sha256": "abc123", "sha256": "abc123",
@@ -106,11 +110,15 @@ async def test_start_download_rejects_parallel_runs(monkeypatch: pytest.MonkeyPa
@pytest.mark.usefixtures("tmp_path") @pytest.mark.usefixtures("tmp_path")
async def test_pause_resume_blocks_processing(monkeypatch: pytest.MonkeyPatch, tmp_path): async def test_pause_resume_blocks_processing(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
ws_manager = RecordingWebSocketManager() ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager) manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path)) monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
models = [ models = [
{ {
@@ -231,13 +239,17 @@ async def test_pause_resume_blocks_processing(monkeypatch: pytest.MonkeyPatch, t
@pytest.mark.usefixtures("tmp_path") @pytest.mark.usefixtures("tmp_path")
async def test_legacy_folder_migrated_and_skipped(monkeypatch: pytest.MonkeyPatch, tmp_path): async def test_legacy_folder_migrated_and_skipped(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
ws_manager = RecordingWebSocketManager() ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager) manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path)) monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}}) monkeypatch.setitem(settings_manager.settings, "libraries", {"default": {}, "extra": {}})
monkeypatch.setitem(settings.settings, "active_library", "extra") monkeypatch.setitem(settings_manager.settings, "active_library", "extra")
model_hash = "d" * 64 model_hash = "d" * 64
model_path = tmp_path / "model.safetensors" model_path = tmp_path / "model.safetensors"
@@ -310,13 +322,17 @@ async def test_legacy_folder_migrated_and_skipped(monkeypatch: pytest.MonkeyPatc
@pytest.mark.usefixtures("tmp_path") @pytest.mark.usefixtures("tmp_path")
async def test_legacy_progress_file_migrates(monkeypatch: pytest.MonkeyPatch, tmp_path): async def test_legacy_progress_file_migrates(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
ws_manager = RecordingWebSocketManager() ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager) manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path)) monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}}) monkeypatch.setitem(settings_manager.settings, "libraries", {"default": {}, "extra": {}})
monkeypatch.setitem(settings.settings, "active_library", "extra") monkeypatch.setitem(settings_manager.settings, "active_library", "extra")
model_hash = "e" * 64 model_hash = "e" * 64
model_path = tmp_path / "model-two.safetensors" model_path = tmp_path / "model-two.safetensors"
@@ -380,20 +396,24 @@ async def test_legacy_progress_file_migrates(monkeypatch: pytest.MonkeyPatch, tm
@pytest.mark.usefixtures("tmp_path") @pytest.mark.usefixtures("tmp_path")
async def test_download_remains_in_initial_library(monkeypatch: pytest.MonkeyPatch, tmp_path): async def test_download_remains_in_initial_library(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
ws_manager = RecordingWebSocketManager() ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager) manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path)) monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
monkeypatch.setitem(settings.settings, "libraries", {"LibraryA": {}, "LibraryB": {}}) monkeypatch.setitem(settings_manager.settings, "libraries", {"LibraryA": {}, "LibraryB": {}})
monkeypatch.setitem(settings.settings, "active_library", "LibraryA") monkeypatch.setitem(settings_manager.settings, "active_library", "LibraryA")
state = {"active": "LibraryA"} state = {"active": "LibraryA"}
def fake_get_active_library_name(self): def fake_get_active_library_name(self):
return state["active"] return state["active"]
monkeypatch.setattr(type(settings), "get_active_library_name", fake_get_active_library_name) monkeypatch.setattr(SettingsManager, "get_active_library_name", fake_get_active_library_name)
model_hash = "f" * 64 model_hash = "f" * 64
model_path = tmp_path / "example-model.safetensors" model_path = tmp_path / "example-model.safetensors"
@@ -454,3 +474,7 @@ async def test_download_remains_in_initial_library(monkeypatch: pytest.MonkeyPat
assert (model_dir / "local.txt").exists() assert (model_dir / "local.txt").exists()
assert not (library_b_root / ".download_progress.json").exists() assert not (library_b_root / ".download_progress.json").exists()
assert not (library_b_root / model_hash).exists() assert not (library_b_root / model_hash).exists()
@pytest.fixture
def settings_manager():
return get_settings_manager()

View File

@@ -243,6 +243,7 @@ async def test_initialize_in_background_uses_persisted_cache_without_full_scan(t
cache = await scanner.get_cached_data() cache = await scanner.get_cached_data()
assert len(cache.raw_data) == 1 assert len(cache.raw_data) == 1
assert cache.raw_data[0]['file_path'] == normalized assert cache.raw_data[0]['file_path'] == normalized
assert cache.version_index[11]['file_path'] == normalized
assert scanner._hash_index.get_path('hash-one') == normalized assert scanner._hash_index.get_path('hash-one') == normalized
@@ -301,6 +302,7 @@ async def test_load_persisted_cache_populates_cache(tmp_path: Path, monkeypatch)
assert entry['file_path'] == normalized assert entry['file_path'] == normalized
assert entry['tags'] == ['alpha'] assert entry['tags'] == ['alpha']
assert entry['civitai']['trainedWords'] == ['abc'] assert entry['civitai']['trainedWords'] == ['abc']
assert cache.version_index[11]['file_path'] == normalized
assert scanner._hash_index.get_path('hash-one') == normalized assert scanner._hash_index.get_path('hash-one') == normalized
assert scanner._tags_count == {'alpha': 1} assert scanner._tags_count == {'alpha': 1}
assert ws_stub.payloads[-1]['stage'] == 'loading_cache' assert ws_stub.payloads[-1]['stage'] == 'loading_cache'
@@ -381,6 +383,66 @@ async def test_batch_delete_persists_removal(tmp_path: Path, monkeypatch):
assert remaining == 0 assert remaining == 0
@pytest.mark.asyncio
async def test_version_index_tracks_version_ids(tmp_path: Path):
scanner = DummyScanner(tmp_path)
first_path = _normalize_path(tmp_path / 'alpha.txt')
second_path = _normalize_path(tmp_path / 'beta.txt')
first_entry = {
'file_path': first_path,
'file_name': 'alpha',
'model_name': 'alpha',
'folder': '',
'size': 1,
'modified': 1.0,
'sha256': 'hash-alpha',
'tags': [],
'civitai': {'id': 101, 'modelId': 1, 'name': 'alpha'},
}
second_entry = {
'file_path': second_path,
'file_name': 'beta',
'model_name': 'beta',
'folder': '',
'size': 1,
'modified': 1.0,
'sha256': 'hash-beta',
'tags': [],
'civitai': {'id': 202, 'modelId': 2, 'name': 'beta'},
}
hash_index = ModelHashIndex()
hash_index.add_entry('hash-alpha', first_path)
hash_index.add_entry('hash-beta', second_path)
scan_result = CacheBuildResult(
raw_data=[first_entry, second_entry],
hash_index=hash_index,
tags_count={},
excluded_models=[],
)
await scanner._apply_scan_result(scan_result)
cache = await scanner.get_cached_data()
assert cache.version_index[101]['file_path'] == first_path
assert cache.version_index[202]['file_path'] == second_path
assert await scanner.check_model_version_exists(101) is True
assert await scanner.check_model_version_exists('202') is True
assert await scanner.check_model_version_exists(999) is False
removed = await scanner._batch_update_cache_for_deleted_models([first_path])
assert removed is True
cache_after = await scanner.get_cached_data()
assert 101 not in cache_after.version_index
assert await scanner.check_model_version_exists(101) is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reconcile_cache_adds_new_files_and_updates_hash_index(tmp_path: Path): async def test_reconcile_cache_adds_new_files_and_updates_hash_index(tmp_path: Path):
first, _, _ = _create_files(tmp_path) first, _, _ = _create_files(tmp_path)

View File

@@ -0,0 +1,182 @@
from pathlib import Path
from typing import Any
import pytest
from py.services.preview_asset_service import PreviewAssetService
class StubMetadataManager:
async def save_metadata(self, *_args: Any, **_kwargs: Any) -> bool: # pragma: no cover - helper
return True
class RecordingExifUtils:
def __init__(self) -> None:
self.called = False
def optimize_image(self, **kwargs):
self.called = True
return kwargs["image_data"], {}
@pytest.mark.asyncio
async def test_ensure_preview_prefers_rewritten_civitai_image(tmp_path):
metadata_path = tmp_path / "model.metadata.json"
metadata_path.write_text("{}")
local_metadata: dict[str, Any] = {}
class Downloader:
def __init__(self):
self.file_calls: list[tuple[str, str]] = []
self.memory_calls = 0
async def download_file(self, url, path, use_auth=False):
self.file_calls.append((url, path))
if "width=450,optimized=true" in url:
Path(path).write_bytes(b"image-data")
return True, None
return False, "fail"
async def download_to_memory(self, *_args, **_kwargs):
self.memory_calls += 1
return False, b"", {}
downloader = Downloader()
async def downloader_factory():
return downloader
exif_utils = RecordingExifUtils()
service = PreviewAssetService(
metadata_manager=StubMetadataManager(),
downloader_factory=downloader_factory,
exif_utils=exif_utils,
)
images = [
{
"url": "https://image.civitai.com/container/example/original=true/sample.jpeg",
"type": "image",
"nsfwLevel": 3,
}
]
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
assert downloader.memory_calls == 0
assert exif_utils.called is False
assert len(downloader.file_calls) == 1
assert "width=450,optimized=true" in downloader.file_calls[0][0]
preview_path = Path(local_metadata["preview_url"])
assert preview_path.exists()
assert preview_path.suffix == ".jpeg"
assert local_metadata["preview_nsfw_level"] == 3
@pytest.mark.asyncio
async def test_ensure_preview_falls_back_to_webp_when_rewrite_fails(tmp_path):
metadata_path = tmp_path / "model.metadata.json"
metadata_path.write_text("{}")
local_metadata: dict[str, Any] = {}
class Downloader:
def __init__(self):
self.file_calls: list[tuple[str, str]] = []
self.memory_calls = 0
async def download_file(self, url, path, use_auth=False):
self.file_calls.append((url, path))
return False, "fail"
async def download_to_memory(self, *_args, **_kwargs):
self.memory_calls += 1
return True, b"raw-image", {}
downloader = Downloader()
async def downloader_factory():
return downloader
class ExifUtils:
def __init__(self):
self.calls = 0
def optimize_image(self, **kwargs):
self.calls += 1
return b"webp-data", {}
exif_utils = ExifUtils()
service = PreviewAssetService(
metadata_manager=StubMetadataManager(),
downloader_factory=downloader_factory,
exif_utils=exif_utils,
)
images = [
{
"url": "https://image.civitai.com/container/example/original=true/sample.png",
"type": "image",
"nsfwLevel": 1,
}
]
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
assert downloader.memory_calls == 1
assert exif_utils.calls == 1
preview_path = Path(local_metadata["preview_url"])
assert preview_path.exists()
assert preview_path.suffix == ".webp"
@pytest.mark.asyncio
async def test_ensure_preview_rewrites_civitai_video(tmp_path):
metadata_path = tmp_path / "model.metadata.json"
metadata_path.write_text("{}")
local_metadata: dict[str, Any] = {}
class Downloader:
def __init__(self):
self.file_calls: list[tuple[str, str]] = []
async def download_file(self, url, path, use_auth=False):
self.file_calls.append((url, path))
if "transcode=true,width=450,optimized=true" in url:
Path(path).write_bytes(b"video-data")
return True, None
if url.endswith(".mp4"):
return False, "fail"
return False, "unexpected"
async def download_to_memory(self, *_args, **_kwargs):
pytest.fail("download_to_memory should not be used for video previews")
downloader = Downloader()
async def downloader_factory():
return downloader
service = PreviewAssetService(
metadata_manager=StubMetadataManager(),
downloader_factory=downloader_factory,
exif_utils=RecordingExifUtils(),
)
images = [
{
"url": "https://image.civitai.com/container/example/original=true/sample.mp4",
"type": "video",
"nsfwLevel": 2,
}
]
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
assert len(downloader.file_calls) >= 1
assert any("transcode=true,width=450,optimized=true" in url for url, _ in downloader.file_calls)
preview_path = Path(local_metadata["preview_url"])
assert preview_path.exists()
assert preview_path.suffix == ".mp4"
assert local_metadata["preview_nsfw_level"] == 2

View File

@@ -28,6 +28,7 @@ from py.utils.example_images_processor import (
ExampleImagesImportError, ExampleImagesImportError,
ExampleImagesValidationError, ExampleImagesValidationError,
) )
from py.utils.metadata_manager import MetadataManager
from tests.conftest import MockModelService, MockScanner from tests.conftest import MockModelService, MockScanner
@@ -155,7 +156,9 @@ async def test_auto_organize_use_case_rejects_when_running() -> None:
await use_case.execute(file_paths=None, progress_callback=None) await use_case.execute(file_paths=None, progress_callback=None)
async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None: async def test_bulk_metadata_refresh_emits_progress_and_updates_cache(
monkeypatch: pytest.MonkeyPatch,
) -> None:
scanner = MockScanner() scanner = MockScanner()
scanner._cache.raw_data = [ scanner._cache.raw_data = [
{ {
@@ -170,6 +173,25 @@ async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None:
settings = StubSettings() settings = StubSettings()
progress = ProgressCollector() progress = ProgressCollector()
hydration_calls: list[str] = []
async def fake_hydrate(model_data: Dict[str, Any]) -> Dict[str, Any]:
hydration_calls.append(model_data.get("file_path", ""))
model_data.clear()
model_data.update(
{
"file_path": "model1.safetensors",
"sha256": "hash",
"from_civitai": True,
"model_name": "Demo",
"extra": "value",
"civitai": {"images": [{"url": "existing.png", "type": "image"}]},
}
)
return model_data
monkeypatch.setattr(MetadataManager, "hydrate_model_data", staticmethod(fake_hydrate))
use_case = BulkMetadataRefreshUseCase( use_case = BulkMetadataRefreshUseCase(
service=service, service=service,
metadata_sync=metadata_sync, metadata_sync=metadata_sync,
@@ -183,6 +205,9 @@ async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None:
assert progress.events[0]["status"] == "started" assert progress.events[0]["status"] == "started"
assert progress.events[-1]["status"] == "completed" assert progress.events[-1]["status"] == "completed"
assert metadata_sync.calls assert metadata_sync.calls
assert metadata_sync.calls[0]["model_data"]["extra"] == "value"
assert scanner._cache.raw_data[0]["extra"] == "value"
assert hydration_calls == ["model1.safetensors"]
assert scanner._cache.resort_calls == 1 assert scanner._cache.resort_calls == 1

View File

@@ -5,7 +5,7 @@ from typing import Any, Dict
import pytest import pytest
from py.services.settings_manager import settings from py.services.settings_manager import get_settings_manager
from py.utils import example_images_download_manager as download_module from py.utils import example_images_download_manager as download_module
@@ -19,19 +19,21 @@ class RecordingWebSocketManager:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def restore_settings() -> None: def restore_settings() -> None:
original = settings.settings.copy() manager = get_settings_manager()
original = manager.settings.copy()
try: try:
yield yield
finally: finally:
settings.settings.clear() manager.settings.clear()
settings.settings.update(original) manager.settings.update(original)
async def test_start_download_requires_configured_path(monkeypatch: pytest.MonkeyPatch) -> None: async def test_start_download_requires_configured_path(monkeypatch: pytest.MonkeyPatch) -> None:
manager = download_module.DownloadManager(ws_manager=RecordingWebSocketManager()) manager = download_module.DownloadManager(ws_manager=RecordingWebSocketManager())
# Ensure example_images_path is not configured # Ensure example_images_path is not configured
settings.settings.pop('example_images_path', None) settings_manager = get_settings_manager()
settings_manager.settings.pop('example_images_path', None)
with pytest.raises(download_module.DownloadConfigurationError) as exc_info: with pytest.raises(download_module.DownloadConfigurationError) as exc_info:
await manager.start_download({}) await manager.start_download({})
@@ -44,9 +46,10 @@ async def test_start_download_requires_configured_path(monkeypatch: pytest.Monke
async def test_start_download_bootstraps_progress_and_task(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None: async def test_start_download_bootstraps_progress_and_task(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path) settings_manager = get_settings_manager()
settings.settings["libraries"] = {"default": {}} settings_manager.settings["example_images_path"] = str(tmp_path)
settings.settings["active_library"] = "default" settings_manager.settings["libraries"] = {"default": {}}
settings_manager.settings["active_library"] = "default"
ws_manager = RecordingWebSocketManager() ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager) manager = download_module.DownloadManager(ws_manager=ws_manager)
@@ -84,9 +87,10 @@ async def test_start_download_bootstraps_progress_and_task(monkeypatch: pytest.M
async def test_pause_and_resume_flow(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None: async def test_pause_and_resume_flow(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path) settings_manager = get_settings_manager()
settings.settings["libraries"] = {"default": {}} settings_manager.settings["example_images_path"] = str(tmp_path)
settings.settings["active_library"] = "default" settings_manager.settings["libraries"] = {"default": {}}
settings_manager.settings["active_library"] = "default"
ws_manager = RecordingWebSocketManager() ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager) manager = download_module.DownloadManager(ws_manager=ws_manager)

View File

@@ -7,7 +7,7 @@ from typing import Any, Dict
import pytest import pytest
from py.services.settings_manager import settings from py.services.settings_manager import get_settings_manager
from py.utils.example_images_file_manager import ExampleImagesFileManager from py.utils.example_images_file_manager import ExampleImagesFileManager
@@ -22,16 +22,18 @@ class JsonRequest:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def restore_settings() -> None: def restore_settings() -> None:
original = settings.settings.copy() manager = get_settings_manager()
original = manager.settings.copy()
try: try:
yield yield
finally: finally:
settings.settings.clear() manager.settings.clear()
settings.settings.update(original) manager.settings.update(original)
async def test_open_folder_requires_existing_model_directory(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None: async def test_open_folder_requires_existing_model_directory(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path) settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
model_hash = "a" * 64 model_hash = "a" * 64
model_folder = tmp_path / model_hash model_folder = tmp_path / model_hash
model_folder.mkdir() model_folder.mkdir()
@@ -65,7 +67,8 @@ async def test_open_folder_requires_existing_model_directory(monkeypatch: pytest
async def test_open_folder_rejects_invalid_paths(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None: async def test_open_folder_rejects_invalid_paths(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path) settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
def fake_get_model_folder(_hash): def fake_get_model_folder(_hash):
return str(tmp_path.parent / "outside") return str(tmp_path.parent / "outside")
@@ -81,7 +84,8 @@ async def test_open_folder_rejects_invalid_paths(monkeypatch: pytest.MonkeyPatch
async def test_get_files_lists_supported_media(tmp_path) -> None: async def test_get_files_lists_supported_media(tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path) settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
model_hash = "b" * 64 model_hash = "b" * 64
model_folder = tmp_path / model_hash model_folder = tmp_path / model_hash
model_folder.mkdir() model_folder.mkdir()
@@ -99,7 +103,8 @@ async def test_get_files_lists_supported_media(tmp_path) -> None:
async def test_has_images_reports_presence(tmp_path) -> None: async def test_has_images_reports_presence(tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path) settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
model_hash = "c" * 64 model_hash = "c" * 64
model_folder = tmp_path / model_hash model_folder = tmp_path / model_hash
model_folder.mkdir() model_folder.mkdir()

View File

@@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
import json import json
import os
from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
@@ -30,7 +32,23 @@ def patch_metadata_manager(monkeypatch: pytest.MonkeyPatch):
saved.append((path, metadata.copy())) saved.append((path, metadata.copy()))
return True return True
class SimpleMetadata:
def __init__(self, payload: Dict[str, Any]) -> None:
self._payload = payload
self._unknown_fields: Dict[str, Any] = {}
def to_dict(self) -> Dict[str, Any]:
return self._payload.copy()
async def fake_load(path: str, *_args: Any, **_kwargs: Any):
metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json"
if os.path.exists(metadata_path):
data = json.loads(Path(metadata_path).read_text(encoding="utf-8"))
return SimpleMetadata(data), False
return None, False
monkeypatch.setattr(metadata_module.MetadataManager, "save_metadata", staticmethod(fake_save)) monkeypatch.setattr(metadata_module.MetadataManager, "save_metadata", staticmethod(fake_save))
monkeypatch.setattr(metadata_module.MetadataManager, "load_metadata", staticmethod(fake_load))
return saved return saved
@@ -64,10 +82,80 @@ async def test_update_metadata_after_import_enriches_entries(monkeypatch: pytest
assert custom[0]["hasMeta"] is True assert custom[0]["hasMeta"] is True
assert custom[0]["type"] == "image" assert custom[0]["type"] == "image"
assert patch_metadata_manager[0][0] == str(model_file) assert Path(patch_metadata_manager[0][0]) == model_file
assert scanner.updates assert scanner.updates
@pytest.mark.asyncio
async def test_update_metadata_after_import_preserves_existing_metadata(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
patch_metadata_manager,
):
model_hash = "b" * 64
model_file = tmp_path / "preserve.safetensors"
model_file.write_text("content", encoding="utf-8")
metadata_path = tmp_path / "preserve.metadata.json"
existing_payload = {
"model_name": "Example",
"file_path": str(model_file),
"civitai": {
"id": 42,
"modelId": 88,
"name": "Example",
"trainedWords": ["foo"],
"images": [{"url": "https://example.com/default.png", "type": "image"}],
"customImages": [
{"id": "existing-id", "type": "image", "url": "", "nsfwLevel": 0}
],
},
"extraField": "keep-me",
}
metadata_path.write_text(json.dumps(existing_payload), encoding="utf-8")
model_data = {
"sha256": model_hash,
"model_name": "Example",
"file_path": str(model_file),
"civitai": {
"id": 42,
"modelId": 88,
"name": "Example",
"trainedWords": ["foo"],
"customImages": [],
},
}
scanner = StubScanner([model_data])
image_path = tmp_path / "new.png"
image_path.write_bytes(b"fakepng")
monkeypatch.setattr(metadata_module.ExifUtils, "extract_image_metadata", staticmethod(lambda _path: None))
monkeypatch.setattr(metadata_module.MetadataUpdater, "_parse_image_metadata", staticmethod(lambda payload: None))
regular, custom = await metadata_module.MetadataUpdater.update_metadata_after_import(
model_hash,
model_data,
scanner,
[(str(image_path), "new-id")],
)
assert regular == existing_payload["civitai"]["images"]
assert any(entry["id"] == "new-id" for entry in custom)
saved_path, saved_payload = patch_metadata_manager[-1]
assert Path(saved_path) == model_file
assert saved_payload["extraField"] == "keep-me"
assert saved_payload["civitai"]["images"] == existing_payload["civitai"]["images"]
assert saved_payload["civitai"]["trainedWords"] == ["foo"]
assert {entry["id"] for entry in saved_payload["civitai"]["customImages"]} == {"existing-id", "new-id"}
assert scanner.updates
updated_metadata = scanner.updates[-1][2]
assert updated_metadata["civitai"]["images"] == existing_payload["civitai"]["images"]
assert {entry["id"] for entry in updated_metadata["civitai"]["customImages"]} == {"existing-id", "new-id"}
async def test_refresh_model_metadata_records_failures(monkeypatch: pytest.MonkeyPatch, tmp_path): async def test_refresh_model_metadata_records_failures(monkeypatch: pytest.MonkeyPatch, tmp_path):
model_hash = "b" * 64 model_hash = "b" * 64
model_file = tmp_path / "model.safetensors" model_file = tmp_path / "model.safetensors"
@@ -79,6 +167,16 @@ async def test_refresh_model_metadata_records_failures(monkeypatch: pytest.Monke
async def fetch_and_update_model(self, **_kwargs): async def fetch_and_update_model(self, **_kwargs):
return True, None return True, None
async def fake_hydrate(model_data: Dict[str, Any]) -> Dict[str, Any]:
model_data["hydrated"] = True
return model_data
monkeypatch.setattr(
metadata_module.MetadataManager,
"hydrate_model_data",
staticmethod(fake_hydrate),
)
monkeypatch.setattr(metadata_module, "_metadata_sync_service", StubMetadataSync()) monkeypatch.setattr(metadata_module, "_metadata_sync_service", StubMetadataSync())
result = await metadata_module.MetadataUpdater.refresh_model_metadata( result = await metadata_module.MetadataUpdater.refresh_model_metadata(
@@ -89,6 +187,7 @@ async def test_refresh_model_metadata_records_failures(monkeypatch: pytest.Monke
{"refreshed_models": set(), "errors": [], "last_error": None}, {"refreshed_models": set(), "errors": [], "last_error": None},
) )
assert result is True assert result is True
assert cache_item["hydrated"] is True
async def test_update_metadata_from_local_examples_generates_entries(monkeypatch: pytest.MonkeyPatch, tmp_path): async def test_update_metadata_from_local_examples_generates_entries(monkeypatch: pytest.MonkeyPatch, tmp_path):

View File

@@ -6,7 +6,7 @@ from pathlib import Path
import pytest import pytest
from py.services.settings_manager import settings from py.services.settings_manager import get_settings_manager
from py.utils.example_images_paths import ( from py.utils.example_images_paths import (
ensure_library_root_exists, ensure_library_root_exists,
get_model_folder, get_model_folder,
@@ -18,18 +18,24 @@ from py.utils.example_images_paths import (
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def restore_settings(): def restore_settings():
original = copy.deepcopy(settings.settings) manager = get_settings_manager()
original = copy.deepcopy(manager.settings)
try: try:
yield yield
finally: finally:
settings.settings.clear() manager.settings.clear()
settings.settings.update(original) manager.settings.update(original)
def test_get_model_folder_single_library(tmp_path): @pytest.fixture
settings.settings['example_images_path'] = str(tmp_path) def settings_manager():
settings.settings['libraries'] = {'default': {}} return get_settings_manager()
settings.settings['active_library'] = 'default'
def test_get_model_folder_single_library(tmp_path, settings_manager):
settings_manager.settings['example_images_path'] = str(tmp_path)
settings_manager.settings['libraries'] = {'default': {}}
settings_manager.settings['active_library'] = 'default'
model_hash = 'a' * 64 model_hash = 'a' * 64
folder = get_model_folder(model_hash) folder = get_model_folder(model_hash)
@@ -39,13 +45,13 @@ def test_get_model_folder_single_library(tmp_path):
assert relative == model_hash assert relative == model_hash
def test_get_model_folder_multi_library(tmp_path): def test_get_model_folder_multi_library(tmp_path, settings_manager):
settings.settings['example_images_path'] = str(tmp_path) settings_manager.settings['example_images_path'] = str(tmp_path)
settings.settings['libraries'] = { settings_manager.settings['libraries'] = {
'default': {}, 'default': {},
'Alt Library': {}, 'Alt Library': {},
} }
settings.settings['active_library'] = 'Alt Library' settings_manager.settings['active_library'] = 'Alt Library'
model_hash = 'b' * 64 model_hash = 'b' * 64
expected_folder = tmp_path / 'Alt_Library' / model_hash expected_folder = tmp_path / 'Alt_Library' / model_hash
@@ -57,13 +63,13 @@ def test_get_model_folder_multi_library(tmp_path):
assert relative == os.path.join('Alt_Library', model_hash).replace('\\', '/') assert relative == os.path.join('Alt_Library', model_hash).replace('\\', '/')
def test_get_model_folder_migrates_legacy_structure(tmp_path): def test_get_model_folder_migrates_legacy_structure(tmp_path, settings_manager):
settings.settings['example_images_path'] = str(tmp_path) settings_manager.settings['example_images_path'] = str(tmp_path)
settings.settings['libraries'] = { settings_manager.settings['libraries'] = {
'default': {}, 'default': {},
'extra': {}, 'extra': {},
} }
settings.settings['active_library'] = 'extra' settings_manager.settings['active_library'] = 'extra'
model_hash = 'c' * 64 model_hash = 'c' * 64
legacy_folder = tmp_path / model_hash legacy_folder = tmp_path / model_hash
@@ -82,31 +88,31 @@ def test_get_model_folder_migrates_legacy_structure(tmp_path):
assert (expected_folder / 'image.png').exists() assert (expected_folder / 'image.png').exists()
def test_ensure_library_root_exists_creates_directories(tmp_path): def test_ensure_library_root_exists_creates_directories(tmp_path, settings_manager):
settings.settings['example_images_path'] = str(tmp_path) settings_manager.settings['example_images_path'] = str(tmp_path)
settings.settings['libraries'] = {'default': {}, 'secondary': {}} settings_manager.settings['libraries'] = {'default': {}, 'secondary': {}}
settings.settings['active_library'] = 'secondary' settings_manager.settings['active_library'] = 'secondary'
resolved = ensure_library_root_exists('secondary') resolved = ensure_library_root_exists('secondary')
assert Path(resolved) == tmp_path / 'secondary' assert Path(resolved) == tmp_path / 'secondary'
assert (tmp_path / 'secondary').is_dir() assert (tmp_path / 'secondary').is_dir()
def test_iter_library_roots_returns_all_configured(tmp_path): def test_iter_library_roots_returns_all_configured(tmp_path, settings_manager):
settings.settings['example_images_path'] = str(tmp_path) settings_manager.settings['example_images_path'] = str(tmp_path)
settings.settings['libraries'] = {'default': {}, 'alt': {}} settings_manager.settings['libraries'] = {'default': {}, 'alt': {}}
settings.settings['active_library'] = 'alt' settings_manager.settings['active_library'] = 'alt'
roots = dict(iter_library_roots()) roots = dict(iter_library_roots())
assert roots['default'] == str(tmp_path / 'default') assert roots['default'] == str(tmp_path / 'default')
assert roots['alt'] == str(tmp_path / 'alt') assert roots['alt'] == str(tmp_path / 'alt')
def test_is_valid_example_images_root_accepts_hash_directories(tmp_path): def test_is_valid_example_images_root_accepts_hash_directories(tmp_path, settings_manager):
settings.settings['example_images_path'] = str(tmp_path) settings_manager.settings['example_images_path'] = str(tmp_path)
# Ensure single-library mode (not multi-library mode) # Ensure single-library mode (not multi-library mode)
settings.settings['libraries'] = {'default': {}} settings_manager.settings['libraries'] = {'default': {}}
settings.settings['active_library'] = 'default' settings_manager.settings['active_library'] = 'default'
hash_folder = tmp_path / ('d' * 64) hash_folder = tmp_path / ('d' * 64)
hash_folder.mkdir() hash_folder.mkdir()

View File

@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import json
import os import os
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
@@ -7,18 +8,42 @@ from typing import Any, Dict, Tuple
import pytest import pytest
from py.services.settings_manager import settings from py.services.settings_manager import get_settings_manager
from py.utils import example_images_metadata as metadata_module
from py.utils import example_images_processor as processor_module from py.utils import example_images_processor as processor_module
from py.utils.example_images_paths import get_model_folder
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def restore_settings() -> None: def restore_settings() -> None:
original = settings.settings.copy() manager = get_settings_manager()
original = manager.settings.copy()
try: try:
yield yield
finally: finally:
settings.settings.clear() manager.settings.clear()
settings.settings.update(original) manager.settings.update(original)
@pytest.fixture(autouse=True)
def patch_metadata_loader(monkeypatch: pytest.MonkeyPatch) -> None:
class SimpleMetadata:
def __init__(self, payload: Dict[str, Any]) -> None:
self._payload = payload
self._unknown_fields: Dict[str, Any] = {}
def to_dict(self) -> Dict[str, Any]:
return self._payload.copy()
async def fake_load(path: str, *_args: Any, **_kwargs: Any):
metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json"
if os.path.exists(metadata_path):
data = json.loads(Path(metadata_path).read_text(encoding="utf-8"))
return SimpleMetadata(data), False
return None, False
monkeypatch.setattr(processor_module.MetadataManager, "load_metadata", staticmethod(fake_load))
monkeypatch.setattr(metadata_module.MetadataManager, "load_metadata", staticmethod(fake_load))
def test_get_file_extension_from_magic_bytes() -> None: def test_get_file_extension_from_magic_bytes() -> None:
@@ -90,9 +115,10 @@ def stub_scanners(monkeypatch: pytest.MonkeyPatch, tmp_path) -> StubScanner:
async def test_import_images_creates_hash_directory(monkeypatch: pytest.MonkeyPatch, tmp_path, stub_scanners: StubScanner) -> None: async def test_import_images_creates_hash_directory(monkeypatch: pytest.MonkeyPatch, tmp_path, stub_scanners: StubScanner) -> None:
settings.settings["example_images_path"] = str(tmp_path / "examples") settings_manager = get_settings_manager()
settings.settings["libraries"] = {"default": {}} settings_manager.settings["example_images_path"] = str(tmp_path / "examples")
settings.settings["active_library"] = "default" settings_manager.settings["libraries"] = {"default": {}}
settings_manager.settings["active_library"] = "default"
source_file = tmp_path / "upload.png" source_file = tmp_path / "upload.png"
source_file.write_bytes(b"PNG data") source_file.write_bytes(b"PNG data")
@@ -112,7 +138,7 @@ async def test_import_images_creates_hash_directory(monkeypatch: pytest.MonkeyPa
assert result["success"] is True assert result["success"] is True
assert result["files"][0]["name"].startswith("custom_short") assert result["files"][0]["name"].startswith("custom_short")
model_folder = Path(settings.settings["example_images_path"]) / ("a" * 64) model_folder = Path(settings_manager.settings["example_images_path"]) / ("a" * 64)
assert model_folder.exists() assert model_folder.exists()
created_files = list(model_folder.glob("custom_short*.png")) created_files = list(model_folder.glob("custom_short*.png"))
assert len(created_files) == 1 assert len(created_files) == 1
@@ -132,7 +158,8 @@ async def test_import_images_rejects_missing_parameters(monkeypatch: pytest.Monk
async def test_import_images_raises_when_model_not_found(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None: async def test_import_images_raises_when_model_not_found(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
settings.settings["example_images_path"] = str(tmp_path) settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
async def _empty_scanner(cls=None): async def _empty_scanner(cls=None):
return StubScanner([]) return StubScanner([])
@@ -143,3 +170,88 @@ async def test_import_images_raises_when_model_not_found(monkeypatch: pytest.Mon
with pytest.raises(processor_module.ExampleImagesImportError): with pytest.raises(processor_module.ExampleImagesImportError):
await processor_module.ExampleImagesProcessor.import_images("a" * 64, [str(tmp_path / "missing.png")]) await processor_module.ExampleImagesProcessor.import_images("a" * 64, [str(tmp_path / "missing.png")])
@pytest.mark.asyncio
async def test_delete_custom_image_preserves_existing_metadata(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path / "examples")
model_hash = "c" * 64
model_file = tmp_path / "keep.safetensors"
model_file.write_text("content", encoding="utf-8")
metadata_path = tmp_path / "keep.metadata.json"
existing_metadata = {
"model_name": "Keep",
"file_path": str(model_file),
"civitai": {
"images": [{"url": "https://example.com/default.png", "type": "image"}],
"customImages": [{"id": "existing-id", "url": "", "type": "image"}],
"trainedWords": ["foo"],
},
}
metadata_path.write_text(json.dumps(existing_metadata), encoding="utf-8")
model_data = {
"sha256": model_hash,
"model_name": "Keep",
"file_path": str(model_file),
"civitai": {
"customImages": [{"id": "existing-id", "url": "", "type": "image"}],
"trainedWords": ["foo"],
},
}
class Scanner(StubScanner):
def has_hash(self, hash_value: str) -> bool:
return hash_value == model_hash
scanner = Scanner([model_data])
async def _return_scanner(cls=None):
return scanner
monkeypatch.setattr(processor_module.ServiceRegistry, "get_lora_scanner", classmethod(_return_scanner))
monkeypatch.setattr(processor_module.ServiceRegistry, "get_checkpoint_scanner", classmethod(_return_scanner))
monkeypatch.setattr(processor_module.ServiceRegistry, "get_embedding_scanner", classmethod(_return_scanner))
model_folder = get_model_folder(model_hash)
os.makedirs(model_folder, exist_ok=True)
(Path(model_folder) / "custom_existing-id.png").write_bytes(b"data")
saved: list[tuple[str, Dict[str, Any]]] = []
async def fake_save(path: str, payload: Dict[str, Any]) -> bool:
saved.append((path, payload.copy()))
return True
monkeypatch.setattr(processor_module.MetadataManager, "save_metadata", staticmethod(fake_save))
class StubRequest:
def __init__(self, payload: Dict[str, Any]) -> None:
self._payload = payload
async def json(self) -> Dict[str, Any]:
return self._payload
response = await processor_module.ExampleImagesProcessor.delete_custom_image(
StubRequest({"model_hash": model_hash, "short_id": "existing-id"})
)
assert response.status == 200
body = json.loads(response.text)
assert body["success"] is True
assert body["custom_images"] == []
assert not (Path(model_folder) / "custom_existing-id.png").exists()
saved_path, saved_payload = saved[-1]
assert saved_path == str(model_file)
assert saved_payload["civitai"]["images"] == existing_metadata["civitai"]["images"]
assert saved_payload["civitai"]["trainedWords"] == ["foo"]
assert saved_payload["civitai"]["customImages"] == []
assert scanner.updated
_, _, updated_metadata = scanner.updated[-1]
assert updated_metadata["civitai"]["images"] == existing_metadata["civitai"]["images"]
assert updated_metadata["civitai"]["customImages"] == []

View File

@@ -1,6 +1,6 @@
import pytest import pytest
from py.services.settings_manager import settings from py.services.settings_manager import SettingsManager, get_settings_manager
from py.utils.utils import ( from py.utils.utils import (
calculate_recipe_fingerprint, calculate_recipe_fingerprint,
calculate_relative_path_for_model, calculate_relative_path_for_model,
@@ -9,7 +9,8 @@ from py.utils.utils import (
@pytest.fixture @pytest.fixture
def isolated_settings(monkeypatch): def isolated_settings(monkeypatch):
default_settings = settings._get_default_settings() manager = get_settings_manager()
default_settings = manager._get_default_settings()
default_settings.update( default_settings.update(
{ {
"download_path_templates": { "download_path_templates": {
@@ -20,8 +21,8 @@ def isolated_settings(monkeypatch):
"base_model_path_mappings": {}, "base_model_path_mappings": {},
} }
) )
monkeypatch.setattr(settings, "settings", default_settings) monkeypatch.setattr(manager, "settings", default_settings)
monkeypatch.setattr(type(settings), "_save_settings", lambda self: None) monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
return default_settings return default_settings

View File

@@ -1,6 +1,7 @@
import { app } from "../../scripts/app.js"; import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js"; import { api } from "../../scripts/api.js";
import { addJsonDisplayWidget } from "./json_display_widget.js"; import { addJsonDisplayWidget } from "./json_display_widget.js";
import { getNodeFromGraph } from "./utils.js";
app.registerExtension({ app.registerExtension({
name: "LoraManager.DebugMetadata", name: "LoraManager.DebugMetadata",
@@ -8,8 +9,8 @@ app.registerExtension({
setup() { setup() {
// Add message handler to listen for metadata updates from Python // Add message handler to listen for metadata updates from Python
api.addEventListener("metadata_update", (event) => { api.addEventListener("metadata_update", (event) => {
const { id, metadata } = event.detail; const { id, graph_id: graphId, metadata } = event.detail;
this.handleMetadataUpdate(id, metadata); this.handleMetadataUpdate(id, graphId, metadata);
}); });
}, },
@@ -37,8 +38,8 @@ app.registerExtension({
}, },
// Handle metadata updates from Python // Handle metadata updates from Python
handleMetadataUpdate(id, metadata) { handleMetadataUpdate(id, graphId, metadata) {
const node = app.graph.getNodeById(+id); const node = getNodeFromGraph(graphId, id);
if (!node || node.comfyClass !== "Debug Metadata (LoraManager)") { if (!node || node.comfyClass !== "Debug Metadata (LoraManager)") {
console.warn("Node not found or not a DebugMetadata node:", id); console.warn("Node not found or not a DebugMetadata node:", id);
return; return;

View File

@@ -7,6 +7,8 @@ import {
chainCallback, chainCallback,
mergeLoras, mergeLoras,
setupInputWidgetWithAutocomplete, setupInputWidgetWithAutocomplete,
getAllGraphNodes,
getNodeFromGraph,
} from "./utils.js"; } from "./utils.js";
import { addLorasWidget } from "./loras_widget.js"; import { addLorasWidget } from "./loras_widget.js";
@@ -16,23 +18,26 @@ app.registerExtension({
setup() { setup() {
// Add message handler to listen for messages from Python // Add message handler to listen for messages from Python
api.addEventListener("lora_code_update", (event) => { api.addEventListener("lora_code_update", (event) => {
const { id, lora_code, mode } = event.detail; this.handleLoraCodeUpdate(event.detail || {});
this.handleLoraCodeUpdate(id, lora_code, mode);
}); });
}, },
// Handle lora code updates from Python // Handle lora code updates from Python
handleLoraCodeUpdate(id, loraCode, mode) { handleLoraCodeUpdate(message) {
const nodeId = message?.node_id ?? message?.id;
const graphId = message?.graph_id;
const loraCode = message?.lora_code ?? "";
const mode = message?.mode ?? "append";
const numericNodeId =
typeof nodeId === "string" ? Number(nodeId) : nodeId;
// Handle broadcast mode (for Desktop/non-browser support) // Handle broadcast mode (for Desktop/non-browser support)
if (id === -1) { if (numericNodeId === -1) {
// Find all Lora Loader nodes in the current graph // Find all Lora Loader nodes in the current graph
const loraLoaderNodes = []; const loraLoaderNodes = getAllGraphNodes(app.graph)
for (const nodeId in app.graph._nodes_by_id) { .map(({ node }) => node)
const node = app.graph._nodes_by_id[nodeId]; .filter((node) => node?.comfyClass === "Lora Loader (LoraManager)");
if (node.comfyClass === "Lora Loader (LoraManager)") {
loraLoaderNodes.push(node);
}
}
// Update each Lora Loader node found // Update each Lora Loader node found
if (loraLoaderNodes.length > 0) { if (loraLoaderNodes.length > 0) {
@@ -52,14 +57,18 @@ app.registerExtension({
} }
// Standard mode - update a specific node // Standard mode - update a specific node
const node = app.graph.getNodeById(+id); const node = getNodeFromGraph(graphId, numericNodeId);
if ( if (
!node || !node ||
(node.comfyClass !== "Lora Loader (LoraManager)" && (node.comfyClass !== "Lora Loader (LoraManager)" &&
node.comfyClass !== "Lora Stacker (LoraManager)" && node.comfyClass !== "Lora Stacker (LoraManager)" &&
node.comfyClass !== "WanVideo Lora Select (LoraManager)") node.comfyClass !== "WanVideo Lora Select (LoraManager)")
) { ) {
console.warn("Node not found or not a LoraLoader:", id); console.warn(
"Node not found or not a LoraLoader:",
graphId ?? "root",
nodeId
);
return; return;
} }

View File

@@ -7,6 +7,8 @@ import {
chainCallback, chainCallback,
mergeLoras, mergeLoras,
setupInputWidgetWithAutocomplete, setupInputWidgetWithAutocomplete,
getLinkFromGraph,
getNodeKey,
} from "./utils.js"; } from "./utils.js";
import { addLorasWidget } from "./loras_widget.js"; import { addLorasWidget } from "./loras_widget.js";
@@ -124,17 +126,18 @@ app.registerExtension({
// Helper function to find and update downstream Lora Loader nodes // Helper function to find and update downstream Lora Loader nodes
function updateDownstreamLoaders(startNode, visited = new Set()) { function updateDownstreamLoaders(startNode, visited = new Set()) {
if (visited.has(startNode.id)) return; const nodeKey = getNodeKey(startNode);
visited.add(startNode.id); if (!nodeKey || visited.has(nodeKey)) return;
visited.add(nodeKey);
// Check each output link // Check each output link
if (startNode.outputs) { if (startNode.outputs) {
for (const output of startNode.outputs) { for (const output of startNode.outputs) {
if (output.links) { if (output.links) {
for (const linkId of output.links) { for (const linkId of output.links) {
const link = app.graph.links[linkId]; const link = getLinkFromGraph(startNode.graph, linkId);
if (link) { if (link) {
const targetNode = app.graph.getNodeById(link.target_id); const targetNode = startNode.graph?.getNodeById?.(link.target_id);
// If target is a Lora Loader, collect all active loras in the chain and update // If target is a Lora Loader, collect all active loras in the chain and update
if ( if (

View File

@@ -1,6 +1,6 @@
import { app } from "../../scripts/app.js"; import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js"; import { api } from "../../scripts/api.js";
import { CONVERTED_TYPE } from "./utils.js"; import { CONVERTED_TYPE, getNodeFromGraph } from "./utils.js";
import { addTagsWidget } from "./tags_widget.js"; import { addTagsWidget } from "./tags_widget.js";
// TriggerWordToggle extension for ComfyUI // TriggerWordToggle extension for ComfyUI
@@ -10,8 +10,8 @@ app.registerExtension({
setup() { setup() {
// Add message handler to listen for messages from Python // Add message handler to listen for messages from Python
api.addEventListener("trigger_word_update", (event) => { api.addEventListener("trigger_word_update", (event) => {
const { id, message } = event.detail; const { id, graph_id: graphId, message } = event.detail;
this.handleTriggerWordUpdate(id, message); this.handleTriggerWordUpdate(id, graphId, message);
}); });
}, },
@@ -76,8 +76,8 @@ app.registerExtension({
}, },
// Handle trigger word updates from Python // Handle trigger word updates from Python
handleTriggerWordUpdate(id, message) { handleTriggerWordUpdate(id, graphId, message) {
const node = app.graph.getNodeById(+id); const node = getNodeFromGraph(graphId, id);
if (!node || node.comfyClass !== "TriggerWord Toggle (LoraManager)") { if (!node || node.comfyClass !== "TriggerWord Toggle (LoraManager)") {
console.warn("Node not found or not a TriggerWordToggle:", id); console.warn("Node not found or not a TriggerWordToggle:", id);
return; return;

View File

@@ -1,7 +1,7 @@
// ComfyUI extension to track model usage statistics // ComfyUI extension to track model usage statistics
import { app } from "../../scripts/app.js"; import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js"; import { api } from "../../scripts/api.js";
import { showToast } from "./utils.js"; import { getAllGraphNodes, getNodeReference, showToast } from "./utils.js";
// Define target nodes and their widget configurations // Define target nodes and their widget configurations
const PATH_CORRECTION_TARGETS = [ const PATH_CORRECTION_TARGETS = [
@@ -56,25 +56,35 @@ app.registerExtension({
async refreshRegistry() { async refreshRegistry() {
try { try {
// Get current workflow nodes
const prompt = await app.graphToPrompt();
const workflow = prompt.workflow;
if (!workflow || !workflow.nodes) {
console.warn("No workflow nodes found for registry refresh");
return;
}
// Find all Lora nodes
const loraNodes = []; const loraNodes = [];
for (const node of workflow.nodes.values()) { const nodeEntries = getAllGraphNodes(app.graph);
if (node.type === "Lora Loader (LoraManager)" ||
node.type === "Lora Stacker (LoraManager)" || for (const { graph, node } of nodeEntries) {
node.type === "WanVideo Lora Select (LoraManager)") { if (!node || !node.comfyClass) {
continue;
}
if (
node.comfyClass === "Lora Loader (LoraManager)" ||
node.comfyClass === "Lora Stacker (LoraManager)" ||
node.comfyClass === "WanVideo Lora Select (LoraManager)"
) {
const reference = getNodeReference(node);
if (!reference) {
continue;
}
const graphName = typeof graph?.name === "string" && graph.name.trim()
? graph.name
: null;
loraNodes.push({ loraNodes.push({
node_id: node.id, node_id: reference.node_id,
bgcolor: node.bgcolor || null, graph_id: reference.graph_id,
title: node.title || node.type, graph_name: graphName,
type: node.type bgcolor: node.bgcolor ?? node.color ?? null,
title: node.title || node.comfyClass,
type: node.comfyClass,
}); });
} }
} }

View File

@@ -2,6 +2,120 @@ export const CONVERTED_TYPE = 'converted-widget';
import { app } from "../../scripts/app.js"; import { app } from "../../scripts/app.js";
import { AutoComplete } from "./autocomplete.js"; import { AutoComplete } from "./autocomplete.js";
const ROOT_GRAPH_ID = "root";
function isMapLike(collection) {
return collection && typeof collection.entries === "function" && typeof collection.values === "function";
}
function getChildGraphs(graph) {
if (!graph || !graph._subgraphs) {
return [];
}
const rawSubgraphs = isMapLike(graph._subgraphs)
? Array.from(graph._subgraphs.values())
: Object.values(graph._subgraphs);
return rawSubgraphs
.map((subgraph) => subgraph?.graph || subgraph?._graph || subgraph)
.filter((subgraph) => subgraph && subgraph !== graph);
}
function traverseGraphs(rootGraph, visitor, visited = new Set()) {
const graph = rootGraph || app.graph;
if (!graph) {
return;
}
const graphId = getGraphId(graph);
if (visited.has(graphId)) {
return;
}
visited.add(graphId);
visitor(graph);
for (const subgraph of getChildGraphs(graph)) {
traverseGraphs(subgraph, visitor, visited);
}
}
export function getGraphId(graph) {
return graph?.id ?? ROOT_GRAPH_ID;
}
export function getNodeGraphId(node) {
if (!node) {
return ROOT_GRAPH_ID;
}
return getGraphId(node.graph || app.graph);
}
export function getGraphById(graphId, rootGraph = app.graph) {
if (!graphId) {
return rootGraph;
}
let foundGraph = null;
traverseGraphs(rootGraph, (graph) => {
if (!foundGraph && getGraphId(graph) === graphId) {
foundGraph = graph;
}
});
return foundGraph;
}
export function getNodeFromGraph(graphId, nodeId) {
const graph = getGraphById(graphId) || app.graph;
if (!graph || typeof graph.getNodeById !== "function") {
return null;
}
const numericId = typeof nodeId === "string" ? Number(nodeId) : nodeId;
return graph.getNodeById(Number.isNaN(numericId) ? nodeId : numericId) || null;
}
export function getAllGraphNodes(rootGraph = app.graph) {
const nodes = [];
traverseGraphs(rootGraph, (graph) => {
if (Array.isArray(graph._nodes)) {
for (const node of graph._nodes) {
nodes.push({ graph, node });
}
}
});
return nodes;
}
export function getNodeReference(node) {
if (!node) {
return null;
}
return {
node_id: node.id,
graph_id: getNodeGraphId(node),
};
}
export function getNodeKey(node) {
if (!node) {
return null;
}
return `${getNodeGraphId(node)}:${node.id}`;
}
export function getLinkFromGraph(graph, linkId) {
if (!graph || graph.links == null) {
return null;
}
if (isMapLike(graph.links)) {
return graph.links.get(linkId) || null;
}
return graph.links[linkId] || null;
}
export function chainCallback(object, property, callback) { export function chainCallback(object, property, callback) {
if (object == undefined) { if (object == undefined) {
//This should not happen. //This should not happen.
@@ -104,19 +218,26 @@ export const LORA_PATTERN = /<lora:([^:]+):([-\d\.]+)(?::([-\d\.]+))?>/g;
export function getConnectedInputStackers(node) { export function getConnectedInputStackers(node) {
const connectedStackers = []; const connectedStackers = [];
if (node.inputs) { if (!node?.inputs) {
for (const input of node.inputs) { return connectedStackers;
if (input.name === "lora_stack" && input.link) { }
const link = app.graph.links[input.link];
if (link) { for (const input of node.inputs) {
const sourceNode = app.graph.getNodeById(link.origin_id); if (input.name !== "lora_stack" || !input.link) {
if (sourceNode && sourceNode.comfyClass === "Lora Stacker (LoraManager)") { continue;
connectedStackers.push(sourceNode); }
}
} const link = getLinkFromGraph(node.graph, input.link);
} if (!link) {
continue;
}
const sourceNode = node.graph?.getNodeById?.(link.origin_id);
if (sourceNode && sourceNode.comfyClass === "Lora Stacker (LoraManager)") {
connectedStackers.push(sourceNode);
} }
} }
return connectedStackers; return connectedStackers;
} }
@@ -124,21 +245,28 @@ export function getConnectedInputStackers(node) {
export function getConnectedTriggerToggleNodes(node) { export function getConnectedTriggerToggleNodes(node) {
const connectedNodes = []; const connectedNodes = [];
if (node.outputs && node.outputs.length > 0) { if (!node?.outputs) {
for (const output of node.outputs) { return connectedNodes;
if (output.links && output.links.length > 0) { }
for (const linkId of output.links) {
const link = app.graph.links[linkId]; for (const output of node.outputs) {
if (link) { if (!output?.links?.length) {
const targetNode = app.graph.getNodeById(link.target_id); continue;
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") { }
connectedNodes.push(targetNode.id);
} for (const linkId of output.links) {
} const link = getLinkFromGraph(node.graph, linkId);
} if (!link) {
continue;
}
const targetNode = node.graph?.getNodeById?.(link.target_id);
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
connectedNodes.push(targetNode);
} }
} }
} }
return connectedNodes; return connectedNodes;
} }
@@ -161,10 +289,14 @@ export function getActiveLorasFromNode(node) {
// Recursively collect all active loras from a node and its input chain // Recursively collect all active loras from a node and its input chain
export function collectActiveLorasFromChain(node, visited = new Set()) { export function collectActiveLorasFromChain(node, visited = new Set()) {
// Prevent infinite loops from circular references // Prevent infinite loops from circular references
if (visited.has(node.id)) { const nodeKey = getNodeKey(node);
if (!nodeKey) {
return new Set(); return new Set();
} }
visited.add(node.id); if (visited.has(nodeKey)) {
return new Set();
}
visited.add(nodeKey);
// Get active loras from current node // Get active loras from current node
const allActiveLoraNames = getActiveLorasFromNode(node); const allActiveLoraNames = getActiveLorasFromNode(node);
@@ -181,14 +313,22 @@ export function collectActiveLorasFromChain(node, visited = new Set()) {
// Update trigger words for connected toggle nodes // Update trigger words for connected toggle nodes
export function updateConnectedTriggerWords(node, loraNames) { export function updateConnectedTriggerWords(node, loraNames) {
const connectedNodeIds = getConnectedTriggerToggleNodes(node); const connectedNodes = getConnectedTriggerToggleNodes(node);
if (connectedNodeIds.length > 0) { if (connectedNodes.length > 0) {
const nodeIds = connectedNodes
.map((connectedNode) => getNodeReference(connectedNode))
.filter((reference) => reference !== null);
if (nodeIds.length === 0) {
return;
}
fetch("/api/lm/loras/get_trigger_words", { fetch("/api/lm/loras/get_trigger_words", {
method: "POST", method: "POST",
headers: { "Content-Type": "application/json" }, headers: { "Content-Type": "application/json" },
body: JSON.stringify({ body: JSON.stringify({
lora_names: Array.from(loraNames), lora_names: Array.from(loraNames),
node_ids: connectedNodeIds node_ids: nodeIds
}) })
}).catch(err => console.error("Error fetching trigger words:", err)); }).catch(err => console.error("Error fetching trigger words:", err));
} }