feat(metadata): implement metadata archive management and update settings for metadata providers

This commit is contained in:
Will Miao
2025-09-08 13:17:16 +08:00
parent 9ba3e2c204
commit 821827a375
11 changed files with 659 additions and 38 deletions

View File

@@ -4,6 +4,7 @@ from aiohttp import web
from .base_model_routes import BaseModelRoutes
from ..services.checkpoint_service import CheckpointService
from ..services.service_registry import ServiceRegistry
from ..services.model_metadata_provider import ModelMetadataProviderManager
from ..config import config
logger = logging.getLogger(__name__)
@@ -15,14 +16,14 @@ class CheckpointRoutes(BaseModelRoutes):
"""Initialize Checkpoint routes with Checkpoint service"""
# Service will be initialized later via setup_routes
self.service = None
self.civitai_client = None
self.metadata_provider = None
self.template_name = "checkpoints.html"
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
self.service = CheckpointService(checkpoint_scanner)
self.civitai_client = await ServiceRegistry.get_civitai_client()
self.metadata_provider = await ModelMetadataProviderManager.get_instance()
# Initialize parent with the service
super().__init__(self.service)
@@ -66,7 +67,7 @@ class CheckpointRoutes(BaseModelRoutes):
"""Get available versions for a Civitai checkpoint model with local availability info"""
try:
model_id = request.match_info['model_id']
response = await self.civitai_client.get_model_versions(model_id)
response = await self.metadata_provider.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")

View File

@@ -4,6 +4,7 @@ from aiohttp import web
from .base_model_routes import BaseModelRoutes
from ..services.embedding_service import EmbeddingService
from ..services.service_registry import ServiceRegistry
from ..services.model_metadata_provider import ModelMetadataProviderManager
logger = logging.getLogger(__name__)
@@ -14,14 +15,14 @@ class EmbeddingRoutes(BaseModelRoutes):
"""Initialize Embedding routes with Embedding service"""
# Service will be initialized later via setup_routes
self.service = None
self.civitai_client = None
self.metadata_provider = None
self.template_name = "embeddings.html"
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
self.service = EmbeddingService(embedding_scanner)
self.civitai_client = await ServiceRegistry.get_civitai_client()
self.metadata_provider = await ModelMetadataProviderManager.get_instance()
# Initialize parent with the service
super().__init__(self.service)
@@ -61,7 +62,7 @@ class EmbeddingRoutes(BaseModelRoutes):
"""Get available versions for a Civitai embedding model with local availability info"""
try:
model_id = request.match_info['model_id']
response = await self.civitai_client.get_model_versions(model_id)
response = await self.metadata_provider.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")

View File

@@ -7,6 +7,7 @@ from server import PromptServer # type: ignore
from .base_model_routes import BaseModelRoutes
from ..services.lora_service import LoraService
from ..services.service_registry import ServiceRegistry
from ..services.model_metadata_provider import ModelMetadataProviderManager
from ..utils.routes_common import ModelRouteUtils
from ..utils.utils import get_lora_info
@@ -19,14 +20,14 @@ class LoraRoutes(BaseModelRoutes):
"""Initialize LoRA routes with LoRA service"""
# Service will be initialized later via setup_routes
self.service = None
self.civitai_client = None
self.metadata_provider = None
self.template_name = "loras.html"
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
lora_scanner = await ServiceRegistry.get_lora_scanner()
self.service = LoraService(lora_scanner)
self.civitai_client = await ServiceRegistry.get_civitai_client()
self.metadata_provider = await ModelMetadataProviderManager.get_instance()
# Initialize parent with the service
super().__init__(self.service)
@@ -217,7 +218,7 @@ class LoraRoutes(BaseModelRoutes):
"""Get available versions for a Civitai LoRA model with local availability info"""
try:
model_id = request.match_info['model_id']
response = await self.civitai_client.get_model_versions(model_id)
response = await self.metadata_provider.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
@@ -261,8 +262,8 @@ class LoraRoutes(BaseModelRoutes):
try:
model_version_id = request.match_info.get('modelVersionId')
# Get model details from Civitai API
model, error_msg = await self.civitai_client.get_model_version_info(model_version_id)
# Get model details from metadata provider
model, error_msg = await self.metadata_provider.get_model_version_info(model_version_id)
if not model:
# Log warning for failed model retrieval
@@ -288,7 +289,7 @@ class LoraRoutes(BaseModelRoutes):
"""Get CivitAI model details by hash"""
try:
hash = request.match_info.get('hash')
model = await self.civitai_client.get_model_by_hash(hash)
model = await self.metadata_provider.get_model_by_hash(hash)
return web.json_response(model)
except Exception as e:
logger.error(f"Error fetching model details by hash: {e}")

View File

@@ -11,6 +11,8 @@ from ..utils.lora_metadata import extract_trained_words
from ..config import config
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS, NODE_TYPES, DEFAULT_NODE_COLOR
from ..services.service_registry import ServiceRegistry
from ..services.metadata_service import get_metadata_archive_manager, update_metadata_provider_priority
from ..services.websocket_manager import ws_manager
import re
logger = logging.getLogger(__name__)
@@ -112,6 +114,11 @@ class MiscRoutes:
# Add new route for checking if a model exists in the library
app.router.add_get('/api/check-model-exists', MiscRoutes.check_model_exists)
# Add routes for metadata archive database management
app.router.add_post('/api/download-metadata-archive', MiscRoutes.download_metadata_archive)
app.router.add_post('/api/remove-metadata-archive', MiscRoutes.remove_metadata_archive)
app.router.add_get('/api/metadata-archive-status', MiscRoutes.get_metadata_archive_status)
@staticmethod
async def clear_cache(request):
@@ -697,3 +704,108 @@ class MiscRoutes:
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def download_metadata_archive(request):
"""Download and extract the metadata archive database"""
try:
archive_manager = await get_metadata_archive_manager()
# Progress callback to send updates via WebSocket
def progress_callback(stage, message):
asyncio.create_task(ws_manager.broadcast({
'stage': stage,
'message': message,
'type': 'metadata_archive_download'
}))
# Download and extract in background
success = await archive_manager.download_and_extract_database(progress_callback)
if success:
# Update settings to enable metadata archive
settings.set('enable_metadata_archive_db', True)
# Update provider priority
await update_metadata_provider_priority()
return web.json_response({
'success': True,
'message': 'Metadata archive database downloaded and extracted successfully'
})
else:
return web.json_response({
'success': False,
'error': 'Failed to download and extract metadata archive database'
}, status=500)
except Exception as e:
logger.error(f"Error downloading metadata archive: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def remove_metadata_archive(request):
"""Remove the metadata archive database"""
try:
archive_manager = await get_metadata_archive_manager()
success = await archive_manager.remove_database()
if success:
# Update settings to disable metadata archive
settings.set('enable_metadata_archive_db', False)
# Update provider priority
await update_metadata_provider_priority()
return web.json_response({
'success': True,
'message': 'Metadata archive database removed successfully'
})
else:
return web.json_response({
'success': False,
'error': 'Failed to remove metadata archive database'
}, status=500)
except Exception as e:
logger.error(f"Error removing metadata archive: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def get_metadata_archive_status(request):
"""Get the status of metadata archive database"""
try:
archive_manager = await get_metadata_archive_manager()
is_available = archive_manager.is_database_available()
is_enabled = settings.get('enable_metadata_archive_db', False)
priority = settings.get('metadata_provider_priority', 'archive_db')
db_size = 0
if is_available:
db_path = archive_manager.get_database_path()
if db_path and os.path.exists(db_path):
db_size = os.path.getsize(db_path)
return web.json_response({
'success': True,
'isAvailable': is_available,
'isEnabled': is_enabled,
'priority': priority,
'databaseSize': db_size,
'databasePath': archive_manager.get_database_path() if is_available else None
})
except Exception as e:
logger.error(f"Error getting metadata archive status: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)