diff --git a/py/nodes/lora_randomizer.py b/py/nodes/lora_randomizer.py
index 9e610776..abb0793d 100644
--- a/py/nodes/lora_randomizer.py
+++ b/py/nodes/lora_randomizer.py
@@ -39,6 +39,22 @@ class LoraRandomizerNode:
FUNCTION = "randomize"
OUTPUT_NODE = False
+ def _preprocess_loras_input(self, loras):
+ """
+ Preprocess loras input to handle different widget formats.
+
+ Args:
+ loras: Input from widget, either:
+ - List of LoRA dicts (expected format)
+ - Dict with '__value__' key containing the list
+
+ Returns:
+ List of LoRA dicts
+ """
+ if isinstance(loras, dict) and "__value__" in loras:
+ return loras["__value__"]
+ return loras
+
async def randomize(self, randomizer_config, loras, pool_config=None):
"""
Randomize LoRAs based on configuration and pool filters.
@@ -53,6 +69,8 @@ class LoraRandomizerNode:
"""
from ..services.service_registry import ServiceRegistry
+ loras = self._preprocess_loras_input(loras)
+
roll_mode = randomizer_config.get("roll_mode", "always")
logger.debug(f"[LoraRandomizerNode] roll_mode: {roll_mode}")
@@ -64,6 +82,8 @@ class LoraRandomizerNode:
scanner, randomizer_config, loras, pool_config
)
+ print("pool config", pool_config)
+
execution_stack = self._build_execution_stack_from_input(loras)
return {
@@ -120,6 +140,8 @@ class LoraRandomizerNode:
Returns:
List of LoRA dicts for UI display
"""
+ from ..services.lora_service import LoraService
+
# Parse randomizer settings
count_mode = randomizer_config.get("count_mode", "range")
count_fixed = randomizer_config.get("count_fixed", 5)
@@ -131,183 +153,23 @@ class LoraRandomizerNode:
clip_strength_min = randomizer_config.get("clip_strength_min", 0.0)
clip_strength_max = randomizer_config.get("clip_strength_max", 1.0)
- # Determine target count
- if count_mode == "fixed":
- target_count = count_fixed
- else:
- target_count = random.randint(count_min, count_max)
-
# Extract locked LoRAs from input
locked_loras = [lora for lora in input_loras if lora.get("locked", False)]
- locked_count = len(locked_loras)
- # Get available loras from cache
- try:
- cache_data = await scanner.get_cached_data(force_refresh=False)
- if cache_data and hasattr(cache_data, "raw_data"):
- available_loras = cache_data.raw_data
- else:
- available_loras = []
- except Exception as e:
- logger.warning(f"[LoraRandomizerNode] Failed to get lora cache: {e}")
- available_loras = []
-
- # Apply pool filters if provided
- if pool_config:
- available_loras = await self._apply_pool_filters(
- available_loras, pool_config, scanner
- )
-
- # Calculate how many new LoRAs to select
- slots_needed = target_count - locked_count
-
- if slots_needed < 0:
- slots_needed = 0
- # Too many locked, trim to target
- locked_loras = locked_loras[:target_count]
- locked_count = len(locked_loras)
-
- # Filter out locked LoRAs from available pool
- locked_names = {lora["name"] for lora in locked_loras}
- available_pool = [
- l for l in available_loras if l["file_name"] not in locked_names
- ]
-
- # Ensure we don't try to select more than available
- if slots_needed > len(available_pool):
- slots_needed = len(available_pool)
-
- # Random sample
- selected = []
- if slots_needed > 0:
- selected = random.sample(available_pool, slots_needed)
-
- # Generate random strengths for selected LoRAs
- result_loras = []
- for lora in selected:
- model_str = round(random.uniform(model_strength_min, model_strength_max), 2)
-
- if use_same_clip_strength:
- clip_str = model_str
- else:
- clip_str = round(
- random.uniform(clip_strength_min, clip_strength_max), 2
- )
-
- result_loras.append(
- {
- "name": lora["file_name"],
- "strength": model_str,
- "clipStrength": clip_str,
- "active": True,
- "expanded": abs(model_str - clip_str) > 0.001,
- "locked": False,
- }
- )
-
- # Merge with locked LoRAs
- result_loras.extend(locked_loras)
+ # Use LoraService to generate random LoRAs
+ lora_service = LoraService(scanner)
+ result_loras = await lora_service.get_random_loras(
+ count=count_fixed,
+ model_strength_min=model_strength_min,
+ model_strength_max=model_strength_max,
+ use_same_clip_strength=use_same_clip_strength,
+ clip_strength_min=clip_strength_min,
+ clip_strength_max=clip_strength_max,
+ locked_loras=locked_loras,
+ pool_config=pool_config,
+ count_mode=count_mode,
+ count_min=count_min,
+ count_max=count_max,
+ )
return result_loras
-
- async def _apply_pool_filters(self, available_loras, pool_config, scanner):
- """
- Apply pool_config filters to available LoRAs.
-
- Args:
- available_loras: List of all LoRA dicts
- pool_config: Dict with filter settings from LoRA Pool node
- scanner: Scanner instance for accessing filter utilities
-
- Returns:
- Filtered list of LoRA dicts
- """
- from ..services.lora_service import LoraService
- from ..services.model_query import FilterCriteria
-
- # Create lora service instance for filtering
- lora_service = LoraService(scanner)
-
- # Extract filter parameters from pool_config
- selected_base_models = pool_config.get("baseModels", [])
- tags_dict = pool_config.get("tags", {})
- include_tags = tags_dict.get("include", [])
- exclude_tags = tags_dict.get("exclude", [])
- folders_dict = pool_config.get("folders", {})
- include_folders = folders_dict.get("include", [])
- exclude_folders = folders_dict.get("exclude", [])
- license_dict = pool_config.get("license", {})
- no_credit_required = license_dict.get("noCreditRequired", False)
- allow_selling = license_dict.get("allowSelling", False)
-
- # Build tag filters dict
- tag_filters = {}
- for tag in include_tags:
- tag_filters[tag] = "include"
- for tag in exclude_tags:
- tag_filters[tag] = "exclude"
-
- # Build folder filter
- # LoRA Pool uses include/exclude folders, we need to apply this logic
- # For now, we'll filter based on folder path matching
- if include_folders or exclude_folders:
- filtered = []
- for lora in available_loras:
- folder = lora.get("folder", "")
-
- # Check exclude folders first
- excluded = False
- for exclude_folder in exclude_folders:
- if folder.startswith(exclude_folder):
- excluded = True
- break
-
- if excluded:
- continue
-
- # Check include folders
- if include_folders:
- included = False
- for include_folder in include_folders:
- if folder.startswith(include_folder):
- included = True
- break
- if not included:
- continue
-
- filtered.append(lora)
-
- available_loras = filtered
-
- # Apply base model filter
- if selected_base_models:
- available_loras = [
- lora
- for lora in available_loras
- if lora.get("base_model") in selected_base_models
- ]
-
- # Apply tag filters
- if tag_filters:
- criteria = FilterCriteria(tags=tag_filters)
- available_loras = lora_service.filter_set.apply(available_loras, criteria)
-
- # Apply license filters
- # Note: no_credit_required=True means filter out models where credit is NOT required
- # (i.e., keep only models where credit IS required)
- if no_credit_required:
- available_loras = [
- lora
- for lora in available_loras
- if not (lora.get("license_flags", 127) & (1 << 0))
- ]
-
- # allow_selling=True means keep only models where selling generated content is allowed
- if allow_selling:
- available_loras = [
- lora
- for lora in available_loras
- if bool(lora.get("license_flags", 127) & (1 << 1))
- ]
-
- return available_loras
diff --git a/py/services/lora_service.py b/py/services/lora_service.py
index d5940389..3c900423 100644
--- a/py/services/lora_service.py
+++ b/py/services/lora_service.py
@@ -225,12 +225,15 @@ class LoraService(BaseModelService):
clip_strength_max: float = 1.0,
locked_loras: Optional[List[Dict]] = None,
pool_config: Optional[Dict] = None,
+ count_mode: str = "fixed",
+ count_min: int = 3,
+ count_max: int = 7,
) -> List[Dict]:
"""
Get random LoRAs with specified strength ranges.
Args:
- count: Number of LoRAs to select
+ count: Number of LoRAs to select (if count_mode='fixed')
model_strength_min: Minimum model strength
model_strength_max: Maximum model strength
use_same_clip_strength: Whether to use same strength for clip
@@ -238,6 +241,9 @@ class LoraService(BaseModelService):
clip_strength_max: Maximum clip strength
locked_loras: List of locked LoRA dicts to preserve
pool_config: Optional pool config for filtering
+ count_mode: How to determine count ('fixed' or 'range')
+ count_min: Minimum count for range mode
+ count_max: Maximum count for range mode
Returns:
List of LoRA dicts with randomized strengths
@@ -247,6 +253,12 @@ class LoraService(BaseModelService):
if locked_loras is None:
locked_loras = []
+ # Determine target count based on count_mode
+ if count_mode == "fixed":
+ target_count = count
+ else:
+ target_count = random.randint(count_min, count_max)
+
# Get available loras from cache
cache = await self.scanner.get_cached_data(force_refresh=False)
available_loras = cache.raw_data if cache else []
@@ -259,12 +271,12 @@ class LoraService(BaseModelService):
# Calculate slots needed (total - locked)
locked_count = len(locked_loras)
- slots_needed = count - locked_count
+ slots_needed = target_count - locked_count
if slots_needed < 0:
slots_needed = 0
# Too many locked, trim to target
- locked_loras = locked_loras[:count]
+ locked_loras = locked_loras[:target_count]
# Filter out locked LoRAs from available pool
locked_names = {lora["name"] for lora in locked_loras}
@@ -324,14 +336,19 @@ class LoraService(BaseModelService):
"""
from .model_query import FilterCriteria
- # Extract filter parameters from pool_config
- selected_base_models = pool_config.get("selected_base_models", [])
- include_tags = pool_config.get("include_tags", [])
- exclude_tags = pool_config.get("exclude_tags", [])
- include_folders = pool_config.get("include_folders", [])
- exclude_folders = pool_config.get("exclude_folders", [])
- no_credit_required = pool_config.get("no_credit_required", False)
- allow_selling = pool_config.get("allow_selling", False)
+ filter_section = pool_config
+
+ # Extract filter parameters
+ selected_base_models = filter_section.get("baseModels", [])
+ tags_dict = filter_section.get("tags", {})
+ include_tags = tags_dict.get("include", [])
+ exclude_tags = tags_dict.get("exclude", [])
+ folders_dict = filter_section.get("folders", {})
+ include_folders = folders_dict.get("include", [])
+ exclude_folders = folders_dict.get("exclude", [])
+ license_dict = filter_section.get("license", {})
+ no_credit_required = license_dict.get("noCreditRequired", False)
+ allow_selling = license_dict.get("allowSelling", False)
# Build tag filters dict
tag_filters = {}
@@ -384,13 +401,13 @@ class LoraService(BaseModelService):
available_loras = self.filter_set.apply(available_loras, criteria)
# Apply license filters
- # Note: no_credit_required=True means filter out models where credit is NOT required
- # (i.e., keep only models where credit IS required)
+ # no_credit_required=True means keep only models where credit is NOT required
+ # (i.e., allowNoCredit=True, which is bit 0 = 1 in license_flags)
if no_credit_required:
available_loras = [
lora
for lora in available_loras
- if not (lora.get("license_flags", 127) & (1 << 0))
+ if bool(lora.get("license_flags", 127) & (1 << 0))
]
# allow_selling=True means keep only models where selling generated content is allowed
diff --git a/tests/services/test_lora_pool_filters.py b/tests/services/test_lora_pool_filters.py
new file mode 100644
index 00000000..f0169a71
--- /dev/null
+++ b/tests/services/test_lora_pool_filters.py
@@ -0,0 +1,371 @@
+"""Tests for LoraService pool filtering functionality."""
+
+import pytest
+from unittest.mock import Mock, AsyncMock
+
+from py.services.lora_service import LoraService
+from py.utils.civitai_utils import build_license_flags
+
+
+@pytest.fixture
+def lora_service():
+ """Create a LoraService instance for testing."""
+ scanner = Mock()
+ cache_mock = Mock()
+ cache_mock.raw_data = []
+ scanner.get_cached_data = AsyncMock(return_value=cache_mock)
+ scanner._hash_index = Mock()
+ scanner._hash_index.get_duplicate_hashes = Mock(return_value={})
+ scanner._hash_index.get_duplicate_filenames = Mock(return_value={})
+
+ service = LoraService(scanner)
+ service.filter_set = Mock()
+ service.filter_set.apply = Mock(return_value=None)
+
+ return service
+
+
+@pytest.fixture
+def sample_loras():
+ """Sample loras with various license configurations."""
+ return [
+ {
+ "file_name": "credit_required_not_for_selling.safetensors",
+ "base_model": "Illustrious",
+ "folder": "",
+ "license_flags": build_license_flags(
+ {"allowNoCredit": False, "allowCommercialUse": ["Rent"]}
+ ),
+ },
+ {
+ "file_name": "no_credit_required_for_selling.safetensors",
+ "base_model": "Illustrious",
+ "folder": "",
+ "license_flags": build_license_flags(
+ {"allowNoCredit": True, "allowCommercialUse": ["Image"]}
+ ),
+ },
+ {
+ "file_name": "credit_required_for_selling.safetensors",
+ "base_model": "Illustrious",
+ "folder": "",
+ "license_flags": build_license_flags(
+ {"allowNoCredit": False, "allowCommercialUse": ["Image"]}
+ ),
+ },
+ {
+ "file_name": "no_credit_required_not_for_selling.safetensors",
+ "base_model": "Illustrious",
+ "folder": "",
+ "license_flags": build_license_flags(
+ {"allowNoCredit": True, "allowCommercialUse": ["Rent"]}
+ ),
+ },
+ {
+ "file_name": "default_license.safetensors",
+ "base_model": "Illustrious",
+ "folder": "",
+ "license_flags": build_license_flags(None),
+ },
+ ]
+
+
+@pytest.mark.asyncio
+async def test_pool_filter_no_credit_required_true(lora_service, sample_loras):
+ """Test that no_credit_required=True keeps only models where credit is NOT required."""
+ pool_config = {
+ "baseModels": [],
+ "tags": {"include": [], "exclude": []},
+ "folders": {"include": [], "exclude": []},
+ "license": {
+ "noCreditRequired": True,
+ "allowSelling": False,
+ },
+ }
+
+ filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
+
+ # Should keep models with allowNoCredit=True (bit 0 = 1)
+ # Models: no_credit_required_for_selling, no_credit_required_not_for_selling, default_license
+ assert len(filtered) == 3
+ file_names = {lora["file_name"] for lora in filtered}
+ assert file_names == {
+ "no_credit_required_for_selling.safetensors",
+ "no_credit_required_not_for_selling.safetensors",
+ "default_license.safetensors",
+ }
+
+
+@pytest.mark.asyncio
+async def test_pool_filter_no_credit_required_false(lora_service, sample_loras):
+ """Test that no_credit_required=False keeps all models (no filter applied)."""
+ pool_config = {
+ "baseModels": [],
+ "tags": {"include": [], "exclude": []},
+ "folders": {"include": [], "exclude": []},
+ "license": {
+ "noCreditRequired": False,
+ "allowSelling": False,
+ },
+ }
+
+ filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
+
+ # Should keep all models when no_credit_required=False
+ assert len(filtered) == 5
+
+
+@pytest.mark.asyncio
+async def test_pool_filter_allow_selling_true(lora_service, sample_loras):
+ """Test that allowSelling=True keeps only models where selling is allowed."""
+ pool_config = {
+ "baseModels": [],
+ "tags": {"include": [], "exclude": []},
+ "folders": {"include": [], "exclude": []},
+ "license": {
+ "noCreditRequired": False,
+ "allowSelling": True,
+ },
+ }
+
+ filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
+
+ # Should keep models with Image permission (allowSelling)
+ # Models: no_credit_required_for_selling, credit_required_for_selling, default_license
+ assert len(filtered) == 3
+ file_names = {lora["file_name"] for lora in filtered}
+ assert file_names == {
+ "no_credit_required_for_selling.safetensors",
+ "credit_required_for_selling.safetensors",
+ "default_license.safetensors",
+ }
+
+
+@pytest.mark.asyncio
+async def test_pool_filter_allow_selling_false(lora_service, sample_loras):
+ """Test that allowSelling=False keeps all models (no filter applied)."""
+ pool_config = {
+ "baseModels": [],
+ "tags": {"include": [], "exclude": []},
+ "folders": {"include": [], "exclude": []},
+ "license": {
+ "noCreditRequired": False,
+ "allowSelling": False,
+ },
+ }
+
+ filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
+
+ # Should keep all models when allowSelling=False
+ assert len(filtered) == 5
+
+
+@pytest.mark.asyncio
+async def test_pool_filter_both_license_filters(lora_service, sample_loras):
+ """Test combining both no_credit_required and allowSelling filters."""
+ pool_config = {
+ "baseModels": [],
+ "tags": {"include": [], "exclude": []},
+ "folders": {"include": [], "exclude": []},
+ "license": {
+ "noCreditRequired": True,
+ "allowSelling": True,
+ },
+ }
+
+ filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
+
+ # Should keep models where both conditions are met:
+ # - allowNoCredit=True (no credit required)
+ # - Image permission exists (allow selling)
+ # Models: no_credit_required_for_selling, default_license
+ assert len(filtered) == 2
+ file_names = {lora["file_name"] for lora in filtered}
+ assert file_names == {
+ "no_credit_required_for_selling.safetensors",
+ "default_license.safetensors",
+ }
+
+
+@pytest.mark.asyncio
+async def test_pool_filter_base_models(lora_service, sample_loras):
+ """Test filtering by base models."""
+ pool_config = {
+ "baseModels": ["Illustrious"],
+ "tags": {"include": [], "exclude": []},
+ "folders": {"include": [], "exclude": []},
+ "license": {
+ "noCreditRequired": False,
+ "allowSelling": False,
+ },
+ }
+
+ filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
+
+ # All sample loras have base_model="Illustrious"
+ assert len(filtered) == 5
+
+ # Test with non-matching base model
+ pool_config["baseModels"] = ["SD15"]
+ filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
+ assert len(filtered) == 0
+
+
+@pytest.mark.asyncio
+async def test_pool_filter_folders(lora_service):
+ """Test filtering by folders."""
+ sample_loras = [
+ {
+ "file_name": "lora1.safetensors",
+ "base_model": "Illustrious",
+ "folder": "characters/",
+ "license_flags": build_license_flags(None),
+ },
+ {
+ "file_name": "lora2.safetensors",
+ "base_model": "Illustrious",
+ "folder": "styles/",
+ "license_flags": build_license_flags(None),
+ },
+ {
+ "file_name": "lora3.safetensors",
+ "base_model": "Illustrious",
+ "folder": "concepts/",
+ "license_flags": build_license_flags(None),
+ },
+ ]
+
+ # Test include folders
+ pool_config = {
+ "baseModels": [],
+ "tags": {"include": [], "exclude": []},
+ "folders": {"include": ["characters/"], "exclude": []},
+ "license": {
+ "noCreditRequired": False,
+ "allowSelling": False,
+ },
+ }
+
+ filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
+ assert len(filtered) == 1
+ assert filtered[0]["file_name"] == "lora1.safetensors"
+
+ # Test exclude folders
+ pool_config = {
+ "baseModels": [],
+ "tags": {"include": [], "exclude": []},
+ "folders": {"include": [], "exclude": ["characters/"]},
+ "license": {
+ "noCreditRequired": False,
+ "allowSelling": False,
+ },
+ }
+
+ filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
+ assert len(filtered) == 2
+ file_names = {lora["file_name"] for lora in filtered}
+ assert file_names == {"lora2.safetensors", "lora3.safetensors"}
+
+
+@pytest.mark.asyncio
+async def test_pool_filter_tags(lora_service):
+ """Test filtering by tags."""
+ lora_service.filter_set.apply = Mock(side_effect=lambda data, criteria: data)
+
+ sample_loras = [
+ {
+ "file_name": "lora1.safetensors",
+ "base_model": "Illustrious",
+ "folder": "",
+ "tags": ["anime", "character"],
+ "license_flags": build_license_flags(None),
+ },
+ {
+ "file_name": "lora2.safetensors",
+ "base_model": "Illustrious",
+ "folder": "",
+ "tags": ["realistic", "style"],
+ "license_flags": build_license_flags(None),
+ },
+ ]
+
+ pool_config = {
+ "baseModels": [],
+ "tags": {"include": ["anime"], "exclude": []},
+ "folders": {"include": [], "exclude": []},
+ "license": {
+ "noCreditRequired": False,
+ "allowSelling": False,
+ },
+ }
+
+ filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
+
+ # Should call filter_set.apply with tag filters
+ assert lora_service.filter_set.apply.called
+ call_args = lora_service.filter_set.apply.call_args
+ assert call_args[0][0] == sample_loras
+ assert "anime" in call_args[0][1].tags
+
+
+@pytest.mark.asyncio
+async def test_pool_filter_combined_all_filters(lora_service):
+ """Test combining all filter types."""
+ test_loras = [
+ {
+ "file_name": "match_all.safetensors",
+ "base_model": "Illustrious",
+ "folder": "folder1/",
+ "tags": ["tag1"],
+ "license_flags": build_license_flags({"allowNoCredit": True}),
+ },
+ {
+ "file_name": "wrong_base_model.safetensors",
+ "base_model": "SD15",
+ "folder": "folder1/",
+ "tags": ["tag1"],
+ "license_flags": build_license_flags({"allowNoCredit": True}),
+ },
+ {
+ "file_name": "wrong_folder.safetensors",
+ "base_model": "Illustrious",
+ "folder": "folder2/",
+ "tags": ["tag1"],
+ "license_flags": build_license_flags({"allowNoCredit": True}),
+ },
+ {
+ "file_name": "credit_required.safetensors",
+ "base_model": "Illustrious",
+ "folder": "folder1/",
+ "tags": ["tag1"],
+ "license_flags": build_license_flags({"allowNoCredit": False}),
+ },
+ ]
+
+ # Mock tag filter to return all items (simulate tag1 match)
+ def mock_tag_filter(data, criteria):
+ return data
+
+ lora_service.filter_set.apply = Mock(side_effect=mock_tag_filter)
+
+ pool_config = {
+ "baseModels": ["Illustrious"],
+ "tags": {"include": ["tag1"], "exclude": []},
+ "folders": {"include": ["folder1/"], "exclude": []},
+ "license": {
+ "noCreditRequired": True,
+ "allowSelling": False,
+ },
+ }
+
+ filtered = await lora_service._apply_pool_filters(test_loras, pool_config)
+
+ # Should apply all filters
+ assert lora_service.filter_set.apply.called
+ # Only "match_all.safetensors" should match:
+ # - base_model: Illustrious ✓
+ # - folder: folder1/ ✓
+ # - no_credit_required: True ✓ (bit 0 = 1)
+ # - tags: tag1 ✓
+ assert len(filtered) == 1
+ assert filtered[0]["file_name"] == "match_all.safetensors"
diff --git a/vue-widgets/src/components/lora-pool/sections/LicenseSection.vue b/vue-widgets/src/components/lora-pool/sections/LicenseSection.vue
index 5b2e2986..d175436f 100644
--- a/vue-widgets/src/components/lora-pool/sections/LicenseSection.vue
+++ b/vue-widgets/src/components/lora-pool/sections/LicenseSection.vue
@@ -5,7 +5,7 @@