mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
fix: isolate extra unet paths from checkpoints to prevent type misclassification
Refactor _prepare_checkpoint_paths() to return a tuple instead of having side effects on instance variables. This prevents extra unet paths from being incorrectly classified as checkpoints when processing extra paths. - Changed return type from List[str] to Tuple[List[str], List[str], List[str]] (all_paths, checkpoint_roots, unet_roots) - Updated _init_checkpoint_paths() and _apply_library_paths() callers - Fixed extra paths processing to properly isolate main and extra roots - Updated test_checkpoint_path_overlap.py tests for new API This ensures models in extra unet paths are correctly identified as diffusion_model type and don't appear in checkpoints list.
This commit is contained in:
@@ -36,8 +36,8 @@ class TestCheckpointPathOverlap:
|
||||
config._preview_root_paths = set()
|
||||
config._cached_fingerprint = None
|
||||
|
||||
# Call the method under test
|
||||
result = config._prepare_checkpoint_paths(
|
||||
# Call the method under test - now returns a tuple
|
||||
all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
|
||||
[str(checkpoints_link)], [str(unet_link)]
|
||||
)
|
||||
|
||||
@@ -50,21 +50,27 @@ class TestCheckpointPathOverlap:
|
||||
]
|
||||
assert len(warning_messages) == 1
|
||||
assert "checkpoints" in warning_messages[0].lower()
|
||||
assert "diffusion_models" in warning_messages[0].lower() or "unet" in warning_messages[0].lower()
|
||||
assert (
|
||||
"diffusion_models" in warning_messages[0].lower()
|
||||
or "unet" in warning_messages[0].lower()
|
||||
)
|
||||
# Verify warning mentions backward compatibility fallback
|
||||
assert "falling back" in warning_messages[0].lower() or "backward compatibility" in warning_messages[0].lower()
|
||||
assert (
|
||||
"falling back" in warning_messages[0].lower()
|
||||
or "backward compatibility" in warning_messages[0].lower()
|
||||
)
|
||||
|
||||
# Verify only one path is returned (deduplication still works)
|
||||
assert len(result) == 1
|
||||
assert len(all_paths) == 1
|
||||
# Prioritizes checkpoints path for backward compatibility
|
||||
assert _normalize(result[0]) == _normalize(str(checkpoints_link))
|
||||
assert _normalize(all_paths[0]) == _normalize(str(checkpoints_link))
|
||||
|
||||
# Verify checkpoints_roots has the path (prioritized)
|
||||
assert len(config.checkpoints_roots) == 1
|
||||
assert _normalize(config.checkpoints_roots[0]) == _normalize(str(checkpoints_link))
|
||||
# Verify checkpoint_roots has the path (prioritized)
|
||||
assert len(checkpoint_roots) == 1
|
||||
assert _normalize(checkpoint_roots[0]) == _normalize(str(checkpoints_link))
|
||||
|
||||
# Verify unet_roots is empty (overlapping paths removed)
|
||||
assert config.unet_roots == []
|
||||
assert unet_roots == []
|
||||
|
||||
def test_non_overlapping_paths_no_warning(
|
||||
self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog
|
||||
@@ -83,7 +89,7 @@ class TestCheckpointPathOverlap:
|
||||
config._preview_root_paths = set()
|
||||
config._cached_fingerprint = None
|
||||
|
||||
result = config._prepare_checkpoint_paths(
|
||||
all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
|
||||
[str(checkpoints_dir)], [str(unet_dir)]
|
||||
)
|
||||
|
||||
@@ -97,14 +103,14 @@ class TestCheckpointPathOverlap:
|
||||
assert len(warning_messages) == 0
|
||||
|
||||
# Verify both paths are returned
|
||||
assert len(result) == 2
|
||||
normalized_result = [_normalize(p) for p in result]
|
||||
assert len(all_paths) == 2
|
||||
normalized_result = [_normalize(p) for p in all_paths]
|
||||
assert _normalize(str(checkpoints_dir)) in normalized_result
|
||||
assert _normalize(str(unet_dir)) in normalized_result
|
||||
|
||||
# Verify both roots are properly set
|
||||
assert len(config.checkpoints_roots) == 1
|
||||
assert len(config.unet_roots) == 1
|
||||
assert len(checkpoint_roots) == 1
|
||||
assert len(unet_roots) == 1
|
||||
|
||||
def test_partial_overlap_prioritizes_checkpoints(
|
||||
self, monkeypatch: pytest.MonkeyPatch, tmp_path, caplog
|
||||
@@ -129,9 +135,9 @@ class TestCheckpointPathOverlap:
|
||||
config._cached_fingerprint = None
|
||||
|
||||
# One checkpoint path overlaps with one unet path
|
||||
result = config._prepare_checkpoint_paths(
|
||||
all_paths, checkpoint_roots, unet_roots = config._prepare_checkpoint_paths(
|
||||
[str(shared_link), str(separate_checkpoint)],
|
||||
[str(shared_link), str(separate_unet)]
|
||||
[str(shared_link), str(separate_unet)],
|
||||
)
|
||||
|
||||
# Verify warning was logged for the overlapping path
|
||||
@@ -144,17 +150,20 @@ class TestCheckpointPathOverlap:
|
||||
assert len(warning_messages) == 1
|
||||
|
||||
# Verify 3 unique paths (shared counted once as checkpoint, plus separate ones)
|
||||
assert len(result) == 3
|
||||
assert len(all_paths) == 3
|
||||
|
||||
# Verify the overlapping path appears in warning message
|
||||
assert str(shared_link.name) in warning_messages[0] or str(shared_dir.name) in warning_messages[0]
|
||||
assert (
|
||||
str(shared_link.name) in warning_messages[0]
|
||||
or str(shared_dir.name) in warning_messages[0]
|
||||
)
|
||||
|
||||
# Verify checkpoints_roots includes both checkpoint paths (including the shared one)
|
||||
assert len(config.checkpoints_roots) == 2
|
||||
checkpoint_normalized = [_normalize(p) for p in config.checkpoints_roots]
|
||||
# Verify checkpoint_roots includes both checkpoint paths (including the shared one)
|
||||
assert len(checkpoint_roots) == 2
|
||||
checkpoint_normalized = [_normalize(p) for p in checkpoint_roots]
|
||||
assert _normalize(str(shared_link)) in checkpoint_normalized
|
||||
assert _normalize(str(separate_checkpoint)) in checkpoint_normalized
|
||||
|
||||
# Verify unet_roots only includes the non-overlapping unet path
|
||||
assert len(config.unet_roots) == 1
|
||||
assert _normalize(config.unet_roots[0]) == _normalize(str(separate_unet))
|
||||
assert len(unet_roots) == 1
|
||||
assert _normalize(unet_roots[0]) == _normalize(str(separate_unet))
|
||||
|
||||
158
tests/test_checkpoint_loaders.py
Normal file
158
tests/test_checkpoint_loaders.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Tests for checkpoint and unet loaders with extra folder paths support"""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
|
||||
|
||||
# Get project root directory (ComfyUI-Lora-Manager folder)
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
class TestCheckpointLoaderLM:
|
||||
"""Test CheckpointLoaderLM node"""
|
||||
|
||||
def test_class_attributes(self):
|
||||
"""Test that CheckpointLoaderLM has required class attributes"""
|
||||
# Import in a way that doesn't require ComfyUI
|
||||
import ast
|
||||
|
||||
filepath = os.path.join(PROJECT_ROOT, "py", "nodes", "checkpoint_loader.py")
|
||||
|
||||
with open(filepath, "r") as f:
|
||||
tree = ast.parse(f.read())
|
||||
|
||||
# Find CheckpointLoaderLM class
|
||||
classes = {
|
||||
node.name: node for node in ast.walk(tree) if isinstance(node, ast.ClassDef)
|
||||
}
|
||||
assert "CheckpointLoaderLM" in classes
|
||||
|
||||
cls = classes["CheckpointLoaderLM"]
|
||||
|
||||
# Check for NAME attribute
|
||||
name_attr = [
|
||||
n
|
||||
for n in cls.body
|
||||
if isinstance(n, ast.Assign)
|
||||
and any(t.id == "NAME" for t in n.targets if isinstance(t, ast.Name))
|
||||
]
|
||||
assert len(name_attr) > 0, "CheckpointLoaderLM should have NAME attribute"
|
||||
|
||||
# Check for CATEGORY attribute
|
||||
cat_attr = [
|
||||
n
|
||||
for n in cls.body
|
||||
if isinstance(n, ast.Assign)
|
||||
and any(t.id == "CATEGORY" for t in n.targets if isinstance(t, ast.Name))
|
||||
]
|
||||
assert len(cat_attr) > 0, "CheckpointLoaderLM should have CATEGORY attribute"
|
||||
|
||||
# Check for INPUT_TYPES method
|
||||
input_types = [
|
||||
n
|
||||
for n in cls.body
|
||||
if isinstance(n, ast.FunctionDef) and n.name == "INPUT_TYPES"
|
||||
]
|
||||
assert len(input_types) > 0, "CheckpointLoaderLM should have INPUT_TYPES method"
|
||||
|
||||
# Check for load_checkpoint method
|
||||
load_method = [
|
||||
n
|
||||
for n in cls.body
|
||||
if isinstance(n, ast.FunctionDef) and n.name == "load_checkpoint"
|
||||
]
|
||||
assert len(load_method) > 0, (
|
||||
"CheckpointLoaderLM should have load_checkpoint method"
|
||||
)
|
||||
|
||||
|
||||
class TestUNETLoaderLM:
|
||||
"""Test UNETLoaderLM node"""
|
||||
|
||||
def test_class_attributes(self):
|
||||
"""Test that UNETLoaderLM has required class attributes"""
|
||||
# Import in a way that doesn't require ComfyUI
|
||||
import ast
|
||||
|
||||
filepath = os.path.join(PROJECT_ROOT, "py", "nodes", "unet_loader.py")
|
||||
|
||||
with open(filepath, "r") as f:
|
||||
tree = ast.parse(f.read())
|
||||
|
||||
# Find UNETLoaderLM class
|
||||
classes = {
|
||||
node.name: node for node in ast.walk(tree) if isinstance(node, ast.ClassDef)
|
||||
}
|
||||
assert "UNETLoaderLM" in classes
|
||||
|
||||
cls = classes["UNETLoaderLM"]
|
||||
|
||||
# Check for NAME attribute
|
||||
name_attr = [
|
||||
n
|
||||
for n in cls.body
|
||||
if isinstance(n, ast.Assign)
|
||||
and any(t.id == "NAME" for t in n.targets if isinstance(t, ast.Name))
|
||||
]
|
||||
assert len(name_attr) > 0, "UNETLoaderLM should have NAME attribute"
|
||||
|
||||
# Check for CATEGORY attribute
|
||||
cat_attr = [
|
||||
n
|
||||
for n in cls.body
|
||||
if isinstance(n, ast.Assign)
|
||||
and any(t.id == "CATEGORY" for t in n.targets if isinstance(t, ast.Name))
|
||||
]
|
||||
assert len(cat_attr) > 0, "UNETLoaderLM should have CATEGORY attribute"
|
||||
|
||||
# Check for INPUT_TYPES method
|
||||
input_types = [
|
||||
n
|
||||
for n in cls.body
|
||||
if isinstance(n, ast.FunctionDef) and n.name == "INPUT_TYPES"
|
||||
]
|
||||
assert len(input_types) > 0, "UNETLoaderLM should have INPUT_TYPES method"
|
||||
|
||||
# Check for load_unet method
|
||||
load_method = [
|
||||
n
|
||||
for n in cls.body
|
||||
if isinstance(n, ast.FunctionDef) and n.name == "load_unet"
|
||||
]
|
||||
assert len(load_method) > 0, "UNETLoaderLM should have load_unet method"
|
||||
|
||||
|
||||
class TestUtils:
|
||||
"""Test utility functions"""
|
||||
|
||||
def test_get_checkpoint_info_absolute_exists(self):
|
||||
"""Test that get_checkpoint_info_absolute function exists in utils"""
|
||||
import ast
|
||||
|
||||
filepath = os.path.join(PROJECT_ROOT, "py", "utils", "utils.py")
|
||||
|
||||
with open(filepath, "r") as f:
|
||||
tree = ast.parse(f.read())
|
||||
|
||||
functions = [
|
||||
node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)
|
||||
]
|
||||
assert "get_checkpoint_info_absolute" in functions, (
|
||||
"get_checkpoint_info_absolute should exist"
|
||||
)
|
||||
|
||||
def test_format_model_name_for_comfyui_exists(self):
|
||||
"""Test that _format_model_name_for_comfyui function exists in utils"""
|
||||
import ast
|
||||
|
||||
filepath = os.path.join(PROJECT_ROOT, "py", "utils", "utils.py")
|
||||
|
||||
with open(filepath, "r") as f:
|
||||
tree = ast.parse(f.read())
|
||||
|
||||
functions = [
|
||||
node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)
|
||||
]
|
||||
assert "_format_model_name_for_comfyui" in functions, (
|
||||
"_format_model_name_for_comfyui should exist"
|
||||
)
|
||||
Reference in New Issue
Block a user