mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
checkpoint
This commit is contained in:
@@ -1,39 +1,39 @@
|
||||
from operator import itemgetter
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
import time
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
from typing import List, Dict, Set
|
||||
from typing import List, Dict, Set, Optional
|
||||
from threading import Lock
|
||||
|
||||
from .checkpoint_scanner import CheckpointScanner
|
||||
from .lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
from .service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoraFileHandler(FileSystemEventHandler):
|
||||
"""Handler for LoRA file system events"""
|
||||
class BaseFileHandler(FileSystemEventHandler):
|
||||
"""Base handler for file system events"""
|
||||
|
||||
def __init__(self, scanner: LoraScanner, loop: asyncio.AbstractEventLoop):
|
||||
self.scanner = scanner
|
||||
self.loop = loop # 存储事件循环引用
|
||||
self.pending_changes = set() # 待处理的变更
|
||||
self.lock = Lock() # 线程安全锁
|
||||
self.update_task = None # 异步更新任务
|
||||
self._ignore_paths = set() # Add ignore paths set
|
||||
self._min_ignore_timeout = 5 # minimum timeout in seconds
|
||||
self._download_speed = 1024 * 1024 # assume 1MB/s as base speed
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop):
|
||||
self.loop = loop # Store event loop reference
|
||||
self.pending_changes = set() # Pending changes
|
||||
self.lock = Lock() # Thread-safe lock
|
||||
self.update_task = None # Async update task
|
||||
self._ignore_paths = set() # Paths to ignore
|
||||
self._min_ignore_timeout = 5 # Minimum timeout in seconds
|
||||
self._download_speed = 1024 * 1024 # Assume 1MB/s as base speed
|
||||
|
||||
# Track modified files with timestamps for debouncing
|
||||
self.modified_files: Dict[str, float] = {}
|
||||
self.debounce_timer = None
|
||||
self.debounce_delay = 3.0 # seconds to wait after last modification
|
||||
self.debounce_delay = 3.0 # Seconds to wait after last modification
|
||||
|
||||
# Track files that are already scheduled for processing
|
||||
# Track files already scheduled for processing
|
||||
self.scheduled_files: Set[str] = set()
|
||||
|
||||
# File extensions to monitor - should be overridden by subclasses
|
||||
self.file_extensions = set()
|
||||
|
||||
def _should_ignore(self, path: str) -> bool:
|
||||
"""Check if path should be ignored"""
|
||||
@@ -58,35 +58,33 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
if event.is_directory:
|
||||
return
|
||||
|
||||
# Handle safetensors files directly
|
||||
if event.src_path.endswith('.safetensors'):
|
||||
# Handle appropriate files based on extensions
|
||||
file_ext = os.path.splitext(event.src_path)[1].lower()
|
||||
if file_ext in self.file_extensions:
|
||||
if self._should_ignore(event.src_path):
|
||||
return
|
||||
|
||||
# We'll process this file directly and ignore subsequent modifications
|
||||
# to prevent duplicate processing
|
||||
# Process this file directly and ignore subsequent modifications
|
||||
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
|
||||
if normalized_path not in self.scheduled_files:
|
||||
logger.info(f"LoRA file created: {event.src_path}")
|
||||
logger.info(f"File created: {event.src_path}")
|
||||
self.scheduled_files.add(normalized_path)
|
||||
self._schedule_update('add', event.src_path)
|
||||
|
||||
# Ignore modifications for a short period after creation
|
||||
# This helps avoid duplicate processing
|
||||
self.loop.call_later(
|
||||
self.debounce_delay * 2,
|
||||
self.scheduled_files.discard,
|
||||
normalized_path
|
||||
)
|
||||
|
||||
# For browser downloads, we'll catch them when they're renamed to .safetensors
|
||||
|
||||
def on_modified(self, event):
|
||||
if event.is_directory:
|
||||
return
|
||||
|
||||
# Only process safetensors files
|
||||
if event.src_path.endswith('.safetensors'):
|
||||
# Only process files with supported extensions
|
||||
file_ext = os.path.splitext(event.src_path)[1].lower()
|
||||
if file_ext in self.file_extensions:
|
||||
if self._should_ignore(event.src_path):
|
||||
return
|
||||
|
||||
@@ -134,12 +132,17 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
|
||||
# Process stable files
|
||||
for file_path in files_to_process:
|
||||
logger.info(f"Processing modified LoRA file: {file_path}")
|
||||
logger.info(f"Processing modified file: {file_path}")
|
||||
self._schedule_update('add', file_path)
|
||||
|
||||
def on_deleted(self, event):
|
||||
if event.is_directory or not event.src_path.endswith('.safetensors'):
|
||||
if event.is_directory:
|
||||
return
|
||||
|
||||
file_ext = os.path.splitext(event.src_path)[1].lower()
|
||||
if file_ext not in self.file_extensions:
|
||||
return
|
||||
|
||||
if self._should_ignore(event.src_path):
|
||||
return
|
||||
|
||||
@@ -147,14 +150,17 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
|
||||
self.scheduled_files.discard(normalized_path)
|
||||
|
||||
logger.info(f"LoRA file deleted: {event.src_path}")
|
||||
logger.info(f"File deleted: {event.src_path}")
|
||||
self._schedule_update('remove', event.src_path)
|
||||
|
||||
def on_moved(self, event):
|
||||
"""Handle file move/rename events"""
|
||||
|
||||
# If destination is a safetensors file, treat it as a new file
|
||||
if event.dest_path.endswith('.safetensors'):
|
||||
src_ext = os.path.splitext(event.src_path)[1].lower()
|
||||
dest_ext = os.path.splitext(event.dest_path)[1].lower()
|
||||
|
||||
# If destination has supported extension, treat as new file
|
||||
if dest_ext in self.file_extensions:
|
||||
if self._should_ignore(event.dest_path):
|
||||
return
|
||||
|
||||
@@ -162,7 +168,7 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
|
||||
# Only process if not already scheduled
|
||||
if normalized_path not in self.scheduled_files:
|
||||
logger.info(f"LoRA file renamed/moved to: {event.dest_path}")
|
||||
logger.info(f"File renamed/moved to: {event.dest_path}")
|
||||
self.scheduled_files.add(normalized_path)
|
||||
self._schedule_update('add', event.dest_path)
|
||||
|
||||
@@ -173,21 +179,21 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
normalized_path
|
||||
)
|
||||
|
||||
# If source was a safetensors file, treat it as deleted
|
||||
if event.src_path.endswith('.safetensors'):
|
||||
# If source was a supported file, treat it as deleted
|
||||
if src_ext in self.file_extensions:
|
||||
if self._should_ignore(event.src_path):
|
||||
return
|
||||
|
||||
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
|
||||
self.scheduled_files.discard(normalized_path)
|
||||
|
||||
logger.info(f"LoRA file moved/renamed from: {event.src_path}")
|
||||
logger.info(f"File moved/renamed from: {event.src_path}")
|
||||
self._schedule_update('remove', event.src_path)
|
||||
|
||||
def _schedule_update(self, action: str, file_path: str): #file_path is a real path
|
||||
def _schedule_update(self, action: str, file_path: str):
|
||||
"""Schedule a cache update"""
|
||||
with self.lock:
|
||||
# 使用 config 中的方法映射路径
|
||||
# Use config method to map path
|
||||
mapped_path = config.map_path_to_link(file_path)
|
||||
normalized_path = mapped_path.replace(os.sep, '/')
|
||||
self.pending_changes.add((action, normalized_path))
|
||||
@@ -198,7 +204,20 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
"""Create update task in the event loop"""
|
||||
if self.update_task is None or self.update_task.done():
|
||||
self.update_task = asyncio.create_task(self._process_changes())
|
||||
|
||||
async def _process_changes(self, delay: float = 2.0):
|
||||
"""Process pending changes with debouncing - should be implemented by subclasses"""
|
||||
raise NotImplementedError("Subclasses must implement _process_changes")
|
||||
|
||||
|
||||
class LoraFileHandler(BaseFileHandler):
|
||||
"""Handler for LoRA file system events"""
|
||||
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop):
|
||||
super().__init__(loop)
|
||||
# Set supported file extensions for LoRAs
|
||||
self.file_extensions = {'.safetensors'}
|
||||
|
||||
async def _process_changes(self, delay: float = 2.0):
|
||||
"""Process pending changes with debouncing"""
|
||||
await asyncio.sleep(delay)
|
||||
@@ -211,9 +230,11 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
if not changes:
|
||||
return
|
||||
|
||||
logger.info(f"Processing {len(changes)} file changes")
|
||||
logger.info(f"Processing {len(changes)} LoRA file changes")
|
||||
|
||||
cache = await self.scanner.get_cached_data()
|
||||
# Get scanner through ServiceRegistry
|
||||
scanner = await ServiceRegistry.get_lora_scanner()
|
||||
cache = await scanner.get_cached_data()
|
||||
needs_resort = False
|
||||
new_folders = set()
|
||||
|
||||
@@ -227,36 +248,36 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
continue
|
||||
|
||||
# Scan new file
|
||||
lora_data = await self.scanner.scan_single_lora(file_path)
|
||||
if lora_data:
|
||||
model_data = await scanner.scan_single_model(file_path)
|
||||
if model_data:
|
||||
# Update tags count
|
||||
for tag in lora_data.get('tags', []):
|
||||
self.scanner._tags_count[tag] = self.scanner._tags_count.get(tag, 0) + 1
|
||||
for tag in model_data.get('tags', []):
|
||||
scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1
|
||||
|
||||
cache.raw_data.append(lora_data)
|
||||
new_folders.add(lora_data['folder'])
|
||||
cache.raw_data.append(model_data)
|
||||
new_folders.add(model_data['folder'])
|
||||
# Update hash index
|
||||
if 'sha256' in lora_data:
|
||||
self.scanner._hash_index.add_entry(
|
||||
lora_data['sha256'],
|
||||
lora_data['file_path']
|
||||
if 'sha256' in model_data:
|
||||
scanner._hash_index.add_entry(
|
||||
model_data['sha256'],
|
||||
model_data['file_path']
|
||||
)
|
||||
needs_resort = True
|
||||
|
||||
elif action == 'remove':
|
||||
# Find the lora to remove so we can update tags count
|
||||
lora_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
||||
if lora_to_remove:
|
||||
# Find the model to remove so we can update tags count
|
||||
model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
||||
if model_to_remove:
|
||||
# Update tags count by reducing counts
|
||||
for tag in lora_to_remove.get('tags', []):
|
||||
if tag in self.scanner._tags_count:
|
||||
self.scanner._tags_count[tag] = max(0, self.scanner._tags_count[tag] - 1)
|
||||
if self.scanner._tags_count[tag] == 0:
|
||||
del self.scanner._tags_count[tag]
|
||||
for tag in model_to_remove.get('tags', []):
|
||||
if tag in scanner._tags_count:
|
||||
scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
|
||||
if scanner._tags_count[tag] == 0:
|
||||
del scanner._tags_count[tag]
|
||||
|
||||
# Remove from cache and hash index
|
||||
logger.info(f"Removing {file_path} from cache")
|
||||
self.scanner._hash_index.remove_by_path(file_path)
|
||||
scanner._hash_index.remove_by_path(file_path)
|
||||
cache.raw_data = [
|
||||
item for item in cache.raw_data
|
||||
if item['file_path'] != file_path
|
||||
@@ -274,59 +295,140 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process_changes: {e}")
|
||||
logger.error(f"Error in process_changes for LoRA: {e}")
|
||||
|
||||
|
||||
class LoraFileMonitor:
|
||||
"""Monitor for LoRA file changes"""
|
||||
class CheckpointFileHandler(BaseFileHandler):
|
||||
"""Handler for checkpoint file system events"""
|
||||
|
||||
def __init__(self, scanner: LoraScanner, roots: List[str]):
|
||||
self.scanner = scanner
|
||||
scanner.set_file_monitor(self)
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop):
|
||||
super().__init__(loop)
|
||||
# Set supported file extensions for checkpoints
|
||||
self.file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
|
||||
|
||||
async def _process_changes(self, delay: float = 2.0):
|
||||
"""Process pending changes with debouncing for checkpoint files"""
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
try:
|
||||
with self.lock:
|
||||
changes = self.pending_changes.copy()
|
||||
self.pending_changes.clear()
|
||||
|
||||
if not changes:
|
||||
return
|
||||
|
||||
logger.info(f"Processing {len(changes)} checkpoint file changes")
|
||||
|
||||
# Get scanner through ServiceRegistry
|
||||
scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
cache = await scanner.get_cached_data()
|
||||
needs_resort = False
|
||||
new_folders = set()
|
||||
|
||||
for action, file_path in changes:
|
||||
try:
|
||||
if action == 'add':
|
||||
# Check if file already exists in cache
|
||||
existing = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
||||
if existing:
|
||||
logger.info(f"File {file_path} already in cache, skipping")
|
||||
continue
|
||||
|
||||
# Scan new file
|
||||
model_data = await scanner.scan_single_model(file_path)
|
||||
if model_data:
|
||||
# Update tags count if applicable
|
||||
for tag in model_data.get('tags', []):
|
||||
scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1
|
||||
|
||||
cache.raw_data.append(model_data)
|
||||
new_folders.add(model_data['folder'])
|
||||
# Update hash index
|
||||
if 'sha256' in model_data:
|
||||
scanner._hash_index.add_entry(
|
||||
model_data['sha256'],
|
||||
model_data['file_path']
|
||||
)
|
||||
needs_resort = True
|
||||
|
||||
elif action == 'remove':
|
||||
# Find the model to remove so we can update tags count
|
||||
model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
||||
if model_to_remove:
|
||||
# Update tags count by reducing counts
|
||||
for tag in model_to_remove.get('tags', []):
|
||||
if tag in scanner._tags_count:
|
||||
scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
|
||||
if scanner._tags_count[tag] == 0:
|
||||
del scanner._tags_count[tag]
|
||||
|
||||
# Remove from cache and hash index
|
||||
logger.info(f"Removing {file_path} from checkpoint cache")
|
||||
scanner._hash_index.remove_by_path(file_path)
|
||||
cache.raw_data = [
|
||||
item for item in cache.raw_data
|
||||
if item['file_path'] != file_path
|
||||
]
|
||||
needs_resort = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing checkpoint {action} for {file_path}: {e}")
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort()
|
||||
|
||||
# Update folder list
|
||||
all_folders = set(cache.folders) | new_folders
|
||||
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process_changes for checkpoint: {e}")
|
||||
|
||||
|
||||
class BaseFileMonitor:
|
||||
"""Base class for file monitoring"""
|
||||
|
||||
def __init__(self, monitor_paths: List[str]):
|
||||
self.observer = Observer()
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.handler = LoraFileHandler(scanner, self.loop)
|
||||
|
||||
# 使用已存在的路径映射
|
||||
self.monitor_paths = set()
|
||||
for root in roots:
|
||||
self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/'))
|
||||
|
||||
# Process monitor paths
|
||||
for path in monitor_paths:
|
||||
self.monitor_paths.add(os.path.realpath(path).replace(os.sep, '/'))
|
||||
|
||||
# 添加所有已映射的目标路径
|
||||
# Add mapped paths from config
|
||||
for target_path in config._path_mappings.keys():
|
||||
self.monitor_paths.add(target_path)
|
||||
|
||||
|
||||
def start(self):
|
||||
"""Start monitoring"""
|
||||
for path_info in self.monitor_paths:
|
||||
"""Start file monitoring"""
|
||||
for path in self.monitor_paths:
|
||||
try:
|
||||
if isinstance(path_info, tuple):
|
||||
# 对于链接,监控目标路径
|
||||
_, target_path = path_info
|
||||
self.observer.schedule(self.handler, target_path, recursive=True)
|
||||
logger.info(f"Started monitoring target path: {target_path}")
|
||||
else:
|
||||
# 对于普通路径,直接监控
|
||||
self.observer.schedule(self.handler, path_info, recursive=True)
|
||||
logger.info(f"Started monitoring: {path_info}")
|
||||
self.observer.schedule(self.handler, path, recursive=True)
|
||||
logger.info(f"Started monitoring: {path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error monitoring {path_info}: {e}")
|
||||
logger.error(f"Error monitoring {path}: {e}")
|
||||
|
||||
self.observer.start()
|
||||
|
||||
|
||||
def stop(self):
|
||||
"""Stop monitoring"""
|
||||
"""Stop file monitoring"""
|
||||
self.observer.stop()
|
||||
self.observer.join()
|
||||
|
||||
|
||||
def rescan_links(self):
|
||||
"""重新扫描链接(当添加新的链接时调用)"""
|
||||
"""Rescan links when new ones are added"""
|
||||
# Find new paths not yet being monitored
|
||||
new_paths = set()
|
||||
for path in self.monitor_paths.copy():
|
||||
self._add_link_targets(path)
|
||||
for path in config._path_mappings.keys():
|
||||
real_path = os.path.realpath(path).replace(os.sep, '/')
|
||||
if real_path not in self.monitor_paths:
|
||||
new_paths.add(real_path)
|
||||
self.monitor_paths.add(real_path)
|
||||
|
||||
# 添加新发现的路径到监控
|
||||
new_paths = self.monitor_paths - set(self.observer.watches.keys())
|
||||
# Add new paths to monitoring
|
||||
for path in new_paths:
|
||||
try:
|
||||
self.observer.schedule(self.handler, path, recursive=True)
|
||||
@@ -334,88 +436,86 @@ class LoraFileMonitor:
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding new monitor for {path}: {e}")
|
||||
|
||||
# Add CheckpointFileMonitor class
|
||||
|
||||
class CheckpointFileMonitor(LoraFileMonitor):
|
||||
class LoraFileMonitor(BaseFileMonitor):
|
||||
"""Monitor for LoRA file changes"""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __new__(cls, monitor_paths=None):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, monitor_paths=None):
|
||||
if not hasattr(self, '_initialized'):
|
||||
if monitor_paths is None:
|
||||
from ..config import config
|
||||
monitor_paths = config.loras_roots
|
||||
|
||||
super().__init__(monitor_paths)
|
||||
self.handler = LoraFileHandler(self.loop)
|
||||
self._initialized = True
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance with async support"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
from ..config import config
|
||||
cls._instance = cls(config.loras_roots)
|
||||
return cls._instance
|
||||
|
||||
|
||||
class CheckpointFileMonitor(BaseFileMonitor):
|
||||
"""Monitor for checkpoint file changes"""
|
||||
|
||||
def __init__(self, scanner: CheckpointScanner, roots: List[str]):
|
||||
# Reuse most of the LoraFileMonitor functionality, but with a different handler
|
||||
self.scanner = scanner
|
||||
scanner.set_file_monitor(self)
|
||||
self.observer = Observer()
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.handler = CheckpointFileHandler(scanner, self.loop)
|
||||
|
||||
# Use existing path mappings
|
||||
self.monitor_paths = set()
|
||||
for root in roots:
|
||||
self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/'))
|
||||
|
||||
# Add all mapped target paths
|
||||
for target_path in config._path_mappings.keys():
|
||||
self.monitor_paths.add(target_path)
|
||||
|
||||
class CheckpointFileHandler(LoraFileHandler):
|
||||
"""Handler for checkpoint file system events"""
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __init__(self, scanner: CheckpointScanner, loop: asyncio.AbstractEventLoop):
|
||||
super().__init__(scanner, loop)
|
||||
# Configure supported file extensions
|
||||
self.supported_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft'}
|
||||
|
||||
def on_created(self, event):
|
||||
if event.is_directory:
|
||||
return
|
||||
|
||||
# Handle supported file extensions directly
|
||||
file_ext = os.path.splitext(event.src_path)[1].lower()
|
||||
if file_ext in self.supported_extensions:
|
||||
if self._should_ignore(event.src_path):
|
||||
return
|
||||
|
||||
# Process this file directly
|
||||
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
|
||||
if normalized_path not in self.scheduled_files:
|
||||
logger.info(f"Checkpoint file created: {event.src_path}")
|
||||
self.scheduled_files.add(normalized_path)
|
||||
self._schedule_update('add', event.src_path)
|
||||
|
||||
# Ignore modifications for a short period after creation
|
||||
self.loop.call_later(
|
||||
self.debounce_delay * 2,
|
||||
self.scheduled_files.discard,
|
||||
normalized_path
|
||||
)
|
||||
def __new__(cls, monitor_paths=None):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def on_modified(self, event):
|
||||
if event.is_directory:
|
||||
return
|
||||
|
||||
# Only process supported file types
|
||||
file_ext = os.path.splitext(event.src_path)[1].lower()
|
||||
if file_ext in self.supported_extensions:
|
||||
super().on_modified(event)
|
||||
def __init__(self, monitor_paths=None):
|
||||
if not hasattr(self, '_initialized'):
|
||||
if monitor_paths is None:
|
||||
# Get checkpoint roots from scanner
|
||||
monitor_paths = []
|
||||
# We'll initialize monitor paths later when scanner is available
|
||||
|
||||
def on_deleted(self, event):
|
||||
if event.is_directory:
|
||||
return
|
||||
super().__init__(monitor_paths or [])
|
||||
self.handler = CheckpointFileHandler(self.loop)
|
||||
self._initialized = True
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance with async support"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls([])
|
||||
|
||||
# Now get checkpoint roots from scanner
|
||||
from .checkpoint_scanner import CheckpointScanner
|
||||
scanner = await CheckpointScanner.get_instance()
|
||||
monitor_paths = scanner.get_model_roots()
|
||||
|
||||
# Update monitor paths
|
||||
for path in monitor_paths:
|
||||
real_path = os.path.realpath(path).replace(os.sep, '/')
|
||||
cls._instance.monitor_paths.add(real_path)
|
||||
|
||||
return cls._instance
|
||||
|
||||
async def initialize_paths(self):
|
||||
"""Initialize monitor paths from scanner"""
|
||||
if not self.monitor_paths:
|
||||
scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
monitor_paths = scanner.get_model_roots()
|
||||
|
||||
file_ext = os.path.splitext(event.src_path)[1].lower()
|
||||
if file_ext not in self.supported_extensions:
|
||||
return
|
||||
|
||||
super().on_deleted(event)
|
||||
|
||||
def on_moved(self, event):
|
||||
"""Handle file move/rename events"""
|
||||
src_ext = os.path.splitext(event.src_path)[1].lower()
|
||||
dest_ext = os.path.splitext(event.dest_path)[1].lower()
|
||||
|
||||
# If destination has supported extension, treat as new file
|
||||
if dest_ext in self.supported_extensions:
|
||||
super().on_moved(event)
|
||||
|
||||
# If source was supported extension, treat as deleted
|
||||
elif src_ext in self.supported_extensions:
|
||||
super().on_moved(event)
|
||||
# Update monitor paths
|
||||
for path in monitor_paths:
|
||||
real_path = os.path.realpath(path).replace(os.sep, '/')
|
||||
self.monitor_paths.add(real_path)
|
||||
Reference in New Issue
Block a user