fix(download): harden HF download path validation, fix WebSocket leak, add URL detection tests (#965, #977)

Security hardening:
- Validate repo format with strict regex (reject .. traversal)
- Validate filename rejects path separators and ..
- Validate relative_path rejects absolute paths and ..
- Verify model_root is within configured scanner roots using
  realpath + os.sep guard to prevent prefix-match bypass
- Add realpath-based escape detection for final dest_path

Bug fixes:
- Fix WebSocket leak in _downloadHfSingle: wrap ws.close() in
  try/finally so it closes even if downloadHfModel() throws
- Same fix for batch HF download per-file WebSocket loop

Frontend hardening:
- Tighten HF repo regex: require huggingface.co for full URLs,
  reject bare .. patterns
- Add 12 unit tests for detectUrlType() covering HF resolve,
  HF repo, CivitAI, CivArchive, direct HTTP, edge cases
This commit is contained in:
Will Miao
2026-07-01 05:51:58 +08:00
parent 7cf785b72f
commit 8348a0cef8
3 changed files with 233 additions and 66 deletions

View File

@@ -10,6 +10,7 @@ from __future__ import annotations
import json
import logging
import os
import re
from typing import Any
import aiohttp
@@ -235,14 +236,55 @@ class HfHandler:
{"error": "Missing required fields: 'repo' and 'filename'"}, status=400
)
# Determine target directory
if os.path.isabs(model_root):
base_dir = model_root
# Validate repo format — must be user/repo_name
if repo.count("/") != 1 or not re.match(r"^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+$", repo):
return web.json_response({"error": f"Invalid repo format: {repo}"}, status=400)
author, repo_name = repo.split("/", 1)
if ".." in (author, repo_name) or "." in (author, repo_name):
return web.json_response({"error": f"Invalid repo format: {repo}"}, status=400)
# Validate filename — must not contain path separators or ..
if "/" in filename or "\\" in filename or ".." in filename:
return web.json_response({"error": "Invalid filename"}, status=400)
# Validate relative_path — must not be absolute or escape base directory
if relative_path:
if os.path.isabs(relative_path):
return web.json_response({"error": "relative_path must not be absolute"}, status=400)
if ".." in relative_path.split("/") or "\\" in relative_path:
return web.json_response({"error": "Invalid relative_path"}, status=400)
# Validate model_root — must not contain path traversal
if not os.path.isabs(model_root):
# For relative model_root, check it doesn't escape
resolved_model_root = os.path.realpath(
os.path.join(os.getcwd(), "models", model_root)
)
else:
base_dir = os.path.join(os.getcwd(), "models", model_root)
resolved_model_root = os.path.realpath(model_root)
# Verify model_root is within a configured scanner root
allowed_roots = set()
for root_list in (
config.loras_roots or [],
config.extra_loras_roots or [],
config.checkpoints_roots or [],
config.extra_checkpoints_roots or [],
config.unet_roots or [],
config.extra_unet_roots or [],
config.embeddings_roots or [],
config.extra_embeddings_roots or [],
):
for r in root_list:
allowed_roots.add(os.path.realpath(r))
if not any(resolved_model_root == root or resolved_model_root.startswith(root + os.sep) for root in allowed_roots):
logger.warning("Invalid model_root rejected: %s", model_root)
return web.json_response({"error": f"Invalid model_root: {model_root}"}, status=400)
base_dir = resolved_model_root
if use_default_paths:
author, repo_name = repo.split("/", 1)
target_dir = os.path.join(base_dir, "huggingface", author, repo_name)
elif relative_path:
target_dir = os.path.join(base_dir, relative_path)
@@ -252,6 +294,13 @@ class HfHandler:
os.makedirs(target_dir, exist_ok=True)
dest_path = os.path.join(target_dir, filename)
# Resolve symlinks and check for path traversal escape
real_dest = os.path.realpath(dest_path)
real_base = os.path.realpath(target_dir)
if not real_dest.startswith(real_base + os.sep):
logger.warning("Path traversal blocked: %s -> %s", dest_path, real_dest)
return web.json_response({"error": "Path traversal detected"}, status=400)
# Check if already exists (simple skip)
if os.path.exists(dest_path) and os.path.getsize(dest_path) > 0:
logger.info("download_hf_model: file already exists, skipping — %s", dest_path)