commit ad6137d355fa747f27f0e3207e4bb62765243f83 Author: Will Miao <13051207myq@gmail.com> Date: Sat Jan 25 19:22:02 2025 +0800 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..ba0430d2 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__/ \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 00000000..265926e6 --- /dev/null +++ b/__init__.py @@ -0,0 +1,13 @@ +from .nodes import LorasEndpoint + +NODE_CLASS_MAPPINGS = { + "LorasEndpoint": LorasEndpoint +} + +WEB_DIRECTORY = "./js" + +# Add custom websocket event type +EXTENSION_WEB_SOCKET_MESSAGE_TYPES = ["lora-scan-progress"] + +__all__ = ['NODE_CLASS_MAPPINGS'] + diff --git a/manifest.json b/manifest.json new file mode 100644 index 00000000..3d75d681 --- /dev/null +++ b/manifest.json @@ -0,0 +1,7 @@ +{ + "name": "Loras Endpoint", + "version": "1.0.0", + "author": "Your Name", + "project": "https://github.com/your/repository", + "description": "Adds /loras endpoint to ComfyUI" +} \ No newline at end of file diff --git a/nodes.py b/nodes.py new file mode 100644 index 00000000..779ad743 --- /dev/null +++ b/nodes.py @@ -0,0 +1,326 @@ +# nodes.py 更新后的核心代码 +import os +import json +import time +from pathlib import Path +from aiohttp import web +from server import PromptServer +import jinja2 +from flask import jsonify, request +from safetensors import safe_open +from .utils.file_utils import get_file_info, save_metadata, load_metadata, update_civitai_metadata +from .utils.lora_metadata import extract_lora_metadata +from typing import Dict, Optional + +class LorasEndpoint: + def __init__(self): + self.template_env = jinja2.Environment( + loader=jinja2.FileSystemLoader( + os.path.join(os.path.dirname(__file__), 'templates') + ), + autoescape=True + ) + # 配置Loras根目录(根据实际安装位置调整) + self.loras_root = os.path.join(Path(__file__).parents[2], "models", "loras") + + @classmethod + def add_routes(cls): + instance = cls() + app = PromptServer.instance.app + static_path = os.path.join(os.path.dirname(__file__), 'static') + app.add_routes([ + web.get('/loras', instance.handle_loras_request), + web.static('/loras_static/previews', instance.loras_root), + web.static('/loras_static', static_path), + web.post('/api/delete_model', instance.delete_model) + ]) + + def send_progress(self, current, total, status="Scanning"): + """Send progress through websocket""" + if self.server and hasattr(self.server, 'send_sync'): + self.server.send_sync("lora-scan-progress", { + "value": current, + "max": total, + "status": status + }) + + async def scan_loras(self): + """扫描Loras目录并返回结构化数据""" + loras = [] + folders = set() + + # 遍历Loras目录(包含子目录) + for root, _, files in os.walk(self.loras_root): + rel_path = os.path.relpath(root, self.loras_root) + if rel_path == ".": + current_folder = "root" + else: + current_folder = rel_path.replace(os.sep, "/") + folders.add(current_folder) + + for file in files: + safetensors_files = [f for f in files if f.endswith('.safetensors')] + total_files = len(safetensors_files) + + # 识别模型文件 + if file.endswith('.safetensors'): + base_name = os.path.splitext(file)[0] + model_path = os.path.join(root, file) + + # Get basic file info and metadata + file_info = await get_file_info(model_path) + base_model_info = await extract_lora_metadata(model_path) + file_info.update(base_model_info) + + # Load existing metadata or create new one + metadata = await load_metadata(model_path) + if not metadata: + # First time scanning this file + await save_metadata(model_path, file_info) + metadata = file_info + else: + # Update basic file info in existing metadata + metadata.update(file_info) + await save_metadata(model_path, metadata) + + # Add civitai data to return value if exists + if 'civitai' in metadata: + metadata.update(metadata['civitai']) + + # 查找预览图 + preview_path = os.path.join(root, f"{base_name}.preview.png") + preview_url = await self.get_preview_url(preview_path, root) if os.path.exists(preview_path) else None + + loras.append({ + "name": base_name, + "folder": current_folder, + "path": model_path, + "preview_url": preview_url, + "metadata": metadata, + "size": os.path.getsize(model_path), + "modified": os.path.getmtime(model_path) + }) + + self.send_progress(total_files, total_files, "Scan completed") + return { + "loras": sorted(loras, key=lambda x: x["name"].lower()), + "folders": sorted(folders) + } + + async def parse_model_metadata(self, file_path): + """从safetensors文件中提取元数据""" + try: + with safe_open(file_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + if metadata: + return metadata + except Exception as e: + print(f"Error reading metadata from {file_path}: {str(e)}") + return {} + + async def parse_metadata(self, meta_file): + """解析元数据文件""" + try: + if os.path.exists(meta_file): + with open(meta_file, 'r', encoding='utf-8') as f: + meta = json.load(f) + return { + "id": meta.get("id"), + "modelId": meta.get("modelId"), + "model": meta.get("model", {}).get("name"), + "base_model": meta.get("baseModel"), + "trained_words": meta.get("trainedWords", []), + "creator": meta.get("creator", {}).get("username"), + "downloads": meta.get("stats", {}).get("downloadCount", 0), + "images": [img["url"] for img in meta.get("images", [])[:3]], + "description": self.clean_description( + meta.get("model", {}).get("description", "") + ) + } + except Exception as e: + print(f"Error parsing metadata {meta_file}: {str(e)}") + return {} + + def clean_description(self, desc): + """清理HTML格式的描述""" + return desc.replace("

", "").replace("

", "\n").strip() + + async def get_preview_url(self, preview_path, root_dir): + """生成预览图URL""" + if os.path.exists(preview_path): + rel_path = os.path.relpath(preview_path, self.loras_root) + return f"/loras_static/previews/{rel_path.replace(os.sep, '/')}" + return "/loras_static/images/no-preview.png" + + async def handle_loras_request(self, request): + """处理Loras请求并渲染模板""" + try: + scan_start = time.time() + data = await self.scan_loras() + print(f"Scanned {len(data['loras'])} loras in {time.time()-scan_start:.2f}s") + + # Format the data for the template + formatted_loras = [self.format_lora(l) for l in data["loras"]] + + # Debug logging + if formatted_loras: + print(f"Sample lora data: {formatted_loras[0]}") + else: + print("Warning: No loras found") + + context = { + "folders": data.get("folders", []), + "loras": formatted_loras, + # Only set single lora if we're viewing details + "lora": formatted_loras[0] if formatted_loras else { + "name": "", + "folder": "", + "file_name": "", + "preview_url": "", + "modified": "", + "size": "0MB", + "meta": { + "id": "", + "model": "", + "base_model": "", + "trained_words": [], + "creator": "", + "downloads": 0, + "images": [], + "description": "" + } + } + } + + template = self.template_env.get_template('loras.html') + rendered = template.render(**context) + return web.Response( + text=rendered, + content_type='text/html' + ) + except Exception as e: + print(f"Error handling loras request: {str(e)}") + import traceback + print(traceback.format_exc()) # Print full stack trace + return web.Response( + text="Error loading loras page", + content_type='text/html', + status=500 + ) + + def format_lora(self, lora): + """格式化前端需要的数据结构""" + try: + metadata = lora.get("metadata", {}) + + return { + "name": lora["name"], + "folder": lora["folder"], + "preview_url": lora["preview_url"], + "modified": time.strftime("%Y-%m-%d %H:%M", + time.localtime(lora["modified"])), + "size": f"{lora['size']/1024/1024:.1f}MB", + "meta": { + "id": metadata.get("id", ""), + "modelId": metadata.get("modelId", ""), + "model": metadata.get("model", ""), + "base_model": metadata.get("base_model", ""), + "trained_words": metadata.get("trained_words", []), + "creator": metadata.get("creator", ""), + "downloads": metadata.get("downloads", 0), + "images": metadata.get("images", []), + "description": metadata.get("description", "") + } + } + except Exception as e: + print(f"Error formatting lora: {str(e)}") + print(f"Lora data: {lora}") + return { + "name": lora.get("name", "Unknown"), + "folder": lora.get("folder", ""), + "preview_url": lora.get("preview_url", ""), + "modified": "", + "size": "0MB", + "meta": { + "id": "", + "modelId": "", + "model": "", + "base_model": "", + "trained_words": [], + "creator": "", + "downloads": 0, + "images": [], + "description": "" + } + } + + async def delete_model(self, request): + try: + data = await request.json() + model_name = data.get('model_name') + folder = data.get('folder') # 从请求中获取folder信息 + if not model_name: + return web.Response(text='Model name is required', status=400) + + # 构建完整的目录路径 + target_dir = self.loras_root + if folder and folder != "root": + target_dir = os.path.join(self.loras_root, folder) + + # List of file patterns to delete + required_file = f"{model_name}.safetensors" # 主文件必须存在 + optional_files = [ # 这些文件可能不存在 + f"{model_name}.civitai.info", + f"{model_name}.preview.png" + ] + + deleted_files = [] + + # Try to delete the main safetensors file + main_file_path = os.path.join(target_dir, required_file) + if os.path.exists(main_file_path): + try: + os.remove(main_file_path) + deleted_files.append(required_file) + except Exception as e: + print(f"Error deleting {main_file_path}: {str(e)}") + return web.Response(text=f"Failed to delete main model file: {str(e)}", status=500) + + # Only try to delete optional files if main file was deleted + for pattern in optional_files: + file_path = os.path.join(target_dir, pattern) + if os.path.exists(file_path): + try: + os.remove(file_path) + deleted_files.append(pattern) + except Exception as e: + print(f"Error deleting optional file {file_path}: {str(e)}") + else: + return web.Response(text=f"Model file {required_file} not found in {folder}", status=404) + + return web.json_response({ + 'success': True, + 'deleted_files': deleted_files + }) + + except Exception as e: + return web.Response(text=str(e), status=500) + + async def update_civitai_info(self, file_path: str, civitai_data: Dict, preview_url: Optional[str] = None): + """Update Civitai metadata and download preview image""" + # Update metadata file + await update_civitai_metadata(file_path, civitai_data) + + # Download and save preview image if URL is provided + if preview_url: + preview_path = f"{os.path.splitext(file_path)[0]}.preview.png" + try: + # Add your image download logic here + # Example: + # await download_image(preview_url, preview_path) + pass + except Exception as e: + print(f"Error downloading preview image: {str(e)}") + +# 注册路由 +LorasEndpoint.add_routes() \ No newline at end of file diff --git a/static/css/style.css b/static/css/style.css new file mode 100644 index 00000000..9b41e5ab --- /dev/null +++ b/static/css/style.css @@ -0,0 +1,315 @@ +:root { + --bg-color: #ffffff; + --text-color: #333333; + --card-bg: #ffffff; + --border-color: #e0e0e0; +} + +[data-theme="dark"] { + --bg-color: #1a1a1a; + --text-color: #e0e0e0; + --card-bg: #2d2d2d; + --border-color: #404040; +} + +body { + margin: 0; + font-family: 'Segoe UI', sans-serif; + background: var(--background-color); +} + +.container { + max-width: 1400px; + margin: 20px auto; + padding: 0 15px; +} + +/* 文件夹标签样式 */ +.folder-tags { + display: flex; + gap: 8px; + overflow-x: auto; + padding: 10px 0; +} + +.tag { + padding: 6px 12px; + border-radius: 20px; + background: #e0e0e0; + cursor: pointer; + white-space: nowrap; + transition: all 0.3s; +} + +.tag.active { + background: var(--primary-color); + color: white; +} + +/* 卡片网格布局 */ +.card-grid { + display: grid; + grid-template-columns: repeat(auto-fill, minmax(240px, 1fr)); + gap: 12px; + margin-top: 20px; +} + +.lora-card { + background: var(--card-bg); + border-radius: 12px; + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); + overflow: hidden; + transition: transform 0.2s; + aspect-ratio: 896/1152; + max-width: 240px; + margin: 0 auto; +} + +.lora-card:hover { + transform: translateY(-5px); +} + +/* Card Preview and Footer Overlay */ +.card-preview { + position: relative; + width: 100%; + height: 100%; +} + +.card-preview img { + width: 100%; + height: 100%; + object-fit: cover; +} + +.card-footer { + position: absolute; + bottom: 0; + left: 0; + right: 0; + background: linear-gradient(transparent, rgba(0, 0, 0, 0.85)); + color: white; + padding: 15px; + display: flex; + justify-content: space-between; + align-items: flex-end; + min-height: 80px; +} + +.model-name { + font-weight: bold; + text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.5); +} + +.model-meta { + font-size: 0.9em; + opacity: 0.9; +} + +.card-header { + position: absolute; + top: 0; + left: 0; + right: 0; + background: linear-gradient(rgba(0, 0, 0, 0.85), transparent); + color: white; + padding: 15px; + display: flex; + justify-content: space-between; + align-items: center; + z-index: 1; +} + +.card-actions i { + margin-left: 10px; + cursor: pointer; + color: white; + transition: opacity 0.2s; +} + +.card-actions i:hover { + opacity: 0.8; +} + +/* 响应式设计 */ +@media (max-width: 768px) { + .card-grid { + grid-template-columns: 1fr; + } + + .controls { + flex-direction: column; + gap: 15px; + } +} + +/* 新增元数据相关样式 */ +.model-info { + flex: 1; +} + +.model-meta { + font-size: 0.8em; + color: #666; + margin-top: 4px; +} + +.base-model { + display: inline-block; + background: #f0f0f0; + padding: 2px 6px; + border-radius: 4px; + margin-right: 6px; +} + +.file-size, +.modified { + display: block; + margin-top: 2px; +} + +.tooltip { + position: relative; + cursor: help; +} + +.tooltip::after { + content: attr(data-tooltip); + position: absolute; + bottom: 120%; + left: 50%; + transform: translateX(-50%); + background: rgba(0, 0, 0, 0.8); + color: white; + padding: 4px 8px; + border-radius: 4px; + font-size: 0.8em; + white-space: nowrap; + opacity: 0; + transition: opacity 0.2s; + pointer-events: none; +} + +.tooltip:hover::after { + opacity: 1; +} + +/* 模态窗口样式 */ +.modal { + display: none; + position: fixed; + top: 0; + left: 0; + width: 100%; + height: 100%; + background: rgba(0, 0, 0, 0.8); + z-index: 1000; + overflow-y: auto; /* 允许模态窗口内容滚动 */ +} + +/* 当模态窗口打开时,禁止body滚动 */ +body.modal-open { + overflow: hidden; +} + +.modal-content { + position: relative; + max-width: 800px; + margin: 2rem auto; + background: var(--card-bg); + border-radius: 12px; + padding: 20px; +} + +.carousel { + display: grid; + grid-auto-flow: column; + gap: 1rem; + overflow-x: auto; + scroll-snap-type: x mandatory; +} + +.carousel img { + scroll-snap-align: start; + max-height: 60vh; + object-fit: contain; +} + +/* 主题切换按钮 */ +.theme-toggle { + position: fixed; + top: 20px; + right: 20px; + cursor: pointer; + padding: 8px; + border-radius: 50%; + background: var(--card-bg); +} + +.base-model-label { + max-width: 120px; + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + display: inline-block; + color: white; + text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.5); + background: rgba(255, 255, 255, 0.2); + padding: 2px 8px; + border-radius: 4px; + backdrop-filter: blur(2px); +} + +.loading-overlay { + position: fixed; + top: 0; + left: 0; + width: 100%; + height: 100%; + background: rgba(0, 0, 0, 0.7); + display: flex; + justify-content: center; + align-items: center; + z-index: 1000; +} + +.loading-content { + background: #fff; + padding: 2rem; + border-radius: 8px; + text-align: center; +} + +.loading-spinner { + border: 4px solid #f3f3f3; + border-top: 4px solid #3498db; + border-radius: 50%; + width: 40px; + height: 40px; + animation: spin 1s linear infinite; + margin: 0 auto 1rem; +} + +.loading-status { + margin-bottom: 1rem; + color: #333; +} + +.progress-container { + width: 300px; + background-color: #f0f0f0; + border-radius: 4px; + overflow: hidden; +} + +.progress-bar { + width: 0%; + height: 20px; + background-color: #4CAF50; + transition: width 0.3s ease; +} + +@keyframes spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } +} \ No newline at end of file diff --git a/static/images/no-preview.png b/static/images/no-preview.png new file mode 100644 index 00000000..e2beb269 Binary files /dev/null and b/static/images/no-preview.png differ diff --git a/static/images/theme-toggle.svg b/static/images/theme-toggle.svg new file mode 100644 index 00000000..75e4e6d9 --- /dev/null +++ b/static/images/theme-toggle.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/static/js/script.js b/static/js/script.js new file mode 100644 index 00000000..cd45705d --- /dev/null +++ b/static/js/script.js @@ -0,0 +1,229 @@ +// 排序功能 +function sortCards(sortBy) { + const grid = document.getElementById('loraGrid'); + const cards = Array.from(grid.children); + + cards.sort((a, b) => { + switch(sortBy) { + case 'name': + return a.dataset.name.localeCompare(b.dataset.name); + case 'date': + return new Date(b.dataset.date) - new Date(a.dataset.date); + case 'size': + return parseFloat(b.dataset.size) - parseFloat(a.dataset.size); + } + }); + + cards.forEach(card => grid.appendChild(card)); +} + +// 文件夹筛选 +document.querySelectorAll('.tag').forEach(tag => { + tag.addEventListener('click', () => { + document.querySelectorAll('.tag').forEach(t => t.classList.remove('active')); + tag.classList.add('active'); + const folder = tag.dataset.folder; + filterByFolder(folder); + }); +}); + +function filterByFolder(folder) { + document.querySelectorAll('.lora-card').forEach(card => { + card.style.display = card.dataset.folder === folder ? 'block' : 'none'; + }); +} + +// 刷新功能 +async function refreshLoras() { + try { + const response = await fetch('/loras?refresh=true'); + if (response.ok) { + location.reload(); + } + } catch (error) { + console.error('Refresh failed:', error); + } +} + +// 占位功能函数 +function openCivitai(loraName) { + // 从卡片的data-meta属性中获取civitai ID + const loraCard = document.querySelector(`.lora-card[data-name="${loraName}"]`); + if (!loraCard) return; + + const metaData = JSON.parse(loraCard.dataset.meta); + const civitaiId = metaData.modelId; // 使用modelId作为civitai模型ID + const versionId = metaData.id; // 使用id作为版本ID + + // 构建URL + if (civitaiId) { + let url = `https://civitai.com/models/${civitaiId}`; + if (versionId) { + url += `?modelVersionId=${versionId}`; + } + window.open(url, '_blank'); + } else { + // 如果没有ID,尝试使用名称搜索 + window.open(`https://civitai.com/models?query=${encodeURIComponent(loraName)}`, '_blank'); + } +} + +async function deleteModel(modelName) { + // Prevent event bubbling + event.stopPropagation(); + + // Get the folder from the card's data attributes + const card = document.querySelector(`.lora-card[data-name="${modelName}"]`); + const folder = card ? card.dataset.folder : null; + + // Show confirmation dialog + const confirmed = confirm(`Are you sure you want to delete "${modelName}" and all associated files?`); + + if (confirmed) { + try { + const response = await fetch('/api/delete_model', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + model_name: modelName, + folder: folder + }) + }); + + if (response.ok) { + // Remove the card from UI + if (card) { + card.remove(); + } + // Show success message + alert('Model deleted successfully'); + } else { + const error = await response.text(); + alert(`Failed to delete model: ${error}`); + } + } catch (error) { + alert(`Error deleting model: ${error}`); + } + } +} + +// 初始化排序 +document.getElementById('sortSelect').addEventListener('change', (e) => { + sortCards(e.target.value); +}); + +// 添加搜索功能 +document.getElementById('searchInput')?.addEventListener('input', (e) => { + const term = e.target.value.toLowerCase(); + document.querySelectorAll('.lora-card').forEach(card => { + const match = card.dataset.name.toLowerCase().includes(term) || + card.dataset.folder.toLowerCase().includes(term); + card.style.display = match ? 'block' : 'none'; + }); +}); + +// 模态窗口管理 +let currentLora = null; +let currentImageIndex = 0; + +document.querySelectorAll('.lora-card').forEach(card => { + card.addEventListener('click', () => { + currentLora = JSON.parse(card.dataset.meta); + showModal(currentLora); + }); +}); + +function showModal(lora) { + const modal = document.getElementById('loraModal'); + modal.innerHTML = ` + + `; + modal.style.display = 'block'; + document.body.classList.add('modal-open'); + + // 添加点击事件监听器 + modal.onclick = function(event) { + // 如果点击的是模态窗口的背景(不是内容区域),则关闭模态窗口 + if (event.target === modal) { + closeModal(); + } + }; +} + +function closeModal() { + const modal = document.getElementById('loraModal'); + modal.style.display = 'none'; + document.body.classList.remove('modal-open'); + // 移除点击事件监听器 + modal.onclick = null; +} + +// WebSocket handling for progress updates +document.addEventListener('DOMContentLoaded', function() { + const loadingOverlay = document.getElementById('loading-overlay'); + const progressBar = document.querySelector('.progress-bar'); + const loadingStatus = document.querySelector('.loading-status'); + + // Show loading overlay initially + loadingOverlay.style.display = 'flex'; + + // Listen for progress updates + api.addEventListener("lora-scan-progress", (event) => { + const data = event.detail; + const progress = (data.value / data.max) * 100; + + progressBar.style.width = `${progress}%`; + progressBar.setAttribute('aria-valuenow', progress); + loadingStatus.textContent = data.status; + + if (data.value === data.max) { + // Hide loading overlay when scan is complete + setTimeout(() => { + loadingOverlay.style.display = 'none'; + }, 500); + } + }); +}); + +// 主题切换 +function toggleTheme() { + const theme = document.body.dataset.theme || 'light'; + document.body.dataset.theme = theme === 'light' ? 'dark' : 'light'; + localStorage.setItem('theme', document.body.dataset.theme); +} + +// 初始化主题 +function initTheme() { + const savedTheme = localStorage.getItem('theme') || 'light'; + document.body.dataset.theme = savedTheme; +} + +// 检测系统主题 +window.matchMedia('(prefers-color-scheme: dark)').addListener(e => { + document.body.dataset.theme = e.matches ? 'dark' : 'light'; +}); + +// 键盘导航 +document.addEventListener('keydown', (e) => { + if (e.key === 'Escape') closeModal(); + if (e.key === 'ArrowLeft') prevImage(); + if (e.key === 'ArrowRight') nextImage(); +}); + +// 图片预加载 +function preloadImages(urls) { + urls.forEach(url => { + new Image().src = url; + }); +} + +initTheme(); \ No newline at end of file diff --git a/templates/loras.html b/templates/loras.html new file mode 100644 index 00000000..b313de77 --- /dev/null +++ b/templates/loras.html @@ -0,0 +1,80 @@ + + + + LoRA Management + + + + + +
+ Theme +
+ + + + + +
+ +
+
+ {% for folder in folders %} +
{{ folder }}
+ {% endfor %} +
+ +
+ + + +
+
+ + +
+ {% for lora in loras %} + +
+
+ {{ lora.name }} +
+ + {{ lora.meta.base_model if lora.meta and lora.meta.base_model else 'Unknown' }} + +
+ + + +
+
+ +
+
+ {% endfor %} +
+
+ + + + \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 00000000..3aae30f6 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1 @@ +# Empty file to mark directory as Python package \ No newline at end of file diff --git a/utils/file_utils.py b/utils/file_utils.py new file mode 100644 index 00000000..19bafa35 --- /dev/null +++ b/utils/file_utils.py @@ -0,0 +1,48 @@ +import os +import hashlib +import json +from typing import Dict, Optional + +async def calculate_sha256(file_path: str) -> str: + """Calculate SHA256 hash of a file""" + sha256_hash = hashlib.sha256() + with open(file_path, "rb") as f: + for byte_block in iter(lambda: f.read(4096), b""): + sha256_hash.update(byte_block) + return sha256_hash.hexdigest() + +async def get_file_info(file_path: str) -> Dict: + """Get basic file information""" + return { + "name": os.path.splitext(os.path.basename(file_path))[0], + "file_path": file_path, + "size": os.path.getsize(file_path), + "modified": os.path.getmtime(file_path), + "sha256": await calculate_sha256(file_path) + } + +async def save_metadata(file_path: str, metadata: Dict) -> None: + """Save metadata to .metadata.json file""" + metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json" + try: + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) + except Exception as e: + print(f"Error saving metadata to {metadata_path}: {str(e)}") + +async def load_metadata(file_path: str) -> Dict: + """Load metadata from .metadata.json file""" + metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json" + try: + if os.path.exists(metadata_path): + with open(metadata_path, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + print(f"Error loading metadata from {metadata_path}: {str(e)}") + return {} + +async def update_civitai_metadata(file_path: str, civitai_data: Dict) -> None: + """Update metadata file with Civitai data""" + metadata = await load_metadata(file_path) + metadata['civitai'] = civitai_data + await save_metadata(file_path, metadata) \ No newline at end of file diff --git a/utils/lora_metadata.py b/utils/lora_metadata.py new file mode 100644 index 00000000..c04cd829 --- /dev/null +++ b/utils/lora_metadata.py @@ -0,0 +1,16 @@ +from safetensors import safe_open +from typing import Dict +from .model_utils import determine_base_model + +async def extract_lora_metadata(file_path: str) -> Dict: + """Extract essential metadata from safetensors file""" + try: + with safe_open(file_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + if metadata: + # Only extract base_model from ss_base_model_version + base_model = determine_base_model(metadata.get("ss_base_model_version")) + return {"base_model": base_model} + except Exception as e: + print(f"Error reading metadata from {file_path}: {str(e)}") + return {"base_model": "Unknown"} \ No newline at end of file diff --git a/utils/model_utils.py b/utils/model_utils.py new file mode 100644 index 00000000..58e75cb0 --- /dev/null +++ b/utils/model_utils.py @@ -0,0 +1,22 @@ +from typing import Dict, Optional + +# Base model mapping based on version string +BASE_MODEL_MAPPING = { + "sd-v1-5": "SD1.5", + "sd-v2-1": "SD2.1", + "sdxl": "SDXL", + "sd-v2": "SD2.0", + "flux1": "Flux1.D", +} + +def determine_base_model(version_string: Optional[str]) -> str: + """Determine base model from version string in safetensors metadata""" + if not version_string: + return "Unknown" + + version_lower = version_string.lower() + for key, value in BASE_MODEL_MAPPING.items(): + if key in version_lower: + return value + + return "Unknown" \ No newline at end of file