From 40ad5900464bbd37e0b2b4f3508628e74e110dbc Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 2 Jul 2025 21:29:41 +0800 Subject: [PATCH] refactor: Update checkpoint handling to use base_models_roots and streamline path management --- py/config.py | 84 +++++++++++++++---------------- py/lora_manager.py | 6 +-- py/services/checkpoint_scanner.py | 23 +-------- standalone.py | 6 +-- 4 files changed, 50 insertions(+), 69 deletions(-) diff --git a/py/config.py b/py/config.py index 3bc7a930..3782abfa 100644 --- a/py/config.py +++ b/py/config.py @@ -22,7 +22,9 @@ class Config: # 静态路由映射字典, target to route mapping self._route_mappings = {} self.loras_roots = self._init_lora_paths() - self.checkpoints_roots = self._init_checkpoint_paths() + self.checkpoints_roots = None + self.unet_roots = None + self.base_models_roots = self._init_checkpoint_paths() # 在初始化时扫描符号链接 self._scan_symbolic_links() @@ -33,34 +35,26 @@ class Config: def save_folder_paths_to_settings(self): """Save folder paths to settings.json for standalone mode to use later""" try: - # Check if we're running in ComfyUI mode (not standalone) - if hasattr(folder_paths, "get_folder_paths") and not isinstance(folder_paths, type): - # Get all relevant paths - lora_paths = folder_paths.get_folder_paths("loras") - checkpoint_paths = folder_paths.get_folder_paths("checkpoints") - diffuser_paths = folder_paths.get_folder_paths("diffusers") - unet_paths = folder_paths.get_folder_paths("unet") - - # Load existing settings - settings_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json') - settings = {} - if os.path.exists(settings_path): - with open(settings_path, 'r', encoding='utf-8') as f: - settings = json.load(f) - - # Update settings with paths - settings['folder_paths'] = { - 'loras': lora_paths, - 'checkpoints': checkpoint_paths, - 'diffusers': diffuser_paths, - 'unet': unet_paths - } - - # Save settings - with open(settings_path, 'w', encoding='utf-8') as f: - json.dump(settings, f, indent=2) - - logger.info("Saved folder paths to settings.json") + # Check if we're running in ComfyUI mode (not standalone) + # Load existing settings + settings_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json') + settings = {} + if os.path.exists(settings_path): + with open(settings_path, 'r', encoding='utf-8') as f: + settings = json.load(f) + + # Update settings with paths + settings['folder_paths'] = { + 'loras': self.loras_roots, + 'checkpoints': self.checkpoints_roots, + 'unet': self.unet_roots, + } + + # Save settings + with open(settings_path, 'w', encoding='utf-8') as f: + json.dump(settings, f, indent=2) + + logger.info("Saved folder paths to settings.json") except Exception as e: logger.warning(f"Failed to save folder paths: {e}") @@ -86,7 +80,7 @@ class Config: for root in self.loras_roots: self._scan_directory_links(root) - for root in self.checkpoints_roots: + for root in self.base_models_roots: self._scan_directory_links(root) def _scan_directory_links(self, root: str): @@ -178,30 +172,36 @@ class Config: try: # Get checkpoint paths from folder_paths checkpoint_paths = folder_paths.get_folder_paths("checkpoints") - diffusion_paths = folder_paths.get_folder_paths("diffusers") unet_paths = folder_paths.get_folder_paths("unet") - # Combine all checkpoint-related paths - all_paths = checkpoint_paths + diffusion_paths + unet_paths - - # Filter and normalize paths - paths = sorted(set(path.replace(os.sep, "/") - for path in all_paths + # Sort each list individually + checkpoint_paths = sorted(set(path.replace(os.sep, "/") + for path in checkpoint_paths if os.path.exists(path)), key=lambda p: p.lower()) - logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(paths) if paths else "[]")) + unet_paths = sorted(set(path.replace(os.sep, "/") + for path in unet_paths + if os.path.exists(path)), key=lambda p: p.lower()) - if not paths: + # Combine all checkpoint-related paths, ensuring checkpoint_paths are first + all_paths = checkpoint_paths + unet_paths + + self.checkpoints_roots = checkpoint_paths + self.unet_roots = unet_paths + + logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(all_paths) if all_paths else "[]")) + + if not all_paths: logger.warning("No valid checkpoint folders found in ComfyUI configuration") return [] - # 初始化路径映射,与 LoRA 路径处理方式相同 - for path in paths: + # Initialize path mappings, similar to LoRA path handling + for path in all_paths: real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/') if real_path != path: self.add_path_mapping(path, real_path) - return paths + return all_paths except Exception as e: logger.warning(f"Error initializing checkpoint paths: {e}") return [] diff --git a/py/lora_manager.py b/py/lora_manager.py index 51d1613a..02d6b5db 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -62,7 +62,7 @@ class LoraManager: added_targets.add(real_root) # Add static routes for each checkpoint root - for idx, root in enumerate(config.checkpoints_roots, start=1): + for idx, root in enumerate(config.base_models_roots, start=1): preview_path = f'/checkpoints_static/root{idx}/preview' real_root = root @@ -88,8 +88,8 @@ class LoraManager: for target_path, link_path in config._path_mappings.items(): if target_path not in added_targets: # Determine if this is a checkpoint or lora link based on path - is_checkpoint = any(cp_root in link_path for cp_root in config.checkpoints_roots) - is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.checkpoints_roots) + is_checkpoint = any(cp_root in link_path for cp_root in config.base_models_roots) + is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.base_models_roots) if is_checkpoint: route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview' diff --git a/py/services/checkpoint_scanner.py b/py/services/checkpoint_scanner.py index c895cb25..26733ab5 100644 --- a/py/services/checkpoint_scanner.py +++ b/py/services/checkpoint_scanner.py @@ -33,7 +33,6 @@ class CheckpointScanner(ModelScanner): file_extensions=file_extensions, hash_index=ModelHashIndex() ) - self._checkpoint_roots = self._init_checkpoint_roots() self._initialized = True @classmethod @@ -44,27 +43,9 @@ class CheckpointScanner(ModelScanner): cls._instance = cls() return cls._instance - def _init_checkpoint_roots(self) -> List[str]: - """Initialize checkpoint roots from ComfyUI settings""" - # Get both checkpoint and diffusion_models paths - checkpoint_paths = folder_paths.get_folder_paths("checkpoints") - diffusion_paths = folder_paths.get_folder_paths("diffusion_models") - - # Combine, normalize and deduplicate paths - all_paths = set() - for path in checkpoint_paths + diffusion_paths: - if os.path.exists(path): - norm_path = path.replace(os.sep, "/") - all_paths.add(norm_path) - - # Sort for consistent order - sorted_paths = sorted(all_paths, key=lambda p: p.lower()) - - return sorted_paths - def get_model_roots(self) -> List[str]: """Get checkpoint root directories""" - return self._checkpoint_roots + return config.base_models_roots async def scan_all_models(self) -> List[Dict]: """Scan all checkpoint directories and return metadata""" @@ -72,7 +53,7 @@ class CheckpointScanner(ModelScanner): # Create scan tasks for each directory scan_tasks = [] - for root in self._checkpoint_roots: + for root in self.get_model_roots(): task = asyncio.create_task(self._scan_directory(root)) scan_tasks.append(task) diff --git a/standalone.py b/standalone.py index b27128ba..0120ab53 100644 --- a/standalone.py +++ b/standalone.py @@ -252,7 +252,7 @@ class StandaloneLoraManager(LoraManager): added_targets.add(os.path.normpath(real_root)) # Add static routes for each checkpoint root - for idx, root in enumerate(config.checkpoints_roots, start=1): + for idx, root in enumerate(config.base_models_roots, start=1): if not os.path.exists(root): logger.warning(f"Checkpoint root path does not exist: {root}") continue @@ -288,8 +288,8 @@ class StandaloneLoraManager(LoraManager): norm_target = os.path.normpath(target_path) if norm_target not in added_targets: # Determine if this is a checkpoint or lora link based on path - is_checkpoint = any(os.path.normpath(cp_root) in os.path.normpath(link_path) for cp_root in config.checkpoints_roots) - is_checkpoint = is_checkpoint or any(os.path.normpath(cp_root) in norm_target for cp_root in config.checkpoints_roots) + is_checkpoint = any(os.path.normpath(cp_root) in os.path.normpath(link_path) for cp_root in config.base_models_roots) + is_checkpoint = is_checkpoint or any(os.path.normpath(cp_root) in norm_target for cp_root in config.base_models_roots) if is_checkpoint: route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'