Merge pull request #511 from willmiao/codex/refactor-misc_routes.py-and-add-tests

refactor: modularize misc route controller
This commit is contained in:
pixelpaws
2025-10-03 22:47:22 +08:00
committed by GitHub
4 changed files with 1148 additions and 1038 deletions

View File

@@ -0,0 +1,755 @@
"""Handlers for miscellaneous routes.
The legacy :mod:`py.routes.misc_routes` module bundled HTTP wiring and
business logic in a single class. This module mirrors the model route
architecture by splitting the responsibilities into dedicated handler
objects that can be composed by the route controller.
"""
from __future__ import annotations
import asyncio
import logging
import os
import re
import subprocess
import sys
from dataclasses import dataclass
from typing import Awaitable, Callable, Dict, Mapping, Protocol
from aiohttp import web
from ...config import config
from ...services.metadata_service import (
get_metadata_archive_manager,
get_metadata_provider,
update_metadata_providers,
)
from ...services.service_registry import ServiceRegistry
from ...services.settings_manager import settings as default_settings
from ...services.websocket_manager import ws_manager
from ...services.downloader import get_downloader
from ...utils.constants import DEFAULT_NODE_COLOR, NODE_TYPES, SUPPORTED_MEDIA_EXTENSIONS
from ...utils.lora_metadata import extract_trained_words
from ...utils.usage_stats import UsageStats
logger = logging.getLogger(__name__)
class PromptServerProtocol(Protocol):
"""Subset of PromptServer used by the handlers."""
instance: "PromptServerProtocol"
def send_sync(self, event: str, payload: dict) -> None: # pragma: no cover - protocol
...
class DownloaderProtocol(Protocol):
async def refresh_session(self) -> None: # pragma: no cover - protocol
...
class UsageStatsFactory(Protocol):
def __call__(self) -> UsageStats: # pragma: no cover - protocol
...
class MetadataProviderProtocol(Protocol):
async def get_model_versions(self, model_id: int) -> dict | None: # pragma: no cover - protocol
...
class MetadataArchiveManagerProtocol(Protocol):
async def download_and_extract_database(
self, progress_callback: Callable[[str, str], None]
) -> bool: # pragma: no cover - protocol
...
async def remove_database(self) -> bool: # pragma: no cover - protocol
...
def is_database_available(self) -> bool: # pragma: no cover - protocol
...
def get_database_path(self) -> str | None: # pragma: no cover - protocol
...
class NodeRegistry:
"""Thread-safe registry for tracking LoRA nodes in active workflows."""
def __init__(self) -> None:
self._lock = asyncio.Lock()
self._nodes: Dict[int, dict] = {}
self._registry_updated = asyncio.Event()
async def register_nodes(self, nodes: list[dict]) -> None:
async with self._lock:
self._nodes.clear()
for node in nodes:
node_id = node["node_id"]
node_type = node.get("type", "")
type_id = NODE_TYPES.get(node_type, 0)
bgcolor = node.get("bgcolor") or DEFAULT_NODE_COLOR
self._nodes[node_id] = {
"id": node_id,
"bgcolor": bgcolor,
"title": node.get("title"),
"type": type_id,
"type_name": node_type,
}
logger.debug("Registered %s nodes in registry", len(nodes))
self._registry_updated.set()
async def get_registry(self) -> dict:
async with self._lock:
return {
"nodes": dict(self._nodes),
"node_count": len(self._nodes),
}
async def wait_for_update(self, timeout: float = 1.0) -> bool:
self._registry_updated.clear()
try:
await asyncio.wait_for(self._registry_updated.wait(), timeout=timeout)
return True
except asyncio.TimeoutError:
return False
class HealthCheckHandler:
async def health_check(self, request: web.Request) -> web.Response:
return web.json_response({"status": "ok"})
class SettingsHandler:
"""Sync settings between backend and frontend."""
_SYNC_KEYS = (
"civitai_api_key",
"default_lora_root",
"default_checkpoint_root",
"default_embedding_root",
"base_model_path_mappings",
"download_path_templates",
"enable_metadata_archive_db",
"language",
"proxy_enabled",
"proxy_type",
"proxy_host",
"proxy_port",
"proxy_username",
"proxy_password",
"example_images_path",
"optimize_example_images",
"auto_download_example_images",
"blur_mature_content",
"autoplay_on_hover",
"display_density",
"card_info_display",
"include_trigger_words",
"show_only_sfw",
"compact_mode",
)
_PROXY_KEYS = {"proxy_enabled", "proxy_host", "proxy_port", "proxy_username", "proxy_password", "proxy_type"}
def __init__(
self,
*,
settings_service=default_settings,
metadata_provider_updater: Callable[[], Awaitable[None]] = update_metadata_providers,
downloader_factory: Callable[[], Awaitable[DownloaderProtocol]] = get_downloader,
) -> None:
self._settings = settings_service
self._metadata_provider_updater = metadata_provider_updater
self._downloader_factory = downloader_factory
async def get_settings(self, request: web.Request) -> web.Response:
try:
response_data = {}
for key in self._SYNC_KEYS:
value = self._settings.get(key)
if value is not None:
response_data[key] = value
return web.json_response({"success": True, "settings": response_data})
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Error getting settings: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def update_settings(self, request: web.Request) -> web.Response:
try:
data = await request.json()
proxy_changed = False
for key, value in data.items():
if value == self._settings.get(key):
continue
if key == "example_images_path" and value:
validation_error = self._validate_example_images_path(value)
if validation_error:
return web.json_response({"success": False, "error": validation_error})
if value == "__DELETE__" and key in ("proxy_username", "proxy_password"):
self._settings.delete(key)
else:
self._settings.set(key, value)
if key == "enable_metadata_archive_db":
await self._metadata_provider_updater()
if key in self._PROXY_KEYS:
proxy_changed = True
if proxy_changed:
downloader = await self._downloader_factory()
await downloader.refresh_session()
return web.json_response({"success": True})
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Error updating settings: %s", exc, exc_info=True)
return web.Response(status=500, text=str(exc))
def _validate_example_images_path(self, folder_path: str) -> str | None:
if not os.path.exists(folder_path):
return f"Path does not exist: {folder_path}"
if not os.path.isdir(folder_path):
return "Please set a dedicated folder for example images."
if not self._is_dedicated_example_images_folder(folder_path):
return "Please set a dedicated folder for example images."
return None
def _is_dedicated_example_images_folder(self, folder_path: str) -> bool:
try:
items = os.listdir(folder_path)
if not items:
return True
for item in items:
item_path = os.path.join(folder_path, item)
if item == ".download_progress.json" and os.path.isfile(item_path):
continue
if os.path.isdir(item_path) and re.fullmatch(r"[a-fA-F0-9]{64}", item):
continue
return False
return True
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Error checking if folder is dedicated: %s", exc)
return False
class UsageStatsHandler:
def __init__(self, usage_stats_factory: UsageStatsFactory = UsageStats) -> None:
self._usage_stats_factory = usage_stats_factory
async def update_usage_stats(self, request: web.Request) -> web.Response:
try:
data = await request.json()
prompt_id = data.get("prompt_id")
if not prompt_id:
return web.json_response({"success": False, "error": "Missing prompt_id"}, status=400)
usage_stats = self._usage_stats_factory()
await usage_stats.process_execution(prompt_id)
return web.json_response({"success": True})
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Failed to update usage stats: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_usage_stats(self, request: web.Request) -> web.Response:
try:
usage_stats = self._usage_stats_factory()
stats = await usage_stats.get_stats()
stats_response = {"success": True, "data": stats, "format_version": 2}
return web.json_response(stats_response)
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Failed to get usage stats: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
class LoraCodeHandler:
def __init__(self, prompt_server: type[PromptServerProtocol]) -> None:
self._prompt_server = prompt_server
async def update_lora_code(self, request: web.Request) -> web.Response:
try:
data = await request.json()
node_ids = data.get("node_ids")
lora_code = data.get("lora_code", "")
mode = data.get("mode", "append")
if not lora_code:
return web.json_response({"success": False, "error": "Missing lora_code parameter"}, status=400)
results = []
if node_ids is None:
try:
self._prompt_server.instance.send_sync(
"lora_code_update", {"id": -1, "lora_code": lora_code, "mode": mode}
)
results.append({"node_id": "broadcast", "success": True})
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Error broadcasting lora code: %s", exc)
results.append({"node_id": "broadcast", "success": False, "error": str(exc)})
else:
for node_id in node_ids:
try:
self._prompt_server.instance.send_sync(
"lora_code_update",
{"id": node_id, "lora_code": lora_code, "mode": mode},
)
results.append({"node_id": node_id, "success": True})
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Error sending lora code to node %s: %s", node_id, exc)
results.append({"node_id": node_id, "success": False, "error": str(exc)})
return web.json_response({"success": True, "results": results})
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Failed to update lora code: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
class TrainedWordsHandler:
async def get_trained_words(self, request: web.Request) -> web.Response:
try:
file_path = request.query.get("file_path")
if not file_path:
return web.json_response({"success": False, "error": "Missing file_path parameter"}, status=400)
if not os.path.exists(file_path):
return web.json_response({"success": False, "error": "File not found"}, status=404)
if not file_path.endswith(".safetensors"):
return web.json_response({"success": False, "error": "File must be a safetensors file"}, status=400)
trained_words, class_tokens = await extract_trained_words(file_path)
return web.json_response(
{
"success": True,
"trained_words": trained_words,
"class_tokens": class_tokens,
}
)
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Failed to get trained words: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
class ModelExampleFilesHandler:
async def get_model_example_files(self, request: web.Request) -> web.Response:
try:
model_path = request.query.get("model_path")
if not model_path:
return web.json_response({"success": False, "error": "Missing model_path parameter"}, status=400)
model_dir = os.path.dirname(model_path)
if not os.path.exists(model_dir):
return web.json_response({"success": False, "error": "Model directory not found"}, status=404)
base_name = os.path.splitext(os.path.basename(model_path))[0]
files = []
pattern = f"{base_name}.example."
for file in os.listdir(model_dir):
if not file.startswith(pattern):
continue
file_full_path = os.path.join(model_dir, file)
if not os.path.isfile(file_full_path):
continue
file_ext = os.path.splitext(file)[1].lower()
if file_ext not in SUPPORTED_MEDIA_EXTENSIONS["images"] and file_ext not in SUPPORTED_MEDIA_EXTENSIONS["videos"]:
continue
try:
index = int(file[len(pattern) :].split(".")[0])
except (ValueError, IndexError):
index = float("inf")
static_url = config.get_preview_static_url(file_full_path)
files.append(
{
"name": file,
"path": static_url,
"extension": file_ext,
"is_video": file_ext in SUPPORTED_MEDIA_EXTENSIONS["videos"],
"index": index,
}
)
files.sort(key=lambda item: item["index"])
for file in files:
file.pop("index", None)
return web.json_response({"success": True, "files": files})
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Failed to get model example files: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
@dataclass
class ServiceRegistryAdapter:
get_lora_scanner: Callable[[], Awaitable]
get_checkpoint_scanner: Callable[[], Awaitable]
get_embedding_scanner: Callable[[], Awaitable]
class ModelLibraryHandler:
def __init__(self, service_registry: ServiceRegistryAdapter, metadata_provider_factory: Callable[[], Awaitable[MetadataProviderProtocol | None]]) -> None:
self._service_registry = service_registry
self._metadata_provider_factory = metadata_provider_factory
async def check_model_exists(self, request: web.Request) -> web.Response:
try:
model_id_str = request.query.get("modelId")
model_version_id_str = request.query.get("modelVersionId")
if not model_id_str:
return web.json_response({"success": False, "error": "Missing required parameter: modelId"}, status=400)
try:
model_id = int(model_id_str)
except ValueError:
return web.json_response({"success": False, "error": "Parameter modelId must be an integer"}, status=400)
lora_scanner = await self._service_registry.get_lora_scanner()
checkpoint_scanner = await self._service_registry.get_checkpoint_scanner()
embedding_scanner = await self._service_registry.get_embedding_scanner()
if model_version_id_str:
try:
model_version_id = int(model_version_id_str)
except ValueError:
return web.json_response({"success": False, "error": "Parameter modelVersionId must be an integer"}, status=400)
exists = False
model_type = None
if await lora_scanner.check_model_version_exists(model_version_id):
exists = True
model_type = "lora"
elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_version_id):
exists = True
model_type = "checkpoint"
elif embedding_scanner and await embedding_scanner.check_model_version_exists(model_version_id):
exists = True
model_type = "embedding"
return web.json_response({"success": True, "exists": exists, "modelType": model_type if exists else None})
lora_versions = await lora_scanner.get_model_versions_by_id(model_id)
checkpoint_versions = []
embedding_versions = []
if not lora_versions and checkpoint_scanner:
checkpoint_versions = await checkpoint_scanner.get_model_versions_by_id(model_id)
if not lora_versions and not checkpoint_versions and embedding_scanner:
embedding_versions = await embedding_scanner.get_model_versions_by_id(model_id)
model_type = None
version_ids: list[int] = []
if lora_versions:
model_type = "lora"
version_ids = [version["versionId"] for version in lora_versions]
elif checkpoint_versions:
model_type = "checkpoint"
version_ids = [version["versionId"] for version in checkpoint_versions]
elif embedding_versions:
model_type = "embedding"
version_ids = [version["versionId"] for version in embedding_versions]
return web.json_response({"success": True, "modelType": model_type, "modelVersionIds": version_ids})
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Failed to check model existence: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_model_versions_status(self, request: web.Request) -> web.Response:
try:
model_id_str = request.query.get("modelId")
if not model_id_str:
return web.json_response({"success": False, "error": "Missing required parameter: modelId"}, status=400)
try:
model_id = int(model_id_str)
except ValueError:
return web.json_response({"success": False, "error": "Parameter modelId must be an integer"}, status=400)
metadata_provider = await self._metadata_provider_factory()
if not metadata_provider:
return web.json_response({"success": False, "error": "Metadata provider not available"}, status=503)
response = await metadata_provider.get_model_versions(model_id)
if not response or not response.get("modelVersions"):
return web.json_response({"success": False, "error": "Model not found"}, status=404)
versions = response.get("modelVersions", [])
model_name = response.get("name", "")
model_type = response.get("type", "").lower()
scanner = None
normalized_type = None
if model_type in {"lora", "locon", "dora"}:
scanner = await self._service_registry.get_lora_scanner()
normalized_type = "lora"
elif model_type == "checkpoint":
scanner = await self._service_registry.get_checkpoint_scanner()
normalized_type = "checkpoint"
elif model_type == "textualinversion":
scanner = await self._service_registry.get_embedding_scanner()
normalized_type = "embedding"
else:
return web.json_response({"success": False, "error": f'Model type "{model_type}" is not supported'}, status=400)
if not scanner:
return web.json_response({"success": False, "error": f'Scanner for type "{normalized_type}" is not available'}, status=503)
local_versions = await scanner.get_model_versions_by_id(model_id)
local_version_ids = {version["versionId"] for version in local_versions}
enriched_versions = []
for version in versions:
version_id = version.get("id")
enriched_versions.append(
{
"id": version_id,
"name": version.get("name", ""),
"thumbnailUrl": version.get("images")[0]["url"] if version.get("images") else None,
"inLibrary": version_id in local_version_ids,
}
)
return web.json_response(
{
"success": True,
"modelId": model_id,
"modelName": model_name,
"modelType": model_type,
"versions": enriched_versions,
}
)
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Failed to get model versions status: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
class MetadataArchiveHandler:
def __init__(
self,
*,
metadata_archive_manager_factory: Callable[[], Awaitable[MetadataArchiveManagerProtocol]] = get_metadata_archive_manager,
settings_service=default_settings,
metadata_provider_updater: Callable[[], Awaitable[None]] = update_metadata_providers,
) -> None:
self._metadata_archive_manager_factory = metadata_archive_manager_factory
self._settings = settings_service
self._metadata_provider_updater = metadata_provider_updater
async def download_metadata_archive(self, request: web.Request) -> web.Response:
try:
archive_manager = await self._metadata_archive_manager_factory()
download_id = request.query.get("download_id")
def progress_callback(stage: str, message: str) -> None:
data = {"stage": stage, "message": message, "type": "metadata_archive_download"}
if download_id:
asyncio.create_task(ws_manager.broadcast_download_progress(download_id, data))
else:
asyncio.create_task(ws_manager.broadcast(data))
success = await archive_manager.download_and_extract_database(progress_callback)
if success:
self._settings.set("enable_metadata_archive_db", True)
await self._metadata_provider_updater()
return web.json_response({"success": True, "message": "Metadata archive database downloaded and extracted successfully"})
return web.json_response({"success": False, "error": "Failed to download and extract metadata archive database"}, status=500)
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Error downloading metadata archive: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def remove_metadata_archive(self, request: web.Request) -> web.Response:
try:
archive_manager = await self._metadata_archive_manager_factory()
success = await archive_manager.remove_database()
if success:
self._settings.set("enable_metadata_archive_db", False)
await self._metadata_provider_updater()
return web.json_response({"success": True, "message": "Metadata archive database removed successfully"})
return web.json_response({"success": False, "error": "Failed to remove metadata archive database"}, status=500)
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Error removing metadata archive: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_metadata_archive_status(self, request: web.Request) -> web.Response:
try:
archive_manager = await self._metadata_archive_manager_factory()
is_available = archive_manager.is_database_available()
is_enabled = self._settings.get("enable_metadata_archive_db", False)
db_size = 0
if is_available:
db_path = archive_manager.get_database_path()
if db_path and os.path.exists(db_path):
db_size = os.path.getsize(db_path)
return web.json_response(
{
"success": True,
"isAvailable": is_available,
"isEnabled": is_enabled,
"databaseSize": db_size,
"databasePath": archive_manager.get_database_path() if is_available else None,
}
)
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Error getting metadata archive status: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
class FileSystemHandler:
async def open_file_location(self, request: web.Request) -> web.Response:
try:
data = await request.json()
file_path = data.get("file_path")
if not file_path:
return web.json_response({"success": False, "error": "Missing file_path parameter"}, status=400)
file_path = os.path.abspath(file_path)
if not os.path.isfile(file_path):
return web.json_response({"success": False, "error": "File does not exist"}, status=404)
if os.name == "nt":
subprocess.Popen(["explorer", "/select,", file_path])
elif os.name == "posix":
if sys.platform == "darwin":
subprocess.Popen(["open", "-R", file_path])
else:
folder = os.path.dirname(file_path)
subprocess.Popen(["xdg-open", folder])
return web.json_response({"success": True, "message": f"Opened folder and selected file: {file_path}"})
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Failed to open file location: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
class NodeRegistryHandler:
def __init__(
self,
node_registry: NodeRegistry,
prompt_server: type[PromptServerProtocol],
*,
standalone_mode: bool,
) -> None:
self._node_registry = node_registry
self._prompt_server = prompt_server
self._standalone_mode = standalone_mode
async def register_nodes(self, request: web.Request) -> web.Response:
try:
data = await request.json()
nodes = data.get("nodes", [])
if not isinstance(nodes, list):
return web.json_response({"success": False, "error": "nodes must be a list"}, status=400)
for index, node in enumerate(nodes):
if not isinstance(node, dict):
return web.json_response({"success": False, "error": f"Node {index} must be an object"}, status=400)
node_id = node.get("node_id")
if node_id is None:
return web.json_response({"success": False, "error": f"Node {index} missing node_id parameter"}, status=400)
try:
node["node_id"] = int(node_id)
except (TypeError, ValueError):
return web.json_response({"success": False, "error": f"Node {index} node_id must be an integer"}, status=400)
await self._node_registry.register_nodes(nodes)
return web.json_response({"success": True, "message": f"{len(nodes)} nodes registered successfully"})
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Failed to register nodes: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_registry(self, request: web.Request) -> web.Response:
try:
if self._standalone_mode:
logger.warning("Registry refresh not available in standalone mode")
return web.json_response(
{
"success": False,
"error": "Standalone Mode Active",
"message": "Cannot interact with ComfyUI in standalone mode.",
},
status=503,
)
try:
self._prompt_server.instance.send_sync("lora_registry_refresh", {})
logger.debug("Sent registry refresh request to frontend")
except Exception as exc:
logger.error("Failed to send registry refresh message: %s", exc)
return web.json_response(
{
"success": False,
"error": "Communication Error",
"message": f"Failed to communicate with ComfyUI frontend: {exc}",
},
status=500,
)
registry_updated = await self._node_registry.wait_for_update(timeout=1.0)
if not registry_updated:
logger.warning("Registry refresh timeout after 1 second")
return web.json_response(
{
"success": False,
"error": "Timeout Error",
"message": "Registry refresh timeout - ComfyUI frontend may not be responsive",
},
status=408,
)
registry_info = await self._node_registry.get_registry()
return web.json_response({"success": True, "data": registry_info})
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Failed to get registry: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": "Internal Error", "message": str(exc)}, status=500)
class MiscHandlerSet:
"""Aggregate handlers into a lookup compatible with the registrar."""
def __init__(
self,
*,
health: HealthCheckHandler,
settings: SettingsHandler,
usage_stats: UsageStatsHandler,
lora_code: LoraCodeHandler,
trained_words: TrainedWordsHandler,
model_examples: ModelExampleFilesHandler,
node_registry: NodeRegistryHandler,
model_library: ModelLibraryHandler,
metadata_archive: MetadataArchiveHandler,
filesystem: FileSystemHandler,
) -> None:
self.health = health
self.settings = settings
self.usage_stats = usage_stats
self.lora_code = lora_code
self.trained_words = trained_words
self.model_examples = model_examples
self.node_registry = node_registry
self.model_library = model_library
self.metadata_archive = metadata_archive
self.filesystem = filesystem
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]:
return {
"health_check": self.health.health_check,
"get_settings": self.settings.get_settings,
"update_settings": self.settings.update_settings,
"update_usage_stats": self.usage_stats.update_usage_stats,
"get_usage_stats": self.usage_stats.get_usage_stats,
"update_lora_code": self.lora_code.update_lora_code,
"get_trained_words": self.trained_words.get_trained_words,
"get_model_example_files": self.model_examples.get_model_example_files,
"register_nodes": self.node_registry.register_nodes,
"get_registry": self.node_registry.get_registry,
"check_model_exists": self.model_library.check_model_exists,
"download_metadata_archive": self.metadata_archive.download_metadata_archive,
"remove_metadata_archive": self.metadata_archive.remove_metadata_archive,
"get_metadata_archive_status": self.metadata_archive.get_metadata_archive_status,
"get_model_versions_status": self.model_library.get_model_versions_status,
"open_file_location": self.filesystem.open_file_location,
}
def build_service_registry_adapter() -> ServiceRegistryAdapter:
return ServiceRegistryAdapter(
get_lora_scanner=ServiceRegistry.get_lora_scanner,
get_checkpoint_scanner=ServiceRegistry.get_checkpoint_scanner,
get_embedding_scanner=ServiceRegistry.get_embedding_scanner,
)

View File

@@ -0,0 +1,67 @@
"""Route registrar for miscellaneous endpoints.
This module mirrors the model route registrar architecture so that
miscellaneous endpoints share a consistent registration flow.
"""
from dataclasses import dataclass
from typing import Callable, Iterable, Mapping
from aiohttp import web
@dataclass(frozen=True)
class RouteDefinition:
"""Declarative definition for a HTTP route."""
method: str
path: str
handler_name: str
MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("GET", "/api/lm/settings", "get_settings"),
RouteDefinition("POST", "/api/lm/settings", "update_settings"),
RouteDefinition("GET", "/api/lm/health-check", "health_check"),
RouteDefinition("POST", "/api/lm/open-file-location", "open_file_location"),
RouteDefinition("POST", "/api/lm/update-usage-stats", "update_usage_stats"),
RouteDefinition("GET", "/api/lm/get-usage-stats", "get_usage_stats"),
RouteDefinition("POST", "/api/lm/update-lora-code", "update_lora_code"),
RouteDefinition("GET", "/api/lm/trained-words", "get_trained_words"),
RouteDefinition("GET", "/api/lm/model-example-files", "get_model_example_files"),
RouteDefinition("POST", "/api/lm/register-nodes", "register_nodes"),
RouteDefinition("GET", "/api/lm/get-registry", "get_registry"),
RouteDefinition("GET", "/api/lm/check-model-exists", "check_model_exists"),
RouteDefinition("POST", "/api/lm/download-metadata-archive", "download_metadata_archive"),
RouteDefinition("POST", "/api/lm/remove-metadata-archive", "remove_metadata_archive"),
RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"),
RouteDefinition("GET", "/api/lm/model-versions-status", "get_model_versions_status"),
)
class MiscRouteRegistrar:
"""Bind miscellaneous route definitions to an aiohttp router."""
_METHOD_MAP = {
"GET": "add_get",
"POST": "add_post",
"PUT": "add_put",
"DELETE": "add_delete",
}
def __init__(self, app: web.Application) -> None:
self._app = app
def register_routes(
self,
handler_lookup: Mapping[str, Callable[[web.Request], object]],
*,
definitions: Iterable[RouteDefinition] = MISC_ROUTE_DEFINITIONS,
) -> None:
for definition in definitions:
self._bind(definition.method, definition.path, handler_lookup[definition.handler_name])
def _bind(self, method: str, path: str, handler: Callable) -> None:
add_method_name = self._METHOD_MAP[method.upper()]
add_method = getattr(self._app.router, add_method_name)
add_method(path, handler)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,211 @@
import json
from types import SimpleNamespace
import pytest
from aiohttp import web
from py.routes.handlers.misc_handlers import SettingsHandler, ServiceRegistryAdapter
from py.routes.misc_route_registrar import MISC_ROUTE_DEFINITIONS, MiscRouteRegistrar
from py.routes.misc_routes import MiscRoutes
class FakeRequest:
def __init__(self, *, json_data=None, query=None):
self._json_data = json_data or {}
self.query = query or {}
async def json(self):
return self._json_data
class DummySettings:
def __init__(self, data=None):
self.data = data or {}
def get(self, key, default=None):
return self.data.get(key, default)
def set(self, key, value):
self.data[key] = value
def delete(self, key):
self.data.pop(key, None)
class DummyDownloader:
def __init__(self):
self.refreshed = False
async def refresh_session(self):
self.refreshed = True
async def noop_async(*_args, **_kwargs):
return None
async def dummy_downloader_factory():
return DummyDownloader()
@pytest.mark.asyncio
async def test_get_settings_filters_sync_keys():
settings_service = DummySettings({"civitai_api_key": "abc", "extraneous": "value"})
handler = SettingsHandler(
settings_service=settings_service,
metadata_provider_updater=noop_async,
downloader_factory=dummy_downloader_factory,
)
response = await handler.get_settings(FakeRequest())
payload = json.loads(response.text)
assert payload["success"] is True
assert payload["settings"] == {"civitai_api_key": "abc"}
@pytest.mark.asyncio
async def test_update_settings_rejects_missing_example_path(tmp_path):
settings_service = DummySettings()
handler = SettingsHandler(
settings_service=settings_service,
metadata_provider_updater=noop_async,
downloader_factory=dummy_downloader_factory,
)
missing_path = tmp_path / "does-not-exist"
request = FakeRequest(json_data={"example_images_path": str(missing_path)})
response = await handler.update_settings(request)
payload = json.loads(response.text)
assert payload["success"] is False
assert "Path does not exist" in payload["error"]
class RecordingRouter:
def __init__(self):
self.calls = []
def add_get(self, path, handler):
self.calls.append(("GET", path, handler))
def add_post(self, path, handler):
self.calls.append(("POST", path, handler))
def add_put(self, path, handler):
self.calls.append(("PUT", path, handler))
def add_delete(self, path, handler):
self.calls.append(("DELETE", path, handler))
def test_misc_route_registrar_registers_all_routes():
app = SimpleNamespace(router=RecordingRouter())
registrar = MiscRouteRegistrar(app) # type: ignore[arg-type]
async def dummy_handler(_request):
return web.Response()
handler_mapping = {definition.handler_name: dummy_handler for definition in MISC_ROUTE_DEFINITIONS}
registrar.register_routes(handler_mapping)
registered = {(method, path) for method, path, _ in app.router.calls}
expected = {(definition.method, definition.path) for definition in MISC_ROUTE_DEFINITIONS}
assert registered == expected
class FakePromptServer:
sent = []
class Instance:
def send_sync(self, event, payload):
FakePromptServer.sent.append((event, payload))
instance = Instance()
class FakeScanner:
async def check_model_version_exists(self, _version_id):
return False
async def get_model_versions_by_id(self, _model_id):
return []
async def fake_scanner_factory():
return FakeScanner()
class FakeMetadataProvider:
async def get_model_versions(self, _model_id):
return {"modelVersions": [], "name": "", "type": "lora"}
async def fake_metadata_provider_factory():
return FakeMetadataProvider()
class FakeMetadataArchiveManager:
async def download_and_extract_database(self, _callback):
return True
async def remove_database(self):
return True
def is_database_available(self):
return False
def get_database_path(self):
return None
async def fake_metadata_archive_manager_factory():
return FakeMetadataArchiveManager()
class RecordingRegistrar:
def __init__(self, _app):
self.registered_mapping = None
def register_routes(self, mapping):
self.registered_mapping = mapping
@pytest.mark.asyncio
async def test_misc_routes_bind_produces_expected_handlers():
service_registry_adapter = ServiceRegistryAdapter(
get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory,
)
recorded_registrars = []
def registrar_factory(app):
registrar = RecordingRegistrar(app)
recorded_registrars.append(registrar)
return registrar
controller = MiscRoutes(
settings_service=DummySettings(),
usage_stats_factory=lambda: SimpleNamespace(process_execution=noop_async, get_stats=noop_async),
prompt_server=FakePromptServer,
service_registry_adapter=service_registry_adapter,
metadata_provider_factory=fake_metadata_provider_factory,
metadata_archive_manager_factory=fake_metadata_archive_manager_factory,
metadata_provider_updater=noop_async,
downloader_factory=dummy_downloader_factory,
registrar_factory=registrar_factory,
)
app = SimpleNamespace(router=RecordingRouter())
controller.bind(app) # type: ignore[arg-type]
assert recorded_registrars, "Expected registrar to be created"
mapping = recorded_registrars[0].registered_mapping
assert mapping is not None
expected_names = {definition.handler_name for definition in MISC_ROUTE_DEFINITIONS}
assert set(mapping.keys()) == expected_names