mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 13:12:12 -03:00
Refactor _prepare_checkpoint_paths() to return a tuple instead of having side effects on instance variables. This prevents extra unet paths from being incorrectly classified as checkpoints when processing extra paths. - Changed return type from List[str] to Tuple[List[str], List[str], List[str]] (all_paths, checkpoint_roots, unet_roots) - Updated _init_checkpoint_paths() and _apply_library_paths() callers - Fixed extra paths processing to properly isolate main and extra roots - Updated test_checkpoint_path_overlap.py tests for new API This ensures models in extra unet paths are correctly identified as diffusion_model type and don't appear in checkpoints list.
170 lines
6.6 KiB
Python
170 lines
6.6 KiB
Python
"""Tests for checkpoint path overlap detection."""
|
|
|
|
import logging
|
|
import os
|
|
|
|
import pytest
|
|
|
|
from py import config as config_module
|
|
|
|
|
|
def _normalize(path: str) -> str:
|
|
return os.path.normpath(path).replace(os.sep, "/")
|
|
|
|
|
|
class TestCheckpointPathOverlap:
|
|
"""Test detection of overlapping paths between checkpoints and unet."""
|
|
|
|
def test_overlapping_paths_prioritizes_checkpoints(
|
|
self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog
|
|
):
|
|
"""Test that overlapping paths prioritize checkpoints for backward compatibility."""
|
|
# Create a shared physical folder
|
|
shared_dir = tmp_path / "shared_models"
|
|
shared_dir.mkdir()
|
|
|
|
# Create two symlinks pointing to the same physical folder
|
|
checkpoints_link = tmp_path / "checkpoints"
|
|
unet_link = tmp_path / "unet"
|
|
checkpoints_link.symlink_to(shared_dir, target_is_directory=True)
|
|
unet_link.symlink_to(shared_dir, target_is_directory=True)
|
|
|
|
# Create Config instance with overlapping paths
|
|
with caplog.at_level(logging.WARNING, logger=config_module.logger.name):
|
|
config = config_module.Config.__new__(config_module.Config)
|
|
config._path_mappings = {}
|
|
config._preview_root_paths = set()
|
|
config._cached_fingerprint = None
|
|
|
|
# Call the method under test - now returns a tuple
|
|
all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
|
|
[str(checkpoints_link)], [str(unet_link)]
|
|
)
|
|
|
|
# Verify warning was logged
|
|
warning_messages = [
|
|
record.message
|
|
for record in caplog.records
|
|
if record.levelname == "WARNING"
|
|
and "overlapping paths" in record.message.lower()
|
|
]
|
|
assert len(warning_messages) == 1
|
|
assert "checkpoints" in warning_messages[0].lower()
|
|
assert (
|
|
"diffusion_models" in warning_messages[0].lower()
|
|
or "unet" in warning_messages[0].lower()
|
|
)
|
|
# Verify warning mentions backward compatibility fallback
|
|
assert (
|
|
"falling back" in warning_messages[0].lower()
|
|
or "backward compatibility" in warning_messages[0].lower()
|
|
)
|
|
|
|
# Verify only one path is returned (deduplication still works)
|
|
assert len(all_paths) == 1
|
|
# Prioritizes checkpoints path for backward compatibility
|
|
assert _normalize(all_paths[0]) == _normalize(str(checkpoints_link))
|
|
|
|
# Verify checkpoint_roots has the path (prioritized)
|
|
assert len(checkpoint_roots) == 1
|
|
assert _normalize(checkpoint_roots[0]) == _normalize(str(checkpoints_link))
|
|
|
|
# Verify unet_roots is empty (overlapping paths removed)
|
|
assert unet_roots == []
|
|
|
|
def test_non_overlapping_paths_no_warning(
|
|
self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog
|
|
):
|
|
"""Test that non-overlapping paths do not trigger a warning."""
|
|
# Create separate physical folders
|
|
checkpoints_dir = tmp_path / "checkpoints"
|
|
checkpoints_dir.mkdir()
|
|
unet_dir = tmp_path / "unet"
|
|
unet_dir.mkdir()
|
|
|
|
# Create Config instance with separate paths
|
|
with caplog.at_level(logging.WARNING, logger=config_module.logger.name):
|
|
config = config_module.Config.__new__(config_module.Config)
|
|
config._path_mappings = {}
|
|
config._preview_root_paths = set()
|
|
config._cached_fingerprint = None
|
|
|
|
all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
|
|
[str(checkpoints_dir)], [str(unet_dir)]
|
|
)
|
|
|
|
# Verify no overlapping paths warning was logged
|
|
warning_messages = [
|
|
record.message
|
|
for record in caplog.records
|
|
if record.levelname == "WARNING"
|
|
and "overlapping paths" in record.message.lower()
|
|
]
|
|
assert len(warning_messages) == 0
|
|
|
|
# Verify both paths are returned
|
|
assert len(all_paths) == 2
|
|
normalized_result = [_normalize(p) for p in all_paths]
|
|
assert _normalize(str(checkpoints_dir)) in normalized_result
|
|
assert _normalize(str(unet_dir)) in normalized_result
|
|
|
|
# Verify both roots are properly set
|
|
assert len(checkpoint_roots) == 1
|
|
assert len(unet_roots) == 1
|
|
|
|
def test_partial_overlap_prioritizes_checkpoints(
|
|
self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog
|
|
):
|
|
"""Test partial overlap - overlapping paths prioritize checkpoints."""
|
|
# Create folders
|
|
shared_dir = tmp_path / "shared"
|
|
shared_dir.mkdir()
|
|
separate_checkpoint = tmp_path / "separate_ckpt"
|
|
separate_checkpoint.mkdir()
|
|
separate_unet = tmp_path / "separate_unet"
|
|
separate_unet.mkdir()
|
|
|
|
# Create symlinks - one shared, others separate
|
|
shared_link = tmp_path / "shared_link"
|
|
shared_link.symlink_to(shared_dir, target_is_directory=True)
|
|
|
|
with caplog.at_level(logging.WARNING, logger=config_module.logger.name):
|
|
config = config_module.Config.__new__(config_module.Config)
|
|
config._path_mappings = {}
|
|
config._preview_root_paths = set()
|
|
config._cached_fingerprint = None
|
|
|
|
# One checkpoint path overlaps with one unet path
|
|
all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
|
|
[str(shared_link), str(separate_checkpoint)],
|
|
[str(shared_link), str(separate_unet)],
|
|
)
|
|
|
|
# Verify warning was logged for the overlapping path
|
|
warning_messages = [
|
|
record.message
|
|
for record in caplog.records
|
|
if record.levelname == "WARNING"
|
|
and "overlapping paths" in record.message.lower()
|
|
]
|
|
assert len(warning_messages) == 1
|
|
|
|
# Verify 3 unique paths (shared counted once as checkpoint, plus separate ones)
|
|
assert len(all_paths) == 3
|
|
|
|
# Verify the overlapping path appears in warning message
|
|
assert (
|
|
str(shared_link.name) in warning_messages[0]
|
|
or str(shared_dir.name) in warning_messages[0]
|
|
)
|
|
|
|
# Verify checkpoint_roots includes both checkpoint paths (including the shared one)
|
|
assert len(checkpoint_roots) == 2
|
|
checkpoint_normalized = [_normalize(p) for p in checkpoint_roots]
|
|
assert _normalize(str(shared_link)) in checkpoint_normalized
|
|
assert _normalize(str(separate_checkpoint)) in checkpoint_normalized
|
|
|
|
# Verify unet_roots only includes the non-overlapping unet path
|
|
assert len(unet_roots) == 1
|
|
assert _normalize(unet_roots[0]) == _normalize(str(separate_unet))
|