mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
825 lines
35 KiB
Python
825 lines
35 KiB
Python
import os
|
|
import json
|
|
import logging
|
|
from aiohttp import web
|
|
from typing import Dict, List
|
|
|
|
from ..utils.model_utils import determine_base_model
|
|
|
|
from ..services.file_monitor import LoraFileMonitor
|
|
from ..services.download_manager import DownloadManager
|
|
from ..services.civitai_client import CivitaiClient
|
|
from ..config import config
|
|
from ..services.lora_scanner import LoraScanner
|
|
from operator import itemgetter
|
|
from ..services.websocket_manager import ws_manager
|
|
from ..services.settings_manager import settings
|
|
import asyncio
|
|
from .update_routes import UpdateRoutes
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ApiRoutes:
|
|
"""API route handlers for LoRA management"""
|
|
|
|
def __init__(self, file_monitor: LoraFileMonitor):
|
|
self.scanner = LoraScanner()
|
|
self.civitai_client = CivitaiClient()
|
|
self.download_manager = DownloadManager(file_monitor)
|
|
self._download_lock = asyncio.Lock()
|
|
|
|
@classmethod
|
|
def setup_routes(cls, app: web.Application, monitor: LoraFileMonitor):
|
|
"""Register API routes"""
|
|
routes = cls(monitor)
|
|
app.router.add_post('/api/delete_model', routes.delete_model)
|
|
app.router.add_post('/api/fetch-civitai', routes.fetch_civitai)
|
|
app.router.add_post('/api/replace_preview', routes.replace_preview)
|
|
app.router.add_get('/api/loras', routes.get_loras)
|
|
app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai)
|
|
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
|
app.router.add_get('/api/lora-roots', routes.get_lora_roots)
|
|
app.router.add_get('/api/civitai/versions/{model_id}', routes.get_civitai_versions)
|
|
app.router.add_post('/api/download-lora', routes.download_lora)
|
|
app.router.add_post('/api/settings', routes.update_settings)
|
|
app.router.add_post('/api/move_model', routes.move_model)
|
|
app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route
|
|
app.router.add_post('/loras/api/save-metadata', routes.save_metadata)
|
|
app.router.add_get('/api/lora-preview-url', routes.get_lora_preview_url) # Add new route
|
|
app.router.add_post('/api/move_models_bulk', routes.move_models_bulk)
|
|
app.router.add_get('/api/top-tags', routes.get_top_tags) # Add new route for top tags
|
|
|
|
# Add update check routes
|
|
UpdateRoutes.setup_routes(app)
|
|
|
|
async def delete_model(self, request: web.Request) -> web.Response:
|
|
"""Handle model deletion request"""
|
|
try:
|
|
data = await request.json()
|
|
file_path = data.get('file_path')
|
|
if not file_path:
|
|
return web.Response(text='Model path is required', status=400)
|
|
|
|
target_dir = os.path.dirname(file_path)
|
|
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
|
|
|
deleted_files = await self._delete_model_files(target_dir, file_name)
|
|
|
|
return web.json_response({
|
|
'success': True,
|
|
'deleted_files': deleted_files
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting model: {e}", exc_info=True)
|
|
return web.Response(text=str(e), status=500)
|
|
|
|
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
|
"""Handle CivitAI metadata fetch request"""
|
|
try:
|
|
data = await request.json()
|
|
metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json'
|
|
|
|
# Check if model is from CivitAI
|
|
local_metadata = await self._load_local_metadata(metadata_path)
|
|
|
|
# Fetch and update metadata
|
|
civitai_metadata = await self.civitai_client.get_model_by_hash(local_metadata["sha256"])
|
|
if not civitai_metadata:
|
|
return await self._handle_not_found_on_civitai(metadata_path, local_metadata)
|
|
|
|
await self._update_model_metadata(metadata_path, local_metadata, civitai_metadata, self.civitai_client)
|
|
|
|
return web.json_response({"success": True})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching from CivitAI: {e}", exc_info=True)
|
|
return web.json_response({"success": False, "error": str(e)}, status=500)
|
|
|
|
async def replace_preview(self, request: web.Request) -> web.Response:
|
|
"""Handle preview image replacement request"""
|
|
try:
|
|
reader = await request.multipart()
|
|
preview_data, content_type = await self._read_preview_file(reader)
|
|
model_path = await self._read_model_path(reader)
|
|
|
|
preview_path = await self._save_preview_file(model_path, preview_data, content_type)
|
|
await self._update_preview_metadata(model_path, preview_path)
|
|
|
|
# Update preview URL in scanner cache
|
|
await self.scanner.update_preview_in_cache(model_path, preview_path)
|
|
|
|
return web.json_response({
|
|
"success": True,
|
|
"preview_url": config.get_preview_static_url(preview_path)
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error replacing preview: {e}", exc_info=True)
|
|
return web.Response(text=str(e), status=500)
|
|
|
|
async def get_loras(self, request: web.Request) -> web.Response:
|
|
"""Handle paginated LoRA data request"""
|
|
try:
|
|
# Parse query parameters
|
|
page = int(request.query.get('page', '1'))
|
|
page_size = int(request.query.get('page_size', '20'))
|
|
sort_by = request.query.get('sort_by', 'name')
|
|
folder = request.query.get('folder')
|
|
search = request.query.get('search', '').lower()
|
|
fuzzy = request.query.get('fuzzy', 'false').lower() == 'true'
|
|
recursive = request.query.get('recursive', 'false').lower() == 'true'
|
|
|
|
# Parse base models filter parameter
|
|
base_models = request.query.get('base_models', '').split(',')
|
|
base_models = [model.strip() for model in base_models if model.strip()]
|
|
|
|
# Parse search options
|
|
search_filename = request.query.get('search_filename', 'true').lower() == 'true'
|
|
search_modelname = request.query.get('search_modelname', 'true').lower() == 'true'
|
|
search_tags = request.query.get('search_tags', 'false').lower() == 'true'
|
|
|
|
# Validate parameters
|
|
if page < 1 or page_size < 1 or page_size > 100:
|
|
return web.json_response({
|
|
'error': 'Invalid pagination parameters'
|
|
}, status=400)
|
|
|
|
if sort_by not in ['date', 'name']:
|
|
return web.json_response({
|
|
'error': 'Invalid sort parameter'
|
|
}, status=400)
|
|
|
|
# Parse tags filter parameter
|
|
tags = request.query.get('tags', '').split(',')
|
|
tags = [tag.strip() for tag in tags if tag.strip()]
|
|
|
|
# Get paginated data with search and filters
|
|
result = await self.scanner.get_paginated_data(
|
|
page=page,
|
|
page_size=page_size,
|
|
sort_by=sort_by,
|
|
folder=folder,
|
|
search=search,
|
|
fuzzy=fuzzy,
|
|
recursive=recursive,
|
|
base_models=base_models, # Pass base models filter
|
|
tags=tags, # Add tags parameter
|
|
search_options={
|
|
'filename': search_filename,
|
|
'modelname': search_modelname,
|
|
'tags': search_tags
|
|
}
|
|
)
|
|
|
|
# Format the response data
|
|
formatted_items = [
|
|
self._format_lora_response(item)
|
|
for item in result['items']
|
|
]
|
|
|
|
# Get all available folders from cache
|
|
cache = await self.scanner.get_cached_data()
|
|
|
|
return web.json_response({
|
|
'items': formatted_items,
|
|
'total': result['total'],
|
|
'page': result['page'],
|
|
'page_size': result['page_size'],
|
|
'total_pages': result['total_pages'],
|
|
'folders': cache.folders
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in get_loras: {str(e)}", exc_info=True)
|
|
return web.json_response({
|
|
'error': 'Internal server error'
|
|
}, status=500)
|
|
|
|
def _format_lora_response(self, lora: Dict) -> Dict:
|
|
"""Format LoRA data for API response"""
|
|
return {
|
|
"model_name": lora["model_name"],
|
|
"file_name": lora["file_name"],
|
|
"preview_url": config.get_preview_static_url(lora["preview_url"]),
|
|
"preview_nsfw_level": lora.get("preview_nsfw_level", 0),
|
|
"base_model": lora["base_model"],
|
|
"folder": lora["folder"],
|
|
"sha256": lora["sha256"],
|
|
"file_path": lora["file_path"].replace(os.sep, "/"),
|
|
"file_size": lora["size"],
|
|
"modified": lora["modified"],
|
|
"tags": lora["tags"],
|
|
"modelDescription": lora["modelDescription"],
|
|
"from_civitai": lora.get("from_civitai", True),
|
|
"usage_tips": lora.get("usage_tips", ""),
|
|
"notes": lora.get("notes", ""),
|
|
"civitai": self._filter_civitai_data(lora.get("civitai", {}))
|
|
}
|
|
|
|
def _filter_civitai_data(self, data: Dict) -> Dict:
|
|
"""Filter relevant fields from CivitAI data"""
|
|
if not data:
|
|
return {}
|
|
|
|
fields = [
|
|
"id", "modelId", "name", "createdAt", "updatedAt",
|
|
"publishedAt", "trainedWords", "baseModel", "description",
|
|
"model", "images"
|
|
]
|
|
return {k: data[k] for k in fields if k in data}
|
|
|
|
# Private helper methods
|
|
async def _delete_model_files(self, target_dir: str, file_name: str) -> List[str]:
|
|
"""Delete model and associated files"""
|
|
patterns = [
|
|
f"{file_name}.safetensors", # Required
|
|
f"{file_name}.metadata.json",
|
|
f"{file_name}.preview.png",
|
|
f"{file_name}.preview.jpg",
|
|
f"{file_name}.preview.jpeg",
|
|
f"{file_name}.preview.webp",
|
|
f"{file_name}.preview.mp4",
|
|
f"{file_name}.png",
|
|
f"{file_name}.jpg",
|
|
f"{file_name}.jpeg",
|
|
f"{file_name}.webp",
|
|
f"{file_name}.mp4"
|
|
]
|
|
|
|
deleted = []
|
|
main_file = patterns[0]
|
|
main_path = os.path.join(target_dir, main_file).replace(os.sep, '/')
|
|
|
|
if os.path.exists(main_path):
|
|
# Notify file monitor to ignore delete event
|
|
self.download_manager.file_monitor.handler.add_ignore_path(main_path, 0)
|
|
|
|
# Delete file
|
|
os.remove(main_path)
|
|
deleted.append(main_path)
|
|
else:
|
|
logger.warning(f"Model file not found: {main_file}")
|
|
|
|
# Remove from cache
|
|
cache = await self.scanner.get_cached_data()
|
|
cache.raw_data = [item for item in cache.raw_data if item['file_path'] != main_path]
|
|
await cache.resort()
|
|
|
|
# Delete optional files
|
|
for pattern in patterns[1:]:
|
|
path = os.path.join(target_dir, pattern)
|
|
if os.path.exists(path):
|
|
try:
|
|
os.remove(path)
|
|
deleted.append(pattern)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to delete {pattern}: {e}")
|
|
|
|
return deleted
|
|
|
|
async def _read_preview_file(self, reader) -> tuple[bytes, str]:
|
|
"""Read preview file and content type from multipart request"""
|
|
field = await reader.next()
|
|
if field.name != 'preview_file':
|
|
raise ValueError("Expected 'preview_file' field")
|
|
content_type = field.headers.get('Content-Type', 'image/png')
|
|
return await field.read(), content_type
|
|
|
|
async def _read_model_path(self, reader) -> str:
|
|
"""Read model path from multipart request"""
|
|
field = await reader.next()
|
|
if field.name != 'model_path':
|
|
raise ValueError("Expected 'model_path' field")
|
|
return (await field.read()).decode()
|
|
|
|
async def _save_preview_file(self, model_path: str, preview_data: bytes, content_type: str) -> str:
|
|
"""Save preview file and return its path"""
|
|
# Determine file extension based on content type
|
|
if content_type.startswith('video/'):
|
|
extension = '.preview.mp4'
|
|
else:
|
|
extension = '.preview.png'
|
|
|
|
base_name = os.path.splitext(os.path.basename(model_path))[0]
|
|
folder = os.path.dirname(model_path)
|
|
preview_path = os.path.join(folder, base_name + extension).replace(os.sep, '/')
|
|
|
|
with open(preview_path, 'wb') as f:
|
|
f.write(preview_data)
|
|
|
|
return preview_path
|
|
|
|
async def _update_preview_metadata(self, model_path: str, preview_path: str):
|
|
"""Update preview path in metadata"""
|
|
metadata_path = os.path.splitext(model_path)[0] + '.metadata.json'
|
|
if os.path.exists(metadata_path):
|
|
try:
|
|
with open(metadata_path, 'r', encoding='utf-8') as f:
|
|
metadata = json.load(f)
|
|
|
|
# Update preview_url directly in the metadata dict
|
|
metadata['preview_url'] = preview_path
|
|
|
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
|
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
|
except Exception as e:
|
|
logger.error(f"Error updating metadata: {e}")
|
|
|
|
async def _load_local_metadata(self, metadata_path: str) -> Dict:
|
|
"""Load local metadata file"""
|
|
if os.path.exists(metadata_path):
|
|
try:
|
|
with open(metadata_path, 'r', encoding='utf-8') as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
logger.error(f"Error loading metadata from {metadata_path}: {e}")
|
|
return {}
|
|
|
|
async def _handle_not_found_on_civitai(self, metadata_path: str, local_metadata: Dict) -> web.Response:
|
|
"""Handle case when model is not found on CivitAI"""
|
|
local_metadata['from_civitai'] = False
|
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
|
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
|
|
return web.json_response(
|
|
{"success": False, "error": "Not found on CivitAI"},
|
|
status=404
|
|
)
|
|
|
|
async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict,
|
|
civitai_metadata: Dict, client: CivitaiClient) -> None:
|
|
"""Update local metadata with CivitAI data"""
|
|
local_metadata['civitai'] = civitai_metadata
|
|
|
|
# Update model name if available
|
|
if 'model' in civitai_metadata:
|
|
if civitai_metadata.get('model', {}).get('name'):
|
|
local_metadata['model_name'] = civitai_metadata['model']['name']
|
|
|
|
# Fetch additional model metadata (description and tags) if we have model ID
|
|
model_id = civitai_metadata['modelId']
|
|
if model_id:
|
|
model_metadata, _ = await client.get_model_metadata(str(model_id))
|
|
if model_metadata:
|
|
local_metadata['modelDescription'] = model_metadata.get('description', '')
|
|
local_metadata['tags'] = model_metadata.get('tags', [])
|
|
|
|
# Update base model
|
|
local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel'))
|
|
|
|
# Update preview if needed
|
|
if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']):
|
|
first_preview = next((img for img in civitai_metadata.get('images', [])), None)
|
|
if first_preview:
|
|
preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1]
|
|
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
|
|
preview_filename = base_name + preview_ext
|
|
preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename)
|
|
|
|
if await client.download_preview_image(first_preview['url'], preview_path):
|
|
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
|
|
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
|
|
|
|
# Save updated metadata
|
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
|
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
|
|
|
|
await self.scanner.update_single_lora_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata)
|
|
|
|
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
|
"""Fetch CivitAI metadata for all loras in the background"""
|
|
try:
|
|
cache = await self.scanner.get_cached_data()
|
|
total = len(cache.raw_data)
|
|
processed = 0
|
|
success = 0
|
|
needs_resort = False
|
|
|
|
# 准备要处理的 loras
|
|
to_process = [
|
|
lora for lora in cache.raw_data
|
|
if lora.get('sha256') and (not lora.get('civitai') or 'id' not in lora.get('civitai')) and lora.get('from_civitai') # TODO: for lora not from CivitAI but added traineWords
|
|
]
|
|
total_to_process = len(to_process)
|
|
|
|
# 发送初始进度
|
|
await ws_manager.broadcast({
|
|
'status': 'started',
|
|
'total': total_to_process,
|
|
'processed': 0,
|
|
'success': 0
|
|
})
|
|
|
|
for lora in to_process:
|
|
try:
|
|
original_name = lora.get('model_name')
|
|
if await self._fetch_and_update_single_lora(
|
|
sha256=lora['sha256'],
|
|
file_path=lora['file_path'],
|
|
lora=lora
|
|
):
|
|
success += 1
|
|
if original_name != lora.get('model_name'):
|
|
needs_resort = True
|
|
|
|
processed += 1
|
|
|
|
# 每处理一个就发送进度更新
|
|
await ws_manager.broadcast({
|
|
'status': 'processing',
|
|
'total': total_to_process,
|
|
'processed': processed,
|
|
'success': success,
|
|
'current_name': lora.get('model_name', 'Unknown')
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching CivitAI data for {lora['file_path']}: {e}")
|
|
|
|
if needs_resort:
|
|
await cache.resort(name_only=True)
|
|
|
|
# 发送完成消息
|
|
await ws_manager.broadcast({
|
|
'status': 'completed',
|
|
'total': total_to_process,
|
|
'processed': processed,
|
|
'success': success
|
|
})
|
|
|
|
return web.json_response({
|
|
"success": True,
|
|
"message": f"Successfully updated {success} of {processed} processed loras (total: {total})"
|
|
})
|
|
|
|
except Exception as e:
|
|
# 发送错误消息
|
|
await ws_manager.broadcast({
|
|
'status': 'error',
|
|
'error': str(e)
|
|
})
|
|
logger.error(f"Error in fetch_all_civitai: {e}")
|
|
return web.Response(text=str(e), status=500)
|
|
|
|
async def _fetch_and_update_single_lora(self, sha256: str, file_path: str, lora: dict) -> bool:
|
|
"""Fetch and update metadata for a single lora without sorting
|
|
|
|
Args:
|
|
sha256: SHA256 hash of the lora file
|
|
file_path: Path to the lora file
|
|
lora: The lora object in cache to update
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
client = CivitaiClient()
|
|
try:
|
|
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
|
|
|
# Check if model is from CivitAI
|
|
local_metadata = await self._load_local_metadata(metadata_path)
|
|
|
|
# Fetch metadata
|
|
civitai_metadata = await client.get_model_by_hash(sha256)
|
|
if not civitai_metadata:
|
|
# Mark as not from CivitAI if not found
|
|
local_metadata['from_civitai'] = False
|
|
lora['from_civitai'] = False
|
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
|
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
|
|
return False
|
|
|
|
# Update metadata
|
|
await self._update_model_metadata(
|
|
metadata_path,
|
|
local_metadata,
|
|
civitai_metadata,
|
|
client
|
|
)
|
|
|
|
# Update cache object directly
|
|
lora.update({
|
|
'model_name': local_metadata.get('model_name'),
|
|
'preview_url': local_metadata.get('preview_url'),
|
|
'from_civitai': True,
|
|
'civitai': civitai_metadata
|
|
})
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching CivitAI data: {e}")
|
|
return False
|
|
finally:
|
|
await client.close()
|
|
|
|
async def get_lora_roots(self, request: web.Request) -> web.Response:
|
|
"""Get all configured LoRA root directories"""
|
|
return web.json_response({
|
|
'roots': config.loras_roots
|
|
})
|
|
|
|
async def get_civitai_versions(self, request: web.Request) -> web.Response:
|
|
"""Get available versions for a Civitai model with local availability info"""
|
|
try:
|
|
model_id = request.match_info['model_id']
|
|
versions = await self.civitai_client.get_model_versions(model_id)
|
|
if not versions:
|
|
return web.Response(status=404, text="Model not found")
|
|
|
|
# Check local availability for each version
|
|
for version in versions:
|
|
for file in version.get('files', []):
|
|
sha256 = file.get('hashes', {}).get('SHA256')
|
|
if sha256:
|
|
file['existsLocally'] = self.scanner.has_lora_hash(sha256)
|
|
if file['existsLocally']:
|
|
file['localPath'] = self.scanner.get_lora_path_by_hash(sha256)
|
|
|
|
return web.json_response(versions)
|
|
except Exception as e:
|
|
logger.error(f"Error fetching model versions: {e}")
|
|
return web.Response(status=500, text=str(e))
|
|
|
|
async def download_lora(self, request: web.Request) -> web.Response:
|
|
async with self._download_lock:
|
|
try:
|
|
data = await request.json()
|
|
|
|
# Create progress callback
|
|
async def progress_callback(progress):
|
|
await ws_manager.broadcast({
|
|
'status': 'progress',
|
|
'progress': progress
|
|
})
|
|
|
|
result = await self.download_manager.download_from_civitai(
|
|
download_url=data.get('download_url'),
|
|
save_dir=data.get('lora_root'),
|
|
relative_path=data.get('relative_path'),
|
|
progress_callback=progress_callback # Add progress callback
|
|
)
|
|
|
|
if not result.get('success', False):
|
|
return web.Response(status=500, text=result.get('error', 'Unknown error'))
|
|
|
|
return web.json_response(result)
|
|
except Exception as e:
|
|
logger.error(f"Error downloading LoRA: {e}")
|
|
return web.Response(status=500, text=str(e))
|
|
|
|
async def update_settings(self, request: web.Request) -> web.Response:
|
|
"""Update application settings"""
|
|
try:
|
|
data = await request.json()
|
|
|
|
# Validate and update settings
|
|
if 'civitai_api_key' in data:
|
|
settings.set('civitai_api_key', data['civitai_api_key'])
|
|
if 'show_only_sfw' in data:
|
|
settings.set('show_only_sfw', data['show_only_sfw'])
|
|
|
|
return web.json_response({'success': True})
|
|
except Exception as e:
|
|
logger.error(f"Error updating settings: {e}", exc_info=True) # 添加 exc_info=True 以获取完整堆栈
|
|
return web.Response(status=500, text=str(e))
|
|
|
|
async def move_model(self, request: web.Request) -> web.Response:
|
|
"""Handle model move request"""
|
|
try:
|
|
data = await request.json()
|
|
file_path = data.get('file_path')
|
|
target_path = data.get('target_path')
|
|
|
|
if not file_path or not target_path:
|
|
return web.Response(text='File path and target path are required', status=400)
|
|
|
|
# Call scanner to handle the move operation
|
|
success = await self.scanner.move_model(file_path, target_path)
|
|
|
|
if success:
|
|
return web.json_response({'success': True})
|
|
else:
|
|
return web.Response(text='Failed to move model', status=500)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error moving model: {e}", exc_info=True)
|
|
return web.Response(text=str(e), status=500)
|
|
|
|
@classmethod
|
|
async def cleanup(cls):
|
|
"""Add cleanup method for application shutdown"""
|
|
if hasattr(cls, '_instance'):
|
|
await cls._instance.civitai_client.close()
|
|
|
|
async def save_metadata(self, request: web.Request) -> web.Response:
|
|
"""Handle saving metadata updates"""
|
|
try:
|
|
data = await request.json()
|
|
file_path = data.get('file_path')
|
|
if not file_path:
|
|
return web.Response(text='File path is required', status=400)
|
|
|
|
# Remove file path from data to avoid saving it
|
|
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
|
|
|
|
# Get metadata file path
|
|
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
|
|
|
# Load existing metadata
|
|
if os.path.exists(metadata_path):
|
|
with open(metadata_path, 'r', encoding='utf-8') as f:
|
|
metadata = json.load(f)
|
|
else:
|
|
metadata = {}
|
|
|
|
# Handle nested updates (for civitai.trainedWords)
|
|
for key, value in metadata_updates.items():
|
|
if isinstance(value, dict) and key in metadata and isinstance(metadata[key], dict):
|
|
# Deep update for nested dictionaries
|
|
for nested_key, nested_value in value.items():
|
|
metadata[key][nested_key] = nested_value
|
|
else:
|
|
# Regular update for top-level keys
|
|
metadata[key] = value
|
|
|
|
# Save updated metadata
|
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
|
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
|
|
|
# Update cache
|
|
await self.scanner.update_single_lora_cache(file_path, file_path, metadata)
|
|
|
|
# If model_name was updated, resort the cache
|
|
if 'model_name' in metadata_updates:
|
|
cache = await self.scanner.get_cached_data()
|
|
await cache.resort(name_only=True)
|
|
|
|
return web.json_response({'success': True})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving metadata: {e}", exc_info=True)
|
|
return web.Response(text=str(e), status=500)
|
|
|
|
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
|
|
"""Get the static preview URL for a LoRA file"""
|
|
try:
|
|
# Get lora file name from query parameters
|
|
lora_name = request.query.get('name')
|
|
if not lora_name:
|
|
return web.Response(text='Lora file name is required', status=400)
|
|
|
|
# Get cache data
|
|
cache = await self.scanner.get_cached_data()
|
|
|
|
# Search for the lora in cache data
|
|
for lora in cache.raw_data:
|
|
file_name = lora['file_name']
|
|
if file_name == lora_name:
|
|
if preview_url := lora.get('preview_url'):
|
|
# Convert preview path to static URL
|
|
static_url = config.get_preview_static_url(preview_url)
|
|
if static_url:
|
|
return web.json_response({
|
|
'success': True,
|
|
'preview_url': static_url
|
|
})
|
|
break
|
|
|
|
# If no preview URL found
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': 'No preview URL found for the specified lora'
|
|
}, status=404)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting lora preview URL: {e}", exc_info=True)
|
|
return web.Response(text=str(e), status=500)
|
|
|
|
async def move_models_bulk(self, request: web.Request) -> web.Response:
|
|
"""Handle bulk model move request"""
|
|
try:
|
|
data = await request.json()
|
|
file_paths = data.get('file_paths', [])
|
|
target_path = data.get('target_path')
|
|
|
|
if not file_paths or not target_path:
|
|
return web.Response(text='File paths and target path are required', status=400)
|
|
|
|
results = []
|
|
for file_path in file_paths:
|
|
success = await self.scanner.move_model(file_path, target_path)
|
|
results.append({"path": file_path, "success": success})
|
|
|
|
# Count successes
|
|
success_count = sum(1 for r in results if r["success"])
|
|
|
|
if success_count == len(file_paths):
|
|
return web.json_response({
|
|
'success': True,
|
|
'message': f'Successfully moved {success_count} models'
|
|
})
|
|
elif success_count > 0:
|
|
return web.json_response({
|
|
'success': True,
|
|
'message': f'Moved {success_count} of {len(file_paths)} models',
|
|
'results': results
|
|
})
|
|
else:
|
|
return web.Response(text='Failed to move any models', status=500)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error moving models in bulk: {e}", exc_info=True)
|
|
return web.Response(text=str(e), status=500)
|
|
|
|
async def get_lora_model_description(self, request: web.Request) -> web.Response:
|
|
"""Get model description for a Lora model"""
|
|
try:
|
|
# Get parameters
|
|
model_id = request.query.get('model_id')
|
|
file_path = request.query.get('file_path')
|
|
|
|
if not model_id:
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': 'Model ID is required'
|
|
}, status=400)
|
|
|
|
# Check if we already have the description stored in metadata
|
|
description = None
|
|
tags = []
|
|
if file_path:
|
|
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
|
if os.path.exists(metadata_path):
|
|
try:
|
|
with open(metadata_path, 'r', encoding='utf-8') as f:
|
|
metadata = json.load(f)
|
|
description = metadata.get('modelDescription')
|
|
tags = metadata.get('tags', [])
|
|
except Exception as e:
|
|
logger.error(f"Error loading metadata from {metadata_path}: {e}")
|
|
|
|
# If description is not in metadata, fetch from CivitAI
|
|
if not description:
|
|
logger.info(f"Fetching model metadata for model ID: {model_id}")
|
|
model_metadata, _ = await self.civitai_client.get_model_metadata(model_id)
|
|
|
|
if model_metadata:
|
|
description = model_metadata.get('description')
|
|
tags = model_metadata.get('tags', [])
|
|
|
|
# Save the metadata to file if we have a file path and got metadata
|
|
if file_path:
|
|
try:
|
|
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
|
if os.path.exists(metadata_path):
|
|
with open(metadata_path, 'r', encoding='utf-8') as f:
|
|
metadata = json.load(f)
|
|
|
|
metadata['modelDescription'] = description
|
|
metadata['tags'] = tags
|
|
|
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
|
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
|
logger.info(f"Saved model metadata to file for {file_path}")
|
|
except Exception as e:
|
|
logger.error(f"Error saving model metadata: {e}")
|
|
|
|
return web.json_response({
|
|
'success': True,
|
|
'description': description or "<p>No model description available.</p>",
|
|
'tags': tags
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting model metadata: {e}", exc_info=True)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
}, status=500)
|
|
|
|
async def get_top_tags(self, request: web.Request) -> web.Response:
|
|
"""Handle request for top tags sorted by frequency"""
|
|
try:
|
|
# Parse query parameters
|
|
limit = int(request.query.get('limit', '20'))
|
|
|
|
# Validate limit
|
|
if limit < 1 or limit > 100:
|
|
limit = 20 # Default to a reasonable limit
|
|
|
|
# Get top tags
|
|
top_tags = await self.scanner.get_top_tags(limit)
|
|
|
|
return web.json_response({
|
|
'success': True,
|
|
'tags': top_tags
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': 'Internal server error'
|
|
}, status=500)
|