diff --git a/py/config.py b/py/config.py index 9c9b0570..2b6911da 100644 --- a/py/config.py +++ b/py/config.py @@ -224,6 +224,20 @@ class Config: logger.error(f"Error checking link status for {path}: {e}") return False + def _entry_is_symlink(self, entry: os.DirEntry) -> bool: + """Check if a directory entry is a symlink, including Windows junctions.""" + if entry.is_symlink(): + return True + if platform.system() == 'Windows': + try: + import ctypes + FILE_ATTRIBUTE_REPARSE_POINT = 0x400 + attrs = ctypes.windll.kernel32.GetFileAttributesW(entry.path) + return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT) + except Exception: + pass + return False + def _normalize_path(self, path: str) -> str: return os.path.normpath(path).replace(os.sep, '/') @@ -241,8 +255,32 @@ class Config: def _build_symlink_fingerprint(self) -> Dict[str, object]: roots = [self._normalize_path(path) for path in self._symlink_roots() if path] unique_roots = sorted(set(roots)) - # Fingerprint now only contains the root paths to avoid sensitivity to folder content changes. - return {"roots": unique_roots} + + # Include first-level symlinks in fingerprint for change detection. + # This ensures new symlinks under roots trigger a cache invalidation. + # Use lists (not tuples) for JSON serialization compatibility. + direct_symlinks: List[List[str]] = [] + for root in unique_roots: + try: + if os.path.isdir(root): + with os.scandir(root) as it: + for entry in it: + if self._entry_is_symlink(entry): + try: + target = os.path.realpath(entry.path) + direct_symlinks.append([ + self._normalize_path(entry.path), + self._normalize_path(target) + ]) + except OSError: + pass + except (OSError, PermissionError): + pass + + return { + "roots": unique_roots, + "direct_symlinks": sorted(direct_symlinks) + } def _initialize_symlink_mappings(self) -> None: start = time.perf_counter() @@ -255,15 +293,19 @@ class Config: ) self._rebuild_preview_roots() - # Only rescan if target roots have changed. - # This is stable across file additions/deletions. current_fingerprint = self._build_symlink_fingerprint() cached_fingerprint = self._cached_fingerprint - - if cached_fingerprint and current_fingerprint == cached_fingerprint: + + # Check 1: First-level symlinks unchanged (catches new symlinks at root) + fingerprint_valid = cached_fingerprint and current_fingerprint == cached_fingerprint + + # Check 2: All cached mappings still valid (catches changes at any depth) + mappings_valid = self._validate_cached_mappings() if fingerprint_valid else False + + if fingerprint_valid and mappings_valid: return - logger.info("Symlink root paths changed; rescanning symbolic links") + logger.info("Symlink configuration changed; rescanning symbolic links") self.rebuild_symlink_cache() logger.info( @@ -354,6 +396,36 @@ class Config: return True + def _validate_cached_mappings(self) -> bool: + """Verify all cached symlink mappings are still valid. + + Returns True if all mappings are valid, False if rescan is needed. + This catches removed or retargeted symlinks at ANY depth. + """ + for target, link in self._path_mappings.items(): + # Convert normalized paths back to OS paths + link_path = link.replace('/', os.sep) + + # Check if symlink still exists + if not self._is_link(link_path): + logger.debug("Cached symlink no longer exists: %s", link_path) + return False + + # Check if target is still the same + try: + actual_target = self._normalize_path(os.path.realpath(link_path)) + if actual_target != target: + logger.debug( + "Symlink target changed: %s -> %s (cached: %s)", + link_path, actual_target, target + ) + return False + except OSError: + logger.debug("Cannot resolve symlink: %s", link_path) + return False + + return True + def _save_symlink_cache(self) -> None: cache_path = self._get_symlink_cache_path() payload = { @@ -406,10 +478,9 @@ class Config: with os.scandir(current_display) as it: for entry in it: try: - # 1. High speed detection using dirent data (is_symlink) - is_link = entry.is_symlink() - - # On Windows, is_symlink handles reparse points + # 1. Detect symlinks including Windows junctions + is_link = self._entry_is_symlink(entry) + if is_link: # Only resolve realpath when we actually find a link target_path = os.path.realpath(entry.path) diff --git a/py/routes/handlers/preview_handlers.py b/py/routes/handlers/preview_handlers.py index a8c0eed8..e3bee61a 100644 --- a/py/routes/handlers/preview_handlers.py +++ b/py/routes/handlers/preview_handlers.py @@ -41,10 +41,8 @@ class PreviewHandler: raise web.HTTPBadRequest(text="Unable to resolve preview path") from exc resolved_str = str(resolved) - # TODO: Temporarily disabled path validation due to issues #772 and #774 - # Re-enable after fixing preview root path handling - # if not self._config.is_preview_path_allowed(resolved_str): - # raise web.HTTPForbidden(text="Preview path is not within an allowed directory") + if not self._config.is_preview_path_allowed(resolved_str): + raise web.HTTPForbidden(text="Preview path is not within an allowed directory") if not resolved.is_file(): logger.debug("Preview file not found at %s", resolved_str) diff --git a/tests/config/test_symlink_cache.py b/tests/config/test_symlink_cache.py index 91fa738b..9f140349 100644 --- a/tests/config/test_symlink_cache.py +++ b/tests/config/test_symlink_cache.py @@ -118,7 +118,8 @@ def test_symlink_cache_survives_noise_mtime(monkeypatch: pytest.MonkeyPatch, tmp assert second_cfg.map_path_to_link(str(target_dir)) == _normalize(str(dir_link)) -def test_manual_rescan_refreshes_cache(monkeypatch: pytest.MonkeyPatch, tmp_path): +def test_retargeted_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, tmp_path): + """Changing a symlink's target should trigger automatic cache invalidation.""" loras_dir, _ = _setup_paths(monkeypatch, tmp_path) target_dir = loras_dir / "target" @@ -128,22 +129,16 @@ def test_manual_rescan_refreshes_cache(monkeypatch: pytest.MonkeyPatch, tmp_path # Build initial cache pointing at the first target first_cfg = config_module.Config() - old_real = _normalize(os.path.realpath(target_dir)) assert first_cfg.map_path_to_link(str(target_dir)) == _normalize(str(dir_link)) - # Retarget the symlink to a new directory without touching the cache file + # Retarget the symlink to a new directory new_target = loras_dir / "target_v2" new_target.mkdir() dir_link.unlink() dir_link.symlink_to(new_target, target_is_directory=True) + # Second config should automatically detect the change and rescan second_cfg = config_module.Config() - - # Cache still point at the old real path immediately after load - assert second_cfg.map_path_to_link(str(new_target)) == _normalize(str(new_target)) - - # Manual rescan should refresh the mapping to the new target - second_cfg.rebuild_symlink_cache() new_real = _normalize(os.path.realpath(new_target)) assert second_cfg._path_mappings.get(new_real) == _normalize(str(dir_link)) assert second_cfg.map_path_to_link(str(new_target)) == _normalize(str(dir_link)) @@ -190,6 +185,103 @@ def test_symlink_roots_are_preserved(monkeypatch: pytest.MonkeyPatch, tmp_path): assert payload["path_mappings"][normalized_real] == normalized_link +def test_symlink_subfolder_to_external_location(monkeypatch: pytest.MonkeyPatch, tmp_path): + """Symlink under root pointing outside root should be detected and allowed.""" + loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path) + + # Create external directory (outside loras_dir) + external_dir = tmp_path / "external_models" + external_dir.mkdir() + preview_file = external_dir / "model.preview.png" + preview_file.write_bytes(b"preview") + + # Create symlink under loras_dir pointing to external location + symlink = loras_dir / "characters" + symlink.symlink_to(external_dir, target_is_directory=True) + + cfg = config_module.Config() + + # Verify symlink was detected + normalized_external = _normalize(str(external_dir)) + normalized_link = _normalize(str(symlink)) + assert cfg._path_mappings[normalized_external] == normalized_link + + # Verify preview path is allowed + assert cfg.is_preview_path_allowed(str(preview_file)) + + +def test_new_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, tmp_path): + """Adding a new symlink should trigger cache invalidation.""" + loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path) + + # Initial scan with no symlinks + first_cfg = config_module.Config() + assert len(first_cfg._path_mappings) == 0 + + # Create a symlink after initial cache + external_dir = tmp_path / "external" + external_dir.mkdir() + symlink = loras_dir / "new_link" + symlink.symlink_to(external_dir, target_is_directory=True) + + # Second config should detect the change and rescan + second_cfg = config_module.Config() + normalized_external = _normalize(str(external_dir)) + assert normalized_external in second_cfg._path_mappings + + +def test_removed_deep_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, tmp_path): + """Removing a deep symlink should trigger cache invalidation.""" + loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path) + + # Create nested structure with deep symlink + subdir = loras_dir / "anime" + subdir.mkdir() + external_dir = tmp_path / "external" + external_dir.mkdir() + deep_symlink = subdir / "styles" + deep_symlink.symlink_to(external_dir, target_is_directory=True) + + # Initial scan finds the deep symlink + first_cfg = config_module.Config() + normalized_external = _normalize(str(external_dir)) + assert normalized_external in first_cfg._path_mappings + + # Remove the deep symlink + deep_symlink.unlink() + + # Second config should detect invalid cached mapping and rescan + second_cfg = config_module.Config() + assert normalized_external not in second_cfg._path_mappings + + +def test_retargeted_deep_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, tmp_path): + """Changing a deep symlink's target should trigger cache invalidation.""" + loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path) + + # Create nested structure + subdir = loras_dir / "anime" + subdir.mkdir() + target_v1 = tmp_path / "external_v1" + target_v1.mkdir() + target_v2 = tmp_path / "external_v2" + target_v2.mkdir() + + deep_symlink = subdir / "styles" + deep_symlink.symlink_to(target_v1, target_is_directory=True) + + # Initial scan + first_cfg = config_module.Config() + assert _normalize(str(target_v1)) in first_cfg._path_mappings + + # Retarget the symlink + deep_symlink.unlink() + deep_symlink.symlink_to(target_v2, target_is_directory=True) + + # Second config should detect changed target and rescan + second_cfg = config_module.Config() + assert _normalize(str(target_v2)) in second_cfg._path_mappings + assert _normalize(str(target_v1)) not in second_cfg._path_mappings def test_legacy_symlink_cache_automatic_cleanup(monkeypatch: pytest.MonkeyPatch, tmp_path): """Test that legacy symlink cache is automatically cleaned up after migration.""" settings_dir = tmp_path / "settings" diff --git a/tests/routes/test_preview_routes.py b/tests/routes/test_preview_routes.py index 65e091c0..909f7004 100644 --- a/tests/routes/test_preview_routes.py +++ b/tests/routes/test_preview_routes.py @@ -39,33 +39,32 @@ async def test_preview_handler_serves_preview_from_active_library(tmp_path): assert response.status == 200 assert Path(response._path) == preview_file -# TODO: disable temporarily. Enable this once the symlink scan bug fixed -# async def test_preview_handler_forbids_paths_outside_active_library(tmp_path): -# allowed_root = tmp_path / "allowed" -# allowed_root.mkdir() -# forbidden_root = tmp_path / "forbidden" -# forbidden_root.mkdir() -# forbidden_file = forbidden_root / "sneaky.webp" -# forbidden_file.write_bytes(b"x") +async def test_preview_handler_forbids_paths_outside_active_library(tmp_path): + allowed_root = tmp_path / "allowed" + allowed_root.mkdir() + forbidden_root = tmp_path / "forbidden" + forbidden_root.mkdir() + forbidden_file = forbidden_root / "sneaky.webp" + forbidden_file.write_bytes(b"x") -# config = Config() -# config.apply_library_settings( -# { -# "folder_paths": { -# "loras": [str(allowed_root)], -# "checkpoints": [], -# "unet": [], -# "embeddings": [], -# } -# } -# ) + config = Config() + config.apply_library_settings( + { + "folder_paths": { + "loras": [str(allowed_root)], + "checkpoints": [], + "unet": [], + "embeddings": [], + } + } + ) -# handler = PreviewHandler(config=config) -# encoded_path = urllib.parse.quote(str(forbidden_file), safe="") -# request = make_mocked_request("GET", f"/api/lm/previews?path={encoded_path}") + handler = PreviewHandler(config=config) + encoded_path = urllib.parse.quote(str(forbidden_file), safe="") + request = make_mocked_request("GET", f"/api/lm/previews?path={encoded_path}") -# with pytest.raises(web.HTTPForbidden): -# await handler.serve_preview(request) + with pytest.raises(web.HTTPForbidden): + await handler.serve_preview(request) async def test_config_updates_preview_roots_after_switch(tmp_path):