mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
refactor(routes): modularize misc route handling
This commit is contained in:
750
py/routes/handlers/misc_handlers.py
Normal file
750
py/routes/handlers/misc_handlers.py
Normal file
@@ -0,0 +1,750 @@
|
||||
"""Handlers for miscellaneous routes.
|
||||
|
||||
The legacy :mod:`py.routes.misc_routes` module bundled HTTP wiring and
|
||||
business logic in a single class. This module mirrors the model route
|
||||
architecture by splitting the responsibilities into dedicated handler
|
||||
objects that can be composed by the route controller.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable, Dict, Mapping, Protocol
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from ...config import config
|
||||
from ...services.metadata_service import (
|
||||
get_metadata_archive_manager,
|
||||
get_metadata_provider,
|
||||
update_metadata_providers,
|
||||
)
|
||||
from ...services.service_registry import ServiceRegistry
|
||||
from ...services.settings_manager import settings as default_settings
|
||||
from ...services.websocket_manager import ws_manager
|
||||
from ...services.downloader import get_downloader
|
||||
from ...utils.constants import DEFAULT_NODE_COLOR, NODE_TYPES, SUPPORTED_MEDIA_EXTENSIONS
|
||||
from ...utils.lora_metadata import extract_trained_words
|
||||
from ...utils.usage_stats import UsageStats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PromptServerProtocol(Protocol):
|
||||
"""Subset of PromptServer used by the handlers."""
|
||||
|
||||
instance: "PromptServerProtocol"
|
||||
|
||||
def send_sync(self, event: str, payload: dict) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class DownloaderProtocol(Protocol):
|
||||
async def refresh_session(self) -> None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class UsageStatsFactory(Protocol):
|
||||
def __call__(self) -> UsageStats: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class MetadataProviderProtocol(Protocol):
|
||||
async def get_model_versions(self, model_id: int) -> dict | None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class MetadataArchiveManagerProtocol(Protocol):
|
||||
async def download_and_extract_database(
|
||||
self, progress_callback: Callable[[str, str], None]
|
||||
) -> bool: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
async def remove_database(self) -> bool: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def is_database_available(self) -> bool: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def get_database_path(self) -> str | None: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class NodeRegistry:
|
||||
"""Thread-safe registry for tracking LoRA nodes in active workflows."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = asyncio.Lock()
|
||||
self._nodes: Dict[int, dict] = {}
|
||||
self._registry_updated = asyncio.Event()
|
||||
|
||||
async def register_nodes(self, nodes: list[dict]) -> None:
|
||||
async with self._lock:
|
||||
self._nodes.clear()
|
||||
for node in nodes:
|
||||
node_id = node["node_id"]
|
||||
node_type = node.get("type", "")
|
||||
type_id = NODE_TYPES.get(node_type, 0)
|
||||
bgcolor = node.get("bgcolor") or DEFAULT_NODE_COLOR
|
||||
self._nodes[node_id] = {
|
||||
"id": node_id,
|
||||
"bgcolor": bgcolor,
|
||||
"title": node.get("title"),
|
||||
"type": type_id,
|
||||
"type_name": node_type,
|
||||
}
|
||||
logger.debug("Registered %s nodes in registry", len(nodes))
|
||||
self._registry_updated.set()
|
||||
|
||||
async def get_registry(self) -> dict:
|
||||
async with self._lock:
|
||||
return {
|
||||
"nodes": dict(self._nodes),
|
||||
"node_count": len(self._nodes),
|
||||
}
|
||||
|
||||
async def wait_for_update(self, timeout: float = 1.0) -> bool:
|
||||
self._registry_updated.clear()
|
||||
try:
|
||||
await asyncio.wait_for(self._registry_updated.wait(), timeout=timeout)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
return False
|
||||
|
||||
|
||||
class HealthCheckHandler:
|
||||
async def health_check(self, request: web.Request) -> web.Response:
|
||||
return web.json_response({"status": "ok"})
|
||||
|
||||
|
||||
class SettingsHandler:
|
||||
"""Sync settings between backend and frontend."""
|
||||
|
||||
_SYNC_KEYS = (
|
||||
"civitai_api_key",
|
||||
"default_lora_root",
|
||||
"default_checkpoint_root",
|
||||
"default_embedding_root",
|
||||
"base_model_path_mappings",
|
||||
"download_path_templates",
|
||||
"enable_metadata_archive_db",
|
||||
"language",
|
||||
"proxy_enabled",
|
||||
"proxy_type",
|
||||
"proxy_host",
|
||||
"proxy_port",
|
||||
"proxy_username",
|
||||
"proxy_password",
|
||||
"example_images_path",
|
||||
"optimize_example_images",
|
||||
"auto_download_example_images",
|
||||
"blur_mature_content",
|
||||
"autoplay_on_hover",
|
||||
"display_density",
|
||||
"card_info_display",
|
||||
"include_trigger_words",
|
||||
"show_only_sfw",
|
||||
"compact_mode",
|
||||
)
|
||||
|
||||
_PROXY_KEYS = {"proxy_enabled", "proxy_host", "proxy_port", "proxy_username", "proxy_password", "proxy_type"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
settings_service=default_settings,
|
||||
metadata_provider_updater: Callable[[], Awaitable[None]] = update_metadata_providers,
|
||||
downloader_factory: Callable[[], Awaitable[DownloaderProtocol]] = get_downloader,
|
||||
) -> None:
|
||||
self._settings = settings_service
|
||||
self._metadata_provider_updater = metadata_provider_updater
|
||||
self._downloader_factory = downloader_factory
|
||||
|
||||
async def get_settings(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
response_data = {}
|
||||
for key in self._SYNC_KEYS:
|
||||
value = self._settings.get(key)
|
||||
if value is not None:
|
||||
response_data[key] = value
|
||||
return web.json_response({"success": True, "settings": response_data})
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Error getting settings: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def update_settings(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
data = await request.json()
|
||||
proxy_changed = False
|
||||
|
||||
for key, value in data.items():
|
||||
if value == self._settings.get(key):
|
||||
continue
|
||||
|
||||
if key == "example_images_path" and value:
|
||||
validation_error = self._validate_example_images_path(value)
|
||||
if validation_error:
|
||||
return web.json_response({"success": False, "error": validation_error})
|
||||
|
||||
if value == "__DELETE__" and key in ("proxy_username", "proxy_password"):
|
||||
self._settings.delete(key)
|
||||
else:
|
||||
self._settings.set(key, value)
|
||||
|
||||
if key == "enable_metadata_archive_db":
|
||||
await self._metadata_provider_updater()
|
||||
|
||||
if key in self._PROXY_KEYS:
|
||||
proxy_changed = True
|
||||
|
||||
if proxy_changed:
|
||||
downloader = await self._downloader_factory()
|
||||
await downloader.refresh_session()
|
||||
|
||||
return web.json_response({"success": True})
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Error updating settings: %s", exc, exc_info=True)
|
||||
return web.Response(status=500, text=str(exc))
|
||||
|
||||
def _validate_example_images_path(self, folder_path: str) -> str | None:
|
||||
if not os.path.exists(folder_path):
|
||||
return f"Path does not exist: {folder_path}"
|
||||
if not os.path.isdir(folder_path):
|
||||
return "Please set a dedicated folder for example images."
|
||||
if not self._is_dedicated_example_images_folder(folder_path):
|
||||
return "Please set a dedicated folder for example images."
|
||||
return None
|
||||
|
||||
def _is_dedicated_example_images_folder(self, folder_path: str) -> bool:
|
||||
try:
|
||||
items = os.listdir(folder_path)
|
||||
if not items:
|
||||
return True
|
||||
for item in items:
|
||||
item_path = os.path.join(folder_path, item)
|
||||
if item == ".download_progress.json" and os.path.isfile(item_path):
|
||||
continue
|
||||
if os.path.isdir(item_path) and re.fullmatch(r"[a-fA-F0-9]{64}", item):
|
||||
continue
|
||||
return False
|
||||
return True
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Error checking if folder is dedicated: %s", exc)
|
||||
return False
|
||||
|
||||
|
||||
class UsageStatsHandler:
|
||||
def __init__(self, usage_stats_factory: UsageStatsFactory = UsageStats) -> None:
|
||||
self._usage_stats_factory = usage_stats_factory
|
||||
|
||||
async def update_usage_stats(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
data = await request.json()
|
||||
prompt_id = data.get("prompt_id")
|
||||
if not prompt_id:
|
||||
return web.json_response({"success": False, "error": "Missing prompt_id"}, status=400)
|
||||
usage_stats = self._usage_stats_factory()
|
||||
await usage_stats.process_execution(prompt_id)
|
||||
return web.json_response({"success": True})
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Failed to update usage stats: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_usage_stats(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
usage_stats = self._usage_stats_factory()
|
||||
stats = await usage_stats.get_stats()
|
||||
stats_response = {"success": True, "data": stats, "format_version": 2}
|
||||
return web.json_response(stats_response)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Failed to get usage stats: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class LoraCodeHandler:
|
||||
def __init__(self, prompt_server: type[PromptServerProtocol]) -> None:
|
||||
self._prompt_server = prompt_server
|
||||
|
||||
async def update_lora_code(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
data = await request.json()
|
||||
node_ids = data.get("node_ids")
|
||||
lora_code = data.get("lora_code", "")
|
||||
mode = data.get("mode", "append")
|
||||
|
||||
if not lora_code:
|
||||
return web.json_response({"success": False, "error": "Missing lora_code parameter"}, status=400)
|
||||
|
||||
results = []
|
||||
if node_ids is None:
|
||||
try:
|
||||
self._prompt_server.instance.send_sync(
|
||||
"lora_code_update", {"id": -1, "lora_code": lora_code, "mode": mode}
|
||||
)
|
||||
results.append({"node_id": "broadcast", "success": True})
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Error broadcasting lora code: %s", exc)
|
||||
results.append({"node_id": "broadcast", "success": False, "error": str(exc)})
|
||||
else:
|
||||
for node_id in node_ids:
|
||||
try:
|
||||
self._prompt_server.instance.send_sync(
|
||||
"lora_code_update",
|
||||
{"id": node_id, "lora_code": lora_code, "mode": mode},
|
||||
)
|
||||
results.append({"node_id": node_id, "success": True})
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Error sending lora code to node %s: %s", node_id, exc)
|
||||
results.append({"node_id": node_id, "success": False, "error": str(exc)})
|
||||
|
||||
return web.json_response({"success": True, "results": results})
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Failed to update lora code: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class TrainedWordsHandler:
|
||||
async def get_trained_words(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
file_path = request.query.get("file_path")
|
||||
if not file_path:
|
||||
return web.json_response({"success": False, "error": "Missing file_path parameter"}, status=400)
|
||||
if not os.path.exists(file_path):
|
||||
return web.json_response({"success": False, "error": "File not found"}, status=404)
|
||||
if not file_path.endswith(".safetensors"):
|
||||
return web.json_response({"success": False, "error": "File must be a safetensors file"}, status=400)
|
||||
|
||||
trained_words = extract_trained_words(file_path)
|
||||
sorted_words = sorted(trained_words, key=lambda w: w.get("count", 0), reverse=True)
|
||||
return web.json_response({"success": True, "trained_words": sorted_words})
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Failed to get trained words: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class ModelExampleFilesHandler:
|
||||
async def get_model_example_files(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
model_path = request.query.get("model_path")
|
||||
if not model_path:
|
||||
return web.json_response({"success": False, "error": "Missing model_path parameter"}, status=400)
|
||||
model_dir = os.path.dirname(model_path)
|
||||
if not os.path.exists(model_dir):
|
||||
return web.json_response({"success": False, "error": "Model directory not found"}, status=404)
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(model_path))[0]
|
||||
files = []
|
||||
pattern = f"{base_name}.example."
|
||||
for file in os.listdir(model_dir):
|
||||
if not file.startswith(pattern):
|
||||
continue
|
||||
file_full_path = os.path.join(model_dir, file)
|
||||
if not os.path.isfile(file_full_path):
|
||||
continue
|
||||
file_ext = os.path.splitext(file)[1].lower()
|
||||
if file_ext not in SUPPORTED_MEDIA_EXTENSIONS["images"] and file_ext not in SUPPORTED_MEDIA_EXTENSIONS["videos"]:
|
||||
continue
|
||||
try:
|
||||
index = int(file[len(pattern) :].split(".")[0])
|
||||
except (ValueError, IndexError):
|
||||
index = float("inf")
|
||||
static_url = config.get_preview_static_url(file_full_path)
|
||||
files.append(
|
||||
{
|
||||
"name": file,
|
||||
"path": static_url,
|
||||
"extension": file_ext,
|
||||
"is_video": file_ext in SUPPORTED_MEDIA_EXTENSIONS["videos"],
|
||||
"index": index,
|
||||
}
|
||||
)
|
||||
|
||||
files.sort(key=lambda item: item["index"])
|
||||
for file in files:
|
||||
file.pop("index", None)
|
||||
|
||||
return web.json_response({"success": True, "files": files})
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Failed to get model example files: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceRegistryAdapter:
|
||||
get_lora_scanner: Callable[[], Awaitable]
|
||||
get_checkpoint_scanner: Callable[[], Awaitable]
|
||||
get_embedding_scanner: Callable[[], Awaitable]
|
||||
|
||||
|
||||
class ModelLibraryHandler:
|
||||
def __init__(self, service_registry: ServiceRegistryAdapter, metadata_provider_factory: Callable[[], Awaitable[MetadataProviderProtocol | None]]) -> None:
|
||||
self._service_registry = service_registry
|
||||
self._metadata_provider_factory = metadata_provider_factory
|
||||
|
||||
async def check_model_exists(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
model_id_str = request.query.get("modelId")
|
||||
model_version_id_str = request.query.get("modelVersionId")
|
||||
if not model_id_str:
|
||||
return web.json_response({"success": False, "error": "Missing required parameter: modelId"}, status=400)
|
||||
try:
|
||||
model_id = int(model_id_str)
|
||||
except ValueError:
|
||||
return web.json_response({"success": False, "error": "Parameter modelId must be an integer"}, status=400)
|
||||
|
||||
lora_scanner = await self._service_registry.get_lora_scanner()
|
||||
checkpoint_scanner = await self._service_registry.get_checkpoint_scanner()
|
||||
embedding_scanner = await self._service_registry.get_embedding_scanner()
|
||||
|
||||
if model_version_id_str:
|
||||
try:
|
||||
model_version_id = int(model_version_id_str)
|
||||
except ValueError:
|
||||
return web.json_response({"success": False, "error": "Parameter modelVersionId must be an integer"}, status=400)
|
||||
|
||||
exists = False
|
||||
model_type = None
|
||||
if await lora_scanner.check_model_version_exists(model_version_id):
|
||||
exists = True
|
||||
model_type = "lora"
|
||||
elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_version_id):
|
||||
exists = True
|
||||
model_type = "checkpoint"
|
||||
elif embedding_scanner and await embedding_scanner.check_model_version_exists(model_version_id):
|
||||
exists = True
|
||||
model_type = "embedding"
|
||||
|
||||
return web.json_response({"success": True, "exists": exists, "modelType": model_type if exists else None})
|
||||
|
||||
lora_versions = await lora_scanner.get_model_versions_by_id(model_id)
|
||||
checkpoint_versions = []
|
||||
embedding_versions = []
|
||||
if not lora_versions and checkpoint_scanner:
|
||||
checkpoint_versions = await checkpoint_scanner.get_model_versions_by_id(model_id)
|
||||
if not lora_versions and not checkpoint_versions and embedding_scanner:
|
||||
embedding_versions = await embedding_scanner.get_model_versions_by_id(model_id)
|
||||
|
||||
model_type = None
|
||||
version_ids: list[int] = []
|
||||
if lora_versions:
|
||||
model_type = "lora"
|
||||
version_ids = [version["versionId"] for version in lora_versions]
|
||||
elif checkpoint_versions:
|
||||
model_type = "checkpoint"
|
||||
version_ids = [version["versionId"] for version in checkpoint_versions]
|
||||
elif embedding_versions:
|
||||
model_type = "embedding"
|
||||
version_ids = [version["versionId"] for version in embedding_versions]
|
||||
|
||||
return web.json_response({"success": True, "modelType": model_type, "modelVersionIds": version_ids})
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Failed to check model existence: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_model_versions_status(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
model_id_str = request.query.get("modelId")
|
||||
if not model_id_str:
|
||||
return web.json_response({"success": False, "error": "Missing required parameter: modelId"}, status=400)
|
||||
try:
|
||||
model_id = int(model_id_str)
|
||||
except ValueError:
|
||||
return web.json_response({"success": False, "error": "Parameter modelId must be an integer"}, status=400)
|
||||
|
||||
metadata_provider = await self._metadata_provider_factory()
|
||||
if not metadata_provider:
|
||||
return web.json_response({"success": False, "error": "Metadata provider not available"}, status=503)
|
||||
|
||||
response = await metadata_provider.get_model_versions(model_id)
|
||||
if not response or not response.get("modelVersions"):
|
||||
return web.json_response({"success": False, "error": "Model not found"}, status=404)
|
||||
|
||||
versions = response.get("modelVersions", [])
|
||||
model_name = response.get("name", "")
|
||||
model_type = response.get("type", "").lower()
|
||||
|
||||
scanner = None
|
||||
normalized_type = None
|
||||
if model_type in {"lora", "locon", "dora"}:
|
||||
scanner = await self._service_registry.get_lora_scanner()
|
||||
normalized_type = "lora"
|
||||
elif model_type == "checkpoint":
|
||||
scanner = await self._service_registry.get_checkpoint_scanner()
|
||||
normalized_type = "checkpoint"
|
||||
elif model_type == "textualinversion":
|
||||
scanner = await self._service_registry.get_embedding_scanner()
|
||||
normalized_type = "embedding"
|
||||
else:
|
||||
return web.json_response({"success": False, "error": f'Model type "{model_type}" is not supported'}, status=400)
|
||||
|
||||
if not scanner:
|
||||
return web.json_response({"success": False, "error": f'Scanner for type "{normalized_type}" is not available'}, status=503)
|
||||
|
||||
local_versions = await scanner.get_model_versions_by_id(model_id)
|
||||
local_version_ids = {version["versionId"] for version in local_versions}
|
||||
|
||||
enriched_versions = []
|
||||
for version in versions:
|
||||
version_id = version.get("id")
|
||||
enriched_versions.append(
|
||||
{
|
||||
"id": version_id,
|
||||
"name": version.get("name", ""),
|
||||
"thumbnailUrl": version.get("images")[0]["url"] if version.get("images") else None,
|
||||
"inLibrary": version_id in local_version_ids,
|
||||
}
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"modelId": model_id,
|
||||
"modelName": model_name,
|
||||
"modelType": model_type,
|
||||
"versions": enriched_versions,
|
||||
}
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Failed to get model versions status: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class MetadataArchiveHandler:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metadata_archive_manager_factory: Callable[[], Awaitable[MetadataArchiveManagerProtocol]] = get_metadata_archive_manager,
|
||||
settings_service=default_settings,
|
||||
metadata_provider_updater: Callable[[], Awaitable[None]] = update_metadata_providers,
|
||||
) -> None:
|
||||
self._metadata_archive_manager_factory = metadata_archive_manager_factory
|
||||
self._settings = settings_service
|
||||
self._metadata_provider_updater = metadata_provider_updater
|
||||
|
||||
async def download_metadata_archive(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
archive_manager = await self._metadata_archive_manager_factory()
|
||||
download_id = request.query.get("download_id")
|
||||
|
||||
def progress_callback(stage: str, message: str) -> None:
|
||||
data = {"stage": stage, "message": message, "type": "metadata_archive_download"}
|
||||
if download_id:
|
||||
asyncio.create_task(ws_manager.broadcast_download_progress(download_id, data))
|
||||
else:
|
||||
asyncio.create_task(ws_manager.broadcast(data))
|
||||
|
||||
success = await archive_manager.download_and_extract_database(progress_callback)
|
||||
if success:
|
||||
self._settings.set("enable_metadata_archive_db", True)
|
||||
await self._metadata_provider_updater()
|
||||
return web.json_response({"success": True, "message": "Metadata archive database downloaded and extracted successfully"})
|
||||
return web.json_response({"success": False, "error": "Failed to download and extract metadata archive database"}, status=500)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Error downloading metadata archive: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def remove_metadata_archive(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
archive_manager = await self._metadata_archive_manager_factory()
|
||||
success = await archive_manager.remove_database()
|
||||
if success:
|
||||
self._settings.set("enable_metadata_archive_db", False)
|
||||
await self._metadata_provider_updater()
|
||||
return web.json_response({"success": True, "message": "Metadata archive database removed successfully"})
|
||||
return web.json_response({"success": False, "error": "Failed to remove metadata archive database"}, status=500)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Error removing metadata archive: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_metadata_archive_status(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
archive_manager = await self._metadata_archive_manager_factory()
|
||||
is_available = archive_manager.is_database_available()
|
||||
is_enabled = self._settings.get("enable_metadata_archive_db", False)
|
||||
db_size = 0
|
||||
if is_available:
|
||||
db_path = archive_manager.get_database_path()
|
||||
if db_path and os.path.exists(db_path):
|
||||
db_size = os.path.getsize(db_path)
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"isAvailable": is_available,
|
||||
"isEnabled": is_enabled,
|
||||
"databaseSize": db_size,
|
||||
"databasePath": archive_manager.get_database_path() if is_available else None,
|
||||
}
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Error getting metadata archive status: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class FileSystemHandler:
|
||||
async def open_file_location(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get("file_path")
|
||||
if not file_path:
|
||||
return web.json_response({"success": False, "error": "Missing file_path parameter"}, status=400)
|
||||
file_path = os.path.abspath(file_path)
|
||||
if not os.path.isfile(file_path):
|
||||
return web.json_response({"success": False, "error": "File does not exist"}, status=404)
|
||||
|
||||
if os.name == "nt":
|
||||
subprocess.Popen(["explorer", "/select,", file_path])
|
||||
elif os.name == "posix":
|
||||
if sys.platform == "darwin":
|
||||
subprocess.Popen(["open", "-R", file_path])
|
||||
else:
|
||||
folder = os.path.dirname(file_path)
|
||||
subprocess.Popen(["xdg-open", folder])
|
||||
|
||||
return web.json_response({"success": True, "message": f"Opened folder and selected file: {file_path}"})
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Failed to open file location: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class NodeRegistryHandler:
|
||||
def __init__(
|
||||
self,
|
||||
node_registry: NodeRegistry,
|
||||
prompt_server: type[PromptServerProtocol],
|
||||
*,
|
||||
standalone_mode: bool,
|
||||
) -> None:
|
||||
self._node_registry = node_registry
|
||||
self._prompt_server = prompt_server
|
||||
self._standalone_mode = standalone_mode
|
||||
|
||||
async def register_nodes(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
data = await request.json()
|
||||
nodes = data.get("nodes", [])
|
||||
if not isinstance(nodes, list):
|
||||
return web.json_response({"success": False, "error": "nodes must be a list"}, status=400)
|
||||
for index, node in enumerate(nodes):
|
||||
if not isinstance(node, dict):
|
||||
return web.json_response({"success": False, "error": f"Node {index} must be an object"}, status=400)
|
||||
node_id = node.get("node_id")
|
||||
if node_id is None:
|
||||
return web.json_response({"success": False, "error": f"Node {index} missing node_id parameter"}, status=400)
|
||||
try:
|
||||
node["node_id"] = int(node_id)
|
||||
except (TypeError, ValueError):
|
||||
return web.json_response({"success": False, "error": f"Node {index} node_id must be an integer"}, status=400)
|
||||
|
||||
await self._node_registry.register_nodes(nodes)
|
||||
return web.json_response({"success": True, "message": f"{len(nodes)} nodes registered successfully"})
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Failed to register nodes: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_registry(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
if self._standalone_mode:
|
||||
logger.warning("Registry refresh not available in standalone mode")
|
||||
return web.json_response(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Standalone Mode Active",
|
||||
"message": "Cannot interact with ComfyUI in standalone mode.",
|
||||
},
|
||||
status=503,
|
||||
)
|
||||
|
||||
try:
|
||||
self._prompt_server.instance.send_sync("lora_registry_refresh", {})
|
||||
logger.debug("Sent registry refresh request to frontend")
|
||||
except Exception as exc:
|
||||
logger.error("Failed to send registry refresh message: %s", exc)
|
||||
return web.json_response(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Communication Error",
|
||||
"message": f"Failed to communicate with ComfyUI frontend: {exc}",
|
||||
},
|
||||
status=500,
|
||||
)
|
||||
|
||||
registry_updated = await self._node_registry.wait_for_update(timeout=1.0)
|
||||
if not registry_updated:
|
||||
logger.warning("Registry refresh timeout after 1 second")
|
||||
return web.json_response(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Timeout Error",
|
||||
"message": "Registry refresh timeout - ComfyUI frontend may not be responsive",
|
||||
},
|
||||
status=408,
|
||||
)
|
||||
|
||||
registry_info = await self._node_registry.get_registry()
|
||||
return web.json_response({"success": True, "data": registry_info})
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Failed to get registry: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": "Internal Error", "message": str(exc)}, status=500)
|
||||
|
||||
|
||||
class MiscHandlerSet:
|
||||
"""Aggregate handlers into a lookup compatible with the registrar."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
health: HealthCheckHandler,
|
||||
settings: SettingsHandler,
|
||||
usage_stats: UsageStatsHandler,
|
||||
lora_code: LoraCodeHandler,
|
||||
trained_words: TrainedWordsHandler,
|
||||
model_examples: ModelExampleFilesHandler,
|
||||
node_registry: NodeRegistryHandler,
|
||||
model_library: ModelLibraryHandler,
|
||||
metadata_archive: MetadataArchiveHandler,
|
||||
filesystem: FileSystemHandler,
|
||||
) -> None:
|
||||
self.health = health
|
||||
self.settings = settings
|
||||
self.usage_stats = usage_stats
|
||||
self.lora_code = lora_code
|
||||
self.trained_words = trained_words
|
||||
self.model_examples = model_examples
|
||||
self.node_registry = node_registry
|
||||
self.model_library = model_library
|
||||
self.metadata_archive = metadata_archive
|
||||
self.filesystem = filesystem
|
||||
|
||||
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]:
|
||||
return {
|
||||
"health_check": self.health.health_check,
|
||||
"get_settings": self.settings.get_settings,
|
||||
"update_settings": self.settings.update_settings,
|
||||
"update_usage_stats": self.usage_stats.update_usage_stats,
|
||||
"get_usage_stats": self.usage_stats.get_usage_stats,
|
||||
"update_lora_code": self.lora_code.update_lora_code,
|
||||
"get_trained_words": self.trained_words.get_trained_words,
|
||||
"get_model_example_files": self.model_examples.get_model_example_files,
|
||||
"register_nodes": self.node_registry.register_nodes,
|
||||
"get_registry": self.node_registry.get_registry,
|
||||
"check_model_exists": self.model_library.check_model_exists,
|
||||
"download_metadata_archive": self.metadata_archive.download_metadata_archive,
|
||||
"remove_metadata_archive": self.metadata_archive.remove_metadata_archive,
|
||||
"get_metadata_archive_status": self.metadata_archive.get_metadata_archive_status,
|
||||
"get_model_versions_status": self.model_library.get_model_versions_status,
|
||||
"open_file_location": self.filesystem.open_file_location,
|
||||
}
|
||||
|
||||
|
||||
def build_service_registry_adapter() -> ServiceRegistryAdapter:
|
||||
return ServiceRegistryAdapter(
|
||||
get_lora_scanner=ServiceRegistry.get_lora_scanner,
|
||||
get_checkpoint_scanner=ServiceRegistry.get_checkpoint_scanner,
|
||||
get_embedding_scanner=ServiceRegistry.get_embedding_scanner,
|
||||
)
|
||||
67
py/routes/misc_route_registrar.py
Normal file
67
py/routes/misc_route_registrar.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Route registrar for miscellaneous endpoints.
|
||||
|
||||
This module mirrors the model route registrar architecture so that
|
||||
miscellaneous endpoints share a consistent registration flow.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Iterable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RouteDefinition:
|
||||
"""Declarative definition for a HTTP route."""
|
||||
|
||||
method: str
|
||||
path: str
|
||||
handler_name: str
|
||||
|
||||
|
||||
MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/api/lm/settings", "get_settings"),
|
||||
RouteDefinition("POST", "/api/lm/settings", "update_settings"),
|
||||
RouteDefinition("GET", "/api/lm/health-check", "health_check"),
|
||||
RouteDefinition("POST", "/api/lm/open-file-location", "open_file_location"),
|
||||
RouteDefinition("POST", "/api/lm/update-usage-stats", "update_usage_stats"),
|
||||
RouteDefinition("GET", "/api/lm/get-usage-stats", "get_usage_stats"),
|
||||
RouteDefinition("POST", "/api/lm/update-lora-code", "update_lora_code"),
|
||||
RouteDefinition("GET", "/api/lm/trained-words", "get_trained_words"),
|
||||
RouteDefinition("GET", "/api/lm/model-example-files", "get_model_example_files"),
|
||||
RouteDefinition("POST", "/api/lm/register-nodes", "register_nodes"),
|
||||
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
|
||||
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
|
||||
RouteDefinition("POST", "/api/lm/download-metadata-archive", "download_metadata_archive"),
|
||||
RouteDefinition("POST", "/api/lm/remove-metadata-archive", "remove_metadata_archive"),
|
||||
RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"),
|
||||
RouteDefinition("GET", "/api/lm/model-versions-status", "get_model_versions_status"),
|
||||
)
|
||||
|
||||
|
||||
class MiscRouteRegistrar:
|
||||
"""Bind miscellaneous route definitions to an aiohttp router."""
|
||||
|
||||
_METHOD_MAP = {
|
||||
"GET": "add_get",
|
||||
"POST": "add_post",
|
||||
"PUT": "add_put",
|
||||
"DELETE": "add_delete",
|
||||
}
|
||||
|
||||
def __init__(self, app: web.Application) -> None:
|
||||
self._app = app
|
||||
|
||||
def register_routes(
|
||||
self,
|
||||
handler_lookup: Mapping[str, Callable[[web.Request], object]],
|
||||
*,
|
||||
definitions: Iterable[RouteDefinition] = MISC_ROUTE_DEFINITIONS,
|
||||
) -> None:
|
||||
for definition in definitions:
|
||||
self._bind(definition.method, definition.path, handler_lookup[definition.handler_name])
|
||||
|
||||
def _bind(self, method: str, path: str, handler: Callable) -> None:
|
||||
add_method_name = self._METHOD_MAP[method.upper()]
|
||||
add_method = getattr(self._app.router, add_method_name)
|
||||
add_method(path, handler)
|
||||
File diff suppressed because it is too large
Load Diff
211
tests/routes/test_misc_routes.py
Normal file
211
tests/routes/test_misc_routes.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from py.routes.handlers.misc_handlers import SettingsHandler, ServiceRegistryAdapter
|
||||
from py.routes.misc_route_registrar import MISC_ROUTE_DEFINITIONS, MiscRouteRegistrar
|
||||
from py.routes.misc_routes import MiscRoutes
|
||||
|
||||
|
||||
class FakeRequest:
|
||||
def __init__(self, *, json_data=None, query=None):
|
||||
self._json_data = json_data or {}
|
||||
self.query = query or {}
|
||||
|
||||
async def json(self):
|
||||
return self._json_data
|
||||
|
||||
|
||||
class DummySettings:
|
||||
def __init__(self, data=None):
|
||||
self.data = data or {}
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self.data.get(key, default)
|
||||
|
||||
def set(self, key, value):
|
||||
self.data[key] = value
|
||||
|
||||
def delete(self, key):
|
||||
self.data.pop(key, None)
|
||||
|
||||
|
||||
class DummyDownloader:
|
||||
def __init__(self):
|
||||
self.refreshed = False
|
||||
|
||||
async def refresh_session(self):
|
||||
self.refreshed = True
|
||||
|
||||
|
||||
async def noop_async(*_args, **_kwargs):
|
||||
return None
|
||||
|
||||
|
||||
async def dummy_downloader_factory():
|
||||
return DummyDownloader()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_settings_filters_sync_keys():
|
||||
settings_service = DummySettings({"civitai_api_key": "abc", "extraneous": "value"})
|
||||
handler = SettingsHandler(
|
||||
settings_service=settings_service,
|
||||
metadata_provider_updater=noop_async,
|
||||
downloader_factory=dummy_downloader_factory,
|
||||
)
|
||||
|
||||
response = await handler.get_settings(FakeRequest())
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert payload["success"] is True
|
||||
assert payload["settings"] == {"civitai_api_key": "abc"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_rejects_missing_example_path(tmp_path):
|
||||
settings_service = DummySettings()
|
||||
handler = SettingsHandler(
|
||||
settings_service=settings_service,
|
||||
metadata_provider_updater=noop_async,
|
||||
downloader_factory=dummy_downloader_factory,
|
||||
)
|
||||
|
||||
missing_path = tmp_path / "does-not-exist"
|
||||
request = FakeRequest(json_data={"example_images_path": str(missing_path)})
|
||||
|
||||
response = await handler.update_settings(request)
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert payload["success"] is False
|
||||
assert "Path does not exist" in payload["error"]
|
||||
|
||||
|
||||
class RecordingRouter:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def add_get(self, path, handler):
|
||||
self.calls.append(("GET", path, handler))
|
||||
|
||||
def add_post(self, path, handler):
|
||||
self.calls.append(("POST", path, handler))
|
||||
|
||||
def add_put(self, path, handler):
|
||||
self.calls.append(("PUT", path, handler))
|
||||
|
||||
def add_delete(self, path, handler):
|
||||
self.calls.append(("DELETE", path, handler))
|
||||
|
||||
|
||||
def test_misc_route_registrar_registers_all_routes():
|
||||
app = SimpleNamespace(router=RecordingRouter())
|
||||
registrar = MiscRouteRegistrar(app) # type: ignore[arg-type]
|
||||
|
||||
async def dummy_handler(_request):
|
||||
return web.Response()
|
||||
|
||||
handler_mapping = {definition.handler_name: dummy_handler for definition in MISC_ROUTE_DEFINITIONS}
|
||||
registrar.register_routes(handler_mapping)
|
||||
|
||||
registered = {(method, path) for method, path, _ in app.router.calls}
|
||||
expected = {(definition.method, definition.path) for definition in MISC_ROUTE_DEFINITIONS}
|
||||
|
||||
assert registered == expected
|
||||
|
||||
|
||||
class FakePromptServer:
|
||||
sent = []
|
||||
|
||||
class Instance:
|
||||
def send_sync(self, event, payload):
|
||||
FakePromptServer.sent.append((event, payload))
|
||||
|
||||
instance = Instance()
|
||||
|
||||
|
||||
class FakeScanner:
|
||||
async def check_model_version_exists(self, _version_id):
|
||||
return False
|
||||
|
||||
async def get_model_versions_by_id(self, _model_id):
|
||||
return []
|
||||
|
||||
|
||||
async def fake_scanner_factory():
|
||||
return FakeScanner()
|
||||
|
||||
|
||||
class FakeMetadataProvider:
|
||||
async def get_model_versions(self, _model_id):
|
||||
return {"modelVersions": [], "name": "", "type": "lora"}
|
||||
|
||||
|
||||
async def fake_metadata_provider_factory():
|
||||
return FakeMetadataProvider()
|
||||
|
||||
|
||||
class FakeMetadataArchiveManager:
|
||||
async def download_and_extract_database(self, _callback):
|
||||
return True
|
||||
|
||||
async def remove_database(self):
|
||||
return True
|
||||
|
||||
def is_database_available(self):
|
||||
return False
|
||||
|
||||
def get_database_path(self):
|
||||
return None
|
||||
|
||||
|
||||
async def fake_metadata_archive_manager_factory():
|
||||
return FakeMetadataArchiveManager()
|
||||
|
||||
|
||||
class RecordingRegistrar:
|
||||
def __init__(self, _app):
|
||||
self.registered_mapping = None
|
||||
|
||||
def register_routes(self, mapping):
|
||||
self.registered_mapping = mapping
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_misc_routes_bind_produces_expected_handlers():
|
||||
service_registry_adapter = ServiceRegistryAdapter(
|
||||
get_lora_scanner=fake_scanner_factory,
|
||||
get_checkpoint_scanner=fake_scanner_factory,
|
||||
get_embedding_scanner=fake_scanner_factory,
|
||||
)
|
||||
|
||||
recorded_registrars = []
|
||||
|
||||
def registrar_factory(app):
|
||||
registrar = RecordingRegistrar(app)
|
||||
recorded_registrars.append(registrar)
|
||||
return registrar
|
||||
|
||||
controller = MiscRoutes(
|
||||
settings_service=DummySettings(),
|
||||
usage_stats_factory=lambda: SimpleNamespace(process_execution=noop_async, get_stats=noop_async),
|
||||
prompt_server=FakePromptServer,
|
||||
service_registry_adapter=service_registry_adapter,
|
||||
metadata_provider_factory=fake_metadata_provider_factory,
|
||||
metadata_archive_manager_factory=fake_metadata_archive_manager_factory,
|
||||
metadata_provider_updater=noop_async,
|
||||
downloader_factory=dummy_downloader_factory,
|
||||
registrar_factory=registrar_factory,
|
||||
)
|
||||
|
||||
app = SimpleNamespace(router=RecordingRouter())
|
||||
controller.bind(app) # type: ignore[arg-type]
|
||||
|
||||
assert recorded_registrars, "Expected registrar to be created"
|
||||
mapping = recorded_registrars[0].registered_mapping
|
||||
assert mapping is not None
|
||||
|
||||
expected_names = {definition.handler_name for definition in MISC_ROUTE_DEFINITIONS}
|
||||
assert set(mapping.keys()) == expected_names
|
||||
Reference in New Issue
Block a user