diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..b0b76b75 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,163 @@ +# AGENTS.md + +This file provides guidance for agentic coding assistants working in this repository. + +## Development Commands + +### Backend Development + +```bash +# Install dependencies +pip install -r requirements.txt +pip install -r requirements-dev.txt + +# Run standalone server (port 8188 by default) +python standalone.py --port 8188 + +# Run all backend tests +pytest + +# Run specific test file +pytest tests/test_recipes.py + +# Run specific test function +pytest tests/test_recipes.py::test_function_name + +# Run backend tests with coverage +COVERAGE_FILE=coverage/backend/.coverage pytest \ + --cov=py \ + --cov=standalone \ + --cov-report=term-missing \ + --cov-report=html:coverage/backend/html \ + --cov-report=xml:coverage/backend/coverage.xml \ + --cov-report=json:coverage/backend/coverage.json +``` + +### Frontend Development + +```bash +# Install frontend dependencies +npm install + +# Run frontend tests +npm test + +# Run frontend tests in watch mode +npm run test:watch + +# Run frontend tests with coverage +npm run test:coverage +``` + +## Python Code Style + +### Imports + +- Use `from __future__ import annotations` for forward references in type hints +- Group imports: standard library, third-party, local (separated by blank lines) +- Use absolute imports within `py/` package: `from ..services import X` +- Mock ComfyUI dependencies in tests using `tests/conftest.py` patterns + +### Formatting & Types + +- PEP 8 with 4-space indentation +- Type hints required for function signatures and class attributes +- Use `TYPE_CHECKING` guard for type-checking-only imports +- Prefer dataclasses for simple data containers +- Use `Optional[T]` for nullable types, `Union[T, None]` only when necessary + +### Naming Conventions + +- Files: `snake_case.py` (e.g., `model_scanner.py`, `lora_service.py`) +- Classes: `PascalCase` (e.g., `ModelScanner`, `LoraService`) +- Functions/variables: `snake_case` (e.g., `get_instance`, `model_type`) +- Constants: `UPPER_SNAKE_CASE` (e.g., `VALID_LORA_TYPES`) +- Private members: `_single_underscore` (protected), `__double_underscore` (name-mangled) + +### Error Handling + +- Use `logging.getLogger(__name__)` for module-level loggers +- Define custom exceptions in `py/services/errors.py` +- Use `asyncio.Lock` for thread-safe singleton patterns +- Raise specific exceptions with descriptive messages +- Log errors at appropriate levels (DEBUG, INFO, WARNING, ERROR, CRITICAL) + +### Async Patterns + +- Use `async def` for I/O-bound operations +- Mark async tests with `@pytest.mark.asyncio` +- Use `async with` for context managers +- Singleton pattern with class-level locks: see `ModelScanner.get_instance()` +- Use `aiohttp.web.Response` for HTTP responses + +### Testing Patterns + +- Use `pytest` with `--import-mode=importlib` +- Fixtures in `tests/conftest.py` handle ComfyUI mocking +- Use `@pytest.mark.no_settings_dir_isolation` for tests needing real paths +- Test files: `tests/test_*.py` +- Use `tmp_path_factory` for temporary directory isolation + +## JavaScript Code Style + +### Imports & Modules + +- ES modules with `import`/`export` +- Use `import { app } from "../../scripts/app.js"` for ComfyUI integration +- Export named functions/classes: `export function foo() {}` +- Widget files use `*_widget.js` suffix + +### Naming & Formatting + +- camelCase for functions, variables, object properties +- PascalCase for classes/constructors +- Constants: `UPPER_SNAKE_CASE` (e.g., `CONVERTED_TYPE`) +- Files: `snake_case.js` or `kebab-case.js` +- 2-space indentation preferred (follow existing file conventions) + +### Widget Development + +- Use `app.registerExtension()` to register ComfyUI extensions +- Use `node.addDOMWidget(name, type, element, options)` for custom widgets +- Event handlers attached via `addEventListener` or widget callbacks +- See `web/comfyui/utils.js` for shared utilities + +## Architecture Patterns + +### Service Layer + +- Use `ServiceRegistry` singleton for dependency injection +- Services follow singleton pattern via `get_instance()` class method +- Separate scanners (discovery) from services (business logic) +- Handlers in `py/routes/handlers/` implement route logic + +### Model Types + +- BaseModelService is abstract base for LoRA, Checkpoint, Embedding services +- ModelScanner provides file discovery and hash-based deduplication +- Persistent cache in SQLite via `PersistentModelCache` +- Metadata sync from CivitAI/CivArchive via `MetadataSyncService` + +### Routes & Handlers + +- Route registrars organize endpoints by domain: `ModelRouteRegistrar`, etc. +- Handlers are pure functions taking dependencies as parameters +- Use `WebSocketManager` for real-time progress updates +- Return `aiohttp.web.json_response` or `web.Response` + +### Recipe System + +- Base metadata in `py/recipes/base.py` +- Enrichment adds model metadata: `RecipeEnrichmentService` +- Parsers for different formats in `py/recipes/parsers/` + +## Important Notes + +- Always use English for comments (per copilot-instructions.md) +- Dual mode: ComfyUI plugin (uses folder_paths) vs standalone (reads settings.json) +- Detection: `os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1"` +- Settings auto-saved in user directory or portable mode +- WebSocket broadcasts for real-time updates (downloads, scans) +- Symlink handling requires normalized paths +- API endpoints follow `/loras/*`, `/checkpoints/*`, `/embeddings/*` patterns +- Run `python scripts/sync_translation_keys.py` after UI string updates diff --git a/py/nodes/lora_pool.py b/py/nodes/lora_pool.py index 19cee8ea..aec2ed7e 100644 --- a/py/nodes/lora_pool.py +++ b/py/nodes/lora_pool.py @@ -20,7 +20,7 @@ class LoraPoolNode: """ NAME = "Lora Pool (LoraManager)" - CATEGORY = "Lora Manager/pools" + CATEGORY = "Lora Manager/randomizer" @classmethod def INPUT_TYPES(cls): @@ -85,10 +85,3 @@ class LoraPoolNode: }, "preview": {"matchCount": 0, "lastUpdated": 0}, } - - -# Node class mappings for ComfyUI -NODE_CLASS_MAPPINGS = {"LoraPoolNode": LoraPoolNode} - -# Display name mappings -NODE_DISPLAY_NAME_MAPPINGS = {"LoraPoolNode": "LoRA Pool (Filter)"} diff --git a/py/nodes/lora_randomizer.py b/py/nodes/lora_randomizer.py index 7eb79ab5..83990410 100644 --- a/py/nodes/lora_randomizer.py +++ b/py/nodes/lora_randomizer.py @@ -34,7 +34,7 @@ class LoraRandomizerNode: } RETURN_TYPES = ("LORA_STACK",) - RETURN_NAMES = ("lora_stack",) + RETURN_NAMES = ("LORA_STACK",) FUNCTION = "randomize" OUTPUT_NODE = False @@ -53,9 +53,6 @@ class LoraRandomizerNode: """ from ..services.service_registry import ServiceRegistry - # Get lora scanner to access available loras - scanner = await ServiceRegistry.get_lora_scanner() - # Parse randomizer settings count_mode = randomizer_config.get("count_mode", "range") count_fixed = randomizer_config.get("count_fixed", 5) @@ -68,6 +65,93 @@ class LoraRandomizerNode: clip_strength_max = randomizer_config.get("clip_strength_max", 1.0) roll_mode = randomizer_config.get("roll_mode", "frontend") + # Get lora scanner to access available loras + scanner = await ServiceRegistry.get_lora_scanner() + + # Backend roll mode: execute with input loras, return new random to UI + if roll_mode == "backend": + execution_stack = self._build_execution_stack_from_input(loras) + ui_loras = await self._generate_random_loras_for_ui( + scanner, randomizer_config, loras, pool_config + ) + logger.info( + f"[LoraRandomizerNode] Backend roll: executing with input, returning new random to UI" + ) + return {"result": (execution_stack,), "ui": {"loras": ui_loras}} + + # Frontend roll mode: use current behavior (random selection for both) + ui_loras = await self._generate_random_loras_for_ui( + scanner, randomizer_config, loras, pool_config + ) + execution_stack = self._build_execution_stack_from_input(ui_loras) + logger.info( + f"[LoraRandomizerNode] Frontend roll: executing with random selection" + ) + return {"result": (execution_stack,), "ui": {"loras": ui_loras}} + + def _build_execution_stack_from_input(self, loras): + """ + Build LORA_STACK tuple from input loras list for execution. + + Args: + loras: List of LoRA dicts with name, strength, clipStrength, active + + Returns: + List of tuples (lora_path, model_strength, clip_strength) + """ + lora_stack = [] + for lora in loras: + if not lora.get("active", False): + continue + + # Get file path + lora_path, trigger_words = get_lora_info(lora["name"]) + if not lora_path: + logger.warning( + f"[LoraRandomizerNode] Could not find path for LoRA: {lora['name']}" + ) + continue + + # Normalize path separators + lora_path = lora_path.replace("/", os.sep) + + # Extract strengths + model_strength = lora.get("strength", 1.0) + clip_strength = lora.get("clipStrength", model_strength) + + lora_stack.append((lora_path, model_strength, clip_strength)) + + logger.info( + f"[LoraRandomizerNode] Built execution stack with {len(lora_stack)} LoRAs" + ) + return lora_stack + + async def _generate_random_loras_for_ui( + self, scanner, randomizer_config, input_loras, pool_config=None + ): + """ + Generate new random loras for UI display. + + Args: + scanner: LoraScanner instance + randomizer_config: Dict with randomizer settings + input_loras: Current input loras (for extracting locked loras) + pool_config: Optional pool filters + + Returns: + List of LoRA dicts for UI display + """ + # Parse randomizer settings + count_mode = randomizer_config.get("count_mode", "range") + count_fixed = randomizer_config.get("count_fixed", 5) + count_min = randomizer_config.get("count_min", 3) + count_max = randomizer_config.get("count_max", 7) + model_strength_min = randomizer_config.get("model_strength_min", 0.0) + model_strength_max = randomizer_config.get("model_strength_max", 1.0) + use_same_clip_strength = randomizer_config.get("use_same_clip_strength", True) + clip_strength_min = randomizer_config.get("clip_strength_min", 0.0) + clip_strength_max = randomizer_config.get("clip_strength_max", 1.0) + # Determine target count if count_mode == "fixed": target_count = count_fixed @@ -75,11 +159,11 @@ class LoraRandomizerNode: target_count = random.randint(count_min, count_max) logger.info( - f"[LoraRandomizerNode] Target count: {target_count}, Roll mode: {roll_mode}" + f"[LoraRandomizerNode] Generating random LoRAs, target count: {target_count}" ) # Extract locked LoRAs from input - locked_loras = [lora for lora in loras if lora.get("locked", False)] + locked_loras = [lora for lora in input_loras if lora.get("locked", False)] locked_count = len(locked_loras) logger.info(f"[LoraRandomizerNode] Locked LoRAs: {locked_count}") @@ -106,8 +190,6 @@ class LoraRandomizerNode: ) # Calculate how many new LoRAs to select - # In frontend mode, if loras already has data, preserve unlocked ones if roll_mode requires - # For simplicity in backend mode, we regenerate all unlocked slots slots_needed = target_count - locked_count if slots_needed < 0: @@ -161,33 +243,10 @@ class LoraRandomizerNode: # Merge with locked LoRAs result_loras.extend(locked_loras) - logger.info(f"[LoraRandomizerNode] Final LoRA count: {len(result_loras)}") - - # Build LORA_STACK output - lora_stack = [] - for lora in result_loras: - if not lora.get("active", False): - continue - - # Get file path - lora_path, trigger_words = get_lora_info(lora["name"]) - if not lora_path: - logger.warning( - f"[LoraRandomizerNode] Could not find path for LoRA: {lora['name']}" - ) - continue - - # Normalize path separators - lora_path = lora_path.replace("/", os.sep) - - # Extract strengths - model_strength = lora.get("strength", 1.0) - clip_strength = lora.get("clipStrength", model_strength) - - lora_stack.append((lora_path, model_strength, clip_strength)) - - # Return format: result for workflow + ui for frontend display - return {"result": (lora_stack,), "ui": {"loras": result_loras}} + logger.info( + f"[LoraRandomizerNode] Final random LoRA count: {len(result_loras)}" + ) + return result_loras async def _apply_pool_filters(self, available_loras, pool_config, scanner): """ @@ -288,10 +347,3 @@ class LoraRandomizerNode: ] return available_loras - - -# Node class mappings for ComfyUI -NODE_CLASS_MAPPINGS = {"LoraRandomizerNode": LoraRandomizerNode} - -# Display name mappings -NODE_DISPLAY_NAME_MAPPINGS = {"LoraRandomizerNode": "LoRA Randomizer"} diff --git a/py/services/lora_service.py b/py/services/lora_service.py index 01a4759f..7beffecd 100644 --- a/py/services/lora_service.py +++ b/py/services/lora_service.py @@ -8,24 +8,27 @@ from ..config import config logger = logging.getLogger(__name__) + class LoraService(BaseModelService): """LoRA-specific service implementation""" - + def __init__(self, scanner, update_service=None): """Initialize LoRA service - + Args: scanner: LoRA scanner instance update_service: Optional service for remote update tracking. """ super().__init__("lora", scanner, LoraMetadata, update_service=update_service) - + async def format_response(self, lora_data: Dict) -> Dict: """Format LoRA data for API response""" return { "model_name": lora_data["model_name"], "file_name": lora_data["file_name"], - "preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")), + "preview_url": config.get_preview_static_url( + lora_data.get("preview_url", "") + ), "preview_nsfw_level": lora_data.get("preview_nsfw_level", 0), "base_model": lora_data.get("base_model", ""), "folder": lora_data["folder"], @@ -40,141 +43,170 @@ class LoraService(BaseModelService): "notes": lora_data.get("notes", ""), "favorite": lora_data.get("favorite", False), "update_available": bool(lora_data.get("update_available", False)), - "civitai": self.filter_civitai_data(lora_data.get("civitai", {}), minimal=True) + "civitai": self.filter_civitai_data( + lora_data.get("civitai", {}), minimal=True + ), } - + async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]: """Apply LoRA-specific filters""" # Handle first_letter filter for LoRAs - first_letter = kwargs.get('first_letter') + first_letter = kwargs.get("first_letter") if first_letter: data = self._filter_by_first_letter(data, first_letter) - + return data - + def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]: """Filter data by first letter of model name - + Special handling: - '#': Numbers (0-9) - '@': Special characters (not alphanumeric) - '漢': CJK characters """ filtered_data = [] - + for lora in data: - model_name = lora.get('model_name', '') + model_name = lora.get("model_name", "") if not model_name: continue - + first_char = model_name[0].upper() - - if letter == '#' and first_char.isdigit(): + + if letter == "#" and first_char.isdigit(): filtered_data.append(lora) - elif letter == '@' and not first_char.isalnum(): + elif letter == "@" and not first_char.isalnum(): # Special characters (not alphanumeric) filtered_data.append(lora) - elif letter == '漢' and self._is_cjk_character(first_char): + elif letter == "漢" and self._is_cjk_character(first_char): # CJK characters filtered_data.append(lora) elif letter.upper() == first_char: # Regular alphabet matching filtered_data.append(lora) - + return filtered_data - + def _is_cjk_character(self, char: str) -> bool: """Check if character is a CJK character""" # Define Unicode ranges for CJK characters cjk_ranges = [ - (0x4E00, 0x9FFF), # CJK Unified Ideographs - (0x3400, 0x4DBF), # CJK Unified Ideographs Extension A - (0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B - (0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C - (0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D - (0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E - (0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F - (0x30000, 0x3134F), # CJK Unified Ideographs Extension G - (0xF900, 0xFAFF), # CJK Compatibility Ideographs - (0x3300, 0x33FF), # CJK Compatibility - (0x3200, 0x32FF), # Enclosed CJK Letters and Months - (0x3100, 0x312F), # Bopomofo - (0x31A0, 0x31BF), # Bopomofo Extended - (0x3040, 0x309F), # Hiragana - (0x30A0, 0x30FF), # Katakana - (0x31F0, 0x31FF), # Katakana Phonetic Extensions - (0xAC00, 0xD7AF), # Hangul Syllables - (0x1100, 0x11FF), # Hangul Jamo - (0xA960, 0xA97F), # Hangul Jamo Extended-A - (0xD7B0, 0xD7FF), # Hangul Jamo Extended-B + (0x4E00, 0x9FFF), # CJK Unified Ideographs + (0x3400, 0x4DBF), # CJK Unified Ideographs Extension A + (0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B + (0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C + (0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D + (0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E + (0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F + (0x30000, 0x3134F), # CJK Unified Ideographs Extension G + (0xF900, 0xFAFF), # CJK Compatibility Ideographs + (0x3300, 0x33FF), # CJK Compatibility + (0x3200, 0x32FF), # Enclosed CJK Letters and Months + (0x3100, 0x312F), # Bopomofo + (0x31A0, 0x31BF), # Bopomofo Extended + (0x3040, 0x309F), # Hiragana + (0x30A0, 0x30FF), # Katakana + (0x31F0, 0x31FF), # Katakana Phonetic Extensions + (0xAC00, 0xD7AF), # Hangul Syllables + (0x1100, 0x11FF), # Hangul Jamo + (0xA960, 0xA97F), # Hangul Jamo Extended-A + (0xD7B0, 0xD7FF), # Hangul Jamo Extended-B ] - + code_point = ord(char) return any(start <= code_point <= end for start, end in cjk_ranges) - + # LoRA-specific methods async def get_letter_counts(self) -> Dict[str, int]: """Get count of LoRAs for each letter of the alphabet""" cache = await self.scanner.get_cached_data() data = cache.raw_data - + # Define letter categories letters = { - '#': 0, # Numbers - 'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0, - 'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0, - 'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0, - 'Y': 0, 'Z': 0, - '@': 0, # Special characters - '漢': 0 # CJK characters + "#": 0, # Numbers + "A": 0, + "B": 0, + "C": 0, + "D": 0, + "E": 0, + "F": 0, + "G": 0, + "H": 0, + "I": 0, + "J": 0, + "K": 0, + "L": 0, + "M": 0, + "N": 0, + "O": 0, + "P": 0, + "Q": 0, + "R": 0, + "S": 0, + "T": 0, + "U": 0, + "V": 0, + "W": 0, + "X": 0, + "Y": 0, + "Z": 0, + "@": 0, # Special characters + "漢": 0, # CJK characters } - + # Count models for each letter for lora in data: - model_name = lora.get('model_name', '') + model_name = lora.get("model_name", "") if not model_name: continue - + first_char = model_name[0].upper() - + if first_char.isdigit(): - letters['#'] += 1 + letters["#"] += 1 elif first_char in letters: letters[first_char] += 1 elif self._is_cjk_character(first_char): - letters['漢'] += 1 + letters["漢"] += 1 elif not first_char.isalnum(): - letters['@'] += 1 - + letters["@"] += 1 + return letters - + async def get_lora_trigger_words(self, lora_name: str) -> List[str]: """Get trigger words for a specific LoRA file""" cache = await self.scanner.get_cached_data() - + for lora in cache.raw_data: - if lora['file_name'] == lora_name: - civitai_data = lora.get('civitai', {}) - return civitai_data.get('trainedWords', []) - + if lora["file_name"] == lora_name: + civitai_data = lora.get("civitai", {}) + return civitai_data.get("trainedWords", []) + return [] - - async def get_lora_usage_tips_by_relative_path(self, relative_path: str) -> Optional[str]: + + async def get_lora_usage_tips_by_relative_path( + self, relative_path: str + ) -> Optional[str]: """Get usage tips for a LoRA by its relative path""" cache = await self.scanner.get_cached_data() - + for lora in cache.raw_data: - file_path = lora.get('file_path', '') + file_path = lora.get("file_path", "") if file_path: # Convert to forward slashes and extract relative path - file_path_normalized = file_path.replace('\\', '/') - relative_path = relative_path.replace('\\', '/') + file_path_normalized = file_path.replace("\\", "/") + relative_path = relative_path.replace("\\", "/") # Find the relative path part by looking for the relative_path in the full path - if file_path_normalized.endswith(relative_path) or relative_path in file_path_normalized: - return lora.get('usage_tips', '') - + if ( + file_path_normalized.endswith(relative_path) + or relative_path in file_path_normalized + ): + return lora.get("usage_tips", "") + return None - + def find_duplicate_hashes(self) -> Dict: """Find LoRAs with duplicate SHA256 hashes""" return self.scanner._hash_index.get_duplicate_hashes() @@ -192,7 +224,7 @@ class LoraService(BaseModelService): clip_strength_min: float = 0.0, clip_strength_max: float = 1.0, locked_loras: Optional[List[Dict]] = None, - pool_config: Optional[Dict] = None + pool_config: Optional[Dict] = None, ) -> List[Dict]: """ Get random LoRAs with specified strength ranges. @@ -235,10 +267,9 @@ class LoraService(BaseModelService): locked_loras = locked_loras[:count] # Filter out locked LoRAs from available pool - locked_names = {lora['name'] for lora in locked_loras} + locked_names = {lora["name"] for lora in locked_loras} available_pool = [ - l for l in available_loras - if l['model_name'] not in locked_names + l for l in available_loras if l["file_name"] not in locked_names ] # Ensure we don't try to select more than available @@ -253,9 +284,7 @@ class LoraService(BaseModelService): # Generate random strengths for selected LoRAs result_loras = [] for lora in selected: - model_str = round( - random.uniform(model_strength_min, model_strength_max), 2 - ) + model_str = round(random.uniform(model_strength_min, model_strength_max), 2) if use_same_clip_strength: clip_str = model_str @@ -264,21 +293,25 @@ class LoraService(BaseModelService): random.uniform(clip_strength_min, clip_strength_max), 2 ) - result_loras.append({ - 'name': lora['model_name'], - 'strength': model_str, - 'clipStrength': clip_str, - 'active': True, - 'expanded': abs(model_str - clip_str) > 0.001, - 'locked': False - }) + result_loras.append( + { + "name": lora["file_name"], + "strength": model_str, + "clipStrength": clip_str, + "active": True, + "expanded": abs(model_str - clip_str) > 0.001, + "locked": False, + } + ) # Merge with locked LoRAs result_loras.extend(locked_loras) return result_loras - async def _apply_pool_filters(self, available_loras: List[Dict], pool_config: Dict) -> List[Dict]: + async def _apply_pool_filters( + self, available_loras: List[Dict], pool_config: Dict + ) -> List[Dict]: """ Apply pool_config filters to available LoRAs. @@ -292,26 +325,26 @@ class LoraService(BaseModelService): from .model_query import FilterCriteria # Extract filter parameters from pool_config - selected_base_models = pool_config.get('selected_base_models', []) - include_tags = pool_config.get('include_tags', []) - exclude_tags = pool_config.get('exclude_tags', []) - include_folders = pool_config.get('include_folders', []) - exclude_folders = pool_config.get('exclude_folders', []) - no_credit_required = pool_config.get('no_credit_required', False) - allow_selling = pool_config.get('allow_selling', False) + selected_base_models = pool_config.get("selected_base_models", []) + include_tags = pool_config.get("include_tags", []) + exclude_tags = pool_config.get("exclude_tags", []) + include_folders = pool_config.get("include_folders", []) + exclude_folders = pool_config.get("exclude_folders", []) + no_credit_required = pool_config.get("no_credit_required", False) + allow_selling = pool_config.get("allow_selling", False) # Build tag filters dict tag_filters = {} for tag in include_tags: - tag_filters[tag] = 'include' + tag_filters[tag] = "include" for tag in exclude_tags: - tag_filters[tag] = 'exclude' + tag_filters[tag] = "exclude" # Build folder filter if include_folders or exclude_folders: filtered = [] for lora in available_loras: - folder = lora.get('folder', '') + folder = lora.get("folder", "") # Check exclude folders first excluded = False @@ -340,8 +373,9 @@ class LoraService(BaseModelService): # Apply base model filter if selected_base_models: available_loras = [ - lora for lora in available_loras - if lora.get('base_model') in selected_base_models + lora + for lora in available_loras + if lora.get("base_model") in selected_base_models ] # Apply tag filters @@ -352,14 +386,17 @@ class LoraService(BaseModelService): # Apply license filters if no_credit_required: available_loras = [ - lora for lora in available_loras - if not lora.get('civitai', {}).get('allowNoCredit', True) + lora + for lora in available_loras + if not lora.get("civitai", {}).get("allowNoCredit", True) ] if allow_selling: available_loras = [ - lora for lora in available_loras - if lora.get('civitai', {}).get('allowCommercialUse', ['None'])[0] != 'None' + lora + for lora in available_loras + if lora.get("civitai", {}).get("allowCommercialUse", ["None"])[0] + != "None" ] return available_loras diff --git a/vue-widgets/src/components/LoraRandomizerWidget.vue b/vue-widgets/src/components/LoraRandomizerWidget.vue index 75e4a5f3..f31132d1 100644 --- a/vue-widgets/src/components/LoraRandomizerWidget.vue +++ b/vue-widgets/src/components/LoraRandomizerWidget.vue @@ -32,12 +32,12 @@ import { onMounted } from 'vue' import LoraRandomizerSettingsView from './lora-randomizer/LoraRandomizerSettingsView.vue' import { useLoraRandomizerState } from '../composables/useLoraRandomizerState' -import type { ComponentWidget, RandomizerConfig } from '../composables/types' +import type { ComponentWidget, RandomizerConfig, LoraEntry } from '../composables/types' // Props const props = defineProps<{ widget: ComponentWidget - node: { id: number } + node: { id: number; inputs?: any[]; widgets?: any[]; graph?: any } }>() // State management @@ -48,13 +48,15 @@ const handleRoll = async () => { try { console.log('[LoraRandomizerWidget] Roll button clicked') - // Get pool config from connected input (if any) - // This would need to be passed from the node's pool_config input - const poolConfig = null // TODO: Get from node input if connected + // Get pool config from connected pool_config input + const poolConfig = (props.node as any).getPoolConfig?.() || null // Get locked loras from the loras widget - // This would need to be retrieved from the loras widget on the node - const lockedLoras: any[] = [] // TODO: Get from loras widget + const lorasWidget = props.node.widgets?.find((w: any) => w.name === "loras") + const lockedLoras: LoraEntry[] = (lorasWidget?.value || []).filter((lora: LoraEntry) => lora.locked === true) + + console.log('[LoraRandomizerWidget] Pool config:', poolConfig) + console.log('[LoraRandomizerWidget] Locked loras:', lockedLoras) // Call API to get random loras const randomLoras = await state.rollLoras(poolConfig, lockedLoras) diff --git a/vue-widgets/src/components/lora-randomizer/LoraRandomizerSettingsView.vue b/vue-widgets/src/components/lora-randomizer/LoraRandomizerSettingsView.vue index 8719b715..224dcb62 100644 --- a/vue-widgets/src/components/lora-randomizer/LoraRandomizerSettingsView.vue +++ b/vue-widgets/src/components/lora-randomizer/LoraRandomizerSettingsView.vue @@ -154,8 +154,10 @@ :disabled="rollMode !== 'frontend' || isRolling" @click="$emit('roll')" > - 🎲 Roll - Rolling... +