Compare commits

..

35 Commits

Author SHA1 Message Date
Will Miao
6f8e09fcde chore: Update version to 0.8.20-beta in pyproject.toml 2025-07-10 18:48:56 +08:00
Will Miao
f54d480f03 refactor: Update section title and improve alignment in README for Browser Extension 2025-07-10 18:43:12 +08:00
Will Miao
e68b213fb3 feat: Add LM Civitai Extension details to README and update release notes for v0.8.20 2025-07-10 18:37:22 +08:00
Will Miao
132334d500 feat: Add new content indicators for Documentation tab and update links in modals 2025-07-10 17:39:59 +08:00
Will Miao
a6f04c6d7e refactor: Remove unused imports and dependencies from utils, recipe_routes, requirements, and pyproject files. See #278 2025-07-10 16:36:28 +08:00
Will Miao
854e8bf356 feat: Adjust CivitaiClient.get_model_version logic to handle API changes — querying by model ID no longer includes image generation metadata. Fixes #279 2025-07-10 15:29:34 +08:00
Will Miao
6ff883d2d3 fix: Update diffusers version requirement to >=0.33.1 in requirements.txt. See #278 2025-07-10 10:55:13 +08:00
Will Miao
849b97afba feat: Add CR_ApplyControlNetStack extractor and enhance prompt conditioning handling in metadata processing. Fixes #277 2025-07-10 09:26:53 +08:00
Will Miao
1bd2635864 feat: Add smZ_CLIPTextEncode extractor to NODE_EXTRACTORS. See #277 2025-07-09 22:56:56 +08:00
Will Miao
79ab0f7b6c refactor: Update folder loading to fetch dynamically from API in DownloadManager and MoveManager. Fixes #274 2025-07-09 20:29:49 +08:00
Will Miao
79011bd257 refactor: Update model_id and model_version_id types to integers and add validation in routes 2025-07-09 14:21:49 +08:00
Will Miao
c692713ffb refactor: Simplify model version existence checks and enhance version retrieval methods in scanners 2025-07-09 10:26:03 +08:00
pixelpaws
df9b554ce1 Merge pull request #267 from younyokel/patch-2
Update requirements.txt
2025-07-08 21:24:49 +08:00
Will Miao
277a8e4682 Add wiki images 2025-07-08 10:05:43 +08:00
Will Miao
acb52dba09 refactor: Remove redundant local file fallback and debug logs in showcase file handling 2025-07-07 16:34:19 +08:00
Will Miao
8f10765254 feat: Add health check route to MiscRoutes for server status monitoring 2025-07-06 21:40:47 +08:00
Will Miao
0653f59473 feat: Enhance relative path handling in download manager to include base model 2025-07-03 10:28:52 +08:00
Will Miao
7a4b5a4667 feat: Implement download progress WebSocket and enhance download manager with unique IDs 2025-07-02 23:48:35 +08:00
Will Miao
49c4a4068b feat: Add default checkpoint root setting with dynamic options in settings modal 2025-07-02 21:46:21 +08:00
Will Miao
40ad590046 refactor: Update checkpoint handling to use base_models_roots and streamline path management 2025-07-02 21:29:41 +08:00
Will Miao
30374ae3e6 feat: Add ServiceRegistry import to routes_common.py for improved service management 2025-07-02 19:24:04 +08:00
Will Miao
ab22d16bad feat: Rename download endpoint from /api/download-lora to /api/download-model and update related logic 2025-07-02 19:21:25 +08:00
Will Miao
971cd56a4a feat: Update WebSocket endpoint for checkpoint progress and adjust related routes 2025-07-02 18:38:02 +08:00
Will Miao
d7cb546c5f refactor: Simplify model download handling by consolidating download logic and updating parameter usage 2025-07-02 18:25:42 +08:00
Will Miao
9d8b7344cd feat: Enhance Civitai image metadata parser to prevent duplicate LoRAs 2025-07-02 16:50:19 +08:00
Will Miao
2d4f6ae7ce feat: Add route to check if a model exists in the library 2025-07-02 14:45:19 +08:00
Edward Johan
d9126807b0 Update requirements.txt 2025-07-01 00:13:29 +05:00
Will Miao
cad5fb3fba feat: Add mock module creation for py/nodes directory to prevent loading modules from the nodes directory 2025-06-30 20:19:37 +08:00
Will Miao
afe23ad6b7 fix: Update project description for clarity and engagement 2025-06-30 15:21:50 +08:00
Will Miao
fc4327087b Add WanVideo Lora Select node and related functionality. Fixes #266
- Implemented the WanVideo Lora Select node in Python with input handling for low memory loading and LORA syntax processing.
- Updated the JavaScript side to register the new node and manage its widget interactions.
- Enhanced constants files to include the new node type and its corresponding ID.
- Modified existing Lora Loader and Stacker references to accommodate the new node in various workflows and UI components.
- Added example workflow JSON for the new node to demonstrate its usage.
2025-06-30 15:10:34 +08:00
Will Miao
71762d788f Add Lora Loader node support for Nunchaku SVDQuant FLUX model architecture with template workflow. Fixes #255 2025-06-29 23:57:50 +08:00
Will Miao
6472e00fb0 fix: Update EXTRANETS_REGEX to allow for hyphens in hypernet identifiers. Fixes #264 2025-06-29 16:48:02 +08:00
pixelpaws
4043846767 Merge pull request #261 from Rauks/add-flux-kontext
feat: Add "Flux.1 Kontext" base model
2025-06-28 21:10:51 +08:00
Karl Woditsch
d3b2bc962c feat: Add "Flux.1 Kontext" base model 2025-06-28 15:01:26 +02:00
Will Miao
54f7b64821 Replace Chart.js CDN link with local path for statistics page. Fixes #260 2025-06-28 20:53:00 +08:00
55 changed files with 1330 additions and 580 deletions

View File

@@ -18,10 +18,28 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
[![One-Click LoRA Integration Tutorial](https://img.youtube.com/vi/hvKw31YpE-U/0.jpg)](https://youtu.be/hvKw31YpE-U) [![One-Click LoRA Integration Tutorial](https://img.youtube.com/vi/hvKw31YpE-U/0.jpg)](https://youtu.be/hvKw31YpE-U)
## 🌐 Browser Extension
Enhance your Civitai browsing experience with our companion browser extension! See which models you already have, download new ones with a single click, and manage your downloads efficiently.
![LM Civitai Extension Preview](https://github.com/willmiao/ComfyUI-Lora-Manager/blob/main/wiki-images/civitai-models-page.png)
<div>
<a href="https://chromewebstore.google.com/detail/lm-civitai-extension/capigligggeijgmocnaflanlbghnamgm?utm_source=item-share-cb" style="display: inline-block; background-color: #4285F4; color: white; padding: 8px 16px; text-decoration: none; border-radius: 4px; font-weight: bold; margin: 10px 0;">
<img src="https://www.google.com/chrome/static/images/chrome-logo.svg" width="20" style="vertical-align: middle; margin-right: 8px;"> Get Extension from Chrome Web Store
</a>
</div>
📚 [Learn More: Complete Tutorial](https://github.com/willmiao/ComfyUI-Lora-Manager/wiki/LoRA-Manager-Civitai-Extension-(Chrome-Extension))
--- ---
## Release Notes ## Release Notes
### v0.8.20
* **LM Civitai Extension** - Released [browser extension through Chrome Web Store](https://chromewebstore.google.com/detail/lm-civitai-extension/capigligggeijgmocnaflanlbghnamgm?utm_source=item-share-cb) that works seamlessly with LoRA Manager to enhance Civitai browsing experience, showing which models are already in your local library, enabling one-click downloads, and providing queue and parallel download support
* **Enhanced Lora Loader** - Added support for nunchaku, improving convenience when working with ComfyUI-nunchaku workflows, plus new template workflows for quick onboarding
* **WanVideo Integration** - Introduced WanVideo Lora Select (LoraManager) node compatible with ComfyUI-WanVideoWrapper for streamlined lora usage in video workflows, including a template workflow to help you get started quickly
### v0.8.19 ### v0.8.19
* **Analytics Dashboard** - Added new Statistics page providing comprehensive visual analysis of model collection and usage patterns for better library insights * **Analytics Dashboard** - Added new Statistics page providing comprehensive visual analysis of model collection and usage patterns for better library insights
* **Target Node Selection** - Enhanced workflow integration with intelligent target choosing when sending LoRAs/recipes to workflows with multiple loader/stacker nodes; a visual selector now appears showing node color, type, ID, and title for precise targeting * **Target Node Selection** - Enhanced workflow integration with intelligent target choosing when sending LoRAs/recipes to workflows with multiple loader/stacker nodes; a visual selector now appears showing node color, type, ID, and title for precise targeting

View File

@@ -4,6 +4,7 @@ from .py.nodes.trigger_word_toggle import TriggerWordToggle
from .py.nodes.lora_stacker import LoraStacker from .py.nodes.lora_stacker import LoraStacker
from .py.nodes.save_image import SaveImage from .py.nodes.save_image import SaveImage
from .py.nodes.debug_metadata import DebugMetadata from .py.nodes.debug_metadata import DebugMetadata
from .py.nodes.wanvideo_lora_select import WanVideoLoraSelect
# Import metadata collector to install hooks on startup # Import metadata collector to install hooks on startup
from .py.metadata_collector import init as init_metadata_collector from .py.metadata_collector import init as init_metadata_collector
@@ -12,7 +13,8 @@ NODE_CLASS_MAPPINGS = {
TriggerWordToggle.NAME: TriggerWordToggle, TriggerWordToggle.NAME: TriggerWordToggle,
LoraStacker.NAME: LoraStacker, LoraStacker.NAME: LoraStacker,
SaveImage.NAME: SaveImage, SaveImage.NAME: SaveImage,
DebugMetadata.NAME: DebugMetadata DebugMetadata.NAME: DebugMetadata,
WanVideoLoraSelect.NAME: WanVideoLoraSelect
} }
WEB_DIRECTORY = "./web/comfyui" WEB_DIRECTORY = "./web/comfyui"

Binary file not shown.

After

Width:  |  Height:  |  Size: 68 KiB

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -22,7 +22,9 @@ class Config:
# 静态路由映射字典, target to route mapping # 静态路由映射字典, target to route mapping
self._route_mappings = {} self._route_mappings = {}
self.loras_roots = self._init_lora_paths() self.loras_roots = self._init_lora_paths()
self.checkpoints_roots = self._init_checkpoint_paths() self.checkpoints_roots = None
self.unet_roots = None
self.base_models_roots = self._init_checkpoint_paths()
# 在初始化时扫描符号链接 # 在初始化时扫描符号链接
self._scan_symbolic_links() self._scan_symbolic_links()
@@ -33,34 +35,26 @@ class Config:
def save_folder_paths_to_settings(self): def save_folder_paths_to_settings(self):
"""Save folder paths to settings.json for standalone mode to use later""" """Save folder paths to settings.json for standalone mode to use later"""
try: try:
# Check if we're running in ComfyUI mode (not standalone) # Check if we're running in ComfyUI mode (not standalone)
if hasattr(folder_paths, "get_folder_paths") and not isinstance(folder_paths, type): # Load existing settings
# Get all relevant paths settings_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json')
lora_paths = folder_paths.get_folder_paths("loras") settings = {}
checkpoint_paths = folder_paths.get_folder_paths("checkpoints") if os.path.exists(settings_path):
diffuser_paths = folder_paths.get_folder_paths("diffusers") with open(settings_path, 'r', encoding='utf-8') as f:
unet_paths = folder_paths.get_folder_paths("unet") settings = json.load(f)
# Load existing settings # Update settings with paths
settings_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json') settings['folder_paths'] = {
settings = {} 'loras': self.loras_roots,
if os.path.exists(settings_path): 'checkpoints': self.checkpoints_roots,
with open(settings_path, 'r', encoding='utf-8') as f: 'unet': self.unet_roots,
settings = json.load(f) }
# Update settings with paths # Save settings
settings['folder_paths'] = { with open(settings_path, 'w', encoding='utf-8') as f:
'loras': lora_paths, json.dump(settings, f, indent=2)
'checkpoints': checkpoint_paths,
'diffusers': diffuser_paths, logger.info("Saved folder paths to settings.json")
'unet': unet_paths
}
# Save settings
with open(settings_path, 'w', encoding='utf-8') as f:
json.dump(settings, f, indent=2)
logger.info("Saved folder paths to settings.json")
except Exception as e: except Exception as e:
logger.warning(f"Failed to save folder paths: {e}") logger.warning(f"Failed to save folder paths: {e}")
@@ -86,7 +80,7 @@ class Config:
for root in self.loras_roots: for root in self.loras_roots:
self._scan_directory_links(root) self._scan_directory_links(root)
for root in self.checkpoints_roots: for root in self.base_models_roots:
self._scan_directory_links(root) self._scan_directory_links(root)
def _scan_directory_links(self, root: str): def _scan_directory_links(self, root: str):
@@ -178,30 +172,36 @@ class Config:
try: try:
# Get checkpoint paths from folder_paths # Get checkpoint paths from folder_paths
checkpoint_paths = folder_paths.get_folder_paths("checkpoints") checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
diffusion_paths = folder_paths.get_folder_paths("diffusers")
unet_paths = folder_paths.get_folder_paths("unet") unet_paths = folder_paths.get_folder_paths("unet")
# Combine all checkpoint-related paths # Sort each list individually
all_paths = checkpoint_paths + diffusion_paths + unet_paths checkpoint_paths = sorted(set(path.replace(os.sep, "/")
for path in checkpoint_paths
# Filter and normalize paths
paths = sorted(set(path.replace(os.sep, "/")
for path in all_paths
if os.path.exists(path)), key=lambda p: p.lower()) if os.path.exists(path)), key=lambda p: p.lower())
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(paths) if paths else "[]")) unet_paths = sorted(set(path.replace(os.sep, "/")
for path in unet_paths
if os.path.exists(path)), key=lambda p: p.lower())
if not paths: # Combine all checkpoint-related paths, ensuring checkpoint_paths are first
all_paths = checkpoint_paths + unet_paths
self.checkpoints_roots = checkpoint_paths
self.unet_roots = unet_paths
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(all_paths) if all_paths else "[]"))
if not all_paths:
logger.warning("No valid checkpoint folders found in ComfyUI configuration") logger.warning("No valid checkpoint folders found in ComfyUI configuration")
return [] return []
# 初始化路径映射,与 LoRA 路径处理方式相同 # Initialize path mappings, similar to LoRA path handling
for path in paths: for path in all_paths:
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/') real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
if real_path != path: if real_path != path:
self.add_path_mapping(path, real_path) self.add_path_mapping(path, real_path)
return paths return all_paths
except Exception as e: except Exception as e:
logger.warning(f"Error initializing checkpoint paths: {e}") logger.warning(f"Error initializing checkpoint paths: {e}")
return [] return []

View File

@@ -62,7 +62,7 @@ class LoraManager:
added_targets.add(real_root) added_targets.add(real_root)
# Add static routes for each checkpoint root # Add static routes for each checkpoint root
for idx, root in enumerate(config.checkpoints_roots, start=1): for idx, root in enumerate(config.base_models_roots, start=1):
preview_path = f'/checkpoints_static/root{idx}/preview' preview_path = f'/checkpoints_static/root{idx}/preview'
real_root = root real_root = root
@@ -88,8 +88,8 @@ class LoraManager:
for target_path, link_path in config._path_mappings.items(): for target_path, link_path in config._path_mappings.items():
if target_path not in added_targets: if target_path not in added_targets:
# Determine if this is a checkpoint or lora link based on path # Determine if this is a checkpoint or lora link based on path
is_checkpoint = any(cp_root in link_path for cp_root in config.checkpoints_roots) is_checkpoint = any(cp_root in link_path for cp_root in config.base_models_roots)
is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.checkpoints_roots) is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.base_models_roots)
if is_checkpoint: if is_checkpoint:
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview' route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'

View File

@@ -238,25 +238,45 @@ class MetadataProcessor:
pos_conditioning = metadata[PROMPTS][sampler_id].get("pos_conditioning") pos_conditioning = metadata[PROMPTS][sampler_id].get("pos_conditioning")
neg_conditioning = metadata[PROMPTS][sampler_id].get("neg_conditioning") neg_conditioning = metadata[PROMPTS][sampler_id].get("neg_conditioning")
# Try to match conditioning objects with those stored by CLIPTextEncodeExtractor # Helper function to recursively find prompt text for a conditioning object
for prompt_node_id, prompt_data in metadata[PROMPTS].items(): def find_prompt_text_for_conditioning(conditioning_obj, is_positive=True):
# For nodes with single conditioning output if conditioning_obj is None:
if "conditioning" in prompt_data: return ""
if pos_conditioning is not None and id(prompt_data["conditioning"]) == id(pos_conditioning):
result["prompt"] = prompt_data.get("text", "")
if neg_conditioning is not None and id(prompt_data["conditioning"]) == id(neg_conditioning): # Try to match conditioning objects with those stored by extractors
result["negative_prompt"] = prompt_data.get("text", "") for prompt_node_id, prompt_data in metadata[PROMPTS].items():
# For nodes with single conditioning output
if "conditioning" in prompt_data:
if id(prompt_data["conditioning"]) == id(conditioning_obj):
return prompt_data.get("text", "")
# For nodes with separate pos_conditioning and neg_conditioning outputs (like TSC_EfficientLoader)
if is_positive and "positive_encoded" in prompt_data:
if id(prompt_data["positive_encoded"]) == id(conditioning_obj):
if "positive_text" in prompt_data:
return prompt_data["positive_text"]
else:
orig_conditioning = prompt_data.get("orig_pos_cond", None)
if orig_conditioning is not None:
# Recursively find the prompt text for the original conditioning
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=True)
if not is_positive and "negative_encoded" in prompt_data:
if id(prompt_data["negative_encoded"]) == id(conditioning_obj):
if "negative_text" in prompt_data:
return prompt_data["negative_text"]
else:
orig_conditioning = prompt_data.get("orig_neg_cond", None)
if orig_conditioning is not None:
# Recursively find the prompt text for the original conditioning
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=False)
# For nodes with separate pos_conditioning and neg_conditioning outputs (like TSC_EfficientLoader) return ""
if "positive_encoded" in prompt_data:
if pos_conditioning is not None and id(prompt_data["positive_encoded"]) == id(pos_conditioning): # Find prompt texts using the helper function
result["prompt"] = prompt_data.get("positive_text", "") result["prompt"] = find_prompt_text_for_conditioning(pos_conditioning, is_positive=True)
result["negative_prompt"] = find_prompt_text_for_conditioning(neg_conditioning, is_positive=False)
if "negative_encoded" in prompt_data:
if neg_conditioning is not None and id(prompt_data["negative_encoded"]) == id(neg_conditioning):
result["negative_prompt"] = prompt_data.get("negative_text", "")
return result return result

View File

@@ -569,6 +569,40 @@ class CFGGuiderExtractor(NodeMetadataExtractor):
metadata[SAMPLING][node_id]["parameters"]["cfg"] = cfg_value metadata[SAMPLING][node_id]["parameters"]["cfg"] = cfg_value
class CR_ApplyControlNetStackExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
# Save the original conditioning inputs
base_positive = inputs.get("base_positive")
base_negative = inputs.get("base_negative")
if base_positive is not None or base_negative is not None:
if node_id not in metadata[PROMPTS]:
metadata[PROMPTS][node_id] = {"node_id": node_id}
metadata[PROMPTS][node_id]["orig_pos_cond"] = base_positive
metadata[PROMPTS][node_id]["orig_neg_cond"] = base_negative
@staticmethod
def update(node_id, outputs, metadata):
# Extract transformed conditionings from outputs
# outputs structure: [(base_positive, base_negative, show_help, )]
if outputs and isinstance(outputs, list) and len(outputs) > 0:
first_output = outputs[0]
if isinstance(first_output, tuple) and len(first_output) >= 2:
transformed_positive = first_output[0]
transformed_negative = first_output[1]
# Save transformed conditioning objects in metadata
if node_id not in metadata[PROMPTS]:
metadata[PROMPTS][node_id] = {"node_id": node_id}
metadata[PROMPTS][node_id]["positive_encoded"] = transformed_positive
metadata[PROMPTS][node_id]["negative_encoded"] = transformed_negative
# Registry of node-specific extractors # Registry of node-specific extractors
# Keys are node class names # Keys are node class names
NODE_EXTRACTORS = { NODE_EXTRACTORS = {
@@ -594,6 +628,8 @@ NODE_EXTRACTORS = {
"CLIPTextEncodeFlux": CLIPTextEncodeFluxExtractor, # Add CLIPTextEncodeFlux "CLIPTextEncodeFlux": CLIPTextEncodeFluxExtractor, # Add CLIPTextEncodeFlux
"WAS_Text_to_Conditioning": CLIPTextEncodeExtractor, "WAS_Text_to_Conditioning": CLIPTextEncodeExtractor,
"AdvancedCLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/BlenderNeko/ComfyUI_ADV_CLIP_emb "AdvancedCLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/BlenderNeko/ComfyUI_ADV_CLIP_emb
"smZ_CLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/shiimizu/ComfyUI_smZNodes
"CR_ApplyControlNetStack": CR_ApplyControlNetStackExtractor, # Add CR_ApplyControlNetStack
# Latent # Latent
"EmptyLatentImage": ImageSizeExtractor, "EmptyLatentImage": ImageSizeExtractor,
# Flux # Flux

View File

@@ -2,14 +2,15 @@ import logging
from nodes import LoraLoader from nodes import LoraLoader
from comfy.comfy_types import IO # type: ignore from comfy.comfy_types import IO # type: ignore
import asyncio import asyncio
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list from ..utils.utils import get_lora_info
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list, nunchaku_load_lora
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LoraManagerLoader: class LoraManagerLoader:
NAME = "Lora Loader (LoraManager)" NAME = "Lora Loader (LoraManager)"
CATEGORY = "Lora Manager/loaders" CATEGORY = "Lora Manager/loaders"
@classmethod @classmethod
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
return { return {
@@ -37,19 +38,39 @@ class LoraManagerLoader:
clip = kwargs.get('clip', None) clip = kwargs.get('clip', None)
lora_stack = kwargs.get('lora_stack', None) lora_stack = kwargs.get('lora_stack', None)
# Check if model is a Nunchaku Flux model - simplified approach
is_nunchaku_model = False
try:
model_wrapper = model.model.diffusion_model
# Check if model is a Nunchaku Flux model using only class name
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
is_nunchaku_model = True
logger.info("Detected Nunchaku Flux model")
except (AttributeError, TypeError):
# Not a model with the expected structure
pass
# First process lora_stack if available # First process lora_stack if available
if lora_stack: if lora_stack:
for lora_path, model_strength, clip_strength in lora_stack: for lora_path, model_strength, clip_strength in lora_stack:
# Apply the LoRA using the provided path and strengths # Apply the LoRA using the appropriate loader
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength) if is_nunchaku_model:
# Use our custom function for Flux models
model = nunchaku_load_lora(model, lora_path, model_strength)
# clip remains unchanged for Nunchaku models
else:
# Use default loader for standard models
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
# Extract lora name for trigger words lookup # Extract lora name for trigger words lookup
lora_name = extract_lora_name(lora_path) lora_name = extract_lora_name(lora_path)
_, trigger_words = asyncio.run(get_lora_info(lora_name)) _, trigger_words = asyncio.run(get_lora_info(lora_name))
all_trigger_words.extend(trigger_words) all_trigger_words.extend(trigger_words)
# Add clip strength to output if different from model strength # Add clip strength to output if different from model strength (except for Nunchaku models)
if abs(model_strength - clip_strength) > 0.001: if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}") loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
else: else:
loaded_loras.append(f"{lora_name}: {model_strength}") loaded_loras.append(f"{lora_name}: {model_strength}")
@@ -68,11 +89,17 @@ class LoraManagerLoader:
# Get lora path and trigger words # Get lora path and trigger words
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name)) lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
# Apply the LoRA using the resolved path with separate strengths # Apply the LoRA using the appropriate loader
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength) if is_nunchaku_model:
# For Nunchaku models, use our custom function
model = nunchaku_load_lora(model, lora_path, model_strength)
# clip remains unchanged
else:
# Use default loader for standard models
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
# Include clip strength in output if different from model strength # Include clip strength in output if different from model strength and not a Nunchaku model
if abs(model_strength - clip_strength) > 0.001: if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}") loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
else: else:
loaded_loras.append(f"{lora_name}: {model_strength}") loaded_loras.append(f"{lora_name}: {model_strength}")

View File

@@ -1,9 +1,9 @@
from comfy.comfy_types import IO # type: ignore from comfy.comfy_types import IO # type: ignore
from ..services.lora_scanner import LoraScanner
from ..config import config
import asyncio import asyncio
import os import os
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list from ..utils.utils import get_lora_info
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -35,31 +35,11 @@ any_type = AnyType("*")
# Common methods extracted from lora_loader.py and lora_stacker.py # Common methods extracted from lora_loader.py and lora_stacker.py
import os import os
import logging import logging
import asyncio import copy
from ..services.lora_scanner import LoraScanner import folder_paths
from ..config import config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def get_lora_info(lora_name):
"""Get the lora path and trigger words from cache"""
scanner = await LoraScanner.get_instance()
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
file_path = item.get('file_path')
if file_path:
for root in config.loras_roots:
root = root.replace(os.sep, '/')
if file_path.startswith(root):
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
# Get trigger words from civitai metadata
civitai = item.get('civitai', {})
trigger_words = civitai.get('trainedWords', []) if civitai else []
return relative_path, trigger_words
return lora_name, [] # Fallback if not found
def extract_lora_name(lora_path): def extract_lora_name(lora_path):
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')""" """Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
# Get the basename without extension # Get the basename without extension
@@ -81,4 +61,70 @@ def get_loras_list(kwargs):
# Unexpected format # Unexpected format
else: else:
logger.warning(f"Unexpected loras format: {type(loras_data)}") logger.warning(f"Unexpected loras format: {type(loras_data)}")
return [] return []
def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""):
"""Simplified version of load_state_dict_in_safetensors that just loads from a local path"""
import safetensors.torch
state_dict = {}
with safetensors.torch.safe_open(path, framework="pt", device=device) as f:
for k in f.keys():
if filter_prefix and not k.startswith(filter_prefix):
continue
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
return state_dict
def to_diffusers(input_lora):
"""Simplified version of to_diffusers for Flux LoRA conversion"""
import torch
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
from diffusers.loaders import FluxLoraLoaderMixin
if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
else:
tensors = {k: v for k, v in input_lora.items()}
# Convert FP8 tensors to BF16
for k, v in tensors.items():
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
tensors[k] = v.to(torch.bfloat16)
new_tensors = FluxLoraLoaderMixin.lora_state_dict(tensors)
new_tensors = convert_unet_state_dict_to_peft(new_tensors)
return new_tensors
def nunchaku_load_lora(model, lora_name, lora_strength):
"""Load a Flux LoRA for Nunchaku model"""
model_wrapper = model.model.diffusion_model
transformer = model_wrapper.model
# Save the transformer temporarily
model_wrapper.model = None
ret_model = copy.deepcopy(model) # copy everything except the model
ret_model_wrapper = ret_model.model.diffusion_model
# Restore the model and set it for the copy
model_wrapper.model = transformer
ret_model_wrapper.model = transformer
# Get full path to the LoRA file
lora_path = folder_paths.get_full_path("loras", lora_name)
ret_model_wrapper.loras.append((lora_path, lora_strength))
# Convert the LoRA to diffusers format
sd = to_diffusers(lora_path)
# Handle embedding adjustment if needed
if "transformer.x_embedder.lora_A.weight" in sd:
new_in_channels = sd["transformer.x_embedder.lora_A.weight"].shape[1]
assert new_in_channels % 4 == 0
new_in_channels = new_in_channels // 4
old_in_channels = ret_model.model.model_config.unet_config["in_channels"]
if old_in_channels < new_in_channels:
ret_model.model.model_config.unet_config["in_channels"] = new_in_channels
return ret_model

View File

@@ -0,0 +1,93 @@
from comfy.comfy_types import IO # type: ignore
import asyncio
import folder_paths # type: ignore
from ..utils.utils import get_lora_info
from .utils import FlexibleOptionalInputType, any_type, get_loras_list
import logging
logger = logging.getLogger(__name__)
class WanVideoLoraSelect:
NAME = "WanVideo Lora Select (LoraManager)"
CATEGORY = "Lora Manager/stackers"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load the LORA model with less VRAM usage, slower loading"}),
"text": (IO.STRING, {
"multiline": True,
"dynamicPrompts": True,
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
"placeholder": "LoRA syntax input: <lora:name:strength>"
}),
},
"optional": FlexibleOptionalInputType(any_type),
}
RETURN_TYPES = ("WANVIDLORA", IO.STRING, IO.STRING)
RETURN_NAMES = ("lora", "trigger_words", "active_loras")
FUNCTION = "process_loras"
def process_loras(self, text, low_mem_load=False, **kwargs):
loras_list = []
all_trigger_words = []
active_loras = []
# Process existing prev_lora if available
prev_lora = kwargs.get('prev_lora', None)
if prev_lora is not None:
loras_list.extend(prev_lora)
# Get blocks if available
blocks = kwargs.get('blocks', {})
selected_blocks = blocks.get("selected_blocks", {})
layer_filter = blocks.get("layer_filter", "")
# Process loras from kwargs with support for both old and new formats
loras_from_widget = get_loras_list(kwargs)
for lora in loras_from_widget:
if not lora.get('active', False):
continue
lora_name = lora['name']
model_strength = float(lora['strength'])
clip_strength = float(lora.get('clipStrength', model_strength))
# Get lora path and trigger words
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
# Create lora item for WanVideo format
lora_item = {
"path": folder_paths.get_full_path("loras", lora_path),
"strength": model_strength,
"name": lora_path.split(".")[0],
"blocks": selected_blocks,
"layer_filter": layer_filter,
"low_mem_load": low_mem_load,
}
# Add to list and collect active loras
loras_list.append(lora_item)
active_loras.append((lora_name, model_strength, clip_strength))
# Add trigger words to collection
all_trigger_words.extend(trigger_words)
# Format trigger_words for output
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
# Format active_loras for output
formatted_loras = []
for name, model_strength, clip_strength in active_loras:
if abs(model_strength - clip_strength) > 0.001:
# Different model and clip strengths
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}:{str(clip_strength).strip()}>")
else:
# Same strength for both
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}>")
active_loras_text = " ".join(formatted_loras)
return (loras_list, trigger_words_text, active_loras_text)

View File

@@ -19,7 +19,7 @@ class AutomaticMetadataParser(RecipeMetadataParser):
LORA_HASHES_REGEX = r', Lora hashes:\s*"([^"]+)"' LORA_HASHES_REGEX = r', Lora hashes:\s*"([^"]+)"'
CIVITAI_RESOURCES_REGEX = r', Civitai resources:\s*(\[\{.*?\}\])' CIVITAI_RESOURCES_REGEX = r', Civitai resources:\s*(\[\{.*?\}\])'
CIVITAI_METADATA_REGEX = r', Civitai metadata:\s*(\{.*?\})' CIVITAI_METADATA_REGEX = r', Civitai metadata:\s*(\{.*?\})'
EXTRANETS_REGEX = r'<(lora|hypernet):([a-zA-Z0-9_\.\-]+):([0-9.]+)>' EXTRANETS_REGEX = r'<(lora|hypernet):([^:]+):(-?[0-9.]+)>'
MODEL_HASH_PATTERN = r'Model hash: ([a-zA-Z0-9]+)' MODEL_HASH_PATTERN = r'Model hash: ([a-zA-Z0-9]+)'
VAE_HASH_PATTERN = r'VAE hash: ([a-zA-Z0-9]+)' VAE_HASH_PATTERN = r'VAE hash: ([a-zA-Z0-9]+)'

View File

@@ -50,6 +50,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
'from_civitai_image': True 'from_civitai_image': True
} }
# Track already added LoRAs to prevent duplicates
added_loras = {} # key: model_version_id or hash, value: index in result["loras"]
# Extract prompt and negative prompt # Extract prompt and negative prompt
if "prompt" in metadata: if "prompt" in metadata:
result["gen_params"]["prompt"] = metadata["prompt"] result["gen_params"]["prompt"] = metadata["prompt"]
@@ -96,11 +99,17 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
for resource in metadata["resources"]: for resource in metadata["resources"]:
# Modified to process resources without a type field as potential LoRAs # Modified to process resources without a type field as potential LoRAs
if resource.get("type", "lora") == "lora": if resource.get("type", "lora") == "lora":
lora_hash = resource.get("hash", "")
# Skip if we've already added this LoRA by hash
if lora_hash and lora_hash in added_loras:
continue
lora_entry = { lora_entry = {
'name': resource.get("name", "Unknown LoRA"), 'name': resource.get("name", "Unknown LoRA"),
'type': "lora", 'type': "lora",
'weight': float(resource.get("weight", 1.0)), 'weight': float(resource.get("weight", 1.0)),
'hash': resource.get("hash", ""), 'hash': lora_hash,
'existsLocally': False, 'existsLocally': False,
'localPath': None, 'localPath': None,
'file_name': resource.get("name", "Unknown"), 'file_name': resource.get("name", "Unknown"),
@@ -114,7 +123,6 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
# Try to get info from Civitai if hash is available # Try to get info from Civitai if hash is available
if lora_entry['hash'] and civitai_client: if lora_entry['hash'] and civitai_client:
try: try:
lora_hash = lora_entry['hash']
civitai_info = await civitai_client.get_model_by_hash(lora_hash) civitai_info = await civitai_client.get_model_by_hash(lora_hash)
populated_entry = await self.populate_lora_from_civitai( populated_entry = await self.populate_lora_from_civitai(
@@ -129,43 +137,124 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
continue # Skip invalid LoRA types continue # Skip invalid LoRA types
lora_entry = populated_entry lora_entry = populated_entry
# If we have a version ID from Civitai, track it for deduplication
if 'id' in lora_entry and lora_entry['id']:
added_loras[str(lora_entry['id'])] = len(result["loras"])
except Exception as e: except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}") logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}")
# Track by hash if we have it
if lora_hash:
added_loras[lora_hash] = len(result["loras"])
result["loras"].append(lora_entry) result["loras"].append(lora_entry)
# Process civitaiResources array # Process civitaiResources array
if "civitaiResources" in metadata and isinstance(metadata["civitaiResources"], list): if "civitaiResources" in metadata and isinstance(metadata["civitaiResources"], list):
for resource in metadata["civitaiResources"]: for resource in metadata["civitaiResources"]:
# Modified to process resources without a type field as potential LoRAs # Skip resources that aren't LoRAs or LyCORIS
if resource.get("type") in ["lora", "lycoris"] or "type" not in resource: if resource.get("type") not in ["lora", "lycoris"] and "type" not in resource:
# Initialize lora entry with the same structure as in automatic.py continue
lora_entry = {
'id': resource.get("modelVersionId", 0),
'modelId': resource.get("modelId", 0),
'name': resource.get("modelName", "Unknown LoRA"),
'version': resource.get("modelVersionName", ""),
'type': resource.get("type", "lora"),
'weight': round(float(resource.get("weight", 1.0)), 2),
'existsLocally': False,
'thumbnailUrl': '/loras_static/images/no-preview.png',
'baseModel': '',
'size': 0,
'downloadUrl': '',
'isDeleted': False
}
# Try to get info from Civitai if modelVersionId is available # Get unique identifier for deduplication
if resource.get('modelVersionId') and civitai_client: version_id = str(resource.get("modelVersionId", ""))
try:
version_id = str(resource.get('modelVersionId')) # Skip if we've already added this LoRA
# Use get_model_version_info instead of get_model_version if version_id and version_id in added_loras:
civitai_info, error = await civitai_client.get_model_version_info(version_id) continue
if error: # Initialize lora entry
logger.warning(f"Error getting model version info: {error}") lora_entry = {
continue 'id': resource.get("modelVersionId", 0),
'modelId': resource.get("modelId", 0),
'name': resource.get("modelName", "Unknown LoRA"),
'version': resource.get("modelVersionName", ""),
'type': resource.get("type", "lora"),
'weight': round(float(resource.get("weight", 1.0)), 2),
'existsLocally': False,
'thumbnailUrl': '/loras_static/images/no-preview.png',
'baseModel': '',
'size': 0,
'downloadUrl': '',
'isDeleted': False
}
# Try to get info from Civitai if modelVersionId is available
if version_id and civitai_client:
try:
# Use get_model_version_info instead of get_model_version
civitai_info, error = await civitai_client.get_model_version_info(version_id)
if error:
logger.warning(f"Error getting model version info: {error}")
continue
populated_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
recipe_scanner,
base_model_counts
)
if populated_entry is None:
continue # Skip invalid LoRA types
lora_entry = populated_entry
except Exception as e:
logger.error(f"Error fetching Civitai info for model version {version_id}: {e}")
# Track this LoRA in our deduplication dict
if version_id:
added_loras[version_id] = len(result["loras"])
result["loras"].append(lora_entry)
# Process additionalResources array
if "additionalResources" in metadata and isinstance(metadata["additionalResources"], list):
for resource in metadata["additionalResources"]:
# Skip resources that aren't LoRAs or LyCORIS
if resource.get("type") not in ["lora", "lycoris"] and "type" not in resource:
continue
lora_type = resource.get("type", "lora")
name = resource.get("name", "")
# Extract ID from URN format if available
version_id = None
if name and "civitai:" in name:
parts = name.split("@")
if len(parts) > 1:
version_id = parts[1]
# Skip if we've already added this LoRA
if version_id in added_loras:
continue
lora_entry = {
'name': name,
'type': lora_type,
'weight': float(resource.get("strength", 1.0)),
'hash': "",
'existsLocally': False,
'localPath': None,
'file_name': name,
'thumbnailUrl': '/loras_static/images/no-preview.png',
'baseModel': '',
'size': 0,
'downloadUrl': '',
'isDeleted': False
}
# If we have a version ID and civitai client, try to get more info
if version_id and civitai_client:
try:
# Use get_model_version_info with the version ID
civitai_info, error = await civitai_client.get_model_version_info(version_id)
if error:
logger.warning(f"Error getting model version info: {error}")
else:
populated_entry = await self.populate_lora_from_civitai( populated_entry = await self.populate_lora_from_civitai(
lora_entry, lora_entry,
civitai_info, civitai_info,
@@ -177,65 +266,14 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
continue # Skip invalid LoRA types continue # Skip invalid LoRA types
lora_entry = populated_entry lora_entry = populated_entry
except Exception as e:
logger.error(f"Error fetching Civitai info for model version {resource.get('modelVersionId')}: {e}")
result["loras"].append(lora_entry)
# Process additionalResources array
if "additionalResources" in metadata and isinstance(metadata["additionalResources"], list):
for resource in metadata["additionalResources"]:
# Modified to process resources without a type field as potential LoRAs
if resource.get("type") in ["lora", "lycoris"] or "type" not in resource:
lora_type = resource.get("type", "lora")
name = resource.get("name", "")
# Extract ID from URN format if available
model_id = None
if name and "civitai:" in name:
parts = name.split("@")
if len(parts) > 1:
model_id = parts[1]
lora_entry = {
'name': name,
'type': lora_type,
'weight': float(resource.get("strength", 1.0)),
'hash': "",
'existsLocally': False,
'localPath': None,
'file_name': name,
'thumbnailUrl': '/loras_static/images/no-preview.png',
'baseModel': '',
'size': 0,
'downloadUrl': '',
'isDeleted': False
}
# If we have a model ID and civitai client, try to get more info
if model_id and civitai_client:
try:
# Use get_model_version_info with the model ID
civitai_info, error = await civitai_client.get_model_version_info(model_id)
if error: # Track this LoRA for deduplication
logger.warning(f"Error getting model version info: {error}") if version_id:
else: added_loras[version_id] = len(result["loras"])
populated_entry = await self.populate_lora_from_civitai( except Exception as e:
lora_entry, logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}")
civitai_info,
recipe_scanner, result["loras"].append(lora_entry)
base_model_counts
)
if populated_entry is None:
continue # Skip invalid LoRA types
lora_entry = populated_entry
except Exception as e:
logger.error(f"Error fetching Civitai info for model ID {model_id}: {e}")
result["loras"].append(lora_entry)
# If base model wasn't found earlier, use the most common one from LoRAs # If base model wasn't found earlier, use the most common one from LoRAs
if not result["base_model"] and base_model_counts: if not result["base_model"] and base_model_counts:

View File

@@ -6,7 +6,7 @@ from typing import Dict
from server import PromptServer # type: ignore from server import PromptServer # type: ignore
from ..utils.routes_common import ModelRouteUtils from ..utils.routes_common import ModelRouteUtils
from ..nodes.utils import get_lora_info from ..utils.utils import get_lora_info
from ..config import config from ..config import config
from ..services.websocket_manager import ws_manager from ..services.websocket_manager import ws_manager
@@ -50,13 +50,14 @@ class ApiRoutes:
app.router.add_get('/api/loras', routes.get_loras) app.router.add_get('/api/loras', routes.get_loras)
app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai) app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai)
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection) # Add new WebSocket route for download progress
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) # Add new WebSocket route app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) # Add new WebSocket route
app.router.add_get('/api/lora-roots', routes.get_lora_roots) app.router.add_get('/api/lora-roots', routes.get_lora_roots)
app.router.add_get('/api/folders', routes.get_folders) app.router.add_get('/api/folders', routes.get_folders)
app.router.add_get('/api/civitai/versions/{model_id}', routes.get_civitai_versions) app.router.add_get('/api/civitai/versions/{model_id}', routes.get_civitai_versions)
app.router.add_get('/api/civitai/model/version/{modelVersionId}', routes.get_civitai_model_by_version) app.router.add_get('/api/civitai/model/version/{modelVersionId}', routes.get_civitai_model_by_version)
app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash) app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash)
app.router.add_post('/api/download-lora', routes.download_lora) app.router.add_post('/api/download-model', routes.download_model)
app.router.add_post('/api/move_model', routes.move_model) app.router.add_post('/api/move_model', routes.move_model)
app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route
app.router.add_post('/api/loras/save-metadata', routes.save_metadata) app.router.add_post('/api/loras/save-metadata', routes.save_metadata)
@@ -436,69 +437,8 @@ class ApiRoutes:
"error": str(e) "error": str(e)
}, status=500) }, status=500)
async def download_lora(self, request: web.Request) -> web.Response: async def download_model(self, request: web.Request) -> web.Response:
async with self._download_lock: return await ModelRouteUtils.handle_download_model(request, self.download_manager)
try:
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
data = await request.json()
# Create progress callback
async def progress_callback(progress):
await ws_manager.broadcast({
'status': 'progress',
'progress': progress
})
# Check which identifier is provided
download_url = data.get('download_url')
model_hash = data.get('model_hash')
model_version_id = data.get('model_version_id')
# Validate that at least one identifier is provided
if not any([download_url, model_hash, model_version_id]):
return web.Response(
status=400,
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
)
result = await self.download_manager.download_from_civitai(
download_url=download_url,
model_hash=model_hash,
model_version_id=model_version_id,
save_dir=data.get('lora_root'),
relative_path=data.get('relative_path'),
progress_callback=progress_callback
)
if not result.get('success', False):
error_message = result.get('error', 'Unknown error')
# Return 401 for early access errors
if 'early access' in error_message.lower():
logger.warning(f"Early access download failed: {error_message}")
return web.Response(
status=401, # Use 401 status code to match Civitai's response
text=error_message
)
return web.Response(status=500, text=error_message)
return web.json_response(result)
except Exception as e:
error_message = str(e)
# Check if this might be an early access error
if '401' in error_message:
logger.warning(f"Early access error (401): {error_message}")
return web.Response(
status=401,
text="Early Access Restriction: This LoRA requires purchase. Please buy early access on Civitai.com."
)
logger.error(f"Error downloading LoRA: {error_message}")
return web.Response(status=500, text=error_message)
async def move_model(self, request: web.Request) -> web.Response: async def move_model(self, request: web.Request) -> web.Response:

View File

@@ -54,12 +54,8 @@ class CheckpointsRoutes:
app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai) app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai)
app.router.add_post('/api/checkpoints/relink-civitai', self.relink_civitai) # Add new relink endpoint app.router.add_post('/api/checkpoints/relink-civitai', self.relink_civitai) # Add new relink endpoint
app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview) app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview)
app.router.add_post('/api/checkpoints/download', self.download_checkpoint)
app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route
app.router.add_post('/api/checkpoints/rename', self.rename_checkpoint) # Add new rename endpoint app.router.add_post('/api/checkpoints/rename', self.rename_checkpoint) # Add new rename endpoint
# Add new WebSocket endpoint for checkpoint progress
app.router.add_get('/ws/checkpoint-progress', ws_manager.handle_checkpoint_connection)
# Add new routes for finding duplicates and filename conflicts # Add new routes for finding duplicates and filename conflicts
app.router.add_get('/api/checkpoints/find-duplicates', self.find_duplicate_checkpoints) app.router.add_get('/api/checkpoints/find-duplicates', self.find_duplicate_checkpoints)
@@ -542,74 +538,6 @@ class CheckpointsRoutes:
"""Handle preview image replacement for checkpoints""" """Handle preview image replacement for checkpoints"""
return await ModelRouteUtils.handle_replace_preview(request, self.scanner) return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
async def download_checkpoint(self, request: web.Request) -> web.Response:
"""Handle checkpoint download request"""
async with self._download_lock:
# Get the download manager from service registry if not already initialized
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
try:
data = await request.json()
# Create progress callback that uses checkpoint-specific WebSocket
async def progress_callback(progress):
await ws_manager.broadcast_checkpoint_progress({
'status': 'progress',
'progress': progress
})
# Check which identifier is provided
download_url = data.get('download_url')
model_hash = data.get('model_hash')
model_version_id = data.get('model_version_id')
# Validate that at least one identifier is provided
if not any([download_url, model_hash, model_version_id]):
return web.Response(
status=400,
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
)
result = await self.download_manager.download_from_civitai(
download_url=download_url,
model_hash=model_hash,
model_version_id=model_version_id,
save_dir=data.get('checkpoint_root'),
relative_path=data.get('relative_path', ''),
progress_callback=progress_callback,
model_type="checkpoint"
)
if not result.get('success', False):
error_message = result.get('error', 'Unknown error')
# Return 401 for early access errors
if 'early access' in error_message.lower():
logger.warning(f"Early access download failed: {error_message}")
return web.Response(
status=401,
text=f"Early Access Restriction: {error_message}"
)
return web.Response(status=500, text=error_message)
return web.json_response(result)
except Exception as e:
error_message = str(e)
# Check if this might be an early access error
if '401' in error_message:
logger.warning(f"Early access error (401): {error_message}")
return web.Response(
status=401,
text="Early Access Restriction: This model requires purchase. Please ensure you have purchased early access and are logged in to Civitai."
)
logger.error(f"Error downloading checkpoint: {error_message}")
return web.Response(status=500, text=error_message)
async def get_checkpoint_roots(self, request): async def get_checkpoint_roots(self, request):
"""Return the checkpoint root directories""" """Return the checkpoint root directories"""
try: try:

View File

@@ -1,7 +1,6 @@
import logging import logging
from ..utils.example_images_download_manager import DownloadManager from ..utils.example_images_download_manager import DownloadManager
from ..utils.example_images_processor import ExampleImagesProcessor from ..utils.example_images_processor import ExampleImagesProcessor
from ..utils.example_images_metadata import MetadataUpdater
from ..utils.example_images_file_manager import ExampleImagesFileManager from ..utils.example_images_file_manager import ExampleImagesFileManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -10,6 +10,7 @@ from ..utils.usage_stats import UsageStats
from ..utils.lora_metadata import extract_trained_words from ..utils.lora_metadata import extract_trained_words
from ..config import config from ..config import config
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS, NODE_TYPES, DEFAULT_NODE_COLOR from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS, NODE_TYPES, DEFAULT_NODE_COLOR
from ..services.service_registry import ServiceRegistry
import re import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -90,6 +91,8 @@ class MiscRoutes:
# Add new route for clearing cache # Add new route for clearing cache
app.router.add_post('/api/clear-cache', MiscRoutes.clear_cache) app.router.add_post('/api/clear-cache', MiscRoutes.clear_cache)
app.router.add_get('/api/health-check', lambda request: web.json_response({'status': 'ok'}))
# Usage stats routes # Usage stats routes
app.router.add_post('/api/update-usage-stats', MiscRoutes.update_usage_stats) app.router.add_post('/api/update-usage-stats', MiscRoutes.update_usage_stats)
app.router.add_get('/api/get-usage-stats', MiscRoutes.get_usage_stats) app.router.add_get('/api/get-usage-stats', MiscRoutes.get_usage_stats)
@@ -106,6 +109,9 @@ class MiscRoutes:
# Node registry endpoints # Node registry endpoints
app.router.add_post('/api/register-nodes', MiscRoutes.register_nodes) app.router.add_post('/api/register-nodes', MiscRoutes.register_nodes)
app.router.add_get('/api/get-registry', MiscRoutes.get_registry) app.router.add_get('/api/get-registry', MiscRoutes.get_registry)
# Add new route for checking if a model exists in the library
app.router.add_get('/api/check-model-exists', MiscRoutes.check_model_exists)
@staticmethod @staticmethod
async def clear_cache(request): async def clear_cache(request):
@@ -580,3 +586,106 @@ class MiscRoutes:
'error': 'Internal Error', 'error': 'Internal Error',
'message': str(e) 'message': str(e)
}, status=500) }, status=500)
@staticmethod
async def check_model_exists(request):
"""
Check if a model with specified modelId and optionally modelVersionId exists in the library
Expects query parameters:
- modelId: int - Civitai model ID (required)
- modelVersionId: int - Civitai model version ID (optional)
Returns:
- If modelVersionId is provided: JSON with a boolean 'exists' field
- If modelVersionId is not provided: JSON with a list of modelVersionIds that exist in the library
"""
try:
# Get the modelId and modelVersionId from query parameters
model_id_str = request.query.get('modelId')
model_version_id_str = request.query.get('modelVersionId')
# Validate modelId parameter (required)
if not model_id_str:
return web.json_response({
'success': False,
'error': 'Missing required parameter: modelId'
}, status=400)
try:
# Convert modelId to integer
model_id = int(model_id_str)
except ValueError:
return web.json_response({
'success': False,
'error': 'Parameter modelId must be an integer'
}, status=400)
# Get both lora and checkpoint scanners
registry = ServiceRegistry.get_instance()
lora_scanner = await registry.get_lora_scanner()
checkpoint_scanner = await registry.get_checkpoint_scanner()
# If modelVersionId is provided, check for specific version
if model_version_id_str:
try:
model_version_id = int(model_version_id_str)
except ValueError:
return web.json_response({
'success': False,
'error': 'Parameter modelVersionId must be an integer'
}, status=400)
# Check if the specific version exists in either scanner
exists = False
model_type = None
# Check lora scanner first
if await lora_scanner.check_model_version_exists(model_id, model_version_id):
exists = True
model_type = 'lora'
# If not found in lora, check checkpoint scanner
elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_id, model_version_id):
exists = True
model_type = 'checkpoint'
return web.json_response({
'success': True,
'exists': exists,
'modelType': model_type if exists else None
})
# If modelVersionId is not provided, return all version IDs for the model
else:
# Get versions from lora scanner first
lora_versions = await lora_scanner.get_model_versions_by_id(model_id)
checkpoint_versions = []
# Only check checkpoint scanner if no lora versions found
if not lora_versions:
checkpoint_versions = await checkpoint_scanner.get_model_versions_by_id(model_id)
# Determine model type and combine results
model_type = None
versions = []
if lora_versions:
model_type = 'lora'
versions = lora_versions
elif checkpoint_versions:
model_type = 'checkpoint'
versions = checkpoint_versions
return web.json_response({
'success': True,
'modelId': model_id,
'modelType': model_type,
'versions': versions
})
except Exception as e:
logger.error(f"Failed to check model existence: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)

View File

@@ -3,7 +3,6 @@ import time
import base64 import base64
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import torch
import io import io
import logging import logging
from aiohttp import web from aiohttp import web
@@ -1018,6 +1017,8 @@ class RecipeRoutes:
shape_info = tensor_image.shape shape_info = tensor_image.shape
logger.debug(f"Tensor shape: {shape_info}, dtype: {tensor_image.dtype}") logger.debug(f"Tensor shape: {shape_info}, dtype: {tensor_image.dtype}")
import torch
# Convert tensor to numpy array # Convert tensor to numpy array
if isinstance(tensor_image, torch.Tensor): if isinstance(tensor_image, torch.Tensor):
image_np = tensor_image.cpu().numpy() image_np = tensor_image.cpu().numpy()

View File

@@ -33,7 +33,6 @@ class CheckpointScanner(ModelScanner):
file_extensions=file_extensions, file_extensions=file_extensions,
hash_index=ModelHashIndex() hash_index=ModelHashIndex()
) )
self._checkpoint_roots = self._init_checkpoint_roots()
self._initialized = True self._initialized = True
@classmethod @classmethod
@@ -44,27 +43,9 @@ class CheckpointScanner(ModelScanner):
cls._instance = cls() cls._instance = cls()
return cls._instance return cls._instance
def _init_checkpoint_roots(self) -> List[str]:
"""Initialize checkpoint roots from ComfyUI settings"""
# Get both checkpoint and diffusion_models paths
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
diffusion_paths = folder_paths.get_folder_paths("diffusion_models")
# Combine, normalize and deduplicate paths
all_paths = set()
for path in checkpoint_paths + diffusion_paths:
if os.path.exists(path):
norm_path = path.replace(os.sep, "/")
all_paths.add(norm_path)
# Sort for consistent order
sorted_paths = sorted(all_paths, key=lambda p: p.lower())
return sorted_paths
def get_model_roots(self) -> List[str]: def get_model_roots(self) -> List[str]:
"""Get checkpoint root directories""" """Get checkpoint root directories"""
return self._checkpoint_roots return config.base_models_roots
async def scan_all_models(self) -> List[Dict]: async def scan_all_models(self) -> List[Dict]:
"""Scan all checkpoint directories and return metadata""" """Scan all checkpoint directories and return metadata"""
@@ -72,7 +53,7 @@ class CheckpointScanner(ModelScanner):
# Create scan tasks for each directory # Create scan tasks for each directory
scan_tasks = [] scan_tasks = []
for root in self._checkpoint_roots: for root in self.get_model_roots():
task = asyncio.create_task(self._scan_directory(root)) task = asyncio.create_task(self._scan_directory(root))
scan_tasks.append(task) scan_tasks.append(task)

View File

@@ -225,7 +225,7 @@ class CivitaiClient:
logger.error(f"Error fetching model versions: {e}") logger.error(f"Error fetching model versions: {e}")
return None return None
async def get_model_version(self, model_id: str, version_id: str = "") -> Optional[Dict]: async def get_model_version(self, model_id: int, version_id: int = None) -> Optional[Dict]:
"""Get specific model version with additional metadata """Get specific model version with additional metadata
Args: Args:
@@ -237,6 +237,8 @@ class CivitaiClient:
""" """
try: try:
session = await self._ensure_fresh_session() session = await self._ensure_fresh_session()
# Step 1: Get model data to find version_id if not provided and get additional metadata
async with session.get(f"{self.base_url}/models/{model_id}") as response: async with session.get(f"{self.base_url}/models/{model_id}") as response:
if response.status != 200: if response.status != 200:
return None return None
@@ -244,45 +246,28 @@ class CivitaiClient:
data = await response.json() data = await response.json()
model_versions = data.get('modelVersions', []) model_versions = data.get('modelVersions', [])
# Find matching version # Step 2: Determine the version_id to use
matched_version = None target_version_id = version_id
if target_version_id is None:
if version_id: target_version_id = model_versions[0].get('id')
# If version_id provided, find exact match
for version in model_versions: # Step 3: Get detailed version info using the version_id
if str(version.get('id')) == str(version_id): headers = self._get_request_headers()
matched_version = version async with session.get(f"{self.base_url}/model-versions/{target_version_id}", headers=headers) as response:
break if response.status != 200:
else:
# If no version_id then use the first version
matched_version = model_versions[0] if model_versions else None
# If no match found, return None
if not matched_version:
return None return None
# Build result with modified fields
result = matched_version.copy() # Copy to avoid modifying original
# Replace index with modelId version = await response.json()
if 'index' in result:
del result['index']
result['modelId'] = model_id
# Add model field with metadata from top level # Step 4: Enrich version_info with model data
result['model'] = { # Add description and tags from model data
"name": data.get("name"), version['model']['description'] = data.get("description")
"type": data.get("type"), version['model']['tags'] = data.get("tags", [])
"nsfw": data.get("nsfw", False),
"poi": data.get("poi", False),
"description": data.get("description"),
"tags": data.get("tags", [])
}
# Add creator field from top level # Add creator from model data
result['creator'] = data.get("creator") version['creator'] = data.get("creator")
return result return version
except Exception as e: except Exception as e:
logger.error(f"Error fetching model version: {e}") logger.error(f"Error fetching model version: {e}")

View File

@@ -1,13 +1,13 @@
import logging import logging
import os import os
import json
import asyncio import asyncio
from typing import Dict from typing import Dict
from ..utils.models import LoraMetadata, CheckpointMetadata from ..utils.models import LoraMetadata, CheckpointMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES
from ..utils.exif_utils import ExifUtils from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
from .service_registry import ServiceRegistry from .service_registry import ServiceRegistry
from .settings_manager import settings
# Download to temporary file first # Download to temporary file first
import tempfile import tempfile
@@ -48,54 +48,98 @@ class DownloadManager:
"""Get the checkpoint scanner from registry""" """Get the checkpoint scanner from registry"""
return await ServiceRegistry.get_checkpoint_scanner() return await ServiceRegistry.get_checkpoint_scanner()
async def download_from_civitai(self, download_url: str = None, model_hash: str = None, async def download_from_civitai(self, model_id: int,
model_version_id: str = None, save_dir: str = None, model_version_id: int, save_dir: str = None,
relative_path: str = '', progress_callback=None, relative_path: str = '', progress_callback=None, use_default_paths: bool = False) -> Dict:
model_type: str = "lora") -> Dict:
"""Download model from Civitai """Download model from Civitai
Args: Args:
download_url: Direct download URL for the model model_id: Civitai model ID
model_hash: SHA256 hash of the model model_version_id: Civitai model version ID (optional, if not provided, will download the latest version)
model_version_id: Civitai model version ID
save_dir: Directory to save the model to save_dir: Directory to save the model to
relative_path: Relative path within save_dir relative_path: Relative path within save_dir
progress_callback: Callback function for progress updates progress_callback: Callback function for progress updates
model_type: Type of model ('lora' or 'checkpoint') use_default_paths: Flag to indicate whether to use default paths
Returns: Returns:
Dict with download result Dict with download result
""" """
try: try:
# Update save directory with relative path if provided # Check if model version already exists in library
if relative_path: if model_version_id is not None:
save_dir = os.path.join(save_dir, relative_path) # Case 1: model_version_id is provided, check both scanners
# Create directory if it doesn't exist lora_scanner = await self._get_lora_scanner()
os.makedirs(save_dir, exist_ok=True) checkpoint_scanner = await self._get_checkpoint_scanner()
# Check lora scanner first
if await lora_scanner.check_model_version_exists(model_id, model_version_id):
return {'success': False, 'error': 'Model version already exists in lora library'}
# Check checkpoint scanner
if await checkpoint_scanner.check_model_version_exists(model_id, model_version_id):
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
# Get civitai client # Get civitai client
civitai_client = await self._get_civitai_client() civitai_client = await self._get_civitai_client()
# Get version info based on the provided identifier # Get version info based on the provided identifier
version_info = None version_info = await civitai_client.get_model_version(model_id, model_version_id)
error_msg = None
if model_hash:
# Get model by hash
version_info = await civitai_client.get_model_by_hash(model_hash)
elif model_version_id:
# Use model version ID directly
version_info, error_msg = await civitai_client.get_model_version_info(model_version_id)
elif download_url:
# Extract version ID from download URL
version_id = download_url.split('/')[-1]
version_info, error_msg = await civitai_client.get_model_version_info(version_id)
if not version_info: if not version_info:
if error_msg and "model not found" in error_msg.lower(): return {'success': False, 'error': 'Failed to fetch model metadata'}
return {'success': False, 'error': f'Model not found on Civitai: {error_msg}'}
return {'success': False, 'error': error_msg or 'Failed to fetch model metadata'} model_type_from_info = version_info.get('model', {}).get('type', '').lower()
if model_type_from_info == 'checkpoint':
model_type = 'checkpoint'
elif model_type_from_info in VALID_LORA_TYPES:
model_type = 'lora'
else:
return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'}
# Case 2: model_version_id was None, check after getting version_info
if model_version_id is None:
version_model_id = version_info.get('modelId')
version_id = version_info.get('id')
if model_type == 'lora':
# Check lora scanner
lora_scanner = await self._get_lora_scanner()
if await lora_scanner.check_model_version_exists(version_model_id, version_id):
return {'success': False, 'error': 'Model version already exists in lora library'}
elif model_type == 'checkpoint':
# Check checkpoint scanner
checkpoint_scanner = await self._get_checkpoint_scanner()
if await checkpoint_scanner.check_model_version_exists(version_model_id, version_id):
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
# Handle use_default_paths
if use_default_paths:
# Set save_dir based on model type
if model_type == 'checkpoint':
default_path = settings.get('default_checkpoint_root')
if not default_path:
return {'success': False, 'error': 'Default checkpoint root path not set in settings'}
save_dir = default_path
else: # model_type == 'lora'
default_path = settings.get('default_lora_root')
if not default_path:
return {'success': False, 'error': 'Default lora root path not set in settings'}
save_dir = default_path
# Set relative_path to version_info.baseModel/first_tag if available
base_model = version_info.get('baseModel', '')
model_tags = version_info.get('model', {}).get('tags', [])
if base_model:
if model_tags:
relative_path = os.path.join(base_model, model_tags[0])
else:
relative_path = base_model
# Update save directory with relative path if provided
if relative_path:
save_dir = os.path.join(save_dir, relative_path)
# Create directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)
# Check if this is an early access model # Check if this is an early access model
if version_info.get('earlyAccessEndsAt'): if version_info.get('earlyAccessEndsAt'):
@@ -137,18 +181,6 @@ class DownloadManager:
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path) metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
logger.info(f"Creating LoraMetadata for {file_name}") logger.info(f"Creating LoraMetadata for {file_name}")
# 5.1 Get and update model tags, description and creator info
model_id = version_info.get('modelId')
if model_id:
model_metadata, _ = await civitai_client.get_model_metadata(str(model_id))
if model_metadata:
if model_metadata.get("tags"):
metadata.tags = model_metadata.get("tags", [])
if model_metadata.get("description"):
metadata.modelDescription = model_metadata.get("description", "")
if model_metadata.get("creator"):
metadata.civitai["creator"] = model_metadata.get("creator")
# 6. Start download process # 6. Start download process
result = await self._execute_download( result = await self._execute_download(
download_url=file_info.get('downloadUrl', ''), download_url=file_info.get('downloadUrl', ''),

View File

@@ -1,11 +1,7 @@
import json
import os import os
import logging import logging
import asyncio import asyncio
import shutil from typing import List, Dict, Optional
import time
import re
from typing import List, Dict, Optional, Set
from ..utils.models import LoraMetadata from ..utils.models import LoraMetadata
from ..config import config from ..config import config
@@ -14,7 +10,6 @@ from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to Mo
from .settings_manager import settings from .settings_manager import settings
from ..utils.constants import NSFW_LEVELS from ..utils.constants import NSFW_LEVELS
from ..utils.utils import fuzzy_match from ..utils.utils import fuzzy_match
from .service_registry import ServiceRegistry
import sys import sys
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -1362,3 +1362,59 @@ class ModelScanner:
if file_name in self._hash_index._duplicate_filenames: if file_name in self._hash_index._duplicate_filenames:
if len(self._hash_index._duplicate_filenames[file_name]) <= 1: if len(self._hash_index._duplicate_filenames[file_name]) <= 1:
del self._hash_index._duplicate_filenames[file_name] del self._hash_index._duplicate_filenames[file_name]
async def check_model_version_exists(self, model_id: int, model_version_id: int) -> bool:
"""Check if a specific model version exists in the cache
Args:
model_id: Civitai model ID
model_version_id: Civitai model version ID
Returns:
bool: True if the model version exists, False otherwise
"""
try:
cache = await self.get_cached_data()
if not cache or not cache.raw_data:
return False
for item in cache.raw_data:
if (item.get('civitai') and
item['civitai'].get('modelId') == model_id and
item['civitai'].get('id') == model_version_id):
return True
return False
except Exception as e:
logger.error(f"Error checking model version existence: {e}")
return False
async def get_model_versions_by_id(self, model_id: int) -> List[Dict]:
"""Get all versions of a model by its ID
Args:
model_id: Civitai model ID
Returns:
List[Dict]: List of version information dictionaries
"""
try:
cache = await self.get_cached_data()
if not cache or not cache.raw_data:
return []
versions = []
for item in cache.raw_data:
if (item.get('civitai') and
item['civitai'].get('modelId') == model_id and
item['civitai'].get('id')):
versions.append({
'versionId': item['civitai'].get('id'),
'name': item['civitai'].get('name'),
'fileName': item.get('file_name', '')
})
return versions
except Exception as e:
logger.error(f"Error getting model versions: {e}")
return []

View File

@@ -1,6 +1,7 @@
import logging import logging
from aiohttp import web from aiohttp import web
from typing import Set, Dict, Optional from typing import Set, Dict, Optional
from uuid import uuid4
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -10,7 +11,7 @@ class WebSocketManager:
def __init__(self): def __init__(self):
self._websockets: Set[web.WebSocketResponse] = set() self._websockets: Set[web.WebSocketResponse] = set()
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
self._checkpoint_websockets: Set[web.WebSocketResponse] = set() # New set for checkpoint download progress self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse: async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection""" """Handle new WebSocket connection"""
@@ -39,19 +40,35 @@ class WebSocketManager:
finally: finally:
self._init_websockets.discard(ws) self._init_websockets.discard(ws)
return ws return ws
async def handle_checkpoint_connection(self, request: web.Request) -> web.WebSocketResponse: async def handle_download_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection for checkpoint download progress""" """Handle new WebSocket connection for download progress"""
ws = web.WebSocketResponse() ws = web.WebSocketResponse()
await ws.prepare(request) await ws.prepare(request)
self._checkpoint_websockets.add(ws)
# Get download_id from query parameters
download_id = request.query.get('id')
if not download_id:
# Generate a new download ID if not provided
download_id = str(uuid4())
# Store the websocket with its download ID
self._download_websockets[download_id] = ws
try: try:
# Send the download ID back to the client
await ws.send_json({
'type': 'download_id',
'download_id': download_id
})
async for msg in ws: async for msg in ws:
if msg.type == web.WSMsgType.ERROR: if msg.type == web.WSMsgType.ERROR:
logger.error(f'Checkpoint WebSocket error: {ws.exception()}') logger.error(f'Download WebSocket error: {ws.exception()}')
finally: finally:
self._checkpoint_websockets.discard(ws) if download_id in self._download_websockets:
del self._download_websockets[download_id]
return ws return ws
async def broadcast(self, data: Dict): async def broadcast(self, data: Dict):
@@ -84,17 +101,18 @@ class WebSocketManager:
except Exception as e: except Exception as e:
logger.error(f"Error sending initialization progress: {e}") logger.error(f"Error sending initialization progress: {e}")
async def broadcast_checkpoint_progress(self, data: Dict): async def broadcast_download_progress(self, download_id: str, data: Dict):
"""Broadcast checkpoint download progress to connected clients""" """Send progress update to specific download client"""
if not self._checkpoint_websockets: if download_id not in self._download_websockets:
logger.debug(f"No WebSocket found for download ID: {download_id}")
return return
for ws in self._checkpoint_websockets: ws = self._download_websockets[download_id]
try: try:
await ws.send_json(data) await ws.send_json(data)
except Exception as e: except Exception as e:
logger.error(f"Error sending checkpoint progress: {e}") logger.error(f"Error sending download progress: {e}")
def get_connected_clients_count(self) -> int: def get_connected_clients_count(self) -> int:
"""Get number of connected clients""" """Get number of connected clients"""
return len(self._websockets) return len(self._websockets)
@@ -102,10 +120,14 @@ class WebSocketManager:
def get_init_clients_count(self) -> int: def get_init_clients_count(self) -> int:
"""Get number of initialization progress clients""" """Get number of initialization progress clients"""
return len(self._init_websockets) return len(self._init_websockets)
def get_checkpoint_clients_count(self) -> int: def get_download_clients_count(self) -> int:
"""Get number of checkpoint progress clients""" """Get number of download progress clients"""
return len(self._checkpoint_websockets) return len(self._download_websockets)
def generate_download_id(self) -> str:
"""Generate a unique download ID"""
return str(uuid4())
# Global instance # Global instance
ws_manager = WebSocketManager() ws_manager = WebSocketManager()

View File

@@ -10,7 +10,8 @@ NSFW_LEVELS = {
# Node type constants # Node type constants
NODE_TYPES = { NODE_TYPES = {
"Lora Loader (LoraManager)": 1, "Lora Loader (LoraManager)": 1,
"Lora Stacker (LoraManager)": 2 "Lora Stacker (LoraManager)": 2,
"WanVideo Lora Select (LoraManager)": 3
} }
# Default ComfyUI node color when bgcolor is null # Default ComfyUI node color when bgcolor is null

View File

@@ -8,9 +8,11 @@ from .model_utils import determine_base_model
from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
from ..config import config from ..config import config
from ..services.civitai_client import CivitaiClient from ..services.civitai_client import CivitaiClient
from ..services.service_registry import ServiceRegistry
from ..utils.exif_utils import ExifUtils from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
from ..services.download_manager import DownloadManager from ..services.download_manager import DownloadManager
from ..services.websocket_manager import ws_manager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -564,13 +566,12 @@ class ModelRouteUtils:
return web.Response(text=str(e), status=500) return web.Response(text=str(e), status=500)
@staticmethod @staticmethod
async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type="lora") -> web.Response: async def handle_download_model(request: web.Request, download_manager: DownloadManager) -> web.Response:
"""Handle model download request """Handle model download request
Args: Args:
request: The aiohttp request request: The aiohttp request
download_manager: Instance of DownloadManager download_manager: Instance of DownloadManager
model_type: Type of model ('lora' or 'checkpoint')
Returns: Returns:
web.Response: The HTTP response web.Response: The HTTP response
@@ -578,40 +579,58 @@ class ModelRouteUtils:
try: try:
data = await request.json() data = await request.json()
# Create progress callback # Get or generate a download ID
download_id = data.get('download_id', ws_manager.generate_download_id())
# Create progress callback with download ID
async def progress_callback(progress): async def progress_callback(progress):
from ..services.websocket_manager import ws_manager await ws_manager.broadcast_download_progress(download_id, {
await ws_manager.broadcast({
'status': 'progress', 'status': 'progress',
'progress': progress 'progress': progress,
'download_id': download_id
}) })
# Check which identifier is provided # Check which identifier is provided and convert to int
download_url = data.get('download_url') try:
model_hash = data.get('model_hash') model_id = int(data.get('model_id'))
model_version_id = data.get('model_version_id') except (TypeError, ValueError):
return web.Response(
status=400,
text="Invalid model_id: Must be an integer"
)
# Convert model_version_id to int if provided
model_version_id = None
if data.get('model_version_id'):
try:
model_version_id = int(data.get('model_version_id'))
except (TypeError, ValueError):
return web.Response(
status=400,
text="Invalid model_version_id: Must be an integer"
)
# Validate that at least one identifier is provided # Only model_id is required, model_version_id is optional
if not any([download_url, model_hash, model_version_id]): if not model_id:
return web.Response( return web.Response(
status=400, status=400,
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'" text="Missing required parameter: Please provide 'model_id'"
) )
# Use the correct root directory based on model type use_default_paths = data.get('use_default_paths', False)
root_key = 'checkpoint_root' if model_type == 'checkpoint' else 'lora_root'
save_dir = data.get(root_key)
result = await download_manager.download_from_civitai( result = await download_manager.download_from_civitai(
download_url=download_url, model_id=model_id,
model_hash=model_hash,
model_version_id=model_version_id, model_version_id=model_version_id,
save_dir=save_dir, save_dir=data.get('model_root'),
relative_path=data.get('relative_path', ''), relative_path=data.get('relative_path', ''),
progress_callback=progress_callback, use_default_paths=use_default_paths,
model_type=model_type progress_callback=progress_callback
) )
# Include download_id in the response
result['download_id'] = download_id
if not result.get('success', False): if not result.get('success', False):
error_message = result.get('error', 'Unknown error') error_message = result.get('error', 'Unknown error')
@@ -638,7 +657,7 @@ class ModelRouteUtils:
text="Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com." text="Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com."
) )
logger.error(f"Error downloading {model_type}: {error_message}") logger.error(f"Error downloading model: {error_message}")
return web.Response(status=500, text=error_message) return web.Response(status=500, text=error_message)
@staticmethod @staticmethod
@@ -693,8 +712,10 @@ class ModelRouteUtils:
try: try:
data = await request.json() data = await request.json()
file_path = data.get('file_path') file_path = data.get('file_path')
model_id = data.get('model_id') model_id = int(data.get('model_id'))
model_version_id = data.get('model_version_id') model_version_id = None
if data.get('model_version_id'):
model_version_id = int(data.get('model_version_id'))
if not file_path or not model_id: if not file_path or not model_id:
return web.json_response({"success": False, "error": "Both file_path and model_id are required"}, status=400) return web.json_response({"success": False, "error": "Both file_path and model_id are required"}, status=400)

View File

@@ -1,8 +1,29 @@
from difflib import SequenceMatcher from difflib import SequenceMatcher
import requests import requests
import tempfile import tempfile
import re import os
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from ..services.service_registry import ServiceRegistry
from ..config import config
async def get_lora_info(lora_name):
"""Get the lora path and trigger words from cache"""
scanner = await ServiceRegistry.get_lora_scanner()
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
file_path = item.get('file_path')
if file_path:
for root in config.loras_roots:
root = root.replace(os.sep, '/')
if file_path.startswith(root):
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
# Get trigger words from civitai metadata
civitai = item.get('civitai', {})
trigger_words = civitai.get('trainedWords', []) if civitai else []
return relative_path, trigger_words
return lora_name, []
def download_twitter_image(url): def download_twitter_image(url):
"""Download image from a URL containing twitter:image meta tag """Download image from a URL containing twitter:image meta tag

View File

@@ -1,13 +1,12 @@
[project] [project]
name = "comfyui-lora-manager" name = "comfyui-lora-manager"
description = "LoRA Manager for ComfyUI - Access it at http://localhost:8188/loras for managing LoRA models with previews and metadata integration." description = "Revolutionize your workflow with the ultimate LoRA companion for ComfyUI!"
version = "0.8.19" version = "0.8.20-beta"
license = {file = "LICENSE"} license = {file = "LICENSE"}
dependencies = [ dependencies = [
"aiohttp", "aiohttp",
"jinja2", "jinja2",
"safetensors", "safetensors",
"watchdog",
"beautifulsoup4", "beautifulsoup4",
"piexif", "piexif",
"Pillow", "Pillow",

View File

@@ -1,7 +1,6 @@
aiohttp aiohttp
jinja2 jinja2
safetensors safetensors
watchdog
beautifulsoup4 beautifulsoup4
piexif piexif
Pillow Pillow
@@ -9,6 +8,5 @@ olefile
requests requests
toml toml
numpy numpy
torch
natsort natsort
msgpack msgpack

View File

@@ -3,6 +3,26 @@ import os
import sys import sys
import json import json
# Create mock modules for py/nodes directory - add this before any other imports
def mock_nodes_directory():
"""Create mock modules for all Python files in the py/nodes directory"""
nodes_dir = os.path.join(os.path.dirname(__file__), 'py', 'nodes')
if os.path.exists(nodes_dir):
# Create a mock module for the nodes package itself
sys.modules['py.nodes'] = type('MockNodesModule', (), {})
# Create mock modules for all Python files in the nodes directory
for file in os.listdir(nodes_dir):
if file.endswith('.py') and file != '__init__.py':
module_name = file[:-3] # Remove .py extension
full_module_name = f'py.nodes.{module_name}'
# Create empty module object
sys.modules[full_module_name] = type(f'Mock{module_name.capitalize()}Module', (), {})
print(f"Created mock module for: {full_module_name}")
# Run the mocking function before any other imports
mock_nodes_directory()
# Create mock folder_paths module BEFORE any other imports # Create mock folder_paths module BEFORE any other imports
class MockFolderPaths: class MockFolderPaths:
@staticmethod @staticmethod
@@ -232,7 +252,7 @@ class StandaloneLoraManager(LoraManager):
added_targets.add(os.path.normpath(real_root)) added_targets.add(os.path.normpath(real_root))
# Add static routes for each checkpoint root # Add static routes for each checkpoint root
for idx, root in enumerate(config.checkpoints_roots, start=1): for idx, root in enumerate(config.base_models_roots, start=1):
if not os.path.exists(root): if not os.path.exists(root):
logger.warning(f"Checkpoint root path does not exist: {root}") logger.warning(f"Checkpoint root path does not exist: {root}")
continue continue
@@ -268,8 +288,8 @@ class StandaloneLoraManager(LoraManager):
norm_target = os.path.normpath(target_path) norm_target = os.path.normpath(target_path)
if norm_target not in added_targets: if norm_target not in added_targets:
# Determine if this is a checkpoint or lora link based on path # Determine if this is a checkpoint or lora link based on path
is_checkpoint = any(os.path.normpath(cp_root) in os.path.normpath(link_path) for cp_root in config.checkpoints_roots) is_checkpoint = any(os.path.normpath(cp_root) in os.path.normpath(link_path) for cp_root in config.base_models_roots)
is_checkpoint = is_checkpoint or any(os.path.normpath(cp_root) in norm_target for cp_root in config.checkpoints_roots) is_checkpoint = is_checkpoint or any(os.path.normpath(cp_root) in norm_target for cp_root in config.base_models_roots)
if is_checkpoint: if is_checkpoint:
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview' route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'

View File

@@ -751,6 +751,29 @@ input:checked + .toggle-slider:before {
opacity: 1; opacity: 1;
} }
/* Add styles for tab with new content indicator */
.tab-btn.has-new-content {
position: relative;
}
.tab-btn.has-new-content::after {
content: "";
position: absolute;
top: 4px;
right: 4px;
width: 8px;
height: 8px;
background-color: var(--lora-accent);
border-radius: 50%;
animation: pulse 2s infinite;
}
@keyframes pulse {
0% { opacity: 1; transform: scale(1); }
50% { opacity: 0.7; transform: scale(1.1); }
100% { opacity: 1; transform: scale(1); }
}
/* Tab content styles */ /* Tab content styles */
.help-content { .help-content {
padding: var(--space-1) 0; padding: var(--space-1) 0;
@@ -817,6 +840,37 @@ input:checked + .toggle-slider:before {
text-decoration: underline; text-decoration: underline;
} }
/* New content badge styles */
.new-content-badge {
display: inline-flex;
align-items: center;
justify-content: center;
font-size: 0.7em;
font-weight: 600;
background-color: var(--lora-accent);
color: var(--lora-text);
padding: 2px 6px;
border-radius: 10px;
margin-left: 8px;
vertical-align: middle;
animation: fadeIn 0.5s ease-in-out;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
text-transform: uppercase;
letter-spacing: 0.5px;
}
.new-content-badge.inline {
font-size: 0.65em;
padding: 1px 4px;
margin-left: 6px;
border-radius: 8px;
}
/* Dark theme adjustments for new content badge */
[data-theme="dark"] .new-content-badge {
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.4);
}
/* Update video list styles */ /* Update video list styles */
.video-list { .video-list {
display: flex; display: flex;

View File

@@ -179,7 +179,7 @@ export function setupBaseModelEditing(filePath) {
'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER], 'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER],
'Video Models': [BASE_MODELS.SVD, BASE_MODELS.LTXV, BASE_MODELS.WAN_VIDEO, BASE_MODELS.HUNYUAN_VIDEO], 'Video Models': [BASE_MODELS.SVD, BASE_MODELS.LTXV, BASE_MODELS.WAN_VIDEO, BASE_MODELS.HUNYUAN_VIDEO],
'Other Models': [ 'Other Models': [
BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.AURAFLOW, BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.FLUX_1_KONTEXT, BASE_MODELS.AURAFLOW,
BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1, BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1,
BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI, BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI,
BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.HIDREAM, BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.HIDREAM,

View File

@@ -180,7 +180,7 @@ export function setupBaseModelEditing(filePath) {
'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER], 'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER],
'Video Models': [BASE_MODELS.SVD, BASE_MODELS.LTXV, BASE_MODELS.WAN_VIDEO, BASE_MODELS.HUNYUAN_VIDEO], 'Video Models': [BASE_MODELS.SVD, BASE_MODELS.LTXV, BASE_MODELS.WAN_VIDEO, BASE_MODELS.HUNYUAN_VIDEO],
'Other Models': [ 'Other Models': [
BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.AURAFLOW, BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.FLUX_1_KONTEXT, BASE_MODELS.AURAFLOW,
BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1, BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1,
BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI, BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI,
BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.HIDREAM, BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.HIDREAM,

View File

@@ -245,11 +245,6 @@ function findLocalFile(img, index, exampleFiles) {
const match = file.name.match(/image_(\d+)\./); const match = file.name.match(/image_(\d+)\./);
return match && parseInt(match[1]) === index; return match && parseInt(match[1]) === index;
}); });
// If not found by index, just use the same position in the array if available
if (!localFile && index < exampleFiles.length) {
localFile = exampleFiles[index];
}
} }
return localFile; return localFile;
@@ -406,9 +401,6 @@ async function handleImportFiles(files, modelHash, importContainer) {
const customImages = result.custom_images || []; const customImages = result.custom_images || [];
// Combine both arrays for rendering // Combine both arrays for rendering
const allImages = [...regularImages, ...customImages]; const allImages = [...regularImages, ...customImages];
console.log("Regular images:", regularImages);
console.log("Custom images:", customImages);
console.log("Combined images:", allImages);
showcaseTab.innerHTML = renderShowcaseContent(allImages, updatedFilesResult.files, true); showcaseTab.innerHTML = renderShowcaseContent(allImages, updatedFilesResult.files, true);
// Re-initialize showcase functionality // Re-initialize showcase functionality

View File

@@ -61,6 +61,7 @@ export class CheckpointDownloadManager {
this.currentVersion = null; this.currentVersion = null;
this.versions = []; this.versions = [];
this.modelInfo = null; this.modelInfo = null;
this.modelId = null;
this.modelVersionId = null; this.modelVersionId = null;
// Clear selected folder and remove selection from UI // Clear selected folder and remove selection from UI
@@ -79,12 +80,12 @@ export class CheckpointDownloadManager {
try { try {
this.loadingManager.showSimpleLoading('Fetching model versions...'); this.loadingManager.showSimpleLoading('Fetching model versions...');
const modelId = this.extractModelId(url); this.modelId = this.extractModelId(url);
if (!modelId) { if (!this.modelId) {
throw new Error('Invalid Civitai URL format'); throw new Error('Invalid Civitai URL format');
} }
const response = await fetch(`/api/checkpoints/civitai/versions/${modelId}`); const response = await fetch(`/api/checkpoints/civitai/versions/${this.modelId}`);
if (!response.ok) { if (!response.ok) {
const errorData = await response.json().catch(() => ({})); const errorData = await response.json().catch(() => ({}));
if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) { if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) {
@@ -254,7 +255,7 @@ export class CheckpointDownloadManager {
).join(''); ).join('');
// Set default checkpoint root if available // Set default checkpoint root if available
const defaultRoot = getStorageItem('settings', {}).default_checkpoints_root; const defaultRoot = getStorageItem('settings', {}).default_checkpoint_root;
if (defaultRoot && data.roots.includes(defaultRoot)) { if (defaultRoot && data.roots.includes(defaultRoot)) {
checkpointRoot.value = defaultRoot; checkpointRoot.value = defaultRoot;
} }
@@ -296,22 +297,28 @@ export class CheckpointDownloadManager {
} }
try { try {
const downloadUrl = this.currentVersion.downloadUrl;
if (!downloadUrl) {
throw new Error('No download URL available');
}
// Show enhanced loading with progress details // Show enhanced loading with progress details
const updateProgress = this.loadingManager.showDownloadProgress(1); const updateProgress = this.loadingManager.showDownloadProgress(1);
updateProgress(0, 0, this.currentVersion.name); updateProgress(0, 0, this.currentVersion.name);
// Setup WebSocket for progress updates using checkpoint-specific endpoint // Generate a unique ID for this download
const downloadId = Date.now().toString();
// Setup WebSocket for progress updates using download-specific endpoint
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/checkpoint-progress`); const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`);
ws.onmessage = (event) => { ws.onmessage = (event) => {
const data = JSON.parse(event.data); const data = JSON.parse(event.data);
if (data.status === 'progress') {
// Handle download ID confirmation
if (data.type === 'download_id') {
console.log(`Connected to checkpoint download progress with ID: ${data.download_id}`);
return;
}
// Only process progress updates for our download
if (data.status === 'progress' && data.download_id === downloadId) {
// Update progress display with current progress // Update progress display with current progress
updateProgress(data.progress, 0, this.currentVersion.name); updateProgress(data.progress, 0, this.currentVersion.name);
@@ -333,14 +340,16 @@ export class CheckpointDownloadManager {
// Continue with download even if WebSocket fails // Continue with download even if WebSocket fails
}; };
// Start download using checkpoint download endpoint // Start download using checkpoint download endpoint with download ID
const response = await fetch('/api/checkpoints/download', { const response = await fetch('/api/download-model', {
method: 'POST', method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ body: JSON.stringify({
download_url: downloadUrl, model_id: this.modelId,
checkpoint_root: checkpointRoot, model_version_id: this.currentVersion.id,
relative_path: targetFolder model_root: checkpointRoot,
relative_path: targetFolder,
download_id: downloadId
}) })
}); });

View File

@@ -63,6 +63,7 @@ export class DownloadManager {
this.currentVersion = null; this.currentVersion = null;
this.versions = []; this.versions = [];
this.modelInfo = null; this.modelInfo = null;
this.modelId = null;
this.modelVersionId = null; this.modelVersionId = null;
// Clear selected folder and remove selection from UI // Clear selected folder and remove selection from UI
@@ -81,12 +82,12 @@ export class DownloadManager {
try { try {
this.loadingManager.showSimpleLoading('Fetching model versions...'); this.loadingManager.showSimpleLoading('Fetching model versions...');
const modelId = this.extractModelId(url); this.modelId = this.extractModelId(url);
if (!modelId) { if (!this.modelId) {
throw new Error('Invalid Civitai URL format'); throw new Error('Invalid Civitai URL format');
} }
const response = await fetch(`/api/civitai/versions/${modelId}`); const response = await fetch(`/api/civitai/versions/${this.modelId}`);
if (!response.ok) { if (!response.ok) {
const errorData = await response.json().catch(() => ({})); const errorData = await response.json().catch(() => ({}));
if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) { if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) {
@@ -252,24 +253,39 @@ export class DownloadManager {
document.getElementById('locationStep').style.display = 'block'; document.getElementById('locationStep').style.display = 'block';
try { try {
const response = await fetch('/api/lora-roots'); // Fetch LoRA roots
if (!response.ok) { const rootsResponse = await fetch('/api/lora-roots');
if (!rootsResponse.ok) {
throw new Error('Failed to fetch LoRA roots'); throw new Error('Failed to fetch LoRA roots');
} }
const data = await response.json(); const rootsData = await rootsResponse.json();
const loraRoot = document.getElementById('loraRoot'); const loraRoot = document.getElementById('loraRoot');
loraRoot.innerHTML = data.roots.map(root => loraRoot.innerHTML = rootsData.roots.map(root =>
`<option value="${root}">${root}</option>` `<option value="${root}">${root}</option>`
).join(''); ).join('');
// Set default lora root if available // Set default lora root if available
const defaultRoot = getStorageItem('settings', {}).default_loras_root; const defaultRoot = getStorageItem('settings', {}).default_loras_root;
if (defaultRoot && data.roots.includes(defaultRoot)) { if (defaultRoot && rootsData.roots.includes(defaultRoot)) {
loraRoot.value = defaultRoot; loraRoot.value = defaultRoot;
} }
// Initialize folder browser after loading roots // Fetch folders dynamically
const foldersResponse = await fetch('/api/folders');
if (!foldersResponse.ok) {
throw new Error('Failed to fetch folders');
}
const foldersData = await foldersResponse.json();
const folderBrowser = document.getElementById('folderBrowser');
// Update folder browser with dynamic content
folderBrowser.innerHTML = foldersData.folders.map(folder =>
`<div class="folder-item" data-folder="${folder}">${folder}</div>`
).join('');
// Initialize folder browser after loading roots and folders
this.initializeFolderBrowser(); this.initializeFolderBrowser();
} catch (error) { } catch (error) {
showToast(error.message, 'error'); showToast(error.message, 'error');
@@ -306,22 +322,28 @@ export class DownloadManager {
} }
try { try {
const downloadUrl = this.currentVersion.downloadUrl;
if (!downloadUrl) {
throw new Error('No download URL available');
}
// Show enhanced loading with progress details // Show enhanced loading with progress details
const updateProgress = this.loadingManager.showDownloadProgress(1); const updateProgress = this.loadingManager.showDownloadProgress(1);
updateProgress(0, 0, this.currentVersion.name); updateProgress(0, 0, this.currentVersion.name);
// Setup WebSocket for progress updates // Generate a unique ID for this download
const downloadId = Date.now().toString();
// Setup WebSocket for progress updates - use download-specific endpoint
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`); const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`);
ws.onmessage = (event) => { ws.onmessage = (event) => {
const data = JSON.parse(event.data); const data = JSON.parse(event.data);
if (data.status === 'progress') {
// Handle download ID confirmation
if (data.type === 'download_id') {
console.log(`Connected to download progress with ID: ${data.download_id}`);
return;
}
// Only process progress updates for our download
if (data.status === 'progress' && data.download_id === downloadId) {
// Update progress display with current progress // Update progress display with current progress
updateProgress(data.progress, 0, this.currentVersion.name); updateProgress(data.progress, 0, this.currentVersion.name);
@@ -343,14 +365,16 @@ export class DownloadManager {
// Continue with download even if WebSocket fails // Continue with download even if WebSocket fails
}; };
// Start download // Start download with our download ID
const response = await fetch('/api/download-lora', { const response = await fetch('/api/download-model', {
method: 'POST', method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ body: JSON.stringify({
download_url: downloadUrl, model_id: this.modelId,
lora_root: loraRoot, model_version_id: this.currentVersion.id,
relative_path: targetFolder model_root: loraRoot,
relative_path: targetFolder,
download_id: downloadId
}) })
}); });
@@ -361,6 +385,9 @@ export class DownloadManager {
showToast('Download completed successfully', 'success'); showToast('Download completed successfully', 'success');
modalManager.closeModal('downloadModal'); modalManager.closeModal('downloadModal');
// Close WebSocket after download completes
ws.close();
// Update state and trigger reload with folder update // Update state and trigger reload with folder update
state.activeFolder = targetFolder; state.activeFolder = targetFolder;
await resetAndReload(true); // Pass true to update folders await resetAndReload(true); // Pass true to update folders

View File

@@ -6,7 +6,7 @@ import { getStorageItem, setStorageItem } from '../utils/storageHelpers.js';
export class HelpManager { export class HelpManager {
constructor() { constructor() {
this.lastViewedTimestamp = getStorageItem('help_last_viewed', 0); this.lastViewedTimestamp = getStorageItem('help_last_viewed', 0);
this.latestContentTimestamp = 0; // Will be updated from server or config this.latestContentTimestamp = new Date('2025-07-09').getTime(); // Will be updated from server or config
this.isInitialized = false; this.isInitialized = false;
// Default latest content data - could be fetched from server // Default latest content data - could be fetched from server
@@ -81,6 +81,9 @@ export class HelpManager {
if (window.modalManager) { if (window.modalManager) {
window.modalManager.toggleModal('helpModal'); window.modalManager.toggleModal('helpModal');
// Add visual indicator to Documentation tab if there's new content
this.updateDocumentationTabIndicator();
// Update the last viewed timestamp // Update the last viewed timestamp
this.markContentAsViewed(); this.markContentAsViewed();
@@ -88,6 +91,16 @@ export class HelpManager {
this.hideHelpBadge(); this.hideHelpBadge();
} }
} }
/**
* Add visual indicator to Documentation tab for new content
*/
updateDocumentationTabIndicator() {
const docTab = document.querySelector('.tab-btn[data-tab="documentation"]');
if (docTab && this.hasNewContent()) {
docTab.classList.add('has-new-content');
}
}
/** /**
* Mark content as viewed by saving current timestamp * Mark content as viewed by saving current timestamp
@@ -105,7 +118,7 @@ export class HelpManager {
// For now, we'll just use the hardcoded data from constructor // For now, we'll just use the hardcoded data from constructor
// Update the timestamp with the latest data // Update the timestamp with the latest data
this.latestContentTimestamp = this.latestVideoData.timestamp; this.latestContentTimestamp = Math.max(this.latestContentTimestamp, this.latestVideoData.timestamp);
// Check again if we need to show the badge with this new data // Check again if we need to show the badge with this new data
this.updateHelpBadge(); this.updateHelpBadge();

View File

@@ -2,6 +2,7 @@ import { showToast } from '../utils/uiHelpers.js';
import { state, getCurrentPageState } from '../state/index.js'; import { state, getCurrentPageState } from '../state/index.js';
import { modalManager } from './ModalManager.js'; import { modalManager } from './ModalManager.js';
import { getStorageItem } from '../utils/storageHelpers.js'; import { getStorageItem } from '../utils/storageHelpers.js';
import { updateFolderTags } from '../api/baseModelApi.js';
class MoveManager { class MoveManager {
constructor() { constructor() {
@@ -72,32 +73,46 @@ class MoveManager {
this.newFolderInput.value = ''; this.newFolderInput.value = '';
try { try {
const response = await fetch('/api/lora-roots'); // Fetch LoRA roots
if (!response.ok) { const rootsResponse = await fetch('/api/lora-roots');
if (!rootsResponse.ok) {
throw new Error('Failed to fetch LoRA roots'); throw new Error('Failed to fetch LoRA roots');
} }
const data = await response.json(); const rootsData = await rootsResponse.json();
if (!data.roots || data.roots.length === 0) { if (!rootsData.roots || rootsData.roots.length === 0) {
throw new Error('No LoRA roots found'); throw new Error('No LoRA roots found');
} }
// 填充LoRA根目录选择器 // 填充LoRA根目录选择器
this.loraRootSelect.innerHTML = data.roots.map(root => this.loraRootSelect.innerHTML = rootsData.roots.map(root =>
`<option value="${root}">${root}</option>` `<option value="${root}">${root}</option>`
).join(''); ).join('');
// Set default lora root if available // Set default lora root if available
const defaultRoot = getStorageItem('settings', {}).default_loras_root; const defaultRoot = getStorageItem('settings', {}).default_loras_root;
if (defaultRoot && data.roots.includes(defaultRoot)) { if (defaultRoot && rootsData.roots.includes(defaultRoot)) {
this.loraRootSelect.value = defaultRoot; this.loraRootSelect.value = defaultRoot;
} }
// Fetch folders dynamically
const foldersResponse = await fetch('/api/folders');
if (!foldersResponse.ok) {
throw new Error('Failed to fetch folders');
}
const foldersData = await foldersResponse.json();
// Update folder browser with dynamic content
this.folderBrowser.innerHTML = foldersData.folders.map(folder =>
`<div class="folder-item" data-folder="${folder}">${folder}</div>`
).join('');
this.updatePathPreview(); this.updatePathPreview();
modalManager.showModal('moveModal'); modalManager.showModal('moveModal');
} catch (error) { } catch (error) {
console.error('Error fetching LoRA roots:', error); console.error('Error fetching LoRA roots or folders:', error);
showToast(error.message, 'error'); showToast(error.message, 'error');
} }
} }
@@ -173,6 +188,17 @@ class MoveManager {
} }
} }
// Refresh folder tags after successful move
try {
const foldersResponse = await fetch('/api/folders');
if (foldersResponse.ok) {
const foldersData = await foldersResponse.json();
updateFolderTags(foldersData.folders);
}
} catch (error) {
console.error('Error refreshing folder tags:', error);
}
modalManager.closeModal('moveModal'); modalManager.closeModal('moveModal');
// If we were in bulk mode, exit it after successful move // If we were in bulk mode, exit it after successful move

View File

@@ -42,6 +42,11 @@ export class SettingsManager {
state.global.settings.cardInfoDisplay = 'always'; state.global.settings.cardInfoDisplay = 'always';
} }
// Set default for defaultCheckpointRoot if undefined
if (state.global.settings.default_checkpoint_root === undefined) {
state.global.settings.default_checkpoint_root = '';
}
// Convert old boolean compactMode to new displayDensity string // Convert old boolean compactMode to new displayDensity string
if (typeof state.global.settings.displayDensity === 'undefined') { if (typeof state.global.settings.displayDensity === 'undefined') {
if (state.global.settings.compactMode === true) { if (state.global.settings.compactMode === true) {
@@ -123,6 +128,9 @@ export class SettingsManager {
// Load default lora root // Load default lora root
await this.loadLoraRoots(); await this.loadLoraRoots();
// Load default checkpoint root
await this.loadCheckpointRoots();
// Backend settings are loaded from the template directly // Backend settings are loaded from the template directly
} }
@@ -165,6 +173,45 @@ export class SettingsManager {
} }
} }
async loadCheckpointRoots() {
try {
const defaultCheckpointRootSelect = document.getElementById('defaultCheckpointRoot');
if (!defaultCheckpointRootSelect) return;
// Fetch checkpoint roots
const response = await fetch('/api/checkpoints/roots');
if (!response.ok) {
throw new Error('Failed to fetch checkpoint roots');
}
const data = await response.json();
if (!data.roots || data.roots.length === 0) {
throw new Error('No checkpoint roots found');
}
// Clear existing options except the first one (No Default)
const noDefaultOption = defaultCheckpointRootSelect.querySelector('option[value=""]');
defaultCheckpointRootSelect.innerHTML = '';
defaultCheckpointRootSelect.appendChild(noDefaultOption);
// Add options for each root
data.roots.forEach(root => {
const option = document.createElement('option');
option.value = root;
option.textContent = root;
defaultCheckpointRootSelect.appendChild(option);
});
// Set selected value from settings
const defaultRoot = state.global.settings.default_checkpoint_root || '';
defaultCheckpointRootSelect.value = defaultRoot;
} catch (error) {
console.error('Error loading checkpoint roots:', error);
showToast('Failed to load checkpoint roots: ' + error.message, 'error');
}
}
toggleSettings() { toggleSettings() {
if (this.isOpen) { if (this.isOpen) {
modalManager.closeModal('settingsModal'); modalManager.closeModal('settingsModal');
@@ -251,6 +298,8 @@ export class SettingsManager {
// Update frontend state // Update frontend state
if (settingKey === 'default_lora_root') { if (settingKey === 'default_lora_root') {
state.global.settings.default_loras_root = value; state.global.settings.default_loras_root = value;
} else if (settingKey === 'default_checkpoint_root') {
state.global.settings.default_checkpoint_root = value;
} else if (settingKey === 'display_density') { } else if (settingKey === 'display_density') {
state.global.settings.displayDensity = value; state.global.settings.displayDensity = value;
@@ -268,7 +317,7 @@ export class SettingsManager {
try { try {
// For backend settings, make API call // For backend settings, make API call
if (settingKey === 'default_lora_root') { if (settingKey === 'default_lora_root' || settingKey === 'default_checkpoint_root') {
const payload = {}; const payload = {};
payload[settingKey] = value; payload[settingKey] = value;
@@ -414,6 +463,7 @@ export class SettingsManager {
const blurMatureContent = document.getElementById('blurMatureContent').checked; const blurMatureContent = document.getElementById('blurMatureContent').checked;
const showOnlySFW = document.getElementById('showOnlySFW').checked; const showOnlySFW = document.getElementById('showOnlySFW').checked;
const defaultLoraRoot = document.getElementById('defaultLoraRoot').value; const defaultLoraRoot = document.getElementById('defaultLoraRoot').value;
const defaultCheckpointRoot = document.getElementById('defaultCheckpointRoot').value;
const autoplayOnHover = document.getElementById('autoplayOnHover').checked; const autoplayOnHover = document.getElementById('autoplayOnHover').checked;
const optimizeExampleImages = document.getElementById('optimizeExampleImages').checked; const optimizeExampleImages = document.getElementById('optimizeExampleImages').checked;
@@ -424,6 +474,7 @@ export class SettingsManager {
state.global.settings.blurMatureContent = blurMatureContent; state.global.settings.blurMatureContent = blurMatureContent;
state.global.settings.show_only_sfw = showOnlySFW; state.global.settings.show_only_sfw = showOnlySFW;
state.global.settings.default_loras_root = defaultLoraRoot; state.global.settings.default_loras_root = defaultLoraRoot;
state.global.settings.default_checkpoint_root = defaultCheckpointRoot;
state.global.settings.autoplayOnHover = autoplayOnHover; state.global.settings.autoplayOnHover = autoplayOnHover;
state.global.settings.optimizeExampleImages = optimizeExampleImages; state.global.settings.optimizeExampleImages = optimizeExampleImages;
@@ -440,7 +491,8 @@ export class SettingsManager {
body: JSON.stringify({ body: JSON.stringify({
civitai_api_key: apiKey, civitai_api_key: apiKey,
show_only_sfw: showOnlySFW, show_only_sfw: showOnlySFW,
optimize_example_images: optimizeExampleImages optimize_example_images: optimizeExampleImages,
default_checkpoint_root: defaultCheckpointRoot
}) })
}); });

View File

@@ -128,9 +128,12 @@ export class DownloadManager {
targetPath += '/' + newFolder; targetPath += '/' + newFolder;
} }
// Generate a unique ID for this batch download
const batchDownloadId = Date.now().toString();
// Set up WebSocket for progress updates // Set up WebSocket for progress updates
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`); const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${batchDownloadId}`);
// Show enhanced loading with progress details for multiple items // Show enhanced loading with progress details for multiple items
const updateProgress = this.importManager.loadingManager.showDownloadProgress( const updateProgress = this.importManager.loadingManager.showDownloadProgress(
@@ -145,7 +148,15 @@ export class DownloadManager {
// Set up progress tracking for current download // Set up progress tracking for current download
ws.onmessage = (event) => { ws.onmessage = (event) => {
const data = JSON.parse(event.data); const data = JSON.parse(event.data);
if (data.status === 'progress') {
// Handle download ID confirmation
if (data.type === 'download_id') {
console.log(`Connected to batch download progress with ID: ${data.download_id}`);
return;
}
// Process progress updates for our current active download
if (data.status === 'progress' && data.download_id && data.download_id.startsWith(batchDownloadId)) {
// Update current LoRA progress // Update current LoRA progress
currentLoraProgress = data.progress; currentLoraProgress = data.progress;
@@ -188,16 +199,16 @@ export class DownloadManager {
updateProgress(0, completedDownloads, lora.name); updateProgress(0, completedDownloads, lora.name);
try { try {
// Download the LoRA // Download the LoRA with download ID
const response = await fetch('/api/download-lora', { const response = await fetch('/api/download-model', {
method: 'POST', method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ body: JSON.stringify({
download_url: lora.downloadUrl, model_id: lora.modelId,
model_version_id: lora.modelVersionId, model_version_id: lora.id,
model_hash: lora.hash, model_root: loraRoot,
lora_root: loraRoot, relative_path: targetPath.replace(loraRoot + '/', ''),
relative_path: targetPath.replace(loraRoot + '/', '') download_id: batchDownloadId
}) })
}); });

View File

@@ -24,6 +24,7 @@ export const BASE_MODELS = {
// Other models // Other models
FLUX_1_D: "Flux.1 D", FLUX_1_D: "Flux.1 D",
FLUX_1_S: "Flux.1 S", FLUX_1_S: "Flux.1 S",
FLUX_1_KONTEXT: "Flux.1 Kontext",
AURAFLOW: "AuraFlow", AURAFLOW: "AuraFlow",
PIXART_A: "PixArt a", PIXART_A: "PixArt a",
PIXART_E: "PixArt E", PIXART_E: "PixArt E",
@@ -78,6 +79,7 @@ export const BASE_MODEL_CLASSES = {
// Other models // Other models
[BASE_MODELS.FLUX_1_D]: "flux-d", [BASE_MODELS.FLUX_1_D]: "flux-d",
[BASE_MODELS.FLUX_1_S]: "flux-s", [BASE_MODELS.FLUX_1_S]: "flux-s",
[BASE_MODELS.FLUX_1_KONTEXT]: "flux-kontext",
[BASE_MODELS.AURAFLOW]: "auraflow", [BASE_MODELS.AURAFLOW]: "auraflow",
[BASE_MODELS.PIXART_A]: "pixart-a", [BASE_MODELS.PIXART_A]: "pixart-a",
[BASE_MODELS.PIXART_E]: "pixart-e", [BASE_MODELS.PIXART_E]: "pixart-e",
@@ -106,19 +108,22 @@ export const NSFW_LEVELS = {
// Node type constants // Node type constants
export const NODE_TYPES = { export const NODE_TYPES = {
LORA_LOADER: 1, LORA_LOADER: 1,
LORA_STACKER: 2 LORA_STACKER: 2,
WAN_VIDEO_LORA_SELECT: 3
}; };
// Node type names to IDs mapping // Node type names to IDs mapping
export const NODE_TYPE_NAMES = { export const NODE_TYPE_NAMES = {
"Lora Loader (LoraManager)": NODE_TYPES.LORA_LOADER, "Lora Loader (LoraManager)": NODE_TYPES.LORA_LOADER,
"Lora Stacker (LoraManager)": NODE_TYPES.LORA_STACKER "Lora Stacker (LoraManager)": NODE_TYPES.LORA_STACKER,
"WanVideo Lora Select (LoraManager)": NODE_TYPES.WAN_VIDEO_LORA_SELECT
}; };
// Node type icons // Node type icons
export const NODE_TYPE_ICONS = { export const NODE_TYPE_ICONS = {
[NODE_TYPES.LORA_LOADER]: "fas fa-l", [NODE_TYPES.LORA_LOADER]: "fas fa-l",
[NODE_TYPES.LORA_STACKER]: "fas fa-s" [NODE_TYPES.LORA_STACKER]: "fas fa-s",
[NODE_TYPES.WAN_VIDEO_LORA_SELECT]: "fas fa-w"
}; };
// Default ComfyUI node color when bgcolor is null // Default ComfyUI node color when bgcolor is null

View File

@@ -7,7 +7,7 @@ export const apiRoutes = {
delete: (id) => `/api/loras/${id}`, delete: (id) => `/api/loras/${id}`,
update: (id) => `/api/loras/${id}`, update: (id) => `/api/loras/${id}`,
civitai: (id) => `/api/loras/${id}/civitai`, civitai: (id) => `/api/loras/${id}/civitai`,
download: '/api/download-lora', download: '/api/download-model',
move: '/api/move-lora', move: '/api/move-lora',
scan: '/api/scan-loras' scan: '/api/scan-loras'
}, },

View File

@@ -337,7 +337,7 @@ export async function sendLoraToWorkflow(loraSyntax, replaceMode = false, syntax
// Success case - check node count // Success case - check node count
if (registryData.data.node_count === 0) { if (registryData.data.node_count === 0) {
// No nodes found - show warning // No nodes found - show warning
showToast('No Lora Loader or Lora Stacker nodes found in workflow', 'warning'); showToast('No supported target nodes found in workflow', 'warning');
return false; return false;
} else if (registryData.data.node_count > 1) { } else if (registryData.data.node_count > 1) {
// Multiple nodes - show selector // Multiple nodes - show selector

14
static/vendor/chart.js/chart.umd.js vendored Normal file

File diff suppressed because one or more lines are too long

View File

@@ -48,13 +48,7 @@
<div class="input-group"> <div class="input-group">
<label>Target Folder:</label> <label>Target Folder:</label>
<div class="folder-browser" id="folderBrowser"> <div class="folder-browser" id="folderBrowser">
{% for folder in folders %} <!-- Folders will be loaded dynamically -->
{% if folder %}
<div class="folder-item" data-folder="{{ folder }}">
{{ folder }}
</div>
{% endif %}
{% endfor %}
</div> </div>
</div> </div>
<div class="input-group"> <div class="input-group">
@@ -92,13 +86,7 @@
<div class="input-group"> <div class="input-group">
<label>Target Folder:</label> <label>Target Folder:</label>
<div class="folder-browser" id="moveFolderBrowser"> <div class="folder-browser" id="moveFolderBrowser">
{% for folder in folders %} <!-- Folders will be loaded dynamically -->
{% if folder %}
<div class="folder-item" data-folder="{{ folder }}">
{{ folder }}
</div>
{% endif %}
{% endfor %}
</div> </div>
</div> </div>
<div class="input-group"> <div class="input-group">

View File

@@ -197,6 +197,23 @@
Set the default LoRA root directory for downloads, imports and moves Set the default LoRA root directory for downloads, imports and moves
</div> </div>
</div> </div>
<div class="setting-item">
<div class="setting-row">
<div class="setting-info">
<label for="defaultCheckpointRoot">Default Checkpoint Root</label>
</div>
<div class="setting-control select-control">
<select id="defaultCheckpointRoot" onchange="settingsManager.saveSelectSetting('defaultCheckpointRoot', 'default_checkpoint_root')">
<option value="">No Default</option>
<!-- Options will be loaded dynamically -->
</select>
</div>
</div>
<div class="input-help">
Set the default checkpoint root directory for downloads, imports and moves
</div>
</div>
</div> </div>
<!-- Add Layout Settings Section --> <!-- Add Layout Settings Section -->
@@ -558,7 +575,7 @@
<div class="docs-section"> <div class="docs-section">
<h4><i class="fas fa-book-open"></i> Recipes</h4> <h4><i class="fas fa-book-open"></i> Recipes</h4>
<ul class="docs-links"> <ul class="docs-links">
<li><a href="https://github.com/willmiao/ComfyUI-Lora-Manager/wiki/%F0%9F%93%96-Recipes-Feature-Tutorial-%E2%80%93-ComfyUI-LoRA-Manager" target="_blank">Recipes Tutorial</a></li> <li><a href="https://github.com/willmiao/ComfyUI-Lora-Manager/wiki/Recipes-Feature-Tutorial-%E2%80%93-ComfyUI-LoRA-Manager" target="_blank">Recipes Tutorial</a></li>
</ul> </ul>
</div> </div>
@@ -568,6 +585,21 @@
<li><a href="https://github.com/willmiao/ComfyUI-Lora-Manager/wiki/Configuration" target="_blank">Configuration Options (WIP)</a></li> <li><a href="https://github.com/willmiao/ComfyUI-Lora-Manager/wiki/Configuration" target="_blank">Configuration Options (WIP)</a></li>
</ul> </ul>
</div> </div>
<div class="docs-section">
<h4>
<i class="fas fa-puzzle-piece"></i> Extensions
<span class="new-content-badge">NEW</span>
</h4>
<ul class="docs-links">
<li>
<a href="https://github.com/willmiao/ComfyUI-Lora-Manager/wiki/LoRA-Manager-Civitai-Extension-(Chrome-Extension)" target="_blank">
LM Civitai Extension
<span class="new-content-badge inline">NEW</span>
</a>
</li>
</ul>
</div>
</div> </div>
</div> </div>
</div> </div>

View File

@@ -11,7 +11,7 @@
{% block head_scripts %} {% block head_scripts %}
<!-- Add Chart.js for statistics page --> <!-- Add Chart.js for statistics page -->
<script src="https://cdn.jsdelivr.net/npm/chart.js@4.4.0/dist/chart.umd.js"></script> <script src="/loras_static/vendor/chart.js/chart.umd.js"></script>
{% endblock %} {% endblock %}
{% block init_title %}Initializing Statistics{% endblock %} {% block init_title %}Initializing Statistics{% endblock %}

View File

@@ -76,7 +76,9 @@ app.registerExtension({
// Standard mode - update a specific node // Standard mode - update a specific node
const node = app.graph.getNodeById(+id); const node = app.graph.getNodeById(+id);
if (!node || (node.comfyClass !== "Lora Loader (LoraManager)" && node.comfyClass !== "Lora Stacker (LoraManager)")) { if (!node || (node.comfyClass !== "Lora Loader (LoraManager)" &&
node.comfyClass !== "Lora Stacker (LoraManager)" &&
node.comfyClass !== "WanVideo Lora Select (LoraManager)")) {
console.warn("Node not found or not a LoraLoader:", id); console.warn("Node not found or not a LoraLoader:", id);
return; return;
} }
@@ -87,7 +89,7 @@ app.registerExtension({
// Helper method to update a single node's lora code // Helper method to update a single node's lora code
updateNodeLoraCode(node, loraCode, mode) { updateNodeLoraCode(node, loraCode, mode) {
// Update the input widget with new lora code // Update the input widget with new lora code
const inputWidget = node.widgets[0]; const inputWidget = node.inputWidget;
if (!inputWidget) return; if (!inputWidget) return;
// Get the current lora code // Get the current lora code
@@ -182,6 +184,7 @@ app.registerExtension({
// Update input widget callback // Update input widget callback
const inputWidget = this.widgets[0]; const inputWidget = this.widgets[0];
this.inputWidget = inputWidget;
inputWidget.callback = (value) => { inputWidget.callback = (value) => {
if (isUpdating) return; if (isUpdating) return;
isUpdating = true; isUpdating = true;

View File

@@ -105,6 +105,7 @@ app.registerExtension({
// Update input widget callback // Update input widget callback
const inputWidget = this.widgets[0]; const inputWidget = this.widgets[0];
this.inputWidget = inputWidget;
inputWidget.callback = (value) => { inputWidget.callback = (value) => {
if (isUpdating) return; if (isUpdating) return;
isUpdating = true; isUpdating = true;

View File

@@ -52,7 +52,9 @@ app.registerExtension({
// Find all Lora nodes // Find all Lora nodes
const loraNodes = []; const loraNodes = [];
for (const node of workflow.nodes.values()) { for (const node of workflow.nodes.values()) {
if (node.type === "Lora Loader (LoraManager)" || node.type === "Lora Stacker (LoraManager)") { if (node.type === "Lora Loader (LoraManager)" ||
node.type === "Lora Stacker (LoraManager)" ||
node.type === "WanVideo Lora Select (LoraManager)") {
loraNodes.push({ loraNodes.push({
node_id: node.id, node_id: node.id,
bgcolor: node.bgcolor || null, bgcolor: node.bgcolor || null,

View File

@@ -0,0 +1,131 @@
import { app } from "../../scripts/app.js";
import {
LORA_PATTERN,
getActiveLorasFromNode,
collectActiveLorasFromChain,
updateConnectedTriggerWords,
chainCallback
} from "./utils.js";
import { addLorasWidget } from "./loras_widget.js";
function mergeLoras(lorasText, lorasArr) {
const result = [];
let match;
// Reset pattern index before using
LORA_PATTERN.lastIndex = 0;
// Parse text input and create initial entries
while ((match = LORA_PATTERN.exec(lorasText)) !== null) {
const name = match[1];
const modelStrength = Number(match[2]);
// Extract clip strength if provided, otherwise use model strength
const clipStrength = match[3] ? Number(match[3]) : modelStrength;
// Find if this lora exists in the array data
const existingLora = lorasArr.find(l => l.name === name);
result.push({
name: name,
// Use existing strength if available, otherwise use input strength
strength: existingLora ? existingLora.strength : modelStrength,
active: existingLora ? existingLora.active : true,
clipStrength: existingLora ? existingLora.clipStrength : clipStrength,
});
}
return result;
}
app.registerExtension({
name: "LoraManager.WanVideoLoraSelect",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeType.comfyClass === "WanVideo Lora Select (LoraManager)") {
chainCallback(nodeType.prototype, "onNodeCreated", async function() {
// Enable widget serialization
this.serialize_widgets = true;
// Add optional inputs
this.addInput("prev_lora", 'WANVIDLORA', {
"shape": 7 // 7 is the shape of the optional input
});
this.addInput("blocks", 'SELECTEDBLOCKS', {
"shape": 7 // 7 is the shape of the optional input
});
// Restore saved value if exists
let existingLoras = [];
if (this.widgets_values && this.widgets_values.length > 0) {
// 0 for low_mem_load, 1 for text widget, 2 for loras widget
const savedValue = this.widgets_values[2];
existingLoras = savedValue || [];
}
// Merge the loras data
const mergedLoras = mergeLoras(this.widgets[1].value, existingLoras);
// Add flag to prevent callback loops
let isUpdating = false;
const result = addLorasWidget(this, "loras", {
defaultVal: mergedLoras // Pass object directly
}, (value) => {
// Prevent recursive calls
if (isUpdating) return;
isUpdating = true;
try {
// Remove loras that are not in the value array
const inputWidget = this.widgets[1];
const currentLoras = value.map(l => l.name);
// Use the constant pattern here as well
let newText = inputWidget.value.replace(LORA_PATTERN, (match, name, strength) => {
return currentLoras.includes(name) ? match : '';
});
// Clean up multiple spaces and trim
newText = newText.replace(/\s+/g, ' ').trim();
inputWidget.value = newText;
// Update this node's direct trigger toggles with its own active loras
const activeLoraNames = new Set();
value.forEach(lora => {
if (lora.active) {
activeLoraNames.add(lora.name);
}
});
updateConnectedTriggerWords(this, activeLoraNames);
} finally {
isUpdating = false;
}
});
this.lorasWidget = result.widget;
// Update input widget callback
const inputWidget = this.widgets[1];
this.inputWidget = inputWidget;
inputWidget.callback = (value) => {
if (isUpdating) return;
isUpdating = true;
try {
const currentLoras = this.lorasWidget.value || [];
const mergedLoras = mergeLoras(value, currentLoras);
this.lorasWidget.value = mergedLoras;
// Update this node's direct trigger toggles with its own active loras
const activeLoraNames = getActiveLorasFromNode(this);
updateConnectedTriggerWords(this, activeLoraNames);
} finally {
isUpdating = false;
}
};
});
}
},
});

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 MiB