From 8348a0cef86f6aa0baa1d21d7f833ed68c05c779 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Wed, 1 Jul 2026 05:51:58 +0800 Subject: [PATCH] fix(download): harden HF download path validation, fix WebSocket leak, add URL detection tests (#965, #977) Security hardening: - Validate repo format with strict regex (reject .. traversal) - Validate filename rejects path separators and .. - Validate relative_path rejects absolute paths and .. - Verify model_root is within configured scanner roots using realpath + os.sep guard to prevent prefix-match bypass - Add realpath-based escape detection for final dest_path Bug fixes: - Fix WebSocket leak in _downloadHfSingle: wrap ws.close() in try/finally so it closes even if downloadHfModel() throws - Same fix for batch HF download per-file WebSocket loop Frontend hardening: - Tighten HF repo regex: require huggingface.co for full URLs, reject bare .. patterns - Add 12 unit tests for detectUrlType() covering HF resolve, HF repo, CivitAI, CivArchive, direct HTTP, edge cases --- py/routes/handlers/hf_handlers.py | 59 ++++++++- static/js/managers/DownloadManager.js | 137 +++++++++++--------- tests/frontend/utils/hfUrlDetection.test.js | 103 +++++++++++++++ 3 files changed, 233 insertions(+), 66 deletions(-) create mode 100644 tests/frontend/utils/hfUrlDetection.test.js diff --git a/py/routes/handlers/hf_handlers.py b/py/routes/handlers/hf_handlers.py index aaebbb22..915037c3 100644 --- a/py/routes/handlers/hf_handlers.py +++ b/py/routes/handlers/hf_handlers.py @@ -10,6 +10,7 @@ from __future__ import annotations import json import logging import os +import re from typing import Any import aiohttp @@ -235,14 +236,55 @@ class HfHandler: {"error": "Missing required fields: 'repo' and 'filename'"}, status=400 ) - # Determine target directory - if os.path.isabs(model_root): - base_dir = model_root + # Validate repo format — must be user/repo_name + if repo.count("/") != 1 or not re.match(r"^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+$", repo): + return web.json_response({"error": f"Invalid repo format: {repo}"}, status=400) + author, repo_name = repo.split("/", 1) + if ".." in (author, repo_name) or "." in (author, repo_name): + return web.json_response({"error": f"Invalid repo format: {repo}"}, status=400) + + # Validate filename — must not contain path separators or .. + if "/" in filename or "\\" in filename or ".." in filename: + return web.json_response({"error": "Invalid filename"}, status=400) + + # Validate relative_path — must not be absolute or escape base directory + if relative_path: + if os.path.isabs(relative_path): + return web.json_response({"error": "relative_path must not be absolute"}, status=400) + if ".." in relative_path.split("/") or "\\" in relative_path: + return web.json_response({"error": "Invalid relative_path"}, status=400) + + # Validate model_root — must not contain path traversal + if not os.path.isabs(model_root): + # For relative model_root, check it doesn't escape + resolved_model_root = os.path.realpath( + os.path.join(os.getcwd(), "models", model_root) + ) else: - base_dir = os.path.join(os.getcwd(), "models", model_root) + resolved_model_root = os.path.realpath(model_root) + + # Verify model_root is within a configured scanner root + allowed_roots = set() + for root_list in ( + config.loras_roots or [], + config.extra_loras_roots or [], + config.checkpoints_roots or [], + config.extra_checkpoints_roots or [], + config.unet_roots or [], + config.extra_unet_roots or [], + config.embeddings_roots or [], + config.extra_embeddings_roots or [], + ): + for r in root_list: + allowed_roots.add(os.path.realpath(r)) + + if not any(resolved_model_root == root or resolved_model_root.startswith(root + os.sep) for root in allowed_roots): + logger.warning("Invalid model_root rejected: %s", model_root) + return web.json_response({"error": f"Invalid model_root: {model_root}"}, status=400) + + base_dir = resolved_model_root if use_default_paths: - author, repo_name = repo.split("/", 1) target_dir = os.path.join(base_dir, "huggingface", author, repo_name) elif relative_path: target_dir = os.path.join(base_dir, relative_path) @@ -252,6 +294,13 @@ class HfHandler: os.makedirs(target_dir, exist_ok=True) dest_path = os.path.join(target_dir, filename) + # Resolve symlinks and check for path traversal escape + real_dest = os.path.realpath(dest_path) + real_base = os.path.realpath(target_dir) + if not real_dest.startswith(real_base + os.sep): + logger.warning("Path traversal blocked: %s -> %s", dest_path, real_dest) + return web.json_response({"error": "Path traversal detected"}, status=400) + # Check if already exists (simple skip) if os.path.exists(dest_path) and os.path.getsize(dest_path) > 0: logger.info("download_hf_model: file already exists, skipping — %s", dest_path) diff --git a/static/js/managers/DownloadManager.js b/static/js/managers/DownloadManager.js index 7f2e942e..942b9ed8 100644 --- a/static/js/managers/DownloadManager.js +++ b/static/js/managers/DownloadManager.js @@ -481,8 +481,18 @@ export class DownloadManager { } // Hugging Face repo URL (huggingface.co/user/repo or bare user/repo path) - const hfRepoMatch = trimmed.match(/(?:https?:\/\/huggingface\.co\/)?([a-zA-Z0-9_.-]+\/[a-zA-Z0-9_.-]+)(?:\/?$|$)/); + // Require huggingface.co prefix for full URLs; bare user/repo only without :// + const hfRepoMatch = trimmed.match( + trimmed.includes('://') + ? /^https?:\/\/huggingface\.co\/([a-zA-Z0-9_.-]+\/[a-zA-Z0-9_.-]+)(?:\/?$|$)/ + : /^([a-zA-Z0-9_.-]+\/[a-zA-Z0-9_.-]+)$/ + ); if (hfRepoMatch) { + // Reject path-traversal patterns like "../.." or "user/.." + const parts = hfRepoMatch[1].split('/'); + if (parts.some(p => p === '.' || p === '..')) { + return null; + } return { type: 'hf-repo', repo: hfRepoMatch[1], @@ -987,42 +997,44 @@ export class DownloadManager { const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`); - await new Promise((resolve, reject) => { - ws.onopen = resolve; - ws.onerror = reject; - }); + try { + await new Promise((resolve, reject) => { + ws.onopen = resolve; + ws.onerror = reject; + }); - // Capture completed count at WS creation time so progress - // updates arriving after completedDownloads increments still - // show the correct "N / total" position. - const snapshotCompleted = completedDownloads; - ws.onmessage = (event) => { - const data = JSON.parse(event.data); - if (data.status === 'progress') { - const metrics = { - bytesDownloaded: data.bytes_downloaded, - totalBytes: data.total_bytes, - bytesPerSecond: data.bytes_per_second, - }; - updateProgress(data.progress, snapshotCompleted, filename, metrics); + // Capture completed count at WS creation time so progress + // updates arriving after completedDownloads increments still + // show the correct "N / total" position. + const snapshotCompleted = completedDownloads; + ws.onmessage = (event) => { + const data = JSON.parse(event.data); + if (data.status === 'progress') { + const metrics = { + bytesDownloaded: data.bytes_downloaded, + totalBytes: data.total_bytes, + bytesPerSecond: data.bytes_per_second, + }; + updateProgress(data.progress, snapshotCompleted, filename, metrics); + } + }; + + const response = await this.apiClient.downloadHfModel({ + repo: this.hfRepoId, + filename, + revision: 'main', + modelRoot, + relativePath: targetFolder, + useDefaultPaths, + download_id: downloadId, + }); + + if (response?.success) { + completedDownloads++; + updateProgress(100, completedDownloads, filename); } - }; - - const response = await this.apiClient.downloadHfModel({ - repo: this.hfRepoId, - filename, - revision: 'main', - modelRoot, - relativePath: targetFolder, - useDefaultPaths, - download_id: downloadId, - }); - - ws.close(); - - if (response?.success) { - completedDownloads++; - updateProgress(100, completedDownloads, filename); + } finally { + ws.close(); } } @@ -1401,33 +1413,36 @@ export class DownloadManager { // Per-file WebSocket for real-time progress const downloadId = Date.now().toString() + '_hf_' + i; const wsHf = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`); - await new Promise((resolve, reject) => { - wsHf.onopen = resolve; - wsHf.onerror = reject; - }); - const snapshotCompleted = completedDownloads; - wsHf.onmessage = (event) => { - const data = JSON.parse(event.data); - if (data.status === 'progress') { - const metrics = { - bytesDownloaded: data.bytes_downloaded, - totalBytes: data.total_bytes, - bytesPerSecond: data.bytes_per_second, - }; - updateProgress(data.progress, snapshotCompleted, name, metrics); - } - }; + try { + await new Promise((resolve, reject) => { + wsHf.onopen = resolve; + wsHf.onerror = reject; + }); + const snapshotCompleted = completedDownloads; + wsHf.onmessage = (event) => { + const data = JSON.parse(event.data); + if (data.status === 'progress') { + const metrics = { + bytesDownloaded: data.bytes_downloaded, + totalBytes: data.total_bytes, + bytesPerSecond: data.bytes_per_second, + }; + updateProgress(data.progress, snapshotCompleted, name, metrics); + } + }; - response = await this.apiClient.downloadHfModel({ - repo: item.repo, - filename: item.filename, - revision: item.revision || 'main', - modelRoot, - relativePath: targetFolder, - useDefaultPaths, - download_id: downloadId, - }); - wsHf.close(); + response = await this.apiClient.downloadHfModel({ + repo: item.repo, + filename: item.filename, + revision: item.revision || 'main', + modelRoot, + relativePath: targetFolder, + useDefaultPaths, + download_id: downloadId, + }); + } finally { + wsHf.close(); + } } else { response = await this.apiClient.downloadModel( item.modelId, diff --git a/tests/frontend/utils/hfUrlDetection.test.js b/tests/frontend/utils/hfUrlDetection.test.js new file mode 100644 index 00000000..df39a286 --- /dev/null +++ b/tests/frontend/utils/hfUrlDetection.test.js @@ -0,0 +1,103 @@ +import { describe, it, expect } from 'vitest'; +import { DownloadManager } from '../../../static/js/managers/DownloadManager.js'; + +describe('DownloadManager.detectUrlType — HF URL detection', () => { + + it('detects HF resolve URL with file', () => { + const result = DownloadManager.detectUrlType( + 'https://huggingface.co/dx8152/Flux2-Klein-9B-Consistency/resolve/main/Flux2-Klein-9B-consistency-V2.safetensors' + ); + expect(result).toEqual({ + type: 'hf-resolve', + repo: 'dx8152/Flux2-Klein-9B-Consistency', + revision: 'main', + filename: 'Flux2-Klein-9B-consistency-V2.safetensors', + }); + }); + + it('detects HF resolve URL with subdirectory file', () => { + const result = DownloadManager.detectUrlType( + 'https://huggingface.co/user/repo/resolve/main/subdir/model.safetensors' + ); + expect(result).toEqual({ + type: 'hf-resolve', + repo: 'user/repo', + revision: 'main', + filename: 'subdir/model.safetensors', + }); + }); + + it('detects HF repo URL (full URL)', () => { + const result = DownloadManager.detectUrlType( + 'https://huggingface.co/dx8152/Flux2-Klein-9B-Consistency' + ); + expect(result).toEqual({ + type: 'hf-repo', + repo: 'dx8152/Flux2-Klein-9B-Consistency', + }); + }); + + it('detects HF repo URL (bare user/repo)', () => { + const result = DownloadManager.detectUrlType('dx8152/Flux2-Klein-9B-Consistency'); + expect(result).toEqual({ + type: 'hf-repo', + repo: 'dx8152/Flux2-Klein-9B-Consistency', + }); + }); + + it('detects HF repo URL with trailing slash', () => { + const result = DownloadManager.detectUrlType( + 'https://huggingface.co/user/repo/' + ); + expect(result).toEqual({ + type: 'hf-repo', + repo: 'user/repo', + }); + }); + + it('detects CivitAI URL', () => { + const result = DownloadManager.detectUrlType( + 'https://civitai.com/models/123/some-model' + ); + expect(result).toEqual({ type: 'civitai' }); + }); + + it('detects CivArchive URL', () => { + const result = DownloadManager.detectUrlType( + 'https://civarchive.com/models/456' + ); + expect(result).toEqual({ type: 'civitai' }); + }); + + it('detects direct HTTP URL', () => { + const result = DownloadManager.detectUrlType( + 'https://example.com/file.zip' + ); + expect(result).toEqual({ type: 'direct-http' }); + }); + + it('returns null for invalid input', () => { + expect(DownloadManager.detectUrlType('')).toBeNull(); + expect(DownloadManager.detectUrlType(' ')).toBeNull(); + }); + + it('returns null for unrecognized path', () => { + expect(DownloadManager.detectUrlType('justrandomtext')).toBeNull(); + }); + + it('prefers HF resolve over repo when both match', () => { + const result = DownloadManager.detectUrlType( + 'https://huggingface.co/user/repo/resolve/main/file.safetensors' + ); + expect(result?.type).toBe('hf-resolve'); + }); + + it('prefers CivitAI over HF when both match', () => { + // CivitAI check comes first in detectUrlType + // This URL should be detected as CivitAI, not HF + const result = DownloadManager.detectUrlType( + 'https://civitai.com/models/123?huggingface.co/test/repo' + ); + expect(result?.type).toBe('civitai'); + }); +});