diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index efd480dd..fcd47c60 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -1,12 +1,18 @@ import os import json import asyncio +from typing import Dict import aiohttp import jinja2 from aiohttp import web import logging from datetime import datetime +from ..utils.model_utils import determine_base_model + +from ..utils.constants import NSFW_LEVELS +from ..services.civitai_client import CivitaiClient +from ..services.websocket_manager import ws_manager from ..services.checkpoint_scanner import CheckpointScanner from ..config import config from ..services.settings_manager import settings @@ -28,6 +34,7 @@ class CheckpointsRoutes: """Register routes with the aiohttp app""" app.router.add_get('/checkpoints', self.handle_checkpoints_page) app.router.add_get('/api/checkpoints', self.get_checkpoints) + app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai) app.router.add_get('/api/checkpoints/base-models', self.get_base_models) app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags) app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints) @@ -267,6 +274,175 @@ class CheckpointsRoutes: ] return {k: data[k] for k in fields if k in data} + async def fetch_all_civitai(self, request: web.Request) -> web.Response: + """Fetch CivitAI metadata for all checkpoints in the background""" + try: + cache = await self.scanner.get_cached_data() + total = len(cache.raw_data) + processed = 0 + success = 0 + needs_resort = False + + # Prepare checkpoints to process + to_process = [ + cp for cp in cache.raw_data + if cp.get('sha256') and (not cp.get('civitai') or 'id' not in cp.get('civitai')) and cp.get('from_civitai', True) + ] + total_to_process = len(to_process) + + # Send initial progress + await ws_manager.broadcast({ + 'status': 'started', + 'total': total_to_process, + 'processed': 0, + 'success': 0 + }) + + # Process each checkpoint + for cp in to_process: + try: + original_name = cp.get('model_name') + if await self._fetch_and_update_single_checkpoint( + sha256=cp['sha256'], + file_path=cp['file_path'], + checkpoint=cp + ): + success += 1 + if original_name != cp.get('model_name'): + needs_resort = True + + processed += 1 + + # Send progress update + await ws_manager.broadcast({ + 'status': 'processing', + 'total': total_to_process, + 'processed': processed, + 'success': success, + 'current_name': cp.get('model_name', 'Unknown') + }) + + except Exception as e: + logger.error(f"Error fetching CivitAI data for {cp['file_path']}: {e}") + + if needs_resort: + await cache.resort(name_only=True) + + # Send completion message + await ws_manager.broadcast({ + 'status': 'completed', + 'total': total_to_process, + 'processed': processed, + 'success': success + }) + + return web.json_response({ + "success": True, + "message": f"Successfully updated {success} of {processed} processed checkpoints (total: {total})" + }) + + except Exception as e: + # Send error message + await ws_manager.broadcast({ + 'status': 'error', + 'error': str(e) + }) + logger.error(f"Error in fetch_all_civitai for checkpoints: {e}") + return web.Response(text=str(e), status=500) + + async def _fetch_and_update_single_checkpoint(self, sha256: str, file_path: str, checkpoint: dict) -> bool: + """Fetch and update metadata for a single checkpoint without sorting""" + client = CivitaiClient() + try: + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + + # Load local metadata + local_metadata = self._load_local_metadata(metadata_path) + + # Fetch metadata from Civitai + civitai_metadata = await client.get_model_by_hash(sha256) + if not civitai_metadata: + # Mark as not from CivitAI if not found + local_metadata['from_civitai'] = False + checkpoint['from_civitai'] = False + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(local_metadata, f, indent=2, ensure_ascii=False) + return False + + # Update metadata with Civitai data + await self._update_model_metadata( + metadata_path, + local_metadata, + civitai_metadata, + client + ) + + # Update cache object directly + checkpoint.update({ + 'model_name': local_metadata.get('model_name'), + 'preview_url': local_metadata.get('preview_url'), + 'from_civitai': True, + 'civitai': civitai_metadata + }) + + return True + + except Exception as e: + logger.error(f"Error fetching CivitAI data for checkpoint: {e}") + return False + finally: + await client.close() + + def _load_local_metadata(self, metadata_path: str) -> Dict: + """Load local metadata file""" + if os.path.exists(metadata_path): + try: + with open(metadata_path, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + logger.error(f"Error loading metadata from {metadata_path}: {e}") + return {} + + async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict, + civitai_metadata: Dict, client: CivitaiClient) -> None: + """Update local metadata with CivitAI data""" + local_metadata['civitai'] = civitai_metadata + + # Update model name if available + if 'model' in civitai_metadata: + if civitai_metadata.get('model', {}).get('name'): + local_metadata['model_name'] = civitai_metadata['model']['name'] + + # Fetch additional model metadata (description and tags) if we have model ID + model_id = civitai_metadata['modelId'] + if model_id: + model_metadata, _ = await client.get_model_metadata(str(model_id)) + if model_metadata: + local_metadata['modelDescription'] = model_metadata.get('description', '') + local_metadata['tags'] = model_metadata.get('tags', []) + + # Update base model + local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel')) + + # Update preview if needed + if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']): + first_preview = next((img for img in civitai_metadata.get('images', [])), None) + if first_preview: + preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1] + base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0] + preview_filename = base_name + preview_ext + preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename) + + if await client.download_preview_image(first_preview['url'], preview_path): + local_metadata['preview_url'] = preview_path.replace(os.sep, '/') + local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0) + + # Save updated metadata + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(local_metadata, f, indent=2, ensure_ascii=False) + + await self.scanner.update_single_model_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata) + async def get_top_tags(self, request: web.Request) -> web.Response: """Handle request for top tags sorted by frequency""" try: diff --git a/static/js/api/checkpointApi.js b/static/js/api/checkpointApi.js index 1dd0c8fe..e0ed5d4f 100644 --- a/static/js/api/checkpointApi.js +++ b/static/js/api/checkpointApi.js @@ -249,4 +249,82 @@ async function _uploadPreview(filePath, file) { console.error('Error updating preview:', error); showToast(`Failed to update preview: ${error.message}`, 'error'); } +} + +// Fetch metadata from Civitai for checkpoints +export async function fetchCivitai() { + let ws = null; + + await state.loadingManager.showWithProgress(async (loading) => { + try { + const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; + const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`); + + const operationComplete = new Promise((resolve, reject) => { + ws.onmessage = (event) => { + const data = JSON.parse(event.data); + + switch(data.status) { + case 'started': + loading.setStatus('Starting metadata fetch...'); + break; + + case 'processing': + const percent = ((data.processed / data.total) * 100).toFixed(1); + loading.setProgress(percent); + loading.setStatus( + `Processing (${data.processed}/${data.total}) ${data.current_name}` + ); + break; + + case 'completed': + loading.setProgress(100); + loading.setStatus( + `Completed: Updated ${data.success} of ${data.processed} checkpoints` + ); + resolve(); + break; + + case 'error': + reject(new Error(data.error)); + break; + } + }; + + ws.onerror = (error) => { + reject(new Error('WebSocket error: ' + error.message)); + }; + }); + + await new Promise((resolve, reject) => { + ws.onopen = resolve; + ws.onerror = reject; + }); + + const response = await fetch('/api/checkpoints/fetch-all-civitai', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ model_type: 'checkpoint' }) // Specify we're fetching checkpoint metadata + }); + + if (!response.ok) { + throw new Error('Failed to fetch metadata'); + } + + await operationComplete; + + await resetAndReload(); + + } catch (error) { + console.error('Error fetching metadata:', error); + showToast('Failed to fetch metadata: ' + error.message, 'error'); + } finally { + if (ws) { + ws.close(); + } + } + }, { + initialMessage: 'Connecting...', + completionMessage: 'Metadata update complete' + }); } \ No newline at end of file diff --git a/static/js/components/controls/CheckpointsControls.js b/static/js/components/controls/CheckpointsControls.js index 2dd968e2..44c6104a 100644 --- a/static/js/components/controls/CheckpointsControls.js +++ b/static/js/components/controls/CheckpointsControls.js @@ -1,6 +1,6 @@ // CheckpointsControls.js - Specific implementation for the Checkpoints page import { PageControls } from './PageControls.js'; -import { loadMoreCheckpoints, resetAndReload, refreshCheckpoints } from '../../api/checkpointApi.js'; +import { loadMoreCheckpoints, resetAndReload, refreshCheckpoints, fetchCivitai } from '../../api/checkpointApi.js'; import { showToast } from '../../utils/uiHelpers.js'; /** @@ -33,6 +33,11 @@ export class CheckpointsControls extends PageControls { return await refreshCheckpoints(); }, + // Add fetch from Civitai functionality for checkpoints + fetchFromCivitai: async () => { + return await fetchCivitai(); + }, + // No clearCustomFilter implementation is needed for checkpoints // as custom filters are currently only used for LoRAs clearCustomFilter: async () => { diff --git a/static/js/components/controls/PageControls.js b/static/js/components/controls/PageControls.js index 43498f12..0bc4f64e 100644 --- a/static/js/components/controls/PageControls.js +++ b/static/js/components/controls/PageControls.js @@ -97,20 +97,20 @@ export class PageControls { * Initialize page-specific event listeners */ initPageSpecificListeners() { + // Fetch from Civitai button - available for both loras and checkpoints + const fetchButton = document.querySelector('[data-action="fetch"]'); + if (fetchButton) { + fetchButton.addEventListener('click', () => this.fetchFromCivitai()); + } + if (this.pageType === 'loras') { - // Fetch from Civitai button - const fetchButton = document.querySelector('[data-action="fetch"]'); - if (fetchButton) { - fetchButton.addEventListener('click', () => this.fetchFromCivitai()); - } - - // Download button + // Download button - LoRAs only const downloadButton = document.querySelector('[data-action="download"]'); if (downloadButton) { downloadButton.addEventListener('click', () => this.showDownloadModal()); } - // Bulk operations button + // Bulk operations button - LoRAs only const bulkButton = document.querySelector('[data-action="bulk"]'); if (bulkButton) { bulkButton.addEventListener('click', () => this.toggleBulkMode()); @@ -332,11 +332,11 @@ export class PageControls { } /** - * Fetch metadata from Civitai (LoRAs only) + * Fetch metadata from Civitai (available for both LoRAs and Checkpoints) */ async fetchFromCivitai() { - if (this.pageType !== 'loras' || !this.api) { - console.error('Fetch from Civitai is only available for LoRAs'); + if (!this.api) { + console.error('API methods not registered'); return; }