mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-26 07:35:44 -03:00
Initial commit
This commit is contained in:
326
nodes.py
Normal file
326
nodes.py
Normal file
@@ -0,0 +1,326 @@
|
||||
# nodes.py 更新后的核心代码
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from aiohttp import web
|
||||
from server import PromptServer
|
||||
import jinja2
|
||||
from flask import jsonify, request
|
||||
from safetensors import safe_open
|
||||
from .utils.file_utils import get_file_info, save_metadata, load_metadata, update_civitai_metadata
|
||||
from .utils.lora_metadata import extract_lora_metadata
|
||||
from typing import Dict, Optional
|
||||
|
||||
class LorasEndpoint:
|
||||
def __init__(self):
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(
|
||||
os.path.join(os.path.dirname(__file__), 'templates')
|
||||
),
|
||||
autoescape=True
|
||||
)
|
||||
# 配置Loras根目录(根据实际安装位置调整)
|
||||
self.loras_root = os.path.join(Path(__file__).parents[2], "models", "loras")
|
||||
|
||||
@classmethod
|
||||
def add_routes(cls):
|
||||
instance = cls()
|
||||
app = PromptServer.instance.app
|
||||
static_path = os.path.join(os.path.dirname(__file__), 'static')
|
||||
app.add_routes([
|
||||
web.get('/loras', instance.handle_loras_request),
|
||||
web.static('/loras_static/previews', instance.loras_root),
|
||||
web.static('/loras_static', static_path),
|
||||
web.post('/api/delete_model', instance.delete_model)
|
||||
])
|
||||
|
||||
def send_progress(self, current, total, status="Scanning"):
|
||||
"""Send progress through websocket"""
|
||||
if self.server and hasattr(self.server, 'send_sync'):
|
||||
self.server.send_sync("lora-scan-progress", {
|
||||
"value": current,
|
||||
"max": total,
|
||||
"status": status
|
||||
})
|
||||
|
||||
async def scan_loras(self):
|
||||
"""扫描Loras目录并返回结构化数据"""
|
||||
loras = []
|
||||
folders = set()
|
||||
|
||||
# 遍历Loras目录(包含子目录)
|
||||
for root, _, files in os.walk(self.loras_root):
|
||||
rel_path = os.path.relpath(root, self.loras_root)
|
||||
if rel_path == ".":
|
||||
current_folder = "root"
|
||||
else:
|
||||
current_folder = rel_path.replace(os.sep, "/")
|
||||
folders.add(current_folder)
|
||||
|
||||
for file in files:
|
||||
safetensors_files = [f for f in files if f.endswith('.safetensors')]
|
||||
total_files = len(safetensors_files)
|
||||
|
||||
# 识别模型文件
|
||||
if file.endswith('.safetensors'):
|
||||
base_name = os.path.splitext(file)[0]
|
||||
model_path = os.path.join(root, file)
|
||||
|
||||
# Get basic file info and metadata
|
||||
file_info = await get_file_info(model_path)
|
||||
base_model_info = await extract_lora_metadata(model_path)
|
||||
file_info.update(base_model_info)
|
||||
|
||||
# Load existing metadata or create new one
|
||||
metadata = await load_metadata(model_path)
|
||||
if not metadata:
|
||||
# First time scanning this file
|
||||
await save_metadata(model_path, file_info)
|
||||
metadata = file_info
|
||||
else:
|
||||
# Update basic file info in existing metadata
|
||||
metadata.update(file_info)
|
||||
await save_metadata(model_path, metadata)
|
||||
|
||||
# Add civitai data to return value if exists
|
||||
if 'civitai' in metadata:
|
||||
metadata.update(metadata['civitai'])
|
||||
|
||||
# 查找预览图
|
||||
preview_path = os.path.join(root, f"{base_name}.preview.png")
|
||||
preview_url = await self.get_preview_url(preview_path, root) if os.path.exists(preview_path) else None
|
||||
|
||||
loras.append({
|
||||
"name": base_name,
|
||||
"folder": current_folder,
|
||||
"path": model_path,
|
||||
"preview_url": preview_url,
|
||||
"metadata": metadata,
|
||||
"size": os.path.getsize(model_path),
|
||||
"modified": os.path.getmtime(model_path)
|
||||
})
|
||||
|
||||
self.send_progress(total_files, total_files, "Scan completed")
|
||||
return {
|
||||
"loras": sorted(loras, key=lambda x: x["name"].lower()),
|
||||
"folders": sorted(folders)
|
||||
}
|
||||
|
||||
async def parse_model_metadata(self, file_path):
|
||||
"""从safetensors文件中提取元数据"""
|
||||
try:
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata:
|
||||
return metadata
|
||||
except Exception as e:
|
||||
print(f"Error reading metadata from {file_path}: {str(e)}")
|
||||
return {}
|
||||
|
||||
async def parse_metadata(self, meta_file):
|
||||
"""解析元数据文件"""
|
||||
try:
|
||||
if os.path.exists(meta_file):
|
||||
with open(meta_file, 'r', encoding='utf-8') as f:
|
||||
meta = json.load(f)
|
||||
return {
|
||||
"id": meta.get("id"),
|
||||
"modelId": meta.get("modelId"),
|
||||
"model": meta.get("model", {}).get("name"),
|
||||
"base_model": meta.get("baseModel"),
|
||||
"trained_words": meta.get("trainedWords", []),
|
||||
"creator": meta.get("creator", {}).get("username"),
|
||||
"downloads": meta.get("stats", {}).get("downloadCount", 0),
|
||||
"images": [img["url"] for img in meta.get("images", [])[:3]],
|
||||
"description": self.clean_description(
|
||||
meta.get("model", {}).get("description", "")
|
||||
)
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Error parsing metadata {meta_file}: {str(e)}")
|
||||
return {}
|
||||
|
||||
def clean_description(self, desc):
|
||||
"""清理HTML格式的描述"""
|
||||
return desc.replace("<p>", "").replace("</p>", "\n").strip()
|
||||
|
||||
async def get_preview_url(self, preview_path, root_dir):
|
||||
"""生成预览图URL"""
|
||||
if os.path.exists(preview_path):
|
||||
rel_path = os.path.relpath(preview_path, self.loras_root)
|
||||
return f"/loras_static/previews/{rel_path.replace(os.sep, '/')}"
|
||||
return "/loras_static/images/no-preview.png"
|
||||
|
||||
async def handle_loras_request(self, request):
|
||||
"""处理Loras请求并渲染模板"""
|
||||
try:
|
||||
scan_start = time.time()
|
||||
data = await self.scan_loras()
|
||||
print(f"Scanned {len(data['loras'])} loras in {time.time()-scan_start:.2f}s")
|
||||
|
||||
# Format the data for the template
|
||||
formatted_loras = [self.format_lora(l) for l in data["loras"]]
|
||||
|
||||
# Debug logging
|
||||
if formatted_loras:
|
||||
print(f"Sample lora data: {formatted_loras[0]}")
|
||||
else:
|
||||
print("Warning: No loras found")
|
||||
|
||||
context = {
|
||||
"folders": data.get("folders", []),
|
||||
"loras": formatted_loras,
|
||||
# Only set single lora if we're viewing details
|
||||
"lora": formatted_loras[0] if formatted_loras else {
|
||||
"name": "",
|
||||
"folder": "",
|
||||
"file_name": "",
|
||||
"preview_url": "",
|
||||
"modified": "",
|
||||
"size": "0MB",
|
||||
"meta": {
|
||||
"id": "",
|
||||
"model": "",
|
||||
"base_model": "",
|
||||
"trained_words": [],
|
||||
"creator": "",
|
||||
"downloads": 0,
|
||||
"images": [],
|
||||
"description": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(**context)
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error handling loras request: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc()) # Print full stack trace
|
||||
return web.Response(
|
||||
text="Error loading loras page",
|
||||
content_type='text/html',
|
||||
status=500
|
||||
)
|
||||
|
||||
def format_lora(self, lora):
|
||||
"""格式化前端需要的数据结构"""
|
||||
try:
|
||||
metadata = lora.get("metadata", {})
|
||||
|
||||
return {
|
||||
"name": lora["name"],
|
||||
"folder": lora["folder"],
|
||||
"preview_url": lora["preview_url"],
|
||||
"modified": time.strftime("%Y-%m-%d %H:%M",
|
||||
time.localtime(lora["modified"])),
|
||||
"size": f"{lora['size']/1024/1024:.1f}MB",
|
||||
"meta": {
|
||||
"id": metadata.get("id", ""),
|
||||
"modelId": metadata.get("modelId", ""),
|
||||
"model": metadata.get("model", ""),
|
||||
"base_model": metadata.get("base_model", ""),
|
||||
"trained_words": metadata.get("trained_words", []),
|
||||
"creator": metadata.get("creator", ""),
|
||||
"downloads": metadata.get("downloads", 0),
|
||||
"images": metadata.get("images", []),
|
||||
"description": metadata.get("description", "")
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Error formatting lora: {str(e)}")
|
||||
print(f"Lora data: {lora}")
|
||||
return {
|
||||
"name": lora.get("name", "Unknown"),
|
||||
"folder": lora.get("folder", ""),
|
||||
"preview_url": lora.get("preview_url", ""),
|
||||
"modified": "",
|
||||
"size": "0MB",
|
||||
"meta": {
|
||||
"id": "",
|
||||
"modelId": "",
|
||||
"model": "",
|
||||
"base_model": "",
|
||||
"trained_words": [],
|
||||
"creator": "",
|
||||
"downloads": 0,
|
||||
"images": [],
|
||||
"description": ""
|
||||
}
|
||||
}
|
||||
|
||||
async def delete_model(self, request):
|
||||
try:
|
||||
data = await request.json()
|
||||
model_name = data.get('model_name')
|
||||
folder = data.get('folder') # 从请求中获取folder信息
|
||||
if not model_name:
|
||||
return web.Response(text='Model name is required', status=400)
|
||||
|
||||
# 构建完整的目录路径
|
||||
target_dir = self.loras_root
|
||||
if folder and folder != "root":
|
||||
target_dir = os.path.join(self.loras_root, folder)
|
||||
|
||||
# List of file patterns to delete
|
||||
required_file = f"{model_name}.safetensors" # 主文件必须存在
|
||||
optional_files = [ # 这些文件可能不存在
|
||||
f"{model_name}.civitai.info",
|
||||
f"{model_name}.preview.png"
|
||||
]
|
||||
|
||||
deleted_files = []
|
||||
|
||||
# Try to delete the main safetensors file
|
||||
main_file_path = os.path.join(target_dir, required_file)
|
||||
if os.path.exists(main_file_path):
|
||||
try:
|
||||
os.remove(main_file_path)
|
||||
deleted_files.append(required_file)
|
||||
except Exception as e:
|
||||
print(f"Error deleting {main_file_path}: {str(e)}")
|
||||
return web.Response(text=f"Failed to delete main model file: {str(e)}", status=500)
|
||||
|
||||
# Only try to delete optional files if main file was deleted
|
||||
for pattern in optional_files:
|
||||
file_path = os.path.join(target_dir, pattern)
|
||||
if os.path.exists(file_path):
|
||||
try:
|
||||
os.remove(file_path)
|
||||
deleted_files.append(pattern)
|
||||
except Exception as e:
|
||||
print(f"Error deleting optional file {file_path}: {str(e)}")
|
||||
else:
|
||||
return web.Response(text=f"Model file {required_file} not found in {folder}", status=404)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'deleted_files': deleted_files
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def update_civitai_info(self, file_path: str, civitai_data: Dict, preview_url: Optional[str] = None):
|
||||
"""Update Civitai metadata and download preview image"""
|
||||
# Update metadata file
|
||||
await update_civitai_metadata(file_path, civitai_data)
|
||||
|
||||
# Download and save preview image if URL is provided
|
||||
if preview_url:
|
||||
preview_path = f"{os.path.splitext(file_path)[0]}.preview.png"
|
||||
try:
|
||||
# Add your image download logic here
|
||||
# Example:
|
||||
# await download_image(preview_url, preview_path)
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Error downloading preview image: {str(e)}")
|
||||
|
||||
# 注册路由
|
||||
LorasEndpoint.add_routes()
|
||||
Reference in New Issue
Block a user