feat(download): support multiple model file extensions in archive extraction

- Add `_get_supported_extensions_for_type` method to return allowed extensions per model type
- Rename `_extract_safetensors_from_archive` to `_extract_model_files_from_archive` and extend to filter by allowed extensions
- Update error message to list supported extensions when archive contains no valid files
- Add test for extracting .pt embedding files from zip archives
This commit is contained in:
Will Miao
2025-12-07 09:00:47 +08:00
parent 40cd2e23ac
commit 5000478991
2 changed files with 88 additions and 5 deletions

View File

@@ -7,7 +7,7 @@ import shutil
import zipfile
from collections import OrderedDict
import uuid
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Set, Tuple
from urllib.parse import urlparse
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES
@@ -774,9 +774,16 @@ class DownloadManager:
# 4. Handle archive extraction and prepare per-file metadata
actual_file_paths = [save_path]
if zipfile.is_zipfile(save_path):
extracted_paths = await self._extract_safetensors_from_archive(save_path)
supported_extensions = self._get_supported_extensions_for_type(model_type)
extracted_paths = await self._extract_model_files_from_archive(
save_path, supported_extensions
)
if not extracted_paths:
return {'success': False, 'error': 'Zip archive does not contain any safetensors files'}
supported_text = ", ".join(sorted(supported_extensions))
return {
'success': False,
'error': f'Zip archive does not contain any supported model files ({supported_text})',
}
actual_file_paths = extracted_paths
try:
os.remove(save_path)
@@ -877,11 +884,23 @@ class DownloadManager:
return {'success': False, 'error': str(e)}
async def _extract_safetensors_from_archive(self, archive_path: str) -> List[str]:
def _get_supported_extensions_for_type(self, model_type: str) -> Set[str]:
if model_type == "checkpoint":
return {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft', '.gguf'}
if model_type == "embedding":
return {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
return {'.safetensors'}
async def _extract_model_files_from_archive(
self,
archive_path: str,
allowed_extensions: Optional[Set[str]] = None,
) -> List[str]:
if not zipfile.is_zipfile(archive_path):
return []
target_dir = os.path.dirname(archive_path)
normalized_extensions = {ext.lower() for ext in allowed_extensions or {'.safetensors'}}
def _extract_sync() -> List[str]:
extracted_files: List[str] = []
@@ -889,7 +908,8 @@ class DownloadManager:
for info in archive.infolist():
if info.is_dir():
continue
if not info.filename.lower().endswith(".safetensors"):
extension = os.path.splitext(info.filename)[1].lower()
if extension not in normalized_extensions:
continue
file_name = os.path.basename(info.filename)
if not file_name:

View File

@@ -660,6 +660,69 @@ async def test_execute_download_extracts_zip_multiple_models(monkeypatch, tmp_pa
assert metadata_calls[1].args[1].sha256 == "hash-two"
async def test_execute_download_extracts_zip_pt_embedding(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
zip_path = save_dir / "bundle.zip"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, updated_path):
self.file_path = str(updated_path)
self.file_name = Path(updated_path).stem
def to_dict(self):
return {"file_path": self.file_path}
metadata = DummyMetadata(zip_path)
version_info = {"images": []}
download_urls = ["https://example.invalid/model.zip"]
class DummyDownloader:
async def download_file(self, *_args, **_kwargs):
with zipfile.ZipFile(str(zip_path), "w") as archive:
archive.writestr("inner/embedding.pt", b"embedding")
archive.writestr("docs/readme.txt", b"ignore")
return True, "ok"
monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()))
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner))
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
hash_calculator = AsyncMock(return_value="hash-pt")
monkeypatch.setattr(download_manager, "calculate_sha256", hash_calculator)
result = await manager._execute_download(
download_urls=download_urls,
save_dir=str(save_dir),
metadata=metadata,
version_info=version_info,
relative_path="",
progress_callback=None,
model_type="embedding",
download_id=None,
)
assert result == {"success": True}
assert not zip_path.exists()
extracted = save_dir / "embedding.pt"
assert extracted.exists()
assert hash_calculator.await_args.args[0] == str(extracted)
saved_call = MetadataManager.save_metadata.await_args
assert saved_call.args[0] == str(extracted)
assert saved_call.args[1].sha256 == "hash-pt"
assert dummy_scanner.add_model_to_cache.await_count == 1
def test_distribute_preview_to_entries_moves_and_copies(tmp_path):
manager = DownloadManager()
preview_file = tmp_path / "bundle.webp"