Add Civitai metadata fetching functionality for checkpoints

- Implement fetchCivitai API method to retrieve metadata from Civitai.
- Enhance CheckpointsControls to include fetch from Civitai functionality.
- Update PageControls to register fetch from Civitai event listener for both LoRAs and Checkpoints.
This commit is contained in:
Will Miao
2025-04-10 21:07:17 +08:00
parent 152ec0da0d
commit 131c3cc324
4 changed files with 271 additions and 12 deletions

View File

@@ -1,12 +1,18 @@
import os
import json
import asyncio
from typing import Dict
import aiohttp
import jinja2
from aiohttp import web
import logging
from datetime import datetime
from ..utils.model_utils import determine_base_model
from ..utils.constants import NSFW_LEVELS
from ..services.civitai_client import CivitaiClient
from ..services.websocket_manager import ws_manager
from ..services.checkpoint_scanner import CheckpointScanner
from ..config import config
from ..services.settings_manager import settings
@@ -28,6 +34,7 @@ class CheckpointsRoutes:
"""Register routes with the aiohttp app"""
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
app.router.add_get('/api/checkpoints', self.get_checkpoints)
app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai)
app.router.add_get('/api/checkpoints/base-models', self.get_base_models)
app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags)
app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints)
@@ -267,6 +274,175 @@ class CheckpointsRoutes:
]
return {k: data[k] for k in fields if k in data}
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
"""Fetch CivitAI metadata for all checkpoints in the background"""
try:
cache = await self.scanner.get_cached_data()
total = len(cache.raw_data)
processed = 0
success = 0
needs_resort = False
# Prepare checkpoints to process
to_process = [
cp for cp in cache.raw_data
if cp.get('sha256') and (not cp.get('civitai') or 'id' not in cp.get('civitai')) and cp.get('from_civitai', True)
]
total_to_process = len(to_process)
# Send initial progress
await ws_manager.broadcast({
'status': 'started',
'total': total_to_process,
'processed': 0,
'success': 0
})
# Process each checkpoint
for cp in to_process:
try:
original_name = cp.get('model_name')
if await self._fetch_and_update_single_checkpoint(
sha256=cp['sha256'],
file_path=cp['file_path'],
checkpoint=cp
):
success += 1
if original_name != cp.get('model_name'):
needs_resort = True
processed += 1
# Send progress update
await ws_manager.broadcast({
'status': 'processing',
'total': total_to_process,
'processed': processed,
'success': success,
'current_name': cp.get('model_name', 'Unknown')
})
except Exception as e:
logger.error(f"Error fetching CivitAI data for {cp['file_path']}: {e}")
if needs_resort:
await cache.resort(name_only=True)
# Send completion message
await ws_manager.broadcast({
'status': 'completed',
'total': total_to_process,
'processed': processed,
'success': success
})
return web.json_response({
"success": True,
"message": f"Successfully updated {success} of {processed} processed checkpoints (total: {total})"
})
except Exception as e:
# Send error message
await ws_manager.broadcast({
'status': 'error',
'error': str(e)
})
logger.error(f"Error in fetch_all_civitai for checkpoints: {e}")
return web.Response(text=str(e), status=500)
async def _fetch_and_update_single_checkpoint(self, sha256: str, file_path: str, checkpoint: dict) -> bool:
"""Fetch and update metadata for a single checkpoint without sorting"""
client = CivitaiClient()
try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Load local metadata
local_metadata = self._load_local_metadata(metadata_path)
# Fetch metadata from Civitai
civitai_metadata = await client.get_model_by_hash(sha256)
if not civitai_metadata:
# Mark as not from CivitAI if not found
local_metadata['from_civitai'] = False
checkpoint['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
return False
# Update metadata with Civitai data
await self._update_model_metadata(
metadata_path,
local_metadata,
civitai_metadata,
client
)
# Update cache object directly
checkpoint.update({
'model_name': local_metadata.get('model_name'),
'preview_url': local_metadata.get('preview_url'),
'from_civitai': True,
'civitai': civitai_metadata
})
return True
except Exception as e:
logger.error(f"Error fetching CivitAI data for checkpoint: {e}")
return False
finally:
await client.close()
def _load_local_metadata(self, metadata_path: str) -> Dict:
"""Load local metadata file"""
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading metadata from {metadata_path}: {e}")
return {}
async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict,
civitai_metadata: Dict, client: CivitaiClient) -> None:
"""Update local metadata with CivitAI data"""
local_metadata['civitai'] = civitai_metadata
# Update model name if available
if 'model' in civitai_metadata:
if civitai_metadata.get('model', {}).get('name'):
local_metadata['model_name'] = civitai_metadata['model']['name']
# Fetch additional model metadata (description and tags) if we have model ID
model_id = civitai_metadata['modelId']
if model_id:
model_metadata, _ = await client.get_model_metadata(str(model_id))
if model_metadata:
local_metadata['modelDescription'] = model_metadata.get('description', '')
local_metadata['tags'] = model_metadata.get('tags', [])
# Update base model
local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel'))
# Update preview if needed
if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']):
first_preview = next((img for img in civitai_metadata.get('images', [])), None)
if first_preview:
preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1]
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
preview_filename = base_name + preview_ext
preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename)
if await client.download_preview_image(first_preview['url'], preview_path):
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
await self.scanner.update_single_model_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata)
async def get_top_tags(self, request: web.Request) -> web.Response:
"""Handle request for top tags sorted by frequency"""
try:

View File

@@ -249,4 +249,82 @@ async function _uploadPreview(filePath, file) {
console.error('Error updating preview:', error);
showToast(`Failed to update preview: ${error.message}`, 'error');
}
}
// Fetch metadata from Civitai for checkpoints
export async function fetchCivitai() {
let ws = null;
await state.loadingManager.showWithProgress(async (loading) => {
try {
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
const operationComplete = new Promise((resolve, reject) => {
ws.onmessage = (event) => {
const data = JSON.parse(event.data);
switch(data.status) {
case 'started':
loading.setStatus('Starting metadata fetch...');
break;
case 'processing':
const percent = ((data.processed / data.total) * 100).toFixed(1);
loading.setProgress(percent);
loading.setStatus(
`Processing (${data.processed}/${data.total}) ${data.current_name}`
);
break;
case 'completed':
loading.setProgress(100);
loading.setStatus(
`Completed: Updated ${data.success} of ${data.processed} checkpoints`
);
resolve();
break;
case 'error':
reject(new Error(data.error));
break;
}
};
ws.onerror = (error) => {
reject(new Error('WebSocket error: ' + error.message));
};
});
await new Promise((resolve, reject) => {
ws.onopen = resolve;
ws.onerror = reject;
});
const response = await fetch('/api/checkpoints/fetch-all-civitai', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ model_type: 'checkpoint' }) // Specify we're fetching checkpoint metadata
});
if (!response.ok) {
throw new Error('Failed to fetch metadata');
}
await operationComplete;
await resetAndReload();
} catch (error) {
console.error('Error fetching metadata:', error);
showToast('Failed to fetch metadata: ' + error.message, 'error');
} finally {
if (ws) {
ws.close();
}
}
}, {
initialMessage: 'Connecting...',
completionMessage: 'Metadata update complete'
});
}

View File

@@ -1,6 +1,6 @@
// CheckpointsControls.js - Specific implementation for the Checkpoints page
import { PageControls } from './PageControls.js';
import { loadMoreCheckpoints, resetAndReload, refreshCheckpoints } from '../../api/checkpointApi.js';
import { loadMoreCheckpoints, resetAndReload, refreshCheckpoints, fetchCivitai } from '../../api/checkpointApi.js';
import { showToast } from '../../utils/uiHelpers.js';
/**
@@ -33,6 +33,11 @@ export class CheckpointsControls extends PageControls {
return await refreshCheckpoints();
},
// Add fetch from Civitai functionality for checkpoints
fetchFromCivitai: async () => {
return await fetchCivitai();
},
// No clearCustomFilter implementation is needed for checkpoints
// as custom filters are currently only used for LoRAs
clearCustomFilter: async () => {

View File

@@ -97,20 +97,20 @@ export class PageControls {
* Initialize page-specific event listeners
*/
initPageSpecificListeners() {
// Fetch from Civitai button - available for both loras and checkpoints
const fetchButton = document.querySelector('[data-action="fetch"]');
if (fetchButton) {
fetchButton.addEventListener('click', () => this.fetchFromCivitai());
}
if (this.pageType === 'loras') {
// Fetch from Civitai button
const fetchButton = document.querySelector('[data-action="fetch"]');
if (fetchButton) {
fetchButton.addEventListener('click', () => this.fetchFromCivitai());
}
// Download button
// Download button - LoRAs only
const downloadButton = document.querySelector('[data-action="download"]');
if (downloadButton) {
downloadButton.addEventListener('click', () => this.showDownloadModal());
}
// Bulk operations button
// Bulk operations button - LoRAs only
const bulkButton = document.querySelector('[data-action="bulk"]');
if (bulkButton) {
bulkButton.addEventListener('click', () => this.toggleBulkMode());
@@ -332,11 +332,11 @@ export class PageControls {
}
/**
* Fetch metadata from Civitai (LoRAs only)
* Fetch metadata from Civitai (available for both LoRAs and Checkpoints)
*/
async fetchFromCivitai() {
if (this.pageType !== 'loras' || !this.api) {
console.error('Fetch from Civitai is only available for LoRAs');
if (!this.api) {
console.error('API methods not registered');
return;
}