mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 14:12:11 -03:00
252 lines
11 KiB
Python
252 lines
11 KiB
Python
import os
|
|
import json
|
|
import logging
|
|
from typing import Dict, List, Callable, Awaitable
|
|
|
|
from .model_utils import determine_base_model
|
|
from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
|
|
from ..config import config
|
|
from ..services.civitai_client import CivitaiClient
|
|
from ..utils.exif_utils import ExifUtils
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ModelRouteUtils:
|
|
"""Shared utilities for model routes (LoRAs, Checkpoints, etc.)"""
|
|
|
|
@staticmethod
|
|
async def load_local_metadata(metadata_path: str) -> Dict:
|
|
"""Load local metadata file"""
|
|
if os.path.exists(metadata_path):
|
|
try:
|
|
with open(metadata_path, 'r', encoding='utf-8') as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
logger.error(f"Error loading metadata from {metadata_path}: {e}")
|
|
return {}
|
|
|
|
@staticmethod
|
|
async def handle_not_found_on_civitai(metadata_path: str, local_metadata: Dict) -> None:
|
|
"""Handle case when model is not found on CivitAI"""
|
|
local_metadata['from_civitai'] = False
|
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
|
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
|
|
|
|
@staticmethod
|
|
async def update_model_metadata(metadata_path: str, local_metadata: Dict,
|
|
civitai_metadata: Dict, client: CivitaiClient) -> None:
|
|
"""Update local metadata with CivitAI data"""
|
|
local_metadata['civitai'] = civitai_metadata
|
|
|
|
# Update model name if available
|
|
if 'model' in civitai_metadata:
|
|
if civitai_metadata.get('model', {}).get('name'):
|
|
local_metadata['model_name'] = civitai_metadata['model']['name']
|
|
|
|
# Fetch additional model metadata (description and tags) if we have model ID
|
|
model_id = civitai_metadata['modelId']
|
|
if model_id:
|
|
model_metadata, _ = await client.get_model_metadata(str(model_id))
|
|
if model_metadata:
|
|
local_metadata['modelDescription'] = model_metadata.get('description', '')
|
|
local_metadata['tags'] = model_metadata.get('tags', [])
|
|
|
|
# Update base model
|
|
local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel'))
|
|
|
|
# Update preview if needed
|
|
if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']):
|
|
first_preview = next((img for img in civitai_metadata.get('images', [])), None)
|
|
if first_preview:
|
|
# Determine if content is video or image
|
|
is_video = first_preview['type'] == 'video'
|
|
|
|
if is_video:
|
|
# For videos use .mp4 extension
|
|
preview_ext = '.mp4'
|
|
else:
|
|
# For images use .webp extension
|
|
preview_ext = '.webp'
|
|
|
|
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
|
|
preview_filename = base_name + preview_ext
|
|
preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename)
|
|
|
|
if is_video:
|
|
# Download video as is
|
|
if await client.download_preview_image(first_preview['url'], preview_path):
|
|
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
|
|
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
|
|
else:
|
|
# For images, download and then optimize to WebP
|
|
temp_path = preview_path + ".temp"
|
|
if await client.download_preview_image(first_preview['url'], temp_path):
|
|
try:
|
|
# Read the downloaded image
|
|
with open(temp_path, 'rb') as f:
|
|
image_data = f.read()
|
|
|
|
# Optimize and convert to WebP
|
|
optimized_data, _ = ExifUtils.optimize_image(
|
|
image_data=image_data,
|
|
target_width=CARD_PREVIEW_WIDTH,
|
|
format='webp',
|
|
quality=85,
|
|
preserve_metadata=True
|
|
)
|
|
|
|
# Save the optimized WebP image
|
|
with open(preview_path, 'wb') as f:
|
|
f.write(optimized_data)
|
|
|
|
# Update metadata
|
|
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
|
|
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
|
|
|
|
# Remove the temporary file
|
|
if os.path.exists(temp_path):
|
|
os.remove(temp_path)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error optimizing preview image: {e}")
|
|
# If optimization fails, try to use the downloaded image directly
|
|
if os.path.exists(temp_path):
|
|
os.rename(temp_path, preview_path)
|
|
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
|
|
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
|
|
|
|
# Save updated metadata
|
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
|
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
|
|
|
|
@staticmethod
|
|
async def fetch_and_update_model(
|
|
sha256: str,
|
|
file_path: str,
|
|
model_data: dict,
|
|
update_cache_func: Callable[[str, str, Dict], Awaitable[bool]]
|
|
) -> bool:
|
|
"""Fetch and update metadata for a single model
|
|
|
|
Args:
|
|
sha256: SHA256 hash of the model file
|
|
file_path: Path to the model file
|
|
model_data: The model object in cache to update
|
|
update_cache_func: Function to update the cache with new metadata
|
|
|
|
Returns:
|
|
bool: True if successful, False otherwise
|
|
"""
|
|
client = CivitaiClient()
|
|
try:
|
|
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
|
|
|
# Check if model metadata exists
|
|
local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
|
|
|
# Fetch metadata from Civitai
|
|
civitai_metadata = await client.get_model_by_hash(sha256)
|
|
if not civitai_metadata:
|
|
# Mark as not from CivitAI if not found
|
|
local_metadata['from_civitai'] = False
|
|
model_data['from_civitai'] = False
|
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
|
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
|
|
return False
|
|
|
|
# Update metadata
|
|
await ModelRouteUtils.update_model_metadata(
|
|
metadata_path,
|
|
local_metadata,
|
|
civitai_metadata,
|
|
client
|
|
)
|
|
|
|
# Update cache object directly
|
|
model_data.update({
|
|
'model_name': local_metadata.get('model_name'),
|
|
'preview_url': local_metadata.get('preview_url'),
|
|
'from_civitai': True,
|
|
'civitai': civitai_metadata
|
|
})
|
|
|
|
# Update cache using the provided function
|
|
await update_cache_func(file_path, file_path, local_metadata)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching CivitAI data: {e}")
|
|
return False
|
|
finally:
|
|
await client.close()
|
|
|
|
@staticmethod
|
|
def filter_civitai_data(data: Dict) -> Dict:
|
|
"""Filter relevant fields from CivitAI data"""
|
|
if not data:
|
|
return {}
|
|
|
|
fields = [
|
|
"id", "modelId", "name", "createdAt", "updatedAt",
|
|
"publishedAt", "trainedWords", "baseModel", "description",
|
|
"model", "images"
|
|
]
|
|
return {k: data[k] for k in fields if k in data}
|
|
|
|
@staticmethod
|
|
async def delete_model_files(target_dir: str, file_name: str, file_monitor=None) -> List[str]:
|
|
"""Delete model and associated files
|
|
|
|
Args:
|
|
target_dir: Directory containing the model files
|
|
file_name: Base name of the model file without extension
|
|
file_monitor: Optional file monitor to ignore delete events
|
|
|
|
Returns:
|
|
List of deleted file paths
|
|
"""
|
|
patterns = [
|
|
f"{file_name}.safetensors", # Required
|
|
f"{file_name}.metadata.json",
|
|
]
|
|
|
|
# Add all preview file extensions
|
|
for ext in PREVIEW_EXTENSIONS:
|
|
patterns.append(f"{file_name}{ext}")
|
|
|
|
deleted = []
|
|
main_file = patterns[0]
|
|
main_path = os.path.join(target_dir, main_file).replace(os.sep, '/')
|
|
|
|
if os.path.exists(main_path):
|
|
# Notify file monitor to ignore delete event if available
|
|
if file_monitor:
|
|
file_monitor.handler.add_ignore_path(main_path, 0)
|
|
|
|
# Delete file
|
|
os.remove(main_path)
|
|
deleted.append(main_path)
|
|
else:
|
|
logger.warning(f"Model file not found: {main_file}")
|
|
|
|
# Delete optional files
|
|
for pattern in patterns[1:]:
|
|
path = os.path.join(target_dir, pattern)
|
|
if os.path.exists(path):
|
|
try:
|
|
os.remove(path)
|
|
deleted.append(pattern)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to delete {pattern}: {e}")
|
|
|
|
return deleted
|
|
|
|
@staticmethod
|
|
def get_multipart_ext(filename):
|
|
"""Get extension that may have multiple parts like .metadata.json"""
|
|
parts = filename.split(".")
|
|
if len(parts) > 2: # If contains multi-part extension
|
|
return "." + ".".join(parts[-2:]) # Take the last two parts, like ".metadata.json"
|
|
return os.path.splitext(filename)[1] # Otherwise take the regular extension, like ".safetensors" |