mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-26 15:38:52 -03:00
refactor: Simplify filtering methods and enhance CJK character handling in LoraService
This commit is contained in:
@@ -7,9 +7,6 @@ from ..utils.models import LoraMetadata
|
|||||||
from ..config import config
|
from ..config import config
|
||||||
from .model_scanner import ModelScanner
|
from .model_scanner import ModelScanner
|
||||||
from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex
|
from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex
|
||||||
from .settings_manager import settings
|
|
||||||
from ..utils.constants import NSFW_LEVELS
|
|
||||||
from ..utils.utils import fuzzy_match
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -115,260 +112,6 @@ class LoraScanner(ModelScanner):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing {file_path}: {e}")
|
logger.error(f"Error processing {file_path}: {e}")
|
||||||
|
|
||||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
|
||||||
folder: str = None, search: str = None, fuzzy_search: bool = False,
|
|
||||||
base_models: list = None, tags: list = None,
|
|
||||||
search_options: dict = None, hash_filters: dict = None,
|
|
||||||
favorites_only: bool = False, first_letter: str = None) -> Dict:
|
|
||||||
"""Get paginated and filtered lora data
|
|
||||||
|
|
||||||
Args:
|
|
||||||
page: Current page number (1-based)
|
|
||||||
page_size: Number of items per page
|
|
||||||
sort_by: Sort method ('name' or 'date')
|
|
||||||
folder: Filter by folder path
|
|
||||||
search: Search term
|
|
||||||
fuzzy_search: Use fuzzy matching for search
|
|
||||||
base_models: List of base models to filter by
|
|
||||||
tags: List of tags to filter by
|
|
||||||
search_options: Dictionary with search options (filename, modelname, tags, recursive)
|
|
||||||
hash_filters: Dictionary with hash filtering options (single_hash or multiple_hashes)
|
|
||||||
favorites_only: Filter for favorite models only
|
|
||||||
first_letter: Filter by first letter of model name
|
|
||||||
"""
|
|
||||||
cache = await self.get_cached_data()
|
|
||||||
|
|
||||||
# Get default search options if not provided
|
|
||||||
if search_options is None:
|
|
||||||
search_options = {
|
|
||||||
'filename': True,
|
|
||||||
'modelname': True,
|
|
||||||
'tags': False,
|
|
||||||
'recursive': False,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get the base data set
|
|
||||||
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
|
||||||
|
|
||||||
# Apply hash filtering if provided (highest priority)
|
|
||||||
if hash_filters:
|
|
||||||
single_hash = hash_filters.get('single_hash')
|
|
||||||
multiple_hashes = hash_filters.get('multiple_hashes')
|
|
||||||
|
|
||||||
if single_hash:
|
|
||||||
# Filter by single hash
|
|
||||||
single_hash = single_hash.lower() # Ensure lowercase for matching
|
|
||||||
filtered_data = [
|
|
||||||
lora for lora in filtered_data
|
|
||||||
if lora.get('sha256', '').lower() == single_hash
|
|
||||||
]
|
|
||||||
elif multiple_hashes:
|
|
||||||
# Filter by multiple hashes
|
|
||||||
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
|
|
||||||
filtered_data = [
|
|
||||||
lora for lora in filtered_data
|
|
||||||
if lora.get('sha256', '').lower() in hash_set
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Jump to pagination
|
|
||||||
total_items = len(filtered_data)
|
|
||||||
start_idx = (page - 1) * page_size
|
|
||||||
end_idx = min(start_idx + page_size, total_items)
|
|
||||||
|
|
||||||
result = {
|
|
||||||
'items': filtered_data[start_idx:end_idx],
|
|
||||||
'total': total_items,
|
|
||||||
'page': page,
|
|
||||||
'page_size': page_size,
|
|
||||||
'total_pages': (total_items + page_size - 1) // page_size
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
# Apply SFW filtering if enabled
|
|
||||||
if settings.get('show_only_sfw', False):
|
|
||||||
filtered_data = [
|
|
||||||
lora for lora in filtered_data
|
|
||||||
if not lora.get('preview_nsfw_level') or lora.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply favorites filtering if enabled
|
|
||||||
if favorites_only:
|
|
||||||
filtered_data = [
|
|
||||||
lora for lora in filtered_data
|
|
||||||
if lora.get('favorite', False) is True
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply first letter filtering
|
|
||||||
if first_letter:
|
|
||||||
filtered_data = self._filter_by_first_letter(filtered_data, first_letter)
|
|
||||||
|
|
||||||
# Apply folder filtering
|
|
||||||
if folder is not None:
|
|
||||||
if search_options.get('recursive', False):
|
|
||||||
# Recursive folder filtering - include all subfolders
|
|
||||||
filtered_data = [
|
|
||||||
lora for lora in filtered_data
|
|
||||||
if lora['folder'].startswith(folder)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# Exact folder filtering
|
|
||||||
filtered_data = [
|
|
||||||
lora for lora in filtered_data
|
|
||||||
if lora['folder'] == folder
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply base model filtering
|
|
||||||
if base_models and len(base_models) > 0:
|
|
||||||
filtered_data = [
|
|
||||||
lora for lora in filtered_data
|
|
||||||
if lora.get('base_model') in base_models
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply tag filtering
|
|
||||||
if tags and len(tags) > 0:
|
|
||||||
filtered_data = [
|
|
||||||
lora for lora in filtered_data
|
|
||||||
if any(tag in lora.get('tags', []) for tag in tags)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply search filtering
|
|
||||||
if search:
|
|
||||||
search_results = []
|
|
||||||
search_opts = search_options or {}
|
|
||||||
|
|
||||||
for lora in filtered_data:
|
|
||||||
# Search by file name
|
|
||||||
if search_opts.get('filename', True):
|
|
||||||
if fuzzy_match(lora.get('file_name', ''), search):
|
|
||||||
search_results.append(lora)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Search by model name
|
|
||||||
if search_opts.get('modelname', True):
|
|
||||||
if fuzzy_match(lora.get('model_name', ''), search):
|
|
||||||
search_results.append(lora)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Search by tags
|
|
||||||
if search_opts.get('tags', False) and 'tags' in lora:
|
|
||||||
if any(fuzzy_match(tag, search) for tag in lora['tags']):
|
|
||||||
search_results.append(lora)
|
|
||||||
continue
|
|
||||||
|
|
||||||
filtered_data = search_results
|
|
||||||
|
|
||||||
# Calculate pagination
|
|
||||||
total_items = len(filtered_data)
|
|
||||||
start_idx = (page - 1) * page_size
|
|
||||||
end_idx = min(start_idx + page_size, total_items)
|
|
||||||
|
|
||||||
result = {
|
|
||||||
'items': filtered_data[start_idx:end_idx],
|
|
||||||
'total': total_items,
|
|
||||||
'page': page,
|
|
||||||
'page_size': page_size,
|
|
||||||
'total_pages': (total_items + page_size - 1) // page_size
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _filter_by_first_letter(self, data, letter):
|
|
||||||
"""Filter data by first letter of model name
|
|
||||||
|
|
||||||
Special handling:
|
|
||||||
- '#': Numbers (0-9)
|
|
||||||
- '@': Special characters (not alphanumeric)
|
|
||||||
- '漢': CJK characters
|
|
||||||
"""
|
|
||||||
filtered_data = []
|
|
||||||
|
|
||||||
for lora in data:
|
|
||||||
model_name = lora.get('model_name', '')
|
|
||||||
if not model_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
first_char = model_name[0].upper()
|
|
||||||
|
|
||||||
if letter == '#' and first_char.isdigit():
|
|
||||||
filtered_data.append(lora)
|
|
||||||
elif letter == '@' and not first_char.isalnum():
|
|
||||||
# Special characters (not alphanumeric)
|
|
||||||
filtered_data.append(lora)
|
|
||||||
elif letter == '漢' and self._is_cjk_character(first_char):
|
|
||||||
# CJK characters
|
|
||||||
filtered_data.append(lora)
|
|
||||||
elif letter.upper() == first_char:
|
|
||||||
# Regular alphabet matching
|
|
||||||
filtered_data.append(lora)
|
|
||||||
|
|
||||||
return filtered_data
|
|
||||||
|
|
||||||
def _is_cjk_character(self, char):
|
|
||||||
"""Check if character is a CJK character"""
|
|
||||||
# Define Unicode ranges for CJK characters
|
|
||||||
cjk_ranges = [
|
|
||||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
|
||||||
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
|
|
||||||
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
|
|
||||||
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
|
|
||||||
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
|
|
||||||
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
|
|
||||||
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
|
|
||||||
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
|
|
||||||
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
|
|
||||||
(0x3300, 0x33FF), # CJK Compatibility
|
|
||||||
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
|
|
||||||
(0x3100, 0x312F), # Bopomofo
|
|
||||||
(0x31A0, 0x31BF), # Bopomofo Extended
|
|
||||||
(0x3040, 0x309F), # Hiragana
|
|
||||||
(0x30A0, 0x30FF), # Katakana
|
|
||||||
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
|
|
||||||
(0xAC00, 0xD7AF), # Hangul Syllables
|
|
||||||
(0x1100, 0x11FF), # Hangul Jamo
|
|
||||||
(0xA960, 0xA97F), # Hangul Jamo Extended-A
|
|
||||||
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
|
|
||||||
]
|
|
||||||
|
|
||||||
code_point = ord(char)
|
|
||||||
return any(start <= code_point <= end for start, end in cjk_ranges)
|
|
||||||
|
|
||||||
async def get_letter_counts(self):
|
|
||||||
"""Get count of models for each letter of the alphabet"""
|
|
||||||
cache = await self.get_cached_data()
|
|
||||||
data = cache.sorted_by_name
|
|
||||||
|
|
||||||
# Define letter categories
|
|
||||||
letters = {
|
|
||||||
'#': 0, # Numbers
|
|
||||||
'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0,
|
|
||||||
'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0,
|
|
||||||
'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0,
|
|
||||||
'Y': 0, 'Z': 0,
|
|
||||||
'@': 0, # Special characters
|
|
||||||
'漢': 0 # CJK characters
|
|
||||||
}
|
|
||||||
|
|
||||||
# Count models for each letter
|
|
||||||
for lora in data:
|
|
||||||
model_name = lora.get('model_name', '')
|
|
||||||
if not model_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
first_char = model_name[0].upper()
|
|
||||||
|
|
||||||
if first_char.isdigit():
|
|
||||||
letters['#'] += 1
|
|
||||||
elif first_char in letters:
|
|
||||||
letters[first_char] += 1
|
|
||||||
elif self._is_cjk_character(first_char):
|
|
||||||
letters['漢'] += 1
|
|
||||||
elif not first_char.isalnum():
|
|
||||||
letters['@'] += 1
|
|
||||||
|
|
||||||
return letters
|
|
||||||
|
|
||||||
# Lora-specific hash index functionality
|
# Lora-specific hash index functionality
|
||||||
def has_lora_hash(self, sha256: str) -> bool:
|
def has_lora_hash(self, sha256: str) -> bool:
|
||||||
"""Check if a LoRA with given hash exists"""
|
"""Check if a LoRA with given hash exists"""
|
||||||
@@ -382,40 +125,6 @@ class LoraScanner(ModelScanner):
|
|||||||
"""Get hash for a LoRA by its file path"""
|
"""Get hash for a LoRA by its file path"""
|
||||||
return self.get_hash_by_path(file_path)
|
return self.get_hash_by_path(file_path)
|
||||||
|
|
||||||
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
|
|
||||||
"""Get top tags sorted by count"""
|
|
||||||
# Make sure cache is initialized
|
|
||||||
await self.get_cached_data()
|
|
||||||
|
|
||||||
# Sort tags by count in descending order
|
|
||||||
sorted_tags = sorted(
|
|
||||||
[{"tag": tag, "count": count} for tag, count in self._tags_count.items()],
|
|
||||||
key=lambda x: x['count'],
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return limited number
|
|
||||||
return sorted_tags[:limit]
|
|
||||||
|
|
||||||
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
|
|
||||||
"""Get base models used in loras sorted by frequency"""
|
|
||||||
# Make sure cache is initialized
|
|
||||||
cache = await self.get_cached_data()
|
|
||||||
|
|
||||||
# Count base model occurrences
|
|
||||||
base_model_counts = {}
|
|
||||||
for lora in cache.raw_data:
|
|
||||||
if 'base_model' in lora and lora['base_model']:
|
|
||||||
base_model = lora['base_model']
|
|
||||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
|
||||||
|
|
||||||
# Sort base models by count
|
|
||||||
sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()]
|
|
||||||
sorted_models.sort(key=lambda x: x['count'], reverse=True)
|
|
||||||
|
|
||||||
# Return limited number
|
|
||||||
return sorted_models[:limit]
|
|
||||||
|
|
||||||
async def diagnose_hash_index(self):
|
async def diagnose_hash_index(self):
|
||||||
"""Diagnostic method to verify hash index functionality"""
|
"""Diagnostic method to verify hash index functionality"""
|
||||||
print("\n\n*** DIAGNOSING LORA HASH INDEX ***\n\n", file=sys.stderr)
|
print("\n\n*** DIAGNOSING LORA HASH INDEX ***\n\n", file=sys.stderr)
|
||||||
|
|||||||
@@ -52,60 +52,100 @@ class LoraService(BaseModelService):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
|
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
|
||||||
"""Filter LoRAs by first letter"""
|
"""Filter data by first letter of model name
|
||||||
if letter == '#':
|
|
||||||
# Filter for non-alphabetic characters
|
Special handling:
|
||||||
return [
|
- '#': Numbers (0-9)
|
||||||
item for item in data
|
- '@': Special characters (not alphanumeric)
|
||||||
if not item.get('model_name', '')[0].isalpha()
|
- '漢': CJK characters
|
||||||
]
|
"""
|
||||||
elif letter == 'CJK':
|
filtered_data = []
|
||||||
# Filter for CJK characters
|
|
||||||
return [
|
for lora in data:
|
||||||
item for item in data
|
model_name = lora.get('model_name', '')
|
||||||
if item.get('model_name', '') and self._is_cjk_character(item['model_name'][0])
|
if not model_name:
|
||||||
]
|
continue
|
||||||
else:
|
|
||||||
# Filter for specific letter
|
first_char = model_name[0].upper()
|
||||||
return [
|
|
||||||
item for item in data
|
if letter == '#' and first_char.isdigit():
|
||||||
if item.get('model_name', '').lower().startswith(letter.lower())
|
filtered_data.append(lora)
|
||||||
]
|
elif letter == '@' and not first_char.isalnum():
|
||||||
|
# Special characters (not alphanumeric)
|
||||||
|
filtered_data.append(lora)
|
||||||
|
elif letter == '漢' and self._is_cjk_character(first_char):
|
||||||
|
# CJK characters
|
||||||
|
filtered_data.append(lora)
|
||||||
|
elif letter.upper() == first_char:
|
||||||
|
# Regular alphabet matching
|
||||||
|
filtered_data.append(lora)
|
||||||
|
|
||||||
|
return filtered_data
|
||||||
|
|
||||||
def _is_cjk_character(self, char: str) -> bool:
|
def _is_cjk_character(self, char: str) -> bool:
|
||||||
"""Check if character is CJK (Chinese, Japanese, Korean)"""
|
"""Check if character is a CJK character"""
|
||||||
|
# Define Unicode ranges for CJK characters
|
||||||
cjk_ranges = [
|
cjk_ranges = [
|
||||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||||
(0x3400, 0x4DBF), # CJK Extension A
|
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
|
||||||
(0x20000, 0x2A6DF), # CJK Extension B
|
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
|
||||||
(0x2A700, 0x2B73F), # CJK Extension C
|
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
|
||||||
(0x2B740, 0x2B81F), # CJK Extension D
|
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
|
||||||
|
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
|
||||||
|
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
|
||||||
|
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
|
||||||
|
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
|
||||||
|
(0x3300, 0x33FF), # CJK Compatibility
|
||||||
|
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
|
||||||
|
(0x3100, 0x312F), # Bopomofo
|
||||||
|
(0x31A0, 0x31BF), # Bopomofo Extended
|
||||||
(0x3040, 0x309F), # Hiragana
|
(0x3040, 0x309F), # Hiragana
|
||||||
(0x30A0, 0x30FF), # Katakana
|
(0x30A0, 0x30FF), # Katakana
|
||||||
|
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
|
||||||
(0xAC00, 0xD7AF), # Hangul Syllables
|
(0xAC00, 0xD7AF), # Hangul Syllables
|
||||||
|
(0x1100, 0x11FF), # Hangul Jamo
|
||||||
|
(0xA960, 0xA97F), # Hangul Jamo Extended-A
|
||||||
|
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
|
||||||
]
|
]
|
||||||
|
|
||||||
char_code = ord(char)
|
code_point = ord(char)
|
||||||
return any(start <= char_code <= end for start, end in cjk_ranges)
|
return any(start <= code_point <= end for start, end in cjk_ranges)
|
||||||
|
|
||||||
# LoRA-specific methods
|
# LoRA-specific methods
|
||||||
async def get_letter_counts(self) -> Dict[str, int]:
|
async def get_letter_counts(self) -> Dict[str, int]:
|
||||||
"""Get count of LoRAs for each letter of the alphabet"""
|
"""Get count of LoRAs for each letter of the alphabet"""
|
||||||
cache = await self.scanner.get_cached_data()
|
cache = await self.scanner.get_cached_data()
|
||||||
letter_counts = {}
|
data = cache.sorted_by_name
|
||||||
|
|
||||||
for lora in cache.raw_data:
|
# Define letter categories
|
||||||
|
letters = {
|
||||||
|
'#': 0, # Numbers
|
||||||
|
'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0,
|
||||||
|
'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0,
|
||||||
|
'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0,
|
||||||
|
'Y': 0, 'Z': 0,
|
||||||
|
'@': 0, # Special characters
|
||||||
|
'漢': 0 # CJK characters
|
||||||
|
}
|
||||||
|
|
||||||
|
# Count models for each letter
|
||||||
|
for lora in data:
|
||||||
model_name = lora.get('model_name', '')
|
model_name = lora.get('model_name', '')
|
||||||
if model_name:
|
if not model_name:
|
||||||
first_char = model_name[0].upper()
|
continue
|
||||||
if first_char.isalpha():
|
|
||||||
letter_counts[first_char] = letter_counts.get(first_char, 0) + 1
|
|
||||||
elif self._is_cjk_character(first_char):
|
|
||||||
letter_counts['CJK'] = letter_counts.get('CJK', 0) + 1
|
|
||||||
else:
|
|
||||||
letter_counts['#'] = letter_counts.get('#', 0) + 1
|
|
||||||
|
|
||||||
return letter_counts
|
first_char = model_name[0].upper()
|
||||||
|
|
||||||
|
if first_char.isdigit():
|
||||||
|
letters['#'] += 1
|
||||||
|
elif first_char in letters:
|
||||||
|
letters[first_char] += 1
|
||||||
|
elif self._is_cjk_character(first_char):
|
||||||
|
letters['漢'] += 1
|
||||||
|
elif not first_char.isalnum():
|
||||||
|
letters['@'] += 1
|
||||||
|
|
||||||
|
return letters
|
||||||
|
|
||||||
async def get_lora_notes(self, lora_name: str) -> Optional[str]:
|
async def get_lora_notes(self, lora_name: str) -> Optional[str]:
|
||||||
"""Get notes for a specific LoRA file"""
|
"""Get notes for a specific LoRA file"""
|
||||||
|
|||||||
Reference in New Issue
Block a user