mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
feat(downloads): support safetensors zips and previews
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@@ -525,6 +526,177 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p
|
||||
assert cached_entry["model_type"] == "diffusion_model"
|
||||
|
||||
|
||||
async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path):
|
||||
manager = DownloadManager()
|
||||
save_dir = tmp_path / "downloads"
|
||||
save_dir.mkdir()
|
||||
zip_path = save_dir / "bundle.zip"
|
||||
|
||||
class DummyMetadata:
|
||||
def __init__(self, path: Path):
|
||||
self.file_path = str(path)
|
||||
self.sha256 = "sha256"
|
||||
self.file_name = path.stem
|
||||
self.preview_url = None
|
||||
|
||||
def generate_unique_filename(self, *_args, **_kwargs):
|
||||
return os.path.basename(self.file_path)
|
||||
|
||||
def update_file_info(self, updated_path):
|
||||
self.file_path = str(updated_path)
|
||||
self.file_name = Path(updated_path).stem
|
||||
|
||||
def to_dict(self):
|
||||
return {"file_path": self.file_path}
|
||||
|
||||
metadata = DummyMetadata(zip_path)
|
||||
version_info = {"images": []}
|
||||
download_urls = ["https://example.invalid/model.zip"]
|
||||
|
||||
class DummyDownloader:
|
||||
async def download_file(self, *_args, **_kwargs):
|
||||
with zipfile.ZipFile(str(zip_path), "w") as archive:
|
||||
archive.writestr("inner/model.safetensors", b"model")
|
||||
archive.writestr("docs/readme.txt", b"ignore")
|
||||
return True, "ok"
|
||||
|
||||
monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()))
|
||||
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
||||
monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner))
|
||||
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
|
||||
hash_calculator = AsyncMock(return_value="hash-single")
|
||||
monkeypatch.setattr(download_manager, "calculate_sha256", hash_calculator)
|
||||
|
||||
result = await manager._execute_download(
|
||||
download_urls=download_urls,
|
||||
save_dir=str(save_dir),
|
||||
metadata=metadata,
|
||||
version_info=version_info,
|
||||
relative_path="",
|
||||
progress_callback=None,
|
||||
model_type="lora",
|
||||
download_id=None,
|
||||
)
|
||||
|
||||
assert result == {"success": True}
|
||||
assert not zip_path.exists()
|
||||
extracted = save_dir / "model.safetensors"
|
||||
assert extracted.exists()
|
||||
assert hash_calculator.await_args.args[0] == str(extracted)
|
||||
saved_call = MetadataManager.save_metadata.await_args
|
||||
assert saved_call.args[0] == str(extracted)
|
||||
assert saved_call.args[1].sha256 == "hash-single"
|
||||
assert dummy_scanner.add_model_to_cache.await_count == 1
|
||||
|
||||
|
||||
async def test_execute_download_extracts_zip_multiple_models(monkeypatch, tmp_path):
|
||||
manager = DownloadManager()
|
||||
save_dir = tmp_path / "downloads"
|
||||
save_dir.mkdir()
|
||||
zip_path = save_dir / "bundle.zip"
|
||||
|
||||
class DummyMetadata:
|
||||
def __init__(self, path: Path):
|
||||
self.file_path = str(path)
|
||||
self.sha256 = "sha256"
|
||||
self.file_name = path.stem
|
||||
self.preview_url = None
|
||||
|
||||
def generate_unique_filename(self, *_args, **_kwargs):
|
||||
return os.path.basename(self.file_path)
|
||||
|
||||
def update_file_info(self, updated_path):
|
||||
self.file_path = str(updated_path)
|
||||
self.file_name = Path(updated_path).stem
|
||||
|
||||
def to_dict(self):
|
||||
return {"file_path": self.file_path}
|
||||
|
||||
metadata = DummyMetadata(zip_path)
|
||||
version_info = {"images": []}
|
||||
download_urls = ["https://example.invalid/model.zip"]
|
||||
|
||||
class DummyDownloader:
|
||||
async def download_file(self, *_args, **_kwargs):
|
||||
with zipfile.ZipFile(str(zip_path), "w") as archive:
|
||||
archive.writestr("first/model-one.safetensors", b"one")
|
||||
archive.writestr("second/model-two.safetensors", b"two")
|
||||
archive.writestr("readme.md", b"ignore")
|
||||
return True, "ok"
|
||||
|
||||
monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()))
|
||||
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
||||
monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner))
|
||||
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
|
||||
hash_calculator = AsyncMock(side_effect=["hash-one", "hash-two"])
|
||||
monkeypatch.setattr(download_manager, "calculate_sha256", hash_calculator)
|
||||
|
||||
result = await manager._execute_download(
|
||||
download_urls=download_urls,
|
||||
save_dir=str(save_dir),
|
||||
metadata=metadata,
|
||||
version_info=version_info,
|
||||
relative_path="",
|
||||
progress_callback=None,
|
||||
model_type="lora",
|
||||
download_id=None,
|
||||
)
|
||||
|
||||
assert result == {"success": True}
|
||||
assert not zip_path.exists()
|
||||
extracted_one = save_dir / "model-one.safetensors"
|
||||
extracted_two = save_dir / "model-two.safetensors"
|
||||
assert extracted_one.exists()
|
||||
assert extracted_two.exists()
|
||||
|
||||
assert hash_calculator.await_count == 2
|
||||
assert MetadataManager.save_metadata.await_count == 2
|
||||
assert dummy_scanner.add_model_to_cache.await_count == 2
|
||||
|
||||
metadata_calls = MetadataManager.save_metadata.await_args_list
|
||||
assert metadata_calls[0].args[0] == str(extracted_one)
|
||||
assert metadata_calls[0].args[1].sha256 == "hash-one"
|
||||
assert metadata_calls[1].args[0] == str(extracted_two)
|
||||
assert metadata_calls[1].args[1].sha256 == "hash-two"
|
||||
|
||||
|
||||
def test_distribute_preview_to_entries_moves_and_copies(tmp_path):
|
||||
manager = DownloadManager()
|
||||
preview_file = tmp_path / "bundle.webp"
|
||||
preview_file.write_bytes(b"image-data")
|
||||
|
||||
entries = [
|
||||
SimpleNamespace(file_path=str(tmp_path / "model-one.safetensors")),
|
||||
SimpleNamespace(file_path=str(tmp_path / "model-two.safetensors")),
|
||||
]
|
||||
|
||||
targets = manager._distribute_preview_to_entries(str(preview_file), entries)
|
||||
|
||||
assert targets == [
|
||||
str(tmp_path / "model-one.webp"),
|
||||
str(tmp_path / "model-two.webp"),
|
||||
]
|
||||
assert not preview_file.exists()
|
||||
assert Path(targets[0]).read_bytes() == b"image-data"
|
||||
assert Path(targets[1]).read_bytes() == b"image-data"
|
||||
|
||||
|
||||
def test_distribute_preview_to_entries_keeps_existing_file(tmp_path):
|
||||
manager = DownloadManager()
|
||||
existing_preview = tmp_path / "model-one.webp"
|
||||
existing_preview.write_bytes(b"preview")
|
||||
|
||||
entries = [
|
||||
SimpleNamespace(file_path=str(tmp_path / "model-one.safetensors")),
|
||||
SimpleNamespace(file_path=str(tmp_path / "model-two.safetensors")),
|
||||
]
|
||||
|
||||
targets = manager._distribute_preview_to_entries(str(existing_preview), entries)
|
||||
|
||||
assert targets[0] == str(existing_preview)
|
||||
assert Path(targets[1]).read_bytes() == b"preview"
|
||||
|
||||
|
||||
async def test_pause_download_updates_state():
|
||||
manager = DownloadManager()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user