Merge pull request #691 from willmiao/feat/zip-preview

feat(downloads): support safetensors zips and previews
This commit is contained in:
pixelpaws
2025-11-20 19:56:31 +08:00
committed by GitHub
2 changed files with 364 additions and 64 deletions

View File

@@ -1,7 +1,10 @@
import copy
import logging
import os
import asyncio
import inspect
import shutil
import zipfile
from collections import OrderedDict
import uuid
from typing import Dict, List, Optional, Tuple
@@ -12,6 +15,7 @@ from ..utils.civitai_utils import rewrite_preview_url
from ..utils.preview_selection import select_preview_media
from ..utils.utils import sanitize_folder_name
from ..utils.exif_utils import ExifUtils
from ..utils.file_utils import calculate_sha256
from ..utils.metadata_manager import MetadataManager
from .service_registry import ServiceRegistry
from .settings_manager import get_settings_manager
@@ -556,6 +560,13 @@ class DownloadManager:
download_id: str = None,
) -> Dict:
"""Execute the actual download process including preview images and model files"""
metadata_entries: List = []
metadata_files_for_cleanup: List[str] = []
extracted_paths: List[str] = []
metadata_path = ""
preview_targets: List[str] = []
preview_path: str | None = None
preview_nsfw_level = 0
try:
# Extract original filename details
original_filename = os.path.basename(metadata.file_path)
@@ -699,10 +710,9 @@ class DownloadManager:
logger.warning(f"Failed to delete temp file: {e}")
if preview_downloaded and preview_path:
preview_nsfw_level = nsfw_level
metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = nsfw_level
if download_id and download_id in self._active_downloads:
self._active_downloads[download_id]['preview_path'] = preview_path
if progress_callback:
await progress_callback(3) # 3% progress after preview download
@@ -761,77 +771,189 @@ class DownloadManager:
return {'success': False, 'error': last_error or 'Failed to download file'}
# 4. Update file information (size and modified time)
metadata.update_file_info(save_path)
# 4. Handle archive extraction and prepare per-file metadata
actual_file_paths = [save_path]
if zipfile.is_zipfile(save_path):
extracted_paths = await self._extract_safetensors_from_archive(save_path)
if not extracted_paths:
return {'success': False, 'error': 'Zip archive does not contain any safetensors files'}
actual_file_paths = extracted_paths
try:
os.remove(save_path)
except OSError as exc:
logger.warning(f"Unable to delete temporary archive {save_path}: {exc}")
if download_id and download_id in self._active_downloads:
self._active_downloads[download_id]['file_path'] = extracted_paths[0]
self._active_downloads[download_id]['extracted_paths'] = extracted_paths
metadata_entries = await self._build_metadata_entries(metadata, actual_file_paths)
if preview_path:
preview_targets = self._distribute_preview_to_entries(preview_path, metadata_entries)
for entry, target in zip(metadata_entries, preview_targets):
entry.preview_url = target.replace(os.sep, "/")
entry.preview_nsfw_level = preview_nsfw_level
if download_id and download_id in self._active_downloads and preview_targets:
self._active_downloads[download_id]["preview_path"] = preview_targets[0]
scanner = None
adjust_root: Optional[str] = None
# 5. Determine scanner and adjust metadata for cache consistency
if model_type == "checkpoint":
scanner = await self._get_checkpoint_scanner()
logger.info(f"Updating checkpoint cache for {save_path}")
logger.info(f"Updating checkpoint cache for {actual_file_paths[0]}")
elif model_type == "lora":
scanner = await self._get_lora_scanner()
logger.info(f"Updating lora cache for {save_path}")
logger.info(f"Updating lora cache for {actual_file_paths[0]}")
elif model_type == "embedding":
scanner = await ServiceRegistry.get_embedding_scanner()
logger.info(f"Updating embedding cache for {save_path}")
logger.info(f"Updating embedding cache for {actual_file_paths[0]}")
if scanner is not None:
file_path_for_adjust = getattr(metadata, "file_path", save_path)
if isinstance(file_path_for_adjust, str):
normalized_file_path = file_path_for_adjust.replace(os.sep, "/")
else:
normalized_file_path = str(file_path_for_adjust)
adjust_cached_entry = (
getattr(scanner, "adjust_cached_entry", None) if scanner is not None else None
)
find_root = getattr(scanner, "_find_root_for_file", None)
if callable(find_root):
try:
adjust_root = find_root(normalized_file_path)
except TypeError:
adjust_root = None
for index, entry in enumerate(metadata_entries):
file_path_for_adjust = getattr(entry, "file_path", actual_file_paths[index])
normalized_file_path = (
file_path_for_adjust.replace(os.sep, "/")
if isinstance(file_path_for_adjust, str)
else str(file_path_for_adjust)
)
adjust_metadata = getattr(scanner, "adjust_metadata", None)
if callable(adjust_metadata):
metadata = adjust_metadata(metadata, normalized_file_path, adjust_root)
if scanner is not None:
find_root = getattr(scanner, "_find_root_for_file", None)
adjust_root = None
if callable(find_root):
try:
adjust_root = find_root(normalized_file_path)
except TypeError:
adjust_root = None
# 6. Persist metadata with any adjustments
await MetadataManager.save_metadata(save_path, metadata)
adjust_metadata = getattr(scanner, "adjust_metadata", None)
if callable(adjust_metadata):
adjusted_entry = adjust_metadata(entry, normalized_file_path, adjust_root)
if adjusted_entry is not None:
entry = adjusted_entry
metadata_entries[index] = entry
# Convert metadata to dictionary
metadata_dict = metadata.to_dict()
adjust_cached_entry = getattr(scanner, "adjust_cached_entry", None) if scanner is not None else None
if callable(adjust_cached_entry):
metadata_dict = adjust_cached_entry(metadata_dict)
metadata_file_path = os.path.splitext(entry.file_path)[0] + '.metadata.json'
metadata_files_for_cleanup.append(metadata_file_path)
# Add model to cache and save to disk in a single operation
await scanner.add_model_to_cache(metadata_dict, relative_path)
await MetadataManager.save_metadata(entry.file_path, entry)
metadata_dict = entry.to_dict()
if callable(adjust_cached_entry):
metadata_dict = adjust_cached_entry(metadata_dict)
if scanner is not None:
await scanner.add_model_to_cache(metadata_dict, relative_path)
# Report 100% completion
if progress_callback:
await progress_callback(100)
return {
'success': True
}
return {'success': True}
except Exception as e:
logger.error(f"Error in _execute_download: {e}", exc_info=True)
# Clean up partial downloads except .part file
cleanup_files = [metadata_path]
if hasattr(metadata, 'preview_url') and metadata.preview_url and os.path.exists(metadata.preview_url):
cleanup_files.append(metadata.preview_url)
for path in cleanup_files:
cleanup_targets = {
path
for path in [save_path, metadata_path, *metadata_files_for_cleanup, *extracted_paths]
if path
}
preview_candidate = (
metadata_entries[0].preview_url
if metadata_entries
else getattr(metadata, "preview_url", None)
)
if preview_candidate:
cleanup_targets.add(preview_candidate)
cleanup_targets.update(preview_targets)
for path in cleanup_targets:
if path and os.path.exists(path):
try:
os.remove(path)
except Exception as e:
logger.warning(f"Failed to cleanup file {path}: {e}")
except Exception as exc:
logger.warning(f"Failed to cleanup file {path}: {exc}")
return {'success': False, 'error': str(e)}
async def _extract_safetensors_from_archive(self, archive_path: str) -> List[str]:
if not zipfile.is_zipfile(archive_path):
return []
target_dir = os.path.dirname(archive_path)
def _extract_sync() -> List[str]:
extracted_files: List[str] = []
with zipfile.ZipFile(archive_path, "r") as archive:
for info in archive.infolist():
if info.is_dir():
continue
if not info.filename.lower().endswith(".safetensors"):
continue
file_name = os.path.basename(info.filename)
if not file_name:
continue
dest_path = self._resolve_extracted_destination(target_dir, file_name)
with archive.open(info) as source, open(dest_path, "wb") as target:
shutil.copyfileobj(source, target)
extracted_files.append(dest_path)
return extracted_files
return await asyncio.to_thread(_extract_sync)
async def _build_metadata_entries(self, base_metadata, file_paths: List[str]) -> List:
if not file_paths:
return []
entries: List = []
for index, file_path in enumerate(file_paths):
entry = base_metadata if index == 0 else copy.deepcopy(base_metadata)
entry.update_file_info(file_path)
entry.sha256 = await calculate_sha256(file_path)
entries.append(entry)
return entries
def _resolve_extracted_destination(self, target_dir: str, filename: str) -> str:
base_name, extension = os.path.splitext(filename)
candidate = filename
destination = os.path.join(target_dir, candidate)
counter = 1
while os.path.exists(destination):
candidate = f"{base_name}-{counter}{extension}"
destination = os.path.join(target_dir, candidate)
counter += 1
return destination
def _distribute_preview_to_entries(self, preview_path: str, entries: List) -> List[str]:
if not preview_path or not entries:
return []
if not os.path.exists(preview_path):
return []
extension = os.path.splitext(preview_path)[1] or ".webp"
targets = [
os.path.splitext(entry.file_path)[0] + extension for entry in entries
]
if not targets:
return []
first_target = targets[0]
if preview_path != first_target:
os.replace(preview_path, first_target)
source_path = first_target
for target in targets[1:]:
shutil.copyfile(source_path, target)
return targets
async def _handle_download_progress(
self,
progress_update,
@@ -895,16 +1017,23 @@ class DownloadManager:
# Clean up ALL files including .part when user cancels
download_info = self._active_downloads.get(download_id)
if download_info:
# Delete the main file
if 'file_path' in download_info:
file_path = download_info['file_path']
target_files = set()
primary_path = download_info.get('file_path')
if primary_path:
target_files.add(primary_path)
for extra_path in download_info.get('extracted_paths', []):
if extra_path:
target_files.add(extra_path)
for file_path in target_files:
if os.path.exists(file_path):
try:
os.unlink(file_path)
logger.debug(f"Deleted cancelled download: {file_path}")
except Exception as e:
logger.error(f"Error deleting file: {e}")
# Delete the .part file (only on user cancellation)
if 'part_path' in download_info:
part_path = download_info['part_path']
@@ -914,10 +1043,9 @@ class DownloadManager:
logger.debug(f"Deleted partial download: {part_path}")
except Exception as e:
logger.error(f"Error deleting part file: {e}")
# Delete metadata file if exists
if 'file_path' in download_info:
file_path = download_info['file_path']
# Delete metadata files for each resolved path
for file_path in target_files:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
if os.path.exists(metadata_path):
try:
@@ -925,15 +1053,16 @@ class DownloadManager:
except Exception as e:
logger.error(f"Error deleting metadata file: {e}")
preview_path_value = download_info.get('preview_path')
if preview_path_value and os.path.exists(preview_path_value):
try:
os.unlink(preview_path_value)
logger.debug(f"Deleted preview file: {preview_path_value}")
except Exception as e:
logger.error(f"Error deleting preview file: {e}")
preview_path_value = download_info.get('preview_path')
if preview_path_value and os.path.exists(preview_path_value):
try:
os.unlink(preview_path_value)
logger.debug(f"Deleted preview file: {preview_path_value}")
except Exception as e:
logger.error(f"Error deleting preview file: {preview_path_value}")
# Delete preview file if exists (.webp or .mp4) for legacy paths
# Delete preview file if exists (.webp or .mp4) for legacy paths
for file_path in target_files:
for preview_ext in ['.webp', '.mp4']:
preview_path = os.path.splitext(file_path)[0] + preview_ext
if os.path.exists(preview_path):
@@ -941,8 +1070,7 @@ class DownloadManager:
os.unlink(preview_path)
logger.debug(f"Deleted preview file: {preview_path}")
except Exception as e:
logger.error(f"Error deleting preview file: {e}")
logger.error(f"Error deleting preview file: {preview_path}")
return {'success': True, 'message': 'Download cancelled successfully'}
except Exception as e:
logger.error(f"Error cancelling download: {e}", exc_info=True)

View File

@@ -1,5 +1,6 @@
import asyncio
import os
import zipfile
from datetime import datetime
from pathlib import Path
from typing import Optional
@@ -525,6 +526,177 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p
assert cached_entry["model_type"] == "diffusion_model"
async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
zip_path = save_dir / "bundle.zip"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, updated_path):
self.file_path = str(updated_path)
self.file_name = Path(updated_path).stem
def to_dict(self):
return {"file_path": self.file_path}
metadata = DummyMetadata(zip_path)
version_info = {"images": []}
download_urls = ["https://example.invalid/model.zip"]
class DummyDownloader:
async def download_file(self, *_args, **_kwargs):
with zipfile.ZipFile(str(zip_path), "w") as archive:
archive.writestr("inner/model.safetensors", b"model")
archive.writestr("docs/readme.txt", b"ignore")
return True, "ok"
monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()))
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner))
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
hash_calculator = AsyncMock(return_value="hash-single")
monkeypatch.setattr(download_manager, "calculate_sha256", hash_calculator)
result = await manager._execute_download(
download_urls=download_urls,
save_dir=str(save_dir),
metadata=metadata,
version_info=version_info,
relative_path="",
progress_callback=None,
model_type="lora",
download_id=None,
)
assert result == {"success": True}
assert not zip_path.exists()
extracted = save_dir / "model.safetensors"
assert extracted.exists()
assert hash_calculator.await_args.args[0] == str(extracted)
saved_call = MetadataManager.save_metadata.await_args
assert saved_call.args[0] == str(extracted)
assert saved_call.args[1].sha256 == "hash-single"
assert dummy_scanner.add_model_to_cache.await_count == 1
async def test_execute_download_extracts_zip_multiple_models(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
zip_path = save_dir / "bundle.zip"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, updated_path):
self.file_path = str(updated_path)
self.file_name = Path(updated_path).stem
def to_dict(self):
return {"file_path": self.file_path}
metadata = DummyMetadata(zip_path)
version_info = {"images": []}
download_urls = ["https://example.invalid/model.zip"]
class DummyDownloader:
async def download_file(self, *_args, **_kwargs):
with zipfile.ZipFile(str(zip_path), "w") as archive:
archive.writestr("first/model-one.safetensors", b"one")
archive.writestr("second/model-two.safetensors", b"two")
archive.writestr("readme.md", b"ignore")
return True, "ok"
monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()))
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner))
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
hash_calculator = AsyncMock(side_effect=["hash-one", "hash-two"])
monkeypatch.setattr(download_manager, "calculate_sha256", hash_calculator)
result = await manager._execute_download(
download_urls=download_urls,
save_dir=str(save_dir),
metadata=metadata,
version_info=version_info,
relative_path="",
progress_callback=None,
model_type="lora",
download_id=None,
)
assert result == {"success": True}
assert not zip_path.exists()
extracted_one = save_dir / "model-one.safetensors"
extracted_two = save_dir / "model-two.safetensors"
assert extracted_one.exists()
assert extracted_two.exists()
assert hash_calculator.await_count == 2
assert MetadataManager.save_metadata.await_count == 2
assert dummy_scanner.add_model_to_cache.await_count == 2
metadata_calls = MetadataManager.save_metadata.await_args_list
assert metadata_calls[0].args[0] == str(extracted_one)
assert metadata_calls[0].args[1].sha256 == "hash-one"
assert metadata_calls[1].args[0] == str(extracted_two)
assert metadata_calls[1].args[1].sha256 == "hash-two"
def test_distribute_preview_to_entries_moves_and_copies(tmp_path):
manager = DownloadManager()
preview_file = tmp_path / "bundle.webp"
preview_file.write_bytes(b"image-data")
entries = [
SimpleNamespace(file_path=str(tmp_path / "model-one.safetensors")),
SimpleNamespace(file_path=str(tmp_path / "model-two.safetensors")),
]
targets = manager._distribute_preview_to_entries(str(preview_file), entries)
assert targets == [
str(tmp_path / "model-one.webp"),
str(tmp_path / "model-two.webp"),
]
assert not preview_file.exists()
assert Path(targets[0]).read_bytes() == b"image-data"
assert Path(targets[1]).read_bytes() == b"image-data"
def test_distribute_preview_to_entries_keeps_existing_file(tmp_path):
manager = DownloadManager()
existing_preview = tmp_path / "model-one.webp"
existing_preview.write_bytes(b"preview")
entries = [
SimpleNamespace(file_path=str(tmp_path / "model-one.safetensors")),
SimpleNamespace(file_path=str(tmp_path / "model-two.safetensors")),
]
targets = manager._distribute_preview_to_entries(str(existing_preview), entries)
assert targets[0] == str(existing_preview)
assert Path(targets[1]).read_bytes() == b"preview"
async def test_pause_download_updates_state():
manager = DownloadManager()