mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
fix(autocomplete): improve tag search ranking with popularity-based sorting
- Add LOG10(post_count) weighting to BM25 score for better relevance ranking - Prioritize tag_name prefix matches above alias matches using CASE statement - Remove frontend re-scoring logic to trust backend排序 results - Fix pagination consistency: page N+1 scores <= page N minimum score Key improvements: - '1girl' (6M posts) now ranks #1 instead of #149 for search '1' - tag_name prefix matches always appear before alias matches - Popular tags rank higher than obscure ones with same prefix - Consistent ordering across pagination boundaries Test coverage: - Add test_search_tag_name_prefix_match_priority - Add test_search_ranks_popular_tags_higher - Add test_search_pagination_ordering_consistency - Add test_search_rank_score_includes_popularity_weight - Update test data with 15 tags starting with '1' Fixes issues with autocomplete dropdown showing inconsistent results when scrolling through paginated search results.
This commit is contained in:
@@ -156,4 +156,115 @@ describe('AutoComplete widget interactions', () => {
|
||||
expect(highlighted).toContain('detail');
|
||||
expect(highlighted).not.toMatch(/beta<\/span>/i);
|
||||
});
|
||||
|
||||
it('handles arrow key navigation with virtual scrolling', async () => {
|
||||
vi.useFakeTimers();
|
||||
|
||||
const mockItems = Array.from({ length: 50 }, (_, i) => `model_${i.toString().padStart(2, '0')}.safetensors`);
|
||||
|
||||
fetchApiMock.mockResolvedValue({
|
||||
json: () => Promise.resolve({ success: true, relative_paths: mockItems }),
|
||||
});
|
||||
|
||||
caretHelperInstance.getBeforeCursor.mockReturnValue('model');
|
||||
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
|
||||
|
||||
const input = document.createElement('textarea');
|
||||
document.body.append(input);
|
||||
|
||||
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||
const autoComplete = new AutoComplete(input, 'loras', {
|
||||
debounceDelay: 0,
|
||||
showPreview: false,
|
||||
enableVirtualScroll: true,
|
||||
itemHeight: 40,
|
||||
visibleItems: 15,
|
||||
pageSize: 20,
|
||||
});
|
||||
|
||||
input.value = 'model';
|
||||
input.dispatchEvent(new Event('input', { bubbles: true }));
|
||||
|
||||
await vi.runAllTimersAsync();
|
||||
await Promise.resolve();
|
||||
|
||||
expect(autoComplete.items.length).toBeGreaterThan(0);
|
||||
expect(autoComplete.selectedIndex).toBe(0);
|
||||
|
||||
const initialSelectedEl = autoComplete.contentContainer?.querySelector('.comfy-autocomplete-item-selected');
|
||||
expect(initialSelectedEl).toBeDefined();
|
||||
|
||||
const arrowDownEvent = new KeyboardEvent('keydown', { key: 'ArrowDown', bubbles: true });
|
||||
input.dispatchEvent(arrowDownEvent);
|
||||
|
||||
expect(autoComplete.selectedIndex).toBe(1);
|
||||
|
||||
const secondSelectedEl = autoComplete.contentContainer?.querySelector('.comfy-autocomplete-item-selected');
|
||||
expect(secondSelectedEl).toBeDefined();
|
||||
expect(secondSelectedEl?.dataset.index).toBe('1');
|
||||
|
||||
const arrowUpEvent = new KeyboardEvent('keydown', { key: 'ArrowUp', bubbles: true });
|
||||
input.dispatchEvent(arrowUpEvent);
|
||||
|
||||
expect(autoComplete.selectedIndex).toBe(0);
|
||||
|
||||
const firstSelectedElAgain = autoComplete.contentContainer?.querySelector('.comfy-autocomplete-item-selected');
|
||||
expect(firstSelectedElAgain).toBeDefined();
|
||||
expect(firstSelectedElAgain?.dataset.index).toBe('0');
|
||||
});
|
||||
|
||||
it('maintains selection when scrolling to invisible items', async () => {
|
||||
vi.useFakeTimers();
|
||||
|
||||
const mockItems = Array.from({ length: 100 }, (_, i) => `item_${i.toString().padStart(3, '0')}.safetensors`);
|
||||
|
||||
fetchApiMock.mockResolvedValue({
|
||||
json: () => Promise.resolve({ success: true, relative_paths: mockItems }),
|
||||
});
|
||||
|
||||
caretHelperInstance.getBeforeCursor.mockReturnValue('item');
|
||||
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
|
||||
|
||||
const input = document.createElement('textarea');
|
||||
input.style.width = '400px';
|
||||
input.style.height = '200px';
|
||||
document.body.append(input);
|
||||
|
||||
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||
const autoComplete = new AutoComplete(input, 'loras', {
|
||||
debounceDelay: 0,
|
||||
showPreview: false,
|
||||
enableVirtualScroll: true,
|
||||
itemHeight: 40,
|
||||
visibleItems: 15,
|
||||
pageSize: 20,
|
||||
});
|
||||
|
||||
input.value = 'item';
|
||||
input.dispatchEvent(new Event('input', { bubbles: true }));
|
||||
|
||||
await vi.runAllTimersAsync();
|
||||
await Promise.resolve();
|
||||
|
||||
expect(autoComplete.items.length).toBeGreaterThan(0);
|
||||
|
||||
autoComplete.selectedIndex = 14;
|
||||
|
||||
const scrollTopBefore = autoComplete.scrollContainer?.scrollTop || 0;
|
||||
|
||||
const arrowDownEvent = new KeyboardEvent('keydown', { key: 'ArrowDown', bubbles: true });
|
||||
input.dispatchEvent(arrowDownEvent);
|
||||
|
||||
await vi.runAllTimersAsync();
|
||||
await Promise.resolve();
|
||||
|
||||
expect(autoComplete.selectedIndex).toBe(15);
|
||||
|
||||
const selectedEl = autoComplete.contentContainer?.querySelector('.comfy-autocomplete-item-selected');
|
||||
expect(selectedEl).toBeDefined();
|
||||
expect(selectedEl?.dataset.index).toBe('15');
|
||||
|
||||
const scrollTopAfter = autoComplete.scrollContainer?.scrollTop || 0;
|
||||
expect(scrollTopAfter).toBeGreaterThanOrEqual(scrollTopBefore);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -31,10 +31,27 @@ def temp_db_path():
|
||||
@pytest.fixture
|
||||
def temp_csv_path():
|
||||
"""Create a temporary CSV file with test data."""
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False, encoding="utf-8") as f:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".csv", delete=False, encoding="utf-8"
|
||||
) as f:
|
||||
# Write test data in the same format as danbooru_e621_merged.csv
|
||||
# Format: tag_name,category,post_count,aliases
|
||||
# Include multiple tags starting with "1" to test popularity-based ranking
|
||||
f.write('1girl,0,6008644,"1girls,sole_female"\n')
|
||||
f.write('1boy,0,1405457,"1boys,sole_male"\n')
|
||||
f.write('1:1,14,377032,""\n')
|
||||
f.write('16:9,14,152866,""\n')
|
||||
f.write('1other,0,70962,""\n')
|
||||
f.write('16:10,14,14739,""\n')
|
||||
f.write('1990s_(style),0,9369,""\n')
|
||||
f.write('1_eye,0,7179,""\n')
|
||||
f.write('1:2,14,5865,""\n')
|
||||
f.write('1980s_(style),0,5665,""\n')
|
||||
f.write('1koma,0,4384,""\n')
|
||||
f.write('1_horn,0,2122,""\n')
|
||||
f.write('101_dalmatian_street,3,1933,""\n')
|
||||
f.write('1upgobbo,3,1731,""\n')
|
||||
f.write('14:9,14,1038,""\n')
|
||||
f.write('highres,5,5256195,"high_res,high_resolution,hires"\n')
|
||||
f.write('solo,0,5000954,"alone,female_solo,single"\n')
|
||||
f.write('hatsune_miku,4,500000,"miku"\n')
|
||||
@@ -86,7 +103,7 @@ class TestTagFTSIndexBuild:
|
||||
fts.build_index()
|
||||
|
||||
assert fts.is_ready() is True
|
||||
assert fts.get_indexed_count() == 10
|
||||
assert fts.get_indexed_count() == 24
|
||||
|
||||
def test_build_index_nonexistent_csv(self, temp_db_path):
|
||||
"""Test that build_index handles missing CSV gracefully."""
|
||||
@@ -187,6 +204,76 @@ class TestTagFTSIndexSearch:
|
||||
results = populated_fts.search("girl", limit=1)
|
||||
assert len(results) <= 1
|
||||
|
||||
def test_search_tag_name_prefix_match_priority(self, populated_fts):
|
||||
"""Test that tag_name prefix matches rank higher than alias matches."""
|
||||
results = populated_fts.search("1", limit=20)
|
||||
|
||||
assert len(results) > 0, "Should return results for '1'"
|
||||
|
||||
# Find first alias match (if any)
|
||||
first_alias_idx = None
|
||||
for i, result in enumerate(results):
|
||||
if result.get("matched_alias"):
|
||||
first_alias_idx = i
|
||||
break
|
||||
|
||||
# All tag_name prefix matches should come before alias matches
|
||||
if first_alias_idx is not None:
|
||||
for i in range(first_alias_idx):
|
||||
assert results[i]["tag_name"].lower().startswith("1"), (
|
||||
f"Tag at index {i} should start with '1' before alias matches"
|
||||
)
|
||||
|
||||
def test_search_ranks_popular_tags_higher(self, populated_fts):
|
||||
"""Test that tags with higher post_count rank higher among prefix matches."""
|
||||
results = populated_fts.search("1", limit=20)
|
||||
|
||||
# Filter to only tag_name prefix matches
|
||||
prefix_matches = [r for r in results if r["tag_name"].lower().startswith("1")]
|
||||
|
||||
assert len(prefix_matches) > 1, "Should have multiple prefix matches"
|
||||
|
||||
# Verify descending post_count order among prefix matches
|
||||
for i in range(len(prefix_matches) - 1):
|
||||
assert (
|
||||
prefix_matches[i]["post_count"] >= prefix_matches[i + 1]["post_count"]
|
||||
), (
|
||||
f"Tags should be sorted by post_count: {prefix_matches[i]['tag_name']} ({prefix_matches[i]['post_count']}) >= {prefix_matches[i + 1]['tag_name']} ({prefix_matches[i + 1]['post_count']})"
|
||||
)
|
||||
|
||||
def test_search_pagination_ordering_consistency(self, populated_fts):
|
||||
"""Test that pagination maintains consistent ordering."""
|
||||
page1 = populated_fts.search("1", limit=10, offset=0)
|
||||
page2 = populated_fts.search("1", limit=10, offset=10)
|
||||
|
||||
assert len(page1) > 0, "Page 1 should have results"
|
||||
assert len(page2) > 0, "Page 2 should have results"
|
||||
|
||||
# Page 2 scores should all be <= Page 1 min score
|
||||
page1_min_score = min(r["rank_score"] for r in page1)
|
||||
page2_max_score = max(r["rank_score"] for r in page2)
|
||||
|
||||
assert page2_max_score <= page1_min_score, (
|
||||
f"Page 2 max score ({page2_max_score}) should be <= Page 1 min score ({page1_min_score})"
|
||||
)
|
||||
|
||||
def test_search_rank_score_includes_popularity_weight(self, populated_fts):
|
||||
"""Test that rank_score includes post_count popularity weighting."""
|
||||
results = populated_fts.search("1", limit=5)
|
||||
|
||||
assert len(results) >= 2, "Need at least 2 results to compare"
|
||||
|
||||
# 1girl has 6M posts, should have higher rank_score than tags with fewer posts
|
||||
girl_result = next((r for r in results if r["tag_name"] == "1girl"), None)
|
||||
assert girl_result is not None, "1girl should be in results"
|
||||
|
||||
# Find a tag with significantly fewer posts
|
||||
low_post_result = next((r for r in results if r["post_count"] < 10000), None)
|
||||
if low_post_result:
|
||||
assert girl_result["rank_score"] > low_post_result["rank_score"], (
|
||||
f"1girl (6M posts) should have higher score than {low_post_result['tag_name']} ({low_post_result['post_count']} posts)"
|
||||
)
|
||||
|
||||
|
||||
class TestAliasSearch:
|
||||
"""Tests for alias search functionality."""
|
||||
@@ -204,7 +291,9 @@ class TestAliasSearch:
|
||||
results = populated_fts.search("miku")
|
||||
|
||||
assert len(results) >= 1
|
||||
hatsune_result = next((r for r in results if r["tag_name"] == "hatsune_miku"), None)
|
||||
hatsune_result = next(
|
||||
(r for r in results if r["tag_name"] == "hatsune_miku"), None
|
||||
)
|
||||
assert hatsune_result is not None
|
||||
assert hatsune_result["matched_alias"] == "miku"
|
||||
|
||||
@@ -214,7 +303,9 @@ class TestAliasSearch:
|
||||
results = populated_fts.search("hatsune")
|
||||
|
||||
assert len(results) >= 1
|
||||
hatsune_result = next((r for r in results if r["tag_name"] == "hatsune_miku"), None)
|
||||
hatsune_result = next(
|
||||
(r for r in results if r["tag_name"] == "hatsune_miku"), None
|
||||
)
|
||||
assert hatsune_result is not None
|
||||
assert "matched_alias" not in hatsune_result
|
||||
|
||||
@@ -301,7 +392,9 @@ class TestSlashPrefixAliases:
|
||||
@pytest.fixture
|
||||
def fts_with_slash_aliases(self, temp_db_path):
|
||||
"""Create an FTS index with slash-prefixed aliases."""
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False, encoding="utf-8") as f:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".csv", delete=False, encoding="utf-8"
|
||||
) as f:
|
||||
# Format: tag_name,category,post_count,aliases
|
||||
f.write('long_hair,0,4350743,"/lh,longhair,very_long_hair"\n')
|
||||
f.write('breasts,0,3439214,"/b,boobs,oppai"\n')
|
||||
@@ -380,7 +473,15 @@ class TestCategoryMappings:
|
||||
|
||||
def test_category_name_to_ids_complete(self):
|
||||
"""Test that CATEGORY_NAME_TO_IDS includes all expected names."""
|
||||
expected_names = ["general", "artist", "copyright", "character", "meta", "species", "lore"]
|
||||
expected_names = [
|
||||
"general",
|
||||
"artist",
|
||||
"copyright",
|
||||
"character",
|
||||
"meta",
|
||||
"species",
|
||||
"lore",
|
||||
]
|
||||
for name in expected_names:
|
||||
assert name in CATEGORY_NAME_TO_IDS
|
||||
assert isinstance(CATEGORY_NAME_TO_IDS[name], list)
|
||||
|
||||
Reference in New Issue
Block a user