refactor: Extract model-related utility functions into ModelRouteUtils for better code organization

This commit is contained in:
Will Miao
2025-04-11 10:54:19 +08:00
parent 297ff0dd25
commit 31d27ff3fa
4 changed files with 318 additions and 351 deletions

View File

@@ -2,9 +2,9 @@ import os
import json import json
import logging import logging
from aiohttp import web from aiohttp import web
from typing import Dict, List from typing import Dict
from ..utils.model_utils import determine_base_model from ..utils.routes_common import ModelRouteUtils
from ..services.file_monitor import LoraFileMonitor from ..services.file_monitor import LoraFileMonitor
from ..services.download_manager import DownloadManager from ..services.download_manager import DownloadManager
@@ -72,7 +72,19 @@ class ApiRoutes:
target_dir = os.path.dirname(file_path) target_dir = os.path.dirname(file_path)
file_name = os.path.splitext(os.path.basename(file_path))[0] file_name = os.path.splitext(os.path.basename(file_path))[0]
deleted_files = await self._delete_model_files(target_dir, file_name) deleted_files = await ModelRouteUtils.delete_model_files(
target_dir,
file_name,
self.download_manager.file_monitor
)
# Remove from cache
cache = await self.scanner.get_cached_data()
cache.raw_data = [item for item in cache.raw_data if item['file_path'] != file_path]
await cache.resort()
# update hash index
self.scanner._hash_index.remove_by_path(file_path)
return web.json_response({ return web.json_response({
'success': True, 'success': True,
@@ -90,14 +102,18 @@ class ApiRoutes:
metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json' metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json'
# Check if model is from CivitAI # Check if model is from CivitAI
local_metadata = await self._load_local_metadata(metadata_path) local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
# Fetch and update metadata # Fetch and update metadata
civitai_metadata = await self.civitai_client.get_model_by_hash(local_metadata["sha256"]) civitai_metadata = await self.civitai_client.get_model_by_hash(local_metadata["sha256"])
if not civitai_metadata: if not civitai_metadata:
return await self._handle_not_found_on_civitai(metadata_path, local_metadata) await ModelRouteUtils.handle_not_found_on_civitai(metadata_path, local_metadata)
return web.json_response({"success": False, "error": "Not found on CivitAI"}, status=404)
await self._update_model_metadata(metadata_path, local_metadata, civitai_metadata, self.civitai_client) await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, self.civitai_client)
# Update the cache
await self.scanner.update_single_model_cache(data['file_path'], data['file_path'], local_metadata)
return web.json_response({"success": True}) return web.json_response({"success": True})
@@ -139,10 +155,12 @@ class ApiRoutes:
fuzzy_search = request.query.get('fuzzy', 'false').lower() == 'true' fuzzy_search = request.query.get('fuzzy', 'false').lower() == 'true'
# Parse search options # Parse search options
search_filename = request.query.get('search_filename', 'true').lower() == 'true' search_options = {
search_modelname = request.query.get('search_modelname', 'true').lower() == 'true' 'filename': request.query.get('search_filename', 'true').lower() == 'true',
search_tags = request.query.get('search_tags', 'false').lower() == 'true' 'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
recursive = request.query.get('recursive', 'false').lower() == 'true' 'tags': request.query.get('search_tags', 'false').lower() == 'true',
'recursive': request.query.get('recursive', 'false').lower() == 'true'
}
# Get filter parameters # Get filter parameters
base_models = request.query.get('base_models', None) base_models = request.query.get('base_models', None)
@@ -159,14 +177,6 @@ class ApiRoutes:
if tags: if tags:
filters['tags'] = tags.split(',') filters['tags'] = tags.split(',')
# Add search options to filters
search_options = {
'filename': search_filename,
'modelname': search_modelname,
'tags': search_tags,
'recursive': recursive
}
# Add lora hash filtering options # Add lora hash filtering options
hash_filters = {} hash_filters = {}
if lora_hash: if lora_hash:
@@ -225,67 +235,10 @@ class ApiRoutes:
"from_civitai": lora.get("from_civitai", True), "from_civitai": lora.get("from_civitai", True),
"usage_tips": lora.get("usage_tips", ""), "usage_tips": lora.get("usage_tips", ""),
"notes": lora.get("notes", ""), "notes": lora.get("notes", ""),
"civitai": self._filter_civitai_data(lora.get("civitai", {})) "civitai": ModelRouteUtils.filter_civitai_data(lora.get("civitai", {}))
} }
def _filter_civitai_data(self, data: Dict) -> Dict:
"""Filter relevant fields from CivitAI data"""
if not data:
return {}
fields = [
"id", "modelId", "name", "createdAt", "updatedAt",
"publishedAt", "trainedWords", "baseModel", "description",
"model", "images"
]
return {k: data[k] for k in fields if k in data}
# Private helper methods # Private helper methods
async def _delete_model_files(self, target_dir: str, file_name: str) -> List[str]:
"""Delete model and associated files"""
patterns = [
f"{file_name}.safetensors", # Required
f"{file_name}.metadata.json",
]
# 添加所有预览文件扩展名
for ext in PREVIEW_EXTENSIONS:
patterns.append(f"{file_name}{ext}")
deleted = []
main_file = patterns[0]
main_path = os.path.join(target_dir, main_file).replace(os.sep, '/')
if os.path.exists(main_path):
# Notify file monitor to ignore delete event
self.download_manager.file_monitor.handler.add_ignore_path(main_path, 0)
# Delete file
os.remove(main_path)
deleted.append(main_path)
else:
logger.warning(f"Model file not found: {main_file}")
# Remove from cache
cache = await self.scanner.get_cached_data()
cache.raw_data = [item for item in cache.raw_data if item['file_path'] != main_path]
await cache.resort()
# update hash index
self.scanner._hash_index.remove_by_path(main_path)
# Delete optional files
for pattern in patterns[1:]:
path = os.path.join(target_dir, pattern)
if os.path.exists(path):
try:
os.remove(path)
deleted.append(pattern)
except Exception as e:
logger.warning(f"Failed to delete {pattern}: {e}")
return deleted
async def _read_preview_file(self, reader) -> tuple[bytes, str]: async def _read_preview_file(self, reader) -> tuple[bytes, str]:
"""Read preview file and content type from multipart request""" """Read preview file and content type from multipart request"""
field = await reader.next() field = await reader.next()
@@ -345,66 +298,6 @@ class ApiRoutes:
except Exception as e: except Exception as e:
logger.error(f"Error updating metadata: {e}") logger.error(f"Error updating metadata: {e}")
async def _load_local_metadata(self, metadata_path: str) -> Dict:
"""Load local metadata file"""
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading metadata from {metadata_path}: {e}")
return {}
async def _handle_not_found_on_civitai(self, metadata_path: str, local_metadata: Dict) -> web.Response:
"""Handle case when model is not found on CivitAI"""
local_metadata['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
return web.json_response(
{"success": False, "error": "Not found on CivitAI"},
status=404
)
async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict,
civitai_metadata: Dict, client: CivitaiClient) -> None:
"""Update local metadata with CivitAI data"""
local_metadata['civitai'] = civitai_metadata
# Update model name if available
if 'model' in civitai_metadata:
if civitai_metadata.get('model', {}).get('name'):
local_metadata['model_name'] = civitai_metadata['model']['name']
# Fetch additional model metadata (description and tags) if we have model ID
model_id = civitai_metadata['modelId']
if model_id:
model_metadata, _ = await client.get_model_metadata(str(model_id))
if model_metadata:
local_metadata['modelDescription'] = model_metadata.get('description', '')
local_metadata['tags'] = model_metadata.get('tags', [])
# Update base model
local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel'))
# Update preview if needed
if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']):
first_preview = next((img for img in civitai_metadata.get('images', [])), None)
if first_preview:
preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1]
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
preview_filename = base_name + preview_ext
preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename)
if await client.download_preview_image(first_preview['url'], preview_path):
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
await self.scanner.update_single_model_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata)
async def fetch_all_civitai(self, request: web.Request) -> web.Response: async def fetch_all_civitai(self, request: web.Request) -> web.Response:
"""Fetch CivitAI metadata for all loras in the background""" """Fetch CivitAI metadata for all loras in the background"""
try: try:
@@ -414,14 +307,14 @@ class ApiRoutes:
success = 0 success = 0
needs_resort = False needs_resort = False
# 准备要处理的 loras # Prepare loras to process
to_process = [ to_process = [
lora for lora in cache.raw_data lora for lora in cache.raw_data
if lora.get('sha256') and (not lora.get('civitai') or 'id' not in lora.get('civitai')) and lora.get('from_civitai') # TODO: for lora not from CivitAI but added traineWords if lora.get('sha256') and (not lora.get('civitai') or 'id' not in lora.get('civitai')) and lora.get('from_civitai', True) # TODO: for lora not from CivitAI but added traineWords
] ]
total_to_process = len(to_process) total_to_process = len(to_process)
# 发送初始进度 # Send initial progress
await ws_manager.broadcast({ await ws_manager.broadcast({
'status': 'started', 'status': 'started',
'total': total_to_process, 'total': total_to_process,
@@ -432,10 +325,11 @@ class ApiRoutes:
for lora in to_process: for lora in to_process:
try: try:
original_name = lora.get('model_name') original_name = lora.get('model_name')
if await self._fetch_and_update_single_lora( if await ModelRouteUtils.fetch_and_update_model(
sha256=lora['sha256'], sha256=lora['sha256'],
file_path=lora['file_path'], file_path=lora['file_path'],
lora=lora model_data=lora,
update_cache_func=self.scanner.update_single_model_cache
): ):
success += 1 success += 1
if original_name != lora.get('model_name'): if original_name != lora.get('model_name'):
@@ -443,7 +337,7 @@ class ApiRoutes:
processed += 1 processed += 1
# 每处理一个就发送进度更新 # Send progress update
await ws_manager.broadcast({ await ws_manager.broadcast({
'status': 'processing', 'status': 'processing',
'total': total_to_process, 'total': total_to_process,
@@ -458,7 +352,7 @@ class ApiRoutes:
if needs_resort: if needs_resort:
await cache.resort(name_only=True) await cache.resort(name_only=True)
# 发送完成消息 # Send completion message
await ws_manager.broadcast({ await ws_manager.broadcast({
'status': 'completed', 'status': 'completed',
'total': total_to_process, 'total': total_to_process,
@@ -472,7 +366,7 @@ class ApiRoutes:
}) })
except Exception as e: except Exception as e:
# 发送错误消息 # Send error message
await ws_manager.broadcast({ await ws_manager.broadcast({
'status': 'error', 'status': 'error',
'error': str(e) 'error': str(e)
@@ -480,58 +374,6 @@ class ApiRoutes:
logger.error(f"Error in fetch_all_civitai: {e}") logger.error(f"Error in fetch_all_civitai: {e}")
return web.Response(text=str(e), status=500) return web.Response(text=str(e), status=500)
async def _fetch_and_update_single_lora(self, sha256: str, file_path: str, lora: dict) -> bool:
"""Fetch and update metadata for a single lora without sorting
Args:
sha256: SHA256 hash of the lora file
file_path: Path to the lora file
lora: The lora object in cache to update
Returns:
bool: True if successful, False otherwise
"""
client = CivitaiClient()
try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Check if model is from CivitAI
local_metadata = await self._load_local_metadata(metadata_path)
# Fetch metadata
civitai_metadata = await client.get_model_by_hash(sha256)
if not civitai_metadata:
# Mark as not from CivitAI if not found
local_metadata['from_civitai'] = False
lora['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
return False
# Update metadata
await self._update_model_metadata(
metadata_path,
local_metadata,
civitai_metadata,
client
)
# Update cache object directly
lora.update({
'model_name': local_metadata.get('model_name'),
'preview_url': local_metadata.get('preview_url'),
'from_civitai': True,
'civitai': civitai_metadata
})
return True
except Exception as e:
logger.error(f"Error fetching CivitAI data: {e}")
return False
finally:
await client.close()
async def get_lora_roots(self, request: web.Request) -> web.Response: async def get_lora_roots(self, request: web.Request) -> web.Response:
"""Get all configured LoRA root directories""" """Get all configured LoRA root directories"""
return web.json_response({ return web.json_response({
@@ -669,7 +511,7 @@ class ApiRoutes:
return web.json_response({'success': True}) return web.json_response({'success': True})
except Exception as e: except Exception as e:
logger.error(f"Error updating settings: {e}", exc_info=True) # 添加 exc_info=True 以获取完整堆栈 logger.error(f"Error updating settings: {e}", exc_info=True)
return web.Response(status=500, text=str(e)) return web.Response(status=500, text=str(e))
async def move_model(self, request: web.Request) -> web.Response: async def move_model(self, request: web.Request) -> web.Response:
@@ -731,11 +573,7 @@ class ApiRoutes:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Load existing metadata # Load existing metadata
if os.path.exists(metadata_path): metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
else:
metadata = {}
# Handle nested updates (for civitai.trainedWords) # Handle nested updates (for civitai.trainedWords)
for key, value in metadata_updates.items(): for key, value in metadata_updates.items():
@@ -798,7 +636,10 @@ class ApiRoutes:
except Exception as e: except Exception as e:
logger.error(f"Error getting lora preview URL: {e}", exc_info=True) logger.error(f"Error getting lora preview URL: {e}", exc_info=True)
return web.Response(text=str(e), status=500) return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_lora_civitai_url(self, request: web.Request) -> web.Response: async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
"""Get the Civitai URL for a LoRA file""" """Get the Civitai URL for a LoRA file"""
@@ -921,14 +762,9 @@ class ApiRoutes:
tags = [] tags = []
if file_path: if file_path:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
if os.path.exists(metadata_path): metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
try: description = metadata.get('modelDescription')
with open(metadata_path, 'r', encoding='utf-8') as f: tags = metadata.get('tags', [])
metadata = json.load(f)
description = metadata.get('modelDescription')
tags = metadata.get('tags', [])
except Exception as e:
logger.error(f"Error loading metadata from {metadata_path}: {e}")
# If description is not in metadata, fetch from CivitAI # If description is not in metadata, fetch from CivitAI
if not description: if not description:
@@ -943,16 +779,14 @@ class ApiRoutes:
if file_path: if file_path:
try: try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
if os.path.exists(metadata_path): metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f) metadata['modelDescription'] = description
metadata['tags'] = tags
metadata['modelDescription'] = description
metadata['tags'] = tags with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
with open(metadata_path, 'w', encoding='utf-8') as f: logger.info(f"Saved model metadata to file for {file_path}")
json.dump(metadata, f, indent=2, ensure_ascii=False)
logger.info(f"Saved model metadata to file for {file_path}")
except Exception as e: except Exception as e:
logger.error(f"Error saving model metadata: {e}") logger.error(f"Error saving model metadata: {e}")
@@ -1018,12 +852,6 @@ class ApiRoutes:
'error': str(e) 'error': str(e)
}, status=500) }, status=500)
def get_multipart_ext(self, filename):
parts = filename.split(".")
if len(parts) > 2: # 如果包含多级扩展名
return "." + ".".join(parts[-2:]) # 取最后两部分,如 ".metadata.json"
return os.path.splitext(filename)[1] # 否则取普通扩展名,如 ".safetensors"
async def rename_lora(self, request: web.Request) -> web.Response: async def rename_lora(self, request: web.Request) -> web.Response:
"""Handle renaming a LoRA file and its associated files""" """Handle renaming a LoRA file and its associated files"""
try: try:
@@ -1063,7 +891,7 @@ class ApiRoutes:
f"{old_file_name}.metadata.json", f"{old_file_name}.metadata.json",
] ]
# 添加所有预览文件扩展名 # Add all preview file extensions
for ext in PREVIEW_EXTENSIONS: for ext in PREVIEW_EXTENSIONS:
patterns.append(f"{old_file_name}{ext}") patterns.append(f"{old_file_name}{ext}")
@@ -1080,12 +908,8 @@ class ApiRoutes:
metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json") metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json")
if os.path.exists(metadata_path): if os.path.exists(metadata_path):
try: metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
with open(metadata_path, 'r', encoding='utf-8') as f: hash_value = metadata.get('sha256')
metadata = json.load(f)
hash_value = metadata.get('sha256')
except Exception as e:
logger.error(f"Error loading metadata for rename: {e}")
# Rename all files # Rename all files
renamed_files = [] renamed_files = []
@@ -1101,7 +925,7 @@ class ApiRoutes:
for old_path, pattern in existing_files: for old_path, pattern in existing_files:
# Get the file extension like .safetensors or .metadata.json # Get the file extension like .safetensors or .metadata.json
ext = self.get_multipart_ext(pattern) ext = ModelRouteUtils.get_multipart_ext(pattern)
# Create the new path # Create the new path
new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/') new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/')
@@ -1123,7 +947,7 @@ class ApiRoutes:
# Update preview_url if it exists # Update preview_url if it exists
if 'preview_url' in metadata and metadata['preview_url']: if 'preview_url' in metadata and metadata['preview_url']:
old_preview = metadata['preview_url'] old_preview = metadata['preview_url']
ext = self.get_multipart_ext(old_preview) ext = ModelRouteUtils.get_multipart_ext(old_preview)
new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/') new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/')
metadata['preview_url'] = new_preview metadata['preview_url'] = new_preview

View File

@@ -1,15 +1,10 @@
import os import os
import json import json
import asyncio
from typing import Dict
import aiohttp
import jinja2 import jinja2
from aiohttp import web from aiohttp import web
import logging import logging
from datetime import datetime
from ..utils.model_utils import determine_base_model
from ..utils.routes_common import ModelRouteUtils
from ..utils.constants import NSFW_LEVELS from ..utils.constants import NSFW_LEVELS
from ..services.civitai_client import CivitaiClient from ..services.civitai_client import CivitaiClient
from ..services.websocket_manager import ws_manager from ..services.websocket_manager import ws_manager
@@ -259,21 +254,9 @@ class CheckpointsRoutes:
"from_civitai": checkpoint.get("from_civitai", True), "from_civitai": checkpoint.get("from_civitai", True),
"notes": checkpoint.get("notes", ""), "notes": checkpoint.get("notes", ""),
"model_type": checkpoint.get("model_type", "checkpoint"), "model_type": checkpoint.get("model_type", "checkpoint"),
"civitai": self._filter_civitai_data(checkpoint.get("civitai", {})) "civitai": ModelRouteUtils.filter_civitai_data(checkpoint.get("civitai", {}))
} }
def _filter_civitai_data(self, data):
"""Filter relevant fields from CivitAI data"""
if not data:
return {}
fields = [
"id", "modelId", "name", "createdAt", "updatedAt",
"publishedAt", "trainedWords", "baseModel", "description",
"model", "images"
]
return {k: data[k] for k in fields if k in data}
async def fetch_all_civitai(self, request: web.Request) -> web.Response: async def fetch_all_civitai(self, request: web.Request) -> web.Response:
"""Fetch CivitAI metadata for all checkpoints in the background""" """Fetch CivitAI metadata for all checkpoints in the background"""
try: try:
@@ -302,10 +285,11 @@ class CheckpointsRoutes:
for cp in to_process: for cp in to_process:
try: try:
original_name = cp.get('model_name') original_name = cp.get('model_name')
if await self._fetch_and_update_single_checkpoint( if await ModelRouteUtils.fetch_and_update_model(
sha256=cp['sha256'], sha256=cp['sha256'],
file_path=cp['file_path'], file_path=cp['file_path'],
checkpoint=cp model_data=cp,
update_cache_func=self.scanner.update_single_model_cache
): ):
success += 1 success += 1
if original_name != cp.get('model_name'): if original_name != cp.get('model_name'):
@@ -350,99 +334,6 @@ class CheckpointsRoutes:
logger.error(f"Error in fetch_all_civitai for checkpoints: {e}") logger.error(f"Error in fetch_all_civitai for checkpoints: {e}")
return web.Response(text=str(e), status=500) return web.Response(text=str(e), status=500)
async def _fetch_and_update_single_checkpoint(self, sha256: str, file_path: str, checkpoint: dict) -> bool:
"""Fetch and update metadata for a single checkpoint without sorting"""
client = CivitaiClient()
try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Load local metadata
local_metadata = self._load_local_metadata(metadata_path)
# Fetch metadata from Civitai
civitai_metadata = await client.get_model_by_hash(sha256)
if not civitai_metadata:
# Mark as not from CivitAI if not found
local_metadata['from_civitai'] = False
checkpoint['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
return False
# Update metadata with Civitai data
await self._update_model_metadata(
metadata_path,
local_metadata,
civitai_metadata,
client
)
# Update cache object directly
checkpoint.update({
'model_name': local_metadata.get('model_name'),
'preview_url': local_metadata.get('preview_url'),
'from_civitai': True,
'civitai': civitai_metadata
})
return True
except Exception as e:
logger.error(f"Error fetching CivitAI data for checkpoint: {e}")
return False
finally:
await client.close()
def _load_local_metadata(self, metadata_path: str) -> Dict:
"""Load local metadata file"""
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading metadata from {metadata_path}: {e}")
return {}
async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict,
civitai_metadata: Dict, client: CivitaiClient) -> None:
"""Update local metadata with CivitAI data"""
local_metadata['civitai'] = civitai_metadata
# Update model name if available
if 'model' in civitai_metadata:
if civitai_metadata.get('model', {}).get('name'):
local_metadata['model_name'] = civitai_metadata['model']['name']
# Fetch additional model metadata (description and tags) if we have model ID
model_id = civitai_metadata['modelId']
if model_id:
model_metadata, _ = await client.get_model_metadata(str(model_id))
if model_metadata:
local_metadata['modelDescription'] = model_metadata.get('description', '')
local_metadata['tags'] = model_metadata.get('tags', [])
# Update base model
local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel'))
# Update preview if needed
if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']):
first_preview = next((img for img in civitai_metadata.get('images', [])), None)
if first_preview:
preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1]
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
preview_filename = base_name + preview_ext
preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename)
if await client.download_preview_image(first_preview['url'], preview_path):
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
await self.scanner.update_single_model_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata)
async def get_top_tags(self, request: web.Request) -> web.Response: async def get_top_tags(self, request: web.Request) -> web.Response:
"""Handle request for top tags sorted by frequency""" """Handle request for top tags sorted by frequency"""
try: try:

252
py/utils/routes_common.py Normal file
View File

@@ -0,0 +1,252 @@
import os
import json
import logging
from typing import Dict, List, Callable, Awaitable
from .model_utils import determine_base_model
from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
from ..config import config
from ..services.civitai_client import CivitaiClient
from ..utils.exif_utils import ExifUtils
logger = logging.getLogger(__name__)
class ModelRouteUtils:
"""Shared utilities for model routes (LoRAs, Checkpoints, etc.)"""
@staticmethod
async def load_local_metadata(metadata_path: str) -> Dict:
"""Load local metadata file"""
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading metadata from {metadata_path}: {e}")
return {}
@staticmethod
async def handle_not_found_on_civitai(metadata_path: str, local_metadata: Dict) -> None:
"""Handle case when model is not found on CivitAI"""
local_metadata['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
@staticmethod
async def update_model_metadata(metadata_path: str, local_metadata: Dict,
civitai_metadata: Dict, client: CivitaiClient) -> None:
"""Update local metadata with CivitAI data"""
local_metadata['civitai'] = civitai_metadata
# Update model name if available
if 'model' in civitai_metadata:
if civitai_metadata.get('model', {}).get('name'):
local_metadata['model_name'] = civitai_metadata['model']['name']
# Fetch additional model metadata (description and tags) if we have model ID
model_id = civitai_metadata['modelId']
if model_id:
model_metadata, _ = await client.get_model_metadata(str(model_id))
if model_metadata:
local_metadata['modelDescription'] = model_metadata.get('description', '')
local_metadata['tags'] = model_metadata.get('tags', [])
# Update base model
local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel'))
# Update preview if needed
if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']):
first_preview = next((img for img in civitai_metadata.get('images', [])), None)
if first_preview:
# Determine if content is video or image
is_video = first_preview['type'] == 'video'
if is_video:
# For videos use .mp4 extension
preview_ext = '.mp4'
else:
# For images use .webp extension
preview_ext = '.webp'
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
preview_filename = base_name + preview_ext
preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename)
if is_video:
# Download video as is
if await client.download_preview_image(first_preview['url'], preview_path):
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
else:
# For images, download and then optimize to WebP
temp_path = preview_path + ".temp"
if await client.download_preview_image(first_preview['url'], temp_path):
try:
# Read the downloaded image
with open(temp_path, 'rb') as f:
image_data = f.read()
# Optimize and convert to WebP
optimized_data, _ = ExifUtils.optimize_image(
image_data=image_data,
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=True
)
# Save the optimized WebP image
with open(preview_path, 'wb') as f:
f.write(optimized_data)
# Update metadata
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
# Remove the temporary file
if os.path.exists(temp_path):
os.remove(temp_path)
except Exception as e:
logger.error(f"Error optimizing preview image: {e}")
# If optimization fails, try to use the downloaded image directly
if os.path.exists(temp_path):
os.rename(temp_path, preview_path)
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
@staticmethod
async def fetch_and_update_model(
sha256: str,
file_path: str,
model_data: dict,
update_cache_func: Callable[[str, str, Dict], Awaitable[bool]]
) -> bool:
"""Fetch and update metadata for a single model
Args:
sha256: SHA256 hash of the model file
file_path: Path to the model file
model_data: The model object in cache to update
update_cache_func: Function to update the cache with new metadata
Returns:
bool: True if successful, False otherwise
"""
client = CivitaiClient()
try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Check if model metadata exists
local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
# Fetch metadata from Civitai
civitai_metadata = await client.get_model_by_hash(sha256)
if not civitai_metadata:
# Mark as not from CivitAI if not found
local_metadata['from_civitai'] = False
model_data['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
return False
# Update metadata
await ModelRouteUtils.update_model_metadata(
metadata_path,
local_metadata,
civitai_metadata,
client
)
# Update cache object directly
model_data.update({
'model_name': local_metadata.get('model_name'),
'preview_url': local_metadata.get('preview_url'),
'from_civitai': True,
'civitai': civitai_metadata
})
# Update cache using the provided function
await update_cache_func(file_path, file_path, local_metadata)
return True
except Exception as e:
logger.error(f"Error fetching CivitAI data: {e}")
return False
finally:
await client.close()
@staticmethod
def filter_civitai_data(data: Dict) -> Dict:
"""Filter relevant fields from CivitAI data"""
if not data:
return {}
fields = [
"id", "modelId", "name", "createdAt", "updatedAt",
"publishedAt", "trainedWords", "baseModel", "description",
"model", "images"
]
return {k: data[k] for k in fields if k in data}
@staticmethod
async def delete_model_files(target_dir: str, file_name: str, file_monitor=None) -> List[str]:
"""Delete model and associated files
Args:
target_dir: Directory containing the model files
file_name: Base name of the model file without extension
file_monitor: Optional file monitor to ignore delete events
Returns:
List of deleted file paths
"""
patterns = [
f"{file_name}.safetensors", # Required
f"{file_name}.metadata.json",
]
# Add all preview file extensions
for ext in PREVIEW_EXTENSIONS:
patterns.append(f"{file_name}{ext}")
deleted = []
main_file = patterns[0]
main_path = os.path.join(target_dir, main_file).replace(os.sep, '/')
if os.path.exists(main_path):
# Notify file monitor to ignore delete event if available
if file_monitor:
file_monitor.handler.add_ignore_path(main_path, 0)
# Delete file
os.remove(main_path)
deleted.append(main_path)
else:
logger.warning(f"Model file not found: {main_file}")
# Delete optional files
for pattern in patterns[1:]:
path = os.path.join(target_dir, pattern)
if os.path.exists(path):
try:
os.remove(path)
deleted.append(pattern)
except Exception as e:
logger.warning(f"Failed to delete {pattern}: {e}")
return deleted
@staticmethod
def get_multipart_ext(filename):
"""Get extension that may have multiple parts like .metadata.json"""
parts = filename.split(".")
if len(parts) > 2: # If contains multi-part extension
return "." + ".".join(parts[-2:]) # Take the last two parts, like ".metadata.json"
return os.path.splitext(filename)[1] # Otherwise take the regular extension, like ".safetensors"

View File

@@ -1,4 +1,4 @@
import { showToast } from '../utils/uiHelpers.js'; import { showToast, openCivitai } from '../utils/uiHelpers.js';
import { state } from '../state/index.js'; import { state } from '../state/index.js';
import { showLoraModal } from './loraModal/index.js'; import { showLoraModal } from './loraModal/index.js';
import { bulkManager } from '../managers/BulkManager.js'; import { bulkManager } from '../managers/BulkManager.js';