mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Merge pull request #458 from willmiao/codex/expose-first-class-operations-on-recipescanner
feat: expose recipe scanner mutation APIs
This commit is contained in:
@@ -290,27 +290,7 @@ class RecipeQueryHandler:
|
|||||||
if not lora_hash:
|
if not lora_hash:
|
||||||
return web.json_response({"success": False, "error": "Lora hash is required"}, status=400)
|
return web.json_response({"success": False, "error": "Lora hash is required"}, status=400)
|
||||||
|
|
||||||
cache = await recipe_scanner.get_cached_data()
|
matching_recipes = await recipe_scanner.get_recipes_for_lora(lora_hash)
|
||||||
matching_recipes = []
|
|
||||||
for recipe in getattr(cache, "raw_data", []):
|
|
||||||
for lora in recipe.get("loras", []):
|
|
||||||
if lora.get("hash", "").lower() == lora_hash.lower():
|
|
||||||
matching_recipes.append(recipe)
|
|
||||||
break
|
|
||||||
|
|
||||||
lora_scanner = getattr(recipe_scanner, "_lora_scanner", None)
|
|
||||||
for recipe in matching_recipes:
|
|
||||||
for lora in recipe.get("loras", []):
|
|
||||||
hash_value = (lora.get("hash") or "").lower()
|
|
||||||
if hash_value and lora_scanner is not None:
|
|
||||||
lora["inLibrary"] = lora_scanner.has_hash(hash_value)
|
|
||||||
lora["preview_url"] = lora_scanner.get_preview_url_by_hash(hash_value)
|
|
||||||
lora["localPath"] = lora_scanner.get_path_by_hash(hash_value)
|
|
||||||
if recipe.get("file_path"):
|
|
||||||
recipe["file_url"] = self._format_recipe_file_url(recipe["file_path"])
|
|
||||||
else:
|
|
||||||
recipe["file_url"] = "/loras_static/images/no-preview.png"
|
|
||||||
|
|
||||||
return web.json_response({"success": True, "recipes": matching_recipes})
|
return web.json_response({"success": True, "recipes": matching_recipes})
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self._logger.error("Error getting recipes for Lora: %s", exc)
|
self._logger.error("Error getting recipes for Lora: %s", exc)
|
||||||
@@ -384,50 +364,15 @@ class RecipeQueryHandler:
|
|||||||
raise RuntimeError("Recipe scanner unavailable")
|
raise RuntimeError("Recipe scanner unavailable")
|
||||||
|
|
||||||
recipe_id = request.match_info["recipe_id"]
|
recipe_id = request.match_info["recipe_id"]
|
||||||
cache = await recipe_scanner.get_cached_data()
|
try:
|
||||||
recipe = next(
|
syntax_parts = await recipe_scanner.get_recipe_syntax_tokens(recipe_id)
|
||||||
(r for r in getattr(cache, "raw_data", []) if str(r.get("id", "")) == recipe_id),
|
except RecipeNotFoundError:
|
||||||
None,
|
|
||||||
)
|
|
||||||
if not recipe:
|
|
||||||
return web.json_response({"error": "Recipe not found"}, status=404)
|
return web.json_response({"error": "Recipe not found"}, status=404)
|
||||||
|
|
||||||
loras = recipe.get("loras", [])
|
if not syntax_parts:
|
||||||
if not loras:
|
|
||||||
return web.json_response({"error": "No LoRAs found in this recipe"}, status=400)
|
return web.json_response({"error": "No LoRAs found in this recipe"}, status=400)
|
||||||
|
|
||||||
lora_scanner = getattr(recipe_scanner, "_lora_scanner", None)
|
return web.json_response({"success": True, "syntax": " ".join(syntax_parts)})
|
||||||
hash_index = getattr(lora_scanner, "_hash_index", None)
|
|
||||||
|
|
||||||
lora_syntax_parts = []
|
|
||||||
for lora in loras:
|
|
||||||
if lora.get("isDeleted", False):
|
|
||||||
continue
|
|
||||||
hash_value = (lora.get("hash") or "").lower()
|
|
||||||
if not hash_value or lora_scanner is None or not lora_scanner.has_hash(hash_value):
|
|
||||||
continue
|
|
||||||
|
|
||||||
file_name = None
|
|
||||||
if hash_value and hash_index is not None and hasattr(hash_index, "_hash_to_path"):
|
|
||||||
file_path = hash_index._hash_to_path.get(hash_value)
|
|
||||||
if file_path:
|
|
||||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
|
||||||
|
|
||||||
if not file_name and lora.get("modelVersionId") and lora_scanner is not None:
|
|
||||||
all_loras = await lora_scanner.get_cached_data()
|
|
||||||
for cached_lora in getattr(all_loras, "raw_data", []):
|
|
||||||
civitai_info = cached_lora.get("civitai")
|
|
||||||
if civitai_info and civitai_info.get("id") == lora.get("modelVersionId"):
|
|
||||||
file_name = os.path.splitext(os.path.basename(cached_lora["path"]))[0]
|
|
||||||
break
|
|
||||||
|
|
||||||
if not file_name:
|
|
||||||
file_name = lora.get("file_name", "unknown-lora")
|
|
||||||
|
|
||||||
strength = lora.get("strength", 1.0)
|
|
||||||
lora_syntax_parts.append(f"<lora:{file_name}:{strength}>")
|
|
||||||
|
|
||||||
return web.json_response({"success": True, "syntax": " ".join(lora_syntax_parts)})
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self._logger.error("Error generating recipe syntax: %s", exc, exc_info=True)
|
self._logger.error("Error generating recipe syntax: %s", exc, exc_info=True)
|
||||||
return web.json_response({"error": str(exc)}, status=500)
|
return web.json_response({"error": str(exc)}, status=500)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Dict
|
from typing import Iterable, List, Dict, Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from natsort import natsorted
|
from natsort import natsorted
|
||||||
@@ -10,77 +10,115 @@ class RecipeCache:
|
|||||||
raw_data: List[Dict]
|
raw_data: List[Dict]
|
||||||
sorted_by_name: List[Dict]
|
sorted_by_name: List[Dict]
|
||||||
sorted_by_date: List[Dict]
|
sorted_by_date: List[Dict]
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
async def resort(self, name_only: bool = False):
|
async def resort(self, name_only: bool = False):
|
||||||
"""Resort all cached data views"""
|
"""Resort all cached data views"""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self.sorted_by_name = natsorted(
|
self._resort_locked(name_only=name_only)
|
||||||
self.raw_data,
|
|
||||||
key=lambda x: x.get('title', '').lower() # Case-insensitive sort
|
async def update_recipe_metadata(self, recipe_id: str, metadata: Dict, *, resort: bool = True) -> bool:
|
||||||
)
|
|
||||||
if not name_only:
|
|
||||||
self.sorted_by_date = sorted(
|
|
||||||
self.raw_data,
|
|
||||||
key=itemgetter('created_date', 'file_path'),
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
async def update_recipe_metadata(self, recipe_id: str, metadata: Dict) -> bool:
|
|
||||||
"""Update metadata for a specific recipe in all cached data
|
"""Update metadata for a specific recipe in all cached data
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
recipe_id: The ID of the recipe to update
|
recipe_id: The ID of the recipe to update
|
||||||
metadata: The new metadata
|
metadata: The new metadata
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the update was successful, False if the recipe wasn't found
|
bool: True if the update was successful, False if the recipe wasn't found
|
||||||
"""
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
for item in self.raw_data:
|
||||||
|
if str(item.get('id')) == str(recipe_id):
|
||||||
|
item.update(metadata)
|
||||||
|
if resort:
|
||||||
|
self._resort_locked()
|
||||||
|
return True
|
||||||
|
return False # Recipe not found
|
||||||
|
|
||||||
|
async def add_recipe(self, recipe_data: Dict, *, resort: bool = False) -> None:
|
||||||
|
"""Add a new recipe to the cache."""
|
||||||
|
|
||||||
# Update in raw_data
|
|
||||||
for item in self.raw_data:
|
|
||||||
if item.get('id') == recipe_id:
|
|
||||||
item.update(metadata)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
return False # Recipe not found
|
|
||||||
|
|
||||||
# Resort to reflect changes
|
|
||||||
await self.resort()
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def add_recipe(self, recipe_data: Dict) -> None:
|
|
||||||
"""Add a new recipe to the cache
|
|
||||||
|
|
||||||
Args:
|
|
||||||
recipe_data: The recipe data to add
|
|
||||||
"""
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self.raw_data.append(recipe_data)
|
self.raw_data.append(recipe_data)
|
||||||
await self.resort()
|
if resort:
|
||||||
|
self._resort_locked()
|
||||||
|
|
||||||
|
async def remove_recipe(self, recipe_id: str, *, resort: bool = False) -> Optional[Dict]:
|
||||||
|
"""Remove a recipe from the cache by ID.
|
||||||
|
|
||||||
async def remove_recipe(self, recipe_id: str) -> bool:
|
|
||||||
"""Remove a recipe from the cache by ID
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
recipe_id: The ID of the recipe to remove
|
recipe_id: The ID of the recipe to remove
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the recipe was found and removed, False otherwise
|
The removed recipe data if found, otherwise ``None``.
|
||||||
"""
|
"""
|
||||||
# Find the recipe in raw_data
|
|
||||||
recipe_index = next((i for i, recipe in enumerate(self.raw_data)
|
async with self._lock:
|
||||||
if recipe.get('id') == recipe_id), None)
|
for index, recipe in enumerate(self.raw_data):
|
||||||
|
if str(recipe.get('id')) == str(recipe_id):
|
||||||
if recipe_index is None:
|
removed = self.raw_data.pop(index)
|
||||||
return False
|
if resort:
|
||||||
|
self._resort_locked()
|
||||||
# Remove from raw_data
|
return removed
|
||||||
self.raw_data.pop(recipe_index)
|
return None
|
||||||
|
|
||||||
# Resort to update sorted lists
|
async def bulk_remove(self, recipe_ids: Iterable[str], *, resort: bool = False) -> List[Dict]:
|
||||||
await self.resort()
|
"""Remove multiple recipes from the cache."""
|
||||||
|
|
||||||
return True
|
id_set = {str(recipe_id) for recipe_id in recipe_ids}
|
||||||
|
if not id_set:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
removed = [item for item in self.raw_data if str(item.get('id')) in id_set]
|
||||||
|
if not removed:
|
||||||
|
return []
|
||||||
|
|
||||||
|
self.raw_data = [item for item in self.raw_data if str(item.get('id')) not in id_set]
|
||||||
|
if resort:
|
||||||
|
self._resort_locked()
|
||||||
|
return removed
|
||||||
|
|
||||||
|
async def replace_recipe(self, recipe_id: str, new_data: Dict, *, resort: bool = False) -> bool:
|
||||||
|
"""Replace cached data for a recipe."""
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
for index, recipe in enumerate(self.raw_data):
|
||||||
|
if str(recipe.get('id')) == str(recipe_id):
|
||||||
|
self.raw_data[index] = new_data
|
||||||
|
if resort:
|
||||||
|
self._resort_locked()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_recipe(self, recipe_id: str) -> Optional[Dict]:
|
||||||
|
"""Return a shallow copy of a cached recipe."""
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
for recipe in self.raw_data:
|
||||||
|
if str(recipe.get('id')) == str(recipe_id):
|
||||||
|
return dict(recipe)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def snapshot(self) -> List[Dict]:
|
||||||
|
"""Return a copy of all cached recipes."""
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
return [dict(item) for item in self.raw_data]
|
||||||
|
|
||||||
|
def _resort_locked(self, *, name_only: bool = False) -> None:
|
||||||
|
"""Sort cached views. Caller must hold ``_lock``."""
|
||||||
|
|
||||||
|
self.sorted_by_name = natsorted(
|
||||||
|
self.raw_data,
|
||||||
|
key=lambda x: x.get('title', '').lower()
|
||||||
|
)
|
||||||
|
if not name_only:
|
||||||
|
self.sorted_by_date = sorted(
|
||||||
|
self.raw_data,
|
||||||
|
key=itemgetter('created_date', 'file_path'),
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
@@ -3,13 +3,14 @@ import logging
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import List, Dict, Optional, Any, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from .recipe_cache import RecipeCache
|
from .recipe_cache import RecipeCache
|
||||||
from .service_registry import ServiceRegistry
|
from .service_registry import ServiceRegistry
|
||||||
from .lora_scanner import LoraScanner
|
from .lora_scanner import LoraScanner
|
||||||
from .metadata_service import get_default_metadata_provider
|
from .metadata_service import get_default_metadata_provider
|
||||||
from ..utils.utils import fuzzy_match
|
from .recipes.errors import RecipeNotFoundError
|
||||||
|
from ..utils.utils import calculate_recipe_fingerprint, fuzzy_match
|
||||||
from natsort import natsorted
|
from natsort import natsorted
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@@ -46,6 +47,8 @@ class RecipeScanner:
|
|||||||
self._initialization_lock = asyncio.Lock()
|
self._initialization_lock = asyncio.Lock()
|
||||||
self._initialization_task: Optional[asyncio.Task] = None
|
self._initialization_task: Optional[asyncio.Task] = None
|
||||||
self._is_initializing = False
|
self._is_initializing = False
|
||||||
|
self._mutation_lock = asyncio.Lock()
|
||||||
|
self._resort_tasks: Set[asyncio.Task] = set()
|
||||||
if lora_scanner:
|
if lora_scanner:
|
||||||
self._lora_scanner = lora_scanner
|
self._lora_scanner = lora_scanner
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
@@ -191,6 +194,22 @@ class RecipeScanner:
|
|||||||
# Clean up the event loop
|
# Clean up the event loop
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
|
def _schedule_resort(self, *, name_only: bool = False) -> None:
|
||||||
|
"""Schedule a background resort of the recipe cache."""
|
||||||
|
|
||||||
|
if not self._cache:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def _resort_wrapper() -> None:
|
||||||
|
try:
|
||||||
|
await self._cache.resort(name_only=name_only)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
logger.error("Recipe Scanner: error resorting cache: %s", exc, exc_info=True)
|
||||||
|
|
||||||
|
task = asyncio.create_task(_resort_wrapper())
|
||||||
|
self._resort_tasks.add(task)
|
||||||
|
task.add_done_callback(lambda finished: self._resort_tasks.discard(finished))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def recipes_dir(self) -> str:
|
def recipes_dir(self) -> str:
|
||||||
"""Get path to recipes directory"""
|
"""Get path to recipes directory"""
|
||||||
@@ -255,7 +274,45 @@ class RecipeScanner:
|
|||||||
|
|
||||||
# Return the cache (may be empty or partially initialized)
|
# Return the cache (may be empty or partially initialized)
|
||||||
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
|
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
|
||||||
|
|
||||||
|
async def refresh_cache(self, force: bool = False) -> RecipeCache:
|
||||||
|
"""Public helper to refresh or return the recipe cache."""
|
||||||
|
|
||||||
|
return await self.get_cached_data(force_refresh=force)
|
||||||
|
|
||||||
|
async def add_recipe(self, recipe_data: Dict[str, Any]) -> None:
|
||||||
|
"""Add a recipe to the in-memory cache."""
|
||||||
|
|
||||||
|
if not recipe_data:
|
||||||
|
return
|
||||||
|
|
||||||
|
cache = await self.get_cached_data()
|
||||||
|
await cache.add_recipe(recipe_data, resort=False)
|
||||||
|
self._schedule_resort()
|
||||||
|
|
||||||
|
async def remove_recipe(self, recipe_id: str) -> bool:
|
||||||
|
"""Remove a recipe from the cache by ID."""
|
||||||
|
|
||||||
|
if not recipe_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
cache = await self.get_cached_data()
|
||||||
|
removed = await cache.remove_recipe(recipe_id, resort=False)
|
||||||
|
if removed is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._schedule_resort()
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def bulk_remove(self, recipe_ids: Iterable[str]) -> int:
|
||||||
|
"""Remove multiple recipes from the cache."""
|
||||||
|
|
||||||
|
cache = await self.get_cached_data()
|
||||||
|
removed = await cache.bulk_remove(recipe_ids, resort=False)
|
||||||
|
if removed:
|
||||||
|
self._schedule_resort()
|
||||||
|
return len(removed)
|
||||||
|
|
||||||
async def scan_all_recipes(self) -> List[Dict]:
|
async def scan_all_recipes(self) -> List[Dict]:
|
||||||
"""Scan all recipe JSON files and return metadata"""
|
"""Scan all recipe JSON files and return metadata"""
|
||||||
recipes = []
|
recipes = []
|
||||||
@@ -326,7 +383,6 @@ class RecipeScanner:
|
|||||||
|
|
||||||
# Calculate and update fingerprint if missing
|
# Calculate and update fingerprint if missing
|
||||||
if 'loras' in recipe_data and 'fingerprint' not in recipe_data:
|
if 'loras' in recipe_data and 'fingerprint' not in recipe_data:
|
||||||
from ..utils.utils import calculate_recipe_fingerprint
|
|
||||||
fingerprint = calculate_recipe_fingerprint(recipe_data['loras'])
|
fingerprint = calculate_recipe_fingerprint(recipe_data['loras'])
|
||||||
recipe_data['fingerprint'] = fingerprint
|
recipe_data['fingerprint'] = fingerprint
|
||||||
|
|
||||||
@@ -497,9 +553,36 @@ class RecipeScanner:
|
|||||||
logger.error(f"Error getting base model for lora: {e}")
|
logger.error(f"Error getting base model for lora: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _enrich_lora_entry(self, lora: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Populate convenience fields for a LoRA entry."""
|
||||||
|
|
||||||
|
if not lora or not self._lora_scanner:
|
||||||
|
return lora
|
||||||
|
|
||||||
|
hash_value = (lora.get('hash') or '').lower()
|
||||||
|
if not hash_value:
|
||||||
|
return lora
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora['inLibrary'] = self._lora_scanner.has_hash(hash_value)
|
||||||
|
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(hash_value)
|
||||||
|
lora['localPath'] = self._lora_scanner.get_path_by_hash(hash_value)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
logger.debug("Error enriching lora entry %s: %s", hash_value, exc)
|
||||||
|
|
||||||
|
return lora
|
||||||
|
|
||||||
|
async def get_local_lora(self, name: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Lookup a local LoRA model by name."""
|
||||||
|
|
||||||
|
if not self._lora_scanner or not name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return await self._lora_scanner.get_model_info_by_name(name)
|
||||||
|
|
||||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None, lora_hash: str = None, bypass_filters: bool = True):
|
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None, lora_hash: str = None, bypass_filters: bool = True):
|
||||||
"""Get paginated and filtered recipe data
|
"""Get paginated and filtered recipe data
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
page: Current page number (1-based)
|
page: Current page number (1-based)
|
||||||
page_size: Number of items per page
|
page_size: Number of items per page
|
||||||
@@ -598,16 +681,12 @@ class RecipeScanner:
|
|||||||
|
|
||||||
# Get paginated items
|
# Get paginated items
|
||||||
paginated_items = filtered_data[start_idx:end_idx]
|
paginated_items = filtered_data[start_idx:end_idx]
|
||||||
|
|
||||||
# Add inLibrary information for each lora
|
# Add inLibrary information for each lora
|
||||||
for item in paginated_items:
|
for item in paginated_items:
|
||||||
if 'loras' in item:
|
if 'loras' in item:
|
||||||
for lora in item['loras']:
|
item['loras'] = [self._enrich_lora_entry(dict(lora)) for lora in item['loras']]
|
||||||
if 'hash' in lora and lora['hash']:
|
|
||||||
lora['inLibrary'] = self._lora_scanner.has_hash(lora['hash'].lower())
|
|
||||||
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
|
|
||||||
lora['localPath'] = self._lora_scanner.get_path_by_hash(lora['hash'].lower())
|
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
'items': paginated_items,
|
'items': paginated_items,
|
||||||
'total': total_items,
|
'total': total_items,
|
||||||
@@ -653,13 +732,8 @@ class RecipeScanner:
|
|||||||
|
|
||||||
# Add lora metadata
|
# Add lora metadata
|
||||||
if 'loras' in formatted_recipe:
|
if 'loras' in formatted_recipe:
|
||||||
for lora in formatted_recipe['loras']:
|
formatted_recipe['loras'] = [self._enrich_lora_entry(dict(lora)) for lora in formatted_recipe['loras']]
|
||||||
if 'hash' in lora and lora['hash']:
|
|
||||||
lora_hash = lora['hash'].lower()
|
|
||||||
lora['inLibrary'] = self._lora_scanner.has_hash(lora_hash)
|
|
||||||
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora_hash)
|
|
||||||
lora['localPath'] = self._lora_scanner.get_path_by_hash(lora_hash)
|
|
||||||
|
|
||||||
return formatted_recipe
|
return formatted_recipe
|
||||||
|
|
||||||
def _format_file_url(self, file_path: str) -> str:
|
def _format_file_url(self, file_path: str) -> str:
|
||||||
@@ -717,26 +791,159 @@ class RecipeScanner:
|
|||||||
# Save updated recipe
|
# Save updated recipe
|
||||||
with open(recipe_json_path, 'w', encoding='utf-8') as f:
|
with open(recipe_json_path, 'w', encoding='utf-8') as f:
|
||||||
json.dump(recipe_data, f, indent=4, ensure_ascii=False)
|
json.dump(recipe_data, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
# Update the cache if it exists
|
# Update the cache if it exists
|
||||||
if self._cache is not None:
|
if self._cache is not None:
|
||||||
await self._cache.update_recipe_metadata(recipe_id, metadata)
|
await self._cache.update_recipe_metadata(recipe_id, metadata, resort=False)
|
||||||
|
self._schedule_resort()
|
||||||
|
|
||||||
# If the recipe has an image, update its EXIF metadata
|
# If the recipe has an image, update its EXIF metadata
|
||||||
from ..utils.exif_utils import ExifUtils
|
from ..utils.exif_utils import ExifUtils
|
||||||
image_path = recipe_data.get('file_path')
|
image_path = recipe_data.get('file_path')
|
||||||
if image_path and os.path.exists(image_path):
|
if image_path and os.path.exists(image_path):
|
||||||
ExifUtils.append_recipe_metadata(image_path, recipe_data)
|
ExifUtils.append_recipe_metadata(image_path, recipe_data)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import logging
|
import logging
|
||||||
logging.getLogger(__name__).error(f"Error updating recipe metadata: {e}", exc_info=True)
|
logging.getLogger(__name__).error(f"Error updating recipe metadata: {e}", exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def update_lora_entry(
|
||||||
|
self,
|
||||||
|
recipe_id: str,
|
||||||
|
lora_index: int,
|
||||||
|
*,
|
||||||
|
target_name: str,
|
||||||
|
target_lora: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||||
|
"""Update a specific LoRA entry within a recipe.
|
||||||
|
|
||||||
|
Returns the updated recipe data and the refreshed LoRA metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if target_name is None:
|
||||||
|
raise ValueError("target_name must be provided")
|
||||||
|
|
||||||
|
recipe_json_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json")
|
||||||
|
if not os.path.exists(recipe_json_path):
|
||||||
|
raise RecipeNotFoundError("Recipe not found")
|
||||||
|
|
||||||
|
async with self._mutation_lock:
|
||||||
|
with open(recipe_json_path, 'r', encoding='utf-8') as file_obj:
|
||||||
|
recipe_data = json.load(file_obj)
|
||||||
|
|
||||||
|
loras = recipe_data.get('loras', [])
|
||||||
|
if lora_index >= len(loras):
|
||||||
|
raise RecipeNotFoundError("LoRA index out of range in recipe")
|
||||||
|
|
||||||
|
lora_entry = loras[lora_index]
|
||||||
|
lora_entry['isDeleted'] = False
|
||||||
|
lora_entry['exclude'] = False
|
||||||
|
lora_entry['file_name'] = target_name
|
||||||
|
|
||||||
|
if target_lora is not None:
|
||||||
|
sha_value = target_lora.get('sha256') or target_lora.get('sha')
|
||||||
|
if sha_value:
|
||||||
|
lora_entry['hash'] = sha_value.lower()
|
||||||
|
|
||||||
|
civitai_info = target_lora.get('civitai') or {}
|
||||||
|
if civitai_info:
|
||||||
|
lora_entry['modelName'] = civitai_info.get('model', {}).get('name', '')
|
||||||
|
lora_entry['modelVersionName'] = civitai_info.get('name', '')
|
||||||
|
lora_entry['modelVersionId'] = civitai_info.get('id')
|
||||||
|
|
||||||
|
recipe_data['fingerprint'] = calculate_recipe_fingerprint(recipe_data.get('loras', []))
|
||||||
|
recipe_data['modified'] = time.time()
|
||||||
|
|
||||||
|
with open(recipe_json_path, 'w', encoding='utf-8') as file_obj:
|
||||||
|
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
cache = await self.get_cached_data()
|
||||||
|
replaced = await cache.replace_recipe(recipe_id, recipe_data, resort=False)
|
||||||
|
if not replaced:
|
||||||
|
await cache.add_recipe(recipe_data, resort=False)
|
||||||
|
self._schedule_resort()
|
||||||
|
|
||||||
|
updated_lora = dict(lora_entry)
|
||||||
|
if target_lora is not None:
|
||||||
|
preview_url = target_lora.get('preview_url')
|
||||||
|
if preview_url:
|
||||||
|
updated_lora['preview_url'] = config.get_preview_static_url(preview_url)
|
||||||
|
if target_lora.get('file_path'):
|
||||||
|
updated_lora['localPath'] = target_lora['file_path']
|
||||||
|
|
||||||
|
updated_lora = self._enrich_lora_entry(updated_lora)
|
||||||
|
return recipe_data, updated_lora
|
||||||
|
|
||||||
|
async def get_recipes_for_lora(self, lora_hash: str) -> List[Dict[str, Any]]:
|
||||||
|
"""Return recipes that reference a given LoRA hash."""
|
||||||
|
|
||||||
|
if not lora_hash:
|
||||||
|
return []
|
||||||
|
|
||||||
|
normalized_hash = lora_hash.lower()
|
||||||
|
cache = await self.get_cached_data()
|
||||||
|
matching_recipes: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
for recipe in cache.raw_data:
|
||||||
|
loras = recipe.get('loras', [])
|
||||||
|
if any((entry.get('hash') or '').lower() == normalized_hash for entry in loras):
|
||||||
|
recipe_copy = {**recipe}
|
||||||
|
recipe_copy['loras'] = [self._enrich_lora_entry(dict(entry)) for entry in loras]
|
||||||
|
recipe_copy['file_url'] = self._format_file_url(recipe.get('file_path'))
|
||||||
|
matching_recipes.append(recipe_copy)
|
||||||
|
|
||||||
|
return matching_recipes
|
||||||
|
|
||||||
|
async def get_recipe_syntax_tokens(self, recipe_id: str) -> List[str]:
|
||||||
|
"""Build LoRA syntax tokens for a recipe."""
|
||||||
|
|
||||||
|
cache = await self.get_cached_data()
|
||||||
|
recipe = await cache.get_recipe(recipe_id)
|
||||||
|
if recipe is None:
|
||||||
|
raise RecipeNotFoundError("Recipe not found")
|
||||||
|
|
||||||
|
loras = recipe.get('loras', [])
|
||||||
|
if not loras:
|
||||||
|
return []
|
||||||
|
|
||||||
|
lora_cache = None
|
||||||
|
if self._lora_scanner is not None:
|
||||||
|
lora_cache = await self._lora_scanner.get_cached_data()
|
||||||
|
|
||||||
|
syntax_parts: List[str] = []
|
||||||
|
for lora in loras:
|
||||||
|
if lora.get('isDeleted', False):
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_name = None
|
||||||
|
hash_value = (lora.get('hash') or '').lower()
|
||||||
|
if hash_value and self._lora_scanner is not None and hasattr(self._lora_scanner, '_hash_index'):
|
||||||
|
file_path = self._lora_scanner._hash_index.get_path(hash_value)
|
||||||
|
if file_path:
|
||||||
|
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||||
|
|
||||||
|
if not file_name and lora.get('modelVersionId') and lora_cache is not None:
|
||||||
|
for cached_lora in getattr(lora_cache, 'raw_data', []):
|
||||||
|
civitai_info = cached_lora.get('civitai')
|
||||||
|
if civitai_info and civitai_info.get('id') == lora.get('modelVersionId'):
|
||||||
|
cached_path = cached_lora.get('path') or cached_lora.get('file_path')
|
||||||
|
if cached_path:
|
||||||
|
file_name = os.path.splitext(os.path.basename(cached_path))[0]
|
||||||
|
break
|
||||||
|
|
||||||
|
if not file_name:
|
||||||
|
file_name = lora.get('file_name', 'unknown-lora')
|
||||||
|
|
||||||
|
strength = lora.get('strength', 1.0)
|
||||||
|
syntax_parts.append(f"<lora:{file_name}:{strength}>")
|
||||||
|
|
||||||
|
return syntax_parts
|
||||||
|
|
||||||
async def update_lora_filename_by_hash(self, hash_value: str, new_file_name: str) -> Tuple[int, int]:
|
async def update_lora_filename_by_hash(self, hash_value: str, new_file_name: str) -> Tuple[int, int]:
|
||||||
"""Update file_name in all recipes that contain a LoRA with the specified hash.
|
"""Update file_name in all recipes that contain a LoRA with the specified hash.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hash_value: The SHA256 hash value of the LoRA
|
hash_value: The SHA256 hash value of the LoRA
|
||||||
new_file_name: The new file_name to set
|
new_file_name: The new file_name to set
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Services encapsulating recipe persistence workflows."""
|
"""Services encapsulating recipe persistence workflows."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@@ -123,7 +122,7 @@ class RecipePersistenceService:
|
|||||||
self._exif_utils.append_recipe_metadata(image_path, recipe_data)
|
self._exif_utils.append_recipe_metadata(image_path, recipe_data)
|
||||||
|
|
||||||
matching_recipes = await self._find_matching_recipes(recipe_scanner, fingerprint, exclude_id=recipe_id)
|
matching_recipes = await self._find_matching_recipes(recipe_scanner, fingerprint, exclude_id=recipe_id)
|
||||||
await self._update_cache(recipe_scanner, recipe_data)
|
await recipe_scanner.add_recipe(recipe_data)
|
||||||
|
|
||||||
return PersistenceResult(
|
return PersistenceResult(
|
||||||
{
|
{
|
||||||
@@ -154,7 +153,7 @@ class RecipePersistenceService:
|
|||||||
if image_path and os.path.exists(image_path):
|
if image_path and os.path.exists(image_path):
|
||||||
os.remove(image_path)
|
os.remove(image_path)
|
||||||
|
|
||||||
await self._remove_from_cache(recipe_scanner, recipe_id)
|
await recipe_scanner.remove_recipe(recipe_id)
|
||||||
return PersistenceResult({"success": True, "message": "Recipe deleted successfully"})
|
return PersistenceResult({"success": True, "message": "Recipe deleted successfully"})
|
||||||
|
|
||||||
async def update_recipe(self, *, recipe_scanner, recipe_id: str, updates: dict[str, Any]) -> PersistenceResult:
|
async def update_recipe(self, *, recipe_scanner, recipe_id: str, updates: dict[str, Any]) -> PersistenceResult:
|
||||||
@@ -185,40 +184,16 @@ class RecipePersistenceService:
|
|||||||
if not os.path.exists(recipe_path):
|
if not os.path.exists(recipe_path):
|
||||||
raise RecipeNotFoundError("Recipe not found")
|
raise RecipeNotFoundError("Recipe not found")
|
||||||
|
|
||||||
lora_scanner = getattr(recipe_scanner, "_lora_scanner", None)
|
target_lora = await recipe_scanner.get_local_lora(target_name)
|
||||||
target_lora = None if lora_scanner is None else await lora_scanner.get_model_info_by_name(target_name)
|
|
||||||
if not target_lora:
|
if not target_lora:
|
||||||
raise RecipeNotFoundError(f"Local LoRA not found with name: {target_name}")
|
raise RecipeNotFoundError(f"Local LoRA not found with name: {target_name}")
|
||||||
|
|
||||||
with open(recipe_path, "r", encoding="utf-8") as file_obj:
|
recipe_data, updated_lora = await recipe_scanner.update_lora_entry(
|
||||||
recipe_data = json.load(file_obj)
|
recipe_id,
|
||||||
|
lora_index,
|
||||||
loras = recipe_data.get("loras", [])
|
target_name=target_name,
|
||||||
if lora_index >= len(loras):
|
target_lora=target_lora,
|
||||||
raise RecipeNotFoundError("LoRA index out of range in recipe")
|
)
|
||||||
|
|
||||||
lora = loras[lora_index]
|
|
||||||
lora["isDeleted"] = False
|
|
||||||
lora["exclude"] = False
|
|
||||||
lora["file_name"] = target_name
|
|
||||||
if "sha256" in target_lora:
|
|
||||||
lora["hash"] = target_lora["sha256"].lower()
|
|
||||||
if target_lora.get("civitai"):
|
|
||||||
lora["modelName"] = target_lora["civitai"]["model"]["name"]
|
|
||||||
lora["modelVersionName"] = target_lora["civitai"]["name"]
|
|
||||||
lora["modelVersionId"] = target_lora["civitai"]["id"]
|
|
||||||
|
|
||||||
recipe_data["fingerprint"] = calculate_recipe_fingerprint(recipe_data.get("loras", []))
|
|
||||||
|
|
||||||
with open(recipe_path, "w", encoding="utf-8") as file_obj:
|
|
||||||
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
|
|
||||||
|
|
||||||
updated_lora = dict(lora)
|
|
||||||
updated_lora["inLibrary"] = True
|
|
||||||
updated_lora["preview_url"] = config.get_preview_static_url(target_lora["preview_url"])
|
|
||||||
updated_lora["localPath"] = target_lora["file_path"]
|
|
||||||
|
|
||||||
await self._refresh_cache_after_update(recipe_scanner, recipe_id, recipe_data)
|
|
||||||
|
|
||||||
image_path = recipe_data.get("file_path")
|
image_path = recipe_data.get("file_path")
|
||||||
if image_path and os.path.exists(image_path):
|
if image_path and os.path.exists(image_path):
|
||||||
@@ -276,7 +251,7 @@ class RecipePersistenceService:
|
|||||||
failed_recipes.append({"id": recipe_id, "reason": str(exc)})
|
failed_recipes.append({"id": recipe_id, "reason": str(exc)})
|
||||||
|
|
||||||
if deleted_recipes:
|
if deleted_recipes:
|
||||||
await self._bulk_remove_from_cache(recipe_scanner, deleted_recipes)
|
await recipe_scanner.bulk_remove(deleted_recipes)
|
||||||
|
|
||||||
return PersistenceResult(
|
return PersistenceResult(
|
||||||
{
|
{
|
||||||
@@ -314,14 +289,11 @@ class RecipePersistenceService:
|
|||||||
if not lora_matches:
|
if not lora_matches:
|
||||||
raise RecipeValidationError("No LoRAs found in the generation metadata")
|
raise RecipeValidationError("No LoRAs found in the generation metadata")
|
||||||
|
|
||||||
lora_scanner = getattr(recipe_scanner, "_lora_scanner", None)
|
|
||||||
loras_data = []
|
loras_data = []
|
||||||
base_model_counts: Dict[str, int] = {}
|
base_model_counts: Dict[str, int] = {}
|
||||||
|
|
||||||
for name, strength in lora_matches:
|
for name, strength in lora_matches:
|
||||||
lora_info = None
|
lora_info = await recipe_scanner.get_local_lora(name)
|
||||||
if lora_scanner is not None:
|
|
||||||
lora_info = await lora_scanner.get_model_info_by_name(name)
|
|
||||||
lora_data = {
|
lora_data = {
|
||||||
"file_name": name,
|
"file_name": name,
|
||||||
"strength": float(strength),
|
"strength": float(strength),
|
||||||
@@ -366,7 +338,7 @@ class RecipePersistenceService:
|
|||||||
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
|
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
self._exif_utils.append_recipe_metadata(image_path, recipe_data)
|
self._exif_utils.append_recipe_metadata(image_path, recipe_data)
|
||||||
await self._update_cache(recipe_scanner, recipe_data)
|
await recipe_scanner.add_recipe(recipe_data)
|
||||||
|
|
||||||
return PersistenceResult(
|
return PersistenceResult(
|
||||||
{
|
{
|
||||||
@@ -422,45 +394,6 @@ class RecipePersistenceService:
|
|||||||
matches.remove(exclude_id)
|
matches.remove(exclude_id)
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
async def _update_cache(self, recipe_scanner, recipe_data: dict[str, Any]) -> None:
|
|
||||||
cache = getattr(recipe_scanner, "_cache", None)
|
|
||||||
if cache is not None:
|
|
||||||
cache.raw_data.append(recipe_data)
|
|
||||||
asyncio.create_task(cache.resort())
|
|
||||||
self._logger.info("Added recipe %s to cache", recipe_data.get("id"))
|
|
||||||
|
|
||||||
async def _remove_from_cache(self, recipe_scanner, recipe_id: str) -> None:
|
|
||||||
cache = getattr(recipe_scanner, "_cache", None)
|
|
||||||
if cache is not None:
|
|
||||||
cache.raw_data = [item for item in cache.raw_data if str(item.get("id", "")) != recipe_id]
|
|
||||||
asyncio.create_task(cache.resort())
|
|
||||||
self._logger.info("Removed recipe %s from cache", recipe_id)
|
|
||||||
|
|
||||||
async def _bulk_remove_from_cache(self, recipe_scanner, recipe_ids: Iterable[str]) -> None:
|
|
||||||
cache = getattr(recipe_scanner, "_cache", None)
|
|
||||||
if cache is not None:
|
|
||||||
recipe_ids_set = set(recipe_ids)
|
|
||||||
cache.raw_data = [item for item in cache.raw_data if item.get("id") not in recipe_ids_set]
|
|
||||||
asyncio.create_task(cache.resort())
|
|
||||||
self._logger.info("Removed %s recipes from cache", len(recipe_ids_set))
|
|
||||||
|
|
||||||
async def _refresh_cache_after_update(
|
|
||||||
self,
|
|
||||||
recipe_scanner,
|
|
||||||
recipe_id: str,
|
|
||||||
recipe_data: dict[str, Any],
|
|
||||||
) -> None:
|
|
||||||
cache = getattr(recipe_scanner, "_cache", None)
|
|
||||||
if cache is not None:
|
|
||||||
for cache_item in cache.raw_data:
|
|
||||||
if cache_item.get("id") == recipe_id:
|
|
||||||
cache_item.update({
|
|
||||||
"loras": recipe_data.get("loras", []),
|
|
||||||
"fingerprint": recipe_data.get("fingerprint"),
|
|
||||||
})
|
|
||||||
asyncio.create_task(cache.resort())
|
|
||||||
break
|
|
||||||
|
|
||||||
def _derive_recipe_name(self, lora_matches: list[tuple[str, str]]) -> str:
|
def _derive_recipe_name(self, lora_matches: list[tuple[str, str]]) -> str:
|
||||||
recipe_name_parts = [f"{name.strip()}-{float(strength):.2f}" for name, strength in lora_matches[:3]]
|
recipe_name_parts = [f"{name.strip()}-{float(strength):.2f}" for name, strength in lora_matches[:3]]
|
||||||
recipe_name = "_".join(recipe_name_parts)
|
recipe_name = "_".join(recipe_name_parts)
|
||||||
|
|||||||
@@ -38,11 +38,7 @@ class RecipeSharingService:
|
|||||||
async def share_recipe(self, *, recipe_scanner, recipe_id: str) -> SharingResult:
|
async def share_recipe(self, *, recipe_scanner, recipe_id: str) -> SharingResult:
|
||||||
"""Prepare a temporary downloadable copy of a recipe image."""
|
"""Prepare a temporary downloadable copy of a recipe image."""
|
||||||
|
|
||||||
cache = await recipe_scanner.get_cached_data()
|
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
|
||||||
recipe = next(
|
|
||||||
(r for r in getattr(cache, "raw_data", []) if str(r.get("id", "")) == recipe_id),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if not recipe:
|
if not recipe:
|
||||||
raise RecipeNotFoundError("Recipe not found")
|
raise RecipeNotFoundError("Recipe not found")
|
||||||
|
|
||||||
@@ -81,11 +77,7 @@ class RecipeSharingService:
|
|||||||
self._cleanup_entry(recipe_id)
|
self._cleanup_entry(recipe_id)
|
||||||
raise RecipeNotFoundError("Shared recipe file not found")
|
raise RecipeNotFoundError("Shared recipe file not found")
|
||||||
|
|
||||||
cache = await recipe_scanner.get_cached_data()
|
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
|
||||||
recipe = next(
|
|
||||||
(r for r in getattr(cache, "raw_data", []) if str(r.get("id", "")) == recipe_id),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
filename_base = (
|
filename_base = (
|
||||||
f"recipe_{recipe.get('title', '').replace(' ', '_').lower()}" if recipe else recipe_id
|
f"recipe_{recipe.get('title', '').replace(' ', '_').lower()}" if recipe else recipe_id
|
||||||
)
|
)
|
||||||
|
|||||||
185
tests/services/test_recipe_scanner.py
Normal file
185
tests/services/test_recipe_scanner.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from py.config import config
|
||||||
|
from py.services.recipe_scanner import RecipeScanner
|
||||||
|
from py.utils.utils import calculate_recipe_fingerprint
|
||||||
|
|
||||||
|
|
||||||
|
class StubHashIndex:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._hash_to_path: dict[str, str] = {}
|
||||||
|
|
||||||
|
def get_path(self, hash_value: str) -> str | None:
|
||||||
|
return self._hash_to_path.get(hash_value)
|
||||||
|
|
||||||
|
|
||||||
|
class StubLoraScanner:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._hash_index = StubHashIndex()
|
||||||
|
self._hash_meta: dict[str, dict[str, str]] = {}
|
||||||
|
self._models_by_name: dict[str, dict] = {}
|
||||||
|
self._cache = SimpleNamespace(raw_data=[])
|
||||||
|
|
||||||
|
async def get_cached_data(self):
|
||||||
|
return self._cache
|
||||||
|
|
||||||
|
def has_hash(self, hash_value: str) -> bool:
|
||||||
|
return hash_value.lower() in self._hash_meta
|
||||||
|
|
||||||
|
def get_preview_url_by_hash(self, hash_value: str) -> str:
|
||||||
|
meta = self._hash_meta.get(hash_value.lower())
|
||||||
|
return meta.get("preview_url", "") if meta else ""
|
||||||
|
|
||||||
|
def get_path_by_hash(self, hash_value: str) -> str | None:
|
||||||
|
meta = self._hash_meta.get(hash_value.lower())
|
||||||
|
return meta.get("path") if meta else None
|
||||||
|
|
||||||
|
async def get_model_info_by_name(self, name: str):
|
||||||
|
return self._models_by_name.get(name)
|
||||||
|
|
||||||
|
def register_model(self, name: str, info: dict) -> None:
|
||||||
|
self._models_by_name[name] = info
|
||||||
|
hash_value = (info.get("sha256") or "").lower()
|
||||||
|
if hash_value:
|
||||||
|
self._hash_meta[hash_value] = {
|
||||||
|
"path": info.get("file_path", ""),
|
||||||
|
"preview_url": info.get("preview_url", ""),
|
||||||
|
}
|
||||||
|
self._hash_index._hash_to_path[hash_value] = info.get("file_path", "")
|
||||||
|
self._cache.raw_data.append({
|
||||||
|
"sha256": info.get("sha256", ""),
|
||||||
|
"path": info.get("file_path", ""),
|
||||||
|
"civitai": info.get("civitai", {}),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def recipe_scanner(tmp_path: Path, monkeypatch):
|
||||||
|
RecipeScanner._instance = None
|
||||||
|
monkeypatch.setattr(config, "loras_roots", [str(tmp_path)])
|
||||||
|
stub = StubLoraScanner()
|
||||||
|
scanner = RecipeScanner(lora_scanner=stub)
|
||||||
|
asyncio.run(scanner.refresh_cache(force=True))
|
||||||
|
yield scanner, stub
|
||||||
|
RecipeScanner._instance = None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_add_recipe_during_concurrent_reads(recipe_scanner):
|
||||||
|
scanner, _ = recipe_scanner
|
||||||
|
|
||||||
|
initial_recipe = {
|
||||||
|
"id": "one",
|
||||||
|
"file_path": "path/a.png",
|
||||||
|
"title": "First",
|
||||||
|
"modified": 1.0,
|
||||||
|
"created_date": 1.0,
|
||||||
|
"loras": [],
|
||||||
|
}
|
||||||
|
await scanner.add_recipe(initial_recipe)
|
||||||
|
|
||||||
|
new_recipe = {
|
||||||
|
"id": "two",
|
||||||
|
"file_path": "path/b.png",
|
||||||
|
"title": "Second",
|
||||||
|
"modified": 2.0,
|
||||||
|
"created_date": 2.0,
|
||||||
|
"loras": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def reader_task():
|
||||||
|
for _ in range(5):
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
_ = [item["id"] for item in cache.raw_data]
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
await asyncio.gather(reader_task(), reader_task(), scanner.add_recipe(new_recipe))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
|
||||||
|
assert {item["id"] for item in cache.raw_data} == {"one", "two"}
|
||||||
|
assert len(cache.sorted_by_name) == len(cache.raw_data)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_remove_recipe_during_reads(recipe_scanner):
|
||||||
|
scanner, _ = recipe_scanner
|
||||||
|
|
||||||
|
recipe_ids = ["alpha", "beta", "gamma"]
|
||||||
|
for index, recipe_id in enumerate(recipe_ids):
|
||||||
|
await scanner.add_recipe({
|
||||||
|
"id": recipe_id,
|
||||||
|
"file_path": f"path/{recipe_id}.png",
|
||||||
|
"title": recipe_id,
|
||||||
|
"modified": float(index),
|
||||||
|
"created_date": float(index),
|
||||||
|
"loras": [],
|
||||||
|
})
|
||||||
|
|
||||||
|
async def reader_task():
|
||||||
|
for _ in range(5):
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
_ = list(cache.sorted_by_date)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
await asyncio.gather(reader_task(), scanner.remove_recipe("beta"))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
|
||||||
|
assert {item["id"] for item in cache.raw_data} == {"alpha", "gamma"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_update_lora_entry_updates_cache_and_file(tmp_path: Path, recipe_scanner):
|
||||||
|
scanner, stub = recipe_scanner
|
||||||
|
recipes_dir = Path(config.loras_roots[0]) / "recipes"
|
||||||
|
recipes_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
recipe_id = "recipe-1"
|
||||||
|
recipe_path = recipes_dir / f"{recipe_id}.recipe.json"
|
||||||
|
recipe_data = {
|
||||||
|
"id": recipe_id,
|
||||||
|
"file_path": str(tmp_path / "image.png"),
|
||||||
|
"title": "Original",
|
||||||
|
"modified": 0.0,
|
||||||
|
"created_date": 0.0,
|
||||||
|
"loras": [
|
||||||
|
{"file_name": "old", "strength": 1.0, "hash": "", "isDeleted": True, "exclude": True},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
recipe_path.write_text(json.dumps(recipe_data))
|
||||||
|
|
||||||
|
await scanner.add_recipe(dict(recipe_data))
|
||||||
|
|
||||||
|
target_hash = "abc123"
|
||||||
|
target_info = {
|
||||||
|
"sha256": target_hash,
|
||||||
|
"file_path": str(tmp_path / "loras" / "target.safetensors"),
|
||||||
|
"preview_url": "preview.png",
|
||||||
|
"civitai": {"id": 42, "name": "v1", "model": {"name": "Target"}},
|
||||||
|
}
|
||||||
|
stub.register_model("target", target_info)
|
||||||
|
|
||||||
|
updated_recipe, updated_lora = await scanner.update_lora_entry(
|
||||||
|
recipe_id,
|
||||||
|
0,
|
||||||
|
target_name="target",
|
||||||
|
target_lora=target_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert updated_lora["inLibrary"] is True
|
||||||
|
assert updated_lora["localPath"] == target_info["file_path"]
|
||||||
|
assert updated_lora["hash"] == target_hash
|
||||||
|
|
||||||
|
with recipe_path.open("r", encoding="utf-8") as file_obj:
|
||||||
|
persisted = json.load(file_obj)
|
||||||
|
|
||||||
|
expected_fingerprint = calculate_recipe_fingerprint(persisted["loras"])
|
||||||
|
assert persisted["fingerprint"] == expected_fingerprint
|
||||||
|
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
cached_recipe = next(item for item in cache.raw_data if item["id"] == recipe_id)
|
||||||
|
assert cached_recipe["loras"][0]["hash"] == target_hash
|
||||||
|
assert cached_recipe["fingerprint"] == expected_fingerprint
|
||||||
@@ -108,6 +108,10 @@ async def test_save_recipe_reports_duplicates(tmp_path):
|
|||||||
self.last_fingerprint = fingerprint
|
self.last_fingerprint = fingerprint
|
||||||
return ["existing"]
|
return ["existing"]
|
||||||
|
|
||||||
|
async def add_recipe(self, recipe_data):
|
||||||
|
self._cache.raw_data.append(recipe_data)
|
||||||
|
await self._cache.resort()
|
||||||
|
|
||||||
scanner = DummyScanner(tmp_path)
|
scanner = DummyScanner(tmp_path)
|
||||||
service = RecipePersistenceService(
|
service = RecipePersistenceService(
|
||||||
exif_utils=exif_utils,
|
exif_utils=exif_utils,
|
||||||
|
|||||||
Reference in New Issue
Block a user