checkpoint

This commit is contained in:
Will Miao
2025-04-11 20:22:12 +08:00
parent 1db49a4dd4
commit 0618541527
13 changed files with 793 additions and 276 deletions

View File

@@ -8,6 +8,7 @@ from ..utils.models import CheckpointMetadata
from ..config import config
from .model_scanner import ModelScanner
from .model_hash_index import ModelHashIndex
from .service_registry import ServiceRegistry
logger = logging.getLogger(__name__)

View File

@@ -3,6 +3,7 @@ import aiohttp
import os
import json
import logging
import asyncio
from email.parser import Parser
from typing import Optional, Dict, Tuple, List
from urllib.parse import unquote
@@ -11,7 +12,23 @@ from ..utils.models import LoraMetadata
logger = logging.getLogger(__name__)
class CivitaiClient:
_instance = None
_lock = asyncio.Lock()
@classmethod
async def get_instance(cls):
"""Get singleton instance of CivitaiClient"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
# Check if already initialized for singleton pattern
if hasattr(self, '_initialized'):
return
self._initialized = True
self.base_url = "https://civitai.com/api/v1"
self.headers = {
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'

View File

@@ -1,12 +1,13 @@
import logging
import os
import json
from typing import Optional, Dict
import asyncio
from typing import Optional, Dict, Any
from .civitai_client import CivitaiClient
from .file_monitor import LoraFileMonitor
from ..utils.models import LoraMetadata, CheckpointMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH
from ..utils.exif_utils import ExifUtils
from .service_registry import ServiceRegistry
# Download to temporary file first
import tempfile
@@ -14,9 +15,46 @@ import tempfile
logger = logging.getLogger(__name__)
class DownloadManager:
def __init__(self, file_monitor: Optional[LoraFileMonitor] = None):
self.civitai_client = CivitaiClient()
self.file_monitor = file_monitor
_instance = None
_lock = asyncio.Lock()
@classmethod
async def get_instance(cls):
"""Get singleton instance of DownloadManager"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
# Check if already initialized for singleton pattern
if hasattr(self, '_initialized'):
return
self._initialized = True
self._civitai_client = None # Will be lazily initialized
async def _get_civitai_client(self):
"""Lazily initialize CivitaiClient from registry"""
if self._civitai_client is None:
self._civitai_client = await ServiceRegistry.get_civitai_client()
return self._civitai_client
async def _get_lora_monitor(self):
"""Get the lora file monitor from registry"""
return await ServiceRegistry.get_lora_monitor()
async def _get_checkpoint_monitor(self):
"""Get the checkpoint file monitor from registry"""
return await ServiceRegistry.get_checkpoint_monitor()
async def _get_lora_scanner(self):
"""Get the lora scanner from registry"""
return await ServiceRegistry.get_lora_scanner()
async def _get_checkpoint_scanner(self):
"""Get the checkpoint scanner from registry"""
return await ServiceRegistry.get_checkpoint_scanner()
async def download_from_civitai(self, download_url: str = None, model_hash: str = None,
model_version_id: str = None, save_dir: str = None,
@@ -43,19 +81,22 @@ class DownloadManager:
# Create directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)
# Get civitai client
civitai_client = await self._get_civitai_client()
# Get version info based on the provided identifier
version_info = None
if download_url:
# Extract version ID from download URL
version_id = download_url.split('/')[-1]
version_info = await self.civitai_client.get_model_version_info(version_id)
version_info = await civitai_client.get_model_version_info(version_id)
elif model_version_id:
# Use model version ID directly
version_info = await self.civitai_client.get_model_version_info(model_version_id)
version_info = await civitai_client.get_model_version_info(model_version_id)
elif model_hash:
# Get model by hash
version_info = await self.civitai_client.get_model_by_hash(model_hash)
version_info = await civitai_client.get_model_by_hash(model_hash)
if not version_info:
@@ -95,8 +136,9 @@ class DownloadManager:
file_size = file_info.get('sizeKB', 0) * 1024
# 4. Notify file monitor - use normalized path and file size
if self.file_monitor and self.file_monitor.handler:
self.file_monitor.handler.add_ignore_path(
file_monitor = await self._get_lora_monitor() if model_type == "lora" else await self._get_checkpoint_monitor()
if file_monitor and file_monitor.handler:
file_monitor.handler.add_ignore_path(
save_path.replace(os.sep, '/'),
file_size
)
@@ -112,7 +154,7 @@ class DownloadManager:
# 5.1 Get and update model tags and description
model_id = version_info.get('modelId')
if model_id:
model_metadata, _ = await self.civitai_client.get_model_metadata(str(model_id))
model_metadata, _ = await civitai_client.get_model_metadata(str(model_id))
if model_metadata:
if model_metadata.get("tags"):
metadata.tags = model_metadata.get("tags", [])
@@ -146,6 +188,7 @@ class DownloadManager:
model_type: str = "lora") -> Dict:
"""Execute the actual download process including preview images and model files"""
try:
civitai_client = await self._get_civitai_client()
save_path = metadata.file_path
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
@@ -165,7 +208,7 @@ class DownloadManager:
preview_path = os.path.splitext(save_path)[0] + preview_ext
# Download video directly
if await self.civitai_client.download_preview_image(images[0]['url'], preview_path):
if await civitai_client.download_preview_image(images[0]['url'], preview_path):
metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
with open(metadata_path, 'w', encoding='utf-8') as f:
@@ -176,7 +219,7 @@ class DownloadManager:
temp_path = temp_file.name
# Download the original image to temp path
if await self.civitai_client.download_preview_image(images[0]['url'], temp_path):
if await civitai_client.download_preview_image(images[0]['url'], temp_path):
# Optimize and convert to WebP
preview_path = os.path.splitext(save_path)[0] + '.webp'
@@ -210,7 +253,7 @@ class DownloadManager:
await progress_callback(3) # 3% progress after preview download
# Download model file with progress tracking
success, result = await self.civitai_client._download_file(
success, result = await civitai_client._download_file(
download_url,
save_dir,
os.path.basename(save_path),
@@ -232,13 +275,14 @@ class DownloadManager:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
# 6. Update cache based on model type
if model_type == "checkpoint" and hasattr(self.file_monitor, "checkpoint_scanner"):
cache = await self.file_monitor.checkpoint_scanner.get_cached_data()
if model_type == "checkpoint":
scanner = await self._get_checkpoint_scanner()
logger.info(f"Updating checkpoint cache for {save_path}")
else:
cache = await self.file_monitor.scanner.get_cached_data()
scanner = await self._get_lora_scanner()
logger.info(f"Updating lora cache for {save_path}")
cache = await scanner.get_cached_data()
metadata_dict = metadata.to_dict()
metadata_dict['folder'] = relative_path
cache.raw_data.append(metadata_dict)
@@ -248,10 +292,7 @@ class DownloadManager:
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Update the hash index with the new model entry
if model_type == "checkpoint" and hasattr(self.file_monitor, "checkpoint_scanner"):
self.file_monitor.checkpoint_scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
else:
self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
# Report 100% completion
if progress_callback:

View File

@@ -1,39 +1,39 @@
from operator import itemgetter
import os
import logging
import asyncio
import time
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from typing import List, Dict, Set
from typing import List, Dict, Set, Optional
from threading import Lock
from .checkpoint_scanner import CheckpointScanner
from .lora_scanner import LoraScanner
from ..config import config
from .service_registry import ServiceRegistry
logger = logging.getLogger(__name__)
class LoraFileHandler(FileSystemEventHandler):
"""Handler for LoRA file system events"""
class BaseFileHandler(FileSystemEventHandler):
"""Base handler for file system events"""
def __init__(self, scanner: LoraScanner, loop: asyncio.AbstractEventLoop):
self.scanner = scanner
self.loop = loop # 存储事件循环引用
self.pending_changes = set() # 待处理的变更
self.lock = Lock() # 线程安全锁
self.update_task = None # 异步更新任务
self._ignore_paths = set() # Add ignore paths set
self._min_ignore_timeout = 5 # minimum timeout in seconds
self._download_speed = 1024 * 1024 # assume 1MB/s as base speed
def __init__(self, loop: asyncio.AbstractEventLoop):
self.loop = loop # Store event loop reference
self.pending_changes = set() # Pending changes
self.lock = Lock() # Thread-safe lock
self.update_task = None # Async update task
self._ignore_paths = set() # Paths to ignore
self._min_ignore_timeout = 5 # Minimum timeout in seconds
self._download_speed = 1024 * 1024 # Assume 1MB/s as base speed
# Track modified files with timestamps for debouncing
self.modified_files: Dict[str, float] = {}
self.debounce_timer = None
self.debounce_delay = 3.0 # seconds to wait after last modification
self.debounce_delay = 3.0 # Seconds to wait after last modification
# Track files that are already scheduled for processing
# Track files already scheduled for processing
self.scheduled_files: Set[str] = set()
# File extensions to monitor - should be overridden by subclasses
self.file_extensions = set()
def _should_ignore(self, path: str) -> bool:
"""Check if path should be ignored"""
@@ -58,35 +58,33 @@ class LoraFileHandler(FileSystemEventHandler):
if event.is_directory:
return
# Handle safetensors files directly
if event.src_path.endswith('.safetensors'):
# Handle appropriate files based on extensions
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext in self.file_extensions:
if self._should_ignore(event.src_path):
return
# We'll process this file directly and ignore subsequent modifications
# to prevent duplicate processing
# Process this file directly and ignore subsequent modifications
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
if normalized_path not in self.scheduled_files:
logger.info(f"LoRA file created: {event.src_path}")
logger.info(f"File created: {event.src_path}")
self.scheduled_files.add(normalized_path)
self._schedule_update('add', event.src_path)
# Ignore modifications for a short period after creation
# This helps avoid duplicate processing
self.loop.call_later(
self.debounce_delay * 2,
self.scheduled_files.discard,
normalized_path
)
# For browser downloads, we'll catch them when they're renamed to .safetensors
def on_modified(self, event):
if event.is_directory:
return
# Only process safetensors files
if event.src_path.endswith('.safetensors'):
# Only process files with supported extensions
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext in self.file_extensions:
if self._should_ignore(event.src_path):
return
@@ -134,12 +132,17 @@ class LoraFileHandler(FileSystemEventHandler):
# Process stable files
for file_path in files_to_process:
logger.info(f"Processing modified LoRA file: {file_path}")
logger.info(f"Processing modified file: {file_path}")
self._schedule_update('add', file_path)
def on_deleted(self, event):
if event.is_directory or not event.src_path.endswith('.safetensors'):
if event.is_directory:
return
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext not in self.file_extensions:
return
if self._should_ignore(event.src_path):
return
@@ -147,14 +150,17 @@ class LoraFileHandler(FileSystemEventHandler):
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
self.scheduled_files.discard(normalized_path)
logger.info(f"LoRA file deleted: {event.src_path}")
logger.info(f"File deleted: {event.src_path}")
self._schedule_update('remove', event.src_path)
def on_moved(self, event):
"""Handle file move/rename events"""
# If destination is a safetensors file, treat it as a new file
if event.dest_path.endswith('.safetensors'):
src_ext = os.path.splitext(event.src_path)[1].lower()
dest_ext = os.path.splitext(event.dest_path)[1].lower()
# If destination has supported extension, treat as new file
if dest_ext in self.file_extensions:
if self._should_ignore(event.dest_path):
return
@@ -162,7 +168,7 @@ class LoraFileHandler(FileSystemEventHandler):
# Only process if not already scheduled
if normalized_path not in self.scheduled_files:
logger.info(f"LoRA file renamed/moved to: {event.dest_path}")
logger.info(f"File renamed/moved to: {event.dest_path}")
self.scheduled_files.add(normalized_path)
self._schedule_update('add', event.dest_path)
@@ -173,21 +179,21 @@ class LoraFileHandler(FileSystemEventHandler):
normalized_path
)
# If source was a safetensors file, treat it as deleted
if event.src_path.endswith('.safetensors'):
# If source was a supported file, treat it as deleted
if src_ext in self.file_extensions:
if self._should_ignore(event.src_path):
return
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
self.scheduled_files.discard(normalized_path)
logger.info(f"LoRA file moved/renamed from: {event.src_path}")
logger.info(f"File moved/renamed from: {event.src_path}")
self._schedule_update('remove', event.src_path)
def _schedule_update(self, action: str, file_path: str): #file_path is a real path
def _schedule_update(self, action: str, file_path: str):
"""Schedule a cache update"""
with self.lock:
# 使用 config 中的方法映射路径
# Use config method to map path
mapped_path = config.map_path_to_link(file_path)
normalized_path = mapped_path.replace(os.sep, '/')
self.pending_changes.add((action, normalized_path))
@@ -198,7 +204,20 @@ class LoraFileHandler(FileSystemEventHandler):
"""Create update task in the event loop"""
if self.update_task is None or self.update_task.done():
self.update_task = asyncio.create_task(self._process_changes())
async def _process_changes(self, delay: float = 2.0):
"""Process pending changes with debouncing - should be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement _process_changes")
class LoraFileHandler(BaseFileHandler):
"""Handler for LoRA file system events"""
def __init__(self, loop: asyncio.AbstractEventLoop):
super().__init__(loop)
# Set supported file extensions for LoRAs
self.file_extensions = {'.safetensors'}
async def _process_changes(self, delay: float = 2.0):
"""Process pending changes with debouncing"""
await asyncio.sleep(delay)
@@ -211,9 +230,11 @@ class LoraFileHandler(FileSystemEventHandler):
if not changes:
return
logger.info(f"Processing {len(changes)} file changes")
logger.info(f"Processing {len(changes)} LoRA file changes")
cache = await self.scanner.get_cached_data()
# Get scanner through ServiceRegistry
scanner = await ServiceRegistry.get_lora_scanner()
cache = await scanner.get_cached_data()
needs_resort = False
new_folders = set()
@@ -227,36 +248,36 @@ class LoraFileHandler(FileSystemEventHandler):
continue
# Scan new file
lora_data = await self.scanner.scan_single_lora(file_path)
if lora_data:
model_data = await scanner.scan_single_model(file_path)
if model_data:
# Update tags count
for tag in lora_data.get('tags', []):
self.scanner._tags_count[tag] = self.scanner._tags_count.get(tag, 0) + 1
for tag in model_data.get('tags', []):
scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1
cache.raw_data.append(lora_data)
new_folders.add(lora_data['folder'])
cache.raw_data.append(model_data)
new_folders.add(model_data['folder'])
# Update hash index
if 'sha256' in lora_data:
self.scanner._hash_index.add_entry(
lora_data['sha256'],
lora_data['file_path']
if 'sha256' in model_data:
scanner._hash_index.add_entry(
model_data['sha256'],
model_data['file_path']
)
needs_resort = True
elif action == 'remove':
# Find the lora to remove so we can update tags count
lora_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if lora_to_remove:
# Find the model to remove so we can update tags count
model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if model_to_remove:
# Update tags count by reducing counts
for tag in lora_to_remove.get('tags', []):
if tag in self.scanner._tags_count:
self.scanner._tags_count[tag] = max(0, self.scanner._tags_count[tag] - 1)
if self.scanner._tags_count[tag] == 0:
del self.scanner._tags_count[tag]
for tag in model_to_remove.get('tags', []):
if tag in scanner._tags_count:
scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
if scanner._tags_count[tag] == 0:
del scanner._tags_count[tag]
# Remove from cache and hash index
logger.info(f"Removing {file_path} from cache")
self.scanner._hash_index.remove_by_path(file_path)
scanner._hash_index.remove_by_path(file_path)
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != file_path
@@ -274,59 +295,140 @@ class LoraFileHandler(FileSystemEventHandler):
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
except Exception as e:
logger.error(f"Error in process_changes: {e}")
logger.error(f"Error in process_changes for LoRA: {e}")
class LoraFileMonitor:
"""Monitor for LoRA file changes"""
class CheckpointFileHandler(BaseFileHandler):
"""Handler for checkpoint file system events"""
def __init__(self, scanner: LoraScanner, roots: List[str]):
self.scanner = scanner
scanner.set_file_monitor(self)
def __init__(self, loop: asyncio.AbstractEventLoop):
super().__init__(loop)
# Set supported file extensions for checkpoints
self.file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
async def _process_changes(self, delay: float = 2.0):
"""Process pending changes with debouncing for checkpoint files"""
await asyncio.sleep(delay)
try:
with self.lock:
changes = self.pending_changes.copy()
self.pending_changes.clear()
if not changes:
return
logger.info(f"Processing {len(changes)} checkpoint file changes")
# Get scanner through ServiceRegistry
scanner = await ServiceRegistry.get_checkpoint_scanner()
cache = await scanner.get_cached_data()
needs_resort = False
new_folders = set()
for action, file_path in changes:
try:
if action == 'add':
# Check if file already exists in cache
existing = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if existing:
logger.info(f"File {file_path} already in cache, skipping")
continue
# Scan new file
model_data = await scanner.scan_single_model(file_path)
if model_data:
# Update tags count if applicable
for tag in model_data.get('tags', []):
scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1
cache.raw_data.append(model_data)
new_folders.add(model_data['folder'])
# Update hash index
if 'sha256' in model_data:
scanner._hash_index.add_entry(
model_data['sha256'],
model_data['file_path']
)
needs_resort = True
elif action == 'remove':
# Find the model to remove so we can update tags count
model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if model_to_remove:
# Update tags count by reducing counts
for tag in model_to_remove.get('tags', []):
if tag in scanner._tags_count:
scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
if scanner._tags_count[tag] == 0:
del scanner._tags_count[tag]
# Remove from cache and hash index
logger.info(f"Removing {file_path} from checkpoint cache")
scanner._hash_index.remove_by_path(file_path)
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != file_path
]
needs_resort = True
except Exception as e:
logger.error(f"Error processing checkpoint {action} for {file_path}: {e}")
if needs_resort:
await cache.resort()
# Update folder list
all_folders = set(cache.folders) | new_folders
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
except Exception as e:
logger.error(f"Error in process_changes for checkpoint: {e}")
class BaseFileMonitor:
"""Base class for file monitoring"""
def __init__(self, monitor_paths: List[str]):
self.observer = Observer()
self.loop = asyncio.get_event_loop()
self.handler = LoraFileHandler(scanner, self.loop)
# 使用已存在的路径映射
self.monitor_paths = set()
for root in roots:
self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/'))
# Process monitor paths
for path in monitor_paths:
self.monitor_paths.add(os.path.realpath(path).replace(os.sep, '/'))
# 添加所有已映射的目标路径
# Add mapped paths from config
for target_path in config._path_mappings.keys():
self.monitor_paths.add(target_path)
def start(self):
"""Start monitoring"""
for path_info in self.monitor_paths:
"""Start file monitoring"""
for path in self.monitor_paths:
try:
if isinstance(path_info, tuple):
# 对于链接,监控目标路径
_, target_path = path_info
self.observer.schedule(self.handler, target_path, recursive=True)
logger.info(f"Started monitoring target path: {target_path}")
else:
# 对于普通路径,直接监控
self.observer.schedule(self.handler, path_info, recursive=True)
logger.info(f"Started monitoring: {path_info}")
self.observer.schedule(self.handler, path, recursive=True)
logger.info(f"Started monitoring: {path}")
except Exception as e:
logger.error(f"Error monitoring {path_info}: {e}")
logger.error(f"Error monitoring {path}: {e}")
self.observer.start()
def stop(self):
"""Stop monitoring"""
"""Stop file monitoring"""
self.observer.stop()
self.observer.join()
def rescan_links(self):
"""重新扫描链接(当添加新的链接时调用)"""
"""Rescan links when new ones are added"""
# Find new paths not yet being monitored
new_paths = set()
for path in self.monitor_paths.copy():
self._add_link_targets(path)
for path in config._path_mappings.keys():
real_path = os.path.realpath(path).replace(os.sep, '/')
if real_path not in self.monitor_paths:
new_paths.add(real_path)
self.monitor_paths.add(real_path)
# 添加新发现的路径到监控
new_paths = self.monitor_paths - set(self.observer.watches.keys())
# Add new paths to monitoring
for path in new_paths:
try:
self.observer.schedule(self.handler, path, recursive=True)
@@ -334,88 +436,86 @@ class LoraFileMonitor:
except Exception as e:
logger.error(f"Error adding new monitor for {path}: {e}")
# Add CheckpointFileMonitor class
class CheckpointFileMonitor(LoraFileMonitor):
class LoraFileMonitor(BaseFileMonitor):
"""Monitor for LoRA file changes"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls, monitor_paths=None):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, monitor_paths=None):
if not hasattr(self, '_initialized'):
if monitor_paths is None:
from ..config import config
monitor_paths = config.loras_roots
super().__init__(monitor_paths)
self.handler = LoraFileHandler(self.loop)
self._initialized = True
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
from ..config import config
cls._instance = cls(config.loras_roots)
return cls._instance
class CheckpointFileMonitor(BaseFileMonitor):
"""Monitor for checkpoint file changes"""
def __init__(self, scanner: CheckpointScanner, roots: List[str]):
# Reuse most of the LoraFileMonitor functionality, but with a different handler
self.scanner = scanner
scanner.set_file_monitor(self)
self.observer = Observer()
self.loop = asyncio.get_event_loop()
self.handler = CheckpointFileHandler(scanner, self.loop)
# Use existing path mappings
self.monitor_paths = set()
for root in roots:
self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/'))
# Add all mapped target paths
for target_path in config._path_mappings.keys():
self.monitor_paths.add(target_path)
class CheckpointFileHandler(LoraFileHandler):
"""Handler for checkpoint file system events"""
_instance = None
_lock = asyncio.Lock()
def __init__(self, scanner: CheckpointScanner, loop: asyncio.AbstractEventLoop):
super().__init__(scanner, loop)
# Configure supported file extensions
self.supported_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft'}
def on_created(self, event):
if event.is_directory:
return
# Handle supported file extensions directly
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext in self.supported_extensions:
if self._should_ignore(event.src_path):
return
# Process this file directly
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
if normalized_path not in self.scheduled_files:
logger.info(f"Checkpoint file created: {event.src_path}")
self.scheduled_files.add(normalized_path)
self._schedule_update('add', event.src_path)
# Ignore modifications for a short period after creation
self.loop.call_later(
self.debounce_delay * 2,
self.scheduled_files.discard,
normalized_path
)
def __new__(cls, monitor_paths=None):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def on_modified(self, event):
if event.is_directory:
return
# Only process supported file types
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext in self.supported_extensions:
super().on_modified(event)
def __init__(self, monitor_paths=None):
if not hasattr(self, '_initialized'):
if monitor_paths is None:
# Get checkpoint roots from scanner
monitor_paths = []
# We'll initialize monitor paths later when scanner is available
def on_deleted(self, event):
if event.is_directory:
return
super().__init__(monitor_paths or [])
self.handler = CheckpointFileHandler(self.loop)
self._initialized = True
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls([])
# Now get checkpoint roots from scanner
from .checkpoint_scanner import CheckpointScanner
scanner = await CheckpointScanner.get_instance()
monitor_paths = scanner.get_model_roots()
# Update monitor paths
for path in monitor_paths:
real_path = os.path.realpath(path).replace(os.sep, '/')
cls._instance.monitor_paths.add(real_path)
return cls._instance
async def initialize_paths(self):
"""Initialize monitor paths from scanner"""
if not self.monitor_paths:
scanner = await ServiceRegistry.get_checkpoint_scanner()
monitor_paths = scanner.get_model_roots()
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext not in self.supported_extensions:
return
super().on_deleted(event)
def on_moved(self, event):
"""Handle file move/rename events"""
src_ext = os.path.splitext(event.src_path)[1].lower()
dest_ext = os.path.splitext(event.dest_path)[1].lower()
# If destination has supported extension, treat as new file
if dest_ext in self.supported_extensions:
super().on_moved(event)
# If source was supported extension, treat as deleted
elif src_ext in self.supported_extensions:
super().on_moved(event)
# Update monitor paths
for path in monitor_paths:
real_path = os.path.realpath(path).replace(os.sep, '/')
self.monitor_paths.add(real_path)

View File

@@ -13,6 +13,7 @@ from .lora_hash_index import LoraHashIndex
from .settings_manager import settings
from ..utils.constants import NSFW_LEVELS
from ..utils.utils import fuzzy_match
from .service_registry import ServiceRegistry
import sys
logger = logging.getLogger(__name__)

View File

@@ -12,13 +12,13 @@ from ..utils.file_utils import load_metadata, get_file_info, find_preview_file,
from .model_cache import ModelCache
from .model_hash_index import ModelHashIndex
from ..utils.constants import PREVIEW_EXTENSIONS
from .service_registry import ServiceRegistry
logger = logging.getLogger(__name__)
class ModelScanner:
"""Base service for scanning and managing model files"""
_instance = None
_lock = asyncio.Lock()
def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None):
@@ -35,14 +35,17 @@ class ModelScanner:
self.file_extensions = file_extensions
self._cache = None
self._hash_index = hash_index or ModelHashIndex()
self.file_monitor = None
self._tags_count = {} # Dictionary to store tag counts
self._is_initializing = False # Flag to track initialization state
# Register this service
asyncio.create_task(self._register_service())
async def _register_service(self):
"""Register this instance with the ServiceRegistry"""
service_name = f"{self.model_type}_scanner"
await ServiceRegistry.register_service(service_name, self)
def set_file_monitor(self, monitor):
"""Set file monitor instance"""
self.file_monitor = monitor
async def initialize_in_background(self) -> None:
"""Initialize cache in background using thread pool"""
try:
@@ -366,12 +369,20 @@ class ModelScanner:
file_size = os.path.getsize(real_source)
if self.file_monitor:
self.file_monitor.handler.add_ignore_path(
# Get the appropriate file monitor through ServiceRegistry
if self.model_type == "lora":
monitor = await ServiceRegistry.get_lora_monitor()
elif self.model_type == "checkpoint":
monitor = await ServiceRegistry.get_checkpoint_monitor()
else:
monitor = None
if monitor:
monitor.handler.add_ignore_path(
real_source,
file_size
)
self.file_monitor.handler.add_ignore_path(
monitor.handler.add_ignore_path(
real_target,
file_size
)

View File

@@ -5,8 +5,8 @@ import json
from typing import List, Dict, Optional, Any, Tuple
from ..config import config
from .recipe_cache import RecipeCache
from .service_registry import ServiceRegistry
from .lora_scanner import LoraScanner
from .civitai_client import CivitaiClient
from ..utils.utils import fuzzy_match
import sys
@@ -18,11 +18,22 @@ class RecipeScanner:
_instance = None
_lock = asyncio.Lock()
@classmethod
async def get_instance(cls, lora_scanner: Optional[LoraScanner] = None):
"""Get singleton instance of RecipeScanner"""
async with cls._lock:
if cls._instance is None:
if not lora_scanner:
# Get lora scanner from service registry if not provided
lora_scanner = await ServiceRegistry.get_lora_scanner()
cls._instance = cls(lora_scanner)
return cls._instance
def __new__(cls, lora_scanner: Optional[LoraScanner] = None):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._lora_scanner = lora_scanner
cls._instance._civitai_client = CivitaiClient()
cls._instance._civitai_client = None # Will be lazily initialized
return cls._instance
def __init__(self, lora_scanner: Optional[LoraScanner] = None):
@@ -36,6 +47,12 @@ class RecipeScanner:
self._lora_scanner = lora_scanner
self._initialized = True
async def _get_civitai_client(self):
"""Lazily initialize CivitaiClient from registry"""
if self._civitai_client is None:
self._civitai_client = await ServiceRegistry.get_civitai_client()
return self._civitai_client
async def initialize_in_background(self) -> None:
"""Initialize cache in background using thread pool"""
try:
@@ -306,10 +323,13 @@ class RecipeScanner:
async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]:
"""Get hash from Civitai API"""
try:
if not self._civitai_client:
# Get CivitaiClient from ServiceRegistry
civitai_client = await self._get_civitai_client()
if not civitai_client:
logger.error("Failed to get CivitaiClient from ServiceRegistry")
return None
version_info = await self._civitai_client.get_model_version_info(model_version_id)
version_info = await civitai_client.get_model_version_info(model_version_id)
if not version_info or not version_info.get('files'):
logger.debug(f"No files found in version info for ID: {model_version_id}")
@@ -329,10 +349,12 @@ class RecipeScanner:
async def _get_model_version_name(self, model_version_id: str) -> Optional[str]:
"""Get model version name from Civitai API"""
try:
if not self._civitai_client:
# Get CivitaiClient from ServiceRegistry
civitai_client = await self._get_civitai_client()
if not civitai_client:
return None
version_info = await self._civitai_client.get_model_version_info(model_version_id)
version_info = await civitai_client.get_model_version_info(model_version_id)
if version_info and 'name' in version_info:
return version_info['name']

View File

@@ -0,0 +1,124 @@
import asyncio
import logging
from typing import Optional, Dict, Any, TypeVar, Type
logger = logging.getLogger(__name__)
T = TypeVar('T') # Define a type variable for service types
class ServiceRegistry:
"""Centralized registry for service singletons"""
_instance = None
_services: Dict[str, Any] = {}
_lock = asyncio.Lock()
@classmethod
def get_instance(cls):
"""Get singleton instance of the registry"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
async def register_service(cls, service_name: str, service_instance: Any) -> None:
"""Register a service instance with the registry"""
registry = cls.get_instance()
async with cls._lock:
registry._services[service_name] = service_instance
logger.debug(f"Registered service: {service_name}")
@classmethod
async def get_service(cls, service_name: str) -> Any:
"""Get a service instance by name"""
registry = cls.get_instance()
async with cls._lock:
if service_name not in registry._services:
logger.warning(f"Service {service_name} not found in registry")
return None
return registry._services[service_name]
# Convenience methods for common services
@classmethod
async def get_lora_scanner(cls):
"""Get the LoraScanner instance"""
from .lora_scanner import LoraScanner
scanner = await cls.get_service("lora_scanner")
if scanner is None:
scanner = await LoraScanner.get_instance()
await cls.register_service("lora_scanner", scanner)
return scanner
@classmethod
async def get_checkpoint_scanner(cls):
"""Get the CheckpointScanner instance"""
from .checkpoint_scanner import CheckpointScanner
scanner = await cls.get_service("checkpoint_scanner")
if scanner is None:
scanner = await CheckpointScanner.get_instance()
await cls.register_service("checkpoint_scanner", scanner)
return scanner
@classmethod
async def get_lora_monitor(cls):
"""Get the LoraFileMonitor instance"""
from .file_monitor import LoraFileMonitor
monitor = await cls.get_service("lora_monitor")
if monitor is None:
monitor = await LoraFileMonitor.get_instance()
await cls.register_service("lora_monitor", monitor)
return monitor
@classmethod
async def get_checkpoint_monitor(cls):
"""Get the CheckpointFileMonitor instance"""
from .file_monitor import CheckpointFileMonitor
monitor = await cls.get_service("checkpoint_monitor")
if monitor is None:
monitor = await CheckpointFileMonitor.get_instance()
await cls.register_service("checkpoint_monitor", monitor)
return monitor
@classmethod
async def get_civitai_client(cls):
"""Get the CivitaiClient instance"""
from .civitai_client import CivitaiClient
client = await cls.get_service("civitai_client")
if client is None:
client = await CivitaiClient.get_instance()
await cls.register_service("civitai_client", client)
return client
@classmethod
async def get_download_manager(cls):
"""Get the DownloadManager instance"""
from .download_manager import DownloadManager
manager = await cls.get_service("download_manager")
if manager is None:
# We'll let DownloadManager.get_instance handle file_monitor parameter
manager = await DownloadManager.get_instance()
await cls.register_service("download_manager", manager)
return manager
@classmethod
async def get_recipe_scanner(cls):
"""Get the RecipeScanner instance"""
from .recipe_scanner import RecipeScanner
scanner = await cls.get_service("recipe_scanner")
if scanner is None:
lora_scanner = await cls.get_lora_scanner()
scanner = RecipeScanner(lora_scanner)
await cls.register_service("recipe_scanner", scanner)
return scanner
@classmethod
async def get_websocket_manager(cls):
"""Get the WebSocketManager instance"""
from .websocket_manager import ws_manager
manager = await cls.get_service("websocket_manager")
if manager is None:
# ws_manager is already a global instance in websocket_manager.py
from .websocket_manager import ws_manager
await cls.register_service("websocket_manager", ws_manager)
manager = ws_manager
return manager