diff --git a/py/lora_manager.py b/py/lora_manager.py index b37d46a2..3f706e29 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -5,10 +5,7 @@ from .routes.lora_routes import LoraRoutes from .routes.api_routes import ApiRoutes from .routes.recipe_routes import RecipeRoutes from .routes.checkpoints_routes import CheckpointsRoutes -from .services.lora_scanner import LoraScanner -from .services.checkpoint_scanner import CheckpointScanner -from .services.recipe_scanner import RecipeScanner -from .services.file_monitor import LoraFileMonitor, CheckpointFileMonitor +from .services.service_registry import ServiceRegistry import logging logger = logging.getLogger(__name__) @@ -41,8 +38,10 @@ class LoraManager: config.add_route_mapping(real_root, preview_path) added_targets.add(real_root) + # Get checkpoint scanner instance + checkpoint_scanner = asyncio.run(ServiceRegistry.get_checkpoint_scanner()) + # Add static routes for each checkpoint root - checkpoint_scanner = CheckpointScanner() for idx, root in enumerate(checkpoint_scanner.get_model_roots(), start=1): preview_path = f'/checkpoints_static/root{idx}/preview' @@ -79,51 +78,89 @@ class LoraManager: lora_routes = LoraRoutes() checkpoints_routes = CheckpointsRoutes() - # Setup file monitoring - lora_monitor = LoraFileMonitor(lora_routes.scanner, config.loras_roots) - lora_monitor.start() - - checkpoint_monitor = CheckpointFileMonitor(checkpoints_routes.scanner, checkpoints_routes.scanner.get_model_roots()) - checkpoint_monitor.start() - + # Initialize routes lora_routes.setup_routes(app) checkpoints_routes.setup_routes(app) - ApiRoutes.setup_routes(app, lora_monitor) + ApiRoutes.setup_routes(app) RecipeRoutes.setup_routes(app) - # Store monitors in app for cleanup - app['lora_monitor'] = lora_monitor - app['checkpoint_monitor'] = checkpoint_monitor - - logger.info("PromptServer app: ", app) - - # Schedule cache initialization using the application's startup handler - app.on_startup.append(lambda app: cls._schedule_cache_init( - lora_routes.scanner, - checkpoints_routes.scanner, - lora_routes.recipe_scanner - )) + # Schedule service initialization + app.on_startup.append(lambda app: cls._initialize_services()) # Add cleanup app.on_shutdown.append(cls._cleanup) app.on_shutdown.append(ApiRoutes.cleanup) @classmethod - async def _schedule_cache_init(cls, lora_scanner, checkpoint_scanner, recipe_scanner): - """Schedule cache initialization in the running event loop""" + async def _initialize_services(cls): + """Initialize all services using the ServiceRegistry""" try: + logger.info("LoRA Manager: Initializing services via ServiceRegistry") + + # Initialize CivitaiClient first to ensure it's ready for other services + civitai_client = await ServiceRegistry.get_civitai_client() + logger.info("CivitaiClient registered in ServiceRegistry") + + # Get file monitors through ServiceRegistry + lora_monitor = await ServiceRegistry.get_lora_monitor() + checkpoint_monitor = await ServiceRegistry.get_checkpoint_monitor() + + # Start monitors + lora_monitor.start() + logger.info("Lora monitor started") + + # Make sure checkpoint monitor has paths before starting + await checkpoint_monitor.initialize_paths() + checkpoint_monitor.start() + logger.info("Checkpoint monitor started") + + # Register DownloadManager with ServiceRegistry + download_manager = await ServiceRegistry.get_download_manager() + logger.info("DownloadManager registered in ServiceRegistry") + + # Initialize WebSocket manager + ws_manager = await ServiceRegistry.get_websocket_manager() + logger.info("WebSocketManager registered in ServiceRegistry") + + # Initialize scanners in background + lora_scanner = await ServiceRegistry.get_lora_scanner() + checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() + + # Initialize recipe scanner if needed + recipe_scanner = await ServiceRegistry.get_recipe_scanner() + # Create low-priority initialization tasks - lora_task = asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init') - checkpoint_task = asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init') - recipe_task = asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init') + asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init') + asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init') + asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init') + + logger.info("LoRA Manager: All services initialized and background tasks scheduled") + except Exception as e: - logger.error(f"LoRA Manager: Error scheduling cache initialization: {e}") + logger.error(f"LoRA Manager: Error initializing services: {e}", exc_info=True) @classmethod async def _cleanup(cls, app): - """Cleanup resources""" - if 'lora_monitor' in app: - app['lora_monitor'].stop() + """Cleanup resources using ServiceRegistry""" + try: + logger.info("LoRA Manager: Cleaning up services") - if 'checkpoint_monitor' in app: - app['checkpoint_monitor'].stop() + # Get monitors from ServiceRegistry + lora_monitor = await ServiceRegistry.get_service("lora_monitor") + if lora_monitor: + lora_monitor.stop() + logger.info("Stopped LoRA monitor") + + checkpoint_monitor = await ServiceRegistry.get_service("checkpoint_monitor") + if checkpoint_monitor: + checkpoint_monitor.stop() + logger.info("Stopped checkpoint monitor") + + # Close CivitaiClient gracefully + civitai_client = await ServiceRegistry.get_service("civitai_client") + if civitai_client: + await civitai_client.close() + logger.info("Closed CivitaiClient connection") + + except Exception as e: + logger.error(f"Error during cleanup: {e}", exc_info=True) diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index 2c63827e..b7506bb8 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -19,22 +19,33 @@ from .update_routes import UpdateRoutes from ..services.recipe_scanner import RecipeScanner from ..utils.constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH from ..utils.exif_utils import ExifUtils +from ..services.service_registry import ServiceRegistry logger = logging.getLogger(__name__) class ApiRoutes: """API route handlers for LoRA management""" - def __init__(self, file_monitor: LoraFileMonitor): - self.scanner = LoraScanner() - self.civitai_client = CivitaiClient() - self.download_manager = DownloadManager(file_monitor) + def __init__(self): + self.scanner = None # Will be initialized in setup_routes + self.civitai_client = None # Will be initialized in setup_routes + self.download_manager = None # Will be initialized in setup_routes self._download_lock = asyncio.Lock() + async def initialize_services(self): + """Initialize services from ServiceRegistry""" + self.scanner = await ServiceRegistry.get_lora_scanner() + self.civitai_client = await ServiceRegistry.get_civitai_client() + self.download_manager = await ServiceRegistry.get_download_manager() + @classmethod - def setup_routes(cls, app: web.Application, monitor: LoraFileMonitor): + def setup_routes(cls, app: web.Application): """Register API routes""" - routes = cls(monitor) + routes = cls() + + # Schedule service initialization on app startup + app.on_startup.append(lambda _: routes.initialize_services()) + app.router.add_post('/api/delete_model', routes.delete_model) app.router.add_post('/api/fetch-civitai', routes.fetch_civitai) app.router.add_post('/api/replace_preview', routes.replace_preview) @@ -63,19 +74,28 @@ class ApiRoutes: async def delete_model(self, request: web.Request) -> web.Response: """Handle model deletion request""" + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() return await ModelRouteUtils.handle_delete_model(request, self.scanner) async def fetch_civitai(self, request: web.Request) -> web.Response: """Handle CivitAI metadata fetch request""" + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() return await ModelRouteUtils.handle_fetch_civitai(request, self.scanner) async def replace_preview(self, request: web.Request) -> web.Response: """Handle preview image replacement request""" + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() return await ModelRouteUtils.handle_replace_preview(request, self.scanner) async def get_loras(self, request: web.Request) -> web.Response: """Handle paginated LoRA data request""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + # Parse query parameters page = int(request.query.get('page', '1')) page_size = int(request.query.get('page_size', '20')) @@ -231,6 +251,9 @@ class ApiRoutes: async def fetch_all_civitai(self, request: web.Request) -> web.Response: """Fetch CivitAI metadata for all loras in the background""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + cache = await self.scanner.get_cached_data() total = len(cache.raw_data) processed = 0 @@ -312,6 +335,9 @@ class ApiRoutes: async def get_folders(self, request: web.Request) -> web.Response: """Get all folders in the cache""" + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + cache = await self.scanner.get_cached_data() return web.json_response({ 'folders': cache.folders @@ -320,6 +346,12 @@ class ApiRoutes: async def get_civitai_versions(self, request: web.Request) -> web.Response: """Get available versions for a Civitai model with local availability info""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + + if self.civitai_client is None: + self.civitai_client = await ServiceRegistry.get_civitai_client() + model_id = request.match_info['model_id'] versions = await self.civitai_client.get_model_versions(model_id) if not versions: @@ -353,9 +385,12 @@ class ApiRoutes: async def get_civitai_model(self, request: web.Request) -> web.Response: """Get CivitAI model details by model version ID or hash""" try: - model_version_id = request.match_info['modelVersionId'] + if self.civitai_client is None: + self.civitai_client = await ServiceRegistry.get_civitai_client() + + model_version_id = request.match_info.get('modelVersionId') if not model_version_id: - hash = request.match_info['hash'] + hash = request.match_info.get('hash') model = await self.civitai_client.get_model_by_hash(hash) return web.json_response(model) @@ -370,6 +405,9 @@ class ApiRoutes: async def download_lora(self, request: web.Request) -> web.Response: async with self._download_lock: try: + if self.download_manager is None: + self.download_manager = await ServiceRegistry.get_download_manager() + data = await request.json() # Create progress callback @@ -447,6 +485,9 @@ class ApiRoutes: async def move_model(self, request: web.Request) -> web.Response: """Handle model move request""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + data = await request.json() file_path = data.get('file_path') # full path of the model file, e.g. /path/to/model.safetensors target_path = data.get('target_path') # folder path to move the model to, e.g. /path/to/target_folder @@ -485,12 +526,17 @@ class ApiRoutes: @classmethod async def cleanup(cls): """Add cleanup method for application shutdown""" - if hasattr(cls, '_instance'): - await cls._instance.civitai_client.close() + # Now we don't need to store an instance, as services are managed by ServiceRegistry + civitai_client = await ServiceRegistry.get_civitai_client() + if civitai_client: + await civitai_client.close() async def save_metadata(self, request: web.Request) -> web.Response: """Handle saving metadata updates""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + data = await request.json() file_path = data.get('file_path') if not file_path: @@ -536,6 +582,9 @@ class ApiRoutes: async def get_lora_preview_url(self, request: web.Request) -> web.Response: """Get the static preview URL for a LoRA file""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + # Get lora file name from query parameters lora_name = request.query.get('name') if not lora_name: @@ -574,6 +623,9 @@ class ApiRoutes: async def get_lora_civitai_url(self, request: web.Request) -> web.Response: """Get the Civitai URL for a LoRA file""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + # Get lora file name from query parameters lora_name = request.query.get('name') if not lora_name: @@ -619,6 +671,9 @@ class ApiRoutes: async def move_models_bulk(self, request: web.Request) -> web.Response: """Handle bulk model move request""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + data = await request.json() file_paths = data.get('file_paths', []) # list of full paths of the model files, e.g. ["/path/to/model1.safetensors", "/path/to/model2.safetensors"] target_path = data.get('target_path') # folder path to move the models to, e.g. "/path/to/target_folder" @@ -677,6 +732,9 @@ class ApiRoutes: async def get_lora_model_description(self, request: web.Request) -> web.Response: """Get model description for a Lora model""" try: + if self.civitai_client is None: + self.civitai_client = await ServiceRegistry.get_civitai_client() + # Get parameters model_id = request.query.get('model_id') file_path = request.query.get('file_path') @@ -736,6 +794,9 @@ class ApiRoutes: async def get_top_tags(self, request: web.Request) -> web.Response: """Handle request for top tags sorted by frequency""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + # Parse query parameters limit = int(request.query.get('limit', '20')) @@ -761,6 +822,9 @@ class ApiRoutes: async def get_base_models(self, request: web.Request) -> web.Response: """Get base models used in loras""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + # Parse query parameters limit = int(request.query.get('limit', '20')) @@ -785,6 +849,12 @@ class ApiRoutes: async def rename_lora(self, request: web.Request) -> web.Response: """Handle renaming a LoRA file and its associated files""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + + if self.download_manager is None: + self.download_manager = await ServiceRegistry.get_download_manager() + data = await request.json() file_path = data.get('file_path') new_file_name = data.get('new_file_name') @@ -891,7 +961,7 @@ class ApiRoutes: # Update recipe files and cache if hash is available if hash_value: - recipe_scanner = RecipeScanner(self.scanner) + recipe_scanner = await ServiceRegistry.get_recipe_scanner() recipes_updated, cache_updated = await recipe_scanner.update_lora_filename_by_hash(hash_value, new_file_name) logger.info(f"Updated {recipes_updated} recipe files and {cache_updated} cache entries for renamed LoRA") diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index 1d3627a5..e50ff042 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -11,6 +11,7 @@ from ..services.civitai_client import CivitaiClient from ..services.websocket_manager import ws_manager from ..services.checkpoint_scanner import CheckpointScanner from ..services.download_manager import DownloadManager +from ..services.service_registry import ServiceRegistry from ..config import config from ..services.settings_manager import settings from ..utils.utils import fuzzy_match @@ -21,16 +22,24 @@ class CheckpointsRoutes: """API routes for checkpoint management""" def __init__(self): - self.scanner = CheckpointScanner() + self.scanner = None # Will be initialized in setup_routes self.template_env = jinja2.Environment( loader=jinja2.FileSystemLoader(config.templates_path), autoescape=True ) - self.download_manager = DownloadManager() + self.download_manager = None # Will be initialized in setup_routes self._download_lock = asyncio.Lock() + async def initialize_services(self): + """Initialize services from ServiceRegistry""" + self.scanner = await ServiceRegistry.get_checkpoint_scanner() + self.download_manager = await ServiceRegistry.get_download_manager() + def setup_routes(self, app): """Register routes with the aiohttp app""" + # Schedule service initialization on app startup + app.on_startup.append(lambda _: self.initialize_services()) + app.router.add_get('/checkpoints', self.handle_checkpoints_page) app.router.add_get('/api/checkpoints', self.get_checkpoints) app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai) @@ -488,10 +497,9 @@ class CheckpointsRoutes: async def download_checkpoint(self, request: web.Request) -> web.Response: """Handle checkpoint download request""" async with self._download_lock: - # Initialize DownloadManager with the file monitor if the scanner has one - if not hasattr(self, 'download_manager') or self.download_manager is None: - file_monitor = getattr(self.scanner, 'file_monitor', None) - self.download_manager = DownloadManager(file_monitor) + # Get the download manager from service registry if not already initialized + if self.download_manager is None: + self.download_manager = await ServiceRegistry.get_download_manager() # Use the common download handler with model_type="checkpoint" return await ModelRouteUtils.handle_download_model( @@ -503,6 +511,9 @@ class CheckpointsRoutes: async def get_checkpoint_roots(self, request): """Return the checkpoint root directories""" try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_checkpoint_scanner() + roots = self.scanner.get_model_roots() return web.json_response({ "success": True, diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index a2c392fc..91e1d2eb 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -6,7 +6,8 @@ import logging from ..services.lora_scanner import LoraScanner from ..services.recipe_scanner import RecipeScanner from ..config import config -from ..services.settings_manager import settings # Add this import +from ..services.settings_manager import settings +from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import logger = logging.getLogger(__name__) logging.getLogger('asyncio').setLevel(logging.CRITICAL) @@ -15,13 +16,24 @@ class LoraRoutes: """Route handlers for LoRA management endpoints""" def __init__(self): - self.scanner = LoraScanner() - self.recipe_scanner = RecipeScanner(self.scanner) + # Initialize service references as None, will be set during async init + self.scanner = None + self.recipe_scanner = None self.template_env = jinja2.Environment( loader=jinja2.FileSystemLoader(config.templates_path), autoescape=True ) + async def init_services(self): + """Initialize services from ServiceRegistry""" + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + logger.info("LoraRoutes: Retrieved LoraScanner from ServiceRegistry") + + if self.recipe_scanner is None: + self.recipe_scanner = await ServiceRegistry.get_recipe_scanner() + logger.info("LoraRoutes: Retrieved RecipeScanner from ServiceRegistry") + def format_lora_data(self, lora: Dict) -> Dict: """Format LoRA data for template rendering""" return { @@ -58,7 +70,10 @@ class LoraRoutes: async def handle_loras_page(self, request: web.Request) -> web.Response: """Handle GET /loras request""" try: - # 检查缓存初始化状态,根据initialize_in_background的工作方式调整判断逻辑 + # Ensure services are initialized + await self.init_services() + + # Check if the LoraScanner is initializing is_initializing = ( self.scanner._cache is None or len(self.scanner._cache.raw_data) == 0 or @@ -66,30 +81,29 @@ class LoraRoutes: ) if is_initializing: - # 如果正在初始化,返回一个只包含加载提示的页面 + # If still initializing, return loading page template = self.template_env.get_template('loras.html') rendered = template.render( - folders=[], # 空文件夹列表 - is_initializing=True, # 新增标志 - settings=settings, # Pass settings to template - request=request # Pass the request object to the template + folders=[], + is_initializing=True, + settings=settings, + request=request ) logger.info("Loras page is initializing, returning loading page") else: - # 正常流程 - 获取已经初始化好的缓存数据 + # Normal flow - get data from initialized cache try: cache = await self.scanner.get_cached_data(force_refresh=False) template = self.template_env.get_template('loras.html') rendered = template.render( folders=cache.folders, is_initializing=False, - settings=settings, # Pass settings to template - request=request # Pass the request object to the template + settings=settings, + request=request ) except Exception as cache_error: logger.error(f"Error loading cache data: {cache_error}") - # 如果获取缓存失败,也显示初始化页面 template = self.template_env.get_template('loras.html') rendered = template.render( folders=[], @@ -114,7 +128,10 @@ class LoraRoutes: async def handle_recipes_page(self, request: web.Request) -> web.Response: """Handle GET /loras/recipes request""" try: - # 检查缓存初始化状态,与handle_loras_page保持一致的逻辑 + # Ensure services are initialized + await self.init_services() + + # Check if the RecipeScanner is initializing is_initializing = ( self.recipe_scanner._cache is None or len(self.recipe_scanner._cache.raw_data) == 0 or @@ -183,5 +200,13 @@ class LoraRoutes: def setup_routes(self, app: web.Application): """Register routes with the application""" + # Add an app startup handler to initialize services + app.on_startup.append(self._on_startup) + + # Register routes app.router.add_get('/loras', self.handle_loras_page) app.router.add_get('/loras/recipes', self.handle_recipes_page) + + async def _on_startup(self, app): + """Initialize services when the app starts""" + await self.init_services() diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index bffc208e..76dd7467 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -16,6 +16,7 @@ from ..services.lora_scanner import LoraScanner from ..config import config from ..workflow.parser import WorkflowParser from ..utils.utils import download_civitai_image +from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import logger = logging.getLogger(__name__) @@ -23,13 +24,24 @@ class RecipeRoutes: """API route handlers for Recipe management""" def __init__(self): - self.recipe_scanner = RecipeScanner(LoraScanner()) - self.civitai_client = CivitaiClient() + # Initialize service references as None, will be set during async init + self.recipe_scanner = None + self.civitai_client = None self.parser = WorkflowParser() # Pre-warm the cache self._init_cache_task = None + async def init_services(self): + """Initialize services from ServiceRegistry""" + if self.recipe_scanner is None: + self.recipe_scanner = await ServiceRegistry.get_recipe_scanner() + logger.info("RecipeRoutes: Retrieved RecipeScanner from ServiceRegistry") + + if self.civitai_client is None: + self.civitai_client = await ServiceRegistry.get_civitai_client() + logger.info("RecipeRoutes: Retrieved CivitaiClient from ServiceRegistry") + @classmethod def setup_routes(cls, app: web.Application): """Register API routes""" @@ -68,7 +80,10 @@ class RecipeRoutes: async def _init_cache(self, app): """Initialize cache on startup""" try: - # First, ensure the lora scanner is fully initialized + # Initialize services first + await self.init_services() + + # Now that services are initialized, get the lora scanner lora_scanner = self.recipe_scanner._lora_scanner # Get lora cache to ensure it's initialized @@ -86,6 +101,9 @@ class RecipeRoutes: async def get_recipes(self, request: web.Request) -> web.Response: """API endpoint for getting paginated recipes""" try: + # Ensure services are initialized + await self.init_services() + # Get query parameters with defaults page = int(request.query.get('page', '1')) page_size = int(request.query.get('page_size', '20')) @@ -155,6 +173,9 @@ class RecipeRoutes: async def get_recipe_detail(self, request: web.Request) -> web.Response: """Get detailed information about a specific recipe""" try: + # Ensure services are initialized + await self.init_services() + recipe_id = request.match_info['recipe_id'] # Use the new get_recipe_by_id method from recipe_scanner @@ -208,6 +229,9 @@ class RecipeRoutes: """Analyze an uploaded image or URL for recipe metadata""" temp_path = None try: + # Ensure services are initialized + await self.init_services() + # Check if request contains multipart data (image) or JSON data (url) content_type = request.headers.get('Content-Type', '') @@ -326,6 +350,9 @@ class RecipeRoutes: async def save_recipe(self, request: web.Request) -> web.Response: """Save a recipe to the recipes folder""" try: + # Ensure services are initialized + await self.init_services() + reader = await request.multipart() # Process form data @@ -527,6 +554,9 @@ class RecipeRoutes: async def delete_recipe(self, request: web.Request) -> web.Response: """Delete a recipe by ID""" try: + # Ensure services are initialized + await self.init_services() + recipe_id = request.match_info['recipe_id'] # Get recipes directory @@ -574,6 +604,9 @@ class RecipeRoutes: async def get_top_tags(self, request: web.Request) -> web.Response: """Get top tags used in recipes""" try: + # Ensure services are initialized + await self.init_services() + # Get limit parameter with default limit = int(request.query.get('limit', '20')) @@ -606,6 +639,9 @@ class RecipeRoutes: async def get_base_models(self, request: web.Request) -> web.Response: """Get base models used in recipes""" try: + # Ensure services are initialized + await self.init_services() + # Get all recipes from cache cache = await self.recipe_scanner.get_cached_data() @@ -634,6 +670,9 @@ class RecipeRoutes: async def share_recipe(self, request: web.Request) -> web.Response: """Process a recipe image for sharing by adding metadata to EXIF""" try: + # Ensure services are initialized + await self.init_services() + recipe_id = request.match_info['recipe_id'] # Get all recipes from cache @@ -693,6 +732,9 @@ class RecipeRoutes: async def download_shared_recipe(self, request: web.Request) -> web.Response: """Serve a processed recipe image for download""" try: + # Ensure services are initialized + await self.init_services() + recipe_id = request.match_info['recipe_id'] # Check if we have this shared recipe @@ -749,6 +791,9 @@ class RecipeRoutes: async def save_recipe_from_widget(self, request: web.Request) -> web.Response: """Save a recipe from the LoRAs widget""" try: + # Ensure services are initialized + await self.init_services() + reader = await request.multipart() # Process form data @@ -923,6 +968,9 @@ class RecipeRoutes: async def get_recipe_syntax(self, request: web.Request) -> web.Response: """Generate recipe syntax for LoRAs in the recipe, looking up proper file names using hash_index""" try: + # Ensure services are initialized + await self.init_services() + recipe_id = request.match_info['recipe_id'] # Get all recipes from cache @@ -1003,6 +1051,9 @@ class RecipeRoutes: async def update_recipe(self, request: web.Request) -> web.Response: """Update recipe metadata (name and tags)""" try: + # Ensure services are initialized + await self.init_services() + recipe_id = request.match_info['recipe_id'] data = await request.json() @@ -1030,6 +1081,9 @@ class RecipeRoutes: async def reconnect_lora(self, request: web.Request) -> web.Response: """Reconnect a deleted LoRA in a recipe to a local LoRA file""" try: + # Ensure services are initialized + await self.init_services() + # Parse request data data = await request.json() @@ -1140,6 +1194,9 @@ class RecipeRoutes: async def get_recipes_for_lora(self, request: web.Request) -> web.Response: """Get recipes that use a specific Lora""" try: + # Ensure services are initialized + await self.init_services() + lora_hash = request.query.get('hash') # Hash is required diff --git a/py/services/checkpoint_scanner.py b/py/services/checkpoint_scanner.py index 4cc77b6a..acaaa461 100644 --- a/py/services/checkpoint_scanner.py +++ b/py/services/checkpoint_scanner.py @@ -8,6 +8,7 @@ from ..utils.models import CheckpointMetadata from ..config import config from .model_scanner import ModelScanner from .model_hash_index import ModelHashIndex +from .service_registry import ServiceRegistry logger = logging.getLogger(__name__) diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index fbd77739..24227f8d 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -3,6 +3,7 @@ import aiohttp import os import json import logging +import asyncio from email.parser import Parser from typing import Optional, Dict, Tuple, List from urllib.parse import unquote @@ -11,7 +12,23 @@ from ..utils.models import LoraMetadata logger = logging.getLogger(__name__) class CivitaiClient: + _instance = None + _lock = asyncio.Lock() + + @classmethod + async def get_instance(cls): + """Get singleton instance of CivitaiClient""" + async with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + def __init__(self): + # Check if already initialized for singleton pattern + if hasattr(self, '_initialized'): + return + self._initialized = True + self.base_url = "https://civitai.com/api/v1" self.headers = { 'User-Agent': 'ComfyUI-LoRA-Manager/1.0' diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 91786336..7231fdfb 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -1,12 +1,13 @@ import logging import os import json -from typing import Optional, Dict +import asyncio +from typing import Optional, Dict, Any from .civitai_client import CivitaiClient -from .file_monitor import LoraFileMonitor from ..utils.models import LoraMetadata, CheckpointMetadata from ..utils.constants import CARD_PREVIEW_WIDTH from ..utils.exif_utils import ExifUtils +from .service_registry import ServiceRegistry # Download to temporary file first import tempfile @@ -14,9 +15,46 @@ import tempfile logger = logging.getLogger(__name__) class DownloadManager: - def __init__(self, file_monitor: Optional[LoraFileMonitor] = None): - self.civitai_client = CivitaiClient() - self.file_monitor = file_monitor + _instance = None + _lock = asyncio.Lock() + + @classmethod + async def get_instance(cls): + """Get singleton instance of DownloadManager""" + async with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def __init__(self): + # Check if already initialized for singleton pattern + if hasattr(self, '_initialized'): + return + self._initialized = True + + self._civitai_client = None # Will be lazily initialized + + async def _get_civitai_client(self): + """Lazily initialize CivitaiClient from registry""" + if self._civitai_client is None: + self._civitai_client = await ServiceRegistry.get_civitai_client() + return self._civitai_client + + async def _get_lora_monitor(self): + """Get the lora file monitor from registry""" + return await ServiceRegistry.get_lora_monitor() + + async def _get_checkpoint_monitor(self): + """Get the checkpoint file monitor from registry""" + return await ServiceRegistry.get_checkpoint_monitor() + + async def _get_lora_scanner(self): + """Get the lora scanner from registry""" + return await ServiceRegistry.get_lora_scanner() + + async def _get_checkpoint_scanner(self): + """Get the checkpoint scanner from registry""" + return await ServiceRegistry.get_checkpoint_scanner() async def download_from_civitai(self, download_url: str = None, model_hash: str = None, model_version_id: str = None, save_dir: str = None, @@ -43,19 +81,22 @@ class DownloadManager: # Create directory if it doesn't exist os.makedirs(save_dir, exist_ok=True) + # Get civitai client + civitai_client = await self._get_civitai_client() + # Get version info based on the provided identifier version_info = None if download_url: # Extract version ID from download URL version_id = download_url.split('/')[-1] - version_info = await self.civitai_client.get_model_version_info(version_id) + version_info = await civitai_client.get_model_version_info(version_id) elif model_version_id: # Use model version ID directly - version_info = await self.civitai_client.get_model_version_info(model_version_id) + version_info = await civitai_client.get_model_version_info(model_version_id) elif model_hash: # Get model by hash - version_info = await self.civitai_client.get_model_by_hash(model_hash) + version_info = await civitai_client.get_model_by_hash(model_hash) if not version_info: @@ -95,8 +136,9 @@ class DownloadManager: file_size = file_info.get('sizeKB', 0) * 1024 # 4. Notify file monitor - use normalized path and file size - if self.file_monitor and self.file_monitor.handler: - self.file_monitor.handler.add_ignore_path( + file_monitor = await self._get_lora_monitor() if model_type == "lora" else await self._get_checkpoint_monitor() + if file_monitor and file_monitor.handler: + file_monitor.handler.add_ignore_path( save_path.replace(os.sep, '/'), file_size ) @@ -112,7 +154,7 @@ class DownloadManager: # 5.1 Get and update model tags and description model_id = version_info.get('modelId') if model_id: - model_metadata, _ = await self.civitai_client.get_model_metadata(str(model_id)) + model_metadata, _ = await civitai_client.get_model_metadata(str(model_id)) if model_metadata: if model_metadata.get("tags"): metadata.tags = model_metadata.get("tags", []) @@ -146,6 +188,7 @@ class DownloadManager: model_type: str = "lora") -> Dict: """Execute the actual download process including preview images and model files""" try: + civitai_client = await self._get_civitai_client() save_path = metadata.file_path metadata_path = os.path.splitext(save_path)[0] + '.metadata.json' @@ -165,7 +208,7 @@ class DownloadManager: preview_path = os.path.splitext(save_path)[0] + preview_ext # Download video directly - if await self.civitai_client.download_preview_image(images[0]['url'], preview_path): + if await civitai_client.download_preview_image(images[0]['url'], preview_path): metadata.preview_url = preview_path.replace(os.sep, '/') metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0) with open(metadata_path, 'w', encoding='utf-8') as f: @@ -176,7 +219,7 @@ class DownloadManager: temp_path = temp_file.name # Download the original image to temp path - if await self.civitai_client.download_preview_image(images[0]['url'], temp_path): + if await civitai_client.download_preview_image(images[0]['url'], temp_path): # Optimize and convert to WebP preview_path = os.path.splitext(save_path)[0] + '.webp' @@ -210,7 +253,7 @@ class DownloadManager: await progress_callback(3) # 3% progress after preview download # Download model file with progress tracking - success, result = await self.civitai_client._download_file( + success, result = await civitai_client._download_file( download_url, save_dir, os.path.basename(save_path), @@ -232,13 +275,14 @@ class DownloadManager: json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False) # 6. Update cache based on model type - if model_type == "checkpoint" and hasattr(self.file_monitor, "checkpoint_scanner"): - cache = await self.file_monitor.checkpoint_scanner.get_cached_data() + if model_type == "checkpoint": + scanner = await self._get_checkpoint_scanner() logger.info(f"Updating checkpoint cache for {save_path}") else: - cache = await self.file_monitor.scanner.get_cached_data() + scanner = await self._get_lora_scanner() logger.info(f"Updating lora cache for {save_path}") + cache = await scanner.get_cached_data() metadata_dict = metadata.to_dict() metadata_dict['folder'] = relative_path cache.raw_data.append(metadata_dict) @@ -248,10 +292,7 @@ class DownloadManager: cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) # Update the hash index with the new model entry - if model_type == "checkpoint" and hasattr(self.file_monitor, "checkpoint_scanner"): - self.file_monitor.checkpoint_scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) - else: - self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) + scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) # Report 100% completion if progress_callback: diff --git a/py/services/file_monitor.py b/py/services/file_monitor.py index 1b9438f3..e702c98d 100644 --- a/py/services/file_monitor.py +++ b/py/services/file_monitor.py @@ -1,39 +1,39 @@ -from operator import itemgetter import os import logging import asyncio import time from watchdog.observers import Observer from watchdog.events import FileSystemEventHandler -from typing import List, Dict, Set +from typing import List, Dict, Set, Optional from threading import Lock -from .checkpoint_scanner import CheckpointScanner -from .lora_scanner import LoraScanner from ..config import config +from .service_registry import ServiceRegistry logger = logging.getLogger(__name__) -class LoraFileHandler(FileSystemEventHandler): - """Handler for LoRA file system events""" +class BaseFileHandler(FileSystemEventHandler): + """Base handler for file system events""" - def __init__(self, scanner: LoraScanner, loop: asyncio.AbstractEventLoop): - self.scanner = scanner - self.loop = loop # 存储事件循环引用 - self.pending_changes = set() # 待处理的变更 - self.lock = Lock() # 线程安全锁 - self.update_task = None # 异步更新任务 - self._ignore_paths = set() # Add ignore paths set - self._min_ignore_timeout = 5 # minimum timeout in seconds - self._download_speed = 1024 * 1024 # assume 1MB/s as base speed + def __init__(self, loop: asyncio.AbstractEventLoop): + self.loop = loop # Store event loop reference + self.pending_changes = set() # Pending changes + self.lock = Lock() # Thread-safe lock + self.update_task = None # Async update task + self._ignore_paths = set() # Paths to ignore + self._min_ignore_timeout = 5 # Minimum timeout in seconds + self._download_speed = 1024 * 1024 # Assume 1MB/s as base speed # Track modified files with timestamps for debouncing self.modified_files: Dict[str, float] = {} self.debounce_timer = None - self.debounce_delay = 3.0 # seconds to wait after last modification + self.debounce_delay = 3.0 # Seconds to wait after last modification - # Track files that are already scheduled for processing + # Track files already scheduled for processing self.scheduled_files: Set[str] = set() + + # File extensions to monitor - should be overridden by subclasses + self.file_extensions = set() def _should_ignore(self, path: str) -> bool: """Check if path should be ignored""" @@ -58,35 +58,33 @@ class LoraFileHandler(FileSystemEventHandler): if event.is_directory: return - # Handle safetensors files directly - if event.src_path.endswith('.safetensors'): + # Handle appropriate files based on extensions + file_ext = os.path.splitext(event.src_path)[1].lower() + if file_ext in self.file_extensions: if self._should_ignore(event.src_path): return - # We'll process this file directly and ignore subsequent modifications - # to prevent duplicate processing + # Process this file directly and ignore subsequent modifications normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/') if normalized_path not in self.scheduled_files: - logger.info(f"LoRA file created: {event.src_path}") + logger.info(f"File created: {event.src_path}") self.scheduled_files.add(normalized_path) self._schedule_update('add', event.src_path) # Ignore modifications for a short period after creation - # This helps avoid duplicate processing self.loop.call_later( self.debounce_delay * 2, self.scheduled_files.discard, normalized_path ) - # For browser downloads, we'll catch them when they're renamed to .safetensors - def on_modified(self, event): if event.is_directory: return - # Only process safetensors files - if event.src_path.endswith('.safetensors'): + # Only process files with supported extensions + file_ext = os.path.splitext(event.src_path)[1].lower() + if file_ext in self.file_extensions: if self._should_ignore(event.src_path): return @@ -134,12 +132,17 @@ class LoraFileHandler(FileSystemEventHandler): # Process stable files for file_path in files_to_process: - logger.info(f"Processing modified LoRA file: {file_path}") + logger.info(f"Processing modified file: {file_path}") self._schedule_update('add', file_path) def on_deleted(self, event): - if event.is_directory or not event.src_path.endswith('.safetensors'): + if event.is_directory: return + + file_ext = os.path.splitext(event.src_path)[1].lower() + if file_ext not in self.file_extensions: + return + if self._should_ignore(event.src_path): return @@ -147,14 +150,17 @@ class LoraFileHandler(FileSystemEventHandler): normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/') self.scheduled_files.discard(normalized_path) - logger.info(f"LoRA file deleted: {event.src_path}") + logger.info(f"File deleted: {event.src_path}") self._schedule_update('remove', event.src_path) def on_moved(self, event): """Handle file move/rename events""" - # If destination is a safetensors file, treat it as a new file - if event.dest_path.endswith('.safetensors'): + src_ext = os.path.splitext(event.src_path)[1].lower() + dest_ext = os.path.splitext(event.dest_path)[1].lower() + + # If destination has supported extension, treat as new file + if dest_ext in self.file_extensions: if self._should_ignore(event.dest_path): return @@ -162,7 +168,7 @@ class LoraFileHandler(FileSystemEventHandler): # Only process if not already scheduled if normalized_path not in self.scheduled_files: - logger.info(f"LoRA file renamed/moved to: {event.dest_path}") + logger.info(f"File renamed/moved to: {event.dest_path}") self.scheduled_files.add(normalized_path) self._schedule_update('add', event.dest_path) @@ -173,21 +179,21 @@ class LoraFileHandler(FileSystemEventHandler): normalized_path ) - # If source was a safetensors file, treat it as deleted - if event.src_path.endswith('.safetensors'): + # If source was a supported file, treat it as deleted + if src_ext in self.file_extensions: if self._should_ignore(event.src_path): return normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/') self.scheduled_files.discard(normalized_path) - logger.info(f"LoRA file moved/renamed from: {event.src_path}") + logger.info(f"File moved/renamed from: {event.src_path}") self._schedule_update('remove', event.src_path) - def _schedule_update(self, action: str, file_path: str): #file_path is a real path + def _schedule_update(self, action: str, file_path: str): """Schedule a cache update""" with self.lock: - # 使用 config 中的方法映射路径 + # Use config method to map path mapped_path = config.map_path_to_link(file_path) normalized_path = mapped_path.replace(os.sep, '/') self.pending_changes.add((action, normalized_path)) @@ -198,7 +204,20 @@ class LoraFileHandler(FileSystemEventHandler): """Create update task in the event loop""" if self.update_task is None or self.update_task.done(): self.update_task = asyncio.create_task(self._process_changes()) + + async def _process_changes(self, delay: float = 2.0): + """Process pending changes with debouncing - should be implemented by subclasses""" + raise NotImplementedError("Subclasses must implement _process_changes") + +class LoraFileHandler(BaseFileHandler): + """Handler for LoRA file system events""" + + def __init__(self, loop: asyncio.AbstractEventLoop): + super().__init__(loop) + # Set supported file extensions for LoRAs + self.file_extensions = {'.safetensors'} + async def _process_changes(self, delay: float = 2.0): """Process pending changes with debouncing""" await asyncio.sleep(delay) @@ -211,9 +230,11 @@ class LoraFileHandler(FileSystemEventHandler): if not changes: return - logger.info(f"Processing {len(changes)} file changes") + logger.info(f"Processing {len(changes)} LoRA file changes") - cache = await self.scanner.get_cached_data() + # Get scanner through ServiceRegistry + scanner = await ServiceRegistry.get_lora_scanner() + cache = await scanner.get_cached_data() needs_resort = False new_folders = set() @@ -227,36 +248,36 @@ class LoraFileHandler(FileSystemEventHandler): continue # Scan new file - lora_data = await self.scanner.scan_single_lora(file_path) - if lora_data: + model_data = await scanner.scan_single_model(file_path) + if model_data: # Update tags count - for tag in lora_data.get('tags', []): - self.scanner._tags_count[tag] = self.scanner._tags_count.get(tag, 0) + 1 + for tag in model_data.get('tags', []): + scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1 - cache.raw_data.append(lora_data) - new_folders.add(lora_data['folder']) + cache.raw_data.append(model_data) + new_folders.add(model_data['folder']) # Update hash index - if 'sha256' in lora_data: - self.scanner._hash_index.add_entry( - lora_data['sha256'], - lora_data['file_path'] + if 'sha256' in model_data: + scanner._hash_index.add_entry( + model_data['sha256'], + model_data['file_path'] ) needs_resort = True elif action == 'remove': - # Find the lora to remove so we can update tags count - lora_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None) - if lora_to_remove: + # Find the model to remove so we can update tags count + model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None) + if model_to_remove: # Update tags count by reducing counts - for tag in lora_to_remove.get('tags', []): - if tag in self.scanner._tags_count: - self.scanner._tags_count[tag] = max(0, self.scanner._tags_count[tag] - 1) - if self.scanner._tags_count[tag] == 0: - del self.scanner._tags_count[tag] + for tag in model_to_remove.get('tags', []): + if tag in scanner._tags_count: + scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1) + if scanner._tags_count[tag] == 0: + del scanner._tags_count[tag] # Remove from cache and hash index logger.info(f"Removing {file_path} from cache") - self.scanner._hash_index.remove_by_path(file_path) + scanner._hash_index.remove_by_path(file_path) cache.raw_data = [ item for item in cache.raw_data if item['file_path'] != file_path @@ -274,59 +295,140 @@ class LoraFileHandler(FileSystemEventHandler): cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) except Exception as e: - logger.error(f"Error in process_changes: {e}") + logger.error(f"Error in process_changes for LoRA: {e}") -class LoraFileMonitor: - """Monitor for LoRA file changes""" +class CheckpointFileHandler(BaseFileHandler): + """Handler for checkpoint file system events""" - def __init__(self, scanner: LoraScanner, roots: List[str]): - self.scanner = scanner - scanner.set_file_monitor(self) + def __init__(self, loop: asyncio.AbstractEventLoop): + super().__init__(loop) + # Set supported file extensions for checkpoints + self.file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'} + + async def _process_changes(self, delay: float = 2.0): + """Process pending changes with debouncing for checkpoint files""" + await asyncio.sleep(delay) + + try: + with self.lock: + changes = self.pending_changes.copy() + self.pending_changes.clear() + + if not changes: + return + + logger.info(f"Processing {len(changes)} checkpoint file changes") + + # Get scanner through ServiceRegistry + scanner = await ServiceRegistry.get_checkpoint_scanner() + cache = await scanner.get_cached_data() + needs_resort = False + new_folders = set() + + for action, file_path in changes: + try: + if action == 'add': + # Check if file already exists in cache + existing = next((item for item in cache.raw_data if item['file_path'] == file_path), None) + if existing: + logger.info(f"File {file_path} already in cache, skipping") + continue + + # Scan new file + model_data = await scanner.scan_single_model(file_path) + if model_data: + # Update tags count if applicable + for tag in model_data.get('tags', []): + scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1 + + cache.raw_data.append(model_data) + new_folders.add(model_data['folder']) + # Update hash index + if 'sha256' in model_data: + scanner._hash_index.add_entry( + model_data['sha256'], + model_data['file_path'] + ) + needs_resort = True + + elif action == 'remove': + # Find the model to remove so we can update tags count + model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None) + if model_to_remove: + # Update tags count by reducing counts + for tag in model_to_remove.get('tags', []): + if tag in scanner._tags_count: + scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1) + if scanner._tags_count[tag] == 0: + del scanner._tags_count[tag] + + # Remove from cache and hash index + logger.info(f"Removing {file_path} from checkpoint cache") + scanner._hash_index.remove_by_path(file_path) + cache.raw_data = [ + item for item in cache.raw_data + if item['file_path'] != file_path + ] + needs_resort = True + + except Exception as e: + logger.error(f"Error processing checkpoint {action} for {file_path}: {e}") + + if needs_resort: + await cache.resort() + + # Update folder list + all_folders = set(cache.folders) | new_folders + cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) + + except Exception as e: + logger.error(f"Error in process_changes for checkpoint: {e}") + + +class BaseFileMonitor: + """Base class for file monitoring""" + + def __init__(self, monitor_paths: List[str]): self.observer = Observer() self.loop = asyncio.get_event_loop() - self.handler = LoraFileHandler(scanner, self.loop) - - # 使用已存在的路径映射 self.monitor_paths = set() - for root in roots: - self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/')) + + # Process monitor paths + for path in monitor_paths: + self.monitor_paths.add(os.path.realpath(path).replace(os.sep, '/')) - # 添加所有已映射的目标路径 + # Add mapped paths from config for target_path in config._path_mappings.keys(): self.monitor_paths.add(target_path) - + def start(self): - """Start monitoring""" - for path_info in self.monitor_paths: + """Start file monitoring""" + for path in self.monitor_paths: try: - if isinstance(path_info, tuple): - # 对于链接,监控目标路径 - _, target_path = path_info - self.observer.schedule(self.handler, target_path, recursive=True) - logger.info(f"Started monitoring target path: {target_path}") - else: - # 对于普通路径,直接监控 - self.observer.schedule(self.handler, path_info, recursive=True) - logger.info(f"Started monitoring: {path_info}") + self.observer.schedule(self.handler, path, recursive=True) + logger.info(f"Started monitoring: {path}") except Exception as e: - logger.error(f"Error monitoring {path_info}: {e}") + logger.error(f"Error monitoring {path}: {e}") self.observer.start() - + def stop(self): - """Stop monitoring""" + """Stop file monitoring""" self.observer.stop() self.observer.join() - + def rescan_links(self): - """重新扫描链接(当添加新的链接时调用)""" + """Rescan links when new ones are added""" + # Find new paths not yet being monitored new_paths = set() - for path in self.monitor_paths.copy(): - self._add_link_targets(path) + for path in config._path_mappings.keys(): + real_path = os.path.realpath(path).replace(os.sep, '/') + if real_path not in self.monitor_paths: + new_paths.add(real_path) + self.monitor_paths.add(real_path) - # 添加新发现的路径到监控 - new_paths = self.monitor_paths - set(self.observer.watches.keys()) + # Add new paths to monitoring for path in new_paths: try: self.observer.schedule(self.handler, path, recursive=True) @@ -334,88 +436,86 @@ class LoraFileMonitor: except Exception as e: logger.error(f"Error adding new monitor for {path}: {e}") -# Add CheckpointFileMonitor class -class CheckpointFileMonitor(LoraFileMonitor): +class LoraFileMonitor(BaseFileMonitor): + """Monitor for LoRA file changes""" + + _instance = None + _lock = asyncio.Lock() + + def __new__(cls, monitor_paths=None): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, monitor_paths=None): + if not hasattr(self, '_initialized'): + if monitor_paths is None: + from ..config import config + monitor_paths = config.loras_roots + + super().__init__(monitor_paths) + self.handler = LoraFileHandler(self.loop) + self._initialized = True + + @classmethod + async def get_instance(cls): + """Get singleton instance with async support""" + async with cls._lock: + if cls._instance is None: + from ..config import config + cls._instance = cls(config.loras_roots) + return cls._instance + + +class CheckpointFileMonitor(BaseFileMonitor): """Monitor for checkpoint file changes""" - def __init__(self, scanner: CheckpointScanner, roots: List[str]): - # Reuse most of the LoraFileMonitor functionality, but with a different handler - self.scanner = scanner - scanner.set_file_monitor(self) - self.observer = Observer() - self.loop = asyncio.get_event_loop() - self.handler = CheckpointFileHandler(scanner, self.loop) - - # Use existing path mappings - self.monitor_paths = set() - for root in roots: - self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/')) - - # Add all mapped target paths - for target_path in config._path_mappings.keys(): - self.monitor_paths.add(target_path) - -class CheckpointFileHandler(LoraFileHandler): - """Handler for checkpoint file system events""" + _instance = None + _lock = asyncio.Lock() - def __init__(self, scanner: CheckpointScanner, loop: asyncio.AbstractEventLoop): - super().__init__(scanner, loop) - # Configure supported file extensions - self.supported_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft'} - - def on_created(self, event): - if event.is_directory: - return - - # Handle supported file extensions directly - file_ext = os.path.splitext(event.src_path)[1].lower() - if file_ext in self.supported_extensions: - if self._should_ignore(event.src_path): - return - - # Process this file directly - normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/') - if normalized_path not in self.scheduled_files: - logger.info(f"Checkpoint file created: {event.src_path}") - self.scheduled_files.add(normalized_path) - self._schedule_update('add', event.src_path) - - # Ignore modifications for a short period after creation - self.loop.call_later( - self.debounce_delay * 2, - self.scheduled_files.discard, - normalized_path - ) + def __new__(cls, monitor_paths=None): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance - def on_modified(self, event): - if event.is_directory: - return - - # Only process supported file types - file_ext = os.path.splitext(event.src_path)[1].lower() - if file_ext in self.supported_extensions: - super().on_modified(event) + def __init__(self, monitor_paths=None): + if not hasattr(self, '_initialized'): + if monitor_paths is None: + # Get checkpoint roots from scanner + monitor_paths = [] + # We'll initialize monitor paths later when scanner is available - def on_deleted(self, event): - if event.is_directory: - return + super().__init__(monitor_paths or []) + self.handler = CheckpointFileHandler(self.loop) + self._initialized = True + + @classmethod + async def get_instance(cls): + """Get singleton instance with async support""" + async with cls._lock: + if cls._instance is None: + cls._instance = cls([]) + + # Now get checkpoint roots from scanner + from .checkpoint_scanner import CheckpointScanner + scanner = await CheckpointScanner.get_instance() + monitor_paths = scanner.get_model_roots() + + # Update monitor paths + for path in monitor_paths: + real_path = os.path.realpath(path).replace(os.sep, '/') + cls._instance.monitor_paths.add(real_path) + + return cls._instance + + async def initialize_paths(self): + """Initialize monitor paths from scanner""" + if not self.monitor_paths: + scanner = await ServiceRegistry.get_checkpoint_scanner() + monitor_paths = scanner.get_model_roots() - file_ext = os.path.splitext(event.src_path)[1].lower() - if file_ext not in self.supported_extensions: - return - - super().on_deleted(event) - - def on_moved(self, event): - """Handle file move/rename events""" - src_ext = os.path.splitext(event.src_path)[1].lower() - dest_ext = os.path.splitext(event.dest_path)[1].lower() - - # If destination has supported extension, treat as new file - if dest_ext in self.supported_extensions: - super().on_moved(event) - - # If source was supported extension, treat as deleted - elif src_ext in self.supported_extensions: - super().on_moved(event) \ No newline at end of file + # Update monitor paths + for path in monitor_paths: + real_path = os.path.realpath(path).replace(os.sep, '/') + self.monitor_paths.add(real_path) \ No newline at end of file diff --git a/py/services/lora_scanner.py b/py/services/lora_scanner.py index a57c44e0..29908ef9 100644 --- a/py/services/lora_scanner.py +++ b/py/services/lora_scanner.py @@ -13,6 +13,7 @@ from .lora_hash_index import LoraHashIndex from .settings_manager import settings from ..utils.constants import NSFW_LEVELS from ..utils.utils import fuzzy_match +from .service_registry import ServiceRegistry import sys logger = logging.getLogger(__name__) diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index 04d0cb2a..edba07f6 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -12,13 +12,13 @@ from ..utils.file_utils import load_metadata, get_file_info, find_preview_file, from .model_cache import ModelCache from .model_hash_index import ModelHashIndex from ..utils.constants import PREVIEW_EXTENSIONS +from .service_registry import ServiceRegistry logger = logging.getLogger(__name__) class ModelScanner: """Base service for scanning and managing model files""" - _instance = None _lock = asyncio.Lock() def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None): @@ -35,14 +35,17 @@ class ModelScanner: self.file_extensions = file_extensions self._cache = None self._hash_index = hash_index or ModelHashIndex() - self.file_monitor = None self._tags_count = {} # Dictionary to store tag counts self._is_initializing = False # Flag to track initialization state + + # Register this service + asyncio.create_task(self._register_service()) + + async def _register_service(self): + """Register this instance with the ServiceRegistry""" + service_name = f"{self.model_type}_scanner" + await ServiceRegistry.register_service(service_name, self) - def set_file_monitor(self, monitor): - """Set file monitor instance""" - self.file_monitor = monitor - async def initialize_in_background(self) -> None: """Initialize cache in background using thread pool""" try: @@ -366,12 +369,20 @@ class ModelScanner: file_size = os.path.getsize(real_source) - if self.file_monitor: - self.file_monitor.handler.add_ignore_path( + # Get the appropriate file monitor through ServiceRegistry + if self.model_type == "lora": + monitor = await ServiceRegistry.get_lora_monitor() + elif self.model_type == "checkpoint": + monitor = await ServiceRegistry.get_checkpoint_monitor() + else: + monitor = None + + if monitor: + monitor.handler.add_ignore_path( real_source, file_size ) - self.file_monitor.handler.add_ignore_path( + monitor.handler.add_ignore_path( real_target, file_size ) diff --git a/py/services/recipe_scanner.py b/py/services/recipe_scanner.py index 588e4e43..5f3e3ffd 100644 --- a/py/services/recipe_scanner.py +++ b/py/services/recipe_scanner.py @@ -5,8 +5,8 @@ import json from typing import List, Dict, Optional, Any, Tuple from ..config import config from .recipe_cache import RecipeCache +from .service_registry import ServiceRegistry from .lora_scanner import LoraScanner -from .civitai_client import CivitaiClient from ..utils.utils import fuzzy_match import sys @@ -18,11 +18,22 @@ class RecipeScanner: _instance = None _lock = asyncio.Lock() + @classmethod + async def get_instance(cls, lora_scanner: Optional[LoraScanner] = None): + """Get singleton instance of RecipeScanner""" + async with cls._lock: + if cls._instance is None: + if not lora_scanner: + # Get lora scanner from service registry if not provided + lora_scanner = await ServiceRegistry.get_lora_scanner() + cls._instance = cls(lora_scanner) + return cls._instance + def __new__(cls, lora_scanner: Optional[LoraScanner] = None): if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._lora_scanner = lora_scanner - cls._instance._civitai_client = CivitaiClient() + cls._instance._civitai_client = None # Will be lazily initialized return cls._instance def __init__(self, lora_scanner: Optional[LoraScanner] = None): @@ -36,6 +47,12 @@ class RecipeScanner: self._lora_scanner = lora_scanner self._initialized = True + async def _get_civitai_client(self): + """Lazily initialize CivitaiClient from registry""" + if self._civitai_client is None: + self._civitai_client = await ServiceRegistry.get_civitai_client() + return self._civitai_client + async def initialize_in_background(self) -> None: """Initialize cache in background using thread pool""" try: @@ -306,10 +323,13 @@ class RecipeScanner: async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]: """Get hash from Civitai API""" try: - if not self._civitai_client: + # Get CivitaiClient from ServiceRegistry + civitai_client = await self._get_civitai_client() + if not civitai_client: + logger.error("Failed to get CivitaiClient from ServiceRegistry") return None - version_info = await self._civitai_client.get_model_version_info(model_version_id) + version_info = await civitai_client.get_model_version_info(model_version_id) if not version_info or not version_info.get('files'): logger.debug(f"No files found in version info for ID: {model_version_id}") @@ -329,10 +349,12 @@ class RecipeScanner: async def _get_model_version_name(self, model_version_id: str) -> Optional[str]: """Get model version name from Civitai API""" try: - if not self._civitai_client: + # Get CivitaiClient from ServiceRegistry + civitai_client = await self._get_civitai_client() + if not civitai_client: return None - version_info = await self._civitai_client.get_model_version_info(model_version_id) + version_info = await civitai_client.get_model_version_info(model_version_id) if version_info and 'name' in version_info: return version_info['name'] diff --git a/py/services/service_registry.py b/py/services/service_registry.py new file mode 100644 index 00000000..17f20ba7 --- /dev/null +++ b/py/services/service_registry.py @@ -0,0 +1,124 @@ +import asyncio +import logging +from typing import Optional, Dict, Any, TypeVar, Type + +logger = logging.getLogger(__name__) + +T = TypeVar('T') # Define a type variable for service types + +class ServiceRegistry: + """Centralized registry for service singletons""" + + _instance = None + _services: Dict[str, Any] = {} + _lock = asyncio.Lock() + + @classmethod + def get_instance(cls): + """Get singleton instance of the registry""" + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + async def register_service(cls, service_name: str, service_instance: Any) -> None: + """Register a service instance with the registry""" + registry = cls.get_instance() + async with cls._lock: + registry._services[service_name] = service_instance + logger.debug(f"Registered service: {service_name}") + + @classmethod + async def get_service(cls, service_name: str) -> Any: + """Get a service instance by name""" + registry = cls.get_instance() + async with cls._lock: + if service_name not in registry._services: + logger.warning(f"Service {service_name} not found in registry") + return None + return registry._services[service_name] + + # Convenience methods for common services + @classmethod + async def get_lora_scanner(cls): + """Get the LoraScanner instance""" + from .lora_scanner import LoraScanner + scanner = await cls.get_service("lora_scanner") + if scanner is None: + scanner = await LoraScanner.get_instance() + await cls.register_service("lora_scanner", scanner) + return scanner + + @classmethod + async def get_checkpoint_scanner(cls): + """Get the CheckpointScanner instance""" + from .checkpoint_scanner import CheckpointScanner + scanner = await cls.get_service("checkpoint_scanner") + if scanner is None: + scanner = await CheckpointScanner.get_instance() + await cls.register_service("checkpoint_scanner", scanner) + return scanner + + @classmethod + async def get_lora_monitor(cls): + """Get the LoraFileMonitor instance""" + from .file_monitor import LoraFileMonitor + monitor = await cls.get_service("lora_monitor") + if monitor is None: + monitor = await LoraFileMonitor.get_instance() + await cls.register_service("lora_monitor", monitor) + return monitor + + @classmethod + async def get_checkpoint_monitor(cls): + """Get the CheckpointFileMonitor instance""" + from .file_monitor import CheckpointFileMonitor + monitor = await cls.get_service("checkpoint_monitor") + if monitor is None: + monitor = await CheckpointFileMonitor.get_instance() + await cls.register_service("checkpoint_monitor", monitor) + return monitor + + @classmethod + async def get_civitai_client(cls): + """Get the CivitaiClient instance""" + from .civitai_client import CivitaiClient + client = await cls.get_service("civitai_client") + if client is None: + client = await CivitaiClient.get_instance() + await cls.register_service("civitai_client", client) + return client + + @classmethod + async def get_download_manager(cls): + """Get the DownloadManager instance""" + from .download_manager import DownloadManager + manager = await cls.get_service("download_manager") + if manager is None: + # We'll let DownloadManager.get_instance handle file_monitor parameter + manager = await DownloadManager.get_instance() + await cls.register_service("download_manager", manager) + return manager + + @classmethod + async def get_recipe_scanner(cls): + """Get the RecipeScanner instance""" + from .recipe_scanner import RecipeScanner + scanner = await cls.get_service("recipe_scanner") + if scanner is None: + lora_scanner = await cls.get_lora_scanner() + scanner = RecipeScanner(lora_scanner) + await cls.register_service("recipe_scanner", scanner) + return scanner + + @classmethod + async def get_websocket_manager(cls): + """Get the WebSocketManager instance""" + from .websocket_manager import ws_manager + manager = await cls.get_service("websocket_manager") + if manager is None: + # ws_manager is already a global instance in websocket_manager.py + from .websocket_manager import ws_manager + await cls.register_service("websocket_manager", ws_manager) + manager = ws_manager + return manager \ No newline at end of file