mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
@@ -1,71 +0,0 @@
|
||||
# Path mappings configuration for ComfyUI-Lora-Manager
|
||||
# This file allows you to customize how base models and Civitai tags map to directories when downloading models
|
||||
|
||||
# Base model mappings
|
||||
# Format: "Original Base Model": "Custom Directory Name"
|
||||
#
|
||||
# Example: If you change "Flux.1 D": "flux"
|
||||
# Then models with base model "Flux.1 D" will be stored in a directory named "flux"
|
||||
# So the final path would be: <model_directory>/flux/<tag>/model_file.safetensors
|
||||
base_models:
|
||||
"SD 1.4": "SD 1.4"
|
||||
"SD 1.5": "SD 1.5"
|
||||
"SD 1.5 LCM": "SD 1.5 LCM"
|
||||
"SD 1.5 Hyper": "SD 1.5 Hyper"
|
||||
"SD 2.0": "SD 2.0"
|
||||
"SD 2.1": "SD 2.1"
|
||||
"SDXL 1.0": "SDXL 1.0"
|
||||
"SD 3": "SD 3"
|
||||
"SD 3.5": "SD 3.5"
|
||||
"SD 3.5 Medium": "SD 3.5 Medium"
|
||||
"SD 3.5 Large": "SD 3.5 Large"
|
||||
"SD 3.5 Large Turbo": "SD 3.5 Large Turbo"
|
||||
"Pony": "Pony"
|
||||
"Flux.1 S": "Flux.1 S"
|
||||
"Flux.1 D": "Flux.1 D"
|
||||
"Flux.1 Kontext": "Flux.1 Kontext"
|
||||
"AuraFlow": "AuraFlow"
|
||||
"SDXL Lightning": "SDXL Lightning"
|
||||
"SDXL Hyper": "SDXL Hyper"
|
||||
"Stable Cascade": "Stable Cascade"
|
||||
"SVD": "SVD"
|
||||
"PixArt a": "PixArt a"
|
||||
"PixArt E": "PixArt E"
|
||||
"Hunyuan 1": "Hunyuan 1"
|
||||
"Hunyuan Video": "Hunyuan Video"
|
||||
"Lumina": "Lumina"
|
||||
"Kolors": "Kolors"
|
||||
"Illustrious": "Illustrious"
|
||||
"Mochi": "Mochi"
|
||||
"LTXV": "LTXV"
|
||||
"CogVideoX": "CogVideoX"
|
||||
"NoobAI": "NoobAI"
|
||||
"Wan Video": "Wan Video"
|
||||
"Wan Video 1.3B t2v": "Wan Video 1.3B t2v"
|
||||
"Wan Video 14B t2v": "Wan Video 14B t2v"
|
||||
"Wan Video 14B i2v 480p": "Wan Video 14B i2v 480p"
|
||||
"Wan Video 14B i2v 720p": "Wan Video 14B i2v 720p"
|
||||
"HiDream": "HiDream"
|
||||
"Other": "Other"
|
||||
|
||||
# Civitai model tag mappings
|
||||
# Format: "Original Tag": "Custom Directory Name"
|
||||
#
|
||||
# Example: If you change "character": "characters"
|
||||
# Then models with tag "character" will be stored in a directory named "characters"
|
||||
# So the final path would be: <model_directory>/<base_model>/characters/model_file.safetensors
|
||||
model_tags:
|
||||
"character": "character"
|
||||
"style": "style"
|
||||
"concept": "concept"
|
||||
"clothing": "clothing"
|
||||
"base model": "base model"
|
||||
"poses": "poses"
|
||||
"background": "background"
|
||||
"tool": "tool"
|
||||
"vehicle": "vehicle"
|
||||
"buildings": "buildings"
|
||||
"objects": "objects"
|
||||
"assets": "assets"
|
||||
"animal": "animal"
|
||||
"action": "action"
|
||||
@@ -6,10 +6,8 @@ from pathlib import Path
|
||||
from server import PromptServer # type: ignore
|
||||
|
||||
from .config import config
|
||||
from .routes.lora_routes import LoraRoutes
|
||||
from .routes.api_routes import ApiRoutes
|
||||
from .services.model_service_factory import ModelServiceFactory, register_default_model_types
|
||||
from .routes.recipe_routes import RecipeRoutes
|
||||
from .routes.checkpoints_routes import CheckpointsRoutes
|
||||
from .routes.stats_routes import StatsRoutes
|
||||
from .routes.update_routes import UpdateRoutes
|
||||
from .routes.misc_routes import MiscRoutes
|
||||
@@ -17,6 +15,7 @@ from .routes.example_images_routes import ExampleImagesRoutes
|
||||
from .services.service_registry import ServiceRegistry
|
||||
from .services.settings_manager import settings
|
||||
from .utils.example_images_migration import ExampleImagesMigration
|
||||
from .services.websocket_manager import ws_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,12 +27,28 @@ class LoraManager:
|
||||
|
||||
@classmethod
|
||||
def add_routes(cls):
|
||||
"""Initialize and register all routes"""
|
||||
"""Initialize and register all routes using the new refactored architecture"""
|
||||
app = PromptServer.instance.app
|
||||
|
||||
# Configure aiohttp access logger to be less verbose
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
# Add specific suppression for connection reset errors
|
||||
class ConnectionResetFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
# Filter out connection reset errors that are not critical
|
||||
if "ConnectionResetError" in str(record.getMessage()):
|
||||
return False
|
||||
if "_call_connection_lost" in str(record.getMessage()):
|
||||
return False
|
||||
if "WinError 10054" in str(record.getMessage()):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Apply the filter to asyncio logger
|
||||
asyncio_logger = logging.getLogger("asyncio")
|
||||
asyncio_logger.addFilter(ConnectionResetFilter())
|
||||
|
||||
added_targets = set() # Track already added target paths
|
||||
|
||||
# Add static route for example images if the path exists in settings
|
||||
@@ -110,35 +125,37 @@ class LoraManager:
|
||||
# Add static route for plugin assets
|
||||
app.router.add_static('/loras_static', config.static_path)
|
||||
|
||||
# Setup feature routes
|
||||
lora_routes = LoraRoutes()
|
||||
checkpoints_routes = CheckpointsRoutes()
|
||||
stats_routes = StatsRoutes()
|
||||
# Register default model types with the factory
|
||||
register_default_model_types()
|
||||
|
||||
# Initialize routes
|
||||
lora_routes.setup_routes(app)
|
||||
checkpoints_routes.setup_routes(app)
|
||||
stats_routes.setup_routes(app) # Add statistics routes
|
||||
ApiRoutes.setup_routes(app)
|
||||
# Setup all model routes using the factory
|
||||
ModelServiceFactory.setup_all_routes(app)
|
||||
|
||||
# Setup non-model-specific routes
|
||||
stats_routes = StatsRoutes()
|
||||
stats_routes.setup_routes(app)
|
||||
RecipeRoutes.setup_routes(app)
|
||||
UpdateRoutes.setup_routes(app)
|
||||
MiscRoutes.setup_routes(app) # Register miscellaneous routes
|
||||
ExampleImagesRoutes.setup_routes(app) # Register example images routes
|
||||
MiscRoutes.setup_routes(app)
|
||||
ExampleImagesRoutes.setup_routes(app)
|
||||
|
||||
# Setup WebSocket routes that are shared across all model types
|
||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection)
|
||||
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection)
|
||||
|
||||
# Schedule service initialization
|
||||
app.on_startup.append(lambda app: cls._initialize_services())
|
||||
|
||||
# Add cleanup
|
||||
app.on_shutdown.append(cls._cleanup)
|
||||
app.on_shutdown.append(ApiRoutes.cleanup)
|
||||
|
||||
logger.info(f"LoRA Manager: Set up routes for {len(ModelServiceFactory.get_registered_types())} model types: {', '.join(ModelServiceFactory.get_registered_types())}")
|
||||
|
||||
@classmethod
|
||||
async def _initialize_services(cls):
|
||||
"""Initialize all services using the ServiceRegistry"""
|
||||
try:
|
||||
# Ensure aiohttp access logger is configured with reduced verbosity
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
# Initialize CivitaiClient first to ensure it's ready for other services
|
||||
await ServiceRegistry.get_civitai_client()
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
619
py/routes/base_model_routes.py
Normal file
619
py/routes/base_model_routes.py
Normal file
@@ -0,0 +1,619 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from aiohttp import web
|
||||
from typing import Dict
|
||||
|
||||
import jinja2
|
||||
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
from ..services.websocket_manager import ws_manager
|
||||
from ..services.settings_manager import settings
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseModelRoutes(ABC):
|
||||
"""Base route controller for all model types"""
|
||||
|
||||
def __init__(self, service):
|
||||
"""Initialize the route controller
|
||||
|
||||
Args:
|
||||
service: Model service instance (LoraService, CheckpointService, etc.)
|
||||
"""
|
||||
self.service = service
|
||||
self.model_type = service.model_type
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
|
||||
def setup_routes(self, app: web.Application, prefix: str):
|
||||
"""Setup common routes for the model type
|
||||
|
||||
Args:
|
||||
app: aiohttp application
|
||||
prefix: URL prefix (e.g., 'loras', 'checkpoints')
|
||||
"""
|
||||
# Common model management routes
|
||||
app.router.add_get(f'/api/{prefix}', self.get_models)
|
||||
app.router.add_post(f'/api/{prefix}/delete', self.delete_model)
|
||||
app.router.add_post(f'/api/{prefix}/exclude', self.exclude_model)
|
||||
app.router.add_post(f'/api/{prefix}/fetch-civitai', self.fetch_civitai)
|
||||
app.router.add_post(f'/api/{prefix}/relink-civitai', self.relink_civitai)
|
||||
app.router.add_post(f'/api/{prefix}/replace-preview', self.replace_preview)
|
||||
app.router.add_post(f'/api/{prefix}/save-metadata', self.save_metadata)
|
||||
app.router.add_post(f'/api/{prefix}/rename', self.rename_model)
|
||||
app.router.add_post(f'/api/{prefix}/bulk-delete', self.bulk_delete_models)
|
||||
app.router.add_post(f'/api/{prefix}/verify-duplicates', self.verify_duplicates)
|
||||
|
||||
# Common query routes
|
||||
app.router.add_get(f'/api/{prefix}/top-tags', self.get_top_tags)
|
||||
app.router.add_get(f'/api/{prefix}/base-models', self.get_base_models)
|
||||
app.router.add_get(f'/api/{prefix}/scan', self.scan_models)
|
||||
app.router.add_get(f'/api/{prefix}/roots', self.get_model_roots)
|
||||
app.router.add_get(f'/api/{prefix}/folders', self.get_folders)
|
||||
app.router.add_get(f'/api/{prefix}/find-duplicates', self.find_duplicate_models)
|
||||
app.router.add_get(f'/api/{prefix}/find-filename-conflicts', self.find_filename_conflicts)
|
||||
|
||||
# Common Download management
|
||||
app.router.add_post(f'/api/download-model', self.download_model)
|
||||
app.router.add_get(f'/api/download-model-get', self.download_model_get)
|
||||
app.router.add_get(f'/api/cancel-download-get', self.cancel_download_get)
|
||||
app.router.add_get(f'/api/download-progress/{{download_id}}', self.get_download_progress)
|
||||
|
||||
# CivitAI integration routes
|
||||
app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai)
|
||||
# app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions)
|
||||
|
||||
# Add generic page route
|
||||
app.router.add_get(f'/{prefix}', self.handle_models_page)
|
||||
|
||||
# Setup model-specific routes
|
||||
self.setup_specific_routes(app, prefix)
|
||||
|
||||
@abstractmethod
|
||||
def setup_specific_routes(self, app: web.Application, prefix: str):
|
||||
"""Setup model-specific routes - to be implemented by subclasses"""
|
||||
pass
|
||||
|
||||
async def handle_models_page(self, request: web.Request) -> web.Response:
|
||||
"""
|
||||
Generic handler for model pages (e.g., /loras, /checkpoints).
|
||||
Subclasses should set self.template_env and template_name.
|
||||
"""
|
||||
try:
|
||||
# Check if the scanner is initializing
|
||||
is_initializing = (
|
||||
self.service.scanner._cache is None or
|
||||
(hasattr(self.service.scanner, 'is_initializing') and callable(self.service.scanner.is_initializing) and self.service.scanner.is_initializing()) or
|
||||
(hasattr(self.service.scanner, '_is_initializing') and self.service.scanner._is_initializing)
|
||||
)
|
||||
|
||||
template_name = getattr(self, "template_name", None)
|
||||
if not self.template_env or not template_name:
|
||||
return web.Response(text="Template environment or template name not set", status=500)
|
||||
|
||||
if is_initializing:
|
||||
rendered = self.template_env.get_template(template_name).render(
|
||||
folders=[],
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
else:
|
||||
try:
|
||||
cache = await self.service.scanner.get_cached_data(force_refresh=False)
|
||||
rendered = self.template_env.get_template(template_name).render(
|
||||
folders=getattr(cache, "folders", []),
|
||||
is_initializing=False,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading cache data: {cache_error}")
|
||||
rendered = self.template_env.get_template(template_name).render(
|
||||
folders=[],
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling models page: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading models page",
|
||||
status=500
|
||||
)
|
||||
|
||||
async def get_models(self, request: web.Request) -> web.Response:
|
||||
"""Get paginated model data"""
|
||||
try:
|
||||
# Parse common query parameters
|
||||
params = self._parse_common_params(request)
|
||||
|
||||
# Get data from service
|
||||
result = await self.service.get_paginated_data(**params)
|
||||
|
||||
# Format response items
|
||||
formatted_result = {
|
||||
'items': [await self.service.format_response(item) for item in result['items']],
|
||||
'total': result['total'],
|
||||
'page': result['page'],
|
||||
'page_size': result['page_size'],
|
||||
'total_pages': result['total_pages']
|
||||
}
|
||||
|
||||
return web.json_response(formatted_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_{self.model_type}s: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
def _parse_common_params(self, request: web.Request) -> Dict:
|
||||
"""Parse common query parameters"""
|
||||
# Parse basic pagination and sorting
|
||||
page = int(request.query.get('page', '1'))
|
||||
page_size = min(int(request.query.get('page_size', '20')), 100)
|
||||
sort_by = request.query.get('sort_by', 'name')
|
||||
folder = request.query.get('folder', None)
|
||||
search = request.query.get('search', None)
|
||||
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
|
||||
|
||||
# Parse filter arrays
|
||||
base_models = request.query.getall('base_model', [])
|
||||
tags = request.query.getall('tag', [])
|
||||
favorites_only = request.query.get('favorites_only', 'false').lower() == 'true'
|
||||
|
||||
# Parse search options
|
||||
search_options = {
|
||||
'filename': request.query.get('search_filename', 'true').lower() == 'true',
|
||||
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
|
||||
'tags': request.query.get('search_tags', 'false').lower() == 'true',
|
||||
'recursive': request.query.get('recursive', 'false').lower() == 'true',
|
||||
}
|
||||
|
||||
# Parse hash filters if provided
|
||||
hash_filters = {}
|
||||
if 'hash' in request.query:
|
||||
hash_filters['single_hash'] = request.query['hash']
|
||||
elif 'hashes' in request.query:
|
||||
try:
|
||||
hash_list = json.loads(request.query['hashes'])
|
||||
if isinstance(hash_list, list):
|
||||
hash_filters['multiple_hashes'] = hash_list
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
return {
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'sort_by': sort_by,
|
||||
'folder': folder,
|
||||
'search': search,
|
||||
'fuzzy_search': fuzzy_search,
|
||||
'base_models': base_models,
|
||||
'tags': tags,
|
||||
'search_options': search_options,
|
||||
'hash_filters': hash_filters,
|
||||
'favorites_only': favorites_only,
|
||||
# Add model-specific parameters
|
||||
**self._parse_specific_params(request)
|
||||
}
|
||||
|
||||
def _parse_specific_params(self, request: web.Request) -> Dict:
|
||||
"""Parse model-specific parameters - to be overridden by subclasses"""
|
||||
return {}
|
||||
|
||||
# Common route handlers
|
||||
async def delete_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle model deletion request"""
|
||||
return await ModelRouteUtils.handle_delete_model(request, self.service.scanner)
|
||||
|
||||
async def exclude_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle model exclusion request"""
|
||||
return await ModelRouteUtils.handle_exclude_model(request, self.service.scanner)
|
||||
|
||||
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Handle CivitAI metadata fetch request"""
|
||||
response = await ModelRouteUtils.handle_fetch_civitai(request, self.service.scanner)
|
||||
|
||||
# If successful, format the metadata before returning
|
||||
if response.status == 200:
|
||||
data = json.loads(response.body.decode('utf-8'))
|
||||
if data.get("success") and data.get("metadata"):
|
||||
formatted_metadata = await self.service.format_response(data["metadata"])
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"metadata": formatted_metadata
|
||||
})
|
||||
|
||||
return response
|
||||
|
||||
async def relink_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Handle CivitAI metadata re-linking request"""
|
||||
return await ModelRouteUtils.handle_relink_civitai(request, self.service.scanner)
|
||||
|
||||
async def replace_preview(self, request: web.Request) -> web.Response:
|
||||
"""Handle preview image replacement"""
|
||||
return await ModelRouteUtils.handle_replace_preview(request, self.service.scanner)
|
||||
|
||||
async def save_metadata(self, request: web.Request) -> web.Response:
|
||||
"""Handle saving metadata updates"""
|
||||
return await ModelRouteUtils.handle_save_metadata(request, self.service.scanner)
|
||||
|
||||
async def rename_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle renaming a model file and its associated files"""
|
||||
return await ModelRouteUtils.handle_rename_model(request, self.service.scanner)
|
||||
|
||||
async def bulk_delete_models(self, request: web.Request) -> web.Response:
|
||||
"""Handle bulk deletion of models"""
|
||||
return await ModelRouteUtils.handle_bulk_delete_models(request, self.service.scanner)
|
||||
|
||||
async def verify_duplicates(self, request: web.Request) -> web.Response:
|
||||
"""Handle verification of duplicate model hashes"""
|
||||
return await ModelRouteUtils.handle_verify_duplicates(request, self.service.scanner)
|
||||
|
||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||
"""Handle request for top tags sorted by frequency"""
|
||||
try:
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
if limit < 1 or limit > 100:
|
||||
limit = 20
|
||||
|
||||
top_tags = await self.service.get_top_tags(limit)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'tags': top_tags
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Internal server error'
|
||||
}, status=500)
|
||||
|
||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||
"""Get base models used in models"""
|
||||
try:
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
if limit < 1 or limit > 100:
|
||||
limit = 20
|
||||
|
||||
base_models = await self.service.get_base_models(limit)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'base_models': base_models
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving base models: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def scan_models(self, request: web.Request) -> web.Response:
|
||||
"""Force a rescan of model files"""
|
||||
try:
|
||||
full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true'
|
||||
|
||||
await self.service.scan_models(force_refresh=True, rebuild_cache=full_rebuild)
|
||||
return web.json_response({
|
||||
"status": "success",
|
||||
"message": f"{self.model_type.capitalize()} scan completed"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error in scan_{self.model_type}s: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_model_roots(self, request: web.Request) -> web.Response:
|
||||
"""Return the model root directories"""
|
||||
try:
|
||||
roots = self.service.get_model_roots()
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"roots": roots
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting {self.model_type} roots: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_folders(self, request: web.Request) -> web.Response:
|
||||
"""Get all folders in the cache"""
|
||||
try:
|
||||
cache = await self.service.scanner.get_cached_data()
|
||||
return web.json_response({
|
||||
'folders': cache.folders
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting folders: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def find_duplicate_models(self, request: web.Request) -> web.Response:
|
||||
"""Find models with duplicate SHA256 hashes"""
|
||||
try:
|
||||
# Get duplicate hashes from service
|
||||
duplicates = self.service.find_duplicate_hashes()
|
||||
|
||||
# Format the response
|
||||
result = []
|
||||
cache = await self.service.scanner.get_cached_data()
|
||||
|
||||
for sha256, paths in duplicates.items():
|
||||
group = {
|
||||
"hash": sha256,
|
||||
"models": []
|
||||
}
|
||||
# Find matching models for each path
|
||||
for path in paths:
|
||||
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
|
||||
if model:
|
||||
group["models"].append(await self.service.format_response(model))
|
||||
|
||||
# Add the primary model too
|
||||
primary_path = self.service.get_path_by_hash(sha256)
|
||||
if primary_path and primary_path not in paths:
|
||||
primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None)
|
||||
if primary_model:
|
||||
group["models"].insert(0, await self.service.format_response(primary_model))
|
||||
|
||||
if len(group["models"]) > 1: # Only include if we found multiple models
|
||||
result.append(group)
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"duplicates": result,
|
||||
"count": len(result)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding duplicate {self.model_type}s: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def find_filename_conflicts(self, request: web.Request) -> web.Response:
|
||||
"""Find models with conflicting filenames"""
|
||||
try:
|
||||
# Get duplicate filenames from service
|
||||
duplicates = self.service.find_duplicate_filenames()
|
||||
|
||||
# Format the response
|
||||
result = []
|
||||
cache = await self.service.scanner.get_cached_data()
|
||||
|
||||
for filename, paths in duplicates.items():
|
||||
group = {
|
||||
"filename": filename,
|
||||
"models": []
|
||||
}
|
||||
# Find matching models for each path
|
||||
for path in paths:
|
||||
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
|
||||
if model:
|
||||
group["models"].append(await self.service.format_response(model))
|
||||
|
||||
# Find the model from the main index too
|
||||
hash_val = self.service.scanner._hash_index.get_hash_by_filename(filename)
|
||||
if hash_val:
|
||||
main_path = self.service.get_path_by_hash(hash_val)
|
||||
if main_path and main_path not in paths:
|
||||
main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None)
|
||||
if main_model:
|
||||
group["models"].insert(0, await self.service.format_response(main_model))
|
||||
|
||||
if group["models"]:
|
||||
result.append(group)
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"conflicts": result,
|
||||
"count": len(result)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding filename conflicts for {self.model_type}s: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
# Download management methods
|
||||
async def download_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle model download request"""
|
||||
return await ModelRouteUtils.handle_download_model(request)
|
||||
|
||||
async def download_model_get(self, request: web.Request) -> web.Response:
|
||||
"""Handle model download request via GET method"""
|
||||
try:
|
||||
# Extract query parameters
|
||||
model_id = request.query.get('model_id')
|
||||
if not model_id:
|
||||
return web.Response(
|
||||
status=400,
|
||||
text="Missing required parameter: Please provide 'model_id'"
|
||||
)
|
||||
|
||||
# Get optional parameters
|
||||
model_version_id = request.query.get('model_version_id')
|
||||
download_id = request.query.get('download_id')
|
||||
use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true'
|
||||
|
||||
# Create a data dictionary that mimics what would be received from a POST request
|
||||
data = {
|
||||
'model_id': model_id
|
||||
}
|
||||
|
||||
# Add optional parameters only if they are provided
|
||||
if model_version_id:
|
||||
data['model_version_id'] = model_version_id
|
||||
|
||||
if download_id:
|
||||
data['download_id'] = download_id
|
||||
|
||||
data['use_default_paths'] = use_default_paths
|
||||
|
||||
# Create a mock request object with the data
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
future.set_result(data)
|
||||
|
||||
mock_request = type('MockRequest', (), {
|
||||
'json': lambda self=None: future
|
||||
})()
|
||||
|
||||
# Call the existing download handler
|
||||
return await ModelRouteUtils.handle_download_model(mock_request)
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(f"Error downloading model via GET: {error_message}", exc_info=True)
|
||||
return web.Response(status=500, text=error_message)
|
||||
|
||||
async def cancel_download_get(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET request for cancelling a download by download_id"""
|
||||
try:
|
||||
download_id = request.query.get('download_id')
|
||||
if not download_id:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Download ID is required'
|
||||
}, status=400)
|
||||
|
||||
# Create a mock request with match_info for compatibility
|
||||
mock_request = type('MockRequest', (), {
|
||||
'match_info': {'download_id': download_id}
|
||||
})()
|
||||
return await ModelRouteUtils.handle_cancel_download(mock_request)
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling download via GET: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_download_progress(self, request: web.Request) -> web.Response:
|
||||
"""Handle request for download progress by download_id"""
|
||||
try:
|
||||
# Get download_id from URL path
|
||||
download_id = request.match_info.get('download_id')
|
||||
if not download_id:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Download ID is required'
|
||||
}, status=400)
|
||||
|
||||
progress_data = ws_manager.get_download_progress(download_id)
|
||||
|
||||
if progress_data is None:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Download ID not found'
|
||||
}, status=404)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'progress': progress_data.get('progress', 0)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting download progress: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Fetch CivitAI metadata for all models in the background"""
|
||||
try:
|
||||
cache = await self.service.scanner.get_cached_data()
|
||||
total = len(cache.raw_data)
|
||||
processed = 0
|
||||
success = 0
|
||||
needs_resort = False
|
||||
|
||||
# Prepare models to process
|
||||
to_process = [
|
||||
model for model in cache.raw_data
|
||||
if model.get('sha256') and (not model.get('civitai') or 'id' not in model.get('civitai')) and model.get('from_civitai', True)
|
||||
]
|
||||
total_to_process = len(to_process)
|
||||
|
||||
# Send initial progress
|
||||
await ws_manager.broadcast({
|
||||
'status': 'started',
|
||||
'total': total_to_process,
|
||||
'processed': 0,
|
||||
'success': 0
|
||||
})
|
||||
|
||||
# Process each model
|
||||
for model in to_process:
|
||||
try:
|
||||
original_name = model.get('model_name')
|
||||
if await ModelRouteUtils.fetch_and_update_model(
|
||||
sha256=model['sha256'],
|
||||
file_path=model['file_path'],
|
||||
model_data=model,
|
||||
update_cache_func=self.service.scanner.update_single_model_cache
|
||||
):
|
||||
success += 1
|
||||
if original_name != model.get('model_name'):
|
||||
needs_resort = True
|
||||
|
||||
processed += 1
|
||||
|
||||
# Send progress update
|
||||
await ws_manager.broadcast({
|
||||
'status': 'processing',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success,
|
||||
'current_name': model.get('model_name', 'Unknown')
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivitAI data for {model['file_path']}: {e}")
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort()
|
||||
|
||||
# Send completion message
|
||||
await ws_manager.broadcast({
|
||||
'status': 'completed',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success
|
||||
})
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"message": f"Successfully updated {success} of {processed} processed {self.model_type}s (total: {total})"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
# Send error message
|
||||
await ws_manager.broadcast({
|
||||
'status': 'error',
|
||||
'error': str(e)
|
||||
})
|
||||
logger.error(f"Error in fetch_all_civitai for {self.model_type}s: {e}")
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_civitai_versions(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai model with local availability info"""
|
||||
# This will be implemented by subclasses as they need CivitAI client access
|
||||
return web.json_response({
|
||||
"error": "Not implemented in base class"
|
||||
}, status=501)
|
||||
105
py/routes/checkpoint_routes.py
Normal file
105
py/routes/checkpoint_routes.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import logging
|
||||
from aiohttp import web
|
||||
|
||||
from .base_model_routes import BaseModelRoutes
|
||||
from ..services.checkpoint_service import CheckpointService
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointRoutes(BaseModelRoutes):
|
||||
"""Checkpoint-specific route controller"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Checkpoint routes with Checkpoint service"""
|
||||
# Service will be initialized later via setup_routes
|
||||
self.service = None
|
||||
self.civitai_client = None
|
||||
self.template_name = "checkpoints.html"
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
self.service = CheckpointService(checkpoint_scanner)
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
# Initialize parent with the service
|
||||
super().__init__(self.service)
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Setup Checkpoint routes"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
# Setup common routes with 'checkpoints' prefix (includes page route)
|
||||
super().setup_routes(app, 'checkpoints')
|
||||
|
||||
def setup_specific_routes(self, app: web.Application, prefix: str):
|
||||
"""Setup Checkpoint-specific routes"""
|
||||
# Checkpoint-specific CivitAI integration
|
||||
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint)
|
||||
|
||||
# Checkpoint info by name
|
||||
app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info)
|
||||
|
||||
async def get_checkpoint_info(self, request: web.Request) -> web.Response:
|
||||
"""Get detailed information for a specific checkpoint by name"""
|
||||
try:
|
||||
name = request.match_info.get('name', '')
|
||||
checkpoint_info = await self.service.get_model_info_by_name(name)
|
||||
|
||||
if checkpoint_info:
|
||||
return web.json_response(checkpoint_info)
|
||||
else:
|
||||
return web.json_response({"error": "Checkpoint not found"}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_civitai_versions_checkpoint(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai checkpoint model with local availability info"""
|
||||
try:
|
||||
model_id = request.match_info['model_id']
|
||||
response = await self.civitai_client.get_model_versions(model_id)
|
||||
if not response or not response.get('modelVersions'):
|
||||
return web.Response(status=404, text="Model not found")
|
||||
|
||||
versions = response.get('modelVersions', [])
|
||||
model_type = response.get('type', '')
|
||||
|
||||
# Check model type - should be Checkpoint
|
||||
if model_type.lower() != 'checkpoint':
|
||||
return web.json_response({
|
||||
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
|
||||
}, status=400)
|
||||
|
||||
# Check local availability for each version
|
||||
for version in versions:
|
||||
# Find the primary model file (type="Model" and primary=true) in the files list
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model' and file.get('primary') == True), None)
|
||||
|
||||
# If no primary file found, try to find any model file
|
||||
if not model_file:
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model'), None)
|
||||
|
||||
if model_file:
|
||||
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||
if sha256:
|
||||
# Set existsLocally and localPath at the version level
|
||||
version['existsLocally'] = self.service.has_hash(sha256)
|
||||
if version['existsLocally']:
|
||||
version['localPath'] = self.service.get_path_by_hash(sha256)
|
||||
|
||||
# Also set the model file size at the version level for easier access
|
||||
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||
else:
|
||||
# No model file found in this version
|
||||
version['existsLocally'] = False
|
||||
|
||||
return web.json_response(versions)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching checkpoint model versions: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
@@ -1,771 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from ..services.websocket_manager import ws_manager
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
from ..utils.utils import fuzzy_match
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointsRoutes:
|
||||
"""API routes for checkpoint management"""
|
||||
|
||||
def __init__(self):
|
||||
self.scanner = None # Will be initialized in setup_routes
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
self.download_manager = None # Will be initialized in setup_routes
|
||||
self._download_lock = asyncio.Lock()
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||
|
||||
def setup_routes(self, app):
|
||||
"""Register routes with the aiohttp app"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
|
||||
app.router.add_get('/api/checkpoints', self.get_checkpoints)
|
||||
app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai)
|
||||
app.router.add_get('/api/checkpoints/base-models', self.get_base_models)
|
||||
app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags)
|
||||
app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints)
|
||||
app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info)
|
||||
app.router.add_get('/api/checkpoints/roots', self.get_checkpoint_roots)
|
||||
app.router.add_get('/api/checkpoints/civitai/versions/{model_id}', self.get_civitai_versions) # Add new route
|
||||
|
||||
# Add new routes for model management similar to LoRA routes
|
||||
app.router.add_post('/api/checkpoints/delete', self.delete_model)
|
||||
app.router.add_post('/api/checkpoints/exclude', self.exclude_model) # Add new exclude endpoint
|
||||
app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai)
|
||||
app.router.add_post('/api/checkpoints/relink-civitai', self.relink_civitai) # Add new relink endpoint
|
||||
app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview)
|
||||
app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route
|
||||
app.router.add_post('/api/checkpoints/rename', self.rename_checkpoint) # Add new rename endpoint
|
||||
|
||||
# Add new routes for finding duplicates and filename conflicts
|
||||
app.router.add_get('/api/checkpoints/find-duplicates', self.find_duplicate_checkpoints)
|
||||
app.router.add_get('/api/checkpoints/find-filename-conflicts', self.find_filename_conflicts)
|
||||
|
||||
# Add new endpoint for bulk deleting checkpoints
|
||||
app.router.add_post('/api/checkpoints/bulk-delete', self.bulk_delete_checkpoints)
|
||||
|
||||
# Add new endpoint for verifying duplicates
|
||||
app.router.add_post('/api/checkpoints/verify-duplicates', self.verify_duplicates)
|
||||
|
||||
async def get_checkpoints(self, request):
|
||||
"""Get paginated checkpoint data"""
|
||||
try:
|
||||
# Parse query parameters
|
||||
page = int(request.query.get('page', '1'))
|
||||
page_size = min(int(request.query.get('page_size', '20')), 100)
|
||||
sort_by = request.query.get('sort_by', 'name')
|
||||
folder = request.query.get('folder', None)
|
||||
search = request.query.get('search', None)
|
||||
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
|
||||
base_models = request.query.getall('base_model', [])
|
||||
tags = request.query.getall('tag', [])
|
||||
favorites_only = request.query.get('favorites_only', 'false').lower() == 'true' # Add favorites_only parameter
|
||||
|
||||
# Process search options
|
||||
search_options = {
|
||||
'filename': request.query.get('search_filename', 'true').lower() == 'true',
|
||||
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
|
||||
'tags': request.query.get('search_tags', 'false').lower() == 'true',
|
||||
'recursive': request.query.get('recursive', 'false').lower() == 'true',
|
||||
}
|
||||
|
||||
# Process hash filters if provided
|
||||
hash_filters = {}
|
||||
if 'hash' in request.query:
|
||||
hash_filters['single_hash'] = request.query['hash']
|
||||
elif 'hashes' in request.query:
|
||||
try:
|
||||
hash_list = json.loads(request.query['hashes'])
|
||||
if isinstance(hash_list, list):
|
||||
hash_filters['multiple_hashes'] = hash_list
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Get data from scanner
|
||||
result = await self.get_paginated_data(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
sort_by=sort_by,
|
||||
folder=folder,
|
||||
search=search,
|
||||
fuzzy_search=fuzzy_search,
|
||||
base_models=base_models,
|
||||
tags=tags,
|
||||
search_options=search_options,
|
||||
hash_filters=hash_filters,
|
||||
favorites_only=favorites_only # Pass favorites_only parameter
|
||||
)
|
||||
|
||||
# Format response items
|
||||
formatted_result = {
|
||||
'items': [self._format_checkpoint_response(cp) for cp in result['items']],
|
||||
'total': result['total'],
|
||||
'page': result['page'],
|
||||
'page_size': result['page_size'],
|
||||
'total_pages': result['total_pages']
|
||||
}
|
||||
|
||||
# Return as JSON
|
||||
return web.json_response(formatted_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_checkpoints: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_paginated_data(self, page, page_size, sort_by='name',
|
||||
folder=None, search=None, fuzzy_search=False,
|
||||
base_models=None, tags=None,
|
||||
search_options=None, hash_filters=None,
|
||||
favorites_only=False): # Add favorites_only parameter with default False
|
||||
"""Get paginated and filtered checkpoint data"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Get default search options if not provided
|
||||
if search_options is None:
|
||||
search_options = {
|
||||
'filename': True,
|
||||
'modelname': True,
|
||||
'tags': False,
|
||||
'recursive': False,
|
||||
}
|
||||
|
||||
# Get the base data set
|
||||
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
||||
|
||||
# Apply hash filtering if provided (highest priority)
|
||||
if hash_filters:
|
||||
single_hash = hash_filters.get('single_hash')
|
||||
multiple_hashes = hash_filters.get('multiple_hashes')
|
||||
|
||||
if single_hash:
|
||||
# Filter by single hash
|
||||
single_hash = single_hash.lower() # Ensure lowercase for matching
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp.get('sha256', '').lower() == single_hash
|
||||
]
|
||||
elif multiple_hashes:
|
||||
# Filter by multiple hashes
|
||||
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp.get('sha256', '').lower() in hash_set
|
||||
]
|
||||
|
||||
# Jump to pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
result = {
|
||||
'items': filtered_data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
# Apply SFW filtering if enabled in settings
|
||||
if settings.get('show_only_sfw', False):
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if not cp.get('preview_nsfw_level') or cp.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
||||
]
|
||||
|
||||
# Apply favorites filtering if enabled
|
||||
if favorites_only:
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp.get('favorite', False) is True
|
||||
]
|
||||
|
||||
# Apply folder filtering
|
||||
if folder is not None:
|
||||
if search_options.get('recursive', False):
|
||||
# Recursive folder filtering - include all subfolders
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp['folder'].startswith(folder)
|
||||
]
|
||||
else:
|
||||
# Exact folder filtering
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp['folder'] == folder
|
||||
]
|
||||
|
||||
# Apply base model filtering
|
||||
if base_models and len(base_models) > 0:
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp.get('base_model') in base_models
|
||||
]
|
||||
|
||||
# Apply tag filtering
|
||||
if tags and len(tags) > 0:
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if any(tag in cp.get('tags', []) for tag in tags)
|
||||
]
|
||||
|
||||
# Apply search filtering
|
||||
if search:
|
||||
search_results = []
|
||||
|
||||
for cp in filtered_data:
|
||||
# Search by file name
|
||||
if search_options.get('filename', True):
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(cp.get('file_name', ''), search):
|
||||
search_results.append(cp)
|
||||
continue
|
||||
elif search.lower() in cp.get('file_name', '').lower():
|
||||
search_results.append(cp)
|
||||
continue
|
||||
|
||||
# Search by model name
|
||||
if search_options.get('modelname', True):
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(cp.get('model_name', ''), search):
|
||||
search_results.append(cp)
|
||||
continue
|
||||
elif search.lower() in cp.get('model_name', '').lower():
|
||||
search_results.append(cp)
|
||||
continue
|
||||
|
||||
# Search by tags
|
||||
if search_options.get('tags', False) and 'tags' in cp:
|
||||
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower()) for tag in cp['tags']):
|
||||
search_results.append(cp)
|
||||
continue
|
||||
|
||||
filtered_data = search_results
|
||||
|
||||
# Calculate pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
result = {
|
||||
'items': filtered_data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def _format_checkpoint_response(self, checkpoint):
|
||||
"""Format checkpoint data for API response"""
|
||||
return {
|
||||
"model_name": checkpoint["model_name"],
|
||||
"file_name": checkpoint["file_name"],
|
||||
"preview_url": config.get_preview_static_url(checkpoint.get("preview_url", "")),
|
||||
"preview_nsfw_level": checkpoint.get("preview_nsfw_level", 0),
|
||||
"base_model": checkpoint.get("base_model", ""),
|
||||
"folder": checkpoint["folder"],
|
||||
"sha256": checkpoint.get("sha256", ""),
|
||||
"file_path": checkpoint["file_path"].replace(os.sep, "/"),
|
||||
"file_size": checkpoint.get("size", 0),
|
||||
"modified": checkpoint.get("modified", ""),
|
||||
"tags": checkpoint.get("tags", []),
|
||||
"modelDescription": checkpoint.get("modelDescription", ""),
|
||||
"from_civitai": checkpoint.get("from_civitai", True),
|
||||
"notes": checkpoint.get("notes", ""),
|
||||
"model_type": checkpoint.get("model_type", "checkpoint"),
|
||||
"favorite": checkpoint.get("favorite", False),
|
||||
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint.get("civitai", {}))
|
||||
}
|
||||
|
||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Fetch CivitAI metadata for all checkpoints in the background"""
|
||||
try:
|
||||
cache = await self.scanner.get_cached_data()
|
||||
total = len(cache.raw_data)
|
||||
processed = 0
|
||||
success = 0
|
||||
needs_resort = False
|
||||
|
||||
# Prepare checkpoints to process
|
||||
to_process = [
|
||||
cp for cp in cache.raw_data
|
||||
if cp.get('sha256') and (not cp.get('civitai') or 'id' not in cp.get('civitai')) and cp.get('from_civitai', True)
|
||||
]
|
||||
total_to_process = len(to_process)
|
||||
|
||||
# Send initial progress
|
||||
await ws_manager.broadcast({
|
||||
'status': 'started',
|
||||
'total': total_to_process,
|
||||
'processed': 0,
|
||||
'success': 0
|
||||
})
|
||||
|
||||
# Process each checkpoint
|
||||
for cp in to_process:
|
||||
try:
|
||||
original_name = cp.get('model_name')
|
||||
if await ModelRouteUtils.fetch_and_update_model(
|
||||
sha256=cp['sha256'],
|
||||
file_path=cp['file_path'],
|
||||
model_data=cp,
|
||||
update_cache_func=self.scanner.update_single_model_cache
|
||||
):
|
||||
success += 1
|
||||
if original_name != cp.get('model_name'):
|
||||
needs_resort = True
|
||||
|
||||
processed += 1
|
||||
|
||||
# Send progress update
|
||||
await ws_manager.broadcast({
|
||||
'status': 'processing',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success,
|
||||
'current_name': cp.get('model_name', 'Unknown')
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivitAI data for {cp['file_path']}: {e}")
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort(name_only=True)
|
||||
|
||||
# Send completion message
|
||||
await ws_manager.broadcast({
|
||||
'status': 'completed',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success
|
||||
})
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"message": f"Successfully updated {success} of {processed} processed checkpoints (total: {total})"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
# Send error message
|
||||
await ws_manager.broadcast({
|
||||
'status': 'error',
|
||||
'error': str(e)
|
||||
})
|
||||
logger.error(f"Error in fetch_all_civitai for checkpoints: {e}")
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||
"""Handle request for top tags sorted by frequency"""
|
||||
try:
|
||||
# Parse query parameters
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
|
||||
# Validate limit
|
||||
if limit < 1 or limit > 100:
|
||||
limit = 20 # Default to a reasonable limit
|
||||
|
||||
# Get top tags
|
||||
top_tags = await self.scanner.get_top_tags(limit)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'tags': top_tags
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Internal server error'
|
||||
}, status=500)
|
||||
|
||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||
"""Get base models used in loras"""
|
||||
try:
|
||||
# Parse query parameters
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
|
||||
# Validate limit
|
||||
if limit < 1 or limit > 100:
|
||||
limit = 20 # Default to a reasonable limit
|
||||
|
||||
# Get base models
|
||||
base_models = await self.scanner.get_base_models(limit)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'base_models': base_models
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving base models: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def scan_checkpoints(self, request):
|
||||
"""Force a rescan of checkpoint files"""
|
||||
try:
|
||||
# Get the full_rebuild parameter and convert to bool, default to False
|
||||
full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true'
|
||||
|
||||
await self.scanner.get_cached_data(force_refresh=True, rebuild_cache=full_rebuild)
|
||||
return web.json_response({"status": "success", "message": "Checkpoint scan completed"})
|
||||
except Exception as e:
|
||||
logger.error(f"Error in scan_checkpoints: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_checkpoint_info(self, request):
|
||||
"""Get detailed information for a specific checkpoint by name"""
|
||||
try:
|
||||
name = request.match_info.get('name', '')
|
||||
checkpoint_info = await self.scanner.get_model_info_by_name(name)
|
||||
|
||||
if checkpoint_info:
|
||||
return web.json_response(checkpoint_info)
|
||||
else:
|
||||
return web.json_response({"error": "Checkpoint not found"}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def handle_checkpoints_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /checkpoints request"""
|
||||
try:
|
||||
# Check if the CheckpointScanner is initializing
|
||||
# It's initializing if the cache object doesn't exist yet,
|
||||
# OR if the scanner explicitly says it's initializing (background task running).
|
||||
is_initializing = (
|
||||
self.scanner._cache is None or
|
||||
(hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing)
|
||||
)
|
||||
|
||||
if is_initializing:
|
||||
# If still initializing, return loading page
|
||||
template = self.template_env.get_template('checkpoints.html')
|
||||
rendered = template.render(
|
||||
folders=[], # 空文件夹列表
|
||||
is_initializing=True, # 新增标志
|
||||
settings=settings, # Pass settings to template
|
||||
request=request # Pass the request object to the template
|
||||
)
|
||||
|
||||
logger.info("Checkpoints page is initializing, returning loading page")
|
||||
else:
|
||||
# 正常流程 - 获取已经初始化好的缓存数据
|
||||
try:
|
||||
cache = await self.scanner.get_cached_data(force_refresh=False)
|
||||
template = self.template_env.get_template('checkpoints.html')
|
||||
rendered = template.render(
|
||||
folders=cache.folders,
|
||||
is_initializing=False,
|
||||
settings=settings, # Pass settings to template
|
||||
request=request # Pass the request object to the template
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading checkpoints cache data: {cache_error}")
|
||||
# 如果获取缓存失败,也显示初始化页面
|
||||
template = self.template_env.get_template('checkpoints.html')
|
||||
rendered = template.render(
|
||||
folders=[],
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
logger.info("Checkpoints cache error, returning initialization page")
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling checkpoints request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading checkpoints page",
|
||||
status=500
|
||||
)
|
||||
|
||||
async def delete_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle checkpoint model deletion request"""
|
||||
return await ModelRouteUtils.handle_delete_model(request, self.scanner)
|
||||
|
||||
async def exclude_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle checkpoint model exclusion request"""
|
||||
return await ModelRouteUtils.handle_exclude_model(request, self.scanner)
|
||||
|
||||
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Handle CivitAI metadata fetch request for checkpoints"""
|
||||
response = await ModelRouteUtils.handle_fetch_civitai(request, self.scanner)
|
||||
|
||||
# If successful, format the metadata before returning
|
||||
if response.status == 200:
|
||||
data = json.loads(response.body.decode('utf-8'))
|
||||
if data.get("success") and data.get("metadata"):
|
||||
formatted_metadata = self._format_checkpoint_response(data["metadata"])
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"metadata": formatted_metadata
|
||||
})
|
||||
|
||||
# Otherwise, return the original response
|
||||
return response
|
||||
|
||||
async def replace_preview(self, request: web.Request) -> web.Response:
|
||||
"""Handle preview image replacement for checkpoints"""
|
||||
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
|
||||
|
||||
async def get_checkpoint_roots(self, request):
|
||||
"""Return the checkpoint root directories"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
roots = self.scanner.get_model_roots()
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"roots": roots
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting checkpoint roots: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def save_metadata(self, request: web.Request) -> web.Response:
|
||||
"""Handle saving metadata updates for checkpoints"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
if not file_path:
|
||||
return web.Response(text='File path is required', status=400)
|
||||
|
||||
# Remove file path from data to avoid saving it
|
||||
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
|
||||
|
||||
# Get metadata file path
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
|
||||
# Load existing metadata
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
|
||||
# Update metadata
|
||||
metadata.update(metadata_updates)
|
||||
|
||||
# Save updated metadata
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
# Update cache
|
||||
await self.scanner.update_single_model_cache(file_path, file_path, metadata)
|
||||
|
||||
# If model_name was updated, resort the cache
|
||||
if 'model_name' in metadata_updates:
|
||||
cache = await self.scanner.get_cached_data()
|
||||
await cache.resort(name_only=True)
|
||||
|
||||
return web.json_response({'success': True})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint metadata: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_civitai_versions(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai checkpoint model with local availability info"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
# Get the civitai client from service registry
|
||||
civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
model_id = request.match_info['model_id']
|
||||
response = await civitai_client.get_model_versions(model_id)
|
||||
if not response or not response.get('modelVersions'):
|
||||
return web.Response(status=404, text="Model not found")
|
||||
|
||||
versions = response.get('modelVersions', [])
|
||||
model_type = response.get('type', '')
|
||||
|
||||
# Check model type - should be Checkpoint
|
||||
if (model_type.lower() != 'checkpoint'):
|
||||
return web.json_response({
|
||||
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
|
||||
}, status=400)
|
||||
|
||||
# Check local availability for each version
|
||||
for version in versions:
|
||||
# Find the primary model file (type="Model" and primary=true) in the files list
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model' and file.get('primary') == True), None)
|
||||
|
||||
# If no primary file found, try to find any model file
|
||||
if not model_file:
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model'), None)
|
||||
|
||||
if model_file:
|
||||
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||
if sha256:
|
||||
# Set existsLocally and localPath at the version level
|
||||
version['existsLocally'] = self.scanner.has_hash(sha256)
|
||||
if version['existsLocally']:
|
||||
version['localPath'] = self.scanner.get_path_by_hash(sha256)
|
||||
|
||||
# Also set the model file size at the version level for easier access
|
||||
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||
else:
|
||||
# No model file found in this version
|
||||
version['existsLocally'] = False
|
||||
|
||||
return web.json_response(versions)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching checkpoint model versions: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
async def find_duplicate_checkpoints(self, request: web.Request) -> web.Response:
|
||||
"""Find checkpoints with duplicate SHA256 hashes"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
# Get duplicate hashes from hash index
|
||||
duplicates = self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
# Format the response
|
||||
result = []
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for sha256, paths in duplicates.items():
|
||||
group = {
|
||||
"hash": sha256,
|
||||
"models": []
|
||||
}
|
||||
# Find matching models for each path
|
||||
for path in paths:
|
||||
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
|
||||
if model:
|
||||
group["models"].append(self._format_checkpoint_response(model))
|
||||
|
||||
# Add the primary model too
|
||||
primary_path = self.scanner._hash_index.get_path(sha256)
|
||||
if primary_path and primary_path not in paths:
|
||||
primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None)
|
||||
if primary_model:
|
||||
group["models"].insert(0, self._format_checkpoint_response(primary_model))
|
||||
|
||||
if len(group["models"]) > 1: # Only include if we found multiple models
|
||||
result.append(group)
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"duplicates": result,
|
||||
"count": len(result)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding duplicate checkpoints: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def find_filename_conflicts(self, request: web.Request) -> web.Response:
|
||||
"""Find checkpoints with conflicting filenames"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
# Get duplicate filenames from hash index
|
||||
duplicates = self.scanner._hash_index.get_duplicate_filenames()
|
||||
|
||||
# Format the response
|
||||
result = []
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for filename, paths in duplicates.items():
|
||||
group = {
|
||||
"filename": filename,
|
||||
"models": []
|
||||
}
|
||||
# Find matching models for each path
|
||||
for path in paths:
|
||||
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
|
||||
if model:
|
||||
group["models"].append(self._format_checkpoint_response(model))
|
||||
|
||||
# Find the model from the main index too
|
||||
hash_val = self.scanner._hash_index.get_hash_by_filename(filename)
|
||||
if hash_val:
|
||||
main_path = self.scanner._hash_index.get_path(hash_val)
|
||||
if main_path and main_path not in paths:
|
||||
main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None)
|
||||
if main_model:
|
||||
group["models"].insert(0, self._format_checkpoint_response(main_model))
|
||||
|
||||
if group["models"]:
|
||||
result.append(group)
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"conflicts": result,
|
||||
"count": len(result)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding filename conflicts: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def bulk_delete_checkpoints(self, request: web.Request) -> web.Response:
|
||||
"""Handle bulk deletion of checkpoint models"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
return await ModelRouteUtils.handle_bulk_delete_models(request, self.scanner)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in bulk delete checkpoints: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def relink_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Handle CivitAI metadata re-linking request by model version ID for checkpoints"""
|
||||
return await ModelRouteUtils.handle_relink_civitai(request, self.scanner)
|
||||
|
||||
async def verify_duplicates(self, request: web.Request) -> web.Response:
|
||||
"""Handle verification of duplicate checkpoint hashes"""
|
||||
return await ModelRouteUtils.handle_verify_duplicates(request, self.scanner)
|
||||
|
||||
async def rename_checkpoint(self, request: web.Request) -> web.Response:
|
||||
"""Handle renaming a checkpoint file and its associated files"""
|
||||
return await ModelRouteUtils.handle_rename_model(request, self.scanner)
|
||||
@@ -1,188 +1,512 @@
|
||||
import os
|
||||
from aiohttp import web
|
||||
import jinja2
|
||||
from typing import Dict
|
||||
import asyncio
|
||||
import logging
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
|
||||
from aiohttp import web
|
||||
from typing import Dict
|
||||
from server import PromptServer # type: ignore
|
||||
|
||||
from .base_model_routes import BaseModelRoutes
|
||||
from ..services.lora_service import LoraService
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
from ..utils.utils import get_lora_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
|
||||
|
||||
class LoraRoutes:
|
||||
"""Route handlers for LoRA management endpoints"""
|
||||
class LoraRoutes(BaseModelRoutes):
|
||||
"""LoRA-specific route controller"""
|
||||
|
||||
def __init__(self):
|
||||
# Initialize service references as None, will be set during async init
|
||||
self.scanner = None
|
||||
self.recipe_scanner = None
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
|
||||
async def init_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
"""Initialize LoRA routes with LoRA service"""
|
||||
# Service will be initialized later via setup_routes
|
||||
self.service = None
|
||||
self.civitai_client = None
|
||||
self.template_name = "loras.html"
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
self.service = LoraService(lora_scanner)
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
# Initialize parent with the service
|
||||
super().__init__(self.service)
|
||||
|
||||
def format_lora_data(self, lora: Dict) -> Dict:
|
||||
"""Format LoRA data for template rendering"""
|
||||
return {
|
||||
"model_name": lora["model_name"],
|
||||
"file_name": lora["file_name"],
|
||||
"preview_url": config.get_preview_static_url(lora["preview_url"]),
|
||||
"preview_nsfw_level": lora.get("preview_nsfw_level", 0),
|
||||
"base_model": lora["base_model"],
|
||||
"folder": lora["folder"],
|
||||
"sha256": lora["sha256"],
|
||||
"file_path": lora["file_path"].replace(os.sep, "/"),
|
||||
"size": lora["size"],
|
||||
"tags": lora["tags"],
|
||||
"modelDescription": lora["modelDescription"],
|
||||
"usage_tips": lora["usage_tips"],
|
||||
"notes": lora["notes"],
|
||||
"modified": lora["modified"],
|
||||
"from_civitai": lora.get("from_civitai", True),
|
||||
"civitai": self._filter_civitai_data(lora.get("civitai", {}))
|
||||
}
|
||||
|
||||
def _filter_civitai_data(self, data: Dict) -> Dict:
|
||||
"""Filter relevant fields from CivitAI data"""
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
fields = [
|
||||
"id", "modelId", "name", "createdAt", "updatedAt",
|
||||
"publishedAt", "trainedWords", "baseModel", "description",
|
||||
"model", "images"
|
||||
]
|
||||
return {k: data[k] for k in fields if k in data}
|
||||
|
||||
async def handle_loras_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /loras request"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Check if the LoraScanner is initializing
|
||||
# It's initializing if the cache object doesn't exist yet,
|
||||
# OR if the scanner explicitly says it's initializing (background task running).
|
||||
is_initializing = (
|
||||
self.scanner._cache is None or self.scanner.is_initializing()
|
||||
)
|
||||
|
||||
if is_initializing:
|
||||
# If still initializing, return loading page
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
folders=[],
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
|
||||
logger.info("Loras page is initializing, returning loading page")
|
||||
else:
|
||||
# Normal flow - get data from initialized cache
|
||||
try:
|
||||
cache = await self.scanner.get_cached_data(force_refresh=False)
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
folders=cache.folders,
|
||||
is_initializing=False,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading cache data: {cache_error}")
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
folders=[],
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
logger.info("Cache error, returning initialization page")
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling loras request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading loras page",
|
||||
status=500
|
||||
)
|
||||
|
||||
async def handle_recipes_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /loras/recipes request"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Skip initialization check and directly try to get cached data
|
||||
try:
|
||||
# Recipe scanner will initialize cache if needed
|
||||
await self.recipe_scanner.get_cached_data(force_refresh=False)
|
||||
template = self.template_env.get_template('recipes.html')
|
||||
rendered = template.render(
|
||||
recipes=[], # Frontend will load recipes via API
|
||||
is_initializing=False,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading recipe cache data: {cache_error}")
|
||||
# Still keep error handling - show initializing page on error
|
||||
template = self.template_env.get_template('recipes.html')
|
||||
rendered = template.render(
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
logger.info("Recipe cache error, returning initialization page")
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling recipes request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading recipes page",
|
||||
status=500
|
||||
)
|
||||
|
||||
def _format_recipe_file_url(self, file_path: str) -> str:
|
||||
"""Format file path for recipe image as a URL - same as in recipe_routes"""
|
||||
try:
|
||||
# Return the file URL directly for the first lora root's preview
|
||||
recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, '/')
|
||||
if file_path.replace(os.sep, '/').startswith(recipes_dir):
|
||||
relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, '/')
|
||||
return f"/loras_static/root1/preview/{relative_path}"
|
||||
|
||||
# If not in recipes dir, try to create a valid URL from the file path
|
||||
file_name = os.path.basename(file_path)
|
||||
return f"/loras_static/root1/preview/recipes/{file_name}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting recipe file URL: {e}", exc_info=True)
|
||||
return '/loras_static/images/no-preview.png' # Return default image on error
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Register routes with the application"""
|
||||
# Add an app startup handler to initialize services
|
||||
app.on_startup.append(self._on_startup)
|
||||
"""Setup LoRA routes"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
# Register routes
|
||||
app.router.add_get('/loras', self.handle_loras_page)
|
||||
app.router.add_get('/loras/recipes', self.handle_recipes_page)
|
||||
# Setup common routes with 'loras' prefix (includes page route)
|
||||
super().setup_routes(app, 'loras')
|
||||
|
||||
def setup_specific_routes(self, app: web.Application, prefix: str):
|
||||
"""Setup LoRA-specific routes"""
|
||||
# LoRA-specific query routes
|
||||
app.router.add_get(f'/api/{prefix}/letter-counts', self.get_letter_counts)
|
||||
app.router.add_get(f'/api/{prefix}/get-notes', self.get_lora_notes)
|
||||
app.router.add_get(f'/api/{prefix}/get-trigger-words', self.get_lora_trigger_words)
|
||||
app.router.add_get(f'/api/lora-preview-url', self.get_lora_preview_url)
|
||||
app.router.add_get(f'/api/lora-civitai-url', self.get_lora_civitai_url)
|
||||
app.router.add_get(f'/api/lora-model-description', self.get_lora_model_description)
|
||||
|
||||
async def _on_startup(self, app):
|
||||
"""Initialize services when the app starts"""
|
||||
await self.init_services()
|
||||
# LoRA-specific management routes
|
||||
app.router.add_post(f'/api/move_model', self.move_model)
|
||||
app.router.add_post(f'/api/move_models_bulk', self.move_models_bulk)
|
||||
|
||||
# CivitAI integration with LoRA-specific validation
|
||||
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_lora)
|
||||
app.router.add_get(f'/api/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version)
|
||||
app.router.add_get(f'/api/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash)
|
||||
|
||||
# ComfyUI integration
|
||||
app.router.add_post(f'/loramanager/get_trigger_words', self.get_trigger_words)
|
||||
|
||||
def _parse_specific_params(self, request: web.Request) -> Dict:
|
||||
"""Parse LoRA-specific parameters"""
|
||||
params = {}
|
||||
|
||||
# LoRA-specific parameters
|
||||
if 'first_letter' in request.query:
|
||||
params['first_letter'] = request.query.get('first_letter')
|
||||
|
||||
# Handle fuzzy search parameter name variation
|
||||
if request.query.get('fuzzy') == 'true':
|
||||
params['fuzzy_search'] = True
|
||||
|
||||
# Handle additional filter parameters for LoRAs
|
||||
if 'lora_hash' in request.query:
|
||||
if not params.get('hash_filters'):
|
||||
params['hash_filters'] = {}
|
||||
params['hash_filters']['single_hash'] = request.query['lora_hash'].lower()
|
||||
elif 'lora_hashes' in request.query:
|
||||
if not params.get('hash_filters'):
|
||||
params['hash_filters'] = {}
|
||||
params['hash_filters']['multiple_hashes'] = [h.lower() for h in request.query['lora_hashes'].split(',')]
|
||||
|
||||
return params
|
||||
|
||||
# LoRA-specific route handlers
|
||||
async def get_letter_counts(self, request: web.Request) -> web.Response:
|
||||
"""Get count of LoRAs for each letter of the alphabet"""
|
||||
try:
|
||||
letter_counts = await self.service.get_letter_counts()
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'letter_counts': letter_counts
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting letter counts: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_notes(self, request: web.Request) -> web.Response:
|
||||
"""Get notes for a specific LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
notes = await self.service.get_lora_notes(lora_name)
|
||||
if notes is not None:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'notes': notes
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'LoRA not found in cache'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora notes: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_trigger_words(self, request: web.Request) -> web.Response:
|
||||
"""Get trigger words for a specific LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
trigger_words = await self.service.get_lora_trigger_words(lora_name)
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'trigger_words': trigger_words
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora trigger words: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the static preview URL for a LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
preview_url = await self.service.get_lora_preview_url(lora_name)
|
||||
if preview_url:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'preview_url': preview_url
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No preview URL found for the specified lora'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora preview URL: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the Civitai URL for a LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
result = await self.service.get_lora_civitai_url(lora_name)
|
||||
if result['civitai_url']:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
**result
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No Civitai data found for the specified lora'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
# Override get_models to add LoRA-specific response data
|
||||
async def get_models(self, request: web.Request) -> web.Response:
|
||||
"""Get paginated LoRA data with LoRA-specific fields"""
|
||||
try:
|
||||
# Parse common query parameters
|
||||
params = self._parse_common_params(request)
|
||||
|
||||
# Get data from service
|
||||
result = await self.service.get_paginated_data(**params)
|
||||
|
||||
# Get all available folders from cache for LoRA-specific response
|
||||
cache = await self.service.scanner.get_cached_data()
|
||||
|
||||
# Format response items with LoRA-specific structure
|
||||
formatted_result = {
|
||||
'items': [await self.service.format_response(item) for item in result['items']],
|
||||
'folders': cache.folders, # LoRA-specific: include folders in response
|
||||
'total': result['total'],
|
||||
'page': result['page'],
|
||||
'page_size': result['page_size'],
|
||||
'total_pages': result['total_pages']
|
||||
}
|
||||
|
||||
return web.json_response(formatted_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_loras: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
# CivitAI integration methods
|
||||
async def get_civitai_versions_lora(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai LoRA model with local availability info"""
|
||||
try:
|
||||
model_id = request.match_info['model_id']
|
||||
response = await self.civitai_client.get_model_versions(model_id)
|
||||
if not response or not response.get('modelVersions'):
|
||||
return web.Response(status=404, text="Model not found")
|
||||
|
||||
versions = response.get('modelVersions', [])
|
||||
model_type = response.get('type', '')
|
||||
|
||||
# Check model type - should be LORA, LoCon, or DORA
|
||||
from ..utils.constants import VALID_LORA_TYPES
|
||||
if model_type.lower() not in VALID_LORA_TYPES:
|
||||
return web.json_response({
|
||||
'error': f"Model type mismatch. Expected LORA or LoCon, got {model_type}"
|
||||
}, status=400)
|
||||
|
||||
# Check local availability for each version
|
||||
for version in versions:
|
||||
# Find the model file (type="Model") in the files list
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model'), None)
|
||||
|
||||
if model_file:
|
||||
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||
if sha256:
|
||||
# Set existsLocally and localPath at the version level
|
||||
version['existsLocally'] = self.service.has_hash(sha256)
|
||||
if version['existsLocally']:
|
||||
version['localPath'] = self.service.get_path_by_hash(sha256)
|
||||
|
||||
# Also set the model file size at the version level for easier access
|
||||
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||
else:
|
||||
# No model file found in this version
|
||||
version['existsLocally'] = False
|
||||
|
||||
return web.json_response(versions)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching LoRA model versions: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
async def get_civitai_model_by_version(self, request: web.Request) -> web.Response:
|
||||
"""Get CivitAI model details by model version ID"""
|
||||
try:
|
||||
model_version_id = request.match_info.get('modelVersionId')
|
||||
|
||||
# Get model details from Civitai API
|
||||
model, error_msg = await self.civitai_client.get_model_version_info(model_version_id)
|
||||
|
||||
if not model:
|
||||
# Log warning for failed model retrieval
|
||||
logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}")
|
||||
|
||||
# Determine status code based on error message
|
||||
status_code = 404 if error_msg and "not found" in error_msg.lower() else 500
|
||||
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": error_msg or "Failed to fetch model information"
|
||||
}, status=status_code)
|
||||
|
||||
return web.json_response(model)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model details: {e}")
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response:
|
||||
"""Get CivitAI model details by hash"""
|
||||
try:
|
||||
hash = request.match_info.get('hash')
|
||||
model = await self.civitai_client.get_model_by_hash(hash)
|
||||
return web.json_response(model)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model details by hash: {e}")
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
# Model management methods
|
||||
async def move_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle model move request"""
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path') # full path of the model file
|
||||
target_path = data.get('target_path') # folder path to move the model to
|
||||
|
||||
if not file_path or not target_path:
|
||||
return web.Response(text='File path and target path are required', status=400)
|
||||
|
||||
# Check if source and destination are the same
|
||||
import os
|
||||
source_dir = os.path.dirname(file_path)
|
||||
if os.path.normpath(source_dir) == os.path.normpath(target_path):
|
||||
logger.info(f"Source and target directories are the same: {source_dir}")
|
||||
return web.json_response({'success': True, 'message': 'Source and target directories are the same'})
|
||||
|
||||
# Check if target file already exists
|
||||
file_name = os.path.basename(file_path)
|
||||
target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/')
|
||||
|
||||
if os.path.exists(target_file_path):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f"Target file already exists: {target_file_path}"
|
||||
}, status=409) # 409 Conflict
|
||||
|
||||
# Call scanner to handle the move operation
|
||||
success = await self.service.scanner.move_model(file_path, target_path)
|
||||
|
||||
if success:
|
||||
return web.json_response({'success': True})
|
||||
else:
|
||||
return web.Response(text='Failed to move model', status=500)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error moving model: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def move_models_bulk(self, request: web.Request) -> web.Response:
|
||||
"""Handle bulk model move request"""
|
||||
try:
|
||||
data = await request.json()
|
||||
file_paths = data.get('file_paths', []) # list of full paths of the model files
|
||||
target_path = data.get('target_path') # folder path to move the models to
|
||||
|
||||
if not file_paths or not target_path:
|
||||
return web.Response(text='File paths and target path are required', status=400)
|
||||
|
||||
results = []
|
||||
import os
|
||||
for file_path in file_paths:
|
||||
# Check if source and destination are the same
|
||||
source_dir = os.path.dirname(file_path)
|
||||
if os.path.normpath(source_dir) == os.path.normpath(target_path):
|
||||
results.append({
|
||||
"path": file_path,
|
||||
"success": True,
|
||||
"message": "Source and target directories are the same"
|
||||
})
|
||||
continue
|
||||
|
||||
# Check if target file already exists
|
||||
file_name = os.path.basename(file_path)
|
||||
target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/')
|
||||
|
||||
if os.path.exists(target_file_path):
|
||||
results.append({
|
||||
"path": file_path,
|
||||
"success": False,
|
||||
"message": f"Target file already exists: {target_file_path}"
|
||||
})
|
||||
continue
|
||||
|
||||
# Try to move the model
|
||||
success = await self.service.scanner.move_model(file_path, target_path)
|
||||
results.append({
|
||||
"path": file_path,
|
||||
"success": success,
|
||||
"message": "Success" if success else "Failed to move model"
|
||||
})
|
||||
|
||||
# Count successes and failures
|
||||
success_count = sum(1 for r in results if r["success"])
|
||||
failure_count = len(results) - success_count
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': f'Moved {success_count} of {len(file_paths)} models',
|
||||
'results': results,
|
||||
'success_count': success_count,
|
||||
'failure_count': failure_count
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error moving models in bulk: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_lora_model_description(self, request: web.Request) -> web.Response:
|
||||
"""Get model description for a Lora model"""
|
||||
try:
|
||||
# Get parameters
|
||||
model_id = request.query.get('model_id')
|
||||
file_path = request.query.get('file_path')
|
||||
|
||||
if not model_id:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Model ID is required'
|
||||
}, status=400)
|
||||
|
||||
# Check if we already have the description stored in metadata
|
||||
description = None
|
||||
tags = []
|
||||
creator = {}
|
||||
if file_path:
|
||||
import os
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
description = metadata.get('modelDescription')
|
||||
tags = metadata.get('tags', [])
|
||||
creator = metadata.get('creator', {})
|
||||
|
||||
# If description is not in metadata, fetch from CivitAI
|
||||
if not description:
|
||||
logger.info(f"Fetching model metadata for model ID: {model_id}")
|
||||
model_metadata, _ = await self.civitai_client.get_model_metadata(model_id)
|
||||
|
||||
if model_metadata:
|
||||
description = model_metadata.get('description')
|
||||
tags = model_metadata.get('tags', [])
|
||||
creator = model_metadata.get('creator', {})
|
||||
|
||||
# Save the metadata to file if we have a file path and got metadata
|
||||
if file_path:
|
||||
try:
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
|
||||
metadata['modelDescription'] = description
|
||||
metadata['tags'] = tags
|
||||
# Ensure the civitai dict exists
|
||||
if 'civitai' not in metadata:
|
||||
metadata['civitai'] = {}
|
||||
# Store creator in the civitai nested structure
|
||||
metadata['civitai']['creator'] = creator
|
||||
|
||||
await MetadataManager.save_metadata(file_path, metadata, True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model metadata: {e}")
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'description': description or "<p>No model description available.</p>",
|
||||
'tags': tags,
|
||||
'creator': creator
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model metadata: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_trigger_words(self, request: web.Request) -> web.Response:
|
||||
"""Get trigger words for specified LoRA models"""
|
||||
try:
|
||||
json_data = await request.json()
|
||||
lora_names = json_data.get("lora_names", [])
|
||||
node_ids = json_data.get("node_ids", [])
|
||||
|
||||
all_trigger_words = []
|
||||
for lora_name in lora_names:
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# Format the trigger words
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Send update to all connected trigger word toggle nodes
|
||||
for node_id in node_ids:
|
||||
PromptServer.instance.send_sync("trigger_word_update", {
|
||||
"id": node_id,
|
||||
"message": trigger_words_text
|
||||
})
|
||||
|
||||
return web.json_response({"success": True})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting trigger words: {e}")
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
@@ -633,9 +633,8 @@ class MiscRoutes:
|
||||
}, status=400)
|
||||
|
||||
# Get both lora and checkpoint scanners
|
||||
registry = ServiceRegistry.get_instance()
|
||||
lora_scanner = await registry.get_lora_scanner()
|
||||
checkpoint_scanner = await registry.get_checkpoint_scanner()
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
# If modelVersionId is provided, check for specific version
|
||||
if model_version_id_str:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import time
|
||||
import base64
|
||||
import jinja2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import io
|
||||
@@ -15,6 +16,7 @@ from ..utils.exif_utils import ExifUtils
|
||||
from ..recipes import RecipeParserFactory
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||
|
||||
from ..services.settings_manager import settings
|
||||
from ..config import config
|
||||
|
||||
# Check if running in standalone mode
|
||||
@@ -39,7 +41,10 @@ class RecipeRoutes:
|
||||
# Initialize service references as None, will be set during async init
|
||||
self.recipe_scanner = None
|
||||
self.civitai_client = None
|
||||
# Remove WorkflowParser instance
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
|
||||
# Pre-warm the cache
|
||||
self._init_cache_task = None
|
||||
@@ -53,6 +58,8 @@ class RecipeRoutes:
|
||||
def setup_routes(cls, app: web.Application):
|
||||
"""Register API routes"""
|
||||
routes = cls()
|
||||
app.router.add_get('/loras/recipes', routes.handle_recipes_page)
|
||||
|
||||
app.router.add_get('/api/recipes', routes.get_recipes)
|
||||
app.router.add_get('/api/recipe/{recipe_id}', routes.get_recipe_detail)
|
||||
app.router.add_post('/api/recipes/analyze-image', routes.analyze_recipe_image)
|
||||
@@ -114,6 +121,46 @@ class RecipeRoutes:
|
||||
await self.recipe_scanner.get_cached_data(force_refresh=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error pre-warming recipe cache: {e}", exc_info=True)
|
||||
|
||||
async def handle_recipes_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /loras/recipes request"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Skip initialization check and directly try to get cached data
|
||||
try:
|
||||
# Recipe scanner will initialize cache if needed
|
||||
await self.recipe_scanner.get_cached_data(force_refresh=False)
|
||||
template = self.template_env.get_template('recipes.html')
|
||||
rendered = template.render(
|
||||
recipes=[], # Frontend will load recipes via API
|
||||
is_initializing=False,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading recipe cache data: {cache_error}")
|
||||
# Still keep error handling - show initializing page on error
|
||||
template = self.template_env.get_template('recipes.html')
|
||||
rendered = template.render(
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
logger.info("Recipe cache error, returning initialization page")
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling recipes request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading recipes page",
|
||||
status=500
|
||||
)
|
||||
|
||||
async def get_recipes(self, request: web.Request) -> web.Response:
|
||||
"""API endpoint for getting paginated recipes"""
|
||||
@@ -1101,7 +1148,7 @@ class RecipeRoutes:
|
||||
for lora_name, lora_strength in lora_matches:
|
||||
try:
|
||||
# Get lora info from scanner
|
||||
lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora_name)
|
||||
lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora_name)
|
||||
|
||||
# Create lora entry
|
||||
lora_entry = {
|
||||
@@ -1120,7 +1167,7 @@ class RecipeRoutes:
|
||||
# Get base model from lora scanner for the available loras
|
||||
base_model_counts = {}
|
||||
for lora in loras_data:
|
||||
lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora.get("file_name", ""))
|
||||
lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora.get("file_name", ""))
|
||||
if lora_info and "base_model" in lora_info:
|
||||
base_model = lora_info["base_model"]
|
||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||
@@ -1210,7 +1257,7 @@ class RecipeRoutes:
|
||||
if lora.get("isDeleted", False):
|
||||
continue
|
||||
|
||||
if not self.recipe_scanner._lora_scanner.has_lora_hash(lora.get("hash", "")):
|
||||
if not self.recipe_scanner._lora_scanner.has_hash(lora.get("hash", "")):
|
||||
continue
|
||||
|
||||
# Get the strength
|
||||
@@ -1318,7 +1365,7 @@ class RecipeRoutes:
|
||||
return web.json_response({"error": "Recipe not found"}, status=404)
|
||||
|
||||
# Find target LoRA by name
|
||||
target_lora = await lora_scanner.get_lora_info_by_name(target_name)
|
||||
target_lora = await lora_scanner.get_model_info_by_name(target_name)
|
||||
if not target_lora:
|
||||
return web.json_response({"error": f"Local LoRA not found with name: {target_name}"}, status=404)
|
||||
|
||||
@@ -1430,9 +1477,9 @@ class RecipeRoutes:
|
||||
if 'loras' in recipe:
|
||||
for lora in recipe['loras']:
|
||||
if 'hash' in lora and lora['hash']:
|
||||
lora['inLibrary'] = self.recipe_scanner._lora_scanner.has_lora_hash(lora['hash'].lower())
|
||||
lora['inLibrary'] = self.recipe_scanner._lora_scanner.has_hash(lora['hash'].lower())
|
||||
lora['preview_url'] = self.recipe_scanner._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
|
||||
lora['localPath'] = self.recipe_scanner._lora_scanner.get_lora_path_by_hash(lora['hash'].lower())
|
||||
lora['localPath'] = self.recipe_scanner._lora_scanner.get_path_by_hash(lora['hash'].lower())
|
||||
|
||||
# Ensure file_url is set (needed by frontend)
|
||||
if 'file_path' in recipe:
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import os
|
||||
import subprocess
|
||||
import aiohttp
|
||||
import logging
|
||||
import toml
|
||||
import subprocess
|
||||
import git
|
||||
from datetime import datetime
|
||||
from aiohttp import web
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,6 +19,7 @@ class UpdateRoutes:
|
||||
"""Register update check routes"""
|
||||
app.router.add_get('/api/check-updates', UpdateRoutes.check_updates)
|
||||
app.router.add_get('/api/version-info', UpdateRoutes.get_version_info)
|
||||
app.router.add_post('/api/perform-update', UpdateRoutes.perform_update)
|
||||
|
||||
@staticmethod
|
||||
async def check_updates(request):
|
||||
@@ -25,6 +28,8 @@ class UpdateRoutes:
|
||||
Returns update status and version information
|
||||
"""
|
||||
try:
|
||||
nightly = request.query.get('nightly', 'false').lower() == 'true'
|
||||
|
||||
# Read local version from pyproject.toml
|
||||
local_version = UpdateRoutes._get_local_version()
|
||||
|
||||
@@ -32,13 +37,21 @@ class UpdateRoutes:
|
||||
git_info = UpdateRoutes._get_git_info()
|
||||
|
||||
# Fetch remote version from GitHub
|
||||
remote_version, changelog = await UpdateRoutes._get_remote_version()
|
||||
if nightly:
|
||||
remote_version, changelog = await UpdateRoutes._get_nightly_version()
|
||||
else:
|
||||
remote_version, changelog = await UpdateRoutes._get_remote_version()
|
||||
|
||||
# Compare versions
|
||||
update_available = UpdateRoutes._compare_versions(
|
||||
local_version.replace('v', ''),
|
||||
remote_version.replace('v', '')
|
||||
)
|
||||
if nightly:
|
||||
# For nightly, compare commit hashes
|
||||
update_available = UpdateRoutes._compare_nightly_versions(git_info, remote_version)
|
||||
else:
|
||||
# For stable, compare semantic versions
|
||||
update_available = UpdateRoutes._compare_versions(
|
||||
local_version.replace('v', ''),
|
||||
remote_version.replace('v', '')
|
||||
)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
@@ -46,7 +59,8 @@ class UpdateRoutes:
|
||||
'latest_version': remote_version,
|
||||
'update_available': update_available,
|
||||
'changelog': changelog,
|
||||
'git_info': git_info
|
||||
'git_info': git_info,
|
||||
'nightly': nightly
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
@@ -55,7 +69,7 @@ class UpdateRoutes:
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def get_version_info(request):
|
||||
"""
|
||||
@@ -84,6 +98,168 @@ class UpdateRoutes:
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def perform_update(request):
|
||||
"""
|
||||
Perform Git-based update to latest release tag or main branch
|
||||
"""
|
||||
try:
|
||||
# Parse request body
|
||||
body = await request.json() if request.has_body else {}
|
||||
nightly = body.get('nightly', False)
|
||||
|
||||
# Get current plugin directory
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
plugin_root = os.path.dirname(os.path.dirname(current_dir))
|
||||
|
||||
# Backup settings.json if it exists
|
||||
settings_path = os.path.join(plugin_root, 'settings.json')
|
||||
settings_backup = None
|
||||
if os.path.exists(settings_path):
|
||||
with open(settings_path, 'r', encoding='utf-8') as f:
|
||||
settings_backup = f.read()
|
||||
logger.info("Backed up settings.json")
|
||||
|
||||
# Perform Git update
|
||||
success, new_version = await UpdateRoutes._perform_git_update(plugin_root, nightly)
|
||||
|
||||
# Restore settings.json if we backed it up
|
||||
if settings_backup and success:
|
||||
with open(settings_path, 'w', encoding='utf-8') as f:
|
||||
f.write(settings_backup)
|
||||
logger.info("Restored settings.json")
|
||||
|
||||
if success:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': f'Successfully updated to {new_version}',
|
||||
'new_version': new_version
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Failed to complete Git update'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to perform update: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def _get_nightly_version() -> tuple[str, List[str]]:
|
||||
"""
|
||||
Fetch latest commit from main branch
|
||||
"""
|
||||
repo_owner = "willmiao"
|
||||
repo_name = "ComfyUI-Lora-Manager"
|
||||
|
||||
# Use GitHub API to fetch the latest commit from main branch
|
||||
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main"
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(github_url, headers={'Accept': 'application/vnd.github+json'}) as response:
|
||||
if response.status != 200:
|
||||
logger.warning(f"Failed to fetch GitHub commit: {response.status}")
|
||||
return "main", []
|
||||
|
||||
data = await response.json()
|
||||
commit_sha = data.get('sha', '')[:7] # Short hash
|
||||
commit_message = data.get('commit', {}).get('message', '')
|
||||
|
||||
# Format as "main-{short_hash}"
|
||||
version = f"main-{commit_sha}"
|
||||
|
||||
# Use commit message as changelog
|
||||
changelog = [commit_message] if commit_message else []
|
||||
|
||||
return version, changelog
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching nightly version: {e}", exc_info=True)
|
||||
return "main", []
|
||||
|
||||
@staticmethod
|
||||
def _compare_nightly_versions(local_git_info: Dict[str, str], remote_version: str) -> bool:
|
||||
"""
|
||||
Compare local commit hash with remote main branch
|
||||
"""
|
||||
try:
|
||||
local_hash = local_git_info.get('short_hash', 'unknown')
|
||||
if local_hash == 'unknown':
|
||||
return True # Assume update available if we can't get local hash
|
||||
|
||||
# Extract remote hash from version string (format: "main-{hash}")
|
||||
if '-' in remote_version:
|
||||
remote_hash = remote_version.split('-')[-1]
|
||||
return local_hash != remote_hash
|
||||
|
||||
return True # Default to update available
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error comparing nightly versions: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _perform_git_update(plugin_root: str, nightly: bool = False) -> tuple[bool, str]:
|
||||
"""
|
||||
Perform Git-based update using GitPython
|
||||
|
||||
Args:
|
||||
plugin_root: Path to the plugin root directory
|
||||
nightly: Whether to update to main branch or latest release
|
||||
|
||||
Returns:
|
||||
tuple: (success, new_version)
|
||||
"""
|
||||
try:
|
||||
# Open the Git repository
|
||||
repo = git.Repo(plugin_root)
|
||||
|
||||
# Fetch latest changes
|
||||
origin = repo.remotes.origin
|
||||
origin.fetch()
|
||||
|
||||
if nightly:
|
||||
# Switch to main branch and pull latest
|
||||
main_branch = 'main'
|
||||
if main_branch not in [branch.name for branch in repo.branches]:
|
||||
# Create local main branch if it doesn't exist
|
||||
repo.create_head(main_branch, origin.refs.main)
|
||||
|
||||
repo.heads[main_branch].checkout()
|
||||
origin.pull(main_branch)
|
||||
|
||||
# Get new commit hash
|
||||
new_version = f"main-{repo.head.commit.hexsha[:7]}"
|
||||
|
||||
else:
|
||||
# Get latest release tag
|
||||
tags = sorted(repo.tags, key=lambda t: t.commit.committed_datetime, reverse=True)
|
||||
if not tags:
|
||||
logger.error("No tags found in repository")
|
||||
return False, ""
|
||||
|
||||
latest_tag = tags[0]
|
||||
|
||||
# Checkout to latest tag
|
||||
repo.git.checkout(latest_tag.name)
|
||||
|
||||
new_version = latest_tag.name
|
||||
|
||||
logger.info(f"Successfully updated to {new_version}")
|
||||
return True, new_version
|
||||
|
||||
except git.exc.GitError as e:
|
||||
logger.error(f"Git error during update: {e}")
|
||||
return False, ""
|
||||
except Exception as e:
|
||||
logger.error(f"Error during Git update: {e}")
|
||||
return False, ""
|
||||
|
||||
@staticmethod
|
||||
def _get_local_version() -> str:
|
||||
"""Get local plugin version from pyproject.toml"""
|
||||
|
||||
259
py/services/base_model_service.py
Normal file
259
py/services/base_model_service.py
Normal file
@@ -0,0 +1,259 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Type
|
||||
import logging
|
||||
|
||||
from ..utils.models import BaseModelMetadata
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from .settings_manager import settings
|
||||
from ..utils.utils import fuzzy_match
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseModelService(ABC):
|
||||
"""Base service class for all model types"""
|
||||
|
||||
def __init__(self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata]):
|
||||
"""Initialize the service
|
||||
|
||||
Args:
|
||||
model_type: Type of model (lora, checkpoint, etc.)
|
||||
scanner: Model scanner instance
|
||||
metadata_class: Metadata class for this model type
|
||||
"""
|
||||
self.model_type = model_type
|
||||
self.scanner = scanner
|
||||
self.metadata_class = metadata_class
|
||||
|
||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
||||
folder: str = None, search: str = None, fuzzy_search: bool = False,
|
||||
base_models: list = None, tags: list = None,
|
||||
search_options: dict = None, hash_filters: dict = None,
|
||||
favorites_only: bool = False, **kwargs) -> Dict:
|
||||
"""Get paginated and filtered model data
|
||||
|
||||
Args:
|
||||
page: Page number (1-based)
|
||||
page_size: Number of items per page
|
||||
sort_by: Sort criteria, e.g. 'name', 'name:asc', 'name:desc', 'date', 'date:asc', 'date:desc'
|
||||
folder: Folder filter
|
||||
search: Search term
|
||||
fuzzy_search: Whether to use fuzzy search
|
||||
base_models: List of base models to filter by
|
||||
tags: List of tags to filter by
|
||||
search_options: Search options dict
|
||||
hash_filters: Hash filtering options
|
||||
favorites_only: Filter for favorites only
|
||||
**kwargs: Additional model-specific filters
|
||||
|
||||
Returns:
|
||||
Dict containing paginated results
|
||||
"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Parse sort_by into sort_key and order
|
||||
if ':' in sort_by:
|
||||
sort_key, order = sort_by.split(':', 1)
|
||||
sort_key = sort_key.strip()
|
||||
order = order.strip().lower()
|
||||
if order not in ('asc', 'desc'):
|
||||
order = 'asc'
|
||||
else:
|
||||
sort_key = sort_by.strip()
|
||||
order = 'asc'
|
||||
|
||||
# Get default search options if not provided
|
||||
if search_options is None:
|
||||
search_options = {
|
||||
'filename': True,
|
||||
'modelname': True,
|
||||
'tags': False,
|
||||
'recursive': False,
|
||||
}
|
||||
|
||||
# Get the base data set using new sort logic
|
||||
filtered_data = await cache.get_sorted_data(sort_key, order)
|
||||
|
||||
# Apply hash filtering if provided (highest priority)
|
||||
if hash_filters:
|
||||
filtered_data = await self._apply_hash_filters(filtered_data, hash_filters)
|
||||
|
||||
# Jump to pagination for hash filters
|
||||
return self._paginate(filtered_data, page, page_size)
|
||||
|
||||
# Apply common filters
|
||||
filtered_data = await self._apply_common_filters(
|
||||
filtered_data, folder, base_models, tags, favorites_only, search_options
|
||||
)
|
||||
|
||||
# Apply search filtering
|
||||
if search:
|
||||
filtered_data = await self._apply_search_filters(
|
||||
filtered_data, search, fuzzy_search, search_options
|
||||
)
|
||||
|
||||
# Apply model-specific filters
|
||||
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
|
||||
|
||||
return self._paginate(filtered_data, page, page_size)
|
||||
|
||||
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
|
||||
"""Apply hash-based filtering"""
|
||||
single_hash = hash_filters.get('single_hash')
|
||||
multiple_hashes = hash_filters.get('multiple_hashes')
|
||||
|
||||
if single_hash:
|
||||
# Filter by single hash
|
||||
single_hash = single_hash.lower()
|
||||
return [
|
||||
item for item in data
|
||||
if item.get('sha256', '').lower() == single_hash
|
||||
]
|
||||
elif multiple_hashes:
|
||||
# Filter by multiple hashes
|
||||
hash_set = set(hash.lower() for hash in multiple_hashes)
|
||||
return [
|
||||
item for item in data
|
||||
if item.get('sha256', '').lower() in hash_set
|
||||
]
|
||||
|
||||
return data
|
||||
|
||||
async def _apply_common_filters(self, data: List[Dict], folder: str = None,
|
||||
base_models: list = None, tags: list = None,
|
||||
favorites_only: bool = False, search_options: dict = None) -> List[Dict]:
|
||||
"""Apply common filters that work across all model types"""
|
||||
# Apply SFW filtering if enabled in settings
|
||||
if settings.get('show_only_sfw', False):
|
||||
data = [
|
||||
item for item in data
|
||||
if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
||||
]
|
||||
|
||||
# Apply favorites filtering if enabled
|
||||
if favorites_only:
|
||||
data = [
|
||||
item for item in data
|
||||
if item.get('favorite', False) is True
|
||||
]
|
||||
|
||||
# Apply folder filtering
|
||||
if folder is not None:
|
||||
if search_options and search_options.get('recursive', False):
|
||||
# Recursive folder filtering - include all subfolders
|
||||
data = [
|
||||
item for item in data
|
||||
if item['folder'].startswith(folder)
|
||||
]
|
||||
else:
|
||||
# Exact folder filtering
|
||||
data = [
|
||||
item for item in data
|
||||
if item['folder'] == folder
|
||||
]
|
||||
|
||||
# Apply base model filtering
|
||||
if base_models and len(base_models) > 0:
|
||||
data = [
|
||||
item for item in data
|
||||
if item.get('base_model') in base_models
|
||||
]
|
||||
|
||||
# Apply tag filtering
|
||||
if tags and len(tags) > 0:
|
||||
data = [
|
||||
item for item in data
|
||||
if any(tag in item.get('tags', []) for tag in tags)
|
||||
]
|
||||
|
||||
return data
|
||||
|
||||
async def _apply_search_filters(self, data: List[Dict], search: str,
|
||||
fuzzy_search: bool, search_options: dict) -> List[Dict]:
|
||||
"""Apply search filtering"""
|
||||
search_results = []
|
||||
|
||||
for item in data:
|
||||
# Search by file name
|
||||
if search_options.get('filename', True):
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(item.get('file_name', ''), search):
|
||||
search_results.append(item)
|
||||
continue
|
||||
elif search.lower() in item.get('file_name', '').lower():
|
||||
search_results.append(item)
|
||||
continue
|
||||
|
||||
# Search by model name
|
||||
if search_options.get('modelname', True):
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(item.get('model_name', ''), search):
|
||||
search_results.append(item)
|
||||
continue
|
||||
elif search.lower() in item.get('model_name', '').lower():
|
||||
search_results.append(item)
|
||||
continue
|
||||
|
||||
# Search by tags
|
||||
if search_options.get('tags', False) and 'tags' in item:
|
||||
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower())
|
||||
for tag in item['tags']):
|
||||
search_results.append(item)
|
||||
continue
|
||||
|
||||
return search_results
|
||||
|
||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
||||
return data
|
||||
|
||||
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
|
||||
"""Apply pagination to filtered data"""
|
||||
total_items = len(data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
return {
|
||||
'items': data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def format_response(self, model_data: Dict) -> Dict:
|
||||
"""Format model data for API response - must be implemented by subclasses"""
|
||||
pass
|
||||
|
||||
# Common service methods that delegate to scanner
|
||||
async def get_top_tags(self, limit: int = 20) -> List[Dict]:
|
||||
"""Get top tags sorted by frequency"""
|
||||
return await self.scanner.get_top_tags(limit)
|
||||
|
||||
async def get_base_models(self, limit: int = 20) -> List[Dict]:
|
||||
"""Get base models sorted by frequency"""
|
||||
return await self.scanner.get_base_models(limit)
|
||||
|
||||
def has_hash(self, sha256: str) -> bool:
|
||||
"""Check if a model with given hash exists"""
|
||||
return self.scanner.has_hash(sha256)
|
||||
|
||||
def get_path_by_hash(self, sha256: str) -> Optional[str]:
|
||||
"""Get file path for a model by its hash"""
|
||||
return self.scanner.get_path_by_hash(sha256)
|
||||
|
||||
def get_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||
"""Get hash for a model by its file path"""
|
||||
return self.scanner.get_hash_by_path(file_path)
|
||||
|
||||
async def scan_models(self, force_refresh: bool = False, rebuild_cache: bool = False):
|
||||
"""Trigger model scanning"""
|
||||
return await self.scanner.get_cached_data(force_refresh=force_refresh, rebuild_cache=rebuild_cache)
|
||||
|
||||
async def get_model_info_by_name(self, name: str):
|
||||
"""Get model information by name"""
|
||||
return await self.scanner.get_model_info_by_name(name)
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get model root directories"""
|
||||
return self.scanner.get_model_roots()
|
||||
@@ -1,112 +1,26 @@
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional, Set
|
||||
import folder_paths # type: ignore
|
||||
from typing import List
|
||||
|
||||
from ..utils.models import CheckpointMetadata
|
||||
from ..config import config
|
||||
from .model_scanner import ModelScanner
|
||||
from .model_hash_index import ModelHashIndex
|
||||
from .service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointScanner(ModelScanner):
|
||||
"""Service for scanning and managing checkpoint files"""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, '_initialized'):
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
|
||||
super().__init__(
|
||||
model_type="checkpoint",
|
||||
model_class=CheckpointMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex()
|
||||
)
|
||||
self._initialized = True
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
|
||||
super().__init__(
|
||||
model_type="checkpoint",
|
||||
model_class=CheckpointMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance with async support"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get checkpoint root directories"""
|
||||
return config.base_models_roots
|
||||
|
||||
async def scan_all_models(self) -> List[Dict]:
|
||||
"""Scan all checkpoint directories and return metadata"""
|
||||
all_checkpoints = []
|
||||
|
||||
# Create scan tasks for each directory
|
||||
scan_tasks = []
|
||||
for root in self.get_model_roots():
|
||||
task = asyncio.create_task(self._scan_directory(root))
|
||||
scan_tasks.append(task)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for task in scan_tasks:
|
||||
try:
|
||||
checkpoints = await task
|
||||
all_checkpoints.extend(checkpoints)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning checkpoint directory: {e}")
|
||||
|
||||
return all_checkpoints
|
||||
|
||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||
"""Scan a directory for checkpoint files"""
|
||||
checkpoints = []
|
||||
original_root = root_path
|
||||
|
||||
async def scan_recursive(path: str, visited_paths: set):
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
logger.debug(f"Skipping already visited path: {path}")
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
entries = list(it)
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True):
|
||||
# Check if file has supported extension
|
||||
ext = os.path.splitext(entry.name)[1].lower()
|
||||
if ext in self.file_extensions:
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
await self._process_single_file(file_path, original_root, checkpoints)
|
||||
await asyncio.sleep(0)
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
# For directories, continue scanning with original path
|
||||
await scan_recursive(entry.path, visited_paths)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {path}: {e}")
|
||||
|
||||
await scan_recursive(root_path, set())
|
||||
return checkpoints
|
||||
|
||||
async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list):
|
||||
"""Process a single checkpoint file and add to results"""
|
||||
try:
|
||||
result = await self._process_model_file(file_path, root_path)
|
||||
if result:
|
||||
checkpoints.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
return config.base_models_roots
|
||||
51
py/services/checkpoint_service.py
Normal file
51
py/services/checkpoint_service.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import CheckpointMetadata
|
||||
from ..config import config
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointService(BaseModelService):
|
||||
"""Checkpoint-specific service implementation"""
|
||||
|
||||
def __init__(self, scanner):
|
||||
"""Initialize Checkpoint service
|
||||
|
||||
Args:
|
||||
scanner: Checkpoint scanner instance
|
||||
"""
|
||||
super().__init__("checkpoint", scanner, CheckpointMetadata)
|
||||
|
||||
async def format_response(self, checkpoint_data: Dict) -> Dict:
|
||||
"""Format Checkpoint data for API response"""
|
||||
return {
|
||||
"model_name": checkpoint_data["model_name"],
|
||||
"file_name": checkpoint_data["file_name"],
|
||||
"preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0),
|
||||
"base_model": checkpoint_data.get("base_model", ""),
|
||||
"folder": checkpoint_data["folder"],
|
||||
"sha256": checkpoint_data.get("sha256", ""),
|
||||
"file_path": checkpoint_data["file_path"].replace(os.sep, "/"),
|
||||
"file_size": checkpoint_data.get("size", 0),
|
||||
"modified": checkpoint_data.get("modified", ""),
|
||||
"tags": checkpoint_data.get("tags", []),
|
||||
"modelDescription": checkpoint_data.get("modelDescription", ""),
|
||||
"from_civitai": checkpoint_data.get("from_civitai", True),
|
||||
"notes": checkpoint_data.get("notes", ""),
|
||||
"model_type": checkpoint_data.get("model_type", "checkpoint"),
|
||||
"favorite": checkpoint_data.get("favorite", False),
|
||||
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint_data.get("civitai", {}))
|
||||
}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find Checkpoints with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
def find_duplicate_filenames(self) -> Dict:
|
||||
"""Find Checkpoints with conflicting filenames"""
|
||||
return self.scanner._hash_index.get_duplicate_filenames()
|
||||
@@ -1,15 +1,10 @@
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional
|
||||
from typing import List
|
||||
|
||||
from ..utils.models import LoraMetadata
|
||||
from ..config import config
|
||||
from .model_scanner import ModelScanner
|
||||
from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex
|
||||
from .settings_manager import settings
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from ..utils.utils import fuzzy_match
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -17,404 +12,21 @@ logger = logging.getLogger(__name__)
|
||||
class LoraScanner(ModelScanner):
|
||||
"""Service for scanning and managing LoRA files"""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
# Ensure initialization happens only once
|
||||
if not hasattr(self, '_initialized'):
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.safetensors'}
|
||||
|
||||
# Initialize parent class with ModelHashIndex
|
||||
super().__init__(
|
||||
model_type="lora",
|
||||
model_class=LoraMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
|
||||
)
|
||||
self._initialized = True
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance with async support"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.safetensors'}
|
||||
|
||||
# Initialize parent class with ModelHashIndex
|
||||
super().__init__(
|
||||
model_type="lora",
|
||||
model_class=LoraMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
|
||||
)
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get lora root directories"""
|
||||
return config.loras_roots
|
||||
|
||||
async def scan_all_models(self) -> List[Dict]:
|
||||
"""Scan all LoRA directories and return metadata"""
|
||||
all_loras = []
|
||||
|
||||
# Create scan tasks for each directory
|
||||
scan_tasks = []
|
||||
for lora_root in self.get_model_roots():
|
||||
task = asyncio.create_task(self._scan_directory(lora_root))
|
||||
scan_tasks.append(task)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for task in scan_tasks:
|
||||
try:
|
||||
loras = await task
|
||||
all_loras.extend(loras)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning directory: {e}")
|
||||
|
||||
return all_loras
|
||||
|
||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||
"""Scan a single directory for LoRA files"""
|
||||
loras = []
|
||||
original_root = root_path # Save original root path
|
||||
|
||||
async def scan_recursive(path: str, visited_paths: set):
|
||||
"""Recursively scan directory, avoiding circular symlinks"""
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
logger.debug(f"Skipping already visited path: {path}")
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
entries = list(it)
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
|
||||
# Use original path instead of real path
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
await self._process_single_file(file_path, original_root, loras)
|
||||
await asyncio.sleep(0)
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
# For directories, continue scanning with original path
|
||||
await scan_recursive(entry.path, visited_paths)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {path}: {e}")
|
||||
|
||||
await scan_recursive(root_path, set())
|
||||
return loras
|
||||
|
||||
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
|
||||
"""Process a single file and add to results list"""
|
||||
try:
|
||||
result = await self._process_model_file(file_path, root_path)
|
||||
if result:
|
||||
loras.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
|
||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
||||
folder: str = None, search: str = None, fuzzy_search: bool = False,
|
||||
base_models: list = None, tags: list = None,
|
||||
search_options: dict = None, hash_filters: dict = None,
|
||||
favorites_only: bool = False, first_letter: str = None) -> Dict:
|
||||
"""Get paginated and filtered lora data
|
||||
|
||||
Args:
|
||||
page: Current page number (1-based)
|
||||
page_size: Number of items per page
|
||||
sort_by: Sort method ('name' or 'date')
|
||||
folder: Filter by folder path
|
||||
search: Search term
|
||||
fuzzy_search: Use fuzzy matching for search
|
||||
base_models: List of base models to filter by
|
||||
tags: List of tags to filter by
|
||||
search_options: Dictionary with search options (filename, modelname, tags, recursive)
|
||||
hash_filters: Dictionary with hash filtering options (single_hash or multiple_hashes)
|
||||
favorites_only: Filter for favorite models only
|
||||
first_letter: Filter by first letter of model name
|
||||
"""
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
# Get default search options if not provided
|
||||
if search_options is None:
|
||||
search_options = {
|
||||
'filename': True,
|
||||
'modelname': True,
|
||||
'tags': False,
|
||||
'recursive': False,
|
||||
}
|
||||
|
||||
# Get the base data set
|
||||
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
||||
|
||||
# Apply hash filtering if provided (highest priority)
|
||||
if hash_filters:
|
||||
single_hash = hash_filters.get('single_hash')
|
||||
multiple_hashes = hash_filters.get('multiple_hashes')
|
||||
|
||||
if single_hash:
|
||||
# Filter by single hash
|
||||
single_hash = single_hash.lower() # Ensure lowercase for matching
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora.get('sha256', '').lower() == single_hash
|
||||
]
|
||||
elif multiple_hashes:
|
||||
# Filter by multiple hashes
|
||||
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora.get('sha256', '').lower() in hash_set
|
||||
]
|
||||
|
||||
|
||||
# Jump to pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
result = {
|
||||
'items': filtered_data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
# Apply SFW filtering if enabled
|
||||
if settings.get('show_only_sfw', False):
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if not lora.get('preview_nsfw_level') or lora.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
||||
]
|
||||
|
||||
# Apply favorites filtering if enabled
|
||||
if favorites_only:
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora.get('favorite', False) is True
|
||||
]
|
||||
|
||||
# Apply first letter filtering
|
||||
if first_letter:
|
||||
filtered_data = self._filter_by_first_letter(filtered_data, first_letter)
|
||||
|
||||
# Apply folder filtering
|
||||
if folder is not None:
|
||||
if search_options.get('recursive', False):
|
||||
# Recursive folder filtering - include all subfolders
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora['folder'].startswith(folder)
|
||||
]
|
||||
else:
|
||||
# Exact folder filtering
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora['folder'] == folder
|
||||
]
|
||||
|
||||
# Apply base model filtering
|
||||
if base_models and len(base_models) > 0:
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora.get('base_model') in base_models
|
||||
]
|
||||
|
||||
# Apply tag filtering
|
||||
if tags and len(tags) > 0:
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if any(tag in lora.get('tags', []) for tag in tags)
|
||||
]
|
||||
|
||||
# Apply search filtering
|
||||
if search:
|
||||
search_results = []
|
||||
search_opts = search_options or {}
|
||||
|
||||
for lora in filtered_data:
|
||||
# Search by file name
|
||||
if search_opts.get('filename', True):
|
||||
if fuzzy_match(lora.get('file_name', ''), search):
|
||||
search_results.append(lora)
|
||||
continue
|
||||
|
||||
# Search by model name
|
||||
if search_opts.get('modelname', True):
|
||||
if fuzzy_match(lora.get('model_name', ''), search):
|
||||
search_results.append(lora)
|
||||
continue
|
||||
|
||||
# Search by tags
|
||||
if search_opts.get('tags', False) and 'tags' in lora:
|
||||
if any(fuzzy_match(tag, search) for tag in lora['tags']):
|
||||
search_results.append(lora)
|
||||
continue
|
||||
|
||||
filtered_data = search_results
|
||||
|
||||
# Calculate pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
result = {
|
||||
'items': filtered_data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def _filter_by_first_letter(self, data, letter):
|
||||
"""Filter data by first letter of model name
|
||||
|
||||
Special handling:
|
||||
- '#': Numbers (0-9)
|
||||
- '@': Special characters (not alphanumeric)
|
||||
- '漢': CJK characters
|
||||
"""
|
||||
filtered_data = []
|
||||
|
||||
for lora in data:
|
||||
model_name = lora.get('model_name', '')
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
first_char = model_name[0].upper()
|
||||
|
||||
if letter == '#' and first_char.isdigit():
|
||||
filtered_data.append(lora)
|
||||
elif letter == '@' and not first_char.isalnum():
|
||||
# Special characters (not alphanumeric)
|
||||
filtered_data.append(lora)
|
||||
elif letter == '漢' and self._is_cjk_character(first_char):
|
||||
# CJK characters
|
||||
filtered_data.append(lora)
|
||||
elif letter.upper() == first_char:
|
||||
# Regular alphabet matching
|
||||
filtered_data.append(lora)
|
||||
|
||||
return filtered_data
|
||||
|
||||
def _is_cjk_character(self, char):
|
||||
"""Check if character is a CJK character"""
|
||||
# Define Unicode ranges for CJK characters
|
||||
cjk_ranges = [
|
||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
|
||||
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
|
||||
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
|
||||
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
|
||||
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
|
||||
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
|
||||
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
|
||||
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
|
||||
(0x3300, 0x33FF), # CJK Compatibility
|
||||
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
|
||||
(0x3100, 0x312F), # Bopomofo
|
||||
(0x31A0, 0x31BF), # Bopomofo Extended
|
||||
(0x3040, 0x309F), # Hiragana
|
||||
(0x30A0, 0x30FF), # Katakana
|
||||
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
|
||||
(0xAC00, 0xD7AF), # Hangul Syllables
|
||||
(0x1100, 0x11FF), # Hangul Jamo
|
||||
(0xA960, 0xA97F), # Hangul Jamo Extended-A
|
||||
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
|
||||
]
|
||||
|
||||
code_point = ord(char)
|
||||
return any(start <= code_point <= end for start, end in cjk_ranges)
|
||||
|
||||
async def get_letter_counts(self):
|
||||
"""Get count of models for each letter of the alphabet"""
|
||||
cache = await self.get_cached_data()
|
||||
data = cache.sorted_by_name
|
||||
|
||||
# Define letter categories
|
||||
letters = {
|
||||
'#': 0, # Numbers
|
||||
'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0,
|
||||
'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0,
|
||||
'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0,
|
||||
'Y': 0, 'Z': 0,
|
||||
'@': 0, # Special characters
|
||||
'漢': 0 # CJK characters
|
||||
}
|
||||
|
||||
# Count models for each letter
|
||||
for lora in data:
|
||||
model_name = lora.get('model_name', '')
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
first_char = model_name[0].upper()
|
||||
|
||||
if first_char.isdigit():
|
||||
letters['#'] += 1
|
||||
elif first_char in letters:
|
||||
letters[first_char] += 1
|
||||
elif self._is_cjk_character(first_char):
|
||||
letters['漢'] += 1
|
||||
elif not first_char.isalnum():
|
||||
letters['@'] += 1
|
||||
|
||||
return letters
|
||||
|
||||
# Lora-specific hash index functionality
|
||||
def has_lora_hash(self, sha256: str) -> bool:
|
||||
"""Check if a LoRA with given hash exists"""
|
||||
return self.has_hash(sha256)
|
||||
|
||||
def get_lora_path_by_hash(self, sha256: str) -> Optional[str]:
|
||||
"""Get file path for a LoRA by its hash"""
|
||||
return self.get_path_by_hash(sha256)
|
||||
|
||||
def get_lora_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||
"""Get hash for a LoRA by its file path"""
|
||||
return self.get_hash_by_path(file_path)
|
||||
|
||||
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
|
||||
"""Get top tags sorted by count"""
|
||||
# Make sure cache is initialized
|
||||
await self.get_cached_data()
|
||||
|
||||
# Sort tags by count in descending order
|
||||
sorted_tags = sorted(
|
||||
[{"tag": tag, "count": count} for tag, count in self._tags_count.items()],
|
||||
key=lambda x: x['count'],
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# Return limited number
|
||||
return sorted_tags[:limit]
|
||||
|
||||
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
|
||||
"""Get base models used in loras sorted by frequency"""
|
||||
# Make sure cache is initialized
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
# Count base model occurrences
|
||||
base_model_counts = {}
|
||||
for lora in cache.raw_data:
|
||||
if 'base_model' in lora and lora['base_model']:
|
||||
base_model = lora['base_model']
|
||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||
|
||||
# Sort base models by count
|
||||
sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()]
|
||||
sorted_models.sort(key=lambda x: x['count'], reverse=True)
|
||||
|
||||
# Return limited number
|
||||
return sorted_models[:limit]
|
||||
|
||||
async def diagnose_hash_index(self):
|
||||
"""Diagnostic method to verify hash index functionality"""
|
||||
@@ -451,19 +63,3 @@ class LoraScanner(ModelScanner):
|
||||
test_hash_result = self._hash_index.get_hash(test_path)
|
||||
print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr)
|
||||
|
||||
async def get_lora_info_by_name(self, name):
|
||||
"""Get LoRA information by name"""
|
||||
try:
|
||||
# Get cached data
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
# Find the LoRA by name
|
||||
for lora in cache.raw_data:
|
||||
if lora.get("file_name") == name:
|
||||
return lora
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting LoRA info by name: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
212
py/services/lora_service.py
Normal file
212
py/services/lora_service.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import LoraMetadata
|
||||
from ..config import config
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoraService(BaseModelService):
|
||||
"""LoRA-specific service implementation"""
|
||||
|
||||
def __init__(self, scanner):
|
||||
"""Initialize LoRA service
|
||||
|
||||
Args:
|
||||
scanner: LoRA scanner instance
|
||||
"""
|
||||
super().__init__("lora", scanner, LoraMetadata)
|
||||
|
||||
async def format_response(self, lora_data: Dict) -> Dict:
|
||||
"""Format LoRA data for API response"""
|
||||
return {
|
||||
"model_name": lora_data["model_name"],
|
||||
"file_name": lora_data["file_name"],
|
||||
"preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
|
||||
"base_model": lora_data.get("base_model", ""),
|
||||
"folder": lora_data["folder"],
|
||||
"sha256": lora_data.get("sha256", ""),
|
||||
"file_path": lora_data["file_path"].replace(os.sep, "/"),
|
||||
"file_size": lora_data.get("size", 0),
|
||||
"modified": lora_data.get("modified", ""),
|
||||
"tags": lora_data.get("tags", []),
|
||||
"modelDescription": lora_data.get("modelDescription", ""),
|
||||
"from_civitai": lora_data.get("from_civitai", True),
|
||||
"usage_tips": lora_data.get("usage_tips", ""),
|
||||
"notes": lora_data.get("notes", ""),
|
||||
"favorite": lora_data.get("favorite", False),
|
||||
"civitai": ModelRouteUtils.filter_civitai_data(lora_data.get("civitai", {}))
|
||||
}
|
||||
|
||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||
"""Apply LoRA-specific filters"""
|
||||
# Handle first_letter filter for LoRAs
|
||||
first_letter = kwargs.get('first_letter')
|
||||
if first_letter:
|
||||
data = self._filter_by_first_letter(data, first_letter)
|
||||
|
||||
return data
|
||||
|
||||
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
|
||||
"""Filter data by first letter of model name
|
||||
|
||||
Special handling:
|
||||
- '#': Numbers (0-9)
|
||||
- '@': Special characters (not alphanumeric)
|
||||
- '漢': CJK characters
|
||||
"""
|
||||
filtered_data = []
|
||||
|
||||
for lora in data:
|
||||
model_name = lora.get('model_name', '')
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
first_char = model_name[0].upper()
|
||||
|
||||
if letter == '#' and first_char.isdigit():
|
||||
filtered_data.append(lora)
|
||||
elif letter == '@' and not first_char.isalnum():
|
||||
# Special characters (not alphanumeric)
|
||||
filtered_data.append(lora)
|
||||
elif letter == '漢' and self._is_cjk_character(first_char):
|
||||
# CJK characters
|
||||
filtered_data.append(lora)
|
||||
elif letter.upper() == first_char:
|
||||
# Regular alphabet matching
|
||||
filtered_data.append(lora)
|
||||
|
||||
return filtered_data
|
||||
|
||||
def _is_cjk_character(self, char: str) -> bool:
|
||||
"""Check if character is a CJK character"""
|
||||
# Define Unicode ranges for CJK characters
|
||||
cjk_ranges = [
|
||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
|
||||
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
|
||||
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
|
||||
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
|
||||
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
|
||||
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
|
||||
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
|
||||
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
|
||||
(0x3300, 0x33FF), # CJK Compatibility
|
||||
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
|
||||
(0x3100, 0x312F), # Bopomofo
|
||||
(0x31A0, 0x31BF), # Bopomofo Extended
|
||||
(0x3040, 0x309F), # Hiragana
|
||||
(0x30A0, 0x30FF), # Katakana
|
||||
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
|
||||
(0xAC00, 0xD7AF), # Hangul Syllables
|
||||
(0x1100, 0x11FF), # Hangul Jamo
|
||||
(0xA960, 0xA97F), # Hangul Jamo Extended-A
|
||||
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
|
||||
]
|
||||
|
||||
code_point = ord(char)
|
||||
return any(start <= code_point <= end for start, end in cjk_ranges)
|
||||
|
||||
# LoRA-specific methods
|
||||
async def get_letter_counts(self) -> Dict[str, int]:
|
||||
"""Get count of LoRAs for each letter of the alphabet"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
data = cache.raw_data
|
||||
|
||||
# Define letter categories
|
||||
letters = {
|
||||
'#': 0, # Numbers
|
||||
'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0,
|
||||
'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0,
|
||||
'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0,
|
||||
'Y': 0, 'Z': 0,
|
||||
'@': 0, # Special characters
|
||||
'漢': 0 # CJK characters
|
||||
}
|
||||
|
||||
# Count models for each letter
|
||||
for lora in data:
|
||||
model_name = lora.get('model_name', '')
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
first_char = model_name[0].upper()
|
||||
|
||||
if first_char.isdigit():
|
||||
letters['#'] += 1
|
||||
elif first_char in letters:
|
||||
letters[first_char] += 1
|
||||
elif self._is_cjk_character(first_char):
|
||||
letters['漢'] += 1
|
||||
elif not first_char.isalnum():
|
||||
letters['@'] += 1
|
||||
|
||||
return letters
|
||||
|
||||
async def get_lora_notes(self, lora_name: str) -> Optional[str]:
|
||||
"""Get notes for a specific LoRA file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
if lora['file_name'] == lora_name:
|
||||
return lora.get('notes', '')
|
||||
|
||||
return None
|
||||
|
||||
async def get_lora_trigger_words(self, lora_name: str) -> List[str]:
|
||||
"""Get trigger words for a specific LoRA file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
if lora['file_name'] == lora_name:
|
||||
civitai_data = lora.get('civitai', {})
|
||||
return civitai_data.get('trainedWords', [])
|
||||
|
||||
return []
|
||||
|
||||
async def get_lora_preview_url(self, lora_name: str) -> Optional[str]:
|
||||
"""Get the static preview URL for a LoRA file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
if lora['file_name'] == lora_name:
|
||||
preview_url = lora.get('preview_url')
|
||||
if preview_url:
|
||||
return config.get_preview_static_url(preview_url)
|
||||
|
||||
return None
|
||||
|
||||
async def get_lora_civitai_url(self, lora_name: str) -> Dict[str, Optional[str]]:
|
||||
"""Get the Civitai URL for a LoRA file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
if lora['file_name'] == lora_name:
|
||||
civitai_data = lora.get('civitai', {})
|
||||
model_id = civitai_data.get('modelId')
|
||||
version_id = civitai_data.get('id')
|
||||
|
||||
if model_id:
|
||||
civitai_url = f"https://civitai.com/models/{model_id}"
|
||||
if version_id:
|
||||
civitai_url += f"?modelVersionId={version_id}"
|
||||
|
||||
return {
|
||||
'civitai_url': civitai_url,
|
||||
'model_id': str(model_id),
|
||||
'version_id': str(version_id) if version_id else None
|
||||
}
|
||||
|
||||
return {'civitai_url': None, 'model_id': None, 'version_id': None}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find LoRAs with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
def find_duplicate_filenames(self) -> Dict:
|
||||
"""Find LoRAs with conflicting filenames"""
|
||||
return self.scanner._hash_index.get_duplicate_filenames()
|
||||
@@ -1,37 +1,85 @@
|
||||
import asyncio
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Tuple
|
||||
from dataclasses import dataclass
|
||||
from operator import itemgetter
|
||||
from natsort import natsorted
|
||||
|
||||
# Supported sort modes: (sort_key, order)
|
||||
# order: 'asc' for ascending, 'desc' for descending
|
||||
SUPPORTED_SORT_MODES = [
|
||||
('name', 'asc'),
|
||||
('name', 'desc'),
|
||||
('date', 'asc'),
|
||||
('date', 'desc'),
|
||||
('size', 'asc'),
|
||||
('size', 'desc'),
|
||||
]
|
||||
|
||||
@dataclass
|
||||
class ModelCache:
|
||||
"""Cache structure for model data"""
|
||||
"""Cache structure for model data with extensible sorting"""
|
||||
raw_data: List[Dict]
|
||||
sorted_by_name: List[Dict]
|
||||
sorted_by_date: List[Dict]
|
||||
folders: List[str]
|
||||
|
||||
def __post_init__(self):
|
||||
self._lock = asyncio.Lock()
|
||||
# Cache for last sort: (sort_key, order) -> sorted list
|
||||
self._last_sort: Tuple[str, str] = (None, None)
|
||||
self._last_sorted_data: List[Dict] = []
|
||||
# Default sort on init
|
||||
asyncio.create_task(self.resort())
|
||||
|
||||
async def resort(self, name_only: bool = False):
|
||||
"""Resort all cached data views"""
|
||||
async def resort(self):
|
||||
"""Resort cached data according to last sort mode if set"""
|
||||
async with self._lock:
|
||||
self.sorted_by_name = natsorted(
|
||||
self.raw_data,
|
||||
key=lambda x: x['model_name'].lower() # Case-insensitive sort
|
||||
)
|
||||
if not name_only:
|
||||
self.sorted_by_date = sorted(
|
||||
self.raw_data,
|
||||
key=itemgetter('modified'),
|
||||
reverse=True
|
||||
)
|
||||
# Update folder list
|
||||
if self._last_sort != (None, None):
|
||||
sort_key, order = self._last_sort
|
||||
sorted_data = self._sort_data(self.raw_data, sort_key, order)
|
||||
self._last_sorted_data = sorted_data
|
||||
# Update folder list
|
||||
# else: do nothing
|
||||
|
||||
all_folders = set(l['folder'] for l in self.raw_data)
|
||||
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
|
||||
"""Sort data by sort_key and order"""
|
||||
reverse = (order == 'desc')
|
||||
if sort_key == 'name':
|
||||
# Natural sort by model_name, case-insensitive
|
||||
return natsorted(
|
||||
data,
|
||||
key=lambda x: x['model_name'].lower(),
|
||||
reverse=reverse
|
||||
)
|
||||
elif sort_key == 'date':
|
||||
# Sort by modified timestamp
|
||||
return sorted(
|
||||
data,
|
||||
key=itemgetter('modified'),
|
||||
reverse=reverse
|
||||
)
|
||||
elif sort_key == 'size':
|
||||
# Sort by file size
|
||||
return sorted(
|
||||
data,
|
||||
key=itemgetter('size'),
|
||||
reverse=reverse
|
||||
)
|
||||
else:
|
||||
# Fallback: no sort
|
||||
return list(data)
|
||||
|
||||
async def get_sorted_data(self, sort_key: str = 'name', order: str = 'asc') -> List[Dict]:
|
||||
"""Get sorted data by sort_key and order, using cache if possible"""
|
||||
async with self._lock:
|
||||
if (sort_key, order) == self._last_sort:
|
||||
return self._last_sorted_data
|
||||
sorted_data = self._sort_data(self.raw_data, sort_key, order)
|
||||
self._last_sort = (sort_key, order)
|
||||
self._last_sorted_data = sorted_data
|
||||
return sorted_data
|
||||
|
||||
async def update_preview_url(self, file_path: str, preview_url: str, preview_nsfw_level: int) -> bool:
|
||||
"""Update preview_url for a specific model in all cached data
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import asyncio
|
||||
import time
|
||||
import shutil
|
||||
from typing import List, Dict, Optional, Type, Set
|
||||
import msgpack # Add MessagePack import for efficient serialization
|
||||
|
||||
from ..utils.models import BaseModelMetadata
|
||||
from ..config import config
|
||||
@@ -19,17 +18,33 @@ from .websocket_manager import ws_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define cache version to handle future format changes
|
||||
# Version history:
|
||||
# 1 - Initial version
|
||||
# 2 - Added duplicate_filenames and duplicate_hashes tracking
|
||||
# 3 - Added _excluded_models list to cache
|
||||
CACHE_VERSION = 3
|
||||
|
||||
class ModelScanner:
|
||||
"""Base service for scanning and managing model files"""
|
||||
|
||||
_lock = asyncio.Lock()
|
||||
_instances = {} # Dictionary to store instances by class
|
||||
_locks = {} # Dictionary to store locks by class
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""Implement singleton pattern for each subclass"""
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super().__new__(cls)
|
||||
return cls._instances[cls]
|
||||
|
||||
@classmethod
|
||||
def _get_lock(cls):
|
||||
"""Get or create a lock for this class"""
|
||||
if cls not in cls._locks:
|
||||
cls._locks[cls] = asyncio.Lock()
|
||||
return cls._locks[cls]
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance with async support"""
|
||||
lock = cls._get_lock()
|
||||
async with lock:
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = cls()
|
||||
return cls._instances[cls]
|
||||
|
||||
def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None):
|
||||
"""Initialize the scanner
|
||||
@@ -40,6 +55,10 @@ class ModelScanner:
|
||||
file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'})
|
||||
hash_index: Hash index instance (optional)
|
||||
"""
|
||||
# Ensure initialization happens only once per instance
|
||||
if hasattr(self, '_initialized'):
|
||||
return
|
||||
|
||||
self.model_type = model_type
|
||||
self.model_class = model_class
|
||||
self.file_extensions = file_extensions
|
||||
@@ -48,202 +67,15 @@ class ModelScanner:
|
||||
self._tags_count = {} # Dictionary to store tag counts
|
||||
self._is_initializing = False # Flag to track initialization state
|
||||
self._excluded_models = [] # List to track excluded models
|
||||
self._dirs_last_modified = {} # Track directory modification times
|
||||
self._use_cache_files = False # Flag to control cache file usage, default to disabled
|
||||
|
||||
# Clear cache files if disabled
|
||||
if not self._use_cache_files:
|
||||
self._clear_cache_files()
|
||||
self._initialized = True
|
||||
|
||||
# Register this service
|
||||
asyncio.create_task(self._register_service())
|
||||
|
||||
def _clear_cache_files(self):
|
||||
"""Clear existing cache files if they exist"""
|
||||
try:
|
||||
cache_path = self._get_cache_file_path()
|
||||
if cache_path and os.path.exists(cache_path):
|
||||
os.remove(cache_path)
|
||||
logger.info(f"Cleared {self.model_type} cache file: {cache_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing {self.model_type} cache file: {e}")
|
||||
|
||||
async def _register_service(self):
|
||||
"""Register this instance with the ServiceRegistry"""
|
||||
service_name = f"{self.model_type}_scanner"
|
||||
await ServiceRegistry.register_service(service_name, self)
|
||||
|
||||
def _get_cache_file_path(self) -> Optional[str]:
|
||||
"""Get the path to the cache file"""
|
||||
# Get the directory where this module is located
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
||||
|
||||
# Create a cache directory within the project if it doesn't exist
|
||||
cache_dir = os.path.join(current_dir, "cache")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# Create filename based on model type
|
||||
cache_filename = f"lm_{self.model_type}_cache.msgpack"
|
||||
return os.path.join(cache_dir, cache_filename)
|
||||
|
||||
def _prepare_for_msgpack(self, data):
|
||||
"""Preprocess data to accommodate MessagePack serialization limitations
|
||||
|
||||
Converts integers exceeding safe range to strings
|
||||
|
||||
Args:
|
||||
data: Any type of data structure
|
||||
|
||||
Returns:
|
||||
Preprocessed data structure with large integers converted to strings
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
return {k: self._prepare_for_msgpack(v) for k, v in data.items()}
|
||||
elif isinstance(data, list):
|
||||
return [self._prepare_for_msgpack(item) for item in data]
|
||||
elif isinstance(data, int) and (data > 9007199254740991 or data < -9007199254740991):
|
||||
# Convert integers exceeding JavaScript's safe integer range (2^53-1) to strings
|
||||
return str(data)
|
||||
else:
|
||||
return data
|
||||
|
||||
async def _save_cache_to_disk(self) -> bool:
|
||||
"""Save cache data to disk using MessagePack"""
|
||||
if not self._use_cache_files:
|
||||
logger.debug(f"Cache files disabled for {self.model_type}, skipping save")
|
||||
return False
|
||||
|
||||
if self._cache is None or not self._cache.raw_data:
|
||||
logger.debug(f"No {self.model_type} cache data to save")
|
||||
return False
|
||||
|
||||
cache_path = self._get_cache_file_path()
|
||||
if not cache_path:
|
||||
logger.warning(f"Cannot determine {self.model_type} cache file location")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Create cache data structure
|
||||
cache_data = {
|
||||
"version": CACHE_VERSION,
|
||||
"timestamp": time.time(),
|
||||
"model_type": self.model_type,
|
||||
"raw_data": self._cache.raw_data,
|
||||
"hash_index": {
|
||||
"hash_to_path": self._hash_index._hash_to_path,
|
||||
"filename_to_hash": self._hash_index._filename_to_hash, # Fix: changed from path_to_hash to filename_to_hash
|
||||
"duplicate_hashes": self._hash_index._duplicate_hashes,
|
||||
"duplicate_filenames": self._hash_index._duplicate_filenames
|
||||
},
|
||||
"tags_count": self._tags_count,
|
||||
"dirs_last_modified": self._get_dirs_last_modified(),
|
||||
"excluded_models": self._excluded_models # Add excluded_models to cache data
|
||||
}
|
||||
|
||||
# Preprocess data to handle large integers
|
||||
processed_cache_data = self._prepare_for_msgpack(cache_data)
|
||||
|
||||
# Write to temporary file first (atomic operation)
|
||||
temp_path = f"{cache_path}.tmp"
|
||||
with open(temp_path, 'wb') as f:
|
||||
msgpack.pack(processed_cache_data, f)
|
||||
|
||||
# Replace the old file with the new one
|
||||
if os.path.exists(cache_path):
|
||||
os.replace(temp_path, cache_path)
|
||||
else:
|
||||
os.rename(temp_path, cache_path)
|
||||
|
||||
logger.info(f"Saved {self.model_type} cache with {len(self._cache.raw_data)} models to {cache_path}")
|
||||
logger.debug(f"Hash index stats - hash_to_path: {len(self._hash_index._hash_to_path)}, filename_to_hash: {len(self._hash_index._filename_to_hash)}, duplicate_hashes: {len(self._hash_index._duplicate_hashes)}, duplicate_filenames: {len(self._hash_index._duplicate_filenames)}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving {self.model_type} cache to disk: {e}")
|
||||
# Try to clean up temp file if it exists
|
||||
if 'temp_path' in locals() and os.path.exists(temp_path):
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _get_dirs_last_modified(self) -> Dict[str, float]:
|
||||
"""Get last modified time for all model directories"""
|
||||
dirs_info = {}
|
||||
for root in self.get_model_roots():
|
||||
if os.path.exists(root):
|
||||
dirs_info[root] = os.path.getmtime(root)
|
||||
# Also check immediate subdirectories for changes
|
||||
try:
|
||||
with os.scandir(root) as it:
|
||||
for entry in it:
|
||||
if entry.is_dir(follow_symlinks=True):
|
||||
dirs_info[entry.path] = entry.stat().st_mtime
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting directory info for {root}: {e}")
|
||||
return dirs_info
|
||||
|
||||
def _is_cache_valid(self, cache_data: Dict) -> bool:
|
||||
"""Validate if the loaded cache is still valid"""
|
||||
if not cache_data or cache_data.get("version") != CACHE_VERSION:
|
||||
logger.info(f"Cache invalid - version mismatch. Got: {cache_data.get('version')}, Expected: {CACHE_VERSION}")
|
||||
return False
|
||||
|
||||
if cache_data.get("model_type") != self.model_type:
|
||||
logger.info(f"Cache invalid - model type mismatch. Got: {cache_data.get('model_type')}, Expected: {self.model_type}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _load_cache_from_disk(self) -> bool:
|
||||
"""Load cache data from disk using MessagePack"""
|
||||
if not self._use_cache_files:
|
||||
logger.info(f"Cache files disabled for {self.model_type}, skipping load")
|
||||
return False
|
||||
|
||||
start_time = time.time()
|
||||
cache_path = self._get_cache_file_path()
|
||||
if not cache_path or not os.path.exists(cache_path):
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(cache_path, 'rb') as f:
|
||||
cache_data = msgpack.unpack(f)
|
||||
|
||||
# Validate cache data
|
||||
if not self._is_cache_valid(cache_data):
|
||||
logger.info(f"{self.model_type.capitalize()} cache file found but invalid or outdated")
|
||||
return False
|
||||
|
||||
# Load data into memory
|
||||
self._cache = ModelCache(
|
||||
raw_data=cache_data["raw_data"],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
|
||||
# Load hash index
|
||||
hash_index_data = cache_data.get("hash_index", {})
|
||||
self._hash_index._hash_to_path = hash_index_data.get("hash_to_path", {})
|
||||
self._hash_index._filename_to_hash = hash_index_data.get("filename_to_hash", {}) # Fix: changed from path_to_hash to filename_to_hash
|
||||
self._hash_index._duplicate_hashes = hash_index_data.get("duplicate_hashes", {})
|
||||
self._hash_index._duplicate_filenames = hash_index_data.get("duplicate_filenames", {})
|
||||
|
||||
# Load tags count
|
||||
self._tags_count = cache_data.get("tags_count", {})
|
||||
|
||||
# Load excluded models
|
||||
self._excluded_models = cache_data.get("excluded_models", [])
|
||||
|
||||
# Resort the cache
|
||||
await self._cache.resort()
|
||||
|
||||
logger.info(f"Loaded {self.model_type} cache from disk with {len(self._cache.raw_data)} models in {time.time() - start_time:.2f} seconds")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {self.model_type} cache from disk: {e}")
|
||||
return False
|
||||
|
||||
async def initialize_in_background(self) -> None:
|
||||
"""Initialize cache in background using thread pool"""
|
||||
@@ -252,8 +84,6 @@ class ModelScanner:
|
||||
if self._cache is None:
|
||||
self._cache = ModelCache(
|
||||
raw_data=[],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
|
||||
@@ -271,21 +101,6 @@ class ModelScanner:
|
||||
'scanner_type': self.model_type,
|
||||
'pageType': page_type
|
||||
})
|
||||
|
||||
cache_loaded = await self._load_cache_from_disk()
|
||||
|
||||
if cache_loaded:
|
||||
# Cache loaded successfully, broadcast complete message
|
||||
await ws_manager.broadcast_init_progress({
|
||||
'stage': 'finalizing',
|
||||
'progress': 100,
|
||||
'status': 'complete',
|
||||
'details': f"Loaded {len(self._cache.raw_data)} {self.model_type} files from cache.",
|
||||
'scanner_type': self.model_type,
|
||||
'pageType': page_type
|
||||
})
|
||||
self._is_initializing = False
|
||||
return
|
||||
|
||||
# If cache loading failed, proceed with full scan
|
||||
await ws_manager.broadcast_init_progress({
|
||||
@@ -332,9 +147,6 @@ class ModelScanner:
|
||||
|
||||
logger.info(f"{self.model_type.capitalize()} cache initialized in {time.time() - start_time:.2f} seconds. Found {len(self._cache.raw_data)} models")
|
||||
|
||||
# Save the cache to disk after initialization
|
||||
await self._save_cache_to_disk()
|
||||
|
||||
# Send completion message
|
||||
await asyncio.sleep(0.5) # Small delay to ensure final progress message is sent
|
||||
await ws_manager.broadcast_init_progress({
|
||||
@@ -509,40 +321,21 @@ class ModelScanner:
|
||||
|
||||
Args:
|
||||
force_refresh: Whether to refresh the cache
|
||||
rebuild_cache: Whether to completely rebuild the cache by reloading from disk first
|
||||
rebuild_cache: Whether to completely rebuild the cache
|
||||
"""
|
||||
# If cache is not initialized, return an empty cache
|
||||
# Actual initialization should be done via initialize_in_background
|
||||
if self._cache is None and not force_refresh:
|
||||
return ModelCache(
|
||||
raw_data=[],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
|
||||
# If force refresh is requested, initialize the cache directly
|
||||
if force_refresh:
|
||||
# If rebuild_cache is True, try to reload from disk before reconciliation
|
||||
if rebuild_cache:
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Attempting to rebuild cache from disk...")
|
||||
cache_loaded = await self._load_cache_from_disk()
|
||||
if cache_loaded:
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Successfully reloaded cache from disk")
|
||||
else:
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Could not reload cache from disk, proceeding with complete rebuild")
|
||||
# If loading from disk failed, do a complete rebuild and save to disk
|
||||
await self._initialize_cache()
|
||||
await self._save_cache_to_disk()
|
||||
return self._cache
|
||||
|
||||
if self._cache is None:
|
||||
# For initial creation, do a full initialization
|
||||
await self._initialize_cache()
|
||||
# Save the newly built cache
|
||||
await self._save_cache_to_disk()
|
||||
else:
|
||||
# For subsequent refreshes, use fast reconciliation
|
||||
await self._reconcile_cache()
|
||||
|
||||
return self._cache
|
||||
@@ -577,8 +370,6 @@ class ModelScanner:
|
||||
# Update cache
|
||||
self._cache = ModelCache(
|
||||
raw_data=raw_data,
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
|
||||
@@ -592,8 +383,6 @@ class ModelScanner:
|
||||
if self._cache is None:
|
||||
self._cache = ModelCache(
|
||||
raw_data=[],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
finally:
|
||||
@@ -735,19 +524,74 @@ class ModelScanner:
|
||||
# Resort cache
|
||||
await self._cache.resort()
|
||||
|
||||
# Save updated cache to disk
|
||||
await self._save_cache_to_disk()
|
||||
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Cache reconciliation completed in {time.time() - start_time:.2f} seconds. Added {total_added}, removed {total_removed} models.")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.model_type.capitalize()} Scanner: Error reconciling cache: {e}", exc_info=True)
|
||||
finally:
|
||||
self._is_initializing = False # Unset flag
|
||||
|
||||
# These methods should be implemented in child classes
|
||||
async def scan_all_models(self) -> List[Dict]:
|
||||
"""Scan all model directories and return metadata"""
|
||||
raise NotImplementedError("Subclasses must implement scan_all_models")
|
||||
all_models = []
|
||||
|
||||
# Create scan tasks for each directory
|
||||
scan_tasks = []
|
||||
for model_root in self.get_model_roots():
|
||||
task = asyncio.create_task(self._scan_directory(model_root))
|
||||
scan_tasks.append(task)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for task in scan_tasks:
|
||||
try:
|
||||
models = await task
|
||||
all_models.extend(models)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning directory: {e}")
|
||||
|
||||
return all_models
|
||||
|
||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||
"""Scan a single directory for model files"""
|
||||
models = []
|
||||
original_root = root_path # Save original root path
|
||||
|
||||
async def scan_recursive(path: str, visited_paths: set):
|
||||
"""Recursively scan directory, avoiding circular symlinks"""
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
logger.debug(f"Skipping already visited path: {path}")
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
entries = list(it)
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
|
||||
# Use original path instead of real path
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
await self._process_single_file(file_path, original_root, models)
|
||||
await asyncio.sleep(0)
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
# For directories, continue scanning with original path
|
||||
await scan_recursive(entry.path, visited_paths)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {path}: {e}")
|
||||
|
||||
await scan_recursive(root_path, set())
|
||||
return models
|
||||
|
||||
async def _process_single_file(self, file_path: str, root_path: str, models: list):
|
||||
"""Process a single file and add to results list"""
|
||||
try:
|
||||
result = await self._process_model_file(file_path, root_path)
|
||||
if result:
|
||||
models.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
|
||||
def is_initializing(self) -> bool:
|
||||
"""Check if the scanner is currently initializing"""
|
||||
@@ -931,7 +775,7 @@ class ModelScanner:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
|
||||
async def add_model_to_cache(self, metadata_dict: Dict, folder: str = '') -> bool:
|
||||
"""Add a model to the cache and save to disk
|
||||
"""Add a model to the cache
|
||||
|
||||
Args:
|
||||
metadata_dict: The model metadata dictionary
|
||||
@@ -960,9 +804,6 @@ class ModelScanner:
|
||||
|
||||
# Update the hash index
|
||||
self._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
|
||||
|
||||
# Save to disk
|
||||
await self._save_cache_to_disk()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding model to cache: {e}")
|
||||
@@ -1102,9 +943,6 @@ class ModelScanner:
|
||||
|
||||
await cache.resort()
|
||||
|
||||
# Save the updated cache
|
||||
await self._save_cache_to_disk()
|
||||
|
||||
return True
|
||||
|
||||
def has_hash(self, sha256: str) -> bool:
|
||||
@@ -1198,11 +1036,7 @@ class ModelScanner:
|
||||
if self._cache is None:
|
||||
return False
|
||||
|
||||
updated = await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level)
|
||||
if updated:
|
||||
# Save updated cache to disk
|
||||
await self._save_cache_to_disk()
|
||||
return updated
|
||||
return await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level)
|
||||
|
||||
async def bulk_delete_models(self, file_paths: List[str]) -> Dict:
|
||||
"""Delete multiple models and update cache in a batch operation
|
||||
@@ -1334,9 +1168,6 @@ class ModelScanner:
|
||||
# Resort cache
|
||||
await self._cache.resort()
|
||||
|
||||
# Save updated cache to disk
|
||||
await self._save_cache_to_disk()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
|
||||
137
py/services/model_service_factory.py
Normal file
137
py/services/model_service_factory.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from typing import Dict, Type, Any
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelServiceFactory:
|
||||
"""Factory for managing model services and routes"""
|
||||
|
||||
_services: Dict[str, Type] = {}
|
||||
_routes: Dict[str, Type] = {}
|
||||
_initialized_services: Dict[str, Any] = {}
|
||||
_initialized_routes: Dict[str, Any] = {}
|
||||
|
||||
@classmethod
|
||||
def register_model_type(cls, model_type: str, service_class: Type, route_class: Type):
|
||||
"""Register a new model type with its service and route classes
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier (e.g., 'lora', 'checkpoint')
|
||||
service_class: The service class for this model type
|
||||
route_class: The route class for this model type
|
||||
"""
|
||||
cls._services[model_type] = service_class
|
||||
cls._routes[model_type] = route_class
|
||||
logger.info(f"Registered model type '{model_type}' with service {service_class.__name__} and routes {route_class.__name__}")
|
||||
|
||||
@classmethod
|
||||
def get_service_class(cls, model_type: str) -> Type:
|
||||
"""Get service class for a model type
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier
|
||||
|
||||
Returns:
|
||||
The service class for the model type
|
||||
|
||||
Raises:
|
||||
ValueError: If model type is not registered
|
||||
"""
|
||||
if model_type not in cls._services:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
return cls._services[model_type]
|
||||
|
||||
@classmethod
|
||||
def get_route_class(cls, model_type: str) -> Type:
|
||||
"""Get route class for a model type
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier
|
||||
|
||||
Returns:
|
||||
The route class for the model type
|
||||
|
||||
Raises:
|
||||
ValueError: If model type is not registered
|
||||
"""
|
||||
if model_type not in cls._routes:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
return cls._routes[model_type]
|
||||
|
||||
@classmethod
|
||||
def get_route_instance(cls, model_type: str):
|
||||
"""Get or create route instance for a model type
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier
|
||||
|
||||
Returns:
|
||||
The route instance for the model type
|
||||
"""
|
||||
if model_type not in cls._initialized_routes:
|
||||
route_class = cls.get_route_class(model_type)
|
||||
cls._initialized_routes[model_type] = route_class()
|
||||
return cls._initialized_routes[model_type]
|
||||
|
||||
@classmethod
|
||||
def setup_all_routes(cls, app):
|
||||
"""Setup routes for all registered model types
|
||||
|
||||
Args:
|
||||
app: The aiohttp application instance
|
||||
"""
|
||||
logger.info(f"Setting up routes for {len(cls._services)} registered model types")
|
||||
|
||||
for model_type in cls._services.keys():
|
||||
try:
|
||||
routes_instance = cls.get_route_instance(model_type)
|
||||
routes_instance.setup_routes(app)
|
||||
logger.info(f"Successfully set up routes for {model_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup routes for {model_type}: {e}", exc_info=True)
|
||||
|
||||
@classmethod
|
||||
def get_registered_types(cls) -> list:
|
||||
"""Get list of all registered model types
|
||||
|
||||
Returns:
|
||||
List of registered model type identifiers
|
||||
"""
|
||||
return list(cls._services.keys())
|
||||
|
||||
@classmethod
|
||||
def is_registered(cls, model_type: str) -> bool:
|
||||
"""Check if a model type is registered
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier
|
||||
|
||||
Returns:
|
||||
True if the model type is registered, False otherwise
|
||||
"""
|
||||
return model_type in cls._services
|
||||
|
||||
@classmethod
|
||||
def clear_registrations(cls):
|
||||
"""Clear all registrations - mainly for testing purposes"""
|
||||
cls._services.clear()
|
||||
cls._routes.clear()
|
||||
cls._initialized_services.clear()
|
||||
cls._initialized_routes.clear()
|
||||
logger.info("Cleared all model type registrations")
|
||||
|
||||
|
||||
def register_default_model_types():
|
||||
"""Register the default model types (LoRA and Checkpoint)"""
|
||||
from ..services.lora_service import LoraService
|
||||
from ..services.checkpoint_service import CheckpointService
|
||||
from ..routes.lora_routes import LoraRoutes
|
||||
from ..routes.checkpoint_routes import CheckpointRoutes
|
||||
|
||||
# Register LoRA model type
|
||||
ModelServiceFactory.register_model_type('lora', LoraService, LoraRoutes)
|
||||
|
||||
# Register Checkpoint model type
|
||||
ModelServiceFactory.register_model_type('checkpoint', CheckpointService, CheckpointRoutes)
|
||||
|
||||
logger.info("Registered default model types: lora, checkpoint")
|
||||
@@ -393,8 +393,8 @@ class RecipeScanner:
|
||||
if 'hash' in lora and (not lora.get('file_name') or not lora['file_name']):
|
||||
hash_value = lora['hash']
|
||||
|
||||
if self._lora_scanner.has_lora_hash(hash_value):
|
||||
lora_path = self._lora_scanner.get_lora_path_by_hash(hash_value)
|
||||
if self._lora_scanner.has_hash(hash_value):
|
||||
lora_path = self._lora_scanner.get_path_by_hash(hash_value)
|
||||
if lora_path:
|
||||
file_name = os.path.splitext(os.path.basename(lora_path))[0]
|
||||
lora['file_name'] = file_name
|
||||
@@ -465,7 +465,7 @@ class RecipeScanner:
|
||||
# Count occurrences of each base model
|
||||
for lora in loras:
|
||||
if 'hash' in lora:
|
||||
lora_path = self._lora_scanner.get_lora_path_by_hash(lora['hash'])
|
||||
lora_path = self._lora_scanner.get_path_by_hash(lora['hash'])
|
||||
if lora_path:
|
||||
base_model = await self._get_base_model_for_lora(lora_path)
|
||||
if base_model:
|
||||
@@ -603,9 +603,9 @@ class RecipeScanner:
|
||||
if 'loras' in item:
|
||||
for lora in item['loras']:
|
||||
if 'hash' in lora and lora['hash']:
|
||||
lora['inLibrary'] = self._lora_scanner.has_lora_hash(lora['hash'].lower())
|
||||
lora['inLibrary'] = self._lora_scanner.has_hash(lora['hash'].lower())
|
||||
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
|
||||
lora['localPath'] = self._lora_scanner.get_lora_path_by_hash(lora['hash'].lower())
|
||||
lora['localPath'] = self._lora_scanner.get_path_by_hash(lora['hash'].lower())
|
||||
|
||||
result = {
|
||||
'items': paginated_items,
|
||||
@@ -655,9 +655,9 @@ class RecipeScanner:
|
||||
for lora in formatted_recipe['loras']:
|
||||
if 'hash' in lora and lora['hash']:
|
||||
lora_hash = lora['hash'].lower()
|
||||
lora['inLibrary'] = self._lora_scanner.has_lora_hash(lora_hash)
|
||||
lora['inLibrary'] = self._lora_scanner.has_hash(lora_hash)
|
||||
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora_hash)
|
||||
lora['localPath'] = self._lora_scanner.get_lora_path_by_hash(lora_hash)
|
||||
lora['localPath'] = self._lora_scanner.get_path_by_hash(lora_hash)
|
||||
|
||||
return formatted_recipe
|
||||
|
||||
|
||||
@@ -7,106 +7,176 @@ logger = logging.getLogger(__name__)
|
||||
T = TypeVar('T') # Define a type variable for service types
|
||||
|
||||
class ServiceRegistry:
|
||||
"""Centralized registry for service singletons"""
|
||||
"""Central registry for managing singleton services"""
|
||||
|
||||
_instance = None
|
||||
_services: Dict[str, Any] = {}
|
||||
_lock = asyncio.Lock()
|
||||
_locks: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""Get singleton instance of the registry"""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
async def register_service(cls, name: str, service: Any) -> None:
|
||||
"""Register a service instance with the registry
|
||||
|
||||
Args:
|
||||
name: Service name identifier
|
||||
service: Service instance to register
|
||||
"""
|
||||
cls._services[name] = service
|
||||
logger.debug(f"Registered service: {name}")
|
||||
|
||||
@classmethod
|
||||
async def register_service(cls, service_name: str, service_instance: Any) -> None:
|
||||
"""Register a service instance with the registry"""
|
||||
registry = cls.get_instance()
|
||||
async with cls._lock:
|
||||
registry._services[service_name] = service_instance
|
||||
logger.debug(f"Registered service: {service_name}")
|
||||
async def get_service(cls, name: str) -> Optional[Any]:
|
||||
"""Get a service instance by name
|
||||
|
||||
Args:
|
||||
name: Service name identifier
|
||||
|
||||
Returns:
|
||||
Service instance or None if not found
|
||||
"""
|
||||
return cls._services.get(name)
|
||||
|
||||
@classmethod
|
||||
async def get_service(cls, service_name: str) -> Any:
|
||||
"""Get a service instance by name"""
|
||||
registry = cls.get_instance()
|
||||
async with cls._lock:
|
||||
if service_name not in registry._services:
|
||||
logger.debug(f"Service {service_name} not found in registry")
|
||||
return None
|
||||
return registry._services[service_name]
|
||||
def _get_lock(cls, name: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a service
|
||||
|
||||
Args:
|
||||
name: Service name identifier
|
||||
|
||||
Returns:
|
||||
AsyncIO lock for the service
|
||||
"""
|
||||
if name not in cls._locks:
|
||||
cls._locks[name] = asyncio.Lock()
|
||||
return cls._locks[name]
|
||||
|
||||
@classmethod
|
||||
def get_service_sync(cls, service_name: str) -> Any:
|
||||
"""Get a service instance by name (synchronous version)"""
|
||||
registry = cls.get_instance()
|
||||
if service_name not in registry._services:
|
||||
logger.debug(f"Service {service_name} not found in registry")
|
||||
return None
|
||||
return registry._services[service_name]
|
||||
|
||||
# Convenience methods for common services
|
||||
@classmethod
|
||||
async def get_lora_scanner(cls):
|
||||
"""Get the LoraScanner instance"""
|
||||
from .lora_scanner import LoraScanner
|
||||
scanner = await cls.get_service("lora_scanner")
|
||||
if scanner is None:
|
||||
scanner = await LoraScanner.get_instance()
|
||||
await cls.register_service("lora_scanner", scanner)
|
||||
return scanner
|
||||
"""Get or create LoRA scanner instance"""
|
||||
service_name = "lora_scanner"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .lora_scanner import LoraScanner
|
||||
|
||||
scanner = await LoraScanner.get_instance()
|
||||
cls._services[service_name] = scanner
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
async def get_checkpoint_scanner(cls):
|
||||
"""Get the CheckpointScanner instance"""
|
||||
from .checkpoint_scanner import CheckpointScanner
|
||||
scanner = await cls.get_service("checkpoint_scanner")
|
||||
if scanner is None:
|
||||
"""Get or create Checkpoint scanner instance"""
|
||||
service_name = "checkpoint_scanner"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .checkpoint_scanner import CheckpointScanner
|
||||
|
||||
scanner = await CheckpointScanner.get_instance()
|
||||
await cls.register_service("checkpoint_scanner", scanner)
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
async def get_civitai_client(cls):
|
||||
"""Get the CivitaiClient instance"""
|
||||
from .civitai_client import CivitaiClient
|
||||
client = await cls.get_service("civitai_client")
|
||||
if client is None:
|
||||
client = await CivitaiClient.get_instance()
|
||||
await cls.register_service("civitai_client", client)
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
async def get_download_manager(cls):
|
||||
"""Get the DownloadManager instance"""
|
||||
from .download_manager import DownloadManager
|
||||
manager = await cls.get_service("download_manager")
|
||||
if manager is None:
|
||||
manager = await DownloadManager.get_instance()
|
||||
await cls.register_service("download_manager", manager)
|
||||
return manager
|
||||
|
||||
cls._services[service_name] = scanner
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
async def get_recipe_scanner(cls):
|
||||
"""Get the RecipeScanner instance"""
|
||||
from .recipe_scanner import RecipeScanner
|
||||
scanner = await cls.get_service("recipe_scanner")
|
||||
if scanner is None:
|
||||
lora_scanner = await cls.get_lora_scanner()
|
||||
scanner = RecipeScanner(lora_scanner)
|
||||
await cls.register_service("recipe_scanner", scanner)
|
||||
return scanner
|
||||
|
||||
"""Get or create Recipe scanner instance"""
|
||||
service_name = "recipe_scanner"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .recipe_scanner import RecipeScanner
|
||||
|
||||
scanner = await RecipeScanner.get_instance()
|
||||
cls._services[service_name] = scanner
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
async def get_civitai_client(cls):
|
||||
"""Get or create CivitAI client instance"""
|
||||
service_name = "civitai_client"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .civitai_client import CivitaiClient
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
cls._services[service_name] = client
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
async def get_download_manager(cls):
|
||||
"""Get or create Download manager instance"""
|
||||
service_name = "download_manager"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .download_manager import DownloadManager
|
||||
|
||||
manager = DownloadManager()
|
||||
cls._services[service_name] = manager
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return manager
|
||||
|
||||
@classmethod
|
||||
async def get_websocket_manager(cls):
|
||||
"""Get the WebSocketManager instance"""
|
||||
from .websocket_manager import ws_manager
|
||||
manager = await cls.get_service("websocket_manager")
|
||||
if manager is None:
|
||||
# ws_manager is already a global instance in websocket_manager.py
|
||||
"""Get or create WebSocket manager instance"""
|
||||
service_name = "websocket_manager"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .websocket_manager import ws_manager
|
||||
await cls.register_service("websocket_manager", ws_manager)
|
||||
manager = ws_manager
|
||||
return manager
|
||||
|
||||
cls._services[service_name] = ws_manager
|
||||
logger.debug(f"Registered {service_name}")
|
||||
return ws_manager
|
||||
|
||||
@classmethod
|
||||
def clear_services(cls):
|
||||
"""Clear all registered services - mainly for testing"""
|
||||
cls._services.clear()
|
||||
cls._locks.clear()
|
||||
logger.info("Cleared all registered services")
|
||||
@@ -566,9 +566,10 @@ class ModelRouteUtils:
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_download_model(request: web.Request, download_manager: DownloadManager) -> web.Response:
|
||||
async def handle_download_model(request: web.Request) -> web.Response:
|
||||
"""Handle model download request"""
|
||||
try:
|
||||
download_manager = await ServiceRegistry.get_download_manager()
|
||||
data = await request.json()
|
||||
|
||||
# Get or generate a download ID
|
||||
@@ -663,17 +664,17 @@ class ModelRouteUtils:
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_cancel_download(request: web.Request, download_manager: DownloadManager) -> web.Response:
|
||||
async def handle_cancel_download(request: web.Request) -> web.Response:
|
||||
"""Handle cancellation of a download task
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
download_manager: The download manager instance
|
||||
|
||||
Returns:
|
||||
web.Response: The HTTP response
|
||||
"""
|
||||
try:
|
||||
download_manager = await ServiceRegistry.get_download_manager()
|
||||
download_id = request.match_info.get('download_id')
|
||||
if not download_id:
|
||||
return web.json_response({
|
||||
@@ -701,17 +702,17 @@ class ModelRouteUtils:
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_list_downloads(request: web.Request, download_manager: DownloadManager) -> web.Response:
|
||||
async def handle_list_downloads(request: web.Request) -> web.Response:
|
||||
"""Get list of active downloads
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
download_manager: The download manager instance
|
||||
|
||||
Returns:
|
||||
web.Response: The HTTP response with list of downloads
|
||||
"""
|
||||
try:
|
||||
download_manager = await ServiceRegistry.get_download_manager()
|
||||
result = await download_manager.get_active_downloads()
|
||||
return web.json_response(result)
|
||||
except Exception as e:
|
||||
@@ -1047,3 +1048,56 @@ class ModelRouteUtils:
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_save_metadata(request: web.Request, scanner) -> web.Response:
|
||||
"""Handle saving metadata updates
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
scanner: The model scanner instance
|
||||
|
||||
Returns:
|
||||
web.Response: The HTTP response
|
||||
"""
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
if not file_path:
|
||||
return web.Response(text='File path is required', status=400)
|
||||
|
||||
# Remove file path from data to avoid saving it
|
||||
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
|
||||
|
||||
# Get metadata file path
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
|
||||
# Load existing metadata
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
|
||||
# Handle nested updates (for civitai.trainedWords)
|
||||
for key, value in metadata_updates.items():
|
||||
if isinstance(value, dict) and key in metadata and isinstance(metadata[key], dict):
|
||||
# Deep update for nested dictionaries
|
||||
for nested_key, nested_value in value.items():
|
||||
metadata[key][nested_key] = nested_value
|
||||
else:
|
||||
# Regular update for top-level keys
|
||||
metadata[key] = value
|
||||
|
||||
# Save updated metadata
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
# Update cache
|
||||
await scanner.update_single_model_cache(file_path, file_path, metadata)
|
||||
|
||||
# If model_name was updated, resort the cache
|
||||
if 'model_name' in metadata_updates:
|
||||
cache = await scanner.get_cached_data()
|
||||
await cache.resort()
|
||||
|
||||
return web.json_response({'success': True})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving metadata: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
@@ -14,7 +14,7 @@ dependencies = [
|
||||
"requests",
|
||||
"toml",
|
||||
"natsort",
|
||||
"msgpack"
|
||||
"GitPython"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@@ -9,5 +9,5 @@ requests
|
||||
toml
|
||||
numpy
|
||||
natsort
|
||||
msgpack
|
||||
pyyaml
|
||||
GitPython
|
||||
|
||||
@@ -106,6 +106,22 @@ logger = logging.getLogger("lora-manager-standalone")
|
||||
# Configure aiohttp access logger to be less verbose
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
# Add specific suppression for connection reset errors
|
||||
class ConnectionResetFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
# Filter out connection reset errors that are not critical
|
||||
if "ConnectionResetError" in str(record.getMessage()):
|
||||
return False
|
||||
if "_call_connection_lost" in str(record.getMessage()):
|
||||
return False
|
||||
if "WinError 10054" in str(record.getMessage()):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Apply the filter to asyncio logger
|
||||
asyncio_logger = logging.getLogger("asyncio")
|
||||
asyncio_logger.addFilter(ConnectionResetFilter())
|
||||
|
||||
# Now we can import the global config from our local modules
|
||||
from py.config import config
|
||||
|
||||
@@ -118,17 +134,6 @@ class StandaloneServer:
|
||||
|
||||
# Ensure the app's access logger is configured to reduce verbosity
|
||||
self.app._subapps = [] # Ensure this exists to avoid AttributeError
|
||||
|
||||
# Configure access logging for the app
|
||||
self.app.on_startup.append(self._configure_access_logger)
|
||||
|
||||
async def _configure_access_logger(self, app):
|
||||
"""Configure access logger to reduce verbosity"""
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
# If using aiohttp>=3.8.0, configure access logger through app directly
|
||||
if hasattr(app, 'access_logger'):
|
||||
app.access_logger.setLevel(logging.WARNING)
|
||||
|
||||
async def setup(self):
|
||||
"""Set up the standalone server"""
|
||||
@@ -218,9 +223,6 @@ class StandaloneLoraManager(LoraManager):
|
||||
|
||||
# Store app in a global-like location for compatibility
|
||||
sys.modules['server'].PromptServer.instance = server_instance
|
||||
|
||||
# Configure aiohttp access logger to be less verbose
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
added_targets = set() # Track already added target paths
|
||||
|
||||
@@ -314,35 +316,39 @@ class StandaloneLoraManager(LoraManager):
|
||||
app.router.add_static('/loras_static', config.static_path)
|
||||
|
||||
# Setup feature routes
|
||||
from py.routes.lora_routes import LoraRoutes
|
||||
from py.routes.api_routes import ApiRoutes
|
||||
from py.services.model_service_factory import ModelServiceFactory, register_default_model_types
|
||||
from py.routes.recipe_routes import RecipeRoutes
|
||||
from py.routes.checkpoints_routes import CheckpointsRoutes
|
||||
from py.routes.update_routes import UpdateRoutes
|
||||
from py.routes.misc_routes import MiscRoutes
|
||||
from py.routes.example_images_routes import ExampleImagesRoutes
|
||||
from py.routes.stats_routes import StatsRoutes
|
||||
from py.services.websocket_manager import ws_manager
|
||||
|
||||
lora_routes = LoraRoutes()
|
||||
checkpoints_routes = CheckpointsRoutes()
|
||||
|
||||
register_default_model_types()
|
||||
|
||||
# Setup all model routes using the factory
|
||||
ModelServiceFactory.setup_all_routes(app)
|
||||
|
||||
stats_routes = StatsRoutes()
|
||||
|
||||
# Initialize routes
|
||||
lora_routes.setup_routes(app)
|
||||
checkpoints_routes.setup_routes(app)
|
||||
stats_routes.setup_routes(app)
|
||||
ApiRoutes.setup_routes(app)
|
||||
RecipeRoutes.setup_routes(app)
|
||||
UpdateRoutes.setup_routes(app)
|
||||
MiscRoutes.setup_routes(app)
|
||||
ExampleImagesRoutes.setup_routes(app)
|
||||
|
||||
# Setup WebSocket routes that are shared across all model types
|
||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection)
|
||||
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection)
|
||||
|
||||
# Schedule service initialization
|
||||
app.on_startup.append(lambda app: cls._initialize_services())
|
||||
|
||||
# Add cleanup
|
||||
app.on_shutdown.append(cls._cleanup)
|
||||
app.on_shutdown.append(ApiRoutes.cleanup)
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
@@ -367,9 +373,6 @@ async def main():
|
||||
# Set log level
|
||||
logging.getLogger().setLevel(getattr(logging, args.log_level))
|
||||
|
||||
# Explicitly configure aiohttp access logger regardless of selected log level
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
# Create the server instance
|
||||
server = StandaloneServer()
|
||||
|
||||
|
||||
@@ -50,8 +50,8 @@ html, body {
|
||||
--lora-border: oklch(90% 0.02 256 / 0.15);
|
||||
--lora-text: oklch(95% 0.02 256);
|
||||
--lora-error: oklch(75% 0.32 29);
|
||||
--lora-warning: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h)); /* Modified to be used with oklch() */
|
||||
--lora-success: oklch(var(--lora-success-l) var(--lora-success-c) var(--lora-success-h)); /* New green success color */
|
||||
--lora-warning: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h));
|
||||
--lora-success: oklch(var(--lora-success-l) var(--lora-success-c) var(--lora-success-h));
|
||||
|
||||
/* Spacing Scale */
|
||||
--space-1: calc(8px * 1);
|
||||
|
||||
@@ -223,11 +223,6 @@
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.update-badge.hidden,
|
||||
.update-badge:not(.visible) {
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
/* Mobile adjustments */
|
||||
@media (max-width: 768px) {
|
||||
.app-title {
|
||||
|
||||
@@ -172,6 +172,91 @@ body.modal-open {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* Update Modal specific styles */
|
||||
.update-actions {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: var(--space-2);
|
||||
align-items: stretch;
|
||||
flex-wrap: nowrap;
|
||||
}
|
||||
|
||||
.update-link {
|
||||
color: var(--lora-accent);
|
||||
text-decoration: none;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-size: 0.95em;
|
||||
}
|
||||
|
||||
.update-link:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
/* Update progress styles */
|
||||
.update-progress {
|
||||
background: rgba(0, 0, 0, 0.03);
|
||||
border: 1px solid var(--lora-border);
|
||||
border-radius: var(--border-radius-sm);
|
||||
padding: var(--space-2);
|
||||
margin: var(--space-2) 0;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .update-progress {
|
||||
background: rgba(255, 255, 255, 0.03);
|
||||
}
|
||||
|
||||
.progress-info {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: var(--space-1);
|
||||
}
|
||||
|
||||
.progress-text {
|
||||
font-size: 0.9em;
|
||||
color: var(--text-color);
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.progress-bar {
|
||||
width: 100%;
|
||||
height: 8px;
|
||||
background-color: rgba(0, 0, 0, 0.1);
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .progress-bar {
|
||||
background-color: rgba(255, 255, 255, 0.1);
|
||||
}
|
||||
|
||||
.progress-fill {
|
||||
height: 100%;
|
||||
background-color: var(--lora-accent);
|
||||
width: 0%;
|
||||
transition: width 0.3s ease;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
/* Update button states */
|
||||
#updateBtn {
|
||||
min-width: 120px;
|
||||
}
|
||||
|
||||
#updateBtn.updating {
|
||||
background-color: var(--lora-warning);
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
#updateBtn.success {
|
||||
background-color: var(--lora-success);
|
||||
}
|
||||
|
||||
#updateBtn.error {
|
||||
background-color: var(--lora-error);
|
||||
}
|
||||
|
||||
/* Settings styles */
|
||||
.settings-toggle {
|
||||
width: 36px;
|
||||
|
||||
@@ -182,6 +182,31 @@
|
||||
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
/* Style for optgroups */
|
||||
.control-group select optgroup {
|
||||
font-weight: 600;
|
||||
font-style: normal;
|
||||
color: var(--text-color);
|
||||
background-color: var(--card-bg);
|
||||
}
|
||||
|
||||
.control-group select option {
|
||||
padding: 4px 8px;
|
||||
background-color: var(--card-bg);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
/* Dark theme optgroup styling */
|
||||
[data-theme="dark"] .control-group select optgroup {
|
||||
background-color: var(--card-bg);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .control-group select option {
|
||||
background-color: var(--card-bg);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
.control-group select:hover {
|
||||
border-color: var(--lora-accent);
|
||||
background-color: var(--bg-color);
|
||||
|
||||
@@ -54,25 +54,16 @@ export async function fetchModelsPage(options = {}) {
|
||||
if (pageState.filters) {
|
||||
// Handle tags filters
|
||||
if (pageState.filters.tags && pageState.filters.tags.length > 0) {
|
||||
// Checkpoints API expects individual 'tag' parameters, Loras API expects comma-separated 'tags'
|
||||
if (modelType === 'checkpoint') {
|
||||
pageState.filters.tags.forEach(tag => {
|
||||
params.append('tag', tag);
|
||||
});
|
||||
} else {
|
||||
params.append('tags', pageState.filters.tags.join(','));
|
||||
}
|
||||
pageState.filters.tags.forEach(tag => {
|
||||
params.append('tag', tag);
|
||||
});
|
||||
}
|
||||
|
||||
// Handle base model filters
|
||||
if (pageState.filters.baseModel && pageState.filters.baseModel.length > 0) {
|
||||
if (modelType === 'checkpoint') {
|
||||
pageState.filters.baseModel.forEach(model => {
|
||||
params.append('base_model', model);
|
||||
});
|
||||
} else {
|
||||
params.append('base_models', pageState.filters.baseModel.join(','));
|
||||
}
|
||||
pageState.filters.baseModel.forEach(model => {
|
||||
params.append('base_model', model);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -277,7 +268,7 @@ export async function deleteModel(filePath, modelType = 'lora') {
|
||||
|
||||
const endpoint = modelType === 'checkpoint'
|
||||
? '/api/checkpoints/delete'
|
||||
: '/api/delete_model';
|
||||
: '/api/loras/delete';
|
||||
|
||||
const response = await fetch(endpoint, {
|
||||
method: 'POST',
|
||||
@@ -454,7 +445,7 @@ export async function refreshSingleModelMetadata(filePath, modelType = 'lora') {
|
||||
|
||||
const endpoint = modelType === 'checkpoint'
|
||||
? '/api/checkpoints/fetch-civitai'
|
||||
: '/api/fetch-civitai';
|
||||
: '/api/loras/fetch-civitai';
|
||||
|
||||
const response = await fetch(endpoint, {
|
||||
method: 'POST',
|
||||
@@ -557,7 +548,7 @@ export async function uploadPreview(filePath, file, modelType = 'lora', nsfwLeve
|
||||
// Set endpoint based on model type
|
||||
const endpoint = modelType === 'checkpoint'
|
||||
? '/api/checkpoints/replace-preview'
|
||||
: '/api/replace_preview';
|
||||
: '/api/loras/replace_preview';
|
||||
|
||||
const response = await fetch(endpoint, {
|
||||
method: 'POST',
|
||||
|
||||
@@ -89,7 +89,7 @@ export async function fetchLorasPage(page = 1, pageSize = 100) {
|
||||
export async function fetchCivitai() {
|
||||
return fetchCivitaiMetadata({
|
||||
modelType: 'lora',
|
||||
fetchEndpoint: '/api/fetch-all-civitai',
|
||||
fetchEndpoint: '/api/loras/fetch-all-civitai',
|
||||
resetAndReloadFunction: resetAndReload
|
||||
});
|
||||
}
|
||||
|
||||
@@ -125,7 +125,7 @@ export const ModelContextMenuMixin = {
|
||||
|
||||
const endpoint = this.modelType === 'checkpoint' ?
|
||||
'/api/checkpoints/relink-civitai' :
|
||||
'/api/relink-civitai';
|
||||
'/api/loras/relink-civitai';
|
||||
|
||||
const response = await fetch(endpoint, {
|
||||
method: 'POST',
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// PageControls.js - Manages controls for both LoRAs and Checkpoints pages
|
||||
import { state, getCurrentPageState, setCurrentPageType } from '../../state/index.js';
|
||||
import { getCurrentPageState, setCurrentPageType } from '../../state/index.js';
|
||||
import { getStorageItem, setStorageItem, getSessionItem, setSessionItem } from '../../utils/storageHelpers.js';
|
||||
import { showToast } from '../../utils/uiHelpers.js';
|
||||
|
||||
@@ -41,6 +41,9 @@ export class PageControls {
|
||||
this.pageState.isLoading = false;
|
||||
this.pageState.hasMore = true;
|
||||
|
||||
// Set default sort based on page type
|
||||
this.pageState.sortBy = this.pageType === 'loras' ? 'name:asc' : 'name:asc';
|
||||
|
||||
// Load sort preference
|
||||
this.loadSortPreference();
|
||||
}
|
||||
@@ -326,14 +329,36 @@ export class PageControls {
|
||||
loadSortPreference() {
|
||||
const savedSort = getStorageItem(`${this.pageType}_sort`);
|
||||
if (savedSort) {
|
||||
this.pageState.sortBy = savedSort;
|
||||
// Handle legacy format conversion
|
||||
const convertedSort = this.convertLegacySortFormat(savedSort);
|
||||
this.pageState.sortBy = convertedSort;
|
||||
const sortSelect = document.getElementById('sortSelect');
|
||||
if (sortSelect) {
|
||||
sortSelect.value = savedSort;
|
||||
sortSelect.value = convertedSort;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert legacy sort format to new format
|
||||
* @param {string} sortValue - The sort value to convert
|
||||
* @returns {string} - Converted sort value
|
||||
*/
|
||||
convertLegacySortFormat(sortValue) {
|
||||
// Convert old format to new format with direction
|
||||
switch (sortValue) {
|
||||
case 'name':
|
||||
return 'name:asc';
|
||||
case 'date':
|
||||
return 'date:desc'; // Newest first is more intuitive default
|
||||
case 'size':
|
||||
return 'size:desc'; // Largest first is more intuitive default
|
||||
default:
|
||||
// If it's already in new format or unknown, return as is
|
||||
return sortValue.includes(':') ? sortValue : 'name:asc';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Save sort preference to storage
|
||||
* @param {string} sortValue - The sort value to save
|
||||
|
||||
@@ -87,7 +87,7 @@ export class DownloadManager {
|
||||
throw new Error('Invalid Civitai URL format');
|
||||
}
|
||||
|
||||
const response = await fetch(`/api/civitai/versions/${this.modelId}`);
|
||||
const response = await fetch(`/api/loras/civitai/versions/${this.modelId}`);
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => ({}));
|
||||
if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) {
|
||||
@@ -254,7 +254,7 @@ export class DownloadManager {
|
||||
|
||||
try {
|
||||
// Fetch LoRA roots
|
||||
const rootsResponse = await fetch('/api/lora-roots');
|
||||
const rootsResponse = await fetch('/api/loras/roots');
|
||||
if (!rootsResponse.ok) {
|
||||
throw new Error('Failed to fetch LoRA roots');
|
||||
}
|
||||
@@ -272,7 +272,7 @@ export class DownloadManager {
|
||||
}
|
||||
|
||||
// Fetch folders dynamically
|
||||
const foldersResponse = await fetch('/api/folders');
|
||||
const foldersResponse = await fetch('/api/loras/folders');
|
||||
if (!foldersResponse.ok) {
|
||||
throw new Error('Failed to fetch folders');
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ class MoveManager {
|
||||
|
||||
try {
|
||||
// Fetch LoRA roots
|
||||
const rootsResponse = await fetch('/api/lora-roots');
|
||||
const rootsResponse = await fetch('/api/loras/roots');
|
||||
if (!rootsResponse.ok) {
|
||||
throw new Error('Failed to fetch LoRA roots');
|
||||
}
|
||||
@@ -96,7 +96,7 @@ class MoveManager {
|
||||
}
|
||||
|
||||
// Fetch folders dynamically
|
||||
const foldersResponse = await fetch('/api/folders');
|
||||
const foldersResponse = await fetch('/api/loras/folders');
|
||||
if (!foldersResponse.ok) {
|
||||
throw new Error('Failed to fetch folders');
|
||||
}
|
||||
@@ -190,7 +190,7 @@ class MoveManager {
|
||||
|
||||
// Refresh folder tags after successful move
|
||||
try {
|
||||
const foldersResponse = await fetch('/api/folders');
|
||||
const foldersResponse = await fetch('/api/loras/folders');
|
||||
if (foldersResponse.ok) {
|
||||
const foldersData = await foldersResponse.json();
|
||||
updateFolderTags(foldersData.folders);
|
||||
|
||||
@@ -161,7 +161,7 @@ export class SettingsManager {
|
||||
if (!defaultLoraRootSelect) return;
|
||||
|
||||
// Fetch lora roots
|
||||
const response = await fetch('/api/lora-roots');
|
||||
const response = await fetch('/api/loras/roots');
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to fetch LoRA roots');
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ import { getStorageItem, setStorageItem } from '../utils/storageHelpers.js';
|
||||
|
||||
export class UpdateService {
|
||||
constructor() {
|
||||
this.updateCheckInterval = 24 * 60 * 60 * 1000; // 24 hours
|
||||
this.updateCheckInterval = 60 * 60 * 1000; // 1 hour
|
||||
this.currentVersion = "v0.0.0"; // Initialize with default values
|
||||
this.latestVersion = "v0.0.0"; // Initialize with default values
|
||||
this.updateInfo = null;
|
||||
@@ -13,8 +13,10 @@ export class UpdateService {
|
||||
branch: "unknown",
|
||||
commit_date: "unknown"
|
||||
};
|
||||
this.updateNotificationsEnabled = getStorageItem('show_update_notifications');
|
||||
this.updateNotificationsEnabled = getStorageItem('show_update_notifications', true);
|
||||
this.lastCheckTime = parseInt(getStorageItem('last_update_check') || '0');
|
||||
this.isUpdating = false;
|
||||
this.nightlyMode = getStorageItem('nightly_updates', false);
|
||||
}
|
||||
|
||||
initialize() {
|
||||
@@ -28,23 +30,44 @@ export class UpdateService {
|
||||
this.updateBadgeVisibility();
|
||||
});
|
||||
}
|
||||
|
||||
const updateBtn = document.getElementById('updateBtn');
|
||||
if (updateBtn) {
|
||||
updateBtn.addEventListener('click', () => this.performUpdate());
|
||||
}
|
||||
|
||||
// Register event listener for nightly update toggle
|
||||
const nightlyCheckbox = document.getElementById('nightlyUpdateToggle');
|
||||
if (nightlyCheckbox) {
|
||||
nightlyCheckbox.checked = this.nightlyMode;
|
||||
nightlyCheckbox.addEventListener('change', (e) => {
|
||||
this.nightlyMode = e.target.checked;
|
||||
setStorageItem('nightly_updates', e.target.checked);
|
||||
this.updateNightlyWarning();
|
||||
this.updateModalContent();
|
||||
// Re-check for updates when switching channels
|
||||
this.manualCheckForUpdates();
|
||||
});
|
||||
this.updateNightlyWarning();
|
||||
}
|
||||
|
||||
// Perform update check if needed
|
||||
this.checkForUpdates().then(() => {
|
||||
// Ensure badges are updated after checking
|
||||
this.updateBadgeVisibility();
|
||||
});
|
||||
|
||||
// Set up event listener for update button
|
||||
// const updateToggle = document.getElementById('updateToggleBtn');
|
||||
// if (updateToggle) {
|
||||
// updateToggle.addEventListener('click', () => this.toggleUpdateModal());
|
||||
// }
|
||||
|
||||
// Immediately update modal content with current values (even if from default)
|
||||
this.updateModalContent();
|
||||
}
|
||||
|
||||
updateNightlyWarning() {
|
||||
const warning = document.getElementById('nightlyWarning');
|
||||
if (warning) {
|
||||
warning.style.display = this.nightlyMode ? 'flex' : 'none';
|
||||
}
|
||||
}
|
||||
|
||||
async checkForUpdates() {
|
||||
// Check if we should perform an update check
|
||||
const now = Date.now();
|
||||
@@ -59,8 +82,8 @@ export class UpdateService {
|
||||
}
|
||||
|
||||
try {
|
||||
// Call backend API to check for updates
|
||||
const response = await fetch('/api/check-updates');
|
||||
// Call backend API to check for updates with nightly flag
|
||||
const response = await fetch(`/api/check-updates?nightly=${this.nightlyMode}`);
|
||||
const data = await response.json();
|
||||
|
||||
if (data.success) {
|
||||
@@ -137,8 +160,8 @@ export class UpdateService {
|
||||
const shouldShow = this.updateNotificationsEnabled && this.updateAvailable;
|
||||
|
||||
if (updateBadge) {
|
||||
updateBadge.classList.toggle('hidden', !shouldShow);
|
||||
console.log("Update badge visibility:", !shouldShow ? "hidden" : "visible");
|
||||
updateBadge.classList.toggle('visible', shouldShow);
|
||||
console.log("Update badge visibility:", shouldShow ? "visible" : "hidden");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,7 +180,17 @@ export class UpdateService {
|
||||
const newVersionEl = modal.querySelector('.new-version .version-number');
|
||||
|
||||
if (currentVersionEl) currentVersionEl.textContent = this.currentVersion;
|
||||
if (newVersionEl) newVersionEl.textContent = this.latestVersion;
|
||||
|
||||
if (newVersionEl) {
|
||||
newVersionEl.textContent = this.latestVersion;
|
||||
}
|
||||
|
||||
// Update update button state
|
||||
const updateBtn = modal.querySelector('#updateBtn');
|
||||
if (updateBtn) {
|
||||
updateBtn.classList.toggle('disabled', !this.updateAvailable || this.isUpdating);
|
||||
updateBtn.disabled = !this.updateAvailable || this.isUpdating;
|
||||
}
|
||||
|
||||
// Update git info
|
||||
const gitInfoEl = modal.querySelector('.git-info');
|
||||
@@ -218,6 +251,131 @@ export class UpdateService {
|
||||
}
|
||||
}
|
||||
|
||||
async performUpdate() {
|
||||
if (!this.updateAvailable || this.isUpdating) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
this.isUpdating = true;
|
||||
this.updateUpdateUI('updating', 'Updating...');
|
||||
this.showUpdateProgress(true);
|
||||
|
||||
// Update progress
|
||||
this.updateProgress(10, 'Preparing update...');
|
||||
|
||||
const response = await fetch('/api/perform-update', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify({
|
||||
nightly: this.nightlyMode
|
||||
})
|
||||
});
|
||||
|
||||
this.updateProgress(50, 'Installing update...');
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (data.success) {
|
||||
this.updateProgress(100, 'Update completed successfully!');
|
||||
this.updateUpdateUI('success', 'Updated!');
|
||||
|
||||
// Show success message and suggest restart
|
||||
setTimeout(() => {
|
||||
this.showUpdateCompleteMessage(data.new_version);
|
||||
}, 1000);
|
||||
|
||||
} else {
|
||||
throw new Error(data.error || 'Update failed');
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
console.error('Update failed:', error);
|
||||
this.updateUpdateUI('error', 'Update Failed');
|
||||
this.updateProgress(0, `Update failed: ${error.message}`);
|
||||
|
||||
// Hide progress after error
|
||||
setTimeout(() => {
|
||||
this.showUpdateProgress(false);
|
||||
}, 3000);
|
||||
} finally {
|
||||
this.isUpdating = false;
|
||||
}
|
||||
}
|
||||
|
||||
updateUpdateUI(state, text) {
|
||||
const updateBtn = document.getElementById('updateBtn');
|
||||
const updateBtnText = document.getElementById('updateBtnText');
|
||||
|
||||
if (updateBtn && updateBtnText) {
|
||||
// Remove existing state classes
|
||||
updateBtn.classList.remove('updating', 'success', 'error', 'disabled');
|
||||
|
||||
// Add new state class
|
||||
if (state !== 'normal') {
|
||||
updateBtn.classList.add(state);
|
||||
}
|
||||
|
||||
// Update button text
|
||||
updateBtnText.textContent = text;
|
||||
|
||||
// Update disabled state
|
||||
updateBtn.disabled = (state === 'updating' || state === 'disabled');
|
||||
}
|
||||
}
|
||||
|
||||
showUpdateProgress(show) {
|
||||
const progressContainer = document.getElementById('updateProgress');
|
||||
if (progressContainer) {
|
||||
progressContainer.style.display = show ? 'block' : 'none';
|
||||
}
|
||||
}
|
||||
|
||||
updateProgress(percentage, text) {
|
||||
const progressFill = document.getElementById('updateProgressFill');
|
||||
const progressText = document.getElementById('updateProgressText');
|
||||
|
||||
if (progressFill) {
|
||||
progressFill.style.width = `${percentage}%`;
|
||||
}
|
||||
|
||||
if (progressText) {
|
||||
progressText.textContent = text;
|
||||
}
|
||||
}
|
||||
|
||||
showUpdateCompleteMessage(newVersion) {
|
||||
const modal = document.getElementById('updateModal');
|
||||
if (!modal) return;
|
||||
|
||||
// Update the modal content to show completion
|
||||
const progressText = document.getElementById('updateProgressText');
|
||||
if (progressText) {
|
||||
progressText.innerHTML = `
|
||||
<div style="text-align: center; color: var(--lora-success);">
|
||||
<i class="fas fa-check-circle" style="margin-right: 8px;"></i>
|
||||
Successfully updated to ${newVersion}!
|
||||
<br><br>
|
||||
<small style="opacity: 0.8;">
|
||||
Please restart ComfyUI to complete the update process.
|
||||
</small>
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
|
||||
// Update current version display
|
||||
this.currentVersion = newVersion;
|
||||
this.updateAvailable = false;
|
||||
|
||||
// Refresh the modal content
|
||||
setTimeout(() => {
|
||||
this.updateModalContent();
|
||||
this.showUpdateProgress(false);
|
||||
}, 2000);
|
||||
}
|
||||
|
||||
// Simple markdown parser for changelog items
|
||||
parseMarkdown(text) {
|
||||
if (!text) return '';
|
||||
|
||||
@@ -99,7 +99,7 @@ export class FolderBrowser {
|
||||
}
|
||||
|
||||
// Fetch LoRA roots
|
||||
const rootsResponse = await fetch('/api/lora-roots');
|
||||
const rootsResponse = await fetch('/api/loras/roots');
|
||||
if (!rootsResponse.ok) {
|
||||
throw new Error(`Failed to fetch LoRA roots: ${rootsResponse.status}`);
|
||||
}
|
||||
@@ -119,7 +119,7 @@ export class FolderBrowser {
|
||||
}
|
||||
|
||||
// Fetch folders
|
||||
const foldersResponse = await fetch('/api/folders');
|
||||
const foldersResponse = await fetch('/api/loras/folders');
|
||||
if (!foldersResponse.ok) {
|
||||
throw new Error(`Failed to fetch folders: ${foldersResponse.status}`);
|
||||
}
|
||||
|
||||
@@ -11,8 +11,18 @@
|
||||
<div class="action-buttons">
|
||||
<div title="Sort models by..." class="control-group">
|
||||
<select id="sortSelect">
|
||||
<option value="name">Name</option>
|
||||
<option value="date">Date</option>
|
||||
<optgroup label="Name">
|
||||
<option value="name:asc">A - Z</option>
|
||||
<option value="name:desc">Z - A</option>
|
||||
</optgroup>
|
||||
<optgroup label="Date Added">
|
||||
<option value="date:desc">Newest</option>
|
||||
<option value="date:asc">Oldest</option>
|
||||
</optgroup>
|
||||
<optgroup label="File Size">
|
||||
<option value="size:desc">Largest</option>
|
||||
<option value="size:asc">Smallest</option>
|
||||
</optgroup>
|
||||
</select>
|
||||
</div>
|
||||
<div title="Refresh model list" class="control-group dropdown-group">
|
||||
|
||||
@@ -53,7 +53,7 @@
|
||||
</div>
|
||||
<div class="update-toggle" id="updateToggleBtn" title="Check Updates">
|
||||
<i class="fas fa-bell"></i>
|
||||
<span class="update-badge hidden"></span>
|
||||
<span class="update-badge"></span>
|
||||
</div>
|
||||
<div class="support-toggle" id="supportToggleBtn" title="Support">
|
||||
<i class="fas fa-heart"></i>
|
||||
|
||||
@@ -476,9 +476,26 @@
|
||||
<span class="version-number">v0.0.0</span>
|
||||
</div>
|
||||
</div>
|
||||
<a href="https://github.com/willmiao/ComfyUI-Lora-Manager" target="_blank" class="update-link">
|
||||
<i class="fas fa-external-link-alt"></i> View on GitHub
|
||||
</a>
|
||||
|
||||
<div class="update-actions">
|
||||
<a href="https://github.com/willmiao/ComfyUI-Lora-Manager" target="_blank" class="update-link">
|
||||
<i class="fas fa-external-link-alt"></i> View on GitHub
|
||||
</a>
|
||||
<button id="updateBtn" class="primary-btn disabled">
|
||||
<i class="fas fa-download"></i>
|
||||
<span id="updateBtnText">Update Now</span>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Update Progress Section -->
|
||||
<div class="update-progress" id="updateProgress" style="display: none;">
|
||||
<div class="progress-info">
|
||||
<div class="progress-text" id="updateProgressText">Preparing update...</div>
|
||||
<div class="progress-bar">
|
||||
<div class="progress-fill" id="updateProgressFill"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="changelog-section">
|
||||
|
||||
Reference in New Issue
Block a user