diff --git a/__init__.py b/__init__.py index 265926e6..9381872f 100644 --- a/__init__.py +++ b/__init__.py @@ -9,5 +9,8 @@ WEB_DIRECTORY = "./js" # Add custom websocket event type EXTENSION_WEB_SOCKET_MESSAGE_TYPES = ["lora-scan-progress"] -__all__ = ['NODE_CLASS_MAPPINGS'] +# Add this init function to properly register routes +def init(): + LorasEndpoint.add_routes() +__all__ = ['NODE_CLASS_MAPPINGS'] \ No newline at end of file diff --git a/nodes.py b/nodes.py index fbd8e5ba..3849bedb 100644 --- a/nodes.py +++ b/nodes.py @@ -5,11 +5,11 @@ from pathlib import Path from aiohttp import web from server import PromptServer import jinja2 -from flask import jsonify, request from safetensors import safe_open from .utils.file_utils import get_file_info, save_metadata, load_metadata, update_civitai_metadata from .utils.lora_metadata import extract_lora_metadata from typing import Dict, Optional +from .services.civitai_client import CivitaiClient class LorasEndpoint: def __init__(self): @@ -33,7 +33,8 @@ class LorasEndpoint: web.get('/loras', instance.handle_loras_request), web.static('/loras_static/previews', instance.loras_root), web.static('/loras_static', static_path), - web.post('/api/delete_model', instance.delete_model) + web.post('/api/delete_model', instance.delete_model), + web.post('/api/fetch-civitai', instance.fetch_civitai) ]) def send_progress(self, current, total, status="Scanning"): @@ -105,8 +106,6 @@ class LorasEndpoint: formatted_loras = [self.format_lora(l) for l in data] folders = sorted(list(set(l['folder'] for l in data))) - print("folders:",folders) - # Debug logging if formatted_loras: print(f"Sample lora data: {formatted_loras[0]}") @@ -160,6 +159,8 @@ class LorasEndpoint: "preview_url": lora["preview_url"], "base_model": lora["base_model"], "folder": lora["folder"], + "sha256": lora["sha256"], + "file_path": lora["file_path"], "civitai": lora.get("civitai", {}) or {} # 确保当 civitai 为 None 时返回空字典 } except Exception as e: @@ -171,6 +172,8 @@ class LorasEndpoint: "preview_url": lora.get("preview_url", ""), "base_model": lora.get("base_model", ""), "folder": lora.get("folder", ""), + "sha256": lora.get("sha256", ""), + "file_path": lora.get("file_path", ""), "civitai": { "id": "", "modelId": "", @@ -252,5 +255,71 @@ class LorasEndpoint: except Exception as e: print(f"Error downloading preview image: {str(e)}") + async def fetch_civitai(self, request): + print("Received fetch-civitai request") # Debug log + try: + data = await request.json() + print(f"Request data: {data}") # Debug log + client = CivitaiClient() + + try: + # 1. 获取CivitAI元数据 + civitai_metadata = await client.get_model_by_hash(data["sha256"]) + if not civitai_metadata: + return web.json_response( + {"success": False, "error": "Not found on CivitAI"}, + status=404 + ) + + # 2. 读取/创建本地元数据文件 + metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json' + + # 合并元数据 + local_metadata = {} + if os.path.exists(metadata_path): + with open(metadata_path, 'r', encoding='utf-8') as f: + local_metadata = json.load(f) + + # 3. 更新元数据字段 + local_metadata['civitai']=civitai_metadata + + # 更新模型名称(优先使用CivitAI名称) + if 'model' in civitai_metadata: + local_metadata['model_name'] = civitai_metadata['model'].get('name', local_metadata.get('model_name')) + + # 4. 下载预览图 + first_image = next((img for img in civitai_metadata.get('images', []) if img.get('type') == 'image'), None) + if first_image: + preview_extension = os.path.splitext(first_image['url'])[-1] # Get the image file extension + preview_filename = os.path.splitext(os.path.basename(data['file_path']))[0] + preview_extension + preview_path = os.path.join(os.path.dirname(data['file_path']), preview_filename) + await client.download_preview_image(first_image['url'], preview_path) + # 存储相对路径,使用正斜杠格式 + local_metadata['preview_url'] = os.path.relpath(preview_path, self.loras_root).replace(os.sep, '/') + + # 5. 保存更新后的元数据 + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(local_metadata, f, indent=2, ensure_ascii=False) + + return web.json_response({ + "success": True + }) + + except Exception as e: + print(f"Error in fetch_civitai: {str(e)}") # Debug log + return web.json_response({ + "success": False, + "error": str(e) + }, status=500) + finally: + await client.close() + + except Exception as e: + print(f"Error processing request: {str(e)}") # Debug log + return web.json_response({ + "success": False, + "error": f"Request processing error: {str(e)}" + }, status=400) + # 注册路由 LorasEndpoint.add_routes() \ No newline at end of file diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 00000000..3aae30f6 --- /dev/null +++ b/services/__init__.py @@ -0,0 +1 @@ +# Empty file to mark directory as Python package \ No newline at end of file diff --git a/services/civitai_client.py b/services/civitai_client.py new file mode 100644 index 00000000..b34c0149 --- /dev/null +++ b/services/civitai_client.py @@ -0,0 +1,35 @@ +import aiohttp +import os +import json +from typing import Optional, Dict + +class CivitaiClient: + def __init__(self): + self.base_url = "https://civitai.com/api/v1" + self.session = aiohttp.ClientSession() + + async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: + try: + async with self.session.get(f"{self.base_url}/model-versions/by-hash/{model_hash}") as response: + if response.status == 200: + return await response.json() + return None + except Exception as e: + print(f"API Error: {str(e)}") + return None + + async def download_preview_image(self, image_url: str, save_path: str): + try: + async with self.session.get(image_url) as response: + if response.status == 200: + content = await response.read() + with open(save_path, 'wb') as f: + f.write(content) + return True + return False + except Exception as e: + print(f"Download Error: {str(e)}") + return False + + async def close(self): + await self.session.close() \ No newline at end of file diff --git a/static/js/script.js b/static/js/script.js index c73b3432..8b0f64e4 100644 --- a/static/js/script.js +++ b/static/js/script.js @@ -57,7 +57,7 @@ async function deleteModel(modelName) { event.stopPropagation(); // Get the folder from the card's data attributes - const card = document.querySelector(`.lora-card[data-name="${modelName}"]`); + const card = document.querySelector(`.lora-card[data-file_name="${modelName}"]`); const folder = card ? card.dataset.folder : null; // Show confirmation dialog @@ -248,4 +248,78 @@ function preloadImages(urls) { }); } +// 新增 fetchCivitai 函数 +async function fetchCivitai() { + const loadingOverlay = document.getElementById('loading-overlay'); + const progressBar = document.querySelector('.progress-bar'); + const loadingStatus = document.querySelector('.loading-status'); + const loraCards = document.querySelectorAll('.lora-card'); + + // 显示进度条 + loadingOverlay.style.display = 'flex'; + loadingStatus.textContent = 'Fetching metadata...'; + + try { + // Iterate through all lora cards + for(let i = 0; i < loraCards.length; i++) { + const card = loraCards[i]; + // Skip if already has metadata + if (card.dataset.meta && Object.keys(JSON.parse(card.dataset.meta)).length > 0) { + continue; + } + + // Make sure these data attributes exist on your lora-card elements + const sha256 = card.dataset.sha256; + const filePath = card.dataset.filepath; + + // Add validation + if (!sha256 || !filePath) { + console.warn(`Missing data for card ${card.dataset.name}:`, { sha256, filePath }); + continue; + } + + // Update progress + const progress = (i / loraCards.length * 100).toFixed(1); + progressBar.style.width = `${progress}%`; + loadingStatus.textContent = `Processing (${i+1}/${loraCards.length}) ${card.dataset.name}`; + + // Call backend API + const response = await fetch('/api/fetch-civitai', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + sha256: sha256, + file_path: filePath + }) + }); + + // if(!response.ok) { + // const errorText = await response.text(); + // throw new Error(`HTTP error! status: ${response.status}, message: ${errorText}`); + // } + + // // Optional: Update the card with new metadata + // const result = await response.json(); + // if (result.success && result.metadata) { + // card.dataset.meta = JSON.stringify(result.metadata); + // // Update card display if needed + // } + } + + // Completion handling + progressBar.style.width = '100%'; + loadingStatus.textContent = 'Metadata update complete'; + setTimeout(() => { + loadingOverlay.style.display = 'none'; + // Optionally reload the page to show updated data + window.location.reload(); + }, 2000); + + } catch (error) { + console.warn('Error fetching metadata:', error); + } +} + initTheme(); \ No newline at end of file diff --git a/templates/loras.html b/templates/loras.html index c05bde25..966b70bc 100644 --- a/templates/loras.html +++ b/templates/loras.html @@ -48,9 +48,15 @@
+