mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
@@ -1,71 +0,0 @@
|
|||||||
# Path mappings configuration for ComfyUI-Lora-Manager
|
|
||||||
# This file allows you to customize how base models and Civitai tags map to directories when downloading models
|
|
||||||
|
|
||||||
# Base model mappings
|
|
||||||
# Format: "Original Base Model": "Custom Directory Name"
|
|
||||||
#
|
|
||||||
# Example: If you change "Flux.1 D": "flux"
|
|
||||||
# Then models with base model "Flux.1 D" will be stored in a directory named "flux"
|
|
||||||
# So the final path would be: <model_directory>/flux/<tag>/model_file.safetensors
|
|
||||||
base_models:
|
|
||||||
"SD 1.4": "SD 1.4"
|
|
||||||
"SD 1.5": "SD 1.5"
|
|
||||||
"SD 1.5 LCM": "SD 1.5 LCM"
|
|
||||||
"SD 1.5 Hyper": "SD 1.5 Hyper"
|
|
||||||
"SD 2.0": "SD 2.0"
|
|
||||||
"SD 2.1": "SD 2.1"
|
|
||||||
"SDXL 1.0": "SDXL 1.0"
|
|
||||||
"SD 3": "SD 3"
|
|
||||||
"SD 3.5": "SD 3.5"
|
|
||||||
"SD 3.5 Medium": "SD 3.5 Medium"
|
|
||||||
"SD 3.5 Large": "SD 3.5 Large"
|
|
||||||
"SD 3.5 Large Turbo": "SD 3.5 Large Turbo"
|
|
||||||
"Pony": "Pony"
|
|
||||||
"Flux.1 S": "Flux.1 S"
|
|
||||||
"Flux.1 D": "Flux.1 D"
|
|
||||||
"Flux.1 Kontext": "Flux.1 Kontext"
|
|
||||||
"AuraFlow": "AuraFlow"
|
|
||||||
"SDXL Lightning": "SDXL Lightning"
|
|
||||||
"SDXL Hyper": "SDXL Hyper"
|
|
||||||
"Stable Cascade": "Stable Cascade"
|
|
||||||
"SVD": "SVD"
|
|
||||||
"PixArt a": "PixArt a"
|
|
||||||
"PixArt E": "PixArt E"
|
|
||||||
"Hunyuan 1": "Hunyuan 1"
|
|
||||||
"Hunyuan Video": "Hunyuan Video"
|
|
||||||
"Lumina": "Lumina"
|
|
||||||
"Kolors": "Kolors"
|
|
||||||
"Illustrious": "Illustrious"
|
|
||||||
"Mochi": "Mochi"
|
|
||||||
"LTXV": "LTXV"
|
|
||||||
"CogVideoX": "CogVideoX"
|
|
||||||
"NoobAI": "NoobAI"
|
|
||||||
"Wan Video": "Wan Video"
|
|
||||||
"Wan Video 1.3B t2v": "Wan Video 1.3B t2v"
|
|
||||||
"Wan Video 14B t2v": "Wan Video 14B t2v"
|
|
||||||
"Wan Video 14B i2v 480p": "Wan Video 14B i2v 480p"
|
|
||||||
"Wan Video 14B i2v 720p": "Wan Video 14B i2v 720p"
|
|
||||||
"HiDream": "HiDream"
|
|
||||||
"Other": "Other"
|
|
||||||
|
|
||||||
# Civitai model tag mappings
|
|
||||||
# Format: "Original Tag": "Custom Directory Name"
|
|
||||||
#
|
|
||||||
# Example: If you change "character": "characters"
|
|
||||||
# Then models with tag "character" will be stored in a directory named "characters"
|
|
||||||
# So the final path would be: <model_directory>/<base_model>/characters/model_file.safetensors
|
|
||||||
model_tags:
|
|
||||||
"character": "character"
|
|
||||||
"style": "style"
|
|
||||||
"concept": "concept"
|
|
||||||
"clothing": "clothing"
|
|
||||||
"base model": "base model"
|
|
||||||
"poses": "poses"
|
|
||||||
"background": "background"
|
|
||||||
"tool": "tool"
|
|
||||||
"vehicle": "vehicle"
|
|
||||||
"buildings": "buildings"
|
|
||||||
"objects": "objects"
|
|
||||||
"assets": "assets"
|
|
||||||
"animal": "animal"
|
|
||||||
"action": "action"
|
|
||||||
@@ -6,10 +6,8 @@ from pathlib import Path
|
|||||||
from server import PromptServer # type: ignore
|
from server import PromptServer # type: ignore
|
||||||
|
|
||||||
from .config import config
|
from .config import config
|
||||||
from .routes.lora_routes import LoraRoutes
|
from .services.model_service_factory import ModelServiceFactory, register_default_model_types
|
||||||
from .routes.api_routes import ApiRoutes
|
|
||||||
from .routes.recipe_routes import RecipeRoutes
|
from .routes.recipe_routes import RecipeRoutes
|
||||||
from .routes.checkpoints_routes import CheckpointsRoutes
|
|
||||||
from .routes.stats_routes import StatsRoutes
|
from .routes.stats_routes import StatsRoutes
|
||||||
from .routes.update_routes import UpdateRoutes
|
from .routes.update_routes import UpdateRoutes
|
||||||
from .routes.misc_routes import MiscRoutes
|
from .routes.misc_routes import MiscRoutes
|
||||||
@@ -17,6 +15,7 @@ from .routes.example_images_routes import ExampleImagesRoutes
|
|||||||
from .services.service_registry import ServiceRegistry
|
from .services.service_registry import ServiceRegistry
|
||||||
from .services.settings_manager import settings
|
from .services.settings_manager import settings
|
||||||
from .utils.example_images_migration import ExampleImagesMigration
|
from .utils.example_images_migration import ExampleImagesMigration
|
||||||
|
from .services.websocket_manager import ws_manager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -28,12 +27,28 @@ class LoraManager:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_routes(cls):
|
def add_routes(cls):
|
||||||
"""Initialize and register all routes"""
|
"""Initialize and register all routes using the new refactored architecture"""
|
||||||
app = PromptServer.instance.app
|
app = PromptServer.instance.app
|
||||||
|
|
||||||
# Configure aiohttp access logger to be less verbose
|
# Configure aiohttp access logger to be less verbose
|
||||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
# Add specific suppression for connection reset errors
|
||||||
|
class ConnectionResetFilter(logging.Filter):
|
||||||
|
def filter(self, record):
|
||||||
|
# Filter out connection reset errors that are not critical
|
||||||
|
if "ConnectionResetError" in str(record.getMessage()):
|
||||||
|
return False
|
||||||
|
if "_call_connection_lost" in str(record.getMessage()):
|
||||||
|
return False
|
||||||
|
if "WinError 10054" in str(record.getMessage()):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Apply the filter to asyncio logger
|
||||||
|
asyncio_logger = logging.getLogger("asyncio")
|
||||||
|
asyncio_logger.addFilter(ConnectionResetFilter())
|
||||||
|
|
||||||
added_targets = set() # Track already added target paths
|
added_targets = set() # Track already added target paths
|
||||||
|
|
||||||
# Add static route for example images if the path exists in settings
|
# Add static route for example images if the path exists in settings
|
||||||
@@ -110,35 +125,37 @@ class LoraManager:
|
|||||||
# Add static route for plugin assets
|
# Add static route for plugin assets
|
||||||
app.router.add_static('/loras_static', config.static_path)
|
app.router.add_static('/loras_static', config.static_path)
|
||||||
|
|
||||||
# Setup feature routes
|
# Register default model types with the factory
|
||||||
lora_routes = LoraRoutes()
|
register_default_model_types()
|
||||||
checkpoints_routes = CheckpointsRoutes()
|
|
||||||
stats_routes = StatsRoutes()
|
|
||||||
|
|
||||||
# Initialize routes
|
# Setup all model routes using the factory
|
||||||
lora_routes.setup_routes(app)
|
ModelServiceFactory.setup_all_routes(app)
|
||||||
checkpoints_routes.setup_routes(app)
|
|
||||||
stats_routes.setup_routes(app) # Add statistics routes
|
# Setup non-model-specific routes
|
||||||
ApiRoutes.setup_routes(app)
|
stats_routes = StatsRoutes()
|
||||||
|
stats_routes.setup_routes(app)
|
||||||
RecipeRoutes.setup_routes(app)
|
RecipeRoutes.setup_routes(app)
|
||||||
UpdateRoutes.setup_routes(app)
|
UpdateRoutes.setup_routes(app)
|
||||||
MiscRoutes.setup_routes(app) # Register miscellaneous routes
|
MiscRoutes.setup_routes(app)
|
||||||
ExampleImagesRoutes.setup_routes(app) # Register example images routes
|
ExampleImagesRoutes.setup_routes(app)
|
||||||
|
|
||||||
|
# Setup WebSocket routes that are shared across all model types
|
||||||
|
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||||
|
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection)
|
||||||
|
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection)
|
||||||
|
|
||||||
# Schedule service initialization
|
# Schedule service initialization
|
||||||
app.on_startup.append(lambda app: cls._initialize_services())
|
app.on_startup.append(lambda app: cls._initialize_services())
|
||||||
|
|
||||||
# Add cleanup
|
# Add cleanup
|
||||||
app.on_shutdown.append(cls._cleanup)
|
app.on_shutdown.append(cls._cleanup)
|
||||||
app.on_shutdown.append(ApiRoutes.cleanup)
|
|
||||||
|
logger.info(f"LoRA Manager: Set up routes for {len(ModelServiceFactory.get_registered_types())} model types: {', '.join(ModelServiceFactory.get_registered_types())}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _initialize_services(cls):
|
async def _initialize_services(cls):
|
||||||
"""Initialize all services using the ServiceRegistry"""
|
"""Initialize all services using the ServiceRegistry"""
|
||||||
try:
|
try:
|
||||||
# Ensure aiohttp access logger is configured with reduced verbosity
|
|
||||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
# Initialize CivitaiClient first to ensure it's ready for other services
|
# Initialize CivitaiClient first to ensure it's ready for other services
|
||||||
await ServiceRegistry.get_civitai_client()
|
await ServiceRegistry.get_civitai_client()
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
619
py/routes/base_model_routes.py
Normal file
619
py/routes/base_model_routes.py
Normal file
@@ -0,0 +1,619 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from aiohttp import web
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import jinja2
|
||||||
|
|
||||||
|
from ..utils.routes_common import ModelRouteUtils
|
||||||
|
from ..services.websocket_manager import ws_manager
|
||||||
|
from ..services.settings_manager import settings
|
||||||
|
from ..config import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class BaseModelRoutes(ABC):
|
||||||
|
"""Base route controller for all model types"""
|
||||||
|
|
||||||
|
def __init__(self, service):
|
||||||
|
"""Initialize the route controller
|
||||||
|
|
||||||
|
Args:
|
||||||
|
service: Model service instance (LoraService, CheckpointService, etc.)
|
||||||
|
"""
|
||||||
|
self.service = service
|
||||||
|
self.model_type = service.model_type
|
||||||
|
self.template_env = jinja2.Environment(
|
||||||
|
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||||
|
autoescape=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def setup_routes(self, app: web.Application, prefix: str):
|
||||||
|
"""Setup common routes for the model type
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: aiohttp application
|
||||||
|
prefix: URL prefix (e.g., 'loras', 'checkpoints')
|
||||||
|
"""
|
||||||
|
# Common model management routes
|
||||||
|
app.router.add_get(f'/api/{prefix}', self.get_models)
|
||||||
|
app.router.add_post(f'/api/{prefix}/delete', self.delete_model)
|
||||||
|
app.router.add_post(f'/api/{prefix}/exclude', self.exclude_model)
|
||||||
|
app.router.add_post(f'/api/{prefix}/fetch-civitai', self.fetch_civitai)
|
||||||
|
app.router.add_post(f'/api/{prefix}/relink-civitai', self.relink_civitai)
|
||||||
|
app.router.add_post(f'/api/{prefix}/replace-preview', self.replace_preview)
|
||||||
|
app.router.add_post(f'/api/{prefix}/save-metadata', self.save_metadata)
|
||||||
|
app.router.add_post(f'/api/{prefix}/rename', self.rename_model)
|
||||||
|
app.router.add_post(f'/api/{prefix}/bulk-delete', self.bulk_delete_models)
|
||||||
|
app.router.add_post(f'/api/{prefix}/verify-duplicates', self.verify_duplicates)
|
||||||
|
|
||||||
|
# Common query routes
|
||||||
|
app.router.add_get(f'/api/{prefix}/top-tags', self.get_top_tags)
|
||||||
|
app.router.add_get(f'/api/{prefix}/base-models', self.get_base_models)
|
||||||
|
app.router.add_get(f'/api/{prefix}/scan', self.scan_models)
|
||||||
|
app.router.add_get(f'/api/{prefix}/roots', self.get_model_roots)
|
||||||
|
app.router.add_get(f'/api/{prefix}/folders', self.get_folders)
|
||||||
|
app.router.add_get(f'/api/{prefix}/find-duplicates', self.find_duplicate_models)
|
||||||
|
app.router.add_get(f'/api/{prefix}/find-filename-conflicts', self.find_filename_conflicts)
|
||||||
|
|
||||||
|
# Common Download management
|
||||||
|
app.router.add_post(f'/api/download-model', self.download_model)
|
||||||
|
app.router.add_get(f'/api/download-model-get', self.download_model_get)
|
||||||
|
app.router.add_get(f'/api/cancel-download-get', self.cancel_download_get)
|
||||||
|
app.router.add_get(f'/api/download-progress/{{download_id}}', self.get_download_progress)
|
||||||
|
|
||||||
|
# CivitAI integration routes
|
||||||
|
app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai)
|
||||||
|
# app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions)
|
||||||
|
|
||||||
|
# Add generic page route
|
||||||
|
app.router.add_get(f'/{prefix}', self.handle_models_page)
|
||||||
|
|
||||||
|
# Setup model-specific routes
|
||||||
|
self.setup_specific_routes(app, prefix)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def setup_specific_routes(self, app: web.Application, prefix: str):
|
||||||
|
"""Setup model-specific routes - to be implemented by subclasses"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def handle_models_page(self, request: web.Request) -> web.Response:
|
||||||
|
"""
|
||||||
|
Generic handler for model pages (e.g., /loras, /checkpoints).
|
||||||
|
Subclasses should set self.template_env and template_name.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check if the scanner is initializing
|
||||||
|
is_initializing = (
|
||||||
|
self.service.scanner._cache is None or
|
||||||
|
(hasattr(self.service.scanner, 'is_initializing') and callable(self.service.scanner.is_initializing) and self.service.scanner.is_initializing()) or
|
||||||
|
(hasattr(self.service.scanner, '_is_initializing') and self.service.scanner._is_initializing)
|
||||||
|
)
|
||||||
|
|
||||||
|
template_name = getattr(self, "template_name", None)
|
||||||
|
if not self.template_env or not template_name:
|
||||||
|
return web.Response(text="Template environment or template name not set", status=500)
|
||||||
|
|
||||||
|
if is_initializing:
|
||||||
|
rendered = self.template_env.get_template(template_name).render(
|
||||||
|
folders=[],
|
||||||
|
is_initializing=True,
|
||||||
|
settings=settings,
|
||||||
|
request=request
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
cache = await self.service.scanner.get_cached_data(force_refresh=False)
|
||||||
|
rendered = self.template_env.get_template(template_name).render(
|
||||||
|
folders=getattr(cache, "folders", []),
|
||||||
|
is_initializing=False,
|
||||||
|
settings=settings,
|
||||||
|
request=request
|
||||||
|
)
|
||||||
|
except Exception as cache_error:
|
||||||
|
logger.error(f"Error loading cache data: {cache_error}")
|
||||||
|
rendered = self.template_env.get_template(template_name).render(
|
||||||
|
folders=[],
|
||||||
|
is_initializing=True,
|
||||||
|
settings=settings,
|
||||||
|
request=request
|
||||||
|
)
|
||||||
|
return web.Response(
|
||||||
|
text=rendered,
|
||||||
|
content_type='text/html'
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling models page: {e}", exc_info=True)
|
||||||
|
return web.Response(
|
||||||
|
text="Error loading models page",
|
||||||
|
status=500
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_models(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get paginated model data"""
|
||||||
|
try:
|
||||||
|
# Parse common query parameters
|
||||||
|
params = self._parse_common_params(request)
|
||||||
|
|
||||||
|
# Get data from service
|
||||||
|
result = await self.service.get_paginated_data(**params)
|
||||||
|
|
||||||
|
# Format response items
|
||||||
|
formatted_result = {
|
||||||
|
'items': [await self.service.format_response(item) for item in result['items']],
|
||||||
|
'total': result['total'],
|
||||||
|
'page': result['page'],
|
||||||
|
'page_size': result['page_size'],
|
||||||
|
'total_pages': result['total_pages']
|
||||||
|
}
|
||||||
|
|
||||||
|
return web.json_response(formatted_result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in get_{self.model_type}s: {e}", exc_info=True)
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
def _parse_common_params(self, request: web.Request) -> Dict:
|
||||||
|
"""Parse common query parameters"""
|
||||||
|
# Parse basic pagination and sorting
|
||||||
|
page = int(request.query.get('page', '1'))
|
||||||
|
page_size = min(int(request.query.get('page_size', '20')), 100)
|
||||||
|
sort_by = request.query.get('sort_by', 'name')
|
||||||
|
folder = request.query.get('folder', None)
|
||||||
|
search = request.query.get('search', None)
|
||||||
|
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
|
||||||
|
|
||||||
|
# Parse filter arrays
|
||||||
|
base_models = request.query.getall('base_model', [])
|
||||||
|
tags = request.query.getall('tag', [])
|
||||||
|
favorites_only = request.query.get('favorites_only', 'false').lower() == 'true'
|
||||||
|
|
||||||
|
# Parse search options
|
||||||
|
search_options = {
|
||||||
|
'filename': request.query.get('search_filename', 'true').lower() == 'true',
|
||||||
|
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
|
||||||
|
'tags': request.query.get('search_tags', 'false').lower() == 'true',
|
||||||
|
'recursive': request.query.get('recursive', 'false').lower() == 'true',
|
||||||
|
}
|
||||||
|
|
||||||
|
# Parse hash filters if provided
|
||||||
|
hash_filters = {}
|
||||||
|
if 'hash' in request.query:
|
||||||
|
hash_filters['single_hash'] = request.query['hash']
|
||||||
|
elif 'hashes' in request.query:
|
||||||
|
try:
|
||||||
|
hash_list = json.loads(request.query['hashes'])
|
||||||
|
if isinstance(hash_list, list):
|
||||||
|
hash_filters['multiple_hashes'] = hash_list
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {
|
||||||
|
'page': page,
|
||||||
|
'page_size': page_size,
|
||||||
|
'sort_by': sort_by,
|
||||||
|
'folder': folder,
|
||||||
|
'search': search,
|
||||||
|
'fuzzy_search': fuzzy_search,
|
||||||
|
'base_models': base_models,
|
||||||
|
'tags': tags,
|
||||||
|
'search_options': search_options,
|
||||||
|
'hash_filters': hash_filters,
|
||||||
|
'favorites_only': favorites_only,
|
||||||
|
# Add model-specific parameters
|
||||||
|
**self._parse_specific_params(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
def _parse_specific_params(self, request: web.Request) -> Dict:
|
||||||
|
"""Parse model-specific parameters - to be overridden by subclasses"""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Common route handlers
|
||||||
|
async def delete_model(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle model deletion request"""
|
||||||
|
return await ModelRouteUtils.handle_delete_model(request, self.service.scanner)
|
||||||
|
|
||||||
|
async def exclude_model(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle model exclusion request"""
|
||||||
|
return await ModelRouteUtils.handle_exclude_model(request, self.service.scanner)
|
||||||
|
|
||||||
|
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle CivitAI metadata fetch request"""
|
||||||
|
response = await ModelRouteUtils.handle_fetch_civitai(request, self.service.scanner)
|
||||||
|
|
||||||
|
# If successful, format the metadata before returning
|
||||||
|
if response.status == 200:
|
||||||
|
data = json.loads(response.body.decode('utf-8'))
|
||||||
|
if data.get("success") and data.get("metadata"):
|
||||||
|
formatted_metadata = await self.service.format_response(data["metadata"])
|
||||||
|
return web.json_response({
|
||||||
|
"success": True,
|
||||||
|
"metadata": formatted_metadata
|
||||||
|
})
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def relink_civitai(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle CivitAI metadata re-linking request"""
|
||||||
|
return await ModelRouteUtils.handle_relink_civitai(request, self.service.scanner)
|
||||||
|
|
||||||
|
async def replace_preview(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle preview image replacement"""
|
||||||
|
return await ModelRouteUtils.handle_replace_preview(request, self.service.scanner)
|
||||||
|
|
||||||
|
async def save_metadata(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle saving metadata updates"""
|
||||||
|
return await ModelRouteUtils.handle_save_metadata(request, self.service.scanner)
|
||||||
|
|
||||||
|
async def rename_model(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle renaming a model file and its associated files"""
|
||||||
|
return await ModelRouteUtils.handle_rename_model(request, self.service.scanner)
|
||||||
|
|
||||||
|
async def bulk_delete_models(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle bulk deletion of models"""
|
||||||
|
return await ModelRouteUtils.handle_bulk_delete_models(request, self.service.scanner)
|
||||||
|
|
||||||
|
async def verify_duplicates(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle verification of duplicate model hashes"""
|
||||||
|
return await ModelRouteUtils.handle_verify_duplicates(request, self.service.scanner)
|
||||||
|
|
||||||
|
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle request for top tags sorted by frequency"""
|
||||||
|
try:
|
||||||
|
limit = int(request.query.get('limit', '20'))
|
||||||
|
if limit < 1 or limit > 100:
|
||||||
|
limit = 20
|
||||||
|
|
||||||
|
top_tags = await self.service.get_top_tags(limit)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'tags': top_tags
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Internal server error'
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get base models used in models"""
|
||||||
|
try:
|
||||||
|
limit = int(request.query.get('limit', '20'))
|
||||||
|
if limit < 1 or limit > 100:
|
||||||
|
limit = 20
|
||||||
|
|
||||||
|
base_models = await self.service.get_base_models(limit)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'base_models': base_models
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error retrieving base models: {e}")
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def scan_models(self, request: web.Request) -> web.Response:
|
||||||
|
"""Force a rescan of model files"""
|
||||||
|
try:
|
||||||
|
full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true'
|
||||||
|
|
||||||
|
await self.service.scan_models(force_refresh=True, rebuild_cache=full_rebuild)
|
||||||
|
return web.json_response({
|
||||||
|
"status": "success",
|
||||||
|
"message": f"{self.model_type.capitalize()} scan completed"
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in scan_{self.model_type}s: {e}", exc_info=True)
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def get_model_roots(self, request: web.Request) -> web.Response:
|
||||||
|
"""Return the model root directories"""
|
||||||
|
try:
|
||||||
|
roots = self.service.get_model_roots()
|
||||||
|
return web.json_response({
|
||||||
|
"success": True,
|
||||||
|
"roots": roots
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting {self.model_type} roots: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
"success": False,
|
||||||
|
"error": str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_folders(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get all folders in the cache"""
|
||||||
|
try:
|
||||||
|
cache = await self.service.scanner.get_cached_data()
|
||||||
|
return web.json_response({
|
||||||
|
'folders': cache.folders
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting folders: {e}")
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def find_duplicate_models(self, request: web.Request) -> web.Response:
|
||||||
|
"""Find models with duplicate SHA256 hashes"""
|
||||||
|
try:
|
||||||
|
# Get duplicate hashes from service
|
||||||
|
duplicates = self.service.find_duplicate_hashes()
|
||||||
|
|
||||||
|
# Format the response
|
||||||
|
result = []
|
||||||
|
cache = await self.service.scanner.get_cached_data()
|
||||||
|
|
||||||
|
for sha256, paths in duplicates.items():
|
||||||
|
group = {
|
||||||
|
"hash": sha256,
|
||||||
|
"models": []
|
||||||
|
}
|
||||||
|
# Find matching models for each path
|
||||||
|
for path in paths:
|
||||||
|
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
|
||||||
|
if model:
|
||||||
|
group["models"].append(await self.service.format_response(model))
|
||||||
|
|
||||||
|
# Add the primary model too
|
||||||
|
primary_path = self.service.get_path_by_hash(sha256)
|
||||||
|
if primary_path and primary_path not in paths:
|
||||||
|
primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None)
|
||||||
|
if primary_model:
|
||||||
|
group["models"].insert(0, await self.service.format_response(primary_model))
|
||||||
|
|
||||||
|
if len(group["models"]) > 1: # Only include if we found multiple models
|
||||||
|
result.append(group)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
"success": True,
|
||||||
|
"duplicates": result,
|
||||||
|
"count": len(result)
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error finding duplicate {self.model_type}s: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
"success": False,
|
||||||
|
"error": str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def find_filename_conflicts(self, request: web.Request) -> web.Response:
|
||||||
|
"""Find models with conflicting filenames"""
|
||||||
|
try:
|
||||||
|
# Get duplicate filenames from service
|
||||||
|
duplicates = self.service.find_duplicate_filenames()
|
||||||
|
|
||||||
|
# Format the response
|
||||||
|
result = []
|
||||||
|
cache = await self.service.scanner.get_cached_data()
|
||||||
|
|
||||||
|
for filename, paths in duplicates.items():
|
||||||
|
group = {
|
||||||
|
"filename": filename,
|
||||||
|
"models": []
|
||||||
|
}
|
||||||
|
# Find matching models for each path
|
||||||
|
for path in paths:
|
||||||
|
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
|
||||||
|
if model:
|
||||||
|
group["models"].append(await self.service.format_response(model))
|
||||||
|
|
||||||
|
# Find the model from the main index too
|
||||||
|
hash_val = self.service.scanner._hash_index.get_hash_by_filename(filename)
|
||||||
|
if hash_val:
|
||||||
|
main_path = self.service.get_path_by_hash(hash_val)
|
||||||
|
if main_path and main_path not in paths:
|
||||||
|
main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None)
|
||||||
|
if main_model:
|
||||||
|
group["models"].insert(0, await self.service.format_response(main_model))
|
||||||
|
|
||||||
|
if group["models"]:
|
||||||
|
result.append(group)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
"success": True,
|
||||||
|
"conflicts": result,
|
||||||
|
"count": len(result)
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error finding filename conflicts for {self.model_type}s: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
"success": False,
|
||||||
|
"error": str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
# Download management methods
|
||||||
|
async def download_model(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle model download request"""
|
||||||
|
return await ModelRouteUtils.handle_download_model(request)
|
||||||
|
|
||||||
|
async def download_model_get(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle model download request via GET method"""
|
||||||
|
try:
|
||||||
|
# Extract query parameters
|
||||||
|
model_id = request.query.get('model_id')
|
||||||
|
if not model_id:
|
||||||
|
return web.Response(
|
||||||
|
status=400,
|
||||||
|
text="Missing required parameter: Please provide 'model_id'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get optional parameters
|
||||||
|
model_version_id = request.query.get('model_version_id')
|
||||||
|
download_id = request.query.get('download_id')
|
||||||
|
use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true'
|
||||||
|
|
||||||
|
# Create a data dictionary that mimics what would be received from a POST request
|
||||||
|
data = {
|
||||||
|
'model_id': model_id
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional parameters only if they are provided
|
||||||
|
if model_version_id:
|
||||||
|
data['model_version_id'] = model_version_id
|
||||||
|
|
||||||
|
if download_id:
|
||||||
|
data['download_id'] = download_id
|
||||||
|
|
||||||
|
data['use_default_paths'] = use_default_paths
|
||||||
|
|
||||||
|
# Create a mock request object with the data
|
||||||
|
future = asyncio.get_event_loop().create_future()
|
||||||
|
future.set_result(data)
|
||||||
|
|
||||||
|
mock_request = type('MockRequest', (), {
|
||||||
|
'json': lambda self=None: future
|
||||||
|
})()
|
||||||
|
|
||||||
|
# Call the existing download handler
|
||||||
|
return await ModelRouteUtils.handle_download_model(mock_request)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_message = str(e)
|
||||||
|
logger.error(f"Error downloading model via GET: {error_message}", exc_info=True)
|
||||||
|
return web.Response(status=500, text=error_message)
|
||||||
|
|
||||||
|
async def cancel_download_get(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle GET request for cancelling a download by download_id"""
|
||||||
|
try:
|
||||||
|
download_id = request.query.get('download_id')
|
||||||
|
if not download_id:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Download ID is required'
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
# Create a mock request with match_info for compatibility
|
||||||
|
mock_request = type('MockRequest', (), {
|
||||||
|
'match_info': {'download_id': download_id}
|
||||||
|
})()
|
||||||
|
return await ModelRouteUtils.handle_cancel_download(mock_request)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error cancelling download via GET: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_download_progress(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle request for download progress by download_id"""
|
||||||
|
try:
|
||||||
|
# Get download_id from URL path
|
||||||
|
download_id = request.match_info.get('download_id')
|
||||||
|
if not download_id:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Download ID is required'
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
progress_data = ws_manager.get_download_progress(download_id)
|
||||||
|
|
||||||
|
if progress_data is None:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Download ID not found'
|
||||||
|
}, status=404)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'progress': progress_data.get('progress', 0)
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting download progress: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||||
|
"""Fetch CivitAI metadata for all models in the background"""
|
||||||
|
try:
|
||||||
|
cache = await self.service.scanner.get_cached_data()
|
||||||
|
total = len(cache.raw_data)
|
||||||
|
processed = 0
|
||||||
|
success = 0
|
||||||
|
needs_resort = False
|
||||||
|
|
||||||
|
# Prepare models to process
|
||||||
|
to_process = [
|
||||||
|
model for model in cache.raw_data
|
||||||
|
if model.get('sha256') and (not model.get('civitai') or 'id' not in model.get('civitai')) and model.get('from_civitai', True)
|
||||||
|
]
|
||||||
|
total_to_process = len(to_process)
|
||||||
|
|
||||||
|
# Send initial progress
|
||||||
|
await ws_manager.broadcast({
|
||||||
|
'status': 'started',
|
||||||
|
'total': total_to_process,
|
||||||
|
'processed': 0,
|
||||||
|
'success': 0
|
||||||
|
})
|
||||||
|
|
||||||
|
# Process each model
|
||||||
|
for model in to_process:
|
||||||
|
try:
|
||||||
|
original_name = model.get('model_name')
|
||||||
|
if await ModelRouteUtils.fetch_and_update_model(
|
||||||
|
sha256=model['sha256'],
|
||||||
|
file_path=model['file_path'],
|
||||||
|
model_data=model,
|
||||||
|
update_cache_func=self.service.scanner.update_single_model_cache
|
||||||
|
):
|
||||||
|
success += 1
|
||||||
|
if original_name != model.get('model_name'):
|
||||||
|
needs_resort = True
|
||||||
|
|
||||||
|
processed += 1
|
||||||
|
|
||||||
|
# Send progress update
|
||||||
|
await ws_manager.broadcast({
|
||||||
|
'status': 'processing',
|
||||||
|
'total': total_to_process,
|
||||||
|
'processed': processed,
|
||||||
|
'success': success,
|
||||||
|
'current_name': model.get('model_name', 'Unknown')
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching CivitAI data for {model['file_path']}: {e}")
|
||||||
|
|
||||||
|
if needs_resort:
|
||||||
|
await cache.resort()
|
||||||
|
|
||||||
|
# Send completion message
|
||||||
|
await ws_manager.broadcast({
|
||||||
|
'status': 'completed',
|
||||||
|
'total': total_to_process,
|
||||||
|
'processed': processed,
|
||||||
|
'success': success
|
||||||
|
})
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
"success": True,
|
||||||
|
"message": f"Successfully updated {success} of {processed} processed {self.model_type}s (total: {total})"
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Send error message
|
||||||
|
await ws_manager.broadcast({
|
||||||
|
'status': 'error',
|
||||||
|
'error': str(e)
|
||||||
|
})
|
||||||
|
logger.error(f"Error in fetch_all_civitai for {self.model_type}s: {e}")
|
||||||
|
return web.Response(text=str(e), status=500)
|
||||||
|
|
||||||
|
async def get_civitai_versions(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get available versions for a Civitai model with local availability info"""
|
||||||
|
# This will be implemented by subclasses as they need CivitAI client access
|
||||||
|
return web.json_response({
|
||||||
|
"error": "Not implemented in base class"
|
||||||
|
}, status=501)
|
||||||
105
py/routes/checkpoint_routes.py
Normal file
105
py/routes/checkpoint_routes.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
import logging
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from .base_model_routes import BaseModelRoutes
|
||||||
|
from ..services.checkpoint_service import CheckpointService
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class CheckpointRoutes(BaseModelRoutes):
|
||||||
|
"""Checkpoint-specific route controller"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize Checkpoint routes with Checkpoint service"""
|
||||||
|
# Service will be initialized later via setup_routes
|
||||||
|
self.service = None
|
||||||
|
self.civitai_client = None
|
||||||
|
self.template_name = "checkpoints.html"
|
||||||
|
|
||||||
|
async def initialize_services(self):
|
||||||
|
"""Initialize services from ServiceRegistry"""
|
||||||
|
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||||
|
self.service = CheckpointService(checkpoint_scanner)
|
||||||
|
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||||
|
|
||||||
|
# Initialize parent with the service
|
||||||
|
super().__init__(self.service)
|
||||||
|
|
||||||
|
def setup_routes(self, app: web.Application):
|
||||||
|
"""Setup Checkpoint routes"""
|
||||||
|
# Schedule service initialization on app startup
|
||||||
|
app.on_startup.append(lambda _: self.initialize_services())
|
||||||
|
|
||||||
|
# Setup common routes with 'checkpoints' prefix (includes page route)
|
||||||
|
super().setup_routes(app, 'checkpoints')
|
||||||
|
|
||||||
|
def setup_specific_routes(self, app: web.Application, prefix: str):
|
||||||
|
"""Setup Checkpoint-specific routes"""
|
||||||
|
# Checkpoint-specific CivitAI integration
|
||||||
|
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint)
|
||||||
|
|
||||||
|
# Checkpoint info by name
|
||||||
|
app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info)
|
||||||
|
|
||||||
|
async def get_checkpoint_info(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get detailed information for a specific checkpoint by name"""
|
||||||
|
try:
|
||||||
|
name = request.match_info.get('name', '')
|
||||||
|
checkpoint_info = await self.service.get_model_info_by_name(name)
|
||||||
|
|
||||||
|
if checkpoint_info:
|
||||||
|
return web.json_response(checkpoint_info)
|
||||||
|
else:
|
||||||
|
return web.json_response({"error": "Checkpoint not found"}, status=404)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def get_civitai_versions_checkpoint(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get available versions for a Civitai checkpoint model with local availability info"""
|
||||||
|
try:
|
||||||
|
model_id = request.match_info['model_id']
|
||||||
|
response = await self.civitai_client.get_model_versions(model_id)
|
||||||
|
if not response or not response.get('modelVersions'):
|
||||||
|
return web.Response(status=404, text="Model not found")
|
||||||
|
|
||||||
|
versions = response.get('modelVersions', [])
|
||||||
|
model_type = response.get('type', '')
|
||||||
|
|
||||||
|
# Check model type - should be Checkpoint
|
||||||
|
if model_type.lower() != 'checkpoint':
|
||||||
|
return web.json_response({
|
||||||
|
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
# Check local availability for each version
|
||||||
|
for version in versions:
|
||||||
|
# Find the primary model file (type="Model" and primary=true) in the files list
|
||||||
|
model_file = next((file for file in version.get('files', [])
|
||||||
|
if file.get('type') == 'Model' and file.get('primary') == True), None)
|
||||||
|
|
||||||
|
# If no primary file found, try to find any model file
|
||||||
|
if not model_file:
|
||||||
|
model_file = next((file for file in version.get('files', [])
|
||||||
|
if file.get('type') == 'Model'), None)
|
||||||
|
|
||||||
|
if model_file:
|
||||||
|
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||||
|
if sha256:
|
||||||
|
# Set existsLocally and localPath at the version level
|
||||||
|
version['existsLocally'] = self.service.has_hash(sha256)
|
||||||
|
if version['existsLocally']:
|
||||||
|
version['localPath'] = self.service.get_path_by_hash(sha256)
|
||||||
|
|
||||||
|
# Also set the model file size at the version level for easier access
|
||||||
|
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||||
|
else:
|
||||||
|
# No model file found in this version
|
||||||
|
version['existsLocally'] = False
|
||||||
|
|
||||||
|
return web.json_response(versions)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching checkpoint model versions: {e}")
|
||||||
|
return web.Response(status=500, text=str(e))
|
||||||
@@ -1,771 +0,0 @@
|
|||||||
import os
|
|
||||||
import json
|
|
||||||
import jinja2
|
|
||||||
from aiohttp import web
|
|
||||||
import logging
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
|
||||||
from ..utils.constants import NSFW_LEVELS
|
|
||||||
from ..utils.metadata_manager import MetadataManager
|
|
||||||
from ..services.websocket_manager import ws_manager
|
|
||||||
from ..services.service_registry import ServiceRegistry
|
|
||||||
from ..config import config
|
|
||||||
from ..services.settings_manager import settings
|
|
||||||
from ..utils.utils import fuzzy_match
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class CheckpointsRoutes:
|
|
||||||
"""API routes for checkpoint management"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.scanner = None # Will be initialized in setup_routes
|
|
||||||
self.template_env = jinja2.Environment(
|
|
||||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
|
||||||
autoescape=True
|
|
||||||
)
|
|
||||||
self.download_manager = None # Will be initialized in setup_routes
|
|
||||||
self._download_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
async def initialize_services(self):
|
|
||||||
"""Initialize services from ServiceRegistry"""
|
|
||||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
|
||||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
|
||||||
|
|
||||||
def setup_routes(self, app):
|
|
||||||
"""Register routes with the aiohttp app"""
|
|
||||||
# Schedule service initialization on app startup
|
|
||||||
app.on_startup.append(lambda _: self.initialize_services())
|
|
||||||
|
|
||||||
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
|
|
||||||
app.router.add_get('/api/checkpoints', self.get_checkpoints)
|
|
||||||
app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai)
|
|
||||||
app.router.add_get('/api/checkpoints/base-models', self.get_base_models)
|
|
||||||
app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags)
|
|
||||||
app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints)
|
|
||||||
app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info)
|
|
||||||
app.router.add_get('/api/checkpoints/roots', self.get_checkpoint_roots)
|
|
||||||
app.router.add_get('/api/checkpoints/civitai/versions/{model_id}', self.get_civitai_versions) # Add new route
|
|
||||||
|
|
||||||
# Add new routes for model management similar to LoRA routes
|
|
||||||
app.router.add_post('/api/checkpoints/delete', self.delete_model)
|
|
||||||
app.router.add_post('/api/checkpoints/exclude', self.exclude_model) # Add new exclude endpoint
|
|
||||||
app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai)
|
|
||||||
app.router.add_post('/api/checkpoints/relink-civitai', self.relink_civitai) # Add new relink endpoint
|
|
||||||
app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview)
|
|
||||||
app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route
|
|
||||||
app.router.add_post('/api/checkpoints/rename', self.rename_checkpoint) # Add new rename endpoint
|
|
||||||
|
|
||||||
# Add new routes for finding duplicates and filename conflicts
|
|
||||||
app.router.add_get('/api/checkpoints/find-duplicates', self.find_duplicate_checkpoints)
|
|
||||||
app.router.add_get('/api/checkpoints/find-filename-conflicts', self.find_filename_conflicts)
|
|
||||||
|
|
||||||
# Add new endpoint for bulk deleting checkpoints
|
|
||||||
app.router.add_post('/api/checkpoints/bulk-delete', self.bulk_delete_checkpoints)
|
|
||||||
|
|
||||||
# Add new endpoint for verifying duplicates
|
|
||||||
app.router.add_post('/api/checkpoints/verify-duplicates', self.verify_duplicates)
|
|
||||||
|
|
||||||
async def get_checkpoints(self, request):
|
|
||||||
"""Get paginated checkpoint data"""
|
|
||||||
try:
|
|
||||||
# Parse query parameters
|
|
||||||
page = int(request.query.get('page', '1'))
|
|
||||||
page_size = min(int(request.query.get('page_size', '20')), 100)
|
|
||||||
sort_by = request.query.get('sort_by', 'name')
|
|
||||||
folder = request.query.get('folder', None)
|
|
||||||
search = request.query.get('search', None)
|
|
||||||
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
|
|
||||||
base_models = request.query.getall('base_model', [])
|
|
||||||
tags = request.query.getall('tag', [])
|
|
||||||
favorites_only = request.query.get('favorites_only', 'false').lower() == 'true' # Add favorites_only parameter
|
|
||||||
|
|
||||||
# Process search options
|
|
||||||
search_options = {
|
|
||||||
'filename': request.query.get('search_filename', 'true').lower() == 'true',
|
|
||||||
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
|
|
||||||
'tags': request.query.get('search_tags', 'false').lower() == 'true',
|
|
||||||
'recursive': request.query.get('recursive', 'false').lower() == 'true',
|
|
||||||
}
|
|
||||||
|
|
||||||
# Process hash filters if provided
|
|
||||||
hash_filters = {}
|
|
||||||
if 'hash' in request.query:
|
|
||||||
hash_filters['single_hash'] = request.query['hash']
|
|
||||||
elif 'hashes' in request.query:
|
|
||||||
try:
|
|
||||||
hash_list = json.loads(request.query['hashes'])
|
|
||||||
if isinstance(hash_list, list):
|
|
||||||
hash_filters['multiple_hashes'] = hash_list
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Get data from scanner
|
|
||||||
result = await self.get_paginated_data(
|
|
||||||
page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
sort_by=sort_by,
|
|
||||||
folder=folder,
|
|
||||||
search=search,
|
|
||||||
fuzzy_search=fuzzy_search,
|
|
||||||
base_models=base_models,
|
|
||||||
tags=tags,
|
|
||||||
search_options=search_options,
|
|
||||||
hash_filters=hash_filters,
|
|
||||||
favorites_only=favorites_only # Pass favorites_only parameter
|
|
||||||
)
|
|
||||||
|
|
||||||
# Format response items
|
|
||||||
formatted_result = {
|
|
||||||
'items': [self._format_checkpoint_response(cp) for cp in result['items']],
|
|
||||||
'total': result['total'],
|
|
||||||
'page': result['page'],
|
|
||||||
'page_size': result['page_size'],
|
|
||||||
'total_pages': result['total_pages']
|
|
||||||
}
|
|
||||||
|
|
||||||
# Return as JSON
|
|
||||||
return web.json_response(formatted_result)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in get_checkpoints: {e}", exc_info=True)
|
|
||||||
return web.json_response({"error": str(e)}, status=500)
|
|
||||||
|
|
||||||
async def get_paginated_data(self, page, page_size, sort_by='name',
|
|
||||||
folder=None, search=None, fuzzy_search=False,
|
|
||||||
base_models=None, tags=None,
|
|
||||||
search_options=None, hash_filters=None,
|
|
||||||
favorites_only=False): # Add favorites_only parameter with default False
|
|
||||||
"""Get paginated and filtered checkpoint data"""
|
|
||||||
cache = await self.scanner.get_cached_data()
|
|
||||||
|
|
||||||
# Get default search options if not provided
|
|
||||||
if search_options is None:
|
|
||||||
search_options = {
|
|
||||||
'filename': True,
|
|
||||||
'modelname': True,
|
|
||||||
'tags': False,
|
|
||||||
'recursive': False,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get the base data set
|
|
||||||
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
|
||||||
|
|
||||||
# Apply hash filtering if provided (highest priority)
|
|
||||||
if hash_filters:
|
|
||||||
single_hash = hash_filters.get('single_hash')
|
|
||||||
multiple_hashes = hash_filters.get('multiple_hashes')
|
|
||||||
|
|
||||||
if single_hash:
|
|
||||||
# Filter by single hash
|
|
||||||
single_hash = single_hash.lower() # Ensure lowercase for matching
|
|
||||||
filtered_data = [
|
|
||||||
cp for cp in filtered_data
|
|
||||||
if cp.get('sha256', '').lower() == single_hash
|
|
||||||
]
|
|
||||||
elif multiple_hashes:
|
|
||||||
# Filter by multiple hashes
|
|
||||||
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
|
|
||||||
filtered_data = [
|
|
||||||
cp for cp in filtered_data
|
|
||||||
if cp.get('sha256', '').lower() in hash_set
|
|
||||||
]
|
|
||||||
|
|
||||||
# Jump to pagination
|
|
||||||
total_items = len(filtered_data)
|
|
||||||
start_idx = (page - 1) * page_size
|
|
||||||
end_idx = min(start_idx + page_size, total_items)
|
|
||||||
|
|
||||||
result = {
|
|
||||||
'items': filtered_data[start_idx:end_idx],
|
|
||||||
'total': total_items,
|
|
||||||
'page': page,
|
|
||||||
'page_size': page_size,
|
|
||||||
'total_pages': (total_items + page_size - 1) // page_size
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
# Apply SFW filtering if enabled in settings
|
|
||||||
if settings.get('show_only_sfw', False):
|
|
||||||
filtered_data = [
|
|
||||||
cp for cp in filtered_data
|
|
||||||
if not cp.get('preview_nsfw_level') or cp.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply favorites filtering if enabled
|
|
||||||
if favorites_only:
|
|
||||||
filtered_data = [
|
|
||||||
cp for cp in filtered_data
|
|
||||||
if cp.get('favorite', False) is True
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply folder filtering
|
|
||||||
if folder is not None:
|
|
||||||
if search_options.get('recursive', False):
|
|
||||||
# Recursive folder filtering - include all subfolders
|
|
||||||
filtered_data = [
|
|
||||||
cp for cp in filtered_data
|
|
||||||
if cp['folder'].startswith(folder)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# Exact folder filtering
|
|
||||||
filtered_data = [
|
|
||||||
cp for cp in filtered_data
|
|
||||||
if cp['folder'] == folder
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply base model filtering
|
|
||||||
if base_models and len(base_models) > 0:
|
|
||||||
filtered_data = [
|
|
||||||
cp for cp in filtered_data
|
|
||||||
if cp.get('base_model') in base_models
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply tag filtering
|
|
||||||
if tags and len(tags) > 0:
|
|
||||||
filtered_data = [
|
|
||||||
cp for cp in filtered_data
|
|
||||||
if any(tag in cp.get('tags', []) for tag in tags)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply search filtering
|
|
||||||
if search:
|
|
||||||
search_results = []
|
|
||||||
|
|
||||||
for cp in filtered_data:
|
|
||||||
# Search by file name
|
|
||||||
if search_options.get('filename', True):
|
|
||||||
if fuzzy_search:
|
|
||||||
if fuzzy_match(cp.get('file_name', ''), search):
|
|
||||||
search_results.append(cp)
|
|
||||||
continue
|
|
||||||
elif search.lower() in cp.get('file_name', '').lower():
|
|
||||||
search_results.append(cp)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Search by model name
|
|
||||||
if search_options.get('modelname', True):
|
|
||||||
if fuzzy_search:
|
|
||||||
if fuzzy_match(cp.get('model_name', ''), search):
|
|
||||||
search_results.append(cp)
|
|
||||||
continue
|
|
||||||
elif search.lower() in cp.get('model_name', '').lower():
|
|
||||||
search_results.append(cp)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Search by tags
|
|
||||||
if search_options.get('tags', False) and 'tags' in cp:
|
|
||||||
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower()) for tag in cp['tags']):
|
|
||||||
search_results.append(cp)
|
|
||||||
continue
|
|
||||||
|
|
||||||
filtered_data = search_results
|
|
||||||
|
|
||||||
# Calculate pagination
|
|
||||||
total_items = len(filtered_data)
|
|
||||||
start_idx = (page - 1) * page_size
|
|
||||||
end_idx = min(start_idx + page_size, total_items)
|
|
||||||
|
|
||||||
result = {
|
|
||||||
'items': filtered_data[start_idx:end_idx],
|
|
||||||
'total': total_items,
|
|
||||||
'page': page,
|
|
||||||
'page_size': page_size,
|
|
||||||
'total_pages': (total_items + page_size - 1) // page_size
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _format_checkpoint_response(self, checkpoint):
|
|
||||||
"""Format checkpoint data for API response"""
|
|
||||||
return {
|
|
||||||
"model_name": checkpoint["model_name"],
|
|
||||||
"file_name": checkpoint["file_name"],
|
|
||||||
"preview_url": config.get_preview_static_url(checkpoint.get("preview_url", "")),
|
|
||||||
"preview_nsfw_level": checkpoint.get("preview_nsfw_level", 0),
|
|
||||||
"base_model": checkpoint.get("base_model", ""),
|
|
||||||
"folder": checkpoint["folder"],
|
|
||||||
"sha256": checkpoint.get("sha256", ""),
|
|
||||||
"file_path": checkpoint["file_path"].replace(os.sep, "/"),
|
|
||||||
"file_size": checkpoint.get("size", 0),
|
|
||||||
"modified": checkpoint.get("modified", ""),
|
|
||||||
"tags": checkpoint.get("tags", []),
|
|
||||||
"modelDescription": checkpoint.get("modelDescription", ""),
|
|
||||||
"from_civitai": checkpoint.get("from_civitai", True),
|
|
||||||
"notes": checkpoint.get("notes", ""),
|
|
||||||
"model_type": checkpoint.get("model_type", "checkpoint"),
|
|
||||||
"favorite": checkpoint.get("favorite", False),
|
|
||||||
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint.get("civitai", {}))
|
|
||||||
}
|
|
||||||
|
|
||||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
|
||||||
"""Fetch CivitAI metadata for all checkpoints in the background"""
|
|
||||||
try:
|
|
||||||
cache = await self.scanner.get_cached_data()
|
|
||||||
total = len(cache.raw_data)
|
|
||||||
processed = 0
|
|
||||||
success = 0
|
|
||||||
needs_resort = False
|
|
||||||
|
|
||||||
# Prepare checkpoints to process
|
|
||||||
to_process = [
|
|
||||||
cp for cp in cache.raw_data
|
|
||||||
if cp.get('sha256') and (not cp.get('civitai') or 'id' not in cp.get('civitai')) and cp.get('from_civitai', True)
|
|
||||||
]
|
|
||||||
total_to_process = len(to_process)
|
|
||||||
|
|
||||||
# Send initial progress
|
|
||||||
await ws_manager.broadcast({
|
|
||||||
'status': 'started',
|
|
||||||
'total': total_to_process,
|
|
||||||
'processed': 0,
|
|
||||||
'success': 0
|
|
||||||
})
|
|
||||||
|
|
||||||
# Process each checkpoint
|
|
||||||
for cp in to_process:
|
|
||||||
try:
|
|
||||||
original_name = cp.get('model_name')
|
|
||||||
if await ModelRouteUtils.fetch_and_update_model(
|
|
||||||
sha256=cp['sha256'],
|
|
||||||
file_path=cp['file_path'],
|
|
||||||
model_data=cp,
|
|
||||||
update_cache_func=self.scanner.update_single_model_cache
|
|
||||||
):
|
|
||||||
success += 1
|
|
||||||
if original_name != cp.get('model_name'):
|
|
||||||
needs_resort = True
|
|
||||||
|
|
||||||
processed += 1
|
|
||||||
|
|
||||||
# Send progress update
|
|
||||||
await ws_manager.broadcast({
|
|
||||||
'status': 'processing',
|
|
||||||
'total': total_to_process,
|
|
||||||
'processed': processed,
|
|
||||||
'success': success,
|
|
||||||
'current_name': cp.get('model_name', 'Unknown')
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error fetching CivitAI data for {cp['file_path']}: {e}")
|
|
||||||
|
|
||||||
if needs_resort:
|
|
||||||
await cache.resort(name_only=True)
|
|
||||||
|
|
||||||
# Send completion message
|
|
||||||
await ws_manager.broadcast({
|
|
||||||
'status': 'completed',
|
|
||||||
'total': total_to_process,
|
|
||||||
'processed': processed,
|
|
||||||
'success': success
|
|
||||||
})
|
|
||||||
|
|
||||||
return web.json_response({
|
|
||||||
"success": True,
|
|
||||||
"message": f"Successfully updated {success} of {processed} processed checkpoints (total: {total})"
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Send error message
|
|
||||||
await ws_manager.broadcast({
|
|
||||||
'status': 'error',
|
|
||||||
'error': str(e)
|
|
||||||
})
|
|
||||||
logger.error(f"Error in fetch_all_civitai for checkpoints: {e}")
|
|
||||||
return web.Response(text=str(e), status=500)
|
|
||||||
|
|
||||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle request for top tags sorted by frequency"""
|
|
||||||
try:
|
|
||||||
# Parse query parameters
|
|
||||||
limit = int(request.query.get('limit', '20'))
|
|
||||||
|
|
||||||
# Validate limit
|
|
||||||
if limit < 1 or limit > 100:
|
|
||||||
limit = 20 # Default to a reasonable limit
|
|
||||||
|
|
||||||
# Get top tags
|
|
||||||
top_tags = await self.scanner.get_top_tags(limit)
|
|
||||||
|
|
||||||
return web.json_response({
|
|
||||||
'success': True,
|
|
||||||
'tags': top_tags
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
|
|
||||||
return web.json_response({
|
|
||||||
'success': False,
|
|
||||||
'error': 'Internal server error'
|
|
||||||
}, status=500)
|
|
||||||
|
|
||||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
|
||||||
"""Get base models used in loras"""
|
|
||||||
try:
|
|
||||||
# Parse query parameters
|
|
||||||
limit = int(request.query.get('limit', '20'))
|
|
||||||
|
|
||||||
# Validate limit
|
|
||||||
if limit < 1 or limit > 100:
|
|
||||||
limit = 20 # Default to a reasonable limit
|
|
||||||
|
|
||||||
# Get base models
|
|
||||||
base_models = await self.scanner.get_base_models(limit)
|
|
||||||
|
|
||||||
return web.json_response({
|
|
||||||
'success': True,
|
|
||||||
'base_models': base_models
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error retrieving base models: {e}")
|
|
||||||
return web.json_response({
|
|
||||||
'success': False,
|
|
||||||
'error': str(e)
|
|
||||||
}, status=500)
|
|
||||||
|
|
||||||
async def scan_checkpoints(self, request):
|
|
||||||
"""Force a rescan of checkpoint files"""
|
|
||||||
try:
|
|
||||||
# Get the full_rebuild parameter and convert to bool, default to False
|
|
||||||
full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true'
|
|
||||||
|
|
||||||
await self.scanner.get_cached_data(force_refresh=True, rebuild_cache=full_rebuild)
|
|
||||||
return web.json_response({"status": "success", "message": "Checkpoint scan completed"})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in scan_checkpoints: {e}", exc_info=True)
|
|
||||||
return web.json_response({"error": str(e)}, status=500)
|
|
||||||
|
|
||||||
async def get_checkpoint_info(self, request):
|
|
||||||
"""Get detailed information for a specific checkpoint by name"""
|
|
||||||
try:
|
|
||||||
name = request.match_info.get('name', '')
|
|
||||||
checkpoint_info = await self.scanner.get_model_info_by_name(name)
|
|
||||||
|
|
||||||
if checkpoint_info:
|
|
||||||
return web.json_response(checkpoint_info)
|
|
||||||
else:
|
|
||||||
return web.json_response({"error": "Checkpoint not found"}, status=404)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
|
||||||
return web.json_response({"error": str(e)}, status=500)
|
|
||||||
|
|
||||||
async def handle_checkpoints_page(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle GET /checkpoints request"""
|
|
||||||
try:
|
|
||||||
# Check if the CheckpointScanner is initializing
|
|
||||||
# It's initializing if the cache object doesn't exist yet,
|
|
||||||
# OR if the scanner explicitly says it's initializing (background task running).
|
|
||||||
is_initializing = (
|
|
||||||
self.scanner._cache is None or
|
|
||||||
(hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing)
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_initializing:
|
|
||||||
# If still initializing, return loading page
|
|
||||||
template = self.template_env.get_template('checkpoints.html')
|
|
||||||
rendered = template.render(
|
|
||||||
folders=[], # 空文件夹列表
|
|
||||||
is_initializing=True, # 新增标志
|
|
||||||
settings=settings, # Pass settings to template
|
|
||||||
request=request # Pass the request object to the template
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Checkpoints page is initializing, returning loading page")
|
|
||||||
else:
|
|
||||||
# 正常流程 - 获取已经初始化好的缓存数据
|
|
||||||
try:
|
|
||||||
cache = await self.scanner.get_cached_data(force_refresh=False)
|
|
||||||
template = self.template_env.get_template('checkpoints.html')
|
|
||||||
rendered = template.render(
|
|
||||||
folders=cache.folders,
|
|
||||||
is_initializing=False,
|
|
||||||
settings=settings, # Pass settings to template
|
|
||||||
request=request # Pass the request object to the template
|
|
||||||
)
|
|
||||||
except Exception as cache_error:
|
|
||||||
logger.error(f"Error loading checkpoints cache data: {cache_error}")
|
|
||||||
# 如果获取缓存失败,也显示初始化页面
|
|
||||||
template = self.template_env.get_template('checkpoints.html')
|
|
||||||
rendered = template.render(
|
|
||||||
folders=[],
|
|
||||||
is_initializing=True,
|
|
||||||
settings=settings,
|
|
||||||
request=request
|
|
||||||
)
|
|
||||||
logger.info("Checkpoints cache error, returning initialization page")
|
|
||||||
|
|
||||||
return web.Response(
|
|
||||||
text=rendered,
|
|
||||||
content_type='text/html'
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error handling checkpoints request: {e}", exc_info=True)
|
|
||||||
return web.Response(
|
|
||||||
text="Error loading checkpoints page",
|
|
||||||
status=500
|
|
||||||
)
|
|
||||||
|
|
||||||
async def delete_model(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle checkpoint model deletion request"""
|
|
||||||
return await ModelRouteUtils.handle_delete_model(request, self.scanner)
|
|
||||||
|
|
||||||
async def exclude_model(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle checkpoint model exclusion request"""
|
|
||||||
return await ModelRouteUtils.handle_exclude_model(request, self.scanner)
|
|
||||||
|
|
||||||
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle CivitAI metadata fetch request for checkpoints"""
|
|
||||||
response = await ModelRouteUtils.handle_fetch_civitai(request, self.scanner)
|
|
||||||
|
|
||||||
# If successful, format the metadata before returning
|
|
||||||
if response.status == 200:
|
|
||||||
data = json.loads(response.body.decode('utf-8'))
|
|
||||||
if data.get("success") and data.get("metadata"):
|
|
||||||
formatted_metadata = self._format_checkpoint_response(data["metadata"])
|
|
||||||
return web.json_response({
|
|
||||||
"success": True,
|
|
||||||
"metadata": formatted_metadata
|
|
||||||
})
|
|
||||||
|
|
||||||
# Otherwise, return the original response
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def replace_preview(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle preview image replacement for checkpoints"""
|
|
||||||
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
|
|
||||||
|
|
||||||
async def get_checkpoint_roots(self, request):
|
|
||||||
"""Return the checkpoint root directories"""
|
|
||||||
try:
|
|
||||||
if self.scanner is None:
|
|
||||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
|
||||||
|
|
||||||
roots = self.scanner.get_model_roots()
|
|
||||||
return web.json_response({
|
|
||||||
"success": True,
|
|
||||||
"roots": roots
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting checkpoint roots: {e}", exc_info=True)
|
|
||||||
return web.json_response({
|
|
||||||
"success": False,
|
|
||||||
"error": str(e)
|
|
||||||
}, status=500)
|
|
||||||
|
|
||||||
async def save_metadata(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle saving metadata updates for checkpoints"""
|
|
||||||
try:
|
|
||||||
if self.scanner is None:
|
|
||||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
|
||||||
|
|
||||||
data = await request.json()
|
|
||||||
file_path = data.get('file_path')
|
|
||||||
if not file_path:
|
|
||||||
return web.Response(text='File path is required', status=400)
|
|
||||||
|
|
||||||
# Remove file path from data to avoid saving it
|
|
||||||
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
|
|
||||||
|
|
||||||
# Get metadata file path
|
|
||||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
|
||||||
|
|
||||||
# Load existing metadata
|
|
||||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
|
||||||
|
|
||||||
# Update metadata
|
|
||||||
metadata.update(metadata_updates)
|
|
||||||
|
|
||||||
# Save updated metadata
|
|
||||||
await MetadataManager.save_metadata(file_path, metadata)
|
|
||||||
|
|
||||||
# Update cache
|
|
||||||
await self.scanner.update_single_model_cache(file_path, file_path, metadata)
|
|
||||||
|
|
||||||
# If model_name was updated, resort the cache
|
|
||||||
if 'model_name' in metadata_updates:
|
|
||||||
cache = await self.scanner.get_cached_data()
|
|
||||||
await cache.resort(name_only=True)
|
|
||||||
|
|
||||||
return web.json_response({'success': True})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error saving checkpoint metadata: {e}", exc_info=True)
|
|
||||||
return web.Response(text=str(e), status=500)
|
|
||||||
|
|
||||||
async def get_civitai_versions(self, request: web.Request) -> web.Response:
|
|
||||||
"""Get available versions for a Civitai checkpoint model with local availability info"""
|
|
||||||
try:
|
|
||||||
if self.scanner is None:
|
|
||||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
|
||||||
|
|
||||||
# Get the civitai client from service registry
|
|
||||||
civitai_client = await ServiceRegistry.get_civitai_client()
|
|
||||||
|
|
||||||
model_id = request.match_info['model_id']
|
|
||||||
response = await civitai_client.get_model_versions(model_id)
|
|
||||||
if not response or not response.get('modelVersions'):
|
|
||||||
return web.Response(status=404, text="Model not found")
|
|
||||||
|
|
||||||
versions = response.get('modelVersions', [])
|
|
||||||
model_type = response.get('type', '')
|
|
||||||
|
|
||||||
# Check model type - should be Checkpoint
|
|
||||||
if (model_type.lower() != 'checkpoint'):
|
|
||||||
return web.json_response({
|
|
||||||
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
|
|
||||||
}, status=400)
|
|
||||||
|
|
||||||
# Check local availability for each version
|
|
||||||
for version in versions:
|
|
||||||
# Find the primary model file (type="Model" and primary=true) in the files list
|
|
||||||
model_file = next((file for file in version.get('files', [])
|
|
||||||
if file.get('type') == 'Model' and file.get('primary') == True), None)
|
|
||||||
|
|
||||||
# If no primary file found, try to find any model file
|
|
||||||
if not model_file:
|
|
||||||
model_file = next((file for file in version.get('files', [])
|
|
||||||
if file.get('type') == 'Model'), None)
|
|
||||||
|
|
||||||
if model_file:
|
|
||||||
sha256 = model_file.get('hashes', {}).get('SHA256')
|
|
||||||
if sha256:
|
|
||||||
# Set existsLocally and localPath at the version level
|
|
||||||
version['existsLocally'] = self.scanner.has_hash(sha256)
|
|
||||||
if version['existsLocally']:
|
|
||||||
version['localPath'] = self.scanner.get_path_by_hash(sha256)
|
|
||||||
|
|
||||||
# Also set the model file size at the version level for easier access
|
|
||||||
version['modelSizeKB'] = model_file.get('sizeKB')
|
|
||||||
else:
|
|
||||||
# No model file found in this version
|
|
||||||
version['existsLocally'] = False
|
|
||||||
|
|
||||||
return web.json_response(versions)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error fetching checkpoint model versions: {e}")
|
|
||||||
return web.Response(status=500, text=str(e))
|
|
||||||
|
|
||||||
async def find_duplicate_checkpoints(self, request: web.Request) -> web.Response:
|
|
||||||
"""Find checkpoints with duplicate SHA256 hashes"""
|
|
||||||
try:
|
|
||||||
if self.scanner is None:
|
|
||||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
|
||||||
|
|
||||||
# Get duplicate hashes from hash index
|
|
||||||
duplicates = self.scanner._hash_index.get_duplicate_hashes()
|
|
||||||
|
|
||||||
# Format the response
|
|
||||||
result = []
|
|
||||||
cache = await self.scanner.get_cached_data()
|
|
||||||
|
|
||||||
for sha256, paths in duplicates.items():
|
|
||||||
group = {
|
|
||||||
"hash": sha256,
|
|
||||||
"models": []
|
|
||||||
}
|
|
||||||
# Find matching models for each path
|
|
||||||
for path in paths:
|
|
||||||
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
|
|
||||||
if model:
|
|
||||||
group["models"].append(self._format_checkpoint_response(model))
|
|
||||||
|
|
||||||
# Add the primary model too
|
|
||||||
primary_path = self.scanner._hash_index.get_path(sha256)
|
|
||||||
if primary_path and primary_path not in paths:
|
|
||||||
primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None)
|
|
||||||
if primary_model:
|
|
||||||
group["models"].insert(0, self._format_checkpoint_response(primary_model))
|
|
||||||
|
|
||||||
if len(group["models"]) > 1: # Only include if we found multiple models
|
|
||||||
result.append(group)
|
|
||||||
|
|
||||||
return web.json_response({
|
|
||||||
"success": True,
|
|
||||||
"duplicates": result,
|
|
||||||
"count": len(result)
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error finding duplicate checkpoints: {e}", exc_info=True)
|
|
||||||
return web.json_response({
|
|
||||||
"success": False,
|
|
||||||
"error": str(e)
|
|
||||||
}, status=500)
|
|
||||||
|
|
||||||
async def find_filename_conflicts(self, request: web.Request) -> web.Response:
|
|
||||||
"""Find checkpoints with conflicting filenames"""
|
|
||||||
try:
|
|
||||||
if self.scanner is None:
|
|
||||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
|
||||||
|
|
||||||
# Get duplicate filenames from hash index
|
|
||||||
duplicates = self.scanner._hash_index.get_duplicate_filenames()
|
|
||||||
|
|
||||||
# Format the response
|
|
||||||
result = []
|
|
||||||
cache = await self.scanner.get_cached_data()
|
|
||||||
|
|
||||||
for filename, paths in duplicates.items():
|
|
||||||
group = {
|
|
||||||
"filename": filename,
|
|
||||||
"models": []
|
|
||||||
}
|
|
||||||
# Find matching models for each path
|
|
||||||
for path in paths:
|
|
||||||
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
|
|
||||||
if model:
|
|
||||||
group["models"].append(self._format_checkpoint_response(model))
|
|
||||||
|
|
||||||
# Find the model from the main index too
|
|
||||||
hash_val = self.scanner._hash_index.get_hash_by_filename(filename)
|
|
||||||
if hash_val:
|
|
||||||
main_path = self.scanner._hash_index.get_path(hash_val)
|
|
||||||
if main_path and main_path not in paths:
|
|
||||||
main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None)
|
|
||||||
if main_model:
|
|
||||||
group["models"].insert(0, self._format_checkpoint_response(main_model))
|
|
||||||
|
|
||||||
if group["models"]:
|
|
||||||
result.append(group)
|
|
||||||
|
|
||||||
return web.json_response({
|
|
||||||
"success": True,
|
|
||||||
"conflicts": result,
|
|
||||||
"count": len(result)
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error finding filename conflicts: {e}", exc_info=True)
|
|
||||||
return web.json_response({
|
|
||||||
"success": False,
|
|
||||||
"error": str(e)
|
|
||||||
}, status=500)
|
|
||||||
|
|
||||||
async def bulk_delete_checkpoints(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle bulk deletion of checkpoint models"""
|
|
||||||
try:
|
|
||||||
if self.scanner is None:
|
|
||||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
|
||||||
|
|
||||||
return await ModelRouteUtils.handle_bulk_delete_models(request, self.scanner)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in bulk delete checkpoints: {e}", exc_info=True)
|
|
||||||
return web.json_response({
|
|
||||||
'success': False,
|
|
||||||
'error': str(e)
|
|
||||||
}, status=500)
|
|
||||||
|
|
||||||
async def relink_civitai(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle CivitAI metadata re-linking request by model version ID for checkpoints"""
|
|
||||||
return await ModelRouteUtils.handle_relink_civitai(request, self.scanner)
|
|
||||||
|
|
||||||
async def verify_duplicates(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle verification of duplicate checkpoint hashes"""
|
|
||||||
return await ModelRouteUtils.handle_verify_duplicates(request, self.scanner)
|
|
||||||
|
|
||||||
async def rename_checkpoint(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle renaming a checkpoint file and its associated files"""
|
|
||||||
return await ModelRouteUtils.handle_rename_model(request, self.scanner)
|
|
||||||
@@ -1,188 +1,512 @@
|
|||||||
import os
|
import asyncio
|
||||||
from aiohttp import web
|
|
||||||
import jinja2
|
|
||||||
from typing import Dict
|
|
||||||
import logging
|
import logging
|
||||||
from ..config import config
|
from aiohttp import web
|
||||||
from ..services.settings_manager import settings
|
from typing import Dict
|
||||||
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
|
from server import PromptServer # type: ignore
|
||||||
|
|
||||||
|
from .base_model_routes import BaseModelRoutes
|
||||||
|
from ..services.lora_service import LoraService
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
from ..utils.routes_common import ModelRouteUtils
|
||||||
|
from ..utils.utils import get_lora_info
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
|
|
||||||
|
|
||||||
class LoraRoutes:
|
class LoraRoutes(BaseModelRoutes):
|
||||||
"""Route handlers for LoRA management endpoints"""
|
"""LoRA-specific route controller"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Initialize service references as None, will be set during async init
|
"""Initialize LoRA routes with LoRA service"""
|
||||||
self.scanner = None
|
# Service will be initialized later via setup_routes
|
||||||
self.recipe_scanner = None
|
self.service = None
|
||||||
self.template_env = jinja2.Environment(
|
self.civitai_client = None
|
||||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
self.template_name = "loras.html"
|
||||||
autoescape=True
|
|
||||||
)
|
|
||||||
|
|
||||||
async def init_services(self):
|
async def initialize_services(self):
|
||||||
"""Initialize services from ServiceRegistry"""
|
"""Initialize services from ServiceRegistry"""
|
||||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
self.service = LoraService(lora_scanner)
|
||||||
|
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||||
|
|
||||||
def format_lora_data(self, lora: Dict) -> Dict:
|
# Initialize parent with the service
|
||||||
"""Format LoRA data for template rendering"""
|
super().__init__(self.service)
|
||||||
return {
|
|
||||||
"model_name": lora["model_name"],
|
|
||||||
"file_name": lora["file_name"],
|
|
||||||
"preview_url": config.get_preview_static_url(lora["preview_url"]),
|
|
||||||
"preview_nsfw_level": lora.get("preview_nsfw_level", 0),
|
|
||||||
"base_model": lora["base_model"],
|
|
||||||
"folder": lora["folder"],
|
|
||||||
"sha256": lora["sha256"],
|
|
||||||
"file_path": lora["file_path"].replace(os.sep, "/"),
|
|
||||||
"size": lora["size"],
|
|
||||||
"tags": lora["tags"],
|
|
||||||
"modelDescription": lora["modelDescription"],
|
|
||||||
"usage_tips": lora["usage_tips"],
|
|
||||||
"notes": lora["notes"],
|
|
||||||
"modified": lora["modified"],
|
|
||||||
"from_civitai": lora.get("from_civitai", True),
|
|
||||||
"civitai": self._filter_civitai_data(lora.get("civitai", {}))
|
|
||||||
}
|
|
||||||
|
|
||||||
def _filter_civitai_data(self, data: Dict) -> Dict:
|
|
||||||
"""Filter relevant fields from CivitAI data"""
|
|
||||||
if not data:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
fields = [
|
|
||||||
"id", "modelId", "name", "createdAt", "updatedAt",
|
|
||||||
"publishedAt", "trainedWords", "baseModel", "description",
|
|
||||||
"model", "images"
|
|
||||||
]
|
|
||||||
return {k: data[k] for k in fields if k in data}
|
|
||||||
|
|
||||||
async def handle_loras_page(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle GET /loras request"""
|
|
||||||
try:
|
|
||||||
# Ensure services are initialized
|
|
||||||
await self.init_services()
|
|
||||||
|
|
||||||
# Check if the LoraScanner is initializing
|
|
||||||
# It's initializing if the cache object doesn't exist yet,
|
|
||||||
# OR if the scanner explicitly says it's initializing (background task running).
|
|
||||||
is_initializing = (
|
|
||||||
self.scanner._cache is None or self.scanner.is_initializing()
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_initializing:
|
|
||||||
# If still initializing, return loading page
|
|
||||||
template = self.template_env.get_template('loras.html')
|
|
||||||
rendered = template.render(
|
|
||||||
folders=[],
|
|
||||||
is_initializing=True,
|
|
||||||
settings=settings,
|
|
||||||
request=request
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Loras page is initializing, returning loading page")
|
|
||||||
else:
|
|
||||||
# Normal flow - get data from initialized cache
|
|
||||||
try:
|
|
||||||
cache = await self.scanner.get_cached_data(force_refresh=False)
|
|
||||||
template = self.template_env.get_template('loras.html')
|
|
||||||
rendered = template.render(
|
|
||||||
folders=cache.folders,
|
|
||||||
is_initializing=False,
|
|
||||||
settings=settings,
|
|
||||||
request=request
|
|
||||||
)
|
|
||||||
except Exception as cache_error:
|
|
||||||
logger.error(f"Error loading cache data: {cache_error}")
|
|
||||||
template = self.template_env.get_template('loras.html')
|
|
||||||
rendered = template.render(
|
|
||||||
folders=[],
|
|
||||||
is_initializing=True,
|
|
||||||
settings=settings,
|
|
||||||
request=request
|
|
||||||
)
|
|
||||||
logger.info("Cache error, returning initialization page")
|
|
||||||
|
|
||||||
return web.Response(
|
|
||||||
text=rendered,
|
|
||||||
content_type='text/html'
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error handling loras request: {e}", exc_info=True)
|
|
||||||
return web.Response(
|
|
||||||
text="Error loading loras page",
|
|
||||||
status=500
|
|
||||||
)
|
|
||||||
|
|
||||||
async def handle_recipes_page(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle GET /loras/recipes request"""
|
|
||||||
try:
|
|
||||||
# Ensure services are initialized
|
|
||||||
await self.init_services()
|
|
||||||
|
|
||||||
# Skip initialization check and directly try to get cached data
|
|
||||||
try:
|
|
||||||
# Recipe scanner will initialize cache if needed
|
|
||||||
await self.recipe_scanner.get_cached_data(force_refresh=False)
|
|
||||||
template = self.template_env.get_template('recipes.html')
|
|
||||||
rendered = template.render(
|
|
||||||
recipes=[], # Frontend will load recipes via API
|
|
||||||
is_initializing=False,
|
|
||||||
settings=settings,
|
|
||||||
request=request
|
|
||||||
)
|
|
||||||
except Exception as cache_error:
|
|
||||||
logger.error(f"Error loading recipe cache data: {cache_error}")
|
|
||||||
# Still keep error handling - show initializing page on error
|
|
||||||
template = self.template_env.get_template('recipes.html')
|
|
||||||
rendered = template.render(
|
|
||||||
is_initializing=True,
|
|
||||||
settings=settings,
|
|
||||||
request=request
|
|
||||||
)
|
|
||||||
logger.info("Recipe cache error, returning initialization page")
|
|
||||||
|
|
||||||
return web.Response(
|
|
||||||
text=rendered,
|
|
||||||
content_type='text/html'
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error handling recipes request: {e}", exc_info=True)
|
|
||||||
return web.Response(
|
|
||||||
text="Error loading recipes page",
|
|
||||||
status=500
|
|
||||||
)
|
|
||||||
|
|
||||||
def _format_recipe_file_url(self, file_path: str) -> str:
|
|
||||||
"""Format file path for recipe image as a URL - same as in recipe_routes"""
|
|
||||||
try:
|
|
||||||
# Return the file URL directly for the first lora root's preview
|
|
||||||
recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, '/')
|
|
||||||
if file_path.replace(os.sep, '/').startswith(recipes_dir):
|
|
||||||
relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, '/')
|
|
||||||
return f"/loras_static/root1/preview/{relative_path}"
|
|
||||||
|
|
||||||
# If not in recipes dir, try to create a valid URL from the file path
|
|
||||||
file_name = os.path.basename(file_path)
|
|
||||||
return f"/loras_static/root1/preview/recipes/{file_name}"
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error formatting recipe file URL: {e}", exc_info=True)
|
|
||||||
return '/loras_static/images/no-preview.png' # Return default image on error
|
|
||||||
|
|
||||||
def setup_routes(self, app: web.Application):
|
def setup_routes(self, app: web.Application):
|
||||||
"""Register routes with the application"""
|
"""Setup LoRA routes"""
|
||||||
# Add an app startup handler to initialize services
|
# Schedule service initialization on app startup
|
||||||
app.on_startup.append(self._on_startup)
|
app.on_startup.append(lambda _: self.initialize_services())
|
||||||
|
|
||||||
# Register routes
|
# Setup common routes with 'loras' prefix (includes page route)
|
||||||
app.router.add_get('/loras', self.handle_loras_page)
|
super().setup_routes(app, 'loras')
|
||||||
app.router.add_get('/loras/recipes', self.handle_recipes_page)
|
|
||||||
|
|
||||||
async def _on_startup(self, app):
|
def setup_specific_routes(self, app: web.Application, prefix: str):
|
||||||
"""Initialize services when the app starts"""
|
"""Setup LoRA-specific routes"""
|
||||||
await self.init_services()
|
# LoRA-specific query routes
|
||||||
|
app.router.add_get(f'/api/{prefix}/letter-counts', self.get_letter_counts)
|
||||||
|
app.router.add_get(f'/api/{prefix}/get-notes', self.get_lora_notes)
|
||||||
|
app.router.add_get(f'/api/{prefix}/get-trigger-words', self.get_lora_trigger_words)
|
||||||
|
app.router.add_get(f'/api/lora-preview-url', self.get_lora_preview_url)
|
||||||
|
app.router.add_get(f'/api/lora-civitai-url', self.get_lora_civitai_url)
|
||||||
|
app.router.add_get(f'/api/lora-model-description', self.get_lora_model_description)
|
||||||
|
|
||||||
|
# LoRA-specific management routes
|
||||||
|
app.router.add_post(f'/api/move_model', self.move_model)
|
||||||
|
app.router.add_post(f'/api/move_models_bulk', self.move_models_bulk)
|
||||||
|
|
||||||
|
# CivitAI integration with LoRA-specific validation
|
||||||
|
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_lora)
|
||||||
|
app.router.add_get(f'/api/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version)
|
||||||
|
app.router.add_get(f'/api/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash)
|
||||||
|
|
||||||
|
# ComfyUI integration
|
||||||
|
app.router.add_post(f'/loramanager/get_trigger_words', self.get_trigger_words)
|
||||||
|
|
||||||
|
def _parse_specific_params(self, request: web.Request) -> Dict:
|
||||||
|
"""Parse LoRA-specific parameters"""
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
# LoRA-specific parameters
|
||||||
|
if 'first_letter' in request.query:
|
||||||
|
params['first_letter'] = request.query.get('first_letter')
|
||||||
|
|
||||||
|
# Handle fuzzy search parameter name variation
|
||||||
|
if request.query.get('fuzzy') == 'true':
|
||||||
|
params['fuzzy_search'] = True
|
||||||
|
|
||||||
|
# Handle additional filter parameters for LoRAs
|
||||||
|
if 'lora_hash' in request.query:
|
||||||
|
if not params.get('hash_filters'):
|
||||||
|
params['hash_filters'] = {}
|
||||||
|
params['hash_filters']['single_hash'] = request.query['lora_hash'].lower()
|
||||||
|
elif 'lora_hashes' in request.query:
|
||||||
|
if not params.get('hash_filters'):
|
||||||
|
params['hash_filters'] = {}
|
||||||
|
params['hash_filters']['multiple_hashes'] = [h.lower() for h in request.query['lora_hashes'].split(',')]
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
# LoRA-specific route handlers
|
||||||
|
async def get_letter_counts(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get count of LoRAs for each letter of the alphabet"""
|
||||||
|
try:
|
||||||
|
letter_counts = await self.service.get_letter_counts()
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'letter_counts': letter_counts
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting letter counts: {e}")
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_lora_notes(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get notes for a specific LoRA file"""
|
||||||
|
try:
|
||||||
|
lora_name = request.query.get('name')
|
||||||
|
if not lora_name:
|
||||||
|
return web.Response(text='Lora file name is required', status=400)
|
||||||
|
|
||||||
|
notes = await self.service.get_lora_notes(lora_name)
|
||||||
|
if notes is not None:
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'notes': notes
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'LoRA not found in cache'
|
||||||
|
}, status=404)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting lora notes: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_lora_trigger_words(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get trigger words for a specific LoRA file"""
|
||||||
|
try:
|
||||||
|
lora_name = request.query.get('name')
|
||||||
|
if not lora_name:
|
||||||
|
return web.Response(text='Lora file name is required', status=400)
|
||||||
|
|
||||||
|
trigger_words = await self.service.get_lora_trigger_words(lora_name)
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'trigger_words': trigger_words
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting lora trigger words: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get the static preview URL for a LoRA file"""
|
||||||
|
try:
|
||||||
|
lora_name = request.query.get('name')
|
||||||
|
if not lora_name:
|
||||||
|
return web.Response(text='Lora file name is required', status=400)
|
||||||
|
|
||||||
|
preview_url = await self.service.get_lora_preview_url(lora_name)
|
||||||
|
if preview_url:
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'preview_url': preview_url
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'No preview URL found for the specified lora'
|
||||||
|
}, status=404)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting lora preview URL: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get the Civitai URL for a LoRA file"""
|
||||||
|
try:
|
||||||
|
lora_name = request.query.get('name')
|
||||||
|
if not lora_name:
|
||||||
|
return web.Response(text='Lora file name is required', status=400)
|
||||||
|
|
||||||
|
result = await self.service.get_lora_civitai_url(lora_name)
|
||||||
|
if result['civitai_url']:
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
**result
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'No Civitai data found for the specified lora'
|
||||||
|
}, status=404)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
# Override get_models to add LoRA-specific response data
|
||||||
|
async def get_models(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get paginated LoRA data with LoRA-specific fields"""
|
||||||
|
try:
|
||||||
|
# Parse common query parameters
|
||||||
|
params = self._parse_common_params(request)
|
||||||
|
|
||||||
|
# Get data from service
|
||||||
|
result = await self.service.get_paginated_data(**params)
|
||||||
|
|
||||||
|
# Get all available folders from cache for LoRA-specific response
|
||||||
|
cache = await self.service.scanner.get_cached_data()
|
||||||
|
|
||||||
|
# Format response items with LoRA-specific structure
|
||||||
|
formatted_result = {
|
||||||
|
'items': [await self.service.format_response(item) for item in result['items']],
|
||||||
|
'folders': cache.folders, # LoRA-specific: include folders in response
|
||||||
|
'total': result['total'],
|
||||||
|
'page': result['page'],
|
||||||
|
'page_size': result['page_size'],
|
||||||
|
'total_pages': result['total_pages']
|
||||||
|
}
|
||||||
|
|
||||||
|
return web.json_response(formatted_result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in get_loras: {e}", exc_info=True)
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
# CivitAI integration methods
|
||||||
|
async def get_civitai_versions_lora(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get available versions for a Civitai LoRA model with local availability info"""
|
||||||
|
try:
|
||||||
|
model_id = request.match_info['model_id']
|
||||||
|
response = await self.civitai_client.get_model_versions(model_id)
|
||||||
|
if not response or not response.get('modelVersions'):
|
||||||
|
return web.Response(status=404, text="Model not found")
|
||||||
|
|
||||||
|
versions = response.get('modelVersions', [])
|
||||||
|
model_type = response.get('type', '')
|
||||||
|
|
||||||
|
# Check model type - should be LORA, LoCon, or DORA
|
||||||
|
from ..utils.constants import VALID_LORA_TYPES
|
||||||
|
if model_type.lower() not in VALID_LORA_TYPES:
|
||||||
|
return web.json_response({
|
||||||
|
'error': f"Model type mismatch. Expected LORA or LoCon, got {model_type}"
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
# Check local availability for each version
|
||||||
|
for version in versions:
|
||||||
|
# Find the model file (type="Model") in the files list
|
||||||
|
model_file = next((file for file in version.get('files', [])
|
||||||
|
if file.get('type') == 'Model'), None)
|
||||||
|
|
||||||
|
if model_file:
|
||||||
|
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||||
|
if sha256:
|
||||||
|
# Set existsLocally and localPath at the version level
|
||||||
|
version['existsLocally'] = self.service.has_hash(sha256)
|
||||||
|
if version['existsLocally']:
|
||||||
|
version['localPath'] = self.service.get_path_by_hash(sha256)
|
||||||
|
|
||||||
|
# Also set the model file size at the version level for easier access
|
||||||
|
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||||
|
else:
|
||||||
|
# No model file found in this version
|
||||||
|
version['existsLocally'] = False
|
||||||
|
|
||||||
|
return web.json_response(versions)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching LoRA model versions: {e}")
|
||||||
|
return web.Response(status=500, text=str(e))
|
||||||
|
|
||||||
|
async def get_civitai_model_by_version(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get CivitAI model details by model version ID"""
|
||||||
|
try:
|
||||||
|
model_version_id = request.match_info.get('modelVersionId')
|
||||||
|
|
||||||
|
# Get model details from Civitai API
|
||||||
|
model, error_msg = await self.civitai_client.get_model_version_info(model_version_id)
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
# Log warning for failed model retrieval
|
||||||
|
logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}")
|
||||||
|
|
||||||
|
# Determine status code based on error message
|
||||||
|
status_code = 404 if error_msg and "not found" in error_msg.lower() else 500
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
"success": False,
|
||||||
|
"error": error_msg or "Failed to fetch model information"
|
||||||
|
}, status=status_code)
|
||||||
|
|
||||||
|
return web.json_response(model)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching model details: {e}")
|
||||||
|
return web.json_response({
|
||||||
|
"success": False,
|
||||||
|
"error": str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get CivitAI model details by hash"""
|
||||||
|
try:
|
||||||
|
hash = request.match_info.get('hash')
|
||||||
|
model = await self.civitai_client.get_model_by_hash(hash)
|
||||||
|
return web.json_response(model)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching model details by hash: {e}")
|
||||||
|
return web.json_response({
|
||||||
|
"success": False,
|
||||||
|
"error": str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
# Model management methods
|
||||||
|
async def move_model(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle model move request"""
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
file_path = data.get('file_path') # full path of the model file
|
||||||
|
target_path = data.get('target_path') # folder path to move the model to
|
||||||
|
|
||||||
|
if not file_path or not target_path:
|
||||||
|
return web.Response(text='File path and target path are required', status=400)
|
||||||
|
|
||||||
|
# Check if source and destination are the same
|
||||||
|
import os
|
||||||
|
source_dir = os.path.dirname(file_path)
|
||||||
|
if os.path.normpath(source_dir) == os.path.normpath(target_path):
|
||||||
|
logger.info(f"Source and target directories are the same: {source_dir}")
|
||||||
|
return web.json_response({'success': True, 'message': 'Source and target directories are the same'})
|
||||||
|
|
||||||
|
# Check if target file already exists
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/')
|
||||||
|
|
||||||
|
if os.path.exists(target_file_path):
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': f"Target file already exists: {target_file_path}"
|
||||||
|
}, status=409) # 409 Conflict
|
||||||
|
|
||||||
|
# Call scanner to handle the move operation
|
||||||
|
success = await self.service.scanner.move_model(file_path, target_path)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
return web.json_response({'success': True})
|
||||||
|
else:
|
||||||
|
return web.Response(text='Failed to move model', status=500)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error moving model: {e}", exc_info=True)
|
||||||
|
return web.Response(text=str(e), status=500)
|
||||||
|
|
||||||
|
async def move_models_bulk(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle bulk model move request"""
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
file_paths = data.get('file_paths', []) # list of full paths of the model files
|
||||||
|
target_path = data.get('target_path') # folder path to move the models to
|
||||||
|
|
||||||
|
if not file_paths or not target_path:
|
||||||
|
return web.Response(text='File paths and target path are required', status=400)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
import os
|
||||||
|
for file_path in file_paths:
|
||||||
|
# Check if source and destination are the same
|
||||||
|
source_dir = os.path.dirname(file_path)
|
||||||
|
if os.path.normpath(source_dir) == os.path.normpath(target_path):
|
||||||
|
results.append({
|
||||||
|
"path": file_path,
|
||||||
|
"success": True,
|
||||||
|
"message": "Source and target directories are the same"
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if target file already exists
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/')
|
||||||
|
|
||||||
|
if os.path.exists(target_file_path):
|
||||||
|
results.append({
|
||||||
|
"path": file_path,
|
||||||
|
"success": False,
|
||||||
|
"message": f"Target file already exists: {target_file_path}"
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Try to move the model
|
||||||
|
success = await self.service.scanner.move_model(file_path, target_path)
|
||||||
|
results.append({
|
||||||
|
"path": file_path,
|
||||||
|
"success": success,
|
||||||
|
"message": "Success" if success else "Failed to move model"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Count successes and failures
|
||||||
|
success_count = sum(1 for r in results if r["success"])
|
||||||
|
failure_count = len(results) - success_count
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'message': f'Moved {success_count} of {len(file_paths)} models',
|
||||||
|
'results': results,
|
||||||
|
'success_count': success_count,
|
||||||
|
'failure_count': failure_count
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error moving models in bulk: {e}", exc_info=True)
|
||||||
|
return web.Response(text=str(e), status=500)
|
||||||
|
|
||||||
|
async def get_lora_model_description(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get model description for a Lora model"""
|
||||||
|
try:
|
||||||
|
# Get parameters
|
||||||
|
model_id = request.query.get('model_id')
|
||||||
|
file_path = request.query.get('file_path')
|
||||||
|
|
||||||
|
if not model_id:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Model ID is required'
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
# Check if we already have the description stored in metadata
|
||||||
|
description = None
|
||||||
|
tags = []
|
||||||
|
creator = {}
|
||||||
|
if file_path:
|
||||||
|
import os
|
||||||
|
from ..utils.metadata_manager import MetadataManager
|
||||||
|
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||||
|
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||||
|
description = metadata.get('modelDescription')
|
||||||
|
tags = metadata.get('tags', [])
|
||||||
|
creator = metadata.get('creator', {})
|
||||||
|
|
||||||
|
# If description is not in metadata, fetch from CivitAI
|
||||||
|
if not description:
|
||||||
|
logger.info(f"Fetching model metadata for model ID: {model_id}")
|
||||||
|
model_metadata, _ = await self.civitai_client.get_model_metadata(model_id)
|
||||||
|
|
||||||
|
if model_metadata:
|
||||||
|
description = model_metadata.get('description')
|
||||||
|
tags = model_metadata.get('tags', [])
|
||||||
|
creator = model_metadata.get('creator', {})
|
||||||
|
|
||||||
|
# Save the metadata to file if we have a file path and got metadata
|
||||||
|
if file_path:
|
||||||
|
try:
|
||||||
|
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||||
|
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||||
|
|
||||||
|
metadata['modelDescription'] = description
|
||||||
|
metadata['tags'] = tags
|
||||||
|
# Ensure the civitai dict exists
|
||||||
|
if 'civitai' not in metadata:
|
||||||
|
metadata['civitai'] = {}
|
||||||
|
# Store creator in the civitai nested structure
|
||||||
|
metadata['civitai']['creator'] = creator
|
||||||
|
|
||||||
|
await MetadataManager.save_metadata(file_path, metadata, True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving model metadata: {e}")
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'description': description or "<p>No model description available.</p>",
|
||||||
|
'tags': tags,
|
||||||
|
'creator': creator
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting model metadata: {e}")
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_trigger_words(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get trigger words for specified LoRA models"""
|
||||||
|
try:
|
||||||
|
json_data = await request.json()
|
||||||
|
lora_names = json_data.get("lora_names", [])
|
||||||
|
node_ids = json_data.get("node_ids", [])
|
||||||
|
|
||||||
|
all_trigger_words = []
|
||||||
|
for lora_name in lora_names:
|
||||||
|
_, trigger_words = get_lora_info(lora_name)
|
||||||
|
all_trigger_words.extend(trigger_words)
|
||||||
|
|
||||||
|
# Format the trigger words
|
||||||
|
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||||
|
|
||||||
|
# Send update to all connected trigger word toggle nodes
|
||||||
|
for node_id in node_ids:
|
||||||
|
PromptServer.instance.send_sync("trigger_word_update", {
|
||||||
|
"id": node_id,
|
||||||
|
"message": trigger_words_text
|
||||||
|
})
|
||||||
|
|
||||||
|
return web.json_response({"success": True})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting trigger words: {e}")
|
||||||
|
return web.json_response({
|
||||||
|
"success": False,
|
||||||
|
"error": str(e)
|
||||||
|
}, status=500)
|
||||||
|
|||||||
@@ -633,9 +633,8 @@ class MiscRoutes:
|
|||||||
}, status=400)
|
}, status=400)
|
||||||
|
|
||||||
# Get both lora and checkpoint scanners
|
# Get both lora and checkpoint scanners
|
||||||
registry = ServiceRegistry.get_instance()
|
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
lora_scanner = await registry.get_lora_scanner()
|
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||||
checkpoint_scanner = await registry.get_checkpoint_scanner()
|
|
||||||
|
|
||||||
# If modelVersionId is provided, check for specific version
|
# If modelVersionId is provided, check for specific version
|
||||||
if model_version_id_str:
|
if model_version_id_str:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import base64
|
import base64
|
||||||
|
import jinja2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import io
|
import io
|
||||||
@@ -15,6 +16,7 @@ from ..utils.exif_utils import ExifUtils
|
|||||||
from ..recipes import RecipeParserFactory
|
from ..recipes import RecipeParserFactory
|
||||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||||
|
|
||||||
|
from ..services.settings_manager import settings
|
||||||
from ..config import config
|
from ..config import config
|
||||||
|
|
||||||
# Check if running in standalone mode
|
# Check if running in standalone mode
|
||||||
@@ -39,7 +41,10 @@ class RecipeRoutes:
|
|||||||
# Initialize service references as None, will be set during async init
|
# Initialize service references as None, will be set during async init
|
||||||
self.recipe_scanner = None
|
self.recipe_scanner = None
|
||||||
self.civitai_client = None
|
self.civitai_client = None
|
||||||
# Remove WorkflowParser instance
|
self.template_env = jinja2.Environment(
|
||||||
|
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||||
|
autoescape=True
|
||||||
|
)
|
||||||
|
|
||||||
# Pre-warm the cache
|
# Pre-warm the cache
|
||||||
self._init_cache_task = None
|
self._init_cache_task = None
|
||||||
@@ -53,6 +58,8 @@ class RecipeRoutes:
|
|||||||
def setup_routes(cls, app: web.Application):
|
def setup_routes(cls, app: web.Application):
|
||||||
"""Register API routes"""
|
"""Register API routes"""
|
||||||
routes = cls()
|
routes = cls()
|
||||||
|
app.router.add_get('/loras/recipes', routes.handle_recipes_page)
|
||||||
|
|
||||||
app.router.add_get('/api/recipes', routes.get_recipes)
|
app.router.add_get('/api/recipes', routes.get_recipes)
|
||||||
app.router.add_get('/api/recipe/{recipe_id}', routes.get_recipe_detail)
|
app.router.add_get('/api/recipe/{recipe_id}', routes.get_recipe_detail)
|
||||||
app.router.add_post('/api/recipes/analyze-image', routes.analyze_recipe_image)
|
app.router.add_post('/api/recipes/analyze-image', routes.analyze_recipe_image)
|
||||||
@@ -115,6 +122,46 @@ class RecipeRoutes:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error pre-warming recipe cache: {e}", exc_info=True)
|
logger.error(f"Error pre-warming recipe cache: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def handle_recipes_page(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle GET /loras/recipes request"""
|
||||||
|
try:
|
||||||
|
# Ensure services are initialized
|
||||||
|
await self.init_services()
|
||||||
|
|
||||||
|
# Skip initialization check and directly try to get cached data
|
||||||
|
try:
|
||||||
|
# Recipe scanner will initialize cache if needed
|
||||||
|
await self.recipe_scanner.get_cached_data(force_refresh=False)
|
||||||
|
template = self.template_env.get_template('recipes.html')
|
||||||
|
rendered = template.render(
|
||||||
|
recipes=[], # Frontend will load recipes via API
|
||||||
|
is_initializing=False,
|
||||||
|
settings=settings,
|
||||||
|
request=request
|
||||||
|
)
|
||||||
|
except Exception as cache_error:
|
||||||
|
logger.error(f"Error loading recipe cache data: {cache_error}")
|
||||||
|
# Still keep error handling - show initializing page on error
|
||||||
|
template = self.template_env.get_template('recipes.html')
|
||||||
|
rendered = template.render(
|
||||||
|
is_initializing=True,
|
||||||
|
settings=settings,
|
||||||
|
request=request
|
||||||
|
)
|
||||||
|
logger.info("Recipe cache error, returning initialization page")
|
||||||
|
|
||||||
|
return web.Response(
|
||||||
|
text=rendered,
|
||||||
|
content_type='text/html'
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling recipes request: {e}", exc_info=True)
|
||||||
|
return web.Response(
|
||||||
|
text="Error loading recipes page",
|
||||||
|
status=500
|
||||||
|
)
|
||||||
|
|
||||||
async def get_recipes(self, request: web.Request) -> web.Response:
|
async def get_recipes(self, request: web.Request) -> web.Response:
|
||||||
"""API endpoint for getting paginated recipes"""
|
"""API endpoint for getting paginated recipes"""
|
||||||
try:
|
try:
|
||||||
@@ -1101,7 +1148,7 @@ class RecipeRoutes:
|
|||||||
for lora_name, lora_strength in lora_matches:
|
for lora_name, lora_strength in lora_matches:
|
||||||
try:
|
try:
|
||||||
# Get lora info from scanner
|
# Get lora info from scanner
|
||||||
lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora_name)
|
lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora_name)
|
||||||
|
|
||||||
# Create lora entry
|
# Create lora entry
|
||||||
lora_entry = {
|
lora_entry = {
|
||||||
@@ -1120,7 +1167,7 @@ class RecipeRoutes:
|
|||||||
# Get base model from lora scanner for the available loras
|
# Get base model from lora scanner for the available loras
|
||||||
base_model_counts = {}
|
base_model_counts = {}
|
||||||
for lora in loras_data:
|
for lora in loras_data:
|
||||||
lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora.get("file_name", ""))
|
lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora.get("file_name", ""))
|
||||||
if lora_info and "base_model" in lora_info:
|
if lora_info and "base_model" in lora_info:
|
||||||
base_model = lora_info["base_model"]
|
base_model = lora_info["base_model"]
|
||||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||||
@@ -1210,7 +1257,7 @@ class RecipeRoutes:
|
|||||||
if lora.get("isDeleted", False):
|
if lora.get("isDeleted", False):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not self.recipe_scanner._lora_scanner.has_lora_hash(lora.get("hash", "")):
|
if not self.recipe_scanner._lora_scanner.has_hash(lora.get("hash", "")):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get the strength
|
# Get the strength
|
||||||
@@ -1318,7 +1365,7 @@ class RecipeRoutes:
|
|||||||
return web.json_response({"error": "Recipe not found"}, status=404)
|
return web.json_response({"error": "Recipe not found"}, status=404)
|
||||||
|
|
||||||
# Find target LoRA by name
|
# Find target LoRA by name
|
||||||
target_lora = await lora_scanner.get_lora_info_by_name(target_name)
|
target_lora = await lora_scanner.get_model_info_by_name(target_name)
|
||||||
if not target_lora:
|
if not target_lora:
|
||||||
return web.json_response({"error": f"Local LoRA not found with name: {target_name}"}, status=404)
|
return web.json_response({"error": f"Local LoRA not found with name: {target_name}"}, status=404)
|
||||||
|
|
||||||
@@ -1430,9 +1477,9 @@ class RecipeRoutes:
|
|||||||
if 'loras' in recipe:
|
if 'loras' in recipe:
|
||||||
for lora in recipe['loras']:
|
for lora in recipe['loras']:
|
||||||
if 'hash' in lora and lora['hash']:
|
if 'hash' in lora and lora['hash']:
|
||||||
lora['inLibrary'] = self.recipe_scanner._lora_scanner.has_lora_hash(lora['hash'].lower())
|
lora['inLibrary'] = self.recipe_scanner._lora_scanner.has_hash(lora['hash'].lower())
|
||||||
lora['preview_url'] = self.recipe_scanner._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
|
lora['preview_url'] = self.recipe_scanner._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
|
||||||
lora['localPath'] = self.recipe_scanner._lora_scanner.get_lora_path_by_hash(lora['hash'].lower())
|
lora['localPath'] = self.recipe_scanner._lora_scanner.get_path_by_hash(lora['hash'].lower())
|
||||||
|
|
||||||
# Ensure file_url is set (needed by frontend)
|
# Ensure file_url is set (needed by frontend)
|
||||||
if 'file_path' in recipe:
|
if 'file_path' in recipe:
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import logging
|
import logging
|
||||||
import toml
|
import toml
|
||||||
import subprocess
|
import git
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -17,6 +19,7 @@ class UpdateRoutes:
|
|||||||
"""Register update check routes"""
|
"""Register update check routes"""
|
||||||
app.router.add_get('/api/check-updates', UpdateRoutes.check_updates)
|
app.router.add_get('/api/check-updates', UpdateRoutes.check_updates)
|
||||||
app.router.add_get('/api/version-info', UpdateRoutes.get_version_info)
|
app.router.add_get('/api/version-info', UpdateRoutes.get_version_info)
|
||||||
|
app.router.add_post('/api/perform-update', UpdateRoutes.perform_update)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def check_updates(request):
|
async def check_updates(request):
|
||||||
@@ -25,6 +28,8 @@ class UpdateRoutes:
|
|||||||
Returns update status and version information
|
Returns update status and version information
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
nightly = request.query.get('nightly', 'false').lower() == 'true'
|
||||||
|
|
||||||
# Read local version from pyproject.toml
|
# Read local version from pyproject.toml
|
||||||
local_version = UpdateRoutes._get_local_version()
|
local_version = UpdateRoutes._get_local_version()
|
||||||
|
|
||||||
@@ -32,13 +37,21 @@ class UpdateRoutes:
|
|||||||
git_info = UpdateRoutes._get_git_info()
|
git_info = UpdateRoutes._get_git_info()
|
||||||
|
|
||||||
# Fetch remote version from GitHub
|
# Fetch remote version from GitHub
|
||||||
remote_version, changelog = await UpdateRoutes._get_remote_version()
|
if nightly:
|
||||||
|
remote_version, changelog = await UpdateRoutes._get_nightly_version()
|
||||||
|
else:
|
||||||
|
remote_version, changelog = await UpdateRoutes._get_remote_version()
|
||||||
|
|
||||||
# Compare versions
|
# Compare versions
|
||||||
update_available = UpdateRoutes._compare_versions(
|
if nightly:
|
||||||
local_version.replace('v', ''),
|
# For nightly, compare commit hashes
|
||||||
remote_version.replace('v', '')
|
update_available = UpdateRoutes._compare_nightly_versions(git_info, remote_version)
|
||||||
)
|
else:
|
||||||
|
# For stable, compare semantic versions
|
||||||
|
update_available = UpdateRoutes._compare_versions(
|
||||||
|
local_version.replace('v', ''),
|
||||||
|
remote_version.replace('v', '')
|
||||||
|
)
|
||||||
|
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
'success': True,
|
'success': True,
|
||||||
@@ -46,7 +59,8 @@ class UpdateRoutes:
|
|||||||
'latest_version': remote_version,
|
'latest_version': remote_version,
|
||||||
'update_available': update_available,
|
'update_available': update_available,
|
||||||
'changelog': changelog,
|
'changelog': changelog,
|
||||||
'git_info': git_info
|
'git_info': git_info,
|
||||||
|
'nightly': nightly
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -84,6 +98,168 @@ class UpdateRoutes:
|
|||||||
'error': str(e)
|
'error': str(e)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def perform_update(request):
|
||||||
|
"""
|
||||||
|
Perform Git-based update to latest release tag or main branch
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Parse request body
|
||||||
|
body = await request.json() if request.has_body else {}
|
||||||
|
nightly = body.get('nightly', False)
|
||||||
|
|
||||||
|
# Get current plugin directory
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
plugin_root = os.path.dirname(os.path.dirname(current_dir))
|
||||||
|
|
||||||
|
# Backup settings.json if it exists
|
||||||
|
settings_path = os.path.join(plugin_root, 'settings.json')
|
||||||
|
settings_backup = None
|
||||||
|
if os.path.exists(settings_path):
|
||||||
|
with open(settings_path, 'r', encoding='utf-8') as f:
|
||||||
|
settings_backup = f.read()
|
||||||
|
logger.info("Backed up settings.json")
|
||||||
|
|
||||||
|
# Perform Git update
|
||||||
|
success, new_version = await UpdateRoutes._perform_git_update(plugin_root, nightly)
|
||||||
|
|
||||||
|
# Restore settings.json if we backed it up
|
||||||
|
if settings_backup and success:
|
||||||
|
with open(settings_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(settings_backup)
|
||||||
|
logger.info("Restored settings.json")
|
||||||
|
|
||||||
|
if success:
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'message': f'Successfully updated to {new_version}',
|
||||||
|
'new_version': new_version
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Failed to complete Git update'
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to perform update: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
})
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _get_nightly_version() -> tuple[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Fetch latest commit from main branch
|
||||||
|
"""
|
||||||
|
repo_owner = "willmiao"
|
||||||
|
repo_name = "ComfyUI-Lora-Manager"
|
||||||
|
|
||||||
|
# Use GitHub API to fetch the latest commit from main branch
|
||||||
|
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(github_url, headers={'Accept': 'application/vnd.github+json'}) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
logger.warning(f"Failed to fetch GitHub commit: {response.status}")
|
||||||
|
return "main", []
|
||||||
|
|
||||||
|
data = await response.json()
|
||||||
|
commit_sha = data.get('sha', '')[:7] # Short hash
|
||||||
|
commit_message = data.get('commit', {}).get('message', '')
|
||||||
|
|
||||||
|
# Format as "main-{short_hash}"
|
||||||
|
version = f"main-{commit_sha}"
|
||||||
|
|
||||||
|
# Use commit message as changelog
|
||||||
|
changelog = [commit_message] if commit_message else []
|
||||||
|
|
||||||
|
return version, changelog
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching nightly version: {e}", exc_info=True)
|
||||||
|
return "main", []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compare_nightly_versions(local_git_info: Dict[str, str], remote_version: str) -> bool:
|
||||||
|
"""
|
||||||
|
Compare local commit hash with remote main branch
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
local_hash = local_git_info.get('short_hash', 'unknown')
|
||||||
|
if local_hash == 'unknown':
|
||||||
|
return True # Assume update available if we can't get local hash
|
||||||
|
|
||||||
|
# Extract remote hash from version string (format: "main-{hash}")
|
||||||
|
if '-' in remote_version:
|
||||||
|
remote_hash = remote_version.split('-')[-1]
|
||||||
|
return local_hash != remote_hash
|
||||||
|
|
||||||
|
return True # Default to update available
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error comparing nightly versions: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _perform_git_update(plugin_root: str, nightly: bool = False) -> tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Perform Git-based update using GitPython
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_root: Path to the plugin root directory
|
||||||
|
nightly: Whether to update to main branch or latest release
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (success, new_version)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Open the Git repository
|
||||||
|
repo = git.Repo(plugin_root)
|
||||||
|
|
||||||
|
# Fetch latest changes
|
||||||
|
origin = repo.remotes.origin
|
||||||
|
origin.fetch()
|
||||||
|
|
||||||
|
if nightly:
|
||||||
|
# Switch to main branch and pull latest
|
||||||
|
main_branch = 'main'
|
||||||
|
if main_branch not in [branch.name for branch in repo.branches]:
|
||||||
|
# Create local main branch if it doesn't exist
|
||||||
|
repo.create_head(main_branch, origin.refs.main)
|
||||||
|
|
||||||
|
repo.heads[main_branch].checkout()
|
||||||
|
origin.pull(main_branch)
|
||||||
|
|
||||||
|
# Get new commit hash
|
||||||
|
new_version = f"main-{repo.head.commit.hexsha[:7]}"
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Get latest release tag
|
||||||
|
tags = sorted(repo.tags, key=lambda t: t.commit.committed_datetime, reverse=True)
|
||||||
|
if not tags:
|
||||||
|
logger.error("No tags found in repository")
|
||||||
|
return False, ""
|
||||||
|
|
||||||
|
latest_tag = tags[0]
|
||||||
|
|
||||||
|
# Checkout to latest tag
|
||||||
|
repo.git.checkout(latest_tag.name)
|
||||||
|
|
||||||
|
new_version = latest_tag.name
|
||||||
|
|
||||||
|
logger.info(f"Successfully updated to {new_version}")
|
||||||
|
return True, new_version
|
||||||
|
|
||||||
|
except git.exc.GitError as e:
|
||||||
|
logger.error(f"Git error during update: {e}")
|
||||||
|
return False, ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during Git update: {e}")
|
||||||
|
return False, ""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_local_version() -> str:
|
def _get_local_version() -> str:
|
||||||
"""Get local plugin version from pyproject.toml"""
|
"""Get local plugin version from pyproject.toml"""
|
||||||
|
|||||||
259
py/services/base_model_service.py
Normal file
259
py/services/base_model_service.py
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, List, Optional, Type
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from ..utils.models import BaseModelMetadata
|
||||||
|
from ..utils.constants import NSFW_LEVELS
|
||||||
|
from .settings_manager import settings
|
||||||
|
from ..utils.utils import fuzzy_match
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class BaseModelService(ABC):
|
||||||
|
"""Base service class for all model types"""
|
||||||
|
|
||||||
|
def __init__(self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata]):
|
||||||
|
"""Initialize the service
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: Type of model (lora, checkpoint, etc.)
|
||||||
|
scanner: Model scanner instance
|
||||||
|
metadata_class: Metadata class for this model type
|
||||||
|
"""
|
||||||
|
self.model_type = model_type
|
||||||
|
self.scanner = scanner
|
||||||
|
self.metadata_class = metadata_class
|
||||||
|
|
||||||
|
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
||||||
|
folder: str = None, search: str = None, fuzzy_search: bool = False,
|
||||||
|
base_models: list = None, tags: list = None,
|
||||||
|
search_options: dict = None, hash_filters: dict = None,
|
||||||
|
favorites_only: bool = False, **kwargs) -> Dict:
|
||||||
|
"""Get paginated and filtered model data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page: Page number (1-based)
|
||||||
|
page_size: Number of items per page
|
||||||
|
sort_by: Sort criteria, e.g. 'name', 'name:asc', 'name:desc', 'date', 'date:asc', 'date:desc'
|
||||||
|
folder: Folder filter
|
||||||
|
search: Search term
|
||||||
|
fuzzy_search: Whether to use fuzzy search
|
||||||
|
base_models: List of base models to filter by
|
||||||
|
tags: List of tags to filter by
|
||||||
|
search_options: Search options dict
|
||||||
|
hash_filters: Hash filtering options
|
||||||
|
favorites_only: Filter for favorites only
|
||||||
|
**kwargs: Additional model-specific filters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict containing paginated results
|
||||||
|
"""
|
||||||
|
cache = await self.scanner.get_cached_data()
|
||||||
|
|
||||||
|
# Parse sort_by into sort_key and order
|
||||||
|
if ':' in sort_by:
|
||||||
|
sort_key, order = sort_by.split(':', 1)
|
||||||
|
sort_key = sort_key.strip()
|
||||||
|
order = order.strip().lower()
|
||||||
|
if order not in ('asc', 'desc'):
|
||||||
|
order = 'asc'
|
||||||
|
else:
|
||||||
|
sort_key = sort_by.strip()
|
||||||
|
order = 'asc'
|
||||||
|
|
||||||
|
# Get default search options if not provided
|
||||||
|
if search_options is None:
|
||||||
|
search_options = {
|
||||||
|
'filename': True,
|
||||||
|
'modelname': True,
|
||||||
|
'tags': False,
|
||||||
|
'recursive': False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get the base data set using new sort logic
|
||||||
|
filtered_data = await cache.get_sorted_data(sort_key, order)
|
||||||
|
|
||||||
|
# Apply hash filtering if provided (highest priority)
|
||||||
|
if hash_filters:
|
||||||
|
filtered_data = await self._apply_hash_filters(filtered_data, hash_filters)
|
||||||
|
|
||||||
|
# Jump to pagination for hash filters
|
||||||
|
return self._paginate(filtered_data, page, page_size)
|
||||||
|
|
||||||
|
# Apply common filters
|
||||||
|
filtered_data = await self._apply_common_filters(
|
||||||
|
filtered_data, folder, base_models, tags, favorites_only, search_options
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply search filtering
|
||||||
|
if search:
|
||||||
|
filtered_data = await self._apply_search_filters(
|
||||||
|
filtered_data, search, fuzzy_search, search_options
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply model-specific filters
|
||||||
|
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
|
||||||
|
|
||||||
|
return self._paginate(filtered_data, page, page_size)
|
||||||
|
|
||||||
|
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
|
||||||
|
"""Apply hash-based filtering"""
|
||||||
|
single_hash = hash_filters.get('single_hash')
|
||||||
|
multiple_hashes = hash_filters.get('multiple_hashes')
|
||||||
|
|
||||||
|
if single_hash:
|
||||||
|
# Filter by single hash
|
||||||
|
single_hash = single_hash.lower()
|
||||||
|
return [
|
||||||
|
item for item in data
|
||||||
|
if item.get('sha256', '').lower() == single_hash
|
||||||
|
]
|
||||||
|
elif multiple_hashes:
|
||||||
|
# Filter by multiple hashes
|
||||||
|
hash_set = set(hash.lower() for hash in multiple_hashes)
|
||||||
|
return [
|
||||||
|
item for item in data
|
||||||
|
if item.get('sha256', '').lower() in hash_set
|
||||||
|
]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def _apply_common_filters(self, data: List[Dict], folder: str = None,
|
||||||
|
base_models: list = None, tags: list = None,
|
||||||
|
favorites_only: bool = False, search_options: dict = None) -> List[Dict]:
|
||||||
|
"""Apply common filters that work across all model types"""
|
||||||
|
# Apply SFW filtering if enabled in settings
|
||||||
|
if settings.get('show_only_sfw', False):
|
||||||
|
data = [
|
||||||
|
item for item in data
|
||||||
|
if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
||||||
|
]
|
||||||
|
|
||||||
|
# Apply favorites filtering if enabled
|
||||||
|
if favorites_only:
|
||||||
|
data = [
|
||||||
|
item for item in data
|
||||||
|
if item.get('favorite', False) is True
|
||||||
|
]
|
||||||
|
|
||||||
|
# Apply folder filtering
|
||||||
|
if folder is not None:
|
||||||
|
if search_options and search_options.get('recursive', False):
|
||||||
|
# Recursive folder filtering - include all subfolders
|
||||||
|
data = [
|
||||||
|
item for item in data
|
||||||
|
if item['folder'].startswith(folder)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# Exact folder filtering
|
||||||
|
data = [
|
||||||
|
item for item in data
|
||||||
|
if item['folder'] == folder
|
||||||
|
]
|
||||||
|
|
||||||
|
# Apply base model filtering
|
||||||
|
if base_models and len(base_models) > 0:
|
||||||
|
data = [
|
||||||
|
item for item in data
|
||||||
|
if item.get('base_model') in base_models
|
||||||
|
]
|
||||||
|
|
||||||
|
# Apply tag filtering
|
||||||
|
if tags and len(tags) > 0:
|
||||||
|
data = [
|
||||||
|
item for item in data
|
||||||
|
if any(tag in item.get('tags', []) for tag in tags)
|
||||||
|
]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def _apply_search_filters(self, data: List[Dict], search: str,
|
||||||
|
fuzzy_search: bool, search_options: dict) -> List[Dict]:
|
||||||
|
"""Apply search filtering"""
|
||||||
|
search_results = []
|
||||||
|
|
||||||
|
for item in data:
|
||||||
|
# Search by file name
|
||||||
|
if search_options.get('filename', True):
|
||||||
|
if fuzzy_search:
|
||||||
|
if fuzzy_match(item.get('file_name', ''), search):
|
||||||
|
search_results.append(item)
|
||||||
|
continue
|
||||||
|
elif search.lower() in item.get('file_name', '').lower():
|
||||||
|
search_results.append(item)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Search by model name
|
||||||
|
if search_options.get('modelname', True):
|
||||||
|
if fuzzy_search:
|
||||||
|
if fuzzy_match(item.get('model_name', ''), search):
|
||||||
|
search_results.append(item)
|
||||||
|
continue
|
||||||
|
elif search.lower() in item.get('model_name', '').lower():
|
||||||
|
search_results.append(item)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Search by tags
|
||||||
|
if search_options.get('tags', False) and 'tags' in item:
|
||||||
|
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower())
|
||||||
|
for tag in item['tags']):
|
||||||
|
search_results.append(item)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||||
|
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
|
||||||
|
"""Apply pagination to filtered data"""
|
||||||
|
total_items = len(data)
|
||||||
|
start_idx = (page - 1) * page_size
|
||||||
|
end_idx = min(start_idx + page_size, total_items)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'items': data[start_idx:end_idx],
|
||||||
|
'total': total_items,
|
||||||
|
'page': page,
|
||||||
|
'page_size': page_size,
|
||||||
|
'total_pages': (total_items + page_size - 1) // page_size
|
||||||
|
}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def format_response(self, model_data: Dict) -> Dict:
|
||||||
|
"""Format model data for API response - must be implemented by subclasses"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Common service methods that delegate to scanner
|
||||||
|
async def get_top_tags(self, limit: int = 20) -> List[Dict]:
|
||||||
|
"""Get top tags sorted by frequency"""
|
||||||
|
return await self.scanner.get_top_tags(limit)
|
||||||
|
|
||||||
|
async def get_base_models(self, limit: int = 20) -> List[Dict]:
|
||||||
|
"""Get base models sorted by frequency"""
|
||||||
|
return await self.scanner.get_base_models(limit)
|
||||||
|
|
||||||
|
def has_hash(self, sha256: str) -> bool:
|
||||||
|
"""Check if a model with given hash exists"""
|
||||||
|
return self.scanner.has_hash(sha256)
|
||||||
|
|
||||||
|
def get_path_by_hash(self, sha256: str) -> Optional[str]:
|
||||||
|
"""Get file path for a model by its hash"""
|
||||||
|
return self.scanner.get_path_by_hash(sha256)
|
||||||
|
|
||||||
|
def get_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||||
|
"""Get hash for a model by its file path"""
|
||||||
|
return self.scanner.get_hash_by_path(file_path)
|
||||||
|
|
||||||
|
async def scan_models(self, force_refresh: bool = False, rebuild_cache: bool = False):
|
||||||
|
"""Trigger model scanning"""
|
||||||
|
return await self.scanner.get_cached_data(force_refresh=force_refresh, rebuild_cache=rebuild_cache)
|
||||||
|
|
||||||
|
async def get_model_info_by_name(self, name: str):
|
||||||
|
"""Get model information by name"""
|
||||||
|
return await self.scanner.get_model_info_by_name(name)
|
||||||
|
|
||||||
|
def get_model_roots(self) -> List[str]:
|
||||||
|
"""Get model root directories"""
|
||||||
|
return self.scanner.get_model_roots()
|
||||||
@@ -1,112 +1,26 @@
|
|||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
from typing import List
|
||||||
from typing import List, Dict, Optional, Set
|
|
||||||
import folder_paths # type: ignore
|
|
||||||
|
|
||||||
from ..utils.models import CheckpointMetadata
|
from ..utils.models import CheckpointMetadata
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from .model_scanner import ModelScanner
|
from .model_scanner import ModelScanner
|
||||||
from .model_hash_index import ModelHashIndex
|
from .model_hash_index import ModelHashIndex
|
||||||
from .service_registry import ServiceRegistry
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class CheckpointScanner(ModelScanner):
|
class CheckpointScanner(ModelScanner):
|
||||||
"""Service for scanning and managing checkpoint files"""
|
"""Service for scanning and managing checkpoint files"""
|
||||||
|
|
||||||
_instance = None
|
|
||||||
_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
def __new__(cls):
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super().__new__(cls)
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
if not hasattr(self, '_initialized'):
|
# Define supported file extensions
|
||||||
# Define supported file extensions
|
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
|
||||||
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
|
super().__init__(
|
||||||
super().__init__(
|
model_type="checkpoint",
|
||||||
model_type="checkpoint",
|
model_class=CheckpointMetadata,
|
||||||
model_class=CheckpointMetadata,
|
file_extensions=file_extensions,
|
||||||
file_extensions=file_extensions,
|
hash_index=ModelHashIndex()
|
||||||
hash_index=ModelHashIndex()
|
)
|
||||||
)
|
|
||||||
self._initialized = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_instance(cls):
|
|
||||||
"""Get singleton instance with async support"""
|
|
||||||
async with cls._lock:
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = cls()
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def get_model_roots(self) -> List[str]:
|
def get_model_roots(self) -> List[str]:
|
||||||
"""Get checkpoint root directories"""
|
"""Get checkpoint root directories"""
|
||||||
return config.base_models_roots
|
return config.base_models_roots
|
||||||
|
|
||||||
async def scan_all_models(self) -> List[Dict]:
|
|
||||||
"""Scan all checkpoint directories and return metadata"""
|
|
||||||
all_checkpoints = []
|
|
||||||
|
|
||||||
# Create scan tasks for each directory
|
|
||||||
scan_tasks = []
|
|
||||||
for root in self.get_model_roots():
|
|
||||||
task = asyncio.create_task(self._scan_directory(root))
|
|
||||||
scan_tasks.append(task)
|
|
||||||
|
|
||||||
# Wait for all tasks to complete
|
|
||||||
for task in scan_tasks:
|
|
||||||
try:
|
|
||||||
checkpoints = await task
|
|
||||||
all_checkpoints.extend(checkpoints)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error scanning checkpoint directory: {e}")
|
|
||||||
|
|
||||||
return all_checkpoints
|
|
||||||
|
|
||||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
|
||||||
"""Scan a directory for checkpoint files"""
|
|
||||||
checkpoints = []
|
|
||||||
original_root = root_path
|
|
||||||
|
|
||||||
async def scan_recursive(path: str, visited_paths: set):
|
|
||||||
try:
|
|
||||||
real_path = os.path.realpath(path)
|
|
||||||
if real_path in visited_paths:
|
|
||||||
logger.debug(f"Skipping already visited path: {path}")
|
|
||||||
return
|
|
||||||
visited_paths.add(real_path)
|
|
||||||
|
|
||||||
with os.scandir(path) as it:
|
|
||||||
entries = list(it)
|
|
||||||
for entry in entries:
|
|
||||||
try:
|
|
||||||
if entry.is_file(follow_symlinks=True):
|
|
||||||
# Check if file has supported extension
|
|
||||||
ext = os.path.splitext(entry.name)[1].lower()
|
|
||||||
if ext in self.file_extensions:
|
|
||||||
file_path = entry.path.replace(os.sep, "/")
|
|
||||||
await self._process_single_file(file_path, original_root, checkpoints)
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
elif entry.is_dir(follow_symlinks=True):
|
|
||||||
# For directories, continue scanning with original path
|
|
||||||
await scan_recursive(entry.path, visited_paths)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error scanning {path}: {e}")
|
|
||||||
|
|
||||||
await scan_recursive(root_path, set())
|
|
||||||
return checkpoints
|
|
||||||
|
|
||||||
async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list):
|
|
||||||
"""Process a single checkpoint file and add to results"""
|
|
||||||
try:
|
|
||||||
result = await self._process_model_file(file_path, root_path)
|
|
||||||
if result:
|
|
||||||
checkpoints.append(result)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing {file_path}: {e}")
|
|
||||||
51
py/services/checkpoint_service.py
Normal file
51
py/services/checkpoint_service.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from .base_model_service import BaseModelService
|
||||||
|
from ..utils.models import CheckpointMetadata
|
||||||
|
from ..config import config
|
||||||
|
from ..utils.routes_common import ModelRouteUtils
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class CheckpointService(BaseModelService):
|
||||||
|
"""Checkpoint-specific service implementation"""
|
||||||
|
|
||||||
|
def __init__(self, scanner):
|
||||||
|
"""Initialize Checkpoint service
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scanner: Checkpoint scanner instance
|
||||||
|
"""
|
||||||
|
super().__init__("checkpoint", scanner, CheckpointMetadata)
|
||||||
|
|
||||||
|
async def format_response(self, checkpoint_data: Dict) -> Dict:
|
||||||
|
"""Format Checkpoint data for API response"""
|
||||||
|
return {
|
||||||
|
"model_name": checkpoint_data["model_name"],
|
||||||
|
"file_name": checkpoint_data["file_name"],
|
||||||
|
"preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")),
|
||||||
|
"preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0),
|
||||||
|
"base_model": checkpoint_data.get("base_model", ""),
|
||||||
|
"folder": checkpoint_data["folder"],
|
||||||
|
"sha256": checkpoint_data.get("sha256", ""),
|
||||||
|
"file_path": checkpoint_data["file_path"].replace(os.sep, "/"),
|
||||||
|
"file_size": checkpoint_data.get("size", 0),
|
||||||
|
"modified": checkpoint_data.get("modified", ""),
|
||||||
|
"tags": checkpoint_data.get("tags", []),
|
||||||
|
"modelDescription": checkpoint_data.get("modelDescription", ""),
|
||||||
|
"from_civitai": checkpoint_data.get("from_civitai", True),
|
||||||
|
"notes": checkpoint_data.get("notes", ""),
|
||||||
|
"model_type": checkpoint_data.get("model_type", "checkpoint"),
|
||||||
|
"favorite": checkpoint_data.get("favorite", False),
|
||||||
|
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint_data.get("civitai", {}))
|
||||||
|
}
|
||||||
|
|
||||||
|
def find_duplicate_hashes(self) -> Dict:
|
||||||
|
"""Find Checkpoints with duplicate SHA256 hashes"""
|
||||||
|
return self.scanner._hash_index.get_duplicate_hashes()
|
||||||
|
|
||||||
|
def find_duplicate_filenames(self) -> Dict:
|
||||||
|
"""Find Checkpoints with conflicting filenames"""
|
||||||
|
return self.scanner._hash_index.get_duplicate_filenames()
|
||||||
@@ -1,15 +1,10 @@
|
|||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
from typing import List
|
||||||
from typing import List, Dict, Optional
|
|
||||||
|
|
||||||
from ..utils.models import LoraMetadata
|
from ..utils.models import LoraMetadata
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from .model_scanner import ModelScanner
|
from .model_scanner import ModelScanner
|
||||||
from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex
|
from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex
|
||||||
from .settings_manager import settings
|
|
||||||
from ..utils.constants import NSFW_LEVELS
|
|
||||||
from ..utils.utils import fuzzy_match
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -17,405 +12,22 @@ logger = logging.getLogger(__name__)
|
|||||||
class LoraScanner(ModelScanner):
|
class LoraScanner(ModelScanner):
|
||||||
"""Service for scanning and managing LoRA files"""
|
"""Service for scanning and managing LoRA files"""
|
||||||
|
|
||||||
_instance = None
|
|
||||||
_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
def __new__(cls):
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super().__new__(cls)
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Ensure initialization happens only once
|
# Define supported file extensions
|
||||||
if not hasattr(self, '_initialized'):
|
file_extensions = {'.safetensors'}
|
||||||
# Define supported file extensions
|
|
||||||
file_extensions = {'.safetensors'}
|
|
||||||
|
|
||||||
# Initialize parent class with ModelHashIndex
|
# Initialize parent class with ModelHashIndex
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_type="lora",
|
model_type="lora",
|
||||||
model_class=LoraMetadata,
|
model_class=LoraMetadata,
|
||||||
file_extensions=file_extensions,
|
file_extensions=file_extensions,
|
||||||
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
|
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
|
||||||
)
|
)
|
||||||
self._initialized = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_instance(cls):
|
|
||||||
"""Get singleton instance with async support"""
|
|
||||||
async with cls._lock:
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = cls()
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def get_model_roots(self) -> List[str]:
|
def get_model_roots(self) -> List[str]:
|
||||||
"""Get lora root directories"""
|
"""Get lora root directories"""
|
||||||
return config.loras_roots
|
return config.loras_roots
|
||||||
|
|
||||||
async def scan_all_models(self) -> List[Dict]:
|
|
||||||
"""Scan all LoRA directories and return metadata"""
|
|
||||||
all_loras = []
|
|
||||||
|
|
||||||
# Create scan tasks for each directory
|
|
||||||
scan_tasks = []
|
|
||||||
for lora_root in self.get_model_roots():
|
|
||||||
task = asyncio.create_task(self._scan_directory(lora_root))
|
|
||||||
scan_tasks.append(task)
|
|
||||||
|
|
||||||
# Wait for all tasks to complete
|
|
||||||
for task in scan_tasks:
|
|
||||||
try:
|
|
||||||
loras = await task
|
|
||||||
all_loras.extend(loras)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error scanning directory: {e}")
|
|
||||||
|
|
||||||
return all_loras
|
|
||||||
|
|
||||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
|
||||||
"""Scan a single directory for LoRA files"""
|
|
||||||
loras = []
|
|
||||||
original_root = root_path # Save original root path
|
|
||||||
|
|
||||||
async def scan_recursive(path: str, visited_paths: set):
|
|
||||||
"""Recursively scan directory, avoiding circular symlinks"""
|
|
||||||
try:
|
|
||||||
real_path = os.path.realpath(path)
|
|
||||||
if real_path in visited_paths:
|
|
||||||
logger.debug(f"Skipping already visited path: {path}")
|
|
||||||
return
|
|
||||||
visited_paths.add(real_path)
|
|
||||||
|
|
||||||
with os.scandir(path) as it:
|
|
||||||
entries = list(it)
|
|
||||||
for entry in entries:
|
|
||||||
try:
|
|
||||||
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
|
|
||||||
# Use original path instead of real path
|
|
||||||
file_path = entry.path.replace(os.sep, "/")
|
|
||||||
await self._process_single_file(file_path, original_root, loras)
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
elif entry.is_dir(follow_symlinks=True):
|
|
||||||
# For directories, continue scanning with original path
|
|
||||||
await scan_recursive(entry.path, visited_paths)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error scanning {path}: {e}")
|
|
||||||
|
|
||||||
await scan_recursive(root_path, set())
|
|
||||||
return loras
|
|
||||||
|
|
||||||
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
|
|
||||||
"""Process a single file and add to results list"""
|
|
||||||
try:
|
|
||||||
result = await self._process_model_file(file_path, root_path)
|
|
||||||
if result:
|
|
||||||
loras.append(result)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing {file_path}: {e}")
|
|
||||||
|
|
||||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
|
||||||
folder: str = None, search: str = None, fuzzy_search: bool = False,
|
|
||||||
base_models: list = None, tags: list = None,
|
|
||||||
search_options: dict = None, hash_filters: dict = None,
|
|
||||||
favorites_only: bool = False, first_letter: str = None) -> Dict:
|
|
||||||
"""Get paginated and filtered lora data
|
|
||||||
|
|
||||||
Args:
|
|
||||||
page: Current page number (1-based)
|
|
||||||
page_size: Number of items per page
|
|
||||||
sort_by: Sort method ('name' or 'date')
|
|
||||||
folder: Filter by folder path
|
|
||||||
search: Search term
|
|
||||||
fuzzy_search: Use fuzzy matching for search
|
|
||||||
base_models: List of base models to filter by
|
|
||||||
tags: List of tags to filter by
|
|
||||||
search_options: Dictionary with search options (filename, modelname, tags, recursive)
|
|
||||||
hash_filters: Dictionary with hash filtering options (single_hash or multiple_hashes)
|
|
||||||
favorites_only: Filter for favorite models only
|
|
||||||
first_letter: Filter by first letter of model name
|
|
||||||
"""
|
|
||||||
cache = await self.get_cached_data()
|
|
||||||
|
|
||||||
# Get default search options if not provided
|
|
||||||
if search_options is None:
|
|
||||||
search_options = {
|
|
||||||
'filename': True,
|
|
||||||
'modelname': True,
|
|
||||||
'tags': False,
|
|
||||||
'recursive': False,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get the base data set
|
|
||||||
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
|
||||||
|
|
||||||
# Apply hash filtering if provided (highest priority)
|
|
||||||
if hash_filters:
|
|
||||||
single_hash = hash_filters.get('single_hash')
|
|
||||||
multiple_hashes = hash_filters.get('multiple_hashes')
|
|
||||||
|
|
||||||
if single_hash:
|
|
||||||
# Filter by single hash
|
|
||||||
single_hash = single_hash.lower() # Ensure lowercase for matching
|
|
||||||
filtered_data = [
|
|
||||||
lora for lora in filtered_data
|
|
||||||
if lora.get('sha256', '').lower() == single_hash
|
|
||||||
]
|
|
||||||
elif multiple_hashes:
|
|
||||||
# Filter by multiple hashes
|
|
||||||
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
|
|
||||||
filtered_data = [
|
|
||||||
lora for lora in filtered_data
|
|
||||||
if lora.get('sha256', '').lower() in hash_set
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Jump to pagination
|
|
||||||
total_items = len(filtered_data)
|
|
||||||
start_idx = (page - 1) * page_size
|
|
||||||
end_idx = min(start_idx + page_size, total_items)
|
|
||||||
|
|
||||||
result = {
|
|
||||||
'items': filtered_data[start_idx:end_idx],
|
|
||||||
'total': total_items,
|
|
||||||
'page': page,
|
|
||||||
'page_size': page_size,
|
|
||||||
'total_pages': (total_items + page_size - 1) // page_size
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
# Apply SFW filtering if enabled
|
|
||||||
if settings.get('show_only_sfw', False):
|
|
||||||
filtered_data = [
|
|
||||||
lora for lora in filtered_data
|
|
||||||
if not lora.get('preview_nsfw_level') or lora.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply favorites filtering if enabled
|
|
||||||
if favorites_only:
|
|
||||||
filtered_data = [
|
|
||||||
lora for lora in filtered_data
|
|
||||||
if lora.get('favorite', False) is True
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply first letter filtering
|
|
||||||
if first_letter:
|
|
||||||
filtered_data = self._filter_by_first_letter(filtered_data, first_letter)
|
|
||||||
|
|
||||||
# Apply folder filtering
|
|
||||||
if folder is not None:
|
|
||||||
if search_options.get('recursive', False):
|
|
||||||
# Recursive folder filtering - include all subfolders
|
|
||||||
filtered_data = [
|
|
||||||
lora for lora in filtered_data
|
|
||||||
if lora['folder'].startswith(folder)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# Exact folder filtering
|
|
||||||
filtered_data = [
|
|
||||||
lora for lora in filtered_data
|
|
||||||
if lora['folder'] == folder
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply base model filtering
|
|
||||||
if base_models and len(base_models) > 0:
|
|
||||||
filtered_data = [
|
|
||||||
lora for lora in filtered_data
|
|
||||||
if lora.get('base_model') in base_models
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply tag filtering
|
|
||||||
if tags and len(tags) > 0:
|
|
||||||
filtered_data = [
|
|
||||||
lora for lora in filtered_data
|
|
||||||
if any(tag in lora.get('tags', []) for tag in tags)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply search filtering
|
|
||||||
if search:
|
|
||||||
search_results = []
|
|
||||||
search_opts = search_options or {}
|
|
||||||
|
|
||||||
for lora in filtered_data:
|
|
||||||
# Search by file name
|
|
||||||
if search_opts.get('filename', True):
|
|
||||||
if fuzzy_match(lora.get('file_name', ''), search):
|
|
||||||
search_results.append(lora)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Search by model name
|
|
||||||
if search_opts.get('modelname', True):
|
|
||||||
if fuzzy_match(lora.get('model_name', ''), search):
|
|
||||||
search_results.append(lora)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Search by tags
|
|
||||||
if search_opts.get('tags', False) and 'tags' in lora:
|
|
||||||
if any(fuzzy_match(tag, search) for tag in lora['tags']):
|
|
||||||
search_results.append(lora)
|
|
||||||
continue
|
|
||||||
|
|
||||||
filtered_data = search_results
|
|
||||||
|
|
||||||
# Calculate pagination
|
|
||||||
total_items = len(filtered_data)
|
|
||||||
start_idx = (page - 1) * page_size
|
|
||||||
end_idx = min(start_idx + page_size, total_items)
|
|
||||||
|
|
||||||
result = {
|
|
||||||
'items': filtered_data[start_idx:end_idx],
|
|
||||||
'total': total_items,
|
|
||||||
'page': page,
|
|
||||||
'page_size': page_size,
|
|
||||||
'total_pages': (total_items + page_size - 1) // page_size
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _filter_by_first_letter(self, data, letter):
|
|
||||||
"""Filter data by first letter of model name
|
|
||||||
|
|
||||||
Special handling:
|
|
||||||
- '#': Numbers (0-9)
|
|
||||||
- '@': Special characters (not alphanumeric)
|
|
||||||
- '漢': CJK characters
|
|
||||||
"""
|
|
||||||
filtered_data = []
|
|
||||||
|
|
||||||
for lora in data:
|
|
||||||
model_name = lora.get('model_name', '')
|
|
||||||
if not model_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
first_char = model_name[0].upper()
|
|
||||||
|
|
||||||
if letter == '#' and first_char.isdigit():
|
|
||||||
filtered_data.append(lora)
|
|
||||||
elif letter == '@' and not first_char.isalnum():
|
|
||||||
# Special characters (not alphanumeric)
|
|
||||||
filtered_data.append(lora)
|
|
||||||
elif letter == '漢' and self._is_cjk_character(first_char):
|
|
||||||
# CJK characters
|
|
||||||
filtered_data.append(lora)
|
|
||||||
elif letter.upper() == first_char:
|
|
||||||
# Regular alphabet matching
|
|
||||||
filtered_data.append(lora)
|
|
||||||
|
|
||||||
return filtered_data
|
|
||||||
|
|
||||||
def _is_cjk_character(self, char):
|
|
||||||
"""Check if character is a CJK character"""
|
|
||||||
# Define Unicode ranges for CJK characters
|
|
||||||
cjk_ranges = [
|
|
||||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
|
||||||
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
|
|
||||||
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
|
|
||||||
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
|
|
||||||
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
|
|
||||||
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
|
|
||||||
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
|
|
||||||
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
|
|
||||||
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
|
|
||||||
(0x3300, 0x33FF), # CJK Compatibility
|
|
||||||
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
|
|
||||||
(0x3100, 0x312F), # Bopomofo
|
|
||||||
(0x31A0, 0x31BF), # Bopomofo Extended
|
|
||||||
(0x3040, 0x309F), # Hiragana
|
|
||||||
(0x30A0, 0x30FF), # Katakana
|
|
||||||
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
|
|
||||||
(0xAC00, 0xD7AF), # Hangul Syllables
|
|
||||||
(0x1100, 0x11FF), # Hangul Jamo
|
|
||||||
(0xA960, 0xA97F), # Hangul Jamo Extended-A
|
|
||||||
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
|
|
||||||
]
|
|
||||||
|
|
||||||
code_point = ord(char)
|
|
||||||
return any(start <= code_point <= end for start, end in cjk_ranges)
|
|
||||||
|
|
||||||
async def get_letter_counts(self):
|
|
||||||
"""Get count of models for each letter of the alphabet"""
|
|
||||||
cache = await self.get_cached_data()
|
|
||||||
data = cache.sorted_by_name
|
|
||||||
|
|
||||||
# Define letter categories
|
|
||||||
letters = {
|
|
||||||
'#': 0, # Numbers
|
|
||||||
'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0,
|
|
||||||
'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0,
|
|
||||||
'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0,
|
|
||||||
'Y': 0, 'Z': 0,
|
|
||||||
'@': 0, # Special characters
|
|
||||||
'漢': 0 # CJK characters
|
|
||||||
}
|
|
||||||
|
|
||||||
# Count models for each letter
|
|
||||||
for lora in data:
|
|
||||||
model_name = lora.get('model_name', '')
|
|
||||||
if not model_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
first_char = model_name[0].upper()
|
|
||||||
|
|
||||||
if first_char.isdigit():
|
|
||||||
letters['#'] += 1
|
|
||||||
elif first_char in letters:
|
|
||||||
letters[first_char] += 1
|
|
||||||
elif self._is_cjk_character(first_char):
|
|
||||||
letters['漢'] += 1
|
|
||||||
elif not first_char.isalnum():
|
|
||||||
letters['@'] += 1
|
|
||||||
|
|
||||||
return letters
|
|
||||||
|
|
||||||
# Lora-specific hash index functionality
|
|
||||||
def has_lora_hash(self, sha256: str) -> bool:
|
|
||||||
"""Check if a LoRA with given hash exists"""
|
|
||||||
return self.has_hash(sha256)
|
|
||||||
|
|
||||||
def get_lora_path_by_hash(self, sha256: str) -> Optional[str]:
|
|
||||||
"""Get file path for a LoRA by its hash"""
|
|
||||||
return self.get_path_by_hash(sha256)
|
|
||||||
|
|
||||||
def get_lora_hash_by_path(self, file_path: str) -> Optional[str]:
|
|
||||||
"""Get hash for a LoRA by its file path"""
|
|
||||||
return self.get_hash_by_path(file_path)
|
|
||||||
|
|
||||||
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
|
|
||||||
"""Get top tags sorted by count"""
|
|
||||||
# Make sure cache is initialized
|
|
||||||
await self.get_cached_data()
|
|
||||||
|
|
||||||
# Sort tags by count in descending order
|
|
||||||
sorted_tags = sorted(
|
|
||||||
[{"tag": tag, "count": count} for tag, count in self._tags_count.items()],
|
|
||||||
key=lambda x: x['count'],
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return limited number
|
|
||||||
return sorted_tags[:limit]
|
|
||||||
|
|
||||||
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
|
|
||||||
"""Get base models used in loras sorted by frequency"""
|
|
||||||
# Make sure cache is initialized
|
|
||||||
cache = await self.get_cached_data()
|
|
||||||
|
|
||||||
# Count base model occurrences
|
|
||||||
base_model_counts = {}
|
|
||||||
for lora in cache.raw_data:
|
|
||||||
if 'base_model' in lora and lora['base_model']:
|
|
||||||
base_model = lora['base_model']
|
|
||||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
|
||||||
|
|
||||||
# Sort base models by count
|
|
||||||
sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()]
|
|
||||||
sorted_models.sort(key=lambda x: x['count'], reverse=True)
|
|
||||||
|
|
||||||
# Return limited number
|
|
||||||
return sorted_models[:limit]
|
|
||||||
|
|
||||||
async def diagnose_hash_index(self):
|
async def diagnose_hash_index(self):
|
||||||
"""Diagnostic method to verify hash index functionality"""
|
"""Diagnostic method to verify hash index functionality"""
|
||||||
print("\n\n*** DIAGNOSING LORA HASH INDEX ***\n\n", file=sys.stderr)
|
print("\n\n*** DIAGNOSING LORA HASH INDEX ***\n\n", file=sys.stderr)
|
||||||
@@ -451,19 +63,3 @@ class LoraScanner(ModelScanner):
|
|||||||
test_hash_result = self._hash_index.get_hash(test_path)
|
test_hash_result = self._hash_index.get_hash(test_path)
|
||||||
print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr)
|
print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr)
|
||||||
|
|
||||||
async def get_lora_info_by_name(self, name):
|
|
||||||
"""Get LoRA information by name"""
|
|
||||||
try:
|
|
||||||
# Get cached data
|
|
||||||
cache = await self.get_cached_data()
|
|
||||||
|
|
||||||
# Find the LoRA by name
|
|
||||||
for lora in cache.raw_data:
|
|
||||||
if lora.get("file_name") == name:
|
|
||||||
return lora
|
|
||||||
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting LoRA info by name: {e}", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|||||||
212
py/services/lora_service.py
Normal file
212
py/services/lora_service.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from .base_model_service import BaseModelService
|
||||||
|
from ..utils.models import LoraMetadata
|
||||||
|
from ..config import config
|
||||||
|
from ..utils.routes_common import ModelRouteUtils
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class LoraService(BaseModelService):
|
||||||
|
"""LoRA-specific service implementation"""
|
||||||
|
|
||||||
|
def __init__(self, scanner):
|
||||||
|
"""Initialize LoRA service
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scanner: LoRA scanner instance
|
||||||
|
"""
|
||||||
|
super().__init__("lora", scanner, LoraMetadata)
|
||||||
|
|
||||||
|
async def format_response(self, lora_data: Dict) -> Dict:
|
||||||
|
"""Format LoRA data for API response"""
|
||||||
|
return {
|
||||||
|
"model_name": lora_data["model_name"],
|
||||||
|
"file_name": lora_data["file_name"],
|
||||||
|
"preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")),
|
||||||
|
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
|
||||||
|
"base_model": lora_data.get("base_model", ""),
|
||||||
|
"folder": lora_data["folder"],
|
||||||
|
"sha256": lora_data.get("sha256", ""),
|
||||||
|
"file_path": lora_data["file_path"].replace(os.sep, "/"),
|
||||||
|
"file_size": lora_data.get("size", 0),
|
||||||
|
"modified": lora_data.get("modified", ""),
|
||||||
|
"tags": lora_data.get("tags", []),
|
||||||
|
"modelDescription": lora_data.get("modelDescription", ""),
|
||||||
|
"from_civitai": lora_data.get("from_civitai", True),
|
||||||
|
"usage_tips": lora_data.get("usage_tips", ""),
|
||||||
|
"notes": lora_data.get("notes", ""),
|
||||||
|
"favorite": lora_data.get("favorite", False),
|
||||||
|
"civitai": ModelRouteUtils.filter_civitai_data(lora_data.get("civitai", {}))
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||||
|
"""Apply LoRA-specific filters"""
|
||||||
|
# Handle first_letter filter for LoRAs
|
||||||
|
first_letter = kwargs.get('first_letter')
|
||||||
|
if first_letter:
|
||||||
|
data = self._filter_by_first_letter(data, first_letter)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
|
||||||
|
"""Filter data by first letter of model name
|
||||||
|
|
||||||
|
Special handling:
|
||||||
|
- '#': Numbers (0-9)
|
||||||
|
- '@': Special characters (not alphanumeric)
|
||||||
|
- '漢': CJK characters
|
||||||
|
"""
|
||||||
|
filtered_data = []
|
||||||
|
|
||||||
|
for lora in data:
|
||||||
|
model_name = lora.get('model_name', '')
|
||||||
|
if not model_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
first_char = model_name[0].upper()
|
||||||
|
|
||||||
|
if letter == '#' and first_char.isdigit():
|
||||||
|
filtered_data.append(lora)
|
||||||
|
elif letter == '@' and not first_char.isalnum():
|
||||||
|
# Special characters (not alphanumeric)
|
||||||
|
filtered_data.append(lora)
|
||||||
|
elif letter == '漢' and self._is_cjk_character(first_char):
|
||||||
|
# CJK characters
|
||||||
|
filtered_data.append(lora)
|
||||||
|
elif letter.upper() == first_char:
|
||||||
|
# Regular alphabet matching
|
||||||
|
filtered_data.append(lora)
|
||||||
|
|
||||||
|
return filtered_data
|
||||||
|
|
||||||
|
def _is_cjk_character(self, char: str) -> bool:
|
||||||
|
"""Check if character is a CJK character"""
|
||||||
|
# Define Unicode ranges for CJK characters
|
||||||
|
cjk_ranges = [
|
||||||
|
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||||
|
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
|
||||||
|
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
|
||||||
|
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
|
||||||
|
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
|
||||||
|
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
|
||||||
|
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
|
||||||
|
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
|
||||||
|
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
|
||||||
|
(0x3300, 0x33FF), # CJK Compatibility
|
||||||
|
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
|
||||||
|
(0x3100, 0x312F), # Bopomofo
|
||||||
|
(0x31A0, 0x31BF), # Bopomofo Extended
|
||||||
|
(0x3040, 0x309F), # Hiragana
|
||||||
|
(0x30A0, 0x30FF), # Katakana
|
||||||
|
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
|
||||||
|
(0xAC00, 0xD7AF), # Hangul Syllables
|
||||||
|
(0x1100, 0x11FF), # Hangul Jamo
|
||||||
|
(0xA960, 0xA97F), # Hangul Jamo Extended-A
|
||||||
|
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
|
||||||
|
]
|
||||||
|
|
||||||
|
code_point = ord(char)
|
||||||
|
return any(start <= code_point <= end for start, end in cjk_ranges)
|
||||||
|
|
||||||
|
# LoRA-specific methods
|
||||||
|
async def get_letter_counts(self) -> Dict[str, int]:
|
||||||
|
"""Get count of LoRAs for each letter of the alphabet"""
|
||||||
|
cache = await self.scanner.get_cached_data()
|
||||||
|
data = cache.raw_data
|
||||||
|
|
||||||
|
# Define letter categories
|
||||||
|
letters = {
|
||||||
|
'#': 0, # Numbers
|
||||||
|
'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0,
|
||||||
|
'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0,
|
||||||
|
'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0,
|
||||||
|
'Y': 0, 'Z': 0,
|
||||||
|
'@': 0, # Special characters
|
||||||
|
'漢': 0 # CJK characters
|
||||||
|
}
|
||||||
|
|
||||||
|
# Count models for each letter
|
||||||
|
for lora in data:
|
||||||
|
model_name = lora.get('model_name', '')
|
||||||
|
if not model_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
first_char = model_name[0].upper()
|
||||||
|
|
||||||
|
if first_char.isdigit():
|
||||||
|
letters['#'] += 1
|
||||||
|
elif first_char in letters:
|
||||||
|
letters[first_char] += 1
|
||||||
|
elif self._is_cjk_character(first_char):
|
||||||
|
letters['漢'] += 1
|
||||||
|
elif not first_char.isalnum():
|
||||||
|
letters['@'] += 1
|
||||||
|
|
||||||
|
return letters
|
||||||
|
|
||||||
|
async def get_lora_notes(self, lora_name: str) -> Optional[str]:
|
||||||
|
"""Get notes for a specific LoRA file"""
|
||||||
|
cache = await self.scanner.get_cached_data()
|
||||||
|
|
||||||
|
for lora in cache.raw_data:
|
||||||
|
if lora['file_name'] == lora_name:
|
||||||
|
return lora.get('notes', '')
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_lora_trigger_words(self, lora_name: str) -> List[str]:
|
||||||
|
"""Get trigger words for a specific LoRA file"""
|
||||||
|
cache = await self.scanner.get_cached_data()
|
||||||
|
|
||||||
|
for lora in cache.raw_data:
|
||||||
|
if lora['file_name'] == lora_name:
|
||||||
|
civitai_data = lora.get('civitai', {})
|
||||||
|
return civitai_data.get('trainedWords', [])
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_lora_preview_url(self, lora_name: str) -> Optional[str]:
|
||||||
|
"""Get the static preview URL for a LoRA file"""
|
||||||
|
cache = await self.scanner.get_cached_data()
|
||||||
|
|
||||||
|
for lora in cache.raw_data:
|
||||||
|
if lora['file_name'] == lora_name:
|
||||||
|
preview_url = lora.get('preview_url')
|
||||||
|
if preview_url:
|
||||||
|
return config.get_preview_static_url(preview_url)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_lora_civitai_url(self, lora_name: str) -> Dict[str, Optional[str]]:
|
||||||
|
"""Get the Civitai URL for a LoRA file"""
|
||||||
|
cache = await self.scanner.get_cached_data()
|
||||||
|
|
||||||
|
for lora in cache.raw_data:
|
||||||
|
if lora['file_name'] == lora_name:
|
||||||
|
civitai_data = lora.get('civitai', {})
|
||||||
|
model_id = civitai_data.get('modelId')
|
||||||
|
version_id = civitai_data.get('id')
|
||||||
|
|
||||||
|
if model_id:
|
||||||
|
civitai_url = f"https://civitai.com/models/{model_id}"
|
||||||
|
if version_id:
|
||||||
|
civitai_url += f"?modelVersionId={version_id}"
|
||||||
|
|
||||||
|
return {
|
||||||
|
'civitai_url': civitai_url,
|
||||||
|
'model_id': str(model_id),
|
||||||
|
'version_id': str(version_id) if version_id else None
|
||||||
|
}
|
||||||
|
|
||||||
|
return {'civitai_url': None, 'model_id': None, 'version_id': None}
|
||||||
|
|
||||||
|
def find_duplicate_hashes(self) -> Dict:
|
||||||
|
"""Find LoRAs with duplicate SHA256 hashes"""
|
||||||
|
return self.scanner._hash_index.get_duplicate_hashes()
|
||||||
|
|
||||||
|
def find_duplicate_filenames(self) -> Dict:
|
||||||
|
"""Find LoRAs with conflicting filenames"""
|
||||||
|
return self.scanner._hash_index.get_duplicate_filenames()
|
||||||
@@ -1,37 +1,85 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Dict
|
from typing import List, Dict, Tuple
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from natsort import natsorted
|
from natsort import natsorted
|
||||||
|
|
||||||
|
# Supported sort modes: (sort_key, order)
|
||||||
|
# order: 'asc' for ascending, 'desc' for descending
|
||||||
|
SUPPORTED_SORT_MODES = [
|
||||||
|
('name', 'asc'),
|
||||||
|
('name', 'desc'),
|
||||||
|
('date', 'asc'),
|
||||||
|
('date', 'desc'),
|
||||||
|
('size', 'asc'),
|
||||||
|
('size', 'desc'),
|
||||||
|
]
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelCache:
|
class ModelCache:
|
||||||
"""Cache structure for model data"""
|
"""Cache structure for model data with extensible sorting"""
|
||||||
raw_data: List[Dict]
|
raw_data: List[Dict]
|
||||||
sorted_by_name: List[Dict]
|
|
||||||
sorted_by_date: List[Dict]
|
|
||||||
folders: List[str]
|
folders: List[str]
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
# Cache for last sort: (sort_key, order) -> sorted list
|
||||||
|
self._last_sort: Tuple[str, str] = (None, None)
|
||||||
|
self._last_sorted_data: List[Dict] = []
|
||||||
|
# Default sort on init
|
||||||
|
asyncio.create_task(self.resort())
|
||||||
|
|
||||||
async def resort(self, name_only: bool = False):
|
async def resort(self):
|
||||||
"""Resort all cached data views"""
|
"""Resort cached data according to last sort mode if set"""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self.sorted_by_name = natsorted(
|
if self._last_sort != (None, None):
|
||||||
self.raw_data,
|
sort_key, order = self._last_sort
|
||||||
key=lambda x: x['model_name'].lower() # Case-insensitive sort
|
sorted_data = self._sort_data(self.raw_data, sort_key, order)
|
||||||
)
|
self._last_sorted_data = sorted_data
|
||||||
if not name_only:
|
# Update folder list
|
||||||
self.sorted_by_date = sorted(
|
# else: do nothing
|
||||||
self.raw_data,
|
|
||||||
key=itemgetter('modified'),
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
# Update folder list
|
|
||||||
all_folders = set(l['folder'] for l in self.raw_data)
|
all_folders = set(l['folder'] for l in self.raw_data)
|
||||||
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||||
|
|
||||||
|
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
|
||||||
|
"""Sort data by sort_key and order"""
|
||||||
|
reverse = (order == 'desc')
|
||||||
|
if sort_key == 'name':
|
||||||
|
# Natural sort by model_name, case-insensitive
|
||||||
|
return natsorted(
|
||||||
|
data,
|
||||||
|
key=lambda x: x['model_name'].lower(),
|
||||||
|
reverse=reverse
|
||||||
|
)
|
||||||
|
elif sort_key == 'date':
|
||||||
|
# Sort by modified timestamp
|
||||||
|
return sorted(
|
||||||
|
data,
|
||||||
|
key=itemgetter('modified'),
|
||||||
|
reverse=reverse
|
||||||
|
)
|
||||||
|
elif sort_key == 'size':
|
||||||
|
# Sort by file size
|
||||||
|
return sorted(
|
||||||
|
data,
|
||||||
|
key=itemgetter('size'),
|
||||||
|
reverse=reverse
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback: no sort
|
||||||
|
return list(data)
|
||||||
|
|
||||||
|
async def get_sorted_data(self, sort_key: str = 'name', order: str = 'asc') -> List[Dict]:
|
||||||
|
"""Get sorted data by sort_key and order, using cache if possible"""
|
||||||
|
async with self._lock:
|
||||||
|
if (sort_key, order) == self._last_sort:
|
||||||
|
return self._last_sorted_data
|
||||||
|
sorted_data = self._sort_data(self.raw_data, sort_key, order)
|
||||||
|
self._last_sort = (sort_key, order)
|
||||||
|
self._last_sorted_data = sorted_data
|
||||||
|
return sorted_data
|
||||||
|
|
||||||
async def update_preview_url(self, file_path: str, preview_url: str, preview_nsfw_level: int) -> bool:
|
async def update_preview_url(self, file_path: str, preview_url: str, preview_nsfw_level: int) -> bool:
|
||||||
"""Update preview_url for a specific model in all cached data
|
"""Update preview_url for a specific model in all cached data
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
import shutil
|
import shutil
|
||||||
from typing import List, Dict, Optional, Type, Set
|
from typing import List, Dict, Optional, Type, Set
|
||||||
import msgpack # Add MessagePack import for efficient serialization
|
|
||||||
|
|
||||||
from ..utils.models import BaseModelMetadata
|
from ..utils.models import BaseModelMetadata
|
||||||
from ..config import config
|
from ..config import config
|
||||||
@@ -19,17 +18,33 @@ from .websocket_manager import ws_manager
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Define cache version to handle future format changes
|
|
||||||
# Version history:
|
|
||||||
# 1 - Initial version
|
|
||||||
# 2 - Added duplicate_filenames and duplicate_hashes tracking
|
|
||||||
# 3 - Added _excluded_models list to cache
|
|
||||||
CACHE_VERSION = 3
|
|
||||||
|
|
||||||
class ModelScanner:
|
class ModelScanner:
|
||||||
"""Base service for scanning and managing model files"""
|
"""Base service for scanning and managing model files"""
|
||||||
|
|
||||||
_lock = asyncio.Lock()
|
_instances = {} # Dictionary to store instances by class
|
||||||
|
_locks = {} # Dictionary to store locks by class
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
"""Implement singleton pattern for each subclass"""
|
||||||
|
if cls not in cls._instances:
|
||||||
|
cls._instances[cls] = super().__new__(cls)
|
||||||
|
return cls._instances[cls]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_lock(cls):
|
||||||
|
"""Get or create a lock for this class"""
|
||||||
|
if cls not in cls._locks:
|
||||||
|
cls._locks[cls] = asyncio.Lock()
|
||||||
|
return cls._locks[cls]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_instance(cls):
|
||||||
|
"""Get singleton instance with async support"""
|
||||||
|
lock = cls._get_lock()
|
||||||
|
async with lock:
|
||||||
|
if cls not in cls._instances:
|
||||||
|
cls._instances[cls] = cls()
|
||||||
|
return cls._instances[cls]
|
||||||
|
|
||||||
def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None):
|
def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None):
|
||||||
"""Initialize the scanner
|
"""Initialize the scanner
|
||||||
@@ -40,6 +55,10 @@ class ModelScanner:
|
|||||||
file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'})
|
file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'})
|
||||||
hash_index: Hash index instance (optional)
|
hash_index: Hash index instance (optional)
|
||||||
"""
|
"""
|
||||||
|
# Ensure initialization happens only once per instance
|
||||||
|
if hasattr(self, '_initialized'):
|
||||||
|
return
|
||||||
|
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.model_class = model_class
|
self.model_class = model_class
|
||||||
self.file_extensions = file_extensions
|
self.file_extensions = file_extensions
|
||||||
@@ -48,203 +67,16 @@ class ModelScanner:
|
|||||||
self._tags_count = {} # Dictionary to store tag counts
|
self._tags_count = {} # Dictionary to store tag counts
|
||||||
self._is_initializing = False # Flag to track initialization state
|
self._is_initializing = False # Flag to track initialization state
|
||||||
self._excluded_models = [] # List to track excluded models
|
self._excluded_models = [] # List to track excluded models
|
||||||
self._dirs_last_modified = {} # Track directory modification times
|
self._initialized = True
|
||||||
self._use_cache_files = False # Flag to control cache file usage, default to disabled
|
|
||||||
|
|
||||||
# Clear cache files if disabled
|
|
||||||
if not self._use_cache_files:
|
|
||||||
self._clear_cache_files()
|
|
||||||
|
|
||||||
# Register this service
|
# Register this service
|
||||||
asyncio.create_task(self._register_service())
|
asyncio.create_task(self._register_service())
|
||||||
|
|
||||||
def _clear_cache_files(self):
|
|
||||||
"""Clear existing cache files if they exist"""
|
|
||||||
try:
|
|
||||||
cache_path = self._get_cache_file_path()
|
|
||||||
if cache_path and os.path.exists(cache_path):
|
|
||||||
os.remove(cache_path)
|
|
||||||
logger.info(f"Cleared {self.model_type} cache file: {cache_path}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error clearing {self.model_type} cache file: {e}")
|
|
||||||
|
|
||||||
async def _register_service(self):
|
async def _register_service(self):
|
||||||
"""Register this instance with the ServiceRegistry"""
|
"""Register this instance with the ServiceRegistry"""
|
||||||
service_name = f"{self.model_type}_scanner"
|
service_name = f"{self.model_type}_scanner"
|
||||||
await ServiceRegistry.register_service(service_name, self)
|
await ServiceRegistry.register_service(service_name, self)
|
||||||
|
|
||||||
def _get_cache_file_path(self) -> Optional[str]:
|
|
||||||
"""Get the path to the cache file"""
|
|
||||||
# Get the directory where this module is located
|
|
||||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
|
||||||
|
|
||||||
# Create a cache directory within the project if it doesn't exist
|
|
||||||
cache_dir = os.path.join(current_dir, "cache")
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Create filename based on model type
|
|
||||||
cache_filename = f"lm_{self.model_type}_cache.msgpack"
|
|
||||||
return os.path.join(cache_dir, cache_filename)
|
|
||||||
|
|
||||||
def _prepare_for_msgpack(self, data):
|
|
||||||
"""Preprocess data to accommodate MessagePack serialization limitations
|
|
||||||
|
|
||||||
Converts integers exceeding safe range to strings
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: Any type of data structure
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Preprocessed data structure with large integers converted to strings
|
|
||||||
"""
|
|
||||||
if isinstance(data, dict):
|
|
||||||
return {k: self._prepare_for_msgpack(v) for k, v in data.items()}
|
|
||||||
elif isinstance(data, list):
|
|
||||||
return [self._prepare_for_msgpack(item) for item in data]
|
|
||||||
elif isinstance(data, int) and (data > 9007199254740991 or data < -9007199254740991):
|
|
||||||
# Convert integers exceeding JavaScript's safe integer range (2^53-1) to strings
|
|
||||||
return str(data)
|
|
||||||
else:
|
|
||||||
return data
|
|
||||||
|
|
||||||
async def _save_cache_to_disk(self) -> bool:
|
|
||||||
"""Save cache data to disk using MessagePack"""
|
|
||||||
if not self._use_cache_files:
|
|
||||||
logger.debug(f"Cache files disabled for {self.model_type}, skipping save")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if self._cache is None or not self._cache.raw_data:
|
|
||||||
logger.debug(f"No {self.model_type} cache data to save")
|
|
||||||
return False
|
|
||||||
|
|
||||||
cache_path = self._get_cache_file_path()
|
|
||||||
if not cache_path:
|
|
||||||
logger.warning(f"Cannot determine {self.model_type} cache file location")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create cache data structure
|
|
||||||
cache_data = {
|
|
||||||
"version": CACHE_VERSION,
|
|
||||||
"timestamp": time.time(),
|
|
||||||
"model_type": self.model_type,
|
|
||||||
"raw_data": self._cache.raw_data,
|
|
||||||
"hash_index": {
|
|
||||||
"hash_to_path": self._hash_index._hash_to_path,
|
|
||||||
"filename_to_hash": self._hash_index._filename_to_hash, # Fix: changed from path_to_hash to filename_to_hash
|
|
||||||
"duplicate_hashes": self._hash_index._duplicate_hashes,
|
|
||||||
"duplicate_filenames": self._hash_index._duplicate_filenames
|
|
||||||
},
|
|
||||||
"tags_count": self._tags_count,
|
|
||||||
"dirs_last_modified": self._get_dirs_last_modified(),
|
|
||||||
"excluded_models": self._excluded_models # Add excluded_models to cache data
|
|
||||||
}
|
|
||||||
|
|
||||||
# Preprocess data to handle large integers
|
|
||||||
processed_cache_data = self._prepare_for_msgpack(cache_data)
|
|
||||||
|
|
||||||
# Write to temporary file first (atomic operation)
|
|
||||||
temp_path = f"{cache_path}.tmp"
|
|
||||||
with open(temp_path, 'wb') as f:
|
|
||||||
msgpack.pack(processed_cache_data, f)
|
|
||||||
|
|
||||||
# Replace the old file with the new one
|
|
||||||
if os.path.exists(cache_path):
|
|
||||||
os.replace(temp_path, cache_path)
|
|
||||||
else:
|
|
||||||
os.rename(temp_path, cache_path)
|
|
||||||
|
|
||||||
logger.info(f"Saved {self.model_type} cache with {len(self._cache.raw_data)} models to {cache_path}")
|
|
||||||
logger.debug(f"Hash index stats - hash_to_path: {len(self._hash_index._hash_to_path)}, filename_to_hash: {len(self._hash_index._filename_to_hash)}, duplicate_hashes: {len(self._hash_index._duplicate_hashes)}, duplicate_filenames: {len(self._hash_index._duplicate_filenames)}")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error saving {self.model_type} cache to disk: {e}")
|
|
||||||
# Try to clean up temp file if it exists
|
|
||||||
if 'temp_path' in locals() and os.path.exists(temp_path):
|
|
||||||
try:
|
|
||||||
os.remove(temp_path)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _get_dirs_last_modified(self) -> Dict[str, float]:
|
|
||||||
"""Get last modified time for all model directories"""
|
|
||||||
dirs_info = {}
|
|
||||||
for root in self.get_model_roots():
|
|
||||||
if os.path.exists(root):
|
|
||||||
dirs_info[root] = os.path.getmtime(root)
|
|
||||||
# Also check immediate subdirectories for changes
|
|
||||||
try:
|
|
||||||
with os.scandir(root) as it:
|
|
||||||
for entry in it:
|
|
||||||
if entry.is_dir(follow_symlinks=True):
|
|
||||||
dirs_info[entry.path] = entry.stat().st_mtime
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting directory info for {root}: {e}")
|
|
||||||
return dirs_info
|
|
||||||
|
|
||||||
def _is_cache_valid(self, cache_data: Dict) -> bool:
|
|
||||||
"""Validate if the loaded cache is still valid"""
|
|
||||||
if not cache_data or cache_data.get("version") != CACHE_VERSION:
|
|
||||||
logger.info(f"Cache invalid - version mismatch. Got: {cache_data.get('version')}, Expected: {CACHE_VERSION}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if cache_data.get("model_type") != self.model_type:
|
|
||||||
logger.info(f"Cache invalid - model type mismatch. Got: {cache_data.get('model_type')}, Expected: {self.model_type}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _load_cache_from_disk(self) -> bool:
|
|
||||||
"""Load cache data from disk using MessagePack"""
|
|
||||||
if not self._use_cache_files:
|
|
||||||
logger.info(f"Cache files disabled for {self.model_type}, skipping load")
|
|
||||||
return False
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
cache_path = self._get_cache_file_path()
|
|
||||||
if not cache_path or not os.path.exists(cache_path):
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(cache_path, 'rb') as f:
|
|
||||||
cache_data = msgpack.unpack(f)
|
|
||||||
|
|
||||||
# Validate cache data
|
|
||||||
if not self._is_cache_valid(cache_data):
|
|
||||||
logger.info(f"{self.model_type.capitalize()} cache file found but invalid or outdated")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Load data into memory
|
|
||||||
self._cache = ModelCache(
|
|
||||||
raw_data=cache_data["raw_data"],
|
|
||||||
sorted_by_name=[],
|
|
||||||
sorted_by_date=[],
|
|
||||||
folders=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load hash index
|
|
||||||
hash_index_data = cache_data.get("hash_index", {})
|
|
||||||
self._hash_index._hash_to_path = hash_index_data.get("hash_to_path", {})
|
|
||||||
self._hash_index._filename_to_hash = hash_index_data.get("filename_to_hash", {}) # Fix: changed from path_to_hash to filename_to_hash
|
|
||||||
self._hash_index._duplicate_hashes = hash_index_data.get("duplicate_hashes", {})
|
|
||||||
self._hash_index._duplicate_filenames = hash_index_data.get("duplicate_filenames", {})
|
|
||||||
|
|
||||||
# Load tags count
|
|
||||||
self._tags_count = cache_data.get("tags_count", {})
|
|
||||||
|
|
||||||
# Load excluded models
|
|
||||||
self._excluded_models = cache_data.get("excluded_models", [])
|
|
||||||
|
|
||||||
# Resort the cache
|
|
||||||
await self._cache.resort()
|
|
||||||
|
|
||||||
logger.info(f"Loaded {self.model_type} cache from disk with {len(self._cache.raw_data)} models in {time.time() - start_time:.2f} seconds")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading {self.model_type} cache from disk: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def initialize_in_background(self) -> None:
|
async def initialize_in_background(self) -> None:
|
||||||
"""Initialize cache in background using thread pool"""
|
"""Initialize cache in background using thread pool"""
|
||||||
try:
|
try:
|
||||||
@@ -252,8 +84,6 @@ class ModelScanner:
|
|||||||
if self._cache is None:
|
if self._cache is None:
|
||||||
self._cache = ModelCache(
|
self._cache = ModelCache(
|
||||||
raw_data=[],
|
raw_data=[],
|
||||||
sorted_by_name=[],
|
|
||||||
sorted_by_date=[],
|
|
||||||
folders=[]
|
folders=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -272,21 +102,6 @@ class ModelScanner:
|
|||||||
'pageType': page_type
|
'pageType': page_type
|
||||||
})
|
})
|
||||||
|
|
||||||
cache_loaded = await self._load_cache_from_disk()
|
|
||||||
|
|
||||||
if cache_loaded:
|
|
||||||
# Cache loaded successfully, broadcast complete message
|
|
||||||
await ws_manager.broadcast_init_progress({
|
|
||||||
'stage': 'finalizing',
|
|
||||||
'progress': 100,
|
|
||||||
'status': 'complete',
|
|
||||||
'details': f"Loaded {len(self._cache.raw_data)} {self.model_type} files from cache.",
|
|
||||||
'scanner_type': self.model_type,
|
|
||||||
'pageType': page_type
|
|
||||||
})
|
|
||||||
self._is_initializing = False
|
|
||||||
return
|
|
||||||
|
|
||||||
# If cache loading failed, proceed with full scan
|
# If cache loading failed, proceed with full scan
|
||||||
await ws_manager.broadcast_init_progress({
|
await ws_manager.broadcast_init_progress({
|
||||||
'stage': 'scan_folders',
|
'stage': 'scan_folders',
|
||||||
@@ -332,9 +147,6 @@ class ModelScanner:
|
|||||||
|
|
||||||
logger.info(f"{self.model_type.capitalize()} cache initialized in {time.time() - start_time:.2f} seconds. Found {len(self._cache.raw_data)} models")
|
logger.info(f"{self.model_type.capitalize()} cache initialized in {time.time() - start_time:.2f} seconds. Found {len(self._cache.raw_data)} models")
|
||||||
|
|
||||||
# Save the cache to disk after initialization
|
|
||||||
await self._save_cache_to_disk()
|
|
||||||
|
|
||||||
# Send completion message
|
# Send completion message
|
||||||
await asyncio.sleep(0.5) # Small delay to ensure final progress message is sent
|
await asyncio.sleep(0.5) # Small delay to ensure final progress message is sent
|
||||||
await ws_manager.broadcast_init_progress({
|
await ws_manager.broadcast_init_progress({
|
||||||
@@ -509,40 +321,21 @@ class ModelScanner:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
force_refresh: Whether to refresh the cache
|
force_refresh: Whether to refresh the cache
|
||||||
rebuild_cache: Whether to completely rebuild the cache by reloading from disk first
|
rebuild_cache: Whether to completely rebuild the cache
|
||||||
"""
|
"""
|
||||||
# If cache is not initialized, return an empty cache
|
# If cache is not initialized, return an empty cache
|
||||||
# Actual initialization should be done via initialize_in_background
|
# Actual initialization should be done via initialize_in_background
|
||||||
if self._cache is None and not force_refresh:
|
if self._cache is None and not force_refresh:
|
||||||
return ModelCache(
|
return ModelCache(
|
||||||
raw_data=[],
|
raw_data=[],
|
||||||
sorted_by_name=[],
|
|
||||||
sorted_by_date=[],
|
|
||||||
folders=[]
|
folders=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
# If force refresh is requested, initialize the cache directly
|
# If force refresh is requested, initialize the cache directly
|
||||||
if force_refresh:
|
if force_refresh:
|
||||||
# If rebuild_cache is True, try to reload from disk before reconciliation
|
|
||||||
if rebuild_cache:
|
if rebuild_cache:
|
||||||
logger.info(f"{self.model_type.capitalize()} Scanner: Attempting to rebuild cache from disk...")
|
|
||||||
cache_loaded = await self._load_cache_from_disk()
|
|
||||||
if cache_loaded:
|
|
||||||
logger.info(f"{self.model_type.capitalize()} Scanner: Successfully reloaded cache from disk")
|
|
||||||
else:
|
|
||||||
logger.info(f"{self.model_type.capitalize()} Scanner: Could not reload cache from disk, proceeding with complete rebuild")
|
|
||||||
# If loading from disk failed, do a complete rebuild and save to disk
|
|
||||||
await self._initialize_cache()
|
|
||||||
await self._save_cache_to_disk()
|
|
||||||
return self._cache
|
|
||||||
|
|
||||||
if self._cache is None:
|
|
||||||
# For initial creation, do a full initialization
|
|
||||||
await self._initialize_cache()
|
await self._initialize_cache()
|
||||||
# Save the newly built cache
|
|
||||||
await self._save_cache_to_disk()
|
|
||||||
else:
|
else:
|
||||||
# For subsequent refreshes, use fast reconciliation
|
|
||||||
await self._reconcile_cache()
|
await self._reconcile_cache()
|
||||||
|
|
||||||
return self._cache
|
return self._cache
|
||||||
@@ -577,8 +370,6 @@ class ModelScanner:
|
|||||||
# Update cache
|
# Update cache
|
||||||
self._cache = ModelCache(
|
self._cache = ModelCache(
|
||||||
raw_data=raw_data,
|
raw_data=raw_data,
|
||||||
sorted_by_name=[],
|
|
||||||
sorted_by_date=[],
|
|
||||||
folders=[]
|
folders=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -592,8 +383,6 @@ class ModelScanner:
|
|||||||
if self._cache is None:
|
if self._cache is None:
|
||||||
self._cache = ModelCache(
|
self._cache = ModelCache(
|
||||||
raw_data=[],
|
raw_data=[],
|
||||||
sorted_by_name=[],
|
|
||||||
sorted_by_date=[],
|
|
||||||
folders=[]
|
folders=[]
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
@@ -735,19 +524,74 @@ class ModelScanner:
|
|||||||
# Resort cache
|
# Resort cache
|
||||||
await self._cache.resort()
|
await self._cache.resort()
|
||||||
|
|
||||||
# Save updated cache to disk
|
|
||||||
await self._save_cache_to_disk()
|
|
||||||
|
|
||||||
logger.info(f"{self.model_type.capitalize()} Scanner: Cache reconciliation completed in {time.time() - start_time:.2f} seconds. Added {total_added}, removed {total_removed} models.")
|
logger.info(f"{self.model_type.capitalize()} Scanner: Cache reconciliation completed in {time.time() - start_time:.2f} seconds. Added {total_added}, removed {total_removed} models.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.model_type.capitalize()} Scanner: Error reconciling cache: {e}", exc_info=True)
|
logger.error(f"{self.model_type.capitalize()} Scanner: Error reconciling cache: {e}", exc_info=True)
|
||||||
finally:
|
finally:
|
||||||
self._is_initializing = False # Unset flag
|
self._is_initializing = False # Unset flag
|
||||||
|
|
||||||
# These methods should be implemented in child classes
|
|
||||||
async def scan_all_models(self) -> List[Dict]:
|
async def scan_all_models(self) -> List[Dict]:
|
||||||
"""Scan all model directories and return metadata"""
|
"""Scan all model directories and return metadata"""
|
||||||
raise NotImplementedError("Subclasses must implement scan_all_models")
|
all_models = []
|
||||||
|
|
||||||
|
# Create scan tasks for each directory
|
||||||
|
scan_tasks = []
|
||||||
|
for model_root in self.get_model_roots():
|
||||||
|
task = asyncio.create_task(self._scan_directory(model_root))
|
||||||
|
scan_tasks.append(task)
|
||||||
|
|
||||||
|
# Wait for all tasks to complete
|
||||||
|
for task in scan_tasks:
|
||||||
|
try:
|
||||||
|
models = await task
|
||||||
|
all_models.extend(models)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error scanning directory: {e}")
|
||||||
|
|
||||||
|
return all_models
|
||||||
|
|
||||||
|
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||||
|
"""Scan a single directory for model files"""
|
||||||
|
models = []
|
||||||
|
original_root = root_path # Save original root path
|
||||||
|
|
||||||
|
async def scan_recursive(path: str, visited_paths: set):
|
||||||
|
"""Recursively scan directory, avoiding circular symlinks"""
|
||||||
|
try:
|
||||||
|
real_path = os.path.realpath(path)
|
||||||
|
if real_path in visited_paths:
|
||||||
|
logger.debug(f"Skipping already visited path: {path}")
|
||||||
|
return
|
||||||
|
visited_paths.add(real_path)
|
||||||
|
|
||||||
|
with os.scandir(path) as it:
|
||||||
|
entries = list(it)
|
||||||
|
for entry in entries:
|
||||||
|
try:
|
||||||
|
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
|
||||||
|
# Use original path instead of real path
|
||||||
|
file_path = entry.path.replace(os.sep, "/")
|
||||||
|
await self._process_single_file(file_path, original_root, models)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
elif entry.is_dir(follow_symlinks=True):
|
||||||
|
# For directories, continue scanning with original path
|
||||||
|
await scan_recursive(entry.path, visited_paths)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error scanning {path}: {e}")
|
||||||
|
|
||||||
|
await scan_recursive(root_path, set())
|
||||||
|
return models
|
||||||
|
|
||||||
|
async def _process_single_file(self, file_path: str, root_path: str, models: list):
|
||||||
|
"""Process a single file and add to results list"""
|
||||||
|
try:
|
||||||
|
result = await self._process_model_file(file_path, root_path)
|
||||||
|
if result:
|
||||||
|
models.append(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing {file_path}: {e}")
|
||||||
|
|
||||||
def is_initializing(self) -> bool:
|
def is_initializing(self) -> bool:
|
||||||
"""Check if the scanner is currently initializing"""
|
"""Check if the scanner is currently initializing"""
|
||||||
@@ -931,7 +775,7 @@ class ModelScanner:
|
|||||||
logger.error(f"Error processing {file_path}: {e}")
|
logger.error(f"Error processing {file_path}: {e}")
|
||||||
|
|
||||||
async def add_model_to_cache(self, metadata_dict: Dict, folder: str = '') -> bool:
|
async def add_model_to_cache(self, metadata_dict: Dict, folder: str = '') -> bool:
|
||||||
"""Add a model to the cache and save to disk
|
"""Add a model to the cache
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
metadata_dict: The model metadata dictionary
|
metadata_dict: The model metadata dictionary
|
||||||
@@ -960,9 +804,6 @@ class ModelScanner:
|
|||||||
|
|
||||||
# Update the hash index
|
# Update the hash index
|
||||||
self._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
|
self._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
|
||||||
|
|
||||||
# Save to disk
|
|
||||||
await self._save_cache_to_disk()
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error adding model to cache: {e}")
|
logger.error(f"Error adding model to cache: {e}")
|
||||||
@@ -1102,9 +943,6 @@ class ModelScanner:
|
|||||||
|
|
||||||
await cache.resort()
|
await cache.resort()
|
||||||
|
|
||||||
# Save the updated cache
|
|
||||||
await self._save_cache_to_disk()
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def has_hash(self, sha256: str) -> bool:
|
def has_hash(self, sha256: str) -> bool:
|
||||||
@@ -1198,11 +1036,7 @@ class ModelScanner:
|
|||||||
if self._cache is None:
|
if self._cache is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
updated = await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level)
|
return await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level)
|
||||||
if updated:
|
|
||||||
# Save updated cache to disk
|
|
||||||
await self._save_cache_to_disk()
|
|
||||||
return updated
|
|
||||||
|
|
||||||
async def bulk_delete_models(self, file_paths: List[str]) -> Dict:
|
async def bulk_delete_models(self, file_paths: List[str]) -> Dict:
|
||||||
"""Delete multiple models and update cache in a batch operation
|
"""Delete multiple models and update cache in a batch operation
|
||||||
@@ -1334,9 +1168,6 @@ class ModelScanner:
|
|||||||
# Resort cache
|
# Resort cache
|
||||||
await self._cache.resort()
|
await self._cache.resort()
|
||||||
|
|
||||||
# Save updated cache to disk
|
|
||||||
await self._save_cache_to_disk()
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
137
py/services/model_service_factory.py
Normal file
137
py/services/model_service_factory.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
from typing import Dict, Type, Any
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ModelServiceFactory:
|
||||||
|
"""Factory for managing model services and routes"""
|
||||||
|
|
||||||
|
_services: Dict[str, Type] = {}
|
||||||
|
_routes: Dict[str, Type] = {}
|
||||||
|
_initialized_services: Dict[str, Any] = {}
|
||||||
|
_initialized_routes: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_model_type(cls, model_type: str, service_class: Type, route_class: Type):
|
||||||
|
"""Register a new model type with its service and route classes
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: The model type identifier (e.g., 'lora', 'checkpoint')
|
||||||
|
service_class: The service class for this model type
|
||||||
|
route_class: The route class for this model type
|
||||||
|
"""
|
||||||
|
cls._services[model_type] = service_class
|
||||||
|
cls._routes[model_type] = route_class
|
||||||
|
logger.info(f"Registered model type '{model_type}' with service {service_class.__name__} and routes {route_class.__name__}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_service_class(cls, model_type: str) -> Type:
|
||||||
|
"""Get service class for a model type
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: The model type identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The service class for the model type
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If model type is not registered
|
||||||
|
"""
|
||||||
|
if model_type not in cls._services:
|
||||||
|
raise ValueError(f"Unknown model type: {model_type}")
|
||||||
|
return cls._services[model_type]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_route_class(cls, model_type: str) -> Type:
|
||||||
|
"""Get route class for a model type
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: The model type identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The route class for the model type
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If model type is not registered
|
||||||
|
"""
|
||||||
|
if model_type not in cls._routes:
|
||||||
|
raise ValueError(f"Unknown model type: {model_type}")
|
||||||
|
return cls._routes[model_type]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_route_instance(cls, model_type: str):
|
||||||
|
"""Get or create route instance for a model type
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: The model type identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The route instance for the model type
|
||||||
|
"""
|
||||||
|
if model_type not in cls._initialized_routes:
|
||||||
|
route_class = cls.get_route_class(model_type)
|
||||||
|
cls._initialized_routes[model_type] = route_class()
|
||||||
|
return cls._initialized_routes[model_type]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_all_routes(cls, app):
|
||||||
|
"""Setup routes for all registered model types
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: The aiohttp application instance
|
||||||
|
"""
|
||||||
|
logger.info(f"Setting up routes for {len(cls._services)} registered model types")
|
||||||
|
|
||||||
|
for model_type in cls._services.keys():
|
||||||
|
try:
|
||||||
|
routes_instance = cls.get_route_instance(model_type)
|
||||||
|
routes_instance.setup_routes(app)
|
||||||
|
logger.info(f"Successfully set up routes for {model_type}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to setup routes for {model_type}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_registered_types(cls) -> list:
|
||||||
|
"""Get list of all registered model types
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of registered model type identifiers
|
||||||
|
"""
|
||||||
|
return list(cls._services.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_registered(cls, model_type: str) -> bool:
|
||||||
|
"""Check if a model type is registered
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: The model type identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the model type is registered, False otherwise
|
||||||
|
"""
|
||||||
|
return model_type in cls._services
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def clear_registrations(cls):
|
||||||
|
"""Clear all registrations - mainly for testing purposes"""
|
||||||
|
cls._services.clear()
|
||||||
|
cls._routes.clear()
|
||||||
|
cls._initialized_services.clear()
|
||||||
|
cls._initialized_routes.clear()
|
||||||
|
logger.info("Cleared all model type registrations")
|
||||||
|
|
||||||
|
|
||||||
|
def register_default_model_types():
|
||||||
|
"""Register the default model types (LoRA and Checkpoint)"""
|
||||||
|
from ..services.lora_service import LoraService
|
||||||
|
from ..services.checkpoint_service import CheckpointService
|
||||||
|
from ..routes.lora_routes import LoraRoutes
|
||||||
|
from ..routes.checkpoint_routes import CheckpointRoutes
|
||||||
|
|
||||||
|
# Register LoRA model type
|
||||||
|
ModelServiceFactory.register_model_type('lora', LoraService, LoraRoutes)
|
||||||
|
|
||||||
|
# Register Checkpoint model type
|
||||||
|
ModelServiceFactory.register_model_type('checkpoint', CheckpointService, CheckpointRoutes)
|
||||||
|
|
||||||
|
logger.info("Registered default model types: lora, checkpoint")
|
||||||
@@ -393,8 +393,8 @@ class RecipeScanner:
|
|||||||
if 'hash' in lora and (not lora.get('file_name') or not lora['file_name']):
|
if 'hash' in lora and (not lora.get('file_name') or not lora['file_name']):
|
||||||
hash_value = lora['hash']
|
hash_value = lora['hash']
|
||||||
|
|
||||||
if self._lora_scanner.has_lora_hash(hash_value):
|
if self._lora_scanner.has_hash(hash_value):
|
||||||
lora_path = self._lora_scanner.get_lora_path_by_hash(hash_value)
|
lora_path = self._lora_scanner.get_path_by_hash(hash_value)
|
||||||
if lora_path:
|
if lora_path:
|
||||||
file_name = os.path.splitext(os.path.basename(lora_path))[0]
|
file_name = os.path.splitext(os.path.basename(lora_path))[0]
|
||||||
lora['file_name'] = file_name
|
lora['file_name'] = file_name
|
||||||
@@ -465,7 +465,7 @@ class RecipeScanner:
|
|||||||
# Count occurrences of each base model
|
# Count occurrences of each base model
|
||||||
for lora in loras:
|
for lora in loras:
|
||||||
if 'hash' in lora:
|
if 'hash' in lora:
|
||||||
lora_path = self._lora_scanner.get_lora_path_by_hash(lora['hash'])
|
lora_path = self._lora_scanner.get_path_by_hash(lora['hash'])
|
||||||
if lora_path:
|
if lora_path:
|
||||||
base_model = await self._get_base_model_for_lora(lora_path)
|
base_model = await self._get_base_model_for_lora(lora_path)
|
||||||
if base_model:
|
if base_model:
|
||||||
@@ -603,9 +603,9 @@ class RecipeScanner:
|
|||||||
if 'loras' in item:
|
if 'loras' in item:
|
||||||
for lora in item['loras']:
|
for lora in item['loras']:
|
||||||
if 'hash' in lora and lora['hash']:
|
if 'hash' in lora and lora['hash']:
|
||||||
lora['inLibrary'] = self._lora_scanner.has_lora_hash(lora['hash'].lower())
|
lora['inLibrary'] = self._lora_scanner.has_hash(lora['hash'].lower())
|
||||||
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
|
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
|
||||||
lora['localPath'] = self._lora_scanner.get_lora_path_by_hash(lora['hash'].lower())
|
lora['localPath'] = self._lora_scanner.get_path_by_hash(lora['hash'].lower())
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
'items': paginated_items,
|
'items': paginated_items,
|
||||||
@@ -655,9 +655,9 @@ class RecipeScanner:
|
|||||||
for lora in formatted_recipe['loras']:
|
for lora in formatted_recipe['loras']:
|
||||||
if 'hash' in lora and lora['hash']:
|
if 'hash' in lora and lora['hash']:
|
||||||
lora_hash = lora['hash'].lower()
|
lora_hash = lora['hash'].lower()
|
||||||
lora['inLibrary'] = self._lora_scanner.has_lora_hash(lora_hash)
|
lora['inLibrary'] = self._lora_scanner.has_hash(lora_hash)
|
||||||
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora_hash)
|
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora_hash)
|
||||||
lora['localPath'] = self._lora_scanner.get_lora_path_by_hash(lora_hash)
|
lora['localPath'] = self._lora_scanner.get_path_by_hash(lora_hash)
|
||||||
|
|
||||||
return formatted_recipe
|
return formatted_recipe
|
||||||
|
|
||||||
|
|||||||
@@ -7,106 +7,176 @@ logger = logging.getLogger(__name__)
|
|||||||
T = TypeVar('T') # Define a type variable for service types
|
T = TypeVar('T') # Define a type variable for service types
|
||||||
|
|
||||||
class ServiceRegistry:
|
class ServiceRegistry:
|
||||||
"""Centralized registry for service singletons"""
|
"""Central registry for managing singleton services"""
|
||||||
|
|
||||||
_instance = None
|
|
||||||
_services: Dict[str, Any] = {}
|
_services: Dict[str, Any] = {}
|
||||||
_lock = asyncio.Lock()
|
_locks: Dict[str, asyncio.Lock] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls):
|
async def register_service(cls, name: str, service: Any) -> None:
|
||||||
"""Get singleton instance of the registry"""
|
"""Register a service instance with the registry
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = cls()
|
Args:
|
||||||
return cls._instance
|
name: Service name identifier
|
||||||
|
service: Service instance to register
|
||||||
|
"""
|
||||||
|
cls._services[name] = service
|
||||||
|
logger.debug(f"Registered service: {name}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def register_service(cls, service_name: str, service_instance: Any) -> None:
|
async def get_service(cls, name: str) -> Optional[Any]:
|
||||||
"""Register a service instance with the registry"""
|
"""Get a service instance by name
|
||||||
registry = cls.get_instance()
|
|
||||||
async with cls._lock:
|
Args:
|
||||||
registry._services[service_name] = service_instance
|
name: Service name identifier
|
||||||
logger.debug(f"Registered service: {service_name}")
|
|
||||||
|
Returns:
|
||||||
|
Service instance or None if not found
|
||||||
|
"""
|
||||||
|
return cls._services.get(name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_service(cls, service_name: str) -> Any:
|
def _get_lock(cls, name: str) -> asyncio.Lock:
|
||||||
"""Get a service instance by name"""
|
"""Get or create a lock for a service
|
||||||
registry = cls.get_instance()
|
|
||||||
async with cls._lock:
|
|
||||||
if service_name not in registry._services:
|
|
||||||
logger.debug(f"Service {service_name} not found in registry")
|
|
||||||
return None
|
|
||||||
return registry._services[service_name]
|
|
||||||
|
|
||||||
@classmethod
|
Args:
|
||||||
def get_service_sync(cls, service_name: str) -> Any:
|
name: Service name identifier
|
||||||
"""Get a service instance by name (synchronous version)"""
|
|
||||||
registry = cls.get_instance()
|
Returns:
|
||||||
if service_name not in registry._services:
|
AsyncIO lock for the service
|
||||||
logger.debug(f"Service {service_name} not found in registry")
|
"""
|
||||||
return None
|
if name not in cls._locks:
|
||||||
return registry._services[service_name]
|
cls._locks[name] = asyncio.Lock()
|
||||||
|
return cls._locks[name]
|
||||||
|
|
||||||
# Convenience methods for common services
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_lora_scanner(cls):
|
async def get_lora_scanner(cls):
|
||||||
"""Get the LoraScanner instance"""
|
"""Get or create LoRA scanner instance"""
|
||||||
from .lora_scanner import LoraScanner
|
service_name = "lora_scanner"
|
||||||
scanner = await cls.get_service("lora_scanner")
|
|
||||||
if scanner is None:
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
async with cls._get_lock(service_name):
|
||||||
|
# Double-check after acquiring lock
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from .lora_scanner import LoraScanner
|
||||||
|
|
||||||
scanner = await LoraScanner.get_instance()
|
scanner = await LoraScanner.get_instance()
|
||||||
await cls.register_service("lora_scanner", scanner)
|
cls._services[service_name] = scanner
|
||||||
return scanner
|
logger.debug(f"Created and registered {service_name}")
|
||||||
|
return scanner
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_checkpoint_scanner(cls):
|
async def get_checkpoint_scanner(cls):
|
||||||
"""Get the CheckpointScanner instance"""
|
"""Get or create Checkpoint scanner instance"""
|
||||||
from .checkpoint_scanner import CheckpointScanner
|
service_name = "checkpoint_scanner"
|
||||||
scanner = await cls.get_service("checkpoint_scanner")
|
|
||||||
if scanner is None:
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
async with cls._get_lock(service_name):
|
||||||
|
# Double-check after acquiring lock
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from .checkpoint_scanner import CheckpointScanner
|
||||||
|
|
||||||
scanner = await CheckpointScanner.get_instance()
|
scanner = await CheckpointScanner.get_instance()
|
||||||
await cls.register_service("checkpoint_scanner", scanner)
|
cls._services[service_name] = scanner
|
||||||
return scanner
|
logger.debug(f"Created and registered {service_name}")
|
||||||
|
return scanner
|
||||||
@classmethod
|
|
||||||
async def get_civitai_client(cls):
|
|
||||||
"""Get the CivitaiClient instance"""
|
|
||||||
from .civitai_client import CivitaiClient
|
|
||||||
client = await cls.get_service("civitai_client")
|
|
||||||
if client is None:
|
|
||||||
client = await CivitaiClient.get_instance()
|
|
||||||
await cls.register_service("civitai_client", client)
|
|
||||||
return client
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_download_manager(cls):
|
|
||||||
"""Get the DownloadManager instance"""
|
|
||||||
from .download_manager import DownloadManager
|
|
||||||
manager = await cls.get_service("download_manager")
|
|
||||||
if manager is None:
|
|
||||||
manager = await DownloadManager.get_instance()
|
|
||||||
await cls.register_service("download_manager", manager)
|
|
||||||
return manager
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_recipe_scanner(cls):
|
async def get_recipe_scanner(cls):
|
||||||
"""Get the RecipeScanner instance"""
|
"""Get or create Recipe scanner instance"""
|
||||||
from .recipe_scanner import RecipeScanner
|
service_name = "recipe_scanner"
|
||||||
scanner = await cls.get_service("recipe_scanner")
|
|
||||||
if scanner is None:
|
if service_name in cls._services:
|
||||||
lora_scanner = await cls.get_lora_scanner()
|
return cls._services[service_name]
|
||||||
scanner = RecipeScanner(lora_scanner)
|
|
||||||
await cls.register_service("recipe_scanner", scanner)
|
async with cls._get_lock(service_name):
|
||||||
return scanner
|
# Double-check after acquiring lock
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from .recipe_scanner import RecipeScanner
|
||||||
|
|
||||||
|
scanner = await RecipeScanner.get_instance()
|
||||||
|
cls._services[service_name] = scanner
|
||||||
|
logger.debug(f"Created and registered {service_name}")
|
||||||
|
return scanner
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_civitai_client(cls):
|
||||||
|
"""Get or create CivitAI client instance"""
|
||||||
|
service_name = "civitai_client"
|
||||||
|
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
async with cls._get_lock(service_name):
|
||||||
|
# Double-check after acquiring lock
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from .civitai_client import CivitaiClient
|
||||||
|
|
||||||
|
client = await CivitaiClient.get_instance()
|
||||||
|
cls._services[service_name] = client
|
||||||
|
logger.debug(f"Created and registered {service_name}")
|
||||||
|
return client
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_download_manager(cls):
|
||||||
|
"""Get or create Download manager instance"""
|
||||||
|
service_name = "download_manager"
|
||||||
|
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
async with cls._get_lock(service_name):
|
||||||
|
# Double-check after acquiring lock
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from .download_manager import DownloadManager
|
||||||
|
|
||||||
|
manager = DownloadManager()
|
||||||
|
cls._services[service_name] = manager
|
||||||
|
logger.debug(f"Created and registered {service_name}")
|
||||||
|
return manager
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_websocket_manager(cls):
|
async def get_websocket_manager(cls):
|
||||||
"""Get the WebSocketManager instance"""
|
"""Get or create WebSocket manager instance"""
|
||||||
from .websocket_manager import ws_manager
|
service_name = "websocket_manager"
|
||||||
manager = await cls.get_service("websocket_manager")
|
|
||||||
if manager is None:
|
if service_name in cls._services:
|
||||||
# ws_manager is already a global instance in websocket_manager.py
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
async with cls._get_lock(service_name):
|
||||||
|
# Double-check after acquiring lock
|
||||||
|
if service_name in cls._services:
|
||||||
|
return cls._services[service_name]
|
||||||
|
|
||||||
|
# Import here to avoid circular imports
|
||||||
from .websocket_manager import ws_manager
|
from .websocket_manager import ws_manager
|
||||||
await cls.register_service("websocket_manager", ws_manager)
|
|
||||||
manager = ws_manager
|
cls._services[service_name] = ws_manager
|
||||||
return manager
|
logger.debug(f"Registered {service_name}")
|
||||||
|
return ws_manager
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def clear_services(cls):
|
||||||
|
"""Clear all registered services - mainly for testing"""
|
||||||
|
cls._services.clear()
|
||||||
|
cls._locks.clear()
|
||||||
|
logger.info("Cleared all registered services")
|
||||||
@@ -566,9 +566,10 @@ class ModelRouteUtils:
|
|||||||
return web.Response(text=str(e), status=500)
|
return web.Response(text=str(e), status=500)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def handle_download_model(request: web.Request, download_manager: DownloadManager) -> web.Response:
|
async def handle_download_model(request: web.Request) -> web.Response:
|
||||||
"""Handle model download request"""
|
"""Handle model download request"""
|
||||||
try:
|
try:
|
||||||
|
download_manager = await ServiceRegistry.get_download_manager()
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
|
|
||||||
# Get or generate a download ID
|
# Get or generate a download ID
|
||||||
@@ -663,17 +664,17 @@ class ModelRouteUtils:
|
|||||||
}, status=500)
|
}, status=500)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def handle_cancel_download(request: web.Request, download_manager: DownloadManager) -> web.Response:
|
async def handle_cancel_download(request: web.Request) -> web.Response:
|
||||||
"""Handle cancellation of a download task
|
"""Handle cancellation of a download task
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: The aiohttp request
|
request: The aiohttp request
|
||||||
download_manager: The download manager instance
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
web.Response: The HTTP response
|
web.Response: The HTTP response
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
download_manager = await ServiceRegistry.get_download_manager()
|
||||||
download_id = request.match_info.get('download_id')
|
download_id = request.match_info.get('download_id')
|
||||||
if not download_id:
|
if not download_id:
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
@@ -701,17 +702,17 @@ class ModelRouteUtils:
|
|||||||
}, status=500)
|
}, status=500)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def handle_list_downloads(request: web.Request, download_manager: DownloadManager) -> web.Response:
|
async def handle_list_downloads(request: web.Request) -> web.Response:
|
||||||
"""Get list of active downloads
|
"""Get list of active downloads
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: The aiohttp request
|
request: The aiohttp request
|
||||||
download_manager: The download manager instance
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
web.Response: The HTTP response with list of downloads
|
web.Response: The HTTP response with list of downloads
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
download_manager = await ServiceRegistry.get_download_manager()
|
||||||
result = await download_manager.get_active_downloads()
|
result = await download_manager.get_active_downloads()
|
||||||
return web.json_response(result)
|
return web.json_response(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1047,3 +1048,56 @@ class ModelRouteUtils:
|
|||||||
'success': False,
|
'success': False,
|
||||||
'error': str(e)
|
'error': str(e)
|
||||||
}, status=500)
|
}, status=500)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def handle_save_metadata(request: web.Request, scanner) -> web.Response:
|
||||||
|
"""Handle saving metadata updates
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The aiohttp request
|
||||||
|
scanner: The model scanner instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
web.Response: The HTTP response
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
file_path = data.get('file_path')
|
||||||
|
if not file_path:
|
||||||
|
return web.Response(text='File path is required', status=400)
|
||||||
|
|
||||||
|
# Remove file path from data to avoid saving it
|
||||||
|
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
|
||||||
|
|
||||||
|
# Get metadata file path
|
||||||
|
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||||
|
|
||||||
|
# Load existing metadata
|
||||||
|
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||||
|
|
||||||
|
# Handle nested updates (for civitai.trainedWords)
|
||||||
|
for key, value in metadata_updates.items():
|
||||||
|
if isinstance(value, dict) and key in metadata and isinstance(metadata[key], dict):
|
||||||
|
# Deep update for nested dictionaries
|
||||||
|
for nested_key, nested_value in value.items():
|
||||||
|
metadata[key][nested_key] = nested_value
|
||||||
|
else:
|
||||||
|
# Regular update for top-level keys
|
||||||
|
metadata[key] = value
|
||||||
|
|
||||||
|
# Save updated metadata
|
||||||
|
await MetadataManager.save_metadata(file_path, metadata)
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
await scanner.update_single_model_cache(file_path, file_path, metadata)
|
||||||
|
|
||||||
|
# If model_name was updated, resort the cache
|
||||||
|
if 'model_name' in metadata_updates:
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
await cache.resort()
|
||||||
|
|
||||||
|
return web.json_response({'success': True})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving metadata: {e}", exc_info=True)
|
||||||
|
return web.Response(text=str(e), status=500)
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ dependencies = [
|
|||||||
"requests",
|
"requests",
|
||||||
"toml",
|
"toml",
|
||||||
"natsort",
|
"natsort",
|
||||||
"msgpack"
|
"GitPython"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
|||||||
@@ -9,5 +9,5 @@ requests
|
|||||||
toml
|
toml
|
||||||
numpy
|
numpy
|
||||||
natsort
|
natsort
|
||||||
msgpack
|
|
||||||
pyyaml
|
pyyaml
|
||||||
|
GitPython
|
||||||
|
|||||||
@@ -106,6 +106,22 @@ logger = logging.getLogger("lora-manager-standalone")
|
|||||||
# Configure aiohttp access logger to be less verbose
|
# Configure aiohttp access logger to be less verbose
|
||||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
# Add specific suppression for connection reset errors
|
||||||
|
class ConnectionResetFilter(logging.Filter):
|
||||||
|
def filter(self, record):
|
||||||
|
# Filter out connection reset errors that are not critical
|
||||||
|
if "ConnectionResetError" in str(record.getMessage()):
|
||||||
|
return False
|
||||||
|
if "_call_connection_lost" in str(record.getMessage()):
|
||||||
|
return False
|
||||||
|
if "WinError 10054" in str(record.getMessage()):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Apply the filter to asyncio logger
|
||||||
|
asyncio_logger = logging.getLogger("asyncio")
|
||||||
|
asyncio_logger.addFilter(ConnectionResetFilter())
|
||||||
|
|
||||||
# Now we can import the global config from our local modules
|
# Now we can import the global config from our local modules
|
||||||
from py.config import config
|
from py.config import config
|
||||||
|
|
||||||
@@ -119,17 +135,6 @@ class StandaloneServer:
|
|||||||
# Ensure the app's access logger is configured to reduce verbosity
|
# Ensure the app's access logger is configured to reduce verbosity
|
||||||
self.app._subapps = [] # Ensure this exists to avoid AttributeError
|
self.app._subapps = [] # Ensure this exists to avoid AttributeError
|
||||||
|
|
||||||
# Configure access logging for the app
|
|
||||||
self.app.on_startup.append(self._configure_access_logger)
|
|
||||||
|
|
||||||
async def _configure_access_logger(self, app):
|
|
||||||
"""Configure access logger to reduce verbosity"""
|
|
||||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
# If using aiohttp>=3.8.0, configure access logger through app directly
|
|
||||||
if hasattr(app, 'access_logger'):
|
|
||||||
app.access_logger.setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
"""Set up the standalone server"""
|
"""Set up the standalone server"""
|
||||||
# Create placeholders for compatibility with ComfyUI's implementation
|
# Create placeholders for compatibility with ComfyUI's implementation
|
||||||
@@ -219,9 +224,6 @@ class StandaloneLoraManager(LoraManager):
|
|||||||
# Store app in a global-like location for compatibility
|
# Store app in a global-like location for compatibility
|
||||||
sys.modules['server'].PromptServer.instance = server_instance
|
sys.modules['server'].PromptServer.instance = server_instance
|
||||||
|
|
||||||
# Configure aiohttp access logger to be less verbose
|
|
||||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
added_targets = set() # Track already added target paths
|
added_targets = set() # Track already added target paths
|
||||||
|
|
||||||
# Add static routes for each lora root
|
# Add static routes for each lora root
|
||||||
@@ -314,35 +316,39 @@ class StandaloneLoraManager(LoraManager):
|
|||||||
app.router.add_static('/loras_static', config.static_path)
|
app.router.add_static('/loras_static', config.static_path)
|
||||||
|
|
||||||
# Setup feature routes
|
# Setup feature routes
|
||||||
from py.routes.lora_routes import LoraRoutes
|
from py.services.model_service_factory import ModelServiceFactory, register_default_model_types
|
||||||
from py.routes.api_routes import ApiRoutes
|
|
||||||
from py.routes.recipe_routes import RecipeRoutes
|
from py.routes.recipe_routes import RecipeRoutes
|
||||||
from py.routes.checkpoints_routes import CheckpointsRoutes
|
|
||||||
from py.routes.update_routes import UpdateRoutes
|
from py.routes.update_routes import UpdateRoutes
|
||||||
from py.routes.misc_routes import MiscRoutes
|
from py.routes.misc_routes import MiscRoutes
|
||||||
from py.routes.example_images_routes import ExampleImagesRoutes
|
from py.routes.example_images_routes import ExampleImagesRoutes
|
||||||
from py.routes.stats_routes import StatsRoutes
|
from py.routes.stats_routes import StatsRoutes
|
||||||
|
from py.services.websocket_manager import ws_manager
|
||||||
|
|
||||||
|
|
||||||
|
register_default_model_types()
|
||||||
|
|
||||||
|
# Setup all model routes using the factory
|
||||||
|
ModelServiceFactory.setup_all_routes(app)
|
||||||
|
|
||||||
lora_routes = LoraRoutes()
|
|
||||||
checkpoints_routes = CheckpointsRoutes()
|
|
||||||
stats_routes = StatsRoutes()
|
stats_routes = StatsRoutes()
|
||||||
|
|
||||||
# Initialize routes
|
# Initialize routes
|
||||||
lora_routes.setup_routes(app)
|
|
||||||
checkpoints_routes.setup_routes(app)
|
|
||||||
stats_routes.setup_routes(app)
|
stats_routes.setup_routes(app)
|
||||||
ApiRoutes.setup_routes(app)
|
|
||||||
RecipeRoutes.setup_routes(app)
|
RecipeRoutes.setup_routes(app)
|
||||||
UpdateRoutes.setup_routes(app)
|
UpdateRoutes.setup_routes(app)
|
||||||
MiscRoutes.setup_routes(app)
|
MiscRoutes.setup_routes(app)
|
||||||
ExampleImagesRoutes.setup_routes(app)
|
ExampleImagesRoutes.setup_routes(app)
|
||||||
|
|
||||||
|
# Setup WebSocket routes that are shared across all model types
|
||||||
|
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||||
|
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection)
|
||||||
|
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection)
|
||||||
|
|
||||||
# Schedule service initialization
|
# Schedule service initialization
|
||||||
app.on_startup.append(lambda app: cls._initialize_services())
|
app.on_startup.append(lambda app: cls._initialize_services())
|
||||||
|
|
||||||
# Add cleanup
|
# Add cleanup
|
||||||
app.on_shutdown.append(cls._cleanup)
|
app.on_shutdown.append(cls._cleanup)
|
||||||
app.on_shutdown.append(ApiRoutes.cleanup)
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
"""Parse command line arguments"""
|
"""Parse command line arguments"""
|
||||||
@@ -367,9 +373,6 @@ async def main():
|
|||||||
# Set log level
|
# Set log level
|
||||||
logging.getLogger().setLevel(getattr(logging, args.log_level))
|
logging.getLogger().setLevel(getattr(logging, args.log_level))
|
||||||
|
|
||||||
# Explicitly configure aiohttp access logger regardless of selected log level
|
|
||||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
# Create the server instance
|
# Create the server instance
|
||||||
server = StandaloneServer()
|
server = StandaloneServer()
|
||||||
|
|
||||||
|
|||||||
@@ -50,8 +50,8 @@ html, body {
|
|||||||
--lora-border: oklch(90% 0.02 256 / 0.15);
|
--lora-border: oklch(90% 0.02 256 / 0.15);
|
||||||
--lora-text: oklch(95% 0.02 256);
|
--lora-text: oklch(95% 0.02 256);
|
||||||
--lora-error: oklch(75% 0.32 29);
|
--lora-error: oklch(75% 0.32 29);
|
||||||
--lora-warning: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h)); /* Modified to be used with oklch() */
|
--lora-warning: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h));
|
||||||
--lora-success: oklch(var(--lora-success-l) var(--lora-success-c) var(--lora-success-h)); /* New green success color */
|
--lora-success: oklch(var(--lora-success-l) var(--lora-success-c) var(--lora-success-h));
|
||||||
|
|
||||||
/* Spacing Scale */
|
/* Spacing Scale */
|
||||||
--space-1: calc(8px * 1);
|
--space-1: calc(8px * 1);
|
||||||
|
|||||||
@@ -223,11 +223,6 @@
|
|||||||
opacity: 1;
|
opacity: 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
.update-badge.hidden,
|
|
||||||
.update-badge:not(.visible) {
|
|
||||||
opacity: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Mobile adjustments */
|
/* Mobile adjustments */
|
||||||
@media (max-width: 768px) {
|
@media (max-width: 768px) {
|
||||||
.app-title {
|
.app-title {
|
||||||
|
|||||||
@@ -172,6 +172,91 @@ body.modal-open {
|
|||||||
opacity: 1;
|
opacity: 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Update Modal specific styles */
|
||||||
|
.update-actions {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: var(--space-2);
|
||||||
|
align-items: stretch;
|
||||||
|
flex-wrap: nowrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.update-link {
|
||||||
|
color: var(--lora-accent);
|
||||||
|
text-decoration: none;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 8px;
|
||||||
|
font-size: 0.95em;
|
||||||
|
}
|
||||||
|
|
||||||
|
.update-link:hover {
|
||||||
|
text-decoration: underline;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Update progress styles */
|
||||||
|
.update-progress {
|
||||||
|
background: rgba(0, 0, 0, 0.03);
|
||||||
|
border: 1px solid var(--lora-border);
|
||||||
|
border-radius: var(--border-radius-sm);
|
||||||
|
padding: var(--space-2);
|
||||||
|
margin: var(--space-2) 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-theme="dark"] .update-progress {
|
||||||
|
background: rgba(255, 255, 255, 0.03);
|
||||||
|
}
|
||||||
|
|
||||||
|
.progress-info {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: var(--space-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
.progress-text {
|
||||||
|
font-size: 0.9em;
|
||||||
|
color: var(--text-color);
|
||||||
|
opacity: 0.8;
|
||||||
|
}
|
||||||
|
|
||||||
|
.progress-bar {
|
||||||
|
width: 100%;
|
||||||
|
height: 8px;
|
||||||
|
background-color: rgba(0, 0, 0, 0.1);
|
||||||
|
border-radius: 4px;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-theme="dark"] .progress-bar {
|
||||||
|
background-color: rgba(255, 255, 255, 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
.progress-fill {
|
||||||
|
height: 100%;
|
||||||
|
background-color: var(--lora-accent);
|
||||||
|
width: 0%;
|
||||||
|
transition: width 0.3s ease;
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Update button states */
|
||||||
|
#updateBtn {
|
||||||
|
min-width: 120px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#updateBtn.updating {
|
||||||
|
background-color: var(--lora-warning);
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
#updateBtn.success {
|
||||||
|
background-color: var(--lora-success);
|
||||||
|
}
|
||||||
|
|
||||||
|
#updateBtn.error {
|
||||||
|
background-color: var(--lora-error);
|
||||||
|
}
|
||||||
|
|
||||||
/* Settings styles */
|
/* Settings styles */
|
||||||
.settings-toggle {
|
.settings-toggle {
|
||||||
width: 36px;
|
width: 36px;
|
||||||
|
|||||||
@@ -182,6 +182,31 @@
|
|||||||
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
|
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Style for optgroups */
|
||||||
|
.control-group select optgroup {
|
||||||
|
font-weight: 600;
|
||||||
|
font-style: normal;
|
||||||
|
color: var(--text-color);
|
||||||
|
background-color: var(--card-bg);
|
||||||
|
}
|
||||||
|
|
||||||
|
.control-group select option {
|
||||||
|
padding: 4px 8px;
|
||||||
|
background-color: var(--card-bg);
|
||||||
|
color: var(--text-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Dark theme optgroup styling */
|
||||||
|
[data-theme="dark"] .control-group select optgroup {
|
||||||
|
background-color: var(--card-bg);
|
||||||
|
color: var(--text-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
[data-theme="dark"] .control-group select option {
|
||||||
|
background-color: var(--card-bg);
|
||||||
|
color: var(--text-color);
|
||||||
|
}
|
||||||
|
|
||||||
.control-group select:hover {
|
.control-group select:hover {
|
||||||
border-color: var(--lora-accent);
|
border-color: var(--lora-accent);
|
||||||
background-color: var(--bg-color);
|
background-color: var(--bg-color);
|
||||||
|
|||||||
@@ -54,25 +54,16 @@ export async function fetchModelsPage(options = {}) {
|
|||||||
if (pageState.filters) {
|
if (pageState.filters) {
|
||||||
// Handle tags filters
|
// Handle tags filters
|
||||||
if (pageState.filters.tags && pageState.filters.tags.length > 0) {
|
if (pageState.filters.tags && pageState.filters.tags.length > 0) {
|
||||||
// Checkpoints API expects individual 'tag' parameters, Loras API expects comma-separated 'tags'
|
pageState.filters.tags.forEach(tag => {
|
||||||
if (modelType === 'checkpoint') {
|
params.append('tag', tag);
|
||||||
pageState.filters.tags.forEach(tag => {
|
});
|
||||||
params.append('tag', tag);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
params.append('tags', pageState.filters.tags.join(','));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle base model filters
|
// Handle base model filters
|
||||||
if (pageState.filters.baseModel && pageState.filters.baseModel.length > 0) {
|
if (pageState.filters.baseModel && pageState.filters.baseModel.length > 0) {
|
||||||
if (modelType === 'checkpoint') {
|
pageState.filters.baseModel.forEach(model => {
|
||||||
pageState.filters.baseModel.forEach(model => {
|
params.append('base_model', model);
|
||||||
params.append('base_model', model);
|
});
|
||||||
});
|
|
||||||
} else {
|
|
||||||
params.append('base_models', pageState.filters.baseModel.join(','));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -277,7 +268,7 @@ export async function deleteModel(filePath, modelType = 'lora') {
|
|||||||
|
|
||||||
const endpoint = modelType === 'checkpoint'
|
const endpoint = modelType === 'checkpoint'
|
||||||
? '/api/checkpoints/delete'
|
? '/api/checkpoints/delete'
|
||||||
: '/api/delete_model';
|
: '/api/loras/delete';
|
||||||
|
|
||||||
const response = await fetch(endpoint, {
|
const response = await fetch(endpoint, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
@@ -454,7 +445,7 @@ export async function refreshSingleModelMetadata(filePath, modelType = 'lora') {
|
|||||||
|
|
||||||
const endpoint = modelType === 'checkpoint'
|
const endpoint = modelType === 'checkpoint'
|
||||||
? '/api/checkpoints/fetch-civitai'
|
? '/api/checkpoints/fetch-civitai'
|
||||||
: '/api/fetch-civitai';
|
: '/api/loras/fetch-civitai';
|
||||||
|
|
||||||
const response = await fetch(endpoint, {
|
const response = await fetch(endpoint, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
@@ -557,7 +548,7 @@ export async function uploadPreview(filePath, file, modelType = 'lora', nsfwLeve
|
|||||||
// Set endpoint based on model type
|
// Set endpoint based on model type
|
||||||
const endpoint = modelType === 'checkpoint'
|
const endpoint = modelType === 'checkpoint'
|
||||||
? '/api/checkpoints/replace-preview'
|
? '/api/checkpoints/replace-preview'
|
||||||
: '/api/replace_preview';
|
: '/api/loras/replace_preview';
|
||||||
|
|
||||||
const response = await fetch(endpoint, {
|
const response = await fetch(endpoint, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ export async function fetchLorasPage(page = 1, pageSize = 100) {
|
|||||||
export async function fetchCivitai() {
|
export async function fetchCivitai() {
|
||||||
return fetchCivitaiMetadata({
|
return fetchCivitaiMetadata({
|
||||||
modelType: 'lora',
|
modelType: 'lora',
|
||||||
fetchEndpoint: '/api/fetch-all-civitai',
|
fetchEndpoint: '/api/loras/fetch-all-civitai',
|
||||||
resetAndReloadFunction: resetAndReload
|
resetAndReloadFunction: resetAndReload
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ export const ModelContextMenuMixin = {
|
|||||||
|
|
||||||
const endpoint = this.modelType === 'checkpoint' ?
|
const endpoint = this.modelType === 'checkpoint' ?
|
||||||
'/api/checkpoints/relink-civitai' :
|
'/api/checkpoints/relink-civitai' :
|
||||||
'/api/relink-civitai';
|
'/api/loras/relink-civitai';
|
||||||
|
|
||||||
const response = await fetch(endpoint, {
|
const response = await fetch(endpoint, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
// PageControls.js - Manages controls for both LoRAs and Checkpoints pages
|
// PageControls.js - Manages controls for both LoRAs and Checkpoints pages
|
||||||
import { state, getCurrentPageState, setCurrentPageType } from '../../state/index.js';
|
import { getCurrentPageState, setCurrentPageType } from '../../state/index.js';
|
||||||
import { getStorageItem, setStorageItem, getSessionItem, setSessionItem } from '../../utils/storageHelpers.js';
|
import { getStorageItem, setStorageItem, getSessionItem, setSessionItem } from '../../utils/storageHelpers.js';
|
||||||
import { showToast } from '../../utils/uiHelpers.js';
|
import { showToast } from '../../utils/uiHelpers.js';
|
||||||
|
|
||||||
@@ -41,6 +41,9 @@ export class PageControls {
|
|||||||
this.pageState.isLoading = false;
|
this.pageState.isLoading = false;
|
||||||
this.pageState.hasMore = true;
|
this.pageState.hasMore = true;
|
||||||
|
|
||||||
|
// Set default sort based on page type
|
||||||
|
this.pageState.sortBy = this.pageType === 'loras' ? 'name:asc' : 'name:asc';
|
||||||
|
|
||||||
// Load sort preference
|
// Load sort preference
|
||||||
this.loadSortPreference();
|
this.loadSortPreference();
|
||||||
}
|
}
|
||||||
@@ -326,14 +329,36 @@ export class PageControls {
|
|||||||
loadSortPreference() {
|
loadSortPreference() {
|
||||||
const savedSort = getStorageItem(`${this.pageType}_sort`);
|
const savedSort = getStorageItem(`${this.pageType}_sort`);
|
||||||
if (savedSort) {
|
if (savedSort) {
|
||||||
this.pageState.sortBy = savedSort;
|
// Handle legacy format conversion
|
||||||
|
const convertedSort = this.convertLegacySortFormat(savedSort);
|
||||||
|
this.pageState.sortBy = convertedSort;
|
||||||
const sortSelect = document.getElementById('sortSelect');
|
const sortSelect = document.getElementById('sortSelect');
|
||||||
if (sortSelect) {
|
if (sortSelect) {
|
||||||
sortSelect.value = savedSort;
|
sortSelect.value = convertedSort;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert legacy sort format to new format
|
||||||
|
* @param {string} sortValue - The sort value to convert
|
||||||
|
* @returns {string} - Converted sort value
|
||||||
|
*/
|
||||||
|
convertLegacySortFormat(sortValue) {
|
||||||
|
// Convert old format to new format with direction
|
||||||
|
switch (sortValue) {
|
||||||
|
case 'name':
|
||||||
|
return 'name:asc';
|
||||||
|
case 'date':
|
||||||
|
return 'date:desc'; // Newest first is more intuitive default
|
||||||
|
case 'size':
|
||||||
|
return 'size:desc'; // Largest first is more intuitive default
|
||||||
|
default:
|
||||||
|
// If it's already in new format or unknown, return as is
|
||||||
|
return sortValue.includes(':') ? sortValue : 'name:asc';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Save sort preference to storage
|
* Save sort preference to storage
|
||||||
* @param {string} sortValue - The sort value to save
|
* @param {string} sortValue - The sort value to save
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ export class DownloadManager {
|
|||||||
throw new Error('Invalid Civitai URL format');
|
throw new Error('Invalid Civitai URL format');
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await fetch(`/api/civitai/versions/${this.modelId}`);
|
const response = await fetch(`/api/loras/civitai/versions/${this.modelId}`);
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
const errorData = await response.json().catch(() => ({}));
|
const errorData = await response.json().catch(() => ({}));
|
||||||
if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) {
|
if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) {
|
||||||
@@ -254,7 +254,7 @@ export class DownloadManager {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
// Fetch LoRA roots
|
// Fetch LoRA roots
|
||||||
const rootsResponse = await fetch('/api/lora-roots');
|
const rootsResponse = await fetch('/api/loras/roots');
|
||||||
if (!rootsResponse.ok) {
|
if (!rootsResponse.ok) {
|
||||||
throw new Error('Failed to fetch LoRA roots');
|
throw new Error('Failed to fetch LoRA roots');
|
||||||
}
|
}
|
||||||
@@ -272,7 +272,7 @@ export class DownloadManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fetch folders dynamically
|
// Fetch folders dynamically
|
||||||
const foldersResponse = await fetch('/api/folders');
|
const foldersResponse = await fetch('/api/loras/folders');
|
||||||
if (!foldersResponse.ok) {
|
if (!foldersResponse.ok) {
|
||||||
throw new Error('Failed to fetch folders');
|
throw new Error('Failed to fetch folders');
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ class MoveManager {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
// Fetch LoRA roots
|
// Fetch LoRA roots
|
||||||
const rootsResponse = await fetch('/api/lora-roots');
|
const rootsResponse = await fetch('/api/loras/roots');
|
||||||
if (!rootsResponse.ok) {
|
if (!rootsResponse.ok) {
|
||||||
throw new Error('Failed to fetch LoRA roots');
|
throw new Error('Failed to fetch LoRA roots');
|
||||||
}
|
}
|
||||||
@@ -96,7 +96,7 @@ class MoveManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fetch folders dynamically
|
// Fetch folders dynamically
|
||||||
const foldersResponse = await fetch('/api/folders');
|
const foldersResponse = await fetch('/api/loras/folders');
|
||||||
if (!foldersResponse.ok) {
|
if (!foldersResponse.ok) {
|
||||||
throw new Error('Failed to fetch folders');
|
throw new Error('Failed to fetch folders');
|
||||||
}
|
}
|
||||||
@@ -190,7 +190,7 @@ class MoveManager {
|
|||||||
|
|
||||||
// Refresh folder tags after successful move
|
// Refresh folder tags after successful move
|
||||||
try {
|
try {
|
||||||
const foldersResponse = await fetch('/api/folders');
|
const foldersResponse = await fetch('/api/loras/folders');
|
||||||
if (foldersResponse.ok) {
|
if (foldersResponse.ok) {
|
||||||
const foldersData = await foldersResponse.json();
|
const foldersData = await foldersResponse.json();
|
||||||
updateFolderTags(foldersData.folders);
|
updateFolderTags(foldersData.folders);
|
||||||
|
|||||||
@@ -161,7 +161,7 @@ export class SettingsManager {
|
|||||||
if (!defaultLoraRootSelect) return;
|
if (!defaultLoraRootSelect) return;
|
||||||
|
|
||||||
// Fetch lora roots
|
// Fetch lora roots
|
||||||
const response = await fetch('/api/lora-roots');
|
const response = await fetch('/api/loras/roots');
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
throw new Error('Failed to fetch LoRA roots');
|
throw new Error('Failed to fetch LoRA roots');
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import { getStorageItem, setStorageItem } from '../utils/storageHelpers.js';
|
|||||||
|
|
||||||
export class UpdateService {
|
export class UpdateService {
|
||||||
constructor() {
|
constructor() {
|
||||||
this.updateCheckInterval = 24 * 60 * 60 * 1000; // 24 hours
|
this.updateCheckInterval = 60 * 60 * 1000; // 1 hour
|
||||||
this.currentVersion = "v0.0.0"; // Initialize with default values
|
this.currentVersion = "v0.0.0"; // Initialize with default values
|
||||||
this.latestVersion = "v0.0.0"; // Initialize with default values
|
this.latestVersion = "v0.0.0"; // Initialize with default values
|
||||||
this.updateInfo = null;
|
this.updateInfo = null;
|
||||||
@@ -13,8 +13,10 @@ export class UpdateService {
|
|||||||
branch: "unknown",
|
branch: "unknown",
|
||||||
commit_date: "unknown"
|
commit_date: "unknown"
|
||||||
};
|
};
|
||||||
this.updateNotificationsEnabled = getStorageItem('show_update_notifications');
|
this.updateNotificationsEnabled = getStorageItem('show_update_notifications', true);
|
||||||
this.lastCheckTime = parseInt(getStorageItem('last_update_check') || '0');
|
this.lastCheckTime = parseInt(getStorageItem('last_update_check') || '0');
|
||||||
|
this.isUpdating = false;
|
||||||
|
this.nightlyMode = getStorageItem('nightly_updates', false);
|
||||||
}
|
}
|
||||||
|
|
||||||
initialize() {
|
initialize() {
|
||||||
@@ -29,22 +31,43 @@ export class UpdateService {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const updateBtn = document.getElementById('updateBtn');
|
||||||
|
if (updateBtn) {
|
||||||
|
updateBtn.addEventListener('click', () => this.performUpdate());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register event listener for nightly update toggle
|
||||||
|
const nightlyCheckbox = document.getElementById('nightlyUpdateToggle');
|
||||||
|
if (nightlyCheckbox) {
|
||||||
|
nightlyCheckbox.checked = this.nightlyMode;
|
||||||
|
nightlyCheckbox.addEventListener('change', (e) => {
|
||||||
|
this.nightlyMode = e.target.checked;
|
||||||
|
setStorageItem('nightly_updates', e.target.checked);
|
||||||
|
this.updateNightlyWarning();
|
||||||
|
this.updateModalContent();
|
||||||
|
// Re-check for updates when switching channels
|
||||||
|
this.manualCheckForUpdates();
|
||||||
|
});
|
||||||
|
this.updateNightlyWarning();
|
||||||
|
}
|
||||||
|
|
||||||
// Perform update check if needed
|
// Perform update check if needed
|
||||||
this.checkForUpdates().then(() => {
|
this.checkForUpdates().then(() => {
|
||||||
// Ensure badges are updated after checking
|
// Ensure badges are updated after checking
|
||||||
this.updateBadgeVisibility();
|
this.updateBadgeVisibility();
|
||||||
});
|
});
|
||||||
|
|
||||||
// Set up event listener for update button
|
|
||||||
// const updateToggle = document.getElementById('updateToggleBtn');
|
|
||||||
// if (updateToggle) {
|
|
||||||
// updateToggle.addEventListener('click', () => this.toggleUpdateModal());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// Immediately update modal content with current values (even if from default)
|
// Immediately update modal content with current values (even if from default)
|
||||||
this.updateModalContent();
|
this.updateModalContent();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
updateNightlyWarning() {
|
||||||
|
const warning = document.getElementById('nightlyWarning');
|
||||||
|
if (warning) {
|
||||||
|
warning.style.display = this.nightlyMode ? 'flex' : 'none';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async checkForUpdates() {
|
async checkForUpdates() {
|
||||||
// Check if we should perform an update check
|
// Check if we should perform an update check
|
||||||
const now = Date.now();
|
const now = Date.now();
|
||||||
@@ -59,8 +82,8 @@ export class UpdateService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Call backend API to check for updates
|
// Call backend API to check for updates with nightly flag
|
||||||
const response = await fetch('/api/check-updates');
|
const response = await fetch(`/api/check-updates?nightly=${this.nightlyMode}`);
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
if (data.success) {
|
if (data.success) {
|
||||||
@@ -137,8 +160,8 @@ export class UpdateService {
|
|||||||
const shouldShow = this.updateNotificationsEnabled && this.updateAvailable;
|
const shouldShow = this.updateNotificationsEnabled && this.updateAvailable;
|
||||||
|
|
||||||
if (updateBadge) {
|
if (updateBadge) {
|
||||||
updateBadge.classList.toggle('hidden', !shouldShow);
|
updateBadge.classList.toggle('visible', shouldShow);
|
||||||
console.log("Update badge visibility:", !shouldShow ? "hidden" : "visible");
|
console.log("Update badge visibility:", shouldShow ? "visible" : "hidden");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,7 +180,17 @@ export class UpdateService {
|
|||||||
const newVersionEl = modal.querySelector('.new-version .version-number');
|
const newVersionEl = modal.querySelector('.new-version .version-number');
|
||||||
|
|
||||||
if (currentVersionEl) currentVersionEl.textContent = this.currentVersion;
|
if (currentVersionEl) currentVersionEl.textContent = this.currentVersion;
|
||||||
if (newVersionEl) newVersionEl.textContent = this.latestVersion;
|
|
||||||
|
if (newVersionEl) {
|
||||||
|
newVersionEl.textContent = this.latestVersion;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update update button state
|
||||||
|
const updateBtn = modal.querySelector('#updateBtn');
|
||||||
|
if (updateBtn) {
|
||||||
|
updateBtn.classList.toggle('disabled', !this.updateAvailable || this.isUpdating);
|
||||||
|
updateBtn.disabled = !this.updateAvailable || this.isUpdating;
|
||||||
|
}
|
||||||
|
|
||||||
// Update git info
|
// Update git info
|
||||||
const gitInfoEl = modal.querySelector('.git-info');
|
const gitInfoEl = modal.querySelector('.git-info');
|
||||||
@@ -218,6 +251,131 @@ export class UpdateService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async performUpdate() {
|
||||||
|
if (!this.updateAvailable || this.isUpdating) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
this.isUpdating = true;
|
||||||
|
this.updateUpdateUI('updating', 'Updating...');
|
||||||
|
this.showUpdateProgress(true);
|
||||||
|
|
||||||
|
// Update progress
|
||||||
|
this.updateProgress(10, 'Preparing update...');
|
||||||
|
|
||||||
|
const response = await fetch('/api/perform-update', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
nightly: this.nightlyMode
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
this.updateProgress(50, 'Installing update...');
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
this.updateProgress(100, 'Update completed successfully!');
|
||||||
|
this.updateUpdateUI('success', 'Updated!');
|
||||||
|
|
||||||
|
// Show success message and suggest restart
|
||||||
|
setTimeout(() => {
|
||||||
|
this.showUpdateCompleteMessage(data.new_version);
|
||||||
|
}, 1000);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
throw new Error(data.error || 'Update failed');
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Update failed:', error);
|
||||||
|
this.updateUpdateUI('error', 'Update Failed');
|
||||||
|
this.updateProgress(0, `Update failed: ${error.message}`);
|
||||||
|
|
||||||
|
// Hide progress after error
|
||||||
|
setTimeout(() => {
|
||||||
|
this.showUpdateProgress(false);
|
||||||
|
}, 3000);
|
||||||
|
} finally {
|
||||||
|
this.isUpdating = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
updateUpdateUI(state, text) {
|
||||||
|
const updateBtn = document.getElementById('updateBtn');
|
||||||
|
const updateBtnText = document.getElementById('updateBtnText');
|
||||||
|
|
||||||
|
if (updateBtn && updateBtnText) {
|
||||||
|
// Remove existing state classes
|
||||||
|
updateBtn.classList.remove('updating', 'success', 'error', 'disabled');
|
||||||
|
|
||||||
|
// Add new state class
|
||||||
|
if (state !== 'normal') {
|
||||||
|
updateBtn.classList.add(state);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update button text
|
||||||
|
updateBtnText.textContent = text;
|
||||||
|
|
||||||
|
// Update disabled state
|
||||||
|
updateBtn.disabled = (state === 'updating' || state === 'disabled');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
showUpdateProgress(show) {
|
||||||
|
const progressContainer = document.getElementById('updateProgress');
|
||||||
|
if (progressContainer) {
|
||||||
|
progressContainer.style.display = show ? 'block' : 'none';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
updateProgress(percentage, text) {
|
||||||
|
const progressFill = document.getElementById('updateProgressFill');
|
||||||
|
const progressText = document.getElementById('updateProgressText');
|
||||||
|
|
||||||
|
if (progressFill) {
|
||||||
|
progressFill.style.width = `${percentage}%`;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (progressText) {
|
||||||
|
progressText.textContent = text;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
showUpdateCompleteMessage(newVersion) {
|
||||||
|
const modal = document.getElementById('updateModal');
|
||||||
|
if (!modal) return;
|
||||||
|
|
||||||
|
// Update the modal content to show completion
|
||||||
|
const progressText = document.getElementById('updateProgressText');
|
||||||
|
if (progressText) {
|
||||||
|
progressText.innerHTML = `
|
||||||
|
<div style="text-align: center; color: var(--lora-success);">
|
||||||
|
<i class="fas fa-check-circle" style="margin-right: 8px;"></i>
|
||||||
|
Successfully updated to ${newVersion}!
|
||||||
|
<br><br>
|
||||||
|
<small style="opacity: 0.8;">
|
||||||
|
Please restart ComfyUI to complete the update process.
|
||||||
|
</small>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update current version display
|
||||||
|
this.currentVersion = newVersion;
|
||||||
|
this.updateAvailable = false;
|
||||||
|
|
||||||
|
// Refresh the modal content
|
||||||
|
setTimeout(() => {
|
||||||
|
this.updateModalContent();
|
||||||
|
this.showUpdateProgress(false);
|
||||||
|
}, 2000);
|
||||||
|
}
|
||||||
|
|
||||||
// Simple markdown parser for changelog items
|
// Simple markdown parser for changelog items
|
||||||
parseMarkdown(text) {
|
parseMarkdown(text) {
|
||||||
if (!text) return '';
|
if (!text) return '';
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ export class FolderBrowser {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fetch LoRA roots
|
// Fetch LoRA roots
|
||||||
const rootsResponse = await fetch('/api/lora-roots');
|
const rootsResponse = await fetch('/api/loras/roots');
|
||||||
if (!rootsResponse.ok) {
|
if (!rootsResponse.ok) {
|
||||||
throw new Error(`Failed to fetch LoRA roots: ${rootsResponse.status}`);
|
throw new Error(`Failed to fetch LoRA roots: ${rootsResponse.status}`);
|
||||||
}
|
}
|
||||||
@@ -119,7 +119,7 @@ export class FolderBrowser {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fetch folders
|
// Fetch folders
|
||||||
const foldersResponse = await fetch('/api/folders');
|
const foldersResponse = await fetch('/api/loras/folders');
|
||||||
if (!foldersResponse.ok) {
|
if (!foldersResponse.ok) {
|
||||||
throw new Error(`Failed to fetch folders: ${foldersResponse.status}`);
|
throw new Error(`Failed to fetch folders: ${foldersResponse.status}`);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,8 +11,18 @@
|
|||||||
<div class="action-buttons">
|
<div class="action-buttons">
|
||||||
<div title="Sort models by..." class="control-group">
|
<div title="Sort models by..." class="control-group">
|
||||||
<select id="sortSelect">
|
<select id="sortSelect">
|
||||||
<option value="name">Name</option>
|
<optgroup label="Name">
|
||||||
<option value="date">Date</option>
|
<option value="name:asc">A - Z</option>
|
||||||
|
<option value="name:desc">Z - A</option>
|
||||||
|
</optgroup>
|
||||||
|
<optgroup label="Date Added">
|
||||||
|
<option value="date:desc">Newest</option>
|
||||||
|
<option value="date:asc">Oldest</option>
|
||||||
|
</optgroup>
|
||||||
|
<optgroup label="File Size">
|
||||||
|
<option value="size:desc">Largest</option>
|
||||||
|
<option value="size:asc">Smallest</option>
|
||||||
|
</optgroup>
|
||||||
</select>
|
</select>
|
||||||
</div>
|
</div>
|
||||||
<div title="Refresh model list" class="control-group dropdown-group">
|
<div title="Refresh model list" class="control-group dropdown-group">
|
||||||
|
|||||||
@@ -53,7 +53,7 @@
|
|||||||
</div>
|
</div>
|
||||||
<div class="update-toggle" id="updateToggleBtn" title="Check Updates">
|
<div class="update-toggle" id="updateToggleBtn" title="Check Updates">
|
||||||
<i class="fas fa-bell"></i>
|
<i class="fas fa-bell"></i>
|
||||||
<span class="update-badge hidden"></span>
|
<span class="update-badge"></span>
|
||||||
</div>
|
</div>
|
||||||
<div class="support-toggle" id="supportToggleBtn" title="Support">
|
<div class="support-toggle" id="supportToggleBtn" title="Support">
|
||||||
<i class="fas fa-heart"></i>
|
<i class="fas fa-heart"></i>
|
||||||
|
|||||||
@@ -476,9 +476,26 @@
|
|||||||
<span class="version-number">v0.0.0</span>
|
<span class="version-number">v0.0.0</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<a href="https://github.com/willmiao/ComfyUI-Lora-Manager" target="_blank" class="update-link">
|
|
||||||
<i class="fas fa-external-link-alt"></i> View on GitHub
|
<div class="update-actions">
|
||||||
</a>
|
<a href="https://github.com/willmiao/ComfyUI-Lora-Manager" target="_blank" class="update-link">
|
||||||
|
<i class="fas fa-external-link-alt"></i> View on GitHub
|
||||||
|
</a>
|
||||||
|
<button id="updateBtn" class="primary-btn disabled">
|
||||||
|
<i class="fas fa-download"></i>
|
||||||
|
<span id="updateBtnText">Update Now</span>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Update Progress Section -->
|
||||||
|
<div class="update-progress" id="updateProgress" style="display: none;">
|
||||||
|
<div class="progress-info">
|
||||||
|
<div class="progress-text" id="updateProgressText">Preparing update...</div>
|
||||||
|
<div class="progress-bar">
|
||||||
|
<div class="progress-fill" id="updateProgressFill"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="changelog-section">
|
<div class="changelog-section">
|
||||||
|
|||||||
Reference in New Issue
Block a user