From c2e00b240e138d2a06f18f1393db7b6c6473376d Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 23 Jul 2025 15:30:39 +0800 Subject: [PATCH] feat: Enhance model routes with generic page handling and template integration --- py/routes/base_model_routes.py | 63 +++++++++++++++++++++++++++++++ py/routes/checkpoint_routes.py | 69 +--------------------------------- py/routes/lora_routes.py | 69 +--------------------------------- 3 files changed, 67 insertions(+), 134 deletions(-) diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py index 7a6e4aac..8f5bc3f6 100644 --- a/py/routes/base_model_routes.py +++ b/py/routes/base_model_routes.py @@ -4,8 +4,12 @@ import logging from aiohttp import web from typing import Dict +import jinja2 + from ..utils.routes_common import ModelRouteUtils from ..services.websocket_manager import ws_manager +from ..services.settings_manager import settings +from ..config import config logger = logging.getLogger(__name__) @@ -20,6 +24,10 @@ class BaseModelRoutes(ABC): """ self.service = service self.model_type = service.model_type + self.template_env = jinja2.Environment( + loader=jinja2.FileSystemLoader(config.templates_path), + autoescape=True + ) def setup_routes(self, app: web.Application, prefix: str): """Setup common routes for the model type @@ -52,6 +60,9 @@ class BaseModelRoutes(ABC): app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai) app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions) + # Add generic page route + app.router.add_get(f'/{prefix}', self.handle_models_page) + # Setup model-specific routes self.setup_specific_routes(app, prefix) @@ -60,6 +71,58 @@ class BaseModelRoutes(ABC): """Setup model-specific routes - to be implemented by subclasses""" pass + async def handle_models_page(self, request: web.Request) -> web.Response: + """ + Generic handler for model pages (e.g., /loras, /checkpoints). + Subclasses should set self.template_env and template_name. + """ + try: + # Check if the scanner is initializing + is_initializing = ( + self.service.scanner._cache is None or + (hasattr(self.service.scanner, 'is_initializing') and callable(self.service.scanner.is_initializing) and self.service.scanner.is_initializing()) or + (hasattr(self.service.scanner, '_is_initializing') and self.service.scanner._is_initializing) + ) + + template_name = getattr(self, "template_name", None) + if not self.template_env or not template_name: + return web.Response(text="Template environment or template name not set", status=500) + + if is_initializing: + rendered = self.template_env.get_template(template_name).render( + folders=[], + is_initializing=True, + settings=settings, + request=request + ) + else: + try: + cache = await self.service.scanner.get_cached_data(force_refresh=False) + rendered = self.template_env.get_template(template_name).render( + folders=getattr(cache, "folders", []), + is_initializing=False, + settings=settings, + request=request + ) + except Exception as cache_error: + logger.error(f"Error loading cache data: {cache_error}") + rendered = self.template_env.get_template(template_name).render( + folders=[], + is_initializing=True, + settings=settings, + request=request + ) + return web.Response( + text=rendered, + content_type='text/html' + ) + except Exception as e: + logger.error(f"Error handling models page: {e}", exc_info=True) + return web.Response( + text="Error loading models page", + status=500 + ) + async def get_models(self, request: web.Request) -> web.Response: """Get paginated model data""" try: diff --git a/py/routes/checkpoint_routes.py b/py/routes/checkpoint_routes.py index 5fdb670e..6ba550a6 100644 --- a/py/routes/checkpoint_routes.py +++ b/py/routes/checkpoint_routes.py @@ -1,12 +1,9 @@ -import jinja2 import logging from aiohttp import web from .base_model_routes import BaseModelRoutes from ..services.checkpoint_service import CheckpointService from ..services.service_registry import ServiceRegistry -from ..config import config -from ..services.settings_manager import settings logger = logging.getLogger(__name__) @@ -18,10 +15,7 @@ class CheckpointRoutes(BaseModelRoutes): # Service will be initialized later via setup_routes self.service = None self.civitai_client = None - self.template_env = jinja2.Environment( - loader=jinja2.FileSystemLoader(config.templates_path), - autoescape=True - ) + self.template_name = "checkpoints.html" async def initialize_services(self): """Initialize services from ServiceRegistry""" @@ -37,76 +31,17 @@ class CheckpointRoutes(BaseModelRoutes): # Schedule service initialization on app startup app.on_startup.append(lambda _: self.initialize_services()) - # Setup common routes with 'checkpoints' prefix + # Setup common routes with 'checkpoints' prefix (includes page route) super().setup_routes(app, 'checkpoints') def setup_specific_routes(self, app: web.Application, prefix: str): """Setup Checkpoint-specific routes""" - # Checkpoint page route - app.router.add_get('/checkpoints', self.handle_checkpoints_page) - # Checkpoint-specific CivitAI integration app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint) # Checkpoint info by name app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info) - async def handle_checkpoints_page(self, request: web.Request) -> web.Response: - """Handle GET /checkpoints request""" - try: - # Check if the CheckpointScanner is initializing - # It's initializing if the cache object doesn't exist yet, - # OR if the scanner explicitly says it's initializing (background task running). - is_initializing = ( - self.service.scanner._cache is None or - (hasattr(self.service.scanner, '_is_initializing') and self.service.scanner._is_initializing) - ) - - if is_initializing: - # If still initializing, return loading page - template = self.template_env.get_template('checkpoints.html') - rendered = template.render( - folders=[], # Empty folder list - is_initializing=True, # New flag - settings=settings, # Pass settings to template - request=request # Pass the request object to the template - ) - - logger.info("Checkpoints page is initializing, returning loading page") - else: - # Normal flow - get initialized cache data - try: - cache = await self.service.scanner.get_cached_data(force_refresh=False) - template = self.template_env.get_template('checkpoints.html') - rendered = template.render( - folders=cache.folders, - is_initializing=False, - settings=settings, # Pass settings to template - request=request # Pass the request object to the template - ) - except Exception as cache_error: - logger.error(f"Error loading checkpoints cache data: {cache_error}") - # If getting cache fails, also show initialization page - template = self.template_env.get_template('checkpoints.html') - rendered = template.render( - folders=[], - is_initializing=True, - settings=settings, - request=request - ) - logger.info("Checkpoints cache error, returning initialization page") - - return web.Response( - text=rendered, - content_type='text/html' - ) - except Exception as e: - logger.error(f"Error handling checkpoints request: {e}", exc_info=True) - return web.Response( - text="Error loading checkpoints page", - status=500 - ) - async def get_checkpoint_info(self, request: web.Request) -> web.Response: """Get detailed information for a specific checkpoint by name""" try: diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index 5b0674ba..0a9d7dff 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -1,7 +1,5 @@ -import jinja2 import asyncio import logging -import os from aiohttp import web from typing import Dict from server import PromptServer # type: ignore @@ -9,8 +7,6 @@ from server import PromptServer # type: ignore from .base_model_routes import BaseModelRoutes from ..services.lora_service import LoraService from ..services.service_registry import ServiceRegistry -from ..services.settings_manager import settings -from ..config import config from ..utils.routes_common import ModelRouteUtils from ..utils.utils import get_lora_info @@ -26,10 +22,7 @@ class LoraRoutes(BaseModelRoutes): self.civitai_client = None self.download_manager = None self._download_lock = asyncio.Lock() - self.template_env = jinja2.Environment( - loader=jinja2.FileSystemLoader(config.templates_path), - autoescape=True - ) + self.template_name = "loras.html" async def initialize_services(self): """Initialize services from ServiceRegistry""" @@ -46,14 +39,11 @@ class LoraRoutes(BaseModelRoutes): # Schedule service initialization on app startup app.on_startup.append(lambda _: self.initialize_services()) - # Setup common routes with 'loras' prefix + # Setup common routes with 'loras' prefix (includes page route) super().setup_routes(app, 'loras') def setup_specific_routes(self, app: web.Application, prefix: str): """Setup LoRA-specific routes""" - # Lora page route - app.router.add_get('/loras', self.handle_loras_page) - # LoRA-specific query routes app.router.add_get(f'/api/{prefix}/letter-counts', self.get_letter_counts) app.router.add_get(f'/api/{prefix}/get-notes', self.get_lora_notes) @@ -113,61 +103,6 @@ class LoraRoutes(BaseModelRoutes): return params - async def handle_loras_page(self, request: web.Request) -> web.Response: - """Handle GET /loras request""" - try: - # Check if the LoraScanner is initializing - # It's initializing if the cache object doesn't exist yet, - # OR if the scanner explicitly says it's initializing (background task running). - is_initializing = ( - self.service.scanner._cache is None or self.service.scanner.is_initializing() - ) - - if is_initializing: - # If still initializing, return loading page - template = self.template_env.get_template('loras.html') - rendered = template.render( - folders=[], - is_initializing=True, - settings=settings, - request=request - ) - - logger.info("Loras page is initializing, returning loading page") - else: - # Normal flow - get data from initialized cache - try: - cache = await self.service.scanner.get_cached_data(force_refresh=False) - template = self.template_env.get_template('loras.html') - rendered = template.render( - folders=cache.folders, - is_initializing=False, - settings=settings, - request=request - ) - except Exception as cache_error: - logger.error(f"Error loading cache data: {cache_error}") - template = self.template_env.get_template('loras.html') - rendered = template.render( - folders=[], - is_initializing=True, - settings=settings, - request=request - ) - logger.info("Cache error, returning initialization page") - - return web.Response( - text=rendered, - content_type='text/html' - ) - - except Exception as e: - logger.error(f"Error handling loras request: {e}", exc_info=True) - return web.Response( - text="Error loading loras page", - status=500 - ) - # LoRA-specific route handlers async def get_letter_counts(self, request: web.Request) -> web.Response: """Get count of LoRAs for each letter of the alphabet"""