mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76d3aa2b5b | ||
|
|
c9a65c7347 | ||
|
|
f542ade628 | ||
|
|
d2c2bfbe6a | ||
|
|
2b6910bd55 | ||
|
|
b1dd733493 | ||
|
|
5dcf0a1e48 | ||
|
|
cf357b57fc | ||
|
|
4e1773833f | ||
|
|
8cf762ffd3 | ||
|
|
d997eaa429 | ||
|
|
8e51f0f19f | ||
|
|
f0e246b4ac | ||
|
|
a232997a79 | ||
|
|
08a449db99 | ||
|
|
0c023c9888 | ||
|
|
0ad92d00b3 | ||
|
|
a726cbea1e | ||
|
|
c53fa8692b | ||
|
|
3118f3b43c |
@@ -529,12 +529,15 @@
|
||||
"title": "Embedding-Modelle"
|
||||
},
|
||||
"sidebar": {
|
||||
"modelRoot": "Modell-Stammverzeichnis",
|
||||
"modelRoot": "Stammverzeichnis",
|
||||
"collapseAll": "Alle Ordner einklappen",
|
||||
"pinSidebar": "Sidebar anheften",
|
||||
"unpinSidebar": "Sidebar lösen",
|
||||
"switchToListView": "Zur Listenansicht wechseln",
|
||||
"switchToTreeView": "Zur Baumansicht wechseln",
|
||||
"recursiveOn": "Unterordner durchsuchen",
|
||||
"recursiveOff": "Nur aktuellen Ordner durchsuchen",
|
||||
"recursiveUnavailable": "Rekursive Suche ist nur in der Baumansicht verfügbar",
|
||||
"collapseAllDisabled": "Im Listenmodus nicht verfügbar"
|
||||
},
|
||||
"statistics": {
|
||||
|
||||
@@ -529,12 +529,15 @@
|
||||
"title": "Embedding Models"
|
||||
},
|
||||
"sidebar": {
|
||||
"modelRoot": "Model Root",
|
||||
"modelRoot": "Root",
|
||||
"collapseAll": "Collapse All Folders",
|
||||
"pinSidebar": "Pin Sidebar",
|
||||
"unpinSidebar": "Unpin Sidebar",
|
||||
"switchToListView": "Switch to List View",
|
||||
"switchToTreeView": "Switch to Tree View",
|
||||
"recursiveOn": "Search subfolders",
|
||||
"recursiveOff": "Search current folder only",
|
||||
"recursiveUnavailable": "Recursive search is available in tree view only",
|
||||
"collapseAllDisabled": "Not available in list view"
|
||||
},
|
||||
"statistics": {
|
||||
|
||||
@@ -529,12 +529,15 @@
|
||||
"title": "Modelos embedding"
|
||||
},
|
||||
"sidebar": {
|
||||
"modelRoot": "Raíz del modelo",
|
||||
"modelRoot": "Raíz",
|
||||
"collapseAll": "Colapsar todas las carpetas",
|
||||
"pinSidebar": "Fijar barra lateral",
|
||||
"unpinSidebar": "Desfijar barra lateral",
|
||||
"switchToListView": "Cambiar a vista de lista",
|
||||
"switchToTreeView": "Cambiar a vista de árbol",
|
||||
"recursiveOn": "Buscar en subcarpetas",
|
||||
"recursiveOff": "Buscar solo en la carpeta actual",
|
||||
"recursiveUnavailable": "La búsqueda recursiva solo está disponible en la vista en árbol",
|
||||
"collapseAllDisabled": "No disponible en vista de lista"
|
||||
},
|
||||
"statistics": {
|
||||
|
||||
@@ -529,12 +529,15 @@
|
||||
"title": "Modèles Embedding"
|
||||
},
|
||||
"sidebar": {
|
||||
"modelRoot": "Racine du modèle",
|
||||
"modelRoot": "Racine",
|
||||
"collapseAll": "Réduire tous les dossiers",
|
||||
"pinSidebar": "Épingler la barre latérale",
|
||||
"unpinSidebar": "Désépingler la barre latérale",
|
||||
"switchToListView": "Passer en vue liste",
|
||||
"switchToTreeView": "Passer en vue arborescence",
|
||||
"recursiveOn": "Rechercher dans les sous-dossiers",
|
||||
"recursiveOff": "Rechercher uniquement dans le dossier actuel",
|
||||
"recursiveUnavailable": "La recherche récursive n'est disponible qu'en vue arborescente",
|
||||
"collapseAllDisabled": "Non disponible en vue liste"
|
||||
},
|
||||
"statistics": {
|
||||
|
||||
@@ -529,12 +529,15 @@
|
||||
"title": "מודלי Embedding"
|
||||
},
|
||||
"sidebar": {
|
||||
"modelRoot": "שורש המודלים",
|
||||
"modelRoot": "שורש",
|
||||
"collapseAll": "כווץ את כל התיקיות",
|
||||
"pinSidebar": "נעל סרגל צד",
|
||||
"unpinSidebar": "שחרר סרגל צד",
|
||||
"switchToListView": "עבור לתצוגת רשימה",
|
||||
"switchToTreeView": "עבור לתצוגת עץ",
|
||||
"switchToTreeView": "תצוגת עץ",
|
||||
"recursiveOn": "חיפוש בתיקיות משנה",
|
||||
"recursiveOff": "חיפוש רק בתיקייה הנוכחית",
|
||||
"recursiveUnavailable": "חיפוש רקורסיבי זמין רק בתצוגת עץ",
|
||||
"collapseAllDisabled": "לא זמין בתצוגת רשימה"
|
||||
},
|
||||
"statistics": {
|
||||
@@ -1264,4 +1267,4 @@
|
||||
"learnMore": "LM Civitai Extension Tutorial"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -529,12 +529,15 @@
|
||||
"title": "Embeddingモデル"
|
||||
},
|
||||
"sidebar": {
|
||||
"modelRoot": "モデルルート",
|
||||
"modelRoot": "ルート",
|
||||
"collapseAll": "すべてのフォルダを折りたたむ",
|
||||
"pinSidebar": "サイドバーを固定",
|
||||
"unpinSidebar": "サイドバーの固定を解除",
|
||||
"switchToListView": "リストビューに切り替え",
|
||||
"switchToTreeView": "ツリービューに切り替え",
|
||||
"switchToTreeView": "ツリー表示に切り替え",
|
||||
"recursiveOn": "サブフォルダーを検索",
|
||||
"recursiveOff": "現在のフォルダーのみを検索",
|
||||
"recursiveUnavailable": "再帰検索はツリービューでのみ利用できます",
|
||||
"collapseAllDisabled": "リストビューでは利用できません"
|
||||
},
|
||||
"statistics": {
|
||||
@@ -1264,4 +1267,4 @@
|
||||
"learnMore": "LM Civitai Extension Tutorial"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -529,12 +529,15 @@
|
||||
"title": "Embedding 모델"
|
||||
},
|
||||
"sidebar": {
|
||||
"modelRoot": "모델 루트",
|
||||
"modelRoot": "루트",
|
||||
"collapseAll": "모든 폴더 접기",
|
||||
"pinSidebar": "사이드바 고정",
|
||||
"unpinSidebar": "사이드바 고정 해제",
|
||||
"switchToListView": "목록 보기로 전환",
|
||||
"switchToTreeView": "트리 보기로 전환",
|
||||
"recursiveOn": "하위 폴더 검색",
|
||||
"recursiveOff": "현재 폴더만 검색",
|
||||
"recursiveUnavailable": "재귀 검색은 트리 보기에서만 사용할 수 있습니다",
|
||||
"collapseAllDisabled": "목록 보기에서는 사용할 수 없습니다"
|
||||
},
|
||||
"statistics": {
|
||||
@@ -1264,4 +1267,4 @@
|
||||
"learnMore": "LM Civitai Extension Tutorial"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -529,12 +529,15 @@
|
||||
"title": "Модели Embedding"
|
||||
},
|
||||
"sidebar": {
|
||||
"modelRoot": "Корень моделей",
|
||||
"modelRoot": "Корень",
|
||||
"collapseAll": "Свернуть все папки",
|
||||
"pinSidebar": "Закрепить боковую панель",
|
||||
"unpinSidebar": "Открепить боковую панель",
|
||||
"switchToListView": "Переключить на вид списка",
|
||||
"switchToTreeView": "Переключить на древовидный вид",
|
||||
"recursiveOn": "Искать во вложенных папках",
|
||||
"recursiveOff": "Искать только в текущей папке",
|
||||
"recursiveUnavailable": "Рекурсивный поиск доступен только в режиме дерева",
|
||||
"collapseAllDisabled": "Недоступно в виде списка"
|
||||
},
|
||||
"statistics": {
|
||||
@@ -1264,4 +1267,4 @@
|
||||
"learnMore": "LM Civitai Extension Tutorial"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -535,12 +535,15 @@
|
||||
"title": "Embedding 模型"
|
||||
},
|
||||
"sidebar": {
|
||||
"modelRoot": "模型根目录",
|
||||
"modelRoot": "根目录",
|
||||
"collapseAll": "折叠所有文件夹",
|
||||
"pinSidebar": "固定侧边栏",
|
||||
"unpinSidebar": "取消固定侧边栏",
|
||||
"switchToListView": "切换到列表视图",
|
||||
"switchToTreeView": "切换到树状视图",
|
||||
"recursiveOn": "搜索子文件夹",
|
||||
"recursiveOff": "仅搜索当前文件夹",
|
||||
"recursiveUnavailable": "仅在树形视图中可使用递归搜索",
|
||||
"collapseAllDisabled": "列表视图下不可用"
|
||||
},
|
||||
"statistics": {
|
||||
@@ -1270,4 +1273,4 @@
|
||||
"learnMore": "LM Civitai Extension Tutorial"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -529,12 +529,15 @@
|
||||
"title": "Embedding 模型"
|
||||
},
|
||||
"sidebar": {
|
||||
"modelRoot": "模型根目錄",
|
||||
"modelRoot": "根目錄",
|
||||
"collapseAll": "全部摺疊資料夾",
|
||||
"pinSidebar": "固定側邊欄",
|
||||
"unpinSidebar": "取消固定側邊欄",
|
||||
"switchToListView": "切換至列表檢視",
|
||||
"switchToTreeView": "切換至樹狀檢視",
|
||||
"switchToTreeView": "切換到樹狀檢視",
|
||||
"recursiveOn": "搜尋子資料夾",
|
||||
"recursiveOff": "僅搜尋目前資料夾",
|
||||
"recursiveUnavailable": "遞迴搜尋僅能在樹狀檢視中使用",
|
||||
"collapseAllDisabled": "列表檢視下不可用"
|
||||
},
|
||||
"statistics": {
|
||||
@@ -1264,4 +1267,4 @@
|
||||
"learnMore": "LM Civitai Extension Tutorial"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -74,8 +74,9 @@ class Config:
|
||||
"""Persist ComfyUI-derived folder paths to the multi-library settings."""
|
||||
try:
|
||||
ensure_settings_file(logger)
|
||||
from .services.settings_manager import settings as settings_service
|
||||
from .services.settings_manager import get_settings_manager
|
||||
|
||||
settings_service = get_settings_manager()
|
||||
libraries = settings_service.get_libraries()
|
||||
comfy_library = libraries.get("comfyui", {})
|
||||
default_library = libraries.get("default", {})
|
||||
@@ -442,8 +443,9 @@ class Config:
|
||||
"""Return the current library registry and active library name."""
|
||||
|
||||
try:
|
||||
from .services.settings_manager import settings as settings_service
|
||||
from .services.settings_manager import get_settings_manager
|
||||
|
||||
settings_service = get_settings_manager()
|
||||
libraries = settings_service.get_libraries()
|
||||
active_library = settings_service.get_active_library_name()
|
||||
return {
|
||||
|
||||
@@ -13,7 +13,7 @@ from .routes.misc_routes import MiscRoutes
|
||||
from .routes.preview_routes import PreviewRoutes
|
||||
from .routes.example_images_routes import ExampleImagesRoutes
|
||||
from .services.service_registry import ServiceRegistry
|
||||
from .services.settings_manager import settings
|
||||
from .services.settings_manager import get_settings_manager
|
||||
from .utils.example_images_migration import ExampleImagesMigration
|
||||
from .services.websocket_manager import ws_manager
|
||||
from .services.example_images_cleanup_service import ExampleImagesCleanupService
|
||||
@@ -23,6 +23,25 @@ logger = logging.getLogger(__name__)
|
||||
# Check if we're in standalone mode
|
||||
STANDALONE_MODE = 'nodes' not in sys.modules
|
||||
|
||||
|
||||
class _SettingsProxy:
|
||||
def __init__(self):
|
||||
self._manager = None
|
||||
|
||||
def _resolve(self):
|
||||
if self._manager is None:
|
||||
self._manager = get_settings_manager()
|
||||
return self._manager
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
return self._resolve().get(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self._resolve(), item)
|
||||
|
||||
|
||||
settings = _SettingsProxy()
|
||||
|
||||
class LoraManager:
|
||||
"""Main entry point for LoRA Manager plugin"""
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from ..services.model_lifecycle_service import ModelLifecycleService
|
||||
from ..services.preview_asset_service import PreviewAssetService
|
||||
from ..services.server_i18n import server_i18n as default_server_i18n
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..services.settings_manager import settings as default_settings
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
from ..services.tag_update_service import TagUpdateService
|
||||
from ..services.websocket_manager import ws_manager as default_ws_manager
|
||||
from ..services.use_cases import (
|
||||
@@ -56,14 +56,14 @@ class BaseModelRoutes(ABC):
|
||||
self,
|
||||
service=None,
|
||||
*,
|
||||
settings_service=default_settings,
|
||||
settings_service=None,
|
||||
ws_manager=default_ws_manager,
|
||||
server_i18n=default_server_i18n,
|
||||
metadata_provider_factory=get_default_metadata_provider,
|
||||
) -> None:
|
||||
self.service = None
|
||||
self.model_type = ""
|
||||
self._settings = settings_service
|
||||
self._settings = settings_service or get_settings_manager()
|
||||
self._ws_manager = ws_manager
|
||||
self._server_i18n = server_i18n
|
||||
self._metadata_provider_factory = metadata_provider_factory
|
||||
@@ -90,7 +90,7 @@ class BaseModelRoutes(ABC):
|
||||
self._metadata_sync_service = MetadataSyncService(
|
||||
metadata_manager=MetadataManager,
|
||||
preview_service=self._preview_service,
|
||||
settings=settings_service,
|
||||
settings=self._settings,
|
||||
default_metadata_provider_factory=metadata_provider_factory,
|
||||
metadata_provider_selector=get_metadata_provider,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ from ..services.recipes import (
|
||||
)
|
||||
from ..services.server_i18n import server_i18n
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from .handlers.recipe_handlers import (
|
||||
@@ -48,7 +48,7 @@ class BaseRecipeRoutes:
|
||||
self.recipe_scanner = None
|
||||
self.lora_scanner = None
|
||||
self.civitai_client = None
|
||||
self.settings = settings
|
||||
self.settings = get_settings_manager()
|
||||
self.server_i18n = server_i18n
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
|
||||
@@ -24,10 +24,17 @@ from ...services.metadata_service import (
|
||||
update_metadata_providers,
|
||||
)
|
||||
from ...services.service_registry import ServiceRegistry
|
||||
from ...services.settings_manager import settings as default_settings
|
||||
from ...services.settings_manager import get_settings_manager
|
||||
from ...services.websocket_manager import ws_manager
|
||||
from ...services.downloader import get_downloader
|
||||
from ...utils.constants import DEFAULT_NODE_COLOR, NODE_TYPES, SUPPORTED_MEDIA_EXTENSIONS
|
||||
from ...utils.constants import (
|
||||
CIVITAI_USER_MODEL_TYPES,
|
||||
DEFAULT_NODE_COLOR,
|
||||
NODE_TYPES,
|
||||
SUPPORTED_MEDIA_EXTENSIONS,
|
||||
VALID_LORA_TYPES,
|
||||
)
|
||||
from ...utils.civitai_utils import rewrite_preview_url
|
||||
from ...utils.example_images_paths import is_valid_example_images_root
|
||||
from ...utils.lora_metadata import extract_trained_words
|
||||
from ...utils.usage_stats import UsageStats
|
||||
@@ -80,7 +87,7 @@ class NodeRegistry:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = asyncio.Lock()
|
||||
self._nodes: Dict[int, dict] = {}
|
||||
self._nodes: Dict[str, dict] = {}
|
||||
self._registry_updated = asyncio.Event()
|
||||
|
||||
async def register_nodes(self, nodes: list[dict]) -> None:
|
||||
@@ -88,11 +95,16 @@ class NodeRegistry:
|
||||
self._nodes.clear()
|
||||
for node in nodes:
|
||||
node_id = node["node_id"]
|
||||
graph_id = str(node["graph_id"])
|
||||
unique_id = f"{graph_id}:{node_id}"
|
||||
node_type = node.get("type", "")
|
||||
type_id = NODE_TYPES.get(node_type, 0)
|
||||
bgcolor = node.get("bgcolor") or DEFAULT_NODE_COLOR
|
||||
self._nodes[node_id] = {
|
||||
self._nodes[unique_id] = {
|
||||
"id": node_id,
|
||||
"graph_id": graph_id,
|
||||
"graph_name": node.get("graph_name"),
|
||||
"unique_id": unique_id,
|
||||
"bgcolor": bgcolor,
|
||||
"title": node.get("title"),
|
||||
"type": type_id,
|
||||
@@ -157,11 +169,11 @@ class SettingsHandler:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
settings_service=default_settings,
|
||||
settings_service=None,
|
||||
metadata_provider_updater: Callable[[], Awaitable[None]] = update_metadata_providers,
|
||||
downloader_factory: Callable[[], Awaitable[DownloaderProtocol]] = get_downloader,
|
||||
) -> None:
|
||||
self._settings = settings_service
|
||||
self._settings = settings_service or get_settings_manager()
|
||||
self._metadata_provider_updater = metadata_provider_updater
|
||||
self._downloader_factory = downloader_factory
|
||||
|
||||
@@ -330,16 +342,65 @@ class LoraCodeHandler:
|
||||
logger.error("Error broadcasting lora code: %s", exc)
|
||||
results.append({"node_id": "broadcast", "success": False, "error": str(exc)})
|
||||
else:
|
||||
for node_id in node_ids:
|
||||
for entry in node_ids:
|
||||
node_identifier = entry
|
||||
graph_identifier = None
|
||||
if isinstance(entry, dict):
|
||||
node_identifier = entry.get("node_id")
|
||||
graph_identifier = entry.get("graph_id")
|
||||
|
||||
if node_identifier is None:
|
||||
results.append(
|
||||
{
|
||||
"node_id": node_identifier,
|
||||
"graph_id": graph_identifier,
|
||||
"success": False,
|
||||
"error": "Missing node_id parameter",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
parsed_node_id = int(node_identifier)
|
||||
except (TypeError, ValueError):
|
||||
parsed_node_id = node_identifier
|
||||
|
||||
payload = {
|
||||
"id": parsed_node_id,
|
||||
"lora_code": lora_code,
|
||||
"mode": mode,
|
||||
}
|
||||
|
||||
if graph_identifier is not None:
|
||||
payload["graph_id"] = str(graph_identifier)
|
||||
|
||||
try:
|
||||
self._prompt_server.instance.send_sync(
|
||||
"lora_code_update",
|
||||
{"id": node_id, "lora_code": lora_code, "mode": mode},
|
||||
payload,
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"node_id": parsed_node_id,
|
||||
"graph_id": payload.get("graph_id"),
|
||||
"success": True,
|
||||
}
|
||||
)
|
||||
results.append({"node_id": node_id, "success": True})
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Error sending lora code to node %s: %s", node_id, exc)
|
||||
results.append({"node_id": node_id, "success": False, "error": str(exc)})
|
||||
logger.error(
|
||||
"Error sending lora code to node %s (graph %s): %s",
|
||||
parsed_node_id,
|
||||
graph_identifier,
|
||||
exc,
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"node_id": parsed_node_id,
|
||||
"graph_id": payload.get("graph_id"),
|
||||
"success": False,
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
|
||||
return web.json_response({"success": True, "results": results})
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
@@ -557,17 +618,118 @@ class ModelLibraryHandler:
|
||||
logger.error("Failed to get model versions status: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_civitai_user_models(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
username = request.query.get("username")
|
||||
if not username:
|
||||
return web.json_response({"success": False, "error": "Missing required parameter: username"}, status=400)
|
||||
|
||||
metadata_provider = await self._metadata_provider_factory()
|
||||
if not metadata_provider:
|
||||
return web.json_response({"success": False, "error": "Metadata provider not available"}, status=503)
|
||||
|
||||
try:
|
||||
models = await metadata_provider.get_user_models(username)
|
||||
except NotImplementedError:
|
||||
return web.json_response({"success": False, "error": "Metadata provider does not support user model queries"}, status=501)
|
||||
|
||||
if models is None:
|
||||
return web.json_response({"success": False, "error": "Failed to fetch user models"}, status=502)
|
||||
|
||||
if not isinstance(models, list):
|
||||
models = []
|
||||
|
||||
lora_scanner = await self._service_registry.get_lora_scanner()
|
||||
checkpoint_scanner = await self._service_registry.get_checkpoint_scanner()
|
||||
embedding_scanner = await self._service_registry.get_embedding_scanner()
|
||||
|
||||
normalized_allowed_types = {model_type.lower() for model_type in CIVITAI_USER_MODEL_TYPES}
|
||||
lora_type_aliases = {model_type.lower() for model_type in VALID_LORA_TYPES}
|
||||
|
||||
type_scanner_map: Dict[str, object | None] = {
|
||||
**{alias: lora_scanner for alias in lora_type_aliases},
|
||||
"checkpoint": checkpoint_scanner,
|
||||
"textualinversion": embedding_scanner,
|
||||
}
|
||||
|
||||
versions: list[dict] = []
|
||||
for model in models:
|
||||
if not isinstance(model, dict):
|
||||
continue
|
||||
|
||||
model_type = str(model.get("type", "")).lower()
|
||||
if model_type not in normalized_allowed_types:
|
||||
continue
|
||||
|
||||
scanner = type_scanner_map.get(model_type)
|
||||
if scanner is None:
|
||||
return web.json_response({"success": False, "error": f'Scanner for type "{model_type}" is not available'}, status=503)
|
||||
|
||||
tags_value = model.get("tags")
|
||||
tags = tags_value if isinstance(tags_value, list) else []
|
||||
model_id = model.get("id")
|
||||
try:
|
||||
model_id_int = int(model_id)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
model_name = model.get("name", "")
|
||||
|
||||
versions_data = model.get("modelVersions")
|
||||
if not isinstance(versions_data, list):
|
||||
continue
|
||||
|
||||
for version in versions_data:
|
||||
if not isinstance(version, dict):
|
||||
continue
|
||||
|
||||
version_id = version.get("id")
|
||||
try:
|
||||
version_id_int = int(version_id)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
|
||||
images = version.get("images") or []
|
||||
thumbnail_url = None
|
||||
if images and isinstance(images, list):
|
||||
first_image = images[0]
|
||||
if isinstance(first_image, dict):
|
||||
raw_url = first_image.get("url")
|
||||
media_type = first_image.get("type")
|
||||
rewritten_url, _ = rewrite_preview_url(raw_url, media_type)
|
||||
thumbnail_url = rewritten_url
|
||||
|
||||
in_library = await scanner.check_model_version_exists(version_id_int)
|
||||
|
||||
versions.append(
|
||||
{
|
||||
"modelId": model_id_int,
|
||||
"versionId": version_id_int,
|
||||
"modelName": model_name,
|
||||
"versionName": version.get("name", ""),
|
||||
"type": model.get("type"),
|
||||
"tags": tags,
|
||||
"baseModel": version.get("baseModel"),
|
||||
"thumbnailUrl": thumbnail_url,
|
||||
"inLibrary": in_library,
|
||||
}
|
||||
)
|
||||
|
||||
return web.json_response({"success": True, "username": username, "versions": versions})
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Failed to get Civitai user models: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class MetadataArchiveHandler:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metadata_archive_manager_factory: Callable[[], Awaitable[MetadataArchiveManagerProtocol]] = get_metadata_archive_manager,
|
||||
settings_service=default_settings,
|
||||
settings_service=None,
|
||||
metadata_provider_updater: Callable[[], Awaitable[None]] = update_metadata_providers,
|
||||
) -> None:
|
||||
self._metadata_archive_manager_factory = metadata_archive_manager_factory
|
||||
self._settings = settings_service
|
||||
self._settings = settings_service or get_settings_manager()
|
||||
self._metadata_provider_updater = metadata_provider_updater
|
||||
|
||||
async def download_metadata_archive(self, request: web.Request) -> web.Response:
|
||||
@@ -679,10 +841,21 @@ class NodeRegistryHandler:
|
||||
node_id = node.get("node_id")
|
||||
if node_id is None:
|
||||
return web.json_response({"success": False, "error": f"Node {index} missing node_id parameter"}, status=400)
|
||||
graph_id = node.get("graph_id")
|
||||
if graph_id is None:
|
||||
return web.json_response({"success": False, "error": f"Node {index} missing graph_id parameter"}, status=400)
|
||||
graph_name = node.get("graph_name")
|
||||
try:
|
||||
node["node_id"] = int(node_id)
|
||||
except (TypeError, ValueError):
|
||||
return web.json_response({"success": False, "error": f"Node {index} node_id must be an integer"}, status=400)
|
||||
node["graph_id"] = str(graph_id)
|
||||
if graph_name is None:
|
||||
node["graph_name"] = None
|
||||
elif isinstance(graph_name, str):
|
||||
node["graph_name"] = graph_name
|
||||
else:
|
||||
node["graph_name"] = str(graph_name)
|
||||
|
||||
await self._node_registry.register_nodes(nodes)
|
||||
return web.json_response({"success": True, "message": f"{len(nodes)} nodes registered successfully"})
|
||||
@@ -779,6 +952,7 @@ class MiscHandlerSet:
|
||||
"register_nodes": self.node_registry.register_nodes,
|
||||
"get_registry": self.node_registry.get_registry,
|
||||
"check_model_exists": self.model_library.check_model_exists,
|
||||
"get_civitai_user_models": self.model_library.get_civitai_user_models,
|
||||
"download_metadata_archive": self.metadata_archive.download_metadata_archive,
|
||||
"remove_metadata_archive": self.metadata_archive.remove_metadata_archive,
|
||||
"get_metadata_archive_status": self.metadata_archive.get_metadata_archive_status,
|
||||
|
||||
@@ -30,6 +30,7 @@ from ...services.use_cases import (
|
||||
from ...services.websocket_manager import WebSocketManager
|
||||
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
||||
from ...utils.file_utils import calculate_sha256
|
||||
from ...utils.metadata_manager import MetadataManager
|
||||
|
||||
|
||||
class ModelPageView:
|
||||
@@ -244,6 +245,8 @@ class ModelManagementHandler:
|
||||
if not model_data.get("sha256"):
|
||||
return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400)
|
||||
|
||||
await MetadataManager.hydrate_model_data(model_data)
|
||||
|
||||
success, error = await self._metadata_sync.fetch_and_update_model(
|
||||
sha256=model_data["sha256"],
|
||||
file_path=file_path,
|
||||
@@ -825,18 +828,30 @@ class ModelCivitaiHandler:
|
||||
status=400,
|
||||
)
|
||||
|
||||
cache = await self._service.scanner.get_cached_data()
|
||||
version_index = cache.version_index
|
||||
|
||||
for version in versions:
|
||||
model_file = self._find_model_file(version.get("files", [])) if isinstance(version.get("files"), Iterable) else None
|
||||
if model_file:
|
||||
hashes = model_file.get("hashes", {}) if isinstance(model_file, Mapping) else {}
|
||||
sha256 = hashes.get("SHA256") if isinstance(hashes, Mapping) else None
|
||||
if sha256:
|
||||
version["existsLocally"] = self._service.has_hash(sha256)
|
||||
if version["existsLocally"]:
|
||||
version["localPath"] = self._service.get_path_by_hash(sha256)
|
||||
version["modelSizeKB"] = model_file.get("sizeKB") if isinstance(model_file, Mapping) else None
|
||||
version_id = None
|
||||
version_id_raw = version.get("id")
|
||||
if version_id_raw is not None:
|
||||
try:
|
||||
version_id = int(str(version_id_raw))
|
||||
except (TypeError, ValueError):
|
||||
version_id = None
|
||||
|
||||
cache_entry = version_index.get(version_id) if (version_id is not None and version_index) else None
|
||||
version["existsLocally"] = cache_entry is not None
|
||||
if cache_entry and isinstance(cache_entry, Mapping):
|
||||
local_path = cache_entry.get("file_path")
|
||||
if local_path:
|
||||
version["localPath"] = local_path
|
||||
else:
|
||||
version["existsLocally"] = False
|
||||
version.pop("localPath", None)
|
||||
|
||||
model_file = self._find_model_file(version.get("files", [])) if isinstance(version.get("files"), Iterable) else None
|
||||
if model_file and isinstance(model_file, Mapping):
|
||||
version["modelSizeKB"] = model_file.get("sizeKB")
|
||||
return web.json_response(versions)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error fetching %s model versions: %s", self._service.model_type, exc)
|
||||
|
||||
@@ -229,11 +229,27 @@ class LoraRoutes(BaseModelRoutes):
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Send update to all connected trigger word toggle nodes
|
||||
for node_id in node_ids:
|
||||
PromptServer.instance.send_sync("trigger_word_update", {
|
||||
"id": node_id,
|
||||
for entry in node_ids:
|
||||
node_identifier = entry
|
||||
graph_identifier = None
|
||||
if isinstance(entry, dict):
|
||||
node_identifier = entry.get("node_id")
|
||||
graph_identifier = entry.get("graph_id")
|
||||
|
||||
try:
|
||||
parsed_node_id = int(node_identifier)
|
||||
except (TypeError, ValueError):
|
||||
parsed_node_id = node_identifier
|
||||
|
||||
payload = {
|
||||
"id": parsed_node_id,
|
||||
"message": trigger_words_text
|
||||
})
|
||||
}
|
||||
|
||||
if graph_identifier is not None:
|
||||
payload["graph_id"] = str(graph_identifier)
|
||||
|
||||
PromptServer.instance.send_sync("trigger_word_update", payload)
|
||||
|
||||
return web.json_response({"success": True})
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("POST", "/api/lm/register-nodes", "register_nodes"),
|
||||
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
|
||||
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
|
||||
RouteDefinition("GET", "/api/lm/civitai/user-models", "get_civitai_user_models"),
|
||||
RouteDefinition("POST", "/api/lm/download-metadata-archive", "download_metadata_archive"),
|
||||
RouteDefinition("POST", "/api/lm/remove-metadata-archive", "remove_metadata_archive"),
|
||||
RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"),
|
||||
|
||||
@@ -14,7 +14,7 @@ from ..services.metadata_service import (
|
||||
get_metadata_provider,
|
||||
update_metadata_providers,
|
||||
)
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
from ..services.downloader import get_downloader
|
||||
from ..utils.usage_stats import UsageStats
|
||||
from .handlers.misc_handlers import (
|
||||
@@ -47,7 +47,7 @@ class MiscRoutes:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
settings_service=settings,
|
||||
settings_service=None,
|
||||
usage_stats_factory: Callable[[], UsageStats] = UsageStats,
|
||||
prompt_server: type[PromptServer] = PromptServer,
|
||||
service_registry_adapter=build_service_registry_adapter(),
|
||||
@@ -60,7 +60,7 @@ class MiscRoutes:
|
||||
node_registry: NodeRegistry | None = None,
|
||||
standalone_mode_flag: bool = standalone_mode,
|
||||
) -> None:
|
||||
self._settings = settings_service
|
||||
self._settings = settings_service or get_settings_manager()
|
||||
self._usage_stats_factory = usage_stats_factory
|
||||
self._prompt_server = prompt_server
|
||||
self._service_registry_adapter = service_registry_adapter
|
||||
|
||||
@@ -8,13 +8,32 @@ from collections import defaultdict, Counter
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
from ..services.server_i18n import server_i18n
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.usage_stats import UsageStats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _SettingsProxy:
|
||||
def __init__(self):
|
||||
self._manager = None
|
||||
|
||||
def _resolve(self):
|
||||
if self._manager is None:
|
||||
self._manager = get_settings_manager()
|
||||
return self._manager
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
return self._resolve().get(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self._resolve(), item)
|
||||
|
||||
|
||||
settings = _SettingsProxy()
|
||||
|
||||
class StatsRoutes:
|
||||
"""Route handlers for Statistics page and API endpoints"""
|
||||
|
||||
@@ -66,7 +85,9 @@ class StatsRoutes:
|
||||
is_initializing = lora_initializing or checkpoint_initializing or embedding_initializing
|
||||
|
||||
# 获取用户语言设置
|
||||
user_language = settings.get('language', 'en')
|
||||
settings_object = settings
|
||||
user_language = settings_object.get('language', 'en')
|
||||
settings_manager = settings_object if not isinstance(settings_object, _SettingsProxy) else settings_object._resolve()
|
||||
|
||||
# 设置服务端i18n语言
|
||||
server_i18n.set_locale(user_language)
|
||||
@@ -79,7 +100,7 @@ class StatsRoutes:
|
||||
template = self.template_env.get_template('statistics.html')
|
||||
rendered = template.render(
|
||||
is_initializing=is_initializing,
|
||||
settings=settings,
|
||||
settings=settings_manager,
|
||||
request=request,
|
||||
t=server_i18n.get_translation,
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ import os
|
||||
from ..utils.models import BaseModelMetadata
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider
|
||||
from .settings_manager import settings as default_settings
|
||||
from .settings_manager import get_settings_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,7 +38,7 @@ class BaseModelService(ABC):
|
||||
self.model_type = model_type
|
||||
self.scanner = scanner
|
||||
self.metadata_class = metadata_class
|
||||
self.settings = settings_provider or default_settings
|
||||
self.settings = settings_provider or get_settings_manager()
|
||||
self.cache_repository = cache_repository or ModelCacheRepository(scanner)
|
||||
self.filter_set = filter_set or ModelFilterSet(self.settings)
|
||||
self.search_strategy = search_strategy or SearchStrategy()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Optional, Dict, Tuple, List
|
||||
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
|
||||
from .downloader import get_downloader
|
||||
@@ -157,141 +157,160 @@ class CivitaiClient:
|
||||
return None
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
"""Get specific model version with additional metadata
|
||||
|
||||
Args:
|
||||
model_id: The Civitai model ID (optional if version_id is provided)
|
||||
version_id: Optional specific version ID to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: The model version data with additional fields or None if not found
|
||||
"""
|
||||
"""Get specific model version with additional metadata."""
|
||||
try:
|
||||
downloader = await get_downloader()
|
||||
|
||||
# Case 1: Only version_id is provided
|
||||
|
||||
if model_id is None and version_id is not None:
|
||||
# First get the version info to extract model_id
|
||||
success, version = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/model-versions/{version_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if not success:
|
||||
return None
|
||||
|
||||
model_id = version.get('modelId')
|
||||
if not model_id:
|
||||
logger.error(f"No modelId found in version {version_id}")
|
||||
return None
|
||||
|
||||
# Now get the model data for additional metadata
|
||||
success, model_data = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/models/{model_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if success:
|
||||
# Enrich version with model data
|
||||
version['model']['description'] = model_data.get("description")
|
||||
version['model']['tags'] = model_data.get("tags", [])
|
||||
version['creator'] = model_data.get("creator")
|
||||
return await self._get_version_by_id_only(downloader, version_id)
|
||||
|
||||
self._remove_comfy_metadata(version)
|
||||
return version
|
||||
|
||||
# Case 2: model_id is provided (with or without version_id)
|
||||
elif model_id is not None:
|
||||
# Step 1: Get model data to find version_id if not provided and get additional metadata
|
||||
success, data = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/models/{model_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if not success:
|
||||
return None
|
||||
if model_id is not None:
|
||||
return await self._get_version_with_model_id(downloader, model_id, version_id)
|
||||
|
||||
model_versions = data.get('modelVersions', [])
|
||||
if not model_versions:
|
||||
logger.warning(f"No model versions found for model {model_id}")
|
||||
return None
|
||||
logger.error("Either model_id or version_id must be provided")
|
||||
return None
|
||||
|
||||
# Step 2: Determine the target version entry to use
|
||||
target_version = None
|
||||
if version_id is not None:
|
||||
target_version = next(
|
||||
(item for item in model_versions if item.get('id') == version_id),
|
||||
None
|
||||
)
|
||||
if target_version is None:
|
||||
logger.warning(
|
||||
f"Version {version_id} not found for model {model_id}, defaulting to first version"
|
||||
)
|
||||
if target_version is None:
|
||||
target_version = model_versions[0]
|
||||
|
||||
target_version_id = target_version.get('id')
|
||||
|
||||
# Step 3: Get detailed version info using the SHA256 hash
|
||||
model_hash = None
|
||||
for file_info in target_version.get('files', []):
|
||||
if file_info.get('type') == 'Model' and file_info.get('primary'):
|
||||
model_hash = file_info.get('hashes', {}).get('SHA256')
|
||||
if model_hash:
|
||||
break
|
||||
|
||||
version = None
|
||||
if model_hash:
|
||||
success, version = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
||||
use_auth=True
|
||||
)
|
||||
if not success:
|
||||
logger.warning(
|
||||
f"Failed to fetch version by hash for model {model_id} version {target_version_id}: {version}"
|
||||
)
|
||||
version = None
|
||||
else:
|
||||
logger.warning(
|
||||
f"No primary model hash found for model {model_id} version {target_version_id}"
|
||||
)
|
||||
|
||||
if version is None:
|
||||
version = copy.deepcopy(target_version)
|
||||
version.pop('index', None)
|
||||
version['modelId'] = model_id
|
||||
version['model'] = {
|
||||
'name': data.get('name'),
|
||||
'type': data.get('type'),
|
||||
'nsfw': data.get('nsfw'),
|
||||
'poi': data.get('poi')
|
||||
}
|
||||
|
||||
# Step 4: Enrich version_info with model data
|
||||
# Add description and tags from model data
|
||||
model_info = version.get('model')
|
||||
if not isinstance(model_info, dict):
|
||||
model_info = {}
|
||||
version['model'] = model_info
|
||||
model_info['description'] = data.get("description")
|
||||
model_info['tags'] = data.get("tags", [])
|
||||
|
||||
# Add creator from model data
|
||||
version['creator'] = data.get("creator")
|
||||
|
||||
self._remove_comfy_metadata(version)
|
||||
return version
|
||||
|
||||
# Case 3: Neither model_id nor version_id provided
|
||||
else:
|
||||
logger.error("Either model_id or version_id must be provided")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model version: {e}")
|
||||
return None
|
||||
|
||||
async def _get_version_by_id_only(self, downloader, version_id: int) -> Optional[Dict]:
|
||||
version = await self._fetch_version_by_id(downloader, version_id)
|
||||
if version is None:
|
||||
return None
|
||||
|
||||
model_id = version.get('modelId')
|
||||
if not model_id:
|
||||
logger.error(f"No modelId found in version {version_id}")
|
||||
return None
|
||||
|
||||
model_data = await self._fetch_model_data(downloader, model_id)
|
||||
if model_data:
|
||||
self._enrich_version_with_model_data(version, model_data)
|
||||
|
||||
self._remove_comfy_metadata(version)
|
||||
return version
|
||||
|
||||
async def _get_version_with_model_id(self, downloader, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
|
||||
model_data = await self._fetch_model_data(downloader, model_id)
|
||||
if not model_data:
|
||||
return None
|
||||
|
||||
target_version = self._select_target_version(model_data, model_id, version_id)
|
||||
if target_version is None:
|
||||
return None
|
||||
|
||||
target_version_id = target_version.get('id')
|
||||
version = await self._fetch_version_by_id(downloader, target_version_id) if target_version_id else None
|
||||
|
||||
if version is None:
|
||||
model_hash = self._extract_primary_model_hash(target_version)
|
||||
if model_hash:
|
||||
version = await self._fetch_version_by_hash(downloader, model_hash)
|
||||
else:
|
||||
logger.warning(
|
||||
f"No primary model hash found for model {model_id} version {target_version_id}"
|
||||
)
|
||||
|
||||
if version is None:
|
||||
version = self._build_version_from_model_data(target_version, model_id, model_data)
|
||||
|
||||
self._enrich_version_with_model_data(version, model_data)
|
||||
self._remove_comfy_metadata(version)
|
||||
return version
|
||||
|
||||
async def _fetch_model_data(self, downloader, model_id: int) -> Optional[Dict]:
|
||||
success, data = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/models/{model_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if success:
|
||||
return data
|
||||
logger.warning(f"Failed to fetch model data for model {model_id}")
|
||||
return None
|
||||
|
||||
async def _fetch_version_by_id(self, downloader, version_id: Optional[int]) -> Optional[Dict]:
|
||||
if version_id is None:
|
||||
return None
|
||||
|
||||
success, version = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/model-versions/{version_id}",
|
||||
use_auth=True
|
||||
)
|
||||
if success:
|
||||
return version
|
||||
|
||||
logger.warning(f"Failed to fetch version by id {version_id}")
|
||||
return None
|
||||
|
||||
async def _fetch_version_by_hash(self, downloader, model_hash: Optional[str]) -> Optional[Dict]:
|
||||
if not model_hash:
|
||||
return None
|
||||
|
||||
success, version = await downloader.make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
||||
use_auth=True
|
||||
)
|
||||
if success:
|
||||
return version
|
||||
|
||||
logger.warning(f"Failed to fetch version by hash {model_hash}")
|
||||
return None
|
||||
|
||||
def _select_target_version(self, model_data: Dict, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
|
||||
model_versions = model_data.get('modelVersions', [])
|
||||
if not model_versions:
|
||||
logger.warning(f"No model versions found for model {model_id}")
|
||||
return None
|
||||
|
||||
if version_id is not None:
|
||||
target_version = next(
|
||||
(item for item in model_versions if item.get('id') == version_id),
|
||||
None
|
||||
)
|
||||
if target_version is None:
|
||||
logger.warning(
|
||||
f"Version {version_id} not found for model {model_id}, defaulting to first version"
|
||||
)
|
||||
return model_versions[0]
|
||||
return target_version
|
||||
|
||||
return model_versions[0]
|
||||
|
||||
def _extract_primary_model_hash(self, version_entry: Dict) -> Optional[str]:
|
||||
for file_info in version_entry.get('files', []):
|
||||
if file_info.get('type') == 'Model' and file_info.get('primary'):
|
||||
hashes = file_info.get('hashes', {})
|
||||
model_hash = hashes.get('SHA256')
|
||||
if model_hash:
|
||||
return model_hash
|
||||
return None
|
||||
|
||||
def _build_version_from_model_data(self, version_entry: Dict, model_id: int, model_data: Dict) -> Dict:
|
||||
version = copy.deepcopy(version_entry)
|
||||
version.pop('index', None)
|
||||
version['modelId'] = model_id
|
||||
version['model'] = {
|
||||
'name': model_data.get('name'),
|
||||
'type': model_data.get('type'),
|
||||
'nsfw': model_data.get('nsfw'),
|
||||
'poi': model_data.get('poi')
|
||||
}
|
||||
return version
|
||||
|
||||
def _enrich_version_with_model_data(self, version: Dict, model_data: Dict) -> None:
|
||||
model_info = version.get('model')
|
||||
if not isinstance(model_info, dict):
|
||||
model_info = {}
|
||||
version['model'] = model_info
|
||||
|
||||
model_info['description'] = model_data.get("description")
|
||||
model_info['tags'] = model_data.get("tags", [])
|
||||
version['creator'] = model_data.get("creator")
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Fetch model version metadata from Civitai
|
||||
|
||||
@@ -335,7 +354,7 @@ class CivitaiClient:
|
||||
|
||||
async def get_image_info(self, image_id: str) -> Optional[Dict]:
|
||||
"""Fetch image information from Civitai API
|
||||
|
||||
|
||||
Args:
|
||||
image_id: The Civitai image ID
|
||||
|
||||
@@ -366,3 +385,37 @@ class CivitaiClient:
|
||||
error_msg = f"Error fetching image info: {e}"
|
||||
logger.error(error_msg)
|
||||
return None
|
||||
|
||||
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||
"""Fetch all models for a specific Civitai user."""
|
||||
if not username:
|
||||
return None
|
||||
|
||||
try:
|
||||
downloader = await get_downloader()
|
||||
url = f"{self.base_url}/models?username={username}"
|
||||
success, result = await downloader.make_request(
|
||||
'GET',
|
||||
url,
|
||||
use_auth=True
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error("Failed to fetch models for %s: %s", username, result)
|
||||
return None
|
||||
|
||||
items = result.get("items") if isinstance(result, dict) else None
|
||||
if not isinstance(items, list):
|
||||
return []
|
||||
|
||||
for model in items:
|
||||
versions = model.get("modelVersions")
|
||||
if not isinstance(versions, list):
|
||||
continue
|
||||
for version in versions:
|
||||
self._remove_comfy_metadata(version)
|
||||
|
||||
return items
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Error fetching models for %s: %s", username, exc)
|
||||
return None
|
||||
|
||||
@@ -4,12 +4,14 @@ import asyncio
|
||||
from collections import OrderedDict
|
||||
import uuid
|
||||
from typing import Dict, List
|
||||
from urllib.parse import urlparse
|
||||
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
|
||||
from ..utils.civitai_utils import rewrite_preview_url
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .service_registry import ServiceRegistry
|
||||
from .settings_manager import settings
|
||||
from .settings_manager import get_settings_manager
|
||||
from .metadata_service import get_default_metadata_provider
|
||||
from .downloader import get_downloader
|
||||
|
||||
@@ -241,23 +243,24 @@ class DownloadManager:
|
||||
|
||||
# Handle use_default_paths
|
||||
if use_default_paths:
|
||||
settings_manager = get_settings_manager()
|
||||
# Set save_dir based on model type
|
||||
if model_type == 'checkpoint':
|
||||
default_path = settings.get('default_checkpoint_root')
|
||||
default_path = settings_manager.get('default_checkpoint_root')
|
||||
if not default_path:
|
||||
return {'success': False, 'error': 'Default checkpoint root path not set in settings'}
|
||||
save_dir = default_path
|
||||
elif model_type == 'lora':
|
||||
default_path = settings.get('default_lora_root')
|
||||
default_path = settings_manager.get('default_lora_root')
|
||||
if not default_path:
|
||||
return {'success': False, 'error': 'Default lora root path not set in settings'}
|
||||
save_dir = default_path
|
||||
elif model_type == 'embedding':
|
||||
default_path = settings.get('default_embedding_root')
|
||||
default_path = settings_manager.get('default_embedding_root')
|
||||
if not default_path:
|
||||
return {'success': False, 'error': 'Default embedding root path not set in settings'}
|
||||
save_dir = default_path
|
||||
|
||||
|
||||
# Calculate relative path using template
|
||||
relative_path = self._calculate_relative_path(version_info, model_type)
|
||||
|
||||
@@ -360,7 +363,8 @@ class DownloadManager:
|
||||
Relative path string
|
||||
"""
|
||||
# Get path template from settings for specific model type
|
||||
path_template = settings.get_download_path_template(model_type)
|
||||
settings_manager = get_settings_manager()
|
||||
path_template = settings_manager.get_download_path_template(model_type)
|
||||
|
||||
# If template is empty, return empty path (flat structure)
|
||||
if not path_template:
|
||||
@@ -377,7 +381,7 @@ class DownloadManager:
|
||||
author = 'Anonymous'
|
||||
|
||||
# Apply mapping if available
|
||||
base_model_mappings = settings.get('base_model_path_mappings', {})
|
||||
base_model_mappings = settings_manager.get('base_model_path_mappings', {})
|
||||
mapped_base_model = base_model_mappings.get(base_model, base_model)
|
||||
|
||||
# Get model tags
|
||||
@@ -448,70 +452,103 @@ class DownloadManager:
|
||||
# Download preview image if available
|
||||
images = version_info.get('images', [])
|
||||
if images:
|
||||
# Report preview download progress
|
||||
if progress_callback:
|
||||
await progress_callback(1) # 1% progress for starting preview download
|
||||
|
||||
# Check if it's a video or an image
|
||||
is_video = images[0].get('type') == 'video'
|
||||
|
||||
if (is_video):
|
||||
# For videos, use .mp4 extension
|
||||
preview_ext = '.mp4'
|
||||
preview_path = os.path.splitext(save_path)[0] + preview_ext
|
||||
|
||||
# Download video directly using downloader
|
||||
downloader = await get_downloader()
|
||||
success, result = await downloader.download_file(
|
||||
images[0]['url'],
|
||||
preview_path,
|
||||
use_auth=False # Preview images typically don't need auth
|
||||
)
|
||||
if success:
|
||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
||||
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
|
||||
else:
|
||||
# For images, use WebP format for better performance
|
||||
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
|
||||
# Download the original image to temp path using downloader
|
||||
downloader = await get_downloader()
|
||||
success, content, headers = await downloader.download_to_memory(
|
||||
images[0]['url'],
|
||||
use_auth=False
|
||||
)
|
||||
if success:
|
||||
# Save to temp file
|
||||
with open(temp_path, 'wb') as f:
|
||||
f.write(content)
|
||||
# Optimize and convert to WebP
|
||||
preview_path = os.path.splitext(save_path)[0] + '.webp'
|
||||
|
||||
# Use ExifUtils to optimize and convert the image
|
||||
optimized_data, _ = ExifUtils.optimize_image(
|
||||
image_data=temp_path,
|
||||
target_width=CARD_PREVIEW_WIDTH,
|
||||
format='webp',
|
||||
quality=85,
|
||||
preserve_metadata=False
|
||||
)
|
||||
|
||||
# Save the optimized image
|
||||
with open(preview_path, 'wb') as f:
|
||||
f.write(optimized_data)
|
||||
|
||||
# Update metadata
|
||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
||||
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
|
||||
|
||||
# Remove temporary file
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete temp file: {e}")
|
||||
first_image = images[0] if isinstance(images[0], dict) else None
|
||||
preview_url = first_image.get('url') if first_image else None
|
||||
media_type = (first_image.get('type') or '').lower() if first_image else ''
|
||||
nsfw_level = first_image.get('nsfwLevel', 0) if first_image else 0
|
||||
|
||||
def _extension_from_url(url: str, fallback: str) -> str:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except ValueError:
|
||||
return fallback
|
||||
ext = os.path.splitext(parsed.path)[1]
|
||||
return ext or fallback
|
||||
|
||||
preview_downloaded = False
|
||||
preview_path = None
|
||||
|
||||
if preview_url:
|
||||
downloader = await get_downloader()
|
||||
|
||||
if media_type == 'video':
|
||||
preview_ext = _extension_from_url(preview_url, '.mp4')
|
||||
preview_path = os.path.splitext(save_path)[0] + preview_ext
|
||||
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='video')
|
||||
attempt_urls: List[str] = []
|
||||
if rewritten:
|
||||
attempt_urls.append(rewritten_url)
|
||||
attempt_urls.append(preview_url)
|
||||
|
||||
seen_attempts = set()
|
||||
for attempt in attempt_urls:
|
||||
if not attempt or attempt in seen_attempts:
|
||||
continue
|
||||
seen_attempts.add(attempt)
|
||||
success, _ = await downloader.download_file(
|
||||
attempt,
|
||||
preview_path,
|
||||
use_auth=False
|
||||
)
|
||||
if success:
|
||||
preview_downloaded = True
|
||||
break
|
||||
else:
|
||||
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='image')
|
||||
if rewritten:
|
||||
preview_ext = _extension_from_url(preview_url, '.png')
|
||||
preview_path = os.path.splitext(save_path)[0] + preview_ext
|
||||
success, _ = await downloader.download_file(
|
||||
rewritten_url,
|
||||
preview_path,
|
||||
use_auth=False
|
||||
)
|
||||
if success:
|
||||
preview_downloaded = True
|
||||
|
||||
if not preview_downloaded:
|
||||
temp_path: str | None = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
|
||||
success, content, _ = await downloader.download_to_memory(
|
||||
preview_url,
|
||||
use_auth=False
|
||||
)
|
||||
if success:
|
||||
with open(temp_path, 'wb') as temp_file_handle:
|
||||
temp_file_handle.write(content)
|
||||
preview_path = os.path.splitext(save_path)[0] + '.webp'
|
||||
|
||||
optimized_data, _ = ExifUtils.optimize_image(
|
||||
image_data=temp_path,
|
||||
target_width=CARD_PREVIEW_WIDTH,
|
||||
format='webp',
|
||||
quality=85,
|
||||
preserve_metadata=False
|
||||
)
|
||||
|
||||
with open(preview_path, 'wb') as preview_file:
|
||||
preview_file.write(optimized_data)
|
||||
|
||||
preview_downloaded = True
|
||||
finally:
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete temp file: {e}")
|
||||
|
||||
if preview_downloaded and preview_path:
|
||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
||||
metadata.preview_nsfw_level = nsfw_level
|
||||
if download_id and download_id in self._active_downloads:
|
||||
self._active_downloads[download_id]['preview_path'] = preview_path
|
||||
|
||||
# Report preview download completion
|
||||
if progress_callback:
|
||||
await progress_callback(3) # 3% progress after preview download
|
||||
|
||||
@@ -675,7 +712,15 @@ class DownloadManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting metadata file: {e}")
|
||||
|
||||
# Delete preview file if exists (.webp or .mp4)
|
||||
preview_path_value = download_info.get('preview_path')
|
||||
if preview_path_value and os.path.exists(preview_path_value):
|
||||
try:
|
||||
os.unlink(preview_path_value)
|
||||
logger.debug(f"Deleted preview file: {preview_path_value}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting preview file: {e}")
|
||||
|
||||
# Delete preview file if exists (.webp or .mp4) for legacy paths
|
||||
for preview_ext in ['.webp', '.mp4']:
|
||||
preview_path = os.path.splitext(file_path)[0] + preview_ext
|
||||
if os.path.exists(preview_path):
|
||||
@@ -708,4 +753,4 @@ class DownloadManager:
|
||||
}
|
||||
for task_id, info in self._active_downloads.items()
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ import asyncio
|
||||
import aiohttp
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Tuple, Callable, Union
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -94,12 +94,13 @@ class Downloader:
|
||||
|
||||
# Check for app-level proxy settings
|
||||
proxy_url = None
|
||||
if settings.get('proxy_enabled', False):
|
||||
proxy_host = settings.get('proxy_host', '').strip()
|
||||
proxy_port = settings.get('proxy_port', '').strip()
|
||||
proxy_type = settings.get('proxy_type', 'http').lower()
|
||||
proxy_username = settings.get('proxy_username', '').strip()
|
||||
proxy_password = settings.get('proxy_password', '').strip()
|
||||
settings_manager = get_settings_manager()
|
||||
if settings_manager.get('proxy_enabled', False):
|
||||
proxy_host = settings_manager.get('proxy_host', '').strip()
|
||||
proxy_port = settings_manager.get('proxy_port', '').strip()
|
||||
proxy_type = settings_manager.get('proxy_type', 'http').lower()
|
||||
proxy_username = settings_manager.get('proxy_username', '').strip()
|
||||
proxy_password = settings_manager.get('proxy_password', '').strip()
|
||||
|
||||
if proxy_host and proxy_port:
|
||||
# Build proxy URL
|
||||
@@ -146,7 +147,8 @@ class Downloader:
|
||||
|
||||
if use_auth:
|
||||
# Add CivitAI API key if available
|
||||
api_key = settings.get('civitai_api_key')
|
||||
settings_manager = get_settings_manager()
|
||||
api_key = settings_manager.get('civitai_api_key')
|
||||
if api_key:
|
||||
headers['Authorization'] = f'Bearer {api_key}'
|
||||
headers['Content-Type'] = 'application/json'
|
||||
|
||||
@@ -11,7 +11,7 @@ from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from .service_registry import ServiceRegistry
|
||||
from .settings_manager import settings
|
||||
from .settings_manager import get_settings_manager
|
||||
from ..utils.example_images_paths import iter_library_roots
|
||||
|
||||
|
||||
@@ -62,7 +62,8 @@ class ExampleImagesCleanupService:
|
||||
async def cleanup_example_image_folders(self) -> Dict[str, object]:
|
||||
"""Clean empty or orphaned example image folders by moving them under a deleted bucket."""
|
||||
|
||||
example_images_path = settings.get("example_images_path")
|
||||
settings_manager = get_settings_manager()
|
||||
example_images_path = settings_manager.get("example_images_path")
|
||||
if not example_images_path:
|
||||
logger.debug("Cleanup skipped: example images path not configured")
|
||||
return {
|
||||
|
||||
@@ -6,7 +6,7 @@ from .model_metadata_provider import (
|
||||
CivitaiModelMetadataProvider,
|
||||
FallbackMetadataProvider
|
||||
)
|
||||
from .settings_manager import settings
|
||||
from .settings_manager import get_settings_manager
|
||||
from .metadata_archive_manager import MetadataArchiveManager
|
||||
from .service_registry import ServiceRegistry
|
||||
|
||||
@@ -21,7 +21,8 @@ async def initialize_metadata_providers():
|
||||
provider_manager.default_provider = None
|
||||
|
||||
# Get settings
|
||||
enable_archive_db = settings.get('enable_metadata_archive_db', False)
|
||||
settings_manager = get_settings_manager()
|
||||
enable_archive_db = settings_manager.get('enable_metadata_archive_db', False)
|
||||
|
||||
providers = []
|
||||
|
||||
@@ -87,7 +88,8 @@ async def update_metadata_providers():
|
||||
"""Update metadata providers based on current settings"""
|
||||
try:
|
||||
# Get current settings
|
||||
enable_archive_db = settings.get('enable_metadata_archive_db', False)
|
||||
settings_manager = get_settings_manager()
|
||||
enable_archive_db = settings_manager.get('enable_metadata_archive_db', False)
|
||||
|
||||
# Reinitialize all providers with new settings
|
||||
provider_manager = await initialize_metadata_providers()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from typing import List, Dict, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from operator import itemgetter
|
||||
from natsort import natsorted
|
||||
|
||||
@@ -17,10 +17,12 @@ SUPPORTED_SORT_MODES = [
|
||||
|
||||
@dataclass
|
||||
class ModelCache:
|
||||
"""Cache structure for model data with extensible sorting"""
|
||||
"""Cache structure for model data with extensible sorting."""
|
||||
|
||||
raw_data: List[Dict]
|
||||
folders: List[str]
|
||||
|
||||
version_index: Dict[int, Dict] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
self._lock = asyncio.Lock()
|
||||
# Cache for last sort: (sort_key, order) -> sorted list
|
||||
@@ -28,6 +30,58 @@ class ModelCache:
|
||||
self._last_sorted_data: List[Dict] = []
|
||||
# Default sort on init
|
||||
asyncio.create_task(self.resort())
|
||||
self.rebuild_version_index()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_version_id(value: Any) -> Optional[int]:
|
||||
"""Normalize a potential version identifier into an integer."""
|
||||
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def rebuild_version_index(self) -> None:
|
||||
"""Rebuild the version index from the current raw data."""
|
||||
|
||||
self.version_index = {}
|
||||
for item in self.raw_data:
|
||||
self.add_to_version_index(item)
|
||||
|
||||
def add_to_version_index(self, item: Dict) -> None:
|
||||
"""Register a cache item in the version index if possible."""
|
||||
|
||||
civitai_data = item.get('civitai') if isinstance(item, dict) else None
|
||||
if not isinstance(civitai_data, dict):
|
||||
return
|
||||
|
||||
version_id = self._normalize_version_id(civitai_data.get('id'))
|
||||
if version_id is None:
|
||||
return
|
||||
|
||||
self.version_index[version_id] = item
|
||||
|
||||
def remove_from_version_index(self, item: Dict) -> None:
|
||||
"""Remove a cache item from the version index if present."""
|
||||
|
||||
civitai_data = item.get('civitai') if isinstance(item, dict) else None
|
||||
if not isinstance(civitai_data, dict):
|
||||
return
|
||||
|
||||
version_id = self._normalize_version_id(civitai_data.get('id'))
|
||||
if version_id is None:
|
||||
return
|
||||
|
||||
existing = self.version_index.get(version_id)
|
||||
if existing is item or (
|
||||
isinstance(existing, dict)
|
||||
and existing.get('file_path') == item.get('file_path')
|
||||
):
|
||||
self.version_index.pop(version_id, None)
|
||||
|
||||
async def resort(self):
|
||||
"""Resort cached data according to last sort mode if set"""
|
||||
@@ -41,6 +95,7 @@ class ModelCache:
|
||||
|
||||
all_folders = set(l['folder'] for l in self.raw_data)
|
||||
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
self.rebuild_version_index()
|
||||
|
||||
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
|
||||
"""Sort data by sort_key and order"""
|
||||
|
||||
@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
|
||||
|
||||
from ..utils.utils import calculate_relative_path_for_model, remove_empty_dirs
|
||||
from ..utils.constants import AUTO_ORGANIZE_BATCH_SIZE
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -114,7 +114,8 @@ class ModelFileService:
|
||||
raise ValueError('No model roots configured')
|
||||
|
||||
# Check if flat structure is configured for this model type
|
||||
path_template = settings.get_download_path_template(self.model_type)
|
||||
settings_manager = get_settings_manager()
|
||||
path_template = settings_manager.get_download_path_template(self.model_type)
|
||||
result.is_flat_structure = not path_template
|
||||
|
||||
# Initialize tracking
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Tuple, Any
|
||||
from typing import Optional, Dict, Tuple, Any, List
|
||||
from .downloader import get_downloader
|
||||
|
||||
try:
|
||||
@@ -61,6 +61,11 @@ class ModelMetadataProvider(ABC):
|
||||
"""Fetch model version metadata"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||
"""Fetch models owned by the specified user"""
|
||||
pass
|
||||
|
||||
class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Provider that uses Civitai API for metadata"""
|
||||
|
||||
@@ -79,6 +84,9 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
return await self.client.get_model_version_info(version_id)
|
||||
|
||||
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||
return await self.client.get_user_models(username)
|
||||
|
||||
class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Provider that uses CivArchive HTML page parsing for metadata"""
|
||||
|
||||
@@ -197,6 +205,10 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Not supported by CivArchive provider - requires both model_id and version_id"""
|
||||
return None, "CivArchive provider requires both model_id and version_id"
|
||||
|
||||
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||
"""Not supported by CivArchive provider"""
|
||||
return None
|
||||
|
||||
class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Provider that uses SQLite database for metadata"""
|
||||
|
||||
@@ -329,20 +341,24 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Fetch model version metadata from SQLite database"""
|
||||
async with self._aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = self._aiosqlite.Row
|
||||
|
||||
|
||||
# Get version details
|
||||
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
|
||||
cursor = await db.execute(version_query, (version_id,))
|
||||
version_row = await cursor.fetchone()
|
||||
|
||||
|
||||
if not version_row:
|
||||
return None, "Model version not found"
|
||||
|
||||
|
||||
model_id = version_row['model_id']
|
||||
|
||||
|
||||
# Build complete version data with model info
|
||||
version_data = await self._get_version_with_model_data(db, model_id, version_id)
|
||||
return version_data, None
|
||||
|
||||
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||
"""Listing models by username is not supported for archive database"""
|
||||
return None
|
||||
|
||||
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
|
||||
"""Helper to build version data with model information"""
|
||||
@@ -481,6 +497,17 @@ class FallbackMetadataProvider(ModelMetadataProvider):
|
||||
continue
|
||||
return None, "No provider could retrieve the data"
|
||||
|
||||
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||
for provider in self.providers:
|
||||
try:
|
||||
result = await provider.get_user_models(username)
|
||||
if result is not None:
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.debug(f"Provider failed for get_user_models: {e}")
|
||||
continue
|
||||
return None
|
||||
|
||||
class ModelMetadataProviderManager:
|
||||
"""Manager for selecting and using model metadata providers"""
|
||||
|
||||
@@ -522,6 +549,11 @@ class ModelMetadataProviderManager:
|
||||
"""Fetch model version info using specified or default provider"""
|
||||
provider = self._get_provider(provider_name)
|
||||
return await provider.get_model_version_info(version_id)
|
||||
|
||||
async def get_user_models(self, username: str, provider_name: str = None) -> Optional[List[Dict]]:
|
||||
"""Fetch models owned by the specified user"""
|
||||
provider = self._get_provider(provider_name)
|
||||
return await provider.get_user_models(username)
|
||||
|
||||
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
|
||||
"""Get provider by name or default provider"""
|
||||
|
||||
@@ -634,7 +634,8 @@ class ModelScanner:
|
||||
if model_data:
|
||||
# Add to cache
|
||||
self._cache.raw_data.append(model_data)
|
||||
|
||||
self._cache.add_to_version_index(model_data)
|
||||
|
||||
# Update hash index if available
|
||||
if 'sha256' in model_data and 'file_path' in model_data:
|
||||
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
|
||||
@@ -661,7 +662,9 @@ class ModelScanner:
|
||||
for path in missing_files:
|
||||
try:
|
||||
model_to_remove = path_to_item[path]
|
||||
|
||||
|
||||
self._cache.remove_from_version_index(model_to_remove)
|
||||
|
||||
# Update tags count
|
||||
for tag in model_to_remove.get('tags', []):
|
||||
if tag in self._tags_count:
|
||||
@@ -684,6 +687,8 @@ class ModelScanner:
|
||||
all_folders = set(item.get('folder', '') for item in self._cache.raw_data)
|
||||
self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
self._cache.rebuild_version_index()
|
||||
|
||||
# Resort cache
|
||||
await self._cache.resort()
|
||||
|
||||
@@ -829,6 +834,8 @@ class ModelScanner:
|
||||
else:
|
||||
self._cache.raw_data = list(scan_result.raw_data)
|
||||
|
||||
self._cache.rebuild_version_index()
|
||||
|
||||
await self._cache.resort()
|
||||
|
||||
async def _gather_model_data(
|
||||
@@ -934,7 +941,8 @@ class ModelScanner:
|
||||
|
||||
# Add to cache
|
||||
self._cache.raw_data.append(metadata_dict)
|
||||
|
||||
self._cache.add_to_version_index(metadata_dict)
|
||||
|
||||
# Resort cache data
|
||||
await self._cache.resort()
|
||||
|
||||
@@ -1076,6 +1084,9 @@ class ModelScanner:
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
|
||||
if existing_item:
|
||||
cache.remove_from_version_index(existing_item)
|
||||
|
||||
if existing_item and 'tags' in existing_item:
|
||||
for tag in existing_item.get('tags', []):
|
||||
if tag in self._tags_count:
|
||||
@@ -1106,6 +1117,7 @@ class ModelScanner:
|
||||
)
|
||||
|
||||
cache.raw_data.append(cache_entry)
|
||||
cache.add_to_version_index(cache_entry)
|
||||
|
||||
sha_value = cache_entry.get('sha256')
|
||||
if sha_value:
|
||||
@@ -1117,6 +1129,8 @@ class ModelScanner:
|
||||
for tag in cache_entry.get('tags', []):
|
||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||
|
||||
cache.rebuild_version_index()
|
||||
|
||||
await cache.resort()
|
||||
|
||||
if cache_modified:
|
||||
@@ -1339,11 +1353,12 @@ class ModelScanner:
|
||||
# Update hash index
|
||||
for model in models_to_remove:
|
||||
file_path = model['file_path']
|
||||
self._cache.remove_from_version_index(model)
|
||||
if hasattr(self, '_hash_index') and self._hash_index:
|
||||
# Get the hash and filename before removal for duplicate checking
|
||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
hash_val = model.get('sha256', '').lower()
|
||||
|
||||
|
||||
# Remove from hash index
|
||||
self._hash_index.remove_by_path(file_path, hash_val)
|
||||
|
||||
@@ -1352,8 +1367,9 @@ class ModelScanner:
|
||||
|
||||
# Update cache data
|
||||
self._cache.raw_data = [item for item in self._cache.raw_data if item['file_path'] not in file_paths]
|
||||
|
||||
|
||||
# Resort cache
|
||||
self._cache.rebuild_version_index()
|
||||
await self._cache.resort()
|
||||
|
||||
await self._persist_current_cache()
|
||||
@@ -1393,16 +1409,17 @@ class ModelScanner:
|
||||
Returns:
|
||||
bool: True if the model version exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
normalized_id = int(model_version_id)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
try:
|
||||
cache = await self.get_cached_data()
|
||||
if not cache or not cache.raw_data:
|
||||
if not cache:
|
||||
return False
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('civitai') and item['civitai'].get('id') == model_version_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
return normalized_id in cache.version_index
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking model version existence: {e}")
|
||||
return False
|
||||
|
||||
@@ -351,7 +351,7 @@ class PersistentModelCache:
|
||||
|
||||
|
||||
def get_persistent_cache() -> PersistentModelCache:
|
||||
from .settings_manager import settings as settings_service # Local import to avoid cycles
|
||||
from .settings_manager import get_settings_manager # Local import to avoid cycles
|
||||
|
||||
library_name = settings_service.get_active_library_name()
|
||||
library_name = get_settings_manager().get_active_library_name()
|
||||
return PersistentModelCache.get_default(library_name)
|
||||
|
||||
@@ -5,8 +5,10 @@ from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
from typing import Awaitable, Callable, Dict, Optional, Sequence
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS
|
||||
from ..utils.civitai_utils import rewrite_preview_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -45,23 +47,59 @@ class PreviewAssetService:
|
||||
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
|
||||
preview_dir = os.path.dirname(metadata_path)
|
||||
is_video = first_preview.get("type") == "video"
|
||||
preview_url = first_preview.get("url")
|
||||
|
||||
if not preview_url:
|
||||
return
|
||||
|
||||
def extension_from_url(url: str, fallback: str) -> str:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except ValueError:
|
||||
return fallback
|
||||
ext = os.path.splitext(parsed.path)[1]
|
||||
return ext or fallback
|
||||
|
||||
downloader = await self._downloader_factory()
|
||||
|
||||
if is_video:
|
||||
extension = ".mp4"
|
||||
extension = extension_from_url(preview_url, ".mp4")
|
||||
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||
downloader = await self._downloader_factory()
|
||||
success, result = await downloader.download_file(
|
||||
first_preview["url"], preview_path, use_auth=False
|
||||
)
|
||||
if success:
|
||||
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
||||
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
||||
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type="video")
|
||||
|
||||
attempt_urls = []
|
||||
if rewritten:
|
||||
attempt_urls.append(rewritten_url)
|
||||
attempt_urls.append(preview_url)
|
||||
|
||||
seen: set[str] = set()
|
||||
for candidate in attempt_urls:
|
||||
if not candidate or candidate in seen:
|
||||
continue
|
||||
seen.add(candidate)
|
||||
|
||||
success, _ = await downloader.download_file(candidate, preview_path, use_auth=False)
|
||||
if success:
|
||||
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
||||
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
||||
return
|
||||
else:
|
||||
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type="image")
|
||||
if rewritten:
|
||||
extension = extension_from_url(preview_url, ".png")
|
||||
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||
success, _ = await downloader.download_file(
|
||||
rewritten_url, preview_path, use_auth=False
|
||||
)
|
||||
if success:
|
||||
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
||||
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
||||
return
|
||||
|
||||
extension = ".webp"
|
||||
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||
downloader = await self._downloader_factory()
|
||||
success, content, _headers = await downloader.download_to_memory(
|
||||
first_preview["url"], use_auth=False
|
||||
preview_url, use_auth=False
|
||||
)
|
||||
if not success:
|
||||
return
|
||||
|
||||
@@ -3,6 +3,7 @@ import json
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from threading import Lock
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional
|
||||
|
||||
from ..utils.settings_paths import ensure_settings_file
|
||||
@@ -688,4 +689,38 @@ class SettingsManager:
|
||||
|
||||
return templates.get(model_type, '{base_model}/{first_tag}')
|
||||
|
||||
settings = SettingsManager()
|
||||
|
||||
_SETTINGS_MANAGER: Optional["SettingsManager"] = None
|
||||
_SETTINGS_MANAGER_LOCK = Lock()
|
||||
# Legacy module-level alias for backwards compatibility with callers that
|
||||
# monkeypatch ``py.services.settings_manager.settings`` during tests.
|
||||
settings: Optional["SettingsManager"] = None
|
||||
|
||||
|
||||
def get_settings_manager() -> "SettingsManager":
|
||||
"""Return the lazily initialised global :class:`SettingsManager`."""
|
||||
|
||||
global _SETTINGS_MANAGER, settings
|
||||
if settings is not None:
|
||||
return settings
|
||||
|
||||
if _SETTINGS_MANAGER is None:
|
||||
with _SETTINGS_MANAGER_LOCK:
|
||||
if _SETTINGS_MANAGER is None:
|
||||
_SETTINGS_MANAGER = SettingsManager()
|
||||
|
||||
settings = _SETTINGS_MANAGER
|
||||
return _SETTINGS_MANAGER
|
||||
|
||||
|
||||
def reset_settings_manager() -> None:
|
||||
"""Reset the cached settings manager instance.
|
||||
|
||||
Primarily intended for tests so they can configure the settings
|
||||
directory before the manager touches the filesystem.
|
||||
"""
|
||||
|
||||
global _SETTINGS_MANAGER, settings
|
||||
with _SETTINGS_MANAGER_LOCK:
|
||||
_SETTINGS_MANAGER = None
|
||||
settings = None
|
||||
|
||||
@@ -6,6 +6,7 @@ import logging
|
||||
from typing import Any, Dict, Optional, Protocol, Sequence
|
||||
|
||||
from ..metadata_sync_service import MetadataSyncService
|
||||
from ...utils.metadata_manager import MetadataManager
|
||||
|
||||
|
||||
class MetadataRefreshProgressReporter(Protocol):
|
||||
@@ -70,6 +71,7 @@ class BulkMetadataRefreshUseCase:
|
||||
for model in to_process:
|
||||
try:
|
||||
original_name = model.get("model_name")
|
||||
await MetadataManager.hydrate_model_data(model)
|
||||
result, _ = await self._metadata_sync.fetch_and_update_model(
|
||||
sha256=model["sha256"],
|
||||
file_path=model["file_path"],
|
||||
|
||||
48
py/utils/civitai_utils.py
Normal file
48
py/utils/civitai_utils.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Utilities for working with Civitai assets."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
|
||||
def rewrite_preview_url(source_url: str | None, media_type: str | None = None) -> tuple[str | None, bool]:
|
||||
"""Rewrite Civitai preview URLs to use optimized renditions.
|
||||
|
||||
Args:
|
||||
source_url: Original preview URL from the Civitai API.
|
||||
media_type: Optional media type hint (e.g. ``"image"`` or ``"video"``).
|
||||
|
||||
Returns:
|
||||
A tuple of the potentially rewritten URL and a flag indicating whether the
|
||||
replacement occurred. When the URL is not rewritten, the original value is
|
||||
returned with ``False``.
|
||||
"""
|
||||
if not source_url:
|
||||
return source_url, False
|
||||
|
||||
try:
|
||||
parsed = urlparse(source_url)
|
||||
except ValueError:
|
||||
return source_url, False
|
||||
|
||||
if parsed.netloc.lower() != "image.civitai.com":
|
||||
return source_url, False
|
||||
|
||||
replacement = "/width=450,optimized=true"
|
||||
if (media_type or "").lower() == "video":
|
||||
replacement = "/transcode=true,width=450,optimized=true"
|
||||
|
||||
if "/original=true" not in parsed.path:
|
||||
return source_url, False
|
||||
|
||||
updated_path = parsed.path.replace("/original=true", replacement, 1)
|
||||
if updated_path == parsed.path:
|
||||
return source_url, False
|
||||
|
||||
rewritten = urlunparse(parsed._replace(path=updated_path))
|
||||
print(rewritten)
|
||||
return rewritten, True
|
||||
|
||||
|
||||
__all__ = ["rewrite_preview_url"]
|
||||
|
||||
@@ -48,6 +48,13 @@ SUPPORTED_MEDIA_EXTENSIONS = {
|
||||
# Valid Lora types
|
||||
VALID_LORA_TYPES = ['lora', 'locon', 'dora']
|
||||
|
||||
# Supported Civitai model types for user model queries (case-insensitive)
|
||||
CIVITAI_USER_MODEL_TYPES = [
|
||||
*VALID_LORA_TYPES,
|
||||
'textualinversion',
|
||||
'checkpoint',
|
||||
]
|
||||
|
||||
# Auto-organize settings
|
||||
AUTO_ORGANIZE_BATCH_SIZE = 50 # Process models in batches to avoid overwhelming the system
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from ..utils.metadata_manager import MetadataManager
|
||||
from .example_images_processor import ExampleImagesProcessor
|
||||
from .example_images_metadata import MetadataUpdater
|
||||
from ..services.downloader import get_downloader
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
|
||||
|
||||
class ExampleImagesDownloadError(RuntimeError):
|
||||
@@ -107,7 +107,7 @@ class DownloadManager:
|
||||
self._state_lock = state_lock or asyncio.Lock()
|
||||
|
||||
def _resolve_output_dir(self, library_name: str | None = None) -> str:
|
||||
base_path = settings.get('example_images_path')
|
||||
base_path = get_settings_manager().get('example_images_path')
|
||||
if not base_path:
|
||||
return ''
|
||||
return ensure_library_root_exists(library_name)
|
||||
@@ -126,7 +126,8 @@ class DownloadManager:
|
||||
model_types = data.get('model_types', ['lora', 'checkpoint'])
|
||||
delay = float(data.get('delay', 0.2))
|
||||
|
||||
base_path = settings.get('example_images_path')
|
||||
settings_manager = get_settings_manager()
|
||||
base_path = settings_manager.get('example_images_path')
|
||||
|
||||
if not base_path:
|
||||
error_msg = 'Example images path not configured in settings'
|
||||
@@ -138,7 +139,7 @@ class DownloadManager:
|
||||
}
|
||||
raise DownloadConfigurationError(error_msg)
|
||||
|
||||
active_library = settings.get_active_library_name()
|
||||
active_library = get_settings_manager().get_active_library_name()
|
||||
output_dir = self._resolve_output_dir(active_library)
|
||||
if not output_dir:
|
||||
raise DownloadConfigurationError('Example images path not configured in settings')
|
||||
@@ -151,7 +152,7 @@ class DownloadManager:
|
||||
progress_file = os.path.join(output_dir, '.download_progress.json')
|
||||
progress_source = progress_file
|
||||
if uses_library_scoped_folders():
|
||||
legacy_root = settings.get('example_images_path') or ''
|
||||
legacy_root = get_settings_manager().get('example_images_path') or ''
|
||||
legacy_progress = os.path.join(legacy_root, '.download_progress.json') if legacy_root else ''
|
||||
if legacy_progress and os.path.exists(legacy_progress) and not os.path.exists(progress_file):
|
||||
try:
|
||||
@@ -555,11 +556,12 @@ class DownloadManager:
|
||||
if not model_hashes:
|
||||
raise DownloadConfigurationError('Missing model_hashes parameter')
|
||||
|
||||
base_path = settings.get('example_images_path')
|
||||
settings_manager = get_settings_manager()
|
||||
base_path = settings_manager.get('example_images_path')
|
||||
|
||||
if not base_path:
|
||||
raise DownloadConfigurationError('Example images path not configured in settings')
|
||||
active_library = settings.get_active_library_name()
|
||||
active_library = settings_manager.get_active_library_name()
|
||||
output_dir = self._resolve_output_dir(active_library)
|
||||
if not output_dir:
|
||||
raise DownloadConfigurationError('Example images path not configured in settings')
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import sys
|
||||
import subprocess
|
||||
from aiohttp import web
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
from ..utils.example_images_paths import (
|
||||
get_model_folder,
|
||||
get_model_relative_path,
|
||||
@@ -37,7 +37,8 @@ class ExampleImagesFileManager:
|
||||
}, status=400)
|
||||
|
||||
# Get example images path from settings
|
||||
example_images_path = settings.get('example_images_path')
|
||||
settings_manager = get_settings_manager()
|
||||
example_images_path = settings_manager.get('example_images_path')
|
||||
if not example_images_path:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
@@ -109,7 +110,8 @@ class ExampleImagesFileManager:
|
||||
}, status=400)
|
||||
|
||||
# Get example images path from settings
|
||||
example_images_path = settings.get('example_images_path')
|
||||
settings_manager = get_settings_manager()
|
||||
example_images_path = settings_manager.get('example_images_path')
|
||||
if not example_images_path:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
@@ -183,7 +185,8 @@ class ExampleImagesFileManager:
|
||||
}, status=400)
|
||||
|
||||
# Get example images path from settings
|
||||
example_images_path = settings.get('example_images_path')
|
||||
settings_manager = get_settings_manager()
|
||||
example_images_path = settings_manager.get('example_images_path')
|
||||
if not example_images_path:
|
||||
return web.json_response({
|
||||
'has_images': False
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from ..recipes.constants import GEN_PARAM_KEYS
|
||||
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
|
||||
from ..services.metadata_sync_service import MetadataSyncService
|
||||
from ..services.preview_asset_service import PreviewAssetService
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
from ..services.downloader import get_downloader
|
||||
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
@@ -20,13 +21,46 @@ _preview_service = PreviewAssetService(
|
||||
exif_utils=ExifUtils,
|
||||
)
|
||||
|
||||
_metadata_sync_service = MetadataSyncService(
|
||||
metadata_manager=MetadataManager,
|
||||
preview_service=_preview_service,
|
||||
settings=settings,
|
||||
default_metadata_provider_factory=get_default_metadata_provider,
|
||||
metadata_provider_selector=get_metadata_provider,
|
||||
)
|
||||
_metadata_sync_service: MetadataSyncService | None = None
|
||||
_metadata_sync_service_settings: Optional["SettingsManager"] = None
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - import for type checkers only
|
||||
from ..services.settings_manager import SettingsManager
|
||||
|
||||
|
||||
def _build_metadata_sync_service(settings_manager: "SettingsManager") -> MetadataSyncService:
|
||||
"""Construct a metadata sync service bound to the provided settings."""
|
||||
|
||||
return MetadataSyncService(
|
||||
metadata_manager=MetadataManager,
|
||||
preview_service=_preview_service,
|
||||
settings=settings_manager,
|
||||
default_metadata_provider_factory=get_default_metadata_provider,
|
||||
metadata_provider_selector=get_metadata_provider,
|
||||
)
|
||||
|
||||
|
||||
def _get_metadata_sync_service() -> MetadataSyncService:
|
||||
"""Return the shared metadata sync service, initialising it lazily."""
|
||||
|
||||
global _metadata_sync_service, _metadata_sync_service_settings
|
||||
|
||||
settings_manager = get_settings_manager()
|
||||
|
||||
if isinstance(_metadata_sync_service, MetadataSyncService):
|
||||
if _metadata_sync_service_settings is not settings_manager:
|
||||
_metadata_sync_service = _build_metadata_sync_service(settings_manager)
|
||||
_metadata_sync_service_settings = settings_manager
|
||||
elif _metadata_sync_service is None:
|
||||
_metadata_sync_service = _build_metadata_sync_service(settings_manager)
|
||||
_metadata_sync_service_settings = settings_manager
|
||||
else:
|
||||
# Tests may inject stand-ins that do not match the sync service type. Preserve
|
||||
# those injections while still updating our cached settings reference so the
|
||||
# next real service instantiation uses the current configuration.
|
||||
_metadata_sync_service_settings = settings_manager
|
||||
|
||||
return _metadata_sync_service
|
||||
|
||||
|
||||
class MetadataUpdater:
|
||||
@@ -71,7 +105,8 @@ class MetadataUpdater:
|
||||
async def update_cache_func(old_path, new_path, metadata):
|
||||
return await scanner.update_single_model_cache(old_path, new_path, metadata)
|
||||
|
||||
success, error = await _metadata_sync_service.fetch_and_update_model(
|
||||
await MetadataManager.hydrate_model_data(model_data)
|
||||
success, error = await _get_metadata_sync_service().fetch_and_update_model(
|
||||
sha256=model_hash,
|
||||
file_path=file_path,
|
||||
model_data=model_data,
|
||||
@@ -151,16 +186,16 @@ class MetadataUpdater:
|
||||
if is_supported:
|
||||
local_images_paths.append(file_path)
|
||||
|
||||
await MetadataManager.hydrate_model_data(model)
|
||||
civitai_data = model.setdefault('civitai', {})
|
||||
|
||||
# Check if metadata update is needed (no civitai field or empty images)
|
||||
needs_update = not model.get('civitai') or not model.get('civitai', {}).get('images')
|
||||
needs_update = not civitai_data or not civitai_data.get('images')
|
||||
|
||||
if needs_update and local_images_paths:
|
||||
logger.debug(f"Found {len(local_images_paths)} local example images for {model.get('model_name')}, updating metadata")
|
||||
|
||||
# Create or get civitai field
|
||||
if not model.get('civitai'):
|
||||
model['civitai'] = {}
|
||||
|
||||
# Create images array
|
||||
images = []
|
||||
|
||||
@@ -195,16 +230,13 @@ class MetadataUpdater:
|
||||
images.append(image_entry)
|
||||
|
||||
# Update the model's civitai.images field
|
||||
model['civitai']['images'] = images
|
||||
civitai_data['images'] = images
|
||||
|
||||
# Save metadata to .metadata.json file
|
||||
file_path = model.get('file_path')
|
||||
try:
|
||||
# Create a copy of model data without 'folder' field
|
||||
model_copy = model.copy()
|
||||
model_copy.pop('folder', None)
|
||||
|
||||
# Write metadata to file
|
||||
await MetadataManager.save_metadata(file_path, model_copy)
|
||||
logger.info(f"Saved metadata for {model.get('model_name')}")
|
||||
except Exception as e:
|
||||
@@ -237,16 +269,13 @@ class MetadataUpdater:
|
||||
tuple: (regular_images, custom_images) - Both image arrays
|
||||
"""
|
||||
try:
|
||||
# Ensure civitai field exists in model_data
|
||||
if not model_data.get('civitai'):
|
||||
model_data['civitai'] = {}
|
||||
|
||||
# Ensure customImages array exists
|
||||
if not model_data['civitai'].get('customImages'):
|
||||
model_data['civitai']['customImages'] = []
|
||||
|
||||
# Get current customImages array
|
||||
custom_images = model_data['civitai']['customImages']
|
||||
await MetadataManager.hydrate_model_data(model_data)
|
||||
civitai_data = model_data.setdefault('civitai', {})
|
||||
custom_images = civitai_data.get('customImages')
|
||||
|
||||
if not isinstance(custom_images, list):
|
||||
custom_images = []
|
||||
civitai_data['customImages'] = custom_images
|
||||
|
||||
# Add new image entry for each imported file
|
||||
for path_tuple in newly_imported_paths:
|
||||
@@ -304,11 +333,8 @@ class MetadataUpdater:
|
||||
file_path = model_data.get('file_path')
|
||||
if file_path:
|
||||
try:
|
||||
# Create a copy of model data without 'folder' field
|
||||
model_copy = model_data.copy()
|
||||
model_copy.pop('folder', None)
|
||||
|
||||
# Write metadata to file
|
||||
await MetadataManager.save_metadata(file_path, model_copy)
|
||||
logger.info(f"Saved metadata for {model_data.get('model_name')}")
|
||||
except Exception as e:
|
||||
@@ -319,7 +345,7 @@ class MetadataUpdater:
|
||||
await scanner.update_single_model_cache(file_path, file_path, model_data)
|
||||
|
||||
# Get regular images array (might be None)
|
||||
regular_images = model_data['civitai'].get('images', [])
|
||||
regular_images = civitai_data.get('images', [])
|
||||
|
||||
# Return both image arrays
|
||||
return regular_images, custom_images
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.example_images_paths import iter_library_roots
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
@@ -14,6 +14,25 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
CURRENT_NAMING_VERSION = 2 # Increment this when naming conventions change
|
||||
|
||||
|
||||
class _SettingsProxy:
|
||||
def __init__(self):
|
||||
self._manager = None
|
||||
|
||||
def _resolve(self):
|
||||
if self._manager is None:
|
||||
self._manager = get_settings_manager()
|
||||
return self._manager
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
return self._resolve().get(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self._resolve(), item)
|
||||
|
||||
|
||||
settings = _SettingsProxy()
|
||||
|
||||
class ExampleImagesMigration:
|
||||
"""Handles migrations for example images naming conventions"""
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import re
|
||||
import shutil
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
|
||||
_HEX_PATTERN = re.compile(r"[a-fA-F0-9]{64}")
|
||||
|
||||
@@ -18,7 +18,8 @@ logger = logging.getLogger(__name__)
|
||||
def _get_configured_libraries() -> List[str]:
|
||||
"""Return configured library names if multi-library support is enabled."""
|
||||
|
||||
libraries = settings.get("libraries")
|
||||
settings_manager = get_settings_manager()
|
||||
libraries = settings_manager.get("libraries")
|
||||
if isinstance(libraries, dict) and libraries:
|
||||
return list(libraries.keys())
|
||||
return []
|
||||
@@ -27,7 +28,8 @@ def _get_configured_libraries() -> List[str]:
|
||||
def get_example_images_root() -> str:
|
||||
"""Return the root directory configured for example images."""
|
||||
|
||||
root = settings.get("example_images_path") or ""
|
||||
settings_manager = get_settings_manager()
|
||||
root = settings_manager.get("example_images_path") or ""
|
||||
return os.path.abspath(root) if root else ""
|
||||
|
||||
|
||||
@@ -41,7 +43,8 @@ def uses_library_scoped_folders() -> bool:
|
||||
def sanitize_library_name(library_name: Optional[str]) -> str:
|
||||
"""Return a filesystem safe library name."""
|
||||
|
||||
name = library_name or settings.get_active_library_name() or "default"
|
||||
settings_manager = get_settings_manager()
|
||||
name = library_name or settings_manager.get_active_library_name() or "default"
|
||||
safe_name = re.sub(r"[^A-Za-z0-9_.-]", "_", name)
|
||||
return safe_name or "default"
|
||||
|
||||
@@ -161,11 +164,13 @@ def iter_library_roots() -> Iterable[Tuple[str, str]]:
|
||||
results.append((library, get_library_root(library)))
|
||||
else:
|
||||
# Fall back to the active library to avoid skipping migrations/cleanup
|
||||
active = settings.get_active_library_name() or "default"
|
||||
settings_manager = get_settings_manager()
|
||||
active = settings_manager.get_active_library_name() or "default"
|
||||
results.append((active, get_library_root(active)))
|
||||
return results
|
||||
|
||||
active = settings.get_active_library_name() or "default"
|
||||
settings_manager = get_settings_manager()
|
||||
active = settings_manager.get_active_library_name() or "default"
|
||||
return [(active, root)]
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import string
|
||||
from aiohttp import web
|
||||
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
from ..utils.example_images_paths import get_model_folder, get_model_relative_path
|
||||
from .example_images_metadata import MetadataUpdater
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
@@ -318,7 +318,7 @@ class ExampleImagesProcessor:
|
||||
|
||||
try:
|
||||
# Get example images path
|
||||
example_images_path = settings.get('example_images_path')
|
||||
example_images_path = get_settings_manager().get('example_images_path')
|
||||
if not example_images_path:
|
||||
raise ExampleImagesValidationError('No example images path configured')
|
||||
|
||||
@@ -442,7 +442,7 @@ class ExampleImagesProcessor:
|
||||
}, status=400)
|
||||
|
||||
# Get example images path
|
||||
example_images_path = settings.get('example_images_path')
|
||||
example_images_path = get_settings_manager().get('example_images_path')
|
||||
if not example_images_path:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
@@ -475,15 +475,17 @@ class ExampleImagesProcessor:
|
||||
'error': f"Model with hash {model_hash} not found in cache"
|
||||
}, status=404)
|
||||
|
||||
# Check if model has custom images
|
||||
if not model_data.get('civitai', {}).get('customImages'):
|
||||
await MetadataManager.hydrate_model_data(model_data)
|
||||
civitai_data = model_data.setdefault('civitai', {})
|
||||
custom_images = civitai_data.get('customImages')
|
||||
|
||||
if not isinstance(custom_images, list) or not custom_images:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f"Model has no custom images"
|
||||
}, status=404)
|
||||
|
||||
# Find the custom image with matching short_id
|
||||
custom_images = model_data['civitai']['customImages']
|
||||
matching_image = None
|
||||
new_custom_images = []
|
||||
|
||||
@@ -527,17 +529,15 @@ class ExampleImagesProcessor:
|
||||
logger.warning(f"File for custom example with id {short_id} not found, but metadata will still be updated")
|
||||
|
||||
# Update metadata
|
||||
model_data['civitai']['customImages'] = new_custom_images
|
||||
civitai_data['customImages'] = new_custom_images
|
||||
model_data.setdefault('civitai', {})['customImages'] = new_custom_images
|
||||
|
||||
# Save updated metadata to file
|
||||
file_path = model_data.get('file_path')
|
||||
if file_path:
|
||||
try:
|
||||
# Create a copy of model data without 'folder' field
|
||||
model_copy = model_data.copy()
|
||||
model_copy.pop('folder', None)
|
||||
|
||||
# Write metadata to file
|
||||
await MetadataManager.save_metadata(file_path, model_copy)
|
||||
logger.debug(f"Saved updated metadata for {model_data.get('model_name')}")
|
||||
except Exception as e:
|
||||
@@ -551,7 +551,7 @@ class ExampleImagesProcessor:
|
||||
await scanner.update_single_model_cache(file_path, file_path, model_data)
|
||||
|
||||
# Get regular images array (might be None)
|
||||
regular_images = model_data['civitai'].get('images', [])
|
||||
regular_images = civitai_data.get('images', [])
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
@@ -568,4 +568,4 @@ class ExampleImagesProcessor:
|
||||
}, status=500)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Optional, Type, Union
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
|
||||
from .models import BaseModelMetadata, LoraMetadata
|
||||
from .file_utils import normalize_path, find_preview_file, calculate_sha256
|
||||
@@ -53,6 +53,70 @@ class MetadataManager:
|
||||
error_type = "Invalid JSON" if isinstance(e, json.JSONDecodeError) else "Parse error"
|
||||
logger.error(f"{error_type} in metadata file: {metadata_path}. Error: {str(e)}. Skipping model to preserve existing data.")
|
||||
return None, True # should_skip = True
|
||||
|
||||
@staticmethod
|
||||
async def load_metadata_payload(file_path: str) -> Dict:
|
||||
"""
|
||||
Load metadata and return it as a dictionary, including any unknown fields.
|
||||
Falls back to reading the raw JSON file if parsing into a model class fails.
|
||||
"""
|
||||
|
||||
payload: Dict = {}
|
||||
metadata_obj, should_skip = await MetadataManager.load_metadata(file_path)
|
||||
|
||||
if metadata_obj:
|
||||
payload = metadata_obj.to_dict()
|
||||
unknown_fields = getattr(metadata_obj, "_unknown_fields", None)
|
||||
if isinstance(unknown_fields, dict):
|
||||
payload.update(unknown_fields)
|
||||
else:
|
||||
if not should_skip:
|
||||
metadata_path = (
|
||||
file_path
|
||||
if file_path.endswith(".metadata.json")
|
||||
else f"{os.path.splitext(file_path)[0]}.metadata.json"
|
||||
)
|
||||
if os.path.exists(metadata_path):
|
||||
try:
|
||||
with open(metadata_path, "r", encoding="utf-8") as handle:
|
||||
raw = json.load(handle)
|
||||
if isinstance(raw, dict):
|
||||
payload = raw
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Failed to parse metadata file %s while loading payload",
|
||||
metadata_path,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.warning("Failed to read metadata file %s: %s", metadata_path, exc)
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
payload = {}
|
||||
|
||||
if file_path:
|
||||
payload.setdefault("file_path", normalize_path(file_path))
|
||||
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
async def hydrate_model_data(model_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Replace the provided model data with the authoritative payload from disk.
|
||||
Preserves the cached folder entry if present.
|
||||
"""
|
||||
|
||||
file_path = model_data.get("file_path")
|
||||
if not file_path:
|
||||
return model_data
|
||||
|
||||
folder = model_data.get("folder")
|
||||
payload = await MetadataManager.load_metadata_payload(file_path)
|
||||
if folder is not None:
|
||||
payload["folder"] = folder
|
||||
|
||||
model_data.clear()
|
||||
model_data.update(payload)
|
||||
return model_data
|
||||
|
||||
@staticmethod
|
||||
async def save_metadata(path: str, metadata: Union[BaseModelMetadata, Dict]) -> bool:
|
||||
|
||||
@@ -65,6 +65,12 @@ def ensure_settings_file(logger: Optional[logging.Logger] = None) -> str:
|
||||
|
||||
logger = logger or _LOGGER
|
||||
target_path = get_settings_file_path(create_dir=True)
|
||||
preferred_dir = user_config_dir(APP_NAME, appauthor=False)
|
||||
preferred_path = os.path.join(preferred_dir, "settings.json")
|
||||
|
||||
if os.path.abspath(target_path) != os.path.abspath(preferred_path):
|
||||
os.makedirs(preferred_dir, exist_ok=True)
|
||||
target_path = preferred_path
|
||||
legacy_path = get_legacy_settings_path()
|
||||
|
||||
if os.path.exists(legacy_path) and not os.path.exists(target_path):
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
from typing import Dict
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
from .constants import CIVITAI_MODEL_TAGS
|
||||
import asyncio
|
||||
|
||||
@@ -143,7 +143,8 @@ def calculate_relative_path_for_model(model_data: Dict, model_type: str = 'lora'
|
||||
Relative path string (empty string for flat structure)
|
||||
"""
|
||||
# Get path template from settings for specific model type
|
||||
path_template = settings.get_download_path_template(model_type)
|
||||
settings_manager = get_settings_manager()
|
||||
path_template = settings_manager.get_download_path_template(model_type)
|
||||
|
||||
# If template is empty, return empty path (flat structure)
|
||||
if not path_template:
|
||||
@@ -166,7 +167,7 @@ def calculate_relative_path_for_model(model_data: Dict, model_type: str = 'lora'
|
||||
model_tags = model_data.get('tags', [])
|
||||
|
||||
# Apply mapping if available
|
||||
base_model_mappings = settings.get('base_model_path_mappings', {})
|
||||
base_model_mappings = settings_manager.get('base_model_path_mappings', {})
|
||||
mapped_base_model = base_model_mappings.get(base_model, base_model)
|
||||
|
||||
# Find the first Civitai model tag that exists in model_tags
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "comfyui-lora-manager"
|
||||
description = "Revolutionize your workflow with the ultimate LoRA companion for ComfyUI!"
|
||||
version = "0.9.6"
|
||||
version = "0.9.7"
|
||||
license = {file = "LICENSE"}
|
||||
dependencies = [
|
||||
"aiohttp",
|
||||
|
||||
@@ -21,6 +21,7 @@ export class SidebarManager {
|
||||
this.isInitialized = false;
|
||||
this.displayMode = 'tree'; // 'tree' or 'list'
|
||||
this.foldersList = [];
|
||||
this.recursiveSearchEnabled = true;
|
||||
|
||||
// Bind methods
|
||||
this.handleTreeClick = this.handleTreeClick.bind(this);
|
||||
@@ -36,6 +37,7 @@ export class SidebarManager {
|
||||
this.updateContainerMargin = this.updateContainerMargin.bind(this);
|
||||
this.handleDisplayModeToggle = this.handleDisplayModeToggle.bind(this);
|
||||
this.handleFolderListClick = this.handleFolderListClick.bind(this);
|
||||
this.handleRecursiveToggle = this.handleRecursiveToggle.bind(this);
|
||||
}
|
||||
|
||||
async initialize(pageControls) {
|
||||
@@ -89,6 +91,7 @@ export class SidebarManager {
|
||||
this.isHovering = false;
|
||||
this.apiClient = null;
|
||||
this.isInitialized = false;
|
||||
this.recursiveSearchEnabled = true;
|
||||
|
||||
// Reset container margin
|
||||
const container = document.querySelector('.container');
|
||||
@@ -111,6 +114,7 @@ export class SidebarManager {
|
||||
const sidebar = document.getElementById('folderSidebar');
|
||||
const hoverArea = document.getElementById('sidebarHoverArea');
|
||||
const displayModeToggleBtn = document.getElementById('sidebarDisplayModeToggle');
|
||||
const recursiveToggleBtn = document.getElementById('sidebarRecursiveToggle');
|
||||
|
||||
if (pinToggleBtn) {
|
||||
pinToggleBtn.removeEventListener('click', this.handlePinToggle);
|
||||
@@ -145,6 +149,9 @@ export class SidebarManager {
|
||||
if (displayModeToggleBtn) {
|
||||
displayModeToggleBtn.removeEventListener('click', this.handleDisplayModeToggle);
|
||||
}
|
||||
if (recursiveToggleBtn) {
|
||||
recursiveToggleBtn.removeEventListener('click', this.handleRecursiveToggle);
|
||||
}
|
||||
}
|
||||
|
||||
async init() {
|
||||
@@ -197,7 +204,7 @@ export class SidebarManager {
|
||||
updateSidebarTitle() {
|
||||
const sidebarTitle = document.getElementById('sidebarTitle');
|
||||
if (sidebarTitle) {
|
||||
sidebarTitle.textContent = `${this.apiClient.apiConfig.config.displayName} Root`;
|
||||
sidebarTitle.textContent = translate('sidebar.modelRoot');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,6 +227,12 @@ export class SidebarManager {
|
||||
collapseAllBtn.addEventListener('click', this.handleCollapseAll);
|
||||
}
|
||||
|
||||
// Recursive toggle button
|
||||
const recursiveToggleBtn = document.getElementById('sidebarRecursiveToggle');
|
||||
if (recursiveToggleBtn) {
|
||||
recursiveToggleBtn.addEventListener('click', this.handleRecursiveToggle);
|
||||
}
|
||||
|
||||
// Tree click handler
|
||||
const folderTree = document.getElementById('sidebarFolderTree');
|
||||
if (folderTree) {
|
||||
@@ -645,11 +658,33 @@ export class SidebarManager {
|
||||
this.displayMode = this.displayMode === 'tree' ? 'list' : 'tree';
|
||||
this.updateDisplayModeButton();
|
||||
this.updateCollapseAllButton();
|
||||
this.updateRecursiveToggleButton();
|
||||
this.updateSearchRecursiveOption();
|
||||
this.saveDisplayMode();
|
||||
this.loadFolderTree(); // Reload with new display mode
|
||||
}
|
||||
|
||||
async handleRecursiveToggle(event) {
|
||||
event.stopPropagation();
|
||||
|
||||
if (this.displayMode !== 'tree') {
|
||||
return;
|
||||
}
|
||||
|
||||
this.recursiveSearchEnabled = !this.recursiveSearchEnabled;
|
||||
setStorageItem(`${this.pageType}_recursiveSearch`, this.recursiveSearchEnabled);
|
||||
this.updateSearchRecursiveOption();
|
||||
this.updateRecursiveToggleButton();
|
||||
|
||||
if (this.pageControls && typeof this.pageControls.resetAndReload === 'function') {
|
||||
try {
|
||||
await this.pageControls.resetAndReload(true);
|
||||
} catch (error) {
|
||||
console.error('Failed to reload models after toggling recursive search:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
updateDisplayModeButton() {
|
||||
const displayModeBtn = document.getElementById('sidebarDisplayModeToggle');
|
||||
if (displayModeBtn) {
|
||||
@@ -679,8 +714,35 @@ export class SidebarManager {
|
||||
}
|
||||
}
|
||||
|
||||
updateRecursiveToggleButton() {
|
||||
const recursiveToggleBtn = document.getElementById('sidebarRecursiveToggle');
|
||||
if (!recursiveToggleBtn) return;
|
||||
|
||||
const icon = recursiveToggleBtn.querySelector('i');
|
||||
const isTreeMode = this.displayMode === 'tree';
|
||||
const isActive = isTreeMode && this.recursiveSearchEnabled;
|
||||
|
||||
recursiveToggleBtn.classList.toggle('active', isActive);
|
||||
recursiveToggleBtn.classList.toggle('disabled', !isTreeMode);
|
||||
recursiveToggleBtn.setAttribute('aria-pressed', isActive ? 'true' : 'false');
|
||||
recursiveToggleBtn.setAttribute('aria-disabled', isTreeMode ? 'false' : 'true');
|
||||
|
||||
if (icon) {
|
||||
icon.className = 'fas fa-code-branch';
|
||||
}
|
||||
|
||||
if (!isTreeMode) {
|
||||
recursiveToggleBtn.title = translate('sidebar.recursiveUnavailable');
|
||||
} else if (this.recursiveSearchEnabled) {
|
||||
recursiveToggleBtn.title = translate('sidebar.recursiveOn');
|
||||
} else {
|
||||
recursiveToggleBtn.title = translate('sidebar.recursiveOff');
|
||||
}
|
||||
}
|
||||
|
||||
updateSearchRecursiveOption() {
|
||||
this.pageControls.pageState.searchOptions.recursive = this.displayMode === 'tree';
|
||||
const isRecursive = this.displayMode === 'tree' && this.recursiveSearchEnabled;
|
||||
this.pageControls.pageState.searchOptions.recursive = isRecursive;
|
||||
}
|
||||
|
||||
updateTreeSelection() {
|
||||
@@ -925,15 +987,18 @@ export class SidebarManager {
|
||||
const isPinned = getStorageItem(`${this.pageType}_sidebarPinned`, true);
|
||||
const expandedPaths = getStorageItem(`${this.pageType}_expandedNodes`, []);
|
||||
const displayMode = getStorageItem(`${this.pageType}_displayMode`, 'tree'); // 'tree' or 'list', default to 'tree'
|
||||
const recursiveSearchEnabled = getStorageItem(`${this.pageType}_recursiveSearch`, true);
|
||||
|
||||
this.isPinned = isPinned;
|
||||
this.expandedNodes = new Set(expandedPaths);
|
||||
this.displayMode = displayMode;
|
||||
this.recursiveSearchEnabled = recursiveSearchEnabled;
|
||||
|
||||
this.updatePinButton();
|
||||
this.updateDisplayModeButton();
|
||||
this.updateCollapseAllButton();
|
||||
this.updateSearchRecursiveOption();
|
||||
this.updateRecursiveToggleButton();
|
||||
}
|
||||
|
||||
restoreSelectedFolder() {
|
||||
@@ -974,4 +1039,4 @@ export class SidebarManager {
|
||||
}
|
||||
|
||||
// Create and export global instance
|
||||
export const sidebarManager = new SidebarManager();
|
||||
export const sidebarManager = new SidebarManager();
|
||||
|
||||
@@ -67,7 +67,7 @@ export const state = {
|
||||
modelname: true,
|
||||
tags: false,
|
||||
creator: false,
|
||||
recursive: true,
|
||||
recursive: getStorageItem(`${MODEL_TYPES.LORA}_recursiveSearch`, true),
|
||||
},
|
||||
filters: {
|
||||
baseModel: [],
|
||||
@@ -116,7 +116,7 @@ export const state = {
|
||||
filename: true,
|
||||
modelname: true,
|
||||
creator: false,
|
||||
recursive: true,
|
||||
recursive: getStorageItem(`${MODEL_TYPES.CHECKPOINT}_recursiveSearch`, true),
|
||||
},
|
||||
filters: {
|
||||
baseModel: [],
|
||||
@@ -144,7 +144,7 @@ export const state = {
|
||||
modelname: true,
|
||||
tags: false,
|
||||
creator: false,
|
||||
recursive: true,
|
||||
recursive: getStorageItem(`${MODEL_TYPES.EMBEDDING}_recursiveSearch`, true),
|
||||
},
|
||||
filters: {
|
||||
baseModel: [],
|
||||
@@ -261,4 +261,4 @@ export function initPageState(pageType) {
|
||||
return getCurrentPageState();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -435,8 +435,9 @@ export async function sendLoraToWorkflow(loraSyntax, replaceMode = false, syntax
|
||||
return true;
|
||||
} else {
|
||||
// Single node - send directly
|
||||
const nodeId = Object.keys(registryData.data.nodes)[0];
|
||||
return await sendToSpecificNode([nodeId], loraSyntax, replaceMode, syntaxType);
|
||||
const nodes = registryData.data.nodes;
|
||||
const nodeId = Object.keys(nodes)[0];
|
||||
return await sendToSpecificNode([nodeId], nodes, loraSyntax, replaceMode, syntaxType);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to get registry:', error);
|
||||
@@ -452,19 +453,65 @@ export async function sendLoraToWorkflow(loraSyntax, replaceMode = false, syntax
|
||||
* @param {boolean} replaceMode - Whether to replace existing LoRAs
|
||||
* @param {string} syntaxType - The type of syntax ('lora' or 'recipe')
|
||||
*/
|
||||
async function sendToSpecificNode(nodeIds, loraSyntax, replaceMode, syntaxType) {
|
||||
function resolveNodeReference(nodeKey, nodesMap) {
|
||||
if (!nodeKey) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const directMatch = nodesMap?.[nodeKey];
|
||||
if (directMatch) {
|
||||
return {
|
||||
node_id: directMatch.id,
|
||||
graph_id: directMatch.graph_id ?? null,
|
||||
};
|
||||
}
|
||||
|
||||
if (typeof nodeKey === 'string' && nodeKey.includes(':')) {
|
||||
const [graphId, ...rest] = nodeKey.split(':');
|
||||
const nodeIdPart = rest.join(':');
|
||||
const numericNodeId = Number(nodeIdPart);
|
||||
return {
|
||||
node_id: Number.isNaN(numericNodeId) ? nodeIdPart : numericNodeId,
|
||||
graph_id: graphId || null,
|
||||
};
|
||||
}
|
||||
|
||||
const numericId = Number(nodeKey);
|
||||
return {
|
||||
node_id: Number.isNaN(numericId) ? nodeKey : numericId,
|
||||
graph_id: null,
|
||||
};
|
||||
}
|
||||
|
||||
async function sendToSpecificNode(nodeIds, nodesMap, loraSyntax, replaceMode, syntaxType) {
|
||||
try {
|
||||
// Call the backend API to update the lora code
|
||||
const requestBody = {
|
||||
lora_code: loraSyntax,
|
||||
mode: replaceMode ? 'replace' : 'append'
|
||||
};
|
||||
|
||||
if (Array.isArray(nodeIds)) {
|
||||
const references = nodeIds
|
||||
.map((nodeKey) => resolveNodeReference(nodeKey, nodesMap))
|
||||
.filter((reference) => reference && reference.node_id !== undefined);
|
||||
|
||||
if (references.length > 0) {
|
||||
requestBody.node_ids = references;
|
||||
}
|
||||
} else if (nodeIds) {
|
||||
const reference = resolveNodeReference(nodeIds, nodesMap);
|
||||
if (reference) {
|
||||
requestBody.node_ids = [reference];
|
||||
}
|
||||
}
|
||||
|
||||
const response = await fetch('/api/lm/update-lora-code', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify({
|
||||
node_ids: nodeIds,
|
||||
lora_code: loraSyntax,
|
||||
mode: replaceMode ? 'replace' : 'append'
|
||||
})
|
||||
body: JSON.stringify(requestBody)
|
||||
});
|
||||
|
||||
const result = await response.json();
|
||||
@@ -522,16 +569,17 @@ function showNodeSelector(nodes, loraSyntax, replaceMode, syntaxType) {
|
||||
hideNodeSelector();
|
||||
|
||||
// Generate node list HTML with icons and proper colors
|
||||
const nodeItems = Object.values(nodes).map(node => {
|
||||
const nodeItems = Object.entries(nodes).map(([nodeKey, node]) => {
|
||||
const iconClass = NODE_TYPE_ICONS[node.type] || 'fas fa-question-circle';
|
||||
const bgColor = node.bgcolor || DEFAULT_NODE_COLOR;
|
||||
|
||||
const graphLabel = node.graph_name ? ` (${node.graph_name})` : '';
|
||||
|
||||
return `
|
||||
<div class="node-item" data-node-id="${node.id}">
|
||||
<div class="node-item" data-node-id="${nodeKey}">
|
||||
<div class="node-icon-indicator" style="background-color: ${bgColor}">
|
||||
<i class="${iconClass}"></i>
|
||||
</div>
|
||||
<span>#${node.id} ${node.title}</span>
|
||||
<span>#${node.id}${graphLabel} ${node.title}</span>
|
||||
</div>
|
||||
`;
|
||||
}).join('');
|
||||
@@ -610,10 +658,10 @@ function setupNodeSelectorEvents(selector, nodes, loraSyntax, replaceMode, synta
|
||||
if (action === 'send-all') {
|
||||
// Send to all nodes
|
||||
const allNodeIds = Object.keys(nodes);
|
||||
await sendToSpecificNode(allNodeIds, loraSyntax, replaceMode, syntaxType);
|
||||
await sendToSpecificNode(allNodeIds, nodes, loraSyntax, replaceMode, syntaxType);
|
||||
} else if (nodeId) {
|
||||
// Send to specific node
|
||||
await sendToSpecificNode([nodeId], loraSyntax, replaceMode, syntaxType);
|
||||
await sendToSpecificNode([nodeId], nodes, loraSyntax, replaceMode, syntaxType);
|
||||
}
|
||||
|
||||
hideNodeSelector();
|
||||
|
||||
@@ -9,6 +9,9 @@
|
||||
<button class="sidebar-action-btn" id="sidebarDisplayModeToggle" title="{{ t('sidebar.switchToListView') }}">
|
||||
<i class="fas fa-sitemap"></i>
|
||||
</button>
|
||||
<button class="sidebar-action-btn active" id="sidebarRecursiveToggle" title="{{ t('sidebar.recursiveOn') }}" aria-pressed="true">
|
||||
<i class="fas fa-code-branch"></i>
|
||||
</button>
|
||||
<button class="sidebar-action-btn" id="sidebarCollapseAll" title="{{ t('sidebar.collapseAll') }}">
|
||||
<i class="fas fa-compress-alt"></i>
|
||||
</button>
|
||||
|
||||
@@ -73,6 +73,30 @@ nodes_mock.NODE_CLASS_MAPPINGS = {}
|
||||
sys.modules['nodes'] = nodes_mock
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_settings_dir(tmp_path_factory, monkeypatch):
|
||||
"""Redirect settings.json into a temporary directory for each test."""
|
||||
|
||||
settings_dir = tmp_path_factory.mktemp("settings_dir")
|
||||
|
||||
def fake_get_settings_dir(create: bool = True) -> str:
|
||||
if create:
|
||||
settings_dir.mkdir(exist_ok=True)
|
||||
return str(settings_dir)
|
||||
|
||||
monkeypatch.setattr("py.utils.settings_paths.get_settings_dir", fake_get_settings_dir)
|
||||
monkeypatch.setattr(
|
||||
"py.utils.settings_paths.user_config_dir",
|
||||
lambda *_args, **_kwargs: str(settings_dir),
|
||||
)
|
||||
|
||||
from py.services import settings_manager as settings_manager_module
|
||||
|
||||
settings_manager_module.reset_settings_manager()
|
||||
yield
|
||||
settings_manager_module.reset_settings_manager()
|
||||
|
||||
|
||||
def pytest_pyfunc_call(pyfuncitem):
|
||||
"""Allow bare async tests to run without pytest.mark.asyncio."""
|
||||
test_function = pyfuncitem.function
|
||||
|
||||
@@ -46,6 +46,7 @@ vi.mock(EVENT_MANAGER_MODULE, () => ({
|
||||
off: vi.fn(),
|
||||
addHandler: vi.fn(),
|
||||
removeHandler: vi.fn(),
|
||||
setState: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
@@ -62,6 +63,7 @@ describe('UI helper DOM utilities', () => {
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
delete global.fetch;
|
||||
});
|
||||
|
||||
it('creates toast elements and cleans them up after timeout', async () => {
|
||||
@@ -105,4 +107,53 @@ describe('UI helper DOM utilities', () => {
|
||||
expect(document.body.dataset.theme).toBe('dark');
|
||||
expect(document.querySelector('.theme-toggle').classList.contains('theme-dark')).toBe(true);
|
||||
});
|
||||
|
||||
it('renders subgraph names in the node selector list', async () => {
|
||||
const registryResponse = {
|
||||
success: true,
|
||||
data: {
|
||||
node_count: 2,
|
||||
nodes: {
|
||||
'root:1': {
|
||||
id: 1,
|
||||
graph_id: 'root',
|
||||
graph_name: null,
|
||||
title: 'Root Loader',
|
||||
type: 1,
|
||||
bgcolor: '#123456',
|
||||
},
|
||||
'subgraph-uuid:2': {
|
||||
id: 2,
|
||||
graph_id: 'subgraph-uuid',
|
||||
graph_name: 'Character Subgraph',
|
||||
title: 'Nested Loader',
|
||||
type: 1,
|
||||
bgcolor: '#654321',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
global.fetch = vi.fn().mockResolvedValue({
|
||||
json: async () => registryResponse,
|
||||
});
|
||||
|
||||
document.body.innerHTML = '<div id="nodeSelector"></div>';
|
||||
|
||||
const { sendLoraToWorkflow } = await import(UI_HELPERS_MODULE);
|
||||
|
||||
const result = await sendLoraToWorkflow('<lora:test:1>');
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(global.fetch).toHaveBeenCalledWith('/api/lm/get-registry');
|
||||
|
||||
const nodeLabels = Array.from(
|
||||
document.querySelectorAll('#nodeSelector .node-item[data-node-id] span')
|
||||
).map((span) => span.textContent.trim());
|
||||
|
||||
expect(nodeLabels).toEqual([
|
||||
'#1 Root Loader',
|
||||
'#2 (Character Subgraph) Nested Loader',
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -16,10 +16,12 @@ from aiohttp.test_utils import TestClient, TestServer
|
||||
from py.config import config
|
||||
from py.routes.base_model_routes import BaseModelRoutes
|
||||
from py.services import model_file_service
|
||||
from py.services.metadata_sync_service import MetadataSyncService
|
||||
from py.services.model_file_service import AutoOrganizeResult
|
||||
from py.services.service_registry import ServiceRegistry
|
||||
from py.services.websocket_manager import ws_manager
|
||||
from py.utils.exif_utils import ExifUtils
|
||||
from py.utils.metadata_manager import MetadataManager
|
||||
|
||||
|
||||
class DummyRoutes(BaseModelRoutes):
|
||||
@@ -197,6 +199,116 @@ def test_replace_preview_writes_file_and_updates_cache(
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_fetch_civitai_hydrates_metadata_before_sync(
|
||||
mock_service,
|
||||
mock_scanner,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
):
|
||||
model_path = tmp_path / "hydrate.safetensors"
|
||||
model_path.write_bytes(b"model")
|
||||
metadata_path = tmp_path / "hydrate.metadata.json"
|
||||
|
||||
existing_metadata = {
|
||||
"file_path": str(model_path),
|
||||
"sha256": "abc123",
|
||||
"model_name": "Hydrated",
|
||||
"preview_url": "keep/me.png",
|
||||
"civitai": {
|
||||
"id": 99,
|
||||
"modelId": 42,
|
||||
"images": [{"url": "https://example.com/existing.png", "type": "image"}],
|
||||
"customImages": [{"id": "old-id", "url": "", "type": "image"}],
|
||||
"trainedWords": ["keep"],
|
||||
},
|
||||
"custom_field": "preserve",
|
||||
}
|
||||
metadata_path.write_text(json.dumps(existing_metadata), encoding="utf-8")
|
||||
|
||||
minimal_cache_entry = {
|
||||
"file_path": str(model_path),
|
||||
"sha256": "abc123",
|
||||
"folder": "some/folder",
|
||||
"civitai": {"id": 99, "modelId": 42},
|
||||
}
|
||||
mock_scanner._cache.raw_data = [minimal_cache_entry]
|
||||
|
||||
class FakeMetadata:
|
||||
def __init__(self, payload: dict) -> None:
|
||||
self._payload = payload
|
||||
self._unknown_fields = {"legacy_field": "legacy"}
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return self._payload.copy()
|
||||
|
||||
async def fake_load_metadata(path: str, *_args, **_kwargs):
|
||||
assert path == str(model_path)
|
||||
return FakeMetadata(existing_metadata), False
|
||||
|
||||
async def fake_save_metadata(path: str, metadata: dict) -> bool:
|
||||
save_calls.append((path, json.loads(json.dumps(metadata))))
|
||||
return True
|
||||
|
||||
async def fake_fetch_and_update_model(
|
||||
self,
|
||||
*,
|
||||
sha256: str,
|
||||
file_path: str,
|
||||
model_data: dict,
|
||||
update_cache_func,
|
||||
):
|
||||
captured["model_data"] = json.loads(json.dumps(model_data))
|
||||
to_save = model_data.copy()
|
||||
to_save.pop("folder", None)
|
||||
await MetadataManager.save_metadata(
|
||||
os.path.splitext(file_path)[0] + ".metadata.json",
|
||||
to_save,
|
||||
)
|
||||
await update_cache_func(file_path, file_path, model_data)
|
||||
return True, None
|
||||
|
||||
save_calls: list[tuple[str, dict]] = []
|
||||
captured: dict[str, dict] = {}
|
||||
|
||||
monkeypatch.setattr(MetadataManager, "load_metadata", staticmethod(fake_load_metadata))
|
||||
monkeypatch.setattr(MetadataManager, "save_metadata", staticmethod(fake_save_metadata))
|
||||
monkeypatch.setattr(MetadataSyncService, "fetch_and_update_model", fake_fetch_and_update_model)
|
||||
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
try:
|
||||
response = await client.post(
|
||||
"/api/lm/test-models/fetch-civitai",
|
||||
json={"file_path": str(model_path)},
|
||||
)
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
assert payload["success"] is True
|
||||
assert captured["model_data"]["custom_field"] == "preserve"
|
||||
assert captured["model_data"]["civitai"]["images"][0]["url"] == "https://example.com/existing.png"
|
||||
assert captured["model_data"]["civitai"]["trainedWords"] == ["keep"]
|
||||
assert captured["model_data"]["civitai"]["id"] == 99
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
assert save_calls, "Metadata save should be invoked"
|
||||
saved_path, saved_payload = save_calls[0]
|
||||
assert saved_path == str(metadata_path)
|
||||
assert saved_payload["custom_field"] == "preserve"
|
||||
assert saved_payload["civitai"]["images"][0]["url"] == "https://example.com/existing.png"
|
||||
assert saved_payload["civitai"]["trainedWords"] == ["keep"]
|
||||
assert saved_payload["civitai"]["id"] == 99
|
||||
assert saved_payload["legacy_field"] == "legacy"
|
||||
|
||||
assert mock_scanner.updated_models
|
||||
updated_metadata = mock_scanner.updated_models[-1]["metadata"]
|
||||
assert updated_metadata["custom_field"] == "preserve"
|
||||
assert updated_metadata["civitai"]["customImages"][0]["id"] == "old-id"
|
||||
|
||||
|
||||
def test_download_model_invokes_download_manager(
|
||||
mock_service,
|
||||
download_manager_stub,
|
||||
|
||||
@@ -188,7 +188,7 @@ async def test_get_trigger_words_broadcasts(monkeypatch, routes):
|
||||
|
||||
monkeypatch.setattr("py.routes.lora_routes.get_lora_info", lambda name: (f"path/{name}", [f"trigger-{name}"]))
|
||||
|
||||
request = DummyRequest(json_data={"lora_names": ["one"], "node_ids": ["node"]})
|
||||
request = DummyRequest(json_data={"lora_names": ["one"], "node_ids": [{"node_id": "node", "graph_id": "graph-1"}]})
|
||||
|
||||
response = await routes.get_trigger_words(request)
|
||||
payload = json.loads(response.text)
|
||||
@@ -196,7 +196,7 @@ async def test_get_trigger_words_broadcasts(monkeypatch, routes):
|
||||
assert payload == {"success": True}
|
||||
send_mock.assert_called_once_with(
|
||||
"trigger_word_update",
|
||||
{"id": "node", "message": "trigger-one"},
|
||||
{"id": "node", "graph_id": "graph-1", "message": "trigger-one"},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,14 @@ from types import SimpleNamespace
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from py.routes.handlers.misc_handlers import SettingsHandler, ServiceRegistryAdapter
|
||||
from py.routes.handlers.misc_handlers import (
|
||||
LoraCodeHandler,
|
||||
ModelLibraryHandler,
|
||||
NodeRegistry,
|
||||
NodeRegistryHandler,
|
||||
ServiceRegistryAdapter,
|
||||
SettingsHandler,
|
||||
)
|
||||
from py.routes.misc_route_registrar import MISC_ROUTE_DEFINITIONS, MiscRouteRegistrar
|
||||
from py.routes.misc_routes import MiscRoutes
|
||||
|
||||
@@ -126,6 +133,128 @@ class FakePromptServer:
|
||||
instance = Instance()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_nodes_requires_graph_id():
|
||||
node_registry = NodeRegistry()
|
||||
handler = NodeRegistryHandler(
|
||||
node_registry=node_registry,
|
||||
prompt_server=FakePromptServer,
|
||||
standalone_mode=False,
|
||||
)
|
||||
|
||||
request = FakeRequest(json_data={"nodes": [{"node_id": 1}]})
|
||||
response = await handler.register_nodes(request)
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert response.status == 400
|
||||
assert payload["success"] is False
|
||||
assert "graph_id" in payload["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_nodes_stores_graph_identifier():
|
||||
node_registry = NodeRegistry()
|
||||
handler = NodeRegistryHandler(
|
||||
node_registry=node_registry,
|
||||
prompt_server=FakePromptServer,
|
||||
standalone_mode=False,
|
||||
)
|
||||
|
||||
request = FakeRequest(
|
||||
json_data={
|
||||
"nodes": [
|
||||
{
|
||||
"node_id": 7,
|
||||
"graph_id": "graph-123",
|
||||
"graph_name": "Character Subgraph",
|
||||
"type": "Lora Loader (LoraManager)",
|
||||
"title": "Loader",
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
response = await handler.register_nodes(request)
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert payload["success"] is True
|
||||
|
||||
registry = await node_registry.get_registry()
|
||||
assert registry["node_count"] == 1
|
||||
stored_node = next(iter(registry["nodes"].values()))
|
||||
assert stored_node["graph_id"] == "graph-123"
|
||||
assert stored_node["unique_id"] == "graph-123:7"
|
||||
assert stored_node["graph_name"] == "Character Subgraph"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_nodes_defaults_graph_name_to_none():
|
||||
node_registry = NodeRegistry()
|
||||
handler = NodeRegistryHandler(
|
||||
node_registry=node_registry,
|
||||
prompt_server=FakePromptServer,
|
||||
standalone_mode=False,
|
||||
)
|
||||
|
||||
request = FakeRequest(
|
||||
json_data={
|
||||
"nodes": [
|
||||
{
|
||||
"node_id": 8,
|
||||
"graph_id": "root",
|
||||
"type": "Lora Loader (LoraManager)",
|
||||
"title": "Root Loader",
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
response = await handler.register_nodes(request)
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert payload["success"] is True
|
||||
|
||||
registry = await node_registry.get_registry()
|
||||
stored_node = next(iter(registry["nodes"].values()))
|
||||
assert stored_node["graph_name"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_lora_code_includes_graph_identifier():
|
||||
send_calls: list[tuple[str, dict]] = []
|
||||
|
||||
class RecordingPromptServer:
|
||||
class Instance:
|
||||
def send_sync(self, event, payload):
|
||||
send_calls.append((event, payload))
|
||||
|
||||
instance = Instance()
|
||||
|
||||
handler = LoraCodeHandler(RecordingPromptServer)
|
||||
|
||||
request = FakeRequest(
|
||||
json_data={
|
||||
"node_ids": [{"node_id": 3, "graph_id": "graph-A"}],
|
||||
"lora_code": "<lora>",
|
||||
"mode": "replace",
|
||||
}
|
||||
)
|
||||
|
||||
response = await handler.update_lora_code(request)
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert payload["success"] is True
|
||||
assert payload["results"] == [
|
||||
{"node_id": 3, "graph_id": "graph-A", "success": True}
|
||||
]
|
||||
assert send_calls == [
|
||||
(
|
||||
"lora_code_update",
|
||||
{"id": 3, "graph_id": "graph-A", "lora_code": "<lora>", "mode": "replace"},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class FakeScanner:
|
||||
async def check_model_version_exists(self, _version_id):
|
||||
return False
|
||||
@@ -138,10 +267,34 @@ async def fake_scanner_factory():
|
||||
return FakeScanner()
|
||||
|
||||
|
||||
class FakeExistenceScanner:
|
||||
def __init__(self, existing=None):
|
||||
self._existing = set(existing or [])
|
||||
|
||||
async def check_model_version_exists(self, version_id):
|
||||
return version_id in self._existing
|
||||
|
||||
async def get_model_versions_by_id(self, _model_id):
|
||||
return []
|
||||
|
||||
|
||||
class FakeMetadataProvider:
|
||||
async def get_model_versions(self, _model_id):
|
||||
return {"modelVersions": [], "name": "", "type": "lora"}
|
||||
|
||||
async def get_user_models(self, _username):
|
||||
return []
|
||||
|
||||
|
||||
class FakeUserModelsProvider(FakeMetadataProvider):
|
||||
def __init__(self, models):
|
||||
self.models = models
|
||||
self.received_usernames: list[str] = []
|
||||
|
||||
async def get_user_models(self, username):
|
||||
self.received_usernames.append(username)
|
||||
return self.models
|
||||
|
||||
|
||||
async def fake_metadata_provider_factory():
|
||||
return FakeMetadataProvider()
|
||||
@@ -211,6 +364,250 @@ async def test_misc_routes_bind_produces_expected_handlers():
|
||||
assert set(mapping.keys()) == expected_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_civitai_user_models_marks_library_versions():
|
||||
models = [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "Model A",
|
||||
"type": "LORA",
|
||||
"tags": ["style"],
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 100,
|
||||
"name": "v1",
|
||||
"baseModel": "Flux.1",
|
||||
"images": [{"url": "http://example.com/a1.jpg"}],
|
||||
},
|
||||
{
|
||||
"id": 101,
|
||||
"name": "v2",
|
||||
"baseModel": "Flux.1",
|
||||
"images": [{"url": "http://example.com/a2.jpg"}],
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"name": "Embedding",
|
||||
"type": "TextualInversion",
|
||||
"tags": ["embedding"],
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 200,
|
||||
"name": "v1",
|
||||
"baseModel": None,
|
||||
"images": [{"url": "http://example.com/e1.jpg"}],
|
||||
},
|
||||
{
|
||||
"id": 202,
|
||||
"name": "v2",
|
||||
"baseModel": None,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"name": "Checkpoint",
|
||||
"type": "Checkpoint",
|
||||
"tags": ["checkpoint"],
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 300,
|
||||
"name": "v1",
|
||||
"baseModel": "SDXL",
|
||||
"images": [],
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": 4,
|
||||
"name": "Unsupported",
|
||||
"type": "Other",
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 400,
|
||||
"name": "v1",
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
provider = FakeUserModelsProvider(models)
|
||||
|
||||
async def provider_factory():
|
||||
return provider
|
||||
|
||||
lora_scanner = FakeExistenceScanner({101})
|
||||
checkpoint_scanner = FakeExistenceScanner()
|
||||
embedding_scanner = FakeExistenceScanner({202})
|
||||
|
||||
async def lora_factory():
|
||||
return lora_scanner
|
||||
|
||||
async def checkpoint_factory():
|
||||
return checkpoint_scanner
|
||||
|
||||
async def embedding_factory():
|
||||
return embedding_scanner
|
||||
|
||||
handler = ModelLibraryHandler(
|
||||
ServiceRegistryAdapter(
|
||||
get_lora_scanner=lora_factory,
|
||||
get_checkpoint_scanner=checkpoint_factory,
|
||||
get_embedding_scanner=embedding_factory,
|
||||
),
|
||||
metadata_provider_factory=provider_factory,
|
||||
)
|
||||
|
||||
response = await handler.get_civitai_user_models(FakeRequest(query={"username": "pixel"}))
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert payload["success"] is True
|
||||
assert payload["username"] == "pixel"
|
||||
assert payload["versions"] == [
|
||||
{
|
||||
"modelId": 1,
|
||||
"versionId": 100,
|
||||
"modelName": "Model A",
|
||||
"versionName": "v1",
|
||||
"type": "LORA",
|
||||
"tags": ["style"],
|
||||
"baseModel": "Flux.1",
|
||||
"thumbnailUrl": "http://example.com/a1.jpg",
|
||||
"inLibrary": False,
|
||||
},
|
||||
{
|
||||
"modelId": 1,
|
||||
"versionId": 101,
|
||||
"modelName": "Model A",
|
||||
"versionName": "v2",
|
||||
"type": "LORA",
|
||||
"tags": ["style"],
|
||||
"baseModel": "Flux.1",
|
||||
"thumbnailUrl": "http://example.com/a2.jpg",
|
||||
"inLibrary": True,
|
||||
},
|
||||
{
|
||||
"modelId": 2,
|
||||
"versionId": 200,
|
||||
"modelName": "Embedding",
|
||||
"versionName": "v1",
|
||||
"type": "TextualInversion",
|
||||
"tags": ["embedding"],
|
||||
"baseModel": None,
|
||||
"thumbnailUrl": "http://example.com/e1.jpg",
|
||||
"inLibrary": False,
|
||||
},
|
||||
{
|
||||
"modelId": 2,
|
||||
"versionId": 202,
|
||||
"modelName": "Embedding",
|
||||
"versionName": "v2",
|
||||
"type": "TextualInversion",
|
||||
"tags": ["embedding"],
|
||||
"baseModel": None,
|
||||
"thumbnailUrl": None,
|
||||
"inLibrary": True,
|
||||
},
|
||||
{
|
||||
"modelId": 3,
|
||||
"versionId": 300,
|
||||
"modelName": "Checkpoint",
|
||||
"versionName": "v1",
|
||||
"type": "Checkpoint",
|
||||
"tags": ["checkpoint"],
|
||||
"baseModel": "SDXL",
|
||||
"thumbnailUrl": None,
|
||||
"inLibrary": False,
|
||||
},
|
||||
]
|
||||
|
||||
assert provider.received_usernames == ["pixel"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_civitai_user_models_rewrites_civitai_previews():
|
||||
image_url = "https://image.civitai.com/container/example/original=true/sample.jpeg"
|
||||
video_url = "https://image.civitai.com/container/example/original=true/sample.mp4"
|
||||
|
||||
models = [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "Model A",
|
||||
"type": "LORA",
|
||||
"tags": ["style"],
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 100,
|
||||
"name": "preview-image",
|
||||
"baseModel": "Flux.1",
|
||||
"images": [
|
||||
{"url": image_url, "type": "image"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": 101,
|
||||
"name": "preview-video",
|
||||
"baseModel": "Flux.1",
|
||||
"images": [
|
||||
{"url": video_url, "type": "video"},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
provider = FakeUserModelsProvider(models)
|
||||
|
||||
async def provider_factory():
|
||||
return provider
|
||||
|
||||
handler = ModelLibraryHandler(
|
||||
ServiceRegistryAdapter(
|
||||
get_lora_scanner=fake_scanner_factory,
|
||||
get_checkpoint_scanner=fake_scanner_factory,
|
||||
get_embedding_scanner=fake_scanner_factory,
|
||||
),
|
||||
metadata_provider_factory=provider_factory,
|
||||
)
|
||||
|
||||
response = await handler.get_civitai_user_models(FakeRequest(query={"username": "pixel"}))
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert payload["success"] is True
|
||||
previews_by_version = {item["versionId"]: item["thumbnailUrl"] for item in payload["versions"]}
|
||||
assert previews_by_version[100] == "https://image.civitai.com/container/example/width=450,optimized=true/sample.jpeg"
|
||||
assert (
|
||||
previews_by_version[101]
|
||||
== "https://image.civitai.com/container/example/transcode=true,width=450,optimized=true/sample.mp4"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_civitai_user_models_requires_username():
|
||||
provider = FakeUserModelsProvider([])
|
||||
|
||||
async def provider_factory():
|
||||
return provider
|
||||
|
||||
handler = ModelLibraryHandler(
|
||||
ServiceRegistryAdapter(
|
||||
get_lora_scanner=fake_scanner_factory,
|
||||
get_checkpoint_scanner=fake_scanner_factory,
|
||||
get_embedding_scanner=fake_scanner_factory,
|
||||
),
|
||||
metadata_provider_factory=provider_factory,
|
||||
)
|
||||
|
||||
response = await handler.get_civitai_user_models(FakeRequest())
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert response.status == 400
|
||||
assert payload["success"] is False
|
||||
assert "username" in payload["error"].lower()
|
||||
|
||||
|
||||
def test_ensure_handler_mapping_caches_result():
|
||||
call_records = []
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
@@ -169,6 +170,158 @@ async def test_get_model_version_by_version_id(monkeypatch, downloader):
|
||||
assert result["images"][0]["meta"]["other"] == "keep"
|
||||
|
||||
|
||||
async def test_get_model_version_with_model_id_prefers_version_endpoint(monkeypatch, downloader):
|
||||
requests = []
|
||||
|
||||
model_payload = {
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 7,
|
||||
"files": [
|
||||
{
|
||||
"type": "Model",
|
||||
"primary": True,
|
||||
"hashes": {"SHA256": "hash7"},
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"description": "desc",
|
||||
"tags": ["tag"],
|
||||
"creator": {"username": "user"},
|
||||
"name": "Model",
|
||||
"type": "LORA",
|
||||
"nsfw": False,
|
||||
"poi": False,
|
||||
}
|
||||
|
||||
version_payload = {
|
||||
"id": 7,
|
||||
"modelId": 99,
|
||||
"model": {},
|
||||
"files": [],
|
||||
"images": [],
|
||||
}
|
||||
|
||||
async def fake_make_request(method, url, use_auth=True):
|
||||
requests.append(url)
|
||||
if url.endswith("/models/99"):
|
||||
return True, copy.deepcopy(model_payload)
|
||||
if url.endswith("/model-versions/7"):
|
||||
return True, copy.deepcopy(version_payload)
|
||||
return False, "unexpected"
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result = await client.get_model_version(model_id=99, version_id=7)
|
||||
|
||||
assert result["id"] == 7
|
||||
assert result["model"]["description"] == "desc"
|
||||
assert result["model"]["tags"] == ["tag"]
|
||||
assert result["creator"] == {"username": "user"}
|
||||
assert requests[0].endswith("/models/99")
|
||||
assert requests[1].endswith("/model-versions/7")
|
||||
|
||||
|
||||
async def test_get_model_version_with_model_id_fallbacks_to_hash(monkeypatch, downloader):
|
||||
requests = []
|
||||
|
||||
model_payload = {
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 7,
|
||||
"files": [
|
||||
{
|
||||
"type": "Model",
|
||||
"primary": True,
|
||||
"hashes": {"SHA256": "hash7"},
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"description": "desc",
|
||||
"tags": ["tag"],
|
||||
"creator": {"username": "user"},
|
||||
"name": "Model",
|
||||
"type": "LORA",
|
||||
"nsfw": False,
|
||||
"poi": False,
|
||||
}
|
||||
|
||||
version_payload = {
|
||||
"id": 7,
|
||||
"modelId": 99,
|
||||
"files": [],
|
||||
"images": [],
|
||||
}
|
||||
|
||||
async def fake_make_request(method, url, use_auth=True):
|
||||
requests.append(url)
|
||||
if url.endswith("/models/99"):
|
||||
return True, copy.deepcopy(model_payload)
|
||||
if url.endswith("/model-versions/7"):
|
||||
return False, "boom"
|
||||
if url.endswith("/model-versions/by-hash/hash7"):
|
||||
return True, copy.deepcopy(version_payload)
|
||||
return False, "unexpected"
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result = await client.get_model_version(model_id=99, version_id=7)
|
||||
|
||||
assert result["id"] == 7
|
||||
assert result["model"]["description"] == "desc"
|
||||
assert result["model"]["tags"] == ["tag"]
|
||||
assert result["creator"] == {"username": "user"}
|
||||
assert requests[1].endswith("/model-versions/7")
|
||||
assert requests[2].endswith("/model-versions/by-hash/hash7")
|
||||
|
||||
|
||||
async def test_get_model_version_with_model_id_builds_from_model_data(monkeypatch, downloader):
|
||||
model_payload = {
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 7,
|
||||
"files": [],
|
||||
"name": "v1",
|
||||
}
|
||||
],
|
||||
"description": "desc",
|
||||
"tags": ["tag"],
|
||||
"creator": {"username": "user"},
|
||||
"name": "Model",
|
||||
"type": "LORA",
|
||||
"nsfw": False,
|
||||
"poi": False,
|
||||
}
|
||||
|
||||
async def fake_make_request(method, url, use_auth=True):
|
||||
if url.endswith("/models/99"):
|
||||
return True, copy.deepcopy(model_payload)
|
||||
if url.endswith("/model-versions/7"):
|
||||
return False, "boom"
|
||||
if "/model-versions/by-hash/" in url:
|
||||
return False, "boom"
|
||||
return False, "unexpected"
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result = await client.get_model_version(model_id=99, version_id=7)
|
||||
|
||||
assert result["modelId"] == 99
|
||||
assert result["model"]["name"] == "Model"
|
||||
assert result["model"]["type"] == "LORA"
|
||||
assert result["model"]["description"] == "desc"
|
||||
assert result["model"]["tags"] == ["tag"]
|
||||
assert result["creator"] == {"username": "user"}
|
||||
|
||||
|
||||
async def test_get_model_version_requires_identifier(monkeypatch, downloader):
|
||||
client = await CivitaiClient.get_instance()
|
||||
result = await client.get_model_version()
|
||||
|
||||
@@ -8,7 +8,7 @@ import pytest
|
||||
from py.services.download_manager import DownloadManager
|
||||
from py.services import download_manager
|
||||
from py.services.service_registry import ServiceRegistry
|
||||
from py.services.settings_manager import settings
|
||||
from py.services.settings_manager import SettingsManager, get_settings_manager
|
||||
from py.utils.metadata_manager import MetadataManager
|
||||
|
||||
|
||||
@@ -23,7 +23,8 @@ def reset_download_manager():
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolate_settings(monkeypatch, tmp_path):
|
||||
"""Point settings writes at a temporary directory to avoid touching real files."""
|
||||
default_settings = settings._get_default_settings()
|
||||
manager = get_settings_manager()
|
||||
default_settings = manager._get_default_settings()
|
||||
default_settings.update(
|
||||
{
|
||||
"default_lora_root": str(tmp_path),
|
||||
@@ -37,8 +38,8 @@ def isolate_settings(monkeypatch, tmp_path):
|
||||
"base_model_path_mappings": {"BaseModel": "MappedModel"},
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(settings, "settings", default_settings)
|
||||
monkeypatch.setattr(type(settings), "_save_settings", lambda self: None)
|
||||
monkeypatch.setattr(manager, "settings", default_settings)
|
||||
monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -187,7 +188,7 @@ async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata
|
||||
assert manager._active_downloads[result["download_id"]]["status"] == "completed"
|
||||
|
||||
assert captured["relative_path"] == "MappedModel/fantasy"
|
||||
expected_dir = Path(settings.get("default_lora_root")) / "MappedModel" / "fantasy"
|
||||
expected_dir = Path(get_settings_manager().get("default_lora_root")) / "MappedModel" / "fantasy"
|
||||
assert captured["save_dir"] == expected_dir
|
||||
assert captured["model_type"] == "lora"
|
||||
assert captured["download_urls"] == [
|
||||
@@ -393,3 +394,98 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
|
||||
assert result == {"success": True}
|
||||
assert [url for url, *_ in dummy_downloader.calls] == download_urls
|
||||
assert dummy_scanner.calls # ensure cache updated
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_download_uses_rewritten_civitai_preview(monkeypatch, tmp_path):
|
||||
manager = DownloadManager()
|
||||
save_dir = tmp_path / "downloads"
|
||||
save_dir.mkdir()
|
||||
target_path = save_dir / "file.safetensors"
|
||||
|
||||
manager._active_downloads["dl"] = {}
|
||||
|
||||
class DummyMetadata:
|
||||
def __init__(self, path: Path):
|
||||
self.file_path = str(path)
|
||||
self.sha256 = "sha256"
|
||||
self.file_name = path.stem
|
||||
self.preview_url = None
|
||||
self.preview_nsfw_level = None
|
||||
|
||||
def generate_unique_filename(self, *_args, **_kwargs):
|
||||
return os.path.basename(self.file_path)
|
||||
|
||||
def update_file_info(self, _path):
|
||||
return None
|
||||
|
||||
def to_dict(self):
|
||||
return {"file_path": self.file_path}
|
||||
|
||||
metadata = DummyMetadata(target_path)
|
||||
version_info = {
|
||||
"images": [
|
||||
{
|
||||
"url": "https://image.civitai.com/container/example/original=true/sample.jpeg",
|
||||
"type": "image",
|
||||
"nsfwLevel": 2,
|
||||
}
|
||||
]
|
||||
}
|
||||
download_urls = ["https://example.invalid/file.safetensors"]
|
||||
|
||||
class DummyDownloader:
|
||||
def __init__(self):
|
||||
self.file_calls: list[tuple[str, str]] = []
|
||||
self.memory_calls = 0
|
||||
|
||||
async def download_file(self, url, path, progress_callback=None, use_auth=None):
|
||||
self.file_calls.append((url, path))
|
||||
if url.endswith(".jpeg"):
|
||||
Path(path).write_bytes(b"preview")
|
||||
return True, None
|
||||
if url.endswith(".safetensors"):
|
||||
Path(path).write_bytes(b"model")
|
||||
return True, None
|
||||
return False, "unexpected url"
|
||||
|
||||
async def download_to_memory(self, *_args, **_kwargs):
|
||||
self.memory_calls += 1
|
||||
return False, b"", {}
|
||||
|
||||
dummy_downloader = DummyDownloader()
|
||||
monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=dummy_downloader))
|
||||
|
||||
optimize_called = {"value": False}
|
||||
|
||||
def fake_optimize_image(**_kwargs):
|
||||
optimize_called["value"] = True
|
||||
return b"", {}
|
||||
|
||||
monkeypatch.setattr(download_manager.ExifUtils, "optimize_image", staticmethod(fake_optimize_image))
|
||||
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
|
||||
|
||||
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
||||
monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner))
|
||||
|
||||
result = await manager._execute_download(
|
||||
download_urls=download_urls,
|
||||
save_dir=str(save_dir),
|
||||
metadata=metadata,
|
||||
version_info=version_info,
|
||||
relative_path="",
|
||||
progress_callback=None,
|
||||
model_type="lora",
|
||||
download_id="dl",
|
||||
)
|
||||
|
||||
assert result == {"success": True}
|
||||
preview_urls = [url for url, _ in dummy_downloader.file_calls if url.endswith(".jpeg")]
|
||||
assert any("width=450,optimized=true" in url for url in preview_urls)
|
||||
assert dummy_downloader.memory_calls == 0
|
||||
assert optimize_called["value"] is False
|
||||
assert metadata.preview_url.endswith(".jpeg")
|
||||
assert metadata.preview_nsfw_level == 2
|
||||
stored_preview = manager._active_downloads["dl"]["preview_path"]
|
||||
assert stored_preview.endswith(".jpeg")
|
||||
assert Path(stored_preview).exists()
|
||||
|
||||
@@ -6,7 +6,7 @@ import pytest
|
||||
|
||||
from py.services.example_images_cleanup_service import ExampleImagesCleanupService
|
||||
from py.services.service_registry import ServiceRegistry
|
||||
from py.services.settings_manager import settings
|
||||
from py.services.settings_manager import get_settings_manager
|
||||
|
||||
|
||||
class StubScanner:
|
||||
@@ -21,8 +21,9 @@ class StubScanner:
|
||||
async def test_cleanup_moves_empty_and_orphaned(tmp_path, monkeypatch):
|
||||
service = ExampleImagesCleanupService()
|
||||
|
||||
previous_path = settings.get('example_images_path')
|
||||
settings.settings['example_images_path'] = str(tmp_path)
|
||||
settings_manager = get_settings_manager()
|
||||
previous_path = settings_manager.get('example_images_path')
|
||||
settings_manager.settings['example_images_path'] = str(tmp_path)
|
||||
|
||||
try:
|
||||
empty_folder = tmp_path / 'empty_folder'
|
||||
@@ -64,23 +65,24 @@ async def test_cleanup_moves_empty_and_orphaned(tmp_path, monkeypatch):
|
||||
|
||||
finally:
|
||||
if previous_path is None:
|
||||
settings.settings.pop('example_images_path', None)
|
||||
settings_manager.settings.pop('example_images_path', None)
|
||||
else:
|
||||
settings.settings['example_images_path'] = previous_path
|
||||
settings_manager.settings['example_images_path'] = previous_path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_handles_missing_path(monkeypatch):
|
||||
service = ExampleImagesCleanupService()
|
||||
|
||||
previous_path = settings.get('example_images_path')
|
||||
settings.settings.pop('example_images_path', None)
|
||||
settings_manager = get_settings_manager()
|
||||
previous_path = settings_manager.get('example_images_path')
|
||||
settings_manager.settings.pop('example_images_path', None)
|
||||
|
||||
try:
|
||||
result = await service.cleanup_example_image_folders()
|
||||
finally:
|
||||
if previous_path is not None:
|
||||
settings.settings['example_images_path'] = previous_path
|
||||
settings_manager.settings['example_images_path'] = previous_path
|
||||
|
||||
assert result['success'] is False
|
||||
assert result['error_code'] == 'path_not_configured'
|
||||
|
||||
@@ -7,7 +7,7 @@ from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.settings_manager import settings
|
||||
from py.services.settings_manager import SettingsManager, get_settings_manager
|
||||
from py.utils import example_images_download_manager as download_module
|
||||
|
||||
|
||||
@@ -43,11 +43,15 @@ def _patch_scanner(monkeypatch: pytest.MonkeyPatch, scanner: StubScanner) -> Non
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_start_download_rejects_parallel_runs(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
async def test_start_download_rejects_parallel_runs(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
model = {
|
||||
"sha256": "abc123",
|
||||
@@ -106,11 +110,15 @@ async def test_start_download_rejects_parallel_runs(monkeypatch: pytest.MonkeyPa
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_pause_resume_blocks_processing(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
async def test_pause_resume_blocks_processing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
models = [
|
||||
{
|
||||
@@ -231,13 +239,17 @@ async def test_pause_resume_blocks_processing(monkeypatch: pytest.MonkeyPatch, t
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_legacy_folder_migrated_and_skipped(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
async def test_legacy_folder_migrated_and_skipped(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
|
||||
monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}})
|
||||
monkeypatch.setitem(settings.settings, "active_library", "extra")
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
monkeypatch.setitem(settings_manager.settings, "libraries", {"default": {}, "extra": {}})
|
||||
monkeypatch.setitem(settings_manager.settings, "active_library", "extra")
|
||||
|
||||
model_hash = "d" * 64
|
||||
model_path = tmp_path / "model.safetensors"
|
||||
@@ -310,13 +322,17 @@ async def test_legacy_folder_migrated_and_skipped(monkeypatch: pytest.MonkeyPatc
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_legacy_progress_file_migrates(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
async def test_legacy_progress_file_migrates(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
|
||||
monkeypatch.setitem(settings.settings, "libraries", {"default": {}, "extra": {}})
|
||||
monkeypatch.setitem(settings.settings, "active_library", "extra")
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
monkeypatch.setitem(settings_manager.settings, "libraries", {"default": {}, "extra": {}})
|
||||
monkeypatch.setitem(settings_manager.settings, "active_library", "extra")
|
||||
|
||||
model_hash = "e" * 64
|
||||
model_path = tmp_path / "model-two.safetensors"
|
||||
@@ -380,20 +396,24 @@ async def test_legacy_progress_file_migrates(monkeypatch: pytest.MonkeyPatch, tm
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_download_remains_in_initial_library(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
async def test_download_remains_in_initial_library(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
|
||||
monkeypatch.setitem(settings.settings, "libraries", {"LibraryA": {}, "LibraryB": {}})
|
||||
monkeypatch.setitem(settings.settings, "active_library", "LibraryA")
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
monkeypatch.setitem(settings_manager.settings, "libraries", {"LibraryA": {}, "LibraryB": {}})
|
||||
monkeypatch.setitem(settings_manager.settings, "active_library", "LibraryA")
|
||||
|
||||
state = {"active": "LibraryA"}
|
||||
|
||||
def fake_get_active_library_name(self):
|
||||
return state["active"]
|
||||
|
||||
monkeypatch.setattr(type(settings), "get_active_library_name", fake_get_active_library_name)
|
||||
monkeypatch.setattr(SettingsManager, "get_active_library_name", fake_get_active_library_name)
|
||||
|
||||
model_hash = "f" * 64
|
||||
model_path = tmp_path / "example-model.safetensors"
|
||||
@@ -454,3 +474,7 @@ async def test_download_remains_in_initial_library(monkeypatch: pytest.MonkeyPat
|
||||
assert (model_dir / "local.txt").exists()
|
||||
assert not (library_b_root / ".download_progress.json").exists()
|
||||
assert not (library_b_root / model_hash).exists()
|
||||
|
||||
@pytest.fixture
|
||||
def settings_manager():
|
||||
return get_settings_manager()
|
||||
|
||||
@@ -243,6 +243,7 @@ async def test_initialize_in_background_uses_persisted_cache_without_full_scan(t
|
||||
cache = await scanner.get_cached_data()
|
||||
assert len(cache.raw_data) == 1
|
||||
assert cache.raw_data[0]['file_path'] == normalized
|
||||
assert cache.version_index[11]['file_path'] == normalized
|
||||
|
||||
assert scanner._hash_index.get_path('hash-one') == normalized
|
||||
|
||||
@@ -301,6 +302,7 @@ async def test_load_persisted_cache_populates_cache(tmp_path: Path, monkeypatch)
|
||||
assert entry['file_path'] == normalized
|
||||
assert entry['tags'] == ['alpha']
|
||||
assert entry['civitai']['trainedWords'] == ['abc']
|
||||
assert cache.version_index[11]['file_path'] == normalized
|
||||
assert scanner._hash_index.get_path('hash-one') == normalized
|
||||
assert scanner._tags_count == {'alpha': 1}
|
||||
assert ws_stub.payloads[-1]['stage'] == 'loading_cache'
|
||||
@@ -381,6 +383,66 @@ async def test_batch_delete_persists_removal(tmp_path: Path, monkeypatch):
|
||||
assert remaining == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_version_index_tracks_version_ids(tmp_path: Path):
|
||||
scanner = DummyScanner(tmp_path)
|
||||
|
||||
first_path = _normalize_path(tmp_path / 'alpha.txt')
|
||||
second_path = _normalize_path(tmp_path / 'beta.txt')
|
||||
|
||||
first_entry = {
|
||||
'file_path': first_path,
|
||||
'file_name': 'alpha',
|
||||
'model_name': 'alpha',
|
||||
'folder': '',
|
||||
'size': 1,
|
||||
'modified': 1.0,
|
||||
'sha256': 'hash-alpha',
|
||||
'tags': [],
|
||||
'civitai': {'id': 101, 'modelId': 1, 'name': 'alpha'},
|
||||
}
|
||||
|
||||
second_entry = {
|
||||
'file_path': second_path,
|
||||
'file_name': 'beta',
|
||||
'model_name': 'beta',
|
||||
'folder': '',
|
||||
'size': 1,
|
||||
'modified': 1.0,
|
||||
'sha256': 'hash-beta',
|
||||
'tags': [],
|
||||
'civitai': {'id': 202, 'modelId': 2, 'name': 'beta'},
|
||||
}
|
||||
|
||||
hash_index = ModelHashIndex()
|
||||
hash_index.add_entry('hash-alpha', first_path)
|
||||
hash_index.add_entry('hash-beta', second_path)
|
||||
|
||||
scan_result = CacheBuildResult(
|
||||
raw_data=[first_entry, second_entry],
|
||||
hash_index=hash_index,
|
||||
tags_count={},
|
||||
excluded_models=[],
|
||||
)
|
||||
|
||||
await scanner._apply_scan_result(scan_result)
|
||||
|
||||
cache = await scanner.get_cached_data()
|
||||
assert cache.version_index[101]['file_path'] == first_path
|
||||
assert cache.version_index[202]['file_path'] == second_path
|
||||
|
||||
assert await scanner.check_model_version_exists(101) is True
|
||||
assert await scanner.check_model_version_exists('202') is True
|
||||
assert await scanner.check_model_version_exists(999) is False
|
||||
|
||||
removed = await scanner._batch_update_cache_for_deleted_models([first_path])
|
||||
assert removed is True
|
||||
|
||||
cache_after = await scanner.get_cached_data()
|
||||
assert 101 not in cache_after.version_index
|
||||
assert await scanner.check_model_version_exists(101) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconcile_cache_adds_new_files_and_updates_hash_index(tmp_path: Path):
|
||||
first, _, _ = _create_files(tmp_path)
|
||||
|
||||
182
tests/services/test_preview_asset_service.py
Normal file
182
tests/services/test_preview_asset_service.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.preview_asset_service import PreviewAssetService
|
||||
|
||||
|
||||
class StubMetadataManager:
|
||||
async def save_metadata(self, *_args: Any, **_kwargs: Any) -> bool: # pragma: no cover - helper
|
||||
return True
|
||||
|
||||
|
||||
class RecordingExifUtils:
|
||||
def __init__(self) -> None:
|
||||
self.called = False
|
||||
|
||||
def optimize_image(self, **kwargs):
|
||||
self.called = True
|
||||
return kwargs["image_data"], {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_preview_prefers_rewritten_civitai_image(tmp_path):
|
||||
metadata_path = tmp_path / "model.metadata.json"
|
||||
metadata_path.write_text("{}")
|
||||
local_metadata: dict[str, Any] = {}
|
||||
|
||||
class Downloader:
|
||||
def __init__(self):
|
||||
self.file_calls: list[tuple[str, str]] = []
|
||||
self.memory_calls = 0
|
||||
|
||||
async def download_file(self, url, path, use_auth=False):
|
||||
self.file_calls.append((url, path))
|
||||
if "width=450,optimized=true" in url:
|
||||
Path(path).write_bytes(b"image-data")
|
||||
return True, None
|
||||
return False, "fail"
|
||||
|
||||
async def download_to_memory(self, *_args, **_kwargs):
|
||||
self.memory_calls += 1
|
||||
return False, b"", {}
|
||||
|
||||
downloader = Downloader()
|
||||
|
||||
async def downloader_factory():
|
||||
return downloader
|
||||
|
||||
exif_utils = RecordingExifUtils()
|
||||
service = PreviewAssetService(
|
||||
metadata_manager=StubMetadataManager(),
|
||||
downloader_factory=downloader_factory,
|
||||
exif_utils=exif_utils,
|
||||
)
|
||||
|
||||
images = [
|
||||
{
|
||||
"url": "https://image.civitai.com/container/example/original=true/sample.jpeg",
|
||||
"type": "image",
|
||||
"nsfwLevel": 3,
|
||||
}
|
||||
]
|
||||
|
||||
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
|
||||
|
||||
assert downloader.memory_calls == 0
|
||||
assert exif_utils.called is False
|
||||
assert len(downloader.file_calls) == 1
|
||||
assert "width=450,optimized=true" in downloader.file_calls[0][0]
|
||||
preview_path = Path(local_metadata["preview_url"])
|
||||
assert preview_path.exists()
|
||||
assert preview_path.suffix == ".jpeg"
|
||||
assert local_metadata["preview_nsfw_level"] == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_preview_falls_back_to_webp_when_rewrite_fails(tmp_path):
|
||||
metadata_path = tmp_path / "model.metadata.json"
|
||||
metadata_path.write_text("{}")
|
||||
local_metadata: dict[str, Any] = {}
|
||||
|
||||
class Downloader:
|
||||
def __init__(self):
|
||||
self.file_calls: list[tuple[str, str]] = []
|
||||
self.memory_calls = 0
|
||||
|
||||
async def download_file(self, url, path, use_auth=False):
|
||||
self.file_calls.append((url, path))
|
||||
return False, "fail"
|
||||
|
||||
async def download_to_memory(self, *_args, **_kwargs):
|
||||
self.memory_calls += 1
|
||||
return True, b"raw-image", {}
|
||||
|
||||
downloader = Downloader()
|
||||
|
||||
async def downloader_factory():
|
||||
return downloader
|
||||
|
||||
class ExifUtils:
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
|
||||
def optimize_image(self, **kwargs):
|
||||
self.calls += 1
|
||||
return b"webp-data", {}
|
||||
|
||||
exif_utils = ExifUtils()
|
||||
|
||||
service = PreviewAssetService(
|
||||
metadata_manager=StubMetadataManager(),
|
||||
downloader_factory=downloader_factory,
|
||||
exif_utils=exif_utils,
|
||||
)
|
||||
|
||||
images = [
|
||||
{
|
||||
"url": "https://image.civitai.com/container/example/original=true/sample.png",
|
||||
"type": "image",
|
||||
"nsfwLevel": 1,
|
||||
}
|
||||
]
|
||||
|
||||
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
|
||||
|
||||
assert downloader.memory_calls == 1
|
||||
assert exif_utils.calls == 1
|
||||
preview_path = Path(local_metadata["preview_url"])
|
||||
assert preview_path.exists()
|
||||
assert preview_path.suffix == ".webp"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_preview_rewrites_civitai_video(tmp_path):
|
||||
metadata_path = tmp_path / "model.metadata.json"
|
||||
metadata_path.write_text("{}")
|
||||
local_metadata: dict[str, Any] = {}
|
||||
|
||||
class Downloader:
|
||||
def __init__(self):
|
||||
self.file_calls: list[tuple[str, str]] = []
|
||||
|
||||
async def download_file(self, url, path, use_auth=False):
|
||||
self.file_calls.append((url, path))
|
||||
if "transcode=true,width=450,optimized=true" in url:
|
||||
Path(path).write_bytes(b"video-data")
|
||||
return True, None
|
||||
if url.endswith(".mp4"):
|
||||
return False, "fail"
|
||||
return False, "unexpected"
|
||||
|
||||
async def download_to_memory(self, *_args, **_kwargs):
|
||||
pytest.fail("download_to_memory should not be used for video previews")
|
||||
|
||||
downloader = Downloader()
|
||||
|
||||
async def downloader_factory():
|
||||
return downloader
|
||||
|
||||
service = PreviewAssetService(
|
||||
metadata_manager=StubMetadataManager(),
|
||||
downloader_factory=downloader_factory,
|
||||
exif_utils=RecordingExifUtils(),
|
||||
)
|
||||
|
||||
images = [
|
||||
{
|
||||
"url": "https://image.civitai.com/container/example/original=true/sample.mp4",
|
||||
"type": "video",
|
||||
"nsfwLevel": 2,
|
||||
}
|
||||
]
|
||||
|
||||
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
|
||||
|
||||
assert len(downloader.file_calls) >= 1
|
||||
assert any("transcode=true,width=450,optimized=true" in url for url, _ in downloader.file_calls)
|
||||
preview_path = Path(local_metadata["preview_url"])
|
||||
assert preview_path.exists()
|
||||
assert preview_path.suffix == ".mp4"
|
||||
assert local_metadata["preview_nsfw_level"] == 2
|
||||
@@ -28,6 +28,7 @@ from py.utils.example_images_processor import (
|
||||
ExampleImagesImportError,
|
||||
ExampleImagesValidationError,
|
||||
)
|
||||
from py.utils.metadata_manager import MetadataManager
|
||||
from tests.conftest import MockModelService, MockScanner
|
||||
|
||||
|
||||
@@ -155,7 +156,9 @@ async def test_auto_organize_use_case_rejects_when_running() -> None:
|
||||
await use_case.execute(file_paths=None, progress_callback=None)
|
||||
|
||||
|
||||
async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None:
|
||||
async def test_bulk_metadata_refresh_emits_progress_and_updates_cache(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
scanner = MockScanner()
|
||||
scanner._cache.raw_data = [
|
||||
{
|
||||
@@ -170,6 +173,25 @@ async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None:
|
||||
settings = StubSettings()
|
||||
progress = ProgressCollector()
|
||||
|
||||
hydration_calls: list[str] = []
|
||||
|
||||
async def fake_hydrate(model_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
hydration_calls.append(model_data.get("file_path", ""))
|
||||
model_data.clear()
|
||||
model_data.update(
|
||||
{
|
||||
"file_path": "model1.safetensors",
|
||||
"sha256": "hash",
|
||||
"from_civitai": True,
|
||||
"model_name": "Demo",
|
||||
"extra": "value",
|
||||
"civitai": {"images": [{"url": "existing.png", "type": "image"}]},
|
||||
}
|
||||
)
|
||||
return model_data
|
||||
|
||||
monkeypatch.setattr(MetadataManager, "hydrate_model_data", staticmethod(fake_hydrate))
|
||||
|
||||
use_case = BulkMetadataRefreshUseCase(
|
||||
service=service,
|
||||
metadata_sync=metadata_sync,
|
||||
@@ -183,6 +205,9 @@ async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None:
|
||||
assert progress.events[0]["status"] == "started"
|
||||
assert progress.events[-1]["status"] == "completed"
|
||||
assert metadata_sync.calls
|
||||
assert metadata_sync.calls[0]["model_data"]["extra"] == "value"
|
||||
assert scanner._cache.raw_data[0]["extra"] == "value"
|
||||
assert hydration_calls == ["model1.safetensors"]
|
||||
assert scanner._cache.resort_calls == 1
|
||||
|
||||
|
||||
@@ -314,4 +339,4 @@ async def test_import_example_images_use_case_propagates_generic_error() -> None
|
||||
request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]})
|
||||
|
||||
with pytest.raises(ExampleImagesImportError):
|
||||
await use_case.execute(request)
|
||||
await use_case.execute(request)
|
||||
@@ -5,7 +5,7 @@ from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.settings_manager import settings
|
||||
from py.services.settings_manager import get_settings_manager
|
||||
from py.utils import example_images_download_manager as download_module
|
||||
|
||||
|
||||
@@ -19,19 +19,21 @@ class RecordingWebSocketManager:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def restore_settings() -> None:
|
||||
original = settings.settings.copy()
|
||||
manager = get_settings_manager()
|
||||
original = manager.settings.copy()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
settings.settings.clear()
|
||||
settings.settings.update(original)
|
||||
manager.settings.clear()
|
||||
manager.settings.update(original)
|
||||
|
||||
|
||||
async def test_start_download_requires_configured_path(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
manager = download_module.DownloadManager(ws_manager=RecordingWebSocketManager())
|
||||
|
||||
# Ensure example_images_path is not configured
|
||||
settings.settings.pop('example_images_path', None)
|
||||
settings_manager = get_settings_manager()
|
||||
settings_manager.settings.pop('example_images_path', None)
|
||||
|
||||
with pytest.raises(download_module.DownloadConfigurationError) as exc_info:
|
||||
await manager.start_download({})
|
||||
@@ -44,9 +46,10 @@ async def test_start_download_requires_configured_path(monkeypatch: pytest.Monke
|
||||
|
||||
|
||||
async def test_start_download_bootstraps_progress_and_task(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
||||
settings.settings["example_images_path"] = str(tmp_path)
|
||||
settings.settings["libraries"] = {"default": {}}
|
||||
settings.settings["active_library"] = "default"
|
||||
settings_manager = get_settings_manager()
|
||||
settings_manager.settings["example_images_path"] = str(tmp_path)
|
||||
settings_manager.settings["libraries"] = {"default": {}}
|
||||
settings_manager.settings["active_library"] = "default"
|
||||
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
@@ -84,9 +87,10 @@ async def test_start_download_bootstraps_progress_and_task(monkeypatch: pytest.M
|
||||
|
||||
|
||||
async def test_pause_and_resume_flow(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
||||
settings.settings["example_images_path"] = str(tmp_path)
|
||||
settings.settings["libraries"] = {"default": {}}
|
||||
settings.settings["active_library"] = "default"
|
||||
settings_manager = get_settings_manager()
|
||||
settings_manager.settings["example_images_path"] = str(tmp_path)
|
||||
settings_manager.settings["libraries"] = {"default": {}}
|
||||
settings_manager.settings["active_library"] = "default"
|
||||
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.settings_manager import settings
|
||||
from py.services.settings_manager import get_settings_manager
|
||||
from py.utils.example_images_file_manager import ExampleImagesFileManager
|
||||
|
||||
|
||||
@@ -22,16 +22,18 @@ class JsonRequest:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def restore_settings() -> None:
|
||||
original = settings.settings.copy()
|
||||
manager = get_settings_manager()
|
||||
original = manager.settings.copy()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
settings.settings.clear()
|
||||
settings.settings.update(original)
|
||||
manager.settings.clear()
|
||||
manager.settings.update(original)
|
||||
|
||||
|
||||
async def test_open_folder_requires_existing_model_directory(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
||||
settings.settings["example_images_path"] = str(tmp_path)
|
||||
settings_manager = get_settings_manager()
|
||||
settings_manager.settings["example_images_path"] = str(tmp_path)
|
||||
model_hash = "a" * 64
|
||||
model_folder = tmp_path / model_hash
|
||||
model_folder.mkdir()
|
||||
@@ -65,7 +67,8 @@ async def test_open_folder_requires_existing_model_directory(monkeypatch: pytest
|
||||
|
||||
|
||||
async def test_open_folder_rejects_invalid_paths(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
||||
settings.settings["example_images_path"] = str(tmp_path)
|
||||
settings_manager = get_settings_manager()
|
||||
settings_manager.settings["example_images_path"] = str(tmp_path)
|
||||
|
||||
def fake_get_model_folder(_hash):
|
||||
return str(tmp_path.parent / "outside")
|
||||
@@ -81,7 +84,8 @@ async def test_open_folder_rejects_invalid_paths(monkeypatch: pytest.MonkeyPatch
|
||||
|
||||
|
||||
async def test_get_files_lists_supported_media(tmp_path) -> None:
|
||||
settings.settings["example_images_path"] = str(tmp_path)
|
||||
settings_manager = get_settings_manager()
|
||||
settings_manager.settings["example_images_path"] = str(tmp_path)
|
||||
model_hash = "b" * 64
|
||||
model_folder = tmp_path / model_hash
|
||||
model_folder.mkdir()
|
||||
@@ -99,7 +103,8 @@ async def test_get_files_lists_supported_media(tmp_path) -> None:
|
||||
|
||||
|
||||
async def test_has_images_reports_presence(tmp_path) -> None:
|
||||
settings.settings["example_images_path"] = str(tmp_path)
|
||||
settings_manager = get_settings_manager()
|
||||
settings_manager.settings["example_images_path"] = str(tmp_path)
|
||||
model_hash = "c" * 64
|
||||
model_folder = tmp_path / model_hash
|
||||
model_folder.mkdir()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
@@ -30,7 +32,23 @@ def patch_metadata_manager(monkeypatch: pytest.MonkeyPatch):
|
||||
saved.append((path, metadata.copy()))
|
||||
return True
|
||||
|
||||
class SimpleMetadata:
|
||||
def __init__(self, payload: Dict[str, Any]) -> None:
|
||||
self._payload = payload
|
||||
self._unknown_fields: Dict[str, Any] = {}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return self._payload.copy()
|
||||
|
||||
async def fake_load(path: str, *_args: Any, **_kwargs: Any):
|
||||
metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json"
|
||||
if os.path.exists(metadata_path):
|
||||
data = json.loads(Path(metadata_path).read_text(encoding="utf-8"))
|
||||
return SimpleMetadata(data), False
|
||||
return None, False
|
||||
|
||||
monkeypatch.setattr(metadata_module.MetadataManager, "save_metadata", staticmethod(fake_save))
|
||||
monkeypatch.setattr(metadata_module.MetadataManager, "load_metadata", staticmethod(fake_load))
|
||||
return saved
|
||||
|
||||
|
||||
@@ -64,10 +82,80 @@ async def test_update_metadata_after_import_enriches_entries(monkeypatch: pytest
|
||||
assert custom[0]["hasMeta"] is True
|
||||
assert custom[0]["type"] == "image"
|
||||
|
||||
assert patch_metadata_manager[0][0] == str(model_file)
|
||||
assert Path(patch_metadata_manager[0][0]) == model_file
|
||||
assert scanner.updates
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_metadata_after_import_preserves_existing_metadata(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
patch_metadata_manager,
|
||||
):
|
||||
model_hash = "b" * 64
|
||||
model_file = tmp_path / "preserve.safetensors"
|
||||
model_file.write_text("content", encoding="utf-8")
|
||||
metadata_path = tmp_path / "preserve.metadata.json"
|
||||
|
||||
existing_payload = {
|
||||
"model_name": "Example",
|
||||
"file_path": str(model_file),
|
||||
"civitai": {
|
||||
"id": 42,
|
||||
"modelId": 88,
|
||||
"name": "Example",
|
||||
"trainedWords": ["foo"],
|
||||
"images": [{"url": "https://example.com/default.png", "type": "image"}],
|
||||
"customImages": [
|
||||
{"id": "existing-id", "type": "image", "url": "", "nsfwLevel": 0}
|
||||
],
|
||||
},
|
||||
"extraField": "keep-me",
|
||||
}
|
||||
metadata_path.write_text(json.dumps(existing_payload), encoding="utf-8")
|
||||
|
||||
model_data = {
|
||||
"sha256": model_hash,
|
||||
"model_name": "Example",
|
||||
"file_path": str(model_file),
|
||||
"civitai": {
|
||||
"id": 42,
|
||||
"modelId": 88,
|
||||
"name": "Example",
|
||||
"trainedWords": ["foo"],
|
||||
"customImages": [],
|
||||
},
|
||||
}
|
||||
scanner = StubScanner([model_data])
|
||||
|
||||
image_path = tmp_path / "new.png"
|
||||
image_path.write_bytes(b"fakepng")
|
||||
|
||||
monkeypatch.setattr(metadata_module.ExifUtils, "extract_image_metadata", staticmethod(lambda _path: None))
|
||||
monkeypatch.setattr(metadata_module.MetadataUpdater, "_parse_image_metadata", staticmethod(lambda payload: None))
|
||||
|
||||
regular, custom = await metadata_module.MetadataUpdater.update_metadata_after_import(
|
||||
model_hash,
|
||||
model_data,
|
||||
scanner,
|
||||
[(str(image_path), "new-id")],
|
||||
)
|
||||
|
||||
assert regular == existing_payload["civitai"]["images"]
|
||||
assert any(entry["id"] == "new-id" for entry in custom)
|
||||
|
||||
saved_path, saved_payload = patch_metadata_manager[-1]
|
||||
assert Path(saved_path) == model_file
|
||||
assert saved_payload["extraField"] == "keep-me"
|
||||
assert saved_payload["civitai"]["images"] == existing_payload["civitai"]["images"]
|
||||
assert saved_payload["civitai"]["trainedWords"] == ["foo"]
|
||||
assert {entry["id"] for entry in saved_payload["civitai"]["customImages"]} == {"existing-id", "new-id"}
|
||||
|
||||
assert scanner.updates
|
||||
updated_metadata = scanner.updates[-1][2]
|
||||
assert updated_metadata["civitai"]["images"] == existing_payload["civitai"]["images"]
|
||||
assert {entry["id"] for entry in updated_metadata["civitai"]["customImages"]} == {"existing-id", "new-id"}
|
||||
|
||||
async def test_refresh_model_metadata_records_failures(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
model_hash = "b" * 64
|
||||
model_file = tmp_path / "model.safetensors"
|
||||
@@ -79,6 +167,16 @@ async def test_refresh_model_metadata_records_failures(monkeypatch: pytest.Monke
|
||||
async def fetch_and_update_model(self, **_kwargs):
|
||||
return True, None
|
||||
|
||||
async def fake_hydrate(model_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
model_data["hydrated"] = True
|
||||
return model_data
|
||||
|
||||
monkeypatch.setattr(
|
||||
metadata_module.MetadataManager,
|
||||
"hydrate_model_data",
|
||||
staticmethod(fake_hydrate),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(metadata_module, "_metadata_sync_service", StubMetadataSync())
|
||||
|
||||
result = await metadata_module.MetadataUpdater.refresh_model_metadata(
|
||||
@@ -89,6 +187,7 @@ async def test_refresh_model_metadata_records_failures(monkeypatch: pytest.Monke
|
||||
{"refreshed_models": set(), "errors": [], "last_error": None},
|
||||
)
|
||||
assert result is True
|
||||
assert cache_item["hydrated"] is True
|
||||
|
||||
|
||||
async def test_update_metadata_from_local_examples_generates_entries(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
@@ -112,4 +211,4 @@ async def test_update_metadata_from_local_examples_generates_entries(monkeypatch
|
||||
str(model_dir),
|
||||
)
|
||||
assert success is True
|
||||
assert model_data["civitai"]["images"]
|
||||
assert model_data["civitai"]["images"]
|
||||
@@ -6,7 +6,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.settings_manager import settings
|
||||
from py.services.settings_manager import get_settings_manager
|
||||
from py.utils.example_images_paths import (
|
||||
ensure_library_root_exists,
|
||||
get_model_folder,
|
||||
@@ -18,18 +18,24 @@ from py.utils.example_images_paths import (
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def restore_settings():
|
||||
original = copy.deepcopy(settings.settings)
|
||||
manager = get_settings_manager()
|
||||
original = copy.deepcopy(manager.settings)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
settings.settings.clear()
|
||||
settings.settings.update(original)
|
||||
manager.settings.clear()
|
||||
manager.settings.update(original)
|
||||
|
||||
|
||||
def test_get_model_folder_single_library(tmp_path):
|
||||
settings.settings['example_images_path'] = str(tmp_path)
|
||||
settings.settings['libraries'] = {'default': {}}
|
||||
settings.settings['active_library'] = 'default'
|
||||
@pytest.fixture
|
||||
def settings_manager():
|
||||
return get_settings_manager()
|
||||
|
||||
|
||||
def test_get_model_folder_single_library(tmp_path, settings_manager):
|
||||
settings_manager.settings['example_images_path'] = str(tmp_path)
|
||||
settings_manager.settings['libraries'] = {'default': {}}
|
||||
settings_manager.settings['active_library'] = 'default'
|
||||
|
||||
model_hash = 'a' * 64
|
||||
folder = get_model_folder(model_hash)
|
||||
@@ -39,13 +45,13 @@ def test_get_model_folder_single_library(tmp_path):
|
||||
assert relative == model_hash
|
||||
|
||||
|
||||
def test_get_model_folder_multi_library(tmp_path):
|
||||
settings.settings['example_images_path'] = str(tmp_path)
|
||||
settings.settings['libraries'] = {
|
||||
def test_get_model_folder_multi_library(tmp_path, settings_manager):
|
||||
settings_manager.settings['example_images_path'] = str(tmp_path)
|
||||
settings_manager.settings['libraries'] = {
|
||||
'default': {},
|
||||
'Alt Library': {},
|
||||
}
|
||||
settings.settings['active_library'] = 'Alt Library'
|
||||
settings_manager.settings['active_library'] = 'Alt Library'
|
||||
|
||||
model_hash = 'b' * 64
|
||||
expected_folder = tmp_path / 'Alt_Library' / model_hash
|
||||
@@ -57,13 +63,13 @@ def test_get_model_folder_multi_library(tmp_path):
|
||||
assert relative == os.path.join('Alt_Library', model_hash).replace('\\', '/')
|
||||
|
||||
|
||||
def test_get_model_folder_migrates_legacy_structure(tmp_path):
|
||||
settings.settings['example_images_path'] = str(tmp_path)
|
||||
settings.settings['libraries'] = {
|
||||
def test_get_model_folder_migrates_legacy_structure(tmp_path, settings_manager):
|
||||
settings_manager.settings['example_images_path'] = str(tmp_path)
|
||||
settings_manager.settings['libraries'] = {
|
||||
'default': {},
|
||||
'extra': {},
|
||||
}
|
||||
settings.settings['active_library'] = 'extra'
|
||||
settings_manager.settings['active_library'] = 'extra'
|
||||
|
||||
model_hash = 'c' * 64
|
||||
legacy_folder = tmp_path / model_hash
|
||||
@@ -82,31 +88,31 @@ def test_get_model_folder_migrates_legacy_structure(tmp_path):
|
||||
assert (expected_folder / 'image.png').exists()
|
||||
|
||||
|
||||
def test_ensure_library_root_exists_creates_directories(tmp_path):
|
||||
settings.settings['example_images_path'] = str(tmp_path)
|
||||
settings.settings['libraries'] = {'default': {}, 'secondary': {}}
|
||||
settings.settings['active_library'] = 'secondary'
|
||||
def test_ensure_library_root_exists_creates_directories(tmp_path, settings_manager):
|
||||
settings_manager.settings['example_images_path'] = str(tmp_path)
|
||||
settings_manager.settings['libraries'] = {'default': {}, 'secondary': {}}
|
||||
settings_manager.settings['active_library'] = 'secondary'
|
||||
|
||||
resolved = ensure_library_root_exists('secondary')
|
||||
assert Path(resolved) == tmp_path / 'secondary'
|
||||
assert (tmp_path / 'secondary').is_dir()
|
||||
|
||||
|
||||
def test_iter_library_roots_returns_all_configured(tmp_path):
|
||||
settings.settings['example_images_path'] = str(tmp_path)
|
||||
settings.settings['libraries'] = {'default': {}, 'alt': {}}
|
||||
settings.settings['active_library'] = 'alt'
|
||||
def test_iter_library_roots_returns_all_configured(tmp_path, settings_manager):
|
||||
settings_manager.settings['example_images_path'] = str(tmp_path)
|
||||
settings_manager.settings['libraries'] = {'default': {}, 'alt': {}}
|
||||
settings_manager.settings['active_library'] = 'alt'
|
||||
|
||||
roots = dict(iter_library_roots())
|
||||
assert roots['default'] == str(tmp_path / 'default')
|
||||
assert roots['alt'] == str(tmp_path / 'alt')
|
||||
|
||||
|
||||
def test_is_valid_example_images_root_accepts_hash_directories(tmp_path):
|
||||
settings.settings['example_images_path'] = str(tmp_path)
|
||||
def test_is_valid_example_images_root_accepts_hash_directories(tmp_path, settings_manager):
|
||||
settings_manager.settings['example_images_path'] = str(tmp_path)
|
||||
# Ensure single-library mode (not multi-library mode)
|
||||
settings.settings['libraries'] = {'default': {}}
|
||||
settings.settings['active_library'] = 'default'
|
||||
settings_manager.settings['libraries'] = {'default': {}}
|
||||
settings_manager.settings['active_library'] = 'default'
|
||||
|
||||
hash_folder = tmp_path / ('d' * 64)
|
||||
hash_folder.mkdir()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
@@ -7,18 +8,42 @@ from typing import Any, Dict, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.settings_manager import settings
|
||||
from py.services.settings_manager import get_settings_manager
|
||||
from py.utils import example_images_metadata as metadata_module
|
||||
from py.utils import example_images_processor as processor_module
|
||||
from py.utils.example_images_paths import get_model_folder
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def restore_settings() -> None:
|
||||
original = settings.settings.copy()
|
||||
manager = get_settings_manager()
|
||||
original = manager.settings.copy()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
settings.settings.clear()
|
||||
settings.settings.update(original)
|
||||
manager.settings.clear()
|
||||
manager.settings.update(original)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_metadata_loader(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class SimpleMetadata:
|
||||
def __init__(self, payload: Dict[str, Any]) -> None:
|
||||
self._payload = payload
|
||||
self._unknown_fields: Dict[str, Any] = {}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return self._payload.copy()
|
||||
|
||||
async def fake_load(path: str, *_args: Any, **_kwargs: Any):
|
||||
metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json"
|
||||
if os.path.exists(metadata_path):
|
||||
data = json.loads(Path(metadata_path).read_text(encoding="utf-8"))
|
||||
return SimpleMetadata(data), False
|
||||
return None, False
|
||||
|
||||
monkeypatch.setattr(processor_module.MetadataManager, "load_metadata", staticmethod(fake_load))
|
||||
monkeypatch.setattr(metadata_module.MetadataManager, "load_metadata", staticmethod(fake_load))
|
||||
|
||||
|
||||
def test_get_file_extension_from_magic_bytes() -> None:
|
||||
@@ -90,9 +115,10 @@ def stub_scanners(monkeypatch: pytest.MonkeyPatch, tmp_path) -> StubScanner:
|
||||
|
||||
|
||||
async def test_import_images_creates_hash_directory(monkeypatch: pytest.MonkeyPatch, tmp_path, stub_scanners: StubScanner) -> None:
|
||||
settings.settings["example_images_path"] = str(tmp_path / "examples")
|
||||
settings.settings["libraries"] = {"default": {}}
|
||||
settings.settings["active_library"] = "default"
|
||||
settings_manager = get_settings_manager()
|
||||
settings_manager.settings["example_images_path"] = str(tmp_path / "examples")
|
||||
settings_manager.settings["libraries"] = {"default": {}}
|
||||
settings_manager.settings["active_library"] = "default"
|
||||
|
||||
source_file = tmp_path / "upload.png"
|
||||
source_file.write_bytes(b"PNG data")
|
||||
@@ -112,7 +138,7 @@ async def test_import_images_creates_hash_directory(monkeypatch: pytest.MonkeyPa
|
||||
assert result["success"] is True
|
||||
assert result["files"][0]["name"].startswith("custom_short")
|
||||
|
||||
model_folder = Path(settings.settings["example_images_path"]) / ("a" * 64)
|
||||
model_folder = Path(settings_manager.settings["example_images_path"]) / ("a" * 64)
|
||||
assert model_folder.exists()
|
||||
created_files = list(model_folder.glob("custom_short*.png"))
|
||||
assert len(created_files) == 1
|
||||
@@ -132,7 +158,8 @@ async def test_import_images_rejects_missing_parameters(monkeypatch: pytest.Monk
|
||||
|
||||
|
||||
async def test_import_images_raises_when_model_not_found(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
||||
settings.settings["example_images_path"] = str(tmp_path)
|
||||
settings_manager = get_settings_manager()
|
||||
settings_manager.settings["example_images_path"] = str(tmp_path)
|
||||
|
||||
async def _empty_scanner(cls=None):
|
||||
return StubScanner([])
|
||||
@@ -143,3 +170,88 @@ async def test_import_images_raises_when_model_not_found(monkeypatch: pytest.Mon
|
||||
|
||||
with pytest.raises(processor_module.ExampleImagesImportError):
|
||||
await processor_module.ExampleImagesProcessor.import_images("a" * 64, [str(tmp_path / "missing.png")])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_custom_image_preserves_existing_metadata(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
|
||||
settings_manager = get_settings_manager()
|
||||
settings_manager.settings["example_images_path"] = str(tmp_path / "examples")
|
||||
|
||||
model_hash = "c" * 64
|
||||
model_file = tmp_path / "keep.safetensors"
|
||||
model_file.write_text("content", encoding="utf-8")
|
||||
metadata_path = tmp_path / "keep.metadata.json"
|
||||
|
||||
existing_metadata = {
|
||||
"model_name": "Keep",
|
||||
"file_path": str(model_file),
|
||||
"civitai": {
|
||||
"images": [{"url": "https://example.com/default.png", "type": "image"}],
|
||||
"customImages": [{"id": "existing-id", "url": "", "type": "image"}],
|
||||
"trainedWords": ["foo"],
|
||||
},
|
||||
}
|
||||
metadata_path.write_text(json.dumps(existing_metadata), encoding="utf-8")
|
||||
|
||||
model_data = {
|
||||
"sha256": model_hash,
|
||||
"model_name": "Keep",
|
||||
"file_path": str(model_file),
|
||||
"civitai": {
|
||||
"customImages": [{"id": "existing-id", "url": "", "type": "image"}],
|
||||
"trainedWords": ["foo"],
|
||||
},
|
||||
}
|
||||
|
||||
class Scanner(StubScanner):
|
||||
def has_hash(self, hash_value: str) -> bool:
|
||||
return hash_value == model_hash
|
||||
|
||||
scanner = Scanner([model_data])
|
||||
|
||||
async def _return_scanner(cls=None):
|
||||
return scanner
|
||||
|
||||
monkeypatch.setattr(processor_module.ServiceRegistry, "get_lora_scanner", classmethod(_return_scanner))
|
||||
monkeypatch.setattr(processor_module.ServiceRegistry, "get_checkpoint_scanner", classmethod(_return_scanner))
|
||||
monkeypatch.setattr(processor_module.ServiceRegistry, "get_embedding_scanner", classmethod(_return_scanner))
|
||||
|
||||
model_folder = get_model_folder(model_hash)
|
||||
os.makedirs(model_folder, exist_ok=True)
|
||||
(Path(model_folder) / "custom_existing-id.png").write_bytes(b"data")
|
||||
|
||||
saved: list[tuple[str, Dict[str, Any]]] = []
|
||||
|
||||
async def fake_save(path: str, payload: Dict[str, Any]) -> bool:
|
||||
saved.append((path, payload.copy()))
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(processor_module.MetadataManager, "save_metadata", staticmethod(fake_save))
|
||||
|
||||
class StubRequest:
|
||||
def __init__(self, payload: Dict[str, Any]) -> None:
|
||||
self._payload = payload
|
||||
|
||||
async def json(self) -> Dict[str, Any]:
|
||||
return self._payload
|
||||
|
||||
response = await processor_module.ExampleImagesProcessor.delete_custom_image(
|
||||
StubRequest({"model_hash": model_hash, "short_id": "existing-id"})
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
body = json.loads(response.text)
|
||||
assert body["success"] is True
|
||||
assert body["custom_images"] == []
|
||||
assert not (Path(model_folder) / "custom_existing-id.png").exists()
|
||||
|
||||
saved_path, saved_payload = saved[-1]
|
||||
assert saved_path == str(model_file)
|
||||
assert saved_payload["civitai"]["images"] == existing_metadata["civitai"]["images"]
|
||||
assert saved_payload["civitai"]["trainedWords"] == ["foo"]
|
||||
assert saved_payload["civitai"]["customImages"] == []
|
||||
|
||||
assert scanner.updated
|
||||
_, _, updated_metadata = scanner.updated[-1]
|
||||
assert updated_metadata["civitai"]["images"] == existing_metadata["civitai"]["images"]
|
||||
assert updated_metadata["civitai"]["customImages"] == []
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from py.services.settings_manager import settings
|
||||
from py.services.settings_manager import SettingsManager, get_settings_manager
|
||||
from py.utils.utils import (
|
||||
calculate_recipe_fingerprint,
|
||||
calculate_relative_path_for_model,
|
||||
@@ -9,7 +9,8 @@ from py.utils.utils import (
|
||||
|
||||
@pytest.fixture
|
||||
def isolated_settings(monkeypatch):
|
||||
default_settings = settings._get_default_settings()
|
||||
manager = get_settings_manager()
|
||||
default_settings = manager._get_default_settings()
|
||||
default_settings.update(
|
||||
{
|
||||
"download_path_templates": {
|
||||
@@ -20,8 +21,8 @@ def isolated_settings(monkeypatch):
|
||||
"base_model_path_mappings": {},
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(settings, "settings", default_settings)
|
||||
monkeypatch.setattr(type(settings), "_save_settings", lambda self: None)
|
||||
monkeypatch.setattr(manager, "settings", default_settings)
|
||||
monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None)
|
||||
return default_settings
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { api } from "../../scripts/api.js";
|
||||
import { addJsonDisplayWidget } from "./json_display_widget.js";
|
||||
import { getNodeFromGraph } from "./utils.js";
|
||||
|
||||
app.registerExtension({
|
||||
name: "LoraManager.DebugMetadata",
|
||||
@@ -8,8 +9,8 @@ app.registerExtension({
|
||||
setup() {
|
||||
// Add message handler to listen for metadata updates from Python
|
||||
api.addEventListener("metadata_update", (event) => {
|
||||
const { id, metadata } = event.detail;
|
||||
this.handleMetadataUpdate(id, metadata);
|
||||
const { id, graph_id: graphId, metadata } = event.detail;
|
||||
this.handleMetadataUpdate(id, graphId, metadata);
|
||||
});
|
||||
},
|
||||
|
||||
@@ -37,8 +38,8 @@ app.registerExtension({
|
||||
},
|
||||
|
||||
// Handle metadata updates from Python
|
||||
handleMetadataUpdate(id, metadata) {
|
||||
const node = app.graph.getNodeById(+id);
|
||||
handleMetadataUpdate(id, graphId, metadata) {
|
||||
const node = getNodeFromGraph(graphId, id);
|
||||
if (!node || node.comfyClass !== "Debug Metadata (LoraManager)") {
|
||||
console.warn("Node not found or not a DebugMetadata node:", id);
|
||||
return;
|
||||
|
||||
@@ -7,6 +7,8 @@ import {
|
||||
chainCallback,
|
||||
mergeLoras,
|
||||
setupInputWidgetWithAutocomplete,
|
||||
getAllGraphNodes,
|
||||
getNodeFromGraph,
|
||||
} from "./utils.js";
|
||||
import { addLorasWidget } from "./loras_widget.js";
|
||||
|
||||
@@ -16,23 +18,26 @@ app.registerExtension({
|
||||
setup() {
|
||||
// Add message handler to listen for messages from Python
|
||||
api.addEventListener("lora_code_update", (event) => {
|
||||
const { id, lora_code, mode } = event.detail;
|
||||
this.handleLoraCodeUpdate(id, lora_code, mode);
|
||||
this.handleLoraCodeUpdate(event.detail || {});
|
||||
});
|
||||
},
|
||||
|
||||
// Handle lora code updates from Python
|
||||
handleLoraCodeUpdate(id, loraCode, mode) {
|
||||
handleLoraCodeUpdate(message) {
|
||||
const nodeId = message?.node_id ?? message?.id;
|
||||
const graphId = message?.graph_id;
|
||||
const loraCode = message?.lora_code ?? "";
|
||||
const mode = message?.mode ?? "append";
|
||||
|
||||
const numericNodeId =
|
||||
typeof nodeId === "string" ? Number(nodeId) : nodeId;
|
||||
|
||||
// Handle broadcast mode (for Desktop/non-browser support)
|
||||
if (id === -1) {
|
||||
if (numericNodeId === -1) {
|
||||
// Find all Lora Loader nodes in the current graph
|
||||
const loraLoaderNodes = [];
|
||||
for (const nodeId in app.graph._nodes_by_id) {
|
||||
const node = app.graph._nodes_by_id[nodeId];
|
||||
if (node.comfyClass === "Lora Loader (LoraManager)") {
|
||||
loraLoaderNodes.push(node);
|
||||
}
|
||||
}
|
||||
const loraLoaderNodes = getAllGraphNodes(app.graph)
|
||||
.map(({ node }) => node)
|
||||
.filter((node) => node?.comfyClass === "Lora Loader (LoraManager)");
|
||||
|
||||
// Update each Lora Loader node found
|
||||
if (loraLoaderNodes.length > 0) {
|
||||
@@ -52,14 +57,18 @@ app.registerExtension({
|
||||
}
|
||||
|
||||
// Standard mode - update a specific node
|
||||
const node = app.graph.getNodeById(+id);
|
||||
const node = getNodeFromGraph(graphId, numericNodeId);
|
||||
if (
|
||||
!node ||
|
||||
(node.comfyClass !== "Lora Loader (LoraManager)" &&
|
||||
node.comfyClass !== "Lora Stacker (LoraManager)" &&
|
||||
node.comfyClass !== "WanVideo Lora Select (LoraManager)")
|
||||
) {
|
||||
console.warn("Node not found or not a LoraLoader:", id);
|
||||
console.warn(
|
||||
"Node not found or not a LoraLoader:",
|
||||
graphId ?? "root",
|
||||
nodeId
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ import {
|
||||
chainCallback,
|
||||
mergeLoras,
|
||||
setupInputWidgetWithAutocomplete,
|
||||
getLinkFromGraph,
|
||||
getNodeKey,
|
||||
} from "./utils.js";
|
||||
import { addLorasWidget } from "./loras_widget.js";
|
||||
|
||||
@@ -124,17 +126,18 @@ app.registerExtension({
|
||||
|
||||
// Helper function to find and update downstream Lora Loader nodes
|
||||
function updateDownstreamLoaders(startNode, visited = new Set()) {
|
||||
if (visited.has(startNode.id)) return;
|
||||
visited.add(startNode.id);
|
||||
const nodeKey = getNodeKey(startNode);
|
||||
if (!nodeKey || visited.has(nodeKey)) return;
|
||||
visited.add(nodeKey);
|
||||
|
||||
// Check each output link
|
||||
if (startNode.outputs) {
|
||||
for (const output of startNode.outputs) {
|
||||
if (output.links) {
|
||||
for (const linkId of output.links) {
|
||||
const link = app.graph.links[linkId];
|
||||
const link = getLinkFromGraph(startNode.graph, linkId);
|
||||
if (link) {
|
||||
const targetNode = app.graph.getNodeById(link.target_id);
|
||||
const targetNode = startNode.graph?.getNodeById?.(link.target_id);
|
||||
|
||||
// If target is a Lora Loader, collect all active loras in the chain and update
|
||||
if (
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { api } from "../../scripts/api.js";
|
||||
import { CONVERTED_TYPE } from "./utils.js";
|
||||
import { CONVERTED_TYPE, getNodeFromGraph } from "./utils.js";
|
||||
import { addTagsWidget } from "./tags_widget.js";
|
||||
|
||||
// TriggerWordToggle extension for ComfyUI
|
||||
@@ -10,8 +10,8 @@ app.registerExtension({
|
||||
setup() {
|
||||
// Add message handler to listen for messages from Python
|
||||
api.addEventListener("trigger_word_update", (event) => {
|
||||
const { id, message } = event.detail;
|
||||
this.handleTriggerWordUpdate(id, message);
|
||||
const { id, graph_id: graphId, message } = event.detail;
|
||||
this.handleTriggerWordUpdate(id, graphId, message);
|
||||
});
|
||||
},
|
||||
|
||||
@@ -76,8 +76,8 @@ app.registerExtension({
|
||||
},
|
||||
|
||||
// Handle trigger word updates from Python
|
||||
handleTriggerWordUpdate(id, message) {
|
||||
const node = app.graph.getNodeById(+id);
|
||||
handleTriggerWordUpdate(id, graphId, message) {
|
||||
const node = getNodeFromGraph(graphId, id);
|
||||
if (!node || node.comfyClass !== "TriggerWord Toggle (LoraManager)") {
|
||||
console.warn("Node not found or not a TriggerWordToggle:", id);
|
||||
return;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// ComfyUI extension to track model usage statistics
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { api } from "../../scripts/api.js";
|
||||
import { showToast } from "./utils.js";
|
||||
import { getAllGraphNodes, getNodeReference, showToast } from "./utils.js";
|
||||
|
||||
// Define target nodes and their widget configurations
|
||||
const PATH_CORRECTION_TARGETS = [
|
||||
@@ -56,25 +56,35 @@ app.registerExtension({
|
||||
|
||||
async refreshRegistry() {
|
||||
try {
|
||||
// Get current workflow nodes
|
||||
const prompt = await app.graphToPrompt();
|
||||
const workflow = prompt.workflow;
|
||||
if (!workflow || !workflow.nodes) {
|
||||
console.warn("No workflow nodes found for registry refresh");
|
||||
return;
|
||||
}
|
||||
|
||||
// Find all Lora nodes
|
||||
const loraNodes = [];
|
||||
for (const node of workflow.nodes.values()) {
|
||||
if (node.type === "Lora Loader (LoraManager)" ||
|
||||
node.type === "Lora Stacker (LoraManager)" ||
|
||||
node.type === "WanVideo Lora Select (LoraManager)") {
|
||||
const nodeEntries = getAllGraphNodes(app.graph);
|
||||
|
||||
for (const { graph, node } of nodeEntries) {
|
||||
if (!node || !node.comfyClass) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (
|
||||
node.comfyClass === "Lora Loader (LoraManager)" ||
|
||||
node.comfyClass === "Lora Stacker (LoraManager)" ||
|
||||
node.comfyClass === "WanVideo Lora Select (LoraManager)"
|
||||
) {
|
||||
const reference = getNodeReference(node);
|
||||
if (!reference) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const graphName = typeof graph?.name === "string" && graph.name.trim()
|
||||
? graph.name
|
||||
: null;
|
||||
|
||||
loraNodes.push({
|
||||
node_id: node.id,
|
||||
bgcolor: node.bgcolor || null,
|
||||
title: node.title || node.type,
|
||||
type: node.type
|
||||
node_id: reference.node_id,
|
||||
graph_id: reference.graph_id,
|
||||
graph_name: graphName,
|
||||
bgcolor: node.bgcolor ?? node.color ?? null,
|
||||
title: node.title || node.comfyClass,
|
||||
type: node.comfyClass,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,120 @@ export const CONVERTED_TYPE = 'converted-widget';
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { AutoComplete } from "./autocomplete.js";
|
||||
|
||||
const ROOT_GRAPH_ID = "root";
|
||||
|
||||
function isMapLike(collection) {
|
||||
return collection && typeof collection.entries === "function" && typeof collection.values === "function";
|
||||
}
|
||||
|
||||
function getChildGraphs(graph) {
|
||||
if (!graph || !graph._subgraphs) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const rawSubgraphs = isMapLike(graph._subgraphs)
|
||||
? Array.from(graph._subgraphs.values())
|
||||
: Object.values(graph._subgraphs);
|
||||
|
||||
return rawSubgraphs
|
||||
.map((subgraph) => subgraph?.graph || subgraph?._graph || subgraph)
|
||||
.filter((subgraph) => subgraph && subgraph !== graph);
|
||||
}
|
||||
|
||||
function traverseGraphs(rootGraph, visitor, visited = new Set()) {
|
||||
const graph = rootGraph || app.graph;
|
||||
if (!graph) {
|
||||
return;
|
||||
}
|
||||
|
||||
const graphId = getGraphId(graph);
|
||||
if (visited.has(graphId)) {
|
||||
return;
|
||||
}
|
||||
visited.add(graphId);
|
||||
visitor(graph);
|
||||
|
||||
for (const subgraph of getChildGraphs(graph)) {
|
||||
traverseGraphs(subgraph, visitor, visited);
|
||||
}
|
||||
}
|
||||
|
||||
export function getGraphId(graph) {
|
||||
return graph?.id ?? ROOT_GRAPH_ID;
|
||||
}
|
||||
|
||||
export function getNodeGraphId(node) {
|
||||
if (!node) {
|
||||
return ROOT_GRAPH_ID;
|
||||
}
|
||||
return getGraphId(node.graph || app.graph);
|
||||
}
|
||||
|
||||
export function getGraphById(graphId, rootGraph = app.graph) {
|
||||
if (!graphId) {
|
||||
return rootGraph;
|
||||
}
|
||||
|
||||
let foundGraph = null;
|
||||
traverseGraphs(rootGraph, (graph) => {
|
||||
if (!foundGraph && getGraphId(graph) === graphId) {
|
||||
foundGraph = graph;
|
||||
}
|
||||
});
|
||||
return foundGraph;
|
||||
}
|
||||
|
||||
export function getNodeFromGraph(graphId, nodeId) {
|
||||
const graph = getGraphById(graphId) || app.graph;
|
||||
if (!graph || typeof graph.getNodeById !== "function") {
|
||||
return null;
|
||||
}
|
||||
|
||||
const numericId = typeof nodeId === "string" ? Number(nodeId) : nodeId;
|
||||
return graph.getNodeById(Number.isNaN(numericId) ? nodeId : numericId) || null;
|
||||
}
|
||||
|
||||
export function getAllGraphNodes(rootGraph = app.graph) {
|
||||
const nodes = [];
|
||||
traverseGraphs(rootGraph, (graph) => {
|
||||
if (Array.isArray(graph._nodes)) {
|
||||
for (const node of graph._nodes) {
|
||||
nodes.push({ graph, node });
|
||||
}
|
||||
}
|
||||
});
|
||||
return nodes;
|
||||
}
|
||||
|
||||
export function getNodeReference(node) {
|
||||
if (!node) {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
node_id: node.id,
|
||||
graph_id: getNodeGraphId(node),
|
||||
};
|
||||
}
|
||||
|
||||
export function getNodeKey(node) {
|
||||
if (!node) {
|
||||
return null;
|
||||
}
|
||||
return `${getNodeGraphId(node)}:${node.id}`;
|
||||
}
|
||||
|
||||
export function getLinkFromGraph(graph, linkId) {
|
||||
if (!graph || graph.links == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (isMapLike(graph.links)) {
|
||||
return graph.links.get(linkId) || null;
|
||||
}
|
||||
|
||||
return graph.links[linkId] || null;
|
||||
}
|
||||
|
||||
export function chainCallback(object, property, callback) {
|
||||
if (object == undefined) {
|
||||
//This should not happen.
|
||||
@@ -103,42 +217,56 @@ export const LORA_PATTERN = /<lora:([^:]+):([-\d\.]+)(?::([-\d\.]+))?>/g;
|
||||
// Get connected Lora Stacker nodes that feed into the current node
|
||||
export function getConnectedInputStackers(node) {
|
||||
const connectedStackers = [];
|
||||
|
||||
if (node.inputs) {
|
||||
for (const input of node.inputs) {
|
||||
if (input.name === "lora_stack" && input.link) {
|
||||
const link = app.graph.links[input.link];
|
||||
if (link) {
|
||||
const sourceNode = app.graph.getNodeById(link.origin_id);
|
||||
if (sourceNode && sourceNode.comfyClass === "Lora Stacker (LoraManager)") {
|
||||
connectedStackers.push(sourceNode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!node?.inputs) {
|
||||
return connectedStackers;
|
||||
}
|
||||
|
||||
for (const input of node.inputs) {
|
||||
if (input.name !== "lora_stack" || !input.link) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const link = getLinkFromGraph(node.graph, input.link);
|
||||
if (!link) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const sourceNode = node.graph?.getNodeById?.(link.origin_id);
|
||||
if (sourceNode && sourceNode.comfyClass === "Lora Stacker (LoraManager)") {
|
||||
connectedStackers.push(sourceNode);
|
||||
}
|
||||
}
|
||||
|
||||
return connectedStackers;
|
||||
}
|
||||
|
||||
// Get connected TriggerWord Toggle nodes that receive output from the current node
|
||||
export function getConnectedTriggerToggleNodes(node) {
|
||||
const connectedNodes = [];
|
||||
|
||||
if (node.outputs && node.outputs.length > 0) {
|
||||
for (const output of node.outputs) {
|
||||
if (output.links && output.links.length > 0) {
|
||||
for (const linkId of output.links) {
|
||||
const link = app.graph.links[linkId];
|
||||
if (link) {
|
||||
const targetNode = app.graph.getNodeById(link.target_id);
|
||||
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
|
||||
connectedNodes.push(targetNode.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!node?.outputs) {
|
||||
return connectedNodes;
|
||||
}
|
||||
|
||||
for (const output of node.outputs) {
|
||||
if (!output?.links?.length) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const linkId of output.links) {
|
||||
const link = getLinkFromGraph(node.graph, linkId);
|
||||
if (!link) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const targetNode = node.graph?.getNodeById?.(link.target_id);
|
||||
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
|
||||
connectedNodes.push(targetNode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return connectedNodes;
|
||||
}
|
||||
|
||||
@@ -161,11 +289,15 @@ export function getActiveLorasFromNode(node) {
|
||||
// Recursively collect all active loras from a node and its input chain
|
||||
export function collectActiveLorasFromChain(node, visited = new Set()) {
|
||||
// Prevent infinite loops from circular references
|
||||
if (visited.has(node.id)) {
|
||||
const nodeKey = getNodeKey(node);
|
||||
if (!nodeKey) {
|
||||
return new Set();
|
||||
}
|
||||
visited.add(node.id);
|
||||
|
||||
if (visited.has(nodeKey)) {
|
||||
return new Set();
|
||||
}
|
||||
visited.add(nodeKey);
|
||||
|
||||
// Get active loras from current node
|
||||
const allActiveLoraNames = getActiveLorasFromNode(node);
|
||||
|
||||
@@ -181,14 +313,22 @@ export function collectActiveLorasFromChain(node, visited = new Set()) {
|
||||
|
||||
// Update trigger words for connected toggle nodes
|
||||
export function updateConnectedTriggerWords(node, loraNames) {
|
||||
const connectedNodeIds = getConnectedTriggerToggleNodes(node);
|
||||
if (connectedNodeIds.length > 0) {
|
||||
const connectedNodes = getConnectedTriggerToggleNodes(node);
|
||||
if (connectedNodes.length > 0) {
|
||||
const nodeIds = connectedNodes
|
||||
.map((connectedNode) => getNodeReference(connectedNode))
|
||||
.filter((reference) => reference !== null);
|
||||
|
||||
if (nodeIds.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
fetch("/api/lm/loras/get_trigger_words", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
lora_names: Array.from(loraNames),
|
||||
node_ids: connectedNodeIds
|
||||
node_ids: nodeIds
|
||||
})
|
||||
}).catch(err => console.error("Error fetching trigger words:", err));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user