mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Add filter-related endpoints to RecipeRoutes for top tags and base models. Enhance get_paginated_data method in RecipeScanner to support filtering by base model and tags. Implement logic to retrieve and count occurrences of top tags and base models from cached recipes.
This commit is contained in:
@@ -40,6 +40,10 @@ class RecipeRoutes:
|
||||
app.router.add_post('/api/recipes/save', routes.save_recipe)
|
||||
app.router.add_delete('/api/recipe/{recipe_id}', routes.delete_recipe)
|
||||
|
||||
# Add new filter-related endpoints
|
||||
app.router.add_get('/api/recipes/top-tags', routes.get_top_tags)
|
||||
app.router.add_get('/api/recipes/base-models', routes.get_base_models)
|
||||
|
||||
# Start cache initialization
|
||||
app.on_startup.append(routes._init_cache)
|
||||
|
||||
@@ -80,12 +84,24 @@ class RecipeRoutes:
|
||||
sort_by = request.query.get('sort_by', 'date')
|
||||
search = request.query.get('search', None)
|
||||
|
||||
# Get filter parameters
|
||||
base_models = request.query.get('base_models', None)
|
||||
tags = request.query.get('tags', None)
|
||||
|
||||
# Parse filter parameters
|
||||
filters = {}
|
||||
if base_models:
|
||||
filters['base_model'] = base_models.split(',')
|
||||
if tags:
|
||||
filters['tags'] = tags.split(',')
|
||||
|
||||
# Get paginated data
|
||||
result = await self.recipe_scanner.get_paginated_data(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
sort_by=sort_by,
|
||||
search=search
|
||||
search=search,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
# Format the response data with static URLs for file paths
|
||||
@@ -521,4 +537,64 @@ class RecipeRoutes:
|
||||
return web.json_response({"success": True, "message": "Recipe deleted successfully"})
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting recipe: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||
"""Get top tags used in recipes"""
|
||||
try:
|
||||
# Get limit parameter with default
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
|
||||
# Get all recipes from cache
|
||||
cache = await self.recipe_scanner.get_cached_data()
|
||||
|
||||
# Count tag occurrences
|
||||
tag_counts = {}
|
||||
for recipe in cache.raw_data:
|
||||
if 'tags' in recipe and recipe['tags']:
|
||||
for tag in recipe['tags']:
|
||||
tag_counts[tag] = tag_counts.get(tag, 0) + 1
|
||||
|
||||
# Sort tags by count and limit results
|
||||
sorted_tags = [{'tag': tag, 'count': count} for tag, count in tag_counts.items()]
|
||||
sorted_tags.sort(key=lambda x: x['count'], reverse=True)
|
||||
top_tags = sorted_tags[:limit]
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'tags': top_tags
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving top tags: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||
"""Get base models used in recipes"""
|
||||
try:
|
||||
# Get all recipes from cache
|
||||
cache = await self.recipe_scanner.get_cached_data()
|
||||
|
||||
# Count base model occurrences
|
||||
base_model_counts = {}
|
||||
for recipe in cache.raw_data:
|
||||
if 'base_model' in recipe and recipe['base_model']:
|
||||
base_model = recipe['base_model']
|
||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||
|
||||
# Sort base models by count
|
||||
sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()]
|
||||
sorted_models.sort(key=lambda x: x['count'], reverse=True)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'base_models': sorted_models
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving base models: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
@@ -374,7 +374,7 @@ class RecipeScanner:
|
||||
logger.error(f"Error getting base model for lora: {e}")
|
||||
return None
|
||||
|
||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None):
|
||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None):
|
||||
"""Get paginated and filtered recipe data
|
||||
|
||||
Args:
|
||||
@@ -382,6 +382,7 @@ class RecipeScanner:
|
||||
page_size: Number of items per page
|
||||
sort_by: Sort method ('name' or 'date')
|
||||
search: Search term
|
||||
filters: Dictionary of filters to apply
|
||||
"""
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
@@ -395,6 +396,22 @@ class RecipeScanner:
|
||||
if search.lower() in str(item.get('title', '')).lower() or
|
||||
search.lower() in str(item.get('prompt', '')).lower()
|
||||
]
|
||||
|
||||
# Apply additional filters
|
||||
if filters:
|
||||
# Filter by base model
|
||||
if 'base_model' in filters and filters['base_model']:
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if item.get('base_model', '') in filters['base_model']
|
||||
]
|
||||
|
||||
# Filter by tags
|
||||
if 'tags' in filters and filters['tags']:
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if any(tag in item.get('tags', []) for tag in filters['tags'])
|
||||
]
|
||||
|
||||
# Calculate pagination
|
||||
total_items = len(filtered_data)
|
||||
|
||||
Reference in New Issue
Block a user