mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
fix: isolate extra unet paths from checkpoints to prevent type misclassification
Refactor _prepare_checkpoint_paths() to return a tuple instead of having side effects on instance variables. This prevents extra unet paths from being incorrectly classified as checkpoints when processing extra paths. - Changed return type from List[str] to Tuple[List[str], List[str], List[str]] (all_paths, checkpoint_roots, unet_roots) - Updated _init_checkpoint_paths() and _apply_library_paths() callers - Fixed extra paths processing to properly isolate main and extra roots - Updated test_checkpoint_path_overlap.py tests for new API This ensures models in extra unet paths are correctly identified as diffusion_model type and don't appear in checkpoints list.
This commit is contained in:
@@ -36,8 +36,8 @@ class TestCheckpointPathOverlap:
|
||||
config._preview_root_paths = set()
|
||||
config._cached_fingerprint = None
|
||||
|
||||
# Call the method under test
|
||||
result = config._prepare_checkpoint_paths(
|
||||
# Call the method under test - now returns a tuple
|
||||
all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
|
||||
[str(checkpoints_link)], [str(unet_link)]
|
||||
)
|
||||
|
||||
@@ -50,21 +50,27 @@ class TestCheckpointPathOverlap:
|
||||
]
|
||||
assert len(warning_messages) == 1
|
||||
assert "checkpoints" in warning_messages[0].lower()
|
||||
assert "diffusion_models" in warning_messages[0].lower() or "unet" in warning_messages[0].lower()
|
||||
assert (
|
||||
"diffusion_models" in warning_messages[0].lower()
|
||||
or "unet" in warning_messages[0].lower()
|
||||
)
|
||||
# Verify warning mentions backward compatibility fallback
|
||||
assert "falling back" in warning_messages[0].lower() or "backward compatibility" in warning_messages[0].lower()
|
||||
assert (
|
||||
"falling back" in warning_messages[0].lower()
|
||||
or "backward compatibility" in warning_messages[0].lower()
|
||||
)
|
||||
|
||||
# Verify only one path is returned (deduplication still works)
|
||||
assert len(result) == 1
|
||||
assert len(all_paths) == 1
|
||||
# Prioritizes checkpoints path for backward compatibility
|
||||
assert _normalize(result[0]) == _normalize(str(checkpoints_link))
|
||||
assert _normalize(all_paths[0]) == _normalize(str(checkpoints_link))
|
||||
|
||||
# Verify checkpoints_roots has the path (prioritized)
|
||||
assert len(config.checkpoints_roots) == 1
|
||||
assert _normalize(config.checkpoints_roots[0]) == _normalize(str(checkpoints_link))
|
||||
# Verify checkpoint_roots has the path (prioritized)
|
||||
assert len(checkpoint_roots) == 1
|
||||
assert _normalize(checkpoint_roots[0]) == _normalize(str(checkpoints_link))
|
||||
|
||||
# Verify unet_roots is empty (overlapping paths removed)
|
||||
assert config.unet_roots == []
|
||||
assert unet_roots == []
|
||||
|
||||
def test_non_overlapping_paths_no_warning(
|
||||
self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog
|
||||
@@ -83,7 +89,7 @@ class TestCheckpointPathOverlap:
|
||||
config._preview_root_paths = set()
|
||||
config._cached_fingerprint = None
|
||||
|
||||
result = config._prepare_checkpoint_paths(
|
||||
all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
|
||||
[str(checkpoints_dir)], [str(unet_dir)]
|
||||
)
|
||||
|
||||
@@ -97,14 +103,14 @@ class TestCheckpointPathOverlap:
|
||||
assert len(warning_messages) == 0
|
||||
|
||||
# Verify both paths are returned
|
||||
assert len(result) == 2
|
||||
normalized_result = [_normalize(p) for p in result]
|
||||
assert len(all_paths) == 2
|
||||
normalized_result = [_normalize(p) for p in all_paths]
|
||||
assert _normalize(str(checkpoints_dir)) in normalized_result
|
||||
assert _normalize(str(unet_dir)) in normalized_result
|
||||
|
||||
# Verify both roots are properly set
|
||||
assert len(config.checkpoints_roots) == 1
|
||||
assert len(config.unet_roots) == 1
|
||||
assert len(checkpoint_roots) == 1
|
||||
assert len(unet_roots) == 1
|
||||
|
||||
def test_partial_overlap_prioritizes_checkpoints(
|
||||
self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog
|
||||
@@ -129,9 +135,9 @@ class TestCheckpointPathOverlap:
|
||||
config._cached_fingerprint = None
|
||||
|
||||
# One checkpoint path overlaps with one unet path
|
||||
result = config._prepare_checkpoint_paths(
|
||||
all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
|
||||
[str(shared_link), str(separate_checkpoint)],
|
||||
[str(shared_link), str(separate_unet)]
|
||||
[str(shared_link), str(separate_unet)],
|
||||
)
|
||||
|
||||
# Verify warning was logged for the overlapping path
|
||||
@@ -144,17 +150,20 @@ class TestCheckpointPathOverlap:
|
||||
assert len(warning_messages) == 1
|
||||
|
||||
# Verify 3 unique paths (shared counted once as checkpoint, plus separate ones)
|
||||
assert len(result) == 3
|
||||
assert len(all_paths) == 3
|
||||
|
||||
# Verify the overlapping path appears in warning message
|
||||
assert str(shared_link.name) in warning_messages[0] or str(shared_dir.name) in warning_messages[0]
|
||||
assert (
|
||||
str(shared_link.name) in warning_messages[0]
|
||||
or str(shared_dir.name) in warning_messages[0]
|
||||
)
|
||||
|
||||
# Verify checkpoints_roots includes both checkpoint paths (including the shared one)
|
||||
assert len(config.checkpoints_roots) == 2
|
||||
checkpoint_normalized = [_normalize(p) for p in config.checkpoints_roots]
|
||||
# Verify checkpoint_roots includes both checkpoint paths (including the shared one)
|
||||
assert len(checkpoint_roots) == 2
|
||||
checkpoint_normalized = [_normalize(p) for p in checkpoint_roots]
|
||||
assert _normalize(str(shared_link)) in checkpoint_normalized
|
||||
assert _normalize(str(separate_checkpoint)) in checkpoint_normalized
|
||||
|
||||
# Verify unet_roots only includes the non-overlapping unet path
|
||||
assert len(config.unet_roots) == 1
|
||||
assert _normalize(config.unet_roots[0]) == _normalize(str(separate_unet))
|
||||
assert len(unet_roots) == 1
|
||||
assert _normalize(unet_roots[0]) == _normalize(str(separate_unet))
|
||||
|
||||
Reference in New Issue
Block a user