refactor: Implement common endpoint handlers for model management in ModelRouteUtils and update routes in CheckpointsRoutes

This commit is contained in:
Will Miao
2025-04-11 12:06:05 +08:00
parent 56670066c7
commit e991dc061d
5 changed files with 210 additions and 80 deletions

View File

@@ -2,6 +2,7 @@ import os
import json
import logging
from typing import Dict, List, Callable, Awaitable
from aiohttp import web
from .model_utils import determine_base_model
from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
@@ -249,4 +250,175 @@ class ModelRouteUtils:
parts = filename.split(".")
if len(parts) > 2: # If contains multi-part extension
return "." + ".".join(parts[-2:]) # Take the last two parts, like ".metadata.json"
return os.path.splitext(filename)[1] # Otherwise take the regular extension, like ".safetensors"
return os.path.splitext(filename)[1] # Otherwise take the regular extension, like ".safetensors"
# New common endpoint handlers
@staticmethod
async def handle_delete_model(request: web.Request, scanner) -> web.Response:
"""Handle model deletion request
Args:
request: The aiohttp request
scanner: The model scanner instance with cache management methods
Returns:
web.Response: The HTTP response
"""
try:
data = await request.json()
file_path = data.get('file_path')
if not file_path:
return web.Response(text='Model path is required', status=400)
target_dir = os.path.dirname(file_path)
file_name = os.path.splitext(os.path.basename(file_path))[0]
# Get the file monitor from the scanner if available
file_monitor = getattr(scanner, 'file_monitor', None)
deleted_files = await ModelRouteUtils.delete_model_files(
target_dir,
file_name,
file_monitor
)
# Remove from cache
cache = await scanner.get_cached_data()
cache.raw_data = [item for item in cache.raw_data if item['file_path'] != file_path]
await cache.resort()
# Update hash index if available
if hasattr(scanner, '_hash_index') and scanner._hash_index:
scanner._hash_index.remove_by_path(file_path)
return web.json_response({
'success': True,
'deleted_files': deleted_files
})
except Exception as e:
logger.error(f"Error deleting model: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
@staticmethod
async def handle_fetch_civitai(request: web.Request, scanner) -> web.Response:
"""Handle CivitAI metadata fetch request
Args:
request: The aiohttp request
scanner: The model scanner instance with cache management methods
Returns:
web.Response: The HTTP response
"""
try:
data = await request.json()
metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json'
# Check if model metadata exists
local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
if not local_metadata or not local_metadata.get('sha256'):
return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400)
# Create a client for fetching from Civitai
client = CivitaiClient()
try:
# Fetch and update metadata
civitai_metadata = await client.get_model_by_hash(local_metadata["sha256"])
if not civitai_metadata:
await ModelRouteUtils.handle_not_found_on_civitai(metadata_path, local_metadata)
return web.json_response({"success": False, "error": "Not found on CivitAI"}, status=404)
await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, client)
# Update the cache
await scanner.update_single_model_cache(data['file_path'], data['file_path'], local_metadata)
return web.json_response({"success": True})
finally:
await client.close()
except Exception as e:
logger.error(f"Error fetching from CivitAI: {e}", exc_info=True)
return web.json_response({"success": False, "error": str(e)}, status=500)
@staticmethod
async def handle_replace_preview(request: web.Request, scanner) -> web.Response:
"""Handle preview image replacement request
Args:
request: The aiohttp request
scanner: The model scanner instance with methods to update cache
Returns:
web.Response: The HTTP response
"""
try:
reader = await request.multipart()
# Read preview file data
field = await reader.next()
if field.name != 'preview_file':
raise ValueError("Expected 'preview_file' field")
content_type = field.headers.get('Content-Type', 'image/png')
preview_data = await field.read()
# Read model path
field = await reader.next()
if field.name != 'model_path':
raise ValueError("Expected 'model_path' field")
model_path = (await field.read()).decode()
# Save preview file
base_name = os.path.splitext(os.path.basename(model_path))[0]
folder = os.path.dirname(model_path)
# Determine if content is video or image
if content_type.startswith('video/'):
# For videos, keep original format and use .mp4 extension
extension = '.mp4'
optimized_data = preview_data
else:
# For images, optimize and convert to WebP
optimized_data, _ = ExifUtils.optimize_image(
image_data=preview_data,
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=True
)
extension = '.webp' # Use .webp without .preview part
preview_path = os.path.join(folder, base_name + extension).replace(os.sep, '/')
with open(preview_path, 'wb') as f:
f.write(optimized_data)
# Update preview path in metadata
metadata_path = os.path.splitext(model_path)[0] + '.metadata.json'
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
# Update preview_url directly in the metadata dict
metadata['preview_url'] = preview_path
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
except Exception as e:
logger.error(f"Error updating metadata: {e}")
# Update preview URL in scanner cache
if hasattr(scanner, 'update_preview_in_cache'):
await scanner.update_preview_in_cache(model_path, preview_path)
return web.json_response({
"success": True,
"preview_url": config.get_preview_static_url(preview_path)
})
except Exception as e:
logger.error(f"Error replacing preview: {e}", exc_info=True)
return web.Response(text=str(e), status=500)