mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 14:12:11 -03:00
refactor: Implement base model routes and services for LoRA and Checkpoint
- Added BaseModelRoutes class to handle common routes and logic for model types. - Created CheckpointRoutes class inheriting from BaseModelRoutes for checkpoint-specific routes. - Implemented CheckpointService class for handling checkpoint-related data and operations. - Developed LoraService class for managing LoRA-specific functionalities. - Introduced ModelServiceFactory to manage service and route registrations for different model types. - Established methods for fetching, filtering, and formatting model data across services. - Integrated CivitAI metadata handling within model routes and services. - Added pagination and filtering capabilities for model data retrieval.
This commit is contained in:
248
py/services/base_model_service.py
Normal file
248
py/services/base_model_service.py
Normal file
@@ -0,0 +1,248 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Type, Set
|
||||
import logging
|
||||
|
||||
from ..utils.models import BaseModelMetadata
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from .settings_manager import settings
|
||||
from ..utils.utils import fuzzy_match
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseModelService(ABC):
|
||||
"""Base service class for all model types"""
|
||||
|
||||
def __init__(self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata]):
|
||||
"""Initialize the service
|
||||
|
||||
Args:
|
||||
model_type: Type of model (lora, checkpoint, etc.)
|
||||
scanner: Model scanner instance
|
||||
metadata_class: Metadata class for this model type
|
||||
"""
|
||||
self.model_type = model_type
|
||||
self.scanner = scanner
|
||||
self.metadata_class = metadata_class
|
||||
|
||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
||||
folder: str = None, search: str = None, fuzzy_search: bool = False,
|
||||
base_models: list = None, tags: list = None,
|
||||
search_options: dict = None, hash_filters: dict = None,
|
||||
favorites_only: bool = False, **kwargs) -> Dict:
|
||||
"""Get paginated and filtered model data
|
||||
|
||||
Args:
|
||||
page: Page number (1-based)
|
||||
page_size: Number of items per page
|
||||
sort_by: Sort criteria ('name' or 'date')
|
||||
folder: Folder filter
|
||||
search: Search term
|
||||
fuzzy_search: Whether to use fuzzy search
|
||||
base_models: List of base models to filter by
|
||||
tags: List of tags to filter by
|
||||
search_options: Search options dict
|
||||
hash_filters: Hash filtering options
|
||||
favorites_only: Filter for favorites only
|
||||
**kwargs: Additional model-specific filters
|
||||
|
||||
Returns:
|
||||
Dict containing paginated results
|
||||
"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Get default search options if not provided
|
||||
if search_options is None:
|
||||
search_options = {
|
||||
'filename': True,
|
||||
'modelname': True,
|
||||
'tags': False,
|
||||
'recursive': False,
|
||||
}
|
||||
|
||||
# Get the base data set
|
||||
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
||||
|
||||
# Apply hash filtering if provided (highest priority)
|
||||
if hash_filters:
|
||||
filtered_data = await self._apply_hash_filters(filtered_data, hash_filters)
|
||||
|
||||
# Jump to pagination for hash filters
|
||||
return self._paginate(filtered_data, page, page_size)
|
||||
|
||||
# Apply common filters
|
||||
filtered_data = await self._apply_common_filters(
|
||||
filtered_data, folder, base_models, tags, favorites_only, search_options
|
||||
)
|
||||
|
||||
# Apply search filtering
|
||||
if search:
|
||||
filtered_data = await self._apply_search_filters(
|
||||
filtered_data, search, fuzzy_search, search_options
|
||||
)
|
||||
|
||||
# Apply model-specific filters
|
||||
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
|
||||
|
||||
return self._paginate(filtered_data, page, page_size)
|
||||
|
||||
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
|
||||
"""Apply hash-based filtering"""
|
||||
single_hash = hash_filters.get('single_hash')
|
||||
multiple_hashes = hash_filters.get('multiple_hashes')
|
||||
|
||||
if single_hash:
|
||||
# Filter by single hash
|
||||
single_hash = single_hash.lower()
|
||||
return [
|
||||
item for item in data
|
||||
if item.get('sha256', '').lower() == single_hash
|
||||
]
|
||||
elif multiple_hashes:
|
||||
# Filter by multiple hashes
|
||||
hash_set = set(hash.lower() for hash in multiple_hashes)
|
||||
return [
|
||||
item for item in data
|
||||
if item.get('sha256', '').lower() in hash_set
|
||||
]
|
||||
|
||||
return data
|
||||
|
||||
async def _apply_common_filters(self, data: List[Dict], folder: str = None,
|
||||
base_models: list = None, tags: list = None,
|
||||
favorites_only: bool = False, search_options: dict = None) -> List[Dict]:
|
||||
"""Apply common filters that work across all model types"""
|
||||
# Apply SFW filtering if enabled in settings
|
||||
if settings.get('show_only_sfw', False):
|
||||
data = [
|
||||
item for item in data
|
||||
if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
||||
]
|
||||
|
||||
# Apply favorites filtering if enabled
|
||||
if favorites_only:
|
||||
data = [
|
||||
item for item in data
|
||||
if item.get('favorite', False) is True
|
||||
]
|
||||
|
||||
# Apply folder filtering
|
||||
if folder is not None:
|
||||
if search_options and search_options.get('recursive', False):
|
||||
# Recursive folder filtering - include all subfolders
|
||||
data = [
|
||||
item for item in data
|
||||
if item['folder'].startswith(folder)
|
||||
]
|
||||
else:
|
||||
# Exact folder filtering
|
||||
data = [
|
||||
item for item in data
|
||||
if item['folder'] == folder
|
||||
]
|
||||
|
||||
# Apply base model filtering
|
||||
if base_models and len(base_models) > 0:
|
||||
data = [
|
||||
item for item in data
|
||||
if item.get('base_model') in base_models
|
||||
]
|
||||
|
||||
# Apply tag filtering
|
||||
if tags and len(tags) > 0:
|
||||
data = [
|
||||
item for item in data
|
||||
if any(tag in item.get('tags', []) for tag in tags)
|
||||
]
|
||||
|
||||
return data
|
||||
|
||||
async def _apply_search_filters(self, data: List[Dict], search: str,
|
||||
fuzzy_search: bool, search_options: dict) -> List[Dict]:
|
||||
"""Apply search filtering"""
|
||||
search_results = []
|
||||
|
||||
for item in data:
|
||||
# Search by file name
|
||||
if search_options.get('filename', True):
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(item.get('file_name', ''), search):
|
||||
search_results.append(item)
|
||||
continue
|
||||
elif search.lower() in item.get('file_name', '').lower():
|
||||
search_results.append(item)
|
||||
continue
|
||||
|
||||
# Search by model name
|
||||
if search_options.get('modelname', True):
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(item.get('model_name', ''), search):
|
||||
search_results.append(item)
|
||||
continue
|
||||
elif search.lower() in item.get('model_name', '').lower():
|
||||
search_results.append(item)
|
||||
continue
|
||||
|
||||
# Search by tags
|
||||
if search_options.get('tags', False) and 'tags' in item:
|
||||
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower())
|
||||
for tag in item['tags']):
|
||||
search_results.append(item)
|
||||
continue
|
||||
|
||||
return search_results
|
||||
|
||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
||||
return data
|
||||
|
||||
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
|
||||
"""Apply pagination to filtered data"""
|
||||
total_items = len(data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
return {
|
||||
'items': data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def format_response(self, model_data: Dict) -> Dict:
|
||||
"""Format model data for API response - must be implemented by subclasses"""
|
||||
pass
|
||||
|
||||
# Common service methods that delegate to scanner
|
||||
async def get_top_tags(self, limit: int = 20) -> List[Dict]:
|
||||
"""Get top tags sorted by frequency"""
|
||||
return await self.scanner.get_top_tags(limit)
|
||||
|
||||
async def get_base_models(self, limit: int = 20) -> List[Dict]:
|
||||
"""Get base models sorted by frequency"""
|
||||
return await self.scanner.get_base_models(limit)
|
||||
|
||||
def has_hash(self, sha256: str) -> bool:
|
||||
"""Check if a model with given hash exists"""
|
||||
return self.scanner.has_hash(sha256)
|
||||
|
||||
def get_path_by_hash(self, sha256: str) -> Optional[str]:
|
||||
"""Get file path for a model by its hash"""
|
||||
return self.scanner.get_path_by_hash(sha256)
|
||||
|
||||
def get_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||
"""Get hash for a model by its file path"""
|
||||
return self.scanner.get_hash_by_path(file_path)
|
||||
|
||||
async def scan_models(self, force_refresh: bool = False, rebuild_cache: bool = False):
|
||||
"""Trigger model scanning"""
|
||||
return await self.scanner.get_cached_data(force_refresh=force_refresh, rebuild_cache=rebuild_cache)
|
||||
|
||||
async def get_model_info_by_name(self, name: str):
|
||||
"""Get model information by name"""
|
||||
return await self.scanner.get_model_info_by_name(name)
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get model root directories"""
|
||||
return self.scanner.get_model_roots()
|
||||
@@ -1,14 +1,12 @@
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional, Set
|
||||
import folder_paths # type: ignore
|
||||
from typing import List, Dict
|
||||
|
||||
from ..utils.models import CheckpointMetadata
|
||||
from ..config import config
|
||||
from .model_scanner import ModelScanner
|
||||
from .model_hash_index import ModelHashIndex
|
||||
from .service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -109,4 +107,31 @@ class CheckpointScanner(ModelScanner):
|
||||
if result:
|
||||
checkpoints.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
|
||||
# Checkpoint-specific hash index functionality
|
||||
def has_checkpoint_hash(self, sha256: str) -> bool:
|
||||
"""Check if a checkpoint with given hash exists"""
|
||||
return self.has_hash(sha256)
|
||||
|
||||
def get_checkpoint_path_by_hash(self, sha256: str) -> str:
|
||||
"""Get file path for a checkpoint by its hash"""
|
||||
return self.get_path_by_hash(sha256)
|
||||
|
||||
def get_checkpoint_hash_by_path(self, file_path: str) -> str:
|
||||
"""Get hash for a checkpoint by its file path"""
|
||||
return self.get_hash_by_path(file_path)
|
||||
|
||||
async def get_checkpoint_info_by_name(self, name):
|
||||
"""Get checkpoint information by name"""
|
||||
try:
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
for checkpoint in cache.raw_data:
|
||||
if checkpoint.get("file_name") == name:
|
||||
return checkpoint
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting checkpoint info by name: {e}", exc_info=True)
|
||||
return None
|
||||
51
py/services/checkpoint_service.py
Normal file
51
py/services/checkpoint_service.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import CheckpointMetadata
|
||||
from ..config import config
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointService(BaseModelService):
|
||||
"""Checkpoint-specific service implementation"""
|
||||
|
||||
def __init__(self, scanner):
|
||||
"""Initialize Checkpoint service
|
||||
|
||||
Args:
|
||||
scanner: Checkpoint scanner instance
|
||||
"""
|
||||
super().__init__("checkpoint", scanner, CheckpointMetadata)
|
||||
|
||||
async def format_response(self, checkpoint_data: Dict) -> Dict:
|
||||
"""Format Checkpoint data for API response"""
|
||||
return {
|
||||
"model_name": checkpoint_data["model_name"],
|
||||
"file_name": checkpoint_data["file_name"],
|
||||
"preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0),
|
||||
"base_model": checkpoint_data.get("base_model", ""),
|
||||
"folder": checkpoint_data["folder"],
|
||||
"sha256": checkpoint_data.get("sha256", ""),
|
||||
"file_path": checkpoint_data["file_path"].replace(os.sep, "/"),
|
||||
"file_size": checkpoint_data.get("size", 0),
|
||||
"modified": checkpoint_data.get("modified", ""),
|
||||
"tags": checkpoint_data.get("tags", []),
|
||||
"modelDescription": checkpoint_data.get("modelDescription", ""),
|
||||
"from_civitai": checkpoint_data.get("from_civitai", True),
|
||||
"notes": checkpoint_data.get("notes", ""),
|
||||
"model_type": checkpoint_data.get("model_type", "checkpoint"),
|
||||
"favorite": checkpoint_data.get("favorite", False),
|
||||
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint_data.get("civitai", {}))
|
||||
}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find Checkpoints with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
def find_duplicate_filenames(self) -> Dict:
|
||||
"""Find Checkpoints with conflicting filenames"""
|
||||
return self.scanner._hash_index.get_duplicate_filenames()
|
||||
172
py/services/lora_service.py
Normal file
172
py/services/lora_service.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import LoraMetadata
|
||||
from ..config import config
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoraService(BaseModelService):
|
||||
"""LoRA-specific service implementation"""
|
||||
|
||||
def __init__(self, scanner):
|
||||
"""Initialize LoRA service
|
||||
|
||||
Args:
|
||||
scanner: LoRA scanner instance
|
||||
"""
|
||||
super().__init__("lora", scanner, LoraMetadata)
|
||||
|
||||
async def format_response(self, lora_data: Dict) -> Dict:
|
||||
"""Format LoRA data for API response"""
|
||||
return {
|
||||
"model_name": lora_data["model_name"],
|
||||
"file_name": lora_data["file_name"],
|
||||
"preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
|
||||
"base_model": lora_data.get("base_model", ""),
|
||||
"folder": lora_data["folder"],
|
||||
"sha256": lora_data.get("sha256", ""),
|
||||
"file_path": lora_data["file_path"].replace(os.sep, "/"),
|
||||
"file_size": lora_data.get("size", 0),
|
||||
"modified": lora_data.get("modified", ""),
|
||||
"tags": lora_data.get("tags", []),
|
||||
"modelDescription": lora_data.get("modelDescription", ""),
|
||||
"from_civitai": lora_data.get("from_civitai", True),
|
||||
"usage_tips": lora_data.get("usage_tips", ""),
|
||||
"notes": lora_data.get("notes", ""),
|
||||
"favorite": lora_data.get("favorite", False),
|
||||
"civitai": ModelRouteUtils.filter_civitai_data(lora_data.get("civitai", {}))
|
||||
}
|
||||
|
||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||
"""Apply LoRA-specific filters"""
|
||||
# Handle first_letter filter for LoRAs
|
||||
first_letter = kwargs.get('first_letter')
|
||||
if first_letter:
|
||||
data = self._filter_by_first_letter(data, first_letter)
|
||||
|
||||
return data
|
||||
|
||||
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
|
||||
"""Filter LoRAs by first letter"""
|
||||
if letter == '#':
|
||||
# Filter for non-alphabetic characters
|
||||
return [
|
||||
item for item in data
|
||||
if not item.get('model_name', '')[0].isalpha()
|
||||
]
|
||||
elif letter == 'CJK':
|
||||
# Filter for CJK characters
|
||||
return [
|
||||
item for item in data
|
||||
if item.get('model_name', '') and self._is_cjk_character(item['model_name'][0])
|
||||
]
|
||||
else:
|
||||
# Filter for specific letter
|
||||
return [
|
||||
item for item in data
|
||||
if item.get('model_name', '').lower().startswith(letter.lower())
|
||||
]
|
||||
|
||||
def _is_cjk_character(self, char: str) -> bool:
|
||||
"""Check if character is CJK (Chinese, Japanese, Korean)"""
|
||||
cjk_ranges = [
|
||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||
(0x3400, 0x4DBF), # CJK Extension A
|
||||
(0x20000, 0x2A6DF), # CJK Extension B
|
||||
(0x2A700, 0x2B73F), # CJK Extension C
|
||||
(0x2B740, 0x2B81F), # CJK Extension D
|
||||
(0x3040, 0x309F), # Hiragana
|
||||
(0x30A0, 0x30FF), # Katakana
|
||||
(0xAC00, 0xD7AF), # Hangul Syllables
|
||||
]
|
||||
|
||||
char_code = ord(char)
|
||||
return any(start <= char_code <= end for start, end in cjk_ranges)
|
||||
|
||||
# LoRA-specific methods
|
||||
async def get_letter_counts(self) -> Dict[str, int]:
|
||||
"""Get count of LoRAs for each letter of the alphabet"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
letter_counts = {}
|
||||
|
||||
for lora in cache.raw_data:
|
||||
model_name = lora.get('model_name', '')
|
||||
if model_name:
|
||||
first_char = model_name[0].upper()
|
||||
if first_char.isalpha():
|
||||
letter_counts[first_char] = letter_counts.get(first_char, 0) + 1
|
||||
elif self._is_cjk_character(first_char):
|
||||
letter_counts['CJK'] = letter_counts.get('CJK', 0) + 1
|
||||
else:
|
||||
letter_counts['#'] = letter_counts.get('#', 0) + 1
|
||||
|
||||
return letter_counts
|
||||
|
||||
async def get_lora_notes(self, lora_name: str) -> Optional[str]:
|
||||
"""Get notes for a specific LoRA file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
if lora['file_name'] == lora_name:
|
||||
return lora.get('notes', '')
|
||||
|
||||
return None
|
||||
|
||||
async def get_lora_trigger_words(self, lora_name: str) -> List[str]:
|
||||
"""Get trigger words for a specific LoRA file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
if lora['file_name'] == lora_name:
|
||||
civitai_data = lora.get('civitai', {})
|
||||
return civitai_data.get('trainedWords', [])
|
||||
|
||||
return []
|
||||
|
||||
async def get_lora_preview_url(self, lora_name: str) -> Optional[str]:
|
||||
"""Get the static preview URL for a LoRA file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
if lora['file_name'] == lora_name:
|
||||
preview_url = lora.get('preview_url')
|
||||
if preview_url:
|
||||
return config.get_preview_static_url(preview_url)
|
||||
|
||||
return None
|
||||
|
||||
async def get_lora_civitai_url(self, lora_name: str) -> Dict[str, Optional[str]]:
|
||||
"""Get the Civitai URL for a LoRA file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
if lora['file_name'] == lora_name:
|
||||
civitai_data = lora.get('civitai', {})
|
||||
model_id = civitai_data.get('modelId')
|
||||
version_id = civitai_data.get('id')
|
||||
|
||||
if model_id:
|
||||
civitai_url = f"https://civitai.com/models/{model_id}"
|
||||
if version_id:
|
||||
civitai_url += f"?modelVersionId={version_id}"
|
||||
|
||||
return {
|
||||
'civitai_url': civitai_url,
|
||||
'model_id': str(model_id),
|
||||
'version_id': str(version_id) if version_id else None
|
||||
}
|
||||
|
||||
return {'civitai_url': None, 'model_id': None, 'version_id': None}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find LoRAs with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
def find_duplicate_filenames(self) -> Dict:
|
||||
"""Find LoRAs with conflicting filenames"""
|
||||
return self.scanner._hash_index.get_duplicate_filenames()
|
||||
137
py/services/model_service_factory.py
Normal file
137
py/services/model_service_factory.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from typing import Dict, Type, Any
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelServiceFactory:
|
||||
"""Factory for managing model services and routes"""
|
||||
|
||||
_services: Dict[str, Type] = {}
|
||||
_routes: Dict[str, Type] = {}
|
||||
_initialized_services: Dict[str, Any] = {}
|
||||
_initialized_routes: Dict[str, Any] = {}
|
||||
|
||||
@classmethod
|
||||
def register_model_type(cls, model_type: str, service_class: Type, route_class: Type):
|
||||
"""Register a new model type with its service and route classes
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier (e.g., 'lora', 'checkpoint')
|
||||
service_class: The service class for this model type
|
||||
route_class: The route class for this model type
|
||||
"""
|
||||
cls._services[model_type] = service_class
|
||||
cls._routes[model_type] = route_class
|
||||
logger.info(f"Registered model type '{model_type}' with service {service_class.__name__} and routes {route_class.__name__}")
|
||||
|
||||
@classmethod
|
||||
def get_service_class(cls, model_type: str) -> Type:
|
||||
"""Get service class for a model type
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier
|
||||
|
||||
Returns:
|
||||
The service class for the model type
|
||||
|
||||
Raises:
|
||||
ValueError: If model type is not registered
|
||||
"""
|
||||
if model_type not in cls._services:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
return cls._services[model_type]
|
||||
|
||||
@classmethod
|
||||
def get_route_class(cls, model_type: str) -> Type:
|
||||
"""Get route class for a model type
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier
|
||||
|
||||
Returns:
|
||||
The route class for the model type
|
||||
|
||||
Raises:
|
||||
ValueError: If model type is not registered
|
||||
"""
|
||||
if model_type not in cls._routes:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
return cls._routes[model_type]
|
||||
|
||||
@classmethod
|
||||
def get_route_instance(cls, model_type: str):
|
||||
"""Get or create route instance for a model type
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier
|
||||
|
||||
Returns:
|
||||
The route instance for the model type
|
||||
"""
|
||||
if model_type not in cls._initialized_routes:
|
||||
route_class = cls.get_route_class(model_type)
|
||||
cls._initialized_routes[model_type] = route_class()
|
||||
return cls._initialized_routes[model_type]
|
||||
|
||||
@classmethod
|
||||
def setup_all_routes(cls, app):
|
||||
"""Setup routes for all registered model types
|
||||
|
||||
Args:
|
||||
app: The aiohttp application instance
|
||||
"""
|
||||
logger.info(f"Setting up routes for {len(cls._services)} registered model types")
|
||||
|
||||
for model_type in cls._services.keys():
|
||||
try:
|
||||
routes_instance = cls.get_route_instance(model_type)
|
||||
routes_instance.setup_routes(app)
|
||||
logger.info(f"Successfully set up routes for {model_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup routes for {model_type}: {e}", exc_info=True)
|
||||
|
||||
@classmethod
|
||||
def get_registered_types(cls) -> list:
|
||||
"""Get list of all registered model types
|
||||
|
||||
Returns:
|
||||
List of registered model type identifiers
|
||||
"""
|
||||
return list(cls._services.keys())
|
||||
|
||||
@classmethod
|
||||
def is_registered(cls, model_type: str) -> bool:
|
||||
"""Check if a model type is registered
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier
|
||||
|
||||
Returns:
|
||||
True if the model type is registered, False otherwise
|
||||
"""
|
||||
return model_type in cls._services
|
||||
|
||||
@classmethod
|
||||
def clear_registrations(cls):
|
||||
"""Clear all registrations - mainly for testing purposes"""
|
||||
cls._services.clear()
|
||||
cls._routes.clear()
|
||||
cls._initialized_services.clear()
|
||||
cls._initialized_routes.clear()
|
||||
logger.info("Cleared all model type registrations")
|
||||
|
||||
|
||||
def register_default_model_types():
|
||||
"""Register the default model types (LoRA and Checkpoint)"""
|
||||
from ..services.lora_service import LoraService
|
||||
from ..services.checkpoint_service import CheckpointService
|
||||
from ..routes.lora_routes import LoraRoutes
|
||||
from ..routes.checkpoint_routes import CheckpointRoutes
|
||||
|
||||
# Register LoRA model type
|
||||
ModelServiceFactory.register_model_type('lora', LoraService, LoraRoutes)
|
||||
|
||||
# Register Checkpoint model type
|
||||
ModelServiceFactory.register_model_type('checkpoint', CheckpointService, CheckpointRoutes)
|
||||
|
||||
logger.info("Registered default model types: lora, checkpoint")
|
||||
@@ -7,106 +7,176 @@ logger = logging.getLogger(__name__)
|
||||
T = TypeVar('T') # Define a type variable for service types
|
||||
|
||||
class ServiceRegistry:
|
||||
"""Centralized registry for service singletons"""
|
||||
"""Central registry for managing singleton services"""
|
||||
|
||||
_instance = None
|
||||
_services: Dict[str, Any] = {}
|
||||
_lock = asyncio.Lock()
|
||||
_locks: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""Get singleton instance of the registry"""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
async def register_service(cls, name: str, service: Any) -> None:
|
||||
"""Register a service instance with the registry
|
||||
|
||||
Args:
|
||||
name: Service name identifier
|
||||
service: Service instance to register
|
||||
"""
|
||||
cls._services[name] = service
|
||||
logger.debug(f"Registered service: {name}")
|
||||
|
||||
@classmethod
|
||||
async def register_service(cls, service_name: str, service_instance: Any) -> None:
|
||||
"""Register a service instance with the registry"""
|
||||
registry = cls.get_instance()
|
||||
async with cls._lock:
|
||||
registry._services[service_name] = service_instance
|
||||
logger.debug(f"Registered service: {service_name}")
|
||||
async def get_service(cls, name: str) -> Optional[Any]:
|
||||
"""Get a service instance by name
|
||||
|
||||
Args:
|
||||
name: Service name identifier
|
||||
|
||||
Returns:
|
||||
Service instance or None if not found
|
||||
"""
|
||||
return cls._services.get(name)
|
||||
|
||||
@classmethod
|
||||
async def get_service(cls, service_name: str) -> Any:
|
||||
"""Get a service instance by name"""
|
||||
registry = cls.get_instance()
|
||||
async with cls._lock:
|
||||
if service_name not in registry._services:
|
||||
logger.debug(f"Service {service_name} not found in registry")
|
||||
return None
|
||||
return registry._services[service_name]
|
||||
def _get_lock(cls, name: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a service
|
||||
|
||||
Args:
|
||||
name: Service name identifier
|
||||
|
||||
Returns:
|
||||
AsyncIO lock for the service
|
||||
"""
|
||||
if name not in cls._locks:
|
||||
cls._locks[name] = asyncio.Lock()
|
||||
return cls._locks[name]
|
||||
|
||||
@classmethod
|
||||
def get_service_sync(cls, service_name: str) -> Any:
|
||||
"""Get a service instance by name (synchronous version)"""
|
||||
registry = cls.get_instance()
|
||||
if service_name not in registry._services:
|
||||
logger.debug(f"Service {service_name} not found in registry")
|
||||
return None
|
||||
return registry._services[service_name]
|
||||
|
||||
# Convenience methods for common services
|
||||
@classmethod
|
||||
async def get_lora_scanner(cls):
|
||||
"""Get the LoraScanner instance"""
|
||||
from .lora_scanner import LoraScanner
|
||||
scanner = await cls.get_service("lora_scanner")
|
||||
if scanner is None:
|
||||
scanner = await LoraScanner.get_instance()
|
||||
await cls.register_service("lora_scanner", scanner)
|
||||
return scanner
|
||||
"""Get or create LoRA scanner instance"""
|
||||
service_name = "lora_scanner"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .lora_scanner import LoraScanner
|
||||
|
||||
scanner = await LoraScanner.get_instance()
|
||||
cls._services[service_name] = scanner
|
||||
logger.info(f"Created and registered {service_name}")
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
async def get_checkpoint_scanner(cls):
|
||||
"""Get the CheckpointScanner instance"""
|
||||
from .checkpoint_scanner import CheckpointScanner
|
||||
scanner = await cls.get_service("checkpoint_scanner")
|
||||
if scanner is None:
|
||||
"""Get or create Checkpoint scanner instance"""
|
||||
service_name = "checkpoint_scanner"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .checkpoint_scanner import CheckpointScanner
|
||||
|
||||
scanner = await CheckpointScanner.get_instance()
|
||||
await cls.register_service("checkpoint_scanner", scanner)
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
async def get_civitai_client(cls):
|
||||
"""Get the CivitaiClient instance"""
|
||||
from .civitai_client import CivitaiClient
|
||||
client = await cls.get_service("civitai_client")
|
||||
if client is None:
|
||||
client = await CivitaiClient.get_instance()
|
||||
await cls.register_service("civitai_client", client)
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
async def get_download_manager(cls):
|
||||
"""Get the DownloadManager instance"""
|
||||
from .download_manager import DownloadManager
|
||||
manager = await cls.get_service("download_manager")
|
||||
if manager is None:
|
||||
manager = await DownloadManager.get_instance()
|
||||
await cls.register_service("download_manager", manager)
|
||||
return manager
|
||||
|
||||
cls._services[service_name] = scanner
|
||||
logger.info(f"Created and registered {service_name}")
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
async def get_recipe_scanner(cls):
|
||||
"""Get the RecipeScanner instance"""
|
||||
from .recipe_scanner import RecipeScanner
|
||||
scanner = await cls.get_service("recipe_scanner")
|
||||
if scanner is None:
|
||||
lora_scanner = await cls.get_lora_scanner()
|
||||
scanner = RecipeScanner(lora_scanner)
|
||||
await cls.register_service("recipe_scanner", scanner)
|
||||
return scanner
|
||||
|
||||
"""Get or create Recipe scanner instance"""
|
||||
service_name = "recipe_scanner"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .recipe_scanner import RecipeScanner
|
||||
|
||||
scanner = await RecipeScanner.get_instance()
|
||||
cls._services[service_name] = scanner
|
||||
logger.info(f"Created and registered {service_name}")
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
async def get_civitai_client(cls):
|
||||
"""Get or create CivitAI client instance"""
|
||||
service_name = "civitai_client"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .civitai_client import CivitaiClient
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
cls._services[service_name] = client
|
||||
logger.info(f"Created and registered {service_name}")
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
async def get_download_manager(cls):
|
||||
"""Get or create Download manager instance"""
|
||||
service_name = "download_manager"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .download_manager import DownloadManager
|
||||
|
||||
manager = DownloadManager()
|
||||
cls._services[service_name] = manager
|
||||
logger.info(f"Created and registered {service_name}")
|
||||
return manager
|
||||
|
||||
@classmethod
|
||||
async def get_websocket_manager(cls):
|
||||
"""Get the WebSocketManager instance"""
|
||||
from .websocket_manager import ws_manager
|
||||
manager = await cls.get_service("websocket_manager")
|
||||
if manager is None:
|
||||
# ws_manager is already a global instance in websocket_manager.py
|
||||
"""Get or create WebSocket manager instance"""
|
||||
service_name = "websocket_manager"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .websocket_manager import ws_manager
|
||||
await cls.register_service("websocket_manager", ws_manager)
|
||||
manager = ws_manager
|
||||
return manager
|
||||
|
||||
cls._services[service_name] = ws_manager
|
||||
logger.info(f"Registered {service_name}")
|
||||
return ws_manager
|
||||
|
||||
@classmethod
|
||||
def clear_services(cls):
|
||||
"""Clear all registered services - mainly for testing"""
|
||||
cls._services.clear()
|
||||
cls._locks.clear()
|
||||
logger.info("Cleared all registered services")
|
||||
Reference in New Issue
Block a user