mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
feat: Enhance model routes with generic page handling and template integration
This commit is contained in:
@@ -4,8 +4,12 @@ import logging
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
import jinja2
|
||||||
|
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
from ..utils.routes_common import ModelRouteUtils
|
||||||
from ..services.websocket_manager import ws_manager
|
from ..services.websocket_manager import ws_manager
|
||||||
|
from ..services.settings_manager import settings
|
||||||
|
from ..config import config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -20,6 +24,10 @@ class BaseModelRoutes(ABC):
|
|||||||
"""
|
"""
|
||||||
self.service = service
|
self.service = service
|
||||||
self.model_type = service.model_type
|
self.model_type = service.model_type
|
||||||
|
self.template_env = jinja2.Environment(
|
||||||
|
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||||
|
autoescape=True
|
||||||
|
)
|
||||||
|
|
||||||
def setup_routes(self, app: web.Application, prefix: str):
|
def setup_routes(self, app: web.Application, prefix: str):
|
||||||
"""Setup common routes for the model type
|
"""Setup common routes for the model type
|
||||||
@@ -52,6 +60,9 @@ class BaseModelRoutes(ABC):
|
|||||||
app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai)
|
app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai)
|
||||||
app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions)
|
app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions)
|
||||||
|
|
||||||
|
# Add generic page route
|
||||||
|
app.router.add_get(f'/{prefix}', self.handle_models_page)
|
||||||
|
|
||||||
# Setup model-specific routes
|
# Setup model-specific routes
|
||||||
self.setup_specific_routes(app, prefix)
|
self.setup_specific_routes(app, prefix)
|
||||||
|
|
||||||
@@ -60,6 +71,58 @@ class BaseModelRoutes(ABC):
|
|||||||
"""Setup model-specific routes - to be implemented by subclasses"""
|
"""Setup model-specific routes - to be implemented by subclasses"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def handle_models_page(self, request: web.Request) -> web.Response:
|
||||||
|
"""
|
||||||
|
Generic handler for model pages (e.g., /loras, /checkpoints).
|
||||||
|
Subclasses should set self.template_env and template_name.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check if the scanner is initializing
|
||||||
|
is_initializing = (
|
||||||
|
self.service.scanner._cache is None or
|
||||||
|
(hasattr(self.service.scanner, 'is_initializing') and callable(self.service.scanner.is_initializing) and self.service.scanner.is_initializing()) or
|
||||||
|
(hasattr(self.service.scanner, '_is_initializing') and self.service.scanner._is_initializing)
|
||||||
|
)
|
||||||
|
|
||||||
|
template_name = getattr(self, "template_name", None)
|
||||||
|
if not self.template_env or not template_name:
|
||||||
|
return web.Response(text="Template environment or template name not set", status=500)
|
||||||
|
|
||||||
|
if is_initializing:
|
||||||
|
rendered = self.template_env.get_template(template_name).render(
|
||||||
|
folders=[],
|
||||||
|
is_initializing=True,
|
||||||
|
settings=settings,
|
||||||
|
request=request
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
cache = await self.service.scanner.get_cached_data(force_refresh=False)
|
||||||
|
rendered = self.template_env.get_template(template_name).render(
|
||||||
|
folders=getattr(cache, "folders", []),
|
||||||
|
is_initializing=False,
|
||||||
|
settings=settings,
|
||||||
|
request=request
|
||||||
|
)
|
||||||
|
except Exception as cache_error:
|
||||||
|
logger.error(f"Error loading cache data: {cache_error}")
|
||||||
|
rendered = self.template_env.get_template(template_name).render(
|
||||||
|
folders=[],
|
||||||
|
is_initializing=True,
|
||||||
|
settings=settings,
|
||||||
|
request=request
|
||||||
|
)
|
||||||
|
return web.Response(
|
||||||
|
text=rendered,
|
||||||
|
content_type='text/html'
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling models page: {e}", exc_info=True)
|
||||||
|
return web.Response(
|
||||||
|
text="Error loading models page",
|
||||||
|
status=500
|
||||||
|
)
|
||||||
|
|
||||||
async def get_models(self, request: web.Request) -> web.Response:
|
async def get_models(self, request: web.Request) -> web.Response:
|
||||||
"""Get paginated model data"""
|
"""Get paginated model data"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
import jinja2
|
|
||||||
import logging
|
import logging
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from .base_model_routes import BaseModelRoutes
|
from .base_model_routes import BaseModelRoutes
|
||||||
from ..services.checkpoint_service import CheckpointService
|
from ..services.checkpoint_service import CheckpointService
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..config import config
|
|
||||||
from ..services.settings_manager import settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -18,10 +15,7 @@ class CheckpointRoutes(BaseModelRoutes):
|
|||||||
# Service will be initialized later via setup_routes
|
# Service will be initialized later via setup_routes
|
||||||
self.service = None
|
self.service = None
|
||||||
self.civitai_client = None
|
self.civitai_client = None
|
||||||
self.template_env = jinja2.Environment(
|
self.template_name = "checkpoints.html"
|
||||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
|
||||||
autoescape=True
|
|
||||||
)
|
|
||||||
|
|
||||||
async def initialize_services(self):
|
async def initialize_services(self):
|
||||||
"""Initialize services from ServiceRegistry"""
|
"""Initialize services from ServiceRegistry"""
|
||||||
@@ -37,76 +31,17 @@ class CheckpointRoutes(BaseModelRoutes):
|
|||||||
# Schedule service initialization on app startup
|
# Schedule service initialization on app startup
|
||||||
app.on_startup.append(lambda _: self.initialize_services())
|
app.on_startup.append(lambda _: self.initialize_services())
|
||||||
|
|
||||||
# Setup common routes with 'checkpoints' prefix
|
# Setup common routes with 'checkpoints' prefix (includes page route)
|
||||||
super().setup_routes(app, 'checkpoints')
|
super().setup_routes(app, 'checkpoints')
|
||||||
|
|
||||||
def setup_specific_routes(self, app: web.Application, prefix: str):
|
def setup_specific_routes(self, app: web.Application, prefix: str):
|
||||||
"""Setup Checkpoint-specific routes"""
|
"""Setup Checkpoint-specific routes"""
|
||||||
# Checkpoint page route
|
|
||||||
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
|
|
||||||
|
|
||||||
# Checkpoint-specific CivitAI integration
|
# Checkpoint-specific CivitAI integration
|
||||||
app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint)
|
app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint)
|
||||||
|
|
||||||
# Checkpoint info by name
|
# Checkpoint info by name
|
||||||
app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info)
|
app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info)
|
||||||
|
|
||||||
async def handle_checkpoints_page(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle GET /checkpoints request"""
|
|
||||||
try:
|
|
||||||
# Check if the CheckpointScanner is initializing
|
|
||||||
# It's initializing if the cache object doesn't exist yet,
|
|
||||||
# OR if the scanner explicitly says it's initializing (background task running).
|
|
||||||
is_initializing = (
|
|
||||||
self.service.scanner._cache is None or
|
|
||||||
(hasattr(self.service.scanner, '_is_initializing') and self.service.scanner._is_initializing)
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_initializing:
|
|
||||||
# If still initializing, return loading page
|
|
||||||
template = self.template_env.get_template('checkpoints.html')
|
|
||||||
rendered = template.render(
|
|
||||||
folders=[], # Empty folder list
|
|
||||||
is_initializing=True, # New flag
|
|
||||||
settings=settings, # Pass settings to template
|
|
||||||
request=request # Pass the request object to the template
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Checkpoints page is initializing, returning loading page")
|
|
||||||
else:
|
|
||||||
# Normal flow - get initialized cache data
|
|
||||||
try:
|
|
||||||
cache = await self.service.scanner.get_cached_data(force_refresh=False)
|
|
||||||
template = self.template_env.get_template('checkpoints.html')
|
|
||||||
rendered = template.render(
|
|
||||||
folders=cache.folders,
|
|
||||||
is_initializing=False,
|
|
||||||
settings=settings, # Pass settings to template
|
|
||||||
request=request # Pass the request object to the template
|
|
||||||
)
|
|
||||||
except Exception as cache_error:
|
|
||||||
logger.error(f"Error loading checkpoints cache data: {cache_error}")
|
|
||||||
# If getting cache fails, also show initialization page
|
|
||||||
template = self.template_env.get_template('checkpoints.html')
|
|
||||||
rendered = template.render(
|
|
||||||
folders=[],
|
|
||||||
is_initializing=True,
|
|
||||||
settings=settings,
|
|
||||||
request=request
|
|
||||||
)
|
|
||||||
logger.info("Checkpoints cache error, returning initialization page")
|
|
||||||
|
|
||||||
return web.Response(
|
|
||||||
text=rendered,
|
|
||||||
content_type='text/html'
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error handling checkpoints request: {e}", exc_info=True)
|
|
||||||
return web.Response(
|
|
||||||
text="Error loading checkpoints page",
|
|
||||||
status=500
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_checkpoint_info(self, request: web.Request) -> web.Response:
|
async def get_checkpoint_info(self, request: web.Request) -> web.Response:
|
||||||
"""Get detailed information for a specific checkpoint by name"""
|
"""Get detailed information for a specific checkpoint by name"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
import jinja2
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from server import PromptServer # type: ignore
|
from server import PromptServer # type: ignore
|
||||||
@@ -9,8 +7,6 @@ from server import PromptServer # type: ignore
|
|||||||
from .base_model_routes import BaseModelRoutes
|
from .base_model_routes import BaseModelRoutes
|
||||||
from ..services.lora_service import LoraService
|
from ..services.lora_service import LoraService
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..services.settings_manager import settings
|
|
||||||
from ..config import config
|
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
from ..utils.routes_common import ModelRouteUtils
|
||||||
from ..utils.utils import get_lora_info
|
from ..utils.utils import get_lora_info
|
||||||
|
|
||||||
@@ -26,10 +22,7 @@ class LoraRoutes(BaseModelRoutes):
|
|||||||
self.civitai_client = None
|
self.civitai_client = None
|
||||||
self.download_manager = None
|
self.download_manager = None
|
||||||
self._download_lock = asyncio.Lock()
|
self._download_lock = asyncio.Lock()
|
||||||
self.template_env = jinja2.Environment(
|
self.template_name = "loras.html"
|
||||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
|
||||||
autoescape=True
|
|
||||||
)
|
|
||||||
|
|
||||||
async def initialize_services(self):
|
async def initialize_services(self):
|
||||||
"""Initialize services from ServiceRegistry"""
|
"""Initialize services from ServiceRegistry"""
|
||||||
@@ -46,14 +39,11 @@ class LoraRoutes(BaseModelRoutes):
|
|||||||
# Schedule service initialization on app startup
|
# Schedule service initialization on app startup
|
||||||
app.on_startup.append(lambda _: self.initialize_services())
|
app.on_startup.append(lambda _: self.initialize_services())
|
||||||
|
|
||||||
# Setup common routes with 'loras' prefix
|
# Setup common routes with 'loras' prefix (includes page route)
|
||||||
super().setup_routes(app, 'loras')
|
super().setup_routes(app, 'loras')
|
||||||
|
|
||||||
def setup_specific_routes(self, app: web.Application, prefix: str):
|
def setup_specific_routes(self, app: web.Application, prefix: str):
|
||||||
"""Setup LoRA-specific routes"""
|
"""Setup LoRA-specific routes"""
|
||||||
# Lora page route
|
|
||||||
app.router.add_get('/loras', self.handle_loras_page)
|
|
||||||
|
|
||||||
# LoRA-specific query routes
|
# LoRA-specific query routes
|
||||||
app.router.add_get(f'/api/{prefix}/letter-counts', self.get_letter_counts)
|
app.router.add_get(f'/api/{prefix}/letter-counts', self.get_letter_counts)
|
||||||
app.router.add_get(f'/api/{prefix}/get-notes', self.get_lora_notes)
|
app.router.add_get(f'/api/{prefix}/get-notes', self.get_lora_notes)
|
||||||
@@ -113,61 +103,6 @@ class LoraRoutes(BaseModelRoutes):
|
|||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
async def handle_loras_page(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle GET /loras request"""
|
|
||||||
try:
|
|
||||||
# Check if the LoraScanner is initializing
|
|
||||||
# It's initializing if the cache object doesn't exist yet,
|
|
||||||
# OR if the scanner explicitly says it's initializing (background task running).
|
|
||||||
is_initializing = (
|
|
||||||
self.service.scanner._cache is None or self.service.scanner.is_initializing()
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_initializing:
|
|
||||||
# If still initializing, return loading page
|
|
||||||
template = self.template_env.get_template('loras.html')
|
|
||||||
rendered = template.render(
|
|
||||||
folders=[],
|
|
||||||
is_initializing=True,
|
|
||||||
settings=settings,
|
|
||||||
request=request
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Loras page is initializing, returning loading page")
|
|
||||||
else:
|
|
||||||
# Normal flow - get data from initialized cache
|
|
||||||
try:
|
|
||||||
cache = await self.service.scanner.get_cached_data(force_refresh=False)
|
|
||||||
template = self.template_env.get_template('loras.html')
|
|
||||||
rendered = template.render(
|
|
||||||
folders=cache.folders,
|
|
||||||
is_initializing=False,
|
|
||||||
settings=settings,
|
|
||||||
request=request
|
|
||||||
)
|
|
||||||
except Exception as cache_error:
|
|
||||||
logger.error(f"Error loading cache data: {cache_error}")
|
|
||||||
template = self.template_env.get_template('loras.html')
|
|
||||||
rendered = template.render(
|
|
||||||
folders=[],
|
|
||||||
is_initializing=True,
|
|
||||||
settings=settings,
|
|
||||||
request=request
|
|
||||||
)
|
|
||||||
logger.info("Cache error, returning initialization page")
|
|
||||||
|
|
||||||
return web.Response(
|
|
||||||
text=rendered,
|
|
||||||
content_type='text/html'
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error handling loras request: {e}", exc_info=True)
|
|
||||||
return web.Response(
|
|
||||||
text="Error loading loras page",
|
|
||||||
status=500
|
|
||||||
)
|
|
||||||
|
|
||||||
# LoRA-specific route handlers
|
# LoRA-specific route handlers
|
||||||
async def get_letter_counts(self, request: web.Request) -> web.Response:
|
async def get_letter_counts(self, request: web.Request) -> web.Response:
|
||||||
"""Get count of LoRAs for each letter of the alphabet"""
|
"""Get count of LoRAs for each letter of the alphabet"""
|
||||||
|
|||||||
Reference in New Issue
Block a user