Implement paginated LoRA data fetching with caching and infinite scroll

This commit is contained in:
Will Miao
2025-02-02 22:57:22 +08:00
parent 118c4521a2
commit 4b247995d1
5 changed files with 429 additions and 127 deletions

View File

@@ -6,12 +6,16 @@ from typing import Dict, List
from ..services.civitai_client import CivitaiClient
from ..utils.file_utils import update_civitai_metadata, load_metadata
from ..config import config
from ..services.lora_scanner import LoraScanner
logger = logging.getLogger(__name__)
class ApiRoutes:
"""API route handlers for LoRA management"""
def __init__(self):
self.scanner = LoraScanner()
@classmethod
def setup_routes(cls, app: web.Application):
"""Register API routes"""
@@ -19,6 +23,7 @@ class ApiRoutes:
app.router.add_post('/api/delete_model', routes.delete_model)
app.router.add_post('/api/fetch-civitai', routes.fetch_civitai)
app.router.add_post('/api/replace_preview', routes.replace_preview)
app.router.add_get('/api/loras', routes.get_loras)
async def delete_model(self, request: web.Request) -> web.Response:
"""Handle model deletion request"""
@@ -88,6 +93,86 @@ class ApiRoutes:
logger.error(f"Error replacing preview: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
async def get_loras(self, request: web.Request) -> web.Response:
"""Handle paginated LoRA data request"""
try:
# Parse query parameters
page = int(request.query.get('page', '1'))
page_size = int(request.query.get('page_size', '20'))
sort_by = request.query.get('sort_by', 'name')
folder = request.query.get('folder')
# Validate parameters
if page < 1 or page_size < 1 or page_size > 100:
return web.json_response({
'error': 'Invalid pagination parameters'
}, status=400)
if sort_by not in ['date', 'name']:
return web.json_response({
'error': 'Invalid sort parameter'
}, status=400)
# Get paginated data
result = await self.scanner.get_paginated_data(
page=page,
page_size=page_size,
sort_by=sort_by,
folder=folder
)
# Format the response data
formatted_items = [
self._format_lora_response(item)
for item in result['items']
]
return web.json_response({
'items': formatted_items,
'total': result['total'],
'page': result['page'],
'page_size': result['page_size'],
'total_pages': result['total_pages']
})
except ValueError as e:
return web.json_response({
'error': 'Invalid parameters'
}, status=400)
except Exception as e:
logger.error(f"Error fetching loras: {e}", exc_info=True)
return web.json_response({
'error': 'Internal server error'
}, status=500)
def _format_lora_response(self, lora: Dict) -> Dict:
"""Format LoRA data for API response"""
return {
"model_name": lora["model_name"],
"file_name": lora["file_name"],
"preview_url": config.get_preview_static_url(lora["preview_url"]),
"base_model": lora["base_model"],
"folder": lora["folder"],
"sha256": lora["sha256"],
"file_path": lora["file_path"].replace(os.sep, "/"),
"modified": lora["modified"],
"from_civitai": lora.get("from_civitai", True),
"civitai": self._filter_civitai_data(lora.get("civitai", {}))
}
def _filter_civitai_data(self, data: Dict) -> Dict:
"""Filter relevant fields from CivitAI data"""
if not data:
return {}
fields = [
"id", "modelId", "name", "createdAt", "updatedAt",
"publishedAt", "trainedWords", "baseModel", "description",
"model", "images"
]
return {k: data[k] for k in fields if k in data}
# Private helper methods
async def _delete_model_files(self, target_dir: str, file_name: str) -> List[str]:
"""Delete model and associated files"""

View File

@@ -48,18 +48,27 @@ class LoraRoutes:
async def handle_loras_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras request"""
try:
# Scan for loras
loras = await self.scanner.scan_all_loras()
# Get cached data
cache = await self.scanner.get_cached_data()
# Format data for template
formatted_loras = [self.format_lora_data(l) for l in loras]
folders = sorted(list(set(l['folder'] for l in loras)))
# Format initial data (first page only)
initial_data = await self.scanner.get_paginated_data(
page=1,
page_size=20,
sort_by='name'
)
formatted_loras = [
self.format_lora_data(l)
for l in initial_data['items']
]
# Render template
template = self.template_env.get_template('loras.html')
rendered = template.render(
loras=formatted_loras,
folders=folders
folders=cache.folders,
total_items=initial_data['total']
)
return web.Response(