From 9d64ebc5d632ea68081a51998ba9aa83279832ee Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Sat, 25 Jan 2025 22:19:32 +0800 Subject: [PATCH] Update metadata structure --- nodes.py | 172 +++++++++++-------------------------------- static/js/script.js | 24 +++++- templates/loras.html | 26 ++++--- utils/file_utils.py | 33 +++++---- 4 files changed, 100 insertions(+), 155 deletions(-) diff --git a/nodes.py b/nodes.py index 779ad743..70b043fd 100644 --- a/nodes.py +++ b/nodes.py @@ -1,4 +1,3 @@ -# nodes.py 更新后的核心代码 import os import json import time @@ -22,6 +21,8 @@ class LorasEndpoint: ) # 配置Loras根目录(根据实际安装位置调整) self.loras_root = os.path.join(Path(__file__).parents[2], "models", "loras") + # 添加 server 属性 + self.server = PromptServer.instance @classmethod def add_routes(cls): @@ -37,109 +38,45 @@ class LorasEndpoint: def send_progress(self, current, total, status="Scanning"): """Send progress through websocket""" - if self.server and hasattr(self.server, 'send_sync'): - self.server.send_sync("lora-scan-progress", { - "value": current, - "max": total, - "status": status - }) + try: + if hasattr(self.server, 'send_sync'): + self.server.send_sync("lora-scan-progress", { + "value": current, + "max": total, + "status": status + }) + except Exception as e: + print(f"Error sending progress: {str(e)}") async def scan_loras(self): - """扫描Loras目录并返回结构化数据""" loras = [] - folders = set() - - # 遍历Loras目录(包含子目录) for root, _, files in os.walk(self.loras_root): - rel_path = os.path.relpath(root, self.loras_root) - if rel_path == ".": - current_folder = "root" - else: - current_folder = rel_path.replace(os.sep, "/") - folders.add(current_folder) - - for file in files: - safetensors_files = [f for f in files if f.endswith('.safetensors')] - total_files = len(safetensors_files) + safetensors_files = [f for f in files if f.endswith('.safetensors')] + total_files = len(safetensors_files) + + for idx, filename in enumerate(safetensors_files, 1): + self.send_progress(idx, total_files, f"Scanning: {filename}") - # 识别模型文件 - if file.endswith('.safetensors'): - base_name = os.path.splitext(file)[0] - model_path = os.path.join(root, file) - - # Get basic file info and metadata - file_info = await get_file_info(model_path) - base_model_info = await extract_lora_metadata(model_path) - file_info.update(base_model_info) - - # Load existing metadata or create new one - metadata = await load_metadata(model_path) - if not metadata: - # First time scanning this file - await save_metadata(model_path, file_info) - metadata = file_info - else: - # Update basic file info in existing metadata - metadata.update(file_info) - await save_metadata(model_path, metadata) - - # Add civitai data to return value if exists - if 'civitai' in metadata: - metadata.update(metadata['civitai']) - - # 查找预览图 - preview_path = os.path.join(root, f"{base_name}.preview.png") - preview_url = await self.get_preview_url(preview_path, root) if os.path.exists(preview_path) else None - - loras.append({ - "name": base_name, - "folder": current_folder, - "path": model_path, - "preview_url": preview_url, - "metadata": metadata, - "size": os.path.getsize(model_path), - "modified": os.path.getmtime(model_path) - }) - + file_path = os.path.join(root, filename) + + # Try to load existing metadata first + metadata = await load_metadata(file_path) + + if metadata is None: + # Only get file info and extract metadata if no existing metadata + metadata = await get_file_info(file_path) + base_model_info = await extract_lora_metadata(file_path) + metadata.base_model = base_model_info['base_model'] + await save_metadata(file_path, metadata) + + # Convert to dict for API response + lora_data = metadata.to_dict() + + loras.append(lora_data) + self.send_progress(total_files, total_files, "Scan completed") - return { - "loras": sorted(loras, key=lambda x: x["name"].lower()), - "folders": sorted(folders) - } + return loras - async def parse_model_metadata(self, file_path): - """从safetensors文件中提取元数据""" - try: - with safe_open(file_path, framework="pt", device="cpu") as f: - metadata = f.metadata() - if metadata: - return metadata - except Exception as e: - print(f"Error reading metadata from {file_path}: {str(e)}") - return {} - - async def parse_metadata(self, meta_file): - """解析元数据文件""" - try: - if os.path.exists(meta_file): - with open(meta_file, 'r', encoding='utf-8') as f: - meta = json.load(f) - return { - "id": meta.get("id"), - "modelId": meta.get("modelId"), - "model": meta.get("model", {}).get("name"), - "base_model": meta.get("baseModel"), - "trained_words": meta.get("trainedWords", []), - "creator": meta.get("creator", {}).get("username"), - "downloads": meta.get("stats", {}).get("downloadCount", 0), - "images": [img["url"] for img in meta.get("images", [])[:3]], - "description": self.clean_description( - meta.get("model", {}).get("description", "") - ) - } - except Exception as e: - print(f"Error parsing metadata {meta_file}: {str(e)}") - return {} def clean_description(self, desc): """清理HTML格式的描述""" @@ -157,10 +94,10 @@ class LorasEndpoint: try: scan_start = time.time() data = await self.scan_loras() - print(f"Scanned {len(data['loras'])} loras in {time.time()-scan_start:.2f}s") + print(f"Scanned {len(data)} loras in {time.time()-scan_start:.2f}s") # Format the data for the template - formatted_loras = [self.format_lora(l) for l in data["loras"]] + formatted_loras = [self.format_lora(l) for l in data] # Debug logging if formatted_loras: @@ -169,17 +106,13 @@ class LorasEndpoint: print("Warning: No loras found") context = { - "folders": data.get("folders", []), "loras": formatted_loras, # Only set single lora if we're viewing details "lora": formatted_loras[0] if formatted_loras else { - "name": "", - "folder": "", + "model_name": "", "file_name": "", "preview_url": "", - "modified": "", - "size": "0MB", - "meta": { + "civitai": { "id": "", "model": "", "base_model": "", @@ -211,37 +144,20 @@ class LorasEndpoint: def format_lora(self, lora): """格式化前端需要的数据结构""" try: - metadata = lora.get("metadata", {}) - return { - "name": lora["name"], - "folder": lora["folder"], + "model_name": lora["model_name"], + "file_name": lora["file_name"], "preview_url": lora["preview_url"], - "modified": time.strftime("%Y-%m-%d %H:%M", - time.localtime(lora["modified"])), - "size": f"{lora['size']/1024/1024:.1f}MB", - "meta": { - "id": metadata.get("id", ""), - "modelId": metadata.get("modelId", ""), - "model": metadata.get("model", ""), - "base_model": metadata.get("base_model", ""), - "trained_words": metadata.get("trained_words", []), - "creator": metadata.get("creator", ""), - "downloads": metadata.get("downloads", 0), - "images": metadata.get("images", []), - "description": metadata.get("description", "") - } + "civitai": lora.get("civitai", {}) or {} # 确保当 civitai 为 None 时返回空字典 } except Exception as e: print(f"Error formatting lora: {str(e)}") print(f"Lora data: {lora}") return { - "name": lora.get("name", "Unknown"), - "folder": lora.get("folder", ""), + "model_name": lora.get("model_name", "Unknown"), + "file_name": lora.get("file_name", ""), "preview_url": lora.get("preview_url", ""), - "modified": "", - "size": "0MB", - "meta": { + "civitai": { "id": "", "modelId": "", "model": "", diff --git a/static/js/script.js b/static/js/script.js index cd45705d..2cff5be4 100644 --- a/static/js/script.js +++ b/static/js/script.js @@ -173,10 +173,23 @@ document.addEventListener('DOMContentLoaded', function() { const progressBar = document.querySelector('.progress-bar'); const loadingStatus = document.querySelector('.loading-status'); - // Show loading overlay initially - loadingOverlay.style.display = 'flex'; + // 默认隐藏 loading overlay + loadingOverlay.style.display = 'none'; + + const api = new EventTarget(); + window.api = api; + + const ws = new WebSocket(`ws://${window.location.host}/ws`); + + ws.onmessage = function(event) { + const data = JSON.parse(event.data); + if (data.type === 'lora-scan-progress') { + // 当收到扫描进度消息时显示 overlay + loadingOverlay.style.display = 'flex'; + api.dispatchEvent(new CustomEvent('lora-scan-progress', { detail: data })); + } + }; - // Listen for progress updates api.addEventListener("lora-scan-progress", (event) => { const data = event.detail; const progress = (data.value / data.max) * 100; @@ -186,9 +199,12 @@ document.addEventListener('DOMContentLoaded', function() { loadingStatus.textContent = data.status; if (data.value === data.max) { - // Hide loading overlay when scan is complete + // 确保在扫描完成时隐藏 overlay setTimeout(() => { loadingOverlay.style.display = 'none'; + // 重置进度条 + progressBar.style.width = '0%'; + progressBar.setAttribute('aria-valuenow', 0); }, 500); } }); diff --git a/templates/loras.html b/templates/loras.html index b313de77..9514618f 100644 --- a/templates/loras.html +++ b/templates/loras.html @@ -14,6 +14,15 @@ +
@@ -39,34 +48,33 @@
{% for lora in loras %} -
+
{{ lora.name }}
- - {{ lora.meta.base_model if lora.meta and lora.meta.base_model else 'Unknown' }} + + {{ lora.base_model }}
+ onclick="event.stopPropagation(); openCivitai('{{ lora.file_name }}')"> + onclick="event.stopPropagation(); navigator.clipboard.writeText(this.closest('.lora-card').dataset.file_name)"> + onclick="event.stopPropagation(); deleteModel('{{ lora.file_name }}')">
diff --git a/utils/file_utils.py b/utils/file_utils.py index 19bafa35..8eec80ac 100644 --- a/utils/file_utils.py +++ b/utils/file_utils.py @@ -2,6 +2,7 @@ import os import hashlib import json from typing import Dict, Optional +from .models import LoraMetadata async def calculate_sha256(file_path: str) -> str: """Calculate SHA256 hash of a file""" @@ -11,35 +12,39 @@ async def calculate_sha256(file_path: str) -> str: sha256_hash.update(byte_block) return sha256_hash.hexdigest() -async def get_file_info(file_path: str) -> Dict: - """Get basic file information""" - return { - "name": os.path.splitext(os.path.basename(file_path))[0], - "file_path": file_path, - "size": os.path.getsize(file_path), - "modified": os.path.getmtime(file_path), - "sha256": await calculate_sha256(file_path) - } +async def get_file_info(file_path: str) -> LoraMetadata: + """Get basic file information as LoraMetadata object""" + return LoraMetadata( + file_name=os.path.splitext(os.path.basename(file_path))[0], + model_name=os.path.splitext(os.path.basename(file_path))[0], + file_path=file_path, + size=os.path.getsize(file_path), + modified=os.path.getmtime(file_path), + sha256=await calculate_sha256(file_path), + base_model="Unknown", # Will be updated later + preview_url="", + ) -async def save_metadata(file_path: str, metadata: Dict) -> None: +async def save_metadata(file_path: str, metadata: LoraMetadata) -> None: """Save metadata to .metadata.json file""" metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json" try: with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(metadata, f, indent=2, ensure_ascii=False) + json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False) except Exception as e: print(f"Error saving metadata to {metadata_path}: {str(e)}") -async def load_metadata(file_path: str) -> Dict: +async def load_metadata(file_path: str) -> Optional[LoraMetadata]: """Load metadata from .metadata.json file""" metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json" try: if os.path.exists(metadata_path): with open(metadata_path, 'r', encoding='utf-8') as f: - return json.load(f) + data = json.load(f) + return LoraMetadata.from_dict(data) except Exception as e: print(f"Error loading metadata from {metadata_path}: {str(e)}") - return {} + return None async def update_civitai_metadata(file_path: str, civitai_data: Dict) -> None: """Update metadata file with Civitai data"""