mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 05:32:12 -03:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca692ed0f2 | ||
|
|
af499565d3 | ||
|
|
fe2d7e3a9e | ||
|
|
9f69822221 | ||
|
|
bb43f047c2 | ||
|
|
2356662492 | ||
|
|
1624a45093 | ||
|
|
dcb9983786 | ||
|
|
83d1828905 | ||
|
|
6a281cf3ee | ||
|
|
ed1cd39a6c | ||
|
|
dda19b3920 | ||
|
|
25139ca922 | ||
|
|
3cd57a582c | ||
|
|
d3903ac655 |
@@ -34,6 +34,11 @@ Enhance your Civitai browsing experience with our companion browser extension! S
|
||||
|
||||
## Release Notes
|
||||
|
||||
### v0.8.28
|
||||
* **Autocomplete for Node Inputs** - Instantly find and add LoRAs by filename directly in Lora Loader, Lora Stacker, and WanVideo Lora Select nodes. Autocomplete suggestions include preview tooltips and preset weights, allowing you to quickly select LoRAs without opening the LoRA Manager UI.
|
||||
* **Duplicate Notification Control** - Added a switch to duplicates mode, enabling users to turn off duplicate model notifications for a more streamlined experience.
|
||||
* **Download Example Images from Context Menu** - Introduced a new context menu option to download example images for individual models.
|
||||
|
||||
### v0.8.27
|
||||
* **User Experience Enhancements** - Improved the model download target folder selection with path input autocomplete and interactive folder tree navigation, making it easier and faster to choose where models are saved.
|
||||
* **Default Path Option for Downloads** - Added a "Use Default Path" option when downloading models. When enabled, models are automatically organized and stored according to your configured path template settings.
|
||||
|
||||
@@ -339,44 +339,8 @@ class MetadataProcessor:
|
||||
is_custom_advanced = prompt.original_prompt[primary_sampler_id].get("class_type") == "SamplerCustomAdvanced"
|
||||
|
||||
if is_custom_advanced:
|
||||
# For SamplerCustomAdvanced, trace specific inputs
|
||||
|
||||
# 1. Trace sigmas input to find BasicScheduler
|
||||
scheduler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sigmas", "BasicScheduler", max_depth=5)
|
||||
if scheduler_node_id and scheduler_node_id in metadata.get(SAMPLING, {}):
|
||||
scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {})
|
||||
params["steps"] = scheduler_params.get("steps")
|
||||
params["scheduler"] = scheduler_params.get("scheduler")
|
||||
|
||||
# 2. Trace sampler input to find KSamplerSelect
|
||||
sampler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sampler", "KSamplerSelect", max_depth=5)
|
||||
if sampler_node_id and sampler_node_id in metadata.get(SAMPLING, {}):
|
||||
sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {})
|
||||
params["sampler"] = sampler_params.get("sampler_name")
|
||||
|
||||
# 3. Trace guider input for CFGGuider and CLIPTextEncode
|
||||
guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5)
|
||||
if guider_node_id and guider_node_id in prompt.original_prompt:
|
||||
# Check if the guider node is a CFGGuider
|
||||
if prompt.original_prompt[guider_node_id].get("class_type") == "CFGGuider":
|
||||
# Extract cfg value from the CFGGuider
|
||||
if guider_node_id in metadata.get(SAMPLING, {}):
|
||||
cfg_params = metadata[SAMPLING][guider_node_id].get("parameters", {})
|
||||
params["cfg_scale"] = cfg_params.get("cfg")
|
||||
|
||||
# Find CLIPTextEncode for positive prompt
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "positive", "CLIPTextEncode", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
|
||||
# Find CLIPTextEncode for negative prompt
|
||||
negative_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "negative", "CLIPTextEncode", max_depth=10)
|
||||
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
||||
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
||||
else:
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
# For SamplerCustomAdvanced, use the new handler method
|
||||
MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params)
|
||||
|
||||
else:
|
||||
# For standard samplers, match conditioning objects to prompts
|
||||
@@ -401,6 +365,9 @@ class MetadataProcessor:
|
||||
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", max_depth=10)
|
||||
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
||||
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
||||
|
||||
# For SamplerCustom, handle any additional parameters
|
||||
MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params)
|
||||
|
||||
# Size extraction is same for all sampler types
|
||||
# Check if the sampler itself has size information (from latent_image)
|
||||
@@ -454,3 +421,59 @@ class MetadataProcessor:
|
||||
"""Convert metadata to JSON string"""
|
||||
params = MetadataProcessor.to_dict(metadata, id)
|
||||
return json.dumps(params, indent=4)
|
||||
|
||||
@staticmethod
|
||||
def handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params):
|
||||
"""
|
||||
Handle parameter extraction for SamplerCustomAdvanced nodes
|
||||
|
||||
Parameters:
|
||||
- metadata: The workflow metadata
|
||||
- prompt: The prompt object containing node connections
|
||||
- primary_sampler_id: ID of the SamplerCustomAdvanced node
|
||||
- params: Parameters dictionary to update
|
||||
"""
|
||||
if not prompt.original_prompt or primary_sampler_id not in prompt.original_prompt:
|
||||
return
|
||||
|
||||
sampler_inputs = prompt.original_prompt[primary_sampler_id].get("inputs", {})
|
||||
|
||||
# 1. Trace sigmas input to find BasicScheduler (only if sigmas input exists)
|
||||
if "sigmas" in sampler_inputs:
|
||||
scheduler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sigmas", None, max_depth=5)
|
||||
if scheduler_node_id and scheduler_node_id in metadata.get(SAMPLING, {}):
|
||||
scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {})
|
||||
params["steps"] = scheduler_params.get("steps")
|
||||
params["scheduler"] = scheduler_params.get("scheduler")
|
||||
|
||||
# 2. Trace sampler input to find KSamplerSelect (only if sampler input exists)
|
||||
if "sampler" in sampler_inputs:
|
||||
sampler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sampler", "KSamplerSelect", max_depth=5)
|
||||
if sampler_node_id and sampler_node_id in metadata.get(SAMPLING, {}):
|
||||
sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {})
|
||||
params["sampler"] = sampler_params.get("sampler_name")
|
||||
|
||||
# 3. Trace guider input for CFGGuider and CLIPTextEncode
|
||||
if "guider" in sampler_inputs:
|
||||
guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5)
|
||||
if guider_node_id and guider_node_id in prompt.original_prompt:
|
||||
# Check if the guider node is a CFGGuider
|
||||
if prompt.original_prompt[guider_node_id].get("class_type") == "CFGGuider":
|
||||
# Extract cfg value from the CFGGuider
|
||||
if guider_node_id in metadata.get(SAMPLING, {}):
|
||||
cfg_params = metadata[SAMPLING][guider_node_id].get("parameters", {})
|
||||
params["cfg_scale"] = cfg_params.get("cfg")
|
||||
|
||||
# Find CLIPTextEncode for positive prompt
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "positive", "CLIPTextEncode", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
|
||||
# Find CLIPTextEncode for negative prompt
|
||||
negative_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "negative", "CLIPTextEncode", max_depth=10)
|
||||
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
||||
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
||||
else:
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
|
||||
@@ -642,6 +642,7 @@ NODE_EXTRACTORS = {
|
||||
# Sampling
|
||||
"KSampler": SamplerExtractor,
|
||||
"KSamplerAdvanced": KSamplerAdvancedExtractor,
|
||||
"SamplerCustom": KSamplerAdvancedExtractor,
|
||||
"SamplerCustomAdvanced": SamplerCustomAdvancedExtractor,
|
||||
"TSC_KSampler": TSCKSamplerExtractor, # Efficient Nodes
|
||||
"TSC_KSamplerAdvanced": TSCKSamplerAdvancedExtractor, # Efficient Nodes
|
||||
@@ -652,9 +653,11 @@ NODE_EXTRACTORS = {
|
||||
# Sampling Selectors
|
||||
"KSamplerSelect": KSamplerSelectExtractor, # Add KSamplerSelect
|
||||
"BasicScheduler": BasicSchedulerExtractor, # Add BasicScheduler
|
||||
"AlignYourStepsScheduler": BasicSchedulerExtractor, # Add AlignYourStepsScheduler
|
||||
# Loaders
|
||||
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
|
||||
"comfyLoader": CheckpointLoaderExtractor, # easy comfyLoader
|
||||
"CheckpointLoaderSimpleWithImages": CheckpointLoaderExtractor, # CheckpointLoader|pysssss
|
||||
"TSC_EfficientLoader": TSCCheckpointLoaderExtractor, # Efficient Nodes
|
||||
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||
"UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||
|
||||
@@ -54,6 +54,7 @@ class BaseModelRoutes(ABC):
|
||||
app.router.add_post(f'/api/{prefix}/move_model', self.move_model)
|
||||
app.router.add_post(f'/api/{prefix}/move_models_bulk', self.move_models_bulk)
|
||||
app.router.add_get(f'/api/{prefix}/auto-organize', self.auto_organize_models)
|
||||
app.router.add_get(f'/api/{prefix}/auto-organize-progress', self.get_auto_organize_progress)
|
||||
|
||||
# Common query routes
|
||||
app.router.add_get(f'/api/{prefix}/top-tags', self.get_top_tags)
|
||||
@@ -65,6 +66,12 @@ class BaseModelRoutes(ABC):
|
||||
app.router.add_get(f'/api/{prefix}/unified-folder-tree', self.get_unified_folder_tree)
|
||||
app.router.add_get(f'/api/{prefix}/find-duplicates', self.find_duplicate_models)
|
||||
app.router.add_get(f'/api/{prefix}/find-filename-conflicts', self.find_filename_conflicts)
|
||||
app.router.add_get(f'/api/{prefix}/get-notes', self.get_model_notes)
|
||||
app.router.add_get(f'/api/{prefix}/preview-url', self.get_model_preview_url)
|
||||
app.router.add_get(f'/api/{prefix}/civitai-url', self.get_model_civitai_url)
|
||||
|
||||
# Autocomplete route
|
||||
app.router.add_get(f'/api/{prefix}/relative-paths', self.get_relative_paths)
|
||||
|
||||
# Common Download management
|
||||
app.router.add_post(f'/api/download-model', self.download_model)
|
||||
@@ -743,6 +750,43 @@ class BaseModelRoutes(ABC):
|
||||
|
||||
async def auto_organize_models(self, request: web.Request) -> web.Response:
|
||||
"""Auto-organize all models based on current settings"""
|
||||
try:
|
||||
# Check if auto-organize is already running
|
||||
if ws_manager.is_auto_organize_running():
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Auto-organize is already running. Please wait for it to complete.'
|
||||
}, status=409)
|
||||
|
||||
# Acquire lock to prevent concurrent auto-organize operations
|
||||
auto_organize_lock = await ws_manager.get_auto_organize_lock()
|
||||
|
||||
if auto_organize_lock.locked():
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Auto-organize is already running. Please wait for it to complete.'
|
||||
}, status=409)
|
||||
|
||||
async with auto_organize_lock:
|
||||
return await self._perform_auto_organize()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in auto_organize_models: {e}", exc_info=True)
|
||||
|
||||
# Send error message via WebSocket and cleanup
|
||||
await ws_manager.broadcast_auto_organize_progress({
|
||||
'type': 'auto_organize_progress',
|
||||
'status': 'error',
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def _perform_auto_organize(self) -> web.Response:
|
||||
"""Perform the actual auto-organize operation"""
|
||||
try:
|
||||
# Get all models from cache
|
||||
cache = await self.service.scanner.get_cached_data()
|
||||
@@ -751,6 +795,11 @@ class BaseModelRoutes(ABC):
|
||||
# Get model roots for this scanner
|
||||
model_roots = self.service.get_model_roots()
|
||||
if not model_roots:
|
||||
await ws_manager.broadcast_auto_organize_progress({
|
||||
'type': 'auto_organize_progress',
|
||||
'status': 'error',
|
||||
'error': 'No model roots configured'
|
||||
})
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No model roots configured'
|
||||
@@ -769,7 +818,7 @@ class BaseModelRoutes(ABC):
|
||||
skipped_count = 0
|
||||
|
||||
# Send initial progress via WebSocket
|
||||
await ws_manager.broadcast({
|
||||
await ws_manager.broadcast_auto_organize_progress({
|
||||
'type': 'auto_organize_progress',
|
||||
'status': 'started',
|
||||
'total': total_models,
|
||||
@@ -900,7 +949,7 @@ class BaseModelRoutes(ABC):
|
||||
processed += 1
|
||||
|
||||
# Send progress update after each batch
|
||||
await ws_manager.broadcast({
|
||||
await ws_manager.broadcast_auto_organize_progress({
|
||||
'type': 'auto_organize_progress',
|
||||
'status': 'processing',
|
||||
'total': total_models,
|
||||
@@ -914,7 +963,7 @@ class BaseModelRoutes(ABC):
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Send completion message
|
||||
await ws_manager.broadcast({
|
||||
await ws_manager.broadcast_auto_organize_progress({
|
||||
'type': 'auto_organize_progress',
|
||||
'status': 'cleaning',
|
||||
'total': total_models,
|
||||
@@ -933,7 +982,7 @@ class BaseModelRoutes(ABC):
|
||||
cleanup_counts[root] = removed
|
||||
|
||||
# Send cleanup completed message
|
||||
await ws_manager.broadcast({
|
||||
await ws_manager.broadcast_auto_organize_progress({
|
||||
'type': 'auto_organize_progress',
|
||||
'status': 'completed',
|
||||
'total': total_models,
|
||||
@@ -968,15 +1017,132 @@ class BaseModelRoutes(ABC):
|
||||
return web.json_response(response_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in auto_organize_models: {e}", exc_info=True)
|
||||
logger.error(f"Error in _perform_auto_organize: {e}", exc_info=True)
|
||||
|
||||
# Send error message via WebSocket
|
||||
await ws_manager.broadcast({
|
||||
await ws_manager.broadcast_auto_organize_progress({
|
||||
'type': 'auto_organize_progress',
|
||||
'status': 'error',
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
raise e
|
||||
|
||||
async def get_auto_organize_progress(self, request: web.Request) -> web.Response:
|
||||
"""Get current auto-organize progress for polling"""
|
||||
try:
|
||||
progress_data = ws_manager.get_auto_organize_progress()
|
||||
|
||||
if progress_data is None:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No auto-organize operation in progress'
|
||||
}, status=404)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'progress': progress_data
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting auto-organize progress: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_model_notes(self, request: web.Request) -> web.Response:
|
||||
"""Get notes for a specific model file"""
|
||||
try:
|
||||
model_name = request.query.get('name')
|
||||
if not model_name:
|
||||
return web.Response(text=f'{self.model_type.capitalize()} file name is required', status=400)
|
||||
|
||||
notes = await self.service.get_model_notes(model_name)
|
||||
if notes is not None:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'notes': notes
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f'{self.model_type.capitalize()} not found in cache'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting {self.model_type} notes: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_model_preview_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the static preview URL for a model file"""
|
||||
try:
|
||||
model_name = request.query.get('name')
|
||||
if not model_name:
|
||||
return web.Response(text=f'{self.model_type.capitalize()} file name is required', status=400)
|
||||
|
||||
preview_url = await self.service.get_model_preview_url(model_name)
|
||||
if preview_url:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'preview_url': preview_url
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f'No preview URL found for the specified {self.model_type}'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting {self.model_type} preview URL: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_model_civitai_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the Civitai URL for a model file"""
|
||||
try:
|
||||
model_name = request.query.get('name')
|
||||
if not model_name:
|
||||
return web.Response(text=f'{self.model_type.capitalize()} file name is required', status=400)
|
||||
|
||||
result = await self.service.get_model_civitai_url(model_name)
|
||||
if result['civitai_url']:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
**result
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f'No Civitai data found for the specified {self.model_type}'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting {self.model_type} Civitai URL: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_relative_paths(self, request: web.Request) -> web.Response:
|
||||
"""Get model relative file paths for autocomplete functionality"""
|
||||
try:
|
||||
search = request.query.get('search', '').strip()
|
||||
limit = min(int(request.query.get('limit', '15')), 50) # Max 50 items
|
||||
|
||||
matching_paths = await self.service.search_relative_paths(search, limit)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'relative_paths': matching_paths
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting relative paths for autocomplete: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
from ..utils.example_images_download_manager import DownloadManager
|
||||
from ..utils.example_images_processor import ExampleImagesProcessor
|
||||
from ..utils.example_images_file_manager import ExampleImagesFileManager
|
||||
from ..services.websocket_manager import ws_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -20,6 +21,7 @@ class ExampleImagesRoutes:
|
||||
app.router.add_get('/api/example-image-files', ExampleImagesRoutes.get_example_image_files)
|
||||
app.router.add_get('/api/has-example-images', ExampleImagesRoutes.has_example_images)
|
||||
app.router.add_post('/api/delete-example-image', ExampleImagesRoutes.delete_example_image)
|
||||
app.router.add_post('/api/force-download-example-images', ExampleImagesRoutes.force_download_example_images)
|
||||
|
||||
@staticmethod
|
||||
async def download_example_images(request):
|
||||
@@ -64,4 +66,9 @@ class ExampleImagesRoutes:
|
||||
@staticmethod
|
||||
async def delete_example_image(request):
|
||||
"""Delete a custom example image for a model"""
|
||||
return await ExampleImagesProcessor.delete_custom_image(request)
|
||||
return await ExampleImagesProcessor.delete_custom_image(request)
|
||||
|
||||
@staticmethod
|
||||
async def force_download_example_images(request):
|
||||
"""Force download example images for specific models"""
|
||||
return await DownloadManager.start_force_download(request)
|
||||
@@ -43,11 +43,9 @@ class LoraRoutes(BaseModelRoutes):
|
||||
"""Setup LoRA-specific routes"""
|
||||
# LoRA-specific query routes
|
||||
app.router.add_get(f'/api/{prefix}/letter-counts', self.get_letter_counts)
|
||||
app.router.add_get(f'/api/{prefix}/get-notes', self.get_lora_notes)
|
||||
app.router.add_get(f'/api/{prefix}/get-trigger-words', self.get_lora_trigger_words)
|
||||
app.router.add_get(f'/api/{prefix}/preview-url', self.get_lora_preview_url)
|
||||
app.router.add_get(f'/api/{prefix}/civitai-url', self.get_lora_civitai_url)
|
||||
app.router.add_get(f'/api/{prefix}/model-description', self.get_lora_model_description)
|
||||
app.router.add_get(f'/api/{prefix}/usage-tips-by-path', self.get_lora_usage_tips_by_path)
|
||||
|
||||
# CivitAI integration with LoRA-specific validation
|
||||
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_lora)
|
||||
@@ -143,6 +141,26 @@ class LoraRoutes(BaseModelRoutes):
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_usage_tips_by_path(self, request: web.Request) -> web.Response:
|
||||
"""Get usage tips for a LoRA by its relative path"""
|
||||
try:
|
||||
relative_path = request.query.get('relative_path')
|
||||
if not relative_path:
|
||||
return web.Response(text='Relative path is required', status=400)
|
||||
|
||||
usage_tips = await self.service.get_lora_usage_tips_by_relative_path(relative_path)
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'usage_tips': usage_tips or ''
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora usage tips by path: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the static preview URL for a LoRA file"""
|
||||
try:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Type
|
||||
import logging
|
||||
import os
|
||||
|
||||
from ..utils.models import BaseModelMetadata
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
@@ -330,4 +331,92 @@ class BaseModelService(ABC):
|
||||
current_level[part] = {}
|
||||
current_level = current_level[part]
|
||||
|
||||
return unified_tree
|
||||
return unified_tree
|
||||
|
||||
async def get_model_notes(self, model_name: str) -> Optional[str]:
|
||||
"""Get notes for a specific model file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model['file_name'] == model_name:
|
||||
return model.get('notes', '')
|
||||
|
||||
return None
|
||||
|
||||
async def get_model_preview_url(self, model_name: str) -> Optional[str]:
|
||||
"""Get the static preview URL for a model file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model['file_name'] == model_name:
|
||||
preview_url = model.get('preview_url')
|
||||
if preview_url:
|
||||
from ..config import config
|
||||
return config.get_preview_static_url(preview_url)
|
||||
|
||||
return None
|
||||
|
||||
async def get_model_civitai_url(self, model_name: str) -> Dict[str, Optional[str]]:
|
||||
"""Get the Civitai URL for a model file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model['file_name'] == model_name:
|
||||
civitai_data = model.get('civitai', {})
|
||||
model_id = civitai_data.get('modelId')
|
||||
version_id = civitai_data.get('id')
|
||||
|
||||
if model_id:
|
||||
civitai_url = f"https://civitai.com/models/{model_id}"
|
||||
if version_id:
|
||||
civitai_url += f"?modelVersionId={version_id}"
|
||||
|
||||
return {
|
||||
'civitai_url': civitai_url,
|
||||
'model_id': str(model_id),
|
||||
'version_id': str(version_id) if version_id else None
|
||||
}
|
||||
|
||||
return {'civitai_url': None, 'model_id': None, 'version_id': None}
|
||||
|
||||
async def search_relative_paths(self, search_term: str, limit: int = 15) -> List[str]:
|
||||
"""Search model relative file paths for autocomplete functionality"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
matching_paths = []
|
||||
search_lower = search_term.lower()
|
||||
|
||||
# Get model roots for path calculation
|
||||
model_roots = self.scanner.get_model_roots()
|
||||
|
||||
for model in cache.raw_data:
|
||||
file_path = model.get('file_path', '')
|
||||
if not file_path:
|
||||
continue
|
||||
|
||||
# Calculate relative path from model root
|
||||
relative_path = None
|
||||
for root in model_roots:
|
||||
# Normalize paths for comparison
|
||||
normalized_root = os.path.normpath(root).replace(os.sep, '/')
|
||||
normalized_file = os.path.normpath(file_path).replace(os.sep, '/')
|
||||
|
||||
if normalized_file.startswith(normalized_root):
|
||||
# Remove root and leading slash to get relative path
|
||||
relative_path = normalized_file[len(normalized_root):].lstrip('/')
|
||||
break
|
||||
|
||||
if relative_path and search_lower in relative_path.lower():
|
||||
matching_paths.append(relative_path)
|
||||
|
||||
if len(matching_paths) >= limit * 2: # Get more for better sorting
|
||||
break
|
||||
|
||||
# Sort by relevance (exact matches first, then by length)
|
||||
matching_paths.sort(key=lambda x: (
|
||||
not x.lower().startswith(search_lower), # Exact prefix matches first
|
||||
len(x), # Then by length (shorter first)
|
||||
x.lower() # Then alphabetically
|
||||
))
|
||||
|
||||
return matching_paths[:limit]
|
||||
@@ -352,7 +352,11 @@ class DownloadManager:
|
||||
base_model = version_info.get('baseModel', '')
|
||||
|
||||
# Get author from creator data
|
||||
author = version_info.get('creator', {}).get('username', 'Anonymous')
|
||||
creator_info = version_info.get('creator')
|
||||
if creator_info and isinstance(creator_info, dict):
|
||||
author = creator_info.get('username') or 'Anonymous'
|
||||
else:
|
||||
author = 'Anonymous'
|
||||
|
||||
# Apply mapping if available
|
||||
base_model_mappings = settings.get('base_model_path_mappings', {})
|
||||
|
||||
@@ -147,16 +147,6 @@ class LoraService(BaseModelService):
|
||||
|
||||
return letters
|
||||
|
||||
async def get_lora_notes(self, lora_name: str) -> Optional[str]:
|
||||
"""Get notes for a specific LoRA file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
if lora['file_name'] == lora_name:
|
||||
return lora.get('notes', '')
|
||||
|
||||
return None
|
||||
|
||||
async def get_lora_trigger_words(self, lora_name: str) -> List[str]:
|
||||
"""Get trigger words for a specific LoRA file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
@@ -168,41 +158,21 @@ class LoraService(BaseModelService):
|
||||
|
||||
return []
|
||||
|
||||
async def get_lora_preview_url(self, lora_name: str) -> Optional[str]:
|
||||
"""Get the static preview URL for a LoRA file"""
|
||||
async def get_lora_usage_tips_by_relative_path(self, relative_path: str) -> Optional[str]:
|
||||
"""Get usage tips for a LoRA by its relative path"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
if lora['file_name'] == lora_name:
|
||||
preview_url = lora.get('preview_url')
|
||||
if preview_url:
|
||||
return config.get_preview_static_url(preview_url)
|
||||
file_path = lora.get('file_path', '')
|
||||
if file_path:
|
||||
# Convert to forward slashes and extract relative path
|
||||
file_path_normalized = file_path.replace('\\', '/')
|
||||
# Find the relative path part by looking for the relative_path in the full path
|
||||
if file_path_normalized.endswith(relative_path) or relative_path in file_path_normalized:
|
||||
return lora.get('usage_tips', '')
|
||||
|
||||
return None
|
||||
|
||||
async def get_lora_civitai_url(self, lora_name: str) -> Dict[str, Optional[str]]:
|
||||
"""Get the Civitai URL for a LoRA file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
if lora['file_name'] == lora_name:
|
||||
civitai_data = lora.get('civitai', {})
|
||||
model_id = civitai_data.get('modelId')
|
||||
version_id = civitai_data.get('id')
|
||||
|
||||
if model_id:
|
||||
civitai_url = f"https://civitai.com/models/{model_id}"
|
||||
if version_id:
|
||||
civitai_url += f"?modelVersionId={version_id}"
|
||||
|
||||
return {
|
||||
'civitai_url': civitai_url,
|
||||
'model_id': str(model_id),
|
||||
'version_id': str(version_id) if version_id else None
|
||||
}
|
||||
|
||||
return {'civitai_url': None, 'model_id': None, 'version_id': None}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find LoRAs with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
@@ -16,6 +16,9 @@ class WebSocketManager:
|
||||
self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients
|
||||
# Add progress tracking dictionary
|
||||
self._download_progress: Dict[str, Dict] = {}
|
||||
# Add auto-organize progress tracking
|
||||
self._auto_organize_progress: Optional[Dict] = None
|
||||
self._auto_organize_lock = asyncio.Lock()
|
||||
|
||||
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||
"""Handle new WebSocket connection"""
|
||||
@@ -134,6 +137,33 @@ class WebSocketManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending download progress: {e}")
|
||||
|
||||
async def broadcast_auto_organize_progress(self, data: Dict):
|
||||
"""Broadcast auto-organize progress to connected clients"""
|
||||
# Store progress data in memory
|
||||
self._auto_organize_progress = data
|
||||
|
||||
# Broadcast via WebSocket
|
||||
await self.broadcast(data)
|
||||
|
||||
def get_auto_organize_progress(self) -> Optional[Dict]:
|
||||
"""Get current auto-organize progress"""
|
||||
return self._auto_organize_progress
|
||||
|
||||
def cleanup_auto_organize_progress(self):
|
||||
"""Clear auto-organize progress data"""
|
||||
self._auto_organize_progress = None
|
||||
|
||||
def is_auto_organize_running(self) -> bool:
|
||||
"""Check if auto-organize is currently running"""
|
||||
if not self._auto_organize_progress:
|
||||
return False
|
||||
status = self._auto_organize_progress.get('status')
|
||||
return status in ['started', 'processing', 'cleaning']
|
||||
|
||||
async def get_auto_organize_lock(self):
|
||||
"""Get the auto-organize lock"""
|
||||
return self._auto_organize_lock
|
||||
|
||||
def get_download_progress(self, download_id: str) -> Optional[Dict]:
|
||||
"""Get progress information for a specific download"""
|
||||
return self._download_progress.get(download_id)
|
||||
|
||||
@@ -6,8 +6,10 @@ import time
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .example_images_processor import ExampleImagesProcessor
|
||||
from .example_images_metadata import MetadataUpdater
|
||||
from ..services.websocket_manager import ws_manager # Add this import at the top
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -431,4 +433,364 @@ class DownloadManager:
|
||||
with open(progress_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(progress_data, f, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save progress file: {e}")
|
||||
logger.error(f"Failed to save progress file: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def start_force_download(request):
|
||||
"""
|
||||
Force download example images for specific models
|
||||
|
||||
Expects a JSON body with:
|
||||
{
|
||||
"model_hashes": ["hash1", "hash2", ...], # List of model hashes to download
|
||||
"output_dir": "path/to/output", # Base directory to save example images
|
||||
"optimize": true, # Whether to optimize images (default: true)
|
||||
"model_types": ["lora", "checkpoint"], # Model types to process (default: both)
|
||||
"delay": 1.0 # Delay between downloads (default: 1.0)
|
||||
}
|
||||
"""
|
||||
global download_task, is_downloading, download_progress
|
||||
|
||||
if is_downloading:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Download already in progress'
|
||||
}, status=400)
|
||||
|
||||
try:
|
||||
# Parse the request body
|
||||
data = await request.json()
|
||||
model_hashes = data.get('model_hashes', [])
|
||||
output_dir = data.get('output_dir')
|
||||
optimize = data.get('optimize', True)
|
||||
model_types = data.get('model_types', ['lora', 'checkpoint'])
|
||||
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
|
||||
|
||||
if not model_hashes:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing model_hashes parameter'
|
||||
}, status=400)
|
||||
|
||||
if not output_dir:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing output_dir parameter'
|
||||
}, status=400)
|
||||
|
||||
# Create the output directory
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Initialize progress tracking
|
||||
download_progress['total'] = len(model_hashes)
|
||||
download_progress['completed'] = 0
|
||||
download_progress['current_model'] = ''
|
||||
download_progress['status'] = 'running'
|
||||
download_progress['errors'] = []
|
||||
download_progress['last_error'] = None
|
||||
download_progress['start_time'] = time.time()
|
||||
download_progress['end_time'] = None
|
||||
download_progress['processed_models'] = set()
|
||||
download_progress['refreshed_models'] = set()
|
||||
download_progress['failed_models'] = set()
|
||||
|
||||
# Set download status to downloading
|
||||
is_downloading = True
|
||||
|
||||
# Execute the download function directly instead of creating a background task
|
||||
result = await DownloadManager._download_specific_models_example_images_sync(
|
||||
model_hashes,
|
||||
output_dir,
|
||||
optimize,
|
||||
model_types,
|
||||
delay
|
||||
)
|
||||
|
||||
# Set download status to not downloading
|
||||
is_downloading = False
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': 'Force download completed',
|
||||
'result': result
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
# Set download status to not downloading
|
||||
is_downloading = False
|
||||
logger.error(f"Failed during forced example images download: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def _download_specific_models_example_images_sync(model_hashes, output_dir, optimize, model_types, delay):
|
||||
"""Download example images for specific models only - synchronous version"""
|
||||
global download_progress
|
||||
|
||||
# Create independent download session
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=True,
|
||||
limit=3,
|
||||
force_close=False,
|
||||
enable_cleanup_closed=True
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
|
||||
independent_session = aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
trust_env=True,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
try:
|
||||
# Get scanners
|
||||
scanners = []
|
||||
if 'lora' in model_types:
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
scanners.append(('lora', lora_scanner))
|
||||
|
||||
if 'checkpoint' in model_types:
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
scanners.append(('checkpoint', checkpoint_scanner))
|
||||
|
||||
if 'embedding' in model_types:
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
scanners.append(('embedding', embedding_scanner))
|
||||
|
||||
# Find the specified models
|
||||
models_to_process = []
|
||||
for scanner_type, scanner in scanners:
|
||||
cache = await scanner.get_cached_data()
|
||||
if cache and cache.raw_data:
|
||||
for model in cache.raw_data:
|
||||
if model.get('sha256') in model_hashes:
|
||||
models_to_process.append((scanner_type, model, scanner))
|
||||
|
||||
# Update total count based on found models
|
||||
download_progress['total'] = len(models_to_process)
|
||||
logger.debug(f"Found {download_progress['total']} models to process")
|
||||
|
||||
# Send initial progress via WebSocket
|
||||
await ws_manager.broadcast({
|
||||
'type': 'example_images_progress',
|
||||
'processed': 0,
|
||||
'total': download_progress['total'],
|
||||
'status': 'running',
|
||||
'current_model': ''
|
||||
})
|
||||
|
||||
# Process each model
|
||||
success_count = 0
|
||||
for i, (scanner_type, model, scanner) in enumerate(models_to_process):
|
||||
# Force process this model regardless of previous status
|
||||
was_successful = await DownloadManager._process_specific_model(
|
||||
scanner_type, model, scanner,
|
||||
output_dir, optimize, independent_session
|
||||
)
|
||||
|
||||
if was_successful:
|
||||
success_count += 1
|
||||
|
||||
# Update progress
|
||||
download_progress['completed'] += 1
|
||||
|
||||
# Send progress update via WebSocket
|
||||
await ws_manager.broadcast({
|
||||
'type': 'example_images_progress',
|
||||
'processed': download_progress['completed'],
|
||||
'total': download_progress['total'],
|
||||
'status': 'running',
|
||||
'current_model': download_progress['current_model']
|
||||
})
|
||||
|
||||
# Only add delay after remote download, and not after processing the last model
|
||||
if was_successful and i < len(models_to_process) - 1 and download_progress['status'] == 'running':
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Mark as completed
|
||||
download_progress['status'] = 'completed'
|
||||
download_progress['end_time'] = time.time()
|
||||
logger.debug(f"Forced example images download completed: {download_progress['completed']}/{download_progress['total']} models processed")
|
||||
|
||||
# Send final progress via WebSocket
|
||||
await ws_manager.broadcast({
|
||||
'type': 'example_images_progress',
|
||||
'processed': download_progress['completed'],
|
||||
'total': download_progress['total'],
|
||||
'status': 'completed',
|
||||
'current_model': ''
|
||||
})
|
||||
|
||||
return {
|
||||
'total': download_progress['total'],
|
||||
'processed': download_progress['completed'],
|
||||
'successful': success_count,
|
||||
'errors': download_progress['errors']
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error during forced example images download: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
download_progress['errors'].append(error_msg)
|
||||
download_progress['last_error'] = error_msg
|
||||
download_progress['status'] = 'error'
|
||||
download_progress['end_time'] = time.time()
|
||||
|
||||
# Send error status via WebSocket
|
||||
await ws_manager.broadcast({
|
||||
'type': 'example_images_progress',
|
||||
'processed': download_progress['completed'],
|
||||
'total': download_progress['total'],
|
||||
'status': 'error',
|
||||
'error': error_msg,
|
||||
'current_model': ''
|
||||
})
|
||||
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Close the independent session
|
||||
try:
|
||||
await independent_session.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing download session: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def _process_specific_model(scanner_type, model, scanner, output_dir, optimize, independent_session):
|
||||
"""Process a specific model for forced download, ignoring previous download status"""
|
||||
global download_progress
|
||||
|
||||
# Check if download is paused
|
||||
while download_progress['status'] == 'paused':
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Check if download should continue
|
||||
if download_progress['status'] != 'running':
|
||||
logger.info(f"Download stopped: {download_progress['status']}")
|
||||
return False
|
||||
|
||||
model_hash = model.get('sha256', '').lower()
|
||||
model_name = model.get('model_name', 'Unknown')
|
||||
model_file_path = model.get('file_path', '')
|
||||
model_file_name = model.get('file_name', '')
|
||||
|
||||
try:
|
||||
# Update current model info
|
||||
download_progress['current_model'] = f"{model_name} ({model_hash[:8]})"
|
||||
|
||||
# Create model directory
|
||||
model_dir = os.path.join(output_dir, model_hash)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
# First check for local example images - local processing doesn't need delay
|
||||
local_images_processed = await ExampleImagesProcessor.process_local_examples(
|
||||
model_file_path, model_file_name, model_name, model_dir, optimize
|
||||
)
|
||||
|
||||
# If we processed local images, update metadata
|
||||
if local_images_processed:
|
||||
await MetadataUpdater.update_metadata_from_local_examples(
|
||||
model_hash, model, scanner_type, scanner, model_dir
|
||||
)
|
||||
download_progress['processed_models'].add(model_hash)
|
||||
return False # Return False to indicate no remote download happened
|
||||
|
||||
# If no local images, try to download from remote
|
||||
elif model.get('civitai') and model.get('civitai', {}).get('images'):
|
||||
images = model.get('civitai', {}).get('images', [])
|
||||
|
||||
success, is_stale, failed_images = await ExampleImagesProcessor.download_model_images_with_tracking(
|
||||
model_hash, model_name, images, model_dir, optimize, independent_session
|
||||
)
|
||||
|
||||
# If metadata is stale, try to refresh it
|
||||
if is_stale and model_hash not in download_progress['refreshed_models']:
|
||||
await MetadataUpdater.refresh_model_metadata(
|
||||
model_hash, model_name, scanner_type, scanner
|
||||
)
|
||||
|
||||
# Get the updated model data
|
||||
updated_model = await MetadataUpdater.get_updated_model(
|
||||
model_hash, scanner
|
||||
)
|
||||
|
||||
if updated_model and updated_model.get('civitai', {}).get('images'):
|
||||
# Retry download with updated metadata
|
||||
updated_images = updated_model.get('civitai', {}).get('images', [])
|
||||
success, _, additional_failed_images = await ExampleImagesProcessor.download_model_images_with_tracking(
|
||||
model_hash, model_name, updated_images, model_dir, optimize, independent_session
|
||||
)
|
||||
|
||||
# Combine failed images from both attempts
|
||||
failed_images.extend(additional_failed_images)
|
||||
|
||||
download_progress['refreshed_models'].add(model_hash)
|
||||
|
||||
# For forced downloads, remove failed images from metadata
|
||||
if failed_images:
|
||||
# Create a copy of images excluding failed ones
|
||||
await DownloadManager._remove_failed_images_from_metadata(
|
||||
model_hash, model_name, failed_images, scanner
|
||||
)
|
||||
|
||||
# Mark as processed
|
||||
if success or failed_images: # Mark as processed if we successfully downloaded some images or removed failed ones
|
||||
download_progress['processed_models'].add(model_hash)
|
||||
|
||||
return True # Return True to indicate a remote download happened
|
||||
else:
|
||||
logger.debug(f"No civitai images available for model {model_name}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error processing model {model.get('model_name')}: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
download_progress['errors'].append(error_msg)
|
||||
download_progress['last_error'] = error_msg
|
||||
return False # Return False on exception
|
||||
|
||||
@staticmethod
|
||||
async def _remove_failed_images_from_metadata(model_hash, model_name, failed_images, scanner):
|
||||
"""Remove failed images from model metadata"""
|
||||
try:
|
||||
# Get current model data
|
||||
model_data = await MetadataUpdater.get_updated_model(model_hash, scanner)
|
||||
if not model_data:
|
||||
logger.warning(f"Could not find model data for {model_name} to remove failed images")
|
||||
return
|
||||
|
||||
if not model_data.get('civitai', {}).get('images'):
|
||||
logger.warning(f"No images in metadata for {model_name}")
|
||||
return
|
||||
|
||||
# Get current images
|
||||
current_images = model_data['civitai']['images']
|
||||
|
||||
# Filter out failed images
|
||||
updated_images = [img for img in current_images if img.get('url') not in failed_images]
|
||||
|
||||
# If images were removed, update metadata
|
||||
if len(updated_images) < len(current_images):
|
||||
removed_count = len(current_images) - len(updated_images)
|
||||
logger.info(f"Removing {removed_count} failed images from metadata for {model_name}")
|
||||
|
||||
# Update the images list
|
||||
model_data['civitai']['images'] = updated_images
|
||||
|
||||
# Save metadata to file
|
||||
file_path = model_data.get('file_path')
|
||||
if file_path:
|
||||
# Create a copy of model data without 'folder' field
|
||||
model_copy = model_data.copy()
|
||||
model_copy.pop('folder', None)
|
||||
|
||||
# Write metadata to file
|
||||
await MetadataManager.save_metadata(file_path, model_copy)
|
||||
logger.info(f"Saved updated metadata for {model_name} after removing failed images")
|
||||
|
||||
# Update the scanner cache
|
||||
await scanner.update_single_model_cache(file_path, file_path, model_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True)
|
||||
@@ -102,6 +102,78 @@ class ExampleImagesProcessor:
|
||||
|
||||
return model_success, False # (success, is_metadata_stale)
|
||||
|
||||
@staticmethod
|
||||
async def download_model_images_with_tracking(model_hash, model_name, model_images, model_dir, optimize, independent_session):
|
||||
"""Download images for a single model with tracking of failed image URLs
|
||||
|
||||
Returns:
|
||||
tuple: (success, is_stale_metadata, failed_images) - whether download was successful, whether metadata is stale, list of failed image URLs
|
||||
"""
|
||||
model_success = True
|
||||
failed_images = []
|
||||
|
||||
for i, image in enumerate(model_images):
|
||||
image_url = image.get('url')
|
||||
if not image_url:
|
||||
continue
|
||||
|
||||
# Get image filename from URL
|
||||
image_filename = os.path.basename(image_url.split('?')[0])
|
||||
image_ext = os.path.splitext(image_filename)[1].lower()
|
||||
|
||||
# Handle images and videos
|
||||
is_image = image_ext in SUPPORTED_MEDIA_EXTENSIONS['images']
|
||||
is_video = image_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']
|
||||
|
||||
if not (is_image or is_video):
|
||||
logger.debug(f"Skipping unsupported file type: {image_filename}")
|
||||
continue
|
||||
|
||||
# Use 0-based indexing instead of 1-based indexing
|
||||
save_filename = f"image_{i}{image_ext}"
|
||||
|
||||
# If optimizing images and this is a Civitai image, use their pre-optimized WebP version
|
||||
if is_image and optimize and 'civitai.com' in image_url:
|
||||
image_url = ExampleImagesProcessor.get_civitai_optimized_url(image_url)
|
||||
save_filename = f"image_{i}.webp"
|
||||
|
||||
# Check if already downloaded
|
||||
save_path = os.path.join(model_dir, save_filename)
|
||||
if os.path.exists(save_path):
|
||||
logger.debug(f"File already exists: {save_path}")
|
||||
continue
|
||||
|
||||
# Download the file
|
||||
try:
|
||||
logger.debug(f"Downloading {save_filename} for {model_name}")
|
||||
|
||||
# Download directly using the independent session
|
||||
async with independent_session.get(image_url, timeout=60) as response:
|
||||
if response.status == 200:
|
||||
with open(save_path, 'wb') as f:
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
elif response.status == 404:
|
||||
error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale"
|
||||
logger.warning(error_msg)
|
||||
model_success = False # Mark the model as failed due to 404 error
|
||||
failed_images.append(image_url) # Track failed URL
|
||||
# Return early to trigger metadata refresh attempt
|
||||
return False, True, failed_images # (success, is_metadata_stale, failed_images)
|
||||
else:
|
||||
error_msg = f"Failed to download file: {image_url}, status code: {response.status}"
|
||||
logger.warning(error_msg)
|
||||
model_success = False # Mark the model as failed
|
||||
failed_images.append(image_url) # Track failed URL
|
||||
except Exception as e:
|
||||
error_msg = f"Error downloading file {image_url}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
model_success = False # Mark the model as failed
|
||||
failed_images.append(image_url) # Track failed URL
|
||||
|
||||
return model_success, False, failed_images # (success, is_metadata_stale, failed_images)
|
||||
|
||||
@staticmethod
|
||||
async def process_local_examples(model_file_path, model_file_name, model_name, model_dir, optimize):
|
||||
"""Process local example images
|
||||
|
||||
@@ -156,7 +156,7 @@ def calculate_relative_path_for_model(model_data: Dict, model_type: str = 'lora'
|
||||
if civitai_data and civitai_data.get('id') is not None:
|
||||
base_model = civitai_data.get('baseModel', '')
|
||||
# Get author from civitai creator data
|
||||
author = civitai_data.get('creator', {}).get('username', 'Anonymous')
|
||||
author = civitai_data.get('creator', {}).get('username') or 'Anonymous'
|
||||
else:
|
||||
# Fallback to model_data fields for non-CivitAI models
|
||||
base_model = model_data.get('base_model', '')
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "comfyui-lora-manager"
|
||||
description = "Revolutionize your workflow with the ultimate LoRA companion for ComfyUI!"
|
||||
version = "0.8.27"
|
||||
version = "0.8.28"
|
||||
license = {file = "LICENSE"}
|
||||
dependencies = [
|
||||
"aiohttp",
|
||||
|
||||
@@ -12,7 +12,9 @@
|
||||
z-index: var(--z-overlay);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
min-width: 300px;
|
||||
min-width: 420px;
|
||||
max-width: 900px;
|
||||
width: auto;
|
||||
transition: all 0.4s cubic-bezier(0.175, 0.885, 0.32, 1.275);
|
||||
opacity: 0;
|
||||
}
|
||||
@@ -48,6 +50,8 @@
|
||||
color: var(--text-color);
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
white-space: nowrap;
|
||||
min-height: 36px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
@@ -105,6 +109,8 @@
|
||||
@media (max-width: 768px) {
|
||||
.bulk-operations-panel {
|
||||
width: calc(100% - 40px);
|
||||
min-width: unset;
|
||||
max-width: unset;
|
||||
left: 20px;
|
||||
transform: none;
|
||||
border-radius: var(--border-radius-sm);
|
||||
|
||||
@@ -165,7 +165,8 @@ export const DOWNLOAD_ENDPOINTS = {
|
||||
download: '/api/download-model',
|
||||
downloadGet: '/api/download-model-get',
|
||||
cancelGet: '/api/cancel-download-get',
|
||||
progress: '/api/download-progress'
|
||||
progress: '/api/download-progress',
|
||||
exampleImages: '/api/force-download-example-images' // New endpoint for downloading example images
|
||||
};
|
||||
|
||||
// WebSocket endpoints
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { state, getCurrentPageState } from '../state/index.js';
|
||||
import { showToast, updateFolderTags } from '../utils/uiHelpers.js';
|
||||
import { getSessionItem, saveMapToStorage } from '../utils/storageHelpers.js';
|
||||
import { getStorageItem, getSessionItem, saveMapToStorage } from '../utils/storageHelpers.js';
|
||||
import {
|
||||
getCompleteApiConfig,
|
||||
getCurrentModelType,
|
||||
@@ -435,7 +435,9 @@ export class BaseModelApiClient {
|
||||
}
|
||||
|
||||
await operationComplete;
|
||||
|
||||
|
||||
resetAndReload(false);
|
||||
showToast('Metadata update complete', 'success');
|
||||
} catch (error) {
|
||||
console.error('Error fetching metadata:', error);
|
||||
showToast('Failed to fetch metadata: ' + error.message, 'error');
|
||||
@@ -853,4 +855,102 @@ export class BaseModelApiClient {
|
||||
state.loadingManager.hide();
|
||||
}
|
||||
}
|
||||
|
||||
async downloadExampleImages(modelHashes, modelTypes = null) {
|
||||
let ws = null;
|
||||
|
||||
await state.loadingManager.showWithProgress(async (loading) => {
|
||||
try {
|
||||
// Connect to WebSocket for progress updates
|
||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
||||
ws = new WebSocket(`${wsProtocol}${window.location.host}${WS_ENDPOINTS.fetchProgress}`);
|
||||
|
||||
const operationComplete = new Promise((resolve, reject) => {
|
||||
ws.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
|
||||
if (data.type !== 'example_images_progress') return;
|
||||
|
||||
switch(data.status) {
|
||||
case 'running':
|
||||
const percent = ((data.processed / data.total) * 100).toFixed(1);
|
||||
loading.setProgress(percent);
|
||||
loading.setStatus(
|
||||
`Processing (${data.processed}/${data.total}) ${data.current_model || ''}`
|
||||
);
|
||||
break;
|
||||
|
||||
case 'completed':
|
||||
loading.setProgress(100);
|
||||
loading.setStatus(
|
||||
`Completed: Downloaded example images for ${data.processed} models`
|
||||
);
|
||||
resolve();
|
||||
break;
|
||||
|
||||
case 'error':
|
||||
reject(new Error(data.error));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
ws.onerror = (error) => {
|
||||
reject(new Error('WebSocket error: ' + error.message));
|
||||
};
|
||||
});
|
||||
|
||||
// Wait for WebSocket connection to establish
|
||||
await new Promise((resolve, reject) => {
|
||||
ws.onopen = resolve;
|
||||
ws.onerror = reject;
|
||||
});
|
||||
|
||||
// Get the output directory from storage
|
||||
const outputDir = getStorageItem('example_images_path', '');
|
||||
if (!outputDir) {
|
||||
throw new Error('Please set the example images path in the settings first.');
|
||||
}
|
||||
|
||||
// Determine optimize setting
|
||||
const optimize = state.global?.settings?.optimizeExampleImages ?? true;
|
||||
|
||||
// Make the API request to start the download process
|
||||
const response = await fetch(DOWNLOAD_ENDPOINTS.exampleImages, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model_hashes: modelHashes,
|
||||
output_dir: outputDir,
|
||||
optimize: optimize,
|
||||
model_types: modelTypes || [this.apiConfig.config.singularName]
|
||||
})
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => ({}));
|
||||
throw new Error(errorData.error || 'Failed to download example images');
|
||||
}
|
||||
|
||||
// Wait for the operation to complete via WebSocket
|
||||
await operationComplete;
|
||||
|
||||
showToast('Successfully downloaded example images!', 'success');
|
||||
return true;
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error downloading example images:', error);
|
||||
showToast(`Failed to download example images: ${error.message}`, 'error');
|
||||
throw error;
|
||||
} finally {
|
||||
if (ws) {
|
||||
ws.close();
|
||||
}
|
||||
}
|
||||
}, {
|
||||
initialMessage: 'Starting example images download...',
|
||||
completionMessage: 'Example images download complete'
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
import { showToast, getNSFWLevelName, openExampleImagesFolder } from '../../utils/uiHelpers.js';
|
||||
import { modalManager } from '../../managers/ModalManager.js';
|
||||
import { state } from '../../state/index.js';
|
||||
import { getModelApiClient } from '../../api/modelApiFactory.js';
|
||||
|
||||
// Mixin with shared functionality for LoraContextMenu and CheckpointContextMenu
|
||||
export const ModelContextMenuMixin = {
|
||||
@@ -202,6 +203,9 @@ export const ModelContextMenuMixin = {
|
||||
case 'preview':
|
||||
openExampleImagesFolder(this.currentCard.dataset.sha256);
|
||||
return true;
|
||||
case 'download-examples':
|
||||
this.downloadExampleImages();
|
||||
return true;
|
||||
case 'civitai':
|
||||
if (this.currentCard.dataset.from_civitai === 'true') {
|
||||
if (this.currentCard.querySelector('.fa-globe')) {
|
||||
@@ -222,5 +226,21 @@ export const ModelContextMenuMixin = {
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
},
|
||||
|
||||
// Download example images method
|
||||
async downloadExampleImages() {
|
||||
const modelHash = this.currentCard.dataset.sha256;
|
||||
if (!modelHash) {
|
||||
showToast('Model hash not available', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const apiClient = getModelApiClient();
|
||||
await apiClient.downloadExampleImages([modelHash]);
|
||||
} catch (error) {
|
||||
console.error('Error downloading example images:', error);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -3,7 +3,7 @@ import { showToast } from '../utils/uiHelpers.js';
|
||||
import { state, getCurrentPageState } from '../state/index.js';
|
||||
import { formatDate } from '../utils/formatters.js';
|
||||
import { resetAndReload} from '../api/modelApiFactory.js';
|
||||
import { LoadingManager } from '../managers/LoadingManager.js';
|
||||
import { getShowDuplicatesNotification, setShowDuplicatesNotification } from '../utils/storageHelpers.js';
|
||||
|
||||
export class ModelDuplicatesManager {
|
||||
constructor(pageManager, modelType = 'loras') {
|
||||
@@ -12,13 +12,21 @@ export class ModelDuplicatesManager {
|
||||
this.inDuplicateMode = false;
|
||||
this.selectedForDeletion = new Set();
|
||||
this.modelType = modelType; // Use the provided modelType or default to 'loras'
|
||||
|
||||
|
||||
// Verification tracking
|
||||
this.verifiedGroups = new Set(); // Track which groups have been verified
|
||||
this.mismatchedFiles = new Map(); // Map file paths to actual hashes for mismatched files
|
||||
|
||||
// Loading manager for verification process
|
||||
this.loadingManager = new LoadingManager();
|
||||
// Badge visibility preference
|
||||
this.showBadge = getShowDuplicatesNotification(); // Default to true (show badge)
|
||||
|
||||
// Event handler references for cleanup
|
||||
this.badgeToggleHandler = null;
|
||||
this.helpTooltipHandlers = {
|
||||
mouseenter: null,
|
||||
mouseleave: null,
|
||||
click: null
|
||||
};
|
||||
|
||||
// Bind methods
|
||||
this.renderModelCard = this.renderModelCard.bind(this);
|
||||
@@ -66,7 +74,16 @@ export class ModelDuplicatesManager {
|
||||
const badge = document.getElementById('duplicatesBadge');
|
||||
if (!badge) return;
|
||||
|
||||
// Check if badge should be hidden based on user preference
|
||||
if (!this.showBadge && !this.inDuplicateMode) {
|
||||
badge.style.display = 'none';
|
||||
badge.textContent = '';
|
||||
badge.classList.remove('pulse');
|
||||
return;
|
||||
}
|
||||
|
||||
if (count > 0) {
|
||||
badge.style.display = 'inline-flex';
|
||||
badge.textContent = count;
|
||||
badge.classList.add('pulse');
|
||||
} else {
|
||||
@@ -136,6 +153,9 @@ export class ModelDuplicatesManager {
|
||||
|
||||
// Setup help tooltip behavior
|
||||
this.setupHelpTooltip();
|
||||
|
||||
// Setup badge toggle control
|
||||
this.setupBadgeToggle();
|
||||
}
|
||||
|
||||
// Disable virtual scrolling if active
|
||||
@@ -173,6 +193,9 @@ export class ModelDuplicatesManager {
|
||||
const pageState = getCurrentPageState();
|
||||
pageState.duplicatesMode = false;
|
||||
|
||||
// Clean up event handlers before hiding banner
|
||||
this.cleanupEventHandlers();
|
||||
|
||||
// Hide duplicates banner
|
||||
const banner = document.getElementById('duplicatesBanner');
|
||||
if (banner) {
|
||||
@@ -672,7 +695,11 @@ export class ModelDuplicatesManager {
|
||||
|
||||
if (!helpIcon || !helpTooltip) return;
|
||||
|
||||
helpIcon.addEventListener('mouseenter', (e) => {
|
||||
// Clean up existing handlers first
|
||||
this.cleanupHelpTooltipHandlers();
|
||||
|
||||
// Create new handler functions and store references
|
||||
this.helpTooltipHandlers.mouseenter = (e) => {
|
||||
// Get the container's positioning context
|
||||
const bannerContent = helpIcon.closest('.banner-content');
|
||||
|
||||
@@ -693,18 +720,22 @@ export class ModelDuplicatesManager {
|
||||
// Reposition relative to container if too close to right edge
|
||||
helpTooltip.style.left = `${bannerContent.offsetWidth - tooltipRect.width - 20}px`;
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
// Rest of the event listeners remain unchanged
|
||||
helpIcon.addEventListener('mouseleave', () => {
|
||||
this.helpTooltipHandlers.mouseleave = () => {
|
||||
helpTooltip.style.display = 'none';
|
||||
});
|
||||
};
|
||||
|
||||
document.addEventListener('click', (e) => {
|
||||
this.helpTooltipHandlers.click = (e) => {
|
||||
if (!helpIcon.contains(e.target)) {
|
||||
helpTooltip.style.display = 'none';
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
// Add event listeners
|
||||
helpIcon.addEventListener('mouseenter', this.helpTooltipHandlers.mouseenter);
|
||||
helpIcon.addEventListener('mouseleave', this.helpTooltipHandlers.mouseleave);
|
||||
document.addEventListener('click', this.helpTooltipHandlers.click);
|
||||
}
|
||||
|
||||
// Handle verify hashes button click
|
||||
@@ -719,7 +750,7 @@ export class ModelDuplicatesManager {
|
||||
}
|
||||
|
||||
// Show loading state
|
||||
this.loadingManager.showSimpleLoading('Verifying hashes...');
|
||||
state.loadingManager.showSimpleLoading('Verifying hashes...');
|
||||
|
||||
// Get file paths for all models in the group
|
||||
const filePaths = group.models.map(model => model.file_path);
|
||||
@@ -772,7 +803,87 @@ export class ModelDuplicatesManager {
|
||||
showToast('Failed to verify hashes: ' + error.message, 'error');
|
||||
} finally {
|
||||
// Hide loading state
|
||||
this.loadingManager.hide();
|
||||
state.loadingManager.hide();
|
||||
}
|
||||
}
|
||||
|
||||
// Add this new method for badge toggle setup
|
||||
setupBadgeToggle() {
|
||||
const toggleControl = document.getElementById('badgeToggleControl');
|
||||
const toggleInput = document.getElementById('badgeToggleInput');
|
||||
|
||||
if (!toggleControl || !toggleInput) return;
|
||||
|
||||
// Clean up existing handler first
|
||||
this.cleanupBadgeToggleHandler();
|
||||
|
||||
// Set initial state based on stored preference (default to true/checked)
|
||||
toggleInput.checked = this.showBadge;
|
||||
|
||||
// Create and store the handler function
|
||||
this.badgeToggleHandler = (e) => {
|
||||
this.showBadge = e.target.checked;
|
||||
setShowDuplicatesNotification(this.showBadge);
|
||||
|
||||
// Update badge visibility immediately if not in duplicate mode
|
||||
if (!this.inDuplicateMode) {
|
||||
this.updateDuplicatesBadge(this.duplicateGroups.length);
|
||||
}
|
||||
|
||||
showToast(
|
||||
this.showBadge ? 'Duplicates notification will be shown' : 'Duplicates notification will be hidden',
|
||||
'info'
|
||||
);
|
||||
};
|
||||
|
||||
// Add change event listener
|
||||
toggleInput.addEventListener('change', this.badgeToggleHandler);
|
||||
}
|
||||
|
||||
// Clean up all event handlers
|
||||
cleanupEventHandlers() {
|
||||
this.cleanupBadgeToggleHandler();
|
||||
this.cleanupHelpTooltipHandlers();
|
||||
}
|
||||
|
||||
// Clean up badge toggle event handler
|
||||
cleanupBadgeToggleHandler() {
|
||||
if (this.badgeToggleHandler) {
|
||||
const toggleInput = document.getElementById('badgeToggleInput');
|
||||
if (toggleInput) {
|
||||
toggleInput.removeEventListener('change', this.badgeToggleHandler);
|
||||
}
|
||||
this.badgeToggleHandler = null;
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up help tooltip event handlers
|
||||
cleanupHelpTooltipHandlers() {
|
||||
const helpIcon = document.getElementById('duplicatesHelp');
|
||||
|
||||
if (helpIcon && this.helpTooltipHandlers.mouseenter) {
|
||||
helpIcon.removeEventListener('mouseenter', this.helpTooltipHandlers.mouseenter);
|
||||
}
|
||||
|
||||
if (helpIcon && this.helpTooltipHandlers.mouseleave) {
|
||||
helpIcon.removeEventListener('mouseleave', this.helpTooltipHandlers.mouseleave);
|
||||
}
|
||||
|
||||
if (this.helpTooltipHandlers.click) {
|
||||
document.removeEventListener('click', this.helpTooltipHandlers.click);
|
||||
}
|
||||
|
||||
// Reset handler references
|
||||
this.helpTooltipHandlers = {
|
||||
mouseenter: null,
|
||||
mouseleave: null,
|
||||
click: null
|
||||
};
|
||||
|
||||
// Hide tooltip if it's visible
|
||||
const helpTooltip = document.getElementById('duplicatesHelpTooltip');
|
||||
if (helpTooltip) {
|
||||
helpTooltip.style.display = 'none';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,18 +273,27 @@ function showExampleAccessModal(card, modelType) {
|
||||
if (hasRemoteExamples) {
|
||||
downloadBtn.classList.remove('disabled');
|
||||
downloadBtn.removeAttribute('title');
|
||||
downloadBtn.onclick = () => {
|
||||
downloadBtn.onclick = async () => {
|
||||
// Get the model hash
|
||||
const modelHash = card.dataset.sha256;
|
||||
if (!modelHash) {
|
||||
showToast('Missing model hash information.', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
// Close the modal
|
||||
modalManager.closeModal('exampleAccessModal');
|
||||
// Open settings modal and scroll to example images section
|
||||
const settingsModal = document.getElementById('settingsModal');
|
||||
if (settingsModal) {
|
||||
modalManager.showModal('settingsModal');
|
||||
setTimeout(() => {
|
||||
const exampleSection = settingsModal.querySelector('.settings-section:nth-child(7)');
|
||||
if (exampleSection) {
|
||||
exampleSection.scrollIntoView({ behavior: 'smooth' });
|
||||
}
|
||||
}, 300);
|
||||
|
||||
try {
|
||||
// Use the appropriate model API client to download examples
|
||||
const apiClient = getModelApiClient(modelType);
|
||||
await apiClient.downloadExampleImages([modelHash]);
|
||||
|
||||
// Open the example images folder if successful
|
||||
openExampleImagesFolder(modelHash);
|
||||
} catch (error) {
|
||||
console.error('Error downloading example images:', error);
|
||||
// Error already shown by the API client
|
||||
}
|
||||
};
|
||||
} else {
|
||||
|
||||
@@ -203,7 +203,6 @@ export class BulkManager {
|
||||
|
||||
toggleCardSelection(card) {
|
||||
const filepath = card.dataset.filepath;
|
||||
const pageState = getCurrentPageState();
|
||||
|
||||
if (card.classList.contains('selected')) {
|
||||
card.classList.remove('selected');
|
||||
|
||||
@@ -254,4 +254,20 @@ export function resetDismissedBanner(bannerId) {
|
||||
const dismissedBanners = getStorageItem('dismissed_banners', []);
|
||||
const updatedBanners = dismissedBanners.filter(id => id !== bannerId);
|
||||
setStorageItem('dismissed_banners', updatedBanners);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the show duplicates notification preference
|
||||
* @returns {boolean} True if notification should be shown (default: true)
|
||||
*/
|
||||
export function getShowDuplicatesNotification() {
|
||||
return getStorageItem('show_duplicates_notification', true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the show duplicates notification preference
|
||||
* @param {boolean} show - Whether to show the notification
|
||||
*/
|
||||
export function setShowDuplicatesNotification(show) {
|
||||
setStorageItem('show_duplicates_notification', show);
|
||||
}
|
||||
@@ -18,6 +18,7 @@
|
||||
<div class="context-menu-item" data-action="relink-civitai"><i class="fas fa-link"></i> Re-link to Civitai</div>
|
||||
<div class="context-menu-item" data-action="copyname"><i class="fas fa-copy"></i> Copy Model Filename</div>
|
||||
<div class="context-menu-item" data-action="preview"><i class="fas fa-folder-open"></i> Open Examples Folder</div>
|
||||
<div class="context-menu-item" data-action="download-examples"><i class="fas fa-download"></i> Download Example Images</div>
|
||||
<div class="context-menu-item" data-action="replace-preview"><i class="fas fa-image"></i> Replace Preview</div>
|
||||
<div class="context-menu-item" data-action="set-nsfw"><i class="fas fa-exclamation-triangle"></i> Set Content Rating</div>
|
||||
<div class="context-menu-separator"></div>
|
||||
@@ -29,27 +30,7 @@
|
||||
|
||||
{% block content %}
|
||||
{% include 'components/controls.html' %}
|
||||
|
||||
<!-- Duplicates banner (hidden by default) -->
|
||||
<div id="duplicatesBanner" class="duplicates-banner" style="display: none;">
|
||||
<div class="banner-content">
|
||||
<i class="fas fa-exclamation-triangle"></i>
|
||||
<span id="duplicatesCount">Found 0 duplicate groups</span>
|
||||
<i class="fas fa-question-circle help-icon" id="duplicatesHelp" aria-label="Help information"></i>
|
||||
<div class="banner-actions">
|
||||
<button class="btn-delete-selected disabled" onclick="modelDuplicatesManager.deleteSelectedDuplicates()">
|
||||
Delete Selected (<span id="duplicatesSelectedCount">0</span>)
|
||||
</button>
|
||||
<button class="btn-exit-mode" onclick="modelDuplicatesManager.exitDuplicateMode()">
|
||||
<i class="fas fa-times"></i> Exit Mode
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="help-tooltip" id="duplicatesHelpTooltip">
|
||||
<p>Identical hashes mean identical model files, even if they have different names or previews.</p>
|
||||
<p>Keep only one version (preferably with better metadata/previews) and safely delete the others.</p>
|
||||
</div>
|
||||
</div>
|
||||
{% include 'components/duplicates_banner.html' %}
|
||||
|
||||
<!-- Checkpoint cards container -->
|
||||
<div class="card-grid" id="modelGrid">
|
||||
|
||||
@@ -23,6 +23,9 @@
|
||||
<div class="context-menu-item" data-action="preview">
|
||||
<i class="fas fa-folder-open"></i> Open Examples Folder
|
||||
</div>
|
||||
<div class="context-menu-item" data-action="download-examples">
|
||||
<i class="fas fa-download"></i> Download Example Images
|
||||
</div>
|
||||
<div class="context-menu-item" data-action="replace-preview">
|
||||
<i class="fas fa-image"></i> Replace Preview
|
||||
</div>
|
||||
|
||||
27
templates/components/duplicates_banner.html
Normal file
27
templates/components/duplicates_banner.html
Normal file
@@ -0,0 +1,27 @@
|
||||
<!-- Duplicates banner (hidden by default) -->
|
||||
<div id="duplicatesBanner" class="duplicates-banner" style="display: none;">
|
||||
<div class="banner-content">
|
||||
<i class="fas fa-exclamation-triangle"></i>
|
||||
<span id="duplicatesCount">Found 0 duplicate groups</span>
|
||||
<i class="fas fa-question-circle help-icon" id="duplicatesHelp" aria-label="Help information"></i>
|
||||
<div class="banner-actions">
|
||||
<div class="setting-contro" id="badgeToggleControl">
|
||||
<span>Show Duplicates Notification:</span>
|
||||
<label class="toggle-switch">
|
||||
<input type="checkbox" id="badgeToggleInput">
|
||||
<span class="toggle-slider"></span>
|
||||
</label>
|
||||
</div>
|
||||
<button class="btn-delete-selected disabled" onclick="modelDuplicatesManager.deleteSelectedDuplicates()">
|
||||
Delete Selected (<span id="duplicatesSelectedCount">0</span>)
|
||||
</button>
|
||||
<button class="btn-exit-mode" onclick="modelDuplicatesManager.exitDuplicateMode()">
|
||||
<i class="fas fa-times"></i> Exit Mode
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="help-tooltip" id="duplicatesHelpTooltip">
|
||||
<p>Identical hashes mean identical model files, even if they have different names or previews.</p>
|
||||
<p>Keep only one version (preferably with better metadata/previews) and safely delete the others.</p>
|
||||
</div>
|
||||
</div>
|
||||
@@ -18,6 +18,7 @@
|
||||
<div class="context-menu-item" data-action="relink-civitai"><i class="fas fa-link"></i> Re-link to Civitai</div>
|
||||
<div class="context-menu-item" data-action="copyname"><i class="fas fa-copy"></i> Copy Model Filename</div>
|
||||
<div class="context-menu-item" data-action="preview"><i class="fas fa-folder-open"></i> Open Examples Folder</div>
|
||||
<div class="context-menu-item" data-action="download-examples"><i class="fas fa-download"></i> Download Example Images</div>
|
||||
<div class="context-menu-item" data-action="replace-preview"><i class="fas fa-image"></i> Replace Preview</div>
|
||||
<div class="context-menu-item" data-action="set-nsfw"><i class="fas fa-exclamation-triangle"></i> Set Content Rating</div>
|
||||
<div class="context-menu-separator"></div>
|
||||
@@ -29,27 +30,7 @@
|
||||
|
||||
{% block content %}
|
||||
{% include 'components/controls.html' %}
|
||||
|
||||
<!-- Duplicates banner (hidden by default) -->
|
||||
<div id="duplicatesBanner" class="duplicates-banner" style="display: none;">
|
||||
<div class="banner-content">
|
||||
<i class="fas fa-exclamation-triangle"></i>
|
||||
<span id="duplicatesCount">Found 0 duplicate groups</span>
|
||||
<i class="fas fa-question-circle help-icon" id="duplicatesHelp" aria-label="Help information"></i>
|
||||
<div class="banner-actions">
|
||||
<button class="btn-delete-selected disabled" onclick="modelDuplicatesManager.deleteSelectedDuplicates()">
|
||||
Delete Selected (<span id="duplicatesSelectedCount">0</span>)
|
||||
</button>
|
||||
<button class="btn-exit-mode" onclick="modelDuplicatesManager.exitDuplicateMode()">
|
||||
<i class="fas fa-times"></i> Exit Mode
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="help-tooltip" id="duplicatesHelpTooltip">
|
||||
<p>Identical hashes mean identical model files, even if they have different names or previews.</p>
|
||||
<p>Keep only one version (preferably with better metadata/previews) and safely delete the others.</p>
|
||||
</div>
|
||||
</div>
|
||||
{% include 'components/duplicates_banner.html' %}
|
||||
|
||||
<!-- Embedding cards container -->
|
||||
<div class="card-grid" id="modelGrid">
|
||||
|
||||
@@ -16,27 +16,7 @@
|
||||
{% block content %}
|
||||
{% include 'components/controls.html' %}
|
||||
{% include 'components/alphabet_bar.html' %}
|
||||
|
||||
<!-- Duplicates banner (hidden by default) -->
|
||||
<div id="duplicatesBanner" class="duplicates-banner" style="display: none;">
|
||||
<div class="banner-content">
|
||||
<i class="fas fa-exclamation-triangle"></i>
|
||||
<span id="duplicatesCount">Found 0 duplicate groups</span>
|
||||
<i class="fas fa-question-circle help-icon" id="duplicatesHelp" aria-label="Help information"></i>
|
||||
<div class="banner-actions">
|
||||
<button class="btn-delete-selected disabled" onclick="modelDuplicatesManager.deleteSelectedDuplicates()">
|
||||
Delete Selected (<span id="duplicatesSelectedCount">0</span>)
|
||||
</button>
|
||||
<button class="btn-exit-mode" onclick="modelDuplicatesManager.exitDuplicateMode()">
|
||||
<i class="fas fa-times"></i> Exit Mode
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="help-tooltip" id="duplicatesHelpTooltip">
|
||||
<p>Identical hashes mean identical model files, even if they have different names or previews.</p>
|
||||
<p>Keep only one version (preferably with better metadata/previews) and safely delete the others.</p>
|
||||
</div>
|
||||
</div>
|
||||
{% include 'components/duplicates_banner.html' %}
|
||||
|
||||
<!-- Lora卡片容器 -->
|
||||
<div class="card-grid" id="modelGrid">
|
||||
|
||||
452
web/comfyui/autocomplete.js
Normal file
452
web/comfyui/autocomplete.js
Normal file
@@ -0,0 +1,452 @@
|
||||
import { api } from "../../scripts/api.js";
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { TextAreaCaretHelper } from "./textarea_caret_helper.js";
|
||||
|
||||
class AutoComplete {
|
||||
constructor(inputElement, modelType = 'loras', options = {}) {
|
||||
this.inputElement = inputElement;
|
||||
this.modelType = modelType;
|
||||
this.options = {
|
||||
maxItems: 15,
|
||||
minChars: 1,
|
||||
debounceDelay: 200,
|
||||
showPreview: true,
|
||||
...options
|
||||
};
|
||||
|
||||
this.dropdown = null;
|
||||
this.selectedIndex = -1;
|
||||
this.items = [];
|
||||
this.debounceTimer = null;
|
||||
this.isVisible = false;
|
||||
this.currentSearchTerm = '';
|
||||
this.previewTooltip = null;
|
||||
|
||||
// Initialize TextAreaCaretHelper
|
||||
this.helper = new TextAreaCaretHelper(inputElement, () => app.canvas.ds.scale);
|
||||
|
||||
this.init();
|
||||
}
|
||||
|
||||
init() {
|
||||
this.createDropdown();
|
||||
this.bindEvents();
|
||||
}
|
||||
|
||||
createDropdown() {
|
||||
this.dropdown = document.createElement('div');
|
||||
this.dropdown.className = 'comfy-autocomplete-dropdown';
|
||||
|
||||
// Apply new color scheme
|
||||
this.dropdown.style.cssText = `
|
||||
position: absolute;
|
||||
z-index: 10000;
|
||||
overflow-y: visible;
|
||||
background-color: rgba(40, 44, 52, 0.95);
|
||||
border: 1px solid rgba(226, 232, 240, 0.2);
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
|
||||
display: none;
|
||||
font-family: Arial, sans-serif;
|
||||
font-size: 14px;
|
||||
min-width: 200px;
|
||||
width: auto;
|
||||
backdrop-filter: blur(8px);
|
||||
-webkit-backdrop-filter: blur(8px);
|
||||
`;
|
||||
|
||||
// Custom scrollbar styles with new color scheme
|
||||
const style = document.createElement('style');
|
||||
style.textContent = `
|
||||
.comfy-autocomplete-dropdown::-webkit-scrollbar {
|
||||
width: 8px;
|
||||
}
|
||||
.comfy-autocomplete-dropdown::-webkit-scrollbar-track {
|
||||
background: rgba(40, 44, 52, 0.3);
|
||||
border-radius: 4px;
|
||||
}
|
||||
.comfy-autocomplete-dropdown::-webkit-scrollbar-thumb {
|
||||
background: rgba(226, 232, 240, 0.2);
|
||||
border-radius: 4px;
|
||||
}
|
||||
.comfy-autocomplete-dropdown::-webkit-scrollbar-thumb:hover {
|
||||
background: rgba(226, 232, 240, 0.4);
|
||||
}
|
||||
`;
|
||||
document.head.appendChild(style);
|
||||
|
||||
// Append to body to avoid overflow issues
|
||||
document.body.appendChild(this.dropdown);
|
||||
|
||||
// Initialize preview tooltip if needed
|
||||
if (this.options.showPreview && this.modelType === 'loras') {
|
||||
this.initPreviewTooltip();
|
||||
}
|
||||
}
|
||||
|
||||
initPreviewTooltip() {
|
||||
// Dynamically import and create preview tooltip
|
||||
import('./loras_widget_components.js').then(module => {
|
||||
this.previewTooltip = new module.PreviewTooltip();
|
||||
}).catch(err => {
|
||||
console.warn('Failed to load preview tooltip:', err);
|
||||
});
|
||||
}
|
||||
|
||||
bindEvents() {
|
||||
// Handle input changes
|
||||
this.inputElement.addEventListener('input', (e) => {
|
||||
this.handleInput(e.target.value);
|
||||
});
|
||||
|
||||
// Handle keyboard navigation
|
||||
this.inputElement.addEventListener('keydown', (e) => {
|
||||
this.handleKeyDown(e);
|
||||
});
|
||||
|
||||
// Handle focus out to hide dropdown
|
||||
this.inputElement.addEventListener('blur', (e) => {
|
||||
// Delay hiding to allow for clicks on dropdown items
|
||||
setTimeout(() => {
|
||||
this.hide();
|
||||
}, 150);
|
||||
});
|
||||
|
||||
// Handle clicks outside to hide dropdown
|
||||
document.addEventListener('click', (e) => {
|
||||
if (!this.dropdown.contains(e.target) && e.target !== this.inputElement) {
|
||||
this.hide();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
handleInput(value = '') {
|
||||
// Clear previous debounce timer
|
||||
if (this.debounceTimer) {
|
||||
clearTimeout(this.debounceTimer);
|
||||
}
|
||||
|
||||
// Get the search term (text after last comma)
|
||||
const searchTerm = this.getSearchTerm(value);
|
||||
|
||||
if (searchTerm.length < this.options.minChars) {
|
||||
this.hide();
|
||||
return;
|
||||
}
|
||||
|
||||
// Debounce the search
|
||||
this.debounceTimer = setTimeout(() => {
|
||||
this.search(searchTerm);
|
||||
}, this.options.debounceDelay);
|
||||
}
|
||||
|
||||
getSearchTerm(value) {
|
||||
const lastCommaIndex = value.lastIndexOf(',');
|
||||
if (lastCommaIndex === -1) {
|
||||
return value.trim();
|
||||
}
|
||||
return value.substring(lastCommaIndex + 1).trim();
|
||||
}
|
||||
|
||||
async search(term = '') {
|
||||
try {
|
||||
this.currentSearchTerm = term;
|
||||
const response = await api.fetchApi(`/${this.modelType}/relative-paths?search=${encodeURIComponent(term)}&limit=${this.options.maxItems}`);
|
||||
const data = await response.json();
|
||||
|
||||
if (data.success && data.relative_paths && data.relative_paths.length > 0) {
|
||||
this.items = data.relative_paths;
|
||||
this.render();
|
||||
this.show();
|
||||
} else {
|
||||
this.items = [];
|
||||
this.hide();
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Autocomplete search error:', error);
|
||||
this.items = [];
|
||||
this.hide();
|
||||
}
|
||||
}
|
||||
|
||||
render() {
|
||||
this.dropdown.innerHTML = '';
|
||||
this.selectedIndex = -1;
|
||||
|
||||
// Early return if no items to prevent empty dropdown
|
||||
if (!this.items || this.items.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.items.forEach((relativePath, index) => {
|
||||
const item = document.createElement('div');
|
||||
item.className = 'comfy-autocomplete-item';
|
||||
|
||||
// Create highlighted content
|
||||
const highlightedContent = this.highlightMatch(relativePath, this.currentSearchTerm);
|
||||
item.innerHTML = highlightedContent;
|
||||
|
||||
// Apply item styles with new color scheme
|
||||
item.style.cssText = `
|
||||
padding: 8px 12px;
|
||||
cursor: pointer;
|
||||
color: rgba(226, 232, 240, 0.8);
|
||||
border-bottom: 1px solid rgba(226, 232, 240, 0.1);
|
||||
transition: all 0.2s ease;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
position: relative;
|
||||
`;
|
||||
|
||||
// Hover and selection handlers
|
||||
item.addEventListener('mouseenter', () => {
|
||||
this.selectItem(index);
|
||||
this.showPreviewForItem(relativePath, item);
|
||||
});
|
||||
|
||||
item.addEventListener('mouseleave', () => {
|
||||
this.hidePreview();
|
||||
});
|
||||
|
||||
// Click handler
|
||||
item.addEventListener('click', () => {
|
||||
this.insertSelection(relativePath);
|
||||
});
|
||||
|
||||
this.dropdown.appendChild(item);
|
||||
});
|
||||
|
||||
// Remove border from last item
|
||||
if (this.dropdown.lastChild) {
|
||||
this.dropdown.lastChild.style.borderBottom = 'none';
|
||||
}
|
||||
}
|
||||
|
||||
highlightMatch(text, searchTerm) {
|
||||
if (!searchTerm) return text;
|
||||
|
||||
const regex = new RegExp(`(${searchTerm.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')})`, 'gi');
|
||||
return text.replace(regex, '<span style="background-color: rgba(66, 153, 225, 0.3); color: white; padding: 1px 2px; border-radius: 2px;">$1</span>');
|
||||
}
|
||||
|
||||
showPreviewForItem(relativePath, itemElement) {
|
||||
if (!this.previewTooltip) return;
|
||||
|
||||
// Extract filename without extension for preview
|
||||
const fileName = relativePath.split('/').pop();
|
||||
const loraName = fileName.replace(/\.(safetensors|ckpt|pt|bin)$/i, '');
|
||||
|
||||
// Get item position for tooltip positioning
|
||||
const rect = itemElement.getBoundingClientRect();
|
||||
const x = rect.right + 10;
|
||||
const y = rect.top;
|
||||
|
||||
this.previewTooltip.show(loraName, x, y);
|
||||
}
|
||||
|
||||
hidePreview() {
|
||||
if (this.previewTooltip) {
|
||||
this.previewTooltip.hide();
|
||||
}
|
||||
}
|
||||
|
||||
show() {
|
||||
if (!this.items || this.items.length === 0) {
|
||||
this.hide();
|
||||
return;
|
||||
}
|
||||
|
||||
// Position dropdown at cursor position using TextAreaCaretHelper
|
||||
this.positionAtCursor();
|
||||
this.dropdown.style.display = 'block';
|
||||
this.isVisible = true;
|
||||
}
|
||||
|
||||
positionAtCursor() {
|
||||
const position = this.helper.getCursorOffset();
|
||||
this.dropdown.style.left = (position.left ?? 0) + "px";
|
||||
this.dropdown.style.top = (position.top ?? 0) + "px";
|
||||
this.dropdown.style.maxHeight = (window.innerHeight - position.top) + "px";
|
||||
|
||||
// Adjust width to fit content
|
||||
// Temporarily show the dropdown to measure content width
|
||||
const originalDisplay = this.dropdown.style.display;
|
||||
this.dropdown.style.display = 'block';
|
||||
this.dropdown.style.visibility = 'hidden';
|
||||
|
||||
// Measure the content width
|
||||
let maxWidth = 200; // minimum width
|
||||
const items = this.dropdown.querySelectorAll('.comfy-autocomplete-item');
|
||||
items.forEach(item => {
|
||||
const itemWidth = item.scrollWidth + 24; // Add padding
|
||||
maxWidth = Math.max(maxWidth, itemWidth);
|
||||
});
|
||||
|
||||
// Set the width and restore visibility
|
||||
this.dropdown.style.width = Math.min(maxWidth, 400) + 'px'; // Cap at 400px
|
||||
this.dropdown.style.visibility = 'visible';
|
||||
this.dropdown.style.display = originalDisplay;
|
||||
}
|
||||
|
||||
getCaretPosition() {
|
||||
return this.inputElement.selectionStart || 0;
|
||||
}
|
||||
|
||||
hide() {
|
||||
this.dropdown.style.display = 'none';
|
||||
this.isVisible = false;
|
||||
this.selectedIndex = -1;
|
||||
|
||||
// Hide preview tooltip
|
||||
this.hidePreview();
|
||||
|
||||
// Clear selection styles from all items
|
||||
const items = this.dropdown.querySelectorAll('.comfy-autocomplete-item');
|
||||
items.forEach(item => {
|
||||
item.classList.remove('comfy-autocomplete-item-selected');
|
||||
item.style.backgroundColor = '';
|
||||
});
|
||||
}
|
||||
|
||||
selectItem(index) {
|
||||
// Remove previous selection
|
||||
const prevSelected = this.dropdown.querySelector('.comfy-autocomplete-item-selected');
|
||||
if (prevSelected) {
|
||||
prevSelected.classList.remove('comfy-autocomplete-item-selected');
|
||||
prevSelected.style.backgroundColor = '';
|
||||
}
|
||||
|
||||
// Add new selection
|
||||
if (index >= 0 && index < this.items.length) {
|
||||
this.selectedIndex = index;
|
||||
const item = this.dropdown.children[index];
|
||||
item.classList.add('comfy-autocomplete-item-selected');
|
||||
item.style.backgroundColor = 'rgba(66, 153, 225, 0.2)';
|
||||
|
||||
// Scroll into view if needed
|
||||
item.scrollIntoView({ block: 'nearest' });
|
||||
|
||||
// Show preview for selected item
|
||||
if (this.options.showPreview) {
|
||||
this.showPreviewForItem(this.items[index], item);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
handleKeyDown(e) {
|
||||
if (!this.isVisible) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (e.key) {
|
||||
case 'ArrowDown':
|
||||
e.preventDefault();
|
||||
this.selectItem(Math.min(this.selectedIndex + 1, this.items.length - 1));
|
||||
break;
|
||||
|
||||
case 'ArrowUp':
|
||||
e.preventDefault();
|
||||
this.selectItem(Math.max(this.selectedIndex - 1, 0));
|
||||
break;
|
||||
|
||||
case 'Enter':
|
||||
e.preventDefault();
|
||||
if (this.selectedIndex >= 0 && this.selectedIndex < this.items.length) {
|
||||
this.insertSelection(this.items[this.selectedIndex]);
|
||||
}
|
||||
break;
|
||||
|
||||
case 'Escape':
|
||||
e.preventDefault();
|
||||
this.hide();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
async insertSelection(relativePath) {
|
||||
// Extract just the filename for LoRA name
|
||||
const fileName = relativePath.split('/').pop().replace(/\.(safetensors|ckpt|pt|bin)$/i, '');
|
||||
|
||||
// Get usage tips and extract strength
|
||||
let strength = 1.0; // Default strength
|
||||
try {
|
||||
const response = await api.fetchApi(`/loras/usage-tips-by-path?relative_path=${encodeURIComponent(relativePath)}`);
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
if (data.success && data.usage_tips) {
|
||||
// Parse JSON string and extract strength
|
||||
try {
|
||||
const usageTips = JSON.parse(data.usage_tips);
|
||||
if (usageTips.strength && typeof usageTips.strength === 'number') {
|
||||
strength = usageTips.strength;
|
||||
}
|
||||
} catch (parseError) {
|
||||
console.warn('Failed to parse usage tips JSON:', parseError);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Failed to fetch usage tips:', error);
|
||||
}
|
||||
|
||||
// Format the LoRA code with strength
|
||||
const loraCode = `<lora:${fileName}:${strength}>, `;
|
||||
|
||||
const currentValue = this.inputElement.value;
|
||||
const caretPos = this.getCaretPosition();
|
||||
const lastCommaIndex = currentValue.lastIndexOf(',', caretPos - 1);
|
||||
|
||||
let newValue;
|
||||
let newCaretPos;
|
||||
|
||||
if (lastCommaIndex === -1) {
|
||||
// No comma found before cursor, replace from start or current search term start
|
||||
const searchTerm = this.getSearchTerm(currentValue.substring(0, caretPos));
|
||||
const searchStartPos = caretPos - searchTerm.length;
|
||||
newValue = currentValue.substring(0, searchStartPos) + loraCode + currentValue.substring(caretPos);
|
||||
newCaretPos = searchStartPos + loraCode.length;
|
||||
} else {
|
||||
// Replace text after last comma before cursor
|
||||
const afterCommaPos = lastCommaIndex + 1;
|
||||
// Skip whitespace after comma
|
||||
let insertPos = afterCommaPos;
|
||||
while (insertPos < caretPos && /\s/.test(currentValue[insertPos])) {
|
||||
insertPos++;
|
||||
}
|
||||
|
||||
newValue = currentValue.substring(0, insertPos) + loraCode + currentValue.substring(caretPos);
|
||||
newCaretPos = insertPos + loraCode.length;
|
||||
}
|
||||
|
||||
this.inputElement.value = newValue;
|
||||
|
||||
// Trigger input event to notify about the change
|
||||
const event = new Event('input', { bubbles: true });
|
||||
this.inputElement.dispatchEvent(event);
|
||||
|
||||
this.hide();
|
||||
|
||||
// Focus back to input and position cursor
|
||||
this.inputElement.focus();
|
||||
this.inputElement.setSelectionRange(newCaretPos, newCaretPos);
|
||||
}
|
||||
|
||||
destroy() {
|
||||
if (this.debounceTimer) {
|
||||
clearTimeout(this.debounceTimer);
|
||||
}
|
||||
|
||||
if (this.previewTooltip) {
|
||||
this.previewTooltip.cleanup();
|
||||
}
|
||||
|
||||
if (this.dropdown && this.dropdown.parentNode) {
|
||||
this.dropdown.parentNode.removeChild(this.dropdown);
|
||||
}
|
||||
|
||||
// Remove event listeners would be added here if we tracked them
|
||||
}
|
||||
}
|
||||
|
||||
export { AutoComplete };
|
||||
@@ -5,7 +5,8 @@ import {
|
||||
collectActiveLorasFromChain,
|
||||
updateConnectedTriggerWords,
|
||||
chainCallback,
|
||||
mergeLoras
|
||||
mergeLoras,
|
||||
setupInputWidgetWithAutocomplete
|
||||
} from "./utils.js";
|
||||
import { addLorasWidget } from "./loras_widget.js";
|
||||
|
||||
@@ -144,8 +145,9 @@ app.registerExtension({
|
||||
}
|
||||
);
|
||||
|
||||
// Clean up multiple spaces and trim
|
||||
newText = newText.replace(/\s+/g, " ").trim();
|
||||
// Clean up multiple spaces, extra commas, and trim; remove trailing comma if it's the only content
|
||||
newText = newText.replace(/\s+/g, " ").replace(/,\s*,+/g, ",").trim();
|
||||
if (newText === ",") newText = "";
|
||||
|
||||
inputWidget.value = newText;
|
||||
} finally {
|
||||
@@ -158,7 +160,8 @@ app.registerExtension({
|
||||
const inputWidget = this.widgets[0];
|
||||
inputWidget.options.getMaxHeight = () => 100;
|
||||
this.inputWidget = inputWidget;
|
||||
inputWidget.callback = (value) => {
|
||||
|
||||
const originalCallback = (value) => {
|
||||
if (isUpdating) return;
|
||||
isUpdating = true;
|
||||
|
||||
@@ -172,6 +175,9 @@ app.registerExtension({
|
||||
}
|
||||
};
|
||||
|
||||
// Setup input widget with autocomplete
|
||||
inputWidget.callback = setupInputWidgetWithAutocomplete(this, inputWidget, originalCallback);
|
||||
|
||||
// Register this node with the backend
|
||||
this.registerNode = async () => {
|
||||
try {
|
||||
|
||||
@@ -5,7 +5,8 @@ import {
|
||||
collectActiveLorasFromChain,
|
||||
updateConnectedTriggerWords,
|
||||
chainCallback,
|
||||
mergeLoras
|
||||
mergeLoras,
|
||||
setupInputWidgetWithAutocomplete
|
||||
} from "./utils.js";
|
||||
import { addLorasWidget } from "./loras_widget.js";
|
||||
|
||||
@@ -52,8 +53,9 @@ app.registerExtension({
|
||||
return currentLoras.includes(name) ? match : '';
|
||||
});
|
||||
|
||||
// Clean up multiple spaces and trim
|
||||
newText = newText.replace(/\s+/g, ' ').trim();
|
||||
// Clean up multiple spaces, extra commas, and trim; remove trailing comma if it's the only content
|
||||
newText = newText.replace(/\s+/g, " ").replace(/,\s*,+/g, ",").trim();
|
||||
if (newText === ",") newText = "";
|
||||
|
||||
inputWidget.value = newText;
|
||||
|
||||
@@ -79,7 +81,8 @@ app.registerExtension({
|
||||
const inputWidget = this.widgets[0];
|
||||
inputWidget.options.getMaxHeight = () => 100;
|
||||
this.inputWidget = inputWidget;
|
||||
inputWidget.callback = (value) => {
|
||||
// Wrap the callback with autocomplete setup
|
||||
const originalCallback = (value) => {
|
||||
if (isUpdating) return;
|
||||
isUpdating = true;
|
||||
|
||||
@@ -99,6 +102,7 @@ app.registerExtension({
|
||||
isUpdating = false;
|
||||
}
|
||||
};
|
||||
inputWidget.callback = setupInputWidgetWithAutocomplete(this, inputWidget, originalCallback);
|
||||
|
||||
// Register this node with the backend
|
||||
this.registerNode = async () => {
|
||||
|
||||
@@ -219,18 +219,26 @@ export class PreviewTooltip {
|
||||
display: 'none',
|
||||
overflow: 'hidden',
|
||||
maxWidth: '300px',
|
||||
pointerEvents: 'none', // Prevent interference with autocomplete
|
||||
});
|
||||
document.body.appendChild(this.element);
|
||||
this.hideTimeout = null;
|
||||
this.isFromAutocomplete = false;
|
||||
|
||||
// Add global click event to hide tooltip
|
||||
document.addEventListener('click', () => this.hide());
|
||||
// Modified event listeners for autocomplete compatibility
|
||||
this.globalClickHandler = (e) => {
|
||||
// Don't hide if click is on autocomplete dropdown
|
||||
if (!e.target.closest('.comfy-autocomplete-dropdown')) {
|
||||
this.hide();
|
||||
}
|
||||
};
|
||||
document.addEventListener('click', this.globalClickHandler);
|
||||
|
||||
// Add scroll event listener
|
||||
document.addEventListener('scroll', () => this.hide(), true);
|
||||
this.globalScrollHandler = () => this.hide();
|
||||
document.addEventListener('scroll', this.globalScrollHandler, true);
|
||||
}
|
||||
|
||||
async show(loraName, x, y) {
|
||||
async show(loraName, x, y, fromAutocomplete = false) {
|
||||
try {
|
||||
// Clear previous hide timer
|
||||
if (this.hideTimeout) {
|
||||
@@ -238,8 +246,12 @@ export class PreviewTooltip {
|
||||
this.hideTimeout = null;
|
||||
}
|
||||
|
||||
// Track if this is from autocomplete
|
||||
this.isFromAutocomplete = fromAutocomplete;
|
||||
|
||||
// Don't redisplay the same lora preview
|
||||
if (this.element.style.display === 'block' && this.currentLora === loraName) {
|
||||
this.position(x, y);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -300,7 +312,7 @@ export class PreviewTooltip {
|
||||
left: '0',
|
||||
right: '0',
|
||||
padding: '8px',
|
||||
color: 'rgba(255, 255, 255, 0.95)',
|
||||
color: 'white',
|
||||
fontSize: '13px',
|
||||
fontFamily: "'Inter', 'Segoe UI', system-ui, -apple-system, sans-serif",
|
||||
background: 'linear-gradient(transparent, rgba(0, 0, 0, 0.8))',
|
||||
@@ -349,6 +361,10 @@ export class PreviewTooltip {
|
||||
top = y - rect.height - 10;
|
||||
}
|
||||
|
||||
// Ensure minimum distance from edges
|
||||
left = Math.max(10, Math.min(left, viewportWidth - rect.width - 10));
|
||||
top = Math.max(10, Math.min(top, viewportHeight - rect.height - 10));
|
||||
|
||||
Object.assign(this.element.style, {
|
||||
left: `${left}px`,
|
||||
top: `${top}px`
|
||||
@@ -362,6 +378,7 @@ export class PreviewTooltip {
|
||||
this.hideTimeout = setTimeout(() => {
|
||||
this.element.style.display = 'none';
|
||||
this.currentLora = null;
|
||||
this.isFromAutocomplete = false;
|
||||
// Stop video playback
|
||||
const video = this.element.querySelector('video');
|
||||
if (video) {
|
||||
@@ -376,9 +393,9 @@ export class PreviewTooltip {
|
||||
if (this.hideTimeout) {
|
||||
clearTimeout(this.hideTimeout);
|
||||
}
|
||||
// Remove all event listeners
|
||||
document.removeEventListener('click', () => this.hide());
|
||||
document.removeEventListener('scroll', () => this.hide(), true);
|
||||
// Remove event listeners properly
|
||||
document.removeEventListener('click', this.globalClickHandler);
|
||||
document.removeEventListener('scroll', this.globalScrollHandler, true);
|
||||
this.element.remove();
|
||||
}
|
||||
}
|
||||
|
||||
332
web/comfyui/textarea_caret_helper.js
Normal file
332
web/comfyui/textarea_caret_helper.js
Normal file
@@ -0,0 +1,332 @@
|
||||
/*
|
||||
https://github.com/component/textarea-caret-position
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015 Jonathan Ong me@jongleberry.com
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
const getCaretCoordinates = (function () {
|
||||
// We'll copy the properties below into the mirror div.
|
||||
// Note that some browsers, such as Firefox, do not concatenate properties
|
||||
// into their shorthand (e.g. padding-top, padding-bottom etc. -> padding),
|
||||
// so we have to list every single property explicitly.
|
||||
var properties = [
|
||||
"direction", // RTL support
|
||||
"boxSizing",
|
||||
"width", // on Chrome and IE, exclude the scrollbar, so the mirror div wraps exactly as the textarea does
|
||||
"height",
|
||||
"overflowX",
|
||||
"overflowY", // copy the scrollbar for IE
|
||||
|
||||
"borderTopWidth",
|
||||
"borderRightWidth",
|
||||
"borderBottomWidth",
|
||||
"borderLeftWidth",
|
||||
"borderStyle",
|
||||
|
||||
"paddingTop",
|
||||
"paddingRight",
|
||||
"paddingBottom",
|
||||
"paddingLeft",
|
||||
|
||||
// https://developer.mozilla.org/en-US/docs/Web/CSS/font
|
||||
"fontStyle",
|
||||
"fontVariant",
|
||||
"fontWeight",
|
||||
"fontStretch",
|
||||
"fontSize",
|
||||
"fontSizeAdjust",
|
||||
"lineHeight",
|
||||
"fontFamily",
|
||||
|
||||
"textAlign",
|
||||
"textTransform",
|
||||
"textIndent",
|
||||
"textDecoration", // might not make a difference, but better be safe
|
||||
|
||||
"letterSpacing",
|
||||
"wordSpacing",
|
||||
|
||||
"tabSize",
|
||||
"MozTabSize",
|
||||
];
|
||||
|
||||
var isBrowser = typeof window !== "undefined";
|
||||
var isFirefox = isBrowser && window.mozInnerScreenX != null;
|
||||
|
||||
return function getCaretCoordinates(element, position, options) {
|
||||
if (!isBrowser) {
|
||||
throw new Error("textarea-caret-position#getCaretCoordinates should only be called in a browser");
|
||||
}
|
||||
|
||||
var debug = (options && options.debug) || false;
|
||||
if (debug) {
|
||||
var el = document.querySelector("#input-textarea-caret-position-mirror-div");
|
||||
if (el) el.parentNode.removeChild(el);
|
||||
}
|
||||
|
||||
// The mirror div will replicate the textarea's style
|
||||
var div = document.createElement("div");
|
||||
div.id = "input-textarea-caret-position-mirror-div";
|
||||
document.body.appendChild(div);
|
||||
|
||||
var style = div.style;
|
||||
var computed = window.getComputedStyle ? window.getComputedStyle(element) : element.currentStyle; // currentStyle for IE < 9
|
||||
var isInput = element.nodeName === "INPUT";
|
||||
|
||||
// Default textarea styles
|
||||
style.whiteSpace = "pre-wrap";
|
||||
if (!isInput) style.wordWrap = "break-word"; // only for textarea-s
|
||||
|
||||
// Position off-screen
|
||||
style.position = "absolute"; // required to return coordinates properly
|
||||
if (!debug) style.visibility = "hidden"; // not 'display: none' because we want rendering
|
||||
|
||||
// Transfer the element's properties to the div
|
||||
properties.forEach(function (prop) {
|
||||
if (isInput && prop === "lineHeight") {
|
||||
// Special case for <input>s because text is rendered centered and line height may be != height
|
||||
if (computed.boxSizing === "border-box") {
|
||||
var height = parseInt(computed.height);
|
||||
var outerHeight =
|
||||
parseInt(computed.paddingTop) +
|
||||
parseInt(computed.paddingBottom) +
|
||||
parseInt(computed.borderTopWidth) +
|
||||
parseInt(computed.borderBottomWidth);
|
||||
var targetHeight = outerHeight + parseInt(computed.lineHeight);
|
||||
if (height > targetHeight) {
|
||||
style.lineHeight = height - outerHeight + "px";
|
||||
} else if (height === targetHeight) {
|
||||
style.lineHeight = computed.lineHeight;
|
||||
} else {
|
||||
style.lineHeight = 0;
|
||||
}
|
||||
} else {
|
||||
style.lineHeight = computed.height;
|
||||
}
|
||||
} else {
|
||||
style[prop] = computed[prop];
|
||||
}
|
||||
});
|
||||
|
||||
if (isFirefox) {
|
||||
// Firefox lies about the overflow property for textareas: https://bugzilla.mozilla.org/show_bug.cgi?id=984275
|
||||
if (element.scrollHeight > parseInt(computed.height)) style.overflowY = "scroll";
|
||||
} else {
|
||||
style.overflow = "hidden"; // for Chrome to not render a scrollbar; IE keeps overflowY = 'scroll'
|
||||
}
|
||||
|
||||
div.textContent = element.value.substring(0, position);
|
||||
// The second special handling for input type="text" vs textarea:
|
||||
// spaces need to be replaced with non-breaking spaces - http://stackoverflow.com/a/13402035/1269037
|
||||
if (isInput) div.textContent = div.textContent.replace(/\s/g, "\u00a0");
|
||||
|
||||
var span = document.createElement("span");
|
||||
// Wrapping must be replicated *exactly*, including when a long word gets
|
||||
// onto the next line, with whitespace at the end of the line before (#7).
|
||||
// The *only* reliable way to do that is to copy the *entire* rest of the
|
||||
// textarea's content into the <span> created at the caret position.
|
||||
// For inputs, just '.' would be enough, but no need to bother.
|
||||
span.textContent = element.value.substring(position) || "."; // || because a completely empty faux span doesn't render at all
|
||||
div.appendChild(span);
|
||||
|
||||
var coordinates = {
|
||||
top: span.offsetTop + parseInt(computed["borderTopWidth"]),
|
||||
left: span.offsetLeft + parseInt(computed["borderLeftWidth"]),
|
||||
height: parseInt(computed["lineHeight"]),
|
||||
};
|
||||
|
||||
if (debug) {
|
||||
span.style.backgroundColor = "#aaa";
|
||||
} else {
|
||||
document.body.removeChild(div);
|
||||
}
|
||||
|
||||
return coordinates;
|
||||
};
|
||||
})();
|
||||
|
||||
/*
|
||||
Key functions from:
|
||||
https://github.com/yuku/textcomplete
|
||||
© Yuku Takahashi - This software is licensed under the MIT license.
|
||||
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015 Jonathan Ong me@jongleberry.com
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
const CHAR_CODE_ZERO = "0".charCodeAt(0);
|
||||
const CHAR_CODE_NINE = "9".charCodeAt(0);
|
||||
|
||||
export class TextAreaCaretHelper {
|
||||
constructor(el, getScale) {
|
||||
this.el = el;
|
||||
this.getScale = getScale;
|
||||
}
|
||||
|
||||
#calculateElementOffset() {
|
||||
const rect = this.el.getBoundingClientRect();
|
||||
const owner = this.el.ownerDocument;
|
||||
if (owner == null) {
|
||||
throw new Error("Given element does not belong to document");
|
||||
}
|
||||
const { defaultView, documentElement } = owner;
|
||||
if (defaultView == null) {
|
||||
throw new Error("Given element does not belong to window");
|
||||
}
|
||||
const offset = {
|
||||
top: rect.top + defaultView.pageYOffset,
|
||||
left: rect.left + defaultView.pageXOffset,
|
||||
};
|
||||
if (documentElement) {
|
||||
offset.top -= documentElement.clientTop;
|
||||
offset.left -= documentElement.clientLeft;
|
||||
}
|
||||
return offset;
|
||||
}
|
||||
|
||||
#isDigit(charCode) {
|
||||
return CHAR_CODE_ZERO <= charCode && charCode <= CHAR_CODE_NINE;
|
||||
}
|
||||
|
||||
#getLineHeightPx() {
|
||||
const computedStyle = getComputedStyle(this.el);
|
||||
const lineHeight = computedStyle.lineHeight;
|
||||
// If the char code starts with a digit, it is either a value in pixels,
|
||||
// or unitless, as per:
|
||||
// https://drafts.csswg.org/css2/visudet.html#propdef-line-height
|
||||
// https://drafts.csswg.org/css2/cascade.html#computed-value
|
||||
if (this.#isDigit(lineHeight.charCodeAt(0))) {
|
||||
const floatLineHeight = parseFloat(lineHeight);
|
||||
// In real browsers the value is *always* in pixels, even for unit-less
|
||||
// line-heights. However, we still check as per the spec.
|
||||
return this.#isDigit(lineHeight.charCodeAt(lineHeight.length - 1))
|
||||
? floatLineHeight * parseFloat(computedStyle.fontSize)
|
||||
: floatLineHeight;
|
||||
}
|
||||
// Otherwise, the value is "normal".
|
||||
// If the line-height is "normal", calculate by font-size
|
||||
return this.#calculateLineHeightPx(this.el.nodeName, computedStyle);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns calculated line-height of the given node in pixels.
|
||||
*/
|
||||
#calculateLineHeightPx(nodeName, computedStyle) {
|
||||
const body = document.body;
|
||||
if (!body) return 0;
|
||||
|
||||
const tempNode = document.createElement(nodeName);
|
||||
tempNode.innerHTML = " ";
|
||||
Object.assign(tempNode.style, {
|
||||
fontSize: computedStyle.fontSize,
|
||||
fontFamily: computedStyle.fontFamily,
|
||||
padding: "0",
|
||||
position: "absolute",
|
||||
});
|
||||
body.appendChild(tempNode);
|
||||
|
||||
// Make sure textarea has only 1 row
|
||||
if (tempNode instanceof HTMLTextAreaElement) {
|
||||
tempNode.rows = 1;
|
||||
}
|
||||
|
||||
// Assume the height of the element is the line-height
|
||||
const height = tempNode.offsetHeight;
|
||||
body.removeChild(tempNode);
|
||||
|
||||
return height;
|
||||
}
|
||||
|
||||
getCursorOffset() {
|
||||
const scale = this.getScale();
|
||||
const elOffset = this.#calculateElementOffset();
|
||||
const elScroll = this.#getElScroll();
|
||||
const cursorPosition = this.#getCursorPosition();
|
||||
const lineHeight = this.#getLineHeightPx();
|
||||
const top = elOffset.top - (elScroll.top * scale) + (cursorPosition.top + lineHeight) * scale;
|
||||
const left = elOffset.left - elScroll.left + cursorPosition.left;
|
||||
const clientTop = this.el.getBoundingClientRect().top;
|
||||
if (this.el.dir !== "rtl") {
|
||||
return { top, left, lineHeight, clientTop };
|
||||
} else {
|
||||
const right = document.documentElement ? document.documentElement.clientWidth - left : 0;
|
||||
return { top, right, lineHeight, clientTop };
|
||||
}
|
||||
}
|
||||
|
||||
#getElScroll() {
|
||||
return { top: this.el.scrollTop, left: this.el.scrollLeft };
|
||||
}
|
||||
|
||||
#getCursorPosition() {
|
||||
return getCaretCoordinates(this.el, this.el.selectionEnd);
|
||||
}
|
||||
|
||||
getBeforeCursor() {
|
||||
return this.el.selectionStart !== this.el.selectionEnd ? null : this.el.value.substring(0, this.el.selectionEnd);
|
||||
}
|
||||
|
||||
getAfterCursor() {
|
||||
return this.el.value.substring(this.el.selectionEnd);
|
||||
}
|
||||
|
||||
insertAtCursor(value, offset, finalOffset) {
|
||||
if (this.el.selectionStart != null) {
|
||||
const startPos = this.el.selectionStart;
|
||||
const endPos = this.el.selectionEnd;
|
||||
|
||||
// Move selection to beginning of offset
|
||||
this.el.selectionStart = this.el.selectionStart + offset;
|
||||
|
||||
// Using execCommand to support undo, but since it's officially
|
||||
// 'deprecated' we need a backup solution, but it won't support undo :(
|
||||
let pasted = true;
|
||||
try {
|
||||
if (!document.execCommand("insertText", false, value)) {
|
||||
pasted = false;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("Error caught during execCommand:", e);
|
||||
pasted = false;
|
||||
}
|
||||
|
||||
if (!pasted) {
|
||||
console.error(
|
||||
"execCommand unsuccessful; not supported. Adding text manually, no undo support.");
|
||||
textarea.setRangeText(modifiedText, this.el.selectionStart, this.el.selectionEnd, 'end');
|
||||
}
|
||||
|
||||
this.el.selectionEnd = this.el.selectionStart = startPos + value.length + offset + (finalOffset ?? 0);
|
||||
} else {
|
||||
// Using execCommand to support undo, but since it's officially
|
||||
// 'deprecated' we need a backup solution, but it won't support undo :(
|
||||
let pasted = true;
|
||||
try {
|
||||
if (!document.execCommand("insertText", false, value)) {
|
||||
pasted = false;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("Error caught during execCommand:", e);
|
||||
pasted = false;
|
||||
}
|
||||
|
||||
if (!pasted) {
|
||||
console.error(
|
||||
"execCommand unsuccessful; not supported. Adding text manually, no undo support.");
|
||||
this.el.value += value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
export const CONVERTED_TYPE = 'converted-widget';
|
||||
import { AutoComplete } from "./autocomplete.js";
|
||||
|
||||
export function chainCallback(object, property, callback) {
|
||||
if (object == undefined) {
|
||||
@@ -226,4 +227,58 @@ export function mergeLoras(lorasText, lorasArr) {
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize autocomplete for an input widget and setup cleanup
|
||||
* @param {Object} node - The node instance
|
||||
* @param {Object} inputWidget - The input widget to add autocomplete to
|
||||
* @param {Function} originalCallback - The original callback function
|
||||
* @returns {Function} Enhanced callback function with autocomplete
|
||||
*/
|
||||
export function setupInputWidgetWithAutocomplete(node, inputWidget, originalCallback) {
|
||||
let autocomplete = null;
|
||||
|
||||
// Enhanced callback that initializes autocomplete and calls original callback
|
||||
const enhancedCallback = (value) => {
|
||||
// Initialize autocomplete on first callback if not already done
|
||||
if (!autocomplete && inputWidget.inputEl) {
|
||||
autocomplete = new AutoComplete(inputWidget.inputEl, 'loras', {
|
||||
maxItems: 15,
|
||||
minChars: 1,
|
||||
debounceDelay: 200
|
||||
});
|
||||
// Store reference for cleanup
|
||||
node.autocomplete = autocomplete;
|
||||
}
|
||||
|
||||
// Call the original callback
|
||||
if (originalCallback) {
|
||||
originalCallback(value);
|
||||
}
|
||||
};
|
||||
|
||||
// Setup cleanup on node removal
|
||||
setupAutocompleteCleanup(node);
|
||||
|
||||
return enhancedCallback;
|
||||
}
|
||||
|
||||
/**
|
||||
* Setup autocomplete cleanup when node is removed
|
||||
* @param {Object} node - The node instance
|
||||
*/
|
||||
export function setupAutocompleteCleanup(node) {
|
||||
// Override onRemoved to cleanup autocomplete
|
||||
const originalOnRemoved = node.onRemoved;
|
||||
node.onRemoved = function() {
|
||||
if (this.autocomplete) {
|
||||
this.autocomplete.destroy();
|
||||
this.autocomplete = null;
|
||||
}
|
||||
|
||||
if (originalOnRemoved) {
|
||||
originalOnRemoved.call(this);
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -4,7 +4,8 @@ import {
|
||||
getActiveLorasFromNode,
|
||||
updateConnectedTriggerWords,
|
||||
chainCallback,
|
||||
mergeLoras
|
||||
mergeLoras,
|
||||
setupInputWidgetWithAutocomplete
|
||||
} from "./utils.js";
|
||||
import { addLorasWidget } from "./loras_widget.js";
|
||||
|
||||
@@ -56,8 +57,9 @@ app.registerExtension({
|
||||
return currentLoras.includes(name) ? match : '';
|
||||
});
|
||||
|
||||
// Clean up multiple spaces and trim
|
||||
newText = newText.replace(/\s+/g, ' ').trim();
|
||||
// Clean up multiple spaces, extra commas, and trim; remove trailing comma if it's the only content
|
||||
newText = newText.replace(/\s+/g, " ").replace(/,\s*,+/g, ",").trim();
|
||||
if (newText === ",") newText = "";
|
||||
|
||||
inputWidget.value = newText;
|
||||
|
||||
@@ -80,7 +82,8 @@ app.registerExtension({
|
||||
const inputWidget = this.widgets[1];
|
||||
inputWidget.options.getMaxHeight = () => 100;
|
||||
this.inputWidget = inputWidget;
|
||||
inputWidget.callback = (value) => {
|
||||
// Wrap the callback with autocomplete setup
|
||||
const originalCallback = (value) => {
|
||||
if (isUpdating) return;
|
||||
isUpdating = true;
|
||||
|
||||
@@ -97,6 +100,7 @@ app.registerExtension({
|
||||
isUpdating = false;
|
||||
}
|
||||
};
|
||||
inputWidget.callback = setupInputWidgetWithAutocomplete(this, inputWidget, originalCallback);
|
||||
});
|
||||
}
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user