mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 05:32:12 -03:00
Introduce a new PersistentRecipeCache service that stores recipe metadata in an SQLite database to significantly reduce application startup time. The cache eliminates the need to walk directories and parse JSON files on each launch by persisting recipe data between sessions. Key features: - Thread-safe singleton implementation with library-specific instances - Automatic schema initialization and migration support - JSON serialization for complex recipe fields (LoRAs, checkpoints, generation parameters, tags) - File system monitoring with mtime/size validation for cache invalidation - Environment variable toggle (LORA_MANAGER_DISABLE_PERSISTENT_CACHE) for debugging - Comprehensive test suite covering save/load cycles, cache invalidation, and edge cases The cache improves user experience by enabling near-instantaneous recipe loading after the initial cache population, while maintaining data consistency through file change detection.
258 lines
8.9 KiB
Python
258 lines
8.9 KiB
Python
"""Tests for PersistentRecipeCache."""
|
|
|
|
import json
|
|
import os
|
|
import tempfile
|
|
from typing import Dict, List
|
|
|
|
import pytest
|
|
|
|
from py.services.persistent_recipe_cache import PersistentRecipeCache, PersistedRecipeData
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_db_path():
|
|
"""Create a temporary database path."""
|
|
with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as f:
|
|
path = f.name
|
|
yield path
|
|
# Cleanup
|
|
if os.path.exists(path):
|
|
os.unlink(path)
|
|
# Also clean up WAL files
|
|
for suffix in ["-wal", "-shm"]:
|
|
wal_path = path + suffix
|
|
if os.path.exists(wal_path):
|
|
os.unlink(wal_path)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_recipes() -> List[Dict]:
|
|
"""Create sample recipe data."""
|
|
return [
|
|
{
|
|
"id": "recipe-001",
|
|
"file_path": "/path/to/image1.png",
|
|
"title": "Test Recipe 1",
|
|
"folder": "folder1",
|
|
"base_model": "SD1.5",
|
|
"fingerprint": "abc123",
|
|
"created_date": 1700000000.0,
|
|
"modified": 1700000100.0,
|
|
"favorite": True,
|
|
"repair_version": 3,
|
|
"preview_nsfw_level": 1,
|
|
"loras": [
|
|
{"hash": "hash1", "file_name": "lora1", "strength": 0.8},
|
|
{"hash": "hash2", "file_name": "lora2", "strength": 1.0},
|
|
],
|
|
"checkpoint": {"name": "model.safetensors", "hash": "cphash"},
|
|
"gen_params": {"prompt": "test prompt", "negative_prompt": "bad"},
|
|
"tags": ["tag1", "tag2"],
|
|
},
|
|
{
|
|
"id": "recipe-002",
|
|
"file_path": "/path/to/image2.png",
|
|
"title": "Test Recipe 2",
|
|
"folder": "",
|
|
"base_model": "SDXL",
|
|
"fingerprint": "def456",
|
|
"created_date": 1700000200.0,
|
|
"modified": 1700000300.0,
|
|
"favorite": False,
|
|
"repair_version": 2,
|
|
"preview_nsfw_level": 0,
|
|
"loras": [{"hash": "hash3", "file_name": "lora3", "strength": 0.5}],
|
|
"gen_params": {"prompt": "another prompt"},
|
|
"tags": [],
|
|
},
|
|
]
|
|
|
|
|
|
class TestPersistentRecipeCache:
|
|
"""Tests for PersistentRecipeCache class."""
|
|
|
|
def test_init_creates_db(self, temp_db_path):
|
|
"""Test that initialization creates the database."""
|
|
cache = PersistentRecipeCache(db_path=temp_db_path)
|
|
assert cache.is_enabled()
|
|
assert os.path.exists(temp_db_path)
|
|
|
|
def test_save_and_load_roundtrip(self, temp_db_path, sample_recipes):
|
|
"""Test save and load cycle preserves data."""
|
|
cache = PersistentRecipeCache(db_path=temp_db_path)
|
|
|
|
# Save recipes
|
|
json_paths = {
|
|
"recipe-001": "/path/to/recipe-001.recipe.json",
|
|
"recipe-002": "/path/to/recipe-002.recipe.json",
|
|
}
|
|
cache.save_cache(sample_recipes, json_paths)
|
|
|
|
# Load recipes
|
|
loaded = cache.load_cache()
|
|
assert loaded is not None
|
|
assert len(loaded.raw_data) == 2
|
|
|
|
# Verify first recipe
|
|
r1 = next(r for r in loaded.raw_data if r["id"] == "recipe-001")
|
|
assert r1["title"] == "Test Recipe 1"
|
|
assert r1["folder"] == "folder1"
|
|
assert r1["base_model"] == "SD1.5"
|
|
assert r1["fingerprint"] == "abc123"
|
|
assert r1["favorite"] is True
|
|
assert r1["repair_version"] == 3
|
|
assert len(r1["loras"]) == 2
|
|
assert r1["loras"][0]["hash"] == "hash1"
|
|
assert r1["checkpoint"]["name"] == "model.safetensors"
|
|
assert r1["gen_params"]["prompt"] == "test prompt"
|
|
assert r1["tags"] == ["tag1", "tag2"]
|
|
|
|
# Verify second recipe
|
|
r2 = next(r for r in loaded.raw_data if r["id"] == "recipe-002")
|
|
assert r2["title"] == "Test Recipe 2"
|
|
assert r2["folder"] == ""
|
|
assert r2["favorite"] is False
|
|
|
|
def test_empty_cache_returns_none(self, temp_db_path):
|
|
"""Test that loading empty cache returns None."""
|
|
cache = PersistentRecipeCache(db_path=temp_db_path)
|
|
loaded = cache.load_cache()
|
|
assert loaded is None
|
|
|
|
def test_update_single_recipe(self, temp_db_path, sample_recipes):
|
|
"""Test updating a single recipe."""
|
|
cache = PersistentRecipeCache(db_path=temp_db_path)
|
|
cache.save_cache(sample_recipes)
|
|
|
|
# Update a recipe
|
|
updated_recipe = dict(sample_recipes[0])
|
|
updated_recipe["title"] = "Updated Title"
|
|
updated_recipe["favorite"] = False
|
|
cache.update_recipe(updated_recipe, "/path/to/recipe-001.recipe.json")
|
|
|
|
# Load and verify
|
|
loaded = cache.load_cache()
|
|
r1 = next(r for r in loaded.raw_data if r["id"] == "recipe-001")
|
|
assert r1["title"] == "Updated Title"
|
|
assert r1["favorite"] is False
|
|
|
|
def test_remove_recipe(self, temp_db_path, sample_recipes):
|
|
"""Test removing a recipe."""
|
|
cache = PersistentRecipeCache(db_path=temp_db_path)
|
|
cache.save_cache(sample_recipes)
|
|
|
|
# Remove a recipe
|
|
cache.remove_recipe("recipe-001")
|
|
|
|
# Load and verify
|
|
loaded = cache.load_cache()
|
|
assert len(loaded.raw_data) == 1
|
|
assert loaded.raw_data[0]["id"] == "recipe-002"
|
|
|
|
def test_get_indexed_recipe_ids(self, temp_db_path, sample_recipes):
|
|
"""Test getting all indexed recipe IDs."""
|
|
cache = PersistentRecipeCache(db_path=temp_db_path)
|
|
cache.save_cache(sample_recipes)
|
|
|
|
ids = cache.get_indexed_recipe_ids()
|
|
assert ids == {"recipe-001", "recipe-002"}
|
|
|
|
def test_get_recipe_count(self, temp_db_path, sample_recipes):
|
|
"""Test getting recipe count."""
|
|
cache = PersistentRecipeCache(db_path=temp_db_path)
|
|
assert cache.get_recipe_count() == 0
|
|
|
|
cache.save_cache(sample_recipes)
|
|
assert cache.get_recipe_count() == 2
|
|
|
|
cache.remove_recipe("recipe-001")
|
|
assert cache.get_recipe_count() == 1
|
|
|
|
def test_file_stats(self, temp_db_path, sample_recipes):
|
|
"""Test file stats tracking."""
|
|
cache = PersistentRecipeCache(db_path=temp_db_path)
|
|
|
|
json_paths = {
|
|
"recipe-001": "/path/to/recipe-001.recipe.json",
|
|
"recipe-002": "/path/to/recipe-002.recipe.json",
|
|
}
|
|
cache.save_cache(sample_recipes, json_paths)
|
|
|
|
stats = cache.get_file_stats()
|
|
# File stats will be (0.0, 0) since files don't exist
|
|
assert len(stats) == 2
|
|
|
|
def test_disabled_cache(self, temp_db_path, sample_recipes, monkeypatch):
|
|
"""Test that disabled cache returns None."""
|
|
monkeypatch.setenv("LORA_MANAGER_DISABLE_PERSISTENT_CACHE", "1")
|
|
|
|
cache = PersistentRecipeCache(db_path=temp_db_path)
|
|
assert not cache.is_enabled()
|
|
cache.save_cache(sample_recipes)
|
|
assert cache.load_cache() is None
|
|
|
|
def test_invalid_recipe_skipped(self, temp_db_path):
|
|
"""Test that recipes without ID are skipped."""
|
|
cache = PersistentRecipeCache(db_path=temp_db_path)
|
|
|
|
recipes = [
|
|
{"title": "No ID recipe"}, # Missing ID
|
|
{"id": "valid-001", "title": "Valid recipe"},
|
|
]
|
|
cache.save_cache(recipes)
|
|
|
|
loaded = cache.load_cache()
|
|
assert len(loaded.raw_data) == 1
|
|
assert loaded.raw_data[0]["id"] == "valid-001"
|
|
|
|
def test_get_default_singleton(self, monkeypatch):
|
|
"""Test singleton behavior."""
|
|
# Use temp directory
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
monkeypatch.setenv("LORA_MANAGER_RECIPE_CACHE_DB", os.path.join(tmpdir, "test.sqlite"))
|
|
|
|
PersistentRecipeCache.clear_instances()
|
|
cache1 = PersistentRecipeCache.get_default("test_lib")
|
|
cache2 = PersistentRecipeCache.get_default("test_lib")
|
|
assert cache1 is cache2
|
|
|
|
cache3 = PersistentRecipeCache.get_default("other_lib")
|
|
assert cache1 is not cache3
|
|
|
|
PersistentRecipeCache.clear_instances()
|
|
|
|
def test_loras_json_handling(self, temp_db_path):
|
|
"""Test that complex loras data is preserved."""
|
|
cache = PersistentRecipeCache(db_path=temp_db_path)
|
|
|
|
recipes = [
|
|
{
|
|
"id": "complex-001",
|
|
"title": "Complex Loras",
|
|
"loras": [
|
|
{
|
|
"hash": "abc123",
|
|
"file_name": "test_lora",
|
|
"strength": 0.75,
|
|
"modelVersionId": 12345,
|
|
"modelName": "Test Model",
|
|
"isDeleted": False,
|
|
},
|
|
{
|
|
"hash": "def456",
|
|
"file_name": "another_lora",
|
|
"strength": 1.0,
|
|
"clip_strength": 0.8,
|
|
},
|
|
],
|
|
}
|
|
]
|
|
cache.save_cache(recipes)
|
|
|
|
loaded = cache.load_cache()
|
|
loras = loaded.raw_data[0]["loras"]
|
|
assert len(loras) == 2
|
|
assert loras[0]["modelVersionId"] == 12345
|
|
assert loras[1]["clip_strength"] == 0.8
|