refactor: Implement base model routes and services for LoRA and Checkpoint

- Added BaseModelRoutes class to handle common routes and logic for model types.
- Created CheckpointRoutes class inheriting from BaseModelRoutes for checkpoint-specific routes.
- Implemented CheckpointService class for handling checkpoint-related data and operations.
- Developed LoraService class for managing LoRA-specific functionalities.
- Introduced ModelServiceFactory to manage service and route registrations for different model types.
- Established methods for fetching, filtering, and formatting model data across services.
- Integrated CivitAI metadata handling within model routes and services.
- Added pagination and filtering capabilities for model data retrieval.
This commit is contained in:
Will Miao
2025-07-23 14:39:02 +08:00
parent ee609e8eac
commit a2b81ea099
15 changed files with 2185 additions and 1385 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,431 @@
from abc import ABC, abstractmethod
import json
import logging
from aiohttp import web
from typing import Dict
from ..utils.routes_common import ModelRouteUtils
from ..services.websocket_manager import ws_manager
logger = logging.getLogger(__name__)
class BaseModelRoutes(ABC):
"""Base route controller for all model types"""
def __init__(self, service):
"""Initialize the route controller
Args:
service: Model service instance (LoraService, CheckpointService, etc.)
"""
self.service = service
self.model_type = service.model_type
def setup_routes(self, app: web.Application, prefix: str):
"""Setup common routes for the model type
Args:
app: aiohttp application
prefix: URL prefix (e.g., 'loras', 'checkpoints')
"""
# Common model management routes
app.router.add_get(f'/api/{prefix}', self.get_models)
app.router.add_post(f'/api/{prefix}/delete', self.delete_model)
app.router.add_post(f'/api/{prefix}/exclude', self.exclude_model)
app.router.add_post(f'/api/{prefix}/fetch-civitai', self.fetch_civitai)
app.router.add_post(f'/api/{prefix}/relink-civitai', self.relink_civitai)
app.router.add_post(f'/api/{prefix}/replace-preview', self.replace_preview)
app.router.add_post(f'/api/{prefix}/save-metadata', self.save_metadata)
app.router.add_post(f'/api/{prefix}/rename', self.rename_model)
app.router.add_post(f'/api/{prefix}/bulk-delete', self.bulk_delete_models)
app.router.add_post(f'/api/{prefix}/verify-duplicates', self.verify_duplicates)
# Common query routes
app.router.add_get(f'/api/{prefix}/top-tags', self.get_top_tags)
app.router.add_get(f'/api/{prefix}/base-models', self.get_base_models)
app.router.add_get(f'/api/{prefix}/scan', self.scan_models)
app.router.add_get(f'/api/{prefix}/roots', self.get_model_roots)
app.router.add_get(f'/api/{prefix}/find-duplicates', self.find_duplicate_models)
app.router.add_get(f'/api/{prefix}/find-filename-conflicts', self.find_filename_conflicts)
# CivitAI integration routes
app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai)
app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions)
# Setup model-specific routes
self.setup_specific_routes(app, prefix)
@abstractmethod
def setup_specific_routes(self, app: web.Application, prefix: str):
"""Setup model-specific routes - to be implemented by subclasses"""
pass
async def get_models(self, request: web.Request) -> web.Response:
"""Get paginated model data"""
try:
# Parse common query parameters
params = self._parse_common_params(request)
# Get data from service
result = await self.service.get_paginated_data(**params)
# Format response items
formatted_result = {
'items': [await self.service.format_response(item) for item in result['items']],
'total': result['total'],
'page': result['page'],
'page_size': result['page_size'],
'total_pages': result['total_pages']
}
return web.json_response(formatted_result)
except Exception as e:
logger.error(f"Error in get_{self.model_type}s: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
def _parse_common_params(self, request: web.Request) -> Dict:
"""Parse common query parameters"""
# Parse basic pagination and sorting
page = int(request.query.get('page', '1'))
page_size = min(int(request.query.get('page_size', '20')), 100)
sort_by = request.query.get('sort_by', 'name')
folder = request.query.get('folder', None)
search = request.query.get('search', None)
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
# Parse filter arrays
base_models = request.query.getall('base_model', [])
tags = request.query.getall('tag', [])
favorites_only = request.query.get('favorites_only', 'false').lower() == 'true'
# Parse search options
search_options = {
'filename': request.query.get('search_filename', 'true').lower() == 'true',
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
'tags': request.query.get('search_tags', 'false').lower() == 'true',
'recursive': request.query.get('recursive', 'false').lower() == 'true',
}
# Parse hash filters if provided
hash_filters = {}
if 'hash' in request.query:
hash_filters['single_hash'] = request.query['hash']
elif 'hashes' in request.query:
try:
hash_list = json.loads(request.query['hashes'])
if isinstance(hash_list, list):
hash_filters['multiple_hashes'] = hash_list
except (json.JSONDecodeError, TypeError):
pass
return {
'page': page,
'page_size': page_size,
'sort_by': sort_by,
'folder': folder,
'search': search,
'fuzzy_search': fuzzy_search,
'base_models': base_models,
'tags': tags,
'search_options': search_options,
'hash_filters': hash_filters,
'favorites_only': favorites_only,
# Add model-specific parameters
**self._parse_specific_params(request)
}
def _parse_specific_params(self, request: web.Request) -> Dict:
"""Parse model-specific parameters - to be overridden by subclasses"""
return {}
# Common route handlers
async def delete_model(self, request: web.Request) -> web.Response:
"""Handle model deletion request"""
return await ModelRouteUtils.handle_delete_model(request, self.service.scanner)
async def exclude_model(self, request: web.Request) -> web.Response:
"""Handle model exclusion request"""
return await ModelRouteUtils.handle_exclude_model(request, self.service.scanner)
async def fetch_civitai(self, request: web.Request) -> web.Response:
"""Handle CivitAI metadata fetch request"""
response = await ModelRouteUtils.handle_fetch_civitai(request, self.service.scanner)
# If successful, format the metadata before returning
if response.status == 200:
data = json.loads(response.body.decode('utf-8'))
if data.get("success") and data.get("metadata"):
formatted_metadata = await self.service.format_response(data["metadata"])
return web.json_response({
"success": True,
"metadata": formatted_metadata
})
return response
async def relink_civitai(self, request: web.Request) -> web.Response:
"""Handle CivitAI metadata re-linking request"""
return await ModelRouteUtils.handle_relink_civitai(request, self.service.scanner)
async def replace_preview(self, request: web.Request) -> web.Response:
"""Handle preview image replacement"""
return await ModelRouteUtils.handle_replace_preview(request, self.service.scanner)
async def save_metadata(self, request: web.Request) -> web.Response:
"""Handle saving metadata updates"""
return await ModelRouteUtils.handle_save_metadata(request, self.service.scanner)
async def rename_model(self, request: web.Request) -> web.Response:
"""Handle renaming a model file and its associated files"""
return await ModelRouteUtils.handle_rename_model(request, self.service.scanner)
async def bulk_delete_models(self, request: web.Request) -> web.Response:
"""Handle bulk deletion of models"""
return await ModelRouteUtils.handle_bulk_delete_models(request, self.service.scanner)
async def verify_duplicates(self, request: web.Request) -> web.Response:
"""Handle verification of duplicate model hashes"""
return await ModelRouteUtils.handle_verify_duplicates(request, self.service.scanner)
async def get_top_tags(self, request: web.Request) -> web.Response:
"""Handle request for top tags sorted by frequency"""
try:
limit = int(request.query.get('limit', '20'))
if limit < 1 or limit > 100:
limit = 20
top_tags = await self.service.get_top_tags(limit)
return web.json_response({
'success': True,
'tags': top_tags
})
except Exception as e:
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
return web.json_response({
'success': False,
'error': 'Internal server error'
}, status=500)
async def get_base_models(self, request: web.Request) -> web.Response:
"""Get base models used in models"""
try:
limit = int(request.query.get('limit', '20'))
if limit < 1 or limit > 100:
limit = 20
base_models = await self.service.get_base_models(limit)
return web.json_response({
'success': True,
'base_models': base_models
})
except Exception as e:
logger.error(f"Error retrieving base models: {e}")
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def scan_models(self, request: web.Request) -> web.Response:
"""Force a rescan of model files"""
try:
full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true'
await self.service.scan_models(force_refresh=True, rebuild_cache=full_rebuild)
return web.json_response({
"status": "success",
"message": f"{self.model_type.capitalize()} scan completed"
})
except Exception as e:
logger.error(f"Error in scan_{self.model_type}s: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_model_roots(self, request: web.Request) -> web.Response:
"""Return the model root directories"""
try:
roots = self.service.get_model_roots()
return web.json_response({
"success": True,
"roots": roots
})
except Exception as e:
logger.error(f"Error getting {self.model_type} roots: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def find_duplicate_models(self, request: web.Request) -> web.Response:
"""Find models with duplicate SHA256 hashes"""
try:
# Get duplicate hashes from service
duplicates = self.service.find_duplicate_hashes()
# Format the response
result = []
cache = await self.service.scanner.get_cached_data()
for sha256, paths in duplicates.items():
group = {
"hash": sha256,
"models": []
}
# Find matching models for each path
for path in paths:
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
if model:
group["models"].append(await self.service.format_response(model))
# Add the primary model too
primary_path = self.service.get_path_by_hash(sha256)
if primary_path and primary_path not in paths:
primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None)
if primary_model:
group["models"].insert(0, await self.service.format_response(primary_model))
if len(group["models"]) > 1: # Only include if we found multiple models
result.append(group)
return web.json_response({
"success": True,
"duplicates": result,
"count": len(result)
})
except Exception as e:
logger.error(f"Error finding duplicate {self.model_type}s: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def find_filename_conflicts(self, request: web.Request) -> web.Response:
"""Find models with conflicting filenames"""
try:
# Get duplicate filenames from service
duplicates = self.service.find_duplicate_filenames()
# Format the response
result = []
cache = await self.service.scanner.get_cached_data()
for filename, paths in duplicates.items():
group = {
"filename": filename,
"models": []
}
# Find matching models for each path
for path in paths:
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
if model:
group["models"].append(await self.service.format_response(model))
# Find the model from the main index too
hash_val = self.service.scanner._hash_index.get_hash_by_filename(filename)
if hash_val:
main_path = self.service.get_path_by_hash(hash_val)
if main_path and main_path not in paths:
main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None)
if main_model:
group["models"].insert(0, await self.service.format_response(main_model))
if group["models"]:
result.append(group)
return web.json_response({
"success": True,
"conflicts": result,
"count": len(result)
})
except Exception as e:
logger.error(f"Error finding filename conflicts for {self.model_type}s: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
"""Fetch CivitAI metadata for all models in the background"""
try:
cache = await self.service.scanner.get_cached_data()
total = len(cache.raw_data)
processed = 0
success = 0
needs_resort = False
# Prepare models to process
to_process = [
model for model in cache.raw_data
if model.get('sha256') and (not model.get('civitai') or 'id' not in model.get('civitai')) and model.get('from_civitai', True)
]
total_to_process = len(to_process)
# Send initial progress
await ws_manager.broadcast({
'status': 'started',
'total': total_to_process,
'processed': 0,
'success': 0
})
# Process each model
for model in to_process:
try:
original_name = model.get('model_name')
if await ModelRouteUtils.fetch_and_update_model(
sha256=model['sha256'],
file_path=model['file_path'],
model_data=model,
update_cache_func=self.service.scanner.update_single_model_cache
):
success += 1
if original_name != model.get('model_name'):
needs_resort = True
processed += 1
# Send progress update
await ws_manager.broadcast({
'status': 'processing',
'total': total_to_process,
'processed': processed,
'success': success,
'current_name': model.get('model_name', 'Unknown')
})
except Exception as e:
logger.error(f"Error fetching CivitAI data for {model['file_path']}: {e}")
if needs_resort:
await cache.resort(name_only=True)
# Send completion message
await ws_manager.broadcast({
'status': 'completed',
'total': total_to_process,
'processed': processed,
'success': success
})
return web.json_response({
"success": True,
"message": f"Successfully updated {success} of {processed} processed {self.model_type}s (total: {total})"
})
except Exception as e:
# Send error message
await ws_manager.broadcast({
'status': 'error',
'error': str(e)
})
logger.error(f"Error in fetch_all_civitai for {self.model_type}s: {e}")
return web.Response(text=str(e), status=500)
async def get_civitai_versions(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai model with local availability info"""
# This will be implemented by subclasses as they need CivitAI client access
return web.json_response({
"error": "Not implemented in base class"
}, status=501)

View File

@@ -0,0 +1,170 @@
import jinja2
import logging
from aiohttp import web
from .base_model_routes import BaseModelRoutes
from ..services.checkpoint_service import CheckpointService
from ..services.service_registry import ServiceRegistry
from ..config import config
from ..services.settings_manager import settings
logger = logging.getLogger(__name__)
class CheckpointRoutes(BaseModelRoutes):
"""Checkpoint-specific route controller"""
def __init__(self):
"""Initialize Checkpoint routes with Checkpoint service"""
# Service will be initialized later via setup_routes
self.service = None
self.civitai_client = None
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
self.service = CheckpointService(checkpoint_scanner)
self.civitai_client = await ServiceRegistry.get_civitai_client()
# Initialize parent with the service
super().__init__(self.service)
def setup_routes(self, app: web.Application):
"""Setup Checkpoint routes"""
# Schedule service initialization on app startup
app.on_startup.append(lambda _: self.initialize_services())
# Setup common routes with 'checkpoints' prefix
super().setup_routes(app, 'checkpoints')
def setup_specific_routes(self, app: web.Application, prefix: str):
"""Setup Checkpoint-specific routes"""
# Checkpoint page route
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
# Checkpoint-specific CivitAI integration
app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint)
# Checkpoint info by name
app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info)
async def handle_checkpoints_page(self, request: web.Request) -> web.Response:
"""Handle GET /checkpoints request"""
try:
# Check if the CheckpointScanner is initializing
# It's initializing if the cache object doesn't exist yet,
# OR if the scanner explicitly says it's initializing (background task running).
is_initializing = (
self.service.scanner._cache is None or
(hasattr(self.service.scanner, '_is_initializing') and self.service.scanner._is_initializing)
)
if is_initializing:
# If still initializing, return loading page
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=[], # Empty folder list
is_initializing=True, # New flag
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
)
logger.info("Checkpoints page is initializing, returning loading page")
else:
# Normal flow - get initialized cache data
try:
cache = await self.service.scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=cache.folders,
is_initializing=False,
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
)
except Exception as cache_error:
logger.error(f"Error loading checkpoints cache data: {cache_error}")
# If getting cache fails, also show initialization page
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
logger.info("Checkpoints cache error, returning initialization page")
return web.Response(
text=rendered,
content_type='text/html'
)
except Exception as e:
logger.error(f"Error handling checkpoints request: {e}", exc_info=True)
return web.Response(
text="Error loading checkpoints page",
status=500
)
async def get_checkpoint_info(self, request: web.Request) -> web.Response:
"""Get detailed information for a specific checkpoint by name"""
try:
name = request.match_info.get('name', '')
checkpoint_info = await self.service.get_model_info_by_name(name)
if checkpoint_info:
return web.json_response(checkpoint_info)
else:
return web.json_response({"error": "Checkpoint not found"}, status=404)
except Exception as e:
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_civitai_versions_checkpoint(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai checkpoint model with local availability info"""
try:
model_id = request.match_info['model_id']
response = await self.civitai_client.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - should be Checkpoint
if model_type.lower() != 'checkpoint':
return web.json_response({
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the primary model file (type="Model" and primary=true) in the files list
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model' and file.get('primary') == True), None)
# If no primary file found, try to find any model file
if not model_file:
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.service.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.service.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching checkpoint model versions: {e}")
return web.Response(status=500, text=str(e))

View File

@@ -1,76 +1,126 @@
import jinja2
import asyncio
import logging
import os
from aiohttp import web
import jinja2
from typing import Dict
import logging
from ..config import config
from server import PromptServer # type: ignore
from .base_model_routes import BaseModelRoutes
from ..services.lora_service import LoraService
from ..services.service_registry import ServiceRegistry
from ..services.settings_manager import settings
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
from ..config import config
from ..utils.routes_common import ModelRouteUtils
from ..utils.utils import get_lora_info
logger = logging.getLogger(__name__)
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
class LoraRoutes:
"""Route handlers for LoRA management endpoints"""
class LoraRoutes(BaseModelRoutes):
"""LoRA-specific route controller"""
def __init__(self):
# Initialize service references as None, will be set during async init
self.scanner = None
self.recipe_scanner = None
"""Initialize LoRA routes with LoRA service"""
# Service will be initialized later via setup_routes
self.service = None
self.civitai_client = None
self.download_manager = None
self._download_lock = asyncio.Lock()
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
async def init_services(self):
"""Initialize services from ServiceRegistry"""
self.scanner = await ServiceRegistry.get_lora_scanner()
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
def format_lora_data(self, lora: Dict) -> Dict:
"""Format LoRA data for template rendering"""
return {
"model_name": lora["model_name"],
"file_name": lora["file_name"],
"preview_url": config.get_preview_static_url(lora["preview_url"]),
"preview_nsfw_level": lora.get("preview_nsfw_level", 0),
"base_model": lora["base_model"],
"folder": lora["folder"],
"sha256": lora["sha256"],
"file_path": lora["file_path"].replace(os.sep, "/"),
"size": lora["size"],
"tags": lora["tags"],
"modelDescription": lora["modelDescription"],
"usage_tips": lora["usage_tips"],
"notes": lora["notes"],
"modified": lora["modified"],
"from_civitai": lora.get("from_civitai", True),
"civitai": self._filter_civitai_data(lora.get("civitai", {}))
}
def _filter_civitai_data(self, data: Dict) -> Dict:
"""Filter relevant fields from CivitAI data"""
if not data:
return {}
fields = [
"id", "modelId", "name", "createdAt", "updatedAt",
"publishedAt", "trainedWords", "baseModel", "description",
"model", "images"
]
return {k: data[k] for k in fields if k in data}
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
lora_scanner = await ServiceRegistry.get_lora_scanner()
self.service = LoraService(lora_scanner)
self.civitai_client = await ServiceRegistry.get_civitai_client()
self.download_manager = await ServiceRegistry.get_download_manager()
# Initialize parent with the service
super().__init__(self.service)
def setup_routes(self, app: web.Application):
"""Setup LoRA routes"""
# Schedule service initialization on app startup
app.on_startup.append(lambda _: self.initialize_services())
# Setup common routes with 'loras' prefix
super().setup_routes(app, 'loras')
def setup_specific_routes(self, app: web.Application, prefix: str):
"""Setup LoRA-specific routes"""
# Lora page route
app.router.add_get('/loras', self.handle_loras_page)
# LoRA-specific query routes
app.router.add_get(f'/api/{prefix}/letter-counts', self.get_letter_counts)
app.router.add_get(f'/api/{prefix}/get-notes', self.get_lora_notes)
app.router.add_get(f'/api/{prefix}/get-trigger-words', self.get_lora_trigger_words)
app.router.add_get(f'/api/lora-preview-url', self.get_lora_preview_url)
app.router.add_get(f'/api/lora-civitai-url', self.get_lora_civitai_url)
app.router.add_get(f'/api/lora-model-description', self.get_lora_model_description)
app.router.add_get(f'/api/folders', self.get_folders)
app.router.add_get(f'/api/lora-roots', self.get_lora_roots)
# LoRA-specific management routes
app.router.add_post(f'/api/move_model', self.move_model)
app.router.add_post(f'/api/move_models_bulk', self.move_models_bulk)
# CivitAI integration with LoRA-specific validation
app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions_lora)
app.router.add_get(f'/api/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version)
app.router.add_get(f'/api/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash)
# Download management
app.router.add_post(f'/api/download-model', self.download_model)
app.router.add_get(f'/api/download-model-get', self.download_model_get)
app.router.add_get(f'/api/cancel-download-get', self.cancel_download_get)
app.router.add_get(f'/api/download-progress/{{download_id}}', self.get_download_progress)
# ComfyUI integration
app.router.add_post(f'/loramanager/get_trigger_words', self.get_trigger_words)
# Legacy API compatibility
app.router.add_post(f'/api/delete_model', self.delete_model)
app.router.add_post(f'/api/fetch-civitai', self.fetch_civitai)
app.router.add_post(f'/api/relink-civitai', self.relink_civitai)
app.router.add_post(f'/api/replace_preview', self.replace_preview)
app.router.add_post(f'/api/fetch-all-civitai', self.fetch_all_civitai)
def _parse_specific_params(self, request: web.Request) -> Dict:
"""Parse LoRA-specific parameters"""
params = {}
# LoRA-specific parameters
if 'first_letter' in request.query:
params['first_letter'] = request.query.get('first_letter')
# Handle fuzzy search parameter name variation
if request.query.get('fuzzy') == 'true':
params['fuzzy_search'] = True
# Handle additional filter parameters for LoRAs
if 'lora_hash' in request.query:
if not params.get('hash_filters'):
params['hash_filters'] = {}
params['hash_filters']['single_hash'] = request.query['lora_hash'].lower()
elif 'lora_hashes' in request.query:
if not params.get('hash_filters'):
params['hash_filters'] = {}
params['hash_filters']['multiple_hashes'] = [h.lower() for h in request.query['lora_hashes'].split(',')]
return params
async def handle_loras_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras request"""
try:
# Ensure services are initialized
await self.init_services()
# Check if the LoraScanner is initializing
# It's initializing if the cache object doesn't exist yet,
# OR if the scanner explicitly says it's initializing (background task running).
is_initializing = (
self.scanner._cache is None or self.scanner.is_initializing()
self.service.scanner._cache is None or self.service.scanner.is_initializing()
)
if is_initializing:
@@ -87,7 +137,7 @@ class LoraRoutes:
else:
# Normal flow - get data from initialized cache
try:
cache = await self.scanner.get_cached_data(force_refresh=False)
cache = await self.service.scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=cache.folders,
@@ -117,72 +167,561 @@ class LoraRoutes:
text="Error loading loras page",
status=500
)
async def handle_recipes_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras/recipes request"""
# LoRA-specific route handlers
async def get_letter_counts(self, request: web.Request) -> web.Response:
"""Get count of LoRAs for each letter of the alphabet"""
try:
# Ensure services are initialized
await self.init_services()
letter_counts = await self.service.get_letter_counts()
return web.json_response({
'success': True,
'letter_counts': letter_counts
})
except Exception as e:
logger.error(f"Error getting letter counts: {e}")
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_lora_notes(self, request: web.Request) -> web.Response:
"""Get notes for a specific LoRA file"""
try:
lora_name = request.query.get('name')
if not lora_name:
return web.Response(text='Lora file name is required', status=400)
# Skip initialization check and directly try to get cached data
try:
# Recipe scanner will initialize cache if needed
await self.recipe_scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('recipes.html')
rendered = template.render(
recipes=[], # Frontend will load recipes via API
is_initializing=False,
settings=settings,
request=request
)
except Exception as cache_error:
logger.error(f"Error loading recipe cache data: {cache_error}")
# Still keep error handling - show initializing page on error
template = self.template_env.get_template('recipes.html')
rendered = template.render(
is_initializing=True,
settings=settings,
request=request
)
logger.info("Recipe cache error, returning initialization page")
notes = await self.service.get_lora_notes(lora_name)
if notes is not None:
return web.json_response({
'success': True,
'notes': notes
})
else:
return web.json_response({
'success': False,
'error': 'LoRA not found in cache'
}, status=404)
except Exception as e:
logger.error(f"Error getting lora notes: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_lora_trigger_words(self, request: web.Request) -> web.Response:
"""Get trigger words for a specific LoRA file"""
try:
lora_name = request.query.get('name')
if not lora_name:
return web.Response(text='Lora file name is required', status=400)
return web.Response(
text=rendered,
content_type='text/html'
)
trigger_words = await self.service.get_lora_trigger_words(lora_name)
return web.json_response({
'success': True,
'trigger_words': trigger_words
})
except Exception as e:
logger.error(f"Error handling recipes request: {e}", exc_info=True)
return web.Response(
text="Error loading recipes page",
status=500
)
def _format_recipe_file_url(self, file_path: str) -> str:
"""Format file path for recipe image as a URL - same as in recipe_routes"""
logger.error(f"Error getting lora trigger words: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
"""Get the static preview URL for a LoRA file"""
try:
# Return the file URL directly for the first lora root's preview
recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, '/')
if file_path.replace(os.sep, '/').startswith(recipes_dir):
relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, '/')
return f"/loras_static/root1/preview/{relative_path}"
lora_name = request.query.get('name')
if not lora_name:
return web.Response(text='Lora file name is required', status=400)
# If not in recipes dir, try to create a valid URL from the file path
preview_url = await self.service.get_lora_preview_url(lora_name)
if preview_url:
return web.json_response({
'success': True,
'preview_url': preview_url
})
else:
return web.json_response({
'success': False,
'error': 'No preview URL found for the specified lora'
}, status=404)
except Exception as e:
logger.error(f"Error getting lora preview URL: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
"""Get the Civitai URL for a LoRA file"""
try:
lora_name = request.query.get('name')
if not lora_name:
return web.Response(text='Lora file name is required', status=400)
result = await self.service.get_lora_civitai_url(lora_name)
if result['civitai_url']:
return web.json_response({
'success': True,
**result
})
else:
return web.json_response({
'success': False,
'error': 'No Civitai data found for the specified lora'
}, status=404)
except Exception as e:
logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_folders(self, request: web.Request) -> web.Response:
"""Get all folders in the cache"""
try:
cache = await self.service.scanner.get_cached_data()
return web.json_response({
'folders': cache.folders
})
except Exception as e:
logger.error(f"Error getting folders: {e}")
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_lora_roots(self, request: web.Request) -> web.Response:
"""Get all configured LoRA root directories"""
try:
return web.json_response({
'roots': self.service.get_model_roots()
})
except Exception as e:
logger.error(f"Error getting LoRA roots: {e}")
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
# Override get_models to add LoRA-specific response data
async def get_models(self, request: web.Request) -> web.Response:
"""Get paginated LoRA data with LoRA-specific fields"""
try:
# Parse common query parameters
params = self._parse_common_params(request)
# Get data from service
result = await self.service.get_paginated_data(**params)
# Get all available folders from cache for LoRA-specific response
cache = await self.service.scanner.get_cached_data()
# Format response items with LoRA-specific structure
formatted_result = {
'items': [await self.service.format_response(item) for item in result['items']],
'folders': cache.folders, # LoRA-specific: include folders in response
'total': result['total'],
'page': result['page'],
'page_size': result['page_size'],
'total_pages': result['total_pages']
}
return web.json_response(formatted_result)
except Exception as e:
logger.error(f"Error in get_loras: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
# CivitAI integration methods
async def get_civitai_versions_lora(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai LoRA model with local availability info"""
try:
model_id = request.match_info['model_id']
response = await self.civitai_client.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - should be LORA, LoCon, or DORA
from ..utils.constants import VALID_LORA_TYPES
if model_type.lower() not in VALID_LORA_TYPES:
return web.json_response({
'error': f"Model type mismatch. Expected LORA or LoCon, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the model file (type="Model") in the files list
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.service.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.service.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching LoRA model versions: {e}")
return web.Response(status=500, text=str(e))
async def get_civitai_model_by_version(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by model version ID"""
try:
model_version_id = request.match_info.get('modelVersionId')
# Get model details from Civitai API
model, error_msg = await self.civitai_client.get_model_version_info(model_version_id)
if not model:
# Log warning for failed model retrieval
logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}")
# Determine status code based on error message
status_code = 404 if error_msg and "not found" in error_msg.lower() else 500
return web.json_response({
"success": False,
"error": error_msg or "Failed to fetch model information"
}, status=status_code)
return web.json_response(model)
except Exception as e:
logger.error(f"Error fetching model details: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by hash"""
try:
hash = request.match_info.get('hash')
model = await self.civitai_client.get_model_by_hash(hash)
return web.json_response(model)
except Exception as e:
logger.error(f"Error fetching model details by hash: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
# Download management methods
async def download_model(self, request: web.Request) -> web.Response:
"""Handle model download request"""
return await ModelRouteUtils.handle_download_model(request, self.download_manager)
async def download_model_get(self, request: web.Request) -> web.Response:
"""Handle model download request via GET method"""
try:
# Extract query parameters
model_id = request.query.get('model_id')
if not model_id:
return web.Response(
status=400,
text="Missing required parameter: Please provide 'model_id'"
)
# Get optional parameters
model_version_id = request.query.get('model_version_id')
download_id = request.query.get('download_id')
use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true'
# Create a data dictionary that mimics what would be received from a POST request
data = {
'model_id': model_id
}
# Add optional parameters only if they are provided
if model_version_id:
data['model_version_id'] = model_version_id
if download_id:
data['download_id'] = download_id
data['use_default_paths'] = use_default_paths
# Create a mock request object with the data
future = asyncio.get_event_loop().create_future()
future.set_result(data)
mock_request = type('MockRequest', (), {
'json': lambda self=None: future
})()
# Call the existing download handler
return await ModelRouteUtils.handle_download_model(mock_request, self.download_manager)
except Exception as e:
error_message = str(e)
logger.error(f"Error downloading model via GET: {error_message}", exc_info=True)
return web.Response(status=500, text=error_message)
async def cancel_download_get(self, request: web.Request) -> web.Response:
"""Handle GET request for cancelling a download by download_id"""
try:
download_id = request.query.get('download_id')
if not download_id:
return web.json_response({
'success': False,
'error': 'Download ID is required'
}, status=400)
# Create a mock request with match_info for compatibility
mock_request = type('MockRequest', (), {
'match_info': {'download_id': download_id}
})()
return await ModelRouteUtils.handle_cancel_download(mock_request, self.download_manager)
except Exception as e:
logger.error(f"Error cancelling download via GET: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_download_progress(self, request: web.Request) -> web.Response:
"""Handle request for download progress by download_id"""
try:
# Get download_id from URL path
download_id = request.match_info.get('download_id')
if not download_id:
return web.json_response({
'success': False,
'error': 'Download ID is required'
}, status=400)
# Get progress information from websocket manager
from ..services.websocket_manager import ws_manager
progress_data = ws_manager.get_download_progress(download_id)
if progress_data is None:
return web.json_response({
'success': False,
'error': 'Download ID not found'
}, status=404)
return web.json_response({
'success': True,
'progress': progress_data.get('progress', 0)
})
except Exception as e:
logger.error(f"Error getting download progress: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
# Model management methods
async def move_model(self, request: web.Request) -> web.Response:
"""Handle model move request"""
try:
data = await request.json()
file_path = data.get('file_path') # full path of the model file
target_path = data.get('target_path') # folder path to move the model to
if not file_path or not target_path:
return web.Response(text='File path and target path are required', status=400)
# Check if source and destination are the same
import os
source_dir = os.path.dirname(file_path)
if os.path.normpath(source_dir) == os.path.normpath(target_path):
logger.info(f"Source and target directories are the same: {source_dir}")
return web.json_response({'success': True, 'message': 'Source and target directories are the same'})
# Check if target file already exists
file_name = os.path.basename(file_path)
return f"/loras_static/root1/preview/recipes/{file_name}"
except Exception as e:
logger.error(f"Error formatting recipe file URL: {e}", exc_info=True)
return '/loras_static/images/no-preview.png' # Return default image on error
target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/')
if os.path.exists(target_file_path):
return web.json_response({
'success': False,
'error': f"Target file already exists: {target_file_path}"
}, status=409) # 409 Conflict
def setup_routes(self, app: web.Application):
"""Register routes with the application"""
# Add an app startup handler to initialize services
app.on_startup.append(self._on_startup)
# Register routes
app.router.add_get('/loras', self.handle_loras_page)
app.router.add_get('/loras/recipes', self.handle_recipes_page)
async def _on_startup(self, app):
"""Initialize services when the app starts"""
await self.init_services()
# Call scanner to handle the move operation
success = await self.service.scanner.move_model(file_path, target_path)
if success:
return web.json_response({'success': True})
else:
return web.Response(text='Failed to move model', status=500)
except Exception as e:
logger.error(f"Error moving model: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
async def move_models_bulk(self, request: web.Request) -> web.Response:
"""Handle bulk model move request"""
try:
data = await request.json()
file_paths = data.get('file_paths', []) # list of full paths of the model files
target_path = data.get('target_path') # folder path to move the models to
if not file_paths or not target_path:
return web.Response(text='File paths and target path are required', status=400)
results = []
import os
for file_path in file_paths:
# Check if source and destination are the same
source_dir = os.path.dirname(file_path)
if os.path.normpath(source_dir) == os.path.normpath(target_path):
results.append({
"path": file_path,
"success": True,
"message": "Source and target directories are the same"
})
continue
# Check if target file already exists
file_name = os.path.basename(file_path)
target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/')
if os.path.exists(target_file_path):
results.append({
"path": file_path,
"success": False,
"message": f"Target file already exists: {target_file_path}"
})
continue
# Try to move the model
success = await self.service.scanner.move_model(file_path, target_path)
results.append({
"path": file_path,
"success": success,
"message": "Success" if success else "Failed to move model"
})
# Count successes and failures
success_count = sum(1 for r in results if r["success"])
failure_count = len(results) - success_count
return web.json_response({
'success': True,
'message': f'Moved {success_count} of {len(file_paths)} models',
'results': results,
'success_count': success_count,
'failure_count': failure_count
})
except Exception as e:
logger.error(f"Error moving models in bulk: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
async def get_lora_model_description(self, request: web.Request) -> web.Response:
"""Get model description for a Lora model"""
try:
# Get parameters
model_id = request.query.get('model_id')
file_path = request.query.get('file_path')
if not model_id:
return web.json_response({
'success': False,
'error': 'Model ID is required'
}, status=400)
# Check if we already have the description stored in metadata
description = None
tags = []
creator = {}
if file_path:
import os
from ..utils.metadata_manager import MetadataManager
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
description = metadata.get('modelDescription')
tags = metadata.get('tags', [])
creator = metadata.get('creator', {})
# If description is not in metadata, fetch from CivitAI
if not description:
logger.info(f"Fetching model metadata for model ID: {model_id}")
model_metadata, _ = await self.civitai_client.get_model_metadata(model_id)
if model_metadata:
description = model_metadata.get('description')
tags = model_metadata.get('tags', [])
creator = model_metadata.get('creator', {})
# Save the metadata to file if we have a file path and got metadata
if file_path:
try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
metadata['modelDescription'] = description
metadata['tags'] = tags
# Ensure the civitai dict exists
if 'civitai' not in metadata:
metadata['civitai'] = {}
# Store creator in the civitai nested structure
metadata['civitai']['creator'] = creator
await MetadataManager.save_metadata(file_path, metadata, True)
except Exception as e:
logger.error(f"Error saving model metadata: {e}")
return web.json_response({
'success': True,
'description': description or "<p>No model description available.</p>",
'tags': tags,
'creator': creator
})
except Exception as e:
logger.error(f"Error getting model metadata: {e}")
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_trigger_words(self, request: web.Request) -> web.Response:
"""Get trigger words for specified LoRA models"""
try:
json_data = await request.json()
lora_names = json_data.get("lora_names", [])
node_ids = json_data.get("node_ids", [])
all_trigger_words = []
for lora_name in lora_names:
_, trigger_words = get_lora_info(lora_name)
all_trigger_words.extend(trigger_words)
# Format the trigger words
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
# Send update to all connected trigger word toggle nodes
for node_id in node_ids:
PromptServer.instance.send_sync("trigger_word_update", {
"id": node_id,
"message": trigger_words_text
})
return web.json_response({"success": True})
except Exception as e:
logger.error(f"Error getting trigger words: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)

View File

@@ -633,9 +633,8 @@ class MiscRoutes:
}, status=400)
# Get both lora and checkpoint scanners
registry = ServiceRegistry.get_instance()
lora_scanner = await registry.get_lora_scanner()
checkpoint_scanner = await registry.get_checkpoint_scanner()
lora_scanner = await ServiceRegistry.get_lora_scanner()
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
# If modelVersionId is provided, check for specific version
if model_version_id_str:

View File

@@ -1,6 +1,7 @@
import os
import time
import base64
import jinja2
import numpy as np
from PIL import Image
import io
@@ -15,6 +16,7 @@ from ..utils.exif_utils import ExifUtils
from ..recipes import RecipeParserFactory
from ..utils.constants import CARD_PREVIEW_WIDTH
from ..services.settings_manager import settings
from ..config import config
# Check if running in standalone mode
@@ -39,7 +41,10 @@ class RecipeRoutes:
# Initialize service references as None, will be set during async init
self.recipe_scanner = None
self.civitai_client = None
# Remove WorkflowParser instance
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
# Pre-warm the cache
self._init_cache_task = None
@@ -53,6 +58,8 @@ class RecipeRoutes:
def setup_routes(cls, app: web.Application):
"""Register API routes"""
routes = cls()
app.router.add_get('/loras/recipes', routes.handle_recipes_page)
app.router.add_get('/api/recipes', routes.get_recipes)
app.router.add_get('/api/recipe/{recipe_id}', routes.get_recipe_detail)
app.router.add_post('/api/recipes/analyze-image', routes.analyze_recipe_image)
@@ -114,6 +121,46 @@ class RecipeRoutes:
await self.recipe_scanner.get_cached_data(force_refresh=True)
except Exception as e:
logger.error(f"Error pre-warming recipe cache: {e}", exc_info=True)
async def handle_recipes_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras/recipes request"""
try:
# Ensure services are initialized
await self.init_services()
# Skip initialization check and directly try to get cached data
try:
# Recipe scanner will initialize cache if needed
await self.recipe_scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('recipes.html')
rendered = template.render(
recipes=[], # Frontend will load recipes via API
is_initializing=False,
settings=settings,
request=request
)
except Exception as cache_error:
logger.error(f"Error loading recipe cache data: {cache_error}")
# Still keep error handling - show initializing page on error
template = self.template_env.get_template('recipes.html')
rendered = template.render(
is_initializing=True,
settings=settings,
request=request
)
logger.info("Recipe cache error, returning initialization page")
return web.Response(
text=rendered,
content_type='text/html'
)
except Exception as e:
logger.error(f"Error handling recipes request: {e}", exc_info=True)
return web.Response(
text="Error loading recipes page",
status=500
)
async def get_recipes(self, request: web.Request) -> web.Response:
"""API endpoint for getting paginated recipes"""