diff --git a/py/lora_manager.py b/py/lora_manager.py index a8d0d8f9..3a8c47c2 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -17,6 +17,7 @@ from .services.settings_manager import get_settings_manager from .utils.example_images_migration import ExampleImagesMigration from .services.websocket_manager import ws_manager from .services.example_images_cleanup_service import ExampleImagesCleanupService +from .middleware.csp_middleware import relax_csp_for_remote_media logger = logging.getLogger(__name__) @@ -62,6 +63,23 @@ class LoraManager: """Initialize and register all routes using the new refactored architecture""" app = PromptServer.instance.app + if relax_csp_for_remote_media not in app.middlewares: + # Ensure CSP relaxer executes after ComfyUI's block_external_middleware so it can + # see and extend the restrictive header instead of being overwritten by it. + block_middleware_index = next( + ( + idx + for idx, middleware in enumerate(app.middlewares) + if getattr(middleware, "__name__", "") == "block_external_middleware" + ), + None, + ) + + if block_middleware_index is None: + app.middlewares.append(relax_csp_for_remote_media) + else: + app.middlewares.insert(block_middleware_index, relax_csp_for_remote_media) + # Increase allowed header sizes so browsers with large localhost cookie # jars (multiple UIs on 127.0.0.1) don't trip aiohttp's 8KB default # limits. Cookies for unrelated apps are still sent to the plugin and diff --git a/py/middleware/csp_middleware.py b/py/middleware/csp_middleware.py new file mode 100644 index 00000000..cad6b99d --- /dev/null +++ b/py/middleware/csp_middleware.py @@ -0,0 +1,65 @@ +"""Middleware helpers for adjusting Content Security Policy headers.""" + +from typing import Awaitable, Callable, Dict, List + +from aiohttp import web + +REMOTE_MEDIA_SOURCES = ( + "https://image.civitai.com", + "https://img.genur.art", +) + + +@web.middleware +async def relax_csp_for_remote_media( + request: web.Request, handler: Callable[[web.Request], Awaitable[web.StreamResponse]] +) -> web.StreamResponse: + """Allow LoRA Manager media previews to load from trusted remote domains. + + When ComfyUI is started with ``--disable-api-nodes`` it injects a restrictive + ``Content-Security-Policy`` header that blocks remote images and videos. The + LoRA Manager UI legitimately needs to fetch previews from Civitai and Genur, + so this middleware augments the existing CSP to whitelist those hosts while + preserving all other directives. + """ + + response: web.StreamResponse = await handler(request) + header_value = response.headers.get("Content-Security-Policy") + + if not header_value: + return response + + directive_order: List[str] = [] + directives: Dict[str, List[str]] = {} + + for raw_directive in header_value.split(";"): + directive = raw_directive.strip() + if not directive: + continue + + parts = directive.split() + name, values = parts[0], parts[1:] + if name not in directive_order: + directive_order.append(name) + directives[name] = values + + def merge_sources(name: str, sources: List[str], defaults: List[str] | None = None) -> None: + existing = directives.get(name, list(defaults or [])) + + for source in sources: + if source not in existing: + existing.append(source) + + directives[name] = existing + if name not in directive_order: + directive_order.append(name) + + merge_sources("img-src", list(REMOTE_MEDIA_SOURCES)) + merge_sources("media-src", ["'self'", *REMOTE_MEDIA_SOURCES], defaults=["'self'"]) + + updated_header = "; ".join( + f"{name} {' '.join(directives[name])}".rstrip() for name in directive_order + ) + + response.headers["Content-Security-Policy"] = f"{updated_header};" + return response diff --git a/tests/middleware/test_csp_middleware.py b/tests/middleware/test_csp_middleware.py new file mode 100644 index 00000000..bc07d58f --- /dev/null +++ b/tests/middleware/test_csp_middleware.py @@ -0,0 +1,69 @@ +import pytest +from aiohttp import web +from aiohttp.test_utils import make_mocked_request + +from py.middleware.csp_middleware import REMOTE_MEDIA_SOURCES, relax_csp_for_remote_media + +DEFAULT_CSP = ( + "default-src 'self'; " + "script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; " + "style-src 'self' 'unsafe-inline'; " + "img-src 'self' data: blob:; " + "font-src 'self'; " + "connect-src 'self'; " + "frame-src 'self'; " + "object-src 'self';" +) + + +def _parse_directives(header: str) -> dict[str, list[str]]: + directives: dict[str, list[str]] = {} + for raw_directive in header.split(";"): + directive = raw_directive.strip() + if not directive: + continue + name, *values = directive.split() + directives[name] = values + return directives + + +async def _invoke_middleware( + path: str, response: web.Response, csp_header: str | None = DEFAULT_CSP +) -> web.Response: + async def handler(_request: web.Request) -> web.Response: + if csp_header is not None: + response.headers["Content-Security-Policy"] = csp_header + return response + + request = make_mocked_request("GET", path) + return await relax_csp_for_remote_media(request, handler) + + +@pytest.mark.asyncio +async def test_relax_csp_appends_remote_sources_and_preserves_existing_directives() -> None: + response = await _invoke_middleware("/some-path", web.Response()) + header_value = response.headers.get("Content-Security-Policy") + assert header_value is not None + + directives = _parse_directives(header_value) + + # Existing directives remain intact + assert directives["script-src"] == ["'self'", "'unsafe-inline'", "'unsafe-eval'", "blob:"] + assert directives["img-src"][:3] == ["'self'", "data:", "blob:"] + + # Remote media hosts are added once to the relevant directives + for source in REMOTE_MEDIA_SOURCES: + assert source in directives["img-src"] + + assert "media-src" in directives + assert directives["media-src"][0] == "'self'" + for source in REMOTE_MEDIA_SOURCES: + assert source in directives["media-src"] + + +@pytest.mark.asyncio +async def test_relax_csp_no_header_left_untouched() -> None: + response = await _invoke_middleware("/no-csp", web.Response(), csp_header=None) + + assert "Content-Security-Policy" not in response.headers + assert response.headers.get("X-Test") is None