mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
refactor(routes): replace ModelMetadataProviderManager with get_default_metadata_provider in checkpoint, embedding, and lora routes
This commit is contained in:
@@ -4,7 +4,7 @@ from aiohttp import web
|
|||||||
from .base_model_routes import BaseModelRoutes
|
from .base_model_routes import BaseModelRoutes
|
||||||
from ..services.checkpoint_service import CheckpointService
|
from ..services.checkpoint_service import CheckpointService
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..services.model_metadata_provider import ModelMetadataProviderManager
|
from ..services.metadata_service import get_default_metadata_provider
|
||||||
from ..config import config
|
from ..config import config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -16,14 +16,12 @@ class CheckpointRoutes(BaseModelRoutes):
|
|||||||
"""Initialize Checkpoint routes with Checkpoint service"""
|
"""Initialize Checkpoint routes with Checkpoint service"""
|
||||||
# Service will be initialized later via setup_routes
|
# Service will be initialized later via setup_routes
|
||||||
self.service = None
|
self.service = None
|
||||||
self.metadata_provider = None
|
|
||||||
self.template_name = "checkpoints.html"
|
self.template_name = "checkpoints.html"
|
||||||
|
|
||||||
async def initialize_services(self):
|
async def initialize_services(self):
|
||||||
"""Initialize services from ServiceRegistry"""
|
"""Initialize services from ServiceRegistry"""
|
||||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||||
self.service = CheckpointService(checkpoint_scanner)
|
self.service = CheckpointService(checkpoint_scanner)
|
||||||
self.metadata_provider = await ModelMetadataProviderManager.get_instance()
|
|
||||||
|
|
||||||
# Initialize parent with the service
|
# Initialize parent with the service
|
||||||
super().__init__(self.service)
|
super().__init__(self.service)
|
||||||
@@ -67,7 +65,8 @@ class CheckpointRoutes(BaseModelRoutes):
|
|||||||
"""Get available versions for a Civitai checkpoint model with local availability info"""
|
"""Get available versions for a Civitai checkpoint model with local availability info"""
|
||||||
try:
|
try:
|
||||||
model_id = request.match_info['model_id']
|
model_id = request.match_info['model_id']
|
||||||
response = await self.metadata_provider.get_model_versions(model_id)
|
metadata_provider = await get_default_metadata_provider()
|
||||||
|
response = await metadata_provider.get_model_versions(model_id)
|
||||||
if not response or not response.get('modelVersions'):
|
if not response or not response.get('modelVersions'):
|
||||||
return web.Response(status=404, text="Model not found")
|
return web.Response(status=404, text="Model not found")
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from aiohttp import web
|
|||||||
from .base_model_routes import BaseModelRoutes
|
from .base_model_routes import BaseModelRoutes
|
||||||
from ..services.embedding_service import EmbeddingService
|
from ..services.embedding_service import EmbeddingService
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..services.model_metadata_provider import ModelMetadataProviderManager
|
from ..services.metadata_service import get_default_metadata_provider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -15,14 +15,12 @@ class EmbeddingRoutes(BaseModelRoutes):
|
|||||||
"""Initialize Embedding routes with Embedding service"""
|
"""Initialize Embedding routes with Embedding service"""
|
||||||
# Service will be initialized later via setup_routes
|
# Service will be initialized later via setup_routes
|
||||||
self.service = None
|
self.service = None
|
||||||
self.metadata_provider = None
|
|
||||||
self.template_name = "embeddings.html"
|
self.template_name = "embeddings.html"
|
||||||
|
|
||||||
async def initialize_services(self):
|
async def initialize_services(self):
|
||||||
"""Initialize services from ServiceRegistry"""
|
"""Initialize services from ServiceRegistry"""
|
||||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||||
self.service = EmbeddingService(embedding_scanner)
|
self.service = EmbeddingService(embedding_scanner)
|
||||||
self.metadata_provider = await ModelMetadataProviderManager.get_instance()
|
|
||||||
|
|
||||||
# Initialize parent with the service
|
# Initialize parent with the service
|
||||||
super().__init__(self.service)
|
super().__init__(self.service)
|
||||||
@@ -62,7 +60,8 @@ class EmbeddingRoutes(BaseModelRoutes):
|
|||||||
"""Get available versions for a Civitai embedding model with local availability info"""
|
"""Get available versions for a Civitai embedding model with local availability info"""
|
||||||
try:
|
try:
|
||||||
model_id = request.match_info['model_id']
|
model_id = request.match_info['model_id']
|
||||||
response = await self.metadata_provider.get_model_versions(model_id)
|
metadata_provider = await get_default_metadata_provider()
|
||||||
|
response = await metadata_provider.get_model_versions(model_id)
|
||||||
if not response or not response.get('modelVersions'):
|
if not response or not response.get('modelVersions'):
|
||||||
return web.Response(status=404, text="Model not found")
|
return web.Response(status=404, text="Model not found")
|
||||||
|
|
||||||
|
|||||||
@@ -7,8 +7,7 @@ from server import PromptServer # type: ignore
|
|||||||
from .base_model_routes import BaseModelRoutes
|
from .base_model_routes import BaseModelRoutes
|
||||||
from ..services.lora_service import LoraService
|
from ..services.lora_service import LoraService
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..services.model_metadata_provider import ModelMetadataProviderManager
|
from ..services.metadata_service import get_default_metadata_provider
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
|
||||||
from ..utils.utils import get_lora_info
|
from ..utils.utils import get_lora_info
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -20,14 +19,12 @@ class LoraRoutes(BaseModelRoutes):
|
|||||||
"""Initialize LoRA routes with LoRA service"""
|
"""Initialize LoRA routes with LoRA service"""
|
||||||
# Service will be initialized later via setup_routes
|
# Service will be initialized later via setup_routes
|
||||||
self.service = None
|
self.service = None
|
||||||
self.metadata_provider = None
|
|
||||||
self.template_name = "loras.html"
|
self.template_name = "loras.html"
|
||||||
|
|
||||||
async def initialize_services(self):
|
async def initialize_services(self):
|
||||||
"""Initialize services from ServiceRegistry"""
|
"""Initialize services from ServiceRegistry"""
|
||||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
self.service = LoraService(lora_scanner)
|
self.service = LoraService(lora_scanner)
|
||||||
self.metadata_provider = await ModelMetadataProviderManager.get_instance()
|
|
||||||
|
|
||||||
# Initialize parent with the service
|
# Initialize parent with the service
|
||||||
super().__init__(self.service)
|
super().__init__(self.service)
|
||||||
@@ -218,7 +215,8 @@ class LoraRoutes(BaseModelRoutes):
|
|||||||
"""Get available versions for a Civitai LoRA model with local availability info"""
|
"""Get available versions for a Civitai LoRA model with local availability info"""
|
||||||
try:
|
try:
|
||||||
model_id = request.match_info['model_id']
|
model_id = request.match_info['model_id']
|
||||||
response = await self.metadata_provider.get_model_versions(model_id)
|
metadata_provider = await get_default_metadata_provider()
|
||||||
|
response = await metadata_provider.get_model_versions(model_id)
|
||||||
if not response or not response.get('modelVersions'):
|
if not response or not response.get('modelVersions'):
|
||||||
return web.Response(status=404, text="Model not found")
|
return web.Response(status=404, text="Model not found")
|
||||||
|
|
||||||
@@ -263,8 +261,9 @@ class LoraRoutes(BaseModelRoutes):
|
|||||||
model_version_id = request.match_info.get('modelVersionId')
|
model_version_id = request.match_info.get('modelVersionId')
|
||||||
|
|
||||||
# Get model details from metadata provider
|
# Get model details from metadata provider
|
||||||
model, error_msg = await self.metadata_provider.get_model_version_info(model_version_id)
|
metadata_provider = await get_default_metadata_provider()
|
||||||
|
model, error_msg = await metadata_provider.get_model_version_info(model_version_id)
|
||||||
|
|
||||||
if not model:
|
if not model:
|
||||||
# Log warning for failed model retrieval
|
# Log warning for failed model retrieval
|
||||||
logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}")
|
logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}")
|
||||||
@@ -289,7 +288,8 @@ class LoraRoutes(BaseModelRoutes):
|
|||||||
"""Get CivitAI model details by hash"""
|
"""Get CivitAI model details by hash"""
|
||||||
try:
|
try:
|
||||||
hash = request.match_info.get('hash')
|
hash = request.match_info.get('hash')
|
||||||
model = await self.metadata_provider.get_model_by_hash(hash)
|
metadata_provider = await get_default_metadata_provider()
|
||||||
|
model = await metadata_provider.get_model_by_hash(hash)
|
||||||
return web.json_response(model)
|
return web.json_response(model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching model details by hash: {e}")
|
logger.error(f"Error fetching model details by hash: {e}")
|
||||||
|
|||||||
Reference in New Issue
Block a user