mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Merge pull request #691 from willmiao/feat/zip-preview
feat(downloads): support safetensors zips and previews
This commit is contained in:
@@ -1,7 +1,10 @@
|
|||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
|
import shutil
|
||||||
|
import zipfile
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@@ -12,6 +15,7 @@ from ..utils.civitai_utils import rewrite_preview_url
|
|||||||
from ..utils.preview_selection import select_preview_media
|
from ..utils.preview_selection import select_preview_media
|
||||||
from ..utils.utils import sanitize_folder_name
|
from ..utils.utils import sanitize_folder_name
|
||||||
from ..utils.exif_utils import ExifUtils
|
from ..utils.exif_utils import ExifUtils
|
||||||
|
from ..utils.file_utils import calculate_sha256
|
||||||
from ..utils.metadata_manager import MetadataManager
|
from ..utils.metadata_manager import MetadataManager
|
||||||
from .service_registry import ServiceRegistry
|
from .service_registry import ServiceRegistry
|
||||||
from .settings_manager import get_settings_manager
|
from .settings_manager import get_settings_manager
|
||||||
@@ -556,6 +560,13 @@ class DownloadManager:
|
|||||||
download_id: str = None,
|
download_id: str = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Execute the actual download process including preview images and model files"""
|
"""Execute the actual download process including preview images and model files"""
|
||||||
|
metadata_entries: List = []
|
||||||
|
metadata_files_for_cleanup: List[str] = []
|
||||||
|
extracted_paths: List[str] = []
|
||||||
|
metadata_path = ""
|
||||||
|
preview_targets: List[str] = []
|
||||||
|
preview_path: str | None = None
|
||||||
|
preview_nsfw_level = 0
|
||||||
try:
|
try:
|
||||||
# Extract original filename details
|
# Extract original filename details
|
||||||
original_filename = os.path.basename(metadata.file_path)
|
original_filename = os.path.basename(metadata.file_path)
|
||||||
@@ -699,10 +710,9 @@ class DownloadManager:
|
|||||||
logger.warning(f"Failed to delete temp file: {e}")
|
logger.warning(f"Failed to delete temp file: {e}")
|
||||||
|
|
||||||
if preview_downloaded and preview_path:
|
if preview_downloaded and preview_path:
|
||||||
|
preview_nsfw_level = nsfw_level
|
||||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
metadata.preview_url = preview_path.replace(os.sep, '/')
|
||||||
metadata.preview_nsfw_level = nsfw_level
|
metadata.preview_nsfw_level = nsfw_level
|
||||||
if download_id and download_id in self._active_downloads:
|
|
||||||
self._active_downloads[download_id]['preview_path'] = preview_path
|
|
||||||
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
await progress_callback(3) # 3% progress after preview download
|
await progress_callback(3) # 3% progress after preview download
|
||||||
@@ -761,77 +771,189 @@ class DownloadManager:
|
|||||||
|
|
||||||
return {'success': False, 'error': last_error or 'Failed to download file'}
|
return {'success': False, 'error': last_error or 'Failed to download file'}
|
||||||
|
|
||||||
# 4. Update file information (size and modified time)
|
# 4. Handle archive extraction and prepare per-file metadata
|
||||||
metadata.update_file_info(save_path)
|
actual_file_paths = [save_path]
|
||||||
|
if zipfile.is_zipfile(save_path):
|
||||||
|
extracted_paths = await self._extract_safetensors_from_archive(save_path)
|
||||||
|
if not extracted_paths:
|
||||||
|
return {'success': False, 'error': 'Zip archive does not contain any safetensors files'}
|
||||||
|
actual_file_paths = extracted_paths
|
||||||
|
try:
|
||||||
|
os.remove(save_path)
|
||||||
|
except OSError as exc:
|
||||||
|
logger.warning(f"Unable to delete temporary archive {save_path}: {exc}")
|
||||||
|
if download_id and download_id in self._active_downloads:
|
||||||
|
self._active_downloads[download_id]['file_path'] = extracted_paths[0]
|
||||||
|
self._active_downloads[download_id]['extracted_paths'] = extracted_paths
|
||||||
|
|
||||||
|
metadata_entries = await self._build_metadata_entries(metadata, actual_file_paths)
|
||||||
|
if preview_path:
|
||||||
|
preview_targets = self._distribute_preview_to_entries(preview_path, metadata_entries)
|
||||||
|
for entry, target in zip(metadata_entries, preview_targets):
|
||||||
|
entry.preview_url = target.replace(os.sep, "/")
|
||||||
|
entry.preview_nsfw_level = preview_nsfw_level
|
||||||
|
if download_id and download_id in self._active_downloads and preview_targets:
|
||||||
|
self._active_downloads[download_id]["preview_path"] = preview_targets[0]
|
||||||
|
|
||||||
scanner = None
|
scanner = None
|
||||||
adjust_root: Optional[str] = None
|
|
||||||
|
|
||||||
# 5. Determine scanner and adjust metadata for cache consistency
|
|
||||||
if model_type == "checkpoint":
|
if model_type == "checkpoint":
|
||||||
scanner = await self._get_checkpoint_scanner()
|
scanner = await self._get_checkpoint_scanner()
|
||||||
logger.info(f"Updating checkpoint cache for {save_path}")
|
logger.info(f"Updating checkpoint cache for {actual_file_paths[0]}")
|
||||||
elif model_type == "lora":
|
elif model_type == "lora":
|
||||||
scanner = await self._get_lora_scanner()
|
scanner = await self._get_lora_scanner()
|
||||||
logger.info(f"Updating lora cache for {save_path}")
|
logger.info(f"Updating lora cache for {actual_file_paths[0]}")
|
||||||
elif model_type == "embedding":
|
elif model_type == "embedding":
|
||||||
scanner = await ServiceRegistry.get_embedding_scanner()
|
scanner = await ServiceRegistry.get_embedding_scanner()
|
||||||
logger.info(f"Updating embedding cache for {save_path}")
|
logger.info(f"Updating embedding cache for {actual_file_paths[0]}")
|
||||||
|
|
||||||
if scanner is not None:
|
adjust_cached_entry = (
|
||||||
file_path_for_adjust = getattr(metadata, "file_path", save_path)
|
getattr(scanner, "adjust_cached_entry", None) if scanner is not None else None
|
||||||
if isinstance(file_path_for_adjust, str):
|
)
|
||||||
normalized_file_path = file_path_for_adjust.replace(os.sep, "/")
|
|
||||||
else:
|
|
||||||
normalized_file_path = str(file_path_for_adjust)
|
|
||||||
|
|
||||||
find_root = getattr(scanner, "_find_root_for_file", None)
|
for index, entry in enumerate(metadata_entries):
|
||||||
if callable(find_root):
|
file_path_for_adjust = getattr(entry, "file_path", actual_file_paths[index])
|
||||||
try:
|
normalized_file_path = (
|
||||||
adjust_root = find_root(normalized_file_path)
|
file_path_for_adjust.replace(os.sep, "/")
|
||||||
except TypeError:
|
if isinstance(file_path_for_adjust, str)
|
||||||
adjust_root = None
|
else str(file_path_for_adjust)
|
||||||
|
)
|
||||||
|
|
||||||
adjust_metadata = getattr(scanner, "adjust_metadata", None)
|
if scanner is not None:
|
||||||
if callable(adjust_metadata):
|
find_root = getattr(scanner, "_find_root_for_file", None)
|
||||||
metadata = adjust_metadata(metadata, normalized_file_path, adjust_root)
|
adjust_root = None
|
||||||
|
if callable(find_root):
|
||||||
|
try:
|
||||||
|
adjust_root = find_root(normalized_file_path)
|
||||||
|
except TypeError:
|
||||||
|
adjust_root = None
|
||||||
|
|
||||||
# 6. Persist metadata with any adjustments
|
adjust_metadata = getattr(scanner, "adjust_metadata", None)
|
||||||
await MetadataManager.save_metadata(save_path, metadata)
|
if callable(adjust_metadata):
|
||||||
|
adjusted_entry = adjust_metadata(entry, normalized_file_path, adjust_root)
|
||||||
|
if adjusted_entry is not None:
|
||||||
|
entry = adjusted_entry
|
||||||
|
metadata_entries[index] = entry
|
||||||
|
|
||||||
# Convert metadata to dictionary
|
metadata_file_path = os.path.splitext(entry.file_path)[0] + '.metadata.json'
|
||||||
metadata_dict = metadata.to_dict()
|
metadata_files_for_cleanup.append(metadata_file_path)
|
||||||
adjust_cached_entry = getattr(scanner, "adjust_cached_entry", None) if scanner is not None else None
|
|
||||||
if callable(adjust_cached_entry):
|
|
||||||
metadata_dict = adjust_cached_entry(metadata_dict)
|
|
||||||
|
|
||||||
# Add model to cache and save to disk in a single operation
|
await MetadataManager.save_metadata(entry.file_path, entry)
|
||||||
await scanner.add_model_to_cache(metadata_dict, relative_path)
|
|
||||||
|
metadata_dict = entry.to_dict()
|
||||||
|
if callable(adjust_cached_entry):
|
||||||
|
metadata_dict = adjust_cached_entry(metadata_dict)
|
||||||
|
|
||||||
|
if scanner is not None:
|
||||||
|
await scanner.add_model_to_cache(metadata_dict, relative_path)
|
||||||
|
|
||||||
# Report 100% completion
|
# Report 100% completion
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
await progress_callback(100)
|
await progress_callback(100)
|
||||||
|
|
||||||
return {
|
return {'success': True}
|
||||||
'success': True
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in _execute_download: {e}", exc_info=True)
|
logger.error(f"Error in _execute_download: {e}", exc_info=True)
|
||||||
# Clean up partial downloads except .part file
|
cleanup_targets = {
|
||||||
cleanup_files = [metadata_path]
|
path
|
||||||
if hasattr(metadata, 'preview_url') and metadata.preview_url and os.path.exists(metadata.preview_url):
|
for path in [save_path, metadata_path, *metadata_files_for_cleanup, *extracted_paths]
|
||||||
cleanup_files.append(metadata.preview_url)
|
if path
|
||||||
|
}
|
||||||
for path in cleanup_files:
|
preview_candidate = (
|
||||||
|
metadata_entries[0].preview_url
|
||||||
|
if metadata_entries
|
||||||
|
else getattr(metadata, "preview_url", None)
|
||||||
|
)
|
||||||
|
if preview_candidate:
|
||||||
|
cleanup_targets.add(preview_candidate)
|
||||||
|
|
||||||
|
cleanup_targets.update(preview_targets)
|
||||||
|
for path in cleanup_targets:
|
||||||
if path and os.path.exists(path):
|
if path and os.path.exists(path):
|
||||||
try:
|
try:
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
except Exception as e:
|
except Exception as exc:
|
||||||
logger.warning(f"Failed to cleanup file {path}: {e}")
|
logger.warning(f"Failed to cleanup file {path}: {exc}")
|
||||||
|
|
||||||
return {'success': False, 'error': str(e)}
|
return {'success': False, 'error': str(e)}
|
||||||
|
|
||||||
|
async def _extract_safetensors_from_archive(self, archive_path: str) -> List[str]:
|
||||||
|
if not zipfile.is_zipfile(archive_path):
|
||||||
|
return []
|
||||||
|
|
||||||
|
target_dir = os.path.dirname(archive_path)
|
||||||
|
|
||||||
|
def _extract_sync() -> List[str]:
|
||||||
|
extracted_files: List[str] = []
|
||||||
|
with zipfile.ZipFile(archive_path, "r") as archive:
|
||||||
|
for info in archive.infolist():
|
||||||
|
if info.is_dir():
|
||||||
|
continue
|
||||||
|
if not info.filename.lower().endswith(".safetensors"):
|
||||||
|
continue
|
||||||
|
file_name = os.path.basename(info.filename)
|
||||||
|
if not file_name:
|
||||||
|
continue
|
||||||
|
dest_path = self._resolve_extracted_destination(target_dir, file_name)
|
||||||
|
with archive.open(info) as source, open(dest_path, "wb") as target:
|
||||||
|
shutil.copyfileobj(source, target)
|
||||||
|
extracted_files.append(dest_path)
|
||||||
|
return extracted_files
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_extract_sync)
|
||||||
|
|
||||||
|
async def _build_metadata_entries(self, base_metadata, file_paths: List[str]) -> List:
|
||||||
|
if not file_paths:
|
||||||
|
return []
|
||||||
|
|
||||||
|
entries: List = []
|
||||||
|
for index, file_path in enumerate(file_paths):
|
||||||
|
entry = base_metadata if index == 0 else copy.deepcopy(base_metadata)
|
||||||
|
entry.update_file_info(file_path)
|
||||||
|
entry.sha256 = await calculate_sha256(file_path)
|
||||||
|
entries.append(entry)
|
||||||
|
|
||||||
|
return entries
|
||||||
|
|
||||||
|
def _resolve_extracted_destination(self, target_dir: str, filename: str) -> str:
|
||||||
|
base_name, extension = os.path.splitext(filename)
|
||||||
|
candidate = filename
|
||||||
|
destination = os.path.join(target_dir, candidate)
|
||||||
|
counter = 1
|
||||||
|
|
||||||
|
while os.path.exists(destination):
|
||||||
|
candidate = f"{base_name}-{counter}{extension}"
|
||||||
|
destination = os.path.join(target_dir, candidate)
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
return destination
|
||||||
|
|
||||||
|
def _distribute_preview_to_entries(self, preview_path: str, entries: List) -> List[str]:
|
||||||
|
if not preview_path or not entries:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not os.path.exists(preview_path):
|
||||||
|
return []
|
||||||
|
|
||||||
|
extension = os.path.splitext(preview_path)[1] or ".webp"
|
||||||
|
|
||||||
|
targets = [
|
||||||
|
os.path.splitext(entry.file_path)[0] + extension for entry in entries
|
||||||
|
]
|
||||||
|
|
||||||
|
if not targets:
|
||||||
|
return []
|
||||||
|
|
||||||
|
first_target = targets[0]
|
||||||
|
if preview_path != first_target:
|
||||||
|
os.replace(preview_path, first_target)
|
||||||
|
source_path = first_target
|
||||||
|
|
||||||
|
for target in targets[1:]:
|
||||||
|
shutil.copyfile(source_path, target)
|
||||||
|
|
||||||
|
return targets
|
||||||
|
|
||||||
async def _handle_download_progress(
|
async def _handle_download_progress(
|
||||||
self,
|
self,
|
||||||
progress_update,
|
progress_update,
|
||||||
@@ -895,16 +1017,23 @@ class DownloadManager:
|
|||||||
# Clean up ALL files including .part when user cancels
|
# Clean up ALL files including .part when user cancels
|
||||||
download_info = self._active_downloads.get(download_id)
|
download_info = self._active_downloads.get(download_id)
|
||||||
if download_info:
|
if download_info:
|
||||||
# Delete the main file
|
target_files = set()
|
||||||
if 'file_path' in download_info:
|
primary_path = download_info.get('file_path')
|
||||||
file_path = download_info['file_path']
|
if primary_path:
|
||||||
|
target_files.add(primary_path)
|
||||||
|
|
||||||
|
for extra_path in download_info.get('extracted_paths', []):
|
||||||
|
if extra_path:
|
||||||
|
target_files.add(extra_path)
|
||||||
|
|
||||||
|
for file_path in target_files:
|
||||||
if os.path.exists(file_path):
|
if os.path.exists(file_path):
|
||||||
try:
|
try:
|
||||||
os.unlink(file_path)
|
os.unlink(file_path)
|
||||||
logger.debug(f"Deleted cancelled download: {file_path}")
|
logger.debug(f"Deleted cancelled download: {file_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting file: {e}")
|
logger.error(f"Error deleting file: {e}")
|
||||||
|
|
||||||
# Delete the .part file (only on user cancellation)
|
# Delete the .part file (only on user cancellation)
|
||||||
if 'part_path' in download_info:
|
if 'part_path' in download_info:
|
||||||
part_path = download_info['part_path']
|
part_path = download_info['part_path']
|
||||||
@@ -914,10 +1043,9 @@ class DownloadManager:
|
|||||||
logger.debug(f"Deleted partial download: {part_path}")
|
logger.debug(f"Deleted partial download: {part_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting part file: {e}")
|
logger.error(f"Error deleting part file: {e}")
|
||||||
|
|
||||||
# Delete metadata file if exists
|
# Delete metadata files for each resolved path
|
||||||
if 'file_path' in download_info:
|
for file_path in target_files:
|
||||||
file_path = download_info['file_path']
|
|
||||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||||
if os.path.exists(metadata_path):
|
if os.path.exists(metadata_path):
|
||||||
try:
|
try:
|
||||||
@@ -925,15 +1053,16 @@ class DownloadManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting metadata file: {e}")
|
logger.error(f"Error deleting metadata file: {e}")
|
||||||
|
|
||||||
preview_path_value = download_info.get('preview_path')
|
preview_path_value = download_info.get('preview_path')
|
||||||
if preview_path_value and os.path.exists(preview_path_value):
|
if preview_path_value and os.path.exists(preview_path_value):
|
||||||
try:
|
try:
|
||||||
os.unlink(preview_path_value)
|
os.unlink(preview_path_value)
|
||||||
logger.debug(f"Deleted preview file: {preview_path_value}")
|
logger.debug(f"Deleted preview file: {preview_path_value}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting preview file: {e}")
|
logger.error(f"Error deleting preview file: {preview_path_value}")
|
||||||
|
|
||||||
# Delete preview file if exists (.webp or .mp4) for legacy paths
|
# Delete preview file if exists (.webp or .mp4) for legacy paths
|
||||||
|
for file_path in target_files:
|
||||||
for preview_ext in ['.webp', '.mp4']:
|
for preview_ext in ['.webp', '.mp4']:
|
||||||
preview_path = os.path.splitext(file_path)[0] + preview_ext
|
preview_path = os.path.splitext(file_path)[0] + preview_ext
|
||||||
if os.path.exists(preview_path):
|
if os.path.exists(preview_path):
|
||||||
@@ -941,8 +1070,7 @@ class DownloadManager:
|
|||||||
os.unlink(preview_path)
|
os.unlink(preview_path)
|
||||||
logger.debug(f"Deleted preview file: {preview_path}")
|
logger.debug(f"Deleted preview file: {preview_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting preview file: {e}")
|
logger.error(f"Error deleting preview file: {preview_path}")
|
||||||
|
|
||||||
return {'success': True, 'message': 'Download cancelled successfully'}
|
return {'success': True, 'message': 'Download cancelled successfully'}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error cancelling download: {e}", exc_info=True)
|
logger.error(f"Error cancelling download: {e}", exc_info=True)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import zipfile
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -525,6 +526,177 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p
|
|||||||
assert cached_entry["model_type"] == "diffusion_model"
|
assert cached_entry["model_type"] == "diffusion_model"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
save_dir = tmp_path / "downloads"
|
||||||
|
save_dir.mkdir()
|
||||||
|
zip_path = save_dir / "bundle.zip"
|
||||||
|
|
||||||
|
class DummyMetadata:
|
||||||
|
def __init__(self, path: Path):
|
||||||
|
self.file_path = str(path)
|
||||||
|
self.sha256 = "sha256"
|
||||||
|
self.file_name = path.stem
|
||||||
|
self.preview_url = None
|
||||||
|
|
||||||
|
def generate_unique_filename(self, *_args, **_kwargs):
|
||||||
|
return os.path.basename(self.file_path)
|
||||||
|
|
||||||
|
def update_file_info(self, updated_path):
|
||||||
|
self.file_path = str(updated_path)
|
||||||
|
self.file_name = Path(updated_path).stem
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {"file_path": self.file_path}
|
||||||
|
|
||||||
|
metadata = DummyMetadata(zip_path)
|
||||||
|
version_info = {"images": []}
|
||||||
|
download_urls = ["https://example.invalid/model.zip"]
|
||||||
|
|
||||||
|
class DummyDownloader:
|
||||||
|
async def download_file(self, *_args, **_kwargs):
|
||||||
|
with zipfile.ZipFile(str(zip_path), "w") as archive:
|
||||||
|
archive.writestr("inner/model.safetensors", b"model")
|
||||||
|
archive.writestr("docs/readme.txt", b"ignore")
|
||||||
|
return True, "ok"
|
||||||
|
|
||||||
|
monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()))
|
||||||
|
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
||||||
|
monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner))
|
||||||
|
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
|
||||||
|
hash_calculator = AsyncMock(return_value="hash-single")
|
||||||
|
monkeypatch.setattr(download_manager, "calculate_sha256", hash_calculator)
|
||||||
|
|
||||||
|
result = await manager._execute_download(
|
||||||
|
download_urls=download_urls,
|
||||||
|
save_dir=str(save_dir),
|
||||||
|
metadata=metadata,
|
||||||
|
version_info=version_info,
|
||||||
|
relative_path="",
|
||||||
|
progress_callback=None,
|
||||||
|
model_type="lora",
|
||||||
|
download_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"success": True}
|
||||||
|
assert not zip_path.exists()
|
||||||
|
extracted = save_dir / "model.safetensors"
|
||||||
|
assert extracted.exists()
|
||||||
|
assert hash_calculator.await_args.args[0] == str(extracted)
|
||||||
|
saved_call = MetadataManager.save_metadata.await_args
|
||||||
|
assert saved_call.args[0] == str(extracted)
|
||||||
|
assert saved_call.args[1].sha256 == "hash-single"
|
||||||
|
assert dummy_scanner.add_model_to_cache.await_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_execute_download_extracts_zip_multiple_models(monkeypatch, tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
save_dir = tmp_path / "downloads"
|
||||||
|
save_dir.mkdir()
|
||||||
|
zip_path = save_dir / "bundle.zip"
|
||||||
|
|
||||||
|
class DummyMetadata:
|
||||||
|
def __init__(self, path: Path):
|
||||||
|
self.file_path = str(path)
|
||||||
|
self.sha256 = "sha256"
|
||||||
|
self.file_name = path.stem
|
||||||
|
self.preview_url = None
|
||||||
|
|
||||||
|
def generate_unique_filename(self, *_args, **_kwargs):
|
||||||
|
return os.path.basename(self.file_path)
|
||||||
|
|
||||||
|
def update_file_info(self, updated_path):
|
||||||
|
self.file_path = str(updated_path)
|
||||||
|
self.file_name = Path(updated_path).stem
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {"file_path": self.file_path}
|
||||||
|
|
||||||
|
metadata = DummyMetadata(zip_path)
|
||||||
|
version_info = {"images": []}
|
||||||
|
download_urls = ["https://example.invalid/model.zip"]
|
||||||
|
|
||||||
|
class DummyDownloader:
|
||||||
|
async def download_file(self, *_args, **_kwargs):
|
||||||
|
with zipfile.ZipFile(str(zip_path), "w") as archive:
|
||||||
|
archive.writestr("first/model-one.safetensors", b"one")
|
||||||
|
archive.writestr("second/model-two.safetensors", b"two")
|
||||||
|
archive.writestr("readme.md", b"ignore")
|
||||||
|
return True, "ok"
|
||||||
|
|
||||||
|
monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()))
|
||||||
|
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
||||||
|
monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner))
|
||||||
|
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
|
||||||
|
hash_calculator = AsyncMock(side_effect=["hash-one", "hash-two"])
|
||||||
|
monkeypatch.setattr(download_manager, "calculate_sha256", hash_calculator)
|
||||||
|
|
||||||
|
result = await manager._execute_download(
|
||||||
|
download_urls=download_urls,
|
||||||
|
save_dir=str(save_dir),
|
||||||
|
metadata=metadata,
|
||||||
|
version_info=version_info,
|
||||||
|
relative_path="",
|
||||||
|
progress_callback=None,
|
||||||
|
model_type="lora",
|
||||||
|
download_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"success": True}
|
||||||
|
assert not zip_path.exists()
|
||||||
|
extracted_one = save_dir / "model-one.safetensors"
|
||||||
|
extracted_two = save_dir / "model-two.safetensors"
|
||||||
|
assert extracted_one.exists()
|
||||||
|
assert extracted_two.exists()
|
||||||
|
|
||||||
|
assert hash_calculator.await_count == 2
|
||||||
|
assert MetadataManager.save_metadata.await_count == 2
|
||||||
|
assert dummy_scanner.add_model_to_cache.await_count == 2
|
||||||
|
|
||||||
|
metadata_calls = MetadataManager.save_metadata.await_args_list
|
||||||
|
assert metadata_calls[0].args[0] == str(extracted_one)
|
||||||
|
assert metadata_calls[0].args[1].sha256 == "hash-one"
|
||||||
|
assert metadata_calls[1].args[0] == str(extracted_two)
|
||||||
|
assert metadata_calls[1].args[1].sha256 == "hash-two"
|
||||||
|
|
||||||
|
|
||||||
|
def test_distribute_preview_to_entries_moves_and_copies(tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
preview_file = tmp_path / "bundle.webp"
|
||||||
|
preview_file.write_bytes(b"image-data")
|
||||||
|
|
||||||
|
entries = [
|
||||||
|
SimpleNamespace(file_path=str(tmp_path / "model-one.safetensors")),
|
||||||
|
SimpleNamespace(file_path=str(tmp_path / "model-two.safetensors")),
|
||||||
|
]
|
||||||
|
|
||||||
|
targets = manager._distribute_preview_to_entries(str(preview_file), entries)
|
||||||
|
|
||||||
|
assert targets == [
|
||||||
|
str(tmp_path / "model-one.webp"),
|
||||||
|
str(tmp_path / "model-two.webp"),
|
||||||
|
]
|
||||||
|
assert not preview_file.exists()
|
||||||
|
assert Path(targets[0]).read_bytes() == b"image-data"
|
||||||
|
assert Path(targets[1]).read_bytes() == b"image-data"
|
||||||
|
|
||||||
|
|
||||||
|
def test_distribute_preview_to_entries_keeps_existing_file(tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
existing_preview = tmp_path / "model-one.webp"
|
||||||
|
existing_preview.write_bytes(b"preview")
|
||||||
|
|
||||||
|
entries = [
|
||||||
|
SimpleNamespace(file_path=str(tmp_path / "model-one.safetensors")),
|
||||||
|
SimpleNamespace(file_path=str(tmp_path / "model-two.safetensors")),
|
||||||
|
]
|
||||||
|
|
||||||
|
targets = manager._distribute_preview_to_entries(str(existing_preview), entries)
|
||||||
|
|
||||||
|
assert targets[0] == str(existing_preview)
|
||||||
|
assert Path(targets[1]).read_bytes() == b"preview"
|
||||||
|
|
||||||
|
|
||||||
async def test_pause_download_updates_state():
|
async def test_pause_download_updates_state():
|
||||||
manager = DownloadManager()
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user