refactor: Implement base model routes and services for LoRA and Checkpoint

- Added BaseModelRoutes class to handle common routes and logic for model types.
- Created CheckpointRoutes class inheriting from BaseModelRoutes for checkpoint-specific routes.
- Implemented CheckpointService class for handling checkpoint-related data and operations.
- Developed LoraService class for managing LoRA-specific functionalities.
- Introduced ModelServiceFactory to manage service and route registrations for different model types.
- Established methods for fetching, filtering, and formatting model data across services.
- Integrated CivitAI metadata handling within model routes and services.
- Added pagination and filtering capabilities for model data retrieval.
This commit is contained in:
Will Miao
2025-07-23 14:39:02 +08:00
parent ee609e8eac
commit a2b81ea099
15 changed files with 2185 additions and 1385 deletions

View File

@@ -0,0 +1,248 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Type, Set
import logging
from ..utils.models import BaseModelMetadata
from ..utils.constants import NSFW_LEVELS
from .settings_manager import settings
from ..utils.utils import fuzzy_match
logger = logging.getLogger(__name__)
class BaseModelService(ABC):
"""Base service class for all model types"""
def __init__(self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata]):
"""Initialize the service
Args:
model_type: Type of model (lora, checkpoint, etc.)
scanner: Model scanner instance
metadata_class: Metadata class for this model type
"""
self.model_type = model_type
self.scanner = scanner
self.metadata_class = metadata_class
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
folder: str = None, search: str = None, fuzzy_search: bool = False,
base_models: list = None, tags: list = None,
search_options: dict = None, hash_filters: dict = None,
favorites_only: bool = False, **kwargs) -> Dict:
"""Get paginated and filtered model data
Args:
page: Page number (1-based)
page_size: Number of items per page
sort_by: Sort criteria ('name' or 'date')
folder: Folder filter
search: Search term
fuzzy_search: Whether to use fuzzy search
base_models: List of base models to filter by
tags: List of tags to filter by
search_options: Search options dict
hash_filters: Hash filtering options
favorites_only: Filter for favorites only
**kwargs: Additional model-specific filters
Returns:
Dict containing paginated results
"""
cache = await self.scanner.get_cached_data()
# Get default search options if not provided
if search_options is None:
search_options = {
'filename': True,
'modelname': True,
'tags': False,
'recursive': False,
}
# Get the base data set
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
# Apply hash filtering if provided (highest priority)
if hash_filters:
filtered_data = await self._apply_hash_filters(filtered_data, hash_filters)
# Jump to pagination for hash filters
return self._paginate(filtered_data, page, page_size)
# Apply common filters
filtered_data = await self._apply_common_filters(
filtered_data, folder, base_models, tags, favorites_only, search_options
)
# Apply search filtering
if search:
filtered_data = await self._apply_search_filters(
filtered_data, search, fuzzy_search, search_options
)
# Apply model-specific filters
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
return self._paginate(filtered_data, page, page_size)
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
"""Apply hash-based filtering"""
single_hash = hash_filters.get('single_hash')
multiple_hashes = hash_filters.get('multiple_hashes')
if single_hash:
# Filter by single hash
single_hash = single_hash.lower()
return [
item for item in data
if item.get('sha256', '').lower() == single_hash
]
elif multiple_hashes:
# Filter by multiple hashes
hash_set = set(hash.lower() for hash in multiple_hashes)
return [
item for item in data
if item.get('sha256', '').lower() in hash_set
]
return data
async def _apply_common_filters(self, data: List[Dict], folder: str = None,
base_models: list = None, tags: list = None,
favorites_only: bool = False, search_options: dict = None) -> List[Dict]:
"""Apply common filters that work across all model types"""
# Apply SFW filtering if enabled in settings
if settings.get('show_only_sfw', False):
data = [
item for item in data
if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R']
]
# Apply favorites filtering if enabled
if favorites_only:
data = [
item for item in data
if item.get('favorite', False) is True
]
# Apply folder filtering
if folder is not None:
if search_options and search_options.get('recursive', False):
# Recursive folder filtering - include all subfolders
data = [
item for item in data
if item['folder'].startswith(folder)
]
else:
# Exact folder filtering
data = [
item for item in data
if item['folder'] == folder
]
# Apply base model filtering
if base_models and len(base_models) > 0:
data = [
item for item in data
if item.get('base_model') in base_models
]
# Apply tag filtering
if tags and len(tags) > 0:
data = [
item for item in data
if any(tag in item.get('tags', []) for tag in tags)
]
return data
async def _apply_search_filters(self, data: List[Dict], search: str,
fuzzy_search: bool, search_options: dict) -> List[Dict]:
"""Apply search filtering"""
search_results = []
for item in data:
# Search by file name
if search_options.get('filename', True):
if fuzzy_search:
if fuzzy_match(item.get('file_name', ''), search):
search_results.append(item)
continue
elif search.lower() in item.get('file_name', '').lower():
search_results.append(item)
continue
# Search by model name
if search_options.get('modelname', True):
if fuzzy_search:
if fuzzy_match(item.get('model_name', ''), search):
search_results.append(item)
continue
elif search.lower() in item.get('model_name', '').lower():
search_results.append(item)
continue
# Search by tags
if search_options.get('tags', False) and 'tags' in item:
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower())
for tag in item['tags']):
search_results.append(item)
continue
return search_results
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
"""Apply model-specific filters - to be overridden by subclasses if needed"""
return data
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
"""Apply pagination to filtered data"""
total_items = len(data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
return {
'items': data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
@abstractmethod
async def format_response(self, model_data: Dict) -> Dict:
"""Format model data for API response - must be implemented by subclasses"""
pass
# Common service methods that delegate to scanner
async def get_top_tags(self, limit: int = 20) -> List[Dict]:
"""Get top tags sorted by frequency"""
return await self.scanner.get_top_tags(limit)
async def get_base_models(self, limit: int = 20) -> List[Dict]:
"""Get base models sorted by frequency"""
return await self.scanner.get_base_models(limit)
def has_hash(self, sha256: str) -> bool:
"""Check if a model with given hash exists"""
return self.scanner.has_hash(sha256)
def get_path_by_hash(self, sha256: str) -> Optional[str]:
"""Get file path for a model by its hash"""
return self.scanner.get_path_by_hash(sha256)
def get_hash_by_path(self, file_path: str) -> Optional[str]:
"""Get hash for a model by its file path"""
return self.scanner.get_hash_by_path(file_path)
async def scan_models(self, force_refresh: bool = False, rebuild_cache: bool = False):
"""Trigger model scanning"""
return await self.scanner.get_cached_data(force_refresh=force_refresh, rebuild_cache=rebuild_cache)
async def get_model_info_by_name(self, name: str):
"""Get model information by name"""
return await self.scanner.get_model_info_by_name(name)
def get_model_roots(self) -> List[str]:
"""Get model root directories"""
return self.scanner.get_model_roots()

View File

@@ -1,14 +1,12 @@
import os
import logging
import asyncio
from typing import List, Dict, Optional, Set
import folder_paths # type: ignore
from typing import List, Dict
from ..utils.models import CheckpointMetadata
from ..config import config
from .model_scanner import ModelScanner
from .model_hash_index import ModelHashIndex
from .service_registry import ServiceRegistry
logger = logging.getLogger(__name__)
@@ -109,4 +107,31 @@ class CheckpointScanner(ModelScanner):
if result:
checkpoints.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
logger.error(f"Error processing {file_path}: {e}")
# Checkpoint-specific hash index functionality
def has_checkpoint_hash(self, sha256: str) -> bool:
"""Check if a checkpoint with given hash exists"""
return self.has_hash(sha256)
def get_checkpoint_path_by_hash(self, sha256: str) -> str:
"""Get file path for a checkpoint by its hash"""
return self.get_path_by_hash(sha256)
def get_checkpoint_hash_by_path(self, file_path: str) -> str:
"""Get hash for a checkpoint by its file path"""
return self.get_hash_by_path(file_path)
async def get_checkpoint_info_by_name(self, name):
"""Get checkpoint information by name"""
try:
cache = await self.get_cached_data()
for checkpoint in cache.raw_data:
if checkpoint.get("file_name") == name:
return checkpoint
return None
except Exception as e:
logger.error(f"Error getting checkpoint info by name: {e}", exc_info=True)
return None

View File

@@ -0,0 +1,51 @@
import os
import logging
from typing import Dict, List, Optional
from .base_model_service import BaseModelService
from ..utils.models import CheckpointMetadata
from ..config import config
from ..utils.routes_common import ModelRouteUtils
logger = logging.getLogger(__name__)
class CheckpointService(BaseModelService):
"""Checkpoint-specific service implementation"""
def __init__(self, scanner):
"""Initialize Checkpoint service
Args:
scanner: Checkpoint scanner instance
"""
super().__init__("checkpoint", scanner, CheckpointMetadata)
async def format_response(self, checkpoint_data: Dict) -> Dict:
"""Format Checkpoint data for API response"""
return {
"model_name": checkpoint_data["model_name"],
"file_name": checkpoint_data["file_name"],
"preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")),
"preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0),
"base_model": checkpoint_data.get("base_model", ""),
"folder": checkpoint_data["folder"],
"sha256": checkpoint_data.get("sha256", ""),
"file_path": checkpoint_data["file_path"].replace(os.sep, "/"),
"file_size": checkpoint_data.get("size", 0),
"modified": checkpoint_data.get("modified", ""),
"tags": checkpoint_data.get("tags", []),
"modelDescription": checkpoint_data.get("modelDescription", ""),
"from_civitai": checkpoint_data.get("from_civitai", True),
"notes": checkpoint_data.get("notes", ""),
"model_type": checkpoint_data.get("model_type", "checkpoint"),
"favorite": checkpoint_data.get("favorite", False),
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint_data.get("civitai", {}))
}
def find_duplicate_hashes(self) -> Dict:
"""Find Checkpoints with duplicate SHA256 hashes"""
return self.scanner._hash_index.get_duplicate_hashes()
def find_duplicate_filenames(self) -> Dict:
"""Find Checkpoints with conflicting filenames"""
return self.scanner._hash_index.get_duplicate_filenames()

172
py/services/lora_service.py Normal file
View File

@@ -0,0 +1,172 @@
import os
import logging
from typing import Dict, List, Optional
from .base_model_service import BaseModelService
from ..utils.models import LoraMetadata
from ..config import config
from ..utils.routes_common import ModelRouteUtils
logger = logging.getLogger(__name__)
class LoraService(BaseModelService):
"""LoRA-specific service implementation"""
def __init__(self, scanner):
"""Initialize LoRA service
Args:
scanner: LoRA scanner instance
"""
super().__init__("lora", scanner, LoraMetadata)
async def format_response(self, lora_data: Dict) -> Dict:
"""Format LoRA data for API response"""
return {
"model_name": lora_data["model_name"],
"file_name": lora_data["file_name"],
"preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")),
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
"base_model": lora_data.get("base_model", ""),
"folder": lora_data["folder"],
"sha256": lora_data.get("sha256", ""),
"file_path": lora_data["file_path"].replace(os.sep, "/"),
"file_size": lora_data.get("size", 0),
"modified": lora_data.get("modified", ""),
"tags": lora_data.get("tags", []),
"modelDescription": lora_data.get("modelDescription", ""),
"from_civitai": lora_data.get("from_civitai", True),
"usage_tips": lora_data.get("usage_tips", ""),
"notes": lora_data.get("notes", ""),
"favorite": lora_data.get("favorite", False),
"civitai": ModelRouteUtils.filter_civitai_data(lora_data.get("civitai", {}))
}
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
"""Apply LoRA-specific filters"""
# Handle first_letter filter for LoRAs
first_letter = kwargs.get('first_letter')
if first_letter:
data = self._filter_by_first_letter(data, first_letter)
return data
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
"""Filter LoRAs by first letter"""
if letter == '#':
# Filter for non-alphabetic characters
return [
item for item in data
if not item.get('model_name', '')[0].isalpha()
]
elif letter == 'CJK':
# Filter for CJK characters
return [
item for item in data
if item.get('model_name', '') and self._is_cjk_character(item['model_name'][0])
]
else:
# Filter for specific letter
return [
item for item in data
if item.get('model_name', '').lower().startswith(letter.lower())
]
def _is_cjk_character(self, char: str) -> bool:
"""Check if character is CJK (Chinese, Japanese, Korean)"""
cjk_ranges = [
(0x4E00, 0x9FFF), # CJK Unified Ideographs
(0x3400, 0x4DBF), # CJK Extension A
(0x20000, 0x2A6DF), # CJK Extension B
(0x2A700, 0x2B73F), # CJK Extension C
(0x2B740, 0x2B81F), # CJK Extension D
(0x3040, 0x309F), # Hiragana
(0x30A0, 0x30FF), # Katakana
(0xAC00, 0xD7AF), # Hangul Syllables
]
char_code = ord(char)
return any(start <= char_code <= end for start, end in cjk_ranges)
# LoRA-specific methods
async def get_letter_counts(self) -> Dict[str, int]:
"""Get count of LoRAs for each letter of the alphabet"""
cache = await self.scanner.get_cached_data()
letter_counts = {}
for lora in cache.raw_data:
model_name = lora.get('model_name', '')
if model_name:
first_char = model_name[0].upper()
if first_char.isalpha():
letter_counts[first_char] = letter_counts.get(first_char, 0) + 1
elif self._is_cjk_character(first_char):
letter_counts['CJK'] = letter_counts.get('CJK', 0) + 1
else:
letter_counts['#'] = letter_counts.get('#', 0) + 1
return letter_counts
async def get_lora_notes(self, lora_name: str) -> Optional[str]:
"""Get notes for a specific LoRA file"""
cache = await self.scanner.get_cached_data()
for lora in cache.raw_data:
if lora['file_name'] == lora_name:
return lora.get('notes', '')
return None
async def get_lora_trigger_words(self, lora_name: str) -> List[str]:
"""Get trigger words for a specific LoRA file"""
cache = await self.scanner.get_cached_data()
for lora in cache.raw_data:
if lora['file_name'] == lora_name:
civitai_data = lora.get('civitai', {})
return civitai_data.get('trainedWords', [])
return []
async def get_lora_preview_url(self, lora_name: str) -> Optional[str]:
"""Get the static preview URL for a LoRA file"""
cache = await self.scanner.get_cached_data()
for lora in cache.raw_data:
if lora['file_name'] == lora_name:
preview_url = lora.get('preview_url')
if preview_url:
return config.get_preview_static_url(preview_url)
return None
async def get_lora_civitai_url(self, lora_name: str) -> Dict[str, Optional[str]]:
"""Get the Civitai URL for a LoRA file"""
cache = await self.scanner.get_cached_data()
for lora in cache.raw_data:
if lora['file_name'] == lora_name:
civitai_data = lora.get('civitai', {})
model_id = civitai_data.get('modelId')
version_id = civitai_data.get('id')
if model_id:
civitai_url = f"https://civitai.com/models/{model_id}"
if version_id:
civitai_url += f"?modelVersionId={version_id}"
return {
'civitai_url': civitai_url,
'model_id': str(model_id),
'version_id': str(version_id) if version_id else None
}
return {'civitai_url': None, 'model_id': None, 'version_id': None}
def find_duplicate_hashes(self) -> Dict:
"""Find LoRAs with duplicate SHA256 hashes"""
return self.scanner._hash_index.get_duplicate_hashes()
def find_duplicate_filenames(self) -> Dict:
"""Find LoRAs with conflicting filenames"""
return self.scanner._hash_index.get_duplicate_filenames()

View File

@@ -0,0 +1,137 @@
from typing import Dict, Type, Any
import logging
logger = logging.getLogger(__name__)
class ModelServiceFactory:
"""Factory for managing model services and routes"""
_services: Dict[str, Type] = {}
_routes: Dict[str, Type] = {}
_initialized_services: Dict[str, Any] = {}
_initialized_routes: Dict[str, Any] = {}
@classmethod
def register_model_type(cls, model_type: str, service_class: Type, route_class: Type):
"""Register a new model type with its service and route classes
Args:
model_type: The model type identifier (e.g., 'lora', 'checkpoint')
service_class: The service class for this model type
route_class: The route class for this model type
"""
cls._services[model_type] = service_class
cls._routes[model_type] = route_class
logger.info(f"Registered model type '{model_type}' with service {service_class.__name__} and routes {route_class.__name__}")
@classmethod
def get_service_class(cls, model_type: str) -> Type:
"""Get service class for a model type
Args:
model_type: The model type identifier
Returns:
The service class for the model type
Raises:
ValueError: If model type is not registered
"""
if model_type not in cls._services:
raise ValueError(f"Unknown model type: {model_type}")
return cls._services[model_type]
@classmethod
def get_route_class(cls, model_type: str) -> Type:
"""Get route class for a model type
Args:
model_type: The model type identifier
Returns:
The route class for the model type
Raises:
ValueError: If model type is not registered
"""
if model_type not in cls._routes:
raise ValueError(f"Unknown model type: {model_type}")
return cls._routes[model_type]
@classmethod
def get_route_instance(cls, model_type: str):
"""Get or create route instance for a model type
Args:
model_type: The model type identifier
Returns:
The route instance for the model type
"""
if model_type not in cls._initialized_routes:
route_class = cls.get_route_class(model_type)
cls._initialized_routes[model_type] = route_class()
return cls._initialized_routes[model_type]
@classmethod
def setup_all_routes(cls, app):
"""Setup routes for all registered model types
Args:
app: The aiohttp application instance
"""
logger.info(f"Setting up routes for {len(cls._services)} registered model types")
for model_type in cls._services.keys():
try:
routes_instance = cls.get_route_instance(model_type)
routes_instance.setup_routes(app)
logger.info(f"Successfully set up routes for {model_type}")
except Exception as e:
logger.error(f"Failed to setup routes for {model_type}: {e}", exc_info=True)
@classmethod
def get_registered_types(cls) -> list:
"""Get list of all registered model types
Returns:
List of registered model type identifiers
"""
return list(cls._services.keys())
@classmethod
def is_registered(cls, model_type: str) -> bool:
"""Check if a model type is registered
Args:
model_type: The model type identifier
Returns:
True if the model type is registered, False otherwise
"""
return model_type in cls._services
@classmethod
def clear_registrations(cls):
"""Clear all registrations - mainly for testing purposes"""
cls._services.clear()
cls._routes.clear()
cls._initialized_services.clear()
cls._initialized_routes.clear()
logger.info("Cleared all model type registrations")
def register_default_model_types():
"""Register the default model types (LoRA and Checkpoint)"""
from ..services.lora_service import LoraService
from ..services.checkpoint_service import CheckpointService
from ..routes.lora_routes import LoraRoutes
from ..routes.checkpoint_routes import CheckpointRoutes
# Register LoRA model type
ModelServiceFactory.register_model_type('lora', LoraService, LoraRoutes)
# Register Checkpoint model type
ModelServiceFactory.register_model_type('checkpoint', CheckpointService, CheckpointRoutes)
logger.info("Registered default model types: lora, checkpoint")

View File

@@ -7,106 +7,176 @@ logger = logging.getLogger(__name__)
T = TypeVar('T') # Define a type variable for service types
class ServiceRegistry:
"""Centralized registry for service singletons"""
"""Central registry for managing singleton services"""
_instance = None
_services: Dict[str, Any] = {}
_lock = asyncio.Lock()
_locks: Dict[str, asyncio.Lock] = {}
@classmethod
def get_instance(cls):
"""Get singleton instance of the registry"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
async def register_service(cls, name: str, service: Any) -> None:
"""Register a service instance with the registry
Args:
name: Service name identifier
service: Service instance to register
"""
cls._services[name] = service
logger.debug(f"Registered service: {name}")
@classmethod
async def register_service(cls, service_name: str, service_instance: Any) -> None:
"""Register a service instance with the registry"""
registry = cls.get_instance()
async with cls._lock:
registry._services[service_name] = service_instance
logger.debug(f"Registered service: {service_name}")
async def get_service(cls, name: str) -> Optional[Any]:
"""Get a service instance by name
Args:
name: Service name identifier
Returns:
Service instance or None if not found
"""
return cls._services.get(name)
@classmethod
async def get_service(cls, service_name: str) -> Any:
"""Get a service instance by name"""
registry = cls.get_instance()
async with cls._lock:
if service_name not in registry._services:
logger.debug(f"Service {service_name} not found in registry")
return None
return registry._services[service_name]
def _get_lock(cls, name: str) -> asyncio.Lock:
"""Get or create a lock for a service
Args:
name: Service name identifier
Returns:
AsyncIO lock for the service
"""
if name not in cls._locks:
cls._locks[name] = asyncio.Lock()
return cls._locks[name]
@classmethod
def get_service_sync(cls, service_name: str) -> Any:
"""Get a service instance by name (synchronous version)"""
registry = cls.get_instance()
if service_name not in registry._services:
logger.debug(f"Service {service_name} not found in registry")
return None
return registry._services[service_name]
# Convenience methods for common services
@classmethod
async def get_lora_scanner(cls):
"""Get the LoraScanner instance"""
from .lora_scanner import LoraScanner
scanner = await cls.get_service("lora_scanner")
if scanner is None:
scanner = await LoraScanner.get_instance()
await cls.register_service("lora_scanner", scanner)
return scanner
"""Get or create LoRA scanner instance"""
service_name = "lora_scanner"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .lora_scanner import LoraScanner
scanner = await LoraScanner.get_instance()
cls._services[service_name] = scanner
logger.info(f"Created and registered {service_name}")
return scanner
@classmethod
async def get_checkpoint_scanner(cls):
"""Get the CheckpointScanner instance"""
from .checkpoint_scanner import CheckpointScanner
scanner = await cls.get_service("checkpoint_scanner")
if scanner is None:
"""Get or create Checkpoint scanner instance"""
service_name = "checkpoint_scanner"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .checkpoint_scanner import CheckpointScanner
scanner = await CheckpointScanner.get_instance()
await cls.register_service("checkpoint_scanner", scanner)
return scanner
@classmethod
async def get_civitai_client(cls):
"""Get the CivitaiClient instance"""
from .civitai_client import CivitaiClient
client = await cls.get_service("civitai_client")
if client is None:
client = await CivitaiClient.get_instance()
await cls.register_service("civitai_client", client)
return client
@classmethod
async def get_download_manager(cls):
"""Get the DownloadManager instance"""
from .download_manager import DownloadManager
manager = await cls.get_service("download_manager")
if manager is None:
manager = await DownloadManager.get_instance()
await cls.register_service("download_manager", manager)
return manager
cls._services[service_name] = scanner
logger.info(f"Created and registered {service_name}")
return scanner
@classmethod
async def get_recipe_scanner(cls):
"""Get the RecipeScanner instance"""
from .recipe_scanner import RecipeScanner
scanner = await cls.get_service("recipe_scanner")
if scanner is None:
lora_scanner = await cls.get_lora_scanner()
scanner = RecipeScanner(lora_scanner)
await cls.register_service("recipe_scanner", scanner)
return scanner
"""Get or create Recipe scanner instance"""
service_name = "recipe_scanner"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .recipe_scanner import RecipeScanner
scanner = await RecipeScanner.get_instance()
cls._services[service_name] = scanner
logger.info(f"Created and registered {service_name}")
return scanner
@classmethod
async def get_civitai_client(cls):
"""Get or create CivitAI client instance"""
service_name = "civitai_client"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .civitai_client import CivitaiClient
client = await CivitaiClient.get_instance()
cls._services[service_name] = client
logger.info(f"Created and registered {service_name}")
return client
@classmethod
async def get_download_manager(cls):
"""Get or create Download manager instance"""
service_name = "download_manager"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .download_manager import DownloadManager
manager = DownloadManager()
cls._services[service_name] = manager
logger.info(f"Created and registered {service_name}")
return manager
@classmethod
async def get_websocket_manager(cls):
"""Get the WebSocketManager instance"""
from .websocket_manager import ws_manager
manager = await cls.get_service("websocket_manager")
if manager is None:
# ws_manager is already a global instance in websocket_manager.py
"""Get or create WebSocket manager instance"""
service_name = "websocket_manager"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .websocket_manager import ws_manager
await cls.register_service("websocket_manager", ws_manager)
manager = ws_manager
return manager
cls._services[service_name] = ws_manager
logger.info(f"Registered {service_name}")
return ws_manager
@classmethod
def clear_services(cls):
"""Clear all registered services - mainly for testing"""
cls._services.clear()
cls._locks.clear()
logger.info("Cleared all registered services")