Add tag filtering checkpoint

This commit is contained in:
Will Miao
2025-03-10 13:18:56 +08:00
parent 0069f84630
commit 721bef3ff8
15 changed files with 482 additions and 50 deletions

View File

@@ -45,6 +45,7 @@ class ApiRoutes:
app.router.add_post('/loras/api/save-metadata', routes.save_metadata)
app.router.add_get('/api/lora-preview-url', routes.get_lora_preview_url) # Add new route
app.router.add_post('/api/move_models_bulk', routes.move_models_bulk)
app.router.add_get('/api/top-tags', routes.get_top_tags) # Add new route for top tags
# Add update check routes
UpdateRoutes.setup_routes(app)
@@ -142,6 +143,10 @@ class ApiRoutes:
'error': 'Invalid sort parameter'
}, status=400)
# Parse tags filter parameter
tags = request.query.get('tags', '').split(',')
tags = [tag.strip() for tag in tags if tag.strip()]
# Get paginated data with search and filters
result = await self.scanner.get_paginated_data(
page=page,
@@ -151,7 +156,8 @@ class ApiRoutes:
search=search,
fuzzy=fuzzy,
recursive=recursive,
base_models=base_models # Pass base models filter
base_models=base_models, # Pass base models filter
tags=tags # Add tags parameter
)
# Format the response data
@@ -190,6 +196,8 @@ class ApiRoutes:
"file_path": lora["file_path"].replace(os.sep, "/"),
"file_size": lora["size"],
"modified": lora["modified"],
"tags": lora["tags"],
"modelDescription": lora["modelDescription"],
"from_civitai": lora.get("from_civitai", True),
"usage_tips": lora.get("usage_tips", ""),
"notes": lora.get("notes", ""),
@@ -335,6 +343,14 @@ class ApiRoutes:
local_metadata['model_name'] = civitai_metadata['model'].get('name',
local_metadata.get('model_name'))
# Fetch additional model metadata (description and tags) if we have model ID
model_id = civitai_metadata['modelId']
if model_id:
model_metadata = await client.get_model_metadata(str(model_id))
if model_metadata:
local_metadata['modelDescription'] = model_metadata.get('description', '')
local_metadata['tags'] = model_metadata.get('tags', [])
# Update base model
local_metadata['base_model'] = civitai_metadata.get('baseModel')
@@ -708,6 +724,7 @@ class ApiRoutes:
# Check if we already have the description stored in metadata
description = None
tags = []
if file_path:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
if os.path.exists(metadata_path):
@@ -715,38 +732,70 @@ class ApiRoutes:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
description = metadata.get('modelDescription')
tags = metadata.get('tags', [])
except Exception as e:
logger.error(f"Error loading metadata from {metadata_path}: {e}")
# If description is not in metadata, fetch from CivitAI
if not description:
logger.info(f"Fetching model description for model ID: {model_id}")
description = await self.civitai_client.get_model_description(model_id)
logger.info(f"Fetching model metadata for model ID: {model_id}")
model_metadata = await self.civitai_client.get_model_metadata(model_id)
# Save the description to metadata if we have a file path and got a description
if file_path and description:
try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
if os.path.exists(metadata_path):
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
metadata['modelDescription'] = description
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
logger.info(f"Saved model description to metadata for {file_path}")
except Exception as e:
logger.error(f"Error saving model description to metadata: {e}")
if model_metadata:
description = model_metadata.get('description')
tags = model_metadata.get('tags', [])
# Save the metadata to file if we have a file path and got metadata
if file_path:
try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
if os.path.exists(metadata_path):
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
metadata['modelDescription'] = description
metadata['tags'] = tags
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
logger.info(f"Saved model metadata to file for {file_path}")
except Exception as e:
logger.error(f"Error saving model metadata: {e}")
return web.json_response({
'success': True,
'description': description or "<p>No model description available.</p>"
'description': description or "<p>No model description available.</p>",
'tags': tags
})
except Exception as e:
logger.error(f"Error getting model description: {e}", exc_info=True)
logger.error(f"Error getting model metadata: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_top_tags(self, request: web.Request) -> web.Response:
"""Handle request for top tags sorted by frequency"""
try:
# Parse query parameters
limit = int(request.query.get('limit', '20'))
# Validate limit
if limit < 1 or limit > 100:
limit = 20 # Default to a reasonable limit
# Get top tags
top_tags = await self.scanner.get_top_tags(limit)
return web.json_response({
'success': True,
'tags': top_tags
})
except Exception as e:
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
return web.json_response({
'success': False,
'error': 'Internal server error'
}, status=500)

View File

@@ -30,6 +30,11 @@ class LoraRoutes:
"folder": lora["folder"],
"sha256": lora["sha256"],
"file_path": lora["file_path"].replace(os.sep, "/"),
"size": lora["size"],
"tags": lora["tags"],
"modelDescription": lora["modelDescription"],
"usage_tips": lora["usage_tips"],
"notes": lora["notes"],
"modified": lora["modified"],
"from_civitai": lora.get("from_civitai", True),
"civitai": self._filter_civitai_data(lora.get("civitai", {}))

View File

@@ -163,41 +163,52 @@ class CivitaiClient:
logger.error(f"Error fetching model version info: {e}")
return None
async def get_model_description(self, model_id: str) -> Optional[str]:
"""Fetch the model description from Civitai API
async def get_model_metadata(self, model_id: str) -> Optional[Dict]:
"""Fetch model metadata (description and tags) from Civitai API
Args:
model_id: The Civitai model ID
Returns:
Optional[str]: The model description HTML or None if not found
Optional[Dict]: A dictionary containing model metadata or None if not found
"""
try:
session = await self.session
headers = self._get_request_headers()
url = f"{self.base_url}/models/{model_id}"
logger.info(f"Fetching model description from {url}")
logger.info(f"Fetching model metadata from {url}")
async with session.get(url, headers=headers) as response:
if response.status != 200:
logger.warning(f"Failed to fetch model description: Status {response.status}")
logger.warning(f"Failed to fetch model metadata: Status {response.status}")
return None
data = await response.json()
description = data.get('description')
if description:
logger.info(f"Successfully retrieved description for model {model_id}")
return description
# Extract relevant metadata
metadata = {
"description": data.get("description", ""),
"tags": data.get("tags", [])
}
if metadata["description"] or metadata["tags"]:
logger.info(f"Successfully retrieved metadata for model {model_id}")
return metadata
else:
logger.warning(f"No description found for model {model_id}")
logger.warning(f"No metadata found for model {model_id}")
return None
except Exception as e:
logger.error(f"Error fetching model description: {e}", exc_info=True)
logger.error(f"Error fetching model metadata: {e}", exc_info=True)
return None
# Keep old method for backward compatibility, delegating to the new one
async def get_model_description(self, model_id: str) -> Optional[str]:
"""Fetch the model description from Civitai API (Legacy method)"""
metadata = await self.get_model_metadata(model_id)
return metadata.get("description") if metadata else None
async def close(self):
"""Close the session if it exists"""
if self._session is not None:

View File

@@ -98,6 +98,10 @@ class LoraFileHandler(FileSystemEventHandler):
# Scan new file
lora_data = await self.scanner.scan_single_lora(file_path)
if lora_data:
# Update tags count
for tag in lora_data.get('tags', []):
self.scanner._tags_count[tag] = self.scanner._tags_count.get(tag, 0) + 1
cache.raw_data.append(lora_data)
new_folders.add(lora_data['folder'])
# Update hash index
@@ -109,6 +113,16 @@ class LoraFileHandler(FileSystemEventHandler):
needs_resort = True
elif action == 'remove':
# Find the lora to remove so we can update tags count
lora_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if lora_to_remove:
# Update tags count by reducing counts
for tag in lora_to_remove.get('tags', []):
if tag in self.scanner._tags_count:
self.scanner._tags_count[tag] = max(0, self.scanner._tags_count[tag] - 1)
if self.scanner._tags_count[tag] == 0:
del self.scanner._tags_count[tag]
# Remove from cache and hash index
logger.info(f"Removing {file_path} from cache")
self.scanner._hash_index.remove_by_path(file_path)

View File

@@ -34,6 +34,7 @@ class LoraScanner:
self._initialization_task: Optional[asyncio.Task] = None
self._initialized = True
self.file_monitor = None # Add this line
self._tags_count = {} # Add a dictionary to store tag counts
def set_file_monitor(self, monitor):
"""Set file monitor instance"""
@@ -90,13 +91,21 @@ class LoraScanner:
# Clear existing hash index
self._hash_index.clear()
# Clear existing tags count
self._tags_count = {}
# Scan for new data
raw_data = await self.scan_all_loras()
# Build hash index
# Build hash index and tags count
for lora_data in raw_data:
if 'sha256' in lora_data and 'file_path' in lora_data:
self._hash_index.add_entry(lora_data['sha256'], lora_data['file_path'])
# Count tags
if 'tags' in lora_data and lora_data['tags']:
for tag in lora_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Update cache
self._cache = LoraCache(
@@ -158,7 +167,7 @@ class LoraScanner:
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
folder: str = None, search: str = None, fuzzy: bool = False,
recursive: bool = False, base_models: list = None):
recursive: bool = False, base_models: list = None, tags: list = None) -> Dict:
"""Get paginated and filtered lora data
Args:
@@ -170,6 +179,7 @@ class LoraScanner:
fuzzy: Use fuzzy matching for search
recursive: Include subfolders when folder filter is applied
base_models: List of base models to filter by
tags: List of tags to filter by
"""
cache = await self.get_cached_data()
@@ -198,6 +208,13 @@ class LoraScanner:
if item.get('base_model') in base_models
]
# Apply tag filtering
if tags and len(tags) > 0:
filtered_data = [
item for item in filtered_data
if any(tag in item.get('tags', []) for tag in tags)
]
# 应用搜索过滤
if search:
if fuzzy:
@@ -311,12 +328,67 @@ class LoraScanner:
# Convert to dict and add folder info
lora_data = metadata.to_dict()
# Try to fetch missing metadata from Civitai if needed
await self._fetch_missing_metadata(file_path, lora_data)
rel_path = os.path.relpath(file_path, root_path)
folder = os.path.dirname(rel_path)
lora_data['folder'] = folder.replace(os.path.sep, '/')
return lora_data
async def _fetch_missing_metadata(self, file_path: str, lora_data: Dict) -> None:
"""Fetch missing description and tags from Civitai if needed
Args:
file_path: Path to the lora file
lora_data: Lora metadata dictionary to update
"""
try:
# Check if we need to fetch additional metadata from Civitai
needs_metadata_update = False
model_id = None
# Check if we have Civitai model ID but missing metadata
if lora_data.get('civitai'):
# Try to get model ID directly from the correct location
model_id = lora_data['civitai'].get('modelId')
if model_id:
model_id = str(model_id)
# Check if tags are missing or empty
tags_missing = not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0
# Check if description is missing or empty
desc_missing = not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "")
needs_metadata_update = tags_missing or desc_missing
# Fetch missing metadata if needed
if needs_metadata_update and model_id:
logger.info(f"Fetching missing metadata for {file_path} with model ID {model_id}")
from ..services.civitai_client import CivitaiClient
client = CivitaiClient()
model_metadata = await client.get_model_metadata(model_id)
await client.close()
if model_metadata:
logger.info(f"Updating metadata for {file_path} with model ID {model_id}")
# Update tags if they were missing
if model_metadata.get('tags') and (not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0):
lora_data['tags'] = model_metadata['tags']
# Update description if it was missing
if model_metadata.get('description') and (not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "")):
lora_data['modelDescription'] = model_metadata['description']
# Save the updated metadata back to file
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(lora_data, f, indent=2, ensure_ascii=False)
except Exception as e:
logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}")
async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool:
"""Update preview URL in cache for a specific lora
@@ -427,6 +499,15 @@ class LoraScanner:
async def update_single_lora_cache(self, original_path: str, new_path: str, metadata: Dict) -> bool:
cache = await self.get_cached_data()
# Find the existing item to remove its tags from count
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
if existing_item and 'tags' in existing_item:
for tag in existing_item.get('tags', []):
if tag in self._tags_count:
self._tags_count[tag] = max(0, self._tags_count[tag] - 1)
if self._tags_count[tag] == 0:
del self._tags_count[tag]
# Remove old path from hash index if exists
self._hash_index.remove_by_path(original_path)
@@ -460,6 +541,11 @@ class LoraScanner:
# Update folders list
all_folders = set(item['folder'] for item in cache.raw_data)
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Update tags count with the new/updated tags
if 'tags' in metadata:
for tag in metadata.get('tags', []):
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Resort cache
await cache.resort()
@@ -505,3 +591,26 @@ class LoraScanner:
"""Get hash for a LoRA by its file path"""
return self._hash_index.get_hash(file_path)
# Add new method to get top tags
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get top tags sorted by count
Args:
limit: Maximum number of tags to return
Returns:
List of dictionaries with tag name and count, sorted by count
"""
# Make sure cache is initialized
await self.get_cached_data()
# Sort tags by count in descending order
sorted_tags = sorted(
[{"tag": tag, "count": count} for tag, count in self._tags_count.items()],
key=lambda x: x['count'],
reverse=True
)
# Return limited number
return sorted_tags[:limit]

View File

@@ -69,6 +69,8 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
notes="",
from_civitai=True,
preview_url=normalize_path(preview_url),
tags=[],
modelDescription=""
)
# create metadata file

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, asdict
from typing import Dict, Optional
from typing import Dict, Optional, List
from datetime import datetime
import os
from .model_utils import determine_base_model
@@ -17,8 +17,15 @@ class LoraMetadata:
preview_url: str # Preview image URL
usage_tips: str = "{}" # Usage tips for the model, json string
notes: str = "" # Additional notes
from_civitai: bool = True # Whether the lora is from Civitai
from_civitai: bool = True # Whether the lora is from Civitai
civitai: Optional[Dict] = None # Civitai API data if available
tags: List[str] = None # Model tags
modelDescription: str = "" # Full model description
def __post_init__(self):
# Initialize empty lists to avoid mutable default parameter issue
if self.tags is None:
self.tags = []
@classmethod
def from_dict(cls, data: Dict) -> 'LoraMetadata':