mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
Refactor file monitoring and model scanning; remove unused monitors and streamline model file deletion process.
This commit is contained in:
@@ -131,19 +131,6 @@ class LoraManager:
|
|||||||
|
|
||||||
# Initialize CivitaiClient first to ensure it's ready for other services
|
# Initialize CivitaiClient first to ensure it's ready for other services
|
||||||
civitai_client = await ServiceRegistry.get_civitai_client()
|
civitai_client = await ServiceRegistry.get_civitai_client()
|
||||||
|
|
||||||
# Get file monitors through ServiceRegistry
|
|
||||||
lora_monitor = await ServiceRegistry.get_lora_monitor()
|
|
||||||
checkpoint_monitor = await ServiceRegistry.get_checkpoint_monitor()
|
|
||||||
|
|
||||||
# Start monitors
|
|
||||||
lora_monitor.start()
|
|
||||||
logger.debug("Lora monitor started")
|
|
||||||
|
|
||||||
# Make sure checkpoint monitor has paths before starting
|
|
||||||
await checkpoint_monitor.initialize_paths()
|
|
||||||
checkpoint_monitor.start()
|
|
||||||
logger.debug("Checkpoint monitor started")
|
|
||||||
|
|
||||||
# Register DownloadManager with ServiceRegistry
|
# Register DownloadManager with ServiceRegistry
|
||||||
download_manager = await ServiceRegistry.get_download_manager()
|
download_manager = await ServiceRegistry.get_download_manager()
|
||||||
|
|||||||
@@ -1,542 +0,0 @@
|
|||||||
import os
|
|
||||||
import logging
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from watchdog.observers import Observer
|
|
||||||
from watchdog.events import FileSystemEventHandler
|
|
||||||
from typing import List, Dict, Set, Optional
|
|
||||||
from threading import Lock
|
|
||||||
|
|
||||||
from ..config import config
|
|
||||||
from .service_registry import ServiceRegistry
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Configuration constant to control file monitoring functionality
|
|
||||||
ENABLE_FILE_MONITORING = False
|
|
||||||
|
|
||||||
class BaseFileHandler(FileSystemEventHandler):
|
|
||||||
"""Base handler for file system events"""
|
|
||||||
|
|
||||||
def __init__(self, loop: asyncio.AbstractEventLoop):
|
|
||||||
self.loop = loop # Store event loop reference
|
|
||||||
self.pending_changes = set() # Pending changes
|
|
||||||
self.lock = Lock() # Thread-safe lock
|
|
||||||
self.update_task = None # Async update task
|
|
||||||
self._ignore_paths = set() # Paths to ignore
|
|
||||||
self._min_ignore_timeout = 5 # Minimum timeout in seconds
|
|
||||||
self._download_speed = 1024 * 1024 # Assume 1MB/s as base speed
|
|
||||||
|
|
||||||
# Track modified files with timestamps for debouncing
|
|
||||||
self.modified_files: Dict[str, float] = {}
|
|
||||||
self.debounce_timer = None
|
|
||||||
self.debounce_delay = 3.0 # Seconds to wait after last modification
|
|
||||||
|
|
||||||
# Track files already scheduled for processing
|
|
||||||
self.scheduled_files: Set[str] = set()
|
|
||||||
|
|
||||||
# File extensions to monitor - should be overridden by subclasses
|
|
||||||
self.file_extensions = set()
|
|
||||||
|
|
||||||
def _should_ignore(self, path: str) -> bool:
|
|
||||||
"""Check if path should be ignored"""
|
|
||||||
real_path = os.path.realpath(path) # Resolve any symbolic links
|
|
||||||
return real_path.replace(os.sep, '/') in self._ignore_paths
|
|
||||||
|
|
||||||
def add_ignore_path(self, path: str, file_size: int = 0):
|
|
||||||
"""Add path to ignore list with dynamic timeout based on file size"""
|
|
||||||
real_path = os.path.realpath(path) # Resolve any symbolic links
|
|
||||||
self._ignore_paths.add(real_path.replace(os.sep, '/'))
|
|
||||||
|
|
||||||
# Short timeout (e.g. 5 seconds) is sufficient to ignore the CREATE event
|
|
||||||
timeout = 5
|
|
||||||
|
|
||||||
self.loop.call_later(
|
|
||||||
timeout,
|
|
||||||
self._ignore_paths.discard,
|
|
||||||
real_path.replace(os.sep, '/')
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_created(self, event):
|
|
||||||
if event.is_directory:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Handle appropriate files based on extensions
|
|
||||||
file_ext = os.path.splitext(event.src_path)[1].lower()
|
|
||||||
if file_ext in self.file_extensions:
|
|
||||||
if self._should_ignore(event.src_path):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Process this file directly and ignore subsequent modifications
|
|
||||||
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
|
|
||||||
if normalized_path not in self.scheduled_files:
|
|
||||||
logger.info(f"File created: {event.src_path}")
|
|
||||||
self.scheduled_files.add(normalized_path)
|
|
||||||
self._schedule_update('add', event.src_path)
|
|
||||||
|
|
||||||
# Ignore modifications for a short period after creation
|
|
||||||
self.loop.call_later(
|
|
||||||
self.debounce_delay * 2,
|
|
||||||
self.scheduled_files.discard,
|
|
||||||
normalized_path
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_modified(self, event):
|
|
||||||
if event.is_directory:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Only process files with supported extensions
|
|
||||||
file_ext = os.path.splitext(event.src_path)[1].lower()
|
|
||||||
if file_ext in self.file_extensions:
|
|
||||||
if self._should_ignore(event.src_path):
|
|
||||||
return
|
|
||||||
|
|
||||||
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
|
|
||||||
|
|
||||||
# Skip if this file is already scheduled for processing
|
|
||||||
if normalized_path in self.scheduled_files:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Update the timestamp for this file
|
|
||||||
self.modified_files[normalized_path] = time.time()
|
|
||||||
|
|
||||||
# Cancel any existing timer
|
|
||||||
if self.debounce_timer:
|
|
||||||
self.debounce_timer.cancel()
|
|
||||||
|
|
||||||
# Set a new timer to process modified files after debounce period
|
|
||||||
self.debounce_timer = self.loop.call_later(
|
|
||||||
self.debounce_delay,
|
|
||||||
self.loop.call_soon_threadsafe,
|
|
||||||
self._process_modified_files
|
|
||||||
)
|
|
||||||
|
|
||||||
def _process_modified_files(self):
|
|
||||||
"""Process files that have been modified after debounce period"""
|
|
||||||
current_time = time.time()
|
|
||||||
files_to_process = []
|
|
||||||
|
|
||||||
# Find files that haven't been modified for debounce_delay seconds
|
|
||||||
for file_path, last_modified in list(self.modified_files.items()):
|
|
||||||
if current_time - last_modified >= self.debounce_delay:
|
|
||||||
# Only process if not already scheduled
|
|
||||||
if file_path not in self.scheduled_files:
|
|
||||||
files_to_process.append(file_path)
|
|
||||||
self.scheduled_files.add(file_path)
|
|
||||||
|
|
||||||
# Auto-remove from scheduled list after reasonable time
|
|
||||||
self.loop.call_later(
|
|
||||||
self.debounce_delay * 2,
|
|
||||||
self.scheduled_files.discard,
|
|
||||||
file_path
|
|
||||||
)
|
|
||||||
|
|
||||||
del self.modified_files[file_path]
|
|
||||||
|
|
||||||
# Process stable files
|
|
||||||
for file_path in files_to_process:
|
|
||||||
logger.info(f"Processing modified file: {file_path}")
|
|
||||||
self._schedule_update('add', file_path)
|
|
||||||
|
|
||||||
def on_deleted(self, event):
|
|
||||||
if event.is_directory:
|
|
||||||
return
|
|
||||||
|
|
||||||
file_ext = os.path.splitext(event.src_path)[1].lower()
|
|
||||||
if file_ext not in self.file_extensions:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self._should_ignore(event.src_path):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Remove from scheduled files if present
|
|
||||||
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
|
|
||||||
self.scheduled_files.discard(normalized_path)
|
|
||||||
|
|
||||||
logger.info(f"File deleted: {event.src_path}")
|
|
||||||
self._schedule_update('remove', event.src_path)
|
|
||||||
|
|
||||||
def on_moved(self, event):
|
|
||||||
"""Handle file move/rename events"""
|
|
||||||
|
|
||||||
src_ext = os.path.splitext(event.src_path)[1].lower()
|
|
||||||
dest_ext = os.path.splitext(event.dest_path)[1].lower()
|
|
||||||
|
|
||||||
# If destination has supported extension, treat as new file
|
|
||||||
if dest_ext in self.file_extensions:
|
|
||||||
if self._should_ignore(event.dest_path):
|
|
||||||
return
|
|
||||||
|
|
||||||
normalized_path = os.path.realpath(event.dest_path).replace(os.sep, '/')
|
|
||||||
|
|
||||||
# Only process if not already scheduled
|
|
||||||
if normalized_path not in self.scheduled_files:
|
|
||||||
logger.info(f"File renamed/moved to: {event.dest_path}")
|
|
||||||
self.scheduled_files.add(normalized_path)
|
|
||||||
self._schedule_update('add', event.dest_path)
|
|
||||||
|
|
||||||
# Auto-remove from scheduled list after reasonable time
|
|
||||||
self.loop.call_later(
|
|
||||||
self.debounce_delay * 2,
|
|
||||||
self.scheduled_files.discard,
|
|
||||||
normalized_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# If source was a supported file, treat it as deleted
|
|
||||||
if src_ext in self.file_extensions:
|
|
||||||
if self._should_ignore(event.src_path):
|
|
||||||
return
|
|
||||||
|
|
||||||
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
|
|
||||||
self.scheduled_files.discard(normalized_path)
|
|
||||||
|
|
||||||
logger.info(f"File moved/renamed from: {event.src_path}")
|
|
||||||
self._schedule_update('remove', event.src_path)
|
|
||||||
|
|
||||||
def _schedule_update(self, action: str, file_path: str):
|
|
||||||
"""Schedule a cache update"""
|
|
||||||
with self.lock:
|
|
||||||
# Use config method to map path
|
|
||||||
mapped_path = config.map_path_to_link(file_path)
|
|
||||||
normalized_path = mapped_path.replace(os.sep, '/')
|
|
||||||
self.pending_changes.add((action, normalized_path))
|
|
||||||
|
|
||||||
self.loop.call_soon_threadsafe(self._create_update_task)
|
|
||||||
|
|
||||||
def _create_update_task(self):
|
|
||||||
"""Create update task in the event loop"""
|
|
||||||
if self.update_task is None or self.update_task.done():
|
|
||||||
self.update_task = asyncio.create_task(self._process_changes())
|
|
||||||
|
|
||||||
async def _process_changes(self, delay: float = 2.0):
|
|
||||||
"""Process pending changes with debouncing - should be implemented by subclasses"""
|
|
||||||
raise NotImplementedError("Subclasses must implement _process_changes")
|
|
||||||
|
|
||||||
|
|
||||||
class LoraFileHandler(BaseFileHandler):
|
|
||||||
"""Handler for LoRA file system events"""
|
|
||||||
|
|
||||||
def __init__(self, loop: asyncio.AbstractEventLoop):
|
|
||||||
super().__init__(loop)
|
|
||||||
# Set supported file extensions for LoRAs
|
|
||||||
self.file_extensions = {'.safetensors'}
|
|
||||||
|
|
||||||
async def _process_changes(self, delay: float = 2.0):
|
|
||||||
"""Process pending changes with debouncing"""
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with self.lock:
|
|
||||||
changes = self.pending_changes.copy()
|
|
||||||
self.pending_changes.clear()
|
|
||||||
|
|
||||||
if not changes:
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"Processing {len(changes)} LoRA file changes")
|
|
||||||
|
|
||||||
# Get scanner through ServiceRegistry
|
|
||||||
scanner = await ServiceRegistry.get_lora_scanner()
|
|
||||||
cache = await scanner.get_cached_data()
|
|
||||||
needs_resort = False
|
|
||||||
new_folders = set()
|
|
||||||
|
|
||||||
for action, file_path in changes:
|
|
||||||
try:
|
|
||||||
if action == 'add':
|
|
||||||
# Check if file already exists in cache
|
|
||||||
existing = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
|
||||||
if existing:
|
|
||||||
logger.info(f"File {file_path} already in cache, skipping")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Scan new file
|
|
||||||
model_data = await scanner.scan_single_model(file_path)
|
|
||||||
if model_data:
|
|
||||||
# Update tags count
|
|
||||||
for tag in model_data.get('tags', []):
|
|
||||||
scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1
|
|
||||||
|
|
||||||
cache.raw_data.append(model_data)
|
|
||||||
new_folders.add(model_data['folder'])
|
|
||||||
# Update hash index
|
|
||||||
if 'sha256' in model_data:
|
|
||||||
scanner._hash_index.add_entry(
|
|
||||||
model_data['sha256'],
|
|
||||||
model_data['file_path']
|
|
||||||
)
|
|
||||||
needs_resort = True
|
|
||||||
|
|
||||||
elif action == 'remove':
|
|
||||||
# Find the model to remove so we can update tags count
|
|
||||||
model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
|
||||||
if model_to_remove:
|
|
||||||
# Update tags count by reducing counts
|
|
||||||
for tag in model_to_remove.get('tags', []):
|
|
||||||
if tag in scanner._tags_count:
|
|
||||||
scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
|
|
||||||
if scanner._tags_count[tag] == 0:
|
|
||||||
del scanner._tags_count[tag]
|
|
||||||
|
|
||||||
# Remove from cache and hash index
|
|
||||||
logger.info(f"Removing {file_path} from cache")
|
|
||||||
scanner._hash_index.remove_by_path(file_path)
|
|
||||||
cache.raw_data = [
|
|
||||||
item for item in cache.raw_data
|
|
||||||
if item['file_path'] != file_path
|
|
||||||
]
|
|
||||||
needs_resort = True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing {action} for {file_path}: {e}")
|
|
||||||
|
|
||||||
if needs_resort:
|
|
||||||
await cache.resort()
|
|
||||||
|
|
||||||
# Update folder list
|
|
||||||
all_folders = set(cache.folders) | new_folders
|
|
||||||
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in process_changes for LoRA: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class CheckpointFileHandler(BaseFileHandler):
|
|
||||||
"""Handler for checkpoint file system events"""
|
|
||||||
|
|
||||||
def __init__(self, loop: asyncio.AbstractEventLoop):
|
|
||||||
super().__init__(loop)
|
|
||||||
# Set supported file extensions for checkpoints
|
|
||||||
self.file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
|
|
||||||
|
|
||||||
async def _process_changes(self, delay: float = 2.0):
|
|
||||||
"""Process pending changes with debouncing for checkpoint files"""
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with self.lock:
|
|
||||||
changes = self.pending_changes.copy()
|
|
||||||
self.pending_changes.clear()
|
|
||||||
|
|
||||||
if not changes:
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"Processing {len(changes)} checkpoint file changes")
|
|
||||||
|
|
||||||
# Get scanner through ServiceRegistry
|
|
||||||
scanner = await ServiceRegistry.get_checkpoint_scanner()
|
|
||||||
cache = await scanner.get_cached_data()
|
|
||||||
needs_resort = False
|
|
||||||
new_folders = set()
|
|
||||||
|
|
||||||
for action, file_path in changes:
|
|
||||||
try:
|
|
||||||
if action == 'add':
|
|
||||||
# Check if file already exists in cache
|
|
||||||
existing = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
|
||||||
if existing:
|
|
||||||
logger.info(f"File {file_path} already in cache, skipping")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Scan new file
|
|
||||||
model_data = await scanner.scan_single_model(file_path)
|
|
||||||
if model_data:
|
|
||||||
# Update tags count if applicable
|
|
||||||
for tag in model_data.get('tags', []):
|
|
||||||
scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1
|
|
||||||
|
|
||||||
cache.raw_data.append(model_data)
|
|
||||||
new_folders.add(model_data['folder'])
|
|
||||||
# Update hash index
|
|
||||||
if 'sha256' in model_data:
|
|
||||||
scanner._hash_index.add_entry(
|
|
||||||
model_data['sha256'],
|
|
||||||
model_data['file_path']
|
|
||||||
)
|
|
||||||
needs_resort = True
|
|
||||||
|
|
||||||
elif action == 'remove':
|
|
||||||
# Find the model to remove so we can update tags count
|
|
||||||
model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
|
||||||
if model_to_remove:
|
|
||||||
# Update tags count by reducing counts
|
|
||||||
for tag in model_to_remove.get('tags', []):
|
|
||||||
if tag in scanner._tags_count:
|
|
||||||
scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
|
|
||||||
if scanner._tags_count[tag] == 0:
|
|
||||||
del scanner._tags_count[tag]
|
|
||||||
|
|
||||||
# Remove from cache and hash index
|
|
||||||
logger.info(f"Removing {file_path} from checkpoint cache")
|
|
||||||
scanner._hash_index.remove_by_path(file_path)
|
|
||||||
cache.raw_data = [
|
|
||||||
item for item in cache.raw_data
|
|
||||||
if item['file_path'] != file_path
|
|
||||||
]
|
|
||||||
needs_resort = True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing checkpoint {action} for {file_path}: {e}")
|
|
||||||
|
|
||||||
if needs_resort:
|
|
||||||
await cache.resort()
|
|
||||||
|
|
||||||
# Update folder list
|
|
||||||
all_folders = set(cache.folders) | new_folders
|
|
||||||
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in process_changes for checkpoint: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class BaseFileMonitor:
|
|
||||||
"""Base class for file monitoring"""
|
|
||||||
|
|
||||||
def __init__(self, monitor_paths: List[str]):
|
|
||||||
self.observer = Observer()
|
|
||||||
self.loop = asyncio.get_event_loop()
|
|
||||||
self.monitor_paths = set()
|
|
||||||
|
|
||||||
# Process monitor paths
|
|
||||||
for path in monitor_paths:
|
|
||||||
self.monitor_paths.add(os.path.realpath(path).replace(os.sep, '/'))
|
|
||||||
|
|
||||||
# Add mapped paths from config
|
|
||||||
for target_path in config._path_mappings.keys():
|
|
||||||
self.monitor_paths.add(target_path)
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
"""Start file monitoring"""
|
|
||||||
if not ENABLE_FILE_MONITORING:
|
|
||||||
logger.debug("File monitoring is disabled via ENABLE_FILE_MONITORING setting")
|
|
||||||
return
|
|
||||||
|
|
||||||
for path in self.monitor_paths:
|
|
||||||
try:
|
|
||||||
self.observer.schedule(self.handler, path, recursive=True)
|
|
||||||
logger.info(f"Started monitoring: {path}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error monitoring {path}: {e}")
|
|
||||||
|
|
||||||
self.observer.start()
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
"""Stop file monitoring"""
|
|
||||||
if not ENABLE_FILE_MONITORING:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.observer.stop()
|
|
||||||
self.observer.join()
|
|
||||||
|
|
||||||
def rescan_links(self):
|
|
||||||
"""Rescan links when new ones are added"""
|
|
||||||
if not ENABLE_FILE_MONITORING:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Find new paths not yet being monitored
|
|
||||||
new_paths = set()
|
|
||||||
for path in config._path_mappings.keys():
|
|
||||||
real_path = os.path.realpath(path).replace(os.sep, '/')
|
|
||||||
if real_path not in self.monitor_paths:
|
|
||||||
new_paths.add(real_path)
|
|
||||||
self.monitor_paths.add(real_path)
|
|
||||||
|
|
||||||
# Add new paths to monitoring
|
|
||||||
for path in new_paths:
|
|
||||||
try:
|
|
||||||
self.observer.schedule(self.handler, path, recursive=True)
|
|
||||||
logger.info(f"Added new monitoring path: {path}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error adding new monitor for {path}: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class LoraFileMonitor(BaseFileMonitor):
|
|
||||||
"""Monitor for LoRA file changes"""
|
|
||||||
|
|
||||||
_instance = None
|
|
||||||
_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
def __new__(cls, monitor_paths=None):
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super().__new__(cls)
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def __init__(self, monitor_paths=None):
|
|
||||||
if not hasattr(self, '_initialized'):
|
|
||||||
if monitor_paths is None:
|
|
||||||
from ..config import config
|
|
||||||
monitor_paths = config.loras_roots
|
|
||||||
|
|
||||||
super().__init__(monitor_paths)
|
|
||||||
self.handler = LoraFileHandler(self.loop)
|
|
||||||
self._initialized = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_instance(cls):
|
|
||||||
"""Get singleton instance with async support"""
|
|
||||||
async with cls._lock:
|
|
||||||
if cls._instance is None:
|
|
||||||
from ..config import config
|
|
||||||
cls._instance = cls(config.loras_roots)
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
|
|
||||||
class CheckpointFileMonitor(BaseFileMonitor):
|
|
||||||
"""Monitor for checkpoint file changes"""
|
|
||||||
|
|
||||||
_instance = None
|
|
||||||
_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
def __new__(cls, monitor_paths=None):
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super().__new__(cls)
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def __init__(self, monitor_paths=None):
|
|
||||||
if not hasattr(self, '_initialized'):
|
|
||||||
if monitor_paths is None:
|
|
||||||
# Get checkpoint roots from scanner
|
|
||||||
monitor_paths = []
|
|
||||||
# We'll initialize monitor paths later when scanner is available
|
|
||||||
|
|
||||||
super().__init__(monitor_paths or [])
|
|
||||||
self.handler = CheckpointFileHandler(self.loop)
|
|
||||||
self._initialized = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_instance(cls):
|
|
||||||
"""Get singleton instance with async support"""
|
|
||||||
async with cls._lock:
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = cls([])
|
|
||||||
|
|
||||||
# Now get checkpoint roots from scanner
|
|
||||||
from .checkpoint_scanner import CheckpointScanner
|
|
||||||
scanner = await CheckpointScanner.get_instance()
|
|
||||||
monitor_paths = scanner.get_model_roots()
|
|
||||||
|
|
||||||
# Update monitor paths - but don't actually monitor them
|
|
||||||
for path in monitor_paths:
|
|
||||||
real_path = os.path.realpath(path).replace(os.sep, '/')
|
|
||||||
cls._instance.monitor_paths.add(real_path)
|
|
||||||
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
"""Override start to check global enable flag"""
|
|
||||||
if not ENABLE_FILE_MONITORING:
|
|
||||||
logger.debug("Checkpoint file monitoring is disabled via ENABLE_FILE_MONITORING setting")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.debug("Checkpoint file monitoring is temporarily disabled")
|
|
||||||
# Skip the actual monitoring setup
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def initialize_paths(self):
|
|
||||||
"""Initialize monitor paths from scanner - currently disabled"""
|
|
||||||
if not ENABLE_FILE_MONITORING:
|
|
||||||
logger.debug("Checkpoint path initialization skipped (monitoring disabled)")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.debug("Checkpoint file path initialization skipped (monitoring disabled)")
|
|
||||||
pass
|
|
||||||
@@ -751,29 +751,6 @@ class ModelScanner:
|
|||||||
def get_model_roots(self) -> List[str]:
|
def get_model_roots(self) -> List[str]:
|
||||||
"""Get model root directories"""
|
"""Get model root directories"""
|
||||||
raise NotImplementedError("Subclasses must implement get_model_roots")
|
raise NotImplementedError("Subclasses must implement get_model_roots")
|
||||||
|
|
||||||
async def scan_single_model(self, file_path: str) -> Optional[Dict]:
|
|
||||||
"""Scan a single model file and return its metadata"""
|
|
||||||
try:
|
|
||||||
if not os.path.exists(os.path.realpath(file_path)):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Get basic file info
|
|
||||||
metadata = await self._get_file_info(file_path)
|
|
||||||
if not metadata:
|
|
||||||
return None
|
|
||||||
|
|
||||||
folder = self._calculate_folder(file_path)
|
|
||||||
|
|
||||||
# Ensure folder field exists
|
|
||||||
metadata_dict = metadata.to_dict()
|
|
||||||
metadata_dict['folder'] = folder or ''
|
|
||||||
|
|
||||||
return metadata_dict
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error scanning {file_path}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _get_file_info(self, file_path: str) -> Optional[BaseModelMetadata]:
|
async def _get_file_info(self, file_path: str) -> Optional[BaseModelMetadata]:
|
||||||
"""Get model file info and metadata (extensible for different model types)"""
|
"""Get model file info and metadata (extensible for different model types)"""
|
||||||
@@ -1263,9 +1240,6 @@ class ModelScanner:
|
|||||||
'results': []
|
'results': []
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get the file monitor
|
|
||||||
file_monitor = getattr(self, 'file_monitor', None)
|
|
||||||
|
|
||||||
# Keep track of success and failures
|
# Keep track of success and failures
|
||||||
results = []
|
results = []
|
||||||
total_deleted = 0
|
total_deleted = 0
|
||||||
@@ -1286,8 +1260,7 @@ class ModelScanner:
|
|||||||
from ..utils.routes_common import ModelRouteUtils
|
from ..utils.routes_common import ModelRouteUtils
|
||||||
deleted_files = await ModelRouteUtils.delete_model_files(
|
deleted_files = await ModelRouteUtils.delete_model_files(
|
||||||
target_dir,
|
target_dir,
|
||||||
file_name,
|
file_name
|
||||||
file_monitor
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if deleted_files:
|
if deleted_files:
|
||||||
|
|||||||
@@ -58,26 +58,6 @@ class ServiceRegistry:
|
|||||||
scanner = await CheckpointScanner.get_instance()
|
scanner = await CheckpointScanner.get_instance()
|
||||||
await cls.register_service("checkpoint_scanner", scanner)
|
await cls.register_service("checkpoint_scanner", scanner)
|
||||||
return scanner
|
return scanner
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_lora_monitor(cls):
|
|
||||||
"""Get the LoraFileMonitor instance"""
|
|
||||||
from .file_monitor import LoraFileMonitor
|
|
||||||
monitor = await cls.get_service("lora_monitor")
|
|
||||||
if monitor is None:
|
|
||||||
monitor = await LoraFileMonitor.get_instance()
|
|
||||||
await cls.register_service("lora_monitor", monitor)
|
|
||||||
return monitor
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def get_checkpoint_monitor(cls):
|
|
||||||
"""Get the CheckpointFileMonitor instance"""
|
|
||||||
from .file_monitor import CheckpointFileMonitor
|
|
||||||
monitor = await cls.get_service("checkpoint_monitor")
|
|
||||||
if monitor is None:
|
|
||||||
monitor = await CheckpointFileMonitor.get_instance()
|
|
||||||
await cls.register_service("checkpoint_monitor", monitor)
|
|
||||||
return monitor
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_civitai_client(cls):
|
async def get_civitai_client(cls):
|
||||||
@@ -95,7 +75,6 @@ class ServiceRegistry:
|
|||||||
from .download_manager import DownloadManager
|
from .download_manager import DownloadManager
|
||||||
manager = await cls.get_service("download_manager")
|
manager = await cls.get_service("download_manager")
|
||||||
if manager is None:
|
if manager is None:
|
||||||
# We'll let DownloadManager.get_instance handle file_monitor parameter
|
|
||||||
manager = await DownloadManager.get_instance()
|
manager = await DownloadManager.get_instance()
|
||||||
await cls.register_service("download_manager", manager)
|
await cls.register_service("download_manager", manager)
|
||||||
return manager
|
return manager
|
||||||
|
|||||||
@@ -209,13 +209,12 @@ class ModelRouteUtils:
|
|||||||
return {k: data[k] for k in fields if k in data}
|
return {k: data[k] for k in fields if k in data}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def delete_model_files(target_dir: str, file_name: str, file_monitor=None) -> List[str]:
|
async def delete_model_files(target_dir: str, file_name: str) -> List[str]:
|
||||||
"""Delete model and associated files
|
"""Delete model and associated files
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_dir: Directory containing the model files
|
target_dir: Directory containing the model files
|
||||||
file_name: Base name of the model file without extension
|
file_name: Base name of the model file without extension
|
||||||
file_monitor: Optional file monitor to ignore delete events
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of deleted file paths
|
List of deleted file paths
|
||||||
@@ -233,11 +232,7 @@ class ModelRouteUtils:
|
|||||||
main_file = patterns[0]
|
main_file = patterns[0]
|
||||||
main_path = os.path.join(target_dir, main_file).replace(os.sep, '/')
|
main_path = os.path.join(target_dir, main_file).replace(os.sep, '/')
|
||||||
|
|
||||||
if os.path.exists(main_path):
|
if os.path.exists(main_path):
|
||||||
# Notify file monitor to ignore delete event if available
|
|
||||||
if file_monitor:
|
|
||||||
file_monitor.handler.add_ignore_path(main_path, 0)
|
|
||||||
|
|
||||||
# Delete file
|
# Delete file
|
||||||
os.remove(main_path)
|
os.remove(main_path)
|
||||||
deleted.append(main_path)
|
deleted.append(main_path)
|
||||||
@@ -286,13 +281,9 @@ class ModelRouteUtils:
|
|||||||
target_dir = os.path.dirname(file_path)
|
target_dir = os.path.dirname(file_path)
|
||||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||||
|
|
||||||
# Get the file monitor from the scanner if available
|
|
||||||
file_monitor = getattr(scanner, 'file_monitor', None)
|
|
||||||
|
|
||||||
deleted_files = await ModelRouteUtils.delete_model_files(
|
deleted_files = await ModelRouteUtils.delete_model_files(
|
||||||
target_dir,
|
target_dir,
|
||||||
file_name,
|
file_name
|
||||||
file_monitor
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove from cache
|
# Remove from cache
|
||||||
|
|||||||
Reference in New Issue
Block a user