diff --git a/py/services/recipe_scanner.py b/py/services/recipe_scanner.py index 3b41643e..c7ac2842 100644 --- a/py/services/recipe_scanner.py +++ b/py/services/recipe_scanner.py @@ -1333,71 +1333,56 @@ class RecipeScanner: # Always use lowercase hash for consistency hash_value = hash_value.lower() - # Get recipes directory - recipes_dir = self.recipes_dir - if not recipes_dir or not os.path.exists(recipes_dir): - logger.warning(f"Recipes directory not found: {recipes_dir}") + # Get cache + cache = await self.get_cached_data() + if not cache or not cache.raw_data: + return 0, 0 + + file_updated_count = 0 + cache_updated_count = 0 + + # Find recipes that need updating from the cache + recipes_to_update = [] + for recipe in cache.raw_data: + loras = recipe.get('loras', []) + if not isinstance(loras, list): + continue + + has_match = False + for lora in loras: + if not isinstance(lora, dict): + continue + if (lora.get('hash') or '').lower() == hash_value: + if lora.get('file_name') != new_file_name: + lora['file_name'] = new_file_name + has_match = True + + if has_match: + recipes_to_update.append(recipe) + cache_updated_count += 1 + + if not recipes_to_update: return 0, 0 - # Check if cache is initialized - cache_initialized = self._cache is not None - cache_updated_count = 0 - file_updated_count = 0 - - # Get all recipe JSON files in the recipes directory - recipe_files = [] - for root, _, files in os.walk(recipes_dir): - for file in files: - if file.lower().endswith('.recipe.json'): - recipe_files.append(os.path.join(root, file)) - - # Process each recipe file - for recipe_path in recipe_files: - try: - # Load the recipe data - with open(recipe_path, 'r', encoding='utf-8') as f: - recipe_data = json.load(f) - - # Skip if no loras or invalid structure - if not recipe_data or not isinstance(recipe_data, dict) or 'loras' not in recipe_data: + # Persist changes to disk + async with self._mutation_lock: + for recipe in recipes_to_update: + recipe_id = recipe.get('id') + if not recipe_id: continue - - # Check if any lora has matching hash - file_updated = False - for lora in recipe_data.get('loras', []): - if 'hash' in lora and lora['hash'].lower() == hash_value: - # Update file_name - old_file_name = lora.get('file_name', '') - lora['file_name'] = new_file_name - file_updated = True - logger.info(f"Updated file_name in recipe {recipe_path}: {old_file_name} -> {new_file_name}") - - # If updated, save the file - if file_updated: - with open(recipe_path, 'w', encoding='utf-8') as f: - json.dump(recipe_data, f, indent=4, ensure_ascii=False) - file_updated_count += 1 - # Also update in cache if it exists - if cache_initialized: - recipe_id = recipe_data.get('id') - if recipe_id: - for cache_item in self._cache.raw_data: - if cache_item.get('id') == recipe_id: - # Replace loras array with updated version - cache_item['loras'] = recipe_data['loras'] - cache_updated_count += 1 - break - - except Exception as e: - logger.error(f"Error updating recipe file {recipe_path}: {e}") - import traceback - traceback.print_exc(file=sys.stderr) + recipe_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json") + try: + self._write_recipe_file(recipe_path, recipe) + file_updated_count += 1 + logger.info(f"Updated file_name in recipe {recipe_path}: -> {new_file_name}") + except Exception as e: + logger.error(f"Error updating recipe file {recipe_path}: {e}") - # Resort cache if updates were made - if cache_initialized and cache_updated_count > 0: - await self._cache.resort() - logger.info(f"Resorted recipe cache after updating {cache_updated_count} items") + # We don't necessarily need to resort because LoRA file_name isn't a sort key, + # but we might want to schedule a resort if we're paranoid or if searching relies on sorted state. + # Given it's a rename of a dependency, search results might change if searching by LoRA name. + self._schedule_resort() return file_updated_count, cache_updated_count diff --git a/tests/services/test_recipe_scanner.py b/tests/services/test_recipe_scanner.py index cdca7c46..822eef43 100644 --- a/tests/services/test_recipe_scanner.py +++ b/tests/services/test_recipe_scanner.py @@ -445,3 +445,64 @@ async def test_load_recipe_persists_deleted_flag_on_invalid_version(monkeypatch, persisted = json.loads(recipe_path.read_text()) assert persisted["loras"][0]["isDeleted"] is True + + +@pytest.mark.asyncio +async def test_update_lora_filename_by_hash_updates_affected_recipes(tmp_path: Path, recipe_scanner): + scanner, _ = recipe_scanner + recipes_dir = Path(config.loras_roots[0]) / "recipes" + recipes_dir.mkdir(parents=True, exist_ok=True) + + # Recipe 1: Contains the LoRA with hash "hash1" + recipe1_id = "recipe1" + recipe1_path = recipes_dir / f"{recipe1_id}.recipe.json" + recipe1_data = { + "id": recipe1_id, + "file_path": str(tmp_path / "img1.png"), + "title": "Recipe 1", + "modified": 0.0, + "created_date": 0.0, + "loras": [ + {"file_name": "old_name", "hash": "hash1"}, + {"file_name": "other_lora", "hash": "hash2"} + ], + } + recipe1_path.write_text(json.dumps(recipe1_data)) + await scanner.add_recipe(dict(recipe1_data)) + + # Recipe 2: Does NOT contain the LoRA + recipe2_id = "recipe2" + recipe2_path = recipes_dir / f"{recipe2_id}.recipe.json" + recipe2_data = { + "id": recipe2_id, + "file_path": str(tmp_path / "img2.png"), + "title": "Recipe 2", + "modified": 0.0, + "created_date": 0.0, + "loras": [ + {"file_name": "other_lora", "hash": "hash2"} + ], + } + recipe2_path.write_text(json.dumps(recipe2_data)) + await scanner.add_recipe(dict(recipe2_data)) + + # Update LoRA name for "hash1" (using different case to test normalization) + new_name = "new_name" + file_count, cache_count = await scanner.update_lora_filename_by_hash("HASH1", new_name) + + assert file_count == 1 + assert cache_count == 1 + + # Check file on disk + persisted1 = json.loads(recipe1_path.read_text()) + assert persisted1["loras"][0]["file_name"] == new_name + assert persisted1["loras"][1]["file_name"] == "other_lora" + + # Verify Recipe 2 unchanged + persisted2 = json.loads(recipe2_path.read_text()) + assert persisted2["loras"][0]["file_name"] == "other_lora" + + # Check cache + cache = await scanner.get_cached_data() + cached1 = next(r for r in cache.raw_data if r["id"] == recipe1_id) + assert cached1["loras"][0]["file_name"] == new_name