Add new API routes for base models and update existing routes

- Introduced a new endpoint for retrieving base models used in loras, enhancing the API functionality.
- Updated the existing top-tags route to reflect the new URL structure under '/api/loras'.
- Modified the FilterManager to accommodate the new base models API, ensuring proper data fetching and display on the loras page.
- Improved error handling and logging for base model retrieval, enhancing overall robustness of the application.
This commit is contained in:
Will Miao
2025-03-20 15:19:05 +08:00
parent addf92d966
commit 4a47dc2073
3 changed files with 101 additions and 68 deletions

View File

@@ -49,7 +49,8 @@ class ApiRoutes:
app.router.add_post('/loras/api/save-metadata', routes.save_metadata)
app.router.add_get('/api/lora-preview-url', routes.get_lora_preview_url) # Add new route
app.router.add_post('/api/move_models_bulk', routes.move_models_bulk)
app.router.add_get('/api/top-tags', routes.get_top_tags) # Add new route for top tags
app.router.add_get('/api/loras/top-tags', routes.get_top_tags) # Add new route for top tags
app.router.add_get('/api/loras/base-models', routes.get_base_models) # Add new route for base models
# Add update check routes
UpdateRoutes.setup_routes(app)
@@ -841,4 +842,28 @@ class ApiRoutes:
return web.json_response({
'success': False,
'error': 'Internal server error'
}, status=500)
async def get_base_models(self, request: web.Request) -> web.Response:
"""Get base models used in loras"""
try:
# Parse query parameters
limit = int(request.query.get('limit', '20'))
# Validate limit
if limit < 1 or limit > 100:
limit = 20 # Default to a reasonable limit
# Get base models
base_models = await self.scanner.get_base_models(limit)
return web.json_response({
'success': True,
'base_models': base_models
})
except Exception as e:
logger.error(f"Error retrieving base models: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)

View File

@@ -702,6 +702,32 @@ class LoraScanner:
# Return limited number
return sorted_tags[:limit]
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get base models used in loras sorted by frequency
Args:
limit: Maximum number of base models to return
Returns:
List of dictionaries with base model name and count, sorted by count
"""
# Make sure cache is initialized
cache = await self.get_cached_data()
# Count base model occurrences
base_model_counts = {}
for lora in cache.raw_data:
if 'base_model' in lora and lora['base_model']:
base_model = lora['base_model']
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
# Sort base models by count
sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()]
sorted_models.sort(key=lambda x: x['count'], reverse=True)
# Return limited number
return sorted_models[:limit]
async def diagnose_hash_index(self):
"""Diagnostic method to verify hash index functionality"""

View File

@@ -66,7 +66,7 @@ export class FilterManager {
tagsContainer.innerHTML = '<div class="tags-loading">Loading tags...</div>';
// Determine the API endpoint based on the page type
let tagsEndpoint = '/api/top-tags?limit=20';
let tagsEndpoint = '/api/loras/top-tags?limit=20';
if (this.currentPage === 'recipes') {
tagsEndpoint = '/api/recipes/top-tags?limit=20';
}
@@ -136,80 +136,62 @@ export class FilterManager {
const baseModelTagsContainer = document.getElementById('baseModelTags');
if (!baseModelTagsContainer) return;
// Set the appropriate API endpoint based on current page
let apiEndpoint = '';
if (this.currentPage === 'loras') {
// Use predefined base models for loras page
baseModelTagsContainer.innerHTML = '';
Object.entries(BASE_MODELS).forEach(([key, value]) => {
const tag = document.createElement('div');
tag.className = `filter-tag base-model-tag ${BASE_MODEL_CLASSES[value]}`;
tag.dataset.baseModel = value;
tag.innerHTML = value;
// Add click handler to toggle selection and automatically apply
tag.addEventListener('click', async () => {
tag.classList.toggle('active');
if (tag.classList.contains('active')) {
if (!this.filters.baseModel.includes(value)) {
this.filters.baseModel.push(value);
}
} else {
this.filters.baseModel = this.filters.baseModel.filter(model => model !== value);
}
this.updateActiveFiltersCount();
// Auto-apply filter when tag is clicked
await this.applyFilters(false);
});
baseModelTagsContainer.appendChild(tag);
});
apiEndpoint = '/api/loras/base-models';
} else if (this.currentPage === 'recipes') {
// Fetch base models for recipes
fetch('/api/recipes/base-models')
.then(response => response.json())
.then(data => {
if (data.success && data.base_models) {
baseModelTagsContainer.innerHTML = '';
apiEndpoint = '/api/recipes/base-models';
} else {
return; // No API endpoint for other pages
}
// Fetch base models
fetch(apiEndpoint)
.then(response => response.json())
.then(data => {
if (data.success && data.base_models) {
baseModelTagsContainer.innerHTML = '';
data.base_models.forEach(model => {
const tag = document.createElement('div');
// Add base model classes only for the loras page
const baseModelClass = (this.currentPage === 'loras' && BASE_MODEL_CLASSES[model.name])
? BASE_MODEL_CLASSES[model.name]
: '';
tag.className = `filter-tag base-model-tag ${baseModelClass}`;
tag.dataset.baseModel = model.name;
tag.innerHTML = `${model.name} <span class="tag-count">${model.count}</span>`;
data.base_models.forEach(model => {
const tag = document.createElement('div');
tag.className = `filter-tag base-model-tag`;
tag.dataset.baseModel = model.name;
tag.innerHTML = `${model.name} <span class="tag-count">${model.count}</span>`;
// Add click handler to toggle selection and automatically apply
tag.addEventListener('click', async () => {
tag.classList.toggle('active');
// Add click handler to toggle selection and automatically apply
tag.addEventListener('click', async () => {
tag.classList.toggle('active');
if (tag.classList.contains('active')) {
if (!this.filters.baseModel.includes(model.name)) {
this.filters.baseModel.push(model.name);
}
} else {
this.filters.baseModel = this.filters.baseModel.filter(m => m !== model.name);
if (tag.classList.contains('active')) {
if (!this.filters.baseModel.includes(model.name)) {
this.filters.baseModel.push(model.name);
}
this.updateActiveFiltersCount();
// Auto-apply filter when tag is clicked
await this.applyFilters(false);
});
} else {
this.filters.baseModel = this.filters.baseModel.filter(m => m !== model.name);
}
baseModelTagsContainer.appendChild(tag);
this.updateActiveFiltersCount();
// Auto-apply filter when tag is clicked
await this.applyFilters(false);
});
// Update selections based on stored filters
this.updateTagSelections();
}
})
.catch(error => {
console.error('Error fetching base models:', error);
baseModelTagsContainer.innerHTML = '<div class="tags-error">Failed to load base models</div>';
});
}
baseModelTagsContainer.appendChild(tag);
});
// Update selections based on stored filters
this.updateTagSelections();
}
})
.catch(error => {
console.error(`Error fetching base models for ${this.currentPage}:`, error);
baseModelTagsContainer.innerHTML = '<div class="tags-error">Failed to load base models</div>';
});
}
toggleFilterPanel() {