mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(randomizer): add LoRA locking and roll modes
- Implement LoRA locking to prevent specific LoRAs from being changed during randomization - Add visual styling for locked state with amber accents and distinct backgrounds - Introduce `roll_mode` configuration with 'backend' (execute current selection while generating new) and 'frontend' (execute newly generated selection) behaviors - Move LoraPoolNode to 'Lora Manager/randomizer' category and remove standalone class mappings - Standardize RETURN_NAMES in LoraRandomizerNode for consistency
This commit is contained in:
163
AGENTS.md
Normal file
163
AGENTS.md
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
# AGENTS.md
|
||||||
|
|
||||||
|
This file provides guidance for agentic coding assistants working in this repository.
|
||||||
|
|
||||||
|
## Development Commands
|
||||||
|
|
||||||
|
### Backend Development
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install dependencies
|
||||||
|
pip install -r requirements.txt
|
||||||
|
pip install -r requirements-dev.txt
|
||||||
|
|
||||||
|
# Run standalone server (port 8188 by default)
|
||||||
|
python standalone.py --port 8188
|
||||||
|
|
||||||
|
# Run all backend tests
|
||||||
|
pytest
|
||||||
|
|
||||||
|
# Run specific test file
|
||||||
|
pytest tests/test_recipes.py
|
||||||
|
|
||||||
|
# Run specific test function
|
||||||
|
pytest tests/test_recipes.py::test_function_name
|
||||||
|
|
||||||
|
# Run backend tests with coverage
|
||||||
|
COVERAGE_FILE=coverage/backend/.coverage pytest \
|
||||||
|
--cov=py \
|
||||||
|
--cov=standalone \
|
||||||
|
--cov-report=term-missing \
|
||||||
|
--cov-report=html:coverage/backend/html \
|
||||||
|
--cov-report=xml:coverage/backend/coverage.xml \
|
||||||
|
--cov-report=json:coverage/backend/coverage.json
|
||||||
|
```
|
||||||
|
|
||||||
|
### Frontend Development
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install frontend dependencies
|
||||||
|
npm install
|
||||||
|
|
||||||
|
# Run frontend tests
|
||||||
|
npm test
|
||||||
|
|
||||||
|
# Run frontend tests in watch mode
|
||||||
|
npm run test:watch
|
||||||
|
|
||||||
|
# Run frontend tests with coverage
|
||||||
|
npm run test:coverage
|
||||||
|
```
|
||||||
|
|
||||||
|
## Python Code Style
|
||||||
|
|
||||||
|
### Imports
|
||||||
|
|
||||||
|
- Use `from __future__ import annotations` for forward references in type hints
|
||||||
|
- Group imports: standard library, third-party, local (separated by blank lines)
|
||||||
|
- Use absolute imports within `py/` package: `from ..services import X`
|
||||||
|
- Mock ComfyUI dependencies in tests using `tests/conftest.py` patterns
|
||||||
|
|
||||||
|
### Formatting & Types
|
||||||
|
|
||||||
|
- PEP 8 with 4-space indentation
|
||||||
|
- Type hints required for function signatures and class attributes
|
||||||
|
- Use `TYPE_CHECKING` guard for type-checking-only imports
|
||||||
|
- Prefer dataclasses for simple data containers
|
||||||
|
- Use `Optional[T]` for nullable types, `Union[T, None]` only when necessary
|
||||||
|
|
||||||
|
### Naming Conventions
|
||||||
|
|
||||||
|
- Files: `snake_case.py` (e.g., `model_scanner.py`, `lora_service.py`)
|
||||||
|
- Classes: `PascalCase` (e.g., `ModelScanner`, `LoraService`)
|
||||||
|
- Functions/variables: `snake_case` (e.g., `get_instance`, `model_type`)
|
||||||
|
- Constants: `UPPER_SNAKE_CASE` (e.g., `VALID_LORA_TYPES`)
|
||||||
|
- Private members: `_single_underscore` (protected), `__double_underscore` (name-mangled)
|
||||||
|
|
||||||
|
### Error Handling
|
||||||
|
|
||||||
|
- Use `logging.getLogger(__name__)` for module-level loggers
|
||||||
|
- Define custom exceptions in `py/services/errors.py`
|
||||||
|
- Use `asyncio.Lock` for thread-safe singleton patterns
|
||||||
|
- Raise specific exceptions with descriptive messages
|
||||||
|
- Log errors at appropriate levels (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||||
|
|
||||||
|
### Async Patterns
|
||||||
|
|
||||||
|
- Use `async def` for I/O-bound operations
|
||||||
|
- Mark async tests with `@pytest.mark.asyncio`
|
||||||
|
- Use `async with` for context managers
|
||||||
|
- Singleton pattern with class-level locks: see `ModelScanner.get_instance()`
|
||||||
|
- Use `aiohttp.web.Response` for HTTP responses
|
||||||
|
|
||||||
|
### Testing Patterns
|
||||||
|
|
||||||
|
- Use `pytest` with `--import-mode=importlib`
|
||||||
|
- Fixtures in `tests/conftest.py` handle ComfyUI mocking
|
||||||
|
- Use `@pytest.mark.no_settings_dir_isolation` for tests needing real paths
|
||||||
|
- Test files: `tests/test_*.py`
|
||||||
|
- Use `tmp_path_factory` for temporary directory isolation
|
||||||
|
|
||||||
|
## JavaScript Code Style
|
||||||
|
|
||||||
|
### Imports & Modules
|
||||||
|
|
||||||
|
- ES modules with `import`/`export`
|
||||||
|
- Use `import { app } from "../../scripts/app.js"` for ComfyUI integration
|
||||||
|
- Export named functions/classes: `export function foo() {}`
|
||||||
|
- Widget files use `*_widget.js` suffix
|
||||||
|
|
||||||
|
### Naming & Formatting
|
||||||
|
|
||||||
|
- camelCase for functions, variables, object properties
|
||||||
|
- PascalCase for classes/constructors
|
||||||
|
- Constants: `UPPER_SNAKE_CASE` (e.g., `CONVERTED_TYPE`)
|
||||||
|
- Files: `snake_case.js` or `kebab-case.js`
|
||||||
|
- 2-space indentation preferred (follow existing file conventions)
|
||||||
|
|
||||||
|
### Widget Development
|
||||||
|
|
||||||
|
- Use `app.registerExtension()` to register ComfyUI extensions
|
||||||
|
- Use `node.addDOMWidget(name, type, element, options)` for custom widgets
|
||||||
|
- Event handlers attached via `addEventListener` or widget callbacks
|
||||||
|
- See `web/comfyui/utils.js` for shared utilities
|
||||||
|
|
||||||
|
## Architecture Patterns
|
||||||
|
|
||||||
|
### Service Layer
|
||||||
|
|
||||||
|
- Use `ServiceRegistry` singleton for dependency injection
|
||||||
|
- Services follow singleton pattern via `get_instance()` class method
|
||||||
|
- Separate scanners (discovery) from services (business logic)
|
||||||
|
- Handlers in `py/routes/handlers/` implement route logic
|
||||||
|
|
||||||
|
### Model Types
|
||||||
|
|
||||||
|
- BaseModelService is abstract base for LoRA, Checkpoint, Embedding services
|
||||||
|
- ModelScanner provides file discovery and hash-based deduplication
|
||||||
|
- Persistent cache in SQLite via `PersistentModelCache`
|
||||||
|
- Metadata sync from CivitAI/CivArchive via `MetadataSyncService`
|
||||||
|
|
||||||
|
### Routes & Handlers
|
||||||
|
|
||||||
|
- Route registrars organize endpoints by domain: `ModelRouteRegistrar`, etc.
|
||||||
|
- Handlers are pure functions taking dependencies as parameters
|
||||||
|
- Use `WebSocketManager` for real-time progress updates
|
||||||
|
- Return `aiohttp.web.json_response` or `web.Response`
|
||||||
|
|
||||||
|
### Recipe System
|
||||||
|
|
||||||
|
- Base metadata in `py/recipes/base.py`
|
||||||
|
- Enrichment adds model metadata: `RecipeEnrichmentService`
|
||||||
|
- Parsers for different formats in `py/recipes/parsers/`
|
||||||
|
|
||||||
|
## Important Notes
|
||||||
|
|
||||||
|
- Always use English for comments (per copilot-instructions.md)
|
||||||
|
- Dual mode: ComfyUI plugin (uses folder_paths) vs standalone (reads settings.json)
|
||||||
|
- Detection: `os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1"`
|
||||||
|
- Settings auto-saved in user directory or portable mode
|
||||||
|
- WebSocket broadcasts for real-time updates (downloads, scans)
|
||||||
|
- Symlink handling requires normalized paths
|
||||||
|
- API endpoints follow `/loras/*`, `/checkpoints/*`, `/embeddings/*` patterns
|
||||||
|
- Run `python scripts/sync_translation_keys.py` after UI string updates
|
||||||
@@ -20,7 +20,7 @@ class LoraPoolNode:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
NAME = "Lora Pool (LoraManager)"
|
NAME = "Lora Pool (LoraManager)"
|
||||||
CATEGORY = "Lora Manager/pools"
|
CATEGORY = "Lora Manager/randomizer"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def INPUT_TYPES(cls):
|
||||||
@@ -85,10 +85,3 @@ class LoraPoolNode:
|
|||||||
},
|
},
|
||||||
"preview": {"matchCount": 0, "lastUpdated": 0},
|
"preview": {"matchCount": 0, "lastUpdated": 0},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# Node class mappings for ComfyUI
|
|
||||||
NODE_CLASS_MAPPINGS = {"LoraPoolNode": LoraPoolNode}
|
|
||||||
|
|
||||||
# Display name mappings
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {"LoraPoolNode": "LoRA Pool (Filter)"}
|
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ class LoraRandomizerNode:
|
|||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("LORA_STACK",)
|
RETURN_TYPES = ("LORA_STACK",)
|
||||||
RETURN_NAMES = ("lora_stack",)
|
RETURN_NAMES = ("LORA_STACK",)
|
||||||
|
|
||||||
FUNCTION = "randomize"
|
FUNCTION = "randomize"
|
||||||
OUTPUT_NODE = False
|
OUTPUT_NODE = False
|
||||||
@@ -53,9 +53,6 @@ class LoraRandomizerNode:
|
|||||||
"""
|
"""
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
|
||||||
# Get lora scanner to access available loras
|
|
||||||
scanner = await ServiceRegistry.get_lora_scanner()
|
|
||||||
|
|
||||||
# Parse randomizer settings
|
# Parse randomizer settings
|
||||||
count_mode = randomizer_config.get("count_mode", "range")
|
count_mode = randomizer_config.get("count_mode", "range")
|
||||||
count_fixed = randomizer_config.get("count_fixed", 5)
|
count_fixed = randomizer_config.get("count_fixed", 5)
|
||||||
@@ -68,6 +65,93 @@ class LoraRandomizerNode:
|
|||||||
clip_strength_max = randomizer_config.get("clip_strength_max", 1.0)
|
clip_strength_max = randomizer_config.get("clip_strength_max", 1.0)
|
||||||
roll_mode = randomizer_config.get("roll_mode", "frontend")
|
roll_mode = randomizer_config.get("roll_mode", "frontend")
|
||||||
|
|
||||||
|
# Get lora scanner to access available loras
|
||||||
|
scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
|
|
||||||
|
# Backend roll mode: execute with input loras, return new random to UI
|
||||||
|
if roll_mode == "backend":
|
||||||
|
execution_stack = self._build_execution_stack_from_input(loras)
|
||||||
|
ui_loras = await self._generate_random_loras_for_ui(
|
||||||
|
scanner, randomizer_config, loras, pool_config
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[LoraRandomizerNode] Backend roll: executing with input, returning new random to UI"
|
||||||
|
)
|
||||||
|
return {"result": (execution_stack,), "ui": {"loras": ui_loras}}
|
||||||
|
|
||||||
|
# Frontend roll mode: use current behavior (random selection for both)
|
||||||
|
ui_loras = await self._generate_random_loras_for_ui(
|
||||||
|
scanner, randomizer_config, loras, pool_config
|
||||||
|
)
|
||||||
|
execution_stack = self._build_execution_stack_from_input(ui_loras)
|
||||||
|
logger.info(
|
||||||
|
f"[LoraRandomizerNode] Frontend roll: executing with random selection"
|
||||||
|
)
|
||||||
|
return {"result": (execution_stack,), "ui": {"loras": ui_loras}}
|
||||||
|
|
||||||
|
def _build_execution_stack_from_input(self, loras):
|
||||||
|
"""
|
||||||
|
Build LORA_STACK tuple from input loras list for execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loras: List of LoRA dicts with name, strength, clipStrength, active
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tuples (lora_path, model_strength, clip_strength)
|
||||||
|
"""
|
||||||
|
lora_stack = []
|
||||||
|
for lora in loras:
|
||||||
|
if not lora.get("active", False):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get file path
|
||||||
|
lora_path, trigger_words = get_lora_info(lora["name"])
|
||||||
|
if not lora_path:
|
||||||
|
logger.warning(
|
||||||
|
f"[LoraRandomizerNode] Could not find path for LoRA: {lora['name']}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Normalize path separators
|
||||||
|
lora_path = lora_path.replace("/", os.sep)
|
||||||
|
|
||||||
|
# Extract strengths
|
||||||
|
model_strength = lora.get("strength", 1.0)
|
||||||
|
clip_strength = lora.get("clipStrength", model_strength)
|
||||||
|
|
||||||
|
lora_stack.append((lora_path, model_strength, clip_strength))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[LoraRandomizerNode] Built execution stack with {len(lora_stack)} LoRAs"
|
||||||
|
)
|
||||||
|
return lora_stack
|
||||||
|
|
||||||
|
async def _generate_random_loras_for_ui(
|
||||||
|
self, scanner, randomizer_config, input_loras, pool_config=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate new random loras for UI display.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scanner: LoraScanner instance
|
||||||
|
randomizer_config: Dict with randomizer settings
|
||||||
|
input_loras: Current input loras (for extracting locked loras)
|
||||||
|
pool_config: Optional pool filters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of LoRA dicts for UI display
|
||||||
|
"""
|
||||||
|
# Parse randomizer settings
|
||||||
|
count_mode = randomizer_config.get("count_mode", "range")
|
||||||
|
count_fixed = randomizer_config.get("count_fixed", 5)
|
||||||
|
count_min = randomizer_config.get("count_min", 3)
|
||||||
|
count_max = randomizer_config.get("count_max", 7)
|
||||||
|
model_strength_min = randomizer_config.get("model_strength_min", 0.0)
|
||||||
|
model_strength_max = randomizer_config.get("model_strength_max", 1.0)
|
||||||
|
use_same_clip_strength = randomizer_config.get("use_same_clip_strength", True)
|
||||||
|
clip_strength_min = randomizer_config.get("clip_strength_min", 0.0)
|
||||||
|
clip_strength_max = randomizer_config.get("clip_strength_max", 1.0)
|
||||||
|
|
||||||
# Determine target count
|
# Determine target count
|
||||||
if count_mode == "fixed":
|
if count_mode == "fixed":
|
||||||
target_count = count_fixed
|
target_count = count_fixed
|
||||||
@@ -75,11 +159,11 @@ class LoraRandomizerNode:
|
|||||||
target_count = random.randint(count_min, count_max)
|
target_count = random.randint(count_min, count_max)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[LoraRandomizerNode] Target count: {target_count}, Roll mode: {roll_mode}"
|
f"[LoraRandomizerNode] Generating random LoRAs, target count: {target_count}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract locked LoRAs from input
|
# Extract locked LoRAs from input
|
||||||
locked_loras = [lora for lora in loras if lora.get("locked", False)]
|
locked_loras = [lora for lora in input_loras if lora.get("locked", False)]
|
||||||
locked_count = len(locked_loras)
|
locked_count = len(locked_loras)
|
||||||
|
|
||||||
logger.info(f"[LoraRandomizerNode] Locked LoRAs: {locked_count}")
|
logger.info(f"[LoraRandomizerNode] Locked LoRAs: {locked_count}")
|
||||||
@@ -106,8 +190,6 @@ class LoraRandomizerNode:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Calculate how many new LoRAs to select
|
# Calculate how many new LoRAs to select
|
||||||
# In frontend mode, if loras already has data, preserve unlocked ones if roll_mode requires
|
|
||||||
# For simplicity in backend mode, we regenerate all unlocked slots
|
|
||||||
slots_needed = target_count - locked_count
|
slots_needed = target_count - locked_count
|
||||||
|
|
||||||
if slots_needed < 0:
|
if slots_needed < 0:
|
||||||
@@ -161,33 +243,10 @@ class LoraRandomizerNode:
|
|||||||
# Merge with locked LoRAs
|
# Merge with locked LoRAs
|
||||||
result_loras.extend(locked_loras)
|
result_loras.extend(locked_loras)
|
||||||
|
|
||||||
logger.info(f"[LoraRandomizerNode] Final LoRA count: {len(result_loras)}")
|
logger.info(
|
||||||
|
f"[LoraRandomizerNode] Final random LoRA count: {len(result_loras)}"
|
||||||
# Build LORA_STACK output
|
)
|
||||||
lora_stack = []
|
return result_loras
|
||||||
for lora in result_loras:
|
|
||||||
if not lora.get("active", False):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get file path
|
|
||||||
lora_path, trigger_words = get_lora_info(lora["name"])
|
|
||||||
if not lora_path:
|
|
||||||
logger.warning(
|
|
||||||
f"[LoraRandomizerNode] Could not find path for LoRA: {lora['name']}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Normalize path separators
|
|
||||||
lora_path = lora_path.replace("/", os.sep)
|
|
||||||
|
|
||||||
# Extract strengths
|
|
||||||
model_strength = lora.get("strength", 1.0)
|
|
||||||
clip_strength = lora.get("clipStrength", model_strength)
|
|
||||||
|
|
||||||
lora_stack.append((lora_path, model_strength, clip_strength))
|
|
||||||
|
|
||||||
# Return format: result for workflow + ui for frontend display
|
|
||||||
return {"result": (lora_stack,), "ui": {"loras": result_loras}}
|
|
||||||
|
|
||||||
async def _apply_pool_filters(self, available_loras, pool_config, scanner):
|
async def _apply_pool_filters(self, available_loras, pool_config, scanner):
|
||||||
"""
|
"""
|
||||||
@@ -288,10 +347,3 @@ class LoraRandomizerNode:
|
|||||||
]
|
]
|
||||||
|
|
||||||
return available_loras
|
return available_loras
|
||||||
|
|
||||||
|
|
||||||
# Node class mappings for ComfyUI
|
|
||||||
NODE_CLASS_MAPPINGS = {"LoraRandomizerNode": LoraRandomizerNode}
|
|
||||||
|
|
||||||
# Display name mappings
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {"LoraRandomizerNode": "LoRA Randomizer"}
|
|
||||||
|
|||||||
@@ -8,24 +8,27 @@ from ..config import config
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LoraService(BaseModelService):
|
class LoraService(BaseModelService):
|
||||||
"""LoRA-specific service implementation"""
|
"""LoRA-specific service implementation"""
|
||||||
|
|
||||||
def __init__(self, scanner, update_service=None):
|
def __init__(self, scanner, update_service=None):
|
||||||
"""Initialize LoRA service
|
"""Initialize LoRA service
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
scanner: LoRA scanner instance
|
scanner: LoRA scanner instance
|
||||||
update_service: Optional service for remote update tracking.
|
update_service: Optional service for remote update tracking.
|
||||||
"""
|
"""
|
||||||
super().__init__("lora", scanner, LoraMetadata, update_service=update_service)
|
super().__init__("lora", scanner, LoraMetadata, update_service=update_service)
|
||||||
|
|
||||||
async def format_response(self, lora_data: Dict) -> Dict:
|
async def format_response(self, lora_data: Dict) -> Dict:
|
||||||
"""Format LoRA data for API response"""
|
"""Format LoRA data for API response"""
|
||||||
return {
|
return {
|
||||||
"model_name": lora_data["model_name"],
|
"model_name": lora_data["model_name"],
|
||||||
"file_name": lora_data["file_name"],
|
"file_name": lora_data["file_name"],
|
||||||
"preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")),
|
"preview_url": config.get_preview_static_url(
|
||||||
|
lora_data.get("preview_url", "")
|
||||||
|
),
|
||||||
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
|
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
|
||||||
"base_model": lora_data.get("base_model", ""),
|
"base_model": lora_data.get("base_model", ""),
|
||||||
"folder": lora_data["folder"],
|
"folder": lora_data["folder"],
|
||||||
@@ -40,141 +43,170 @@ class LoraService(BaseModelService):
|
|||||||
"notes": lora_data.get("notes", ""),
|
"notes": lora_data.get("notes", ""),
|
||||||
"favorite": lora_data.get("favorite", False),
|
"favorite": lora_data.get("favorite", False),
|
||||||
"update_available": bool(lora_data.get("update_available", False)),
|
"update_available": bool(lora_data.get("update_available", False)),
|
||||||
"civitai": self.filter_civitai_data(lora_data.get("civitai", {}), minimal=True)
|
"civitai": self.filter_civitai_data(
|
||||||
|
lora_data.get("civitai", {}), minimal=True
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||||
"""Apply LoRA-specific filters"""
|
"""Apply LoRA-specific filters"""
|
||||||
# Handle first_letter filter for LoRAs
|
# Handle first_letter filter for LoRAs
|
||||||
first_letter = kwargs.get('first_letter')
|
first_letter = kwargs.get("first_letter")
|
||||||
if first_letter:
|
if first_letter:
|
||||||
data = self._filter_by_first_letter(data, first_letter)
|
data = self._filter_by_first_letter(data, first_letter)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
|
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
|
||||||
"""Filter data by first letter of model name
|
"""Filter data by first letter of model name
|
||||||
|
|
||||||
Special handling:
|
Special handling:
|
||||||
- '#': Numbers (0-9)
|
- '#': Numbers (0-9)
|
||||||
- '@': Special characters (not alphanumeric)
|
- '@': Special characters (not alphanumeric)
|
||||||
- '漢': CJK characters
|
- '漢': CJK characters
|
||||||
"""
|
"""
|
||||||
filtered_data = []
|
filtered_data = []
|
||||||
|
|
||||||
for lora in data:
|
for lora in data:
|
||||||
model_name = lora.get('model_name', '')
|
model_name = lora.get("model_name", "")
|
||||||
if not model_name:
|
if not model_name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
first_char = model_name[0].upper()
|
first_char = model_name[0].upper()
|
||||||
|
|
||||||
if letter == '#' and first_char.isdigit():
|
if letter == "#" and first_char.isdigit():
|
||||||
filtered_data.append(lora)
|
filtered_data.append(lora)
|
||||||
elif letter == '@' and not first_char.isalnum():
|
elif letter == "@" and not first_char.isalnum():
|
||||||
# Special characters (not alphanumeric)
|
# Special characters (not alphanumeric)
|
||||||
filtered_data.append(lora)
|
filtered_data.append(lora)
|
||||||
elif letter == '漢' and self._is_cjk_character(first_char):
|
elif letter == "漢" and self._is_cjk_character(first_char):
|
||||||
# CJK characters
|
# CJK characters
|
||||||
filtered_data.append(lora)
|
filtered_data.append(lora)
|
||||||
elif letter.upper() == first_char:
|
elif letter.upper() == first_char:
|
||||||
# Regular alphabet matching
|
# Regular alphabet matching
|
||||||
filtered_data.append(lora)
|
filtered_data.append(lora)
|
||||||
|
|
||||||
return filtered_data
|
return filtered_data
|
||||||
|
|
||||||
def _is_cjk_character(self, char: str) -> bool:
|
def _is_cjk_character(self, char: str) -> bool:
|
||||||
"""Check if character is a CJK character"""
|
"""Check if character is a CJK character"""
|
||||||
# Define Unicode ranges for CJK characters
|
# Define Unicode ranges for CJK characters
|
||||||
cjk_ranges = [
|
cjk_ranges = [
|
||||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||||
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
|
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
|
||||||
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
|
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
|
||||||
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
|
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
|
||||||
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
|
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
|
||||||
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
|
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
|
||||||
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
|
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
|
||||||
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
|
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
|
||||||
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
|
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
|
||||||
(0x3300, 0x33FF), # CJK Compatibility
|
(0x3300, 0x33FF), # CJK Compatibility
|
||||||
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
|
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
|
||||||
(0x3100, 0x312F), # Bopomofo
|
(0x3100, 0x312F), # Bopomofo
|
||||||
(0x31A0, 0x31BF), # Bopomofo Extended
|
(0x31A0, 0x31BF), # Bopomofo Extended
|
||||||
(0x3040, 0x309F), # Hiragana
|
(0x3040, 0x309F), # Hiragana
|
||||||
(0x30A0, 0x30FF), # Katakana
|
(0x30A0, 0x30FF), # Katakana
|
||||||
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
|
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
|
||||||
(0xAC00, 0xD7AF), # Hangul Syllables
|
(0xAC00, 0xD7AF), # Hangul Syllables
|
||||||
(0x1100, 0x11FF), # Hangul Jamo
|
(0x1100, 0x11FF), # Hangul Jamo
|
||||||
(0xA960, 0xA97F), # Hangul Jamo Extended-A
|
(0xA960, 0xA97F), # Hangul Jamo Extended-A
|
||||||
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
|
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
|
||||||
]
|
]
|
||||||
|
|
||||||
code_point = ord(char)
|
code_point = ord(char)
|
||||||
return any(start <= code_point <= end for start, end in cjk_ranges)
|
return any(start <= code_point <= end for start, end in cjk_ranges)
|
||||||
|
|
||||||
# LoRA-specific methods
|
# LoRA-specific methods
|
||||||
async def get_letter_counts(self) -> Dict[str, int]:
|
async def get_letter_counts(self) -> Dict[str, int]:
|
||||||
"""Get count of LoRAs for each letter of the alphabet"""
|
"""Get count of LoRAs for each letter of the alphabet"""
|
||||||
cache = await self.scanner.get_cached_data()
|
cache = await self.scanner.get_cached_data()
|
||||||
data = cache.raw_data
|
data = cache.raw_data
|
||||||
|
|
||||||
# Define letter categories
|
# Define letter categories
|
||||||
letters = {
|
letters = {
|
||||||
'#': 0, # Numbers
|
"#": 0, # Numbers
|
||||||
'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0,
|
"A": 0,
|
||||||
'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0,
|
"B": 0,
|
||||||
'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0,
|
"C": 0,
|
||||||
'Y': 0, 'Z': 0,
|
"D": 0,
|
||||||
'@': 0, # Special characters
|
"E": 0,
|
||||||
'漢': 0 # CJK characters
|
"F": 0,
|
||||||
|
"G": 0,
|
||||||
|
"H": 0,
|
||||||
|
"I": 0,
|
||||||
|
"J": 0,
|
||||||
|
"K": 0,
|
||||||
|
"L": 0,
|
||||||
|
"M": 0,
|
||||||
|
"N": 0,
|
||||||
|
"O": 0,
|
||||||
|
"P": 0,
|
||||||
|
"Q": 0,
|
||||||
|
"R": 0,
|
||||||
|
"S": 0,
|
||||||
|
"T": 0,
|
||||||
|
"U": 0,
|
||||||
|
"V": 0,
|
||||||
|
"W": 0,
|
||||||
|
"X": 0,
|
||||||
|
"Y": 0,
|
||||||
|
"Z": 0,
|
||||||
|
"@": 0, # Special characters
|
||||||
|
"漢": 0, # CJK characters
|
||||||
}
|
}
|
||||||
|
|
||||||
# Count models for each letter
|
# Count models for each letter
|
||||||
for lora in data:
|
for lora in data:
|
||||||
model_name = lora.get('model_name', '')
|
model_name = lora.get("model_name", "")
|
||||||
if not model_name:
|
if not model_name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
first_char = model_name[0].upper()
|
first_char = model_name[0].upper()
|
||||||
|
|
||||||
if first_char.isdigit():
|
if first_char.isdigit():
|
||||||
letters['#'] += 1
|
letters["#"] += 1
|
||||||
elif first_char in letters:
|
elif first_char in letters:
|
||||||
letters[first_char] += 1
|
letters[first_char] += 1
|
||||||
elif self._is_cjk_character(first_char):
|
elif self._is_cjk_character(first_char):
|
||||||
letters['漢'] += 1
|
letters["漢"] += 1
|
||||||
elif not first_char.isalnum():
|
elif not first_char.isalnum():
|
||||||
letters['@'] += 1
|
letters["@"] += 1
|
||||||
|
|
||||||
return letters
|
return letters
|
||||||
|
|
||||||
async def get_lora_trigger_words(self, lora_name: str) -> List[str]:
|
async def get_lora_trigger_words(self, lora_name: str) -> List[str]:
|
||||||
"""Get trigger words for a specific LoRA file"""
|
"""Get trigger words for a specific LoRA file"""
|
||||||
cache = await self.scanner.get_cached_data()
|
cache = await self.scanner.get_cached_data()
|
||||||
|
|
||||||
for lora in cache.raw_data:
|
for lora in cache.raw_data:
|
||||||
if lora['file_name'] == lora_name:
|
if lora["file_name"] == lora_name:
|
||||||
civitai_data = lora.get('civitai', {})
|
civitai_data = lora.get("civitai", {})
|
||||||
return civitai_data.get('trainedWords', [])
|
return civitai_data.get("trainedWords", [])
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def get_lora_usage_tips_by_relative_path(self, relative_path: str) -> Optional[str]:
|
async def get_lora_usage_tips_by_relative_path(
|
||||||
|
self, relative_path: str
|
||||||
|
) -> Optional[str]:
|
||||||
"""Get usage tips for a LoRA by its relative path"""
|
"""Get usage tips for a LoRA by its relative path"""
|
||||||
cache = await self.scanner.get_cached_data()
|
cache = await self.scanner.get_cached_data()
|
||||||
|
|
||||||
for lora in cache.raw_data:
|
for lora in cache.raw_data:
|
||||||
file_path = lora.get('file_path', '')
|
file_path = lora.get("file_path", "")
|
||||||
if file_path:
|
if file_path:
|
||||||
# Convert to forward slashes and extract relative path
|
# Convert to forward slashes and extract relative path
|
||||||
file_path_normalized = file_path.replace('\\', '/')
|
file_path_normalized = file_path.replace("\\", "/")
|
||||||
relative_path = relative_path.replace('\\', '/')
|
relative_path = relative_path.replace("\\", "/")
|
||||||
# Find the relative path part by looking for the relative_path in the full path
|
# Find the relative path part by looking for the relative_path in the full path
|
||||||
if file_path_normalized.endswith(relative_path) or relative_path in file_path_normalized:
|
if (
|
||||||
return lora.get('usage_tips', '')
|
file_path_normalized.endswith(relative_path)
|
||||||
|
or relative_path in file_path_normalized
|
||||||
|
):
|
||||||
|
return lora.get("usage_tips", "")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def find_duplicate_hashes(self) -> Dict:
|
def find_duplicate_hashes(self) -> Dict:
|
||||||
"""Find LoRAs with duplicate SHA256 hashes"""
|
"""Find LoRAs with duplicate SHA256 hashes"""
|
||||||
return self.scanner._hash_index.get_duplicate_hashes()
|
return self.scanner._hash_index.get_duplicate_hashes()
|
||||||
@@ -192,7 +224,7 @@ class LoraService(BaseModelService):
|
|||||||
clip_strength_min: float = 0.0,
|
clip_strength_min: float = 0.0,
|
||||||
clip_strength_max: float = 1.0,
|
clip_strength_max: float = 1.0,
|
||||||
locked_loras: Optional[List[Dict]] = None,
|
locked_loras: Optional[List[Dict]] = None,
|
||||||
pool_config: Optional[Dict] = None
|
pool_config: Optional[Dict] = None,
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
Get random LoRAs with specified strength ranges.
|
Get random LoRAs with specified strength ranges.
|
||||||
@@ -235,10 +267,9 @@ class LoraService(BaseModelService):
|
|||||||
locked_loras = locked_loras[:count]
|
locked_loras = locked_loras[:count]
|
||||||
|
|
||||||
# Filter out locked LoRAs from available pool
|
# Filter out locked LoRAs from available pool
|
||||||
locked_names = {lora['name'] for lora in locked_loras}
|
locked_names = {lora["name"] for lora in locked_loras}
|
||||||
available_pool = [
|
available_pool = [
|
||||||
l for l in available_loras
|
l for l in available_loras if l["file_name"] not in locked_names
|
||||||
if l['model_name'] not in locked_names
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Ensure we don't try to select more than available
|
# Ensure we don't try to select more than available
|
||||||
@@ -253,9 +284,7 @@ class LoraService(BaseModelService):
|
|||||||
# Generate random strengths for selected LoRAs
|
# Generate random strengths for selected LoRAs
|
||||||
result_loras = []
|
result_loras = []
|
||||||
for lora in selected:
|
for lora in selected:
|
||||||
model_str = round(
|
model_str = round(random.uniform(model_strength_min, model_strength_max), 2)
|
||||||
random.uniform(model_strength_min, model_strength_max), 2
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_same_clip_strength:
|
if use_same_clip_strength:
|
||||||
clip_str = model_str
|
clip_str = model_str
|
||||||
@@ -264,21 +293,25 @@ class LoraService(BaseModelService):
|
|||||||
random.uniform(clip_strength_min, clip_strength_max), 2
|
random.uniform(clip_strength_min, clip_strength_max), 2
|
||||||
)
|
)
|
||||||
|
|
||||||
result_loras.append({
|
result_loras.append(
|
||||||
'name': lora['model_name'],
|
{
|
||||||
'strength': model_str,
|
"name": lora["file_name"],
|
||||||
'clipStrength': clip_str,
|
"strength": model_str,
|
||||||
'active': True,
|
"clipStrength": clip_str,
|
||||||
'expanded': abs(model_str - clip_str) > 0.001,
|
"active": True,
|
||||||
'locked': False
|
"expanded": abs(model_str - clip_str) > 0.001,
|
||||||
})
|
"locked": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Merge with locked LoRAs
|
# Merge with locked LoRAs
|
||||||
result_loras.extend(locked_loras)
|
result_loras.extend(locked_loras)
|
||||||
|
|
||||||
return result_loras
|
return result_loras
|
||||||
|
|
||||||
async def _apply_pool_filters(self, available_loras: List[Dict], pool_config: Dict) -> List[Dict]:
|
async def _apply_pool_filters(
|
||||||
|
self, available_loras: List[Dict], pool_config: Dict
|
||||||
|
) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
Apply pool_config filters to available LoRAs.
|
Apply pool_config filters to available LoRAs.
|
||||||
|
|
||||||
@@ -292,26 +325,26 @@ class LoraService(BaseModelService):
|
|||||||
from .model_query import FilterCriteria
|
from .model_query import FilterCriteria
|
||||||
|
|
||||||
# Extract filter parameters from pool_config
|
# Extract filter parameters from pool_config
|
||||||
selected_base_models = pool_config.get('selected_base_models', [])
|
selected_base_models = pool_config.get("selected_base_models", [])
|
||||||
include_tags = pool_config.get('include_tags', [])
|
include_tags = pool_config.get("include_tags", [])
|
||||||
exclude_tags = pool_config.get('exclude_tags', [])
|
exclude_tags = pool_config.get("exclude_tags", [])
|
||||||
include_folders = pool_config.get('include_folders', [])
|
include_folders = pool_config.get("include_folders", [])
|
||||||
exclude_folders = pool_config.get('exclude_folders', [])
|
exclude_folders = pool_config.get("exclude_folders", [])
|
||||||
no_credit_required = pool_config.get('no_credit_required', False)
|
no_credit_required = pool_config.get("no_credit_required", False)
|
||||||
allow_selling = pool_config.get('allow_selling', False)
|
allow_selling = pool_config.get("allow_selling", False)
|
||||||
|
|
||||||
# Build tag filters dict
|
# Build tag filters dict
|
||||||
tag_filters = {}
|
tag_filters = {}
|
||||||
for tag in include_tags:
|
for tag in include_tags:
|
||||||
tag_filters[tag] = 'include'
|
tag_filters[tag] = "include"
|
||||||
for tag in exclude_tags:
|
for tag in exclude_tags:
|
||||||
tag_filters[tag] = 'exclude'
|
tag_filters[tag] = "exclude"
|
||||||
|
|
||||||
# Build folder filter
|
# Build folder filter
|
||||||
if include_folders or exclude_folders:
|
if include_folders or exclude_folders:
|
||||||
filtered = []
|
filtered = []
|
||||||
for lora in available_loras:
|
for lora in available_loras:
|
||||||
folder = lora.get('folder', '')
|
folder = lora.get("folder", "")
|
||||||
|
|
||||||
# Check exclude folders first
|
# Check exclude folders first
|
||||||
excluded = False
|
excluded = False
|
||||||
@@ -340,8 +373,9 @@ class LoraService(BaseModelService):
|
|||||||
# Apply base model filter
|
# Apply base model filter
|
||||||
if selected_base_models:
|
if selected_base_models:
|
||||||
available_loras = [
|
available_loras = [
|
||||||
lora for lora in available_loras
|
lora
|
||||||
if lora.get('base_model') in selected_base_models
|
for lora in available_loras
|
||||||
|
if lora.get("base_model") in selected_base_models
|
||||||
]
|
]
|
||||||
|
|
||||||
# Apply tag filters
|
# Apply tag filters
|
||||||
@@ -352,14 +386,17 @@ class LoraService(BaseModelService):
|
|||||||
# Apply license filters
|
# Apply license filters
|
||||||
if no_credit_required:
|
if no_credit_required:
|
||||||
available_loras = [
|
available_loras = [
|
||||||
lora for lora in available_loras
|
lora
|
||||||
if not lora.get('civitai', {}).get('allowNoCredit', True)
|
for lora in available_loras
|
||||||
|
if not lora.get("civitai", {}).get("allowNoCredit", True)
|
||||||
]
|
]
|
||||||
|
|
||||||
if allow_selling:
|
if allow_selling:
|
||||||
available_loras = [
|
available_loras = [
|
||||||
lora for lora in available_loras
|
lora
|
||||||
if lora.get('civitai', {}).get('allowCommercialUse', ['None'])[0] != 'None'
|
for lora in available_loras
|
||||||
|
if lora.get("civitai", {}).get("allowCommercialUse", ["None"])[0]
|
||||||
|
!= "None"
|
||||||
]
|
]
|
||||||
|
|
||||||
return available_loras
|
return available_loras
|
||||||
|
|||||||
@@ -32,12 +32,12 @@
|
|||||||
import { onMounted } from 'vue'
|
import { onMounted } from 'vue'
|
||||||
import LoraRandomizerSettingsView from './lora-randomizer/LoraRandomizerSettingsView.vue'
|
import LoraRandomizerSettingsView from './lora-randomizer/LoraRandomizerSettingsView.vue'
|
||||||
import { useLoraRandomizerState } from '../composables/useLoraRandomizerState'
|
import { useLoraRandomizerState } from '../composables/useLoraRandomizerState'
|
||||||
import type { ComponentWidget, RandomizerConfig } from '../composables/types'
|
import type { ComponentWidget, RandomizerConfig, LoraEntry } from '../composables/types'
|
||||||
|
|
||||||
// Props
|
// Props
|
||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
widget: ComponentWidget
|
widget: ComponentWidget
|
||||||
node: { id: number }
|
node: { id: number; inputs?: any[]; widgets?: any[]; graph?: any }
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
// State management
|
// State management
|
||||||
@@ -48,13 +48,15 @@ const handleRoll = async () => {
|
|||||||
try {
|
try {
|
||||||
console.log('[LoraRandomizerWidget] Roll button clicked')
|
console.log('[LoraRandomizerWidget] Roll button clicked')
|
||||||
|
|
||||||
// Get pool config from connected input (if any)
|
// Get pool config from connected pool_config input
|
||||||
// This would need to be passed from the node's pool_config input
|
const poolConfig = (props.node as any).getPoolConfig?.() || null
|
||||||
const poolConfig = null // TODO: Get from node input if connected
|
|
||||||
|
|
||||||
// Get locked loras from the loras widget
|
// Get locked loras from the loras widget
|
||||||
// This would need to be retrieved from the loras widget on the node
|
const lorasWidget = props.node.widgets?.find((w: any) => w.name === "loras")
|
||||||
const lockedLoras: any[] = [] // TODO: Get from loras widget
|
const lockedLoras: LoraEntry[] = (lorasWidget?.value || []).filter((lora: LoraEntry) => lora.locked === true)
|
||||||
|
|
||||||
|
console.log('[LoraRandomizerWidget] Pool config:', poolConfig)
|
||||||
|
console.log('[LoraRandomizerWidget] Locked loras:', lockedLoras)
|
||||||
|
|
||||||
// Call API to get random loras
|
// Call API to get random loras
|
||||||
const randomLoras = await state.rollLoras(poolConfig, lockedLoras)
|
const randomLoras = await state.rollLoras(poolConfig, lockedLoras)
|
||||||
|
|||||||
@@ -154,8 +154,10 @@
|
|||||||
:disabled="rollMode !== 'frontend' || isRolling"
|
:disabled="rollMode !== 'frontend' || isRolling"
|
||||||
@click="$emit('roll')"
|
@click="$emit('roll')"
|
||||||
>
|
>
|
||||||
<span v-if="!isRolling">🎲 Roll</span>
|
<span class="roll-button__content">
|
||||||
<span v-else>Rolling...</span>
|
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect width="18" height="18" x="3" y="3" rx="2" ry="2"></rect><path d="M8 8h.01"></path><path d="M16 16h.01"></path><path d="M16 8h.01"></path><path d="M8 16h.01"></path></svg>
|
||||||
|
Roll
|
||||||
|
</span>
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
<div class="roll-mode-selector">
|
<div class="roll-mode-selector">
|
||||||
@@ -329,7 +331,7 @@ defineEmits<{
|
|||||||
}
|
}
|
||||||
|
|
||||||
.roll-button {
|
.roll-button {
|
||||||
padding: 6px 16px;
|
padding: 8px 16px;
|
||||||
background: linear-gradient(135deg, #3b82f6 0%, #2563eb 100%);
|
background: linear-gradient(135deg, #3b82f6 0%, #2563eb 100%);
|
||||||
border: none;
|
border: none;
|
||||||
border-radius: 4px;
|
border-radius: 4px;
|
||||||
@@ -339,6 +341,9 @@ defineEmits<{
|
|||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
transition: all 0.2s;
|
transition: all 0.2s;
|
||||||
white-space: nowrap;
|
white-space: nowrap;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
}
|
}
|
||||||
|
|
||||||
.roll-button:hover:not(:disabled) {
|
.roll-button:hover:not(:disabled) {
|
||||||
@@ -356,4 +361,10 @@ defineEmits<{
|
|||||||
cursor: not-allowed;
|
cursor: not-allowed;
|
||||||
background: linear-gradient(135deg, #52525b 0%, #3f3f46 100%);
|
background: linear-gradient(135deg, #52525b 0%, #3f3f46 100%);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.roll-button__content {
|
||||||
|
display: inline-flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 6px;
|
||||||
|
}
|
||||||
</style>
|
</style>
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import type { LoraPoolConfig, LegacyLoraPoolConfig, RandomizerConfig } from './c
|
|||||||
|
|
||||||
// @ts-ignore - ComfyUI external module
|
// @ts-ignore - ComfyUI external module
|
||||||
import { app } from '../../../scripts/app.js'
|
import { app } from '../../../scripts/app.js'
|
||||||
|
// @ts-ignore
|
||||||
|
import { getPoolConfigFromConnectedNode, getActiveLorasFromNode, updateConnectedTriggerWords, updateDownstreamLoaders } from '../../web/comfyui/utils.js'
|
||||||
|
|
||||||
const vueApps = new Map<number, VueApp>()
|
const vueApps = new Map<number, VueApp>()
|
||||||
|
|
||||||
@@ -119,6 +121,9 @@ function createLoraRandomizerWidget(node) {
|
|||||||
internalValue = v
|
internalValue = v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add method to get pool config from connected node
|
||||||
|
node.getPoolConfig = () => getPoolConfigFromConnectedNode(node)
|
||||||
|
|
||||||
// Handle roll event from Vue component
|
// Handle roll event from Vue component
|
||||||
widget.onRoll = (randomLoras: any[]) => {
|
widget.onRoll = (randomLoras: any[]) => {
|
||||||
console.log('[createLoraRandomizerWidget] Roll event received:', randomLoras)
|
console.log('[createLoraRandomizerWidget] Roll event received:', randomLoras)
|
||||||
@@ -181,10 +186,21 @@ app.registerExtension({
|
|||||||
// @ts-ignore
|
// @ts-ignore
|
||||||
async LORAS(node: any) {
|
async LORAS(node: any) {
|
||||||
if (!addLorasWidgetCache) {
|
if (!addLorasWidgetCache) {
|
||||||
|
// @ts-ignore
|
||||||
const module = await import(/* @vite-ignore */ '../loras_widget.js')
|
const module = await import(/* @vite-ignore */ '../loras_widget.js')
|
||||||
addLorasWidgetCache = module.addLorasWidget
|
addLorasWidgetCache = module.addLorasWidget
|
||||||
}
|
}
|
||||||
return addLorasWidgetCache(node, 'loras', {}, null)
|
// Check if this is a randomizer node to enable lock buttons
|
||||||
|
const isRandomizerNode = node.comfyClass === 'Lora Randomizer (LoraManager)'
|
||||||
|
|
||||||
|
console.log(node)
|
||||||
|
|
||||||
|
// For randomizer nodes, add a callback to update connected trigger words
|
||||||
|
const callback = isRandomizerNode ? (value: any) => {
|
||||||
|
updateDownstreamLoaders(node)
|
||||||
|
} : null
|
||||||
|
|
||||||
|
return addLorasWidgetCache(node, 'loras', { isRandomizerNode }, callback)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -175,6 +175,20 @@
|
|||||||
box-shadow: 0 0 0 1px rgba(66, 153, 225, 0.3) !important;
|
box-shadow: 0 0 0 1px rgba(66, 153, 225, 0.3) !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.lm-lora-entry[data-selected="true"][data-locked="true"] {
|
||||||
|
background-color: rgba(66, 153, 225, 0.2) !important;
|
||||||
|
border-left: 3px solid rgba(245, 158, 11, 0.8) !important;
|
||||||
|
border-right: 1px solid rgba(66, 153, 225, 0.6) !important;
|
||||||
|
border-top: 1px solid rgba(66, 153, 225, 0.6) !important;
|
||||||
|
border-bottom: 1px solid rgba(66, 153, 225, 0.6) !important;
|
||||||
|
box-shadow: 0 0 0 1px rgba(66, 153, 225, 0.3), inset 0 0 20px rgba(245, 158, 11, 0.04) !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.lm-lora-entry[data-selected="true"][data-locked="true"][data-active="false"] {
|
||||||
|
background-color: rgba(48, 42, 36, 0.5) !important;
|
||||||
|
border-left: 3px solid rgba(245, 158, 11, 0.6) !important;
|
||||||
|
}
|
||||||
|
|
||||||
.lm-lora-name {
|
.lm-lora-name {
|
||||||
margin-left: 4px;
|
margin-left: 4px;
|
||||||
flex: 1;
|
flex: 1;
|
||||||
@@ -236,7 +250,6 @@
|
|||||||
justify-content: center;
|
justify-content: center;
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
user-select: none;
|
user-select: none;
|
||||||
font-size: 12px;
|
|
||||||
color: rgba(226, 232, 240, 0.8);
|
color: rgba(226, 232, 240, 0.8);
|
||||||
transition: all 0.2s ease;
|
transition: all 0.2s ease;
|
||||||
}
|
}
|
||||||
@@ -246,6 +259,11 @@
|
|||||||
transform: scale(1.2);
|
transform: scale(1.2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.lm-lora-arrow svg {
|
||||||
|
width: 12px;
|
||||||
|
height: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
.lm-lora-expand-button {
|
.lm-lora-expand-button {
|
||||||
width: 20px;
|
width: 20px;
|
||||||
height: 20px;
|
height: 20px;
|
||||||
@@ -254,7 +272,6 @@
|
|||||||
justify-content: center;
|
justify-content: center;
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
user-select: none;
|
user-select: none;
|
||||||
font-size: 10px;
|
|
||||||
color: rgba(226, 232, 240, 0.7);
|
color: rgba(226, 232, 240, 0.7);
|
||||||
background-color: rgba(45, 55, 72, 0.3);
|
background-color: rgba(45, 55, 72, 0.3);
|
||||||
border: 1px solid rgba(226, 232, 240, 0.2);
|
border: 1px solid rgba(226, 232, 240, 0.2);
|
||||||
@@ -262,7 +279,6 @@
|
|||||||
transition: all 0.2s ease;
|
transition: all 0.2s ease;
|
||||||
flex-shrink: 0;
|
flex-shrink: 0;
|
||||||
padding: 0;
|
padding: 0;
|
||||||
line-height: 1;
|
|
||||||
box-sizing: border-box;
|
box-sizing: border-box;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -314,13 +330,17 @@
|
|||||||
justify-content: center;
|
justify-content: center;
|
||||||
cursor: grab;
|
cursor: grab;
|
||||||
user-select: none;
|
user-select: none;
|
||||||
font-size: 14px;
|
|
||||||
color: rgba(226, 232, 240, 0.6);
|
color: rgba(226, 232, 240, 0.6);
|
||||||
transition: all 0.2s ease;
|
transition: all 0.2s ease;
|
||||||
margin-right: 8px;
|
margin-right: 8px;
|
||||||
flex-shrink: 0;
|
flex-shrink: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.lm-lora-drag-handle svg {
|
||||||
|
width: 14px;
|
||||||
|
height: 14px;
|
||||||
|
}
|
||||||
|
|
||||||
.lm-lora-drag-handle:hover {
|
.lm-lora-drag-handle:hover {
|
||||||
color: rgba(226, 232, 240, 0.9);
|
color: rgba(226, 232, 240, 0.9);
|
||||||
transform: scale(1.1);
|
transform: scale(1.1);
|
||||||
@@ -330,6 +350,83 @@
|
|||||||
cursor: grabbing;
|
cursor: grabbing;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.lm-lora-lock-button {
|
||||||
|
width: 20px;
|
||||||
|
height: 20px;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
cursor: pointer;
|
||||||
|
user-select: none;
|
||||||
|
background-color: rgba(45, 55, 72, 0.3);
|
||||||
|
border: 1px solid rgba(226, 232, 240, 0.2);
|
||||||
|
border-radius: 3px;
|
||||||
|
transition: all 0.2s ease;
|
||||||
|
flex-shrink: 0;
|
||||||
|
padding: 0;
|
||||||
|
box-sizing: border-box;
|
||||||
|
margin-right: 8px;
|
||||||
|
color: rgba(226, 232, 240, 0.8);
|
||||||
|
}
|
||||||
|
|
||||||
|
.lm-lora-lock-button svg {
|
||||||
|
width: 14px;
|
||||||
|
height: 14px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.lm-lora-lock-button:hover {
|
||||||
|
background-color: rgba(66, 153, 225, 0.2);
|
||||||
|
border-color: rgba(66, 153, 225, 0.4);
|
||||||
|
color: rgba(226, 232, 240, 0.95);
|
||||||
|
transform: scale(1.05);
|
||||||
|
}
|
||||||
|
|
||||||
|
.lm-lora-lock-button:active {
|
||||||
|
transform: scale(0.95);
|
||||||
|
}
|
||||||
|
|
||||||
|
.lm-lora-lock-button:focus {
|
||||||
|
outline: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.lm-lora-lock-button:focus-visible {
|
||||||
|
box-shadow: 0 0 0 2px rgba(66, 153, 225, 0.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
.lm-lora-lock-button--locked {
|
||||||
|
background-color: rgba(245, 158, 11, 0.25);
|
||||||
|
border-color: rgba(245, 158, 11, 0.7);
|
||||||
|
color: rgba(251, 191, 36, 0.95);
|
||||||
|
box-shadow: 0 0 8px rgba(245, 158, 11, 0.15);
|
||||||
|
}
|
||||||
|
|
||||||
|
.lm-lora-lock-button--locked:hover {
|
||||||
|
background-color: rgba(245, 158, 11, 0.35);
|
||||||
|
border-color: rgba(245, 158, 11, 0.85);
|
||||||
|
box-shadow: 0 0 12px rgba(245, 158, 11, 0.25);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Visual styling for locked lora entries */
|
||||||
|
.lm-lora-entry[data-locked="true"] {
|
||||||
|
background-color: rgba(60, 50, 40, 0.75);
|
||||||
|
border-left: 3px solid rgba(245, 158, 11, 0.7);
|
||||||
|
box-shadow: inset 0 0 20px rgba(245, 158, 11, 0.04);
|
||||||
|
}
|
||||||
|
|
||||||
|
.lm-lora-entry[data-locked="true"]:hover {
|
||||||
|
background-color: rgba(65, 55, 45, 0.85);
|
||||||
|
box-shadow: inset 0 0 24px rgba(245, 158, 11, 0.06);
|
||||||
|
}
|
||||||
|
|
||||||
|
.lm-lora-entry[data-locked="true"][data-active="false"] {
|
||||||
|
background-color: rgba(48, 42, 36, 0.6);
|
||||||
|
border-left: 3px solid rgba(245, 158, 11, 0.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
.lm-lora-entry[data-locked="true"][data-active="false"]:hover {
|
||||||
|
background-color: rgba(52, 46, 40, 0.7);
|
||||||
|
}
|
||||||
|
|
||||||
body.lm-lora-strength-dragging,
|
body.lm-lora-strength-dragging,
|
||||||
body.lm-lora-strength-dragging * {
|
body.lm-lora-strength-dragging * {
|
||||||
cursor: ew-resize !important;
|
cursor: ew-resize !important;
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import {
|
|||||||
getActiveLorasFromNode,
|
getActiveLorasFromNode,
|
||||||
collectActiveLorasFromChain,
|
collectActiveLorasFromChain,
|
||||||
updateConnectedTriggerWords,
|
updateConnectedTriggerWords,
|
||||||
|
updateDownstreamLoaders,
|
||||||
chainCallback,
|
chainCallback,
|
||||||
mergeLoras,
|
mergeLoras,
|
||||||
setupInputWidgetWithAutocomplete,
|
setupInputWidgetWithAutocomplete,
|
||||||
@@ -169,41 +170,3 @@ app.registerExtension({
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
// Helper function to find and update downstream Lora Loader nodes
|
|
||||||
function updateDownstreamLoaders(startNode, visited = new Set()) {
|
|
||||||
const nodeKey = getNodeKey(startNode);
|
|
||||||
if (!nodeKey || visited.has(nodeKey)) return;
|
|
||||||
visited.add(nodeKey);
|
|
||||||
|
|
||||||
// Check each output link
|
|
||||||
if (startNode.outputs) {
|
|
||||||
for (const output of startNode.outputs) {
|
|
||||||
if (output.links) {
|
|
||||||
for (const linkId of output.links) {
|
|
||||||
const link = getLinkFromGraph(startNode.graph, linkId);
|
|
||||||
if (link) {
|
|
||||||
const targetNode = startNode.graph?.getNodeById?.(link.target_id);
|
|
||||||
|
|
||||||
// If target is a Lora Loader, collect all active loras in the chain and update
|
|
||||||
if (
|
|
||||||
targetNode &&
|
|
||||||
targetNode.comfyClass === "Lora Loader (LoraManager)"
|
|
||||||
) {
|
|
||||||
const allActiveLoraNames =
|
|
||||||
collectActiveLorasFromChain(targetNode);
|
|
||||||
updateConnectedTriggerWords(targetNode, allActiveLoraNames);
|
|
||||||
}
|
|
||||||
// If target is another Lora Stacker, recursively check its outputs
|
|
||||||
else if (
|
|
||||||
targetNode &&
|
|
||||||
targetNode.comfyClass === "Lora Stacker (LoraManager)"
|
|
||||||
) {
|
|
||||||
updateDownstreamLoaders(targetNode, visited);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { createToggle, createArrowButton, createDragHandle, updateEntrySelection, createExpandButton, updateExpandButtonState } from "./loras_widget_components.js";
|
import { createToggle, createArrowButton, createDragHandle, updateEntrySelection, createExpandButton, updateExpandButtonState, createLockButton, updateLockButtonState } from "./loras_widget_components.js";
|
||||||
import {
|
import {
|
||||||
parseLoraValue,
|
parseLoraValue,
|
||||||
formatLoraValue,
|
formatLoraValue,
|
||||||
@@ -27,6 +27,9 @@ export function addLorasWidget(node, name, opts, callback) {
|
|||||||
// Set initial height using CSS variables approach
|
// Set initial height using CSS variables approach
|
||||||
const defaultHeight = 200;
|
const defaultHeight = 200;
|
||||||
|
|
||||||
|
// Check if this is a randomizer node (lock button instead of drag handle)
|
||||||
|
const isRandomizerNode = opts?.isRandomizerNode === true;
|
||||||
|
|
||||||
// Initialize default value
|
// Initialize default value
|
||||||
const defaultValue = opts?.defaultVal || [];
|
const defaultValue = opts?.defaultVal || [];
|
||||||
const onSelectionChange = typeof opts?.onSelectionChange === "function"
|
const onSelectionChange = typeof opts?.onSelectionChange === "function"
|
||||||
@@ -255,32 +258,52 @@ export function addLorasWidget(node, name, opts, callback) {
|
|||||||
const loraEl = document.createElement("div");
|
const loraEl = document.createElement("div");
|
||||||
loraEl.className = "lm-lora-entry";
|
loraEl.className = "lm-lora-entry";
|
||||||
|
|
||||||
// Store lora name and active state in dataset for selection
|
// Store lora name, active state, and locked state in dataset
|
||||||
loraEl.dataset.loraName = name;
|
loraEl.dataset.loraName = name;
|
||||||
loraEl.dataset.active = active ? "true" : "false";
|
loraEl.dataset.active = active ? "true" : "false";
|
||||||
|
loraEl.dataset.locked = (loraData.locked || false) ? "true" : "false";
|
||||||
|
|
||||||
// Add click handler for selection
|
// Add click handler for selection
|
||||||
loraEl.addEventListener('click', (e) => {
|
loraEl.addEventListener('click', (e) => {
|
||||||
// Skip if clicking on interactive elements
|
// Skip if clicking on interactive elements
|
||||||
if (e.target.closest('.lm-lora-toggle') ||
|
if (e.target.closest('.lm-lora-toggle') ||
|
||||||
e.target.closest('input') ||
|
e.target.closest('input') ||
|
||||||
e.target.closest('.lm-lora-arrow') ||
|
e.target.closest('.lm-lora-arrow') ||
|
||||||
e.target.closest('.lm-lora-drag-handle') ||
|
e.target.closest('.lm-lora-drag-handle') ||
|
||||||
|
e.target.closest('.lm-lora-lock-button') ||
|
||||||
e.target.closest('.lm-lora-expand-button')) {
|
e.target.closest('.lm-lora-expand-button')) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
e.stopPropagation();
|
e.stopPropagation();
|
||||||
selectLora(name);
|
selectLora(name);
|
||||||
container.focus(); // Focus container for keyboard events
|
container.focus(); // Focus container for keyboard events
|
||||||
});
|
});
|
||||||
|
|
||||||
// Create drag handle for reordering
|
// Conditionally create drag handle OR lock button
|
||||||
const dragHandle = createDragHandle();
|
let dragHandleOrLockButton;
|
||||||
|
|
||||||
// Initialize reorder drag functionality
|
if (isRandomizerNode) {
|
||||||
initReorderDrag(dragHandle, name, widget, renderLoras);
|
// For randomizer node, show lock button instead of drag handle
|
||||||
|
const isLocked = loraData.locked || false;
|
||||||
|
dragHandleOrLockButton = createLockButton(isLocked, (newLocked) => {
|
||||||
|
// Update this lora's locked state
|
||||||
|
const lorasData = parseLoraValue(widget.value);
|
||||||
|
const loraIndex = lorasData.findIndex(l => l.name === name);
|
||||||
|
|
||||||
|
if (loraIndex >= 0) {
|
||||||
|
lorasData[loraIndex].locked = newLocked;
|
||||||
|
const newValue = formatLoraValue(lorasData);
|
||||||
|
updateWidgetValue(newValue);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// For other nodes, show drag handle
|
||||||
|
dragHandleOrLockButton = createDragHandle();
|
||||||
|
// Initialize reorder drag functionality
|
||||||
|
initReorderDrag(dragHandleOrLockButton, name, widget, renderLoras);
|
||||||
|
}
|
||||||
|
|
||||||
// Create toggle for this lora
|
// Create toggle for this lora
|
||||||
const toggle = createToggle(active, (newActive) => {
|
const toggle = createToggle(active, (newActive) => {
|
||||||
@@ -481,8 +504,8 @@ export function addLorasWidget(node, name, opts, callback) {
|
|||||||
// Assemble entry
|
// Assemble entry
|
||||||
const leftSection = document.createElement("div");
|
const leftSection = document.createElement("div");
|
||||||
leftSection.className = "lm-lora-entry-left";
|
leftSection.className = "lm-lora-entry-left";
|
||||||
|
|
||||||
leftSection.appendChild(dragHandle); // Add drag handle first
|
leftSection.appendChild(dragHandleOrLockButton); // Add drag handle or lock button first
|
||||||
leftSection.appendChild(toggle);
|
leftSection.appendChild(toggle);
|
||||||
leftSection.appendChild(expandButton); // Add expand button
|
leftSection.appendChild(expandButton); // Add expand button
|
||||||
leftSection.appendChild(nameEl);
|
leftSection.appendChild(nameEl);
|
||||||
@@ -685,16 +708,17 @@ export function addLorasWidget(node, name, opts, callback) {
|
|||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
// Apply existing clip strength values and transfer them to the new value
|
// Apply existing clip strength values and transfer them to the new value
|
||||||
const updatedValue = uniqueValue.map(lora => {
|
const updatedValue = uniqueValue.map(lora => {
|
||||||
// For new loras, default clip strength to model strength and expanded to false
|
// For new loras, default clip strength to model strength and expanded to false
|
||||||
// unless clipStrength is already different from strength
|
// unless clipStrength is already different from strength
|
||||||
const clipStrength = lora.clipStrength || lora.strength;
|
const clipStrength = lora.clipStrength || lora.strength;
|
||||||
return {
|
return {
|
||||||
...lora,
|
...lora,
|
||||||
clipStrength: clipStrength,
|
clipStrength: clipStrength,
|
||||||
expanded: lora.hasOwnProperty('expanded') ?
|
expanded: lora.hasOwnProperty('expanded') ?
|
||||||
lora.expanded :
|
lora.expanded :
|
||||||
Number(clipStrength) !== Number(lora.strength)
|
Number(clipStrength) !== Number(lora.strength),
|
||||||
|
locked: lora.hasOwnProperty('locked') ? lora.locked : false // Initialize locked to false if not present
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,9 @@ export function updateToggleStyle(toggleEl, active) {
|
|||||||
export function createArrowButton(direction, onClick) {
|
export function createArrowButton(direction, onClick) {
|
||||||
const button = document.createElement("div");
|
const button = document.createElement("div");
|
||||||
button.className = `lm-lora-arrow lm-lora-arrow-${direction}`;
|
button.className = `lm-lora-arrow lm-lora-arrow-${direction}`;
|
||||||
button.textContent = direction === "left" ? "◀" : "▶";
|
button.innerHTML = direction === "left"
|
||||||
|
? `<svg width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="m15 18-6-6 6-6"/></svg>`
|
||||||
|
: `<svg width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="m9 18 6-6-6-6"/></svg>`;
|
||||||
|
|
||||||
button.addEventListener("click", (e) => {
|
button.addEventListener("click", (e) => {
|
||||||
e.stopPropagation();
|
e.stopPropagation();
|
||||||
@@ -36,7 +38,7 @@ export function createArrowButton(direction, onClick) {
|
|||||||
export function createDragHandle() {
|
export function createDragHandle() {
|
||||||
const handle = document.createElement("div");
|
const handle = document.createElement("div");
|
||||||
handle.className = "lm-lora-drag-handle";
|
handle.className = "lm-lora-drag-handle";
|
||||||
handle.innerHTML = "≡";
|
handle.innerHTML = `<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="9" cy="12" r="1"/><circle cx="9" cy="5" r="1"/><circle cx="9" cy="19" r="1"/><circle cx="15" cy="12" r="1"/><circle cx="15" cy="5" r="1"/><circle cx="15" cy="19" r="1"/></svg>`;
|
||||||
handle.title = "Drag to reorder LoRA";
|
handle.title = "Drag to reorder LoRA";
|
||||||
return handle;
|
return handle;
|
||||||
}
|
}
|
||||||
@@ -102,10 +104,42 @@ export function createExpandButton(isExpanded, onClick) {
|
|||||||
// Helper function to update expand button state
|
// Helper function to update expand button state
|
||||||
export function updateExpandButtonState(button, isExpanded) {
|
export function updateExpandButtonState(button, isExpanded) {
|
||||||
if (isExpanded) {
|
if (isExpanded) {
|
||||||
button.innerHTML = "▼"; // Down arrow for expanded
|
button.innerHTML = `<svg width="10" height="10" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="m6 9 6 6 6-6"/></svg>`;
|
||||||
button.title = "Collapse clip controls";
|
button.title = "Collapse clip controls";
|
||||||
} else {
|
} else {
|
||||||
button.innerHTML = "▶"; // Right arrow for collapsed
|
button.innerHTML = `<svg width="10" height="10" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="m9 18 6-6-6-6"/></svg>`;
|
||||||
button.title = "Expand clip controls";
|
button.title = "Expand clip controls";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Function to create lock button
|
||||||
|
export function createLockButton(isLocked, onChange) {
|
||||||
|
const button = document.createElement("button");
|
||||||
|
button.className = "lm-lora-lock-button";
|
||||||
|
button.type = "button";
|
||||||
|
button.tabIndex = -1;
|
||||||
|
|
||||||
|
// Set icon based on locked state
|
||||||
|
updateLockButtonState(button, isLocked);
|
||||||
|
|
||||||
|
button.addEventListener("click", (e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
e.stopPropagation();
|
||||||
|
onChange(!isLocked);
|
||||||
|
});
|
||||||
|
|
||||||
|
return button;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to update lock button state
|
||||||
|
export function updateLockButtonState(button, isLocked) {
|
||||||
|
if (isLocked) {
|
||||||
|
button.innerHTML = `<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect width="18" height="11" x="3" y="11" rx="2" ry="2"></rect><path d="M7 11V7a5 5 0 0 1 10 0v4"></path></svg>`;
|
||||||
|
button.title = "Unlock this LoRA (allow re-rolling)";
|
||||||
|
button.classList.add("lm-lora-lock-button--locked");
|
||||||
|
} else {
|
||||||
|
button.innerHTML = `<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect width="18" height="11" x="3" y="11" rx="2" ry="2"></rect><path d="M7 11V7a5 5 0 0 1 9.9-1"></path></svg>`;
|
||||||
|
button.title = "Lock this LoRA (prevent re-rolling)";
|
||||||
|
button.classList.remove("lm-lora-lock-button--locked");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -239,6 +239,7 @@ app.registerExtension({
|
|||||||
|
|
||||||
// Handle trigger word updates from Python
|
// Handle trigger word updates from Python
|
||||||
handleTriggerWordUpdate(id, graphId, message) {
|
handleTriggerWordUpdate(id, graphId, message) {
|
||||||
|
console.log('trigger word update: ', id, graphId, message);
|
||||||
const node = getNodeFromGraph(graphId, id);
|
const node = getNodeFromGraph(graphId, id);
|
||||||
if (!node || node.comfyClass !== "TriggerWord Toggle (LoraManager)") {
|
if (!node || node.comfyClass !== "TriggerWord Toggle (LoraManager)") {
|
||||||
console.warn("Node not found or not a TriggerWordToggle:", id);
|
console.warn("Node not found or not a TriggerWordToggle:", id);
|
||||||
|
|||||||
@@ -233,7 +233,7 @@ export function getConnectedInputStackers(node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const sourceNode = node.graph?.getNodeById?.(link.origin_id);
|
const sourceNode = node.graph?.getNodeById?.(link.origin_id);
|
||||||
if (sourceNode && sourceNode.comfyClass === "Lora Stacker (LoraManager)") {
|
if (sourceNode && (sourceNode.comfyClass === "Lora Stacker (LoraManager)" || sourceNode.comfyClass === "Lora Randomizer (LoraManager)")) {
|
||||||
connectedStackers.push(sourceNode);
|
connectedStackers.push(sourceNode);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -274,9 +274,13 @@ export function getConnectedTriggerToggleNodes(node) {
|
|||||||
export function getActiveLorasFromNode(node) {
|
export function getActiveLorasFromNode(node) {
|
||||||
const activeLoraNames = new Set();
|
const activeLoraNames = new Set();
|
||||||
|
|
||||||
// For lorasWidget style entries (array of objects)
|
let lorasWidget = node.lorasWidget;
|
||||||
if (node.lorasWidget && node.lorasWidget.value) {
|
if (!lorasWidget && node.widgets) {
|
||||||
node.lorasWidget.value.forEach(lora => {
|
lorasWidget = node.widgets.find(w => w.name === 'loras');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lorasWidget && lorasWidget.value) {
|
||||||
|
lorasWidget.value.forEach(lora => {
|
||||||
if (lora.active) {
|
if (lora.active) {
|
||||||
activeLoraNames.add(lora.name);
|
activeLoraNames.add(lora.name);
|
||||||
}
|
}
|
||||||
@@ -324,6 +328,8 @@ export function updateConnectedTriggerWords(node, loraNames) {
|
|||||||
.map((connectedNode) => getNodeReference(connectedNode))
|
.map((connectedNode) => getNodeReference(connectedNode))
|
||||||
.filter((reference) => reference !== null);
|
.filter((reference) => reference !== null);
|
||||||
|
|
||||||
|
console.log('node ids: ', nodeIds, loraNames);
|
||||||
|
|
||||||
if (nodeIds.length === 0) {
|
if (nodeIds.length === 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -467,3 +473,78 @@ export function forwardMiddleMouseToCanvas(container) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get connected Lora Pool node from pool_config input
|
||||||
|
export function getConnectedPoolConfigNode(node) {
|
||||||
|
if (!node?.inputs) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const input of node.inputs) {
|
||||||
|
if (input.name !== "pool_config" || !input.link) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const link = getLinkFromGraph(node.graph, input.link);
|
||||||
|
if (!link) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const sourceNode = node.graph?.getNodeById?.(link.origin_id);
|
||||||
|
if (sourceNode && sourceNode.comfyClass === "Lora Pool (LoraManager)") {
|
||||||
|
return sourceNode;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get pool config widget value from connected Lora Pool node
|
||||||
|
export function getPoolConfigFromConnectedNode(node) {
|
||||||
|
const poolNode = getConnectedPoolConfigNode(node);
|
||||||
|
if (!poolNode) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const poolWidget = poolNode.widgets?.find(w => w.name === "pool_config");
|
||||||
|
return poolWidget?.value || null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to find and update downstream Lora Loader nodes
|
||||||
|
export function updateDownstreamLoaders(startNode, visited = new Set()) {
|
||||||
|
const nodeKey = getNodeKey(startNode);
|
||||||
|
if (!nodeKey || visited.has(nodeKey)) return;
|
||||||
|
visited.add(nodeKey);
|
||||||
|
|
||||||
|
// Check each output link
|
||||||
|
if (startNode.outputs) {
|
||||||
|
for (const output of startNode.outputs) {
|
||||||
|
if (output.links) {
|
||||||
|
for (const linkId of output.links) {
|
||||||
|
const link = getLinkFromGraph(startNode.graph, linkId);
|
||||||
|
if (link) {
|
||||||
|
const targetNode = startNode.graph?.getNodeById?.(link.target_id);
|
||||||
|
|
||||||
|
// If target is a Lora Loader, collect all active loras in the chain and update
|
||||||
|
if (
|
||||||
|
targetNode &&
|
||||||
|
targetNode.comfyClass === "Lora Loader (LoraManager)"
|
||||||
|
) {
|
||||||
|
const allActiveLoraNames =
|
||||||
|
collectActiveLorasFromChain(targetNode);
|
||||||
|
updateConnectedTriggerWords(targetNode, allActiveLoraNames);
|
||||||
|
}
|
||||||
|
// If target is another Lora Stacker or Lora Randomizer, recursively check its outputs
|
||||||
|
else if (
|
||||||
|
targetNode &&
|
||||||
|
(targetNode.comfyClass === "Lora Stacker (LoraManager)" ||
|
||||||
|
targetNode.comfyClass === "Lora Randomizer (LoraManager)")
|
||||||
|
) {
|
||||||
|
updateDownstreamLoaders(targetNode, visited);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user