Merge pull request #313 from willmiao/refactor

Refactor
This commit is contained in:
pixelpaws
2025-07-24 19:37:17 +08:00
committed by GitHub
40 changed files with 2954 additions and 3212 deletions

View File

@@ -1,71 +0,0 @@
# Path mappings configuration for ComfyUI-Lora-Manager
# This file allows you to customize how base models and Civitai tags map to directories when downloading models
# Base model mappings
# Format: "Original Base Model": "Custom Directory Name"
#
# Example: If you change "Flux.1 D": "flux"
# Then models with base model "Flux.1 D" will be stored in a directory named "flux"
# So the final path would be: <model_directory>/flux/<tag>/model_file.safetensors
base_models:
"SD 1.4": "SD 1.4"
"SD 1.5": "SD 1.5"
"SD 1.5 LCM": "SD 1.5 LCM"
"SD 1.5 Hyper": "SD 1.5 Hyper"
"SD 2.0": "SD 2.0"
"SD 2.1": "SD 2.1"
"SDXL 1.0": "SDXL 1.0"
"SD 3": "SD 3"
"SD 3.5": "SD 3.5"
"SD 3.5 Medium": "SD 3.5 Medium"
"SD 3.5 Large": "SD 3.5 Large"
"SD 3.5 Large Turbo": "SD 3.5 Large Turbo"
"Pony": "Pony"
"Flux.1 S": "Flux.1 S"
"Flux.1 D": "Flux.1 D"
"Flux.1 Kontext": "Flux.1 Kontext"
"AuraFlow": "AuraFlow"
"SDXL Lightning": "SDXL Lightning"
"SDXL Hyper": "SDXL Hyper"
"Stable Cascade": "Stable Cascade"
"SVD": "SVD"
"PixArt a": "PixArt a"
"PixArt E": "PixArt E"
"Hunyuan 1": "Hunyuan 1"
"Hunyuan Video": "Hunyuan Video"
"Lumina": "Lumina"
"Kolors": "Kolors"
"Illustrious": "Illustrious"
"Mochi": "Mochi"
"LTXV": "LTXV"
"CogVideoX": "CogVideoX"
"NoobAI": "NoobAI"
"Wan Video": "Wan Video"
"Wan Video 1.3B t2v": "Wan Video 1.3B t2v"
"Wan Video 14B t2v": "Wan Video 14B t2v"
"Wan Video 14B i2v 480p": "Wan Video 14B i2v 480p"
"Wan Video 14B i2v 720p": "Wan Video 14B i2v 720p"
"HiDream": "HiDream"
"Other": "Other"
# Civitai model tag mappings
# Format: "Original Tag": "Custom Directory Name"
#
# Example: If you change "character": "characters"
# Then models with tag "character" will be stored in a directory named "characters"
# So the final path would be: <model_directory>/<base_model>/characters/model_file.safetensors
model_tags:
"character": "character"
"style": "style"
"concept": "concept"
"clothing": "clothing"
"base model": "base model"
"poses": "poses"
"background": "background"
"tool": "tool"
"vehicle": "vehicle"
"buildings": "buildings"
"objects": "objects"
"assets": "assets"
"animal": "animal"
"action": "action"

View File

@@ -6,10 +6,8 @@ from pathlib import Path
from server import PromptServer # type: ignore
from .config import config
from .routes.lora_routes import LoraRoutes
from .routes.api_routes import ApiRoutes
from .services.model_service_factory import ModelServiceFactory, register_default_model_types
from .routes.recipe_routes import RecipeRoutes
from .routes.checkpoints_routes import CheckpointsRoutes
from .routes.stats_routes import StatsRoutes
from .routes.update_routes import UpdateRoutes
from .routes.misc_routes import MiscRoutes
@@ -17,6 +15,7 @@ from .routes.example_images_routes import ExampleImagesRoutes
from .services.service_registry import ServiceRegistry
from .services.settings_manager import settings
from .utils.example_images_migration import ExampleImagesMigration
from .services.websocket_manager import ws_manager
logger = logging.getLogger(__name__)
@@ -28,12 +27,28 @@ class LoraManager:
@classmethod
def add_routes(cls):
"""Initialize and register all routes"""
"""Initialize and register all routes using the new refactored architecture"""
app = PromptServer.instance.app
# Configure aiohttp access logger to be less verbose
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
# Add specific suppression for connection reset errors
class ConnectionResetFilter(logging.Filter):
def filter(self, record):
# Filter out connection reset errors that are not critical
if "ConnectionResetError" in str(record.getMessage()):
return False
if "_call_connection_lost" in str(record.getMessage()):
return False
if "WinError 10054" in str(record.getMessage()):
return False
return True
# Apply the filter to asyncio logger
asyncio_logger = logging.getLogger("asyncio")
asyncio_logger.addFilter(ConnectionResetFilter())
added_targets = set() # Track already added target paths
# Add static route for example images if the path exists in settings
@@ -110,35 +125,37 @@ class LoraManager:
# Add static route for plugin assets
app.router.add_static('/loras_static', config.static_path)
# Setup feature routes
lora_routes = LoraRoutes()
checkpoints_routes = CheckpointsRoutes()
stats_routes = StatsRoutes()
# Register default model types with the factory
register_default_model_types()
# Initialize routes
lora_routes.setup_routes(app)
checkpoints_routes.setup_routes(app)
stats_routes.setup_routes(app) # Add statistics routes
ApiRoutes.setup_routes(app)
# Setup all model routes using the factory
ModelServiceFactory.setup_all_routes(app)
# Setup non-model-specific routes
stats_routes = StatsRoutes()
stats_routes.setup_routes(app)
RecipeRoutes.setup_routes(app)
UpdateRoutes.setup_routes(app)
MiscRoutes.setup_routes(app) # Register miscellaneous routes
ExampleImagesRoutes.setup_routes(app) # Register example images routes
MiscRoutes.setup_routes(app)
ExampleImagesRoutes.setup_routes(app)
# Setup WebSocket routes that are shared across all model types
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection)
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection)
# Schedule service initialization
app.on_startup.append(lambda app: cls._initialize_services())
# Add cleanup
app.on_shutdown.append(cls._cleanup)
app.on_shutdown.append(ApiRoutes.cleanup)
logger.info(f"LoRA Manager: Set up routes for {len(ModelServiceFactory.get_registered_types())} model types: {', '.join(ModelServiceFactory.get_registered_types())}")
@classmethod
async def _initialize_services(cls):
"""Initialize all services using the ServiceRegistry"""
try:
# Ensure aiohttp access logger is configured with reduced verbosity
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
# Initialize CivitaiClient first to ensure it's ready for other services
await ServiceRegistry.get_civitai_client()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,619 @@
from abc import ABC, abstractmethod
import asyncio
import json
import logging
from aiohttp import web
from typing import Dict
import jinja2
from ..utils.routes_common import ModelRouteUtils
from ..services.websocket_manager import ws_manager
from ..services.settings_manager import settings
from ..config import config
logger = logging.getLogger(__name__)
class BaseModelRoutes(ABC):
"""Base route controller for all model types"""
def __init__(self, service):
"""Initialize the route controller
Args:
service: Model service instance (LoraService, CheckpointService, etc.)
"""
self.service = service
self.model_type = service.model_type
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
def setup_routes(self, app: web.Application, prefix: str):
"""Setup common routes for the model type
Args:
app: aiohttp application
prefix: URL prefix (e.g., 'loras', 'checkpoints')
"""
# Common model management routes
app.router.add_get(f'/api/{prefix}', self.get_models)
app.router.add_post(f'/api/{prefix}/delete', self.delete_model)
app.router.add_post(f'/api/{prefix}/exclude', self.exclude_model)
app.router.add_post(f'/api/{prefix}/fetch-civitai', self.fetch_civitai)
app.router.add_post(f'/api/{prefix}/relink-civitai', self.relink_civitai)
app.router.add_post(f'/api/{prefix}/replace-preview', self.replace_preview)
app.router.add_post(f'/api/{prefix}/save-metadata', self.save_metadata)
app.router.add_post(f'/api/{prefix}/rename', self.rename_model)
app.router.add_post(f'/api/{prefix}/bulk-delete', self.bulk_delete_models)
app.router.add_post(f'/api/{prefix}/verify-duplicates', self.verify_duplicates)
# Common query routes
app.router.add_get(f'/api/{prefix}/top-tags', self.get_top_tags)
app.router.add_get(f'/api/{prefix}/base-models', self.get_base_models)
app.router.add_get(f'/api/{prefix}/scan', self.scan_models)
app.router.add_get(f'/api/{prefix}/roots', self.get_model_roots)
app.router.add_get(f'/api/{prefix}/folders', self.get_folders)
app.router.add_get(f'/api/{prefix}/find-duplicates', self.find_duplicate_models)
app.router.add_get(f'/api/{prefix}/find-filename-conflicts', self.find_filename_conflicts)
# Common Download management
app.router.add_post(f'/api/download-model', self.download_model)
app.router.add_get(f'/api/download-model-get', self.download_model_get)
app.router.add_get(f'/api/cancel-download-get', self.cancel_download_get)
app.router.add_get(f'/api/download-progress/{{download_id}}', self.get_download_progress)
# CivitAI integration routes
app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai)
# app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions)
# Add generic page route
app.router.add_get(f'/{prefix}', self.handle_models_page)
# Setup model-specific routes
self.setup_specific_routes(app, prefix)
@abstractmethod
def setup_specific_routes(self, app: web.Application, prefix: str):
"""Setup model-specific routes - to be implemented by subclasses"""
pass
async def handle_models_page(self, request: web.Request) -> web.Response:
"""
Generic handler for model pages (e.g., /loras, /checkpoints).
Subclasses should set self.template_env and template_name.
"""
try:
# Check if the scanner is initializing
is_initializing = (
self.service.scanner._cache is None or
(hasattr(self.service.scanner, 'is_initializing') and callable(self.service.scanner.is_initializing) and self.service.scanner.is_initializing()) or
(hasattr(self.service.scanner, '_is_initializing') and self.service.scanner._is_initializing)
)
template_name = getattr(self, "template_name", None)
if not self.template_env or not template_name:
return web.Response(text="Template environment or template name not set", status=500)
if is_initializing:
rendered = self.template_env.get_template(template_name).render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
else:
try:
cache = await self.service.scanner.get_cached_data(force_refresh=False)
rendered = self.template_env.get_template(template_name).render(
folders=getattr(cache, "folders", []),
is_initializing=False,
settings=settings,
request=request
)
except Exception as cache_error:
logger.error(f"Error loading cache data: {cache_error}")
rendered = self.template_env.get_template(template_name).render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
return web.Response(
text=rendered,
content_type='text/html'
)
except Exception as e:
logger.error(f"Error handling models page: {e}", exc_info=True)
return web.Response(
text="Error loading models page",
status=500
)
async def get_models(self, request: web.Request) -> web.Response:
"""Get paginated model data"""
try:
# Parse common query parameters
params = self._parse_common_params(request)
# Get data from service
result = await self.service.get_paginated_data(**params)
# Format response items
formatted_result = {
'items': [await self.service.format_response(item) for item in result['items']],
'total': result['total'],
'page': result['page'],
'page_size': result['page_size'],
'total_pages': result['total_pages']
}
return web.json_response(formatted_result)
except Exception as e:
logger.error(f"Error in get_{self.model_type}s: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
def _parse_common_params(self, request: web.Request) -> Dict:
"""Parse common query parameters"""
# Parse basic pagination and sorting
page = int(request.query.get('page', '1'))
page_size = min(int(request.query.get('page_size', '20')), 100)
sort_by = request.query.get('sort_by', 'name')
folder = request.query.get('folder', None)
search = request.query.get('search', None)
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
# Parse filter arrays
base_models = request.query.getall('base_model', [])
tags = request.query.getall('tag', [])
favorites_only = request.query.get('favorites_only', 'false').lower() == 'true'
# Parse search options
search_options = {
'filename': request.query.get('search_filename', 'true').lower() == 'true',
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
'tags': request.query.get('search_tags', 'false').lower() == 'true',
'recursive': request.query.get('recursive', 'false').lower() == 'true',
}
# Parse hash filters if provided
hash_filters = {}
if 'hash' in request.query:
hash_filters['single_hash'] = request.query['hash']
elif 'hashes' in request.query:
try:
hash_list = json.loads(request.query['hashes'])
if isinstance(hash_list, list):
hash_filters['multiple_hashes'] = hash_list
except (json.JSONDecodeError, TypeError):
pass
return {
'page': page,
'page_size': page_size,
'sort_by': sort_by,
'folder': folder,
'search': search,
'fuzzy_search': fuzzy_search,
'base_models': base_models,
'tags': tags,
'search_options': search_options,
'hash_filters': hash_filters,
'favorites_only': favorites_only,
# Add model-specific parameters
**self._parse_specific_params(request)
}
def _parse_specific_params(self, request: web.Request) -> Dict:
"""Parse model-specific parameters - to be overridden by subclasses"""
return {}
# Common route handlers
async def delete_model(self, request: web.Request) -> web.Response:
"""Handle model deletion request"""
return await ModelRouteUtils.handle_delete_model(request, self.service.scanner)
async def exclude_model(self, request: web.Request) -> web.Response:
"""Handle model exclusion request"""
return await ModelRouteUtils.handle_exclude_model(request, self.service.scanner)
async def fetch_civitai(self, request: web.Request) -> web.Response:
"""Handle CivitAI metadata fetch request"""
response = await ModelRouteUtils.handle_fetch_civitai(request, self.service.scanner)
# If successful, format the metadata before returning
if response.status == 200:
data = json.loads(response.body.decode('utf-8'))
if data.get("success") and data.get("metadata"):
formatted_metadata = await self.service.format_response(data["metadata"])
return web.json_response({
"success": True,
"metadata": formatted_metadata
})
return response
async def relink_civitai(self, request: web.Request) -> web.Response:
"""Handle CivitAI metadata re-linking request"""
return await ModelRouteUtils.handle_relink_civitai(request, self.service.scanner)
async def replace_preview(self, request: web.Request) -> web.Response:
"""Handle preview image replacement"""
return await ModelRouteUtils.handle_replace_preview(request, self.service.scanner)
async def save_metadata(self, request: web.Request) -> web.Response:
"""Handle saving metadata updates"""
return await ModelRouteUtils.handle_save_metadata(request, self.service.scanner)
async def rename_model(self, request: web.Request) -> web.Response:
"""Handle renaming a model file and its associated files"""
return await ModelRouteUtils.handle_rename_model(request, self.service.scanner)
async def bulk_delete_models(self, request: web.Request) -> web.Response:
"""Handle bulk deletion of models"""
return await ModelRouteUtils.handle_bulk_delete_models(request, self.service.scanner)
async def verify_duplicates(self, request: web.Request) -> web.Response:
"""Handle verification of duplicate model hashes"""
return await ModelRouteUtils.handle_verify_duplicates(request, self.service.scanner)
async def get_top_tags(self, request: web.Request) -> web.Response:
"""Handle request for top tags sorted by frequency"""
try:
limit = int(request.query.get('limit', '20'))
if limit < 1 or limit > 100:
limit = 20
top_tags = await self.service.get_top_tags(limit)
return web.json_response({
'success': True,
'tags': top_tags
})
except Exception as e:
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
return web.json_response({
'success': False,
'error': 'Internal server error'
}, status=500)
async def get_base_models(self, request: web.Request) -> web.Response:
"""Get base models used in models"""
try:
limit = int(request.query.get('limit', '20'))
if limit < 1 or limit > 100:
limit = 20
base_models = await self.service.get_base_models(limit)
return web.json_response({
'success': True,
'base_models': base_models
})
except Exception as e:
logger.error(f"Error retrieving base models: {e}")
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def scan_models(self, request: web.Request) -> web.Response:
"""Force a rescan of model files"""
try:
full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true'
await self.service.scan_models(force_refresh=True, rebuild_cache=full_rebuild)
return web.json_response({
"status": "success",
"message": f"{self.model_type.capitalize()} scan completed"
})
except Exception as e:
logger.error(f"Error in scan_{self.model_type}s: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_model_roots(self, request: web.Request) -> web.Response:
"""Return the model root directories"""
try:
roots = self.service.get_model_roots()
return web.json_response({
"success": True,
"roots": roots
})
except Exception as e:
logger.error(f"Error getting {self.model_type} roots: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def get_folders(self, request: web.Request) -> web.Response:
"""Get all folders in the cache"""
try:
cache = await self.service.scanner.get_cached_data()
return web.json_response({
'folders': cache.folders
})
except Exception as e:
logger.error(f"Error getting folders: {e}")
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def find_duplicate_models(self, request: web.Request) -> web.Response:
"""Find models with duplicate SHA256 hashes"""
try:
# Get duplicate hashes from service
duplicates = self.service.find_duplicate_hashes()
# Format the response
result = []
cache = await self.service.scanner.get_cached_data()
for sha256, paths in duplicates.items():
group = {
"hash": sha256,
"models": []
}
# Find matching models for each path
for path in paths:
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
if model:
group["models"].append(await self.service.format_response(model))
# Add the primary model too
primary_path = self.service.get_path_by_hash(sha256)
if primary_path and primary_path not in paths:
primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None)
if primary_model:
group["models"].insert(0, await self.service.format_response(primary_model))
if len(group["models"]) > 1: # Only include if we found multiple models
result.append(group)
return web.json_response({
"success": True,
"duplicates": result,
"count": len(result)
})
except Exception as e:
logger.error(f"Error finding duplicate {self.model_type}s: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def find_filename_conflicts(self, request: web.Request) -> web.Response:
"""Find models with conflicting filenames"""
try:
# Get duplicate filenames from service
duplicates = self.service.find_duplicate_filenames()
# Format the response
result = []
cache = await self.service.scanner.get_cached_data()
for filename, paths in duplicates.items():
group = {
"filename": filename,
"models": []
}
# Find matching models for each path
for path in paths:
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
if model:
group["models"].append(await self.service.format_response(model))
# Find the model from the main index too
hash_val = self.service.scanner._hash_index.get_hash_by_filename(filename)
if hash_val:
main_path = self.service.get_path_by_hash(hash_val)
if main_path and main_path not in paths:
main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None)
if main_model:
group["models"].insert(0, await self.service.format_response(main_model))
if group["models"]:
result.append(group)
return web.json_response({
"success": True,
"conflicts": result,
"count": len(result)
})
except Exception as e:
logger.error(f"Error finding filename conflicts for {self.model_type}s: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
# Download management methods
async def download_model(self, request: web.Request) -> web.Response:
"""Handle model download request"""
return await ModelRouteUtils.handle_download_model(request)
async def download_model_get(self, request: web.Request) -> web.Response:
"""Handle model download request via GET method"""
try:
# Extract query parameters
model_id = request.query.get('model_id')
if not model_id:
return web.Response(
status=400,
text="Missing required parameter: Please provide 'model_id'"
)
# Get optional parameters
model_version_id = request.query.get('model_version_id')
download_id = request.query.get('download_id')
use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true'
# Create a data dictionary that mimics what would be received from a POST request
data = {
'model_id': model_id
}
# Add optional parameters only if they are provided
if model_version_id:
data['model_version_id'] = model_version_id
if download_id:
data['download_id'] = download_id
data['use_default_paths'] = use_default_paths
# Create a mock request object with the data
future = asyncio.get_event_loop().create_future()
future.set_result(data)
mock_request = type('MockRequest', (), {
'json': lambda self=None: future
})()
# Call the existing download handler
return await ModelRouteUtils.handle_download_model(mock_request)
except Exception as e:
error_message = str(e)
logger.error(f"Error downloading model via GET: {error_message}", exc_info=True)
return web.Response(status=500, text=error_message)
async def cancel_download_get(self, request: web.Request) -> web.Response:
"""Handle GET request for cancelling a download by download_id"""
try:
download_id = request.query.get('download_id')
if not download_id:
return web.json_response({
'success': False,
'error': 'Download ID is required'
}, status=400)
# Create a mock request with match_info for compatibility
mock_request = type('MockRequest', (), {
'match_info': {'download_id': download_id}
})()
return await ModelRouteUtils.handle_cancel_download(mock_request)
except Exception as e:
logger.error(f"Error cancelling download via GET: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_download_progress(self, request: web.Request) -> web.Response:
"""Handle request for download progress by download_id"""
try:
# Get download_id from URL path
download_id = request.match_info.get('download_id')
if not download_id:
return web.json_response({
'success': False,
'error': 'Download ID is required'
}, status=400)
progress_data = ws_manager.get_download_progress(download_id)
if progress_data is None:
return web.json_response({
'success': False,
'error': 'Download ID not found'
}, status=404)
return web.json_response({
'success': True,
'progress': progress_data.get('progress', 0)
})
except Exception as e:
logger.error(f"Error getting download progress: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
"""Fetch CivitAI metadata for all models in the background"""
try:
cache = await self.service.scanner.get_cached_data()
total = len(cache.raw_data)
processed = 0
success = 0
needs_resort = False
# Prepare models to process
to_process = [
model for model in cache.raw_data
if model.get('sha256') and (not model.get('civitai') or 'id' not in model.get('civitai')) and model.get('from_civitai', True)
]
total_to_process = len(to_process)
# Send initial progress
await ws_manager.broadcast({
'status': 'started',
'total': total_to_process,
'processed': 0,
'success': 0
})
# Process each model
for model in to_process:
try:
original_name = model.get('model_name')
if await ModelRouteUtils.fetch_and_update_model(
sha256=model['sha256'],
file_path=model['file_path'],
model_data=model,
update_cache_func=self.service.scanner.update_single_model_cache
):
success += 1
if original_name != model.get('model_name'):
needs_resort = True
processed += 1
# Send progress update
await ws_manager.broadcast({
'status': 'processing',
'total': total_to_process,
'processed': processed,
'success': success,
'current_name': model.get('model_name', 'Unknown')
})
except Exception as e:
logger.error(f"Error fetching CivitAI data for {model['file_path']}: {e}")
if needs_resort:
await cache.resort()
# Send completion message
await ws_manager.broadcast({
'status': 'completed',
'total': total_to_process,
'processed': processed,
'success': success
})
return web.json_response({
"success": True,
"message": f"Successfully updated {success} of {processed} processed {self.model_type}s (total: {total})"
})
except Exception as e:
# Send error message
await ws_manager.broadcast({
'status': 'error',
'error': str(e)
})
logger.error(f"Error in fetch_all_civitai for {self.model_type}s: {e}")
return web.Response(text=str(e), status=500)
async def get_civitai_versions(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai model with local availability info"""
# This will be implemented by subclasses as they need CivitAI client access
return web.json_response({
"error": "Not implemented in base class"
}, status=501)

View File

@@ -0,0 +1,105 @@
import logging
from aiohttp import web
from .base_model_routes import BaseModelRoutes
from ..services.checkpoint_service import CheckpointService
from ..services.service_registry import ServiceRegistry
logger = logging.getLogger(__name__)
class CheckpointRoutes(BaseModelRoutes):
"""Checkpoint-specific route controller"""
def __init__(self):
"""Initialize Checkpoint routes with Checkpoint service"""
# Service will be initialized later via setup_routes
self.service = None
self.civitai_client = None
self.template_name = "checkpoints.html"
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
self.service = CheckpointService(checkpoint_scanner)
self.civitai_client = await ServiceRegistry.get_civitai_client()
# Initialize parent with the service
super().__init__(self.service)
def setup_routes(self, app: web.Application):
"""Setup Checkpoint routes"""
# Schedule service initialization on app startup
app.on_startup.append(lambda _: self.initialize_services())
# Setup common routes with 'checkpoints' prefix (includes page route)
super().setup_routes(app, 'checkpoints')
def setup_specific_routes(self, app: web.Application, prefix: str):
"""Setup Checkpoint-specific routes"""
# Checkpoint-specific CivitAI integration
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint)
# Checkpoint info by name
app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info)
async def get_checkpoint_info(self, request: web.Request) -> web.Response:
"""Get detailed information for a specific checkpoint by name"""
try:
name = request.match_info.get('name', '')
checkpoint_info = await self.service.get_model_info_by_name(name)
if checkpoint_info:
return web.json_response(checkpoint_info)
else:
return web.json_response({"error": "Checkpoint not found"}, status=404)
except Exception as e:
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_civitai_versions_checkpoint(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai checkpoint model with local availability info"""
try:
model_id = request.match_info['model_id']
response = await self.civitai_client.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - should be Checkpoint
if model_type.lower() != 'checkpoint':
return web.json_response({
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the primary model file (type="Model" and primary=true) in the files list
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model' and file.get('primary') == True), None)
# If no primary file found, try to find any model file
if not model_file:
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.service.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.service.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching checkpoint model versions: {e}")
return web.Response(status=500, text=str(e))

View File

@@ -1,771 +0,0 @@
import os
import json
import jinja2
from aiohttp import web
import logging
import asyncio
from ..utils.routes_common import ModelRouteUtils
from ..utils.constants import NSFW_LEVELS
from ..utils.metadata_manager import MetadataManager
from ..services.websocket_manager import ws_manager
from ..services.service_registry import ServiceRegistry
from ..config import config
from ..services.settings_manager import settings
from ..utils.utils import fuzzy_match
logger = logging.getLogger(__name__)
class CheckpointsRoutes:
"""API routes for checkpoint management"""
def __init__(self):
self.scanner = None # Will be initialized in setup_routes
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
self.download_manager = None # Will be initialized in setup_routes
self._download_lock = asyncio.Lock()
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
self.download_manager = await ServiceRegistry.get_download_manager()
def setup_routes(self, app):
"""Register routes with the aiohttp app"""
# Schedule service initialization on app startup
app.on_startup.append(lambda _: self.initialize_services())
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
app.router.add_get('/api/checkpoints', self.get_checkpoints)
app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai)
app.router.add_get('/api/checkpoints/base-models', self.get_base_models)
app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags)
app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints)
app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info)
app.router.add_get('/api/checkpoints/roots', self.get_checkpoint_roots)
app.router.add_get('/api/checkpoints/civitai/versions/{model_id}', self.get_civitai_versions) # Add new route
# Add new routes for model management similar to LoRA routes
app.router.add_post('/api/checkpoints/delete', self.delete_model)
app.router.add_post('/api/checkpoints/exclude', self.exclude_model) # Add new exclude endpoint
app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai)
app.router.add_post('/api/checkpoints/relink-civitai', self.relink_civitai) # Add new relink endpoint
app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview)
app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route
app.router.add_post('/api/checkpoints/rename', self.rename_checkpoint) # Add new rename endpoint
# Add new routes for finding duplicates and filename conflicts
app.router.add_get('/api/checkpoints/find-duplicates', self.find_duplicate_checkpoints)
app.router.add_get('/api/checkpoints/find-filename-conflicts', self.find_filename_conflicts)
# Add new endpoint for bulk deleting checkpoints
app.router.add_post('/api/checkpoints/bulk-delete', self.bulk_delete_checkpoints)
# Add new endpoint for verifying duplicates
app.router.add_post('/api/checkpoints/verify-duplicates', self.verify_duplicates)
async def get_checkpoints(self, request):
"""Get paginated checkpoint data"""
try:
# Parse query parameters
page = int(request.query.get('page', '1'))
page_size = min(int(request.query.get('page_size', '20')), 100)
sort_by = request.query.get('sort_by', 'name')
folder = request.query.get('folder', None)
search = request.query.get('search', None)
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
base_models = request.query.getall('base_model', [])
tags = request.query.getall('tag', [])
favorites_only = request.query.get('favorites_only', 'false').lower() == 'true' # Add favorites_only parameter
# Process search options
search_options = {
'filename': request.query.get('search_filename', 'true').lower() == 'true',
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
'tags': request.query.get('search_tags', 'false').lower() == 'true',
'recursive': request.query.get('recursive', 'false').lower() == 'true',
}
# Process hash filters if provided
hash_filters = {}
if 'hash' in request.query:
hash_filters['single_hash'] = request.query['hash']
elif 'hashes' in request.query:
try:
hash_list = json.loads(request.query['hashes'])
if isinstance(hash_list, list):
hash_filters['multiple_hashes'] = hash_list
except (json.JSONDecodeError, TypeError):
pass
# Get data from scanner
result = await self.get_paginated_data(
page=page,
page_size=page_size,
sort_by=sort_by,
folder=folder,
search=search,
fuzzy_search=fuzzy_search,
base_models=base_models,
tags=tags,
search_options=search_options,
hash_filters=hash_filters,
favorites_only=favorites_only # Pass favorites_only parameter
)
# Format response items
formatted_result = {
'items': [self._format_checkpoint_response(cp) for cp in result['items']],
'total': result['total'],
'page': result['page'],
'page_size': result['page_size'],
'total_pages': result['total_pages']
}
# Return as JSON
return web.json_response(formatted_result)
except Exception as e:
logger.error(f"Error in get_checkpoints: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_paginated_data(self, page, page_size, sort_by='name',
folder=None, search=None, fuzzy_search=False,
base_models=None, tags=None,
search_options=None, hash_filters=None,
favorites_only=False): # Add favorites_only parameter with default False
"""Get paginated and filtered checkpoint data"""
cache = await self.scanner.get_cached_data()
# Get default search options if not provided
if search_options is None:
search_options = {
'filename': True,
'modelname': True,
'tags': False,
'recursive': False,
}
# Get the base data set
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
# Apply hash filtering if provided (highest priority)
if hash_filters:
single_hash = hash_filters.get('single_hash')
multiple_hashes = hash_filters.get('multiple_hashes')
if single_hash:
# Filter by single hash
single_hash = single_hash.lower() # Ensure lowercase for matching
filtered_data = [
cp for cp in filtered_data
if cp.get('sha256', '').lower() == single_hash
]
elif multiple_hashes:
# Filter by multiple hashes
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
filtered_data = [
cp for cp in filtered_data
if cp.get('sha256', '').lower() in hash_set
]
# Jump to pagination
total_items = len(filtered_data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
result = {
'items': filtered_data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
return result
# Apply SFW filtering if enabled in settings
if settings.get('show_only_sfw', False):
filtered_data = [
cp for cp in filtered_data
if not cp.get('preview_nsfw_level') or cp.get('preview_nsfw_level') < NSFW_LEVELS['R']
]
# Apply favorites filtering if enabled
if favorites_only:
filtered_data = [
cp for cp in filtered_data
if cp.get('favorite', False) is True
]
# Apply folder filtering
if folder is not None:
if search_options.get('recursive', False):
# Recursive folder filtering - include all subfolders
filtered_data = [
cp for cp in filtered_data
if cp['folder'].startswith(folder)
]
else:
# Exact folder filtering
filtered_data = [
cp for cp in filtered_data
if cp['folder'] == folder
]
# Apply base model filtering
if base_models and len(base_models) > 0:
filtered_data = [
cp for cp in filtered_data
if cp.get('base_model') in base_models
]
# Apply tag filtering
if tags and len(tags) > 0:
filtered_data = [
cp for cp in filtered_data
if any(tag in cp.get('tags', []) for tag in tags)
]
# Apply search filtering
if search:
search_results = []
for cp in filtered_data:
# Search by file name
if search_options.get('filename', True):
if fuzzy_search:
if fuzzy_match(cp.get('file_name', ''), search):
search_results.append(cp)
continue
elif search.lower() in cp.get('file_name', '').lower():
search_results.append(cp)
continue
# Search by model name
if search_options.get('modelname', True):
if fuzzy_search:
if fuzzy_match(cp.get('model_name', ''), search):
search_results.append(cp)
continue
elif search.lower() in cp.get('model_name', '').lower():
search_results.append(cp)
continue
# Search by tags
if search_options.get('tags', False) and 'tags' in cp:
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower()) for tag in cp['tags']):
search_results.append(cp)
continue
filtered_data = search_results
# Calculate pagination
total_items = len(filtered_data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
result = {
'items': filtered_data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
return result
def _format_checkpoint_response(self, checkpoint):
"""Format checkpoint data for API response"""
return {
"model_name": checkpoint["model_name"],
"file_name": checkpoint["file_name"],
"preview_url": config.get_preview_static_url(checkpoint.get("preview_url", "")),
"preview_nsfw_level": checkpoint.get("preview_nsfw_level", 0),
"base_model": checkpoint.get("base_model", ""),
"folder": checkpoint["folder"],
"sha256": checkpoint.get("sha256", ""),
"file_path": checkpoint["file_path"].replace(os.sep, "/"),
"file_size": checkpoint.get("size", 0),
"modified": checkpoint.get("modified", ""),
"tags": checkpoint.get("tags", []),
"modelDescription": checkpoint.get("modelDescription", ""),
"from_civitai": checkpoint.get("from_civitai", True),
"notes": checkpoint.get("notes", ""),
"model_type": checkpoint.get("model_type", "checkpoint"),
"favorite": checkpoint.get("favorite", False),
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint.get("civitai", {}))
}
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
"""Fetch CivitAI metadata for all checkpoints in the background"""
try:
cache = await self.scanner.get_cached_data()
total = len(cache.raw_data)
processed = 0
success = 0
needs_resort = False
# Prepare checkpoints to process
to_process = [
cp for cp in cache.raw_data
if cp.get('sha256') and (not cp.get('civitai') or 'id' not in cp.get('civitai')) and cp.get('from_civitai', True)
]
total_to_process = len(to_process)
# Send initial progress
await ws_manager.broadcast({
'status': 'started',
'total': total_to_process,
'processed': 0,
'success': 0
})
# Process each checkpoint
for cp in to_process:
try:
original_name = cp.get('model_name')
if await ModelRouteUtils.fetch_and_update_model(
sha256=cp['sha256'],
file_path=cp['file_path'],
model_data=cp,
update_cache_func=self.scanner.update_single_model_cache
):
success += 1
if original_name != cp.get('model_name'):
needs_resort = True
processed += 1
# Send progress update
await ws_manager.broadcast({
'status': 'processing',
'total': total_to_process,
'processed': processed,
'success': success,
'current_name': cp.get('model_name', 'Unknown')
})
except Exception as e:
logger.error(f"Error fetching CivitAI data for {cp['file_path']}: {e}")
if needs_resort:
await cache.resort(name_only=True)
# Send completion message
await ws_manager.broadcast({
'status': 'completed',
'total': total_to_process,
'processed': processed,
'success': success
})
return web.json_response({
"success": True,
"message": f"Successfully updated {success} of {processed} processed checkpoints (total: {total})"
})
except Exception as e:
# Send error message
await ws_manager.broadcast({
'status': 'error',
'error': str(e)
})
logger.error(f"Error in fetch_all_civitai for checkpoints: {e}")
return web.Response(text=str(e), status=500)
async def get_top_tags(self, request: web.Request) -> web.Response:
"""Handle request for top tags sorted by frequency"""
try:
# Parse query parameters
limit = int(request.query.get('limit', '20'))
# Validate limit
if limit < 1 or limit > 100:
limit = 20 # Default to a reasonable limit
# Get top tags
top_tags = await self.scanner.get_top_tags(limit)
return web.json_response({
'success': True,
'tags': top_tags
})
except Exception as e:
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
return web.json_response({
'success': False,
'error': 'Internal server error'
}, status=500)
async def get_base_models(self, request: web.Request) -> web.Response:
"""Get base models used in loras"""
try:
# Parse query parameters
limit = int(request.query.get('limit', '20'))
# Validate limit
if limit < 1 or limit > 100:
limit = 20 # Default to a reasonable limit
# Get base models
base_models = await self.scanner.get_base_models(limit)
return web.json_response({
'success': True,
'base_models': base_models
})
except Exception as e:
logger.error(f"Error retrieving base models: {e}")
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def scan_checkpoints(self, request):
"""Force a rescan of checkpoint files"""
try:
# Get the full_rebuild parameter and convert to bool, default to False
full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true'
await self.scanner.get_cached_data(force_refresh=True, rebuild_cache=full_rebuild)
return web.json_response({"status": "success", "message": "Checkpoint scan completed"})
except Exception as e:
logger.error(f"Error in scan_checkpoints: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_checkpoint_info(self, request):
"""Get detailed information for a specific checkpoint by name"""
try:
name = request.match_info.get('name', '')
checkpoint_info = await self.scanner.get_model_info_by_name(name)
if checkpoint_info:
return web.json_response(checkpoint_info)
else:
return web.json_response({"error": "Checkpoint not found"}, status=404)
except Exception as e:
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def handle_checkpoints_page(self, request: web.Request) -> web.Response:
"""Handle GET /checkpoints request"""
try:
# Check if the CheckpointScanner is initializing
# It's initializing if the cache object doesn't exist yet,
# OR if the scanner explicitly says it's initializing (background task running).
is_initializing = (
self.scanner._cache is None or
(hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing)
)
if is_initializing:
# If still initializing, return loading page
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=[], # 空文件夹列表
is_initializing=True, # 新增标志
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
)
logger.info("Checkpoints page is initializing, returning loading page")
else:
# 正常流程 - 获取已经初始化好的缓存数据
try:
cache = await self.scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=cache.folders,
is_initializing=False,
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
)
except Exception as cache_error:
logger.error(f"Error loading checkpoints cache data: {cache_error}")
# 如果获取缓存失败,也显示初始化页面
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
logger.info("Checkpoints cache error, returning initialization page")
return web.Response(
text=rendered,
content_type='text/html'
)
except Exception as e:
logger.error(f"Error handling checkpoints request: {e}", exc_info=True)
return web.Response(
text="Error loading checkpoints page",
status=500
)
async def delete_model(self, request: web.Request) -> web.Response:
"""Handle checkpoint model deletion request"""
return await ModelRouteUtils.handle_delete_model(request, self.scanner)
async def exclude_model(self, request: web.Request) -> web.Response:
"""Handle checkpoint model exclusion request"""
return await ModelRouteUtils.handle_exclude_model(request, self.scanner)
async def fetch_civitai(self, request: web.Request) -> web.Response:
"""Handle CivitAI metadata fetch request for checkpoints"""
response = await ModelRouteUtils.handle_fetch_civitai(request, self.scanner)
# If successful, format the metadata before returning
if response.status == 200:
data = json.loads(response.body.decode('utf-8'))
if data.get("success") and data.get("metadata"):
formatted_metadata = self._format_checkpoint_response(data["metadata"])
return web.json_response({
"success": True,
"metadata": formatted_metadata
})
# Otherwise, return the original response
return response
async def replace_preview(self, request: web.Request) -> web.Response:
"""Handle preview image replacement for checkpoints"""
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
async def get_checkpoint_roots(self, request):
"""Return the checkpoint root directories"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
roots = self.scanner.get_model_roots()
return web.json_response({
"success": True,
"roots": roots
})
except Exception as e:
logger.error(f"Error getting checkpoint roots: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def save_metadata(self, request: web.Request) -> web.Response:
"""Handle saving metadata updates for checkpoints"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
data = await request.json()
file_path = data.get('file_path')
if not file_path:
return web.Response(text='File path is required', status=400)
# Remove file path from data to avoid saving it
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
# Get metadata file path
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Load existing metadata
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
# Update metadata
metadata.update(metadata_updates)
# Save updated metadata
await MetadataManager.save_metadata(file_path, metadata)
# Update cache
await self.scanner.update_single_model_cache(file_path, file_path, metadata)
# If model_name was updated, resort the cache
if 'model_name' in metadata_updates:
cache = await self.scanner.get_cached_data()
await cache.resort(name_only=True)
return web.json_response({'success': True})
except Exception as e:
logger.error(f"Error saving checkpoint metadata: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
async def get_civitai_versions(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai checkpoint model with local availability info"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
# Get the civitai client from service registry
civitai_client = await ServiceRegistry.get_civitai_client()
model_id = request.match_info['model_id']
response = await civitai_client.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - should be Checkpoint
if (model_type.lower() != 'checkpoint'):
return web.json_response({
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the primary model file (type="Model" and primary=true) in the files list
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model' and file.get('primary') == True), None)
# If no primary file found, try to find any model file
if not model_file:
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.scanner.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.scanner.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching checkpoint model versions: {e}")
return web.Response(status=500, text=str(e))
async def find_duplicate_checkpoints(self, request: web.Request) -> web.Response:
"""Find checkpoints with duplicate SHA256 hashes"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
# Get duplicate hashes from hash index
duplicates = self.scanner._hash_index.get_duplicate_hashes()
# Format the response
result = []
cache = await self.scanner.get_cached_data()
for sha256, paths in duplicates.items():
group = {
"hash": sha256,
"models": []
}
# Find matching models for each path
for path in paths:
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
if model:
group["models"].append(self._format_checkpoint_response(model))
# Add the primary model too
primary_path = self.scanner._hash_index.get_path(sha256)
if primary_path and primary_path not in paths:
primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None)
if primary_model:
group["models"].insert(0, self._format_checkpoint_response(primary_model))
if len(group["models"]) > 1: # Only include if we found multiple models
result.append(group)
return web.json_response({
"success": True,
"duplicates": result,
"count": len(result)
})
except Exception as e:
logger.error(f"Error finding duplicate checkpoints: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def find_filename_conflicts(self, request: web.Request) -> web.Response:
"""Find checkpoints with conflicting filenames"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
# Get duplicate filenames from hash index
duplicates = self.scanner._hash_index.get_duplicate_filenames()
# Format the response
result = []
cache = await self.scanner.get_cached_data()
for filename, paths in duplicates.items():
group = {
"filename": filename,
"models": []
}
# Find matching models for each path
for path in paths:
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
if model:
group["models"].append(self._format_checkpoint_response(model))
# Find the model from the main index too
hash_val = self.scanner._hash_index.get_hash_by_filename(filename)
if hash_val:
main_path = self.scanner._hash_index.get_path(hash_val)
if main_path and main_path not in paths:
main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None)
if main_model:
group["models"].insert(0, self._format_checkpoint_response(main_model))
if group["models"]:
result.append(group)
return web.json_response({
"success": True,
"conflicts": result,
"count": len(result)
})
except Exception as e:
logger.error(f"Error finding filename conflicts: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def bulk_delete_checkpoints(self, request: web.Request) -> web.Response:
"""Handle bulk deletion of checkpoint models"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
return await ModelRouteUtils.handle_bulk_delete_models(request, self.scanner)
except Exception as e:
logger.error(f"Error in bulk delete checkpoints: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def relink_civitai(self, request: web.Request) -> web.Response:
"""Handle CivitAI metadata re-linking request by model version ID for checkpoints"""
return await ModelRouteUtils.handle_relink_civitai(request, self.scanner)
async def verify_duplicates(self, request: web.Request) -> web.Response:
"""Handle verification of duplicate checkpoint hashes"""
return await ModelRouteUtils.handle_verify_duplicates(request, self.scanner)
async def rename_checkpoint(self, request: web.Request) -> web.Response:
"""Handle renaming a checkpoint file and its associated files"""
return await ModelRouteUtils.handle_rename_model(request, self.scanner)

View File

@@ -1,188 +1,512 @@
import os
from aiohttp import web
import jinja2
from typing import Dict
import asyncio
import logging
from ..config import config
from ..services.settings_manager import settings
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
from aiohttp import web
from typing import Dict
from server import PromptServer # type: ignore
from .base_model_routes import BaseModelRoutes
from ..services.lora_service import LoraService
from ..services.service_registry import ServiceRegistry
from ..utils.routes_common import ModelRouteUtils
from ..utils.utils import get_lora_info
logger = logging.getLogger(__name__)
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
class LoraRoutes:
"""Route handlers for LoRA management endpoints"""
class LoraRoutes(BaseModelRoutes):
"""LoRA-specific route controller"""
def __init__(self):
# Initialize service references as None, will be set during async init
self.scanner = None
self.recipe_scanner = None
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
async def init_services(self):
"""Initialize services from ServiceRegistry"""
self.scanner = await ServiceRegistry.get_lora_scanner()
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
"""Initialize LoRA routes with LoRA service"""
# Service will be initialized later via setup_routes
self.service = None
self.civitai_client = None
self.template_name = "loras.html"
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
lora_scanner = await ServiceRegistry.get_lora_scanner()
self.service = LoraService(lora_scanner)
self.civitai_client = await ServiceRegistry.get_civitai_client()
# Initialize parent with the service
super().__init__(self.service)
def format_lora_data(self, lora: Dict) -> Dict:
"""Format LoRA data for template rendering"""
return {
"model_name": lora["model_name"],
"file_name": lora["file_name"],
"preview_url": config.get_preview_static_url(lora["preview_url"]),
"preview_nsfw_level": lora.get("preview_nsfw_level", 0),
"base_model": lora["base_model"],
"folder": lora["folder"],
"sha256": lora["sha256"],
"file_path": lora["file_path"].replace(os.sep, "/"),
"size": lora["size"],
"tags": lora["tags"],
"modelDescription": lora["modelDescription"],
"usage_tips": lora["usage_tips"],
"notes": lora["notes"],
"modified": lora["modified"],
"from_civitai": lora.get("from_civitai", True),
"civitai": self._filter_civitai_data(lora.get("civitai", {}))
}
def _filter_civitai_data(self, data: Dict) -> Dict:
"""Filter relevant fields from CivitAI data"""
if not data:
return {}
fields = [
"id", "modelId", "name", "createdAt", "updatedAt",
"publishedAt", "trainedWords", "baseModel", "description",
"model", "images"
]
return {k: data[k] for k in fields if k in data}
async def handle_loras_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras request"""
try:
# Ensure services are initialized
await self.init_services()
# Check if the LoraScanner is initializing
# It's initializing if the cache object doesn't exist yet,
# OR if the scanner explicitly says it's initializing (background task running).
is_initializing = (
self.scanner._cache is None or self.scanner.is_initializing()
)
if is_initializing:
# If still initializing, return loading page
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
logger.info("Loras page is initializing, returning loading page")
else:
# Normal flow - get data from initialized cache
try:
cache = await self.scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=cache.folders,
is_initializing=False,
settings=settings,
request=request
)
except Exception as cache_error:
logger.error(f"Error loading cache data: {cache_error}")
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
logger.info("Cache error, returning initialization page")
return web.Response(
text=rendered,
content_type='text/html'
)
except Exception as e:
logger.error(f"Error handling loras request: {e}", exc_info=True)
return web.Response(
text="Error loading loras page",
status=500
)
async def handle_recipes_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras/recipes request"""
try:
# Ensure services are initialized
await self.init_services()
# Skip initialization check and directly try to get cached data
try:
# Recipe scanner will initialize cache if needed
await self.recipe_scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('recipes.html')
rendered = template.render(
recipes=[], # Frontend will load recipes via API
is_initializing=False,
settings=settings,
request=request
)
except Exception as cache_error:
logger.error(f"Error loading recipe cache data: {cache_error}")
# Still keep error handling - show initializing page on error
template = self.template_env.get_template('recipes.html')
rendered = template.render(
is_initializing=True,
settings=settings,
request=request
)
logger.info("Recipe cache error, returning initialization page")
return web.Response(
text=rendered,
content_type='text/html'
)
except Exception as e:
logger.error(f"Error handling recipes request: {e}", exc_info=True)
return web.Response(
text="Error loading recipes page",
status=500
)
def _format_recipe_file_url(self, file_path: str) -> str:
"""Format file path for recipe image as a URL - same as in recipe_routes"""
try:
# Return the file URL directly for the first lora root's preview
recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, '/')
if file_path.replace(os.sep, '/').startswith(recipes_dir):
relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, '/')
return f"/loras_static/root1/preview/{relative_path}"
# If not in recipes dir, try to create a valid URL from the file path
file_name = os.path.basename(file_path)
return f"/loras_static/root1/preview/recipes/{file_name}"
except Exception as e:
logger.error(f"Error formatting recipe file URL: {e}", exc_info=True)
return '/loras_static/images/no-preview.png' # Return default image on error
def setup_routes(self, app: web.Application):
"""Register routes with the application"""
# Add an app startup handler to initialize services
app.on_startup.append(self._on_startup)
"""Setup LoRA routes"""
# Schedule service initialization on app startup
app.on_startup.append(lambda _: self.initialize_services())
# Register routes
app.router.add_get('/loras', self.handle_loras_page)
app.router.add_get('/loras/recipes', self.handle_recipes_page)
# Setup common routes with 'loras' prefix (includes page route)
super().setup_routes(app, 'loras')
def setup_specific_routes(self, app: web.Application, prefix: str):
"""Setup LoRA-specific routes"""
# LoRA-specific query routes
app.router.add_get(f'/api/{prefix}/letter-counts', self.get_letter_counts)
app.router.add_get(f'/api/{prefix}/get-notes', self.get_lora_notes)
app.router.add_get(f'/api/{prefix}/get-trigger-words', self.get_lora_trigger_words)
app.router.add_get(f'/api/lora-preview-url', self.get_lora_preview_url)
app.router.add_get(f'/api/lora-civitai-url', self.get_lora_civitai_url)
app.router.add_get(f'/api/lora-model-description', self.get_lora_model_description)
async def _on_startup(self, app):
"""Initialize services when the app starts"""
await self.init_services()
# LoRA-specific management routes
app.router.add_post(f'/api/move_model', self.move_model)
app.router.add_post(f'/api/move_models_bulk', self.move_models_bulk)
# CivitAI integration with LoRA-specific validation
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_lora)
app.router.add_get(f'/api/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version)
app.router.add_get(f'/api/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash)
# ComfyUI integration
app.router.add_post(f'/loramanager/get_trigger_words', self.get_trigger_words)
def _parse_specific_params(self, request: web.Request) -> Dict:
"""Parse LoRA-specific parameters"""
params = {}
# LoRA-specific parameters
if 'first_letter' in request.query:
params['first_letter'] = request.query.get('first_letter')
# Handle fuzzy search parameter name variation
if request.query.get('fuzzy') == 'true':
params['fuzzy_search'] = True
# Handle additional filter parameters for LoRAs
if 'lora_hash' in request.query:
if not params.get('hash_filters'):
params['hash_filters'] = {}
params['hash_filters']['single_hash'] = request.query['lora_hash'].lower()
elif 'lora_hashes' in request.query:
if not params.get('hash_filters'):
params['hash_filters'] = {}
params['hash_filters']['multiple_hashes'] = [h.lower() for h in request.query['lora_hashes'].split(',')]
return params
# LoRA-specific route handlers
async def get_letter_counts(self, request: web.Request) -> web.Response:
"""Get count of LoRAs for each letter of the alphabet"""
try:
letter_counts = await self.service.get_letter_counts()
return web.json_response({
'success': True,
'letter_counts': letter_counts
})
except Exception as e:
logger.error(f"Error getting letter counts: {e}")
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_lora_notes(self, request: web.Request) -> web.Response:
"""Get notes for a specific LoRA file"""
try:
lora_name = request.query.get('name')
if not lora_name:
return web.Response(text='Lora file name is required', status=400)
notes = await self.service.get_lora_notes(lora_name)
if notes is not None:
return web.json_response({
'success': True,
'notes': notes
})
else:
return web.json_response({
'success': False,
'error': 'LoRA not found in cache'
}, status=404)
except Exception as e:
logger.error(f"Error getting lora notes: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_lora_trigger_words(self, request: web.Request) -> web.Response:
"""Get trigger words for a specific LoRA file"""
try:
lora_name = request.query.get('name')
if not lora_name:
return web.Response(text='Lora file name is required', status=400)
trigger_words = await self.service.get_lora_trigger_words(lora_name)
return web.json_response({
'success': True,
'trigger_words': trigger_words
})
except Exception as e:
logger.error(f"Error getting lora trigger words: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
"""Get the static preview URL for a LoRA file"""
try:
lora_name = request.query.get('name')
if not lora_name:
return web.Response(text='Lora file name is required', status=400)
preview_url = await self.service.get_lora_preview_url(lora_name)
if preview_url:
return web.json_response({
'success': True,
'preview_url': preview_url
})
else:
return web.json_response({
'success': False,
'error': 'No preview URL found for the specified lora'
}, status=404)
except Exception as e:
logger.error(f"Error getting lora preview URL: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
"""Get the Civitai URL for a LoRA file"""
try:
lora_name = request.query.get('name')
if not lora_name:
return web.Response(text='Lora file name is required', status=400)
result = await self.service.get_lora_civitai_url(lora_name)
if result['civitai_url']:
return web.json_response({
'success': True,
**result
})
else:
return web.json_response({
'success': False,
'error': 'No Civitai data found for the specified lora'
}, status=404)
except Exception as e:
logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
# Override get_models to add LoRA-specific response data
async def get_models(self, request: web.Request) -> web.Response:
"""Get paginated LoRA data with LoRA-specific fields"""
try:
# Parse common query parameters
params = self._parse_common_params(request)
# Get data from service
result = await self.service.get_paginated_data(**params)
# Get all available folders from cache for LoRA-specific response
cache = await self.service.scanner.get_cached_data()
# Format response items with LoRA-specific structure
formatted_result = {
'items': [await self.service.format_response(item) for item in result['items']],
'folders': cache.folders, # LoRA-specific: include folders in response
'total': result['total'],
'page': result['page'],
'page_size': result['page_size'],
'total_pages': result['total_pages']
}
return web.json_response(formatted_result)
except Exception as e:
logger.error(f"Error in get_loras: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
# CivitAI integration methods
async def get_civitai_versions_lora(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai LoRA model with local availability info"""
try:
model_id = request.match_info['model_id']
response = await self.civitai_client.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - should be LORA, LoCon, or DORA
from ..utils.constants import VALID_LORA_TYPES
if model_type.lower() not in VALID_LORA_TYPES:
return web.json_response({
'error': f"Model type mismatch. Expected LORA or LoCon, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the model file (type="Model") in the files list
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.service.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.service.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching LoRA model versions: {e}")
return web.Response(status=500, text=str(e))
async def get_civitai_model_by_version(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by model version ID"""
try:
model_version_id = request.match_info.get('modelVersionId')
# Get model details from Civitai API
model, error_msg = await self.civitai_client.get_model_version_info(model_version_id)
if not model:
# Log warning for failed model retrieval
logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}")
# Determine status code based on error message
status_code = 404 if error_msg and "not found" in error_msg.lower() else 500
return web.json_response({
"success": False,
"error": error_msg or "Failed to fetch model information"
}, status=status_code)
return web.json_response(model)
except Exception as e:
logger.error(f"Error fetching model details: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by hash"""
try:
hash = request.match_info.get('hash')
model = await self.civitai_client.get_model_by_hash(hash)
return web.json_response(model)
except Exception as e:
logger.error(f"Error fetching model details by hash: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
# Model management methods
async def move_model(self, request: web.Request) -> web.Response:
"""Handle model move request"""
try:
data = await request.json()
file_path = data.get('file_path') # full path of the model file
target_path = data.get('target_path') # folder path to move the model to
if not file_path or not target_path:
return web.Response(text='File path and target path are required', status=400)
# Check if source and destination are the same
import os
source_dir = os.path.dirname(file_path)
if os.path.normpath(source_dir) == os.path.normpath(target_path):
logger.info(f"Source and target directories are the same: {source_dir}")
return web.json_response({'success': True, 'message': 'Source and target directories are the same'})
# Check if target file already exists
file_name = os.path.basename(file_path)
target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/')
if os.path.exists(target_file_path):
return web.json_response({
'success': False,
'error': f"Target file already exists: {target_file_path}"
}, status=409) # 409 Conflict
# Call scanner to handle the move operation
success = await self.service.scanner.move_model(file_path, target_path)
if success:
return web.json_response({'success': True})
else:
return web.Response(text='Failed to move model', status=500)
except Exception as e:
logger.error(f"Error moving model: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
async def move_models_bulk(self, request: web.Request) -> web.Response:
"""Handle bulk model move request"""
try:
data = await request.json()
file_paths = data.get('file_paths', []) # list of full paths of the model files
target_path = data.get('target_path') # folder path to move the models to
if not file_paths or not target_path:
return web.Response(text='File paths and target path are required', status=400)
results = []
import os
for file_path in file_paths:
# Check if source and destination are the same
source_dir = os.path.dirname(file_path)
if os.path.normpath(source_dir) == os.path.normpath(target_path):
results.append({
"path": file_path,
"success": True,
"message": "Source and target directories are the same"
})
continue
# Check if target file already exists
file_name = os.path.basename(file_path)
target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/')
if os.path.exists(target_file_path):
results.append({
"path": file_path,
"success": False,
"message": f"Target file already exists: {target_file_path}"
})
continue
# Try to move the model
success = await self.service.scanner.move_model(file_path, target_path)
results.append({
"path": file_path,
"success": success,
"message": "Success" if success else "Failed to move model"
})
# Count successes and failures
success_count = sum(1 for r in results if r["success"])
failure_count = len(results) - success_count
return web.json_response({
'success': True,
'message': f'Moved {success_count} of {len(file_paths)} models',
'results': results,
'success_count': success_count,
'failure_count': failure_count
})
except Exception as e:
logger.error(f"Error moving models in bulk: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
async def get_lora_model_description(self, request: web.Request) -> web.Response:
"""Get model description for a Lora model"""
try:
# Get parameters
model_id = request.query.get('model_id')
file_path = request.query.get('file_path')
if not model_id:
return web.json_response({
'success': False,
'error': 'Model ID is required'
}, status=400)
# Check if we already have the description stored in metadata
description = None
tags = []
creator = {}
if file_path:
import os
from ..utils.metadata_manager import MetadataManager
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
description = metadata.get('modelDescription')
tags = metadata.get('tags', [])
creator = metadata.get('creator', {})
# If description is not in metadata, fetch from CivitAI
if not description:
logger.info(f"Fetching model metadata for model ID: {model_id}")
model_metadata, _ = await self.civitai_client.get_model_metadata(model_id)
if model_metadata:
description = model_metadata.get('description')
tags = model_metadata.get('tags', [])
creator = model_metadata.get('creator', {})
# Save the metadata to file if we have a file path and got metadata
if file_path:
try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
metadata['modelDescription'] = description
metadata['tags'] = tags
# Ensure the civitai dict exists
if 'civitai' not in metadata:
metadata['civitai'] = {}
# Store creator in the civitai nested structure
metadata['civitai']['creator'] = creator
await MetadataManager.save_metadata(file_path, metadata, True)
except Exception as e:
logger.error(f"Error saving model metadata: {e}")
return web.json_response({
'success': True,
'description': description or "<p>No model description available.</p>",
'tags': tags,
'creator': creator
})
except Exception as e:
logger.error(f"Error getting model metadata: {e}")
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_trigger_words(self, request: web.Request) -> web.Response:
"""Get trigger words for specified LoRA models"""
try:
json_data = await request.json()
lora_names = json_data.get("lora_names", [])
node_ids = json_data.get("node_ids", [])
all_trigger_words = []
for lora_name in lora_names:
_, trigger_words = get_lora_info(lora_name)
all_trigger_words.extend(trigger_words)
# Format the trigger words
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
# Send update to all connected trigger word toggle nodes
for node_id in node_ids:
PromptServer.instance.send_sync("trigger_word_update", {
"id": node_id,
"message": trigger_words_text
})
return web.json_response({"success": True})
except Exception as e:
logger.error(f"Error getting trigger words: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)

View File

@@ -633,9 +633,8 @@ class MiscRoutes:
}, status=400)
# Get both lora and checkpoint scanners
registry = ServiceRegistry.get_instance()
lora_scanner = await registry.get_lora_scanner()
checkpoint_scanner = await registry.get_checkpoint_scanner()
lora_scanner = await ServiceRegistry.get_lora_scanner()
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
# If modelVersionId is provided, check for specific version
if model_version_id_str:

View File

@@ -1,6 +1,7 @@
import os
import time
import base64
import jinja2
import numpy as np
from PIL import Image
import io
@@ -15,6 +16,7 @@ from ..utils.exif_utils import ExifUtils
from ..recipes import RecipeParserFactory
from ..utils.constants import CARD_PREVIEW_WIDTH
from ..services.settings_manager import settings
from ..config import config
# Check if running in standalone mode
@@ -39,7 +41,10 @@ class RecipeRoutes:
# Initialize service references as None, will be set during async init
self.recipe_scanner = None
self.civitai_client = None
# Remove WorkflowParser instance
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
# Pre-warm the cache
self._init_cache_task = None
@@ -53,6 +58,8 @@ class RecipeRoutes:
def setup_routes(cls, app: web.Application):
"""Register API routes"""
routes = cls()
app.router.add_get('/loras/recipes', routes.handle_recipes_page)
app.router.add_get('/api/recipes', routes.get_recipes)
app.router.add_get('/api/recipe/{recipe_id}', routes.get_recipe_detail)
app.router.add_post('/api/recipes/analyze-image', routes.analyze_recipe_image)
@@ -114,6 +121,46 @@ class RecipeRoutes:
await self.recipe_scanner.get_cached_data(force_refresh=True)
except Exception as e:
logger.error(f"Error pre-warming recipe cache: {e}", exc_info=True)
async def handle_recipes_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras/recipes request"""
try:
# Ensure services are initialized
await self.init_services()
# Skip initialization check and directly try to get cached data
try:
# Recipe scanner will initialize cache if needed
await self.recipe_scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('recipes.html')
rendered = template.render(
recipes=[], # Frontend will load recipes via API
is_initializing=False,
settings=settings,
request=request
)
except Exception as cache_error:
logger.error(f"Error loading recipe cache data: {cache_error}")
# Still keep error handling - show initializing page on error
template = self.template_env.get_template('recipes.html')
rendered = template.render(
is_initializing=True,
settings=settings,
request=request
)
logger.info("Recipe cache error, returning initialization page")
return web.Response(
text=rendered,
content_type='text/html'
)
except Exception as e:
logger.error(f"Error handling recipes request: {e}", exc_info=True)
return web.Response(
text="Error loading recipes page",
status=500
)
async def get_recipes(self, request: web.Request) -> web.Response:
"""API endpoint for getting paginated recipes"""
@@ -1101,7 +1148,7 @@ class RecipeRoutes:
for lora_name, lora_strength in lora_matches:
try:
# Get lora info from scanner
lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora_name)
lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora_name)
# Create lora entry
lora_entry = {
@@ -1120,7 +1167,7 @@ class RecipeRoutes:
# Get base model from lora scanner for the available loras
base_model_counts = {}
for lora in loras_data:
lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora.get("file_name", ""))
lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora.get("file_name", ""))
if lora_info and "base_model" in lora_info:
base_model = lora_info["base_model"]
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
@@ -1210,7 +1257,7 @@ class RecipeRoutes:
if lora.get("isDeleted", False):
continue
if not self.recipe_scanner._lora_scanner.has_lora_hash(lora.get("hash", "")):
if not self.recipe_scanner._lora_scanner.has_hash(lora.get("hash", "")):
continue
# Get the strength
@@ -1318,7 +1365,7 @@ class RecipeRoutes:
return web.json_response({"error": "Recipe not found"}, status=404)
# Find target LoRA by name
target_lora = await lora_scanner.get_lora_info_by_name(target_name)
target_lora = await lora_scanner.get_model_info_by_name(target_name)
if not target_lora:
return web.json_response({"error": f"Local LoRA not found with name: {target_name}"}, status=404)
@@ -1430,9 +1477,9 @@ class RecipeRoutes:
if 'loras' in recipe:
for lora in recipe['loras']:
if 'hash' in lora and lora['hash']:
lora['inLibrary'] = self.recipe_scanner._lora_scanner.has_lora_hash(lora['hash'].lower())
lora['inLibrary'] = self.recipe_scanner._lora_scanner.has_hash(lora['hash'].lower())
lora['preview_url'] = self.recipe_scanner._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
lora['localPath'] = self.recipe_scanner._lora_scanner.get_lora_path_by_hash(lora['hash'].lower())
lora['localPath'] = self.recipe_scanner._lora_scanner.get_path_by_hash(lora['hash'].lower())
# Ensure file_url is set (needed by frontend)
if 'file_path' in recipe:

View File

@@ -1,11 +1,13 @@
import os
import subprocess
import aiohttp
import logging
import toml
import subprocess
import git
from datetime import datetime
from aiohttp import web
from typing import Dict, Any, List
from typing import Dict, List
logger = logging.getLogger(__name__)
@@ -17,6 +19,7 @@ class UpdateRoutes:
"""Register update check routes"""
app.router.add_get('/api/check-updates', UpdateRoutes.check_updates)
app.router.add_get('/api/version-info', UpdateRoutes.get_version_info)
app.router.add_post('/api/perform-update', UpdateRoutes.perform_update)
@staticmethod
async def check_updates(request):
@@ -25,6 +28,8 @@ class UpdateRoutes:
Returns update status and version information
"""
try:
nightly = request.query.get('nightly', 'false').lower() == 'true'
# Read local version from pyproject.toml
local_version = UpdateRoutes._get_local_version()
@@ -32,13 +37,21 @@ class UpdateRoutes:
git_info = UpdateRoutes._get_git_info()
# Fetch remote version from GitHub
remote_version, changelog = await UpdateRoutes._get_remote_version()
if nightly:
remote_version, changelog = await UpdateRoutes._get_nightly_version()
else:
remote_version, changelog = await UpdateRoutes._get_remote_version()
# Compare versions
update_available = UpdateRoutes._compare_versions(
local_version.replace('v', ''),
remote_version.replace('v', '')
)
if nightly:
# For nightly, compare commit hashes
update_available = UpdateRoutes._compare_nightly_versions(git_info, remote_version)
else:
# For stable, compare semantic versions
update_available = UpdateRoutes._compare_versions(
local_version.replace('v', ''),
remote_version.replace('v', '')
)
return web.json_response({
'success': True,
@@ -46,7 +59,8 @@ class UpdateRoutes:
'latest_version': remote_version,
'update_available': update_available,
'changelog': changelog,
'git_info': git_info
'git_info': git_info,
'nightly': nightly
})
except Exception as e:
@@ -55,7 +69,7 @@ class UpdateRoutes:
'success': False,
'error': str(e)
})
@staticmethod
async def get_version_info(request):
"""
@@ -84,6 +98,168 @@ class UpdateRoutes:
'error': str(e)
})
@staticmethod
async def perform_update(request):
"""
Perform Git-based update to latest release tag or main branch
"""
try:
# Parse request body
body = await request.json() if request.has_body else {}
nightly = body.get('nightly', False)
# Get current plugin directory
current_dir = os.path.dirname(os.path.abspath(__file__))
plugin_root = os.path.dirname(os.path.dirname(current_dir))
# Backup settings.json if it exists
settings_path = os.path.join(plugin_root, 'settings.json')
settings_backup = None
if os.path.exists(settings_path):
with open(settings_path, 'r', encoding='utf-8') as f:
settings_backup = f.read()
logger.info("Backed up settings.json")
# Perform Git update
success, new_version = await UpdateRoutes._perform_git_update(plugin_root, nightly)
# Restore settings.json if we backed it up
if settings_backup and success:
with open(settings_path, 'w', encoding='utf-8') as f:
f.write(settings_backup)
logger.info("Restored settings.json")
if success:
return web.json_response({
'success': True,
'message': f'Successfully updated to {new_version}',
'new_version': new_version
})
else:
return web.json_response({
'success': False,
'error': 'Failed to complete Git update'
})
except Exception as e:
logger.error(f"Failed to perform update: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
})
@staticmethod
async def _get_nightly_version() -> tuple[str, List[str]]:
"""
Fetch latest commit from main branch
"""
repo_owner = "willmiao"
repo_name = "ComfyUI-Lora-Manager"
# Use GitHub API to fetch the latest commit from main branch
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main"
try:
async with aiohttp.ClientSession() as session:
async with session.get(github_url, headers={'Accept': 'application/vnd.github+json'}) as response:
if response.status != 200:
logger.warning(f"Failed to fetch GitHub commit: {response.status}")
return "main", []
data = await response.json()
commit_sha = data.get('sha', '')[:7] # Short hash
commit_message = data.get('commit', {}).get('message', '')
# Format as "main-{short_hash}"
version = f"main-{commit_sha}"
# Use commit message as changelog
changelog = [commit_message] if commit_message else []
return version, changelog
except Exception as e:
logger.error(f"Error fetching nightly version: {e}", exc_info=True)
return "main", []
@staticmethod
def _compare_nightly_versions(local_git_info: Dict[str, str], remote_version: str) -> bool:
"""
Compare local commit hash with remote main branch
"""
try:
local_hash = local_git_info.get('short_hash', 'unknown')
if local_hash == 'unknown':
return True # Assume update available if we can't get local hash
# Extract remote hash from version string (format: "main-{hash}")
if '-' in remote_version:
remote_hash = remote_version.split('-')[-1]
return local_hash != remote_hash
return True # Default to update available
except Exception as e:
logger.error(f"Error comparing nightly versions: {e}")
return False
@staticmethod
async def _perform_git_update(plugin_root: str, nightly: bool = False) -> tuple[bool, str]:
"""
Perform Git-based update using GitPython
Args:
plugin_root: Path to the plugin root directory
nightly: Whether to update to main branch or latest release
Returns:
tuple: (success, new_version)
"""
try:
# Open the Git repository
repo = git.Repo(plugin_root)
# Fetch latest changes
origin = repo.remotes.origin
origin.fetch()
if nightly:
# Switch to main branch and pull latest
main_branch = 'main'
if main_branch not in [branch.name for branch in repo.branches]:
# Create local main branch if it doesn't exist
repo.create_head(main_branch, origin.refs.main)
repo.heads[main_branch].checkout()
origin.pull(main_branch)
# Get new commit hash
new_version = f"main-{repo.head.commit.hexsha[:7]}"
else:
# Get latest release tag
tags = sorted(repo.tags, key=lambda t: t.commit.committed_datetime, reverse=True)
if not tags:
logger.error("No tags found in repository")
return False, ""
latest_tag = tags[0]
# Checkout to latest tag
repo.git.checkout(latest_tag.name)
new_version = latest_tag.name
logger.info(f"Successfully updated to {new_version}")
return True, new_version
except git.exc.GitError as e:
logger.error(f"Git error during update: {e}")
return False, ""
except Exception as e:
logger.error(f"Error during Git update: {e}")
return False, ""
@staticmethod
def _get_local_version() -> str:
"""Get local plugin version from pyproject.toml"""

View File

@@ -0,0 +1,259 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Type
import logging
from ..utils.models import BaseModelMetadata
from ..utils.constants import NSFW_LEVELS
from .settings_manager import settings
from ..utils.utils import fuzzy_match
logger = logging.getLogger(__name__)
class BaseModelService(ABC):
"""Base service class for all model types"""
def __init__(self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata]):
"""Initialize the service
Args:
model_type: Type of model (lora, checkpoint, etc.)
scanner: Model scanner instance
metadata_class: Metadata class for this model type
"""
self.model_type = model_type
self.scanner = scanner
self.metadata_class = metadata_class
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
folder: str = None, search: str = None, fuzzy_search: bool = False,
base_models: list = None, tags: list = None,
search_options: dict = None, hash_filters: dict = None,
favorites_only: bool = False, **kwargs) -> Dict:
"""Get paginated and filtered model data
Args:
page: Page number (1-based)
page_size: Number of items per page
sort_by: Sort criteria, e.g. 'name', 'name:asc', 'name:desc', 'date', 'date:asc', 'date:desc'
folder: Folder filter
search: Search term
fuzzy_search: Whether to use fuzzy search
base_models: List of base models to filter by
tags: List of tags to filter by
search_options: Search options dict
hash_filters: Hash filtering options
favorites_only: Filter for favorites only
**kwargs: Additional model-specific filters
Returns:
Dict containing paginated results
"""
cache = await self.scanner.get_cached_data()
# Parse sort_by into sort_key and order
if ':' in sort_by:
sort_key, order = sort_by.split(':', 1)
sort_key = sort_key.strip()
order = order.strip().lower()
if order not in ('asc', 'desc'):
order = 'asc'
else:
sort_key = sort_by.strip()
order = 'asc'
# Get default search options if not provided
if search_options is None:
search_options = {
'filename': True,
'modelname': True,
'tags': False,
'recursive': False,
}
# Get the base data set using new sort logic
filtered_data = await cache.get_sorted_data(sort_key, order)
# Apply hash filtering if provided (highest priority)
if hash_filters:
filtered_data = await self._apply_hash_filters(filtered_data, hash_filters)
# Jump to pagination for hash filters
return self._paginate(filtered_data, page, page_size)
# Apply common filters
filtered_data = await self._apply_common_filters(
filtered_data, folder, base_models, tags, favorites_only, search_options
)
# Apply search filtering
if search:
filtered_data = await self._apply_search_filters(
filtered_data, search, fuzzy_search, search_options
)
# Apply model-specific filters
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
return self._paginate(filtered_data, page, page_size)
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
"""Apply hash-based filtering"""
single_hash = hash_filters.get('single_hash')
multiple_hashes = hash_filters.get('multiple_hashes')
if single_hash:
# Filter by single hash
single_hash = single_hash.lower()
return [
item for item in data
if item.get('sha256', '').lower() == single_hash
]
elif multiple_hashes:
# Filter by multiple hashes
hash_set = set(hash.lower() for hash in multiple_hashes)
return [
item for item in data
if item.get('sha256', '').lower() in hash_set
]
return data
async def _apply_common_filters(self, data: List[Dict], folder: str = None,
base_models: list = None, tags: list = None,
favorites_only: bool = False, search_options: dict = None) -> List[Dict]:
"""Apply common filters that work across all model types"""
# Apply SFW filtering if enabled in settings
if settings.get('show_only_sfw', False):
data = [
item for item in data
if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R']
]
# Apply favorites filtering if enabled
if favorites_only:
data = [
item for item in data
if item.get('favorite', False) is True
]
# Apply folder filtering
if folder is not None:
if search_options and search_options.get('recursive', False):
# Recursive folder filtering - include all subfolders
data = [
item for item in data
if item['folder'].startswith(folder)
]
else:
# Exact folder filtering
data = [
item for item in data
if item['folder'] == folder
]
# Apply base model filtering
if base_models and len(base_models) > 0:
data = [
item for item in data
if item.get('base_model') in base_models
]
# Apply tag filtering
if tags and len(tags) > 0:
data = [
item for item in data
if any(tag in item.get('tags', []) for tag in tags)
]
return data
async def _apply_search_filters(self, data: List[Dict], search: str,
fuzzy_search: bool, search_options: dict) -> List[Dict]:
"""Apply search filtering"""
search_results = []
for item in data:
# Search by file name
if search_options.get('filename', True):
if fuzzy_search:
if fuzzy_match(item.get('file_name', ''), search):
search_results.append(item)
continue
elif search.lower() in item.get('file_name', '').lower():
search_results.append(item)
continue
# Search by model name
if search_options.get('modelname', True):
if fuzzy_search:
if fuzzy_match(item.get('model_name', ''), search):
search_results.append(item)
continue
elif search.lower() in item.get('model_name', '').lower():
search_results.append(item)
continue
# Search by tags
if search_options.get('tags', False) and 'tags' in item:
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower())
for tag in item['tags']):
search_results.append(item)
continue
return search_results
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
"""Apply model-specific filters - to be overridden by subclasses if needed"""
return data
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
"""Apply pagination to filtered data"""
total_items = len(data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
return {
'items': data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
@abstractmethod
async def format_response(self, model_data: Dict) -> Dict:
"""Format model data for API response - must be implemented by subclasses"""
pass
# Common service methods that delegate to scanner
async def get_top_tags(self, limit: int = 20) -> List[Dict]:
"""Get top tags sorted by frequency"""
return await self.scanner.get_top_tags(limit)
async def get_base_models(self, limit: int = 20) -> List[Dict]:
"""Get base models sorted by frequency"""
return await self.scanner.get_base_models(limit)
def has_hash(self, sha256: str) -> bool:
"""Check if a model with given hash exists"""
return self.scanner.has_hash(sha256)
def get_path_by_hash(self, sha256: str) -> Optional[str]:
"""Get file path for a model by its hash"""
return self.scanner.get_path_by_hash(sha256)
def get_hash_by_path(self, file_path: str) -> Optional[str]:
"""Get hash for a model by its file path"""
return self.scanner.get_hash_by_path(file_path)
async def scan_models(self, force_refresh: bool = False, rebuild_cache: bool = False):
"""Trigger model scanning"""
return await self.scanner.get_cached_data(force_refresh=force_refresh, rebuild_cache=rebuild_cache)
async def get_model_info_by_name(self, name: str):
"""Get model information by name"""
return await self.scanner.get_model_info_by_name(name)
def get_model_roots(self) -> List[str]:
"""Get model root directories"""
return self.scanner.get_model_roots()

View File

@@ -1,112 +1,26 @@
import os
import logging
import asyncio
from typing import List, Dict, Optional, Set
import folder_paths # type: ignore
from typing import List
from ..utils.models import CheckpointMetadata
from ..config import config
from .model_scanner import ModelScanner
from .model_hash_index import ModelHashIndex
from .service_registry import ServiceRegistry
logger = logging.getLogger(__name__)
class CheckpointScanner(ModelScanner):
"""Service for scanning and managing checkpoint files"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not hasattr(self, '_initialized'):
# Define supported file extensions
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
super().__init__(
model_type="checkpoint",
model_class=CheckpointMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex()
)
self._initialized = True
# Define supported file extensions
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
super().__init__(
model_type="checkpoint",
model_class=CheckpointMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex()
)
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def get_model_roots(self) -> List[str]:
"""Get checkpoint root directories"""
return config.base_models_roots
async def scan_all_models(self) -> List[Dict]:
"""Scan all checkpoint directories and return metadata"""
all_checkpoints = []
# Create scan tasks for each directory
scan_tasks = []
for root in self.get_model_roots():
task = asyncio.create_task(self._scan_directory(root))
scan_tasks.append(task)
# Wait for all tasks to complete
for task in scan_tasks:
try:
checkpoints = await task
all_checkpoints.extend(checkpoints)
except Exception as e:
logger.error(f"Error scanning checkpoint directory: {e}")
return all_checkpoints
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a directory for checkpoint files"""
checkpoints = []
original_root = root_path
async def scan_recursive(path: str, visited_paths: set):
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True):
# Check if file has supported extension
ext = os.path.splitext(entry.name)[1].lower()
if ext in self.file_extensions:
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, checkpoints)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# For directories, continue scanning with original path
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return checkpoints
async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list):
"""Process a single checkpoint file and add to results"""
try:
result = await self._process_model_file(file_path, root_path)
if result:
checkpoints.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
return config.base_models_roots

View File

@@ -0,0 +1,51 @@
import os
import logging
from typing import Dict, List, Optional
from .base_model_service import BaseModelService
from ..utils.models import CheckpointMetadata
from ..config import config
from ..utils.routes_common import ModelRouteUtils
logger = logging.getLogger(__name__)
class CheckpointService(BaseModelService):
"""Checkpoint-specific service implementation"""
def __init__(self, scanner):
"""Initialize Checkpoint service
Args:
scanner: Checkpoint scanner instance
"""
super().__init__("checkpoint", scanner, CheckpointMetadata)
async def format_response(self, checkpoint_data: Dict) -> Dict:
"""Format Checkpoint data for API response"""
return {
"model_name": checkpoint_data["model_name"],
"file_name": checkpoint_data["file_name"],
"preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")),
"preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0),
"base_model": checkpoint_data.get("base_model", ""),
"folder": checkpoint_data["folder"],
"sha256": checkpoint_data.get("sha256", ""),
"file_path": checkpoint_data["file_path"].replace(os.sep, "/"),
"file_size": checkpoint_data.get("size", 0),
"modified": checkpoint_data.get("modified", ""),
"tags": checkpoint_data.get("tags", []),
"modelDescription": checkpoint_data.get("modelDescription", ""),
"from_civitai": checkpoint_data.get("from_civitai", True),
"notes": checkpoint_data.get("notes", ""),
"model_type": checkpoint_data.get("model_type", "checkpoint"),
"favorite": checkpoint_data.get("favorite", False),
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint_data.get("civitai", {}))
}
def find_duplicate_hashes(self) -> Dict:
"""Find Checkpoints with duplicate SHA256 hashes"""
return self.scanner._hash_index.get_duplicate_hashes()
def find_duplicate_filenames(self) -> Dict:
"""Find Checkpoints with conflicting filenames"""
return self.scanner._hash_index.get_duplicate_filenames()

View File

@@ -1,15 +1,10 @@
import os
import logging
import asyncio
from typing import List, Dict, Optional
from typing import List
from ..utils.models import LoraMetadata
from ..config import config
from .model_scanner import ModelScanner
from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex
from .settings_manager import settings
from ..utils.constants import NSFW_LEVELS
from ..utils.utils import fuzzy_match
import sys
logger = logging.getLogger(__name__)
@@ -17,404 +12,21 @@ logger = logging.getLogger(__name__)
class LoraScanner(ModelScanner):
"""Service for scanning and managing LoRA files"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
# Ensure initialization happens only once
if not hasattr(self, '_initialized'):
# Define supported file extensions
file_extensions = {'.safetensors'}
# Initialize parent class with ModelHashIndex
super().__init__(
model_type="lora",
model_class=LoraMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
)
self._initialized = True
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
# Define supported file extensions
file_extensions = {'.safetensors'}
# Initialize parent class with ModelHashIndex
super().__init__(
model_type="lora",
model_class=LoraMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
)
def get_model_roots(self) -> List[str]:
"""Get lora root directories"""
return config.loras_roots
async def scan_all_models(self) -> List[Dict]:
"""Scan all LoRA directories and return metadata"""
all_loras = []
# Create scan tasks for each directory
scan_tasks = []
for lora_root in self.get_model_roots():
task = asyncio.create_task(self._scan_directory(lora_root))
scan_tasks.append(task)
# Wait for all tasks to complete
for task in scan_tasks:
try:
loras = await task
all_loras.extend(loras)
except Exception as e:
logger.error(f"Error scanning directory: {e}")
return all_loras
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a single directory for LoRA files"""
loras = []
original_root = root_path # Save original root path
async def scan_recursive(path: str, visited_paths: set):
"""Recursively scan directory, avoiding circular symlinks"""
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
# Use original path instead of real path
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, loras)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# For directories, continue scanning with original path
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return loras
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
"""Process a single file and add to results list"""
try:
result = await self._process_model_file(file_path, root_path)
if result:
loras.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
folder: str = None, search: str = None, fuzzy_search: bool = False,
base_models: list = None, tags: list = None,
search_options: dict = None, hash_filters: dict = None,
favorites_only: bool = False, first_letter: str = None) -> Dict:
"""Get paginated and filtered lora data
Args:
page: Current page number (1-based)
page_size: Number of items per page
sort_by: Sort method ('name' or 'date')
folder: Filter by folder path
search: Search term
fuzzy_search: Use fuzzy matching for search
base_models: List of base models to filter by
tags: List of tags to filter by
search_options: Dictionary with search options (filename, modelname, tags, recursive)
hash_filters: Dictionary with hash filtering options (single_hash or multiple_hashes)
favorites_only: Filter for favorite models only
first_letter: Filter by first letter of model name
"""
cache = await self.get_cached_data()
# Get default search options if not provided
if search_options is None:
search_options = {
'filename': True,
'modelname': True,
'tags': False,
'recursive': False,
}
# Get the base data set
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
# Apply hash filtering if provided (highest priority)
if hash_filters:
single_hash = hash_filters.get('single_hash')
multiple_hashes = hash_filters.get('multiple_hashes')
if single_hash:
# Filter by single hash
single_hash = single_hash.lower() # Ensure lowercase for matching
filtered_data = [
lora for lora in filtered_data
if lora.get('sha256', '').lower() == single_hash
]
elif multiple_hashes:
# Filter by multiple hashes
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
filtered_data = [
lora for lora in filtered_data
if lora.get('sha256', '').lower() in hash_set
]
# Jump to pagination
total_items = len(filtered_data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
result = {
'items': filtered_data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
return result
# Apply SFW filtering if enabled
if settings.get('show_only_sfw', False):
filtered_data = [
lora for lora in filtered_data
if not lora.get('preview_nsfw_level') or lora.get('preview_nsfw_level') < NSFW_LEVELS['R']
]
# Apply favorites filtering if enabled
if favorites_only:
filtered_data = [
lora for lora in filtered_data
if lora.get('favorite', False) is True
]
# Apply first letter filtering
if first_letter:
filtered_data = self._filter_by_first_letter(filtered_data, first_letter)
# Apply folder filtering
if folder is not None:
if search_options.get('recursive', False):
# Recursive folder filtering - include all subfolders
filtered_data = [
lora for lora in filtered_data
if lora['folder'].startswith(folder)
]
else:
# Exact folder filtering
filtered_data = [
lora for lora in filtered_data
if lora['folder'] == folder
]
# Apply base model filtering
if base_models and len(base_models) > 0:
filtered_data = [
lora for lora in filtered_data
if lora.get('base_model') in base_models
]
# Apply tag filtering
if tags and len(tags) > 0:
filtered_data = [
lora for lora in filtered_data
if any(tag in lora.get('tags', []) for tag in tags)
]
# Apply search filtering
if search:
search_results = []
search_opts = search_options or {}
for lora in filtered_data:
# Search by file name
if search_opts.get('filename', True):
if fuzzy_match(lora.get('file_name', ''), search):
search_results.append(lora)
continue
# Search by model name
if search_opts.get('modelname', True):
if fuzzy_match(lora.get('model_name', ''), search):
search_results.append(lora)
continue
# Search by tags
if search_opts.get('tags', False) and 'tags' in lora:
if any(fuzzy_match(tag, search) for tag in lora['tags']):
search_results.append(lora)
continue
filtered_data = search_results
# Calculate pagination
total_items = len(filtered_data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
result = {
'items': filtered_data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
return result
def _filter_by_first_letter(self, data, letter):
"""Filter data by first letter of model name
Special handling:
- '#': Numbers (0-9)
- '@': Special characters (not alphanumeric)
- '': CJK characters
"""
filtered_data = []
for lora in data:
model_name = lora.get('model_name', '')
if not model_name:
continue
first_char = model_name[0].upper()
if letter == '#' and first_char.isdigit():
filtered_data.append(lora)
elif letter == '@' and not first_char.isalnum():
# Special characters (not alphanumeric)
filtered_data.append(lora)
elif letter == '' and self._is_cjk_character(first_char):
# CJK characters
filtered_data.append(lora)
elif letter.upper() == first_char:
# Regular alphabet matching
filtered_data.append(lora)
return filtered_data
def _is_cjk_character(self, char):
"""Check if character is a CJK character"""
# Define Unicode ranges for CJK characters
cjk_ranges = [
(0x4E00, 0x9FFF), # CJK Unified Ideographs
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
(0x3300, 0x33FF), # CJK Compatibility
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
(0x3100, 0x312F), # Bopomofo
(0x31A0, 0x31BF), # Bopomofo Extended
(0x3040, 0x309F), # Hiragana
(0x30A0, 0x30FF), # Katakana
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
(0xAC00, 0xD7AF), # Hangul Syllables
(0x1100, 0x11FF), # Hangul Jamo
(0xA960, 0xA97F), # Hangul Jamo Extended-A
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
]
code_point = ord(char)
return any(start <= code_point <= end for start, end in cjk_ranges)
async def get_letter_counts(self):
"""Get count of models for each letter of the alphabet"""
cache = await self.get_cached_data()
data = cache.sorted_by_name
# Define letter categories
letters = {
'#': 0, # Numbers
'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0,
'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0,
'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0,
'Y': 0, 'Z': 0,
'@': 0, # Special characters
'': 0 # CJK characters
}
# Count models for each letter
for lora in data:
model_name = lora.get('model_name', '')
if not model_name:
continue
first_char = model_name[0].upper()
if first_char.isdigit():
letters['#'] += 1
elif first_char in letters:
letters[first_char] += 1
elif self._is_cjk_character(first_char):
letters[''] += 1
elif not first_char.isalnum():
letters['@'] += 1
return letters
# Lora-specific hash index functionality
def has_lora_hash(self, sha256: str) -> bool:
"""Check if a LoRA with given hash exists"""
return self.has_hash(sha256)
def get_lora_path_by_hash(self, sha256: str) -> Optional[str]:
"""Get file path for a LoRA by its hash"""
return self.get_path_by_hash(sha256)
def get_lora_hash_by_path(self, file_path: str) -> Optional[str]:
"""Get hash for a LoRA by its file path"""
return self.get_hash_by_path(file_path)
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get top tags sorted by count"""
# Make sure cache is initialized
await self.get_cached_data()
# Sort tags by count in descending order
sorted_tags = sorted(
[{"tag": tag, "count": count} for tag, count in self._tags_count.items()],
key=lambda x: x['count'],
reverse=True
)
# Return limited number
return sorted_tags[:limit]
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get base models used in loras sorted by frequency"""
# Make sure cache is initialized
cache = await self.get_cached_data()
# Count base model occurrences
base_model_counts = {}
for lora in cache.raw_data:
if 'base_model' in lora and lora['base_model']:
base_model = lora['base_model']
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
# Sort base models by count
sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()]
sorted_models.sort(key=lambda x: x['count'], reverse=True)
# Return limited number
return sorted_models[:limit]
async def diagnose_hash_index(self):
"""Diagnostic method to verify hash index functionality"""
@@ -451,19 +63,3 @@ class LoraScanner(ModelScanner):
test_hash_result = self._hash_index.get_hash(test_path)
print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr)
async def get_lora_info_by_name(self, name):
"""Get LoRA information by name"""
try:
# Get cached data
cache = await self.get_cached_data()
# Find the LoRA by name
for lora in cache.raw_data:
if lora.get("file_name") == name:
return lora
return None
except Exception as e:
logger.error(f"Error getting LoRA info by name: {e}", exc_info=True)
return None

212
py/services/lora_service.py Normal file
View File

@@ -0,0 +1,212 @@
import os
import logging
from typing import Dict, List, Optional
from .base_model_service import BaseModelService
from ..utils.models import LoraMetadata
from ..config import config
from ..utils.routes_common import ModelRouteUtils
logger = logging.getLogger(__name__)
class LoraService(BaseModelService):
"""LoRA-specific service implementation"""
def __init__(self, scanner):
"""Initialize LoRA service
Args:
scanner: LoRA scanner instance
"""
super().__init__("lora", scanner, LoraMetadata)
async def format_response(self, lora_data: Dict) -> Dict:
"""Format LoRA data for API response"""
return {
"model_name": lora_data["model_name"],
"file_name": lora_data["file_name"],
"preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")),
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
"base_model": lora_data.get("base_model", ""),
"folder": lora_data["folder"],
"sha256": lora_data.get("sha256", ""),
"file_path": lora_data["file_path"].replace(os.sep, "/"),
"file_size": lora_data.get("size", 0),
"modified": lora_data.get("modified", ""),
"tags": lora_data.get("tags", []),
"modelDescription": lora_data.get("modelDescription", ""),
"from_civitai": lora_data.get("from_civitai", True),
"usage_tips": lora_data.get("usage_tips", ""),
"notes": lora_data.get("notes", ""),
"favorite": lora_data.get("favorite", False),
"civitai": ModelRouteUtils.filter_civitai_data(lora_data.get("civitai", {}))
}
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
"""Apply LoRA-specific filters"""
# Handle first_letter filter for LoRAs
first_letter = kwargs.get('first_letter')
if first_letter:
data = self._filter_by_first_letter(data, first_letter)
return data
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
"""Filter data by first letter of model name
Special handling:
- '#': Numbers (0-9)
- '@': Special characters (not alphanumeric)
- '': CJK characters
"""
filtered_data = []
for lora in data:
model_name = lora.get('model_name', '')
if not model_name:
continue
first_char = model_name[0].upper()
if letter == '#' and first_char.isdigit():
filtered_data.append(lora)
elif letter == '@' and not first_char.isalnum():
# Special characters (not alphanumeric)
filtered_data.append(lora)
elif letter == '' and self._is_cjk_character(first_char):
# CJK characters
filtered_data.append(lora)
elif letter.upper() == first_char:
# Regular alphabet matching
filtered_data.append(lora)
return filtered_data
def _is_cjk_character(self, char: str) -> bool:
"""Check if character is a CJK character"""
# Define Unicode ranges for CJK characters
cjk_ranges = [
(0x4E00, 0x9FFF), # CJK Unified Ideographs
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
(0x3300, 0x33FF), # CJK Compatibility
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
(0x3100, 0x312F), # Bopomofo
(0x31A0, 0x31BF), # Bopomofo Extended
(0x3040, 0x309F), # Hiragana
(0x30A0, 0x30FF), # Katakana
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
(0xAC00, 0xD7AF), # Hangul Syllables
(0x1100, 0x11FF), # Hangul Jamo
(0xA960, 0xA97F), # Hangul Jamo Extended-A
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
]
code_point = ord(char)
return any(start <= code_point <= end for start, end in cjk_ranges)
# LoRA-specific methods
async def get_letter_counts(self) -> Dict[str, int]:
"""Get count of LoRAs for each letter of the alphabet"""
cache = await self.scanner.get_cached_data()
data = cache.raw_data
# Define letter categories
letters = {
'#': 0, # Numbers
'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0,
'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0,
'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0,
'Y': 0, 'Z': 0,
'@': 0, # Special characters
'': 0 # CJK characters
}
# Count models for each letter
for lora in data:
model_name = lora.get('model_name', '')
if not model_name:
continue
first_char = model_name[0].upper()
if first_char.isdigit():
letters['#'] += 1
elif first_char in letters:
letters[first_char] += 1
elif self._is_cjk_character(first_char):
letters[''] += 1
elif not first_char.isalnum():
letters['@'] += 1
return letters
async def get_lora_notes(self, lora_name: str) -> Optional[str]:
"""Get notes for a specific LoRA file"""
cache = await self.scanner.get_cached_data()
for lora in cache.raw_data:
if lora['file_name'] == lora_name:
return lora.get('notes', '')
return None
async def get_lora_trigger_words(self, lora_name: str) -> List[str]:
"""Get trigger words for a specific LoRA file"""
cache = await self.scanner.get_cached_data()
for lora in cache.raw_data:
if lora['file_name'] == lora_name:
civitai_data = lora.get('civitai', {})
return civitai_data.get('trainedWords', [])
return []
async def get_lora_preview_url(self, lora_name: str) -> Optional[str]:
"""Get the static preview URL for a LoRA file"""
cache = await self.scanner.get_cached_data()
for lora in cache.raw_data:
if lora['file_name'] == lora_name:
preview_url = lora.get('preview_url')
if preview_url:
return config.get_preview_static_url(preview_url)
return None
async def get_lora_civitai_url(self, lora_name: str) -> Dict[str, Optional[str]]:
"""Get the Civitai URL for a LoRA file"""
cache = await self.scanner.get_cached_data()
for lora in cache.raw_data:
if lora['file_name'] == lora_name:
civitai_data = lora.get('civitai', {})
model_id = civitai_data.get('modelId')
version_id = civitai_data.get('id')
if model_id:
civitai_url = f"https://civitai.com/models/{model_id}"
if version_id:
civitai_url += f"?modelVersionId={version_id}"
return {
'civitai_url': civitai_url,
'model_id': str(model_id),
'version_id': str(version_id) if version_id else None
}
return {'civitai_url': None, 'model_id': None, 'version_id': None}
def find_duplicate_hashes(self) -> Dict:
"""Find LoRAs with duplicate SHA256 hashes"""
return self.scanner._hash_index.get_duplicate_hashes()
def find_duplicate_filenames(self) -> Dict:
"""Find LoRAs with conflicting filenames"""
return self.scanner._hash_index.get_duplicate_filenames()

View File

@@ -1,37 +1,85 @@
import asyncio
from typing import List, Dict
from typing import List, Dict, Tuple
from dataclasses import dataclass
from operator import itemgetter
from natsort import natsorted
# Supported sort modes: (sort_key, order)
# order: 'asc' for ascending, 'desc' for descending
SUPPORTED_SORT_MODES = [
('name', 'asc'),
('name', 'desc'),
('date', 'asc'),
('date', 'desc'),
('size', 'asc'),
('size', 'desc'),
]
@dataclass
class ModelCache:
"""Cache structure for model data"""
"""Cache structure for model data with extensible sorting"""
raw_data: List[Dict]
sorted_by_name: List[Dict]
sorted_by_date: List[Dict]
folders: List[str]
def __post_init__(self):
self._lock = asyncio.Lock()
# Cache for last sort: (sort_key, order) -> sorted list
self._last_sort: Tuple[str, str] = (None, None)
self._last_sorted_data: List[Dict] = []
# Default sort on init
asyncio.create_task(self.resort())
async def resort(self, name_only: bool = False):
"""Resort all cached data views"""
async def resort(self):
"""Resort cached data according to last sort mode if set"""
async with self._lock:
self.sorted_by_name = natsorted(
self.raw_data,
key=lambda x: x['model_name'].lower() # Case-insensitive sort
)
if not name_only:
self.sorted_by_date = sorted(
self.raw_data,
key=itemgetter('modified'),
reverse=True
)
# Update folder list
if self._last_sort != (None, None):
sort_key, order = self._last_sort
sorted_data = self._sort_data(self.raw_data, sort_key, order)
self._last_sorted_data = sorted_data
# Update folder list
# else: do nothing
all_folders = set(l['folder'] for l in self.raw_data)
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
"""Sort data by sort_key and order"""
reverse = (order == 'desc')
if sort_key == 'name':
# Natural sort by model_name, case-insensitive
return natsorted(
data,
key=lambda x: x['model_name'].lower(),
reverse=reverse
)
elif sort_key == 'date':
# Sort by modified timestamp
return sorted(
data,
key=itemgetter('modified'),
reverse=reverse
)
elif sort_key == 'size':
# Sort by file size
return sorted(
data,
key=itemgetter('size'),
reverse=reverse
)
else:
# Fallback: no sort
return list(data)
async def get_sorted_data(self, sort_key: str = 'name', order: str = 'asc') -> List[Dict]:
"""Get sorted data by sort_key and order, using cache if possible"""
async with self._lock:
if (sort_key, order) == self._last_sort:
return self._last_sorted_data
sorted_data = self._sort_data(self.raw_data, sort_key, order)
self._last_sort = (sort_key, order)
self._last_sorted_data = sorted_data
return sorted_data
async def update_preview_url(self, file_path: str, preview_url: str, preview_nsfw_level: int) -> bool:
"""Update preview_url for a specific model in all cached data

View File

@@ -5,7 +5,6 @@ import asyncio
import time
import shutil
from typing import List, Dict, Optional, Type, Set
import msgpack # Add MessagePack import for efficient serialization
from ..utils.models import BaseModelMetadata
from ..config import config
@@ -19,17 +18,33 @@ from .websocket_manager import ws_manager
logger = logging.getLogger(__name__)
# Define cache version to handle future format changes
# Version history:
# 1 - Initial version
# 2 - Added duplicate_filenames and duplicate_hashes tracking
# 3 - Added _excluded_models list to cache
CACHE_VERSION = 3
class ModelScanner:
"""Base service for scanning and managing model files"""
_lock = asyncio.Lock()
_instances = {} # Dictionary to store instances by class
_locks = {} # Dictionary to store locks by class
def __new__(cls, *args, **kwargs):
"""Implement singleton pattern for each subclass"""
if cls not in cls._instances:
cls._instances[cls] = super().__new__(cls)
return cls._instances[cls]
@classmethod
def _get_lock(cls):
"""Get or create a lock for this class"""
if cls not in cls._locks:
cls._locks[cls] = asyncio.Lock()
return cls._locks[cls]
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
lock = cls._get_lock()
async with lock:
if cls not in cls._instances:
cls._instances[cls] = cls()
return cls._instances[cls]
def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None):
"""Initialize the scanner
@@ -40,6 +55,10 @@ class ModelScanner:
file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'})
hash_index: Hash index instance (optional)
"""
# Ensure initialization happens only once per instance
if hasattr(self, '_initialized'):
return
self.model_type = model_type
self.model_class = model_class
self.file_extensions = file_extensions
@@ -48,202 +67,15 @@ class ModelScanner:
self._tags_count = {} # Dictionary to store tag counts
self._is_initializing = False # Flag to track initialization state
self._excluded_models = [] # List to track excluded models
self._dirs_last_modified = {} # Track directory modification times
self._use_cache_files = False # Flag to control cache file usage, default to disabled
# Clear cache files if disabled
if not self._use_cache_files:
self._clear_cache_files()
self._initialized = True
# Register this service
asyncio.create_task(self._register_service())
def _clear_cache_files(self):
"""Clear existing cache files if they exist"""
try:
cache_path = self._get_cache_file_path()
if cache_path and os.path.exists(cache_path):
os.remove(cache_path)
logger.info(f"Cleared {self.model_type} cache file: {cache_path}")
except Exception as e:
logger.error(f"Error clearing {self.model_type} cache file: {e}")
async def _register_service(self):
"""Register this instance with the ServiceRegistry"""
service_name = f"{self.model_type}_scanner"
await ServiceRegistry.register_service(service_name, self)
def _get_cache_file_path(self) -> Optional[str]:
"""Get the path to the cache file"""
# Get the directory where this module is located
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
# Create a cache directory within the project if it doesn't exist
cache_dir = os.path.join(current_dir, "cache")
os.makedirs(cache_dir, exist_ok=True)
# Create filename based on model type
cache_filename = f"lm_{self.model_type}_cache.msgpack"
return os.path.join(cache_dir, cache_filename)
def _prepare_for_msgpack(self, data):
"""Preprocess data to accommodate MessagePack serialization limitations
Converts integers exceeding safe range to strings
Args:
data: Any type of data structure
Returns:
Preprocessed data structure with large integers converted to strings
"""
if isinstance(data, dict):
return {k: self._prepare_for_msgpack(v) for k, v in data.items()}
elif isinstance(data, list):
return [self._prepare_for_msgpack(item) for item in data]
elif isinstance(data, int) and (data > 9007199254740991 or data < -9007199254740991):
# Convert integers exceeding JavaScript's safe integer range (2^53-1) to strings
return str(data)
else:
return data
async def _save_cache_to_disk(self) -> bool:
"""Save cache data to disk using MessagePack"""
if not self._use_cache_files:
logger.debug(f"Cache files disabled for {self.model_type}, skipping save")
return False
if self._cache is None or not self._cache.raw_data:
logger.debug(f"No {self.model_type} cache data to save")
return False
cache_path = self._get_cache_file_path()
if not cache_path:
logger.warning(f"Cannot determine {self.model_type} cache file location")
return False
try:
# Create cache data structure
cache_data = {
"version": CACHE_VERSION,
"timestamp": time.time(),
"model_type": self.model_type,
"raw_data": self._cache.raw_data,
"hash_index": {
"hash_to_path": self._hash_index._hash_to_path,
"filename_to_hash": self._hash_index._filename_to_hash, # Fix: changed from path_to_hash to filename_to_hash
"duplicate_hashes": self._hash_index._duplicate_hashes,
"duplicate_filenames": self._hash_index._duplicate_filenames
},
"tags_count": self._tags_count,
"dirs_last_modified": self._get_dirs_last_modified(),
"excluded_models": self._excluded_models # Add excluded_models to cache data
}
# Preprocess data to handle large integers
processed_cache_data = self._prepare_for_msgpack(cache_data)
# Write to temporary file first (atomic operation)
temp_path = f"{cache_path}.tmp"
with open(temp_path, 'wb') as f:
msgpack.pack(processed_cache_data, f)
# Replace the old file with the new one
if os.path.exists(cache_path):
os.replace(temp_path, cache_path)
else:
os.rename(temp_path, cache_path)
logger.info(f"Saved {self.model_type} cache with {len(self._cache.raw_data)} models to {cache_path}")
logger.debug(f"Hash index stats - hash_to_path: {len(self._hash_index._hash_to_path)}, filename_to_hash: {len(self._hash_index._filename_to_hash)}, duplicate_hashes: {len(self._hash_index._duplicate_hashes)}, duplicate_filenames: {len(self._hash_index._duplicate_filenames)}")
return True
except Exception as e:
logger.error(f"Error saving {self.model_type} cache to disk: {e}")
# Try to clean up temp file if it exists
if 'temp_path' in locals() and os.path.exists(temp_path):
try:
os.remove(temp_path)
except:
pass
return False
def _get_dirs_last_modified(self) -> Dict[str, float]:
"""Get last modified time for all model directories"""
dirs_info = {}
for root in self.get_model_roots():
if os.path.exists(root):
dirs_info[root] = os.path.getmtime(root)
# Also check immediate subdirectories for changes
try:
with os.scandir(root) as it:
for entry in it:
if entry.is_dir(follow_symlinks=True):
dirs_info[entry.path] = entry.stat().st_mtime
except Exception as e:
logger.error(f"Error getting directory info for {root}: {e}")
return dirs_info
def _is_cache_valid(self, cache_data: Dict) -> bool:
"""Validate if the loaded cache is still valid"""
if not cache_data or cache_data.get("version") != CACHE_VERSION:
logger.info(f"Cache invalid - version mismatch. Got: {cache_data.get('version')}, Expected: {CACHE_VERSION}")
return False
if cache_data.get("model_type") != self.model_type:
logger.info(f"Cache invalid - model type mismatch. Got: {cache_data.get('model_type')}, Expected: {self.model_type}")
return False
return True
async def _load_cache_from_disk(self) -> bool:
"""Load cache data from disk using MessagePack"""
if not self._use_cache_files:
logger.info(f"Cache files disabled for {self.model_type}, skipping load")
return False
start_time = time.time()
cache_path = self._get_cache_file_path()
if not cache_path or not os.path.exists(cache_path):
return False
try:
with open(cache_path, 'rb') as f:
cache_data = msgpack.unpack(f)
# Validate cache data
if not self._is_cache_valid(cache_data):
logger.info(f"{self.model_type.capitalize()} cache file found but invalid or outdated")
return False
# Load data into memory
self._cache = ModelCache(
raw_data=cache_data["raw_data"],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# Load hash index
hash_index_data = cache_data.get("hash_index", {})
self._hash_index._hash_to_path = hash_index_data.get("hash_to_path", {})
self._hash_index._filename_to_hash = hash_index_data.get("filename_to_hash", {}) # Fix: changed from path_to_hash to filename_to_hash
self._hash_index._duplicate_hashes = hash_index_data.get("duplicate_hashes", {})
self._hash_index._duplicate_filenames = hash_index_data.get("duplicate_filenames", {})
# Load tags count
self._tags_count = cache_data.get("tags_count", {})
# Load excluded models
self._excluded_models = cache_data.get("excluded_models", [])
# Resort the cache
await self._cache.resort()
logger.info(f"Loaded {self.model_type} cache from disk with {len(self._cache.raw_data)} models in {time.time() - start_time:.2f} seconds")
return True
except Exception as e:
logger.error(f"Error loading {self.model_type} cache from disk: {e}")
return False
async def initialize_in_background(self) -> None:
"""Initialize cache in background using thread pool"""
@@ -252,8 +84,6 @@ class ModelScanner:
if self._cache is None:
self._cache = ModelCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
@@ -271,21 +101,6 @@ class ModelScanner:
'scanner_type': self.model_type,
'pageType': page_type
})
cache_loaded = await self._load_cache_from_disk()
if cache_loaded:
# Cache loaded successfully, broadcast complete message
await ws_manager.broadcast_init_progress({
'stage': 'finalizing',
'progress': 100,
'status': 'complete',
'details': f"Loaded {len(self._cache.raw_data)} {self.model_type} files from cache.",
'scanner_type': self.model_type,
'pageType': page_type
})
self._is_initializing = False
return
# If cache loading failed, proceed with full scan
await ws_manager.broadcast_init_progress({
@@ -332,9 +147,6 @@ class ModelScanner:
logger.info(f"{self.model_type.capitalize()} cache initialized in {time.time() - start_time:.2f} seconds. Found {len(self._cache.raw_data)} models")
# Save the cache to disk after initialization
await self._save_cache_to_disk()
# Send completion message
await asyncio.sleep(0.5) # Small delay to ensure final progress message is sent
await ws_manager.broadcast_init_progress({
@@ -509,40 +321,21 @@ class ModelScanner:
Args:
force_refresh: Whether to refresh the cache
rebuild_cache: Whether to completely rebuild the cache by reloading from disk first
rebuild_cache: Whether to completely rebuild the cache
"""
# If cache is not initialized, return an empty cache
# Actual initialization should be done via initialize_in_background
if self._cache is None and not force_refresh:
return ModelCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# If force refresh is requested, initialize the cache directly
if force_refresh:
# If rebuild_cache is True, try to reload from disk before reconciliation
if rebuild_cache:
logger.info(f"{self.model_type.capitalize()} Scanner: Attempting to rebuild cache from disk...")
cache_loaded = await self._load_cache_from_disk()
if cache_loaded:
logger.info(f"{self.model_type.capitalize()} Scanner: Successfully reloaded cache from disk")
else:
logger.info(f"{self.model_type.capitalize()} Scanner: Could not reload cache from disk, proceeding with complete rebuild")
# If loading from disk failed, do a complete rebuild and save to disk
await self._initialize_cache()
await self._save_cache_to_disk()
return self._cache
if self._cache is None:
# For initial creation, do a full initialization
await self._initialize_cache()
# Save the newly built cache
await self._save_cache_to_disk()
else:
# For subsequent refreshes, use fast reconciliation
await self._reconcile_cache()
return self._cache
@@ -577,8 +370,6 @@ class ModelScanner:
# Update cache
self._cache = ModelCache(
raw_data=raw_data,
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
@@ -592,8 +383,6 @@ class ModelScanner:
if self._cache is None:
self._cache = ModelCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
finally:
@@ -735,19 +524,74 @@ class ModelScanner:
# Resort cache
await self._cache.resort()
# Save updated cache to disk
await self._save_cache_to_disk()
logger.info(f"{self.model_type.capitalize()} Scanner: Cache reconciliation completed in {time.time() - start_time:.2f} seconds. Added {total_added}, removed {total_removed} models.")
except Exception as e:
logger.error(f"{self.model_type.capitalize()} Scanner: Error reconciling cache: {e}", exc_info=True)
finally:
self._is_initializing = False # Unset flag
# These methods should be implemented in child classes
async def scan_all_models(self) -> List[Dict]:
"""Scan all model directories and return metadata"""
raise NotImplementedError("Subclasses must implement scan_all_models")
all_models = []
# Create scan tasks for each directory
scan_tasks = []
for model_root in self.get_model_roots():
task = asyncio.create_task(self._scan_directory(model_root))
scan_tasks.append(task)
# Wait for all tasks to complete
for task in scan_tasks:
try:
models = await task
all_models.extend(models)
except Exception as e:
logger.error(f"Error scanning directory: {e}")
return all_models
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a single directory for model files"""
models = []
original_root = root_path # Save original root path
async def scan_recursive(path: str, visited_paths: set):
"""Recursively scan directory, avoiding circular symlinks"""
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
# Use original path instead of real path
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, models)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# For directories, continue scanning with original path
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return models
async def _process_single_file(self, file_path: str, root_path: str, models: list):
"""Process a single file and add to results list"""
try:
result = await self._process_model_file(file_path, root_path)
if result:
models.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
def is_initializing(self) -> bool:
"""Check if the scanner is currently initializing"""
@@ -931,7 +775,7 @@ class ModelScanner:
logger.error(f"Error processing {file_path}: {e}")
async def add_model_to_cache(self, metadata_dict: Dict, folder: str = '') -> bool:
"""Add a model to the cache and save to disk
"""Add a model to the cache
Args:
metadata_dict: The model metadata dictionary
@@ -960,9 +804,6 @@ class ModelScanner:
# Update the hash index
self._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
# Save to disk
await self._save_cache_to_disk()
return True
except Exception as e:
logger.error(f"Error adding model to cache: {e}")
@@ -1102,9 +943,6 @@ class ModelScanner:
await cache.resort()
# Save the updated cache
await self._save_cache_to_disk()
return True
def has_hash(self, sha256: str) -> bool:
@@ -1198,11 +1036,7 @@ class ModelScanner:
if self._cache is None:
return False
updated = await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level)
if updated:
# Save updated cache to disk
await self._save_cache_to_disk()
return updated
return await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level)
async def bulk_delete_models(self, file_paths: List[str]) -> Dict:
"""Delete multiple models and update cache in a batch operation
@@ -1334,9 +1168,6 @@ class ModelScanner:
# Resort cache
await self._cache.resort()
# Save updated cache to disk
await self._save_cache_to_disk()
return True
except Exception as e:

View File

@@ -0,0 +1,137 @@
from typing import Dict, Type, Any
import logging
logger = logging.getLogger(__name__)
class ModelServiceFactory:
"""Factory for managing model services and routes"""
_services: Dict[str, Type] = {}
_routes: Dict[str, Type] = {}
_initialized_services: Dict[str, Any] = {}
_initialized_routes: Dict[str, Any] = {}
@classmethod
def register_model_type(cls, model_type: str, service_class: Type, route_class: Type):
"""Register a new model type with its service and route classes
Args:
model_type: The model type identifier (e.g., 'lora', 'checkpoint')
service_class: The service class for this model type
route_class: The route class for this model type
"""
cls._services[model_type] = service_class
cls._routes[model_type] = route_class
logger.info(f"Registered model type '{model_type}' with service {service_class.__name__} and routes {route_class.__name__}")
@classmethod
def get_service_class(cls, model_type: str) -> Type:
"""Get service class for a model type
Args:
model_type: The model type identifier
Returns:
The service class for the model type
Raises:
ValueError: If model type is not registered
"""
if model_type not in cls._services:
raise ValueError(f"Unknown model type: {model_type}")
return cls._services[model_type]
@classmethod
def get_route_class(cls, model_type: str) -> Type:
"""Get route class for a model type
Args:
model_type: The model type identifier
Returns:
The route class for the model type
Raises:
ValueError: If model type is not registered
"""
if model_type not in cls._routes:
raise ValueError(f"Unknown model type: {model_type}")
return cls._routes[model_type]
@classmethod
def get_route_instance(cls, model_type: str):
"""Get or create route instance for a model type
Args:
model_type: The model type identifier
Returns:
The route instance for the model type
"""
if model_type not in cls._initialized_routes:
route_class = cls.get_route_class(model_type)
cls._initialized_routes[model_type] = route_class()
return cls._initialized_routes[model_type]
@classmethod
def setup_all_routes(cls, app):
"""Setup routes for all registered model types
Args:
app: The aiohttp application instance
"""
logger.info(f"Setting up routes for {len(cls._services)} registered model types")
for model_type in cls._services.keys():
try:
routes_instance = cls.get_route_instance(model_type)
routes_instance.setup_routes(app)
logger.info(f"Successfully set up routes for {model_type}")
except Exception as e:
logger.error(f"Failed to setup routes for {model_type}: {e}", exc_info=True)
@classmethod
def get_registered_types(cls) -> list:
"""Get list of all registered model types
Returns:
List of registered model type identifiers
"""
return list(cls._services.keys())
@classmethod
def is_registered(cls, model_type: str) -> bool:
"""Check if a model type is registered
Args:
model_type: The model type identifier
Returns:
True if the model type is registered, False otherwise
"""
return model_type in cls._services
@classmethod
def clear_registrations(cls):
"""Clear all registrations - mainly for testing purposes"""
cls._services.clear()
cls._routes.clear()
cls._initialized_services.clear()
cls._initialized_routes.clear()
logger.info("Cleared all model type registrations")
def register_default_model_types():
"""Register the default model types (LoRA and Checkpoint)"""
from ..services.lora_service import LoraService
from ..services.checkpoint_service import CheckpointService
from ..routes.lora_routes import LoraRoutes
from ..routes.checkpoint_routes import CheckpointRoutes
# Register LoRA model type
ModelServiceFactory.register_model_type('lora', LoraService, LoraRoutes)
# Register Checkpoint model type
ModelServiceFactory.register_model_type('checkpoint', CheckpointService, CheckpointRoutes)
logger.info("Registered default model types: lora, checkpoint")

View File

@@ -393,8 +393,8 @@ class RecipeScanner:
if 'hash' in lora and (not lora.get('file_name') or not lora['file_name']):
hash_value = lora['hash']
if self._lora_scanner.has_lora_hash(hash_value):
lora_path = self._lora_scanner.get_lora_path_by_hash(hash_value)
if self._lora_scanner.has_hash(hash_value):
lora_path = self._lora_scanner.get_path_by_hash(hash_value)
if lora_path:
file_name = os.path.splitext(os.path.basename(lora_path))[0]
lora['file_name'] = file_name
@@ -465,7 +465,7 @@ class RecipeScanner:
# Count occurrences of each base model
for lora in loras:
if 'hash' in lora:
lora_path = self._lora_scanner.get_lora_path_by_hash(lora['hash'])
lora_path = self._lora_scanner.get_path_by_hash(lora['hash'])
if lora_path:
base_model = await self._get_base_model_for_lora(lora_path)
if base_model:
@@ -603,9 +603,9 @@ class RecipeScanner:
if 'loras' in item:
for lora in item['loras']:
if 'hash' in lora and lora['hash']:
lora['inLibrary'] = self._lora_scanner.has_lora_hash(lora['hash'].lower())
lora['inLibrary'] = self._lora_scanner.has_hash(lora['hash'].lower())
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
lora['localPath'] = self._lora_scanner.get_lora_path_by_hash(lora['hash'].lower())
lora['localPath'] = self._lora_scanner.get_path_by_hash(lora['hash'].lower())
result = {
'items': paginated_items,
@@ -655,9 +655,9 @@ class RecipeScanner:
for lora in formatted_recipe['loras']:
if 'hash' in lora and lora['hash']:
lora_hash = lora['hash'].lower()
lora['inLibrary'] = self._lora_scanner.has_lora_hash(lora_hash)
lora['inLibrary'] = self._lora_scanner.has_hash(lora_hash)
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora_hash)
lora['localPath'] = self._lora_scanner.get_lora_path_by_hash(lora_hash)
lora['localPath'] = self._lora_scanner.get_path_by_hash(lora_hash)
return formatted_recipe

View File

@@ -7,106 +7,176 @@ logger = logging.getLogger(__name__)
T = TypeVar('T') # Define a type variable for service types
class ServiceRegistry:
"""Centralized registry for service singletons"""
"""Central registry for managing singleton services"""
_instance = None
_services: Dict[str, Any] = {}
_lock = asyncio.Lock()
_locks: Dict[str, asyncio.Lock] = {}
@classmethod
def get_instance(cls):
"""Get singleton instance of the registry"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
async def register_service(cls, name: str, service: Any) -> None:
"""Register a service instance with the registry
Args:
name: Service name identifier
service: Service instance to register
"""
cls._services[name] = service
logger.debug(f"Registered service: {name}")
@classmethod
async def register_service(cls, service_name: str, service_instance: Any) -> None:
"""Register a service instance with the registry"""
registry = cls.get_instance()
async with cls._lock:
registry._services[service_name] = service_instance
logger.debug(f"Registered service: {service_name}")
async def get_service(cls, name: str) -> Optional[Any]:
"""Get a service instance by name
Args:
name: Service name identifier
Returns:
Service instance or None if not found
"""
return cls._services.get(name)
@classmethod
async def get_service(cls, service_name: str) -> Any:
"""Get a service instance by name"""
registry = cls.get_instance()
async with cls._lock:
if service_name not in registry._services:
logger.debug(f"Service {service_name} not found in registry")
return None
return registry._services[service_name]
def _get_lock(cls, name: str) -> asyncio.Lock:
"""Get or create a lock for a service
Args:
name: Service name identifier
Returns:
AsyncIO lock for the service
"""
if name not in cls._locks:
cls._locks[name] = asyncio.Lock()
return cls._locks[name]
@classmethod
def get_service_sync(cls, service_name: str) -> Any:
"""Get a service instance by name (synchronous version)"""
registry = cls.get_instance()
if service_name not in registry._services:
logger.debug(f"Service {service_name} not found in registry")
return None
return registry._services[service_name]
# Convenience methods for common services
@classmethod
async def get_lora_scanner(cls):
"""Get the LoraScanner instance"""
from .lora_scanner import LoraScanner
scanner = await cls.get_service("lora_scanner")
if scanner is None:
scanner = await LoraScanner.get_instance()
await cls.register_service("lora_scanner", scanner)
return scanner
"""Get or create LoRA scanner instance"""
service_name = "lora_scanner"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .lora_scanner import LoraScanner
scanner = await LoraScanner.get_instance()
cls._services[service_name] = scanner
logger.debug(f"Created and registered {service_name}")
return scanner
@classmethod
async def get_checkpoint_scanner(cls):
"""Get the CheckpointScanner instance"""
from .checkpoint_scanner import CheckpointScanner
scanner = await cls.get_service("checkpoint_scanner")
if scanner is None:
"""Get or create Checkpoint scanner instance"""
service_name = "checkpoint_scanner"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .checkpoint_scanner import CheckpointScanner
scanner = await CheckpointScanner.get_instance()
await cls.register_service("checkpoint_scanner", scanner)
return scanner
@classmethod
async def get_civitai_client(cls):
"""Get the CivitaiClient instance"""
from .civitai_client import CivitaiClient
client = await cls.get_service("civitai_client")
if client is None:
client = await CivitaiClient.get_instance()
await cls.register_service("civitai_client", client)
return client
@classmethod
async def get_download_manager(cls):
"""Get the DownloadManager instance"""
from .download_manager import DownloadManager
manager = await cls.get_service("download_manager")
if manager is None:
manager = await DownloadManager.get_instance()
await cls.register_service("download_manager", manager)
return manager
cls._services[service_name] = scanner
logger.debug(f"Created and registered {service_name}")
return scanner
@classmethod
async def get_recipe_scanner(cls):
"""Get the RecipeScanner instance"""
from .recipe_scanner import RecipeScanner
scanner = await cls.get_service("recipe_scanner")
if scanner is None:
lora_scanner = await cls.get_lora_scanner()
scanner = RecipeScanner(lora_scanner)
await cls.register_service("recipe_scanner", scanner)
return scanner
"""Get or create Recipe scanner instance"""
service_name = "recipe_scanner"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .recipe_scanner import RecipeScanner
scanner = await RecipeScanner.get_instance()
cls._services[service_name] = scanner
logger.debug(f"Created and registered {service_name}")
return scanner
@classmethod
async def get_civitai_client(cls):
"""Get or create CivitAI client instance"""
service_name = "civitai_client"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .civitai_client import CivitaiClient
client = await CivitaiClient.get_instance()
cls._services[service_name] = client
logger.debug(f"Created and registered {service_name}")
return client
@classmethod
async def get_download_manager(cls):
"""Get or create Download manager instance"""
service_name = "download_manager"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .download_manager import DownloadManager
manager = DownloadManager()
cls._services[service_name] = manager
logger.debug(f"Created and registered {service_name}")
return manager
@classmethod
async def get_websocket_manager(cls):
"""Get the WebSocketManager instance"""
from .websocket_manager import ws_manager
manager = await cls.get_service("websocket_manager")
if manager is None:
# ws_manager is already a global instance in websocket_manager.py
"""Get or create WebSocket manager instance"""
service_name = "websocket_manager"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .websocket_manager import ws_manager
await cls.register_service("websocket_manager", ws_manager)
manager = ws_manager
return manager
cls._services[service_name] = ws_manager
logger.debug(f"Registered {service_name}")
return ws_manager
@classmethod
def clear_services(cls):
"""Clear all registered services - mainly for testing"""
cls._services.clear()
cls._locks.clear()
logger.info("Cleared all registered services")

View File

@@ -566,9 +566,10 @@ class ModelRouteUtils:
return web.Response(text=str(e), status=500)
@staticmethod
async def handle_download_model(request: web.Request, download_manager: DownloadManager) -> web.Response:
async def handle_download_model(request: web.Request) -> web.Response:
"""Handle model download request"""
try:
download_manager = await ServiceRegistry.get_download_manager()
data = await request.json()
# Get or generate a download ID
@@ -663,17 +664,17 @@ class ModelRouteUtils:
}, status=500)
@staticmethod
async def handle_cancel_download(request: web.Request, download_manager: DownloadManager) -> web.Response:
async def handle_cancel_download(request: web.Request) -> web.Response:
"""Handle cancellation of a download task
Args:
request: The aiohttp request
download_manager: The download manager instance
Returns:
web.Response: The HTTP response
"""
try:
download_manager = await ServiceRegistry.get_download_manager()
download_id = request.match_info.get('download_id')
if not download_id:
return web.json_response({
@@ -701,17 +702,17 @@ class ModelRouteUtils:
}, status=500)
@staticmethod
async def handle_list_downloads(request: web.Request, download_manager: DownloadManager) -> web.Response:
async def handle_list_downloads(request: web.Request) -> web.Response:
"""Get list of active downloads
Args:
request: The aiohttp request
download_manager: The download manager instance
Returns:
web.Response: The HTTP response with list of downloads
"""
try:
download_manager = await ServiceRegistry.get_download_manager()
result = await download_manager.get_active_downloads()
return web.json_response(result)
except Exception as e:
@@ -1047,3 +1048,56 @@ class ModelRouteUtils:
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def handle_save_metadata(request: web.Request, scanner) -> web.Response:
"""Handle saving metadata updates
Args:
request: The aiohttp request
scanner: The model scanner instance
Returns:
web.Response: The HTTP response
"""
try:
data = await request.json()
file_path = data.get('file_path')
if not file_path:
return web.Response(text='File path is required', status=400)
# Remove file path from data to avoid saving it
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
# Get metadata file path
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Load existing metadata
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
# Handle nested updates (for civitai.trainedWords)
for key, value in metadata_updates.items():
if isinstance(value, dict) and key in metadata and isinstance(metadata[key], dict):
# Deep update for nested dictionaries
for nested_key, nested_value in value.items():
metadata[key][nested_key] = nested_value
else:
# Regular update for top-level keys
metadata[key] = value
# Save updated metadata
await MetadataManager.save_metadata(file_path, metadata)
# Update cache
await scanner.update_single_model_cache(file_path, file_path, metadata)
# If model_name was updated, resort the cache
if 'model_name' in metadata_updates:
cache = await scanner.get_cached_data()
await cache.resort()
return web.json_response({'success': True})
except Exception as e:
logger.error(f"Error saving metadata: {e}", exc_info=True)
return web.Response(text=str(e), status=500)

View File

@@ -14,7 +14,7 @@ dependencies = [
"requests",
"toml",
"natsort",
"msgpack"
"GitPython"
]
[project.urls]

View File

@@ -9,5 +9,5 @@ requests
toml
numpy
natsort
msgpack
pyyaml
GitPython

View File

@@ -106,6 +106,22 @@ logger = logging.getLogger("lora-manager-standalone")
# Configure aiohttp access logger to be less verbose
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
# Add specific suppression for connection reset errors
class ConnectionResetFilter(logging.Filter):
def filter(self, record):
# Filter out connection reset errors that are not critical
if "ConnectionResetError" in str(record.getMessage()):
return False
if "_call_connection_lost" in str(record.getMessage()):
return False
if "WinError 10054" in str(record.getMessage()):
return False
return True
# Apply the filter to asyncio logger
asyncio_logger = logging.getLogger("asyncio")
asyncio_logger.addFilter(ConnectionResetFilter())
# Now we can import the global config from our local modules
from py.config import config
@@ -118,17 +134,6 @@ class StandaloneServer:
# Ensure the app's access logger is configured to reduce verbosity
self.app._subapps = [] # Ensure this exists to avoid AttributeError
# Configure access logging for the app
self.app.on_startup.append(self._configure_access_logger)
async def _configure_access_logger(self, app):
"""Configure access logger to reduce verbosity"""
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
# If using aiohttp>=3.8.0, configure access logger through app directly
if hasattr(app, 'access_logger'):
app.access_logger.setLevel(logging.WARNING)
async def setup(self):
"""Set up the standalone server"""
@@ -218,9 +223,6 @@ class StandaloneLoraManager(LoraManager):
# Store app in a global-like location for compatibility
sys.modules['server'].PromptServer.instance = server_instance
# Configure aiohttp access logger to be less verbose
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
added_targets = set() # Track already added target paths
@@ -314,35 +316,39 @@ class StandaloneLoraManager(LoraManager):
app.router.add_static('/loras_static', config.static_path)
# Setup feature routes
from py.routes.lora_routes import LoraRoutes
from py.routes.api_routes import ApiRoutes
from py.services.model_service_factory import ModelServiceFactory, register_default_model_types
from py.routes.recipe_routes import RecipeRoutes
from py.routes.checkpoints_routes import CheckpointsRoutes
from py.routes.update_routes import UpdateRoutes
from py.routes.misc_routes import MiscRoutes
from py.routes.example_images_routes import ExampleImagesRoutes
from py.routes.stats_routes import StatsRoutes
from py.services.websocket_manager import ws_manager
lora_routes = LoraRoutes()
checkpoints_routes = CheckpointsRoutes()
register_default_model_types()
# Setup all model routes using the factory
ModelServiceFactory.setup_all_routes(app)
stats_routes = StatsRoutes()
# Initialize routes
lora_routes.setup_routes(app)
checkpoints_routes.setup_routes(app)
stats_routes.setup_routes(app)
ApiRoutes.setup_routes(app)
RecipeRoutes.setup_routes(app)
UpdateRoutes.setup_routes(app)
MiscRoutes.setup_routes(app)
ExampleImagesRoutes.setup_routes(app)
# Setup WebSocket routes that are shared across all model types
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection)
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection)
# Schedule service initialization
app.on_startup.append(lambda app: cls._initialize_services())
# Add cleanup
app.on_shutdown.append(cls._cleanup)
app.on_shutdown.append(ApiRoutes.cleanup)
def parse_args():
"""Parse command line arguments"""
@@ -367,9 +373,6 @@ async def main():
# Set log level
logging.getLogger().setLevel(getattr(logging, args.log_level))
# Explicitly configure aiohttp access logger regardless of selected log level
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
# Create the server instance
server = StandaloneServer()

View File

@@ -50,8 +50,8 @@ html, body {
--lora-border: oklch(90% 0.02 256 / 0.15);
--lora-text: oklch(95% 0.02 256);
--lora-error: oklch(75% 0.32 29);
--lora-warning: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h)); /* Modified to be used with oklch() */
--lora-success: oklch(var(--lora-success-l) var(--lora-success-c) var(--lora-success-h)); /* New green success color */
--lora-warning: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h));
--lora-success: oklch(var(--lora-success-l) var(--lora-success-c) var(--lora-success-h));
/* Spacing Scale */
--space-1: calc(8px * 1);

View File

@@ -223,11 +223,6 @@
opacity: 1;
}
.update-badge.hidden,
.update-badge:not(.visible) {
opacity: 0;
}
/* Mobile adjustments */
@media (max-width: 768px) {
.app-title {

View File

@@ -172,6 +172,91 @@ body.modal-open {
opacity: 1;
}
/* Update Modal specific styles */
.update-actions {
display: flex;
flex-direction: column;
gap: var(--space-2);
align-items: stretch;
flex-wrap: nowrap;
}
.update-link {
color: var(--lora-accent);
text-decoration: none;
display: flex;
align-items: center;
gap: 8px;
font-size: 0.95em;
}
.update-link:hover {
text-decoration: underline;
}
/* Update progress styles */
.update-progress {
background: rgba(0, 0, 0, 0.03);
border: 1px solid var(--lora-border);
border-radius: var(--border-radius-sm);
padding: var(--space-2);
margin: var(--space-2) 0;
}
[data-theme="dark"] .update-progress {
background: rgba(255, 255, 255, 0.03);
}
.progress-info {
display: flex;
flex-direction: column;
gap: var(--space-1);
}
.progress-text {
font-size: 0.9em;
color: var(--text-color);
opacity: 0.8;
}
.progress-bar {
width: 100%;
height: 8px;
background-color: rgba(0, 0, 0, 0.1);
border-radius: 4px;
overflow: hidden;
}
[data-theme="dark"] .progress-bar {
background-color: rgba(255, 255, 255, 0.1);
}
.progress-fill {
height: 100%;
background-color: var(--lora-accent);
width: 0%;
transition: width 0.3s ease;
border-radius: 4px;
}
/* Update button states */
#updateBtn {
min-width: 120px;
}
#updateBtn.updating {
background-color: var(--lora-warning);
cursor: not-allowed;
}
#updateBtn.success {
background-color: var(--lora-success);
}
#updateBtn.error {
background-color: var(--lora-error);
}
/* Settings styles */
.settings-toggle {
width: 36px;

View File

@@ -182,6 +182,31 @@
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
}
/* Style for optgroups */
.control-group select optgroup {
font-weight: 600;
font-style: normal;
color: var(--text-color);
background-color: var(--card-bg);
}
.control-group select option {
padding: 4px 8px;
background-color: var(--card-bg);
color: var(--text-color);
}
/* Dark theme optgroup styling */
[data-theme="dark"] .control-group select optgroup {
background-color: var(--card-bg);
color: var(--text-color);
}
[data-theme="dark"] .control-group select option {
background-color: var(--card-bg);
color: var(--text-color);
}
.control-group select:hover {
border-color: var(--lora-accent);
background-color: var(--bg-color);

View File

@@ -54,25 +54,16 @@ export async function fetchModelsPage(options = {}) {
if (pageState.filters) {
// Handle tags filters
if (pageState.filters.tags && pageState.filters.tags.length > 0) {
// Checkpoints API expects individual 'tag' parameters, Loras API expects comma-separated 'tags'
if (modelType === 'checkpoint') {
pageState.filters.tags.forEach(tag => {
params.append('tag', tag);
});
} else {
params.append('tags', pageState.filters.tags.join(','));
}
pageState.filters.tags.forEach(tag => {
params.append('tag', tag);
});
}
// Handle base model filters
if (pageState.filters.baseModel && pageState.filters.baseModel.length > 0) {
if (modelType === 'checkpoint') {
pageState.filters.baseModel.forEach(model => {
params.append('base_model', model);
});
} else {
params.append('base_models', pageState.filters.baseModel.join(','));
}
pageState.filters.baseModel.forEach(model => {
params.append('base_model', model);
});
}
}
@@ -277,7 +268,7 @@ export async function deleteModel(filePath, modelType = 'lora') {
const endpoint = modelType === 'checkpoint'
? '/api/checkpoints/delete'
: '/api/delete_model';
: '/api/loras/delete';
const response = await fetch(endpoint, {
method: 'POST',
@@ -454,7 +445,7 @@ export async function refreshSingleModelMetadata(filePath, modelType = 'lora') {
const endpoint = modelType === 'checkpoint'
? '/api/checkpoints/fetch-civitai'
: '/api/fetch-civitai';
: '/api/loras/fetch-civitai';
const response = await fetch(endpoint, {
method: 'POST',
@@ -557,7 +548,7 @@ export async function uploadPreview(filePath, file, modelType = 'lora', nsfwLeve
// Set endpoint based on model type
const endpoint = modelType === 'checkpoint'
? '/api/checkpoints/replace-preview'
: '/api/replace_preview';
: '/api/loras/replace_preview';
const response = await fetch(endpoint, {
method: 'POST',

View File

@@ -89,7 +89,7 @@ export async function fetchLorasPage(page = 1, pageSize = 100) {
export async function fetchCivitai() {
return fetchCivitaiMetadata({
modelType: 'lora',
fetchEndpoint: '/api/fetch-all-civitai',
fetchEndpoint: '/api/loras/fetch-all-civitai',
resetAndReloadFunction: resetAndReload
});
}

View File

@@ -125,7 +125,7 @@ export const ModelContextMenuMixin = {
const endpoint = this.modelType === 'checkpoint' ?
'/api/checkpoints/relink-civitai' :
'/api/relink-civitai';
'/api/loras/relink-civitai';
const response = await fetch(endpoint, {
method: 'POST',

View File

@@ -1,5 +1,5 @@
// PageControls.js - Manages controls for both LoRAs and Checkpoints pages
import { state, getCurrentPageState, setCurrentPageType } from '../../state/index.js';
import { getCurrentPageState, setCurrentPageType } from '../../state/index.js';
import { getStorageItem, setStorageItem, getSessionItem, setSessionItem } from '../../utils/storageHelpers.js';
import { showToast } from '../../utils/uiHelpers.js';
@@ -41,6 +41,9 @@ export class PageControls {
this.pageState.isLoading = false;
this.pageState.hasMore = true;
// Set default sort based on page type
this.pageState.sortBy = this.pageType === 'loras' ? 'name:asc' : 'name:asc';
// Load sort preference
this.loadSortPreference();
}
@@ -326,14 +329,36 @@ export class PageControls {
loadSortPreference() {
const savedSort = getStorageItem(`${this.pageType}_sort`);
if (savedSort) {
this.pageState.sortBy = savedSort;
// Handle legacy format conversion
const convertedSort = this.convertLegacySortFormat(savedSort);
this.pageState.sortBy = convertedSort;
const sortSelect = document.getElementById('sortSelect');
if (sortSelect) {
sortSelect.value = savedSort;
sortSelect.value = convertedSort;
}
}
}
/**
* Convert legacy sort format to new format
* @param {string} sortValue - The sort value to convert
* @returns {string} - Converted sort value
*/
convertLegacySortFormat(sortValue) {
// Convert old format to new format with direction
switch (sortValue) {
case 'name':
return 'name:asc';
case 'date':
return 'date:desc'; // Newest first is more intuitive default
case 'size':
return 'size:desc'; // Largest first is more intuitive default
default:
// If it's already in new format or unknown, return as is
return sortValue.includes(':') ? sortValue : 'name:asc';
}
}
/**
* Save sort preference to storage
* @param {string} sortValue - The sort value to save

View File

@@ -87,7 +87,7 @@ export class DownloadManager {
throw new Error('Invalid Civitai URL format');
}
const response = await fetch(`/api/civitai/versions/${this.modelId}`);
const response = await fetch(`/api/loras/civitai/versions/${this.modelId}`);
if (!response.ok) {
const errorData = await response.json().catch(() => ({}));
if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) {
@@ -254,7 +254,7 @@ export class DownloadManager {
try {
// Fetch LoRA roots
const rootsResponse = await fetch('/api/lora-roots');
const rootsResponse = await fetch('/api/loras/roots');
if (!rootsResponse.ok) {
throw new Error('Failed to fetch LoRA roots');
}
@@ -272,7 +272,7 @@ export class DownloadManager {
}
// Fetch folders dynamically
const foldersResponse = await fetch('/api/folders');
const foldersResponse = await fetch('/api/loras/folders');
if (!foldersResponse.ok) {
throw new Error('Failed to fetch folders');
}

View File

@@ -74,7 +74,7 @@ class MoveManager {
try {
// Fetch LoRA roots
const rootsResponse = await fetch('/api/lora-roots');
const rootsResponse = await fetch('/api/loras/roots');
if (!rootsResponse.ok) {
throw new Error('Failed to fetch LoRA roots');
}
@@ -96,7 +96,7 @@ class MoveManager {
}
// Fetch folders dynamically
const foldersResponse = await fetch('/api/folders');
const foldersResponse = await fetch('/api/loras/folders');
if (!foldersResponse.ok) {
throw new Error('Failed to fetch folders');
}
@@ -190,7 +190,7 @@ class MoveManager {
// Refresh folder tags after successful move
try {
const foldersResponse = await fetch('/api/folders');
const foldersResponse = await fetch('/api/loras/folders');
if (foldersResponse.ok) {
const foldersData = await foldersResponse.json();
updateFolderTags(foldersData.folders);

View File

@@ -161,7 +161,7 @@ export class SettingsManager {
if (!defaultLoraRootSelect) return;
// Fetch lora roots
const response = await fetch('/api/lora-roots');
const response = await fetch('/api/loras/roots');
if (!response.ok) {
throw new Error('Failed to fetch LoRA roots');
}

View File

@@ -3,7 +3,7 @@ import { getStorageItem, setStorageItem } from '../utils/storageHelpers.js';
export class UpdateService {
constructor() {
this.updateCheckInterval = 24 * 60 * 60 * 1000; // 24 hours
this.updateCheckInterval = 60 * 60 * 1000; // 1 hour
this.currentVersion = "v0.0.0"; // Initialize with default values
this.latestVersion = "v0.0.0"; // Initialize with default values
this.updateInfo = null;
@@ -13,8 +13,10 @@ export class UpdateService {
branch: "unknown",
commit_date: "unknown"
};
this.updateNotificationsEnabled = getStorageItem('show_update_notifications');
this.updateNotificationsEnabled = getStorageItem('show_update_notifications', true);
this.lastCheckTime = parseInt(getStorageItem('last_update_check') || '0');
this.isUpdating = false;
this.nightlyMode = getStorageItem('nightly_updates', false);
}
initialize() {
@@ -28,23 +30,44 @@ export class UpdateService {
this.updateBadgeVisibility();
});
}
const updateBtn = document.getElementById('updateBtn');
if (updateBtn) {
updateBtn.addEventListener('click', () => this.performUpdate());
}
// Register event listener for nightly update toggle
const nightlyCheckbox = document.getElementById('nightlyUpdateToggle');
if (nightlyCheckbox) {
nightlyCheckbox.checked = this.nightlyMode;
nightlyCheckbox.addEventListener('change', (e) => {
this.nightlyMode = e.target.checked;
setStorageItem('nightly_updates', e.target.checked);
this.updateNightlyWarning();
this.updateModalContent();
// Re-check for updates when switching channels
this.manualCheckForUpdates();
});
this.updateNightlyWarning();
}
// Perform update check if needed
this.checkForUpdates().then(() => {
// Ensure badges are updated after checking
this.updateBadgeVisibility();
});
// Set up event listener for update button
// const updateToggle = document.getElementById('updateToggleBtn');
// if (updateToggle) {
// updateToggle.addEventListener('click', () => this.toggleUpdateModal());
// }
// Immediately update modal content with current values (even if from default)
this.updateModalContent();
}
updateNightlyWarning() {
const warning = document.getElementById('nightlyWarning');
if (warning) {
warning.style.display = this.nightlyMode ? 'flex' : 'none';
}
}
async checkForUpdates() {
// Check if we should perform an update check
const now = Date.now();
@@ -59,8 +82,8 @@ export class UpdateService {
}
try {
// Call backend API to check for updates
const response = await fetch('/api/check-updates');
// Call backend API to check for updates with nightly flag
const response = await fetch(`/api/check-updates?nightly=${this.nightlyMode}`);
const data = await response.json();
if (data.success) {
@@ -137,8 +160,8 @@ export class UpdateService {
const shouldShow = this.updateNotificationsEnabled && this.updateAvailable;
if (updateBadge) {
updateBadge.classList.toggle('hidden', !shouldShow);
console.log("Update badge visibility:", !shouldShow ? "hidden" : "visible");
updateBadge.classList.toggle('visible', shouldShow);
console.log("Update badge visibility:", shouldShow ? "visible" : "hidden");
}
}
@@ -157,7 +180,17 @@ export class UpdateService {
const newVersionEl = modal.querySelector('.new-version .version-number');
if (currentVersionEl) currentVersionEl.textContent = this.currentVersion;
if (newVersionEl) newVersionEl.textContent = this.latestVersion;
if (newVersionEl) {
newVersionEl.textContent = this.latestVersion;
}
// Update update button state
const updateBtn = modal.querySelector('#updateBtn');
if (updateBtn) {
updateBtn.classList.toggle('disabled', !this.updateAvailable || this.isUpdating);
updateBtn.disabled = !this.updateAvailable || this.isUpdating;
}
// Update git info
const gitInfoEl = modal.querySelector('.git-info');
@@ -218,6 +251,131 @@ export class UpdateService {
}
}
async performUpdate() {
if (!this.updateAvailable || this.isUpdating) {
return;
}
try {
this.isUpdating = true;
this.updateUpdateUI('updating', 'Updating...');
this.showUpdateProgress(true);
// Update progress
this.updateProgress(10, 'Preparing update...');
const response = await fetch('/api/perform-update', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
nightly: this.nightlyMode
})
});
this.updateProgress(50, 'Installing update...');
const data = await response.json();
if (data.success) {
this.updateProgress(100, 'Update completed successfully!');
this.updateUpdateUI('success', 'Updated!');
// Show success message and suggest restart
setTimeout(() => {
this.showUpdateCompleteMessage(data.new_version);
}, 1000);
} else {
throw new Error(data.error || 'Update failed');
}
} catch (error) {
console.error('Update failed:', error);
this.updateUpdateUI('error', 'Update Failed');
this.updateProgress(0, `Update failed: ${error.message}`);
// Hide progress after error
setTimeout(() => {
this.showUpdateProgress(false);
}, 3000);
} finally {
this.isUpdating = false;
}
}
updateUpdateUI(state, text) {
const updateBtn = document.getElementById('updateBtn');
const updateBtnText = document.getElementById('updateBtnText');
if (updateBtn && updateBtnText) {
// Remove existing state classes
updateBtn.classList.remove('updating', 'success', 'error', 'disabled');
// Add new state class
if (state !== 'normal') {
updateBtn.classList.add(state);
}
// Update button text
updateBtnText.textContent = text;
// Update disabled state
updateBtn.disabled = (state === 'updating' || state === 'disabled');
}
}
showUpdateProgress(show) {
const progressContainer = document.getElementById('updateProgress');
if (progressContainer) {
progressContainer.style.display = show ? 'block' : 'none';
}
}
updateProgress(percentage, text) {
const progressFill = document.getElementById('updateProgressFill');
const progressText = document.getElementById('updateProgressText');
if (progressFill) {
progressFill.style.width = `${percentage}%`;
}
if (progressText) {
progressText.textContent = text;
}
}
showUpdateCompleteMessage(newVersion) {
const modal = document.getElementById('updateModal');
if (!modal) return;
// Update the modal content to show completion
const progressText = document.getElementById('updateProgressText');
if (progressText) {
progressText.innerHTML = `
<div style="text-align: center; color: var(--lora-success);">
<i class="fas fa-check-circle" style="margin-right: 8px;"></i>
Successfully updated to ${newVersion}!
<br><br>
<small style="opacity: 0.8;">
Please restart ComfyUI to complete the update process.
</small>
</div>
`;
}
// Update current version display
this.currentVersion = newVersion;
this.updateAvailable = false;
// Refresh the modal content
setTimeout(() => {
this.updateModalContent();
this.showUpdateProgress(false);
}, 2000);
}
// Simple markdown parser for changelog items
parseMarkdown(text) {
if (!text) return '';

View File

@@ -99,7 +99,7 @@ export class FolderBrowser {
}
// Fetch LoRA roots
const rootsResponse = await fetch('/api/lora-roots');
const rootsResponse = await fetch('/api/loras/roots');
if (!rootsResponse.ok) {
throw new Error(`Failed to fetch LoRA roots: ${rootsResponse.status}`);
}
@@ -119,7 +119,7 @@ export class FolderBrowser {
}
// Fetch folders
const foldersResponse = await fetch('/api/folders');
const foldersResponse = await fetch('/api/loras/folders');
if (!foldersResponse.ok) {
throw new Error(`Failed to fetch folders: ${foldersResponse.status}`);
}

View File

@@ -11,8 +11,18 @@
<div class="action-buttons">
<div title="Sort models by..." class="control-group">
<select id="sortSelect">
<option value="name">Name</option>
<option value="date">Date</option>
<optgroup label="Name">
<option value="name:asc">A - Z</option>
<option value="name:desc">Z - A</option>
</optgroup>
<optgroup label="Date Added">
<option value="date:desc">Newest</option>
<option value="date:asc">Oldest</option>
</optgroup>
<optgroup label="File Size">
<option value="size:desc">Largest</option>
<option value="size:asc">Smallest</option>
</optgroup>
</select>
</div>
<div title="Refresh model list" class="control-group dropdown-group">

View File

@@ -53,7 +53,7 @@
</div>
<div class="update-toggle" id="updateToggleBtn" title="Check Updates">
<i class="fas fa-bell"></i>
<span class="update-badge hidden"></span>
<span class="update-badge"></span>
</div>
<div class="support-toggle" id="supportToggleBtn" title="Support">
<i class="fas fa-heart"></i>

View File

@@ -476,9 +476,26 @@
<span class="version-number">v0.0.0</span>
</div>
</div>
<a href="https://github.com/willmiao/ComfyUI-Lora-Manager" target="_blank" class="update-link">
<i class="fas fa-external-link-alt"></i> View on GitHub
</a>
<div class="update-actions">
<a href="https://github.com/willmiao/ComfyUI-Lora-Manager" target="_blank" class="update-link">
<i class="fas fa-external-link-alt"></i> View on GitHub
</a>
<button id="updateBtn" class="primary-btn disabled">
<i class="fas fa-download"></i>
<span id="updateBtnText">Update Now</span>
</button>
</div>
</div>
<!-- Update Progress Section -->
<div class="update-progress" id="updateProgress" style="display: none;">
<div class="progress-info">
<div class="progress-text" id="updateProgressText">Preparing update...</div>
<div class="progress-bar">
<div class="progress-fill" id="updateProgressFill"></div>
</div>
</div>
</div>
<div class="changelog-section">