feat(routes): add model versions status endpoint and enhance metadata retrieval

This commit is contained in:
Will Miao
2025-09-17 22:06:59 +08:00
parent 933e2fc01d
commit ded17c1479
4 changed files with 118 additions and 5 deletions

View File

@@ -40,11 +40,11 @@ class EmbeddingRoutes(BaseModelRoutes):
def _validate_civitai_model_type(self, model_type: str) -> bool:
"""Validate CivitAI model type for Embedding"""
return model_type.lower() in ['textualinversion', 'embedding']
return model_type.lower() == 'textualinversion'
def _get_expected_model_types(self) -> str:
"""Get expected model types string for error messages"""
return "TextualInversion/Embedding"
return "TextualInversion"
async def get_embedding_info(self, request: web.Request) -> web.Response:
"""Get detailed information for a specific embedding by name"""

View File

@@ -12,7 +12,7 @@ from ..utils.lora_metadata import extract_trained_words
from ..config import config
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS, NODE_TYPES, DEFAULT_NODE_COLOR
from ..services.service_registry import ServiceRegistry
from ..services.metadata_service import get_metadata_archive_manager, update_metadata_providers
from ..services.metadata_service import get_metadata_archive_manager, update_metadata_providers, get_metadata_provider
from ..services.websocket_manager import ws_manager
from ..services.downloader import get_downloader
logger = logging.getLogger(__name__)
@@ -119,6 +119,9 @@ class MiscRoutes:
app.router.add_post('/api/download-metadata-archive', MiscRoutes.download_metadata_archive)
app.router.add_post('/api/remove-metadata-archive', MiscRoutes.remove_metadata_archive)
app.router.add_get('/api/metadata-archive-status', MiscRoutes.get_metadata_archive_status)
# Add route for checking model versions in library
app.router.add_get('/api/model-versions-status', MiscRoutes.get_model_versions_status)
@staticmethod
async def get_settings(request):
@@ -832,6 +835,113 @@ class MiscRoutes:
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def get_model_versions_status(request):
"""
Get all versions of a model from metadata provider and check their library status
Expects query parameters:
- modelId: int - Civitai model ID (required)
Returns:
- JSON with model type and versions list, each version includes 'inLibrary' flag
"""
try:
# Get the modelId from query parameters
model_id_str = request.query.get('modelId')
# Validate modelId parameter (required)
if not model_id_str:
return web.json_response({
'success': False,
'error': 'Missing required parameter: modelId'
}, status=400)
try:
# Convert modelId to integer
model_id = int(model_id_str)
except ValueError:
return web.json_response({
'success': False,
'error': 'Parameter modelId must be an integer'
}, status=400)
# Get metadata provider
metadata_provider = await get_metadata_provider()
if not metadata_provider:
return web.json_response({
'success': False,
'error': 'Metadata provider not available'
}, status=503)
# Get model versions from metadata provider
response = await metadata_provider.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.json_response({
'success': False,
'error': 'Model not found'
}, status=404)
versions = response.get('modelVersions', [])
model_name = response.get('name', '')
model_type = response.get('type', '').lower()
# Determine scanner based on model type
scanner = None
normalized_type = None
if model_type in ['lora', 'locon', 'dora']:
scanner = await ServiceRegistry.get_lora_scanner()
normalized_type = 'lora'
elif model_type == 'checkpoint':
scanner = await ServiceRegistry.get_checkpoint_scanner()
normalized_type = 'checkpoint'
elif model_type == 'textualinversion':
scanner = await ServiceRegistry.get_embedding_scanner()
normalized_type = 'embedding'
else:
return web.json_response({
'success': False,
'error': f'Model type "{model_type}" is not supported'
}, status=400)
if not scanner:
return web.json_response({
'success': False,
'error': f'Scanner for type "{normalized_type}" is not available'
}, status=503)
# Get local versions from scanner
local_versions = await scanner.get_model_versions_by_id(model_id)
local_version_ids = set(version['versionId'] for version in local_versions)
# Add inLibrary flag to each version
enriched_versions = []
for version in versions:
version_id = version.get('id')
enriched_version = {
'id': version_id,
'name': version.get('name', ''),
'thumbnailUrl': version.get('images')[0]['url'] if version.get('images') else None,
'inLibrary': version_id in local_version_ids
}
enriched_versions.append(enriched_version)
return web.json_response({
'success': True,
'modelId': model_id,
'modelName': model_name,
'modelType': model_type,
'versions': enriched_versions
})
except Exception as e:
logger.error(f"Failed to get model versions status: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def open_file_location(request):

View File

@@ -122,7 +122,8 @@ class CivitaiClient:
# Also return model type along with versions
return {
'modelVersions': result.get('modelVersions', []),
'type': result.get('type', '')
'type': result.get('type', ''),
'name': result.get('name', '')
}
return None
except Exception as e:

View File

@@ -224,6 +224,7 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
model_data = json.loads(model_row['data'])
model_type = model_row['type']
model_name = model_row['name']
# Get all versions for this model
versions_query = """
@@ -260,7 +261,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
return {
'modelVersions': model_versions,
'type': model_type
'type': model_type,
'name': model_name
}
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: