feat(routes): integrate CivitAI model version retrieval for various model types

This commit is contained in:
Will Miao
2025-09-17 15:47:30 +08:00
parent 1cddeee264
commit 933e2fc01d
4 changed files with 126 additions and 201 deletions

View File

@@ -14,6 +14,7 @@ from ..services.settings_manager import settings
from ..services.server_i18n import server_i18n
from ..services.model_file_service import ModelFileService, ModelMoveService
from ..services.websocket_progress_callback import WebSocketProgressCallback
from ..services.metadata_service import get_default_metadata_provider
from ..config import config
logger = logging.getLogger(__name__)
@@ -84,14 +85,17 @@ class BaseModelRoutes(ABC):
# Autocomplete route
app.router.add_get(f'/api/{prefix}/relative-paths', self.get_relative_paths)
# Common CivitAI integration
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions)
app.router.add_get(f'/api/{prefix}/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version)
app.router.add_get(f'/api/{prefix}/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash)
# Common Download management
app.router.add_post(f'/api/download-model', self.download_model)
app.router.add_get(f'/api/download-model-get', self.download_model_get)
app.router.add_get(f'/api/cancel-download-get', self.cancel_download_get)
app.router.add_get(f'/api/download-progress/{{download_id}}', self.get_download_progress)
# app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions)
# Add generic page route
app.router.add_get(f'/{prefix}', self.handle_models_page)
@@ -704,10 +708,101 @@ class BaseModelRoutes(ABC):
async def get_civitai_versions(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai model with local availability info"""
# This will be implemented by subclasses as they need CivitAI client access
return web.json_response({
"error": "Not implemented in base class"
}, status=501)
try:
model_id = request.match_info['model_id']
metadata_provider = await get_default_metadata_provider()
response = await metadata_provider.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - allow subclasses to override validation
if not self._validate_civitai_model_type(model_type):
return web.json_response({
'error': f"Model type mismatch. Expected {self._get_expected_model_types()}, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the model file (type="Model" and primary=true) in the files list
model_file = self._find_model_file(version.get('files', []))
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.service.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.service.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching {self.model_type} model versions: {e}")
return web.Response(status=500, text=str(e))
async def get_civitai_model_by_version(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by model version ID"""
try:
model_version_id = request.match_info.get('modelVersionId')
# Get model details from metadata provider
metadata_provider = await get_default_metadata_provider()
model, error_msg = await metadata_provider.get_model_version_info(model_version_id)
if not model:
# Log warning for failed model retrieval
logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}")
# Determine status code based on error message
status_code = 404 if error_msg and "not found" in error_msg.lower() else 500
return web.json_response({
"success": False,
"error": error_msg or "Failed to fetch model information"
}, status=status_code)
return web.json_response(model)
except Exception as e:
logger.error(f"Error fetching model details: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by hash"""
try:
hash = request.match_info.get('hash')
metadata_provider = await get_default_metadata_provider()
model = await metadata_provider.get_model_by_hash(hash)
return web.json_response(model)
except Exception as e:
logger.error(f"Error fetching model details by hash: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
def _validate_civitai_model_type(self, model_type: str) -> bool:
"""Validate CivitAI model type - to be overridden by subclasses"""
return True # Default: accept all types
def _get_expected_model_types(self) -> str:
"""Get expected model types string for error messages - to be overridden by subclasses"""
return "any model type"
def _find_model_file(self, files: list) -> dict:
"""Find the appropriate model file from the files list - can be overridden by subclasses"""
# Find the primary model file (type="Model" and primary=true) in the files list
return next((file for file in files if file.get('type') == 'Model' and file.get('primary') == True), None)
# Common model move handlers
async def move_model(self, request: web.Request) -> web.Response:

View File

@@ -36,9 +36,6 @@ class CheckpointRoutes(BaseModelRoutes):
def setup_specific_routes(self, app: web.Application, prefix: str):
"""Setup Checkpoint-specific routes"""
# Checkpoint-specific CivitAI integration
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint)
# Checkpoint info by name
app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info)
@@ -46,6 +43,14 @@ class CheckpointRoutes(BaseModelRoutes):
app.router.add_get(f'/api/{prefix}/checkpoints_roots', self.get_checkpoints_roots)
app.router.add_get(f'/api/{prefix}/unet_roots', self.get_unet_roots)
def _validate_civitai_model_type(self, model_type: str) -> bool:
"""Validate CivitAI model type for Checkpoint"""
return model_type.lower() == 'checkpoint'
def _get_expected_model_types(self) -> str:
"""Get expected model types string for error messages"""
return "Checkpoint"
async def get_checkpoint_info(self, request: web.Request) -> web.Response:
"""Get detailed information for a specific checkpoint by name"""
try:
@@ -61,54 +66,6 @@ class CheckpointRoutes(BaseModelRoutes):
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_civitai_versions_checkpoint(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai checkpoint model with local availability info"""
try:
model_id = request.match_info['model_id']
metadata_provider = await get_default_metadata_provider()
response = await metadata_provider.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - should be Checkpoint
if model_type.lower() != 'checkpoint':
return web.json_response({
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the primary model file (type="Model" and primary=true) in the files list
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model' and file.get('primary') == True), None)
# If no primary file found, try to find any model file
if not model_file:
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.service.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.service.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching checkpoint model versions: {e}")
return web.Response(status=500, text=str(e))
async def get_checkpoints_roots(self, request: web.Request) -> web.Response:
"""Return the list of checkpoint roots from config"""
try:

View File

@@ -35,12 +35,17 @@ class EmbeddingRoutes(BaseModelRoutes):
def setup_specific_routes(self, app: web.Application, prefix: str):
"""Setup Embedding-specific routes"""
# Embedding-specific CivitAI integration
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_embedding)
# Embedding info by name
app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_embedding_info)
def _validate_civitai_model_type(self, model_type: str) -> bool:
"""Validate CivitAI model type for Embedding"""
return model_type.lower() in ['textualinversion', 'embedding']
def _get_expected_model_types(self) -> str:
"""Get expected model types string for error messages"""
return "TextualInversion/Embedding"
async def get_embedding_info(self, request: web.Request) -> web.Response:
"""Get detailed information for a specific embedding by name"""
try:
@@ -55,51 +60,3 @@ class EmbeddingRoutes(BaseModelRoutes):
except Exception as e:
logger.error(f"Error in get_embedding_info: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_civitai_versions_embedding(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai embedding model with local availability info"""
try:
model_id = request.match_info['model_id']
metadata_provider = await get_default_metadata_provider()
response = await metadata_provider.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - should be TextualInversion (Embedding)
if model_type.lower() not in ['textualinversion', 'embedding']:
return web.json_response({
'error': f"Model type mismatch. Expected TextualInversion/Embedding, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the primary model file (type="Model" and primary=true) in the files list
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model' and file.get('primary') == True), None)
# If no primary file found, try to find any model file
if not model_file:
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.service.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.service.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching embedding model versions: {e}")
return web.Response(status=500, text=str(e))

View File

@@ -44,11 +44,6 @@ class LoraRoutes(BaseModelRoutes):
app.router.add_get(f'/api/{prefix}/get-trigger-words', self.get_lora_trigger_words)
app.router.add_get(f'/api/{prefix}/usage-tips-by-path', self.get_lora_usage_tips_by_path)
# CivitAI integration with LoRA-specific validation
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_lora)
app.router.add_get(f'/api/{prefix}/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version)
app.router.add_get(f'/api/{prefix}/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash)
# ComfyUI integration
app.router.add_post(f'/api/{prefix}/get_trigger_words', self.get_trigger_words)
@@ -76,6 +71,15 @@ class LoraRoutes(BaseModelRoutes):
return params
def _validate_civitai_model_type(self, model_type: str) -> bool:
"""Validate CivitAI model type for LoRA"""
from ..utils.constants import VALID_LORA_TYPES
return model_type.lower() in VALID_LORA_TYPES
def _get_expected_model_types(self) -> str:
"""Get expected model types string for error messages"""
return "LORA, LoCon, or DORA"
# LoRA-specific route handlers
async def get_letter_counts(self, request: web.Request) -> web.Response:
"""Get count of LoRAs for each letter of the alphabet"""
@@ -210,94 +214,6 @@ class LoraRoutes(BaseModelRoutes):
'error': str(e)
}, status=500)
# CivitAI integration methods
async def get_civitai_versions_lora(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai LoRA model with local availability info"""
try:
model_id = request.match_info['model_id']
metadata_provider = await get_default_metadata_provider()
response = await metadata_provider.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - should be LORA, LoCon, or DORA
from ..utils.constants import VALID_LORA_TYPES
if model_type.lower() not in VALID_LORA_TYPES:
return web.json_response({
'error': f"Model type mismatch. Expected LORA or LoCon, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the model file (type="Model") in the files list
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.service.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.service.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching LoRA model versions: {e}")
return web.Response(status=500, text=str(e))
async def get_civitai_model_by_version(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by model version ID"""
try:
model_version_id = request.match_info.get('modelVersionId')
# Get model details from metadata provider
metadata_provider = await get_default_metadata_provider()
model, error_msg = await metadata_provider.get_model_version_info(model_version_id)
if not model:
# Log warning for failed model retrieval
logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}")
# Determine status code based on error message
status_code = 404 if error_msg and "not found" in error_msg.lower() else 500
return web.json_response({
"success": False,
"error": error_msg or "Failed to fetch model information"
}, status=status_code)
return web.json_response(model)
except Exception as e:
logger.error(f"Error fetching model details: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by hash"""
try:
hash = request.match_info.get('hash')
metadata_provider = await get_default_metadata_provider()
model = await metadata_provider.get_model_by_hash(hash)
return web.json_response(model)
except Exception as e:
logger.error(f"Error fetching model details by hash: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def get_trigger_words(self, request: web.Request) -> web.Response:
"""Get trigger words for specified LoRA models"""
try: