feat(download): add Hugging Face model download to standalone UI wizard (#965, #977)

Integrate HF model downloading into the existing CivitAI-style wizard flow:
- URL type detection (civitai / hf-resolve / hf-repo / direct-http)
- Repo file explorer with checkbox-based file selection
- Batch/queue download with per-file WebSocket progress
- Aria2 backend support (respects download_backend setting)
- Scanner cache integration via create_default_metadata + add_model_to_cache
- i18n updates for all 10 locales
This commit is contained in:
Will Miao
2026-06-30 19:36:12 +08:00
parent 16f5222efd
commit 09ca91fc0e
20 changed files with 20207 additions and 19207 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -1134,7 +1134,9 @@
"titleWithType": "Download {type} from URL", "titleWithType": "Download {type} from URL",
"civitaiUrl": "Civitai URL(s):", "civitaiUrl": "Civitai URL(s):",
"placeholder": "https://civitai.com/models/...", "placeholder": "https://civitai.com/models/...",
"urlHint": "Enter one CivitAI or CivArchive URL per line. Supports multiple URLs for batch download.", "urlHint": "Enter one CivitAI, CivArchive, or Hugging Face URL per line. Supports multiple URLs for batch download.",
"selectHfFiles": "Select file(s) to download from this repository:",
"fetchingRepoFiles": "Fetching repository files...",
"locationPreview": "Download Location Preview", "locationPreview": "Download Location Preview",
"useDefaultPath": "Use Default Path", "useDefaultPath": "Use Default Path",
"useDefaultPathTooltip": "When enabled, files are automatically organized using configured path templates", "useDefaultPathTooltip": "When enabled, files are automatically organized using configured path templates",
@@ -1163,7 +1165,9 @@
}, },
"errors": { "errors": {
"invalidUrl": "Invalid Civitai URL format", "invalidUrl": "Invalid Civitai URL format",
"noVersions": "No versions available for this model" "noVersions": "No versions available for this model",
"mixedSources": "Cannot mix CivitAI and Hugging Face URLs in the same batch.",
"noModelFiles": "No model files found in this repository."
}, },
"status": { "status": {
"preparing": "Preparing download...", "preparing": "Preparing download...",

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,360 @@
"""Handlers for Hugging Face model listing and download.
Minimal MVP implementation — uses direct HTTP to the HF API for file
listing and the project's existing aiohttp-based Downloader for
downloading. No huggingface_hub dependency required.
"""
from __future__ import annotations
import json
import logging
import os
from typing import Any
import aiohttp
from aiohttp import web
from ...config import config
from ...services.downloader import (
DownloadProgress,
get_downloader,
)
from ...services.aria2_downloader import Aria2Downloader
from ...services.settings_manager import get_settings_manager
from ...services.service_registry import ServiceRegistry
from ...services.websocket_manager import ws_manager
from ...utils.constants import MODEL_FILE_EXTENSIONS
from ...utils.metadata_manager import MetadataManager
from ...utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
logger = logging.getLogger(__name__)
_DEFAULT_MODEL_CLASS = LoraMetadata
_DEFAULT_SCANNER_GETTER = "get_lora_scanner"
# Shared aiohttp session for HF API calls (created on first use)
_hf_api_session: aiohttp.ClientSession | None = None
async def _get_hf_api_session() -> aiohttp.ClientSession:
"""Get or create the shared aiohttp session for HF API calls."""
global _hf_api_session # needed because we reassign the module-level name
if _hf_api_session is None or _hf_api_session.closed:
_hf_api_session = aiohttp.ClientSession(
headers={"User-Agent": "ComfyUI-LoRA-Manager/1.0"},
timeout=aiohttp.ClientTimeout(total=30),
)
return _hf_api_session
def _infer_model_type(model_root: str) -> tuple[Any, str]:
"""Determine model class and scanner by matching ``model_root`` against the
configured root paths for each model type (from ``Config``).
The ``model_root`` value comes from the frontend's model-root dropdown,
which is populated from the current page's scanner roots. By checking
which scanner's root list it belongs to, we avoid fragile heuristics
like substring-matching path names.
"""
norm = os.path.normpath(model_root).replace(os.sep, "/")
# LoRA roots
for p in (config.loras_roots or []) + (config.extra_loras_roots or []):
if os.path.normpath(p).replace(os.sep, "/") == norm:
return LoraMetadata, "get_lora_scanner"
# Checkpoint / UNet roots
for p in (
(config.checkpoints_roots or [])
+ (config.extra_checkpoints_roots or [])
+ (config.unet_roots or [])
+ (config.extra_unet_roots or [])
):
if os.path.normpath(p).replace(os.sep, "/") == norm:
return CheckpointMetadata, "get_checkpoint_scanner"
# Embedding roots
for p in (config.embeddings_roots or []) + (config.extra_embeddings_roots or []):
if os.path.normpath(p).replace(os.sep, "/") == norm:
return EmbeddingMetadata, "get_embedding_scanner"
# Fallback — should not happen in normal use
logger.warning(
"Could not determine model type for root '%s'; defaulting to LoRA",
model_root,
)
return _DEFAULT_MODEL_CLASS, _DEFAULT_SCANNER_GETTER
async def _save_hf_metadata(dest_path: str, repo: str, model_root: str) -> None:
"""Create a proper .metadata.json and add the model to the scanner cache.
Uses ``MetadataManager.create_default_metadata()`` which computes the
SHA256 hash, extracts safetensors header metadata (base_model), and
produces a fully-populated ``LoraMetadata`` (or ``CheckpointMetadata`` /
``EmbeddingMetadata``) object. We then overlay HF-specific fields and
register the model in the in-memory scanner cache so it appears
immediately without a full filesystem walk.
"""
try:
hf_url = f"https://huggingface.co/{repo}"
model_class, scanner_getter_name = _infer_model_type(model_root)
# 1. Create proper metadata (computes SHA256, reads safetensors headers)
metadata = await MetadataManager.create_default_metadata(
dest_path, model_class=model_class
)
if metadata is None:
logger.warning("create_default_metadata returned None for %s", dest_path)
return
# 2. Overlay HF-specific fields
metadata._unknown_fields["hf_url"] = hf_url
metadata.from_civitai = True # leave default, don't interfere with CivitAI fetch
# 3. Save metadata atomically
await MetadataManager.save_metadata(dest_path, metadata)
logger.info("Saved HF metadata (with hf_url) for %s", dest_path)
# 4. Determine relative folder path for cache
# model_root is an absolute path; dest_path is under it
folder = ""
if os.path.isabs(model_root) and dest_path.startswith(model_root):
rel = os.path.relpath(os.path.dirname(dest_path), model_root)
folder = rel.replace(os.sep, "/") if rel != "." else ""
# 5. Add to scanner cache (same as CivitAI's _execute_download does)
scanner_getter = getattr(ServiceRegistry, scanner_getter_name, None)
if scanner_getter is not None:
scanner = await scanner_getter()
if scanner is not None:
metadata_dict = metadata.to_dict()
metadata_dict["hf_url"] = hf_url
await scanner.add_model_to_cache(metadata_dict, folder)
logger.info("Added %s to scanner cache (folder=%s)", dest_path, folder)
except Exception as exc:
logger.warning("Failed to save HF metadata for %s: %s", dest_path, exc)
class HfHandler:
"""Handle Hugging Face model browsing and download."""
async def get_hf_repo_files(self, request: web.Request) -> web.Response:
"""List model-weight files from a HF repo with real file sizes.
Uses the HF tree API endpoint which returns accurate file sizes
(including LFS-tracked files), unlike the model info endpoint.
"""
repo = request.query.get("repo", "").strip()
if not repo or "/" not in repo:
return web.json_response(
{"error": "Missing or invalid 'repo' parameter (expected user/repo)"},
status=400,
)
url = f"https://huggingface.co/api/models/{repo}/tree/main"
try:
session = await _get_hf_api_session()
async with session.get(url) as resp:
if resp.status == 404:
return web.json_response(
{"error": f"Repo '{repo}' not found"}, status=404
)
if resp.status != 200:
text = await resp.text()
return web.json_response(
{"error": f"HF API error {resp.status}: {text[:200]}"},
status=resp.status,
)
tree: list[dict[str, Any]] = await resp.json()
except Exception as exc:
logger.error("Failed to fetch HF repo files: %s", exc)
return web.json_response({"error": str(exc)}, status=502)
files: list[dict[str, Any]] = []
for entry in tree:
path: str = entry.get("path", "")
ext = os.path.splitext(path)[1].lower()
if ext not in MODEL_FILE_EXTENSIONS:
continue
size = entry.get("size", 0) or 0
if size == 0 and "lfs" in entry:
size = entry["lfs"].get("size", 0) or 0
files.append({
"filename": path,
"size": size,
})
files.sort(key=lambda f: f["size"], reverse=True)
return web.json_response(files)
async def download_hf_model(self, request: web.Request) -> web.Response:
"""Download a single file from Hugging Face into the model directory.
POST JSON body::
{
"repo": "dx8152/Flux2-Klein-9B-Consistency",
"filename": "Flux2-Klein-9B-consistency-V2.safetensors",
"revision": "main",
"model_root": "loras",
"relative_path": "",
"use_default_paths": false,
"download_id": "optional-batch-id"
}
If ``download_id`` is provided, real-time progress (bytes, speed,
percentage) is broadcast via the WebSocket progress system, matching
the CivitAI download experience.
Respects the ``download_backend`` setting (``aria2`` or ``default``).
"""
try:
payload: dict[str, Any] = await request.json()
except json.JSONDecodeError:
return web.json_response({"error": "Invalid JSON"}, status=400)
repo = (payload.get("repo") or "").strip()
filename = (payload.get("filename") or "").strip()
revision = (payload.get("revision") or "main").strip()
model_root = (payload.get("model_root") or "").strip()
relative_path = (payload.get("relative_path") or "").strip()
use_default_paths = bool(payload.get("use_default_paths", False))
download_id: str | None = payload.get("download_id")
logger.info(
"download_hf_model: repo=%s file=%s root=%s download_id=%s",
repo, filename, model_root, download_id,
)
if not repo or not filename:
return web.json_response(
{"error": "Missing required fields: 'repo' and 'filename'"}, status=400
)
# Determine target directory
if os.path.isabs(model_root):
base_dir = model_root
else:
base_dir = os.path.join(os.getcwd(), "models", model_root)
if use_default_paths:
author, repo_name = repo.split("/", 1)
target_dir = os.path.join(base_dir, "huggingface", author, repo_name)
elif relative_path:
target_dir = os.path.join(base_dir, relative_path)
else:
target_dir = base_dir
os.makedirs(target_dir, exist_ok=True)
dest_path = os.path.join(target_dir, filename)
# Check if already exists (simple skip)
if os.path.exists(dest_path) and os.path.getsize(dest_path) > 0:
logger.info("download_hf_model: file already exists, skipping — %s", dest_path)
return web.json_response({
"success": True,
"message": f"File already exists: {dest_path}",
"path": dest_path,
})
# Build HF resolve URL
resolve_url = (
f"https://huggingface.co/{repo}/resolve/{revision}/{filename}"
)
# Set up progress callback if download_id is provided
progress_callback = None
if download_id:
async def _progress_callback(
progress: float | DownloadProgress,
snapshot: DownloadProgress | None = None,
) -> None:
percent = 0.0
metrics = snapshot if isinstance(snapshot, DownloadProgress) else None
if isinstance(progress, DownloadProgress):
percent = progress.percent_complete
metrics = progress
elif isinstance(snapshot, DownloadProgress):
percent = snapshot.percent_complete
else:
percent = float(progress)
broadcast: dict[str, Any] = {
"status": "progress",
"progress": round(percent),
}
if metrics:
broadcast["bytes_downloaded"] = metrics.bytes_downloaded
broadcast["total_bytes"] = metrics.total_bytes
broadcast["bytes_per_second"] = metrics.bytes_per_second
await ws_manager.broadcast_download_progress(download_id, broadcast)
progress_callback = _progress_callback
# Respect download backend setting (aria2 vs default)
download_backend = (
get_settings_manager().get("download_backend", "default")
)
if download_backend == "aria2":
aria2 = await Aria2Downloader.get_instance()
aid = download_id or f"hf_{repo}_{filename}"
try:
hf_success, hf_result = await aria2.download_file(
url=resolve_url,
save_path=dest_path,
download_id=aid,
progress_callback=progress_callback,
)
if hf_success:
await _save_hf_metadata(dest_path, repo, model_root)
return web.json_response({
"success": True,
"message": f"Downloaded to {dest_path}",
"path": dest_path,
})
else:
return web.json_response(
{"success": False, "error": hf_result or "aria2 download failed"},
status=500,
)
except Exception as exc:
logger.error("HF download (aria2) failed: %s", exc)
return web.json_response(
{"success": False, "error": str(exc)}, status=500
)
# Default: use built-in aiohttp Downloader
downloader = await get_downloader()
try:
success, result = await downloader.download_file(
url=resolve_url,
save_path=dest_path,
use_auth=False,
allow_resume=True,
progress_callback=progress_callback,
)
if success:
await _save_hf_metadata(dest_path, repo, model_root)
return web.json_response({
"success": True,
"message": f"Downloaded to {result}",
"path": result,
})
else:
return web.json_response(
{"success": False, "error": result or "Download failed"},
status=500,
)
except Exception as exc:
logger.error("HF download failed: %s", exc)
return web.json_response(
{"success": False, "error": str(exc)}, status=500
)

View File

@@ -48,6 +48,7 @@ from ...utils.constants import (
SUPPORTED_MEDIA_EXTENSIONS, SUPPORTED_MEDIA_EXTENSIONS,
VALID_LORA_TYPES, VALID_LORA_TYPES,
) )
from .hf_handlers import HfHandler
from ...utils.civitai_utils import rewrite_preview_url from ...utils.civitai_utils import rewrite_preview_url
from ...utils.example_images_paths import ( from ...utils.example_images_paths import (
find_non_compliant_items_in_example_images_root, find_non_compliant_items_in_example_images_root,
@@ -3315,6 +3316,7 @@ class MiscHandlerSet:
doctor: DoctorHandler, doctor: DoctorHandler,
example_workflows: ExampleWorkflowsHandler, example_workflows: ExampleWorkflowsHandler,
base_model: BaseModelHandlerSet, base_model: BaseModelHandlerSet,
hf_handler: HfHandler | None = None,
) -> None: ) -> None:
self.health = health self.health = health
self.settings = settings self.settings = settings
@@ -3333,6 +3335,7 @@ class MiscHandlerSet:
self.doctor = doctor self.doctor = doctor
self.example_workflows = example_workflows self.example_workflows = example_workflows
self.base_model = base_model self.base_model = base_model
self.hf_handler = hf_handler
def to_route_mapping( def to_route_mapping(
self, self,
@@ -3378,6 +3381,9 @@ class MiscHandlerSet:
"get_supporters": self.supporters.get_supporters, "get_supporters": self.supporters.get_supporters,
"get_example_workflows": self.example_workflows.get_example_workflows, "get_example_workflows": self.example_workflows.get_example_workflows,
"get_example_workflow": self.example_workflows.get_example_workflow, "get_example_workflow": self.example_workflows.get_example_workflow,
# Hugging Face handlers
"get_hf_repo_files": self.hf_handler.get_hf_repo_files,
"download_hf_model": self.hf_handler.download_hf_model,
# Base model handlers # Base model handlers
"get_base_models": self.base_model.get_base_models, "get_base_models": self.base_model.get_base_models,
"refresh_base_models": self.base_model.refresh_base_models, "refresh_base_models": self.base_model.refresh_base_models,

View File

@@ -94,6 +94,13 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition( RouteDefinition(
"GET", "/api/lm/delete-model-version", "delete_model_version" "GET", "/api/lm/delete-model-version", "delete_model_version"
), ),
# Hugging Face model endpoints
RouteDefinition(
"GET", "/api/lm/hf-repo-files", "get_hf_repo_files"
),
RouteDefinition(
"POST", "/api/lm/download-hf-model", "download_hf_model"
),
) )

View File

@@ -39,6 +39,7 @@ from .handlers.misc_handlers import (
build_service_registry_adapter, build_service_registry_adapter,
) )
from .handlers.base_model_handlers import BaseModelHandlerSet from .handlers.base_model_handlers import BaseModelHandlerSet
from .handlers.hf_handlers import HfHandler
from .misc_route_registrar import MiscRouteRegistrar from .misc_route_registrar import MiscRouteRegistrar
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -136,6 +137,7 @@ class MiscRoutes:
doctor = DoctorHandler(settings_service=self._settings) doctor = DoctorHandler(settings_service=self._settings)
example_workflows = ExampleWorkflowsHandler() example_workflows = ExampleWorkflowsHandler()
base_model = BaseModelHandlerSet() base_model = BaseModelHandlerSet()
hf_handler = HfHandler()
return self._handler_set_factory( return self._handler_set_factory(
health=health, health=health,
@@ -155,6 +157,7 @@ class MiscRoutes:
doctor=doctor, doctor=doctor,
example_workflows=example_workflows, example_workflows=example_workflows,
base_model=base_model, base_model=base_model,
hf_handler=hf_handler,
) )

View File

@@ -47,6 +47,20 @@ SUPPORTED_MEDIA_EXTENSIONS = {
"videos": [".mp4", ".webm"], "videos": [".mp4", ".webm"],
} }
# Model weight file extensions recognised by scanners.
# This is the union of all scanner extensions (lora, checkpoint, embedding).
MODEL_FILE_EXTENSIONS = {
".safetensors",
".ckpt",
".pt",
".pt2",
".bin",
".pth",
".pkl",
".sft",
".gguf",
}
# Valid sub-types for each scanner type # Valid sub-types for each scanner type
VALID_LORA_SUB_TYPES = ["lora", "locon", "dora"] VALID_LORA_SUB_TYPES = ["lora", "locon", "dora"]
VALID_CHECKPOINT_SUB_TYPES = ["checkpoint", "diffusion_model"] VALID_CHECKPOINT_SUB_TYPES = ["checkpoint", "diffusion_model"]

View File

@@ -822,3 +822,111 @@
[data-theme="dark"] .batch-preview-item { [data-theme="dark"] .batch-preview-item {
background: var(--lora-surface); background: var(--lora-surface);
} }
/* HF Repo File Explorer Step */
.hf-repo-header {
margin-bottom: var(--space-2);
font-size: 0.95em;
color: var(--text-color);
opacity: 0.8;
}
.repo-file-list {
max-height: 360px;
overflow-y: auto;
margin: var(--space-2) 0;
display: flex;
flex-direction: column;
gap: 6px;
}
.repo-file-item {
display: flex;
align-items: center;
gap: 10px;
padding: 10px 12px;
border: 1px solid var(--border-color);
border-radius: var(--border-radius-sm);
cursor: pointer;
transition: var(--transition-base);
background: var(--bg-color);
}
.repo-file-item:hover {
border-color: var(--lora-accent);
box-shadow: var(--shadow-sm);
}
.repo-file-item.selected {
border: 2px solid var(--lora-accent);
background: oklch(var(--lora-accent) / 0.05);
}
.repo-file-item .repo-file-checkbox {
width: 18px;
height: 18px;
cursor: pointer;
accent-color: var(--lora-accent);
flex-shrink: 0;
padding: 0;
border: none;
}
.repo-file-icon {
font-size: 1.2em;
color: var(--text-color);
opacity: 0.6;
width: 24px;
text-align: center;
flex-shrink: 0;
}
.repo-file-name {
flex: 1;
font-weight: 500;
font-size: 0.95em;
word-break: keep-all;
overflow-wrap: anywhere;
min-width: 0;
}
.repo-file-meta {
display: flex;
align-items: center;
gap: 8px;
font-size: 0.85em;
color: var(--text-color);
opacity: 0.6;
white-space: nowrap;
}
.repo-file-size {
font-variant-numeric: tabular-nums;
}
.hf-badge {
display: inline-block;
padding: 1px 6px;
border-radius: 8px;
background: oklch(0.55 0.12 250 / 0.15);
color: oklch(0.7 0.12 250);
font-size: 0.75em;
font-weight: 600;
margin-left: 4px;
}
[data-theme="dark"] .repo-file-item {
background: var(--lora-surface);
}
/* Checkbox inside HF batch preview items */
.batch-preview-checkbox {
width: 18px;
height: 18px;
cursor: pointer;
accent-color: var(--lora-accent);
flex-shrink: 0;
padding: 0;
border: none;
margin: 0;
}

View File

@@ -190,6 +190,12 @@ export const DOWNLOAD_ENDPOINTS = {
exampleImages: '/api/lm/force-download-example-images' // New endpoint for downloading example images exampleImages: '/api/lm/force-download-example-images' // New endpoint for downloading example images
}; };
// Hugging Face API endpoints
export const HF_ENDPOINTS = {
repoFiles: '/api/lm/hf-repo-files',
download: '/api/lm/download-hf-model',
};
// WebSocket endpoints // WebSocket endpoints
export const WS_ENDPOINTS = { export const WS_ENDPOINTS = {
fetchProgress: '/ws/fetch-progress' fetchProgress: '/ws/fetch-progress'

View File

@@ -7,6 +7,7 @@ import {
getCurrentModelType, getCurrentModelType,
isValidModelType, isValidModelType,
DOWNLOAD_ENDPOINTS, DOWNLOAD_ENDPOINTS,
HF_ENDPOINTS,
WS_ENDPOINTS WS_ENDPOINTS
} from './apiConfig.js'; } from './apiConfig.js';
import { resetAndReload } from './modelApiFactory.js'; import { resetAndReload } from './modelApiFactory.js';
@@ -1243,6 +1244,48 @@ export class BaseModelApiClient {
} }
} }
async fetchHfRepoFiles(repo, revision = 'main') {
try {
const params = new URLSearchParams({ repo, revision });
const response = await fetch(`${HF_ENDPOINTS.repoFiles}?${params}`);
if (!response.ok) {
const err = await response.json().catch(() => ({}));
throw new Error(err.error || 'Failed to fetch HF repo files');
}
return await response.json();
} catch (error) {
console.error('Error fetching HF repo files:', error);
throw error;
}
}
async downloadHfModel({ repo, filename, revision, modelRoot, relativePath, useDefaultPaths, download_id }) {
try {
const response = await fetch(HF_ENDPOINTS.download, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
repo,
filename,
revision: revision || 'main',
model_root: modelRoot,
relative_path: relativePath || '',
use_default_paths: useDefaultPaths || false,
...(download_id ? { download_id } : {}),
})
});
if (!response.ok) {
throw new Error(await response.text());
}
return await response.json();
} catch (error) {
console.error('Error downloading HF model:', error);
throw error;
}
}
_buildQueryParams(baseParams, pageState) { _buildQueryParams(baseParams, pageState) {
const params = new URLSearchParams(baseParams); const params = new URLSearchParams(baseParams);
const isExcludedView = pageState.viewMode === 'excluded'; const isExcludedView = pageState.viewMode === 'excluded';

View File

@@ -27,6 +27,11 @@ export class DownloadManager {
this.isBatchMode = false; this.isBatchMode = false;
this.editingBatchIndex = -1; this.editingBatchIndex = -1;
// HF download state
this.hfRepoId = null;
this.hfRepoFiles = [];
this.hfSelectedFiles = [];
this.loadingManager = new LoadingManager(); this.loadingManager = new LoadingManager();
this.folderTreeManager = new FolderTreeManager(); this.folderTreeManager = new FolderTreeManager();
this.folderClickHandler = null; this.folderClickHandler = null;
@@ -44,6 +49,10 @@ export class DownloadManager {
this.handleToggleDefaultPath = this.toggleDefaultPath.bind(this); this.handleToggleDefaultPath = this.toggleDefaultPath.bind(this);
this.handleBackToUrlFromBatch = this.backToUrlFromBatch.bind(this); this.handleBackToUrlFromBatch = this.backToUrlFromBatch.bind(this);
this.handleNextFromBatch = this.nextFromBatch.bind(this); this.handleNextFromBatch = this.nextFromBatch.bind(this);
// HF handlers
this.handleBackToUrlFromHf = this.backToUrlFromHf.bind(this);
this.handleNextFromHfFiles = this.nextFromHfFiles.bind(this);
} }
showDownloadModal() { showDownloadModal() {
@@ -99,6 +108,12 @@ export class DownloadManager {
// Default path toggle handler // Default path toggle handler
document.getElementById('useDefaultPath').addEventListener('change', this.handleToggleDefaultPath); document.getElementById('useDefaultPath').addEventListener('change', this.handleToggleDefaultPath);
// HF step buttons
const backToUrlFromHfBtn = document.getElementById('backToUrlFromHfBtn');
if (backToUrlFromHfBtn) backToUrlFromHfBtn.addEventListener('click', this.handleBackToUrlFromHf);
const nextFromHfFiles = document.getElementById('nextFromHfFiles');
if (nextFromHfFiles) nextFromHfFiles.addEventListener('click', this.handleNextFromHfFiles);
} }
updateModalLabels() { updateModalLabels() {
@@ -160,6 +175,11 @@ export class DownloadManager {
// Reset default path toggle // Reset default path toggle
this.loadDefaultPathSetting(); this.loadDefaultPathSetting();
// Reset HF state
this.hfRepoId = null;
this.hfRepoFiles = [];
this.hfSelectedFiles = [];
} }
async retrieveVersionsForModel(modelId, source = null) { async retrieveVersionsForModel(modelId, source = null) {
@@ -180,6 +200,29 @@ export class DownloadManager {
return; return;
} }
// Detect URL types — all URLs must share the same source type
const urlTypes = urls.map(u => DownloadManager.detectUrlType(u));
const isHf = urlTypes.every(t => t && (t.type === 'hf-resolve' || t.type === 'hf-repo'));
const isCivitai = urlTypes.every(t => t && t.type === 'civitai');
if (!isHf && !isCivitai) {
const allValid = urlTypes.every(t => t !== null);
if (!allValid) {
errorElement.textContent = translate('modals.download.errors.invalidUrl');
return;
}
// Mixed sources not supported in one batch
if (urls.length > 1) {
errorElement.textContent = translate('modals.download.errors.mixedSources');
return;
}
}
if (isHf) {
return this._validateAndFetchHf(urls, errorElement);
}
// --- Original CivitAI flow below ---
if (urls.length === 1) { if (urls.length === 1) {
this.isBatchMode = false; this.isBatchMode = false;
try { try {
@@ -271,6 +314,142 @@ export class DownloadManager {
this.showBatchPreviewStep(); this.showBatchPreviewStep();
} }
// ---- Hugging Face download flow ----
async _validateAndFetchHf(urls, errorElement) {
if (urls.length === 1) {
const info = DownloadManager.detectUrlType(urls[0]);
// Direct file resolve URL → skip file selection, go to location
if (info.type === 'hf-resolve') {
this.isBatchMode = false;
this.hfRepoId = info.repo;
this.hfSelectedFiles = [info.filename];
this.hfRepoFiles = [];
this.source = 'huggingface';
this.proceedToLocation();
return;
}
// Repo URL → fetch file list
try {
this.loadingManager.showSimpleLoading(translate('modals.download.fetchingRepoFiles'));
const files = await this.apiClient.fetchHfRepoFiles(info.repo);
if (!files || files.length === 0) {
throw new Error(translate('modals.download.errors.noModelFiles'));
}
this.hfRepoId = info.repo;
this.hfRepoFiles = files;
this.hfSelectedFiles = [];
this.isBatchMode = false;
this.source = 'huggingface';
this.showRepoFileStep(info.repo);
} catch (err) {
errorElement.textContent = err.message;
} finally {
this.loadingManager.hide();
}
return;
}
// Multiple HF URLs → batch mode: flatten all files from all repos
this.isBatchMode = true;
this.batchModels = [];
this.source = 'huggingface';
this.loadingManager.showSimpleLoading(translate('modals.download.fetchingRepoFiles'));
for (const url of urls) {
const info = DownloadManager.detectUrlType(url);
if (!info) {
this.batchModels.push({ url, error: 'Invalid URL', versions: [], selectedVersion: null });
continue;
}
if (info.type === 'hf-resolve') {
this.batchModels.push({
url,
source: 'huggingface',
repo: info.repo,
filename: info.filename,
revision: info.revision || 'main',
displayName: info.filename,
selectedVersion: true,
versions: [],
checked: true,
error: null,
});
} else if (info.type === 'hf-repo') {
try {
const files = await this.apiClient.fetchHfRepoFiles(info.repo);
if (!files || files.length === 0) {
this.batchModels.push({ url, error: 'No model files found', versions: [], selectedVersion: null });
continue;
}
// Flatten: create one batch item per file, all checked by default
for (const file of files) {
this.batchModels.push({
url,
source: 'huggingface',
repo: info.repo,
filename: file.filename,
revision: 'main',
displayName: file.filename,
fileSizeBytes: file.size,
selectedVersion: true,
versions: [],
checked: true,
error: null,
});
}
} catch (err) {
this.batchModels.push({ url, error: err.message, versions: [], selectedVersion: null });
}
}
}
this.loadingManager.hide();
this.showBatchPreviewStep();
}
showRepoFileStep(repoId) {
document.querySelectorAll('.download-step').forEach(s => s.style.display = 'none');
document.getElementById('repoFileStep').style.display = 'block';
document.getElementById('hfRepoLabel').textContent = repoId;
const list = document.getElementById('repoFileList');
list.innerHTML = this.hfRepoFiles.map((f, i) => {
const sizeMb = f.size > 0 ? (f.size / (1024 * 1024)).toFixed(1) : '?';
return `
<div class="repo-file-item" data-index="${i}">
<input type="checkbox" class="repo-file-checkbox" data-index="${i}" />
<span class="repo-file-icon"><i class="fas fa-file"></i></span>
<span class="repo-file-name">${f.filename}</span>
<span class="repo-file-meta">
<span class="repo-file-size">${sizeMb} MB</span>
</span>
</div>
`;
}).join('');
}
backToUrlFromHf() {
this.hfRepoId = null;
this.hfRepoFiles = [];
this.hfSelectedFiles = [];
document.getElementById('repoFileStep').style.display = 'none';
document.getElementById('urlStep').style.display = 'block';
}
nextFromHfFiles() {
// Read checked state directly from DOM — more reliable than event-tracking
const checked = document.querySelectorAll('.repo-file-checkbox:checked');
this.hfSelectedFiles = Array.from(checked).map(cb => {
const idx = parseInt(cb.dataset.index);
return this.hfRepoFiles[idx].filename;
});
if (!this.hfSelectedFiles.length) {
return;
}
this.proceedToLocation();
}
async fetchVersionsForCurrentModel() { async fetchVersionsForCurrentModel() {
const errorElement = document.getElementById('urlError'); const errorElement = document.getElementById('urlError');
if (errorElement) { if (errorElement) {
@@ -311,6 +490,50 @@ export class DownloadManager {
return { modelId: null, modelVersionId: null, source: null }; return { modelId: null, modelVersionId: null, source: null };
} }
/**
* Detect the source type of a download URL.
* @param {string} url
* @returns {{ type: string, repo?: string, filename?: string, revision?: string } | null}
* type: 'civitai' | 'civarchive' | 'hf-resolve' | 'hf-repo' | 'direct-http'
*/
static detectUrlType(url) {
const trimmed = url.trim();
if (!trimmed) return null;
// CivitAI
if (/civitai\.com\/models\//i.test(trimmed) || /civitaiarchive|civarchive/i.test(trimmed)) {
// Will be parsed by existing CivitAI logic
return { type: 'civitai' };
}
// Hugging Face resolve URL → direct file
const hfResolveMatch = trimmed.match(/huggingface\.co\/([^/\s]+\/[^/\s]+)\/resolve\/([^/\s]+)\/(.+)/i);
if (hfResolveMatch) {
return {
type: 'hf-resolve',
repo: hfResolveMatch[1],
revision: hfResolveMatch[2],
filename: hfResolveMatch[3],
};
}
// Hugging Face repo URL (huggingface.co/user/repo or bare user/repo path)
const hfRepoMatch = trimmed.match(/(?:https?:\/\/huggingface\.co\/)?([a-zA-Z0-9_.-]+\/[a-zA-Z0-9_.-]+)(?:\/?$|$)/);
if (hfRepoMatch) {
return {
type: 'hf-repo',
repo: hfRepoMatch[1],
};
}
// Direct HTTP(S) URL (non-HF)
if (/^https?:\/\//i.test(trimmed)) {
return { type: 'direct-http' };
}
return null;
}
extractModelId(url) { extractModelId(url) {
const result = DownloadManager.parseModelUrl(url); const result = DownloadManager.parseModelUrl(url);
this.modelVersionId = result.modelVersionId; this.modelVersionId = result.modelVersionId;
@@ -559,8 +782,8 @@ export class DownloadManager {
return; return;
} }
// In single-URL mode, validate version selection // In single-URL mode, validate version selection (skip for HF)
if (!this.isBatchMode) { if (!this.isBatchMode && this.source !== 'huggingface') {
if (!this.currentVersion) { if (!this.currentVersion) {
showToast('toast.loras.pleaseSelectVersion', {}, 'error'); showToast('toast.loras.pleaseSelectVersion', {}, 'error');
return; return;
@@ -784,6 +1007,75 @@ export class DownloadManager {
} }
} }
async _downloadHfSingle({ modelRoot, targetFolder, useDefaultPaths }) {
modalManager.closeModal('downloadModal');
this.loadingManager.restoreProgressBar();
const totalFiles = this.hfSelectedFiles.length;
const updateProgress = this.loadingManager.showDownloadProgress(totalFiles);
try {
let completedDownloads = 0;
for (let i = 0; i < totalFiles; i++) {
const filename = this.hfSelectedFiles[i];
updateProgress(0, completedDownloads, filename);
this.loadingManager.setStatus(`Downloading ${filename}...`);
const downloadId = Date.now().toString() + '_' + i;
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`);
await new Promise((resolve, reject) => {
ws.onopen = resolve;
ws.onerror = reject;
});
// Capture completed count at WS creation time so progress
// updates arriving after completedDownloads increments still
// show the correct "N / total" position.
const snapshotCompleted = completedDownloads;
ws.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.status === 'progress') {
const metrics = {
bytesDownloaded: data.bytes_downloaded,
totalBytes: data.total_bytes,
bytesPerSecond: data.bytes_per_second,
};
updateProgress(data.progress, snapshotCompleted, filename, metrics);
}
};
const response = await this.apiClient.downloadHfModel({
repo: this.hfRepoId,
filename,
revision: 'main',
modelRoot,
relativePath: targetFolder,
useDefaultPaths,
download_id: downloadId,
});
ws.close();
if (response?.success) {
completedDownloads++;
updateProgress(100, completedDownloads, filename);
}
}
showToast('toast.loras.downloadCompleted', {}, 'success');
// Reload page data — model is already in scanner cache via backend
await resetAndReload(true);
return true;
} catch (error) {
console.error('Failed to download HF model:', error);
showToast('toast.downloads.downloadError', { message: error?.message }, 'error');
return false;
} finally {
this.loadingManager.hide();
}
}
updatePathSelectionUI() { updatePathSelectionUI() {
const manualSelection = document.getElementById('manualPathSelection'); const manualSelection = document.getElementById('manualPathSelection');
@@ -812,7 +1104,11 @@ export class DownloadManager {
document.querySelectorAll('.download-step').forEach(step => step.style.display = 'none'); document.querySelectorAll('.download-step').forEach(step => step.style.display = 'none');
document.getElementById('batchPreviewStep').style.display = 'block'; document.getElementById('batchPreviewStep').style.display = 'block';
const validCount = this.batchModels.filter(m => !m.error && m.selectedVersion).length; const validCount = this.batchModels.filter(m => {
if (m.error) return false;
if (m.source === 'huggingface') return m.checked !== false;
return m.selectedVersion;
}).length;
document.getElementById('downloadModalTitle').textContent = document.getElementById('downloadModalTitle').textContent =
translate('modals.download.titleWithType', { type: this.apiClient.apiConfig.config.displayName }) + translate('modals.download.titleWithType', { type: this.apiClient.apiConfig.config.displayName }) +
` (${validCount})`; ` (${validCount})`;
@@ -837,6 +1133,33 @@ export class DownloadManager {
} }
const ver = item.selectedVersion; const ver = item.selectedVersion;
// HF batch item rendering with checkbox
if (item.source === 'huggingface') {
const hfSize = item.fileSizeBytes
? (item.fileSizeBytes / (1024 * 1024)).toFixed(1)
: '?';
return `
<div class="batch-preview-item" data-index="${index}">
<input type="checkbox" class="batch-preview-checkbox"
data-index="${index}" ${item.checked !== false ? 'checked' : ''} />
<div class="batch-preview-icon" style="color: var(--lora-accent);">
<i class="fas fa-cloud"></i>
</div>
<div class="batch-preview-info">
<div class="batch-preview-name">${item.displayName || item.filename || `HF #${index}`} <span class="hf-badge">HF</span></div>
<div class="batch-preview-meta">
<span>${hfSize} MB</span>
<span>${item.repo || ''}</span>
</div>
</div>
<button class="batch-preview-remove" data-index="${index}" title="${translate('common.actions.remove', {}, 'Remove')}">
<i class="fas fa-times"></i>
</button>
</div>
`;
}
const firstImage = ver?.images?.find(img => !img.url.endsWith('.mp4')); const firstImage = ver?.images?.find(img => !img.url.endsWith('.mp4'));
const thumbnailUrl = firstImage ? firstImage.url : '/loras_static/images/no-preview.png'; const thumbnailUrl = firstImage ? firstImage.url : '/loras_static/images/no-preview.png';
const fileSize = ver?.modelSizeKB const fileSize = ver?.modelSizeKB
@@ -881,6 +1204,24 @@ export class DownloadManager {
} }
}; };
// Checkbox handler for HF batch items
const checkboxes = list.querySelectorAll('.batch-preview-checkbox');
checkboxes.forEach(cb => {
cb.addEventListener('change', (e) => {
const idx = parseInt(e.target.dataset.index);
if (this.batchModels[idx]) {
this.batchModels[idx].checked = e.target.checked;
}
// Update valid count in title
const checkedCount = this.batchModels.filter(
m => !m.error && m.checked !== false
).length;
document.getElementById('downloadModalTitle').textContent =
translate('modals.download.titleWithType', { type: this.apiClient.apiConfig.config.displayName }) +
` (${checkedCount})`;
});
});
const nextBtn = document.getElementById('nextFromBatchBtn'); const nextBtn = document.getElementById('nextFromBatchBtn');
nextBtn.disabled = validCount === 0; nextBtn.disabled = validCount === 0;
nextBtn.classList.toggle('disabled', validCount === 0); nextBtn.classList.toggle('disabled', validCount === 0);
@@ -903,7 +1244,12 @@ export class DownloadManager {
} }
nextFromBatch() { nextFromBatch() {
const validModels = this.batchModels.filter(m => !m.error && m.selectedVersion); // For HF items, respect the checked flag; for CivitAI items, use selectedVersion
const validModels = this.batchModels.filter(m => {
if (m.error) return false;
if (m.source === 'huggingface') return m.checked !== false;
return m.selectedVersion;
});
if (validModels.length === 0) return; if (validModels.length === 0) return;
this.proceedToLocation(); this.proceedToLocation();
} }
@@ -953,6 +1299,15 @@ export class DownloadManager {
targetFolder = this.folderTreeManager.getSelectedPath(); targetFolder = this.folderTreeManager.getSelectedPath();
} }
if (!this.isBatchMode) { if (!this.isBatchMode) {
// Single-item download
if (this.source === 'huggingface') {
return this._downloadHfSingle({
modelRoot,
targetFolder,
useDefaultPaths,
});
}
const fileParams = this.selectedFile ? { const fileParams = this.selectedFile ? {
type: this.selectedFile.type || 'Model', type: this.selectedFile.type || 'Model',
format: this.selectedFile.metadata?.format || 'SafeTensor', format: this.selectedFile.metadata?.format || 'SafeTensor',
@@ -974,7 +1329,13 @@ export class DownloadManager {
} }
// Batch download mode // Batch download mode
const downloadItems = this.batchModels.filter(m => !m.error && m.selectedVersion && !m.selectedVersion.existsLocally); const downloadItems = this.batchModels.filter(m => {
if (m.error) return false;
if (!m.selectedVersion) return false;
// HF items have selectedVersion as a boolean marker + checked flag
if (m.source === 'huggingface') return m.checked !== false;
return !m.selectedVersion.existsLocally;
});
if (downloadItems.length === 0) { if (downloadItems.length === 0) {
showToast('toast.loras.downloadCompleted', {}, 'info'); showToast('toast.loras.downloadCompleted', {}, 'info');
modalManager.closeModal('downloadModal'); modalManager.closeModal('downloadModal');
@@ -1016,22 +1377,56 @@ export class DownloadManager {
for (let i = 0; i < downloadItems.length; i++) { for (let i = 0; i < downloadItems.length; i++) {
const item = downloadItems[i]; const item = downloadItems[i];
const ver = item.selectedVersion; const name = item.displayName || item.filename || (item.selectedVersion?.name || `Model #${item.modelId}`);
const name = ver?.name || `Model #${item.modelId}`; const isHf = item.source === 'huggingface';
updateProgress(0, completedDownloads, name); updateProgress(0, completedDownloads, name);
loadingManager.setStatus(`${i + 1}/${downloadItems.length}: ${name}`); loadingManager.setStatus(`${i + 1}/${downloadItems.length}: ${name}`);
try { try {
const response = await this.apiClient.downloadModel( let response;
item.modelId, if (isHf) {
ver.id, // Per-file WebSocket for real-time progress
modelRoot, const downloadId = Date.now().toString() + '_hf_' + i;
targetFolder, const wsHf = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`);
useDefaultPaths, await new Promise((resolve, reject) => {
batchDownloadId, wsHf.onopen = resolve;
item.source wsHf.onerror = reject;
); });
const snapshotCompleted = completedDownloads;
wsHf.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.status === 'progress') {
const metrics = {
bytesDownloaded: data.bytes_downloaded,
totalBytes: data.total_bytes,
bytesPerSecond: data.bytes_per_second,
};
updateProgress(data.progress, snapshotCompleted, name, metrics);
}
};
response = await this.apiClient.downloadHfModel({
repo: item.repo,
filename: item.filename,
revision: item.revision || 'main',
modelRoot,
relativePath: targetFolder,
useDefaultPaths,
download_id: downloadId,
});
wsHf.close();
} else {
response = await this.apiClient.downloadModel(
item.modelId,
item.selectedVersion.id,
modelRoot,
targetFolder,
useDefaultPaths,
batchDownloadId,
item.source
);
}
if (!response.success) { if (!response.success) {
failedDownloads++; failedDownloads++;

View File

@@ -14,7 +14,7 @@
<div class="error-message" id="urlError"></div> <div class="error-message" id="urlError"></div>
<div class="input-hint"> <div class="input-hint">
<i class="fas fa-info-circle"></i> <i class="fas fa-info-circle"></i>
<span>{{ t('modals.download.urlHint') }}</span> <span id="urlHint">{{ t('modals.download.urlHint') }}</span>
</div> </div>
</div> </div>
<div class="modal-actions"> <div class="modal-actions">
@@ -22,6 +22,24 @@
</div> </div>
</div> </div>
<!-- Step 1b: HF Repo File Explorer (shown when HF repo URL is detected) -->
<div class="download-step" id="repoFileStep" style="display: none;">
<div class="input-group">
<label>{{ t('modals.download.selectHfFiles') }}</label>
<div class="hf-repo-header">
<span id="hfRepoLabel" class="hf-repo-label"></span>
</div>
<div class="repo-file-list" id="repoFileList">
<!-- Files will be inserted here dynamically -->
</div>
<div class="error-message" id="repoFileError"></div>
</div>
<div class="modal-actions">
<button class="secondary-btn" id="backToUrlFromHfBtn">{{ t('common.actions.back') }}</button>
<button class="primary-btn" id="nextFromHfFiles">{{ t('common.actions.next') }}</button>
</div>
</div>
<!-- Step 2: Batch Preview (multi-URL mode) --> <!-- Step 2: Batch Preview (multi-URL mode) -->
<div class="download-step" id="batchPreviewStep" style="display: none;"> <div class="download-step" id="batchPreviewStep" style="display: none;">
<div class="batch-preview-list" id="batchPreviewList"> <div class="batch-preview-list" id="batchPreviewList">