Update metadata structure

This commit is contained in:
Will Miao
2025-01-25 22:19:32 +08:00
parent 1100363427
commit 9d64ebc5d6
4 changed files with 100 additions and 155 deletions

172
nodes.py
View File

@@ -1,4 +1,3 @@
# nodes.py 更新后的核心代码
import os
import json
import time
@@ -22,6 +21,8 @@ class LorasEndpoint:
)
# 配置Loras根目录根据实际安装位置调整
self.loras_root = os.path.join(Path(__file__).parents[2], "models", "loras")
# 添加 server 属性
self.server = PromptServer.instance
@classmethod
def add_routes(cls):
@@ -37,109 +38,45 @@ class LorasEndpoint:
def send_progress(self, current, total, status="Scanning"):
"""Send progress through websocket"""
if self.server and hasattr(self.server, 'send_sync'):
self.server.send_sync("lora-scan-progress", {
"value": current,
"max": total,
"status": status
})
try:
if hasattr(self.server, 'send_sync'):
self.server.send_sync("lora-scan-progress", {
"value": current,
"max": total,
"status": status
})
except Exception as e:
print(f"Error sending progress: {str(e)}")
async def scan_loras(self):
"""扫描Loras目录并返回结构化数据"""
loras = []
folders = set()
# 遍历Loras目录包含子目录
for root, _, files in os.walk(self.loras_root):
rel_path = os.path.relpath(root, self.loras_root)
if rel_path == ".":
current_folder = "root"
else:
current_folder = rel_path.replace(os.sep, "/")
folders.add(current_folder)
for file in files:
safetensors_files = [f for f in files if f.endswith('.safetensors')]
total_files = len(safetensors_files)
safetensors_files = [f for f in files if f.endswith('.safetensors')]
total_files = len(safetensors_files)
for idx, filename in enumerate(safetensors_files, 1):
self.send_progress(idx, total_files, f"Scanning: {filename}")
# 识别模型文件
if file.endswith('.safetensors'):
base_name = os.path.splitext(file)[0]
model_path = os.path.join(root, file)
# Get basic file info and metadata
file_info = await get_file_info(model_path)
base_model_info = await extract_lora_metadata(model_path)
file_info.update(base_model_info)
# Load existing metadata or create new one
metadata = await load_metadata(model_path)
if not metadata:
# First time scanning this file
await save_metadata(model_path, file_info)
metadata = file_info
else:
# Update basic file info in existing metadata
metadata.update(file_info)
await save_metadata(model_path, metadata)
# Add civitai data to return value if exists
if 'civitai' in metadata:
metadata.update(metadata['civitai'])
# 查找预览图
preview_path = os.path.join(root, f"{base_name}.preview.png")
preview_url = await self.get_preview_url(preview_path, root) if os.path.exists(preview_path) else None
loras.append({
"name": base_name,
"folder": current_folder,
"path": model_path,
"preview_url": preview_url,
"metadata": metadata,
"size": os.path.getsize(model_path),
"modified": os.path.getmtime(model_path)
})
file_path = os.path.join(root, filename)
# Try to load existing metadata first
metadata = await load_metadata(file_path)
if metadata is None:
# Only get file info and extract metadata if no existing metadata
metadata = await get_file_info(file_path)
base_model_info = await extract_lora_metadata(file_path)
metadata.base_model = base_model_info['base_model']
await save_metadata(file_path, metadata)
# Convert to dict for API response
lora_data = metadata.to_dict()
loras.append(lora_data)
self.send_progress(total_files, total_files, "Scan completed")
return {
"loras": sorted(loras, key=lambda x: x["name"].lower()),
"folders": sorted(folders)
}
return loras
async def parse_model_metadata(self, file_path):
"""从safetensors文件中提取元数据"""
try:
with safe_open(file_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
if metadata:
return metadata
except Exception as e:
print(f"Error reading metadata from {file_path}: {str(e)}")
return {}
async def parse_metadata(self, meta_file):
"""解析元数据文件"""
try:
if os.path.exists(meta_file):
with open(meta_file, 'r', encoding='utf-8') as f:
meta = json.load(f)
return {
"id": meta.get("id"),
"modelId": meta.get("modelId"),
"model": meta.get("model", {}).get("name"),
"base_model": meta.get("baseModel"),
"trained_words": meta.get("trainedWords", []),
"creator": meta.get("creator", {}).get("username"),
"downloads": meta.get("stats", {}).get("downloadCount", 0),
"images": [img["url"] for img in meta.get("images", [])[:3]],
"description": self.clean_description(
meta.get("model", {}).get("description", "")
)
}
except Exception as e:
print(f"Error parsing metadata {meta_file}: {str(e)}")
return {}
def clean_description(self, desc):
"""清理HTML格式的描述"""
@@ -157,10 +94,10 @@ class LorasEndpoint:
try:
scan_start = time.time()
data = await self.scan_loras()
print(f"Scanned {len(data['loras'])} loras in {time.time()-scan_start:.2f}s")
print(f"Scanned {len(data)} loras in {time.time()-scan_start:.2f}s")
# Format the data for the template
formatted_loras = [self.format_lora(l) for l in data["loras"]]
formatted_loras = [self.format_lora(l) for l in data]
# Debug logging
if formatted_loras:
@@ -169,17 +106,13 @@ class LorasEndpoint:
print("Warning: No loras found")
context = {
"folders": data.get("folders", []),
"loras": formatted_loras,
# Only set single lora if we're viewing details
"lora": formatted_loras[0] if formatted_loras else {
"name": "",
"folder": "",
"model_name": "",
"file_name": "",
"preview_url": "",
"modified": "",
"size": "0MB",
"meta": {
"civitai": {
"id": "",
"model": "",
"base_model": "",
@@ -211,37 +144,20 @@ class LorasEndpoint:
def format_lora(self, lora):
"""格式化前端需要的数据结构"""
try:
metadata = lora.get("metadata", {})
return {
"name": lora["name"],
"folder": lora["folder"],
"model_name": lora["model_name"],
"file_name": lora["file_name"],
"preview_url": lora["preview_url"],
"modified": time.strftime("%Y-%m-%d %H:%M",
time.localtime(lora["modified"])),
"size": f"{lora['size']/1024/1024:.1f}MB",
"meta": {
"id": metadata.get("id", ""),
"modelId": metadata.get("modelId", ""),
"model": metadata.get("model", ""),
"base_model": metadata.get("base_model", ""),
"trained_words": metadata.get("trained_words", []),
"creator": metadata.get("creator", ""),
"downloads": metadata.get("downloads", 0),
"images": metadata.get("images", []),
"description": metadata.get("description", "")
}
"civitai": lora.get("civitai", {}) or {} # 确保当 civitai 为 None 时返回空字典
}
except Exception as e:
print(f"Error formatting lora: {str(e)}")
print(f"Lora data: {lora}")
return {
"name": lora.get("name", "Unknown"),
"folder": lora.get("folder", ""),
"model_name": lora.get("model_name", "Unknown"),
"file_name": lora.get("file_name", ""),
"preview_url": lora.get("preview_url", ""),
"modified": "",
"size": "0MB",
"meta": {
"civitai": {
"id": "",
"modelId": "",
"model": "",