From 50004789912b9dbb08ddc18900b4dfd96a13cf78 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Sun, 7 Dec 2025 09:00:47 +0800 Subject: [PATCH] feat(download): support multiple model file extensions in archive extraction - Add `_get_supported_extensions_for_type` method to return allowed extensions per model type - Rename `_extract_safetensors_from_archive` to `_extract_model_files_from_archive` and extend to filter by allowed extensions - Update error message to list supported extensions when archive contains no valid files - Add test for extracting .pt embedding files from zip archives --- py/services/download_manager.py | 30 ++++++++++-- tests/services/test_download_manager.py | 63 +++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 5 deletions(-) diff --git a/py/services/download_manager.py b/py/services/download_manager.py index f497c0b6..f4b10b89 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -7,7 +7,7 @@ import shutil import zipfile from collections import OrderedDict import uuid -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Set, Tuple from urllib.parse import urlparse from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES @@ -774,9 +774,16 @@ class DownloadManager: # 4. Handle archive extraction and prepare per-file metadata actual_file_paths = [save_path] if zipfile.is_zipfile(save_path): - extracted_paths = await self._extract_safetensors_from_archive(save_path) + supported_extensions = self._get_supported_extensions_for_type(model_type) + extracted_paths = await self._extract_model_files_from_archive( + save_path, supported_extensions + ) if not extracted_paths: - return {'success': False, 'error': 'Zip archive does not contain any safetensors files'} + supported_text = ", ".join(sorted(supported_extensions)) + return { + 'success': False, + 'error': f'Zip archive does not contain any supported model files ({supported_text})', + } actual_file_paths = extracted_paths try: os.remove(save_path) @@ -877,11 +884,23 @@ class DownloadManager: return {'success': False, 'error': str(e)} - async def _extract_safetensors_from_archive(self, archive_path: str) -> List[str]: + def _get_supported_extensions_for_type(self, model_type: str) -> Set[str]: + if model_type == "checkpoint": + return {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft', '.gguf'} + if model_type == "embedding": + return {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft'} + return {'.safetensors'} + + async def _extract_model_files_from_archive( + self, + archive_path: str, + allowed_extensions: Optional[Set[str]] = None, + ) -> List[str]: if not zipfile.is_zipfile(archive_path): return [] target_dir = os.path.dirname(archive_path) + normalized_extensions = {ext.lower() for ext in allowed_extensions or {'.safetensors'}} def _extract_sync() -> List[str]: extracted_files: List[str] = [] @@ -889,7 +908,8 @@ class DownloadManager: for info in archive.infolist(): if info.is_dir(): continue - if not info.filename.lower().endswith(".safetensors"): + extension = os.path.splitext(info.filename)[1].lower() + if extension not in normalized_extensions: continue file_name = os.path.basename(info.filename) if not file_name: diff --git a/tests/services/test_download_manager.py b/tests/services/test_download_manager.py index b21632d9..669f8bf8 100644 --- a/tests/services/test_download_manager.py +++ b/tests/services/test_download_manager.py @@ -660,6 +660,69 @@ async def test_execute_download_extracts_zip_multiple_models(monkeypatch, tmp_pa assert metadata_calls[1].args[1].sha256 == "hash-two" +async def test_execute_download_extracts_zip_pt_embedding(monkeypatch, tmp_path): + manager = DownloadManager() + save_dir = tmp_path / "downloads" + save_dir.mkdir() + zip_path = save_dir / "bundle.zip" + + class DummyMetadata: + def __init__(self, path: Path): + self.file_path = str(path) + self.sha256 = "sha256" + self.file_name = path.stem + self.preview_url = None + + def generate_unique_filename(self, *_args, **_kwargs): + return os.path.basename(self.file_path) + + def update_file_info(self, updated_path): + self.file_path = str(updated_path) + self.file_name = Path(updated_path).stem + + def to_dict(self): + return {"file_path": self.file_path} + + metadata = DummyMetadata(zip_path) + version_info = {"images": []} + download_urls = ["https://example.invalid/model.zip"] + + class DummyDownloader: + async def download_file(self, *_args, **_kwargs): + with zipfile.ZipFile(str(zip_path), "w") as archive: + archive.writestr("inner/embedding.pt", b"embedding") + archive.writestr("docs/readme.txt", b"ignore") + return True, "ok" + + monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())) + dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) + monkeypatch.setattr(ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner)) + monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) + hash_calculator = AsyncMock(return_value="hash-pt") + monkeypatch.setattr(download_manager, "calculate_sha256", hash_calculator) + + result = await manager._execute_download( + download_urls=download_urls, + save_dir=str(save_dir), + metadata=metadata, + version_info=version_info, + relative_path="", + progress_callback=None, + model_type="embedding", + download_id=None, + ) + + assert result == {"success": True} + assert not zip_path.exists() + extracted = save_dir / "embedding.pt" + assert extracted.exists() + assert hash_calculator.await_args.args[0] == str(extracted) + saved_call = MetadataManager.save_metadata.await_args + assert saved_call.args[0] == str(extracted) + assert saved_call.args[1].sha256 == "hash-pt" + assert dummy_scanner.add_model_to_cache.await_count == 1 + + def test_distribute_preview_to_entries_moves_and_copies(tmp_path): manager = DownloadManager() preview_file = tmp_path / "bundle.webp"