From 721bef3ff8efcfe9ffa873ea93ff5d2f5c85507e Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Mon, 10 Mar 2025 13:18:56 +0800 Subject: [PATCH] Add tag filtering checkpoint --- py/routes/api_routes.py | 89 +++++++++++++---- py/routes/lora_routes.py | 5 + py/services/civitai_client.py | 33 ++++--- py/services/file_monitor.py | 14 +++ py/services/lora_scanner.py | 113 +++++++++++++++++++++- py/utils/file_utils.py | 2 + py/utils/models.py | 11 ++- static/css/components/lora-modal.css | 39 ++++++++ static/css/components/search-filter.css | 40 +++++++- static/js/api/loraApi.js | 12 ++- static/js/components/LoraCard.js | 12 ++- static/js/components/LoraModal.js | 30 +++++- static/js/managers/FilterManager.js | 122 ++++++++++++++++++++++-- static/js/state/index.js | 3 +- templates/components/controls.html | 7 ++ 15 files changed, 482 insertions(+), 50 deletions(-) diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index 49d760dd..cde1284a 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -45,6 +45,7 @@ class ApiRoutes: app.router.add_post('/loras/api/save-metadata', routes.save_metadata) app.router.add_get('/api/lora-preview-url', routes.get_lora_preview_url) # Add new route app.router.add_post('/api/move_models_bulk', routes.move_models_bulk) + app.router.add_get('/api/top-tags', routes.get_top_tags) # Add new route for top tags # Add update check routes UpdateRoutes.setup_routes(app) @@ -142,6 +143,10 @@ class ApiRoutes: 'error': 'Invalid sort parameter' }, status=400) + # Parse tags filter parameter + tags = request.query.get('tags', '').split(',') + tags = [tag.strip() for tag in tags if tag.strip()] + # Get paginated data with search and filters result = await self.scanner.get_paginated_data( page=page, @@ -151,7 +156,8 @@ class ApiRoutes: search=search, fuzzy=fuzzy, recursive=recursive, - base_models=base_models # Pass base models filter + base_models=base_models, # Pass base models filter + tags=tags # Add tags parameter ) # Format the response data @@ -190,6 +196,8 @@ class ApiRoutes: "file_path": lora["file_path"].replace(os.sep, "/"), "file_size": lora["size"], "modified": lora["modified"], + "tags": lora["tags"], + "modelDescription": lora["modelDescription"], "from_civitai": lora.get("from_civitai", True), "usage_tips": lora.get("usage_tips", ""), "notes": lora.get("notes", ""), @@ -335,6 +343,14 @@ class ApiRoutes: local_metadata['model_name'] = civitai_metadata['model'].get('name', local_metadata.get('model_name')) + # Fetch additional model metadata (description and tags) if we have model ID + model_id = civitai_metadata['modelId'] + if model_id: + model_metadata = await client.get_model_metadata(str(model_id)) + if model_metadata: + local_metadata['modelDescription'] = model_metadata.get('description', '') + local_metadata['tags'] = model_metadata.get('tags', []) + # Update base model local_metadata['base_model'] = civitai_metadata.get('baseModel') @@ -708,6 +724,7 @@ class ApiRoutes: # Check if we already have the description stored in metadata description = None + tags = [] if file_path: metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' if os.path.exists(metadata_path): @@ -715,38 +732,70 @@ class ApiRoutes: with open(metadata_path, 'r', encoding='utf-8') as f: metadata = json.load(f) description = metadata.get('modelDescription') + tags = metadata.get('tags', []) except Exception as e: logger.error(f"Error loading metadata from {metadata_path}: {e}") # If description is not in metadata, fetch from CivitAI if not description: - logger.info(f"Fetching model description for model ID: {model_id}") - description = await self.civitai_client.get_model_description(model_id) + logger.info(f"Fetching model metadata for model ID: {model_id}") + model_metadata = await self.civitai_client.get_model_metadata(model_id) - # Save the description to metadata if we have a file path and got a description - if file_path and description: - try: - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - if os.path.exists(metadata_path): - with open(metadata_path, 'r', encoding='utf-8') as f: - metadata = json.load(f) - - metadata['modelDescription'] = description - - with open(metadata_path, 'w', encoding='utf-8') as f: - json.dump(metadata, f, indent=2, ensure_ascii=False) - logger.info(f"Saved model description to metadata for {file_path}") - except Exception as e: - logger.error(f"Error saving model description to metadata: {e}") + if model_metadata: + description = model_metadata.get('description') + tags = model_metadata.get('tags', []) + + # Save the metadata to file if we have a file path and got metadata + if file_path: + try: + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + if os.path.exists(metadata_path): + with open(metadata_path, 'r', encoding='utf-8') as f: + metadata = json.load(f) + + metadata['modelDescription'] = description + metadata['tags'] = tags + + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) + logger.info(f"Saved model metadata to file for {file_path}") + except Exception as e: + logger.error(f"Error saving model metadata: {e}") return web.json_response({ 'success': True, - 'description': description or "
No model description available.
" + 'description': description or "No model description available.
", + 'tags': tags }) except Exception as e: - logger.error(f"Error getting model description: {e}", exc_info=True) + logger.error(f"Error getting model metadata: {e}", exc_info=True) return web.json_response({ 'success': False, 'error': str(e) }, status=500) + + async def get_top_tags(self, request: web.Request) -> web.Response: + """Handle request for top tags sorted by frequency""" + try: + # Parse query parameters + limit = int(request.query.get('limit', '20')) + + # Validate limit + if limit < 1 or limit > 100: + limit = 20 # Default to a reasonable limit + + # Get top tags + top_tags = await self.scanner.get_top_tags(limit) + + return web.json_response({ + 'success': True, + 'tags': top_tags + }) + + except Exception as e: + logger.error(f"Error getting top tags: {str(e)}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': 'Internal server error' + }, status=500) diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index 64c71385..5359aacc 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -30,6 +30,11 @@ class LoraRoutes: "folder": lora["folder"], "sha256": lora["sha256"], "file_path": lora["file_path"].replace(os.sep, "/"), + "size": lora["size"], + "tags": lora["tags"], + "modelDescription": lora["modelDescription"], + "usage_tips": lora["usage_tips"], + "notes": lora["notes"], "modified": lora["modified"], "from_civitai": lora.get("from_civitai", True), "civitai": self._filter_civitai_data(lora.get("civitai", {})) diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index a264d45b..389fc491 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -163,41 +163,52 @@ class CivitaiClient: logger.error(f"Error fetching model version info: {e}") return None - async def get_model_description(self, model_id: str) -> Optional[str]: - """Fetch the model description from Civitai API + async def get_model_metadata(self, model_id: str) -> Optional[Dict]: + """Fetch model metadata (description and tags) from Civitai API Args: model_id: The Civitai model ID Returns: - Optional[str]: The model description HTML or None if not found + Optional[Dict]: A dictionary containing model metadata or None if not found """ try: session = await self.session headers = self._get_request_headers() url = f"{self.base_url}/models/{model_id}" - logger.info(f"Fetching model description from {url}") + logger.info(f"Fetching model metadata from {url}") async with session.get(url, headers=headers) as response: if response.status != 200: - logger.warning(f"Failed to fetch model description: Status {response.status}") + logger.warning(f"Failed to fetch model metadata: Status {response.status}") return None data = await response.json() - description = data.get('description') - if description: - logger.info(f"Successfully retrieved description for model {model_id}") - return description + # Extract relevant metadata + metadata = { + "description": data.get("description", ""), + "tags": data.get("tags", []) + } + + if metadata["description"] or metadata["tags"]: + logger.info(f"Successfully retrieved metadata for model {model_id}") + return metadata else: - logger.warning(f"No description found for model {model_id}") + logger.warning(f"No metadata found for model {model_id}") return None except Exception as e: - logger.error(f"Error fetching model description: {e}", exc_info=True) + logger.error(f"Error fetching model metadata: {e}", exc_info=True) return None + # Keep old method for backward compatibility, delegating to the new one + async def get_model_description(self, model_id: str) -> Optional[str]: + """Fetch the model description from Civitai API (Legacy method)""" + metadata = await self.get_model_metadata(model_id) + return metadata.get("description") if metadata else None + async def close(self): """Close the session if it exists""" if self._session is not None: diff --git a/py/services/file_monitor.py b/py/services/file_monitor.py index 3bef2dd5..33b53448 100644 --- a/py/services/file_monitor.py +++ b/py/services/file_monitor.py @@ -98,6 +98,10 @@ class LoraFileHandler(FileSystemEventHandler): # Scan new file lora_data = await self.scanner.scan_single_lora(file_path) if lora_data: + # Update tags count + for tag in lora_data.get('tags', []): + self.scanner._tags_count[tag] = self.scanner._tags_count.get(tag, 0) + 1 + cache.raw_data.append(lora_data) new_folders.add(lora_data['folder']) # Update hash index @@ -109,6 +113,16 @@ class LoraFileHandler(FileSystemEventHandler): needs_resort = True elif action == 'remove': + # Find the lora to remove so we can update tags count + lora_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None) + if lora_to_remove: + # Update tags count by reducing counts + for tag in lora_to_remove.get('tags', []): + if tag in self.scanner._tags_count: + self.scanner._tags_count[tag] = max(0, self.scanner._tags_count[tag] - 1) + if self.scanner._tags_count[tag] == 0: + del self.scanner._tags_count[tag] + # Remove from cache and hash index logger.info(f"Removing {file_path} from cache") self.scanner._hash_index.remove_by_path(file_path) diff --git a/py/services/lora_scanner.py b/py/services/lora_scanner.py index 857a4369..6777b577 100644 --- a/py/services/lora_scanner.py +++ b/py/services/lora_scanner.py @@ -34,6 +34,7 @@ class LoraScanner: self._initialization_task: Optional[asyncio.Task] = None self._initialized = True self.file_monitor = None # Add this line + self._tags_count = {} # Add a dictionary to store tag counts def set_file_monitor(self, monitor): """Set file monitor instance""" @@ -90,13 +91,21 @@ class LoraScanner: # Clear existing hash index self._hash_index.clear() + # Clear existing tags count + self._tags_count = {} + # Scan for new data raw_data = await self.scan_all_loras() - # Build hash index + # Build hash index and tags count for lora_data in raw_data: if 'sha256' in lora_data and 'file_path' in lora_data: self._hash_index.add_entry(lora_data['sha256'], lora_data['file_path']) + + # Count tags + if 'tags' in lora_data and lora_data['tags']: + for tag in lora_data['tags']: + self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 # Update cache self._cache = LoraCache( @@ -158,7 +167,7 @@ class LoraScanner: async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name', folder: str = None, search: str = None, fuzzy: bool = False, - recursive: bool = False, base_models: list = None): + recursive: bool = False, base_models: list = None, tags: list = None) -> Dict: """Get paginated and filtered lora data Args: @@ -170,6 +179,7 @@ class LoraScanner: fuzzy: Use fuzzy matching for search recursive: Include subfolders when folder filter is applied base_models: List of base models to filter by + tags: List of tags to filter by """ cache = await self.get_cached_data() @@ -198,6 +208,13 @@ class LoraScanner: if item.get('base_model') in base_models ] + # Apply tag filtering + if tags and len(tags) > 0: + filtered_data = [ + item for item in filtered_data + if any(tag in item.get('tags', []) for tag in tags) + ] + # 应用搜索过滤 if search: if fuzzy: @@ -311,12 +328,67 @@ class LoraScanner: # Convert to dict and add folder info lora_data = metadata.to_dict() + # Try to fetch missing metadata from Civitai if needed + await self._fetch_missing_metadata(file_path, lora_data) rel_path = os.path.relpath(file_path, root_path) folder = os.path.dirname(rel_path) lora_data['folder'] = folder.replace(os.path.sep, '/') return lora_data + async def _fetch_missing_metadata(self, file_path: str, lora_data: Dict) -> None: + """Fetch missing description and tags from Civitai if needed + + Args: + file_path: Path to the lora file + lora_data: Lora metadata dictionary to update + """ + try: + # Check if we need to fetch additional metadata from Civitai + needs_metadata_update = False + model_id = None + + # Check if we have Civitai model ID but missing metadata + if lora_data.get('civitai'): + # Try to get model ID directly from the correct location + model_id = lora_data['civitai'].get('modelId') + + if model_id: + model_id = str(model_id) + # Check if tags are missing or empty + tags_missing = not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0 + + # Check if description is missing or empty + desc_missing = not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "") + + needs_metadata_update = tags_missing or desc_missing + + # Fetch missing metadata if needed + if needs_metadata_update and model_id: + logger.info(f"Fetching missing metadata for {file_path} with model ID {model_id}") + from ..services.civitai_client import CivitaiClient + client = CivitaiClient() + model_metadata = await client.get_model_metadata(model_id) + await client.close() + + if model_metadata: + logger.info(f"Updating metadata for {file_path} with model ID {model_id}") + + # Update tags if they were missing + if model_metadata.get('tags') and (not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0): + lora_data['tags'] = model_metadata['tags'] + + # Update description if it was missing + if model_metadata.get('description') and (not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "")): + lora_data['modelDescription'] = model_metadata['description'] + + # Save the updated metadata back to file + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(lora_data, f, indent=2, ensure_ascii=False) + except Exception as e: + logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}") + async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool: """Update preview URL in cache for a specific lora @@ -427,6 +499,15 @@ class LoraScanner: async def update_single_lora_cache(self, original_path: str, new_path: str, metadata: Dict) -> bool: cache = await self.get_cached_data() + # Find the existing item to remove its tags from count + existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None) + if existing_item and 'tags' in existing_item: + for tag in existing_item.get('tags', []): + if tag in self._tags_count: + self._tags_count[tag] = max(0, self._tags_count[tag] - 1) + if self._tags_count[tag] == 0: + del self._tags_count[tag] + # Remove old path from hash index if exists self._hash_index.remove_by_path(original_path) @@ -460,6 +541,11 @@ class LoraScanner: # Update folders list all_folders = set(item['folder'] for item in cache.raw_data) cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) + + # Update tags count with the new/updated tags + if 'tags' in metadata: + for tag in metadata.get('tags', []): + self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 # Resort cache await cache.resort() @@ -505,3 +591,26 @@ class LoraScanner: """Get hash for a LoRA by its file path""" return self._hash_index.get_hash(file_path) + # Add new method to get top tags + async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]: + """Get top tags sorted by count + + Args: + limit: Maximum number of tags to return + + Returns: + List of dictionaries with tag name and count, sorted by count + """ + # Make sure cache is initialized + await self.get_cached_data() + + # Sort tags by count in descending order + sorted_tags = sorted( + [{"tag": tag, "count": count} for tag, count in self._tags_count.items()], + key=lambda x: x['count'], + reverse=True + ) + + # Return limited number + return sorted_tags[:limit] + diff --git a/py/utils/file_utils.py b/py/utils/file_utils.py index 2aaa1ed2..f9758e12 100644 --- a/py/utils/file_utils.py +++ b/py/utils/file_utils.py @@ -69,6 +69,8 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]: notes="", from_civitai=True, preview_url=normalize_path(preview_url), + tags=[], + modelDescription="" ) # create metadata file diff --git a/py/utils/models.py b/py/utils/models.py index 0a39e065..49568fee 100644 --- a/py/utils/models.py +++ b/py/utils/models.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, asdict -from typing import Dict, Optional +from typing import Dict, Optional, List from datetime import datetime import os from .model_utils import determine_base_model @@ -17,8 +17,15 @@ class LoraMetadata: preview_url: str # Preview image URL usage_tips: str = "{}" # Usage tips for the model, json string notes: str = "" # Additional notes - from_civitai: bool = True # Whether the lora is from Civitai + from_civitai: bool = True # Whether the lora is from Civitai civitai: Optional[Dict] = None # Civitai API data if available + tags: List[str] = None # Model tags + modelDescription: str = "" # Full model description + + def __post_init__(self): + # Initialize empty lists to avoid mutable default parameter issue + if self.tags is None: + self.tags = [] @classmethod def from_dict(cls, data: Dict) -> 'LoraMetadata': diff --git a/static/css/components/lora-modal.css b/static/css/components/lora-modal.css index 8cb990fb..3e7bebc6 100644 --- a/static/css/components/lora-modal.css +++ b/static/css/components/lora-modal.css @@ -699,4 +699,43 @@ [data-theme="dark"] .model-description-content pre, [data-theme="dark"] .model-description-content code { background: rgba(255, 255, 255, 0.05); +} + +/* Model Tags styles */ +.model-tags { + display: flex; + flex-wrap: wrap; + gap: 6px; + margin-top: 8px; + margin-bottom: 4px; +} + +.model-tag { + background: var(--lora-surface); + border: 1px solid var(--lora-border); + border-radius: var(--border-radius-xs); + padding: 3px 8px; + font-size: 0.8em; + color: var(--lora-accent); + cursor: pointer; + display: inline-flex; + align-items: center; + gap: 5px; + transition: all 0.2s; +} + +.model-tag i { + font-size: 0.85em; + opacity: 0.6; + color: var(--text-color); +} + +.model-tag:hover { + background: oklch(var(--lora-accent) / 0.1); + border-color: var(--lora-accent); +} + +.model-tag:hover i { + opacity: 1; + color: var(--lora-accent); } \ No newline at end of file diff --git a/static/css/components/search-filter.css b/static/css/components/search-filter.css index 50f7fc40..5b0c440e 100644 --- a/static/css/components/search-filter.css +++ b/static/css/components/search-filter.css @@ -237,6 +237,44 @@ border-color: var(--lora-accent); } +/* Tag filter styles */ +.tag-filter { + display: flex; + align-items: center; + justify-content: space-between; + min-width: 60px; +} + +.tag-count { + background: rgba(0, 0, 0, 0.1); + padding: 1px 6px; + border-radius: 10px; + font-size: 0.8em; + margin-left: 4px; +} + +[data-theme="dark"] .tag-count { + background: rgba(255, 255, 255, 0.1); +} + +.tag-filter.active .tag-count { + background: rgba(255, 255, 255, 0.3); + color: white; +} + +.tags-loading, .tags-error, .no-tags { + width: 100%; + padding: 8px; + text-align: center; + font-size: 0.9em; + color: var(--text-color); + opacity: 0.7; +} + +.tags-error { + color: var(--lora-error); +} + /* Filter actions */ .filter-actions { display: flex; @@ -276,4 +314,4 @@ right: 20px; top: 140px; } -} \ No newline at end of file +} \ No newline at end of file diff --git a/static/js/api/loraApi.js b/static/js/api/loraApi.js index 132923f1..fe961412 100644 --- a/static/js/api/loraApi.js +++ b/static/js/api/loraApi.js @@ -32,9 +32,15 @@ export async function loadMoreLoras(boolUpdateFolders = false) { } // Add filter parameters if active - if (state.filters && state.filters.baseModel && state.filters.baseModel.length > 0) { - // Convert the array of base models to a comma-separated string - params.append('base_models', state.filters.baseModel.join(',')); + if (state.filters) { + if (state.filters.tags && state.filters.tags.length > 0) { + // Convert the array of tags to a comma-separated string + params.append('tags', state.filters.tags.join(',')); + } + if (state.filters.baseModel && state.filters.baseModel.length > 0) { + // Convert the array of base models to a comma-separated string + params.append('base_models', state.filters.baseModel.join(',')); + } } console.log('Loading loras with params:', params.toString()); diff --git a/static/js/components/LoraCard.js b/static/js/components/LoraCard.js index 99985f4d..35de22ea 100644 --- a/static/js/components/LoraCard.js +++ b/static/js/components/LoraCard.js @@ -18,6 +18,14 @@ export function createLoraCard(lora) { card.dataset.usage_tips = lora.usage_tips; card.dataset.notes = lora.notes; card.dataset.meta = JSON.stringify(lora.civitai || {}); + + // Store tags and model description + if (lora.tags && Array.isArray(lora.tags)) { + card.dataset.tags = JSON.stringify(lora.tags); + } + if (lora.modelDescription) { + card.dataset.modelDescription = lora.modelDescription; + } // Apply selection state if in bulk mode and this card is in the selected set if (state.bulkMode && state.selectedLoras.has(lora.file_path)) { @@ -86,7 +94,9 @@ export function createLoraCard(lora) { base_model: card.dataset.base_model, usage_tips: card.dataset.usage_tips, notes: card.dataset.notes, - civitai: JSON.parse(card.dataset.meta || '{}') + civitai: JSON.parse(card.dataset.meta || '{}'), + tags: JSON.parse(card.dataset.tags || '[]'), + modelDescription: card.dataset.modelDescription || '' }; showLoraModal(loraMeta); } diff --git a/static/js/components/LoraModal.js b/static/js/components/LoraModal.js index ac55487d..4f4b3c3d 100644 --- a/static/js/components/LoraModal.js +++ b/static/js/components/LoraModal.js @@ -15,6 +15,7 @@ export function showLoraModal(lora) { + ${renderTags(lora.tags || [])}