Add fetch metadata from civitai

This commit is contained in:
Will Miao
2025-01-26 13:41:16 +08:00
parent 6e9ed34b92
commit 2007e80a7d
6 changed files with 196 additions and 8 deletions

View File

@@ -9,5 +9,8 @@ WEB_DIRECTORY = "./js"
# Add custom websocket event type # Add custom websocket event type
EXTENSION_WEB_SOCKET_MESSAGE_TYPES = ["lora-scan-progress"] EXTENSION_WEB_SOCKET_MESSAGE_TYPES = ["lora-scan-progress"]
__all__ = ['NODE_CLASS_MAPPINGS'] # Add this init function to properly register routes
def init():
LorasEndpoint.add_routes()
__all__ = ['NODE_CLASS_MAPPINGS']

View File

@@ -5,11 +5,11 @@ from pathlib import Path
from aiohttp import web from aiohttp import web
from server import PromptServer from server import PromptServer
import jinja2 import jinja2
from flask import jsonify, request
from safetensors import safe_open from safetensors import safe_open
from .utils.file_utils import get_file_info, save_metadata, load_metadata, update_civitai_metadata from .utils.file_utils import get_file_info, save_metadata, load_metadata, update_civitai_metadata
from .utils.lora_metadata import extract_lora_metadata from .utils.lora_metadata import extract_lora_metadata
from typing import Dict, Optional from typing import Dict, Optional
from .services.civitai_client import CivitaiClient
class LorasEndpoint: class LorasEndpoint:
def __init__(self): def __init__(self):
@@ -33,7 +33,8 @@ class LorasEndpoint:
web.get('/loras', instance.handle_loras_request), web.get('/loras', instance.handle_loras_request),
web.static('/loras_static/previews', instance.loras_root), web.static('/loras_static/previews', instance.loras_root),
web.static('/loras_static', static_path), web.static('/loras_static', static_path),
web.post('/api/delete_model', instance.delete_model) web.post('/api/delete_model', instance.delete_model),
web.post('/api/fetch-civitai', instance.fetch_civitai)
]) ])
def send_progress(self, current, total, status="Scanning"): def send_progress(self, current, total, status="Scanning"):
@@ -105,8 +106,6 @@ class LorasEndpoint:
formatted_loras = [self.format_lora(l) for l in data] formatted_loras = [self.format_lora(l) for l in data]
folders = sorted(list(set(l['folder'] for l in data))) folders = sorted(list(set(l['folder'] for l in data)))
print("folders:",folders)
# Debug logging # Debug logging
if formatted_loras: if formatted_loras:
print(f"Sample lora data: {formatted_loras[0]}") print(f"Sample lora data: {formatted_loras[0]}")
@@ -160,6 +159,8 @@ class LorasEndpoint:
"preview_url": lora["preview_url"], "preview_url": lora["preview_url"],
"base_model": lora["base_model"], "base_model": lora["base_model"],
"folder": lora["folder"], "folder": lora["folder"],
"sha256": lora["sha256"],
"file_path": lora["file_path"],
"civitai": lora.get("civitai", {}) or {} # 确保当 civitai 为 None 时返回空字典 "civitai": lora.get("civitai", {}) or {} # 确保当 civitai 为 None 时返回空字典
} }
except Exception as e: except Exception as e:
@@ -171,6 +172,8 @@ class LorasEndpoint:
"preview_url": lora.get("preview_url", ""), "preview_url": lora.get("preview_url", ""),
"base_model": lora.get("base_model", ""), "base_model": lora.get("base_model", ""),
"folder": lora.get("folder", ""), "folder": lora.get("folder", ""),
"sha256": lora.get("sha256", ""),
"file_path": lora.get("file_path", ""),
"civitai": { "civitai": {
"id": "", "id": "",
"modelId": "", "modelId": "",
@@ -252,5 +255,71 @@ class LorasEndpoint:
except Exception as e: except Exception as e:
print(f"Error downloading preview image: {str(e)}") print(f"Error downloading preview image: {str(e)}")
async def fetch_civitai(self, request):
print("Received fetch-civitai request") # Debug log
try:
data = await request.json()
print(f"Request data: {data}") # Debug log
client = CivitaiClient()
try:
# 1. 获取CivitAI元数据
civitai_metadata = await client.get_model_by_hash(data["sha256"])
if not civitai_metadata:
return web.json_response(
{"success": False, "error": "Not found on CivitAI"},
status=404
)
# 2. 读取/创建本地元数据文件
metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json'
# 合并元数据
local_metadata = {}
if os.path.exists(metadata_path):
with open(metadata_path, 'r', encoding='utf-8') as f:
local_metadata = json.load(f)
# 3. 更新元数据字段
local_metadata['civitai']=civitai_metadata
# 更新模型名称优先使用CivitAI名称
if 'model' in civitai_metadata:
local_metadata['model_name'] = civitai_metadata['model'].get('name', local_metadata.get('model_name'))
# 4. 下载预览图
first_image = next((img for img in civitai_metadata.get('images', []) if img.get('type') == 'image'), None)
if first_image:
preview_extension = os.path.splitext(first_image['url'])[-1] # Get the image file extension
preview_filename = os.path.splitext(os.path.basename(data['file_path']))[0] + preview_extension
preview_path = os.path.join(os.path.dirname(data['file_path']), preview_filename)
await client.download_preview_image(first_image['url'], preview_path)
# 存储相对路径,使用正斜杠格式
local_metadata['preview_url'] = os.path.relpath(preview_path, self.loras_root).replace(os.sep, '/')
# 5. 保存更新后的元数据
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
return web.json_response({
"success": True
})
except Exception as e:
print(f"Error in fetch_civitai: {str(e)}") # Debug log
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
finally:
await client.close()
except Exception as e:
print(f"Error processing request: {str(e)}") # Debug log
return web.json_response({
"success": False,
"error": f"Request processing error: {str(e)}"
}, status=400)
# 注册路由 # 注册路由
LorasEndpoint.add_routes() LorasEndpoint.add_routes()

1
services/__init__.py Normal file
View File

@@ -0,0 +1 @@
# Empty file to mark directory as Python package

View File

@@ -0,0 +1,35 @@
import aiohttp
import os
import json
from typing import Optional, Dict
class CivitaiClient:
def __init__(self):
self.base_url = "https://civitai.com/api/v1"
self.session = aiohttp.ClientSession()
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
try:
async with self.session.get(f"{self.base_url}/model-versions/by-hash/{model_hash}") as response:
if response.status == 200:
return await response.json()
return None
except Exception as e:
print(f"API Error: {str(e)}")
return None
async def download_preview_image(self, image_url: str, save_path: str):
try:
async with self.session.get(image_url) as response:
if response.status == 200:
content = await response.read()
with open(save_path, 'wb') as f:
f.write(content)
return True
return False
except Exception as e:
print(f"Download Error: {str(e)}")
return False
async def close(self):
await self.session.close()

View File

@@ -57,7 +57,7 @@ async function deleteModel(modelName) {
event.stopPropagation(); event.stopPropagation();
// Get the folder from the card's data attributes // Get the folder from the card's data attributes
const card = document.querySelector(`.lora-card[data-name="${modelName}"]`); const card = document.querySelector(`.lora-card[data-file_name="${modelName}"]`);
const folder = card ? card.dataset.folder : null; const folder = card ? card.dataset.folder : null;
// Show confirmation dialog // Show confirmation dialog
@@ -248,4 +248,78 @@ function preloadImages(urls) {
}); });
} }
// 新增 fetchCivitai 函数
async function fetchCivitai() {
const loadingOverlay = document.getElementById('loading-overlay');
const progressBar = document.querySelector('.progress-bar');
const loadingStatus = document.querySelector('.loading-status');
const loraCards = document.querySelectorAll('.lora-card');
// 显示进度条
loadingOverlay.style.display = 'flex';
loadingStatus.textContent = 'Fetching metadata...';
try {
// Iterate through all lora cards
for(let i = 0; i < loraCards.length; i++) {
const card = loraCards[i];
// Skip if already has metadata
if (card.dataset.meta && Object.keys(JSON.parse(card.dataset.meta)).length > 0) {
continue;
}
// Make sure these data attributes exist on your lora-card elements
const sha256 = card.dataset.sha256;
const filePath = card.dataset.filepath;
// Add validation
if (!sha256 || !filePath) {
console.warn(`Missing data for card ${card.dataset.name}:`, { sha256, filePath });
continue;
}
// Update progress
const progress = (i / loraCards.length * 100).toFixed(1);
progressBar.style.width = `${progress}%`;
loadingStatus.textContent = `Processing (${i+1}/${loraCards.length}) ${card.dataset.name}`;
// Call backend API
const response = await fetch('/api/fetch-civitai', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
sha256: sha256,
file_path: filePath
})
});
// if(!response.ok) {
// const errorText = await response.text();
// throw new Error(`HTTP error! status: ${response.status}, message: ${errorText}`);
// }
// // Optional: Update the card with new metadata
// const result = await response.json();
// if (result.success && result.metadata) {
// card.dataset.meta = JSON.stringify(result.metadata);
// // Update card display if needed
// }
}
// Completion handling
progressBar.style.width = '100%';
loadingStatus.textContent = 'Metadata update complete';
setTimeout(() => {
loadingOverlay.style.display = 'none';
// Optionally reload the page to show updated data
window.location.reload();
}, 2000);
} catch (error) {
console.warn('Error fetching metadata:', error);
}
}
initTheme(); initTheme();

View File

@@ -48,9 +48,15 @@
<div class="card-grid" id="loraGrid"> <div class="card-grid" id="loraGrid">
{% for lora in loras %} {% for lora in loras %}
<!-- 在卡片部分更新元数据展示 --> <!-- 在卡片部分更新元数据展示 -->
<div class="lora-card" data-name="{{ lora.model_name }}" data-file_name="{{ lora.file_name }}" data-folder="{{ lora.folder }}" data-meta="{{ lora.civitai | default({}) | tojson | forceescape }}"> <div class="lora-card"
data-sha256="{{ lora.sha256 }}"
data-filepath="{{ lora.file_path }}"
data-name="{{ lora.model_name }}"
data-file_name="{{ lora.file_name }}"
data-folder="{{ lora.folder }}"
data-meta="{{ lora.civitai | default({}) | tojson | forceescape }}">
<div class="card-preview"> <div class="card-preview">
<img src="{{ lora.preview_url or '/loras_static/images/no-preview.png' }}" alt="{{ lora.name }}"> <img src="{{ ('/loras_static/previews/' + lora.preview_url) if lora.preview_url else '/loras_static/images/no-preview.png' }}" alt="{{ lora.name }}">
<div class="card-header"> <div class="card-header">
<span class="base-model-label" title="{{ lora.base_model }}"> <span class="base-model-label" title="{{ lora.base_model }}">
{{ lora.base_model }} {{ lora.base_model }}