diff --git a/py/routes/handlers/misc_handlers.py b/py/routes/handlers/misc_handlers.py new file mode 100644 index 00000000..79b581e9 --- /dev/null +++ b/py/routes/handlers/misc_handlers.py @@ -0,0 +1,755 @@ +"""Handlers for miscellaneous routes. + +The legacy :mod:`py.routes.misc_routes` module bundled HTTP wiring and +business logic in a single class. This module mirrors the model route +architecture by splitting the responsibilities into dedicated handler +objects that can be composed by the route controller. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import re +import subprocess +import sys +from dataclasses import dataclass +from typing import Awaitable, Callable, Dict, Mapping, Protocol + +from aiohttp import web + +from ...config import config +from ...services.metadata_service import ( + get_metadata_archive_manager, + get_metadata_provider, + update_metadata_providers, +) +from ...services.service_registry import ServiceRegistry +from ...services.settings_manager import settings as default_settings +from ...services.websocket_manager import ws_manager +from ...services.downloader import get_downloader +from ...utils.constants import DEFAULT_NODE_COLOR, NODE_TYPES, SUPPORTED_MEDIA_EXTENSIONS +from ...utils.lora_metadata import extract_trained_words +from ...utils.usage_stats import UsageStats + +logger = logging.getLogger(__name__) + + +class PromptServerProtocol(Protocol): + """Subset of PromptServer used by the handlers.""" + + instance: "PromptServerProtocol" + + def send_sync(self, event: str, payload: dict) -> None: # pragma: no cover - protocol + ... + + +class DownloaderProtocol(Protocol): + async def refresh_session(self) -> None: # pragma: no cover - protocol + ... + + +class UsageStatsFactory(Protocol): + def __call__(self) -> UsageStats: # pragma: no cover - protocol + ... + + +class MetadataProviderProtocol(Protocol): + async def get_model_versions(self, model_id: int) -> dict | None: # pragma: no cover - protocol + ... + + +class MetadataArchiveManagerProtocol(Protocol): + async def download_and_extract_database( + self, progress_callback: Callable[[str, str], None] + ) -> bool: # pragma: no cover - protocol + ... + + async def remove_database(self) -> bool: # pragma: no cover - protocol + ... + + def is_database_available(self) -> bool: # pragma: no cover - protocol + ... + + def get_database_path(self) -> str | None: # pragma: no cover - protocol + ... + + +class NodeRegistry: + """Thread-safe registry for tracking LoRA nodes in active workflows.""" + + def __init__(self) -> None: + self._lock = asyncio.Lock() + self._nodes: Dict[int, dict] = {} + self._registry_updated = asyncio.Event() + + async def register_nodes(self, nodes: list[dict]) -> None: + async with self._lock: + self._nodes.clear() + for node in nodes: + node_id = node["node_id"] + node_type = node.get("type", "") + type_id = NODE_TYPES.get(node_type, 0) + bgcolor = node.get("bgcolor") or DEFAULT_NODE_COLOR + self._nodes[node_id] = { + "id": node_id, + "bgcolor": bgcolor, + "title": node.get("title"), + "type": type_id, + "type_name": node_type, + } + logger.debug("Registered %s nodes in registry", len(nodes)) + self._registry_updated.set() + + async def get_registry(self) -> dict: + async with self._lock: + return { + "nodes": dict(self._nodes), + "node_count": len(self._nodes), + } + + async def wait_for_update(self, timeout: float = 1.0) -> bool: + self._registry_updated.clear() + try: + await asyncio.wait_for(self._registry_updated.wait(), timeout=timeout) + return True + except asyncio.TimeoutError: + return False + + +class HealthCheckHandler: + async def health_check(self, request: web.Request) -> web.Response: + return web.json_response({"status": "ok"}) + + +class SettingsHandler: + """Sync settings between backend and frontend.""" + + _SYNC_KEYS = ( + "civitai_api_key", + "default_lora_root", + "default_checkpoint_root", + "default_embedding_root", + "base_model_path_mappings", + "download_path_templates", + "enable_metadata_archive_db", + "language", + "proxy_enabled", + "proxy_type", + "proxy_host", + "proxy_port", + "proxy_username", + "proxy_password", + "example_images_path", + "optimize_example_images", + "auto_download_example_images", + "blur_mature_content", + "autoplay_on_hover", + "display_density", + "card_info_display", + "include_trigger_words", + "show_only_sfw", + "compact_mode", + ) + + _PROXY_KEYS = {"proxy_enabled", "proxy_host", "proxy_port", "proxy_username", "proxy_password", "proxy_type"} + + def __init__( + self, + *, + settings_service=default_settings, + metadata_provider_updater: Callable[[], Awaitable[None]] = update_metadata_providers, + downloader_factory: Callable[[], Awaitable[DownloaderProtocol]] = get_downloader, + ) -> None: + self._settings = settings_service + self._metadata_provider_updater = metadata_provider_updater + self._downloader_factory = downloader_factory + + async def get_settings(self, request: web.Request) -> web.Response: + try: + response_data = {} + for key in self._SYNC_KEYS: + value = self._settings.get(key) + if value is not None: + response_data[key] = value + return web.json_response({"success": True, "settings": response_data}) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Error getting settings: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def update_settings(self, request: web.Request) -> web.Response: + try: + data = await request.json() + proxy_changed = False + + for key, value in data.items(): + if value == self._settings.get(key): + continue + + if key == "example_images_path" and value: + validation_error = self._validate_example_images_path(value) + if validation_error: + return web.json_response({"success": False, "error": validation_error}) + + if value == "__DELETE__" and key in ("proxy_username", "proxy_password"): + self._settings.delete(key) + else: + self._settings.set(key, value) + + if key == "enable_metadata_archive_db": + await self._metadata_provider_updater() + + if key in self._PROXY_KEYS: + proxy_changed = True + + if proxy_changed: + downloader = await self._downloader_factory() + await downloader.refresh_session() + + return web.json_response({"success": True}) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Error updating settings: %s", exc, exc_info=True) + return web.Response(status=500, text=str(exc)) + + def _validate_example_images_path(self, folder_path: str) -> str | None: + if not os.path.exists(folder_path): + return f"Path does not exist: {folder_path}" + if not os.path.isdir(folder_path): + return "Please set a dedicated folder for example images." + if not self._is_dedicated_example_images_folder(folder_path): + return "Please set a dedicated folder for example images." + return None + + def _is_dedicated_example_images_folder(self, folder_path: str) -> bool: + try: + items = os.listdir(folder_path) + if not items: + return True + for item in items: + item_path = os.path.join(folder_path, item) + if item == ".download_progress.json" and os.path.isfile(item_path): + continue + if os.path.isdir(item_path) and re.fullmatch(r"[a-fA-F0-9]{64}", item): + continue + return False + return True + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Error checking if folder is dedicated: %s", exc) + return False + + +class UsageStatsHandler: + def __init__(self, usage_stats_factory: UsageStatsFactory = UsageStats) -> None: + self._usage_stats_factory = usage_stats_factory + + async def update_usage_stats(self, request: web.Request) -> web.Response: + try: + data = await request.json() + prompt_id = data.get("prompt_id") + if not prompt_id: + return web.json_response({"success": False, "error": "Missing prompt_id"}, status=400) + usage_stats = self._usage_stats_factory() + await usage_stats.process_execution(prompt_id) + return web.json_response({"success": True}) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Failed to update usage stats: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_usage_stats(self, request: web.Request) -> web.Response: + try: + usage_stats = self._usage_stats_factory() + stats = await usage_stats.get_stats() + stats_response = {"success": True, "data": stats, "format_version": 2} + return web.json_response(stats_response) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Failed to get usage stats: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + +class LoraCodeHandler: + def __init__(self, prompt_server: type[PromptServerProtocol]) -> None: + self._prompt_server = prompt_server + + async def update_lora_code(self, request: web.Request) -> web.Response: + try: + data = await request.json() + node_ids = data.get("node_ids") + lora_code = data.get("lora_code", "") + mode = data.get("mode", "append") + + if not lora_code: + return web.json_response({"success": False, "error": "Missing lora_code parameter"}, status=400) + + results = [] + if node_ids is None: + try: + self._prompt_server.instance.send_sync( + "lora_code_update", {"id": -1, "lora_code": lora_code, "mode": mode} + ) + results.append({"node_id": "broadcast", "success": True}) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Error broadcasting lora code: %s", exc) + results.append({"node_id": "broadcast", "success": False, "error": str(exc)}) + else: + for node_id in node_ids: + try: + self._prompt_server.instance.send_sync( + "lora_code_update", + {"id": node_id, "lora_code": lora_code, "mode": mode}, + ) + results.append({"node_id": node_id, "success": True}) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Error sending lora code to node %s: %s", node_id, exc) + results.append({"node_id": node_id, "success": False, "error": str(exc)}) + + return web.json_response({"success": True, "results": results}) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Failed to update lora code: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + +class TrainedWordsHandler: + async def get_trained_words(self, request: web.Request) -> web.Response: + try: + file_path = request.query.get("file_path") + if not file_path: + return web.json_response({"success": False, "error": "Missing file_path parameter"}, status=400) + if not os.path.exists(file_path): + return web.json_response({"success": False, "error": "File not found"}, status=404) + if not file_path.endswith(".safetensors"): + return web.json_response({"success": False, "error": "File must be a safetensors file"}, status=400) + + trained_words, class_tokens = await extract_trained_words(file_path) + return web.json_response( + { + "success": True, + "trained_words": trained_words, + "class_tokens": class_tokens, + } + ) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Failed to get trained words: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + +class ModelExampleFilesHandler: + async def get_model_example_files(self, request: web.Request) -> web.Response: + try: + model_path = request.query.get("model_path") + if not model_path: + return web.json_response({"success": False, "error": "Missing model_path parameter"}, status=400) + model_dir = os.path.dirname(model_path) + if not os.path.exists(model_dir): + return web.json_response({"success": False, "error": "Model directory not found"}, status=404) + + base_name = os.path.splitext(os.path.basename(model_path))[0] + files = [] + pattern = f"{base_name}.example." + for file in os.listdir(model_dir): + if not file.startswith(pattern): + continue + file_full_path = os.path.join(model_dir, file) + if not os.path.isfile(file_full_path): + continue + file_ext = os.path.splitext(file)[1].lower() + if file_ext not in SUPPORTED_MEDIA_EXTENSIONS["images"] and file_ext not in SUPPORTED_MEDIA_EXTENSIONS["videos"]: + continue + try: + index = int(file[len(pattern) :].split(".")[0]) + except (ValueError, IndexError): + index = float("inf") + static_url = config.get_preview_static_url(file_full_path) + files.append( + { + "name": file, + "path": static_url, + "extension": file_ext, + "is_video": file_ext in SUPPORTED_MEDIA_EXTENSIONS["videos"], + "index": index, + } + ) + + files.sort(key=lambda item: item["index"]) + for file in files: + file.pop("index", None) + + return web.json_response({"success": True, "files": files}) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Failed to get model example files: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + +@dataclass +class ServiceRegistryAdapter: + get_lora_scanner: Callable[[], Awaitable] + get_checkpoint_scanner: Callable[[], Awaitable] + get_embedding_scanner: Callable[[], Awaitable] + + +class ModelLibraryHandler: + def __init__(self, service_registry: ServiceRegistryAdapter, metadata_provider_factory: Callable[[], Awaitable[MetadataProviderProtocol | None]]) -> None: + self._service_registry = service_registry + self._metadata_provider_factory = metadata_provider_factory + + async def check_model_exists(self, request: web.Request) -> web.Response: + try: + model_id_str = request.query.get("modelId") + model_version_id_str = request.query.get("modelVersionId") + if not model_id_str: + return web.json_response({"success": False, "error": "Missing required parameter: modelId"}, status=400) + try: + model_id = int(model_id_str) + except ValueError: + return web.json_response({"success": False, "error": "Parameter modelId must be an integer"}, status=400) + + lora_scanner = await self._service_registry.get_lora_scanner() + checkpoint_scanner = await self._service_registry.get_checkpoint_scanner() + embedding_scanner = await self._service_registry.get_embedding_scanner() + + if model_version_id_str: + try: + model_version_id = int(model_version_id_str) + except ValueError: + return web.json_response({"success": False, "error": "Parameter modelVersionId must be an integer"}, status=400) + + exists = False + model_type = None + if await lora_scanner.check_model_version_exists(model_version_id): + exists = True + model_type = "lora" + elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_version_id): + exists = True + model_type = "checkpoint" + elif embedding_scanner and await embedding_scanner.check_model_version_exists(model_version_id): + exists = True + model_type = "embedding" + + return web.json_response({"success": True, "exists": exists, "modelType": model_type if exists else None}) + + lora_versions = await lora_scanner.get_model_versions_by_id(model_id) + checkpoint_versions = [] + embedding_versions = [] + if not lora_versions and checkpoint_scanner: + checkpoint_versions = await checkpoint_scanner.get_model_versions_by_id(model_id) + if not lora_versions and not checkpoint_versions and embedding_scanner: + embedding_versions = await embedding_scanner.get_model_versions_by_id(model_id) + + model_type = None + version_ids: list[int] = [] + if lora_versions: + model_type = "lora" + version_ids = [version["versionId"] for version in lora_versions] + elif checkpoint_versions: + model_type = "checkpoint" + version_ids = [version["versionId"] for version in checkpoint_versions] + elif embedding_versions: + model_type = "embedding" + version_ids = [version["versionId"] for version in embedding_versions] + + return web.json_response({"success": True, "modelType": model_type, "modelVersionIds": version_ids}) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Failed to check model existence: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_model_versions_status(self, request: web.Request) -> web.Response: + try: + model_id_str = request.query.get("modelId") + if not model_id_str: + return web.json_response({"success": False, "error": "Missing required parameter: modelId"}, status=400) + try: + model_id = int(model_id_str) + except ValueError: + return web.json_response({"success": False, "error": "Parameter modelId must be an integer"}, status=400) + + metadata_provider = await self._metadata_provider_factory() + if not metadata_provider: + return web.json_response({"success": False, "error": "Metadata provider not available"}, status=503) + + response = await metadata_provider.get_model_versions(model_id) + if not response or not response.get("modelVersions"): + return web.json_response({"success": False, "error": "Model not found"}, status=404) + + versions = response.get("modelVersions", []) + model_name = response.get("name", "") + model_type = response.get("type", "").lower() + + scanner = None + normalized_type = None + if model_type in {"lora", "locon", "dora"}: + scanner = await self._service_registry.get_lora_scanner() + normalized_type = "lora" + elif model_type == "checkpoint": + scanner = await self._service_registry.get_checkpoint_scanner() + normalized_type = "checkpoint" + elif model_type == "textualinversion": + scanner = await self._service_registry.get_embedding_scanner() + normalized_type = "embedding" + else: + return web.json_response({"success": False, "error": f'Model type "{model_type}" is not supported'}, status=400) + + if not scanner: + return web.json_response({"success": False, "error": f'Scanner for type "{normalized_type}" is not available'}, status=503) + + local_versions = await scanner.get_model_versions_by_id(model_id) + local_version_ids = {version["versionId"] for version in local_versions} + + enriched_versions = [] + for version in versions: + version_id = version.get("id") + enriched_versions.append( + { + "id": version_id, + "name": version.get("name", ""), + "thumbnailUrl": version.get("images")[0]["url"] if version.get("images") else None, + "inLibrary": version_id in local_version_ids, + } + ) + + return web.json_response( + { + "success": True, + "modelId": model_id, + "modelName": model_name, + "modelType": model_type, + "versions": enriched_versions, + } + ) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Failed to get model versions status: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + +class MetadataArchiveHandler: + def __init__( + self, + *, + metadata_archive_manager_factory: Callable[[], Awaitable[MetadataArchiveManagerProtocol]] = get_metadata_archive_manager, + settings_service=default_settings, + metadata_provider_updater: Callable[[], Awaitable[None]] = update_metadata_providers, + ) -> None: + self._metadata_archive_manager_factory = metadata_archive_manager_factory + self._settings = settings_service + self._metadata_provider_updater = metadata_provider_updater + + async def download_metadata_archive(self, request: web.Request) -> web.Response: + try: + archive_manager = await self._metadata_archive_manager_factory() + download_id = request.query.get("download_id") + + def progress_callback(stage: str, message: str) -> None: + data = {"stage": stage, "message": message, "type": "metadata_archive_download"} + if download_id: + asyncio.create_task(ws_manager.broadcast_download_progress(download_id, data)) + else: + asyncio.create_task(ws_manager.broadcast(data)) + + success = await archive_manager.download_and_extract_database(progress_callback) + if success: + self._settings.set("enable_metadata_archive_db", True) + await self._metadata_provider_updater() + return web.json_response({"success": True, "message": "Metadata archive database downloaded and extracted successfully"}) + return web.json_response({"success": False, "error": "Failed to download and extract metadata archive database"}, status=500) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Error downloading metadata archive: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def remove_metadata_archive(self, request: web.Request) -> web.Response: + try: + archive_manager = await self._metadata_archive_manager_factory() + success = await archive_manager.remove_database() + if success: + self._settings.set("enable_metadata_archive_db", False) + await self._metadata_provider_updater() + return web.json_response({"success": True, "message": "Metadata archive database removed successfully"}) + return web.json_response({"success": False, "error": "Failed to remove metadata archive database"}, status=500) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Error removing metadata archive: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_metadata_archive_status(self, request: web.Request) -> web.Response: + try: + archive_manager = await self._metadata_archive_manager_factory() + is_available = archive_manager.is_database_available() + is_enabled = self._settings.get("enable_metadata_archive_db", False) + db_size = 0 + if is_available: + db_path = archive_manager.get_database_path() + if db_path and os.path.exists(db_path): + db_size = os.path.getsize(db_path) + return web.json_response( + { + "success": True, + "isAvailable": is_available, + "isEnabled": is_enabled, + "databaseSize": db_size, + "databasePath": archive_manager.get_database_path() if is_available else None, + } + ) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Error getting metadata archive status: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + +class FileSystemHandler: + async def open_file_location(self, request: web.Request) -> web.Response: + try: + data = await request.json() + file_path = data.get("file_path") + if not file_path: + return web.json_response({"success": False, "error": "Missing file_path parameter"}, status=400) + file_path = os.path.abspath(file_path) + if not os.path.isfile(file_path): + return web.json_response({"success": False, "error": "File does not exist"}, status=404) + + if os.name == "nt": + subprocess.Popen(["explorer", "/select,", file_path]) + elif os.name == "posix": + if sys.platform == "darwin": + subprocess.Popen(["open", "-R", file_path]) + else: + folder = os.path.dirname(file_path) + subprocess.Popen(["xdg-open", folder]) + + return web.json_response({"success": True, "message": f"Opened folder and selected file: {file_path}"}) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Failed to open file location: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + +class NodeRegistryHandler: + def __init__( + self, + node_registry: NodeRegistry, + prompt_server: type[PromptServerProtocol], + *, + standalone_mode: bool, + ) -> None: + self._node_registry = node_registry + self._prompt_server = prompt_server + self._standalone_mode = standalone_mode + + async def register_nodes(self, request: web.Request) -> web.Response: + try: + data = await request.json() + nodes = data.get("nodes", []) + if not isinstance(nodes, list): + return web.json_response({"success": False, "error": "nodes must be a list"}, status=400) + for index, node in enumerate(nodes): + if not isinstance(node, dict): + return web.json_response({"success": False, "error": f"Node {index} must be an object"}, status=400) + node_id = node.get("node_id") + if node_id is None: + return web.json_response({"success": False, "error": f"Node {index} missing node_id parameter"}, status=400) + try: + node["node_id"] = int(node_id) + except (TypeError, ValueError): + return web.json_response({"success": False, "error": f"Node {index} node_id must be an integer"}, status=400) + + await self._node_registry.register_nodes(nodes) + return web.json_response({"success": True, "message": f"{len(nodes)} nodes registered successfully"}) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Failed to register nodes: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_registry(self, request: web.Request) -> web.Response: + try: + if self._standalone_mode: + logger.warning("Registry refresh not available in standalone mode") + return web.json_response( + { + "success": False, + "error": "Standalone Mode Active", + "message": "Cannot interact with ComfyUI in standalone mode.", + }, + status=503, + ) + + try: + self._prompt_server.instance.send_sync("lora_registry_refresh", {}) + logger.debug("Sent registry refresh request to frontend") + except Exception as exc: + logger.error("Failed to send registry refresh message: %s", exc) + return web.json_response( + { + "success": False, + "error": "Communication Error", + "message": f"Failed to communicate with ComfyUI frontend: {exc}", + }, + status=500, + ) + + registry_updated = await self._node_registry.wait_for_update(timeout=1.0) + if not registry_updated: + logger.warning("Registry refresh timeout after 1 second") + return web.json_response( + { + "success": False, + "error": "Timeout Error", + "message": "Registry refresh timeout - ComfyUI frontend may not be responsive", + }, + status=408, + ) + + registry_info = await self._node_registry.get_registry() + return web.json_response({"success": True, "data": registry_info}) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Failed to get registry: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": "Internal Error", "message": str(exc)}, status=500) + + +class MiscHandlerSet: + """Aggregate handlers into a lookup compatible with the registrar.""" + + def __init__( + self, + *, + health: HealthCheckHandler, + settings: SettingsHandler, + usage_stats: UsageStatsHandler, + lora_code: LoraCodeHandler, + trained_words: TrainedWordsHandler, + model_examples: ModelExampleFilesHandler, + node_registry: NodeRegistryHandler, + model_library: ModelLibraryHandler, + metadata_archive: MetadataArchiveHandler, + filesystem: FileSystemHandler, + ) -> None: + self.health = health + self.settings = settings + self.usage_stats = usage_stats + self.lora_code = lora_code + self.trained_words = trained_words + self.model_examples = model_examples + self.node_registry = node_registry + self.model_library = model_library + self.metadata_archive = metadata_archive + self.filesystem = filesystem + + def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]: + return { + "health_check": self.health.health_check, + "get_settings": self.settings.get_settings, + "update_settings": self.settings.update_settings, + "update_usage_stats": self.usage_stats.update_usage_stats, + "get_usage_stats": self.usage_stats.get_usage_stats, + "update_lora_code": self.lora_code.update_lora_code, + "get_trained_words": self.trained_words.get_trained_words, + "get_model_example_files": self.model_examples.get_model_example_files, + "register_nodes": self.node_registry.register_nodes, + "get_registry": self.node_registry.get_registry, + "check_model_exists": self.model_library.check_model_exists, + "download_metadata_archive": self.metadata_archive.download_metadata_archive, + "remove_metadata_archive": self.metadata_archive.remove_metadata_archive, + "get_metadata_archive_status": self.metadata_archive.get_metadata_archive_status, + "get_model_versions_status": self.model_library.get_model_versions_status, + "open_file_location": self.filesystem.open_file_location, + } + + +def build_service_registry_adapter() -> ServiceRegistryAdapter: + return ServiceRegistryAdapter( + get_lora_scanner=ServiceRegistry.get_lora_scanner, + get_checkpoint_scanner=ServiceRegistry.get_checkpoint_scanner, + get_embedding_scanner=ServiceRegistry.get_embedding_scanner, + ) diff --git a/py/routes/misc_route_registrar.py b/py/routes/misc_route_registrar.py new file mode 100644 index 00000000..22c2f5bc --- /dev/null +++ b/py/routes/misc_route_registrar.py @@ -0,0 +1,67 @@ +"""Route registrar for miscellaneous endpoints. + +This module mirrors the model route registrar architecture so that +miscellaneous endpoints share a consistent registration flow. +""" + +from dataclasses import dataclass +from typing import Callable, Iterable, Mapping + +from aiohttp import web + + +@dataclass(frozen=True) +class RouteDefinition: + """Declarative definition for a HTTP route.""" + + method: str + path: str + handler_name: str + + +MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( + RouteDefinition("GET", "/api/lm/settings", "get_settings"), + RouteDefinition("POST", "/api/lm/settings", "update_settings"), + RouteDefinition("GET", "/api/lm/health-check", "health_check"), + RouteDefinition("POST", "/api/lm/open-file-location", "open_file_location"), + RouteDefinition("POST", "/api/lm/update-usage-stats", "update_usage_stats"), + RouteDefinition("GET", "/api/lm/get-usage-stats", "get_usage_stats"), + RouteDefinition("POST", "/api/lm/update-lora-code", "update_lora_code"), + RouteDefinition("GET", "/api/lm/trained-words", "get_trained_words"), + RouteDefinition("GET", "/api/lm/model-example-files", "get_model_example_files"), + RouteDefinition("POST", "/api/lm/register-nodes", "register_nodes"), + RouteDefinition("GET", "/api/lm/get-registry", "get_registry"), + RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"), + RouteDefinition("POST", "/api/lm/download-metadata-archive", "download_metadata_archive"), + RouteDefinition("POST", "/api/lm/remove-metadata-archive", "remove_metadata_archive"), + RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"), + RouteDefinition("GET", "/api/lm/model-versions-status", "get_model_versions_status"), +) + + +class MiscRouteRegistrar: + """Bind miscellaneous route definitions to an aiohttp router.""" + + _METHOD_MAP = { + "GET": "add_get", + "POST": "add_post", + "PUT": "add_put", + "DELETE": "add_delete", + } + + def __init__(self, app: web.Application) -> None: + self._app = app + + def register_routes( + self, + handler_lookup: Mapping[str, Callable[[web.Request], object]], + *, + definitions: Iterable[RouteDefinition] = MISC_ROUTE_DEFINITIONS, + ) -> None: + for definition in definitions: + self._bind(definition.method, definition.path, handler_lookup[definition.handler_name]) + + def _bind(self, method: str, path: str, handler: Callable) -> None: + add_method_name = self._METHOD_MAP[method.upper()] + add_method = getattr(self._app.router, add_method_name) + add_method(path, handler) diff --git a/py/routes/misc_routes.py b/py/routes/misc_routes.py index 9d1c5782..de40087c 100644 --- a/py/routes/misc_routes.py +++ b/py/routes/misc_routes.py @@ -1,1058 +1,135 @@ +"""Route controller for miscellaneous endpoints.""" + +from __future__ import annotations + import logging import os -import sys -import threading -import asyncio -import subprocess -import re -from server import PromptServer # type: ignore +from typing import Awaitable, Callable, Mapping + from aiohttp import web +from server import PromptServer # type: ignore + +from ..services.metadata_service import ( + get_metadata_archive_manager, + get_metadata_provider, + update_metadata_providers, +) from ..services.settings_manager import settings -from ..utils.usage_stats import UsageStats -from ..utils.lora_metadata import extract_trained_words -from ..config import config -from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS, NODE_TYPES, DEFAULT_NODE_COLOR -from ..services.service_registry import ServiceRegistry -from ..services.metadata_service import get_metadata_archive_manager, update_metadata_providers, get_metadata_provider -from ..services.websocket_manager import ws_manager from ..services.downloader import get_downloader +from ..utils.usage_stats import UsageStats +from .handlers.misc_handlers import ( + FileSystemHandler, + HealthCheckHandler, + LoraCodeHandler, + MetadataArchiveHandler, + MiscHandlerSet, + ModelExampleFilesHandler, + ModelLibraryHandler, + NodeRegistry, + NodeRegistryHandler, + SettingsHandler, + TrainedWordsHandler, + UsageStatsHandler, + build_service_registry_adapter, +) +from .misc_route_registrar import MiscRouteRegistrar + logger = logging.getLogger(__name__) -standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" +standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get( + "HF_HUB_DISABLE_TELEMETRY", "0" +) == "0" -# Node registry for tracking active workflow nodes -class NodeRegistry: - """Thread-safe registry for tracking Lora nodes in active workflows""" - - def __init__(self): - self._lock = threading.RLock() - self._nodes = {} # node_id -> node_info - self._registry_updated = threading.Event() - - def register_nodes(self, nodes): - """Register multiple nodes at once, replacing existing registry""" - with self._lock: - # Clear existing registry - self._nodes.clear() - - # Register all new nodes - for node in nodes: - node_id = node['node_id'] - node_type = node.get('type', '') - - # Convert node type name to integer - type_id = NODE_TYPES.get(node_type, 0) # 0 for unknown types - - # Handle null bgcolor with default color - bgcolor = node.get('bgcolor') - if bgcolor is None: - bgcolor = DEFAULT_NODE_COLOR - - self._nodes[node_id] = { - 'id': node_id, - 'bgcolor': bgcolor, - 'title': node.get('title'), - 'type': type_id, - 'type_name': node_type - } - - logger.debug(f"Registered {len(nodes)} nodes in registry") - - # Signal that registry has been updated - self._registry_updated.set() - - def get_registry(self): - """Get current registry information""" - with self._lock: - return { - 'nodes': dict(self._nodes), # Return a copy - 'node_count': len(self._nodes) - } - - def clear_registry(self): - """Clear the entire registry""" - with self._lock: - self._nodes.clear() - logger.info("Node registry cleared") - - def wait_for_update(self, timeout=1.0): - """Wait for registry update with timeout""" - self._registry_updated.clear() - return self._registry_updated.wait(timeout) - -# Global registry instance -node_registry = NodeRegistry() class MiscRoutes: - """Miscellaneous routes for various utility functions""" - - @staticmethod - def is_dedicated_example_images_folder(folder_path): - """ - Check if a folder is a dedicated example images folder. - - A dedicated folder should either be: - 1. Empty - 2. Only contain .download_progress.json file and/or folders with valid SHA256 hash names (64 hex characters) - - Args: - folder_path (str): Path to the folder to check - - Returns: - bool: True if the folder is dedicated, False otherwise - """ - try: - if not os.path.exists(folder_path) or not os.path.isdir(folder_path): - return False - - items = os.listdir(folder_path) - - # Empty folder is considered dedicated - if not items: - return True - - # Check each item in the folder - for item in items: - item_path = os.path.join(folder_path, item) - - # Allow .download_progress.json file - if item == '.download_progress.json' and os.path.isfile(item_path): - continue - - # Allow folders with valid SHA256 hash names (64 hex characters) - if os.path.isdir(item_path): - # Check if the folder name is a valid SHA256 hash - if re.match(r'^[a-fA-F0-9]{64}$', item): - continue - - # If we encounter anything else, it's not a dedicated folder - return False - - return True - - except Exception as e: - logger.error(f"Error checking if folder is dedicated: {e}") - return False - - @staticmethod - def setup_routes(app): - """Register miscellaneous routes""" - app.router.add_get('/api/lm/settings', MiscRoutes.get_settings) - app.router.add_post('/api/lm/settings', MiscRoutes.update_settings) + """Route controller that mirrors the model route architecture.""" - app.router.add_get('/api/lm/health-check', lambda request: web.json_response({'status': 'ok'})) + def __init__( + self, + *, + settings_service=settings, + usage_stats_factory: Callable[[], UsageStats] = UsageStats, + prompt_server: type[PromptServer] = PromptServer, + service_registry_adapter=build_service_registry_adapter(), + metadata_provider_factory=get_metadata_provider, + metadata_archive_manager_factory=get_metadata_archive_manager, + metadata_provider_updater=update_metadata_providers, + downloader_factory=get_downloader, + registrar_factory=MiscRouteRegistrar, + handler_set_factory=MiscHandlerSet, + node_registry: NodeRegistry | None = None, + standalone_mode_flag: bool = standalone_mode, + ) -> None: + self._settings = settings_service + self._usage_stats_factory = usage_stats_factory + self._prompt_server = prompt_server + self._service_registry_adapter = service_registry_adapter + self._metadata_provider_factory = metadata_provider_factory + self._metadata_archive_manager_factory = metadata_archive_manager_factory + self._metadata_provider_updater = metadata_provider_updater + self._downloader_factory = downloader_factory + self._registrar_factory = registrar_factory + self._handler_set_factory = handler_set_factory + self._node_registry = node_registry or NodeRegistry() + self._standalone_mode = standalone_mode_flag - app.router.add_post('/api/lm/open-file-location', MiscRoutes.open_file_location) - - # Usage stats routes - app.router.add_post('/api/lm/update-usage-stats', MiscRoutes.update_usage_stats) - app.router.add_get('/api/lm/get-usage-stats', MiscRoutes.get_usage_stats) - - # Lora code update endpoint - app.router.add_post('/api/lm/update-lora-code', MiscRoutes.update_lora_code) - - # Add new route for getting trained words - app.router.add_get('/api/lm/trained-words', MiscRoutes.get_trained_words) - - # Add new route for getting model example files - app.router.add_get('/api/lm/model-example-files', MiscRoutes.get_model_example_files) - - # Node registry endpoints - app.router.add_post('/api/lm/register-nodes', MiscRoutes.register_nodes) - app.router.add_get('/api/lm/get-registry', MiscRoutes.get_registry) - - # Add new route for checking if a model exists in the library - app.router.add_get('/api/lm/check-model-exists', MiscRoutes.check_model_exists) - - # Add routes for metadata archive database management - app.router.add_post('/api/lm/download-metadata-archive', MiscRoutes.download_metadata_archive) - app.router.add_post('/api/lm/remove-metadata-archive', MiscRoutes.remove_metadata_archive) - app.router.add_get('/api/lm/metadata-archive-status', MiscRoutes.get_metadata_archive_status) - - # Add route for checking model versions in library - app.router.add_get('/api/lm/model-versions-status', MiscRoutes.get_model_versions_status) + self._handler_mapping: Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]] | None = None @staticmethod - async def get_settings(request): - """Get application settings that should be synced to frontend""" - try: - # Define keys that should be synced from backend to frontend - sync_keys = [ - 'civitai_api_key', - 'default_lora_root', - 'default_checkpoint_root', - 'default_embedding_root', - 'base_model_path_mappings', - 'download_path_templates', - 'enable_metadata_archive_db', - 'language', - 'proxy_enabled', - 'proxy_type', - 'proxy_host', - 'proxy_port', - 'proxy_username', - 'proxy_password', - 'example_images_path', - 'optimize_example_images', - 'auto_download_example_images', - 'blur_mature_content', - 'autoplay_on_hover', - 'display_density', - 'card_info_display', - 'include_trigger_words', - 'show_only_sfw', - 'compact_mode' - ] - - # Build response with only the keys that should be synced - response_data = {} - for key in sync_keys: - value = settings.get(key) - if value is not None: - response_data[key] = value - - return web.json_response({ - 'success': True, - 'settings': response_data - }) - - except Exception as e: - logger.error(f"Error getting settings: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) + def setup_routes(app: web.Application) -> None: + """Entry point used by the application bootstrap.""" + controller = MiscRoutes() + controller.bind(app) - @staticmethod - async def update_settings(request): - """Update application settings""" - try: - data = await request.json() - proxy_keys = {'proxy_enabled', 'proxy_host', 'proxy_port', 'proxy_username', 'proxy_password', 'proxy_type'} - proxy_changed = False - - # Validate and update settings - for key, value in data.items(): - if value == settings.get(key): - # No change, skip - continue - # Special handling for example_images_path - verify path exists and is dedicated - if key == 'example_images_path' and value: - if not os.path.exists(value): - return web.json_response({ - 'success': False, - 'error': f"Path does not exist: {value}" - }) - - # Check if folder is dedicated for example images - if not MiscRoutes.is_dedicated_example_images_folder(value): - return web.json_response({ - 'success': False, - 'error': "Please set a dedicated folder for example images." - }) - - # Path changed - server restart required for new path to take effect - old_path = settings.get('example_images_path') - if old_path != value: - logger.info(f"Example images path changed to {value} - server restart required") + def bind(self, app: web.Application) -> None: + registrar = self._registrar_factory(app) + registrar.register_routes(self._ensure_handler_mapping()) - # Handle deletion for proxy credentials - if value == '__DELETE__' and key in ('proxy_username', 'proxy_password'): - settings.delete(key) - else: - # Save to settings - settings.set(key, value) - - if key == 'enable_metadata_archive_db': - await update_metadata_providers() - - if key in proxy_keys: - proxy_changed = True + def _ensure_handler_mapping(self) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]: + if self._handler_mapping is None: + handler_set = self._create_handler_set() + self._handler_mapping = handler_set.to_route_mapping() + return self._handler_mapping - if proxy_changed: - downloader = await get_downloader() - await downloader.refresh_session() + def _create_handler_set(self) -> MiscHandlerSet: + health = HealthCheckHandler() + settings_handler = SettingsHandler( + settings_service=self._settings, + metadata_provider_updater=self._metadata_provider_updater, + downloader_factory=self._downloader_factory, + ) + usage_stats = UsageStatsHandler(usage_stats_factory=self._usage_stats_factory) + lora_code = LoraCodeHandler(prompt_server=self._prompt_server) + trained_words = TrainedWordsHandler() + model_examples = ModelExampleFilesHandler() + metadata_archive = MetadataArchiveHandler( + metadata_archive_manager_factory=self._metadata_archive_manager_factory, + settings_service=self._settings, + metadata_provider_updater=self._metadata_provider_updater, + ) + filesystem = FileSystemHandler() + node_registry_handler = NodeRegistryHandler( + node_registry=self._node_registry, + prompt_server=self._prompt_server, + standalone_mode=self._standalone_mode, + ) + model_library = ModelLibraryHandler( + service_registry=self._service_registry_adapter, + metadata_provider_factory=self._metadata_provider_factory, + ) - return web.json_response({'success': True}) - except Exception as e: - logger.error(f"Error updating settings: {e}", exc_info=True) - return web.Response(status=500, text=str(e)) - - @staticmethod - async def update_usage_stats(request): - """ - Update usage statistics based on a prompt_id - - Expects a JSON body with: - { - "prompt_id": "string" - } - """ - try: - # Parse the request body - data = await request.json() - prompt_id = data.get('prompt_id') - - if not prompt_id: - return web.json_response({ - 'success': False, - 'error': 'Missing prompt_id' - }, status=400) - - # Call the UsageStats to process this prompt_id synchronously - usage_stats = UsageStats() - await usage_stats.process_execution(prompt_id) - - return web.json_response({ - 'success': True - }) - - except Exception as e: - logger.error(f"Failed to update usage stats: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - @staticmethod - async def get_usage_stats(request): - """Get current usage statistics""" - try: - usage_stats = UsageStats() - stats = await usage_stats.get_stats() - - # Add version information to help clients handle format changes - stats_response = { - 'success': True, - 'data': stats, - 'format_version': 2 # Indicate this is the new format with history - } - - return web.json_response(stats_response) - - except Exception as e: - logger.error(f"Failed to get usage stats: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - @staticmethod - async def update_lora_code(request): - """ - Update Lora code in ComfyUI nodes - - Expects a JSON body with: - { - "node_ids": [123, 456], # Optional - List of node IDs to update (for browser mode) - "lora_code": "", # The Lora code to send - "mode": "append" # or "replace" - whether to append or replace existing code - } - """ - try: - # Parse the request body - data = await request.json() - node_ids = data.get('node_ids') - lora_code = data.get('lora_code', '') - mode = data.get('mode', 'append') - - if not lora_code: - return web.json_response({ - 'success': False, - 'error': 'Missing lora_code parameter' - }, status=400) - - results = [] - - # Desktop mode: no specific node_ids provided - if node_ids is None: - try: - # Send broadcast message with id=-1 to all Lora Loader nodes - PromptServer.instance.send_sync("lora_code_update", { - "id": -1, - "lora_code": lora_code, - "mode": mode - }) - results.append({ - 'node_id': 'broadcast', - 'success': True - }) - except Exception as e: - logger.error(f"Error broadcasting lora code: {e}") - results.append({ - 'node_id': 'broadcast', - 'success': False, - 'error': str(e) - }) - else: - # Browser mode: send to specific nodes - for node_id in node_ids: - try: - # Send the message to the frontend - PromptServer.instance.send_sync("lora_code_update", { - "id": node_id, - "lora_code": lora_code, - "mode": mode - }) - results.append({ - 'node_id': node_id, - 'success': True - }) - except Exception as e: - logger.error(f"Error sending lora code to node {node_id}: {e}") - results.append({ - 'node_id': node_id, - 'success': False, - 'error': str(e) - }) - - return web.json_response({ - 'success': True, - 'results': results - }) - - except Exception as e: - logger.error(f"Failed to update lora code: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) + return self._handler_set_factory( + health=health, + settings=settings_handler, + usage_stats=usage_stats, + lora_code=lora_code, + trained_words=trained_words, + model_examples=model_examples, + node_registry=node_registry_handler, + model_library=model_library, + metadata_archive=metadata_archive, + filesystem=filesystem, + ) - @staticmethod - async def get_trained_words(request): - """ - Get trained words from a safetensors file, sorted by frequency - - Expects a query parameter: - file_path: Path to the safetensors file - """ - try: - # Get file path from query parameters - file_path = request.query.get('file_path') - - if not file_path: - return web.json_response({ - 'success': False, - 'error': 'Missing file_path parameter' - }, status=400) - - # Check if file exists and is a safetensors file - if not os.path.exists(file_path): - return web.json_response({ - 'success': False, - 'error': f"File not found: {file_path}" - }, status=404) - - if not file_path.lower().endswith('.safetensors'): - return web.json_response({ - 'success': False, - 'error': 'File is not a safetensors file' - }, status=400) - - # Extract trained words and class_tokens - trained_words, class_tokens = await extract_trained_words(file_path) - - # Return result with both trained words and class tokens - return web.json_response({ - 'success': True, - 'trained_words': trained_words, - 'class_tokens': class_tokens - }) - - except Exception as e: - logger.error(f"Failed to get trained words: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - @staticmethod - async def get_model_example_files(request): - """ - Get list of example image files for a specific model based on file path - - Expects: - - file_path in query parameters - - Returns: - - List of image files with their paths as static URLs - """ - try: - # Get the model file path from query parameters - file_path = request.query.get('file_path') - - if not file_path: - return web.json_response({ - 'success': False, - 'error': 'Missing file_path parameter' - }, status=400) - - # Extract directory and base filename - model_dir = os.path.dirname(file_path) - model_filename = os.path.basename(file_path) - model_name = os.path.splitext(model_filename)[0] - - # Check if the directory exists - if not os.path.exists(model_dir): - return web.json_response({ - 'success': False, - 'error': 'Model directory not found', - 'files': [] - }, status=404) - - # Look for files matching the pattern modelname.example.. - files = [] - pattern = f"{model_name}.example." - - for file in os.listdir(model_dir): - file_lower = file.lower() - if file_lower.startswith(pattern.lower()): - file_full_path = os.path.join(model_dir, file) - if os.path.isfile(file_full_path): - # Check if the file is a supported media file - file_ext = os.path.splitext(file)[1].lower() - if (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or - file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']): - - # Extract the index from the filename - try: - # Extract the part after '.example.' and before file extension - index_part = file[len(pattern):].split('.')[0] - # Try to parse it as an integer - index = int(index_part) - except (ValueError, IndexError): - # If we can't parse the index, use infinity to sort at the end - index = float('inf') - - # Convert file path to static URL - static_url = config.get_preview_static_url(file_full_path) - - files.append({ - 'name': file, - 'path': static_url, - 'extension': file_ext, - 'is_video': file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos'], - 'index': index - }) - - # Sort files by their index for consistent ordering - files.sort(key=lambda x: x['index']) - # Remove the index field as it's only used for sorting - for file in files: - file.pop('index', None) - - return web.json_response({ - 'success': True, - 'files': files - }) - - except Exception as e: - logger.error(f"Failed to get model example files: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - @staticmethod - async def register_nodes(request): - """ - Register multiple Lora nodes at once - - Expects a JSON body with: - { - "nodes": [ - { - "node_id": 123, - "bgcolor": "#535", - "title": "Lora Loader (LoraManager)" - }, - ... - ] - } - """ - try: - data = await request.json() - - # Validate required fields - nodes = data.get('nodes', []) - - if not isinstance(nodes, list): - return web.json_response({ - 'success': False, - 'error': 'nodes must be a list' - }, status=400) - - # Validate each node - for i, node in enumerate(nodes): - if not isinstance(node, dict): - return web.json_response({ - 'success': False, - 'error': f'Node {i} must be an object' - }, status=400) - - node_id = node.get('node_id') - if node_id is None: - return web.json_response({ - 'success': False, - 'error': f'Node {i} missing node_id parameter' - }, status=400) - - # Validate node_id is an integer - try: - node['node_id'] = int(node_id) - except (ValueError, TypeError): - return web.json_response({ - 'success': False, - 'error': f'Node {i} node_id must be an integer' - }, status=400) - - # Register all nodes - node_registry.register_nodes(nodes) - - return web.json_response({ - 'success': True, - 'message': f'{len(nodes)} nodes registered successfully' - }) - - except Exception as e: - logger.error(f"Failed to register nodes: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - @staticmethod - async def get_registry(request): - """Get current node registry information by refreshing from frontend""" - try: - # Check if running in standalone mode - if standalone_mode: - logger.warning("Registry refresh not available in standalone mode") - return web.json_response({ - 'success': False, - 'error': 'Standalone Mode Active', - 'message': 'Cannot interact with ComfyUI in standalone mode.' - }, status=503) - - # Send message to frontend to refresh registry - try: - PromptServer.instance.send_sync("lora_registry_refresh", {}) - logger.debug("Sent registry refresh request to frontend") - except Exception as e: - logger.error(f"Failed to send registry refresh message: {e}") - return web.json_response({ - 'success': False, - 'error': 'Communication Error', - 'message': f'Failed to communicate with ComfyUI frontend: {str(e)}' - }, status=500) - - # Wait for registry update with timeout - def wait_for_registry(): - return node_registry.wait_for_update(timeout=1.0) - - # Run the wait in a thread to avoid blocking the event loop - loop = asyncio.get_event_loop() - registry_updated = await loop.run_in_executor(None, wait_for_registry) - - if not registry_updated: - logger.warning("Registry refresh timeout after 1 second") - return web.json_response({ - 'success': False, - 'error': 'Timeout Error', - 'message': 'Registry refresh timeout - ComfyUI frontend may not be responsive' - }, status=408) - - # Get updated registry - registry_info = node_registry.get_registry() - - return web.json_response({ - 'success': True, - 'data': registry_info - }) - - except Exception as e: - logger.error(f"Failed to get registry: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': 'Internal Error', - 'message': str(e) - }, status=500) - - @staticmethod - async def check_model_exists(request): - """ - Check if a model with specified modelId and optionally modelVersionId exists in the library - - Expects query parameters: - - modelId: int - Civitai model ID (required) - - modelVersionId: int - Civitai model version ID (optional) - - Returns: - - If modelVersionId is provided: JSON with a boolean 'exists' field - - If modelVersionId is not provided: JSON with a list of modelVersionIds that exist in the library - """ - try: - # Get the modelId and modelVersionId from query parameters - model_id_str = request.query.get('modelId') - model_version_id_str = request.query.get('modelVersionId') - - # Validate modelId parameter (required) - if not model_id_str: - return web.json_response({ - 'success': False, - 'error': 'Missing required parameter: modelId' - }, status=400) - - try: - # Convert modelId to integer - model_id = int(model_id_str) - except ValueError: - return web.json_response({ - 'success': False, - 'error': 'Parameter modelId must be an integer' - }, status=400) - - # Get all scanners - lora_scanner = await ServiceRegistry.get_lora_scanner() - checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() - embedding_scanner = await ServiceRegistry.get_embedding_scanner() - - # If modelVersionId is provided, check for specific version - if model_version_id_str: - try: - model_version_id = int(model_version_id_str) - except ValueError: - return web.json_response({ - 'success': False, - 'error': 'Parameter modelVersionId must be an integer' - }, status=400) - - # Check lora scanner first - exists = False - model_type = None - - if await lora_scanner.check_model_version_exists(model_version_id): - exists = True - model_type = 'lora' - elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_version_id): - exists = True - model_type = 'checkpoint' - elif embedding_scanner and await embedding_scanner.check_model_version_exists(model_version_id): - exists = True - model_type = 'embedding' - - return web.json_response({ - 'success': True, - 'exists': exists, - 'modelType': model_type if exists else None - }) - - # If modelVersionId is not provided, return all version IDs for the model - else: - lora_versions = await lora_scanner.get_model_versions_by_id(model_id) - checkpoint_versions = [] - embedding_versions = [] - - # 优先lora,其次checkpoint,最后embedding - if not lora_versions: - checkpoint_versions = await checkpoint_scanner.get_model_versions_by_id(model_id) - if not lora_versions and not checkpoint_versions: - embedding_versions = await embedding_scanner.get_model_versions_by_id(model_id) - - model_type = None - versions = [] - - if lora_versions: - model_type = 'lora' - versions = lora_versions - elif checkpoint_versions: - model_type = 'checkpoint' - versions = checkpoint_versions - elif embedding_versions: - model_type = 'embedding' - versions = embedding_versions - - return web.json_response({ - 'success': True, - 'modelId': model_id, - 'modelType': model_type, - 'versions': versions - }) - - except Exception as e: - logger.error(f"Failed to check model existence: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - @staticmethod - async def download_metadata_archive(request): - """Download and extract the metadata archive database""" - try: - archive_manager = await get_metadata_archive_manager() - - # Get the download_id from query parameters if provided - download_id = request.query.get('download_id') - - # Progress callback to send updates via WebSocket - def progress_callback(stage, message): - data = { - 'stage': stage, - 'message': message, - 'type': 'metadata_archive_download' - } - - if download_id: - # Send to specific download WebSocket if download_id is provided - asyncio.create_task(ws_manager.broadcast_download_progress(download_id, data)) - else: - # Fallback to general broadcast - asyncio.create_task(ws_manager.broadcast(data)) - - # Download and extract in background - success = await archive_manager.download_and_extract_database(progress_callback) - - if success: - # Update settings to enable metadata archive - settings.set('enable_metadata_archive_db', True) - - # Update metadata providers - await update_metadata_providers() - - return web.json_response({ - 'success': True, - 'message': 'Metadata archive database downloaded and extracted successfully' - }) - else: - return web.json_response({ - 'success': False, - 'error': 'Failed to download and extract metadata archive database' - }, status=500) - - except Exception as e: - logger.error(f"Error downloading metadata archive: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - @staticmethod - async def remove_metadata_archive(request): - """Remove the metadata archive database""" - try: - archive_manager = await get_metadata_archive_manager() - - success = await archive_manager.remove_database() - - if success: - # Update settings to disable metadata archive - settings.set('enable_metadata_archive_db', False) - - # Update metadata providers - await update_metadata_providers() - - return web.json_response({ - 'success': True, - 'message': 'Metadata archive database removed successfully' - }) - else: - return web.json_response({ - 'success': False, - 'error': 'Failed to remove metadata archive database' - }, status=500) - - except Exception as e: - logger.error(f"Error removing metadata archive: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - @staticmethod - async def get_metadata_archive_status(request): - """Get the status of metadata archive database""" - try: - archive_manager = await get_metadata_archive_manager() - - is_available = archive_manager.is_database_available() - is_enabled = settings.get('enable_metadata_archive_db', False) - - db_size = 0 - if is_available: - db_path = archive_manager.get_database_path() - if db_path and os.path.exists(db_path): - db_size = os.path.getsize(db_path) - - return web.json_response({ - 'success': True, - 'isAvailable': is_available, - 'isEnabled': is_enabled, - 'databaseSize': db_size, - 'databasePath': archive_manager.get_database_path() if is_available else None - }) - - except Exception as e: - logger.error(f"Error getting metadata archive status: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - @staticmethod - async def get_model_versions_status(request): - """ - Get all versions of a model from metadata provider and check their library status - - Expects query parameters: - - modelId: int - Civitai model ID (required) - - Returns: - - JSON with model type and versions list, each version includes 'inLibrary' flag - """ - try: - # Get the modelId from query parameters - model_id_str = request.query.get('modelId') - - # Validate modelId parameter (required) - if not model_id_str: - return web.json_response({ - 'success': False, - 'error': 'Missing required parameter: modelId' - }, status=400) - - try: - # Convert modelId to integer - model_id = int(model_id_str) - except ValueError: - return web.json_response({ - 'success': False, - 'error': 'Parameter modelId must be an integer' - }, status=400) - - # Get metadata provider - metadata_provider = await get_metadata_provider() - if not metadata_provider: - return web.json_response({ - 'success': False, - 'error': 'Metadata provider not available' - }, status=503) - - # Get model versions from metadata provider - response = await metadata_provider.get_model_versions(model_id) - if not response or not response.get('modelVersions'): - return web.json_response({ - 'success': False, - 'error': 'Model not found' - }, status=404) - - versions = response.get('modelVersions', []) - model_name = response.get('name', '') - model_type = response.get('type', '').lower() - - # Determine scanner based on model type - scanner = None - normalized_type = None - - if model_type in ['lora', 'locon', 'dora']: - scanner = await ServiceRegistry.get_lora_scanner() - normalized_type = 'lora' - elif model_type == 'checkpoint': - scanner = await ServiceRegistry.get_checkpoint_scanner() - normalized_type = 'checkpoint' - elif model_type == 'textualinversion': - scanner = await ServiceRegistry.get_embedding_scanner() - normalized_type = 'embedding' - else: - return web.json_response({ - 'success': False, - 'error': f'Model type "{model_type}" is not supported' - }, status=400) - - if not scanner: - return web.json_response({ - 'success': False, - 'error': f'Scanner for type "{normalized_type}" is not available' - }, status=503) - - # Get local versions from scanner - local_versions = await scanner.get_model_versions_by_id(model_id) - local_version_ids = set(version['versionId'] for version in local_versions) - - # Add inLibrary flag to each version - enriched_versions = [] - for version in versions: - version_id = version.get('id') - enriched_version = { - 'id': version_id, - 'name': version.get('name', ''), - 'thumbnailUrl': version.get('images')[0]['url'] if version.get('images') else None, - 'inLibrary': version_id in local_version_ids - } - enriched_versions.append(enriched_version) - - return web.json_response({ - 'success': True, - 'modelId': model_id, - 'modelName': model_name, - 'modelType': model_type, - 'versions': enriched_versions - }) - - except Exception as e: - logger.error(f"Failed to get model versions status: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - @staticmethod - async def open_file_location(request): - """ - Open the folder containing the specified file and select the file in the file explorer. - - Expects a JSON request body with: - { - "file_path": "absolute/path/to/file" - } - """ - try: - data = await request.json() - file_path = data.get('file_path') - - if not file_path: - return web.json_response({ - 'success': False, - 'error': 'Missing file_path parameter' - }, status=400) - - file_path = os.path.abspath(file_path) - - if not os.path.isfile(file_path): - return web.json_response({ - 'success': False, - 'error': 'File does not exist' - }, status=404) - - # Open the folder and select the file - if os.name == 'nt': # Windows - # explorer /select,"C:\path\to\file" - subprocess.Popen(['explorer', '/select,', file_path]) - elif os.name == 'posix': - if sys.platform == 'darwin': # macOS - subprocess.Popen(['open', '-R', file_path]) - else: # Linux (selecting file is not standard, just open folder) - folder = os.path.dirname(file_path) - subprocess.Popen(['xdg-open', folder]) - - return web.json_response({ - 'success': True, - 'message': f'Opened folder and selected file: {file_path}' - }) - - except Exception as e: - logger.error(f"Failed to open file location: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) +__all__ = ["MiscRoutes"] diff --git a/tests/routes/test_misc_routes.py b/tests/routes/test_misc_routes.py new file mode 100644 index 00000000..87da2f5f --- /dev/null +++ b/tests/routes/test_misc_routes.py @@ -0,0 +1,211 @@ +import json +from types import SimpleNamespace + +import pytest +from aiohttp import web + +from py.routes.handlers.misc_handlers import SettingsHandler, ServiceRegistryAdapter +from py.routes.misc_route_registrar import MISC_ROUTE_DEFINITIONS, MiscRouteRegistrar +from py.routes.misc_routes import MiscRoutes + + +class FakeRequest: + def __init__(self, *, json_data=None, query=None): + self._json_data = json_data or {} + self.query = query or {} + + async def json(self): + return self._json_data + + +class DummySettings: + def __init__(self, data=None): + self.data = data or {} + + def get(self, key, default=None): + return self.data.get(key, default) + + def set(self, key, value): + self.data[key] = value + + def delete(self, key): + self.data.pop(key, None) + + +class DummyDownloader: + def __init__(self): + self.refreshed = False + + async def refresh_session(self): + self.refreshed = True + + +async def noop_async(*_args, **_kwargs): + return None + + +async def dummy_downloader_factory(): + return DummyDownloader() + + +@pytest.mark.asyncio +async def test_get_settings_filters_sync_keys(): + settings_service = DummySettings({"civitai_api_key": "abc", "extraneous": "value"}) + handler = SettingsHandler( + settings_service=settings_service, + metadata_provider_updater=noop_async, + downloader_factory=dummy_downloader_factory, + ) + + response = await handler.get_settings(FakeRequest()) + payload = json.loads(response.text) + + assert payload["success"] is True + assert payload["settings"] == {"civitai_api_key": "abc"} + + +@pytest.mark.asyncio +async def test_update_settings_rejects_missing_example_path(tmp_path): + settings_service = DummySettings() + handler = SettingsHandler( + settings_service=settings_service, + metadata_provider_updater=noop_async, + downloader_factory=dummy_downloader_factory, + ) + + missing_path = tmp_path / "does-not-exist" + request = FakeRequest(json_data={"example_images_path": str(missing_path)}) + + response = await handler.update_settings(request) + payload = json.loads(response.text) + + assert payload["success"] is False + assert "Path does not exist" in payload["error"] + + +class RecordingRouter: + def __init__(self): + self.calls = [] + + def add_get(self, path, handler): + self.calls.append(("GET", path, handler)) + + def add_post(self, path, handler): + self.calls.append(("POST", path, handler)) + + def add_put(self, path, handler): + self.calls.append(("PUT", path, handler)) + + def add_delete(self, path, handler): + self.calls.append(("DELETE", path, handler)) + + +def test_misc_route_registrar_registers_all_routes(): + app = SimpleNamespace(router=RecordingRouter()) + registrar = MiscRouteRegistrar(app) # type: ignore[arg-type] + + async def dummy_handler(_request): + return web.Response() + + handler_mapping = {definition.handler_name: dummy_handler for definition in MISC_ROUTE_DEFINITIONS} + registrar.register_routes(handler_mapping) + + registered = {(method, path) for method, path, _ in app.router.calls} + expected = {(definition.method, definition.path) for definition in MISC_ROUTE_DEFINITIONS} + + assert registered == expected + + +class FakePromptServer: + sent = [] + + class Instance: + def send_sync(self, event, payload): + FakePromptServer.sent.append((event, payload)) + + instance = Instance() + + +class FakeScanner: + async def check_model_version_exists(self, _version_id): + return False + + async def get_model_versions_by_id(self, _model_id): + return [] + + +async def fake_scanner_factory(): + return FakeScanner() + + +class FakeMetadataProvider: + async def get_model_versions(self, _model_id): + return {"modelVersions": [], "name": "", "type": "lora"} + + +async def fake_metadata_provider_factory(): + return FakeMetadataProvider() + + +class FakeMetadataArchiveManager: + async def download_and_extract_database(self, _callback): + return True + + async def remove_database(self): + return True + + def is_database_available(self): + return False + + def get_database_path(self): + return None + + +async def fake_metadata_archive_manager_factory(): + return FakeMetadataArchiveManager() + + +class RecordingRegistrar: + def __init__(self, _app): + self.registered_mapping = None + + def register_routes(self, mapping): + self.registered_mapping = mapping + + +@pytest.mark.asyncio +async def test_misc_routes_bind_produces_expected_handlers(): + service_registry_adapter = ServiceRegistryAdapter( + get_lora_scanner=fake_scanner_factory, + get_checkpoint_scanner=fake_scanner_factory, + get_embedding_scanner=fake_scanner_factory, + ) + + recorded_registrars = [] + + def registrar_factory(app): + registrar = RecordingRegistrar(app) + recorded_registrars.append(registrar) + return registrar + + controller = MiscRoutes( + settings_service=DummySettings(), + usage_stats_factory=lambda: SimpleNamespace(process_execution=noop_async, get_stats=noop_async), + prompt_server=FakePromptServer, + service_registry_adapter=service_registry_adapter, + metadata_provider_factory=fake_metadata_provider_factory, + metadata_archive_manager_factory=fake_metadata_archive_manager_factory, + metadata_provider_updater=noop_async, + downloader_factory=dummy_downloader_factory, + registrar_factory=registrar_factory, + ) + + app = SimpleNamespace(router=RecordingRouter()) + controller.bind(app) # type: ignore[arg-type] + + assert recorded_registrars, "Expected registrar to be created" + mapping = recorded_registrars[0].registered_mapping + assert mapping is not None + + expected_names = {definition.handler_name for definition in MISC_ROUTE_DEFINITIONS} + assert set(mapping.keys()) == expected_names