mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(download): support multiple model file extensions in archive extraction
- Add `_get_supported_extensions_for_type` method to return allowed extensions per model type - Rename `_extract_safetensors_from_archive` to `_extract_model_files_from_archive` and extend to filter by allowed extensions - Update error message to list supported extensions when archive contains no valid files - Add test for extracting .pt embedding files from zip archives
This commit is contained in:
@@ -7,7 +7,7 @@ import shutil
|
||||
import zipfile
|
||||
from collections import OrderedDict
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from urllib.parse import urlparse
|
||||
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES
|
||||
@@ -774,9 +774,16 @@ class DownloadManager:
|
||||
# 4. Handle archive extraction and prepare per-file metadata
|
||||
actual_file_paths = [save_path]
|
||||
if zipfile.is_zipfile(save_path):
|
||||
extracted_paths = await self._extract_safetensors_from_archive(save_path)
|
||||
supported_extensions = self._get_supported_extensions_for_type(model_type)
|
||||
extracted_paths = await self._extract_model_files_from_archive(
|
||||
save_path, supported_extensions
|
||||
)
|
||||
if not extracted_paths:
|
||||
return {'success': False, 'error': 'Zip archive does not contain any safetensors files'}
|
||||
supported_text = ", ".join(sorted(supported_extensions))
|
||||
return {
|
||||
'success': False,
|
||||
'error': f'Zip archive does not contain any supported model files ({supported_text})',
|
||||
}
|
||||
actual_file_paths = extracted_paths
|
||||
try:
|
||||
os.remove(save_path)
|
||||
@@ -877,11 +884,23 @@ class DownloadManager:
|
||||
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
async def _extract_safetensors_from_archive(self, archive_path: str) -> List[str]:
|
||||
def _get_supported_extensions_for_type(self, model_type: str) -> Set[str]:
|
||||
if model_type == "checkpoint":
|
||||
return {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft', '.gguf'}
|
||||
if model_type == "embedding":
|
||||
return {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
||||
return {'.safetensors'}
|
||||
|
||||
async def _extract_model_files_from_archive(
|
||||
self,
|
||||
archive_path: str,
|
||||
allowed_extensions: Optional[Set[str]] = None,
|
||||
) -> List[str]:
|
||||
if not zipfile.is_zipfile(archive_path):
|
||||
return []
|
||||
|
||||
target_dir = os.path.dirname(archive_path)
|
||||
normalized_extensions = {ext.lower() for ext in allowed_extensions or {'.safetensors'}}
|
||||
|
||||
def _extract_sync() -> List[str]:
|
||||
extracted_files: List[str] = []
|
||||
@@ -889,7 +908,8 @@ class DownloadManager:
|
||||
for info in archive.infolist():
|
||||
if info.is_dir():
|
||||
continue
|
||||
if not info.filename.lower().endswith(".safetensors"):
|
||||
extension = os.path.splitext(info.filename)[1].lower()
|
||||
if extension not in normalized_extensions:
|
||||
continue
|
||||
file_name = os.path.basename(info.filename)
|
||||
if not file_name:
|
||||
|
||||
Reference in New Issue
Block a user