mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Merge pull request #516 from willmiao/codex/investigate-library-switching-functionality-issue
feat: serve dynamic preview assets after library switch
This commit is contained in:
28
docs/library-switching.md
Normal file
28
docs/library-switching.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# Library Switching and Preview Routes
|
||||
|
||||
Library switching no longer requires restarting the backend. The preview
|
||||
thumbnails shown in the UI are now served through a dynamic endpoint that
|
||||
resolves files against the folders registered for the active library at request
|
||||
time. This allows the multi-library flow to update model roots without touching
|
||||
the aiohttp router, so previews remain available immediately after a switch.
|
||||
|
||||
## How the dynamic preview endpoint works
|
||||
|
||||
* `config.get_preview_static_url()` now returns `/api/lm/previews?path=<encoded>`
|
||||
for any preview path. The raw filesystem location is URL encoded so that it
|
||||
can be passed through the query string without leaking directory structure in
|
||||
the route itself.【F:py/config.py†L398-L404】
|
||||
* `PreviewRoutes` exposes the `/api/lm/previews` handler which validates the
|
||||
decoded path against the directories registered for the current library. The
|
||||
request is rejected if it falls outside those roots or if the file does not
|
||||
exist.【F:py/routes/preview_routes.py†L5-L21】【F:py/routes/handlers/preview_handlers.py†L9-L48】
|
||||
* `Config` keeps an up-to-date cache of allowed preview roots. Every time a
|
||||
library is applied the cache is rebuilt using the declared LoRA, checkpoint
|
||||
and embedding directories (including symlink targets). The validation logic
|
||||
checks preview requests against this cache.【F:py/config.py†L51-L68】【F:py/config.py†L180-L248】【F:py/config.py†L332-L346】
|
||||
|
||||
Both the ComfyUI runtime (`LoraManager.add_routes`) and the standalone launcher
|
||||
(`StandaloneLoraManager.add_routes`) register the new preview routes instead of
|
||||
mounting a static directory per root. Switching libraries therefore works
|
||||
without restarting the application, and preview URLs generated before or after a
|
||||
switch continue to resolve correctly.【F:py/lora_manager.py†L21-L82】【F:standalone.py†L302-L315】
|
||||
112
py/config.py
112
py/config.py
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
import folder_paths # type: ignore
|
||||
from typing import Dict, Iterable, List, Mapping, Set
|
||||
import logging
|
||||
@@ -52,9 +53,9 @@ class Config:
|
||||
self.static_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static')
|
||||
self.i18n_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'locales')
|
||||
# Path mapping dictionary, target to link mapping
|
||||
self._path_mappings = {}
|
||||
# Static route mapping dictionary, target to route mapping
|
||||
self._route_mappings = {}
|
||||
self._path_mappings: Dict[str, str] = {}
|
||||
# Normalized preview root directories used to validate preview access
|
||||
self._preview_root_paths: Set[Path] = set()
|
||||
self.loras_roots = self._init_lora_paths()
|
||||
self.checkpoints_roots = None
|
||||
self.unet_roots = None
|
||||
@@ -63,6 +64,7 @@ class Config:
|
||||
self.embeddings_roots = self._init_embedding_paths()
|
||||
# Scan symbolic links during initialization
|
||||
self._scan_symbolic_links()
|
||||
self._rebuild_preview_roots()
|
||||
|
||||
if not standalone_mode:
|
||||
# Save the paths to settings.json when running in ComfyUI mode
|
||||
@@ -185,12 +187,65 @@ class Config:
|
||||
# Keep the original mapping: target path -> link path
|
||||
self._path_mappings[normalized_target] = normalized_link
|
||||
logger.info(f"Added path mapping: {normalized_target} -> {normalized_link}")
|
||||
self._preview_root_paths.update(self._expand_preview_root(normalized_target))
|
||||
self._preview_root_paths.update(self._expand_preview_root(normalized_link))
|
||||
|
||||
def add_route_mapping(self, path: str, route: str):
|
||||
"""Add a static route mapping"""
|
||||
normalized_path = os.path.normpath(path).replace(os.sep, '/')
|
||||
self._route_mappings[normalized_path] = route
|
||||
# logger.info(f"Added route mapping: {normalized_path} -> {route}")
|
||||
def _expand_preview_root(self, path: str) -> Set[Path]:
|
||||
"""Return normalized ``Path`` objects representing a preview root."""
|
||||
|
||||
roots: Set[Path] = set()
|
||||
if not path:
|
||||
return roots
|
||||
|
||||
try:
|
||||
raw_path = Path(path).expanduser()
|
||||
except Exception:
|
||||
return roots
|
||||
|
||||
if raw_path.is_absolute():
|
||||
roots.add(raw_path)
|
||||
|
||||
try:
|
||||
resolved = raw_path.resolve(strict=False)
|
||||
except RuntimeError:
|
||||
resolved = raw_path.absolute()
|
||||
roots.add(resolved)
|
||||
|
||||
try:
|
||||
real_path = raw_path.resolve()
|
||||
except (FileNotFoundError, RuntimeError):
|
||||
real_path = resolved
|
||||
roots.add(real_path)
|
||||
|
||||
normalized: Set[Path] = set()
|
||||
for candidate in roots:
|
||||
if candidate.is_absolute():
|
||||
normalized.add(candidate)
|
||||
else:
|
||||
try:
|
||||
normalized.add(candidate.resolve(strict=False))
|
||||
except RuntimeError:
|
||||
normalized.add(candidate.absolute())
|
||||
|
||||
return normalized
|
||||
|
||||
def _rebuild_preview_roots(self) -> None:
|
||||
"""Recompute the cache of directories permitted for previews."""
|
||||
|
||||
preview_roots: Set[Path] = set()
|
||||
|
||||
for root in self.loras_roots or []:
|
||||
preview_roots.update(self._expand_preview_root(root))
|
||||
for root in self.base_models_roots or []:
|
||||
preview_roots.update(self._expand_preview_root(root))
|
||||
for root in self.embeddings_roots or []:
|
||||
preview_roots.update(self._expand_preview_root(root))
|
||||
|
||||
for target, link in self._path_mappings.items():
|
||||
preview_roots.update(self._expand_preview_root(target))
|
||||
preview_roots.update(self._expand_preview_root(link))
|
||||
|
||||
self._preview_root_paths = {path for path in preview_roots if path.is_absolute()}
|
||||
|
||||
def map_path_to_link(self, path: str) -> str:
|
||||
"""Map a target path back to its symbolic link path"""
|
||||
@@ -276,6 +331,7 @@ class Config:
|
||||
|
||||
def _apply_library_paths(self, folder_paths: Mapping[str, Iterable[str]]) -> None:
|
||||
self._path_mappings.clear()
|
||||
self._preview_root_paths = set()
|
||||
|
||||
lora_paths = folder_paths.get('loras', []) or []
|
||||
checkpoint_paths = folder_paths.get('checkpoints', []) or []
|
||||
@@ -287,6 +343,7 @@ class Config:
|
||||
self.embeddings_roots = self._prepare_embedding_paths(embedding_paths)
|
||||
|
||||
self._scan_symbolic_links()
|
||||
self._rebuild_preview_roots()
|
||||
|
||||
def _init_lora_paths(self) -> List[str]:
|
||||
"""Initialize and validate LoRA paths from ComfyUI settings"""
|
||||
@@ -342,24 +399,29 @@ class Config:
|
||||
if not preview_path:
|
||||
return ""
|
||||
|
||||
real_path = os.path.realpath(preview_path).replace(os.sep, '/')
|
||||
|
||||
# Find longest matching path (most specific match)
|
||||
best_match = ""
|
||||
best_route = ""
|
||||
|
||||
for path, route in self._route_mappings.items():
|
||||
if real_path.startswith(path) and len(path) > len(best_match):
|
||||
best_match = path
|
||||
best_route = route
|
||||
|
||||
if best_match:
|
||||
relative_path = os.path.relpath(real_path, best_match).replace(os.sep, '/')
|
||||
safe_parts = [urllib.parse.quote(part) for part in relative_path.split('/')]
|
||||
safe_path = '/'.join(safe_parts)
|
||||
return f'{best_route}/{safe_path}'
|
||||
normalized = os.path.normpath(preview_path).replace(os.sep, '/')
|
||||
encoded_path = urllib.parse.quote(normalized, safe='')
|
||||
return f'/api/lm/previews?path={encoded_path}'
|
||||
|
||||
return ""
|
||||
def is_preview_path_allowed(self, preview_path: str) -> bool:
|
||||
"""Return ``True`` if ``preview_path`` is within an allowed directory."""
|
||||
|
||||
if not preview_path:
|
||||
return False
|
||||
|
||||
try:
|
||||
candidate = Path(preview_path).expanduser().resolve(strict=False)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
for root in self._preview_root_paths:
|
||||
try:
|
||||
candidate.relative_to(root)
|
||||
return True
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def apply_library_settings(self, library_config: Mapping[str, object]) -> None:
|
||||
"""Update runtime paths to match the provided library configuration."""
|
||||
|
||||
@@ -2,7 +2,6 @@ import asyncio
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from server import PromptServer # type: ignore
|
||||
|
||||
from .config import config
|
||||
@@ -11,6 +10,7 @@ from .routes.recipe_routes import RecipeRoutes
|
||||
from .routes.stats_routes import StatsRoutes
|
||||
from .routes.update_routes import UpdateRoutes
|
||||
from .routes.misc_routes import MiscRoutes
|
||||
from .routes.preview_routes import PreviewRoutes
|
||||
from .routes.example_images_routes import ExampleImagesRoutes
|
||||
from .services.service_registry import ServiceRegistry
|
||||
from .services.settings_manager import settings
|
||||
@@ -50,102 +50,12 @@ class LoraManager:
|
||||
asyncio_logger = logging.getLogger("asyncio")
|
||||
asyncio_logger.addFilter(ConnectionResetFilter())
|
||||
|
||||
added_targets = set() # Track already added target paths
|
||||
|
||||
# Add static route for example images if the path exists in settings
|
||||
example_images_path = settings.get('example_images_path')
|
||||
logger.info(f"Example images path: {example_images_path}")
|
||||
if example_images_path and os.path.exists(example_images_path):
|
||||
app.router.add_static('/example_images_static', example_images_path)
|
||||
logger.info(f"Added static route for example images: /example_images_static -> {example_images_path}")
|
||||
|
||||
# Add static routes for each lora root
|
||||
for idx, root in enumerate(config.loras_roots, start=1):
|
||||
preview_path = f'/loras_static/root{idx}/preview'
|
||||
|
||||
real_root = root
|
||||
if root in config._path_mappings.values():
|
||||
for target, link in config._path_mappings.items():
|
||||
if link == root:
|
||||
real_root = target
|
||||
break
|
||||
# Add static route for original path
|
||||
app.router.add_static(preview_path, real_root)
|
||||
logger.info(f"Added static route {preview_path} -> {real_root}")
|
||||
|
||||
# Record route mapping
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(real_root)
|
||||
|
||||
# Add static routes for each checkpoint root
|
||||
for idx, root in enumerate(config.base_models_roots, start=1):
|
||||
preview_path = f'/checkpoints_static/root{idx}/preview'
|
||||
|
||||
real_root = root
|
||||
if root in config._path_mappings.values():
|
||||
for target, link in config._path_mappings.items():
|
||||
if link == root:
|
||||
real_root = target
|
||||
break
|
||||
# Add static route for original path
|
||||
app.router.add_static(preview_path, real_root)
|
||||
logger.info(f"Added static route {preview_path} -> {real_root}")
|
||||
|
||||
# Record route mapping
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(real_root)
|
||||
|
||||
# Add static routes for each embedding root
|
||||
for idx, root in enumerate(config.embeddings_roots, start=1):
|
||||
preview_path = f'/embeddings_static/root{idx}/preview'
|
||||
|
||||
real_root = root
|
||||
if root in config._path_mappings.values():
|
||||
for target, link in config._path_mappings.items():
|
||||
if link == root:
|
||||
real_root = target
|
||||
break
|
||||
# Add static route for original path
|
||||
app.router.add_static(preview_path, real_root)
|
||||
logger.info(f"Added static route {preview_path} -> {real_root}")
|
||||
|
||||
# Record route mapping
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(real_root)
|
||||
|
||||
# Add static routes for symlink target paths
|
||||
link_idx = {
|
||||
'lora': 1,
|
||||
'checkpoint': 1,
|
||||
'embedding': 1
|
||||
}
|
||||
|
||||
for target_path, link_path in config._path_mappings.items():
|
||||
if target_path not in added_targets:
|
||||
# Determine if this is a checkpoint, lora, or embedding link based on path
|
||||
is_checkpoint = any(cp_root in link_path for cp_root in config.base_models_roots)
|
||||
is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.base_models_roots)
|
||||
is_embedding = any(emb_root in link_path for emb_root in config.embeddings_roots)
|
||||
is_embedding = is_embedding or any(emb_root in target_path for emb_root in config.embeddings_roots)
|
||||
|
||||
if is_checkpoint:
|
||||
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'
|
||||
link_idx["checkpoint"] += 1
|
||||
elif is_embedding:
|
||||
route_path = f'/embeddings_static/link_{link_idx["embedding"]}/preview'
|
||||
link_idx["embedding"] += 1
|
||||
else:
|
||||
route_path = f'/loras_static/link_{link_idx["lora"]}/preview'
|
||||
link_idx["lora"] += 1
|
||||
|
||||
try:
|
||||
app.router.add_static(route_path, Path(target_path).resolve(strict=False))
|
||||
logger.info(f"Added static route for link target {route_path} -> {target_path}")
|
||||
config.add_route_mapping(target_path, route_path)
|
||||
added_targets.add(target_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to add static route on initialization for {target_path}: {e}")
|
||||
continue
|
||||
|
||||
# Add static route for locales JSON files
|
||||
if os.path.exists(config.i18n_path):
|
||||
@@ -168,6 +78,7 @@ class LoraManager:
|
||||
UpdateRoutes.setup_routes(app)
|
||||
MiscRoutes.setup_routes(app)
|
||||
ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager)
|
||||
PreviewRoutes.setup_routes(app)
|
||||
|
||||
# Setup WebSocket routes that are shared across all model types
|
||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||
|
||||
56
py/routes/handlers/preview_handlers.py
Normal file
56
py/routes/handlers/preview_handlers.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Handlers responsible for serving preview assets dynamically."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from ...config import config as global_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreviewHandler:
|
||||
"""Serve preview assets for the active library at request time."""
|
||||
|
||||
def __init__(self, *, config=global_config) -> None:
|
||||
self._config = config
|
||||
|
||||
async def serve_preview(self, request: web.Request) -> web.StreamResponse:
|
||||
"""Return the preview file referenced by the encoded ``path`` query."""
|
||||
|
||||
raw_path = request.query.get("path", "")
|
||||
if not raw_path:
|
||||
raise web.HTTPBadRequest(text="Missing 'path' query parameter")
|
||||
|
||||
try:
|
||||
decoded_path = urllib.parse.unquote(raw_path)
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.debug("Failed to decode preview path %s: %s", raw_path, exc)
|
||||
raise web.HTTPBadRequest(text="Invalid preview path encoding") from exc
|
||||
|
||||
normalized = decoded_path.replace("\\", "/")
|
||||
candidate = Path(normalized)
|
||||
try:
|
||||
resolved = candidate.expanduser().resolve(strict=False)
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to resolve preview path %s: %s", normalized, exc)
|
||||
raise web.HTTPBadRequest(text="Unable to resolve preview path") from exc
|
||||
|
||||
resolved_str = str(resolved)
|
||||
if not self._config.is_preview_path_allowed(resolved_str):
|
||||
logger.debug("Rejected preview outside allowed roots: %s", resolved_str)
|
||||
raise web.HTTPForbidden(text="Preview path is not within an allowed directory")
|
||||
|
||||
if not resolved.is_file():
|
||||
logger.debug("Preview file not found at %s", resolved_str)
|
||||
raise web.HTTPNotFound(text="Preview file not found")
|
||||
|
||||
# aiohttp's FileResponse handles range requests and content headers for us.
|
||||
return web.FileResponse(path=resolved, chunk_size=256 * 1024)
|
||||
|
||||
|
||||
__all__ = ["PreviewHandler"]
|
||||
25
py/routes/preview_routes.py
Normal file
25
py/routes/preview_routes.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Route controller for preview asset delivery."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from .handlers.preview_handlers import PreviewHandler
|
||||
|
||||
|
||||
class PreviewRoutes:
|
||||
"""Register routes that expose preview assets."""
|
||||
|
||||
def __init__(self, *, handler: PreviewHandler | None = None) -> None:
|
||||
self._handler = handler or PreviewHandler()
|
||||
|
||||
@classmethod
|
||||
def setup_routes(cls, app: web.Application) -> None:
|
||||
controller = cls()
|
||||
controller.register(app)
|
||||
|
||||
def register(self, app: web.Application) -> None:
|
||||
app.router.add_get('/api/lm/previews', self._handler.serve_preview)
|
||||
|
||||
|
||||
__all__ = ["PreviewRoutes"]
|
||||
117
standalone.py
117
standalone.py
@@ -1,4 +1,3 @@
|
||||
from pathlib import Path
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
@@ -280,121 +279,7 @@ class StandaloneLoraManager(LoraManager):
|
||||
# Store app in a global-like location for compatibility
|
||||
sys.modules['server'].PromptServer.instance = server_instance
|
||||
|
||||
added_targets = set() # Track already added target paths
|
||||
|
||||
# Add static routes for each lora root
|
||||
for idx, root in enumerate(config.loras_roots, start=1):
|
||||
if not os.path.exists(root):
|
||||
logger.warning(f"Lora root path does not exist: {root}")
|
||||
continue
|
||||
|
||||
preview_path = f'/loras_static/root{idx}/preview'
|
||||
|
||||
# Check if this root is a link path in the mappings
|
||||
real_root = root
|
||||
for target, link in config._path_mappings.items():
|
||||
if os.path.normpath(link) == os.path.normpath(root):
|
||||
# If so, route should point to the target (real path)
|
||||
real_root = target
|
||||
break
|
||||
|
||||
# Normalize and standardize path display for consistency
|
||||
display_root = real_root.replace('\\', '/')
|
||||
|
||||
# Add static route for original path - use the normalized path
|
||||
app.router.add_static(preview_path, real_root)
|
||||
logger.info(f"Added static route {preview_path} -> {display_root}")
|
||||
|
||||
# Record route mapping with normalized path
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(os.path.normpath(real_root))
|
||||
|
||||
# Add static routes for each checkpoint root
|
||||
for idx, root in enumerate(config.base_models_roots, start=1):
|
||||
if not os.path.exists(root):
|
||||
logger.warning(f"Checkpoint root path does not exist: {root}")
|
||||
continue
|
||||
|
||||
preview_path = f'/checkpoints_static/root{idx}/preview'
|
||||
|
||||
# Check if this root is a link path in the mappings
|
||||
real_root = root
|
||||
for target, link in config._path_mappings.items():
|
||||
if os.path.normpath(link) == os.path.normpath(root):
|
||||
# If so, route should point to the target (real path)
|
||||
real_root = target
|
||||
break
|
||||
|
||||
# Normalize and standardize path display for consistency
|
||||
display_root = real_root.replace('\\', '/')
|
||||
|
||||
# Add static route for original path
|
||||
app.router.add_static(preview_path, real_root)
|
||||
logger.info(f"Added static route {preview_path} -> {display_root}")
|
||||
|
||||
# Record route mapping
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(os.path.normpath(real_root))
|
||||
|
||||
# Add static routes for each embedding root
|
||||
for idx, root in enumerate(getattr(config, "embeddings_roots", []), start=1):
|
||||
if not os.path.exists(root):
|
||||
logger.warning(f"Embedding root path does not exist: {root}")
|
||||
continue
|
||||
|
||||
preview_path = f'/embeddings_static/root{idx}/preview'
|
||||
|
||||
real_root = root
|
||||
for target, link in config._path_mappings.items():
|
||||
if os.path.normpath(link) == os.path.normpath(root):
|
||||
real_root = target
|
||||
break
|
||||
|
||||
display_root = real_root.replace('\\', '/')
|
||||
app.router.add_static(preview_path, real_root)
|
||||
logger.info(f"Added static route {preview_path} -> {display_root}")
|
||||
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(os.path.normpath(real_root))
|
||||
|
||||
# Add static routes for symlink target paths that aren't already covered
|
||||
link_idx = {
|
||||
'lora': 1,
|
||||
'checkpoint': 1,
|
||||
'embedding': 1
|
||||
}
|
||||
|
||||
for target_path, link_path in config._path_mappings.items():
|
||||
norm_target = os.path.normpath(target_path)
|
||||
if norm_target not in added_targets:
|
||||
# Determine if this is a checkpoint, lora, or embedding link based on path
|
||||
is_checkpoint = any(os.path.normpath(cp_root) in os.path.normpath(link_path) for cp_root in config.base_models_roots)
|
||||
is_checkpoint = is_checkpoint or any(os.path.normpath(cp_root) in norm_target for cp_root in config.base_models_roots)
|
||||
is_embedding = any(os.path.normpath(emb_root) in os.path.normpath(link_path) for emb_root in getattr(config, "embeddings_roots", []))
|
||||
is_embedding = is_embedding or any(os.path.normpath(emb_root) in norm_target for emb_root in getattr(config, "embeddings_roots", []))
|
||||
|
||||
if is_checkpoint:
|
||||
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'
|
||||
link_idx["checkpoint"] += 1
|
||||
elif is_embedding:
|
||||
route_path = f'/embeddings_static/link_{link_idx["embedding"]}/preview'
|
||||
link_idx["embedding"] += 1
|
||||
else:
|
||||
route_path = f'/loras_static/link_{link_idx["lora"]}/preview'
|
||||
link_idx["lora"] += 1
|
||||
|
||||
# Display path with forward slashes for consistency
|
||||
display_target = target_path.replace('\\', '/')
|
||||
|
||||
try:
|
||||
app.router.add_static(route_path, Path(target_path).resolve(strict=False))
|
||||
logger.info(f"Added static route for link target {route_path} -> {display_target}")
|
||||
config.add_route_mapping(target_path, route_path)
|
||||
added_targets.add(norm_target)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to add static route on initialization for {target_path}: {e}")
|
||||
continue
|
||||
|
||||
# Add static route for locales JSON files
|
||||
if os.path.exists(config.i18n_path):
|
||||
app.router.add_static('/locales', config.i18n_path)
|
||||
@@ -409,6 +294,7 @@ class StandaloneLoraManager(LoraManager):
|
||||
from py.routes.update_routes import UpdateRoutes
|
||||
from py.routes.misc_routes import MiscRoutes
|
||||
from py.routes.example_images_routes import ExampleImagesRoutes
|
||||
from py.routes.preview_routes import PreviewRoutes
|
||||
from py.routes.stats_routes import StatsRoutes
|
||||
from py.services.websocket_manager import ws_manager
|
||||
|
||||
@@ -426,6 +312,7 @@ class StandaloneLoraManager(LoraManager):
|
||||
UpdateRoutes.setup_routes(app)
|
||||
MiscRoutes.setup_routes(app)
|
||||
ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager)
|
||||
PreviewRoutes.setup_routes(app)
|
||||
|
||||
# Setup WebSocket routes that are shared across all model types
|
||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||
|
||||
112
tests/routes/test_preview_routes.py
Normal file
112
tests/routes/test_preview_routes.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import make_mocked_request
|
||||
|
||||
from py.config import Config
|
||||
from py.routes.handlers.preview_handlers import PreviewHandler
|
||||
|
||||
|
||||
async def test_preview_handler_serves_preview_from_active_library(tmp_path):
|
||||
library_root = tmp_path / "library"
|
||||
library_root.mkdir()
|
||||
preview_file = library_root / "model.webp"
|
||||
preview_file.write_bytes(b"preview")
|
||||
|
||||
config = Config()
|
||||
config.apply_library_settings(
|
||||
{
|
||||
"folder_paths": {
|
||||
"loras": [str(library_root)],
|
||||
"checkpoints": [],
|
||||
"unet": [],
|
||||
"embeddings": [],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
handler = PreviewHandler(config=config)
|
||||
encoded_path = urllib.parse.quote(str(preview_file), safe="")
|
||||
request = make_mocked_request("GET", f"/api/lm/previews?path={encoded_path}")
|
||||
|
||||
response = await handler.serve_preview(request)
|
||||
|
||||
assert isinstance(response, web.FileResponse)
|
||||
assert response.status == 200
|
||||
assert Path(response._path) == preview_file
|
||||
|
||||
|
||||
async def test_preview_handler_forbids_paths_outside_active_library(tmp_path):
|
||||
allowed_root = tmp_path / "allowed"
|
||||
allowed_root.mkdir()
|
||||
forbidden_root = tmp_path / "forbidden"
|
||||
forbidden_root.mkdir()
|
||||
forbidden_file = forbidden_root / "sneaky.webp"
|
||||
forbidden_file.write_bytes(b"x")
|
||||
|
||||
config = Config()
|
||||
config.apply_library_settings(
|
||||
{
|
||||
"folder_paths": {
|
||||
"loras": [str(allowed_root)],
|
||||
"checkpoints": [],
|
||||
"unet": [],
|
||||
"embeddings": [],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
handler = PreviewHandler(config=config)
|
||||
encoded_path = urllib.parse.quote(str(forbidden_file), safe="")
|
||||
request = make_mocked_request("GET", f"/api/lm/previews?path={encoded_path}")
|
||||
|
||||
with pytest.raises(web.HTTPForbidden):
|
||||
await handler.serve_preview(request)
|
||||
|
||||
|
||||
async def test_config_updates_preview_roots_after_switch(tmp_path):
|
||||
first_root = tmp_path / "first"
|
||||
first_root.mkdir()
|
||||
second_root = tmp_path / "second"
|
||||
second_root.mkdir()
|
||||
|
||||
first_preview = first_root / "model.webp"
|
||||
first_preview.write_bytes(b"a")
|
||||
second_preview = second_root / "model.webp"
|
||||
second_preview.write_bytes(b"b")
|
||||
|
||||
config = Config()
|
||||
config.apply_library_settings(
|
||||
{
|
||||
"folder_paths": {
|
||||
"loras": [str(first_root)],
|
||||
"checkpoints": [],
|
||||
"unet": [],
|
||||
"embeddings": [],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
assert config.is_preview_path_allowed(str(first_preview))
|
||||
assert not config.is_preview_path_allowed(str(second_preview))
|
||||
|
||||
config.apply_library_settings(
|
||||
{
|
||||
"folder_paths": {
|
||||
"loras": [str(second_root)],
|
||||
"checkpoints": [],
|
||||
"unet": [],
|
||||
"embeddings": [],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
assert config.is_preview_path_allowed(str(second_preview))
|
||||
assert not config.is_preview_path_allowed(str(first_preview))
|
||||
|
||||
preview_url = config.get_preview_static_url(str(second_preview))
|
||||
assert preview_url.startswith("/api/lm/previews?path=")
|
||||
decoded = urllib.parse.unquote(preview_url.split("path=", 1)[1])
|
||||
assert decoded.replace("\\", "/").endswith("model.webp")
|
||||
Reference in New Issue
Block a user