diff --git a/py/services/download_coordinator.py b/py/services/download_coordinator.py index 8a008412..035f288c 100644 --- a/py/services/download_coordinator.py +++ b/py/services/download_coordinator.py @@ -86,6 +86,7 @@ class DownloadCoordinator: progress_callback=progress_callback, download_id=download_id, source=payload.get("source"), + file_params=payload.get("file_params"), ) result["download_id"] = download_id diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 83108371..75401531 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -70,6 +70,7 @@ class DownloadManager: use_default_paths: bool = False, download_id: str = None, source: str = None, + file_params: Dict = None, ) -> Dict: """Download model from Civitai with task tracking and concurrency control @@ -82,6 +83,7 @@ class DownloadManager: use_default_paths: Flag to use default paths download_id: Unique identifier for this download task source: Optional source parameter to specify metadata provider + file_params: Optional dict with file selection params (type, format, size, fp, isPrimary) Returns: Dict with download result @@ -122,6 +124,7 @@ class DownloadManager: progress_callback, use_default_paths, source, + file_params, ) ) @@ -155,6 +158,7 @@ class DownloadManager: progress_callback=None, use_default_paths: bool = False, source: str = None, + file_params: Dict = None, ): """Execute download with semaphore to limit concurrency""" # Update status to waiting @@ -215,6 +219,7 @@ class DownloadManager: use_default_paths, task_id, source, + file_params, ) # Update status based on result @@ -266,6 +271,7 @@ class DownloadManager: use_default_paths, download_id=None, source=None, + file_params=None, ): """Wrapper for original download_from_civitai implementation""" try: @@ -456,16 +462,57 @@ class DownloadManager: await progress_callback(0) # 2. Get file information - file_info = next( - ( - f - for f in version_info.get("files", []) - if f.get("primary") and f.get("type") in ("Model", "Negative") - ), - None, - ) + files = version_info.get("files", []) + file_info = None + + # If file_params is provided, try to find matching file + if file_params and model_version_id: + target_type = file_params.get("type", "Model") + target_format = file_params.get("format", "SafeTensor") + target_size = file_params.get("size", "full") + target_fp = file_params.get("fp") + is_primary = file_params.get("isPrimary", False) + + if is_primary: + # Find primary file + file_info = next( + (f for f in files if f.get("primary") and f.get("type") in ("Model", "Negative")), + None + ) + else: + # Match by metadata + for f in files: + f_type = f.get("type", "") + f_meta = f.get("metadata", {}) + + # Check type match + if f_type != target_type: + continue + + # Check metadata match + if f_meta.get("format") != target_format: + continue + if f_meta.get("size") != target_size: + continue + if target_fp and f_meta.get("fp") != target_fp: + continue + + file_info = f + break + + # Fallback to primary file if no match found if not file_info: - return {"success": False, "error": "No primary file found in metadata"} + file_info = next( + ( + f + for f in files + if f.get("primary") and f.get("type") in ("Model", "Negative") + ), + None, + ) + + if not file_info: + return {"success": False, "error": "No suitable file found in metadata"} mirrors = file_info.get("mirrors") or [] download_urls = [] if mirrors: