diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..b0b76b75 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,163 @@ +# AGENTS.md + +This file provides guidance for agentic coding assistants working in this repository. + +## Development Commands + +### Backend Development + +```bash +# Install dependencies +pip install -r requirements.txt +pip install -r requirements-dev.txt + +# Run standalone server (port 8188 by default) +python standalone.py --port 8188 + +# Run all backend tests +pytest + +# Run specific test file +pytest tests/test_recipes.py + +# Run specific test function +pytest tests/test_recipes.py::test_function_name + +# Run backend tests with coverage +COVERAGE_FILE=coverage/backend/.coverage pytest \ + --cov=py \ + --cov=standalone \ + --cov-report=term-missing \ + --cov-report=html:coverage/backend/html \ + --cov-report=xml:coverage/backend/coverage.xml \ + --cov-report=json:coverage/backend/coverage.json +``` + +### Frontend Development + +```bash +# Install frontend dependencies +npm install + +# Run frontend tests +npm test + +# Run frontend tests in watch mode +npm run test:watch + +# Run frontend tests with coverage +npm run test:coverage +``` + +## Python Code Style + +### Imports + +- Use `from __future__ import annotations` for forward references in type hints +- Group imports: standard library, third-party, local (separated by blank lines) +- Use absolute imports within `py/` package: `from ..services import X` +- Mock ComfyUI dependencies in tests using `tests/conftest.py` patterns + +### Formatting & Types + +- PEP 8 with 4-space indentation +- Type hints required for function signatures and class attributes +- Use `TYPE_CHECKING` guard for type-checking-only imports +- Prefer dataclasses for simple data containers +- Use `Optional[T]` for nullable types, `Union[T, None]` only when necessary + +### Naming Conventions + +- Files: `snake_case.py` (e.g., `model_scanner.py`, `lora_service.py`) +- Classes: `PascalCase` (e.g., `ModelScanner`, `LoraService`) +- Functions/variables: `snake_case` (e.g., `get_instance`, `model_type`) +- Constants: `UPPER_SNAKE_CASE` (e.g., `VALID_LORA_TYPES`) +- Private members: `_single_underscore` (protected), `__double_underscore` (name-mangled) + +### Error Handling + +- Use `logging.getLogger(__name__)` for module-level loggers +- Define custom exceptions in `py/services/errors.py` +- Use `asyncio.Lock` for thread-safe singleton patterns +- Raise specific exceptions with descriptive messages +- Log errors at appropriate levels (DEBUG, INFO, WARNING, ERROR, CRITICAL) + +### Async Patterns + +- Use `async def` for I/O-bound operations +- Mark async tests with `@pytest.mark.asyncio` +- Use `async with` for context managers +- Singleton pattern with class-level locks: see `ModelScanner.get_instance()` +- Use `aiohttp.web.Response` for HTTP responses + +### Testing Patterns + +- Use `pytest` with `--import-mode=importlib` +- Fixtures in `tests/conftest.py` handle ComfyUI mocking +- Use `@pytest.mark.no_settings_dir_isolation` for tests needing real paths +- Test files: `tests/test_*.py` +- Use `tmp_path_factory` for temporary directory isolation + +## JavaScript Code Style + +### Imports & Modules + +- ES modules with `import`/`export` +- Use `import { app } from "../../scripts/app.js"` for ComfyUI integration +- Export named functions/classes: `export function foo() {}` +- Widget files use `*_widget.js` suffix + +### Naming & Formatting + +- camelCase for functions, variables, object properties +- PascalCase for classes/constructors +- Constants: `UPPER_SNAKE_CASE` (e.g., `CONVERTED_TYPE`) +- Files: `snake_case.js` or `kebab-case.js` +- 2-space indentation preferred (follow existing file conventions) + +### Widget Development + +- Use `app.registerExtension()` to register ComfyUI extensions +- Use `node.addDOMWidget(name, type, element, options)` for custom widgets +- Event handlers attached via `addEventListener` or widget callbacks +- See `web/comfyui/utils.js` for shared utilities + +## Architecture Patterns + +### Service Layer + +- Use `ServiceRegistry` singleton for dependency injection +- Services follow singleton pattern via `get_instance()` class method +- Separate scanners (discovery) from services (business logic) +- Handlers in `py/routes/handlers/` implement route logic + +### Model Types + +- BaseModelService is abstract base for LoRA, Checkpoint, Embedding services +- ModelScanner provides file discovery and hash-based deduplication +- Persistent cache in SQLite via `PersistentModelCache` +- Metadata sync from CivitAI/CivArchive via `MetadataSyncService` + +### Routes & Handlers + +- Route registrars organize endpoints by domain: `ModelRouteRegistrar`, etc. +- Handlers are pure functions taking dependencies as parameters +- Use `WebSocketManager` for real-time progress updates +- Return `aiohttp.web.json_response` or `web.Response` + +### Recipe System + +- Base metadata in `py/recipes/base.py` +- Enrichment adds model metadata: `RecipeEnrichmentService` +- Parsers for different formats in `py/recipes/parsers/` + +## Important Notes + +- Always use English for comments (per copilot-instructions.md) +- Dual mode: ComfyUI plugin (uses folder_paths) vs standalone (reads settings.json) +- Detection: `os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1"` +- Settings auto-saved in user directory or portable mode +- WebSocket broadcasts for real-time updates (downloads, scans) +- Symlink handling requires normalized paths +- API endpoints follow `/loras/*`, `/checkpoints/*`, `/embeddings/*` patterns +- Run `python scripts/sync_translation_keys.py` after UI string updates diff --git a/py/nodes/lora_pool.py b/py/nodes/lora_pool.py index 19cee8ea..aec2ed7e 100644 --- a/py/nodes/lora_pool.py +++ b/py/nodes/lora_pool.py @@ -20,7 +20,7 @@ class LoraPoolNode: """ NAME = "Lora Pool (LoraManager)" - CATEGORY = "Lora Manager/pools" + CATEGORY = "Lora Manager/randomizer" @classmethod def INPUT_TYPES(cls): @@ -85,10 +85,3 @@ class LoraPoolNode: }, "preview": {"matchCount": 0, "lastUpdated": 0}, } - - -# Node class mappings for ComfyUI -NODE_CLASS_MAPPINGS = {"LoraPoolNode": LoraPoolNode} - -# Display name mappings -NODE_DISPLAY_NAME_MAPPINGS = {"LoraPoolNode": "LoRA Pool (Filter)"} diff --git a/py/nodes/lora_randomizer.py b/py/nodes/lora_randomizer.py index 7eb79ab5..83990410 100644 --- a/py/nodes/lora_randomizer.py +++ b/py/nodes/lora_randomizer.py @@ -34,7 +34,7 @@ class LoraRandomizerNode: } RETURN_TYPES = ("LORA_STACK",) - RETURN_NAMES = ("lora_stack",) + RETURN_NAMES = ("LORA_STACK",) FUNCTION = "randomize" OUTPUT_NODE = False @@ -53,9 +53,6 @@ class LoraRandomizerNode: """ from ..services.service_registry import ServiceRegistry - # Get lora scanner to access available loras - scanner = await ServiceRegistry.get_lora_scanner() - # Parse randomizer settings count_mode = randomizer_config.get("count_mode", "range") count_fixed = randomizer_config.get("count_fixed", 5) @@ -68,6 +65,93 @@ class LoraRandomizerNode: clip_strength_max = randomizer_config.get("clip_strength_max", 1.0) roll_mode = randomizer_config.get("roll_mode", "frontend") + # Get lora scanner to access available loras + scanner = await ServiceRegistry.get_lora_scanner() + + # Backend roll mode: execute with input loras, return new random to UI + if roll_mode == "backend": + execution_stack = self._build_execution_stack_from_input(loras) + ui_loras = await self._generate_random_loras_for_ui( + scanner, randomizer_config, loras, pool_config + ) + logger.info( + f"[LoraRandomizerNode] Backend roll: executing with input, returning new random to UI" + ) + return {"result": (execution_stack,), "ui": {"loras": ui_loras}} + + # Frontend roll mode: use current behavior (random selection for both) + ui_loras = await self._generate_random_loras_for_ui( + scanner, randomizer_config, loras, pool_config + ) + execution_stack = self._build_execution_stack_from_input(ui_loras) + logger.info( + f"[LoraRandomizerNode] Frontend roll: executing with random selection" + ) + return {"result": (execution_stack,), "ui": {"loras": ui_loras}} + + def _build_execution_stack_from_input(self, loras): + """ + Build LORA_STACK tuple from input loras list for execution. + + Args: + loras: List of LoRA dicts with name, strength, clipStrength, active + + Returns: + List of tuples (lora_path, model_strength, clip_strength) + """ + lora_stack = [] + for lora in loras: + if not lora.get("active", False): + continue + + # Get file path + lora_path, trigger_words = get_lora_info(lora["name"]) + if not lora_path: + logger.warning( + f"[LoraRandomizerNode] Could not find path for LoRA: {lora['name']}" + ) + continue + + # Normalize path separators + lora_path = lora_path.replace("/", os.sep) + + # Extract strengths + model_strength = lora.get("strength", 1.0) + clip_strength = lora.get("clipStrength", model_strength) + + lora_stack.append((lora_path, model_strength, clip_strength)) + + logger.info( + f"[LoraRandomizerNode] Built execution stack with {len(lora_stack)} LoRAs" + ) + return lora_stack + + async def _generate_random_loras_for_ui( + self, scanner, randomizer_config, input_loras, pool_config=None + ): + """ + Generate new random loras for UI display. + + Args: + scanner: LoraScanner instance + randomizer_config: Dict with randomizer settings + input_loras: Current input loras (for extracting locked loras) + pool_config: Optional pool filters + + Returns: + List of LoRA dicts for UI display + """ + # Parse randomizer settings + count_mode = randomizer_config.get("count_mode", "range") + count_fixed = randomizer_config.get("count_fixed", 5) + count_min = randomizer_config.get("count_min", 3) + count_max = randomizer_config.get("count_max", 7) + model_strength_min = randomizer_config.get("model_strength_min", 0.0) + model_strength_max = randomizer_config.get("model_strength_max", 1.0) + use_same_clip_strength = randomizer_config.get("use_same_clip_strength", True) + clip_strength_min = randomizer_config.get("clip_strength_min", 0.0) + clip_strength_max = randomizer_config.get("clip_strength_max", 1.0) + # Determine target count if count_mode == "fixed": target_count = count_fixed @@ -75,11 +159,11 @@ class LoraRandomizerNode: target_count = random.randint(count_min, count_max) logger.info( - f"[LoraRandomizerNode] Target count: {target_count}, Roll mode: {roll_mode}" + f"[LoraRandomizerNode] Generating random LoRAs, target count: {target_count}" ) # Extract locked LoRAs from input - locked_loras = [lora for lora in loras if lora.get("locked", False)] + locked_loras = [lora for lora in input_loras if lora.get("locked", False)] locked_count = len(locked_loras) logger.info(f"[LoraRandomizerNode] Locked LoRAs: {locked_count}") @@ -106,8 +190,6 @@ class LoraRandomizerNode: ) # Calculate how many new LoRAs to select - # In frontend mode, if loras already has data, preserve unlocked ones if roll_mode requires - # For simplicity in backend mode, we regenerate all unlocked slots slots_needed = target_count - locked_count if slots_needed < 0: @@ -161,33 +243,10 @@ class LoraRandomizerNode: # Merge with locked LoRAs result_loras.extend(locked_loras) - logger.info(f"[LoraRandomizerNode] Final LoRA count: {len(result_loras)}") - - # Build LORA_STACK output - lora_stack = [] - for lora in result_loras: - if not lora.get("active", False): - continue - - # Get file path - lora_path, trigger_words = get_lora_info(lora["name"]) - if not lora_path: - logger.warning( - f"[LoraRandomizerNode] Could not find path for LoRA: {lora['name']}" - ) - continue - - # Normalize path separators - lora_path = lora_path.replace("/", os.sep) - - # Extract strengths - model_strength = lora.get("strength", 1.0) - clip_strength = lora.get("clipStrength", model_strength) - - lora_stack.append((lora_path, model_strength, clip_strength)) - - # Return format: result for workflow + ui for frontend display - return {"result": (lora_stack,), "ui": {"loras": result_loras}} + logger.info( + f"[LoraRandomizerNode] Final random LoRA count: {len(result_loras)}" + ) + return result_loras async def _apply_pool_filters(self, available_loras, pool_config, scanner): """ @@ -288,10 +347,3 @@ class LoraRandomizerNode: ] return available_loras - - -# Node class mappings for ComfyUI -NODE_CLASS_MAPPINGS = {"LoraRandomizerNode": LoraRandomizerNode} - -# Display name mappings -NODE_DISPLAY_NAME_MAPPINGS = {"LoraRandomizerNode": "LoRA Randomizer"} diff --git a/py/services/lora_service.py b/py/services/lora_service.py index 01a4759f..7beffecd 100644 --- a/py/services/lora_service.py +++ b/py/services/lora_service.py @@ -8,24 +8,27 @@ from ..config import config logger = logging.getLogger(__name__) + class LoraService(BaseModelService): """LoRA-specific service implementation""" - + def __init__(self, scanner, update_service=None): """Initialize LoRA service - + Args: scanner: LoRA scanner instance update_service: Optional service for remote update tracking. """ super().__init__("lora", scanner, LoraMetadata, update_service=update_service) - + async def format_response(self, lora_data: Dict) -> Dict: """Format LoRA data for API response""" return { "model_name": lora_data["model_name"], "file_name": lora_data["file_name"], - "preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")), + "preview_url": config.get_preview_static_url( + lora_data.get("preview_url", "") + ), "preview_nsfw_level": lora_data.get("preview_nsfw_level", 0), "base_model": lora_data.get("base_model", ""), "folder": lora_data["folder"], @@ -40,141 +43,170 @@ class LoraService(BaseModelService): "notes": lora_data.get("notes", ""), "favorite": lora_data.get("favorite", False), "update_available": bool(lora_data.get("update_available", False)), - "civitai": self.filter_civitai_data(lora_data.get("civitai", {}), minimal=True) + "civitai": self.filter_civitai_data( + lora_data.get("civitai", {}), minimal=True + ), } - + async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]: """Apply LoRA-specific filters""" # Handle first_letter filter for LoRAs - first_letter = kwargs.get('first_letter') + first_letter = kwargs.get("first_letter") if first_letter: data = self._filter_by_first_letter(data, first_letter) - + return data - + def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]: """Filter data by first letter of model name - + Special handling: - '#': Numbers (0-9) - '@': Special characters (not alphanumeric) - '漢': CJK characters """ filtered_data = [] - + for lora in data: - model_name = lora.get('model_name', '') + model_name = lora.get("model_name", "") if not model_name: continue - + first_char = model_name[0].upper() - - if letter == '#' and first_char.isdigit(): + + if letter == "#" and first_char.isdigit(): filtered_data.append(lora) - elif letter == '@' and not first_char.isalnum(): + elif letter == "@" and not first_char.isalnum(): # Special characters (not alphanumeric) filtered_data.append(lora) - elif letter == '漢' and self._is_cjk_character(first_char): + elif letter == "漢" and self._is_cjk_character(first_char): # CJK characters filtered_data.append(lora) elif letter.upper() == first_char: # Regular alphabet matching filtered_data.append(lora) - + return filtered_data - + def _is_cjk_character(self, char: str) -> bool: """Check if character is a CJK character""" # Define Unicode ranges for CJK characters cjk_ranges = [ - (0x4E00, 0x9FFF), # CJK Unified Ideographs - (0x3400, 0x4DBF), # CJK Unified Ideographs Extension A - (0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B - (0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C - (0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D - (0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E - (0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F - (0x30000, 0x3134F), # CJK Unified Ideographs Extension G - (0xF900, 0xFAFF), # CJK Compatibility Ideographs - (0x3300, 0x33FF), # CJK Compatibility - (0x3200, 0x32FF), # Enclosed CJK Letters and Months - (0x3100, 0x312F), # Bopomofo - (0x31A0, 0x31BF), # Bopomofo Extended - (0x3040, 0x309F), # Hiragana - (0x30A0, 0x30FF), # Katakana - (0x31F0, 0x31FF), # Katakana Phonetic Extensions - (0xAC00, 0xD7AF), # Hangul Syllables - (0x1100, 0x11FF), # Hangul Jamo - (0xA960, 0xA97F), # Hangul Jamo Extended-A - (0xD7B0, 0xD7FF), # Hangul Jamo Extended-B + (0x4E00, 0x9FFF), # CJK Unified Ideographs + (0x3400, 0x4DBF), # CJK Unified Ideographs Extension A + (0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B + (0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C + (0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D + (0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E + (0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F + (0x30000, 0x3134F), # CJK Unified Ideographs Extension G + (0xF900, 0xFAFF), # CJK Compatibility Ideographs + (0x3300, 0x33FF), # CJK Compatibility + (0x3200, 0x32FF), # Enclosed CJK Letters and Months + (0x3100, 0x312F), # Bopomofo + (0x31A0, 0x31BF), # Bopomofo Extended + (0x3040, 0x309F), # Hiragana + (0x30A0, 0x30FF), # Katakana + (0x31F0, 0x31FF), # Katakana Phonetic Extensions + (0xAC00, 0xD7AF), # Hangul Syllables + (0x1100, 0x11FF), # Hangul Jamo + (0xA960, 0xA97F), # Hangul Jamo Extended-A + (0xD7B0, 0xD7FF), # Hangul Jamo Extended-B ] - + code_point = ord(char) return any(start <= code_point <= end for start, end in cjk_ranges) - + # LoRA-specific methods async def get_letter_counts(self) -> Dict[str, int]: """Get count of LoRAs for each letter of the alphabet""" cache = await self.scanner.get_cached_data() data = cache.raw_data - + # Define letter categories letters = { - '#': 0, # Numbers - 'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0, - 'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0, - 'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0, - 'Y': 0, 'Z': 0, - '@': 0, # Special characters - '漢': 0 # CJK characters + "#": 0, # Numbers + "A": 0, + "B": 0, + "C": 0, + "D": 0, + "E": 0, + "F": 0, + "G": 0, + "H": 0, + "I": 0, + "J": 0, + "K": 0, + "L": 0, + "M": 0, + "N": 0, + "O": 0, + "P": 0, + "Q": 0, + "R": 0, + "S": 0, + "T": 0, + "U": 0, + "V": 0, + "W": 0, + "X": 0, + "Y": 0, + "Z": 0, + "@": 0, # Special characters + "漢": 0, # CJK characters } - + # Count models for each letter for lora in data: - model_name = lora.get('model_name', '') + model_name = lora.get("model_name", "") if not model_name: continue - + first_char = model_name[0].upper() - + if first_char.isdigit(): - letters['#'] += 1 + letters["#"] += 1 elif first_char in letters: letters[first_char] += 1 elif self._is_cjk_character(first_char): - letters['漢'] += 1 + letters["漢"] += 1 elif not first_char.isalnum(): - letters['@'] += 1 - + letters["@"] += 1 + return letters - + async def get_lora_trigger_words(self, lora_name: str) -> List[str]: """Get trigger words for a specific LoRA file""" cache = await self.scanner.get_cached_data() - + for lora in cache.raw_data: - if lora['file_name'] == lora_name: - civitai_data = lora.get('civitai', {}) - return civitai_data.get('trainedWords', []) - + if lora["file_name"] == lora_name: + civitai_data = lora.get("civitai", {}) + return civitai_data.get("trainedWords", []) + return [] - - async def get_lora_usage_tips_by_relative_path(self, relative_path: str) -> Optional[str]: + + async def get_lora_usage_tips_by_relative_path( + self, relative_path: str + ) -> Optional[str]: """Get usage tips for a LoRA by its relative path""" cache = await self.scanner.get_cached_data() - + for lora in cache.raw_data: - file_path = lora.get('file_path', '') + file_path = lora.get("file_path", "") if file_path: # Convert to forward slashes and extract relative path - file_path_normalized = file_path.replace('\\', '/') - relative_path = relative_path.replace('\\', '/') + file_path_normalized = file_path.replace("\\", "/") + relative_path = relative_path.replace("\\", "/") # Find the relative path part by looking for the relative_path in the full path - if file_path_normalized.endswith(relative_path) or relative_path in file_path_normalized: - return lora.get('usage_tips', '') - + if ( + file_path_normalized.endswith(relative_path) + or relative_path in file_path_normalized + ): + return lora.get("usage_tips", "") + return None - + def find_duplicate_hashes(self) -> Dict: """Find LoRAs with duplicate SHA256 hashes""" return self.scanner._hash_index.get_duplicate_hashes() @@ -192,7 +224,7 @@ class LoraService(BaseModelService): clip_strength_min: float = 0.0, clip_strength_max: float = 1.0, locked_loras: Optional[List[Dict]] = None, - pool_config: Optional[Dict] = None + pool_config: Optional[Dict] = None, ) -> List[Dict]: """ Get random LoRAs with specified strength ranges. @@ -235,10 +267,9 @@ class LoraService(BaseModelService): locked_loras = locked_loras[:count] # Filter out locked LoRAs from available pool - locked_names = {lora['name'] for lora in locked_loras} + locked_names = {lora["name"] for lora in locked_loras} available_pool = [ - l for l in available_loras - if l['model_name'] not in locked_names + l for l in available_loras if l["file_name"] not in locked_names ] # Ensure we don't try to select more than available @@ -253,9 +284,7 @@ class LoraService(BaseModelService): # Generate random strengths for selected LoRAs result_loras = [] for lora in selected: - model_str = round( - random.uniform(model_strength_min, model_strength_max), 2 - ) + model_str = round(random.uniform(model_strength_min, model_strength_max), 2) if use_same_clip_strength: clip_str = model_str @@ -264,21 +293,25 @@ class LoraService(BaseModelService): random.uniform(clip_strength_min, clip_strength_max), 2 ) - result_loras.append({ - 'name': lora['model_name'], - 'strength': model_str, - 'clipStrength': clip_str, - 'active': True, - 'expanded': abs(model_str - clip_str) > 0.001, - 'locked': False - }) + result_loras.append( + { + "name": lora["file_name"], + "strength": model_str, + "clipStrength": clip_str, + "active": True, + "expanded": abs(model_str - clip_str) > 0.001, + "locked": False, + } + ) # Merge with locked LoRAs result_loras.extend(locked_loras) return result_loras - async def _apply_pool_filters(self, available_loras: List[Dict], pool_config: Dict) -> List[Dict]: + async def _apply_pool_filters( + self, available_loras: List[Dict], pool_config: Dict + ) -> List[Dict]: """ Apply pool_config filters to available LoRAs. @@ -292,26 +325,26 @@ class LoraService(BaseModelService): from .model_query import FilterCriteria # Extract filter parameters from pool_config - selected_base_models = pool_config.get('selected_base_models', []) - include_tags = pool_config.get('include_tags', []) - exclude_tags = pool_config.get('exclude_tags', []) - include_folders = pool_config.get('include_folders', []) - exclude_folders = pool_config.get('exclude_folders', []) - no_credit_required = pool_config.get('no_credit_required', False) - allow_selling = pool_config.get('allow_selling', False) + selected_base_models = pool_config.get("selected_base_models", []) + include_tags = pool_config.get("include_tags", []) + exclude_tags = pool_config.get("exclude_tags", []) + include_folders = pool_config.get("include_folders", []) + exclude_folders = pool_config.get("exclude_folders", []) + no_credit_required = pool_config.get("no_credit_required", False) + allow_selling = pool_config.get("allow_selling", False) # Build tag filters dict tag_filters = {} for tag in include_tags: - tag_filters[tag] = 'include' + tag_filters[tag] = "include" for tag in exclude_tags: - tag_filters[tag] = 'exclude' + tag_filters[tag] = "exclude" # Build folder filter if include_folders or exclude_folders: filtered = [] for lora in available_loras: - folder = lora.get('folder', '') + folder = lora.get("folder", "") # Check exclude folders first excluded = False @@ -340,8 +373,9 @@ class LoraService(BaseModelService): # Apply base model filter if selected_base_models: available_loras = [ - lora for lora in available_loras - if lora.get('base_model') in selected_base_models + lora + for lora in available_loras + if lora.get("base_model") in selected_base_models ] # Apply tag filters @@ -352,14 +386,17 @@ class LoraService(BaseModelService): # Apply license filters if no_credit_required: available_loras = [ - lora for lora in available_loras - if not lora.get('civitai', {}).get('allowNoCredit', True) + lora + for lora in available_loras + if not lora.get("civitai", {}).get("allowNoCredit", True) ] if allow_selling: available_loras = [ - lora for lora in available_loras - if lora.get('civitai', {}).get('allowCommercialUse', ['None'])[0] != 'None' + lora + for lora in available_loras + if lora.get("civitai", {}).get("allowCommercialUse", ["None"])[0] + != "None" ] return available_loras diff --git a/vue-widgets/src/components/LoraRandomizerWidget.vue b/vue-widgets/src/components/LoraRandomizerWidget.vue index 75e4a5f3..f31132d1 100644 --- a/vue-widgets/src/components/LoraRandomizerWidget.vue +++ b/vue-widgets/src/components/LoraRandomizerWidget.vue @@ -32,12 +32,12 @@ import { onMounted } from 'vue' import LoraRandomizerSettingsView from './lora-randomizer/LoraRandomizerSettingsView.vue' import { useLoraRandomizerState } from '../composables/useLoraRandomizerState' -import type { ComponentWidget, RandomizerConfig } from '../composables/types' +import type { ComponentWidget, RandomizerConfig, LoraEntry } from '../composables/types' // Props const props = defineProps<{ widget: ComponentWidget - node: { id: number } + node: { id: number; inputs?: any[]; widgets?: any[]; graph?: any } }>() // State management @@ -48,13 +48,15 @@ const handleRoll = async () => { try { console.log('[LoraRandomizerWidget] Roll button clicked') - // Get pool config from connected input (if any) - // This would need to be passed from the node's pool_config input - const poolConfig = null // TODO: Get from node input if connected + // Get pool config from connected pool_config input + const poolConfig = (props.node as any).getPoolConfig?.() || null // Get locked loras from the loras widget - // This would need to be retrieved from the loras widget on the node - const lockedLoras: any[] = [] // TODO: Get from loras widget + const lorasWidget = props.node.widgets?.find((w: any) => w.name === "loras") + const lockedLoras: LoraEntry[] = (lorasWidget?.value || []).filter((lora: LoraEntry) => lora.locked === true) + + console.log('[LoraRandomizerWidget] Pool config:', poolConfig) + console.log('[LoraRandomizerWidget] Locked loras:', lockedLoras) // Call API to get random loras const randomLoras = await state.rollLoras(poolConfig, lockedLoras) diff --git a/vue-widgets/src/components/lora-randomizer/LoraRandomizerSettingsView.vue b/vue-widgets/src/components/lora-randomizer/LoraRandomizerSettingsView.vue index 8719b715..224dcb62 100644 --- a/vue-widgets/src/components/lora-randomizer/LoraRandomizerSettingsView.vue +++ b/vue-widgets/src/components/lora-randomizer/LoraRandomizerSettingsView.vue @@ -154,8 +154,10 @@ :disabled="rollMode !== 'frontend' || isRolling" @click="$emit('roll')" > - 🎲 Roll - Rolling... + + + Roll +
@@ -329,7 +331,7 @@ defineEmits<{ } .roll-button { - padding: 6px 16px; + padding: 8px 16px; background: linear-gradient(135deg, #3b82f6 0%, #2563eb 100%); border: none; border-radius: 4px; @@ -339,6 +341,9 @@ defineEmits<{ cursor: pointer; transition: all 0.2s; white-space: nowrap; + display: flex; + align-items: center; + justify-content: center; } .roll-button:hover:not(:disabled) { @@ -356,4 +361,10 @@ defineEmits<{ cursor: not-allowed; background: linear-gradient(135deg, #52525b 0%, #3f3f46 100%); } + +.roll-button__content { + display: inline-flex; + align-items: center; + gap: 6px; +} diff --git a/vue-widgets/src/main.ts b/vue-widgets/src/main.ts index aa5add04..1fd31f1f 100644 --- a/vue-widgets/src/main.ts +++ b/vue-widgets/src/main.ts @@ -6,6 +6,8 @@ import type { LoraPoolConfig, LegacyLoraPoolConfig, RandomizerConfig } from './c // @ts-ignore - ComfyUI external module import { app } from '../../../scripts/app.js' +// @ts-ignore +import { getPoolConfigFromConnectedNode, getActiveLorasFromNode, updateConnectedTriggerWords, updateDownstreamLoaders } from '../../web/comfyui/utils.js' const vueApps = new Map() @@ -119,6 +121,9 @@ function createLoraRandomizerWidget(node) { internalValue = v } + // Add method to get pool config from connected node + node.getPoolConfig = () => getPoolConfigFromConnectedNode(node) + // Handle roll event from Vue component widget.onRoll = (randomLoras: any[]) => { console.log('[createLoraRandomizerWidget] Roll event received:', randomLoras) @@ -181,10 +186,21 @@ app.registerExtension({ // @ts-ignore async LORAS(node: any) { if (!addLorasWidgetCache) { + // @ts-ignore const module = await import(/* @vite-ignore */ '../loras_widget.js') addLorasWidgetCache = module.addLorasWidget } - return addLorasWidgetCache(node, 'loras', {}, null) + // Check if this is a randomizer node to enable lock buttons + const isRandomizerNode = node.comfyClass === 'Lora Randomizer (LoraManager)' + + console.log(node) + + // For randomizer nodes, add a callback to update connected trigger words + const callback = isRandomizerNode ? (value: any) => { + updateDownstreamLoaders(node) + } : null + + return addLorasWidgetCache(node, 'loras', { isRandomizerNode }, callback) } } } diff --git a/web/comfyui/lm_styles.css b/web/comfyui/lm_styles.css index 4c7b2641..a66b07de 100644 --- a/web/comfyui/lm_styles.css +++ b/web/comfyui/lm_styles.css @@ -175,6 +175,20 @@ box-shadow: 0 0 0 1px rgba(66, 153, 225, 0.3) !important; } +.lm-lora-entry[data-selected="true"][data-locked="true"] { + background-color: rgba(66, 153, 225, 0.2) !important; + border-left: 3px solid rgba(245, 158, 11, 0.8) !important; + border-right: 1px solid rgba(66, 153, 225, 0.6) !important; + border-top: 1px solid rgba(66, 153, 225, 0.6) !important; + border-bottom: 1px solid rgba(66, 153, 225, 0.6) !important; + box-shadow: 0 0 0 1px rgba(66, 153, 225, 0.3), inset 0 0 20px rgba(245, 158, 11, 0.04) !important; +} + +.lm-lora-entry[data-selected="true"][data-locked="true"][data-active="false"] { + background-color: rgba(48, 42, 36, 0.5) !important; + border-left: 3px solid rgba(245, 158, 11, 0.6) !important; +} + .lm-lora-name { margin-left: 4px; flex: 1; @@ -236,7 +250,6 @@ justify-content: center; cursor: pointer; user-select: none; - font-size: 12px; color: rgba(226, 232, 240, 0.8); transition: all 0.2s ease; } @@ -246,6 +259,11 @@ transform: scale(1.2); } +.lm-lora-arrow svg { + width: 12px; + height: 12px; +} + .lm-lora-expand-button { width: 20px; height: 20px; @@ -254,7 +272,6 @@ justify-content: center; cursor: pointer; user-select: none; - font-size: 10px; color: rgba(226, 232, 240, 0.7); background-color: rgba(45, 55, 72, 0.3); border: 1px solid rgba(226, 232, 240, 0.2); @@ -262,7 +279,6 @@ transition: all 0.2s ease; flex-shrink: 0; padding: 0; - line-height: 1; box-sizing: border-box; } @@ -314,13 +330,17 @@ justify-content: center; cursor: grab; user-select: none; - font-size: 14px; color: rgba(226, 232, 240, 0.6); transition: all 0.2s ease; margin-right: 8px; flex-shrink: 0; } +.lm-lora-drag-handle svg { + width: 14px; + height: 14px; +} + .lm-lora-drag-handle:hover { color: rgba(226, 232, 240, 0.9); transform: scale(1.1); @@ -330,6 +350,83 @@ cursor: grabbing; } +.lm-lora-lock-button { + width: 20px; + height: 20px; + display: flex; + align-items: center; + justify-content: center; + cursor: pointer; + user-select: none; + background-color: rgba(45, 55, 72, 0.3); + border: 1px solid rgba(226, 232, 240, 0.2); + border-radius: 3px; + transition: all 0.2s ease; + flex-shrink: 0; + padding: 0; + box-sizing: border-box; + margin-right: 8px; + color: rgba(226, 232, 240, 0.8); +} + +.lm-lora-lock-button svg { + width: 14px; + height: 14px; +} + +.lm-lora-lock-button:hover { + background-color: rgba(66, 153, 225, 0.2); + border-color: rgba(66, 153, 225, 0.4); + color: rgba(226, 232, 240, 0.95); + transform: scale(1.05); +} + +.lm-lora-lock-button:active { + transform: scale(0.95); +} + +.lm-lora-lock-button:focus { + outline: none; +} + +.lm-lora-lock-button:focus-visible { + box-shadow: 0 0 0 2px rgba(66, 153, 225, 0.5); +} + +.lm-lora-lock-button--locked { + background-color: rgba(245, 158, 11, 0.25); + border-color: rgba(245, 158, 11, 0.7); + color: rgba(251, 191, 36, 0.95); + box-shadow: 0 0 8px rgba(245, 158, 11, 0.15); +} + +.lm-lora-lock-button--locked:hover { + background-color: rgba(245, 158, 11, 0.35); + border-color: rgba(245, 158, 11, 0.85); + box-shadow: 0 0 12px rgba(245, 158, 11, 0.25); +} + +/* Visual styling for locked lora entries */ +.lm-lora-entry[data-locked="true"] { + background-color: rgba(60, 50, 40, 0.75); + border-left: 3px solid rgba(245, 158, 11, 0.7); + box-shadow: inset 0 0 20px rgba(245, 158, 11, 0.04); +} + +.lm-lora-entry[data-locked="true"]:hover { + background-color: rgba(65, 55, 45, 0.85); + box-shadow: inset 0 0 24px rgba(245, 158, 11, 0.06); +} + +.lm-lora-entry[data-locked="true"][data-active="false"] { + background-color: rgba(48, 42, 36, 0.6); + border-left: 3px solid rgba(245, 158, 11, 0.5); +} + +.lm-lora-entry[data-locked="true"][data-active="false"]:hover { + background-color: rgba(52, 46, 40, 0.7); +} + body.lm-lora-strength-dragging, body.lm-lora-strength-dragging * { cursor: ew-resize !important; diff --git a/web/comfyui/lora_stacker.js b/web/comfyui/lora_stacker.js index 6c278550..818e9851 100644 --- a/web/comfyui/lora_stacker.js +++ b/web/comfyui/lora_stacker.js @@ -3,6 +3,7 @@ import { getActiveLorasFromNode, collectActiveLorasFromChain, updateConnectedTriggerWords, + updateDownstreamLoaders, chainCallback, mergeLoras, setupInputWidgetWithAutocomplete, @@ -169,41 +170,3 @@ app.registerExtension({ } }, }); - -// Helper function to find and update downstream Lora Loader nodes -function updateDownstreamLoaders(startNode, visited = new Set()) { - const nodeKey = getNodeKey(startNode); - if (!nodeKey || visited.has(nodeKey)) return; - visited.add(nodeKey); - - // Check each output link - if (startNode.outputs) { - for (const output of startNode.outputs) { - if (output.links) { - for (const linkId of output.links) { - const link = getLinkFromGraph(startNode.graph, linkId); - if (link) { - const targetNode = startNode.graph?.getNodeById?.(link.target_id); - - // If target is a Lora Loader, collect all active loras in the chain and update - if ( - targetNode && - targetNode.comfyClass === "Lora Loader (LoraManager)" - ) { - const allActiveLoraNames = - collectActiveLorasFromChain(targetNode); - updateConnectedTriggerWords(targetNode, allActiveLoraNames); - } - // If target is another Lora Stacker, recursively check its outputs - else if ( - targetNode && - targetNode.comfyClass === "Lora Stacker (LoraManager)" - ) { - updateDownstreamLoaders(targetNode, visited); - } - } - } - } - } - } -} diff --git a/web/comfyui/loras_widget.js b/web/comfyui/loras_widget.js index 80048274..bcb394a7 100644 --- a/web/comfyui/loras_widget.js +++ b/web/comfyui/loras_widget.js @@ -1,4 +1,4 @@ -import { createToggle, createArrowButton, createDragHandle, updateEntrySelection, createExpandButton, updateExpandButtonState } from "./loras_widget_components.js"; +import { createToggle, createArrowButton, createDragHandle, updateEntrySelection, createExpandButton, updateExpandButtonState, createLockButton, updateLockButtonState } from "./loras_widget_components.js"; import { parseLoraValue, formatLoraValue, @@ -27,6 +27,9 @@ export function addLorasWidget(node, name, opts, callback) { // Set initial height using CSS variables approach const defaultHeight = 200; + // Check if this is a randomizer node (lock button instead of drag handle) + const isRandomizerNode = opts?.isRandomizerNode === true; + // Initialize default value const defaultValue = opts?.defaultVal || []; const onSelectionChange = typeof opts?.onSelectionChange === "function" @@ -255,32 +258,52 @@ export function addLorasWidget(node, name, opts, callback) { const loraEl = document.createElement("div"); loraEl.className = "lm-lora-entry"; - // Store lora name and active state in dataset for selection + // Store lora name, active state, and locked state in dataset loraEl.dataset.loraName = name; loraEl.dataset.active = active ? "true" : "false"; + loraEl.dataset.locked = (loraData.locked || false) ? "true" : "false"; // Add click handler for selection loraEl.addEventListener('click', (e) => { // Skip if clicking on interactive elements - if (e.target.closest('.lm-lora-toggle') || - e.target.closest('input') || + if (e.target.closest('.lm-lora-toggle') || + e.target.closest('input') || e.target.closest('.lm-lora-arrow') || e.target.closest('.lm-lora-drag-handle') || + e.target.closest('.lm-lora-lock-button') || e.target.closest('.lm-lora-expand-button')) { return; } - + e.preventDefault(); e.stopPropagation(); selectLora(name); container.focus(); // Focus container for keyboard events }); - // Create drag handle for reordering - const dragHandle = createDragHandle(); - - // Initialize reorder drag functionality - initReorderDrag(dragHandle, name, widget, renderLoras); + // Conditionally create drag handle OR lock button + let dragHandleOrLockButton; + + if (isRandomizerNode) { + // For randomizer node, show lock button instead of drag handle + const isLocked = loraData.locked || false; + dragHandleOrLockButton = createLockButton(isLocked, (newLocked) => { + // Update this lora's locked state + const lorasData = parseLoraValue(widget.value); + const loraIndex = lorasData.findIndex(l => l.name === name); + + if (loraIndex >= 0) { + lorasData[loraIndex].locked = newLocked; + const newValue = formatLoraValue(lorasData); + updateWidgetValue(newValue); + } + }); + } else { + // For other nodes, show drag handle + dragHandleOrLockButton = createDragHandle(); + // Initialize reorder drag functionality + initReorderDrag(dragHandleOrLockButton, name, widget, renderLoras); + } // Create toggle for this lora const toggle = createToggle(active, (newActive) => { @@ -481,8 +504,8 @@ export function addLorasWidget(node, name, opts, callback) { // Assemble entry const leftSection = document.createElement("div"); leftSection.className = "lm-lora-entry-left"; - - leftSection.appendChild(dragHandle); // Add drag handle first + + leftSection.appendChild(dragHandleOrLockButton); // Add drag handle or lock button first leftSection.appendChild(toggle); leftSection.appendChild(expandButton); // Add expand button leftSection.appendChild(nameEl); @@ -685,16 +708,17 @@ export function addLorasWidget(node, name, opts, callback) { }, []); // Apply existing clip strength values and transfer them to the new value - const updatedValue = uniqueValue.map(lora => { + const updatedValue = uniqueValue.map(lora => { // For new loras, default clip strength to model strength and expanded to false // unless clipStrength is already different from strength const clipStrength = lora.clipStrength || lora.strength; return { ...lora, clipStrength: clipStrength, - expanded: lora.hasOwnProperty('expanded') ? - lora.expanded : - Number(clipStrength) !== Number(lora.strength) + expanded: lora.hasOwnProperty('expanded') ? + lora.expanded : + Number(clipStrength) !== Number(lora.strength), + locked: lora.hasOwnProperty('locked') ? lora.locked : false // Initialize locked to false if not present }; }); diff --git a/web/comfyui/loras_widget_components.js b/web/comfyui/loras_widget_components.js index ffed8297..3e453884 100644 --- a/web/comfyui/loras_widget_components.js +++ b/web/comfyui/loras_widget_components.js @@ -22,7 +22,9 @@ export function updateToggleStyle(toggleEl, active) { export function createArrowButton(direction, onClick) { const button = document.createElement("div"); button.className = `lm-lora-arrow lm-lora-arrow-${direction}`; - button.textContent = direction === "left" ? "◀" : "▶"; + button.innerHTML = direction === "left" + ? `` + : ``; button.addEventListener("click", (e) => { e.stopPropagation(); @@ -36,7 +38,7 @@ export function createArrowButton(direction, onClick) { export function createDragHandle() { const handle = document.createElement("div"); handle.className = "lm-lora-drag-handle"; - handle.innerHTML = "≡"; + handle.innerHTML = ``; handle.title = "Drag to reorder LoRA"; return handle; } @@ -102,10 +104,42 @@ export function createExpandButton(isExpanded, onClick) { // Helper function to update expand button state export function updateExpandButtonState(button, isExpanded) { if (isExpanded) { - button.innerHTML = "▼"; // Down arrow for expanded + button.innerHTML = ``; button.title = "Collapse clip controls"; } else { - button.innerHTML = "▶"; // Right arrow for collapsed + button.innerHTML = ``; button.title = "Expand clip controls"; } } + +// Function to create lock button +export function createLockButton(isLocked, onChange) { + const button = document.createElement("button"); + button.className = "lm-lora-lock-button"; + button.type = "button"; + button.tabIndex = -1; + + // Set icon based on locked state + updateLockButtonState(button, isLocked); + + button.addEventListener("click", (e) => { + e.preventDefault(); + e.stopPropagation(); + onChange(!isLocked); + }); + + return button; +} + +// Helper function to update lock button state +export function updateLockButtonState(button, isLocked) { + if (isLocked) { + button.innerHTML = ``; + button.title = "Unlock this LoRA (allow re-rolling)"; + button.classList.add("lm-lora-lock-button--locked"); + } else { + button.innerHTML = ``; + button.title = "Lock this LoRA (prevent re-rolling)"; + button.classList.remove("lm-lora-lock-button--locked"); + } +} diff --git a/web/comfyui/trigger_word_toggle.js b/web/comfyui/trigger_word_toggle.js index 7ff85f78..4ab6a640 100644 --- a/web/comfyui/trigger_word_toggle.js +++ b/web/comfyui/trigger_word_toggle.js @@ -239,6 +239,7 @@ app.registerExtension({ // Handle trigger word updates from Python handleTriggerWordUpdate(id, graphId, message) { + console.log('trigger word update: ', id, graphId, message); const node = getNodeFromGraph(graphId, id); if (!node || node.comfyClass !== "TriggerWord Toggle (LoraManager)") { console.warn("Node not found or not a TriggerWordToggle:", id); diff --git a/web/comfyui/utils.js b/web/comfyui/utils.js index 60dc3b43..3277b850 100644 --- a/web/comfyui/utils.js +++ b/web/comfyui/utils.js @@ -233,7 +233,7 @@ export function getConnectedInputStackers(node) { } const sourceNode = node.graph?.getNodeById?.(link.origin_id); - if (sourceNode && sourceNode.comfyClass === "Lora Stacker (LoraManager)") { + if (sourceNode && (sourceNode.comfyClass === "Lora Stacker (LoraManager)" || sourceNode.comfyClass === "Lora Randomizer (LoraManager)")) { connectedStackers.push(sourceNode); } } @@ -274,9 +274,13 @@ export function getConnectedTriggerToggleNodes(node) { export function getActiveLorasFromNode(node) { const activeLoraNames = new Set(); - // For lorasWidget style entries (array of objects) - if (node.lorasWidget && node.lorasWidget.value) { - node.lorasWidget.value.forEach(lora => { + let lorasWidget = node.lorasWidget; + if (!lorasWidget && node.widgets) { + lorasWidget = node.widgets.find(w => w.name === 'loras'); + } + + if (lorasWidget && lorasWidget.value) { + lorasWidget.value.forEach(lora => { if (lora.active) { activeLoraNames.add(lora.name); } @@ -324,6 +328,8 @@ export function updateConnectedTriggerWords(node, loraNames) { .map((connectedNode) => getNodeReference(connectedNode)) .filter((reference) => reference !== null); + console.log('node ids: ', nodeIds, loraNames); + if (nodeIds.length === 0) { return; } @@ -467,3 +473,78 @@ export function forwardMiddleMouseToCanvas(container) { } }); } + +// Get connected Lora Pool node from pool_config input +export function getConnectedPoolConfigNode(node) { + if (!node?.inputs) { + return null; + } + + for (const input of node.inputs) { + if (input.name !== "pool_config" || !input.link) { + continue; + } + + const link = getLinkFromGraph(node.graph, input.link); + if (!link) { + continue; + } + + const sourceNode = node.graph?.getNodeById?.(link.origin_id); + if (sourceNode && sourceNode.comfyClass === "Lora Pool (LoraManager)") { + return sourceNode; + } + } + + return null; +} + +// Get pool config widget value from connected Lora Pool node +export function getPoolConfigFromConnectedNode(node) { + const poolNode = getConnectedPoolConfigNode(node); + if (!poolNode) { + return null; + } + + const poolWidget = poolNode.widgets?.find(w => w.name === "pool_config"); + return poolWidget?.value || null; +} + +// Helper function to find and update downstream Lora Loader nodes +export function updateDownstreamLoaders(startNode, visited = new Set()) { + const nodeKey = getNodeKey(startNode); + if (!nodeKey || visited.has(nodeKey)) return; + visited.add(nodeKey); + + // Check each output link + if (startNode.outputs) { + for (const output of startNode.outputs) { + if (output.links) { + for (const linkId of output.links) { + const link = getLinkFromGraph(startNode.graph, linkId); + if (link) { + const targetNode = startNode.graph?.getNodeById?.(link.target_id); + + // If target is a Lora Loader, collect all active loras in the chain and update + if ( + targetNode && + targetNode.comfyClass === "Lora Loader (LoraManager)" + ) { + const allActiveLoraNames = + collectActiveLorasFromChain(targetNode); + updateConnectedTriggerWords(targetNode, allActiveLoraNames); + } + // If target is another Lora Stacker or Lora Randomizer, recursively check its outputs + else if ( + targetNode && + (targetNode.comfyClass === "Lora Stacker (LoraManager)" || + targetNode.comfyClass === "Lora Randomizer (LoraManager)") + ) { + updateDownstreamLoaders(targetNode, visited); + } + } + } + } + } + } +}