diff --git a/py/routes/handlers/hf_handlers.py b/py/routes/handlers/hf_handlers.py index aaebbb22..915037c3 100644 --- a/py/routes/handlers/hf_handlers.py +++ b/py/routes/handlers/hf_handlers.py @@ -10,6 +10,7 @@ from __future__ import annotations import json import logging import os +import re from typing import Any import aiohttp @@ -235,14 +236,55 @@ class HfHandler: {"error": "Missing required fields: 'repo' and 'filename'"}, status=400 ) - # Determine target directory - if os.path.isabs(model_root): - base_dir = model_root + # Validate repo format — must be user/repo_name + if repo.count("/") != 1 or not re.match(r"^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+$", repo): + return web.json_response({"error": f"Invalid repo format: {repo}"}, status=400) + author, repo_name = repo.split("/", 1) + if ".." in (author, repo_name) or "." in (author, repo_name): + return web.json_response({"error": f"Invalid repo format: {repo}"}, status=400) + + # Validate filename — must not contain path separators or .. + if "/" in filename or "\\" in filename or ".." in filename: + return web.json_response({"error": "Invalid filename"}, status=400) + + # Validate relative_path — must not be absolute or escape base directory + if relative_path: + if os.path.isabs(relative_path): + return web.json_response({"error": "relative_path must not be absolute"}, status=400) + if ".." in relative_path.split("/") or "\\" in relative_path: + return web.json_response({"error": "Invalid relative_path"}, status=400) + + # Validate model_root — must not contain path traversal + if not os.path.isabs(model_root): + # For relative model_root, check it doesn't escape + resolved_model_root = os.path.realpath( + os.path.join(os.getcwd(), "models", model_root) + ) else: - base_dir = os.path.join(os.getcwd(), "models", model_root) + resolved_model_root = os.path.realpath(model_root) + + # Verify model_root is within a configured scanner root + allowed_roots = set() + for root_list in ( + config.loras_roots or [], + config.extra_loras_roots or [], + config.checkpoints_roots or [], + config.extra_checkpoints_roots or [], + config.unet_roots or [], + config.extra_unet_roots or [], + config.embeddings_roots or [], + config.extra_embeddings_roots or [], + ): + for r in root_list: + allowed_roots.add(os.path.realpath(r)) + + if not any(resolved_model_root == root or resolved_model_root.startswith(root + os.sep) for root in allowed_roots): + logger.warning("Invalid model_root rejected: %s", model_root) + return web.json_response({"error": f"Invalid model_root: {model_root}"}, status=400) + + base_dir = resolved_model_root if use_default_paths: - author, repo_name = repo.split("/", 1) target_dir = os.path.join(base_dir, "huggingface", author, repo_name) elif relative_path: target_dir = os.path.join(base_dir, relative_path) @@ -252,6 +294,13 @@ class HfHandler: os.makedirs(target_dir, exist_ok=True) dest_path = os.path.join(target_dir, filename) + # Resolve symlinks and check for path traversal escape + real_dest = os.path.realpath(dest_path) + real_base = os.path.realpath(target_dir) + if not real_dest.startswith(real_base + os.sep): + logger.warning("Path traversal blocked: %s -> %s", dest_path, real_dest) + return web.json_response({"error": "Path traversal detected"}, status=400) + # Check if already exists (simple skip) if os.path.exists(dest_path) and os.path.getsize(dest_path) > 0: logger.info("download_hf_model: file already exists, skipping — %s", dest_path) diff --git a/static/js/managers/DownloadManager.js b/static/js/managers/DownloadManager.js index 7f2e942e..942b9ed8 100644 --- a/static/js/managers/DownloadManager.js +++ b/static/js/managers/DownloadManager.js @@ -481,8 +481,18 @@ export class DownloadManager { } // Hugging Face repo URL (huggingface.co/user/repo or bare user/repo path) - const hfRepoMatch = trimmed.match(/(?:https?:\/\/huggingface\.co\/)?([a-zA-Z0-9_.-]+\/[a-zA-Z0-9_.-]+)(?:\/?$|$)/); + // Require huggingface.co prefix for full URLs; bare user/repo only without :// + const hfRepoMatch = trimmed.match( + trimmed.includes('://') + ? /^https?:\/\/huggingface\.co\/([a-zA-Z0-9_.-]+\/[a-zA-Z0-9_.-]+)(?:\/?$|$)/ + : /^([a-zA-Z0-9_.-]+\/[a-zA-Z0-9_.-]+)$/ + ); if (hfRepoMatch) { + // Reject path-traversal patterns like "../.." or "user/.." + const parts = hfRepoMatch[1].split('/'); + if (parts.some(p => p === '.' || p === '..')) { + return null; + } return { type: 'hf-repo', repo: hfRepoMatch[1], @@ -987,42 +997,44 @@ export class DownloadManager { const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`); - await new Promise((resolve, reject) => { - ws.onopen = resolve; - ws.onerror = reject; - }); + try { + await new Promise((resolve, reject) => { + ws.onopen = resolve; + ws.onerror = reject; + }); - // Capture completed count at WS creation time so progress - // updates arriving after completedDownloads increments still - // show the correct "N / total" position. - const snapshotCompleted = completedDownloads; - ws.onmessage = (event) => { - const data = JSON.parse(event.data); - if (data.status === 'progress') { - const metrics = { - bytesDownloaded: data.bytes_downloaded, - totalBytes: data.total_bytes, - bytesPerSecond: data.bytes_per_second, - }; - updateProgress(data.progress, snapshotCompleted, filename, metrics); + // Capture completed count at WS creation time so progress + // updates arriving after completedDownloads increments still + // show the correct "N / total" position. + const snapshotCompleted = completedDownloads; + ws.onmessage = (event) => { + const data = JSON.parse(event.data); + if (data.status === 'progress') { + const metrics = { + bytesDownloaded: data.bytes_downloaded, + totalBytes: data.total_bytes, + bytesPerSecond: data.bytes_per_second, + }; + updateProgress(data.progress, snapshotCompleted, filename, metrics); + } + }; + + const response = await this.apiClient.downloadHfModel({ + repo: this.hfRepoId, + filename, + revision: 'main', + modelRoot, + relativePath: targetFolder, + useDefaultPaths, + download_id: downloadId, + }); + + if (response?.success) { + completedDownloads++; + updateProgress(100, completedDownloads, filename); } - }; - - const response = await this.apiClient.downloadHfModel({ - repo: this.hfRepoId, - filename, - revision: 'main', - modelRoot, - relativePath: targetFolder, - useDefaultPaths, - download_id: downloadId, - }); - - ws.close(); - - if (response?.success) { - completedDownloads++; - updateProgress(100, completedDownloads, filename); + } finally { + ws.close(); } } @@ -1401,33 +1413,36 @@ export class DownloadManager { // Per-file WebSocket for real-time progress const downloadId = Date.now().toString() + '_hf_' + i; const wsHf = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`); - await new Promise((resolve, reject) => { - wsHf.onopen = resolve; - wsHf.onerror = reject; - }); - const snapshotCompleted = completedDownloads; - wsHf.onmessage = (event) => { - const data = JSON.parse(event.data); - if (data.status === 'progress') { - const metrics = { - bytesDownloaded: data.bytes_downloaded, - totalBytes: data.total_bytes, - bytesPerSecond: data.bytes_per_second, - }; - updateProgress(data.progress, snapshotCompleted, name, metrics); - } - }; + try { + await new Promise((resolve, reject) => { + wsHf.onopen = resolve; + wsHf.onerror = reject; + }); + const snapshotCompleted = completedDownloads; + wsHf.onmessage = (event) => { + const data = JSON.parse(event.data); + if (data.status === 'progress') { + const metrics = { + bytesDownloaded: data.bytes_downloaded, + totalBytes: data.total_bytes, + bytesPerSecond: data.bytes_per_second, + }; + updateProgress(data.progress, snapshotCompleted, name, metrics); + } + }; - response = await this.apiClient.downloadHfModel({ - repo: item.repo, - filename: item.filename, - revision: item.revision || 'main', - modelRoot, - relativePath: targetFolder, - useDefaultPaths, - download_id: downloadId, - }); - wsHf.close(); + response = await this.apiClient.downloadHfModel({ + repo: item.repo, + filename: item.filename, + revision: item.revision || 'main', + modelRoot, + relativePath: targetFolder, + useDefaultPaths, + download_id: downloadId, + }); + } finally { + wsHf.close(); + } } else { response = await this.apiClient.downloadModel( item.modelId, diff --git a/tests/frontend/utils/hfUrlDetection.test.js b/tests/frontend/utils/hfUrlDetection.test.js new file mode 100644 index 00000000..df39a286 --- /dev/null +++ b/tests/frontend/utils/hfUrlDetection.test.js @@ -0,0 +1,103 @@ +import { describe, it, expect } from 'vitest'; +import { DownloadManager } from '../../../static/js/managers/DownloadManager.js'; + +describe('DownloadManager.detectUrlType — HF URL detection', () => { + + it('detects HF resolve URL with file', () => { + const result = DownloadManager.detectUrlType( + 'https://huggingface.co/dx8152/Flux2-Klein-9B-Consistency/resolve/main/Flux2-Klein-9B-consistency-V2.safetensors' + ); + expect(result).toEqual({ + type: 'hf-resolve', + repo: 'dx8152/Flux2-Klein-9B-Consistency', + revision: 'main', + filename: 'Flux2-Klein-9B-consistency-V2.safetensors', + }); + }); + + it('detects HF resolve URL with subdirectory file', () => { + const result = DownloadManager.detectUrlType( + 'https://huggingface.co/user/repo/resolve/main/subdir/model.safetensors' + ); + expect(result).toEqual({ + type: 'hf-resolve', + repo: 'user/repo', + revision: 'main', + filename: 'subdir/model.safetensors', + }); + }); + + it('detects HF repo URL (full URL)', () => { + const result = DownloadManager.detectUrlType( + 'https://huggingface.co/dx8152/Flux2-Klein-9B-Consistency' + ); + expect(result).toEqual({ + type: 'hf-repo', + repo: 'dx8152/Flux2-Klein-9B-Consistency', + }); + }); + + it('detects HF repo URL (bare user/repo)', () => { + const result = DownloadManager.detectUrlType('dx8152/Flux2-Klein-9B-Consistency'); + expect(result).toEqual({ + type: 'hf-repo', + repo: 'dx8152/Flux2-Klein-9B-Consistency', + }); + }); + + it('detects HF repo URL with trailing slash', () => { + const result = DownloadManager.detectUrlType( + 'https://huggingface.co/user/repo/' + ); + expect(result).toEqual({ + type: 'hf-repo', + repo: 'user/repo', + }); + }); + + it('detects CivitAI URL', () => { + const result = DownloadManager.detectUrlType( + 'https://civitai.com/models/123/some-model' + ); + expect(result).toEqual({ type: 'civitai' }); + }); + + it('detects CivArchive URL', () => { + const result = DownloadManager.detectUrlType( + 'https://civarchive.com/models/456' + ); + expect(result).toEqual({ type: 'civitai' }); + }); + + it('detects direct HTTP URL', () => { + const result = DownloadManager.detectUrlType( + 'https://example.com/file.zip' + ); + expect(result).toEqual({ type: 'direct-http' }); + }); + + it('returns null for invalid input', () => { + expect(DownloadManager.detectUrlType('')).toBeNull(); + expect(DownloadManager.detectUrlType(' ')).toBeNull(); + }); + + it('returns null for unrecognized path', () => { + expect(DownloadManager.detectUrlType('justrandomtext')).toBeNull(); + }); + + it('prefers HF resolve over repo when both match', () => { + const result = DownloadManager.detectUrlType( + 'https://huggingface.co/user/repo/resolve/main/file.safetensors' + ); + expect(result?.type).toBe('hf-resolve'); + }); + + it('prefers CivitAI over HF when both match', () => { + // CivitAI check comes first in detectUrlType + // This URL should be detected as CivitAI, not HF + const result = DownloadManager.detectUrlType( + 'https://civitai.com/models/123?huggingface.co/test/repo' + ); + expect(result?.type).toBe('civitai'); + }); +});