feat: add endpoints for retrieving checkpoint and unet roots from config

This commit is contained in:
Will Miao
2025-08-04 17:40:19 +08:00
parent 31f6edf8f0
commit 9387470c69

View File

@@ -4,6 +4,7 @@ from aiohttp import web
from .base_model_routes import BaseModelRoutes
from ..services.checkpoint_service import CheckpointService
from ..services.service_registry import ServiceRegistry
from ..config import config
logger = logging.getLogger(__name__)
@@ -41,6 +42,10 @@ class CheckpointRoutes(BaseModelRoutes):
# Checkpoint info by name
app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info)
# Checkpoint roots and Unet roots
app.router.add_get(f'/api/{prefix}/checkpoints_roots', self.get_checkpoints_roots)
app.router.add_get(f'/api/{prefix}/unet_roots', self.get_unet_roots)
async def get_checkpoint_info(self, request: web.Request) -> web.Response:
"""Get detailed information for a specific checkpoint by name"""
@@ -102,4 +107,12 @@ class CheckpointRoutes(BaseModelRoutes):
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching checkpoint model versions: {e}")
return web.Response(status=500, text=str(e))
return web.Response(status=500, text=str(e))
async def get_checkpoints_roots(self, request: web.Request) -> web.Response:
"""Return the list of checkpoint roots from config"""
return web.json_response({"checkpoints_roots": config.checkpoints_roots})
async def get_unet_roots(self, request: web.Request) -> web.Response:
"""Return the list of unet roots from config"""
return web.json_response({"unet_roots": config.unet_roots})