Implement move to folder

This commit is contained in:
Will Miao
2025-02-18 19:01:02 +08:00
parent bbc5aea08c
commit 52b01d1bce
8 changed files with 333 additions and 14 deletions

View File

@@ -57,6 +57,8 @@ class LoraFileHandler(FileSystemEventHandler):
def on_deleted(self, event):
if event.is_directory or not event.src_path.endswith('.safetensors'):
return
if self._should_ignore(event.src_path):
return
logger.info(f"LoRA file deleted: {event.src_path}")
self._schedule_update('remove', event.src_path)
@@ -131,6 +133,7 @@ class LoraFileMonitor:
def __init__(self, scanner: LoraScanner, roots: List[str]):
self.scanner = scanner
scanner.set_file_monitor(self)
self.roots = roots
self.observer = Observer()
# 获取当前运行的事件循环

View File

@@ -1,7 +1,9 @@
import json
import os
import logging
import time
import asyncio
import shutil
from typing import List, Dict, Optional
from dataclasses import dataclass
from operator import itemgetter
@@ -30,6 +32,11 @@ class LoraScanner:
self._initialization_lock = asyncio.Lock()
self._initialization_task: Optional[asyncio.Task] = None
self._initialized = True
self.file_monitor = None # Add this line
def set_file_monitor(self, monitor):
"""Set file monitor instance"""
self.file_monitor = monitor
@classmethod
async def get_instance(cls):
@@ -39,7 +46,7 @@ class LoraScanner:
cls._instance = cls()
return cls._instance
async def get_cached_data(self, force_refresh: bool = False) -> LoraCache:
async def get_cached_data(self, force_refresh: bool = False) -> LoraCache:
"""Get cached LoRA data, refresh if needed"""
async with self._initialization_lock:
@@ -295,17 +302,7 @@ class LoraScanner:
if not metadata:
return None
# 计算相对于 lora_roots 的文件夹路径
folder = None
file_dir = os.path.dirname(file_path)
for root in config.loras_roots:
if file_dir.startswith(root):
rel_path = os.path.relpath(file_dir, root)
if rel_path == '.':
folder = '' # 根目录
else:
folder = rel_path.replace(os.sep, '/')
break
folder = self._calculate_folder(file_path)
# 确保 folder 字段存在
metadata_dict = metadata.to_dict()
@@ -316,4 +313,119 @@ class LoraScanner:
except Exception as e:
logger.error(f"Error scanning {file_path}: {e}")
return None
def _calculate_folder(self, file_path: str) -> str:
"""Calculate the folder path for a LoRA file"""
for root in config.loras_roots:
if file_path.startswith(root):
rel_path = os.path.relpath(file_path, root)
return os.path.dirname(rel_path).replace(os.path.sep, '/')
return ''
async def move_model(self, source_path: str, target_path: str) -> bool:
"""Move a model and its associated files to a new location
Args:
source_path: Full path to the source lora file
target_path: Full path to the target directory
Returns:
bool: True if successful, False otherwise
"""
try:
# Ensure paths are normalized
source_path = source_path.replace(os.sep, '/')
target_path = target_path.replace(os.sep, '/')
# Get base name without extension
base_name = os.path.splitext(os.path.basename(source_path))[0]
source_dir = os.path.dirname(source_path)
# Create target directory if it doesn't exist
os.makedirs(target_path, exist_ok=True)
# Calculate target lora path
target_lora = os.path.join(target_path, f"{base_name}.safetensors").replace(os.sep, '/')
# Get source file size for timeout calculation
file_size = os.path.getsize(source_path)
# Tell file monitor to ignore these paths
if self.file_monitor:
self.file_monitor.handler.add_ignore_path(
source_path,
file_size
)
self.file_monitor.handler.add_ignore_path(
target_lora,
file_size
)
# Move main lora file
shutil.move(source_path, target_lora)
# Move associated files
source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json")
if os.path.exists(source_metadata):
target_metadata = os.path.join(target_path, f"{base_name}.metadata.json")
shutil.move(source_metadata, target_metadata)
lora_data = await self._update_metadata_paths(target_metadata, target_lora)
# Move preview file if exists
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
'.png', '.jpeg', '.jpg', '.mp4']
for ext in preview_extensions:
source_preview = os.path.join(source_dir, f"{base_name}{ext}")
if os.path.exists(source_preview):
target_preview = os.path.join(target_path, f"{base_name}{ext}")
shutil.move(source_preview, target_preview)
break
# Update cache folders
cache = await self.get_cached_data()
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != source_path
]
if lora_data:
cache.raw_data.append(lora_data)
all_folders = set(cache.folders)
all_folders.add(lora_data['folder'])
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Resort cache
await cache.resort()
return True
except Exception as e:
logger.error(f"Error moving model: {e}", exc_info=True)
return False
async def _update_metadata_paths(self, metadata_path: str, lora_path: str) -> Dict:
"""Update file paths in metadata file"""
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
# Update file_path
metadata['file_path'] = lora_path.replace(os.sep, '/')
# Update preview_url if exists
if 'preview_url' in metadata:
preview_dir = os.path.dirname(lora_path)
preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0]
preview_ext = os.path.splitext(metadata['preview_url'])[1]
new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}")
metadata['preview_url'] = new_preview_path.replace(os.sep, '/')
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
metadata['folder'] = self._calculate_folder(lora_path)
return metadata
except Exception as e:
logger.error(f"Error updating metadata paths: {e}", exc_info=True)