mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 14:42:11 -03:00
Compare commits
35 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6f8e09fcde | ||
|
|
f54d480f03 | ||
|
|
e68b213fb3 | ||
|
|
132334d500 | ||
|
|
a6f04c6d7e | ||
|
|
854e8bf356 | ||
|
|
6ff883d2d3 | ||
|
|
849b97afba | ||
|
|
1bd2635864 | ||
|
|
79ab0f7b6c | ||
|
|
79011bd257 | ||
|
|
c692713ffb | ||
|
|
df9b554ce1 | ||
|
|
277a8e4682 | ||
|
|
acb52dba09 | ||
|
|
8f10765254 | ||
|
|
0653f59473 | ||
|
|
7a4b5a4667 | ||
|
|
49c4a4068b | ||
|
|
40ad590046 | ||
|
|
30374ae3e6 | ||
|
|
ab22d16bad | ||
|
|
971cd56a4a | ||
|
|
d7cb546c5f | ||
|
|
9d8b7344cd | ||
|
|
2d4f6ae7ce | ||
|
|
d9126807b0 | ||
|
|
cad5fb3fba | ||
|
|
afe23ad6b7 | ||
|
|
fc4327087b | ||
|
|
71762d788f | ||
|
|
6472e00fb0 | ||
|
|
4043846767 | ||
|
|
d3b2bc962c | ||
|
|
54f7b64821 |
18
README.md
18
README.md
@@ -18,10 +18,28 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
|
|||||||
|
|
||||||
[](https://youtu.be/hvKw31YpE-U)
|
[](https://youtu.be/hvKw31YpE-U)
|
||||||
|
|
||||||
|
## 🌐 Browser Extension
|
||||||
|
Enhance your Civitai browsing experience with our companion browser extension! See which models you already have, download new ones with a single click, and manage your downloads efficiently.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
<div>
|
||||||
|
<a href="https://chromewebstore.google.com/detail/lm-civitai-extension/capigligggeijgmocnaflanlbghnamgm?utm_source=item-share-cb" style="display: inline-block; background-color: #4285F4; color: white; padding: 8px 16px; text-decoration: none; border-radius: 4px; font-weight: bold; margin: 10px 0;">
|
||||||
|
<img src="https://www.google.com/chrome/static/images/chrome-logo.svg" width="20" style="vertical-align: middle; margin-right: 8px;"> Get Extension from Chrome Web Store
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
📚 [Learn More: Complete Tutorial](https://github.com/willmiao/ComfyUI-Lora-Manager/wiki/LoRA-Manager-Civitai-Extension-(Chrome-Extension))
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Release Notes
|
## Release Notes
|
||||||
|
|
||||||
|
### v0.8.20
|
||||||
|
* **LM Civitai Extension** - Released [browser extension through Chrome Web Store](https://chromewebstore.google.com/detail/lm-civitai-extension/capigligggeijgmocnaflanlbghnamgm?utm_source=item-share-cb) that works seamlessly with LoRA Manager to enhance Civitai browsing experience, showing which models are already in your local library, enabling one-click downloads, and providing queue and parallel download support
|
||||||
|
* **Enhanced Lora Loader** - Added support for nunchaku, improving convenience when working with ComfyUI-nunchaku workflows, plus new template workflows for quick onboarding
|
||||||
|
* **WanVideo Integration** - Introduced WanVideo Lora Select (LoraManager) node compatible with ComfyUI-WanVideoWrapper for streamlined lora usage in video workflows, including a template workflow to help you get started quickly
|
||||||
|
|
||||||
### v0.8.19
|
### v0.8.19
|
||||||
* **Analytics Dashboard** - Added new Statistics page providing comprehensive visual analysis of model collection and usage patterns for better library insights
|
* **Analytics Dashboard** - Added new Statistics page providing comprehensive visual analysis of model collection and usage patterns for better library insights
|
||||||
* **Target Node Selection** - Enhanced workflow integration with intelligent target choosing when sending LoRAs/recipes to workflows with multiple loader/stacker nodes; a visual selector now appears showing node color, type, ID, and title for precise targeting
|
* **Target Node Selection** - Enhanced workflow integration with intelligent target choosing when sending LoRAs/recipes to workflows with multiple loader/stacker nodes; a visual selector now appears showing node color, type, ID, and title for precise targeting
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from .py.nodes.trigger_word_toggle import TriggerWordToggle
|
|||||||
from .py.nodes.lora_stacker import LoraStacker
|
from .py.nodes.lora_stacker import LoraStacker
|
||||||
from .py.nodes.save_image import SaveImage
|
from .py.nodes.save_image import SaveImage
|
||||||
from .py.nodes.debug_metadata import DebugMetadata
|
from .py.nodes.debug_metadata import DebugMetadata
|
||||||
|
from .py.nodes.wanvideo_lora_select import WanVideoLoraSelect
|
||||||
# Import metadata collector to install hooks on startup
|
# Import metadata collector to install hooks on startup
|
||||||
from .py.metadata_collector import init as init_metadata_collector
|
from .py.metadata_collector import init as init_metadata_collector
|
||||||
|
|
||||||
@@ -12,7 +13,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
TriggerWordToggle.NAME: TriggerWordToggle,
|
TriggerWordToggle.NAME: TriggerWordToggle,
|
||||||
LoraStacker.NAME: LoraStacker,
|
LoraStacker.NAME: LoraStacker,
|
||||||
SaveImage.NAME: SaveImage,
|
SaveImage.NAME: SaveImage,
|
||||||
DebugMetadata.NAME: DebugMetadata
|
DebugMetadata.NAME: DebugMetadata,
|
||||||
|
WanVideoLoraSelect.NAME: WanVideoLoraSelect
|
||||||
}
|
}
|
||||||
|
|
||||||
WEB_DIRECTORY = "./web/comfyui"
|
WEB_DIRECTORY = "./web/comfyui"
|
||||||
|
|||||||
BIN
example_workflows/nunchaku-flux.1-dev.jpg
Normal file
BIN
example_workflows/nunchaku-flux.1-dev.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 68 KiB |
1
example_workflows/nunchaku-flux.1-dev.json
Normal file
1
example_workflows/nunchaku-flux.1-dev.json
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
84
py/config.py
84
py/config.py
@@ -22,7 +22,9 @@ class Config:
|
|||||||
# 静态路由映射字典, target to route mapping
|
# 静态路由映射字典, target to route mapping
|
||||||
self._route_mappings = {}
|
self._route_mappings = {}
|
||||||
self.loras_roots = self._init_lora_paths()
|
self.loras_roots = self._init_lora_paths()
|
||||||
self.checkpoints_roots = self._init_checkpoint_paths()
|
self.checkpoints_roots = None
|
||||||
|
self.unet_roots = None
|
||||||
|
self.base_models_roots = self._init_checkpoint_paths()
|
||||||
# 在初始化时扫描符号链接
|
# 在初始化时扫描符号链接
|
||||||
self._scan_symbolic_links()
|
self._scan_symbolic_links()
|
||||||
|
|
||||||
@@ -33,34 +35,26 @@ class Config:
|
|||||||
def save_folder_paths_to_settings(self):
|
def save_folder_paths_to_settings(self):
|
||||||
"""Save folder paths to settings.json for standalone mode to use later"""
|
"""Save folder paths to settings.json for standalone mode to use later"""
|
||||||
try:
|
try:
|
||||||
# Check if we're running in ComfyUI mode (not standalone)
|
# Check if we're running in ComfyUI mode (not standalone)
|
||||||
if hasattr(folder_paths, "get_folder_paths") and not isinstance(folder_paths, type):
|
# Load existing settings
|
||||||
# Get all relevant paths
|
settings_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json')
|
||||||
lora_paths = folder_paths.get_folder_paths("loras")
|
settings = {}
|
||||||
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
if os.path.exists(settings_path):
|
||||||
diffuser_paths = folder_paths.get_folder_paths("diffusers")
|
with open(settings_path, 'r', encoding='utf-8') as f:
|
||||||
unet_paths = folder_paths.get_folder_paths("unet")
|
settings = json.load(f)
|
||||||
|
|
||||||
# Load existing settings
|
# Update settings with paths
|
||||||
settings_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json')
|
settings['folder_paths'] = {
|
||||||
settings = {}
|
'loras': self.loras_roots,
|
||||||
if os.path.exists(settings_path):
|
'checkpoints': self.checkpoints_roots,
|
||||||
with open(settings_path, 'r', encoding='utf-8') as f:
|
'unet': self.unet_roots,
|
||||||
settings = json.load(f)
|
}
|
||||||
|
|
||||||
# Update settings with paths
|
# Save settings
|
||||||
settings['folder_paths'] = {
|
with open(settings_path, 'w', encoding='utf-8') as f:
|
||||||
'loras': lora_paths,
|
json.dump(settings, f, indent=2)
|
||||||
'checkpoints': checkpoint_paths,
|
|
||||||
'diffusers': diffuser_paths,
|
logger.info("Saved folder paths to settings.json")
|
||||||
'unet': unet_paths
|
|
||||||
}
|
|
||||||
|
|
||||||
# Save settings
|
|
||||||
with open(settings_path, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(settings, f, indent=2)
|
|
||||||
|
|
||||||
logger.info("Saved folder paths to settings.json")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to save folder paths: {e}")
|
logger.warning(f"Failed to save folder paths: {e}")
|
||||||
|
|
||||||
@@ -86,7 +80,7 @@ class Config:
|
|||||||
for root in self.loras_roots:
|
for root in self.loras_roots:
|
||||||
self._scan_directory_links(root)
|
self._scan_directory_links(root)
|
||||||
|
|
||||||
for root in self.checkpoints_roots:
|
for root in self.base_models_roots:
|
||||||
self._scan_directory_links(root)
|
self._scan_directory_links(root)
|
||||||
|
|
||||||
def _scan_directory_links(self, root: str):
|
def _scan_directory_links(self, root: str):
|
||||||
@@ -178,30 +172,36 @@ class Config:
|
|||||||
try:
|
try:
|
||||||
# Get checkpoint paths from folder_paths
|
# Get checkpoint paths from folder_paths
|
||||||
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||||
diffusion_paths = folder_paths.get_folder_paths("diffusers")
|
|
||||||
unet_paths = folder_paths.get_folder_paths("unet")
|
unet_paths = folder_paths.get_folder_paths("unet")
|
||||||
|
|
||||||
# Combine all checkpoint-related paths
|
# Sort each list individually
|
||||||
all_paths = checkpoint_paths + diffusion_paths + unet_paths
|
checkpoint_paths = sorted(set(path.replace(os.sep, "/")
|
||||||
|
for path in checkpoint_paths
|
||||||
# Filter and normalize paths
|
|
||||||
paths = sorted(set(path.replace(os.sep, "/")
|
|
||||||
for path in all_paths
|
|
||||||
if os.path.exists(path)), key=lambda p: p.lower())
|
if os.path.exists(path)), key=lambda p: p.lower())
|
||||||
|
|
||||||
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(paths) if paths else "[]"))
|
unet_paths = sorted(set(path.replace(os.sep, "/")
|
||||||
|
for path in unet_paths
|
||||||
|
if os.path.exists(path)), key=lambda p: p.lower())
|
||||||
|
|
||||||
if not paths:
|
# Combine all checkpoint-related paths, ensuring checkpoint_paths are first
|
||||||
|
all_paths = checkpoint_paths + unet_paths
|
||||||
|
|
||||||
|
self.checkpoints_roots = checkpoint_paths
|
||||||
|
self.unet_roots = unet_paths
|
||||||
|
|
||||||
|
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(all_paths) if all_paths else "[]"))
|
||||||
|
|
||||||
|
if not all_paths:
|
||||||
logger.warning("No valid checkpoint folders found in ComfyUI configuration")
|
logger.warning("No valid checkpoint folders found in ComfyUI configuration")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 初始化路径映射,与 LoRA 路径处理方式相同
|
# Initialize path mappings, similar to LoRA path handling
|
||||||
for path in paths:
|
for path in all_paths:
|
||||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||||
if real_path != path:
|
if real_path != path:
|
||||||
self.add_path_mapping(path, real_path)
|
self.add_path_mapping(path, real_path)
|
||||||
|
|
||||||
return paths
|
return all_paths
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error initializing checkpoint paths: {e}")
|
logger.warning(f"Error initializing checkpoint paths: {e}")
|
||||||
return []
|
return []
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ class LoraManager:
|
|||||||
added_targets.add(real_root)
|
added_targets.add(real_root)
|
||||||
|
|
||||||
# Add static routes for each checkpoint root
|
# Add static routes for each checkpoint root
|
||||||
for idx, root in enumerate(config.checkpoints_roots, start=1):
|
for idx, root in enumerate(config.base_models_roots, start=1):
|
||||||
preview_path = f'/checkpoints_static/root{idx}/preview'
|
preview_path = f'/checkpoints_static/root{idx}/preview'
|
||||||
|
|
||||||
real_root = root
|
real_root = root
|
||||||
@@ -88,8 +88,8 @@ class LoraManager:
|
|||||||
for target_path, link_path in config._path_mappings.items():
|
for target_path, link_path in config._path_mappings.items():
|
||||||
if target_path not in added_targets:
|
if target_path not in added_targets:
|
||||||
# Determine if this is a checkpoint or lora link based on path
|
# Determine if this is a checkpoint or lora link based on path
|
||||||
is_checkpoint = any(cp_root in link_path for cp_root in config.checkpoints_roots)
|
is_checkpoint = any(cp_root in link_path for cp_root in config.base_models_roots)
|
||||||
is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.checkpoints_roots)
|
is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.base_models_roots)
|
||||||
|
|
||||||
if is_checkpoint:
|
if is_checkpoint:
|
||||||
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'
|
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'
|
||||||
|
|||||||
@@ -238,25 +238,45 @@ class MetadataProcessor:
|
|||||||
|
|
||||||
pos_conditioning = metadata[PROMPTS][sampler_id].get("pos_conditioning")
|
pos_conditioning = metadata[PROMPTS][sampler_id].get("pos_conditioning")
|
||||||
neg_conditioning = metadata[PROMPTS][sampler_id].get("neg_conditioning")
|
neg_conditioning = metadata[PROMPTS][sampler_id].get("neg_conditioning")
|
||||||
|
|
||||||
# Try to match conditioning objects with those stored by CLIPTextEncodeExtractor
|
# Helper function to recursively find prompt text for a conditioning object
|
||||||
for prompt_node_id, prompt_data in metadata[PROMPTS].items():
|
def find_prompt_text_for_conditioning(conditioning_obj, is_positive=True):
|
||||||
# For nodes with single conditioning output
|
if conditioning_obj is None:
|
||||||
if "conditioning" in prompt_data:
|
return ""
|
||||||
if pos_conditioning is not None and id(prompt_data["conditioning"]) == id(pos_conditioning):
|
|
||||||
result["prompt"] = prompt_data.get("text", "")
|
|
||||||
|
|
||||||
if neg_conditioning is not None and id(prompt_data["conditioning"]) == id(neg_conditioning):
|
# Try to match conditioning objects with those stored by extractors
|
||||||
result["negative_prompt"] = prompt_data.get("text", "")
|
for prompt_node_id, prompt_data in metadata[PROMPTS].items():
|
||||||
|
# For nodes with single conditioning output
|
||||||
|
if "conditioning" in prompt_data:
|
||||||
|
if id(prompt_data["conditioning"]) == id(conditioning_obj):
|
||||||
|
return prompt_data.get("text", "")
|
||||||
|
|
||||||
|
# For nodes with separate pos_conditioning and neg_conditioning outputs (like TSC_EfficientLoader)
|
||||||
|
if is_positive and "positive_encoded" in prompt_data:
|
||||||
|
if id(prompt_data["positive_encoded"]) == id(conditioning_obj):
|
||||||
|
if "positive_text" in prompt_data:
|
||||||
|
return prompt_data["positive_text"]
|
||||||
|
else:
|
||||||
|
orig_conditioning = prompt_data.get("orig_pos_cond", None)
|
||||||
|
if orig_conditioning is not None:
|
||||||
|
# Recursively find the prompt text for the original conditioning
|
||||||
|
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=True)
|
||||||
|
|
||||||
|
if not is_positive and "negative_encoded" in prompt_data:
|
||||||
|
if id(prompt_data["negative_encoded"]) == id(conditioning_obj):
|
||||||
|
if "negative_text" in prompt_data:
|
||||||
|
return prompt_data["negative_text"]
|
||||||
|
else:
|
||||||
|
orig_conditioning = prompt_data.get("orig_neg_cond", None)
|
||||||
|
if orig_conditioning is not None:
|
||||||
|
# Recursively find the prompt text for the original conditioning
|
||||||
|
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=False)
|
||||||
|
|
||||||
# For nodes with separate pos_conditioning and neg_conditioning outputs (like TSC_EfficientLoader)
|
return ""
|
||||||
if "positive_encoded" in prompt_data:
|
|
||||||
if pos_conditioning is not None and id(prompt_data["positive_encoded"]) == id(pos_conditioning):
|
# Find prompt texts using the helper function
|
||||||
result["prompt"] = prompt_data.get("positive_text", "")
|
result["prompt"] = find_prompt_text_for_conditioning(pos_conditioning, is_positive=True)
|
||||||
|
result["negative_prompt"] = find_prompt_text_for_conditioning(neg_conditioning, is_positive=False)
|
||||||
if "negative_encoded" in prompt_data:
|
|
||||||
if neg_conditioning is not None and id(prompt_data["negative_encoded"]) == id(neg_conditioning):
|
|
||||||
result["negative_prompt"] = prompt_data.get("negative_text", "")
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -569,6 +569,40 @@ class CFGGuiderExtractor(NodeMetadataExtractor):
|
|||||||
|
|
||||||
metadata[SAMPLING][node_id]["parameters"]["cfg"] = cfg_value
|
metadata[SAMPLING][node_id]["parameters"]["cfg"] = cfg_value
|
||||||
|
|
||||||
|
class CR_ApplyControlNetStackExtractor(NodeMetadataExtractor):
|
||||||
|
@staticmethod
|
||||||
|
def extract(node_id, inputs, outputs, metadata):
|
||||||
|
if not inputs:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Save the original conditioning inputs
|
||||||
|
base_positive = inputs.get("base_positive")
|
||||||
|
base_negative = inputs.get("base_negative")
|
||||||
|
|
||||||
|
if base_positive is not None or base_negative is not None:
|
||||||
|
if node_id not in metadata[PROMPTS]:
|
||||||
|
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||||
|
|
||||||
|
metadata[PROMPTS][node_id]["orig_pos_cond"] = base_positive
|
||||||
|
metadata[PROMPTS][node_id]["orig_neg_cond"] = base_negative
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update(node_id, outputs, metadata):
|
||||||
|
# Extract transformed conditionings from outputs
|
||||||
|
# outputs structure: [(base_positive, base_negative, show_help, )]
|
||||||
|
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||||
|
first_output = outputs[0]
|
||||||
|
if isinstance(first_output, tuple) and len(first_output) >= 2:
|
||||||
|
transformed_positive = first_output[0]
|
||||||
|
transformed_negative = first_output[1]
|
||||||
|
|
||||||
|
# Save transformed conditioning objects in metadata
|
||||||
|
if node_id not in metadata[PROMPTS]:
|
||||||
|
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||||
|
|
||||||
|
metadata[PROMPTS][node_id]["positive_encoded"] = transformed_positive
|
||||||
|
metadata[PROMPTS][node_id]["negative_encoded"] = transformed_negative
|
||||||
|
|
||||||
# Registry of node-specific extractors
|
# Registry of node-specific extractors
|
||||||
# Keys are node class names
|
# Keys are node class names
|
||||||
NODE_EXTRACTORS = {
|
NODE_EXTRACTORS = {
|
||||||
@@ -594,6 +628,8 @@ NODE_EXTRACTORS = {
|
|||||||
"CLIPTextEncodeFlux": CLIPTextEncodeFluxExtractor, # Add CLIPTextEncodeFlux
|
"CLIPTextEncodeFlux": CLIPTextEncodeFluxExtractor, # Add CLIPTextEncodeFlux
|
||||||
"WAS_Text_to_Conditioning": CLIPTextEncodeExtractor,
|
"WAS_Text_to_Conditioning": CLIPTextEncodeExtractor,
|
||||||
"AdvancedCLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/BlenderNeko/ComfyUI_ADV_CLIP_emb
|
"AdvancedCLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/BlenderNeko/ComfyUI_ADV_CLIP_emb
|
||||||
|
"smZ_CLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/shiimizu/ComfyUI_smZNodes
|
||||||
|
"CR_ApplyControlNetStack": CR_ApplyControlNetStackExtractor, # Add CR_ApplyControlNetStack
|
||||||
# Latent
|
# Latent
|
||||||
"EmptyLatentImage": ImageSizeExtractor,
|
"EmptyLatentImage": ImageSizeExtractor,
|
||||||
# Flux
|
# Flux
|
||||||
|
|||||||
@@ -2,14 +2,15 @@ import logging
|
|||||||
from nodes import LoraLoader
|
from nodes import LoraLoader
|
||||||
from comfy.comfy_types import IO # type: ignore
|
from comfy.comfy_types import IO # type: ignore
|
||||||
import asyncio
|
import asyncio
|
||||||
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
|
from ..utils.utils import get_lora_info
|
||||||
|
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list, nunchaku_load_lora
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class LoraManagerLoader:
|
class LoraManagerLoader:
|
||||||
NAME = "Lora Loader (LoraManager)"
|
NAME = "Lora Loader (LoraManager)"
|
||||||
CATEGORY = "Lora Manager/loaders"
|
CATEGORY = "Lora Manager/loaders"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def INPUT_TYPES(cls):
|
||||||
return {
|
return {
|
||||||
@@ -37,19 +38,39 @@ class LoraManagerLoader:
|
|||||||
|
|
||||||
clip = kwargs.get('clip', None)
|
clip = kwargs.get('clip', None)
|
||||||
lora_stack = kwargs.get('lora_stack', None)
|
lora_stack = kwargs.get('lora_stack', None)
|
||||||
|
|
||||||
|
# Check if model is a Nunchaku Flux model - simplified approach
|
||||||
|
is_nunchaku_model = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_wrapper = model.model.diffusion_model
|
||||||
|
# Check if model is a Nunchaku Flux model using only class name
|
||||||
|
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
|
||||||
|
is_nunchaku_model = True
|
||||||
|
logger.info("Detected Nunchaku Flux model")
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
# Not a model with the expected structure
|
||||||
|
pass
|
||||||
|
|
||||||
# First process lora_stack if available
|
# First process lora_stack if available
|
||||||
if lora_stack:
|
if lora_stack:
|
||||||
for lora_path, model_strength, clip_strength in lora_stack:
|
for lora_path, model_strength, clip_strength in lora_stack:
|
||||||
# Apply the LoRA using the provided path and strengths
|
# Apply the LoRA using the appropriate loader
|
||||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
if is_nunchaku_model:
|
||||||
|
# Use our custom function for Flux models
|
||||||
|
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||||
|
# clip remains unchanged for Nunchaku models
|
||||||
|
else:
|
||||||
|
# Use default loader for standard models
|
||||||
|
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||||
|
|
||||||
# Extract lora name for trigger words lookup
|
# Extract lora name for trigger words lookup
|
||||||
lora_name = extract_lora_name(lora_path)
|
lora_name = extract_lora_name(lora_path)
|
||||||
_, trigger_words = asyncio.run(get_lora_info(lora_name))
|
_, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||||
|
|
||||||
all_trigger_words.extend(trigger_words)
|
all_trigger_words.extend(trigger_words)
|
||||||
# Add clip strength to output if different from model strength
|
# Add clip strength to output if different from model strength (except for Nunchaku models)
|
||||||
if abs(model_strength - clip_strength) > 0.001:
|
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||||
else:
|
else:
|
||||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||||
@@ -68,11 +89,17 @@ class LoraManagerLoader:
|
|||||||
# Get lora path and trigger words
|
# Get lora path and trigger words
|
||||||
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
|
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||||
|
|
||||||
# Apply the LoRA using the resolved path with separate strengths
|
# Apply the LoRA using the appropriate loader
|
||||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
if is_nunchaku_model:
|
||||||
|
# For Nunchaku models, use our custom function
|
||||||
|
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||||
|
# clip remains unchanged
|
||||||
|
else:
|
||||||
|
# Use default loader for standard models
|
||||||
|
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||||
|
|
||||||
# Include clip strength in output if different from model strength
|
# Include clip strength in output if different from model strength and not a Nunchaku model
|
||||||
if abs(model_strength - clip_strength) > 0.001:
|
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||||
else:
|
else:
|
||||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
from comfy.comfy_types import IO # type: ignore
|
from comfy.comfy_types import IO # type: ignore
|
||||||
from ..services.lora_scanner import LoraScanner
|
|
||||||
from ..config import config
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
|
from ..utils.utils import get_lora_info
|
||||||
|
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -35,31 +35,11 @@ any_type = AnyType("*")
|
|||||||
# Common methods extracted from lora_loader.py and lora_stacker.py
|
# Common methods extracted from lora_loader.py and lora_stacker.py
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import copy
|
||||||
from ..services.lora_scanner import LoraScanner
|
import folder_paths
|
||||||
from ..config import config
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
async def get_lora_info(lora_name):
|
|
||||||
"""Get the lora path and trigger words from cache"""
|
|
||||||
scanner = await LoraScanner.get_instance()
|
|
||||||
cache = await scanner.get_cached_data()
|
|
||||||
|
|
||||||
for item in cache.raw_data:
|
|
||||||
if item.get('file_name') == lora_name:
|
|
||||||
file_path = item.get('file_path')
|
|
||||||
if file_path:
|
|
||||||
for root in config.loras_roots:
|
|
||||||
root = root.replace(os.sep, '/')
|
|
||||||
if file_path.startswith(root):
|
|
||||||
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
|
||||||
# Get trigger words from civitai metadata
|
|
||||||
civitai = item.get('civitai', {})
|
|
||||||
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
|
||||||
return relative_path, trigger_words
|
|
||||||
return lora_name, [] # Fallback if not found
|
|
||||||
|
|
||||||
def extract_lora_name(lora_path):
|
def extract_lora_name(lora_path):
|
||||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
||||||
# Get the basename without extension
|
# Get the basename without extension
|
||||||
@@ -81,4 +61,70 @@ def get_loras_list(kwargs):
|
|||||||
# Unexpected format
|
# Unexpected format
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""):
|
||||||
|
"""Simplified version of load_state_dict_in_safetensors that just loads from a local path"""
|
||||||
|
import safetensors.torch
|
||||||
|
|
||||||
|
state_dict = {}
|
||||||
|
with safetensors.torch.safe_open(path, framework="pt", device=device) as f:
|
||||||
|
for k in f.keys():
|
||||||
|
if filter_prefix and not k.startswith(filter_prefix):
|
||||||
|
continue
|
||||||
|
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def to_diffusers(input_lora):
|
||||||
|
"""Simplified version of to_diffusers for Flux LoRA conversion"""
|
||||||
|
import torch
|
||||||
|
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
|
||||||
|
from diffusers.loaders import FluxLoraLoaderMixin
|
||||||
|
|
||||||
|
if isinstance(input_lora, str):
|
||||||
|
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
|
||||||
|
else:
|
||||||
|
tensors = {k: v for k, v in input_lora.items()}
|
||||||
|
|
||||||
|
# Convert FP8 tensors to BF16
|
||||||
|
for k, v in tensors.items():
|
||||||
|
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
|
||||||
|
tensors[k] = v.to(torch.bfloat16)
|
||||||
|
|
||||||
|
new_tensors = FluxLoraLoaderMixin.lora_state_dict(tensors)
|
||||||
|
new_tensors = convert_unet_state_dict_to_peft(new_tensors)
|
||||||
|
|
||||||
|
return new_tensors
|
||||||
|
|
||||||
|
def nunchaku_load_lora(model, lora_name, lora_strength):
|
||||||
|
"""Load a Flux LoRA for Nunchaku model"""
|
||||||
|
model_wrapper = model.model.diffusion_model
|
||||||
|
transformer = model_wrapper.model
|
||||||
|
|
||||||
|
# Save the transformer temporarily
|
||||||
|
model_wrapper.model = None
|
||||||
|
ret_model = copy.deepcopy(model) # copy everything except the model
|
||||||
|
ret_model_wrapper = ret_model.model.diffusion_model
|
||||||
|
|
||||||
|
# Restore the model and set it for the copy
|
||||||
|
model_wrapper.model = transformer
|
||||||
|
ret_model_wrapper.model = transformer
|
||||||
|
|
||||||
|
# Get full path to the LoRA file
|
||||||
|
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||||
|
ret_model_wrapper.loras.append((lora_path, lora_strength))
|
||||||
|
|
||||||
|
# Convert the LoRA to diffusers format
|
||||||
|
sd = to_diffusers(lora_path)
|
||||||
|
|
||||||
|
# Handle embedding adjustment if needed
|
||||||
|
if "transformer.x_embedder.lora_A.weight" in sd:
|
||||||
|
new_in_channels = sd["transformer.x_embedder.lora_A.weight"].shape[1]
|
||||||
|
assert new_in_channels % 4 == 0
|
||||||
|
new_in_channels = new_in_channels // 4
|
||||||
|
|
||||||
|
old_in_channels = ret_model.model.model_config.unet_config["in_channels"]
|
||||||
|
if old_in_channels < new_in_channels:
|
||||||
|
ret_model.model.model_config.unet_config["in_channels"] = new_in_channels
|
||||||
|
|
||||||
|
return ret_model
|
||||||
93
py/nodes/wanvideo_lora_select.py
Normal file
93
py/nodes/wanvideo_lora_select.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
from comfy.comfy_types import IO # type: ignore
|
||||||
|
import asyncio
|
||||||
|
import folder_paths # type: ignore
|
||||||
|
from ..utils.utils import get_lora_info
|
||||||
|
from .utils import FlexibleOptionalInputType, any_type, get_loras_list
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class WanVideoLoraSelect:
|
||||||
|
NAME = "WanVideo Lora Select (LoraManager)"
|
||||||
|
CATEGORY = "Lora Manager/stackers"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load the LORA model with less VRAM usage, slower loading"}),
|
||||||
|
"text": (IO.STRING, {
|
||||||
|
"multiline": True,
|
||||||
|
"dynamicPrompts": True,
|
||||||
|
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||||
|
"placeholder": "LoRA syntax input: <lora:name:strength>"
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
"optional": FlexibleOptionalInputType(any_type),
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("WANVIDLORA", IO.STRING, IO.STRING)
|
||||||
|
RETURN_NAMES = ("lora", "trigger_words", "active_loras")
|
||||||
|
FUNCTION = "process_loras"
|
||||||
|
|
||||||
|
def process_loras(self, text, low_mem_load=False, **kwargs):
|
||||||
|
loras_list = []
|
||||||
|
all_trigger_words = []
|
||||||
|
active_loras = []
|
||||||
|
|
||||||
|
# Process existing prev_lora if available
|
||||||
|
prev_lora = kwargs.get('prev_lora', None)
|
||||||
|
if prev_lora is not None:
|
||||||
|
loras_list.extend(prev_lora)
|
||||||
|
|
||||||
|
# Get blocks if available
|
||||||
|
blocks = kwargs.get('blocks', {})
|
||||||
|
selected_blocks = blocks.get("selected_blocks", {})
|
||||||
|
layer_filter = blocks.get("layer_filter", "")
|
||||||
|
|
||||||
|
# Process loras from kwargs with support for both old and new formats
|
||||||
|
loras_from_widget = get_loras_list(kwargs)
|
||||||
|
for lora in loras_from_widget:
|
||||||
|
if not lora.get('active', False):
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora_name = lora['name']
|
||||||
|
model_strength = float(lora['strength'])
|
||||||
|
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||||
|
|
||||||
|
# Get lora path and trigger words
|
||||||
|
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||||
|
|
||||||
|
# Create lora item for WanVideo format
|
||||||
|
lora_item = {
|
||||||
|
"path": folder_paths.get_full_path("loras", lora_path),
|
||||||
|
"strength": model_strength,
|
||||||
|
"name": lora_path.split(".")[0],
|
||||||
|
"blocks": selected_blocks,
|
||||||
|
"layer_filter": layer_filter,
|
||||||
|
"low_mem_load": low_mem_load,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add to list and collect active loras
|
||||||
|
loras_list.append(lora_item)
|
||||||
|
active_loras.append((lora_name, model_strength, clip_strength))
|
||||||
|
|
||||||
|
# Add trigger words to collection
|
||||||
|
all_trigger_words.extend(trigger_words)
|
||||||
|
|
||||||
|
# Format trigger_words for output
|
||||||
|
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||||
|
|
||||||
|
# Format active_loras for output
|
||||||
|
formatted_loras = []
|
||||||
|
for name, model_strength, clip_strength in active_loras:
|
||||||
|
if abs(model_strength - clip_strength) > 0.001:
|
||||||
|
# Different model and clip strengths
|
||||||
|
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}:{str(clip_strength).strip()}>")
|
||||||
|
else:
|
||||||
|
# Same strength for both
|
||||||
|
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}>")
|
||||||
|
|
||||||
|
active_loras_text = " ".join(formatted_loras)
|
||||||
|
|
||||||
|
return (loras_list, trigger_words_text, active_loras_text)
|
||||||
@@ -19,7 +19,7 @@ class AutomaticMetadataParser(RecipeMetadataParser):
|
|||||||
LORA_HASHES_REGEX = r', Lora hashes:\s*"([^"]+)"'
|
LORA_HASHES_REGEX = r', Lora hashes:\s*"([^"]+)"'
|
||||||
CIVITAI_RESOURCES_REGEX = r', Civitai resources:\s*(\[\{.*?\}\])'
|
CIVITAI_RESOURCES_REGEX = r', Civitai resources:\s*(\[\{.*?\}\])'
|
||||||
CIVITAI_METADATA_REGEX = r', Civitai metadata:\s*(\{.*?\})'
|
CIVITAI_METADATA_REGEX = r', Civitai metadata:\s*(\{.*?\})'
|
||||||
EXTRANETS_REGEX = r'<(lora|hypernet):([a-zA-Z0-9_\.\-]+):([0-9.]+)>'
|
EXTRANETS_REGEX = r'<(lora|hypernet):([^:]+):(-?[0-9.]+)>'
|
||||||
MODEL_HASH_PATTERN = r'Model hash: ([a-zA-Z0-9]+)'
|
MODEL_HASH_PATTERN = r'Model hash: ([a-zA-Z0-9]+)'
|
||||||
VAE_HASH_PATTERN = r'VAE hash: ([a-zA-Z0-9]+)'
|
VAE_HASH_PATTERN = r'VAE hash: ([a-zA-Z0-9]+)'
|
||||||
|
|
||||||
|
|||||||
@@ -50,6 +50,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
'from_civitai_image': True
|
'from_civitai_image': True
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Track already added LoRAs to prevent duplicates
|
||||||
|
added_loras = {} # key: model_version_id or hash, value: index in result["loras"]
|
||||||
|
|
||||||
# Extract prompt and negative prompt
|
# Extract prompt and negative prompt
|
||||||
if "prompt" in metadata:
|
if "prompt" in metadata:
|
||||||
result["gen_params"]["prompt"] = metadata["prompt"]
|
result["gen_params"]["prompt"] = metadata["prompt"]
|
||||||
@@ -96,11 +99,17 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
for resource in metadata["resources"]:
|
for resource in metadata["resources"]:
|
||||||
# Modified to process resources without a type field as potential LoRAs
|
# Modified to process resources without a type field as potential LoRAs
|
||||||
if resource.get("type", "lora") == "lora":
|
if resource.get("type", "lora") == "lora":
|
||||||
|
lora_hash = resource.get("hash", "")
|
||||||
|
|
||||||
|
# Skip if we've already added this LoRA by hash
|
||||||
|
if lora_hash and lora_hash in added_loras:
|
||||||
|
continue
|
||||||
|
|
||||||
lora_entry = {
|
lora_entry = {
|
||||||
'name': resource.get("name", "Unknown LoRA"),
|
'name': resource.get("name", "Unknown LoRA"),
|
||||||
'type': "lora",
|
'type': "lora",
|
||||||
'weight': float(resource.get("weight", 1.0)),
|
'weight': float(resource.get("weight", 1.0)),
|
||||||
'hash': resource.get("hash", ""),
|
'hash': lora_hash,
|
||||||
'existsLocally': False,
|
'existsLocally': False,
|
||||||
'localPath': None,
|
'localPath': None,
|
||||||
'file_name': resource.get("name", "Unknown"),
|
'file_name': resource.get("name", "Unknown"),
|
||||||
@@ -114,7 +123,6 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
# Try to get info from Civitai if hash is available
|
# Try to get info from Civitai if hash is available
|
||||||
if lora_entry['hash'] and civitai_client:
|
if lora_entry['hash'] and civitai_client:
|
||||||
try:
|
try:
|
||||||
lora_hash = lora_entry['hash']
|
|
||||||
civitai_info = await civitai_client.get_model_by_hash(lora_hash)
|
civitai_info = await civitai_client.get_model_by_hash(lora_hash)
|
||||||
|
|
||||||
populated_entry = await self.populate_lora_from_civitai(
|
populated_entry = await self.populate_lora_from_civitai(
|
||||||
@@ -129,43 +137,124 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
continue # Skip invalid LoRA types
|
continue # Skip invalid LoRA types
|
||||||
|
|
||||||
lora_entry = populated_entry
|
lora_entry = populated_entry
|
||||||
|
|
||||||
|
# If we have a version ID from Civitai, track it for deduplication
|
||||||
|
if 'id' in lora_entry and lora_entry['id']:
|
||||||
|
added_loras[str(lora_entry['id'])] = len(result["loras"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}")
|
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}")
|
||||||
|
|
||||||
|
# Track by hash if we have it
|
||||||
|
if lora_hash:
|
||||||
|
added_loras[lora_hash] = len(result["loras"])
|
||||||
|
|
||||||
result["loras"].append(lora_entry)
|
result["loras"].append(lora_entry)
|
||||||
|
|
||||||
# Process civitaiResources array
|
# Process civitaiResources array
|
||||||
if "civitaiResources" in metadata and isinstance(metadata["civitaiResources"], list):
|
if "civitaiResources" in metadata and isinstance(metadata["civitaiResources"], list):
|
||||||
for resource in metadata["civitaiResources"]:
|
for resource in metadata["civitaiResources"]:
|
||||||
# Modified to process resources without a type field as potential LoRAs
|
# Skip resources that aren't LoRAs or LyCORIS
|
||||||
if resource.get("type") in ["lora", "lycoris"] or "type" not in resource:
|
if resource.get("type") not in ["lora", "lycoris"] and "type" not in resource:
|
||||||
# Initialize lora entry with the same structure as in automatic.py
|
continue
|
||||||
lora_entry = {
|
|
||||||
'id': resource.get("modelVersionId", 0),
|
|
||||||
'modelId': resource.get("modelId", 0),
|
|
||||||
'name': resource.get("modelName", "Unknown LoRA"),
|
|
||||||
'version': resource.get("modelVersionName", ""),
|
|
||||||
'type': resource.get("type", "lora"),
|
|
||||||
'weight': round(float(resource.get("weight", 1.0)), 2),
|
|
||||||
'existsLocally': False,
|
|
||||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
|
||||||
'baseModel': '',
|
|
||||||
'size': 0,
|
|
||||||
'downloadUrl': '',
|
|
||||||
'isDeleted': False
|
|
||||||
}
|
|
||||||
|
|
||||||
# Try to get info from Civitai if modelVersionId is available
|
# Get unique identifier for deduplication
|
||||||
if resource.get('modelVersionId') and civitai_client:
|
version_id = str(resource.get("modelVersionId", ""))
|
||||||
try:
|
|
||||||
version_id = str(resource.get('modelVersionId'))
|
# Skip if we've already added this LoRA
|
||||||
# Use get_model_version_info instead of get_model_version
|
if version_id and version_id in added_loras:
|
||||||
civitai_info, error = await civitai_client.get_model_version_info(version_id)
|
continue
|
||||||
|
|
||||||
if error:
|
# Initialize lora entry
|
||||||
logger.warning(f"Error getting model version info: {error}")
|
lora_entry = {
|
||||||
continue
|
'id': resource.get("modelVersionId", 0),
|
||||||
|
'modelId': resource.get("modelId", 0),
|
||||||
|
'name': resource.get("modelName", "Unknown LoRA"),
|
||||||
|
'version': resource.get("modelVersionName", ""),
|
||||||
|
'type': resource.get("type", "lora"),
|
||||||
|
'weight': round(float(resource.get("weight", 1.0)), 2),
|
||||||
|
'existsLocally': False,
|
||||||
|
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||||
|
'baseModel': '',
|
||||||
|
'size': 0,
|
||||||
|
'downloadUrl': '',
|
||||||
|
'isDeleted': False
|
||||||
|
}
|
||||||
|
|
||||||
|
# Try to get info from Civitai if modelVersionId is available
|
||||||
|
if version_id and civitai_client:
|
||||||
|
try:
|
||||||
|
# Use get_model_version_info instead of get_model_version
|
||||||
|
civitai_info, error = await civitai_client.get_model_version_info(version_id)
|
||||||
|
|
||||||
|
if error:
|
||||||
|
logger.warning(f"Error getting model version info: {error}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
populated_entry = await self.populate_lora_from_civitai(
|
||||||
|
lora_entry,
|
||||||
|
civitai_info,
|
||||||
|
recipe_scanner,
|
||||||
|
base_model_counts
|
||||||
|
)
|
||||||
|
|
||||||
|
if populated_entry is None:
|
||||||
|
continue # Skip invalid LoRA types
|
||||||
|
|
||||||
|
lora_entry = populated_entry
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching Civitai info for model version {version_id}: {e}")
|
||||||
|
|
||||||
|
# Track this LoRA in our deduplication dict
|
||||||
|
if version_id:
|
||||||
|
added_loras[version_id] = len(result["loras"])
|
||||||
|
|
||||||
|
result["loras"].append(lora_entry)
|
||||||
|
|
||||||
|
# Process additionalResources array
|
||||||
|
if "additionalResources" in metadata and isinstance(metadata["additionalResources"], list):
|
||||||
|
for resource in metadata["additionalResources"]:
|
||||||
|
# Skip resources that aren't LoRAs or LyCORIS
|
||||||
|
if resource.get("type") not in ["lora", "lycoris"] and "type" not in resource:
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora_type = resource.get("type", "lora")
|
||||||
|
name = resource.get("name", "")
|
||||||
|
|
||||||
|
# Extract ID from URN format if available
|
||||||
|
version_id = None
|
||||||
|
if name and "civitai:" in name:
|
||||||
|
parts = name.split("@")
|
||||||
|
if len(parts) > 1:
|
||||||
|
version_id = parts[1]
|
||||||
|
|
||||||
|
# Skip if we've already added this LoRA
|
||||||
|
if version_id in added_loras:
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora_entry = {
|
||||||
|
'name': name,
|
||||||
|
'type': lora_type,
|
||||||
|
'weight': float(resource.get("strength", 1.0)),
|
||||||
|
'hash': "",
|
||||||
|
'existsLocally': False,
|
||||||
|
'localPath': None,
|
||||||
|
'file_name': name,
|
||||||
|
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||||
|
'baseModel': '',
|
||||||
|
'size': 0,
|
||||||
|
'downloadUrl': '',
|
||||||
|
'isDeleted': False
|
||||||
|
}
|
||||||
|
|
||||||
|
# If we have a version ID and civitai client, try to get more info
|
||||||
|
if version_id and civitai_client:
|
||||||
|
try:
|
||||||
|
# Use get_model_version_info with the version ID
|
||||||
|
civitai_info, error = await civitai_client.get_model_version_info(version_id)
|
||||||
|
|
||||||
|
if error:
|
||||||
|
logger.warning(f"Error getting model version info: {error}")
|
||||||
|
else:
|
||||||
populated_entry = await self.populate_lora_from_civitai(
|
populated_entry = await self.populate_lora_from_civitai(
|
||||||
lora_entry,
|
lora_entry,
|
||||||
civitai_info,
|
civitai_info,
|
||||||
@@ -177,65 +266,14 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
|||||||
continue # Skip invalid LoRA types
|
continue # Skip invalid LoRA types
|
||||||
|
|
||||||
lora_entry = populated_entry
|
lora_entry = populated_entry
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error fetching Civitai info for model version {resource.get('modelVersionId')}: {e}")
|
|
||||||
|
|
||||||
result["loras"].append(lora_entry)
|
|
||||||
|
|
||||||
# Process additionalResources array
|
|
||||||
if "additionalResources" in metadata and isinstance(metadata["additionalResources"], list):
|
|
||||||
for resource in metadata["additionalResources"]:
|
|
||||||
# Modified to process resources without a type field as potential LoRAs
|
|
||||||
if resource.get("type") in ["lora", "lycoris"] or "type" not in resource:
|
|
||||||
lora_type = resource.get("type", "lora")
|
|
||||||
name = resource.get("name", "")
|
|
||||||
|
|
||||||
# Extract ID from URN format if available
|
|
||||||
model_id = None
|
|
||||||
if name and "civitai:" in name:
|
|
||||||
parts = name.split("@")
|
|
||||||
if len(parts) > 1:
|
|
||||||
model_id = parts[1]
|
|
||||||
|
|
||||||
lora_entry = {
|
|
||||||
'name': name,
|
|
||||||
'type': lora_type,
|
|
||||||
'weight': float(resource.get("strength", 1.0)),
|
|
||||||
'hash': "",
|
|
||||||
'existsLocally': False,
|
|
||||||
'localPath': None,
|
|
||||||
'file_name': name,
|
|
||||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
|
||||||
'baseModel': '',
|
|
||||||
'size': 0,
|
|
||||||
'downloadUrl': '',
|
|
||||||
'isDeleted': False
|
|
||||||
}
|
|
||||||
|
|
||||||
# If we have a model ID and civitai client, try to get more info
|
|
||||||
if model_id and civitai_client:
|
|
||||||
try:
|
|
||||||
# Use get_model_version_info with the model ID
|
|
||||||
civitai_info, error = await civitai_client.get_model_version_info(model_id)
|
|
||||||
|
|
||||||
if error:
|
# Track this LoRA for deduplication
|
||||||
logger.warning(f"Error getting model version info: {error}")
|
if version_id:
|
||||||
else:
|
added_loras[version_id] = len(result["loras"])
|
||||||
populated_entry = await self.populate_lora_from_civitai(
|
except Exception as e:
|
||||||
lora_entry,
|
logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}")
|
||||||
civitai_info,
|
|
||||||
recipe_scanner,
|
result["loras"].append(lora_entry)
|
||||||
base_model_counts
|
|
||||||
)
|
|
||||||
|
|
||||||
if populated_entry is None:
|
|
||||||
continue # Skip invalid LoRA types
|
|
||||||
|
|
||||||
lora_entry = populated_entry
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error fetching Civitai info for model ID {model_id}: {e}")
|
|
||||||
|
|
||||||
result["loras"].append(lora_entry)
|
|
||||||
|
|
||||||
# If base model wasn't found earlier, use the most common one from LoRAs
|
# If base model wasn't found earlier, use the most common one from LoRAs
|
||||||
if not result["base_model"] and base_model_counts:
|
if not result["base_model"] and base_model_counts:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Dict
|
|||||||
from server import PromptServer # type: ignore
|
from server import PromptServer # type: ignore
|
||||||
|
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
from ..utils.routes_common import ModelRouteUtils
|
||||||
from ..nodes.utils import get_lora_info
|
from ..utils.utils import get_lora_info
|
||||||
|
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..services.websocket_manager import ws_manager
|
from ..services.websocket_manager import ws_manager
|
||||||
@@ -50,13 +50,14 @@ class ApiRoutes:
|
|||||||
app.router.add_get('/api/loras', routes.get_loras)
|
app.router.add_get('/api/loras', routes.get_loras)
|
||||||
app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai)
|
app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai)
|
||||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||||
|
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection) # Add new WebSocket route for download progress
|
||||||
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) # Add new WebSocket route
|
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) # Add new WebSocket route
|
||||||
app.router.add_get('/api/lora-roots', routes.get_lora_roots)
|
app.router.add_get('/api/lora-roots', routes.get_lora_roots)
|
||||||
app.router.add_get('/api/folders', routes.get_folders)
|
app.router.add_get('/api/folders', routes.get_folders)
|
||||||
app.router.add_get('/api/civitai/versions/{model_id}', routes.get_civitai_versions)
|
app.router.add_get('/api/civitai/versions/{model_id}', routes.get_civitai_versions)
|
||||||
app.router.add_get('/api/civitai/model/version/{modelVersionId}', routes.get_civitai_model_by_version)
|
app.router.add_get('/api/civitai/model/version/{modelVersionId}', routes.get_civitai_model_by_version)
|
||||||
app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash)
|
app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash)
|
||||||
app.router.add_post('/api/download-lora', routes.download_lora)
|
app.router.add_post('/api/download-model', routes.download_model)
|
||||||
app.router.add_post('/api/move_model', routes.move_model)
|
app.router.add_post('/api/move_model', routes.move_model)
|
||||||
app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route
|
app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route
|
||||||
app.router.add_post('/api/loras/save-metadata', routes.save_metadata)
|
app.router.add_post('/api/loras/save-metadata', routes.save_metadata)
|
||||||
@@ -436,69 +437,8 @@ class ApiRoutes:
|
|||||||
"error": str(e)
|
"error": str(e)
|
||||||
}, status=500)
|
}, status=500)
|
||||||
|
|
||||||
async def download_lora(self, request: web.Request) -> web.Response:
|
async def download_model(self, request: web.Request) -> web.Response:
|
||||||
async with self._download_lock:
|
return await ModelRouteUtils.handle_download_model(request, self.download_manager)
|
||||||
try:
|
|
||||||
if self.download_manager is None:
|
|
||||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
|
||||||
|
|
||||||
data = await request.json()
|
|
||||||
|
|
||||||
# Create progress callback
|
|
||||||
async def progress_callback(progress):
|
|
||||||
await ws_manager.broadcast({
|
|
||||||
'status': 'progress',
|
|
||||||
'progress': progress
|
|
||||||
})
|
|
||||||
|
|
||||||
# Check which identifier is provided
|
|
||||||
download_url = data.get('download_url')
|
|
||||||
model_hash = data.get('model_hash')
|
|
||||||
model_version_id = data.get('model_version_id')
|
|
||||||
|
|
||||||
# Validate that at least one identifier is provided
|
|
||||||
if not any([download_url, model_hash, model_version_id]):
|
|
||||||
return web.Response(
|
|
||||||
status=400,
|
|
||||||
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await self.download_manager.download_from_civitai(
|
|
||||||
download_url=download_url,
|
|
||||||
model_hash=model_hash,
|
|
||||||
model_version_id=model_version_id,
|
|
||||||
save_dir=data.get('lora_root'),
|
|
||||||
relative_path=data.get('relative_path'),
|
|
||||||
progress_callback=progress_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
if not result.get('success', False):
|
|
||||||
error_message = result.get('error', 'Unknown error')
|
|
||||||
|
|
||||||
# Return 401 for early access errors
|
|
||||||
if 'early access' in error_message.lower():
|
|
||||||
logger.warning(f"Early access download failed: {error_message}")
|
|
||||||
return web.Response(
|
|
||||||
status=401, # Use 401 status code to match Civitai's response
|
|
||||||
text=error_message
|
|
||||||
)
|
|
||||||
|
|
||||||
return web.Response(status=500, text=error_message)
|
|
||||||
|
|
||||||
return web.json_response(result)
|
|
||||||
except Exception as e:
|
|
||||||
error_message = str(e)
|
|
||||||
|
|
||||||
# Check if this might be an early access error
|
|
||||||
if '401' in error_message:
|
|
||||||
logger.warning(f"Early access error (401): {error_message}")
|
|
||||||
return web.Response(
|
|
||||||
status=401,
|
|
||||||
text="Early Access Restriction: This LoRA requires purchase. Please buy early access on Civitai.com."
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.error(f"Error downloading LoRA: {error_message}")
|
|
||||||
return web.Response(status=500, text=error_message)
|
|
||||||
|
|
||||||
|
|
||||||
async def move_model(self, request: web.Request) -> web.Response:
|
async def move_model(self, request: web.Request) -> web.Response:
|
||||||
|
|||||||
@@ -54,12 +54,8 @@ class CheckpointsRoutes:
|
|||||||
app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai)
|
app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai)
|
||||||
app.router.add_post('/api/checkpoints/relink-civitai', self.relink_civitai) # Add new relink endpoint
|
app.router.add_post('/api/checkpoints/relink-civitai', self.relink_civitai) # Add new relink endpoint
|
||||||
app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview)
|
app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview)
|
||||||
app.router.add_post('/api/checkpoints/download', self.download_checkpoint)
|
|
||||||
app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route
|
app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route
|
||||||
app.router.add_post('/api/checkpoints/rename', self.rename_checkpoint) # Add new rename endpoint
|
app.router.add_post('/api/checkpoints/rename', self.rename_checkpoint) # Add new rename endpoint
|
||||||
|
|
||||||
# Add new WebSocket endpoint for checkpoint progress
|
|
||||||
app.router.add_get('/ws/checkpoint-progress', ws_manager.handle_checkpoint_connection)
|
|
||||||
|
|
||||||
# Add new routes for finding duplicates and filename conflicts
|
# Add new routes for finding duplicates and filename conflicts
|
||||||
app.router.add_get('/api/checkpoints/find-duplicates', self.find_duplicate_checkpoints)
|
app.router.add_get('/api/checkpoints/find-duplicates', self.find_duplicate_checkpoints)
|
||||||
@@ -542,74 +538,6 @@ class CheckpointsRoutes:
|
|||||||
"""Handle preview image replacement for checkpoints"""
|
"""Handle preview image replacement for checkpoints"""
|
||||||
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
|
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
|
||||||
|
|
||||||
async def download_checkpoint(self, request: web.Request) -> web.Response:
|
|
||||||
"""Handle checkpoint download request"""
|
|
||||||
async with self._download_lock:
|
|
||||||
# Get the download manager from service registry if not already initialized
|
|
||||||
if self.download_manager is None:
|
|
||||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = await request.json()
|
|
||||||
|
|
||||||
# Create progress callback that uses checkpoint-specific WebSocket
|
|
||||||
async def progress_callback(progress):
|
|
||||||
await ws_manager.broadcast_checkpoint_progress({
|
|
||||||
'status': 'progress',
|
|
||||||
'progress': progress
|
|
||||||
})
|
|
||||||
|
|
||||||
# Check which identifier is provided
|
|
||||||
download_url = data.get('download_url')
|
|
||||||
model_hash = data.get('model_hash')
|
|
||||||
model_version_id = data.get('model_version_id')
|
|
||||||
|
|
||||||
# Validate that at least one identifier is provided
|
|
||||||
if not any([download_url, model_hash, model_version_id]):
|
|
||||||
return web.Response(
|
|
||||||
status=400,
|
|
||||||
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await self.download_manager.download_from_civitai(
|
|
||||||
download_url=download_url,
|
|
||||||
model_hash=model_hash,
|
|
||||||
model_version_id=model_version_id,
|
|
||||||
save_dir=data.get('checkpoint_root'),
|
|
||||||
relative_path=data.get('relative_path', ''),
|
|
||||||
progress_callback=progress_callback,
|
|
||||||
model_type="checkpoint"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not result.get('success', False):
|
|
||||||
error_message = result.get('error', 'Unknown error')
|
|
||||||
|
|
||||||
# Return 401 for early access errors
|
|
||||||
if 'early access' in error_message.lower():
|
|
||||||
logger.warning(f"Early access download failed: {error_message}")
|
|
||||||
return web.Response(
|
|
||||||
status=401,
|
|
||||||
text=f"Early Access Restriction: {error_message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return web.Response(status=500, text=error_message)
|
|
||||||
|
|
||||||
return web.json_response(result)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_message = str(e)
|
|
||||||
|
|
||||||
# Check if this might be an early access error
|
|
||||||
if '401' in error_message:
|
|
||||||
logger.warning(f"Early access error (401): {error_message}")
|
|
||||||
return web.Response(
|
|
||||||
status=401,
|
|
||||||
text="Early Access Restriction: This model requires purchase. Please ensure you have purchased early access and are logged in to Civitai."
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.error(f"Error downloading checkpoint: {error_message}")
|
|
||||||
return web.Response(status=500, text=error_message)
|
|
||||||
|
|
||||||
async def get_checkpoint_roots(self, request):
|
async def get_checkpoint_roots(self, request):
|
||||||
"""Return the checkpoint root directories"""
|
"""Return the checkpoint root directories"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from ..utils.example_images_download_manager import DownloadManager
|
from ..utils.example_images_download_manager import DownloadManager
|
||||||
from ..utils.example_images_processor import ExampleImagesProcessor
|
from ..utils.example_images_processor import ExampleImagesProcessor
|
||||||
from ..utils.example_images_metadata import MetadataUpdater
|
|
||||||
from ..utils.example_images_file_manager import ExampleImagesFileManager
|
from ..utils.example_images_file_manager import ExampleImagesFileManager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from ..utils.usage_stats import UsageStats
|
|||||||
from ..utils.lora_metadata import extract_trained_words
|
from ..utils.lora_metadata import extract_trained_words
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS, NODE_TYPES, DEFAULT_NODE_COLOR
|
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS, NODE_TYPES, DEFAULT_NODE_COLOR
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
import re
|
import re
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -90,6 +91,8 @@ class MiscRoutes:
|
|||||||
# Add new route for clearing cache
|
# Add new route for clearing cache
|
||||||
app.router.add_post('/api/clear-cache', MiscRoutes.clear_cache)
|
app.router.add_post('/api/clear-cache', MiscRoutes.clear_cache)
|
||||||
|
|
||||||
|
app.router.add_get('/api/health-check', lambda request: web.json_response({'status': 'ok'}))
|
||||||
|
|
||||||
# Usage stats routes
|
# Usage stats routes
|
||||||
app.router.add_post('/api/update-usage-stats', MiscRoutes.update_usage_stats)
|
app.router.add_post('/api/update-usage-stats', MiscRoutes.update_usage_stats)
|
||||||
app.router.add_get('/api/get-usage-stats', MiscRoutes.get_usage_stats)
|
app.router.add_get('/api/get-usage-stats', MiscRoutes.get_usage_stats)
|
||||||
@@ -106,6 +109,9 @@ class MiscRoutes:
|
|||||||
# Node registry endpoints
|
# Node registry endpoints
|
||||||
app.router.add_post('/api/register-nodes', MiscRoutes.register_nodes)
|
app.router.add_post('/api/register-nodes', MiscRoutes.register_nodes)
|
||||||
app.router.add_get('/api/get-registry', MiscRoutes.get_registry)
|
app.router.add_get('/api/get-registry', MiscRoutes.get_registry)
|
||||||
|
|
||||||
|
# Add new route for checking if a model exists in the library
|
||||||
|
app.router.add_get('/api/check-model-exists', MiscRoutes.check_model_exists)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def clear_cache(request):
|
async def clear_cache(request):
|
||||||
@@ -580,3 +586,106 @@ class MiscRoutes:
|
|||||||
'error': 'Internal Error',
|
'error': 'Internal Error',
|
||||||
'message': str(e)
|
'message': str(e)
|
||||||
}, status=500)
|
}, status=500)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def check_model_exists(request):
|
||||||
|
"""
|
||||||
|
Check if a model with specified modelId and optionally modelVersionId exists in the library
|
||||||
|
|
||||||
|
Expects query parameters:
|
||||||
|
- modelId: int - Civitai model ID (required)
|
||||||
|
- modelVersionId: int - Civitai model version ID (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- If modelVersionId is provided: JSON with a boolean 'exists' field
|
||||||
|
- If modelVersionId is not provided: JSON with a list of modelVersionIds that exist in the library
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get the modelId and modelVersionId from query parameters
|
||||||
|
model_id_str = request.query.get('modelId')
|
||||||
|
model_version_id_str = request.query.get('modelVersionId')
|
||||||
|
|
||||||
|
# Validate modelId parameter (required)
|
||||||
|
if not model_id_str:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Missing required parameter: modelId'
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert modelId to integer
|
||||||
|
model_id = int(model_id_str)
|
||||||
|
except ValueError:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Parameter modelId must be an integer'
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
# Get both lora and checkpoint scanners
|
||||||
|
registry = ServiceRegistry.get_instance()
|
||||||
|
lora_scanner = await registry.get_lora_scanner()
|
||||||
|
checkpoint_scanner = await registry.get_checkpoint_scanner()
|
||||||
|
|
||||||
|
# If modelVersionId is provided, check for specific version
|
||||||
|
if model_version_id_str:
|
||||||
|
try:
|
||||||
|
model_version_id = int(model_version_id_str)
|
||||||
|
except ValueError:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Parameter modelVersionId must be an integer'
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
# Check if the specific version exists in either scanner
|
||||||
|
exists = False
|
||||||
|
model_type = None
|
||||||
|
|
||||||
|
# Check lora scanner first
|
||||||
|
if await lora_scanner.check_model_version_exists(model_id, model_version_id):
|
||||||
|
exists = True
|
||||||
|
model_type = 'lora'
|
||||||
|
# If not found in lora, check checkpoint scanner
|
||||||
|
elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_id, model_version_id):
|
||||||
|
exists = True
|
||||||
|
model_type = 'checkpoint'
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'exists': exists,
|
||||||
|
'modelType': model_type if exists else None
|
||||||
|
})
|
||||||
|
|
||||||
|
# If modelVersionId is not provided, return all version IDs for the model
|
||||||
|
else:
|
||||||
|
# Get versions from lora scanner first
|
||||||
|
lora_versions = await lora_scanner.get_model_versions_by_id(model_id)
|
||||||
|
checkpoint_versions = []
|
||||||
|
|
||||||
|
# Only check checkpoint scanner if no lora versions found
|
||||||
|
if not lora_versions:
|
||||||
|
checkpoint_versions = await checkpoint_scanner.get_model_versions_by_id(model_id)
|
||||||
|
|
||||||
|
# Determine model type and combine results
|
||||||
|
model_type = None
|
||||||
|
versions = []
|
||||||
|
|
||||||
|
if lora_versions:
|
||||||
|
model_type = 'lora'
|
||||||
|
versions = lora_versions
|
||||||
|
elif checkpoint_versions:
|
||||||
|
model_type = 'checkpoint'
|
||||||
|
versions = checkpoint_versions
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'modelId': model_id,
|
||||||
|
'modelType': model_type,
|
||||||
|
'versions': versions
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to check model existence: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import time
|
|||||||
import base64
|
import base64
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torch
|
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
@@ -1018,6 +1017,8 @@ class RecipeRoutes:
|
|||||||
shape_info = tensor_image.shape
|
shape_info = tensor_image.shape
|
||||||
logger.debug(f"Tensor shape: {shape_info}, dtype: {tensor_image.dtype}")
|
logger.debug(f"Tensor shape: {shape_info}, dtype: {tensor_image.dtype}")
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
# Convert tensor to numpy array
|
# Convert tensor to numpy array
|
||||||
if isinstance(tensor_image, torch.Tensor):
|
if isinstance(tensor_image, torch.Tensor):
|
||||||
image_np = tensor_image.cpu().numpy()
|
image_np = tensor_image.cpu().numpy()
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ class CheckpointScanner(ModelScanner):
|
|||||||
file_extensions=file_extensions,
|
file_extensions=file_extensions,
|
||||||
hash_index=ModelHashIndex()
|
hash_index=ModelHashIndex()
|
||||||
)
|
)
|
||||||
self._checkpoint_roots = self._init_checkpoint_roots()
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -44,27 +43,9 @@ class CheckpointScanner(ModelScanner):
|
|||||||
cls._instance = cls()
|
cls._instance = cls()
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def _init_checkpoint_roots(self) -> List[str]:
|
|
||||||
"""Initialize checkpoint roots from ComfyUI settings"""
|
|
||||||
# Get both checkpoint and diffusion_models paths
|
|
||||||
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
|
||||||
diffusion_paths = folder_paths.get_folder_paths("diffusion_models")
|
|
||||||
|
|
||||||
# Combine, normalize and deduplicate paths
|
|
||||||
all_paths = set()
|
|
||||||
for path in checkpoint_paths + diffusion_paths:
|
|
||||||
if os.path.exists(path):
|
|
||||||
norm_path = path.replace(os.sep, "/")
|
|
||||||
all_paths.add(norm_path)
|
|
||||||
|
|
||||||
# Sort for consistent order
|
|
||||||
sorted_paths = sorted(all_paths, key=lambda p: p.lower())
|
|
||||||
|
|
||||||
return sorted_paths
|
|
||||||
|
|
||||||
def get_model_roots(self) -> List[str]:
|
def get_model_roots(self) -> List[str]:
|
||||||
"""Get checkpoint root directories"""
|
"""Get checkpoint root directories"""
|
||||||
return self._checkpoint_roots
|
return config.base_models_roots
|
||||||
|
|
||||||
async def scan_all_models(self) -> List[Dict]:
|
async def scan_all_models(self) -> List[Dict]:
|
||||||
"""Scan all checkpoint directories and return metadata"""
|
"""Scan all checkpoint directories and return metadata"""
|
||||||
@@ -72,7 +53,7 @@ class CheckpointScanner(ModelScanner):
|
|||||||
|
|
||||||
# Create scan tasks for each directory
|
# Create scan tasks for each directory
|
||||||
scan_tasks = []
|
scan_tasks = []
|
||||||
for root in self._checkpoint_roots:
|
for root in self.get_model_roots():
|
||||||
task = asyncio.create_task(self._scan_directory(root))
|
task = asyncio.create_task(self._scan_directory(root))
|
||||||
scan_tasks.append(task)
|
scan_tasks.append(task)
|
||||||
|
|
||||||
|
|||||||
@@ -225,7 +225,7 @@ class CivitaiClient:
|
|||||||
logger.error(f"Error fetching model versions: {e}")
|
logger.error(f"Error fetching model versions: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_model_version(self, model_id: str, version_id: str = "") -> Optional[Dict]:
|
async def get_model_version(self, model_id: int, version_id: int = None) -> Optional[Dict]:
|
||||||
"""Get specific model version with additional metadata
|
"""Get specific model version with additional metadata
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -237,6 +237,8 @@ class CivitaiClient:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
session = await self._ensure_fresh_session()
|
session = await self._ensure_fresh_session()
|
||||||
|
|
||||||
|
# Step 1: Get model data to find version_id if not provided and get additional metadata
|
||||||
async with session.get(f"{self.base_url}/models/{model_id}") as response:
|
async with session.get(f"{self.base_url}/models/{model_id}") as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
return None
|
return None
|
||||||
@@ -244,45 +246,28 @@ class CivitaiClient:
|
|||||||
data = await response.json()
|
data = await response.json()
|
||||||
model_versions = data.get('modelVersions', [])
|
model_versions = data.get('modelVersions', [])
|
||||||
|
|
||||||
# Find matching version
|
# Step 2: Determine the version_id to use
|
||||||
matched_version = None
|
target_version_id = version_id
|
||||||
|
if target_version_id is None:
|
||||||
if version_id:
|
target_version_id = model_versions[0].get('id')
|
||||||
# If version_id provided, find exact match
|
|
||||||
for version in model_versions:
|
# Step 3: Get detailed version info using the version_id
|
||||||
if str(version.get('id')) == str(version_id):
|
headers = self._get_request_headers()
|
||||||
matched_version = version
|
async with session.get(f"{self.base_url}/model-versions/{target_version_id}", headers=headers) as response:
|
||||||
break
|
if response.status != 200:
|
||||||
else:
|
|
||||||
# If no version_id then use the first version
|
|
||||||
matched_version = model_versions[0] if model_versions else None
|
|
||||||
|
|
||||||
# If no match found, return None
|
|
||||||
if not matched_version:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Build result with modified fields
|
|
||||||
result = matched_version.copy() # Copy to avoid modifying original
|
|
||||||
|
|
||||||
# Replace index with modelId
|
version = await response.json()
|
||||||
if 'index' in result:
|
|
||||||
del result['index']
|
|
||||||
result['modelId'] = model_id
|
|
||||||
|
|
||||||
# Add model field with metadata from top level
|
# Step 4: Enrich version_info with model data
|
||||||
result['model'] = {
|
# Add description and tags from model data
|
||||||
"name": data.get("name"),
|
version['model']['description'] = data.get("description")
|
||||||
"type": data.get("type"),
|
version['model']['tags'] = data.get("tags", [])
|
||||||
"nsfw": data.get("nsfw", False),
|
|
||||||
"poi": data.get("poi", False),
|
|
||||||
"description": data.get("description"),
|
|
||||||
"tags": data.get("tags", [])
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add creator field from top level
|
# Add creator from model data
|
||||||
result['creator'] = data.get("creator")
|
version['creator'] = data.get("creator")
|
||||||
|
|
||||||
return result
|
return version
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching model version: {e}")
|
logger.error(f"Error fetching model version: {e}")
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import json
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from ..utils.models import LoraMetadata, CheckpointMetadata
|
from ..utils.models import LoraMetadata, CheckpointMetadata
|
||||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES
|
||||||
from ..utils.exif_utils import ExifUtils
|
from ..utils.exif_utils import ExifUtils
|
||||||
from ..utils.metadata_manager import MetadataManager
|
from ..utils.metadata_manager import MetadataManager
|
||||||
from .service_registry import ServiceRegistry
|
from .service_registry import ServiceRegistry
|
||||||
|
from .settings_manager import settings
|
||||||
|
|
||||||
# Download to temporary file first
|
# Download to temporary file first
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -48,54 +48,98 @@ class DownloadManager:
|
|||||||
"""Get the checkpoint scanner from registry"""
|
"""Get the checkpoint scanner from registry"""
|
||||||
return await ServiceRegistry.get_checkpoint_scanner()
|
return await ServiceRegistry.get_checkpoint_scanner()
|
||||||
|
|
||||||
async def download_from_civitai(self, download_url: str = None, model_hash: str = None,
|
async def download_from_civitai(self, model_id: int,
|
||||||
model_version_id: str = None, save_dir: str = None,
|
model_version_id: int, save_dir: str = None,
|
||||||
relative_path: str = '', progress_callback=None,
|
relative_path: str = '', progress_callback=None, use_default_paths: bool = False) -> Dict:
|
||||||
model_type: str = "lora") -> Dict:
|
|
||||||
"""Download model from Civitai
|
"""Download model from Civitai
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
download_url: Direct download URL for the model
|
model_id: Civitai model ID
|
||||||
model_hash: SHA256 hash of the model
|
model_version_id: Civitai model version ID (optional, if not provided, will download the latest version)
|
||||||
model_version_id: Civitai model version ID
|
|
||||||
save_dir: Directory to save the model to
|
save_dir: Directory to save the model to
|
||||||
relative_path: Relative path within save_dir
|
relative_path: Relative path within save_dir
|
||||||
progress_callback: Callback function for progress updates
|
progress_callback: Callback function for progress updates
|
||||||
model_type: Type of model ('lora' or 'checkpoint')
|
use_default_paths: Flag to indicate whether to use default paths
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with download result
|
Dict with download result
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Update save directory with relative path if provided
|
# Check if model version already exists in library
|
||||||
if relative_path:
|
if model_version_id is not None:
|
||||||
save_dir = os.path.join(save_dir, relative_path)
|
# Case 1: model_version_id is provided, check both scanners
|
||||||
# Create directory if it doesn't exist
|
lora_scanner = await self._get_lora_scanner()
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
checkpoint_scanner = await self._get_checkpoint_scanner()
|
||||||
|
|
||||||
|
# Check lora scanner first
|
||||||
|
if await lora_scanner.check_model_version_exists(model_id, model_version_id):
|
||||||
|
return {'success': False, 'error': 'Model version already exists in lora library'}
|
||||||
|
|
||||||
|
# Check checkpoint scanner
|
||||||
|
if await checkpoint_scanner.check_model_version_exists(model_id, model_version_id):
|
||||||
|
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
|
||||||
|
|
||||||
# Get civitai client
|
# Get civitai client
|
||||||
civitai_client = await self._get_civitai_client()
|
civitai_client = await self._get_civitai_client()
|
||||||
|
|
||||||
# Get version info based on the provided identifier
|
# Get version info based on the provided identifier
|
||||||
version_info = None
|
version_info = await civitai_client.get_model_version(model_id, model_version_id)
|
||||||
error_msg = None
|
|
||||||
|
|
||||||
if model_hash:
|
|
||||||
# Get model by hash
|
|
||||||
version_info = await civitai_client.get_model_by_hash(model_hash)
|
|
||||||
elif model_version_id:
|
|
||||||
# Use model version ID directly
|
|
||||||
version_info, error_msg = await civitai_client.get_model_version_info(model_version_id)
|
|
||||||
elif download_url:
|
|
||||||
# Extract version ID from download URL
|
|
||||||
version_id = download_url.split('/')[-1]
|
|
||||||
version_info, error_msg = await civitai_client.get_model_version_info(version_id)
|
|
||||||
|
|
||||||
|
|
||||||
if not version_info:
|
if not version_info:
|
||||||
if error_msg and "model not found" in error_msg.lower():
|
return {'success': False, 'error': 'Failed to fetch model metadata'}
|
||||||
return {'success': False, 'error': f'Model not found on Civitai: {error_msg}'}
|
|
||||||
return {'success': False, 'error': error_msg or 'Failed to fetch model metadata'}
|
model_type_from_info = version_info.get('model', {}).get('type', '').lower()
|
||||||
|
if model_type_from_info == 'checkpoint':
|
||||||
|
model_type = 'checkpoint'
|
||||||
|
elif model_type_from_info in VALID_LORA_TYPES:
|
||||||
|
model_type = 'lora'
|
||||||
|
else:
|
||||||
|
return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'}
|
||||||
|
|
||||||
|
# Case 2: model_version_id was None, check after getting version_info
|
||||||
|
if model_version_id is None:
|
||||||
|
version_model_id = version_info.get('modelId')
|
||||||
|
version_id = version_info.get('id')
|
||||||
|
|
||||||
|
if model_type == 'lora':
|
||||||
|
# Check lora scanner
|
||||||
|
lora_scanner = await self._get_lora_scanner()
|
||||||
|
if await lora_scanner.check_model_version_exists(version_model_id, version_id):
|
||||||
|
return {'success': False, 'error': 'Model version already exists in lora library'}
|
||||||
|
elif model_type == 'checkpoint':
|
||||||
|
# Check checkpoint scanner
|
||||||
|
checkpoint_scanner = await self._get_checkpoint_scanner()
|
||||||
|
if await checkpoint_scanner.check_model_version_exists(version_model_id, version_id):
|
||||||
|
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
|
||||||
|
|
||||||
|
# Handle use_default_paths
|
||||||
|
if use_default_paths:
|
||||||
|
# Set save_dir based on model type
|
||||||
|
if model_type == 'checkpoint':
|
||||||
|
default_path = settings.get('default_checkpoint_root')
|
||||||
|
if not default_path:
|
||||||
|
return {'success': False, 'error': 'Default checkpoint root path not set in settings'}
|
||||||
|
save_dir = default_path
|
||||||
|
else: # model_type == 'lora'
|
||||||
|
default_path = settings.get('default_lora_root')
|
||||||
|
if not default_path:
|
||||||
|
return {'success': False, 'error': 'Default lora root path not set in settings'}
|
||||||
|
save_dir = default_path
|
||||||
|
|
||||||
|
# Set relative_path to version_info.baseModel/first_tag if available
|
||||||
|
base_model = version_info.get('baseModel', '')
|
||||||
|
model_tags = version_info.get('model', {}).get('tags', [])
|
||||||
|
if base_model:
|
||||||
|
if model_tags:
|
||||||
|
relative_path = os.path.join(base_model, model_tags[0])
|
||||||
|
else:
|
||||||
|
relative_path = base_model
|
||||||
|
|
||||||
|
# Update save directory with relative path if provided
|
||||||
|
if relative_path:
|
||||||
|
save_dir = os.path.join(save_dir, relative_path)
|
||||||
|
# Create directory if it doesn't exist
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
# Check if this is an early access model
|
# Check if this is an early access model
|
||||||
if version_info.get('earlyAccessEndsAt'):
|
if version_info.get('earlyAccessEndsAt'):
|
||||||
@@ -137,18 +181,6 @@ class DownloadManager:
|
|||||||
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
|
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||||
logger.info(f"Creating LoraMetadata for {file_name}")
|
logger.info(f"Creating LoraMetadata for {file_name}")
|
||||||
|
|
||||||
# 5.1 Get and update model tags, description and creator info
|
|
||||||
model_id = version_info.get('modelId')
|
|
||||||
if model_id:
|
|
||||||
model_metadata, _ = await civitai_client.get_model_metadata(str(model_id))
|
|
||||||
if model_metadata:
|
|
||||||
if model_metadata.get("tags"):
|
|
||||||
metadata.tags = model_metadata.get("tags", [])
|
|
||||||
if model_metadata.get("description"):
|
|
||||||
metadata.modelDescription = model_metadata.get("description", "")
|
|
||||||
if model_metadata.get("creator"):
|
|
||||||
metadata.civitai["creator"] = model_metadata.get("creator")
|
|
||||||
|
|
||||||
# 6. Start download process
|
# 6. Start download process
|
||||||
result = await self._execute_download(
|
result = await self._execute_download(
|
||||||
download_url=file_info.get('downloadUrl', ''),
|
download_url=file_info.get('downloadUrl', ''),
|
||||||
|
|||||||
@@ -1,11 +1,7 @@
|
|||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
import shutil
|
from typing import List, Dict, Optional
|
||||||
import time
|
|
||||||
import re
|
|
||||||
from typing import List, Dict, Optional, Set
|
|
||||||
|
|
||||||
from ..utils.models import LoraMetadata
|
from ..utils.models import LoraMetadata
|
||||||
from ..config import config
|
from ..config import config
|
||||||
@@ -14,7 +10,6 @@ from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to Mo
|
|||||||
from .settings_manager import settings
|
from .settings_manager import settings
|
||||||
from ..utils.constants import NSFW_LEVELS
|
from ..utils.constants import NSFW_LEVELS
|
||||||
from ..utils.utils import fuzzy_match
|
from ..utils.utils import fuzzy_match
|
||||||
from .service_registry import ServiceRegistry
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -1362,3 +1362,59 @@ class ModelScanner:
|
|||||||
if file_name in self._hash_index._duplicate_filenames:
|
if file_name in self._hash_index._duplicate_filenames:
|
||||||
if len(self._hash_index._duplicate_filenames[file_name]) <= 1:
|
if len(self._hash_index._duplicate_filenames[file_name]) <= 1:
|
||||||
del self._hash_index._duplicate_filenames[file_name]
|
del self._hash_index._duplicate_filenames[file_name]
|
||||||
|
|
||||||
|
async def check_model_version_exists(self, model_id: int, model_version_id: int) -> bool:
|
||||||
|
"""Check if a specific model version exists in the cache
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Civitai model ID
|
||||||
|
model_version_id: Civitai model version ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the model version exists, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cache = await self.get_cached_data()
|
||||||
|
if not cache or not cache.raw_data:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for item in cache.raw_data:
|
||||||
|
if (item.get('civitai') and
|
||||||
|
item['civitai'].get('modelId') == model_id and
|
||||||
|
item['civitai'].get('id') == model_version_id):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking model version existence: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_model_versions_by_id(self, model_id: int) -> List[Dict]:
|
||||||
|
"""Get all versions of a model by its ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Civitai model ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: List of version information dictionaries
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cache = await self.get_cached_data()
|
||||||
|
if not cache or not cache.raw_data:
|
||||||
|
return []
|
||||||
|
|
||||||
|
versions = []
|
||||||
|
for item in cache.raw_data:
|
||||||
|
if (item.get('civitai') and
|
||||||
|
item['civitai'].get('modelId') == model_id and
|
||||||
|
item['civitai'].get('id')):
|
||||||
|
versions.append({
|
||||||
|
'versionId': item['civitai'].get('id'),
|
||||||
|
'name': item['civitai'].get('name'),
|
||||||
|
'fileName': item.get('file_name', '')
|
||||||
|
})
|
||||||
|
|
||||||
|
return versions
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting model versions: {e}")
|
||||||
|
return []
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from typing import Set, Dict, Optional
|
from typing import Set, Dict, Optional
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -10,7 +11,7 @@ class WebSocketManager:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._websockets: Set[web.WebSocketResponse] = set()
|
self._websockets: Set[web.WebSocketResponse] = set()
|
||||||
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
|
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
|
||||||
self._checkpoint_websockets: Set[web.WebSocketResponse] = set() # New set for checkpoint download progress
|
self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients
|
||||||
|
|
||||||
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
|
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||||
"""Handle new WebSocket connection"""
|
"""Handle new WebSocket connection"""
|
||||||
@@ -39,19 +40,35 @@ class WebSocketManager:
|
|||||||
finally:
|
finally:
|
||||||
self._init_websockets.discard(ws)
|
self._init_websockets.discard(ws)
|
||||||
return ws
|
return ws
|
||||||
|
|
||||||
async def handle_checkpoint_connection(self, request: web.Request) -> web.WebSocketResponse:
|
async def handle_download_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||||
"""Handle new WebSocket connection for checkpoint download progress"""
|
"""Handle new WebSocket connection for download progress"""
|
||||||
ws = web.WebSocketResponse()
|
ws = web.WebSocketResponse()
|
||||||
await ws.prepare(request)
|
await ws.prepare(request)
|
||||||
self._checkpoint_websockets.add(ws)
|
|
||||||
|
# Get download_id from query parameters
|
||||||
|
download_id = request.query.get('id')
|
||||||
|
|
||||||
|
if not download_id:
|
||||||
|
# Generate a new download ID if not provided
|
||||||
|
download_id = str(uuid4())
|
||||||
|
|
||||||
|
# Store the websocket with its download ID
|
||||||
|
self._download_websockets[download_id] = ws
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Send the download ID back to the client
|
||||||
|
await ws.send_json({
|
||||||
|
'type': 'download_id',
|
||||||
|
'download_id': download_id
|
||||||
|
})
|
||||||
|
|
||||||
async for msg in ws:
|
async for msg in ws:
|
||||||
if msg.type == web.WSMsgType.ERROR:
|
if msg.type == web.WSMsgType.ERROR:
|
||||||
logger.error(f'Checkpoint WebSocket error: {ws.exception()}')
|
logger.error(f'Download WebSocket error: {ws.exception()}')
|
||||||
finally:
|
finally:
|
||||||
self._checkpoint_websockets.discard(ws)
|
if download_id in self._download_websockets:
|
||||||
|
del self._download_websockets[download_id]
|
||||||
return ws
|
return ws
|
||||||
|
|
||||||
async def broadcast(self, data: Dict):
|
async def broadcast(self, data: Dict):
|
||||||
@@ -84,17 +101,18 @@ class WebSocketManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error sending initialization progress: {e}")
|
logger.error(f"Error sending initialization progress: {e}")
|
||||||
|
|
||||||
async def broadcast_checkpoint_progress(self, data: Dict):
|
async def broadcast_download_progress(self, download_id: str, data: Dict):
|
||||||
"""Broadcast checkpoint download progress to connected clients"""
|
"""Send progress update to specific download client"""
|
||||||
if not self._checkpoint_websockets:
|
if download_id not in self._download_websockets:
|
||||||
|
logger.debug(f"No WebSocket found for download ID: {download_id}")
|
||||||
return
|
return
|
||||||
|
|
||||||
for ws in self._checkpoint_websockets:
|
ws = self._download_websockets[download_id]
|
||||||
try:
|
try:
|
||||||
await ws.send_json(data)
|
await ws.send_json(data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error sending checkpoint progress: {e}")
|
logger.error(f"Error sending download progress: {e}")
|
||||||
|
|
||||||
def get_connected_clients_count(self) -> int:
|
def get_connected_clients_count(self) -> int:
|
||||||
"""Get number of connected clients"""
|
"""Get number of connected clients"""
|
||||||
return len(self._websockets)
|
return len(self._websockets)
|
||||||
@@ -102,10 +120,14 @@ class WebSocketManager:
|
|||||||
def get_init_clients_count(self) -> int:
|
def get_init_clients_count(self) -> int:
|
||||||
"""Get number of initialization progress clients"""
|
"""Get number of initialization progress clients"""
|
||||||
return len(self._init_websockets)
|
return len(self._init_websockets)
|
||||||
|
|
||||||
def get_checkpoint_clients_count(self) -> int:
|
def get_download_clients_count(self) -> int:
|
||||||
"""Get number of checkpoint progress clients"""
|
"""Get number of download progress clients"""
|
||||||
return len(self._checkpoint_websockets)
|
return len(self._download_websockets)
|
||||||
|
|
||||||
|
def generate_download_id(self) -> str:
|
||||||
|
"""Generate a unique download ID"""
|
||||||
|
return str(uuid4())
|
||||||
|
|
||||||
# Global instance
|
# Global instance
|
||||||
ws_manager = WebSocketManager()
|
ws_manager = WebSocketManager()
|
||||||
@@ -10,7 +10,8 @@ NSFW_LEVELS = {
|
|||||||
# Node type constants
|
# Node type constants
|
||||||
NODE_TYPES = {
|
NODE_TYPES = {
|
||||||
"Lora Loader (LoraManager)": 1,
|
"Lora Loader (LoraManager)": 1,
|
||||||
"Lora Stacker (LoraManager)": 2
|
"Lora Stacker (LoraManager)": 2,
|
||||||
|
"WanVideo Lora Select (LoraManager)": 3
|
||||||
}
|
}
|
||||||
|
|
||||||
# Default ComfyUI node color when bgcolor is null
|
# Default ComfyUI node color when bgcolor is null
|
||||||
|
|||||||
@@ -8,9 +8,11 @@ from .model_utils import determine_base_model
|
|||||||
from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
|
from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..services.civitai_client import CivitaiClient
|
from ..services.civitai_client import CivitaiClient
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..utils.exif_utils import ExifUtils
|
from ..utils.exif_utils import ExifUtils
|
||||||
from ..utils.metadata_manager import MetadataManager
|
from ..utils.metadata_manager import MetadataManager
|
||||||
from ..services.download_manager import DownloadManager
|
from ..services.download_manager import DownloadManager
|
||||||
|
from ..services.websocket_manager import ws_manager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -564,13 +566,12 @@ class ModelRouteUtils:
|
|||||||
return web.Response(text=str(e), status=500)
|
return web.Response(text=str(e), status=500)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type="lora") -> web.Response:
|
async def handle_download_model(request: web.Request, download_manager: DownloadManager) -> web.Response:
|
||||||
"""Handle model download request
|
"""Handle model download request
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: The aiohttp request
|
request: The aiohttp request
|
||||||
download_manager: Instance of DownloadManager
|
download_manager: Instance of DownloadManager
|
||||||
model_type: Type of model ('lora' or 'checkpoint')
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
web.Response: The HTTP response
|
web.Response: The HTTP response
|
||||||
@@ -578,40 +579,58 @@ class ModelRouteUtils:
|
|||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
|
|
||||||
# Create progress callback
|
# Get or generate a download ID
|
||||||
|
download_id = data.get('download_id', ws_manager.generate_download_id())
|
||||||
|
|
||||||
|
# Create progress callback with download ID
|
||||||
async def progress_callback(progress):
|
async def progress_callback(progress):
|
||||||
from ..services.websocket_manager import ws_manager
|
await ws_manager.broadcast_download_progress(download_id, {
|
||||||
await ws_manager.broadcast({
|
|
||||||
'status': 'progress',
|
'status': 'progress',
|
||||||
'progress': progress
|
'progress': progress,
|
||||||
|
'download_id': download_id
|
||||||
})
|
})
|
||||||
|
|
||||||
# Check which identifier is provided
|
# Check which identifier is provided and convert to int
|
||||||
download_url = data.get('download_url')
|
try:
|
||||||
model_hash = data.get('model_hash')
|
model_id = int(data.get('model_id'))
|
||||||
model_version_id = data.get('model_version_id')
|
except (TypeError, ValueError):
|
||||||
|
return web.Response(
|
||||||
|
status=400,
|
||||||
|
text="Invalid model_id: Must be an integer"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert model_version_id to int if provided
|
||||||
|
model_version_id = None
|
||||||
|
if data.get('model_version_id'):
|
||||||
|
try:
|
||||||
|
model_version_id = int(data.get('model_version_id'))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return web.Response(
|
||||||
|
status=400,
|
||||||
|
text="Invalid model_version_id: Must be an integer"
|
||||||
|
)
|
||||||
|
|
||||||
# Validate that at least one identifier is provided
|
# Only model_id is required, model_version_id is optional
|
||||||
if not any([download_url, model_hash, model_version_id]):
|
if not model_id:
|
||||||
return web.Response(
|
return web.Response(
|
||||||
status=400,
|
status=400,
|
||||||
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
|
text="Missing required parameter: Please provide 'model_id'"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the correct root directory based on model type
|
use_default_paths = data.get('use_default_paths', False)
|
||||||
root_key = 'checkpoint_root' if model_type == 'checkpoint' else 'lora_root'
|
|
||||||
save_dir = data.get(root_key)
|
|
||||||
|
|
||||||
result = await download_manager.download_from_civitai(
|
result = await download_manager.download_from_civitai(
|
||||||
download_url=download_url,
|
model_id=model_id,
|
||||||
model_hash=model_hash,
|
|
||||||
model_version_id=model_version_id,
|
model_version_id=model_version_id,
|
||||||
save_dir=save_dir,
|
save_dir=data.get('model_root'),
|
||||||
relative_path=data.get('relative_path', ''),
|
relative_path=data.get('relative_path', ''),
|
||||||
progress_callback=progress_callback,
|
use_default_paths=use_default_paths,
|
||||||
model_type=model_type
|
progress_callback=progress_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Include download_id in the response
|
||||||
|
result['download_id'] = download_id
|
||||||
|
|
||||||
if not result.get('success', False):
|
if not result.get('success', False):
|
||||||
error_message = result.get('error', 'Unknown error')
|
error_message = result.get('error', 'Unknown error')
|
||||||
|
|
||||||
@@ -638,7 +657,7 @@ class ModelRouteUtils:
|
|||||||
text="Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com."
|
text="Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com."
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.error(f"Error downloading {model_type}: {error_message}")
|
logger.error(f"Error downloading model: {error_message}")
|
||||||
return web.Response(status=500, text=error_message)
|
return web.Response(status=500, text=error_message)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -693,8 +712,10 @@ class ModelRouteUtils:
|
|||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
file_path = data.get('file_path')
|
file_path = data.get('file_path')
|
||||||
model_id = data.get('model_id')
|
model_id = int(data.get('model_id'))
|
||||||
model_version_id = data.get('model_version_id')
|
model_version_id = None
|
||||||
|
if data.get('model_version_id'):
|
||||||
|
model_version_id = int(data.get('model_version_id'))
|
||||||
|
|
||||||
if not file_path or not model_id:
|
if not file_path or not model_id:
|
||||||
return web.json_response({"success": False, "error": "Both file_path and model_id are required"}, status=400)
|
return web.json_response({"success": False, "error": "Both file_path and model_id are required"}, status=400)
|
||||||
|
|||||||
@@ -1,8 +1,29 @@
|
|||||||
from difflib import SequenceMatcher
|
from difflib import SequenceMatcher
|
||||||
import requests
|
import requests
|
||||||
import tempfile
|
import tempfile
|
||||||
import re
|
import os
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
from ..config import config
|
||||||
|
|
||||||
|
async def get_lora_info(lora_name):
|
||||||
|
"""Get the lora path and trigger words from cache"""
|
||||||
|
scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
|
||||||
|
for item in cache.raw_data:
|
||||||
|
if item.get('file_name') == lora_name:
|
||||||
|
file_path = item.get('file_path')
|
||||||
|
if file_path:
|
||||||
|
for root in config.loras_roots:
|
||||||
|
root = root.replace(os.sep, '/')
|
||||||
|
if file_path.startswith(root):
|
||||||
|
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
||||||
|
# Get trigger words from civitai metadata
|
||||||
|
civitai = item.get('civitai', {})
|
||||||
|
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
||||||
|
return relative_path, trigger_words
|
||||||
|
return lora_name, []
|
||||||
|
|
||||||
def download_twitter_image(url):
|
def download_twitter_image(url):
|
||||||
"""Download image from a URL containing twitter:image meta tag
|
"""Download image from a URL containing twitter:image meta tag
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "comfyui-lora-manager"
|
name = "comfyui-lora-manager"
|
||||||
description = "LoRA Manager for ComfyUI - Access it at http://localhost:8188/loras for managing LoRA models with previews and metadata integration."
|
description = "Revolutionize your workflow with the ultimate LoRA companion for ComfyUI!"
|
||||||
version = "0.8.19"
|
version = "0.8.20-beta"
|
||||||
license = {file = "LICENSE"}
|
license = {file = "LICENSE"}
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"jinja2",
|
"jinja2",
|
||||||
"safetensors",
|
"safetensors",
|
||||||
"watchdog",
|
|
||||||
"beautifulsoup4",
|
"beautifulsoup4",
|
||||||
"piexif",
|
"piexif",
|
||||||
"Pillow",
|
"Pillow",
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
aiohttp
|
aiohttp
|
||||||
jinja2
|
jinja2
|
||||||
safetensors
|
safetensors
|
||||||
watchdog
|
|
||||||
beautifulsoup4
|
beautifulsoup4
|
||||||
piexif
|
piexif
|
||||||
Pillow
|
Pillow
|
||||||
@@ -9,6 +8,5 @@ olefile
|
|||||||
requests
|
requests
|
||||||
toml
|
toml
|
||||||
numpy
|
numpy
|
||||||
torch
|
|
||||||
natsort
|
natsort
|
||||||
msgpack
|
msgpack
|
||||||
|
|||||||
@@ -3,6 +3,26 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
# Create mock modules for py/nodes directory - add this before any other imports
|
||||||
|
def mock_nodes_directory():
|
||||||
|
"""Create mock modules for all Python files in the py/nodes directory"""
|
||||||
|
nodes_dir = os.path.join(os.path.dirname(__file__), 'py', 'nodes')
|
||||||
|
if os.path.exists(nodes_dir):
|
||||||
|
# Create a mock module for the nodes package itself
|
||||||
|
sys.modules['py.nodes'] = type('MockNodesModule', (), {})
|
||||||
|
|
||||||
|
# Create mock modules for all Python files in the nodes directory
|
||||||
|
for file in os.listdir(nodes_dir):
|
||||||
|
if file.endswith('.py') and file != '__init__.py':
|
||||||
|
module_name = file[:-3] # Remove .py extension
|
||||||
|
full_module_name = f'py.nodes.{module_name}'
|
||||||
|
# Create empty module object
|
||||||
|
sys.modules[full_module_name] = type(f'Mock{module_name.capitalize()}Module', (), {})
|
||||||
|
print(f"Created mock module for: {full_module_name}")
|
||||||
|
|
||||||
|
# Run the mocking function before any other imports
|
||||||
|
mock_nodes_directory()
|
||||||
|
|
||||||
# Create mock folder_paths module BEFORE any other imports
|
# Create mock folder_paths module BEFORE any other imports
|
||||||
class MockFolderPaths:
|
class MockFolderPaths:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -232,7 +252,7 @@ class StandaloneLoraManager(LoraManager):
|
|||||||
added_targets.add(os.path.normpath(real_root))
|
added_targets.add(os.path.normpath(real_root))
|
||||||
|
|
||||||
# Add static routes for each checkpoint root
|
# Add static routes for each checkpoint root
|
||||||
for idx, root in enumerate(config.checkpoints_roots, start=1):
|
for idx, root in enumerate(config.base_models_roots, start=1):
|
||||||
if not os.path.exists(root):
|
if not os.path.exists(root):
|
||||||
logger.warning(f"Checkpoint root path does not exist: {root}")
|
logger.warning(f"Checkpoint root path does not exist: {root}")
|
||||||
continue
|
continue
|
||||||
@@ -268,8 +288,8 @@ class StandaloneLoraManager(LoraManager):
|
|||||||
norm_target = os.path.normpath(target_path)
|
norm_target = os.path.normpath(target_path)
|
||||||
if norm_target not in added_targets:
|
if norm_target not in added_targets:
|
||||||
# Determine if this is a checkpoint or lora link based on path
|
# Determine if this is a checkpoint or lora link based on path
|
||||||
is_checkpoint = any(os.path.normpath(cp_root) in os.path.normpath(link_path) for cp_root in config.checkpoints_roots)
|
is_checkpoint = any(os.path.normpath(cp_root) in os.path.normpath(link_path) for cp_root in config.base_models_roots)
|
||||||
is_checkpoint = is_checkpoint or any(os.path.normpath(cp_root) in norm_target for cp_root in config.checkpoints_roots)
|
is_checkpoint = is_checkpoint or any(os.path.normpath(cp_root) in norm_target for cp_root in config.base_models_roots)
|
||||||
|
|
||||||
if is_checkpoint:
|
if is_checkpoint:
|
||||||
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'
|
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'
|
||||||
|
|||||||
@@ -751,6 +751,29 @@ input:checked + .toggle-slider:before {
|
|||||||
opacity: 1;
|
opacity: 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Add styles for tab with new content indicator */
|
||||||
|
.tab-btn.has-new-content {
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tab-btn.has-new-content::after {
|
||||||
|
content: "";
|
||||||
|
position: absolute;
|
||||||
|
top: 4px;
|
||||||
|
right: 4px;
|
||||||
|
width: 8px;
|
||||||
|
height: 8px;
|
||||||
|
background-color: var(--lora-accent);
|
||||||
|
border-radius: 50%;
|
||||||
|
animation: pulse 2s infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes pulse {
|
||||||
|
0% { opacity: 1; transform: scale(1); }
|
||||||
|
50% { opacity: 0.7; transform: scale(1.1); }
|
||||||
|
100% { opacity: 1; transform: scale(1); }
|
||||||
|
}
|
||||||
|
|
||||||
/* Tab content styles */
|
/* Tab content styles */
|
||||||
.help-content {
|
.help-content {
|
||||||
padding: var(--space-1) 0;
|
padding: var(--space-1) 0;
|
||||||
@@ -817,6 +840,37 @@ input:checked + .toggle-slider:before {
|
|||||||
text-decoration: underline;
|
text-decoration: underline;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* New content badge styles */
|
||||||
|
.new-content-badge {
|
||||||
|
display: inline-flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
font-size: 0.7em;
|
||||||
|
font-weight: 600;
|
||||||
|
background-color: var(--lora-accent);
|
||||||
|
color: var(--lora-text);
|
||||||
|
padding: 2px 6px;
|
||||||
|
border-radius: 10px;
|
||||||
|
margin-left: 8px;
|
||||||
|
vertical-align: middle;
|
||||||
|
animation: fadeIn 0.5s ease-in-out;
|
||||||
|
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
|
||||||
|
text-transform: uppercase;
|
||||||
|
letter-spacing: 0.5px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.new-content-badge.inline {
|
||||||
|
font-size: 0.65em;
|
||||||
|
padding: 1px 4px;
|
||||||
|
margin-left: 6px;
|
||||||
|
border-radius: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Dark theme adjustments for new content badge */
|
||||||
|
[data-theme="dark"] .new-content-badge {
|
||||||
|
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.4);
|
||||||
|
}
|
||||||
|
|
||||||
/* Update video list styles */
|
/* Update video list styles */
|
||||||
.video-list {
|
.video-list {
|
||||||
display: flex;
|
display: flex;
|
||||||
|
|||||||
@@ -179,7 +179,7 @@ export function setupBaseModelEditing(filePath) {
|
|||||||
'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER],
|
'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER],
|
||||||
'Video Models': [BASE_MODELS.SVD, BASE_MODELS.LTXV, BASE_MODELS.WAN_VIDEO, BASE_MODELS.HUNYUAN_VIDEO],
|
'Video Models': [BASE_MODELS.SVD, BASE_MODELS.LTXV, BASE_MODELS.WAN_VIDEO, BASE_MODELS.HUNYUAN_VIDEO],
|
||||||
'Other Models': [
|
'Other Models': [
|
||||||
BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.AURAFLOW,
|
BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.FLUX_1_KONTEXT, BASE_MODELS.AURAFLOW,
|
||||||
BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1,
|
BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1,
|
||||||
BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI,
|
BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI,
|
||||||
BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.HIDREAM,
|
BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.HIDREAM,
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ export function setupBaseModelEditing(filePath) {
|
|||||||
'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER],
|
'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER],
|
||||||
'Video Models': [BASE_MODELS.SVD, BASE_MODELS.LTXV, BASE_MODELS.WAN_VIDEO, BASE_MODELS.HUNYUAN_VIDEO],
|
'Video Models': [BASE_MODELS.SVD, BASE_MODELS.LTXV, BASE_MODELS.WAN_VIDEO, BASE_MODELS.HUNYUAN_VIDEO],
|
||||||
'Other Models': [
|
'Other Models': [
|
||||||
BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.AURAFLOW,
|
BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.FLUX_1_KONTEXT, BASE_MODELS.AURAFLOW,
|
||||||
BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1,
|
BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1,
|
||||||
BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI,
|
BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI,
|
||||||
BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.HIDREAM,
|
BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.HIDREAM,
|
||||||
|
|||||||
@@ -245,11 +245,6 @@ function findLocalFile(img, index, exampleFiles) {
|
|||||||
const match = file.name.match(/image_(\d+)\./);
|
const match = file.name.match(/image_(\d+)\./);
|
||||||
return match && parseInt(match[1]) === index;
|
return match && parseInt(match[1]) === index;
|
||||||
});
|
});
|
||||||
|
|
||||||
// If not found by index, just use the same position in the array if available
|
|
||||||
if (!localFile && index < exampleFiles.length) {
|
|
||||||
localFile = exampleFiles[index];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return localFile;
|
return localFile;
|
||||||
@@ -406,9 +401,6 @@ async function handleImportFiles(files, modelHash, importContainer) {
|
|||||||
const customImages = result.custom_images || [];
|
const customImages = result.custom_images || [];
|
||||||
// Combine both arrays for rendering
|
// Combine both arrays for rendering
|
||||||
const allImages = [...regularImages, ...customImages];
|
const allImages = [...regularImages, ...customImages];
|
||||||
console.log("Regular images:", regularImages);
|
|
||||||
console.log("Custom images:", customImages);
|
|
||||||
console.log("Combined images:", allImages);
|
|
||||||
showcaseTab.innerHTML = renderShowcaseContent(allImages, updatedFilesResult.files, true);
|
showcaseTab.innerHTML = renderShowcaseContent(allImages, updatedFilesResult.files, true);
|
||||||
|
|
||||||
// Re-initialize showcase functionality
|
// Re-initialize showcase functionality
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ export class CheckpointDownloadManager {
|
|||||||
this.currentVersion = null;
|
this.currentVersion = null;
|
||||||
this.versions = [];
|
this.versions = [];
|
||||||
this.modelInfo = null;
|
this.modelInfo = null;
|
||||||
|
this.modelId = null;
|
||||||
this.modelVersionId = null;
|
this.modelVersionId = null;
|
||||||
|
|
||||||
// Clear selected folder and remove selection from UI
|
// Clear selected folder and remove selection from UI
|
||||||
@@ -79,12 +80,12 @@ export class CheckpointDownloadManager {
|
|||||||
try {
|
try {
|
||||||
this.loadingManager.showSimpleLoading('Fetching model versions...');
|
this.loadingManager.showSimpleLoading('Fetching model versions...');
|
||||||
|
|
||||||
const modelId = this.extractModelId(url);
|
this.modelId = this.extractModelId(url);
|
||||||
if (!modelId) {
|
if (!this.modelId) {
|
||||||
throw new Error('Invalid Civitai URL format');
|
throw new Error('Invalid Civitai URL format');
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await fetch(`/api/checkpoints/civitai/versions/${modelId}`);
|
const response = await fetch(`/api/checkpoints/civitai/versions/${this.modelId}`);
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
const errorData = await response.json().catch(() => ({}));
|
const errorData = await response.json().catch(() => ({}));
|
||||||
if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) {
|
if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) {
|
||||||
@@ -254,7 +255,7 @@ export class CheckpointDownloadManager {
|
|||||||
).join('');
|
).join('');
|
||||||
|
|
||||||
// Set default checkpoint root if available
|
// Set default checkpoint root if available
|
||||||
const defaultRoot = getStorageItem('settings', {}).default_checkpoints_root;
|
const defaultRoot = getStorageItem('settings', {}).default_checkpoint_root;
|
||||||
if (defaultRoot && data.roots.includes(defaultRoot)) {
|
if (defaultRoot && data.roots.includes(defaultRoot)) {
|
||||||
checkpointRoot.value = defaultRoot;
|
checkpointRoot.value = defaultRoot;
|
||||||
}
|
}
|
||||||
@@ -296,22 +297,28 @@ export class CheckpointDownloadManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const downloadUrl = this.currentVersion.downloadUrl;
|
|
||||||
if (!downloadUrl) {
|
|
||||||
throw new Error('No download URL available');
|
|
||||||
}
|
|
||||||
|
|
||||||
// Show enhanced loading with progress details
|
// Show enhanced loading with progress details
|
||||||
const updateProgress = this.loadingManager.showDownloadProgress(1);
|
const updateProgress = this.loadingManager.showDownloadProgress(1);
|
||||||
updateProgress(0, 0, this.currentVersion.name);
|
updateProgress(0, 0, this.currentVersion.name);
|
||||||
|
|
||||||
// Setup WebSocket for progress updates using checkpoint-specific endpoint
|
// Generate a unique ID for this download
|
||||||
|
const downloadId = Date.now().toString();
|
||||||
|
|
||||||
|
// Setup WebSocket for progress updates using download-specific endpoint
|
||||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
||||||
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/checkpoint-progress`);
|
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`);
|
||||||
|
|
||||||
ws.onmessage = (event) => {
|
ws.onmessage = (event) => {
|
||||||
const data = JSON.parse(event.data);
|
const data = JSON.parse(event.data);
|
||||||
if (data.status === 'progress') {
|
|
||||||
|
// Handle download ID confirmation
|
||||||
|
if (data.type === 'download_id') {
|
||||||
|
console.log(`Connected to checkpoint download progress with ID: ${data.download_id}`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only process progress updates for our download
|
||||||
|
if (data.status === 'progress' && data.download_id === downloadId) {
|
||||||
// Update progress display with current progress
|
// Update progress display with current progress
|
||||||
updateProgress(data.progress, 0, this.currentVersion.name);
|
updateProgress(data.progress, 0, this.currentVersion.name);
|
||||||
|
|
||||||
@@ -333,14 +340,16 @@ export class CheckpointDownloadManager {
|
|||||||
// Continue with download even if WebSocket fails
|
// Continue with download even if WebSocket fails
|
||||||
};
|
};
|
||||||
|
|
||||||
// Start download using checkpoint download endpoint
|
// Start download using checkpoint download endpoint with download ID
|
||||||
const response = await fetch('/api/checkpoints/download', {
|
const response = await fetch('/api/download-model', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
download_url: downloadUrl,
|
model_id: this.modelId,
|
||||||
checkpoint_root: checkpointRoot,
|
model_version_id: this.currentVersion.id,
|
||||||
relative_path: targetFolder
|
model_root: checkpointRoot,
|
||||||
|
relative_path: targetFolder,
|
||||||
|
download_id: downloadId
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ export class DownloadManager {
|
|||||||
this.currentVersion = null;
|
this.currentVersion = null;
|
||||||
this.versions = [];
|
this.versions = [];
|
||||||
this.modelInfo = null;
|
this.modelInfo = null;
|
||||||
|
this.modelId = null;
|
||||||
this.modelVersionId = null;
|
this.modelVersionId = null;
|
||||||
|
|
||||||
// Clear selected folder and remove selection from UI
|
// Clear selected folder and remove selection from UI
|
||||||
@@ -81,12 +82,12 @@ export class DownloadManager {
|
|||||||
try {
|
try {
|
||||||
this.loadingManager.showSimpleLoading('Fetching model versions...');
|
this.loadingManager.showSimpleLoading('Fetching model versions...');
|
||||||
|
|
||||||
const modelId = this.extractModelId(url);
|
this.modelId = this.extractModelId(url);
|
||||||
if (!modelId) {
|
if (!this.modelId) {
|
||||||
throw new Error('Invalid Civitai URL format');
|
throw new Error('Invalid Civitai URL format');
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await fetch(`/api/civitai/versions/${modelId}`);
|
const response = await fetch(`/api/civitai/versions/${this.modelId}`);
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
const errorData = await response.json().catch(() => ({}));
|
const errorData = await response.json().catch(() => ({}));
|
||||||
if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) {
|
if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) {
|
||||||
@@ -252,24 +253,39 @@ export class DownloadManager {
|
|||||||
document.getElementById('locationStep').style.display = 'block';
|
document.getElementById('locationStep').style.display = 'block';
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await fetch('/api/lora-roots');
|
// Fetch LoRA roots
|
||||||
if (!response.ok) {
|
const rootsResponse = await fetch('/api/lora-roots');
|
||||||
|
if (!rootsResponse.ok) {
|
||||||
throw new Error('Failed to fetch LoRA roots');
|
throw new Error('Failed to fetch LoRA roots');
|
||||||
}
|
}
|
||||||
|
|
||||||
const data = await response.json();
|
const rootsData = await rootsResponse.json();
|
||||||
const loraRoot = document.getElementById('loraRoot');
|
const loraRoot = document.getElementById('loraRoot');
|
||||||
loraRoot.innerHTML = data.roots.map(root =>
|
loraRoot.innerHTML = rootsData.roots.map(root =>
|
||||||
`<option value="${root}">${root}</option>`
|
`<option value="${root}">${root}</option>`
|
||||||
).join('');
|
).join('');
|
||||||
|
|
||||||
// Set default lora root if available
|
// Set default lora root if available
|
||||||
const defaultRoot = getStorageItem('settings', {}).default_loras_root;
|
const defaultRoot = getStorageItem('settings', {}).default_loras_root;
|
||||||
if (defaultRoot && data.roots.includes(defaultRoot)) {
|
if (defaultRoot && rootsData.roots.includes(defaultRoot)) {
|
||||||
loraRoot.value = defaultRoot;
|
loraRoot.value = defaultRoot;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize folder browser after loading roots
|
// Fetch folders dynamically
|
||||||
|
const foldersResponse = await fetch('/api/folders');
|
||||||
|
if (!foldersResponse.ok) {
|
||||||
|
throw new Error('Failed to fetch folders');
|
||||||
|
}
|
||||||
|
|
||||||
|
const foldersData = await foldersResponse.json();
|
||||||
|
const folderBrowser = document.getElementById('folderBrowser');
|
||||||
|
|
||||||
|
// Update folder browser with dynamic content
|
||||||
|
folderBrowser.innerHTML = foldersData.folders.map(folder =>
|
||||||
|
`<div class="folder-item" data-folder="${folder}">${folder}</div>`
|
||||||
|
).join('');
|
||||||
|
|
||||||
|
// Initialize folder browser after loading roots and folders
|
||||||
this.initializeFolderBrowser();
|
this.initializeFolderBrowser();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
showToast(error.message, 'error');
|
showToast(error.message, 'error');
|
||||||
@@ -306,22 +322,28 @@ export class DownloadManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const downloadUrl = this.currentVersion.downloadUrl;
|
|
||||||
if (!downloadUrl) {
|
|
||||||
throw new Error('No download URL available');
|
|
||||||
}
|
|
||||||
|
|
||||||
// Show enhanced loading with progress details
|
// Show enhanced loading with progress details
|
||||||
const updateProgress = this.loadingManager.showDownloadProgress(1);
|
const updateProgress = this.loadingManager.showDownloadProgress(1);
|
||||||
updateProgress(0, 0, this.currentVersion.name);
|
updateProgress(0, 0, this.currentVersion.name);
|
||||||
|
|
||||||
// Setup WebSocket for progress updates
|
// Generate a unique ID for this download
|
||||||
|
const downloadId = Date.now().toString();
|
||||||
|
|
||||||
|
// Setup WebSocket for progress updates - use download-specific endpoint
|
||||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
||||||
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
|
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`);
|
||||||
|
|
||||||
ws.onmessage = (event) => {
|
ws.onmessage = (event) => {
|
||||||
const data = JSON.parse(event.data);
|
const data = JSON.parse(event.data);
|
||||||
if (data.status === 'progress') {
|
|
||||||
|
// Handle download ID confirmation
|
||||||
|
if (data.type === 'download_id') {
|
||||||
|
console.log(`Connected to download progress with ID: ${data.download_id}`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only process progress updates for our download
|
||||||
|
if (data.status === 'progress' && data.download_id === downloadId) {
|
||||||
// Update progress display with current progress
|
// Update progress display with current progress
|
||||||
updateProgress(data.progress, 0, this.currentVersion.name);
|
updateProgress(data.progress, 0, this.currentVersion.name);
|
||||||
|
|
||||||
@@ -343,14 +365,16 @@ export class DownloadManager {
|
|||||||
// Continue with download even if WebSocket fails
|
// Continue with download even if WebSocket fails
|
||||||
};
|
};
|
||||||
|
|
||||||
// Start download
|
// Start download with our download ID
|
||||||
const response = await fetch('/api/download-lora', {
|
const response = await fetch('/api/download-model', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
download_url: downloadUrl,
|
model_id: this.modelId,
|
||||||
lora_root: loraRoot,
|
model_version_id: this.currentVersion.id,
|
||||||
relative_path: targetFolder
|
model_root: loraRoot,
|
||||||
|
relative_path: targetFolder,
|
||||||
|
download_id: downloadId
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -361,6 +385,9 @@ export class DownloadManager {
|
|||||||
showToast('Download completed successfully', 'success');
|
showToast('Download completed successfully', 'success');
|
||||||
modalManager.closeModal('downloadModal');
|
modalManager.closeModal('downloadModal');
|
||||||
|
|
||||||
|
// Close WebSocket after download completes
|
||||||
|
ws.close();
|
||||||
|
|
||||||
// Update state and trigger reload with folder update
|
// Update state and trigger reload with folder update
|
||||||
state.activeFolder = targetFolder;
|
state.activeFolder = targetFolder;
|
||||||
await resetAndReload(true); // Pass true to update folders
|
await resetAndReload(true); // Pass true to update folders
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import { getStorageItem, setStorageItem } from '../utils/storageHelpers.js';
|
|||||||
export class HelpManager {
|
export class HelpManager {
|
||||||
constructor() {
|
constructor() {
|
||||||
this.lastViewedTimestamp = getStorageItem('help_last_viewed', 0);
|
this.lastViewedTimestamp = getStorageItem('help_last_viewed', 0);
|
||||||
this.latestContentTimestamp = 0; // Will be updated from server or config
|
this.latestContentTimestamp = new Date('2025-07-09').getTime(); // Will be updated from server or config
|
||||||
this.isInitialized = false;
|
this.isInitialized = false;
|
||||||
|
|
||||||
// Default latest content data - could be fetched from server
|
// Default latest content data - could be fetched from server
|
||||||
@@ -81,6 +81,9 @@ export class HelpManager {
|
|||||||
if (window.modalManager) {
|
if (window.modalManager) {
|
||||||
window.modalManager.toggleModal('helpModal');
|
window.modalManager.toggleModal('helpModal');
|
||||||
|
|
||||||
|
// Add visual indicator to Documentation tab if there's new content
|
||||||
|
this.updateDocumentationTabIndicator();
|
||||||
|
|
||||||
// Update the last viewed timestamp
|
// Update the last viewed timestamp
|
||||||
this.markContentAsViewed();
|
this.markContentAsViewed();
|
||||||
|
|
||||||
@@ -88,6 +91,16 @@ export class HelpManager {
|
|||||||
this.hideHelpBadge();
|
this.hideHelpBadge();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add visual indicator to Documentation tab for new content
|
||||||
|
*/
|
||||||
|
updateDocumentationTabIndicator() {
|
||||||
|
const docTab = document.querySelector('.tab-btn[data-tab="documentation"]');
|
||||||
|
if (docTab && this.hasNewContent()) {
|
||||||
|
docTab.classList.add('has-new-content');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Mark content as viewed by saving current timestamp
|
* Mark content as viewed by saving current timestamp
|
||||||
@@ -105,7 +118,7 @@ export class HelpManager {
|
|||||||
// For now, we'll just use the hardcoded data from constructor
|
// For now, we'll just use the hardcoded data from constructor
|
||||||
|
|
||||||
// Update the timestamp with the latest data
|
// Update the timestamp with the latest data
|
||||||
this.latestContentTimestamp = this.latestVideoData.timestamp;
|
this.latestContentTimestamp = Math.max(this.latestContentTimestamp, this.latestVideoData.timestamp);
|
||||||
|
|
||||||
// Check again if we need to show the badge with this new data
|
// Check again if we need to show the badge with this new data
|
||||||
this.updateHelpBadge();
|
this.updateHelpBadge();
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import { showToast } from '../utils/uiHelpers.js';
|
|||||||
import { state, getCurrentPageState } from '../state/index.js';
|
import { state, getCurrentPageState } from '../state/index.js';
|
||||||
import { modalManager } from './ModalManager.js';
|
import { modalManager } from './ModalManager.js';
|
||||||
import { getStorageItem } from '../utils/storageHelpers.js';
|
import { getStorageItem } from '../utils/storageHelpers.js';
|
||||||
|
import { updateFolderTags } from '../api/baseModelApi.js';
|
||||||
|
|
||||||
class MoveManager {
|
class MoveManager {
|
||||||
constructor() {
|
constructor() {
|
||||||
@@ -72,32 +73,46 @@ class MoveManager {
|
|||||||
this.newFolderInput.value = '';
|
this.newFolderInput.value = '';
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await fetch('/api/lora-roots');
|
// Fetch LoRA roots
|
||||||
if (!response.ok) {
|
const rootsResponse = await fetch('/api/lora-roots');
|
||||||
|
if (!rootsResponse.ok) {
|
||||||
throw new Error('Failed to fetch LoRA roots');
|
throw new Error('Failed to fetch LoRA roots');
|
||||||
}
|
}
|
||||||
|
|
||||||
const data = await response.json();
|
const rootsData = await rootsResponse.json();
|
||||||
if (!data.roots || data.roots.length === 0) {
|
if (!rootsData.roots || rootsData.roots.length === 0) {
|
||||||
throw new Error('No LoRA roots found');
|
throw new Error('No LoRA roots found');
|
||||||
}
|
}
|
||||||
|
|
||||||
// 填充LoRA根目录选择器
|
// 填充LoRA根目录选择器
|
||||||
this.loraRootSelect.innerHTML = data.roots.map(root =>
|
this.loraRootSelect.innerHTML = rootsData.roots.map(root =>
|
||||||
`<option value="${root}">${root}</option>`
|
`<option value="${root}">${root}</option>`
|
||||||
).join('');
|
).join('');
|
||||||
|
|
||||||
// Set default lora root if available
|
// Set default lora root if available
|
||||||
const defaultRoot = getStorageItem('settings', {}).default_loras_root;
|
const defaultRoot = getStorageItem('settings', {}).default_loras_root;
|
||||||
if (defaultRoot && data.roots.includes(defaultRoot)) {
|
if (defaultRoot && rootsData.roots.includes(defaultRoot)) {
|
||||||
this.loraRootSelect.value = defaultRoot;
|
this.loraRootSelect.value = defaultRoot;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fetch folders dynamically
|
||||||
|
const foldersResponse = await fetch('/api/folders');
|
||||||
|
if (!foldersResponse.ok) {
|
||||||
|
throw new Error('Failed to fetch folders');
|
||||||
|
}
|
||||||
|
|
||||||
|
const foldersData = await foldersResponse.json();
|
||||||
|
|
||||||
|
// Update folder browser with dynamic content
|
||||||
|
this.folderBrowser.innerHTML = foldersData.folders.map(folder =>
|
||||||
|
`<div class="folder-item" data-folder="${folder}">${folder}</div>`
|
||||||
|
).join('');
|
||||||
|
|
||||||
this.updatePathPreview();
|
this.updatePathPreview();
|
||||||
modalManager.showModal('moveModal');
|
modalManager.showModal('moveModal');
|
||||||
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error fetching LoRA roots:', error);
|
console.error('Error fetching LoRA roots or folders:', error);
|
||||||
showToast(error.message, 'error');
|
showToast(error.message, 'error');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -173,6 +188,17 @@ class MoveManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Refresh folder tags after successful move
|
||||||
|
try {
|
||||||
|
const foldersResponse = await fetch('/api/folders');
|
||||||
|
if (foldersResponse.ok) {
|
||||||
|
const foldersData = await foldersResponse.json();
|
||||||
|
updateFolderTags(foldersData.folders);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error refreshing folder tags:', error);
|
||||||
|
}
|
||||||
|
|
||||||
modalManager.closeModal('moveModal');
|
modalManager.closeModal('moveModal');
|
||||||
|
|
||||||
// If we were in bulk mode, exit it after successful move
|
// If we were in bulk mode, exit it after successful move
|
||||||
|
|||||||
@@ -42,6 +42,11 @@ export class SettingsManager {
|
|||||||
state.global.settings.cardInfoDisplay = 'always';
|
state.global.settings.cardInfoDisplay = 'always';
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set default for defaultCheckpointRoot if undefined
|
||||||
|
if (state.global.settings.default_checkpoint_root === undefined) {
|
||||||
|
state.global.settings.default_checkpoint_root = '';
|
||||||
|
}
|
||||||
|
|
||||||
// Convert old boolean compactMode to new displayDensity string
|
// Convert old boolean compactMode to new displayDensity string
|
||||||
if (typeof state.global.settings.displayDensity === 'undefined') {
|
if (typeof state.global.settings.displayDensity === 'undefined') {
|
||||||
if (state.global.settings.compactMode === true) {
|
if (state.global.settings.compactMode === true) {
|
||||||
@@ -123,6 +128,9 @@ export class SettingsManager {
|
|||||||
// Load default lora root
|
// Load default lora root
|
||||||
await this.loadLoraRoots();
|
await this.loadLoraRoots();
|
||||||
|
|
||||||
|
// Load default checkpoint root
|
||||||
|
await this.loadCheckpointRoots();
|
||||||
|
|
||||||
// Backend settings are loaded from the template directly
|
// Backend settings are loaded from the template directly
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,6 +173,45 @@ export class SettingsManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async loadCheckpointRoots() {
|
||||||
|
try {
|
||||||
|
const defaultCheckpointRootSelect = document.getElementById('defaultCheckpointRoot');
|
||||||
|
if (!defaultCheckpointRootSelect) return;
|
||||||
|
|
||||||
|
// Fetch checkpoint roots
|
||||||
|
const response = await fetch('/api/checkpoints/roots');
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error('Failed to fetch checkpoint roots');
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
if (!data.roots || data.roots.length === 0) {
|
||||||
|
throw new Error('No checkpoint roots found');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear existing options except the first one (No Default)
|
||||||
|
const noDefaultOption = defaultCheckpointRootSelect.querySelector('option[value=""]');
|
||||||
|
defaultCheckpointRootSelect.innerHTML = '';
|
||||||
|
defaultCheckpointRootSelect.appendChild(noDefaultOption);
|
||||||
|
|
||||||
|
// Add options for each root
|
||||||
|
data.roots.forEach(root => {
|
||||||
|
const option = document.createElement('option');
|
||||||
|
option.value = root;
|
||||||
|
option.textContent = root;
|
||||||
|
defaultCheckpointRootSelect.appendChild(option);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Set selected value from settings
|
||||||
|
const defaultRoot = state.global.settings.default_checkpoint_root || '';
|
||||||
|
defaultCheckpointRootSelect.value = defaultRoot;
|
||||||
|
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error loading checkpoint roots:', error);
|
||||||
|
showToast('Failed to load checkpoint roots: ' + error.message, 'error');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
toggleSettings() {
|
toggleSettings() {
|
||||||
if (this.isOpen) {
|
if (this.isOpen) {
|
||||||
modalManager.closeModal('settingsModal');
|
modalManager.closeModal('settingsModal');
|
||||||
@@ -251,6 +298,8 @@ export class SettingsManager {
|
|||||||
// Update frontend state
|
// Update frontend state
|
||||||
if (settingKey === 'default_lora_root') {
|
if (settingKey === 'default_lora_root') {
|
||||||
state.global.settings.default_loras_root = value;
|
state.global.settings.default_loras_root = value;
|
||||||
|
} else if (settingKey === 'default_checkpoint_root') {
|
||||||
|
state.global.settings.default_checkpoint_root = value;
|
||||||
} else if (settingKey === 'display_density') {
|
} else if (settingKey === 'display_density') {
|
||||||
state.global.settings.displayDensity = value;
|
state.global.settings.displayDensity = value;
|
||||||
|
|
||||||
@@ -268,7 +317,7 @@ export class SettingsManager {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
// For backend settings, make API call
|
// For backend settings, make API call
|
||||||
if (settingKey === 'default_lora_root') {
|
if (settingKey === 'default_lora_root' || settingKey === 'default_checkpoint_root') {
|
||||||
const payload = {};
|
const payload = {};
|
||||||
payload[settingKey] = value;
|
payload[settingKey] = value;
|
||||||
|
|
||||||
@@ -414,6 +463,7 @@ export class SettingsManager {
|
|||||||
const blurMatureContent = document.getElementById('blurMatureContent').checked;
|
const blurMatureContent = document.getElementById('blurMatureContent').checked;
|
||||||
const showOnlySFW = document.getElementById('showOnlySFW').checked;
|
const showOnlySFW = document.getElementById('showOnlySFW').checked;
|
||||||
const defaultLoraRoot = document.getElementById('defaultLoraRoot').value;
|
const defaultLoraRoot = document.getElementById('defaultLoraRoot').value;
|
||||||
|
const defaultCheckpointRoot = document.getElementById('defaultCheckpointRoot').value;
|
||||||
const autoplayOnHover = document.getElementById('autoplayOnHover').checked;
|
const autoplayOnHover = document.getElementById('autoplayOnHover').checked;
|
||||||
const optimizeExampleImages = document.getElementById('optimizeExampleImages').checked;
|
const optimizeExampleImages = document.getElementById('optimizeExampleImages').checked;
|
||||||
|
|
||||||
@@ -424,6 +474,7 @@ export class SettingsManager {
|
|||||||
state.global.settings.blurMatureContent = blurMatureContent;
|
state.global.settings.blurMatureContent = blurMatureContent;
|
||||||
state.global.settings.show_only_sfw = showOnlySFW;
|
state.global.settings.show_only_sfw = showOnlySFW;
|
||||||
state.global.settings.default_loras_root = defaultLoraRoot;
|
state.global.settings.default_loras_root = defaultLoraRoot;
|
||||||
|
state.global.settings.default_checkpoint_root = defaultCheckpointRoot;
|
||||||
state.global.settings.autoplayOnHover = autoplayOnHover;
|
state.global.settings.autoplayOnHover = autoplayOnHover;
|
||||||
state.global.settings.optimizeExampleImages = optimizeExampleImages;
|
state.global.settings.optimizeExampleImages = optimizeExampleImages;
|
||||||
|
|
||||||
@@ -440,7 +491,8 @@ export class SettingsManager {
|
|||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
civitai_api_key: apiKey,
|
civitai_api_key: apiKey,
|
||||||
show_only_sfw: showOnlySFW,
|
show_only_sfw: showOnlySFW,
|
||||||
optimize_example_images: optimizeExampleImages
|
optimize_example_images: optimizeExampleImages,
|
||||||
|
default_checkpoint_root: defaultCheckpointRoot
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -128,9 +128,12 @@ export class DownloadManager {
|
|||||||
targetPath += '/' + newFolder;
|
targetPath += '/' + newFolder;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Generate a unique ID for this batch download
|
||||||
|
const batchDownloadId = Date.now().toString();
|
||||||
|
|
||||||
// Set up WebSocket for progress updates
|
// Set up WebSocket for progress updates
|
||||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
||||||
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
|
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${batchDownloadId}`);
|
||||||
|
|
||||||
// Show enhanced loading with progress details for multiple items
|
// Show enhanced loading with progress details for multiple items
|
||||||
const updateProgress = this.importManager.loadingManager.showDownloadProgress(
|
const updateProgress = this.importManager.loadingManager.showDownloadProgress(
|
||||||
@@ -145,7 +148,15 @@ export class DownloadManager {
|
|||||||
// Set up progress tracking for current download
|
// Set up progress tracking for current download
|
||||||
ws.onmessage = (event) => {
|
ws.onmessage = (event) => {
|
||||||
const data = JSON.parse(event.data);
|
const data = JSON.parse(event.data);
|
||||||
if (data.status === 'progress') {
|
|
||||||
|
// Handle download ID confirmation
|
||||||
|
if (data.type === 'download_id') {
|
||||||
|
console.log(`Connected to batch download progress with ID: ${data.download_id}`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process progress updates for our current active download
|
||||||
|
if (data.status === 'progress' && data.download_id && data.download_id.startsWith(batchDownloadId)) {
|
||||||
// Update current LoRA progress
|
// Update current LoRA progress
|
||||||
currentLoraProgress = data.progress;
|
currentLoraProgress = data.progress;
|
||||||
|
|
||||||
@@ -188,16 +199,16 @@ export class DownloadManager {
|
|||||||
updateProgress(0, completedDownloads, lora.name);
|
updateProgress(0, completedDownloads, lora.name);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Download the LoRA
|
// Download the LoRA with download ID
|
||||||
const response = await fetch('/api/download-lora', {
|
const response = await fetch('/api/download-model', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
download_url: lora.downloadUrl,
|
model_id: lora.modelId,
|
||||||
model_version_id: lora.modelVersionId,
|
model_version_id: lora.id,
|
||||||
model_hash: lora.hash,
|
model_root: loraRoot,
|
||||||
lora_root: loraRoot,
|
relative_path: targetPath.replace(loraRoot + '/', ''),
|
||||||
relative_path: targetPath.replace(loraRoot + '/', '')
|
download_id: batchDownloadId
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ export const BASE_MODELS = {
|
|||||||
// Other models
|
// Other models
|
||||||
FLUX_1_D: "Flux.1 D",
|
FLUX_1_D: "Flux.1 D",
|
||||||
FLUX_1_S: "Flux.1 S",
|
FLUX_1_S: "Flux.1 S",
|
||||||
|
FLUX_1_KONTEXT: "Flux.1 Kontext",
|
||||||
AURAFLOW: "AuraFlow",
|
AURAFLOW: "AuraFlow",
|
||||||
PIXART_A: "PixArt a",
|
PIXART_A: "PixArt a",
|
||||||
PIXART_E: "PixArt E",
|
PIXART_E: "PixArt E",
|
||||||
@@ -78,6 +79,7 @@ export const BASE_MODEL_CLASSES = {
|
|||||||
// Other models
|
// Other models
|
||||||
[BASE_MODELS.FLUX_1_D]: "flux-d",
|
[BASE_MODELS.FLUX_1_D]: "flux-d",
|
||||||
[BASE_MODELS.FLUX_1_S]: "flux-s",
|
[BASE_MODELS.FLUX_1_S]: "flux-s",
|
||||||
|
[BASE_MODELS.FLUX_1_KONTEXT]: "flux-kontext",
|
||||||
[BASE_MODELS.AURAFLOW]: "auraflow",
|
[BASE_MODELS.AURAFLOW]: "auraflow",
|
||||||
[BASE_MODELS.PIXART_A]: "pixart-a",
|
[BASE_MODELS.PIXART_A]: "pixart-a",
|
||||||
[BASE_MODELS.PIXART_E]: "pixart-e",
|
[BASE_MODELS.PIXART_E]: "pixart-e",
|
||||||
@@ -106,19 +108,22 @@ export const NSFW_LEVELS = {
|
|||||||
// Node type constants
|
// Node type constants
|
||||||
export const NODE_TYPES = {
|
export const NODE_TYPES = {
|
||||||
LORA_LOADER: 1,
|
LORA_LOADER: 1,
|
||||||
LORA_STACKER: 2
|
LORA_STACKER: 2,
|
||||||
|
WAN_VIDEO_LORA_SELECT: 3
|
||||||
};
|
};
|
||||||
|
|
||||||
// Node type names to IDs mapping
|
// Node type names to IDs mapping
|
||||||
export const NODE_TYPE_NAMES = {
|
export const NODE_TYPE_NAMES = {
|
||||||
"Lora Loader (LoraManager)": NODE_TYPES.LORA_LOADER,
|
"Lora Loader (LoraManager)": NODE_TYPES.LORA_LOADER,
|
||||||
"Lora Stacker (LoraManager)": NODE_TYPES.LORA_STACKER
|
"Lora Stacker (LoraManager)": NODE_TYPES.LORA_STACKER,
|
||||||
|
"WanVideo Lora Select (LoraManager)": NODE_TYPES.WAN_VIDEO_LORA_SELECT
|
||||||
};
|
};
|
||||||
|
|
||||||
// Node type icons
|
// Node type icons
|
||||||
export const NODE_TYPE_ICONS = {
|
export const NODE_TYPE_ICONS = {
|
||||||
[NODE_TYPES.LORA_LOADER]: "fas fa-l",
|
[NODE_TYPES.LORA_LOADER]: "fas fa-l",
|
||||||
[NODE_TYPES.LORA_STACKER]: "fas fa-s"
|
[NODE_TYPES.LORA_STACKER]: "fas fa-s",
|
||||||
|
[NODE_TYPES.WAN_VIDEO_LORA_SELECT]: "fas fa-w"
|
||||||
};
|
};
|
||||||
|
|
||||||
// Default ComfyUI node color when bgcolor is null
|
// Default ComfyUI node color when bgcolor is null
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ export const apiRoutes = {
|
|||||||
delete: (id) => `/api/loras/${id}`,
|
delete: (id) => `/api/loras/${id}`,
|
||||||
update: (id) => `/api/loras/${id}`,
|
update: (id) => `/api/loras/${id}`,
|
||||||
civitai: (id) => `/api/loras/${id}/civitai`,
|
civitai: (id) => `/api/loras/${id}/civitai`,
|
||||||
download: '/api/download-lora',
|
download: '/api/download-model',
|
||||||
move: '/api/move-lora',
|
move: '/api/move-lora',
|
||||||
scan: '/api/scan-loras'
|
scan: '/api/scan-loras'
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -337,7 +337,7 @@ export async function sendLoraToWorkflow(loraSyntax, replaceMode = false, syntax
|
|||||||
// Success case - check node count
|
// Success case - check node count
|
||||||
if (registryData.data.node_count === 0) {
|
if (registryData.data.node_count === 0) {
|
||||||
// No nodes found - show warning
|
// No nodes found - show warning
|
||||||
showToast('No Lora Loader or Lora Stacker nodes found in workflow', 'warning');
|
showToast('No supported target nodes found in workflow', 'warning');
|
||||||
return false;
|
return false;
|
||||||
} else if (registryData.data.node_count > 1) {
|
} else if (registryData.data.node_count > 1) {
|
||||||
// Multiple nodes - show selector
|
// Multiple nodes - show selector
|
||||||
|
|||||||
14
static/vendor/chart.js/chart.umd.js
vendored
Normal file
14
static/vendor/chart.js/chart.umd.js
vendored
Normal file
File diff suppressed because one or more lines are too long
@@ -48,13 +48,7 @@
|
|||||||
<div class="input-group">
|
<div class="input-group">
|
||||||
<label>Target Folder:</label>
|
<label>Target Folder:</label>
|
||||||
<div class="folder-browser" id="folderBrowser">
|
<div class="folder-browser" id="folderBrowser">
|
||||||
{% for folder in folders %}
|
<!-- Folders will be loaded dynamically -->
|
||||||
{% if folder %}
|
|
||||||
<div class="folder-item" data-folder="{{ folder }}">
|
|
||||||
{{ folder }}
|
|
||||||
</div>
|
|
||||||
{% endif %}
|
|
||||||
{% endfor %}
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="input-group">
|
<div class="input-group">
|
||||||
@@ -92,13 +86,7 @@
|
|||||||
<div class="input-group">
|
<div class="input-group">
|
||||||
<label>Target Folder:</label>
|
<label>Target Folder:</label>
|
||||||
<div class="folder-browser" id="moveFolderBrowser">
|
<div class="folder-browser" id="moveFolderBrowser">
|
||||||
{% for folder in folders %}
|
<!-- Folders will be loaded dynamically -->
|
||||||
{% if folder %}
|
|
||||||
<div class="folder-item" data-folder="{{ folder }}">
|
|
||||||
{{ folder }}
|
|
||||||
</div>
|
|
||||||
{% endif %}
|
|
||||||
{% endfor %}
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="input-group">
|
<div class="input-group">
|
||||||
|
|||||||
@@ -197,6 +197,23 @@
|
|||||||
Set the default LoRA root directory for downloads, imports and moves
|
Set the default LoRA root directory for downloads, imports and moves
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div class="setting-item">
|
||||||
|
<div class="setting-row">
|
||||||
|
<div class="setting-info">
|
||||||
|
<label for="defaultCheckpointRoot">Default Checkpoint Root</label>
|
||||||
|
</div>
|
||||||
|
<div class="setting-control select-control">
|
||||||
|
<select id="defaultCheckpointRoot" onchange="settingsManager.saveSelectSetting('defaultCheckpointRoot', 'default_checkpoint_root')">
|
||||||
|
<option value="">No Default</option>
|
||||||
|
<!-- Options will be loaded dynamically -->
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="input-help">
|
||||||
|
Set the default checkpoint root directory for downloads, imports and moves
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Add Layout Settings Section -->
|
<!-- Add Layout Settings Section -->
|
||||||
@@ -558,7 +575,7 @@
|
|||||||
<div class="docs-section">
|
<div class="docs-section">
|
||||||
<h4><i class="fas fa-book-open"></i> Recipes</h4>
|
<h4><i class="fas fa-book-open"></i> Recipes</h4>
|
||||||
<ul class="docs-links">
|
<ul class="docs-links">
|
||||||
<li><a href="https://github.com/willmiao/ComfyUI-Lora-Manager/wiki/%F0%9F%93%96-Recipes-Feature-Tutorial-%E2%80%93-ComfyUI-LoRA-Manager" target="_blank">Recipes Tutorial</a></li>
|
<li><a href="https://github.com/willmiao/ComfyUI-Lora-Manager/wiki/Recipes-Feature-Tutorial-%E2%80%93-ComfyUI-LoRA-Manager" target="_blank">Recipes Tutorial</a></li>
|
||||||
</ul>
|
</ul>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -568,6 +585,21 @@
|
|||||||
<li><a href="https://github.com/willmiao/ComfyUI-Lora-Manager/wiki/Configuration" target="_blank">Configuration Options (WIP)</a></li>
|
<li><a href="https://github.com/willmiao/ComfyUI-Lora-Manager/wiki/Configuration" target="_blank">Configuration Options (WIP)</a></li>
|
||||||
</ul>
|
</ul>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div class="docs-section">
|
||||||
|
<h4>
|
||||||
|
<i class="fas fa-puzzle-piece"></i> Extensions
|
||||||
|
<span class="new-content-badge">NEW</span>
|
||||||
|
</h4>
|
||||||
|
<ul class="docs-links">
|
||||||
|
<li>
|
||||||
|
<a href="https://github.com/willmiao/ComfyUI-Lora-Manager/wiki/LoRA-Manager-Civitai-Extension-(Chrome-Extension)" target="_blank">
|
||||||
|
LM Civitai Extension
|
||||||
|
<span class="new-content-badge inline">NEW</span>
|
||||||
|
</a>
|
||||||
|
</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
|
|
||||||
{% block head_scripts %}
|
{% block head_scripts %}
|
||||||
<!-- Add Chart.js for statistics page -->
|
<!-- Add Chart.js for statistics page -->
|
||||||
<script src="https://cdn.jsdelivr.net/npm/chart.js@4.4.0/dist/chart.umd.js"></script>
|
<script src="/loras_static/vendor/chart.js/chart.umd.js"></script>
|
||||||
{% endblock %}
|
{% endblock %}
|
||||||
|
|
||||||
{% block init_title %}Initializing Statistics{% endblock %}
|
{% block init_title %}Initializing Statistics{% endblock %}
|
||||||
|
|||||||
@@ -76,7 +76,9 @@ app.registerExtension({
|
|||||||
|
|
||||||
// Standard mode - update a specific node
|
// Standard mode - update a specific node
|
||||||
const node = app.graph.getNodeById(+id);
|
const node = app.graph.getNodeById(+id);
|
||||||
if (!node || (node.comfyClass !== "Lora Loader (LoraManager)" && node.comfyClass !== "Lora Stacker (LoraManager)")) {
|
if (!node || (node.comfyClass !== "Lora Loader (LoraManager)" &&
|
||||||
|
node.comfyClass !== "Lora Stacker (LoraManager)" &&
|
||||||
|
node.comfyClass !== "WanVideo Lora Select (LoraManager)")) {
|
||||||
console.warn("Node not found or not a LoraLoader:", id);
|
console.warn("Node not found or not a LoraLoader:", id);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -87,7 +89,7 @@ app.registerExtension({
|
|||||||
// Helper method to update a single node's lora code
|
// Helper method to update a single node's lora code
|
||||||
updateNodeLoraCode(node, loraCode, mode) {
|
updateNodeLoraCode(node, loraCode, mode) {
|
||||||
// Update the input widget with new lora code
|
// Update the input widget with new lora code
|
||||||
const inputWidget = node.widgets[0];
|
const inputWidget = node.inputWidget;
|
||||||
if (!inputWidget) return;
|
if (!inputWidget) return;
|
||||||
|
|
||||||
// Get the current lora code
|
// Get the current lora code
|
||||||
@@ -182,6 +184,7 @@ app.registerExtension({
|
|||||||
|
|
||||||
// Update input widget callback
|
// Update input widget callback
|
||||||
const inputWidget = this.widgets[0];
|
const inputWidget = this.widgets[0];
|
||||||
|
this.inputWidget = inputWidget;
|
||||||
inputWidget.callback = (value) => {
|
inputWidget.callback = (value) => {
|
||||||
if (isUpdating) return;
|
if (isUpdating) return;
|
||||||
isUpdating = true;
|
isUpdating = true;
|
||||||
|
|||||||
@@ -105,6 +105,7 @@ app.registerExtension({
|
|||||||
|
|
||||||
// Update input widget callback
|
// Update input widget callback
|
||||||
const inputWidget = this.widgets[0];
|
const inputWidget = this.widgets[0];
|
||||||
|
this.inputWidget = inputWidget;
|
||||||
inputWidget.callback = (value) => {
|
inputWidget.callback = (value) => {
|
||||||
if (isUpdating) return;
|
if (isUpdating) return;
|
||||||
isUpdating = true;
|
isUpdating = true;
|
||||||
|
|||||||
@@ -52,7 +52,9 @@ app.registerExtension({
|
|||||||
// Find all Lora nodes
|
// Find all Lora nodes
|
||||||
const loraNodes = [];
|
const loraNodes = [];
|
||||||
for (const node of workflow.nodes.values()) {
|
for (const node of workflow.nodes.values()) {
|
||||||
if (node.type === "Lora Loader (LoraManager)" || node.type === "Lora Stacker (LoraManager)") {
|
if (node.type === "Lora Loader (LoraManager)" ||
|
||||||
|
node.type === "Lora Stacker (LoraManager)" ||
|
||||||
|
node.type === "WanVideo Lora Select (LoraManager)") {
|
||||||
loraNodes.push({
|
loraNodes.push({
|
||||||
node_id: node.id,
|
node_id: node.id,
|
||||||
bgcolor: node.bgcolor || null,
|
bgcolor: node.bgcolor || null,
|
||||||
|
|||||||
131
web/comfyui/wanvideo_lora_select.js
Normal file
131
web/comfyui/wanvideo_lora_select.js
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
import { app } from "../../scripts/app.js";
|
||||||
|
import {
|
||||||
|
LORA_PATTERN,
|
||||||
|
getActiveLorasFromNode,
|
||||||
|
collectActiveLorasFromChain,
|
||||||
|
updateConnectedTriggerWords,
|
||||||
|
chainCallback
|
||||||
|
} from "./utils.js";
|
||||||
|
import { addLorasWidget } from "./loras_widget.js";
|
||||||
|
|
||||||
|
function mergeLoras(lorasText, lorasArr) {
|
||||||
|
const result = [];
|
||||||
|
let match;
|
||||||
|
|
||||||
|
// Reset pattern index before using
|
||||||
|
LORA_PATTERN.lastIndex = 0;
|
||||||
|
|
||||||
|
// Parse text input and create initial entries
|
||||||
|
while ((match = LORA_PATTERN.exec(lorasText)) !== null) {
|
||||||
|
const name = match[1];
|
||||||
|
const modelStrength = Number(match[2]);
|
||||||
|
// Extract clip strength if provided, otherwise use model strength
|
||||||
|
const clipStrength = match[3] ? Number(match[3]) : modelStrength;
|
||||||
|
|
||||||
|
// Find if this lora exists in the array data
|
||||||
|
const existingLora = lorasArr.find(l => l.name === name);
|
||||||
|
|
||||||
|
result.push({
|
||||||
|
name: name,
|
||||||
|
// Use existing strength if available, otherwise use input strength
|
||||||
|
strength: existingLora ? existingLora.strength : modelStrength,
|
||||||
|
active: existingLora ? existingLora.active : true,
|
||||||
|
clipStrength: existingLora ? existingLora.clipStrength : clipStrength,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
app.registerExtension({
|
||||||
|
name: "LoraManager.WanVideoLoraSelect",
|
||||||
|
|
||||||
|
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
||||||
|
if (nodeType.comfyClass === "WanVideo Lora Select (LoraManager)") {
|
||||||
|
chainCallback(nodeType.prototype, "onNodeCreated", async function() {
|
||||||
|
// Enable widget serialization
|
||||||
|
this.serialize_widgets = true;
|
||||||
|
|
||||||
|
// Add optional inputs
|
||||||
|
this.addInput("prev_lora", 'WANVIDLORA', {
|
||||||
|
"shape": 7 // 7 is the shape of the optional input
|
||||||
|
});
|
||||||
|
|
||||||
|
this.addInput("blocks", 'SELECTEDBLOCKS', {
|
||||||
|
"shape": 7 // 7 is the shape of the optional input
|
||||||
|
});
|
||||||
|
|
||||||
|
// Restore saved value if exists
|
||||||
|
let existingLoras = [];
|
||||||
|
if (this.widgets_values && this.widgets_values.length > 0) {
|
||||||
|
// 0 for low_mem_load, 1 for text widget, 2 for loras widget
|
||||||
|
const savedValue = this.widgets_values[2];
|
||||||
|
existingLoras = savedValue || [];
|
||||||
|
}
|
||||||
|
// Merge the loras data
|
||||||
|
const mergedLoras = mergeLoras(this.widgets[1].value, existingLoras);
|
||||||
|
|
||||||
|
// Add flag to prevent callback loops
|
||||||
|
let isUpdating = false;
|
||||||
|
|
||||||
|
const result = addLorasWidget(this, "loras", {
|
||||||
|
defaultVal: mergedLoras // Pass object directly
|
||||||
|
}, (value) => {
|
||||||
|
// Prevent recursive calls
|
||||||
|
if (isUpdating) return;
|
||||||
|
isUpdating = true;
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Remove loras that are not in the value array
|
||||||
|
const inputWidget = this.widgets[1];
|
||||||
|
const currentLoras = value.map(l => l.name);
|
||||||
|
|
||||||
|
// Use the constant pattern here as well
|
||||||
|
let newText = inputWidget.value.replace(LORA_PATTERN, (match, name, strength) => {
|
||||||
|
return currentLoras.includes(name) ? match : '';
|
||||||
|
});
|
||||||
|
|
||||||
|
// Clean up multiple spaces and trim
|
||||||
|
newText = newText.replace(/\s+/g, ' ').trim();
|
||||||
|
|
||||||
|
inputWidget.value = newText;
|
||||||
|
|
||||||
|
// Update this node's direct trigger toggles with its own active loras
|
||||||
|
const activeLoraNames = new Set();
|
||||||
|
value.forEach(lora => {
|
||||||
|
if (lora.active) {
|
||||||
|
activeLoraNames.add(lora.name);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
updateConnectedTriggerWords(this, activeLoraNames);
|
||||||
|
} finally {
|
||||||
|
isUpdating = false;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
this.lorasWidget = result.widget;
|
||||||
|
|
||||||
|
// Update input widget callback
|
||||||
|
const inputWidget = this.widgets[1];
|
||||||
|
this.inputWidget = inputWidget;
|
||||||
|
inputWidget.callback = (value) => {
|
||||||
|
if (isUpdating) return;
|
||||||
|
isUpdating = true;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const currentLoras = this.lorasWidget.value || [];
|
||||||
|
const mergedLoras = mergeLoras(value, currentLoras);
|
||||||
|
|
||||||
|
this.lorasWidget.value = mergedLoras;
|
||||||
|
|
||||||
|
// Update this node's direct trigger toggles with its own active loras
|
||||||
|
const activeLoraNames = getActiveLorasFromNode(this);
|
||||||
|
updateConnectedTriggerWords(this, activeLoraNames);
|
||||||
|
} finally {
|
||||||
|
isUpdating = false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
BIN
wiki-images/civitai-model-page.png
Normal file
BIN
wiki-images/civitai-model-page.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.1 MiB |
BIN
wiki-images/civitai-models-page.png
Normal file
BIN
wiki-images/civitai-models-page.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.0 MiB |
Reference in New Issue
Block a user