mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(lora-randomizer): refactor randomization logic and add input preprocessing
- Add `_preprocess_loras_input` method to handle different widget input formats - Move core randomization logic to `LoraService` for better separation of concerns - Update `_select_loras` method to use new service-based approach - Add comprehensive test fixtures for license filtering scenarios - Include debug print statement for pool config inspection during development This refactor improves code organization by centralizing business logic in the service layer while maintaining backward compatibility with existing widget inputs.
This commit is contained in:
@@ -39,6 +39,22 @@ class LoraRandomizerNode:
|
||||
FUNCTION = "randomize"
|
||||
OUTPUT_NODE = False
|
||||
|
||||
def _preprocess_loras_input(self, loras):
|
||||
"""
|
||||
Preprocess loras input to handle different widget formats.
|
||||
|
||||
Args:
|
||||
loras: Input from widget, either:
|
||||
- List of LoRA dicts (expected format)
|
||||
- Dict with '__value__' key containing the list
|
||||
|
||||
Returns:
|
||||
List of LoRA dicts
|
||||
"""
|
||||
if isinstance(loras, dict) and "__value__" in loras:
|
||||
return loras["__value__"]
|
||||
return loras
|
||||
|
||||
async def randomize(self, randomizer_config, loras, pool_config=None):
|
||||
"""
|
||||
Randomize LoRAs based on configuration and pool filters.
|
||||
@@ -53,6 +69,8 @@ class LoraRandomizerNode:
|
||||
"""
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
loras = self._preprocess_loras_input(loras)
|
||||
|
||||
roll_mode = randomizer_config.get("roll_mode", "always")
|
||||
logger.debug(f"[LoraRandomizerNode] roll_mode: {roll_mode}")
|
||||
|
||||
@@ -64,6 +82,8 @@ class LoraRandomizerNode:
|
||||
scanner, randomizer_config, loras, pool_config
|
||||
)
|
||||
|
||||
print("pool config", pool_config)
|
||||
|
||||
execution_stack = self._build_execution_stack_from_input(loras)
|
||||
|
||||
return {
|
||||
@@ -120,6 +140,8 @@ class LoraRandomizerNode:
|
||||
Returns:
|
||||
List of LoRA dicts for UI display
|
||||
"""
|
||||
from ..services.lora_service import LoraService
|
||||
|
||||
# Parse randomizer settings
|
||||
count_mode = randomizer_config.get("count_mode", "range")
|
||||
count_fixed = randomizer_config.get("count_fixed", 5)
|
||||
@@ -131,183 +153,23 @@ class LoraRandomizerNode:
|
||||
clip_strength_min = randomizer_config.get("clip_strength_min", 0.0)
|
||||
clip_strength_max = randomizer_config.get("clip_strength_max", 1.0)
|
||||
|
||||
# Determine target count
|
||||
if count_mode == "fixed":
|
||||
target_count = count_fixed
|
||||
else:
|
||||
target_count = random.randint(count_min, count_max)
|
||||
|
||||
# Extract locked LoRAs from input
|
||||
locked_loras = [lora for lora in input_loras if lora.get("locked", False)]
|
||||
locked_count = len(locked_loras)
|
||||
|
||||
# Get available loras from cache
|
||||
try:
|
||||
cache_data = await scanner.get_cached_data(force_refresh=False)
|
||||
if cache_data and hasattr(cache_data, "raw_data"):
|
||||
available_loras = cache_data.raw_data
|
||||
else:
|
||||
available_loras = []
|
||||
except Exception as e:
|
||||
logger.warning(f"[LoraRandomizerNode] Failed to get lora cache: {e}")
|
||||
available_loras = []
|
||||
|
||||
# Apply pool filters if provided
|
||||
if pool_config:
|
||||
available_loras = await self._apply_pool_filters(
|
||||
available_loras, pool_config, scanner
|
||||
)
|
||||
|
||||
# Calculate how many new LoRAs to select
|
||||
slots_needed = target_count - locked_count
|
||||
|
||||
if slots_needed < 0:
|
||||
slots_needed = 0
|
||||
# Too many locked, trim to target
|
||||
locked_loras = locked_loras[:target_count]
|
||||
locked_count = len(locked_loras)
|
||||
|
||||
# Filter out locked LoRAs from available pool
|
||||
locked_names = {lora["name"] for lora in locked_loras}
|
||||
available_pool = [
|
||||
l for l in available_loras if l["file_name"] not in locked_names
|
||||
]
|
||||
|
||||
# Ensure we don't try to select more than available
|
||||
if slots_needed > len(available_pool):
|
||||
slots_needed = len(available_pool)
|
||||
|
||||
# Random sample
|
||||
selected = []
|
||||
if slots_needed > 0:
|
||||
selected = random.sample(available_pool, slots_needed)
|
||||
|
||||
# Generate random strengths for selected LoRAs
|
||||
result_loras = []
|
||||
for lora in selected:
|
||||
model_str = round(random.uniform(model_strength_min, model_strength_max), 2)
|
||||
|
||||
if use_same_clip_strength:
|
||||
clip_str = model_str
|
||||
else:
|
||||
clip_str = round(
|
||||
random.uniform(clip_strength_min, clip_strength_max), 2
|
||||
)
|
||||
|
||||
result_loras.append(
|
||||
{
|
||||
"name": lora["file_name"],
|
||||
"strength": model_str,
|
||||
"clipStrength": clip_str,
|
||||
"active": True,
|
||||
"expanded": abs(model_str - clip_str) > 0.001,
|
||||
"locked": False,
|
||||
}
|
||||
)
|
||||
|
||||
# Merge with locked LoRAs
|
||||
result_loras.extend(locked_loras)
|
||||
# Use LoraService to generate random LoRAs
|
||||
lora_service = LoraService(scanner)
|
||||
result_loras = await lora_service.get_random_loras(
|
||||
count=count_fixed,
|
||||
model_strength_min=model_strength_min,
|
||||
model_strength_max=model_strength_max,
|
||||
use_same_clip_strength=use_same_clip_strength,
|
||||
clip_strength_min=clip_strength_min,
|
||||
clip_strength_max=clip_strength_max,
|
||||
locked_loras=locked_loras,
|
||||
pool_config=pool_config,
|
||||
count_mode=count_mode,
|
||||
count_min=count_min,
|
||||
count_max=count_max,
|
||||
)
|
||||
|
||||
return result_loras
|
||||
|
||||
async def _apply_pool_filters(self, available_loras, pool_config, scanner):
|
||||
"""
|
||||
Apply pool_config filters to available LoRAs.
|
||||
|
||||
Args:
|
||||
available_loras: List of all LoRA dicts
|
||||
pool_config: Dict with filter settings from LoRA Pool node
|
||||
scanner: Scanner instance for accessing filter utilities
|
||||
|
||||
Returns:
|
||||
Filtered list of LoRA dicts
|
||||
"""
|
||||
from ..services.lora_service import LoraService
|
||||
from ..services.model_query import FilterCriteria
|
||||
|
||||
# Create lora service instance for filtering
|
||||
lora_service = LoraService(scanner)
|
||||
|
||||
# Extract filter parameters from pool_config
|
||||
selected_base_models = pool_config.get("baseModels", [])
|
||||
tags_dict = pool_config.get("tags", {})
|
||||
include_tags = tags_dict.get("include", [])
|
||||
exclude_tags = tags_dict.get("exclude", [])
|
||||
folders_dict = pool_config.get("folders", {})
|
||||
include_folders = folders_dict.get("include", [])
|
||||
exclude_folders = folders_dict.get("exclude", [])
|
||||
license_dict = pool_config.get("license", {})
|
||||
no_credit_required = license_dict.get("noCreditRequired", False)
|
||||
allow_selling = license_dict.get("allowSelling", False)
|
||||
|
||||
# Build tag filters dict
|
||||
tag_filters = {}
|
||||
for tag in include_tags:
|
||||
tag_filters[tag] = "include"
|
||||
for tag in exclude_tags:
|
||||
tag_filters[tag] = "exclude"
|
||||
|
||||
# Build folder filter
|
||||
# LoRA Pool uses include/exclude folders, we need to apply this logic
|
||||
# For now, we'll filter based on folder path matching
|
||||
if include_folders or exclude_folders:
|
||||
filtered = []
|
||||
for lora in available_loras:
|
||||
folder = lora.get("folder", "")
|
||||
|
||||
# Check exclude folders first
|
||||
excluded = False
|
||||
for exclude_folder in exclude_folders:
|
||||
if folder.startswith(exclude_folder):
|
||||
excluded = True
|
||||
break
|
||||
|
||||
if excluded:
|
||||
continue
|
||||
|
||||
# Check include folders
|
||||
if include_folders:
|
||||
included = False
|
||||
for include_folder in include_folders:
|
||||
if folder.startswith(include_folder):
|
||||
included = True
|
||||
break
|
||||
if not included:
|
||||
continue
|
||||
|
||||
filtered.append(lora)
|
||||
|
||||
available_loras = filtered
|
||||
|
||||
# Apply base model filter
|
||||
if selected_base_models:
|
||||
available_loras = [
|
||||
lora
|
||||
for lora in available_loras
|
||||
if lora.get("base_model") in selected_base_models
|
||||
]
|
||||
|
||||
# Apply tag filters
|
||||
if tag_filters:
|
||||
criteria = FilterCriteria(tags=tag_filters)
|
||||
available_loras = lora_service.filter_set.apply(available_loras, criteria)
|
||||
|
||||
# Apply license filters
|
||||
# Note: no_credit_required=True means filter out models where credit is NOT required
|
||||
# (i.e., keep only models where credit IS required)
|
||||
if no_credit_required:
|
||||
available_loras = [
|
||||
lora
|
||||
for lora in available_loras
|
||||
if not (lora.get("license_flags", 127) & (1 << 0))
|
||||
]
|
||||
|
||||
# allow_selling=True means keep only models where selling generated content is allowed
|
||||
if allow_selling:
|
||||
available_loras = [
|
||||
lora
|
||||
for lora in available_loras
|
||||
if bool(lora.get("license_flags", 127) & (1 << 1))
|
||||
]
|
||||
|
||||
return available_loras
|
||||
|
||||
@@ -225,12 +225,15 @@ class LoraService(BaseModelService):
|
||||
clip_strength_max: float = 1.0,
|
||||
locked_loras: Optional[List[Dict]] = None,
|
||||
pool_config: Optional[Dict] = None,
|
||||
count_mode: str = "fixed",
|
||||
count_min: int = 3,
|
||||
count_max: int = 7,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Get random LoRAs with specified strength ranges.
|
||||
|
||||
Args:
|
||||
count: Number of LoRAs to select
|
||||
count: Number of LoRAs to select (if count_mode='fixed')
|
||||
model_strength_min: Minimum model strength
|
||||
model_strength_max: Maximum model strength
|
||||
use_same_clip_strength: Whether to use same strength for clip
|
||||
@@ -238,6 +241,9 @@ class LoraService(BaseModelService):
|
||||
clip_strength_max: Maximum clip strength
|
||||
locked_loras: List of locked LoRA dicts to preserve
|
||||
pool_config: Optional pool config for filtering
|
||||
count_mode: How to determine count ('fixed' or 'range')
|
||||
count_min: Minimum count for range mode
|
||||
count_max: Maximum count for range mode
|
||||
|
||||
Returns:
|
||||
List of LoRA dicts with randomized strengths
|
||||
@@ -247,6 +253,12 @@ class LoraService(BaseModelService):
|
||||
if locked_loras is None:
|
||||
locked_loras = []
|
||||
|
||||
# Determine target count based on count_mode
|
||||
if count_mode == "fixed":
|
||||
target_count = count
|
||||
else:
|
||||
target_count = random.randint(count_min, count_max)
|
||||
|
||||
# Get available loras from cache
|
||||
cache = await self.scanner.get_cached_data(force_refresh=False)
|
||||
available_loras = cache.raw_data if cache else []
|
||||
@@ -259,12 +271,12 @@ class LoraService(BaseModelService):
|
||||
|
||||
# Calculate slots needed (total - locked)
|
||||
locked_count = len(locked_loras)
|
||||
slots_needed = count - locked_count
|
||||
slots_needed = target_count - locked_count
|
||||
|
||||
if slots_needed < 0:
|
||||
slots_needed = 0
|
||||
# Too many locked, trim to target
|
||||
locked_loras = locked_loras[:count]
|
||||
locked_loras = locked_loras[:target_count]
|
||||
|
||||
# Filter out locked LoRAs from available pool
|
||||
locked_names = {lora["name"] for lora in locked_loras}
|
||||
@@ -324,14 +336,19 @@ class LoraService(BaseModelService):
|
||||
"""
|
||||
from .model_query import FilterCriteria
|
||||
|
||||
# Extract filter parameters from pool_config
|
||||
selected_base_models = pool_config.get("selected_base_models", [])
|
||||
include_tags = pool_config.get("include_tags", [])
|
||||
exclude_tags = pool_config.get("exclude_tags", [])
|
||||
include_folders = pool_config.get("include_folders", [])
|
||||
exclude_folders = pool_config.get("exclude_folders", [])
|
||||
no_credit_required = pool_config.get("no_credit_required", False)
|
||||
allow_selling = pool_config.get("allow_selling", False)
|
||||
filter_section = pool_config
|
||||
|
||||
# Extract filter parameters
|
||||
selected_base_models = filter_section.get("baseModels", [])
|
||||
tags_dict = filter_section.get("tags", {})
|
||||
include_tags = tags_dict.get("include", [])
|
||||
exclude_tags = tags_dict.get("exclude", [])
|
||||
folders_dict = filter_section.get("folders", {})
|
||||
include_folders = folders_dict.get("include", [])
|
||||
exclude_folders = folders_dict.get("exclude", [])
|
||||
license_dict = filter_section.get("license", {})
|
||||
no_credit_required = license_dict.get("noCreditRequired", False)
|
||||
allow_selling = license_dict.get("allowSelling", False)
|
||||
|
||||
# Build tag filters dict
|
||||
tag_filters = {}
|
||||
@@ -384,13 +401,13 @@ class LoraService(BaseModelService):
|
||||
available_loras = self.filter_set.apply(available_loras, criteria)
|
||||
|
||||
# Apply license filters
|
||||
# Note: no_credit_required=True means filter out models where credit is NOT required
|
||||
# (i.e., keep only models where credit IS required)
|
||||
# no_credit_required=True means keep only models where credit is NOT required
|
||||
# (i.e., allowNoCredit=True, which is bit 0 = 1 in license_flags)
|
||||
if no_credit_required:
|
||||
available_loras = [
|
||||
lora
|
||||
for lora in available_loras
|
||||
if not (lora.get("license_flags", 127) & (1 << 0))
|
||||
if bool(lora.get("license_flags", 127) & (1 << 0))
|
||||
]
|
||||
|
||||
# allow_selling=True means keep only models where selling generated content is allowed
|
||||
|
||||
371
tests/services/test_lora_pool_filters.py
Normal file
371
tests/services/test_lora_pool_filters.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""Tests for LoraService pool filtering functionality."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
|
||||
from py.services.lora_service import LoraService
|
||||
from py.utils.civitai_utils import build_license_flags
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lora_service():
|
||||
"""Create a LoraService instance for testing."""
|
||||
scanner = Mock()
|
||||
cache_mock = Mock()
|
||||
cache_mock.raw_data = []
|
||||
scanner.get_cached_data = AsyncMock(return_value=cache_mock)
|
||||
scanner._hash_index = Mock()
|
||||
scanner._hash_index.get_duplicate_hashes = Mock(return_value={})
|
||||
scanner._hash_index.get_duplicate_filenames = Mock(return_value={})
|
||||
|
||||
service = LoraService(scanner)
|
||||
service.filter_set = Mock()
|
||||
service.filter_set.apply = Mock(return_value=None)
|
||||
|
||||
return service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_loras():
|
||||
"""Sample loras with various license configurations."""
|
||||
return [
|
||||
{
|
||||
"file_name": "credit_required_not_for_selling.safetensors",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(
|
||||
{"allowNoCredit": False, "allowCommercialUse": ["Rent"]}
|
||||
),
|
||||
},
|
||||
{
|
||||
"file_name": "no_credit_required_for_selling.safetensors",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(
|
||||
{"allowNoCredit": True, "allowCommercialUse": ["Image"]}
|
||||
),
|
||||
},
|
||||
{
|
||||
"file_name": "credit_required_for_selling.safetensors",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(
|
||||
{"allowNoCredit": False, "allowCommercialUse": ["Image"]}
|
||||
),
|
||||
},
|
||||
{
|
||||
"file_name": "no_credit_required_not_for_selling.safetensors",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(
|
||||
{"allowNoCredit": True, "allowCommercialUse": ["Rent"]}
|
||||
),
|
||||
},
|
||||
{
|
||||
"file_name": "default_license.safetensors",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_filter_no_credit_required_true(lora_service, sample_loras):
|
||||
"""Test that no_credit_required=True keeps only models where credit is NOT required."""
|
||||
pool_config = {
|
||||
"baseModels": [],
|
||||
"tags": {"include": [], "exclude": []},
|
||||
"folders": {"include": [], "exclude": []},
|
||||
"license": {
|
||||
"noCreditRequired": True,
|
||||
"allowSelling": False,
|
||||
},
|
||||
}
|
||||
|
||||
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
|
||||
|
||||
# Should keep models with allowNoCredit=True (bit 0 = 1)
|
||||
# Models: no_credit_required_for_selling, no_credit_required_not_for_selling, default_license
|
||||
assert len(filtered) == 3
|
||||
file_names = {lora["file_name"] for lora in filtered}
|
||||
assert file_names == {
|
||||
"no_credit_required_for_selling.safetensors",
|
||||
"no_credit_required_not_for_selling.safetensors",
|
||||
"default_license.safetensors",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_filter_no_credit_required_false(lora_service, sample_loras):
|
||||
"""Test that no_credit_required=False keeps all models (no filter applied)."""
|
||||
pool_config = {
|
||||
"baseModels": [],
|
||||
"tags": {"include": [], "exclude": []},
|
||||
"folders": {"include": [], "exclude": []},
|
||||
"license": {
|
||||
"noCreditRequired": False,
|
||||
"allowSelling": False,
|
||||
},
|
||||
}
|
||||
|
||||
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
|
||||
|
||||
# Should keep all models when no_credit_required=False
|
||||
assert len(filtered) == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_filter_allow_selling_true(lora_service, sample_loras):
|
||||
"""Test that allowSelling=True keeps only models where selling is allowed."""
|
||||
pool_config = {
|
||||
"baseModels": [],
|
||||
"tags": {"include": [], "exclude": []},
|
||||
"folders": {"include": [], "exclude": []},
|
||||
"license": {
|
||||
"noCreditRequired": False,
|
||||
"allowSelling": True,
|
||||
},
|
||||
}
|
||||
|
||||
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
|
||||
|
||||
# Should keep models with Image permission (allowSelling)
|
||||
# Models: no_credit_required_for_selling, credit_required_for_selling, default_license
|
||||
assert len(filtered) == 3
|
||||
file_names = {lora["file_name"] for lora in filtered}
|
||||
assert file_names == {
|
||||
"no_credit_required_for_selling.safetensors",
|
||||
"credit_required_for_selling.safetensors",
|
||||
"default_license.safetensors",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_filter_allow_selling_false(lora_service, sample_loras):
|
||||
"""Test that allowSelling=False keeps all models (no filter applied)."""
|
||||
pool_config = {
|
||||
"baseModels": [],
|
||||
"tags": {"include": [], "exclude": []},
|
||||
"folders": {"include": [], "exclude": []},
|
||||
"license": {
|
||||
"noCreditRequired": False,
|
||||
"allowSelling": False,
|
||||
},
|
||||
}
|
||||
|
||||
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
|
||||
|
||||
# Should keep all models when allowSelling=False
|
||||
assert len(filtered) == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_filter_both_license_filters(lora_service, sample_loras):
|
||||
"""Test combining both no_credit_required and allowSelling filters."""
|
||||
pool_config = {
|
||||
"baseModels": [],
|
||||
"tags": {"include": [], "exclude": []},
|
||||
"folders": {"include": [], "exclude": []},
|
||||
"license": {
|
||||
"noCreditRequired": True,
|
||||
"allowSelling": True,
|
||||
},
|
||||
}
|
||||
|
||||
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
|
||||
|
||||
# Should keep models where both conditions are met:
|
||||
# - allowNoCredit=True (no credit required)
|
||||
# - Image permission exists (allow selling)
|
||||
# Models: no_credit_required_for_selling, default_license
|
||||
assert len(filtered) == 2
|
||||
file_names = {lora["file_name"] for lora in filtered}
|
||||
assert file_names == {
|
||||
"no_credit_required_for_selling.safetensors",
|
||||
"default_license.safetensors",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_filter_base_models(lora_service, sample_loras):
|
||||
"""Test filtering by base models."""
|
||||
pool_config = {
|
||||
"baseModels": ["Illustrious"],
|
||||
"tags": {"include": [], "exclude": []},
|
||||
"folders": {"include": [], "exclude": []},
|
||||
"license": {
|
||||
"noCreditRequired": False,
|
||||
"allowSelling": False,
|
||||
},
|
||||
}
|
||||
|
||||
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
|
||||
|
||||
# All sample loras have base_model="Illustrious"
|
||||
assert len(filtered) == 5
|
||||
|
||||
# Test with non-matching base model
|
||||
pool_config["baseModels"] = ["SD15"]
|
||||
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
|
||||
assert len(filtered) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_filter_folders(lora_service):
|
||||
"""Test filtering by folders."""
|
||||
sample_loras = [
|
||||
{
|
||||
"file_name": "lora1.safetensors",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "characters/",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
{
|
||||
"file_name": "lora2.safetensors",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "styles/",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
{
|
||||
"file_name": "lora3.safetensors",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "concepts/",
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
]
|
||||
|
||||
# Test include folders
|
||||
pool_config = {
|
||||
"baseModels": [],
|
||||
"tags": {"include": [], "exclude": []},
|
||||
"folders": {"include": ["characters/"], "exclude": []},
|
||||
"license": {
|
||||
"noCreditRequired": False,
|
||||
"allowSelling": False,
|
||||
},
|
||||
}
|
||||
|
||||
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["file_name"] == "lora1.safetensors"
|
||||
|
||||
# Test exclude folders
|
||||
pool_config = {
|
||||
"baseModels": [],
|
||||
"tags": {"include": [], "exclude": []},
|
||||
"folders": {"include": [], "exclude": ["characters/"]},
|
||||
"license": {
|
||||
"noCreditRequired": False,
|
||||
"allowSelling": False,
|
||||
},
|
||||
}
|
||||
|
||||
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
|
||||
assert len(filtered) == 2
|
||||
file_names = {lora["file_name"] for lora in filtered}
|
||||
assert file_names == {"lora2.safetensors", "lora3.safetensors"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_filter_tags(lora_service):
|
||||
"""Test filtering by tags."""
|
||||
lora_service.filter_set.apply = Mock(side_effect=lambda data, criteria: data)
|
||||
|
||||
sample_loras = [
|
||||
{
|
||||
"file_name": "lora1.safetensors",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"tags": ["anime", "character"],
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
{
|
||||
"file_name": "lora2.safetensors",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "",
|
||||
"tags": ["realistic", "style"],
|
||||
"license_flags": build_license_flags(None),
|
||||
},
|
||||
]
|
||||
|
||||
pool_config = {
|
||||
"baseModels": [],
|
||||
"tags": {"include": ["anime"], "exclude": []},
|
||||
"folders": {"include": [], "exclude": []},
|
||||
"license": {
|
||||
"noCreditRequired": False,
|
||||
"allowSelling": False,
|
||||
},
|
||||
}
|
||||
|
||||
filtered = await lora_service._apply_pool_filters(sample_loras, pool_config)
|
||||
|
||||
# Should call filter_set.apply with tag filters
|
||||
assert lora_service.filter_set.apply.called
|
||||
call_args = lora_service.filter_set.apply.call_args
|
||||
assert call_args[0][0] == sample_loras
|
||||
assert "anime" in call_args[0][1].tags
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_filter_combined_all_filters(lora_service):
|
||||
"""Test combining all filter types."""
|
||||
test_loras = [
|
||||
{
|
||||
"file_name": "match_all.safetensors",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "folder1/",
|
||||
"tags": ["tag1"],
|
||||
"license_flags": build_license_flags({"allowNoCredit": True}),
|
||||
},
|
||||
{
|
||||
"file_name": "wrong_base_model.safetensors",
|
||||
"base_model": "SD15",
|
||||
"folder": "folder1/",
|
||||
"tags": ["tag1"],
|
||||
"license_flags": build_license_flags({"allowNoCredit": True}),
|
||||
},
|
||||
{
|
||||
"file_name": "wrong_folder.safetensors",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "folder2/",
|
||||
"tags": ["tag1"],
|
||||
"license_flags": build_license_flags({"allowNoCredit": True}),
|
||||
},
|
||||
{
|
||||
"file_name": "credit_required.safetensors",
|
||||
"base_model": "Illustrious",
|
||||
"folder": "folder1/",
|
||||
"tags": ["tag1"],
|
||||
"license_flags": build_license_flags({"allowNoCredit": False}),
|
||||
},
|
||||
]
|
||||
|
||||
# Mock tag filter to return all items (simulate tag1 match)
|
||||
def mock_tag_filter(data, criteria):
|
||||
return data
|
||||
|
||||
lora_service.filter_set.apply = Mock(side_effect=mock_tag_filter)
|
||||
|
||||
pool_config = {
|
||||
"baseModels": ["Illustrious"],
|
||||
"tags": {"include": ["tag1"], "exclude": []},
|
||||
"folders": {"include": ["folder1/"], "exclude": []},
|
||||
"license": {
|
||||
"noCreditRequired": True,
|
||||
"allowSelling": False,
|
||||
},
|
||||
}
|
||||
|
||||
filtered = await lora_service._apply_pool_filters(test_loras, pool_config)
|
||||
|
||||
# Should apply all filters
|
||||
assert lora_service.filter_set.apply.called
|
||||
# Only "match_all.safetensors" should match:
|
||||
# - base_model: Illustrious ✓
|
||||
# - folder: folder1/ ✓
|
||||
# - no_credit_required: True ✓ (bit 0 = 1)
|
||||
# - tags: tag1 ✓
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["file_name"] == "match_all.safetensors"
|
||||
@@ -5,7 +5,7 @@
|
||||
</div>
|
||||
<div class="section__toggles">
|
||||
<label class="toggle-item">
|
||||
<span class="toggle-item__label">No Credit</span>
|
||||
<span class="toggle-item__label">No Credit Required</span>
|
||||
<button
|
||||
type="button"
|
||||
class="toggle-switch"
|
||||
|
||||
@@ -88,16 +88,7 @@ export function useLoraRandomizerState(widget: ComponentWidget) {
|
||||
|
||||
// Add pool config if provided
|
||||
if (poolConfig) {
|
||||
// Convert pool config to backend format
|
||||
requestBody.pool_config = {
|
||||
selected_base_models: poolConfig.filters?.baseModels || [],
|
||||
include_tags: poolConfig.filters?.tags?.include || [],
|
||||
exclude_tags: poolConfig.filters?.tags?.exclude || [],
|
||||
include_folders: poolConfig.filters?.folders?.include || [],
|
||||
exclude_folders: poolConfig.filters?.folders?.exclude || [],
|
||||
no_credit_required: poolConfig.filters?.license?.noCreditRequired || false,
|
||||
allow_selling: poolConfig.filters?.license?.allowSelling || false,
|
||||
}
|
||||
requestBody.pool_config = poolConfig.filters || {}
|
||||
}
|
||||
|
||||
// Call API endpoint
|
||||
|
||||
@@ -734,10 +734,6 @@ export function addLorasWidget(node, name, opts, callback) {
|
||||
|
||||
widget.callback = callback;
|
||||
|
||||
widget.serializeValue = () => {
|
||||
return widgetValue;
|
||||
}
|
||||
|
||||
widget.onRemove = () => {
|
||||
container.remove();
|
||||
previewTooltip.cleanup();
|
||||
|
||||
@@ -283,13 +283,13 @@
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.section[data-v-8b49983c] {
|
||||
.section[data-v-66148794] {
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.section__header[data-v-8b49983c] {
|
||||
.section__header[data-v-66148794] {
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.section__title[data-v-8b49983c] {
|
||||
.section__title[data-v-66148794] {
|
||||
font-size: 10px;
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
@@ -297,21 +297,21 @@
|
||||
color: var(--fg-color, #fff);
|
||||
opacity: 0.6;
|
||||
}
|
||||
.section__toggles[data-v-8b49983c] {
|
||||
.section__toggles[data-v-66148794] {
|
||||
display: flex;
|
||||
gap: 16px;
|
||||
}
|
||||
.toggle-item[data-v-8b49983c] {
|
||||
.toggle-item[data-v-66148794] {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
cursor: pointer;
|
||||
}
|
||||
.toggle-item__label[data-v-8b49983c] {
|
||||
.toggle-item__label[data-v-66148794] {
|
||||
font-size: 12px;
|
||||
color: var(--fg-color, #fff);
|
||||
}
|
||||
.toggle-switch[data-v-8b49983c] {
|
||||
.toggle-switch[data-v-66148794] {
|
||||
position: relative;
|
||||
width: 36px;
|
||||
height: 20px;
|
||||
@@ -320,7 +320,7 @@
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
}
|
||||
.toggle-switch__track[data-v-8b49983c] {
|
||||
.toggle-switch__track[data-v-66148794] {
|
||||
position: absolute;
|
||||
inset: 0;
|
||||
background: var(--comfy-input-bg, #333);
|
||||
@@ -328,11 +328,11 @@
|
||||
border-radius: 10px;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
.toggle-switch--active .toggle-switch__track[data-v-8b49983c] {
|
||||
.toggle-switch--active .toggle-switch__track[data-v-66148794] {
|
||||
background: rgba(66, 153, 225, 0.3);
|
||||
border-color: rgba(66, 153, 225, 0.6);
|
||||
}
|
||||
.toggle-switch__thumb[data-v-8b49983c] {
|
||||
.toggle-switch__thumb[data-v-66148794] {
|
||||
position: absolute;
|
||||
top: 2px;
|
||||
left: 2px;
|
||||
@@ -343,12 +343,12 @@
|
||||
transition: all 0.2s;
|
||||
opacity: 0.6;
|
||||
}
|
||||
.toggle-switch--active .toggle-switch__thumb[data-v-8b49983c] {
|
||||
.toggle-switch--active .toggle-switch__thumb[data-v-66148794] {
|
||||
transform: translateX(16px);
|
||||
background: #4299e1;
|
||||
opacity: 1;
|
||||
}
|
||||
.toggle-switch:hover .toggle-switch__thumb[data-v-8b49983c] {
|
||||
.toggle-switch:hover .toggle-switch__thumb[data-v-66148794] {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
@@ -10092,7 +10092,7 @@ const _sfc_main$d = /* @__PURE__ */ defineComponent({
|
||||
], -1)),
|
||||
createBaseVNode("div", _hoisted_2$a, [
|
||||
createBaseVNode("label", _hoisted_3$8, [
|
||||
_cache[3] || (_cache[3] = createBaseVNode("span", { class: "toggle-item__label" }, "No Credit", -1)),
|
||||
_cache[3] || (_cache[3] = createBaseVNode("span", { class: "toggle-item__label" }, "No Credit Required", -1)),
|
||||
createBaseVNode("button", {
|
||||
type: "button",
|
||||
class: normalizeClass(["toggle-switch", { "toggle-switch--active": __props.noCreditRequired }]),
|
||||
@@ -10122,7 +10122,7 @@ const _sfc_main$d = /* @__PURE__ */ defineComponent({
|
||||
};
|
||||
}
|
||||
});
|
||||
const LicenseSection = /* @__PURE__ */ _export_sfc(_sfc_main$d, [["__scopeId", "data-v-8b49983c"]]);
|
||||
const LicenseSection = /* @__PURE__ */ _export_sfc(_sfc_main$d, [["__scopeId", "data-v-66148794"]]);
|
||||
const _hoisted_1$c = { class: "preview" };
|
||||
const _hoisted_2$9 = { class: "preview__title" };
|
||||
const _hoisted_3$7 = ["disabled"];
|
||||
@@ -11756,7 +11756,6 @@ function useLoraRandomizerState(widget) {
|
||||
lastUsed.value = config.last_used || null;
|
||||
};
|
||||
const rollLoras = async (poolConfig, lockedLoras) => {
|
||||
var _a, _b, _c, _d, _e2, _f, _g, _h, _i, _j, _k, _l, _m;
|
||||
try {
|
||||
isRolling.value = true;
|
||||
const config = buildConfig();
|
||||
@@ -11775,15 +11774,7 @@ function useLoraRandomizerState(widget) {
|
||||
requestBody.count_max = config.count_max;
|
||||
}
|
||||
if (poolConfig) {
|
||||
requestBody.pool_config = {
|
||||
selected_base_models: ((_a = poolConfig.filters) == null ? void 0 : _a.baseModels) || [],
|
||||
include_tags: ((_c = (_b = poolConfig.filters) == null ? void 0 : _b.tags) == null ? void 0 : _c.include) || [],
|
||||
exclude_tags: ((_e2 = (_d = poolConfig.filters) == null ? void 0 : _d.tags) == null ? void 0 : _e2.exclude) || [],
|
||||
include_folders: ((_g = (_f = poolConfig.filters) == null ? void 0 : _f.folders) == null ? void 0 : _g.include) || [],
|
||||
exclude_folders: ((_i = (_h = poolConfig.filters) == null ? void 0 : _h.folders) == null ? void 0 : _i.exclude) || [],
|
||||
no_credit_required: ((_k = (_j = poolConfig.filters) == null ? void 0 : _j.license) == null ? void 0 : _k.noCreditRequired) || false,
|
||||
allow_selling: ((_m = (_l = poolConfig.filters) == null ? void 0 : _l.license) == null ? void 0 : _m.allowSelling) || false
|
||||
};
|
||||
requestBody.pool_config = poolConfig.filters || {};
|
||||
}
|
||||
const response = await fetch("/api/lm/loras/random-sample", {
|
||||
method: "POST",
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user