From bf7b07ba745b2104fe062776967c07ed2cb16cfd Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Mon, 4 Aug 2025 10:48:48 +0800 Subject: [PATCH] feat: deduplicate and merge checkpoint and unet paths in configuration. See #338 and #312 --- py/config.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/py/config.py b/py/config.py index b0f0f059..f915c275 100644 --- a/py/config.py +++ b/py/config.py @@ -204,16 +204,20 @@ class Config: real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/') unet_map[real_path] = unet_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen + # Merge both maps and deduplicate by real path + merged_map = {} + for real_path, orig_path in {**checkpoint_map, **unet_map}.items(): + if real_path not in merged_map: + merged_map[real_path] = orig_path + # Now sort and use only the deduplicated real paths - unique_checkpoint_paths = sorted(checkpoint_map.values(), key=lambda p: p.lower()) - unique_unet_paths = sorted(unet_map.values(), key=lambda p: p.lower()) + unique_paths = sorted(merged_map.values(), key=lambda p: p.lower()) - # Store individual paths in class properties - self.checkpoints_roots = unique_checkpoint_paths - self.unet_roots = unique_unet_paths + # Split back into checkpoints and unet roots for class properties + self.checkpoints_roots = [p for p in unique_paths if p in checkpoint_map.values()] + self.unet_roots = [p for p in unique_paths if p in unet_map.values()] - # Combine all checkpoint-related paths for return value - all_paths = unique_checkpoint_paths + unique_unet_paths + all_paths = unique_paths logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(all_paths) if all_paths else "[]"))