From 801aa2e87624c088182769e593bce5a84d16325c Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Tue, 8 Apr 2025 12:23:51 +0800 Subject: [PATCH] Enhance Lora and recipe integration with improved filtering and UI updates - Added support for filtering LoRAs by hash in both API and UI components. - Implemented session storage management for custom filter states when navigating between recipes and LoRAs. - Introduced a new button in the recipe modal to view associated LoRAs, enhancing user navigation. - Updated CSS styles for new UI elements, including a custom filter indicator and LoRA view button. - Refactored existing JavaScript components to streamline the handling of filter parameters and improve maintainability. --- py/routes/api_routes.py | 109 +++++++++-------- py/routes/recipe_routes.py | 4 +- py/services/lora_scanner.py | 125 +++++++++++--------- py/services/recipe_scanner.py | 2 +- static/css/components/recipe-modal.css | 29 +++++ static/js/api/loraApi.js | 23 ++++ static/js/components/RecipeModal.js | 74 +++++++++++- static/js/components/loraModal/RecipeTab.js | 20 +--- static/js/loras.js | 79 ++++++++++++- static/js/recipes.js | 36 +----- templates/components/controls.html | 6 + templates/components/recipe_modal.html | 3 + 12 files changed, 352 insertions(+), 158 deletions(-) diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index 95b224d9..0dff5524 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -132,13 +132,9 @@ class ApiRoutes: page = int(request.query.get('page', '1')) page_size = int(request.query.get('page_size', '20')) sort_by = request.query.get('sort_by', 'name') - folder = request.query.get('folder') - search = request.query.get('search', '').lower() - fuzzy = request.query.get('fuzzy', 'false').lower() == 'true' - - # Parse base models filter parameter - base_models = request.query.get('base_models', '').split(',') - base_models = [model.strip() for model in base_models if model.strip()] + folder = request.query.get('folder', None) + search = request.query.get('search', None) + fuzzy_search = request.query.get('fuzzy', 'false').lower() == 'true' # Parse search options search_filename = request.query.get('search_filename', 'true').lower() == 'true' @@ -146,62 +142,68 @@ class ApiRoutes: search_tags = request.query.get('search_tags', 'false').lower() == 'true' recursive = request.query.get('recursive', 'false').lower() == 'true' - # Validate parameters - if page < 1 or page_size < 1 or page_size > 100: - return web.json_response({ - 'error': 'Invalid pagination parameters' - }, status=400) + # Get filter parameters + base_models = request.query.get('base_models', None) + tags = request.query.get('tags', None) - if sort_by not in ['date', 'name']: - return web.json_response({ - 'error': 'Invalid sort parameter' - }, status=400) + # New parameters for recipe filtering + lora_hash = request.query.get('lora_hash', None) + lora_hashes = request.query.get('lora_hashes', None) - # Parse tags filter parameter - tags = request.query.get('tags', '').split(',') - tags = [tag.strip() for tag in tags if tag.strip()] + # Parse filter parameters + filters = {} + if base_models: + filters['base_model'] = base_models.split(',') + if tags: + filters['tags'] = tags.split(',') - # Get paginated data with search and filters - result = await self.scanner.get_paginated_data( - page=page, - page_size=page_size, - sort_by=sort_by, + # Add search options to filters + search_options = { + 'filename': search_filename, + 'modelname': search_modelname, + 'tags': search_tags, + 'recursive': recursive + } + + # Add lora hash filtering options + hash_filters = {} + if lora_hash: + hash_filters['single_hash'] = lora_hash.lower() + elif lora_hashes: + hash_filters['multiple_hashes'] = [h.lower() for h in lora_hashes.split(',')] + + # Get file data + data = await self.scanner.get_paginated_data( + page, + page_size, + sort_by=sort_by, folder=folder, search=search, - fuzzy=fuzzy, - base_models=base_models, # Pass base models filter - tags=tags, # Add tags parameter - search_options={ - 'filename': search_filename, - 'modelname': search_modelname, - 'tags': search_tags, - 'recursive': recursive - } + fuzzy_search=fuzzy_search, + base_models=filters.get('base_model', None), + tags=filters.get('tags', None), + search_options=search_options, + hash_filters=hash_filters ) - - # Format the response data - formatted_items = [ - self._format_lora_response(item) - for item in result['items'] - ] # Get all available folders from cache cache = await self.scanner.get_cached_data() - return web.json_response({ - 'items': formatted_items, - 'total': result['total'], - 'page': result['page'], - 'page_size': result['page_size'], - 'total_pages': result['total_pages'], - 'folders': cache.folders - }) + # Convert output to match expected format + result = { + 'items': [self._format_lora_response(lora) for lora in data['items']], + 'folders': cache.folders, + 'total': data['total'], + 'page': data['page'], + 'page_size': data['page_size'], + 'total_pages': data['total_pages'] + } + + return web.json_response(result) except Exception as e: - logger.error(f"Error in get_loras: {str(e)}", exc_info=True) - return web.json_response({ - 'error': 'Internal server error' - }, status=500) + logger.error(f"Error retrieving loras: {e}", exc_info=True) + return web.json_response({"error": str(e)}, status=500) def _format_lora_response(self, lora: Dict) -> Dict: """Format LoRA data for API response""" @@ -831,7 +833,10 @@ class ApiRoutes: except Exception as e: logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True) - return web.Response(text=str(e), status=500) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) async def move_models_bulk(self, request: web.Request) -> web.Response: """Handle bulk model move request""" diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index 7493087e..32de5722 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -103,7 +103,6 @@ class RecipeRoutes: # New parameter: get LoRA hash filter lora_hash = request.query.get('lora_hash', None) - bypass_filters = request.query.get('bypass_filters', 'false').lower() == 'true' # Parse filter parameters filters = {} @@ -128,8 +127,7 @@ class RecipeRoutes: search=search, filters=filters, search_options=search_options, - lora_hash=lora_hash, - bypass_filters=bypass_filters + lora_hash=lora_hash ) # Format the response data with static URLs for file paths diff --git a/py/services/lora_scanner.py b/py/services/lora_scanner.py index f322c9f2..96bcce45 100644 --- a/py/services/lora_scanner.py +++ b/py/services/lora_scanner.py @@ -136,9 +136,9 @@ class LoraScanner: ) async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name', - folder: str = None, search: str = None, fuzzy: bool = False, + folder: str = None, search: str = None, fuzzy_search: bool = False, base_models: list = None, tags: list = None, - search_options: dict = None) -> Dict: + search_options: dict = None, hash_filters: dict = None) -> Dict: """Get paginated and filtered lora data Args: @@ -147,10 +147,11 @@ class LoraScanner: sort_by: Sort method ('name' or 'date') folder: Filter by folder path search: Search term - fuzzy: Use fuzzy matching for search + fuzzy_search: Use fuzzy matching for search base_models: List of base models to filter by tags: List of tags to filter by search_options: Dictionary with search options (filename, modelname, tags, recursive) + hash_filters: Dictionary with hash filtering options (single_hash or multiple_hashes) """ cache = await self.get_cached_data() @@ -160,90 +161,108 @@ class LoraScanner: 'filename': True, 'modelname': True, 'tags': False, - 'recursive': False + 'recursive': False, } # Get the base data set filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name + # Apply hash filtering if provided (highest priority) + if hash_filters: + single_hash = hash_filters.get('single_hash') + multiple_hashes = hash_filters.get('multiple_hashes') + + if single_hash: + # Filter by single hash + single_hash = single_hash.lower() # Ensure lowercase for matching + filtered_data = [ + lora for lora in filtered_data + if lora.get('sha256', '').lower() == single_hash + ] + elif multiple_hashes: + # Filter by multiple hashes + hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup + filtered_data = [ + lora for lora in filtered_data + if lora.get('sha256', '').lower() in hash_set + ] + + + # Jump to pagination + total_items = len(filtered_data) + start_idx = (page - 1) * page_size + end_idx = min(start_idx + page_size, total_items) + + result = { + 'items': filtered_data[start_idx:end_idx], + 'total': total_items, + 'page': page, + 'page_size': page_size, + 'total_pages': (total_items + page_size - 1) // page_size + } + + return result + # Apply SFW filtering if enabled if settings.get('show_only_sfw', False): filtered_data = [ - item for item in filtered_data - if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R'] + lora for lora in filtered_data + if not lora.get('preview_nsfw_level') or lora.get('preview_nsfw_level') < NSFW_LEVELS['NSFW'] ] # Apply folder filtering if folder is not None: if search_options.get('recursive', False): - # Recursive mode: match all paths starting with this folder + # Recursive folder filtering - include all subfolders filtered_data = [ - item for item in filtered_data - if item['folder'].startswith(folder + '/') or item['folder'] == folder + lora for lora in filtered_data + if lora['folder'].startswith(folder) ] else: - # Non-recursive mode: match exact folder + # Exact folder filtering filtered_data = [ - item for item in filtered_data - if item['folder'] == folder + lora for lora in filtered_data + if lora['folder'] == folder ] # Apply base model filtering if base_models and len(base_models) > 0: filtered_data = [ - item for item in filtered_data - if item.get('base_model') in base_models + lora for lora in filtered_data + if lora.get('base_model') in base_models ] # Apply tag filtering if tags and len(tags) > 0: filtered_data = [ - item for item in filtered_data - if any(tag in item.get('tags', []) for tag in tags) + lora for lora in filtered_data + if any(tag in lora.get('tags', []) for tag in tags) ] # Apply search filtering if search: search_results = [] - for item in filtered_data: - # Check filename if enabled - if search_options.get('filename', True): - if fuzzy: - if fuzzy_match(item.get('file_name', ''), search): - search_results.append(item) - continue - else: - if search.lower() in item.get('file_name', '').lower(): - search_results.append(item) - continue - - # Check model name if enabled - if search_options.get('modelname', True): - if fuzzy: - if fuzzy_match(item.get('model_name', ''), search): - search_results.append(item) - continue - else: - if search.lower() in item.get('model_name', '').lower(): - search_results.append(item) - continue - - # Check tags if enabled - if search_options.get('tags', False) and item.get('tags'): - found_tag = False - for tag in item['tags']: - if fuzzy: - if fuzzy_match(tag, search): - found_tag = True - break - else: - if search.lower() in tag.lower(): - found_tag = True - break - if found_tag: - search_results.append(item) + search_opts = search_options or {} + + for lora in filtered_data: + # Search by file name + if search_opts.get('filename', True): + if fuzzy_match(lora.get('file_name', ''), search): + search_results.append(lora) continue + # Search by model name + if search_opts.get('modelname', True): + if fuzzy_match(lora.get('model_name', ''), search): + search_results.append(lora) + continue + + # Search by tags + if search_opts.get('tags', False) and 'tags' in lora: + if any(fuzzy_match(tag, search) for tag in lora['tags']): + search_results.append(lora) + continue + filtered_data = search_results # Calculate pagination diff --git a/py/services/recipe_scanner.py b/py/services/recipe_scanner.py index ef4644aa..ec3310ee 100644 --- a/py/services/recipe_scanner.py +++ b/py/services/recipe_scanner.py @@ -330,7 +330,7 @@ class RecipeScanner: logger.error(f"Error getting base model for lora: {e}") return None - async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None, lora_hash: str = None, bypass_filters: bool = False): + async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None, lora_hash: str = None, bypass_filters: bool = True): """Get paginated and filtered recipe data Args: diff --git a/static/css/components/recipe-modal.css b/static/css/components/recipe-modal.css index 9e546c0c..298faa1a 100644 --- a/static/css/components/recipe-modal.css +++ b/static/css/components/recipe-modal.css @@ -400,6 +400,27 @@ gap: var(--space-1); } +/* View LoRAs button */ +.view-loras-btn { + background: none; + border: none; + color: var(--text-color); + opacity: 0.7; + cursor: pointer; + padding: 4px 8px; + border-radius: var(--border-radius-xs); + transition: all 0.2s; + display: flex; + align-items: center; + justify-content: center; +} + +.view-loras-btn:hover { + opacity: 1; + background: var(--lora-surface); + color: var(--lora-accent); +} + #recipeLorasCount { font-size: 0.9em; color: var(--text-color); @@ -433,6 +454,14 @@ will-change: transform; /* Create a new containing block for absolutely positioned descendants */ transform: translateZ(0); + cursor: pointer; /* Make it clear the item is clickable */ + transition: transform 0.2s ease, box-shadow 0.2s ease, border-color 0.2s ease; +} + +.recipe-lora-item:hover { + transform: translateY(-1px); + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08); + border-color: var(--lora-accent); } .recipe-lora-item.exists-locally { diff --git a/static/js/api/loraApi.js b/static/js/api/loraApi.js index 95066e65..8ebee93c 100644 --- a/static/js/api/loraApi.js +++ b/static/js/api/loraApi.js @@ -4,6 +4,7 @@ import { createLoraCard } from '../components/LoraCard.js'; import { initializeInfiniteScroll } from '../utils/infiniteScroll.js'; import { showDeleteModal } from '../utils/modalUtils.js'; import { toggleFolder } from '../utils/uiHelpers.js'; +import { getSessionItem } from '../utils/storageHelpers.js'; export async function loadMoreLoras(resetPage = false, updateFolders = false) { const pageState = getCurrentPageState(); @@ -57,6 +58,28 @@ export async function loadMoreLoras(resetPage = false, updateFolders = false) { } } + // Check for recipe-based filtering parameters from session storage + const filterLoraHash = getSessionItem('recipe_to_lora_filterLoraHash'); + const filterLoraHashes = getSessionItem('recipe_to_lora_filterLoraHashes'); + + console.log('Filter Lora Hash:', filterLoraHash); + console.log('Filter Lora Hashes:', filterLoraHashes); + + // Add hash filter parameter if present + if (filterLoraHash) { + params.append('lora_hash', filterLoraHash); + } + // Add multiple hashes filter if present + else if (filterLoraHashes) { + try { + if (Array.isArray(filterLoraHashes) && filterLoraHashes.length > 0) { + params.append('lora_hashes', filterLoraHashes.join(',')); + } + } catch (error) { + console.error('Error parsing lora hashes from session storage:', error); + } + } + const response = await fetch(`/api/loras?${params}`); if (!response.ok) { throw new Error(`Failed to fetch loras: ${response.statusText}`); diff --git a/static/js/components/RecipeModal.js b/static/js/components/RecipeModal.js index 5ae4483f..e0f52af5 100644 --- a/static/js/components/RecipeModal.js +++ b/static/js/components/RecipeModal.js @@ -1,6 +1,7 @@ // Recipe Modal Component import { showToast } from '../utils/uiHelpers.js'; import { state } from '../state/index.js'; +import { setSessionItem, removeSessionItem } from '../utils/storageHelpers.js'; class RecipeModal { constructor() { @@ -294,7 +295,7 @@ class RecipeModal { } else { // No generation parameters available if (promptElement) promptElement.textContent = 'No prompt information available'; - if (negativePromptElement) negativePromptElement.textContent = 'No negative prompt information available'; + if (negativePromptElement) promptElement.textContent = 'No negative prompt information available'; if (otherParamsElement) otherParamsElement.innerHTML = '