diff --git a/py/routes/embedding_routes.py b/py/routes/embedding_routes.py index 70b5b26b..eefa8bdd 100644 --- a/py/routes/embedding_routes.py +++ b/py/routes/embedding_routes.py @@ -40,11 +40,11 @@ class EmbeddingRoutes(BaseModelRoutes): def _validate_civitai_model_type(self, model_type: str) -> bool: """Validate CivitAI model type for Embedding""" - return model_type.lower() in ['textualinversion', 'embedding'] + return model_type.lower() == 'textualinversion' def _get_expected_model_types(self) -> str: """Get expected model types string for error messages""" - return "TextualInversion/Embedding" + return "TextualInversion" async def get_embedding_info(self, request: web.Request) -> web.Response: """Get detailed information for a specific embedding by name""" diff --git a/py/routes/misc_routes.py b/py/routes/misc_routes.py index 89fadace..5cae7002 100644 --- a/py/routes/misc_routes.py +++ b/py/routes/misc_routes.py @@ -12,7 +12,7 @@ from ..utils.lora_metadata import extract_trained_words from ..config import config from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS, NODE_TYPES, DEFAULT_NODE_COLOR from ..services.service_registry import ServiceRegistry -from ..services.metadata_service import get_metadata_archive_manager, update_metadata_providers +from ..services.metadata_service import get_metadata_archive_manager, update_metadata_providers, get_metadata_provider from ..services.websocket_manager import ws_manager from ..services.downloader import get_downloader logger = logging.getLogger(__name__) @@ -119,6 +119,9 @@ class MiscRoutes: app.router.add_post('/api/download-metadata-archive', MiscRoutes.download_metadata_archive) app.router.add_post('/api/remove-metadata-archive', MiscRoutes.remove_metadata_archive) app.router.add_get('/api/metadata-archive-status', MiscRoutes.get_metadata_archive_status) + + # Add route for checking model versions in library + app.router.add_get('/api/model-versions-status', MiscRoutes.get_model_versions_status) @staticmethod async def get_settings(request): @@ -832,6 +835,113 @@ class MiscRoutes: 'success': False, 'error': str(e) }, status=500) + + @staticmethod + async def get_model_versions_status(request): + """ + Get all versions of a model from metadata provider and check their library status + + Expects query parameters: + - modelId: int - Civitai model ID (required) + + Returns: + - JSON with model type and versions list, each version includes 'inLibrary' flag + """ + try: + # Get the modelId from query parameters + model_id_str = request.query.get('modelId') + + # Validate modelId parameter (required) + if not model_id_str: + return web.json_response({ + 'success': False, + 'error': 'Missing required parameter: modelId' + }, status=400) + + try: + # Convert modelId to integer + model_id = int(model_id_str) + except ValueError: + return web.json_response({ + 'success': False, + 'error': 'Parameter modelId must be an integer' + }, status=400) + + # Get metadata provider + metadata_provider = await get_metadata_provider() + if not metadata_provider: + return web.json_response({ + 'success': False, + 'error': 'Metadata provider not available' + }, status=503) + + # Get model versions from metadata provider + response = await metadata_provider.get_model_versions(model_id) + if not response or not response.get('modelVersions'): + return web.json_response({ + 'success': False, + 'error': 'Model not found' + }, status=404) + + versions = response.get('modelVersions', []) + model_name = response.get('name', '') + model_type = response.get('type', '').lower() + + # Determine scanner based on model type + scanner = None + normalized_type = None + + if model_type in ['lora', 'locon', 'dora']: + scanner = await ServiceRegistry.get_lora_scanner() + normalized_type = 'lora' + elif model_type == 'checkpoint': + scanner = await ServiceRegistry.get_checkpoint_scanner() + normalized_type = 'checkpoint' + elif model_type == 'textualinversion': + scanner = await ServiceRegistry.get_embedding_scanner() + normalized_type = 'embedding' + else: + return web.json_response({ + 'success': False, + 'error': f'Model type "{model_type}" is not supported' + }, status=400) + + if not scanner: + return web.json_response({ + 'success': False, + 'error': f'Scanner for type "{normalized_type}" is not available' + }, status=503) + + # Get local versions from scanner + local_versions = await scanner.get_model_versions_by_id(model_id) + local_version_ids = set(version['versionId'] for version in local_versions) + + # Add inLibrary flag to each version + enriched_versions = [] + for version in versions: + version_id = version.get('id') + enriched_version = { + 'id': version_id, + 'name': version.get('name', ''), + 'thumbnailUrl': version.get('images')[0]['url'] if version.get('images') else None, + 'inLibrary': version_id in local_version_ids + } + enriched_versions.append(enriched_version) + + return web.json_response({ + 'success': True, + 'modelId': model_id, + 'modelName': model_name, + 'modelType': model_type, + 'versions': enriched_versions + }) + + except Exception as e: + logger.error(f"Failed to get model versions status: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) @staticmethod async def open_file_location(request): diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index e037ba35..463bd036 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -122,7 +122,8 @@ class CivitaiClient: # Also return model type along with versions return { 'modelVersions': result.get('modelVersions', []), - 'type': result.get('type', '') + 'type': result.get('type', ''), + 'name': result.get('name', '') } return None except Exception as e: diff --git a/py/services/model_metadata_provider.py b/py/services/model_metadata_provider.py index 9957b849..ee38f373 100644 --- a/py/services/model_metadata_provider.py +++ b/py/services/model_metadata_provider.py @@ -224,6 +224,7 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider): model_data = json.loads(model_row['data']) model_type = model_row['type'] + model_name = model_row['name'] # Get all versions for this model versions_query = """ @@ -260,7 +261,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider): return { 'modelVersions': model_versions, - 'type': model_type + 'type': model_type, + 'name': model_name } async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: