mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 05:32:12 -03:00
258 lines
8.8 KiB
Python
258 lines
8.8 KiB
Python
import asyncio
|
|
import logging
|
|
from typing import Optional, Dict, Any, TypeVar, Type
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
T = TypeVar('T') # Define a type variable for service types
|
|
|
|
class ServiceRegistry:
|
|
"""Central registry for managing singleton services"""
|
|
|
|
_services: Dict[str, Any] = {}
|
|
_locks: Dict[str, asyncio.Lock] = {}
|
|
|
|
@classmethod
|
|
async def register_service(cls, name: str, service: Any) -> None:
|
|
"""Register a service instance with the registry
|
|
|
|
Args:
|
|
name: Service name identifier
|
|
service: Service instance to register
|
|
"""
|
|
cls._services[name] = service
|
|
logger.debug(f"Registered service: {name}")
|
|
|
|
@classmethod
|
|
async def get_service(cls, name: str) -> Optional[Any]:
|
|
"""Get a service instance by name
|
|
|
|
Args:
|
|
name: Service name identifier
|
|
|
|
Returns:
|
|
Service instance or None if not found
|
|
"""
|
|
return cls._services.get(name)
|
|
|
|
@classmethod
|
|
def get_service_sync(cls, name: str) -> Optional[Any]:
|
|
"""Synchronously get a service instance by name
|
|
|
|
Args:
|
|
name: Service name identifier
|
|
|
|
Returns:
|
|
Service instance or None if not found
|
|
"""
|
|
return cls._services.get(name)
|
|
|
|
@classmethod
|
|
def _get_lock(cls, name: str) -> asyncio.Lock:
|
|
"""Get or create a lock for a service
|
|
|
|
Args:
|
|
name: Service name identifier
|
|
|
|
Returns:
|
|
AsyncIO lock for the service
|
|
"""
|
|
if name not in cls._locks:
|
|
cls._locks[name] = asyncio.Lock()
|
|
return cls._locks[name]
|
|
|
|
@classmethod
|
|
async def get_lora_scanner(cls):
|
|
"""Get or create LoRA scanner instance"""
|
|
service_name = "lora_scanner"
|
|
|
|
if service_name in cls._services:
|
|
return cls._services[service_name]
|
|
|
|
async with cls._get_lock(service_name):
|
|
# Double-check after acquiring lock
|
|
if service_name in cls._services:
|
|
return cls._services[service_name]
|
|
|
|
# Import here to avoid circular imports
|
|
from .lora_scanner import LoraScanner
|
|
|
|
scanner = await LoraScanner.get_instance()
|
|
cls._services[service_name] = scanner
|
|
logger.debug(f"Created and registered {service_name}")
|
|
return scanner
|
|
|
|
@classmethod
|
|
async def get_checkpoint_scanner(cls):
|
|
"""Get or create Checkpoint scanner instance"""
|
|
service_name = "checkpoint_scanner"
|
|
|
|
if service_name in cls._services:
|
|
return cls._services[service_name]
|
|
|
|
async with cls._get_lock(service_name):
|
|
# Double-check after acquiring lock
|
|
if service_name in cls._services:
|
|
return cls._services[service_name]
|
|
|
|
# Import here to avoid circular imports
|
|
from .checkpoint_scanner import CheckpointScanner
|
|
|
|
scanner = await CheckpointScanner.get_instance()
|
|
cls._services[service_name] = scanner
|
|
logger.debug(f"Created and registered {service_name}")
|
|
return scanner
|
|
|
|
@classmethod
|
|
async def get_recipe_scanner(cls):
|
|
"""Get or create Recipe scanner instance"""
|
|
service_name = "recipe_scanner"
|
|
|
|
if service_name in cls._services:
|
|
return cls._services[service_name]
|
|
|
|
async with cls._get_lock(service_name):
|
|
# Double-check after acquiring lock
|
|
if service_name in cls._services:
|
|
return cls._services[service_name]
|
|
|
|
# Import here to avoid circular imports
|
|
from .recipe_scanner import RecipeScanner
|
|
|
|
scanner = await RecipeScanner.get_instance()
|
|
cls._services[service_name] = scanner
|
|
logger.debug(f"Created and registered {service_name}")
|
|
return scanner
|
|
|
|
@classmethod
|
|
async def get_civitai_client(cls):
|
|
"""Get or create CivitAI client instance"""
|
|
service_name = "civitai_client"
|
|
|
|
if service_name in cls._services:
|
|
return cls._services[service_name]
|
|
|
|
async with cls._get_lock(service_name):
|
|
# Double-check after acquiring lock
|
|
if service_name in cls._services:
|
|
return cls._services[service_name]
|
|
|
|
# Import here to avoid circular imports
|
|
from .civitai_client import CivitaiClient
|
|
|
|
client = await CivitaiClient.get_instance()
|
|
cls._services[service_name] = client
|
|
logger.debug(f"Created and registered {service_name}")
|
|
return client
|
|
|
|
@classmethod
|
|
async def get_model_update_service(cls):
|
|
"""Get or create the model update tracking service."""
|
|
|
|
service_name = "model_update_service"
|
|
|
|
if service_name in cls._services:
|
|
return cls._services[service_name]
|
|
|
|
async with cls._get_lock(service_name):
|
|
if service_name in cls._services:
|
|
return cls._services[service_name]
|
|
|
|
from .model_update_service import ModelUpdateService
|
|
from .persistent_model_cache import get_persistent_cache
|
|
|
|
cache = get_persistent_cache()
|
|
service = ModelUpdateService(cache.get_database_path())
|
|
cls._services[service_name] = service
|
|
logger.debug(f"Created and registered {service_name}")
|
|
return service
|
|
|
|
@classmethod
|
|
async def get_civarchive_client(cls):
|
|
"""Get or create CivArchive client instance"""
|
|
service_name = "civarchive_client"
|
|
|
|
if service_name in cls._services:
|
|
return cls._services[service_name]
|
|
|
|
async with cls._get_lock(service_name):
|
|
# Double-check after acquiring lock
|
|
if service_name in cls._services:
|
|
return cls._services[service_name]
|
|
|
|
# Import here to avoid circular imports
|
|
from .civarchive_client import CivArchiveClient
|
|
|
|
client = await CivArchiveClient.get_instance()
|
|
cls._services[service_name] = client
|
|
logger.debug(f"Created and registered {service_name}")
|
|
return client
|
|
|
|
@classmethod
|
|
async def get_download_manager(cls):
|
|
"""Get or create Download manager instance"""
|
|
service_name = "download_manager"
|
|
|
|
if service_name in cls._services:
|
|
return cls._services[service_name]
|
|
|
|
async with cls._get_lock(service_name):
|
|
# Double-check after acquiring lock
|
|
if service_name in cls._services:
|
|
return cls._services[service_name]
|
|
|
|
# Import here to avoid circular imports
|
|
from .download_manager import DownloadManager
|
|
|
|
manager = DownloadManager()
|
|
cls._services[service_name] = manager
|
|
logger.debug(f"Created and registered {service_name}")
|
|
return manager
|
|
|
|
@classmethod
|
|
async def get_websocket_manager(cls):
|
|
"""Get or create WebSocket manager instance"""
|
|
service_name = "websocket_manager"
|
|
|
|
if service_name in cls._services:
|
|
return cls._services[service_name]
|
|
|
|
async with cls._get_lock(service_name):
|
|
# Double-check after acquiring lock
|
|
if service_name in cls._services:
|
|
return cls._services[service_name]
|
|
|
|
# Import here to avoid circular imports
|
|
from .websocket_manager import ws_manager
|
|
|
|
cls._services[service_name] = ws_manager
|
|
logger.debug(f"Registered {service_name}")
|
|
return ws_manager
|
|
|
|
@classmethod
|
|
async def get_embedding_scanner(cls):
|
|
"""Get or create Embedding scanner instance"""
|
|
service_name = "embedding_scanner"
|
|
|
|
if service_name in cls._services:
|
|
return cls._services[service_name]
|
|
|
|
async with cls._get_lock(service_name):
|
|
# Double-check after acquiring lock
|
|
if service_name in cls._services:
|
|
return cls._services[service_name]
|
|
|
|
# Import here to avoid circular imports
|
|
from .embedding_scanner import EmbeddingScanner
|
|
|
|
scanner = await EmbeddingScanner.get_instance()
|
|
cls._services[service_name] = scanner
|
|
logger.debug(f"Created and registered {service_name}")
|
|
return scanner
|
|
|
|
@classmethod
|
|
def clear_services(cls):
|
|
"""Clear all registered services - mainly for testing"""
|
|
cls._services.clear()
|
|
cls._locks.clear()
|
|
logger.info("Cleared all registered services") |