From 9952721e7630eed22d2904940045322f4c5f2827 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Thu, 20 Nov 2025 19:41:31 +0800 Subject: [PATCH] feat(downloads): support safetensors zips and previews --- py/services/download_manager.py | 256 ++++++++++++++++++------ tests/services/test_download_manager.py | 172 ++++++++++++++++ 2 files changed, 364 insertions(+), 64 deletions(-) diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 1ece024d..f497c0b6 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -1,7 +1,10 @@ +import copy import logging import os import asyncio import inspect +import shutil +import zipfile from collections import OrderedDict import uuid from typing import Dict, List, Optional, Tuple @@ -12,6 +15,7 @@ from ..utils.civitai_utils import rewrite_preview_url from ..utils.preview_selection import select_preview_media from ..utils.utils import sanitize_folder_name from ..utils.exif_utils import ExifUtils +from ..utils.file_utils import calculate_sha256 from ..utils.metadata_manager import MetadataManager from .service_registry import ServiceRegistry from .settings_manager import get_settings_manager @@ -556,6 +560,13 @@ class DownloadManager: download_id: str = None, ) -> Dict: """Execute the actual download process including preview images and model files""" + metadata_entries: List = [] + metadata_files_for_cleanup: List[str] = [] + extracted_paths: List[str] = [] + metadata_path = "" + preview_targets: List[str] = [] + preview_path: str | None = None + preview_nsfw_level = 0 try: # Extract original filename details original_filename = os.path.basename(metadata.file_path) @@ -699,10 +710,9 @@ class DownloadManager: logger.warning(f"Failed to delete temp file: {e}") if preview_downloaded and preview_path: + preview_nsfw_level = nsfw_level metadata.preview_url = preview_path.replace(os.sep, '/') metadata.preview_nsfw_level = nsfw_level - if download_id and download_id in self._active_downloads: - self._active_downloads[download_id]['preview_path'] = preview_path if progress_callback: await progress_callback(3) # 3% progress after preview download @@ -761,77 +771,189 @@ class DownloadManager: return {'success': False, 'error': last_error or 'Failed to download file'} - # 4. Update file information (size and modified time) - metadata.update_file_info(save_path) + # 4. Handle archive extraction and prepare per-file metadata + actual_file_paths = [save_path] + if zipfile.is_zipfile(save_path): + extracted_paths = await self._extract_safetensors_from_archive(save_path) + if not extracted_paths: + return {'success': False, 'error': 'Zip archive does not contain any safetensors files'} + actual_file_paths = extracted_paths + try: + os.remove(save_path) + except OSError as exc: + logger.warning(f"Unable to delete temporary archive {save_path}: {exc}") + if download_id and download_id in self._active_downloads: + self._active_downloads[download_id]['file_path'] = extracted_paths[0] + self._active_downloads[download_id]['extracted_paths'] = extracted_paths + + metadata_entries = await self._build_metadata_entries(metadata, actual_file_paths) + if preview_path: + preview_targets = self._distribute_preview_to_entries(preview_path, metadata_entries) + for entry, target in zip(metadata_entries, preview_targets): + entry.preview_url = target.replace(os.sep, "/") + entry.preview_nsfw_level = preview_nsfw_level + if download_id and download_id in self._active_downloads and preview_targets: + self._active_downloads[download_id]["preview_path"] = preview_targets[0] scanner = None - adjust_root: Optional[str] = None - - # 5. Determine scanner and adjust metadata for cache consistency if model_type == "checkpoint": scanner = await self._get_checkpoint_scanner() - logger.info(f"Updating checkpoint cache for {save_path}") + logger.info(f"Updating checkpoint cache for {actual_file_paths[0]}") elif model_type == "lora": scanner = await self._get_lora_scanner() - logger.info(f"Updating lora cache for {save_path}") + logger.info(f"Updating lora cache for {actual_file_paths[0]}") elif model_type == "embedding": scanner = await ServiceRegistry.get_embedding_scanner() - logger.info(f"Updating embedding cache for {save_path}") + logger.info(f"Updating embedding cache for {actual_file_paths[0]}") - if scanner is not None: - file_path_for_adjust = getattr(metadata, "file_path", save_path) - if isinstance(file_path_for_adjust, str): - normalized_file_path = file_path_for_adjust.replace(os.sep, "/") - else: - normalized_file_path = str(file_path_for_adjust) + adjust_cached_entry = ( + getattr(scanner, "adjust_cached_entry", None) if scanner is not None else None + ) - find_root = getattr(scanner, "_find_root_for_file", None) - if callable(find_root): - try: - adjust_root = find_root(normalized_file_path) - except TypeError: - adjust_root = None + for index, entry in enumerate(metadata_entries): + file_path_for_adjust = getattr(entry, "file_path", actual_file_paths[index]) + normalized_file_path = ( + file_path_for_adjust.replace(os.sep, "/") + if isinstance(file_path_for_adjust, str) + else str(file_path_for_adjust) + ) - adjust_metadata = getattr(scanner, "adjust_metadata", None) - if callable(adjust_metadata): - metadata = adjust_metadata(metadata, normalized_file_path, adjust_root) + if scanner is not None: + find_root = getattr(scanner, "_find_root_for_file", None) + adjust_root = None + if callable(find_root): + try: + adjust_root = find_root(normalized_file_path) + except TypeError: + adjust_root = None - # 6. Persist metadata with any adjustments - await MetadataManager.save_metadata(save_path, metadata) + adjust_metadata = getattr(scanner, "adjust_metadata", None) + if callable(adjust_metadata): + adjusted_entry = adjust_metadata(entry, normalized_file_path, adjust_root) + if adjusted_entry is not None: + entry = adjusted_entry + metadata_entries[index] = entry - # Convert metadata to dictionary - metadata_dict = metadata.to_dict() - adjust_cached_entry = getattr(scanner, "adjust_cached_entry", None) if scanner is not None else None - if callable(adjust_cached_entry): - metadata_dict = adjust_cached_entry(metadata_dict) + metadata_file_path = os.path.splitext(entry.file_path)[0] + '.metadata.json' + metadata_files_for_cleanup.append(metadata_file_path) - # Add model to cache and save to disk in a single operation - await scanner.add_model_to_cache(metadata_dict, relative_path) + await MetadataManager.save_metadata(entry.file_path, entry) + + metadata_dict = entry.to_dict() + if callable(adjust_cached_entry): + metadata_dict = adjust_cached_entry(metadata_dict) + + if scanner is not None: + await scanner.add_model_to_cache(metadata_dict, relative_path) # Report 100% completion if progress_callback: await progress_callback(100) - return { - 'success': True - } + return {'success': True} except Exception as e: logger.error(f"Error in _execute_download: {e}", exc_info=True) - # Clean up partial downloads except .part file - cleanup_files = [metadata_path] - if hasattr(metadata, 'preview_url') and metadata.preview_url and os.path.exists(metadata.preview_url): - cleanup_files.append(metadata.preview_url) - - for path in cleanup_files: + cleanup_targets = { + path + for path in [save_path, metadata_path, *metadata_files_for_cleanup, *extracted_paths] + if path + } + preview_candidate = ( + metadata_entries[0].preview_url + if metadata_entries + else getattr(metadata, "preview_url", None) + ) + if preview_candidate: + cleanup_targets.add(preview_candidate) + + cleanup_targets.update(preview_targets) + for path in cleanup_targets: if path and os.path.exists(path): try: os.remove(path) - except Exception as e: - logger.warning(f"Failed to cleanup file {path}: {e}") - + except Exception as exc: + logger.warning(f"Failed to cleanup file {path}: {exc}") + return {'success': False, 'error': str(e)} + async def _extract_safetensors_from_archive(self, archive_path: str) -> List[str]: + if not zipfile.is_zipfile(archive_path): + return [] + + target_dir = os.path.dirname(archive_path) + + def _extract_sync() -> List[str]: + extracted_files: List[str] = [] + with zipfile.ZipFile(archive_path, "r") as archive: + for info in archive.infolist(): + if info.is_dir(): + continue + if not info.filename.lower().endswith(".safetensors"): + continue + file_name = os.path.basename(info.filename) + if not file_name: + continue + dest_path = self._resolve_extracted_destination(target_dir, file_name) + with archive.open(info) as source, open(dest_path, "wb") as target: + shutil.copyfileobj(source, target) + extracted_files.append(dest_path) + return extracted_files + + return await asyncio.to_thread(_extract_sync) + + async def _build_metadata_entries(self, base_metadata, file_paths: List[str]) -> List: + if not file_paths: + return [] + + entries: List = [] + for index, file_path in enumerate(file_paths): + entry = base_metadata if index == 0 else copy.deepcopy(base_metadata) + entry.update_file_info(file_path) + entry.sha256 = await calculate_sha256(file_path) + entries.append(entry) + + return entries + + def _resolve_extracted_destination(self, target_dir: str, filename: str) -> str: + base_name, extension = os.path.splitext(filename) + candidate = filename + destination = os.path.join(target_dir, candidate) + counter = 1 + + while os.path.exists(destination): + candidate = f"{base_name}-{counter}{extension}" + destination = os.path.join(target_dir, candidate) + counter += 1 + + return destination + + def _distribute_preview_to_entries(self, preview_path: str, entries: List) -> List[str]: + if not preview_path or not entries: + return [] + + if not os.path.exists(preview_path): + return [] + + extension = os.path.splitext(preview_path)[1] or ".webp" + + targets = [ + os.path.splitext(entry.file_path)[0] + extension for entry in entries + ] + + if not targets: + return [] + + first_target = targets[0] + if preview_path != first_target: + os.replace(preview_path, first_target) + source_path = first_target + + for target in targets[1:]: + shutil.copyfile(source_path, target) + + return targets + async def _handle_download_progress( self, progress_update, @@ -895,16 +1017,23 @@ class DownloadManager: # Clean up ALL files including .part when user cancels download_info = self._active_downloads.get(download_id) if download_info: - # Delete the main file - if 'file_path' in download_info: - file_path = download_info['file_path'] + target_files = set() + primary_path = download_info.get('file_path') + if primary_path: + target_files.add(primary_path) + + for extra_path in download_info.get('extracted_paths', []): + if extra_path: + target_files.add(extra_path) + + for file_path in target_files: if os.path.exists(file_path): try: os.unlink(file_path) logger.debug(f"Deleted cancelled download: {file_path}") except Exception as e: logger.error(f"Error deleting file: {e}") - + # Delete the .part file (only on user cancellation) if 'part_path' in download_info: part_path = download_info['part_path'] @@ -914,10 +1043,9 @@ class DownloadManager: logger.debug(f"Deleted partial download: {part_path}") except Exception as e: logger.error(f"Error deleting part file: {e}") - - # Delete metadata file if exists - if 'file_path' in download_info: - file_path = download_info['file_path'] + + # Delete metadata files for each resolved path + for file_path in target_files: metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' if os.path.exists(metadata_path): try: @@ -925,15 +1053,16 @@ class DownloadManager: except Exception as e: logger.error(f"Error deleting metadata file: {e}") - preview_path_value = download_info.get('preview_path') - if preview_path_value and os.path.exists(preview_path_value): - try: - os.unlink(preview_path_value) - logger.debug(f"Deleted preview file: {preview_path_value}") - except Exception as e: - logger.error(f"Error deleting preview file: {e}") + preview_path_value = download_info.get('preview_path') + if preview_path_value and os.path.exists(preview_path_value): + try: + os.unlink(preview_path_value) + logger.debug(f"Deleted preview file: {preview_path_value}") + except Exception as e: + logger.error(f"Error deleting preview file: {preview_path_value}") - # Delete preview file if exists (.webp or .mp4) for legacy paths + # Delete preview file if exists (.webp or .mp4) for legacy paths + for file_path in target_files: for preview_ext in ['.webp', '.mp4']: preview_path = os.path.splitext(file_path)[0] + preview_ext if os.path.exists(preview_path): @@ -941,8 +1070,7 @@ class DownloadManager: os.unlink(preview_path) logger.debug(f"Deleted preview file: {preview_path}") except Exception as e: - logger.error(f"Error deleting preview file: {e}") - + logger.error(f"Error deleting preview file: {preview_path}") return {'success': True, 'message': 'Download cancelled successfully'} except Exception as e: logger.error(f"Error cancelling download: {e}", exc_info=True) diff --git a/tests/services/test_download_manager.py b/tests/services/test_download_manager.py index c8bbeed8..b21632d9 100644 --- a/tests/services/test_download_manager.py +++ b/tests/services/test_download_manager.py @@ -1,5 +1,6 @@ import asyncio import os +import zipfile from datetime import datetime from pathlib import Path from typing import Optional @@ -525,6 +526,177 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p assert cached_entry["model_type"] == "diffusion_model" +async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path): + manager = DownloadManager() + save_dir = tmp_path / "downloads" + save_dir.mkdir() + zip_path = save_dir / "bundle.zip" + + class DummyMetadata: + def __init__(self, path: Path): + self.file_path = str(path) + self.sha256 = "sha256" + self.file_name = path.stem + self.preview_url = None + + def generate_unique_filename(self, *_args, **_kwargs): + return os.path.basename(self.file_path) + + def update_file_info(self, updated_path): + self.file_path = str(updated_path) + self.file_name = Path(updated_path).stem + + def to_dict(self): + return {"file_path": self.file_path} + + metadata = DummyMetadata(zip_path) + version_info = {"images": []} + download_urls = ["https://example.invalid/model.zip"] + + class DummyDownloader: + async def download_file(self, *_args, **_kwargs): + with zipfile.ZipFile(str(zip_path), "w") as archive: + archive.writestr("inner/model.safetensors", b"model") + archive.writestr("docs/readme.txt", b"ignore") + return True, "ok" + + monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())) + dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) + monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)) + monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) + hash_calculator = AsyncMock(return_value="hash-single") + monkeypatch.setattr(download_manager, "calculate_sha256", hash_calculator) + + result = await manager._execute_download( + download_urls=download_urls, + save_dir=str(save_dir), + metadata=metadata, + version_info=version_info, + relative_path="", + progress_callback=None, + model_type="lora", + download_id=None, + ) + + assert result == {"success": True} + assert not zip_path.exists() + extracted = save_dir / "model.safetensors" + assert extracted.exists() + assert hash_calculator.await_args.args[0] == str(extracted) + saved_call = MetadataManager.save_metadata.await_args + assert saved_call.args[0] == str(extracted) + assert saved_call.args[1].sha256 == "hash-single" + assert dummy_scanner.add_model_to_cache.await_count == 1 + + +async def test_execute_download_extracts_zip_multiple_models(monkeypatch, tmp_path): + manager = DownloadManager() + save_dir = tmp_path / "downloads" + save_dir.mkdir() + zip_path = save_dir / "bundle.zip" + + class DummyMetadata: + def __init__(self, path: Path): + self.file_path = str(path) + self.sha256 = "sha256" + self.file_name = path.stem + self.preview_url = None + + def generate_unique_filename(self, *_args, **_kwargs): + return os.path.basename(self.file_path) + + def update_file_info(self, updated_path): + self.file_path = str(updated_path) + self.file_name = Path(updated_path).stem + + def to_dict(self): + return {"file_path": self.file_path} + + metadata = DummyMetadata(zip_path) + version_info = {"images": []} + download_urls = ["https://example.invalid/model.zip"] + + class DummyDownloader: + async def download_file(self, *_args, **_kwargs): + with zipfile.ZipFile(str(zip_path), "w") as archive: + archive.writestr("first/model-one.safetensors", b"one") + archive.writestr("second/model-two.safetensors", b"two") + archive.writestr("readme.md", b"ignore") + return True, "ok" + + monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())) + dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) + monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)) + monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) + hash_calculator = AsyncMock(side_effect=["hash-one", "hash-two"]) + monkeypatch.setattr(download_manager, "calculate_sha256", hash_calculator) + + result = await manager._execute_download( + download_urls=download_urls, + save_dir=str(save_dir), + metadata=metadata, + version_info=version_info, + relative_path="", + progress_callback=None, + model_type="lora", + download_id=None, + ) + + assert result == {"success": True} + assert not zip_path.exists() + extracted_one = save_dir / "model-one.safetensors" + extracted_two = save_dir / "model-two.safetensors" + assert extracted_one.exists() + assert extracted_two.exists() + + assert hash_calculator.await_count == 2 + assert MetadataManager.save_metadata.await_count == 2 + assert dummy_scanner.add_model_to_cache.await_count == 2 + + metadata_calls = MetadataManager.save_metadata.await_args_list + assert metadata_calls[0].args[0] == str(extracted_one) + assert metadata_calls[0].args[1].sha256 == "hash-one" + assert metadata_calls[1].args[0] == str(extracted_two) + assert metadata_calls[1].args[1].sha256 == "hash-two" + + +def test_distribute_preview_to_entries_moves_and_copies(tmp_path): + manager = DownloadManager() + preview_file = tmp_path / "bundle.webp" + preview_file.write_bytes(b"image-data") + + entries = [ + SimpleNamespace(file_path=str(tmp_path / "model-one.safetensors")), + SimpleNamespace(file_path=str(tmp_path / "model-two.safetensors")), + ] + + targets = manager._distribute_preview_to_entries(str(preview_file), entries) + + assert targets == [ + str(tmp_path / "model-one.webp"), + str(tmp_path / "model-two.webp"), + ] + assert not preview_file.exists() + assert Path(targets[0]).read_bytes() == b"image-data" + assert Path(targets[1]).read_bytes() == b"image-data" + + +def test_distribute_preview_to_entries_keeps_existing_file(tmp_path): + manager = DownloadManager() + existing_preview = tmp_path / "model-one.webp" + existing_preview.write_bytes(b"preview") + + entries = [ + SimpleNamespace(file_path=str(tmp_path / "model-one.safetensors")), + SimpleNamespace(file_path=str(tmp_path / "model-two.safetensors")), + ] + + targets = manager._distribute_preview_to_entries(str(existing_preview), entries) + + assert targets[0] == str(existing_preview) + assert Path(targets[1]).read_bytes() == b"preview" + + async def test_pause_download_updates_state(): manager = DownloadManager()