mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 22:22:11 -03:00
Compare commits
41 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4e22cd375 | ||
|
|
9bc92736a7 | ||
|
|
111b34d05c | ||
|
|
07d9599a2f | ||
|
|
d8194f211d | ||
|
|
51a6374c33 | ||
|
|
aa6c6035b6 | ||
|
|
44b4a7ffbb | ||
|
|
e5bb018d22 | ||
|
|
79b8a6536e | ||
|
|
3de31cd06a | ||
|
|
c579b54d40 | ||
|
|
0a52575e8b | ||
|
|
23c9a98f66 | ||
|
|
796fc33b5b | ||
|
|
dc4c11ddd2 | ||
|
|
d389e4d5d4 | ||
|
|
8cb78ad931 | ||
|
|
85f987d15c | ||
|
|
b12079e0f6 | ||
|
|
dcf5c6167a | ||
|
|
b395d3f487 | ||
|
|
37662cad10 | ||
|
|
aa1673063d | ||
|
|
f51f49eb60 | ||
|
|
54c9bac961 | ||
|
|
e70fd73bdd | ||
|
|
9bb9e7b64d | ||
|
|
f64c03543a | ||
|
|
51374de1a1 | ||
|
|
afcc12f263 | ||
|
|
88c5482366 | ||
|
|
bbf7295c32 | ||
|
|
ca5e23e68c | ||
|
|
eadb1487ae | ||
|
|
1faa70fc77 | ||
|
|
30d7c007de | ||
|
|
f54f6a4402 | ||
|
|
7b41cdec65 | ||
|
|
fb6a652a57 | ||
|
|
ea34d753c1 |
35
README.md
35
README.md
@@ -20,6 +20,18 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
|
|||||||
|
|
||||||
## Release Notes
|
## Release Notes
|
||||||
|
|
||||||
|
### v0.8.9
|
||||||
|
* **Favorites System** - New functionality to bookmark your favorite LoRAs and checkpoints for quick access and better organization
|
||||||
|
* **Enhanced UI Controls** - Increased model card button sizes for improved usability and easier interaction
|
||||||
|
* **Smoother Page Transitions** - Optimized interface switching between pages, eliminating flash issues particularly noticeable in dark theme
|
||||||
|
* **Bug Fixes & Stability** - Resolved various issues to enhance overall reliability and performance
|
||||||
|
|
||||||
|
### v0.8.8
|
||||||
|
* **Real-time TriggerWord Updates** - Enhanced TriggerWord Toggle node to instantly update when connected Lora Loader or Lora Stacker nodes change, without requiring workflow execution
|
||||||
|
* **Optimized Metadata Recovery** - Improved utilization of existing .civitai.info files for faster initialization and preservation of metadata from models deleted from CivitAI
|
||||||
|
* **Migration Acceleration** - Further speed improvements for users transitioning from A1111/Forge environments
|
||||||
|
* **Bug Fixes & Stability** - Resolved various issues to enhance overall reliability and performance
|
||||||
|
|
||||||
### v0.8.7
|
### v0.8.7
|
||||||
* **Enhanced Context Menu** - Added comprehensive context menu functionality to Recipes and Checkpoints pages for improved workflow
|
* **Enhanced Context Menu** - Added comprehensive context menu functionality to Recipes and Checkpoints pages for improved workflow
|
||||||
* **Interactive LoRA Strength Control** - Implemented drag functionality in LoRA Loader for intuitive strength adjustment
|
* **Interactive LoRA Strength Control** - Implemented drag functionality in LoRA Loader for intuitive strength adjustment
|
||||||
@@ -140,7 +152,7 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
|
|||||||
```bash
|
```bash
|
||||||
git clone https://github.com/willmiao/ComfyUI-Lora-Manager.git
|
git clone https://github.com/willmiao/ComfyUI-Lora-Manager.git
|
||||||
cd ComfyUI-Lora-Manager
|
cd ComfyUI-Lora-Manager
|
||||||
pip install requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@@ -163,21 +175,28 @@ pip install requirements.txt
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Thank you for your interest in contributing to ComfyUI LoRA Manager! As this project is currently in its early stages and undergoing rapid development and refactoring, we are temporarily not accepting pull requests.
|
||||||
|
|
||||||
|
However, your feedback and ideas are extremely valuable to us:
|
||||||
|
- Please feel free to open issues for any bugs you encounter
|
||||||
|
- Submit feature requests through GitHub issues
|
||||||
|
- Share your suggestions for improvements
|
||||||
|
|
||||||
|
We appreciate your understanding and look forward to potentially accepting code contributions once the project architecture stabilizes.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Credits
|
## Credits
|
||||||
|
|
||||||
This project has been inspired by and benefited from other excellent ComfyUI extensions:
|
This project has been inspired by and benefited from other excellent ComfyUI extensions:
|
||||||
|
|
||||||
- [ComfyUI-SaveImageWithMetaData](https://github.com/Comfy-Community/ComfyUI-SaveImageWithMetaData) - For the image metadata functionality
|
- [ComfyUI-SaveImageWithMetaData](https://github.com/nkchocoai/ComfyUI-SaveImageWithMetaData) - For the image metadata functionality
|
||||||
- [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) - For the lora loader functionality
|
- [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) - For the lora loader functionality
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Contributing
|
|
||||||
|
|
||||||
If you have suggestions, bug reports, or improvements, feel free to open an issue or contribute directly to the codebase. Pull requests are always welcome!
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## ☕ Support
|
## ☕ Support
|
||||||
|
|
||||||
If you find this project helpful, consider supporting its development:
|
If you find this project helpful, consider supporting its development:
|
||||||
|
|||||||
30
py/config.py
30
py/config.py
@@ -103,21 +103,29 @@ class Config:
|
|||||||
|
|
||||||
def _init_lora_paths(self) -> List[str]:
|
def _init_lora_paths(self) -> List[str]:
|
||||||
"""Initialize and validate LoRA paths from ComfyUI settings"""
|
"""Initialize and validate LoRA paths from ComfyUI settings"""
|
||||||
paths = sorted(set(path.replace(os.sep, "/")
|
raw_paths = folder_paths.get_folder_paths("loras")
|
||||||
for path in folder_paths.get_folder_paths("loras")
|
|
||||||
if os.path.exists(path)), key=lambda p: p.lower())
|
|
||||||
print("Found LoRA roots:", "\n - " + "\n - ".join(paths))
|
|
||||||
|
|
||||||
if not paths:
|
# Normalize and resolve symlinks, store mapping from resolved -> original
|
||||||
|
path_map = {}
|
||||||
|
for path in raw_paths:
|
||||||
|
if os.path.exists(path):
|
||||||
|
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||||
|
path_map[real_path] = path_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
|
||||||
|
|
||||||
|
# Now sort and use only the deduplicated real paths
|
||||||
|
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
|
||||||
|
print("Found LoRA roots:", "\n - " + "\n - ".join(unique_paths))
|
||||||
|
|
||||||
|
if not unique_paths:
|
||||||
raise ValueError("No valid loras folders found in ComfyUI configuration")
|
raise ValueError("No valid loras folders found in ComfyUI configuration")
|
||||||
|
|
||||||
# 初始化路径映射
|
for original_path in unique_paths:
|
||||||
for path in paths:
|
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
|
||||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
if real_path != original_path:
|
||||||
if real_path != path:
|
self.add_path_mapping(original_path, real_path)
|
||||||
self.add_path_mapping(path, real_path)
|
|
||||||
|
|
||||||
return paths
|
return unique_paths
|
||||||
|
|
||||||
|
|
||||||
def _init_checkpoint_paths(self) -> List[str]:
|
def _init_checkpoint_paths(self) -> List[str]:
|
||||||
"""Initialize and validate checkpoint paths from ComfyUI settings"""
|
"""Initialize and validate checkpoint paths from ComfyUI settings"""
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ from .routes.lora_routes import LoraRoutes
|
|||||||
from .routes.api_routes import ApiRoutes
|
from .routes.api_routes import ApiRoutes
|
||||||
from .routes.recipe_routes import RecipeRoutes
|
from .routes.recipe_routes import RecipeRoutes
|
||||||
from .routes.checkpoints_routes import CheckpointsRoutes
|
from .routes.checkpoints_routes import CheckpointsRoutes
|
||||||
|
from .routes.update_routes import UpdateRoutes
|
||||||
|
from .routes.usage_stats_routes import UsageStatsRoutes
|
||||||
from .services.service_registry import ServiceRegistry
|
from .services.service_registry import ServiceRegistry
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -92,6 +94,8 @@ class LoraManager:
|
|||||||
checkpoints_routes.setup_routes(app)
|
checkpoints_routes.setup_routes(app)
|
||||||
ApiRoutes.setup_routes(app)
|
ApiRoutes.setup_routes(app)
|
||||||
RecipeRoutes.setup_routes(app)
|
RecipeRoutes.setup_routes(app)
|
||||||
|
UpdateRoutes.setup_routes(app)
|
||||||
|
UsageStatsRoutes.setup_routes(app) # Register usage stats routes
|
||||||
|
|
||||||
# Schedule service initialization
|
# Schedule service initialization
|
||||||
app.on_startup.append(lambda app: cls._initialize_services())
|
app.on_startup.append(lambda app: cls._initialize_services())
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
"""Constants used by the metadata collector"""
|
"""Constants used by the metadata collector"""
|
||||||
|
|
||||||
# Individual category constants
|
# Metadata collection constants
|
||||||
|
|
||||||
|
# Metadata categories
|
||||||
MODELS = "models"
|
MODELS = "models"
|
||||||
PROMPTS = "prompts"
|
PROMPTS = "prompts"
|
||||||
SAMPLING = "sampling"
|
SAMPLING = "sampling"
|
||||||
LORAS = "loras"
|
LORAS = "loras"
|
||||||
SIZE = "size"
|
SIZE = "size"
|
||||||
IMAGES = "images" # Added new category for image results
|
IMAGES = "images"
|
||||||
|
|
||||||
# Collection of categories for iteration
|
# Complete list of categories to track
|
||||||
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES] # Added IMAGES to categories
|
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES]
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from ..services.lora_scanner import LoraScanner
|
|||||||
from ..config import config
|
from ..config import config
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from .utils import FlexibleOptionalInputType, any_type
|
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -32,48 +32,6 @@ class LoraManagerLoader:
|
|||||||
RETURN_TYPES = ("MODEL", "CLIP", IO.STRING, IO.STRING)
|
RETURN_TYPES = ("MODEL", "CLIP", IO.STRING, IO.STRING)
|
||||||
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
|
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
|
||||||
FUNCTION = "load_loras"
|
FUNCTION = "load_loras"
|
||||||
|
|
||||||
async def get_lora_info(self, lora_name):
|
|
||||||
"""Get the lora path and trigger words from cache"""
|
|
||||||
scanner = await LoraScanner.get_instance()
|
|
||||||
cache = await scanner.get_cached_data()
|
|
||||||
|
|
||||||
for item in cache.raw_data:
|
|
||||||
if item.get('file_name') == lora_name:
|
|
||||||
file_path = item.get('file_path')
|
|
||||||
if file_path:
|
|
||||||
for root in config.loras_roots:
|
|
||||||
root = root.replace(os.sep, '/')
|
|
||||||
if file_path.startswith(root):
|
|
||||||
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
|
||||||
# Get trigger words from civitai metadata
|
|
||||||
civitai = item.get('civitai', {})
|
|
||||||
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
|
||||||
return relative_path, trigger_words
|
|
||||||
return lora_name, [] # Fallback if not found
|
|
||||||
|
|
||||||
def extract_lora_name(self, lora_path):
|
|
||||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
|
||||||
# Get the basename without extension
|
|
||||||
basename = os.path.basename(lora_path)
|
|
||||||
return os.path.splitext(basename)[0]
|
|
||||||
|
|
||||||
def _get_loras_list(self, kwargs):
|
|
||||||
"""Helper to extract loras list from either old or new kwargs format"""
|
|
||||||
if 'loras' not in kwargs:
|
|
||||||
return []
|
|
||||||
|
|
||||||
loras_data = kwargs['loras']
|
|
||||||
# Handle new format: {'loras': {'__value__': [...]}}
|
|
||||||
if isinstance(loras_data, dict) and '__value__' in loras_data:
|
|
||||||
return loras_data['__value__']
|
|
||||||
# Handle old format: {'loras': [...]}
|
|
||||||
elif isinstance(loras_data, list):
|
|
||||||
return loras_data
|
|
||||||
# Unexpected format
|
|
||||||
else:
|
|
||||||
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def load_loras(self, model, text, **kwargs):
|
def load_loras(self, model, text, **kwargs):
|
||||||
"""Loads multiple LoRAs based on the kwargs input and lora_stack."""
|
"""Loads multiple LoRAs based on the kwargs input and lora_stack."""
|
||||||
@@ -89,14 +47,14 @@ class LoraManagerLoader:
|
|||||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||||
|
|
||||||
# Extract lora name for trigger words lookup
|
# Extract lora name for trigger words lookup
|
||||||
lora_name = self.extract_lora_name(lora_path)
|
lora_name = extract_lora_name(lora_path)
|
||||||
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
_, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||||
|
|
||||||
all_trigger_words.extend(trigger_words)
|
all_trigger_words.extend(trigger_words)
|
||||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||||
|
|
||||||
# Then process loras from kwargs with support for both old and new formats
|
# Then process loras from kwargs with support for both old and new formats
|
||||||
loras_list = self._get_loras_list(kwargs)
|
loras_list = get_loras_list(kwargs)
|
||||||
for lora in loras_list:
|
for lora in loras_list:
|
||||||
if not lora.get('active', False):
|
if not lora.get('active', False):
|
||||||
continue
|
continue
|
||||||
@@ -105,7 +63,7 @@ class LoraManagerLoader:
|
|||||||
strength = float(lora['strength'])
|
strength = float(lora['strength'])
|
||||||
|
|
||||||
# Get lora path and trigger words
|
# Get lora path and trigger words
|
||||||
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||||
|
|
||||||
# Apply the LoRA using the resolved path
|
# Apply the LoRA using the resolved path
|
||||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, strength, strength)
|
model, clip = LoraLoader().load_lora(model, clip, lora_path, strength, strength)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from ..services.lora_scanner import LoraScanner
|
|||||||
from ..config import config
|
from ..config import config
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from .utils import FlexibleOptionalInputType, any_type
|
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -29,48 +29,6 @@ class LoraStacker:
|
|||||||
RETURN_TYPES = ("LORA_STACK", IO.STRING, IO.STRING)
|
RETURN_TYPES = ("LORA_STACK", IO.STRING, IO.STRING)
|
||||||
RETURN_NAMES = ("LORA_STACK", "trigger_words", "active_loras")
|
RETURN_NAMES = ("LORA_STACK", "trigger_words", "active_loras")
|
||||||
FUNCTION = "stack_loras"
|
FUNCTION = "stack_loras"
|
||||||
|
|
||||||
async def get_lora_info(self, lora_name):
|
|
||||||
"""Get the lora path and trigger words from cache"""
|
|
||||||
scanner = await LoraScanner.get_instance()
|
|
||||||
cache = await scanner.get_cached_data()
|
|
||||||
|
|
||||||
for item in cache.raw_data:
|
|
||||||
if item.get('file_name') == lora_name:
|
|
||||||
file_path = item.get('file_path')
|
|
||||||
if file_path:
|
|
||||||
for root in config.loras_roots:
|
|
||||||
root = root.replace(os.sep, '/')
|
|
||||||
if file_path.startswith(root):
|
|
||||||
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
|
||||||
# Get trigger words from civitai metadata
|
|
||||||
civitai = item.get('civitai', {})
|
|
||||||
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
|
||||||
return relative_path, trigger_words
|
|
||||||
return lora_name, [] # Fallback if not found
|
|
||||||
|
|
||||||
def extract_lora_name(self, lora_path):
|
|
||||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
|
||||||
# Get the basename without extension
|
|
||||||
basename = os.path.basename(lora_path)
|
|
||||||
return os.path.splitext(basename)[0]
|
|
||||||
|
|
||||||
def _get_loras_list(self, kwargs):
|
|
||||||
"""Helper to extract loras list from either old or new kwargs format"""
|
|
||||||
if 'loras' not in kwargs:
|
|
||||||
return []
|
|
||||||
|
|
||||||
loras_data = kwargs['loras']
|
|
||||||
# Handle new format: {'loras': {'__value__': [...]}}
|
|
||||||
if isinstance(loras_data, dict) and '__value__' in loras_data:
|
|
||||||
return loras_data['__value__']
|
|
||||||
# Handle old format: {'loras': [...]}
|
|
||||||
elif isinstance(loras_data, list):
|
|
||||||
return loras_data
|
|
||||||
# Unexpected format
|
|
||||||
else:
|
|
||||||
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def stack_loras(self, text, **kwargs):
|
def stack_loras(self, text, **kwargs):
|
||||||
"""Stacks multiple LoRAs based on the kwargs input without loading them."""
|
"""Stacks multiple LoRAs based on the kwargs input without loading them."""
|
||||||
@@ -84,12 +42,12 @@ class LoraStacker:
|
|||||||
stack.extend(lora_stack)
|
stack.extend(lora_stack)
|
||||||
# Get trigger words from existing stack entries
|
# Get trigger words from existing stack entries
|
||||||
for lora_path, _, _ in lora_stack:
|
for lora_path, _, _ in lora_stack:
|
||||||
lora_name = self.extract_lora_name(lora_path)
|
lora_name = extract_lora_name(lora_path)
|
||||||
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
_, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||||
all_trigger_words.extend(trigger_words)
|
all_trigger_words.extend(trigger_words)
|
||||||
|
|
||||||
# Process loras from kwargs with support for both old and new formats
|
# Process loras from kwargs with support for both old and new formats
|
||||||
loras_list = self._get_loras_list(kwargs)
|
loras_list = get_loras_list(kwargs)
|
||||||
for lora in loras_list:
|
for lora in loras_list:
|
||||||
if not lora.get('active', False):
|
if not lora.get('active', False):
|
||||||
continue
|
continue
|
||||||
@@ -99,7 +57,7 @@ class LoraStacker:
|
|||||||
clip_strength = model_strength # Using same strength for both as in the original loader
|
clip_strength = model_strength # Using same strength for both as in the original loader
|
||||||
|
|
||||||
# Get lora path and trigger words
|
# Get lora path and trigger words
|
||||||
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||||
|
|
||||||
# Add to stack without loading
|
# Add to stack without loading
|
||||||
# replace '/' with os.sep to avoid different OS path format
|
# replace '/' with os.sep to avoid different OS path format
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import re
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import folder_paths # type: ignore
|
import folder_paths # type: ignore
|
||||||
from ..services.lora_scanner import LoraScanner
|
from ..services.lora_scanner import LoraScanner
|
||||||
|
from ..services.checkpoint_scanner import CheckpointScanner
|
||||||
from ..metadata_collector.metadata_processor import MetadataProcessor
|
from ..metadata_collector.metadata_processor import MetadataProcessor
|
||||||
from ..metadata_collector import get_metadata
|
from ..metadata_collector import get_metadata
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
@@ -53,18 +54,55 @@ class SaveImage:
|
|||||||
async def get_lora_hash(self, lora_name):
|
async def get_lora_hash(self, lora_name):
|
||||||
"""Get the lora hash from cache"""
|
"""Get the lora hash from cache"""
|
||||||
scanner = await LoraScanner.get_instance()
|
scanner = await LoraScanner.get_instance()
|
||||||
cache = await scanner.get_cached_data()
|
|
||||||
|
|
||||||
|
# Use the new direct filename lookup method
|
||||||
|
hash_value = scanner.get_hash_by_filename(lora_name)
|
||||||
|
if hash_value:
|
||||||
|
return hash_value
|
||||||
|
|
||||||
|
# Fallback to old method for compatibility
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
for item in cache.raw_data:
|
for item in cache.raw_data:
|
||||||
if item.get('file_name') == lora_name:
|
if item.get('file_name') == lora_name:
|
||||||
return item.get('sha256')
|
return item.get('sha256')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def get_checkpoint_hash(self, checkpoint_path):
|
||||||
|
"""Get the checkpoint hash from cache"""
|
||||||
|
scanner = await CheckpointScanner.get_instance()
|
||||||
|
|
||||||
|
if not checkpoint_path:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract basename without extension
|
||||||
|
checkpoint_name = os.path.basename(checkpoint_path)
|
||||||
|
checkpoint_name = os.path.splitext(checkpoint_name)[0]
|
||||||
|
|
||||||
|
# Try direct filename lookup first
|
||||||
|
hash_value = scanner.get_hash_by_filename(checkpoint_name)
|
||||||
|
if hash_value:
|
||||||
|
return hash_value
|
||||||
|
|
||||||
|
# Fallback to old method for compatibility
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
normalized_path = checkpoint_path.replace('\\', '/')
|
||||||
|
|
||||||
|
for item in cache.raw_data:
|
||||||
|
if item.get('file_name') == checkpoint_name and item.get('file_path').endswith(normalized_path):
|
||||||
|
return item.get('sha256')
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
async def format_metadata(self, metadata_dict):
|
async def format_metadata(self, metadata_dict):
|
||||||
"""Format metadata in the requested format similar to userComment example"""
|
"""Format metadata in the requested format similar to userComment example"""
|
||||||
if not metadata_dict:
|
if not metadata_dict:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
# Helper function to only add parameter if value is not None
|
||||||
|
def add_param_if_not_none(param_list, label, value):
|
||||||
|
if value is not None:
|
||||||
|
param_list.append(f"{label}: {value}")
|
||||||
|
|
||||||
# Extract the prompt and negative prompt
|
# Extract the prompt and negative prompt
|
||||||
prompt = metadata_dict.get('prompt', '')
|
prompt = metadata_dict.get('prompt', '')
|
||||||
negative_prompt = metadata_dict.get('negative_prompt', '')
|
negative_prompt = metadata_dict.get('negative_prompt', '')
|
||||||
@@ -100,7 +138,11 @@ class SaveImage:
|
|||||||
|
|
||||||
# Add standard parameters in the correct order
|
# Add standard parameters in the correct order
|
||||||
if 'steps' in metadata_dict:
|
if 'steps' in metadata_dict:
|
||||||
params.append(f"Steps: {metadata_dict.get('steps')}")
|
add_param_if_not_none(params, "Steps", metadata_dict.get('steps'))
|
||||||
|
|
||||||
|
# Combine sampler and scheduler information
|
||||||
|
sampler_name = None
|
||||||
|
scheduler_name = None
|
||||||
|
|
||||||
if 'sampler' in metadata_dict:
|
if 'sampler' in metadata_dict:
|
||||||
sampler = metadata_dict.get('sampler')
|
sampler = metadata_dict.get('sampler')
|
||||||
@@ -123,7 +165,6 @@ class SaveImage:
|
|||||||
'ddim': 'DDIM'
|
'ddim': 'DDIM'
|
||||||
}
|
}
|
||||||
sampler_name = sampler_mapping.get(sampler, sampler)
|
sampler_name = sampler_mapping.get(sampler, sampler)
|
||||||
params.append(f"Sampler: {sampler_name}")
|
|
||||||
|
|
||||||
if 'scheduler' in metadata_dict:
|
if 'scheduler' in metadata_dict:
|
||||||
scheduler = metadata_dict.get('scheduler')
|
scheduler = metadata_dict.get('scheduler')
|
||||||
@@ -135,38 +176,48 @@ class SaveImage:
|
|||||||
'sgm_quadratic': 'SGM Quadratic'
|
'sgm_quadratic': 'SGM Quadratic'
|
||||||
}
|
}
|
||||||
scheduler_name = scheduler_mapping.get(scheduler, scheduler)
|
scheduler_name = scheduler_mapping.get(scheduler, scheduler)
|
||||||
params.append(f"Schedule type: {scheduler_name}")
|
|
||||||
|
|
||||||
# CFG scale (cfg_scale in metadata_dict)
|
# Add combined sampler and scheduler information
|
||||||
if 'cfg_scale' in metadata_dict:
|
if sampler_name:
|
||||||
params.append(f"CFG scale: {metadata_dict.get('cfg_scale')}")
|
if scheduler_name:
|
||||||
|
params.append(f"Sampler: {sampler_name} {scheduler_name}")
|
||||||
|
else:
|
||||||
|
params.append(f"Sampler: {sampler_name}")
|
||||||
|
|
||||||
|
# CFG scale (Use guidance if available, otherwise fall back to cfg_scale or cfg)
|
||||||
|
if 'guidance' in metadata_dict:
|
||||||
|
add_param_if_not_none(params, "CFG scale", metadata_dict.get('guidance'))
|
||||||
|
elif 'cfg_scale' in metadata_dict:
|
||||||
|
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg_scale'))
|
||||||
elif 'cfg' in metadata_dict:
|
elif 'cfg' in metadata_dict:
|
||||||
params.append(f"CFG scale: {metadata_dict.get('cfg')}")
|
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg'))
|
||||||
|
|
||||||
# Seed
|
# Seed
|
||||||
if 'seed' in metadata_dict:
|
if 'seed' in metadata_dict:
|
||||||
params.append(f"Seed: {metadata_dict.get('seed')}")
|
add_param_if_not_none(params, "Seed", metadata_dict.get('seed'))
|
||||||
|
|
||||||
# Size
|
# Size
|
||||||
if 'size' in metadata_dict:
|
if 'size' in metadata_dict:
|
||||||
params.append(f"Size: {metadata_dict.get('size')}")
|
add_param_if_not_none(params, "Size", metadata_dict.get('size'))
|
||||||
|
|
||||||
# Model info
|
# Model info
|
||||||
if 'checkpoint' in metadata_dict:
|
if 'checkpoint' in metadata_dict:
|
||||||
# Ensure checkpoint is a string before processing
|
# Ensure checkpoint is a string before processing
|
||||||
checkpoint = metadata_dict.get('checkpoint')
|
checkpoint = metadata_dict.get('checkpoint')
|
||||||
if checkpoint is not None:
|
if checkpoint is not None:
|
||||||
# Handle both string and other types safely
|
# Get model hash
|
||||||
if isinstance(checkpoint, str):
|
model_hash = await self.get_checkpoint_hash(checkpoint)
|
||||||
# Extract basename without path
|
|
||||||
checkpoint = os.path.basename(checkpoint)
|
|
||||||
# Remove extension if present
|
|
||||||
checkpoint = os.path.splitext(checkpoint)[0]
|
|
||||||
else:
|
|
||||||
# Convert non-string to string
|
|
||||||
checkpoint = str(checkpoint)
|
|
||||||
|
|
||||||
params.append(f"Model: {checkpoint}")
|
# Extract basename without path
|
||||||
|
checkpoint_name = os.path.basename(checkpoint)
|
||||||
|
# Remove extension if present
|
||||||
|
checkpoint_name = os.path.splitext(checkpoint_name)[0]
|
||||||
|
|
||||||
|
# Add model hash if available
|
||||||
|
if model_hash:
|
||||||
|
params.append(f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}")
|
||||||
|
else:
|
||||||
|
params.append(f"Model: {checkpoint_name}")
|
||||||
|
|
||||||
# Add LoRA hashes if available
|
# Add LoRA hashes if available
|
||||||
if lora_hashes:
|
if lora_hashes:
|
||||||
@@ -284,7 +335,7 @@ class SaveImage:
|
|||||||
if add_counter_to_filename:
|
if add_counter_to_filename:
|
||||||
# Use counter + i to ensure unique filenames for all images in batch
|
# Use counter + i to ensure unique filenames for all images in batch
|
||||||
current_counter = counter + i
|
current_counter = counter + i
|
||||||
base_filename += f"_{current_counter:05}"
|
base_filename += f"_{current_counter:05}_"
|
||||||
|
|
||||||
# Set file extension and prepare saving parameters
|
# Set file extension and prepare saving parameters
|
||||||
if file_format == "png":
|
if file_format == "png":
|
||||||
|
|||||||
@@ -47,10 +47,10 @@ class TriggerWordToggle:
|
|||||||
trigger_words = trigger_words_data if isinstance(trigger_words_data, str) else ""
|
trigger_words = trigger_words_data if isinstance(trigger_words_data, str) else ""
|
||||||
|
|
||||||
# Send trigger words to frontend
|
# Send trigger words to frontend
|
||||||
PromptServer.instance.send_sync("trigger_word_update", {
|
# PromptServer.instance.send_sync("trigger_word_update", {
|
||||||
"id": id,
|
# "id": id,
|
||||||
"message": trigger_words
|
# "message": trigger_words
|
||||||
})
|
# })
|
||||||
|
|
||||||
filtered_triggers = trigger_words
|
filtered_triggers = trigger_words
|
||||||
|
|
||||||
|
|||||||
@@ -30,4 +30,55 @@ class FlexibleOptionalInputType(dict):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
any_type = AnyType("*")
|
any_type = AnyType("*")
|
||||||
|
|
||||||
|
# Common methods extracted from lora_loader.py and lora_stacker.py
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from ..services.lora_scanner import LoraScanner
|
||||||
|
from ..config import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def get_lora_info(lora_name):
|
||||||
|
"""Get the lora path and trigger words from cache"""
|
||||||
|
scanner = await LoraScanner.get_instance()
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
|
||||||
|
for item in cache.raw_data:
|
||||||
|
if item.get('file_name') == lora_name:
|
||||||
|
file_path = item.get('file_path')
|
||||||
|
if file_path:
|
||||||
|
for root in config.loras_roots:
|
||||||
|
root = root.replace(os.sep, '/')
|
||||||
|
if file_path.startswith(root):
|
||||||
|
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
||||||
|
# Get trigger words from civitai metadata
|
||||||
|
civitai = item.get('civitai', {})
|
||||||
|
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
||||||
|
return relative_path, trigger_words
|
||||||
|
return lora_name, [] # Fallback if not found
|
||||||
|
|
||||||
|
def extract_lora_name(lora_path):
|
||||||
|
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
||||||
|
# Get the basename without extension
|
||||||
|
basename = os.path.basename(lora_path)
|
||||||
|
return os.path.splitext(basename)[0]
|
||||||
|
|
||||||
|
def get_loras_list(kwargs):
|
||||||
|
"""Helper to extract loras list from either old or new kwargs format"""
|
||||||
|
if 'loras' not in kwargs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
loras_data = kwargs['loras']
|
||||||
|
# Handle new format: {'loras': {'__value__': [...]}}
|
||||||
|
if isinstance(loras_data, dict) and '__value__' in loras_data:
|
||||||
|
return loras_data['__value__']
|
||||||
|
# Handle old format: {'loras': [...]}
|
||||||
|
elif isinstance(loras_data, list):
|
||||||
|
return loras_data
|
||||||
|
# Unexpected format
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
||||||
|
return []
|
||||||
@@ -3,8 +3,10 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
from server import PromptServer # type: ignore
|
||||||
|
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
from ..utils.routes_common import ModelRouteUtils
|
||||||
|
from ..nodes.utils import get_lora_info
|
||||||
|
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..services.websocket_manager import ws_manager
|
from ..services.websocket_manager import ws_manager
|
||||||
@@ -64,6 +66,9 @@ class ApiRoutes:
|
|||||||
app.router.add_get('/api/lora-civitai-url', routes.get_lora_civitai_url) # Add new route for Civitai URL
|
app.router.add_get('/api/lora-civitai-url', routes.get_lora_civitai_url) # Add new route for Civitai URL
|
||||||
app.router.add_post('/api/rename_lora', routes.rename_lora) # Add new route for renaming LoRA files
|
app.router.add_post('/api/rename_lora', routes.rename_lora) # Add new route for renaming LoRA files
|
||||||
app.router.add_get('/api/loras/scan', routes.scan_loras) # Add new route for scanning LoRA files
|
app.router.add_get('/api/loras/scan', routes.scan_loras) # Add new route for scanning LoRA files
|
||||||
|
|
||||||
|
# Add the new trigger words route
|
||||||
|
app.router.add_post('/loramanager/get_trigger_words', routes.get_trigger_words)
|
||||||
|
|
||||||
# Add update check routes
|
# Add update check routes
|
||||||
UpdateRoutes.setup_routes(app)
|
UpdateRoutes.setup_routes(app)
|
||||||
@@ -120,6 +125,7 @@ class ApiRoutes:
|
|||||||
# Get filter parameters
|
# Get filter parameters
|
||||||
base_models = request.query.get('base_models', None)
|
base_models = request.query.get('base_models', None)
|
||||||
tags = request.query.get('tags', None)
|
tags = request.query.get('tags', None)
|
||||||
|
favorites_only = request.query.get('favorites_only', 'false').lower() == 'true' # New parameter
|
||||||
|
|
||||||
# New parameters for recipe filtering
|
# New parameters for recipe filtering
|
||||||
lora_hash = request.query.get('lora_hash', None)
|
lora_hash = request.query.get('lora_hash', None)
|
||||||
@@ -150,7 +156,8 @@ class ApiRoutes:
|
|||||||
base_models=filters.get('base_model', None),
|
base_models=filters.get('base_model', None),
|
||||||
tags=filters.get('tags', None),
|
tags=filters.get('tags', None),
|
||||||
search_options=search_options,
|
search_options=search_options,
|
||||||
hash_filters=hash_filters
|
hash_filters=hash_filters,
|
||||||
|
favorites_only=favorites_only # Pass favorites_only parameter
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get all available folders from cache
|
# Get all available folders from cache
|
||||||
@@ -190,6 +197,7 @@ class ApiRoutes:
|
|||||||
"from_civitai": lora.get("from_civitai", True),
|
"from_civitai": lora.get("from_civitai", True),
|
||||||
"usage_tips": lora.get("usage_tips", ""),
|
"usage_tips": lora.get("usage_tips", ""),
|
||||||
"notes": lora.get("notes", ""),
|
"notes": lora.get("notes", ""),
|
||||||
|
"favorite": lora.get("favorite", False), # Include favorite status in response
|
||||||
"civitai": ModelRouteUtils.filter_civitai_data(lora.get("civitai", {}))
|
"civitai": ModelRouteUtils.filter_civitai_data(lora.get("civitai", {}))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1021,4 +1029,35 @@ class ApiRoutes:
|
|||||||
return web.json_response({
|
return web.json_response({
|
||||||
'success': False,
|
'success': False,
|
||||||
'error': str(e)
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_trigger_words(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get trigger words for specified LoRA models"""
|
||||||
|
try:
|
||||||
|
json_data = await request.json()
|
||||||
|
lora_names = json_data.get("lora_names", [])
|
||||||
|
node_ids = json_data.get("node_ids", [])
|
||||||
|
|
||||||
|
all_trigger_words = []
|
||||||
|
for lora_name in lora_names:
|
||||||
|
_, trigger_words = await get_lora_info(lora_name)
|
||||||
|
all_trigger_words.extend(trigger_words)
|
||||||
|
|
||||||
|
# Format the trigger words
|
||||||
|
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||||
|
|
||||||
|
# Send update to all connected trigger word toggle nodes
|
||||||
|
for node_id in node_ids:
|
||||||
|
PromptServer.instance.send_sync("trigger_word_update", {
|
||||||
|
"id": node_id,
|
||||||
|
"message": trigger_words_text
|
||||||
|
})
|
||||||
|
|
||||||
|
return web.json_response({"success": True})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting trigger words: {e}")
|
||||||
|
return web.json_response({
|
||||||
|
"success": False,
|
||||||
|
"error": str(e)
|
||||||
}, status=500)
|
}, status=500)
|
||||||
@@ -69,6 +69,7 @@ class CheckpointsRoutes:
|
|||||||
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
|
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
|
||||||
base_models = request.query.getall('base_model', [])
|
base_models = request.query.getall('base_model', [])
|
||||||
tags = request.query.getall('tag', [])
|
tags = request.query.getall('tag', [])
|
||||||
|
favorites_only = request.query.get('favorites_only', 'false').lower() == 'true' # Add favorites_only parameter
|
||||||
|
|
||||||
# Process search options
|
# Process search options
|
||||||
search_options = {
|
search_options = {
|
||||||
@@ -101,7 +102,8 @@ class CheckpointsRoutes:
|
|||||||
base_models=base_models,
|
base_models=base_models,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
search_options=search_options,
|
search_options=search_options,
|
||||||
hash_filters=hash_filters
|
hash_filters=hash_filters,
|
||||||
|
favorites_only=favorites_only # Pass favorites_only parameter
|
||||||
)
|
)
|
||||||
|
|
||||||
# Format response items
|
# Format response items
|
||||||
@@ -123,7 +125,8 @@ class CheckpointsRoutes:
|
|||||||
async def get_paginated_data(self, page, page_size, sort_by='name',
|
async def get_paginated_data(self, page, page_size, sort_by='name',
|
||||||
folder=None, search=None, fuzzy_search=False,
|
folder=None, search=None, fuzzy_search=False,
|
||||||
base_models=None, tags=None,
|
base_models=None, tags=None,
|
||||||
search_options=None, hash_filters=None):
|
search_options=None, hash_filters=None,
|
||||||
|
favorites_only=False): # Add favorites_only parameter with default False
|
||||||
"""Get paginated and filtered checkpoint data"""
|
"""Get paginated and filtered checkpoint data"""
|
||||||
cache = await self.scanner.get_cached_data()
|
cache = await self.scanner.get_cached_data()
|
||||||
|
|
||||||
@@ -181,6 +184,13 @@ class CheckpointsRoutes:
|
|||||||
if not cp.get('preview_nsfw_level') or cp.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
if not cp.get('preview_nsfw_level') or cp.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Apply favorites filtering if enabled
|
||||||
|
if favorites_only:
|
||||||
|
filtered_data = [
|
||||||
|
cp for cp in filtered_data
|
||||||
|
if cp.get('favorite', False) is True
|
||||||
|
]
|
||||||
|
|
||||||
# Apply folder filtering
|
# Apply folder filtering
|
||||||
if folder is not None:
|
if folder is not None:
|
||||||
if search_options.get('recursive', False):
|
if search_options.get('recursive', False):
|
||||||
@@ -276,6 +286,7 @@ class CheckpointsRoutes:
|
|||||||
"from_civitai": checkpoint.get("from_civitai", True),
|
"from_civitai": checkpoint.get("from_civitai", True),
|
||||||
"notes": checkpoint.get("notes", ""),
|
"notes": checkpoint.get("notes", ""),
|
||||||
"model_type": checkpoint.get("model_type", "checkpoint"),
|
"model_type": checkpoint.get("model_type", "checkpoint"),
|
||||||
|
"favorite": checkpoint.get("favorite", False),
|
||||||
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint.get("civitai", {}))
|
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint.get("civitai", {}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -74,6 +74,9 @@ class RecipeRoutes:
|
|||||||
|
|
||||||
# Add route to get recipes for a specific Lora
|
# Add route to get recipes for a specific Lora
|
||||||
app.router.add_get('/api/recipes/for-lora', routes.get_recipes_for_lora)
|
app.router.add_get('/api/recipes/for-lora', routes.get_recipes_for_lora)
|
||||||
|
|
||||||
|
# Add new endpoint for scanning and rebuilding the recipe cache
|
||||||
|
app.router.add_get('/api/recipes/scan', routes.scan_recipes)
|
||||||
|
|
||||||
async def _init_cache(self, app):
|
async def _init_cache(self, app):
|
||||||
"""Initialize cache on startup"""
|
"""Initialize cache on startup"""
|
||||||
@@ -1255,3 +1258,24 @@ class RecipeRoutes:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting recipes for Lora: {str(e)}")
|
logger.error(f"Error getting recipes for Lora: {str(e)}")
|
||||||
return web.json_response({'success': False, 'error': str(e)}, status=500)
|
return web.json_response({'success': False, 'error': str(e)}, status=500)
|
||||||
|
|
||||||
|
async def scan_recipes(self, request: web.Request) -> web.Response:
|
||||||
|
"""API endpoint for scanning and rebuilding the recipe cache"""
|
||||||
|
try:
|
||||||
|
# Ensure services are initialized
|
||||||
|
await self.init_services()
|
||||||
|
|
||||||
|
# Force refresh the recipe cache
|
||||||
|
logger.info("Manually triggering recipe cache rebuild")
|
||||||
|
await self.recipe_scanner.get_cached_data(force_refresh=True)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'message': 'Recipe cache refreshed successfully'
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error refreshing recipe cache: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|||||||
69
py/routes/usage_stats_routes.py
Normal file
69
py/routes/usage_stats_routes.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import logging
|
||||||
|
from aiohttp import web
|
||||||
|
from ..utils.usage_stats import UsageStats
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class UsageStatsRoutes:
|
||||||
|
"""Routes for handling usage statistics updates"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def setup_routes(app):
|
||||||
|
"""Register usage stats routes"""
|
||||||
|
app.router.add_post('/loras/api/update-usage-stats', UsageStatsRoutes.update_usage_stats)
|
||||||
|
app.router.add_get('/loras/api/get-usage-stats', UsageStatsRoutes.get_usage_stats)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def update_usage_stats(request):
|
||||||
|
"""
|
||||||
|
Update usage statistics based on a prompt_id
|
||||||
|
|
||||||
|
Expects a JSON body with:
|
||||||
|
{
|
||||||
|
"prompt_id": "string"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Parse the request body
|
||||||
|
data = await request.json()
|
||||||
|
prompt_id = data.get('prompt_id')
|
||||||
|
|
||||||
|
if not prompt_id:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Missing prompt_id'
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
# Call the UsageStats to process this prompt_id synchronously
|
||||||
|
usage_stats = UsageStats()
|
||||||
|
await usage_stats.process_execution(prompt_id)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update usage stats: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_usage_stats(request):
|
||||||
|
"""Get current usage statistics"""
|
||||||
|
try:
|
||||||
|
usage_stats = UsageStats()
|
||||||
|
stats = await usage_stats.get_stats()
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'data': stats
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get usage stats: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
26
py/server_routes.py
Normal file
26
py/server_routes.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from aiohttp import web
|
||||||
|
from server import PromptServer
|
||||||
|
from .nodes.utils import get_lora_info
|
||||||
|
|
||||||
|
@PromptServer.instance.routes.post("/loramanager/get_trigger_words")
|
||||||
|
async def get_trigger_words(request):
|
||||||
|
json_data = await request.json()
|
||||||
|
lora_names = json_data.get("lora_names", [])
|
||||||
|
node_ids = json_data.get("node_ids", [])
|
||||||
|
|
||||||
|
all_trigger_words = []
|
||||||
|
for lora_name in lora_names:
|
||||||
|
_, trigger_words = await get_lora_info(lora_name)
|
||||||
|
all_trigger_words.extend(trigger_words)
|
||||||
|
|
||||||
|
# Format the trigger words
|
||||||
|
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||||
|
|
||||||
|
# Send update to all connected trigger word toggle nodes
|
||||||
|
for node_id in node_ids:
|
||||||
|
PromptServer.instance.send_sync("trigger_word_update", {
|
||||||
|
"id": node_id,
|
||||||
|
"message": trigger_words_text
|
||||||
|
})
|
||||||
|
|
||||||
|
return web.json_response({"success": True})
|
||||||
@@ -34,6 +34,7 @@ class CivitaiClient:
|
|||||||
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
|
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
|
||||||
}
|
}
|
||||||
self._session = None
|
self._session = None
|
||||||
|
self._session_created_at = None
|
||||||
# Set default buffer size to 1MB for higher throughput
|
# Set default buffer size to 1MB for higher throughput
|
||||||
self.chunk_size = 1024 * 1024
|
self.chunk_size = 1024 * 1024
|
||||||
|
|
||||||
@@ -44,8 +45,8 @@ class CivitaiClient:
|
|||||||
# Optimize TCP connection parameters
|
# Optimize TCP connection parameters
|
||||||
connector = aiohttp.TCPConnector(
|
connector = aiohttp.TCPConnector(
|
||||||
ssl=True,
|
ssl=True,
|
||||||
limit=10, # Increase parallel connections
|
limit=3, # Further reduced from 5 to 3
|
||||||
ttl_dns_cache=300, # DNS cache time
|
ttl_dns_cache=0, # Disabled DNS caching completely
|
||||||
force_close=False, # Keep connections for reuse
|
force_close=False, # Keep connections for reuse
|
||||||
enable_cleanup_closed=True
|
enable_cleanup_closed=True
|
||||||
)
|
)
|
||||||
@@ -57,7 +58,18 @@ class CivitaiClient:
|
|||||||
trust_env=trust_env,
|
trust_env=trust_env,
|
||||||
timeout=timeout
|
timeout=timeout
|
||||||
)
|
)
|
||||||
|
self._session_created_at = datetime.now()
|
||||||
return self._session
|
return self._session
|
||||||
|
|
||||||
|
async def _ensure_fresh_session(self):
|
||||||
|
"""Refresh session if it's been open too long"""
|
||||||
|
if self._session is not None:
|
||||||
|
if not hasattr(self, '_session_created_at') or \
|
||||||
|
(datetime.now() - self._session_created_at).total_seconds() > 300: # 5 minutes
|
||||||
|
await self.close()
|
||||||
|
self._session = None
|
||||||
|
|
||||||
|
return await self.session
|
||||||
|
|
||||||
def _parse_content_disposition(self, header: str) -> str:
|
def _parse_content_disposition(self, header: str) -> str:
|
||||||
"""Parse filename from content-disposition header"""
|
"""Parse filename from content-disposition header"""
|
||||||
@@ -103,13 +115,15 @@ class CivitaiClient:
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, str]: (success, save_path or error message)
|
Tuple[bool, str]: (success, save_path or error message)
|
||||||
"""
|
"""
|
||||||
session = await self.session
|
logger.debug(f"Resolving DNS for: {url}")
|
||||||
|
session = await self._ensure_fresh_session()
|
||||||
try:
|
try:
|
||||||
headers = self._get_request_headers()
|
headers = self._get_request_headers()
|
||||||
|
|
||||||
# Add Range header to allow resumable downloads
|
# Add Range header to allow resumable downloads
|
||||||
headers['Accept-Encoding'] = 'identity' # Disable compression for better chunked downloads
|
headers['Accept-Encoding'] = 'identity' # Disable compression for better chunked downloads
|
||||||
|
|
||||||
|
logger.debug(f"Starting download from: {url}")
|
||||||
async with session.get(url, headers=headers, allow_redirects=True) as response:
|
async with session.get(url, headers=headers, allow_redirects=True) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
# Handle 401 unauthorized responses
|
# Handle 401 unauthorized responses
|
||||||
@@ -124,6 +138,7 @@ class CivitaiClient:
|
|||||||
return False, "Access forbidden: You don't have permission to download this file."
|
return False, "Access forbidden: You don't have permission to download this file."
|
||||||
|
|
||||||
# Generic error response for other status codes
|
# Generic error response for other status codes
|
||||||
|
logger.error(f"Download failed for {url} with status {response.status}")
|
||||||
return False, f"Download failed with status {response.status}"
|
return False, f"Download failed with status {response.status}"
|
||||||
|
|
||||||
# Get filename from content-disposition header
|
# Get filename from content-disposition header
|
||||||
@@ -170,7 +185,7 @@ class CivitaiClient:
|
|||||||
|
|
||||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
||||||
try:
|
try:
|
||||||
session = await self.session
|
session = await self._ensure_fresh_session()
|
||||||
async with session.get(f"{self.base_url}/model-versions/by-hash/{model_hash}") as response:
|
async with session.get(f"{self.base_url}/model-versions/by-hash/{model_hash}") as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
return await response.json()
|
return await response.json()
|
||||||
@@ -181,7 +196,7 @@ class CivitaiClient:
|
|||||||
|
|
||||||
async def download_preview_image(self, image_url: str, save_path: str):
|
async def download_preview_image(self, image_url: str, save_path: str):
|
||||||
try:
|
try:
|
||||||
session = await self.session
|
session = await self._ensure_fresh_session()
|
||||||
async with session.get(image_url) as response:
|
async with session.get(image_url) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
content = await response.read()
|
content = await response.read()
|
||||||
@@ -196,7 +211,7 @@ class CivitaiClient:
|
|||||||
async def get_model_versions(self, model_id: str) -> List[Dict]:
|
async def get_model_versions(self, model_id: str) -> List[Dict]:
|
||||||
"""Get all versions of a model with local availability info"""
|
"""Get all versions of a model with local availability info"""
|
||||||
try:
|
try:
|
||||||
session = await self.session # 等待获取 session
|
session = await self._ensure_fresh_session() # Use fresh session
|
||||||
async with session.get(f"{self.base_url}/models/{model_id}") as response:
|
async with session.get(f"{self.base_url}/models/{model_id}") as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
return None
|
return None
|
||||||
@@ -222,12 +237,14 @@ class CivitaiClient:
|
|||||||
- An error message if there was an error, or None on success
|
- An error message if there was an error, or None on success
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
session = await self.session
|
session = await self._ensure_fresh_session()
|
||||||
url = f"{self.base_url}/model-versions/{version_id}"
|
url = f"{self.base_url}/model-versions/{version_id}"
|
||||||
headers = self._get_request_headers()
|
headers = self._get_request_headers()
|
||||||
|
|
||||||
|
logger.debug(f"Resolving DNS for model version info: {url}")
|
||||||
async with session.get(url, headers=headers) as response:
|
async with session.get(url, headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
|
logger.debug(f"Successfully fetched model version info for: {version_id}")
|
||||||
return await response.json(), None
|
return await response.json(), None
|
||||||
|
|
||||||
# Handle specific error cases
|
# Handle specific error cases
|
||||||
@@ -242,6 +259,7 @@ class CivitaiClient:
|
|||||||
return None, "Model not found (status 404)"
|
return None, "Model not found (status 404)"
|
||||||
|
|
||||||
# Other error cases
|
# Other error cases
|
||||||
|
logger.error(f"Failed to fetch model info for {version_id} (status {response.status})")
|
||||||
return None, f"Failed to fetch model info (status {response.status})"
|
return None, f"Failed to fetch model info (status {response.status})"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error fetching model version info: {e}"
|
error_msg = f"Error fetching model version info: {e}"
|
||||||
@@ -260,7 +278,7 @@ class CivitaiClient:
|
|||||||
- The HTTP status code from the request
|
- The HTTP status code from the request
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
session = await self.session
|
session = await self._ensure_fresh_session()
|
||||||
headers = self._get_request_headers()
|
headers = self._get_request_headers()
|
||||||
url = f"{self.base_url}/models/{model_id}"
|
url = f"{self.base_url}/models/{model_id}"
|
||||||
|
|
||||||
@@ -304,10 +322,11 @@ class CivitaiClient:
|
|||||||
async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]:
|
async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]:
|
||||||
"""Get hash from Civitai API"""
|
"""Get hash from Civitai API"""
|
||||||
try:
|
try:
|
||||||
if not self._session:
|
session = await self._ensure_fresh_session()
|
||||||
|
if not session:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
version_info = await self._session.get(f"{self.base_url}/model-versions/{model_version_id}")
|
version_info = await session.get(f"{self.base_url}/model-versions/{model_version_id}")
|
||||||
|
|
||||||
if not version_info or not version_info.json().get('files'):
|
if not version_info or not version_info.json().get('files'):
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -88,16 +88,16 @@ class DownloadManager:
|
|||||||
version_info = None
|
version_info = None
|
||||||
error_msg = None
|
error_msg = None
|
||||||
|
|
||||||
if download_url:
|
if model_hash:
|
||||||
# Extract version ID from download URL
|
# Get model by hash
|
||||||
version_id = download_url.split('/')[-1]
|
version_info = await civitai_client.get_model_by_hash(model_hash)
|
||||||
version_info, error_msg = await civitai_client.get_model_version_info(version_id)
|
|
||||||
elif model_version_id:
|
elif model_version_id:
|
||||||
# Use model version ID directly
|
# Use model version ID directly
|
||||||
version_info, error_msg = await civitai_client.get_model_version_info(model_version_id)
|
version_info, error_msg = await civitai_client.get_model_version_info(model_version_id)
|
||||||
elif model_hash:
|
elif download_url:
|
||||||
# Get model by hash
|
# Extract version ID from download URL
|
||||||
version_info = await civitai_client.get_model_by_hash(model_hash)
|
version_id = download_url.split('/')[-1]
|
||||||
|
version_info, error_msg = await civitai_client.get_model_version_info(version_id)
|
||||||
|
|
||||||
|
|
||||||
if not version_info:
|
if not version_info:
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from typing import List, Dict, Optional, Set
|
|||||||
from ..utils.models import LoraMetadata
|
from ..utils.models import LoraMetadata
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from .model_scanner import ModelScanner
|
from .model_scanner import ModelScanner
|
||||||
from .lora_hash_index import LoraHashIndex
|
from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex
|
||||||
from .settings_manager import settings
|
from .settings_manager import settings
|
||||||
from ..utils.constants import NSFW_LEVELS
|
from ..utils.constants import NSFW_LEVELS
|
||||||
from ..utils.utils import fuzzy_match
|
from ..utils.utils import fuzzy_match
|
||||||
@@ -35,12 +35,12 @@ class LoraScanner(ModelScanner):
|
|||||||
# Define supported file extensions
|
# Define supported file extensions
|
||||||
file_extensions = {'.safetensors'}
|
file_extensions = {'.safetensors'}
|
||||||
|
|
||||||
# Initialize parent class
|
# Initialize parent class with ModelHashIndex
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_type="lora",
|
model_type="lora",
|
||||||
model_class=LoraMetadata,
|
model_class=LoraMetadata,
|
||||||
file_extensions=file_extensions,
|
file_extensions=file_extensions,
|
||||||
hash_index=LoraHashIndex()
|
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
|
||||||
)
|
)
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
@@ -122,7 +122,8 @@ class LoraScanner(ModelScanner):
|
|||||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
||||||
folder: str = None, search: str = None, fuzzy_search: bool = False,
|
folder: str = None, search: str = None, fuzzy_search: bool = False,
|
||||||
base_models: list = None, tags: list = None,
|
base_models: list = None, tags: list = None,
|
||||||
search_options: dict = None, hash_filters: dict = None) -> Dict:
|
search_options: dict = None, hash_filters: dict = None,
|
||||||
|
favorites_only: bool = False) -> Dict:
|
||||||
"""Get paginated and filtered lora data
|
"""Get paginated and filtered lora data
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -136,6 +137,7 @@ class LoraScanner(ModelScanner):
|
|||||||
tags: List of tags to filter by
|
tags: List of tags to filter by
|
||||||
search_options: Dictionary with search options (filename, modelname, tags, recursive)
|
search_options: Dictionary with search options (filename, modelname, tags, recursive)
|
||||||
hash_filters: Dictionary with hash filtering options (single_hash or multiple_hashes)
|
hash_filters: Dictionary with hash filtering options (single_hash or multiple_hashes)
|
||||||
|
favorites_only: Filter for favorite models only
|
||||||
"""
|
"""
|
||||||
cache = await self.get_cached_data()
|
cache = await self.get_cached_data()
|
||||||
|
|
||||||
@@ -194,6 +196,13 @@ class LoraScanner(ModelScanner):
|
|||||||
if not lora.get('preview_nsfw_level') or lora.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
if not lora.get('preview_nsfw_level') or lora.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Apply favorites filtering if enabled
|
||||||
|
if favorites_only:
|
||||||
|
filtered_data = [
|
||||||
|
lora for lora in filtered_data
|
||||||
|
if lora.get('favorite', False) is True
|
||||||
|
]
|
||||||
|
|
||||||
# Apply folder filtering
|
# Apply folder filtering
|
||||||
if folder is not None:
|
if folder is not None:
|
||||||
if search_options.get('recursive', False):
|
if search_options.get('recursive', False):
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
from typing import Dict, Optional, Set
|
from typing import Dict, Optional, Set
|
||||||
|
import os
|
||||||
|
|
||||||
class ModelHashIndex:
|
class ModelHashIndex:
|
||||||
"""Index for looking up models by hash or path"""
|
"""Index for looking up models by hash or path"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._hash_to_path: Dict[str, str] = {}
|
self._hash_to_path: Dict[str, str] = {}
|
||||||
self._path_to_hash: Dict[str, str] = {}
|
self._filename_to_hash: Dict[str, str] = {} # Changed from path_to_hash to filename_to_hash
|
||||||
|
|
||||||
def add_entry(self, sha256: str, file_path: str) -> None:
|
def add_entry(self, sha256: str, file_path: str) -> None:
|
||||||
"""Add or update hash index entry"""
|
"""Add or update hash index entry"""
|
||||||
@@ -15,37 +16,47 @@ class ModelHashIndex:
|
|||||||
# Ensure hash is lowercase for consistency
|
# Ensure hash is lowercase for consistency
|
||||||
sha256 = sha256.lower()
|
sha256 = sha256.lower()
|
||||||
|
|
||||||
|
# Extract filename without extension
|
||||||
|
filename = self._get_filename_from_path(file_path)
|
||||||
|
|
||||||
# Remove old path mapping if hash exists
|
# Remove old path mapping if hash exists
|
||||||
if sha256 in self._hash_to_path:
|
if sha256 in self._hash_to_path:
|
||||||
old_path = self._hash_to_path[sha256]
|
old_path = self._hash_to_path[sha256]
|
||||||
if old_path in self._path_to_hash:
|
old_filename = self._get_filename_from_path(old_path)
|
||||||
del self._path_to_hash[old_path]
|
if old_filename in self._filename_to_hash:
|
||||||
|
del self._filename_to_hash[old_filename]
|
||||||
|
|
||||||
# Remove old hash mapping if path exists
|
# Remove old hash mapping if filename exists
|
||||||
if file_path in self._path_to_hash:
|
if filename in self._filename_to_hash:
|
||||||
old_hash = self._path_to_hash[file_path]
|
old_hash = self._filename_to_hash[filename]
|
||||||
if old_hash in self._hash_to_path:
|
if old_hash in self._hash_to_path:
|
||||||
del self._hash_to_path[old_hash]
|
del self._hash_to_path[old_hash]
|
||||||
|
|
||||||
# Add new mappings
|
# Add new mappings
|
||||||
self._hash_to_path[sha256] = file_path
|
self._hash_to_path[sha256] = file_path
|
||||||
self._path_to_hash[file_path] = sha256
|
self._filename_to_hash[filename] = sha256
|
||||||
|
|
||||||
|
def _get_filename_from_path(self, file_path: str) -> str:
|
||||||
|
"""Extract filename without extension from path"""
|
||||||
|
return os.path.splitext(os.path.basename(file_path))[0]
|
||||||
|
|
||||||
def remove_by_path(self, file_path: str) -> None:
|
def remove_by_path(self, file_path: str) -> None:
|
||||||
"""Remove entry by file path"""
|
"""Remove entry by file path"""
|
||||||
if file_path in self._path_to_hash:
|
filename = self._get_filename_from_path(file_path)
|
||||||
hash_val = self._path_to_hash[file_path]
|
if filename in self._filename_to_hash:
|
||||||
|
hash_val = self._filename_to_hash[filename]
|
||||||
if hash_val in self._hash_to_path:
|
if hash_val in self._hash_to_path:
|
||||||
del self._hash_to_path[hash_val]
|
del self._hash_to_path[hash_val]
|
||||||
del self._path_to_hash[file_path]
|
del self._filename_to_hash[filename]
|
||||||
|
|
||||||
def remove_by_hash(self, sha256: str) -> None:
|
def remove_by_hash(self, sha256: str) -> None:
|
||||||
"""Remove entry by hash"""
|
"""Remove entry by hash"""
|
||||||
sha256 = sha256.lower()
|
sha256 = sha256.lower()
|
||||||
if sha256 in self._hash_to_path:
|
if sha256 in self._hash_to_path:
|
||||||
path = self._hash_to_path[sha256]
|
path = self._hash_to_path[sha256]
|
||||||
if path in self._path_to_hash:
|
filename = self._get_filename_from_path(path)
|
||||||
del self._path_to_hash[path]
|
if filename in self._filename_to_hash:
|
||||||
|
del self._filename_to_hash[filename]
|
||||||
del self._hash_to_path[sha256]
|
del self._hash_to_path[sha256]
|
||||||
|
|
||||||
def has_hash(self, sha256: str) -> bool:
|
def has_hash(self, sha256: str) -> bool:
|
||||||
@@ -58,20 +69,27 @@ class ModelHashIndex:
|
|||||||
|
|
||||||
def get_hash(self, file_path: str) -> Optional[str]:
|
def get_hash(self, file_path: str) -> Optional[str]:
|
||||||
"""Get hash for a file path"""
|
"""Get hash for a file path"""
|
||||||
return self._path_to_hash.get(file_path)
|
filename = self._get_filename_from_path(file_path)
|
||||||
|
return self._filename_to_hash.get(filename)
|
||||||
|
|
||||||
|
def get_hash_by_filename(self, filename: str) -> Optional[str]:
|
||||||
|
"""Get hash for a filename without extension"""
|
||||||
|
# Strip extension if present to make the function more flexible
|
||||||
|
filename = os.path.splitext(filename)[0]
|
||||||
|
return self._filename_to_hash.get(filename)
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear all entries"""
|
"""Clear all entries"""
|
||||||
self._hash_to_path.clear()
|
self._hash_to_path.clear()
|
||||||
self._path_to_hash.clear()
|
self._filename_to_hash.clear()
|
||||||
|
|
||||||
def get_all_hashes(self) -> Set[str]:
|
def get_all_hashes(self) -> Set[str]:
|
||||||
"""Get all hashes in the index"""
|
"""Get all hashes in the index"""
|
||||||
return set(self._hash_to_path.keys())
|
return set(self._hash_to_path.keys())
|
||||||
|
|
||||||
def get_all_paths(self) -> Set[str]:
|
def get_all_filenames(self) -> Set[str]:
|
||||||
"""Get all file paths in the index"""
|
"""Get all filenames in the index"""
|
||||||
return set(self._path_to_hash.keys())
|
return set(self._filename_to_hash.keys())
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Get number of entries"""
|
"""Get number of entries"""
|
||||||
|
|||||||
@@ -292,7 +292,7 @@ class ModelScanner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# If force refresh is requested, initialize the cache directly
|
# If force refresh is requested, initialize the cache directly
|
||||||
if force_refresh:
|
if (force_refresh):
|
||||||
if self._cache is None:
|
if self._cache is None:
|
||||||
# For initial creation, do a full initialization
|
# For initial creation, do a full initialization
|
||||||
await self._initialize_cache()
|
await self._initialize_cache()
|
||||||
@@ -553,9 +553,36 @@ class ModelScanner:
|
|||||||
logger.debug(f"Created metadata from .civitai.info for {file_path}")
|
logger.debug(f"Created metadata from .civitai.info for {file_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
|
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
|
||||||
|
else:
|
||||||
|
# Check if metadata exists but civitai field is empty - try to restore from civitai.info
|
||||||
|
if metadata.civitai is None or metadata.civitai == {}:
|
||||||
|
civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info"
|
||||||
|
if os.path.exists(civitai_info_path):
|
||||||
|
try:
|
||||||
|
with open(civitai_info_path, 'r', encoding='utf-8') as f:
|
||||||
|
version_info = json.load(f)
|
||||||
|
|
||||||
|
logger.debug(f"Restoring missing civitai data from .civitai.info for {file_path}")
|
||||||
|
metadata.civitai = version_info
|
||||||
|
|
||||||
|
# Ensure tags are also updated if they're missing
|
||||||
|
if (not metadata.tags or len(metadata.tags) == 0) and 'model' in version_info:
|
||||||
|
if 'tags' in version_info['model']:
|
||||||
|
metadata.tags = version_info['model']['tags']
|
||||||
|
|
||||||
|
# Also restore description if missing
|
||||||
|
if (not metadata.modelDescription or metadata.modelDescription == "") and 'model' in version_info:
|
||||||
|
if 'description' in version_info['model']:
|
||||||
|
metadata.modelDescription = version_info['model']['description']
|
||||||
|
|
||||||
|
# Save the updated metadata
|
||||||
|
await save_metadata(file_path, metadata)
|
||||||
|
logger.debug(f"Updated metadata with civitai info for {file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error restoring civitai data from .civitai.info for {file_path}: {e}")
|
||||||
|
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = await self._get_file_info(file_path)
|
metadata = await self._get_file_info(file_path)
|
||||||
|
|
||||||
model_data = metadata.to_dict()
|
model_data = metadata.to_dict()
|
||||||
|
|
||||||
@@ -709,6 +736,12 @@ class ModelScanner:
|
|||||||
shutil.move(source_metadata, target_metadata)
|
shutil.move(source_metadata, target_metadata)
|
||||||
metadata = await self._update_metadata_paths(target_metadata, target_file)
|
metadata = await self._update_metadata_paths(target_metadata, target_file)
|
||||||
|
|
||||||
|
# Move civitai.info file if exists
|
||||||
|
source_civitai = os.path.join(source_dir, f"{base_name}.civitai.info")
|
||||||
|
if os.path.exists(source_civitai):
|
||||||
|
target_civitai = os.path.join(target_path, f"{base_name}.civitai.info")
|
||||||
|
shutil.move(source_civitai, target_civitai)
|
||||||
|
|
||||||
for ext in PREVIEW_EXTENSIONS:
|
for ext in PREVIEW_EXTENSIONS:
|
||||||
source_preview = os.path.join(source_dir, f"{base_name}{ext}")
|
source_preview = os.path.join(source_dir, f"{base_name}{ext}")
|
||||||
if os.path.exists(source_preview):
|
if os.path.exists(source_preview):
|
||||||
@@ -805,6 +838,10 @@ class ModelScanner:
|
|||||||
def get_hash_by_path(self, file_path: str) -> Optional[str]:
|
def get_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||||
"""Get hash for a model by its file path"""
|
"""Get hash for a model by its file path"""
|
||||||
return self._hash_index.get_hash(file_path)
|
return self._hash_index.get_hash(file_path)
|
||||||
|
|
||||||
|
def get_hash_by_filename(self, filename: str) -> Optional[str]:
|
||||||
|
"""Get hash for a model by its filename without path"""
|
||||||
|
return self._hash_index.get_hash_by_filename(filename)
|
||||||
|
|
||||||
# TODO: Adjust this method to use metadata instead of finding the file
|
# TODO: Adjust this method to use metadata instead of finding the file
|
||||||
def get_preview_url_by_hash(self, sha256: str) -> Optional[str]:
|
def get_preview_url_by_hash(self, sha256: str) -> Optional[str]:
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ class BaseModelMetadata:
|
|||||||
civitai: Optional[Dict] = None # Civitai API data if available
|
civitai: Optional[Dict] = None # Civitai API data if available
|
||||||
tags: List[str] = None # Model tags
|
tags: List[str] = None # Model tags
|
||||||
modelDescription: str = "" # Full model description
|
modelDescription: str = "" # Full model description
|
||||||
|
civitai_deleted: bool = False # Whether deleted from Civitai
|
||||||
|
favorite: bool = False # Whether the model is a favorite
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Initialize empty lists to avoid mutable default parameter issue
|
# Initialize empty lists to avoid mutable default parameter issue
|
||||||
@@ -64,6 +66,15 @@ class LoraMetadata(BaseModelMetadata):
|
|||||||
file_name = file_info['name']
|
file_name = file_info['name']
|
||||||
base_model = determine_base_model(version_info.get('baseModel', ''))
|
base_model = determine_base_model(version_info.get('baseModel', ''))
|
||||||
|
|
||||||
|
# Extract tags and description if available
|
||||||
|
tags = []
|
||||||
|
description = ""
|
||||||
|
if 'model' in version_info:
|
||||||
|
if 'tags' in version_info['model']:
|
||||||
|
tags = version_info['model']['tags']
|
||||||
|
if 'description' in version_info['model']:
|
||||||
|
description = version_info['model']['description']
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
file_name=os.path.splitext(file_name)[0],
|
file_name=os.path.splitext(file_name)[0],
|
||||||
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
||||||
@@ -75,7 +86,9 @@ class LoraMetadata(BaseModelMetadata):
|
|||||||
preview_url=None, # Will be updated after preview download
|
preview_url=None, # Will be updated after preview download
|
||||||
preview_nsfw_level=0, # Will be updated after preview download
|
preview_nsfw_level=0, # Will be updated after preview download
|
||||||
from_civitai=True,
|
from_civitai=True,
|
||||||
civitai=version_info
|
civitai=version_info,
|
||||||
|
tags=tags,
|
||||||
|
modelDescription=description
|
||||||
)
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -90,6 +103,15 @@ class CheckpointMetadata(BaseModelMetadata):
|
|||||||
base_model = determine_base_model(version_info.get('baseModel', ''))
|
base_model = determine_base_model(version_info.get('baseModel', ''))
|
||||||
model_type = version_info.get('type', 'checkpoint')
|
model_type = version_info.get('type', 'checkpoint')
|
||||||
|
|
||||||
|
# Extract tags and description if available
|
||||||
|
tags = []
|
||||||
|
description = ""
|
||||||
|
if 'model' in version_info:
|
||||||
|
if 'tags' in version_info['model']:
|
||||||
|
tags = version_info['model']['tags']
|
||||||
|
if 'description' in version_info['model']:
|
||||||
|
description = version_info['model']['description']
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
file_name=os.path.splitext(file_name)[0],
|
file_name=os.path.splitext(file_name)[0],
|
||||||
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
||||||
@@ -102,6 +124,8 @@ class CheckpointMetadata(BaseModelMetadata):
|
|||||||
preview_nsfw_level=0,
|
preview_nsfw_level=0,
|
||||||
from_civitai=True,
|
from_civitai=True,
|
||||||
civitai=version_info,
|
civitai=version_info,
|
||||||
model_type=model_type
|
model_type=model_type,
|
||||||
|
tags=tags,
|
||||||
|
modelDescription=description
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -45,14 +45,14 @@ class RecipeMetadataParser(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def populate_lora_from_civitai(self, lora_entry: Dict[str, Any], civitai_info: Dict[str, Any],
|
async def populate_lora_from_civitai(self, lora_entry: Dict[str, Any], civitai_info_tuple: Tuple[Dict[str, Any], Optional[str]],
|
||||||
recipe_scanner=None, base_model_counts=None, hash_value=None) -> Dict[str, Any]:
|
recipe_scanner=None, base_model_counts=None, hash_value=None) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Populate a lora entry with information from Civitai API response
|
Populate a lora entry with information from Civitai API response
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
lora_entry: The lora entry to populate
|
lora_entry: The lora entry to populate
|
||||||
civitai_info: The response from Civitai API
|
civitai_info_tuple: The response tuple from Civitai API (data, error_msg)
|
||||||
recipe_scanner: Optional recipe scanner for local file lookup
|
recipe_scanner: Optional recipe scanner for local file lookup
|
||||||
base_model_counts: Optional dict to track base model counts
|
base_model_counts: Optional dict to track base model counts
|
||||||
hash_value: Optional hash value to use if not available in civitai_info
|
hash_value: Optional hash value to use if not available in civitai_info
|
||||||
@@ -61,6 +61,9 @@ class RecipeMetadataParser(ABC):
|
|||||||
The populated lora_entry dict
|
The populated lora_entry dict
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Unpack the tuple to get the actual data
|
||||||
|
civitai_info, error_msg = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
|
||||||
|
|
||||||
if civitai_info and civitai_info.get("error") != "Model not found":
|
if civitai_info and civitai_info.get("error") != "Model not found":
|
||||||
# Check if this is an early access lora
|
# Check if this is an early access lora
|
||||||
if civitai_info.get('earlyAccessEndsAt'):
|
if civitai_info.get('earlyAccessEndsAt'):
|
||||||
@@ -94,8 +97,9 @@ class RecipeMetadataParser(ABC):
|
|||||||
|
|
||||||
# Process file information if available
|
# Process file information if available
|
||||||
if 'files' in civitai_info:
|
if 'files' in civitai_info:
|
||||||
|
# Find the primary model file (type="Model" and primary=true) in the files list
|
||||||
model_file = next((file for file in civitai_info.get('files', [])
|
model_file = next((file for file in civitai_info.get('files', [])
|
||||||
if file.get('type') == 'Model'), None)
|
if file.get('type') == 'Model' and file.get('primary') == True), None)
|
||||||
|
|
||||||
if model_file:
|
if model_file:
|
||||||
# Get size
|
# Get size
|
||||||
@@ -241,11 +245,11 @@ class RecipeFormatParser(RecipeMetadataParser):
|
|||||||
# Try to get additional info from Civitai if we have a model version ID
|
# Try to get additional info from Civitai if we have a model version ID
|
||||||
if lora.get('modelVersionId') and civitai_client:
|
if lora.get('modelVersionId') and civitai_client:
|
||||||
try:
|
try:
|
||||||
civitai_info = await civitai_client.get_model_version_info(lora['modelVersionId'])
|
civitai_info_tuple = await civitai_client.get_model_version_info(lora['modelVersionId'])
|
||||||
# Populate lora entry with Civitai info
|
# Populate lora entry with Civitai info
|
||||||
lora_entry = await self.populate_lora_from_civitai(
|
lora_entry = await self.populate_lora_from_civitai(
|
||||||
lora_entry,
|
lora_entry,
|
||||||
civitai_info,
|
civitai_info_tuple,
|
||||||
recipe_scanner,
|
recipe_scanner,
|
||||||
None, # No need to track base model counts
|
None, # No need to track base model counts
|
||||||
lora['hash']
|
lora['hash']
|
||||||
@@ -336,12 +340,13 @@ class StandardMetadataParser(RecipeMetadataParser):
|
|||||||
# Get additional info from Civitai if client is available
|
# Get additional info from Civitai if client is available
|
||||||
if civitai_client:
|
if civitai_client:
|
||||||
try:
|
try:
|
||||||
civitai_info = await civitai_client.get_model_version_info(model_version_id)
|
civitai_info_tuple = await civitai_client.get_model_version_info(model_version_id)
|
||||||
# Populate lora entry with Civitai info
|
# Populate lora entry with Civitai info
|
||||||
lora_entry = await self.populate_lora_from_civitai(
|
lora_entry = await self.populate_lora_from_civitai(
|
||||||
lora_entry,
|
lora_entry,
|
||||||
civitai_info,
|
civitai_info_tuple,
|
||||||
recipe_scanner
|
recipe_scanner,
|
||||||
|
base_model_counts
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching Civitai info for LoRA: {e}")
|
logger.error(f"Error fetching Civitai info for LoRA: {e}")
|
||||||
@@ -621,11 +626,11 @@ class ComfyMetadataParser(RecipeMetadataParser):
|
|||||||
# Get additional info from Civitai if client is available
|
# Get additional info from Civitai if client is available
|
||||||
if civitai_client:
|
if civitai_client:
|
||||||
try:
|
try:
|
||||||
civitai_info = await civitai_client.get_model_version_info(model_version_id)
|
civitai_info_tuple = await civitai_client.get_model_version_info(model_version_id)
|
||||||
# Populate lora entry with Civitai info
|
# Populate lora entry with Civitai info
|
||||||
lora_entry = await self.populate_lora_from_civitai(
|
lora_entry = await self.populate_lora_from_civitai(
|
||||||
lora_entry,
|
lora_entry,
|
||||||
civitai_info,
|
civitai_info_tuple,
|
||||||
recipe_scanner
|
recipe_scanner
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -660,7 +665,8 @@ class ComfyMetadataParser(RecipeMetadataParser):
|
|||||||
# Get additional checkpoint info from Civitai
|
# Get additional checkpoint info from Civitai
|
||||||
if civitai_client:
|
if civitai_client:
|
||||||
try:
|
try:
|
||||||
civitai_info = await civitai_client.get_model_version_info(checkpoint_version_id)
|
civitai_info_tuple = await civitai_client.get_model_version_info(checkpoint_version_id)
|
||||||
|
civitai_info, _ = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
|
||||||
# Populate checkpoint with Civitai info
|
# Populate checkpoint with Civitai info
|
||||||
checkpoint = await self.populate_checkpoint_from_civitai(checkpoint, civitai_info)
|
checkpoint = await self.populate_checkpoint_from_civitai(checkpoint, civitai_info)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
267
py/utils/usage_stats.py
Normal file
267
py/utils/usage_stats.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Set
|
||||||
|
|
||||||
|
from ..config import config
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
from ..metadata_collector.metadata_registry import MetadataRegistry
|
||||||
|
from ..metadata_collector.constants import MODELS, LORAS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class UsageStats:
|
||||||
|
"""Track usage statistics for models and save to JSON"""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
_lock = asyncio.Lock() # For thread safety
|
||||||
|
|
||||||
|
# Default stats file name
|
||||||
|
STATS_FILENAME = "lora_manager_stats.json"
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._initialized = False
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Initialize stats storage
|
||||||
|
self.stats = {
|
||||||
|
"checkpoints": {}, # sha256 -> count
|
||||||
|
"loras": {}, # sha256 -> count
|
||||||
|
"total_executions": 0,
|
||||||
|
"last_save_time": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# Queue for prompt_ids to process
|
||||||
|
self.pending_prompt_ids = set()
|
||||||
|
|
||||||
|
# Load existing stats if available
|
||||||
|
self._stats_file_path = self._get_stats_file_path()
|
||||||
|
self._load_stats()
|
||||||
|
|
||||||
|
# Save interval in seconds
|
||||||
|
self.save_interval = 90 # 1.5 minutes
|
||||||
|
|
||||||
|
# Start background task to process queued prompt_ids
|
||||||
|
self._bg_task = asyncio.create_task(self._background_processor())
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
logger.info("Usage statistics tracker initialized")
|
||||||
|
|
||||||
|
def _get_stats_file_path(self) -> str:
|
||||||
|
"""Get the path to the stats JSON file"""
|
||||||
|
if not config.loras_roots or len(config.loras_roots) == 0:
|
||||||
|
# Fallback to temporary directory if no lora roots
|
||||||
|
return os.path.join(config.temp_directory, self.STATS_FILENAME)
|
||||||
|
|
||||||
|
# Use the first lora root
|
||||||
|
return os.path.join(config.loras_roots[0], self.STATS_FILENAME)
|
||||||
|
|
||||||
|
def _load_stats(self):
|
||||||
|
"""Load existing statistics from file"""
|
||||||
|
try:
|
||||||
|
if os.path.exists(self._stats_file_path):
|
||||||
|
with open(self._stats_file_path, 'r', encoding='utf-8') as f:
|
||||||
|
loaded_stats = json.load(f)
|
||||||
|
|
||||||
|
# Update our stats with loaded data
|
||||||
|
if isinstance(loaded_stats, dict):
|
||||||
|
# Update individual sections to maintain structure
|
||||||
|
if "checkpoints" in loaded_stats and isinstance(loaded_stats["checkpoints"], dict):
|
||||||
|
self.stats["checkpoints"] = loaded_stats["checkpoints"]
|
||||||
|
|
||||||
|
if "loras" in loaded_stats and isinstance(loaded_stats["loras"], dict):
|
||||||
|
self.stats["loras"] = loaded_stats["loras"]
|
||||||
|
|
||||||
|
if "total_executions" in loaded_stats:
|
||||||
|
self.stats["total_executions"] = loaded_stats["total_executions"]
|
||||||
|
|
||||||
|
logger.info(f"Loaded usage statistics from {self._stats_file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading usage statistics: {e}")
|
||||||
|
|
||||||
|
async def save_stats(self, force=False):
|
||||||
|
"""Save statistics to file"""
|
||||||
|
try:
|
||||||
|
# Only save if it's been at least save_interval since last save or force is True
|
||||||
|
current_time = time.time()
|
||||||
|
if not force and (current_time - self.stats.get("last_save_time", 0)) < self.save_interval:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Use a lock to prevent concurrent writes
|
||||||
|
async with self._lock:
|
||||||
|
# Update last save time
|
||||||
|
self.stats["last_save_time"] = current_time
|
||||||
|
|
||||||
|
# Create directory if it doesn't exist
|
||||||
|
os.makedirs(os.path.dirname(self._stats_file_path), exist_ok=True)
|
||||||
|
|
||||||
|
# Write to a temporary file first, then move it to avoid corruption
|
||||||
|
temp_path = f"{self._stats_file_path}.tmp"
|
||||||
|
with open(temp_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(self.stats, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
# Replace the old file with the new one
|
||||||
|
os.replace(temp_path, self._stats_file_path)
|
||||||
|
|
||||||
|
logger.debug(f"Saved usage statistics to {self._stats_file_path}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving usage statistics: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def register_execution(self, prompt_id):
|
||||||
|
"""Register a completed execution by prompt_id for later processing"""
|
||||||
|
if prompt_id:
|
||||||
|
self.pending_prompt_ids.add(prompt_id)
|
||||||
|
|
||||||
|
async def _background_processor(self):
|
||||||
|
"""Background task to process queued prompt_ids"""
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Wait a short interval before checking for new prompt_ids
|
||||||
|
await asyncio.sleep(5) # Check every 5 seconds
|
||||||
|
|
||||||
|
# Process any pending prompt_ids
|
||||||
|
if self.pending_prompt_ids:
|
||||||
|
async with self._lock:
|
||||||
|
# Get a copy of the set and clear original
|
||||||
|
prompt_ids = self.pending_prompt_ids.copy()
|
||||||
|
self.pending_prompt_ids.clear()
|
||||||
|
|
||||||
|
# Process each prompt_id
|
||||||
|
registry = MetadataRegistry()
|
||||||
|
for prompt_id in prompt_ids:
|
||||||
|
try:
|
||||||
|
metadata = registry.get_metadata(prompt_id)
|
||||||
|
await self._process_metadata(metadata)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing prompt_id {prompt_id}: {e}")
|
||||||
|
|
||||||
|
# Periodically save stats
|
||||||
|
await self.save_stats()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Task was cancelled, clean up
|
||||||
|
await self.save_stats(force=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in background processing task: {e}", exc_info=True)
|
||||||
|
# Restart the task after a delay if it fails
|
||||||
|
asyncio.create_task(self._restart_background_task())
|
||||||
|
|
||||||
|
async def _restart_background_task(self):
|
||||||
|
"""Restart the background task after a delay"""
|
||||||
|
await asyncio.sleep(30) # Wait 30 seconds before restarting
|
||||||
|
self._bg_task = asyncio.create_task(self._background_processor())
|
||||||
|
|
||||||
|
async def _process_metadata(self, metadata):
|
||||||
|
"""Process metadata from an execution"""
|
||||||
|
if not metadata or not isinstance(metadata, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Increment total executions count
|
||||||
|
self.stats["total_executions"] += 1
|
||||||
|
|
||||||
|
# Process checkpoints
|
||||||
|
if MODELS in metadata and isinstance(metadata[MODELS], dict):
|
||||||
|
await self._process_checkpoints(metadata[MODELS])
|
||||||
|
|
||||||
|
# Process loras
|
||||||
|
if LORAS in metadata and isinstance(metadata[LORAS], dict):
|
||||||
|
await self._process_loras(metadata[LORAS])
|
||||||
|
|
||||||
|
async def _process_checkpoints(self, models_data):
|
||||||
|
"""Process checkpoint models from metadata"""
|
||||||
|
try:
|
||||||
|
# Get checkpoint scanner service
|
||||||
|
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||||
|
if not checkpoint_scanner:
|
||||||
|
logger.warning("Checkpoint scanner not available for usage tracking")
|
||||||
|
return
|
||||||
|
|
||||||
|
for node_id, model_info in models_data.items():
|
||||||
|
if not isinstance(model_info, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if this is a checkpoint model
|
||||||
|
model_type = model_info.get("type")
|
||||||
|
if model_type == "checkpoint":
|
||||||
|
model_name = model_info.get("name")
|
||||||
|
if not model_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Clean up filename (remove extension if present)
|
||||||
|
model_filename = os.path.splitext(os.path.basename(model_name))[0]
|
||||||
|
|
||||||
|
# Get hash for this checkpoint
|
||||||
|
model_hash = checkpoint_scanner.get_hash_by_filename(model_filename)
|
||||||
|
if model_hash:
|
||||||
|
# Update stats for this checkpoint
|
||||||
|
self.stats["checkpoints"][model_hash] = self.stats["checkpoints"].get(model_hash, 0) + 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing checkpoint usage: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def _process_loras(self, loras_data):
|
||||||
|
"""Process LoRA models from metadata"""
|
||||||
|
try:
|
||||||
|
# Get LoRA scanner service
|
||||||
|
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
|
if not lora_scanner:
|
||||||
|
logger.warning("LoRA scanner not available for usage tracking")
|
||||||
|
return
|
||||||
|
|
||||||
|
for node_id, lora_info in loras_data.items():
|
||||||
|
if not isinstance(lora_info, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get the list of LoRAs from standardized format
|
||||||
|
lora_list = lora_info.get("lora_list", [])
|
||||||
|
for lora in lora_list:
|
||||||
|
if not isinstance(lora, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora_name = lora.get("name")
|
||||||
|
if not lora_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get hash for this LoRA
|
||||||
|
lora_hash = lora_scanner.get_hash_by_filename(lora_name)
|
||||||
|
if lora_hash:
|
||||||
|
# Update stats for this LoRA
|
||||||
|
self.stats["loras"][lora_hash] = self.stats["loras"].get(lora_hash, 0) + 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing LoRA usage: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def get_stats(self):
|
||||||
|
"""Get current usage statistics"""
|
||||||
|
return self.stats
|
||||||
|
|
||||||
|
async def get_model_usage_count(self, model_type, sha256):
|
||||||
|
"""Get usage count for a specific model by hash"""
|
||||||
|
if model_type == "checkpoint":
|
||||||
|
return self.stats["checkpoints"].get(sha256, 0)
|
||||||
|
elif model_type == "lora":
|
||||||
|
return self.stats["loras"].get(sha256, 0)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def process_execution(self, prompt_id):
|
||||||
|
"""Process a prompt execution immediately (synchronous approach)"""
|
||||||
|
if not prompt_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Process metadata for this prompt_id
|
||||||
|
registry = MetadataRegistry()
|
||||||
|
metadata = registry.get_metadata(prompt_id)
|
||||||
|
if metadata:
|
||||||
|
await self._process_metadata(metadata)
|
||||||
|
# Save stats if needed
|
||||||
|
await self.save_stats()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing prompt_id {prompt_id}: {e}", exc_info=True)
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "comfyui-lora-manager"
|
name = "comfyui-lora-manager"
|
||||||
description = "LoRA Manager for ComfyUI - Access it at http://localhost:8188/loras for managing LoRA models with previews and metadata integration."
|
description = "LoRA Manager for ComfyUI - Access it at http://localhost:8188/loras for managing LoRA models with previews and metadata integration."
|
||||||
version = "0.8.7"
|
version = "0.8.9"
|
||||||
license = {file = "LICENSE"}
|
license = {file = "LICENSE"}
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
@@ -12,7 +12,8 @@ dependencies = [
|
|||||||
"piexif",
|
"piexif",
|
||||||
"Pillow",
|
"Pillow",
|
||||||
"olefile", # for getting rid of warning message
|
"olefile", # for getting rid of warning message
|
||||||
"requests"
|
"requests",
|
||||||
|
"toml"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
|||||||
@@ -6,4 +6,5 @@ beautifulsoup4
|
|||||||
piexif
|
piexif
|
||||||
Pillow
|
Pillow
|
||||||
olefile
|
olefile
|
||||||
requests
|
requests
|
||||||
|
toml
|
||||||
@@ -59,6 +59,16 @@ html, body {
|
|||||||
--scrollbar-width: 8px; /* 添加滚动条宽度变量 */
|
--scrollbar-width: 8px; /* 添加滚动条宽度变量 */
|
||||||
}
|
}
|
||||||
|
|
||||||
|
html[data-theme="dark"] {
|
||||||
|
background-color: #1a1a1a !important;
|
||||||
|
color-scheme: dark;
|
||||||
|
}
|
||||||
|
|
||||||
|
html[data-theme="light"] {
|
||||||
|
background-color: #ffffff !important;
|
||||||
|
color-scheme: light;
|
||||||
|
}
|
||||||
|
|
||||||
[data-theme="dark"] {
|
[data-theme="dark"] {
|
||||||
--bg-color: #1a1a1a;
|
--bg-color: #1a1a1a;
|
||||||
--text-color: #e0e0e0;
|
--text-color: #e0e0e0;
|
||||||
|
|||||||
@@ -192,12 +192,43 @@
|
|||||||
margin-left: var(--space-1);
|
margin-left: var(--space-1);
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
color: white;
|
color: white;
|
||||||
transition: opacity 0.2s;
|
transition: opacity 0.2s, transform 0.15s ease;
|
||||||
font-size: 0.9em;
|
font-size: 1.0em; /* Increased from 0.9em for better visibility */
|
||||||
|
width: 16px; /* Fixed width for consistent spacing */
|
||||||
|
height: 16px; /* Fixed height for larger touch target */
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
border-radius: 50%;
|
||||||
|
padding: 4px; /* Add padding to increase clickable area */
|
||||||
|
box-sizing: content-box; /* Ensure padding adds to dimensions */
|
||||||
|
position: relative; /* For proper positioning */
|
||||||
|
margin: 0; /* Reset margin */
|
||||||
|
}
|
||||||
|
|
||||||
|
.card-actions i::before {
|
||||||
|
position: absolute; /* Position the icon glyph */
|
||||||
|
top: 50%;
|
||||||
|
left: 50%;
|
||||||
|
transform: translate(-50%, -50%); /* Center the icon */
|
||||||
|
}
|
||||||
|
|
||||||
|
.card-actions {
|
||||||
|
display: flex;
|
||||||
|
gap: var(--space-1); /* Use gap instead of margin for spacing between icons */
|
||||||
|
align-items: center;
|
||||||
}
|
}
|
||||||
|
|
||||||
.card-actions i:hover {
|
.card-actions i:hover {
|
||||||
opacity: 0.8;
|
opacity: 0.9;
|
||||||
|
transform: scale(1.1);
|
||||||
|
background-color: rgba(255, 255, 255, 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Style for active favorites */
|
||||||
|
.favorite-active {
|
||||||
|
color: #ffc107 !important; /* Gold color for favorites */
|
||||||
|
text-shadow: 0 0 5px rgba(255, 193, 7, 0.5);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* 响应式设计 */
|
/* 响应式设计 */
|
||||||
|
|||||||
@@ -81,6 +81,22 @@
|
|||||||
opacity: 1;
|
opacity: 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Controls */
|
||||||
|
.control-group button.favorite-filter {
|
||||||
|
position: relative;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.control-group button.favorite-filter.active {
|
||||||
|
background: var(--lora-accent);
|
||||||
|
color: white;
|
||||||
|
}
|
||||||
|
|
||||||
|
.control-group button.favorite-filter i {
|
||||||
|
margin-right: 4px;
|
||||||
|
color: #ffc107;
|
||||||
|
}
|
||||||
|
|
||||||
/* Active state for buttons that can be toggled */
|
/* Active state for buttons that can be toggled */
|
||||||
.control-group button.active {
|
.control-group button.active {
|
||||||
background: var(--lora-accent);
|
background: var(--lora-accent);
|
||||||
|
|||||||
@@ -45,6 +45,11 @@ export async function loadMoreModels(options = {}) {
|
|||||||
params.append('folder', pageState.activeFolder);
|
params.append('folder', pageState.activeFolder);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add favorites filter parameter if enabled
|
||||||
|
if (pageState.showFavoritesOnly) {
|
||||||
|
params.append('favorites_only', 'true');
|
||||||
|
}
|
||||||
|
|
||||||
// Add search parameters if there's a search term
|
// Add search parameters if there's a search term
|
||||||
if (pageState.filters?.search) {
|
if (pageState.filters?.search) {
|
||||||
params.append('search', pageState.filters.search);
|
params.append('search', pageState.filters.search);
|
||||||
|
|||||||
@@ -62,8 +62,13 @@ export async function refreshSingleCheckpointMetadata(filePath) {
|
|||||||
return refreshSingleModelMetadata(filePath, 'checkpoint');
|
return refreshSingleModelMetadata(filePath, 'checkpoint');
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save checkpoint metadata (similar to the Lora version)
|
/**
|
||||||
export async function saveCheckpointMetadata(filePath, data) {
|
* Save model metadata to the server
|
||||||
|
* @param {string} filePath - Path to the model file
|
||||||
|
* @param {Object} data - Metadata to save
|
||||||
|
* @returns {Promise} - Promise that resolves with the server response
|
||||||
|
*/
|
||||||
|
export async function saveModelMetadata(filePath, data) {
|
||||||
const response = await fetch('/api/checkpoints/save-metadata', {
|
const response = await fetch('/api/checkpoints/save-metadata', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
@@ -79,5 +84,5 @@ export async function saveCheckpointMetadata(filePath, data) {
|
|||||||
throw new Error('Failed to save metadata');
|
throw new Error('Failed to save metadata');
|
||||||
}
|
}
|
||||||
|
|
||||||
return await response.json();
|
return response.json();
|
||||||
}
|
}
|
||||||
@@ -9,6 +9,31 @@ import {
|
|||||||
refreshSingleModelMetadata
|
refreshSingleModelMetadata
|
||||||
} from './baseModelApi.js';
|
} from './baseModelApi.js';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Save model metadata to the server
|
||||||
|
* @param {string} filePath - File path
|
||||||
|
* @param {Object} data - Data to save
|
||||||
|
* @returns {Promise} Promise of the save operation
|
||||||
|
*/
|
||||||
|
export async function saveModelMetadata(filePath, data) {
|
||||||
|
const response = await fetch('/api/loras/save-metadata', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
file_path: filePath,
|
||||||
|
...data
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error('Failed to save metadata');
|
||||||
|
}
|
||||||
|
|
||||||
|
return response.json();
|
||||||
|
}
|
||||||
|
|
||||||
export async function loadMoreLoras(resetPage = false, updateFolders = false) {
|
export async function loadMoreLoras(resetPage = false, updateFolders = false) {
|
||||||
return loadMoreModels({
|
return loadMoreModels({
|
||||||
resetPage,
|
resetPage,
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import { showToast } from '../utils/uiHelpers.js';
|
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
|
||||||
import { state } from '../state/index.js';
|
import { state } from '../state/index.js';
|
||||||
import { showCheckpointModal } from './checkpointModal/index.js';
|
import { showCheckpointModal } from './checkpointModal/index.js';
|
||||||
import { NSFW_LEVELS } from '../utils/constants.js';
|
import { NSFW_LEVELS } from '../utils/constants.js';
|
||||||
import { replaceCheckpointPreview as apiReplaceCheckpointPreview } from '../api/checkpointApi.js';
|
import { replaceCheckpointPreview as apiReplaceCheckpointPreview, saveModelMetadata } from '../api/checkpointApi.js';
|
||||||
|
|
||||||
export function createCheckpointCard(checkpoint) {
|
export function createCheckpointCard(checkpoint) {
|
||||||
const card = document.createElement('div');
|
const card = document.createElement('div');
|
||||||
@@ -17,6 +17,7 @@ export function createCheckpointCard(checkpoint) {
|
|||||||
card.dataset.from_civitai = checkpoint.from_civitai;
|
card.dataset.from_civitai = checkpoint.from_civitai;
|
||||||
card.dataset.notes = checkpoint.notes || '';
|
card.dataset.notes = checkpoint.notes || '';
|
||||||
card.dataset.base_model = checkpoint.base_model || 'Unknown';
|
card.dataset.base_model = checkpoint.base_model || 'Unknown';
|
||||||
|
card.dataset.favorite = checkpoint.favorite ? 'true' : 'false';
|
||||||
|
|
||||||
// Store metadata if available
|
// Store metadata if available
|
||||||
if (checkpoint.civitai) {
|
if (checkpoint.civitai) {
|
||||||
@@ -65,6 +66,9 @@ export function createCheckpointCard(checkpoint) {
|
|||||||
const isVideo = previewUrl.endsWith('.mp4');
|
const isVideo = previewUrl.endsWith('.mp4');
|
||||||
const videoAttrs = autoplayOnHover ? 'controls muted loop' : 'controls autoplay muted loop';
|
const videoAttrs = autoplayOnHover ? 'controls muted loop' : 'controls autoplay muted loop';
|
||||||
|
|
||||||
|
// Get favorite status from checkpoint data
|
||||||
|
const isFavorite = checkpoint.favorite === true;
|
||||||
|
|
||||||
card.innerHTML = `
|
card.innerHTML = `
|
||||||
<div class="card-preview ${shouldBlur ? 'blurred' : ''}">
|
<div class="card-preview ${shouldBlur ? 'blurred' : ''}">
|
||||||
${isVideo ?
|
${isVideo ?
|
||||||
@@ -82,6 +86,9 @@ export function createCheckpointCard(checkpoint) {
|
|||||||
${checkpoint.base_model}
|
${checkpoint.base_model}
|
||||||
</span>
|
</span>
|
||||||
<div class="card-actions">
|
<div class="card-actions">
|
||||||
|
<i class="${isFavorite ? 'fas fa-star favorite-active' : 'far fa-star'}"
|
||||||
|
title="${isFavorite ? 'Remove from favorites' : 'Add to favorites'}">
|
||||||
|
</i>
|
||||||
<i class="fas fa-globe"
|
<i class="fas fa-globe"
|
||||||
title="${checkpoint.from_civitai ? 'View on Civitai' : 'Not available from Civitai'}"
|
title="${checkpoint.from_civitai ? 'View on Civitai' : 'Not available from Civitai'}"
|
||||||
${!checkpoint.from_civitai ? 'style="opacity: 0.5; cursor: not-allowed"' : ''}>
|
${!checkpoint.from_civitai ? 'style="opacity: 0.5; cursor: not-allowed"' : ''}>
|
||||||
@@ -198,27 +205,46 @@ export function createCheckpointCard(checkpoint) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Favorite button click event
|
||||||
|
card.querySelector('.fa-star')?.addEventListener('click', async e => {
|
||||||
|
e.stopPropagation();
|
||||||
|
const starIcon = e.currentTarget;
|
||||||
|
const isFavorite = starIcon.classList.contains('fas');
|
||||||
|
const newFavoriteState = !isFavorite;
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Save the new favorite state to the server
|
||||||
|
await saveModelMetadata(card.dataset.filepath, {
|
||||||
|
favorite: newFavoriteState
|
||||||
|
});
|
||||||
|
|
||||||
|
// Update the UI
|
||||||
|
if (newFavoriteState) {
|
||||||
|
starIcon.classList.remove('far');
|
||||||
|
starIcon.classList.add('fas', 'favorite-active');
|
||||||
|
starIcon.title = 'Remove from favorites';
|
||||||
|
card.dataset.favorite = 'true';
|
||||||
|
showToast('Added to favorites', 'success');
|
||||||
|
} else {
|
||||||
|
starIcon.classList.remove('fas', 'favorite-active');
|
||||||
|
starIcon.classList.add('far');
|
||||||
|
starIcon.title = 'Add to favorites';
|
||||||
|
card.dataset.favorite = 'false';
|
||||||
|
showToast('Removed from favorites', 'success');
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to update favorite status:', error);
|
||||||
|
showToast('Failed to update favorite status', 'error');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
// Copy button click event
|
// Copy button click event
|
||||||
card.querySelector('.fa-copy')?.addEventListener('click', async e => {
|
card.querySelector('.fa-copy')?.addEventListener('click', async e => {
|
||||||
e.stopPropagation();
|
e.stopPropagation();
|
||||||
const checkpointName = card.dataset.file_name;
|
const checkpointName = card.dataset.file_name;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Modern clipboard API
|
await copyToClipboard(checkpointName, 'Checkpoint name copied');
|
||||||
if (navigator.clipboard && window.isSecureContext) {
|
|
||||||
await navigator.clipboard.writeText(checkpointName);
|
|
||||||
} else {
|
|
||||||
// Fallback for older browsers
|
|
||||||
const textarea = document.createElement('textarea');
|
|
||||||
textarea.value = checkpointName;
|
|
||||||
textarea.style.position = 'absolute';
|
|
||||||
textarea.style.left = '-99999px';
|
|
||||||
document.body.appendChild(textarea);
|
|
||||||
textarea.select();
|
|
||||||
document.execCommand('copy');
|
|
||||||
document.body.removeChild(textarea);
|
|
||||||
}
|
|
||||||
showToast('Checkpoint name copied', 'success');
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Copy failed:', err);
|
console.error('Copy failed:', err);
|
||||||
showToast('Copy failed', 'error');
|
showToast('Copy failed', 'error');
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { BaseContextMenu } from './BaseContextMenu.js';
|
import { BaseContextMenu } from './BaseContextMenu.js';
|
||||||
import { refreshSingleCheckpointMetadata, saveCheckpointMetadata } from '../../api/checkpointApi.js';
|
import { refreshSingleCheckpointMetadata, saveModelMetadata } from '../../api/checkpointApi.js';
|
||||||
import { showToast, getNSFWLevelName } from '../../utils/uiHelpers.js';
|
import { showToast, getNSFWLevelName } from '../../utils/uiHelpers.js';
|
||||||
import { NSFW_LEVELS } from '../../utils/constants.js';
|
import { NSFW_LEVELS } from '../../utils/constants.js';
|
||||||
import { getStorageItem } from '../../utils/storageHelpers.js';
|
import { getStorageItem } from '../../utils/storageHelpers.js';
|
||||||
@@ -82,7 +82,7 @@ export class CheckpointContextMenu extends BaseContextMenu {
|
|||||||
if (!filePath) return;
|
if (!filePath) return;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await saveCheckpointMetadata(filePath, { preview_nsfw_level: level });
|
await saveModelMetadata(filePath, { preview_nsfw_level: level });
|
||||||
|
|
||||||
// Update card data
|
// Update card data
|
||||||
const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
|
const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { BaseContextMenu } from './BaseContextMenu.js';
|
import { BaseContextMenu } from './BaseContextMenu.js';
|
||||||
import { refreshSingleLoraMetadata } from '../../api/loraApi.js';
|
import { refreshSingleLoraMetadata, saveModelMetadata } from '../../api/loraApi.js';
|
||||||
import { showToast, getNSFWLevelName } from '../../utils/uiHelpers.js';
|
import { showToast, getNSFWLevelName } from '../../utils/uiHelpers.js';
|
||||||
import { NSFW_LEVELS } from '../../utils/constants.js';
|
import { NSFW_LEVELS } from '../../utils/constants.js';
|
||||||
import { getStorageItem } from '../../utils/storageHelpers.js';
|
import { getStorageItem } from '../../utils/storageHelpers.js';
|
||||||
@@ -111,22 +111,7 @@ export class LoraContextMenu extends BaseContextMenu {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async saveModelMetadata(filePath, data) {
|
async saveModelMetadata(filePath, data) {
|
||||||
const response = await fetch('/api/loras/save-metadata', {
|
return saveModelMetadata(filePath, data);
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
},
|
|
||||||
body: JSON.stringify({
|
|
||||||
file_path: filePath,
|
|
||||||
...data
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
throw new Error('Failed to save metadata');
|
|
||||||
}
|
|
||||||
|
|
||||||
return await response.json();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
updateCardBlurEffect(card, level) {
|
updateCardBlurEffect(card, level) {
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import { showToast, openCivitai } from '../utils/uiHelpers.js';
|
import { showToast, openCivitai, copyToClipboard } from '../utils/uiHelpers.js';
|
||||||
import { state } from '../state/index.js';
|
import { state } from '../state/index.js';
|
||||||
import { showLoraModal } from './loraModal/index.js';
|
import { showLoraModal } from './loraModal/index.js';
|
||||||
import { bulkManager } from '../managers/BulkManager.js';
|
import { bulkManager } from '../managers/BulkManager.js';
|
||||||
import { NSFW_LEVELS } from '../utils/constants.js';
|
import { NSFW_LEVELS } from '../utils/constants.js';
|
||||||
import { replacePreview, deleteModel } from '../api/loraApi.js'
|
import { replacePreview, deleteModel, saveModelMetadata } from '../api/loraApi.js'
|
||||||
|
|
||||||
export function createLoraCard(lora) {
|
export function createLoraCard(lora) {
|
||||||
const card = document.createElement('div');
|
const card = document.createElement('div');
|
||||||
@@ -20,6 +20,7 @@ export function createLoraCard(lora) {
|
|||||||
card.dataset.usage_tips = lora.usage_tips;
|
card.dataset.usage_tips = lora.usage_tips;
|
||||||
card.dataset.notes = lora.notes;
|
card.dataset.notes = lora.notes;
|
||||||
card.dataset.meta = JSON.stringify(lora.civitai || {});
|
card.dataset.meta = JSON.stringify(lora.civitai || {});
|
||||||
|
card.dataset.favorite = lora.favorite ? 'true' : 'false';
|
||||||
|
|
||||||
// Store tags and model description
|
// Store tags and model description
|
||||||
if (lora.tags && Array.isArray(lora.tags)) {
|
if (lora.tags && Array.isArray(lora.tags)) {
|
||||||
@@ -65,6 +66,9 @@ export function createLoraCard(lora) {
|
|||||||
const isVideo = previewUrl.endsWith('.mp4');
|
const isVideo = previewUrl.endsWith('.mp4');
|
||||||
const videoAttrs = autoplayOnHover ? 'controls muted loop' : 'controls autoplay muted loop';
|
const videoAttrs = autoplayOnHover ? 'controls muted loop' : 'controls autoplay muted loop';
|
||||||
|
|
||||||
|
// Get favorite status from the lora data
|
||||||
|
const isFavorite = lora.favorite === true;
|
||||||
|
|
||||||
card.innerHTML = `
|
card.innerHTML = `
|
||||||
<div class="card-preview ${shouldBlur ? 'blurred' : ''}">
|
<div class="card-preview ${shouldBlur ? 'blurred' : ''}">
|
||||||
${isVideo ?
|
${isVideo ?
|
||||||
@@ -82,6 +86,9 @@ export function createLoraCard(lora) {
|
|||||||
${lora.base_model}
|
${lora.base_model}
|
||||||
</span>
|
</span>
|
||||||
<div class="card-actions">
|
<div class="card-actions">
|
||||||
|
<i class="${isFavorite ? 'fas fa-star favorite-active' : 'far fa-star'}"
|
||||||
|
title="${isFavorite ? 'Remove from favorites' : 'Add to favorites'}">
|
||||||
|
</i>
|
||||||
<i class="fas fa-globe"
|
<i class="fas fa-globe"
|
||||||
title="${lora.from_civitai ? 'View on Civitai' : 'Not available from Civitai'}"
|
title="${lora.from_civitai ? 'View on Civitai' : 'Not available from Civitai'}"
|
||||||
${!lora.from_civitai ? 'style="opacity: 0.5; cursor: not-allowed"' : ''}>
|
${!lora.from_civitai ? 'style="opacity: 0.5; cursor: not-allowed"' : ''}>
|
||||||
@@ -135,6 +142,7 @@ export function createLoraCard(lora) {
|
|||||||
base_model: card.dataset.base_model,
|
base_model: card.dataset.base_model,
|
||||||
usage_tips: card.dataset.usage_tips,
|
usage_tips: card.dataset.usage_tips,
|
||||||
notes: card.dataset.notes,
|
notes: card.dataset.notes,
|
||||||
|
favorite: card.dataset.favorite === 'true',
|
||||||
// Parse civitai metadata from the card's dataset
|
// Parse civitai metadata from the card's dataset
|
||||||
civitai: (() => {
|
civitai: (() => {
|
||||||
try {
|
try {
|
||||||
@@ -198,6 +206,39 @@ export function createLoraCard(lora) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Favorite button click event
|
||||||
|
card.querySelector('.fa-star')?.addEventListener('click', async e => {
|
||||||
|
e.stopPropagation();
|
||||||
|
const starIcon = e.currentTarget;
|
||||||
|
const isFavorite = starIcon.classList.contains('fas');
|
||||||
|
const newFavoriteState = !isFavorite;
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Save the new favorite state to the server
|
||||||
|
await saveModelMetadata(card.dataset.filepath, {
|
||||||
|
favorite: newFavoriteState
|
||||||
|
});
|
||||||
|
|
||||||
|
// Update the UI
|
||||||
|
if (newFavoriteState) {
|
||||||
|
starIcon.classList.remove('far');
|
||||||
|
starIcon.classList.add('fas', 'favorite-active');
|
||||||
|
starIcon.title = 'Remove from favorites';
|
||||||
|
card.dataset.favorite = 'true';
|
||||||
|
showToast('Added to favorites', 'success');
|
||||||
|
} else {
|
||||||
|
starIcon.classList.remove('fas', 'favorite-active');
|
||||||
|
starIcon.classList.add('far');
|
||||||
|
starIcon.title = 'Add to favorites';
|
||||||
|
card.dataset.favorite = 'false';
|
||||||
|
showToast('Removed from favorites', 'success');
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to update favorite status:', error);
|
||||||
|
showToast('Failed to update favorite status', 'error');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
// Copy button click event
|
// Copy button click event
|
||||||
card.querySelector('.fa-copy')?.addEventListener('click', async e => {
|
card.querySelector('.fa-copy')?.addEventListener('click', async e => {
|
||||||
e.stopPropagation();
|
e.stopPropagation();
|
||||||
@@ -205,26 +246,7 @@ export function createLoraCard(lora) {
|
|||||||
const strength = usageTips.strength || 1;
|
const strength = usageTips.strength || 1;
|
||||||
const loraSyntax = `<lora:${card.dataset.file_name}:${strength}>`;
|
const loraSyntax = `<lora:${card.dataset.file_name}:${strength}>`;
|
||||||
|
|
||||||
try {
|
await copyToClipboard(loraSyntax, 'LoRA syntax copied');
|
||||||
// Modern clipboard API
|
|
||||||
if (navigator.clipboard && window.isSecureContext) {
|
|
||||||
await navigator.clipboard.writeText(loraSyntax);
|
|
||||||
} else {
|
|
||||||
// Fallback for older browsers
|
|
||||||
const textarea = document.createElement('textarea');
|
|
||||||
textarea.value = loraSyntax;
|
|
||||||
textarea.style.position = 'absolute';
|
|
||||||
textarea.style.left = '-99999px';
|
|
||||||
document.body.appendChild(textarea);
|
|
||||||
textarea.select();
|
|
||||||
document.execCommand('copy');
|
|
||||||
document.body.removeChild(textarea);
|
|
||||||
}
|
|
||||||
showToast('LoRA syntax copied', 'success');
|
|
||||||
} catch (err) {
|
|
||||||
console.error('Copy failed:', err);
|
|
||||||
showToast('Copy failed', 'error');
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Civitai button click event
|
// Civitai button click event
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
// Recipe Card Component
|
// Recipe Card Component
|
||||||
import { showToast } from '../utils/uiHelpers.js';
|
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
|
||||||
import { modalManager } from '../managers/ModalManager.js';
|
import { modalManager } from '../managers/ModalManager.js';
|
||||||
|
|
||||||
class RecipeCard {
|
class RecipeCard {
|
||||||
@@ -109,14 +109,11 @@ class RecipeCard {
|
|||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
if (data.success && data.syntax) {
|
if (data.success && data.syntax) {
|
||||||
return navigator.clipboard.writeText(data.syntax);
|
return copyToClipboard(data.syntax, 'Recipe syntax copied to clipboard');
|
||||||
} else {
|
} else {
|
||||||
throw new Error(data.error || 'No syntax returned');
|
throw new Error(data.error || 'No syntax returned');
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.then(() => {
|
|
||||||
showToast('Recipe syntax copied to clipboard', 'success');
|
|
||||||
})
|
|
||||||
.catch(err => {
|
.catch(err => {
|
||||||
console.error('Failed to copy: ', err);
|
console.error('Failed to copy: ', err);
|
||||||
showToast('Failed to copy recipe syntax', 'error');
|
showToast('Failed to copy recipe syntax', 'error');
|
||||||
@@ -279,4 +276,4 @@ class RecipeCard {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export { RecipeCard };
|
export { RecipeCard };
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
// Recipe Modal Component
|
// Recipe Modal Component
|
||||||
import { showToast } from '../utils/uiHelpers.js';
|
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
|
||||||
import { state } from '../state/index.js';
|
import { state } from '../state/index.js';
|
||||||
import { setSessionItem, removeSessionItem } from '../utils/storageHelpers.js';
|
import { setSessionItem, removeSessionItem } from '../utils/storageHelpers.js';
|
||||||
|
|
||||||
@@ -747,9 +747,8 @@ class RecipeModal {
|
|||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
if (data.success && data.syntax) {
|
if (data.success && data.syntax) {
|
||||||
// Copy to clipboard
|
// Use the centralized copyToClipboard utility function
|
||||||
await navigator.clipboard.writeText(data.syntax);
|
await copyToClipboard(data.syntax, 'Recipe syntax copied to clipboard');
|
||||||
showToast('Recipe syntax copied to clipboard', 'success');
|
|
||||||
} else {
|
} else {
|
||||||
throw new Error(data.error || 'No syntax returned from server');
|
throw new Error(data.error || 'No syntax returned from server');
|
||||||
}
|
}
|
||||||
@@ -761,12 +760,7 @@ class RecipeModal {
|
|||||||
|
|
||||||
// Helper method to copy text to clipboard
|
// Helper method to copy text to clipboard
|
||||||
copyToClipboard(text, successMessage) {
|
copyToClipboard(text, successMessage) {
|
||||||
navigator.clipboard.writeText(text).then(() => {
|
copyToClipboard(text, successMessage);
|
||||||
showToast(successMessage, 'success');
|
|
||||||
}).catch(err => {
|
|
||||||
console.error('Failed to copy text: ', err);
|
|
||||||
showToast('Failed to copy text', 'error');
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add new method to handle downloading missing LoRAs
|
// Add new method to handle downloading missing LoRAs
|
||||||
|
|||||||
@@ -5,31 +5,7 @@
|
|||||||
import { showToast } from '../../utils/uiHelpers.js';
|
import { showToast } from '../../utils/uiHelpers.js';
|
||||||
import { BASE_MODELS } from '../../utils/constants.js';
|
import { BASE_MODELS } from '../../utils/constants.js';
|
||||||
import { updateCheckpointCard } from '../../utils/cardUpdater.js';
|
import { updateCheckpointCard } from '../../utils/cardUpdater.js';
|
||||||
|
import { saveModelMetadata } from '../../api/checkpointApi.js';
|
||||||
/**
|
|
||||||
* Save model metadata to the server
|
|
||||||
* @param {string} filePath - Path to the model file
|
|
||||||
* @param {Object} data - Metadata to save
|
|
||||||
* @returns {Promise} - Promise that resolves with the server response
|
|
||||||
*/
|
|
||||||
export async function saveModelMetadata(filePath, data) {
|
|
||||||
const response = await fetch('/api/checkpoints/save-metadata', {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
},
|
|
||||||
body: JSON.stringify({
|
|
||||||
file_path: filePath,
|
|
||||||
...data
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
throw new Error('Failed to save metadata');
|
|
||||||
}
|
|
||||||
|
|
||||||
return response.json();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set up model name editing functionality
|
* Set up model name editing functionality
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* ShowcaseView.js
|
* ShowcaseView.js
|
||||||
* Handles showcase content (images, videos) display for checkpoint modal
|
* Handles showcase content (images, videos) display for checkpoint modal
|
||||||
*/
|
*/
|
||||||
import { showToast } from '../../utils/uiHelpers.js';
|
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
|
||||||
import { state } from '../../state/index.js';
|
import { state } from '../../state/index.js';
|
||||||
import { NSFW_LEVELS } from '../../utils/constants.js';
|
import { NSFW_LEVELS } from '../../utils/constants.js';
|
||||||
|
|
||||||
@@ -307,8 +307,7 @@ function initMetadataPanelHandlers(container) {
|
|||||||
if (!promptElement) return;
|
if (!promptElement) return;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await navigator.clipboard.writeText(promptElement.textContent);
|
await copyToClipboard(promptElement.textContent, 'Prompt copied to clipboard');
|
||||||
showToast('Prompt copied to clipboard', 'success');
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Copy failed:', err);
|
console.error('Copy failed:', err);
|
||||||
showToast('Copy failed', 'error');
|
showToast('Copy failed', 'error');
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ import { setupTabSwitching, loadModelDescription } from './ModelDescription.js';
|
|||||||
import {
|
import {
|
||||||
setupModelNameEditing,
|
setupModelNameEditing,
|
||||||
setupBaseModelEditing,
|
setupBaseModelEditing,
|
||||||
setupFileNameEditing,
|
setupFileNameEditing
|
||||||
saveModelMetadata
|
|
||||||
} from './ModelMetadata.js';
|
} from './ModelMetadata.js';
|
||||||
|
import { saveModelMetadata } from '../../api/checkpointApi.js';
|
||||||
import { renderCompactTags, setupTagTooltip, formatFileSize } from './utils.js';
|
import { renderCompactTags, setupTagTooltip, formatFileSize } from './utils.js';
|
||||||
import { updateCheckpointCard } from '../../utils/cardUpdater.js';
|
import { updateCheckpointCard } from '../../utils/cardUpdater.js';
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
import { PageControls } from './PageControls.js';
|
import { PageControls } from './PageControls.js';
|
||||||
import { loadMoreLoras, fetchCivitai, resetAndReload, refreshLoras } from '../../api/loraApi.js';
|
import { loadMoreLoras, fetchCivitai, resetAndReload, refreshLoras } from '../../api/loraApi.js';
|
||||||
import { getSessionItem, removeSessionItem } from '../../utils/storageHelpers.js';
|
import { getSessionItem, removeSessionItem } from '../../utils/storageHelpers.js';
|
||||||
import { showToast } from '../../utils/uiHelpers.js';
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* LorasControls class - Extends PageControls for LoRA-specific functionality
|
* LorasControls class - Extends PageControls for LoRA-specific functionality
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// PageControls.js - Manages controls for both LoRAs and Checkpoints pages
|
// PageControls.js - Manages controls for both LoRAs and Checkpoints pages
|
||||||
import { state, getCurrentPageState, setCurrentPageType } from '../../state/index.js';
|
import { state, getCurrentPageState, setCurrentPageType } from '../../state/index.js';
|
||||||
import { getStorageItem, setStorageItem } from '../../utils/storageHelpers.js';
|
import { getStorageItem, setStorageItem, getSessionItem, setSessionItem } from '../../utils/storageHelpers.js';
|
||||||
import { showToast } from '../../utils/uiHelpers.js';
|
import { showToast } from '../../utils/uiHelpers.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -26,6 +26,9 @@ export class PageControls {
|
|||||||
// Initialize event listeners
|
// Initialize event listeners
|
||||||
this.initEventListeners();
|
this.initEventListeners();
|
||||||
|
|
||||||
|
// Initialize favorites filter button state
|
||||||
|
this.initFavoritesFilter();
|
||||||
|
|
||||||
console.log(`PageControls initialized for ${pageType} page`);
|
console.log(`PageControls initialized for ${pageType} page`);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,6 +124,12 @@ export class PageControls {
|
|||||||
bulkButton.addEventListener('click', () => this.toggleBulkMode());
|
bulkButton.addEventListener('click', () => this.toggleBulkMode());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Favorites filter button handler
|
||||||
|
const favoriteFilterBtn = document.getElementById('favoriteFilterBtn');
|
||||||
|
if (favoriteFilterBtn) {
|
||||||
|
favoriteFilterBtn.addEventListener('click', () => this.toggleFavoritesOnly());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -385,4 +394,50 @@ export class PageControls {
|
|||||||
showToast('Failed to clear custom filter: ' + error.message, 'error');
|
showToast('Failed to clear custom filter: ' + error.message, 'error');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialize the favorites filter button state
|
||||||
|
*/
|
||||||
|
initFavoritesFilter() {
|
||||||
|
const favoriteFilterBtn = document.getElementById('favoriteFilterBtn');
|
||||||
|
if (favoriteFilterBtn) {
|
||||||
|
// Get current state from session storage with page-specific key
|
||||||
|
const storageKey = `show_favorites_only_${this.pageType}`;
|
||||||
|
const showFavoritesOnly = getSessionItem(storageKey, false);
|
||||||
|
|
||||||
|
// Update button state
|
||||||
|
if (showFavoritesOnly) {
|
||||||
|
favoriteFilterBtn.classList.add('active');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update app state
|
||||||
|
this.pageState.showFavoritesOnly = showFavoritesOnly;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Toggle favorites-only filter and reload models
|
||||||
|
*/
|
||||||
|
async toggleFavoritesOnly() {
|
||||||
|
const favoriteFilterBtn = document.getElementById('favoriteFilterBtn');
|
||||||
|
|
||||||
|
// Toggle the filter state in storage
|
||||||
|
const storageKey = `show_favorites_only_${this.pageType}`;
|
||||||
|
const currentState = this.pageState.showFavoritesOnly;
|
||||||
|
const newState = !currentState;
|
||||||
|
|
||||||
|
// Update session storage
|
||||||
|
setSessionItem(storageKey, newState);
|
||||||
|
|
||||||
|
// Update state
|
||||||
|
this.pageState.showFavoritesOnly = newState;
|
||||||
|
|
||||||
|
// Update button appearance
|
||||||
|
if (favoriteFilterBtn) {
|
||||||
|
favoriteFilterBtn.classList.toggle('active', newState);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload models with new filter
|
||||||
|
await this.resetAndReload(true);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -5,31 +5,7 @@
|
|||||||
import { showToast } from '../../utils/uiHelpers.js';
|
import { showToast } from '../../utils/uiHelpers.js';
|
||||||
import { BASE_MODELS } from '../../utils/constants.js';
|
import { BASE_MODELS } from '../../utils/constants.js';
|
||||||
import { updateLoraCard } from '../../utils/cardUpdater.js';
|
import { updateLoraCard } from '../../utils/cardUpdater.js';
|
||||||
|
import { saveModelMetadata } from '../../api/loraApi.js';
|
||||||
/**
|
|
||||||
* 保存模型元数据到服务器
|
|
||||||
* @param {string} filePath - 文件路径
|
|
||||||
* @param {Object} data - 要保存的数据
|
|
||||||
* @returns {Promise} 保存操作的Promise
|
|
||||||
*/
|
|
||||||
export async function saveModelMetadata(filePath, data) {
|
|
||||||
const response = await fetch('/api/loras/save-metadata', {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
},
|
|
||||||
body: JSON.stringify({
|
|
||||||
file_path: filePath,
|
|
||||||
...data
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
throw new Error('Failed to save metadata');
|
|
||||||
}
|
|
||||||
|
|
||||||
return response.json();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 设置模型名称编辑功能
|
* 设置模型名称编辑功能
|
||||||
|
|||||||
@@ -2,8 +2,7 @@
|
|||||||
* PresetTags.js
|
* PresetTags.js
|
||||||
* 处理LoRA模型预设参数标签相关的功能模块
|
* 处理LoRA模型预设参数标签相关的功能模块
|
||||||
*/
|
*/
|
||||||
import { saveModelMetadata } from './ModelMetadata.js';
|
import { saveModelMetadata } from '../../api/loraApi.js';
|
||||||
import { showToast } from '../../utils/uiHelpers.js';
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 解析预设参数
|
* 解析预设参数
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
/**
|
/**
|
||||||
* RecipeTab - Handles the recipes tab in the Lora Modal
|
* RecipeTab - Handles the recipes tab in the Lora Modal
|
||||||
*/
|
*/
|
||||||
import { showToast } from '../../utils/uiHelpers.js';
|
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
|
||||||
import { setSessionItem, removeSessionItem } from '../../utils/storageHelpers.js';
|
import { setSessionItem, removeSessionItem } from '../../utils/storageHelpers.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -172,14 +172,11 @@ function copyRecipeSyntax(recipeId) {
|
|||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
if (data.success && data.syntax) {
|
if (data.success && data.syntax) {
|
||||||
return navigator.clipboard.writeText(data.syntax);
|
return copyToClipboard(data.syntax, 'Recipe syntax copied to clipboard');
|
||||||
} else {
|
} else {
|
||||||
throw new Error(data.error || 'No syntax returned');
|
throw new Error(data.error || 'No syntax returned');
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.then(() => {
|
|
||||||
showToast('Recipe syntax copied to clipboard', 'success');
|
|
||||||
})
|
|
||||||
.catch(err => {
|
.catch(err => {
|
||||||
console.error('Failed to copy: ', err);
|
console.error('Failed to copy: ', err);
|
||||||
showToast('Failed to copy recipe syntax', 'error');
|
showToast('Failed to copy recipe syntax', 'error');
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* ShowcaseView.js
|
* ShowcaseView.js
|
||||||
* 处理LoRA模型展示内容(图片、视频)的功能模块
|
* 处理LoRA模型展示内容(图片、视频)的功能模块
|
||||||
*/
|
*/
|
||||||
import { showToast } from '../../utils/uiHelpers.js';
|
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
|
||||||
import { state } from '../../state/index.js';
|
import { state } from '../../state/index.js';
|
||||||
import { NSFW_LEVELS } from '../../utils/constants.js';
|
import { NSFW_LEVELS } from '../../utils/constants.js';
|
||||||
|
|
||||||
@@ -311,8 +311,7 @@ function initMetadataPanelHandlers(container) {
|
|||||||
if (!promptElement) return;
|
if (!promptElement) return;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await navigator.clipboard.writeText(promptElement.textContent);
|
await copyToClipboard(promptElement.textContent, 'Prompt copied to clipboard');
|
||||||
showToast('Prompt copied to clipboard', 'success');
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Copy failed:', err);
|
console.error('Copy failed:', err);
|
||||||
showToast('Copy failed', 'error');
|
showToast('Copy failed', 'error');
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
* TriggerWords.js
|
* TriggerWords.js
|
||||||
* 处理LoRA模型触发词相关的功能模块
|
* 处理LoRA模型触发词相关的功能模块
|
||||||
*/
|
*/
|
||||||
import { showToast } from '../../utils/uiHelpers.js';
|
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
|
||||||
import { saveModelMetadata } from './ModelMetadata.js';
|
import { saveModelMetadata } from '../../api/loraApi.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 渲染触发词
|
* 渲染触发词
|
||||||
@@ -336,23 +336,7 @@ async function saveTriggerWords() {
|
|||||||
*/
|
*/
|
||||||
window.copyTriggerWord = async function(word) {
|
window.copyTriggerWord = async function(word) {
|
||||||
try {
|
try {
|
||||||
// Modern clipboard API - with fallback for non-secure contexts
|
await copyToClipboard(word, 'Trigger word copied');
|
||||||
if (navigator.clipboard && window.isSecureContext) {
|
|
||||||
await navigator.clipboard.writeText(word);
|
|
||||||
} else {
|
|
||||||
// Fallback for older browsers or non-secure contexts
|
|
||||||
const textarea = document.createElement('textarea');
|
|
||||||
textarea.value = word;
|
|
||||||
textarea.style.position = 'absolute';
|
|
||||||
textarea.style.left = '-99999px';
|
|
||||||
document.body.appendChild(textarea);
|
|
||||||
textarea.select();
|
|
||||||
const success = document.execCommand('copy');
|
|
||||||
document.body.removeChild(textarea);
|
|
||||||
|
|
||||||
if (!success) throw new Error('Copy command failed');
|
|
||||||
}
|
|
||||||
showToast('Trigger word copied', 'success');
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Copy failed:', err);
|
console.error('Copy failed:', err);
|
||||||
showToast('Copy failed', 'error');
|
showToast('Copy failed', 'error');
|
||||||
|
|||||||
@@ -3,8 +3,7 @@
|
|||||||
*
|
*
|
||||||
* 将原始的LoraModal.js拆分成多个功能模块后的主入口文件
|
* 将原始的LoraModal.js拆分成多个功能模块后的主入口文件
|
||||||
*/
|
*/
|
||||||
import { showToast } from '../../utils/uiHelpers.js';
|
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
|
||||||
import { state } from '../../state/index.js';
|
|
||||||
import { modalManager } from '../../managers/ModalManager.js';
|
import { modalManager } from '../../managers/ModalManager.js';
|
||||||
import { renderShowcaseContent, toggleShowcase, setupShowcaseScroll, scrollToTop } from './ShowcaseView.js';
|
import { renderShowcaseContent, toggleShowcase, setupShowcaseScroll, scrollToTop } from './ShowcaseView.js';
|
||||||
import { setupTabSwitching, loadModelDescription } from './ModelDescription.js';
|
import { setupTabSwitching, loadModelDescription } from './ModelDescription.js';
|
||||||
@@ -14,9 +13,9 @@ import { loadRecipesForLora } from './RecipeTab.js'; // Add import for recipe ta
|
|||||||
import {
|
import {
|
||||||
setupModelNameEditing,
|
setupModelNameEditing,
|
||||||
setupBaseModelEditing,
|
setupBaseModelEditing,
|
||||||
setupFileNameEditing,
|
setupFileNameEditing
|
||||||
saveModelMetadata
|
|
||||||
} from './ModelMetadata.js';
|
} from './ModelMetadata.js';
|
||||||
|
import { saveModelMetadata } from '../../api/loraApi.js';
|
||||||
import { renderCompactTags, setupTagTooltip, formatFileSize } from './utils.js';
|
import { renderCompactTags, setupTagTooltip, formatFileSize } from './utils.js';
|
||||||
import { updateLoraCard } from '../../utils/cardUpdater.js';
|
import { updateLoraCard } from '../../utils/cardUpdater.js';
|
||||||
|
|
||||||
@@ -174,8 +173,7 @@ export function showLoraModal(lora) {
|
|||||||
// Copy file name function
|
// Copy file name function
|
||||||
window.copyFileName = async function(fileName) {
|
window.copyFileName = async function(fileName) {
|
||||||
try {
|
try {
|
||||||
await navigator.clipboard.writeText(fileName);
|
await copyToClipboard(fileName, 'File name copied');
|
||||||
showToast('File name copied', 'success');
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Copy failed:', err);
|
console.error('Copy failed:', err);
|
||||||
showToast('Copy failed', 'error');
|
showToast('Copy failed', 'error');
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { state } from '../state/index.js';
|
import { state } from '../state/index.js';
|
||||||
import { showToast } from '../utils/uiHelpers.js';
|
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
|
||||||
import { updateCardsForBulkMode } from '../components/LoraCard.js';
|
import { updateCardsForBulkMode } from '../components/LoraCard.js';
|
||||||
|
|
||||||
export class BulkManager {
|
export class BulkManager {
|
||||||
@@ -205,13 +205,7 @@ export class BulkManager {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
await copyToClipboard(loraSyntaxes.join(', '), `Copied ${loraSyntaxes.length} LoRA syntaxes to clipboard`);
|
||||||
await navigator.clipboard.writeText(loraSyntaxes.join(', '));
|
|
||||||
showToast(`Copied ${loraSyntaxes.length} LoRA syntaxes to clipboard`, 'success');
|
|
||||||
} catch (err) {
|
|
||||||
console.error('Copy failed:', err);
|
|
||||||
showToast('Copy failed', 'error');
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create and show the thumbnail strip of selected LoRAs
|
// Create and show the thumbnail strip of selected LoRAs
|
||||||
|
|||||||
@@ -268,6 +268,32 @@ class RecipeManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Refreshes the recipe list by first rebuilding the cache and then loading recipes
|
||||||
|
*/
|
||||||
|
async refreshRecipes() {
|
||||||
|
try {
|
||||||
|
// Call the new endpoint to rebuild the recipe cache
|
||||||
|
const response = await fetch('/api/recipes/scan');
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const data = await response.json();
|
||||||
|
throw new Error(data.error || 'Failed to refresh recipe cache');
|
||||||
|
}
|
||||||
|
|
||||||
|
// After successful cache rebuild, load the recipes
|
||||||
|
await this.loadRecipes(true);
|
||||||
|
|
||||||
|
appCore.showToast('Refresh complete', 'success');
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error refreshing recipes:', error);
|
||||||
|
appCore.showToast(error.message || 'Failed to refresh recipes', 'error');
|
||||||
|
|
||||||
|
// Still try to load recipes even if scan failed
|
||||||
|
await this.loadRecipes(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async _loadSpecificRecipe(recipeId) {
|
async _loadSpecificRecipe(recipeId) {
|
||||||
try {
|
try {
|
||||||
// Fetch specific recipe by ID
|
// Fetch specific recipe by ID
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ export const state = {
|
|||||||
bulkMode: false,
|
bulkMode: false,
|
||||||
selectedLoras: new Set(),
|
selectedLoras: new Set(),
|
||||||
loraMetadataCache: new Map(),
|
loraMetadataCache: new Map(),
|
||||||
|
showFavoritesOnly: false,
|
||||||
},
|
},
|
||||||
|
|
||||||
recipes: {
|
recipes: {
|
||||||
@@ -61,7 +62,8 @@ export const state = {
|
|||||||
tags: [],
|
tags: [],
|
||||||
search: ''
|
search: ''
|
||||||
},
|
},
|
||||||
pageSize: 20
|
pageSize: 20,
|
||||||
|
showFavoritesOnly: false,
|
||||||
},
|
},
|
||||||
|
|
||||||
checkpoints: {
|
checkpoints: {
|
||||||
@@ -80,7 +82,8 @@ export const state = {
|
|||||||
filters: {
|
filters: {
|
||||||
baseModel: [],
|
baseModel: [],
|
||||||
tags: []
|
tags: []
|
||||||
}
|
},
|
||||||
|
showFavoritesOnly: false,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,40 @@ import { state } from '../state/index.js';
|
|||||||
import { resetAndReload } from '../api/loraApi.js';
|
import { resetAndReload } from '../api/loraApi.js';
|
||||||
import { getStorageItem, setStorageItem } from './storageHelpers.js';
|
import { getStorageItem, setStorageItem } from './storageHelpers.js';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utility function to copy text to clipboard with fallback for older browsers
|
||||||
|
* @param {string} text - The text to copy to clipboard
|
||||||
|
* @param {string} successMessage - Optional success message to show in toast
|
||||||
|
* @returns {Promise<boolean>} - Promise that resolves to true if copy was successful
|
||||||
|
*/
|
||||||
|
export async function copyToClipboard(text, successMessage = 'Copied to clipboard') {
|
||||||
|
try {
|
||||||
|
// Modern clipboard API
|
||||||
|
if (navigator.clipboard && window.isSecureContext) {
|
||||||
|
await navigator.clipboard.writeText(text);
|
||||||
|
} else {
|
||||||
|
// Fallback for older browsers
|
||||||
|
const textarea = document.createElement('textarea');
|
||||||
|
textarea.value = text;
|
||||||
|
textarea.style.position = 'absolute';
|
||||||
|
textarea.style.left = '-99999px';
|
||||||
|
document.body.appendChild(textarea);
|
||||||
|
textarea.select();
|
||||||
|
document.execCommand('copy');
|
||||||
|
document.body.removeChild(textarea);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (successMessage) {
|
||||||
|
showToast(successMessage, 'success');
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Copy failed:', err);
|
||||||
|
showToast('Copy failed', 'error');
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export function showToast(message, type = 'info') {
|
export function showToast(message, type = 'info') {
|
||||||
const toast = document.createElement('div');
|
const toast = document.createElement('div');
|
||||||
toast.className = `toast toast-${type}`;
|
toast.className = `toast toast-${type}`;
|
||||||
@@ -80,13 +114,55 @@ export function restoreFolderFilter() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function initTheme() {
|
export function initTheme() {
|
||||||
document.body.dataset.theme = getStorageItem('theme') || 'dark';
|
const savedTheme = getStorageItem('theme') || 'auto';
|
||||||
|
applyTheme(savedTheme);
|
||||||
|
|
||||||
|
// Update theme when system preference changes (for 'auto' mode)
|
||||||
|
window.matchMedia('(prefers-color-scheme: dark)').addEventListener('change', () => {
|
||||||
|
const currentTheme = getStorageItem('theme') || 'auto';
|
||||||
|
if (currentTheme === 'auto') {
|
||||||
|
applyTheme('auto');
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
export function toggleTheme() {
|
export function toggleTheme() {
|
||||||
const theme = document.body.dataset.theme === 'light' ? 'dark' : 'light';
|
const currentTheme = getStorageItem('theme') || 'auto';
|
||||||
document.body.dataset.theme = theme;
|
let newTheme;
|
||||||
setStorageItem('theme', theme);
|
|
||||||
|
if (currentTheme === 'dark') {
|
||||||
|
newTheme = 'light';
|
||||||
|
} else {
|
||||||
|
newTheme = 'dark';
|
||||||
|
}
|
||||||
|
|
||||||
|
setStorageItem('theme', newTheme);
|
||||||
|
applyTheme(newTheme);
|
||||||
|
|
||||||
|
// Force a repaint to ensure theme changes are applied immediately
|
||||||
|
document.body.style.display = 'none';
|
||||||
|
document.body.offsetHeight; // Trigger a reflow
|
||||||
|
document.body.style.display = '';
|
||||||
|
|
||||||
|
return newTheme;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a new helper function to apply the theme
|
||||||
|
function applyTheme(theme) {
|
||||||
|
const prefersDark = window.matchMedia('(prefers-color-scheme: dark)').matches;
|
||||||
|
const htmlElement = document.documentElement;
|
||||||
|
|
||||||
|
// Remove any existing theme attributes
|
||||||
|
htmlElement.removeAttribute('data-theme');
|
||||||
|
|
||||||
|
// Apply the appropriate theme
|
||||||
|
if (theme === 'dark' || (theme === 'auto' && prefersDark)) {
|
||||||
|
htmlElement.setAttribute('data-theme', 'dark');
|
||||||
|
document.body.dataset.theme = 'dark';
|
||||||
|
} else {
|
||||||
|
htmlElement.setAttribute('data-theme', 'light');
|
||||||
|
document.body.dataset.theme = 'light';
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export function toggleFolder(tag) {
|
export function toggleFolder(tag) {
|
||||||
@@ -108,12 +184,6 @@ export function toggleFolder(tag) {
|
|||||||
resetAndReload();
|
resetAndReload();
|
||||||
}
|
}
|
||||||
|
|
||||||
export function copyTriggerWord(word) {
|
|
||||||
navigator.clipboard.writeText(word).then(() => {
|
|
||||||
showToast('Trigger word copied', 'success');
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
function filterByFolder(folderPath) {
|
function filterByFolder(folderPath) {
|
||||||
document.querySelectorAll('.lora-card').forEach(card => {
|
document.querySelectorAll('.lora-card').forEach(card => {
|
||||||
card.style.display = card.dataset.folder === folderPath ? '' : 'none';
|
card.style.display = card.dataset.folder === folderPath ? '' : 'none';
|
||||||
|
|||||||
6
static/vendor/font-awesome/css/all.min.css
vendored
Normal file
6
static/vendor/font-awesome/css/all.min.css
vendored
Normal file
File diff suppressed because one or more lines are too long
BIN
static/vendor/font-awesome/webfonts/fa-brands-400.woff2
vendored
Normal file
BIN
static/vendor/font-awesome/webfonts/fa-brands-400.woff2
vendored
Normal file
Binary file not shown.
BIN
static/vendor/font-awesome/webfonts/fa-regular-400.ttf
vendored
Normal file
BIN
static/vendor/font-awesome/webfonts/fa-regular-400.ttf
vendored
Normal file
Binary file not shown.
BIN
static/vendor/font-awesome/webfonts/fa-regular-400.woff2
vendored
Normal file
BIN
static/vendor/font-awesome/webfonts/fa-regular-400.woff2
vendored
Normal file
Binary file not shown.
BIN
static/vendor/font-awesome/webfonts/fa-solid-900.woff2
vendored
Normal file
BIN
static/vendor/font-awesome/webfonts/fa-solid-900.woff2
vendored
Normal file
Binary file not shown.
@@ -6,7 +6,7 @@
|
|||||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||||
<link rel="stylesheet" href="/loras_static/css/style.css">
|
<link rel="stylesheet" href="/loras_static/css/style.css">
|
||||||
{% block page_css %}{% endblock %}
|
{% block page_css %}{% endblock %}
|
||||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css"
|
<link rel="stylesheet" href="/loras_static/vendor/font-awesome/css/all.min.css"
|
||||||
crossorigin="anonymous" referrerpolicy="no-referrer">
|
crossorigin="anonymous" referrerpolicy="no-referrer">
|
||||||
<link rel="icon" type="image/png" sizes="32x32" href="/loras_static/images/favicon-32x32.png">
|
<link rel="icon" type="image/png" sizes="32x32" href="/loras_static/images/favicon-32x32.png">
|
||||||
<link rel="icon" type="image/png" sizes="16x16" href="/loras_static/images/favicon-16x16.png">
|
<link rel="icon" type="image/png" sizes="16x16" href="/loras_static/images/favicon-16x16.png">
|
||||||
@@ -17,7 +17,7 @@
|
|||||||
{% block preload %}{% endblock %}
|
{% block preload %}{% endblock %}
|
||||||
|
|
||||||
<!-- 优化字体加载 -->
|
<!-- 优化字体加载 -->
|
||||||
<link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/webfonts/fa-solid-900.woff2"
|
<link rel="preload" href="/loras_static/vendor/font-awesome/webfonts/fa-solid-900.woff2"
|
||||||
as="font" type="font/woff2" crossorigin>
|
as="font" type="font/woff2" crossorigin>
|
||||||
|
|
||||||
<!-- 添加性能监控 -->
|
<!-- 添加性能监控 -->
|
||||||
@@ -35,7 +35,7 @@
|
|||||||
|
|
||||||
<!-- 添加资源加载策略 -->
|
<!-- 添加资源加载策略 -->
|
||||||
<link rel="preconnect" href="https://civitai.com">
|
<link rel="preconnect" href="https://civitai.com">
|
||||||
<link rel="preconnect" href="https://cdnjs.cloudflare.com">
|
<!-- <link rel="preconnect" href="https://cdnjs.cloudflare.com"> -->
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
// 计算滚动条宽度并设置CSS变量
|
// 计算滚动条宽度并设置CSS变量
|
||||||
@@ -48,6 +48,20 @@
|
|||||||
document.documentElement.style.setProperty('--scrollbar-width', scrollbarWidth + 'px');
|
document.documentElement.style.setProperty('--scrollbar-width', scrollbarWidth + 'px');
|
||||||
});
|
});
|
||||||
</script>
|
</script>
|
||||||
|
<script>
|
||||||
|
(function() {
|
||||||
|
// Apply theme immediately based on stored preference
|
||||||
|
const STORAGE_PREFIX = 'lora_manager_';
|
||||||
|
const savedTheme = localStorage.getItem(STORAGE_PREFIX + 'theme') || 'auto';
|
||||||
|
const prefersDark = window.matchMedia('(prefers-color-scheme: dark)').matches;
|
||||||
|
|
||||||
|
if (savedTheme === 'dark' || (savedTheme === 'auto' && prefersDark)) {
|
||||||
|
document.documentElement.setAttribute('data-theme', 'dark');
|
||||||
|
} else {
|
||||||
|
document.documentElement.setAttribute('data-theme', 'light');
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
</script>
|
||||||
{% block head_scripts %}{% endblock %}
|
{% block head_scripts %}{% endblock %}
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,11 @@
|
|||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
<div class="control-group">
|
||||||
|
<button id="favoriteFilterBtn" data-action="toggle-favorites" class="favorite-filter" title="Show favorites only">
|
||||||
|
<i class="fas fa-star"></i> Favorites
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
<div id="customFilterIndicator" class="control-group hidden">
|
<div id="customFilterIndicator" class="control-group hidden">
|
||||||
<div class="filter-active">
|
<div class="filter-active">
|
||||||
<i class="fas fa-filter"></i> <span class="customFilterText" title=""></span>
|
<i class="fas fa-filter"></i> <span class="customFilterText" title=""></span>
|
||||||
|
|||||||
@@ -37,7 +37,7 @@
|
|||||||
<div class="controls">
|
<div class="controls">
|
||||||
<div class="action-buttons">
|
<div class="action-buttons">
|
||||||
<div title="Refresh recipe list" class="control-group">
|
<div title="Refresh recipe list" class="control-group">
|
||||||
<button onclick="recipeManager.loadRecipes(true)"><i class="fas fa-sync"></i> Refresh</button>
|
<button onclick="recipeManager.refreshRecipes()"><i class="fas fa-sync"></i> Refresh</button>
|
||||||
</div>
|
</div>
|
||||||
<div title="Import recipes" class="control-group">
|
<div title="Import recipes" class="control-group">
|
||||||
<button onclick="importManager.showImportModal()"><i class="fas fa-file-import"></i> Import</button>
|
<button onclick="importManager.showImportModal()"><i class="fas fa-file-import"></i> Import</button>
|
||||||
|
|||||||
@@ -927,10 +927,6 @@ export function addLorasWidget(node, name, opts, callback) {
|
|||||||
// Function to directly save the recipe without dialog
|
// Function to directly save the recipe without dialog
|
||||||
async function saveRecipeDirectly(widget) {
|
async function saveRecipeDirectly(widget) {
|
||||||
try {
|
try {
|
||||||
// Get the workflow data from the ComfyUI app
|
|
||||||
const prompt = await app.graphToPrompt();
|
|
||||||
console.log('Prompt:', prompt);
|
|
||||||
|
|
||||||
// Show loading toast
|
// Show loading toast
|
||||||
if (app && app.extensionManager && app.extensionManager.toast) {
|
if (app && app.extensionManager && app.extensionManager.toast) {
|
||||||
app.extensionManager.toast.add({
|
app.extensionManager.toast.add({
|
||||||
@@ -941,14 +937,9 @@ async function saveRecipeDirectly(widget) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare the data - only send workflow JSON
|
|
||||||
const formData = new FormData();
|
|
||||||
formData.append('workflow_json', JSON.stringify(prompt.output));
|
|
||||||
|
|
||||||
// Send the request
|
// Send the request
|
||||||
const response = await fetch('/api/recipes/save-from-widget', {
|
const response = await fetch('/api/recipes/save-from-widget', {
|
||||||
method: 'POST',
|
method: 'POST'
|
||||||
body: formData
|
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = await response.json();
|
const result = await response.json();
|
||||||
|
|||||||
@@ -9,6 +9,54 @@ async function getLorasWidgetModule() {
|
|||||||
return await dynamicImportByVersion("./loras_widget.js", "./legacy_loras_widget.js");
|
return await dynamicImportByVersion("./loras_widget.js", "./legacy_loras_widget.js");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Function to get connected trigger toggle nodes
|
||||||
|
function getConnectedTriggerToggleNodes(node) {
|
||||||
|
const connectedNodes = [];
|
||||||
|
|
||||||
|
// Check if node has outputs
|
||||||
|
if (node.outputs && node.outputs.length > 0) {
|
||||||
|
// For each output slot
|
||||||
|
for (const output of node.outputs) {
|
||||||
|
// Check if this output has any links
|
||||||
|
if (output.links && output.links.length > 0) {
|
||||||
|
// For each link, get the target node
|
||||||
|
for (const linkId of output.links) {
|
||||||
|
const link = app.graph.links[linkId];
|
||||||
|
if (link) {
|
||||||
|
const targetNode = app.graph.getNodeById(link.target_id);
|
||||||
|
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
|
||||||
|
connectedNodes.push(targetNode.id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return connectedNodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to update trigger words for connected toggle nodes
|
||||||
|
function updateConnectedTriggerWords(node, text) {
|
||||||
|
const connectedNodeIds = getConnectedTriggerToggleNodes(node);
|
||||||
|
if (connectedNodeIds.length > 0) {
|
||||||
|
const loraNames = new Set();
|
||||||
|
let match;
|
||||||
|
LORA_PATTERN.lastIndex = 0;
|
||||||
|
while ((match = LORA_PATTERN.exec(text)) !== null) {
|
||||||
|
loraNames.add(match[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
fetch("/loramanager/get_trigger_words", {
|
||||||
|
method: "POST",
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
body: JSON.stringify({
|
||||||
|
lora_names: Array.from(loraNames),
|
||||||
|
node_ids: connectedNodeIds
|
||||||
|
})
|
||||||
|
}).catch(err => console.error("Error fetching trigger words:", err));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function mergeLoras(lorasText, lorasArr) {
|
function mergeLoras(lorasText, lorasArr) {
|
||||||
const result = [];
|
const result = [];
|
||||||
let match;
|
let match;
|
||||||
@@ -99,6 +147,9 @@ app.registerExtension({
|
|||||||
newText = newText.replace(/\s+/g, ' ').trim();
|
newText = newText.replace(/\s+/g, ' ').trim();
|
||||||
|
|
||||||
inputWidget.value = newText;
|
inputWidget.value = newText;
|
||||||
|
|
||||||
|
// Add this line to update trigger words when lorasWidget changes cause inputWidget value to change
|
||||||
|
updateConnectedTriggerWords(node, newText);
|
||||||
} finally {
|
} finally {
|
||||||
isUpdating = false;
|
isUpdating = false;
|
||||||
}
|
}
|
||||||
@@ -117,6 +168,9 @@ app.registerExtension({
|
|||||||
const mergedLoras = mergeLoras(value, currentLoras);
|
const mergedLoras = mergeLoras(value, currentLoras);
|
||||||
|
|
||||||
node.lorasWidget.value = mergedLoras;
|
node.lorasWidget.value = mergedLoras;
|
||||||
|
|
||||||
|
// Replace the existing trigger word update code with the new function
|
||||||
|
updateConnectedTriggerWords(node, value);
|
||||||
} finally {
|
} finally {
|
||||||
isUpdating = false;
|
isUpdating = false;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,58 @@
|
|||||||
import { app } from "../../scripts/app.js";
|
import { app } from "../../scripts/app.js";
|
||||||
import { addLorasWidget } from "./loras_widget.js";
|
import { dynamicImportByVersion } from "./utils.js";
|
||||||
|
|
||||||
// Extract pattern into a constant for consistent use
|
// Extract pattern into a constant for consistent use
|
||||||
const LORA_PATTERN = /<lora:([^:]+):([-\d\.]+)>/g;
|
const LORA_PATTERN = /<lora:([^:]+):([-\d\.]+)>/g;
|
||||||
|
|
||||||
|
// Function to get the appropriate loras widget based on ComfyUI version
|
||||||
|
async function getLorasWidgetModule() {
|
||||||
|
return await dynamicImportByVersion("./loras_widget.js", "./legacy_loras_widget.js");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to get connected trigger toggle nodes
|
||||||
|
function getConnectedTriggerToggleNodes(node) {
|
||||||
|
const connectedNodes = [];
|
||||||
|
|
||||||
|
if (node.outputs && node.outputs.length > 0) {
|
||||||
|
for (const output of node.outputs) {
|
||||||
|
if (output.links && output.links.length > 0) {
|
||||||
|
for (const linkId of output.links) {
|
||||||
|
const link = app.graph.links[linkId];
|
||||||
|
if (link) {
|
||||||
|
const targetNode = app.graph.getNodeById(link.target_id);
|
||||||
|
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
|
||||||
|
connectedNodes.push(targetNode.id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return connectedNodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to update trigger words for connected toggle nodes
|
||||||
|
function updateConnectedTriggerWords(node, text) {
|
||||||
|
const connectedNodeIds = getConnectedTriggerToggleNodes(node);
|
||||||
|
if (connectedNodeIds.length > 0) {
|
||||||
|
const loraNames = new Set();
|
||||||
|
let match;
|
||||||
|
LORA_PATTERN.lastIndex = 0;
|
||||||
|
while ((match = LORA_PATTERN.exec(text)) !== null) {
|
||||||
|
loraNames.add(match[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
fetch("/loramanager/get_trigger_words", {
|
||||||
|
method: "POST",
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
body: JSON.stringify({
|
||||||
|
lora_names: Array.from(loraNames),
|
||||||
|
node_ids: connectedNodeIds
|
||||||
|
})
|
||||||
|
}).catch(err => console.error("Error fetching trigger words:", err));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function mergeLoras(lorasText, lorasArr) {
|
function mergeLoras(lorasText, lorasArr) {
|
||||||
const result = [];
|
const result = [];
|
||||||
let match;
|
let match;
|
||||||
@@ -40,7 +89,7 @@ app.registerExtension({
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Wait for node to be properly initialized
|
// Wait for node to be properly initialized
|
||||||
requestAnimationFrame(() => {
|
requestAnimationFrame(async () => {
|
||||||
// Restore saved value if exists
|
// Restore saved value if exists
|
||||||
let existingLoras = [];
|
let existingLoras = [];
|
||||||
if (node.widgets_values && node.widgets_values.length > 0) {
|
if (node.widgets_values && node.widgets_values.length > 0) {
|
||||||
@@ -64,7 +113,10 @@ app.registerExtension({
|
|||||||
// Add flag to prevent callback loops
|
// Add flag to prevent callback loops
|
||||||
let isUpdating = false;
|
let isUpdating = false;
|
||||||
|
|
||||||
// Get the widget object directly from the returned object
|
// Dynamically load the appropriate widget module
|
||||||
|
const lorasModule = await getLorasWidgetModule();
|
||||||
|
const { addLorasWidget } = lorasModule;
|
||||||
|
|
||||||
const result = addLorasWidget(node, "loras", {
|
const result = addLorasWidget(node, "loras", {
|
||||||
defaultVal: mergedLoras // Pass object directly
|
defaultVal: mergedLoras // Pass object directly
|
||||||
}, (value) => {
|
}, (value) => {
|
||||||
@@ -86,6 +138,9 @@ app.registerExtension({
|
|||||||
newText = newText.replace(/\s+/g, ' ').trim();
|
newText = newText.replace(/\s+/g, ' ').trim();
|
||||||
|
|
||||||
inputWidget.value = newText;
|
inputWidget.value = newText;
|
||||||
|
|
||||||
|
// Update trigger words when lorasWidget changes
|
||||||
|
updateConnectedTriggerWords(node, newText);
|
||||||
} finally {
|
} finally {
|
||||||
isUpdating = false;
|
isUpdating = false;
|
||||||
}
|
}
|
||||||
@@ -104,6 +159,9 @@ app.registerExtension({
|
|||||||
const mergedLoras = mergeLoras(value, currentLoras);
|
const mergedLoras = mergeLoras(value, currentLoras);
|
||||||
|
|
||||||
node.lorasWidget.value = mergedLoras;
|
node.lorasWidget.value = mergedLoras;
|
||||||
|
|
||||||
|
// Update trigger words when input changes
|
||||||
|
updateConnectedTriggerWords(node, value);
|
||||||
} finally {
|
} finally {
|
||||||
isUpdating = false;
|
isUpdating = false;
|
||||||
}
|
}
|
||||||
|
|||||||
36
web/comfyui/usage_stats.js
Normal file
36
web/comfyui/usage_stats.js
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
// ComfyUI extension to track model usage statistics
|
||||||
|
import { app } from "../../scripts/app.js";
|
||||||
|
import { api } from "../../scripts/api.js";
|
||||||
|
|
||||||
|
// Register the extension
|
||||||
|
app.registerExtension({
|
||||||
|
name: "ComfyUI-Lora-Manager.UsageStats",
|
||||||
|
|
||||||
|
init() {
|
||||||
|
// Listen for successful executions
|
||||||
|
api.addEventListener("execution_success", ({ detail }) => {
|
||||||
|
if (detail && detail.prompt_id) {
|
||||||
|
this.updateUsageStats(detail.prompt_id);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
},
|
||||||
|
|
||||||
|
async updateUsageStats(promptId) {
|
||||||
|
try {
|
||||||
|
// Call backend endpoint with the prompt_id
|
||||||
|
const response = await fetch(`/loras/api/update-usage-stats`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({ prompt_id: promptId }),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
console.warn("Failed to update usage statistics:", response.statusText);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error updating usage statistics:", error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user