Merge pull request #691 from willmiao/feat/zip-preview

feat(downloads): support safetensors zips and previews
This commit is contained in:
pixelpaws
2025-11-20 19:56:31 +08:00
committed by GitHub
2 changed files with 364 additions and 64 deletions

View File

@@ -1,7 +1,10 @@
import copy
import logging import logging
import os import os
import asyncio import asyncio
import inspect import inspect
import shutil
import zipfile
from collections import OrderedDict from collections import OrderedDict
import uuid import uuid
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@@ -12,6 +15,7 @@ from ..utils.civitai_utils import rewrite_preview_url
from ..utils.preview_selection import select_preview_media from ..utils.preview_selection import select_preview_media
from ..utils.utils import sanitize_folder_name from ..utils.utils import sanitize_folder_name
from ..utils.exif_utils import ExifUtils from ..utils.exif_utils import ExifUtils
from ..utils.file_utils import calculate_sha256
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
from .service_registry import ServiceRegistry from .service_registry import ServiceRegistry
from .settings_manager import get_settings_manager from .settings_manager import get_settings_manager
@@ -556,6 +560,13 @@ class DownloadManager:
download_id: str = None, download_id: str = None,
) -> Dict: ) -> Dict:
"""Execute the actual download process including preview images and model files""" """Execute the actual download process including preview images and model files"""
metadata_entries: List = []
metadata_files_for_cleanup: List[str] = []
extracted_paths: List[str] = []
metadata_path = ""
preview_targets: List[str] = []
preview_path: str | None = None
preview_nsfw_level = 0
try: try:
# Extract original filename details # Extract original filename details
original_filename = os.path.basename(metadata.file_path) original_filename = os.path.basename(metadata.file_path)
@@ -699,10 +710,9 @@ class DownloadManager:
logger.warning(f"Failed to delete temp file: {e}") logger.warning(f"Failed to delete temp file: {e}")
if preview_downloaded and preview_path: if preview_downloaded and preview_path:
preview_nsfw_level = nsfw_level
metadata.preview_url = preview_path.replace(os.sep, '/') metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = nsfw_level metadata.preview_nsfw_level = nsfw_level
if download_id and download_id in self._active_downloads:
self._active_downloads[download_id]['preview_path'] = preview_path
if progress_callback: if progress_callback:
await progress_callback(3) # 3% progress after preview download await progress_callback(3) # 3% progress after preview download
@@ -761,77 +771,189 @@ class DownloadManager:
return {'success': False, 'error': last_error or 'Failed to download file'} return {'success': False, 'error': last_error or 'Failed to download file'}
# 4. Update file information (size and modified time) # 4. Handle archive extraction and prepare per-file metadata
metadata.update_file_info(save_path) actual_file_paths = [save_path]
if zipfile.is_zipfile(save_path):
extracted_paths = await self._extract_safetensors_from_archive(save_path)
if not extracted_paths:
return {'success': False, 'error': 'Zip archive does not contain any safetensors files'}
actual_file_paths = extracted_paths
try:
os.remove(save_path)
except OSError as exc:
logger.warning(f"Unable to delete temporary archive {save_path}: {exc}")
if download_id and download_id in self._active_downloads:
self._active_downloads[download_id]['file_path'] = extracted_paths[0]
self._active_downloads[download_id]['extracted_paths'] = extracted_paths
metadata_entries = await self._build_metadata_entries(metadata, actual_file_paths)
if preview_path:
preview_targets = self._distribute_preview_to_entries(preview_path, metadata_entries)
for entry, target in zip(metadata_entries, preview_targets):
entry.preview_url = target.replace(os.sep, "/")
entry.preview_nsfw_level = preview_nsfw_level
if download_id and download_id in self._active_downloads and preview_targets:
self._active_downloads[download_id]["preview_path"] = preview_targets[0]
scanner = None scanner = None
adjust_root: Optional[str] = None
# 5. Determine scanner and adjust metadata for cache consistency
if model_type == "checkpoint": if model_type == "checkpoint":
scanner = await self._get_checkpoint_scanner() scanner = await self._get_checkpoint_scanner()
logger.info(f"Updating checkpoint cache for {save_path}") logger.info(f"Updating checkpoint cache for {actual_file_paths[0]}")
elif model_type == "lora": elif model_type == "lora":
scanner = await self._get_lora_scanner() scanner = await self._get_lora_scanner()
logger.info(f"Updating lora cache for {save_path}") logger.info(f"Updating lora cache for {actual_file_paths[0]}")
elif model_type == "embedding": elif model_type == "embedding":
scanner = await ServiceRegistry.get_embedding_scanner() scanner = await ServiceRegistry.get_embedding_scanner()
logger.info(f"Updating embedding cache for {save_path}") logger.info(f"Updating embedding cache for {actual_file_paths[0]}")
if scanner is not None: adjust_cached_entry = (
file_path_for_adjust = getattr(metadata, "file_path", save_path) getattr(scanner, "adjust_cached_entry", None) if scanner is not None else None
if isinstance(file_path_for_adjust, str): )
normalized_file_path = file_path_for_adjust.replace(os.sep, "/")
else:
normalized_file_path = str(file_path_for_adjust)
find_root = getattr(scanner, "_find_root_for_file", None) for index, entry in enumerate(metadata_entries):
if callable(find_root): file_path_for_adjust = getattr(entry, "file_path", actual_file_paths[index])
try: normalized_file_path = (
adjust_root = find_root(normalized_file_path) file_path_for_adjust.replace(os.sep, "/")
except TypeError: if isinstance(file_path_for_adjust, str)
adjust_root = None else str(file_path_for_adjust)
)
adjust_metadata = getattr(scanner, "adjust_metadata", None) if scanner is not None:
if callable(adjust_metadata): find_root = getattr(scanner, "_find_root_for_file", None)
metadata = adjust_metadata(metadata, normalized_file_path, adjust_root) adjust_root = None
if callable(find_root):
try:
adjust_root = find_root(normalized_file_path)
except TypeError:
adjust_root = None
# 6. Persist metadata with any adjustments adjust_metadata = getattr(scanner, "adjust_metadata", None)
await MetadataManager.save_metadata(save_path, metadata) if callable(adjust_metadata):
adjusted_entry = adjust_metadata(entry, normalized_file_path, adjust_root)
if adjusted_entry is not None:
entry = adjusted_entry
metadata_entries[index] = entry
# Convert metadata to dictionary metadata_file_path = os.path.splitext(entry.file_path)[0] + '.metadata.json'
metadata_dict = metadata.to_dict() metadata_files_for_cleanup.append(metadata_file_path)
adjust_cached_entry = getattr(scanner, "adjust_cached_entry", None) if scanner is not None else None
if callable(adjust_cached_entry):
metadata_dict = adjust_cached_entry(metadata_dict)
# Add model to cache and save to disk in a single operation await MetadataManager.save_metadata(entry.file_path, entry)
await scanner.add_model_to_cache(metadata_dict, relative_path)
metadata_dict = entry.to_dict()
if callable(adjust_cached_entry):
metadata_dict = adjust_cached_entry(metadata_dict)
if scanner is not None:
await scanner.add_model_to_cache(metadata_dict, relative_path)
# Report 100% completion # Report 100% completion
if progress_callback: if progress_callback:
await progress_callback(100) await progress_callback(100)
return { return {'success': True}
'success': True
}
except Exception as e: except Exception as e:
logger.error(f"Error in _execute_download: {e}", exc_info=True) logger.error(f"Error in _execute_download: {e}", exc_info=True)
# Clean up partial downloads except .part file cleanup_targets = {
cleanup_files = [metadata_path] path
if hasattr(metadata, 'preview_url') and metadata.preview_url and os.path.exists(metadata.preview_url): for path in [save_path, metadata_path, *metadata_files_for_cleanup, *extracted_paths]
cleanup_files.append(metadata.preview_url) if path
}
for path in cleanup_files: preview_candidate = (
metadata_entries[0].preview_url
if metadata_entries
else getattr(metadata, "preview_url", None)
)
if preview_candidate:
cleanup_targets.add(preview_candidate)
cleanup_targets.update(preview_targets)
for path in cleanup_targets:
if path and os.path.exists(path): if path and os.path.exists(path):
try: try:
os.remove(path) os.remove(path)
except Exception as e: except Exception as exc:
logger.warning(f"Failed to cleanup file {path}: {e}") logger.warning(f"Failed to cleanup file {path}: {exc}")
return {'success': False, 'error': str(e)} return {'success': False, 'error': str(e)}
async def _extract_safetensors_from_archive(self, archive_path: str) -> List[str]:
if not zipfile.is_zipfile(archive_path):
return []
target_dir = os.path.dirname(archive_path)
def _extract_sync() -> List[str]:
extracted_files: List[str] = []
with zipfile.ZipFile(archive_path, "r") as archive:
for info in archive.infolist():
if info.is_dir():
continue
if not info.filename.lower().endswith(".safetensors"):
continue
file_name = os.path.basename(info.filename)
if not file_name:
continue
dest_path = self._resolve_extracted_destination(target_dir, file_name)
with archive.open(info) as source, open(dest_path, "wb") as target:
shutil.copyfileobj(source, target)
extracted_files.append(dest_path)
return extracted_files
return await asyncio.to_thread(_extract_sync)
async def _build_metadata_entries(self, base_metadata, file_paths: List[str]) -> List:
if not file_paths:
return []
entries: List = []
for index, file_path in enumerate(file_paths):
entry = base_metadata if index == 0 else copy.deepcopy(base_metadata)
entry.update_file_info(file_path)
entry.sha256 = await calculate_sha256(file_path)
entries.append(entry)
return entries
def _resolve_extracted_destination(self, target_dir: str, filename: str) -> str:
base_name, extension = os.path.splitext(filename)
candidate = filename
destination = os.path.join(target_dir, candidate)
counter = 1
while os.path.exists(destination):
candidate = f"{base_name}-{counter}{extension}"
destination = os.path.join(target_dir, candidate)
counter += 1
return destination
def _distribute_preview_to_entries(self, preview_path: str, entries: List) -> List[str]:
if not preview_path or not entries:
return []
if not os.path.exists(preview_path):
return []
extension = os.path.splitext(preview_path)[1] or ".webp"
targets = [
os.path.splitext(entry.file_path)[0] + extension for entry in entries
]
if not targets:
return []
first_target = targets[0]
if preview_path != first_target:
os.replace(preview_path, first_target)
source_path = first_target
for target in targets[1:]:
shutil.copyfile(source_path, target)
return targets
async def _handle_download_progress( async def _handle_download_progress(
self, self,
progress_update, progress_update,
@@ -895,16 +1017,23 @@ class DownloadManager:
# Clean up ALL files including .part when user cancels # Clean up ALL files including .part when user cancels
download_info = self._active_downloads.get(download_id) download_info = self._active_downloads.get(download_id)
if download_info: if download_info:
# Delete the main file target_files = set()
if 'file_path' in download_info: primary_path = download_info.get('file_path')
file_path = download_info['file_path'] if primary_path:
target_files.add(primary_path)
for extra_path in download_info.get('extracted_paths', []):
if extra_path:
target_files.add(extra_path)
for file_path in target_files:
if os.path.exists(file_path): if os.path.exists(file_path):
try: try:
os.unlink(file_path) os.unlink(file_path)
logger.debug(f"Deleted cancelled download: {file_path}") logger.debug(f"Deleted cancelled download: {file_path}")
except Exception as e: except Exception as e:
logger.error(f"Error deleting file: {e}") logger.error(f"Error deleting file: {e}")
# Delete the .part file (only on user cancellation) # Delete the .part file (only on user cancellation)
if 'part_path' in download_info: if 'part_path' in download_info:
part_path = download_info['part_path'] part_path = download_info['part_path']
@@ -914,10 +1043,9 @@ class DownloadManager:
logger.debug(f"Deleted partial download: {part_path}") logger.debug(f"Deleted partial download: {part_path}")
except Exception as e: except Exception as e:
logger.error(f"Error deleting part file: {e}") logger.error(f"Error deleting part file: {e}")
# Delete metadata file if exists # Delete metadata files for each resolved path
if 'file_path' in download_info: for file_path in target_files:
file_path = download_info['file_path']
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
if os.path.exists(metadata_path): if os.path.exists(metadata_path):
try: try:
@@ -925,15 +1053,16 @@ class DownloadManager:
except Exception as e: except Exception as e:
logger.error(f"Error deleting metadata file: {e}") logger.error(f"Error deleting metadata file: {e}")
preview_path_value = download_info.get('preview_path') preview_path_value = download_info.get('preview_path')
if preview_path_value and os.path.exists(preview_path_value): if preview_path_value and os.path.exists(preview_path_value):
try: try:
os.unlink(preview_path_value) os.unlink(preview_path_value)
logger.debug(f"Deleted preview file: {preview_path_value}") logger.debug(f"Deleted preview file: {preview_path_value}")
except Exception as e: except Exception as e:
logger.error(f"Error deleting preview file: {e}") logger.error(f"Error deleting preview file: {preview_path_value}")
# Delete preview file if exists (.webp or .mp4) for legacy paths # Delete preview file if exists (.webp or .mp4) for legacy paths
for file_path in target_files:
for preview_ext in ['.webp', '.mp4']: for preview_ext in ['.webp', '.mp4']:
preview_path = os.path.splitext(file_path)[0] + preview_ext preview_path = os.path.splitext(file_path)[0] + preview_ext
if os.path.exists(preview_path): if os.path.exists(preview_path):
@@ -941,8 +1070,7 @@ class DownloadManager:
os.unlink(preview_path) os.unlink(preview_path)
logger.debug(f"Deleted preview file: {preview_path}") logger.debug(f"Deleted preview file: {preview_path}")
except Exception as e: except Exception as e:
logger.error(f"Error deleting preview file: {e}") logger.error(f"Error deleting preview file: {preview_path}")
return {'success': True, 'message': 'Download cancelled successfully'} return {'success': True, 'message': 'Download cancelled successfully'}
except Exception as e: except Exception as e:
logger.error(f"Error cancelling download: {e}", exc_info=True) logger.error(f"Error cancelling download: {e}", exc_info=True)

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import os import os
import zipfile
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@@ -525,6 +526,177 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p
assert cached_entry["model_type"] == "diffusion_model" assert cached_entry["model_type"] == "diffusion_model"
async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
zip_path = save_dir / "bundle.zip"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, updated_path):
self.file_path = str(updated_path)
self.file_name = Path(updated_path).stem
def to_dict(self):
return {"file_path": self.file_path}
metadata = DummyMetadata(zip_path)
version_info = {"images": []}
download_urls = ["https://example.invalid/model.zip"]
class DummyDownloader:
async def download_file(self, *_args, **_kwargs):
with zipfile.ZipFile(str(zip_path), "w") as archive:
archive.writestr("inner/model.safetensors", b"model")
archive.writestr("docs/readme.txt", b"ignore")
return True, "ok"
monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()))
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner))
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
hash_calculator = AsyncMock(return_value="hash-single")
monkeypatch.setattr(download_manager, "calculate_sha256", hash_calculator)
result = await manager._execute_download(
download_urls=download_urls,
save_dir=str(save_dir),
metadata=metadata,
version_info=version_info,
relative_path="",
progress_callback=None,
model_type="lora",
download_id=None,
)
assert result == {"success": True}
assert not zip_path.exists()
extracted = save_dir / "model.safetensors"
assert extracted.exists()
assert hash_calculator.await_args.args[0] == str(extracted)
saved_call = MetadataManager.save_metadata.await_args
assert saved_call.args[0] == str(extracted)
assert saved_call.args[1].sha256 == "hash-single"
assert dummy_scanner.add_model_to_cache.await_count == 1
async def test_execute_download_extracts_zip_multiple_models(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
zip_path = save_dir / "bundle.zip"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, updated_path):
self.file_path = str(updated_path)
self.file_name = Path(updated_path).stem
def to_dict(self):
return {"file_path": self.file_path}
metadata = DummyMetadata(zip_path)
version_info = {"images": []}
download_urls = ["https://example.invalid/model.zip"]
class DummyDownloader:
async def download_file(self, *_args, **_kwargs):
with zipfile.ZipFile(str(zip_path), "w") as archive:
archive.writestr("first/model-one.safetensors", b"one")
archive.writestr("second/model-two.safetensors", b"two")
archive.writestr("readme.md", b"ignore")
return True, "ok"
monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()))
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner))
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
hash_calculator = AsyncMock(side_effect=["hash-one", "hash-two"])
monkeypatch.setattr(download_manager, "calculate_sha256", hash_calculator)
result = await manager._execute_download(
download_urls=download_urls,
save_dir=str(save_dir),
metadata=metadata,
version_info=version_info,
relative_path="",
progress_callback=None,
model_type="lora",
download_id=None,
)
assert result == {"success": True}
assert not zip_path.exists()
extracted_one = save_dir / "model-one.safetensors"
extracted_two = save_dir / "model-two.safetensors"
assert extracted_one.exists()
assert extracted_two.exists()
assert hash_calculator.await_count == 2
assert MetadataManager.save_metadata.await_count == 2
assert dummy_scanner.add_model_to_cache.await_count == 2
metadata_calls = MetadataManager.save_metadata.await_args_list
assert metadata_calls[0].args[0] == str(extracted_one)
assert metadata_calls[0].args[1].sha256 == "hash-one"
assert metadata_calls[1].args[0] == str(extracted_two)
assert metadata_calls[1].args[1].sha256 == "hash-two"
def test_distribute_preview_to_entries_moves_and_copies(tmp_path):
manager = DownloadManager()
preview_file = tmp_path / "bundle.webp"
preview_file.write_bytes(b"image-data")
entries = [
SimpleNamespace(file_path=str(tmp_path / "model-one.safetensors")),
SimpleNamespace(file_path=str(tmp_path / "model-two.safetensors")),
]
targets = manager._distribute_preview_to_entries(str(preview_file), entries)
assert targets == [
str(tmp_path / "model-one.webp"),
str(tmp_path / "model-two.webp"),
]
assert not preview_file.exists()
assert Path(targets[0]).read_bytes() == b"image-data"
assert Path(targets[1]).read_bytes() == b"image-data"
def test_distribute_preview_to_entries_keeps_existing_file(tmp_path):
manager = DownloadManager()
existing_preview = tmp_path / "model-one.webp"
existing_preview.write_bytes(b"preview")
entries = [
SimpleNamespace(file_path=str(tmp_path / "model-one.safetensors")),
SimpleNamespace(file_path=str(tmp_path / "model-two.safetensors")),
]
targets = manager._distribute_preview_to_entries(str(existing_preview), entries)
assert targets[0] == str(existing_preview)
assert Path(targets[1]).read_bytes() == b"preview"
async def test_pause_download_updates_state(): async def test_pause_download_updates_state():
manager = DownloadManager() manager = DownloadManager()