checkpoint

This commit is contained in:
Will Miao
2025-04-10 09:08:36 +08:00
parent 64c9e4aeca
commit 8fdfb68741
12 changed files with 1397 additions and 484 deletions

View File

@@ -7,6 +7,8 @@ from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from typing import List, Dict, Set
from threading import Lock
from .checkpoint_scanner import CheckpointScanner
from .lora_scanner import LoraScanner
from ..config import config
@@ -330,4 +332,90 @@ class LoraFileMonitor:
self.observer.schedule(self.handler, path, recursive=True)
logger.info(f"Added new monitoring path: {path}")
except Exception as e:
logger.error(f"Error adding new monitor for {path}: {e}")
logger.error(f"Error adding new monitor for {path}: {e}")
# Add CheckpointFileMonitor class
class CheckpointFileMonitor(LoraFileMonitor):
"""Monitor for checkpoint file changes"""
def __init__(self, scanner: CheckpointScanner, roots: List[str]):
# Reuse most of the LoraFileMonitor functionality, but with a different handler
self.scanner = scanner
scanner.set_file_monitor(self)
self.observer = Observer()
self.loop = asyncio.get_event_loop()
self.handler = CheckpointFileHandler(scanner, self.loop)
# Use existing path mappings
self.monitor_paths = set()
for root in roots:
self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/'))
# Add all mapped target paths
for target_path in config._path_mappings.keys():
self.monitor_paths.add(target_path)
class CheckpointFileHandler(LoraFileHandler):
"""Handler for checkpoint file system events"""
def __init__(self, scanner: CheckpointScanner, loop: asyncio.AbstractEventLoop):
super().__init__(scanner, loop)
# Configure supported file extensions
self.supported_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft'}
def on_created(self, event):
if event.is_directory:
return
# Handle supported file extensions directly
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext in self.supported_extensions:
if self._should_ignore(event.src_path):
return
# Process this file directly
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
if normalized_path not in self.scheduled_files:
logger.info(f"Checkpoint file created: {event.src_path}")
self.scheduled_files.add(normalized_path)
self._schedule_update('add', event.src_path)
# Ignore modifications for a short period after creation
self.loop.call_later(
self.debounce_delay * 2,
self.scheduled_files.discard,
normalized_path
)
def on_modified(self, event):
if event.is_directory:
return
# Only process supported file types
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext in self.supported_extensions:
super().on_modified(event)
def on_deleted(self, event):
if event.is_directory:
return
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext not in self.supported_extensions:
return
super().on_deleted(event)
def on_moved(self, event):
"""Handle file move/rename events"""
src_ext = os.path.splitext(event.src_path)[1].lower()
dest_ext = os.path.splitext(event.dest_path)[1].lower()
# If destination has supported extension, treat as new file
if dest_ext in self.supported_extensions:
super().on_moved(event)
# If source was supported extension, treat as deleted
elif src_ext in self.supported_extensions:
super().on_moved(event)