Compare commits

...

64 Commits

Author SHA1 Message Date
Will Miao
cede387783 Bump version to 0.8.6 in pyproject.toml 2025-04-14 08:42:00 +08:00
Will Miao
b206427d50 feat: Update README to include enhanced checkpoint management features and improved initial loading details 2025-04-14 08:40:42 +08:00
Will Miao
47d96e2037 feat: Simplify recipe page initialization and enhance error handling for recipe cache loading 2025-04-14 07:03:34 +08:00
Will Miao
e51f7cc1a7 feat: Enhance checkpoint download manager to save active folder preference and update UI accordingly 2025-04-13 22:12:18 +08:00
Will Miao
40381d4b11 feat: Optimize session management and enhance download functionality with resumable support 2025-04-13 21:51:21 +08:00
Will Miao
76fc9e5a3d feat: Add WebSocket support for checkpoint download progress and update related components 2025-04-13 21:31:01 +08:00
Will Miao
9822f2c614 feat: Add Civitai model version retrieval for Checkpoints and update error handling in download managers 2025-04-13 20:36:19 +08:00
Will Miao
8854334ab5 Add tip images 2025-04-13 18:46:44 +08:00
Will Miao
53080844d2 feat: Refactor progress bar classes for initialization component to improve clarity and avoid conflicts 2025-04-13 18:42:36 +08:00
Will Miao
76fd722e33 feat: Improve card layout by adding overflow hidden and fixing flexbox sizing issues 2025-04-13 18:20:15 +08:00
Will Miao
fa27513f76 feat: Enhance infinite scroll functionality with improved observer settings and scroll event handling 2025-04-13 17:58:14 +08:00
Will Miao
72c6f91130 feat: Update initialization component with loading progress and tips carousel 2025-04-13 14:03:02 +08:00
Will Miao
5918f35b8b feat: Add keyboard shortcuts for search input focus and selection 2025-04-13 13:12:32 +08:00
Will Miao
0b11e6e6d0 feat: Enhance initialization component with progress tracking and UI improvements 2025-04-13 12:58:38 +08:00
Will Miao
a043b487bd feat: Add initialization progress WebSocket and UI components
- Implement WebSocket route for initialization progress updates
- Create initialization component with progress bar and stages
- Add styles for initialization UI
- Update base template to include initialization component
- Enhance model scanner to broadcast progress during initialization
2025-04-13 10:41:27 +08:00
pixelpaws
3982489e67 Merge pull request #97 from willmiao/dev
feat: Enhance checkpoint handling by initializing paths and adding st…
2025-04-12 19:10:13 +08:00
Will Miao
5f3c515323 feat: Enhance checkpoint handling by initializing paths and adding static routes 2025-04-12 19:06:17 +08:00
pixelpaws
6e1297d734 Merge pull request #96 from willmiao/dev
Dev
2025-04-12 17:01:07 +08:00
Will Miao
8f3cbdd257 fix: Simplify session item retrieval in loadMoreModels function 2025-04-12 16:54:27 +08:00
Will Miao
2fc06ae64e Refactor file name update in Lora card
- Updated the setupFileNameEditing function to pass the new file name in the updates object when calling updateLoraCard.
- Removed the page reload after file name change to improve user experience.
- Enhanced the updateLoraCard function to handle the 'file_name' update, ensuring the dataset reflects the new file name correctly.
2025-04-12 16:35:35 +08:00
Will Miao
515aa1d2bd fix: Improve error logging and update lora monitor path handling 2025-04-12 16:24:29 +08:00
Will Miao
ff7a36394a refactor: Optimize event handling for folder tags using delegation 2025-04-12 16:15:29 +08:00
Will Miao
5261ab249a Fix checkpoints sort_by 2025-04-12 13:39:32 +08:00
Will Miao
c3192351da feat: Add support for reading SHA256 from .sha256 file in get_file_info function 2025-04-12 11:59:40 +08:00
Will Miao
ce30d067a6 feat: Import and expose loadMoreLoras function in LoraPageManager 2025-04-12 11:46:26 +08:00
Will Miao
e84a8a72c5 feat: Add save metadata route and update checkpoint card functionality 2025-04-12 11:18:21 +08:00
Will Miao
10a4fe04d1 refactor: Update API endpoint for saving model metadata to use consistent route structure 2025-04-12 09:03:34 +08:00
Will Miao
d5ce6441e3 refactor: Simplify service initialization in LoraRoutes and RecipeRoutes, and adjust logging level in ServiceRegistry 2025-04-12 09:01:09 +08:00
Will Miao
a8d21fb1d6 refactor: Remove unused service imports and add new route for scanning LoRA files 2025-04-12 07:49:11 +08:00
Will Miao
9277d8d8f8 refactor: Disable file monitoring functionality with ENABLE_FILE_MONITORING flag 2025-04-12 06:47:47 +08:00
Will Miao
0618541527 checkpoint 2025-04-11 20:22:12 +08:00
Will Miao
1db49a4dd4 refactor: Enhance checkpoint download functionality with new modal and manager integration 2025-04-11 18:25:37 +08:00
Will Miao
3df96034a1 refactor: Consolidate model handling functions into baseModelApi for better code reuse and organization 2025-04-11 14:35:56 +08:00
Will Miao
e991dc061d refactor: Implement common endpoint handlers for model management in ModelRouteUtils and update routes in CheckpointsRoutes 2025-04-11 12:06:05 +08:00
Will Miao
56670066c7 refactor: Optimize preview image handling by converting to webp format and improving error logging 2025-04-11 11:17:49 +08:00
Will Miao
31d27ff3fa refactor: Extract model-related utility functions into ModelRouteUtils for better code organization 2025-04-11 10:54:19 +08:00
Will Miao
297ff0dd25 refactor: Improve download handling for previews and optimize image conversion in DownloadManager 2025-04-11 09:00:58 +08:00
Will Miao
b0a5b48fb2 refactor: Enhance preview file handling and add update_preview_in_cache method for ModelScanner 2025-04-11 08:43:21 +08:00
Will Miao
ac244e6ad9 refactor: Replace hardcoded image width with CARD_PREVIEW_WIDTH constant for consistency 2025-04-11 08:19:19 +08:00
Will Miao
7393e92b21 refactor: Consolidate preview file extensions into constants for improved maintainability 2025-04-11 06:19:15 +08:00
Will Miao
86810d9f03 refactor: Remove move_model method from LoraScanner class to streamline code 2025-04-11 06:05:19 +08:00
Will Miao
18aa8d11ad refactor: Remove showToast call from clearCustomFilter method in LorasControls 2025-04-11 05:59:32 +08:00
Will Miao
fafec56f09 refactor: Rename update_single_lora_cache to update_single_model_cache for consistency 2025-04-11 05:52:56 +08:00
Will Miao
129ca9da81 feat: Implement checkpoint modal functionality with metadata editing, showcase display, and utility functions
- Added ModelMetadata.js for handling model metadata editing, including model name, base model, and file name.
- Introduced ShowcaseView.js to manage the display of images and videos in the checkpoint modal, including NSFW filtering and lazy loading.
- Created index.js as the main entry point for the checkpoint modal, integrating various components and functionalities.
- Developed utils.js for utility functions related to file size formatting and tag rendering.
- Enhanced user experience with editable fields, toast notifications, and improved showcase scrolling.
2025-04-10 22:59:09 +08:00
Will Miao
cbfb9ac87c Enhance CheckpointModal: Implement detailed checkpoint display, editable fields, and showcase functionality 2025-04-10 22:25:40 +08:00
Will Miao
42309edef4 Refactor visibility toggle: Remove toggleApiKeyVisibility function and update related button in modals 2025-04-10 21:43:56 +08:00
Will Miao
559e57ca46 Enhance CheckpointCard: Implement NSFW content handling, toggle blur functionality, and improve video autoplay behavior 2025-04-10 21:28:34 +08:00
Will Miao
311bf1f157 Add support for '.gguf' file extension in CheckpointScanner 2025-04-10 21:15:12 +08:00
Will Miao
131c3cc324 Add Civitai metadata fetching functionality for checkpoints
- Implement fetchCivitai API method to retrieve metadata from Civitai.
- Enhance CheckpointsControls to include fetch from Civitai functionality.
- Update PageControls to register fetch from Civitai event listener for both LoRAs and Checkpoints.
2025-04-10 21:07:17 +08:00
Will Miao
152ec0da0d Refactor Checkpoints functionality: Integrate loadMoreCheckpoints API, remove CheckpointSearchManager, and enhance FilterManager for improved checkpoint loading and filtering. 2025-04-10 19:57:04 +08:00
Will Miao
ee04df40c3 Refactor controls and pagination for Checkpoints and LoRAs: Implement unified PageControls, enhance API integration, and improve event handling for better user experience. 2025-04-10 19:41:02 +08:00
Will Miao
252e90a633 Enhance Checkpoints Manager: Implement API integration for checkpoints, add filtering and sorting options, and improve UI components for better user experience 2025-04-10 16:04:08 +08:00
Will Miao
048d486fa6 Refactor cache initialization in LoraManager and RecipeScanner for improved background processing and error handling 2025-04-10 11:34:19 +08:00
Will Miao
8fdfb68741 checkpoint 2025-04-10 09:08:51 +08:00
Will Miao
64c9e4aeca Update version to 0.8.5 and add release notes for enhanced features and improvements 2025-04-09 11:41:38 +08:00
Will Miao
08b90e8767 Update toast messages to clarify settings update notifications 2025-04-09 11:29:02 +08:00
Will Miao
0206613f9e Update NSFW level filter to include 'R' rating for improved content moderation 2025-04-09 11:25:52 +08:00
Will Miao
ae0629628e Enhance settings modal with video autoplay on hover option and improve layout. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/92 2025-04-09 11:18:30 +08:00
Will Miao
785b2e7287 style: Add padding to recipe list to prevent item cutoff on hover 2025-04-08 13:51:00 +08:00
Will Miao
43e3d0552e style: Update filter indicator and button styles for improved UI consistency
feat: Add pulse animation to filter indicators in Lora and recipe management
refactor: Change filter-active button to a div for better semantic structure
2025-04-08 13:45:15 +08:00
Will Miao
801aa2e876 Enhance Lora and recipe integration with improved filtering and UI updates
- Added support for filtering LoRAs by hash in both API and UI components.
- Implemented session storage management for custom filter states when navigating between recipes and LoRAs.
- Introduced a new button in the recipe modal to view associated LoRAs, enhancing user navigation.
- Updated CSS styles for new UI elements, including a custom filter indicator and LoRA view button.
- Refactored existing JavaScript components to streamline the handling of filter parameters and improve maintainability.
2025-04-08 12:23:51 +08:00
Will Miao
bddc7a438d feat: Add Lora recipes retrieval and filtering functionality
- Implemented a new API endpoint to fetch recipes associated with a specific Lora by its hash.
- Enhanced the recipe scanning logic to support filtering by Lora hash and bypassing other filters.
- Added a new method to retrieve a recipe by its ID with formatted metadata.
- Created a new RecipeTab component to display recipes in the Lora modal.
- Introduced session storage utilities for managing custom filter states.
- Updated the UI to include a custom filter indicator and loading/error states for recipes.
- Refactored existing recipe management logic to accommodate new features and improve maintainability.
2025-04-07 21:53:39 +08:00
Will Miao
b8c78a68e7 refactor: remove unused recipe card CSS styles 2025-04-07 20:36:58 +08:00
Will Miao
49219f4447 feat: Refactor LoraModal into modular components
- Added ShowcaseView.js for rendering LoRA model showcase content with NSFW filtering and lazy loading.
- Introduced TriggerWords.js to manage trigger words, including editing, adding, and saving functionality.
- Created index.js as the main entry point for the LoraModal, integrating all components and functionalities.
- Implemented utils.js for utility functions such as file size formatting and tag rendering.
- Enhanced user experience with editable fields, tooltips, and improved event handling for trigger words and presets.
2025-04-07 15:36:13 +08:00
90 changed files with 12073 additions and 4107 deletions

View File

@@ -6,7 +6,7 @@
[![Release](https://img.shields.io/github/v/release/willmiao/ComfyUI-Lora-Manager?include_prereleases&color=blue&logo=github)](https://github.com/willmiao/ComfyUI-Lora-Manager/releases)
[![Release Date](https://img.shields.io/github/release-date/willmiao/ComfyUI-Lora-Manager?color=green&logo=github)](https://github.com/willmiao/ComfyUI-Lora-Manager/releases)
A comprehensive toolset that streamlines organizing, downloading, and applying LoRA models in ComfyUI. With powerful features like recipe management and one-click workflow integration, working with LoRAs becomes faster, smoother, and significantly easier. Access the interface at: `http://localhost:8188/loras`
A comprehensive toolset that streamlines organizing, downloading, and applying LoRA models in ComfyUI. With powerful features like recipe management, checkpoint organization, and one-click workflow integration, working with models becomes faster, smoother, and significantly easier. Access the interface at: `http://localhost:8188/loras`
![Interface Preview](https://github.com/willmiao/ComfyUI-Lora-Manager/blob/main/static/images/screenshot.png)
@@ -20,6 +20,17 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
## Release Notes
### v0.8.6 Major Update
* **Checkpoint Management** - Added comprehensive management for model checkpoints including scanning, searching, filtering, and deletion
* **Enhanced Metadata Support** - New capabilities for retrieving and managing checkpoint metadata with improved operations
* **Improved Initial Loading** - Optimized cache initialization with visual progress indicators for better user experience
### v0.8.5
* **Enhanced LoRA & Recipe Connectivity** - Added Recipes tab in LoRA details to see all recipes using a specific LoRA
* **Improved Navigation** - New shortcuts to jump between related LoRAs and Recipes with one-click navigation
* **Video Preview Controls** - Added "Autoplay Videos on Hover" setting to optimize performance and reduce resource usage
* **UI Experience Refinements** - Smoother transitions between related content pages
### v0.8.4
* **Node Layout Improvements** - Fixed layout issues with LoRA Loader and Trigger Words Toggle nodes in newer ComfyUI frontend versions
* **Recipe LoRA Reconnection** - Added ability to reconnect deleted LoRAs in recipes by clicking the "deleted" badge in recipe details
@@ -86,6 +97,12 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
- Trigger words at a glance
- One-click workflow integration with preset values
- 🔄 **Checkpoint Management**
- Scan and organize checkpoint models
- Filter and search your collection
- View and edit metadata
- Clean up and manage disk space
- 🧩 **LoRA Recipes**
- Save and share favorite LoRA combinations
- Preserve generation parameters for future reference
@@ -97,6 +114,7 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
- Context menu for quick actions
- Custom notes and usage tips
- Multi-folder support
- Visual progress indicators during initialization
---

View File

@@ -17,6 +17,7 @@ class Config:
# 静态路由映射字典, target to route mapping
self._route_mappings = {}
self.loras_roots = self._init_lora_paths()
self.checkpoints_roots = self._init_checkpoint_paths()
self.temp_directory = folder_paths.get_temp_directory()
# 在初始化时扫描符号链接
self._scan_symbolic_links()
@@ -39,9 +40,12 @@ class Config:
return False
def _scan_symbolic_links(self):
"""扫描所有 LoRA 根目录中的符号链接"""
"""扫描所有 LoRA 和 Checkpoint 根目录中的符号链接"""
for root in self.loras_roots:
self._scan_directory_links(root)
for root in self.checkpoints_roots:
self._scan_directory_links(root)
def _scan_directory_links(self, root: str):
"""递归扫描目录中的符号链接"""
@@ -73,7 +77,7 @@ class Config:
"""添加静态路由映射"""
normalized_path = os.path.normpath(path).replace(os.sep, '/')
self._route_mappings[normalized_path] = route
logger.info(f"Added route mapping: {normalized_path} -> {route}")
# logger.info(f"Added route mapping: {normalized_path} -> {route}")
def map_path_to_link(self, path: str) -> str:
"""将目标路径映射回符号链接路径"""
@@ -115,6 +119,35 @@ class Config:
return paths
def _init_checkpoint_paths(self) -> List[str]:
"""Initialize and validate checkpoint paths from ComfyUI settings"""
# Get checkpoint paths from folder_paths
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
diffusion_paths = folder_paths.get_folder_paths("diffusers")
unet_paths = folder_paths.get_folder_paths("unet")
# Combine all checkpoint-related paths
all_paths = checkpoint_paths + diffusion_paths + unet_paths
# Filter and normalize paths
paths = sorted(set(path.replace(os.sep, "/")
for path in all_paths
if os.path.exists(path)), key=lambda p: p.lower())
print("Found checkpoint roots:", paths)
if not paths:
logger.warning("No valid checkpoint folders found in ComfyUI configuration")
return []
# 初始化路径映射,与 LoRA 路径处理方式相同
for path in paths:
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
if real_path != path:
self.add_path_mapping(path, real_path)
return paths
def get_preview_static_url(self, preview_path: str) -> str:
"""Convert local preview path to static URL"""
if not preview_path:

View File

@@ -1,16 +1,11 @@
import asyncio
import os
from server import PromptServer # type: ignore
from .config import config
from .routes.lora_routes import LoraRoutes
from .routes.api_routes import ApiRoutes
from .routes.recipe_routes import RecipeRoutes
from .routes.checkpoints_routes import CheckpointsRoutes
from .services.lora_scanner import LoraScanner
from .services.recipe_scanner import RecipeScanner
from .services.file_monitor import LoraFileMonitor
from .services.lora_cache import LoraCache
from .services.recipe_cache import RecipeCache
from .services.service_registry import ServiceRegistry
import logging
logger = logging.getLogger(__name__)
@@ -23,7 +18,7 @@ class LoraManager:
"""Initialize and register all routes"""
app = PromptServer.instance.app
added_targets = set() # 用于跟踪已添加的目标路径
added_targets = set() # Track already added target paths
# Add static routes for each lora root
for idx, root in enumerate(config.loras_roots, start=1):
@@ -35,102 +30,143 @@ class LoraManager:
if link == root:
real_root = target
break
# 为原始路径添加静态路由
# Add static route for original path
app.router.add_static(preview_path, real_root)
logger.info(f"Added static route {preview_path} -> {real_root}")
# 记录路由映射
# Record route mapping
config.add_route_mapping(real_root, preview_path)
added_targets.add(real_root)
# 为符号链接的目标路径添加额外的静态路由
link_idx = 1
# Add static routes for each checkpoint root
for idx, root in enumerate(config.checkpoints_roots, start=1):
preview_path = f'/checkpoints_static/root{idx}/preview'
real_root = root
if root in config._path_mappings.values():
for target, link in config._path_mappings.items():
if link == root:
real_root = target
break
# Add static route for original path
app.router.add_static(preview_path, real_root)
logger.info(f"Added static route {preview_path} -> {real_root}")
# Record route mapping
config.add_route_mapping(real_root, preview_path)
added_targets.add(real_root)
# Add static routes for symlink target paths
link_idx = {
'lora': 1,
'checkpoint': 1
}
for target_path, link_path in config._path_mappings.items():
if target_path not in added_targets:
route_path = f'/loras_static/link_{link_idx}/preview'
# Determine if this is a checkpoint or lora link based on path
is_checkpoint = any(cp_root in link_path for cp_root in config.checkpoints_roots)
is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.checkpoints_roots)
if is_checkpoint:
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'
link_idx["checkpoint"] += 1
else:
route_path = f'/loras_static/link_{link_idx["lora"]}/preview'
link_idx["lora"] += 1
app.router.add_static(route_path, target_path)
logger.info(f"Added static route for link target {route_path} -> {target_path}")
config.add_route_mapping(target_path, route_path)
added_targets.add(target_path)
link_idx += 1
# Add static route for plugin assets
app.router.add_static('/loras_static', config.static_path)
# Setup feature routes
routes = LoraRoutes()
lora_routes = LoraRoutes()
checkpoints_routes = CheckpointsRoutes()
# Setup file monitoring
monitor = LoraFileMonitor(routes.scanner, config.loras_roots)
monitor.start()
routes.setup_routes(app)
# Initialize routes
lora_routes.setup_routes(app)
checkpoints_routes.setup_routes(app)
ApiRoutes.setup_routes(app, monitor)
ApiRoutes.setup_routes(app)
RecipeRoutes.setup_routes(app)
# Store monitor in app for cleanup
app['lora_monitor'] = monitor
# Schedule cache initialization using the application's startup handler
app.on_startup.append(lambda app: cls._schedule_cache_init(routes.scanner, routes.recipe_scanner))
# Schedule service initialization
app.on_startup.append(lambda app: cls._initialize_services())
# Add cleanup
app.on_shutdown.append(cls._cleanup)
app.on_shutdown.append(ApiRoutes.cleanup)
@classmethod
async def _schedule_cache_init(cls, scanner: LoraScanner, recipe_scanner: RecipeScanner):
"""Schedule cache initialization in the running event loop"""
async def _initialize_services(cls):
"""Initialize all services using the ServiceRegistry"""
try:
# 创建低优先级的初始化任务
lora_task = asyncio.create_task(cls._initialize_lora_cache(scanner), name='lora_cache_init')
logger.info("LoRA Manager: Initializing services via ServiceRegistry")
# Schedule recipe cache initialization with a delay to let lora scanner initialize first
recipe_task = asyncio.create_task(cls._initialize_recipe_cache(recipe_scanner, delay=2), name='recipe_cache_init')
# Initialize CivitaiClient first to ensure it's ready for other services
civitai_client = await ServiceRegistry.get_civitai_client()
# Get file monitors through ServiceRegistry
lora_monitor = await ServiceRegistry.get_lora_monitor()
checkpoint_monitor = await ServiceRegistry.get_checkpoint_monitor()
# Start monitors
lora_monitor.start()
logger.info("Lora monitor started")
# Make sure checkpoint monitor has paths before starting
await checkpoint_monitor.initialize_paths()
checkpoint_monitor.start()
logger.info("Checkpoint monitor started")
# Register DownloadManager with ServiceRegistry
download_manager = await ServiceRegistry.get_download_manager()
# Initialize WebSocket manager
ws_manager = await ServiceRegistry.get_websocket_manager()
# Initialize scanners in background
lora_scanner = await ServiceRegistry.get_lora_scanner()
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
# Initialize recipe scanner if needed
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
# Create low-priority initialization tasks
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init')
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init')
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
logger.info("LoRA Manager: All services initialized and background tasks scheduled")
except Exception as e:
logger.error(f"LoRA Manager: Error scheduling cache initialization: {e}")
@classmethod
async def _initialize_lora_cache(cls, scanner: LoraScanner):
"""Initialize lora cache in background"""
try:
# 设置初始缓存占位
scanner._cache = LoraCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# 分阶段加载缓存
await scanner.get_cached_data(force_refresh=True)
except Exception as e:
logger.error(f"LoRA Manager: Error initializing lora cache: {e}")
@classmethod
async def _initialize_recipe_cache(cls, scanner: RecipeScanner, delay: float = 2.0):
"""Initialize recipe cache in background with a delay"""
try:
# Wait for the specified delay to let lora scanner initialize first
await asyncio.sleep(delay)
# Set initial empty cache
scanner._cache = RecipeCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[]
)
# Force refresh to load the actual data
await scanner.get_cached_data(force_refresh=True)
except Exception as e:
logger.error(f"LoRA Manager: Error initializing recipe cache: {e}")
logger.error(f"LoRA Manager: Error initializing services: {e}", exc_info=True)
@classmethod
async def _cleanup(cls, app):
"""Cleanup resources"""
if 'lora_monitor' in app:
app['lora_monitor'].stop()
"""Cleanup resources using ServiceRegistry"""
try:
logger.info("LoRA Manager: Cleaning up services")
# Get monitors from ServiceRegistry
lora_monitor = await ServiceRegistry.get_service("lora_monitor")
if lora_monitor:
lora_monitor.stop()
logger.info("Stopped LoRA monitor")
checkpoint_monitor = await ServiceRegistry.get_service("checkpoint_monitor")
if checkpoint_monitor:
checkpoint_monitor.stop()
logger.info("Stopped checkpoint monitor")
# Close CivitaiClient gracefully
civitai_client = await ServiceRegistry.get_service("civitai_client")
if civitai_client:
await civitai_client.close()
logger.info("Closed CivitaiClient connection")
except Exception as e:
logger.error(f"Error during cleanup: {e}", exc_info=True)

View File

@@ -2,43 +2,51 @@ import os
import json
import logging
from aiohttp import web
from typing import Dict, List
from typing import Dict
from ..utils.model_utils import determine_base_model
from ..utils.routes_common import ModelRouteUtils
from ..services.file_monitor import LoraFileMonitor
from ..services.download_manager import DownloadManager
from ..services.civitai_client import CivitaiClient
from ..config import config
from ..services.lora_scanner import LoraScanner
from operator import itemgetter
from ..services.websocket_manager import ws_manager
from ..services.settings_manager import settings
import asyncio
from .update_routes import UpdateRoutes
from ..services.recipe_scanner import RecipeScanner
from ..utils.constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
from ..utils.exif_utils import ExifUtils
from ..services.service_registry import ServiceRegistry
logger = logging.getLogger(__name__)
class ApiRoutes:
"""API route handlers for LoRA management"""
def __init__(self, file_monitor: LoraFileMonitor):
self.scanner = LoraScanner()
self.civitai_client = CivitaiClient()
self.download_manager = DownloadManager(file_monitor)
def __init__(self):
self.scanner = None # Will be initialized in setup_routes
self.civitai_client = None # Will be initialized in setup_routes
self.download_manager = None # Will be initialized in setup_routes
self._download_lock = asyncio.Lock()
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
self.scanner = await ServiceRegistry.get_lora_scanner()
self.civitai_client = await ServiceRegistry.get_civitai_client()
self.download_manager = await ServiceRegistry.get_download_manager()
@classmethod
def setup_routes(cls, app: web.Application, monitor: LoraFileMonitor):
def setup_routes(cls, app: web.Application):
"""Register API routes"""
routes = cls(monitor)
routes = cls()
# Schedule service initialization on app startup
app.on_startup.append(lambda _: routes.initialize_services())
app.router.add_post('/api/delete_model', routes.delete_model)
app.router.add_post('/api/fetch-civitai', routes.fetch_civitai)
app.router.add_post('/api/replace_preview', routes.replace_preview)
app.router.add_get('/api/loras', routes.get_loras)
app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai)
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) # Add new WebSocket route
app.router.add_get('/api/lora-roots', routes.get_lora_roots)
app.router.add_get('/api/folders', routes.get_folders)
app.router.add_get('/api/civitai/versions/{model_id}', routes.get_civitai_versions)
@@ -48,160 +56,121 @@ class ApiRoutes:
app.router.add_post('/api/settings', routes.update_settings)
app.router.add_post('/api/move_model', routes.move_model)
app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route
app.router.add_post('/loras/api/save-metadata', routes.save_metadata)
app.router.add_post('/api/loras/save-metadata', routes.save_metadata)
app.router.add_get('/api/lora-preview-url', routes.get_lora_preview_url) # Add new route
app.router.add_post('/api/move_models_bulk', routes.move_models_bulk)
app.router.add_get('/api/loras/top-tags', routes.get_top_tags) # Add new route for top tags
app.router.add_get('/api/loras/base-models', routes.get_base_models) # Add new route for base models
app.router.add_get('/api/lora-civitai-url', routes.get_lora_civitai_url) # Add new route for Civitai URL
app.router.add_post('/api/rename_lora', routes.rename_lora) # Add new route for renaming LoRA files
app.router.add_get('/api/loras/scan', routes.scan_loras) # Add new route for scanning LoRA files
# Add update check routes
UpdateRoutes.setup_routes(app)
async def delete_model(self, request: web.Request) -> web.Response:
"""Handle model deletion request"""
try:
data = await request.json()
file_path = data.get('file_path')
if not file_path:
return web.Response(text='Model path is required', status=400)
target_dir = os.path.dirname(file_path)
file_name = os.path.splitext(os.path.basename(file_path))[0]
deleted_files = await self._delete_model_files(target_dir, file_name)
return web.json_response({
'success': True,
'deleted_files': deleted_files
})
except Exception as e:
logger.error(f"Error deleting model: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
return await ModelRouteUtils.handle_delete_model(request, self.scanner)
async def fetch_civitai(self, request: web.Request) -> web.Response:
"""Handle CivitAI metadata fetch request"""
try:
data = await request.json()
metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json'
# Check if model is from CivitAI
local_metadata = await self._load_local_metadata(metadata_path)
# Fetch and update metadata
civitai_metadata = await self.civitai_client.get_model_by_hash(local_metadata["sha256"])
if not civitai_metadata:
return await self._handle_not_found_on_civitai(metadata_path, local_metadata)
await self._update_model_metadata(metadata_path, local_metadata, civitai_metadata, self.civitai_client)
return web.json_response({"success": True})
except Exception as e:
logger.error(f"Error fetching from CivitAI: {e}", exc_info=True)
return web.json_response({"success": False, "error": str(e)}, status=500)
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
return await ModelRouteUtils.handle_fetch_civitai(request, self.scanner)
async def replace_preview(self, request: web.Request) -> web.Response:
"""Handle preview image replacement request"""
try:
reader = await request.multipart()
preview_data, content_type = await self._read_preview_file(reader)
model_path = await self._read_model_path(reader)
preview_path = await self._save_preview_file(model_path, preview_data, content_type)
await self._update_preview_metadata(model_path, preview_path)
# Update preview URL in scanner cache
await self.scanner.update_preview_in_cache(model_path, preview_path)
return web.json_response({
"success": True,
"preview_url": config.get_preview_static_url(preview_path)
})
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
async def scan_loras(self, request: web.Request) -> web.Response:
"""Force a rescan of LoRA files"""
try:
await self.scanner.get_cached_data(force_refresh=True)
return web.json_response({"status": "success", "message": "LoRA scan completed"})
except Exception as e:
logger.error(f"Error replacing preview: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
logger.error(f"Error in scan_loras: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_loras(self, request: web.Request) -> web.Response:
"""Handle paginated LoRA data request"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Parse query parameters
page = int(request.query.get('page', '1'))
page_size = int(request.query.get('page_size', '20'))
sort_by = request.query.get('sort_by', 'name')
folder = request.query.get('folder')
search = request.query.get('search', '').lower()
fuzzy = request.query.get('fuzzy', 'false').lower() == 'true'
# Parse base models filter parameter
base_models = request.query.get('base_models', '').split(',')
base_models = [model.strip() for model in base_models if model.strip()]
folder = request.query.get('folder', None)
search = request.query.get('search', None)
fuzzy_search = request.query.get('fuzzy', 'false').lower() == 'true'
# Parse search options
search_filename = request.query.get('search_filename', 'true').lower() == 'true'
search_modelname = request.query.get('search_modelname', 'true').lower() == 'true'
search_tags = request.query.get('search_tags', 'false').lower() == 'true'
recursive = request.query.get('recursive', 'false').lower() == 'true'
search_options = {
'filename': request.query.get('search_filename', 'true').lower() == 'true',
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
'tags': request.query.get('search_tags', 'false').lower() == 'true',
'recursive': request.query.get('recursive', 'false').lower() == 'true'
}
# Validate parameters
if page < 1 or page_size < 1 or page_size > 100:
return web.json_response({
'error': 'Invalid pagination parameters'
}, status=400)
# Get filter parameters
base_models = request.query.get('base_models', None)
tags = request.query.get('tags', None)
if sort_by not in ['date', 'name']:
return web.json_response({
'error': 'Invalid sort parameter'
}, status=400)
# New parameters for recipe filtering
lora_hash = request.query.get('lora_hash', None)
lora_hashes = request.query.get('lora_hashes', None)
# Parse tags filter parameter
tags = request.query.get('tags', '').split(',')
tags = [tag.strip() for tag in tags if tag.strip()]
# Parse filter parameters
filters = {}
if base_models:
filters['base_model'] = base_models.split(',')
if tags:
filters['tags'] = tags.split(',')
# Get paginated data with search and filters
result = await self.scanner.get_paginated_data(
page=page,
page_size=page_size,
sort_by=sort_by,
# Add lora hash filtering options
hash_filters = {}
if lora_hash:
hash_filters['single_hash'] = lora_hash.lower()
elif lora_hashes:
hash_filters['multiple_hashes'] = [h.lower() for h in lora_hashes.split(',')]
# Get file data
data = await self.scanner.get_paginated_data(
page,
page_size,
sort_by=sort_by,
folder=folder,
search=search,
fuzzy=fuzzy,
base_models=base_models, # Pass base models filter
tags=tags, # Add tags parameter
search_options={
'filename': search_filename,
'modelname': search_modelname,
'tags': search_tags,
'recursive': recursive
}
fuzzy_search=fuzzy_search,
base_models=filters.get('base_model', None),
tags=filters.get('tags', None),
search_options=search_options,
hash_filters=hash_filters
)
# Format the response data
formatted_items = [
self._format_lora_response(item)
for item in result['items']
]
# Get all available folders from cache
cache = await self.scanner.get_cached_data()
return web.json_response({
'items': formatted_items,
'total': result['total'],
'page': result['page'],
'page_size': result['page_size'],
'total_pages': result['total_pages'],
'folders': cache.folders
})
# Convert output to match expected format
result = {
'items': [self._format_lora_response(lora) for lora in data['items']],
'folders': cache.folders,
'total': data['total'],
'page': data['page'],
'page_size': data['page_size'],
'total_pages': data['total_pages']
}
return web.json_response(result)
except Exception as e:
logger.error(f"Error in get_loras: {str(e)}", exc_info=True)
return web.json_response({
'error': 'Internal server error'
}, status=500)
logger.error(f"Error retrieving loras: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
def _format_lora_response(self, lora: Dict) -> Dict:
"""Format LoRA data for API response"""
@@ -221,73 +190,10 @@ class ApiRoutes:
"from_civitai": lora.get("from_civitai", True),
"usage_tips": lora.get("usage_tips", ""),
"notes": lora.get("notes", ""),
"civitai": self._filter_civitai_data(lora.get("civitai", {}))
"civitai": ModelRouteUtils.filter_civitai_data(lora.get("civitai", {}))
}
def _filter_civitai_data(self, data: Dict) -> Dict:
"""Filter relevant fields from CivitAI data"""
if not data:
return {}
fields = [
"id", "modelId", "name", "createdAt", "updatedAt",
"publishedAt", "trainedWords", "baseModel", "description",
"model", "images"
]
return {k: data[k] for k in fields if k in data}
# Private helper methods
async def _delete_model_files(self, target_dir: str, file_name: str) -> List[str]:
"""Delete model and associated files"""
patterns = [
f"{file_name}.safetensors", # Required
f"{file_name}.metadata.json",
f"{file_name}.preview.png",
f"{file_name}.preview.jpg",
f"{file_name}.preview.jpeg",
f"{file_name}.preview.webp",
f"{file_name}.preview.mp4",
f"{file_name}.png",
f"{file_name}.jpg",
f"{file_name}.jpeg",
f"{file_name}.webp",
f"{file_name}.mp4"
]
deleted = []
main_file = patterns[0]
main_path = os.path.join(target_dir, main_file).replace(os.sep, '/')
if os.path.exists(main_path):
# Notify file monitor to ignore delete event
self.download_manager.file_monitor.handler.add_ignore_path(main_path, 0)
# Delete file
os.remove(main_path)
deleted.append(main_path)
else:
logger.warning(f"Model file not found: {main_file}")
# Remove from cache
cache = await self.scanner.get_cached_data()
cache.raw_data = [item for item in cache.raw_data if item['file_path'] != main_path]
await cache.resort()
# update hash index
self.scanner._hash_index.remove_by_path(main_path)
# Delete optional files
for pattern in patterns[1:]:
path = os.path.join(target_dir, pattern)
if os.path.exists(path):
try:
os.remove(path)
deleted.append(pattern)
except Exception as e:
logger.warning(f"Failed to delete {pattern}: {e}")
return deleted
async def _read_preview_file(self, reader) -> tuple[bytes, str]:
"""Read preview file and content type from multipart request"""
field = await reader.next()
@@ -305,18 +211,29 @@ class ApiRoutes:
async def _save_preview_file(self, model_path: str, preview_data: bytes, content_type: str) -> str:
"""Save preview file and return its path"""
# Determine file extension based on content type
if content_type.startswith('video/'):
extension = '.preview.mp4'
else:
extension = '.preview.png'
base_name = os.path.splitext(os.path.basename(model_path))[0]
folder = os.path.dirname(model_path)
# Determine if content is video or image
if content_type.startswith('video/'):
# For videos, keep original format and use .mp4 extension
extension = '.mp4'
optimized_data = preview_data
else:
# For images, optimize and convert to WebP
optimized_data, _ = ExifUtils.optimize_image(
image_data=preview_data,
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=True
)
extension = '.webp' # Use .webp without .preview part
preview_path = os.path.join(folder, base_name + extension).replace(os.sep, '/')
with open(preview_path, 'wb') as f:
f.write(preview_data)
f.write(optimized_data)
return preview_path
@@ -336,83 +253,26 @@ class ApiRoutes:
except Exception as e:
logger.error(f"Error updating metadata: {e}")
async def _load_local_metadata(self, metadata_path: str) -> Dict:
"""Load local metadata file"""
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading metadata from {metadata_path}: {e}")
return {}
async def _handle_not_found_on_civitai(self, metadata_path: str, local_metadata: Dict) -> web.Response:
"""Handle case when model is not found on CivitAI"""
local_metadata['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
return web.json_response(
{"success": False, "error": "Not found on CivitAI"},
status=404
)
async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict,
civitai_metadata: Dict, client: CivitaiClient) -> None:
"""Update local metadata with CivitAI data"""
local_metadata['civitai'] = civitai_metadata
# Update model name if available
if 'model' in civitai_metadata:
if civitai_metadata.get('model', {}).get('name'):
local_metadata['model_name'] = civitai_metadata['model']['name']
# Fetch additional model metadata (description and tags) if we have model ID
model_id = civitai_metadata['modelId']
if model_id:
model_metadata, _ = await client.get_model_metadata(str(model_id))
if model_metadata:
local_metadata['modelDescription'] = model_metadata.get('description', '')
local_metadata['tags'] = model_metadata.get('tags', [])
# Update base model
local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel'))
# Update preview if needed
if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']):
first_preview = next((img for img in civitai_metadata.get('images', [])), None)
if first_preview:
preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1]
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
preview_filename = base_name + preview_ext
preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename)
if await client.download_preview_image(first_preview['url'], preview_path):
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
await self.scanner.update_single_lora_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata)
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
"""Fetch CivitAI metadata for all loras in the background"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
cache = await self.scanner.get_cached_data()
total = len(cache.raw_data)
processed = 0
success = 0
needs_resort = False
# 准备要处理的 loras
# Prepare loras to process
to_process = [
lora for lora in cache.raw_data
if lora.get('sha256') and (not lora.get('civitai') or 'id' not in lora.get('civitai')) and lora.get('from_civitai') # TODO: for lora not from CivitAI but added traineWords
if lora.get('sha256') and (not lora.get('civitai') or 'id' not in lora.get('civitai')) and lora.get('from_civitai', True) # TODO: for lora not from CivitAI but added traineWords
]
total_to_process = len(to_process)
# 发送初始进度
# Send initial progress
await ws_manager.broadcast({
'status': 'started',
'total': total_to_process,
@@ -423,10 +283,11 @@ class ApiRoutes:
for lora in to_process:
try:
original_name = lora.get('model_name')
if await self._fetch_and_update_single_lora(
if await ModelRouteUtils.fetch_and_update_model(
sha256=lora['sha256'],
file_path=lora['file_path'],
lora=lora
model_data=lora,
update_cache_func=self.scanner.update_single_model_cache
):
success += 1
if original_name != lora.get('model_name'):
@@ -434,7 +295,7 @@ class ApiRoutes:
processed += 1
# 每处理一个就发送进度更新
# Send progress update
await ws_manager.broadcast({
'status': 'processing',
'total': total_to_process,
@@ -449,7 +310,7 @@ class ApiRoutes:
if needs_resort:
await cache.resort(name_only=True)
# 发送完成消息
# Send completion message
await ws_manager.broadcast({
'status': 'completed',
'total': total_to_process,
@@ -463,7 +324,7 @@ class ApiRoutes:
})
except Exception as e:
# 发送错误消息
# Send error message
await ws_manager.broadcast({
'status': 'error',
'error': str(e)
@@ -471,58 +332,6 @@ class ApiRoutes:
logger.error(f"Error in fetch_all_civitai: {e}")
return web.Response(text=str(e), status=500)
async def _fetch_and_update_single_lora(self, sha256: str, file_path: str, lora: dict) -> bool:
"""Fetch and update metadata for a single lora without sorting
Args:
sha256: SHA256 hash of the lora file
file_path: Path to the lora file
lora: The lora object in cache to update
Returns:
bool: True if successful, False otherwise
"""
client = CivitaiClient()
try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Check if model is from CivitAI
local_metadata = await self._load_local_metadata(metadata_path)
# Fetch metadata
civitai_metadata = await client.get_model_by_hash(sha256)
if not civitai_metadata:
# Mark as not from CivitAI if not found
local_metadata['from_civitai'] = False
lora['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
return False
# Update metadata
await self._update_model_metadata(
metadata_path,
local_metadata,
civitai_metadata,
client
)
# Update cache object directly
lora.update({
'model_name': local_metadata.get('model_name'),
'preview_url': local_metadata.get('preview_url'),
'from_civitai': True,
'civitai': civitai_metadata
})
return True
except Exception as e:
logger.error(f"Error fetching CivitAI data: {e}")
return False
finally:
await client.close()
async def get_lora_roots(self, request: web.Request) -> web.Response:
"""Get all configured LoRA root directories"""
return web.json_response({
@@ -531,6 +340,9 @@ class ApiRoutes:
async def get_folders(self, request: web.Request) -> web.Response:
"""Get all folders in the cache"""
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
cache = await self.scanner.get_cached_data()
return web.json_response({
'folders': cache.folders
@@ -539,11 +351,26 @@ class ApiRoutes:
async def get_civitai_versions(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai model with local availability info"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
if self.civitai_client is None:
self.civitai_client = await ServiceRegistry.get_civitai_client()
model_id = request.match_info['model_id']
versions = await self.civitai_client.get_model_versions(model_id)
if not versions:
response = await self.civitai_client.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - should be LORA
if model_type.lower() != 'lora':
return web.json_response({
'error': f"Model type mismatch. Expected LORA, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the model file (type="Model") in the files list
@@ -554,9 +381,9 @@ class ApiRoutes:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.scanner.has_lora_hash(sha256)
version['existsLocally'] = self.scanner.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.scanner.get_lora_path_by_hash(sha256)
version['localPath'] = self.scanner.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
@@ -572,9 +399,12 @@ class ApiRoutes:
async def get_civitai_model(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by model version ID or hash"""
try:
model_version_id = request.match_info['modelVersionId']
if self.civitai_client is None:
self.civitai_client = await ServiceRegistry.get_civitai_client()
model_version_id = request.match_info.get('modelVersionId')
if not model_version_id:
hash = request.match_info['hash']
hash = request.match_info.get('hash')
model = await self.civitai_client.get_model_by_hash(hash)
return web.json_response(model)
@@ -589,6 +419,9 @@ class ApiRoutes:
async def download_lora(self, request: web.Request) -> web.Response:
async with self._download_lock:
try:
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
data = await request.json()
# Create progress callback
@@ -660,12 +493,15 @@ class ApiRoutes:
return web.json_response({'success': True})
except Exception as e:
logger.error(f"Error updating settings: {e}", exc_info=True) # 添加 exc_info=True 以获取完整堆栈
logger.error(f"Error updating settings: {e}", exc_info=True)
return web.Response(status=500, text=str(e))
async def move_model(self, request: web.Request) -> web.Response:
"""Handle model move request"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
data = await request.json()
file_path = data.get('file_path') # full path of the model file, e.g. /path/to/model.safetensors
target_path = data.get('target_path') # folder path to move the model to, e.g. /path/to/target_folder
@@ -704,12 +540,17 @@ class ApiRoutes:
@classmethod
async def cleanup(cls):
"""Add cleanup method for application shutdown"""
if hasattr(cls, '_instance'):
await cls._instance.civitai_client.close()
# Now we don't need to store an instance, as services are managed by ServiceRegistry
civitai_client = await ServiceRegistry.get_civitai_client()
if civitai_client:
await civitai_client.close()
async def save_metadata(self, request: web.Request) -> web.Response:
"""Handle saving metadata updates"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
data = await request.json()
file_path = data.get('file_path')
if not file_path:
@@ -722,11 +563,7 @@ class ApiRoutes:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Load existing metadata
if os.path.exists(metadata_path):
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
else:
metadata = {}
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
# Handle nested updates (for civitai.trainedWords)
for key, value in metadata_updates.items():
@@ -743,7 +580,7 @@ class ApiRoutes:
json.dump(metadata, f, indent=2, ensure_ascii=False)
# Update cache
await self.scanner.update_single_lora_cache(file_path, file_path, metadata)
await self.scanner.update_single_model_cache(file_path, file_path, metadata)
# If model_name was updated, resort the cache
if 'model_name' in metadata_updates:
@@ -759,6 +596,9 @@ class ApiRoutes:
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
"""Get the static preview URL for a LoRA file"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Get lora file name from query parameters
lora_name = request.query.get('name')
if not lora_name:
@@ -789,11 +629,17 @@ class ApiRoutes:
except Exception as e:
logger.error(f"Error getting lora preview URL: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
"""Get the Civitai URL for a LoRA file"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Get lora file name from query parameters
lora_name = request.query.get('name')
if not lora_name:
@@ -831,11 +677,17 @@ class ApiRoutes:
except Exception as e:
logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def move_models_bulk(self, request: web.Request) -> web.Response:
"""Handle bulk model move request"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
data = await request.json()
file_paths = data.get('file_paths', []) # list of full paths of the model files, e.g. ["/path/to/model1.safetensors", "/path/to/model2.safetensors"]
target_path = data.get('target_path') # folder path to move the models to, e.g. "/path/to/target_folder"
@@ -894,6 +746,9 @@ class ApiRoutes:
async def get_lora_model_description(self, request: web.Request) -> web.Response:
"""Get model description for a Lora model"""
try:
if self.civitai_client is None:
self.civitai_client = await ServiceRegistry.get_civitai_client()
# Get parameters
model_id = request.query.get('model_id')
file_path = request.query.get('file_path')
@@ -909,14 +764,9 @@ class ApiRoutes:
tags = []
if file_path:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
description = metadata.get('modelDescription')
tags = metadata.get('tags', [])
except Exception as e:
logger.error(f"Error loading metadata from {metadata_path}: {e}")
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
description = metadata.get('modelDescription')
tags = metadata.get('tags', [])
# If description is not in metadata, fetch from CivitAI
if not description:
@@ -931,16 +781,14 @@ class ApiRoutes:
if file_path:
try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
if os.path.exists(metadata_path):
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
metadata['modelDescription'] = description
metadata['tags'] = tags
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
logger.info(f"Saved model metadata to file for {file_path}")
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
metadata['modelDescription'] = description
metadata['tags'] = tags
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
logger.info(f"Saved model metadata to file for {file_path}")
except Exception as e:
logger.error(f"Error saving model metadata: {e}")
@@ -951,7 +799,7 @@ class ApiRoutes:
})
except Exception as e:
logger.error(f"Error getting model metadata: {e}", exc_info=True)
logger.error(f"Error getting model metadata: {e}")
return web.json_response({
'success': False,
'error': str(e)
@@ -960,6 +808,9 @@ class ApiRoutes:
async def get_top_tags(self, request: web.Request) -> web.Response:
"""Handle request for top tags sorted by frequency"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Parse query parameters
limit = int(request.query.get('limit', '20'))
@@ -985,6 +836,9 @@ class ApiRoutes:
async def get_base_models(self, request: web.Request) -> web.Response:
"""Get base models used in loras"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Parse query parameters
limit = int(request.query.get('limit', '20'))
@@ -1006,15 +860,15 @@ class ApiRoutes:
'error': str(e)
}, status=500)
def get_multipart_ext(self, filename):
parts = filename.split(".")
if len(parts) > 2: # 如果包含多级扩展名
return "." + ".".join(parts[-2:]) # 取最后两部分,如 ".metadata.json"
return os.path.splitext(filename)[1] # 否则取普通扩展名,如 ".safetensors"
async def rename_lora(self, request: web.Request) -> web.Response:
"""Handle renaming a LoRA file and its associated files"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
data = await request.json()
file_path = data.get('file_path')
new_file_name = data.get('new_file_name')
@@ -1049,18 +903,12 @@ class ApiRoutes:
patterns = [
f"{old_file_name}.safetensors", # Required
f"{old_file_name}.metadata.json",
f"{old_file_name}.preview.png",
f"{old_file_name}.preview.jpg",
f"{old_file_name}.preview.jpeg",
f"{old_file_name}.preview.webp",
f"{old_file_name}.preview.mp4",
f"{old_file_name}.png",
f"{old_file_name}.jpg",
f"{old_file_name}.jpeg",
f"{old_file_name}.webp",
f"{old_file_name}.mp4"
]
# Add all preview file extensions
for ext in PREVIEW_EXTENSIONS:
patterns.append(f"{old_file_name}{ext}")
# Find all matching files
existing_files = []
for pattern in patterns:
@@ -1074,12 +922,8 @@ class ApiRoutes:
metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json")
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
hash_value = metadata.get('sha256')
except Exception as e:
logger.error(f"Error loading metadata for rename: {e}")
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
hash_value = metadata.get('sha256')
# Rename all files
renamed_files = []
@@ -1087,15 +931,18 @@ class ApiRoutes:
# Notify file monitor to ignore these events
main_file_path = os.path.join(target_dir, f"{old_file_name}.safetensors")
if os.path.exists(main_file_path) and self.download_manager.file_monitor:
# Add old and new paths to ignore list
file_size = os.path.getsize(main_file_path)
self.download_manager.file_monitor.handler.add_ignore_path(main_file_path, file_size)
self.download_manager.file_monitor.handler.add_ignore_path(new_file_path, file_size)
if os.path.exists(main_file_path):
# Get lora monitor through ServiceRegistry instead of download_manager
lora_monitor = await ServiceRegistry.get_lora_monitor()
if lora_monitor:
# Add old and new paths to ignore list
file_size = os.path.getsize(main_file_path)
lora_monitor.handler.add_ignore_path(main_file_path, file_size)
lora_monitor.handler.add_ignore_path(new_file_path, file_size)
for old_path, pattern in existing_files:
# Get the file extension like .safetensors or .metadata.json
ext = self.get_multipart_ext(pattern)
ext = ModelRouteUtils.get_multipart_ext(pattern)
# Create the new path
new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/')
@@ -1117,7 +964,7 @@ class ApiRoutes:
# Update preview_url if it exists
if 'preview_url' in metadata and metadata['preview_url']:
old_preview = metadata['preview_url']
ext = self.get_multipart_ext(old_preview)
ext = ModelRouteUtils.get_multipart_ext(old_preview)
new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/')
metadata['preview_url'] = new_preview
@@ -1127,11 +974,11 @@ class ApiRoutes:
# Update the scanner cache
if metadata:
await self.scanner.update_single_lora_cache(file_path, new_file_path, metadata)
await self.scanner.update_single_model_cache(file_path, new_file_path, metadata)
# Update recipe files and cache if hash is available
if hash_value:
recipe_scanner = RecipeScanner(self.scanner)
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
recipes_updated, cache_updated = await recipe_scanner.update_lora_filename_by_hash(hash_value, new_file_name)
logger.info(f"Updated {recipes_updated} recipe files and {cache_updated} cache entries for renamed LoRA")

View File

@@ -1,37 +1,483 @@
import os
from aiohttp import web
import json
import jinja2
from aiohttp import web
import logging
import asyncio
from ..utils.routes_common import ModelRouteUtils
from ..utils.constants import NSFW_LEVELS
from ..services.websocket_manager import ws_manager
from ..services.service_registry import ServiceRegistry
from ..config import config
from ..services.settings_manager import settings
from ..utils.utils import fuzzy_match
logger = logging.getLogger(__name__)
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
class CheckpointsRoutes:
"""Route handlers for Checkpoints management endpoints"""
"""API routes for checkpoint management"""
def __init__(self):
self.scanner = None # Will be initialized in setup_routes
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
self.download_manager = None # Will be initialized in setup_routes
self._download_lock = asyncio.Lock()
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
self.download_manager = await ServiceRegistry.get_download_manager()
def setup_routes(self, app):
"""Register routes with the aiohttp app"""
# Schedule service initialization on app startup
app.on_startup.append(lambda _: self.initialize_services())
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
app.router.add_get('/api/checkpoints', self.get_checkpoints)
app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai)
app.router.add_get('/api/checkpoints/base-models', self.get_base_models)
app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags)
app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints)
app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info)
app.router.add_get('/api/checkpoints/roots', self.get_checkpoint_roots)
app.router.add_get('/api/checkpoints/civitai/versions/{model_id}', self.get_civitai_versions) # Add new route
# Add new routes for model management similar to LoRA routes
app.router.add_post('/api/checkpoints/delete', self.delete_model)
app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai)
app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview)
app.router.add_post('/api/checkpoints/download', self.download_checkpoint)
app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route
# Add new WebSocket endpoint for checkpoint progress
app.router.add_get('/ws/checkpoint-progress', ws_manager.handle_checkpoint_connection)
async def get_checkpoints(self, request):
"""Get paginated checkpoint data"""
try:
# Parse query parameters
page = int(request.query.get('page', '1'))
page_size = min(int(request.query.get('page_size', '20')), 100)
sort_by = request.query.get('sort_by', 'name')
folder = request.query.get('folder', None)
search = request.query.get('search', None)
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
base_models = request.query.getall('base_model', [])
tags = request.query.getall('tag', [])
# Process search options
search_options = {
'filename': request.query.get('search_filename', 'true').lower() == 'true',
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
'tags': request.query.get('search_tags', 'false').lower() == 'true',
'recursive': request.query.get('recursive', 'false').lower() == 'true',
}
# Process hash filters if provided
hash_filters = {}
if 'hash' in request.query:
hash_filters['single_hash'] = request.query['hash']
elif 'hashes' in request.query:
try:
hash_list = json.loads(request.query['hashes'])
if isinstance(hash_list, list):
hash_filters['multiple_hashes'] = hash_list
except (json.JSONDecodeError, TypeError):
pass
# Get data from scanner
result = await self.get_paginated_data(
page=page,
page_size=page_size,
sort_by=sort_by,
folder=folder,
search=search,
fuzzy_search=fuzzy_search,
base_models=base_models,
tags=tags,
search_options=search_options,
hash_filters=hash_filters
)
# Format response items
formatted_result = {
'items': [self._format_checkpoint_response(cp) for cp in result['items']],
'total': result['total'],
'page': result['page'],
'page_size': result['page_size'],
'total_pages': result['total_pages']
}
# Return as JSON
return web.json_response(formatted_result)
except Exception as e:
logger.error(f"Error in get_checkpoints: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_paginated_data(self, page, page_size, sort_by='name',
folder=None, search=None, fuzzy_search=False,
base_models=None, tags=None,
search_options=None, hash_filters=None):
"""Get paginated and filtered checkpoint data"""
cache = await self.scanner.get_cached_data()
# Get default search options if not provided
if search_options is None:
search_options = {
'filename': True,
'modelname': True,
'tags': False,
'recursive': False,
}
# Get the base data set
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
# Apply hash filtering if provided (highest priority)
if hash_filters:
single_hash = hash_filters.get('single_hash')
multiple_hashes = hash_filters.get('multiple_hashes')
if single_hash:
# Filter by single hash
single_hash = single_hash.lower() # Ensure lowercase for matching
filtered_data = [
cp for cp in filtered_data
if cp.get('sha256', '').lower() == single_hash
]
elif multiple_hashes:
# Filter by multiple hashes
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
filtered_data = [
cp for cp in filtered_data
if cp.get('sha256', '').lower() in hash_set
]
# Jump to pagination
total_items = len(filtered_data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
result = {
'items': filtered_data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
return result
# Apply SFW filtering if enabled in settings
if settings.get('show_only_sfw', False):
filtered_data = [
cp for cp in filtered_data
if not cp.get('preview_nsfw_level') or cp.get('preview_nsfw_level') < NSFW_LEVELS['R']
]
# Apply folder filtering
if folder is not None:
if search_options.get('recursive', False):
# Recursive folder filtering - include all subfolders
filtered_data = [
cp for cp in filtered_data
if cp['folder'].startswith(folder)
]
else:
# Exact folder filtering
filtered_data = [
cp for cp in filtered_data
if cp['folder'] == folder
]
# Apply base model filtering
if base_models and len(base_models) > 0:
filtered_data = [
cp for cp in filtered_data
if cp.get('base_model') in base_models
]
# Apply tag filtering
if tags and len(tags) > 0:
filtered_data = [
cp for cp in filtered_data
if any(tag in cp.get('tags', []) for tag in tags)
]
# Apply search filtering
if search:
search_results = []
for cp in filtered_data:
# Search by file name
if search_options.get('filename', True):
if fuzzy_search:
if fuzzy_match(cp.get('file_name', ''), search):
search_results.append(cp)
continue
elif search.lower() in cp.get('file_name', '').lower():
search_results.append(cp)
continue
# Search by model name
if search_options.get('modelname', True):
if fuzzy_search:
if fuzzy_match(cp.get('model_name', ''), search):
search_results.append(cp)
continue
elif search.lower() in cp.get('model_name', '').lower():
search_results.append(cp)
continue
# Search by tags
if search_options.get('tags', False) and 'tags' in cp:
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower()) for tag in cp['tags']):
search_results.append(cp)
continue
filtered_data = search_results
# Calculate pagination
total_items = len(filtered_data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
result = {
'items': filtered_data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
return result
def _format_checkpoint_response(self, checkpoint):
"""Format checkpoint data for API response"""
return {
"model_name": checkpoint["model_name"],
"file_name": checkpoint["file_name"],
"preview_url": config.get_preview_static_url(checkpoint.get("preview_url", "")),
"preview_nsfw_level": checkpoint.get("preview_nsfw_level", 0),
"base_model": checkpoint.get("base_model", ""),
"folder": checkpoint["folder"],
"sha256": checkpoint.get("sha256", ""),
"file_path": checkpoint["file_path"].replace(os.sep, "/"),
"file_size": checkpoint.get("size", 0),
"modified": checkpoint.get("modified", ""),
"tags": checkpoint.get("tags", []),
"modelDescription": checkpoint.get("modelDescription", ""),
"from_civitai": checkpoint.get("from_civitai", True),
"notes": checkpoint.get("notes", ""),
"model_type": checkpoint.get("model_type", "checkpoint"),
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint.get("civitai", {}))
}
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
"""Fetch CivitAI metadata for all checkpoints in the background"""
try:
cache = await self.scanner.get_cached_data()
total = len(cache.raw_data)
processed = 0
success = 0
needs_resort = False
# Prepare checkpoints to process
to_process = [
cp for cp in cache.raw_data
if cp.get('sha256') and (not cp.get('civitai') or 'id' not in cp.get('civitai')) and cp.get('from_civitai', True)
]
total_to_process = len(to_process)
# Send initial progress
await ws_manager.broadcast({
'status': 'started',
'total': total_to_process,
'processed': 0,
'success': 0
})
# Process each checkpoint
for cp in to_process:
try:
original_name = cp.get('model_name')
if await ModelRouteUtils.fetch_and_update_model(
sha256=cp['sha256'],
file_path=cp['file_path'],
model_data=cp,
update_cache_func=self.scanner.update_single_model_cache
):
success += 1
if original_name != cp.get('model_name'):
needs_resort = True
processed += 1
# Send progress update
await ws_manager.broadcast({
'status': 'processing',
'total': total_to_process,
'processed': processed,
'success': success,
'current_name': cp.get('model_name', 'Unknown')
})
except Exception as e:
logger.error(f"Error fetching CivitAI data for {cp['file_path']}: {e}")
if needs_resort:
await cache.resort(name_only=True)
# Send completion message
await ws_manager.broadcast({
'status': 'completed',
'total': total_to_process,
'processed': processed,
'success': success
})
return web.json_response({
"success": True,
"message": f"Successfully updated {success} of {processed} processed checkpoints (total: {total})"
})
except Exception as e:
# Send error message
await ws_manager.broadcast({
'status': 'error',
'error': str(e)
})
logger.error(f"Error in fetch_all_civitai for checkpoints: {e}")
return web.Response(text=str(e), status=500)
async def get_top_tags(self, request: web.Request) -> web.Response:
"""Handle request for top tags sorted by frequency"""
try:
# Parse query parameters
limit = int(request.query.get('limit', '20'))
# Validate limit
if limit < 1 or limit > 100:
limit = 20 # Default to a reasonable limit
# Get top tags
top_tags = await self.scanner.get_top_tags(limit)
return web.json_response({
'success': True,
'tags': top_tags
})
except Exception as e:
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
return web.json_response({
'success': False,
'error': 'Internal server error'
}, status=500)
async def get_base_models(self, request: web.Request) -> web.Response:
"""Get base models used in loras"""
try:
# Parse query parameters
limit = int(request.query.get('limit', '20'))
# Validate limit
if limit < 1 or limit > 100:
limit = 20 # Default to a reasonable limit
# Get base models
base_models = await self.scanner.get_base_models(limit)
return web.json_response({
'success': True,
'base_models': base_models
})
except Exception as e:
logger.error(f"Error retrieving base models: {e}")
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def scan_checkpoints(self, request):
"""Force a rescan of checkpoint files"""
try:
await self.scanner.get_cached_data(force_refresh=True)
return web.json_response({"status": "success", "message": "Checkpoint scan completed"})
except Exception as e:
logger.error(f"Error in scan_checkpoints: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_checkpoint_info(self, request):
"""Get detailed information for a specific checkpoint by name"""
try:
name = request.match_info.get('name', '')
checkpoint_info = await self.scanner.get_checkpoint_info_by_name(name)
if checkpoint_info:
return web.json_response(checkpoint_info)
else:
return web.json_response({"error": "Checkpoint not found"}, status=404)
except Exception as e:
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def handle_checkpoints_page(self, request: web.Request) -> web.Response:
"""Handle GET /checkpoints request"""
try:
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
is_initializing=False,
settings=settings,
request=request
# Check if the CheckpointScanner is initializing
# It's initializing if the cache object doesn't exist yet,
# OR if the scanner explicitly says it's initializing (background task running).
is_initializing = (
self.scanner._cache is None or
(hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing)
)
if is_initializing:
# If still initializing, return loading page
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=[], # 空文件夹列表
is_initializing=True, # 新增标志
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
)
logger.info("Checkpoints page is initializing, returning loading page")
else:
# 正常流程 - 获取已经初始化好的缓存数据
try:
cache = await self.scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=cache.folders,
is_initializing=False,
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
)
except Exception as cache_error:
logger.error(f"Error loading checkpoints cache data: {cache_error}")
# 如果获取缓存失败,也显示初始化页面
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
logger.info("Checkpoints cache error, returning initialization page")
return web.Response(
text=rendered,
content_type='text/html'
)
except Exception as e:
logger.error(f"Error handling checkpoints request: {e}", exc_info=True)
return web.Response(
@@ -39,6 +485,194 @@ class CheckpointsRoutes:
status=500
)
def setup_routes(self, app: web.Application):
"""Register routes with the application"""
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
async def delete_model(self, request: web.Request) -> web.Response:
"""Handle checkpoint model deletion request"""
return await ModelRouteUtils.handle_delete_model(request, self.scanner)
async def fetch_civitai(self, request: web.Request) -> web.Response:
"""Handle CivitAI metadata fetch request for checkpoints"""
return await ModelRouteUtils.handle_fetch_civitai(request, self.scanner)
async def replace_preview(self, request: web.Request) -> web.Response:
"""Handle preview image replacement for checkpoints"""
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
async def download_checkpoint(self, request: web.Request) -> web.Response:
"""Handle checkpoint download request"""
async with self._download_lock:
# Get the download manager from service registry if not already initialized
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
try:
data = await request.json()
# Create progress callback that uses checkpoint-specific WebSocket
async def progress_callback(progress):
await ws_manager.broadcast_checkpoint_progress({
'status': 'progress',
'progress': progress
})
# Check which identifier is provided
download_url = data.get('download_url')
model_hash = data.get('model_hash')
model_version_id = data.get('model_version_id')
# Validate that at least one identifier is provided
if not any([download_url, model_hash, model_version_id]):
return web.Response(
status=400,
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
)
result = await self.download_manager.download_from_civitai(
download_url=download_url,
model_hash=model_hash,
model_version_id=model_version_id,
save_dir=data.get('checkpoint_root'),
relative_path=data.get('relative_path', ''),
progress_callback=progress_callback,
model_type="checkpoint"
)
if not result.get('success', False):
error_message = result.get('error', 'Unknown error')
# Return 401 for early access errors
if 'early access' in error_message.lower():
logger.warning(f"Early access download failed: {error_message}")
return web.Response(
status=401,
text=f"Early Access Restriction: {error_message}"
)
return web.Response(status=500, text=error_message)
return web.json_response(result)
except Exception as e:
error_message = str(e)
# Check if this might be an early access error
if '401' in error_message:
logger.warning(f"Early access error (401): {error_message}")
return web.Response(
status=401,
text="Early Access Restriction: This model requires purchase. Please ensure you have purchased early access and are logged in to Civitai."
)
logger.error(f"Error downloading checkpoint: {error_message}")
return web.Response(status=500, text=error_message)
async def get_checkpoint_roots(self, request):
"""Return the checkpoint root directories"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
roots = self.scanner.get_model_roots()
return web.json_response({
"success": True,
"roots": roots
})
except Exception as e:
logger.error(f"Error getting checkpoint roots: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def save_metadata(self, request: web.Request) -> web.Response:
"""Handle saving metadata updates for checkpoints"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
data = await request.json()
file_path = data.get('file_path')
if not file_path:
return web.Response(text='File path is required', status=400)
# Remove file path from data to avoid saving it
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
# Get metadata file path
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Load existing metadata
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
# Update metadata
metadata.update(metadata_updates)
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
# Update cache
await self.scanner.update_single_model_cache(file_path, file_path, metadata)
# If model_name was updated, resort the cache
if 'model_name' in metadata_updates:
cache = await self.scanner.get_cached_data()
await cache.resort(name_only=True)
return web.json_response({'success': True})
except Exception as e:
logger.error(f"Error saving checkpoint metadata: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
async def get_civitai_versions(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai checkpoint model with local availability info"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
# Get the civitai client from service registry
civitai_client = await ServiceRegistry.get_civitai_client()
model_id = request.match_info['model_id']
response = await civitai_client.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - should be Checkpoint
if model_type.lower() != 'checkpoint':
return web.json_response({
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the primary model file (type="Model" and primary=true) in the files list
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model' and file.get('primary') == True), None)
# If no primary file found, try to find any model file
if not model_file:
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.scanner.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.scanner.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching checkpoint model versions: {e}")
return web.Response(status=500, text=str(e))

View File

@@ -1,12 +1,11 @@
import os
from aiohttp import web
import jinja2
from typing import Dict, List
from typing import Dict
import logging
from ..services.lora_scanner import LoraScanner
from ..services.recipe_scanner import RecipeScanner
from ..config import config
from ..services.settings_manager import settings # Add this import
from ..services.settings_manager import settings
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
logger = logging.getLogger(__name__)
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
@@ -15,13 +14,19 @@ class LoraRoutes:
"""Route handlers for LoRA management endpoints"""
def __init__(self):
self.scanner = LoraScanner()
self.recipe_scanner = RecipeScanner(self.scanner)
# Initialize service references as None, will be set during async init
self.scanner = None
self.recipe_scanner = None
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
async def init_services(self):
"""Initialize services from ServiceRegistry"""
self.scanner = await ServiceRegistry.get_lora_scanner()
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
def format_lora_data(self, lora: Dict) -> Dict:
"""Format LoRA data for template rendering"""
return {
@@ -58,41 +63,41 @@ class LoraRoutes:
async def handle_loras_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras request"""
try:
# 检查缓存初始化状态,增强判断条件
# Ensure services are initialized
await self.init_services()
# Check if the LoraScanner is initializing
# It's initializing if the cache object doesn't exist yet,
# OR if the scanner explicitly says it's initializing (background task running).
is_initializing = (
self.scanner._cache is None or
(self.scanner._initialization_task is not None and
not self.scanner._initialization_task.done()) or
(self.scanner._cache is not None and len(self.scanner._cache.raw_data) == 0 and
self.scanner._initialization_task is not None)
self.scanner._cache is None or
(hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing)
)
if is_initializing:
# 如果正在初始化,返回一个只包含加载提示的页面
# If still initializing, return loading page
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=[], # 空文件夹列表
is_initializing=True, # 新增标志
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
folders=[],
is_initializing=True,
settings=settings,
request=request
)
logger.info("Loras page is initializing, returning loading page")
else:
# 正常流程 - 但不要等待缓存刷新
# Normal flow - get data from initialized cache
try:
cache = await self.scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=cache.folders,
is_initializing=False,
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
settings=settings,
request=request
)
logger.debug(f"Loras page loaded successfully with {len(cache.raw_data)} items")
except Exception as cache_error:
logger.error(f"Error loading cache data: {cache_error}")
# 如果获取缓存失败,也显示初始化页面
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=[],
@@ -117,32 +122,30 @@ class LoraRoutes:
async def handle_recipes_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras/recipes request"""
try:
# Check cache initialization status
is_initializing = (
self.recipe_scanner._cache is None and
(self.recipe_scanner._initialization_task is not None and
not self.recipe_scanner._initialization_task.done())
)
if is_initializing:
# If initializing, return a loading page
# Ensure services are initialized
await self.init_services()
# Skip initialization check and directly try to get cached data
try:
# Recipe scanner will initialize cache if needed
await self.recipe_scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('recipes.html')
rendered = template.render(
recipes=[], # Frontend will load recipes via API
is_initializing=False,
settings=settings,
request=request
)
except Exception as cache_error:
logger.error(f"Error loading recipe cache data: {cache_error}")
# Still keep error handling - show initializing page on error
template = self.template_env.get_template('recipes.html')
rendered = template.render(
is_initializing=True,
settings=settings,
request=request # Pass the request object to the template
)
else:
# return empty recipes
recipes_data = []
template = self.template_env.get_template('recipes.html')
rendered = template.render(
recipes=recipes_data,
is_initializing=False,
settings=settings,
request=request # Pass the request object to the template
request=request
)
logger.info("Recipe cache error, returning initialization page")
return web.Response(
text=rendered,
@@ -174,5 +177,13 @@ class LoraRoutes:
def setup_routes(self, app: web.Application):
"""Register routes with the application"""
# Add an app startup handler to initialize services
app.on_startup.append(self._on_startup)
# Register routes
app.router.add_get('/loras', self.handle_loras_page)
app.router.add_get('/loras/recipes', self.handle_recipes_page)
async def _on_startup(self, app):
"""Initialize services when the app starts"""
await self.init_services()

View File

@@ -8,13 +8,12 @@ import json
import asyncio
from ..utils.exif_utils import ExifUtils
from ..utils.recipe_parsers import RecipeParserFactory
from ..services.civitai_client import CivitaiClient
from ..utils.constants import CARD_PREVIEW_WIDTH
from ..services.recipe_scanner import RecipeScanner
from ..services.lora_scanner import LoraScanner
from ..config import config
from ..workflow.parser import WorkflowParser
from ..utils.utils import download_civitai_image
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
logger = logging.getLogger(__name__)
@@ -22,13 +21,19 @@ class RecipeRoutes:
"""API route handlers for Recipe management"""
def __init__(self):
self.recipe_scanner = RecipeScanner(LoraScanner())
self.civitai_client = CivitaiClient()
# Initialize service references as None, will be set during async init
self.recipe_scanner = None
self.civitai_client = None
self.parser = WorkflowParser()
# Pre-warm the cache
self._init_cache_task = None
async def init_services(self):
"""Initialize services from ServiceRegistry"""
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
self.civitai_client = await ServiceRegistry.get_civitai_client()
@classmethod
def setup_routes(cls, app: web.Application):
"""Register API routes"""
@@ -60,11 +65,17 @@ class RecipeRoutes:
app.on_startup.append(routes._init_cache)
app.router.add_post('/api/recipes/save-from-widget', routes.save_recipe_from_widget)
# Add route to get recipes for a specific Lora
app.router.add_get('/api/recipes/for-lora', routes.get_recipes_for_lora)
async def _init_cache(self, app):
"""Initialize cache on startup"""
try:
# First, ensure the lora scanner is fully initialized
# Initialize services first
await self.init_services()
# Now that services are initialized, get the lora scanner
lora_scanner = self.recipe_scanner._lora_scanner
# Get lora cache to ensure it's initialized
@@ -82,6 +93,9 @@ class RecipeRoutes:
async def get_recipes(self, request: web.Request) -> web.Response:
"""API endpoint for getting paginated recipes"""
try:
# Ensure services are initialized
await self.init_services()
# Get query parameters with defaults
page = int(request.query.get('page', '1'))
page_size = int(request.query.get('page_size', '20'))
@@ -98,6 +112,9 @@ class RecipeRoutes:
base_models = request.query.get('base_models', None)
tags = request.query.get('tags', None)
# New parameter: get LoRA hash filter
lora_hash = request.query.get('lora_hash', None)
# Parse filter parameters
filters = {}
if base_models:
@@ -113,14 +130,15 @@ class RecipeRoutes:
'lora_model': search_lora_model
}
# Get paginated data
# Get paginated data with the new lora_hash parameter
result = await self.recipe_scanner.get_paginated_data(
page=page,
page_size=page_size,
sort_by=sort_by,
search=search,
filters=filters,
search_options=search_options
search_options=search_options,
lora_hash=lora_hash
)
# Format the response data with static URLs for file paths
@@ -147,21 +165,18 @@ class RecipeRoutes:
async def get_recipe_detail(self, request: web.Request) -> web.Response:
"""Get detailed information about a specific recipe"""
try:
recipe_id = request.match_info['recipe_id']
# Get all recipes from cache
cache = await self.recipe_scanner.get_cached_data()
# Ensure services are initialized
await self.init_services()
# Find the specific recipe
recipe = next((r for r in cache.raw_data if str(r.get('id', '')) == recipe_id), None)
recipe_id = request.match_info['recipe_id']
# Use the new get_recipe_by_id method from recipe_scanner
recipe = await self.recipe_scanner.get_recipe_by_id(recipe_id)
if not recipe:
return web.json_response({"error": "Recipe not found"}, status=404)
# Format recipe data
formatted_recipe = self._format_recipe_data(recipe)
return web.json_response(formatted_recipe)
return web.json_response(recipe)
except Exception as e:
logger.error(f"Error retrieving recipe details: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
@@ -206,6 +221,9 @@ class RecipeRoutes:
"""Analyze an uploaded image or URL for recipe metadata"""
temp_path = None
try:
# Ensure services are initialized
await self.init_services()
# Check if request contains multipart data (image) or JSON data (url)
content_type = request.headers.get('Content-Type', '')
@@ -324,6 +342,9 @@ class RecipeRoutes:
async def save_recipe(self, request: web.Request) -> web.Response:
"""Save a recipe to the recipes folder"""
try:
# Ensure services are initialized
await self.init_services()
reader = await request.multipart()
# Process form data
@@ -423,7 +444,7 @@ class RecipeRoutes:
# Optimize the image (resize and convert to WebP)
optimized_image, extension = ExifUtils.optimize_image(
image_data=image,
target_width=480,
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=True
@@ -525,6 +546,9 @@ class RecipeRoutes:
async def delete_recipe(self, request: web.Request) -> web.Response:
"""Delete a recipe by ID"""
try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id']
# Get recipes directory
@@ -572,6 +596,9 @@ class RecipeRoutes:
async def get_top_tags(self, request: web.Request) -> web.Response:
"""Get top tags used in recipes"""
try:
# Ensure services are initialized
await self.init_services()
# Get limit parameter with default
limit = int(request.query.get('limit', '20'))
@@ -604,6 +631,9 @@ class RecipeRoutes:
async def get_base_models(self, request: web.Request) -> web.Response:
"""Get base models used in recipes"""
try:
# Ensure services are initialized
await self.init_services()
# Get all recipes from cache
cache = await self.recipe_scanner.get_cached_data()
@@ -632,6 +662,9 @@ class RecipeRoutes:
async def share_recipe(self, request: web.Request) -> web.Response:
"""Process a recipe image for sharing by adding metadata to EXIF"""
try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id']
# Get all recipes from cache
@@ -691,6 +724,9 @@ class RecipeRoutes:
async def download_shared_recipe(self, request: web.Request) -> web.Response:
"""Serve a processed recipe image for download"""
try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id']
# Check if we have this shared recipe
@@ -747,6 +783,9 @@ class RecipeRoutes:
async def save_recipe_from_widget(self, request: web.Request) -> web.Response:
"""Save a recipe from the LoRAs widget"""
try:
# Ensure services are initialized
await self.init_services()
reader = await request.multipart()
# Process form data
@@ -827,7 +866,7 @@ class RecipeRoutes:
# Optimize the image (resize and convert to WebP)
optimized_image, extension = ExifUtils.optimize_image(
image_data=image,
target_width=480,
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=True
@@ -921,6 +960,9 @@ class RecipeRoutes:
async def get_recipe_syntax(self, request: web.Request) -> web.Response:
"""Generate recipe syntax for LoRAs in the recipe, looking up proper file names using hash_index"""
try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id']
# Get all recipes from cache
@@ -1001,6 +1043,9 @@ class RecipeRoutes:
async def update_recipe(self, request: web.Request) -> web.Response:
"""Update recipe metadata (name and tags)"""
try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id']
data = await request.json()
@@ -1028,6 +1073,9 @@ class RecipeRoutes:
async def reconnect_lora(self, request: web.Request) -> web.Response:
"""Reconnect a deleted LoRA in a recipe to a local LoRA file"""
try:
# Ensure services are initialized
await self.init_services()
# Parse request data
data = await request.json()
@@ -1134,3 +1182,52 @@ class RecipeRoutes:
except Exception as e:
logger.error(f"Error reconnecting LoRA: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_recipes_for_lora(self, request: web.Request) -> web.Response:
"""Get recipes that use a specific Lora"""
try:
# Ensure services are initialized
await self.init_services()
lora_hash = request.query.get('hash')
# Hash is required
if not lora_hash:
return web.json_response({'success': False, 'error': 'Lora hash is required'}, status=400)
# Log the search parameters
logger.debug(f"Getting recipes for Lora by hash: {lora_hash}")
# Get all recipes from cache
cache = await self.recipe_scanner.get_cached_data()
# Filter recipes that use this Lora by hash
matching_recipes = []
for recipe in cache.raw_data:
# Check if any of the recipe's loras match this hash
loras = recipe.get('loras', [])
for lora in loras:
if lora.get('hash', '').lower() == lora_hash.lower():
matching_recipes.append(recipe)
break # No need to check other loras in this recipe
# Process the recipes similar to get_paginated_data to ensure all needed data is available
for recipe in matching_recipes:
# Add inLibrary information for each lora
if 'loras' in recipe:
for lora in recipe['loras']:
if 'hash' in lora and lora['hash']:
lora['inLibrary'] = self.recipe_scanner._lora_scanner.has_lora_hash(lora['hash'].lower())
lora['preview_url'] = self.recipe_scanner._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
lora['localPath'] = self.recipe_scanner._lora_scanner.get_lora_path_by_hash(lora['hash'].lower())
# Ensure file_url is set (needed by frontend)
if 'file_path' in recipe:
recipe['file_url'] = self._format_recipe_file_url(recipe['file_path'])
else:
recipe['file_url'] = '/loras_static/images/no-preview.png'
return web.json_response({'success': True, 'recipes': matching_recipes})
except Exception as e:
logger.error(f"Error getting recipes for Lora: {str(e)}")
return web.json_response({'success': False, 'error': str(e)}, status=500)

View File

@@ -0,0 +1,131 @@
import os
import logging
import asyncio
from typing import List, Dict, Optional, Set
import folder_paths # type: ignore
from ..utils.models import CheckpointMetadata
from ..config import config
from .model_scanner import ModelScanner
from .model_hash_index import ModelHashIndex
from .service_registry import ServiceRegistry
logger = logging.getLogger(__name__)
class CheckpointScanner(ModelScanner):
"""Service for scanning and managing checkpoint files"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not hasattr(self, '_initialized'):
# Define supported file extensions
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
super().__init__(
model_type="checkpoint",
model_class=CheckpointMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex()
)
self._checkpoint_roots = self._init_checkpoint_roots()
self._initialized = True
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def _init_checkpoint_roots(self) -> List[str]:
"""Initialize checkpoint roots from ComfyUI settings"""
# Get both checkpoint and diffusion_models paths
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
diffusion_paths = folder_paths.get_folder_paths("diffusion_models")
# Combine, normalize and deduplicate paths
all_paths = set()
for path in checkpoint_paths + diffusion_paths:
if os.path.exists(path):
norm_path = path.replace(os.sep, "/")
all_paths.add(norm_path)
# Sort for consistent order
sorted_paths = sorted(all_paths, key=lambda p: p.lower())
return sorted_paths
def get_model_roots(self) -> List[str]:
"""Get checkpoint root directories"""
return self._checkpoint_roots
async def scan_all_models(self) -> List[Dict]:
"""Scan all checkpoint directories and return metadata"""
all_checkpoints = []
# Create scan tasks for each directory
scan_tasks = []
for root in self._checkpoint_roots:
task = asyncio.create_task(self._scan_directory(root))
scan_tasks.append(task)
# Wait for all tasks to complete
for task in scan_tasks:
try:
checkpoints = await task
all_checkpoints.extend(checkpoints)
except Exception as e:
logger.error(f"Error scanning checkpoint directory: {e}")
return all_checkpoints
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a directory for checkpoint files"""
checkpoints = []
original_root = root_path
async def scan_recursive(path: str, visited_paths: set):
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True):
# Check if file has supported extension
ext = os.path.splitext(entry.name)[1].lower()
if ext in self.file_extensions:
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, checkpoints)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# For directories, continue scanning with original path
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return checkpoints
async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list):
"""Process a single checkpoint file and add to results"""
try:
result = await self._process_model_file(file_path, root_path)
if result:
checkpoints.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")

View File

@@ -3,6 +3,7 @@ import aiohttp
import os
import json
import logging
import asyncio
from email.parser import Parser
from typing import Optional, Dict, Tuple, List
from urllib.parse import unquote
@@ -11,20 +12,51 @@ from ..utils.models import LoraMetadata
logger = logging.getLogger(__name__)
class CivitaiClient:
_instance = None
_lock = asyncio.Lock()
@classmethod
async def get_instance(cls):
"""Get singleton instance of CivitaiClient"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
# Check if already initialized for singleton pattern
if hasattr(self, '_initialized'):
return
self._initialized = True
self.base_url = "https://civitai.com/api/v1"
self.headers = {
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
}
self._session = None
# Set default buffer size to 1MB for higher throughput
self.chunk_size = 1024 * 1024
@property
async def session(self) -> aiohttp.ClientSession:
"""Lazy initialize the session"""
if self._session is None:
connector = aiohttp.TCPConnector(ssl=True)
trust_env = True # 允许使用系统环境变量中的代理设置
self._session = aiohttp.ClientSession(connector=connector, trust_env=trust_env)
# Optimize TCP connection parameters
connector = aiohttp.TCPConnector(
ssl=True,
limit=10, # Increase parallel connections
ttl_dns_cache=300, # DNS cache time
force_close=False, # Keep connections for reuse
enable_cleanup_closed=True
)
trust_env = True # Allow using system environment proxy settings
# Configure timeout parameters
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
self._session = aiohttp.ClientSession(
connector=connector,
trust_env=trust_env,
timeout=timeout
)
return self._session
def _parse_content_disposition(self, header: str) -> str:
@@ -74,6 +106,10 @@ class CivitaiClient:
session = await self.session
try:
headers = self._get_request_headers()
# Add Range header to allow resumable downloads
headers['Accept-Encoding'] = 'identity' # Disable compression for better chunked downloads
async with session.get(url, headers=headers, allow_redirects=True) as response:
if response.status != 200:
# Handle 401 unauthorized responses
@@ -101,16 +137,23 @@ class CivitaiClient:
# Get total file size for progress calculation
total_size = int(response.headers.get('content-length', 0))
current_size = 0
last_progress_report_time = datetime.now()
# Stream download to file with progress updates
# Stream download to file with progress updates using larger buffer
with open(save_path, 'wb') as f:
async for chunk in response.content.iter_chunked(8192):
async for chunk in response.content.iter_chunked(self.chunk_size):
if chunk:
f.write(chunk)
current_size += len(chunk)
if progress_callback and total_size:
# Limit progress update frequency to reduce overhead
now = datetime.now()
time_diff = (now - last_progress_report_time).total_seconds()
if progress_callback and total_size and time_diff >= 0.5:
progress = (current_size / total_size) * 100
await progress_callback(progress)
last_progress_report_time = now
# Ensure 100% progress is reported
if progress_callback:
@@ -118,6 +161,9 @@ class CivitaiClient:
return True, save_path
except aiohttp.ClientError as e:
logger.error(f"Network error during download: {e}")
return False, f"Network error: {str(e)}"
except Exception as e:
logger.error(f"Download error: {e}")
return False, str(e)
@@ -155,7 +201,11 @@ class CivitaiClient:
if response.status != 200:
return None
data = await response.json()
return data.get('modelVersions', [])
# Also return model type along with versions
return {
'modelVersions': data.get('modelVersions', []),
'type': data.get('type', '')
}
except Exception as e:
logger.error(f"Error fetching model versions: {e}")
return None

View File

@@ -1,21 +1,79 @@
import logging
import os
import json
from typing import Optional, Dict
import asyncio
from typing import Optional, Dict, Any
from .civitai_client import CivitaiClient
from .file_monitor import LoraFileMonitor
from ..utils.models import LoraMetadata
from ..utils.models import LoraMetadata, CheckpointMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH
from ..utils.exif_utils import ExifUtils
from .service_registry import ServiceRegistry
# Download to temporary file first
import tempfile
logger = logging.getLogger(__name__)
class DownloadManager:
def __init__(self, file_monitor: Optional[LoraFileMonitor] = None):
self.civitai_client = CivitaiClient()
self.file_monitor = file_monitor
_instance = None
_lock = asyncio.Lock()
@classmethod
async def get_instance(cls):
"""Get singleton instance of DownloadManager"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
# Check if already initialized for singleton pattern
if hasattr(self, '_initialized'):
return
self._initialized = True
self._civitai_client = None # Will be lazily initialized
async def _get_civitai_client(self):
"""Lazily initialize CivitaiClient from registry"""
if self._civitai_client is None:
self._civitai_client = await ServiceRegistry.get_civitai_client()
return self._civitai_client
async def _get_lora_monitor(self):
"""Get the lora file monitor from registry"""
return await ServiceRegistry.get_lora_monitor()
async def _get_checkpoint_monitor(self):
"""Get the checkpoint file monitor from registry"""
return await ServiceRegistry.get_checkpoint_monitor()
async def _get_lora_scanner(self):
"""Get the lora scanner from registry"""
return await ServiceRegistry.get_lora_scanner()
async def _get_checkpoint_scanner(self):
"""Get the checkpoint scanner from registry"""
return await ServiceRegistry.get_checkpoint_scanner()
async def download_from_civitai(self, download_url: str = None, model_hash: str = None,
model_version_id: str = None, save_dir: str = None,
relative_path: str = '', progress_callback=None) -> Dict:
relative_path: str = '', progress_callback=None,
model_type: str = "lora") -> Dict:
"""Download model from Civitai
Args:
download_url: Direct download URL for the model
model_hash: SHA256 hash of the model
model_version_id: Civitai model version ID
save_dir: Directory to save the model to
relative_path: Relative path within save_dir
progress_callback: Callback function for progress updates
model_type: Type of model ('lora' or 'checkpoint')
Returns:
Dict with download result
"""
try:
# Update save directory with relative path if provided
if relative_path:
@@ -23,25 +81,28 @@ class DownloadManager:
# Create directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)
# Get civitai client
civitai_client = await self._get_civitai_client()
# Get version info based on the provided identifier
version_info = None
if download_url:
# Extract version ID from download URL
version_id = download_url.split('/')[-1]
version_info = await self.civitai_client.get_model_version_info(version_id)
version_info = await civitai_client.get_model_version_info(version_id)
elif model_version_id:
# Use model version ID directly
version_info = await self.civitai_client.get_model_version_info(model_version_id)
version_info = await civitai_client.get_model_version_info(model_version_id)
elif model_hash:
# Get model by hash
version_info = await self.civitai_client.get_model_by_hash(model_hash)
version_info = await civitai_client.get_model_by_hash(model_hash)
if not version_info:
return {'success': False, 'error': 'Failed to fetch model metadata'}
# Check if this is an early access LoRA
# Check if this is an early access model
if version_info.get('earlyAccessEndsAt'):
early_access_date = version_info.get('earlyAccessEndsAt', '')
# Convert to a readable date if possible
@@ -49,12 +110,12 @@ class DownloadManager:
from datetime import datetime
date_obj = datetime.fromisoformat(early_access_date.replace('Z', '+00:00'))
formatted_date = date_obj.strftime('%Y-%m-%d')
early_access_msg = f"This LoRA requires early access payment (until {formatted_date}). "
early_access_msg = f"This model requires early access payment (until {formatted_date}). "
except:
early_access_msg = "This LoRA requires early access payment. "
early_access_msg = "This model requires early access payment. "
early_access_msg += "Please ensure you have purchased early access and are logged in to Civitai."
logger.warning(f"Early access LoRA detected: {version_info.get('name', 'Unknown')}")
logger.warning(f"Early access model detected: {version_info.get('name', 'Unknown')}")
# We'll still try to download, but log a warning and prepare for potential failure
if progress_callback:
@@ -64,43 +125,51 @@ class DownloadManager:
if progress_callback:
await progress_callback(0)
# 2. 获取文件信息
# 2. Get file information
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
if not file_info:
return {'success': False, 'error': 'No primary file found in metadata'}
# 3. 准备下载
# 3. Prepare download
file_name = file_info['name']
save_path = os.path.join(save_dir, file_name)
file_size = file_info.get('sizeKB', 0) * 1024
# 4. 通知文件监控系统 - 使用规范化路径和文件大小
self.file_monitor.handler.add_ignore_path(
save_path.replace(os.sep, '/'),
file_size
)
# 4. Notify file monitor - use normalized path and file size
file_monitor = await self._get_lora_monitor() if model_type == "lora" else await self._get_checkpoint_monitor()
if file_monitor and file_monitor.handler:
file_monitor.handler.add_ignore_path(
save_path.replace(os.sep, '/'),
file_size
)
# 5. 准备元数据
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
# 5. Prepare metadata based on model type
if model_type == "checkpoint":
metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path)
logger.info(f"Creating CheckpointMetadata for {file_name}")
else:
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
logger.info(f"Creating LoraMetadata for {file_name}")
# 5.1 获取并更新模型标签和描述信息
# 5.1 Get and update model tags and description
model_id = version_info.get('modelId')
if model_id:
model_metadata, _ = await self.civitai_client.get_model_metadata(str(model_id))
model_metadata, _ = await civitai_client.get_model_metadata(str(model_id))
if model_metadata:
if model_metadata.get("tags"):
metadata.tags = model_metadata.get("tags", [])
if model_metadata.get("description"):
metadata.modelDescription = model_metadata.get("description", "")
# 6. 开始下载流程
# 6. Start download process
result = await self._execute_download(
download_url=file_info.get('downloadUrl', ''),
save_dir=save_dir,
metadata=metadata,
version_info=version_info,
relative_path=relative_path,
progress_callback=progress_callback
progress_callback=progress_callback,
model_type=model_type
)
return result
@@ -114,10 +183,12 @@ class DownloadManager:
return {'success': False, 'error': str(e)}
async def _execute_download(self, download_url: str, save_dir: str,
metadata: LoraMetadata, version_info: Dict,
relative_path: str, progress_callback=None) -> Dict:
metadata, version_info: Dict,
relative_path: str, progress_callback=None,
model_type: str = "lora") -> Dict:
"""Execute the actual download process including preview images and model files"""
try:
civitai_client = await self._get_civitai_client()
save_path = metadata.file_path
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
@@ -128,20 +199,61 @@ class DownloadManager:
if progress_callback:
await progress_callback(1) # 1% progress for starting preview download
preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png'
preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext
if await self.civitai_client.download_preview_image(images[0]['url'], preview_path):
metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
# Check if it's a video or an image
is_video = images[0].get('type') == 'video'
if is_video:
# For videos, use .mp4 extension
preview_ext = '.mp4'
preview_path = os.path.splitext(save_path)[0] + preview_ext
# Download video directly
if await civitai_client.download_preview_image(images[0]['url'], preview_path):
metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
else:
# For images, use WebP format for better performance
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
temp_path = temp_file.name
# Download the original image to temp path
if await civitai_client.download_preview_image(images[0]['url'], temp_path):
# Optimize and convert to WebP
preview_path = os.path.splitext(save_path)[0] + '.webp'
# Use ExifUtils to optimize and convert the image
optimized_data, _ = ExifUtils.optimize_image(
image_data=temp_path,
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=True
)
# Save the optimized image
with open(preview_path, 'wb') as f:
f.write(optimized_data)
# Update metadata
metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
# Remove temporary file
try:
os.unlink(temp_path)
except Exception as e:
logger.warning(f"Failed to delete temp file: {e}")
# Report preview download completion
if progress_callback:
await progress_callback(3) # 3% progress after preview download
# Download model file with progress tracking
success, result = await self.civitai_client._download_file(
success, result = await civitai_client._download_file(
download_url,
save_dir,
os.path.basename(save_path),
@@ -155,15 +267,22 @@ class DownloadManager:
os.remove(path)
return {'success': False, 'error': result}
# 4. 更新文件信息(大小和修改时间)
# 4. Update file information (size and modified time)
metadata.update_file_info(save_path)
# 5. 最终更新元数据
# 5. Final metadata update
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
# 6. update lora cache
cache = await self.file_monitor.scanner.get_cached_data()
# 6. Update cache based on model type
if model_type == "checkpoint":
scanner = await self._get_checkpoint_scanner()
logger.info(f"Updating checkpoint cache for {save_path}")
else:
scanner = await self._get_lora_scanner()
logger.info(f"Updating lora cache for {save_path}")
cache = await scanner.get_cached_data()
metadata_dict = metadata.to_dict()
metadata_dict['folder'] = relative_path
cache.raw_data.append(metadata_dict)
@@ -172,11 +291,8 @@ class DownloadManager:
all_folders.add(relative_path)
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Update the hash index with the new LoRA entry
self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
# Update the hash index with the new LoRA entry
self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
# Update the hash index with the new model entry
scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
# Report 100% completion
if progress_callback:

View File

@@ -1,37 +1,42 @@
from operator import itemgetter
import os
import logging
import asyncio
import time
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from typing import List, Dict, Set
from typing import List, Dict, Set, Optional
from threading import Lock
from .lora_scanner import LoraScanner
from ..config import config
from .service_registry import ServiceRegistry
logger = logging.getLogger(__name__)
class LoraFileHandler(FileSystemEventHandler):
"""Handler for LoRA file system events"""
# Configuration constant to control file monitoring functionality
ENABLE_FILE_MONITORING = False
class BaseFileHandler(FileSystemEventHandler):
"""Base handler for file system events"""
def __init__(self, scanner: LoraScanner, loop: asyncio.AbstractEventLoop):
self.scanner = scanner
self.loop = loop # 存储事件循环引用
self.pending_changes = set() # 待处理的变更
self.lock = Lock() # 线程安全锁
self.update_task = None # 异步更新任务
self._ignore_paths = set() # Add ignore paths set
self._min_ignore_timeout = 5 # minimum timeout in seconds
self._download_speed = 1024 * 1024 # assume 1MB/s as base speed
def __init__(self, loop: asyncio.AbstractEventLoop):
self.loop = loop # Store event loop reference
self.pending_changes = set() # Pending changes
self.lock = Lock() # Thread-safe lock
self.update_task = None # Async update task
self._ignore_paths = set() # Paths to ignore
self._min_ignore_timeout = 5 # Minimum timeout in seconds
self._download_speed = 1024 * 1024 # Assume 1MB/s as base speed
# Track modified files with timestamps for debouncing
self.modified_files: Dict[str, float] = {}
self.debounce_timer = None
self.debounce_delay = 3.0 # seconds to wait after last modification
self.debounce_delay = 3.0 # Seconds to wait after last modification
# Track files that are already scheduled for processing
# Track files already scheduled for processing
self.scheduled_files: Set[str] = set()
# File extensions to monitor - should be overridden by subclasses
self.file_extensions = set()
def _should_ignore(self, path: str) -> bool:
"""Check if path should be ignored"""
@@ -56,35 +61,33 @@ class LoraFileHandler(FileSystemEventHandler):
if event.is_directory:
return
# Handle safetensors files directly
if event.src_path.endswith('.safetensors'):
# Handle appropriate files based on extensions
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext in self.file_extensions:
if self._should_ignore(event.src_path):
return
# We'll process this file directly and ignore subsequent modifications
# to prevent duplicate processing
# Process this file directly and ignore subsequent modifications
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
if normalized_path not in self.scheduled_files:
logger.info(f"LoRA file created: {event.src_path}")
logger.info(f"File created: {event.src_path}")
self.scheduled_files.add(normalized_path)
self._schedule_update('add', event.src_path)
# Ignore modifications for a short period after creation
# This helps avoid duplicate processing
self.loop.call_later(
self.debounce_delay * 2,
self.scheduled_files.discard,
normalized_path
)
# For browser downloads, we'll catch them when they're renamed to .safetensors
def on_modified(self, event):
if event.is_directory:
return
# Only process safetensors files
if event.src_path.endswith('.safetensors'):
# Only process files with supported extensions
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext in self.file_extensions:
if self._should_ignore(event.src_path):
return
@@ -132,12 +135,17 @@ class LoraFileHandler(FileSystemEventHandler):
# Process stable files
for file_path in files_to_process:
logger.info(f"Processing modified LoRA file: {file_path}")
logger.info(f"Processing modified file: {file_path}")
self._schedule_update('add', file_path)
def on_deleted(self, event):
if event.is_directory or not event.src_path.endswith('.safetensors'):
if event.is_directory:
return
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext not in self.file_extensions:
return
if self._should_ignore(event.src_path):
return
@@ -145,14 +153,17 @@ class LoraFileHandler(FileSystemEventHandler):
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
self.scheduled_files.discard(normalized_path)
logger.info(f"LoRA file deleted: {event.src_path}")
logger.info(f"File deleted: {event.src_path}")
self._schedule_update('remove', event.src_path)
def on_moved(self, event):
"""Handle file move/rename events"""
# If destination is a safetensors file, treat it as a new file
if event.dest_path.endswith('.safetensors'):
src_ext = os.path.splitext(event.src_path)[1].lower()
dest_ext = os.path.splitext(event.dest_path)[1].lower()
# If destination has supported extension, treat as new file
if dest_ext in self.file_extensions:
if self._should_ignore(event.dest_path):
return
@@ -160,7 +171,7 @@ class LoraFileHandler(FileSystemEventHandler):
# Only process if not already scheduled
if normalized_path not in self.scheduled_files:
logger.info(f"LoRA file renamed/moved to: {event.dest_path}")
logger.info(f"File renamed/moved to: {event.dest_path}")
self.scheduled_files.add(normalized_path)
self._schedule_update('add', event.dest_path)
@@ -171,21 +182,21 @@ class LoraFileHandler(FileSystemEventHandler):
normalized_path
)
# If source was a safetensors file, treat it as deleted
if event.src_path.endswith('.safetensors'):
# If source was a supported file, treat it as deleted
if src_ext in self.file_extensions:
if self._should_ignore(event.src_path):
return
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
self.scheduled_files.discard(normalized_path)
logger.info(f"LoRA file moved/renamed from: {event.src_path}")
logger.info(f"File moved/renamed from: {event.src_path}")
self._schedule_update('remove', event.src_path)
def _schedule_update(self, action: str, file_path: str): #file_path is a real path
def _schedule_update(self, action: str, file_path: str):
"""Schedule a cache update"""
with self.lock:
# 使用 config 中的方法映射路径
# Use config method to map path
mapped_path = config.map_path_to_link(file_path)
normalized_path = mapped_path.replace(os.sep, '/')
self.pending_changes.add((action, normalized_path))
@@ -196,7 +207,20 @@ class LoraFileHandler(FileSystemEventHandler):
"""Create update task in the event loop"""
if self.update_task is None or self.update_task.done():
self.update_task = asyncio.create_task(self._process_changes())
async def _process_changes(self, delay: float = 2.0):
"""Process pending changes with debouncing - should be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement _process_changes")
class LoraFileHandler(BaseFileHandler):
"""Handler for LoRA file system events"""
def __init__(self, loop: asyncio.AbstractEventLoop):
super().__init__(loop)
# Set supported file extensions for LoRAs
self.file_extensions = {'.safetensors'}
async def _process_changes(self, delay: float = 2.0):
"""Process pending changes with debouncing"""
await asyncio.sleep(delay)
@@ -209,9 +233,11 @@ class LoraFileHandler(FileSystemEventHandler):
if not changes:
return
logger.info(f"Processing {len(changes)} file changes")
logger.info(f"Processing {len(changes)} LoRA file changes")
cache = await self.scanner.get_cached_data()
# Get scanner through ServiceRegistry
scanner = await ServiceRegistry.get_lora_scanner()
cache = await scanner.get_cached_data()
needs_resort = False
new_folders = set()
@@ -225,36 +251,36 @@ class LoraFileHandler(FileSystemEventHandler):
continue
# Scan new file
lora_data = await self.scanner.scan_single_lora(file_path)
if lora_data:
model_data = await scanner.scan_single_model(file_path)
if model_data:
# Update tags count
for tag in lora_data.get('tags', []):
self.scanner._tags_count[tag] = self.scanner._tags_count.get(tag, 0) + 1
for tag in model_data.get('tags', []):
scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1
cache.raw_data.append(lora_data)
new_folders.add(lora_data['folder'])
cache.raw_data.append(model_data)
new_folders.add(model_data['folder'])
# Update hash index
if 'sha256' in lora_data:
self.scanner._hash_index.add_entry(
lora_data['sha256'],
lora_data['file_path']
if 'sha256' in model_data:
scanner._hash_index.add_entry(
model_data['sha256'],
model_data['file_path']
)
needs_resort = True
elif action == 'remove':
# Find the lora to remove so we can update tags count
lora_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if lora_to_remove:
# Find the model to remove so we can update tags count
model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if model_to_remove:
# Update tags count by reducing counts
for tag in lora_to_remove.get('tags', []):
if tag in self.scanner._tags_count:
self.scanner._tags_count[tag] = max(0, self.scanner._tags_count[tag] - 1)
if self.scanner._tags_count[tag] == 0:
del self.scanner._tags_count[tag]
for tag in model_to_remove.get('tags', []):
if tag in scanner._tags_count:
scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
if scanner._tags_count[tag] == 0:
del scanner._tags_count[tag]
# Remove from cache and hash index
logger.info(f"Removing {file_path} from cache")
self.scanner._hash_index.remove_by_path(file_path)
scanner._hash_index.remove_by_path(file_path)
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != file_path
@@ -272,62 +298,245 @@ class LoraFileHandler(FileSystemEventHandler):
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
except Exception as e:
logger.error(f"Error in process_changes: {e}")
logger.error(f"Error in process_changes for LoRA: {e}")
class LoraFileMonitor:
"""Monitor for LoRA file changes"""
class CheckpointFileHandler(BaseFileHandler):
"""Handler for checkpoint file system events"""
def __init__(self, scanner: LoraScanner, roots: List[str]):
self.scanner = scanner
scanner.set_file_monitor(self)
def __init__(self, loop: asyncio.AbstractEventLoop):
super().__init__(loop)
# Set supported file extensions for checkpoints
self.file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
async def _process_changes(self, delay: float = 2.0):
"""Process pending changes with debouncing for checkpoint files"""
await asyncio.sleep(delay)
try:
with self.lock:
changes = self.pending_changes.copy()
self.pending_changes.clear()
if not changes:
return
logger.info(f"Processing {len(changes)} checkpoint file changes")
# Get scanner through ServiceRegistry
scanner = await ServiceRegistry.get_checkpoint_scanner()
cache = await scanner.get_cached_data()
needs_resort = False
new_folders = set()
for action, file_path in changes:
try:
if action == 'add':
# Check if file already exists in cache
existing = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if existing:
logger.info(f"File {file_path} already in cache, skipping")
continue
# Scan new file
model_data = await scanner.scan_single_model(file_path)
if model_data:
# Update tags count if applicable
for tag in model_data.get('tags', []):
scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1
cache.raw_data.append(model_data)
new_folders.add(model_data['folder'])
# Update hash index
if 'sha256' in model_data:
scanner._hash_index.add_entry(
model_data['sha256'],
model_data['file_path']
)
needs_resort = True
elif action == 'remove':
# Find the model to remove so we can update tags count
model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if model_to_remove:
# Update tags count by reducing counts
for tag in model_to_remove.get('tags', []):
if tag in scanner._tags_count:
scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
if scanner._tags_count[tag] == 0:
del scanner._tags_count[tag]
# Remove from cache and hash index
logger.info(f"Removing {file_path} from checkpoint cache")
scanner._hash_index.remove_by_path(file_path)
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != file_path
]
needs_resort = True
except Exception as e:
logger.error(f"Error processing checkpoint {action} for {file_path}: {e}")
if needs_resort:
await cache.resort()
# Update folder list
all_folders = set(cache.folders) | new_folders
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
except Exception as e:
logger.error(f"Error in process_changes for checkpoint: {e}")
class BaseFileMonitor:
"""Base class for file monitoring"""
def __init__(self, monitor_paths: List[str]):
self.observer = Observer()
self.loop = asyncio.get_event_loop()
self.handler = LoraFileHandler(scanner, self.loop)
# 使用已存在的路径映射
self.monitor_paths = set()
for root in roots:
self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/'))
# Process monitor paths
for path in monitor_paths:
self.monitor_paths.add(os.path.realpath(path).replace(os.sep, '/'))
# 添加所有已映射的目标路径
# Add mapped paths from config
for target_path in config._path_mappings.keys():
self.monitor_paths.add(target_path)
def start(self):
"""Start monitoring"""
for path_info in self.monitor_paths:
"""Start file monitoring"""
if not ENABLE_FILE_MONITORING:
logger.info("File monitoring is disabled via ENABLE_FILE_MONITORING setting")
return
for path in self.monitor_paths:
try:
if isinstance(path_info, tuple):
# 对于链接,监控目标路径
_, target_path = path_info
self.observer.schedule(self.handler, target_path, recursive=True)
logger.info(f"Started monitoring target path: {target_path}")
else:
# 对于普通路径,直接监控
self.observer.schedule(self.handler, path_info, recursive=True)
logger.info(f"Started monitoring: {path_info}")
self.observer.schedule(self.handler, path, recursive=True)
logger.info(f"Started monitoring: {path}")
except Exception as e:
logger.error(f"Error monitoring {path_info}: {e}")
logger.error(f"Error monitoring {path}: {e}")
self.observer.start()
def stop(self):
"""Stop monitoring"""
"""Stop file monitoring"""
if not ENABLE_FILE_MONITORING:
return
self.observer.stop()
self.observer.join()
def rescan_links(self):
"""重新扫描链接(当添加新的链接时调用)"""
"""Rescan links when new ones are added"""
if not ENABLE_FILE_MONITORING:
return
# Find new paths not yet being monitored
new_paths = set()
for path in self.monitor_paths.copy():
self._add_link_targets(path)
for path in config._path_mappings.keys():
real_path = os.path.realpath(path).replace(os.sep, '/')
if real_path not in self.monitor_paths:
new_paths.add(real_path)
self.monitor_paths.add(real_path)
# 添加新发现的路径到监控
new_paths = self.monitor_paths - set(self.observer.watches.keys())
# Add new paths to monitoring
for path in new_paths:
try:
self.observer.schedule(self.handler, path, recursive=True)
logger.info(f"Added new monitoring path: {path}")
except Exception as e:
logger.error(f"Error adding new monitor for {path}: {e}")
logger.error(f"Error adding new monitor for {path}: {e}")
class LoraFileMonitor(BaseFileMonitor):
"""Monitor for LoRA file changes"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls, monitor_paths=None):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, monitor_paths=None):
if not hasattr(self, '_initialized'):
if monitor_paths is None:
from ..config import config
monitor_paths = config.loras_roots
super().__init__(monitor_paths)
self.handler = LoraFileHandler(self.loop)
self._initialized = True
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
from ..config import config
cls._instance = cls(config.loras_roots)
return cls._instance
class CheckpointFileMonitor(BaseFileMonitor):
"""Monitor for checkpoint file changes"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls, monitor_paths=None):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, monitor_paths=None):
if not hasattr(self, '_initialized'):
if monitor_paths is None:
# Get checkpoint roots from scanner
monitor_paths = []
# We'll initialize monitor paths later when scanner is available
super().__init__(monitor_paths or [])
self.handler = CheckpointFileHandler(self.loop)
self._initialized = True
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls([])
# Now get checkpoint roots from scanner
from .checkpoint_scanner import CheckpointScanner
scanner = await CheckpointScanner.get_instance()
monitor_paths = scanner.get_model_roots()
# Update monitor paths - but don't actually monitor them
for path in monitor_paths:
real_path = os.path.realpath(path).replace(os.sep, '/')
cls._instance.monitor_paths.add(real_path)
return cls._instance
def start(self):
"""Override start to check global enable flag"""
if not ENABLE_FILE_MONITORING:
logger.info("Checkpoint file monitoring is disabled via ENABLE_FILE_MONITORING setting")
return
logger.info("Checkpoint file monitoring is temporarily disabled")
# Skip the actual monitoring setup
pass
async def initialize_paths(self):
"""Initialize monitor paths from scanner - currently disabled"""
if not ENABLE_FILE_MONITORING:
logger.info("Checkpoint path initialization skipped (monitoring disabled)")
return
logger.info("Checkpoint file path initialization skipped (monitoring disabled)")
pass

View File

@@ -4,22 +4,21 @@ import logging
import asyncio
import shutil
import time
from typing import List, Dict, Optional
from typing import List, Dict, Optional, Set
from ..utils.models import LoraMetadata
from ..config import config
from ..utils.file_utils import load_metadata, get_file_info, normalize_path, find_preview_file, save_metadata
from ..utils.lora_metadata import extract_lora_metadata
from .lora_cache import LoraCache
from .model_scanner import ModelScanner
from .lora_hash_index import LoraHashIndex
from .settings_manager import settings
from ..utils.constants import NSFW_LEVELS
from ..utils.utils import fuzzy_match
from .service_registry import ServiceRegistry
import sys
logger = logging.getLogger(__name__)
class LoraScanner:
class LoraScanner(ModelScanner):
"""Service for scanning and managing LoRA files"""
_instance = None
@@ -31,20 +30,20 @@ class LoraScanner:
return cls._instance
def __init__(self):
# 确保初始化只执行一次
# Ensure initialization happens only once
if not hasattr(self, '_initialized'):
self._cache: Optional[LoraCache] = None
self._hash_index = LoraHashIndex()
self._initialization_lock = asyncio.Lock()
self._initialization_task: Optional[asyncio.Task] = None
# Define supported file extensions
file_extensions = {'.safetensors'}
# Initialize parent class
super().__init__(
model_type="lora",
model_class=LoraMetadata,
file_extensions=file_extensions,
hash_index=LoraHashIndex()
)
self._initialized = True
self.file_monitor = None # Add this line
self._tags_count = {} # Add a dictionary to store tag counts
def set_file_monitor(self, monitor):
"""Set file monitor instance"""
self.file_monitor = monitor
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
@@ -52,93 +51,78 @@ class LoraScanner:
if cls._instance is None:
cls._instance = cls()
return cls._instance
async def get_cached_data(self, force_refresh: bool = False) -> LoraCache:
"""Get cached LoRA data, refresh if needed"""
async with self._initialization_lock:
def get_model_roots(self) -> List[str]:
"""Get lora root directories"""
return config.loras_roots
async def scan_all_models(self) -> List[Dict]:
"""Scan all LoRA directories and return metadata"""
all_loras = []
# Create scan tasks for each directory
scan_tasks = []
for lora_root in self.get_model_roots():
task = asyncio.create_task(self._scan_directory(lora_root))
scan_tasks.append(task)
# 如果缓存未初始化但需要响应请求,返回空缓存
if self._cache is None and not force_refresh:
return LoraCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# 如果正在初始化,等待完成
if self._initialization_task and not self._initialization_task.done():
try:
await self._initialization_task
except Exception as e:
logger.error(f"Cache initialization failed: {e}")
self._initialization_task = None
if (self._cache is None or force_refresh):
# Wait for all tasks to complete
for task in scan_tasks:
try:
loras = await task
all_loras.extend(loras)
except Exception as e:
logger.error(f"Error scanning directory: {e}")
# 创建新的初始化任务
if not self._initialization_task or self._initialization_task.done():
self._initialization_task = asyncio.create_task(self._initialize_cache())
return all_loras
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a single directory for LoRA files"""
loras = []
original_root = root_path # Save original root path
async def scan_recursive(path: str, visited_paths: set):
"""Recursively scan directory, avoiding circular symlinks"""
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
try:
await self._initialization_task
except Exception as e:
logger.error(f"Cache initialization failed: {e}")
# 如果缓存已存在,继续使用旧缓存
if self._cache is None:
raise # 如果没有缓存,则抛出异常
return self._cache
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
# Use original path instead of real path
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, loras)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# For directories, continue scanning with original path
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
async def _initialize_cache(self) -> None:
"""Initialize or refresh the cache"""
await scan_recursive(root_path, set())
return loras
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
"""Process a single file and add to results list"""
try:
start_time = time.time()
# Clear existing hash index
self._hash_index.clear()
# Clear existing tags count
self._tags_count = {}
# Scan for new data
raw_data = await self.scan_all_loras()
# Build hash index and tags count
for lora_data in raw_data:
if 'sha256' in lora_data and 'file_path' in lora_data:
self._hash_index.add_entry(lora_data['sha256'].lower(), lora_data['file_path'])
# Count tags
if 'tags' in lora_data and lora_data['tags']:
for tag in lora_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Update cache
self._cache = LoraCache(
raw_data=raw_data,
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# Call resort_cache to create sorted views
await self._cache.resort()
self._initialization_task = None
logger.info(f"LoRA Manager: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} loras")
result = await self._process_model_file(file_path, root_path)
if result:
loras.append(result)
except Exception as e:
logger.error(f"LoRA Manager: Error initializing cache: {e}")
self._cache = LoraCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
logger.error(f"Error processing {file_path}: {e}")
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
folder: str = None, search: str = None, fuzzy: bool = False,
folder: str = None, search: str = None, fuzzy_search: bool = False,
base_models: list = None, tags: list = None,
search_options: dict = None) -> Dict:
search_options: dict = None, hash_filters: dict = None) -> Dict:
"""Get paginated and filtered lora data
Args:
@@ -147,10 +131,11 @@ class LoraScanner:
sort_by: Sort method ('name' or 'date')
folder: Filter by folder path
search: Search term
fuzzy: Use fuzzy matching for search
fuzzy_search: Use fuzzy matching for search
base_models: List of base models to filter by
tags: List of tags to filter by
search_options: Dictionary with search options (filename, modelname, tags, recursive)
hash_filters: Dictionary with hash filtering options (single_hash or multiple_hashes)
"""
cache = await self.get_cached_data()
@@ -160,90 +145,108 @@ class LoraScanner:
'filename': True,
'modelname': True,
'tags': False,
'recursive': False
'recursive': False,
}
# Get the base data set
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
# Apply hash filtering if provided (highest priority)
if hash_filters:
single_hash = hash_filters.get('single_hash')
multiple_hashes = hash_filters.get('multiple_hashes')
if single_hash:
# Filter by single hash
single_hash = single_hash.lower() # Ensure lowercase for matching
filtered_data = [
lora for lora in filtered_data
if lora.get('sha256', '').lower() == single_hash
]
elif multiple_hashes:
# Filter by multiple hashes
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
filtered_data = [
lora for lora in filtered_data
if lora.get('sha256', '').lower() in hash_set
]
# Jump to pagination
total_items = len(filtered_data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
result = {
'items': filtered_data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
return result
# Apply SFW filtering if enabled
if settings.get('show_only_sfw', False):
filtered_data = [
item for item in filtered_data
if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R']
lora for lora in filtered_data
if not lora.get('preview_nsfw_level') or lora.get('preview_nsfw_level') < NSFW_LEVELS['R']
]
# Apply folder filtering
if folder is not None:
if search_options.get('recursive', False):
# Recursive mode: match all paths starting with this folder
# Recursive folder filtering - include all subfolders
filtered_data = [
item for item in filtered_data
if item['folder'].startswith(folder + '/') or item['folder'] == folder
lora for lora in filtered_data
if lora['folder'].startswith(folder)
]
else:
# Non-recursive mode: match exact folder
# Exact folder filtering
filtered_data = [
item for item in filtered_data
if item['folder'] == folder
lora for lora in filtered_data
if lora['folder'] == folder
]
# Apply base model filtering
if base_models and len(base_models) > 0:
filtered_data = [
item for item in filtered_data
if item.get('base_model') in base_models
lora for lora in filtered_data
if lora.get('base_model') in base_models
]
# Apply tag filtering
if tags and len(tags) > 0:
filtered_data = [
item for item in filtered_data
if any(tag in item.get('tags', []) for tag in tags)
lora for lora in filtered_data
if any(tag in lora.get('tags', []) for tag in tags)
]
# Apply search filtering
if search:
search_results = []
for item in filtered_data:
# Check filename if enabled
if search_options.get('filename', True):
if fuzzy:
if fuzzy_match(item.get('file_name', ''), search):
search_results.append(item)
continue
else:
if search.lower() in item.get('file_name', '').lower():
search_results.append(item)
continue
# Check model name if enabled
if search_options.get('modelname', True):
if fuzzy:
if fuzzy_match(item.get('model_name', ''), search):
search_results.append(item)
continue
else:
if search.lower() in item.get('model_name', '').lower():
search_results.append(item)
continue
# Check tags if enabled
if search_options.get('tags', False) and item.get('tags'):
found_tag = False
for tag in item['tags']:
if fuzzy:
if fuzzy_match(tag, search):
found_tag = True
break
else:
if search.lower() in tag.lower():
found_tag = True
break
if found_tag:
search_results.append(item)
search_opts = search_options or {}
for lora in filtered_data:
# Search by file name
if search_opts.get('filename', True):
if fuzzy_match(lora.get('file_name', ''), search):
search_results.append(lora)
continue
# Search by model name
if search_opts.get('modelname', True):
if fuzzy_match(lora.get('model_name', ''), search):
search_results.append(lora)
continue
# Search by tags
if search_opts.get('tags', False) and 'tags' in lora:
if any(fuzzy_match(tag, search) for tag in lora['tags']):
search_results.append(lora)
continue
filtered_data = search_results
# Calculate pagination
@@ -261,348 +264,6 @@ class LoraScanner:
return result
def invalidate_cache(self):
"""Invalidate the current cache"""
self._cache = None
async def scan_all_loras(self) -> List[Dict]:
"""Scan all LoRA directories and return metadata"""
all_loras = []
# 分目录异步扫描
scan_tasks = []
for loras_root in config.loras_roots:
task = asyncio.create_task(self._scan_directory(loras_root))
scan_tasks.append(task)
for task in scan_tasks:
try:
loras = await task
all_loras.extend(loras)
except Exception as e:
logger.error(f"Error scanning directory: {e}")
return all_loras
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a single directory for LoRA files"""
loras = []
original_root = root_path # 保存原始根路径
async def scan_recursive(path: str, visited_paths: set):
"""递归扫描目录,避免循环链接"""
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True) and entry.name.endswith('.safetensors'):
# 使用原始路径而不是真实路径
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, loras)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# 对于目录,使用原始路径继续扫描
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return loras
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
"""处理单个文件并添加到结果列表"""
try:
result = await self._process_lora_file(file_path, root_path)
if result:
loras.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
async def _process_lora_file(self, file_path: str, root_path: str) -> Dict:
"""Process a single LoRA file and return its metadata"""
# Try loading existing metadata
metadata = await load_metadata(file_path)
if metadata is None:
# Try to find and use .civitai.info file first
civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info"
if os.path.exists(civitai_info_path):
try:
with open(civitai_info_path, 'r', encoding='utf-8') as f:
version_info = json.load(f)
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
if file_info:
# Create a minimal file_info with the required fields
file_name = os.path.splitext(os.path.basename(file_path))[0]
file_info['name'] = file_name
# Use from_civitai_info to create metadata
metadata = LoraMetadata.from_civitai_info(version_info, file_info, file_path)
metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path))
await save_metadata(file_path, metadata)
logger.debug(f"Created metadata from .civitai.info for {file_path}")
except Exception as e:
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
# If still no metadata, create new metadata using get_file_info
if metadata is None:
metadata = await get_file_info(file_path)
# Convert to dict and add folder info
lora_data = metadata.to_dict()
# Try to fetch missing metadata from Civitai if needed
await self._fetch_missing_metadata(file_path, lora_data)
rel_path = os.path.relpath(file_path, root_path)
folder = os.path.dirname(rel_path)
lora_data['folder'] = folder.replace(os.path.sep, '/')
return lora_data
async def _fetch_missing_metadata(self, file_path: str, lora_data: Dict) -> None:
"""Fetch missing description and tags from Civitai if needed
Args:
file_path: Path to the lora file
lora_data: Lora metadata dictionary to update
"""
try:
# Skip if already marked as deleted on Civitai
if lora_data.get('civitai_deleted', False):
logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai")
return
# Check if we need to fetch additional metadata from Civitai
needs_metadata_update = False
model_id = None
# Check if we have Civitai model ID but missing metadata
if lora_data.get('civitai'):
# Try to get model ID directly from the correct location
model_id = lora_data['civitai'].get('modelId')
if model_id:
model_id = str(model_id)
# Check if tags are missing or empty
tags_missing = not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0
# Check if description is missing or empty
desc_missing = not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "")
needs_metadata_update = tags_missing or desc_missing
# Fetch missing metadata if needed
if needs_metadata_update and model_id:
logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}")
from ..services.civitai_client import CivitaiClient
client = CivitaiClient()
# Get metadata and status code
model_metadata, status_code = await client.get_model_metadata(model_id)
await client.close()
# Handle 404 status (model deleted from Civitai)
if status_code == 404:
logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)")
# Mark as deleted to avoid future API calls
lora_data['civitai_deleted'] = True
# Save the updated metadata back to file
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(lora_data, f, indent=2, ensure_ascii=False)
# Process valid metadata if available
elif model_metadata:
logger.debug(f"Updating metadata for {file_path} with model ID {model_id}")
# Update tags if they were missing
if model_metadata.get('tags') and (not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0):
lora_data['tags'] = model_metadata['tags']
# Update description if it was missing
if model_metadata.get('description') and (not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "")):
lora_data['modelDescription'] = model_metadata['description']
# Save the updated metadata back to file
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(lora_data, f, indent=2, ensure_ascii=False)
except Exception as e:
logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}")
async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool:
"""Update preview URL in cache for a specific lora
Args:
file_path: The file path of the lora to update
preview_url: The new preview URL
Returns:
bool: True if the update was successful, False if cache doesn't exist or lora wasn't found
"""
if self._cache is None:
return False
return await self._cache.update_preview_url(file_path, preview_url)
async def scan_single_lora(self, file_path: str) -> Optional[Dict]:
"""Scan a single LoRA file and return its metadata"""
try:
if not os.path.exists(os.path.realpath(file_path)):
return None
# 获取基本文件信息
metadata = await get_file_info(file_path)
if not metadata:
return None
folder = self._calculate_folder(file_path)
# 确保 folder 字段存在
metadata_dict = metadata.to_dict()
metadata_dict['folder'] = folder or ''
return metadata_dict
except Exception as e:
logger.error(f"Error scanning {file_path}: {e}")
return None
def _calculate_folder(self, file_path: str) -> str:
"""Calculate the folder path for a LoRA file"""
# 使用原始路径计算相对路径
for root in config.loras_roots:
if file_path.startswith(root):
rel_path = os.path.relpath(file_path, root)
return os.path.dirname(rel_path).replace(os.path.sep, '/')
return ''
async def move_model(self, source_path: str, target_path: str) -> bool:
"""Move a model and its associated files to a new location"""
try:
# 保持原始路径格式
source_path = source_path.replace(os.sep, '/')
target_path = target_path.replace(os.sep, '/')
# 其余代码保持不变
base_name = os.path.splitext(os.path.basename(source_path))[0]
source_dir = os.path.dirname(source_path)
os.makedirs(target_path, exist_ok=True)
target_lora = os.path.join(target_path, f"{base_name}.safetensors").replace(os.sep, '/')
# 使用真实路径进行文件操作
real_source = os.path.realpath(source_path)
real_target = os.path.realpath(target_lora)
file_size = os.path.getsize(real_source)
if self.file_monitor:
self.file_monitor.handler.add_ignore_path(
real_source,
file_size
)
self.file_monitor.handler.add_ignore_path(
real_target,
file_size
)
# 使用真实路径进行文件操作
shutil.move(real_source, real_target)
# Move associated files
source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json")
if os.path.exists(source_metadata):
target_metadata = os.path.join(target_path, f"{base_name}.metadata.json")
shutil.move(source_metadata, target_metadata)
metadata = await self._update_metadata_paths(target_metadata, target_lora)
# Move preview file if exists
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
'.png', '.jpeg', '.jpg', '.mp4']
for ext in preview_extensions:
source_preview = os.path.join(source_dir, f"{base_name}{ext}")
if os.path.exists(source_preview):
target_preview = os.path.join(target_path, f"{base_name}{ext}")
shutil.move(source_preview, target_preview)
break
# Update cache
await self.update_single_lora_cache(source_path, target_lora, metadata)
return True
except Exception as e:
logger.error(f"Error moving model: {e}", exc_info=True)
return False
async def update_single_lora_cache(self, original_path: str, new_path: str, metadata: Dict) -> bool:
cache = await self.get_cached_data()
# Find the existing item to remove its tags from count
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
if existing_item and 'tags' in existing_item:
for tag in existing_item.get('tags', []):
if tag in self._tags_count:
self._tags_count[tag] = max(0, self._tags_count[tag] - 1)
if self._tags_count[tag] == 0:
del self._tags_count[tag]
# Remove old path from hash index if exists
self._hash_index.remove_by_path(original_path)
# Remove the old entry from raw_data
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != original_path
]
if metadata:
# If this is an update to an existing path (not a move), ensure folder is preserved
if original_path == new_path:
# Find the folder from existing entries or calculate it
existing_folder = next((item['folder'] for item in cache.raw_data
if item['file_path'] == original_path), None)
if existing_folder:
metadata['folder'] = existing_folder
else:
metadata['folder'] = self._calculate_folder(new_path)
else:
# For moved files, recalculate the folder
metadata['folder'] = self._calculate_folder(new_path)
# Add the updated metadata to raw_data
cache.raw_data.append(metadata)
# Update hash index with new path
if 'sha256' in metadata:
self._hash_index.add_entry(metadata['sha256'].lower(), new_path)
# Update folders list
all_folders = set(item['folder'] for item in cache.raw_data)
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Update tags count with the new/updated tags
if 'tags' in metadata:
for tag in metadata.get('tags', []):
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Resort cache
await cache.resort()
return True
async def _update_metadata_paths(self, metadata_path: str, lora_path: str) -> Dict:
"""Update file paths in metadata file"""
try:
@@ -629,49 +290,21 @@ class LoraScanner:
except Exception as e:
logger.error(f"Error updating metadata paths: {e}", exc_info=True)
# Add new methods for hash index functionality
# Lora-specific hash index functionality
def has_lora_hash(self, sha256: str) -> bool:
"""Check if a LoRA with given hash exists"""
return self._hash_index.has_hash(sha256.lower())
return self.has_hash(sha256)
def get_lora_path_by_hash(self, sha256: str) -> Optional[str]:
"""Get file path for a LoRA by its hash"""
return self._hash_index.get_path(sha256.lower())
return self.get_path_by_hash(sha256)
def get_lora_hash_by_path(self, file_path: str) -> Optional[str]:
"""Get hash for a LoRA by its file path"""
return self._hash_index.get_hash(file_path)
return self.get_hash_by_path(file_path)
def get_preview_url_by_hash(self, sha256: str) -> Optional[str]:
"""Get preview static URL for a LoRA by its hash"""
# Get the file path first
file_path = self._hash_index.get_path(sha256.lower())
if not file_path:
return None
# Determine the preview file path (typically same name with different extension)
base_name = os.path.splitext(file_path)[0]
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
'.png', '.jpeg', '.jpg', '.mp4']
for ext in preview_extensions:
preview_path = f"{base_name}{ext}"
if os.path.exists(preview_path):
# Convert to static URL using config
return config.get_preview_static_url(preview_path)
return None
# Add new method to get top tags
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get top tags sorted by count
Args:
limit: Maximum number of tags to return
Returns:
List of dictionaries with tag name and count, sorted by count
"""
"""Get top tags sorted by count"""
# Make sure cache is initialized
await self.get_cached_data()
@@ -686,14 +319,7 @@ class LoraScanner:
return sorted_tags[:limit]
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get base models used in loras sorted by frequency
Args:
limit: Maximum number of base models to return
Returns:
List of dictionaries with base model name and count, sorted by count
"""
"""Get base models used in loras sorted by frequency"""
# Make sure cache is initialized
cache = await self.get_cached_data()

View File

@@ -0,0 +1,64 @@
import asyncio
from typing import List, Dict
from dataclasses import dataclass
from operator import itemgetter
@dataclass
class ModelCache:
"""Cache structure for model data"""
raw_data: List[Dict]
sorted_by_name: List[Dict]
sorted_by_date: List[Dict]
folders: List[str]
def __post_init__(self):
self._lock = asyncio.Lock()
async def resort(self, name_only: bool = False):
"""Resort all cached data views"""
async with self._lock:
self.sorted_by_name = sorted(
self.raw_data,
key=lambda x: x['model_name'].lower() # Case-insensitive sort
)
if not name_only:
self.sorted_by_date = sorted(
self.raw_data,
key=itemgetter('modified'),
reverse=True
)
# Update folder list
all_folders = set(l['folder'] for l in self.raw_data)
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
async def update_preview_url(self, file_path: str, preview_url: str) -> bool:
"""Update preview_url for a specific model in all cached data
Args:
file_path: The file path of the model to update
preview_url: The new preview URL
Returns:
bool: True if the update was successful, False if the model wasn't found
"""
async with self._lock:
# Update in raw_data
for item in self.raw_data:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
else:
return False # Model not found
# Update in sorted lists (references to the same dict objects)
for item in self.sorted_by_name:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
for item in self.sorted_by_date:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
return True

View File

@@ -0,0 +1,78 @@
from typing import Dict, Optional, Set
class ModelHashIndex:
"""Index for looking up models by hash or path"""
def __init__(self):
self._hash_to_path: Dict[str, str] = {}
self._path_to_hash: Dict[str, str] = {}
def add_entry(self, sha256: str, file_path: str) -> None:
"""Add or update hash index entry"""
if not sha256 or not file_path:
return
# Ensure hash is lowercase for consistency
sha256 = sha256.lower()
# Remove old path mapping if hash exists
if sha256 in self._hash_to_path:
old_path = self._hash_to_path[sha256]
if old_path in self._path_to_hash:
del self._path_to_hash[old_path]
# Remove old hash mapping if path exists
if file_path in self._path_to_hash:
old_hash = self._path_to_hash[file_path]
if old_hash in self._hash_to_path:
del self._hash_to_path[old_hash]
# Add new mappings
self._hash_to_path[sha256] = file_path
self._path_to_hash[file_path] = sha256
def remove_by_path(self, file_path: str) -> None:
"""Remove entry by file path"""
if file_path in self._path_to_hash:
hash_val = self._path_to_hash[file_path]
if hash_val in self._hash_to_path:
del self._hash_to_path[hash_val]
del self._path_to_hash[file_path]
def remove_by_hash(self, sha256: str) -> None:
"""Remove entry by hash"""
sha256 = sha256.lower()
if sha256 in self._hash_to_path:
path = self._hash_to_path[sha256]
if path in self._path_to_hash:
del self._path_to_hash[path]
del self._hash_to_path[sha256]
def has_hash(self, sha256: str) -> bool:
"""Check if hash exists in index"""
return sha256.lower() in self._hash_to_path
def get_path(self, sha256: str) -> Optional[str]:
"""Get file path for a hash"""
return self._hash_to_path.get(sha256.lower())
def get_hash(self, file_path: str) -> Optional[str]:
"""Get hash for a file path"""
return self._path_to_hash.get(file_path)
def clear(self) -> None:
"""Clear all entries"""
self._hash_to_path.clear()
self._path_to_hash.clear()
def get_all_hashes(self) -> Set[str]:
"""Get all hashes in the index"""
return set(self._hash_to_path.keys())
def get_all_paths(self) -> Set[str]:
"""Get all file paths in the index"""
return set(self._path_to_hash.keys())
def __len__(self) -> int:
"""Get number of entries"""
return len(self._hash_to_path)

View File

@@ -0,0 +1,879 @@
import json
import os
import logging
import asyncio
import time
import shutil
from typing import List, Dict, Optional, Type, Set
from ..utils.models import BaseModelMetadata
from ..config import config
from ..utils.file_utils import load_metadata, get_file_info, find_preview_file, save_metadata
from .model_cache import ModelCache
from .model_hash_index import ModelHashIndex
from ..utils.constants import PREVIEW_EXTENSIONS
from .service_registry import ServiceRegistry
from .websocket_manager import ws_manager
logger = logging.getLogger(__name__)
class ModelScanner:
"""Base service for scanning and managing model files"""
_lock = asyncio.Lock()
def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None):
"""Initialize the scanner
Args:
model_type: Type of model (lora, checkpoint, etc.)
model_class: Class used to create metadata instances
file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'})
hash_index: Hash index instance (optional)
"""
self.model_type = model_type
self.model_class = model_class
self.file_extensions = file_extensions
self._cache = None
self._hash_index = hash_index or ModelHashIndex()
self._tags_count = {} # Dictionary to store tag counts
self._is_initializing = False # Flag to track initialization state
# Register this service
asyncio.create_task(self._register_service())
async def _register_service(self):
"""Register this instance with the ServiceRegistry"""
service_name = f"{self.model_type}_scanner"
await ServiceRegistry.register_service(service_name, self)
async def initialize_in_background(self) -> None:
"""Initialize cache in background using thread pool"""
try:
# Set initial empty cache to avoid None reference errors
if self._cache is None:
self._cache = ModelCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# Set initializing flag to true
self._is_initializing = True
# Determine the page type based on model type
page_type = 'loras' if self.model_type == 'lora' else 'checkpoints'
# First, count all model files to track progress
await ws_manager.broadcast_init_progress({
'stage': 'scan_folders',
'progress': 0,
'details': f"Scanning {self.model_type} folders...",
'scanner_type': self.model_type,
'pageType': page_type
})
# Count files in a separate thread to avoid blocking
loop = asyncio.get_event_loop()
total_files = await loop.run_in_executor(
None, # Use default thread pool
self._count_model_files # Run file counting in thread
)
await ws_manager.broadcast_init_progress({
'stage': 'count_models',
'progress': 1, # Changed from 10 to 1
'details': f"Found {total_files} {self.model_type} files",
'scanner_type': self.model_type,
'pageType': page_type
})
start_time = time.time()
# Use thread pool to execute CPU-intensive operations with progress reporting
await loop.run_in_executor(
None, # Use default thread pool
self._initialize_cache_sync, # Run synchronous version in thread
total_files, # Pass the total file count for progress reporting
page_type # Pass the page type for progress reporting
)
# Send final progress update
await ws_manager.broadcast_init_progress({
'stage': 'finalizing',
'progress': 99, # Changed from 95 to 99
'details': f"Finalizing {self.model_type} cache...",
'scanner_type': self.model_type,
'pageType': page_type
})
logger.info(f"{self.model_type.capitalize()} cache initialized in {time.time() - start_time:.2f} seconds. Found {len(self._cache.raw_data)} models")
# Send completion message
await asyncio.sleep(0.5) # Small delay to ensure final progress message is sent
await ws_manager.broadcast_init_progress({
'stage': 'finalizing',
'progress': 100,
'status': 'complete',
'details': f"Completed! Found {len(self._cache.raw_data)} {self.model_type} files.",
'scanner_type': self.model_type,
'pageType': page_type
})
except Exception as e:
logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache in background: {e}")
finally:
# Always clear the initializing flag when done
self._is_initializing = False
def _count_model_files(self) -> int:
"""Count all model files with supported extensions in all roots
Returns:
int: Total number of model files found
"""
total_files = 0
visited_real_paths = set()
for root_path in self.get_model_roots():
if not os.path.exists(root_path):
continue
def count_recursive(path):
nonlocal total_files
try:
real_path = os.path.realpath(path)
if real_path in visited_real_paths:
return
visited_real_paths.add(real_path)
with os.scandir(path) as it:
for entry in it:
try:
if entry.is_file(follow_symlinks=True):
ext = os.path.splitext(entry.name)[1].lower()
if ext in self.file_extensions:
total_files += 1
elif entry.is_dir(follow_symlinks=True):
count_recursive(entry.path)
except Exception as e:
logger.error(f"Error counting files in entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error counting files in {path}: {e}")
count_recursive(root_path)
return total_files
def _initialize_cache_sync(self, total_files=0, page_type='loras'):
"""Synchronous version of cache initialization for thread pool execution"""
try:
# Create a new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Create a synchronous method to bypass the async lock
def sync_initialize_cache():
# Track progress
processed_files = 0
last_progress_time = time.time()
last_progress_percent = 0
# We need a wrapper around scan_all_models to track progress
# This is a local function that will run in our thread's event loop
async def scan_with_progress():
nonlocal processed_files, last_progress_time, last_progress_percent
# For storing raw model data
all_models = []
# Process each model root
for root_path in self.get_model_roots():
if not os.path.exists(root_path):
continue
# Track visited paths to avoid symlink loops
visited_paths = set()
# Recursively process directory
async def scan_dir_with_progress(path):
nonlocal processed_files, last_progress_time, last_progress_percent
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True):
ext = os.path.splitext(entry.name)[1].lower()
if ext in self.file_extensions:
file_path = entry.path.replace(os.sep, "/")
result = await self._process_model_file(file_path, root_path)
if result:
all_models.append(result)
# Update progress counter
processed_files += 1
# Update progress periodically (not every file to avoid excessive updates)
current_time = time.time()
if total_files > 0 and (current_time - last_progress_time > 0.5 or processed_files == total_files):
# Adjusted progress calculation
progress_percent = min(99, int(1 + (processed_files / total_files) * 98))
if progress_percent > last_progress_percent:
last_progress_percent = progress_percent
last_progress_time = current_time
# Send progress update through websocket
await ws_manager.broadcast_init_progress({
'stage': 'process_models',
'progress': progress_percent,
'details': f"Processing {self.model_type} files: {processed_files}/{total_files}",
'scanner_type': self.model_type,
'pageType': page_type
})
elif entry.is_dir(follow_symlinks=True):
await scan_dir_with_progress(entry.path)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
# Process the root path
await scan_dir_with_progress(root_path)
return all_models
# Run the progress-tracking scan function
raw_data = loop.run_until_complete(scan_with_progress())
# Update hash index and tags count
for model_data in raw_data:
if 'sha256' in model_data and 'file_path' in model_data:
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
# Count tags
if 'tags' in model_data and model_data['tags']:
for tag in model_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Update cache
self._cache.raw_data = raw_data
loop.run_until_complete(self._cache.resort())
return self._cache
# Run our sync initialization that avoids lock conflicts
return sync_initialize_cache()
except Exception as e:
logger.error(f"Error in thread-based {self.model_type} cache initialization: {e}")
finally:
# Clean up the event loop
loop.close()
async def get_cached_data(self, force_refresh: bool = False) -> ModelCache:
"""Get cached model data, refresh if needed"""
# If cache is not initialized, return an empty cache
# Actual initialization should be done via initialize_in_background
if self._cache is None and not force_refresh:
return ModelCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# If force refresh is requested, initialize the cache directly
if force_refresh:
if self._cache is None:
# For initial creation, do a full initialization
await self._initialize_cache()
else:
# For subsequent refreshes, use fast reconciliation
await self._reconcile_cache()
return self._cache
async def _initialize_cache(self) -> None:
"""Initialize or refresh the cache"""
self._is_initializing = True # Set flag
try:
start_time = time.time()
# Clear existing hash index
self._hash_index.clear()
# Clear existing tags count
self._tags_count = {}
# Determine the page type based on model type
page_type = 'loras' if self.model_type == 'lora' else 'checkpoints'
# Scan for new data
raw_data = await self.scan_all_models()
# Build hash index and tags count
for model_data in raw_data:
if 'sha256' in model_data and 'file_path' in model_data:
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
# Count tags
if 'tags' in model_data and model_data['tags']:
for tag in model_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Update cache
self._cache = ModelCache(
raw_data=raw_data,
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# Resort cache
await self._cache.resort()
logger.info(f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} models")
except Exception as e:
logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache: {e}")
# Ensure cache is at least an empty structure on error
if self._cache is None:
self._cache = ModelCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
finally:
self._is_initializing = False # Unset flag
async def _reconcile_cache(self) -> None:
"""Fast cache reconciliation - only process differences between cache and filesystem"""
self._is_initializing = True # Set flag for reconciliation duration
try:
start_time = time.time()
logger.info(f"{self.model_type.capitalize()} Scanner: Starting fast cache reconciliation...")
# Get current cached file paths
cached_paths = {item['file_path'] for item in self._cache.raw_data}
path_to_item = {item['file_path']: item for item in self._cache.raw_data}
# Track found files and new files
found_paths = set()
new_files = []
# Scan all model roots
for root_path in self.get_model_roots():
if not os.path.exists(root_path):
continue
# Track visited real paths to avoid symlink loops
visited_real_paths = set()
# Recursively scan directory
for root, _, files in os.walk(root_path, followlinks=True):
real_root = os.path.realpath(root)
if real_root in visited_real_paths:
continue
visited_real_paths.add(real_root)
for file in files:
ext = os.path.splitext(file)[1].lower()
if ext in self.file_extensions:
# Construct paths exactly as they would be in cache
file_path = os.path.join(root, file).replace(os.sep, '/')
# Check if this file is already in cache
if file_path in cached_paths:
found_paths.add(file_path)
continue
# Try case-insensitive match on Windows
if os.name == 'nt':
lower_path = file_path.lower()
matched = False
for cached_path in cached_paths:
if cached_path.lower() == lower_path:
found_paths.add(cached_path)
matched = True
break
if matched:
continue
# This is a new file to process
new_files.append(file_path)
# Yield control periodically
await asyncio.sleep(0)
# Process new files in batches
total_added = 0
if new_files:
logger.info(f"{self.model_type.capitalize()} Scanner: Found {len(new_files)} new files to process")
batch_size = 50
for i in range(0, len(new_files), batch_size):
batch = new_files[i:i+batch_size]
for path in batch:
try:
model_data = await self.scan_single_model(path)
if model_data:
# Add to cache
self._cache.raw_data.append(model_data)
# Update hash index if available
if 'sha256' in model_data and 'file_path' in model_data:
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
# Update tags count
if 'tags' in model_data and model_data['tags']:
for tag in model_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
total_added += 1
except Exception as e:
logger.error(f"Error adding {path} to cache: {e}")
# Yield control after each batch
await asyncio.sleep(0)
# Find missing files (in cache but not in filesystem)
missing_files = cached_paths - found_paths
total_removed = 0
if missing_files:
logger.info(f"{self.model_type.capitalize()} Scanner: Found {len(missing_files)} files to remove from cache")
# Process files to remove
for path in missing_files:
try:
model_to_remove = path_to_item[path]
# Update tags count
for tag in model_to_remove.get('tags', []):
if tag in self._tags_count:
self._tags_count[tag] = max(0, self._tags_count[tag] - 1)
if self._tags_count[tag] == 0:
del self._tags_count[tag]
# Remove from hash index
self._hash_index.remove_by_path(path)
total_removed += 1
except Exception as e:
logger.error(f"Error removing {path} from cache: {e}")
# Update cache data
self._cache.raw_data = [item for item in self._cache.raw_data if item['file_path'] not in missing_files]
# Resort cache if changes were made
if total_added > 0 or total_removed > 0:
# Update folders list
all_folders = set(item.get('folder', '') for item in self._cache.raw_data)
self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Resort cache
await self._cache.resort()
logger.info(f"{self.model_type.capitalize()} Scanner: Cache reconciliation completed in {time.time() - start_time:.2f} seconds. Added {total_added}, removed {total_removed} models.")
except Exception as e:
logger.error(f"{self.model_type.capitalize()} Scanner: Error reconciling cache: {e}", exc_info=True)
finally:
self._is_initializing = False # Unset flag
# These methods should be implemented in child classes
async def scan_all_models(self) -> List[Dict]:
"""Scan all model directories and return metadata"""
raise NotImplementedError("Subclasses must implement scan_all_models")
def get_model_roots(self) -> List[str]:
"""Get model root directories"""
raise NotImplementedError("Subclasses must implement get_model_roots")
async def scan_single_model(self, file_path: str) -> Optional[Dict]:
"""Scan a single model file and return its metadata"""
try:
if not os.path.exists(os.path.realpath(file_path)):
return None
# Get basic file info
metadata = await self._get_file_info(file_path)
if not metadata:
return None
folder = self._calculate_folder(file_path)
# Ensure folder field exists
metadata_dict = metadata.to_dict()
metadata_dict['folder'] = folder or ''
return metadata_dict
except Exception as e:
logger.error(f"Error scanning {file_path}: {e}")
return None
async def _get_file_info(self, file_path: str) -> Optional[BaseModelMetadata]:
"""Get model file info and metadata (extensible for different model types)"""
return await get_file_info(file_path, self.model_class)
def _calculate_folder(self, file_path: str) -> str:
"""Calculate the folder path for a model file"""
for root in self.get_model_roots():
if file_path.startswith(root):
rel_path = os.path.relpath(file_path, root)
return os.path.dirname(rel_path).replace(os.path.sep, '/')
return ''
# Common methods shared between scanners
async def _process_model_file(self, file_path: str, root_path: str) -> Dict:
"""Process a single model file and return its metadata"""
metadata = await load_metadata(file_path, self.model_class)
if metadata is None:
civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info"
if os.path.exists(civitai_info_path):
try:
with open(civitai_info_path, 'r', encoding='utf-8') as f:
version_info = json.load(f)
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
if file_info:
file_name = os.path.splitext(os.path.basename(file_path))[0]
file_info['name'] = file_name
metadata = self.model_class.from_civitai_info(version_info, file_info, file_path)
metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path))
await save_metadata(file_path, metadata)
logger.debug(f"Created metadata from .civitai.info for {file_path}")
except Exception as e:
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
if metadata is None:
metadata = await self._get_file_info(file_path)
model_data = metadata.to_dict()
await self._fetch_missing_metadata(file_path, model_data)
rel_path = os.path.relpath(file_path, root_path)
folder = os.path.dirname(rel_path)
model_data['folder'] = folder.replace(os.path.sep, '/')
return model_data
async def _fetch_missing_metadata(self, file_path: str, model_data: Dict) -> None:
"""Fetch missing description and tags from Civitai if needed"""
try:
if model_data.get('civitai_deleted', False):
logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai")
return
needs_metadata_update = False
model_id = None
if model_data.get('civitai'):
model_id = model_data['civitai'].get('modelId')
if model_id:
model_id = str(model_id)
tags_missing = not model_data.get('tags') or len(model_data.get('tags', [])) == 0
desc_missing = not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")
needs_metadata_update = tags_missing or desc_missing
if needs_metadata_update and model_id:
logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}")
from ..services.civitai_client import CivitaiClient
client = CivitaiClient()
model_metadata, status_code = await client.get_model_metadata(model_id)
await client.close()
if status_code == 404:
logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)")
model_data['civitai_deleted'] = True
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(model_data, f, indent=2, ensure_ascii=False)
elif model_metadata:
logger.debug(f"Updating metadata for {file_path} with model ID {model_id}")
if model_metadata.get('tags') and (not model_data.get('tags') or len(model_data.get('tags', [])) == 0):
model_data['tags'] = model_metadata['tags']
if model_metadata.get('description') and (not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")):
model_data['modelDescription'] = model_metadata['description']
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(model_data, f, indent=2, ensure_ascii=False)
except Exception as e:
logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}")
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Base implementation for directory scanning"""
models = []
original_root = root_path
async def scan_recursive(path: str, visited_paths: set):
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True):
ext = os.path.splitext(entry.name)[1].lower()
if ext in self.file_extensions:
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, models)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return models
async def _process_single_file(self, file_path: str, root_path: str, models_list: list):
"""Process a single file and add to results list"""
try:
result = await self._process_model_file(file_path, root_path)
if result:
models_list.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
async def move_model(self, source_path: str, target_path: str) -> bool:
"""Move a model and its associated files to a new location"""
try:
source_path = source_path.replace(os.sep, '/')
target_path = target_path.replace(os.sep, '/')
file_ext = os.path.splitext(source_path)[1]
if not file_ext or file_ext.lower() not in self.file_extensions:
logger.error(f"Invalid file extension for model: {file_ext}")
return False
base_name = os.path.splitext(os.path.basename(source_path))[0]
source_dir = os.path.dirname(source_path)
os.makedirs(target_path, exist_ok=True)
target_file = os.path.join(target_path, f"{base_name}{file_ext}").replace(os.sep, '/')
real_source = os.path.realpath(source_path)
real_target = os.path.realpath(target_file)
file_size = os.path.getsize(real_source)
# Get the appropriate file monitor through ServiceRegistry
if self.model_type == "lora":
monitor = await ServiceRegistry.get_lora_monitor()
elif self.model_type == "checkpoint":
monitor = await ServiceRegistry.get_checkpoint_monitor()
else:
monitor = None
if monitor:
monitor.handler.add_ignore_path(
real_source,
file_size
)
monitor.handler.add_ignore_path(
real_target,
file_size
)
shutil.move(real_source, real_target)
source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json")
metadata = None
if os.path.exists(source_metadata):
target_metadata = os.path.join(target_path, f"{base_name}.metadata.json")
shutil.move(source_metadata, target_metadata)
metadata = await self._update_metadata_paths(target_metadata, target_file)
for ext in PREVIEW_EXTENSIONS:
source_preview = os.path.join(source_dir, f"{base_name}{ext}")
if os.path.exists(source_preview):
target_preview = os.path.join(target_path, f"{base_name}{ext}")
shutil.move(source_preview, target_preview)
break
await self.update_single_model_cache(source_path, target_file, metadata)
return True
except Exception as e:
logger.error(f"Error moving model: {e}", exc_info=True)
return False
async def _update_metadata_paths(self, metadata_path: str, model_path: str) -> Dict:
"""Update file paths in metadata file"""
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
metadata['file_path'] = model_path.replace(os.sep, '/')
if 'preview_url' in metadata:
preview_dir = os.path.dirname(model_path)
preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0]
preview_ext = os.path.splitext(metadata['preview_url'])[1]
new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}")
metadata['preview_url'] = new_preview_path.replace(os.sep, '/')
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
return metadata
except Exception as e:
logger.error(f"Error updating metadata paths: {e}", exc_info=True)
return None
async def update_single_model_cache(self, original_path: str, new_path: str, metadata: Dict) -> bool:
"""Update cache after a model has been moved or modified"""
cache = await self.get_cached_data()
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
if existing_item and 'tags' in existing_item:
for tag in existing_item.get('tags', []):
if tag in self._tags_count:
self._tags_count[tag] = max(0, self._tags_count[tag] - 1)
if self._tags_count[tag] == 0:
del self._tags_count[tag]
self._hash_index.remove_by_path(original_path)
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != original_path
]
if metadata:
if original_path == new_path:
existing_folder = next((item['folder'] for item in cache.raw_data
if item['file_path'] == original_path), None)
if existing_folder:
metadata['folder'] = existing_folder
else:
metadata['folder'] = self._calculate_folder(new_path)
else:
metadata['folder'] = self._calculate_folder(new_path)
cache.raw_data.append(metadata)
if 'sha256' in metadata:
self._hash_index.add_entry(metadata['sha256'].lower(), new_path)
all_folders = set(item['folder'] for item in cache.raw_data)
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
if 'tags' in metadata:
for tag in metadata.get('tags', []):
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
await cache.resort()
return True
def has_hash(self, sha256: str) -> bool:
"""Check if a model with given hash exists"""
return self._hash_index.has_hash(sha256.lower())
def get_path_by_hash(self, sha256: str) -> Optional[str]:
"""Get file path for a model by its hash"""
return self._hash_index.get_path(sha256.lower())
def get_hash_by_path(self, file_path: str) -> Optional[str]:
"""Get hash for a model by its file path"""
return self._hash_index.get_hash(file_path)
# TODO: Adjust this method to use metadata instead of finding the file
def get_preview_url_by_hash(self, sha256: str) -> Optional[str]:
"""Get preview static URL for a model by its hash"""
file_path = self._hash_index.get_path(sha256.lower())
if not file_path:
return None
base_name = os.path.splitext(file_path)[0]
for ext in PREVIEW_EXTENSIONS:
preview_path = f"{base_name}{ext}"
if os.path.exists(preview_path):
return config.get_preview_static_url(preview_path)
return None
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get top tags sorted by count"""
await self.get_cached_data()
sorted_tags = sorted(
[{"tag": tag, "count": count} for tag, count in self._tags_count.items()],
key=lambda x: x['count'],
reverse=True
)
return sorted_tags[:limit]
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get base models sorted by frequency"""
cache = await self.get_cached_data()
base_model_counts = {}
for model in cache.raw_data:
if 'base_model' in model and model['base_model']:
base_model = model['base_model']
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()]
sorted_models.sort(key=lambda x: x['count'], reverse=True)
return sorted_models[:limit]
async def get_model_info_by_name(self, name):
"""Get model information by name"""
try:
cache = await self.get_cached_data()
for model in cache.raw_data:
if model.get("file_name") == name:
return model
return None
except Exception as e:
logger.error(f"Error getting model info by name: {e}", exc_info=True)
return None
async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool:
"""Update preview URL in cache for a specific lora
Args:
file_path: The file path of the lora to update
preview_url: The new preview URL
Returns:
bool: True if the update was successful, False if cache doesn't exist or lora wasn't found
"""
if self._cache is None:
return False
return await self._cache.update_preview_url(file_path, preview_url)

View File

@@ -2,11 +2,12 @@ import os
import logging
import asyncio
import json
import time
from typing import List, Dict, Optional, Any, Tuple
from ..config import config
from .recipe_cache import RecipeCache
from .service_registry import ServiceRegistry
from .lora_scanner import LoraScanner
from .civitai_client import CivitaiClient
from ..utils.utils import fuzzy_match
import sys
@@ -18,11 +19,22 @@ class RecipeScanner:
_instance = None
_lock = asyncio.Lock()
@classmethod
async def get_instance(cls, lora_scanner: Optional[LoraScanner] = None):
"""Get singleton instance of RecipeScanner"""
async with cls._lock:
if cls._instance is None:
if not lora_scanner:
# Get lora scanner from service registry if not provided
lora_scanner = await ServiceRegistry.get_lora_scanner()
cls._instance = cls(lora_scanner)
return cls._instance
def __new__(cls, lora_scanner: Optional[LoraScanner] = None):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._lora_scanner = lora_scanner
cls._instance._civitai_client = CivitaiClient()
cls._instance._civitai_client = None # Will be lazily initialized
return cls._instance
def __init__(self, lora_scanner: Optional[LoraScanner] = None):
@@ -35,9 +47,148 @@ class RecipeScanner:
if lora_scanner:
self._lora_scanner = lora_scanner
self._initialized = True
# Initialization will be scheduled by LoraManager
async def _get_civitai_client(self):
"""Lazily initialize CivitaiClient from registry"""
if self._civitai_client is None:
self._civitai_client = await ServiceRegistry.get_civitai_client()
return self._civitai_client
async def initialize_in_background(self) -> None:
"""Initialize cache in background using thread pool"""
try:
# Set initial empty cache to avoid None reference errors
if self._cache is None:
self._cache = RecipeCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[]
)
# Mark as initializing to prevent concurrent initializations
self._is_initializing = True
try:
# Start timer
start_time = time.time()
# Use thread pool to execute CPU-intensive operations
loop = asyncio.get_event_loop()
cache = await loop.run_in_executor(
None, # Use default thread pool
self._initialize_recipe_cache_sync # Run synchronous version in thread
)
# Calculate elapsed time and log it
elapsed_time = time.time() - start_time
recipe_count = len(cache.raw_data) if cache and hasattr(cache, 'raw_data') else 0
logger.info(f"Recipe cache initialized in {elapsed_time:.2f} seconds. Found {recipe_count} recipes")
finally:
# Mark initialization as complete regardless of outcome
self._is_initializing = False
except Exception as e:
logger.error(f"Recipe Scanner: Error initializing cache in background: {e}")
def _initialize_recipe_cache_sync(self):
"""Synchronous version of recipe cache initialization for thread pool execution"""
try:
# Create a new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Create a synchronous method to bypass the async lock
def sync_initialize_cache():
# We need to implement scan_all_recipes logic synchronously here
# instead of calling the async method to avoid event loop issues
recipes = []
recipes_dir = self.recipes_dir
if not recipes_dir or not os.path.exists(recipes_dir):
logger.warning(f"Recipes directory not found: {recipes_dir}")
return recipes
# Get all recipe JSON files in the recipes directory
recipe_files = []
for root, _, files in os.walk(recipes_dir):
recipe_count = sum(1 for f in files if f.lower().endswith('.recipe.json'))
if recipe_count > 0:
for file in files:
if file.lower().endswith('.recipe.json'):
recipe_files.append(os.path.join(root, file))
# Process each recipe file
for recipe_path in recipe_files:
try:
with open(recipe_path, 'r', encoding='utf-8') as f:
recipe_data = json.load(f)
# Validate recipe data
if not recipe_data or not isinstance(recipe_data, dict):
logger.warning(f"Invalid recipe data in {recipe_path}")
continue
# Ensure required fields exist
required_fields = ['id', 'file_path', 'title']
if not all(field in recipe_data for field in required_fields):
logger.warning(f"Missing required fields in {recipe_path}")
continue
# Ensure the image file exists
image_path = recipe_data.get('file_path')
if not os.path.exists(image_path):
recipe_dir = os.path.dirname(recipe_path)
image_filename = os.path.basename(image_path)
alternative_path = os.path.join(recipe_dir, image_filename)
if os.path.exists(alternative_path):
recipe_data['file_path'] = alternative_path
# Ensure loras array exists
if 'loras' not in recipe_data:
recipe_data['loras'] = []
# Ensure gen_params exists
if 'gen_params' not in recipe_data:
recipe_data['gen_params'] = {}
# Add to list without async operations
recipes.append(recipe_data)
except Exception as e:
logger.error(f"Error loading recipe file {recipe_path}: {e}")
import traceback
traceback.print_exc(file=sys.stderr)
# Update cache with the collected data
self._cache.raw_data = recipes
# Create a simplified resort function that doesn't use await
if hasattr(self._cache, "resort"):
try:
# Sort by name
self._cache.sorted_by_name = sorted(
self._cache.raw_data,
key=lambda x: x.get('title', '').lower()
)
# Sort by date (modified or created)
self._cache.sorted_by_date = sorted(
self._cache.raw_data,
key=lambda x: x.get('modified', x.get('created_date', 0)),
reverse=True
)
except Exception as e:
logger.error(f"Error sorting recipe cache: {e}")
return self._cache
# Run our sync initialization that avoids lock conflicts
return sync_initialize_cache()
except Exception as e:
logger.error(f"Error in thread-based recipe cache initialization: {e}")
return self._cache if hasattr(self, '_cache') else None
finally:
# Clean up the event loop
loop.close()
@property
def recipes_dir(self) -> str:
"""Get path to recipes directory"""
@@ -60,49 +211,48 @@ class RecipeScanner:
if self._is_initializing and not force_refresh:
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
# Try to acquire the lock with a timeout to prevent deadlocks
try:
async with self._initialization_lock:
# Check again after acquiring the lock
if self._cache is not None and not force_refresh:
return self._cache
# Mark as initializing to prevent concurrent initializations
self._is_initializing = True
try:
# Remove dependency on lora scanner initialization
# Scan for recipe data directly
raw_data = await self.scan_all_recipes()
# If force refresh is requested, initialize the cache directly
if force_refresh:
# Try to acquire the lock with a timeout to prevent deadlocks
try:
async with self._initialization_lock:
# Mark as initializing to prevent concurrent initializations
self._is_initializing = True
# Update cache
self._cache = RecipeCache(
raw_data=raw_data,
sorted_by_name=[],
sorted_by_date=[]
)
try:
# Scan for recipe data directly
raw_data = await self.scan_all_recipes()
# Update cache
self._cache = RecipeCache(
raw_data=raw_data,
sorted_by_name=[],
sorted_by_date=[]
)
# Resort cache
await self._cache.resort()
return self._cache
# Resort cache
await self._cache.resort()
return self._cache
except Exception as e:
logger.error(f"Recipe Manager: Error initializing cache: {e}", exc_info=True)
# Create empty cache on error
self._cache = RecipeCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[]
)
return self._cache
finally:
# Mark initialization as complete
self._is_initializing = False
except Exception as e:
logger.error(f"Recipe Manager: Error initializing cache: {e}", exc_info=True)
# Create empty cache on error
self._cache = RecipeCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[]
)
return self._cache
finally:
# Mark initialization as complete
self._is_initializing = False
except Exception as e:
logger.error(f"Unexpected error in get_cached_data: {e}")
except Exception as e:
logger.error(f"Unexpected error in get_cached_data: {e}")
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
# Return the cache (may be empty or partially initialized)
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
async def scan_all_recipes(self) -> List[Dict]:
"""Scan all recipe JSON files and return metadata"""
@@ -255,10 +405,13 @@ class RecipeScanner:
async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]:
"""Get hash from Civitai API"""
try:
if not self._civitai_client:
# Get CivitaiClient from ServiceRegistry
civitai_client = await self._get_civitai_client()
if not civitai_client:
logger.error("Failed to get CivitaiClient from ServiceRegistry")
return None
version_info = await self._civitai_client.get_model_version_info(model_version_id)
version_info = await civitai_client.get_model_version_info(model_version_id)
if not version_info or not version_info.get('files'):
logger.debug(f"No files found in version info for ID: {model_version_id}")
@@ -278,10 +431,12 @@ class RecipeScanner:
async def _get_model_version_name(self, model_version_id: str) -> Optional[str]:
"""Get model version name from Civitai API"""
try:
if not self._civitai_client:
# Get CivitaiClient from ServiceRegistry
civitai_client = await self._get_civitai_client()
if not civitai_client:
return None
version_info = await self._civitai_client.get_model_version_info(model_version_id)
version_info = await civitai_client.get_model_version_info(model_version_id)
if version_info and 'name' in version_info:
return version_info['name']
@@ -330,7 +485,7 @@ class RecipeScanner:
logger.error(f"Error getting base model for lora: {e}")
return None
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None):
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None, lora_hash: str = None, bypass_filters: bool = True):
"""Get paginated and filtered recipe data
Args:
@@ -340,69 +495,89 @@ class RecipeScanner:
search: Search term
filters: Dictionary of filters to apply
search_options: Dictionary of search options to apply
lora_hash: Optional SHA256 hash of a LoRA to filter recipes by
bypass_filters: If True, ignore other filters when a lora_hash is provided
"""
cache = await self.get_cached_data()
# Get base dataset
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
# Apply search filter
if search:
# Default search options if none provided
if not search_options:
search_options = {
'title': True,
'tags': True,
'lora_name': True,
'lora_model': True
}
# Special case: Filter by LoRA hash (takes precedence if bypass_filters is True)
if lora_hash:
# Filter recipes that contain this LoRA hash
filtered_data = [
item for item in filtered_data
if 'loras' in item and any(
lora.get('hash', '').lower() == lora_hash.lower()
for lora in item['loras']
)
]
# Build the search predicate based on search options
def matches_search(item):
# Search in title if enabled
if search_options.get('title', True):
if fuzzy_match(str(item.get('title', '')), search):
return True
# Search in tags if enabled
if search_options.get('tags', True) and 'tags' in item:
for tag in item['tags']:
if fuzzy_match(tag, search):
return True
# Search in lora file names if enabled
if search_options.get('lora_name', True) and 'loras' in item:
for lora in item['loras']:
if fuzzy_match(str(lora.get('file_name', '')), search):
return True
# Search in lora model names if enabled
if search_options.get('lora_model', True) and 'loras' in item:
for lora in item['loras']:
if fuzzy_match(str(lora.get('modelName', '')), search):
return True
# No match found
return False
# Filter the data using the search predicate
filtered_data = [item for item in filtered_data if matches_search(item)]
if bypass_filters:
# Skip other filters if bypass_filters is True
pass
# Otherwise continue with normal filtering after applying LoRA hash filter
# Apply additional filters
if filters:
# Filter by base model
if 'base_model' in filters and filters['base_model']:
filtered_data = [
item for item in filtered_data
if item.get('base_model', '') in filters['base_model']
]
# Skip further filtering if we're only filtering by LoRA hash with bypass enabled
if not (lora_hash and bypass_filters):
# Apply search filter
if search:
# Default search options if none provided
if not search_options:
search_options = {
'title': True,
'tags': True,
'lora_name': True,
'lora_model': True
}
# Build the search predicate based on search options
def matches_search(item):
# Search in title if enabled
if search_options.get('title', True):
if fuzzy_match(str(item.get('title', '')), search):
return True
# Search in tags if enabled
if search_options.get('tags', True) and 'tags' in item:
for tag in item['tags']:
if fuzzy_match(tag, search):
return True
# Search in lora file names if enabled
if search_options.get('lora_name', True) and 'loras' in item:
for lora in item['loras']:
if fuzzy_match(str(lora.get('file_name', '')), search):
return True
# Search in lora model names if enabled
if search_options.get('lora_model', True) and 'loras' in item:
for lora in item['loras']:
if fuzzy_match(str(lora.get('modelName', '')), search):
return True
# No match found
return False
# Filter the data using the search predicate
filtered_data = [item for item in filtered_data if matches_search(item)]
# Filter by tags
if 'tags' in filters and filters['tags']:
filtered_data = [
item for item in filtered_data
if any(tag in item.get('tags', []) for tag in filters['tags'])
]
# Apply additional filters
if filters:
# Filter by base model
if 'base_model' in filters and filters['base_model']:
filtered_data = [
item for item in filtered_data
if item.get('base_model', '') in filters['base_model']
]
# Filter by tags
if 'tags' in filters and filters['tags']:
filtered_data = [
item for item in filtered_data
if any(tag in item.get('tags', []) for tag in filters['tags'])
]
# Calculate pagination
total_items = len(filtered_data)
@@ -430,6 +605,74 @@ class RecipeScanner:
}
return result
async def get_recipe_by_id(self, recipe_id: str) -> dict:
"""Get a single recipe by ID with all metadata and formatted URLs
Args:
recipe_id: The ID of the recipe to retrieve
Returns:
Dict containing the recipe data or None if not found
"""
if not recipe_id:
return None
# Get all recipes from cache
cache = await self.get_cached_data()
# Find the recipe with the specified ID
recipe = next((r for r in cache.raw_data if str(r.get('id', '')) == recipe_id), None)
if not recipe:
return None
# Format the recipe with all needed information
formatted_recipe = {**recipe} # Copy all fields
# Format file path to URL
if 'file_path' in formatted_recipe:
formatted_recipe['file_url'] = self._format_file_url(formatted_recipe['file_path'])
# Format dates for display
for date_field in ['created_date', 'modified']:
if date_field in formatted_recipe:
formatted_recipe[f"{date_field}_formatted"] = self._format_timestamp(formatted_recipe[date_field])
# Add lora metadata
if 'loras' in formatted_recipe:
for lora in formatted_recipe['loras']:
if 'hash' in lora and lora['hash']:
lora_hash = lora['hash'].lower()
lora['inLibrary'] = self._lora_scanner.has_lora_hash(lora_hash)
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora_hash)
lora['localPath'] = self._lora_scanner.get_lora_path_by_hash(lora_hash)
return formatted_recipe
def _format_file_url(self, file_path: str) -> str:
"""Format file path as URL for serving in web UI"""
if not file_path:
return '/loras_static/images/no-preview.png'
try:
# Format file path as a URL that will work with static file serving
recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, '/')
if file_path.replace(os.sep, '/').startswith(recipes_dir):
relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, '/')
return f"/loras_static/root1/preview/{relative_path}"
# If not in recipes dir, try to create a valid URL from the file name
file_name = os.path.basename(file_path)
return f"/loras_static/root1/preview/recipes/{file_name}"
except Exception as e:
logger.error(f"Error formatting file URL: {e}")
return '/loras_static/images/no-preview.png'
def _format_timestamp(self, timestamp: float) -> str:
"""Format timestamp for display"""
from datetime import datetime
return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S')
async def update_recipe_metadata(self, recipe_id: str, metadata: dict) -> bool:
"""Update recipe metadata (like title and tags) in both file system and cache

View File

@@ -0,0 +1,124 @@
import asyncio
import logging
from typing import Optional, Dict, Any, TypeVar, Type
logger = logging.getLogger(__name__)
T = TypeVar('T') # Define a type variable for service types
class ServiceRegistry:
"""Centralized registry for service singletons"""
_instance = None
_services: Dict[str, Any] = {}
_lock = asyncio.Lock()
@classmethod
def get_instance(cls):
"""Get singleton instance of the registry"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
async def register_service(cls, service_name: str, service_instance: Any) -> None:
"""Register a service instance with the registry"""
registry = cls.get_instance()
async with cls._lock:
registry._services[service_name] = service_instance
logger.debug(f"Registered service: {service_name}")
@classmethod
async def get_service(cls, service_name: str) -> Any:
"""Get a service instance by name"""
registry = cls.get_instance()
async with cls._lock:
if service_name not in registry._services:
logger.debug(f"Service {service_name} not found in registry")
return None
return registry._services[service_name]
# Convenience methods for common services
@classmethod
async def get_lora_scanner(cls):
"""Get the LoraScanner instance"""
from .lora_scanner import LoraScanner
scanner = await cls.get_service("lora_scanner")
if scanner is None:
scanner = await LoraScanner.get_instance()
await cls.register_service("lora_scanner", scanner)
return scanner
@classmethod
async def get_checkpoint_scanner(cls):
"""Get the CheckpointScanner instance"""
from .checkpoint_scanner import CheckpointScanner
scanner = await cls.get_service("checkpoint_scanner")
if scanner is None:
scanner = await CheckpointScanner.get_instance()
await cls.register_service("checkpoint_scanner", scanner)
return scanner
@classmethod
async def get_lora_monitor(cls):
"""Get the LoraFileMonitor instance"""
from .file_monitor import LoraFileMonitor
monitor = await cls.get_service("lora_monitor")
if monitor is None:
monitor = await LoraFileMonitor.get_instance()
await cls.register_service("lora_monitor", monitor)
return monitor
@classmethod
async def get_checkpoint_monitor(cls):
"""Get the CheckpointFileMonitor instance"""
from .file_monitor import CheckpointFileMonitor
monitor = await cls.get_service("checkpoint_monitor")
if monitor is None:
monitor = await CheckpointFileMonitor.get_instance()
await cls.register_service("checkpoint_monitor", monitor)
return monitor
@classmethod
async def get_civitai_client(cls):
"""Get the CivitaiClient instance"""
from .civitai_client import CivitaiClient
client = await cls.get_service("civitai_client")
if client is None:
client = await CivitaiClient.get_instance()
await cls.register_service("civitai_client", client)
return client
@classmethod
async def get_download_manager(cls):
"""Get the DownloadManager instance"""
from .download_manager import DownloadManager
manager = await cls.get_service("download_manager")
if manager is None:
# We'll let DownloadManager.get_instance handle file_monitor parameter
manager = await DownloadManager.get_instance()
await cls.register_service("download_manager", manager)
return manager
@classmethod
async def get_recipe_scanner(cls):
"""Get the RecipeScanner instance"""
from .recipe_scanner import RecipeScanner
scanner = await cls.get_service("recipe_scanner")
if scanner is None:
lora_scanner = await cls.get_lora_scanner()
scanner = RecipeScanner(lora_scanner)
await cls.register_service("recipe_scanner", scanner)
return scanner
@classmethod
async def get_websocket_manager(cls):
"""Get the WebSocketManager instance"""
from .websocket_manager import ws_manager
manager = await cls.get_service("websocket_manager")
if manager is None:
# ws_manager is already a global instance in websocket_manager.py
from .websocket_manager import ws_manager
await cls.register_service("websocket_manager", ws_manager)
manager = ws_manager
return manager

View File

@@ -9,6 +9,8 @@ class WebSocketManager:
def __init__(self):
self._websockets: Set[web.WebSocketResponse] = set()
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
self._checkpoint_websockets: Set[web.WebSocketResponse] = set() # New set for checkpoint download progress
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection"""
@@ -23,6 +25,34 @@ class WebSocketManager:
finally:
self._websockets.discard(ws)
return ws
async def handle_init_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection for initialization progress"""
ws = web.WebSocketResponse()
await ws.prepare(request)
self._init_websockets.add(ws)
try:
async for msg in ws:
if msg.type == web.WSMsgType.ERROR:
logger.error(f'Init WebSocket error: {ws.exception()}')
finally:
self._init_websockets.discard(ws)
return ws
async def handle_checkpoint_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection for checkpoint download progress"""
ws = web.WebSocketResponse()
await ws.prepare(request)
self._checkpoint_websockets.add(ws)
try:
async for msg in ws:
if msg.type == web.WSMsgType.ERROR:
logger.error(f'Checkpoint WebSocket error: {ws.exception()}')
finally:
self._checkpoint_websockets.discard(ws)
return ws
async def broadcast(self, data: Dict):
"""Broadcast message to all connected clients"""
@@ -34,10 +64,48 @@ class WebSocketManager:
await ws.send_json(data)
except Exception as e:
logger.error(f"Error sending progress: {e}")
async def broadcast_init_progress(self, data: Dict):
"""Broadcast initialization progress to connected clients"""
if not self._init_websockets:
return
# Ensure data has all required fields
if 'stage' not in data:
data['stage'] = 'processing'
if 'progress' not in data:
data['progress'] = 0
if 'details' not in data:
data['details'] = 'Processing...'
for ws in self._init_websockets:
try:
await ws.send_json(data)
except Exception as e:
logger.error(f"Error sending initialization progress: {e}")
async def broadcast_checkpoint_progress(self, data: Dict):
"""Broadcast checkpoint download progress to connected clients"""
if not self._checkpoint_websockets:
return
for ws in self._checkpoint_websockets:
try:
await ws.send_json(data)
except Exception as e:
logger.error(f"Error sending checkpoint progress: {e}")
def get_connected_clients_count(self) -> int:
"""Get number of connected clients"""
return len(self._websockets)
def get_init_clients_count(self) -> int:
"""Get number of initialization progress clients"""
return len(self._init_websockets)
def get_checkpoint_clients_count(self) -> int:
"""Get number of checkpoint progress clients"""
return len(self._checkpoint_websockets)
# Global instance
ws_manager = WebSocketManager()
ws_manager = WebSocketManager()

View File

@@ -5,4 +5,21 @@ NSFW_LEVELS = {
"X": 8,
"XXX": 16,
"Blocked": 32, # Probably not actually visible through the API without being logged in on model owner account?
}
}
# preview extensions
PREVIEW_EXTENSIONS = [
'.webp',
'.preview.webp',
'.preview.png',
'.preview.jpeg',
'.preview.jpg',
'.preview.mp4',
'.png',
'.jpeg',
'.jpg',
'.mp4'
]
# Card preview image width
CARD_PREVIEW_WIDTH = 480

View File

@@ -2,12 +2,14 @@ import logging
import os
import hashlib
import json
from typing import Dict, Optional
import time
from typing import Dict, Optional, Type
from .model_utils import determine_base_model
from .lora_metadata import extract_lora_metadata
from .models import LoraMetadata
from .lora_metadata import extract_lora_metadata, extract_checkpoint_metadata
from .models import BaseModelMetadata, LoraMetadata, CheckpointMetadata
from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
from .exif_utils import ExifUtils
logger = logging.getLogger(__name__)
@@ -15,35 +17,56 @@ async def calculate_sha256(file_path: str) -> str:
"""Calculate SHA256 hash of a file"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
for byte_block in iter(lambda: f.read(128 * 1024), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
def find_preview_file(base_name: str, dir_path: str) -> str:
"""Find preview file for given base name in directory"""
preview_patterns = [
f"{base_name}.preview.png",
f"{base_name}.preview.jpg",
f"{base_name}.preview.jpeg",
f"{base_name}.preview.mp4",
f"{base_name}.png",
f"{base_name}.jpg",
f"{base_name}.jpeg",
f"{base_name}.mp4"
]
for pattern in preview_patterns:
full_pattern = os.path.join(dir_path, pattern)
for ext in PREVIEW_EXTENSIONS:
full_pattern = os.path.join(dir_path, f"{base_name}{ext}")
if os.path.exists(full_pattern):
# Check if this is an image and not already webp
if ext.lower().endswith(('.jpg', '.jpeg', '.png')) and not ext.lower().endswith('.webp'):
try:
# Optimize the image to webp format
webp_path = os.path.join(dir_path, f"{base_name}.webp")
# Use ExifUtils to optimize the image
with open(full_pattern, 'rb') as f:
image_data = f.read()
optimized_data, _ = ExifUtils.optimize_image(
image_data=image_data,
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=True
)
# Save the optimized webp file
with open(webp_path, 'wb') as f:
f.write(optimized_data)
logger.debug(f"Optimized preview image from {full_pattern} to {webp_path}")
return webp_path.replace(os.sep, "/")
except Exception as e:
logger.error(f"Error optimizing preview image {full_pattern}: {e}")
# Fall back to original file if optimization fails
return full_pattern.replace(os.sep, "/")
# Return the original path for webp images or non-image files
return full_pattern.replace(os.sep, "/")
return ""
def normalize_path(path: str) -> str:
"""Normalize file path to use forward slashes"""
return path.replace(os.sep, "/") if path else path
async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
"""Get basic file information as LoraMetadata object"""
async def get_file_info(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]:
"""Get basic file information as a model metadata object"""
# First check if file actually exists and resolve symlinks
try:
real_path = os.path.realpath(file_path)
@@ -70,31 +93,67 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
logger.debug(f"Using SHA256 from .json file for {file_path}")
except Exception as e:
logger.error(f"Error reading .json file for {file_path}: {e}")
# If SHA256 is still not found, check for a .sha256 file
if sha256 is None:
sha256_file = f"{os.path.splitext(file_path)[0]}.sha256"
if os.path.exists(sha256_file):
try:
with open(sha256_file, 'r', encoding='utf-8') as f:
sha256 = f.read().strip().lower()
logger.debug(f"Using SHA256 from .sha256 file for {file_path}")
except Exception as e:
logger.error(f"Error reading .sha256 file for {file_path}: {e}")
try:
# If we didn't get SHA256 from the .json file, calculate it
if not sha256:
start_time = time.time()
sha256 = await calculate_sha256(real_path)
logger.debug(f"Calculated SHA256 for {file_path} in {time.time() - start_time:.2f} seconds")
# Create default metadata based on model class
if model_class == CheckpointMetadata:
metadata = CheckpointMetadata(
file_name=base_name,
model_name=base_name,
file_path=normalize_path(file_path),
size=os.path.getsize(real_path),
modified=os.path.getmtime(real_path),
sha256=sha256,
base_model="Unknown", # Will be updated later
preview_url=normalize_path(preview_url),
tags=[],
modelDescription="",
model_type="checkpoint"
)
metadata = LoraMetadata(
file_name=base_name,
model_name=base_name,
file_path=normalize_path(file_path),
size=os.path.getsize(real_path),
modified=os.path.getmtime(real_path),
sha256=sha256,
base_model="Unknown", # Will be updated later
usage_tips="",
notes="",
from_civitai=True,
preview_url=normalize_path(preview_url),
tags=[],
modelDescription=""
)
# Extract checkpoint-specific metadata
# model_info = await extract_checkpoint_metadata(real_path)
# metadata.base_model = model_info['base_model']
# if 'model_type' in model_info:
# metadata.model_type = model_info['model_type']
else: # Default to LoraMetadata
metadata = LoraMetadata(
file_name=base_name,
model_name=base_name,
file_path=normalize_path(file_path),
size=os.path.getsize(real_path),
modified=os.path.getmtime(real_path),
sha256=sha256,
base_model="Unknown", # Will be updated later
usage_tips="{}",
preview_url=normalize_path(preview_url),
tags=[],
modelDescription=""
)
# Extract lora-specific metadata
model_info = await extract_lora_metadata(real_path)
metadata.base_model = model_info['base_model']
# create metadata file
base_model_info = await extract_lora_metadata(real_path)
metadata.base_model = base_model_info['base_model']
# Save metadata to file
await save_metadata(file_path, metadata)
return metadata
@@ -102,7 +161,7 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
logger.error(f"Error getting file info for {file_path}: {e}")
return None
async def save_metadata(file_path: str, metadata: LoraMetadata) -> None:
async def save_metadata(file_path: str, metadata: BaseModelMetadata) -> None:
"""Save metadata to .metadata.json file"""
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
try:
@@ -115,7 +174,7 @@ async def save_metadata(file_path: str, metadata: LoraMetadata) -> None:
except Exception as e:
print(f"Error saving metadata to {metadata_path}: {str(e)}")
async def load_metadata(file_path: str) -> Optional[LoraMetadata]:
async def load_metadata(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]:
"""Load metadata from .metadata.json file"""
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
try:
@@ -138,6 +197,7 @@ async def load_metadata(file_path: str) -> Optional[LoraMetadata]:
data['file_path'] = normalize_path(file_path)
needs_update = True
# TODO: optimize preview image to webp format if not already done
preview_url = data.get('preview_url', '')
if not preview_url or not os.path.exists(preview_url):
base_name = os.path.splitext(os.path.basename(file_path))[0]
@@ -162,12 +222,22 @@ async def load_metadata(file_path: str) -> Optional[LoraMetadata]:
if 'modelDescription' not in data:
data['modelDescription'] = ""
needs_update = True
# For checkpoint metadata
if model_class == CheckpointMetadata and 'model_type' not in data:
data['model_type'] = "checkpoint"
needs_update = True
# For lora metadata
if model_class == LoraMetadata and 'usage_tips' not in data:
data['usage_tips'] = "{}"
needs_update = True
if needs_update:
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
return LoraMetadata.from_dict(data)
return model_class.from_dict(data)
except Exception as e:
print(f"Error loading metadata from {metadata_path}: {str(e)}")

View File

@@ -1,6 +1,7 @@
from safetensors import safe_open
from typing import Dict
from .model_utils import determine_base_model
import os
async def extract_lora_metadata(file_path: str) -> Dict:
"""Extract essential metadata from safetensors file"""
@@ -13,4 +14,67 @@ async def extract_lora_metadata(file_path: str) -> Dict:
return {"base_model": base_model}
except Exception as e:
print(f"Error reading metadata from {file_path}: {str(e)}")
return {"base_model": "Unknown"}
return {"base_model": "Unknown"}
async def extract_checkpoint_metadata(file_path: str) -> dict:
"""Extract metadata from a checkpoint file to determine model type and base model"""
try:
# Analyze filename for clues about the model
filename = os.path.basename(file_path).lower()
model_info = {
'base_model': 'Unknown',
'model_type': 'checkpoint'
}
# Detect base model from filename
if 'xl' in filename or 'sdxl' in filename:
model_info['base_model'] = 'SDXL'
elif 'sd3' in filename:
model_info['base_model'] = 'SD3'
elif 'sd2' in filename or 'v2' in filename:
model_info['base_model'] = 'SD2.x'
elif 'sd1' in filename or 'v1' in filename:
model_info['base_model'] = 'SD1.5'
# Detect model type from filename
if 'inpaint' in filename:
model_info['model_type'] = 'inpainting'
elif 'anime' in filename:
model_info['model_type'] = 'anime'
elif 'realistic' in filename:
model_info['model_type'] = 'realistic'
# Try to peek at the safetensors file structure if available
if file_path.endswith('.safetensors'):
import json
import struct
with open(file_path, 'rb') as f:
header_size = struct.unpack('<Q', f.read(8))[0]
header_json = f.read(header_size)
header = json.loads(header_json)
# Look for specific keys to identify model type
metadata = header.get('__metadata__', {})
if metadata:
# Try to determine if it's SDXL
if any(key.startswith('conditioner.embedders.1') for key in header):
model_info['base_model'] = 'SDXL'
# Look for model type info
if metadata.get('modelspec.architecture') == 'SD-XL':
model_info['base_model'] = 'SDXL'
elif metadata.get('modelspec.architecture') == 'SD-3':
model_info['base_model'] = 'SD3'
# Check for specific use case
if metadata.get('modelspec.purpose') == 'inpainting':
model_info['model_type'] = 'inpainting'
return model_info
except Exception as e:
logger.error(f"Error extracting checkpoint metadata for {file_path}: {e}")
# Return default values
return {'base_model': 'Unknown', 'model_type': 'checkpoint'}

View File

@@ -5,20 +5,19 @@ import os
from .model_utils import determine_base_model
@dataclass
class LoraMetadata:
"""Represents the metadata structure for a Lora model"""
file_name: str # The filename without extension of the lora
model_name: str # The lora's name defined by the creator, initially same as file_name
file_path: str # Full path to the safetensors file
class BaseModelMetadata:
"""Base class for all model metadata structures"""
file_name: str # The filename without extension
model_name: str # The model's name defined by the creator
file_path: str # Full path to the model file
size: int # File size in bytes
modified: float # Last modified timestamp
sha256: str # SHA256 hash of the file
base_model: str # Base model (SD1.5/SD2.1/SDXL/etc.)
base_model: str # Base model type (SD1.5/SD2.1/SDXL/etc.)
preview_url: str # Preview image URL
preview_nsfw_level: int = 0 # NSFW level of the preview image
usage_tips: str = "{}" # Usage tips for the model, json string
notes: str = "" # Additional notes
from_civitai: bool = True # Whether the lora is from Civitai
from_civitai: bool = True # Whether from Civitai
civitai: Optional[Dict] = None # Civitai API data if available
tags: List[str] = None # Model tags
modelDescription: str = "" # Full model description
@@ -29,32 +28,11 @@ class LoraMetadata:
self.tags = []
@classmethod
def from_dict(cls, data: Dict) -> 'LoraMetadata':
"""Create LoraMetadata instance from dictionary"""
# Create a copy of the data to avoid modifying the input
def from_dict(cls, data: Dict) -> 'BaseModelMetadata':
"""Create instance from dictionary"""
data_copy = data.copy()
return cls(**data_copy)
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata':
"""Create LoraMetadata instance from Civitai version info"""
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
file_path=save_path.replace(os.sep, '/'),
size=file_info.get('sizeKB', 0) * 1024,
modified=datetime.now().timestamp(),
sha256=file_info['hashes'].get('SHA256', '').lower(),
base_model=base_model,
preview_url=None, # Will be updated after preview download
preview_nsfw_level=0, # Will be updated after preview download, it is decided by the nsfw level of the preview image
from_civitai=True,
civitai=version_info
)
def to_dict(self) -> Dict:
"""Convert to dictionary for JSON serialization"""
return asdict(self)
@@ -76,30 +54,54 @@ class LoraMetadata:
self.file_path = file_path.replace(os.sep, '/')
@dataclass
class CheckpointMetadata:
"""Represents the metadata structure for a Checkpoint model"""
file_name: str # The filename without extension
model_name: str # The checkpoint's name defined by the creator
file_path: str # Full path to the model file
size: int # File size in bytes
modified: float # Last modified timestamp
sha256: str # SHA256 hash of the file
base_model: str # Base model type (SD1.5/SD2.1/SDXL/etc.)
preview_url: str # Preview image URL
preview_nsfw_level: int = 0 # NSFW level of the preview image
model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.)
notes: str = "" # Additional notes
from_civitai: bool = True # Whether from Civitai
civitai: Optional[Dict] = None # Civitai API data if available
tags: List[str] = None # Model tags
modelDescription: str = "" # Full model description
# Additional checkpoint-specific fields
resolution: Optional[str] = None # Native resolution (e.g., 512x512, 1024x1024)
vae_included: bool = False # Whether VAE is included in the checkpoint
architecture: str = "" # Model architecture (if known)
def __post_init__(self):
if self.tags is None:
self.tags = []
class LoraMetadata(BaseModelMetadata):
"""Represents the metadata structure for a Lora model"""
usage_tips: str = "{}" # Usage tips for the model, json string
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata':
"""Create LoraMetadata instance from Civitai version info"""
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
file_path=save_path.replace(os.sep, '/'),
size=file_info.get('sizeKB', 0) * 1024,
modified=datetime.now().timestamp(),
sha256=file_info['hashes'].get('SHA256', '').lower(),
base_model=base_model,
preview_url=None, # Will be updated after preview download
preview_nsfw_level=0, # Will be updated after preview download
from_civitai=True,
civitai=version_info
)
@dataclass
class CheckpointMetadata(BaseModelMetadata):
"""Represents the metadata structure for a Checkpoint model"""
model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.)
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'CheckpointMetadata':
"""Create CheckpointMetadata instance from Civitai version info"""
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
model_type = version_info.get('type', 'checkpoint')
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
file_path=save_path.replace(os.sep, '/'),
size=file_info.get('sizeKB', 0) * 1024,
modified=datetime.now().timestamp(),
sha256=file_info['hashes'].get('SHA256', '').lower(),
base_model=base_model,
preview_url=None, # Will be updated after preview download
preview_nsfw_level=0,
from_civitai=True,
civitai=version_info,
model_type=model_type
)

503
py/utils/routes_common.py Normal file
View File

@@ -0,0 +1,503 @@
import os
import json
import logging
from typing import Dict, List, Callable, Awaitable
from aiohttp import web
from .model_utils import determine_base_model
from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
from ..config import config
from ..services.civitai_client import CivitaiClient
from ..utils.exif_utils import ExifUtils
from ..services.download_manager import DownloadManager
logger = logging.getLogger(__name__)
class ModelRouteUtils:
"""Shared utilities for model routes (LoRAs, Checkpoints, etc.)"""
@staticmethod
async def load_local_metadata(metadata_path: str) -> Dict:
"""Load local metadata file"""
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading metadata from {metadata_path}: {e}")
return {}
@staticmethod
async def handle_not_found_on_civitai(metadata_path: str, local_metadata: Dict) -> None:
"""Handle case when model is not found on CivitAI"""
local_metadata['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
@staticmethod
async def update_model_metadata(metadata_path: str, local_metadata: Dict,
civitai_metadata: Dict, client: CivitaiClient) -> None:
"""Update local metadata with CivitAI data"""
local_metadata['civitai'] = civitai_metadata
# Update model name if available
if 'model' in civitai_metadata:
if civitai_metadata.get('model', {}).get('name'):
local_metadata['model_name'] = civitai_metadata['model']['name']
# Fetch additional model metadata (description and tags) if we have model ID
model_id = civitai_metadata['modelId']
if model_id:
model_metadata, _ = await client.get_model_metadata(str(model_id))
if model_metadata:
local_metadata['modelDescription'] = model_metadata.get('description', '')
local_metadata['tags'] = model_metadata.get('tags', [])
# Update base model
local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel'))
# Update preview if needed
if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']):
first_preview = next((img for img in civitai_metadata.get('images', [])), None)
if first_preview:
# Determine if content is video or image
is_video = first_preview['type'] == 'video'
if is_video:
# For videos use .mp4 extension
preview_ext = '.mp4'
else:
# For images use .webp extension
preview_ext = '.webp'
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
preview_filename = base_name + preview_ext
preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename)
if is_video:
# Download video as is
if await client.download_preview_image(first_preview['url'], preview_path):
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
else:
# For images, download and then optimize to WebP
temp_path = preview_path + ".temp"
if await client.download_preview_image(first_preview['url'], temp_path):
try:
# Read the downloaded image
with open(temp_path, 'rb') as f:
image_data = f.read()
# Optimize and convert to WebP
optimized_data, _ = ExifUtils.optimize_image(
image_data=image_data,
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=True
)
# Save the optimized WebP image
with open(preview_path, 'wb') as f:
f.write(optimized_data)
# Update metadata
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
# Remove the temporary file
if os.path.exists(temp_path):
os.remove(temp_path)
except Exception as e:
logger.error(f"Error optimizing preview image: {e}")
# If optimization fails, try to use the downloaded image directly
if os.path.exists(temp_path):
os.rename(temp_path, preview_path)
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
@staticmethod
async def fetch_and_update_model(
sha256: str,
file_path: str,
model_data: dict,
update_cache_func: Callable[[str, str, Dict], Awaitable[bool]]
) -> bool:
"""Fetch and update metadata for a single model
Args:
sha256: SHA256 hash of the model file
file_path: Path to the model file
model_data: The model object in cache to update
update_cache_func: Function to update the cache with new metadata
Returns:
bool: True if successful, False otherwise
"""
client = CivitaiClient()
try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Check if model metadata exists
local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
# Fetch metadata from Civitai
civitai_metadata = await client.get_model_by_hash(sha256)
if not civitai_metadata:
# Mark as not from CivitAI if not found
local_metadata['from_civitai'] = False
model_data['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
return False
# Update metadata
await ModelRouteUtils.update_model_metadata(
metadata_path,
local_metadata,
civitai_metadata,
client
)
# Update cache object directly
model_data.update({
'model_name': local_metadata.get('model_name'),
'preview_url': local_metadata.get('preview_url'),
'from_civitai': True,
'civitai': civitai_metadata
})
# Update cache using the provided function
await update_cache_func(file_path, file_path, local_metadata)
return True
except Exception as e:
logger.error(f"Error fetching CivitAI data: {e}")
return False
finally:
await client.close()
@staticmethod
def filter_civitai_data(data: Dict) -> Dict:
"""Filter relevant fields from CivitAI data"""
if not data:
return {}
fields = [
"id", "modelId", "name", "createdAt", "updatedAt",
"publishedAt", "trainedWords", "baseModel", "description",
"model", "images"
]
return {k: data[k] for k in fields if k in data}
@staticmethod
async def delete_model_files(target_dir: str, file_name: str, file_monitor=None) -> List[str]:
"""Delete model and associated files
Args:
target_dir: Directory containing the model files
file_name: Base name of the model file without extension
file_monitor: Optional file monitor to ignore delete events
Returns:
List of deleted file paths
"""
patterns = [
f"{file_name}.safetensors", # Required
f"{file_name}.metadata.json",
]
# Add all preview file extensions
for ext in PREVIEW_EXTENSIONS:
patterns.append(f"{file_name}{ext}")
deleted = []
main_file = patterns[0]
main_path = os.path.join(target_dir, main_file).replace(os.sep, '/')
if os.path.exists(main_path):
# Notify file monitor to ignore delete event if available
if file_monitor:
file_monitor.handler.add_ignore_path(main_path, 0)
# Delete file
os.remove(main_path)
deleted.append(main_path)
else:
logger.warning(f"Model file not found: {main_file}")
# Delete optional files
for pattern in patterns[1:]:
path = os.path.join(target_dir, pattern)
if os.path.exists(path):
try:
os.remove(path)
deleted.append(pattern)
except Exception as e:
logger.warning(f"Failed to delete {pattern}: {e}")
return deleted
@staticmethod
def get_multipart_ext(filename):
"""Get extension that may have multiple parts like .metadata.json"""
parts = filename.split(".")
if len(parts) > 2: # If contains multi-part extension
return "." + ".".join(parts[-2:]) # Take the last two parts, like ".metadata.json"
return os.path.splitext(filename)[1] # Otherwise take the regular extension, like ".safetensors"
# New common endpoint handlers
@staticmethod
async def handle_delete_model(request: web.Request, scanner) -> web.Response:
"""Handle model deletion request
Args:
request: The aiohttp request
scanner: The model scanner instance with cache management methods
Returns:
web.Response: The HTTP response
"""
try:
data = await request.json()
file_path = data.get('file_path')
if not file_path:
return web.Response(text='Model path is required', status=400)
target_dir = os.path.dirname(file_path)
file_name = os.path.splitext(os.path.basename(file_path))[0]
# Get the file monitor from the scanner if available
file_monitor = getattr(scanner, 'file_monitor', None)
deleted_files = await ModelRouteUtils.delete_model_files(
target_dir,
file_name,
file_monitor
)
# Remove from cache
cache = await scanner.get_cached_data()
cache.raw_data = [item for item in cache.raw_data if item['file_path'] != file_path]
await cache.resort()
# Update hash index if available
if hasattr(scanner, '_hash_index') and scanner._hash_index:
scanner._hash_index.remove_by_path(file_path)
return web.json_response({
'success': True,
'deleted_files': deleted_files
})
except Exception as e:
logger.error(f"Error deleting model: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
@staticmethod
async def handle_fetch_civitai(request: web.Request, scanner) -> web.Response:
"""Handle CivitAI metadata fetch request
Args:
request: The aiohttp request
scanner: The model scanner instance with cache management methods
Returns:
web.Response: The HTTP response
"""
try:
data = await request.json()
metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json'
# Check if model metadata exists
local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
if not local_metadata or not local_metadata.get('sha256'):
return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400)
# Create a client for fetching from Civitai
client = CivitaiClient()
try:
# Fetch and update metadata
civitai_metadata = await client.get_model_by_hash(local_metadata["sha256"])
if not civitai_metadata:
await ModelRouteUtils.handle_not_found_on_civitai(metadata_path, local_metadata)
return web.json_response({"success": False, "error": "Not found on CivitAI"}, status=404)
await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, client)
# Update the cache
await scanner.update_single_model_cache(data['file_path'], data['file_path'], local_metadata)
return web.json_response({"success": True})
finally:
await client.close()
except Exception as e:
logger.error(f"Error fetching from CivitAI: {e}", exc_info=True)
return web.json_response({"success": False, "error": str(e)}, status=500)
@staticmethod
async def handle_replace_preview(request: web.Request, scanner) -> web.Response:
"""Handle preview image replacement request
Args:
request: The aiohttp request
scanner: The model scanner instance with methods to update cache
Returns:
web.Response: The HTTP response
"""
try:
reader = await request.multipart()
# Read preview file data
field = await reader.next()
if field.name != 'preview_file':
raise ValueError("Expected 'preview_file' field")
content_type = field.headers.get('Content-Type', 'image/png')
preview_data = await field.read()
# Read model path
field = await reader.next()
if field.name != 'model_path':
raise ValueError("Expected 'model_path' field")
model_path = (await field.read()).decode()
# Save preview file
base_name = os.path.splitext(os.path.basename(model_path))[0]
folder = os.path.dirname(model_path)
# Determine if content is video or image
if content_type.startswith('video/'):
# For videos, keep original format and use .mp4 extension
extension = '.mp4'
optimized_data = preview_data
else:
# For images, optimize and convert to WebP
optimized_data, _ = ExifUtils.optimize_image(
image_data=preview_data,
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=True
)
extension = '.webp' # Use .webp without .preview part
preview_path = os.path.join(folder, base_name + extension).replace(os.sep, '/')
with open(preview_path, 'wb') as f:
f.write(optimized_data)
# Update preview path in metadata
metadata_path = os.path.splitext(model_path)[0] + '.metadata.json'
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
# Update preview_url directly in the metadata dict
metadata['preview_url'] = preview_path
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
except Exception as e:
logger.error(f"Error updating metadata: {e}")
# Update preview URL in scanner cache
if hasattr(scanner, 'update_preview_in_cache'):
await scanner.update_preview_in_cache(model_path, preview_path)
return web.json_response({
"success": True,
"preview_url": config.get_preview_static_url(preview_path)
})
except Exception as e:
logger.error(f"Error replacing preview: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
@staticmethod
async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type="lora") -> web.Response:
"""Handle model download request
Args:
request: The aiohttp request
download_manager: Instance of DownloadManager
model_type: Type of model ('lora' or 'checkpoint')
Returns:
web.Response: The HTTP response
"""
try:
data = await request.json()
# Create progress callback
async def progress_callback(progress):
from ..services.websocket_manager import ws_manager
await ws_manager.broadcast({
'status': 'progress',
'progress': progress
})
# Check which identifier is provided
download_url = data.get('download_url')
model_hash = data.get('model_hash')
model_version_id = data.get('model_version_id')
# Validate that at least one identifier is provided
if not any([download_url, model_hash, model_version_id]):
return web.Response(
status=400,
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
)
# Use the correct root directory based on model type
root_key = 'checkpoint_root' if model_type == 'checkpoint' else 'lora_root'
save_dir = data.get(root_key)
result = await download_manager.download_from_civitai(
download_url=download_url,
model_hash=model_hash,
model_version_id=model_version_id,
save_dir=save_dir,
relative_path=data.get('relative_path', ''),
progress_callback=progress_callback,
model_type=model_type
)
if not result.get('success', False):
error_message = result.get('error', 'Unknown error')
# Return 401 for early access errors
if 'early access' in error_message.lower():
logger.warning(f"Early access download failed: {error_message}")
return web.Response(
status=401, # Use 401 status code to match Civitai's response
text=f"Early Access Restriction: {error_message}"
)
return web.Response(status=500, text=error_message)
return web.json_response(result)
except Exception as e:
error_message = str(e)
# Check if this might be an early access error
if '401' in error_message:
logger.warning(f"Early access error (401): {error_message}")
return web.Response(
status=401,
text="Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com."
)
logger.error(f"Error downloading {model_type}: {error_message}")
return web.Response(status=500, text=error_message)

View File

@@ -1,7 +1,7 @@
[project]
name = "comfyui-lora-manager"
description = "LoRA Manager for ComfyUI - Access it at http://localhost:8188/loras for managing LoRA models with previews and metadata integration."
version = "0.8.4"
version = "0.8.6"
license = {file = "LICENSE"}
dependencies = [
"aiohttp",
@@ -11,6 +11,7 @@ dependencies = [
"beautifulsoup4",
"piexif",
"Pillow",
"olefile", # for getting rid of warning message
"requests"
]

View File

@@ -5,4 +5,5 @@ watchdog
beautifulsoup4
piexif
Pillow
olefile
requests

View File

@@ -23,6 +23,7 @@
cursor: pointer; /* Added from recipe-card */
display: flex; /* Added from recipe-card */
flex-direction: column; /* Added from recipe-card */
overflow: hidden; /* Add overflow hidden to contain children */
}
.lora-card:hover {
@@ -50,9 +51,11 @@
.card-preview {
position: relative;
width: 100%;
height: 100%;
height: 100%; /* This should work with aspect-ratio on parent */
border-radius: var(--border-radius-base);
overflow: hidden;
flex-shrink: 0; /* Prevent shrinking */
min-height: 0; /* Fix for potential flexbox sizing issue in Firefox */
}
.card-preview img,

View File

@@ -0,0 +1,84 @@
/* Filter indicator styles */
.control-group .filter-active {
display: flex;
align-items: center;
gap: 6px;
background: var(--lora-accent);
color: white;
border-radius: var(--border-radius-xs);
padding: 4px 10px;
transition: all 0.2s ease;
border: 1px solid var(--lora-accent);
cursor: pointer;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
font-size: 0.85em;
}
.control-group .filter-active:hover {
opacity: 0.92;
transform: translateY(-1px);
box-shadow: 0 3px 5px rgba(0, 0, 0, 0.15);
}
.control-group .filter-active:active {
transform: translateY(0);
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
}
.control-group .filter-active i.fa-filter {
font-size: 0.9em;
margin-right: 2px;
opacity: 0.9;
}
.control-group .filter-active i.clear-filter {
transition: transform 0.2s ease, background-color 0.2s ease;
cursor: pointer;
margin-left: 4px;
border-radius: 50%;
font-size: 0.85em;
width: 16px;
height: 16px;
display: flex;
align-items: center;
justify-content: center;
}
.control-group .filter-active i.clear-filter:hover {
transform: scale(1.2);
background-color: rgba(255, 255, 255, 0.2);
}
.control-group .filter-active .lora-name {
font-weight: 500;
max-width: 150px;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
/* Animation for filter indicator */
@keyframes filterPulse {
0% { transform: scale(1); box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); }
50% { transform: scale(1.03); box-shadow: 0 3px 8px rgba(0, 0, 0, 0.15); }
100% { transform: scale(1); box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); }
}
.filter-active.animate {
animation: filterPulse 0.6s ease;
}
/* Make responsive */
@media (max-width: 576px) {
.control-group .filter-active {
padding: 6px 10px;
}
.control-group .filter-active .lora-name {
max-width: 100px;
}
.control-group .filter-active:hover {
transform: none; /* Disable hover effects on mobile */
}
}

View File

@@ -0,0 +1,359 @@
/* Initialization Component Styles */
.initialization-container {
width: 100%;
height: 100%;
padding: var(--space-3);
background: var(--lora-surface);
animation: fadeIn 0.3s ease-in-out;
display: flex;
align-items: center;
justify-content: center;
}
.initialization-content {
max-width: 800px;
width: 100%;
}
/* Override loading.css width for initialization component */
.initialization-container .loading-content {
width: 100%;
max-width: 100%;
background: transparent;
backdrop-filter: none;
border: none;
padding: 0;
}
.initialization-header {
text-align: center;
margin-bottom: var(--space-3);
}
.initialization-header h2 {
font-size: 1.8rem;
margin-bottom: var(--space-1);
color: var(--text-color);
}
.init-subtitle {
color: var(--text-color);
opacity: 0.8;
font-size: 1rem;
}
/* Progress Bar Styles specific to initialization */
.initialization-progress {
margin-bottom: var(--space-3);
}
/* Renamed container class */
.init-progress-container {
width: 100%; /* Use full width within its container */
height: 8px; /* Match height from previous .progress-bar-container */
background-color: var(--lora-border); /* Consistent background */
border-radius: 4px;
overflow: hidden;
margin: 0 auto var(--space-1); /* Center horizontally, add bottom margin */
}
/* Renamed progress bar class */
.init-progress-bar {
height: 100%;
/* Use a gradient consistent with the theme accent */
background: linear-gradient(90deg, var(--lora-accent) 0%, color-mix(in oklch, var(--lora-accent) 80%, transparent) 100%);
border-radius: 4px; /* Match container radius */
transition: width 0.3s ease;
width: 0%; /* Start at 0% */
}
/* Remove the old .progress-bar rule specific to initialization to avoid conflicts */
/* .progress-bar { ... } */
/* Progress Details */
.progress-details {
display: flex;
justify-content: space-between;
font-size: 0.9rem;
color: var(--text-color);
margin-top: var(--space-1);
padding: 0 2px;
}
#remainingTime {
font-style: italic;
color: var(--text-color);
opacity: 0.8;
}
/* Stages Styles */
.initialization-stages {
margin-bottom: var(--space-3);
}
.stage-item {
display: flex;
align-items: flex-start;
padding: var(--space-2);
border-radius: var(--border-radius-xs);
margin-bottom: var(--space-1);
transition: background-color 0.2s ease;
border: 1px solid transparent;
}
.stage-item.active {
background-color: rgba(var(--lora-accent), 0.1);
border-color: var(--lora-accent);
}
.stage-item.completed {
background-color: rgba(0, 150, 0, 0.05);
border-color: rgba(0, 150, 0, 0.2);
}
.stage-icon {
display: flex;
align-items: center;
justify-content: center;
width: 40px;
height: 40px;
background: var(--lora-border);
border-radius: 50%;
margin-right: var(--space-2);
}
.stage-item.active .stage-icon {
background: var(--lora-accent);
color: white;
}
.stage-item.completed .stage-icon {
background: rgb(0, 150, 0);
color: white;
}
.stage-content {
flex: 1;
}
.stage-content h4 {
margin: 0 0 5px 0;
font-size: 1rem;
color: var(--text-color);
}
.stage-details {
font-size: 0.85rem;
color: var(--text-color);
opacity: 0.8;
}
.stage-status {
display: flex;
align-items: center;
justify-content: center;
width: 24px;
height: 24px;
}
.stage-status.pending {
color: var(--text-color);
opacity: 0.5;
}
.stage-status.in-progress {
color: var(--lora-accent);
}
.stage-status.completed {
color: rgb(0, 150, 0);
}
/* Tips Container */
.tips-container {
margin-top: var(--space-3);
background: rgba(var(--lora-accent), 0.05);
border-radius: var(--border-radius-base);
padding: var(--space-2);
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05);
}
.tips-header {
display: flex;
align-items: center;
margin-bottom: var(--space-2);
padding-bottom: var(--space-1);
border-bottom: 1px solid var(--lora-border);
}
.tips-header i {
margin-right: 10px;
color: var(--lora-accent);
font-size: 1.2rem;
}
.tips-header h3 {
font-size: 1.2rem;
margin: 0;
color: var(--text-color);
}
/* Tip Carousel with Images */
.tips-content {
position: relative;
}
.tip-carousel {
position: relative;
height: 160px;
overflow: hidden;
}
.tip-item {
position: absolute;
width: 100%;
height: 100%;
display: flex;
opacity: 0;
transition: opacity 0.5s ease;
padding: 0;
border-radius: var(--border-radius-sm);
overflow: hidden;
}
.tip-item.active {
opacity: 1;
}
.tip-image {
width: 40%;
overflow: hidden;
display: flex;
align-items: center;
justify-content: center;
background-color: var(--lora-border);
}
.tip-image img {
width: 100%;
height: 100%;
object-fit: cover;
}
.tip-text {
width: 60%;
padding: var(--space-2);
display: flex;
flex-direction: column;
justify-content: center;
}
.tip-text h4 {
margin: 0 0 var(--space-1) 0;
font-size: 1.1rem;
color: var(--text-color);
}
.tip-text p {
margin: 0;
line-height: 1.5;
font-size: 0.9rem;
color: var(--text-color);
}
.tip-navigation {
display: flex;
justify-content: center;
margin-top: var(--space-2);
}
.tip-dot {
width: 10px;
height: 10px;
border-radius: 50%;
background-color: var(--lora-border);
margin: 0 5px;
cursor: pointer;
transition: background-color 0.2s ease, transform 0.2s ease;
}
.tip-dot:hover {
transform: scale(1.2);
}
.tip-dot.active {
background-color: var(--lora-accent);
}
/* Animation */
@keyframes fadeIn {
from {
opacity: 0;
transform: translateY(10px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
/* Different stage status animations */
@keyframes pulse {
0% {
transform: scale(1);
}
50% {
transform: scale(1.2);
}
100% {
transform: scale(1);
}
}
.stage-item.active .stage-icon i {
animation: pulse 1s infinite;
}
/* Responsive Adjustments */
@media (max-width: 768px) {
.initialization-container {
padding: var(--space-2);
}
.stage-item {
padding: var(--space-1);
}
.stage-icon {
width: 32px;
height: 32px;
min-width: 32px;
}
.tip-item {
flex-direction: column;
height: 220px;
}
.tip-image, .tip-text {
width: 100%;
}
.tip-image {
height: 120px;
}
.tip-carousel {
height: 220px;
}
}
@media (prefers-reduced-motion: reduce) {
.initialization-container,
.tip-item,
.tip-dot {
transition: none;
animation: none;
}
}

View File

@@ -863,7 +863,7 @@
}
.model-description-content blockquote {
border-left: 3px solid var(--lora-accent);
border-left: 3px solid var (--lora-accent);
padding-left: 1em;
margin-left: 0;
margin-right: 0;
@@ -1280,4 +1280,47 @@
font-size: 1.1em;
color: var(--lora-accent);
opacity: 0.8;
}
.view-all-btn {
display: flex;
align-items: center;
gap: 5px;
padding: 6px 12px;
background-color: var(--lora-accent);
color: var(--lora-text);
border: none;
border-radius: var(--border-radius-sm);
cursor: pointer;
transition: background-color 0.2s;
font-size: 13px;
}
.view-all-btn:hover {
opacity: 0.9;
}
/* Loading, error and empty states */
.recipes-loading,
.recipes-error,
.recipes-empty {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 40px;
text-align: center;
min-height: 200px;
}
.recipes-loading i,
.recipes-error i,
.recipes-empty i {
font-size: 32px;
margin-bottom: 15px;
color: var(--lora-accent);
}
.recipes-error i {
color: var(--lora-error);
}

View File

@@ -196,7 +196,7 @@ body.modal-open {
}
.settings-modal {
max-width: 500px;
max-width: 650px; /* Further increased from 600px for more space */
}
/* Settings Links */
@@ -266,14 +266,22 @@ body.modal-open {
}
}
/* API key input specific styles */
.api-key-input {
width: 100%; /* Take full width of parent */
position: relative;
display: flex;
align-items: center;
}
.api-key-input input {
padding-right: 40px;
width: 100%;
padding: 6px 40px 6px 10px; /* Add left padding */
height: 32px;
border-radius: var(--border-radius-xs);
border: 1px solid var(--border-color);
background-color: var(--lora-surface);
color: var(--text-color);
}
.api-key-input .toggle-visibility {
@@ -294,8 +302,10 @@ body.modal-open {
.input-help {
font-size: 0.85em;
color: var(--text-color);
opacity: 0.8;
margin-top: 4px;
opacity: 0.7;
margin-top: 8px; /* Space between control and help */
line-height: 1.4;
width: 100%; /* Full width */
}
/* 统一各个 section 的样式 */
@@ -341,8 +351,8 @@ body.modal-open {
.setting-item {
display: flex;
flex-direction: column;
margin-bottom: var(--space-2);
flex-direction: column; /* Changed to column for help text placement */
margin-bottom: var(--space-3); /* Increased to provide more spacing between items */
padding: var(--space-1);
border-radius: var(--border-radius-xs);
}
@@ -355,35 +365,52 @@ body.modal-open {
background: rgba(255, 255, 255, 0.05);
}
.setting-info {
margin-bottom: var(--space-1);
/* Control row with label and input together */
.setting-row {
display: flex;
flex-direction: row;
justify-content: space-between;
align-items: center;
width: 100%;
}
.setting-info {
margin-bottom: 0;
width: 35%; /* Increased from 30% to prevent wrapping */
flex-shrink: 0; /* Prevent shrinking */
}
.setting-info label {
display: block;
margin-bottom: 4px;
font-weight: 500;
margin-bottom: 0;
white-space: nowrap; /* Prevent label wrapping */
}
.setting-control {
width: 100%;
margin-bottom: var(--space-1);
width: 60%; /* Decreased slightly from 65% */
margin-bottom: 0;
display: flex;
justify-content: flex-end; /* Right-align all controls */
}
/* Select Control Styles */
.select-control {
width: 100%;
display: flex;
justify-content: flex-end;
}
.select-control select {
width: 100%;
max-width: 100%; /* Increased from 200px */
padding: 6px 10px;
border-radius: var(--border-radius-xs);
border: 1px solid var(--border-color);
background-color: var(--lora-surface);
color: var(--text-color);
font-size: 0.95em;
height: 32px;
}
/* Fix dark theme select dropdown text color */
@@ -409,6 +436,7 @@ body.modal-open {
width: 50px;
height: 24px;
cursor: pointer;
margin-left: auto; /* Push to right side */
}
.toggle-switch input {
@@ -458,15 +486,6 @@ input:checked + .toggle-slider:before {
width: 22px;
}
/* Update input help styles */
.input-help {
font-size: 0.85em;
color: var(--text-color);
opacity: 0.7;
margin-top: 4px;
line-height: 1.4;
}
/* Blur effect for NSFW content */
.nsfw-blur {
filter: blur(12px);

View File

@@ -1,184 +0,0 @@
.recipe-tag-container {
display: flex;
flex-wrap: wrap;
gap: 0.5rem;
margin-bottom: 1rem;
}
.recipe-tag {
background: var(--lora-surface-hover);
color: var(--lora-text-secondary);
padding: 0.25rem 0.5rem;
border-radius: var(--border-radius-sm);
font-size: 0.8rem;
cursor: pointer;
transition: all 0.2s ease;
}
.recipe-tag:hover, .recipe-tag.active {
background: var(--lora-primary);
color: var(--lora-text-on-primary);
}
.recipe-card {
position: relative;
background: var(--lora-surface);
border-radius: var(--border-radius-base);
overflow: hidden;
box-shadow: var(--shadow-sm);
transition: all 0.2s ease;
aspect-ratio: 896/1152;
cursor: pointer;
display: flex;
flex-direction: column;
}
.recipe-card:hover {
transform: translateY(-3px);
box-shadow: var(--shadow-md);
}
.recipe-card:focus-visible {
outline: 2px solid var(--lora-accent);
outline-offset: 2px;
}
.recipe-indicator {
position: absolute;
top: 6px;
left: 8px;
width: 24px;
height: 24px;
background: var(--lora-primary);
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
color: white;
font-weight: bold;
z-index: 2;
}
.recipe-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(250px, 1fr));
gap: 1.5rem;
margin-top: 1.5rem;
}
.placeholder-message {
grid-column: 1 / -1;
text-align: center;
padding: 2rem;
background: var(--lora-surface-alt);
border-radius: var(--border-radius-base);
}
.card-preview {
position: relative;
width: 100%;
height: 100%;
border-radius: var(--border-radius-base);
overflow: hidden;
}
.card-preview img {
width: 100%;
height: 100%;
object-fit: cover;
object-position: center top;
}
.card-header {
position: absolute;
top: 0;
left: 0;
right: 0;
background: linear-gradient(oklch(0% 0 0 / 0.75), transparent 85%);
backdrop-filter: blur(8px);
color: white;
padding: var(--space-1);
display: flex;
justify-content: space-between;
align-items: center;
z-index: 1;
min-height: 20px;
}
.base-model-wrapper {
display: flex;
align-items: center;
gap: 8px;
margin-left: 32px;
}
.card-actions {
display: flex;
gap: 8px;
}
.card-actions i {
cursor: pointer;
opacity: 0.8;
transition: opacity 0.2s ease;
}
.card-actions i:hover {
opacity: 1;
}
.card-footer {
position: absolute;
bottom: 0;
left: 0;
right: 0;
background: linear-gradient(transparent 15%, oklch(0% 0 0 / 0.75));
backdrop-filter: blur(8px);
color: white;
padding: var(--space-1);
display: flex;
justify-content: space-between;
align-items: flex-start;
min-height: 32px;
gap: var(--space-1);
}
.lora-count {
display: flex;
align-items: center;
gap: 4px;
background: rgba(255, 255, 255, 0.2);
padding: 2px 8px;
border-radius: var(--border-radius-xs);
font-size: 0.85em;
position: relative;
}
.lora-count.ready {
background: rgba(46, 204, 113, 0.3);
}
.lora-count.missing {
background: rgba(231, 76, 60, 0.3);
}
/* 响应式设计 */
@media (max-width: 1400px) {
.recipe-grid {
grid-template-columns: repeat(auto-fill, minmax(240px, 1fr));
}
.recipe-card {
max-width: 240px;
}
}
@media (max-width: 768px) {
.recipe-grid {
grid-template-columns: minmax(260px, 1fr);
}
.recipe-card {
max-width: 100%;
}
}

View File

@@ -400,6 +400,27 @@
gap: var(--space-1);
}
/* View LoRAs button */
.view-loras-btn {
background: none;
border: none;
color: var(--text-color);
opacity: 0.7;
cursor: pointer;
padding: 4px 8px;
border-radius: var(--border-radius-xs);
transition: all 0.2s;
display: flex;
align-items: center;
justify-content: center;
}
.view-loras-btn:hover {
opacity: 1;
background: var(--lora-surface);
color: var(--lora-accent);
}
#recipeLorasCount {
font-size: 0.9em;
color: var(--text-color);
@@ -420,6 +441,7 @@
gap: 10px;
overflow-y: auto;
flex: 1;
padding-top: 4px; /* Add padding to prevent first item from being cut off when hovered */
}
.recipe-lora-item {
@@ -433,6 +455,14 @@
will-change: transform;
/* Create a new containing block for absolutely positioned descendants */
transform: translateZ(0);
cursor: pointer; /* Make it clear the item is clickable */
transition: transform 0.2s ease, box-shadow 0.2s ease, border-color 0.2s ease;
}
.recipe-lora-item:hover {
transform: translateY(-1px);
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08);
border-color: var(--lora-accent);
}
.recipe-lora-item.exists-locally {

View File

@@ -38,6 +38,90 @@
flex-wrap: nowrap;
}
/* Action button styling */
.control-group {
position: relative;
}
.control-group button {
min-width: 100px;
display: flex;
align-items: center;
justify-content: center;
gap: 4px;
border-radius: var(--border-radius-xs);
padding: 4px 10px;
border: 1px solid var(--border-color);
background: var(--card-bg);
color: var(--text-color);
font-size: 0.85em;
transition: all 0.2s ease;
cursor: pointer;
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
}
.control-group button:hover {
border-color: var(--lora-accent);
background: var(--bg-color);
transform: translateY(-1px);
box-shadow: 0 3px 5px rgba(0, 0, 0, 0.08);
}
.control-group button:active {
transform: translateY(0);
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
}
.control-group button i {
opacity: 0.8;
transition: opacity 0.2s ease;
}
.control-group button:hover i {
opacity: 1;
}
/* Active state for buttons that can be toggled */
.control-group button.active {
background: var(--lora-accent);
color: white;
border-color: var(--lora-accent);
}
/* Select dropdown styling */
.control-group select {
min-width: 100px;
padding: 4px 26px 4px 10px;
border-radius: var(--border-radius-xs);
border: 1px solid var(--border-color);
background-color: var(--card-bg);
color: var(--text-color);
font-size: 0.85em;
appearance: none;
-webkit-appearance: none;
-moz-appearance: none;
background-image: url("data:image/svg+xml;charset=UTF-8,%3csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24' fill='none' stroke='currentColor' stroke-width='2' stroke-linecap='round' stroke-linejoin='round'%3e%3cpolyline points='6 9 12 15 18 9'%3e%3c/polyline%3e%3c/svg%3e");
background-repeat: no-repeat;
background-position: right 6px center;
background-size: 14px;
cursor: pointer;
transition: all 0.2s ease;
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
}
.control-group select:hover {
border-color: var(--lora-accent);
background-color: var(--bg-color);
transform: translateY(-1px);
box-shadow: 0 3px 5px rgba(0, 0, 0, 0.08);
}
.control-group select:focus {
outline: none;
border-color: var(--lora-accent);
box-shadow: 0 0 0 2px oklch(var(--lora-accent) / 0.15);
}
/* Ensure hidden class works properly */
.hidden {
display: none !important;
@@ -86,12 +170,14 @@
justify-content: center;
cursor: pointer;
transition: all 0.3s ease;
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
}
.toggle-folders-btn:hover {
background: var(--lora-accent);
color: white;
transform: translateY(-2px);
box-shadow: 0 3px 6px rgba(0, 0, 0, 0.1);
}
.toggle-folders-btn i {
@@ -101,8 +187,9 @@
/* Icon-only button style */
.icon-only {
min-width: unset !important;
width: 36px !important;
width: 32px !important;
padding: 0 !important;
height: 32px !important;
}
/* Rotate icon when folders are collapsed */
@@ -133,16 +220,25 @@
cursor: pointer;
padding: 2px 8px;
margin: 2px;
border: 1px solid #ccc;
border: 1px solid var(--border-color);
border-radius: var(--border-radius-xs);
display: inline-block;
line-height: 1.2;
font-size: 14px;
background-color: var(--card-bg);
transition: all 0.2s ease;
}
.tag:hover {
border-color: var(--lora-accent);
background-color: oklch(var(--lora-accent) / 0.1);
transform: translateY(-1px);
}
.tag.active {
background-color: #007bff;
background-color: var(--lora-accent);
color: white;
border-color: var(--lora-accent);
}
/* Back to Top Button */
@@ -155,7 +251,7 @@
border-radius: 50%;
background: var(--card-bg);
border: 1px solid var(--border-color);
color: var (--text-color);
color: var(--text-color);
display: flex;
align-items: center;
justify-content: center;
@@ -165,6 +261,7 @@
transform: translateY(10px);
transition: all 0.3s ease;
z-index: var(--z-overlay);
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
}
.back-to-top.visible {
@@ -174,9 +271,10 @@
}
.back-to-top:hover {
background: var (--lora-accent);
background: var(--lora-accent);
color: white;
transform: translateY(-2px);
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15);
}
@media (max-width: 768px) {
@@ -203,19 +301,22 @@
}
.toggle-folders-btn:hover {
transform: none; /* 移动端下禁用hover效果 */
transform: none; /* Disable hover effects on mobile */
}
.control-group button:hover {
transform: none; /* Disable hover effects on mobile */
}
.control-group select:hover {
transform: none; /* Disable hover effects on mobile */
}
.tag:hover {
transform: none; /* Disable hover effects on mobile */
}
.back-to-top {
bottom: 60px; /* Give some extra space from bottom on mobile */
}
}
/* Standardize button widths in controls */
.control-group button {
min-width: 100px;
display: flex;
align-items: center;
justify-content: center;
gap: 6px;
}

View File

@@ -18,6 +18,8 @@
@import 'components/search-filter.css';
@import 'components/bulk.css';
@import 'components/shared.css';
@import 'components/filter-indicator.css';
@import 'components/initialization.css';
.initialization-notice {
display: flex;

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 213 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 338 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.8 MiB

View File

@@ -0,0 +1,499 @@
// filepath: d:\Workspace\ComfyUI\custom_nodes\ComfyUI-Lora-Manager\static\js\api\baseModelApi.js
import { state, getCurrentPageState } from '../state/index.js';
import { showToast } from '../utils/uiHelpers.js';
import { showDeleteModal, confirmDelete } from '../utils/modalUtils.js';
import { getSessionItem } from '../utils/storageHelpers.js';
/**
* Shared functionality for handling models (loras and checkpoints)
*/
// Generic function to load more models with pagination
export async function loadMoreModels(options = {}) {
const {
resetPage = false,
updateFolders = false,
modelType = 'lora', // 'lora' or 'checkpoint'
createCardFunction,
endpoint = '/api/loras'
} = options;
const pageState = getCurrentPageState();
if (pageState.isLoading || (!pageState.hasMore && !resetPage)) return;
pageState.isLoading = true;
document.body.classList.add('loading');
try {
// Reset to first page if requested
if (resetPage) {
pageState.currentPage = 1;
// Clear grid if resetting
const gridId = modelType === 'checkpoint' ? 'checkpointGrid' : 'loraGrid';
const grid = document.getElementById(gridId);
if (grid) grid.innerHTML = '';
}
const params = new URLSearchParams({
page: pageState.currentPage,
page_size: pageState.pageSize || 20,
sort_by: pageState.sortBy
});
if (pageState.activeFolder !== null) {
params.append('folder', pageState.activeFolder);
}
// Add search parameters if there's a search term
if (pageState.filters?.search) {
params.append('search', pageState.filters.search);
params.append('fuzzy', 'true');
// Add search option parameters if available
if (pageState.searchOptions) {
params.append('search_filename', pageState.searchOptions.filename.toString());
params.append('search_modelname', pageState.searchOptions.modelname.toString());
if (pageState.searchOptions.tags !== undefined) {
params.append('search_tags', pageState.searchOptions.tags.toString());
}
params.append('recursive', (pageState.searchOptions?.recursive ?? false).toString());
}
}
// Add filter parameters if active
if (pageState.filters) {
// Handle tags filters
if (pageState.filters.tags && pageState.filters.tags.length > 0) {
// Checkpoints API expects individual 'tag' parameters, Loras API expects comma-separated 'tags'
if (modelType === 'checkpoint') {
pageState.filters.tags.forEach(tag => {
params.append('tag', tag);
});
} else {
params.append('tags', pageState.filters.tags.join(','));
}
}
// Handle base model filters
if (pageState.filters.baseModel && pageState.filters.baseModel.length > 0) {
if (modelType === 'checkpoint') {
pageState.filters.baseModel.forEach(model => {
params.append('base_model', model);
});
} else {
params.append('base_models', pageState.filters.baseModel.join(','));
}
}
}
// Add model-specific parameters
if (modelType === 'lora') {
// Check for recipe-based filtering parameters from session storage
const filterLoraHash = getSessionItem('recipe_to_lora_filterLoraHash');
const filterLoraHashes = getSessionItem('recipe_to_lora_filterLoraHashes');
// Add hash filter parameter if present
if (filterLoraHash) {
params.append('lora_hash', filterLoraHash);
}
// Add multiple hashes filter if present
else if (filterLoraHashes) {
try {
if (Array.isArray(filterLoraHashes) && filterLoraHashes.length > 0) {
params.append('lora_hashes', filterLoraHashes.join(','));
}
} catch (error) {
console.error('Error parsing lora hashes from session storage:', error);
}
}
}
const response = await fetch(`${endpoint}?${params}`);
if (!response.ok) {
throw new Error(`Failed to fetch models: ${response.statusText}`);
}
const data = await response.json();
const gridId = modelType === 'checkpoint' ? 'checkpointGrid' : 'loraGrid';
const grid = document.getElementById(gridId);
if (data.items.length === 0 && pageState.currentPage === 1) {
grid.innerHTML = `<div class="no-results">No ${modelType}s found in this folder</div>`;
pageState.hasMore = false;
} else if (data.items.length > 0) {
pageState.hasMore = pageState.currentPage < data.total_pages;
// Append model cards using the provided card creation function
data.items.forEach(model => {
const card = createCardFunction(model);
grid.appendChild(card);
});
// Increment the page number AFTER successful loading
pageState.currentPage++;
} else {
pageState.hasMore = false;
}
if (updateFolders && data.folders) {
updateFolderTags(data.folders);
}
} catch (error) {
console.error(`Error loading ${modelType}s:`, error);
showToast(`Failed to load ${modelType}s: ${error.message}`, 'error');
} finally {
pageState.isLoading = false;
document.body.classList.remove('loading');
}
}
// Update folder tags in the UI
export function updateFolderTags(folders) {
const folderTagsContainer = document.querySelector('.folder-tags');
if (!folderTagsContainer) return;
// Keep track of currently selected folder
const pageState = getCurrentPageState();
const currentFolder = pageState.activeFolder;
// Create HTML for folder tags
const tagsHTML = folders.map(folder => {
const isActive = folder === currentFolder;
return `<div class="tag ${isActive ? 'active' : ''}" data-folder="${folder}">${folder}</div>`;
}).join('');
// Update the container
folderTagsContainer.innerHTML = tagsHTML;
// Reattach click handlers and ensure the active tag is visible
const tags = folderTagsContainer.querySelectorAll('.tag');
tags.forEach(tag => {
if (typeof toggleFolder === 'function') {
tag.addEventListener('click', toggleFolder);
}
if (tag.dataset.folder === currentFolder) {
tag.scrollIntoView({ behavior: 'smooth', block: 'nearest' });
}
});
}
// Generic function to replace a model preview
export function replaceModelPreview(filePath, modelType = 'lora') {
// Open file picker
const input = document.createElement('input');
input.type = 'file';
input.accept ='image/*,video/mp4';
input.onchange = async function() {
if (!input.files || !input.files[0]) return;
const file = input.files[0];
await uploadPreview(filePath, file, modelType);
};
input.click();
}
// Delete a model (generic)
export function deleteModel(filePath, modelType = 'lora') {
if (modelType === 'checkpoint') {
confirmDelete('Are you sure you want to delete this checkpoint?', () => {
performDelete(filePath, modelType);
});
} else {
showDeleteModal(filePath);
}
}
// Reset and reload models
export async function resetAndReload(options = {}) {
const {
updateFolders = false,
modelType = 'lora',
loadMoreFunction
} = options;
const pageState = getCurrentPageState();
// Reset pagination and load more models
if (typeof loadMoreFunction === 'function') {
await loadMoreFunction(true, updateFolders);
}
}
// Generic function to refresh models
export async function refreshModels(options = {}) {
const {
modelType = 'lora',
scanEndpoint = '/api/loras/scan',
resetAndReloadFunction
} = options;
try {
state.loadingManager.showSimpleLoading(`Refreshing ${modelType}s...`);
const response = await fetch(scanEndpoint);
if (!response.ok) {
throw new Error(`Failed to refresh ${modelType}s: ${response.status} ${response.statusText}`);
}
if (typeof resetAndReloadFunction === 'function') {
await resetAndReloadFunction();
}
showToast(`Refresh complete`, 'success');
} catch (error) {
console.error(`Refresh failed:`, error);
showToast(`Failed to refresh ${modelType}s`, 'error');
} finally {
state.loadingManager.hide();
state.loadingManager.restoreProgressBar();
}
}
// Generic fetch from Civitai
export async function fetchCivitaiMetadata(options = {}) {
const {
modelType = 'lora',
fetchEndpoint = '/api/fetch-all-civitai',
resetAndReloadFunction
} = options;
let ws = null;
await state.loadingManager.showWithProgress(async (loading) => {
try {
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
const operationComplete = new Promise((resolve, reject) => {
ws.onmessage = (event) => {
const data = JSON.parse(event.data);
switch(data.status) {
case 'started':
loading.setStatus('Starting metadata fetch...');
break;
case 'processing':
const percent = ((data.processed / data.total) * 100).toFixed(1);
loading.setProgress(percent);
loading.setStatus(
`Processing (${data.processed}/${data.total}) ${data.current_name}`
);
break;
case 'completed':
loading.setProgress(100);
loading.setStatus(
`Completed: Updated ${data.success} of ${data.processed} ${modelType}s`
);
resolve();
break;
case 'error':
reject(new Error(data.error));
break;
}
};
ws.onerror = (error) => {
reject(new Error('WebSocket error: ' + error.message));
};
});
await new Promise((resolve, reject) => {
ws.onopen = resolve;
ws.onerror = reject;
});
const requestBody = modelType === 'checkpoint'
? JSON.stringify({ model_type: 'checkpoint' })
: JSON.stringify({});
const response = await fetch(fetchEndpoint, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: requestBody
});
if (!response.ok) {
throw new Error('Failed to fetch metadata');
}
await operationComplete;
if (typeof resetAndReloadFunction === 'function') {
await resetAndReloadFunction();
}
} catch (error) {
console.error('Error fetching metadata:', error);
showToast('Failed to fetch metadata: ' + error.message, 'error');
} finally {
if (ws) {
ws.close();
}
}
}, {
initialMessage: 'Connecting...',
completionMessage: 'Metadata update complete'
});
}
// Generic function to refresh single model metadata
export async function refreshSingleModelMetadata(filePath, modelType = 'lora') {
try {
state.loadingManager.showSimpleLoading('Refreshing metadata...');
const endpoint = modelType === 'checkpoint'
? '/api/checkpoints/fetch-civitai'
: '/api/fetch-civitai';
const response = await fetch(endpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ file_path: filePath })
});
if (!response.ok) {
throw new Error('Failed to refresh metadata');
}
const data = await response.json();
if (data.success) {
showToast('Metadata refreshed successfully', 'success');
return true;
} else {
throw new Error(data.error || 'Failed to refresh metadata');
}
} catch (error) {
console.error('Error refreshing metadata:', error);
showToast(error.message, 'error');
return false;
} finally {
state.loadingManager.hide();
state.loadingManager.restoreProgressBar();
}
}
// Private methods
// Upload a preview image
async function uploadPreview(filePath, file, modelType = 'lora') {
const loadingOverlay = document.getElementById('loading-overlay');
const loadingStatus = document.querySelector('.loading-status');
try {
if (loadingOverlay) loadingOverlay.style.display = 'flex';
if (loadingStatus) loadingStatus.textContent = 'Uploading preview...';
const formData = new FormData();
// Use appropriate parameter names and endpoint based on model type
// Prepare common form data
formData.append('preview_file', file);
formData.append('model_path', filePath);
// Set endpoint based on model type
const endpoint = modelType === 'checkpoint'
? '/api/checkpoints/replace-preview'
: '/api/replace_preview';
const response = await fetch(endpoint, {
method: 'POST',
body: formData
});
if (!response.ok) {
throw new Error('Upload failed');
}
const data = await response.json();
// Update the card preview in UI
const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
if (card) {
const previewContainer = card.querySelector('.card-preview');
const oldPreview = previewContainer.querySelector('img, video');
// For LoRA models, use timestamp to prevent caching
if (modelType === 'lora') {
state.previewVersions?.set(filePath, Date.now());
}
const timestamp = Date.now();
const previewUrl = data.preview_url ?
`${data.preview_url}?t=${timestamp}` :
`/api/model/preview_image?path=${encodeURIComponent(filePath)}&t=${timestamp}`;
// Create appropriate element based on file type
if (file.type.startsWith('video/')) {
const video = document.createElement('video');
video.controls = true;
video.autoplay = true;
video.muted = true;
video.loop = true;
video.src = previewUrl;
oldPreview.replaceWith(video);
} else {
const img = document.createElement('img');
img.src = previewUrl;
oldPreview.replaceWith(img);
}
showToast('Preview updated successfully', 'success');
}
} catch (error) {
console.error('Error uploading preview:', error);
showToast('Failed to upload preview image', 'error');
} finally {
if (loadingOverlay) loadingOverlay.style.display = 'none';
}
}
// Private function to perform the delete operation
async function performDelete(filePath, modelType = 'lora') {
try {
showToast(`Deleting ${modelType}...`, 'info');
const response = await fetch('/api/model/delete', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
file_path: filePath,
model_type: modelType
})
});
if (!response.ok) {
throw new Error(`Failed to delete ${modelType}: ${response.status} ${response.statusText}`);
}
const data = await response.json();
if (data.success) {
// Remove the card from UI
const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
if (card) {
card.remove();
}
showToast(`${modelType} deleted successfully`, 'success');
} else {
throw new Error(data.error || `Failed to delete ${modelType}`);
}
} catch (error) {
console.error(`Error deleting ${modelType}:`, error);
showToast(`Failed to delete ${modelType}: ${error.message}`, 'error');
}
}

View File

@@ -0,0 +1,57 @@
import { createCheckpointCard } from '../components/CheckpointCard.js';
import {
loadMoreModels,
resetAndReload as baseResetAndReload,
refreshModels as baseRefreshModels,
deleteModel as baseDeleteModel,
replaceModelPreview,
fetchCivitaiMetadata
} from './baseModelApi.js';
// Load more checkpoints with pagination
export async function loadMoreCheckpoints(resetPagination = true) {
return loadMoreModels({
resetPage: resetPagination,
updateFolders: true,
modelType: 'checkpoint',
createCardFunction: createCheckpointCard,
endpoint: '/api/checkpoints'
});
}
// Reset and reload checkpoints
export async function resetAndReload() {
return baseResetAndReload({
updateFolders: true,
modelType: 'checkpoint',
loadMoreFunction: loadMoreCheckpoints
});
}
// Refresh checkpoints
export async function refreshCheckpoints() {
return baseRefreshModels({
modelType: 'checkpoint',
scanEndpoint: '/api/checkpoints/scan',
resetAndReloadFunction: resetAndReload
});
}
// Delete a checkpoint
export function deleteCheckpoint(filePath) {
return baseDeleteModel(filePath, 'checkpoint');
}
// Replace checkpoint preview
export function replaceCheckpointPreview(filePath) {
return replaceModelPreview(filePath, 'checkpoint');
}
// Fetch metadata from Civitai for checkpoints
export async function fetchCivitai() {
return fetchCivitaiMetadata({
modelType: 'checkpoint',
fetchEndpoint: '/api/checkpoints/fetch-all-civitai',
resetAndReloadFunction: resetAndReload
});
}

View File

@@ -1,269 +1,38 @@
import { state, getCurrentPageState } from '../state/index.js';
import { showToast } from '../utils/uiHelpers.js';
import { createLoraCard } from '../components/LoraCard.js';
import { initializeInfiniteScroll } from '../utils/infiniteScroll.js';
import { showDeleteModal } from '../utils/modalUtils.js';
import { toggleFolder } from '../utils/uiHelpers.js';
import {
loadMoreModels,
resetAndReload as baseResetAndReload,
refreshModels as baseRefreshModels,
deleteModel as baseDeleteModel,
replaceModelPreview,
fetchCivitaiMetadata,
refreshSingleModelMetadata
} from './baseModelApi.js';
export async function loadMoreLoras(resetPage = false, updateFolders = false) {
const pageState = getCurrentPageState();
if (pageState.isLoading || (!pageState.hasMore && !resetPage)) return;
pageState.isLoading = true;
try {
// Reset to first page if requested
if (resetPage) {
pageState.currentPage = 1;
// Clear grid if resetting
const grid = document.getElementById('loraGrid');
if (grid) grid.innerHTML = '';
initializeInfiniteScroll();
}
const params = new URLSearchParams({
page: pageState.currentPage,
page_size: 20,
sort_by: pageState.sortBy
});
if (pageState.activeFolder !== null) {
params.append('folder', pageState.activeFolder);
}
// Add search parameters if there's a search term
if (pageState.filters?.search) {
params.append('search', pageState.filters.search);
params.append('fuzzy', 'true');
// Add search option parameters if available
if (pageState.searchOptions) {
params.append('search_filename', pageState.searchOptions.filename.toString());
params.append('search_modelname', pageState.searchOptions.modelname.toString());
params.append('search_tags', (pageState.searchOptions.tags || false).toString());
params.append('recursive', (pageState.searchOptions?.recursive ?? false).toString());
}
}
// Add filter parameters if active
if (pageState.filters) {
if (pageState.filters.tags && pageState.filters.tags.length > 0) {
// Convert the array of tags to a comma-separated string
params.append('tags', pageState.filters.tags.join(','));
}
if (pageState.filters.baseModel && pageState.filters.baseModel.length > 0) {
// Convert the array of base models to a comma-separated string
params.append('base_models', pageState.filters.baseModel.join(','));
}
}
const response = await fetch(`/api/loras?${params}`);
if (!response.ok) {
throw new Error(`Failed to fetch loras: ${response.statusText}`);
}
const data = await response.json();
if (data.items.length === 0 && pageState.currentPage === 1) {
const grid = document.getElementById('loraGrid');
grid.innerHTML = '<div class="no-results">No loras found in this folder</div>';
pageState.hasMore = false;
} else if (data.items.length > 0) {
pageState.hasMore = pageState.currentPage < data.total_pages;
pageState.currentPage++;
appendLoraCards(data.items);
const sentinel = document.getElementById('scroll-sentinel');
if (sentinel && state.observer) {
state.observer.observe(sentinel);
}
} else {
pageState.hasMore = false;
}
if (updateFolders && data.folders) {
updateFolderTags(data.folders);
}
} catch (error) {
console.error('Error loading loras:', error);
showToast('Failed to load loras: ' + error.message, 'error');
} finally {
pageState.isLoading = false;
}
}
function updateFolderTags(folders) {
const folderTagsContainer = document.querySelector('.folder-tags');
if (!folderTagsContainer) return;
// Keep track of currently selected folder
const pageState = getCurrentPageState();
const currentFolder = pageState.activeFolder;
// Create HTML for folder tags
const tagsHTML = folders.map(folder => {
const isActive = folder === currentFolder;
return `<div class="tag ${isActive ? 'active' : ''}" data-folder="${folder}">${folder}</div>`;
}).join('');
// Update the container
folderTagsContainer.innerHTML = tagsHTML;
// Reattach click handlers and ensure the active tag is visible
const tags = folderTagsContainer.querySelectorAll('.tag');
tags.forEach(tag => {
tag.addEventListener('click', toggleFolder);
if (tag.dataset.folder === currentFolder) {
tag.scrollIntoView({ behavior: 'smooth', block: 'nearest' });
}
return loadMoreModels({
resetPage,
updateFolders,
modelType: 'lora',
createCardFunction: createLoraCard,
endpoint: '/api/loras'
});
}
export async function fetchCivitai() {
let ws = null;
await state.loadingManager.showWithProgress(async (loading) => {
try {
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
const operationComplete = new Promise((resolve, reject) => {
ws.onmessage = (event) => {
const data = JSON.parse(event.data);
switch(data.status) {
case 'started':
loading.setStatus('Starting metadata fetch...');
break;
case 'processing':
const percent = ((data.processed / data.total) * 100).toFixed(1);
loading.setProgress(percent);
loading.setStatus(
`Processing (${data.processed}/${data.total}) ${data.current_name}`
);
break;
case 'completed':
loading.setProgress(100);
loading.setStatus(
`Completed: Updated ${data.success} of ${data.processed} loras`
);
resolve();
break;
case 'error':
reject(new Error(data.error));
break;
}
};
ws.onerror = (error) => {
reject(new Error('WebSocket error: ' + error.message));
};
});
await new Promise((resolve, reject) => {
ws.onopen = resolve;
ws.onerror = reject;
});
const response = await fetch('/api/fetch-all-civitai', {
method: 'POST',
headers: { 'Content-Type': 'application/json' }
});
if (!response.ok) {
throw new Error('Failed to fetch metadata');
}
await operationComplete;
await resetAndReload();
} catch (error) {
console.error('Error fetching metadata:', error);
showToast('Failed to fetch metadata: ' + error.message, 'error');
} finally {
if (ws) {
ws.close();
}
}
}, {
initialMessage: 'Connecting...',
completionMessage: 'Metadata update complete'
return fetchCivitaiMetadata({
modelType: 'lora',
fetchEndpoint: '/api/fetch-all-civitai',
resetAndReloadFunction: resetAndReload
});
}
export async function deleteModel(filePath) {
showDeleteModal(filePath);
return baseDeleteModel(filePath, 'lora');
}
export async function replacePreview(filePath) {
const loadingOverlay = document.getElementById('loading-overlay');
const loadingStatus = document.querySelector('.loading-status');
const input = document.createElement('input');
input.type = 'file';
input.accept = 'image/*,video/mp4';
input.onchange = async function() {
if (!input.files || !input.files[0]) return;
const file = input.files[0];
const formData = new FormData();
formData.append('preview_file', file);
formData.append('model_path', filePath);
try {
loadingOverlay.style.display = 'flex';
loadingStatus.textContent = 'Uploading preview...';
const response = await fetch('/api/replace_preview', {
method: 'POST',
body: formData
});
if (!response.ok) {
throw new Error('Upload failed');
}
const data = await response.json();
// 更新预览版本
state.previewVersions.set(filePath, Date.now());
// 更新卡片显示
const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
const previewContainer = card.querySelector('.card-preview');
const oldPreview = previewContainer.querySelector('img, video');
const previewUrl = `${data.preview_url}?t=${state.previewVersions.get(filePath)}`;
if (file.type.startsWith('video/')) {
const video = document.createElement('video');
video.controls = true;
video.autoplay = true;
video.muted = true;
video.loop = true;
video.src = previewUrl;
oldPreview.replaceWith(video);
} else {
const img = document.createElement('img');
img.src = previewUrl;
oldPreview.replaceWith(img);
}
} catch (error) {
console.error('Error uploading preview:', error);
alert('Failed to upload preview image');
} finally {
loadingOverlay.style.display = 'none';
}
};
input.click();
return replaceModelPreview(filePath, 'lora');
}
export function appendLoraCards(loras) {
@@ -277,60 +46,26 @@ export function appendLoraCards(loras) {
}
export async function resetAndReload(updateFolders = false) {
const pageState = getCurrentPageState();
console.log('Resetting with state:', { ...pageState });
// Initialize infinite scroll - will reset the observer
initializeInfiniteScroll();
// Load more loras with reset flag
await loadMoreLoras(true, updateFolders);
return baseResetAndReload({
updateFolders,
modelType: 'lora',
loadMoreFunction: loadMoreLoras
});
}
export async function refreshLoras() {
try {
state.loadingManager.showSimpleLoading('Refreshing loras...');
await resetAndReload();
showToast('Refresh complete', 'success');
} catch (error) {
console.error('Refresh failed:', error);
showToast('Failed to refresh loras', 'error');
} finally {
state.loadingManager.hide();
state.loadingManager.restoreProgressBar();
}
return baseRefreshModels({
modelType: 'lora',
scanEndpoint: '/api/loras/scan',
resetAndReloadFunction: resetAndReload
});
}
export async function refreshSingleLoraMetadata(filePath) {
try {
state.loadingManager.showSimpleLoading('Refreshing metadata...');
const response = await fetch('/api/fetch-civitai', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ file_path: filePath })
});
if (!response.ok) {
throw new Error('Failed to refresh metadata');
}
const data = await response.json();
if (data.success) {
showToast('Metadata refreshed successfully', 'success');
// Reload the current view to show updated data
await resetAndReload();
} else {
throw new Error(data.error || 'Failed to refresh metadata');
}
} catch (error) {
console.error('Error refreshing metadata:', error);
showToast(error.message, 'error');
} finally {
state.loadingManager.hide();
state.loadingManager.restoreProgressBar();
const success = await refreshSingleModelMetadata(filePath, 'lora');
if (success) {
// Reload the current view to show updated data
await resetAndReload();
}
}

View File

@@ -1,36 +1,55 @@
import { appCore } from './core.js';
import { state, initPageState } from './state/index.js';
import { initializeInfiniteScroll } from './utils/infiniteScroll.js';
import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js';
import { createPageControls } from './components/controls/index.js';
import { loadMoreCheckpoints } from './api/checkpointApi.js';
import { CheckpointDownloadManager } from './managers/CheckpointDownloadManager.js';
// Initialize the Checkpoints page
class CheckpointsPageManager {
constructor() {
// Initialize any necessary state
this.initialized = false;
// Initialize page controls
this.pageControls = createPageControls('checkpoints');
// Initialize checkpoint download manager
window.checkpointDownloadManager = new CheckpointDownloadManager();
// Expose only necessary functions to global scope
this._exposeRequiredGlobalFunctions();
}
_exposeRequiredGlobalFunctions() {
// Minimal set of functions that need to remain global
window.confirmDelete = confirmDelete;
window.closeDeleteModal = closeDeleteModal;
// Add loadCheckpoints function to window for FilterManager compatibility
window.checkpointManager = {
loadCheckpoints: (reset) => loadMoreCheckpoints(reset)
};
}
async initialize() {
if (this.initialized) return;
// Initialize page state
initPageState('checkpoints');
// Initialize core application
await appCore.initialize();
// Initialize page-specific components
this._initializeWorkInProgress();
this.pageControls.restoreFolderFilter();
this.pageControls.initFolderTagsVisibility();
this.initialized = true;
}
_initializeWorkInProgress() {
// Add any work-in-progress specific initialization here
console.log('Checkpoints Manager is under development');
// Initialize infinite scroll
initializeInfiniteScroll('checkpoints');
// Initialize common page features
appCore.initializePageFeatures();
console.log('Checkpoints Manager initialized');
}
}
// Initialize everything when DOM is ready
document.addEventListener('DOMContentLoaded', async () => {
// Initialize core application
await appCore.initialize();
// Initialize checkpoints page
const checkpointsPage = new CheckpointsPageManager();
await checkpointsPage.initialize();
});

View File

@@ -0,0 +1,313 @@
import { showToast } from '../utils/uiHelpers.js';
import { state } from '../state/index.js';
import { showCheckpointModal } from './checkpointModal/index.js';
import { NSFW_LEVELS } from '../utils/constants.js';
import { replaceCheckpointPreview as apiReplaceCheckpointPreview } from '../api/checkpointApi.js';
export function createCheckpointCard(checkpoint) {
const card = document.createElement('div');
card.className = 'lora-card'; // Reuse the same class for styling
card.dataset.sha256 = checkpoint.sha256;
card.dataset.filepath = checkpoint.file_path;
card.dataset.name = checkpoint.model_name;
card.dataset.file_name = checkpoint.file_name;
card.dataset.folder = checkpoint.folder;
card.dataset.modified = checkpoint.modified;
card.dataset.file_size = checkpoint.file_size;
card.dataset.from_civitai = checkpoint.from_civitai;
card.dataset.notes = checkpoint.notes || '';
card.dataset.base_model = checkpoint.base_model || 'Unknown';
// Store metadata if available
if (checkpoint.civitai) {
card.dataset.meta = JSON.stringify(checkpoint.civitai || {});
}
// Store tags if available
if (checkpoint.tags && Array.isArray(checkpoint.tags)) {
card.dataset.tags = JSON.stringify(checkpoint.tags);
}
if (checkpoint.modelDescription) {
card.dataset.modelDescription = checkpoint.modelDescription;
}
// Store NSFW level if available
const nsfwLevel = checkpoint.preview_nsfw_level !== undefined ? checkpoint.preview_nsfw_level : 0;
card.dataset.nsfwLevel = nsfwLevel;
// Determine if the preview should be blurred based on NSFW level and user settings
const shouldBlur = state.settings.blurMatureContent && nsfwLevel > NSFW_LEVELS.PG13;
if (shouldBlur) {
card.classList.add('nsfw-content');
}
// Determine preview URL
const previewUrl = checkpoint.preview_url || '/loras_static/images/no-preview.png';
const version = state.previewVersions ? state.previewVersions.get(checkpoint.file_path) : null;
const versionedPreviewUrl = version ? `${previewUrl}?t=${version}` : previewUrl;
// Determine NSFW warning text based on level
let nsfwText = "Mature Content";
if (nsfwLevel >= NSFW_LEVELS.XXX) {
nsfwText = "XXX-rated Content";
} else if (nsfwLevel >= NSFW_LEVELS.X) {
nsfwText = "X-rated Content";
} else if (nsfwLevel >= NSFW_LEVELS.R) {
nsfwText = "R-rated Content";
}
// Check if autoplayOnHover is enabled for video previews
const autoplayOnHover = state.global?.settings?.autoplayOnHover || false;
const isVideo = previewUrl.endsWith('.mp4');
const videoAttrs = autoplayOnHover ? 'controls muted loop' : 'controls autoplay muted loop';
card.innerHTML = `
<div class="card-preview ${shouldBlur ? 'blurred' : ''}">
${isVideo ?
`<video ${videoAttrs}>
<source src="${versionedPreviewUrl}" type="video/mp4">
</video>` :
`<img src="${versionedPreviewUrl}" alt="${checkpoint.model_name}">`
}
<div class="card-header">
${shouldBlur ?
`<button class="toggle-blur-btn" title="Toggle blur">
<i class="fas fa-eye"></i>
</button>` : ''}
<span class="base-model-label ${shouldBlur ? 'with-toggle' : ''}" title="${checkpoint.base_model}">
${checkpoint.base_model}
</span>
<div class="card-actions">
<i class="fas fa-globe"
title="${checkpoint.from_civitai ? 'View on Civitai' : 'Not available from Civitai'}"
${!checkpoint.from_civitai ? 'style="opacity: 0.5; cursor: not-allowed"' : ''}>
</i>
<i class="fas fa-copy"
title="Copy Checkpoint Name">
</i>
<i class="fas fa-trash"
title="Delete Model">
</i>
</div>
</div>
${shouldBlur ? `
<div class="nsfw-overlay">
<div class="nsfw-warning">
<p>${nsfwText}</p>
<button class="show-content-btn">Show</button>
</div>
</div>
` : ''}
<div class="card-footer">
<div class="model-info">
<span class="model-name">${checkpoint.model_name}</span>
</div>
<div class="card-actions">
<i class="fas fa-image"
title="Replace Preview Image">
</i>
</div>
</div>
</div>
`;
// Main card click event
card.addEventListener('click', () => {
// Show checkpoint details modal
const checkpointMeta = {
sha256: card.dataset.sha256,
file_path: card.dataset.filepath,
model_name: card.dataset.name,
file_name: card.dataset.file_name,
folder: card.dataset.folder,
modified: card.dataset.modified,
file_size: parseInt(card.dataset.file_size || '0'),
from_civitai: card.dataset.from_civitai === 'true',
base_model: card.dataset.base_model,
notes: card.dataset.notes || '',
preview_url: versionedPreviewUrl,
// Parse civitai metadata from the card's dataset
civitai: (() => {
try {
return JSON.parse(card.dataset.meta || '{}');
} catch (e) {
console.error('Failed to parse civitai metadata:', e);
return {}; // Return empty object on error
}
})(),
tags: (() => {
try {
return JSON.parse(card.dataset.tags || '[]');
} catch (e) {
console.error('Failed to parse tags:', e);
return []; // Return empty array on error
}
})(),
modelDescription: card.dataset.modelDescription || ''
};
showCheckpointModal(checkpointMeta);
});
// Toggle blur button functionality
const toggleBlurBtn = card.querySelector('.toggle-blur-btn');
if (toggleBlurBtn) {
toggleBlurBtn.addEventListener('click', (e) => {
e.stopPropagation();
const preview = card.querySelector('.card-preview');
const isBlurred = preview.classList.toggle('blurred');
const icon = toggleBlurBtn.querySelector('i');
// Update the icon based on blur state
if (isBlurred) {
icon.className = 'fas fa-eye';
} else {
icon.className = 'fas fa-eye-slash';
}
// Toggle the overlay visibility
const overlay = card.querySelector('.nsfw-overlay');
if (overlay) {
overlay.style.display = isBlurred ? 'flex' : 'none';
}
});
}
// Show content button functionality
const showContentBtn = card.querySelector('.show-content-btn');
if (showContentBtn) {
showContentBtn.addEventListener('click', (e) => {
e.stopPropagation();
const preview = card.querySelector('.card-preview');
preview.classList.remove('blurred');
// Update the toggle button icon
const toggleBtn = card.querySelector('.toggle-blur-btn');
if (toggleBtn) {
toggleBtn.querySelector('i').className = 'fas fa-eye-slash';
}
// Hide the overlay
const overlay = card.querySelector('.nsfw-overlay');
if (overlay) {
overlay.style.display = 'none';
}
});
}
// Copy button click event
card.querySelector('.fa-copy')?.addEventListener('click', async e => {
e.stopPropagation();
const checkpointName = card.dataset.file_name;
try {
// Modern clipboard API
if (navigator.clipboard && window.isSecureContext) {
await navigator.clipboard.writeText(checkpointName);
} else {
// Fallback for older browsers
const textarea = document.createElement('textarea');
textarea.value = checkpointName;
textarea.style.position = 'absolute';
textarea.style.left = '-99999px';
document.body.appendChild(textarea);
textarea.select();
document.execCommand('copy');
document.body.removeChild(textarea);
}
showToast('Checkpoint name copied', 'success');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');
}
});
// Civitai button click event
if (checkpoint.from_civitai) {
card.querySelector('.fa-globe')?.addEventListener('click', e => {
e.stopPropagation();
openCivitai(checkpoint.model_name);
});
}
// Delete button click event
card.querySelector('.fa-trash')?.addEventListener('click', e => {
e.stopPropagation();
deleteCheckpoint(checkpoint.file_path);
});
// Replace preview button click event
card.querySelector('.fa-image')?.addEventListener('click', e => {
e.stopPropagation();
replaceCheckpointPreview(checkpoint.file_path);
});
// Add autoplayOnHover handlers for video elements if needed
const videoElement = card.querySelector('video');
if (videoElement && autoplayOnHover) {
const cardPreview = card.querySelector('.card-preview');
// Remove autoplay attribute and pause initially
videoElement.removeAttribute('autoplay');
videoElement.pause();
// Add mouse events to trigger play/pause
cardPreview.addEventListener('mouseenter', () => {
videoElement.play();
});
cardPreview.addEventListener('mouseleave', () => {
videoElement.pause();
videoElement.currentTime = 0;
});
}
return card;
}
// These functions will be implemented in checkpointApi.js
function openCivitai(modelName) {
// Check if the global function exists (registered by PageControls)
if (window.openCivitai) {
window.openCivitai(modelName);
} else {
// Fallback implementation
const card = document.querySelector(`.lora-card[data-name="${modelName}"]`);
if (!card) return;
const metaData = JSON.parse(card.dataset.meta || '{}');
const civitaiId = metaData.modelId;
const versionId = metaData.id;
// Build URL
if (civitaiId) {
let url = `https://civitai.com/models/${civitaiId}`;
if (versionId) {
url += `?modelVersionId=${versionId}`;
}
window.open(url, '_blank');
} else {
// If no ID, try searching by name
window.open(`https://civitai.com/models?query=${encodeURIComponent(modelName)}`, '_blank');
}
}
}
function deleteCheckpoint(filePath) {
if (window.deleteCheckpoint) {
window.deleteCheckpoint(filePath);
} else {
// Use the modal delete functionality
import('../utils/modalUtils.js').then(({ showDeleteModal }) => {
showDeleteModal(filePath, 'checkpoint');
});
}
}
function replaceCheckpointPreview(filePath) {
if (window.replaceCheckpointPreview) {
window.replaceCheckpointPreview(filePath);
} else {
apiReplaceCheckpointPreview(filePath);
}
}

View File

@@ -130,7 +130,7 @@ export class LoraContextMenu {
}
async saveModelMetadata(filePath, data) {
const response = await fetch('/loras/api/save-metadata', {
const response = await fetch('/api/loras/save-metadata', {
method: 'POST',
headers: {
'Content-Type': 'application/json',

View File

@@ -1,8 +1,9 @@
import { showToast } from '../utils/uiHelpers.js';
import { showToast, openCivitai } from '../utils/uiHelpers.js';
import { state } from '../state/index.js';
import { showLoraModal } from './LoraModal.js';
import { showLoraModal } from './loraModal/index.js';
import { bulkManager } from '../managers/BulkManager.js';
import { NSFW_LEVELS } from '../utils/constants.js';
import { replacePreview, deleteModel } from '../api/loraApi.js'
export function createLoraCard(lora) {
const card = document.createElement('div');
@@ -57,10 +58,15 @@ export function createLoraCard(lora) {
nsfwText = "R-rated Content";
}
// Check if autoplayOnHover is enabled for video previews
const autoplayOnHover = state.global.settings.autoplayOnHover || false;
const isVideo = previewUrl.endsWith('.mp4');
const videoAttrs = autoplayOnHover ? 'controls muted loop' : 'controls autoplay muted loop';
card.innerHTML = `
<div class="card-preview ${shouldBlur ? 'blurred' : ''}">
${previewUrl.endsWith('.mp4') ?
`<video controls autoplay muted loop>
${isVideo ?
`<video ${videoAttrs}>
<source src="${versionedPreviewUrl}" type="video/mp4">
</video>` :
`<img src="${versionedPreviewUrl}" alt="${lora.model_name}">`
@@ -246,6 +252,26 @@ export function createLoraCard(lora) {
actionGroup.style.display = 'none';
});
}
// Add autoplayOnHover handlers for video elements if needed
const videoElement = card.querySelector('video');
if (videoElement && autoplayOnHover) {
const cardPreview = card.querySelector('.card-preview');
// Remove autoplay attribute and pause initially
videoElement.removeAttribute('autoplay');
videoElement.pause();
// Add mouse events to trigger play/pause
cardPreview.addEventListener('mouseenter', () => {
videoElement.play();
});
cardPreview.addEventListener('mouseleave', () => {
videoElement.pause();
videoElement.currentTime = 0;
});
}
return card;
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,7 @@
// Recipe Modal Component
import { showToast } from '../utils/uiHelpers.js';
import { state } from '../state/index.js';
import { setSessionItem, removeSessionItem } from '../utils/storageHelpers.js';
class RecipeModal {
constructor() {
@@ -294,7 +295,7 @@ class RecipeModal {
} else {
// No generation parameters available
if (promptElement) promptElement.textContent = 'No prompt information available';
if (negativePromptElement) negativePromptElement.textContent = 'No negative prompt information available';
if (negativePromptElement) promptElement.textContent = 'No negative prompt information available';
if (otherParamsElement) otherParamsElement.innerHTML = '<div class="no-params">No parameters available</div>';
}
@@ -342,8 +343,15 @@ class RecipeModal {
lorasCountElement.innerHTML = `<i class="fas fa-layer-group"></i> ${totalCount} LoRAs ${statusHTML}`;
// Add click handler for missing LoRAs status
// Add event listeners for buttons and status indicators
setTimeout(() => {
// Set up click handler for View LoRAs button
const viewRecipeLorasBtn = document.getElementById('viewRecipeLorasBtn');
if (viewRecipeLorasBtn) {
viewRecipeLorasBtn.addEventListener('click', () => this.navigateToLorasPage());
}
// Add click handler for missing LoRAs status
const missingStatus = document.querySelector('.recipe-status.missing');
if (missingStatus && missingLorasCount > 0) {
missingStatus.classList.add('clickable');
@@ -433,6 +441,7 @@ class RecipeModal {
// Add event listeners for reconnect functionality
setTimeout(() => {
this.setupReconnectButtons();
this.setupLoraItemsClickable();
}, 100);
// Generate recipe syntax for copy button (this is now a placeholder, actual syntax will be fetched from the API)
@@ -1007,6 +1016,65 @@ class RecipeModal {
state.loadingManager.hide();
}
}
// New method to navigate to the LoRAs page
navigateToLorasPage(specificLoraIndex = null) {
// Close the current modal
modalManager.closeModal('recipeModal');
// Clear any previous filters first
removeSessionItem('recipe_to_lora_filterLoraHash');
removeSessionItem('recipe_to_lora_filterLoraHashes');
removeSessionItem('filterRecipeName');
removeSessionItem('viewLoraDetail');
if (specificLoraIndex !== null) {
// If a specific LoRA index is provided, navigate to view just that one LoRA
const lora = this.currentRecipe.loras[specificLoraIndex];
if (lora && lora.hash) {
// Set session storage to open the LoRA modal directly
setSessionItem('recipe_to_lora_filterLoraHash', lora.hash.toLowerCase());
setSessionItem('viewLoraDetail', 'true');
setSessionItem('filterRecipeName', this.currentRecipe.title);
}
} else {
// If no specific LoRA index is provided, show all LoRAs from this recipe
// Collect all hashes from the recipe's LoRAs
const loraHashes = this.currentRecipe.loras
.filter(lora => lora.hash)
.map(lora => lora.hash.toLowerCase());
if (loraHashes.length > 0) {
// Store the LoRA hashes and recipe name in sessionStorage
setSessionItem('recipe_to_lora_filterLoraHashes', JSON.stringify(loraHashes));
setSessionItem('filterRecipeName', this.currentRecipe.title);
}
}
// Navigate to the LoRAs page
window.location.href = '/loras';
}
// New method to make LoRA items clickable
setupLoraItemsClickable() {
const loraItems = document.querySelectorAll('.recipe-lora-item');
loraItems.forEach(item => {
// Get the lora index from the data attribute
const loraIndex = parseInt(item.dataset.loraIndex);
item.addEventListener('click', (e) => {
// If the click is on the reconnect container or badge, don't navigate
if (e.target.closest('.lora-reconnect-container') ||
e.target.closest('.deleted-badge') ||
e.target.closest('.reconnect-tooltip')) {
return;
}
// Navigate to the LoRAs page with the specific LoRA index
this.navigateToLorasPage(loraIndex);
});
});
}
}
export { RecipeModal };
export { RecipeModal };

View File

@@ -0,0 +1,102 @@
/**
* ModelDescription.js
* Handles checkpoint model descriptions
*/
import { showToast } from '../../utils/uiHelpers.js';
/**
* Set up tab switching functionality
*/
export function setupTabSwitching() {
const tabButtons = document.querySelectorAll('.showcase-tabs .tab-btn');
tabButtons.forEach(button => {
button.addEventListener('click', () => {
// Remove active class from all tabs
document.querySelectorAll('.showcase-tabs .tab-btn').forEach(btn =>
btn.classList.remove('active')
);
document.querySelectorAll('.tab-content .tab-pane').forEach(tab =>
tab.classList.remove('active')
);
// Add active class to clicked tab
button.classList.add('active');
const tabId = `${button.dataset.tab}-tab`;
document.getElementById(tabId).classList.add('active');
// If switching to description tab, make sure content is properly loaded and displayed
if (button.dataset.tab === 'description') {
const descriptionContent = document.querySelector('.model-description-content');
if (descriptionContent) {
const hasContent = descriptionContent.innerHTML.trim() !== '';
document.querySelector('.model-description-loading')?.classList.add('hidden');
// If no content, show a message
if (!hasContent) {
descriptionContent.innerHTML = '<div class="no-description">No model description available</div>';
descriptionContent.classList.remove('hidden');
}
}
}
});
});
}
/**
* Load model description from API
* @param {string} modelId - The Civitai model ID
* @param {string} filePath - File path for the model
*/
export async function loadModelDescription(modelId, filePath) {
try {
const descriptionContainer = document.querySelector('.model-description-content');
const loadingElement = document.querySelector('.model-description-loading');
if (!descriptionContainer || !loadingElement) return;
// Show loading indicator
loadingElement.classList.remove('hidden');
descriptionContainer.classList.add('hidden');
// Try to get model description from API
const response = await fetch(`/api/checkpoint-model-description?model_id=${modelId}&file_path=${encodeURIComponent(filePath)}`);
if (!response.ok) {
throw new Error(`Failed to fetch model description: ${response.statusText}`);
}
const data = await response.json();
if (data.success && data.description) {
// Update the description content
descriptionContainer.innerHTML = data.description;
// Process any links in the description to open in new tab
const links = descriptionContainer.querySelectorAll('a');
links.forEach(link => {
link.setAttribute('target', '_blank');
link.setAttribute('rel', 'noopener noreferrer');
});
// Show the description and hide loading indicator
descriptionContainer.classList.remove('hidden');
loadingElement.classList.add('hidden');
} else {
throw new Error(data.error || 'No description available');
}
} catch (error) {
console.error('Error loading model description:', error);
const loadingElement = document.querySelector('.model-description-loading');
if (loadingElement) {
loadingElement.innerHTML = `<div class="error-message">Failed to load model description. ${error.message}</div>`;
}
// Show empty state message in the description container
const descriptionContainer = document.querySelector('.model-description-content');
if (descriptionContainer) {
descriptionContainer.innerHTML = '<div class="no-description">No model description available</div>';
descriptionContainer.classList.remove('hidden');
}
}
}

View File

@@ -0,0 +1,484 @@
/**
* ModelMetadata.js
* Handles checkpoint model metadata editing functionality
*/
import { showToast } from '../../utils/uiHelpers.js';
import { BASE_MODELS } from '../../utils/constants.js';
import { updateCheckpointCard } from '../../utils/cardUpdater.js';
/**
* Save model metadata to the server
* @param {string} filePath - Path to the model file
* @param {Object} data - Metadata to save
* @returns {Promise} - Promise that resolves with the server response
*/
export async function saveModelMetadata(filePath, data) {
const response = await fetch('/api/checkpoints/save-metadata', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
file_path: filePath,
...data
})
});
if (!response.ok) {
throw new Error('Failed to save metadata');
}
return response.json();
}
/**
* Set up model name editing functionality
* @param {string} filePath - The full file path of the model.
*/
export function setupModelNameEditing(filePath) {
const modelNameContent = document.querySelector('.model-name-content');
const editBtn = document.querySelector('.edit-model-name-btn');
if (!modelNameContent || !editBtn) return;
// Show edit button on hover
const modelNameHeader = document.querySelector('.model-name-header');
modelNameHeader.addEventListener('mouseenter', () => {
editBtn.classList.add('visible');
});
modelNameHeader.addEventListener('mouseleave', () => {
if (!modelNameContent.getAttribute('data-editing')) {
editBtn.classList.remove('visible');
}
});
// Handle edit button click
editBtn.addEventListener('click', () => {
modelNameContent.setAttribute('data-editing', 'true');
modelNameContent.focus();
// Place cursor at the end
const range = document.createRange();
const sel = window.getSelection();
if (modelNameContent.childNodes.length > 0) {
range.setStart(modelNameContent.childNodes[0], modelNameContent.textContent.length);
range.collapse(true);
sel.removeAllRanges();
sel.addRange(range);
}
editBtn.classList.add('visible');
});
// Handle focus out
modelNameContent.addEventListener('blur', function() {
this.removeAttribute('data-editing');
editBtn.classList.remove('visible');
if (this.textContent.trim() === '') {
// Restore original model name if empty
// Use the passed filePath to find the card
const checkpointCard = document.querySelector(`.checkpoint-card[data-filepath="${filePath}"]`);
if (checkpointCard) {
this.textContent = checkpointCard.dataset.model_name;
}
}
});
// Handle enter key
modelNameContent.addEventListener('keydown', function(e) {
if (e.key === 'Enter') {
e.preventDefault();
// Use the passed filePath
saveModelName(filePath);
this.blur();
}
});
// Limit model name length
modelNameContent.addEventListener('input', function() {
if (this.textContent.length > 100) {
this.textContent = this.textContent.substring(0, 100);
// Place cursor at the end
const range = document.createRange();
const sel = window.getSelection();
range.setStart(this.childNodes[0], 100);
range.collapse(true);
sel.removeAllRanges();
sel.addRange(range);
showToast('Model name is limited to 100 characters', 'warning');
}
});
}
/**
* Save model name
* @param {string} filePath - File path
*/
async function saveModelName(filePath) {
const modelNameElement = document.querySelector('.model-name-content');
const newModelName = modelNameElement.textContent.trim();
// Validate model name
if (!newModelName) {
showToast('Model name cannot be empty', 'error');
return;
}
// Check if model name is too long
if (newModelName.length > 100) {
showToast('Model name is too long (maximum 100 characters)', 'error');
// Truncate the displayed text
modelNameElement.textContent = newModelName.substring(0, 100);
return;
}
try {
await saveModelMetadata(filePath, { model_name: newModelName });
// Update the card with the new model name
updateCheckpointCard(filePath, { name: newModelName });
showToast('Model name updated successfully', 'success');
// No need to reload the entire page
// setTimeout(() => {
// window.location.reload();
// }, 1500);
} catch (error) {
showToast('Failed to update model name', 'error');
}
}
/**
* Set up base model editing functionality
* @param {string} filePath - The full file path of the model.
*/
export function setupBaseModelEditing(filePath) {
const baseModelContent = document.querySelector('.base-model-content');
const editBtn = document.querySelector('.edit-base-model-btn');
if (!baseModelContent || !editBtn) return;
// Show edit button on hover
const baseModelDisplay = document.querySelector('.base-model-display');
baseModelDisplay.addEventListener('mouseenter', () => {
editBtn.classList.add('visible');
});
baseModelDisplay.addEventListener('mouseleave', () => {
if (!baseModelDisplay.classList.contains('editing')) {
editBtn.classList.remove('visible');
}
});
// Handle edit button click
editBtn.addEventListener('click', () => {
baseModelDisplay.classList.add('editing');
// Store the original value to check for changes later
const originalValue = baseModelContent.textContent.trim();
// Create dropdown selector to replace the base model content
const currentValue = originalValue;
const dropdown = document.createElement('select');
dropdown.className = 'base-model-selector';
// Flag to track if a change was made
let valueChanged = false;
// Add options from BASE_MODELS constants
const baseModelCategories = {
'Stable Diffusion 1.x': [BASE_MODELS.SD_1_4, BASE_MODELS.SD_1_5, BASE_MODELS.SD_1_5_LCM, BASE_MODELS.SD_1_5_HYPER],
'Stable Diffusion 2.x': [BASE_MODELS.SD_2_0, BASE_MODELS.SD_2_1],
'Stable Diffusion 3.x': [BASE_MODELS.SD_3, BASE_MODELS.SD_3_5, BASE_MODELS.SD_3_5_MEDIUM, BASE_MODELS.SD_3_5_LARGE, BASE_MODELS.SD_3_5_LARGE_TURBO],
'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER],
'Video Models': [BASE_MODELS.SVD, BASE_MODELS.WAN_VIDEO, BASE_MODELS.HUNYUAN_VIDEO],
'Other Models': [
BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.AURAFLOW,
BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1,
BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI,
BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.UNKNOWN
]
};
// Create option groups for better organization
Object.entries(baseModelCategories).forEach(([category, models]) => {
const group = document.createElement('optgroup');
group.label = category;
models.forEach(model => {
const option = document.createElement('option');
option.value = model;
option.textContent = model;
option.selected = model === currentValue;
group.appendChild(option);
});
dropdown.appendChild(group);
});
// Replace content with dropdown
baseModelContent.style.display = 'none';
baseModelDisplay.insertBefore(dropdown, editBtn);
// Hide edit button during editing
editBtn.style.display = 'none';
// Focus the dropdown
dropdown.focus();
// Handle dropdown change
dropdown.addEventListener('change', function() {
const selectedModel = this.value;
baseModelContent.textContent = selectedModel;
// Mark that a change was made if the value differs from original
if (selectedModel !== originalValue) {
valueChanged = true;
} else {
valueChanged = false;
}
});
// Function to save changes and exit edit mode
const saveAndExit = function() {
// Check if dropdown still exists and remove it
if (dropdown && dropdown.parentNode === baseModelDisplay) {
baseModelDisplay.removeChild(dropdown);
}
// Show the content and edit button
baseModelContent.style.display = '';
editBtn.style.display = '';
// Remove editing class
baseModelDisplay.classList.remove('editing');
// Only save if the value has actually changed
if (valueChanged || baseModelContent.textContent.trim() !== originalValue) {
// Use the passed filePath for saving
saveBaseModel(filePath, originalValue);
}
// Remove this event listener
document.removeEventListener('click', outsideClickHandler);
};
// Handle outside clicks to save and exit
const outsideClickHandler = function(e) {
// If click is outside the dropdown and base model display
if (!baseModelDisplay.contains(e.target)) {
saveAndExit();
}
};
// Add delayed event listener for outside clicks
setTimeout(() => {
document.addEventListener('click', outsideClickHandler);
}, 0);
// Also handle dropdown blur event
dropdown.addEventListener('blur', function(e) {
// Only save if the related target is not the edit button or inside the baseModelDisplay
if (!baseModelDisplay.contains(e.relatedTarget)) {
saveAndExit();
}
});
});
}
/**
* Save base model
* @param {string} filePath - File path
* @param {string} originalValue - Original value (for comparison)
*/
async function saveBaseModel(filePath, originalValue) {
const baseModelElement = document.querySelector('.base-model-content');
const newBaseModel = baseModelElement.textContent.trim();
// Only save if the value has actually changed
if (newBaseModel === originalValue) {
return; // No change, no need to save
}
try {
await saveModelMetadata(filePath, { base_model: newBaseModel });
// Update the card with the new base model
updateCheckpointCard(filePath, { base_model: newBaseModel });
showToast('Base model updated successfully', 'success');
} catch (error) {
showToast('Failed to update base model', 'error');
}
}
/**
* Set up file name editing functionality
* @param {string} filePath - The full file path of the model.
*/
export function setupFileNameEditing(filePath) {
const fileNameContent = document.querySelector('.file-name-content');
const editBtn = document.querySelector('.edit-file-name-btn');
if (!fileNameContent || !editBtn) return;
// Show edit button on hover
const fileNameWrapper = document.querySelector('.file-name-wrapper');
fileNameWrapper.addEventListener('mouseenter', () => {
editBtn.classList.add('visible');
});
fileNameWrapper.addEventListener('mouseleave', () => {
if (!fileNameWrapper.classList.contains('editing')) {
editBtn.classList.remove('visible');
}
});
// Handle edit button click
editBtn.addEventListener('click', () => {
fileNameWrapper.classList.add('editing');
fileNameContent.setAttribute('contenteditable', 'true');
fileNameContent.focus();
// Store original value for comparison later
fileNameContent.dataset.originalValue = fileNameContent.textContent.trim();
// Place cursor at the end
const range = document.createRange();
const sel = window.getSelection();
range.selectNodeContents(fileNameContent);
range.collapse(false);
sel.removeAllRanges();
sel.addRange(range);
editBtn.classList.add('visible');
});
// Handle keyboard events in edit mode
fileNameContent.addEventListener('keydown', function(e) {
if (!this.getAttribute('contenteditable')) return;
if (e.key === 'Enter') {
e.preventDefault();
this.blur(); // Trigger save on Enter
} else if (e.key === 'Escape') {
e.preventDefault();
// Restore original value
this.textContent = this.dataset.originalValue;
exitEditMode();
}
});
// Handle input validation
fileNameContent.addEventListener('input', function() {
if (!this.getAttribute('contenteditable')) return;
// Replace invalid characters for filenames
const invalidChars = /[\\/:*?"<>|]/g;
if (invalidChars.test(this.textContent)) {
const cursorPos = window.getSelection().getRangeAt(0).startOffset;
this.textContent = this.textContent.replace(invalidChars, '');
// Restore cursor position
const range = document.createRange();
const sel = window.getSelection();
const newPos = Math.min(cursorPos, this.textContent.length);
if (this.firstChild) {
range.setStart(this.firstChild, newPos);
range.collapse(true);
sel.removeAllRanges();
sel.addRange(range);
}
showToast('Invalid characters removed from filename', 'warning');
}
});
// Handle focus out - save changes
fileNameContent.addEventListener('blur', async function() {
if (!this.getAttribute('contenteditable')) return;
const newFileName = this.textContent.trim();
const originalValue = this.dataset.originalValue;
// Basic validation
if (!newFileName) {
// Restore original value if empty
this.textContent = originalValue;
showToast('File name cannot be empty', 'error');
exitEditMode();
return;
}
if (newFileName === originalValue) {
// No changes, just exit edit mode
exitEditMode();
return;
}
try {
// Use the passed filePath (which includes the original filename)
// Call API to rename the file
const response = await fetch('/api/rename_checkpoint', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
file_path: filePath, // Use the full original path
new_file_name: newFileName
})
});
const result = await response.json();
if (result.success) {
showToast('File name updated successfully', 'success');
// Get the new file path from the result
const pathParts = filePath.split(/[\\/]/);
pathParts.pop(); // Remove old filename
const newFilePath = [...pathParts, newFileName].join('/');
// Update the checkpoint card with new file path
updateCheckpointCard(filePath, {
filepath: newFilePath,
file_name: newFileName
});
// Update the file name display in the modal
document.querySelector('#file-name').textContent = newFileName;
// Update the modal's data-filepath attribute
const modalContent = document.querySelector('#checkpointModal .modal-content');
if (modalContent) {
modalContent.dataset.filepath = newFilePath;
}
// Reload the page after a short delay to reflect changes
setTimeout(() => {
window.location.reload();
}, 1500);
} else {
throw new Error(result.error || 'Unknown error');
}
} catch (error) {
console.error('Error renaming file:', error);
this.textContent = originalValue; // Restore original file name
showToast(`Failed to rename file: ${error.message}`, 'error');
} finally {
exitEditMode();
}
});
function exitEditMode() {
fileNameContent.removeAttribute('contenteditable');
fileNameWrapper.classList.remove('editing');
editBtn.classList.remove('visible');
}
}

View File

@@ -0,0 +1,489 @@
/**
* ShowcaseView.js
* Handles showcase content (images, videos) display for checkpoint modal
*/
import { showToast } from '../../utils/uiHelpers.js';
import { state } from '../../state/index.js';
import { NSFW_LEVELS } from '../../utils/constants.js';
/**
* Render showcase content
* @param {Array} images - Array of images/videos to show
* @returns {string} HTML content
*/
export function renderShowcaseContent(images) {
if (!images?.length) return '<div class="no-examples">No example images available</div>';
// Filter images based on SFW setting
const showOnlySFW = state.settings.show_only_sfw;
let filteredImages = images;
let hiddenCount = 0;
if (showOnlySFW) {
filteredImages = images.filter(img => {
const nsfwLevel = img.nsfwLevel !== undefined ? img.nsfwLevel : 0;
const isSfw = nsfwLevel < NSFW_LEVELS.R;
if (!isSfw) hiddenCount++;
return isSfw;
});
}
// Show message if no images are available after filtering
if (filteredImages.length === 0) {
return `
<div class="no-examples">
<p>All example images are filtered due to NSFW content settings</p>
<p class="nsfw-filter-info">Your settings are currently set to show only safe-for-work content</p>
<p>You can change this in Settings <i class="fas fa-cog"></i></p>
</div>
`;
}
// Show hidden content notification if applicable
const hiddenNotification = hiddenCount > 0 ?
`<div class="nsfw-filter-notification">
<i class="fas fa-eye-slash"></i> ${hiddenCount} ${hiddenCount === 1 ? 'image' : 'images'} hidden due to SFW-only setting
</div>` : '';
return `
<div class="scroll-indicator" onclick="toggleShowcase(this)">
<i class="fas fa-chevron-down"></i>
<span>Scroll or click to show ${filteredImages.length} examples</span>
</div>
<div class="carousel collapsed">
${hiddenNotification}
<div class="carousel-container">
${filteredImages.map(img => generateMediaWrapper(img)).join('')}
</div>
</div>
`;
}
/**
* Generate media wrapper HTML for an image or video
* @param {Object} media - Media object with image or video data
* @returns {string} HTML content
*/
function generateMediaWrapper(media) {
// Calculate appropriate aspect ratio:
// 1. Keep original aspect ratio
// 2. Limit maximum height to 60% of viewport height
// 3. Ensure minimum height is 40% of container width
const aspectRatio = (media.height / media.width) * 100;
const containerWidth = 800; // modal content maximum width
const minHeightPercent = 40;
const maxHeightPercent = (window.innerHeight * 0.6 / containerWidth) * 100;
const heightPercent = Math.max(
minHeightPercent,
Math.min(maxHeightPercent, aspectRatio)
);
// Check if media should be blurred
const nsfwLevel = media.nsfwLevel !== undefined ? media.nsfwLevel : 0;
const shouldBlur = state.settings.blurMatureContent && nsfwLevel > NSFW_LEVELS.PG13;
// Determine NSFW warning text based on level
let nsfwText = "Mature Content";
if (nsfwLevel >= NSFW_LEVELS.XXX) {
nsfwText = "XXX-rated Content";
} else if (nsfwLevel >= NSFW_LEVELS.X) {
nsfwText = "X-rated Content";
} else if (nsfwLevel >= NSFW_LEVELS.R) {
nsfwText = "R-rated Content";
}
// Extract metadata from the media
const meta = media.meta || {};
const prompt = meta.prompt || '';
const negativePrompt = meta.negative_prompt || meta.negativePrompt || '';
const size = meta.Size || `${media.width}x${media.height}`;
const seed = meta.seed || '';
const model = meta.Model || '';
const steps = meta.steps || '';
const sampler = meta.sampler || '';
const cfgScale = meta.cfgScale || '';
const clipSkip = meta.clipSkip || '';
// Check if we have any meaningful generation parameters
const hasParams = seed || model || steps || sampler || cfgScale || clipSkip;
const hasPrompts = prompt || negativePrompt;
// Create metadata panel content
const metadataPanel = generateMetadataPanel(
hasParams, hasPrompts,
prompt, negativePrompt,
size, seed, model, steps, sampler, cfgScale, clipSkip
);
// Check if this is a video or image
if (media.type === 'video') {
return generateVideoWrapper(media, heightPercent, shouldBlur, nsfwText, metadataPanel);
}
return generateImageWrapper(media, heightPercent, shouldBlur, nsfwText, metadataPanel);
}
/**
* Generate metadata panel HTML
*/
function generateMetadataPanel(hasParams, hasPrompts, prompt, negativePrompt, size, seed, model, steps, sampler, cfgScale, clipSkip) {
// Create unique IDs for prompt copying
const promptIndex = Math.random().toString(36).substring(2, 15);
const negPromptIndex = Math.random().toString(36).substring(2, 15);
let content = '<div class="image-metadata-panel"><div class="metadata-content">';
if (hasParams) {
content += `
<div class="params-tags">
${size ? `<div class="param-tag"><span class="param-name">Size:</span><span class="param-value">${size}</span></div>` : ''}
${seed ? `<div class="param-tag"><span class="param-name">Seed:</span><span class="param-value">${seed}</span></div>` : ''}
${model ? `<div class="param-tag"><span class="param-name">Model:</span><span class="param-value">${model}</span></div>` : ''}
${steps ? `<div class="param-tag"><span class="param-name">Steps:</span><span class="param-value">${steps}</span></div>` : ''}
${sampler ? `<div class="param-tag"><span class="param-name">Sampler:</span><span class="param-value">${sampler}</span></div>` : ''}
${cfgScale ? `<div class="param-tag"><span class="param-name">CFG:</span><span class="param-value">${cfgScale}</span></div>` : ''}
${clipSkip ? `<div class="param-tag"><span class="param-name">Clip Skip:</span><span class="param-value">${clipSkip}</span></div>` : ''}
</div>
`;
}
if (!hasParams && !hasPrompts) {
content += `
<div class="no-metadata-message">
<i class="fas fa-info-circle"></i>
<span>No generation parameters available</span>
</div>
`;
}
if (prompt) {
content += `
<div class="metadata-row prompt-row">
<span class="metadata-label">Prompt:</span>
<div class="metadata-prompt-wrapper">
<div class="metadata-prompt">${prompt}</div>
<button class="copy-prompt-btn" data-prompt-index="${promptIndex}">
<i class="fas fa-copy"></i>
</button>
</div>
</div>
<div class="hidden-prompt" id="prompt-${promptIndex}" style="display:none;">${prompt}</div>
`;
}
if (negativePrompt) {
content += `
<div class="metadata-row prompt-row">
<span class="metadata-label">Negative Prompt:</span>
<div class="metadata-prompt-wrapper">
<div class="metadata-prompt">${negativePrompt}</div>
<button class="copy-prompt-btn" data-prompt-index="${negPromptIndex}">
<i class="fas fa-copy"></i>
</button>
</div>
</div>
<div class="hidden-prompt" id="prompt-${negPromptIndex}" style="display:none;">${negativePrompt}</div>
`;
}
content += '</div></div>';
return content;
}
/**
* Generate video wrapper HTML
*/
function generateVideoWrapper(media, heightPercent, shouldBlur, nsfwText, metadataPanel) {
return `
<div class="media-wrapper ${shouldBlur ? 'nsfw-media-wrapper' : ''}" style="padding-bottom: ${heightPercent}%">
${shouldBlur ? `
<button class="toggle-blur-btn showcase-toggle-btn" title="Toggle blur">
<i class="fas fa-eye"></i>
</button>
` : ''}
<video controls autoplay muted loop crossorigin="anonymous"
referrerpolicy="no-referrer" data-src="${media.url}"
class="lazy ${shouldBlur ? 'blurred' : ''}">
<source data-src="${media.url}" type="video/mp4">
Your browser does not support video playback
</video>
${shouldBlur ? `
<div class="nsfw-overlay">
<div class="nsfw-warning">
<p>${nsfwText}</p>
<button class="show-content-btn">Show</button>
</div>
</div>
` : ''}
${metadataPanel}
</div>
`;
}
/**
* Generate image wrapper HTML
*/
function generateImageWrapper(media, heightPercent, shouldBlur, nsfwText, metadataPanel) {
return `
<div class="media-wrapper ${shouldBlur ? 'nsfw-media-wrapper' : ''}" style="padding-bottom: ${heightPercent}%">
${shouldBlur ? `
<button class="toggle-blur-btn showcase-toggle-btn" title="Toggle blur">
<i class="fas fa-eye"></i>
</button>
` : ''}
<img data-src="${media.url}"
alt="Preview"
crossorigin="anonymous"
referrerpolicy="no-referrer"
width="${media.width}"
height="${media.height}"
class="lazy ${shouldBlur ? 'blurred' : ''}">
${shouldBlur ? `
<div class="nsfw-overlay">
<div class="nsfw-warning">
<p>${nsfwText}</p>
<button class="show-content-btn">Show</button>
</div>
</div>
` : ''}
${metadataPanel}
</div>
`;
}
/**
* Toggle showcase expansion
*/
export function toggleShowcase(element) {
const carousel = element.nextElementSibling;
const isCollapsed = carousel.classList.contains('collapsed');
const indicator = element.querySelector('span');
const icon = element.querySelector('i');
carousel.classList.toggle('collapsed');
if (isCollapsed) {
const count = carousel.querySelectorAll('.media-wrapper').length;
indicator.textContent = `Scroll or click to hide examples`;
icon.classList.replace('fa-chevron-down', 'fa-chevron-up');
initLazyLoading(carousel);
// Initialize NSFW content blur toggle handlers
initNsfwBlurHandlers(carousel);
// Initialize metadata panel interaction handlers
initMetadataPanelHandlers(carousel);
} else {
const count = carousel.querySelectorAll('.media-wrapper').length;
indicator.textContent = `Scroll or click to show ${count} examples`;
icon.classList.replace('fa-chevron-up', 'fa-chevron-down');
}
}
/**
* Initialize metadata panel interaction handlers
*/
function initMetadataPanelHandlers(container) {
const mediaWrappers = container.querySelectorAll('.media-wrapper');
mediaWrappers.forEach(wrapper => {
const metadataPanel = wrapper.querySelector('.image-metadata-panel');
if (!metadataPanel) return;
// Prevent events from bubbling
metadataPanel.addEventListener('click', (e) => {
e.stopPropagation();
});
// Handle copy prompt buttons
const copyBtns = metadataPanel.querySelectorAll('.copy-prompt-btn');
copyBtns.forEach(copyBtn => {
const promptIndex = copyBtn.dataset.promptIndex;
const promptElement = wrapper.querySelector(`#prompt-${promptIndex}`);
copyBtn.addEventListener('click', async (e) => {
e.stopPropagation();
if (!promptElement) return;
try {
await navigator.clipboard.writeText(promptElement.textContent);
showToast('Prompt copied to clipboard', 'success');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');
}
});
});
// Prevent panel scroll from causing modal scroll
metadataPanel.addEventListener('wheel', (e) => {
e.stopPropagation();
});
});
}
/**
* Initialize blur toggle handlers
*/
function initNsfwBlurHandlers(container) {
// Handle toggle blur buttons
const toggleButtons = container.querySelectorAll('.toggle-blur-btn');
toggleButtons.forEach(btn => {
btn.addEventListener('click', (e) => {
e.stopPropagation();
const wrapper = btn.closest('.media-wrapper');
const media = wrapper.querySelector('img, video');
const isBlurred = media.classList.toggle('blurred');
const icon = btn.querySelector('i');
// Update the icon based on blur state
if (isBlurred) {
icon.className = 'fas fa-eye';
} else {
icon.className = 'fas fa-eye-slash';
}
// Toggle the overlay visibility
const overlay = wrapper.querySelector('.nsfw-overlay');
if (overlay) {
overlay.style.display = isBlurred ? 'flex' : 'none';
}
});
});
// Handle "Show" buttons in overlays
const showButtons = container.querySelectorAll('.show-content-btn');
showButtons.forEach(btn => {
btn.addEventListener('click', (e) => {
e.stopPropagation();
const wrapper = btn.closest('.media-wrapper');
const media = wrapper.querySelector('img, video');
media.classList.remove('blurred');
// Update the toggle button icon
const toggleBtn = wrapper.querySelector('.toggle-blur-btn');
if (toggleBtn) {
toggleBtn.querySelector('i').className = 'fas fa-eye-slash';
}
// Hide the overlay
const overlay = wrapper.querySelector('.nsfw-overlay');
if (overlay) {
overlay.style.display = 'none';
}
});
});
}
/**
* Initialize lazy loading for images and videos
*/
function initLazyLoading(container) {
const lazyElements = container.querySelectorAll('.lazy');
const lazyLoad = (element) => {
if (element.tagName.toLowerCase() === 'video') {
element.src = element.dataset.src;
element.querySelector('source').src = element.dataset.src;
element.load();
} else {
element.src = element.dataset.src;
}
element.classList.remove('lazy');
};
const observer = new IntersectionObserver((entries) => {
entries.forEach(entry => {
if (entry.isIntersecting) {
lazyLoad(entry.target);
observer.unobserve(entry.target);
}
});
});
lazyElements.forEach(element => observer.observe(element));
}
/**
* Set up showcase scroll functionality
*/
export function setupShowcaseScroll() {
// Listen for wheel events
document.addEventListener('wheel', (event) => {
const modalContent = document.querySelector('#checkpointModal .modal-content');
if (!modalContent) return;
const showcase = modalContent.querySelector('.showcase-section');
if (!showcase) return;
const carousel = showcase.querySelector('.carousel');
const scrollIndicator = showcase.querySelector('.scroll-indicator');
if (carousel?.classList.contains('collapsed') && event.deltaY > 0) {
const isNearBottom = modalContent.scrollHeight - modalContent.scrollTop - modalContent.clientHeight < 100;
if (isNearBottom) {
toggleShowcase(scrollIndicator);
event.preventDefault();
}
}
}, { passive: false });
// Use MutationObserver to set up back-to-top button when modal content is added
const observer = new MutationObserver((mutations) => {
for (const mutation of mutations) {
if (mutation.type === 'childList' && mutation.addedNodes.length) {
const checkpointModal = document.getElementById('checkpointModal');
if (checkpointModal && checkpointModal.querySelector('.modal-content')) {
setupBackToTopButton(checkpointModal.querySelector('.modal-content'));
}
}
}
});
// Start observing the document body for changes
observer.observe(document.body, { childList: true, subtree: true });
// Also try to set up the button immediately in case the modal is already open
const modalContent = document.querySelector('#checkpointModal .modal-content');
if (modalContent) {
setupBackToTopButton(modalContent);
}
}
/**
* Set up back-to-top button
*/
function setupBackToTopButton(modalContent) {
// Remove any existing scroll listeners to avoid duplicates
modalContent.onscroll = null;
// Add new scroll listener
modalContent.addEventListener('scroll', () => {
const backToTopBtn = modalContent.querySelector('.back-to-top');
if (backToTopBtn) {
if (modalContent.scrollTop > 300) {
backToTopBtn.classList.add('visible');
} else {
backToTopBtn.classList.remove('visible');
}
}
});
// Trigger a scroll event to check initial position
modalContent.dispatchEvent(new Event('scroll'));
}
/**
* Scroll to top of modal content
*/
export function scrollToTop(button) {
const modalContent = button.closest('.modal-content');
if (modalContent) {
modalContent.scrollTo({
top: 0,
behavior: 'smooth'
});
}
}

View File

@@ -0,0 +1,214 @@
/**
* CheckpointModal - Main entry point
*
* Modularized checkpoint modal component that handles checkpoint model details display
*/
import { showToast } from '../../utils/uiHelpers.js';
import { state } from '../../state/index.js';
import { modalManager } from '../../managers/ModalManager.js';
import { renderShowcaseContent, toggleShowcase, setupShowcaseScroll, scrollToTop } from './ShowcaseView.js';
import { setupTabSwitching, loadModelDescription } from './ModelDescription.js';
import {
setupModelNameEditing,
setupBaseModelEditing,
setupFileNameEditing,
saveModelMetadata
} from './ModelMetadata.js';
import { renderCompactTags, setupTagTooltip, formatFileSize } from './utils.js';
import { updateCheckpointCard } from '../../utils/cardUpdater.js';
/**
* Display the checkpoint modal with the given checkpoint data
* @param {Object} checkpoint - Checkpoint data object
*/
export function showCheckpointModal(checkpoint) {
const content = `
<div class="modal-content">
<button class="close" onclick="modalManager.closeModal('checkpointModal')">&times;</button>
<header class="modal-header">
<div class="model-name-header">
<h2 class="model-name-content" contenteditable="true" spellcheck="false">${checkpoint.model_name || 'Checkpoint Details'}</h2>
<button class="edit-model-name-btn" title="Edit model name">
<i class="fas fa-pencil-alt"></i>
</button>
</div>
${renderCompactTags(checkpoint.tags || [])}
</header>
<div class="modal-body">
<div class="info-section">
<div class="info-grid">
<div class="info-item">
<label>Version</label>
<span>${checkpoint.civitai?.name || 'N/A'}</span>
</div>
<div class="info-item">
<label>File Name</label>
<div class="file-name-wrapper">
<span id="file-name" class="file-name-content">${checkpoint.file_name || 'N/A'}</span>
<button class="edit-file-name-btn" title="Edit file name">
<i class="fas fa-pencil-alt"></i>
</button>
</div>
</div>
<div class="info-item location-size">
<div class="location-wrapper">
<label>Location</label>
<span class="file-path">${checkpoint.file_path.replace(/[^/]+$/, '')}</span>
</div>
</div>
<div class="info-item base-size">
<div class="base-wrapper">
<label>Base Model</label>
<div class="base-model-display">
<span class="base-model-content">${checkpoint.base_model || 'Unknown'}</span>
<button class="edit-base-model-btn" title="Edit base model">
<i class="fas fa-pencil-alt"></i>
</button>
</div>
</div>
<div class="size-wrapper">
<label>Size</label>
<span>${formatFileSize(checkpoint.file_size)}</span>
</div>
</div>
<div class="info-item notes">
<label>Additional Notes</label>
<div class="editable-field">
<div class="notes-content" contenteditable="true" spellcheck="false">${checkpoint.notes || 'Add your notes here...'}</div>
<button class="save-btn" onclick="saveCheckpointNotes('${checkpoint.file_path}')">
<i class="fas fa-save"></i>
</button>
</div>
</div>
<div class="info-item full-width">
<label>About this version</label>
<div class="description-text">${checkpoint.description || 'N/A'}</div>
</div>
</div>
</div>
<div class="showcase-section" data-checkpoint-id="${checkpoint.civitai?.modelId || ''}">
<div class="showcase-tabs">
<button class="tab-btn active" data-tab="showcase">Examples</button>
<button class="tab-btn" data-tab="description">Model Description</button>
</div>
<div class="tab-content">
<div id="showcase-tab" class="tab-pane active">
${renderShowcaseContent(checkpoint.civitai?.images || [])}
</div>
<div id="description-tab" class="tab-pane">
<div class="model-description-container">
<div class="model-description-loading">
<i class="fas fa-spinner fa-spin"></i> Loading model description...
</div>
<div class="model-description-content">
${checkpoint.modelDescription || ''}
</div>
</div>
</div>
</div>
<button class="back-to-top" onclick="scrollToTopCheckpoint(this)">
<i class="fas fa-arrow-up"></i>
</button>
</div>
</div>
</div>
`;
modalManager.showModal('checkpointModal', content);
setupEditableFields(checkpoint.file_path);
setupShowcaseScroll();
setupTabSwitching();
setupTagTooltip();
setupModelNameEditing(checkpoint.file_path);
setupBaseModelEditing(checkpoint.file_path);
setupFileNameEditing(checkpoint.file_path);
// If we have a model ID but no description, fetch it
if (checkpoint.civitai?.modelId && !checkpoint.modelDescription) {
loadModelDescription(checkpoint.civitai.modelId, checkpoint.file_path);
}
}
/**
* Set up editable fields in the checkpoint modal
* @param {string} filePath - The full file path of the model.
*/
function setupEditableFields(filePath) {
const editableFields = document.querySelectorAll('.editable-field [contenteditable]');
editableFields.forEach(field => {
field.addEventListener('focus', function() {
if (this.textContent === 'Add your notes here...') {
this.textContent = '';
}
});
field.addEventListener('blur', function() {
if (this.textContent.trim() === '') {
if (this.classList.contains('notes-content')) {
this.textContent = 'Add your notes here...';
}
}
});
});
// Add keydown event listeners for notes
const notesContent = document.querySelector('.notes-content');
if (notesContent) {
notesContent.addEventListener('keydown', async function(e) {
if (e.key === 'Enter') {
if (e.shiftKey) {
// Allow shift+enter for new line
return;
}
e.preventDefault();
await saveNotes(filePath);
}
});
}
}
/**
* Save checkpoint notes
* @param {string} filePath - Path to the checkpoint file
*/
async function saveNotes(filePath) {
const content = document.querySelector('.notes-content').textContent;
try {
await saveModelMetadata(filePath, { notes: content });
// Update the corresponding checkpoint card's dataset
updateCheckpointCard(filePath, { notes: content });
showToast('Notes saved successfully', 'success');
} catch (error) {
showToast('Failed to save notes', 'error');
}
}
// Export the checkpoint modal API
const checkpointModal = {
show: showCheckpointModal,
toggleShowcase,
scrollToTop
};
export { checkpointModal };
// Define global functions for use in HTML
window.toggleShowcase = function(element) {
toggleShowcase(element);
};
window.scrollToTopCheckpoint = function(button) {
scrollToTop(button);
};
window.saveCheckpointNotes = function(filePath) {
saveNotes(filePath);
};

View File

@@ -0,0 +1,74 @@
/**
* utils.js
* CheckpointModal component utility functions
*/
import { showToast } from '../../utils/uiHelpers.js';
/**
* Format file size for display
* @param {number} bytes - File size in bytes
* @returns {string} - Formatted file size
*/
export function formatFileSize(bytes) {
if (!bytes) return 'N/A';
const units = ['B', 'KB', 'MB', 'GB', 'TB'];
let size = bytes;
let unitIndex = 0;
while (size >= 1024 && unitIndex < units.length - 1) {
size /= 1024;
unitIndex++;
}
return `${size.toFixed(1)} ${units[unitIndex]}`;
}
/**
* Render compact tags
* @param {Array} tags - Array of tags
* @returns {string} HTML content
*/
export function renderCompactTags(tags) {
if (!tags || tags.length === 0) return '';
// Display up to 5 tags, with a tooltip indicator if there are more
const visibleTags = tags.slice(0, 5);
const remainingCount = Math.max(0, tags.length - 5);
return `
<div class="model-tags-container">
<div class="model-tags-compact">
${visibleTags.map(tag => `<span class="model-tag-compact">${tag}</span>`).join('')}
${remainingCount > 0 ?
`<span class="model-tag-more" data-count="${remainingCount}">+${remainingCount}</span>` :
''}
</div>
${tags.length > 0 ?
`<div class="model-tags-tooltip">
<div class="tooltip-content">
${tags.map(tag => `<span class="tooltip-tag">${tag}</span>`).join('')}
</div>
</div>` :
''}
</div>
`;
}
/**
* Set up tag tooltip functionality
*/
export function setupTagTooltip() {
const tagsContainer = document.querySelector('.model-tags-container');
const tooltip = document.querySelector('.model-tags-tooltip');
if (tagsContainer && tooltip) {
tagsContainer.addEventListener('mouseenter', () => {
tooltip.classList.add('visible');
});
tagsContainer.addEventListener('mouseleave', () => {
tooltip.classList.remove('visible');
});
}
}

View File

@@ -0,0 +1,60 @@
// CheckpointsControls.js - Specific implementation for the Checkpoints page
import { PageControls } from './PageControls.js';
import { loadMoreCheckpoints, resetAndReload, refreshCheckpoints, fetchCivitai } from '../../api/checkpointApi.js';
import { showToast } from '../../utils/uiHelpers.js';
import { CheckpointDownloadManager } from '../../managers/CheckpointDownloadManager.js';
/**
* CheckpointsControls class - Extends PageControls for Checkpoint-specific functionality
*/
export class CheckpointsControls extends PageControls {
constructor() {
// Initialize with 'checkpoints' page type
super('checkpoints');
// Initialize checkpoint download manager
this.downloadManager = new CheckpointDownloadManager();
// Register API methods specific to the Checkpoints page
this.registerCheckpointsAPI();
}
/**
* Register Checkpoint-specific API methods
*/
registerCheckpointsAPI() {
const checkpointsAPI = {
// Core API functions
loadMoreModels: async (resetPage = false, updateFolders = false) => {
return await loadMoreCheckpoints(resetPage, updateFolders);
},
resetAndReload: async (updateFolders = false) => {
return await resetAndReload(updateFolders);
},
refreshModels: async () => {
return await refreshCheckpoints();
},
// Add fetch from Civitai functionality for checkpoints
fetchFromCivitai: async () => {
return await fetchCivitai();
},
// Add show download modal functionality
showDownloadModal: () => {
this.downloadManager.showDownloadModal();
},
// No clearCustomFilter implementation is needed for checkpoints
// as custom filters are currently only used for LoRAs
clearCustomFilter: async () => {
showToast('No custom filter to clear', 'info');
}
};
// Register the API
this.registerAPI(checkpointsAPI);
}
}

View File

@@ -0,0 +1,146 @@
// LorasControls.js - Specific implementation for the LoRAs page
import { PageControls } from './PageControls.js';
import { loadMoreLoras, fetchCivitai, resetAndReload, refreshLoras } from '../../api/loraApi.js';
import { getSessionItem, removeSessionItem } from '../../utils/storageHelpers.js';
import { showToast } from '../../utils/uiHelpers.js';
/**
* LorasControls class - Extends PageControls for LoRA-specific functionality
*/
export class LorasControls extends PageControls {
constructor() {
// Initialize with 'loras' page type
super('loras');
// Register API methods specific to the LoRAs page
this.registerLorasAPI();
// Check for custom filters (e.g., from recipe navigation)
this.checkCustomFilters();
}
/**
* Register LoRA-specific API methods
*/
registerLorasAPI() {
const lorasAPI = {
// Core API functions
loadMoreModels: async (resetPage = false, updateFolders = false) => {
return await loadMoreLoras(resetPage, updateFolders);
},
resetAndReload: async (updateFolders = false) => {
return await resetAndReload(updateFolders);
},
refreshModels: async () => {
return await refreshLoras();
},
// LoRA-specific API functions
fetchFromCivitai: async () => {
return await fetchCivitai();
},
showDownloadModal: () => {
if (window.downloadManager) {
window.downloadManager.showDownloadModal();
} else {
console.error('Download manager not available');
}
},
toggleBulkMode: () => {
if (window.bulkManager) {
window.bulkManager.toggleBulkMode();
} else {
console.error('Bulk manager not available');
}
},
clearCustomFilter: async () => {
await this.clearCustomFilter();
}
};
// Register the API
this.registerAPI(lorasAPI);
}
/**
* Check for custom filter parameters in session storage (e.g., from recipe page navigation)
*/
checkCustomFilters() {
const filterLoraHash = getSessionItem('recipe_to_lora_filterLoraHash');
const filterLoraHashes = getSessionItem('recipe_to_lora_filterLoraHashes');
const filterRecipeName = getSessionItem('filterRecipeName');
const viewLoraDetail = getSessionItem('viewLoraDetail');
if ((filterLoraHash || filterLoraHashes) && filterRecipeName) {
// Found custom filter parameters, set up the custom filter
// Show the filter indicator
const indicator = document.getElementById('customFilterIndicator');
const filterText = indicator?.querySelector('.customFilterText');
if (indicator && filterText) {
indicator.classList.remove('hidden');
// Set text content with recipe name
const filterType = filterLoraHash && viewLoraDetail ? "Viewing LoRA from" : "Viewing LoRAs from";
const displayText = `${filterType}: ${filterRecipeName}`;
filterText.textContent = this._truncateText(displayText, 30);
filterText.setAttribute('title', displayText);
// Add pulse animation
const filterElement = indicator.querySelector('.filter-active');
if (filterElement) {
filterElement.classList.add('animate');
setTimeout(() => filterElement.classList.remove('animate'), 600);
}
}
// If we're viewing a specific LoRA detail, set up to open the modal
if (filterLoraHash && viewLoraDetail) {
this.pageState.pendingLoraHash = filterLoraHash;
}
}
}
/**
* Clear the custom filter and reload the page
*/
async clearCustomFilter() {
console.log("Clearing custom filter...");
// Remove filter parameters from session storage
removeSessionItem('recipe_to_lora_filterLoraHash');
removeSessionItem('recipe_to_lora_filterLoraHashes');
removeSessionItem('filterRecipeName');
removeSessionItem('viewLoraDetail');
// Hide the filter indicator
const indicator = document.getElementById('customFilterIndicator');
if (indicator) {
indicator.classList.add('hidden');
}
// Reset state
if (this.pageState.pendingLoraHash) {
delete this.pageState.pendingLoraHash;
}
// Reload the loras
await resetAndReload();
}
/**
* Helper to truncate text with ellipsis
* @param {string} text - Text to truncate
* @param {number} maxLength - Maximum length before truncating
* @returns {string} - Truncated text
*/
_truncateText(text, maxLength) {
return text.length > maxLength ? text.substring(0, maxLength - 3) + '...' : text;
}
}

View File

@@ -0,0 +1,388 @@
// PageControls.js - Manages controls for both LoRAs and Checkpoints pages
import { state, getCurrentPageState, setCurrentPageType } from '../../state/index.js';
import { getStorageItem, setStorageItem } from '../../utils/storageHelpers.js';
import { showToast } from '../../utils/uiHelpers.js';
/**
* PageControls class - Unified control management for model pages
*/
export class PageControls {
constructor(pageType) {
// Set the current page type in state
setCurrentPageType(pageType);
// Store the page type
this.pageType = pageType;
// Get the current page state
this.pageState = getCurrentPageState();
// Initialize state based on page type
this.initializeState();
// Store API methods
this.api = null;
// Initialize event listeners
this.initEventListeners();
console.log(`PageControls initialized for ${pageType} page`);
}
/**
* Initialize state based on page type
*/
initializeState() {
// Set default values
this.pageState.pageSize = 20;
this.pageState.isLoading = false;
this.pageState.hasMore = true;
// Load sort preference
this.loadSortPreference();
}
/**
* Register API methods for the page
* @param {Object} api - API methods for the page
*/
registerAPI(api) {
this.api = api;
console.log(`API methods registered for ${this.pageType} page`);
}
/**
* Initialize event listeners for controls
*/
initEventListeners() {
// Sort select handler
const sortSelect = document.getElementById('sortSelect');
if (sortSelect) {
sortSelect.value = this.pageState.sortBy;
sortSelect.addEventListener('change', async (e) => {
this.pageState.sortBy = e.target.value;
this.saveSortPreference(e.target.value);
await this.resetAndReload();
});
}
// Use event delegation for folder tags - this is the key fix
const folderTagsContainer = document.querySelector('.folder-tags-container');
if (folderTagsContainer) {
folderTagsContainer.addEventListener('click', (e) => {
const tag = e.target.closest('.tag');
if (tag) {
this.handleFolderClick(tag);
}
});
}
// Refresh button handler
const refreshBtn = document.querySelector('[data-action="refresh"]');
if (refreshBtn) {
refreshBtn.addEventListener('click', () => this.refreshModels());
}
// Toggle folders button
const toggleFoldersBtn = document.querySelector('.toggle-folders-btn');
if (toggleFoldersBtn) {
toggleFoldersBtn.addEventListener('click', () => this.toggleFolderTags());
}
// Clear custom filter handler
const clearFilterBtn = document.querySelector('.clear-filter');
if (clearFilterBtn) {
clearFilterBtn.addEventListener('click', () => this.clearCustomFilter());
}
// Page-specific event listeners
this.initPageSpecificListeners();
}
/**
* Initialize page-specific event listeners
*/
initPageSpecificListeners() {
// Fetch from Civitai button - available for both loras and checkpoints
const fetchButton = document.querySelector('[data-action="fetch"]');
if (fetchButton) {
fetchButton.addEventListener('click', () => this.fetchFromCivitai());
}
const downloadButton = document.querySelector('[data-action="download"]');
if (downloadButton) {
downloadButton.addEventListener('click', () => this.showDownloadModal());
}
if (this.pageType === 'loras') {
// Bulk operations button - LoRAs only
const bulkButton = document.querySelector('[data-action="bulk"]');
if (bulkButton) {
bulkButton.addEventListener('click', () => this.toggleBulkMode());
}
}
}
/**
* Toggle folder selection
* @param {HTMLElement} tagElement - The folder tag element that was clicked
*/
handleFolderClick(tagElement) {
const folder = tagElement.dataset.folder;
const wasActive = tagElement.classList.contains('active');
document.querySelectorAll('.folder-tags .tag').forEach(t => {
t.classList.remove('active');
});
if (!wasActive) {
tagElement.classList.add('active');
this.pageState.activeFolder = folder;
setStorageItem(`${this.pageType}_activeFolder`, folder);
} else {
this.pageState.activeFolder = null;
setStorageItem(`${this.pageType}_activeFolder`, null);
}
this.resetAndReload();
}
/**
* Restore folder filter from storage
*/
restoreFolderFilter() {
const activeFolder = getStorageItem(`${this.pageType}_activeFolder`);
const folderTag = activeFolder && document.querySelector(`.tag[data-folder="${activeFolder}"]`);
if (folderTag) {
folderTag.classList.add('active');
this.pageState.activeFolder = activeFolder;
this.filterByFolder(activeFolder);
}
}
/**
* Filter displayed cards by folder
* @param {string} folderPath - Folder path to filter by
*/
filterByFolder(folderPath) {
const cardSelector = this.pageType === 'loras' ? '.lora-card' : '.checkpoint-card';
document.querySelectorAll(cardSelector).forEach(card => {
card.style.display = card.dataset.folder === folderPath ? '' : 'none';
});
}
/**
* Update the folder tags display with new folder list
* @param {Array} folders - List of folder names
*/
updateFolderTags(folders) {
const folderTagsContainer = document.querySelector('.folder-tags');
if (!folderTagsContainer) return;
// Keep track of currently selected folder
const currentFolder = this.pageState.activeFolder;
// Create HTML for folder tags
const tagsHTML = folders.map(folder => {
const isActive = folder === currentFolder;
return `<div class="tag ${isActive ? 'active' : ''}" data-folder="${folder}">${folder}</div>`;
}).join('');
// Update the container
folderTagsContainer.innerHTML = tagsHTML;
// Scroll active folder into view (no need to reattach click handlers)
const activeTag = folderTagsContainer.querySelector(`.tag[data-folder="${currentFolder}"]`);
if (activeTag) {
activeTag.scrollIntoView({ behavior: 'smooth', block: 'nearest' });
}
}
/**
* Toggle visibility of folder tags
*/
toggleFolderTags() {
const folderTags = document.querySelector('.folder-tags');
const toggleBtn = document.querySelector('.toggle-folders-btn i');
if (folderTags) {
folderTags.classList.toggle('collapsed');
if (folderTags.classList.contains('collapsed')) {
// Change icon to indicate folders are hidden
toggleBtn.className = 'fas fa-folder-plus';
toggleBtn.parentElement.title = 'Show folder tags';
setStorageItem('folderTagsCollapsed', 'true');
} else {
// Change icon to indicate folders are visible
toggleBtn.className = 'fas fa-folder-minus';
toggleBtn.parentElement.title = 'Hide folder tags';
setStorageItem('folderTagsCollapsed', 'false');
}
}
}
/**
* Initialize folder tags visibility based on stored preference
*/
initFolderTagsVisibility() {
const isCollapsed = getStorageItem('folderTagsCollapsed');
if (isCollapsed) {
const folderTags = document.querySelector('.folder-tags');
const toggleBtn = document.querySelector('.toggle-folders-btn i');
if (folderTags) {
folderTags.classList.add('collapsed');
}
if (toggleBtn) {
toggleBtn.className = 'fas fa-folder-plus';
toggleBtn.parentElement.title = 'Show folder tags';
}
} else {
const toggleBtn = document.querySelector('.toggle-folders-btn i');
if (toggleBtn) {
toggleBtn.className = 'fas fa-folder-minus';
toggleBtn.parentElement.title = 'Hide folder tags';
}
}
}
/**
* Load sort preference from storage
*/
loadSortPreference() {
const savedSort = getStorageItem(`${this.pageType}_sort`);
if (savedSort) {
this.pageState.sortBy = savedSort;
const sortSelect = document.getElementById('sortSelect');
if (sortSelect) {
sortSelect.value = savedSort;
}
}
}
/**
* Save sort preference to storage
* @param {string} sortValue - The sort value to save
*/
saveSortPreference(sortValue) {
setStorageItem(`${this.pageType}_sort`, sortValue);
}
/**
* Open model page on Civitai
* @param {string} modelName - Name of the model
*/
openCivitai(modelName) {
// Get card selector based on page type
const cardSelector = this.pageType === 'loras'
? `.lora-card[data-name="${modelName}"]`
: `.checkpoint-card[data-name="${modelName}"]`;
const card = document.querySelector(cardSelector);
if (!card) return;
const metaData = JSON.parse(card.dataset.meta);
const civitaiId = metaData.modelId;
const versionId = metaData.id;
// Build URL
if (civitaiId) {
let url = `https://civitai.com/models/${civitaiId}`;
if (versionId) {
url += `?modelVersionId=${versionId}`;
}
window.open(url, '_blank');
} else {
// If no ID, try searching by name
window.open(`https://civitai.com/models?query=${encodeURIComponent(modelName)}`, '_blank');
}
}
/**
* Reset and reload the models list
*/
async resetAndReload(updateFolders = false) {
if (!this.api) {
console.error('API methods not registered');
return;
}
try {
await this.api.resetAndReload(updateFolders);
} catch (error) {
console.error(`Error reloading ${this.pageType}:`, error);
showToast(`Failed to reload ${this.pageType}: ${error.message}`, 'error');
}
}
/**
* Refresh models list
*/
async refreshModels() {
if (!this.api) {
console.error('API methods not registered');
return;
}
try {
await this.api.refreshModels();
} catch (error) {
console.error(`Error refreshing ${this.pageType}:`, error);
showToast(`Failed to refresh ${this.pageType}: ${error.message}`, 'error');
}
}
/**
* Fetch metadata from Civitai (available for both LoRAs and Checkpoints)
*/
async fetchFromCivitai() {
if (!this.api) {
console.error('API methods not registered');
return;
}
try {
await this.api.fetchFromCivitai();
} catch (error) {
console.error('Error fetching metadata:', error);
showToast('Failed to fetch metadata: ' + error.message, 'error');
}
}
/**
* Show download modal
*/
showDownloadModal() {
this.api.showDownloadModal();
}
/**
* Toggle bulk mode (LoRAs only)
*/
toggleBulkMode() {
if (this.pageType !== 'loras' || !this.api) {
console.error('Bulk mode is only available for LoRAs');
return;
}
this.api.toggleBulkMode();
}
/**
* Clear custom filter
*/
async clearCustomFilter() {
if (!this.api) {
console.error('API methods not registered');
return;
}
try {
await this.api.clearCustomFilter();
} catch (error) {
console.error('Error clearing custom filter:', error);
showToast('Failed to clear custom filter: ' + error.message, 'error');
}
}
}

View File

@@ -0,0 +1,23 @@
// Controls components index file
import { PageControls } from './PageControls.js';
import { LorasControls } from './LorasControls.js';
import { CheckpointsControls } from './CheckpointsControls.js';
// Export the classes
export { PageControls, LorasControls, CheckpointsControls };
/**
* Factory function to create the appropriate controls based on page type
* @param {string} pageType - The type of page ('loras' or 'checkpoints')
* @returns {PageControls} - The appropriate controls instance
*/
export function createPageControls(pageType) {
if (pageType === 'loras') {
return new LorasControls();
} else if (pageType === 'checkpoints') {
return new CheckpointsControls();
} else {
console.error(`Unknown page type: ${pageType}`);
return null;
}
}

View File

@@ -0,0 +1,495 @@
/**
* Initialization Component
* Manages the display of initialization progress and status
*/
import { appCore } from '../core.js';
import { getSessionItem, setSessionItem } from '../utils/storageHelpers.js';
import { state, getCurrentPageState } from '../state/index.js';
class InitializationManager {
constructor() {
this.currentTipIndex = 0;
this.tipInterval = null;
this.websocket = null;
this.progress = 0;
this.processingStartTime = null;
this.processedFilesCount = 0;
this.totalFilesCount = 0;
this.averageProcessingTime = null;
this.pageType = null; // Added page type property
}
/**
* Initialize the component
*/
initialize() {
// Initialize core application for theme and header functionality
appCore.initialize().then(() => {
console.log('Core application initialized for initialization component');
});
// Detect the current page type
this.detectPageType();
// Check session storage for saved progress
this.restoreProgress();
// Setup the tip carousel
this.setupTipCarousel();
// Connect to WebSocket for progress updates
this.connectWebSocket();
// Add event listeners for tip navigation
this.setupTipNavigation();
// Show first tip as active
document.querySelector('.tip-item').classList.add('active');
}
/**
* Detect the current page type
*/
detectPageType() {
// Get the current page type from URL or data attribute
const path = window.location.pathname;
if (path.includes('/checkpoints')) {
this.pageType = 'checkpoints';
} else if (path.includes('/loras')) {
this.pageType = 'loras';
} else {
// Default to loras if can't determine
this.pageType = 'loras';
}
console.log(`Initialization component detected page type: ${this.pageType}`);
}
/**
* Get the storage key with page type prefix
*/
getStorageKey(key) {
return `${this.pageType}_${key}`;
}
/**
* Restore progress from session storage if available
*/
restoreProgress() {
const savedProgress = getSessionItem(this.getStorageKey('initProgress'));
if (savedProgress) {
console.log(`Restoring ${this.pageType} progress from session storage:`, savedProgress);
// Restore progress percentage
if (savedProgress.progress !== undefined) {
this.updateProgress(savedProgress.progress);
}
// Restore processed files count and total files
if (savedProgress.processedFiles !== undefined) {
this.processedFilesCount = savedProgress.processedFiles;
}
if (savedProgress.totalFiles !== undefined) {
this.totalFilesCount = savedProgress.totalFiles;
}
// Restore processing time metrics if available
if (savedProgress.averageProcessingTime !== undefined) {
this.averageProcessingTime = savedProgress.averageProcessingTime;
this.updateRemainingTime();
}
// Restore progress status message
if (savedProgress.details) {
this.updateStatusMessage(savedProgress.details);
}
}
}
/**
* Connect to WebSocket for initialization progress updates
*/
connectWebSocket() {
try {
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
this.websocket = new WebSocket(`${wsProtocol}${window.location.host}/ws/init-progress`);
this.websocket.onopen = () => {
console.log('Connected to initialization progress WebSocket');
};
this.websocket.onmessage = (event) => {
this.handleProgressUpdate(JSON.parse(event.data));
};
this.websocket.onerror = (error) => {
console.error('WebSocket error:', error);
// Fall back to polling if WebSocket fails
this.fallbackToPolling();
};
this.websocket.onclose = () => {
console.log('WebSocket connection closed');
// Check if we need to fall back to polling
if (!this.pollingActive) {
this.fallbackToPolling();
}
};
} catch (error) {
console.error('Failed to connect to WebSocket:', error);
this.fallbackToPolling();
}
}
/**
* Fall back to polling if WebSocket connection fails
*/
fallbackToPolling() {
this.pollingActive = true;
this.pollProgress();
// Set a simulated progress that moves forward slowly
// This gives users feedback even if the backend isn't providing updates
let simulatedProgress = this.progress || 0;
const simulateInterval = setInterval(() => {
simulatedProgress += 0.5;
if (simulatedProgress > 95) {
clearInterval(simulateInterval);
return;
}
// Only use simulated progress if we haven't received a real update
if (this.progress < simulatedProgress) {
this.updateProgress(simulatedProgress);
}
}, 1000);
}
/**
* Poll for progress updates from the server
*/
pollProgress() {
const checkProgress = () => {
fetch('/api/init-status')
.then(response => response.json())
.then(data => {
this.handleProgressUpdate(data);
// If initialization is complete, stop polling
if (data.status !== 'complete') {
setTimeout(checkProgress, 2000);
} else {
window.location.reload();
}
})
.catch(error => {
console.error('Error polling for progress:', error);
setTimeout(checkProgress, 3000); // Try again after a longer delay
});
};
checkProgress();
}
/**
* Handle progress updates from WebSocket or polling
*/
handleProgressUpdate(data) {
if (!data) return;
// Check if this update is for our page type
if (data.pageType && data.pageType !== this.pageType) {
console.log(`Ignoring update for ${data.pageType}, we're on ${this.pageType}`);
return;
}
// If no pageType is specified in the data but we have scanner_type, map it to pageType
if (!data.pageType && data.scanner_type) {
const scannerTypeToPageType = {
'lora': 'loras',
'checkpoint': 'checkpoints'
};
if (scannerTypeToPageType[data.scanner_type] !== this.pageType) {
console.log(`Ignoring update for ${data.scanner_type}, we're on ${this.pageType}`);
return;
}
}
// Save progress data to session storage
setSessionItem(this.getStorageKey('initProgress'), {
...data,
averageProcessingTime: this.averageProcessingTime,
processedFiles: this.processedFilesCount,
totalFiles: this.totalFilesCount
});
// Update progress percentage
if (data.progress !== undefined) {
this.updateProgress(data.progress);
}
// Update stage-specific details
if (data.details) {
this.updateStatusMessage(data.details);
}
// Track files count for time estimation
if (data.stage === 'count_models' && data.details) {
const match = data.details.match(/Found (\d+)/);
if (match && match[1]) {
this.totalFilesCount = parseInt(match[1]);
}
}
// Track processed files for time estimation
if (data.stage === 'process_models' && data.details) {
const match = data.details.match(/Processing .* files: (\d+)\/(\d+)/);
if (match && match[1] && match[2]) {
const currentCount = parseInt(match[1]);
const totalCount = parseInt(match[2]);
// Make sure we have valid total count
if (totalCount > 0 && this.totalFilesCount === 0) {
this.totalFilesCount = totalCount;
}
// Start tracking processing time once we've processed some files
if (currentCount > 0 && !this.processingStartTime && this.processedFilesCount === 0) {
this.processingStartTime = Date.now();
}
// Calculate average processing time based on elapsed time and files processed
if (this.processingStartTime && currentCount > this.processedFilesCount) {
const newFiles = currentCount - this.processedFilesCount;
const elapsedTime = Date.now() - this.processingStartTime;
const timePerFile = elapsedTime / currentCount; // ms per file
// Update moving average
if (!this.averageProcessingTime) {
this.averageProcessingTime = timePerFile;
} else {
// Simple exponential moving average
this.averageProcessingTime = this.averageProcessingTime * 0.7 + timePerFile * 0.3;
}
// Update remaining time estimate
this.updateRemainingTime();
}
this.processedFilesCount = currentCount;
}
}
// If initialization is complete, reload the page
if (data.status === 'complete') {
this.showCompletionMessage();
// Remove session storage data since we're done
setSessionItem(this.getStorageKey('initProgress'), null);
// Give the user a moment to see the completion message
setTimeout(() => {
window.location.reload();
}, 1500);
}
}
/**
* Update the remaining time display based on current progress
*/
updateRemainingTime() {
if (!this.averageProcessingTime || !this.totalFilesCount || this.totalFilesCount <= 0) {
document.getElementById('remainingTime').textContent = 'Estimating...';
return;
}
const remainingFiles = this.totalFilesCount - this.processedFilesCount;
const remainingTimeMs = remainingFiles * this.averageProcessingTime;
if (remainingTimeMs <= 0) {
document.getElementById('remainingTime').textContent = 'Almost done...';
return;
}
// Format the time for display
let formattedTime;
if (remainingTimeMs < 60000) {
// Less than a minute
formattedTime = 'Less than a minute';
} else if (remainingTimeMs < 3600000) {
// Less than an hour
const minutes = Math.round(remainingTimeMs / 60000);
formattedTime = `~${minutes} minute${minutes !== 1 ? 's' : ''}`;
} else {
// Hours and minutes
const hours = Math.floor(remainingTimeMs / 3600000);
const minutes = Math.round((remainingTimeMs % 3600000) / 60000);
formattedTime = `~${hours} hour${hours !== 1 ? 's' : ''} ${minutes} minute${minutes !== 1 ? 's' : ''}`;
}
document.getElementById('remainingTime').textContent = formattedTime + ' remaining';
}
/**
* Update status message
*/
updateStatusMessage(message) {
const progressStatus = document.getElementById('progressStatus');
if (progressStatus) {
progressStatus.textContent = message;
}
}
/**
* Update the progress bar and percentage
*/
updateProgress(progress) {
this.progress = progress;
const progressBar = document.getElementById('initProgressBar');
const progressPercentage = document.getElementById('progressPercentage');
if (progressBar && progressPercentage) {
progressBar.style.width = `${progress}%`;
progressPercentage.textContent = `${Math.round(progress)}%`;
}
}
/**
* Setup the tip carousel to rotate through tips
*/
setupTipCarousel() {
const tipItems = document.querySelectorAll('.tip-item');
if (tipItems.length === 0) return;
// Show the first tip
tipItems[0].classList.add('active');
document.querySelector('.tip-dot').classList.add('active');
// Set up automatic rotation
this.tipInterval = setInterval(() => {
this.showNextTip();
}, 8000); // Change tip every 8 seconds
}
/**
* Setup tip navigation dots
*/
setupTipNavigation() {
const tipDots = document.querySelectorAll('.tip-dot');
tipDots.forEach((dot, index) => {
dot.addEventListener('click', () => {
this.showTipByIndex(index);
});
});
}
/**
* Show the next tip in the carousel
*/
showNextTip() {
const tipItems = document.querySelectorAll('.tip-item');
const tipDots = document.querySelectorAll('.tip-dot');
if (tipItems.length === 0) return;
// Hide current tip
tipItems[this.currentTipIndex].classList.remove('active');
tipDots[this.currentTipIndex].classList.remove('active');
// Calculate next index
this.currentTipIndex = (this.currentTipIndex + 1) % tipItems.length;
// Show next tip
tipItems[this.currentTipIndex].classList.add('active');
tipDots[this.currentTipIndex].classList.add('active');
}
/**
* Show a specific tip by index
*/
showTipByIndex(index) {
const tipItems = document.querySelectorAll('.tip-item');
const tipDots = document.querySelectorAll('.tip-dot');
if (index >= tipItems.length || index < 0) return;
// Hide current tip
tipItems[this.currentTipIndex].classList.remove('active');
tipDots[this.currentTipIndex].classList.remove('active');
// Update index and show selected tip
this.currentTipIndex = index;
// Show selected tip
tipItems[this.currentTipIndex].classList.add('active');
tipDots[this.currentTipIndex].classList.add('active');
// Reset interval to prevent quick tip change
if (this.tipInterval) {
clearInterval(this.tipInterval);
this.tipInterval = setInterval(() => {
this.showNextTip();
}, 8000);
}
}
/**
* Show completion message
*/
showCompletionMessage() {
// Update progress to 100%
this.updateProgress(100);
// Update status message
this.updateStatusMessage('Initialization complete!');
// Update title and subtitle
const initTitle = document.getElementById('initTitle');
const initSubtitle = document.getElementById('initSubtitle');
const remainingTime = document.getElementById('remainingTime');
if (initTitle) {
initTitle.textContent = 'Initialization Complete';
}
if (initSubtitle) {
initSubtitle.textContent = 'Reloading page...';
}
if (remainingTime) {
remainingTime.textContent = 'Done!';
}
}
/**
* Clean up resources when the component is destroyed
*/
cleanup() {
if (this.tipInterval) {
clearInterval(this.tipInterval);
}
if (this.websocket && this.websocket.readyState === WebSocket.OPEN) {
this.websocket.close();
}
}
}
// Create and export the initialization manager
export const initManager = new InitializationManager();
// Initialize the component when the DOM is loaded
document.addEventListener('DOMContentLoaded', () => {
// Only initialize if we're in initialization mode
const initContainer = document.getElementById('initializationContainer');
if (initContainer) {
initManager.initialize();
}
});
// Clean up when the page is unloaded
window.addEventListener('beforeunload', () => {
initManager.cleanup();
});

View File

@@ -0,0 +1,102 @@
/**
* ModelDescription.js
* 处理LoRA模型描述相关的功能模块
*/
import { showToast } from '../../utils/uiHelpers.js';
/**
* 设置标签页切换功能
*/
export function setupTabSwitching() {
const tabButtons = document.querySelectorAll('.showcase-tabs .tab-btn');
tabButtons.forEach(button => {
button.addEventListener('click', () => {
// Remove active class from all tabs
document.querySelectorAll('.showcase-tabs .tab-btn').forEach(btn =>
btn.classList.remove('active')
);
document.querySelectorAll('.tab-content .tab-pane').forEach(tab =>
tab.classList.remove('active')
);
// Add active class to clicked tab
button.classList.add('active');
const tabId = `${button.dataset.tab}-tab`;
document.getElementById(tabId).classList.add('active');
// If switching to description tab, make sure content is properly sized
if (button.dataset.tab === 'description') {
const descriptionContent = document.querySelector('.model-description-content');
if (descriptionContent) {
const hasContent = descriptionContent.innerHTML.trim() !== '';
document.querySelector('.model-description-loading')?.classList.add('hidden');
// If no content, show a message
if (!hasContent) {
descriptionContent.innerHTML = '<div class="no-description">No model description available</div>';
descriptionContent.classList.remove('hidden');
}
}
}
});
});
}
/**
* 加载模型描述
* @param {string} modelId - 模型ID
* @param {string} filePath - 文件路径
*/
export async function loadModelDescription(modelId, filePath) {
try {
const descriptionContainer = document.querySelector('.model-description-content');
const loadingElement = document.querySelector('.model-description-loading');
if (!descriptionContainer || !loadingElement) return;
// Show loading indicator
loadingElement.classList.remove('hidden');
descriptionContainer.classList.add('hidden');
// Try to get model description from API
const response = await fetch(`/api/lora-model-description?model_id=${modelId}&file_path=${encodeURIComponent(filePath)}`);
if (!response.ok) {
throw new Error(`Failed to fetch model description: ${response.statusText}`);
}
const data = await response.json();
if (data.success && data.description) {
// Update the description content
descriptionContainer.innerHTML = data.description;
// Process any links in the description to open in new tab
const links = descriptionContainer.querySelectorAll('a');
links.forEach(link => {
link.setAttribute('target', '_blank');
link.setAttribute('rel', 'noopener noreferrer');
});
// Show the description and hide loading indicator
descriptionContainer.classList.remove('hidden');
loadingElement.classList.add('hidden');
} else {
throw new Error(data.error || 'No description available');
}
} catch (error) {
console.error('Error loading model description:', error);
const loadingElement = document.querySelector('.model-description-loading');
if (loadingElement) {
loadingElement.innerHTML = `<div class="error-message">Failed to load model description. ${error.message}</div>`;
}
// Show empty state message in the description container
const descriptionContainer = document.querySelector('.model-description-content');
if (descriptionContainer) {
descriptionContainer.innerHTML = '<div class="no-description">No model description available</div>';
descriptionContainer.classList.remove('hidden');
}
}
}

View File

@@ -0,0 +1,474 @@
/**
* ModelMetadata.js
* 处理LoRA模型元数据编辑相关的功能模块
*/
import { showToast } from '../../utils/uiHelpers.js';
import { BASE_MODELS } from '../../utils/constants.js';
import { updateLoraCard } from '../../utils/cardUpdater.js';
/**
* 保存模型元数据到服务器
* @param {string} filePath - 文件路径
* @param {Object} data - 要保存的数据
* @returns {Promise} 保存操作的Promise
*/
export async function saveModelMetadata(filePath, data) {
const response = await fetch('/api/loras/save-metadata', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
file_path: filePath,
...data
})
});
if (!response.ok) {
throw new Error('Failed to save metadata');
}
return response.json();
}
/**
* 设置模型名称编辑功能
* @param {string} filePath - 文件路径
*/
export function setupModelNameEditing(filePath) {
const modelNameContent = document.querySelector('.model-name-content');
const editBtn = document.querySelector('.edit-model-name-btn');
if (!modelNameContent || !editBtn) return;
// Store the file path in a data attribute for later use
modelNameContent.dataset.filePath = filePath;
// Show edit button on hover
const modelNameHeader = document.querySelector('.model-name-header');
modelNameHeader.addEventListener('mouseenter', () => {
editBtn.classList.add('visible');
});
modelNameHeader.addEventListener('mouseleave', () => {
if (!modelNameContent.getAttribute('data-editing')) {
editBtn.classList.remove('visible');
}
});
// Handle edit button click
editBtn.addEventListener('click', () => {
modelNameContent.setAttribute('data-editing', 'true');
modelNameContent.focus();
// Place cursor at the end
const range = document.createRange();
const sel = window.getSelection();
if (modelNameContent.childNodes.length > 0) {
range.setStart(modelNameContent.childNodes[0], modelNameContent.textContent.length);
range.collapse(true);
sel.removeAllRanges();
sel.addRange(range);
}
editBtn.classList.add('visible');
});
// Handle focus out
modelNameContent.addEventListener('blur', function() {
this.removeAttribute('data-editing');
editBtn.classList.remove('visible');
if (this.textContent.trim() === '') {
// Restore original model name if empty
const filePath = this.dataset.filePath;
const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
if (loraCard) {
this.textContent = loraCard.dataset.model_name;
}
}
});
// Handle enter key
modelNameContent.addEventListener('keydown', function(e) {
if (e.key === 'Enter') {
e.preventDefault();
const filePath = this.dataset.filePath;
saveModelName(filePath);
this.blur();
}
});
// Limit model name length
modelNameContent.addEventListener('input', function() {
// Limit model name length
if (this.textContent.length > 100) {
this.textContent = this.textContent.substring(0, 100);
// Place cursor at the end
const range = document.createRange();
const sel = window.getSelection();
range.setStart(this.childNodes[0], 100);
range.collapse(true);
sel.removeAllRanges();
sel.addRange(range);
showToast('Model name is limited to 100 characters', 'warning');
}
});
}
/**
* 保存模型名称
* @param {string} filePath - 文件路径
*/
async function saveModelName(filePath) {
const modelNameElement = document.querySelector('.model-name-content');
const newModelName = modelNameElement.textContent.trim();
// Validate model name
if (!newModelName) {
showToast('Model name cannot be empty', 'error');
return;
}
// Check if model name is too long (limit to 100 characters)
if (newModelName.length > 100) {
showToast('Model name is too long (maximum 100 characters)', 'error');
// Truncate the displayed text
modelNameElement.textContent = newModelName.substring(0, 100);
return;
}
try {
await saveModelMetadata(filePath, { model_name: newModelName });
// Update the corresponding lora card's dataset and display
updateLoraCard(filePath, { model_name: newModelName });
showToast('Model name updated successfully', 'success');
} catch (error) {
showToast('Failed to update model name', 'error');
}
}
/**
* 设置基础模型编辑功能
* @param {string} filePath - 文件路径
*/
export function setupBaseModelEditing(filePath) {
const baseModelContent = document.querySelector('.base-model-content');
const editBtn = document.querySelector('.edit-base-model-btn');
if (!baseModelContent || !editBtn) return;
// Store the file path in a data attribute for later use
baseModelContent.dataset.filePath = filePath;
// Show edit button on hover
const baseModelDisplay = document.querySelector('.base-model-display');
baseModelDisplay.addEventListener('mouseenter', () => {
editBtn.classList.add('visible');
});
baseModelDisplay.addEventListener('mouseleave', () => {
if (!baseModelDisplay.classList.contains('editing')) {
editBtn.classList.remove('visible');
}
});
// Handle edit button click
editBtn.addEventListener('click', () => {
baseModelDisplay.classList.add('editing');
// Store the original value to check for changes later
const originalValue = baseModelContent.textContent.trim();
// Create dropdown selector to replace the base model content
const currentValue = originalValue;
const dropdown = document.createElement('select');
dropdown.className = 'base-model-selector';
// Flag to track if a change was made
let valueChanged = false;
// Add options from BASE_MODELS constants
const baseModelCategories = {
'Stable Diffusion 1.x': [BASE_MODELS.SD_1_4, BASE_MODELS.SD_1_5, BASE_MODELS.SD_1_5_LCM, BASE_MODELS.SD_1_5_HYPER],
'Stable Diffusion 2.x': [BASE_MODELS.SD_2_0, BASE_MODELS.SD_2_1],
'Stable Diffusion 3.x': [BASE_MODELS.SD_3, BASE_MODELS.SD_3_5, BASE_MODELS.SD_3_5_MEDIUM, BASE_MODELS.SD_3_5_LARGE, BASE_MODELS.SD_3_5_LARGE_TURBO],
'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER],
'Video Models': [BASE_MODELS.SVD, BASE_MODELS.WAN_VIDEO, BASE_MODELS.HUNYUAN_VIDEO],
'Other Models': [
BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.AURAFLOW,
BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1,
BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI,
BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.UNKNOWN
]
};
// Create option groups for better organization
Object.entries(baseModelCategories).forEach(([category, models]) => {
const group = document.createElement('optgroup');
group.label = category;
models.forEach(model => {
const option = document.createElement('option');
option.value = model;
option.textContent = model;
option.selected = model === currentValue;
group.appendChild(option);
});
dropdown.appendChild(group);
});
// Replace content with dropdown
baseModelContent.style.display = 'none';
baseModelDisplay.insertBefore(dropdown, editBtn);
// Hide edit button during editing
editBtn.style.display = 'none';
// Focus the dropdown
dropdown.focus();
// Handle dropdown change
dropdown.addEventListener('change', function() {
const selectedModel = this.value;
baseModelContent.textContent = selectedModel;
// Mark that a change was made if the value differs from original
if (selectedModel !== originalValue) {
valueChanged = true;
} else {
valueChanged = false;
}
});
// Function to save changes and exit edit mode
const saveAndExit = function() {
// Check if dropdown still exists and remove it
if (dropdown && dropdown.parentNode === baseModelDisplay) {
baseModelDisplay.removeChild(dropdown);
}
// Show the content and edit button
baseModelContent.style.display = '';
editBtn.style.display = '';
// Remove editing class
baseModelDisplay.classList.remove('editing');
// Only save if the value has actually changed
if (valueChanged || baseModelContent.textContent.trim() !== originalValue) {
// Get file path from the dataset
const filePath = baseModelContent.dataset.filePath;
// Save the changes, passing the original value for comparison
saveBaseModel(filePath, originalValue);
}
// Remove this event listener
document.removeEventListener('click', outsideClickHandler);
};
// Handle outside clicks to save and exit
const outsideClickHandler = function(e) {
// If click is outside the dropdown and base model display
if (!baseModelDisplay.contains(e.target)) {
saveAndExit();
}
};
// Add delayed event listener for outside clicks
setTimeout(() => {
document.addEventListener('click', outsideClickHandler);
}, 0);
// Also handle dropdown blur event
dropdown.addEventListener('blur', function(e) {
// Only save if the related target is not the edit button or inside the baseModelDisplay
if (!baseModelDisplay.contains(e.relatedTarget)) {
saveAndExit();
}
});
});
}
/**
* 保存基础模型
* @param {string} filePath - 文件路径
* @param {string} originalValue - 原始值(用于比较)
*/
async function saveBaseModel(filePath, originalValue) {
const baseModelElement = document.querySelector('.base-model-content');
const newBaseModel = baseModelElement.textContent.trim();
// Only save if the value has actually changed
if (newBaseModel === originalValue) {
return; // No change, no need to save
}
try {
await saveModelMetadata(filePath, { base_model: newBaseModel });
// Update the corresponding lora card's dataset
updateLoraCard(filePath, { base_model: newBaseModel });
showToast('Base model updated successfully', 'success');
} catch (error) {
showToast('Failed to update base model', 'error');
}
}
/**
* 设置文件名编辑功能
* @param {string} filePath - 文件路径
*/
export function setupFileNameEditing(filePath) {
const fileNameContent = document.querySelector('.file-name-content');
const editBtn = document.querySelector('.edit-file-name-btn');
if (!fileNameContent || !editBtn) return;
// Store the original file path
fileNameContent.dataset.filePath = filePath;
// Show edit button on hover
const fileNameWrapper = document.querySelector('.file-name-wrapper');
fileNameWrapper.addEventListener('mouseenter', () => {
editBtn.classList.add('visible');
});
fileNameWrapper.addEventListener('mouseleave', () => {
if (!fileNameWrapper.classList.contains('editing')) {
editBtn.classList.remove('visible');
}
});
// Handle edit button click
editBtn.addEventListener('click', () => {
fileNameWrapper.classList.add('editing');
fileNameContent.setAttribute('contenteditable', 'true');
fileNameContent.focus();
// Store original value for comparison later
fileNameContent.dataset.originalValue = fileNameContent.textContent.trim();
// Place cursor at the end
const range = document.createRange();
const sel = window.getSelection();
range.selectNodeContents(fileNameContent);
range.collapse(false);
sel.removeAllRanges();
sel.addRange(range);
editBtn.classList.add('visible');
});
// Handle keyboard events in edit mode
fileNameContent.addEventListener('keydown', function(e) {
if (!this.getAttribute('contenteditable')) return;
if (e.key === 'Enter') {
e.preventDefault();
this.blur(); // Trigger save on Enter
} else if (e.key === 'Escape') {
e.preventDefault();
// Restore original value
this.textContent = this.dataset.originalValue;
exitEditMode();
}
});
// Handle input validation
fileNameContent.addEventListener('input', function() {
if (!this.getAttribute('contenteditable')) return;
// Replace invalid characters for filenames
const invalidChars = /[\\/:*?"<>|]/g;
if (invalidChars.test(this.textContent)) {
const cursorPos = window.getSelection().getRangeAt(0).startOffset;
this.textContent = this.textContent.replace(invalidChars, '');
// Restore cursor position
const range = document.createRange();
const sel = window.getSelection();
const newPos = Math.min(cursorPos, this.textContent.length);
if (this.firstChild) {
range.setStart(this.firstChild, newPos);
range.collapse(true);
sel.removeAllRanges();
sel.addRange(range);
}
showToast('Invalid characters removed from filename', 'warning');
}
});
// Handle focus out - save changes
fileNameContent.addEventListener('blur', async function() {
if (!this.getAttribute('contenteditable')) return;
const newFileName = this.textContent.trim();
const originalValue = this.dataset.originalValue;
// Basic validation
if (!newFileName) {
// Restore original value if empty
this.textContent = originalValue;
showToast('File name cannot be empty', 'error');
exitEditMode();
return;
}
if (newFileName === originalValue) {
// No changes, just exit edit mode
exitEditMode();
return;
}
try {
// Get the file path from the dataset
const filePath = this.dataset.filePath;
// Call API to rename the file
const response = await fetch('/api/rename_lora', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
file_path: filePath,
new_file_name: newFileName
})
});
const result = await response.json();
if (result.success) {
showToast('File name updated successfully', 'success');
// Get the new file path and update the card
const newFilePath = filePath.replace(originalValue, newFileName);
// Pass the new file_name in the updates object for proper card update
updateLoraCard(filePath, { file_name: newFileName }, newFilePath);
} else {
throw new Error(result.error || 'Unknown error');
}
} catch (error) {
console.error('Error renaming file:', error);
this.textContent = originalValue; // Restore original file name
showToast(`Failed to rename file: ${error.message}`, 'error');
} finally {
exitEditMode();
}
});
function exitEditMode() {
fileNameContent.removeAttribute('contenteditable');
fileNameWrapper.classList.remove('editing');
editBtn.classList.remove('visible');
}
}

View File

@@ -0,0 +1,68 @@
/**
* PresetTags.js
* 处理LoRA模型预设参数标签相关的功能模块
*/
import { saveModelMetadata } from './ModelMetadata.js';
import { showToast } from '../../utils/uiHelpers.js';
/**
* 解析预设参数
* @param {string} usageTips - 包含预设参数的JSON字符串
* @returns {Object} 解析后的预设参数对象
*/
export function parsePresets(usageTips) {
if (!usageTips) return {};
try {
return JSON.parse(usageTips);
} catch {
return {};
}
}
/**
* 渲染预设标签
* @param {Object} presets - 预设参数对象
* @returns {string} HTML内容
*/
export function renderPresetTags(presets) {
return Object.entries(presets).map(([key, value]) => `
<div class="preset-tag" data-key="${key}">
<span>${formatPresetKey(key)}: ${value}</span>
<i class="fas fa-times" onclick="removePreset('${key}')"></i>
</div>
`).join('');
}
/**
* 格式化预设键名
* @param {string} key - 预设键名
* @returns {string} 格式化后的键名
*/
function formatPresetKey(key) {
return key.split('_').map(word =>
word.charAt(0).toUpperCase() + word.slice(1)
).join(' ');
}
/**
* 移除预设参数
* @param {string} key - 要移除的预设键名
*/
window.removePreset = async function(key) {
const filePath = document.querySelector('#loraModal .modal-content')
.querySelector('.file-path').textContent +
document.querySelector('#loraModal .modal-content')
.querySelector('#file-name').textContent + '.safetensors';
const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
const currentPresets = parsePresets(loraCard.dataset.usage_tips);
delete currentPresets[key];
const newPresetsJson = JSON.stringify(currentPresets);
await saveModelMetadata(filePath, {
usage_tips: newPresetsJson
});
loraCard.dataset.usage_tips = newPresetsJson;
document.querySelector('.preset-tags').innerHTML = renderPresetTags(currentPresets);
};

View File

@@ -0,0 +1,234 @@
/**
* RecipeTab - Handles the recipes tab in the Lora Modal
*/
import { showToast } from '../../utils/uiHelpers.js';
import { setSessionItem, removeSessionItem } from '../../utils/storageHelpers.js';
/**
* Loads recipes that use the specified Lora and renders them in the tab
* @param {string} loraName - The display name of the Lora
* @param {string} sha256 - The SHA256 hash of the Lora
*/
export function loadRecipesForLora(loraName, sha256) {
const recipeTab = document.getElementById('recipes-tab');
if (!recipeTab) return;
// Show loading state
recipeTab.innerHTML = `
<div class="recipes-loading">
<i class="fas fa-spinner fa-spin"></i> Loading recipes...
</div>
`;
// Fetch recipes that use this Lora by hash
fetch(`/api/recipes/for-lora?hash=${encodeURIComponent(sha256.toLowerCase())}`)
.then(response => response.json())
.then(data => {
if (!data.success) {
throw new Error(data.error || 'Failed to load recipes');
}
renderRecipes(recipeTab, data.recipes, loraName, sha256);
})
.catch(error => {
console.error('Error loading recipes for Lora:', error);
recipeTab.innerHTML = `
<div class="recipes-error">
<i class="fas fa-exclamation-circle"></i>
<p>Failed to load recipes. Please try again later.</p>
</div>
`;
});
}
/**
* Renders the recipe cards in the tab
* @param {HTMLElement} tabElement - The tab element to render into
* @param {Array} recipes - Array of recipe objects
* @param {string} loraName - The display name of the Lora
* @param {string} loraHash - The hash of the Lora
*/
function renderRecipes(tabElement, recipes, loraName, loraHash) {
if (!recipes || recipes.length === 0) {
tabElement.innerHTML = `
<div class="recipes-empty">
<i class="fas fa-book-open"></i>
<p>No recipes found that use this Lora.</p>
</div>
`;
return;
}
// Create header with count and view all button
const headerElement = document.createElement('div');
headerElement.className = 'recipes-header';
headerElement.innerHTML = `
<h3>Found ${recipes.length} recipe${recipes.length > 1 ? 's' : ''} using this Lora</h3>
<button class="view-all-btn" title="View all in Recipes page">
<i class="fas fa-external-link-alt"></i> View All in Recipes
</button>
`;
// Add click handler for "View All" button
headerElement.querySelector('.view-all-btn').addEventListener('click', () => {
navigateToRecipesPage(loraName, loraHash);
});
// Create grid container for recipe cards
const cardGrid = document.createElement('div');
cardGrid.className = 'card-grid';
// Create recipe cards matching the structure in recipes.html
recipes.forEach(recipe => {
// Get basic info
const baseModel = recipe.base_model || '';
const loras = recipe.loras || [];
const lorasCount = loras.length;
const missingLorasCount = loras.filter(lora => !lora.inLibrary && !lora.isDeleted).length;
const allLorasAvailable = missingLorasCount === 0 && lorasCount > 0;
// Ensure file_url exists, fallback to file_path if needed
const imageUrl = recipe.file_url ||
(recipe.file_path ? `/loras_static/root1/preview/${recipe.file_path.split('/').pop()}` :
'/loras_static/images/no-preview.png');
// Create card element matching the structure in recipes.html
const card = document.createElement('div');
card.className = 'lora-card';
card.dataset.filePath = recipe.file_path || '';
card.dataset.title = recipe.title || '';
card.dataset.created = recipe.created_date || '';
card.dataset.id = recipe.id || '';
card.innerHTML = `
<div class="recipe-indicator" title="Recipe">R</div>
<div class="card-preview">
<img src="${imageUrl}" alt="${recipe.title}" loading="lazy">
<div class="card-header">
<div class="base-model-wrapper">
${baseModel ? `<span class="base-model-label" title="${baseModel}">${baseModel}</span>` : ''}
</div>
<div class="card-actions">
<i class="fas fa-copy" title="Copy Recipe Syntax"></i>
</div>
</div>
<div class="card-footer">
<div class="model-info">
<span class="model-name">${recipe.title}</span>
</div>
<div class="lora-count ${allLorasAvailable ? 'ready' : (lorasCount > 0 ? 'missing' : '')}"
title="${getLoraStatusTitle(lorasCount, missingLorasCount)}">
<i class="fas fa-layer-group"></i> ${lorasCount}
</div>
</div>
</div>
`;
// Add event listeners for action buttons
card.querySelector('.fa-copy').addEventListener('click', (e) => {
e.stopPropagation();
copyRecipeSyntax(recipe.id);
});
// Add click handler for the entire card
card.addEventListener('click', () => {
navigateToRecipeDetails(recipe.id);
});
// Add card to grid
cardGrid.appendChild(card);
});
// Clear loading indicator and append content
tabElement.innerHTML = '';
tabElement.appendChild(headerElement);
tabElement.appendChild(cardGrid);
}
/**
* Returns a descriptive title for the LoRA status indicator
* @param {number} totalCount - Total number of LoRAs in recipe
* @param {number} missingCount - Number of missing LoRAs
* @returns {string} Status title text
*/
function getLoraStatusTitle(totalCount, missingCount) {
if (totalCount === 0) return "No LoRAs in this recipe";
if (missingCount === 0) return "All LoRAs available - Ready to use";
return `${missingCount} of ${totalCount} LoRAs missing`;
}
/**
* Copies recipe syntax to clipboard
* @param {string} recipeId - The recipe ID
*/
function copyRecipeSyntax(recipeId) {
if (!recipeId) {
showToast('Cannot copy recipe syntax: Missing recipe ID', 'error');
return;
}
fetch(`/api/recipe/${recipeId}/syntax`)
.then(response => response.json())
.then(data => {
if (data.success && data.syntax) {
return navigator.clipboard.writeText(data.syntax);
} else {
throw new Error(data.error || 'No syntax returned');
}
})
.then(() => {
showToast('Recipe syntax copied to clipboard', 'success');
})
.catch(err => {
console.error('Failed to copy: ', err);
showToast('Failed to copy recipe syntax', 'error');
});
}
/**
* Navigates to the recipes page with filter for the current Lora
* @param {string} loraName - The Lora display name to filter by
* @param {string} loraHash - The hash of the Lora to filter by
* @param {boolean} createNew - Whether to open the create recipe dialog
*/
function navigateToRecipesPage(loraName, loraHash) {
// Close the current modal
if (window.modalManager) {
modalManager.closeModal('loraModal');
}
// Clear any previous filters first
removeSessionItem('lora_to_recipe_filterLoraName');
removeSessionItem('lora_to_recipe_filterLoraHash');
removeSessionItem('viewRecipeId');
// Store the LoRA name and hash filter in sessionStorage
setSessionItem('lora_to_recipe_filterLoraName', loraName);
setSessionItem('lora_to_recipe_filterLoraHash', loraHash);
// Directly navigate to recipes page
window.location.href = '/loras/recipes';
}
/**
* Navigates directly to a specific recipe's details
* @param {string} recipeId - The recipe ID to view
*/
function navigateToRecipeDetails(recipeId) {
// Close the current modal
if (window.modalManager) {
modalManager.closeModal('loraModal');
}
// Clear any previous filters first
removeSessionItem('filterLoraName');
removeSessionItem('filterLoraHash');
removeSessionItem('viewRecipeId');
// Store the recipe ID in sessionStorage to load on recipes page
setSessionItem('viewRecipeId', recipeId);
// Directly navigate to recipes page
window.location.href = '/loras/recipes';
}

View File

@@ -0,0 +1,501 @@
/**
* ShowcaseView.js
* 处理LoRA模型展示内容图片、视频的功能模块
*/
import { showToast } from '../../utils/uiHelpers.js';
import { state } from '../../state/index.js';
import { NSFW_LEVELS } from '../../utils/constants.js';
/**
* 渲染展示内容
* @param {Array} images - 要展示的图片/视频数组
* @returns {string} HTML内容
*/
export function renderShowcaseContent(images) {
if (!images?.length) return '<div class="no-examples">No example images available</div>';
// Filter images based on SFW setting
const showOnlySFW = state.settings.show_only_sfw;
let filteredImages = images;
let hiddenCount = 0;
if (showOnlySFW) {
filteredImages = images.filter(img => {
const nsfwLevel = img.nsfwLevel !== undefined ? img.nsfwLevel : 0;
const isSfw = nsfwLevel < NSFW_LEVELS.R;
if (!isSfw) hiddenCount++;
return isSfw;
});
}
// Show message if no images are available after filtering
if (filteredImages.length === 0) {
return `
<div class="no-examples">
<p>All example images are filtered due to NSFW content settings</p>
<p class="nsfw-filter-info">Your settings are currently set to show only safe-for-work content</p>
<p>You can change this in Settings <i class="fas fa-cog"></i></p>
</div>
`;
}
// Show hidden content notification if applicable
const hiddenNotification = hiddenCount > 0 ?
`<div class="nsfw-filter-notification">
<i class="fas fa-eye-slash"></i> ${hiddenCount} ${hiddenCount === 1 ? 'image' : 'images'} hidden due to SFW-only setting
</div>` : '';
return `
<div class="scroll-indicator" onclick="toggleShowcase(this)">
<i class="fas fa-chevron-down"></i>
<span>Scroll or click to show ${filteredImages.length} examples</span>
</div>
<div class="carousel collapsed">
${hiddenNotification}
<div class="carousel-container">
${filteredImages.map(img => {
// 计算适当的展示高度:
// 1. 保持原始宽高比
// 2. 限制最大高度为视窗高度的60%
// 3. 确保最小高度为容器宽度的40%
const aspectRatio = (img.height / img.width) * 100;
const containerWidth = 800; // modal content的最大宽度
const minHeightPercent = 40; // 最小高度为容器宽度的40%
const maxHeightPercent = (window.innerHeight * 0.6 / containerWidth) * 100;
const heightPercent = Math.max(
minHeightPercent,
Math.min(maxHeightPercent, aspectRatio)
);
// Check if image should be blurred
const nsfwLevel = img.nsfwLevel !== undefined ? img.nsfwLevel : 0;
const shouldBlur = state.settings.blurMatureContent && nsfwLevel > NSFW_LEVELS.PG13;
// Determine NSFW warning text based on level
let nsfwText = "Mature Content";
if (nsfwLevel >= NSFW_LEVELS.XXX) {
nsfwText = "XXX-rated Content";
} else if (nsfwLevel >= NSFW_LEVELS.X) {
nsfwText = "X-rated Content";
} else if (nsfwLevel >= NSFW_LEVELS.R) {
nsfwText = "R-rated Content";
}
// Extract metadata from the image
const meta = img.meta || {};
const prompt = meta.prompt || '';
const negativePrompt = meta.negative_prompt || meta.negativePrompt || '';
const size = meta.Size || `${img.width}x${img.height}`;
const seed = meta.seed || '';
const model = meta.Model || '';
const steps = meta.steps || '';
const sampler = meta.sampler || '';
const cfgScale = meta.cfgScale || '';
const clipSkip = meta.clipSkip || '';
// Check if we have any meaningful generation parameters
const hasParams = seed || model || steps || sampler || cfgScale || clipSkip;
const hasPrompts = prompt || negativePrompt;
// If no metadata available, show a message
if (!hasParams && !hasPrompts) {
const metadataPanel = `
<div class="image-metadata-panel">
<div class="metadata-content">
<div class="no-metadata-message">
<i class="fas fa-info-circle"></i>
<span>No generation parameters available</span>
</div>
</div>
</div>
`;
if (img.type === 'video') {
return generateVideoWrapper(img, heightPercent, shouldBlur, nsfwText, metadataPanel);
}
return generateImageWrapper(img, heightPercent, shouldBlur, nsfwText, metadataPanel);
}
// Create a data attribute with the prompt for copying instead of trying to handle it in the onclick
// This avoids issues with quotes and special characters
const promptIndex = Math.random().toString(36).substring(2, 15);
const negPromptIndex = Math.random().toString(36).substring(2, 15);
// Create parameter tags HTML
const paramTags = `
<div class="params-tags">
${size ? `<div class="param-tag"><span class="param-name">Size:</span><span class="param-value">${size}</span></div>` : ''}
${seed ? `<div class="param-tag"><span class="param-name">Seed:</span><span class="param-value">${seed}</span></div>` : ''}
${model ? `<div class="param-tag"><span class="param-name">Model:</span><span class="param-value">${model}</span></div>` : ''}
${steps ? `<div class="param-tag"><span class="param-name">Steps:</span><span class="param-value">${steps}</span></div>` : ''}
${sampler ? `<div class="param-tag"><span class="param-name">Sampler:</span><span class="param-value">${sampler}</span></div>` : ''}
${cfgScale ? `<div class="param-tag"><span class="param-name">CFG:</span><span class="param-value">${cfgScale}</span></div>` : ''}
${clipSkip ? `<div class="param-tag"><span class="param-name">Clip Skip:</span><span class="param-value">${clipSkip}</span></div>` : ''}
</div>
`;
// Metadata panel HTML
const metadataPanel = `
<div class="image-metadata-panel">
<div class="metadata-content">
${hasParams ? paramTags : ''}
${!hasParams && !hasPrompts ? `
<div class="no-metadata-message">
<i class="fas fa-info-circle"></i>
<span>No generation parameters available</span>
</div>
` : ''}
${prompt ? `
<div class="metadata-row prompt-row">
<span class="metadata-label">Prompt:</span>
<div class="metadata-prompt-wrapper">
<div class="metadata-prompt">${prompt}</div>
<button class="copy-prompt-btn" data-prompt-index="${promptIndex}">
<i class="fas fa-copy"></i>
</button>
</div>
</div>
<div class="hidden-prompt" id="prompt-${promptIndex}" style="display:none;">${prompt}</div>
` : ''}
${negativePrompt ? `
<div class="metadata-row prompt-row">
<span class="metadata-label">Negative Prompt:</span>
<div class="metadata-prompt-wrapper">
<div class="metadata-prompt">${negativePrompt}</div>
<button class="copy-prompt-btn" data-prompt-index="${negPromptIndex}">
<i class="fas fa-copy"></i>
</button>
</div>
</div>
<div class="hidden-prompt" id="prompt-${negPromptIndex}" style="display:none;">${negativePrompt}</div>
` : ''}
</div>
</div>
`;
if (img.type === 'video') {
return generateVideoWrapper(img, heightPercent, shouldBlur, nsfwText, metadataPanel);
}
return generateImageWrapper(img, heightPercent, shouldBlur, nsfwText, metadataPanel);
}).join('')}
</div>
</div>
`;
}
/**
* 生成视频包装HTML
*/
function generateVideoWrapper(img, heightPercent, shouldBlur, nsfwText, metadataPanel) {
return `
<div class="media-wrapper ${shouldBlur ? 'nsfw-media-wrapper' : ''}" style="padding-bottom: ${heightPercent}%">
${shouldBlur ? `
<button class="toggle-blur-btn showcase-toggle-btn" title="Toggle blur">
<i class="fas fa-eye"></i>
</button>
` : ''}
<video controls autoplay muted loop crossorigin="anonymous"
referrerpolicy="no-referrer" data-src="${img.url}"
class="lazy ${shouldBlur ? 'blurred' : ''}">
<source data-src="${img.url}" type="video/mp4">
Your browser does not support video playback
</video>
${shouldBlur ? `
<div class="nsfw-overlay">
<div class="nsfw-warning">
<p>${nsfwText}</p>
<button class="show-content-btn">Show</button>
</div>
</div>
` : ''}
${metadataPanel}
</div>
`;
}
/**
* 生成图片包装HTML
*/
function generateImageWrapper(img, heightPercent, shouldBlur, nsfwText, metadataPanel) {
return `
<div class="media-wrapper ${shouldBlur ? 'nsfw-media-wrapper' : ''}" style="padding-bottom: ${heightPercent}%">
${shouldBlur ? `
<button class="toggle-blur-btn showcase-toggle-btn" title="Toggle blur">
<i class="fas fa-eye"></i>
</button>
` : ''}
<img data-src="${img.url}"
alt="Preview"
crossorigin="anonymous"
referrerpolicy="no-referrer"
width="${img.width}"
height="${img.height}"
class="lazy ${shouldBlur ? 'blurred' : ''}">
${shouldBlur ? `
<div class="nsfw-overlay">
<div class="nsfw-warning">
<p>${nsfwText}</p>
<button class="show-content-btn">Show</button>
</div>
</div>
` : ''}
${metadataPanel}
</div>
`;
}
/**
* 切换展示区域的显示状态
*/
export function toggleShowcase(element) {
const carousel = element.nextElementSibling;
const isCollapsed = carousel.classList.contains('collapsed');
const indicator = element.querySelector('span');
const icon = element.querySelector('i');
carousel.classList.toggle('collapsed');
if (isCollapsed) {
const count = carousel.querySelectorAll('.media-wrapper').length;
indicator.textContent = `Scroll or click to hide examples`;
icon.classList.replace('fa-chevron-down', 'fa-chevron-up');
initLazyLoading(carousel);
// Initialize NSFW content blur toggle handlers
initNsfwBlurHandlers(carousel);
// Initialize metadata panel interaction handlers
initMetadataPanelHandlers(carousel);
} else {
const count = carousel.querySelectorAll('.media-wrapper').length;
indicator.textContent = `Scroll or click to show ${count} examples`;
icon.classList.replace('fa-chevron-up', 'fa-chevron-down');
// Make sure any open metadata panels get closed
const carouselContainer = carousel.querySelector('.carousel-container');
if (carouselContainer) {
carouselContainer.style.height = '0';
setTimeout(() => {
carouselContainer.style.height = '';
}, 300);
}
}
}
/**
* 初始化元数据面板交互处理
*/
function initMetadataPanelHandlers(container) {
// Find all media wrappers
const mediaWrappers = container.querySelectorAll('.media-wrapper');
mediaWrappers.forEach(wrapper => {
// Get the metadata panel
const metadataPanel = wrapper.querySelector('.image-metadata-panel');
if (!metadataPanel) return;
// Prevent events from the metadata panel from bubbling
metadataPanel.addEventListener('click', (e) => {
e.stopPropagation();
});
// Handle copy prompt button clicks
const copyBtns = metadataPanel.querySelectorAll('.copy-prompt-btn');
copyBtns.forEach(copyBtn => {
const promptIndex = copyBtn.dataset.promptIndex;
const promptElement = wrapper.querySelector(`#prompt-${promptIndex}`);
copyBtn.addEventListener('click', async (e) => {
e.stopPropagation(); // Prevent bubbling
if (!promptElement) return;
try {
await navigator.clipboard.writeText(promptElement.textContent);
showToast('Prompt copied to clipboard', 'success');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');
}
});
});
// Prevent scrolling in the metadata panel from scrolling the whole modal
metadataPanel.addEventListener('wheel', (e) => {
const isAtTop = metadataPanel.scrollTop === 0;
const isAtBottom = metadataPanel.scrollHeight - metadataPanel.scrollTop === metadataPanel.clientHeight;
// Only prevent default if scrolling would cause the panel to scroll
if ((e.deltaY < 0 && !isAtTop) || (e.deltaY > 0 && !isAtBottom)) {
e.stopPropagation();
}
}, { passive: true });
});
}
/**
* 初始化模糊切换处理
*/
function initNsfwBlurHandlers(container) {
// Handle toggle blur buttons
const toggleButtons = container.querySelectorAll('.toggle-blur-btn');
toggleButtons.forEach(btn => {
btn.addEventListener('click', (e) => {
e.stopPropagation();
const wrapper = btn.closest('.media-wrapper');
const media = wrapper.querySelector('img, video');
const isBlurred = media.classList.toggle('blurred');
const icon = btn.querySelector('i');
// Update the icon based on blur state
if (isBlurred) {
icon.className = 'fas fa-eye';
} else {
icon.className = 'fas fa-eye-slash';
}
// Toggle the overlay visibility
const overlay = wrapper.querySelector('.nsfw-overlay');
if (overlay) {
overlay.style.display = isBlurred ? 'flex' : 'none';
}
});
});
// Handle "Show" buttons in overlays
const showButtons = container.querySelectorAll('.show-content-btn');
showButtons.forEach(btn => {
btn.addEventListener('click', (e) => {
e.stopPropagation();
const wrapper = btn.closest('.media-wrapper');
const media = wrapper.querySelector('img, video');
media.classList.remove('blurred');
// Update the toggle button icon
const toggleBtn = wrapper.querySelector('.toggle-blur-btn');
if (toggleBtn) {
toggleBtn.querySelector('i').className = 'fas fa-eye-slash';
}
// Hide the overlay
const overlay = wrapper.querySelector('.nsfw-overlay');
if (overlay) {
overlay.style.display = 'none';
}
});
});
}
/**
* 初始化延迟加载
*/
function initLazyLoading(container) {
const lazyElements = container.querySelectorAll('.lazy');
const lazyLoad = (element) => {
if (element.tagName.toLowerCase() === 'video') {
element.src = element.dataset.src;
element.querySelector('source').src = element.dataset.src;
element.load();
} else {
element.src = element.dataset.src;
}
element.classList.remove('lazy');
};
const observer = new IntersectionObserver((entries) => {
entries.forEach(entry => {
if (entry.isIntersecting) {
lazyLoad(entry.target);
observer.unobserve(entry.target);
}
});
});
lazyElements.forEach(element => observer.observe(element));
}
/**
* 设置展示区域的滚动处理
*/
export function setupShowcaseScroll() {
// Add event listener to document for wheel events
document.addEventListener('wheel', (event) => {
// Find the active modal content
const modalContent = document.querySelector('#loraModal .modal-content');
if (!modalContent) return;
const showcase = modalContent.querySelector('.showcase-section');
if (!showcase) return;
const carousel = showcase.querySelector('.carousel');
const scrollIndicator = showcase.querySelector('.scroll-indicator');
if (carousel?.classList.contains('collapsed') && event.deltaY > 0) {
const isNearBottom = modalContent.scrollHeight - modalContent.scrollTop - modalContent.clientHeight < 100;
if (isNearBottom) {
toggleShowcase(scrollIndicator);
event.preventDefault();
}
}
}, { passive: false });
// Use MutationObserver instead of deprecated DOMNodeInserted
const observer = new MutationObserver((mutations) => {
for (const mutation of mutations) {
if (mutation.type === 'childList' && mutation.addedNodes.length) {
// Check if loraModal content was added
const loraModal = document.getElementById('loraModal');
if (loraModal && loraModal.querySelector('.modal-content')) {
setupBackToTopButton(loraModal.querySelector('.modal-content'));
}
}
}
});
// Start observing the document body for changes
observer.observe(document.body, { childList: true, subtree: true });
// Also try to set up the button immediately in case the modal is already open
const modalContent = document.querySelector('#loraModal .modal-content');
if (modalContent) {
setupBackToTopButton(modalContent);
}
}
/**
* 设置返回顶部按钮
*/
function setupBackToTopButton(modalContent) {
// Remove any existing scroll listeners to avoid duplicates
modalContent.onscroll = null;
// Add new scroll listener
modalContent.addEventListener('scroll', () => {
const backToTopBtn = modalContent.querySelector('.back-to-top');
if (backToTopBtn) {
if (modalContent.scrollTop > 300) {
backToTopBtn.classList.add('visible');
} else {
backToTopBtn.classList.remove('visible');
}
}
});
// Trigger a scroll event to check initial position
modalContent.dispatchEvent(new Event('scroll'));
}
/**
* 滚动到顶部
*/
export function scrollToTop(button) {
const modalContent = button.closest('.modal-content');
if (modalContent) {
modalContent.scrollTo({
top: 0,
behavior: 'smooth'
});
}
}

View File

@@ -0,0 +1,345 @@
/**
* TriggerWords.js
* 处理LoRA模型触发词相关的功能模块
*/
import { showToast } from '../../utils/uiHelpers.js';
import { saveModelMetadata } from './ModelMetadata.js';
/**
* 渲染触发词
* @param {Array} words - 触发词数组
* @param {string} filePath - 文件路径
* @returns {string} HTML内容
*/
export function renderTriggerWords(words, filePath) {
if (!words.length) return `
<div class="info-item full-width trigger-words">
<div class="trigger-words-header">
<label>Trigger Words</label>
<button class="edit-trigger-words-btn" data-file-path="${filePath}" title="Edit trigger words">
<i class="fas fa-pencil-alt"></i>
</button>
</div>
<div class="trigger-words-content">
<span class="no-trigger-words">No trigger word needed</span>
<div class="trigger-words-tags" style="display:none;"></div>
</div>
<div class="trigger-words-edit-controls" style="display:none;">
<button class="add-trigger-word-btn" title="Add a trigger word">
<i class="fas fa-plus"></i> Add
</button>
<button class="save-trigger-words-btn" title="Save changes">
<i class="fas fa-save"></i> Save
</button>
</div>
<div class="add-trigger-word-form" style="display:none;">
<input type="text" class="new-trigger-word-input" placeholder="Enter trigger word">
<button class="confirm-add-trigger-word-btn">Add</button>
<button class="cancel-add-trigger-word-btn">Cancel</button>
</div>
</div>
`;
return `
<div class="info-item full-width trigger-words">
<div class="trigger-words-header">
<label>Trigger Words</label>
<button class="edit-trigger-words-btn" data-file-path="${filePath}" title="Edit trigger words">
<i class="fas fa-pencil-alt"></i>
</button>
</div>
<div class="trigger-words-content">
<div class="trigger-words-tags">
${words.map(word => `
<div class="trigger-word-tag" data-word="${word}" onclick="copyTriggerWord('${word}')">
<span class="trigger-word-content">${word}</span>
<span class="trigger-word-copy">
<i class="fas fa-copy"></i>
</span>
<button class="delete-trigger-word-btn" style="display:none;" onclick="event.stopPropagation();">
<i class="fas fa-times"></i>
</button>
</div>
`).join('')}
</div>
</div>
<div class="trigger-words-edit-controls" style="display:none;">
<button class="add-trigger-word-btn" title="Add a trigger word">
<i class="fas fa-plus"></i> Add
</button>
<button class="save-trigger-words-btn" title="Save changes">
<i class="fas fa-save"></i> Save
</button>
</div>
<div class="add-trigger-word-form" style="display:none;">
<input type="text" class="new-trigger-word-input" placeholder="Enter trigger word">
<button class="confirm-add-trigger-word-btn">Add</button>
<button class="cancel-add-trigger-word-btn">Cancel</button>
</div>
</div>
`;
}
/**
* 设置触发词编辑模式
*/
export function setupTriggerWordsEditMode() {
const editBtn = document.querySelector('.edit-trigger-words-btn');
if (!editBtn) return;
editBtn.addEventListener('click', function() {
const triggerWordsSection = this.closest('.trigger-words');
const isEditMode = triggerWordsSection.classList.toggle('edit-mode');
// Toggle edit mode UI elements
const triggerWordTags = triggerWordsSection.querySelectorAll('.trigger-word-tag');
const editControls = triggerWordsSection.querySelector('.trigger-words-edit-controls');
const noTriggerWords = triggerWordsSection.querySelector('.no-trigger-words');
const tagsContainer = triggerWordsSection.querySelector('.trigger-words-tags');
if (isEditMode) {
this.innerHTML = '<i class="fas fa-times"></i>'; // Change to cancel icon
this.title = "Cancel editing";
editControls.style.display = 'flex';
// If we have no trigger words yet, hide the "No trigger word needed" text
// and show the empty tags container
if (noTriggerWords) {
noTriggerWords.style.display = 'none';
if (tagsContainer) tagsContainer.style.display = 'flex';
}
// Disable click-to-copy and show delete buttons
triggerWordTags.forEach(tag => {
tag.onclick = null;
tag.querySelector('.trigger-word-copy').style.display = 'none';
tag.querySelector('.delete-trigger-word-btn').style.display = 'block';
});
} else {
this.innerHTML = '<i class="fas fa-pencil-alt"></i>'; // Change back to edit icon
this.title = "Edit trigger words";
editControls.style.display = 'none';
// If we have no trigger words, show the "No trigger word needed" text
// and hide the empty tags container
const currentTags = triggerWordsSection.querySelectorAll('.trigger-word-tag');
if (noTriggerWords && currentTags.length === 0) {
noTriggerWords.style.display = '';
if (tagsContainer) tagsContainer.style.display = 'none';
}
// Restore original state
triggerWordTags.forEach(tag => {
const word = tag.dataset.word;
tag.onclick = () => copyTriggerWord(word);
tag.querySelector('.trigger-word-copy').style.display = 'flex';
tag.querySelector('.delete-trigger-word-btn').style.display = 'none';
});
// Hide add form if open
triggerWordsSection.querySelector('.add-trigger-word-form').style.display = 'none';
}
});
// Set up add trigger word button
const addBtn = document.querySelector('.add-trigger-word-btn');
if (addBtn) {
addBtn.addEventListener('click', function() {
const triggerWordsSection = this.closest('.trigger-words');
const addForm = triggerWordsSection.querySelector('.add-trigger-word-form');
addForm.style.display = 'flex';
addForm.querySelector('input').focus();
});
}
// Set up confirm and cancel add buttons
const confirmAddBtn = document.querySelector('.confirm-add-trigger-word-btn');
const cancelAddBtn = document.querySelector('.cancel-add-trigger-word-btn');
const triggerWordInput = document.querySelector('.new-trigger-word-input');
if (confirmAddBtn && triggerWordInput) {
confirmAddBtn.addEventListener('click', function() {
addNewTriggerWord(triggerWordInput.value);
});
// Add keydown event to input
triggerWordInput.addEventListener('keydown', function(e) {
if (e.key === 'Enter') {
e.preventDefault();
addNewTriggerWord(this.value);
}
});
}
if (cancelAddBtn) {
cancelAddBtn.addEventListener('click', function() {
const addForm = this.closest('.add-trigger-word-form');
addForm.style.display = 'none';
addForm.querySelector('input').value = '';
});
}
// Set up save button
const saveBtn = document.querySelector('.save-trigger-words-btn');
if (saveBtn) {
saveBtn.addEventListener('click', saveTriggerWords);
}
// Set up delete buttons
document.querySelectorAll('.delete-trigger-word-btn').forEach(btn => {
btn.addEventListener('click', function(e) {
e.stopPropagation();
const tag = this.closest('.trigger-word-tag');
tag.remove();
});
});
}
/**
* 添加新触发词
* @param {string} word - 要添加的触发词
*/
function addNewTriggerWord(word) {
word = word.trim();
if (!word) return;
const triggerWordsSection = document.querySelector('.trigger-words');
let tagsContainer = document.querySelector('.trigger-words-tags');
// Ensure tags container exists and is visible
if (tagsContainer) {
tagsContainer.style.display = 'flex';
} else {
// Create tags container if it doesn't exist
const contentDiv = triggerWordsSection.querySelector('.trigger-words-content');
if (contentDiv) {
tagsContainer = document.createElement('div');
tagsContainer.className = 'trigger-words-tags';
contentDiv.appendChild(tagsContainer);
}
}
if (!tagsContainer) return;
// Hide "no trigger words" message if it exists
const noTriggerWordsMsg = triggerWordsSection.querySelector('.no-trigger-words');
if (noTriggerWordsMsg) {
noTriggerWordsMsg.style.display = 'none';
}
// Validation: Check length
if (word.split(/\s+/).length > 30) {
showToast('Trigger word should not exceed 30 words', 'error');
return;
}
// Validation: Check total number
const currentTags = tagsContainer.querySelectorAll('.trigger-word-tag');
if (currentTags.length >= 10) {
showToast('Maximum 10 trigger words allowed', 'error');
return;
}
// Validation: Check for duplicates
const existingWords = Array.from(currentTags).map(tag => tag.dataset.word);
if (existingWords.includes(word)) {
showToast('This trigger word already exists', 'error');
return;
}
// Create new tag
const newTag = document.createElement('div');
newTag.className = 'trigger-word-tag';
newTag.dataset.word = word;
newTag.innerHTML = `
<span class="trigger-word-content">${word}</span>
<span class="trigger-word-copy" style="display:none;">
<i class="fas fa-copy"></i>
</span>
<button class="delete-trigger-word-btn" onclick="event.stopPropagation();">
<i class="fas fa-times"></i>
</button>
`;
// Add event listener to delete button
const deleteBtn = newTag.querySelector('.delete-trigger-word-btn');
deleteBtn.addEventListener('click', function() {
newTag.remove();
});
tagsContainer.appendChild(newTag);
// Clear and hide the input form
const triggerWordInput = document.querySelector('.new-trigger-word-input');
triggerWordInput.value = '';
document.querySelector('.add-trigger-word-form').style.display = 'none';
}
/**
* 保存触发词
*/
async function saveTriggerWords() {
const filePath = document.querySelector('.edit-trigger-words-btn').dataset.filePath;
const triggerWordTags = document.querySelectorAll('.trigger-word-tag');
const words = Array.from(triggerWordTags).map(tag => tag.dataset.word);
try {
// Special format for updating nested civitai.trainedWords
await saveModelMetadata(filePath, {
civitai: { trainedWords: words }
});
// Update UI
const editBtn = document.querySelector('.edit-trigger-words-btn');
editBtn.click(); // Exit edit mode
// Update the LoRA card's dataset
const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
if (loraCard) {
try {
// Create a proper structure for civitai data
let civitaiData = {};
// Parse existing data if available
if (loraCard.dataset.meta) {
civitaiData = JSON.parse(loraCard.dataset.meta);
}
// Update trainedWords property
civitaiData.trainedWords = words;
// Update the meta dataset attribute with the full civitai data
loraCard.dataset.meta = JSON.stringify(civitaiData);
} catch (e) {
console.error('Error updating civitai data:', e);
}
}
// If we saved an empty array and there's a no-trigger-words element, show it
const noTriggerWords = document.querySelector('.no-trigger-words');
const tagsContainer = document.querySelector('.trigger-words-tags');
if (words.length === 0 && noTriggerWords) {
noTriggerWords.style.display = '';
if (tagsContainer) tagsContainer.style.display = 'none';
}
showToast('Trigger words updated successfully', 'success');
} catch (error) {
console.error('Error saving trigger words:', error);
showToast('Failed to update trigger words', 'error');
}
}
/**
* 复制触发词到剪贴板
* @param {string} word - 要复制的触发词
*/
window.copyTriggerWord = async function(word) {
try {
await navigator.clipboard.writeText(word);
showToast('Trigger word copied', 'success');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');
}
};

View File

@@ -0,0 +1,293 @@
/**
* LoraModal - 主入口点
*
* 将原始的LoraModal.js拆分成多个功能模块后的主入口文件
*/
import { showToast } from '../../utils/uiHelpers.js';
import { state } from '../../state/index.js';
import { modalManager } from '../../managers/ModalManager.js';
import { renderShowcaseContent, toggleShowcase, setupShowcaseScroll, scrollToTop } from './ShowcaseView.js';
import { setupTabSwitching, loadModelDescription } from './ModelDescription.js';
import { renderTriggerWords, setupTriggerWordsEditMode } from './TriggerWords.js';
import { parsePresets, renderPresetTags } from './PresetTags.js';
import { loadRecipesForLora } from './RecipeTab.js'; // Add import for recipe tab
import {
setupModelNameEditing,
setupBaseModelEditing,
setupFileNameEditing,
saveModelMetadata
} from './ModelMetadata.js';
import { renderCompactTags, setupTagTooltip, formatFileSize } from './utils.js';
import { updateLoraCard } from '../../utils/cardUpdater.js';
/**
* 显示LoRA模型弹窗
* @param {Object} lora - LoRA模型数据
*/
export function showLoraModal(lora) {
const escapedWords = lora.civitai?.trainedWords?.length ?
lora.civitai.trainedWords.map(word => word.replace(/'/g, '\\\'')) : [];
const content = `
<div class="modal-content">
<button class="close" onclick="modalManager.closeModal('loraModal')">&times;</button>
<header class="modal-header">
<div class="model-name-header">
<h2 class="model-name-content" contenteditable="true" spellcheck="false">${lora.model_name}</h2>
<button class="edit-model-name-btn" title="Edit model name">
<i class="fas fa-pencil-alt"></i>
</button>
</div>
${renderCompactTags(lora.tags || [])}
</header>
<div class="modal-body">
<div class="info-section">
<div class="info-grid">
<div class="info-item">
<label>Version</label>
<span>${lora.civitai.name || 'N/A'}</span>
</div>
<div class="info-item">
<label>File Name</label>
<div class="file-name-wrapper">
<span id="file-name" class="file-name-content">${lora.file_name || 'N/A'}</span>
<button class="edit-file-name-btn" title="Edit file name">
<i class="fas fa-pencil-alt"></i>
</button>
</div>
</div>
<div class="info-item location-size">
<div class="location-wrapper">
<label>Location</label>
<span class="file-path">${lora.file_path.replace(/[^/]+$/, '') || 'N/A'}</span>
</div>
</div>
<div class="info-item base-size">
<div class="base-wrapper">
<label>Base Model</label>
<div class="base-model-display">
<span class="base-model-content">${lora.base_model || 'N/A'}</span>
<button class="edit-base-model-btn" title="Edit base model">
<i class="fas fa-pencil-alt"></i>
</button>
</div>
</div>
<div class="size-wrapper">
<label>Size</label>
<span>${formatFileSize(lora.file_size)}</span>
</div>
</div>
<div class="info-item usage-tips">
<label>Usage Tips</label>
<div class="editable-field">
<div class="preset-controls">
<select id="preset-selector">
<option value="">Add preset parameter...</option>
<option value="strength_min">Strength Min</option>
<option value="strength_max">Strength Max</option>
<option value="strength">Strength</option>
<option value="clip_skip">Clip Skip</option>
</select>
<input type="number" id="preset-value" step="0.01" placeholder="Value" style="display:none;">
<button class="add-preset-btn">Add</button>
</div>
<div class="preset-tags">
${renderPresetTags(parsePresets(lora.usage_tips))}
</div>
</div>
</div>
${renderTriggerWords(escapedWords, lora.file_path)}
<div class="info-item notes">
<label>Additional Notes</label>
<div class="editable-field">
<div class="notes-content" contenteditable="true" spellcheck="false">${lora.notes || 'Add your notes here...'}</div>
<button class="save-btn" onclick="saveNotes('${lora.file_path}')">
<i class="fas fa-save"></i>
</button>
</div>
</div>
<div class="info-item full-width">
<label>About this version</label>
<div class="description-text">${lora.description || 'N/A'}</div>
</div>
</div>
</div>
<div class="showcase-section" data-lora-id="${lora.civitai?.modelId || ''}">
<div class="showcase-tabs">
<button class="tab-btn active" data-tab="showcase">Examples</button>
<button class="tab-btn" data-tab="description">Model Description</button>
<button class="tab-btn" data-tab="recipes">Recipes</button>
</div>
<div class="tab-content">
<div id="showcase-tab" class="tab-pane active">
${renderShowcaseContent(lora.civitai?.images)}
</div>
<div id="description-tab" class="tab-pane">
<div class="model-description-container">
<div class="model-description-loading">
<i class="fas fa-spinner fa-spin"></i> Loading model description...
</div>
<div class="model-description-content">
${lora.modelDescription || ''}
</div>
</div>
</div>
<div id="recipes-tab" class="tab-pane">
<div class="recipes-loading">
<i class="fas fa-spinner fa-spin"></i> Loading recipes...
</div>
</div>
</div>
<button class="back-to-top" onclick="scrollToTop(this)">
<i class="fas fa-arrow-up"></i>
</button>
</div>
</div>
</div>
`;
modalManager.showModal('loraModal', content);
setupEditableFields(lora.file_path);
setupShowcaseScroll();
setupTabSwitching();
setupTagTooltip();
setupTriggerWordsEditMode();
setupModelNameEditing(lora.file_path);
setupBaseModelEditing(lora.file_path);
setupFileNameEditing(lora.file_path);
// If we have a model ID but no description, fetch it
if (lora.civitai?.modelId && !lora.modelDescription) {
loadModelDescription(lora.civitai.modelId, lora.file_path);
}
// Load recipes for this Lora
loadRecipesForLora(lora.model_name, lora.sha256);
}
// Copy file name function
window.copyFileName = async function(fileName) {
try {
await navigator.clipboard.writeText(fileName);
showToast('File name copied', 'success');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');
}
};
// Add save note function
window.saveNotes = async function(filePath) {
const content = document.querySelector('.notes-content').textContent;
try {
await saveModelMetadata(filePath, { notes: content });
// Update the corresponding lora card's dataset
updateLoraCard(filePath, { notes: content });
showToast('Notes saved successfully', 'success');
} catch (error) {
showToast('Failed to save notes', 'error');
}
};
function setupEditableFields(filePath) {
const editableFields = document.querySelectorAll('.editable-field [contenteditable]');
editableFields.forEach(field => {
field.addEventListener('focus', function() {
if (this.textContent === 'Add your notes here...') {
this.textContent = '';
}
});
field.addEventListener('blur', function() {
if (this.textContent.trim() === '') {
if (this.classList.contains('notes-content')) {
this.textContent = 'Add your notes here...';
}
}
});
});
const presetSelector = document.getElementById('preset-selector');
const presetValue = document.getElementById('preset-value');
const addPresetBtn = document.querySelector('.add-preset-btn');
const presetTags = document.querySelector('.preset-tags');
presetSelector.addEventListener('change', function() {
const selected = this.value;
if (selected) {
presetValue.style.display = 'inline-block';
presetValue.min = selected.includes('strength') ? -10 : 0;
presetValue.max = selected.includes('strength') ? 10 : 10;
presetValue.step = 0.5;
if (selected === 'clip_skip') {
presetValue.type = 'number';
presetValue.step = 1;
}
// Add auto-focus
setTimeout(() => presetValue.focus(), 0);
} else {
presetValue.style.display = 'none';
}
});
addPresetBtn.addEventListener('click', async function() {
const key = presetSelector.value;
const value = presetValue.value;
if (!key || !value) return;
const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
const currentPresets = parsePresets(loraCard.dataset.usage_tips);
currentPresets[key] = parseFloat(value);
const newPresetsJson = JSON.stringify(currentPresets);
await saveModelMetadata(filePath, {
usage_tips: newPresetsJson
});
// Update the card with the new usage tips
updateLoraCard(filePath, { usage_tips: newPresetsJson });
presetTags.innerHTML = renderPresetTags(currentPresets);
presetSelector.value = '';
presetValue.value = '';
presetValue.style.display = 'none';
});
// Add keydown event listeners for notes
const notesContent = document.querySelector('.notes-content');
if (notesContent) {
notesContent.addEventListener('keydown', async function(e) {
if (e.key === 'Enter') {
if (e.shiftKey) {
// Allow shift+enter for new line
return;
}
e.preventDefault();
await saveNotes(filePath);
}
});
}
// Add keydown event for preset value
presetValue.addEventListener('keydown', function(e) {
if (e.key === 'Enter') {
e.preventDefault();
addPresetBtn.click();
}
});
}
// Export functions for global access
export { toggleShowcase, scrollToTop };

View File

@@ -0,0 +1,73 @@
/**
* utils.js
* LoraModal组件的辅助函数集合
*/
import { showToast } from '../../utils/uiHelpers.js';
/**
* 格式化文件大小
* @param {number} bytes - 字节数
* @returns {string} 格式化后的文件大小
*/
export function formatFileSize(bytes) {
if (!bytes) return 'N/A';
const units = ['B', 'KB', 'MB', 'GB'];
let size = bytes;
let unitIndex = 0;
while (size >= 1024 && unitIndex < units.length - 1) {
size /= 1024;
unitIndex++;
}
return `${size.toFixed(1)} ${units[unitIndex]}`;
}
/**
* 渲染紧凑标签
* @param {Array} tags - 标签数组
* @returns {string} HTML内容
*/
export function renderCompactTags(tags) {
if (!tags || tags.length === 0) return '';
// Display up to 5 tags, with a tooltip indicator if there are more
const visibleTags = tags.slice(0, 5);
const remainingCount = Math.max(0, tags.length - 5);
return `
<div class="model-tags-container">
<div class="model-tags-compact">
${visibleTags.map(tag => `<span class="model-tag-compact">${tag}</span>`).join('')}
${remainingCount > 0 ?
`<span class="model-tag-more" data-count="${remainingCount}">+${remainingCount}</span>` :
''}
</div>
${tags.length > 0 ?
`<div class="model-tags-tooltip">
<div class="tooltip-content">
${tags.map(tag => `<span class="tooltip-tag">${tag}</span>`).join('')}
</div>
</div>` :
''}
</div>
`;
}
/**
* 设置标签提示功能
*/
export function setupTagTooltip() {
const tagsContainer = document.querySelector('.model-tags-container');
const tooltip = document.querySelector('.model-tags-tooltip');
if (tagsContainer && tooltip) {
tagsContainer.addEventListener('mouseenter', () => {
tooltip.classList.add('visible');
});
tagsContainer.addEventListener('mouseleave', () => {
tooltip.classList.remove('visible');
});
}
}

View File

@@ -4,7 +4,7 @@ import { LoadingManager } from './managers/LoadingManager.js';
import { modalManager } from './managers/ModalManager.js';
import { updateService } from './managers/UpdateService.js';
import { HeaderManager } from './components/Header.js';
import { SettingsManager } from './managers/SettingsManager.js';
import { settingsManager } from './managers/SettingsManager.js';
import { showToast, initTheme, initBackToTop, lazyLoadImages } from './utils/uiHelpers.js';
import { initializeInfiniteScroll } from './utils/infiniteScroll.js';
import { migrateStorageItems } from './utils/storageHelpers.js';
@@ -26,7 +26,7 @@ export class AppCore {
modalManager.initialize();
updateService.initialize();
window.modalManager = modalManager;
window.settingsManager = new SettingsManager();
window.settingsManager = settingsManager;
// Initialize UI components
window.headerManager = new HeaderManager();
@@ -76,4 +76,4 @@ document.addEventListener('DOMContentLoaded', () => {
export const appCore = new AppCore();
// Export common utilities for global use
export { showToast, lazyLoadImages, initializeInfiniteScroll };
export { showToast, lazyLoadImages, initializeInfiniteScroll };

View File

@@ -1,23 +1,14 @@
import { appCore } from './core.js';
import { state } from './state/index.js';
import { showLoraModal, toggleShowcase, scrollToTop } from './components/LoraModal.js';
import { loadMoreLoras, fetchCivitai, deleteModel, replacePreview, resetAndReload, refreshLoras } from './api/loraApi.js';
import {
restoreFolderFilter,
toggleFolder,
copyTriggerWord,
openCivitai,
toggleFolderTags,
initFolderTagsVisibility,
} from './utils/uiHelpers.js';
import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js';
import { DownloadManager } from './managers/DownloadManager.js';
import { toggleApiKeyVisibility } from './managers/SettingsManager.js';
import { LoraContextMenu } from './components/ContextMenu.js';
import { moveManager } from './managers/MoveManager.js';
import { showLoraModal, toggleShowcase, scrollToTop } from './components/loraModal/index.js';
import { loadMoreLoras } from './api/loraApi.js';
import { updateCardsForBulkMode } from './components/LoraCard.js';
import { bulkManager } from './managers/BulkManager.js';
import { setStorageItem, getStorageItem } from './utils/storageHelpers.js';
import { DownloadManager } from './managers/DownloadManager.js';
import { moveManager } from './managers/MoveManager.js';
import { LoraContextMenu } from './components/ContextMenu.js';
import { createPageControls } from './components/controls/index.js';
import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js';
// Initialize the LoRA page
class LoraPageManager {
@@ -29,25 +20,21 @@ class LoraPageManager {
// Initialize managers
this.downloadManager = new DownloadManager();
// Expose necessary functions to the page
this._exposeGlobalFunctions();
// Initialize page controls
this.pageControls = createPageControls('loras');
// Expose necessary functions to the page that still need global access
// These will be refactored in future updates
this._exposeRequiredGlobalFunctions();
}
_exposeGlobalFunctions() {
// Only expose what's needed for the page
_exposeRequiredGlobalFunctions() {
// Only expose what's still needed globally
// Most functionality is now handled by the PageControls component
window.loadMoreLoras = loadMoreLoras;
window.fetchCivitai = fetchCivitai;
window.deleteModel = deleteModel;
window.replacePreview = replacePreview;
window.toggleFolder = toggleFolder;
window.copyTriggerWord = copyTriggerWord;
window.showLoraModal = showLoraModal;
window.confirmDelete = confirmDelete;
window.closeDeleteModal = closeDeleteModal;
window.refreshLoras = refreshLoras;
window.openCivitai = openCivitai;
window.toggleFolderTags = toggleFolderTags;
window.toggleApiKeyVisibility = toggleApiKeyVisibility;
window.downloadManager = this.downloadManager;
window.moveManager = moveManager;
window.toggleShowcase = toggleShowcase;
@@ -64,9 +51,8 @@ class LoraPageManager {
async initialize() {
// Initialize page-specific components
this.initEventListeners();
restoreFolderFilter();
initFolderTagsVisibility();
this.pageControls.restoreFolderFilter();
this.pageControls.initFolderTagsVisibility();
new LoraContextMenu();
// Initialize cards for current bulk mode state (should be false initially)
@@ -78,38 +64,6 @@ class LoraPageManager {
// Initialize common page features (lazy loading, infinite scroll)
appCore.initializePageFeatures();
}
loadSortPreference() {
const savedSort = getStorageItem('loras_sort');
if (savedSort) {
state.sortBy = savedSort;
const sortSelect = document.getElementById('sortSelect');
if (sortSelect) {
sortSelect.value = savedSort;
}
}
}
saveSortPreference(sortValue) {
setStorageItem('loras_sort', sortValue);
}
initEventListeners() {
const sortSelect = document.getElementById('sortSelect');
if (sortSelect) {
sortSelect.value = state.sortBy;
this.loadSortPreference();
sortSelect.addEventListener('change', async (e) => {
state.sortBy = e.target.value;
this.saveSortPreference(e.target.value);
await resetAndReload();
});
}
document.querySelectorAll('.folder-tags .tag').forEach(tag => {
tag.addEventListener('click', toggleFolder);
});
}
}
// Initialize everything when DOM is ready

View File

@@ -0,0 +1,441 @@
import { modalManager } from './ModalManager.js';
import { showToast } from '../utils/uiHelpers.js';
import { LoadingManager } from './LoadingManager.js';
import { state } from '../state/index.js';
import { resetAndReload } from '../api/checkpointApi.js';
import { getStorageItem, setStorageItem } from '../utils/storageHelpers.js';
export class CheckpointDownloadManager {
constructor() {
this.currentVersion = null;
this.versions = [];
this.modelInfo = null;
this.modelVersionId = null;
this.initialized = false;
this.selectedFolder = '';
this.loadingManager = new LoadingManager();
this.folderClickHandler = null;
this.updateTargetPath = this.updateTargetPath.bind(this);
}
showDownloadModal() {
console.log('Showing checkpoint download modal...');
if (!this.initialized) {
const modal = document.getElementById('checkpointDownloadModal');
if (!modal) {
console.error('Checkpoint download modal element not found');
return;
}
this.initialized = true;
}
modalManager.showModal('checkpointDownloadModal', null, () => {
// Cleanup handler when modal closes
this.cleanupFolderBrowser();
});
this.resetSteps();
}
resetSteps() {
document.querySelectorAll('#checkpointDownloadModal .download-step').forEach(step => step.style.display = 'none');
document.getElementById('cpUrlStep').style.display = 'block';
document.getElementById('checkpointUrl').value = '';
document.getElementById('cpUrlError').textContent = '';
// Clear new folder input
const newFolderInput = document.getElementById('cpNewFolder');
if (newFolderInput) {
newFolderInput.value = '';
}
this.currentVersion = null;
this.versions = [];
this.modelInfo = null;
this.modelVersionId = null;
// Clear selected folder and remove selection from UI
this.selectedFolder = '';
const folderBrowser = document.getElementById('cpFolderBrowser');
if (folderBrowser) {
folderBrowser.querySelectorAll('.folder-item').forEach(f =>
f.classList.remove('selected'));
}
}
async validateAndFetchVersions() {
const url = document.getElementById('checkpointUrl').value.trim();
const errorElement = document.getElementById('cpUrlError');
try {
this.loadingManager.showSimpleLoading('Fetching model versions...');
const modelId = this.extractModelId(url);
if (!modelId) {
throw new Error('Invalid Civitai URL format');
}
const response = await fetch(`/api/checkpoints/civitai/versions/${modelId}`);
if (!response.ok) {
const errorData = await response.json().catch(() => ({}));
if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) {
throw new Error('This model is not a Checkpoint. Please switch to the LoRAs page to download LoRA models.');
}
throw new Error('Failed to fetch model versions');
}
this.versions = await response.json();
if (!this.versions.length) {
throw new Error('No versions available for this model');
}
// If we have a version ID from URL, pre-select it
if (this.modelVersionId) {
this.currentVersion = this.versions.find(v => v.id.toString() === this.modelVersionId);
}
this.showVersionStep();
} catch (error) {
errorElement.textContent = error.message;
} finally {
this.loadingManager.hide();
}
}
extractModelId(url) {
const modelMatch = url.match(/civitai\.com\/models\/(\d+)/);
const versionMatch = url.match(/modelVersionId=(\d+)/);
if (modelMatch) {
this.modelVersionId = versionMatch ? versionMatch[1] : null;
return modelMatch[1];
}
return null;
}
showVersionStep() {
document.getElementById('cpUrlStep').style.display = 'none';
document.getElementById('cpVersionStep').style.display = 'block';
const versionList = document.getElementById('cpVersionList');
versionList.innerHTML = this.versions.map(version => {
const firstImage = version.images?.find(img => !img.url.endsWith('.mp4'));
const thumbnailUrl = firstImage ? firstImage.url : '/loras_static/images/no-preview.png';
// Use version-level size or fallback to first file
const fileSize = version.modelSizeKB ?
(version.modelSizeKB / 1024).toFixed(2) :
(version.files[0]?.sizeKB / 1024).toFixed(2);
// Use version-level existsLocally flag
const existsLocally = version.existsLocally;
const localPath = version.localPath;
// Check if this is an early access version
const isEarlyAccess = version.availability === 'EarlyAccess';
// Create early access badge if needed
let earlyAccessBadge = '';
if (isEarlyAccess) {
earlyAccessBadge = `
<div class="early-access-badge" title="Early access required">
<i class="fas fa-clock"></i> Early Access
</div>
`;
}
// Status badge for local models
const localStatus = existsLocally ?
`<div class="local-badge">
<i class="fas fa-check"></i> In Library
<div class="local-path">${localPath || ''}</div>
</div>` : '';
return `
<div class="version-item ${this.currentVersion?.id === version.id ? 'selected' : ''}
${existsLocally ? 'exists-locally' : ''}
${isEarlyAccess ? 'is-early-access' : ''}"
onclick="checkpointDownloadManager.selectVersion('${version.id}')">
<div class="version-thumbnail">
<img src="${thumbnailUrl}" alt="Version preview">
</div>
<div class="version-content">
<div class="version-header">
<h3>${version.name}</h3>
${localStatus}
</div>
<div class="version-info">
${version.baseModel ? `<div class="base-model">${version.baseModel}</div>` : ''}
${earlyAccessBadge}
</div>
<div class="version-meta">
<span><i class="fas fa-calendar"></i> ${new Date(version.createdAt).toLocaleDateString()}</span>
<span><i class="fas fa-file-archive"></i> ${fileSize} MB</span>
</div>
</div>
</div>
`;
}).join('');
// Update Next button state based on initial selection
this.updateNextButtonState();
}
selectVersion(versionId) {
this.currentVersion = this.versions.find(v => v.id.toString() === versionId.toString());
if (!this.currentVersion) return;
document.querySelectorAll('#cpVersionList .version-item').forEach(item => {
item.classList.toggle('selected', item.querySelector('h3').textContent === this.currentVersion.name);
});
// Update Next button state after selection
this.updateNextButtonState();
}
updateNextButtonState() {
const nextButton = document.querySelector('#cpVersionStep .primary-btn');
if (!nextButton) return;
const existsLocally = this.currentVersion?.existsLocally;
if (existsLocally) {
nextButton.disabled = true;
nextButton.classList.add('disabled');
nextButton.textContent = 'Already in Library';
} else {
nextButton.disabled = false;
nextButton.classList.remove('disabled');
nextButton.textContent = 'Next';
}
}
async proceedToLocation() {
if (!this.currentVersion) {
showToast('Please select a version', 'error');
return;
}
// Double-check if the version exists locally
const existsLocally = this.currentVersion.existsLocally;
if (existsLocally) {
showToast('This version already exists in your library', 'info');
return;
}
document.getElementById('cpVersionStep').style.display = 'none';
document.getElementById('cpLocationStep').style.display = 'block';
try {
// Use checkpoint roots endpoint instead of lora roots
const response = await fetch('/api/checkpoints/roots');
if (!response.ok) {
throw new Error('Failed to fetch checkpoint roots');
}
const data = await response.json();
const checkpointRoot = document.getElementById('checkpointRoot');
checkpointRoot.innerHTML = data.roots.map(root =>
`<option value="${root}">${root}</option>`
).join('');
// Set default checkpoint root if available
const defaultRoot = getStorageItem('settings', {}).default_checkpoints_root;
if (defaultRoot && data.roots.includes(defaultRoot)) {
checkpointRoot.value = defaultRoot;
}
// Initialize folder browser after loading roots
this.initializeFolderBrowser();
} catch (error) {
showToast(error.message, 'error');
}
}
backToUrl() {
document.getElementById('cpVersionStep').style.display = 'none';
document.getElementById('cpUrlStep').style.display = 'block';
}
backToVersions() {
document.getElementById('cpLocationStep').style.display = 'none';
document.getElementById('cpVersionStep').style.display = 'block';
}
async startDownload() {
const checkpointRoot = document.getElementById('checkpointRoot').value;
const newFolder = document.getElementById('cpNewFolder').value.trim();
if (!checkpointRoot) {
showToast('Please select a checkpoint root directory', 'error');
return;
}
// Construct relative path
let targetFolder = '';
if (this.selectedFolder) {
targetFolder = this.selectedFolder;
}
if (newFolder) {
targetFolder = targetFolder ?
`${targetFolder}/${newFolder}` : newFolder;
}
try {
const downloadUrl = this.currentVersion.downloadUrl;
if (!downloadUrl) {
throw new Error('No download URL available');
}
// Show enhanced loading with progress details
const updateProgress = this.loadingManager.showDownloadProgress(1);
updateProgress(0, 0, this.currentVersion.name);
// Setup WebSocket for progress updates using checkpoint-specific endpoint
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/checkpoint-progress`);
ws.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.status === 'progress') {
// Update progress display with current progress
updateProgress(data.progress, 0, this.currentVersion.name);
// Add more detailed status messages based on progress
if (data.progress < 3) {
this.loadingManager.setStatus(`Preparing download...`);
} else if (data.progress === 3) {
this.loadingManager.setStatus(`Downloaded preview image`);
} else if (data.progress > 3 && data.progress < 100) {
this.loadingManager.setStatus(`Downloading checkpoint file`);
} else {
this.loadingManager.setStatus(`Finalizing download...`);
}
}
};
ws.onerror = (error) => {
console.error('WebSocket error:', error);
// Continue with download even if WebSocket fails
};
// Start download using checkpoint download endpoint
const response = await fetch('/api/checkpoints/download', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
download_url: downloadUrl,
checkpoint_root: checkpointRoot,
relative_path: targetFolder
})
});
if (!response.ok) {
throw new Error(await response.text());
}
showToast('Download completed successfully', 'success');
modalManager.closeModal('checkpointDownloadModal');
// Update state specifically for the checkpoints page
state.pages.checkpoints.activeFolder = targetFolder;
// Save the active folder preference to storage
setStorageItem('checkpoints_activeFolder', targetFolder);
// Update UI to show the folder as selected
document.querySelectorAll('.folder-tags .tag').forEach(tag => {
const isActive = tag.dataset.folder === targetFolder;
tag.classList.toggle('active', isActive);
if (isActive && !tag.parentNode.classList.contains('collapsed')) {
// Scroll the tag into view if folder tags are not collapsed
tag.scrollIntoView({ behavior: 'smooth', block: 'nearest' });
}
});
await resetAndReload(true); // Pass true to update folders
} catch (error) {
showToast(error.message, 'error');
} finally {
this.loadingManager.hide();
}
}
initializeFolderBrowser() {
const folderBrowser = document.getElementById('cpFolderBrowser');
if (!folderBrowser) return;
// Cleanup existing handler if any
this.cleanupFolderBrowser();
// Create new handler
this.folderClickHandler = (event) => {
const folderItem = event.target.closest('.folder-item');
if (!folderItem) return;
if (folderItem.classList.contains('selected')) {
folderItem.classList.remove('selected');
this.selectedFolder = '';
} else {
folderBrowser.querySelectorAll('.folder-item').forEach(f =>
f.classList.remove('selected'));
folderItem.classList.add('selected');
this.selectedFolder = folderItem.dataset.folder;
}
// Update path display after folder selection
this.updateTargetPath();
};
// Add the new handler
folderBrowser.addEventListener('click', this.folderClickHandler);
// Add event listeners for path updates
const checkpointRoot = document.getElementById('checkpointRoot');
const newFolder = document.getElementById('cpNewFolder');
checkpointRoot.addEventListener('change', this.updateTargetPath);
newFolder.addEventListener('input', this.updateTargetPath);
// Update initial path
this.updateTargetPath();
}
cleanupFolderBrowser() {
if (this.folderClickHandler) {
const folderBrowser = document.getElementById('cpFolderBrowser');
if (folderBrowser) {
folderBrowser.removeEventListener('click', this.folderClickHandler);
this.folderClickHandler = null;
}
}
// Remove path update listeners
const checkpointRoot = document.getElementById('checkpointRoot');
const newFolder = document.getElementById('cpNewFolder');
if (checkpointRoot) checkpointRoot.removeEventListener('change', this.updateTargetPath);
if (newFolder) newFolder.removeEventListener('input', this.updateTargetPath);
}
updateTargetPath() {
const pathDisplay = document.getElementById('cpTargetPathDisplay');
const checkpointRoot = document.getElementById('checkpointRoot').value;
const newFolder = document.getElementById('cpNewFolder').value.trim();
let fullPath = checkpointRoot || 'Select a checkpoint root directory';
if (checkpointRoot) {
if (this.selectedFolder) {
fullPath += '/' + this.selectedFolder;
}
if (newFolder) {
fullPath += '/' + newFolder;
}
}
pathDisplay.innerHTML = `<span class="path-text">${fullPath}</span>`;
}
}

View File

@@ -1,150 +0,0 @@
/**
* CheckpointSearchManager - Specialized search manager for the Checkpoints page
* Extends the base SearchManager with checkpoint-specific functionality
*/
import { SearchManager } from './SearchManager.js';
import { state } from '../state/index.js';
import { showToast } from '../utils/uiHelpers.js';
export class CheckpointSearchManager extends SearchManager {
constructor(options = {}) {
super({
page: 'checkpoints',
...options
});
this.currentSearchTerm = '';
// Store this instance in the state
if (state) {
state.searchManager = this;
}
}
async performSearch() {
const searchTerm = this.searchInput.value.trim().toLowerCase();
if (searchTerm === this.currentSearchTerm && !this.isSearching) {
return; // Avoid duplicate searches
}
this.currentSearchTerm = searchTerm;
const grid = document.getElementById('checkpointGrid');
if (!searchTerm) {
if (state) {
state.currentPage = 1;
}
this.resetAndReloadCheckpoints();
return;
}
try {
this.isSearching = true;
if (state && state.loadingManager) {
state.loadingManager.showSimpleLoading('Searching checkpoints...');
}
// Store current scroll position
const scrollPosition = window.pageYOffset || document.documentElement.scrollTop;
if (state) {
state.currentPage = 1;
state.hasMore = true;
}
const url = new URL('/api/checkpoints', window.location.origin);
url.searchParams.set('page', '1');
url.searchParams.set('page_size', '20');
url.searchParams.set('sort_by', state ? state.sortBy : 'name');
url.searchParams.set('search', searchTerm);
url.searchParams.set('fuzzy', 'true');
// Add search options
const searchOptions = this.getActiveSearchOptions();
url.searchParams.set('search_filename', searchOptions.filename.toString());
url.searchParams.set('search_modelname', searchOptions.modelname.toString());
// Always send folder parameter if there is an active folder
if (state && state.activeFolder) {
url.searchParams.set('folder', state.activeFolder);
// Add recursive parameter when recursive search is enabled
const recursive = this.recursiveSearchToggle ? this.recursiveSearchToggle.checked : false;
url.searchParams.set('recursive', recursive.toString());
}
const response = await fetch(url);
if (!response.ok) {
throw new Error('Search failed');
}
const data = await response.json();
if (searchTerm === this.currentSearchTerm && grid) {
grid.innerHTML = '';
if (data.items.length === 0) {
grid.innerHTML = '<div class="no-results">No matching checkpoints found</div>';
if (state) {
state.hasMore = false;
}
} else {
this.appendCheckpointCards(data.items);
if (state) {
state.hasMore = state.currentPage < data.total_pages;
state.currentPage++;
}
}
// Restore scroll position after content is loaded
setTimeout(() => {
window.scrollTo({
top: scrollPosition,
behavior: 'instant' // Use 'instant' to prevent animation
});
}, 10);
}
} catch (error) {
console.error('Checkpoint search error:', error);
showToast('Checkpoint search failed', 'error');
} finally {
this.isSearching = false;
if (state && state.loadingManager) {
state.loadingManager.hide();
}
}
}
resetAndReloadCheckpoints() {
// This function would be implemented in the checkpoints page
if (typeof window.loadCheckpoints === 'function') {
window.loadCheckpoints();
} else {
// Fallback to reloading the page
window.location.reload();
}
}
appendCheckpointCards(checkpoints) {
// This function would be implemented in the checkpoints page
const grid = document.getElementById('checkpointGrid');
if (!grid) return;
if (typeof window.appendCheckpointCards === 'function') {
window.appendCheckpointCards(checkpoints);
} else {
// Fallback implementation
checkpoints.forEach(checkpoint => {
const card = document.createElement('div');
card.className = 'checkpoint-card';
card.innerHTML = `
<h3>${checkpoint.name}</h3>
<p>${checkpoint.filename || 'No filename'}</p>
`;
grid.appendChild(card);
});
}
}
}

View File

@@ -80,6 +80,10 @@ export class DownloadManager {
const response = await fetch(`/api/civitai/versions/${modelId}`);
if (!response.ok) {
const errorData = await response.json().catch(() => ({}));
if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) {
throw new Error('This model is not a LoRA. Please switch to the Checkpoints page to download checkpoint models.');
}
throw new Error('Failed to fetch model versions');
}

View File

@@ -1,7 +1,8 @@
import { BASE_MODELS, BASE_MODEL_CLASSES } from '../utils/constants.js';
import { state, getCurrentPageState } from '../state/index.js';
import { BASE_MODEL_CLASSES } from '../utils/constants.js';
import { getCurrentPageState } from '../state/index.js';
import { showToast, updatePanelPositions } from '../utils/uiHelpers.js';
import { loadMoreLoras } from '../api/loraApi.js';
import { loadMoreCheckpoints } from '../api/checkpointApi.js';
import { removeStorageItem, setStorageItem, getStorageItem } from '../utils/storageHelpers.js';
export class FilterManager {
@@ -70,8 +71,10 @@ export class FilterManager {
let tagsEndpoint = '/api/loras/top-tags?limit=20';
if (this.currentPage === 'recipes') {
tagsEndpoint = '/api/recipes/top-tags?limit=20';
} else if (this.currentPage === 'checkpoints') {
tagsEndpoint = '/api/checkpoints/top-tags?limit=20';
}
const response = await fetch(tagsEndpoint);
if (!response.ok) throw new Error('Failed to fetch tags');
@@ -143,8 +146,8 @@ export class FilterManager {
apiEndpoint = '/api/loras/base-models';
} else if (this.currentPage === 'recipes') {
apiEndpoint = '/api/recipes/base-models';
} else {
return; // No API endpoint for other pages
} else if (this.currentPage === 'checkpoints') {
apiEndpoint = '/api/checkpoints/base-models';
}
// Fetch base models
@@ -280,7 +283,7 @@ export class FilterManager {
// For loras page, reset the page and reload
await loadMoreLoras(true, true);
} else if (this.currentPage === 'checkpoints' && window.checkpointManager) {
await window.checkpointManager.loadCheckpoints(true);
await loadMoreCheckpoints(true);
}
// Update filter button to show active state
@@ -336,8 +339,8 @@ export class FilterManager {
await window.recipeManager.loadRecipes(true);
} else if (this.currentPage === 'loras') {
await loadMoreLoras(true, true);
} else if (this.currentPage === 'checkpoints' && window.checkpointManager) {
await window.checkpointManager.loadCheckpoints(true);
} else if (this.currentPage === 'checkpoints') {
await loadMoreCheckpoints(true);
}
showToast(`Filters cleared`, 'info');

View File

@@ -23,6 +23,32 @@ export class ModalManager {
});
}
// Add checkpointModal registration
const checkpointModal = document.getElementById('checkpointModal');
if (checkpointModal) {
this.registerModal('checkpointModal', {
element: checkpointModal,
onClose: () => {
this.getModal('checkpointModal').element.style.display = 'none';
document.body.classList.remove('modal-open');
},
closeOnOutsideClick: true
});
}
// Add checkpointDownloadModal registration
const checkpointDownloadModal = document.getElementById('checkpointDownloadModal');
if (checkpointDownloadModal) {
this.registerModal('checkpointDownloadModal', {
element: checkpointDownloadModal,
onClose: () => {
this.getModal('checkpointDownloadModal').element.style.display = 'none';
document.body.classList.remove('modal-open');
},
closeOnOutsideClick: true
});
}
const deleteModal = document.getElementById('deleteModal');
if (deleteModal) {
this.registerModal('deleteModal', {

View File

@@ -29,6 +29,7 @@ export class SearchManager {
this.initEventListeners();
this.loadSearchPreferences();
this.setupKeyboardShortcuts();
updatePanelPositions();
@@ -36,6 +37,26 @@ export class SearchManager {
window.addEventListener('resize', updatePanelPositions);
}
// Add this new method to setup keyboard shortcuts
setupKeyboardShortcuts() {
// Add global keyboard shortcut listener for Ctrl+F or Cmd+F
document.addEventListener('keydown', (e) => {
// Check for Ctrl+F (Windows/Linux) or Cmd+F (Mac)
if ((e.ctrlKey || e.metaKey) && e.key === 'f') {
// Prevent default browser search behavior
e.preventDefault();
// Focus the search input if it exists
if (this.searchInput) {
this.searchInput.focus();
// Optionally select all text in the search input for easy replacement
this.searchInput.select();
}
}
});
}
initEventListeners() {
// Search input event
if (this.searchInput) {
@@ -302,6 +323,7 @@ export class SearchManager {
pageState.searchOptions = {
filename: options.filename || false,
modelname: options.modelname || false,
tags: options.tags || false,
recursive: recursive
};
}

View File

@@ -50,6 +50,11 @@ export class SettingsManager {
observer.observe(settingsModal, { attributes: true });
}
// Add event listeners for all toggle-visibility buttons
document.querySelectorAll('.toggle-visibility').forEach(button => {
button.addEventListener('click', () => this.toggleInputVisibility(button));
});
this.initialized = true;
}
@@ -65,6 +70,12 @@ export class SettingsManager {
// Sync with state (backend will set this via template)
state.global.settings.show_only_sfw = showOnlySFWCheckbox.checked;
}
// Set video autoplay on hover setting
const autoplayOnHoverCheckbox = document.getElementById('autoplayOnHover');
if (autoplayOnHoverCheckbox) {
autoplayOnHoverCheckbox.checked = state.global.settings.autoplayOnHover || false;
}
// Load default lora root
await this.loadLoraRoots();
@@ -120,11 +131,183 @@ export class SettingsManager {
this.isOpen = !this.isOpen;
}
// Auto-save methods for different control types
// For toggle switches
async saveToggleSetting(elementId, settingKey) {
const element = document.getElementById(elementId);
if (!element) return;
const value = element.checked;
// Update frontend state
if (settingKey === 'blur_mature_content') {
state.global.settings.blurMatureContent = value;
} else if (settingKey === 'show_only_sfw') {
state.global.settings.show_only_sfw = value;
} else if (settingKey === 'autoplay_on_hover') {
state.global.settings.autoplayOnHover = value;
} else {
// For any other settings that might be added in the future
state.global.settings[settingKey] = value;
}
// Save to localStorage
setStorageItem('settings', state.global.settings);
try {
// For backend settings, make API call
if (['show_only_sfw', 'blur_mature_content', 'autoplay_on_hover'].includes(settingKey)) {
const payload = {};
payload[settingKey] = value;
const response = await fetch('/api/settings', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(payload)
});
if (!response.ok) {
throw new Error('Failed to save setting');
}
showToast(`Settings updated: ${settingKey.replace(/_/g, ' ')}`, 'success');
}
// Apply frontend settings immediately
this.applyFrontendSettings();
if (settingKey === 'show_only_sfw') {
this.reloadContent();
}
} catch (error) {
showToast('Failed to save setting: ' + error.message, 'error');
}
}
// For select dropdowns
async saveSelectSetting(elementId, settingKey) {
const element = document.getElementById(elementId);
if (!element) return;
const value = element.value;
// Update frontend state
if (settingKey === 'default_lora_root') {
state.global.settings.default_loras_root = value;
} else {
// For any other settings that might be added in the future
state.global.settings[settingKey] = value;
}
// Save to localStorage
setStorageItem('settings', state.global.settings);
try {
// For backend settings, make API call
const payload = {};
payload[settingKey] = value;
const response = await fetch('/api/settings', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(payload)
});
if (!response.ok) {
throw new Error('Failed to save setting');
}
showToast(`Settings updated: ${settingKey.replace(/_/g, ' ')}`, 'success');
} catch (error) {
showToast('Failed to save setting: ' + error.message, 'error');
}
}
// For input fields
async saveInputSetting(elementId, settingKey) {
const element = document.getElementById(elementId);
if (!element) return;
const value = element.value;
// For API key or other inputs that need to be saved on backend
try {
// Check if value has changed from existing value
const currentValue = state.global.settings[settingKey] || '';
if (value === currentValue) {
return; // No change, exit early
}
// Update state
state.global.settings[settingKey] = value;
// Save to localStorage if appropriate
if (!settingKey.includes('api_key')) { // Don't store API keys in localStorage for security
setStorageItem('settings', state.global.settings);
}
// For backend settings, make API call
const payload = {};
payload[settingKey] = value;
const response = await fetch('/api/settings', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(payload)
});
if (!response.ok) {
throw new Error('Failed to save setting');
}
showToast(`Settings updated: ${settingKey.replace(/_/g, ' ')}`, 'success');
} catch (error) {
showToast('Failed to save setting: ' + error.message, 'error');
}
}
toggleInputVisibility(button) {
const input = button.parentElement.querySelector('input');
const icon = button.querySelector('i');
if (input.type === 'password') {
input.type = 'text';
icon.className = 'fas fa-eye-slash';
} else {
input.type = 'password';
icon.className = 'fas fa-eye';
}
}
async reloadContent() {
if (this.currentPage === 'loras') {
// Reload the loras without updating folders
await resetAndReload(false);
} else if (this.currentPage === 'recipes') {
// Reload the recipes without updating folders
await window.recipeManager.loadRecipes();
} else if (this.currentPage === 'checkpoints') {
// Reload the checkpoints without updating folders
await window.checkpointsManager.loadCheckpoints();
}
}
async saveSettings() {
// Get frontend settings from UI
const blurMatureContent = document.getElementById('blurMatureContent').checked;
const showOnlySFW = document.getElementById('showOnlySFW').checked;
const defaultLoraRoot = document.getElementById('defaultLoraRoot').value;
const autoplayOnHover = document.getElementById('autoplayOnHover').checked;
// Get backend settings
const apiKey = document.getElementById('civitaiApiKey').value;
@@ -133,6 +316,7 @@ export class SettingsManager {
state.global.settings.blurMatureContent = blurMatureContent;
state.global.settings.show_only_sfw = showOnlySFW;
state.global.settings.default_loras_root = defaultLoraRoot;
state.global.settings.autoplayOnHover = autoplayOnHover;
// Save settings to localStorage
setStorageItem('settings', state.global.settings);
@@ -186,21 +370,38 @@ export class SettingsManager {
}
});
// Apply autoplay setting to existing videos in card previews
const autoplayOnHover = state.global.settings.autoplayOnHover;
document.querySelectorAll('.card-preview video').forEach(video => {
// Remove previous event listeners by cloning and replacing the element
const videoParent = video.parentElement;
const videoClone = video.cloneNode(true);
if (autoplayOnHover) {
// Pause video initially and set up mouse events for hover playback
videoClone.removeAttribute('autoplay');
videoClone.pause();
// Add mouse events to the parent element
videoParent.onmouseenter = () => videoClone.play();
videoParent.onmouseleave = () => {
videoClone.pause();
videoClone.currentTime = 0;
};
} else {
// Use default autoplay behavior
videoClone.setAttribute('autoplay', '');
videoParent.onmouseenter = null;
videoParent.onmouseleave = null;
}
videoParent.replaceChild(videoClone, video);
});
// For show_only_sfw, there's no immediate action needed as it affects content loading
// The setting will take effect on next reload
}
}
// Helper function for toggling API key visibility
export function toggleApiKeyVisibility(button) {
const input = button.parentElement.querySelector('input');
const icon = button.querySelector('i');
if (input.type === 'password') {
input.type = 'text';
icon.className = 'fas fa-eye-slash';
} else {
input.type = 'password';
icon.className = 'fas fa-eye';
}
}
// Create singleton instance
export const settingsManager = new SettingsManager();

View File

@@ -4,7 +4,7 @@ import { ImportManager } from './managers/ImportManager.js';
import { RecipeCard } from './components/RecipeCard.js';
import { RecipeModal } from './components/RecipeModal.js';
import { getCurrentPageState } from './state/index.js';
import { toggleApiKeyVisibility } from './managers/SettingsManager.js';
import { getSessionItem, removeSessionItem } from './utils/storageHelpers.js';
class RecipeManager {
constructor() {
@@ -20,6 +20,14 @@ class RecipeManager {
// Add state tracking for infinite scroll
this.pageState.isLoading = false;
this.pageState.hasMore = true;
// Custom filter state
this.customFilter = {
active: false,
loraName: null,
loraHash: null,
recipeId: null
};
}
async initialize() {
@@ -29,6 +37,9 @@ class RecipeManager {
// Set default search options if not already defined
this._initSearchOptions();
// Check for custom filter parameters in session storage
this._checkCustomFilter();
// Load initial set of recipes
await this.loadRecipes();
@@ -55,7 +66,100 @@ class RecipeManager {
// Only expose what's needed for the page
window.recipeManager = this;
window.importManager = this.importManager;
window.toggleApiKeyVisibility = toggleApiKeyVisibility;
}
_checkCustomFilter() {
// Check for Lora filter
const filterLoraName = getSessionItem('lora_to_recipe_filterLoraName');
const filterLoraHash = getSessionItem('lora_to_recipe_filterLoraHash');
// Check for specific recipe ID
const viewRecipeId = getSessionItem('viewRecipeId');
// Set custom filter if any parameter is present
if (filterLoraName || filterLoraHash || viewRecipeId) {
this.customFilter = {
active: true,
loraName: filterLoraName,
loraHash: filterLoraHash,
recipeId: viewRecipeId
};
// Show custom filter indicator
this._showCustomFilterIndicator();
}
}
_showCustomFilterIndicator() {
const indicator = document.getElementById('customFilterIndicator');
const textElement = document.getElementById('customFilterText');
if (!indicator || !textElement) return;
// Update text based on filter type
let filterText = '';
if (this.customFilter.recipeId) {
filterText = 'Viewing specific recipe';
} else if (this.customFilter.loraName) {
// Format with Lora name
const loraName = this.customFilter.loraName;
const displayName = loraName.length > 25 ?
loraName.substring(0, 22) + '...' :
loraName;
filterText = `<span>Recipes using: <span class="lora-name">${displayName}</span></span>`;
} else {
filterText = 'Filtered recipes';
}
// Update indicator text and show it
textElement.innerHTML = filterText;
// Add title attribute to show the lora name as a tooltip
if (this.customFilter.loraName) {
textElement.setAttribute('title', this.customFilter.loraName);
}
indicator.classList.remove('hidden');
// Add pulse animation
const filterElement = indicator.querySelector('.filter-active');
if (filterElement) {
filterElement.classList.add('animate');
setTimeout(() => filterElement.classList.remove('animate'), 600);
}
// Add click handler for clear filter button
const clearFilterBtn = indicator.querySelector('.clear-filter');
if (clearFilterBtn) {
clearFilterBtn.addEventListener('click', (e) => {
e.stopPropagation(); // Prevent button click from triggering
this._clearCustomFilter();
});
}
}
_clearCustomFilter() {
// Reset custom filter
this.customFilter = {
active: false,
loraName: null,
loraHash: null,
recipeId: null
};
// Hide indicator
const indicator = document.getElementById('customFilterIndicator');
if (indicator) {
indicator.classList.add('hidden');
}
// Clear any session storage items
removeSessionItem('lora_to_recipe_filterLoraName');
removeSessionItem('lora_to_recipe_filterLoraHash');
removeSessionItem('viewRecipeId');
// Reload recipes without custom filter
this.loadRecipes();
}
initEventListeners() {
@@ -83,6 +187,12 @@ class RecipeManager {
if (grid) grid.innerHTML = '';
}
// If we have a specific recipe ID to load
if (this.customFilter.active && this.customFilter.recipeId) {
await this._loadSpecificRecipe(this.customFilter.recipeId);
return;
}
// Build query parameters
const params = new URLSearchParams({
page: this.pageState.currentPage,
@@ -90,28 +200,38 @@ class RecipeManager {
sort_by: this.pageState.sortBy
});
// Add search filter if present
if (this.pageState.filters.search) {
params.append('search', this.pageState.filters.search);
// Add custom filter for Lora if present
if (this.customFilter.active && this.customFilter.loraHash) {
params.append('lora_hash', this.customFilter.loraHash);
// Add search option parameters
if (this.pageState.searchOptions) {
params.append('search_title', this.pageState.searchOptions.title.toString());
params.append('search_tags', this.pageState.searchOptions.tags.toString());
params.append('search_lora_name', this.pageState.searchOptions.loraName.toString());
params.append('search_lora_model', this.pageState.searchOptions.loraModel.toString());
params.append('fuzzy', 'true');
// Skip other filters when using custom filter
params.append('bypass_filters', 'true');
} else {
// Normal filtering logic
// Add search filter if present
if (this.pageState.filters.search) {
params.append('search', this.pageState.filters.search);
// Add search option parameters
if (this.pageState.searchOptions) {
params.append('search_title', this.pageState.searchOptions.title.toString());
params.append('search_tags', this.pageState.searchOptions.tags.toString());
params.append('search_lora_name', this.pageState.searchOptions.loraName.toString());
params.append('search_lora_model', this.pageState.searchOptions.loraModel.toString());
params.append('fuzzy', 'true');
}
}
// Add base model filters
if (this.pageState.filters.baseModel && this.pageState.filters.baseModel.length) {
params.append('base_models', this.pageState.filters.baseModel.join(','));
}
// Add tag filters
if (this.pageState.filters.tags && this.pageState.filters.tags.length) {
params.append('tags', this.pageState.filters.tags.join(','));
}
}
// Add base model filters
if (this.pageState.filters.baseModel && this.pageState.filters.baseModel.length) {
params.append('base_models', this.pageState.filters.baseModel.join(','));
}
// Add tag filters
if (this.pageState.filters.tags && this.pageState.filters.tags.length) {
params.append('tags', this.pageState.filters.tags.join(','));
}
// Fetch recipes
@@ -129,6 +249,11 @@ class RecipeManager {
// Update pagination state based on current page and total pages
this.pageState.hasMore = data.page < data.total_pages;
// Increment the page number AFTER successful loading
if (data.items.length > 0) {
this.pageState.currentPage++;
}
} catch (error) {
console.error('Error loading recipes:', error);
appCore.showToast('Failed to load recipes', 'error');
@@ -139,6 +264,46 @@ class RecipeManager {
}
}
async _loadSpecificRecipe(recipeId) {
try {
// Fetch specific recipe by ID
const response = await fetch(`/api/recipe/${recipeId}`);
if (!response.ok) {
throw new Error(`Failed to load recipe: ${response.statusText}`);
}
const recipe = await response.json();
// Create a data structure that matches the expected format
const recipeData = {
items: [recipe],
total: 1,
page: 1,
page_size: 1,
total_pages: 1
};
// Update grid with single recipe
this.updateRecipesGrid(recipeData, true);
// Pagination not needed for single recipe
this.pageState.hasMore = false;
// Show recipe details modal
setTimeout(() => {
this.showRecipeDetails(recipe);
}, 300);
} catch (error) {
console.error('Error loading specific recipe:', error);
appCore.showToast('Failed to load recipe details', 'error');
// Clear the filter and show all recipes
this._clearCustomFilter();
}
}
updateRecipesGrid(data, resetGrid = true) {
const grid = document.getElementById('recipeGrid');
if (!grid) return;

View File

@@ -0,0 +1,128 @@
/**
* Utility functions to update checkpoint cards after modal edits
*/
/**
* Update the checkpoint card after metadata edits in the modal
* @param {string} filePath - Path to the checkpoint file
* @param {Object} updates - Object containing the updates (model_name, base_model, etc)
*/
export function updateCheckpointCard(filePath, updates) {
// Find the card with matching filepath
const checkpointCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
if (!checkpointCard) return;
// Update card dataset and visual elements based on the updates object
Object.entries(updates).forEach(([key, value]) => {
// Update dataset
checkpointCard.dataset[key] = value;
// Update visual elements based on the property
switch(key) {
case 'name': // model_name
// Update the model name in the footer
const modelNameElement = checkpointCard.querySelector('.model-name');
if (modelNameElement) modelNameElement.textContent = value;
break;
case 'base_model':
// Update the base model label in the card header
const baseModelLabel = checkpointCard.querySelector('.base-model-label');
if (baseModelLabel) {
baseModelLabel.textContent = value;
baseModelLabel.title = value;
}
break;
case 'filepath':
// The filepath was changed (file renamed), update the dataset
checkpointCard.dataset.filepath = value;
break;
case 'tags':
// Update tags if they're displayed on the card
try {
checkpointCard.dataset.tags = JSON.stringify(value);
} catch (e) {
console.error('Failed to update tags:', e);
}
break;
// Add other properties as needed
}
});
}
/**
* Update the Lora card after metadata edits in the modal
* @param {string} filePath - Path to the Lora file
* @param {Object} updates - Object containing the updates (model_name, base_model, notes, usage_tips, etc)
* @param {string} [newFilePath] - Optional new file path if the file has been renamed
*/
export function updateLoraCard(filePath, updates, newFilePath) {
// Find the card with matching filepath
const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
if (!loraCard) return;
// If file was renamed, update the filepath first
if (newFilePath) {
loraCard.dataset.filepath = newFilePath;
}
// Update card dataset and visual elements based on the updates object
Object.entries(updates).forEach(([key, value]) => {
// Update dataset
loraCard.dataset[key] = value;
// Update visual elements based on the property
switch(key) {
case 'model_name':
// Update the model name in the card title
const titleElement = loraCard.querySelector('.card-title');
if (titleElement) titleElement.textContent = value;
// Also update the model name in the footer if it exists
const modelNameElement = loraCard.querySelector('.model-name');
if (modelNameElement) modelNameElement.textContent = value;
break;
case 'file_name':
// Update the file_name in the dataset
loraCard.dataset.file_name = value;
break;
case 'base_model':
// Update the base model label in the card header if it exists
const baseModelLabel = loraCard.querySelector('.base-model-label');
if (baseModelLabel) {
baseModelLabel.textContent = value;
baseModelLabel.title = value;
}
break;
case 'tags':
// Update tags if they're displayed on the card
try {
if (typeof value === 'string') {
loraCard.dataset.tags = value;
} else {
loraCard.dataset.tags = JSON.stringify(value);
}
// If there's a tag container, update its content
const tagContainer = loraCard.querySelector('.card-tags');
if (tagContainer) {
// This depends on how your tags are rendered
// You may need to update this logic based on your tag rendering function
}
} catch (e) {
console.error('Failed to update tags:', e);
}
break;
// No visual updates needed for notes, usage_tips as they're typically not shown on cards
}
});
return loraCard; // Return the updated card element for chaining
}

View File

@@ -1,5 +1,6 @@
import { state, getCurrentPageState } from '../state/index.js';
import { loadMoreLoras } from '../api/loraApi.js';
import { loadMoreCheckpoints } from '../api/checkpointApi.js';
import { debounce } from './debounce.js';
export function initializeInfiniteScroll(pageType = 'loras') {
@@ -21,7 +22,6 @@ export function initializeInfiniteScroll(pageType = 'loras') {
case 'recipes':
loadMoreFunction = () => {
if (!pageState.isLoading && pageState.hasMore) {
pageState.currentPage++;
window.recipeManager.loadRecipes(false); // false to not reset pagination
}
};
@@ -30,21 +30,25 @@ export function initializeInfiniteScroll(pageType = 'loras') {
case 'checkpoints':
loadMoreFunction = () => {
if (!pageState.isLoading && pageState.hasMore) {
pageState.currentPage++;
window.checkpointManager.loadCheckpoints(false); // false to not reset pagination
loadMoreCheckpoints(false); // false to not reset
}
};
gridId = 'checkpointGrid';
break;
case 'loras':
default:
loadMoreFunction = () => loadMoreLoras(false); // false to not reset
loadMoreFunction = () => {
if (!pageState.isLoading && pageState.hasMore) {
loadMoreLoras(false); // false to not reset
}
};
gridId = 'loraGrid';
break;
}
const debouncedLoadMore = debounce(loadMoreFunction, 100);
// Create a more robust observer with lower threshold and root margin
state.observer = new IntersectionObserver(
(entries) => {
const target = entries[0];
@@ -52,7 +56,10 @@ export function initializeInfiniteScroll(pageType = 'loras') {
debouncedLoadMore();
}
},
{ threshold: 0.1 }
{
threshold: 0.01, // Lower threshold to detect even minimal visibility
rootMargin: '0px 0px 300px 0px' // Increase bottom margin to trigger earlier
}
);
const grid = document.getElementById(gridId);
@@ -68,14 +75,14 @@ export function initializeInfiniteScroll(pageType = 'loras') {
// Create a wrapper div that will be placed after the grid
const sentinelWrapper = document.createElement('div');
sentinelWrapper.style.width = '100%';
sentinelWrapper.style.height = '10px';
sentinelWrapper.style.height = '30px'; // Increased height for better visibility
sentinelWrapper.style.margin = '0';
sentinelWrapper.style.padding = '0';
// Create the actual sentinel element
const sentinel = document.createElement('div');
sentinel.id = 'scroll-sentinel';
sentinel.style.height = '10px';
sentinel.style.height = '30px'; // Match wrapper height
// Add the sentinel to the wrapper
sentinelWrapper.appendChild(sentinel);
@@ -85,4 +92,37 @@ export function initializeInfiniteScroll(pageType = 'loras') {
state.observer.observe(sentinel);
}
}
// Add a scroll event backup to handle edge cases
const handleScroll = debounce(() => {
if (pageState.isLoading || !pageState.hasMore) return;
const sentinel = document.getElementById('scroll-sentinel');
if (!sentinel) return;
const rect = sentinel.getBoundingClientRect();
const windowHeight = window.innerHeight;
// If sentinel is within 500px of viewport bottom, load more
if (rect.top < windowHeight + 500) {
debouncedLoadMore();
}
}, 200);
// Clean up existing scroll listener if any
if (state.scrollHandler) {
window.removeEventListener('scroll', state.scrollHandler);
}
// Save reference to the handler for cleanup
state.scrollHandler = handleScroll;
window.addEventListener('scroll', state.scrollHandler);
// Check position immediately in case content is already visible
setTimeout(() => {
const sentinel = document.getElementById('scroll-sentinel');
if (sentinel && sentinel.getBoundingClientRect().top < window.innerHeight) {
debouncedLoadMore();
}
}, 100);
}

View File

@@ -1,10 +1,12 @@
import { modalManager } from '../managers/ModalManager.js';
let pendingDeletePath = null;
let pendingModelType = null;
export function showDeleteModal(filePath) {
event.stopPropagation();
export function showDeleteModal(filePath, modelType = 'lora') {
// event.stopPropagation();
pendingDeletePath = filePath;
pendingModelType = modelType;
const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
const modelName = card.dataset.name;
@@ -23,11 +25,15 @@ export function showDeleteModal(filePath) {
export async function confirmDelete() {
if (!pendingDeletePath) return;
const modal = document.getElementById('deleteModal');
const card = document.querySelector(`.lora-card[data-filepath="${pendingDeletePath}"]`);
try {
const response = await fetch('/api/delete_model', {
// Use the appropriate endpoint based on model type
const endpoint = pendingModelType === 'checkpoint' ?
'/api/checkpoints/delete' :
'/api/delete_model';
const response = await fetch(endpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
@@ -53,4 +59,6 @@ export async function confirmDelete() {
export function closeDeleteModal() {
modalManager.closeModal('deleteModal');
}
pendingDeletePath = null;
pendingModelType = null;
}

View File

@@ -69,6 +69,53 @@ export function removeStorageItem(key) {
localStorage.removeItem(key); // Also remove legacy key
}
/**
* Get an item from sessionStorage with namespace support
* @param {string} key - The key without prefix
* @param {any} defaultValue - Default value if key doesn't exist
* @returns {any} The stored value or defaultValue
*/
export function getSessionItem(key, defaultValue = null) {
// Try with prefix
const prefixedValue = sessionStorage.getItem(STORAGE_PREFIX + key);
if (prefixedValue !== null) {
// If it's a JSON string, parse it
try {
return JSON.parse(prefixedValue);
} catch (e) {
return prefixedValue;
}
}
// Return default value if key doesn't exist
return defaultValue;
}
/**
* Set an item in sessionStorage with namespace prefix
* @param {string} key - The key without prefix
* @param {any} value - The value to store
*/
export function setSessionItem(key, value) {
const prefixedKey = STORAGE_PREFIX + key;
// Convert objects and arrays to JSON strings
if (typeof value === 'object' && value !== null) {
sessionStorage.setItem(prefixedKey, JSON.stringify(value));
} else {
sessionStorage.setItem(prefixedKey, value);
}
}
/**
* Remove an item from sessionStorage with namespace prefix
* @param {string} key - The key without prefix
*/
export function removeSessionItem(key) {
sessionStorage.removeItem(STORAGE_PREFIX + key);
}
/**
* Migrate all existing localStorage items to use the prefix
* This should be called once during application initialization

View File

@@ -37,49 +37,6 @@
<link rel="preconnect" href="https://civitai.com">
<link rel="preconnect" href="https://cdnjs.cloudflare.com">
<!-- Add styles for initialization notice -->
{% if is_initializing %}
<style>
.initialization-notice {
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
background-color: rgba(0, 0, 0, 0.85);
z-index: 9999;
margin-top: 0;
border-radius: 0;
display: flex;
justify-content: center;
align-items: center;
color: white;
}
.notice-content {
background-color: rgba(30, 30, 30, 0.9);
border-radius: 10px;
padding: 30px;
text-align: center;
box-shadow: 0 0 20px rgba(0, 0, 0, 0.5);
max-width: 500px;
width: 80%;
}
.loading-spinner {
border: 5px solid rgba(255, 255, 255, 0.3);
border-radius: 50%;
border-top: 5px solid #fff;
width: 50px;
height: 50px;
margin: 0 auto 20px;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
</style>
{% endif %}
<script>
// 计算滚动条宽度并设置CSS变量
document.addEventListener('DOMContentLoaded', () => {
@@ -95,15 +52,7 @@
</head>
<body data-page="{% block page_id %}base{% endblock %}">
{% if is_initializing %}
<div class="initialization-notice">
<div class="notice-content">
<div class="loading-spinner"></div>
<h2>{% block init_title %}Initializing{% endblock %}</h2>
<p>{% block init_message %}Scanning and building cache. This may take a few minutes...{% endblock %}</p>
</div>
</div>
{% else %}
<!-- Header is always visible, even during initialization -->
{% include 'components/header.html' %}
<div class="page-content">
@@ -113,37 +62,23 @@
{% block additional_components %}{% endblock %}
<div class="container">
{% block content %}{% endblock %}
{% if is_initializing %}
<!-- Show initialization component when initializing -->
{% include 'components/initialization.html' %}
{% else %}
<!-- Show regular content when not initializing -->
{% block content %}{% endblock %}
{% endif %}
</div>
{% block overlay %}{% endblock %}
</div>
{% endif %}
{% block main_script %}{% endblock %}
{% if is_initializing %}
<script>
// 检查初始化状态并设置自动刷新
async function checkInitStatus() {
try {
const response = await fetch('{% block init_check_url %}/api/loras?page=1&page_size=1{% endblock %}');
if (response.ok) {
// 如果成功获取数据,说明初始化完成,刷新页面
window.location.reload();
} else {
// 如果还未完成,继续轮询
setTimeout(checkInitStatus, 2000); // 每2秒检查一次
}
} catch (error) {
// 如果出错,继续轮询
setTimeout(checkInitStatus, 2000);
}
}
// 启动状态检查
setTimeout(checkInitStatus, 1000); // 给页面完全加载的时间
</script>
<!-- Load initialization JavaScript -->
<script type="module" src="/loras_static/js/components/initialization.js"></script>
{% else %}
{% block main_script %}{% endblock %}
{% endif %}
{% block additional_scripts %}{% endblock %}

View File

@@ -8,18 +8,19 @@
{% endblock %}
{% block init_title %}Initializing Checkpoints Manager{% endblock %}
{% block init_message %}Setting up checkpoints interface. This may take a few moments...{% endblock %}
{% block init_message %}Scanning and building checkpoints cache. This may take a few moments...{% endblock %}
{% block init_check_url %}/api/checkpoints?page=1&page_size=1{% endblock %}
{% block additional_components %}
{% include 'components/checkpoint_modals.html' %}
{% endblock %}
{% block content %}
<div class="work-in-progress">
<div class="wip-content">
<i class="fas fa-tools"></i>
<h2>Checkpoints Manager</h2>
<p>This feature is currently under development and will be available soon.</p>
<p>Please check back later for updates!</p>
{% include 'components/controls.html' %}
<!-- Checkpoint cards container -->
<div class="card-grid" id="checkpointGrid">
<!-- Cards will be dynamically inserted here -->
</div>
</div>
{% endblock %}
{% block main_script %}

View File

@@ -0,0 +1,73 @@
<!-- Checkpoint Modals -->
<!-- Checkpoint details Modal -->
<div id="checkpointModal" class="modal"></div>
<!-- Download Checkpoint from URL Modal -->
<div id="checkpointDownloadModal" class="modal">
<div class="modal-content">
<button class="close" onclick="modalManager.closeModal('checkpointDownloadModal')">&times;</button>
<h2>Download Checkpoint from URL</h2>
<!-- Step 1: URL Input -->
<div class="download-step" id="cpUrlStep">
<div class="input-group">
<label for="checkpointUrl">Civitai URL:</label>
<input type="text" id="checkpointUrl" placeholder="https://civitai.com/models/..." />
<div class="error-message" id="cpUrlError"></div>
</div>
<div class="modal-actions">
<button class="primary-btn" onclick="checkpointDownloadManager.validateAndFetchVersions()">Next</button>
</div>
</div>
<!-- Step 2: Version Selection -->
<div class="download-step" id="cpVersionStep" style="display: none;">
<div class="version-list" id="cpVersionList">
<!-- Versions will be inserted here dynamically -->
</div>
<div class="modal-actions">
<button class="secondary-btn" onclick="checkpointDownloadManager.backToUrl()">Back</button>
<button class="primary-btn" onclick="checkpointDownloadManager.proceedToLocation()">Next</button>
</div>
</div>
<!-- Step 3: Location Selection -->
<div class="download-step" id="cpLocationStep" style="display: none;">
<div class="location-selection">
<!-- Move path preview to top -->
<div class="path-preview">
<label>Download Location Preview:</label>
<div class="path-display" id="cpTargetPathDisplay">
<span class="path-text">Select a Checkpoint root directory</span>
</div>
</div>
<div class="input-group">
<label>Select Checkpoint Root:</label>
<select id="checkpointRoot"></select>
</div>
<div class="input-group">
<label>Target Folder:</label>
<div class="folder-browser" id="cpFolderBrowser">
{% for folder in folders %}
{% if folder %}
<div class="folder-item" data-folder="{{ folder }}">
{{ folder }}
</div>
{% endif %}
{% endfor %}
</div>
</div>
<div class="input-group">
<label for="cpNewFolder">New Folder (optional):</label>
<input type="text" id="cpNewFolder" placeholder="Enter folder name" />
</div>
</div>
<div class="modal-actions">
<button class="secondary-btn" onclick="checkpointDownloadManager.backToVersions()">Back</button>
<button class="primary-btn" onclick="checkpointDownloadManager.startDownload()">Download</button>
</div>
</div>
</div>
</div>

View File

@@ -16,24 +16,34 @@
</select>
</div>
<div title="Refresh model list" class="control-group">
<button onclick="refreshLoras()"><i class="fas fa-sync"></i> Refresh</button>
<button data-action="refresh"><i class="fas fa-sync"></i> Refresh</button>
</div>
<div class="control-group">
<button data-action="fetch" title="Fetch from Civitai"><i class="fas fa-download"></i> Fetch</button>
</div>
<div class="control-group">
<button onclick="fetchCivitai()" class="secondary" title="Fetch from Civitai"><i class="fas fa-download"></i> Fetch</button>
</div>
<div class="control-group">
<button onclick="downloadManager.showDownloadModal()" title="Download from URL">
<button data-action="download" title="Download from URL">
<i class="fas fa-cloud-download-alt"></i> Download
</button>
</div>
<!-- Conditional buttons based on page -->
{% if request.path == '/loras' %}
<div class="control-group">
<button id="bulkOperationsBtn" onclick="bulkManager.toggleBulkMode()" title="Bulk Operations">
<button id="bulkOperationsBtn" data-action="bulk" title="Bulk Operations">
<i class="fas fa-th-large"></i> Bulk
</button>
</div>
{% endif %}
<div id="customFilterIndicator" class="control-group hidden">
<div class="filter-active">
<i class="fas fa-filter"></i> <span class="customFilterText" title=""></span>
<i class="fas fa-times-circle clear-filter"></i>
</div>
</div>
</div>
<div class="toggle-folders-container">
<button class="toggle-folders-btn icon-only" onclick="toggleFolderTags()" title="Toggle folder tags">
<button class="toggle-folders-btn icon-only" title="Toggle folder tags">
<i class="fas fa-tags"></i>
</button>
</div>

View File

@@ -74,6 +74,7 @@
{% elif request.path == '/checkpoints' %}
<div class="search-option-tag active" data-option="filename">Filename</div>
<div class="search-option-tag active" data-option="modelname">Checkpoint Name</div>
<div class="search-option-tag active" data-option="tags">Tags</div>
{% else %}
<!-- Default options for LoRAs page -->
<div class="search-option-tag active" data-option="filename">Filename</div>

View File

@@ -0,0 +1,92 @@
<!-- Initialization Component -->
<div class="initialization-container" id="initializationContainer">
<div class="initialization-content">
<div class="initialization-header">
<h2 id="initTitle">{% block init_title %}Initializing{% endblock %}</h2>
<p class="init-subtitle" id="initSubtitle">{% block init_message %}Preparing your workspace...{% endblock %}
</p>
</div>
<div class="loading-content">
<div class="loading-spinner"></div>
<div class="loading-status" id="progressStatus">Initializing...</div>
<!-- Use initialization-specific classes for the progress bar -->
<div class="init-progress-container">
<div class="init-progress-bar" id="initProgressBar"></div>
</div>
<div class="progress-details">
<span id="progressPercentage">0%</span>
<span id="remainingTime">Estimating time...</span>
</div>
</div>
<div class="tips-container">
<div class="tips-header">
<i class="fas fa-lightbulb"></i>
<h3>Tips & Tricks</h3>
</div>
<div class="tips-content">
<div class="tip-carousel" id="tipCarousel">
<div class="tip-item active">
<div class="tip-image">
<img src="/loras_static/images/tips/civitai-api.png" alt="Civitai API Setup"
onerror="this.src='/loras_static/images/no-preview.png'">
</div>
<div class="tip-text">
<h4>Civitai Integration</h4>
<p>Connect your Civitai account: Visit Profile Avatar → Settings → API Keys → Add API Key,
then paste it in Lora Manager settings.</p>
</div>
</div>
<div class="tip-item">
<div class="tip-image">
<img src="/loras_static/images/tips/civitai-download.png" alt="Civitai Download"
onerror="this.src='/loras_static/images/no-preview.png'">
</div>
<div class="tip-text">
<h4>Easy Download</h4>
<p>Use Civitai URLs to quickly download and install new models.</p>
</div>
</div>
<div class="tip-item">
<div class="tip-image">
<img src="/loras_static/images/tips/recipes.png" alt="Recipes"
onerror="this.src='/loras_static/images/no-preview.png'">
</div>
<div class="tip-text">
<h4>Save Recipes</h4>
<p>Create recipes to save your favorite model combinations for future use.</p>
</div>
</div>
<div class="tip-item">
<div class="tip-image">
<img src="/loras_static/images/tips/filter.png" alt="Filter Models"
onerror="this.src='/loras_static/images/no-preview.png'">
</div>
<div class="tip-text">
<h4>Fast Filtering</h4>
<p>Filter models by tags or base model type using the filter button in the header.</p>
</div>
</div>
<div class="tip-item">
<div class="tip-image">
<img src="/loras_static/images/tips/search.webp" alt="Quick Search"
onerror="this.src='/loras_static/images/no-preview.png'">
</div>
<div class="tip-text">
<h4>Quick Search</h4>
<p>Press Ctrl+F (Cmd+F on Mac) to quickly search within your current view.</p>
</div>
</div>
</div>
<div class="tip-navigation">
<span class="tip-dot active" data-index="0"></span>
<span class="tip-dot" data-index="1"></span>
<span class="tip-dot" data-index="2"></span>
<span class="tip-dot" data-index="3"></span>
<span class="tip-dot" data-index="4"></span>
</div>
</div>
</div>
</div>
</div>

View File

@@ -1,7 +1,7 @@
<!-- Model details Modal -->
<div id="loraModal" class="modal"></div>
<!-- Download from URL Modal -->
<!-- Download from URL Modal (for LoRAs) -->
<div id="downloadModal" class="modal">
<div class="modal-content">
<button class="close" onclick="modalManager.closeModal('downloadModal')">&times;</button>
@@ -111,4 +111,4 @@
<button class="primary-btn" onclick="moveManager.moveModel()">Move</button>
</div>
</div>
</div>
</div>

View File

@@ -17,16 +17,24 @@
<button class="close" onclick="modalManager.closeModal('settingsModal')">&times;</button>
<h2>Settings</h2>
<div class="settings-form">
<div class="input-group">
<label for="civitaiApiKey">Civitai API Key:</label>
<div class="api-key-input">
<input type="password"
id="civitaiApiKey"
placeholder="Enter your Civitai API key"
value="{{ settings.get('civitai_api_key', '') }}" />
<button class="toggle-visibility" onclick="toggleApiKeyVisibility(this)">
<i class="fas fa-eye"></i>
</button>
<div class="setting-item api-key-item">
<div class="setting-row">
<div class="setting-info">
<label for="civitaiApiKey">Civitai API Key:</label>
</div>
<div class="setting-control">
<div class="api-key-input">
<input type="password"
id="civitaiApiKey"
placeholder="Enter your Civitai API key"
value="{{ settings.get('civitai_api_key', '') }}"
onblur="settingsManager.saveInputSetting('civitaiApiKey', 'civitai_api_key')"
onkeydown="if(event.key === 'Enter') { this.blur(); }" />
<button class="toggle-visibility">
<i class="fas fa-eye"></i>
</button>
</div>
</div>
</div>
<div class="input-help">
Used for authentication when downloading models from Civitai
@@ -37,32 +45,61 @@
<h3>Content Filtering</h3>
<div class="setting-item">
<div class="setting-info">
<label for="blurMatureContent">Blur NSFW Content</label>
<div class="input-help">
Blur mature (NSFW) content preview images
<div class="setting-row">
<div class="setting-info">
<label for="blurMatureContent">Blur NSFW Content</label>
</div>
<div class="setting-control">
<label class="toggle-switch">
<input type="checkbox" id="blurMatureContent" checked
onchange="settingsManager.saveToggleSetting('blurMatureContent', 'blur_mature_content')">
<span class="toggle-slider"></span>
</label>
</div>
</div>
<div class="setting-control">
<label class="toggle-switch">
<input type="checkbox" id="blurMatureContent" checked>
<span class="toggle-slider"></span>
</label>
<div class="input-help">
Blur mature (NSFW) content preview images
</div>
</div>
<div class="setting-item">
<div class="setting-info">
<label for="showOnlySFW">Show Only SFW Results</label>
<div class="input-help">
Filter out all NSFW content when browsing and searching
<div class="setting-row">
<div class="setting-info">
<label for="showOnlySFW">Show Only SFW Results</label>
</div>
<div class="setting-control">
<label class="toggle-switch">
<input type="checkbox" id="showOnlySFW" value="{{ settings.get('show_only_sfw', False) }}" {% if settings.get('show_only_sfw', False) %}checked{% endif %}
onchange="settingsManager.saveToggleSetting('showOnlySFW', 'show_only_sfw')">
<span class="toggle-slider"></span>
</label>
</div>
</div>
<div class="setting-control">
<label class="toggle-switch">
<input type="checkbox" id="showOnlySFW" value="{{ settings.get('show_only_sfw', False) }}" {% if settings.get('show_only_sfw', False) %}checked{% endif %}>
<span class="toggle-slider"></span>
</label>
<div class="input-help">
Filter out all NSFW content when browsing and searching
</div>
</div>
</div>
<!-- Add Video Settings Section -->
<div class="settings-section">
<h3>Video Settings</h3>
<div class="setting-item">
<div class="setting-row">
<div class="setting-info">
<label for="autoplayOnHover">Autoplay Videos on Hover</label>
</div>
<div class="setting-control">
<label class="toggle-switch">
<input type="checkbox" id="autoplayOnHover"
onchange="settingsManager.saveToggleSetting('autoplayOnHover', 'autoplay_on_hover')">
<span class="toggle-slider"></span>
</label>
</div>
</div>
<div class="input-help">
Only play video previews when hovering over them
</div>
</div>
</div>
@@ -72,14 +109,16 @@
<h3>Folder Settings</h3>
<div class="setting-item">
<div class="setting-info">
<label for="defaultLoraRoot">Default LoRA Root</label>
</div>
<div class="setting-control select-control">
<select id="defaultLoraRoot">
<option value="">No Default</option>
<!-- Options will be loaded dynamically -->
</select>
<div class="setting-row">
<div class="setting-info">
<label for="defaultLoraRoot">Default LoRA Root</label>
</div>
<div class="setting-control select-control">
<select id="defaultLoraRoot" onchange="settingsManager.saveSelectSetting('defaultLoraRoot', 'default_lora_root')">
<option value="">No Default</option>
<!-- Options will be loaded dynamically -->
</select>
</div>
</div>
<div class="input-help">
Set the default LoRA root directory for downloads, imports and moves
@@ -87,18 +126,6 @@
</div>
</div>
</div>
<div class="modal-actions">
<button class="primary-btn" onclick="settingsManager.saveSettings()">Save</button>
</div>
<!-- Add the new links section -->
<div class="settings-links">
<a href="https://github.com/willmiao/ComfyUI-Lora-Manager" target="_blank" class="settings-link" title="GitHub Repository">
<i class="fab fa-github"></i>
</a>
<a href="https://github.com/willmiao/ComfyUI-Lora-Manager/issues/new" target="_blank" class="settings-link" title="Report Issue">
<i class="fas fa-bug"></i>
</a>
</div>
</div>
</div>

View File

@@ -58,6 +58,9 @@
<h3>Resources</h3>
<div class="recipe-section-actions">
<span id="recipeLorasCount"><i class="fas fa-layer-group"></i> 0 LoRAs</span>
<button class="action-btn view-loras-btn" id="viewRecipeLorasBtn" title="View all LoRAs in this recipe">
<i class="fas fa-external-link-alt"></i>
</button>
<button class="copy-btn" id="copyRecipeSyntaxBtn" title="Copy Recipe Syntax">
<i class="fas fa-copy"></i>
</button>

View File

@@ -32,6 +32,13 @@
<div title="Import recipes" class="control-group">
<button onclick="importManager.showImportModal()"><i class="fas fa-file-import"></i> Import</button>
</div>
<!-- Custom filter indicator button (hidden by default) -->
<div id="customFilterIndicator" class="control-group hidden">
<div class="filter-active">
<i class="fas fa-filter"></i> <span id="customFilterText">Filtered by LoRA</span>
<i class="fas fa-times-circle clear-filter"></i>
</div>
</div>
</div>
</div>