import os import json import time from pathlib import Path from aiohttp import web from server import PromptServer import jinja2 from flask import jsonify, request from safetensors import safe_open from .utils.file_utils import get_file_info, save_metadata, load_metadata, update_civitai_metadata from .utils.lora_metadata import extract_lora_metadata from typing import Dict, Optional class LorasEndpoint: def __init__(self): self.template_env = jinja2.Environment( loader=jinja2.FileSystemLoader( os.path.join(os.path.dirname(__file__), 'templates') ), autoescape=True ) # 配置Loras根目录(根据实际安装位置调整) self.loras_root = os.path.join(Path(__file__).parents[2], "models", "loras") # 添加 server 属性 self.server = PromptServer.instance @classmethod def add_routes(cls): instance = cls() app = PromptServer.instance.app static_path = os.path.join(os.path.dirname(__file__), 'static') app.add_routes([ web.get('/loras', instance.handle_loras_request), web.static('/loras_static/previews', instance.loras_root), web.static('/loras_static', static_path), web.post('/api/delete_model', instance.delete_model) ]) def send_progress(self, current, total, status="Scanning"): """Send progress through websocket""" try: if hasattr(self.server, 'send_sync'): self.server.send_sync("lora-scan-progress", { "value": current, "max": total, "status": status }) except Exception as e: print(f"Error sending progress: {str(e)}") async def scan_loras(self): loras = [] for root, _, files in os.walk(self.loras_root): safetensors_files = [f for f in files if f.endswith('.safetensors')] total_files = len(safetensors_files) for idx, filename in enumerate(safetensors_files, 1): self.send_progress(idx, total_files, f"Scanning: {filename}") file_path = os.path.join(root, filename) # Try to load existing metadata first metadata = await load_metadata(file_path) if metadata is None: # Only get file info and extract metadata if no existing metadata metadata = await get_file_info(file_path) base_model_info = await extract_lora_metadata(file_path) metadata.base_model = base_model_info['base_model'] await save_metadata(file_path, metadata) # Convert to dict for API response lora_data = metadata.to_dict() # Get relative path and remove filename to get just the folder structure rel_path = os.path.relpath(file_path, self.loras_root) folder = os.path.dirname(rel_path) # Ensure forward slashes for consistency across platforms lora_data['folder'] = folder.replace(os.path.sep, '/') loras.append(lora_data) self.send_progress(total_files, total_files, "Scan completed") return loras def clean_description(self, desc): """清理HTML格式的描述""" return desc.replace("
", "").replace("
", "\n").strip() async def get_preview_url(self, preview_path, root_dir): """生成预览图URL""" if os.path.exists(preview_path): rel_path = os.path.relpath(preview_path, self.loras_root) return f"/loras_static/previews/{rel_path.replace(os.sep, '/')}" return "/loras_static/images/no-preview.png" async def handle_loras_request(self, request): """处理Loras请求并渲染模板""" try: scan_start = time.time() data = await self.scan_loras() print(f"Scanned {len(data)} loras in {time.time()-scan_start:.2f}s") # Format the data for the template formatted_loras = [self.format_lora(l) for l in data] folders = sorted(list(set(l['folder'] for l in data))) print("folders:",folders) # Debug logging if formatted_loras: print(f"Sample lora data: {formatted_loras[0]}") else: print("Warning: No loras found") context = { "loras": formatted_loras, "folders": folders, # Only set single lora if we're viewing details "lora": formatted_loras[0] if formatted_loras else { "model_name": "", "file_name": "", "preview_url": "", "folder": "", "civitai": { "id": "", "model": "", "base_model": "", "trained_words": [], "creator": "", "downloads": 0, "images": [], "description": "" } } } template = self.template_env.get_template('loras.html') rendered = template.render(**context) return web.Response( text=rendered, content_type='text/html' ) except Exception as e: print(f"Error handling loras request: {str(e)}") import traceback print(traceback.format_exc()) # Print full stack trace return web.Response( text="Error loading loras page", content_type='text/html', status=500 ) def format_lora(self, lora): """格式化前端需要的数据结构""" try: return { "model_name": lora["model_name"], "file_name": lora["file_name"], "preview_url": lora["preview_url"], "base_model": lora["base_model"], "folder": lora["folder"], "civitai": lora.get("civitai", {}) or {} # 确保当 civitai 为 None 时返回空字典 } except Exception as e: print(f"Error formatting lora: {str(e)}") print(f"Lora data: {lora}") return { "model_name": lora.get("model_name", "Unknown"), "file_name": lora.get("file_name", ""), "preview_url": lora.get("preview_url", ""), "base_model": lora.get("base_model", ""), "folder": lora.get("folder", ""), "civitai": { "id": "", "modelId": "", "model": "", "base_model": "", "trained_words": [], "creator": "", "downloads": 0, "images": [], "description": "" } } async def delete_model(self, request): try: data = await request.json() model_name = data.get('model_name') folder = data.get('folder') # 从请求中获取folder信息 if not model_name: return web.Response(text='Model name is required', status=400) # 构建完整的目录路径 target_dir = self.loras_root if folder and folder != "root": target_dir = os.path.join(self.loras_root, folder) # List of file patterns to delete required_file = f"{model_name}.safetensors" # 主文件必须存在 optional_files = [ # 这些文件可能不存在 f"{model_name}.civitai.info", f"{model_name}.preview.png" ] deleted_files = [] # Try to delete the main safetensors file main_file_path = os.path.join(target_dir, required_file) if os.path.exists(main_file_path): try: os.remove(main_file_path) deleted_files.append(required_file) except Exception as e: print(f"Error deleting {main_file_path}: {str(e)}") return web.Response(text=f"Failed to delete main model file: {str(e)}", status=500) # Only try to delete optional files if main file was deleted for pattern in optional_files: file_path = os.path.join(target_dir, pattern) if os.path.exists(file_path): try: os.remove(file_path) deleted_files.append(pattern) except Exception as e: print(f"Error deleting optional file {file_path}: {str(e)}") else: return web.Response(text=f"Model file {required_file} not found in {folder}", status=404) return web.json_response({ 'success': True, 'deleted_files': deleted_files }) except Exception as e: return web.Response(text=str(e), status=500) async def update_civitai_info(self, file_path: str, civitai_data: Dict, preview_url: Optional[str] = None): """Update Civitai metadata and download preview image""" # Update metadata file await update_civitai_metadata(file_path, civitai_data) # Download and save preview image if URL is provided if preview_url: preview_path = f"{os.path.splitext(file_path)[0]}.preview.png" try: # Add your image download logic here # Example: # await download_image(preview_url, preview_path) pass except Exception as e: print(f"Error downloading preview image: {str(e)}") # 注册路由 LorasEndpoint.add_routes()