feat(download): support multiple model file extensions in archive extraction

- Add `_get_supported_extensions_for_type` method to return allowed extensions per model type
- Rename `_extract_safetensors_from_archive` to `_extract_model_files_from_archive` and extend to filter by allowed extensions
- Update error message to list supported extensions when archive contains no valid files
- Add test for extracting .pt embedding files from zip archives
This commit is contained in:
Will Miao
2025-12-07 09:00:47 +08:00
parent 40cd2e23ac
commit 5000478991
2 changed files with 88 additions and 5 deletions

View File

@@ -7,7 +7,7 @@ import shutil
import zipfile
from collections import OrderedDict
import uuid
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Set, Tuple
from urllib.parse import urlparse
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES
@@ -774,9 +774,16 @@ class DownloadManager:
# 4. Handle archive extraction and prepare per-file metadata
actual_file_paths = [save_path]
if zipfile.is_zipfile(save_path):
extracted_paths = await self._extract_safetensors_from_archive(save_path)
supported_extensions = self._get_supported_extensions_for_type(model_type)
extracted_paths = await self._extract_model_files_from_archive(
save_path, supported_extensions
)
if not extracted_paths:
return {'success': False, 'error': 'Zip archive does not contain any safetensors files'}
supported_text = ", ".join(sorted(supported_extensions))
return {
'success': False,
'error': f'Zip archive does not contain any supported model files ({supported_text})',
}
actual_file_paths = extracted_paths
try:
os.remove(save_path)
@@ -877,11 +884,23 @@ class DownloadManager:
return {'success': False, 'error': str(e)}
async def _extract_safetensors_from_archive(self, archive_path: str) -> List[str]:
def _get_supported_extensions_for_type(self, model_type: str) -> Set[str]:
if model_type == "checkpoint":
return {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft', '.gguf'}
if model_type == "embedding":
return {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
return {'.safetensors'}
async def _extract_model_files_from_archive(
self,
archive_path: str,
allowed_extensions: Optional[Set[str]] = None,
) -> List[str]:
if not zipfile.is_zipfile(archive_path):
return []
target_dir = os.path.dirname(archive_path)
normalized_extensions = {ext.lower() for ext in allowed_extensions or {'.safetensors'}}
def _extract_sync() -> List[str]:
extracted_files: List[str] = []
@@ -889,7 +908,8 @@ class DownloadManager:
for info in archive.infolist():
if info.is_dir():
continue
if not info.filename.lower().endswith(".safetensors"):
extension = os.path.splitext(info.filename)[1].lower()
if extension not in normalized_extensions:
continue
file_name = os.path.basename(info.filename)
if not file_name: