diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py index c6af3de2..ca2f3d6a 100644 --- a/py/routes/base_model_routes.py +++ b/py/routes/base_model_routes.py @@ -251,7 +251,7 @@ class BaseModelRoutes(ABC): def _find_model_file(self, files): """Find the appropriate model file from the files list - can be overridden by subclasses.""" - return next((file for file in files if file.get("type") == "Model" and file.get("primary") is True), None) + return next((file for file in files if file.get("type") in ("Model", "Diffusion Model") and file.get("primary") is True), None) def get_handler(self, name: str) -> Callable[[web.Request], web.StreamResponse]: """Expose handlers for subclasses or tests.""" diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 38c21e6f..cf4225ab 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -1364,7 +1364,7 @@ class DownloadManager: f for f in files if f.get("primary") - and f.get("type") in ("Model", "Negative") + and f.get("type") in ("Model", "Negative", "Diffusion Model") ), None, ) @@ -1395,7 +1395,7 @@ class DownloadManager: ( f for f in files - if f.get("primary") and f.get("type") in ("Model", "Negative") + if f.get("primary") and f.get("type") in ("Model", "Negative", "Diffusion Model") ), None, )