From 62f06302f0fb91cc9c96241dbd28ce8c1572e0de Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 10 Sep 2025 20:29:26 +0800 Subject: [PATCH] refactor(routes): replace ModelMetadataProviderManager with get_default_metadata_provider in checkpoint, embedding, and lora routes --- py/routes/checkpoint_routes.py | 7 +++---- py/routes/embedding_routes.py | 7 +++---- py/routes/lora_routes.py | 16 ++++++++-------- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/py/routes/checkpoint_routes.py b/py/routes/checkpoint_routes.py index b93700cf..a0f6a027 100644 --- a/py/routes/checkpoint_routes.py +++ b/py/routes/checkpoint_routes.py @@ -4,7 +4,7 @@ from aiohttp import web from .base_model_routes import BaseModelRoutes from ..services.checkpoint_service import CheckpointService from ..services.service_registry import ServiceRegistry -from ..services.model_metadata_provider import ModelMetadataProviderManager +from ..services.metadata_service import get_default_metadata_provider from ..config import config logger = logging.getLogger(__name__) @@ -16,14 +16,12 @@ class CheckpointRoutes(BaseModelRoutes): """Initialize Checkpoint routes with Checkpoint service""" # Service will be initialized later via setup_routes self.service = None - self.metadata_provider = None self.template_name = "checkpoints.html" async def initialize_services(self): """Initialize services from ServiceRegistry""" checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() self.service = CheckpointService(checkpoint_scanner) - self.metadata_provider = await ModelMetadataProviderManager.get_instance() # Initialize parent with the service super().__init__(self.service) @@ -67,7 +65,8 @@ class CheckpointRoutes(BaseModelRoutes): """Get available versions for a Civitai checkpoint model with local availability info""" try: model_id = request.match_info['model_id'] - response = await self.metadata_provider.get_model_versions(model_id) + metadata_provider = await get_default_metadata_provider() + response = await metadata_provider.get_model_versions(model_id) if not response or not response.get('modelVersions'): return web.Response(status=404, text="Model not found") diff --git a/py/routes/embedding_routes.py b/py/routes/embedding_routes.py index 65a66824..ab028666 100644 --- a/py/routes/embedding_routes.py +++ b/py/routes/embedding_routes.py @@ -4,7 +4,7 @@ from aiohttp import web from .base_model_routes import BaseModelRoutes from ..services.embedding_service import EmbeddingService from ..services.service_registry import ServiceRegistry -from ..services.model_metadata_provider import ModelMetadataProviderManager +from ..services.metadata_service import get_default_metadata_provider logger = logging.getLogger(__name__) @@ -15,14 +15,12 @@ class EmbeddingRoutes(BaseModelRoutes): """Initialize Embedding routes with Embedding service""" # Service will be initialized later via setup_routes self.service = None - self.metadata_provider = None self.template_name = "embeddings.html" async def initialize_services(self): """Initialize services from ServiceRegistry""" embedding_scanner = await ServiceRegistry.get_embedding_scanner() self.service = EmbeddingService(embedding_scanner) - self.metadata_provider = await ModelMetadataProviderManager.get_instance() # Initialize parent with the service super().__init__(self.service) @@ -62,7 +60,8 @@ class EmbeddingRoutes(BaseModelRoutes): """Get available versions for a Civitai embedding model with local availability info""" try: model_id = request.match_info['model_id'] - response = await self.metadata_provider.get_model_versions(model_id) + metadata_provider = await get_default_metadata_provider() + response = await metadata_provider.get_model_versions(model_id) if not response or not response.get('modelVersions'): return web.Response(status=404, text="Model not found") diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index 4c1c0467..4e261004 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -7,8 +7,7 @@ from server import PromptServer # type: ignore from .base_model_routes import BaseModelRoutes from ..services.lora_service import LoraService from ..services.service_registry import ServiceRegistry -from ..services.model_metadata_provider import ModelMetadataProviderManager -from ..utils.routes_common import ModelRouteUtils +from ..services.metadata_service import get_default_metadata_provider from ..utils.utils import get_lora_info logger = logging.getLogger(__name__) @@ -20,14 +19,12 @@ class LoraRoutes(BaseModelRoutes): """Initialize LoRA routes with LoRA service""" # Service will be initialized later via setup_routes self.service = None - self.metadata_provider = None self.template_name = "loras.html" async def initialize_services(self): """Initialize services from ServiceRegistry""" lora_scanner = await ServiceRegistry.get_lora_scanner() self.service = LoraService(lora_scanner) - self.metadata_provider = await ModelMetadataProviderManager.get_instance() # Initialize parent with the service super().__init__(self.service) @@ -218,7 +215,8 @@ class LoraRoutes(BaseModelRoutes): """Get available versions for a Civitai LoRA model with local availability info""" try: model_id = request.match_info['model_id'] - response = await self.metadata_provider.get_model_versions(model_id) + metadata_provider = await get_default_metadata_provider() + response = await metadata_provider.get_model_versions(model_id) if not response or not response.get('modelVersions'): return web.Response(status=404, text="Model not found") @@ -263,8 +261,9 @@ class LoraRoutes(BaseModelRoutes): model_version_id = request.match_info.get('modelVersionId') # Get model details from metadata provider - model, error_msg = await self.metadata_provider.get_model_version_info(model_version_id) - + metadata_provider = await get_default_metadata_provider() + model, error_msg = await metadata_provider.get_model_version_info(model_version_id) + if not model: # Log warning for failed model retrieval logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}") @@ -289,7 +288,8 @@ class LoraRoutes(BaseModelRoutes): """Get CivitAI model details by hash""" try: hash = request.match_info.get('hash') - model = await self.metadata_provider.get_model_by_hash(hash) + metadata_provider = await get_default_metadata_provider() + model = await metadata_provider.get_model_by_hash(hash) return web.json_response(model) except Exception as e: logger.error(f"Error fetching model details by hash: {e}")