diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index a35bf2b2..331cadb4 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -2,9 +2,9 @@ import os import json import logging from aiohttp import web -from typing import Dict, List +from typing import Dict -from ..utils.model_utils import determine_base_model +from ..utils.routes_common import ModelRouteUtils from ..services.file_monitor import LoraFileMonitor from ..services.download_manager import DownloadManager @@ -72,7 +72,19 @@ class ApiRoutes: target_dir = os.path.dirname(file_path) file_name = os.path.splitext(os.path.basename(file_path))[0] - deleted_files = await self._delete_model_files(target_dir, file_name) + deleted_files = await ModelRouteUtils.delete_model_files( + target_dir, + file_name, + self.download_manager.file_monitor + ) + + # Remove from cache + cache = await self.scanner.get_cached_data() + cache.raw_data = [item for item in cache.raw_data if item['file_path'] != file_path] + await cache.resort() + + # update hash index + self.scanner._hash_index.remove_by_path(file_path) return web.json_response({ 'success': True, @@ -90,14 +102,18 @@ class ApiRoutes: metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json' # Check if model is from CivitAI - local_metadata = await self._load_local_metadata(metadata_path) + local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path) # Fetch and update metadata civitai_metadata = await self.civitai_client.get_model_by_hash(local_metadata["sha256"]) if not civitai_metadata: - return await self._handle_not_found_on_civitai(metadata_path, local_metadata) + await ModelRouteUtils.handle_not_found_on_civitai(metadata_path, local_metadata) + return web.json_response({"success": False, "error": "Not found on CivitAI"}, status=404) - await self._update_model_metadata(metadata_path, local_metadata, civitai_metadata, self.civitai_client) + await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, self.civitai_client) + + # Update the cache + await self.scanner.update_single_model_cache(data['file_path'], data['file_path'], local_metadata) return web.json_response({"success": True}) @@ -139,10 +155,12 @@ class ApiRoutes: fuzzy_search = request.query.get('fuzzy', 'false').lower() == 'true' # Parse search options - search_filename = request.query.get('search_filename', 'true').lower() == 'true' - search_modelname = request.query.get('search_modelname', 'true').lower() == 'true' - search_tags = request.query.get('search_tags', 'false').lower() == 'true' - recursive = request.query.get('recursive', 'false').lower() == 'true' + search_options = { + 'filename': request.query.get('search_filename', 'true').lower() == 'true', + 'modelname': request.query.get('search_modelname', 'true').lower() == 'true', + 'tags': request.query.get('search_tags', 'false').lower() == 'true', + 'recursive': request.query.get('recursive', 'false').lower() == 'true' + } # Get filter parameters base_models = request.query.get('base_models', None) @@ -159,14 +177,6 @@ class ApiRoutes: if tags: filters['tags'] = tags.split(',') - # Add search options to filters - search_options = { - 'filename': search_filename, - 'modelname': search_modelname, - 'tags': search_tags, - 'recursive': recursive - } - # Add lora hash filtering options hash_filters = {} if lora_hash: @@ -225,67 +235,10 @@ class ApiRoutes: "from_civitai": lora.get("from_civitai", True), "usage_tips": lora.get("usage_tips", ""), "notes": lora.get("notes", ""), - "civitai": self._filter_civitai_data(lora.get("civitai", {})) + "civitai": ModelRouteUtils.filter_civitai_data(lora.get("civitai", {})) } - def _filter_civitai_data(self, data: Dict) -> Dict: - """Filter relevant fields from CivitAI data""" - if not data: - return {} - - fields = [ - "id", "modelId", "name", "createdAt", "updatedAt", - "publishedAt", "trainedWords", "baseModel", "description", - "model", "images" - ] - return {k: data[k] for k in fields if k in data} - # Private helper methods - async def _delete_model_files(self, target_dir: str, file_name: str) -> List[str]: - """Delete model and associated files""" - patterns = [ - f"{file_name}.safetensors", # Required - f"{file_name}.metadata.json", - ] - - # 添加所有预览文件扩展名 - for ext in PREVIEW_EXTENSIONS: - patterns.append(f"{file_name}{ext}") - - deleted = [] - main_file = patterns[0] - main_path = os.path.join(target_dir, main_file).replace(os.sep, '/') - - if os.path.exists(main_path): - # Notify file monitor to ignore delete event - self.download_manager.file_monitor.handler.add_ignore_path(main_path, 0) - - # Delete file - os.remove(main_path) - deleted.append(main_path) - else: - logger.warning(f"Model file not found: {main_file}") - - # Remove from cache - cache = await self.scanner.get_cached_data() - cache.raw_data = [item for item in cache.raw_data if item['file_path'] != main_path] - await cache.resort() - - # update hash index - self.scanner._hash_index.remove_by_path(main_path) - - # Delete optional files - for pattern in patterns[1:]: - path = os.path.join(target_dir, pattern) - if os.path.exists(path): - try: - os.remove(path) - deleted.append(pattern) - except Exception as e: - logger.warning(f"Failed to delete {pattern}: {e}") - - return deleted - async def _read_preview_file(self, reader) -> tuple[bytes, str]: """Read preview file and content type from multipart request""" field = await reader.next() @@ -345,66 +298,6 @@ class ApiRoutes: except Exception as e: logger.error(f"Error updating metadata: {e}") - async def _load_local_metadata(self, metadata_path: str) -> Dict: - """Load local metadata file""" - if os.path.exists(metadata_path): - try: - with open(metadata_path, 'r', encoding='utf-8') as f: - return json.load(f) - except Exception as e: - logger.error(f"Error loading metadata from {metadata_path}: {e}") - return {} - - async def _handle_not_found_on_civitai(self, metadata_path: str, local_metadata: Dict) -> web.Response: - """Handle case when model is not found on CivitAI""" - local_metadata['from_civitai'] = False - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(local_metadata, f, indent=2, ensure_ascii=False) - return web.json_response( - {"success": False, "error": "Not found on CivitAI"}, - status=404 - ) - - async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict, - civitai_metadata: Dict, client: CivitaiClient) -> None: - """Update local metadata with CivitAI data""" - local_metadata['civitai'] = civitai_metadata - - # Update model name if available - if 'model' in civitai_metadata: - if civitai_metadata.get('model', {}).get('name'): - local_metadata['model_name'] = civitai_metadata['model']['name'] - - # Fetch additional model metadata (description and tags) if we have model ID - model_id = civitai_metadata['modelId'] - if model_id: - model_metadata, _ = await client.get_model_metadata(str(model_id)) - if model_metadata: - local_metadata['modelDescription'] = model_metadata.get('description', '') - local_metadata['tags'] = model_metadata.get('tags', []) - - # Update base model - local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel')) - - # Update preview if needed - if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']): - first_preview = next((img for img in civitai_metadata.get('images', [])), None) - if first_preview: - preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1] - base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0] - preview_filename = base_name + preview_ext - preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename) - - if await client.download_preview_image(first_preview['url'], preview_path): - local_metadata['preview_url'] = preview_path.replace(os.sep, '/') - local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0) - - # Save updated metadata - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(local_metadata, f, indent=2, ensure_ascii=False) - - await self.scanner.update_single_model_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata) - async def fetch_all_civitai(self, request: web.Request) -> web.Response: """Fetch CivitAI metadata for all loras in the background""" try: @@ -414,14 +307,14 @@ class ApiRoutes: success = 0 needs_resort = False - # 准备要处理的 loras + # Prepare loras to process to_process = [ lora for lora in cache.raw_data - if lora.get('sha256') and (not lora.get('civitai') or 'id' not in lora.get('civitai')) and lora.get('from_civitai') # TODO: for lora not from CivitAI but added traineWords + if lora.get('sha256') and (not lora.get('civitai') or 'id' not in lora.get('civitai')) and lora.get('from_civitai', True) # TODO: for lora not from CivitAI but added traineWords ] total_to_process = len(to_process) - # 发送初始进度 + # Send initial progress await ws_manager.broadcast({ 'status': 'started', 'total': total_to_process, @@ -432,10 +325,11 @@ class ApiRoutes: for lora in to_process: try: original_name = lora.get('model_name') - if await self._fetch_and_update_single_lora( + if await ModelRouteUtils.fetch_and_update_model( sha256=lora['sha256'], file_path=lora['file_path'], - lora=lora + model_data=lora, + update_cache_func=self.scanner.update_single_model_cache ): success += 1 if original_name != lora.get('model_name'): @@ -443,7 +337,7 @@ class ApiRoutes: processed += 1 - # 每处理一个就发送进度更新 + # Send progress update await ws_manager.broadcast({ 'status': 'processing', 'total': total_to_process, @@ -458,7 +352,7 @@ class ApiRoutes: if needs_resort: await cache.resort(name_only=True) - # 发送完成消息 + # Send completion message await ws_manager.broadcast({ 'status': 'completed', 'total': total_to_process, @@ -472,7 +366,7 @@ class ApiRoutes: }) except Exception as e: - # 发送错误消息 + # Send error message await ws_manager.broadcast({ 'status': 'error', 'error': str(e) @@ -480,58 +374,6 @@ class ApiRoutes: logger.error(f"Error in fetch_all_civitai: {e}") return web.Response(text=str(e), status=500) - async def _fetch_and_update_single_lora(self, sha256: str, file_path: str, lora: dict) -> bool: - """Fetch and update metadata for a single lora without sorting - - Args: - sha256: SHA256 hash of the lora file - file_path: Path to the lora file - lora: The lora object in cache to update - - Returns: - bool: True if successful, False otherwise - """ - client = CivitaiClient() - try: - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - - # Check if model is from CivitAI - local_metadata = await self._load_local_metadata(metadata_path) - - # Fetch metadata - civitai_metadata = await client.get_model_by_hash(sha256) - if not civitai_metadata: - # Mark as not from CivitAI if not found - local_metadata['from_civitai'] = False - lora['from_civitai'] = False - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(local_metadata, f, indent=2, ensure_ascii=False) - return False - - # Update metadata - await self._update_model_metadata( - metadata_path, - local_metadata, - civitai_metadata, - client - ) - - # Update cache object directly - lora.update({ - 'model_name': local_metadata.get('model_name'), - 'preview_url': local_metadata.get('preview_url'), - 'from_civitai': True, - 'civitai': civitai_metadata - }) - - return True - - except Exception as e: - logger.error(f"Error fetching CivitAI data: {e}") - return False - finally: - await client.close() - async def get_lora_roots(self, request: web.Request) -> web.Response: """Get all configured LoRA root directories""" return web.json_response({ @@ -669,7 +511,7 @@ class ApiRoutes: return web.json_response({'success': True}) except Exception as e: - logger.error(f"Error updating settings: {e}", exc_info=True) # 添加 exc_info=True 以获取完整堆栈 + logger.error(f"Error updating settings: {e}", exc_info=True) return web.Response(status=500, text=str(e)) async def move_model(self, request: web.Request) -> web.Response: @@ -731,11 +573,7 @@ class ApiRoutes: metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' # Load existing metadata - if os.path.exists(metadata_path): - with open(metadata_path, 'r', encoding='utf-8') as f: - metadata = json.load(f) - else: - metadata = {} + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) # Handle nested updates (for civitai.trainedWords) for key, value in metadata_updates.items(): @@ -798,7 +636,10 @@ class ApiRoutes: except Exception as e: logger.error(f"Error getting lora preview URL: {e}", exc_info=True) - return web.Response(text=str(e), status=500) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) async def get_lora_civitai_url(self, request: web.Request) -> web.Response: """Get the Civitai URL for a LoRA file""" @@ -921,14 +762,9 @@ class ApiRoutes: tags = [] if file_path: metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - if os.path.exists(metadata_path): - try: - with open(metadata_path, 'r', encoding='utf-8') as f: - metadata = json.load(f) - description = metadata.get('modelDescription') - tags = metadata.get('tags', []) - except Exception as e: - logger.error(f"Error loading metadata from {metadata_path}: {e}") + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + description = metadata.get('modelDescription') + tags = metadata.get('tags', []) # If description is not in metadata, fetch from CivitAI if not description: @@ -943,16 +779,14 @@ class ApiRoutes: if file_path: try: metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - if os.path.exists(metadata_path): - with open(metadata_path, 'r', encoding='utf-8') as f: - metadata = json.load(f) - - metadata['modelDescription'] = description - metadata['tags'] = tags - - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(metadata, f, indent=2, ensure_ascii=False) - logger.info(f"Saved model metadata to file for {file_path}") + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + + metadata['modelDescription'] = description + metadata['tags'] = tags + + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) + logger.info(f"Saved model metadata to file for {file_path}") except Exception as e: logger.error(f"Error saving model metadata: {e}") @@ -1018,12 +852,6 @@ class ApiRoutes: 'error': str(e) }, status=500) - def get_multipart_ext(self, filename): - parts = filename.split(".") - if len(parts) > 2: # 如果包含多级扩展名 - return "." + ".".join(parts[-2:]) # 取最后两部分,如 ".metadata.json" - return os.path.splitext(filename)[1] # 否则取普通扩展名,如 ".safetensors" - async def rename_lora(self, request: web.Request) -> web.Response: """Handle renaming a LoRA file and its associated files""" try: @@ -1063,7 +891,7 @@ class ApiRoutes: f"{old_file_name}.metadata.json", ] - # 添加所有预览文件扩展名 + # Add all preview file extensions for ext in PREVIEW_EXTENSIONS: patterns.append(f"{old_file_name}{ext}") @@ -1080,12 +908,8 @@ class ApiRoutes: metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json") if os.path.exists(metadata_path): - try: - with open(metadata_path, 'r', encoding='utf-8') as f: - metadata = json.load(f) - hash_value = metadata.get('sha256') - except Exception as e: - logger.error(f"Error loading metadata for rename: {e}") + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + hash_value = metadata.get('sha256') # Rename all files renamed_files = [] @@ -1101,7 +925,7 @@ class ApiRoutes: for old_path, pattern in existing_files: # Get the file extension like .safetensors or .metadata.json - ext = self.get_multipart_ext(pattern) + ext = ModelRouteUtils.get_multipart_ext(pattern) # Create the new path new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/') @@ -1123,7 +947,7 @@ class ApiRoutes: # Update preview_url if it exists if 'preview_url' in metadata and metadata['preview_url']: old_preview = metadata['preview_url'] - ext = self.get_multipart_ext(old_preview) + ext = ModelRouteUtils.get_multipart_ext(old_preview) new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/') metadata['preview_url'] = new_preview diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index fcd47c60..af16f732 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -1,15 +1,10 @@ import os import json -import asyncio -from typing import Dict -import aiohttp import jinja2 from aiohttp import web import logging -from datetime import datetime - -from ..utils.model_utils import determine_base_model +from ..utils.routes_common import ModelRouteUtils from ..utils.constants import NSFW_LEVELS from ..services.civitai_client import CivitaiClient from ..services.websocket_manager import ws_manager @@ -259,21 +254,9 @@ class CheckpointsRoutes: "from_civitai": checkpoint.get("from_civitai", True), "notes": checkpoint.get("notes", ""), "model_type": checkpoint.get("model_type", "checkpoint"), - "civitai": self._filter_civitai_data(checkpoint.get("civitai", {})) + "civitai": ModelRouteUtils.filter_civitai_data(checkpoint.get("civitai", {})) } - def _filter_civitai_data(self, data): - """Filter relevant fields from CivitAI data""" - if not data: - return {} - - fields = [ - "id", "modelId", "name", "createdAt", "updatedAt", - "publishedAt", "trainedWords", "baseModel", "description", - "model", "images" - ] - return {k: data[k] for k in fields if k in data} - async def fetch_all_civitai(self, request: web.Request) -> web.Response: """Fetch CivitAI metadata for all checkpoints in the background""" try: @@ -302,10 +285,11 @@ class CheckpointsRoutes: for cp in to_process: try: original_name = cp.get('model_name') - if await self._fetch_and_update_single_checkpoint( + if await ModelRouteUtils.fetch_and_update_model( sha256=cp['sha256'], file_path=cp['file_path'], - checkpoint=cp + model_data=cp, + update_cache_func=self.scanner.update_single_model_cache ): success += 1 if original_name != cp.get('model_name'): @@ -350,99 +334,6 @@ class CheckpointsRoutes: logger.error(f"Error in fetch_all_civitai for checkpoints: {e}") return web.Response(text=str(e), status=500) - async def _fetch_and_update_single_checkpoint(self, sha256: str, file_path: str, checkpoint: dict) -> bool: - """Fetch and update metadata for a single checkpoint without sorting""" - client = CivitaiClient() - try: - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - - # Load local metadata - local_metadata = self._load_local_metadata(metadata_path) - - # Fetch metadata from Civitai - civitai_metadata = await client.get_model_by_hash(sha256) - if not civitai_metadata: - # Mark as not from CivitAI if not found - local_metadata['from_civitai'] = False - checkpoint['from_civitai'] = False - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(local_metadata, f, indent=2, ensure_ascii=False) - return False - - # Update metadata with Civitai data - await self._update_model_metadata( - metadata_path, - local_metadata, - civitai_metadata, - client - ) - - # Update cache object directly - checkpoint.update({ - 'model_name': local_metadata.get('model_name'), - 'preview_url': local_metadata.get('preview_url'), - 'from_civitai': True, - 'civitai': civitai_metadata - }) - - return True - - except Exception as e: - logger.error(f"Error fetching CivitAI data for checkpoint: {e}") - return False - finally: - await client.close() - - def _load_local_metadata(self, metadata_path: str) -> Dict: - """Load local metadata file""" - if os.path.exists(metadata_path): - try: - with open(metadata_path, 'r', encoding='utf-8') as f: - return json.load(f) - except Exception as e: - logger.error(f"Error loading metadata from {metadata_path}: {e}") - return {} - - async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict, - civitai_metadata: Dict, client: CivitaiClient) -> None: - """Update local metadata with CivitAI data""" - local_metadata['civitai'] = civitai_metadata - - # Update model name if available - if 'model' in civitai_metadata: - if civitai_metadata.get('model', {}).get('name'): - local_metadata['model_name'] = civitai_metadata['model']['name'] - - # Fetch additional model metadata (description and tags) if we have model ID - model_id = civitai_metadata['modelId'] - if model_id: - model_metadata, _ = await client.get_model_metadata(str(model_id)) - if model_metadata: - local_metadata['modelDescription'] = model_metadata.get('description', '') - local_metadata['tags'] = model_metadata.get('tags', []) - - # Update base model - local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel')) - - # Update preview if needed - if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']): - first_preview = next((img for img in civitai_metadata.get('images', [])), None) - if first_preview: - preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1] - base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0] - preview_filename = base_name + preview_ext - preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename) - - if await client.download_preview_image(first_preview['url'], preview_path): - local_metadata['preview_url'] = preview_path.replace(os.sep, '/') - local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0) - - # Save updated metadata - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(local_metadata, f, indent=2, ensure_ascii=False) - - await self.scanner.update_single_model_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata) - async def get_top_tags(self, request: web.Request) -> web.Response: """Handle request for top tags sorted by frequency""" try: diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py new file mode 100644 index 00000000..69ea63a1 --- /dev/null +++ b/py/utils/routes_common.py @@ -0,0 +1,252 @@ +import os +import json +import logging +from typing import Dict, List, Callable, Awaitable + +from .model_utils import determine_base_model +from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH +from ..config import config +from ..services.civitai_client import CivitaiClient +from ..utils.exif_utils import ExifUtils + +logger = logging.getLogger(__name__) + + +class ModelRouteUtils: + """Shared utilities for model routes (LoRAs, Checkpoints, etc.)""" + + @staticmethod + async def load_local_metadata(metadata_path: str) -> Dict: + """Load local metadata file""" + if os.path.exists(metadata_path): + try: + with open(metadata_path, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + logger.error(f"Error loading metadata from {metadata_path}: {e}") + return {} + + @staticmethod + async def handle_not_found_on_civitai(metadata_path: str, local_metadata: Dict) -> None: + """Handle case when model is not found on CivitAI""" + local_metadata['from_civitai'] = False + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(local_metadata, f, indent=2, ensure_ascii=False) + + @staticmethod + async def update_model_metadata(metadata_path: str, local_metadata: Dict, + civitai_metadata: Dict, client: CivitaiClient) -> None: + """Update local metadata with CivitAI data""" + local_metadata['civitai'] = civitai_metadata + + # Update model name if available + if 'model' in civitai_metadata: + if civitai_metadata.get('model', {}).get('name'): + local_metadata['model_name'] = civitai_metadata['model']['name'] + + # Fetch additional model metadata (description and tags) if we have model ID + model_id = civitai_metadata['modelId'] + if model_id: + model_metadata, _ = await client.get_model_metadata(str(model_id)) + if model_metadata: + local_metadata['modelDescription'] = model_metadata.get('description', '') + local_metadata['tags'] = model_metadata.get('tags', []) + + # Update base model + local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel')) + + # Update preview if needed + if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']): + first_preview = next((img for img in civitai_metadata.get('images', [])), None) + if first_preview: + # Determine if content is video or image + is_video = first_preview['type'] == 'video' + + if is_video: + # For videos use .mp4 extension + preview_ext = '.mp4' + else: + # For images use .webp extension + preview_ext = '.webp' + + base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0] + preview_filename = base_name + preview_ext + preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename) + + if is_video: + # Download video as is + if await client.download_preview_image(first_preview['url'], preview_path): + local_metadata['preview_url'] = preview_path.replace(os.sep, '/') + local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0) + else: + # For images, download and then optimize to WebP + temp_path = preview_path + ".temp" + if await client.download_preview_image(first_preview['url'], temp_path): + try: + # Read the downloaded image + with open(temp_path, 'rb') as f: + image_data = f.read() + + # Optimize and convert to WebP + optimized_data, _ = ExifUtils.optimize_image( + image_data=image_data, + target_width=CARD_PREVIEW_WIDTH, + format='webp', + quality=85, + preserve_metadata=True + ) + + # Save the optimized WebP image + with open(preview_path, 'wb') as f: + f.write(optimized_data) + + # Update metadata + local_metadata['preview_url'] = preview_path.replace(os.sep, '/') + local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0) + + # Remove the temporary file + if os.path.exists(temp_path): + os.remove(temp_path) + + except Exception as e: + logger.error(f"Error optimizing preview image: {e}") + # If optimization fails, try to use the downloaded image directly + if os.path.exists(temp_path): + os.rename(temp_path, preview_path) + local_metadata['preview_url'] = preview_path.replace(os.sep, '/') + local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0) + + # Save updated metadata + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(local_metadata, f, indent=2, ensure_ascii=False) + + @staticmethod + async def fetch_and_update_model( + sha256: str, + file_path: str, + model_data: dict, + update_cache_func: Callable[[str, str, Dict], Awaitable[bool]] + ) -> bool: + """Fetch and update metadata for a single model + + Args: + sha256: SHA256 hash of the model file + file_path: Path to the model file + model_data: The model object in cache to update + update_cache_func: Function to update the cache with new metadata + + Returns: + bool: True if successful, False otherwise + """ + client = CivitaiClient() + try: + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + + # Check if model metadata exists + local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + + # Fetch metadata from Civitai + civitai_metadata = await client.get_model_by_hash(sha256) + if not civitai_metadata: + # Mark as not from CivitAI if not found + local_metadata['from_civitai'] = False + model_data['from_civitai'] = False + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(local_metadata, f, indent=2, ensure_ascii=False) + return False + + # Update metadata + await ModelRouteUtils.update_model_metadata( + metadata_path, + local_metadata, + civitai_metadata, + client + ) + + # Update cache object directly + model_data.update({ + 'model_name': local_metadata.get('model_name'), + 'preview_url': local_metadata.get('preview_url'), + 'from_civitai': True, + 'civitai': civitai_metadata + }) + + # Update cache using the provided function + await update_cache_func(file_path, file_path, local_metadata) + + return True + + except Exception as e: + logger.error(f"Error fetching CivitAI data: {e}") + return False + finally: + await client.close() + + @staticmethod + def filter_civitai_data(data: Dict) -> Dict: + """Filter relevant fields from CivitAI data""" + if not data: + return {} + + fields = [ + "id", "modelId", "name", "createdAt", "updatedAt", + "publishedAt", "trainedWords", "baseModel", "description", + "model", "images" + ] + return {k: data[k] for k in fields if k in data} + + @staticmethod + async def delete_model_files(target_dir: str, file_name: str, file_monitor=None) -> List[str]: + """Delete model and associated files + + Args: + target_dir: Directory containing the model files + file_name: Base name of the model file without extension + file_monitor: Optional file monitor to ignore delete events + + Returns: + List of deleted file paths + """ + patterns = [ + f"{file_name}.safetensors", # Required + f"{file_name}.metadata.json", + ] + + # Add all preview file extensions + for ext in PREVIEW_EXTENSIONS: + patterns.append(f"{file_name}{ext}") + + deleted = [] + main_file = patterns[0] + main_path = os.path.join(target_dir, main_file).replace(os.sep, '/') + + if os.path.exists(main_path): + # Notify file monitor to ignore delete event if available + if file_monitor: + file_monitor.handler.add_ignore_path(main_path, 0) + + # Delete file + os.remove(main_path) + deleted.append(main_path) + else: + logger.warning(f"Model file not found: {main_file}") + + # Delete optional files + for pattern in patterns[1:]: + path = os.path.join(target_dir, pattern) + if os.path.exists(path): + try: + os.remove(path) + deleted.append(pattern) + except Exception as e: + logger.warning(f"Failed to delete {pattern}: {e}") + + return deleted + + @staticmethod + def get_multipart_ext(filename): + """Get extension that may have multiple parts like .metadata.json""" + parts = filename.split(".") + if len(parts) > 2: # If contains multi-part extension + return "." + ".".join(parts[-2:]) # Take the last two parts, like ".metadata.json" + return os.path.splitext(filename)[1] # Otherwise take the regular extension, like ".safetensors" \ No newline at end of file diff --git a/static/js/components/LoraCard.js b/static/js/components/LoraCard.js index e6534081..86737b59 100644 --- a/static/js/components/LoraCard.js +++ b/static/js/components/LoraCard.js @@ -1,4 +1,4 @@ -import { showToast } from '../utils/uiHelpers.js'; +import { showToast, openCivitai } from '../utils/uiHelpers.js'; import { state } from '../state/index.js'; import { showLoraModal } from './loraModal/index.js'; import { bulkManager } from '../managers/BulkManager.js';