mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Add Civitai metadata fetching functionality for checkpoints
- Implement fetchCivitai API method to retrieve metadata from Civitai. - Enhance CheckpointsControls to include fetch from Civitai functionality. - Update PageControls to register fetch from Civitai event listener for both LoRAs and Checkpoints.
This commit is contained in:
@@ -1,12 +1,18 @@
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
import aiohttp
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from ..utils.model_utils import determine_base_model
|
||||
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from ..services.civitai_client import CivitaiClient
|
||||
from ..services.websocket_manager import ws_manager
|
||||
from ..services.checkpoint_scanner import CheckpointScanner
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
@@ -28,6 +34,7 @@ class CheckpointsRoutes:
|
||||
"""Register routes with the aiohttp app"""
|
||||
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
|
||||
app.router.add_get('/api/checkpoints', self.get_checkpoints)
|
||||
app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai)
|
||||
app.router.add_get('/api/checkpoints/base-models', self.get_base_models)
|
||||
app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags)
|
||||
app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints)
|
||||
@@ -267,6 +274,175 @@ class CheckpointsRoutes:
|
||||
]
|
||||
return {k: data[k] for k in fields if k in data}
|
||||
|
||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Fetch CivitAI metadata for all checkpoints in the background"""
|
||||
try:
|
||||
cache = await self.scanner.get_cached_data()
|
||||
total = len(cache.raw_data)
|
||||
processed = 0
|
||||
success = 0
|
||||
needs_resort = False
|
||||
|
||||
# Prepare checkpoints to process
|
||||
to_process = [
|
||||
cp for cp in cache.raw_data
|
||||
if cp.get('sha256') and (not cp.get('civitai') or 'id' not in cp.get('civitai')) and cp.get('from_civitai', True)
|
||||
]
|
||||
total_to_process = len(to_process)
|
||||
|
||||
# Send initial progress
|
||||
await ws_manager.broadcast({
|
||||
'status': 'started',
|
||||
'total': total_to_process,
|
||||
'processed': 0,
|
||||
'success': 0
|
||||
})
|
||||
|
||||
# Process each checkpoint
|
||||
for cp in to_process:
|
||||
try:
|
||||
original_name = cp.get('model_name')
|
||||
if await self._fetch_and_update_single_checkpoint(
|
||||
sha256=cp['sha256'],
|
||||
file_path=cp['file_path'],
|
||||
checkpoint=cp
|
||||
):
|
||||
success += 1
|
||||
if original_name != cp.get('model_name'):
|
||||
needs_resort = True
|
||||
|
||||
processed += 1
|
||||
|
||||
# Send progress update
|
||||
await ws_manager.broadcast({
|
||||
'status': 'processing',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success,
|
||||
'current_name': cp.get('model_name', 'Unknown')
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivitAI data for {cp['file_path']}: {e}")
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort(name_only=True)
|
||||
|
||||
# Send completion message
|
||||
await ws_manager.broadcast({
|
||||
'status': 'completed',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success
|
||||
})
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"message": f"Successfully updated {success} of {processed} processed checkpoints (total: {total})"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
# Send error message
|
||||
await ws_manager.broadcast({
|
||||
'status': 'error',
|
||||
'error': str(e)
|
||||
})
|
||||
logger.error(f"Error in fetch_all_civitai for checkpoints: {e}")
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def _fetch_and_update_single_checkpoint(self, sha256: str, file_path: str, checkpoint: dict) -> bool:
|
||||
"""Fetch and update metadata for a single checkpoint without sorting"""
|
||||
client = CivitaiClient()
|
||||
try:
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
|
||||
# Load local metadata
|
||||
local_metadata = self._load_local_metadata(metadata_path)
|
||||
|
||||
# Fetch metadata from Civitai
|
||||
civitai_metadata = await client.get_model_by_hash(sha256)
|
||||
if not civitai_metadata:
|
||||
# Mark as not from CivitAI if not found
|
||||
local_metadata['from_civitai'] = False
|
||||
checkpoint['from_civitai'] = False
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
|
||||
return False
|
||||
|
||||
# Update metadata with Civitai data
|
||||
await self._update_model_metadata(
|
||||
metadata_path,
|
||||
local_metadata,
|
||||
civitai_metadata,
|
||||
client
|
||||
)
|
||||
|
||||
# Update cache object directly
|
||||
checkpoint.update({
|
||||
'model_name': local_metadata.get('model_name'),
|
||||
'preview_url': local_metadata.get('preview_url'),
|
||||
'from_civitai': True,
|
||||
'civitai': civitai_metadata
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivitAI data for checkpoint: {e}")
|
||||
return False
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
def _load_local_metadata(self, metadata_path: str) -> Dict:
|
||||
"""Load local metadata file"""
|
||||
if os.path.exists(metadata_path):
|
||||
try:
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading metadata from {metadata_path}: {e}")
|
||||
return {}
|
||||
|
||||
async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict,
|
||||
civitai_metadata: Dict, client: CivitaiClient) -> None:
|
||||
"""Update local metadata with CivitAI data"""
|
||||
local_metadata['civitai'] = civitai_metadata
|
||||
|
||||
# Update model name if available
|
||||
if 'model' in civitai_metadata:
|
||||
if civitai_metadata.get('model', {}).get('name'):
|
||||
local_metadata['model_name'] = civitai_metadata['model']['name']
|
||||
|
||||
# Fetch additional model metadata (description and tags) if we have model ID
|
||||
model_id = civitai_metadata['modelId']
|
||||
if model_id:
|
||||
model_metadata, _ = await client.get_model_metadata(str(model_id))
|
||||
if model_metadata:
|
||||
local_metadata['modelDescription'] = model_metadata.get('description', '')
|
||||
local_metadata['tags'] = model_metadata.get('tags', [])
|
||||
|
||||
# Update base model
|
||||
local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel'))
|
||||
|
||||
# Update preview if needed
|
||||
if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']):
|
||||
first_preview = next((img for img in civitai_metadata.get('images', [])), None)
|
||||
if first_preview:
|
||||
preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1]
|
||||
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
|
||||
preview_filename = base_name + preview_ext
|
||||
preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename)
|
||||
|
||||
if await client.download_preview_image(first_preview['url'], preview_path):
|
||||
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
|
||||
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
|
||||
|
||||
# Save updated metadata
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
await self.scanner.update_single_model_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata)
|
||||
|
||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||
"""Handle request for top tags sorted by frequency"""
|
||||
try:
|
||||
|
||||
@@ -249,4 +249,82 @@ async function _uploadPreview(filePath, file) {
|
||||
console.error('Error updating preview:', error);
|
||||
showToast(`Failed to update preview: ${error.message}`, 'error');
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch metadata from Civitai for checkpoints
|
||||
export async function fetchCivitai() {
|
||||
let ws = null;
|
||||
|
||||
await state.loadingManager.showWithProgress(async (loading) => {
|
||||
try {
|
||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
||||
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
|
||||
|
||||
const operationComplete = new Promise((resolve, reject) => {
|
||||
ws.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
|
||||
switch(data.status) {
|
||||
case 'started':
|
||||
loading.setStatus('Starting metadata fetch...');
|
||||
break;
|
||||
|
||||
case 'processing':
|
||||
const percent = ((data.processed / data.total) * 100).toFixed(1);
|
||||
loading.setProgress(percent);
|
||||
loading.setStatus(
|
||||
`Processing (${data.processed}/${data.total}) ${data.current_name}`
|
||||
);
|
||||
break;
|
||||
|
||||
case 'completed':
|
||||
loading.setProgress(100);
|
||||
loading.setStatus(
|
||||
`Completed: Updated ${data.success} of ${data.processed} checkpoints`
|
||||
);
|
||||
resolve();
|
||||
break;
|
||||
|
||||
case 'error':
|
||||
reject(new Error(data.error));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
ws.onerror = (error) => {
|
||||
reject(new Error('WebSocket error: ' + error.message));
|
||||
};
|
||||
});
|
||||
|
||||
await new Promise((resolve, reject) => {
|
||||
ws.onopen = resolve;
|
||||
ws.onerror = reject;
|
||||
});
|
||||
|
||||
const response = await fetch('/api/checkpoints/fetch-all-civitai', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ model_type: 'checkpoint' }) // Specify we're fetching checkpoint metadata
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to fetch metadata');
|
||||
}
|
||||
|
||||
await operationComplete;
|
||||
|
||||
await resetAndReload();
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error fetching metadata:', error);
|
||||
showToast('Failed to fetch metadata: ' + error.message, 'error');
|
||||
} finally {
|
||||
if (ws) {
|
||||
ws.close();
|
||||
}
|
||||
}
|
||||
}, {
|
||||
initialMessage: 'Connecting...',
|
||||
completionMessage: 'Metadata update complete'
|
||||
});
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
// CheckpointsControls.js - Specific implementation for the Checkpoints page
|
||||
import { PageControls } from './PageControls.js';
|
||||
import { loadMoreCheckpoints, resetAndReload, refreshCheckpoints } from '../../api/checkpointApi.js';
|
||||
import { loadMoreCheckpoints, resetAndReload, refreshCheckpoints, fetchCivitai } from '../../api/checkpointApi.js';
|
||||
import { showToast } from '../../utils/uiHelpers.js';
|
||||
|
||||
/**
|
||||
@@ -33,6 +33,11 @@ export class CheckpointsControls extends PageControls {
|
||||
return await refreshCheckpoints();
|
||||
},
|
||||
|
||||
// Add fetch from Civitai functionality for checkpoints
|
||||
fetchFromCivitai: async () => {
|
||||
return await fetchCivitai();
|
||||
},
|
||||
|
||||
// No clearCustomFilter implementation is needed for checkpoints
|
||||
// as custom filters are currently only used for LoRAs
|
||||
clearCustomFilter: async () => {
|
||||
|
||||
@@ -97,20 +97,20 @@ export class PageControls {
|
||||
* Initialize page-specific event listeners
|
||||
*/
|
||||
initPageSpecificListeners() {
|
||||
// Fetch from Civitai button - available for both loras and checkpoints
|
||||
const fetchButton = document.querySelector('[data-action="fetch"]');
|
||||
if (fetchButton) {
|
||||
fetchButton.addEventListener('click', () => this.fetchFromCivitai());
|
||||
}
|
||||
|
||||
if (this.pageType === 'loras') {
|
||||
// Fetch from Civitai button
|
||||
const fetchButton = document.querySelector('[data-action="fetch"]');
|
||||
if (fetchButton) {
|
||||
fetchButton.addEventListener('click', () => this.fetchFromCivitai());
|
||||
}
|
||||
|
||||
// Download button
|
||||
// Download button - LoRAs only
|
||||
const downloadButton = document.querySelector('[data-action="download"]');
|
||||
if (downloadButton) {
|
||||
downloadButton.addEventListener('click', () => this.showDownloadModal());
|
||||
}
|
||||
|
||||
// Bulk operations button
|
||||
// Bulk operations button - LoRAs only
|
||||
const bulkButton = document.querySelector('[data-action="bulk"]');
|
||||
if (bulkButton) {
|
||||
bulkButton.addEventListener('click', () => this.toggleBulkMode());
|
||||
@@ -332,11 +332,11 @@ export class PageControls {
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch metadata from Civitai (LoRAs only)
|
||||
* Fetch metadata from Civitai (available for both LoRAs and Checkpoints)
|
||||
*/
|
||||
async fetchFromCivitai() {
|
||||
if (this.pageType !== 'loras' || !this.api) {
|
||||
console.error('Fetch from Civitai is only available for LoRAs');
|
||||
if (!this.api) {
|
||||
console.error('API methods not registered');
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user