mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 05:32:12 -03:00
- Added BaseModelRoutes class to handle common routes and logic for model types. - Created CheckpointRoutes class inheriting from BaseModelRoutes for checkpoint-specific routes. - Implemented CheckpointService class for handling checkpoint-related data and operations. - Developed LoraService class for managing LoRA-specific functionalities. - Introduced ModelServiceFactory to manage service and route registrations for different model types. - Established methods for fetching, filtering, and formatting model data across services. - Integrated CivitAI metadata handling within model routes and services. - Added pagination and filtering capabilities for model data retrieval.
172 lines
6.8 KiB
Python
172 lines
6.8 KiB
Python
import os
|
|
import logging
|
|
from typing import Dict, List, Optional
|
|
|
|
from .base_model_service import BaseModelService
|
|
from ..utils.models import LoraMetadata
|
|
from ..config import config
|
|
from ..utils.routes_common import ModelRouteUtils
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class LoraService(BaseModelService):
|
|
"""LoRA-specific service implementation"""
|
|
|
|
def __init__(self, scanner):
|
|
"""Initialize LoRA service
|
|
|
|
Args:
|
|
scanner: LoRA scanner instance
|
|
"""
|
|
super().__init__("lora", scanner, LoraMetadata)
|
|
|
|
async def format_response(self, lora_data: Dict) -> Dict:
|
|
"""Format LoRA data for API response"""
|
|
return {
|
|
"model_name": lora_data["model_name"],
|
|
"file_name": lora_data["file_name"],
|
|
"preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")),
|
|
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
|
|
"base_model": lora_data.get("base_model", ""),
|
|
"folder": lora_data["folder"],
|
|
"sha256": lora_data.get("sha256", ""),
|
|
"file_path": lora_data["file_path"].replace(os.sep, "/"),
|
|
"file_size": lora_data.get("size", 0),
|
|
"modified": lora_data.get("modified", ""),
|
|
"tags": lora_data.get("tags", []),
|
|
"modelDescription": lora_data.get("modelDescription", ""),
|
|
"from_civitai": lora_data.get("from_civitai", True),
|
|
"usage_tips": lora_data.get("usage_tips", ""),
|
|
"notes": lora_data.get("notes", ""),
|
|
"favorite": lora_data.get("favorite", False),
|
|
"civitai": ModelRouteUtils.filter_civitai_data(lora_data.get("civitai", {}))
|
|
}
|
|
|
|
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
|
"""Apply LoRA-specific filters"""
|
|
# Handle first_letter filter for LoRAs
|
|
first_letter = kwargs.get('first_letter')
|
|
if first_letter:
|
|
data = self._filter_by_first_letter(data, first_letter)
|
|
|
|
return data
|
|
|
|
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
|
|
"""Filter LoRAs by first letter"""
|
|
if letter == '#':
|
|
# Filter for non-alphabetic characters
|
|
return [
|
|
item for item in data
|
|
if not item.get('model_name', '')[0].isalpha()
|
|
]
|
|
elif letter == 'CJK':
|
|
# Filter for CJK characters
|
|
return [
|
|
item for item in data
|
|
if item.get('model_name', '') and self._is_cjk_character(item['model_name'][0])
|
|
]
|
|
else:
|
|
# Filter for specific letter
|
|
return [
|
|
item for item in data
|
|
if item.get('model_name', '').lower().startswith(letter.lower())
|
|
]
|
|
|
|
def _is_cjk_character(self, char: str) -> bool:
|
|
"""Check if character is CJK (Chinese, Japanese, Korean)"""
|
|
cjk_ranges = [
|
|
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
|
(0x3400, 0x4DBF), # CJK Extension A
|
|
(0x20000, 0x2A6DF), # CJK Extension B
|
|
(0x2A700, 0x2B73F), # CJK Extension C
|
|
(0x2B740, 0x2B81F), # CJK Extension D
|
|
(0x3040, 0x309F), # Hiragana
|
|
(0x30A0, 0x30FF), # Katakana
|
|
(0xAC00, 0xD7AF), # Hangul Syllables
|
|
]
|
|
|
|
char_code = ord(char)
|
|
return any(start <= char_code <= end for start, end in cjk_ranges)
|
|
|
|
# LoRA-specific methods
|
|
async def get_letter_counts(self) -> Dict[str, int]:
|
|
"""Get count of LoRAs for each letter of the alphabet"""
|
|
cache = await self.scanner.get_cached_data()
|
|
letter_counts = {}
|
|
|
|
for lora in cache.raw_data:
|
|
model_name = lora.get('model_name', '')
|
|
if model_name:
|
|
first_char = model_name[0].upper()
|
|
if first_char.isalpha():
|
|
letter_counts[first_char] = letter_counts.get(first_char, 0) + 1
|
|
elif self._is_cjk_character(first_char):
|
|
letter_counts['CJK'] = letter_counts.get('CJK', 0) + 1
|
|
else:
|
|
letter_counts['#'] = letter_counts.get('#', 0) + 1
|
|
|
|
return letter_counts
|
|
|
|
async def get_lora_notes(self, lora_name: str) -> Optional[str]:
|
|
"""Get notes for a specific LoRA file"""
|
|
cache = await self.scanner.get_cached_data()
|
|
|
|
for lora in cache.raw_data:
|
|
if lora['file_name'] == lora_name:
|
|
return lora.get('notes', '')
|
|
|
|
return None
|
|
|
|
async def get_lora_trigger_words(self, lora_name: str) -> List[str]:
|
|
"""Get trigger words for a specific LoRA file"""
|
|
cache = await self.scanner.get_cached_data()
|
|
|
|
for lora in cache.raw_data:
|
|
if lora['file_name'] == lora_name:
|
|
civitai_data = lora.get('civitai', {})
|
|
return civitai_data.get('trainedWords', [])
|
|
|
|
return []
|
|
|
|
async def get_lora_preview_url(self, lora_name: str) -> Optional[str]:
|
|
"""Get the static preview URL for a LoRA file"""
|
|
cache = await self.scanner.get_cached_data()
|
|
|
|
for lora in cache.raw_data:
|
|
if lora['file_name'] == lora_name:
|
|
preview_url = lora.get('preview_url')
|
|
if preview_url:
|
|
return config.get_preview_static_url(preview_url)
|
|
|
|
return None
|
|
|
|
async def get_lora_civitai_url(self, lora_name: str) -> Dict[str, Optional[str]]:
|
|
"""Get the Civitai URL for a LoRA file"""
|
|
cache = await self.scanner.get_cached_data()
|
|
|
|
for lora in cache.raw_data:
|
|
if lora['file_name'] == lora_name:
|
|
civitai_data = lora.get('civitai', {})
|
|
model_id = civitai_data.get('modelId')
|
|
version_id = civitai_data.get('id')
|
|
|
|
if model_id:
|
|
civitai_url = f"https://civitai.com/models/{model_id}"
|
|
if version_id:
|
|
civitai_url += f"?modelVersionId={version_id}"
|
|
|
|
return {
|
|
'civitai_url': civitai_url,
|
|
'model_id': str(model_id),
|
|
'version_id': str(version_id) if version_id else None
|
|
}
|
|
|
|
return {'civitai_url': None, 'model_id': None, 'version_id': None}
|
|
|
|
def find_duplicate_hashes(self) -> Dict:
|
|
"""Find LoRAs with duplicate SHA256 hashes"""
|
|
return self.scanner._hash_index.get_duplicate_hashes()
|
|
|
|
def find_duplicate_filenames(self) -> Dict:
|
|
"""Find LoRAs with conflicting filenames"""
|
|
return self.scanner._hash_index.get_duplicate_filenames() |