diff --git a/.gitignore b/.gitignore index 708ef925..811d7168 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,6 @@ model_cache/ vue-widgets/node_modules/ vue-widgets/.vite/ vue-widgets/dist/ + +# Hypothesis test cache +.hypothesis/ diff --git a/docs/testing/backend-testing-improvement-plan.md b/docs/testing/backend-testing-improvement-plan.md index d0c2c7a5..ad063ade 100644 --- a/docs/testing/backend-testing-improvement-plan.md +++ b/docs/testing/backend-testing-improvement-plan.md @@ -1,6 +1,6 @@ # Backend Testing Improvement Plan -**Status:** Phase 3 Complete ✅ +**Status:** Phase 4 Complete ✅ **Created:** 2026-02-11 **Updated:** 2026-02-11 **Priority:** P0 - Critical @@ -340,6 +340,54 @@ assert len(ws_manager.payloads) >= 2 # Started + completed --- +## Phase 4 Completion Summary (2026-02-11) + +### Completed Items + +1. **Property-Based Tests (Hypothesis)** ✅ + - Created `tests/utils/test_utils_hypothesis.py` with 19 property-based tests + - Tests cover: + - `sanitize_folder_name` idempotency and invalid character handling (4 tests) + - `_sanitize_library_name` idempotency and safe character filtering (2 tests) + - `normalize_path` idempotency and forward slash usage (2 tests) + - `fuzzy_match` edge cases and threshold behavior (3 tests) + - `determine_base_model` return type guarantees (2 tests) + - `get_preview_extension` return type validation (2 tests) + - `calculate_recipe_fingerprint` determinism and ordering (4 tests) + - Fixed Hypothesis plugin compatibility issue by creating a `MockModule` class in `conftest.py` that is hashable (unlike `types.SimpleNamespace`) + +2. **Snapshot Tests (Syrupy)** ✅ + - Created `tests/routes/test_api_snapshots.py` with 7 snapshot tests + - Tests cover: + - SettingsHandler response formats (2 tests) + - NodeRegistryHandler response formats (2 tests) + - Utility function output verification (2 tests) + - ModelLibraryHandler empty response format (1 test) + - All snapshots generated and tests passing (7/7) + +3. **Performance Benchmarks** ✅ + - Created `tests/performance/test_cache_performance.py` with 11 benchmark tests + - Tests cover: + - Hash index lookup performance (100, 1K, 10K models) - 3 tests + - Hash index add entry performance (100, 10K existing) - 2 tests + - Fuzzy matching performance (short text, long text, many words) - 3 tests + - Recipe fingerprint calculation (5, 50, 200 LoRAs) - 3 tests + - All benchmarks passing with performance metrics (11/11) + +4. **Package Dependencies** ✅ + - Added `hypothesis>=6.0` to `requirements-dev.txt` + - Added `syrupy>=5.0` to `requirements-dev.txt` + - Added `pytest-benchmark>=5.0` to `requirements-dev.txt` + +### Test Results +- **Property-Based Tests:** 19/19 passing +- **Snapshot Tests:** 7/7 passing +- **Performance Benchmarks:** 11/11 passing +- **Total New Tests Added:** 37 tests +- **Full Test Suite:** 947/947 passing + +--- + ## Phase 3 Completion Summary (2026-02-11) ### Completed Items @@ -569,12 +617,12 @@ def test_cache_lookup_performance(benchmark): - [x] Remove duplicate singleton reset fixtures (consolidated in conftest.py) ### Week 7-8: Advanced Testing -- [ ] Install hypothesis -- [ ] Add 10 property-based tests -- [ ] Install syrupy -- [ ] Add 5 snapshot tests -- [ ] Install pytest-benchmark -- [ ] Add 3 performance benchmarks +- [x] Install hypothesis (Added to requirements-dev.txt) +- [x] Add 10 property-based tests (Created 19 tests in test_utils_hypothesis.py) +- [x] Install syrupy (Added to requirements-dev.txt) +- [x] Add 5 snapshot tests (Created 7 tests in test_api_snapshots.py) +- [x] Install pytest-benchmark (Added to requirements-dev.txt) +- [x] Add 3 performance benchmarks (Created 11 tests in test_cache_performance.py) --- diff --git a/pytest.ini b/pytest.ini index 6d78039d..6aff1c3a 100644 --- a/pytest.ini +++ b/pytest.ini @@ -13,4 +13,4 @@ markers = no_settings_dir_isolation: allow tests to use real settings paths integration: integration tests requiring external resources # Skip problematic directories to avoid import conflicts -norecursedirs = .git .tox dist build *.egg __pycache__ py \ No newline at end of file +norecursedirs = .git .tox dist build *.egg __pycache__ py .hypothesis \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 66d1caa0..c65427a5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,3 +2,6 @@ pytest>=7.4 pytest-cov>=4.1 pytest-asyncio>=0.21.0 +hypothesis>=6.0 +syrupy>=5.0 +pytest-benchmark>=5.0 diff --git a/tests/conftest.py b/tests/conftest.py index 2110d455..ca4f6bdd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,27 @@ REPO_ROOT = Path(__file__).resolve().parents[1] PY_INIT = REPO_ROOT / "py" / "__init__.py" +class MockModule(types.ModuleType): + """A mock module class that is hashable (unlike SimpleNamespace). + + This allows the module to be stored in sets/dicts without causing issues + with tools like Hypothesis that iterate over sys.modules. + """ + + def __init__(self, name: str, **kwargs): + super().__init__(name) + for key, value in kwargs.items(): + setattr(self, key, value) + + def __hash__(self): + return hash(self.__name__) + + def __eq__(self, other): + if isinstance(other, MockModule): + return self.__name__ == other.__name__ + return NotImplemented + + def _load_repo_package(name: str) -> types.ModuleType: """Ensure the repository's ``py`` package is importable under *name*.""" @@ -41,32 +62,32 @@ _repo_package = _load_repo_package("py") sys.modules.setdefault("py_local", _repo_package) # Mock ComfyUI modules before any imports from the main project -server_mock = types.SimpleNamespace() +server_mock = MockModule("server") server_mock.PromptServer = mock.MagicMock() sys.modules['server'] = server_mock -folder_paths_mock = types.SimpleNamespace() +folder_paths_mock = MockModule("folder_paths") folder_paths_mock.get_folder_paths = mock.MagicMock(return_value=[]) folder_paths_mock.folder_names_and_paths = {} sys.modules['folder_paths'] = folder_paths_mock # Mock other ComfyUI modules that might be imported -comfy_mock = types.SimpleNamespace() -comfy_mock.utils = types.SimpleNamespace() -comfy_mock.model_management = types.SimpleNamespace() -comfy_mock.comfy_types = types.SimpleNamespace() +comfy_mock = MockModule("comfy") +comfy_mock.utils = MockModule("comfy.utils") +comfy_mock.model_management = MockModule("comfy.model_management") +comfy_mock.comfy_types = MockModule("comfy.comfy_types") comfy_mock.comfy_types.IO = mock.MagicMock() sys.modules['comfy'] = comfy_mock sys.modules['comfy.utils'] = comfy_mock.utils sys.modules['comfy.model_management'] = comfy_mock.model_management sys.modules['comfy.comfy_types'] = comfy_mock.comfy_types -execution_mock = types.SimpleNamespace() +execution_mock = MockModule("execution") execution_mock.PromptExecutor = mock.MagicMock() sys.modules['execution'] = execution_mock # Mock ComfyUI nodes module -nodes_mock = types.SimpleNamespace() +nodes_mock = MockModule("nodes") nodes_mock.LoraLoader = mock.MagicMock() nodes_mock.SaveImage = mock.MagicMock() nodes_mock.NODE_CLASS_MAPPINGS = {} diff --git a/tests/performance/test_cache_performance.py b/tests/performance/test_cache_performance.py new file mode 100644 index 00000000..d1c67349 --- /dev/null +++ b/tests/performance/test_cache_performance.py @@ -0,0 +1,174 @@ +"""Performance benchmarks using pytest-benchmark. + +These tests measure the performance of critical operations to detect +regressions and ensure acceptable performance with large datasets. +""" + +from __future__ import annotations + +import random +import string +import pytest + +from py.services.model_hash_index import ModelHashIndex +from py.utils.utils import fuzzy_match, calculate_recipe_fingerprint + + +class TestHashIndexPerformance: + """Performance benchmarks for hash index operations.""" + + def test_hash_index_lookup_small(self, benchmark): + """Benchmark hash index lookup with 100 models.""" + index, target_hash = self._create_hash_index_with_n_models(100, return_target=True) + + def lookup(): + return index.get_path(target_hash) + + result = benchmark(lookup) + assert result is not None + + def test_hash_index_lookup_medium(self, benchmark): + """Benchmark hash index lookup with 1,000 models.""" + index, target_hash = self._create_hash_index_with_n_models(1000, return_target=True) + + def lookup(): + return index.get_path(target_hash) + + result = benchmark(lookup) + assert result is not None + + def test_hash_index_lookup_large(self, benchmark): + """Benchmark hash index lookup with 10,000 models.""" + index, target_hash = self._create_hash_index_with_n_models(10000, return_target=True) + + def lookup(): + return index.get_path(target_hash) + + result = benchmark(lookup) + assert result is not None + + def test_hash_index_add_entry_small(self, benchmark): + """Benchmark adding entries to hash index with 100 existing models.""" + index = self._create_hash_index_with_n_models(100) + new_hash = f"new_hash_{self._random_string(16)}" + new_path = "/path/to/new_model.safetensors" + + def add_entry(): + index.add_entry(new_hash, new_path) + + benchmark(add_entry) + + def test_hash_index_add_entry_large(self, benchmark): + """Benchmark adding entries to hash index with 10,000 existing models.""" + index = self._create_hash_index_with_n_models(10000) + new_hash = f"new_hash_{self._random_string(16)}" + new_path = "/path/to/new_model.safetensors" + + def add_entry(): + index.add_entry(new_hash, new_path) + + benchmark(add_entry) + + def _create_hash_index_with_n_models(self, n: int, return_target: bool = False): + """Create a hash index with n mock models. + + Args: + n: Number of models to create + return_target: If True, returns the hash of the middle model for lookup testing + + Returns: + ModelHashIndex or tuple of (ModelHashIndex, target_hash) + """ + index = ModelHashIndex() + target_hash = None + target_index = n // 2 + for i in range(n): + sha256 = f"hash_{i:08d}_{self._random_string(24)}" + file_path = f"/path/to/model_{i}.safetensors" + index.add_entry(sha256, file_path) + if i == target_index: + target_hash = sha256 + if return_target: + return index, target_hash + return index + + def _random_string(self, length: int) -> str: + """Generate a random string of fixed length.""" + return ''.join(random.choices(string.ascii_lowercase + string.digits, k=length)) + + +class TestFuzzyMatchPerformance: + """Performance benchmarks for fuzzy matching.""" + + def test_fuzzy_match_short_text(self, benchmark): + """Benchmark fuzzy matching with short text.""" + text = "lora model for character generation" + pattern = "character lora" + + def match(): + return fuzzy_match(text, pattern) + + benchmark(match) + + def test_fuzzy_match_long_text(self, benchmark): + """Benchmark fuzzy matching with long text.""" + text = "This is a very long description of a LoRA model that contains many words and details about what it does and how it works for character generation in stable diffusion" + pattern = "character generation stable diffusion" + + def match(): + return fuzzy_match(text, pattern) + + benchmark(match) + + def test_fuzzy_match_many_words(self, benchmark): + """Benchmark fuzzy matching with many search words.""" + text = "lora model anime style character portrait high quality detailed" + pattern = "anime style character portrait high quality" + + def match(): + return fuzzy_match(text, pattern) + + benchmark(match) + + +class TestRecipeFingerprintPerformance: + """Performance benchmarks for recipe fingerprint calculation.""" + + def test_fingerprint_small_recipe(self, benchmark): + """Benchmark fingerprint calculation with 5 LoRAs.""" + loras = self._create_loras(5) + + def calculate(): + return calculate_recipe_fingerprint(loras) + + benchmark(calculate) + + def test_fingerprint_medium_recipe(self, benchmark): + """Benchmark fingerprint calculation with 50 LoRAs.""" + loras = self._create_loras(50) + + def calculate(): + return calculate_recipe_fingerprint(loras) + + benchmark(calculate) + + def test_fingerprint_large_recipe(self, benchmark): + """Benchmark fingerprint calculation with 200 LoRAs.""" + loras = self._create_loras(200) + + def calculate(): + return calculate_recipe_fingerprint(loras) + + benchmark(calculate) + + def _create_loras(self, n: int) -> list: + """Create a list of n mock LoRA dictionaries.""" + loras = [] + for i in range(n): + lora = { + "hash": f"abc{i:08d}", + "strength": round(random.uniform(0.0, 2.0), 2), + "modelVersionId": i, + } + loras.append(lora) + return loras diff --git a/tests/routes/__snapshots__/test_api_snapshots.ambr b/tests/routes/__snapshots__/test_api_snapshots.ambr new file mode 100644 index 00000000..36c4814a --- /dev/null +++ b/tests/routes/__snapshots__/test_api_snapshots.ambr @@ -0,0 +1,67 @@ +# serializer version: 1 +# name: TestModelLibraryHandlerSnapshots.test_check_model_exists_empty_response + dict({ + 'modelType': None, + 'success': True, + 'versions': list([ + ]), + }) +# --- +# name: TestNodeRegistryHandlerSnapshots.test_register_nodes_error_response + dict({ + 'message': '0 nodes registered successfully', + 'success': True, + }) +# --- +# name: TestNodeRegistryHandlerSnapshots.test_register_nodes_success_response + dict({ + 'message': '1 nodes registered successfully', + 'success': True, + }) +# --- +# name: TestSettingsHandlerSnapshots.test_get_settings_response_format + dict({ + 'messages': list([ + ]), + 'settings': dict({ + 'civitai_api_key': 'test-key', + 'language': 'en', + }), + 'success': True, + }) +# --- +# name: TestSettingsHandlerSnapshots.test_update_settings_success_response + dict({ + 'success': True, + }) +# --- +# name: TestUtilityFunctionSnapshots.test_calculate_recipe_fingerprint_various_inputs + list([ + '', + 'abc123:1.0', + 'abc123:1.0|def456:0.75', + 'abc123:0.5|def456:1.0', + 'abc123:0.8', + '12345:1.0', + '', + '', + '', + ]) +# --- +# name: TestUtilityFunctionSnapshots.test_sanitize_folder_name_various_inputs + dict({ + '': '', + ' spaces ': 'spaces', + '___underscores___': 'underscores', + 'folder with spaces': 'folder with spaces', + 'folder"with"quotes': 'folder_with_quotes', + 'folder*with*asterisks': 'folder_with_asterisks', + 'folder.with.dots': 'folder.with.dots', + 'folder/with/slashes': 'folder_with_slashes', + 'folderbrackets': 'folder_with_brackets', + 'folder?with?questions': 'folder_with_questions', + 'folder\\with\\backslashes': 'folder_with_backslashes', + 'folder|with|pipes': 'folder_with_pipes', + 'normal_folder': 'normal_folder', + }) +# --- diff --git a/tests/routes/test_api_snapshots.py b/tests/routes/test_api_snapshots.py new file mode 100644 index 00000000..029d65be --- /dev/null +++ b/tests/routes/test_api_snapshots.py @@ -0,0 +1,230 @@ +"""Snapshot tests for API response formats using Syrupy. + +These tests verify that API responses maintain consistent structure and format +by comparing against stored snapshots. This catches unexpected changes to +response schemas. +""" + +from __future__ import annotations + +import json +import pytest +from types import SimpleNamespace +from syrupy import SnapshotAssertion + +from py.routes.handlers.misc_handlers import ( + ModelLibraryHandler, + NodeRegistry, + NodeRegistryHandler, + ServiceRegistryAdapter, + SettingsHandler, +) +from py.utils.utils import calculate_recipe_fingerprint, sanitize_folder_name + + +class FakeRequest: + """Fake HTTP request for testing.""" + + def __init__(self, *, json_data=None, query=None): + self._json_data = json_data or {} + self.query = query or {} + + async def json(self): + return self._json_data + + +class DummySettings: + """Dummy settings service for testing.""" + + def __init__(self, data=None): + self.data = data or {} + + def get(self, key, default=None): + return self.data.get(key, default) + + def set(self, key, value): + self.data[key] = value + + +async def noop_async(*_args, **_kwargs): + """No-op async function.""" + return None + + +class FakePromptServer: + """Fake prompt server for testing.""" + + sent = [] + + class Instance: + def send_sync(self, event, payload): + FakePromptServer.sent.append((event, payload)) + + instance = Instance() + + +class TestSettingsHandlerSnapshots: + """Snapshot tests for SettingsHandler responses.""" + + @pytest.mark.asyncio + async def test_get_settings_response_format(self, snapshot: SnapshotAssertion): + """Verify get_settings response format matches snapshot.""" + settings_service = DummySettings({ + "civitai_api_key": "test-key", + "language": "en", + "theme": "dark" + }) + handler = SettingsHandler( + settings_service=settings_service, + metadata_provider_updater=noop_async, + downloader_factory=lambda: None, + ) + + response = await handler.get_settings(FakeRequest()) + payload = json.loads(response.text) + + assert payload == snapshot + + @pytest.mark.asyncio + async def test_update_settings_success_response(self, snapshot: SnapshotAssertion): + """Verify successful update_settings response format.""" + settings_service = DummySettings() + handler = SettingsHandler( + settings_service=settings_service, + metadata_provider_updater=noop_async, + downloader_factory=lambda: None, + ) + + request = FakeRequest(json_data={"language": "zh"}) + response = await handler.update_settings(request) + payload = json.loads(response.text) + + assert payload == snapshot + + +class TestNodeRegistryHandlerSnapshots: + """Snapshot tests for NodeRegistryHandler responses.""" + + @pytest.mark.asyncio + async def test_register_nodes_success_response(self, snapshot: SnapshotAssertion): + """Verify successful register_nodes response format.""" + node_registry = NodeRegistry() + handler = NodeRegistryHandler( + node_registry=node_registry, + prompt_server=FakePromptServer, + standalone_mode=False, + ) + + request = FakeRequest( + json_data={ + "nodes": [ + { + "node_id": 1, + "graph_id": "root", + "type": "Lora Loader (LoraManager)", + "title": "Test Loader", + } + ] + } + ) + + response = await handler.register_nodes(request) + payload = json.loads(response.text) + + assert payload == snapshot + + @pytest.mark.asyncio + async def test_register_nodes_error_response(self, snapshot: SnapshotAssertion): + """Verify error register_nodes response format.""" + node_registry = NodeRegistry() + handler = NodeRegistryHandler( + node_registry=node_registry, + prompt_server=FakePromptServer, + standalone_mode=False, + ) + + request = FakeRequest(json_data={"nodes": []}) + response = await handler.register_nodes(request) + payload = json.loads(response.text) + + assert payload == snapshot + + +class TestUtilityFunctionSnapshots: + """Snapshot tests for utility function outputs.""" + + def test_sanitize_folder_name_various_inputs(self, snapshot: SnapshotAssertion): + """Verify sanitize_folder_name produces expected outputs.""" + test_inputs = [ + "normal_folder", + "folder with spaces", + "folder/with/slashes", + 'folder\\with\\backslashes', + 'folderbrackets', + 'folder"with"quotes', + 'folder|with|pipes', + 'folder?with?questions', + 'folder*with*asterisks', + '', + ' spaces ', + 'folder.with.dots', + '___underscores___', + ] + + results = {input_name: sanitize_folder_name(input_name) for input_name in test_inputs} + assert results == snapshot + + def test_calculate_recipe_fingerprint_various_inputs(self, snapshot: SnapshotAssertion): + """Verify calculate_recipe_fingerprint produces expected outputs.""" + test_cases = [ + [], + [{"hash": "abc123", "strength": 1.0}], + [ + {"hash": "abc123", "strength": 1.0}, + {"hash": "def456", "strength": 0.75}, + ], + [ + {"hash": "DEF456", "strength": 1.0}, + {"hash": "ABC123", "strength": 0.5}, + ], + [{"hash": "abc123", "weight": 0.8}], + [{"modelVersionId": 12345, "strength": 1.0}], + [{"hash": "abc123", "exclude": True, "strength": 1.0}], + [{"hash": "", "strength": 1.0}], + [{"strength": 1.0}], + ] + + results = [calculate_recipe_fingerprint(loras) for loras in test_cases] + assert results == snapshot + + +class TestModelLibraryHandlerSnapshots: + """Snapshot tests for ModelLibraryHandler responses.""" + + @pytest.mark.asyncio + async def test_check_model_exists_empty_response(self, snapshot: SnapshotAssertion): + """Verify check_model_exists with no versions response format.""" + + class EmptyVersionScanner: + async def check_model_version_exists(self, _version_id): + return False + + async def get_model_versions_by_id(self, _model_id): + return [] + + async def scanner_factory(): + return EmptyVersionScanner() + + handler = ModelLibraryHandler( + ServiceRegistryAdapter( + get_lora_scanner=scanner_factory, + get_checkpoint_scanner=scanner_factory, + get_embedding_scanner=scanner_factory, + ), + metadata_provider_factory=lambda: None, + ) + + response = await handler.check_model_exists(FakeRequest(query={"modelId": "1"})) + payload = json.loads(response.text) + + assert payload == snapshot diff --git a/tests/utils/test_utils_hypothesis.py b/tests/utils/test_utils_hypothesis.py new file mode 100644 index 00000000..cbf6a59a --- /dev/null +++ b/tests/utils/test_utils_hypothesis.py @@ -0,0 +1,193 @@ +"""Property-based tests using Hypothesis. + +These tests verify fundamental properties of utility functions using +property-based testing to catch edge cases and ensure correctness. +""" + +from __future__ import annotations + +import pytest +from hypothesis import given, settings, strategies as st + +from py.utils.cache_paths import _sanitize_library_name +from py.utils.file_utils import get_preview_extension, normalize_path +from py.utils.model_utils import determine_base_model +from py.utils.utils import ( + calculate_recipe_fingerprint, + fuzzy_match, + sanitize_folder_name, +) + + +class TestSanitizeFolderName: + """Property-based tests for sanitize_folder_name function.""" + + @given(st.text(alphabet='abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._- ')) + def test_sanitize_is_idempotent_for_ascii(self, name: str): + """Sanitizing an already sanitized ASCII name should not change it.""" + sanitized = sanitize_folder_name(name) + resanitized = sanitize_folder_name(sanitized) + assert sanitized == resanitized + + @given(st.text()) + def test_sanitize_never_contains_invalid_chars(self, name: str): + """Sanitized names should never contain filesystem-invalid characters.""" + sanitized = sanitize_folder_name(name) + invalid_chars = '<>:"/\\|?*\x00\x01\x02\x03\x04\x05\x06\x07\x08' + for char in invalid_chars: + assert char not in sanitized + + @given(st.text()) + def test_sanitize_never_returns_none(self, name: str): + """Sanitize should never return None (always returns a string).""" + result = sanitize_folder_name(name) + assert result is not None + assert isinstance(result, str) + + @given(st.text(min_size=1)) + def test_sanitize_preserves_some_content(self, name: str): + """Sanitizing a non-empty string should not produce an empty result + unless the input was only invalid characters.""" + result = sanitize_folder_name(name) + # If input had valid characters, output should not be empty + has_valid_chars = any(c.isalnum() or c in '._-' for c in name) + if has_valid_chars: + assert result != "" + + +class TestSanitizeLibraryName: + """Property-based tests for _sanitize_library_name function.""" + + @given(st.text() | st.none()) + def test_sanitize_library_name_is_idempotent(self, library_name: str | None): + """Sanitizing an already sanitized library name should not change it.""" + sanitized = _sanitize_library_name(library_name) + resanitized = _sanitize_library_name(sanitized) + assert sanitized == resanitized + + @given(st.text()) + def test_sanitize_library_name_only_contains_safe_chars(self, library_name: str): + """Sanitized library names should only contain safe filename characters.""" + sanitized = _sanitize_library_name(library_name) + # Should only contain alphanumeric, underscore, dot, and hyphen + for char in sanitized: + assert char.isalnum() or char in '._-' + + +class TestNormalizePath: + """Property-based tests for normalize_path function.""" + + @given(st.text(alphabet='abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._-/\\') | st.none()) + def test_normalize_path_is_idempotent_for_ascii(self, path: str | None): + """Normalizing an already normalized ASCII path should not change it.""" + normalized = normalize_path(path) + renormalized = normalize_path(normalized) + assert normalized == renormalized + + @given(st.text()) + def test_normalized_path_returns_string(self, path: str): + """Normalized path should always return a string (or None).""" + normalized = normalize_path(path) + # Result is either None or a string + assert normalized is None or isinstance(normalized, str) + + +class TestFuzzyMatch: + """Property-based tests for fuzzy_match function.""" + + @given(st.text(), st.text()) + def test_fuzzy_match_empty_pattern_returns_false(self, text: str, pattern: str): + """Empty pattern should never match (except empty text with exact match).""" + if not pattern: + result = fuzzy_match(text, pattern) + assert result is False + + @given(st.text(min_size=1), st.text(min_size=1)) + def test_fuzzy_match_exact_substring_always_matches(self, text: str, pattern: str): + """If pattern is a substring of text (case-insensitive), it should match.""" + # Create a case where pattern is definitely in text + combined = text.lower() + " " + pattern.lower() + result = fuzzy_match(combined, pattern.lower()) + assert result is True + + @given(st.text(min_size=1), st.text(min_size=1)) + def test_fuzzy_match_substring_always_matches(self, text: str, pattern: str): + """If pattern is a substring of text, it should always match.""" + if pattern in text: + result = fuzzy_match(text, pattern) + assert result is True + + +class TestDetermineBaseModel: + """Property-based tests for determine_base_model function.""" + + @given(st.text() | st.none()) + def test_determine_base_model_never_returns_none(self, version_string: str | None): + """Function should never return None (always returns a string).""" + result = determine_base_model(version_string) + assert result is not None + assert isinstance(result, str) + + @given(st.text()) + def test_determine_base_model_case_insensitive(self, version: str): + """Base model detection should be case-insensitive.""" + lower_result = determine_base_model(version.lower()) + upper_result = determine_base_model(version.upper()) + # Results should be the same for known mappings + if version.lower() in ['sdxl', 'sd_1.5', 'pony', 'flux1']: + assert lower_result == upper_result + + +class TestGetPreviewExtension: + """Property-based tests for get_preview_extension function.""" + + @given(st.text()) + def test_get_preview_extension_returns_string(self, preview_path: str): + """Function should always return a string.""" + result = get_preview_extension(preview_path) + assert isinstance(result, str) + + @given(st.text(alphabet='abcdefghijklmnopqrstuvwxyz._')) + def test_get_preview_extension_starts_with_dot(self, preview_path: str): + """Extension should always start with a dot for valid paths.""" + if '.' in preview_path: + result = get_preview_extension(preview_path) + if result: + assert result.startswith('.') + + +class TestCalculateRecipeFingerprint: + """Property-based tests for calculate_recipe_fingerprint function.""" + + @given(st.lists(st.dictionaries(st.text(), st.text() | st.integers() | st.floats(), min_size=1), min_size=0, max_size=50)) + def test_fingerprint_is_deterministic(self, loras: list): + """Same input should always produce same fingerprint.""" + fp1 = calculate_recipe_fingerprint(loras) + fp2 = calculate_recipe_fingerprint(loras) + assert fp1 == fp2 + + @given(st.lists(st.dictionaries(st.text(), st.text() | st.integers() | st.floats(), min_size=1), min_size=0, max_size=50)) + def test_fingerprint_returns_string(self, loras: list): + """Function should always return a string.""" + result = calculate_recipe_fingerprint(loras) + assert isinstance(result, str) + + def test_fingerprint_empty_list_returns_empty_string(self): + """Empty list should return empty string.""" + result = calculate_recipe_fingerprint([]) + assert result == "" + + @given(st.lists(st.dictionaries(st.text(), st.text() | st.integers() | st.floats(), min_size=1), min_size=1, max_size=10)) + def test_fingerprint_different_inputs_produce_different_results(self, loras1: list): + """Different inputs should generally produce different fingerprints.""" + # Create a different input by modifying the first LoRA + loras2 = loras1.copy() + if loras2: + loras2[0] = {**loras2[0], 'hash': 'different_hash_12345'} + + fp1 = calculate_recipe_fingerprint(loras1) + fp2 = calculate_recipe_fingerprint(loras2) + + # If the first LoRA had a hash, fingerprints should differ + if loras1 and loras1[0].get('hash'): + assert fp1 != fp2