diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index 12a8aeb1..efd480dd 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -10,6 +10,7 @@ from datetime import datetime from ..services.checkpoint_scanner import CheckpointScanner from ..config import config from ..services.settings_manager import settings +from ..utils.utils import fuzzy_match logger = logging.getLogger(__name__) @@ -25,9 +26,12 @@ class CheckpointsRoutes: def setup_routes(self, app): """Register routes with the aiohttp app""" - app.router.add_get('/lora_manager/api/checkpoints', self.get_checkpoints) - app.router.add_get('/lora_manager/api/checkpoints/scan', self.scan_checkpoints) - app.router.add_get('/lora_manager/api/checkpoints/info/{name}', self.get_checkpoint_info) + app.router.add_get('/checkpoints', self.handle_checkpoints_page) + app.router.add_get('/api/checkpoints', self.get_checkpoints) + app.router.add_get('/api/checkpoints/base-models', self.get_base_models) + app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags) + app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints) + app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info) async def get_checkpoints(self, request): """Get paginated checkpoint data""" @@ -76,8 +80,17 @@ class CheckpointsRoutes: hash_filters=hash_filters ) + # Format response items + formatted_result = { + 'items': [self._format_checkpoint_response(cp) for cp in result['items']], + 'total': result['total'], + 'page': result['page'], + 'page_size': result['page_size'], + 'total_pages': result['total_pages'] + } + # Return as JSON - return web.json_response(result) + return web.json_response(formatted_result) except Exception as e: logger.error(f"Error in get_checkpoints: {e}", exc_info=True) @@ -90,28 +103,122 @@ class CheckpointsRoutes: """Get paginated and filtered checkpoint data""" cache = await self.scanner.get_cached_data() - # Implement similar filtering logic as in LoraScanner - # (Adapt code from LoraScanner.get_paginated_data) - # ... - - # For now, a simplified implementation: + # Get default search options if not provided + if search_options is None: + search_options = { + 'filename': True, + 'modelname': True, + 'tags': False, + 'recursive': False, + } + + # Get the base data set filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name - # Apply basic folder filtering if needed + # Apply hash filtering if provided (highest priority) + if hash_filters: + single_hash = hash_filters.get('single_hash') + multiple_hashes = hash_filters.get('multiple_hashes') + + if single_hash: + # Filter by single hash + single_hash = single_hash.lower() # Ensure lowercase for matching + filtered_data = [ + cp for cp in filtered_data + if cp.get('sha256', '').lower() == single_hash + ] + elif multiple_hashes: + # Filter by multiple hashes + hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup + filtered_data = [ + cp for cp in filtered_data + if cp.get('sha256', '').lower() in hash_set + ] + + # Jump to pagination + total_items = len(filtered_data) + start_idx = (page - 1) * page_size + end_idx = min(start_idx + page_size, total_items) + + result = { + 'items': filtered_data[start_idx:end_idx], + 'total': total_items, + 'page': page, + 'page_size': page_size, + 'total_pages': (total_items + page_size - 1) // page_size + } + + return result + + # Apply SFW filtering if enabled in settings + if settings.get('show_only_sfw', False): + filtered_data = [ + cp for cp in filtered_data + if not cp.get('preview_nsfw_level') or cp.get('preview_nsfw_level') < NSFW_LEVELS['R'] + ] + + # Apply folder filtering if folder is not None: + if search_options.get('recursive', False): + # Recursive folder filtering - include all subfolders + filtered_data = [ + cp for cp in filtered_data + if cp['folder'].startswith(folder) + ] + else: + # Exact folder filtering + filtered_data = [ + cp for cp in filtered_data + if cp['folder'] == folder + ] + + # Apply base model filtering + if base_models and len(base_models) > 0: filtered_data = [ cp for cp in filtered_data - if cp['folder'] == folder + if cp.get('base_model') in base_models ] - # Apply basic search if needed + # Apply tag filtering + if tags and len(tags) > 0: + filtered_data = [ + cp for cp in filtered_data + if any(tag in cp.get('tags', []) for tag in tags) + ] + + # Apply search filtering if search: - filtered_data = [ - cp for cp in filtered_data - if search.lower() in cp['file_name'].lower() or - search.lower() in cp['model_name'].lower() - ] - + search_results = [] + + for cp in filtered_data: + # Search by file name + if search_options.get('filename', True): + if fuzzy_search: + if fuzzy_match(cp.get('file_name', ''), search): + search_results.append(cp) + continue + elif search.lower() in cp.get('file_name', '').lower(): + search_results.append(cp) + continue + + # Search by model name + if search_options.get('modelname', True): + if fuzzy_search: + if fuzzy_match(cp.get('model_name', ''), search): + search_results.append(cp) + continue + elif search.lower() in cp.get('model_name', '').lower(): + search_results.append(cp) + continue + + # Search by tags + if search_options.get('tags', False) and 'tags' in cp: + if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower()) for tag in cp['tags']): + search_results.append(cp) + continue + + filtered_data = search_results + # Calculate pagination total_items = len(filtered_data) start_idx = (page - 1) * page_size @@ -127,6 +234,88 @@ class CheckpointsRoutes: return result + def _format_checkpoint_response(self, checkpoint): + """Format checkpoint data for API response""" + return { + "model_name": checkpoint["model_name"], + "file_name": checkpoint["file_name"], + "preview_url": config.get_preview_static_url(checkpoint.get("preview_url", "")), + "preview_nsfw_level": checkpoint.get("preview_nsfw_level", 0), + "base_model": checkpoint.get("base_model", ""), + "folder": checkpoint["folder"], + "sha256": checkpoint.get("sha256", ""), + "file_path": checkpoint["file_path"].replace(os.sep, "/"), + "file_size": checkpoint.get("size", 0), + "modified": checkpoint.get("modified", ""), + "tags": checkpoint.get("tags", []), + "modelDescription": checkpoint.get("modelDescription", ""), + "from_civitai": checkpoint.get("from_civitai", True), + "notes": checkpoint.get("notes", ""), + "model_type": checkpoint.get("model_type", "checkpoint"), + "civitai": self._filter_civitai_data(checkpoint.get("civitai", {})) + } + + def _filter_civitai_data(self, data): + """Filter relevant fields from CivitAI data""" + if not data: + return {} + + fields = [ + "id", "modelId", "name", "createdAt", "updatedAt", + "publishedAt", "trainedWords", "baseModel", "description", + "model", "images" + ] + return {k: data[k] for k in fields if k in data} + + async def get_top_tags(self, request: web.Request) -> web.Response: + """Handle request for top tags sorted by frequency""" + try: + # Parse query parameters + limit = int(request.query.get('limit', '20')) + + # Validate limit + if limit < 1 or limit > 100: + limit = 20 # Default to a reasonable limit + + # Get top tags + top_tags = await self.scanner.get_top_tags(limit) + + return web.json_response({ + 'success': True, + 'tags': top_tags + }) + + except Exception as e: + logger.error(f"Error getting top tags: {str(e)}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': 'Internal server error' + }, status=500) + + async def get_base_models(self, request: web.Request) -> web.Response: + """Get base models used in loras""" + try: + # Parse query parameters + limit = int(request.query.get('limit', '20')) + + # Validate limit + if limit < 1 or limit > 100: + limit = 20 # Default to a reasonable limit + + # Get base models + base_models = await self.scanner.get_base_models(limit) + + return web.json_response({ + 'success': True, + 'base_models': base_models + }) + except Exception as e: + logger.error(f"Error retrieving base models: {e}") + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + async def scan_checkpoints(self, request): """Force a rescan of checkpoint files""" try: diff --git a/static/js/api/checkpointApi.js b/static/js/api/checkpointApi.js new file mode 100644 index 00000000..cbf1ca0b --- /dev/null +++ b/static/js/api/checkpointApi.js @@ -0,0 +1,247 @@ +import { state, getCurrentPageState } from '../state/index.js'; +import { showToast } from '../utils/uiHelpers.js'; +import { confirmDelete } from '../utils/modalUtils.js'; +import { createCheckpointCard } from '../components/CheckpointCard.js'; + +// Load more checkpoints with pagination +export async function loadMoreCheckpoints(resetPagination = true) { + try { + const pageState = getCurrentPageState(); + + // Don't load if we're already loading or there are no more items + if (pageState.isLoading || (!resetPagination && !pageState.hasMore)) { + return; + } + + // Set loading state + pageState.isLoading = true; + document.body.classList.add('loading'); + + // Reset pagination if requested + if (resetPagination) { + pageState.currentPage = 1; + const grid = document.getElementById('checkpointGrid'); + if (grid) grid.innerHTML = ''; + } + + // Build API URL with parameters + const params = new URLSearchParams({ + page: pageState.currentPage, + page_size: pageState.pageSize || 20, + sort: pageState.sortBy || 'name' + }); + + // Add folder filter if active + if (pageState.activeFolder) { + params.append('folder', pageState.activeFolder); + } + + // Add search if available + if (pageState.filters && pageState.filters.search) { + params.append('search', pageState.filters.search); + + // Add search options + if (pageState.searchOptions) { + params.append('search_filename', pageState.searchOptions.filename.toString()); + params.append('search_modelname', pageState.searchOptions.modelname.toString()); + params.append('recursive', pageState.searchOptions.recursive.toString()); + } + } + + // Add base model filters + if (pageState.filters && pageState.filters.baseModel && pageState.filters.baseModel.length > 0) { + pageState.filters.baseModel.forEach(model => { + params.append('base_model', model); + }); + } + + // Add tags filters + if (pageState.filters && pageState.filters.tags && pageState.filters.tags.length > 0) { + pageState.filters.tags.forEach(tag => { + params.append('tag', tag); + }); + } + + // Execute fetch + const response = await fetch(`/api/checkpoints?${params.toString()}`); + + if (!response.ok) { + throw new Error(`Failed to load checkpoints: ${response.status} ${response.statusText}`); + } + + const data = await response.json(); + + // Update state with response data + pageState.hasMore = data.page < data.total_pages; + + // Update UI with checkpoints + const grid = document.getElementById('checkpointGrid'); + if (!grid) { + return; + } + + // Clear grid if this is the first page + if (resetPagination) { + grid.innerHTML = ''; + } + + // Check for empty result + if (data.items.length === 0 && resetPagination) { + grid.innerHTML = ` +
+ `; + return; + } + + // Render checkpoint cards + data.items.forEach(checkpoint => { + const card = createCheckpointCard(checkpoint); + grid.appendChild(card); + }); + } catch (error) { + console.error('Error loading checkpoints:', error); + showToast('Failed to load checkpoints', 'error'); + } finally { + // Clear loading state + const pageState = getCurrentPageState(); + pageState.isLoading = false; + document.body.classList.remove('loading'); + } +} + +// Reset and reload checkpoints +export async function resetAndReload() { + const pageState = getCurrentPageState(); + pageState.currentPage = 1; + pageState.hasMore = true; + await loadMoreCheckpoints(true); +} + +// Refresh checkpoints +export async function refreshCheckpoints() { + try { + showToast('Scanning for checkpoints...', 'info'); + const response = await fetch('/api/checkpoints/scan'); + + if (!response.ok) { + throw new Error(`Failed to scan checkpoints: ${response.status} ${response.statusText}`); + } + + await resetAndReload(); + showToast('Checkpoints refreshed successfully', 'success'); + } catch (error) { + console.error('Error refreshing checkpoints:', error); + showToast('Failed to refresh checkpoints', 'error'); + } +} + +// Delete a checkpoint +export function deleteCheckpoint(filePath) { + confirmDelete('Are you sure you want to delete this checkpoint?', () => { + _performDelete(filePath); + }); +} + +// Private function to perform the delete operation +async function _performDelete(filePath) { + try { + showToast('Deleting checkpoint...', 'info'); + + const response = await fetch('/api/model/delete', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + file_path: filePath, + model_type: 'checkpoint' + }) + }); + + if (!response.ok) { + throw new Error(`Failed to delete checkpoint: ${response.status} ${response.statusText}`); + } + + const data = await response.json(); + + if (data.success) { + // Remove the card from UI + const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); + if (card) { + card.remove(); + } + + showToast('Checkpoint deleted successfully', 'success'); + } else { + throw new Error(data.error || 'Failed to delete checkpoint'); + } + } catch (error) { + console.error('Error deleting checkpoint:', error); + showToast(`Failed to delete checkpoint: ${error.message}`, 'error'); + } +} + +// Replace checkpoint preview +export function replaceCheckpointPreview(filePath) { + // Open file picker + const input = document.createElement('input'); + input.type = 'file'; + input.accept = 'image/*'; + input.onchange = async (e) => { + if (!e.target.files.length) return; + + const file = e.target.files[0]; + await _uploadPreview(filePath, file); + }; + input.click(); +} + +// Upload a preview image +async function _uploadPreview(filePath, file) { + try { + showToast('Uploading preview...', 'info'); + + const formData = new FormData(); + formData.append('file', file); + formData.append('file_path', filePath); + formData.append('model_type', 'checkpoint'); + + const response = await fetch('/api/model/preview', { + method: 'POST', + body: formData + }); + + if (!response.ok) { + throw new Error(`Failed to upload preview: ${response.status} ${response.statusText}`); + } + + const data = await response.json(); + + if (data.success) { + // Update the preview in UI + const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`); + if (card) { + const img = card.querySelector('.card-preview img'); + if (img) { + // Add timestamp to prevent caching + const timestamp = new Date().getTime(); + if (data.preview_url) { + img.src = `${data.preview_url}?t=${timestamp}`; + } else { + img.src = `/api/model/preview_image?path=${encodeURIComponent(filePath)}&t=${timestamp}`; + } + } + } + + showToast('Preview updated successfully', 'success'); + } else { + throw new Error(data.error || 'Failed to update preview'); + } + } catch (error) { + console.error('Error updating preview:', error); + showToast(`Failed to update preview: ${error.message}`, 'error'); + } +} \ No newline at end of file diff --git a/static/js/checkpoints.js b/static/js/checkpoints.js index ea149a2f..8b563f8c 100644 --- a/static/js/checkpoints.js +++ b/static/js/checkpoints.js @@ -1,36 +1,128 @@ import { appCore } from './core.js'; -import { state, initPageState } from './state/index.js'; +import { state, getCurrentPageState } from './state/index.js'; +import { + loadMoreCheckpoints, + resetAndReload, + refreshCheckpoints, + deleteCheckpoint, + replaceCheckpointPreview +} from './api/checkpointApi.js'; +import { + restoreFolderFilter, + toggleFolder, + openCivitai, + showToast +} from './utils/uiHelpers.js'; +import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js'; +import { toggleApiKeyVisibility } from './managers/SettingsManager.js'; +import { initializeInfiniteScroll } from './utils/infiniteScroll.js'; +import { setStorageItem, getStorageItem } from './utils/storageHelpers.js'; // Initialize the Checkpoints page class CheckpointsPageManager { constructor() { - // Initialize any necessary state - this.initialized = false; + // Get page state + this.pageState = getCurrentPageState(); + + // Set default values + this.pageState.pageSize = 20; + this.pageState.isLoading = false; + this.pageState.hasMore = true; + + // Expose functions to window object + this._exposeGlobalFunctions(); + } + + _exposeGlobalFunctions() { + // API functions + window.loadCheckpoints = (reset = true) => this.loadCheckpoints(reset); + window.refreshCheckpoints = refreshCheckpoints; + window.deleteCheckpoint = deleteCheckpoint; + window.replaceCheckpointPreview = replaceCheckpointPreview; + + // UI helper functions + window.toggleFolder = toggleFolder; + window.openCivitai = openCivitai; + window.confirmDelete = confirmDelete; + window.closeDeleteModal = closeDeleteModal; + window.toggleApiKeyVisibility = toggleApiKeyVisibility; + + // Add reference to this manager + window.checkpointManager = this; } async initialize() { - if (this.initialized) return; + // Initialize event listeners + this._initEventListeners(); - // Initialize page state - initPageState('checkpoints'); + // Restore folder filters if available + restoreFolderFilter('checkpoints'); - // Initialize core application - await appCore.initialize(); + // Load sort preference + this._loadSortPreference(); - // Initialize page-specific components - this._initializeWorkInProgress(); + // Load initial checkpoints + await this.loadCheckpoints(); - this.initialized = true; + // Initialize infinite scroll + initializeInfiniteScroll('checkpoints'); + + // Initialize common page features + appCore.initializePageFeatures(); + + console.log('Checkpoints Manager initialized'); } - _initializeWorkInProgress() { - // Add any work-in-progress specific initialization here - console.log('Checkpoints Manager is under development'); + _initEventListeners() { + // Sort select handler + const sortSelect = document.getElementById('sortSelect'); + if (sortSelect) { + sortSelect.addEventListener('change', async (e) => { + this.pageState.sortBy = e.target.value; + this._saveSortPreference(e.target.value); + await resetAndReload(); + }); + } + + // Folder tags handler + document.querySelectorAll('.folder-tags .tag').forEach(tag => { + tag.addEventListener('click', toggleFolder); + }); + + // Refresh button handler + const refreshBtn = document.getElementById('refreshBtn'); + if (refreshBtn) { + refreshBtn.addEventListener('click', () => refreshCheckpoints()); + } + } + + _loadSortPreference() { + const savedSort = getStorageItem('checkpoints_sort'); + if (savedSort) { + this.pageState.sortBy = savedSort; + const sortSelect = document.getElementById('sortSelect'); + if (sortSelect) { + sortSelect.value = savedSort; + } + } + } + + _saveSortPreference(sortValue) { + setStorageItem('checkpoints_sort', sortValue); + } + + // Load checkpoints with optional pagination reset + async loadCheckpoints(resetPage = true) { + await loadMoreCheckpoints(resetPage); } } // Initialize everything when DOM is ready document.addEventListener('DOMContentLoaded', async () => { + // Initialize core application + await appCore.initialize(); + + // Initialize checkpoints page const checkpointsPage = new CheckpointsPageManager(); await checkpointsPage.initialize(); }); diff --git a/static/js/components/CheckpointCard.js b/static/js/components/CheckpointCard.js new file mode 100644 index 00000000..a9246b8e --- /dev/null +++ b/static/js/components/CheckpointCard.js @@ -0,0 +1,147 @@ +import { showToast } from '../utils/uiHelpers.js'; +import { state } from '../state/index.js'; +import { CheckpointModal } from './CheckpointModal.js'; + +// Create an instance of the modal +const checkpointModal = new CheckpointModal(); + +export function createCheckpointCard(checkpoint) { + const card = document.createElement('div'); + card.className = 'lora-card'; // Reuse the same class for styling + card.dataset.sha256 = checkpoint.sha256; + card.dataset.filepath = checkpoint.file_path; + card.dataset.name = checkpoint.model_name; + card.dataset.file_name = checkpoint.file_name; + card.dataset.folder = checkpoint.folder; + card.dataset.modified = checkpoint.modified; + card.dataset.file_size = checkpoint.file_size; + card.dataset.from_civitai = checkpoint.from_civitai; + card.dataset.base_model = checkpoint.base_model || 'Unknown'; + + // Store metadata if available + if (checkpoint.civitai) { + card.dataset.meta = JSON.stringify(checkpoint.civitai || {}); + } + + // Store tags if available + if (checkpoint.tags && Array.isArray(checkpoint.tags)) { + card.dataset.tags = JSON.stringify(checkpoint.tags); + } + + // Determine preview URL + const previewUrl = checkpoint.preview_url || '/loras_static/images/no-preview.png'; + const version = state.previewVersions ? state.previewVersions.get(checkpoint.file_path) : null; + const versionedPreviewUrl = version ? `${previewUrl}?t=${version}` : previewUrl; + + card.innerHTML = ` +
+ This feature is currently under development and will be available soon.
-Please check back later for updates!
+ {% include 'components/controls.html' %} + +