mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
feat(download): support multiple model file extensions in archive extraction
- Add `_get_supported_extensions_for_type` method to return allowed extensions per model type - Rename `_extract_safetensors_from_archive` to `_extract_model_files_from_archive` and extend to filter by allowed extensions - Update error message to list supported extensions when archive contains no valid files - Add test for extracting .pt embedding files from zip archives
This commit is contained in:
@@ -7,7 +7,7 @@ import shutil
|
|||||||
import zipfile
|
import zipfile
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
||||||
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES
|
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES
|
||||||
@@ -774,9 +774,16 @@ class DownloadManager:
|
|||||||
# 4. Handle archive extraction and prepare per-file metadata
|
# 4. Handle archive extraction and prepare per-file metadata
|
||||||
actual_file_paths = [save_path]
|
actual_file_paths = [save_path]
|
||||||
if zipfile.is_zipfile(save_path):
|
if zipfile.is_zipfile(save_path):
|
||||||
extracted_paths = await self._extract_safetensors_from_archive(save_path)
|
supported_extensions = self._get_supported_extensions_for_type(model_type)
|
||||||
|
extracted_paths = await self._extract_model_files_from_archive(
|
||||||
|
save_path, supported_extensions
|
||||||
|
)
|
||||||
if not extracted_paths:
|
if not extracted_paths:
|
||||||
return {'success': False, 'error': 'Zip archive does not contain any safetensors files'}
|
supported_text = ", ".join(sorted(supported_extensions))
|
||||||
|
return {
|
||||||
|
'success': False,
|
||||||
|
'error': f'Zip archive does not contain any supported model files ({supported_text})',
|
||||||
|
}
|
||||||
actual_file_paths = extracted_paths
|
actual_file_paths = extracted_paths
|
||||||
try:
|
try:
|
||||||
os.remove(save_path)
|
os.remove(save_path)
|
||||||
@@ -877,11 +884,23 @@ class DownloadManager:
|
|||||||
|
|
||||||
return {'success': False, 'error': str(e)}
|
return {'success': False, 'error': str(e)}
|
||||||
|
|
||||||
async def _extract_safetensors_from_archive(self, archive_path: str) -> List[str]:
|
def _get_supported_extensions_for_type(self, model_type: str) -> Set[str]:
|
||||||
|
if model_type == "checkpoint":
|
||||||
|
return {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft', '.gguf'}
|
||||||
|
if model_type == "embedding":
|
||||||
|
return {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
||||||
|
return {'.safetensors'}
|
||||||
|
|
||||||
|
async def _extract_model_files_from_archive(
|
||||||
|
self,
|
||||||
|
archive_path: str,
|
||||||
|
allowed_extensions: Optional[Set[str]] = None,
|
||||||
|
) -> List[str]:
|
||||||
if not zipfile.is_zipfile(archive_path):
|
if not zipfile.is_zipfile(archive_path):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
target_dir = os.path.dirname(archive_path)
|
target_dir = os.path.dirname(archive_path)
|
||||||
|
normalized_extensions = {ext.lower() for ext in allowed_extensions or {'.safetensors'}}
|
||||||
|
|
||||||
def _extract_sync() -> List[str]:
|
def _extract_sync() -> List[str]:
|
||||||
extracted_files: List[str] = []
|
extracted_files: List[str] = []
|
||||||
@@ -889,7 +908,8 @@ class DownloadManager:
|
|||||||
for info in archive.infolist():
|
for info in archive.infolist():
|
||||||
if info.is_dir():
|
if info.is_dir():
|
||||||
continue
|
continue
|
||||||
if not info.filename.lower().endswith(".safetensors"):
|
extension = os.path.splitext(info.filename)[1].lower()
|
||||||
|
if extension not in normalized_extensions:
|
||||||
continue
|
continue
|
||||||
file_name = os.path.basename(info.filename)
|
file_name = os.path.basename(info.filename)
|
||||||
if not file_name:
|
if not file_name:
|
||||||
|
|||||||
@@ -660,6 +660,69 @@ async def test_execute_download_extracts_zip_multiple_models(monkeypatch, tmp_pa
|
|||||||
assert metadata_calls[1].args[1].sha256 == "hash-two"
|
assert metadata_calls[1].args[1].sha256 == "hash-two"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_execute_download_extracts_zip_pt_embedding(monkeypatch, tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
save_dir = tmp_path / "downloads"
|
||||||
|
save_dir.mkdir()
|
||||||
|
zip_path = save_dir / "bundle.zip"
|
||||||
|
|
||||||
|
class DummyMetadata:
|
||||||
|
def __init__(self, path: Path):
|
||||||
|
self.file_path = str(path)
|
||||||
|
self.sha256 = "sha256"
|
||||||
|
self.file_name = path.stem
|
||||||
|
self.preview_url = None
|
||||||
|
|
||||||
|
def generate_unique_filename(self, *_args, **_kwargs):
|
||||||
|
return os.path.basename(self.file_path)
|
||||||
|
|
||||||
|
def update_file_info(self, updated_path):
|
||||||
|
self.file_path = str(updated_path)
|
||||||
|
self.file_name = Path(updated_path).stem
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {"file_path": self.file_path}
|
||||||
|
|
||||||
|
metadata = DummyMetadata(zip_path)
|
||||||
|
version_info = {"images": []}
|
||||||
|
download_urls = ["https://example.invalid/model.zip"]
|
||||||
|
|
||||||
|
class DummyDownloader:
|
||||||
|
async def download_file(self, *_args, **_kwargs):
|
||||||
|
with zipfile.ZipFile(str(zip_path), "w") as archive:
|
||||||
|
archive.writestr("inner/embedding.pt", b"embedding")
|
||||||
|
archive.writestr("docs/readme.txt", b"ignore")
|
||||||
|
return True, "ok"
|
||||||
|
|
||||||
|
monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()))
|
||||||
|
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
||||||
|
monkeypatch.setattr(ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner))
|
||||||
|
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
|
||||||
|
hash_calculator = AsyncMock(return_value="hash-pt")
|
||||||
|
monkeypatch.setattr(download_manager, "calculate_sha256", hash_calculator)
|
||||||
|
|
||||||
|
result = await manager._execute_download(
|
||||||
|
download_urls=download_urls,
|
||||||
|
save_dir=str(save_dir),
|
||||||
|
metadata=metadata,
|
||||||
|
version_info=version_info,
|
||||||
|
relative_path="",
|
||||||
|
progress_callback=None,
|
||||||
|
model_type="embedding",
|
||||||
|
download_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"success": True}
|
||||||
|
assert not zip_path.exists()
|
||||||
|
extracted = save_dir / "embedding.pt"
|
||||||
|
assert extracted.exists()
|
||||||
|
assert hash_calculator.await_args.args[0] == str(extracted)
|
||||||
|
saved_call = MetadataManager.save_metadata.await_args
|
||||||
|
assert saved_call.args[0] == str(extracted)
|
||||||
|
assert saved_call.args[1].sha256 == "hash-pt"
|
||||||
|
assert dummy_scanner.add_model_to_cache.await_count == 1
|
||||||
|
|
||||||
|
|
||||||
def test_distribute_preview_to_entries_moves_and_copies(tmp_path):
|
def test_distribute_preview_to_entries_moves_and_copies(tmp_path):
|
||||||
manager = DownloadManager()
|
manager = DownloadManager()
|
||||||
preview_file = tmp_path / "bundle.webp"
|
preview_file = tmp_path / "bundle.webp"
|
||||||
|
|||||||
Reference in New Issue
Block a user