diff --git a/py/services/civarchive_client.py b/py/services/civarchive_client.py index 2deaa51f..4445b5ee 100644 --- a/py/services/civarchive_client.py +++ b/py/services/civarchive_client.py @@ -186,6 +186,22 @@ class CivArchiveClient: if "metadata" in file_data: transformed["metadata"] = file_data["metadata"] + # Infer metadata.format from filename extension + name = transformed.get("name") + if name and isinstance(name, str): + lower_name = name.lower() + if lower_name.endswith(".safetensors"): + inferred_format = "SafeTensor" + elif lower_name.endswith(".ckpt"): + inferred_format = "PickleTensor" + else: + inferred_format = None + if inferred_format: + if "metadata" not in transformed: + transformed["metadata"] = {} + if isinstance(transformed["metadata"], dict): + transformed["metadata"].setdefault("format", inferred_format) + if file_data.get("modelVersionId") is not None: transformed["modelVersionId"] = file_data.get("modelVersionId") elif file_data.get("model_version_id") is not None: @@ -213,6 +229,20 @@ class CivArchiveClient: for file_data in candidates: if isinstance(file_data, dict): transformed_files.append(self._transform_file_entry(file_data)) + + # Sort: .safetensors first, .ckpt second, others last + # so the backend fallback (no file_params) prefers safetensors + def _sort_key(f: Dict) -> int: + fname = f.get("name") or "" + if isinstance(fname, str): + lower = fname.lower() + if lower.endswith(".safetensors"): + return 0 + elif lower.endswith(".ckpt"): + return 1 + return 2 + + transformed_files.sort(key=_sort_key) return transformed_files def _transform_version(