Add Civitai metadata fetching functionality for checkpoints

- Implement fetchCivitai API method to retrieve metadata from Civitai.
- Enhance CheckpointsControls to include fetch from Civitai functionality.
- Update PageControls to register fetch from Civitai event listener for both LoRAs and Checkpoints.
This commit is contained in:
Will Miao
2025-04-10 21:07:17 +08:00
parent 152ec0da0d
commit 131c3cc324
4 changed files with 271 additions and 12 deletions

View File

@@ -1,12 +1,18 @@
import os import os
import json import json
import asyncio import asyncio
from typing import Dict
import aiohttp import aiohttp
import jinja2 import jinja2
from aiohttp import web from aiohttp import web
import logging import logging
from datetime import datetime from datetime import datetime
from ..utils.model_utils import determine_base_model
from ..utils.constants import NSFW_LEVELS
from ..services.civitai_client import CivitaiClient
from ..services.websocket_manager import ws_manager
from ..services.checkpoint_scanner import CheckpointScanner from ..services.checkpoint_scanner import CheckpointScanner
from ..config import config from ..config import config
from ..services.settings_manager import settings from ..services.settings_manager import settings
@@ -28,6 +34,7 @@ class CheckpointsRoutes:
"""Register routes with the aiohttp app""" """Register routes with the aiohttp app"""
app.router.add_get('/checkpoints', self.handle_checkpoints_page) app.router.add_get('/checkpoints', self.handle_checkpoints_page)
app.router.add_get('/api/checkpoints', self.get_checkpoints) app.router.add_get('/api/checkpoints', self.get_checkpoints)
app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai)
app.router.add_get('/api/checkpoints/base-models', self.get_base_models) app.router.add_get('/api/checkpoints/base-models', self.get_base_models)
app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags) app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags)
app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints) app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints)
@@ -267,6 +274,175 @@ class CheckpointsRoutes:
] ]
return {k: data[k] for k in fields if k in data} return {k: data[k] for k in fields if k in data}
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
"""Fetch CivitAI metadata for all checkpoints in the background"""
try:
cache = await self.scanner.get_cached_data()
total = len(cache.raw_data)
processed = 0
success = 0
needs_resort = False
# Prepare checkpoints to process
to_process = [
cp for cp in cache.raw_data
if cp.get('sha256') and (not cp.get('civitai') or 'id' not in cp.get('civitai')) and cp.get('from_civitai', True)
]
total_to_process = len(to_process)
# Send initial progress
await ws_manager.broadcast({
'status': 'started',
'total': total_to_process,
'processed': 0,
'success': 0
})
# Process each checkpoint
for cp in to_process:
try:
original_name = cp.get('model_name')
if await self._fetch_and_update_single_checkpoint(
sha256=cp['sha256'],
file_path=cp['file_path'],
checkpoint=cp
):
success += 1
if original_name != cp.get('model_name'):
needs_resort = True
processed += 1
# Send progress update
await ws_manager.broadcast({
'status': 'processing',
'total': total_to_process,
'processed': processed,
'success': success,
'current_name': cp.get('model_name', 'Unknown')
})
except Exception as e:
logger.error(f"Error fetching CivitAI data for {cp['file_path']}: {e}")
if needs_resort:
await cache.resort(name_only=True)
# Send completion message
await ws_manager.broadcast({
'status': 'completed',
'total': total_to_process,
'processed': processed,
'success': success
})
return web.json_response({
"success": True,
"message": f"Successfully updated {success} of {processed} processed checkpoints (total: {total})"
})
except Exception as e:
# Send error message
await ws_manager.broadcast({
'status': 'error',
'error': str(e)
})
logger.error(f"Error in fetch_all_civitai for checkpoints: {e}")
return web.Response(text=str(e), status=500)
async def _fetch_and_update_single_checkpoint(self, sha256: str, file_path: str, checkpoint: dict) -> bool:
"""Fetch and update metadata for a single checkpoint without sorting"""
client = CivitaiClient()
try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Load local metadata
local_metadata = self._load_local_metadata(metadata_path)
# Fetch metadata from Civitai
civitai_metadata = await client.get_model_by_hash(sha256)
if not civitai_metadata:
# Mark as not from CivitAI if not found
local_metadata['from_civitai'] = False
checkpoint['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
return False
# Update metadata with Civitai data
await self._update_model_metadata(
metadata_path,
local_metadata,
civitai_metadata,
client
)
# Update cache object directly
checkpoint.update({
'model_name': local_metadata.get('model_name'),
'preview_url': local_metadata.get('preview_url'),
'from_civitai': True,
'civitai': civitai_metadata
})
return True
except Exception as e:
logger.error(f"Error fetching CivitAI data for checkpoint: {e}")
return False
finally:
await client.close()
def _load_local_metadata(self, metadata_path: str) -> Dict:
"""Load local metadata file"""
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading metadata from {metadata_path}: {e}")
return {}
async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict,
civitai_metadata: Dict, client: CivitaiClient) -> None:
"""Update local metadata with CivitAI data"""
local_metadata['civitai'] = civitai_metadata
# Update model name if available
if 'model' in civitai_metadata:
if civitai_metadata.get('model', {}).get('name'):
local_metadata['model_name'] = civitai_metadata['model']['name']
# Fetch additional model metadata (description and tags) if we have model ID
model_id = civitai_metadata['modelId']
if model_id:
model_metadata, _ = await client.get_model_metadata(str(model_id))
if model_metadata:
local_metadata['modelDescription'] = model_metadata.get('description', '')
local_metadata['tags'] = model_metadata.get('tags', [])
# Update base model
local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel'))
# Update preview if needed
if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']):
first_preview = next((img for img in civitai_metadata.get('images', [])), None)
if first_preview:
preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1]
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
preview_filename = base_name + preview_ext
preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename)
if await client.download_preview_image(first_preview['url'], preview_path):
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
await self.scanner.update_single_model_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata)
async def get_top_tags(self, request: web.Request) -> web.Response: async def get_top_tags(self, request: web.Request) -> web.Response:
"""Handle request for top tags sorted by frequency""" """Handle request for top tags sorted by frequency"""
try: try:

View File

@@ -249,4 +249,82 @@ async function _uploadPreview(filePath, file) {
console.error('Error updating preview:', error); console.error('Error updating preview:', error);
showToast(`Failed to update preview: ${error.message}`, 'error'); showToast(`Failed to update preview: ${error.message}`, 'error');
} }
}
// Fetch metadata from Civitai for checkpoints
export async function fetchCivitai() {
let ws = null;
await state.loadingManager.showWithProgress(async (loading) => {
try {
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
const operationComplete = new Promise((resolve, reject) => {
ws.onmessage = (event) => {
const data = JSON.parse(event.data);
switch(data.status) {
case 'started':
loading.setStatus('Starting metadata fetch...');
break;
case 'processing':
const percent = ((data.processed / data.total) * 100).toFixed(1);
loading.setProgress(percent);
loading.setStatus(
`Processing (${data.processed}/${data.total}) ${data.current_name}`
);
break;
case 'completed':
loading.setProgress(100);
loading.setStatus(
`Completed: Updated ${data.success} of ${data.processed} checkpoints`
);
resolve();
break;
case 'error':
reject(new Error(data.error));
break;
}
};
ws.onerror = (error) => {
reject(new Error('WebSocket error: ' + error.message));
};
});
await new Promise((resolve, reject) => {
ws.onopen = resolve;
ws.onerror = reject;
});
const response = await fetch('/api/checkpoints/fetch-all-civitai', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ model_type: 'checkpoint' }) // Specify we're fetching checkpoint metadata
});
if (!response.ok) {
throw new Error('Failed to fetch metadata');
}
await operationComplete;
await resetAndReload();
} catch (error) {
console.error('Error fetching metadata:', error);
showToast('Failed to fetch metadata: ' + error.message, 'error');
} finally {
if (ws) {
ws.close();
}
}
}, {
initialMessage: 'Connecting...',
completionMessage: 'Metadata update complete'
});
} }

View File

@@ -1,6 +1,6 @@
// CheckpointsControls.js - Specific implementation for the Checkpoints page // CheckpointsControls.js - Specific implementation for the Checkpoints page
import { PageControls } from './PageControls.js'; import { PageControls } from './PageControls.js';
import { loadMoreCheckpoints, resetAndReload, refreshCheckpoints } from '../../api/checkpointApi.js'; import { loadMoreCheckpoints, resetAndReload, refreshCheckpoints, fetchCivitai } from '../../api/checkpointApi.js';
import { showToast } from '../../utils/uiHelpers.js'; import { showToast } from '../../utils/uiHelpers.js';
/** /**
@@ -33,6 +33,11 @@ export class CheckpointsControls extends PageControls {
return await refreshCheckpoints(); return await refreshCheckpoints();
}, },
// Add fetch from Civitai functionality for checkpoints
fetchFromCivitai: async () => {
return await fetchCivitai();
},
// No clearCustomFilter implementation is needed for checkpoints // No clearCustomFilter implementation is needed for checkpoints
// as custom filters are currently only used for LoRAs // as custom filters are currently only used for LoRAs
clearCustomFilter: async () => { clearCustomFilter: async () => {

View File

@@ -97,20 +97,20 @@ export class PageControls {
* Initialize page-specific event listeners * Initialize page-specific event listeners
*/ */
initPageSpecificListeners() { initPageSpecificListeners() {
// Fetch from Civitai button - available for both loras and checkpoints
const fetchButton = document.querySelector('[data-action="fetch"]');
if (fetchButton) {
fetchButton.addEventListener('click', () => this.fetchFromCivitai());
}
if (this.pageType === 'loras') { if (this.pageType === 'loras') {
// Fetch from Civitai button // Download button - LoRAs only
const fetchButton = document.querySelector('[data-action="fetch"]');
if (fetchButton) {
fetchButton.addEventListener('click', () => this.fetchFromCivitai());
}
// Download button
const downloadButton = document.querySelector('[data-action="download"]'); const downloadButton = document.querySelector('[data-action="download"]');
if (downloadButton) { if (downloadButton) {
downloadButton.addEventListener('click', () => this.showDownloadModal()); downloadButton.addEventListener('click', () => this.showDownloadModal());
} }
// Bulk operations button // Bulk operations button - LoRAs only
const bulkButton = document.querySelector('[data-action="bulk"]'); const bulkButton = document.querySelector('[data-action="bulk"]');
if (bulkButton) { if (bulkButton) {
bulkButton.addEventListener('click', () => this.toggleBulkMode()); bulkButton.addEventListener('click', () => this.toggleBulkMode());
@@ -332,11 +332,11 @@ export class PageControls {
} }
/** /**
* Fetch metadata from Civitai (LoRAs only) * Fetch metadata from Civitai (available for both LoRAs and Checkpoints)
*/ */
async fetchFromCivitai() { async fetchFromCivitai() {
if (this.pageType !== 'loras' || !this.api) { if (!this.api) {
console.error('Fetch from Civitai is only available for LoRAs'); console.error('API methods not registered');
return; return;
} }