Reorganize python files

This commit is contained in:
Will Miao
2025-02-24 20:41:16 +08:00
parent f0cd77e7e5
commit 2d72044d66
20 changed files with 5 additions and 7 deletions

1
py/services/__init__.py Normal file
View File

@@ -0,0 +1 @@
# Empty file to mark directory as Python package

View File

@@ -0,0 +1,171 @@
from datetime import datetime
import aiohttp
import os
import json
import logging
from email.parser import Parser
from typing import Optional, Dict, Tuple
from urllib.parse import unquote
from ..utils.models import LoraMetadata
logger = logging.getLogger(__name__)
class CivitaiClient:
def __init__(self):
self.base_url = "https://civitai.com/api/v1"
self.headers = {
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
}
self._session = None
@property
async def session(self) -> aiohttp.ClientSession:
"""Lazy initialize the session"""
if self._session is None:
connector = aiohttp.TCPConnector(ssl=True)
trust_env = True # 允许使用系统环境变量中的代理设置
self._session = aiohttp.ClientSession(connector=connector, trust_env=trust_env)
return self._session
def _parse_content_disposition(self, header: str) -> str:
"""Parse filename from content-disposition header"""
if not header:
return None
# Handle quoted filenames
if 'filename="' in header:
start = header.index('filename="') + 10
end = header.index('"', start)
return unquote(header[start:end])
# Fallback to original parsing
disposition = Parser().parsestr(f'Content-Disposition: {header}')
filename = disposition.get_param('filename')
if filename:
return unquote(filename)
return None
def _get_request_headers(self) -> dict:
"""Get request headers with optional API key"""
headers = {
'User-Agent': 'ComfyUI-LoRA-Manager/1.0',
'Content-Type': 'application/json'
}
from .settings_manager import settings
api_key = settings.get('civitai_api_key')
if (api_key):
headers['Authorization'] = f'Bearer {api_key}'
return headers
async def _download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
"""Download file with content-disposition support and progress tracking
Args:
url: Download URL
save_dir: Directory to save the file
default_filename: Fallback filename if none provided in headers
progress_callback: Optional async callback function for progress updates (0-100)
Returns:
Tuple[bool, str]: (success, save_path or error message)
"""
session = await self.session
try:
headers = self._get_request_headers()
async with session.get(url, headers=headers, allow_redirects=True) as response:
if response.status != 200:
return False, f"Download failed with status {response.status}"
# Get filename from content-disposition header
content_disposition = response.headers.get('Content-Disposition')
filename = self._parse_content_disposition(content_disposition)
if not filename:
filename = default_filename
save_path = os.path.join(save_dir, filename)
# Get total file size for progress calculation
total_size = int(response.headers.get('content-length', 0))
current_size = 0
# Stream download to file with progress updates
with open(save_path, 'wb') as f:
async for chunk in response.content.iter_chunked(8192):
if chunk:
f.write(chunk)
current_size += len(chunk)
if progress_callback and total_size:
progress = (current_size / total_size) * 100
await progress_callback(progress)
# Ensure 100% progress is reported
if progress_callback:
await progress_callback(100)
return True, save_path
except Exception as e:
logger.error(f"Download error: {e}")
return False, str(e)
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
try:
session = await self.session
async with session.get(f"{self.base_url}/model-versions/by-hash/{model_hash}") as response:
if response.status == 200:
return await response.json()
return None
except Exception as e:
logger.error(f"API Error: {str(e)}")
return None
async def download_preview_image(self, image_url: str, save_path: str):
try:
session = await self.session
async with session.get(image_url) as response:
if response.status == 200:
content = await response.read()
with open(save_path, 'wb') as f:
f.write(content)
return True
return False
except Exception as e:
print(f"Download Error: {str(e)}")
return False
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
"""Fetch all versions of a model"""
try:
session = await self.session
url = f"{self.base_url}/models/{model_id}"
async with session.get(url, headers=self.headers) as response:
if response.status == 200:
data = await response.json()
return data.get('modelVersions', [])
return None
except Exception as e:
logger.error(f"Error fetching model versions: {e}")
return None
async def get_model_version_info(self, version_id: str) -> Optional[Dict]:
"""Fetch model version metadata from Civitai"""
try:
session = await self.session
url = f"{self.base_url}/model-versions/{version_id}"
headers = self._get_request_headers()
async with session.get(url, headers=headers) as response:
if response.status == 200:
return await response.json()
return None
except Exception as e:
logger.error(f"Error fetching model version info: {e}")
return None
async def close(self):
"""Close the session if it exists"""
if self._session is not None:
await self._session.close()
self._session = None

View File

@@ -0,0 +1,154 @@
import logging
import os
import json
from typing import Optional, Dict
from .civitai_client import CivitaiClient
from .file_monitor import LoraFileMonitor
from ..utils.models import LoraMetadata
logger = logging.getLogger(__name__)
class DownloadManager:
def __init__(self, file_monitor: Optional[LoraFileMonitor] = None):
self.civitai_client = CivitaiClient()
self.file_monitor = file_monitor
async def download_from_civitai(self, download_url: str, save_dir: str, relative_path: str = '',
progress_callback=None) -> Dict:
try:
# Update save directory with relative path if provided
if relative_path:
save_dir = os.path.join(save_dir, relative_path)
# Create directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)
# Get version info
version_id = download_url.split('/')[-1]
version_info = await self.civitai_client.get_model_version_info(version_id)
if not version_info:
return {'success': False, 'error': 'Failed to fetch model metadata'}
# Report initial progress
if progress_callback:
await progress_callback(0)
# 2. 获取文件信息
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
if not file_info:
return {'success': False, 'error': 'No primary file found in metadata'}
# 3. 准备下载
file_name = file_info['name']
save_path = os.path.join(save_dir, file_name)
file_size = file_info.get('sizeKB', 0) * 1024
# 4. 通知文件监控系统
self.file_monitor.handler.add_ignore_path(
save_path.replace(os.sep, '/'),
file_size
)
# 5. 准备元数据
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
# 6. 开始下载流程
result = await self._execute_download(
download_url=download_url,
save_dir=save_dir,
metadata=metadata,
version_info=version_info,
relative_path=relative_path,
progress_callback=progress_callback
)
return result
except Exception as e:
logger.error(f"Error in download_from_civitai: {e}", exc_info=True)
return {'success': False, 'error': str(e)}
async def _execute_download(self, download_url: str, save_dir: str,
metadata: LoraMetadata, version_info: Dict,
relative_path: str, progress_callback=None) -> Dict:
"""Execute the actual download process including preview images and model files"""
try:
save_path = metadata.file_path
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
# Download preview image if available
images = version_info.get('images', [])
if images:
# Report preview download progress
if progress_callback:
await progress_callback(5) # 5% progress for starting preview download
preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png'
preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext
if await self.civitai_client.download_preview_image(images[0]['url'], preview_path):
metadata.preview_url = preview_path.replace(os.sep, '/')
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
# Report preview download completion
if progress_callback:
await progress_callback(10) # 10% progress after preview download
# Download model file with progress tracking
success, result = await self.civitai_client._download_file(
download_url,
save_dir,
os.path.basename(save_path),
progress_callback=lambda p: self._handle_download_progress(p, progress_callback)
)
if not success:
# Clean up files on failure
for path in [save_path, metadata_path, metadata.preview_url]:
if path and os.path.exists(path):
os.remove(path)
return {'success': False, 'error': result}
# 4. 更新文件信息(大小和修改时间)
metadata.update_file_info(save_path)
# 5. 最终更新元数据
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
# 6. update lora cache
cache = await self.file_monitor.scanner.get_cached_data()
metadata_dict = metadata.to_dict()
metadata_dict['folder'] = relative_path
cache.raw_data.append(metadata_dict)
await cache.resort()
all_folders = set(cache.folders)
all_folders.add(relative_path)
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Report 100% completion
if progress_callback:
await progress_callback(100)
return {
'success': True
}
except Exception as e:
logger.error(f"Error in _execute_download: {e}", exc_info=True)
# Clean up partial downloads
for path in [save_path, metadata_path]:
if path and os.path.exists(path):
os.remove(path)
return {'success': False, 'error': str(e)}
async def _handle_download_progress(self, file_progress: float, progress_callback):
"""Convert file download progress to overall progress
Args:
file_progress: Progress of file download (0-100)
progress_callback: Callback function for progress updates
"""
if progress_callback:
# Scale file progress to 10-100 range (after preview download)
overall_progress = 10 + (file_progress * 0.9) # 90% of progress for file download
await progress_callback(round(overall_progress))

184
py/services/file_monitor.py Normal file
View File

@@ -0,0 +1,184 @@
from operator import itemgetter
import os
import logging
import asyncio
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler, FileCreatedEvent, FileDeletedEvent
from typing import List
from threading import Lock
from .lora_scanner import LoraScanner
from ..config import config
logger = logging.getLogger(__name__)
class LoraFileHandler(FileSystemEventHandler):
"""Handler for LoRA file system events"""
def __init__(self, scanner: LoraScanner, loop: asyncio.AbstractEventLoop):
self.scanner = scanner
self.loop = loop # 存储事件循环引用
self.pending_changes = set() # 待处理的变更
self.lock = Lock() # 线程安全锁
self.update_task = None # 异步更新任务
self._ignore_paths = set() # Add ignore paths set
self._min_ignore_timeout = 5 # minimum timeout in seconds
self._download_speed = 1024 * 1024 # assume 1MB/s as base speed
def _should_ignore(self, path: str) -> bool:
"""Check if path should be ignored"""
real_path = os.path.realpath(path) # Resolve any symbolic links
return real_path.replace(os.sep, '/') in self._ignore_paths
def add_ignore_path(self, path: str, file_size: int = 0):
"""Add path to ignore list with dynamic timeout based on file size"""
real_path = os.path.realpath(path) # Resolve any symbolic links
self._ignore_paths.add(real_path.replace(os.sep, '/'))
# Short timeout (e.g. 5 seconds) is sufficient to ignore the CREATE event
timeout = 5
asyncio.get_event_loop().call_later(
timeout,
self._ignore_paths.discard,
real_path.replace(os.sep, '/')
)
def on_created(self, event):
if event.is_directory or not event.src_path.endswith('.safetensors'):
return
if self._should_ignore(event.src_path):
return
logger.info(f"LoRA file created: {event.src_path}")
self._schedule_update('add', event.src_path)
def on_deleted(self, event):
if event.is_directory or not event.src_path.endswith('.safetensors'):
return
if self._should_ignore(event.src_path):
return
logger.info(f"LoRA file deleted: {event.src_path}")
self._schedule_update('remove', event.src_path)
def _schedule_update(self, action: str, file_path: str): #file_path is a real path
"""Schedule a cache update"""
with self.lock:
# 使用 config 中的方法映射路径
mapped_path = config.map_path_to_link(file_path)
normalized_path = mapped_path.replace(os.sep, '/')
self.pending_changes.add((action, normalized_path))
self.loop.call_soon_threadsafe(self._create_update_task)
def _create_update_task(self):
"""Create update task in the event loop"""
if self.update_task is None or self.update_task.done():
self.update_task = asyncio.create_task(self._process_changes())
async def _process_changes(self, delay: float = 2.0):
"""Process pending changes with debouncing"""
await asyncio.sleep(delay)
try:
with self.lock:
changes = self.pending_changes.copy()
self.pending_changes.clear()
if not changes:
return
logger.info(f"Processing {len(changes)} file changes")
cache = await self.scanner.get_cached_data() # 先完成可能的初始化
needs_resort = False
new_folders = set() # 用于收集新的文件夹
for action, file_path in changes:
try:
if action == 'add':
# 扫描新文件
lora_data = await self.scanner.scan_single_lora(file_path)
if lora_data:
cache.raw_data.append(lora_data)
new_folders.add(lora_data['folder']) # 收集新文件夹
needs_resort = True
elif action == 'remove':
# 从缓存中移除
logger.info(f"Removing {file_path} from cache")
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != file_path
]
needs_resort = True
except Exception as e:
logger.error(f"Error processing {action} for {file_path}: {e}")
if needs_resort:
await cache.resort()
# 更新文件夹列表,包括新添加的文件夹
all_folders = set(cache.folders) | new_folders
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
except Exception as e:
logger.error(f"Error in process_changes: {e}")
class LoraFileMonitor:
"""Monitor for LoRA file changes"""
def __init__(self, scanner: LoraScanner, roots: List[str]):
self.scanner = scanner
scanner.set_file_monitor(self)
self.observer = Observer()
self.loop = asyncio.get_event_loop()
self.handler = LoraFileHandler(scanner, self.loop)
# 使用已存在的路径映射
self.monitor_paths = set()
for root in roots:
self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/'))
# 添加所有已映射的目标路径
for target_path in config._path_mappings.keys():
self.monitor_paths.add(target_path)
def start(self):
"""Start monitoring"""
for path_info in self.monitor_paths:
try:
if isinstance(path_info, tuple):
# 对于链接,监控目标路径
_, target_path = path_info
self.observer.schedule(self.handler, target_path, recursive=True)
logger.info(f"Started monitoring target path: {target_path}")
else:
# 对于普通路径,直接监控
self.observer.schedule(self.handler, path_info, recursive=True)
logger.info(f"Started monitoring: {path_info}")
except Exception as e:
logger.error(f"Error monitoring {path_info}: {e}")
self.observer.start()
def stop(self):
"""Stop monitoring"""
self.observer.stop()
self.observer.join()
def rescan_links(self):
"""重新扫描链接(当添加新的链接时调用)"""
new_paths = set()
for path in self.monitor_paths.copy():
self._add_link_targets(path)
# 添加新发现的路径到监控
new_paths = self.monitor_paths - set(self.observer.watches.keys())
for path in new_paths:
try:
self.observer.schedule(self.handler, path, recursive=True)
logger.info(f"Added new monitoring path: {path}")
except Exception as e:
logger.error(f"Error adding new monitor for {path}: {e}")

64
py/services/lora_cache.py Normal file
View File

@@ -0,0 +1,64 @@
import asyncio
from typing import List, Dict
from dataclasses import dataclass
from operator import itemgetter
@dataclass
class LoraCache:
"""Cache structure for LoRA data"""
raw_data: List[Dict]
sorted_by_name: List[Dict]
sorted_by_date: List[Dict]
folders: List[str]
def __post_init__(self):
self._lock = asyncio.Lock()
async def resort(self, name_only: bool = False):
"""Resort all cached data views"""
async with self._lock:
self.sorted_by_name = sorted(
self.raw_data,
key=lambda x: x['model_name'].lower() # Case-insensitive sort
)
if not name_only:
self.sorted_by_date = sorted(
self.raw_data,
key=itemgetter('modified'),
reverse=True
)
# Update folder list
all_folders = set(l['folder'] for l in self.raw_data)
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
async def update_preview_url(self, file_path: str, preview_url: str) -> bool:
"""Update preview_url for a specific lora in all cached data
Args:
file_path: The file path of the lora to update
preview_url: The new preview URL
Returns:
bool: True if the update was successful, False if the lora wasn't found
"""
async with self._lock:
# Update in raw_data
for item in self.raw_data:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
else:
return False # Lora not found
# Update in sorted lists (references to the same dict objects)
for item in self.sorted_by_name:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
for item in self.sorted_by_date:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
return True

439
py/services/lora_scanner.py Normal file
View File

@@ -0,0 +1,439 @@
import json
import os
import logging
import asyncio
import shutil
from typing import List, Dict, Optional
from dataclasses import dataclass
from operator import itemgetter
from ..config import config
from ..utils.file_utils import load_metadata, get_file_info
from .lora_cache import LoraCache
from difflib import SequenceMatcher
logger = logging.getLogger(__name__)
class LoraScanner:
"""Service for scanning and managing LoRA files"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
# 确保初始化只执行一次
if not hasattr(self, '_initialized'):
self._cache: Optional[LoraCache] = None
self._initialization_lock = asyncio.Lock()
self._initialization_task: Optional[asyncio.Task] = None
self._initialized = True
self.file_monitor = None # Add this line
def set_file_monitor(self, monitor):
"""Set file monitor instance"""
self.file_monitor = monitor
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
async def get_cached_data(self, force_refresh: bool = False) -> LoraCache:
"""Get cached LoRA data, refresh if needed"""
async with self._initialization_lock:
# 如果缓存未初始化但需要响应请求,返回空缓存
if self._cache is None and not force_refresh:
return LoraCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# 如果正在初始化,等待完成
if self._initialization_task and not self._initialization_task.done():
try:
await self._initialization_task
except Exception as e:
logger.error(f"Cache initialization failed: {e}")
self._initialization_task = None
if (self._cache is None or force_refresh):
# 创建新的初始化任务
if not self._initialization_task or self._initialization_task.done():
self._initialization_task = asyncio.create_task(self._initialize_cache())
try:
await self._initialization_task
except Exception as e:
logger.error(f"Cache initialization failed: {e}")
# 如果缓存已存在,继续使用旧缓存
if self._cache is None:
raise # 如果没有缓存,则抛出异常
return self._cache
async def _initialize_cache(self) -> None:
"""Initialize or refresh the cache"""
# Scan for new data
raw_data = await self.scan_all_loras()
# Update cache
self._cache = LoraCache(
raw_data=raw_data,
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# Call resort_cache to create sorted views
await self._cache.resort()
def fuzzy_match(self, text: str, pattern: str, threshold: float = 0.7) -> bool:
"""
Check if text matches pattern using fuzzy matching.
Returns True if similarity ratio is above threshold.
"""
if not pattern or not text:
return False
# Convert both to lowercase for case-insensitive matching
text = text.lower()
pattern = pattern.lower()
# Split pattern into words
search_words = pattern.split()
# Check each word
for word in search_words:
# First check if word is a substring (faster)
if word in text:
continue
# If not found as substring, try fuzzy matching
# Check if any part of the text matches this word
found_match = False
for text_part in text.split():
ratio = SequenceMatcher(None, text_part, word).ratio()
if ratio >= threshold:
found_match = True
break
if not found_match:
return False
# All words found either as substrings or fuzzy matches
return True
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
folder: str = None, search: str = None, fuzzy: bool = False,
recursive: bool = False):
"""Get paginated and filtered lora data
Args:
page: Current page number (1-based)
page_size: Number of items per page
sort_by: Sort method ('name' or 'date')
folder: Filter by folder path
search: Search term
fuzzy: Use fuzzy matching for search
recursive: Include subfolders when folder filter is applied
"""
cache = await self.get_cached_data()
# 先获取基础数据集
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
# 应用文件夹过滤
if folder is not None:
if recursive:
# 递归模式:匹配所有以该文件夹开头的路径
filtered_data = [
item for item in filtered_data
if item['folder'].startswith(folder + '/') or item['folder'] == folder
]
else:
# 非递归模式:只匹配确切的文件夹
filtered_data = [
item for item in filtered_data
if item['folder'] == folder
]
# 应用搜索过滤
if search:
if fuzzy:
filtered_data = [
item for item in filtered_data
if any(
self.fuzzy_match(str(value), search)
for value in [
item.get('model_name', ''),
item.get('base_model', '')
]
if value
)
]
else:
# Original exact search logic
filtered_data = [
item for item in filtered_data
if search in str(item.get('model_name', '')).lower()
]
# 计算分页
total_items = len(filtered_data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
result = {
'items': filtered_data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
return result
def invalidate_cache(self):
"""Invalidate the current cache"""
self._cache = None
async def scan_all_loras(self) -> List[Dict]:
"""Scan all LoRA directories and return metadata"""
all_loras = []
# 分目录异步扫描
scan_tasks = []
for loras_root in config.loras_roots:
task = asyncio.create_task(self._scan_directory(loras_root))
scan_tasks.append(task)
for task in scan_tasks:
try:
loras = await task
all_loras.extend(loras)
except Exception as e:
logger.error(f"Error scanning directory: {e}")
return all_loras
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a single directory for LoRA files"""
loras = []
original_root = root_path # 保存原始根路径
async def scan_recursive(path: str, visited_paths: set):
"""递归扫描目录,避免循环链接"""
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True) and entry.name.endswith('.safetensors'):
# 使用原始路径而不是真实路径
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, loras)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# 对于目录,使用原始路径继续扫描
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return loras
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
"""处理单个文件并添加到结果列表"""
try:
result = await self._process_lora_file(file_path, root_path)
if result:
loras.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
async def _process_lora_file(self, file_path: str, root_path: str) -> Dict:
"""Process a single LoRA file and return its metadata"""
# Try loading existing metadata
metadata = await load_metadata(file_path)
if metadata is None:
# Create new metadata if none exists
metadata = await get_file_info(file_path)
# Convert to dict and add folder info
lora_data = metadata.to_dict()
rel_path = os.path.relpath(file_path, root_path)
folder = os.path.dirname(rel_path)
lora_data['folder'] = folder.replace(os.path.sep, '/')
return lora_data
async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool:
"""Update preview URL in cache for a specific lora
Args:
file_path: The file path of the lora to update
preview_url: The new preview URL
Returns:
bool: True if the update was successful, False if cache doesn't exist or lora wasn't found
"""
if self._cache is None:
return False
return await self._cache.update_preview_url(file_path, preview_url)
async def scan_single_lora(self, file_path: str) -> Optional[Dict]:
"""Scan a single LoRA file and return its metadata"""
try:
if not os.path.exists(os.path.realpath(file_path)):
return None
# 获取基本文件信息
metadata = await get_file_info(file_path)
if not metadata:
return None
folder = self._calculate_folder(file_path)
# 确保 folder 字段存在
metadata_dict = metadata.to_dict()
metadata_dict['folder'] = folder or ''
return metadata_dict
except Exception as e:
logger.error(f"Error scanning {file_path}: {e}")
return None
def _calculate_folder(self, file_path: str) -> str:
"""Calculate the folder path for a LoRA file"""
# 使用原始路径计算相对路径
for root in config.loras_roots:
if file_path.startswith(root):
rel_path = os.path.relpath(file_path, root)
return os.path.dirname(rel_path).replace(os.path.sep, '/')
return ''
async def move_model(self, source_path: str, target_path: str) -> bool:
"""Move a model and its associated files to a new location"""
try:
# 保持原始路径格式
source_path = source_path.replace(os.sep, '/')
target_path = target_path.replace(os.sep, '/')
# 其余代码保持不变
base_name = os.path.splitext(os.path.basename(source_path))[0]
source_dir = os.path.dirname(source_path)
os.makedirs(target_path, exist_ok=True)
target_lora = os.path.join(target_path, f"{base_name}.safetensors").replace(os.sep, '/')
# 使用真实路径进行文件操作
real_source = os.path.realpath(source_path)
real_target = os.path.realpath(target_lora)
file_size = os.path.getsize(real_source)
if self.file_monitor:
self.file_monitor.handler.add_ignore_path(
real_source,
file_size
)
self.file_monitor.handler.add_ignore_path(
real_target,
file_size
)
# 使用真实路径进行文件操作
shutil.move(real_source, real_target)
# Move associated files
source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json")
if os.path.exists(source_metadata):
target_metadata = os.path.join(target_path, f"{base_name}.metadata.json")
shutil.move(source_metadata, target_metadata)
metadata = await self._update_metadata_paths(target_metadata, target_lora)
# Move preview file if exists
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
'.png', '.jpeg', '.jpg', '.mp4']
for ext in preview_extensions:
source_preview = os.path.join(source_dir, f"{base_name}{ext}")
if os.path.exists(source_preview):
target_preview = os.path.join(target_path, f"{base_name}{ext}")
shutil.move(source_preview, target_preview)
break
# Update cache
await self.update_single_lora_cache(source_path, target_lora, metadata)
return True
except Exception as e:
logger.error(f"Error moving model: {e}", exc_info=True)
return False
async def update_single_lora_cache(self, original_path: str, new_path: str, metadata: Dict) -> bool:
cache = await self.get_cached_data()
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != original_path
]
if metadata:
metadata['folder'] = self._calculate_folder(new_path)
cache.raw_data.append(metadata)
all_folders = set(cache.folders)
all_folders.add(metadata['folder'])
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Resort cache
await cache.resort()
async def _update_metadata_paths(self, metadata_path: str, lora_path: str) -> Dict:
"""Update file paths in metadata file"""
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
# Update file_path
metadata['file_path'] = lora_path.replace(os.sep, '/')
# Update preview_url if exists
if 'preview_url' in metadata:
preview_dir = os.path.dirname(lora_path)
preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0]
preview_ext = os.path.splitext(metadata['preview_url'])[1]
new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}")
metadata['preview_url'] = new_preview_path.replace(os.sep, '/')
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
return metadata
except Exception as e:
logger.error(f"Error updating metadata paths: {e}", exc_info=True)

View File

@@ -0,0 +1,46 @@
import os
import json
import logging
from typing import Any, Dict
logger = logging.getLogger(__name__)
class SettingsManager:
def __init__(self):
self.settings_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json')
self.settings = self._load_settings()
def _load_settings(self) -> Dict[str, Any]:
"""Load settings from file"""
if os.path.exists(self.settings_file):
try:
with open(self.settings_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading settings: {e}")
return self._get_default_settings()
def _get_default_settings(self) -> Dict[str, Any]:
"""Return default settings"""
return {
"civitai_api_key": ""
}
def get(self, key: str, default: Any = None) -> Any:
"""Get setting value"""
return self.settings.get(key, default)
def set(self, key: str, value: Any) -> None:
"""Set setting value and save"""
self.settings[key] = value
self._save_settings()
def _save_settings(self) -> None:
"""Save settings to file"""
try:
with open(self.settings_file, 'w', encoding='utf-8') as f:
json.dump(self.settings, f, indent=2)
except Exception as e:
logger.error(f"Error saving settings: {e}")
settings = SettingsManager()

View File

@@ -0,0 +1,43 @@
import logging
from aiohttp import web
from typing import Set, Dict, Optional
logger = logging.getLogger(__name__)
class WebSocketManager:
"""Manages WebSocket connections and broadcasts"""
def __init__(self):
self._websockets: Set[web.WebSocketResponse] = set()
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection"""
ws = web.WebSocketResponse()
await ws.prepare(request)
self._websockets.add(ws)
try:
async for msg in ws:
if msg.type == web.WSMsgType.ERROR:
logger.error(f'WebSocket error: {ws.exception()}')
finally:
self._websockets.discard(ws)
return ws
async def broadcast(self, data: Dict):
"""Broadcast message to all connected clients"""
if not self._websockets:
return
for ws in self._websockets:
try:
await ws.send_json(data)
except Exception as e:
logger.error(f"Error sending progress: {e}")
def get_connected_clients_count(self) -> int:
"""Get number of connected clients"""
return len(self._websockets)
# Global instance
ws_manager = WebSocketManager()