diff --git a/py/routes/handlers/recipe_handlers.py b/py/routes/handlers/recipe_handlers.py index 35f4c088..aa912477 100644 --- a/py/routes/handlers/recipe_handlers.py +++ b/py/routes/handlers/recipe_handlers.py @@ -290,27 +290,7 @@ class RecipeQueryHandler: if not lora_hash: return web.json_response({"success": False, "error": "Lora hash is required"}, status=400) - cache = await recipe_scanner.get_cached_data() - matching_recipes = [] - for recipe in getattr(cache, "raw_data", []): - for lora in recipe.get("loras", []): - if lora.get("hash", "").lower() == lora_hash.lower(): - matching_recipes.append(recipe) - break - - lora_scanner = getattr(recipe_scanner, "_lora_scanner", None) - for recipe in matching_recipes: - for lora in recipe.get("loras", []): - hash_value = (lora.get("hash") or "").lower() - if hash_value and lora_scanner is not None: - lora["inLibrary"] = lora_scanner.has_hash(hash_value) - lora["preview_url"] = lora_scanner.get_preview_url_by_hash(hash_value) - lora["localPath"] = lora_scanner.get_path_by_hash(hash_value) - if recipe.get("file_path"): - recipe["file_url"] = self._format_recipe_file_url(recipe["file_path"]) - else: - recipe["file_url"] = "/loras_static/images/no-preview.png" - + matching_recipes = await recipe_scanner.get_recipes_for_lora(lora_hash) return web.json_response({"success": True, "recipes": matching_recipes}) except Exception as exc: self._logger.error("Error getting recipes for Lora: %s", exc) @@ -384,50 +364,15 @@ class RecipeQueryHandler: raise RuntimeError("Recipe scanner unavailable") recipe_id = request.match_info["recipe_id"] - cache = await recipe_scanner.get_cached_data() - recipe = next( - (r for r in getattr(cache, "raw_data", []) if str(r.get("id", "")) == recipe_id), - None, - ) - if not recipe: + try: + syntax_parts = await recipe_scanner.get_recipe_syntax_tokens(recipe_id) + except RecipeNotFoundError: return web.json_response({"error": "Recipe not found"}, status=404) - loras = recipe.get("loras", []) - if not loras: + if not syntax_parts: return web.json_response({"error": "No LoRAs found in this recipe"}, status=400) - lora_scanner = getattr(recipe_scanner, "_lora_scanner", None) - hash_index = getattr(lora_scanner, "_hash_index", None) - - lora_syntax_parts = [] - for lora in loras: - if lora.get("isDeleted", False): - continue - hash_value = (lora.get("hash") or "").lower() - if not hash_value or lora_scanner is None or not lora_scanner.has_hash(hash_value): - continue - - file_name = None - if hash_value and hash_index is not None and hasattr(hash_index, "_hash_to_path"): - file_path = hash_index._hash_to_path.get(hash_value) - if file_path: - file_name = os.path.splitext(os.path.basename(file_path))[0] - - if not file_name and lora.get("modelVersionId") and lora_scanner is not None: - all_loras = await lora_scanner.get_cached_data() - for cached_lora in getattr(all_loras, "raw_data", []): - civitai_info = cached_lora.get("civitai") - if civitai_info and civitai_info.get("id") == lora.get("modelVersionId"): - file_name = os.path.splitext(os.path.basename(cached_lora["path"]))[0] - break - - if not file_name: - file_name = lora.get("file_name", "unknown-lora") - - strength = lora.get("strength", 1.0) - lora_syntax_parts.append(f"") - - return web.json_response({"success": True, "syntax": " ".join(lora_syntax_parts)}) + return web.json_response({"success": True, "syntax": " ".join(syntax_parts)}) except Exception as exc: self._logger.error("Error generating recipe syntax: %s", exc, exc_info=True) return web.json_response({"error": str(exc)}, status=500) diff --git a/py/services/recipe_cache.py b/py/services/recipe_cache.py index b1f52246..ac28b3aa 100644 --- a/py/services/recipe_cache.py +++ b/py/services/recipe_cache.py @@ -1,5 +1,5 @@ import asyncio -from typing import List, Dict +from typing import Iterable, List, Dict, Optional from dataclasses import dataclass from operator import itemgetter from natsort import natsorted @@ -10,77 +10,115 @@ class RecipeCache: raw_data: List[Dict] sorted_by_name: List[Dict] sorted_by_date: List[Dict] - + def __post_init__(self): self._lock = asyncio.Lock() async def resort(self, name_only: bool = False): """Resort all cached data views""" async with self._lock: - self.sorted_by_name = natsorted( - self.raw_data, - key=lambda x: x.get('title', '').lower() # Case-insensitive sort - ) - if not name_only: - self.sorted_by_date = sorted( - self.raw_data, - key=itemgetter('created_date', 'file_path'), - reverse=True - ) - - async def update_recipe_metadata(self, recipe_id: str, metadata: Dict) -> bool: + self._resort_locked(name_only=name_only) + + async def update_recipe_metadata(self, recipe_id: str, metadata: Dict, *, resort: bool = True) -> bool: """Update metadata for a specific recipe in all cached data - + Args: recipe_id: The ID of the recipe to update metadata: The new metadata - + Returns: bool: True if the update was successful, False if the recipe wasn't found """ + async with self._lock: + for item in self.raw_data: + if str(item.get('id')) == str(recipe_id): + item.update(metadata) + if resort: + self._resort_locked() + return True + return False # Recipe not found + + async def add_recipe(self, recipe_data: Dict, *, resort: bool = False) -> None: + """Add a new recipe to the cache.""" - # Update in raw_data - for item in self.raw_data: - if item.get('id') == recipe_id: - item.update(metadata) - break - else: - return False # Recipe not found - - # Resort to reflect changes - await self.resort() - return True - - async def add_recipe(self, recipe_data: Dict) -> None: - """Add a new recipe to the cache - - Args: - recipe_data: The recipe data to add - """ async with self._lock: self.raw_data.append(recipe_data) - await self.resort() + if resort: + self._resort_locked() + + async def remove_recipe(self, recipe_id: str, *, resort: bool = False) -> Optional[Dict]: + """Remove a recipe from the cache by ID. - async def remove_recipe(self, recipe_id: str) -> bool: - """Remove a recipe from the cache by ID - Args: recipe_id: The ID of the recipe to remove - + Returns: - bool: True if the recipe was found and removed, False otherwise + The removed recipe data if found, otherwise ``None``. """ - # Find the recipe in raw_data - recipe_index = next((i for i, recipe in enumerate(self.raw_data) - if recipe.get('id') == recipe_id), None) - - if recipe_index is None: - return False - - # Remove from raw_data - self.raw_data.pop(recipe_index) - - # Resort to update sorted lists - await self.resort() - - return True \ No newline at end of file + + async with self._lock: + for index, recipe in enumerate(self.raw_data): + if str(recipe.get('id')) == str(recipe_id): + removed = self.raw_data.pop(index) + if resort: + self._resort_locked() + return removed + return None + + async def bulk_remove(self, recipe_ids: Iterable[str], *, resort: bool = False) -> List[Dict]: + """Remove multiple recipes from the cache.""" + + id_set = {str(recipe_id) for recipe_id in recipe_ids} + if not id_set: + return [] + + async with self._lock: + removed = [item for item in self.raw_data if str(item.get('id')) in id_set] + if not removed: + return [] + + self.raw_data = [item for item in self.raw_data if str(item.get('id')) not in id_set] + if resort: + self._resort_locked() + return removed + + async def replace_recipe(self, recipe_id: str, new_data: Dict, *, resort: bool = False) -> bool: + """Replace cached data for a recipe.""" + + async with self._lock: + for index, recipe in enumerate(self.raw_data): + if str(recipe.get('id')) == str(recipe_id): + self.raw_data[index] = new_data + if resort: + self._resort_locked() + return True + return False + + async def get_recipe(self, recipe_id: str) -> Optional[Dict]: + """Return a shallow copy of a cached recipe.""" + + async with self._lock: + for recipe in self.raw_data: + if str(recipe.get('id')) == str(recipe_id): + return dict(recipe) + return None + + async def snapshot(self) -> List[Dict]: + """Return a copy of all cached recipes.""" + + async with self._lock: + return [dict(item) for item in self.raw_data] + + def _resort_locked(self, *, name_only: bool = False) -> None: + """Sort cached views. Caller must hold ``_lock``.""" + + self.sorted_by_name = natsorted( + self.raw_data, + key=lambda x: x.get('title', '').lower() + ) + if not name_only: + self.sorted_by_date = sorted( + self.raw_data, + key=itemgetter('created_date', 'file_path'), + reverse=True + ) \ No newline at end of file diff --git a/py/services/recipe_scanner.py b/py/services/recipe_scanner.py index ca5a20ac..9a82b237 100644 --- a/py/services/recipe_scanner.py +++ b/py/services/recipe_scanner.py @@ -3,13 +3,14 @@ import logging import asyncio import json import time -from typing import List, Dict, Optional, Any, Tuple +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from ..config import config from .recipe_cache import RecipeCache from .service_registry import ServiceRegistry from .lora_scanner import LoraScanner from .metadata_service import get_default_metadata_provider -from ..utils.utils import fuzzy_match +from .recipes.errors import RecipeNotFoundError +from ..utils.utils import calculate_recipe_fingerprint, fuzzy_match from natsort import natsorted import sys @@ -46,6 +47,8 @@ class RecipeScanner: self._initialization_lock = asyncio.Lock() self._initialization_task: Optional[asyncio.Task] = None self._is_initializing = False + self._mutation_lock = asyncio.Lock() + self._resort_tasks: Set[asyncio.Task] = set() if lora_scanner: self._lora_scanner = lora_scanner self._initialized = True @@ -191,6 +194,22 @@ class RecipeScanner: # Clean up the event loop loop.close() + def _schedule_resort(self, *, name_only: bool = False) -> None: + """Schedule a background resort of the recipe cache.""" + + if not self._cache: + return + + async def _resort_wrapper() -> None: + try: + await self._cache.resort(name_only=name_only) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Recipe Scanner: error resorting cache: %s", exc, exc_info=True) + + task = asyncio.create_task(_resort_wrapper()) + self._resort_tasks.add(task) + task.add_done_callback(lambda finished: self._resort_tasks.discard(finished)) + @property def recipes_dir(self) -> str: """Get path to recipes directory""" @@ -255,7 +274,45 @@ class RecipeScanner: # Return the cache (may be empty or partially initialized) return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[]) - + + async def refresh_cache(self, force: bool = False) -> RecipeCache: + """Public helper to refresh or return the recipe cache.""" + + return await self.get_cached_data(force_refresh=force) + + async def add_recipe(self, recipe_data: Dict[str, Any]) -> None: + """Add a recipe to the in-memory cache.""" + + if not recipe_data: + return + + cache = await self.get_cached_data() + await cache.add_recipe(recipe_data, resort=False) + self._schedule_resort() + + async def remove_recipe(self, recipe_id: str) -> bool: + """Remove a recipe from the cache by ID.""" + + if not recipe_id: + return False + + cache = await self.get_cached_data() + removed = await cache.remove_recipe(recipe_id, resort=False) + if removed is None: + return False + + self._schedule_resort() + return True + + async def bulk_remove(self, recipe_ids: Iterable[str]) -> int: + """Remove multiple recipes from the cache.""" + + cache = await self.get_cached_data() + removed = await cache.bulk_remove(recipe_ids, resort=False) + if removed: + self._schedule_resort() + return len(removed) + async def scan_all_recipes(self) -> List[Dict]: """Scan all recipe JSON files and return metadata""" recipes = [] @@ -326,7 +383,6 @@ class RecipeScanner: # Calculate and update fingerprint if missing if 'loras' in recipe_data and 'fingerprint' not in recipe_data: - from ..utils.utils import calculate_recipe_fingerprint fingerprint = calculate_recipe_fingerprint(recipe_data['loras']) recipe_data['fingerprint'] = fingerprint @@ -497,9 +553,36 @@ class RecipeScanner: logger.error(f"Error getting base model for lora: {e}") return None + def _enrich_lora_entry(self, lora: Dict[str, Any]) -> Dict[str, Any]: + """Populate convenience fields for a LoRA entry.""" + + if not lora or not self._lora_scanner: + return lora + + hash_value = (lora.get('hash') or '').lower() + if not hash_value: + return lora + + try: + lora['inLibrary'] = self._lora_scanner.has_hash(hash_value) + lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(hash_value) + lora['localPath'] = self._lora_scanner.get_path_by_hash(hash_value) + except Exception as exc: # pragma: no cover - defensive logging + logger.debug("Error enriching lora entry %s: %s", hash_value, exc) + + return lora + + async def get_local_lora(self, name: str) -> Optional[Dict[str, Any]]: + """Lookup a local LoRA model by name.""" + + if not self._lora_scanner or not name: + return None + + return await self._lora_scanner.get_model_info_by_name(name) + async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None, lora_hash: str = None, bypass_filters: bool = True): """Get paginated and filtered recipe data - + Args: page: Current page number (1-based) page_size: Number of items per page @@ -598,16 +681,12 @@ class RecipeScanner: # Get paginated items paginated_items = filtered_data[start_idx:end_idx] - + # Add inLibrary information for each lora for item in paginated_items: if 'loras' in item: - for lora in item['loras']: - if 'hash' in lora and lora['hash']: - lora['inLibrary'] = self._lora_scanner.has_hash(lora['hash'].lower()) - lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora['hash'].lower()) - lora['localPath'] = self._lora_scanner.get_path_by_hash(lora['hash'].lower()) - + item['loras'] = [self._enrich_lora_entry(dict(lora)) for lora in item['loras']] + result = { 'items': paginated_items, 'total': total_items, @@ -653,13 +732,8 @@ class RecipeScanner: # Add lora metadata if 'loras' in formatted_recipe: - for lora in formatted_recipe['loras']: - if 'hash' in lora and lora['hash']: - lora_hash = lora['hash'].lower() - lora['inLibrary'] = self._lora_scanner.has_hash(lora_hash) - lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora_hash) - lora['localPath'] = self._lora_scanner.get_path_by_hash(lora_hash) - + formatted_recipe['loras'] = [self._enrich_lora_entry(dict(lora)) for lora in formatted_recipe['loras']] + return formatted_recipe def _format_file_url(self, file_path: str) -> str: @@ -717,26 +791,159 @@ class RecipeScanner: # Save updated recipe with open(recipe_json_path, 'w', encoding='utf-8') as f: json.dump(recipe_data, f, indent=4, ensure_ascii=False) - + # Update the cache if it exists if self._cache is not None: - await self._cache.update_recipe_metadata(recipe_id, metadata) - + await self._cache.update_recipe_metadata(recipe_id, metadata, resort=False) + self._schedule_resort() + # If the recipe has an image, update its EXIF metadata from ..utils.exif_utils import ExifUtils image_path = recipe_data.get('file_path') if image_path and os.path.exists(image_path): ExifUtils.append_recipe_metadata(image_path, recipe_data) - + return True except Exception as e: import logging logging.getLogger(__name__).error(f"Error updating recipe metadata: {e}", exc_info=True) return False + async def update_lora_entry( + self, + recipe_id: str, + lora_index: int, + *, + target_name: str, + target_lora: Optional[Dict[str, Any]] = None, + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Update a specific LoRA entry within a recipe. + + Returns the updated recipe data and the refreshed LoRA metadata. + """ + + if target_name is None: + raise ValueError("target_name must be provided") + + recipe_json_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json") + if not os.path.exists(recipe_json_path): + raise RecipeNotFoundError("Recipe not found") + + async with self._mutation_lock: + with open(recipe_json_path, 'r', encoding='utf-8') as file_obj: + recipe_data = json.load(file_obj) + + loras = recipe_data.get('loras', []) + if lora_index >= len(loras): + raise RecipeNotFoundError("LoRA index out of range in recipe") + + lora_entry = loras[lora_index] + lora_entry['isDeleted'] = False + lora_entry['exclude'] = False + lora_entry['file_name'] = target_name + + if target_lora is not None: + sha_value = target_lora.get('sha256') or target_lora.get('sha') + if sha_value: + lora_entry['hash'] = sha_value.lower() + + civitai_info = target_lora.get('civitai') or {} + if civitai_info: + lora_entry['modelName'] = civitai_info.get('model', {}).get('name', '') + lora_entry['modelVersionName'] = civitai_info.get('name', '') + lora_entry['modelVersionId'] = civitai_info.get('id') + + recipe_data['fingerprint'] = calculate_recipe_fingerprint(recipe_data.get('loras', [])) + recipe_data['modified'] = time.time() + + with open(recipe_json_path, 'w', encoding='utf-8') as file_obj: + json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) + + cache = await self.get_cached_data() + replaced = await cache.replace_recipe(recipe_id, recipe_data, resort=False) + if not replaced: + await cache.add_recipe(recipe_data, resort=False) + self._schedule_resort() + + updated_lora = dict(lora_entry) + if target_lora is not None: + preview_url = target_lora.get('preview_url') + if preview_url: + updated_lora['preview_url'] = config.get_preview_static_url(preview_url) + if target_lora.get('file_path'): + updated_lora['localPath'] = target_lora['file_path'] + + updated_lora = self._enrich_lora_entry(updated_lora) + return recipe_data, updated_lora + + async def get_recipes_for_lora(self, lora_hash: str) -> List[Dict[str, Any]]: + """Return recipes that reference a given LoRA hash.""" + + if not lora_hash: + return [] + + normalized_hash = lora_hash.lower() + cache = await self.get_cached_data() + matching_recipes: List[Dict[str, Any]] = [] + + for recipe in cache.raw_data: + loras = recipe.get('loras', []) + if any((entry.get('hash') or '').lower() == normalized_hash for entry in loras): + recipe_copy = {**recipe} + recipe_copy['loras'] = [self._enrich_lora_entry(dict(entry)) for entry in loras] + recipe_copy['file_url'] = self._format_file_url(recipe.get('file_path')) + matching_recipes.append(recipe_copy) + + return matching_recipes + + async def get_recipe_syntax_tokens(self, recipe_id: str) -> List[str]: + """Build LoRA syntax tokens for a recipe.""" + + cache = await self.get_cached_data() + recipe = await cache.get_recipe(recipe_id) + if recipe is None: + raise RecipeNotFoundError("Recipe not found") + + loras = recipe.get('loras', []) + if not loras: + return [] + + lora_cache = None + if self._lora_scanner is not None: + lora_cache = await self._lora_scanner.get_cached_data() + + syntax_parts: List[str] = [] + for lora in loras: + if lora.get('isDeleted', False): + continue + + file_name = None + hash_value = (lora.get('hash') or '').lower() + if hash_value and self._lora_scanner is not None and hasattr(self._lora_scanner, '_hash_index'): + file_path = self._lora_scanner._hash_index.get_path(hash_value) + if file_path: + file_name = os.path.splitext(os.path.basename(file_path))[0] + + if not file_name and lora.get('modelVersionId') and lora_cache is not None: + for cached_lora in getattr(lora_cache, 'raw_data', []): + civitai_info = cached_lora.get('civitai') + if civitai_info and civitai_info.get('id') == lora.get('modelVersionId'): + cached_path = cached_lora.get('path') or cached_lora.get('file_path') + if cached_path: + file_name = os.path.splitext(os.path.basename(cached_path))[0] + break + + if not file_name: + file_name = lora.get('file_name', 'unknown-lora') + + strength = lora.get('strength', 1.0) + syntax_parts.append(f"") + + return syntax_parts + async def update_lora_filename_by_hash(self, hash_value: str, new_file_name: str) -> Tuple[int, int]: """Update file_name in all recipes that contain a LoRA with the specified hash. - + Args: hash_value: The SHA256 hash value of the LoRA new_file_name: The new file_name to set diff --git a/py/services/recipes/persistence_service.py b/py/services/recipes/persistence_service.py index 945680df..078ac906 100644 --- a/py/services/recipes/persistence_service.py +++ b/py/services/recipes/persistence_service.py @@ -1,7 +1,6 @@ """Services encapsulating recipe persistence workflows.""" from __future__ import annotations -import asyncio import base64 import json import os @@ -123,7 +122,7 @@ class RecipePersistenceService: self._exif_utils.append_recipe_metadata(image_path, recipe_data) matching_recipes = await self._find_matching_recipes(recipe_scanner, fingerprint, exclude_id=recipe_id) - await self._update_cache(recipe_scanner, recipe_data) + await recipe_scanner.add_recipe(recipe_data) return PersistenceResult( { @@ -154,7 +153,7 @@ class RecipePersistenceService: if image_path and os.path.exists(image_path): os.remove(image_path) - await self._remove_from_cache(recipe_scanner, recipe_id) + await recipe_scanner.remove_recipe(recipe_id) return PersistenceResult({"success": True, "message": "Recipe deleted successfully"}) async def update_recipe(self, *, recipe_scanner, recipe_id: str, updates: dict[str, Any]) -> PersistenceResult: @@ -185,40 +184,16 @@ class RecipePersistenceService: if not os.path.exists(recipe_path): raise RecipeNotFoundError("Recipe not found") - lora_scanner = getattr(recipe_scanner, "_lora_scanner", None) - target_lora = None if lora_scanner is None else await lora_scanner.get_model_info_by_name(target_name) + target_lora = await recipe_scanner.get_local_lora(target_name) if not target_lora: raise RecipeNotFoundError(f"Local LoRA not found with name: {target_name}") - with open(recipe_path, "r", encoding="utf-8") as file_obj: - recipe_data = json.load(file_obj) - - loras = recipe_data.get("loras", []) - if lora_index >= len(loras): - raise RecipeNotFoundError("LoRA index out of range in recipe") - - lora = loras[lora_index] - lora["isDeleted"] = False - lora["exclude"] = False - lora["file_name"] = target_name - if "sha256" in target_lora: - lora["hash"] = target_lora["sha256"].lower() - if target_lora.get("civitai"): - lora["modelName"] = target_lora["civitai"]["model"]["name"] - lora["modelVersionName"] = target_lora["civitai"]["name"] - lora["modelVersionId"] = target_lora["civitai"]["id"] - - recipe_data["fingerprint"] = calculate_recipe_fingerprint(recipe_data.get("loras", [])) - - with open(recipe_path, "w", encoding="utf-8") as file_obj: - json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) - - updated_lora = dict(lora) - updated_lora["inLibrary"] = True - updated_lora["preview_url"] = config.get_preview_static_url(target_lora["preview_url"]) - updated_lora["localPath"] = target_lora["file_path"] - - await self._refresh_cache_after_update(recipe_scanner, recipe_id, recipe_data) + recipe_data, updated_lora = await recipe_scanner.update_lora_entry( + recipe_id, + lora_index, + target_name=target_name, + target_lora=target_lora, + ) image_path = recipe_data.get("file_path") if image_path and os.path.exists(image_path): @@ -276,7 +251,7 @@ class RecipePersistenceService: failed_recipes.append({"id": recipe_id, "reason": str(exc)}) if deleted_recipes: - await self._bulk_remove_from_cache(recipe_scanner, deleted_recipes) + await recipe_scanner.bulk_remove(deleted_recipes) return PersistenceResult( { @@ -314,14 +289,11 @@ class RecipePersistenceService: if not lora_matches: raise RecipeValidationError("No LoRAs found in the generation metadata") - lora_scanner = getattr(recipe_scanner, "_lora_scanner", None) loras_data = [] base_model_counts: Dict[str, int] = {} for name, strength in lora_matches: - lora_info = None - if lora_scanner is not None: - lora_info = await lora_scanner.get_model_info_by_name(name) + lora_info = await recipe_scanner.get_local_lora(name) lora_data = { "file_name": name, "strength": float(strength), @@ -366,7 +338,7 @@ class RecipePersistenceService: json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) self._exif_utils.append_recipe_metadata(image_path, recipe_data) - await self._update_cache(recipe_scanner, recipe_data) + await recipe_scanner.add_recipe(recipe_data) return PersistenceResult( { @@ -422,45 +394,6 @@ class RecipePersistenceService: matches.remove(exclude_id) return matches - async def _update_cache(self, recipe_scanner, recipe_data: dict[str, Any]) -> None: - cache = getattr(recipe_scanner, "_cache", None) - if cache is not None: - cache.raw_data.append(recipe_data) - asyncio.create_task(cache.resort()) - self._logger.info("Added recipe %s to cache", recipe_data.get("id")) - - async def _remove_from_cache(self, recipe_scanner, recipe_id: str) -> None: - cache = getattr(recipe_scanner, "_cache", None) - if cache is not None: - cache.raw_data = [item for item in cache.raw_data if str(item.get("id", "")) != recipe_id] - asyncio.create_task(cache.resort()) - self._logger.info("Removed recipe %s from cache", recipe_id) - - async def _bulk_remove_from_cache(self, recipe_scanner, recipe_ids: Iterable[str]) -> None: - cache = getattr(recipe_scanner, "_cache", None) - if cache is not None: - recipe_ids_set = set(recipe_ids) - cache.raw_data = [item for item in cache.raw_data if item.get("id") not in recipe_ids_set] - asyncio.create_task(cache.resort()) - self._logger.info("Removed %s recipes from cache", len(recipe_ids_set)) - - async def _refresh_cache_after_update( - self, - recipe_scanner, - recipe_id: str, - recipe_data: dict[str, Any], - ) -> None: - cache = getattr(recipe_scanner, "_cache", None) - if cache is not None: - for cache_item in cache.raw_data: - if cache_item.get("id") == recipe_id: - cache_item.update({ - "loras": recipe_data.get("loras", []), - "fingerprint": recipe_data.get("fingerprint"), - }) - asyncio.create_task(cache.resort()) - break - def _derive_recipe_name(self, lora_matches: list[tuple[str, str]]) -> str: recipe_name_parts = [f"{name.strip()}-{float(strength):.2f}" for name, strength in lora_matches[:3]] recipe_name = "_".join(recipe_name_parts) diff --git a/py/services/recipes/sharing_service.py b/py/services/recipes/sharing_service.py index 7c365bba..47ab9718 100644 --- a/py/services/recipes/sharing_service.py +++ b/py/services/recipes/sharing_service.py @@ -38,11 +38,7 @@ class RecipeSharingService: async def share_recipe(self, *, recipe_scanner, recipe_id: str) -> SharingResult: """Prepare a temporary downloadable copy of a recipe image.""" - cache = await recipe_scanner.get_cached_data() - recipe = next( - (r for r in getattr(cache, "raw_data", []) if str(r.get("id", "")) == recipe_id), - None, - ) + recipe = await recipe_scanner.get_recipe_by_id(recipe_id) if not recipe: raise RecipeNotFoundError("Recipe not found") @@ -81,11 +77,7 @@ class RecipeSharingService: self._cleanup_entry(recipe_id) raise RecipeNotFoundError("Shared recipe file not found") - cache = await recipe_scanner.get_cached_data() - recipe = next( - (r for r in getattr(cache, "raw_data", []) if str(r.get("id", "")) == recipe_id), - None, - ) + recipe = await recipe_scanner.get_recipe_by_id(recipe_id) filename_base = ( f"recipe_{recipe.get('title', '').replace(' ', '_').lower()}" if recipe else recipe_id ) diff --git a/tests/services/test_recipe_scanner.py b/tests/services/test_recipe_scanner.py new file mode 100644 index 00000000..63c18f25 --- /dev/null +++ b/tests/services/test_recipe_scanner.py @@ -0,0 +1,185 @@ +import asyncio +import json +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from py.config import config +from py.services.recipe_scanner import RecipeScanner +from py.utils.utils import calculate_recipe_fingerprint + + +class StubHashIndex: + def __init__(self) -> None: + self._hash_to_path: dict[str, str] = {} + + def get_path(self, hash_value: str) -> str | None: + return self._hash_to_path.get(hash_value) + + +class StubLoraScanner: + def __init__(self) -> None: + self._hash_index = StubHashIndex() + self._hash_meta: dict[str, dict[str, str]] = {} + self._models_by_name: dict[str, dict] = {} + self._cache = SimpleNamespace(raw_data=[]) + + async def get_cached_data(self): + return self._cache + + def has_hash(self, hash_value: str) -> bool: + return hash_value.lower() in self._hash_meta + + def get_preview_url_by_hash(self, hash_value: str) -> str: + meta = self._hash_meta.get(hash_value.lower()) + return meta.get("preview_url", "") if meta else "" + + def get_path_by_hash(self, hash_value: str) -> str | None: + meta = self._hash_meta.get(hash_value.lower()) + return meta.get("path") if meta else None + + async def get_model_info_by_name(self, name: str): + return self._models_by_name.get(name) + + def register_model(self, name: str, info: dict) -> None: + self._models_by_name[name] = info + hash_value = (info.get("sha256") or "").lower() + if hash_value: + self._hash_meta[hash_value] = { + "path": info.get("file_path", ""), + "preview_url": info.get("preview_url", ""), + } + self._hash_index._hash_to_path[hash_value] = info.get("file_path", "") + self._cache.raw_data.append({ + "sha256": info.get("sha256", ""), + "path": info.get("file_path", ""), + "civitai": info.get("civitai", {}), + }) + + +@pytest.fixture +def recipe_scanner(tmp_path: Path, monkeypatch): + RecipeScanner._instance = None + monkeypatch.setattr(config, "loras_roots", [str(tmp_path)]) + stub = StubLoraScanner() + scanner = RecipeScanner(lora_scanner=stub) + asyncio.run(scanner.refresh_cache(force=True)) + yield scanner, stub + RecipeScanner._instance = None + + +async def test_add_recipe_during_concurrent_reads(recipe_scanner): + scanner, _ = recipe_scanner + + initial_recipe = { + "id": "one", + "file_path": "path/a.png", + "title": "First", + "modified": 1.0, + "created_date": 1.0, + "loras": [], + } + await scanner.add_recipe(initial_recipe) + + new_recipe = { + "id": "two", + "file_path": "path/b.png", + "title": "Second", + "modified": 2.0, + "created_date": 2.0, + "loras": [], + } + + async def reader_task(): + for _ in range(5): + cache = await scanner.get_cached_data() + _ = [item["id"] for item in cache.raw_data] + await asyncio.sleep(0) + + await asyncio.gather(reader_task(), reader_task(), scanner.add_recipe(new_recipe)) + await asyncio.sleep(0) + cache = await scanner.get_cached_data() + + assert {item["id"] for item in cache.raw_data} == {"one", "two"} + assert len(cache.sorted_by_name) == len(cache.raw_data) + + +async def test_remove_recipe_during_reads(recipe_scanner): + scanner, _ = recipe_scanner + + recipe_ids = ["alpha", "beta", "gamma"] + for index, recipe_id in enumerate(recipe_ids): + await scanner.add_recipe({ + "id": recipe_id, + "file_path": f"path/{recipe_id}.png", + "title": recipe_id, + "modified": float(index), + "created_date": float(index), + "loras": [], + }) + + async def reader_task(): + for _ in range(5): + cache = await scanner.get_cached_data() + _ = list(cache.sorted_by_date) + await asyncio.sleep(0) + + await asyncio.gather(reader_task(), scanner.remove_recipe("beta")) + await asyncio.sleep(0) + cache = await scanner.get_cached_data() + + assert {item["id"] for item in cache.raw_data} == {"alpha", "gamma"} + + +async def test_update_lora_entry_updates_cache_and_file(tmp_path: Path, recipe_scanner): + scanner, stub = recipe_scanner + recipes_dir = Path(config.loras_roots[0]) / "recipes" + recipes_dir.mkdir(parents=True, exist_ok=True) + + recipe_id = "recipe-1" + recipe_path = recipes_dir / f"{recipe_id}.recipe.json" + recipe_data = { + "id": recipe_id, + "file_path": str(tmp_path / "image.png"), + "title": "Original", + "modified": 0.0, + "created_date": 0.0, + "loras": [ + {"file_name": "old", "strength": 1.0, "hash": "", "isDeleted": True, "exclude": True}, + ], + } + recipe_path.write_text(json.dumps(recipe_data)) + + await scanner.add_recipe(dict(recipe_data)) + + target_hash = "abc123" + target_info = { + "sha256": target_hash, + "file_path": str(tmp_path / "loras" / "target.safetensors"), + "preview_url": "preview.png", + "civitai": {"id": 42, "name": "v1", "model": {"name": "Target"}}, + } + stub.register_model("target", target_info) + + updated_recipe, updated_lora = await scanner.update_lora_entry( + recipe_id, + 0, + target_name="target", + target_lora=target_info, + ) + + assert updated_lora["inLibrary"] is True + assert updated_lora["localPath"] == target_info["file_path"] + assert updated_lora["hash"] == target_hash + + with recipe_path.open("r", encoding="utf-8") as file_obj: + persisted = json.load(file_obj) + + expected_fingerprint = calculate_recipe_fingerprint(persisted["loras"]) + assert persisted["fingerprint"] == expected_fingerprint + + cache = await scanner.get_cached_data() + cached_recipe = next(item for item in cache.raw_data if item["id"] == recipe_id) + assert cached_recipe["loras"][0]["hash"] == target_hash + assert cached_recipe["fingerprint"] == expected_fingerprint diff --git a/tests/services/test_recipe_services.py b/tests/services/test_recipe_services.py index e57abf2f..81a15424 100644 --- a/tests/services/test_recipe_services.py +++ b/tests/services/test_recipe_services.py @@ -108,6 +108,10 @@ async def test_save_recipe_reports_duplicates(tmp_path): self.last_fingerprint = fingerprint return ["existing"] + async def add_recipe(self, recipe_data): + self._cache.raw_data.append(recipe_data) + await self._cache.resort() + scanner = DummyScanner(tmp_path) service = RecipePersistenceService( exif_utils=exif_utils,