mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 14:42:11 -03:00
feat(randomizer): add LoRA locking and roll modes
- Implement LoRA locking to prevent specific LoRAs from being changed during randomization - Add visual styling for locked state with amber accents and distinct backgrounds - Introduce `roll_mode` configuration with 'backend' (execute current selection while generating new) and 'frontend' (execute newly generated selection) behaviors - Move LoraPoolNode to 'Lora Manager/randomizer' category and remove standalone class mappings - Standardize RETURN_NAMES in LoraRandomizerNode for consistency
This commit is contained in:
@@ -20,7 +20,7 @@ class LoraPoolNode:
|
||||
"""
|
||||
|
||||
NAME = "Lora Pool (LoraManager)"
|
||||
CATEGORY = "Lora Manager/pools"
|
||||
CATEGORY = "Lora Manager/randomizer"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
@@ -85,10 +85,3 @@ class LoraPoolNode:
|
||||
},
|
||||
"preview": {"matchCount": 0, "lastUpdated": 0},
|
||||
}
|
||||
|
||||
|
||||
# Node class mappings for ComfyUI
|
||||
NODE_CLASS_MAPPINGS = {"LoraPoolNode": LoraPoolNode}
|
||||
|
||||
# Display name mappings
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {"LoraPoolNode": "LoRA Pool (Filter)"}
|
||||
|
||||
@@ -34,7 +34,7 @@ class LoraRandomizerNode:
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LORA_STACK",)
|
||||
RETURN_NAMES = ("lora_stack",)
|
||||
RETURN_NAMES = ("LORA_STACK",)
|
||||
|
||||
FUNCTION = "randomize"
|
||||
OUTPUT_NODE = False
|
||||
@@ -53,9 +53,6 @@ class LoraRandomizerNode:
|
||||
"""
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
# Get lora scanner to access available loras
|
||||
scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Parse randomizer settings
|
||||
count_mode = randomizer_config.get("count_mode", "range")
|
||||
count_fixed = randomizer_config.get("count_fixed", 5)
|
||||
@@ -68,6 +65,93 @@ class LoraRandomizerNode:
|
||||
clip_strength_max = randomizer_config.get("clip_strength_max", 1.0)
|
||||
roll_mode = randomizer_config.get("roll_mode", "frontend")
|
||||
|
||||
# Get lora scanner to access available loras
|
||||
scanner = await ServiceRegistry.get_lora_scanner()
|
||||
|
||||
# Backend roll mode: execute with input loras, return new random to UI
|
||||
if roll_mode == "backend":
|
||||
execution_stack = self._build_execution_stack_from_input(loras)
|
||||
ui_loras = await self._generate_random_loras_for_ui(
|
||||
scanner, randomizer_config, loras, pool_config
|
||||
)
|
||||
logger.info(
|
||||
f"[LoraRandomizerNode] Backend roll: executing with input, returning new random to UI"
|
||||
)
|
||||
return {"result": (execution_stack,), "ui": {"loras": ui_loras}}
|
||||
|
||||
# Frontend roll mode: use current behavior (random selection for both)
|
||||
ui_loras = await self._generate_random_loras_for_ui(
|
||||
scanner, randomizer_config, loras, pool_config
|
||||
)
|
||||
execution_stack = self._build_execution_stack_from_input(ui_loras)
|
||||
logger.info(
|
||||
f"[LoraRandomizerNode] Frontend roll: executing with random selection"
|
||||
)
|
||||
return {"result": (execution_stack,), "ui": {"loras": ui_loras}}
|
||||
|
||||
def _build_execution_stack_from_input(self, loras):
|
||||
"""
|
||||
Build LORA_STACK tuple from input loras list for execution.
|
||||
|
||||
Args:
|
||||
loras: List of LoRA dicts with name, strength, clipStrength, active
|
||||
|
||||
Returns:
|
||||
List of tuples (lora_path, model_strength, clip_strength)
|
||||
"""
|
||||
lora_stack = []
|
||||
for lora in loras:
|
||||
if not lora.get("active", False):
|
||||
continue
|
||||
|
||||
# Get file path
|
||||
lora_path, trigger_words = get_lora_info(lora["name"])
|
||||
if not lora_path:
|
||||
logger.warning(
|
||||
f"[LoraRandomizerNode] Could not find path for LoRA: {lora['name']}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Normalize path separators
|
||||
lora_path = lora_path.replace("/", os.sep)
|
||||
|
||||
# Extract strengths
|
||||
model_strength = lora.get("strength", 1.0)
|
||||
clip_strength = lora.get("clipStrength", model_strength)
|
||||
|
||||
lora_stack.append((lora_path, model_strength, clip_strength))
|
||||
|
||||
logger.info(
|
||||
f"[LoraRandomizerNode] Built execution stack with {len(lora_stack)} LoRAs"
|
||||
)
|
||||
return lora_stack
|
||||
|
||||
async def _generate_random_loras_for_ui(
|
||||
self, scanner, randomizer_config, input_loras, pool_config=None
|
||||
):
|
||||
"""
|
||||
Generate new random loras for UI display.
|
||||
|
||||
Args:
|
||||
scanner: LoraScanner instance
|
||||
randomizer_config: Dict with randomizer settings
|
||||
input_loras: Current input loras (for extracting locked loras)
|
||||
pool_config: Optional pool filters
|
||||
|
||||
Returns:
|
||||
List of LoRA dicts for UI display
|
||||
"""
|
||||
# Parse randomizer settings
|
||||
count_mode = randomizer_config.get("count_mode", "range")
|
||||
count_fixed = randomizer_config.get("count_fixed", 5)
|
||||
count_min = randomizer_config.get("count_min", 3)
|
||||
count_max = randomizer_config.get("count_max", 7)
|
||||
model_strength_min = randomizer_config.get("model_strength_min", 0.0)
|
||||
model_strength_max = randomizer_config.get("model_strength_max", 1.0)
|
||||
use_same_clip_strength = randomizer_config.get("use_same_clip_strength", True)
|
||||
clip_strength_min = randomizer_config.get("clip_strength_min", 0.0)
|
||||
clip_strength_max = randomizer_config.get("clip_strength_max", 1.0)
|
||||
|
||||
# Determine target count
|
||||
if count_mode == "fixed":
|
||||
target_count = count_fixed
|
||||
@@ -75,11 +159,11 @@ class LoraRandomizerNode:
|
||||
target_count = random.randint(count_min, count_max)
|
||||
|
||||
logger.info(
|
||||
f"[LoraRandomizerNode] Target count: {target_count}, Roll mode: {roll_mode}"
|
||||
f"[LoraRandomizerNode] Generating random LoRAs, target count: {target_count}"
|
||||
)
|
||||
|
||||
# Extract locked LoRAs from input
|
||||
locked_loras = [lora for lora in loras if lora.get("locked", False)]
|
||||
locked_loras = [lora for lora in input_loras if lora.get("locked", False)]
|
||||
locked_count = len(locked_loras)
|
||||
|
||||
logger.info(f"[LoraRandomizerNode] Locked LoRAs: {locked_count}")
|
||||
@@ -106,8 +190,6 @@ class LoraRandomizerNode:
|
||||
)
|
||||
|
||||
# Calculate how many new LoRAs to select
|
||||
# In frontend mode, if loras already has data, preserve unlocked ones if roll_mode requires
|
||||
# For simplicity in backend mode, we regenerate all unlocked slots
|
||||
slots_needed = target_count - locked_count
|
||||
|
||||
if slots_needed < 0:
|
||||
@@ -161,33 +243,10 @@ class LoraRandomizerNode:
|
||||
# Merge with locked LoRAs
|
||||
result_loras.extend(locked_loras)
|
||||
|
||||
logger.info(f"[LoraRandomizerNode] Final LoRA count: {len(result_loras)}")
|
||||
|
||||
# Build LORA_STACK output
|
||||
lora_stack = []
|
||||
for lora in result_loras:
|
||||
if not lora.get("active", False):
|
||||
continue
|
||||
|
||||
# Get file path
|
||||
lora_path, trigger_words = get_lora_info(lora["name"])
|
||||
if not lora_path:
|
||||
logger.warning(
|
||||
f"[LoraRandomizerNode] Could not find path for LoRA: {lora['name']}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Normalize path separators
|
||||
lora_path = lora_path.replace("/", os.sep)
|
||||
|
||||
# Extract strengths
|
||||
model_strength = lora.get("strength", 1.0)
|
||||
clip_strength = lora.get("clipStrength", model_strength)
|
||||
|
||||
lora_stack.append((lora_path, model_strength, clip_strength))
|
||||
|
||||
# Return format: result for workflow + ui for frontend display
|
||||
return {"result": (lora_stack,), "ui": {"loras": result_loras}}
|
||||
logger.info(
|
||||
f"[LoraRandomizerNode] Final random LoRA count: {len(result_loras)}"
|
||||
)
|
||||
return result_loras
|
||||
|
||||
async def _apply_pool_filters(self, available_loras, pool_config, scanner):
|
||||
"""
|
||||
@@ -288,10 +347,3 @@ class LoraRandomizerNode:
|
||||
]
|
||||
|
||||
return available_loras
|
||||
|
||||
|
||||
# Node class mappings for ComfyUI
|
||||
NODE_CLASS_MAPPINGS = {"LoraRandomizerNode": LoraRandomizerNode}
|
||||
|
||||
# Display name mappings
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {"LoraRandomizerNode": "LoRA Randomizer"}
|
||||
|
||||
@@ -8,24 +8,27 @@ from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoraService(BaseModelService):
|
||||
"""LoRA-specific service implementation"""
|
||||
|
||||
|
||||
def __init__(self, scanner, update_service=None):
|
||||
"""Initialize LoRA service
|
||||
|
||||
|
||||
Args:
|
||||
scanner: LoRA scanner instance
|
||||
update_service: Optional service for remote update tracking.
|
||||
"""
|
||||
super().__init__("lora", scanner, LoraMetadata, update_service=update_service)
|
||||
|
||||
|
||||
async def format_response(self, lora_data: Dict) -> Dict:
|
||||
"""Format LoRA data for API response"""
|
||||
return {
|
||||
"model_name": lora_data["model_name"],
|
||||
"file_name": lora_data["file_name"],
|
||||
"preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")),
|
||||
"preview_url": config.get_preview_static_url(
|
||||
lora_data.get("preview_url", "")
|
||||
),
|
||||
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
|
||||
"base_model": lora_data.get("base_model", ""),
|
||||
"folder": lora_data["folder"],
|
||||
@@ -40,141 +43,170 @@ class LoraService(BaseModelService):
|
||||
"notes": lora_data.get("notes", ""),
|
||||
"favorite": lora_data.get("favorite", False),
|
||||
"update_available": bool(lora_data.get("update_available", False)),
|
||||
"civitai": self.filter_civitai_data(lora_data.get("civitai", {}), minimal=True)
|
||||
"civitai": self.filter_civitai_data(
|
||||
lora_data.get("civitai", {}), minimal=True
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||
"""Apply LoRA-specific filters"""
|
||||
# Handle first_letter filter for LoRAs
|
||||
first_letter = kwargs.get('first_letter')
|
||||
first_letter = kwargs.get("first_letter")
|
||||
if first_letter:
|
||||
data = self._filter_by_first_letter(data, first_letter)
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
|
||||
"""Filter data by first letter of model name
|
||||
|
||||
|
||||
Special handling:
|
||||
- '#': Numbers (0-9)
|
||||
- '@': Special characters (not alphanumeric)
|
||||
- '漢': CJK characters
|
||||
"""
|
||||
filtered_data = []
|
||||
|
||||
|
||||
for lora in data:
|
||||
model_name = lora.get('model_name', '')
|
||||
model_name = lora.get("model_name", "")
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
|
||||
first_char = model_name[0].upper()
|
||||
|
||||
if letter == '#' and first_char.isdigit():
|
||||
|
||||
if letter == "#" and first_char.isdigit():
|
||||
filtered_data.append(lora)
|
||||
elif letter == '@' and not first_char.isalnum():
|
||||
elif letter == "@" and not first_char.isalnum():
|
||||
# Special characters (not alphanumeric)
|
||||
filtered_data.append(lora)
|
||||
elif letter == '漢' and self._is_cjk_character(first_char):
|
||||
elif letter == "漢" and self._is_cjk_character(first_char):
|
||||
# CJK characters
|
||||
filtered_data.append(lora)
|
||||
elif letter.upper() == first_char:
|
||||
# Regular alphabet matching
|
||||
filtered_data.append(lora)
|
||||
|
||||
|
||||
return filtered_data
|
||||
|
||||
|
||||
def _is_cjk_character(self, char: str) -> bool:
|
||||
"""Check if character is a CJK character"""
|
||||
# Define Unicode ranges for CJK characters
|
||||
cjk_ranges = [
|
||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
|
||||
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
|
||||
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
|
||||
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
|
||||
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
|
||||
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
|
||||
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
|
||||
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
|
||||
(0x3300, 0x33FF), # CJK Compatibility
|
||||
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
|
||||
(0x3100, 0x312F), # Bopomofo
|
||||
(0x31A0, 0x31BF), # Bopomofo Extended
|
||||
(0x3040, 0x309F), # Hiragana
|
||||
(0x30A0, 0x30FF), # Katakana
|
||||
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
|
||||
(0xAC00, 0xD7AF), # Hangul Syllables
|
||||
(0x1100, 0x11FF), # Hangul Jamo
|
||||
(0xA960, 0xA97F), # Hangul Jamo Extended-A
|
||||
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
|
||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
|
||||
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
|
||||
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
|
||||
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
|
||||
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
|
||||
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
|
||||
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
|
||||
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
|
||||
(0x3300, 0x33FF), # CJK Compatibility
|
||||
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
|
||||
(0x3100, 0x312F), # Bopomofo
|
||||
(0x31A0, 0x31BF), # Bopomofo Extended
|
||||
(0x3040, 0x309F), # Hiragana
|
||||
(0x30A0, 0x30FF), # Katakana
|
||||
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
|
||||
(0xAC00, 0xD7AF), # Hangul Syllables
|
||||
(0x1100, 0x11FF), # Hangul Jamo
|
||||
(0xA960, 0xA97F), # Hangul Jamo Extended-A
|
||||
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
|
||||
]
|
||||
|
||||
|
||||
code_point = ord(char)
|
||||
return any(start <= code_point <= end for start, end in cjk_ranges)
|
||||
|
||||
|
||||
# LoRA-specific methods
|
||||
async def get_letter_counts(self) -> Dict[str, int]:
|
||||
"""Get count of LoRAs for each letter of the alphabet"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
data = cache.raw_data
|
||||
|
||||
|
||||
# Define letter categories
|
||||
letters = {
|
||||
'#': 0, # Numbers
|
||||
'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0,
|
||||
'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0,
|
||||
'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0,
|
||||
'Y': 0, 'Z': 0,
|
||||
'@': 0, # Special characters
|
||||
'漢': 0 # CJK characters
|
||||
"#": 0, # Numbers
|
||||
"A": 0,
|
||||
"B": 0,
|
||||
"C": 0,
|
||||
"D": 0,
|
||||
"E": 0,
|
||||
"F": 0,
|
||||
"G": 0,
|
||||
"H": 0,
|
||||
"I": 0,
|
||||
"J": 0,
|
||||
"K": 0,
|
||||
"L": 0,
|
||||
"M": 0,
|
||||
"N": 0,
|
||||
"O": 0,
|
||||
"P": 0,
|
||||
"Q": 0,
|
||||
"R": 0,
|
||||
"S": 0,
|
||||
"T": 0,
|
||||
"U": 0,
|
||||
"V": 0,
|
||||
"W": 0,
|
||||
"X": 0,
|
||||
"Y": 0,
|
||||
"Z": 0,
|
||||
"@": 0, # Special characters
|
||||
"漢": 0, # CJK characters
|
||||
}
|
||||
|
||||
|
||||
# Count models for each letter
|
||||
for lora in data:
|
||||
model_name = lora.get('model_name', '')
|
||||
model_name = lora.get("model_name", "")
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
|
||||
first_char = model_name[0].upper()
|
||||
|
||||
|
||||
if first_char.isdigit():
|
||||
letters['#'] += 1
|
||||
letters["#"] += 1
|
||||
elif first_char in letters:
|
||||
letters[first_char] += 1
|
||||
elif self._is_cjk_character(first_char):
|
||||
letters['漢'] += 1
|
||||
letters["漢"] += 1
|
||||
elif not first_char.isalnum():
|
||||
letters['@'] += 1
|
||||
|
||||
letters["@"] += 1
|
||||
|
||||
return letters
|
||||
|
||||
|
||||
async def get_lora_trigger_words(self, lora_name: str) -> List[str]:
|
||||
"""Get trigger words for a specific LoRA file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
|
||||
for lora in cache.raw_data:
|
||||
if lora['file_name'] == lora_name:
|
||||
civitai_data = lora.get('civitai', {})
|
||||
return civitai_data.get('trainedWords', [])
|
||||
|
||||
if lora["file_name"] == lora_name:
|
||||
civitai_data = lora.get("civitai", {})
|
||||
return civitai_data.get("trainedWords", [])
|
||||
|
||||
return []
|
||||
|
||||
async def get_lora_usage_tips_by_relative_path(self, relative_path: str) -> Optional[str]:
|
||||
|
||||
async def get_lora_usage_tips_by_relative_path(
|
||||
self, relative_path: str
|
||||
) -> Optional[str]:
|
||||
"""Get usage tips for a LoRA by its relative path"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
|
||||
for lora in cache.raw_data:
|
||||
file_path = lora.get('file_path', '')
|
||||
file_path = lora.get("file_path", "")
|
||||
if file_path:
|
||||
# Convert to forward slashes and extract relative path
|
||||
file_path_normalized = file_path.replace('\\', '/')
|
||||
relative_path = relative_path.replace('\\', '/')
|
||||
file_path_normalized = file_path.replace("\\", "/")
|
||||
relative_path = relative_path.replace("\\", "/")
|
||||
# Find the relative path part by looking for the relative_path in the full path
|
||||
if file_path_normalized.endswith(relative_path) or relative_path in file_path_normalized:
|
||||
return lora.get('usage_tips', '')
|
||||
|
||||
if (
|
||||
file_path_normalized.endswith(relative_path)
|
||||
or relative_path in file_path_normalized
|
||||
):
|
||||
return lora.get("usage_tips", "")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find LoRAs with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
@@ -192,7 +224,7 @@ class LoraService(BaseModelService):
|
||||
clip_strength_min: float = 0.0,
|
||||
clip_strength_max: float = 1.0,
|
||||
locked_loras: Optional[List[Dict]] = None,
|
||||
pool_config: Optional[Dict] = None
|
||||
pool_config: Optional[Dict] = None,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Get random LoRAs with specified strength ranges.
|
||||
@@ -235,10 +267,9 @@ class LoraService(BaseModelService):
|
||||
locked_loras = locked_loras[:count]
|
||||
|
||||
# Filter out locked LoRAs from available pool
|
||||
locked_names = {lora['name'] for lora in locked_loras}
|
||||
locked_names = {lora["name"] for lora in locked_loras}
|
||||
available_pool = [
|
||||
l for l in available_loras
|
||||
if l['model_name'] not in locked_names
|
||||
l for l in available_loras if l["file_name"] not in locked_names
|
||||
]
|
||||
|
||||
# Ensure we don't try to select more than available
|
||||
@@ -253,9 +284,7 @@ class LoraService(BaseModelService):
|
||||
# Generate random strengths for selected LoRAs
|
||||
result_loras = []
|
||||
for lora in selected:
|
||||
model_str = round(
|
||||
random.uniform(model_strength_min, model_strength_max), 2
|
||||
)
|
||||
model_str = round(random.uniform(model_strength_min, model_strength_max), 2)
|
||||
|
||||
if use_same_clip_strength:
|
||||
clip_str = model_str
|
||||
@@ -264,21 +293,25 @@ class LoraService(BaseModelService):
|
||||
random.uniform(clip_strength_min, clip_strength_max), 2
|
||||
)
|
||||
|
||||
result_loras.append({
|
||||
'name': lora['model_name'],
|
||||
'strength': model_str,
|
||||
'clipStrength': clip_str,
|
||||
'active': True,
|
||||
'expanded': abs(model_str - clip_str) > 0.001,
|
||||
'locked': False
|
||||
})
|
||||
result_loras.append(
|
||||
{
|
||||
"name": lora["file_name"],
|
||||
"strength": model_str,
|
||||
"clipStrength": clip_str,
|
||||
"active": True,
|
||||
"expanded": abs(model_str - clip_str) > 0.001,
|
||||
"locked": False,
|
||||
}
|
||||
)
|
||||
|
||||
# Merge with locked LoRAs
|
||||
result_loras.extend(locked_loras)
|
||||
|
||||
return result_loras
|
||||
|
||||
async def _apply_pool_filters(self, available_loras: List[Dict], pool_config: Dict) -> List[Dict]:
|
||||
async def _apply_pool_filters(
|
||||
self, available_loras: List[Dict], pool_config: Dict
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Apply pool_config filters to available LoRAs.
|
||||
|
||||
@@ -292,26 +325,26 @@ class LoraService(BaseModelService):
|
||||
from .model_query import FilterCriteria
|
||||
|
||||
# Extract filter parameters from pool_config
|
||||
selected_base_models = pool_config.get('selected_base_models', [])
|
||||
include_tags = pool_config.get('include_tags', [])
|
||||
exclude_tags = pool_config.get('exclude_tags', [])
|
||||
include_folders = pool_config.get('include_folders', [])
|
||||
exclude_folders = pool_config.get('exclude_folders', [])
|
||||
no_credit_required = pool_config.get('no_credit_required', False)
|
||||
allow_selling = pool_config.get('allow_selling', False)
|
||||
selected_base_models = pool_config.get("selected_base_models", [])
|
||||
include_tags = pool_config.get("include_tags", [])
|
||||
exclude_tags = pool_config.get("exclude_tags", [])
|
||||
include_folders = pool_config.get("include_folders", [])
|
||||
exclude_folders = pool_config.get("exclude_folders", [])
|
||||
no_credit_required = pool_config.get("no_credit_required", False)
|
||||
allow_selling = pool_config.get("allow_selling", False)
|
||||
|
||||
# Build tag filters dict
|
||||
tag_filters = {}
|
||||
for tag in include_tags:
|
||||
tag_filters[tag] = 'include'
|
||||
tag_filters[tag] = "include"
|
||||
for tag in exclude_tags:
|
||||
tag_filters[tag] = 'exclude'
|
||||
tag_filters[tag] = "exclude"
|
||||
|
||||
# Build folder filter
|
||||
if include_folders or exclude_folders:
|
||||
filtered = []
|
||||
for lora in available_loras:
|
||||
folder = lora.get('folder', '')
|
||||
folder = lora.get("folder", "")
|
||||
|
||||
# Check exclude folders first
|
||||
excluded = False
|
||||
@@ -340,8 +373,9 @@ class LoraService(BaseModelService):
|
||||
# Apply base model filter
|
||||
if selected_base_models:
|
||||
available_loras = [
|
||||
lora for lora in available_loras
|
||||
if lora.get('base_model') in selected_base_models
|
||||
lora
|
||||
for lora in available_loras
|
||||
if lora.get("base_model") in selected_base_models
|
||||
]
|
||||
|
||||
# Apply tag filters
|
||||
@@ -352,14 +386,17 @@ class LoraService(BaseModelService):
|
||||
# Apply license filters
|
||||
if no_credit_required:
|
||||
available_loras = [
|
||||
lora for lora in available_loras
|
||||
if not lora.get('civitai', {}).get('allowNoCredit', True)
|
||||
lora
|
||||
for lora in available_loras
|
||||
if not lora.get("civitai", {}).get("allowNoCredit", True)
|
||||
]
|
||||
|
||||
if allow_selling:
|
||||
available_loras = [
|
||||
lora for lora in available_loras
|
||||
if lora.get('civitai', {}).get('allowCommercialUse', ['None'])[0] != 'None'
|
||||
lora
|
||||
for lora in available_loras
|
||||
if lora.get("civitai", {}).get("allowCommercialUse", ["None"])[0]
|
||||
!= "None"
|
||||
]
|
||||
|
||||
return available_loras
|
||||
|
||||
Reference in New Issue
Block a user