refactor: Update checkpoint handling to use base_models_roots and streamline path management

This commit is contained in:
Will Miao
2025-07-02 21:29:41 +08:00
parent 30374ae3e6
commit 40ad590046
4 changed files with 50 additions and 69 deletions

View File

@@ -22,7 +22,9 @@ class Config:
# 静态路由映射字典, target to route mapping
self._route_mappings = {}
self.loras_roots = self._init_lora_paths()
self.checkpoints_roots = self._init_checkpoint_paths()
self.checkpoints_roots = None
self.unet_roots = None
self.base_models_roots = self._init_checkpoint_paths()
# 在初始化时扫描符号链接
self._scan_symbolic_links()
@@ -33,34 +35,26 @@ class Config:
def save_folder_paths_to_settings(self):
"""Save folder paths to settings.json for standalone mode to use later"""
try:
# Check if we're running in ComfyUI mode (not standalone)
if hasattr(folder_paths, "get_folder_paths") and not isinstance(folder_paths, type):
# Get all relevant paths
lora_paths = folder_paths.get_folder_paths("loras")
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
diffuser_paths = folder_paths.get_folder_paths("diffusers")
unet_paths = folder_paths.get_folder_paths("unet")
# Load existing settings
settings_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json')
settings = {}
if os.path.exists(settings_path):
with open(settings_path, 'r', encoding='utf-8') as f:
settings = json.load(f)
# Update settings with paths
settings['folder_paths'] = {
'loras': lora_paths,
'checkpoints': checkpoint_paths,
'diffusers': diffuser_paths,
'unet': unet_paths
}
# Save settings
with open(settings_path, 'w', encoding='utf-8') as f:
json.dump(settings, f, indent=2)
logger.info("Saved folder paths to settings.json")
# Check if we're running in ComfyUI mode (not standalone)
# Load existing settings
settings_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json')
settings = {}
if os.path.exists(settings_path):
with open(settings_path, 'r', encoding='utf-8') as f:
settings = json.load(f)
# Update settings with paths
settings['folder_paths'] = {
'loras': self.loras_roots,
'checkpoints': self.checkpoints_roots,
'unet': self.unet_roots,
}
# Save settings
with open(settings_path, 'w', encoding='utf-8') as f:
json.dump(settings, f, indent=2)
logger.info("Saved folder paths to settings.json")
except Exception as e:
logger.warning(f"Failed to save folder paths: {e}")
@@ -86,7 +80,7 @@ class Config:
for root in self.loras_roots:
self._scan_directory_links(root)
for root in self.checkpoints_roots:
for root in self.base_models_roots:
self._scan_directory_links(root)
def _scan_directory_links(self, root: str):
@@ -178,30 +172,36 @@ class Config:
try:
# Get checkpoint paths from folder_paths
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
diffusion_paths = folder_paths.get_folder_paths("diffusers")
unet_paths = folder_paths.get_folder_paths("unet")
# Combine all checkpoint-related paths
all_paths = checkpoint_paths + diffusion_paths + unet_paths
# Filter and normalize paths
paths = sorted(set(path.replace(os.sep, "/")
for path in all_paths
# Sort each list individually
checkpoint_paths = sorted(set(path.replace(os.sep, "/")
for path in checkpoint_paths
if os.path.exists(path)), key=lambda p: p.lower())
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(paths) if paths else "[]"))
unet_paths = sorted(set(path.replace(os.sep, "/")
for path in unet_paths
if os.path.exists(path)), key=lambda p: p.lower())
if not paths:
# Combine all checkpoint-related paths, ensuring checkpoint_paths are first
all_paths = checkpoint_paths + unet_paths
self.checkpoints_roots = checkpoint_paths
self.unet_roots = unet_paths
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(all_paths) if all_paths else "[]"))
if not all_paths:
logger.warning("No valid checkpoint folders found in ComfyUI configuration")
return []
# 初始化路径映射,与 LoRA 路径处理方式相同
for path in paths:
# Initialize path mappings, similar to LoRA path handling
for path in all_paths:
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
if real_path != path:
self.add_path_mapping(path, real_path)
return paths
return all_paths
except Exception as e:
logger.warning(f"Error initializing checkpoint paths: {e}")
return []

View File

@@ -62,7 +62,7 @@ class LoraManager:
added_targets.add(real_root)
# Add static routes for each checkpoint root
for idx, root in enumerate(config.checkpoints_roots, start=1):
for idx, root in enumerate(config.base_models_roots, start=1):
preview_path = f'/checkpoints_static/root{idx}/preview'
real_root = root
@@ -88,8 +88,8 @@ class LoraManager:
for target_path, link_path in config._path_mappings.items():
if target_path not in added_targets:
# Determine if this is a checkpoint or lora link based on path
is_checkpoint = any(cp_root in link_path for cp_root in config.checkpoints_roots)
is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.checkpoints_roots)
is_checkpoint = any(cp_root in link_path for cp_root in config.base_models_roots)
is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.base_models_roots)
if is_checkpoint:
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'

View File

@@ -33,7 +33,6 @@ class CheckpointScanner(ModelScanner):
file_extensions=file_extensions,
hash_index=ModelHashIndex()
)
self._checkpoint_roots = self._init_checkpoint_roots()
self._initialized = True
@classmethod
@@ -44,27 +43,9 @@ class CheckpointScanner(ModelScanner):
cls._instance = cls()
return cls._instance
def _init_checkpoint_roots(self) -> List[str]:
"""Initialize checkpoint roots from ComfyUI settings"""
# Get both checkpoint and diffusion_models paths
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
diffusion_paths = folder_paths.get_folder_paths("diffusion_models")
# Combine, normalize and deduplicate paths
all_paths = set()
for path in checkpoint_paths + diffusion_paths:
if os.path.exists(path):
norm_path = path.replace(os.sep, "/")
all_paths.add(norm_path)
# Sort for consistent order
sorted_paths = sorted(all_paths, key=lambda p: p.lower())
return sorted_paths
def get_model_roots(self) -> List[str]:
"""Get checkpoint root directories"""
return self._checkpoint_roots
return config.base_models_roots
async def scan_all_models(self) -> List[Dict]:
"""Scan all checkpoint directories and return metadata"""
@@ -72,7 +53,7 @@ class CheckpointScanner(ModelScanner):
# Create scan tasks for each directory
scan_tasks = []
for root in self._checkpoint_roots:
for root in self.get_model_roots():
task = asyncio.create_task(self._scan_directory(root))
scan_tasks.append(task)

View File

@@ -252,7 +252,7 @@ class StandaloneLoraManager(LoraManager):
added_targets.add(os.path.normpath(real_root))
# Add static routes for each checkpoint root
for idx, root in enumerate(config.checkpoints_roots, start=1):
for idx, root in enumerate(config.base_models_roots, start=1):
if not os.path.exists(root):
logger.warning(f"Checkpoint root path does not exist: {root}")
continue
@@ -288,8 +288,8 @@ class StandaloneLoraManager(LoraManager):
norm_target = os.path.normpath(target_path)
if norm_target not in added_targets:
# Determine if this is a checkpoint or lora link based on path
is_checkpoint = any(os.path.normpath(cp_root) in os.path.normpath(link_path) for cp_root in config.checkpoints_roots)
is_checkpoint = is_checkpoint or any(os.path.normpath(cp_root) in norm_target for cp_root in config.checkpoints_roots)
is_checkpoint = any(os.path.normpath(cp_root) in os.path.normpath(link_path) for cp_root in config.base_models_roots)
is_checkpoint = is_checkpoint or any(os.path.normpath(cp_root) in norm_target for cp_root in config.base_models_roots)
if is_checkpoint:
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'