feat(routes): integrate CivitAI model version retrieval for various model types

This commit is contained in:
Will Miao
2025-09-17 15:47:30 +08:00
parent 1cddeee264
commit 933e2fc01d
4 changed files with 126 additions and 201 deletions

View File

@@ -14,6 +14,7 @@ from ..services.settings_manager import settings
from ..services.server_i18n import server_i18n
from ..services.model_file_service import ModelFileService, ModelMoveService
from ..services.websocket_progress_callback import WebSocketProgressCallback
from ..services.metadata_service import get_default_metadata_provider
from ..config import config
logger = logging.getLogger(__name__)
@@ -84,14 +85,17 @@ class BaseModelRoutes(ABC):
# Autocomplete route
app.router.add_get(f'/api/{prefix}/relative-paths', self.get_relative_paths)
# Common CivitAI integration
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions)
app.router.add_get(f'/api/{prefix}/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version)
app.router.add_get(f'/api/{prefix}/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash)
# Common Download management
app.router.add_post(f'/api/download-model', self.download_model)
app.router.add_get(f'/api/download-model-get', self.download_model_get)
app.router.add_get(f'/api/cancel-download-get', self.cancel_download_get)
app.router.add_get(f'/api/download-progress/{{download_id}}', self.get_download_progress)
# app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions)
# Add generic page route
app.router.add_get(f'/{prefix}', self.handle_models_page)
@@ -704,10 +708,101 @@ class BaseModelRoutes(ABC):
async def get_civitai_versions(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai model with local availability info"""
# This will be implemented by subclasses as they need CivitAI client access
return web.json_response({
"error": "Not implemented in base class"
}, status=501)
try:
model_id = request.match_info['model_id']
metadata_provider = await get_default_metadata_provider()
response = await metadata_provider.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - allow subclasses to override validation
if not self._validate_civitai_model_type(model_type):
return web.json_response({
'error': f"Model type mismatch. Expected {self._get_expected_model_types()}, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the model file (type="Model" and primary=true) in the files list
model_file = self._find_model_file(version.get('files', []))
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.service.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.service.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching {self.model_type} model versions: {e}")
return web.Response(status=500, text=str(e))
async def get_civitai_model_by_version(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by model version ID"""
try:
model_version_id = request.match_info.get('modelVersionId')
# Get model details from metadata provider
metadata_provider = await get_default_metadata_provider()
model, error_msg = await metadata_provider.get_model_version_info(model_version_id)
if not model:
# Log warning for failed model retrieval
logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}")
# Determine status code based on error message
status_code = 404 if error_msg and "not found" in error_msg.lower() else 500
return web.json_response({
"success": False,
"error": error_msg or "Failed to fetch model information"
}, status=status_code)
return web.json_response(model)
except Exception as e:
logger.error(f"Error fetching model details: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by hash"""
try:
hash = request.match_info.get('hash')
metadata_provider = await get_default_metadata_provider()
model = await metadata_provider.get_model_by_hash(hash)
return web.json_response(model)
except Exception as e:
logger.error(f"Error fetching model details by hash: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
def _validate_civitai_model_type(self, model_type: str) -> bool:
"""Validate CivitAI model type - to be overridden by subclasses"""
return True # Default: accept all types
def _get_expected_model_types(self) -> str:
"""Get expected model types string for error messages - to be overridden by subclasses"""
return "any model type"
def _find_model_file(self, files: list) -> dict:
"""Find the appropriate model file from the files list - can be overridden by subclasses"""
# Find the primary model file (type="Model" and primary=true) in the files list
return next((file for file in files if file.get('type') == 'Model' and file.get('primary') == True), None)
# Common model move handlers
async def move_model(self, request: web.Request) -> web.Response: