feat(randomizer): add LoRA locking and roll modes

- Implement LoRA locking to prevent specific LoRAs from being changed during randomization
- Add visual styling for locked state with amber accents and distinct backgrounds
- Introduce `roll_mode` configuration with 'backend' (execute current selection while generating new) and 'frontend' (execute newly generated selection) behaviors
- Move LoraPoolNode to 'Lora Manager/randomizer' category and remove standalone class mappings
- Standardize RETURN_NAMES in LoraRandomizerNode for consistency
This commit is contained in:
Will Miao
2026-01-12 21:53:47 +08:00
parent 177b20263d
commit bce6b0e610
13 changed files with 706 additions and 232 deletions

View File

@@ -8,24 +8,27 @@ from ..config import config
logger = logging.getLogger(__name__)
class LoraService(BaseModelService):
"""LoRA-specific service implementation"""
def __init__(self, scanner, update_service=None):
"""Initialize LoRA service
Args:
scanner: LoRA scanner instance
update_service: Optional service for remote update tracking.
"""
super().__init__("lora", scanner, LoraMetadata, update_service=update_service)
async def format_response(self, lora_data: Dict) -> Dict:
"""Format LoRA data for API response"""
return {
"model_name": lora_data["model_name"],
"file_name": lora_data["file_name"],
"preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")),
"preview_url": config.get_preview_static_url(
lora_data.get("preview_url", "")
),
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
"base_model": lora_data.get("base_model", ""),
"folder": lora_data["folder"],
@@ -40,141 +43,170 @@ class LoraService(BaseModelService):
"notes": lora_data.get("notes", ""),
"favorite": lora_data.get("favorite", False),
"update_available": bool(lora_data.get("update_available", False)),
"civitai": self.filter_civitai_data(lora_data.get("civitai", {}), minimal=True)
"civitai": self.filter_civitai_data(
lora_data.get("civitai", {}), minimal=True
),
}
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
"""Apply LoRA-specific filters"""
# Handle first_letter filter for LoRAs
first_letter = kwargs.get('first_letter')
first_letter = kwargs.get("first_letter")
if first_letter:
data = self._filter_by_first_letter(data, first_letter)
return data
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
"""Filter data by first letter of model name
Special handling:
- '#': Numbers (0-9)
- '@': Special characters (not alphanumeric)
- '': CJK characters
"""
filtered_data = []
for lora in data:
model_name = lora.get('model_name', '')
model_name = lora.get("model_name", "")
if not model_name:
continue
first_char = model_name[0].upper()
if letter == '#' and first_char.isdigit():
if letter == "#" and first_char.isdigit():
filtered_data.append(lora)
elif letter == '@' and not first_char.isalnum():
elif letter == "@" and not first_char.isalnum():
# Special characters (not alphanumeric)
filtered_data.append(lora)
elif letter == '' and self._is_cjk_character(first_char):
elif letter == "" and self._is_cjk_character(first_char):
# CJK characters
filtered_data.append(lora)
elif letter.upper() == first_char:
# Regular alphabet matching
filtered_data.append(lora)
return filtered_data
def _is_cjk_character(self, char: str) -> bool:
"""Check if character is a CJK character"""
# Define Unicode ranges for CJK characters
cjk_ranges = [
(0x4E00, 0x9FFF), # CJK Unified Ideographs
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
(0x3300, 0x33FF), # CJK Compatibility
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
(0x3100, 0x312F), # Bopomofo
(0x31A0, 0x31BF), # Bopomofo Extended
(0x3040, 0x309F), # Hiragana
(0x30A0, 0x30FF), # Katakana
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
(0xAC00, 0xD7AF), # Hangul Syllables
(0x1100, 0x11FF), # Hangul Jamo
(0xA960, 0xA97F), # Hangul Jamo Extended-A
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
(0x4E00, 0x9FFF), # CJK Unified Ideographs
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
(0x3300, 0x33FF), # CJK Compatibility
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
(0x3100, 0x312F), # Bopomofo
(0x31A0, 0x31BF), # Bopomofo Extended
(0x3040, 0x309F), # Hiragana
(0x30A0, 0x30FF), # Katakana
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
(0xAC00, 0xD7AF), # Hangul Syllables
(0x1100, 0x11FF), # Hangul Jamo
(0xA960, 0xA97F), # Hangul Jamo Extended-A
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
]
code_point = ord(char)
return any(start <= code_point <= end for start, end in cjk_ranges)
# LoRA-specific methods
async def get_letter_counts(self) -> Dict[str, int]:
"""Get count of LoRAs for each letter of the alphabet"""
cache = await self.scanner.get_cached_data()
data = cache.raw_data
# Define letter categories
letters = {
'#': 0, # Numbers
'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0,
'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0,
'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0,
'Y': 0, 'Z': 0,
'@': 0, # Special characters
'': 0 # CJK characters
"#": 0, # Numbers
"A": 0,
"B": 0,
"C": 0,
"D": 0,
"E": 0,
"F": 0,
"G": 0,
"H": 0,
"I": 0,
"J": 0,
"K": 0,
"L": 0,
"M": 0,
"N": 0,
"O": 0,
"P": 0,
"Q": 0,
"R": 0,
"S": 0,
"T": 0,
"U": 0,
"V": 0,
"W": 0,
"X": 0,
"Y": 0,
"Z": 0,
"@": 0, # Special characters
"": 0, # CJK characters
}
# Count models for each letter
for lora in data:
model_name = lora.get('model_name', '')
model_name = lora.get("model_name", "")
if not model_name:
continue
first_char = model_name[0].upper()
if first_char.isdigit():
letters['#'] += 1
letters["#"] += 1
elif first_char in letters:
letters[first_char] += 1
elif self._is_cjk_character(first_char):
letters[''] += 1
letters[""] += 1
elif not first_char.isalnum():
letters['@'] += 1
letters["@"] += 1
return letters
async def get_lora_trigger_words(self, lora_name: str) -> List[str]:
"""Get trigger words for a specific LoRA file"""
cache = await self.scanner.get_cached_data()
for lora in cache.raw_data:
if lora['file_name'] == lora_name:
civitai_data = lora.get('civitai', {})
return civitai_data.get('trainedWords', [])
if lora["file_name"] == lora_name:
civitai_data = lora.get("civitai", {})
return civitai_data.get("trainedWords", [])
return []
async def get_lora_usage_tips_by_relative_path(self, relative_path: str) -> Optional[str]:
async def get_lora_usage_tips_by_relative_path(
self, relative_path: str
) -> Optional[str]:
"""Get usage tips for a LoRA by its relative path"""
cache = await self.scanner.get_cached_data()
for lora in cache.raw_data:
file_path = lora.get('file_path', '')
file_path = lora.get("file_path", "")
if file_path:
# Convert to forward slashes and extract relative path
file_path_normalized = file_path.replace('\\', '/')
relative_path = relative_path.replace('\\', '/')
file_path_normalized = file_path.replace("\\", "/")
relative_path = relative_path.replace("\\", "/")
# Find the relative path part by looking for the relative_path in the full path
if file_path_normalized.endswith(relative_path) or relative_path in file_path_normalized:
return lora.get('usage_tips', '')
if (
file_path_normalized.endswith(relative_path)
or relative_path in file_path_normalized
):
return lora.get("usage_tips", "")
return None
def find_duplicate_hashes(self) -> Dict:
"""Find LoRAs with duplicate SHA256 hashes"""
return self.scanner._hash_index.get_duplicate_hashes()
@@ -192,7 +224,7 @@ class LoraService(BaseModelService):
clip_strength_min: float = 0.0,
clip_strength_max: float = 1.0,
locked_loras: Optional[List[Dict]] = None,
pool_config: Optional[Dict] = None
pool_config: Optional[Dict] = None,
) -> List[Dict]:
"""
Get random LoRAs with specified strength ranges.
@@ -235,10 +267,9 @@ class LoraService(BaseModelService):
locked_loras = locked_loras[:count]
# Filter out locked LoRAs from available pool
locked_names = {lora['name'] for lora in locked_loras}
locked_names = {lora["name"] for lora in locked_loras}
available_pool = [
l for l in available_loras
if l['model_name'] not in locked_names
l for l in available_loras if l["file_name"] not in locked_names
]
# Ensure we don't try to select more than available
@@ -253,9 +284,7 @@ class LoraService(BaseModelService):
# Generate random strengths for selected LoRAs
result_loras = []
for lora in selected:
model_str = round(
random.uniform(model_strength_min, model_strength_max), 2
)
model_str = round(random.uniform(model_strength_min, model_strength_max), 2)
if use_same_clip_strength:
clip_str = model_str
@@ -264,21 +293,25 @@ class LoraService(BaseModelService):
random.uniform(clip_strength_min, clip_strength_max), 2
)
result_loras.append({
'name': lora['model_name'],
'strength': model_str,
'clipStrength': clip_str,
'active': True,
'expanded': abs(model_str - clip_str) > 0.001,
'locked': False
})
result_loras.append(
{
"name": lora["file_name"],
"strength": model_str,
"clipStrength": clip_str,
"active": True,
"expanded": abs(model_str - clip_str) > 0.001,
"locked": False,
}
)
# Merge with locked LoRAs
result_loras.extend(locked_loras)
return result_loras
async def _apply_pool_filters(self, available_loras: List[Dict], pool_config: Dict) -> List[Dict]:
async def _apply_pool_filters(
self, available_loras: List[Dict], pool_config: Dict
) -> List[Dict]:
"""
Apply pool_config filters to available LoRAs.
@@ -292,26 +325,26 @@ class LoraService(BaseModelService):
from .model_query import FilterCriteria
# Extract filter parameters from pool_config
selected_base_models = pool_config.get('selected_base_models', [])
include_tags = pool_config.get('include_tags', [])
exclude_tags = pool_config.get('exclude_tags', [])
include_folders = pool_config.get('include_folders', [])
exclude_folders = pool_config.get('exclude_folders', [])
no_credit_required = pool_config.get('no_credit_required', False)
allow_selling = pool_config.get('allow_selling', False)
selected_base_models = pool_config.get("selected_base_models", [])
include_tags = pool_config.get("include_tags", [])
exclude_tags = pool_config.get("exclude_tags", [])
include_folders = pool_config.get("include_folders", [])
exclude_folders = pool_config.get("exclude_folders", [])
no_credit_required = pool_config.get("no_credit_required", False)
allow_selling = pool_config.get("allow_selling", False)
# Build tag filters dict
tag_filters = {}
for tag in include_tags:
tag_filters[tag] = 'include'
tag_filters[tag] = "include"
for tag in exclude_tags:
tag_filters[tag] = 'exclude'
tag_filters[tag] = "exclude"
# Build folder filter
if include_folders or exclude_folders:
filtered = []
for lora in available_loras:
folder = lora.get('folder', '')
folder = lora.get("folder", "")
# Check exclude folders first
excluded = False
@@ -340,8 +373,9 @@ class LoraService(BaseModelService):
# Apply base model filter
if selected_base_models:
available_loras = [
lora for lora in available_loras
if lora.get('base_model') in selected_base_models
lora
for lora in available_loras
if lora.get("base_model") in selected_base_models
]
# Apply tag filters
@@ -352,14 +386,17 @@ class LoraService(BaseModelService):
# Apply license filters
if no_credit_required:
available_loras = [
lora for lora in available_loras
if not lora.get('civitai', {}).get('allowNoCredit', True)
lora
for lora in available_loras
if not lora.get("civitai", {}).get("allowNoCredit", True)
]
if allow_selling:
available_loras = [
lora for lora in available_loras
if lora.get('civitai', {}).get('allowCommercialUse', ['None'])[0] != 'None'
lora
for lora in available_loras
if lora.get("civitai", {}).get("allowCommercialUse", ["None"])[0]
!= "None"
]
return available_loras