diff --git a/py/nodes/lora_randomizer.py b/py/nodes/lora_randomizer.py index 9e610776..abb0793d 100644 --- a/py/nodes/lora_randomizer.py +++ b/py/nodes/lora_randomizer.py @@ -39,6 +39,22 @@ class LoraRandomizerNode: FUNCTION = "randomize" OUTPUT_NODE = False + def _preprocess_loras_input(self, loras): + """ + Preprocess loras input to handle different widget formats. + + Args: + loras: Input from widget, either: + - List of LoRA dicts (expected format) + - Dict with '__value__' key containing the list + + Returns: + List of LoRA dicts + """ + if isinstance(loras, dict) and "__value__" in loras: + return loras["__value__"] + return loras + async def randomize(self, randomizer_config, loras, pool_config=None): """ Randomize LoRAs based on configuration and pool filters. @@ -53,6 +69,8 @@ class LoraRandomizerNode: """ from ..services.service_registry import ServiceRegistry + loras = self._preprocess_loras_input(loras) + roll_mode = randomizer_config.get("roll_mode", "always") logger.debug(f"[LoraRandomizerNode] roll_mode: {roll_mode}") @@ -64,6 +82,8 @@ class LoraRandomizerNode: scanner, randomizer_config, loras, pool_config ) + print("pool config", pool_config) + execution_stack = self._build_execution_stack_from_input(loras) return { @@ -120,6 +140,8 @@ class LoraRandomizerNode: Returns: List of LoRA dicts for UI display """ + from ..services.lora_service import LoraService + # Parse randomizer settings count_mode = randomizer_config.get("count_mode", "range") count_fixed = randomizer_config.get("count_fixed", 5) @@ -131,183 +153,23 @@ class LoraRandomizerNode: clip_strength_min = randomizer_config.get("clip_strength_min", 0.0) clip_strength_max = randomizer_config.get("clip_strength_max", 1.0) - # Determine target count - if count_mode == "fixed": - target_count = count_fixed - else: - target_count = random.randint(count_min, count_max) - # Extract locked LoRAs from input locked_loras = [lora for lora in input_loras if lora.get("locked", False)] - locked_count = len(locked_loras) - # Get available loras from cache - try: - cache_data = await scanner.get_cached_data(force_refresh=False) - if cache_data and hasattr(cache_data, "raw_data"): - available_loras = cache_data.raw_data - else: - available_loras = [] - except Exception as e: - logger.warning(f"[LoraRandomizerNode] Failed to get lora cache: {e}") - available_loras = [] - - # Apply pool filters if provided - if pool_config: - available_loras = await self._apply_pool_filters( - available_loras, pool_config, scanner - ) - - # Calculate how many new LoRAs to select - slots_needed = target_count - locked_count - - if slots_needed < 0: - slots_needed = 0 - # Too many locked, trim to target - locked_loras = locked_loras[:target_count] - locked_count = len(locked_loras) - - # Filter out locked LoRAs from available pool - locked_names = {lora["name"] for lora in locked_loras} - available_pool = [ - l for l in available_loras if l["file_name"] not in locked_names - ] - - # Ensure we don't try to select more than available - if slots_needed > len(available_pool): - slots_needed = len(available_pool) - - # Random sample - selected = [] - if slots_needed > 0: - selected = random.sample(available_pool, slots_needed) - - # Generate random strengths for selected LoRAs - result_loras = [] - for lora in selected: - model_str = round(random.uniform(model_strength_min, model_strength_max), 2) - - if use_same_clip_strength: - clip_str = model_str - else: - clip_str = round( - random.uniform(clip_strength_min, clip_strength_max), 2 - ) - - result_loras.append( - { - "name": lora["file_name"], - "strength": model_str, - "clipStrength": clip_str, - "active": True, - "expanded": abs(model_str - clip_str) > 0.001, - "locked": False, - } - ) - - # Merge with locked LoRAs - result_loras.extend(locked_loras) + # Use LoraService to generate random LoRAs + lora_service = LoraService(scanner) + result_loras = await lora_service.get_random_loras( + count=count_fixed, + model_strength_min=model_strength_min, + model_strength_max=model_strength_max, + use_same_clip_strength=use_same_clip_strength, + clip_strength_min=clip_strength_min, + clip_strength_max=clip_strength_max, + locked_loras=locked_loras, + pool_config=pool_config, + count_mode=count_mode, + count_min=count_min, + count_max=count_max, + ) return result_loras - - async def _apply_pool_filters(self, available_loras, pool_config, scanner): - """ - Apply pool_config filters to available LoRAs. - - Args: - available_loras: List of all LoRA dicts - pool_config: Dict with filter settings from LoRA Pool node - scanner: Scanner instance for accessing filter utilities - - Returns: - Filtered list of LoRA dicts - """ - from ..services.lora_service import LoraService - from ..services.model_query import FilterCriteria - - # Create lora service instance for filtering - lora_service = LoraService(scanner) - - # Extract filter parameters from pool_config - selected_base_models = pool_config.get("baseModels", []) - tags_dict = pool_config.get("tags", {}) - include_tags = tags_dict.get("include", []) - exclude_tags = tags_dict.get("exclude", []) - folders_dict = pool_config.get("folders", {}) - include_folders = folders_dict.get("include", []) - exclude_folders = folders_dict.get("exclude", []) - license_dict = pool_config.get("license", {}) - no_credit_required = license_dict.get("noCreditRequired", False) - allow_selling = license_dict.get("allowSelling", False) - - # Build tag filters dict - tag_filters = {} - for tag in include_tags: - tag_filters[tag] = "include" - for tag in exclude_tags: - tag_filters[tag] = "exclude" - - # Build folder filter - # LoRA Pool uses include/exclude folders, we need to apply this logic - # For now, we'll filter based on folder path matching - if include_folders or exclude_folders: - filtered = [] - for lora in available_loras: - folder = lora.get("folder", "") - - # Check exclude folders first - excluded = False - for exclude_folder in exclude_folders: - if folder.startswith(exclude_folder): - excluded = True - break - - if excluded: - continue - - # Check include folders - if include_folders: - included = False - for include_folder in include_folders: - if folder.startswith(include_folder): - included = True - break - if not included: - continue - - filtered.append(lora) - - available_loras = filtered - - # Apply base model filter - if selected_base_models: - available_loras = [ - lora - for lora in available_loras - if lora.get("base_model") in selected_base_models - ] - - # Apply tag filters - if tag_filters: - criteria = FilterCriteria(tags=tag_filters) - available_loras = lora_service.filter_set.apply(available_loras, criteria) - - # Apply license filters - # Note: no_credit_required=True means filter out models where credit is NOT required - # (i.e., keep only models where credit IS required) - if no_credit_required: - available_loras = [ - lora - for lora in available_loras - if not (lora.get("license_flags", 127) & (1 << 0)) - ] - - # allow_selling=True means keep only models where selling generated content is allowed - if allow_selling: - available_loras = [ - lora - for lora in available_loras - if bool(lora.get("license_flags", 127) & (1 << 1)) - ] - - return available_loras diff --git a/py/services/lora_service.py b/py/services/lora_service.py index d5940389..3c900423 100644 --- a/py/services/lora_service.py +++ b/py/services/lora_service.py @@ -225,12 +225,15 @@ class LoraService(BaseModelService): clip_strength_max: float = 1.0, locked_loras: Optional[List[Dict]] = None, pool_config: Optional[Dict] = None, + count_mode: str = "fixed", + count_min: int = 3, + count_max: int = 7, ) -> List[Dict]: """ Get random LoRAs with specified strength ranges. Args: - count: Number of LoRAs to select + count: Number of LoRAs to select (if count_mode='fixed') model_strength_min: Minimum model strength model_strength_max: Maximum model strength use_same_clip_strength: Whether to use same strength for clip @@ -238,6 +241,9 @@ class LoraService(BaseModelService): clip_strength_max: Maximum clip strength locked_loras: List of locked LoRA dicts to preserve pool_config: Optional pool config for filtering + count_mode: How to determine count ('fixed' or 'range') + count_min: Minimum count for range mode + count_max: Maximum count for range mode Returns: List of LoRA dicts with randomized strengths @@ -247,6 +253,12 @@ class LoraService(BaseModelService): if locked_loras is None: locked_loras = [] + # Determine target count based on count_mode + if count_mode == "fixed": + target_count = count + else: + target_count = random.randint(count_min, count_max) + # Get available loras from cache cache = await self.scanner.get_cached_data(force_refresh=False) available_loras = cache.raw_data if cache else [] @@ -259,12 +271,12 @@ class LoraService(BaseModelService): # Calculate slots needed (total - locked) locked_count = len(locked_loras) - slots_needed = count - locked_count + slots_needed = target_count - locked_count if slots_needed < 0: slots_needed = 0 # Too many locked, trim to target - locked_loras = locked_loras[:count] + locked_loras = locked_loras[:target_count] # Filter out locked LoRAs from available pool locked_names = {lora["name"] for lora in locked_loras} @@ -324,14 +336,19 @@ class LoraService(BaseModelService): """ from .model_query import FilterCriteria - # Extract filter parameters from pool_config - selected_base_models = pool_config.get("selected_base_models", []) - include_tags = pool_config.get("include_tags", []) - exclude_tags = pool_config.get("exclude_tags", []) - include_folders = pool_config.get("include_folders", []) - exclude_folders = pool_config.get("exclude_folders", []) - no_credit_required = pool_config.get("no_credit_required", False) - allow_selling = pool_config.get("allow_selling", False) + filter_section = pool_config + + # Extract filter parameters + selected_base_models = filter_section.get("baseModels", []) + tags_dict = filter_section.get("tags", {}) + include_tags = tags_dict.get("include", []) + exclude_tags = tags_dict.get("exclude", []) + folders_dict = filter_section.get("folders", {}) + include_folders = folders_dict.get("include", []) + exclude_folders = folders_dict.get("exclude", []) + license_dict = filter_section.get("license", {}) + no_credit_required = license_dict.get("noCreditRequired", False) + allow_selling = license_dict.get("allowSelling", False) # Build tag filters dict tag_filters = {} @@ -384,13 +401,13 @@ class LoraService(BaseModelService): available_loras = self.filter_set.apply(available_loras, criteria) # Apply license filters - # Note: no_credit_required=True means filter out models where credit is NOT required - # (i.e., keep only models where credit IS required) + # no_credit_required=True means keep only models where credit is NOT required + # (i.e., allowNoCredit=True, which is bit 0 = 1 in license_flags) if no_credit_required: available_loras = [ lora for lora in available_loras - if not (lora.get("license_flags", 127) & (1 << 0)) + if bool(lora.get("license_flags", 127) & (1 << 0)) ] # allow_selling=True means keep only models where selling generated content is allowed diff --git a/tests/services/test_lora_pool_filters.py b/tests/services/test_lora_pool_filters.py new file mode 100644 index 00000000..f0169a71 --- /dev/null +++ b/tests/services/test_lora_pool_filters.py @@ -0,0 +1,371 @@ +"""Tests for LoraService pool filtering functionality.""" + +import pytest +from unittest.mock import Mock, AsyncMock + +from py.services.lora_service import LoraService +from py.utils.civitai_utils import build_license_flags + + +@pytest.fixture +def lora_service(): + """Create a LoraService instance for testing.""" + scanner = Mock() + cache_mock = Mock() + cache_mock.raw_data = [] + scanner.get_cached_data = AsyncMock(return_value=cache_mock) + scanner._hash_index = Mock() + scanner._hash_index.get_duplicate_hashes = Mock(return_value={}) + scanner._hash_index.get_duplicate_filenames = Mock(return_value={}) + + service = LoraService(scanner) + service.filter_set = Mock() + service.filter_set.apply = Mock(return_value=None) + + return service + + +@pytest.fixture +def sample_loras(): + """Sample loras with various license configurations.""" + return [ + { + "file_name": "credit_required_not_for_selling.safetensors", + "base_model": "Illustrious", + "folder": "", + "license_flags": build_license_flags( + {"allowNoCredit": False, "allowCommercialUse": ["Rent"]} + ), + }, + { + "file_name": "no_credit_required_for_selling.safetensors", + "base_model": "Illustrious", + "folder": "", + "license_flags": build_license_flags( + {"allowNoCredit": True, "allowCommercialUse": ["Image"]} + ), + }, + { + "file_name": "credit_required_for_selling.safetensors", + "base_model": "Illustrious", + "folder": "", + "license_flags": build_license_flags( + {"allowNoCredit": False, "allowCommercialUse": ["Image"]} + ), + }, + { + "file_name": "no_credit_required_not_for_selling.safetensors", + "base_model": "Illustrious", + "folder": "", + "license_flags": build_license_flags( + {"allowNoCredit": True, "allowCommercialUse": ["Rent"]} + ), + }, + { + "file_name": "default_license.safetensors", + "base_model": "Illustrious", + "folder": "", + "license_flags": build_license_flags(None), + }, + ] + + +@pytest.mark.asyncio +async def test_pool_filter_no_credit_required_true(lora_service, sample_loras): + """Test that no_credit_required=True keeps only models where credit is NOT required.""" + pool_config = { + "baseModels": [], + "tags": {"include": [], "exclude": []}, + "folders": {"include": [], "exclude": []}, + "license": { + "noCreditRequired": True, + "allowSelling": False, + }, + } + + filtered = await lora_service._apply_pool_filters(sample_loras, pool_config) + + # Should keep models with allowNoCredit=True (bit 0 = 1) + # Models: no_credit_required_for_selling, no_credit_required_not_for_selling, default_license + assert len(filtered) == 3 + file_names = {lora["file_name"] for lora in filtered} + assert file_names == { + "no_credit_required_for_selling.safetensors", + "no_credit_required_not_for_selling.safetensors", + "default_license.safetensors", + } + + +@pytest.mark.asyncio +async def test_pool_filter_no_credit_required_false(lora_service, sample_loras): + """Test that no_credit_required=False keeps all models (no filter applied).""" + pool_config = { + "baseModels": [], + "tags": {"include": [], "exclude": []}, + "folders": {"include": [], "exclude": []}, + "license": { + "noCreditRequired": False, + "allowSelling": False, + }, + } + + filtered = await lora_service._apply_pool_filters(sample_loras, pool_config) + + # Should keep all models when no_credit_required=False + assert len(filtered) == 5 + + +@pytest.mark.asyncio +async def test_pool_filter_allow_selling_true(lora_service, sample_loras): + """Test that allowSelling=True keeps only models where selling is allowed.""" + pool_config = { + "baseModels": [], + "tags": {"include": [], "exclude": []}, + "folders": {"include": [], "exclude": []}, + "license": { + "noCreditRequired": False, + "allowSelling": True, + }, + } + + filtered = await lora_service._apply_pool_filters(sample_loras, pool_config) + + # Should keep models with Image permission (allowSelling) + # Models: no_credit_required_for_selling, credit_required_for_selling, default_license + assert len(filtered) == 3 + file_names = {lora["file_name"] for lora in filtered} + assert file_names == { + "no_credit_required_for_selling.safetensors", + "credit_required_for_selling.safetensors", + "default_license.safetensors", + } + + +@pytest.mark.asyncio +async def test_pool_filter_allow_selling_false(lora_service, sample_loras): + """Test that allowSelling=False keeps all models (no filter applied).""" + pool_config = { + "baseModels": [], + "tags": {"include": [], "exclude": []}, + "folders": {"include": [], "exclude": []}, + "license": { + "noCreditRequired": False, + "allowSelling": False, + }, + } + + filtered = await lora_service._apply_pool_filters(sample_loras, pool_config) + + # Should keep all models when allowSelling=False + assert len(filtered) == 5 + + +@pytest.mark.asyncio +async def test_pool_filter_both_license_filters(lora_service, sample_loras): + """Test combining both no_credit_required and allowSelling filters.""" + pool_config = { + "baseModels": [], + "tags": {"include": [], "exclude": []}, + "folders": {"include": [], "exclude": []}, + "license": { + "noCreditRequired": True, + "allowSelling": True, + }, + } + + filtered = await lora_service._apply_pool_filters(sample_loras, pool_config) + + # Should keep models where both conditions are met: + # - allowNoCredit=True (no credit required) + # - Image permission exists (allow selling) + # Models: no_credit_required_for_selling, default_license + assert len(filtered) == 2 + file_names = {lora["file_name"] for lora in filtered} + assert file_names == { + "no_credit_required_for_selling.safetensors", + "default_license.safetensors", + } + + +@pytest.mark.asyncio +async def test_pool_filter_base_models(lora_service, sample_loras): + """Test filtering by base models.""" + pool_config = { + "baseModels": ["Illustrious"], + "tags": {"include": [], "exclude": []}, + "folders": {"include": [], "exclude": []}, + "license": { + "noCreditRequired": False, + "allowSelling": False, + }, + } + + filtered = await lora_service._apply_pool_filters(sample_loras, pool_config) + + # All sample loras have base_model="Illustrious" + assert len(filtered) == 5 + + # Test with non-matching base model + pool_config["baseModels"] = ["SD15"] + filtered = await lora_service._apply_pool_filters(sample_loras, pool_config) + assert len(filtered) == 0 + + +@pytest.mark.asyncio +async def test_pool_filter_folders(lora_service): + """Test filtering by folders.""" + sample_loras = [ + { + "file_name": "lora1.safetensors", + "base_model": "Illustrious", + "folder": "characters/", + "license_flags": build_license_flags(None), + }, + { + "file_name": "lora2.safetensors", + "base_model": "Illustrious", + "folder": "styles/", + "license_flags": build_license_flags(None), + }, + { + "file_name": "lora3.safetensors", + "base_model": "Illustrious", + "folder": "concepts/", + "license_flags": build_license_flags(None), + }, + ] + + # Test include folders + pool_config = { + "baseModels": [], + "tags": {"include": [], "exclude": []}, + "folders": {"include": ["characters/"], "exclude": []}, + "license": { + "noCreditRequired": False, + "allowSelling": False, + }, + } + + filtered = await lora_service._apply_pool_filters(sample_loras, pool_config) + assert len(filtered) == 1 + assert filtered[0]["file_name"] == "lora1.safetensors" + + # Test exclude folders + pool_config = { + "baseModels": [], + "tags": {"include": [], "exclude": []}, + "folders": {"include": [], "exclude": ["characters/"]}, + "license": { + "noCreditRequired": False, + "allowSelling": False, + }, + } + + filtered = await lora_service._apply_pool_filters(sample_loras, pool_config) + assert len(filtered) == 2 + file_names = {lora["file_name"] for lora in filtered} + assert file_names == {"lora2.safetensors", "lora3.safetensors"} + + +@pytest.mark.asyncio +async def test_pool_filter_tags(lora_service): + """Test filtering by tags.""" + lora_service.filter_set.apply = Mock(side_effect=lambda data, criteria: data) + + sample_loras = [ + { + "file_name": "lora1.safetensors", + "base_model": "Illustrious", + "folder": "", + "tags": ["anime", "character"], + "license_flags": build_license_flags(None), + }, + { + "file_name": "lora2.safetensors", + "base_model": "Illustrious", + "folder": "", + "tags": ["realistic", "style"], + "license_flags": build_license_flags(None), + }, + ] + + pool_config = { + "baseModels": [], + "tags": {"include": ["anime"], "exclude": []}, + "folders": {"include": [], "exclude": []}, + "license": { + "noCreditRequired": False, + "allowSelling": False, + }, + } + + filtered = await lora_service._apply_pool_filters(sample_loras, pool_config) + + # Should call filter_set.apply with tag filters + assert lora_service.filter_set.apply.called + call_args = lora_service.filter_set.apply.call_args + assert call_args[0][0] == sample_loras + assert "anime" in call_args[0][1].tags + + +@pytest.mark.asyncio +async def test_pool_filter_combined_all_filters(lora_service): + """Test combining all filter types.""" + test_loras = [ + { + "file_name": "match_all.safetensors", + "base_model": "Illustrious", + "folder": "folder1/", + "tags": ["tag1"], + "license_flags": build_license_flags({"allowNoCredit": True}), + }, + { + "file_name": "wrong_base_model.safetensors", + "base_model": "SD15", + "folder": "folder1/", + "tags": ["tag1"], + "license_flags": build_license_flags({"allowNoCredit": True}), + }, + { + "file_name": "wrong_folder.safetensors", + "base_model": "Illustrious", + "folder": "folder2/", + "tags": ["tag1"], + "license_flags": build_license_flags({"allowNoCredit": True}), + }, + { + "file_name": "credit_required.safetensors", + "base_model": "Illustrious", + "folder": "folder1/", + "tags": ["tag1"], + "license_flags": build_license_flags({"allowNoCredit": False}), + }, + ] + + # Mock tag filter to return all items (simulate tag1 match) + def mock_tag_filter(data, criteria): + return data + + lora_service.filter_set.apply = Mock(side_effect=mock_tag_filter) + + pool_config = { + "baseModels": ["Illustrious"], + "tags": {"include": ["tag1"], "exclude": []}, + "folders": {"include": ["folder1/"], "exclude": []}, + "license": { + "noCreditRequired": True, + "allowSelling": False, + }, + } + + filtered = await lora_service._apply_pool_filters(test_loras, pool_config) + + # Should apply all filters + assert lora_service.filter_set.apply.called + # Only "match_all.safetensors" should match: + # - base_model: Illustrious ✓ + # - folder: folder1/ ✓ + # - no_credit_required: True ✓ (bit 0 = 1) + # - tags: tag1 ✓ + assert len(filtered) == 1 + assert filtered[0]["file_name"] == "match_all.safetensors" diff --git a/vue-widgets/src/components/lora-pool/sections/LicenseSection.vue b/vue-widgets/src/components/lora-pool/sections/LicenseSection.vue index 5b2e2986..d175436f 100644 --- a/vue-widgets/src/components/lora-pool/sections/LicenseSection.vue +++ b/vue-widgets/src/components/lora-pool/sections/LicenseSection.vue @@ -5,7 +5,7 @@