Files
ComfyUI-Lora-Manager/py/services/model_service_factory.py
Will Miao 3fc72d6bc1 feat(config): replace symlink scanning with cached mapping system
- Add `get_settings_dir` import for cache directory resolution
- Replace `_scan_symbolic_links` and `_rebuild_preview_roots` with unified `_initialize_symlink_mappings` method
- Implement fingerprint-based cache validation using root mtimes, inodes, and noise-aware timestamps
- Add helper methods for path normalization, cache location, and symlink root aggregation
- Improve performance by avoiding redundant symlink traversal when directory structure is unchanged
2025-12-09 19:51:01 +08:00

136 lines
4.6 KiB
Python

from typing import Dict, Type, Any
import logging
logger = logging.getLogger(__name__)
class ModelServiceFactory:
"""Factory for managing model services and routes"""
_services: Dict[str, Type] = {}
_routes: Dict[str, Type] = {}
_initialized_services: Dict[str, Any] = {}
_initialized_routes: Dict[str, Any] = {}
@classmethod
def register_model_type(cls, model_type: str, service_class: Type, route_class: Type):
"""Register a new model type with its service and route classes
Args:
model_type: The model type identifier (e.g., 'lora', 'checkpoint')
service_class: The service class for this model type
route_class: The route class for this model type
"""
cls._services[model_type] = service_class
cls._routes[model_type] = route_class
@classmethod
def get_service_class(cls, model_type: str) -> Type:
"""Get service class for a model type
Args:
model_type: The model type identifier
Returns:
The service class for the model type
Raises:
ValueError: If model type is not registered
"""
if model_type not in cls._services:
raise ValueError(f"Unknown model type: {model_type}")
return cls._services[model_type]
@classmethod
def get_route_class(cls, model_type: str) -> Type:
"""Get route class for a model type
Args:
model_type: The model type identifier
Returns:
The route class for the model type
Raises:
ValueError: If model type is not registered
"""
if model_type not in cls._routes:
raise ValueError(f"Unknown model type: {model_type}")
return cls._routes[model_type]
@classmethod
def get_route_instance(cls, model_type: str):
"""Get or create route instance for a model type
Args:
model_type: The model type identifier
Returns:
The route instance for the model type
"""
if model_type not in cls._initialized_routes:
route_class = cls.get_route_class(model_type)
cls._initialized_routes[model_type] = route_class()
return cls._initialized_routes[model_type]
@classmethod
def setup_all_routes(cls, app):
"""Setup routes for all registered model types
Args:
app: The aiohttp application instance
"""
for model_type in cls._services.keys():
try:
routes_instance = cls.get_route_instance(model_type)
routes_instance.setup_routes(app)
except Exception as e:
logger.error(f"Failed to setup routes for {model_type}: {e}", exc_info=True)
@classmethod
def get_registered_types(cls) -> list:
"""Get list of all registered model types
Returns:
List of registered model type identifiers
"""
return list(cls._services.keys())
@classmethod
def is_registered(cls, model_type: str) -> bool:
"""Check if a model type is registered
Args:
model_type: The model type identifier
Returns:
True if the model type is registered, False otherwise
"""
return model_type in cls._services
@classmethod
def clear_registrations(cls):
"""Clear all registrations - mainly for testing purposes"""
cls._services.clear()
cls._routes.clear()
cls._initialized_services.clear()
cls._initialized_routes.clear()
logger.info("Cleared all model type registrations")
def register_default_model_types():
"""Register the default model types (LoRA, Checkpoint, and Embedding)"""
from ..services.lora_service import LoraService
from ..services.checkpoint_service import CheckpointService
from ..services.embedding_service import EmbeddingService
from ..routes.lora_routes import LoraRoutes
from ..routes.checkpoint_routes import CheckpointRoutes
from ..routes.embedding_routes import EmbeddingRoutes
# Register LoRA model type
ModelServiceFactory.register_model_type('lora', LoraService, LoraRoutes)
# Register Checkpoint model type
ModelServiceFactory.register_model_type('checkpoint', CheckpointService, CheckpointRoutes)
# Register Embedding model type
ModelServiceFactory.register_model_type('embedding', EmbeddingService, EmbeddingRoutes)