From 4a47dc20734f722c937b4fb64b773a847f7649e9 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 20 Mar 2025 15:19:05 +0800 Subject: [PATCH] Add new API routes for base models and update existing routes - Introduced a new endpoint for retrieving base models used in loras, enhancing the API functionality. - Updated the existing top-tags route to reflect the new URL structure under '/api/loras'. - Modified the FilterManager to accommodate the new base models API, ensuring proper data fetching and display on the loras page. - Improved error handling and logging for base model retrieval, enhancing overall robustness of the application. --- py/routes/api_routes.py | 27 ++++++- py/services/lora_scanner.py | 26 +++++++ static/js/managers/FilterManager.js | 116 ++++++++++++---------------- 3 files changed, 101 insertions(+), 68 deletions(-) diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index ceb685b3..c86d0d6a 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -49,7 +49,8 @@ class ApiRoutes: app.router.add_post('/loras/api/save-metadata', routes.save_metadata) app.router.add_get('/api/lora-preview-url', routes.get_lora_preview_url) # Add new route app.router.add_post('/api/move_models_bulk', routes.move_models_bulk) - app.router.add_get('/api/top-tags', routes.get_top_tags) # Add new route for top tags + app.router.add_get('/api/loras/top-tags', routes.get_top_tags) # Add new route for top tags + app.router.add_get('/api/loras/base-models', routes.get_base_models) # Add new route for base models # Add update check routes UpdateRoutes.setup_routes(app) @@ -841,4 +842,28 @@ class ApiRoutes: return web.json_response({ 'success': False, 'error': 'Internal server error' + }, status=500) + + async def get_base_models(self, request: web.Request) -> web.Response: + """Get base models used in loras""" + try: + # Parse query parameters + limit = int(request.query.get('limit', '20')) + + # Validate limit + if limit < 1 or limit > 100: + limit = 20 # Default to a reasonable limit + + # Get base models + base_models = await self.scanner.get_base_models(limit) + + return web.json_response({ + 'success': True, + 'base_models': base_models + }) + except Exception as e: + logger.error(f"Error retrieving base models: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) }, status=500) \ No newline at end of file diff --git a/py/services/lora_scanner.py b/py/services/lora_scanner.py index cb3cfabb..de24f067 100644 --- a/py/services/lora_scanner.py +++ b/py/services/lora_scanner.py @@ -702,6 +702,32 @@ class LoraScanner: # Return limited number return sorted_tags[:limit] + + async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]: + """Get base models used in loras sorted by frequency + + Args: + limit: Maximum number of base models to return + + Returns: + List of dictionaries with base model name and count, sorted by count + """ + # Make sure cache is initialized + cache = await self.get_cached_data() + + # Count base model occurrences + base_model_counts = {} + for lora in cache.raw_data: + if 'base_model' in lora and lora['base_model']: + base_model = lora['base_model'] + base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 + + # Sort base models by count + sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()] + sorted_models.sort(key=lambda x: x['count'], reverse=True) + + # Return limited number + return sorted_models[:limit] async def diagnose_hash_index(self): """Diagnostic method to verify hash index functionality""" diff --git a/static/js/managers/FilterManager.js b/static/js/managers/FilterManager.js index 11e49e8b..36897430 100644 --- a/static/js/managers/FilterManager.js +++ b/static/js/managers/FilterManager.js @@ -66,7 +66,7 @@ export class FilterManager { tagsContainer.innerHTML = '
'; // Determine the API endpoint based on the page type - let tagsEndpoint = '/api/top-tags?limit=20'; + let tagsEndpoint = '/api/loras/top-tags?limit=20'; if (this.currentPage === 'recipes') { tagsEndpoint = '/api/recipes/top-tags?limit=20'; } @@ -136,80 +136,62 @@ export class FilterManager { const baseModelTagsContainer = document.getElementById('baseModelTags'); if (!baseModelTagsContainer) return; + // Set the appropriate API endpoint based on current page + let apiEndpoint = ''; if (this.currentPage === 'loras') { - // Use predefined base models for loras page - baseModelTagsContainer.innerHTML = ''; - - Object.entries(BASE_MODELS).forEach(([key, value]) => { - const tag = document.createElement('div'); - tag.className = `filter-tag base-model-tag ${BASE_MODEL_CLASSES[value]}`; - tag.dataset.baseModel = value; - tag.innerHTML = value; - - // Add click handler to toggle selection and automatically apply - tag.addEventListener('click', async () => { - tag.classList.toggle('active'); - - if (tag.classList.contains('active')) { - if (!this.filters.baseModel.includes(value)) { - this.filters.baseModel.push(value); - } - } else { - this.filters.baseModel = this.filters.baseModel.filter(model => model !== value); - } - - this.updateActiveFiltersCount(); - - // Auto-apply filter when tag is clicked - await this.applyFilters(false); - }); - - baseModelTagsContainer.appendChild(tag); - }); + apiEndpoint = '/api/loras/base-models'; } else if (this.currentPage === 'recipes') { - // Fetch base models for recipes - fetch('/api/recipes/base-models') - .then(response => response.json()) - .then(data => { - if (data.success && data.base_models) { - baseModelTagsContainer.innerHTML = ''; + apiEndpoint = '/api/recipes/base-models'; + } else { + return; // No API endpoint for other pages + } + + // Fetch base models + fetch(apiEndpoint) + .then(response => response.json()) + .then(data => { + if (data.success && data.base_models) { + baseModelTagsContainer.innerHTML = ''; + + data.base_models.forEach(model => { + const tag = document.createElement('div'); + // Add base model classes only for the loras page + const baseModelClass = (this.currentPage === 'loras' && BASE_MODEL_CLASSES[model.name]) + ? BASE_MODEL_CLASSES[model.name] + : ''; + tag.className = `filter-tag base-model-tag ${baseModelClass}`; + tag.dataset.baseModel = model.name; + tag.innerHTML = `${model.name} ${model.count}`; - data.base_models.forEach(model => { - const tag = document.createElement('div'); - tag.className = `filter-tag base-model-tag`; - tag.dataset.baseModel = model.name; - tag.innerHTML = `${model.name} ${model.count}`; + // Add click handler to toggle selection and automatically apply + tag.addEventListener('click', async () => { + tag.classList.toggle('active'); - // Add click handler to toggle selection and automatically apply - tag.addEventListener('click', async () => { - tag.classList.toggle('active'); - - if (tag.classList.contains('active')) { - if (!this.filters.baseModel.includes(model.name)) { - this.filters.baseModel.push(model.name); - } - } else { - this.filters.baseModel = this.filters.baseModel.filter(m => m !== model.name); + if (tag.classList.contains('active')) { + if (!this.filters.baseModel.includes(model.name)) { + this.filters.baseModel.push(model.name); } - - this.updateActiveFiltersCount(); - - // Auto-apply filter when tag is clicked - await this.applyFilters(false); - }); + } else { + this.filters.baseModel = this.filters.baseModel.filter(m => m !== model.name); + } - baseModelTagsContainer.appendChild(tag); + this.updateActiveFiltersCount(); + + // Auto-apply filter when tag is clicked + await this.applyFilters(false); }); - // Update selections based on stored filters - this.updateTagSelections(); - } - }) - .catch(error => { - console.error('Error fetching base models:', error); - baseModelTagsContainer.innerHTML = ''; - }); - } + baseModelTagsContainer.appendChild(tag); + }); + + // Update selections based on stored filters + this.updateTagSelections(); + } + }) + .catch(error => { + console.error(`Error fetching base models for ${this.currentPage}:`, error); + baseModelTagsContainer.innerHTML = ''; + }); } toggleFilterPanel() {