Reorganize python files

This commit is contained in:
Will Miao
2025-02-24 20:41:16 +08:00
parent f0cd77e7e5
commit 2d72044d66
20 changed files with 5 additions and 7 deletions

0
py/__init__.py Normal file
View File

121
py/config.py Normal file
View File

@@ -0,0 +1,121 @@
import os
import platform
import folder_paths # type: ignore
from typing import List
import logging
logger = logging.getLogger(__name__)
class Config:
"""Global configuration for LoRA Manager"""
def __init__(self):
self.templates_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'templates')
self.static_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static')
# 路径映射字典, target to link mapping
self._path_mappings = {}
# 静态路由映射字典, target to route mapping
self._route_mappings = {}
self.loras_roots = self._init_lora_paths()
# 在初始化时扫描符号链接
self._scan_symbolic_links()
def _is_link(self, path: str) -> bool:
try:
if os.path.islink(path):
return True
if platform.system() == 'Windows':
try:
import ctypes
FILE_ATTRIBUTE_REPARSE_POINT = 0x400
attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path))
return attrs != -1 and (attrs & FILE_ATTRIBUTE_REPARSE_POINT)
except Exception as e:
logger.error(f"Error checking Windows reparse point: {e}")
return False
except Exception as e:
logger.error(f"Error checking link status for {path}: {e}")
return False
def _scan_symbolic_links(self):
"""扫描所有 LoRA 根目录中的符号链接"""
for root in self.loras_roots:
self._scan_directory_links(root)
def _scan_directory_links(self, root: str):
"""递归扫描目录中的符号链接"""
try:
with os.scandir(root) as it:
for entry in it:
if self._is_link(entry.path):
target_path = os.path.realpath(entry.path)
if os.path.isdir(target_path):
self.add_path_mapping(entry.path, target_path)
self._scan_directory_links(target_path)
elif entry.is_dir(follow_symlinks=False):
self._scan_directory_links(entry.path)
except Exception as e:
logger.error(f"Error scanning links in {root}: {e}")
def add_path_mapping(self, link_path: str, target_path: str):
"""添加符号链接路径映射
target_path: 实际目标路径
link_path: 符号链接路径
"""
normalized_link = os.path.normpath(link_path).replace(os.sep, '/')
normalized_target = os.path.normpath(target_path).replace(os.sep, '/')
# 保持原有的映射关系:目标路径 -> 链接路径
self._path_mappings[normalized_target] = normalized_link
logger.info(f"Added path mapping: {normalized_target} -> {normalized_link}")
def add_route_mapping(self, path: str, route: str):
"""添加静态路由映射"""
normalized_path = os.path.normpath(path).replace(os.sep, '/')
self._route_mappings[normalized_path] = route
logger.info(f"Added route mapping: {normalized_path} -> {route}")
def map_path_to_link(self, path: str) -> str:
"""将目标路径映射回符号链接路径"""
normalized_path = os.path.normpath(path).replace(os.sep, '/')
# 检查路径是否包含在任何映射的目标路径中
for target_path, link_path in self._path_mappings.items():
if normalized_path.startswith(target_path):
# 如果路径以目标路径开头,则替换为链接路径
mapped_path = normalized_path.replace(target_path, link_path, 1)
return mapped_path
return path
def _init_lora_paths(self) -> List[str]:
"""Initialize and validate LoRA paths from ComfyUI settings"""
paths = list(set(path.replace(os.sep, "/")
for path in folder_paths.get_folder_paths("loras")
if os.path.exists(path)))
print("Found LoRA roots:", "\n - " + "\n - ".join(paths))
if not paths:
raise ValueError("No valid loras folders found in ComfyUI configuration")
# 初始化路径映射
for path in paths:
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
if real_path != path:
self.add_path_mapping(path, real_path)
return paths
def get_preview_static_url(self, preview_path: str) -> str:
"""Convert local preview path to static URL"""
if not preview_path:
return ""
real_path = os.path.realpath(preview_path).replace(os.sep, '/')
for path, route in self._route_mappings.items():
if real_path.startswith(path):
relative_path = os.path.relpath(real_path, path)
return f'{route}/{relative_path.replace(os.sep, "/")}'
return ""
# Global config instance
config = Config()

108
py/lora_manager.py Normal file
View File

@@ -0,0 +1,108 @@
import asyncio
import os
from server import PromptServer # type: ignore
from .config import config
from .routes.lora_routes import LoraRoutes
from .routes.api_routes import ApiRoutes
from .services.lora_scanner import LoraScanner
from .services.file_monitor import LoraFileMonitor
from .services.lora_cache import LoraCache
import logging
logger = logging.getLogger(__name__)
class LoraManager:
"""Main entry point for LoRA Manager plugin"""
@classmethod
def add_routes(cls):
"""Initialize and register all routes"""
app = PromptServer.instance.app
added_targets = set() # 用于跟踪已添加的目标路径
# Add static routes for each lora root
for idx, root in enumerate(config.loras_roots, start=1):
preview_path = f'/loras_static/root{idx}/preview'
real_root = root
if root in config._path_mappings.values():
for target, link in config._path_mappings.items():
if link == root:
real_root = target
break
# 为原始路径添加静态路由
app.router.add_static(preview_path, real_root)
logger.info(f"Added static route {preview_path} -> {real_root}")
# 记录路由映射
config.add_route_mapping(real_root, preview_path)
added_targets.add(real_root)
# 为符号链接的目标路径添加额外的静态路由
link_idx = 1
for target_path, link_path in config._path_mappings.items():
if target_path not in added_targets:
route_path = f'/loras_static/link_{link_idx}/preview'
app.router.add_static(route_path, target_path)
logger.info(f"Added static route for link target {route_path} -> {target_path}")
config.add_route_mapping(target_path, route_path)
added_targets.add(target_path)
link_idx += 1
# Add static route for plugin assets
app.router.add_static('/loras_static', config.static_path)
# Setup feature routes
routes = LoraRoutes()
# Setup file monitoring
monitor = LoraFileMonitor(routes.scanner, config.loras_roots)
monitor.start()
routes.setup_routes(app)
ApiRoutes.setup_routes(app, monitor)
# Store monitor in app for cleanup
app['lora_monitor'] = monitor
# Schedule cache initialization using the application's startup handler
app.on_startup.append(lambda app: cls._schedule_cache_init(routes.scanner))
# Add cleanup
app.on_shutdown.append(cls._cleanup)
app.on_shutdown.append(ApiRoutes.cleanup)
@classmethod
async def _schedule_cache_init(cls, scanner: LoraScanner):
"""Schedule cache initialization in the running event loop"""
try:
# 创建低优先级的初始化任务
asyncio.create_task(cls._initialize_cache(scanner), name='lora_cache_init')
except Exception as e:
print(f"LoRA Manager: Error scheduling cache initialization: {e}")
@classmethod
async def _initialize_cache(cls, scanner: LoraScanner):
"""Initialize cache in background"""
try:
# 设置初始缓存占位
scanner._cache = LoraCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# 分阶段加载缓存
await scanner.get_cached_data(force_refresh=True)
print("LoRA Manager: Cache initialization completed")
except Exception as e:
print(f"LoRA Manager: Error initializing cache: {e}")
@classmethod
async def _cleanup(cls, app):
"""Cleanup resources"""
if 'lora_monitor' in app:
app['lora_monitor'].stop()

75
py/nodes/lora_loader.py Normal file
View File

@@ -0,0 +1,75 @@
import re
from nodes import LoraLoader
from comfy.comfy_types import IO # type: ignore
from ..services.lora_scanner import LoraScanner
from ..config import config
import asyncio
import os
class LoraManagerLoader:
NAME = "Lora Loader (LoraManager)"
CATEGORY = "loaders"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"clip": ("CLIP",),
"text": (IO.STRING, {
"multiline": True,
"dynamicPrompts": True,
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation"
}),
},
}
RETURN_TYPES = ("MODEL", "CLIP", IO.STRING, IO.STRING)
RETURN_NAMES = ("MODEL", "CLIP", "loaded_loras", "trigger_words")
FUNCTION = "load_loras"
async def get_lora_info(self, lora_name):
"""Get the lora path and trigger words from cache"""
scanner = await LoraScanner.get_instance()
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
file_path = item.get('file_path')
if file_path:
for root in config.loras_roots:
root = root.replace(os.sep, '/')
if file_path.startswith(root):
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
# Get trigger words from civitai metadata
civitai = item.get('civitai', {})
trigger_words = civitai.get('trainedWords', []) if civitai else []
return relative_path, trigger_words
return lora_name, [] # Fallback if not found
def load_loras(self, model, clip, text):
"""Loads multiple LoRAs based on the text input format."""
lora_pattern = r'<lora:([^:]+):([\d\.]+)>'
lora_matches = re.finditer(lora_pattern, text)
loaded_loras = []
all_trigger_words = []
for match in lora_matches:
lora_name = match.group(1)
strength = float(match.group(2))
# Get lora path and trigger words
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
# Apply the LoRA using the resolved path
model, clip = LoraLoader().load_lora(model, clip, lora_path, strength, strength)
loaded_loras.append(f"{lora_name}: {strength}")
# Add trigger words to collection
all_trigger_words.extend(trigger_words)
loaded_loras_text = "\n".join(loaded_loras) if loaded_loras else "No LoRAs loaded"
trigger_words_text = ", ".join(all_trigger_words) if all_trigger_words else ""
return (model, clip, loaded_loras_text, trigger_words_text)

595
py/routes/api_routes.py Normal file
View File

@@ -0,0 +1,595 @@
import os
import json
import logging
from aiohttp import web
from typing import Dict, List
from ..services.file_monitor import LoraFileMonitor
from ..services.download_manager import DownloadManager
from ..services.civitai_client import CivitaiClient
from ..config import config
from ..services.lora_scanner import LoraScanner
from operator import itemgetter
from ..services.websocket_manager import ws_manager
from ..services.settings_manager import settings
import asyncio
logger = logging.getLogger(__name__)
class ApiRoutes:
"""API route handlers for LoRA management"""
def __init__(self, file_monitor: LoraFileMonitor):
self.scanner = LoraScanner()
self.civitai_client = CivitaiClient()
self.download_manager = DownloadManager(file_monitor)
self._download_lock = asyncio.Lock()
@classmethod
def setup_routes(cls, app: web.Application, monitor: LoraFileMonitor):
"""Register API routes"""
routes = cls(monitor)
app.router.add_post('/api/delete_model', routes.delete_model)
app.router.add_post('/api/fetch-civitai', routes.fetch_civitai)
app.router.add_post('/api/replace_preview', routes.replace_preview)
app.router.add_get('/api/loras', routes.get_loras)
app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai)
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
app.router.add_get('/api/lora-roots', routes.get_lora_roots)
app.router.add_get('/api/civitai/versions/{model_id}', routes.get_civitai_versions)
app.router.add_post('/api/download-lora', routes.download_lora)
app.router.add_post('/api/settings', routes.update_settings)
app.router.add_post('/api/move_model', routes.move_model)
app.router.add_post('/loras/api/save-metadata', routes.save_metadata)
async def delete_model(self, request: web.Request) -> web.Response:
"""Handle model deletion request"""
try:
data = await request.json()
file_path = data.get('file_path')
if not file_path:
return web.Response(text='Model path is required', status=400)
target_dir = os.path.dirname(file_path)
file_name = os.path.splitext(os.path.basename(file_path))[0]
deleted_files = await self._delete_model_files(target_dir, file_name)
return web.json_response({
'success': True,
'deleted_files': deleted_files
})
except Exception as e:
logger.error(f"Error deleting model: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
async def fetch_civitai(self, request: web.Request) -> web.Response:
"""Handle CivitAI metadata fetch request"""
try:
data = await request.json()
metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json'
# Check if model is from CivitAI
local_metadata = await self._load_local_metadata(metadata_path)
# Fetch and update metadata
civitai_metadata = await self.civitai_client.get_model_by_hash(local_metadata["sha256"])
if not civitai_metadata:
return await self._handle_not_found_on_civitai(metadata_path, local_metadata)
await self._update_model_metadata(metadata_path, local_metadata, civitai_metadata, self.civitai_client)
return web.json_response({"success": True})
except Exception as e:
logger.error(f"Error fetching from CivitAI: {e}", exc_info=True)
return web.json_response({"success": False, "error": str(e)}, status=500)
async def replace_preview(self, request: web.Request) -> web.Response:
"""Handle preview image replacement request"""
try:
reader = await request.multipart()
preview_data, content_type = await self._read_preview_file(reader)
model_path = await self._read_model_path(reader)
preview_path = await self._save_preview_file(model_path, preview_data, content_type)
await self._update_preview_metadata(model_path, preview_path)
# Update preview URL in scanner cache
await self.scanner.update_preview_in_cache(model_path, preview_path)
return web.json_response({
"success": True,
"preview_url": config.get_preview_static_url(preview_path)
})
except Exception as e:
logger.error(f"Error replacing preview: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
async def get_loras(self, request: web.Request) -> web.Response:
"""Handle paginated LoRA data request"""
try:
# Parse query parameters
page = int(request.query.get('page', '1'))
page_size = int(request.query.get('page_size', '20'))
sort_by = request.query.get('sort_by', 'name')
folder = request.query.get('folder')
search = request.query.get('search', '').lower()
fuzzy = request.query.get('fuzzy', 'false').lower() == 'true'
recursive = request.query.get('recursive', 'false').lower() == 'true'
# Validate parameters
if page < 1 or page_size < 1 or page_size > 100:
return web.json_response({
'error': 'Invalid pagination parameters'
}, status=400)
if sort_by not in ['date', 'name']:
return web.json_response({
'error': 'Invalid sort parameter'
}, status=400)
# Get paginated data with search
result = await self.scanner.get_paginated_data(
page=page,
page_size=page_size,
sort_by=sort_by,
folder=folder,
search=search,
fuzzy=fuzzy,
recursive=recursive # 添加递归参数
)
# Format the response data
formatted_items = [
self._format_lora_response(item)
for item in result['items']
]
# Get all available folders from cache
cache = await self.scanner.get_cached_data()
return web.json_response({
'items': formatted_items,
'total': result['total'],
'page': result['page'],
'page_size': result['page_size'],
'total_pages': result['total_pages'],
'folders': cache.folders
})
except Exception as e:
logger.error(f"Error in get_loras: {str(e)}", exc_info=True)
return web.json_response({
'error': 'Internal server error'
}, status=500)
def _format_lora_response(self, lora: Dict) -> Dict:
"""Format LoRA data for API response"""
return {
"model_name": lora["model_name"],
"file_name": lora["file_name"],
"preview_url": config.get_preview_static_url(lora["preview_url"]),
"base_model": lora["base_model"],
"folder": lora["folder"],
"sha256": lora["sha256"],
"file_path": lora["file_path"].replace(os.sep, "/"),
"modified": lora["modified"],
"from_civitai": lora.get("from_civitai", True),
"usage_tips": lora.get("usage_tips", ""),
"notes": lora.get("notes", ""),
"civitai": self._filter_civitai_data(lora.get("civitai", {}))
}
def _filter_civitai_data(self, data: Dict) -> Dict:
"""Filter relevant fields from CivitAI data"""
if not data:
return {}
fields = [
"id", "modelId", "name", "createdAt", "updatedAt",
"publishedAt", "trainedWords", "baseModel", "description",
"model", "images"
]
return {k: data[k] for k in fields if k in data}
# Private helper methods
async def _delete_model_files(self, target_dir: str, file_name: str) -> List[str]:
"""Delete model and associated files"""
patterns = [
f"{file_name}.safetensors", # Required
f"{file_name}.metadata.json",
f"{file_name}.preview.png",
f"{file_name}.preview.jpg",
f"{file_name}.preview.jpeg",
f"{file_name}.preview.webp",
f"{file_name}.preview.mp4",
f"{file_name}.png",
f"{file_name}.jpg",
f"{file_name}.jpeg",
f"{file_name}.webp",
f"{file_name}.mp4"
]
deleted = []
main_file = patterns[0]
main_path = os.path.join(target_dir, main_file).replace(os.sep, '/')
if os.path.exists(main_path):
# Notify file monitor to ignore delete event
self.download_manager.file_monitor.handler.add_ignore_path(main_path, 0)
# Delete file
os.remove(main_path)
deleted.append(main_path)
else:
logger.warning(f"Model file not found: {main_file}")
# Remove from cache
cache = await self.scanner.get_cached_data()
cache.raw_data = [item for item in cache.raw_data if item['file_path'] != main_path]
await cache.resort()
# Delete optional files
for pattern in patterns[1:]:
path = os.path.join(target_dir, pattern)
if os.path.exists(path):
try:
os.remove(path)
deleted.append(pattern)
except Exception as e:
logger.warning(f"Failed to delete {pattern}: {e}")
return deleted
async def _read_preview_file(self, reader) -> tuple[bytes, str]:
"""Read preview file and content type from multipart request"""
field = await reader.next()
if field.name != 'preview_file':
raise ValueError("Expected 'preview_file' field")
content_type = field.headers.get('Content-Type', 'image/png')
return await field.read(), content_type
async def _read_model_path(self, reader) -> str:
"""Read model path from multipart request"""
field = await reader.next()
if field.name != 'model_path':
raise ValueError("Expected 'model_path' field")
return (await field.read()).decode()
async def _save_preview_file(self, model_path: str, preview_data: bytes, content_type: str) -> str:
"""Save preview file and return its path"""
# Determine file extension based on content type
if content_type.startswith('video/'):
extension = '.preview.mp4'
else:
extension = '.preview.png'
base_name = os.path.splitext(os.path.basename(model_path))[0]
folder = os.path.dirname(model_path)
preview_path = os.path.join(folder, base_name + extension).replace(os.sep, '/')
with open(preview_path, 'wb') as f:
f.write(preview_data)
return preview_path
async def _update_preview_metadata(self, model_path: str, preview_path: str):
"""Update preview path in metadata"""
metadata_path = os.path.splitext(model_path)[0] + '.metadata.json'
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
# Update preview_url directly in the metadata dict
metadata['preview_url'] = preview_path
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
except Exception as e:
logger.error(f"Error updating metadata: {e}")
async def _load_local_metadata(self, metadata_path: str) -> Dict:
"""Load local metadata file"""
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading metadata from {metadata_path}: {e}")
return {}
async def _handle_not_found_on_civitai(self, metadata_path: str, local_metadata: Dict) -> web.Response:
"""Handle case when model is not found on CivitAI"""
local_metadata['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
return web.json_response(
{"success": False, "error": "Not found on CivitAI"},
status=404
)
async def _update_model_metadata(self, metadata_path: str, local_metadata: Dict,
civitai_metadata: Dict, client: CivitaiClient) -> None:
"""Update local metadata with CivitAI data"""
local_metadata['civitai'] = civitai_metadata
# Update model name if available
if 'model' in civitai_metadata:
local_metadata['model_name'] = civitai_metadata['model'].get('name',
local_metadata.get('model_name'))
# Update base model
local_metadata['base_model'] = civitai_metadata.get('baseModel')
# Update preview if needed
if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']):
first_preview = next((img for img in civitai_metadata.get('images', [])), None)
if first_preview:
preview_ext = '.mp4' if first_preview['type'] == 'video' else os.path.splitext(first_preview['url'])[-1]
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
preview_filename = base_name + preview_ext
preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename)
if await client.download_preview_image(first_preview['url'], preview_path):
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
await self.scanner.update_single_lora_cache(local_metadata['file_path'], local_metadata['file_path'], local_metadata)
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
"""Fetch CivitAI metadata for all loras in the background"""
try:
cache = await self.scanner.get_cached_data()
total = len(cache.raw_data)
processed = 0
success = 0
needs_resort = False
# 准备要处理的 loras
to_process = [
lora for lora in cache.raw_data
if lora.get('sha256') and not lora.get('civitai') and lora.get('from_civitai')
]
total_to_process = len(to_process)
# 发送初始进度
await ws_manager.broadcast({
'status': 'started',
'total': total_to_process,
'processed': 0,
'success': 0
})
for lora in to_process:
try:
original_name = lora.get('model_name')
if await self._fetch_and_update_single_lora(
sha256=lora['sha256'],
file_path=lora['file_path'],
lora=lora
):
success += 1
if original_name != lora.get('model_name'):
needs_resort = True
processed += 1
# 每处理一个就发送进度更新
await ws_manager.broadcast({
'status': 'processing',
'total': total_to_process,
'processed': processed,
'success': success,
'current_name': lora.get('model_name', 'Unknown')
})
except Exception as e:
logger.error(f"Error fetching CivitAI data for {lora['file_path']}: {e}")
if needs_resort:
await cache.resort(name_only=True)
# 发送完成消息
await ws_manager.broadcast({
'status': 'completed',
'total': total_to_process,
'processed': processed,
'success': success
})
return web.json_response({
"success": True,
"message": f"Successfully updated {success} of {processed} processed loras (total: {total})"
})
except Exception as e:
# 发送错误消息
await ws_manager.broadcast({
'status': 'error',
'error': str(e)
})
logger.error(f"Error in fetch_all_civitai: {e}")
return web.Response(text=str(e), status=500)
async def _fetch_and_update_single_lora(self, sha256: str, file_path: str, lora: dict) -> bool:
"""Fetch and update metadata for a single lora without sorting
Args:
sha256: SHA256 hash of the lora file
file_path: Path to the lora file
lora: The lora object in cache to update
Returns:
bool: True if successful, False otherwise
"""
client = CivitaiClient()
try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Check if model is from CivitAI
local_metadata = await self._load_local_metadata(metadata_path)
# Fetch metadata
civitai_metadata = await client.get_model_by_hash(sha256)
if not civitai_metadata:
# Mark as not from CivitAI if not found
local_metadata['from_civitai'] = False
lora['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
return False
# Update metadata
await self._update_model_metadata(
metadata_path,
local_metadata,
civitai_metadata,
client
)
# Update cache object directly
lora.update({
'model_name': local_metadata.get('model_name'),
'preview_url': local_metadata.get('preview_url'),
'from_civitai': True,
'civitai': civitai_metadata
})
return True
except Exception as e:
logger.error(f"Error fetching CivitAI data: {e}")
return False
finally:
await client.close()
async def get_lora_roots(self, request: web.Request) -> web.Response:
"""Get all configured LoRA root directories"""
return web.json_response({
'roots': config.loras_roots
})
async def get_civitai_versions(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai model"""
try:
model_id = request.match_info['model_id']
versions = await self.civitai_client.get_model_versions(model_id)
if not versions:
return web.Response(status=404, text="Model not found")
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching model versions: {e}")
return web.Response(status=500, text=str(e))
async def download_lora(self, request: web.Request) -> web.Response:
async with self._download_lock:
try:
data = await request.json()
# Create progress callback
async def progress_callback(progress):
await ws_manager.broadcast({
'status': 'progress',
'progress': progress
})
result = await self.download_manager.download_from_civitai(
download_url=data.get('download_url'),
save_dir=data.get('lora_root'),
relative_path=data.get('relative_path'),
progress_callback=progress_callback # Add progress callback
)
if not result.get('success', False):
return web.Response(status=500, text=result.get('error', 'Unknown error'))
return web.json_response(result)
except Exception as e:
logger.error(f"Error downloading LoRA: {e}")
return web.Response(status=500, text=str(e))
async def update_settings(self, request: web.Request) -> web.Response:
"""Update application settings"""
try:
data = await request.json()
# Validate and update settings
if 'civitai_api_key' in data:
settings.set('civitai_api_key', data['civitai_api_key'])
return web.json_response({'success': True})
except Exception as e:
logger.error(f"Error updating settings: {e}", exc_info=True) # 添加 exc_info=True 以获取完整堆栈
return web.Response(status=500, text=str(e))
async def move_model(self, request: web.Request) -> web.Response:
"""Handle model move request"""
try:
data = await request.json()
file_path = data.get('file_path')
target_path = data.get('target_path')
if not file_path or not target_path:
return web.Response(text='File path and target path are required', status=400)
# Call scanner to handle the move operation
success = await self.scanner.move_model(file_path, target_path)
if success:
return web.json_response({'success': True})
else:
return web.Response(text='Failed to move model', status=500)
except Exception as e:
logger.error(f"Error moving model: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
@classmethod
async def cleanup(cls):
"""Add cleanup method for application shutdown"""
if hasattr(cls, '_instance'):
await cls._instance.civitai_client.close()
async def save_metadata(self, request: web.Request) -> web.Response:
"""Handle saving metadata updates"""
try:
data = await request.json()
file_path = data.get('file_path')
if not file_path:
return web.Response(text='File path is required', status=400)
# Remove file path from data to avoid saving it
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
# Get metadata file path
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Load existing metadata
if os.path.exists(metadata_path):
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
else:
metadata = {}
# Update metadata with new values
metadata.update(metadata_updates)
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
# Update cache
await self.scanner.update_single_lora_cache(file_path, file_path, metadata)
return web.json_response({'success': True})
except Exception as e:
logger.error(f"Error saving metadata: {e}", exc_info=True)
return web.Response(text=str(e), status=500)

91
py/routes/lora_routes.py Normal file
View File

@@ -0,0 +1,91 @@
import os
from aiohttp import web
import jinja2
from typing import Dict, List
import logging
from ..services.lora_scanner import LoraScanner
from ..config import config
from ..services.settings_manager import settings # Add this import
logger = logging.getLogger(__name__)
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
class LoraRoutes:
"""Route handlers for LoRA management endpoints"""
def __init__(self):
self.scanner = LoraScanner()
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
def format_lora_data(self, lora: Dict) -> Dict:
"""Format LoRA data for template rendering"""
return {
"model_name": lora["model_name"],
"file_name": lora["file_name"],
"preview_url": config.get_preview_static_url(lora["preview_url"]),
"base_model": lora["base_model"],
"folder": lora["folder"],
"sha256": lora["sha256"],
"file_path": lora["file_path"].replace(os.sep, "/"),
"modified": lora["modified"],
"from_civitai": lora.get("from_civitai", True),
"civitai": self._filter_civitai_data(lora.get("civitai", {}))
}
def _filter_civitai_data(self, data: Dict) -> Dict:
"""Filter relevant fields from CivitAI data"""
if not data:
return {}
fields = [
"id", "modelId", "name", "createdAt", "updatedAt",
"publishedAt", "trainedWords", "baseModel", "description",
"model", "images"
]
return {k: data[k] for k in fields if k in data}
async def handle_loras_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras request"""
try:
# 不等待缓存数据,直接检查缓存状态
is_initializing = (
self.scanner._cache is None or
(hasattr(self.scanner, '_cache') and len(self.scanner._cache.raw_data) == 0)
)
if is_initializing:
# 如果正在初始化,返回一个只包含加载提示的页面
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=[], # 空文件夹列表
is_initializing=True, # 新增标志
settings=settings # Pass settings to template
)
else:
# 正常流程
cache = await self.scanner.get_cached_data()
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=cache.folders,
is_initializing=False,
settings=settings # Pass settings to template
)
return web.Response(
text=rendered,
content_type='text/html'
)
except Exception as e:
logger.error(f"Error handling loras request: {e}", exc_info=True)
return web.Response(
text="Error loading loras page",
status=500
)
def setup_routes(self, app: web.Application):
"""Register routes with the application"""
app.router.add_get('/loras', self.handle_loras_page)

1
py/services/__init__.py Normal file
View File

@@ -0,0 +1 @@
# Empty file to mark directory as Python package

View File

@@ -0,0 +1,171 @@
from datetime import datetime
import aiohttp
import os
import json
import logging
from email.parser import Parser
from typing import Optional, Dict, Tuple
from urllib.parse import unquote
from ..utils.models import LoraMetadata
logger = logging.getLogger(__name__)
class CivitaiClient:
def __init__(self):
self.base_url = "https://civitai.com/api/v1"
self.headers = {
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
}
self._session = None
@property
async def session(self) -> aiohttp.ClientSession:
"""Lazy initialize the session"""
if self._session is None:
connector = aiohttp.TCPConnector(ssl=True)
trust_env = True # 允许使用系统环境变量中的代理设置
self._session = aiohttp.ClientSession(connector=connector, trust_env=trust_env)
return self._session
def _parse_content_disposition(self, header: str) -> str:
"""Parse filename from content-disposition header"""
if not header:
return None
# Handle quoted filenames
if 'filename="' in header:
start = header.index('filename="') + 10
end = header.index('"', start)
return unquote(header[start:end])
# Fallback to original parsing
disposition = Parser().parsestr(f'Content-Disposition: {header}')
filename = disposition.get_param('filename')
if filename:
return unquote(filename)
return None
def _get_request_headers(self) -> dict:
"""Get request headers with optional API key"""
headers = {
'User-Agent': 'ComfyUI-LoRA-Manager/1.0',
'Content-Type': 'application/json'
}
from .settings_manager import settings
api_key = settings.get('civitai_api_key')
if (api_key):
headers['Authorization'] = f'Bearer {api_key}'
return headers
async def _download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
"""Download file with content-disposition support and progress tracking
Args:
url: Download URL
save_dir: Directory to save the file
default_filename: Fallback filename if none provided in headers
progress_callback: Optional async callback function for progress updates (0-100)
Returns:
Tuple[bool, str]: (success, save_path or error message)
"""
session = await self.session
try:
headers = self._get_request_headers()
async with session.get(url, headers=headers, allow_redirects=True) as response:
if response.status != 200:
return False, f"Download failed with status {response.status}"
# Get filename from content-disposition header
content_disposition = response.headers.get('Content-Disposition')
filename = self._parse_content_disposition(content_disposition)
if not filename:
filename = default_filename
save_path = os.path.join(save_dir, filename)
# Get total file size for progress calculation
total_size = int(response.headers.get('content-length', 0))
current_size = 0
# Stream download to file with progress updates
with open(save_path, 'wb') as f:
async for chunk in response.content.iter_chunked(8192):
if chunk:
f.write(chunk)
current_size += len(chunk)
if progress_callback and total_size:
progress = (current_size / total_size) * 100
await progress_callback(progress)
# Ensure 100% progress is reported
if progress_callback:
await progress_callback(100)
return True, save_path
except Exception as e:
logger.error(f"Download error: {e}")
return False, str(e)
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
try:
session = await self.session
async with session.get(f"{self.base_url}/model-versions/by-hash/{model_hash}") as response:
if response.status == 200:
return await response.json()
return None
except Exception as e:
logger.error(f"API Error: {str(e)}")
return None
async def download_preview_image(self, image_url: str, save_path: str):
try:
session = await self.session
async with session.get(image_url) as response:
if response.status == 200:
content = await response.read()
with open(save_path, 'wb') as f:
f.write(content)
return True
return False
except Exception as e:
print(f"Download Error: {str(e)}")
return False
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
"""Fetch all versions of a model"""
try:
session = await self.session
url = f"{self.base_url}/models/{model_id}"
async with session.get(url, headers=self.headers) as response:
if response.status == 200:
data = await response.json()
return data.get('modelVersions', [])
return None
except Exception as e:
logger.error(f"Error fetching model versions: {e}")
return None
async def get_model_version_info(self, version_id: str) -> Optional[Dict]:
"""Fetch model version metadata from Civitai"""
try:
session = await self.session
url = f"{self.base_url}/model-versions/{version_id}"
headers = self._get_request_headers()
async with session.get(url, headers=headers) as response:
if response.status == 200:
return await response.json()
return None
except Exception as e:
logger.error(f"Error fetching model version info: {e}")
return None
async def close(self):
"""Close the session if it exists"""
if self._session is not None:
await self._session.close()
self._session = None

View File

@@ -0,0 +1,154 @@
import logging
import os
import json
from typing import Optional, Dict
from .civitai_client import CivitaiClient
from .file_monitor import LoraFileMonitor
from ..utils.models import LoraMetadata
logger = logging.getLogger(__name__)
class DownloadManager:
def __init__(self, file_monitor: Optional[LoraFileMonitor] = None):
self.civitai_client = CivitaiClient()
self.file_monitor = file_monitor
async def download_from_civitai(self, download_url: str, save_dir: str, relative_path: str = '',
progress_callback=None) -> Dict:
try:
# Update save directory with relative path if provided
if relative_path:
save_dir = os.path.join(save_dir, relative_path)
# Create directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)
# Get version info
version_id = download_url.split('/')[-1]
version_info = await self.civitai_client.get_model_version_info(version_id)
if not version_info:
return {'success': False, 'error': 'Failed to fetch model metadata'}
# Report initial progress
if progress_callback:
await progress_callback(0)
# 2. 获取文件信息
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
if not file_info:
return {'success': False, 'error': 'No primary file found in metadata'}
# 3. 准备下载
file_name = file_info['name']
save_path = os.path.join(save_dir, file_name)
file_size = file_info.get('sizeKB', 0) * 1024
# 4. 通知文件监控系统
self.file_monitor.handler.add_ignore_path(
save_path.replace(os.sep, '/'),
file_size
)
# 5. 准备元数据
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
# 6. 开始下载流程
result = await self._execute_download(
download_url=download_url,
save_dir=save_dir,
metadata=metadata,
version_info=version_info,
relative_path=relative_path,
progress_callback=progress_callback
)
return result
except Exception as e:
logger.error(f"Error in download_from_civitai: {e}", exc_info=True)
return {'success': False, 'error': str(e)}
async def _execute_download(self, download_url: str, save_dir: str,
metadata: LoraMetadata, version_info: Dict,
relative_path: str, progress_callback=None) -> Dict:
"""Execute the actual download process including preview images and model files"""
try:
save_path = metadata.file_path
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
# Download preview image if available
images = version_info.get('images', [])
if images:
# Report preview download progress
if progress_callback:
await progress_callback(5) # 5% progress for starting preview download
preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png'
preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext
if await self.civitai_client.download_preview_image(images[0]['url'], preview_path):
metadata.preview_url = preview_path.replace(os.sep, '/')
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
# Report preview download completion
if progress_callback:
await progress_callback(10) # 10% progress after preview download
# Download model file with progress tracking
success, result = await self.civitai_client._download_file(
download_url,
save_dir,
os.path.basename(save_path),
progress_callback=lambda p: self._handle_download_progress(p, progress_callback)
)
if not success:
# Clean up files on failure
for path in [save_path, metadata_path, metadata.preview_url]:
if path and os.path.exists(path):
os.remove(path)
return {'success': False, 'error': result}
# 4. 更新文件信息(大小和修改时间)
metadata.update_file_info(save_path)
# 5. 最终更新元数据
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
# 6. update lora cache
cache = await self.file_monitor.scanner.get_cached_data()
metadata_dict = metadata.to_dict()
metadata_dict['folder'] = relative_path
cache.raw_data.append(metadata_dict)
await cache.resort()
all_folders = set(cache.folders)
all_folders.add(relative_path)
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Report 100% completion
if progress_callback:
await progress_callback(100)
return {
'success': True
}
except Exception as e:
logger.error(f"Error in _execute_download: {e}", exc_info=True)
# Clean up partial downloads
for path in [save_path, metadata_path]:
if path and os.path.exists(path):
os.remove(path)
return {'success': False, 'error': str(e)}
async def _handle_download_progress(self, file_progress: float, progress_callback):
"""Convert file download progress to overall progress
Args:
file_progress: Progress of file download (0-100)
progress_callback: Callback function for progress updates
"""
if progress_callback:
# Scale file progress to 10-100 range (after preview download)
overall_progress = 10 + (file_progress * 0.9) # 90% of progress for file download
await progress_callback(round(overall_progress))

184
py/services/file_monitor.py Normal file
View File

@@ -0,0 +1,184 @@
from operator import itemgetter
import os
import logging
import asyncio
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler, FileCreatedEvent, FileDeletedEvent
from typing import List
from threading import Lock
from .lora_scanner import LoraScanner
from ..config import config
logger = logging.getLogger(__name__)
class LoraFileHandler(FileSystemEventHandler):
"""Handler for LoRA file system events"""
def __init__(self, scanner: LoraScanner, loop: asyncio.AbstractEventLoop):
self.scanner = scanner
self.loop = loop # 存储事件循环引用
self.pending_changes = set() # 待处理的变更
self.lock = Lock() # 线程安全锁
self.update_task = None # 异步更新任务
self._ignore_paths = set() # Add ignore paths set
self._min_ignore_timeout = 5 # minimum timeout in seconds
self._download_speed = 1024 * 1024 # assume 1MB/s as base speed
def _should_ignore(self, path: str) -> bool:
"""Check if path should be ignored"""
real_path = os.path.realpath(path) # Resolve any symbolic links
return real_path.replace(os.sep, '/') in self._ignore_paths
def add_ignore_path(self, path: str, file_size: int = 0):
"""Add path to ignore list with dynamic timeout based on file size"""
real_path = os.path.realpath(path) # Resolve any symbolic links
self._ignore_paths.add(real_path.replace(os.sep, '/'))
# Short timeout (e.g. 5 seconds) is sufficient to ignore the CREATE event
timeout = 5
asyncio.get_event_loop().call_later(
timeout,
self._ignore_paths.discard,
real_path.replace(os.sep, '/')
)
def on_created(self, event):
if event.is_directory or not event.src_path.endswith('.safetensors'):
return
if self._should_ignore(event.src_path):
return
logger.info(f"LoRA file created: {event.src_path}")
self._schedule_update('add', event.src_path)
def on_deleted(self, event):
if event.is_directory or not event.src_path.endswith('.safetensors'):
return
if self._should_ignore(event.src_path):
return
logger.info(f"LoRA file deleted: {event.src_path}")
self._schedule_update('remove', event.src_path)
def _schedule_update(self, action: str, file_path: str): #file_path is a real path
"""Schedule a cache update"""
with self.lock:
# 使用 config 中的方法映射路径
mapped_path = config.map_path_to_link(file_path)
normalized_path = mapped_path.replace(os.sep, '/')
self.pending_changes.add((action, normalized_path))
self.loop.call_soon_threadsafe(self._create_update_task)
def _create_update_task(self):
"""Create update task in the event loop"""
if self.update_task is None or self.update_task.done():
self.update_task = asyncio.create_task(self._process_changes())
async def _process_changes(self, delay: float = 2.0):
"""Process pending changes with debouncing"""
await asyncio.sleep(delay)
try:
with self.lock:
changes = self.pending_changes.copy()
self.pending_changes.clear()
if not changes:
return
logger.info(f"Processing {len(changes)} file changes")
cache = await self.scanner.get_cached_data() # 先完成可能的初始化
needs_resort = False
new_folders = set() # 用于收集新的文件夹
for action, file_path in changes:
try:
if action == 'add':
# 扫描新文件
lora_data = await self.scanner.scan_single_lora(file_path)
if lora_data:
cache.raw_data.append(lora_data)
new_folders.add(lora_data['folder']) # 收集新文件夹
needs_resort = True
elif action == 'remove':
# 从缓存中移除
logger.info(f"Removing {file_path} from cache")
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != file_path
]
needs_resort = True
except Exception as e:
logger.error(f"Error processing {action} for {file_path}: {e}")
if needs_resort:
await cache.resort()
# 更新文件夹列表,包括新添加的文件夹
all_folders = set(cache.folders) | new_folders
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
except Exception as e:
logger.error(f"Error in process_changes: {e}")
class LoraFileMonitor:
"""Monitor for LoRA file changes"""
def __init__(self, scanner: LoraScanner, roots: List[str]):
self.scanner = scanner
scanner.set_file_monitor(self)
self.observer = Observer()
self.loop = asyncio.get_event_loop()
self.handler = LoraFileHandler(scanner, self.loop)
# 使用已存在的路径映射
self.monitor_paths = set()
for root in roots:
self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/'))
# 添加所有已映射的目标路径
for target_path in config._path_mappings.keys():
self.monitor_paths.add(target_path)
def start(self):
"""Start monitoring"""
for path_info in self.monitor_paths:
try:
if isinstance(path_info, tuple):
# 对于链接,监控目标路径
_, target_path = path_info
self.observer.schedule(self.handler, target_path, recursive=True)
logger.info(f"Started monitoring target path: {target_path}")
else:
# 对于普通路径,直接监控
self.observer.schedule(self.handler, path_info, recursive=True)
logger.info(f"Started monitoring: {path_info}")
except Exception as e:
logger.error(f"Error monitoring {path_info}: {e}")
self.observer.start()
def stop(self):
"""Stop monitoring"""
self.observer.stop()
self.observer.join()
def rescan_links(self):
"""重新扫描链接(当添加新的链接时调用)"""
new_paths = set()
for path in self.monitor_paths.copy():
self._add_link_targets(path)
# 添加新发现的路径到监控
new_paths = self.monitor_paths - set(self.observer.watches.keys())
for path in new_paths:
try:
self.observer.schedule(self.handler, path, recursive=True)
logger.info(f"Added new monitoring path: {path}")
except Exception as e:
logger.error(f"Error adding new monitor for {path}: {e}")

64
py/services/lora_cache.py Normal file
View File

@@ -0,0 +1,64 @@
import asyncio
from typing import List, Dict
from dataclasses import dataclass
from operator import itemgetter
@dataclass
class LoraCache:
"""Cache structure for LoRA data"""
raw_data: List[Dict]
sorted_by_name: List[Dict]
sorted_by_date: List[Dict]
folders: List[str]
def __post_init__(self):
self._lock = asyncio.Lock()
async def resort(self, name_only: bool = False):
"""Resort all cached data views"""
async with self._lock:
self.sorted_by_name = sorted(
self.raw_data,
key=lambda x: x['model_name'].lower() # Case-insensitive sort
)
if not name_only:
self.sorted_by_date = sorted(
self.raw_data,
key=itemgetter('modified'),
reverse=True
)
# Update folder list
all_folders = set(l['folder'] for l in self.raw_data)
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
async def update_preview_url(self, file_path: str, preview_url: str) -> bool:
"""Update preview_url for a specific lora in all cached data
Args:
file_path: The file path of the lora to update
preview_url: The new preview URL
Returns:
bool: True if the update was successful, False if the lora wasn't found
"""
async with self._lock:
# Update in raw_data
for item in self.raw_data:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
else:
return False # Lora not found
# Update in sorted lists (references to the same dict objects)
for item in self.sorted_by_name:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
for item in self.sorted_by_date:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
return True

439
py/services/lora_scanner.py Normal file
View File

@@ -0,0 +1,439 @@
import json
import os
import logging
import asyncio
import shutil
from typing import List, Dict, Optional
from dataclasses import dataclass
from operator import itemgetter
from ..config import config
from ..utils.file_utils import load_metadata, get_file_info
from .lora_cache import LoraCache
from difflib import SequenceMatcher
logger = logging.getLogger(__name__)
class LoraScanner:
"""Service for scanning and managing LoRA files"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
# 确保初始化只执行一次
if not hasattr(self, '_initialized'):
self._cache: Optional[LoraCache] = None
self._initialization_lock = asyncio.Lock()
self._initialization_task: Optional[asyncio.Task] = None
self._initialized = True
self.file_monitor = None # Add this line
def set_file_monitor(self, monitor):
"""Set file monitor instance"""
self.file_monitor = monitor
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
async def get_cached_data(self, force_refresh: bool = False) -> LoraCache:
"""Get cached LoRA data, refresh if needed"""
async with self._initialization_lock:
# 如果缓存未初始化但需要响应请求,返回空缓存
if self._cache is None and not force_refresh:
return LoraCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# 如果正在初始化,等待完成
if self._initialization_task and not self._initialization_task.done():
try:
await self._initialization_task
except Exception as e:
logger.error(f"Cache initialization failed: {e}")
self._initialization_task = None
if (self._cache is None or force_refresh):
# 创建新的初始化任务
if not self._initialization_task or self._initialization_task.done():
self._initialization_task = asyncio.create_task(self._initialize_cache())
try:
await self._initialization_task
except Exception as e:
logger.error(f"Cache initialization failed: {e}")
# 如果缓存已存在,继续使用旧缓存
if self._cache is None:
raise # 如果没有缓存,则抛出异常
return self._cache
async def _initialize_cache(self) -> None:
"""Initialize or refresh the cache"""
# Scan for new data
raw_data = await self.scan_all_loras()
# Update cache
self._cache = LoraCache(
raw_data=raw_data,
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# Call resort_cache to create sorted views
await self._cache.resort()
def fuzzy_match(self, text: str, pattern: str, threshold: float = 0.7) -> bool:
"""
Check if text matches pattern using fuzzy matching.
Returns True if similarity ratio is above threshold.
"""
if not pattern or not text:
return False
# Convert both to lowercase for case-insensitive matching
text = text.lower()
pattern = pattern.lower()
# Split pattern into words
search_words = pattern.split()
# Check each word
for word in search_words:
# First check if word is a substring (faster)
if word in text:
continue
# If not found as substring, try fuzzy matching
# Check if any part of the text matches this word
found_match = False
for text_part in text.split():
ratio = SequenceMatcher(None, text_part, word).ratio()
if ratio >= threshold:
found_match = True
break
if not found_match:
return False
# All words found either as substrings or fuzzy matches
return True
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
folder: str = None, search: str = None, fuzzy: bool = False,
recursive: bool = False):
"""Get paginated and filtered lora data
Args:
page: Current page number (1-based)
page_size: Number of items per page
sort_by: Sort method ('name' or 'date')
folder: Filter by folder path
search: Search term
fuzzy: Use fuzzy matching for search
recursive: Include subfolders when folder filter is applied
"""
cache = await self.get_cached_data()
# 先获取基础数据集
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
# 应用文件夹过滤
if folder is not None:
if recursive:
# 递归模式:匹配所有以该文件夹开头的路径
filtered_data = [
item for item in filtered_data
if item['folder'].startswith(folder + '/') or item['folder'] == folder
]
else:
# 非递归模式:只匹配确切的文件夹
filtered_data = [
item for item in filtered_data
if item['folder'] == folder
]
# 应用搜索过滤
if search:
if fuzzy:
filtered_data = [
item for item in filtered_data
if any(
self.fuzzy_match(str(value), search)
for value in [
item.get('model_name', ''),
item.get('base_model', '')
]
if value
)
]
else:
# Original exact search logic
filtered_data = [
item for item in filtered_data
if search in str(item.get('model_name', '')).lower()
]
# 计算分页
total_items = len(filtered_data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
result = {
'items': filtered_data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
return result
def invalidate_cache(self):
"""Invalidate the current cache"""
self._cache = None
async def scan_all_loras(self) -> List[Dict]:
"""Scan all LoRA directories and return metadata"""
all_loras = []
# 分目录异步扫描
scan_tasks = []
for loras_root in config.loras_roots:
task = asyncio.create_task(self._scan_directory(loras_root))
scan_tasks.append(task)
for task in scan_tasks:
try:
loras = await task
all_loras.extend(loras)
except Exception as e:
logger.error(f"Error scanning directory: {e}")
return all_loras
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a single directory for LoRA files"""
loras = []
original_root = root_path # 保存原始根路径
async def scan_recursive(path: str, visited_paths: set):
"""递归扫描目录,避免循环链接"""
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True) and entry.name.endswith('.safetensors'):
# 使用原始路径而不是真实路径
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, loras)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# 对于目录,使用原始路径继续扫描
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return loras
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
"""处理单个文件并添加到结果列表"""
try:
result = await self._process_lora_file(file_path, root_path)
if result:
loras.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
async def _process_lora_file(self, file_path: str, root_path: str) -> Dict:
"""Process a single LoRA file and return its metadata"""
# Try loading existing metadata
metadata = await load_metadata(file_path)
if metadata is None:
# Create new metadata if none exists
metadata = await get_file_info(file_path)
# Convert to dict and add folder info
lora_data = metadata.to_dict()
rel_path = os.path.relpath(file_path, root_path)
folder = os.path.dirname(rel_path)
lora_data['folder'] = folder.replace(os.path.sep, '/')
return lora_data
async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool:
"""Update preview URL in cache for a specific lora
Args:
file_path: The file path of the lora to update
preview_url: The new preview URL
Returns:
bool: True if the update was successful, False if cache doesn't exist or lora wasn't found
"""
if self._cache is None:
return False
return await self._cache.update_preview_url(file_path, preview_url)
async def scan_single_lora(self, file_path: str) -> Optional[Dict]:
"""Scan a single LoRA file and return its metadata"""
try:
if not os.path.exists(os.path.realpath(file_path)):
return None
# 获取基本文件信息
metadata = await get_file_info(file_path)
if not metadata:
return None
folder = self._calculate_folder(file_path)
# 确保 folder 字段存在
metadata_dict = metadata.to_dict()
metadata_dict['folder'] = folder or ''
return metadata_dict
except Exception as e:
logger.error(f"Error scanning {file_path}: {e}")
return None
def _calculate_folder(self, file_path: str) -> str:
"""Calculate the folder path for a LoRA file"""
# 使用原始路径计算相对路径
for root in config.loras_roots:
if file_path.startswith(root):
rel_path = os.path.relpath(file_path, root)
return os.path.dirname(rel_path).replace(os.path.sep, '/')
return ''
async def move_model(self, source_path: str, target_path: str) -> bool:
"""Move a model and its associated files to a new location"""
try:
# 保持原始路径格式
source_path = source_path.replace(os.sep, '/')
target_path = target_path.replace(os.sep, '/')
# 其余代码保持不变
base_name = os.path.splitext(os.path.basename(source_path))[0]
source_dir = os.path.dirname(source_path)
os.makedirs(target_path, exist_ok=True)
target_lora = os.path.join(target_path, f"{base_name}.safetensors").replace(os.sep, '/')
# 使用真实路径进行文件操作
real_source = os.path.realpath(source_path)
real_target = os.path.realpath(target_lora)
file_size = os.path.getsize(real_source)
if self.file_monitor:
self.file_monitor.handler.add_ignore_path(
real_source,
file_size
)
self.file_monitor.handler.add_ignore_path(
real_target,
file_size
)
# 使用真实路径进行文件操作
shutil.move(real_source, real_target)
# Move associated files
source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json")
if os.path.exists(source_metadata):
target_metadata = os.path.join(target_path, f"{base_name}.metadata.json")
shutil.move(source_metadata, target_metadata)
metadata = await self._update_metadata_paths(target_metadata, target_lora)
# Move preview file if exists
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
'.png', '.jpeg', '.jpg', '.mp4']
for ext in preview_extensions:
source_preview = os.path.join(source_dir, f"{base_name}{ext}")
if os.path.exists(source_preview):
target_preview = os.path.join(target_path, f"{base_name}{ext}")
shutil.move(source_preview, target_preview)
break
# Update cache
await self.update_single_lora_cache(source_path, target_lora, metadata)
return True
except Exception as e:
logger.error(f"Error moving model: {e}", exc_info=True)
return False
async def update_single_lora_cache(self, original_path: str, new_path: str, metadata: Dict) -> bool:
cache = await self.get_cached_data()
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != original_path
]
if metadata:
metadata['folder'] = self._calculate_folder(new_path)
cache.raw_data.append(metadata)
all_folders = set(cache.folders)
all_folders.add(metadata['folder'])
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Resort cache
await cache.resort()
async def _update_metadata_paths(self, metadata_path: str, lora_path: str) -> Dict:
"""Update file paths in metadata file"""
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
# Update file_path
metadata['file_path'] = lora_path.replace(os.sep, '/')
# Update preview_url if exists
if 'preview_url' in metadata:
preview_dir = os.path.dirname(lora_path)
preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0]
preview_ext = os.path.splitext(metadata['preview_url'])[1]
new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}")
metadata['preview_url'] = new_preview_path.replace(os.sep, '/')
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
return metadata
except Exception as e:
logger.error(f"Error updating metadata paths: {e}", exc_info=True)

View File

@@ -0,0 +1,46 @@
import os
import json
import logging
from typing import Any, Dict
logger = logging.getLogger(__name__)
class SettingsManager:
def __init__(self):
self.settings_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json')
self.settings = self._load_settings()
def _load_settings(self) -> Dict[str, Any]:
"""Load settings from file"""
if os.path.exists(self.settings_file):
try:
with open(self.settings_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading settings: {e}")
return self._get_default_settings()
def _get_default_settings(self) -> Dict[str, Any]:
"""Return default settings"""
return {
"civitai_api_key": ""
}
def get(self, key: str, default: Any = None) -> Any:
"""Get setting value"""
return self.settings.get(key, default)
def set(self, key: str, value: Any) -> None:
"""Set setting value and save"""
self.settings[key] = value
self._save_settings()
def _save_settings(self) -> None:
"""Save settings to file"""
try:
with open(self.settings_file, 'w', encoding='utf-8') as f:
json.dump(self.settings, f, indent=2)
except Exception as e:
logger.error(f"Error saving settings: {e}")
settings = SettingsManager()

View File

@@ -0,0 +1,43 @@
import logging
from aiohttp import web
from typing import Set, Dict, Optional
logger = logging.getLogger(__name__)
class WebSocketManager:
"""Manages WebSocket connections and broadcasts"""
def __init__(self):
self._websockets: Set[web.WebSocketResponse] = set()
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection"""
ws = web.WebSocketResponse()
await ws.prepare(request)
self._websockets.add(ws)
try:
async for msg in ws:
if msg.type == web.WSMsgType.ERROR:
logger.error(f'WebSocket error: {ws.exception()}')
finally:
self._websockets.discard(ws)
return ws
async def broadcast(self, data: Dict):
"""Broadcast message to all connected clients"""
if not self._websockets:
return
for ws in self._websockets:
try:
await ws.send_json(data)
except Exception as e:
logger.error(f"Error sending progress: {e}")
def get_connected_clients_count(self) -> int:
"""Get number of connected clients"""
return len(self._websockets)
# Global instance
ws_manager = WebSocketManager()

1
py/utils/__init__.py Normal file
View File

@@ -0,0 +1 @@
# Empty file to mark directory as Python package

137
py/utils/file_utils.py Normal file
View File

@@ -0,0 +1,137 @@
import logging
import os
import hashlib
import json
from typing import Dict, Optional
from .lora_metadata import extract_lora_metadata
from .models import LoraMetadata
logger = logging.getLogger(__name__)
async def calculate_sha256(file_path: str) -> str:
"""Calculate SHA256 hash of a file"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
def _find_preview_file(base_name: str, dir_path: str) -> str:
"""Find preview file for given base name in directory"""
preview_patterns = [
f"{base_name}.preview.png",
f"{base_name}.preview.jpg",
f"{base_name}.preview.jpeg",
f"{base_name}.preview.mp4",
f"{base_name}.png",
f"{base_name}.jpg",
f"{base_name}.jpeg",
f"{base_name}.mp4"
]
for pattern in preview_patterns:
full_pattern = os.path.join(dir_path, pattern)
if os.path.exists(full_pattern):
return full_pattern.replace(os.sep, "/")
return ""
def normalize_path(path: str) -> str:
"""Normalize file path to use forward slashes"""
return path.replace(os.sep, "/") if path else path
async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
"""Get basic file information as LoraMetadata object"""
# First check if file actually exists and resolve symlinks
try:
real_path = os.path.realpath(file_path)
if not os.path.exists(real_path):
return None
except Exception as e:
logger.error(f"Error checking file existence for {file_path}: {e}")
return None
base_name = os.path.splitext(os.path.basename(file_path))[0]
dir_path = os.path.dirname(file_path)
preview_url = _find_preview_file(base_name, dir_path)
try:
metadata = LoraMetadata(
file_name=base_name,
model_name=base_name,
file_path=normalize_path(file_path),
size=os.path.getsize(real_path),
modified=os.path.getmtime(real_path),
sha256=await calculate_sha256(real_path),
base_model="Unknown", # Will be updated later
usage_tips="",
notes="",
from_civitai=True,
preview_url=normalize_path(preview_url),
)
# create metadata file
base_model_info = await extract_lora_metadata(real_path)
metadata.base_model = base_model_info['base_model']
await save_metadata(file_path, metadata)
return metadata
except Exception as e:
logger.error(f"Error getting file info for {file_path}: {e}")
return None
async def save_metadata(file_path: str, metadata: LoraMetadata) -> None:
"""Save metadata to .metadata.json file"""
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
try:
metadata_dict = metadata.to_dict()
metadata_dict['file_path'] = normalize_path(metadata_dict['file_path'])
metadata_dict['preview_url'] = normalize_path(metadata_dict['preview_url'])
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata_dict, f, indent=2, ensure_ascii=False)
except Exception as e:
print(f"Error saving metadata to {metadata_path}: {str(e)}")
async def load_metadata(file_path: str) -> Optional[LoraMetadata]:
"""Load metadata from .metadata.json file"""
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
try:
if os.path.exists(metadata_path):
with open(metadata_path, 'r', encoding='utf-8') as f:
data = json.load(f)
needs_update = False
if data['file_path'] != normalize_path(data['file_path']):
data['file_path'] = normalize_path(data['file_path'])
needs_update = True
preview_url = data.get('preview_url', '')
if not preview_url or not os.path.exists(preview_url):
base_name = os.path.splitext(os.path.basename(file_path))[0]
dir_path = os.path.dirname(file_path)
new_preview_url = normalize_path(_find_preview_file(base_name, dir_path))
if new_preview_url != preview_url:
data['preview_url'] = new_preview_url
needs_update = True
elif preview_url != normalize_path(preview_url):
data['preview_url'] = normalize_path(preview_url)
needs_update = True
if needs_update:
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
return LoraMetadata.from_dict(data)
except Exception as e:
print(f"Error loading metadata from {metadata_path}: {str(e)}")
return None
async def update_civitai_metadata(file_path: str, civitai_data: Dict) -> None:
"""Update metadata file with Civitai data"""
metadata = await load_metadata(file_path)
metadata['civitai'] = civitai_data
await save_metadata(file_path, metadata)

16
py/utils/lora_metadata.py Normal file
View File

@@ -0,0 +1,16 @@
from safetensors import safe_open
from typing import Dict
from .model_utils import determine_base_model
async def extract_lora_metadata(file_path: str) -> Dict:
"""Extract essential metadata from safetensors file"""
try:
with safe_open(file_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
if metadata:
# Only extract base_model from ss_base_model_version
base_model = determine_base_model(metadata.get("ss_base_model_version"))
return {"base_model": base_model}
except Exception as e:
print(f"Error reading metadata from {file_path}: {str(e)}")
return {"base_model": "Unknown"}

25
py/utils/model_utils.py Normal file
View File

@@ -0,0 +1,25 @@
from typing import Optional
# Base model mapping based on version string
BASE_MODEL_MAPPING = {
"sd-v1-5": "SD1.5",
"sd-v2-1": "SD2.1",
"sdxl": "SDXL",
"sd-v2": "SD2.0",
"flux1": "Flux.1 D",
"flux.1 d": "Flux.1 D",
"illustrious": "IL",
"pony": "Pony"
}
def determine_base_model(version_string: Optional[str]) -> str:
"""Determine base model from version string in safetensors metadata"""
if not version_string:
return "Unknown"
version_lower = version_string.lower()
for key, value in BASE_MODEL_MAPPING.items():
if key in version_lower:
return value
return "Unknown"

68
py/utils/models.py Normal file
View File

@@ -0,0 +1,68 @@
from dataclasses import dataclass, asdict
from typing import Dict, Optional
from datetime import datetime
import os
from .model_utils import determine_base_model
@dataclass
class LoraMetadata:
"""Represents the metadata structure for a Lora model"""
file_name: str # The filename without extension of the lora
model_name: str # The lora's name defined by the creator, initially same as file_name
file_path: str # Full path to the safetensors file
size: int # File size in bytes
modified: float # Last modified timestamp
sha256: str # SHA256 hash of the file
base_model: str # Base model (SD1.5/SD2.1/SDXL/etc.)
preview_url: str # Preview image URL
usage_tips: str = "{}" # Usage tips for the model, json string
notes: str = "" # Additional notes
from_civitai: bool = True # Whether the lora is from Civitai
civitai: Optional[Dict] = None # Civitai API data if available
@classmethod
def from_dict(cls, data: Dict) -> 'LoraMetadata':
"""Create LoraMetadata instance from dictionary"""
# Create a copy of the data to avoid modifying the input
data_copy = data.copy()
return cls(**data_copy)
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata':
"""Create LoraMetadata instance from Civitai version info"""
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
file_path=save_path.replace(os.sep, '/'),
size=file_info.get('sizeKB', 0) * 1024,
modified=datetime.now().timestamp(),
sha256=file_info['hashes'].get('SHA256', ''),
base_model=base_model,
preview_url=None, # Will be updated after preview download
from_civitai=True,
civitai=version_info
)
def to_dict(self) -> Dict:
"""Convert to dictionary for JSON serialization"""
return asdict(self)
@property
def modified_datetime(self) -> datetime:
"""Convert modified timestamp to datetime object"""
return datetime.fromtimestamp(self.modified)
def update_civitai_info(self, civitai_data: Dict) -> None:
"""Update Civitai information"""
self.civitai = civitai_data
def update_file_info(self, file_path: str) -> None:
"""Update metadata with actual file information"""
if os.path.exists(file_path):
self.size = os.path.getsize(file_path)
self.modified = os.path.getmtime(file_path)
self.file_path = file_path.replace(os.sep, '/')