mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
Add Civitai metadata fetching functionality for checkpoints
- Implement fetchCivitai API method to retrieve metadata from Civitai. - Enhance CheckpointsControls to include fetch from Civitai functionality. - Update PageControls to register fetch from Civitai event listener for both LoRAs and Checkpoints.
This commit is contained in:
@@ -1,12 +1,18 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import Dict
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import jinja2
|
import jinja2
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
from ..utils.model_utils import determine_base_model
|
||||||
|
|
||||||
|
from ..utils.constants import NSFW_LEVELS
|
||||||
|
from ..services.civitai_client import CivitaiClient
|
||||||
|
from ..services.websocket_manager import ws_manager
|
||||||
from ..services.checkpoint_scanner import CheckpointScanner
|
from ..services.checkpoint_scanner import CheckpointScanner
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..services.settings_manager import settings
|
from ..services.settings_manager import settings
|
||||||
@@ -28,6 +34,7 @@ class CheckpointsRoutes:
|
|||||||
"""Register routes with the aiohttp app"""
|
"""Register routes with the aiohttp app"""
|
||||||
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
|
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
|
||||||
app.router.add_get('/api/checkpoints', self.get_checkpoints)
|
app.router.add_get('/api/checkpoints', self.get_checkpoints)
|
||||||
|
app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai)
|
||||||
app.router.add_get('/api/checkpoints/base-models', self.get_base_models)
|
app.router.add_get('/api/checkpoints/base-models', self.get_base_models)
|
||||||
app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags)
|
app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags)
|
||||||
app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints)
|
app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints)
|
||||||
@@ -267,6 +274,175 @@ class CheckpointsRoutes:
|
|||||||
]
|
]
|
||||||
return {k: data[k] for k in fields if k in data}
|
return {k: data[k] for k in fields if k in data}
|
||||||
|
|
||||||
|
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||||
|
"""Fetch CivitAI metadata for all checkpoints in the background"""
|
||||||
|
try:
|
||||||
|
cache = await self.scanner.get_cached_data()
|
||||||
|
total = len(cache.raw_data)
|
||||||
|
processed = 0
|
||||||
|
success = 0
|
||||||
|
needs_resort = False
|
||||||
|
|
||||||
|
# Prepare checkpoints to process
|
||||||
|
to_process = [
|
||||||
|
cp for cp in cache.raw_data
|
||||||
|
if cp.get('sha256') and (not cp.get('civitai') or 'id' not in cp.get('civitai')) and cp.get('from_civitai', True)
|
||||||
|
]
|
||||||
|
total_to_process = len(to_process)
|
||||||
|
|
||||||
|
# Send initial progress
|
||||||
|
await ws_manager.broadcast({
|
||||||
|
'status': 'started',
|
||||||
|
'total': total_to_process,
|
||||||
|
'processed': 0,
|
||||||
|
'success': 0
|
||||||
|
})
|
||||||
|
|
||||||
|
# Process each checkpoint
|
||||||
|
for cp in to_process:
|
||||||
|
try:
|
||||||
|
original_name = cp.get('model_name')
|
||||||
|
if await self._fetch_and_update_single_checkpoint(
|
||||||
|
sha256=cp['sha256'],
|
||||||
|
file_path=cp['file_path'],
|
||||||
|
checkpoint=cp
|
||||||
|
):
|
||||||
|
success += 1
|
||||||
|
if original_name != cp.get('model_name'):
|
||||||
|
needs_resort = True
|
||||||
|
|
||||||
|
processed += 1
|
||||||
|
|
||||||
|
# Send progress update
|
||||||
|
await ws_manager.broadcast({
|
||||||
|
'status': 'processing',
|
||||||
|
'total': total_to_process,
|
||||||
|
'processed': processed,
|
||||||
|
'success': success,
|
||||||
|
'current_name': cp.get('model_name', 'Unknown')
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching CivitAI data for {cp['file_path']}: {e}")
|
||||||
|
|
||||||
|
if needs_resort:
|
||||||
|
await cache.resort(name_only=True)
|
||||||
|
|
||||||
|
# Send completion message
|
||||||
|
await ws_manager.broadcast({
|
||||||
|
'status': 'completed',
|
||||||
|
'total': total_to_process,
|
||||||
|
'processed': processed,
|
||||||
|
'success': success
|
||||||
|
})
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
"success": True,
|
||||||
|
"message": f"Successfully updated {success} of {processed} processed checkpoints (total: {total})"
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Send error message
|
||||||
|
await ws_manager.broadcast({
|
||||||
|
'status': 'error',
|
||||||
|
'error': str(e)
|
||||||
|
})
|
||||||
|
logger.error(f"Error in fetch_all_civitai for checkpoints: {e}")
|
||||||
|
return web.Response(text=str(e), status=500)
|
||||||
|
|
||||||
|
async def _fetch_and_update_single_checkpoint(self, sha256: str, file_path: str, checkpoint: dict) -> bool:
|
||||||
|
"""Fetch and update metadata for a single checkpoint without sorting"""
|
||||||
|
client = CivitaiClient()
|
||||||
|
try:
|
||||||
|
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||||
|
|
||||||
|
# Load local metadata
|
||||||
|
local_metadata = self._load_local_metadata(metadata_path)
|
||||||
|
|
||||||
|
# Fetch metadata from Civitai
|
||||||
|
civitai_metadata = await client.get_model_by_hash(sha256)
|
||||||
|
if not civitai_metadata:
|
||||||
|
# Mark as not from CivitAI if not found
|
||||||
|
local_metadata['from_civitai'] = False
|
||||||
|
checkpoint['from_civitai'] = False
|
||||||
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Update metadata with Civitai data
|
||||||
|
await self._update_model_metadata(
|
||||||
|
metadata_path,
|
||||||
|
local_metadata,
|
||||||
|
civitai_metadata,
|
||||||
|
client
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update cache object directly
|
||||||
|
checkpoint.update({
|
||||||
|
'model_name': local_metadata.get('model_name'),
|
||||||
|
'preview_url': local_metadata.get('preview_url'),
|
||||||
|
'from_civitai': True,
|
||||||
|
'civitai': civitai_metadata
|
||||||
|
})
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching CivitAI data for checkpoint: {e}")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
def _load_local_metadata(self, metadata_path: str) -> Dict:
|
||||||
|
"""Load local metadata file"""
|
||||||
|
if os.path.exists(metadata_path):
|
||||||
|
try:
|
||||||
|
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||||
|
return json.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading metadata from {metadata_path}: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict,
|
||||||
|
civitai_metadata: Dict, client: CivitaiClient) -> None:
|
||||||
|
"""Update local metadata with CivitAI data"""
|
||||||
|
local_metadata['civitai'] = civitai_metadata
|
||||||
|
|
||||||
|
# Update model name if available
|
||||||
|
if 'model' in civitai_metadata:
|
||||||
|
if civitai_metadata.get('model', {}).get('name'):
|
||||||
|
local_metadata['model_name'] = civitai_metadata['model']['name']
|
||||||
|
|
||||||
|
# Fetch additional model metadata (description and tags) if we have model ID
|
||||||
|
model_id = civitai_metadata['modelId']
|
||||||
|
if model_id:
|
||||||
|
model_metadata, _ = await client.get_model_metadata(str(model_id))
|
||||||
|
if model_metadata:
|
||||||
|
local_metadata['modelDescription'] = model_metadata.get('description', '')
|
||||||
|
local_metadata['tags'] = model_metadata.get('tags', [])
|
||||||
|
|
||||||
|
# Update base model
|
||||||
|
local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel'))
|
||||||
|
|
||||||
|
# Update preview if needed
|
||||||
|
if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']):
|
||||||
|
first_preview = next((img for img in civitai_metadata.get('images', [])), None)
|
||||||
|
if first_preview:
|
||||||
|
preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1]
|
||||||
|
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
|
||||||
|
preview_filename = base_name + preview_ext
|
||||||
|
preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename)
|
||||||
|
|
||||||
|
if await client.download_preview_image(first_preview['url'], preview_path):
|
||||||
|
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
|
||||||
|
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
|
||||||
|
|
||||||
|
# Save updated metadata
|
||||||
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
await self.scanner.update_single_model_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata)
|
||||||
|
|
||||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||||
"""Handle request for top tags sorted by frequency"""
|
"""Handle request for top tags sorted by frequency"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -249,4 +249,82 @@ async function _uploadPreview(filePath, file) {
|
|||||||
console.error('Error updating preview:', error);
|
console.error('Error updating preview:', error);
|
||||||
showToast(`Failed to update preview: ${error.message}`, 'error');
|
showToast(`Failed to update preview: ${error.message}`, 'error');
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch metadata from Civitai for checkpoints
|
||||||
|
export async function fetchCivitai() {
|
||||||
|
let ws = null;
|
||||||
|
|
||||||
|
await state.loadingManager.showWithProgress(async (loading) => {
|
||||||
|
try {
|
||||||
|
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
||||||
|
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
|
||||||
|
|
||||||
|
const operationComplete = new Promise((resolve, reject) => {
|
||||||
|
ws.onmessage = (event) => {
|
||||||
|
const data = JSON.parse(event.data);
|
||||||
|
|
||||||
|
switch(data.status) {
|
||||||
|
case 'started':
|
||||||
|
loading.setStatus('Starting metadata fetch...');
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 'processing':
|
||||||
|
const percent = ((data.processed / data.total) * 100).toFixed(1);
|
||||||
|
loading.setProgress(percent);
|
||||||
|
loading.setStatus(
|
||||||
|
`Processing (${data.processed}/${data.total}) ${data.current_name}`
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 'completed':
|
||||||
|
loading.setProgress(100);
|
||||||
|
loading.setStatus(
|
||||||
|
`Completed: Updated ${data.success} of ${data.processed} checkpoints`
|
||||||
|
);
|
||||||
|
resolve();
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 'error':
|
||||||
|
reject(new Error(data.error));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onerror = (error) => {
|
||||||
|
reject(new Error('WebSocket error: ' + error.message));
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
await new Promise((resolve, reject) => {
|
||||||
|
ws.onopen = resolve;
|
||||||
|
ws.onerror = reject;
|
||||||
|
});
|
||||||
|
|
||||||
|
const response = await fetch('/api/checkpoints/fetch-all-civitai', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({ model_type: 'checkpoint' }) // Specify we're fetching checkpoint metadata
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error('Failed to fetch metadata');
|
||||||
|
}
|
||||||
|
|
||||||
|
await operationComplete;
|
||||||
|
|
||||||
|
await resetAndReload();
|
||||||
|
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error fetching metadata:', error);
|
||||||
|
showToast('Failed to fetch metadata: ' + error.message, 'error');
|
||||||
|
} finally {
|
||||||
|
if (ws) {
|
||||||
|
ws.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, {
|
||||||
|
initialMessage: 'Connecting...',
|
||||||
|
completionMessage: 'Metadata update complete'
|
||||||
|
});
|
||||||
}
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
// CheckpointsControls.js - Specific implementation for the Checkpoints page
|
// CheckpointsControls.js - Specific implementation for the Checkpoints page
|
||||||
import { PageControls } from './PageControls.js';
|
import { PageControls } from './PageControls.js';
|
||||||
import { loadMoreCheckpoints, resetAndReload, refreshCheckpoints } from '../../api/checkpointApi.js';
|
import { loadMoreCheckpoints, resetAndReload, refreshCheckpoints, fetchCivitai } from '../../api/checkpointApi.js';
|
||||||
import { showToast } from '../../utils/uiHelpers.js';
|
import { showToast } from '../../utils/uiHelpers.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -33,6 +33,11 @@ export class CheckpointsControls extends PageControls {
|
|||||||
return await refreshCheckpoints();
|
return await refreshCheckpoints();
|
||||||
},
|
},
|
||||||
|
|
||||||
|
// Add fetch from Civitai functionality for checkpoints
|
||||||
|
fetchFromCivitai: async () => {
|
||||||
|
return await fetchCivitai();
|
||||||
|
},
|
||||||
|
|
||||||
// No clearCustomFilter implementation is needed for checkpoints
|
// No clearCustomFilter implementation is needed for checkpoints
|
||||||
// as custom filters are currently only used for LoRAs
|
// as custom filters are currently only used for LoRAs
|
||||||
clearCustomFilter: async () => {
|
clearCustomFilter: async () => {
|
||||||
|
|||||||
@@ -97,20 +97,20 @@ export class PageControls {
|
|||||||
* Initialize page-specific event listeners
|
* Initialize page-specific event listeners
|
||||||
*/
|
*/
|
||||||
initPageSpecificListeners() {
|
initPageSpecificListeners() {
|
||||||
|
// Fetch from Civitai button - available for both loras and checkpoints
|
||||||
|
const fetchButton = document.querySelector('[data-action="fetch"]');
|
||||||
|
if (fetchButton) {
|
||||||
|
fetchButton.addEventListener('click', () => this.fetchFromCivitai());
|
||||||
|
}
|
||||||
|
|
||||||
if (this.pageType === 'loras') {
|
if (this.pageType === 'loras') {
|
||||||
// Fetch from Civitai button
|
// Download button - LoRAs only
|
||||||
const fetchButton = document.querySelector('[data-action="fetch"]');
|
|
||||||
if (fetchButton) {
|
|
||||||
fetchButton.addEventListener('click', () => this.fetchFromCivitai());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Download button
|
|
||||||
const downloadButton = document.querySelector('[data-action="download"]');
|
const downloadButton = document.querySelector('[data-action="download"]');
|
||||||
if (downloadButton) {
|
if (downloadButton) {
|
||||||
downloadButton.addEventListener('click', () => this.showDownloadModal());
|
downloadButton.addEventListener('click', () => this.showDownloadModal());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bulk operations button
|
// Bulk operations button - LoRAs only
|
||||||
const bulkButton = document.querySelector('[data-action="bulk"]');
|
const bulkButton = document.querySelector('[data-action="bulk"]');
|
||||||
if (bulkButton) {
|
if (bulkButton) {
|
||||||
bulkButton.addEventListener('click', () => this.toggleBulkMode());
|
bulkButton.addEventListener('click', () => this.toggleBulkMode());
|
||||||
@@ -332,11 +332,11 @@ export class PageControls {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Fetch metadata from Civitai (LoRAs only)
|
* Fetch metadata from Civitai (available for both LoRAs and Checkpoints)
|
||||||
*/
|
*/
|
||||||
async fetchFromCivitai() {
|
async fetchFromCivitai() {
|
||||||
if (this.pageType !== 'loras' || !this.api) {
|
if (!this.api) {
|
||||||
console.error('Fetch from Civitai is only available for LoRAs');
|
console.error('API methods not registered');
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user