From d01666f4e284e0b2e91108d8af1057a6354c65fd Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Sat, 4 Oct 2025 10:38:06 +0800 Subject: [PATCH] feat(previews): serve dynamic library previews --- docs/library-switching.md | 28 ++++++ py/config.py | 112 +++++++++++++++++------ py/lora_manager.py | 93 +------------------- py/routes/handlers/preview_handlers.py | 56 ++++++++++++ py/routes/preview_routes.py | 25 ++++++ standalone.py | 117 +------------------------ tests/routes/test_preview_routes.py | 112 +++++++++++++++++++++++ 7 files changed, 312 insertions(+), 231 deletions(-) create mode 100644 docs/library-switching.md create mode 100644 py/routes/handlers/preview_handlers.py create mode 100644 py/routes/preview_routes.py create mode 100644 tests/routes/test_preview_routes.py diff --git a/docs/library-switching.md b/docs/library-switching.md new file mode 100644 index 00000000..b75db4f7 --- /dev/null +++ b/docs/library-switching.md @@ -0,0 +1,28 @@ +# Library Switching and Preview Routes + +Library switching no longer requires restarting the backend. The preview +thumbnails shown in the UI are now served through a dynamic endpoint that +resolves files against the folders registered for the active library at request +time. This allows the multi-library flow to update model roots without touching +the aiohttp router, so previews remain available immediately after a switch. + +## How the dynamic preview endpoint works + +* `config.get_preview_static_url()` now returns `/api/lm/previews?path=` + for any preview path. The raw filesystem location is URL encoded so that it + can be passed through the query string without leaking directory structure in + the route itself.【F:py/config.py†L398-L404】 +* `PreviewRoutes` exposes the `/api/lm/previews` handler which validates the + decoded path against the directories registered for the current library. The + request is rejected if it falls outside those roots or if the file does not + exist.【F:py/routes/preview_routes.py†L5-L21】【F:py/routes/handlers/preview_handlers.py†L9-L48】 +* `Config` keeps an up-to-date cache of allowed preview roots. Every time a + library is applied the cache is rebuilt using the declared LoRA, checkpoint + and embedding directories (including symlink targets). The validation logic + checks preview requests against this cache.【F:py/config.py†L51-L68】【F:py/config.py†L180-L248】【F:py/config.py†L332-L346】 + +Both the ComfyUI runtime (`LoraManager.add_routes`) and the standalone launcher +(`StandaloneLoraManager.add_routes`) register the new preview routes instead of +mounting a static directory per root. Switching libraries therefore works +without restarting the application, and preview URLs generated before or after a +switch continue to resolve correctly.【F:py/lora_manager.py†L21-L82】【F:standalone.py†L302-L315】 diff --git a/py/config.py b/py/config.py index 85d5e58a..a9d0e736 100644 --- a/py/config.py +++ b/py/config.py @@ -1,5 +1,6 @@ import os import platform +from pathlib import Path import folder_paths # type: ignore from typing import Dict, Iterable, List, Mapping, Set import logging @@ -52,9 +53,9 @@ class Config: self.static_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static') self.i18n_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'locales') # Path mapping dictionary, target to link mapping - self._path_mappings = {} - # Static route mapping dictionary, target to route mapping - self._route_mappings = {} + self._path_mappings: Dict[str, str] = {} + # Normalized preview root directories used to validate preview access + self._preview_root_paths: Set[Path] = set() self.loras_roots = self._init_lora_paths() self.checkpoints_roots = None self.unet_roots = None @@ -63,6 +64,7 @@ class Config: self.embeddings_roots = self._init_embedding_paths() # Scan symbolic links during initialization self._scan_symbolic_links() + self._rebuild_preview_roots() if not standalone_mode: # Save the paths to settings.json when running in ComfyUI mode @@ -185,12 +187,65 @@ class Config: # Keep the original mapping: target path -> link path self._path_mappings[normalized_target] = normalized_link logger.info(f"Added path mapping: {normalized_target} -> {normalized_link}") + self._preview_root_paths.update(self._expand_preview_root(normalized_target)) + self._preview_root_paths.update(self._expand_preview_root(normalized_link)) - def add_route_mapping(self, path: str, route: str): - """Add a static route mapping""" - normalized_path = os.path.normpath(path).replace(os.sep, '/') - self._route_mappings[normalized_path] = route - # logger.info(f"Added route mapping: {normalized_path} -> {route}") + def _expand_preview_root(self, path: str) -> Set[Path]: + """Return normalized ``Path`` objects representing a preview root.""" + + roots: Set[Path] = set() + if not path: + return roots + + try: + raw_path = Path(path).expanduser() + except Exception: + return roots + + if raw_path.is_absolute(): + roots.add(raw_path) + + try: + resolved = raw_path.resolve(strict=False) + except RuntimeError: + resolved = raw_path.absolute() + roots.add(resolved) + + try: + real_path = raw_path.resolve() + except (FileNotFoundError, RuntimeError): + real_path = resolved + roots.add(real_path) + + normalized: Set[Path] = set() + for candidate in roots: + if candidate.is_absolute(): + normalized.add(candidate) + else: + try: + normalized.add(candidate.resolve(strict=False)) + except RuntimeError: + normalized.add(candidate.absolute()) + + return normalized + + def _rebuild_preview_roots(self) -> None: + """Recompute the cache of directories permitted for previews.""" + + preview_roots: Set[Path] = set() + + for root in self.loras_roots or []: + preview_roots.update(self._expand_preview_root(root)) + for root in self.base_models_roots or []: + preview_roots.update(self._expand_preview_root(root)) + for root in self.embeddings_roots or []: + preview_roots.update(self._expand_preview_root(root)) + + for target, link in self._path_mappings.items(): + preview_roots.update(self._expand_preview_root(target)) + preview_roots.update(self._expand_preview_root(link)) + + self._preview_root_paths = {path for path in preview_roots if path.is_absolute()} def map_path_to_link(self, path: str) -> str: """Map a target path back to its symbolic link path""" @@ -276,6 +331,7 @@ class Config: def _apply_library_paths(self, folder_paths: Mapping[str, Iterable[str]]) -> None: self._path_mappings.clear() + self._preview_root_paths = set() lora_paths = folder_paths.get('loras', []) or [] checkpoint_paths = folder_paths.get('checkpoints', []) or [] @@ -287,6 +343,7 @@ class Config: self.embeddings_roots = self._prepare_embedding_paths(embedding_paths) self._scan_symbolic_links() + self._rebuild_preview_roots() def _init_lora_paths(self) -> List[str]: """Initialize and validate LoRA paths from ComfyUI settings""" @@ -342,24 +399,29 @@ class Config: if not preview_path: return "" - real_path = os.path.realpath(preview_path).replace(os.sep, '/') - - # Find longest matching path (most specific match) - best_match = "" - best_route = "" - - for path, route in self._route_mappings.items(): - if real_path.startswith(path) and len(path) > len(best_match): - best_match = path - best_route = route - - if best_match: - relative_path = os.path.relpath(real_path, best_match).replace(os.sep, '/') - safe_parts = [urllib.parse.quote(part) for part in relative_path.split('/')] - safe_path = '/'.join(safe_parts) - return f'{best_route}/{safe_path}' + normalized = os.path.normpath(preview_path).replace(os.sep, '/') + encoded_path = urllib.parse.quote(normalized, safe='') + return f'/api/lm/previews?path={encoded_path}' - return "" + def is_preview_path_allowed(self, preview_path: str) -> bool: + """Return ``True`` if ``preview_path`` is within an allowed directory.""" + + if not preview_path: + return False + + try: + candidate = Path(preview_path).expanduser().resolve(strict=False) + except Exception: + return False + + for root in self._preview_root_paths: + try: + candidate.relative_to(root) + return True + except ValueError: + continue + + return False def apply_library_settings(self, library_config: Mapping[str, object]) -> None: """Update runtime paths to match the provided library configuration.""" diff --git a/py/lora_manager.py b/py/lora_manager.py index 98f1a00e..a1255bd4 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -2,7 +2,6 @@ import asyncio import sys import os import logging -from pathlib import Path from server import PromptServer # type: ignore from .config import config @@ -11,6 +10,7 @@ from .routes.recipe_routes import RecipeRoutes from .routes.stats_routes import StatsRoutes from .routes.update_routes import UpdateRoutes from .routes.misc_routes import MiscRoutes +from .routes.preview_routes import PreviewRoutes from .routes.example_images_routes import ExampleImagesRoutes from .services.service_registry import ServiceRegistry from .services.settings_manager import settings @@ -50,102 +50,12 @@ class LoraManager: asyncio_logger = logging.getLogger("asyncio") asyncio_logger.addFilter(ConnectionResetFilter()) - added_targets = set() # Track already added target paths - # Add static route for example images if the path exists in settings example_images_path = settings.get('example_images_path') logger.info(f"Example images path: {example_images_path}") if example_images_path and os.path.exists(example_images_path): app.router.add_static('/example_images_static', example_images_path) logger.info(f"Added static route for example images: /example_images_static -> {example_images_path}") - - # Add static routes for each lora root - for idx, root in enumerate(config.loras_roots, start=1): - preview_path = f'/loras_static/root{idx}/preview' - - real_root = root - if root in config._path_mappings.values(): - for target, link in config._path_mappings.items(): - if link == root: - real_root = target - break - # Add static route for original path - app.router.add_static(preview_path, real_root) - logger.info(f"Added static route {preview_path} -> {real_root}") - - # Record route mapping - config.add_route_mapping(real_root, preview_path) - added_targets.add(real_root) - - # Add static routes for each checkpoint root - for idx, root in enumerate(config.base_models_roots, start=1): - preview_path = f'/checkpoints_static/root{idx}/preview' - - real_root = root - if root in config._path_mappings.values(): - for target, link in config._path_mappings.items(): - if link == root: - real_root = target - break - # Add static route for original path - app.router.add_static(preview_path, real_root) - logger.info(f"Added static route {preview_path} -> {real_root}") - - # Record route mapping - config.add_route_mapping(real_root, preview_path) - added_targets.add(real_root) - - # Add static routes for each embedding root - for idx, root in enumerate(config.embeddings_roots, start=1): - preview_path = f'/embeddings_static/root{idx}/preview' - - real_root = root - if root in config._path_mappings.values(): - for target, link in config._path_mappings.items(): - if link == root: - real_root = target - break - # Add static route for original path - app.router.add_static(preview_path, real_root) - logger.info(f"Added static route {preview_path} -> {real_root}") - - # Record route mapping - config.add_route_mapping(real_root, preview_path) - added_targets.add(real_root) - - # Add static routes for symlink target paths - link_idx = { - 'lora': 1, - 'checkpoint': 1, - 'embedding': 1 - } - - for target_path, link_path in config._path_mappings.items(): - if target_path not in added_targets: - # Determine if this is a checkpoint, lora, or embedding link based on path - is_checkpoint = any(cp_root in link_path for cp_root in config.base_models_roots) - is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.base_models_roots) - is_embedding = any(emb_root in link_path for emb_root in config.embeddings_roots) - is_embedding = is_embedding or any(emb_root in target_path for emb_root in config.embeddings_roots) - - if is_checkpoint: - route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview' - link_idx["checkpoint"] += 1 - elif is_embedding: - route_path = f'/embeddings_static/link_{link_idx["embedding"]}/preview' - link_idx["embedding"] += 1 - else: - route_path = f'/loras_static/link_{link_idx["lora"]}/preview' - link_idx["lora"] += 1 - - try: - app.router.add_static(route_path, Path(target_path).resolve(strict=False)) - logger.info(f"Added static route for link target {route_path} -> {target_path}") - config.add_route_mapping(target_path, route_path) - added_targets.add(target_path) - except Exception as e: - logger.warning(f"Failed to add static route on initialization for {target_path}: {e}") - continue # Add static route for locales JSON files if os.path.exists(config.i18n_path): @@ -168,6 +78,7 @@ class LoraManager: UpdateRoutes.setup_routes(app) MiscRoutes.setup_routes(app) ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager) + PreviewRoutes.setup_routes(app) # Setup WebSocket routes that are shared across all model types app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) diff --git a/py/routes/handlers/preview_handlers.py b/py/routes/handlers/preview_handlers.py new file mode 100644 index 00000000..660dbe83 --- /dev/null +++ b/py/routes/handlers/preview_handlers.py @@ -0,0 +1,56 @@ +"""Handlers responsible for serving preview assets dynamically.""" + +from __future__ import annotations + +import logging +import urllib.parse +from pathlib import Path + +from aiohttp import web + +from ...config import config as global_config + +logger = logging.getLogger(__name__) + + +class PreviewHandler: + """Serve preview assets for the active library at request time.""" + + def __init__(self, *, config=global_config) -> None: + self._config = config + + async def serve_preview(self, request: web.Request) -> web.StreamResponse: + """Return the preview file referenced by the encoded ``path`` query.""" + + raw_path = request.query.get("path", "") + if not raw_path: + raise web.HTTPBadRequest(text="Missing 'path' query parameter") + + try: + decoded_path = urllib.parse.unquote(raw_path) + except Exception as exc: # pragma: no cover - defensive guard + logger.debug("Failed to decode preview path %s: %s", raw_path, exc) + raise web.HTTPBadRequest(text="Invalid preview path encoding") from exc + + normalized = decoded_path.replace("\\", "/") + candidate = Path(normalized) + try: + resolved = candidate.expanduser().resolve(strict=False) + except Exception as exc: + logger.debug("Failed to resolve preview path %s: %s", normalized, exc) + raise web.HTTPBadRequest(text="Unable to resolve preview path") from exc + + resolved_str = str(resolved) + if not self._config.is_preview_path_allowed(resolved_str): + logger.debug("Rejected preview outside allowed roots: %s", resolved_str) + raise web.HTTPForbidden(text="Preview path is not within an allowed directory") + + if not resolved.is_file(): + logger.debug("Preview file not found at %s", resolved_str) + raise web.HTTPNotFound(text="Preview file not found") + + # aiohttp's FileResponse handles range requests and content headers for us. + return web.FileResponse(path=resolved, chunk_size=256 * 1024) + + +__all__ = ["PreviewHandler"] diff --git a/py/routes/preview_routes.py b/py/routes/preview_routes.py new file mode 100644 index 00000000..416a4e0f --- /dev/null +++ b/py/routes/preview_routes.py @@ -0,0 +1,25 @@ +"""Route controller for preview asset delivery.""" + +from __future__ import annotations + +from aiohttp import web + +from .handlers.preview_handlers import PreviewHandler + + +class PreviewRoutes: + """Register routes that expose preview assets.""" + + def __init__(self, *, handler: PreviewHandler | None = None) -> None: + self._handler = handler or PreviewHandler() + + @classmethod + def setup_routes(cls, app: web.Application) -> None: + controller = cls() + controller.register(app) + + def register(self, app: web.Application) -> None: + app.router.add_get('/api/lm/previews', self._handler.serve_preview) + + +__all__ = ["PreviewRoutes"] diff --git a/standalone.py b/standalone.py index d9d84aee..48f1a0b4 100644 --- a/standalone.py +++ b/standalone.py @@ -1,4 +1,3 @@ -from pathlib import Path import os import sys import json @@ -280,121 +279,7 @@ class StandaloneLoraManager(LoraManager): # Store app in a global-like location for compatibility sys.modules['server'].PromptServer.instance = server_instance - added_targets = set() # Track already added target paths - - # Add static routes for each lora root - for idx, root in enumerate(config.loras_roots, start=1): - if not os.path.exists(root): - logger.warning(f"Lora root path does not exist: {root}") - continue - - preview_path = f'/loras_static/root{idx}/preview' - - # Check if this root is a link path in the mappings - real_root = root - for target, link in config._path_mappings.items(): - if os.path.normpath(link) == os.path.normpath(root): - # If so, route should point to the target (real path) - real_root = target - break - - # Normalize and standardize path display for consistency - display_root = real_root.replace('\\', '/') - - # Add static route for original path - use the normalized path - app.router.add_static(preview_path, real_root) - logger.info(f"Added static route {preview_path} -> {display_root}") - - # Record route mapping with normalized path - config.add_route_mapping(real_root, preview_path) - added_targets.add(os.path.normpath(real_root)) - - # Add static routes for each checkpoint root - for idx, root in enumerate(config.base_models_roots, start=1): - if not os.path.exists(root): - logger.warning(f"Checkpoint root path does not exist: {root}") - continue - - preview_path = f'/checkpoints_static/root{idx}/preview' - - # Check if this root is a link path in the mappings - real_root = root - for target, link in config._path_mappings.items(): - if os.path.normpath(link) == os.path.normpath(root): - # If so, route should point to the target (real path) - real_root = target - break - - # Normalize and standardize path display for consistency - display_root = real_root.replace('\\', '/') - - # Add static route for original path - app.router.add_static(preview_path, real_root) - logger.info(f"Added static route {preview_path} -> {display_root}") - - # Record route mapping - config.add_route_mapping(real_root, preview_path) - added_targets.add(os.path.normpath(real_root)) - # Add static routes for each embedding root - for idx, root in enumerate(getattr(config, "embeddings_roots", []), start=1): - if not os.path.exists(root): - logger.warning(f"Embedding root path does not exist: {root}") - continue - - preview_path = f'/embeddings_static/root{idx}/preview' - - real_root = root - for target, link in config._path_mappings.items(): - if os.path.normpath(link) == os.path.normpath(root): - real_root = target - break - - display_root = real_root.replace('\\', '/') - app.router.add_static(preview_path, real_root) - logger.info(f"Added static route {preview_path} -> {display_root}") - - config.add_route_mapping(real_root, preview_path) - added_targets.add(os.path.normpath(real_root)) - - # Add static routes for symlink target paths that aren't already covered - link_idx = { - 'lora': 1, - 'checkpoint': 1, - 'embedding': 1 - } - - for target_path, link_path in config._path_mappings.items(): - norm_target = os.path.normpath(target_path) - if norm_target not in added_targets: - # Determine if this is a checkpoint, lora, or embedding link based on path - is_checkpoint = any(os.path.normpath(cp_root) in os.path.normpath(link_path) for cp_root in config.base_models_roots) - is_checkpoint = is_checkpoint or any(os.path.normpath(cp_root) in norm_target for cp_root in config.base_models_roots) - is_embedding = any(os.path.normpath(emb_root) in os.path.normpath(link_path) for emb_root in getattr(config, "embeddings_roots", [])) - is_embedding = is_embedding or any(os.path.normpath(emb_root) in norm_target for emb_root in getattr(config, "embeddings_roots", [])) - - if is_checkpoint: - route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview' - link_idx["checkpoint"] += 1 - elif is_embedding: - route_path = f'/embeddings_static/link_{link_idx["embedding"]}/preview' - link_idx["embedding"] += 1 - else: - route_path = f'/loras_static/link_{link_idx["lora"]}/preview' - link_idx["lora"] += 1 - - # Display path with forward slashes for consistency - display_target = target_path.replace('\\', '/') - - try: - app.router.add_static(route_path, Path(target_path).resolve(strict=False)) - logger.info(f"Added static route for link target {route_path} -> {display_target}") - config.add_route_mapping(target_path, route_path) - added_targets.add(norm_target) - except Exception as e: - logger.warning(f"Failed to add static route on initialization for {target_path}: {e}") - continue - # Add static route for locales JSON files if os.path.exists(config.i18n_path): app.router.add_static('/locales', config.i18n_path) @@ -409,6 +294,7 @@ class StandaloneLoraManager(LoraManager): from py.routes.update_routes import UpdateRoutes from py.routes.misc_routes import MiscRoutes from py.routes.example_images_routes import ExampleImagesRoutes + from py.routes.preview_routes import PreviewRoutes from py.routes.stats_routes import StatsRoutes from py.services.websocket_manager import ws_manager @@ -426,6 +312,7 @@ class StandaloneLoraManager(LoraManager): UpdateRoutes.setup_routes(app) MiscRoutes.setup_routes(app) ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager) + PreviewRoutes.setup_routes(app) # Setup WebSocket routes that are shared across all model types app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) diff --git a/tests/routes/test_preview_routes.py b/tests/routes/test_preview_routes.py new file mode 100644 index 00000000..9ec798ce --- /dev/null +++ b/tests/routes/test_preview_routes.py @@ -0,0 +1,112 @@ +import urllib.parse +from pathlib import Path + +import pytest +from aiohttp import web +from aiohttp.test_utils import make_mocked_request + +from py.config import Config +from py.routes.handlers.preview_handlers import PreviewHandler + + +async def test_preview_handler_serves_preview_from_active_library(tmp_path): + library_root = tmp_path / "library" + library_root.mkdir() + preview_file = library_root / "model.webp" + preview_file.write_bytes(b"preview") + + config = Config() + config.apply_library_settings( + { + "folder_paths": { + "loras": [str(library_root)], + "checkpoints": [], + "unet": [], + "embeddings": [], + } + } + ) + + handler = PreviewHandler(config=config) + encoded_path = urllib.parse.quote(str(preview_file), safe="") + request = make_mocked_request("GET", f"/api/lm/previews?path={encoded_path}") + + response = await handler.serve_preview(request) + + assert isinstance(response, web.FileResponse) + assert response.status == 200 + assert Path(response._path) == preview_file + + +async def test_preview_handler_forbids_paths_outside_active_library(tmp_path): + allowed_root = tmp_path / "allowed" + allowed_root.mkdir() + forbidden_root = tmp_path / "forbidden" + forbidden_root.mkdir() + forbidden_file = forbidden_root / "sneaky.webp" + forbidden_file.write_bytes(b"x") + + config = Config() + config.apply_library_settings( + { + "folder_paths": { + "loras": [str(allowed_root)], + "checkpoints": [], + "unet": [], + "embeddings": [], + } + } + ) + + handler = PreviewHandler(config=config) + encoded_path = urllib.parse.quote(str(forbidden_file), safe="") + request = make_mocked_request("GET", f"/api/lm/previews?path={encoded_path}") + + with pytest.raises(web.HTTPForbidden): + await handler.serve_preview(request) + + +async def test_config_updates_preview_roots_after_switch(tmp_path): + first_root = tmp_path / "first" + first_root.mkdir() + second_root = tmp_path / "second" + second_root.mkdir() + + first_preview = first_root / "model.webp" + first_preview.write_bytes(b"a") + second_preview = second_root / "model.webp" + second_preview.write_bytes(b"b") + + config = Config() + config.apply_library_settings( + { + "folder_paths": { + "loras": [str(first_root)], + "checkpoints": [], + "unet": [], + "embeddings": [], + } + } + ) + + assert config.is_preview_path_allowed(str(first_preview)) + assert not config.is_preview_path_allowed(str(second_preview)) + + config.apply_library_settings( + { + "folder_paths": { + "loras": [str(second_root)], + "checkpoints": [], + "unet": [], + "embeddings": [], + } + } + ) + + assert config.is_preview_path_allowed(str(second_preview)) + assert not config.is_preview_path_allowed(str(first_preview)) + + preview_url = config.get_preview_static_url(str(second_preview)) + assert preview_url.startswith("/api/lm/previews?path=") + decoded = urllib.parse.unquote(preview_url.split("path=", 1)[1]) + assert decoded.replace("\\", "/").endswith("model.webp")