feat: Enhance model routes with generic page handling and template integration

This commit is contained in:
Will Miao
2025-07-23 15:30:39 +08:00
parent a2b81ea099
commit c2e00b240e
3 changed files with 67 additions and 134 deletions

View File

@@ -4,8 +4,12 @@ import logging
from aiohttp import web from aiohttp import web
from typing import Dict from typing import Dict
import jinja2
from ..utils.routes_common import ModelRouteUtils from ..utils.routes_common import ModelRouteUtils
from ..services.websocket_manager import ws_manager from ..services.websocket_manager import ws_manager
from ..services.settings_manager import settings
from ..config import config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -20,6 +24,10 @@ class BaseModelRoutes(ABC):
""" """
self.service = service self.service = service
self.model_type = service.model_type self.model_type = service.model_type
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
def setup_routes(self, app: web.Application, prefix: str): def setup_routes(self, app: web.Application, prefix: str):
"""Setup common routes for the model type """Setup common routes for the model type
@@ -52,6 +60,9 @@ class BaseModelRoutes(ABC):
app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai) app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai)
app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions) app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions)
# Add generic page route
app.router.add_get(f'/{prefix}', self.handle_models_page)
# Setup model-specific routes # Setup model-specific routes
self.setup_specific_routes(app, prefix) self.setup_specific_routes(app, prefix)
@@ -60,6 +71,58 @@ class BaseModelRoutes(ABC):
"""Setup model-specific routes - to be implemented by subclasses""" """Setup model-specific routes - to be implemented by subclasses"""
pass pass
async def handle_models_page(self, request: web.Request) -> web.Response:
"""
Generic handler for model pages (e.g., /loras, /checkpoints).
Subclasses should set self.template_env and template_name.
"""
try:
# Check if the scanner is initializing
is_initializing = (
self.service.scanner._cache is None or
(hasattr(self.service.scanner, 'is_initializing') and callable(self.service.scanner.is_initializing) and self.service.scanner.is_initializing()) or
(hasattr(self.service.scanner, '_is_initializing') and self.service.scanner._is_initializing)
)
template_name = getattr(self, "template_name", None)
if not self.template_env or not template_name:
return web.Response(text="Template environment or template name not set", status=500)
if is_initializing:
rendered = self.template_env.get_template(template_name).render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
else:
try:
cache = await self.service.scanner.get_cached_data(force_refresh=False)
rendered = self.template_env.get_template(template_name).render(
folders=getattr(cache, "folders", []),
is_initializing=False,
settings=settings,
request=request
)
except Exception as cache_error:
logger.error(f"Error loading cache data: {cache_error}")
rendered = self.template_env.get_template(template_name).render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
return web.Response(
text=rendered,
content_type='text/html'
)
except Exception as e:
logger.error(f"Error handling models page: {e}", exc_info=True)
return web.Response(
text="Error loading models page",
status=500
)
async def get_models(self, request: web.Request) -> web.Response: async def get_models(self, request: web.Request) -> web.Response:
"""Get paginated model data""" """Get paginated model data"""
try: try:

View File

@@ -1,12 +1,9 @@
import jinja2
import logging import logging
from aiohttp import web from aiohttp import web
from .base_model_routes import BaseModelRoutes from .base_model_routes import BaseModelRoutes
from ..services.checkpoint_service import CheckpointService from ..services.checkpoint_service import CheckpointService
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..config import config
from ..services.settings_manager import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -18,10 +15,7 @@ class CheckpointRoutes(BaseModelRoutes):
# Service will be initialized later via setup_routes # Service will be initialized later via setup_routes
self.service = None self.service = None
self.civitai_client = None self.civitai_client = None
self.template_env = jinja2.Environment( self.template_name = "checkpoints.html"
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
async def initialize_services(self): async def initialize_services(self):
"""Initialize services from ServiceRegistry""" """Initialize services from ServiceRegistry"""
@@ -37,76 +31,17 @@ class CheckpointRoutes(BaseModelRoutes):
# Schedule service initialization on app startup # Schedule service initialization on app startup
app.on_startup.append(lambda _: self.initialize_services()) app.on_startup.append(lambda _: self.initialize_services())
# Setup common routes with 'checkpoints' prefix # Setup common routes with 'checkpoints' prefix (includes page route)
super().setup_routes(app, 'checkpoints') super().setup_routes(app, 'checkpoints')
def setup_specific_routes(self, app: web.Application, prefix: str): def setup_specific_routes(self, app: web.Application, prefix: str):
"""Setup Checkpoint-specific routes""" """Setup Checkpoint-specific routes"""
# Checkpoint page route
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
# Checkpoint-specific CivitAI integration # Checkpoint-specific CivitAI integration
app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint) app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint)
# Checkpoint info by name # Checkpoint info by name
app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info) app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info)
async def handle_checkpoints_page(self, request: web.Request) -> web.Response:
"""Handle GET /checkpoints request"""
try:
# Check if the CheckpointScanner is initializing
# It's initializing if the cache object doesn't exist yet,
# OR if the scanner explicitly says it's initializing (background task running).
is_initializing = (
self.service.scanner._cache is None or
(hasattr(self.service.scanner, '_is_initializing') and self.service.scanner._is_initializing)
)
if is_initializing:
# If still initializing, return loading page
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=[], # Empty folder list
is_initializing=True, # New flag
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
)
logger.info("Checkpoints page is initializing, returning loading page")
else:
# Normal flow - get initialized cache data
try:
cache = await self.service.scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=cache.folders,
is_initializing=False,
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
)
except Exception as cache_error:
logger.error(f"Error loading checkpoints cache data: {cache_error}")
# If getting cache fails, also show initialization page
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
logger.info("Checkpoints cache error, returning initialization page")
return web.Response(
text=rendered,
content_type='text/html'
)
except Exception as e:
logger.error(f"Error handling checkpoints request: {e}", exc_info=True)
return web.Response(
text="Error loading checkpoints page",
status=500
)
async def get_checkpoint_info(self, request: web.Request) -> web.Response: async def get_checkpoint_info(self, request: web.Request) -> web.Response:
"""Get detailed information for a specific checkpoint by name""" """Get detailed information for a specific checkpoint by name"""
try: try:

View File

@@ -1,7 +1,5 @@
import jinja2
import asyncio import asyncio
import logging import logging
import os
from aiohttp import web from aiohttp import web
from typing import Dict from typing import Dict
from server import PromptServer # type: ignore from server import PromptServer # type: ignore
@@ -9,8 +7,6 @@ from server import PromptServer # type: ignore
from .base_model_routes import BaseModelRoutes from .base_model_routes import BaseModelRoutes
from ..services.lora_service import LoraService from ..services.lora_service import LoraService
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..services.settings_manager import settings
from ..config import config
from ..utils.routes_common import ModelRouteUtils from ..utils.routes_common import ModelRouteUtils
from ..utils.utils import get_lora_info from ..utils.utils import get_lora_info
@@ -26,10 +22,7 @@ class LoraRoutes(BaseModelRoutes):
self.civitai_client = None self.civitai_client = None
self.download_manager = None self.download_manager = None
self._download_lock = asyncio.Lock() self._download_lock = asyncio.Lock()
self.template_env = jinja2.Environment( self.template_name = "loras.html"
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
async def initialize_services(self): async def initialize_services(self):
"""Initialize services from ServiceRegistry""" """Initialize services from ServiceRegistry"""
@@ -46,14 +39,11 @@ class LoraRoutes(BaseModelRoutes):
# Schedule service initialization on app startup # Schedule service initialization on app startup
app.on_startup.append(lambda _: self.initialize_services()) app.on_startup.append(lambda _: self.initialize_services())
# Setup common routes with 'loras' prefix # Setup common routes with 'loras' prefix (includes page route)
super().setup_routes(app, 'loras') super().setup_routes(app, 'loras')
def setup_specific_routes(self, app: web.Application, prefix: str): def setup_specific_routes(self, app: web.Application, prefix: str):
"""Setup LoRA-specific routes""" """Setup LoRA-specific routes"""
# Lora page route
app.router.add_get('/loras', self.handle_loras_page)
# LoRA-specific query routes # LoRA-specific query routes
app.router.add_get(f'/api/{prefix}/letter-counts', self.get_letter_counts) app.router.add_get(f'/api/{prefix}/letter-counts', self.get_letter_counts)
app.router.add_get(f'/api/{prefix}/get-notes', self.get_lora_notes) app.router.add_get(f'/api/{prefix}/get-notes', self.get_lora_notes)
@@ -113,61 +103,6 @@ class LoraRoutes(BaseModelRoutes):
return params return params
async def handle_loras_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras request"""
try:
# Check if the LoraScanner is initializing
# It's initializing if the cache object doesn't exist yet,
# OR if the scanner explicitly says it's initializing (background task running).
is_initializing = (
self.service.scanner._cache is None or self.service.scanner.is_initializing()
)
if is_initializing:
# If still initializing, return loading page
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
logger.info("Loras page is initializing, returning loading page")
else:
# Normal flow - get data from initialized cache
try:
cache = await self.service.scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=cache.folders,
is_initializing=False,
settings=settings,
request=request
)
except Exception as cache_error:
logger.error(f"Error loading cache data: {cache_error}")
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
logger.info("Cache error, returning initialization page")
return web.Response(
text=rendered,
content_type='text/html'
)
except Exception as e:
logger.error(f"Error handling loras request: {e}", exc_info=True)
return web.Response(
text="Error loading loras page",
status=500
)
# LoRA-specific route handlers # LoRA-specific route handlers
async def get_letter_counts(self, request: web.Request) -> web.Response: async def get_letter_counts(self, request: web.Request) -> web.Response:
"""Get count of LoRAs for each letter of the alphabet""" """Get count of LoRAs for each letter of the alphabet"""