mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 14:42:11 -03:00
checkpoint
This commit is contained in:
@@ -5,10 +5,7 @@ from .routes.lora_routes import LoraRoutes
|
||||
from .routes.api_routes import ApiRoutes
|
||||
from .routes.recipe_routes import RecipeRoutes
|
||||
from .routes.checkpoints_routes import CheckpointsRoutes
|
||||
from .services.lora_scanner import LoraScanner
|
||||
from .services.checkpoint_scanner import CheckpointScanner
|
||||
from .services.recipe_scanner import RecipeScanner
|
||||
from .services.file_monitor import LoraFileMonitor, CheckpointFileMonitor
|
||||
from .services.service_registry import ServiceRegistry
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -41,8 +38,10 @@ class LoraManager:
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(real_root)
|
||||
|
||||
# Get checkpoint scanner instance
|
||||
checkpoint_scanner = asyncio.run(ServiceRegistry.get_checkpoint_scanner())
|
||||
|
||||
# Add static routes for each checkpoint root
|
||||
checkpoint_scanner = CheckpointScanner()
|
||||
for idx, root in enumerate(checkpoint_scanner.get_model_roots(), start=1):
|
||||
preview_path = f'/checkpoints_static/root{idx}/preview'
|
||||
|
||||
@@ -79,51 +78,89 @@ class LoraManager:
|
||||
lora_routes = LoraRoutes()
|
||||
checkpoints_routes = CheckpointsRoutes()
|
||||
|
||||
# Setup file monitoring
|
||||
lora_monitor = LoraFileMonitor(lora_routes.scanner, config.loras_roots)
|
||||
lora_monitor.start()
|
||||
|
||||
checkpoint_monitor = CheckpointFileMonitor(checkpoints_routes.scanner, checkpoints_routes.scanner.get_model_roots())
|
||||
checkpoint_monitor.start()
|
||||
|
||||
# Initialize routes
|
||||
lora_routes.setup_routes(app)
|
||||
checkpoints_routes.setup_routes(app)
|
||||
ApiRoutes.setup_routes(app, lora_monitor)
|
||||
ApiRoutes.setup_routes(app)
|
||||
RecipeRoutes.setup_routes(app)
|
||||
|
||||
# Store monitors in app for cleanup
|
||||
app['lora_monitor'] = lora_monitor
|
||||
app['checkpoint_monitor'] = checkpoint_monitor
|
||||
|
||||
logger.info("PromptServer app: ", app)
|
||||
|
||||
# Schedule cache initialization using the application's startup handler
|
||||
app.on_startup.append(lambda app: cls._schedule_cache_init(
|
||||
lora_routes.scanner,
|
||||
checkpoints_routes.scanner,
|
||||
lora_routes.recipe_scanner
|
||||
))
|
||||
# Schedule service initialization
|
||||
app.on_startup.append(lambda app: cls._initialize_services())
|
||||
|
||||
# Add cleanup
|
||||
app.on_shutdown.append(cls._cleanup)
|
||||
app.on_shutdown.append(ApiRoutes.cleanup)
|
||||
|
||||
@classmethod
|
||||
async def _schedule_cache_init(cls, lora_scanner, checkpoint_scanner, recipe_scanner):
|
||||
"""Schedule cache initialization in the running event loop"""
|
||||
async def _initialize_services(cls):
|
||||
"""Initialize all services using the ServiceRegistry"""
|
||||
try:
|
||||
logger.info("LoRA Manager: Initializing services via ServiceRegistry")
|
||||
|
||||
# Initialize CivitaiClient first to ensure it's ready for other services
|
||||
civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
logger.info("CivitaiClient registered in ServiceRegistry")
|
||||
|
||||
# Get file monitors through ServiceRegistry
|
||||
lora_monitor = await ServiceRegistry.get_lora_monitor()
|
||||
checkpoint_monitor = await ServiceRegistry.get_checkpoint_monitor()
|
||||
|
||||
# Start monitors
|
||||
lora_monitor.start()
|
||||
logger.info("Lora monitor started")
|
||||
|
||||
# Make sure checkpoint monitor has paths before starting
|
||||
await checkpoint_monitor.initialize_paths()
|
||||
checkpoint_monitor.start()
|
||||
logger.info("Checkpoint monitor started")
|
||||
|
||||
# Register DownloadManager with ServiceRegistry
|
||||
download_manager = await ServiceRegistry.get_download_manager()
|
||||
logger.info("DownloadManager registered in ServiceRegistry")
|
||||
|
||||
# Initialize WebSocket manager
|
||||
ws_manager = await ServiceRegistry.get_websocket_manager()
|
||||
logger.info("WebSocketManager registered in ServiceRegistry")
|
||||
|
||||
# Initialize scanners in background
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
# Initialize recipe scanner if needed
|
||||
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
|
||||
# Create low-priority initialization tasks
|
||||
lora_task = asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init')
|
||||
checkpoint_task = asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init')
|
||||
recipe_task = asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
|
||||
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init')
|
||||
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init')
|
||||
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
|
||||
|
||||
logger.info("LoRA Manager: All services initialized and background tasks scheduled")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LoRA Manager: Error scheduling cache initialization: {e}")
|
||||
logger.error(f"LoRA Manager: Error initializing services: {e}", exc_info=True)
|
||||
|
||||
@classmethod
|
||||
async def _cleanup(cls, app):
|
||||
"""Cleanup resources"""
|
||||
if 'lora_monitor' in app:
|
||||
app['lora_monitor'].stop()
|
||||
"""Cleanup resources using ServiceRegistry"""
|
||||
try:
|
||||
logger.info("LoRA Manager: Cleaning up services")
|
||||
|
||||
if 'checkpoint_monitor' in app:
|
||||
app['checkpoint_monitor'].stop()
|
||||
# Get monitors from ServiceRegistry
|
||||
lora_monitor = await ServiceRegistry.get_service("lora_monitor")
|
||||
if lora_monitor:
|
||||
lora_monitor.stop()
|
||||
logger.info("Stopped LoRA monitor")
|
||||
|
||||
checkpoint_monitor = await ServiceRegistry.get_service("checkpoint_monitor")
|
||||
if checkpoint_monitor:
|
||||
checkpoint_monitor.stop()
|
||||
logger.info("Stopped checkpoint monitor")
|
||||
|
||||
# Close CivitaiClient gracefully
|
||||
civitai_client = await ServiceRegistry.get_service("civitai_client")
|
||||
if civitai_client:
|
||||
await civitai_client.close()
|
||||
logger.info("Closed CivitaiClient connection")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cleanup: {e}", exc_info=True)
|
||||
|
||||
Reference in New Issue
Block a user