refactor: Extract model-related utility functions into ModelRouteUtils for better code organization

This commit is contained in:
Will Miao
2025-04-11 10:54:19 +08:00
parent 297ff0dd25
commit 31d27ff3fa
4 changed files with 318 additions and 351 deletions

View File

@@ -2,9 +2,9 @@ import os
import json
import logging
from aiohttp import web
from typing import Dict, List
from typing import Dict
from ..utils.model_utils import determine_base_model
from ..utils.routes_common import ModelRouteUtils
from ..services.file_monitor import LoraFileMonitor
from ..services.download_manager import DownloadManager
@@ -72,7 +72,19 @@ class ApiRoutes:
target_dir = os.path.dirname(file_path)
file_name = os.path.splitext(os.path.basename(file_path))[0]
deleted_files = await self._delete_model_files(target_dir, file_name)
deleted_files = await ModelRouteUtils.delete_model_files(
target_dir,
file_name,
self.download_manager.file_monitor
)
# Remove from cache
cache = await self.scanner.get_cached_data()
cache.raw_data = [item for item in cache.raw_data if item['file_path'] != file_path]
await cache.resort()
# update hash index
self.scanner._hash_index.remove_by_path(file_path)
return web.json_response({
'success': True,
@@ -90,14 +102,18 @@ class ApiRoutes:
metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json'
# Check if model is from CivitAI
local_metadata = await self._load_local_metadata(metadata_path)
local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
# Fetch and update metadata
civitai_metadata = await self.civitai_client.get_model_by_hash(local_metadata["sha256"])
if not civitai_metadata:
return await self._handle_not_found_on_civitai(metadata_path, local_metadata)
await ModelRouteUtils.handle_not_found_on_civitai(metadata_path, local_metadata)
return web.json_response({"success": False, "error": "Not found on CivitAI"}, status=404)
await self._update_model_metadata(metadata_path, local_metadata, civitai_metadata, self.civitai_client)
await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, self.civitai_client)
# Update the cache
await self.scanner.update_single_model_cache(data['file_path'], data['file_path'], local_metadata)
return web.json_response({"success": True})
@@ -139,10 +155,12 @@ class ApiRoutes:
fuzzy_search = request.query.get('fuzzy', 'false').lower() == 'true'
# Parse search options
search_filename = request.query.get('search_filename', 'true').lower() == 'true'
search_modelname = request.query.get('search_modelname', 'true').lower() == 'true'
search_tags = request.query.get('search_tags', 'false').lower() == 'true'
recursive = request.query.get('recursive', 'false').lower() == 'true'
search_options = {
'filename': request.query.get('search_filename', 'true').lower() == 'true',
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
'tags': request.query.get('search_tags', 'false').lower() == 'true',
'recursive': request.query.get('recursive', 'false').lower() == 'true'
}
# Get filter parameters
base_models = request.query.get('base_models', None)
@@ -159,14 +177,6 @@ class ApiRoutes:
if tags:
filters['tags'] = tags.split(',')
# Add search options to filters
search_options = {
'filename': search_filename,
'modelname': search_modelname,
'tags': search_tags,
'recursive': recursive
}
# Add lora hash filtering options
hash_filters = {}
if lora_hash:
@@ -225,67 +235,10 @@ class ApiRoutes:
"from_civitai": lora.get("from_civitai", True),
"usage_tips": lora.get("usage_tips", ""),
"notes": lora.get("notes", ""),
"civitai": self._filter_civitai_data(lora.get("civitai", {}))
"civitai": ModelRouteUtils.filter_civitai_data(lora.get("civitai", {}))
}
def _filter_civitai_data(self, data: Dict) -> Dict:
"""Filter relevant fields from CivitAI data"""
if not data:
return {}
fields = [
"id", "modelId", "name", "createdAt", "updatedAt",
"publishedAt", "trainedWords", "baseModel", "description",
"model", "images"
]
return {k: data[k] for k in fields if k in data}
# Private helper methods
async def _delete_model_files(self, target_dir: str, file_name: str) -> List[str]:
"""Delete model and associated files"""
patterns = [
f"{file_name}.safetensors", # Required
f"{file_name}.metadata.json",
]
# 添加所有预览文件扩展名
for ext in PREVIEW_EXTENSIONS:
patterns.append(f"{file_name}{ext}")
deleted = []
main_file = patterns[0]
main_path = os.path.join(target_dir, main_file).replace(os.sep, '/')
if os.path.exists(main_path):
# Notify file monitor to ignore delete event
self.download_manager.file_monitor.handler.add_ignore_path(main_path, 0)
# Delete file
os.remove(main_path)
deleted.append(main_path)
else:
logger.warning(f"Model file not found: {main_file}")
# Remove from cache
cache = await self.scanner.get_cached_data()
cache.raw_data = [item for item in cache.raw_data if item['file_path'] != main_path]
await cache.resort()
# update hash index
self.scanner._hash_index.remove_by_path(main_path)
# Delete optional files
for pattern in patterns[1:]:
path = os.path.join(target_dir, pattern)
if os.path.exists(path):
try:
os.remove(path)
deleted.append(pattern)
except Exception as e:
logger.warning(f"Failed to delete {pattern}: {e}")
return deleted
async def _read_preview_file(self, reader) -> tuple[bytes, str]:
"""Read preview file and content type from multipart request"""
field = await reader.next()
@@ -345,66 +298,6 @@ class ApiRoutes:
except Exception as e:
logger.error(f"Error updating metadata: {e}")
async def _load_local_metadata(self, metadata_path: str) -> Dict:
"""Load local metadata file"""
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading metadata from {metadata_path}: {e}")
return {}
async def _handle_not_found_on_civitai(self, metadata_path: str, local_metadata: Dict) -> web.Response:
"""Handle case when model is not found on CivitAI"""
local_metadata['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
return web.json_response(
{"success": False, "error": "Not found on CivitAI"},
status=404
)
async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict,
civitai_metadata: Dict, client: CivitaiClient) -> None:
"""Update local metadata with CivitAI data"""
local_metadata['civitai'] = civitai_metadata
# Update model name if available
if 'model' in civitai_metadata:
if civitai_metadata.get('model', {}).get('name'):
local_metadata['model_name'] = civitai_metadata['model']['name']
# Fetch additional model metadata (description and tags) if we have model ID
model_id = civitai_metadata['modelId']
if model_id:
model_metadata, _ = await client.get_model_metadata(str(model_id))
if model_metadata:
local_metadata['modelDescription'] = model_metadata.get('description', '')
local_metadata['tags'] = model_metadata.get('tags', [])
# Update base model
local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel'))
# Update preview if needed
if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']):
first_preview = next((img for img in civitai_metadata.get('images', [])), None)
if first_preview:
preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1]
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
preview_filename = base_name + preview_ext
preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename)
if await client.download_preview_image(first_preview['url'], preview_path):
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
await self.scanner.update_single_model_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata)
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
"""Fetch CivitAI metadata for all loras in the background"""
try:
@@ -414,14 +307,14 @@ class ApiRoutes:
success = 0
needs_resort = False
# 准备要处理的 loras
# Prepare loras to process
to_process = [
lora for lora in cache.raw_data
if lora.get('sha256') and (not lora.get('civitai') or 'id' not in lora.get('civitai')) and lora.get('from_civitai') # TODO: for lora not from CivitAI but added traineWords
if lora.get('sha256') and (not lora.get('civitai') or 'id' not in lora.get('civitai')) and lora.get('from_civitai', True) # TODO: for lora not from CivitAI but added traineWords
]
total_to_process = len(to_process)
# 发送初始进度
# Send initial progress
await ws_manager.broadcast({
'status': 'started',
'total': total_to_process,
@@ -432,10 +325,11 @@ class ApiRoutes:
for lora in to_process:
try:
original_name = lora.get('model_name')
if await self._fetch_and_update_single_lora(
if await ModelRouteUtils.fetch_and_update_model(
sha256=lora['sha256'],
file_path=lora['file_path'],
lora=lora
model_data=lora,
update_cache_func=self.scanner.update_single_model_cache
):
success += 1
if original_name != lora.get('model_name'):
@@ -443,7 +337,7 @@ class ApiRoutes:
processed += 1
# 每处理一个就发送进度更新
# Send progress update
await ws_manager.broadcast({
'status': 'processing',
'total': total_to_process,
@@ -458,7 +352,7 @@ class ApiRoutes:
if needs_resort:
await cache.resort(name_only=True)
# 发送完成消息
# Send completion message
await ws_manager.broadcast({
'status': 'completed',
'total': total_to_process,
@@ -472,7 +366,7 @@ class ApiRoutes:
})
except Exception as e:
# 发送错误消息
# Send error message
await ws_manager.broadcast({
'status': 'error',
'error': str(e)
@@ -480,58 +374,6 @@ class ApiRoutes:
logger.error(f"Error in fetch_all_civitai: {e}")
return web.Response(text=str(e), status=500)
async def _fetch_and_update_single_lora(self, sha256: str, file_path: str, lora: dict) -> bool:
"""Fetch and update metadata for a single lora without sorting
Args:
sha256: SHA256 hash of the lora file
file_path: Path to the lora file
lora: The lora object in cache to update
Returns:
bool: True if successful, False otherwise
"""
client = CivitaiClient()
try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Check if model is from CivitAI
local_metadata = await self._load_local_metadata(metadata_path)
# Fetch metadata
civitai_metadata = await client.get_model_by_hash(sha256)
if not civitai_metadata:
# Mark as not from CivitAI if not found
local_metadata['from_civitai'] = False
lora['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
return False
# Update metadata
await self._update_model_metadata(
metadata_path,
local_metadata,
civitai_metadata,
client
)
# Update cache object directly
lora.update({
'model_name': local_metadata.get('model_name'),
'preview_url': local_metadata.get('preview_url'),
'from_civitai': True,
'civitai': civitai_metadata
})
return True
except Exception as e:
logger.error(f"Error fetching CivitAI data: {e}")
return False
finally:
await client.close()
async def get_lora_roots(self, request: web.Request) -> web.Response:
"""Get all configured LoRA root directories"""
return web.json_response({
@@ -669,7 +511,7 @@ class ApiRoutes:
return web.json_response({'success': True})
except Exception as e:
logger.error(f"Error updating settings: {e}", exc_info=True) # 添加 exc_info=True 以获取完整堆栈
logger.error(f"Error updating settings: {e}", exc_info=True)
return web.Response(status=500, text=str(e))
async def move_model(self, request: web.Request) -> web.Response:
@@ -731,11 +573,7 @@ class ApiRoutes:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Load existing metadata
if os.path.exists(metadata_path):
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
else:
metadata = {}
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
# Handle nested updates (for civitai.trainedWords)
for key, value in metadata_updates.items():
@@ -798,7 +636,10 @@ class ApiRoutes:
except Exception as e:
logger.error(f"Error getting lora preview URL: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
"""Get the Civitai URL for a LoRA file"""
@@ -921,14 +762,9 @@ class ApiRoutes:
tags = []
if file_path:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
description = metadata.get('modelDescription')
tags = metadata.get('tags', [])
except Exception as e:
logger.error(f"Error loading metadata from {metadata_path}: {e}")
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
description = metadata.get('modelDescription')
tags = metadata.get('tags', [])
# If description is not in metadata, fetch from CivitAI
if not description:
@@ -943,16 +779,14 @@ class ApiRoutes:
if file_path:
try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
if os.path.exists(metadata_path):
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
metadata['modelDescription'] = description
metadata['tags'] = tags
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
logger.info(f"Saved model metadata to file for {file_path}")
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
metadata['modelDescription'] = description
metadata['tags'] = tags
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
logger.info(f"Saved model metadata to file for {file_path}")
except Exception as e:
logger.error(f"Error saving model metadata: {e}")
@@ -1018,12 +852,6 @@ class ApiRoutes:
'error': str(e)
}, status=500)
def get_multipart_ext(self, filename):
parts = filename.split(".")
if len(parts) > 2: # 如果包含多级扩展名
return "." + ".".join(parts[-2:]) # 取最后两部分,如 ".metadata.json"
return os.path.splitext(filename)[1] # 否则取普通扩展名,如 ".safetensors"
async def rename_lora(self, request: web.Request) -> web.Response:
"""Handle renaming a LoRA file and its associated files"""
try:
@@ -1063,7 +891,7 @@ class ApiRoutes:
f"{old_file_name}.metadata.json",
]
# 添加所有预览文件扩展名
# Add all preview file extensions
for ext in PREVIEW_EXTENSIONS:
patterns.append(f"{old_file_name}{ext}")
@@ -1080,12 +908,8 @@ class ApiRoutes:
metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json")
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
hash_value = metadata.get('sha256')
except Exception as e:
logger.error(f"Error loading metadata for rename: {e}")
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
hash_value = metadata.get('sha256')
# Rename all files
renamed_files = []
@@ -1101,7 +925,7 @@ class ApiRoutes:
for old_path, pattern in existing_files:
# Get the file extension like .safetensors or .metadata.json
ext = self.get_multipart_ext(pattern)
ext = ModelRouteUtils.get_multipart_ext(pattern)
# Create the new path
new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/')
@@ -1123,7 +947,7 @@ class ApiRoutes:
# Update preview_url if it exists
if 'preview_url' in metadata and metadata['preview_url']:
old_preview = metadata['preview_url']
ext = self.get_multipart_ext(old_preview)
ext = ModelRouteUtils.get_multipart_ext(old_preview)
new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/')
metadata['preview_url'] = new_preview

View File

@@ -1,15 +1,10 @@
import os
import json
import asyncio
from typing import Dict
import aiohttp
import jinja2
from aiohttp import web
import logging
from datetime import datetime
from ..utils.model_utils import determine_base_model
from ..utils.routes_common import ModelRouteUtils
from ..utils.constants import NSFW_LEVELS
from ..services.civitai_client import CivitaiClient
from ..services.websocket_manager import ws_manager
@@ -259,21 +254,9 @@ class CheckpointsRoutes:
"from_civitai": checkpoint.get("from_civitai", True),
"notes": checkpoint.get("notes", ""),
"model_type": checkpoint.get("model_type", "checkpoint"),
"civitai": self._filter_civitai_data(checkpoint.get("civitai", {}))
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint.get("civitai", {}))
}
def _filter_civitai_data(self, data):
"""Filter relevant fields from CivitAI data"""
if not data:
return {}
fields = [
"id", "modelId", "name", "createdAt", "updatedAt",
"publishedAt", "trainedWords", "baseModel", "description",
"model", "images"
]
return {k: data[k] for k in fields if k in data}
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
"""Fetch CivitAI metadata for all checkpoints in the background"""
try:
@@ -302,10 +285,11 @@ class CheckpointsRoutes:
for cp in to_process:
try:
original_name = cp.get('model_name')
if await self._fetch_and_update_single_checkpoint(
if await ModelRouteUtils.fetch_and_update_model(
sha256=cp['sha256'],
file_path=cp['file_path'],
checkpoint=cp
model_data=cp,
update_cache_func=self.scanner.update_single_model_cache
):
success += 1
if original_name != cp.get('model_name'):
@@ -350,99 +334,6 @@ class CheckpointsRoutes:
logger.error(f"Error in fetch_all_civitai for checkpoints: {e}")
return web.Response(text=str(e), status=500)
async def _fetch_and_update_single_checkpoint(self, sha256: str, file_path: str, checkpoint: dict) -> bool:
"""Fetch and update metadata for a single checkpoint without sorting"""
client = CivitaiClient()
try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Load local metadata
local_metadata = self._load_local_metadata(metadata_path)
# Fetch metadata from Civitai
civitai_metadata = await client.get_model_by_hash(sha256)
if not civitai_metadata:
# Mark as not from CivitAI if not found
local_metadata['from_civitai'] = False
checkpoint['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
return False
# Update metadata with Civitai data
await self._update_model_metadata(
metadata_path,
local_metadata,
civitai_metadata,
client
)
# Update cache object directly
checkpoint.update({
'model_name': local_metadata.get('model_name'),
'preview_url': local_metadata.get('preview_url'),
'from_civitai': True,
'civitai': civitai_metadata
})
return True
except Exception as e:
logger.error(f"Error fetching CivitAI data for checkpoint: {e}")
return False
finally:
await client.close()
def _load_local_metadata(self, metadata_path: str) -> Dict:
"""Load local metadata file"""
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading metadata from {metadata_path}: {e}")
return {}
async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict,
civitai_metadata: Dict, client: CivitaiClient) -> None:
"""Update local metadata with CivitAI data"""
local_metadata['civitai'] = civitai_metadata
# Update model name if available
if 'model' in civitai_metadata:
if civitai_metadata.get('model', {}).get('name'):
local_metadata['model_name'] = civitai_metadata['model']['name']
# Fetch additional model metadata (description and tags) if we have model ID
model_id = civitai_metadata['modelId']
if model_id:
model_metadata, _ = await client.get_model_metadata(str(model_id))
if model_metadata:
local_metadata['modelDescription'] = model_metadata.get('description', '')
local_metadata['tags'] = model_metadata.get('tags', [])
# Update base model
local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel'))
# Update preview if needed
if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']):
first_preview = next((img for img in civitai_metadata.get('images', [])), None)
if first_preview:
preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1]
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
preview_filename = base_name + preview_ext
preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename)
if await client.download_preview_image(first_preview['url'], preview_path):
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
await self.scanner.update_single_model_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata)
async def get_top_tags(self, request: web.Request) -> web.Response:
"""Handle request for top tags sorted by frequency"""
try:

252
py/utils/routes_common.py Normal file
View File

@@ -0,0 +1,252 @@
import os
import json
import logging
from typing import Dict, List, Callable, Awaitable
from .model_utils import determine_base_model
from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
from ..config import config
from ..services.civitai_client import CivitaiClient
from ..utils.exif_utils import ExifUtils
logger = logging.getLogger(__name__)
class ModelRouteUtils:
"""Shared utilities for model routes (LoRAs, Checkpoints, etc.)"""
@staticmethod
async def load_local_metadata(metadata_path: str) -> Dict:
"""Load local metadata file"""
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading metadata from {metadata_path}: {e}")
return {}
@staticmethod
async def handle_not_found_on_civitai(metadata_path: str, local_metadata: Dict) -> None:
"""Handle case when model is not found on CivitAI"""
local_metadata['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
@staticmethod
async def update_model_metadata(metadata_path: str, local_metadata: Dict,
civitai_metadata: Dict, client: CivitaiClient) -> None:
"""Update local metadata with CivitAI data"""
local_metadata['civitai'] = civitai_metadata
# Update model name if available
if 'model' in civitai_metadata:
if civitai_metadata.get('model', {}).get('name'):
local_metadata['model_name'] = civitai_metadata['model']['name']
# Fetch additional model metadata (description and tags) if we have model ID
model_id = civitai_metadata['modelId']
if model_id:
model_metadata, _ = await client.get_model_metadata(str(model_id))
if model_metadata:
local_metadata['modelDescription'] = model_metadata.get('description', '')
local_metadata['tags'] = model_metadata.get('tags', [])
# Update base model
local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel'))
# Update preview if needed
if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']):
first_preview = next((img for img in civitai_metadata.get('images', [])), None)
if first_preview:
# Determine if content is video or image
is_video = first_preview['type'] == 'video'
if is_video:
# For videos use .mp4 extension
preview_ext = '.mp4'
else:
# For images use .webp extension
preview_ext = '.webp'
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
preview_filename = base_name + preview_ext
preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename)
if is_video:
# Download video as is
if await client.download_preview_image(first_preview['url'], preview_path):
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
else:
# For images, download and then optimize to WebP
temp_path = preview_path + ".temp"
if await client.download_preview_image(first_preview['url'], temp_path):
try:
# Read the downloaded image
with open(temp_path, 'rb') as f:
image_data = f.read()
# Optimize and convert to WebP
optimized_data, _ = ExifUtils.optimize_image(
image_data=image_data,
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=True
)
# Save the optimized WebP image
with open(preview_path, 'wb') as f:
f.write(optimized_data)
# Update metadata
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
# Remove the temporary file
if os.path.exists(temp_path):
os.remove(temp_path)
except Exception as e:
logger.error(f"Error optimizing preview image: {e}")
# If optimization fails, try to use the downloaded image directly
if os.path.exists(temp_path):
os.rename(temp_path, preview_path)
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
@staticmethod
async def fetch_and_update_model(
sha256: str,
file_path: str,
model_data: dict,
update_cache_func: Callable[[str, str, Dict], Awaitable[bool]]
) -> bool:
"""Fetch and update metadata for a single model
Args:
sha256: SHA256 hash of the model file
file_path: Path to the model file
model_data: The model object in cache to update
update_cache_func: Function to update the cache with new metadata
Returns:
bool: True if successful, False otherwise
"""
client = CivitaiClient()
try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Check if model metadata exists
local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
# Fetch metadata from Civitai
civitai_metadata = await client.get_model_by_hash(sha256)
if not civitai_metadata:
# Mark as not from CivitAI if not found
local_metadata['from_civitai'] = False
model_data['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
return False
# Update metadata
await ModelRouteUtils.update_model_metadata(
metadata_path,
local_metadata,
civitai_metadata,
client
)
# Update cache object directly
model_data.update({
'model_name': local_metadata.get('model_name'),
'preview_url': local_metadata.get('preview_url'),
'from_civitai': True,
'civitai': civitai_metadata
})
# Update cache using the provided function
await update_cache_func(file_path, file_path, local_metadata)
return True
except Exception as e:
logger.error(f"Error fetching CivitAI data: {e}")
return False
finally:
await client.close()
@staticmethod
def filter_civitai_data(data: Dict) -> Dict:
"""Filter relevant fields from CivitAI data"""
if not data:
return {}
fields = [
"id", "modelId", "name", "createdAt", "updatedAt",
"publishedAt", "trainedWords", "baseModel", "description",
"model", "images"
]
return {k: data[k] for k in fields if k in data}
@staticmethod
async def delete_model_files(target_dir: str, file_name: str, file_monitor=None) -> List[str]:
"""Delete model and associated files
Args:
target_dir: Directory containing the model files
file_name: Base name of the model file without extension
file_monitor: Optional file monitor to ignore delete events
Returns:
List of deleted file paths
"""
patterns = [
f"{file_name}.safetensors", # Required
f"{file_name}.metadata.json",
]
# Add all preview file extensions
for ext in PREVIEW_EXTENSIONS:
patterns.append(f"{file_name}{ext}")
deleted = []
main_file = patterns[0]
main_path = os.path.join(target_dir, main_file).replace(os.sep, '/')
if os.path.exists(main_path):
# Notify file monitor to ignore delete event if available
if file_monitor:
file_monitor.handler.add_ignore_path(main_path, 0)
# Delete file
os.remove(main_path)
deleted.append(main_path)
else:
logger.warning(f"Model file not found: {main_file}")
# Delete optional files
for pattern in patterns[1:]:
path = os.path.join(target_dir, pattern)
if os.path.exists(path):
try:
os.remove(path)
deleted.append(pattern)
except Exception as e:
logger.warning(f"Failed to delete {pattern}: {e}")
return deleted
@staticmethod
def get_multipart_ext(filename):
"""Get extension that may have multiple parts like .metadata.json"""
parts = filename.split(".")
if len(parts) > 2: # If contains multi-part extension
return "." + ".".join(parts[-2:]) # Take the last two parts, like ".metadata.json"
return os.path.splitext(filename)[1] # Otherwise take the regular extension, like ".safetensors"

View File

@@ -1,4 +1,4 @@
import { showToast } from '../utils/uiHelpers.js';
import { showToast, openCivitai } from '../utils/uiHelpers.js';
import { state } from '../state/index.js';
import { showLoraModal } from './loraModal/index.js';
import { bulkManager } from '../managers/BulkManager.js';