diff --git a/locales/de.json b/locales/de.json index 0a92283b..316d661c 100644 --- a/locales/de.json +++ b/locales/de.json @@ -223,7 +223,11 @@ "noCreditRequired": "Kein Credit erforderlich", "allowSellingGeneratedContent": "Verkauf erlaubt", "noTags": "Keine Tags", - "clearAll": "Alle Filter löschen" + "clearAll": "Alle Filter löschen", + "any": "Beliebig", + "all": "Alle", + "tagLogicAny": "Jedes Tag abgleichen (ODER)", + "tagLogicAll": "Alle Tags abgleichen (UND)" }, "theme": { "toggle": "Theme wechseln", diff --git a/locales/en.json b/locales/en.json index 61bb02ab..970549ae 100644 --- a/locales/en.json +++ b/locales/en.json @@ -223,7 +223,11 @@ "noCreditRequired": "No Credit Required", "allowSellingGeneratedContent": "Allow Selling", "noTags": "No tags", - "clearAll": "Clear All Filters" + "clearAll": "Clear All Filters", + "any": "Any", + "all": "All", + "tagLogicAny": "Match any tag (OR)", + "tagLogicAll": "Match all tags (AND)" }, "theme": { "toggle": "Toggle theme", diff --git a/locales/es.json b/locales/es.json index cf5a2662..904c8878 100644 --- a/locales/es.json +++ b/locales/es.json @@ -223,7 +223,11 @@ "noCreditRequired": "Sin crédito requerido", "allowSellingGeneratedContent": "Venta permitida", "noTags": "Sin etiquetas", - "clearAll": "Limpiar todos los filtros" + "clearAll": "Limpiar todos los filtros", + "any": "Cualquiera", + "all": "Todos", + "tagLogicAny": "Coincidir con cualquier etiqueta (O)", + "tagLogicAll": "Coincidir con todas las etiquetas (Y)" }, "theme": { "toggle": "Cambiar tema", diff --git a/locales/fr.json b/locales/fr.json index 792ce01f..edca5d25 100644 --- a/locales/fr.json +++ b/locales/fr.json @@ -223,7 +223,11 @@ "noCreditRequired": "Crédit non requis", "allowSellingGeneratedContent": "Vente autorisée", "noTags": "Aucun tag", - "clearAll": "Effacer tous les filtres" + "clearAll": "Effacer tous les filtres", + "any": "N'importe quel", + "all": "Tous", + "tagLogicAny": "Correspondre à n'importe quel tag (OU)", + "tagLogicAll": "Correspondre à tous les tags (ET)" }, "theme": { "toggle": "Basculer le thème", diff --git a/locales/he.json b/locales/he.json index e5eb3d91..24f5a6c8 100644 --- a/locales/he.json +++ b/locales/he.json @@ -223,7 +223,11 @@ "noCreditRequired": "ללא קרדיט נדרש", "allowSellingGeneratedContent": "אפשר מכירה", "noTags": "ללא תגיות", - "clearAll": "נקה את כל המסננים" + "clearAll": "נקה את כל המסננים", + "any": "כלשהו", + "all": "כל התגים", + "tagLogicAny": "התאם כל תג (או)", + "tagLogicAll": "התאם את כל התגים (וגם)" }, "theme": { "toggle": "החלף ערכת נושא", diff --git a/locales/ja.json b/locales/ja.json index 03b6f666..d357391c 100644 --- a/locales/ja.json +++ b/locales/ja.json @@ -223,7 +223,11 @@ "noCreditRequired": "クレジット不要", "allowSellingGeneratedContent": "販売許可", "noTags": "タグなし", - "clearAll": "すべてのフィルタをクリア" + "clearAll": "すべてのフィルタをクリア", + "any": "いずれか", + "all": "すべて", + "tagLogicAny": "いずれかのタグに一致 (OR)", + "tagLogicAll": "すべてのタグに一致 (AND)" }, "theme": { "toggle": "テーマの切り替え", diff --git a/locales/ko.json b/locales/ko.json index 5819af65..ddd44127 100644 --- a/locales/ko.json +++ b/locales/ko.json @@ -223,7 +223,11 @@ "noCreditRequired": "크레딧 표기 없음", "allowSellingGeneratedContent": "판매 허용", "noTags": "태그 없음", - "clearAll": "모든 필터 지우기" + "clearAll": "모든 필터 지우기", + "any": "아무", + "all": "모두", + "tagLogicAny": "모든 태그 일치 (OR)", + "tagLogicAll": "모든 태그 일치 (AND)" }, "theme": { "toggle": "테마 토글", diff --git a/locales/ru.json b/locales/ru.json index 1ce928cf..c810423d 100644 --- a/locales/ru.json +++ b/locales/ru.json @@ -223,7 +223,11 @@ "noCreditRequired": "Без указания авторства", "allowSellingGeneratedContent": "Продажа разрешена", "noTags": "Без тегов", - "clearAll": "Очистить все фильтры" + "clearAll": "Очистить все фильтры", + "any": "Любой", + "all": "Все", + "tagLogicAny": "Совпадение с любым тегом (ИЛИ)", + "tagLogicAll": "Совпадение со всеми тегами (И)" }, "theme": { "toggle": "Переключить тему", diff --git a/locales/zh-CN.json b/locales/zh-CN.json index 6d9aab11..a3cb1c86 100644 --- a/locales/zh-CN.json +++ b/locales/zh-CN.json @@ -223,7 +223,11 @@ "noCreditRequired": "无需署名", "allowSellingGeneratedContent": "允许销售", "noTags": "无标签", - "clearAll": "清除所有筛选" + "clearAll": "清除所有筛选", + "any": "任一", + "all": "全部", + "tagLogicAny": "匹配任一标签 (或)", + "tagLogicAll": "匹配所有标签 (与)" }, "theme": { "toggle": "切换主题", diff --git a/locales/zh-TW.json b/locales/zh-TW.json index 4ecaf4bd..b9c662fc 100644 --- a/locales/zh-TW.json +++ b/locales/zh-TW.json @@ -223,7 +223,11 @@ "noCreditRequired": "無需署名", "allowSellingGeneratedContent": "允許銷售", "noTags": "無標籤", - "clearAll": "清除所有篩選" + "clearAll": "清除所有篩選", + "any": "任一", + "all": "全部", + "tagLogicAny": "符合任一票籤 (或)", + "tagLogicAll": "符合所有標籤 (與)" }, "theme": { "toggle": "切換主題", diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index 62aae575..7cb9cd6e 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -270,6 +270,11 @@ class ModelListingHandler: request.query.get("update_available_only", "false").lower() == "true" ) + # Tag logic: "any" (OR) or "all" (AND) for include tags + tag_logic = request.query.get("tag_logic", "any").lower() + if tag_logic not in ("any", "all"): + tag_logic = "any" + # New license-based query filters credit_required = request.query.get("credit_required") if credit_required is not None: @@ -298,6 +303,7 @@ class ModelListingHandler: "fuzzy_search": fuzzy_search, "base_models": base_models, "tags": tag_filters, + "tag_logic": tag_logic, "search_options": search_options, "hash_filters": hash_filters, "favorites_only": favorites_only, diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index 7faab2d2..fbd9a6f2 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -81,6 +81,7 @@ class BaseModelService(ABC): update_available_only: bool = False, credit_required: Optional[bool] = None, allow_selling_generated_content: Optional[bool] = None, + tag_logic: str = "any", **kwargs, ) -> Dict: """Get paginated and filtered model data""" @@ -109,6 +110,7 @@ class BaseModelService(ABC): tags=tags, favorites_only=favorites_only, search_options=search_options, + tag_logic=tag_logic, ) if search: @@ -241,6 +243,7 @@ class BaseModelService(ABC): tags: Optional[Dict[str, str]] = None, favorites_only: bool = False, search_options: dict = None, + tag_logic: str = "any", ) -> List[Dict]: """Apply common filters that work across all model types""" normalized_options = self.search_strategy.normalize_options(search_options) @@ -253,6 +256,7 @@ class BaseModelService(ABC): tags=tags, favorites_only=favorites_only, search_options=normalized_options, + tag_logic=tag_logic, ) return self.filter_set.apply(data, criteria) diff --git a/py/services/model_query.py b/py/services/model_query.py index 4666c5e6..0cc91c40 100644 --- a/py/services/model_query.py +++ b/py/services/model_query.py @@ -99,6 +99,7 @@ class FilterCriteria: favorites_only: bool = False search_options: Optional[Dict[str, Any]] = None model_types: Optional[Sequence[str]] = None + tag_logic: str = "any" # "any" (OR) or "all" (AND) class ModelCacheRepository: @@ -300,11 +301,29 @@ class ModelFilterSet: include_tags = {tag for tag in tag_filters if tag} if include_tags: + tag_logic = criteria.tag_logic.lower() if criteria.tag_logic else "any" def matches_include(item_tags): if not item_tags and "__no_tags__" in include_tags: return True - return any(tag in include_tags for tag in (item_tags or [])) + if tag_logic == "all": + # AND logic: item must have ALL include tags + # Special case: __no_tags__ is handled separately + non_special_tags = include_tags - {"__no_tags__"} + if "__no_tags__" in include_tags: + # If __no_tags__ is selected along with other tags, + # treat it as "no tags OR (all other tags)" + if not item_tags: + return True + # Otherwise, check if all non-special tags match + if non_special_tags: + return all(tag in (item_tags or []) for tag in non_special_tags) + return True + # Normal case: all tags must match + return all(tag in (item_tags or []) for tag in non_special_tags) + else: + # OR logic (default): item must have ANY include tag + return any(tag in include_tags for tag in (item_tags or [])) items = [item for item in items if matches_include(item.get("tags"))] diff --git a/static/css/components/search-filter.css b/static/css/components/search-filter.css index d1671283..b5ddee4a 100644 --- a/static/css/components/search-filter.css +++ b/static/css/components/search-filter.css @@ -673,6 +673,57 @@ +/* Tag Logic Toggle Styles */ +.filter-section-header { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 8px; +} + +.filter-section-header h4 { + margin: 0; +} + +.tag-logic-toggle { + display: flex; + background-color: var(--lora-surface); + border: 1px solid var(--border-color); + border-radius: var(--border-radius-sm); + overflow: hidden; +} + +.tag-logic-option { + background: none; + border: none; + padding: 2px 8px; + font-size: 11px; + cursor: pointer; + color: var(--text-color); + opacity: 0.7; + transition: all 0.2s ease; + font-weight: 500; +} + +.tag-logic-option:hover { + opacity: 1; + background-color: var(--lora-surface-hover); +} + +.tag-logic-option.active { + background-color: var(--lora-accent); + color: white; + opacity: 1; +} + +.tag-logic-option:first-child { + border-right: 1px solid var(--border-color); +} + +.tag-logic-option.active:first-child { + border-right: 1px solid rgba(255, 255, 255, 0.3); +} + /* Mobile adjustments */ @media (max-width: 768px) { .search-options-panel, diff --git a/static/js/api/baseModelApi.js b/static/js/api/baseModelApi.js index 9e2a9d09..4d144f6a 100644 --- a/static/js/api/baseModelApi.js +++ b/static/js/api/baseModelApi.js @@ -924,6 +924,11 @@ export class BaseModelApiClient { params.append('model_type', type); }); } + + // Add tag logic parameter (any = OR, all = AND) + if (pageState.filters.tagLogic) { + params.append('tag_logic', pageState.filters.tagLogic); + } } this._addModelSpecificParams(params, pageState); diff --git a/static/js/managers/FilterManager.js b/static/js/managers/FilterManager.js index 3cf61233..1b62eab1 100644 --- a/static/js/managers/FilterManager.js +++ b/static/js/managers/FilterManager.js @@ -63,6 +63,9 @@ export class FilterManager { this.initializeLicenseFilters(); } + // Initialize tag logic toggle + this.initializeTagLogicToggle(); + // Add click handler for filter button if (this.filterButton) { this.filterButton.addEventListener('click', () => { @@ -84,6 +87,45 @@ export class FilterManager { this.loadFiltersFromStorage(); } + initializeTagLogicToggle() { + const toggleContainer = document.getElementById('tagLogicToggle'); + if (!toggleContainer) return; + + const options = toggleContainer.querySelectorAll('.tag-logic-option'); + + options.forEach(option => { + option.addEventListener('click', async () => { + const value = option.dataset.value; + if (this.filters.tagLogic === value) return; + + this.filters.tagLogic = value; + this.updateTagLogicToggleUI(); + + // Auto-apply filter when logic changes + await this.applyFilters(false); + }); + }); + + // Set initial state + this.updateTagLogicToggleUI(); + } + + updateTagLogicToggleUI() { + const toggleContainer = document.getElementById('tagLogicToggle'); + if (!toggleContainer) return; + + const options = toggleContainer.querySelectorAll('.tag-logic-option'); + const currentLogic = this.filters.tagLogic || 'any'; + + options.forEach(option => { + if (option.dataset.value === currentLogic) { + option.classList.add('active'); + } else { + option.classList.remove('active'); + } + }); + } + async loadTopTags() { try { // Show loading state @@ -573,9 +615,13 @@ export class FilterManager { baseModel: [], tags: {}, license: {}, - modelTypes: [] + modelTypes: [], + tagLogic: 'any' }); + // Update tag logic toggle UI + this.updateTagLogicToggleUI(); + // Update state const pageState = getCurrentPageState(); pageState.filters = this.cloneFilters(); @@ -620,6 +666,7 @@ export class FilterManager { pageState.filters = this.cloneFilters(); this.updateTagSelections(); + this.updateTagLogicToggleUI(); this.updateActiveFiltersCount(); if (this.hasActiveFilters()) { @@ -655,7 +702,8 @@ export class FilterManager { baseModel: Array.isArray(source.baseModel) ? [...source.baseModel] : [], tags: this.normalizeTagFilters(source.tags), license: this.shouldShowLicenseFilters() ? this.normalizeLicenseFilters(source.license) : {}, - modelTypes: this.normalizeModelTypeFilters(source.modelTypes) + modelTypes: this.normalizeModelTypeFilters(source.modelTypes), + tagLogic: source.tagLogic || 'any' }; } @@ -737,7 +785,8 @@ export class FilterManager { baseModel: [...(this.filters.baseModel || [])], tags: { ...(this.filters.tags || {}) }, license: { ...(this.filters.license || {}) }, - modelTypes: [...(this.filters.modelTypes || [])] + modelTypes: [...(this.filters.modelTypes || [])], + tagLogic: this.filters.tagLogic || 'any' }; } diff --git a/templates/components/header.html b/templates/components/header.html index b980c924..b5494b70 100644 --- a/templates/components/header.html +++ b/templates/components/header.html @@ -150,7 +150,13 @@
-

{{ t('header.filter.modelTags') }}

+
+

{{ t('header.filter.modelTags') }}

+
+ + +
+
{{ t('common.status.loading') }}
diff --git a/tests/frontend/managers/FilterManager.tagLogic.test.js b/tests/frontend/managers/FilterManager.tagLogic.test.js new file mode 100644 index 00000000..07081163 --- /dev/null +++ b/tests/frontend/managers/FilterManager.tagLogic.test.js @@ -0,0 +1,290 @@ +import { describe, it, expect, beforeEach, vi } from 'vitest'; + +// Mock dependencies +vi.mock('../../../static/js/state/index.js', () => ({ + getCurrentPageState: vi.fn(() => ({ + filters: {}, + })), + state: { + currentPageType: 'loras', + loadingManager: { + showSimpleLoading: vi.fn(), + hide: vi.fn(), + }, + }, +})); + +vi.mock('../../../static/js/utils/uiHelpers.js', () => ({ + showToast: vi.fn(), + updatePanelPositions: vi.fn(), +})); + +vi.mock('../../../static/js/api/modelApiFactory.js', () => ({ + getModelApiClient: vi.fn(() => ({ + loadMoreWithVirtualScroll: vi.fn().mockResolvedValue(), + })), +})); + +vi.mock('../../../static/js/utils/storageHelpers.js', () => ({ + getStorageItem: vi.fn(), + setStorageItem: vi.fn(), + removeStorageItem: vi.fn(), +})); + +vi.mock('../../../static/js/utils/i18nHelpers.js', () => ({ + translate: vi.fn((key, _params, fallback) => fallback || key), +})); + +vi.mock('../../../static/js/managers/FilterPresetManager.js', () => ({ + FilterPresetManager: vi.fn().mockImplementation(() => ({ + renderPresets: vi.fn(), + saveActivePreset: vi.fn(), + restoreActivePreset: vi.fn(), + updateAddButtonState: vi.fn(), + hasEmptyWildcardResult: vi.fn(() => false), + })), + EMPTY_WILDCARD_MARKER: '__EMPTY_WILDCARD_RESULT__', +})); + +import { FilterManager } from '../../../static/js/managers/FilterManager.js'; +import { getStorageItem, setStorageItem } from '../../../static/js/utils/storageHelpers.js'; + +describe('FilterManager - Tag Logic', () => { + let manager; + let mockFilterPanel; + let mockTagLogicToggle; + + beforeEach(() => { + vi.clearAllMocks(); + + // Setup DOM mocks + mockFilterPanel = document.createElement('div'); + mockFilterPanel.id = 'filterPanel'; + mockFilterPanel.classList.add('hidden'); + + mockTagLogicToggle = document.createElement('div'); + mockTagLogicToggle.id = 'tagLogicToggle'; + + // Create tag logic options + const anyOption = document.createElement('button'); + anyOption.className = 'tag-logic-option'; + anyOption.dataset.value = 'any'; + mockTagLogicToggle.appendChild(anyOption); + + const allOption = document.createElement('button'); + allOption.className = 'tag-logic-option'; + allOption.dataset.value = 'all'; + mockTagLogicToggle.appendChild(allOption); + + document.body.appendChild(mockFilterPanel); + document.body.appendChild(mockTagLogicToggle); + + // Mock getElementById + const originalGetElementById = document.getElementById; + document.getElementById = vi.fn((id) => { + if (id === 'filterPanel') return mockFilterPanel; + if (id === 'tagLogicToggle') return mockTagLogicToggle; + if (id === 'filterButton') return document.createElement('button'); + if (id === 'activeFiltersCount') return document.createElement('span'); + if (id === 'baseModelTags') return document.createElement('div'); + if (id === 'modelTypeTags') return document.createElement('div'); + return originalGetElementById.call(document, id); + }); + }); + + describe('initializeFilters', () => { + it('should default tagLogic to "any" when not provided', () => { + manager = new FilterManager({ page: 'loras' }); + + expect(manager.filters.tagLogic).toBe('any'); + }); + + it('should use provided tagLogic value', () => { + getStorageItem.mockReturnValue({ + tagLogic: 'all', + tags: {}, + baseModel: [], + }); + + manager = new FilterManager({ page: 'loras' }); + + expect(manager.filters.tagLogic).toBe('all'); + }); + }); + + describe('initializeTagLogicToggle', () => { + it('should set "any" option as active by default', () => { + manager = new FilterManager({ page: 'loras' }); + + // Ensure filters.tagLogic is set to default + manager.filters.tagLogic = 'any'; + + const anyOption = mockTagLogicToggle.querySelector('[data-value="any"]'); + const allOption = mockTagLogicToggle.querySelector('[data-value="all"]'); + + // Manually update UI to ensure correct state + manager.updateTagLogicToggleUI(); + + expect(manager.filters.tagLogic).toBe('any'); + expect(anyOption.classList.contains('active')).toBe(true); + expect(allOption.classList.contains('active')).toBe(false); + }); + + it('should set "all" option as active when tagLogic is "all"', () => { + getStorageItem.mockReturnValue({ + tagLogic: 'all', + tags: {}, + baseModel: [], + }); + + manager = new FilterManager({ page: 'loras' }); + + // Ensure filters.tagLogic is set correctly + manager.filters.tagLogic = 'all'; + + const anyOption = mockTagLogicToggle.querySelector('[data-value="any"]'); + const allOption = mockTagLogicToggle.querySelector('[data-value="all"]'); + + // Manually update UI to ensure correct state + manager.updateTagLogicToggleUI(); + + expect(manager.filters.tagLogic).toBe('all'); + expect(anyOption.classList.contains('active')).toBe(false); + expect(allOption.classList.contains('active')).toBe(true); + }); + }); + + describe('updateTagLogicToggleUI', () => { + it('should update UI when tagLogic changes', () => { + // Clear any existing active classes first + mockTagLogicToggle.querySelectorAll('.tag-logic-option').forEach(el => { + el.classList.remove('active'); + }); + + manager = new FilterManager({ page: 'loras' }); + + let anyOption = mockTagLogicToggle.querySelector('[data-value="any"]'); + let allOption = mockTagLogicToggle.querySelector('[data-value="all"]'); + + // Ensure initial state + manager.filters.tagLogic = 'any'; + manager.updateTagLogicToggleUI(); + expect(anyOption.classList.contains('active')).toBe(true); + expect(allOption.classList.contains('active')).toBe(false); + + // Change to "all" + manager.filters.tagLogic = 'all'; + manager.updateTagLogicToggleUI(); + + expect(anyOption.classList.contains('active')).toBe(false); + expect(allOption.classList.contains('active')).toBe(true); + }); + }); + + describe('cloneFilters', () => { + it('should include tagLogic in cloned filters', () => { + manager = new FilterManager({ page: 'loras' }); + manager.filters.tagLogic = 'all'; + + const cloned = manager.cloneFilters(); + + expect(cloned.tagLogic).toBe('all'); + }); + }); + + describe('clearFilters', () => { + it('should reset tagLogic to "any"', () => { + getStorageItem.mockReturnValue({ + tagLogic: 'all', + tags: { anime: 'include' }, + baseModel: ['SDXL'], + }); + + manager = new FilterManager({ page: 'loras' }); + expect(manager.filters.tagLogic).toBe('all'); + + manager.clearFilters(); + + expect(manager.filters.tagLogic).toBe('any'); + }); + + it('should update UI after clearing', () => { + getStorageItem.mockReturnValue({ + tagLogic: 'all', + tags: {}, + baseModel: [], + }); + + manager = new FilterManager({ page: 'loras' }); + + const anyOption = mockTagLogicToggle.querySelector('[data-value="any"]'); + const allOption = mockTagLogicToggle.querySelector('[data-value="all"]'); + + // Initially "all" is active + expect(allOption.classList.contains('active')).toBe(true); + + manager.clearFilters(); + + // After clear, "any" should be active + expect(anyOption.classList.contains('active')).toBe(true); + expect(allOption.classList.contains('active')).toBe(false); + }); + }); + + describe('loadFiltersFromStorage', () => { + it('should restore tagLogic from storage', () => { + getStorageItem.mockReturnValue({ + tagLogic: 'all', + tags: { anime: 'include' }, + baseModel: [], + }); + + manager = new FilterManager({ page: 'loras' }); + + expect(manager.filters.tagLogic).toBe('all'); + expect(manager.filters.tags).toEqual({ anime: 'include' }); + }); + + it('should default to "any" when no tagLogic in storage', () => { + getStorageItem.mockReturnValue({ + tags: {}, + baseModel: [], + }); + + manager = new FilterManager({ page: 'loras' }); + + expect(manager.filters.tagLogic).toBe('any'); + }); + }); + + describe('tag logic toggle interaction', () => { + it('should update tagLogic when clicking "all" option', async () => { + manager = new FilterManager({ page: 'loras' }); + + const allOption = mockTagLogicToggle.querySelector('[data-value="all"]'); + + // Simulate click + allOption.click(); + + // Wait for async operation + await new Promise(resolve => setTimeout(resolve, 0)); + + expect(manager.filters.tagLogic).toBe('all'); + }); + + it('should not change tagLogic when clicking already active option', async () => { + manager = new FilterManager({ page: 'loras' }); + + const anyOption = mockTagLogicToggle.querySelector('[data-value="any"]'); + const applyFiltersSpy = vi.spyOn(manager, 'applyFilters'); + + // Click already active option + anyOption.click(); + + await new Promise(resolve => setTimeout(resolve, 0)); + + // applyFilters should not be called since value didn't change + expect(applyFiltersSpy).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/tests/routes/test_tag_logic_param_parsing.py b/tests/routes/test_tag_logic_param_parsing.py new file mode 100644 index 00000000..d10006a0 --- /dev/null +++ b/tests/routes/test_tag_logic_param_parsing.py @@ -0,0 +1,166 @@ +"""Tests for tag_logic parameter parsing in model handlers.""" + +import pytest +from unittest.mock import Mock +from aiohttp import web +from aiohttp.test_utils import TestClient, TestServer + +import sys +import types + +folder_paths_stub = types.SimpleNamespace(get_folder_paths=lambda *_: []) +sys.modules.setdefault("folder_paths", folder_paths_stub) + +from py.routes.handlers.model_handlers import ModelListingHandler + + +class MockService: + """Mock service for testing.""" + + def __init__(self): + self.model_type = "test-model" + + async def get_paginated_data(self, **kwargs): + # Store the kwargs for verification + self.last_call_kwargs = kwargs + return { + "items": [], + "total": 0, + "page": 1, + "page_size": 20, + "total_pages": 0, + } + + async def format_response(self, item): + return item + + +def parse_specific_params(request): + """No specific params for testing.""" + return {} + + +@pytest.fixture +def handler(): + service = MockService() + logger = Mock() + return ModelListingHandler( + service=service, + parse_specific_params=parse_specific_params, + logger=logger, + ), service + + +async def make_request(handler, query_string=""): + """Helper to create a request and call get_models.""" + app = web.Application() + + async def test_handler(request): + return await handler.get_models(request) + + app.router.add_get("/test", test_handler) + server = TestServer(app) + client = TestClient(server) + await client.start_server() + + try: + response = await client.get(f"/test?{query_string}") + return response + finally: + await client.close() + + +@pytest.mark.asyncio +async def test_tag_logic_param_default_is_any(handler): + """Test that tag_logic defaults to 'any' when not provided.""" + h, service = handler + + response = await make_request(h, "tag_include=anime&tag_include=realistic") + assert response.status == 200 + + # Verify tag_logic was set to 'any' by default + assert service.last_call_kwargs["tag_logic"] == "any" + + +@pytest.mark.asyncio +async def test_tag_logic_param_explicit_any(handler): + """Test that tag_logic='any' is correctly parsed.""" + h, service = handler + + response = await make_request(h, "tag_include=anime&tag_logic=any") + assert response.status == 200 + + assert service.last_call_kwargs["tag_logic"] == "any" + + +@pytest.mark.asyncio +async def test_tag_logic_param_explicit_all(handler): + """Test that tag_logic='all' is correctly parsed.""" + h, service = handler + + response = await make_request(h, "tag_include=anime&tag_include=realistic&tag_logic=all") + assert response.status == 200 + + assert service.last_call_kwargs["tag_logic"] == "all" + + +@pytest.mark.asyncio +async def test_tag_logic_param_case_insensitive(handler): + """Test that tag_logic values are case insensitive.""" + h, service = handler + + # Test uppercase + response = await make_request(h, "tag_logic=ALL") + assert response.status == 200 + assert service.last_call_kwargs["tag_logic"] == "all" + + # Test mixed case + response = await make_request(h, "tag_logic=Any") + assert response.status == 200 + assert service.last_call_kwargs["tag_logic"] == "any" + + +@pytest.mark.asyncio +async def test_tag_logic_param_invalid_value_defaults_to_any(handler): + """Test that invalid tag_logic values default to 'any'.""" + h, service = handler + + response = await make_request(h, "tag_logic=invalid") + assert response.status == 200 + + # Should default to 'any' for invalid values + assert service.last_call_kwargs["tag_logic"] == "any" + + +@pytest.mark.asyncio +async def test_tag_logic_param_with_other_filters(handler): + """Test that tag_logic works correctly with other filter parameters.""" + h, service = handler + + query = ( + "tag_include=anime&" + "tag_include=character&" + "tag_exclude=nsfw&" + "base_model=SDXL&" + "tag_logic=all" + ) + response = await make_request(h, query) + assert response.status == 200 + + assert service.last_call_kwargs["tag_logic"] == "all" + assert service.last_call_kwargs["base_models"] == ["SDXL"] + assert "anime" in service.last_call_kwargs["tags"] + assert "character" in service.last_call_kwargs["tags"] + assert "nsfw" in service.last_call_kwargs["tags"] + + +@pytest.mark.asyncio +async def test_tag_logic_without_include_tags(handler): + """Test that tag_logic is still passed even without include tags.""" + h, service = handler + + response = await make_request(h, "tag_logic=all&base_model=SDXL") + assert response.status == 200 + + # tag_logic should still be set even without tag filters + assert service.last_call_kwargs["tag_logic"] == "all" diff --git a/tests/services/test_tag_logic_filter.py b/tests/services/test_tag_logic_filter.py new file mode 100644 index 00000000..0a8b2d76 --- /dev/null +++ b/tests/services/test_tag_logic_filter.py @@ -0,0 +1,276 @@ +"""Tests for tag logic (OR/AND) filtering functionality.""" + +import pytest +from py.services.model_query import ModelFilterSet, FilterCriteria + + +class StubSettings: + def get(self, key, default=None): + return default + + +class TestTagLogicFilter: + """Test cases for tag_logic parameter in FilterCriteria.""" + + def test_tag_logic_any_returns_items_with_any_tag(self): + """Test that tag_logic='any' (OR) returns items matching any include tag.""" + filter_set = ModelFilterSet(StubSettings()) + data = [ + {"name": "m1", "tags": ["anime"]}, + {"name": "m2", "tags": ["realistic"]}, + {"name": "m3", "tags": ["anime", "realistic"]}, + {"name": "m4", "tags": ["style"]}, + {"name": "m5", "tags": []}, + ] + + # Include anime OR realistic (should match m1, m2, m3) + criteria = FilterCriteria( + tags={"anime": "include", "realistic": "include"}, + tag_logic="any" + ) + result = filter_set.apply(data, criteria) + assert len(result) == 3 + assert {item["name"] for item in result} == {"m1", "m2", "m3"} + + def test_tag_logic_all_returns_items_with_all_tags(self): + """Test that tag_logic='all' (AND) returns only items matching all include tags.""" + filter_set = ModelFilterSet(StubSettings()) + data = [ + {"name": "m1", "tags": ["anime"]}, + {"name": "m2", "tags": ["realistic"]}, + {"name": "m3", "tags": ["anime", "realistic"]}, + {"name": "m4", "tags": ["style"]}, + {"name": "m5", "tags": []}, + ] + + # Include anime AND realistic (should match only m3) + criteria = FilterCriteria( + tags={"anime": "include", "realistic": "include"}, + tag_logic="all" + ) + result = filter_set.apply(data, criteria) + assert len(result) == 1 + assert result[0]["name"] == "m3" + + def test_tag_logic_all_with_single_tag(self): + """Test that tag_logic='all' with single tag works same as 'any'.""" + filter_set = ModelFilterSet(StubSettings()) + data = [ + {"name": "m1", "tags": ["anime"]}, + {"name": "m2", "tags": ["realistic"]}, + {"name": "m3", "tags": ["anime", "realistic"]}, + ] + + # Include only anime with 'all' logic + criteria = FilterCriteria( + tags={"anime": "include"}, + tag_logic="all" + ) + result = filter_set.apply(data, criteria) + assert len(result) == 2 + assert {item["name"] for item in result} == {"m1", "m3"} + + def test_tag_logic_any_with_exclude_tags(self): + """Test that tag_logic='any' works correctly with exclude tags.""" + filter_set = ModelFilterSet(StubSettings()) + data = [ + {"name": "m1", "tags": ["anime"]}, + {"name": "m2", "tags": ["realistic"]}, + {"name": "m3", "tags": ["anime", "realistic"]}, + {"name": "m4", "tags": ["nsfw"]}, + {"name": "m5", "tags": ["anime", "nsfw"]}, + ] + + # Include anime OR realistic, exclude nsfw + criteria = FilterCriteria( + tags={ + "anime": "include", + "realistic": "include", + "nsfw": "exclude" + }, + tag_logic="any" + ) + result = filter_set.apply(data, criteria) + # Should match m1 (anime), m2 (realistic), m3 (both) + # m4 excluded by nsfw, m5 excluded by nsfw + assert len(result) == 3 + assert {item["name"] for item in result} == {"m1", "m2", "m3"} + + def test_tag_logic_all_with_exclude_tags(self): + """Test that tag_logic='all' works correctly with exclude tags.""" + filter_set = ModelFilterSet(StubSettings()) + data = [ + {"name": "m1", "tags": ["anime", "character"]}, + {"name": "m2", "tags": ["realistic", "character"]}, + {"name": "m3", "tags": ["anime", "realistic", "character"]}, + {"name": "m4", "tags": ["anime", "character", "nsfw"]}, + ] + + # Include anime AND character, exclude nsfw + criteria = FilterCriteria( + tags={ + "anime": "include", + "character": "include", + "nsfw": "exclude" + }, + tag_logic="all" + ) + result = filter_set.apply(data, criteria) + # m1: has anime+character, no nsfw ✓ + # m2: missing anime ✗ + # m3: has anime+character, no nsfw ✓ + # m4: has anime+character but also nsfw ✗ + assert len(result) == 2 + assert {item["name"] for item in result} == {"m1", "m3"} + + def test_tag_logic_all_with_no_tags_special_case(self): + """Test tag_logic='all' with __no_tags__ special tag. + + When __no_tags__ is used with 'all' logic along with regular tags, + the behavior is: items with no tags are returned (since they satisfy + __no_tags__), OR items that have all the regular tags. + This is because __no_tags__ is a special condition that can't be ANDed + with regular tags in a meaningful way. + """ + filter_set = ModelFilterSet(StubSettings()) + data = [ + {"name": "m1", "tags": ["anime"]}, + {"name": "m2", "tags": []}, + {"name": "m3", "tags": None}, + {"name": "m4", "tags": ["anime", "character"]}, + ] + + # Include anime AND __no_tags__ with 'all' logic + # Implementation treats this as: no tags OR (all regular tags) + criteria = FilterCriteria( + tags={"anime": "include", "__no_tags__": "include"}, + tag_logic="all" + ) + result = filter_set.apply(data, criteria) + # Items with no tags: m2, m3 + # Items with all regular tags (anime): m1, m4 + # Combined: m1, m2, m3, m4 (all items) + assert len(result) == 4 + + def test_tag_logic_any_with_no_tags_special_case(self): + """Test tag_logic='any' with __no_tags__ special tag.""" + filter_set = ModelFilterSet(StubSettings()) + data = [ + {"name": "m1", "tags": ["anime"]}, + {"name": "m2", "tags": []}, + {"name": "m3", "tags": None}, + {"name": "m4", "tags": ["realistic"]}, + ] + + # Include anime OR __no_tags__ + criteria = FilterCriteria( + tags={"anime": "include", "__no_tags__": "include"}, + tag_logic="any" + ) + result = filter_set.apply(data, criteria) + # Should match m1 (anime), m2 (no tags), m3 (no tags) + assert len(result) == 3 + assert {item["name"] for item in result} == {"m1", "m2", "m3"} + + def test_tag_logic_default_is_any(self): + """Test that default tag_logic is 'any' when not specified.""" + filter_set = ModelFilterSet(StubSettings()) + data = [ + {"name": "m1", "tags": ["anime"]}, + {"name": "m2", "tags": ["realistic"]}, + {"name": "m3", "tags": ["anime", "realistic"]}, + ] + + # Not specifying tag_logic should default to 'any' + criteria = FilterCriteria( + tags={"anime": "include", "realistic": "include"} + ) + result = filter_set.apply(data, criteria) + # Should match m1, m2, m3 (OR behavior) + assert len(result) == 3 + assert {item["name"] for item in result} == {"m1", "m2", "m3"} + + def test_tag_logic_case_insensitive(self): + """Test that tag_logic values are case insensitive.""" + filter_set = ModelFilterSet(StubSettings()) + data = [ + {"name": "m1", "tags": ["anime"]}, + {"name": "m2", "tags": ["realistic"]}, + {"name": "m3", "tags": ["anime", "realistic"]}, + ] + + # Test uppercase 'ALL' + criteria = FilterCriteria( + tags={"anime": "include", "realistic": "include"}, + tag_logic="ALL" + ) + result = filter_set.apply(data, criteria) + assert len(result) == 1 + assert result[0]["name"] == "m3" + + # Test mixed case 'Any' + criteria = FilterCriteria( + tags={"anime": "include", "realistic": "include"}, + tag_logic="Any" + ) + result = filter_set.apply(data, criteria) + assert len(result) == 3 + + def test_tag_logic_all_with_three_tags(self): + """Test tag_logic='all' with three include tags.""" + filter_set = ModelFilterSet(StubSettings()) + data = [ + {"name": "m1", "tags": ["anime"]}, + {"name": "m2", "tags": ["anime", "character"]}, + {"name": "m3", "tags": ["anime", "character", "style"]}, + {"name": "m4", "tags": ["character", "style"]}, + ] + + # Include anime AND character AND style + criteria = FilterCriteria( + tags={ + "anime": "include", + "character": "include", + "style": "include" + }, + tag_logic="all" + ) + result = filter_set.apply(data, criteria) + # Only m3 has all three tags + assert len(result) == 1 + assert result[0]["name"] == "m3" + + def test_tag_logic_empty_include_tags(self): + """Test that empty include tags with any logic returns all items.""" + filter_set = ModelFilterSet(StubSettings()) + data = [ + {"name": "m1", "tags": ["anime"]}, + {"name": "m2", "tags": ["realistic"]}, + ] + + # Only exclude tags, no include tags + criteria = FilterCriteria( + tags={"nsfw": "exclude"}, + tag_logic="all" + ) + result = filter_set.apply(data, criteria) + # Both should match since no include filters + assert len(result) == 2 + + def test_tag_logic_with_none_tags_field(self): + """Test tag_logic handles items with None tags field.""" + filter_set = ModelFilterSet(StubSettings()) + data = [ + {"name": "m1", "tags": ["anime", "realistic"]}, + {"name": "m2", "tags": None}, + {"name": "m3", "tags": ["anime"]}, + ] + + criteria = FilterCriteria( + tags={"anime": "include", "realistic": "include"}, + tag_logic="all" + ) + result = filter_set.apply(data, criteria) + # Only m1 has both anime and realistic + assert len(result) == 1 + assert result[0]["name"] == "m1"