Files
ComfyUI-Lora-Manager/py/services/lora_service.py
Will Miao a2b81ea099 refactor: Implement base model routes and services for LoRA and Checkpoint
- Added BaseModelRoutes class to handle common routes and logic for model types.
- Created CheckpointRoutes class inheriting from BaseModelRoutes for checkpoint-specific routes.
- Implemented CheckpointService class for handling checkpoint-related data and operations.
- Developed LoraService class for managing LoRA-specific functionalities.
- Introduced ModelServiceFactory to manage service and route registrations for different model types.
- Established methods for fetching, filtering, and formatting model data across services.
- Integrated CivitAI metadata handling within model routes and services.
- Added pagination and filtering capabilities for model data retrieval.
2025-07-23 14:39:02 +08:00

172 lines
6.8 KiB
Python

import os
import logging
from typing import Dict, List, Optional
from .base_model_service import BaseModelService
from ..utils.models import LoraMetadata
from ..config import config
from ..utils.routes_common import ModelRouteUtils
logger = logging.getLogger(__name__)
class LoraService(BaseModelService):
"""LoRA-specific service implementation"""
def __init__(self, scanner):
"""Initialize LoRA service
Args:
scanner: LoRA scanner instance
"""
super().__init__("lora", scanner, LoraMetadata)
async def format_response(self, lora_data: Dict) -> Dict:
"""Format LoRA data for API response"""
return {
"model_name": lora_data["model_name"],
"file_name": lora_data["file_name"],
"preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")),
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
"base_model": lora_data.get("base_model", ""),
"folder": lora_data["folder"],
"sha256": lora_data.get("sha256", ""),
"file_path": lora_data["file_path"].replace(os.sep, "/"),
"file_size": lora_data.get("size", 0),
"modified": lora_data.get("modified", ""),
"tags": lora_data.get("tags", []),
"modelDescription": lora_data.get("modelDescription", ""),
"from_civitai": lora_data.get("from_civitai", True),
"usage_tips": lora_data.get("usage_tips", ""),
"notes": lora_data.get("notes", ""),
"favorite": lora_data.get("favorite", False),
"civitai": ModelRouteUtils.filter_civitai_data(lora_data.get("civitai", {}))
}
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
"""Apply LoRA-specific filters"""
# Handle first_letter filter for LoRAs
first_letter = kwargs.get('first_letter')
if first_letter:
data = self._filter_by_first_letter(data, first_letter)
return data
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
"""Filter LoRAs by first letter"""
if letter == '#':
# Filter for non-alphabetic characters
return [
item for item in data
if not item.get('model_name', '')[0].isalpha()
]
elif letter == 'CJK':
# Filter for CJK characters
return [
item for item in data
if item.get('model_name', '') and self._is_cjk_character(item['model_name'][0])
]
else:
# Filter for specific letter
return [
item for item in data
if item.get('model_name', '').lower().startswith(letter.lower())
]
def _is_cjk_character(self, char: str) -> bool:
"""Check if character is CJK (Chinese, Japanese, Korean)"""
cjk_ranges = [
(0x4E00, 0x9FFF), # CJK Unified Ideographs
(0x3400, 0x4DBF), # CJK Extension A
(0x20000, 0x2A6DF), # CJK Extension B
(0x2A700, 0x2B73F), # CJK Extension C
(0x2B740, 0x2B81F), # CJK Extension D
(0x3040, 0x309F), # Hiragana
(0x30A0, 0x30FF), # Katakana
(0xAC00, 0xD7AF), # Hangul Syllables
]
char_code = ord(char)
return any(start <= char_code <= end for start, end in cjk_ranges)
# LoRA-specific methods
async def get_letter_counts(self) -> Dict[str, int]:
"""Get count of LoRAs for each letter of the alphabet"""
cache = await self.scanner.get_cached_data()
letter_counts = {}
for lora in cache.raw_data:
model_name = lora.get('model_name', '')
if model_name:
first_char = model_name[0].upper()
if first_char.isalpha():
letter_counts[first_char] = letter_counts.get(first_char, 0) + 1
elif self._is_cjk_character(first_char):
letter_counts['CJK'] = letter_counts.get('CJK', 0) + 1
else:
letter_counts['#'] = letter_counts.get('#', 0) + 1
return letter_counts
async def get_lora_notes(self, lora_name: str) -> Optional[str]:
"""Get notes for a specific LoRA file"""
cache = await self.scanner.get_cached_data()
for lora in cache.raw_data:
if lora['file_name'] == lora_name:
return lora.get('notes', '')
return None
async def get_lora_trigger_words(self, lora_name: str) -> List[str]:
"""Get trigger words for a specific LoRA file"""
cache = await self.scanner.get_cached_data()
for lora in cache.raw_data:
if lora['file_name'] == lora_name:
civitai_data = lora.get('civitai', {})
return civitai_data.get('trainedWords', [])
return []
async def get_lora_preview_url(self, lora_name: str) -> Optional[str]:
"""Get the static preview URL for a LoRA file"""
cache = await self.scanner.get_cached_data()
for lora in cache.raw_data:
if lora['file_name'] == lora_name:
preview_url = lora.get('preview_url')
if preview_url:
return config.get_preview_static_url(preview_url)
return None
async def get_lora_civitai_url(self, lora_name: str) -> Dict[str, Optional[str]]:
"""Get the Civitai URL for a LoRA file"""
cache = await self.scanner.get_cached_data()
for lora in cache.raw_data:
if lora['file_name'] == lora_name:
civitai_data = lora.get('civitai', {})
model_id = civitai_data.get('modelId')
version_id = civitai_data.get('id')
if model_id:
civitai_url = f"https://civitai.com/models/{model_id}"
if version_id:
civitai_url += f"?modelVersionId={version_id}"
return {
'civitai_url': civitai_url,
'model_id': str(model_id),
'version_id': str(version_id) if version_id else None
}
return {'civitai_url': None, 'model_id': None, 'version_id': None}
def find_duplicate_hashes(self) -> Dict:
"""Find LoRAs with duplicate SHA256 hashes"""
return self.scanner._hash_index.get_duplicate_hashes()
def find_duplicate_filenames(self) -> Dict:
"""Find LoRAs with conflicting filenames"""
return self.scanner._hash_index.get_duplicate_filenames()