feat: deduplicate and merge checkpoint and unet paths in configuration. See #338 and #312

This commit is contained in:
Will Miao
2025-08-04 10:48:48 +08:00
parent 28fe3e7b7a
commit bf7b07ba74

View File

@@ -204,16 +204,20 @@ class Config:
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
unet_map[real_path] = unet_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
# Merge both maps and deduplicate by real path
merged_map = {}
for real_path, orig_path in {**checkpoint_map, **unet_map}.items():
if real_path not in merged_map:
merged_map[real_path] = orig_path
# Now sort and use only the deduplicated real paths
unique_checkpoint_paths = sorted(checkpoint_map.values(), key=lambda p: p.lower())
unique_unet_paths = sorted(unet_map.values(), key=lambda p: p.lower())
unique_paths = sorted(merged_map.values(), key=lambda p: p.lower())
# Store individual paths in class properties
self.checkpoints_roots = unique_checkpoint_paths
self.unet_roots = unique_unet_paths
# Split back into checkpoints and unet roots for class properties
self.checkpoints_roots = [p for p in unique_paths if p in checkpoint_map.values()]
self.unet_roots = [p for p in unique_paths if p in unet_map.values()]
# Combine all checkpoint-related paths for return value
all_paths = unique_checkpoint_paths + unique_unet_paths
all_paths = unique_paths
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(all_paths) if all_paths else "[]"))