diff --git a/py/routes/checkpoint_routes.py b/py/routes/checkpoint_routes.py index 4f27115e..b8edabac 100644 --- a/py/routes/checkpoint_routes.py +++ b/py/routes/checkpoint_routes.py @@ -4,6 +4,7 @@ from aiohttp import web from .base_model_routes import BaseModelRoutes from ..services.checkpoint_service import CheckpointService from ..services.service_registry import ServiceRegistry +from ..config import config logger = logging.getLogger(__name__) @@ -41,6 +42,10 @@ class CheckpointRoutes(BaseModelRoutes): # Checkpoint info by name app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info) + + # Checkpoint roots and Unet roots + app.router.add_get(f'/api/{prefix}/checkpoints_roots', self.get_checkpoints_roots) + app.router.add_get(f'/api/{prefix}/unet_roots', self.get_unet_roots) async def get_checkpoint_info(self, request: web.Request) -> web.Response: """Get detailed information for a specific checkpoint by name""" @@ -102,4 +107,12 @@ class CheckpointRoutes(BaseModelRoutes): return web.json_response(versions) except Exception as e: logger.error(f"Error fetching checkpoint model versions: {e}") - return web.Response(status=500, text=str(e)) \ No newline at end of file + return web.Response(status=500, text=str(e)) + + async def get_checkpoints_roots(self, request: web.Request) -> web.Response: + """Return the list of checkpoint roots from config""" + return web.json_response({"checkpoints_roots": config.checkpoints_roots}) + + async def get_unet_roots(self, request: web.Request) -> web.Response: + """Return the list of unet roots from config""" + return web.json_response({"unet_roots": config.unet_roots}) \ No newline at end of file