mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
refactor: Implement base model routes and services for LoRA and Checkpoint
- Added BaseModelRoutes class to handle common routes and logic for model types. - Created CheckpointRoutes class inheriting from BaseModelRoutes for checkpoint-specific routes. - Implemented CheckpointService class for handling checkpoint-related data and operations. - Developed LoraService class for managing LoRA-specific functionalities. - Introduced ModelServiceFactory to manage service and route registrations for different model types. - Established methods for fetching, filtering, and formatting model data across services. - Integrated CivitAI metadata handling within model routes and services. - Added pagination and filtering capabilities for model data retrieval.
This commit is contained in:
@@ -6,10 +6,8 @@ from pathlib import Path
|
|||||||
from server import PromptServer # type: ignore
|
from server import PromptServer # type: ignore
|
||||||
|
|
||||||
from .config import config
|
from .config import config
|
||||||
from .routes.lora_routes import LoraRoutes
|
from .services.model_service_factory import ModelServiceFactory, register_default_model_types
|
||||||
from .routes.api_routes import ApiRoutes
|
|
||||||
from .routes.recipe_routes import RecipeRoutes
|
from .routes.recipe_routes import RecipeRoutes
|
||||||
from .routes.checkpoints_routes import CheckpointsRoutes
|
|
||||||
from .routes.stats_routes import StatsRoutes
|
from .routes.stats_routes import StatsRoutes
|
||||||
from .routes.update_routes import UpdateRoutes
|
from .routes.update_routes import UpdateRoutes
|
||||||
from .routes.misc_routes import MiscRoutes
|
from .routes.misc_routes import MiscRoutes
|
||||||
@@ -17,6 +15,7 @@ from .routes.example_images_routes import ExampleImagesRoutes
|
|||||||
from .services.service_registry import ServiceRegistry
|
from .services.service_registry import ServiceRegistry
|
||||||
from .services.settings_manager import settings
|
from .services.settings_manager import settings
|
||||||
from .utils.example_images_migration import ExampleImagesMigration
|
from .utils.example_images_migration import ExampleImagesMigration
|
||||||
|
from .services.websocket_manager import ws_manager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -28,7 +27,7 @@ class LoraManager:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_routes(cls):
|
def add_routes(cls):
|
||||||
"""Initialize and register all routes"""
|
"""Initialize and register all routes using the new refactored architecture"""
|
||||||
app = PromptServer.instance.app
|
app = PromptServer.instance.app
|
||||||
|
|
||||||
# Configure aiohttp access logger to be less verbose
|
# Configure aiohttp access logger to be less verbose
|
||||||
@@ -110,27 +109,32 @@ class LoraManager:
|
|||||||
# Add static route for plugin assets
|
# Add static route for plugin assets
|
||||||
app.router.add_static('/loras_static', config.static_path)
|
app.router.add_static('/loras_static', config.static_path)
|
||||||
|
|
||||||
# Setup feature routes
|
# Register default model types with the factory
|
||||||
lora_routes = LoraRoutes()
|
register_default_model_types()
|
||||||
checkpoints_routes = CheckpointsRoutes()
|
|
||||||
stats_routes = StatsRoutes()
|
|
||||||
|
|
||||||
# Initialize routes
|
# Setup all model routes using the factory
|
||||||
lora_routes.setup_routes(app)
|
ModelServiceFactory.setup_all_routes(app)
|
||||||
checkpoints_routes.setup_routes(app)
|
|
||||||
stats_routes.setup_routes(app) # Add statistics routes
|
# Setup non-model-specific routes
|
||||||
ApiRoutes.setup_routes(app)
|
stats_routes = StatsRoutes()
|
||||||
|
stats_routes.setup_routes(app)
|
||||||
RecipeRoutes.setup_routes(app)
|
RecipeRoutes.setup_routes(app)
|
||||||
UpdateRoutes.setup_routes(app)
|
UpdateRoutes.setup_routes(app)
|
||||||
MiscRoutes.setup_routes(app) # Register miscellaneous routes
|
MiscRoutes.setup_routes(app)
|
||||||
ExampleImagesRoutes.setup_routes(app) # Register example images routes
|
ExampleImagesRoutes.setup_routes(app)
|
||||||
|
|
||||||
|
# Setup WebSocket routes that are shared across all model types
|
||||||
|
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||||
|
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection)
|
||||||
|
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection)
|
||||||
|
|
||||||
# Schedule service initialization
|
# Schedule service initialization
|
||||||
app.on_startup.append(lambda app: cls._initialize_services())
|
app.on_startup.append(lambda app: cls._initialize_services())
|
||||||
|
|
||||||
# Add cleanup
|
# Add cleanup
|
||||||
app.on_shutdown.append(cls._cleanup)
|
app.on_shutdown.append(cls._cleanup)
|
||||||
app.on_shutdown.append(ApiRoutes.cleanup)
|
|
||||||
|
logger.info(f"LoRA Manager: Set up routes for {len(ModelServiceFactory.get_registered_types())} model types: {', '.join(ModelServiceFactory.get_registered_types())}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _initialize_services(cls):
|
async def _initialize_services(cls):
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
431
py/routes/base_model_routes.py
Normal file
431
py/routes/base_model_routes.py
Normal file
@@ -0,0 +1,431 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from aiohttp import web
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from ..utils.routes_common import ModelRouteUtils
|
||||||
|
from ..services.websocket_manager import ws_manager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class BaseModelRoutes(ABC):
|
||||||
|
"""Base route controller for all model types"""
|
||||||
|
|
||||||
|
def __init__(self, service):
|
||||||
|
"""Initialize the route controller
|
||||||
|
|
||||||
|
Args:
|
||||||
|
service: Model service instance (LoraService, CheckpointService, etc.)
|
||||||
|
"""
|
||||||
|
self.service = service
|
||||||
|
self.model_type = service.model_type
|
||||||
|
|
||||||
|
def setup_routes(self, app: web.Application, prefix: str):
|
||||||
|
"""Setup common routes for the model type
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: aiohttp application
|
||||||
|
prefix: URL prefix (e.g., 'loras', 'checkpoints')
|
||||||
|
"""
|
||||||
|
# Common model management routes
|
||||||
|
app.router.add_get(f'/api/{prefix}', self.get_models)
|
||||||
|
app.router.add_post(f'/api/{prefix}/delete', self.delete_model)
|
||||||
|
app.router.add_post(f'/api/{prefix}/exclude', self.exclude_model)
|
||||||
|
app.router.add_post(f'/api/{prefix}/fetch-civitai', self.fetch_civitai)
|
||||||
|
app.router.add_post(f'/api/{prefix}/relink-civitai', self.relink_civitai)
|
||||||
|
app.router.add_post(f'/api/{prefix}/replace-preview', self.replace_preview)
|
||||||
|
app.router.add_post(f'/api/{prefix}/save-metadata', self.save_metadata)
|
||||||
|
app.router.add_post(f'/api/{prefix}/rename', self.rename_model)
|
||||||
|
app.router.add_post(f'/api/{prefix}/bulk-delete', self.bulk_delete_models)
|
||||||
|
app.router.add_post(f'/api/{prefix}/verify-duplicates', self.verify_duplicates)
|
||||||
|
|
||||||
|
# Common query routes
|
||||||
|
app.router.add_get(f'/api/{prefix}/top-tags', self.get_top_tags)
|
||||||
|
app.router.add_get(f'/api/{prefix}/base-models', self.get_base_models)
|
||||||
|
app.router.add_get(f'/api/{prefix}/scan', self.scan_models)
|
||||||
|
app.router.add_get(f'/api/{prefix}/roots', self.get_model_roots)
|
||||||
|
app.router.add_get(f'/api/{prefix}/find-duplicates', self.find_duplicate_models)
|
||||||
|
app.router.add_get(f'/api/{prefix}/find-filename-conflicts', self.find_filename_conflicts)
|
||||||
|
|
||||||
|
# CivitAI integration routes
|
||||||
|
app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai)
|
||||||
|
app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions)
|
||||||
|
|
||||||
|
# Setup model-specific routes
|
||||||
|
self.setup_specific_routes(app, prefix)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def setup_specific_routes(self, app: web.Application, prefix: str):
|
||||||
|
"""Setup model-specific routes - to be implemented by subclasses"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_models(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get paginated model data"""
|
||||||
|
try:
|
||||||
|
# Parse common query parameters
|
||||||
|
params = self._parse_common_params(request)
|
||||||
|
|
||||||
|
# Get data from service
|
||||||
|
result = await self.service.get_paginated_data(**params)
|
||||||
|
|
||||||
|
# Format response items
|
||||||
|
formatted_result = {
|
||||||
|
'items': [await self.service.format_response(item) for item in result['items']],
|
||||||
|
'total': result['total'],
|
||||||
|
'page': result['page'],
|
||||||
|
'page_size': result['page_size'],
|
||||||
|
'total_pages': result['total_pages']
|
||||||
|
}
|
||||||
|
|
||||||
|
return web.json_response(formatted_result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in get_{self.model_type}s: {e}", exc_info=True)
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
def _parse_common_params(self, request: web.Request) -> Dict:
|
||||||
|
"""Parse common query parameters"""
|
||||||
|
# Parse basic pagination and sorting
|
||||||
|
page = int(request.query.get('page', '1'))
|
||||||
|
page_size = min(int(request.query.get('page_size', '20')), 100)
|
||||||
|
sort_by = request.query.get('sort_by', 'name')
|
||||||
|
folder = request.query.get('folder', None)
|
||||||
|
search = request.query.get('search', None)
|
||||||
|
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
|
||||||
|
|
||||||
|
# Parse filter arrays
|
||||||
|
base_models = request.query.getall('base_model', [])
|
||||||
|
tags = request.query.getall('tag', [])
|
||||||
|
favorites_only = request.query.get('favorites_only', 'false').lower() == 'true'
|
||||||
|
|
||||||
|
# Parse search options
|
||||||
|
search_options = {
|
||||||
|
'filename': request.query.get('search_filename', 'true').lower() == 'true',
|
||||||
|
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
|
||||||
|
'tags': request.query.get('search_tags', 'false').lower() == 'true',
|
||||||
|
'recursive': request.query.get('recursive', 'false').lower() == 'true',
|
||||||
|
}
|
||||||
|
|
||||||
|
# Parse hash filters if provided
|
||||||
|
hash_filters = {}
|
||||||
|
if 'hash' in request.query:
|
||||||
|
hash_filters['single_hash'] = request.query['hash']
|
||||||
|
elif 'hashes' in request.query:
|
||||||
|
try:
|
||||||
|
hash_list = json.loads(request.query['hashes'])
|
||||||
|
if isinstance(hash_list, list):
|
||||||
|
hash_filters['multiple_hashes'] = hash_list
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {
|
||||||
|
'page': page,
|
||||||
|
'page_size': page_size,
|
||||||
|
'sort_by': sort_by,
|
||||||
|
'folder': folder,
|
||||||
|
'search': search,
|
||||||
|
'fuzzy_search': fuzzy_search,
|
||||||
|
'base_models': base_models,
|
||||||
|
'tags': tags,
|
||||||
|
'search_options': search_options,
|
||||||
|
'hash_filters': hash_filters,
|
||||||
|
'favorites_only': favorites_only,
|
||||||
|
# Add model-specific parameters
|
||||||
|
**self._parse_specific_params(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
def _parse_specific_params(self, request: web.Request) -> Dict:
|
||||||
|
"""Parse model-specific parameters - to be overridden by subclasses"""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Common route handlers
|
||||||
|
async def delete_model(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle model deletion request"""
|
||||||
|
return await ModelRouteUtils.handle_delete_model(request, self.service.scanner)
|
||||||
|
|
||||||
|
async def exclude_model(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle model exclusion request"""
|
||||||
|
return await ModelRouteUtils.handle_exclude_model(request, self.service.scanner)
|
||||||
|
|
||||||
|
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle CivitAI metadata fetch request"""
|
||||||
|
response = await ModelRouteUtils.handle_fetch_civitai(request, self.service.scanner)
|
||||||
|
|
||||||
|
# If successful, format the metadata before returning
|
||||||
|
if response.status == 200:
|
||||||
|
data = json.loads(response.body.decode('utf-8'))
|
||||||
|
if data.get("success") and data.get("metadata"):
|
||||||
|
formatted_metadata = await self.service.format_response(data["metadata"])
|
||||||
|
return web.json_response({
|
||||||
|
"success": True,
|
||||||
|
"metadata": formatted_metadata
|
||||||
|
})
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def relink_civitai(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle CivitAI metadata re-linking request"""
|
||||||
|
return await ModelRouteUtils.handle_relink_civitai(request, self.service.scanner)
|
||||||
|
|
||||||
|
async def replace_preview(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle preview image replacement"""
|
||||||
|
return await ModelRouteUtils.handle_replace_preview(request, self.service.scanner)
|
||||||
|
|
||||||
|
async def save_metadata(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle saving metadata updates"""
|
||||||
|
return await ModelRouteUtils.handle_save_metadata(request, self.service.scanner)
|
||||||
|
|
||||||
|
async def rename_model(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle renaming a model file and its associated files"""
|
||||||
|
return await ModelRouteUtils.handle_rename_model(request, self.service.scanner)
|
||||||
|
|
||||||
|
async def bulk_delete_models(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle bulk deletion of models"""
|
||||||
|
return await ModelRouteUtils.handle_bulk_delete_models(request, self.service.scanner)
|
||||||
|
|
||||||
|
async def verify_duplicates(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle verification of duplicate model hashes"""
|
||||||
|
return await ModelRouteUtils.handle_verify_duplicates(request, self.service.scanner)
|
||||||
|
|
||||||
|
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle request for top tags sorted by frequency"""
|
||||||
|
try:
|
||||||
|
limit = int(request.query.get('limit', '20'))
|
||||||
|
if limit < 1 or limit > 100:
|
||||||
|
limit = 20
|
||||||
|
|
||||||
|
top_tags = await self.service.get_top_tags(limit)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'tags': top_tags
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Internal server error'
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get base models used in models"""
|
||||||
|
try:
|
||||||
|
limit = int(request.query.get('limit', '20'))
|
||||||
|
if limit < 1 or limit > 100:
|
||||||
|
limit = 20
|
||||||
|
|
||||||
|
base_models = await self.service.get_base_models(limit)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'base_models': base_models
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error retrieving base models: {e}")
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def scan_models(self, request: web.Request) -> web.Response:
|
||||||
|
"""Force a rescan of model files"""
|
||||||
|
try:
|
||||||
|
full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true'
|
||||||
|
|
||||||
|
await self.service.scan_models(force_refresh=True, rebuild_cache=full_rebuild)
|
||||||
|
return web.json_response({
|
||||||
|
"status": "success",
|
||||||
|
"message": f"{self.model_type.capitalize()} scan completed"
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in scan_{self.model_type}s: {e}", exc_info=True)
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def get_model_roots(self, request: web.Request) -> web.Response:
|
||||||
|
"""Return the model root directories"""
|
||||||
|
try:
|
||||||
|
roots = self.service.get_model_roots()
|
||||||
|
return web.json_response({
|
||||||
|
"success": True,
|
||||||
|
"roots": roots
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting {self.model_type} roots: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
"success": False,
|
||||||
|
"error": str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def find_duplicate_models(self, request: web.Request) -> web.Response:
|
||||||
|
"""Find models with duplicate SHA256 hashes"""
|
||||||
|
try:
|
||||||
|
# Get duplicate hashes from service
|
||||||
|
duplicates = self.service.find_duplicate_hashes()
|
||||||
|
|
||||||
|
# Format the response
|
||||||
|
result = []
|
||||||
|
cache = await self.service.scanner.get_cached_data()
|
||||||
|
|
||||||
|
for sha256, paths in duplicates.items():
|
||||||
|
group = {
|
||||||
|
"hash": sha256,
|
||||||
|
"models": []
|
||||||
|
}
|
||||||
|
# Find matching models for each path
|
||||||
|
for path in paths:
|
||||||
|
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
|
||||||
|
if model:
|
||||||
|
group["models"].append(await self.service.format_response(model))
|
||||||
|
|
||||||
|
# Add the primary model too
|
||||||
|
primary_path = self.service.get_path_by_hash(sha256)
|
||||||
|
if primary_path and primary_path not in paths:
|
||||||
|
primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None)
|
||||||
|
if primary_model:
|
||||||
|
group["models"].insert(0, await self.service.format_response(primary_model))
|
||||||
|
|
||||||
|
if len(group["models"]) > 1: # Only include if we found multiple models
|
||||||
|
result.append(group)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
"success": True,
|
||||||
|
"duplicates": result,
|
||||||
|
"count": len(result)
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error finding duplicate {self.model_type}s: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
"success": False,
|
||||||
|
"error": str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def find_filename_conflicts(self, request: web.Request) -> web.Response:
|
||||||
|
"""Find models with conflicting filenames"""
|
||||||
|
try:
|
||||||
|
# Get duplicate filenames from service
|
||||||
|
duplicates = self.service.find_duplicate_filenames()
|
||||||
|
|
||||||
|
# Format the response
|
||||||
|
result = []
|
||||||
|
cache = await self.service.scanner.get_cached_data()
|
||||||
|
|
||||||
|
for filename, paths in duplicates.items():
|
||||||
|
group = {
|
||||||
|
"filename": filename,
|
||||||
|
"models": []
|
||||||
|
}
|
||||||
|
# Find matching models for each path
|
||||||
|
for path in paths:
|
||||||
|
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
|
||||||
|
if model:
|
||||||
|
group["models"].append(await self.service.format_response(model))
|
||||||
|
|
||||||
|
# Find the model from the main index too
|
||||||
|
hash_val = self.service.scanner._hash_index.get_hash_by_filename(filename)
|
||||||
|
if hash_val:
|
||||||
|
main_path = self.service.get_path_by_hash(hash_val)
|
||||||
|
if main_path and main_path not in paths:
|
||||||
|
main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None)
|
||||||
|
if main_model:
|
||||||
|
group["models"].insert(0, await self.service.format_response(main_model))
|
||||||
|
|
||||||
|
if group["models"]:
|
||||||
|
result.append(group)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
"success": True,
|
||||||
|
"conflicts": result,
|
||||||
|
"count": len(result)
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error finding filename conflicts for {self.model_type}s: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
"success": False,
|
||||||
|
"error": str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||||
|
"""Fetch CivitAI metadata for all models in the background"""
|
||||||
|
try:
|
||||||
|
cache = await self.service.scanner.get_cached_data()
|
||||||
|
total = len(cache.raw_data)
|
||||||
|
processed = 0
|
||||||
|
success = 0
|
||||||
|
needs_resort = False
|
||||||
|
|
||||||
|
# Prepare models to process
|
||||||
|
to_process = [
|
||||||
|
model for model in cache.raw_data
|
||||||
|
if model.get('sha256') and (not model.get('civitai') or 'id' not in model.get('civitai')) and model.get('from_civitai', True)
|
||||||
|
]
|
||||||
|
total_to_process = len(to_process)
|
||||||
|
|
||||||
|
# Send initial progress
|
||||||
|
await ws_manager.broadcast({
|
||||||
|
'status': 'started',
|
||||||
|
'total': total_to_process,
|
||||||
|
'processed': 0,
|
||||||
|
'success': 0
|
||||||
|
})
|
||||||
|
|
||||||
|
# Process each model
|
||||||
|
for model in to_process:
|
||||||
|
try:
|
||||||
|
original_name = model.get('model_name')
|
||||||
|
if await ModelRouteUtils.fetch_and_update_model(
|
||||||
|
sha256=model['sha256'],
|
||||||
|
file_path=model['file_path'],
|
||||||
|
model_data=model,
|
||||||
|
update_cache_func=self.service.scanner.update_single_model_cache
|
||||||
|
):
|
||||||
|
success += 1
|
||||||
|
if original_name != model.get('model_name'):
|
||||||
|
needs_resort = True
|
||||||
|
|
||||||
|
processed += 1
|
||||||
|
|
||||||
|
# Send progress update
|
||||||
|
await ws_manager.broadcast({
|
||||||
|
'status': 'processing',
|
||||||
|
'total': total_to_process,
|
||||||
|
'processed': processed,
|
||||||
|
'success': success,
|
||||||
|
'current_name': model.get('model_name', 'Unknown')
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching CivitAI data for {model['file_path']}: {e}")
|
||||||
|
|
||||||
|
if needs_resort:
|
||||||
|
await cache.resort(name_only=True)
|
||||||
|
|
||||||
|
# Send completion message
|
||||||
|
await ws_manager.broadcast({
|
||||||
|
'status': 'completed',
|
||||||
|
'total': total_to_process,
|
||||||
|
'processed': processed,
|
||||||
|
'success': success
|
||||||
|
})
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
"success": True,
|
||||||
|
"message": f"Successfully updated {success} of {processed} processed {self.model_type}s (total: {total})"
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Send error message
|
||||||
|
await ws_manager.broadcast({
|
||||||
|
'status': 'error',
|
||||||
|
'error': str(e)
|
||||||
|
})
|
||||||
|
logger.error(f"Error in fetch_all_civitai for {self.model_type}s: {e}")
|
||||||
|
return web.Response(text=str(e), status=500)
|
||||||
|
|
||||||
|
async def get_civitai_versions(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get available versions for a Civitai model with local availability info"""
|
||||||
|
# This will be implemented by subclasses as they need CivitAI client access
|
||||||
|
return web.json_response({
|
||||||
|
"error": "Not implemented in base class"
|
||||||
|
}, status=501)
|
||||||
170
py/routes/checkpoint_routes.py
Normal file
170
py/routes/checkpoint_routes.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
import jinja2
|
||||||
|
import logging
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from .base_model_routes import BaseModelRoutes
|
||||||
|
from ..services.checkpoint_service import CheckpointService
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
from ..config import config
|
||||||
|
from ..services.settings_manager import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class CheckpointRoutes(BaseModelRoutes):
|
||||||
|
"""Checkpoint-specific route controller"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize Checkpoint routes with Checkpoint service"""
|
||||||
|
# Service will be initialized later via setup_routes
|
||||||
|
self.service = None
|
||||||
|
self.civitai_client = None
|
||||||
|
self.template_env = jinja2.Environment(
|
||||||
|
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||||
|
autoescape=True
|
||||||
|
)
|
||||||
|
|
||||||
|
async def initialize_services(self):
|
||||||
|
"""Initialize services from ServiceRegistry"""
|
||||||
|
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||||
|
self.service = CheckpointService(checkpoint_scanner)
|
||||||
|
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||||
|
|
||||||
|
# Initialize parent with the service
|
||||||
|
super().__init__(self.service)
|
||||||
|
|
||||||
|
def setup_routes(self, app: web.Application):
|
||||||
|
"""Setup Checkpoint routes"""
|
||||||
|
# Schedule service initialization on app startup
|
||||||
|
app.on_startup.append(lambda _: self.initialize_services())
|
||||||
|
|
||||||
|
# Setup common routes with 'checkpoints' prefix
|
||||||
|
super().setup_routes(app, 'checkpoints')
|
||||||
|
|
||||||
|
def setup_specific_routes(self, app: web.Application, prefix: str):
|
||||||
|
"""Setup Checkpoint-specific routes"""
|
||||||
|
# Checkpoint page route
|
||||||
|
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
|
||||||
|
|
||||||
|
# Checkpoint-specific CivitAI integration
|
||||||
|
app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint)
|
||||||
|
|
||||||
|
# Checkpoint info by name
|
||||||
|
app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info)
|
||||||
|
|
||||||
|
async def handle_checkpoints_page(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle GET /checkpoints request"""
|
||||||
|
try:
|
||||||
|
# Check if the CheckpointScanner is initializing
|
||||||
|
# It's initializing if the cache object doesn't exist yet,
|
||||||
|
# OR if the scanner explicitly says it's initializing (background task running).
|
||||||
|
is_initializing = (
|
||||||
|
self.service.scanner._cache is None or
|
||||||
|
(hasattr(self.service.scanner, '_is_initializing') and self.service.scanner._is_initializing)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_initializing:
|
||||||
|
# If still initializing, return loading page
|
||||||
|
template = self.template_env.get_template('checkpoints.html')
|
||||||
|
rendered = template.render(
|
||||||
|
folders=[], # Empty folder list
|
||||||
|
is_initializing=True, # New flag
|
||||||
|
settings=settings, # Pass settings to template
|
||||||
|
request=request # Pass the request object to the template
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Checkpoints page is initializing, returning loading page")
|
||||||
|
else:
|
||||||
|
# Normal flow - get initialized cache data
|
||||||
|
try:
|
||||||
|
cache = await self.service.scanner.get_cached_data(force_refresh=False)
|
||||||
|
template = self.template_env.get_template('checkpoints.html')
|
||||||
|
rendered = template.render(
|
||||||
|
folders=cache.folders,
|
||||||
|
is_initializing=False,
|
||||||
|
settings=settings, # Pass settings to template
|
||||||
|
request=request # Pass the request object to the template
|
||||||
|
)
|
||||||
|
except Exception as cache_error:
|
||||||
|
logger.error(f"Error loading checkpoints cache data: {cache_error}")
|
||||||
|
# If getting cache fails, also show initialization page
|
||||||
|
template = self.template_env.get_template('checkpoints.html')
|
||||||
|
rendered = template.render(
|
||||||
|
folders=[],
|
||||||
|
is_initializing=True,
|
||||||
|
settings=settings,
|
||||||
|
request=request
|
||||||
|
)
|
||||||
|
logger.info("Checkpoints cache error, returning initialization page")
|
||||||
|
|
||||||
|
return web.Response(
|
||||||
|
text=rendered,
|
||||||
|
content_type='text/html'
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling checkpoints request: {e}", exc_info=True)
|
||||||
|
return web.Response(
|
||||||
|
text="Error loading checkpoints page",
|
||||||
|
status=500
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_checkpoint_info(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get detailed information for a specific checkpoint by name"""
|
||||||
|
try:
|
||||||
|
name = request.match_info.get('name', '')
|
||||||
|
checkpoint_info = await self.service.get_model_info_by_name(name)
|
||||||
|
|
||||||
|
if checkpoint_info:
|
||||||
|
return web.json_response(checkpoint_info)
|
||||||
|
else:
|
||||||
|
return web.json_response({"error": "Checkpoint not found"}, status=404)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def get_civitai_versions_checkpoint(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get available versions for a Civitai checkpoint model with local availability info"""
|
||||||
|
try:
|
||||||
|
model_id = request.match_info['model_id']
|
||||||
|
response = await self.civitai_client.get_model_versions(model_id)
|
||||||
|
if not response or not response.get('modelVersions'):
|
||||||
|
return web.Response(status=404, text="Model not found")
|
||||||
|
|
||||||
|
versions = response.get('modelVersions', [])
|
||||||
|
model_type = response.get('type', '')
|
||||||
|
|
||||||
|
# Check model type - should be Checkpoint
|
||||||
|
if model_type.lower() != 'checkpoint':
|
||||||
|
return web.json_response({
|
||||||
|
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
# Check local availability for each version
|
||||||
|
for version in versions:
|
||||||
|
# Find the primary model file (type="Model" and primary=true) in the files list
|
||||||
|
model_file = next((file for file in version.get('files', [])
|
||||||
|
if file.get('type') == 'Model' and file.get('primary') == True), None)
|
||||||
|
|
||||||
|
# If no primary file found, try to find any model file
|
||||||
|
if not model_file:
|
||||||
|
model_file = next((file for file in version.get('files', [])
|
||||||
|
if file.get('type') == 'Model'), None)
|
||||||
|
|
||||||
|
if model_file:
|
||||||
|
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||||
|
if sha256:
|
||||||
|
# Set existsLocally and localPath at the version level
|
||||||
|
version['existsLocally'] = self.service.has_hash(sha256)
|
||||||
|
if version['existsLocally']:
|
||||||
|
version['localPath'] = self.service.get_path_by_hash(sha256)
|
||||||
|
|
||||||
|
# Also set the model file size at the version level for easier access
|
||||||
|
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||||
|
else:
|
||||||
|
# No model file found in this version
|
||||||
|
version['existsLocally'] = False
|
||||||
|
|
||||||
|
return web.json_response(versions)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching checkpoint model versions: {e}")
|
||||||
|
return web.Response(status=500, text=str(e))
|
||||||
@@ -1,76 +1,126 @@
|
|||||||
|
import jinja2
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import jinja2
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
import logging
|
from server import PromptServer # type: ignore
|
||||||
from ..config import config
|
|
||||||
|
from .base_model_routes import BaseModelRoutes
|
||||||
|
from ..services.lora_service import LoraService
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..services.settings_manager import settings
|
from ..services.settings_manager import settings
|
||||||
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
|
from ..config import config
|
||||||
|
from ..utils.routes_common import ModelRouteUtils
|
||||||
|
from ..utils.utils import get_lora_info
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
|
|
||||||
|
|
||||||
class LoraRoutes:
|
class LoraRoutes(BaseModelRoutes):
|
||||||
"""Route handlers for LoRA management endpoints"""
|
"""LoRA-specific route controller"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Initialize service references as None, will be set during async init
|
"""Initialize LoRA routes with LoRA service"""
|
||||||
self.scanner = None
|
# Service will be initialized later via setup_routes
|
||||||
self.recipe_scanner = None
|
self.service = None
|
||||||
|
self.civitai_client = None
|
||||||
|
self.download_manager = None
|
||||||
|
self._download_lock = asyncio.Lock()
|
||||||
self.template_env = jinja2.Environment(
|
self.template_env = jinja2.Environment(
|
||||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||||
autoescape=True
|
autoescape=True
|
||||||
)
|
)
|
||||||
|
|
||||||
async def init_services(self):
|
async def initialize_services(self):
|
||||||
"""Initialize services from ServiceRegistry"""
|
"""Initialize services from ServiceRegistry"""
|
||||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
self.service = LoraService(lora_scanner)
|
||||||
|
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||||
|
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||||
|
|
||||||
def format_lora_data(self, lora: Dict) -> Dict:
|
# Initialize parent with the service
|
||||||
"""Format LoRA data for template rendering"""
|
super().__init__(self.service)
|
||||||
return {
|
|
||||||
"model_name": lora["model_name"],
|
|
||||||
"file_name": lora["file_name"],
|
|
||||||
"preview_url": config.get_preview_static_url(lora["preview_url"]),
|
|
||||||
"preview_nsfw_level": lora.get("preview_nsfw_level", 0),
|
|
||||||
"base_model": lora["base_model"],
|
|
||||||
"folder": lora["folder"],
|
|
||||||
"sha256": lora["sha256"],
|
|
||||||
"file_path": lora["file_path"].replace(os.sep, "/"),
|
|
||||||
"size": lora["size"],
|
|
||||||
"tags": lora["tags"],
|
|
||||||
"modelDescription": lora["modelDescription"],
|
|
||||||
"usage_tips": lora["usage_tips"],
|
|
||||||
"notes": lora["notes"],
|
|
||||||
"modified": lora["modified"],
|
|
||||||
"from_civitai": lora.get("from_civitai", True),
|
|
||||||
"civitai": self._filter_civitai_data(lora.get("civitai", {}))
|
|
||||||
}
|
|
||||||
|
|
||||||
def _filter_civitai_data(self, data: Dict) -> Dict:
|
def setup_routes(self, app: web.Application):
|
||||||
"""Filter relevant fields from CivitAI data"""
|
"""Setup LoRA routes"""
|
||||||
if not data:
|
# Schedule service initialization on app startup
|
||||||
return {}
|
app.on_startup.append(lambda _: self.initialize_services())
|
||||||
|
|
||||||
fields = [
|
# Setup common routes with 'loras' prefix
|
||||||
"id", "modelId", "name", "createdAt", "updatedAt",
|
super().setup_routes(app, 'loras')
|
||||||
"publishedAt", "trainedWords", "baseModel", "description",
|
|
||||||
"model", "images"
|
def setup_specific_routes(self, app: web.Application, prefix: str):
|
||||||
]
|
"""Setup LoRA-specific routes"""
|
||||||
return {k: data[k] for k in fields if k in data}
|
# Lora page route
|
||||||
|
app.router.add_get('/loras', self.handle_loras_page)
|
||||||
|
|
||||||
|
# LoRA-specific query routes
|
||||||
|
app.router.add_get(f'/api/{prefix}/letter-counts', self.get_letter_counts)
|
||||||
|
app.router.add_get(f'/api/{prefix}/get-notes', self.get_lora_notes)
|
||||||
|
app.router.add_get(f'/api/{prefix}/get-trigger-words', self.get_lora_trigger_words)
|
||||||
|
app.router.add_get(f'/api/lora-preview-url', self.get_lora_preview_url)
|
||||||
|
app.router.add_get(f'/api/lora-civitai-url', self.get_lora_civitai_url)
|
||||||
|
app.router.add_get(f'/api/lora-model-description', self.get_lora_model_description)
|
||||||
|
app.router.add_get(f'/api/folders', self.get_folders)
|
||||||
|
app.router.add_get(f'/api/lora-roots', self.get_lora_roots)
|
||||||
|
|
||||||
|
# LoRA-specific management routes
|
||||||
|
app.router.add_post(f'/api/move_model', self.move_model)
|
||||||
|
app.router.add_post(f'/api/move_models_bulk', self.move_models_bulk)
|
||||||
|
|
||||||
|
# CivitAI integration with LoRA-specific validation
|
||||||
|
app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions_lora)
|
||||||
|
app.router.add_get(f'/api/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version)
|
||||||
|
app.router.add_get(f'/api/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash)
|
||||||
|
|
||||||
|
# Download management
|
||||||
|
app.router.add_post(f'/api/download-model', self.download_model)
|
||||||
|
app.router.add_get(f'/api/download-model-get', self.download_model_get)
|
||||||
|
app.router.add_get(f'/api/cancel-download-get', self.cancel_download_get)
|
||||||
|
app.router.add_get(f'/api/download-progress/{{download_id}}', self.get_download_progress)
|
||||||
|
|
||||||
|
# ComfyUI integration
|
||||||
|
app.router.add_post(f'/loramanager/get_trigger_words', self.get_trigger_words)
|
||||||
|
|
||||||
|
# Legacy API compatibility
|
||||||
|
app.router.add_post(f'/api/delete_model', self.delete_model)
|
||||||
|
app.router.add_post(f'/api/fetch-civitai', self.fetch_civitai)
|
||||||
|
app.router.add_post(f'/api/relink-civitai', self.relink_civitai)
|
||||||
|
app.router.add_post(f'/api/replace_preview', self.replace_preview)
|
||||||
|
app.router.add_post(f'/api/fetch-all-civitai', self.fetch_all_civitai)
|
||||||
|
|
||||||
|
def _parse_specific_params(self, request: web.Request) -> Dict:
|
||||||
|
"""Parse LoRA-specific parameters"""
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
# LoRA-specific parameters
|
||||||
|
if 'first_letter' in request.query:
|
||||||
|
params['first_letter'] = request.query.get('first_letter')
|
||||||
|
|
||||||
|
# Handle fuzzy search parameter name variation
|
||||||
|
if request.query.get('fuzzy') == 'true':
|
||||||
|
params['fuzzy_search'] = True
|
||||||
|
|
||||||
|
# Handle additional filter parameters for LoRAs
|
||||||
|
if 'lora_hash' in request.query:
|
||||||
|
if not params.get('hash_filters'):
|
||||||
|
params['hash_filters'] = {}
|
||||||
|
params['hash_filters']['single_hash'] = request.query['lora_hash'].lower()
|
||||||
|
elif 'lora_hashes' in request.query:
|
||||||
|
if not params.get('hash_filters'):
|
||||||
|
params['hash_filters'] = {}
|
||||||
|
params['hash_filters']['multiple_hashes'] = [h.lower() for h in request.query['lora_hashes'].split(',')]
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
async def handle_loras_page(self, request: web.Request) -> web.Response:
|
async def handle_loras_page(self, request: web.Request) -> web.Response:
|
||||||
"""Handle GET /loras request"""
|
"""Handle GET /loras request"""
|
||||||
try:
|
try:
|
||||||
# Ensure services are initialized
|
|
||||||
await self.init_services()
|
|
||||||
|
|
||||||
# Check if the LoraScanner is initializing
|
# Check if the LoraScanner is initializing
|
||||||
# It's initializing if the cache object doesn't exist yet,
|
# It's initializing if the cache object doesn't exist yet,
|
||||||
# OR if the scanner explicitly says it's initializing (background task running).
|
# OR if the scanner explicitly says it's initializing (background task running).
|
||||||
is_initializing = (
|
is_initializing = (
|
||||||
self.scanner._cache is None or self.scanner.is_initializing()
|
self.service.scanner._cache is None or self.service.scanner.is_initializing()
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_initializing:
|
if is_initializing:
|
||||||
@@ -87,7 +137,7 @@ class LoraRoutes:
|
|||||||
else:
|
else:
|
||||||
# Normal flow - get data from initialized cache
|
# Normal flow - get data from initialized cache
|
||||||
try:
|
try:
|
||||||
cache = await self.scanner.get_cached_data(force_refresh=False)
|
cache = await self.service.scanner.get_cached_data(force_refresh=False)
|
||||||
template = self.template_env.get_template('loras.html')
|
template = self.template_env.get_template('loras.html')
|
||||||
rendered = template.render(
|
rendered = template.render(
|
||||||
folders=cache.folders,
|
folders=cache.folders,
|
||||||
@@ -118,71 +168,560 @@ class LoraRoutes:
|
|||||||
status=500
|
status=500
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_recipes_page(self, request: web.Request) -> web.Response:
|
# LoRA-specific route handlers
|
||||||
"""Handle GET /loras/recipes request"""
|
async def get_letter_counts(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get count of LoRAs for each letter of the alphabet"""
|
||||||
try:
|
try:
|
||||||
# Ensure services are initialized
|
letter_counts = await self.service.get_letter_counts()
|
||||||
await self.init_services()
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'letter_counts': letter_counts
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting letter counts: {e}")
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
# Skip initialization check and directly try to get cached data
|
async def get_lora_notes(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get notes for a specific LoRA file"""
|
||||||
try:
|
try:
|
||||||
# Recipe scanner will initialize cache if needed
|
lora_name = request.query.get('name')
|
||||||
await self.recipe_scanner.get_cached_data(force_refresh=False)
|
if not lora_name:
|
||||||
template = self.template_env.get_template('recipes.html')
|
return web.Response(text='Lora file name is required', status=400)
|
||||||
rendered = template.render(
|
|
||||||
recipes=[], # Frontend will load recipes via API
|
|
||||||
is_initializing=False,
|
|
||||||
settings=settings,
|
|
||||||
request=request
|
|
||||||
)
|
|
||||||
except Exception as cache_error:
|
|
||||||
logger.error(f"Error loading recipe cache data: {cache_error}")
|
|
||||||
# Still keep error handling - show initializing page on error
|
|
||||||
template = self.template_env.get_template('recipes.html')
|
|
||||||
rendered = template.render(
|
|
||||||
is_initializing=True,
|
|
||||||
settings=settings,
|
|
||||||
request=request
|
|
||||||
)
|
|
||||||
logger.info("Recipe cache error, returning initialization page")
|
|
||||||
|
|
||||||
return web.Response(
|
notes = await self.service.get_lora_notes(lora_name)
|
||||||
text=rendered,
|
if notes is not None:
|
||||||
content_type='text/html'
|
return web.json_response({
|
||||||
)
|
'success': True,
|
||||||
|
'notes': notes
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'LoRA not found in cache'
|
||||||
|
}, status=404)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error handling recipes request: {e}", exc_info=True)
|
logger.error(f"Error getting lora notes: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_lora_trigger_words(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get trigger words for a specific LoRA file"""
|
||||||
|
try:
|
||||||
|
lora_name = request.query.get('name')
|
||||||
|
if not lora_name:
|
||||||
|
return web.Response(text='Lora file name is required', status=400)
|
||||||
|
|
||||||
|
trigger_words = await self.service.get_lora_trigger_words(lora_name)
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'trigger_words': trigger_words
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting lora trigger words: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get the static preview URL for a LoRA file"""
|
||||||
|
try:
|
||||||
|
lora_name = request.query.get('name')
|
||||||
|
if not lora_name:
|
||||||
|
return web.Response(text='Lora file name is required', status=400)
|
||||||
|
|
||||||
|
preview_url = await self.service.get_lora_preview_url(lora_name)
|
||||||
|
if preview_url:
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'preview_url': preview_url
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'No preview URL found for the specified lora'
|
||||||
|
}, status=404)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting lora preview URL: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get the Civitai URL for a LoRA file"""
|
||||||
|
try:
|
||||||
|
lora_name = request.query.get('name')
|
||||||
|
if not lora_name:
|
||||||
|
return web.Response(text='Lora file name is required', status=400)
|
||||||
|
|
||||||
|
result = await self.service.get_lora_civitai_url(lora_name)
|
||||||
|
if result['civitai_url']:
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
**result
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'No Civitai data found for the specified lora'
|
||||||
|
}, status=404)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_folders(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get all folders in the cache"""
|
||||||
|
try:
|
||||||
|
cache = await self.service.scanner.get_cached_data()
|
||||||
|
return web.json_response({
|
||||||
|
'folders': cache.folders
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting folders: {e}")
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_lora_roots(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get all configured LoRA root directories"""
|
||||||
|
try:
|
||||||
|
return web.json_response({
|
||||||
|
'roots': self.service.get_model_roots()
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting LoRA roots: {e}")
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
# Override get_models to add LoRA-specific response data
|
||||||
|
async def get_models(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get paginated LoRA data with LoRA-specific fields"""
|
||||||
|
try:
|
||||||
|
# Parse common query parameters
|
||||||
|
params = self._parse_common_params(request)
|
||||||
|
|
||||||
|
# Get data from service
|
||||||
|
result = await self.service.get_paginated_data(**params)
|
||||||
|
|
||||||
|
# Get all available folders from cache for LoRA-specific response
|
||||||
|
cache = await self.service.scanner.get_cached_data()
|
||||||
|
|
||||||
|
# Format response items with LoRA-specific structure
|
||||||
|
formatted_result = {
|
||||||
|
'items': [await self.service.format_response(item) for item in result['items']],
|
||||||
|
'folders': cache.folders, # LoRA-specific: include folders in response
|
||||||
|
'total': result['total'],
|
||||||
|
'page': result['page'],
|
||||||
|
'page_size': result['page_size'],
|
||||||
|
'total_pages': result['total_pages']
|
||||||
|
}
|
||||||
|
|
||||||
|
return web.json_response(formatted_result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in get_loras: {e}", exc_info=True)
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
# CivitAI integration methods
|
||||||
|
async def get_civitai_versions_lora(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get available versions for a Civitai LoRA model with local availability info"""
|
||||||
|
try:
|
||||||
|
model_id = request.match_info['model_id']
|
||||||
|
response = await self.civitai_client.get_model_versions(model_id)
|
||||||
|
if not response or not response.get('modelVersions'):
|
||||||
|
return web.Response(status=404, text="Model not found")
|
||||||
|
|
||||||
|
versions = response.get('modelVersions', [])
|
||||||
|
model_type = response.get('type', '')
|
||||||
|
|
||||||
|
# Check model type - should be LORA, LoCon, or DORA
|
||||||
|
from ..utils.constants import VALID_LORA_TYPES
|
||||||
|
if model_type.lower() not in VALID_LORA_TYPES:
|
||||||
|
return web.json_response({
|
||||||
|
'error': f"Model type mismatch. Expected LORA or LoCon, got {model_type}"
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
# Check local availability for each version
|
||||||
|
for version in versions:
|
||||||
|
# Find the model file (type="Model") in the files list
|
||||||
|
model_file = next((file for file in version.get('files', [])
|
||||||
|
if file.get('type') == 'Model'), None)
|
||||||
|
|
||||||
|
if model_file:
|
||||||
|
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||||
|
if sha256:
|
||||||
|
# Set existsLocally and localPath at the version level
|
||||||
|
version['existsLocally'] = self.service.has_hash(sha256)
|
||||||
|
if version['existsLocally']:
|
||||||
|
version['localPath'] = self.service.get_path_by_hash(sha256)
|
||||||
|
|
||||||
|
# Also set the model file size at the version level for easier access
|
||||||
|
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||||
|
else:
|
||||||
|
# No model file found in this version
|
||||||
|
version['existsLocally'] = False
|
||||||
|
|
||||||
|
return web.json_response(versions)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching LoRA model versions: {e}")
|
||||||
|
return web.Response(status=500, text=str(e))
|
||||||
|
|
||||||
|
async def get_civitai_model_by_version(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get CivitAI model details by model version ID"""
|
||||||
|
try:
|
||||||
|
model_version_id = request.match_info.get('modelVersionId')
|
||||||
|
|
||||||
|
# Get model details from Civitai API
|
||||||
|
model, error_msg = await self.civitai_client.get_model_version_info(model_version_id)
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
# Log warning for failed model retrieval
|
||||||
|
logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}")
|
||||||
|
|
||||||
|
# Determine status code based on error message
|
||||||
|
status_code = 404 if error_msg and "not found" in error_msg.lower() else 500
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
"success": False,
|
||||||
|
"error": error_msg or "Failed to fetch model information"
|
||||||
|
}, status=status_code)
|
||||||
|
|
||||||
|
return web.json_response(model)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching model details: {e}")
|
||||||
|
return web.json_response({
|
||||||
|
"success": False,
|
||||||
|
"error": str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get CivitAI model details by hash"""
|
||||||
|
try:
|
||||||
|
hash = request.match_info.get('hash')
|
||||||
|
model = await self.civitai_client.get_model_by_hash(hash)
|
||||||
|
return web.json_response(model)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching model details by hash: {e}")
|
||||||
|
return web.json_response({
|
||||||
|
"success": False,
|
||||||
|
"error": str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
# Download management methods
|
||||||
|
async def download_model(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle model download request"""
|
||||||
|
return await ModelRouteUtils.handle_download_model(request, self.download_manager)
|
||||||
|
|
||||||
|
async def download_model_get(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle model download request via GET method"""
|
||||||
|
try:
|
||||||
|
# Extract query parameters
|
||||||
|
model_id = request.query.get('model_id')
|
||||||
|
if not model_id:
|
||||||
return web.Response(
|
return web.Response(
|
||||||
text="Error loading recipes page",
|
status=400,
|
||||||
status=500
|
text="Missing required parameter: Please provide 'model_id'"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _format_recipe_file_url(self, file_path: str) -> str:
|
# Get optional parameters
|
||||||
"""Format file path for recipe image as a URL - same as in recipe_routes"""
|
model_version_id = request.query.get('model_version_id')
|
||||||
try:
|
download_id = request.query.get('download_id')
|
||||||
# Return the file URL directly for the first lora root's preview
|
use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true'
|
||||||
recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, '/')
|
|
||||||
if file_path.replace(os.sep, '/').startswith(recipes_dir):
|
|
||||||
relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, '/')
|
|
||||||
return f"/loras_static/root1/preview/{relative_path}"
|
|
||||||
|
|
||||||
# If not in recipes dir, try to create a valid URL from the file path
|
# Create a data dictionary that mimics what would be received from a POST request
|
||||||
|
data = {
|
||||||
|
'model_id': model_id
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional parameters only if they are provided
|
||||||
|
if model_version_id:
|
||||||
|
data['model_version_id'] = model_version_id
|
||||||
|
|
||||||
|
if download_id:
|
||||||
|
data['download_id'] = download_id
|
||||||
|
|
||||||
|
data['use_default_paths'] = use_default_paths
|
||||||
|
|
||||||
|
# Create a mock request object with the data
|
||||||
|
future = asyncio.get_event_loop().create_future()
|
||||||
|
future.set_result(data)
|
||||||
|
|
||||||
|
mock_request = type('MockRequest', (), {
|
||||||
|
'json': lambda self=None: future
|
||||||
|
})()
|
||||||
|
|
||||||
|
# Call the existing download handler
|
||||||
|
return await ModelRouteUtils.handle_download_model(mock_request, self.download_manager)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_message = str(e)
|
||||||
|
logger.error(f"Error downloading model via GET: {error_message}", exc_info=True)
|
||||||
|
return web.Response(status=500, text=error_message)
|
||||||
|
|
||||||
|
async def cancel_download_get(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle GET request for cancelling a download by download_id"""
|
||||||
|
try:
|
||||||
|
download_id = request.query.get('download_id')
|
||||||
|
if not download_id:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Download ID is required'
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
# Create a mock request with match_info for compatibility
|
||||||
|
mock_request = type('MockRequest', (), {
|
||||||
|
'match_info': {'download_id': download_id}
|
||||||
|
})()
|
||||||
|
return await ModelRouteUtils.handle_cancel_download(mock_request, self.download_manager)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error cancelling download via GET: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_download_progress(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle request for download progress by download_id"""
|
||||||
|
try:
|
||||||
|
# Get download_id from URL path
|
||||||
|
download_id = request.match_info.get('download_id')
|
||||||
|
if not download_id:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Download ID is required'
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
# Get progress information from websocket manager
|
||||||
|
from ..services.websocket_manager import ws_manager
|
||||||
|
progress_data = ws_manager.get_download_progress(download_id)
|
||||||
|
|
||||||
|
if progress_data is None:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Download ID not found'
|
||||||
|
}, status=404)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'progress': progress_data.get('progress', 0)
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting download progress: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
# Model management methods
|
||||||
|
async def move_model(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle model move request"""
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
file_path = data.get('file_path') # full path of the model file
|
||||||
|
target_path = data.get('target_path') # folder path to move the model to
|
||||||
|
|
||||||
|
if not file_path or not target_path:
|
||||||
|
return web.Response(text='File path and target path are required', status=400)
|
||||||
|
|
||||||
|
# Check if source and destination are the same
|
||||||
|
import os
|
||||||
|
source_dir = os.path.dirname(file_path)
|
||||||
|
if os.path.normpath(source_dir) == os.path.normpath(target_path):
|
||||||
|
logger.info(f"Source and target directories are the same: {source_dir}")
|
||||||
|
return web.json_response({'success': True, 'message': 'Source and target directories are the same'})
|
||||||
|
|
||||||
|
# Check if target file already exists
|
||||||
file_name = os.path.basename(file_path)
|
file_name = os.path.basename(file_path)
|
||||||
return f"/loras_static/root1/preview/recipes/{file_name}"
|
target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/')
|
||||||
|
|
||||||
|
if os.path.exists(target_file_path):
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': f"Target file already exists: {target_file_path}"
|
||||||
|
}, status=409) # 409 Conflict
|
||||||
|
|
||||||
|
# Call scanner to handle the move operation
|
||||||
|
success = await self.service.scanner.move_model(file_path, target_path)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
return web.json_response({'success': True})
|
||||||
|
else:
|
||||||
|
return web.Response(text='Failed to move model', status=500)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error formatting recipe file URL: {e}", exc_info=True)
|
logger.error(f"Error moving model: {e}", exc_info=True)
|
||||||
return '/loras_static/images/no-preview.png' # Return default image on error
|
return web.Response(text=str(e), status=500)
|
||||||
|
|
||||||
def setup_routes(self, app: web.Application):
|
async def move_models_bulk(self, request: web.Request) -> web.Response:
|
||||||
"""Register routes with the application"""
|
"""Handle bulk model move request"""
|
||||||
# Add an app startup handler to initialize services
|
try:
|
||||||
app.on_startup.append(self._on_startup)
|
data = await request.json()
|
||||||
|
file_paths = data.get('file_paths', []) # list of full paths of the model files
|
||||||
|
target_path = data.get('target_path') # folder path to move the models to
|
||||||
|
|
||||||
# Register routes
|
if not file_paths or not target_path:
|
||||||
app.router.add_get('/loras', self.handle_loras_page)
|
return web.Response(text='File paths and target path are required', status=400)
|
||||||
app.router.add_get('/loras/recipes', self.handle_recipes_page)
|
|
||||||
|
|
||||||
async def _on_startup(self, app):
|
results = []
|
||||||
"""Initialize services when the app starts"""
|
import os
|
||||||
await self.init_services()
|
for file_path in file_paths:
|
||||||
|
# Check if source and destination are the same
|
||||||
|
source_dir = os.path.dirname(file_path)
|
||||||
|
if os.path.normpath(source_dir) == os.path.normpath(target_path):
|
||||||
|
results.append({
|
||||||
|
"path": file_path,
|
||||||
|
"success": True,
|
||||||
|
"message": "Source and target directories are the same"
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if target file already exists
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/')
|
||||||
|
|
||||||
|
if os.path.exists(target_file_path):
|
||||||
|
results.append({
|
||||||
|
"path": file_path,
|
||||||
|
"success": False,
|
||||||
|
"message": f"Target file already exists: {target_file_path}"
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Try to move the model
|
||||||
|
success = await self.service.scanner.move_model(file_path, target_path)
|
||||||
|
results.append({
|
||||||
|
"path": file_path,
|
||||||
|
"success": success,
|
||||||
|
"message": "Success" if success else "Failed to move model"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Count successes and failures
|
||||||
|
success_count = sum(1 for r in results if r["success"])
|
||||||
|
failure_count = len(results) - success_count
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'message': f'Moved {success_count} of {len(file_paths)} models',
|
||||||
|
'results': results,
|
||||||
|
'success_count': success_count,
|
||||||
|
'failure_count': failure_count
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error moving models in bulk: {e}", exc_info=True)
|
||||||
|
return web.Response(text=str(e), status=500)
|
||||||
|
|
||||||
|
async def get_lora_model_description(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get model description for a Lora model"""
|
||||||
|
try:
|
||||||
|
# Get parameters
|
||||||
|
model_id = request.query.get('model_id')
|
||||||
|
file_path = request.query.get('file_path')
|
||||||
|
|
||||||
|
if not model_id:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Model ID is required'
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
# Check if we already have the description stored in metadata
|
||||||
|
description = None
|
||||||
|
tags = []
|
||||||
|
creator = {}
|
||||||
|
if file_path:
|
||||||
|
import os
|
||||||
|
from ..utils.metadata_manager import MetadataManager
|
||||||
|
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||||
|
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||||
|
description = metadata.get('modelDescription')
|
||||||
|
tags = metadata.get('tags', [])
|
||||||
|
creator = metadata.get('creator', {})
|
||||||
|
|
||||||
|
# If description is not in metadata, fetch from CivitAI
|
||||||
|
if not description:
|
||||||
|
logger.info(f"Fetching model metadata for model ID: {model_id}")
|
||||||
|
model_metadata, _ = await self.civitai_client.get_model_metadata(model_id)
|
||||||
|
|
||||||
|
if model_metadata:
|
||||||
|
description = model_metadata.get('description')
|
||||||
|
tags = model_metadata.get('tags', [])
|
||||||
|
creator = model_metadata.get('creator', {})
|
||||||
|
|
||||||
|
# Save the metadata to file if we have a file path and got metadata
|
||||||
|
if file_path:
|
||||||
|
try:
|
||||||
|
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||||
|
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||||
|
|
||||||
|
metadata['modelDescription'] = description
|
||||||
|
metadata['tags'] = tags
|
||||||
|
# Ensure the civitai dict exists
|
||||||
|
if 'civitai' not in metadata:
|
||||||
|
metadata['civitai'] = {}
|
||||||
|
# Store creator in the civitai nested structure
|
||||||
|
metadata['civitai']['creator'] = creator
|
||||||
|
|
||||||
|
await MetadataManager.save_metadata(file_path, metadata, True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving model metadata: {e}")
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'description': description or "<p>No model description available.</p>",
|
||||||
|
'tags': tags,
|
||||||
|
'creator': creator
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting model metadata: {e}")
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_trigger_words(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get trigger words for specified LoRA models"""
|
||||||
|
try:
|
||||||
|
json_data = await request.json()
|
||||||
|
lora_names = json_data.get("lora_names", [])
|
||||||
|
node_ids = json_data.get("node_ids", [])
|
||||||
|
|
||||||
|
all_trigger_words = []
|
||||||
|
for lora_name in lora_names:
|
||||||
|
_, trigger_words = get_lora_info(lora_name)
|
||||||
|
all_trigger_words.extend(trigger_words)
|
||||||
|
|
||||||
|
# Format the trigger words
|
||||||
|
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||||
|
|
||||||
|
# Send update to all connected trigger word toggle nodes
|
||||||
|
for node_id in node_ids:
|
||||||
|
PromptServer.instance.send_sync("trigger_word_update", {
|
||||||
|
"id": node_id,
|
||||||
|
"message": trigger_words_text
|
||||||
|
})
|
||||||
|
|
||||||
|
return web.json_response({"success": True})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting trigger words: {e}")
|
||||||
|
return web.json_response({
|
||||||
|
"success": False,
|
||||||
|
"error": str(e)
|
||||||
|
}, status=500)
|
||||||
|
|||||||
@@ -633,9 +633,8 @@ class MiscRoutes:
|
|||||||
}, status=400)
|
}, status=400)
|
||||||
|
|
||||||
# Get both lora and checkpoint scanners
|
# Get both lora and checkpoint scanners
|
||||||
registry = ServiceRegistry.get_instance()
|
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
lora_scanner = await registry.get_lora_scanner()
|
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||||
checkpoint_scanner = await registry.get_checkpoint_scanner()
|
|
||||||
|
|
||||||
# If modelVersionId is provided, check for specific version
|
# If modelVersionId is provided, check for specific version
|
||||||
if model_version_id_str:
|
if model_version_id_str:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import base64
|
import base64
|
||||||
|
import jinja2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import io
|
import io
|
||||||
@@ -15,6 +16,7 @@ from ..utils.exif_utils import ExifUtils
|
|||||||
from ..recipes import RecipeParserFactory
|
from ..recipes import RecipeParserFactory
|
||||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||||
|
|
||||||
|
from ..services.settings_manager import settings
|
||||||
from ..config import config
|
from ..config import config
|
||||||
|
|
||||||
# Check if running in standalone mode
|
# Check if running in standalone mode
|
||||||
@@ -39,7 +41,10 @@ class RecipeRoutes:
|
|||||||
# Initialize service references as None, will be set during async init
|
# Initialize service references as None, will be set during async init
|
||||||
self.recipe_scanner = None
|
self.recipe_scanner = None
|
||||||
self.civitai_client = None
|
self.civitai_client = None
|
||||||
# Remove WorkflowParser instance
|
self.template_env = jinja2.Environment(
|
||||||
|
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||||
|
autoescape=True
|
||||||
|
)
|
||||||
|
|
||||||
# Pre-warm the cache
|
# Pre-warm the cache
|
||||||
self._init_cache_task = None
|
self._init_cache_task = None
|
||||||
@@ -53,6 +58,8 @@ class RecipeRoutes:
|
|||||||
def setup_routes(cls, app: web.Application):
|
def setup_routes(cls, app: web.Application):
|
||||||
"""Register API routes"""
|
"""Register API routes"""
|
||||||
routes = cls()
|
routes = cls()
|
||||||
|
app.router.add_get('/loras/recipes', routes.handle_recipes_page)
|
||||||
|
|
||||||
app.router.add_get('/api/recipes', routes.get_recipes)
|
app.router.add_get('/api/recipes', routes.get_recipes)
|
||||||
app.router.add_get('/api/recipe/{recipe_id}', routes.get_recipe_detail)
|
app.router.add_get('/api/recipe/{recipe_id}', routes.get_recipe_detail)
|
||||||
app.router.add_post('/api/recipes/analyze-image', routes.analyze_recipe_image)
|
app.router.add_post('/api/recipes/analyze-image', routes.analyze_recipe_image)
|
||||||
@@ -115,6 +122,46 @@ class RecipeRoutes:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error pre-warming recipe cache: {e}", exc_info=True)
|
logger.error(f"Error pre-warming recipe cache: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def handle_recipes_page(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle GET /loras/recipes request"""
|
||||||
|
try:
|
||||||
|
# Ensure services are initialized
|
||||||
|
await self.init_services()
|
||||||
|
|
||||||
|
# Skip initialization check and directly try to get cached data
|
||||||
|
try:
|
||||||
|
# Recipe scanner will initialize cache if needed
|
||||||
|
await self.recipe_scanner.get_cached_data(force_refresh=False)
|
||||||
|
template = self.template_env.get_template('recipes.html')
|
||||||
|
rendered = template.render(
|
||||||
|
recipes=[], # Frontend will load recipes via API
|
||||||
|
is_initializing=False,
|
||||||
|
settings=settings,
|
||||||
|
request=request
|
||||||
|
)
|
||||||
|
except Exception as cache_error:
|
||||||
|
logger.error(f"Error loading recipe cache data: {cache_error}")
|
||||||
|
# Still keep error handling - show initializing page on error
|
||||||
|
template = self.template_env.get_template('recipes.html')
|
||||||
|
rendered = template.render(
|
||||||
|
is_initializing=True,
|
||||||
|
settings=settings,
|
||||||
|
request=request
|
||||||
|
)
|
||||||
|
logger.info("Recipe cache error, returning initialization page")
|
||||||
|
|
||||||
|
return web.Response(
|
||||||
|
text=rendered,
|
||||||
|
content_type='text/html'
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling recipes request: {e}", exc_info=True)
|
||||||
|
return web.Response(
|
||||||
|
text="Error loading recipes page",
|
||||||
|
status=500
|
||||||
|
)
|
||||||
|
|
||||||
async def get_recipes(self, request: web.Request) -> web.Response:
|
async def get_recipes(self, request: web.Request) -> web.Response:
|
||||||
"""API endpoint for getting paginated recipes"""
|
"""API endpoint for getting paginated recipes"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
248
py/services/base_model_service.py
Normal file
248
py/services/base_model_service.py
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, List, Optional, Type, Set
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from ..utils.models import BaseModelMetadata
|
||||||
|
from ..utils.constants import NSFW_LEVELS
|
||||||
|
from .settings_manager import settings
|
||||||
|
from ..utils.utils import fuzzy_match
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class BaseModelService(ABC):
|
||||||
|
"""Base service class for all model types"""
|
||||||
|
|
||||||
|
def __init__(self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata]):
|
||||||
|
"""Initialize the service
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: Type of model (lora, checkpoint, etc.)
|
||||||
|
scanner: Model scanner instance
|
||||||
|
metadata_class: Metadata class for this model type
|
||||||
|
"""
|
||||||
|
self.model_type = model_type
|
||||||
|
self.scanner = scanner
|
||||||
|
self.metadata_class = metadata_class
|
||||||
|
|
||||||
|
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
||||||
|
folder: str = None, search: str = None, fuzzy_search: bool = False,
|
||||||
|
base_models: list = None, tags: list = None,
|
||||||
|
search_options: dict = None, hash_filters: dict = None,
|
||||||
|
favorites_only: bool = False, **kwargs) -> Dict:
|
||||||
|
"""Get paginated and filtered model data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page: Page number (1-based)
|
||||||
|
page_size: Number of items per page
|
||||||
|
sort_by: Sort criteria ('name' or 'date')
|
||||||
|
folder: Folder filter
|
||||||
|
search: Search term
|
||||||
|
fuzzy_search: Whether to use fuzzy search
|
||||||
|
base_models: List of base models to filter by
|
||||||
|
tags: List of tags to filter by
|
||||||
|
search_options: Search options dict
|
||||||
|
hash_filters: Hash filtering options
|
||||||
|
favorites_only: Filter for favorites only
|
||||||
|
**kwargs: Additional model-specific filters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict containing paginated results
|
||||||
|
"""
|
||||||
|
cache = await self.scanner.get_cached_data()
|
||||||
|
|
||||||
|
# Get default search options if not provided
|
||||||
|
if search_options is None:
|
||||||
|
search_options = {
|
||||||
|
'filename': True,
|
||||||
|
'modelname': True,
|
||||||
|
'tags': False,
|
||||||
|
'recursive': False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get the base data set
|
||||||
|
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
||||||
|
|
||||||
|
# Apply hash filtering if provided (highest priority)
|
||||||
|
if hash_filters:
|
||||||
|
filtered_data = await self._apply_hash_filters(filtered_data, hash_filters)
|
||||||
|
|
||||||
|
# Jump to pagination for hash filters
|
||||||
|
return self._paginate(filtered_data, page, page_size)
|
||||||
|
|
||||||
|
# Apply common filters
|
||||||
|
filtered_data = await self._apply_common_filters(
|
||||||
|
filtered_data, folder, base_models, tags, favorites_only, search_options
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply search filtering
|
||||||
|
if search:
|
||||||
|
filtered_data = await self._apply_search_filters(
|
||||||
|
filtered_data, search, fuzzy_search, search_options
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply model-specific filters
|
||||||
|
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
|
||||||
|
|
||||||
|
return self._paginate(filtered_data, page, page_size)
|
||||||
|
|
||||||
|
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
|
||||||
|
"""Apply hash-based filtering"""
|
||||||
|
single_hash = hash_filters.get('single_hash')
|
||||||
|
multiple_hashes = hash_filters.get('multiple_hashes')
|
||||||
|
|
||||||
|
if single_hash:
|
||||||
|
# Filter by single hash
|
||||||
|
single_hash = single_hash.lower()
|
||||||
|
return [
|
||||||
|
item for item in data
|
||||||
|
if item.get('sha256', '').lower() == single_hash
|
||||||
|
]
|
||||||
|
elif multiple_hashes:
|
||||||
|
# Filter by multiple hashes
|
||||||
|
hash_set = set(hash.lower() for hash in multiple_hashes)
|
||||||
|
return [
|
||||||
|
item for item in data
|
||||||
|
if item.get('sha256', '').lower() in hash_set
|
||||||
|
]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def _apply_common_filters(self, data: List[Dict], folder: str = None,
|
||||||
|
base_models: list = None, tags: list = None,
|
||||||
|
favorites_only: bool = False, search_options: dict = None) -> List[Dict]:
|
||||||
|
"""Apply common filters that work across all model types"""
|
||||||
|
# Apply SFW filtering if enabled in settings
|
||||||
|
if settings.get('show_only_sfw', False):
|
||||||
|
data = [
|
||||||
|
item for item in data
|
||||||
|
if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
||||||
|
]
|
||||||
|
|
||||||
|
# Apply favorites filtering if enabled
|
||||||
|
if favorites_only:
|
||||||
|
data = [
|
||||||
|
item for item in data
|
||||||
|
if item.get('favorite', False) is True
|
||||||
|
]
|
||||||
|
|
||||||
|
# Apply folder filtering
|
||||||
|
if folder is not None:
|
||||||
|
if search_options and search_options.get('recursive', False):
|
||||||
|
# Recursive folder filtering - include all subfolders
|
||||||
|
data = [
|
||||||
|
item for item in data
|
||||||
|
if item['folder'].startswith(folder)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# Exact folder filtering
|
||||||
|
data = [
|
||||||
|
item for item in data
|
||||||
|
if item['folder'] == folder
|
||||||
|
]
|
||||||
|
|
||||||
|
# Apply base model filtering
|
||||||
|
if base_models and len(base_models) > 0:
|
||||||
|
data = [
|
||||||
|
item for item in data
|
||||||
|
if item.get('base_model') in base_models
|
||||||
|
]
|
||||||
|
|
||||||
|
# Apply tag filtering
|
||||||
|
if tags and len(tags) > 0:
|
||||||
|
data = [
|
||||||
|
item for item in data
|
||||||
|
if any(tag in item.get('tags', []) for tag in tags)
|
||||||
|
]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def _apply_search_filters(self, data: List[Dict], search: str,
|
||||||
|
fuzzy_search: bool, search_options: dict) -> List[Dict]:
|
||||||
|
"""Apply search filtering"""
|
||||||
|
search_results = []
|
||||||
|
|
||||||
|
for item in data:
|
||||||
|
# Search by file name
|
||||||
|
if search_options.get('filename', True):
|
||||||
|
if fuzzy_search:
|
||||||
|
if fuzzy_match(item.get('file_name', ''), search):
|
||||||
|
search_results.append(item)
|
||||||
|
continue
|
||||||
|
elif search.lower() in item.get('file_name', '').lower():
|
||||||
|
search_results.append(item)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Search by model name
|
||||||
|
if search_options.get('modelname', True):
|
||||||
|
if fuzzy_search:
|
||||||
|
if fuzzy_match(item.get('model_name', ''), search):
|
||||||
|
search_results.append(item)
|
||||||
|
continue
|
||||||
|
elif search.lower() in item.get('model_name', '').lower():
|
||||||
|
search_results.append(item)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Search by tags
|
||||||
|
if search_options.get('tags', False) and 'tags' in item:
|
||||||
|
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower())
|
||||||
|
for tag in item['tags']):
|
||||||
|
search_results.append(item)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||||
|
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
|
||||||
|
"""Apply pagination to filtered data"""
|
||||||
|
total_items = len(data)
|
||||||
|
start_idx = (page - 1) * page_size
|
||||||
|
end_idx = min(start_idx + page_size, total_items)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'items': data[start_idx:end_idx],
|
||||||
|
'total': total_items,
|
||||||
|
'page': page,
|
||||||
|
'page_size': page_size,
|
||||||
|
'total_pages': (total_items + page_size - 1) // page_size
|
||||||
|
}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def format_response(self, model_data: Dict) -> Dict:
|
||||||
|
"""Format model data for API response - must be implemented by subclasses"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Common service methods that delegate to scanner
|
||||||
|
async def get_top_tags(self, limit: int = 20) -> List[Dict]:
|
||||||
|
"""Get top tags sorted by frequency"""
|
||||||
|
return await self.scanner.get_top_tags(limit)
|
||||||
|
|
||||||
|
async def get_base_models(self, limit: int = 20) -> List[Dict]:
|
||||||
|
"""Get base models sorted by frequency"""
|
||||||
|
return await self.scanner.get_base_models(limit)
|
||||||
|
|
||||||
|
def has_hash(self, sha256: str) -> bool:
|
||||||
|
"""Check if a model with given hash exists"""
|
||||||
|
return self.scanner.has_hash(sha256)
|
||||||
|
|
||||||
|
def get_path_by_hash(self, sha256: str) -> Optional[str]:
|
||||||
|
"""Get file path for a model by its hash"""
|
||||||
|
return self.scanner.get_path_by_hash(sha256)
|
||||||
|
|
||||||
|
def get_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||||
|
"""Get hash for a model by its file path"""
|
||||||
|
return self.scanner.get_hash_by_path(file_path)
|
||||||
|
|
||||||
|
async def scan_models(self, force_refresh: bool = False, rebuild_cache: bool = False):
|
||||||
|
"""Trigger model scanning"""
|
||||||
|
return await self.scanner.get_cached_data(force_refresh=force_refresh, rebuild_cache=rebuild_cache)
|
||||||
|
|
||||||
|
async def get_model_info_by_name(self, name: str):
|
||||||
|
"""Get model information by name"""
|
||||||
|
return await self.scanner.get_model_info_by_name(name)
|
||||||
|
|
||||||
|
def get_model_roots(self) -> List[str]:
|
||||||
|
"""Get model root directories"""
|
||||||
|
return self.scanner.get_model_roots()
|
||||||
@@ -1,14 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Dict, Optional, Set
|
from typing import List, Dict
|
||||||
import folder_paths # type: ignore
|
|
||||||
|
|
||||||
from ..utils.models import CheckpointMetadata
|
from ..utils.models import CheckpointMetadata
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from .model_scanner import ModelScanner
|
from .model_scanner import ModelScanner
|
||||||
from .model_hash_index import ModelHashIndex
|
from .model_hash_index import ModelHashIndex
|
||||||
from .service_registry import ServiceRegistry
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -110,3 +108,30 @@ class CheckpointScanner(ModelScanner):
|
|||||||
checkpoints.append(result)
|
checkpoints.append(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing {file_path}: {e}")
|
logger.error(f"Error processing {file_path}: {e}")
|
||||||
|
|
||||||
|
# Checkpoint-specific hash index functionality
|
||||||
|
def has_checkpoint_hash(self, sha256: str) -> bool:
|
||||||
|
"""Check if a checkpoint with given hash exists"""
|
||||||
|
return self.has_hash(sha256)
|
||||||
|
|
||||||
|
def get_checkpoint_path_by_hash(self, sha256: str) -> str:
|
||||||
|
"""Get file path for a checkpoint by its hash"""
|
||||||
|
return self.get_path_by_hash(sha256)
|
||||||
|
|
||||||
|
def get_checkpoint_hash_by_path(self, file_path: str) -> str:
|
||||||
|
"""Get hash for a checkpoint by its file path"""
|
||||||
|
return self.get_hash_by_path(file_path)
|
||||||
|
|
||||||
|
async def get_checkpoint_info_by_name(self, name):
|
||||||
|
"""Get checkpoint information by name"""
|
||||||
|
try:
|
||||||
|
cache = await self.get_cached_data()
|
||||||
|
|
||||||
|
for checkpoint in cache.raw_data:
|
||||||
|
if checkpoint.get("file_name") == name:
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting checkpoint info by name: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
51
py/services/checkpoint_service.py
Normal file
51
py/services/checkpoint_service.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from .base_model_service import BaseModelService
|
||||||
|
from ..utils.models import CheckpointMetadata
|
||||||
|
from ..config import config
|
||||||
|
from ..utils.routes_common import ModelRouteUtils
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class CheckpointService(BaseModelService):
|
||||||
|
"""Checkpoint-specific service implementation"""
|
||||||
|
|
||||||
|
def __init__(self, scanner):
|
||||||
|
"""Initialize Checkpoint service
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scanner: Checkpoint scanner instance
|
||||||
|
"""
|
||||||
|
super().__init__("checkpoint", scanner, CheckpointMetadata)
|
||||||
|
|
||||||
|
async def format_response(self, checkpoint_data: Dict) -> Dict:
|
||||||
|
"""Format Checkpoint data for API response"""
|
||||||
|
return {
|
||||||
|
"model_name": checkpoint_data["model_name"],
|
||||||
|
"file_name": checkpoint_data["file_name"],
|
||||||
|
"preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")),
|
||||||
|
"preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0),
|
||||||
|
"base_model": checkpoint_data.get("base_model", ""),
|
||||||
|
"folder": checkpoint_data["folder"],
|
||||||
|
"sha256": checkpoint_data.get("sha256", ""),
|
||||||
|
"file_path": checkpoint_data["file_path"].replace(os.sep, "/"),
|
||||||
|
"file_size": checkpoint_data.get("size", 0),
|
||||||
|
"modified": checkpoint_data.get("modified", ""),
|
||||||
|
"tags": checkpoint_data.get("tags", []),
|
||||||
|
"modelDescription": checkpoint_data.get("modelDescription", ""),
|
||||||
|
"from_civitai": checkpoint_data.get("from_civitai", True),
|
||||||
|
"notes": checkpoint_data.get("notes", ""),
|
||||||
|
"model_type": checkpoint_data.get("model_type", "checkpoint"),
|
||||||
|
"favorite": checkpoint_data.get("favorite", False),
|
||||||
|
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint_data.get("civitai", {}))
|
||||||
|
}
|
||||||
|
|
||||||
|
def find_duplicate_hashes(self) -> Dict:
|
||||||
|
"""Find Checkpoints with duplicate SHA256 hashes"""
|
||||||
|
return self.scanner._hash_index.get_duplicate_hashes()
|
||||||
|
|
||||||
|
def find_duplicate_filenames(self) -> Dict:
|
||||||
|
"""Find Checkpoints with conflicting filenames"""
|
||||||
|
return self.scanner._hash_index.get_duplicate_filenames()
|
||||||
172
py/services/lora_service.py
Normal file
172
py/services/lora_service.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from .base_model_service import BaseModelService
|
||||||
|
from ..utils.models import LoraMetadata
|
||||||
|
from ..config import config
|
||||||
|
from ..utils.routes_common import ModelRouteUtils
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class LoraService(BaseModelService):
|
||||||
|
"""LoRA-specific service implementation"""
|
||||||
|
|
||||||
|
def __init__(self, scanner):
|
||||||
|
"""Initialize LoRA service
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scanner: LoRA scanner instance
|
||||||
|
"""
|
||||||
|
super().__init__("lora", scanner, LoraMetadata)
|
||||||
|
|
||||||
|
async def format_response(self, lora_data: Dict) -> Dict:
|
||||||
|
"""Format LoRA data for API response"""
|
||||||
|
return {
|
||||||
|
"model_name": lora_data["model_name"],
|
||||||
|
"file_name": lora_data["file_name"],
|
||||||
|
"preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")),
|
||||||
|
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
|
||||||
|
"base_model": lora_data.get("base_model", ""),
|
||||||
|
"folder": lora_data["folder"],
|
||||||
|
"sha256": lora_data.get("sha256", ""),
|
||||||
|
"file_path": lora_data["file_path"].replace(os.sep, "/"),
|
||||||
|
"file_size": lora_data.get("size", 0),
|
||||||
|
"modified": lora_data.get("modified", ""),
|
||||||
|
"tags": lora_data.get("tags", []),
|
||||||
|
"modelDescription": lora_data.get("modelDescription", ""),
|
||||||
|
"from_civitai": lora_data.get("from_civitai", True),
|
||||||
|
"usage_tips": lora_data.get("usage_tips", ""),
|
||||||
|
"notes": lora_data.get("notes", ""),
|
||||||
|
"favorite": lora_data.get("favorite", False),
|
||||||
|
"civitai": ModelRouteUtils.filter_civitai_data(lora_data.get("civitai", {}))
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||||
|
"""Apply LoRA-specific filters"""
|
||||||
|
# Handle first_letter filter for LoRAs
|
||||||
|
first_letter = kwargs.get('first_letter')
|
||||||
|
if first_letter:
|
||||||
|
data = self._filter_by_first_letter(data, first_letter)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
|
||||||
|
"""Filter LoRAs by first letter"""
|
||||||
|
if letter == '#':
|
||||||
|
# Filter for non-alphabetic characters
|
||||||
|
return [
|
||||||
|
item for item in data
|
||||||
|
if not item.get('model_name', '')[0].isalpha()
|
||||||
|
]
|
||||||
|
elif letter == 'CJK':
|
||||||
|
# Filter for CJK characters
|
||||||
|
return [
|
||||||
|
item for item in data
|
||||||
|
if item.get('model_name', '') and self._is_cjk_character(item['model_name'][0])
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# Filter for specific letter
|
||||||
|
return [
|
||||||
|
item for item in data
|
||||||
|
if item.get('model_name', '').lower().startswith(letter.lower())
|
||||||
|
]
|
||||||
|
|
||||||
|
def _is_cjk_character(self, char: str) -> bool:
|
||||||
|
"""Check if character is CJK (Chinese, Japanese, Korean)"""
|
||||||
|
cjk_ranges = [
|
||||||
|
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||||
|
(0x3400, 0x4DBF), # CJK Extension A
|
||||||
|
(0x20000, 0x2A6DF), # CJK Extension B
|
||||||
|
(0x2A700, 0x2B73F), # CJK Extension C
|
||||||
|
(0x2B740, 0x2B81F), # CJK Extension D
|
||||||
|
(0x3040, 0x309F), # Hiragana
|
||||||
|
(0x30A0, 0x30FF), # Katakana
|
||||||
|
(0xAC00, 0xD7AF), # Hangul Syllables
|
||||||
|
]
|
||||||
|
|
||||||
|
char_code = ord(char)
|
||||||
|
return any(start <= char_code <= end for start, end in cjk_ranges)
|
||||||
|
|
||||||
|
# LoRA-specific methods
|
||||||
|
async def get_letter_counts(self) -> Dict[str, int]:
|
||||||
|
"""Get count of LoRAs for each letter of the alphabet"""
|
||||||
|
cache = await self.scanner.get_cached_data()
|
||||||
|
letter_counts = {}
|
||||||
|
|
||||||
|
for lora in cache.raw_data:
|
||||||
|
model_name = lora.get('model_name', '')
|
||||||
|
if model_name:
|
||||||
|
first_char = model_name[0].upper()
|
||||||
|
if first_char.isalpha():
|
||||||
|
letter_counts[first_char] = letter_counts.get(first_char, 0) + 1
|
||||||
|
elif self._is_cjk_character(first_char):
|
||||||
|
letter_counts['CJK'] = letter_counts.get('CJK', 0) + 1
|
||||||
|
else:
|
||||||
|
letter_counts['#'] = letter_counts.get('#', 0) + 1
|
||||||
|
|
||||||
|
return letter_counts
|
||||||
|
|
||||||
|
async def get_lora_notes(self, lora_name: str) -> Optional[str]:
|
||||||
|
"""Get notes for a specific LoRA file"""
|
||||||
|
cache = await self.scanner.get_cached_data()
|
||||||
|
|
||||||
|
for lora in cache.raw_data:
|
||||||
|
if lora['file_name'] == lora_name:
|
||||||
|
return lora.get('notes', '')
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_lora_trigger_words(self, lora_name: str) -> List[str]:
|
||||||
|
"""Get trigger words for a specific LoRA file"""
|
||||||
|
cache = await self.scanner.get_cached_data()
|
||||||
|
|
||||||
|
for lora in cache.raw_data:
|
||||||
|
if lora['file_name'] == lora_name:
|
||||||
|
civitai_data = lora.get('civitai', {})
|
||||||
|
return civitai_data.get('trainedWords', [])
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_lora_preview_url(self, lora_name: str) -> Optional[str]:
|
||||||
|
"""Get the static preview URL for a LoRA file"""
|
||||||
|
cache = await self.scanner.get_cached_data()
|
||||||
|
|
||||||
|
for lora in cache.raw_data:
|
||||||
|
if lora['file_name'] == lora_name:
|
||||||
|
preview_url = lora.get('preview_url')
|
||||||
|
if preview_url:
|
||||||
|
return config.get_preview_static_url(preview_url)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_lora_civitai_url(self, lora_name: str) -> Dict[str, Optional[str]]:
|
||||||
|
"""Get the Civitai URL for a LoRA file"""
|
||||||
|
cache = await self.scanner.get_cached_data()
|
||||||
|
|
||||||
|
for lora in cache.raw_data:
|
||||||
|
if lora['file_name'] == lora_name:
|
||||||
|
civitai_data = lora.get('civitai', {})
|
||||||
|
model_id = civitai_data.get('modelId')
|
||||||
|
version_id = civitai_data.get('id')
|
||||||
|
|
||||||
|
if model_id:
|
||||||
|
civitai_url = f"https://civitai.com/models/{model_id}"
|
||||||
|
if version_id:
|
||||||
|
civitai_url += f"?modelVersionId={version_id}"
|
||||||
|
|
||||||
|
return {
|
||||||
|
'civitai_url': civitai_url,
|
||||||
|
'model_id': str(model_id),
|
||||||
|
'version_id': str(version_id) if version_id else None
|
||||||
|
}
|
||||||
|
|
||||||
|
return {'civitai_url': None, 'model_id': None, 'version_id': None}
|
||||||
|
|
||||||
|
def find_duplicate_hashes(self) -> Dict:
|
||||||
|
"""Find LoRAs with duplicate SHA256 hashes"""
|
||||||
|
return self.scanner._hash_index.get_duplicate_hashes()
|
||||||
|
|
||||||
|
def find_duplicate_filenames(self) -> Dict:
|
||||||
|
"""Find LoRAs with conflicting filenames"""
|
||||||
|
return self.scanner._hash_index.get_duplicate_filenames()
|
||||||
137
py/services/model_service_factory.py
Normal file
137
py/services/model_service_factory.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
from typing import Dict, Type, Any
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ModelServiceFactory:
|
||||||
|
"""Factory for managing model services and routes"""
|
||||||
|
|
||||||
|
_services: Dict[str, Type] = {}
|
||||||
|
_routes: Dict[str, Type] = {}
|
||||||
|
_initialized_services: Dict[str, Any] = {}
|
||||||
|
_initialized_routes: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_model_type(cls, model_type: str, service_class: Type, route_class: Type):
|
||||||
|
"""Register a new model type with its service and route classes
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: The model type identifier (e.g., 'lora', 'checkpoint')
|
||||||
|
service_class: The service class for this model type
|
||||||
|
route_class: The route class for this model type
|
||||||
|
"""
|
||||||
|
cls._services[model_type] = service_class
|
||||||
|
cls._routes[model_type] = route_class
|
||||||
|
logger.info(f"Registered model type '{model_type}' with service {service_class.__name__} and routes {route_class.__name__}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_service_class(cls, model_type: str) -> Type:
|
||||||
|
"""Get service class for a model type
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: The model type identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The service class for the model type
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If model type is not registered
|
||||||
|
"""
|
||||||
|
if model_type not in cls._services:
|
||||||
|
raise ValueError(f"Unknown model type: {model_type}")
|
||||||
|
return cls._services[model_type]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_route_class(cls, model_type: str) -> Type:
|
||||||
|
"""Get route class for a model type
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: The model type identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The route class for the model type
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If model type is not registered
|
||||||
|
"""
|
||||||
|
if model_type not in cls._routes:
|
||||||
|
raise ValueError(f"Unknown model type: {model_type}")
|
||||||
|
return cls._routes[model_type]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_route_instance(cls, model_type: str):
|
||||||
|
"""Get or create route instance for a model type
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: The model type identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The route instance for the model type
|
||||||
|
"""
|
||||||
|
if model_type not in cls._initialized_routes:
|
||||||
|
route_class = cls.get_route_class(model_type)
|
||||||
|
cls._initialized_routes[model_type] = route_class()
|
||||||
|
return cls._initialized_routes[model_type]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_all_routes(cls, app):
|
||||||
|
"""Setup routes for all registered model types
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: The aiohttp application instance
|
||||||
|
"""
|
||||||
|
logger.info(f"Setting up routes for {len(cls._services)} registered model types")
|
||||||
|
|
||||||
|
for model_type in cls._services.keys():
|
||||||
|
try:
|
||||||
|
routes_instance = cls.get_route_instance(model_type)
|
||||||
|
routes_instance.setup_routes(app)
|
||||||
|
logger.info(f"Successfully set up routes for {model_type}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to setup routes for {model_type}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_registered_types(cls) -> list:
|
||||||
|
"""Get list of all registered model types
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of registered model type identifiers
|
||||||
|
"""
|
||||||
|
return list(cls._services.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_registered(cls, model_type: str) -> bool:
|
||||||
|
"""Check if a model type is registered
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: The model type identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the model type is registered, False otherwise
|
||||||
|
"""
|
||||||
|
return model_type in cls._services
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def clear_registrations(cls):
|
||||||
|
"""Clear all registrations - mainly for testing purposes"""
|
||||||
|
cls._services.clear()
|
||||||
|
cls._routes.clear()
|
||||||
|
cls._initialized_services.clear()
|
||||||
|
cls._initialized_routes.clear()
|
||||||
|
logger.info("Cleared all model type registrations")
|
||||||
|
|
||||||
|
|
||||||
|
def register_default_model_types():
|
||||||
|
"""Register the default model types (LoRA and Checkpoint)"""
|
||||||
|
from ..services.lora_service import LoraService
|
||||||
|
from ..services.checkpoint_service import CheckpointService
|
||||||
|
from ..routes.lora_routes import LoraRoutes
|
||||||
|
from ..routes.checkpoint_routes import CheckpointRoutes
|
||||||
|
|
||||||
|
# Register LoRA model type
|
||||||
|
ModelServiceFactory.register_model_type('lora', LoraService, LoraRoutes)
|
||||||
|
|
||||||
|
# Register Checkpoint model type
|
||||||
|
ModelServiceFactory.register_model_type('checkpoint', CheckpointService, CheckpointRoutes)
|
||||||
|
|
||||||
|
logger.info("Registered default model types: lora, checkpoint")
|
||||||
@@ -7,106 +7,176 @@ logger = logging.getLogger(__name__)
|
|||||||
T = TypeVar('T') # Define a type variable for service types
|
T = TypeVar('T') # Define a type variable for service types
|
||||||
|
|
||||||
class ServiceRegistry:
|
class ServiceRegistry:
|
||||||
"""Centralized registry for service singletons"""
|
"""Central registry for managing singleton services"""
|
||||||
|
|
||||||
_instance = None
|
|
||||||
_services: Dict[str, Any] = {}
|
_services: Dict[str, Any] = {}
|
||||||
_lock = asyncio.Lock()
|
_locks: Dict[str, asyncio.Lock] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls):
|
async def register_service(cls, name: str, service: Any) -> None:
|
||||||
"""Get singleton instance of the registry"""
|
"""Register a service instance with the registry
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = cls()
|
Args:
|
||||||
return cls._instance
|
name: Service name identifier
|
||||||
|
service: Service instance to register
|
||||||
|
"""
|
||||||
|
cls._services[name] = service
|
||||||
|
logger.debug(f"Registered service: {name}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def register_service(cls, service_name: str, service_instance: Any) -> None:
|
async def get_service(cls, name: str) -> Optional[Any]:
|
||||||
"""Register a service instance with the registry"""
|
"""Get a service instance by name
|
||||||
registry = cls.get_instance()
|
|
||||||
async with cls._lock:
|
Args:
|
||||||
registry._services[service_name] = service_instance
|
name: Service name identifier
|
||||||
logger.debug(f"Registered service: {service_name}")
|
|
||||||
|
Returns:
|
||||||
|
Service instance or None if not found
|
||||||
|
"""
|
||||||
|
return cls._services.get(name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_service(cls, service_name: str) -> Any:
|
def _get_lock(cls, name: str) -> asyncio.Lock:
|
||||||
"""Get a service instance by name"""
|
"""Get or create a lock for a service
|
||||||
registry = cls.get_instance()
|
|
||||||
async with cls._lock:
|
|
||||||
if service_name not in registry._services:
|
|
||||||
logger.debug(f"Service {service_name} not found in registry")
|
|
||||||
return None
|
|
||||||
return registry._services[service_name]
|
|
||||||
|
|
||||||
@classmethod
|
Args:
|
||||||
def get_service_sync(cls, service_name: str) -> Any:
|
name: Service name identifier
|
||||||
"""Get a service instance by name (synchronous version)"""
|
|
||||||
registry = cls.get_instance()
|
Returns:
|
||||||
if service_name not in registry._services:
|
AsyncIO lock for the service
|
||||||
logger.debug(f"Service {service_name} not found in registry")
|
"""
|
||||||
return None
|
if name not in cls._locks:
|
||||||
return registry._services[service_name]
|
cls._locks[name] = asyncio.Lock()
|
||||||
|
return cls._locks[name]
|
||||||
|
|
||||||
# Convenience methods for common services
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_lora_scanner(cls):
|
async def get_lora_scanner(cls):
|
||||||
"""Get the LoraScanner instance"""
|
"""Get or create LoRA scanner instance"""
|
||||||
|
service_name = "lora_scanner"
|
||||||
|
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
async with cls._get_lock(service_name):
|
||||||
|
# Double-check after acquiring lock
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
# Import here to avoid circular imports
|
||||||
from .lora_scanner import LoraScanner
|
from .lora_scanner import LoraScanner
|
||||||
scanner = await cls.get_service("lora_scanner")
|
|
||||||
if scanner is None:
|
|
||||||
scanner = await LoraScanner.get_instance()
|
scanner = await LoraScanner.get_instance()
|
||||||
await cls.register_service("lora_scanner", scanner)
|
cls._services[service_name] = scanner
|
||||||
|
logger.info(f"Created and registered {service_name}")
|
||||||
return scanner
|
return scanner
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_checkpoint_scanner(cls):
|
async def get_checkpoint_scanner(cls):
|
||||||
"""Get the CheckpointScanner instance"""
|
"""Get or create Checkpoint scanner instance"""
|
||||||
|
service_name = "checkpoint_scanner"
|
||||||
|
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
async with cls._get_lock(service_name):
|
||||||
|
# Double-check after acquiring lock
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
# Import here to avoid circular imports
|
||||||
from .checkpoint_scanner import CheckpointScanner
|
from .checkpoint_scanner import CheckpointScanner
|
||||||
scanner = await cls.get_service("checkpoint_scanner")
|
|
||||||
if scanner is None:
|
|
||||||
scanner = await CheckpointScanner.get_instance()
|
scanner = await CheckpointScanner.get_instance()
|
||||||
await cls.register_service("checkpoint_scanner", scanner)
|
cls._services[service_name] = scanner
|
||||||
|
logger.info(f"Created and registered {service_name}")
|
||||||
|
return scanner
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_recipe_scanner(cls):
|
||||||
|
"""Get or create Recipe scanner instance"""
|
||||||
|
service_name = "recipe_scanner"
|
||||||
|
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
async with cls._get_lock(service_name):
|
||||||
|
# Double-check after acquiring lock
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from .recipe_scanner import RecipeScanner
|
||||||
|
|
||||||
|
scanner = await RecipeScanner.get_instance()
|
||||||
|
cls._services[service_name] = scanner
|
||||||
|
logger.info(f"Created and registered {service_name}")
|
||||||
return scanner
|
return scanner
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_civitai_client(cls):
|
async def get_civitai_client(cls):
|
||||||
"""Get the CivitaiClient instance"""
|
"""Get or create CivitAI client instance"""
|
||||||
|
service_name = "civitai_client"
|
||||||
|
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
async with cls._get_lock(service_name):
|
||||||
|
# Double-check after acquiring lock
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
# Import here to avoid circular imports
|
||||||
from .civitai_client import CivitaiClient
|
from .civitai_client import CivitaiClient
|
||||||
client = await cls.get_service("civitai_client")
|
|
||||||
if client is None:
|
|
||||||
client = await CivitaiClient.get_instance()
|
client = await CivitaiClient.get_instance()
|
||||||
await cls.register_service("civitai_client", client)
|
cls._services[service_name] = client
|
||||||
|
logger.info(f"Created and registered {service_name}")
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_download_manager(cls):
|
async def get_download_manager(cls):
|
||||||
"""Get the DownloadManager instance"""
|
"""Get or create Download manager instance"""
|
||||||
from .download_manager import DownloadManager
|
service_name = "download_manager"
|
||||||
manager = await cls.get_service("download_manager")
|
|
||||||
if manager is None:
|
|
||||||
manager = await DownloadManager.get_instance()
|
|
||||||
await cls.register_service("download_manager", manager)
|
|
||||||
return manager
|
|
||||||
|
|
||||||
@classmethod
|
if service_name in cls._services:
|
||||||
async def get_recipe_scanner(cls):
|
return cls._services[service_name]
|
||||||
"""Get the RecipeScanner instance"""
|
|
||||||
from .recipe_scanner import RecipeScanner
|
async with cls._get_lock(service_name):
|
||||||
scanner = await cls.get_service("recipe_scanner")
|
# Double-check after acquiring lock
|
||||||
if scanner is None:
|
if service_name in cls._services:
|
||||||
lora_scanner = await cls.get_lora_scanner()
|
return cls._services[service_name]
|
||||||
scanner = RecipeScanner(lora_scanner)
|
|
||||||
await cls.register_service("recipe_scanner", scanner)
|
# Import here to avoid circular imports
|
||||||
return scanner
|
from .download_manager import DownloadManager
|
||||||
|
|
||||||
|
manager = DownloadManager()
|
||||||
|
cls._services[service_name] = manager
|
||||||
|
logger.info(f"Created and registered {service_name}")
|
||||||
|
return manager
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_websocket_manager(cls):
|
async def get_websocket_manager(cls):
|
||||||
"""Get the WebSocketManager instance"""
|
"""Get or create WebSocket manager instance"""
|
||||||
|
service_name = "websocket_manager"
|
||||||
|
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
async with cls._get_lock(service_name):
|
||||||
|
# Double-check after acquiring lock
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
# Import here to avoid circular imports
|
||||||
from .websocket_manager import ws_manager
|
from .websocket_manager import ws_manager
|
||||||
manager = await cls.get_service("websocket_manager")
|
|
||||||
if manager is None:
|
cls._services[service_name] = ws_manager
|
||||||
# ws_manager is already a global instance in websocket_manager.py
|
logger.info(f"Registered {service_name}")
|
||||||
from .websocket_manager import ws_manager
|
return ws_manager
|
||||||
await cls.register_service("websocket_manager", ws_manager)
|
|
||||||
manager = ws_manager
|
@classmethod
|
||||||
return manager
|
def clear_services(cls):
|
||||||
|
"""Clear all registered services - mainly for testing"""
|
||||||
|
cls._services.clear()
|
||||||
|
cls._locks.clear()
|
||||||
|
logger.info("Cleared all registered services")
|
||||||
@@ -1047,3 +1047,56 @@ class ModelRouteUtils:
|
|||||||
'success': False,
|
'success': False,
|
||||||
'error': str(e)
|
'error': str(e)
|
||||||
}, status=500)
|
}, status=500)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def handle_save_metadata(request: web.Request, scanner) -> web.Response:
|
||||||
|
"""Handle saving metadata updates
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The aiohttp request
|
||||||
|
scanner: The model scanner instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
web.Response: The HTTP response
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
file_path = data.get('file_path')
|
||||||
|
if not file_path:
|
||||||
|
return web.Response(text='File path is required', status=400)
|
||||||
|
|
||||||
|
# Remove file path from data to avoid saving it
|
||||||
|
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
|
||||||
|
|
||||||
|
# Get metadata file path
|
||||||
|
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||||
|
|
||||||
|
# Load existing metadata
|
||||||
|
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||||
|
|
||||||
|
# Handle nested updates (for civitai.trainedWords)
|
||||||
|
for key, value in metadata_updates.items():
|
||||||
|
if isinstance(value, dict) and key in metadata and isinstance(metadata[key], dict):
|
||||||
|
# Deep update for nested dictionaries
|
||||||
|
for nested_key, nested_value in value.items():
|
||||||
|
metadata[key][nested_key] = nested_value
|
||||||
|
else:
|
||||||
|
# Regular update for top-level keys
|
||||||
|
metadata[key] = value
|
||||||
|
|
||||||
|
# Save updated metadata
|
||||||
|
await MetadataManager.save_metadata(file_path, metadata)
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
await scanner.update_single_model_cache(file_path, file_path, metadata)
|
||||||
|
|
||||||
|
# If model_name was updated, resort the cache
|
||||||
|
if 'model_name' in metadata_updates:
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
await cache.resort(name_only=True)
|
||||||
|
|
||||||
|
return web.json_response({'success': True})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving metadata: {e}", exc_info=True)
|
||||||
|
return web.Response(text=str(e), status=500)
|
||||||
|
|||||||
@@ -314,22 +314,23 @@ class StandaloneLoraManager(LoraManager):
|
|||||||
app.router.add_static('/loras_static', config.static_path)
|
app.router.add_static('/loras_static', config.static_path)
|
||||||
|
|
||||||
# Setup feature routes
|
# Setup feature routes
|
||||||
from py.routes.lora_routes import LoraRoutes
|
from py.services.model_service_factory import ModelServiceFactory, register_default_model_types
|
||||||
from py.routes.api_routes import ApiRoutes
|
from py.routes.api_routes import ApiRoutes
|
||||||
from py.routes.recipe_routes import RecipeRoutes
|
from py.routes.recipe_routes import RecipeRoutes
|
||||||
from py.routes.checkpoints_routes import CheckpointsRoutes
|
|
||||||
from py.routes.update_routes import UpdateRoutes
|
from py.routes.update_routes import UpdateRoutes
|
||||||
from py.routes.misc_routes import MiscRoutes
|
from py.routes.misc_routes import MiscRoutes
|
||||||
from py.routes.example_images_routes import ExampleImagesRoutes
|
from py.routes.example_images_routes import ExampleImagesRoutes
|
||||||
from py.routes.stats_routes import StatsRoutes
|
from py.routes.stats_routes import StatsRoutes
|
||||||
|
|
||||||
lora_routes = LoraRoutes()
|
|
||||||
checkpoints_routes = CheckpointsRoutes()
|
register_default_model_types()
|
||||||
|
|
||||||
|
# Setup all model routes using the factory
|
||||||
|
ModelServiceFactory.setup_all_routes(app)
|
||||||
|
|
||||||
stats_routes = StatsRoutes()
|
stats_routes = StatsRoutes()
|
||||||
|
|
||||||
# Initialize routes
|
# Initialize routes
|
||||||
lora_routes.setup_routes(app)
|
|
||||||
checkpoints_routes.setup_routes(app)
|
|
||||||
stats_routes.setup_routes(app)
|
stats_routes.setup_routes(app)
|
||||||
ApiRoutes.setup_routes(app)
|
ApiRoutes.setup_routes(app)
|
||||||
RecipeRoutes.setup_routes(app)
|
RecipeRoutes.setup_routes(app)
|
||||||
|
|||||||
Reference in New Issue
Block a user