feat: remove prewarm cache and improve recipe scanner initialization

- Remove prewarm_cache startup hook from BaseRecipeRoutes
- Add post-scan task management to RecipeScanner for proper cleanup
- Ensure LoRA scanner initialization completes before recipe enrichment
- Schedule post-scan enrichment after cache initialization
- Improve error handling and task cancellation during shutdown
This commit is contained in:
Will Miao
2025-12-16 21:00:04 +08:00
parent 7e133e4b9d
commit 3382d83aee
4 changed files with 205 additions and 23 deletions

View File

@@ -79,26 +79,8 @@ class BaseRecipeRoutes:
return return
app.on_startup.append(self.attach_dependencies) app.on_startup.append(self.attach_dependencies)
app.on_startup.append(self.prewarm_cache)
self._startup_hooks_registered = True self._startup_hooks_registered = True
async def prewarm_cache(self, app: web.Application | None = None) -> None:
"""Pre-load recipe and LoRA caches on startup."""
try:
await self.attach_dependencies(app)
if self.lora_scanner is not None:
await self.lora_scanner.get_cached_data()
hash_index = getattr(self.lora_scanner, "_hash_index", None)
if hash_index is not None and hasattr(hash_index, "_hash_to_path"):
_ = len(hash_index._hash_to_path)
if self.recipe_scanner is not None:
await self.recipe_scanner.get_cached_data(force_refresh=True)
except Exception as exc:
logger.error("Error pre-warming recipe cache: %s", exc, exc_info=True)
def to_route_mapping(self) -> Mapping[str, Callable]: def to_route_mapping(self) -> Mapping[str, Callable]:
"""Return a mapping of handler name to coroutine for registrar binding.""" """Return a mapping of handler name to coroutine for registrar binding."""

View File

@@ -64,6 +64,7 @@ class RecipeScanner:
self._initialization_task: Optional[asyncio.Task] = None self._initialization_task: Optional[asyncio.Task] = None
self._is_initializing = False self._is_initializing = False
self._mutation_lock = asyncio.Lock() self._mutation_lock = asyncio.Lock()
self._post_scan_task: Optional[asyncio.Task] = None
self._resort_tasks: Set[asyncio.Task] = set() self._resort_tasks: Set[asyncio.Task] = set()
if lora_scanner: if lora_scanner:
self._lora_scanner = lora_scanner self._lora_scanner = lora_scanner
@@ -84,6 +85,10 @@ class RecipeScanner:
task.cancel() task.cancel()
self._resort_tasks.clear() self._resort_tasks.clear()
if self._post_scan_task and not self._post_scan_task.done():
self._post_scan_task.cancel()
self._post_scan_task = None
self._cache = None self._cache = None
self._initialization_task = None self._initialization_task = None
self._is_initializing = False self._is_initializing = False
@@ -105,6 +110,8 @@ class RecipeScanner:
async def initialize_in_background(self) -> None: async def initialize_in_background(self) -> None:
"""Initialize cache in background using thread pool""" """Initialize cache in background using thread pool"""
try: try:
await self._wait_for_lora_scanner()
# Set initial empty cache to avoid None reference errors # Set initial empty cache to avoid None reference errors
if self._cache is None: if self._cache is None:
self._cache = RecipeCache( self._cache = RecipeCache(
@@ -115,6 +122,7 @@ class RecipeScanner:
# Mark as initializing to prevent concurrent initializations # Mark as initializing to prevent concurrent initializations
self._is_initializing = True self._is_initializing = True
self._initialization_task = asyncio.current_task()
try: try:
# Start timer # Start timer
@@ -126,11 +134,14 @@ class RecipeScanner:
None, # Use default thread pool None, # Use default thread pool
self._initialize_recipe_cache_sync # Run synchronous version in thread self._initialize_recipe_cache_sync # Run synchronous version in thread
) )
if cache is not None:
self._cache = cache
# Calculate elapsed time and log it # Calculate elapsed time and log it
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
recipe_count = len(cache.raw_data) if cache and hasattr(cache, 'raw_data') else 0 recipe_count = len(cache.raw_data) if cache and hasattr(cache, 'raw_data') else 0
logger.info(f"Recipe cache initialized in {elapsed_time:.2f} seconds. Found {recipe_count} recipes") logger.info(f"Recipe cache initialized in {elapsed_time:.2f} seconds. Found {recipe_count} recipes")
self._schedule_post_scan_enrichment()
finally: finally:
# Mark initialization as complete regardless of outcome # Mark initialization as complete regardless of outcome
self._is_initializing = False self._is_initializing = False
@@ -237,6 +248,88 @@ class RecipeScanner:
# Clean up the event loop # Clean up the event loop
loop.close() loop.close()
async def _wait_for_lora_scanner(self) -> None:
"""Ensure the LoRA scanner has initialized before recipe enrichment."""
if not getattr(self, "_lora_scanner", None):
return
lora_scanner = self._lora_scanner
cache_ready = getattr(lora_scanner, "_cache", None) is not None
# If cache is already available, we can proceed
if cache_ready:
return
# Await an existing initialization task if present
task = getattr(lora_scanner, "_initialization_task", None)
if task and hasattr(task, "done") and not task.done():
try:
await task
except Exception: # pragma: no cover - defensive guard
pass
if getattr(lora_scanner, "_cache", None) is not None:
return
# Otherwise, request initialization and proceed once it completes
try:
await lora_scanner.initialize_in_background()
except Exception as exc: # pragma: no cover - defensive guard
logger.debug("Recipe Scanner: LoRA init request failed: %s", exc)
def _schedule_post_scan_enrichment(self) -> None:
"""Kick off a non-blocking enrichment pass to fill remote metadata."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
if self._post_scan_task and not self._post_scan_task.done():
return
async def _run_enrichment():
try:
await self._enrich_cache_metadata()
except asyncio.CancelledError:
raise
except Exception as exc: # pragma: no cover - defensive guard
logger.error("Recipe Scanner: error during post-scan enrichment: %s", exc, exc_info=True)
self._post_scan_task = loop.create_task(_run_enrichment(), name="recipe_cache_enrichment")
async def _enrich_cache_metadata(self) -> None:
"""Perform remote metadata enrichment after the initial scan."""
cache = self._cache
if cache is None or not getattr(cache, "raw_data", None):
return
for index, recipe in enumerate(list(cache.raw_data)):
try:
metadata_updated = await self._update_lora_information(recipe)
if metadata_updated:
recipe_id = recipe.get("id")
if recipe_id:
recipe_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json")
if os.path.exists(recipe_path):
try:
self._write_recipe_file(recipe_path, recipe)
except Exception as exc: # pragma: no cover - best-effort persistence
logger.debug("Recipe Scanner: could not persist recipe %s: %s", recipe_id, exc)
except asyncio.CancelledError:
raise
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Recipe Scanner: error enriching recipe %s: %s", recipe.get("id"), exc, exc_info=True)
if index % 10 == 0:
await asyncio.sleep(0)
try:
await cache.resort()
except Exception as exc: # pragma: no cover - defensive logging
logger.debug("Recipe Scanner: error resorting cache after enrichment: %s", exc)
def _schedule_resort(self, *, name_only: bool = False) -> None: def _schedule_resort(self, *, name_only: bool = False) -> None:
"""Schedule a background resort of the recipe cache.""" """Schedule a background resort of the recipe cache."""
@@ -438,7 +531,7 @@ class RecipeScanner:
recipe_data['gen_params'] = {} recipe_data['gen_params'] = {}
# Update lora information with local paths and availability # Update lora information with local paths and availability
await self._update_lora_information(recipe_data) lora_metadata_updated = await self._update_lora_information(recipe_data)
if recipe_data.get('checkpoint'): if recipe_data.get('checkpoint'):
checkpoint_entry = self._normalize_checkpoint_entry(recipe_data['checkpoint']) checkpoint_entry = self._normalize_checkpoint_entry(recipe_data['checkpoint'])
@@ -459,6 +552,12 @@ class RecipeScanner:
logger.info(f"Added fingerprint to recipe: {recipe_path}") logger.info(f"Added fingerprint to recipe: {recipe_path}")
except Exception as e: except Exception as e:
logger.error(f"Error writing updated recipe with fingerprint: {e}") logger.error(f"Error writing updated recipe with fingerprint: {e}")
elif lora_metadata_updated:
# Persist updates such as marking invalid entries as deleted
try:
self._write_recipe_file(recipe_path, recipe_data)
except Exception as e:
logger.error(f"Error writing updated recipe metadata: {e}")
return recipe_data return recipe_data
except Exception as e: except Exception as e:
@@ -519,7 +618,13 @@ class RecipeScanner:
logger.warning(f"Marked lora with modelVersionId {model_version_id} as deleted") logger.warning(f"Marked lora with modelVersionId {model_version_id} as deleted")
metadata_updated = True metadata_updated = True
else: else:
logger.debug(f"Could not get hash for modelVersionId {model_version_id}") # No hash returned; mark as deleted to avoid repeated lookups
lora['isDeleted'] = True
metadata_updated = True
logger.warning(
"Marked lora with modelVersionId %s as deleted after failed hash lookup",
model_version_id,
)
# If has hash but no file_name, look up in lora library # If has hash but no file_name, look up in lora library
if 'hash' in lora and (not lora.get('file_name') or not lora['file_name']): if 'hash' in lora and (not lora.get('file_name') or not lora['file_name']):

View File

@@ -103,8 +103,7 @@ def test_register_startup_hooks_appends_once():
] ]
assert routes.attach_dependencies in startup_bound_to_routes assert routes.attach_dependencies in startup_bound_to_routes
assert routes.prewarm_cache in startup_bound_to_routes assert len(startup_bound_to_routes) == 1
assert len(startup_bound_to_routes) == 2
def test_to_route_mapping_uses_handler_set(): def test_to_route_mapping_uses_handler_set():
@@ -212,4 +211,4 @@ def test_recipe_routes_setup_routes_uses_registrar(monkeypatch: pytest.MonkeyPat
if isinstance(getattr(cb, "__self__", None), recipe_routes.RecipeRoutes) if isinstance(getattr(cb, "__self__", None), recipe_routes.RecipeRoutes)
} }
assert {type(cb.__self__) for cb in recipe_callbacks} == {recipe_routes.RecipeRoutes} assert {type(cb.__self__) for cb in recipe_callbacks} == {recipe_routes.RecipeRoutes}
assert {cb.__name__ for cb in recipe_callbacks} == {"attach_dependencies", "prewarm_cache"} assert {cb.__name__ for cb in recipe_callbacks} == {"attach_dependencies"}

View File

@@ -349,3 +349,99 @@ def test_enrich_formats_absolute_preview_paths(recipe_scanner, tmp_path):
enriched = scanner._enrich_lora_entry(dict(lora)) enriched = scanner._enrich_lora_entry(dict(lora))
assert enriched["preview_url"] == config.get_preview_static_url(str(preview_path)) assert enriched["preview_url"] == config.get_preview_static_url(str(preview_path))
@pytest.mark.asyncio
async def test_initialize_waits_for_lora_scanner(monkeypatch):
ready_flag = asyncio.Event()
call_count = 0
class StubLoraScanner:
def __init__(self):
self._cache = None
self._is_initializing = True
async def initialize_in_background(self):
nonlocal call_count
call_count += 1
await asyncio.sleep(0)
self._cache = SimpleNamespace(raw_data=[])
self._is_initializing = False
ready_flag.set()
lora_scanner = StubLoraScanner()
scanner = RecipeScanner(lora_scanner=lora_scanner)
await scanner.initialize_in_background()
assert ready_flag.is_set()
assert call_count == 1
assert scanner._cache is not None
@pytest.mark.asyncio
async def test_invalid_model_version_marked_deleted_and_not_retried(monkeypatch, recipe_scanner):
scanner, _ = recipe_scanner
recipes_dir = Path(config.loras_roots[0]) / "recipes"
recipes_dir.mkdir(parents=True, exist_ok=True)
recipe = {
"id": "invalid-version",
"file_path": str(recipes_dir / "invalid-version.webp"),
"title": "Invalid",
"modified": 0.0,
"created_date": 0.0,
"loras": [{"modelVersionId": 999, "file_name": "", "hash": ""}],
}
await scanner.add_recipe(dict(recipe))
call_count = 0
async def fake_get_hash(model_version_id):
nonlocal call_count
call_count += 1
return None
monkeypatch.setattr(scanner, "_get_hash_from_civitai", fake_get_hash)
metadata_updated = await scanner._update_lora_information(recipe)
assert metadata_updated is True
assert recipe["loras"][0]["isDeleted"] is True
assert call_count == 1
# Subsequent calls should skip remote lookup once marked deleted
metadata_updated_again = await scanner._update_lora_information(recipe)
assert metadata_updated_again is False
assert call_count == 1
@pytest.mark.asyncio
async def test_load_recipe_persists_deleted_flag_on_invalid_version(monkeypatch, recipe_scanner, tmp_path):
scanner, _ = recipe_scanner
recipes_dir = Path(config.loras_roots[0]) / "recipes"
recipes_dir.mkdir(parents=True, exist_ok=True)
recipe_id = "persist-invalid"
recipe_path = recipes_dir / f"{recipe_id}.recipe.json"
recipe_data = {
"id": recipe_id,
"file_path": str(recipes_dir / f"{recipe_id}.webp"),
"title": "Invalid",
"modified": 0.0,
"created_date": 0.0,
"loras": [{"modelVersionId": 1234, "file_name": "", "hash": ""}],
}
recipe_path.write_text(json.dumps(recipe_data))
async def fake_get_hash(model_version_id):
return None
monkeypatch.setattr(scanner, "_get_hash_from_civitai", fake_get_hash)
loaded = await scanner._load_recipe_file(str(recipe_path))
assert loaded["loras"][0]["isDeleted"] is True
persisted = json.loads(recipe_path.read_text())
assert persisted["loras"][0]["isDeleted"] is True