Files
ComfyUI-Lora-Manager/tests/test_persistent_recipe_cache.py
Will Miao e335a527d4 test: Complete Phase 2 - Integration & Coverage improvements
- Create tests/integration/ directory with conftest.py fixtures
- Add 7 download flow integration tests (test_download_flow.py)
- Add 9 recipe flow integration tests (test_recipe_flow.py)
- Add 12 ModelLifecycleService tests (exclude_model, bulk_delete, error paths)
- Add 5 PersistentRecipeCache concurrent access tests
- Update backend-testing-improvement-plan.md with Phase 2 completion

Total: 28 new tests, all passing (51/51)
2026-02-11 10:55:19 +08:00

468 lines
16 KiB
Python

"""Tests for PersistentRecipeCache."""
import json
import os
import tempfile
from typing import Dict, List
import pytest
from py.services.persistent_recipe_cache import PersistentRecipeCache, PersistedRecipeData
@pytest.fixture
def temp_db_path():
"""Create a temporary database path."""
with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as f:
path = f.name
yield path
# Cleanup
if os.path.exists(path):
os.unlink(path)
# Also clean up WAL files
for suffix in ["-wal", "-shm"]:
wal_path = path + suffix
if os.path.exists(wal_path):
os.unlink(wal_path)
@pytest.fixture
def sample_recipes() -> List[Dict]:
"""Create sample recipe data."""
return [
{
"id": "recipe-001",
"file_path": "/path/to/image1.png",
"title": "Test Recipe 1",
"folder": "folder1",
"base_model": "SD1.5",
"fingerprint": "abc123",
"created_date": 1700000000.0,
"modified": 1700000100.0,
"favorite": True,
"repair_version": 3,
"preview_nsfw_level": 1,
"loras": [
{"hash": "hash1", "file_name": "lora1", "strength": 0.8},
{"hash": "hash2", "file_name": "lora2", "strength": 1.0},
],
"checkpoint": {"name": "model.safetensors", "hash": "cphash"},
"gen_params": {"prompt": "test prompt", "negative_prompt": "bad"},
"tags": ["tag1", "tag2"],
},
{
"id": "recipe-002",
"file_path": "/path/to/image2.png",
"title": "Test Recipe 2",
"folder": "",
"base_model": "SDXL",
"fingerprint": "def456",
"created_date": 1700000200.0,
"modified": 1700000300.0,
"favorite": False,
"repair_version": 2,
"preview_nsfw_level": 0,
"loras": [{"hash": "hash3", "file_name": "lora3", "strength": 0.5}],
"gen_params": {"prompt": "another prompt"},
"tags": [],
},
]
class TestPersistentRecipeCache:
"""Tests for PersistentRecipeCache class."""
def test_init_creates_db(self, temp_db_path):
"""Test that initialization creates the database."""
cache = PersistentRecipeCache(db_path=temp_db_path)
assert cache.is_enabled()
assert os.path.exists(temp_db_path)
def test_save_and_load_roundtrip(self, temp_db_path, sample_recipes):
"""Test save and load cycle preserves data."""
cache = PersistentRecipeCache(db_path=temp_db_path)
# Save recipes
json_paths = {
"recipe-001": "/path/to/recipe-001.recipe.json",
"recipe-002": "/path/to/recipe-002.recipe.json",
}
cache.save_cache(sample_recipes, json_paths)
# Load recipes
loaded = cache.load_cache()
assert loaded is not None
assert len(loaded.raw_data) == 2
# Verify first recipe
r1 = next(r for r in loaded.raw_data if r["id"] == "recipe-001")
assert r1["title"] == "Test Recipe 1"
assert r1["folder"] == "folder1"
assert r1["base_model"] == "SD1.5"
assert r1["fingerprint"] == "abc123"
assert r1["favorite"] is True
assert r1["repair_version"] == 3
assert len(r1["loras"]) == 2
assert r1["loras"][0]["hash"] == "hash1"
assert r1["checkpoint"]["name"] == "model.safetensors"
assert r1["gen_params"]["prompt"] == "test prompt"
assert r1["tags"] == ["tag1", "tag2"]
# Verify second recipe
r2 = next(r for r in loaded.raw_data if r["id"] == "recipe-002")
assert r2["title"] == "Test Recipe 2"
assert r2["folder"] == ""
assert r2["favorite"] is False
def test_empty_cache_returns_none(self, temp_db_path):
"""Test that loading empty cache returns None."""
cache = PersistentRecipeCache(db_path=temp_db_path)
loaded = cache.load_cache()
assert loaded is None
def test_update_single_recipe(self, temp_db_path, sample_recipes):
"""Test updating a single recipe."""
cache = PersistentRecipeCache(db_path=temp_db_path)
cache.save_cache(sample_recipes)
# Update a recipe
updated_recipe = dict(sample_recipes[0])
updated_recipe["title"] = "Updated Title"
updated_recipe["favorite"] = False
cache.update_recipe(updated_recipe, "/path/to/recipe-001.recipe.json")
# Load and verify
loaded = cache.load_cache()
r1 = next(r for r in loaded.raw_data if r["id"] == "recipe-001")
assert r1["title"] == "Updated Title"
assert r1["favorite"] is False
def test_remove_recipe(self, temp_db_path, sample_recipes):
"""Test removing a recipe."""
cache = PersistentRecipeCache(db_path=temp_db_path)
cache.save_cache(sample_recipes)
# Remove a recipe
cache.remove_recipe("recipe-001")
# Load and verify
loaded = cache.load_cache()
assert len(loaded.raw_data) == 1
assert loaded.raw_data[0]["id"] == "recipe-002"
def test_get_indexed_recipe_ids(self, temp_db_path, sample_recipes):
"""Test getting all indexed recipe IDs."""
cache = PersistentRecipeCache(db_path=temp_db_path)
cache.save_cache(sample_recipes)
ids = cache.get_indexed_recipe_ids()
assert ids == {"recipe-001", "recipe-002"}
def test_get_recipe_count(self, temp_db_path, sample_recipes):
"""Test getting recipe count."""
cache = PersistentRecipeCache(db_path=temp_db_path)
assert cache.get_recipe_count() == 0
cache.save_cache(sample_recipes)
assert cache.get_recipe_count() == 2
cache.remove_recipe("recipe-001")
assert cache.get_recipe_count() == 1
def test_file_stats(self, temp_db_path, sample_recipes):
"""Test file stats tracking."""
cache = PersistentRecipeCache(db_path=temp_db_path)
json_paths = {
"recipe-001": "/path/to/recipe-001.recipe.json",
"recipe-002": "/path/to/recipe-002.recipe.json",
}
cache.save_cache(sample_recipes, json_paths)
stats = cache.get_file_stats()
# File stats will be (0.0, 0) since files don't exist
assert len(stats) == 2
def test_disabled_cache(self, temp_db_path, sample_recipes, monkeypatch):
"""Test that disabled cache returns None."""
monkeypatch.setenv("LORA_MANAGER_DISABLE_PERSISTENT_CACHE", "1")
cache = PersistentRecipeCache(db_path=temp_db_path)
assert not cache.is_enabled()
cache.save_cache(sample_recipes)
assert cache.load_cache() is None
def test_invalid_recipe_skipped(self, temp_db_path):
"""Test that recipes without ID are skipped."""
cache = PersistentRecipeCache(db_path=temp_db_path)
recipes = [
{"title": "No ID recipe"}, # Missing ID
{"id": "valid-001", "title": "Valid recipe"},
]
cache.save_cache(recipes)
loaded = cache.load_cache()
assert len(loaded.raw_data) == 1
assert loaded.raw_data[0]["id"] == "valid-001"
def test_get_default_singleton(self, monkeypatch):
"""Test singleton behavior."""
# Use temp directory
with tempfile.TemporaryDirectory() as tmpdir:
monkeypatch.setenv("LORA_MANAGER_RECIPE_CACHE_DB", os.path.join(tmpdir, "test.sqlite"))
PersistentRecipeCache.clear_instances()
cache1 = PersistentRecipeCache.get_default("test_lib")
cache2 = PersistentRecipeCache.get_default("test_lib")
assert cache1 is cache2
cache3 = PersistentRecipeCache.get_default("other_lib")
assert cache1 is not cache3
PersistentRecipeCache.clear_instances()
def test_loras_json_handling(self, temp_db_path):
"""Test that complex loras data is preserved."""
cache = PersistentRecipeCache(db_path=temp_db_path)
recipes = [
{
"id": "complex-001",
"title": "Complex Loras",
"loras": [
{
"hash": "abc123",
"file_name": "test_lora",
"strength": 0.75,
"modelVersionId": 12345,
"modelName": "Test Model",
"isDeleted": False,
},
{
"hash": "def456",
"file_name": "another_lora",
"strength": 1.0,
"clip_strength": 0.8,
},
],
}
]
cache.save_cache(recipes)
loaded = cache.load_cache()
loras = loaded.raw_data[0]["loras"]
assert len(loras) == 2
assert loras[0]["modelVersionId"] == 12345
assert loras[1]["clip_strength"] == 0.8
# =============================================================================
# Tests for concurrent access (from Phase 2 improvement plan)
# =============================================================================
def test_concurrent_reads_do_not_corrupt_data(self, temp_db_path, sample_recipes):
"""Verify concurrent reads don't corrupt database state."""
import threading
import time
cache = PersistentRecipeCache(db_path=temp_db_path)
cache.save_cache(sample_recipes)
results = []
errors = []
def read_operation():
try:
for _ in range(10):
loaded = cache.load_cache()
if loaded is not None:
results.append(len(loaded.raw_data))
time.sleep(0.01)
except Exception as e:
errors.append(str(e))
# Start multiple reader threads
threads = [threading.Thread(target=read_operation) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
# No errors should occur
assert len(errors) == 0, f"Errors during concurrent reads: {errors}"
# All reads should return consistent data
assert all(count == 2 for count in results), "Inconsistent read results"
def test_concurrent_write_and_read(self, temp_db_path, sample_recipes):
"""Verify thread safety under concurrent writes and reads."""
import threading
import time
cache = PersistentRecipeCache(db_path=temp_db_path)
cache.save_cache(sample_recipes)
write_errors = []
read_errors = []
write_count = [0]
def write_operation():
try:
for i in range(5):
recipe = {
"id": f"concurrent-{i}",
"title": f"Concurrent Recipe {i}",
}
cache.update_recipe(recipe)
write_count[0] += 1
time.sleep(0.02)
except Exception as e:
write_errors.append(str(e))
def read_operation():
try:
for _ in range(10):
cache.load_cache()
cache.get_recipe_count()
time.sleep(0.01)
except Exception as e:
read_errors.append(str(e))
# Mix of read and write threads
threads = (
[threading.Thread(target=write_operation) for _ in range(2)]
+ [threading.Thread(target=read_operation) for _ in range(3)]
)
for t in threads:
t.start()
for t in threads:
t.join()
# No errors should occur
assert len(write_errors) == 0, f"Write errors: {write_errors}"
assert len(read_errors) == 0, f"Read errors: {read_errors}"
# Writes should complete successfully
assert write_count[0] > 0
def test_concurrent_updates_to_same_recipe(self, temp_db_path):
"""Verify concurrent updates to the same recipe don't corrupt data."""
import threading
cache = PersistentRecipeCache(db_path=temp_db_path)
# Initialize with one recipe
initial_recipe = {
"id": "concurrent-update",
"title": "Initial Title",
"version": 1,
}
cache.save_cache([initial_recipe])
errors = []
successful_updates = []
def update_operation(thread_id):
try:
for i in range(5):
recipe = {
"id": "concurrent-update",
"title": f"Title from thread {thread_id} update {i}",
"version": i + 1,
}
cache.update_recipe(recipe)
successful_updates.append((thread_id, i))
except Exception as e:
errors.append(f"Thread {thread_id}: {e}")
# Multiple threads updating the same recipe
threads = [
threading.Thread(target=update_operation, args=(i,)) for i in range(3)
]
for t in threads:
t.start()
for t in threads:
t.join()
# No errors should occur
assert len(errors) == 0, f"Update errors: {errors}"
# All updates should complete
assert len(successful_updates) == 15
# Final state should be valid
final_count = cache.get_recipe_count()
assert final_count == 1
def test_schema_initialization_thread_safety(self, temp_db_path):
"""Verify schema initialization is thread-safe."""
import threading
errors = []
initialized_caches = []
def create_cache():
try:
cache = PersistentRecipeCache(db_path=temp_db_path)
initialized_caches.append(cache)
except Exception as e:
errors.append(str(e))
# Multiple threads creating cache simultaneously
threads = [threading.Thread(target=create_cache) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
# No errors should occur
assert len(errors) == 0, f"Initialization errors: {errors}"
# All caches should be created
assert len(initialized_caches) == 5
def test_concurrent_save_and_remove(self, temp_db_path, sample_recipes):
"""Verify concurrent save and remove operations don't corrupt database."""
import threading
import time
cache = PersistentRecipeCache(db_path=temp_db_path)
errors = []
operation_counts = {"saves": 0, "removes": 0}
def save_operation():
try:
for i in range(5):
recipes = [
{"id": f"recipe-{j}", "title": f"Recipe {j}"}
for j in range(i * 2, i * 2 + 2)
]
cache.save_cache(recipes)
operation_counts["saves"] += 1
time.sleep(0.015)
except Exception as e:
errors.append(f"Save error: {e}")
def remove_operation():
try:
for i in range(5):
cache.remove_recipe(f"recipe-{i}")
operation_counts["removes"] += 1
time.sleep(0.02)
except Exception as e:
errors.append(f"Remove error: {e}")
# Concurrent save and remove threads
threads = [
threading.Thread(target=save_operation),
threading.Thread(target=remove_operation),
]
for t in threads:
t.start()
for t in threads:
t.join()
# No errors should occur
assert len(errors) == 0, f"Operation errors: {errors}"
# Operations should complete
assert operation_counts["saves"] == 5
assert operation_counts["removes"] == 5