diff --git a/py/config.py b/py/config.py index dbb8909d..368b8f69 100644 --- a/py/config.py +++ b/py/config.py @@ -645,6 +645,23 @@ class Config: checkpoint_map = self._dedupe_existing_paths(checkpoint_paths) unet_map = self._dedupe_existing_paths(unet_paths) + # Detect when checkpoints and unet share the same physical location + # This is a configuration issue that can cause duplicate model entries + overlapping_real_paths = set(checkpoint_map.keys()) & set(unet_map.keys()) + if overlapping_real_paths: + logger.warning( + "Detected overlapping paths between 'checkpoints' and 'diffusion_models' (unet). " + "They should not point to the same physical folder as they are different model types. " + "Please fix your ComfyUI path configuration to separate these folders. " + "Falling back to 'checkpoints' for backward compatibility. " + "Overlapping real paths: %s", + [checkpoint_map.get(rp, rp) for rp in overlapping_real_paths] + ) + # Remove overlapping paths from unet_map to prioritize checkpoints + for rp in overlapping_real_paths: + if rp in unet_map: + del unet_map[rp] + merged_map: Dict[str, str] = {} for real_path, original in {**checkpoint_map, **unet_map}.items(): if real_path not in merged_map: diff --git a/tests/config/test_checkpoint_path_overlap.py b/tests/config/test_checkpoint_path_overlap.py new file mode 100644 index 00000000..2019e785 --- /dev/null +++ b/tests/config/test_checkpoint_path_overlap.py @@ -0,0 +1,160 @@ +"""Tests for checkpoint path overlap detection.""" + +import logging +import os + +import pytest + +from py import config as config_module + + +def _normalize(path: str) -> str: + return os.path.normpath(path).replace(os.sep, "/") + + +class TestCheckpointPathOverlap: + """Test detection of overlapping paths between checkpoints and unet.""" + + def test_overlapping_paths_prioritizes_checkpoints( + self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog + ): + """Test that overlapping paths prioritize checkpoints for backward compatibility.""" + # Create a shared physical folder + shared_dir = tmp_path / "shared_models" + shared_dir.mkdir() + + # Create two symlinks pointing to the same physical folder + checkpoints_link = tmp_path / "checkpoints" + unet_link = tmp_path / "unet" + checkpoints_link.symlink_to(shared_dir, target_is_directory=True) + unet_link.symlink_to(shared_dir, target_is_directory=True) + + # Create Config instance with overlapping paths + with caplog.at_level(logging.WARNING, logger=config_module.logger.name): + config = config_module.Config.__new__(config_module.Config) + config._path_mappings = {} + config._preview_root_paths = set() + config._cached_fingerprint = None + + # Call the method under test + result = config._prepare_checkpoint_paths( + [str(checkpoints_link)], [str(unet_link)] + ) + + # Verify warning was logged + warning_messages = [ + record.message + for record in caplog.records + if record.levelname == "WARNING" + and "overlapping paths" in record.message.lower() + ] + assert len(warning_messages) == 1 + assert "checkpoints" in warning_messages[0].lower() + assert "diffusion_models" in warning_messages[0].lower() or "unet" in warning_messages[0].lower() + # Verify warning mentions backward compatibility fallback + assert "falling back" in warning_messages[0].lower() or "backward compatibility" in warning_messages[0].lower() + + # Verify only one path is returned (deduplication still works) + assert len(result) == 1 + # Prioritizes checkpoints path for backward compatibility + assert _normalize(result[0]) == _normalize(str(checkpoints_link)) + + # Verify checkpoints_roots has the path (prioritized) + assert len(config.checkpoints_roots) == 1 + assert _normalize(config.checkpoints_roots[0]) == _normalize(str(checkpoints_link)) + + # Verify unet_roots is empty (overlapping paths removed) + assert config.unet_roots == [] + + def test_non_overlapping_paths_no_warning( + self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog + ): + """Test that non-overlapping paths do not trigger a warning.""" + # Create separate physical folders + checkpoints_dir = tmp_path / "checkpoints" + checkpoints_dir.mkdir() + unet_dir = tmp_path / "unet" + unet_dir.mkdir() + + # Create Config instance with separate paths + with caplog.at_level(logging.WARNING, logger=config_module.logger.name): + config = config_module.Config.__new__(config_module.Config) + config._path_mappings = {} + config._preview_root_paths = set() + config._cached_fingerprint = None + + result = config._prepare_checkpoint_paths( + [str(checkpoints_dir)], [str(unet_dir)] + ) + + # Verify no overlapping paths warning was logged + warning_messages = [ + record.message + for record in caplog.records + if record.levelname == "WARNING" + and "overlapping paths" in record.message.lower() + ] + assert len(warning_messages) == 0 + + # Verify both paths are returned + assert len(result) == 2 + normalized_result = [_normalize(p) for p in result] + assert _normalize(str(checkpoints_dir)) in normalized_result + assert _normalize(str(unet_dir)) in normalized_result + + # Verify both roots are properly set + assert len(config.checkpoints_roots) == 1 + assert len(config.unet_roots) == 1 + + def test_partial_overlap_prioritizes_checkpoints( + self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog + ): + """Test partial overlap - overlapping paths prioritize checkpoints.""" + # Create folders + shared_dir = tmp_path / "shared" + shared_dir.mkdir() + separate_checkpoint = tmp_path / "separate_ckpt" + separate_checkpoint.mkdir() + separate_unet = tmp_path / "separate_unet" + separate_unet.mkdir() + + # Create symlinks - one shared, others separate + shared_link = tmp_path / "shared_link" + shared_link.symlink_to(shared_dir, target_is_directory=True) + + with caplog.at_level(logging.WARNING, logger=config_module.logger.name): + config = config_module.Config.__new__(config_module.Config) + config._path_mappings = {} + config._preview_root_paths = set() + config._cached_fingerprint = None + + # One checkpoint path overlaps with one unet path + result = config._prepare_checkpoint_paths( + [str(shared_link), str(separate_checkpoint)], + [str(shared_link), str(separate_unet)] + ) + + # Verify warning was logged for the overlapping path + warning_messages = [ + record.message + for record in caplog.records + if record.levelname == "WARNING" + and "overlapping paths" in record.message.lower() + ] + assert len(warning_messages) == 1 + + # Verify 3 unique paths (shared counted once as checkpoint, plus separate ones) + assert len(result) == 3 + + # Verify the overlapping path appears in warning message + assert str(shared_link.name) in warning_messages[0] or str(shared_dir.name) in warning_messages[0] + + # Verify checkpoints_roots includes both checkpoint paths (including the shared one) + assert len(config.checkpoints_roots) == 2 + checkpoint_normalized = [_normalize(p) for p in config.checkpoints_roots] + assert _normalize(str(shared_link)) in checkpoint_normalized + assert _normalize(str(separate_checkpoint)) in checkpoint_normalized + + # Verify unet_roots only includes the non-overlapping unet path + assert len(config.unet_roots) == 1 + assert _normalize(config.unet_roots[0]) == _normalize(str(separate_unet))