mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Refactor and optimize code for improved readability and maintainability
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from .lora_manager import LoraManager
|
||||
from .nodes.lora_gateway import LoRAGateway
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
@@ -10,4 +11,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
|
||||
WEB_DIRECTORY = "./js"
|
||||
|
||||
# Register routes on import
|
||||
LoraManager.add_routes()
|
||||
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS', 'WEB_DIRECTORY']
|
||||
37
config.py
Normal file
37
config.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import os
|
||||
import folder_paths # type: ignore
|
||||
from typing import List
|
||||
|
||||
class Config:
|
||||
"""Global configuration for LoRA Manager"""
|
||||
|
||||
def __init__(self):
|
||||
self.loras_roots = self._init_lora_paths()
|
||||
self.templates_path = os.path.join(os.path.dirname(__file__), 'templates')
|
||||
self.static_path = os.path.join(os.path.dirname(__file__), 'static')
|
||||
|
||||
def _init_lora_paths(self) -> List[str]:
|
||||
"""Initialize and validate LoRA paths from ComfyUI settings"""
|
||||
paths = [path.replace(os.sep, "/")
|
||||
for path in folder_paths.get_folder_paths("loras")
|
||||
if os.path.exists(path)]
|
||||
|
||||
if not paths:
|
||||
raise ValueError("No valid loras folders found in ComfyUI configuration")
|
||||
|
||||
return paths
|
||||
|
||||
def get_preview_static_url(self, preview_path: str) -> str:
|
||||
"""Convert local preview path to static URL"""
|
||||
if not preview_path:
|
||||
return ""
|
||||
|
||||
for idx, root in enumerate(self.loras_roots, start=1):
|
||||
if preview_path.startswith(root):
|
||||
relative_path = os.path.relpath(preview_path, root)
|
||||
return f'/loras_static/root{idx}/preview/{relative_path.replace(os.sep, "/")}'
|
||||
|
||||
return ""
|
||||
|
||||
# Global config instance
|
||||
config = Config()
|
||||
389
lora_manager.py
389
lora_manager.py
@@ -1,379 +1,24 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from aiohttp import web
|
||||
from server import PromptServer
|
||||
import jinja2
|
||||
from safetensors import safe_open
|
||||
from .utils.file_utils import get_file_info, save_metadata, load_metadata, update_civitai_metadata
|
||||
from .utils.lora_metadata import extract_lora_metadata
|
||||
from typing import Dict, Optional
|
||||
from .services.civitai_client import CivitaiClient
|
||||
import folder_paths
|
||||
import logging
|
||||
|
||||
class LorasEndpoint:
|
||||
def __init__(self):
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(
|
||||
os.path.join(os.path.dirname(__file__), 'templates')
|
||||
),
|
||||
autoescape=True
|
||||
)
|
||||
# Configure Loras root directories (from ComfyUI folder paths settings)
|
||||
self.loras_roots = [path.replace(os.sep, "/") for path in folder_paths.get_folder_paths("loras") if os.path.exists(path)]
|
||||
if not self.loras_roots:
|
||||
raise ValueError("No valid loras folders found")
|
||||
print(f"Loras roots: {self.loras_roots}") # debug log
|
||||
|
||||
self.server = PromptServer.instance
|
||||
from server import PromptServer # type: ignore
|
||||
from .config import config
|
||||
from .routes.lora_routes import LoraRoutes
|
||||
from .routes.api_routes import ApiRoutes
|
||||
|
||||
class LoraManager:
|
||||
"""Main entry point for LoRA Manager plugin"""
|
||||
|
||||
@classmethod
|
||||
def add_routes(cls):
|
||||
instance = cls()
|
||||
"""Initialize and register all routes"""
|
||||
app = PromptServer.instance.app
|
||||
static_path = os.path.join(os.path.dirname(__file__), 'static')
|
||||
|
||||
# Generate multiple static paths based on the number of folders in instance.loras_roots
|
||||
for idx, root in enumerate(instance.loras_roots, start=1):
|
||||
# Create different static paths for each folder, like /loras_static/root1/preview
|
||||
|
||||
# Add static routes for each lora root
|
||||
for idx, root in enumerate(config.loras_roots, start=1):
|
||||
preview_path = f'/loras_static/root{idx}/preview'
|
||||
app.add_routes([web.static(preview_path, root)])
|
||||
|
||||
app.add_routes([
|
||||
web.get('/loras', instance.handle_loras_request),
|
||||
# web.static('/loras_static/previews', instance.loras_root),
|
||||
web.static('/loras_static', static_path),
|
||||
web.post('/api/delete_model', instance.delete_model),
|
||||
web.post('/api/fetch-civitai', instance.fetch_civitai),
|
||||
web.post('/api/replace_preview', instance.replace_preview),
|
||||
])
|
||||
|
||||
|
||||
async def scan_loras(self):
|
||||
loras = []
|
||||
for loras_root in self.loras_roots:
|
||||
for root, _, files in os.walk(loras_root):
|
||||
safetensors_files = [f for f in files if f.endswith('.safetensors')]
|
||||
|
||||
for filename in safetensors_files:
|
||||
file_path = os.path.join(root, filename).replace(os.sep, "/")
|
||||
|
||||
# Try to load existing metadata first
|
||||
metadata = await load_metadata(file_path)
|
||||
|
||||
if metadata is None:
|
||||
# Only get file info and extract metadata if no existing metadata
|
||||
metadata = await get_file_info(file_path)
|
||||
base_model_info = await extract_lora_metadata(file_path)
|
||||
metadata.base_model = base_model_info['base_model']
|
||||
await save_metadata(file_path, metadata)
|
||||
|
||||
# Convert to dict for API response
|
||||
lora_data = metadata.to_dict()
|
||||
# Get relative path and remove filename to get just the folder structure
|
||||
rel_path = os.path.relpath(file_path, loras_root)
|
||||
folder = os.path.dirname(rel_path)
|
||||
# Ensure forward slashes for consistency across platforms
|
||||
lora_data['folder'] = folder.replace(os.path.sep, '/')
|
||||
|
||||
loras.append(lora_data)
|
||||
|
||||
return loras
|
||||
|
||||
|
||||
def clean_description(self, desc):
|
||||
"""清理HTML格式的描述"""
|
||||
return desc.replace("<p>", "").replace("</p>", "\n").strip()
|
||||
|
||||
async def handle_loras_request(self, request):
|
||||
"""处理Loras请求并渲染模板"""
|
||||
try:
|
||||
scan_start = time.time()
|
||||
data = await self.scan_loras()
|
||||
print(f"Lora Manager: Scanned {len(data)} loras in {time.time()-scan_start:.2f}s")
|
||||
|
||||
# Format the data for the template
|
||||
formatted_loras = [self.format_lora(l) for l in data]
|
||||
folders = sorted(list(set(l['folder'] for l in data)))
|
||||
|
||||
context = {
|
||||
"loras": formatted_loras,
|
||||
"folders": folders
|
||||
}
|
||||
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(**context)
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error handling loras request: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc()) # Print full stack trace
|
||||
return web.Response(
|
||||
text="Error loading loras page",
|
||||
content_type='text/html',
|
||||
status=500
|
||||
)
|
||||
app.router.add_static(preview_path, root)
|
||||
|
||||
def filter_civitai_data(self, civitai_data):
|
||||
if not civitai_data:
|
||||
return {}
|
||||
|
||||
required_fields = [
|
||||
"id", "modelId", "name", "createdAt", "updatedAt",
|
||||
"publishedAt", "trainedWords", "baseModel", "description",
|
||||
"model", "images"
|
||||
]
|
||||
|
||||
return {k: civitai_data[k] for k in required_fields if k in civitai_data}
|
||||
|
||||
def format_lora(self, lora):
|
||||
"""格式化前端需要的数据结构"""
|
||||
return {
|
||||
"model_name": lora["model_name"],
|
||||
"file_name": lora["file_name"],
|
||||
"preview_url": self.get_static_url_for_preview(lora["preview_url"]),
|
||||
"base_model": lora["base_model"],
|
||||
"folder": lora["folder"],
|
||||
"sha256": lora["sha256"],
|
||||
"file_path": lora["file_path"].replace(os.sep, "/"),
|
||||
"modified": lora["modified"],
|
||||
"from_civitai": lora.get("from_civitai", True),
|
||||
"civitai": self.filter_civitai_data(lora.get("civitai", {}))
|
||||
}
|
||||
|
||||
def get_static_url_for_preview(self, preview_url):
|
||||
"""
|
||||
Determines which loras_root the preview_url belongs to and
|
||||
returns the corresponding static URL.
|
||||
"""
|
||||
for idx, root in enumerate(self.loras_roots, start=1):
|
||||
# Check if preview_url belongs to current root
|
||||
if preview_url.startswith(root):
|
||||
# Get relative path and generate static URL
|
||||
relative_path = os.path.relpath(preview_url, root)
|
||||
static_url = f'/loras_static/root{idx}/preview/{relative_path.replace(os.sep, "/")}'
|
||||
return static_url
|
||||
# Add static route for plugin assets
|
||||
app.router.add_static('/loras_static', config.static_path)
|
||||
|
||||
# If no matching root found, return empty string
|
||||
return ""
|
||||
|
||||
|
||||
async def delete_model(self, request):
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path') # 从请求中获取file_path信息
|
||||
if not file_path:
|
||||
return web.Response(text='Model full path is required', status=400)
|
||||
|
||||
# 构建完整的目录路径
|
||||
target_dir = os.path.dirname(file_path)
|
||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
# List of file patterns to delete
|
||||
required_file = f"{file_name}.safetensors" # 主文件必须存在
|
||||
optional_files = [ # 这些文件可能不存在
|
||||
f"{file_name}.metadata.json",
|
||||
f"{file_name}.preview.png",
|
||||
f"{file_name}.preview.jpg",
|
||||
f"{file_name}.preview.jpeg",
|
||||
f"{file_name}.preview.webp",
|
||||
f"{file_name}.preview.mp4",
|
||||
f"{file_name}.png",
|
||||
f"{file_name}.jpg",
|
||||
f"{file_name}.jpeg",
|
||||
f"{file_name}.webp",
|
||||
f"{file_name}.mp4"
|
||||
]
|
||||
|
||||
deleted_files = []
|
||||
|
||||
# Try to delete the main safetensors file
|
||||
main_file_path = os.path.join(target_dir, required_file)
|
||||
if os.path.exists(main_file_path):
|
||||
try:
|
||||
os.remove(main_file_path)
|
||||
deleted_files.append(required_file)
|
||||
except Exception as e:
|
||||
print(f"Error deleting {main_file_path}: {str(e)}")
|
||||
return web.Response(text=f"Failed to delete main model file: {str(e)}", status=500)
|
||||
|
||||
# Only try to delete optional files if main file was deleted
|
||||
for pattern in optional_files:
|
||||
file_path = os.path.join(target_dir, pattern)
|
||||
if os.path.exists(file_path):
|
||||
try:
|
||||
os.remove(file_path)
|
||||
deleted_files.append(pattern)
|
||||
except Exception as e:
|
||||
print(f"Error deleting optional file {file_path}: {str(e)}")
|
||||
else:
|
||||
return web.Response(text=f"Model file {required_file} not found in {target_dir}", status=404)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'deleted_files': deleted_files
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def update_civitai_info(self, file_path: str, civitai_data: Dict, preview_url: Optional[str] = None):
|
||||
"""Update Civitai metadata and download preview image"""
|
||||
# Update metadata file
|
||||
await update_civitai_metadata(file_path, civitai_data)
|
||||
|
||||
# Download and save preview image if URL is provided
|
||||
if preview_url:
|
||||
preview_path = f"{os.path.splitext(file_path)[0]}.preview.png"
|
||||
try:
|
||||
# Add your image download logic here
|
||||
# Example:
|
||||
# await download_image(preview_url, preview_path)
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Error downloading preview image: {str(e)}")
|
||||
|
||||
async def fetch_civitai(self, request):
|
||||
try:
|
||||
data = await request.json()
|
||||
client = CivitaiClient()
|
||||
|
||||
try:
|
||||
metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json'
|
||||
|
||||
local_metadata = {}
|
||||
if os.path.exists(metadata_path):
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
local_metadata = json.load(f)
|
||||
|
||||
|
||||
if not local_metadata.get('from_civitai', True):
|
||||
return web.json_response(
|
||||
{"success": True, "Notice": "Not from CivitAI"},
|
||||
status=200
|
||||
)
|
||||
|
||||
# 1. 获取CivitAI元数据
|
||||
civitai_metadata = await client.get_model_by_hash(data["sha256"])
|
||||
if not civitai_metadata:
|
||||
local_metadata['from_civitai'] = False
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Not found on CivitAI"},
|
||||
status=404
|
||||
)
|
||||
|
||||
local_metadata['civitai']=civitai_metadata
|
||||
|
||||
# 更新模型名称(优先使用CivitAI名称)
|
||||
if 'model' in civitai_metadata:
|
||||
local_metadata['model_name'] = civitai_metadata['model'].get('name', local_metadata.get('model_name'))
|
||||
# update base model
|
||||
local_metadata['base_model'] = civitai_metadata.get('baseModel')
|
||||
|
||||
# 4. 下载预览图
|
||||
# Check if existing preview is valid and the file exists
|
||||
if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']):
|
||||
first_preview = next((img for img in civitai_metadata.get('images', [])), None)
|
||||
if first_preview:
|
||||
|
||||
preview_extension = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1] # Get the file extension
|
||||
preview_filename = os.path.splitext(os.path.basename(data['file_path']))[0] + '.preview' + preview_extension
|
||||
preview_path = os.path.join(os.path.dirname(data['file_path']), preview_filename)
|
||||
await client.download_preview_image(first_preview['url'], preview_path)
|
||||
# 存储相对路径,使用正斜杠格式
|
||||
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
|
||||
|
||||
# 5. 保存更新后的元数据
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
|
||||
|
||||
return web.json_response({
|
||||
"success": True
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in fetch_civitai: {str(e)}") # Debug log
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing request: {str(e)}") # Debug log
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": f"Request processing error: {str(e)}"
|
||||
}, status=400)
|
||||
|
||||
async def replace_preview(self, request):
|
||||
try:
|
||||
reader = await request.multipart()
|
||||
|
||||
# Get the preview_file field first
|
||||
file_field = await reader.next()
|
||||
if file_field.name != 'preview_file':
|
||||
raise ValueError("Expected 'preview_file' field first")
|
||||
preview_data = await file_field.read()
|
||||
|
||||
# Get the file model_path field
|
||||
name_field = await reader.next()
|
||||
if name_field.name != 'model_path':
|
||||
raise ValueError("Expected 'model_path' field second")
|
||||
model_path = (await name_field.read()).decode()
|
||||
|
||||
# Get the content type from the file field headers
|
||||
content_type = file_field.headers.get('Content-Type', '')
|
||||
|
||||
print(f"Received preview file: {model_path} ({content_type})") # Debug log
|
||||
|
||||
# Determine file extension based on content type
|
||||
if content_type.startswith('video/'):
|
||||
extension = '.preview.mp4'
|
||||
else:
|
||||
extension = '.preview.png'
|
||||
|
||||
# Construct the preview file path
|
||||
base_name = os.path.splitext(os.path.basename(model_path))[0] # Remove original extension
|
||||
preview_name = base_name + extension
|
||||
# Get the folder path from the model_path
|
||||
folder = os.path.dirname(model_path)
|
||||
preview_path = os.path.join(folder, preview_name).replace(os.sep, '/')
|
||||
|
||||
# Save the preview file
|
||||
with open(preview_path, 'wb') as f:
|
||||
f.write(preview_data)
|
||||
|
||||
# Update metadata if it exists
|
||||
metadata_path = os.path.join(folder, base_name + '.metadata.json')
|
||||
if os.path.exists(metadata_path):
|
||||
try:
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
# Update the preview_url to match the new file name
|
||||
metadata['preview_url'] = preview_path
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
print(f"Error updating metadata: {str(e)}")
|
||||
# Continue even if metadata update fails
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"preview_url": self.get_static_url_for_preview(preview_path)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error replacing preview: {str(e)}")
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
# 注册路由
|
||||
LorasEndpoint.add_routes()
|
||||
# Setup feature routes
|
||||
LoraRoutes.setup_routes(app)
|
||||
ApiRoutes.setup_routes(app)
|
||||
@@ -1,6 +1,3 @@
|
||||
from ..lora_manager import LorasEndpoint
|
||||
|
||||
|
||||
class LoRAGateway:
|
||||
"""
|
||||
LoRA Gateway Node
|
||||
@@ -15,10 +12,4 @@ class LoRAGateway:
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "register_services"
|
||||
CATEGORY = "LoRA Management"
|
||||
|
||||
@classmethod
|
||||
def register_services(cls):
|
||||
# Service registration logic
|
||||
LorasEndpoint.add_routes()
|
||||
return ()
|
||||
CATEGORY = "LoRA Management"
|
||||
228
routes/api_routes.py
Normal file
228
routes/api_routes.py
Normal file
@@ -0,0 +1,228 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from aiohttp import web
|
||||
from typing import Dict, List
|
||||
from ..services.civitai_client import CivitaiClient
|
||||
from ..utils.file_utils import update_civitai_metadata, load_metadata
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ApiRoutes:
|
||||
"""API route handlers for LoRA management"""
|
||||
|
||||
@classmethod
|
||||
def setup_routes(cls, app: web.Application):
|
||||
"""Register API routes"""
|
||||
routes = cls()
|
||||
app.router.add_post('/api/delete_model', routes.delete_model)
|
||||
app.router.add_post('/api/fetch-civitai', routes.fetch_civitai)
|
||||
app.router.add_post('/api/replace_preview', routes.replace_preview)
|
||||
|
||||
async def delete_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle model deletion request"""
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
if not file_path:
|
||||
return web.Response(text='Model path is required', status=400)
|
||||
|
||||
target_dir = os.path.dirname(file_path)
|
||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
deleted_files = await self._delete_model_files(target_dir, file_name)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'deleted_files': deleted_files
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting model: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Handle CivitAI metadata fetch request"""
|
||||
client = CivitaiClient()
|
||||
try:
|
||||
data = await request.json()
|
||||
metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json'
|
||||
|
||||
# Check if model is from CivitAI
|
||||
local_metadata = await self._load_local_metadata(metadata_path)
|
||||
if not local_metadata.get('from_civitai', True):
|
||||
return web.json_response({"success": True, "notice": "Not from CivitAI"})
|
||||
|
||||
# Fetch and update metadata
|
||||
civitai_metadata = await client.get_model_by_hash(data["sha256"])
|
||||
if not civitai_metadata:
|
||||
return await self._handle_not_found_on_civitai(metadata_path, local_metadata)
|
||||
|
||||
await self._update_model_metadata(metadata_path, local_metadata, civitai_metadata, client)
|
||||
|
||||
return web.json_response({"success": True})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching from CivitAI: {e}", exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(e)}, status=500)
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
async def replace_preview(self, request: web.Request) -> web.Response:
|
||||
"""Handle preview image replacement request"""
|
||||
try:
|
||||
reader = await request.multipart()
|
||||
preview_data, content_type = await self._read_preview_file(reader)
|
||||
model_path = await self._read_model_path(reader)
|
||||
|
||||
preview_path = await self._save_preview_file(model_path, preview_data, content_type)
|
||||
await self._update_preview_metadata(model_path, preview_path)
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"preview_url": config.get_preview_static_url(preview_path)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error replacing preview: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
# Private helper methods
|
||||
async def _delete_model_files(self, target_dir: str, file_name: str) -> List[str]:
|
||||
"""Delete model and associated files"""
|
||||
patterns = [
|
||||
f"{file_name}.safetensors", # Required
|
||||
f"{file_name}.metadata.json",
|
||||
f"{file_name}.preview.png",
|
||||
f"{file_name}.preview.jpg",
|
||||
f"{file_name}.preview.jpeg",
|
||||
f"{file_name}.preview.webp",
|
||||
f"{file_name}.preview.mp4",
|
||||
f"{file_name}.png",
|
||||
f"{file_name}.jpg",
|
||||
f"{file_name}.jpeg",
|
||||
f"{file_name}.webp",
|
||||
f"{file_name}.mp4"
|
||||
]
|
||||
|
||||
deleted = []
|
||||
main_file = patterns[0]
|
||||
main_path = os.path.join(target_dir, main_file)
|
||||
|
||||
if not os.path.exists(main_path):
|
||||
raise web.HTTPNotFound(text=f"Model file not found: {main_file}")
|
||||
|
||||
# Delete main file first
|
||||
os.remove(main_path)
|
||||
deleted.append(main_file)
|
||||
|
||||
# Delete optional files
|
||||
for pattern in patterns[1:]:
|
||||
path = os.path.join(target_dir, pattern)
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
os.remove(path)
|
||||
deleted.append(pattern)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete {pattern}: {e}")
|
||||
|
||||
return deleted
|
||||
|
||||
async def _read_preview_file(self, reader) -> tuple[bytes, str]:
|
||||
"""Read preview file and content type from multipart request"""
|
||||
field = await reader.next()
|
||||
if field.name != 'preview_file':
|
||||
raise ValueError("Expected 'preview_file' field")
|
||||
content_type = field.headers.get('Content-Type', 'image/png')
|
||||
return await field.read(), content_type
|
||||
|
||||
async def _read_model_path(self, reader) -> str:
|
||||
"""Read model path from multipart request"""
|
||||
field = await reader.next()
|
||||
if field.name != 'model_path':
|
||||
raise ValueError("Expected 'model_path' field")
|
||||
return (await field.read()).decode()
|
||||
|
||||
async def _save_preview_file(self, model_path: str, preview_data: bytes, content_type: str) -> str:
|
||||
"""Save preview file and return its path"""
|
||||
# Determine file extension based on content type
|
||||
if content_type.startswith('video/'):
|
||||
extension = '.preview.mp4'
|
||||
else:
|
||||
extension = '.preview.png'
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(model_path))[0]
|
||||
folder = os.path.dirname(model_path)
|
||||
preview_path = os.path.join(folder, base_name + extension).replace(os.sep, '/')
|
||||
|
||||
with open(preview_path, 'wb') as f:
|
||||
f.write(preview_data)
|
||||
|
||||
return preview_path
|
||||
|
||||
async def _update_preview_metadata(self, model_path: str, preview_path: str):
|
||||
"""Update preview path in metadata"""
|
||||
metadata_path = os.path.splitext(model_path)[0] + '.metadata.json'
|
||||
if os.path.exists(metadata_path):
|
||||
try:
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Update preview_url directly in the metadata dict
|
||||
metadata['preview_url'] = preview_path
|
||||
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating metadata: {e}")
|
||||
|
||||
async def _load_local_metadata(self, metadata_path: str) -> Dict:
|
||||
"""Load local metadata file"""
|
||||
if os.path.exists(metadata_path):
|
||||
try:
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading metadata from {metadata_path}: {e}")
|
||||
return {}
|
||||
|
||||
async def _handle_not_found_on_civitai(self, metadata_path: str, local_metadata: Dict) -> web.Response:
|
||||
"""Handle case when model is not found on CivitAI"""
|
||||
local_metadata['from_civitai'] = False
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Not found on CivitAI"},
|
||||
status=404
|
||||
)
|
||||
|
||||
async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict,
|
||||
civitai_metadata: Dict, client: CivitaiClient) -> None:
|
||||
"""Update local metadata with CivitAI data"""
|
||||
local_metadata['civitai'] = civitai_metadata
|
||||
|
||||
# Update model name if available
|
||||
if 'model' in civitai_metadata:
|
||||
local_metadata['model_name'] = civitai_metadata['model'].get('name',
|
||||
local_metadata.get('model_name'))
|
||||
|
||||
# Update base model
|
||||
local_metadata['base_model'] = civitai_metadata.get('baseModel')
|
||||
|
||||
# Update preview if needed
|
||||
if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']):
|
||||
first_preview = next((img for img in civitai_metadata.get('images', [])), None)
|
||||
if first_preview:
|
||||
preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1]
|
||||
# Fix: Get base name without .metadata.json
|
||||
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
|
||||
preview_filename = base_name + '.preview' + preview_ext
|
||||
preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename)
|
||||
|
||||
if await client.download_preview_image(first_preview['url'], preview_path):
|
||||
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
|
||||
|
||||
# Save updated metadata
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
|
||||
81
routes/lora_routes.py
Normal file
81
routes/lora_routes.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import os
|
||||
from aiohttp import web
|
||||
import jinja2
|
||||
from typing import Dict, List
|
||||
import logging
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoraRoutes:
|
||||
"""Route handlers for LoRA management endpoints"""
|
||||
|
||||
def __init__(self):
|
||||
self.scanner = LoraScanner()
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
|
||||
def format_lora_data(self, lora: Dict) -> Dict:
|
||||
"""Format LoRA data for template rendering"""
|
||||
return {
|
||||
"model_name": lora["model_name"],
|
||||
"file_name": lora["file_name"],
|
||||
"preview_url": config.get_preview_static_url(lora["preview_url"]),
|
||||
"base_model": lora["base_model"],
|
||||
"folder": lora["folder"],
|
||||
"sha256": lora["sha256"],
|
||||
"file_path": lora["file_path"].replace(os.sep, "/"),
|
||||
"modified": lora["modified"],
|
||||
"from_civitai": lora.get("from_civitai", True),
|
||||
"civitai": self._filter_civitai_data(lora.get("civitai", {}))
|
||||
}
|
||||
|
||||
def _filter_civitai_data(self, data: Dict) -> Dict:
|
||||
"""Filter relevant fields from CivitAI data"""
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
fields = [
|
||||
"id", "modelId", "name", "createdAt", "updatedAt",
|
||||
"publishedAt", "trainedWords", "baseModel", "description",
|
||||
"model", "images"
|
||||
]
|
||||
return {k: data[k] for k in fields if k in data}
|
||||
|
||||
async def handle_loras_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /loras request"""
|
||||
try:
|
||||
# Scan for loras
|
||||
loras = await self.scanner.scan_all_loras()
|
||||
|
||||
# Format data for template
|
||||
formatted_loras = [self.format_lora_data(l) for l in loras]
|
||||
folders = sorted(list(set(l['folder'] for l in loras)))
|
||||
|
||||
# Render template
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
loras=formatted_loras,
|
||||
folders=folders
|
||||
)
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling loras request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading loras page",
|
||||
status=500
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def setup_routes(cls, app: web.Application):
|
||||
"""Register routes with the application"""
|
||||
routes = cls()
|
||||
app.router.add_get('/loras', routes.handle_loras_page)
|
||||
60
services/lora_scanner.py
Normal file
60
services/lora_scanner.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import List, Dict
|
||||
from ..config import config
|
||||
from ..utils.file_utils import load_metadata, get_file_info, save_metadata
|
||||
from ..utils.lora_metadata import extract_lora_metadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoraScanner:
|
||||
"""Service for scanning and managing LoRA files"""
|
||||
|
||||
async def scan_all_loras(self) -> List[Dict]:
|
||||
"""Scan all LoRA directories and return metadata"""
|
||||
all_loras = []
|
||||
|
||||
for loras_root in config.loras_roots:
|
||||
try:
|
||||
loras = await self._scan_directory(loras_root)
|
||||
all_loras.extend(loras)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning directory {loras_root}: {e}")
|
||||
|
||||
return all_loras
|
||||
|
||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||
"""Scan a single directory for LoRA files"""
|
||||
loras = []
|
||||
|
||||
for root, _, files in os.walk(root_path):
|
||||
for filename in (f for f in files if f.endswith('.safetensors')):
|
||||
try:
|
||||
file_path = os.path.join(root, filename).replace(os.sep, "/")
|
||||
lora_data = await self._process_lora_file(file_path, root_path)
|
||||
if lora_data:
|
||||
loras.append(lora_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {filename}: {e}")
|
||||
|
||||
return loras
|
||||
|
||||
async def _process_lora_file(self, file_path: str, root_path: str) -> Dict:
|
||||
"""Process a single LoRA file and return its metadata"""
|
||||
# Try loading existing metadata
|
||||
metadata = await load_metadata(file_path)
|
||||
|
||||
if metadata is None:
|
||||
# Create new metadata if none exists
|
||||
metadata = await get_file_info(file_path)
|
||||
base_model_info = await extract_lora_metadata(file_path)
|
||||
metadata.base_model = base_model_info['base_model']
|
||||
await save_metadata(file_path, metadata)
|
||||
|
||||
# Convert to dict and add folder info
|
||||
lora_data = metadata.to_dict()
|
||||
rel_path = os.path.relpath(file_path, root_path)
|
||||
folder = os.path.dirname(rel_path)
|
||||
lora_data['folder'] = folder.replace(os.path.sep, '/')
|
||||
|
||||
return lora_data
|
||||
@@ -28,7 +28,7 @@ def _find_preview_file(base_name: str, dir_path: str) -> str:
|
||||
for pattern in preview_patterns:
|
||||
full_pattern = os.path.join(dir_path, pattern)
|
||||
if os.path.exists(full_pattern):
|
||||
return full_pattern.replace("\\", "/")
|
||||
return full_pattern.replace(os.sep, "/")
|
||||
return ""
|
||||
|
||||
async def get_file_info(file_path: str) -> LoraMetadata:
|
||||
|
||||
Reference in New Issue
Block a user