Merge branch 'main' into dev

This commit is contained in:
Will Miao
2025-03-13 11:45:43 +08:00
48 changed files with 3592 additions and 269 deletions

View File

@@ -163,6 +163,53 @@ class CivitaiClient:
logger.error(f"Error fetching model version info: {e}")
return None
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
"""Fetch model metadata (description and tags) from Civitai API
Args:
model_id: The Civitai model ID
Returns:
Tuple[Optional[Dict], int]: A tuple containing:
- A dictionary with model metadata or None if not found
- The HTTP status code from the request
"""
try:
session = await self.session
headers = self._get_request_headers()
url = f"{self.base_url}/models/{model_id}"
async with session.get(url, headers=headers) as response:
status_code = response.status
if status_code != 200:
logger.warning(f"Failed to fetch model metadata: Status {status_code}")
return None, status_code
data = await response.json()
# Extract relevant metadata
metadata = {
"description": data.get("description") or "No model description available",
"tags": data.get("tags", [])
}
if metadata["description"] or metadata["tags"]:
return metadata, status_code
else:
logger.warning(f"No metadata found for model {model_id}")
return None, status_code
except Exception as e:
logger.error(f"Error fetching model metadata: {e}", exc_info=True)
return None, 0
# Keep old method for backward compatibility, delegating to the new one
async def get_model_description(self, model_id: str) -> Optional[str]:
"""Fetch the model description from Civitai API (Legacy method)"""
metadata, _ = await self.get_model_metadata(model_id)
return metadata.get("description") if metadata else None
async def close(self):
"""Close the session if it exists"""
if self._session is not None:

View File

@@ -51,6 +51,16 @@ class DownloadManager:
# 5. 准备元数据
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
# 5.1 获取并更新模型标签和描述信息
model_id = version_info.get('modelId')
if model_id:
model_metadata, _ = await self.civitai_client.get_model_metadata(str(model_id))
if model_metadata:
if model_metadata.get("tags"):
metadata.tags = model_metadata.get("tags", [])
if model_metadata.get("description"):
metadata.modelDescription = model_metadata.get("description", "")
# 6. 开始下载流程
result = await self._execute_download(
download_url=download_url,
@@ -86,6 +96,7 @@ class DownloadManager:
preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext
if await self.civitai_client.download_preview_image(images[0]['url'], preview_path):
metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)

View File

@@ -98,6 +98,10 @@ class LoraFileHandler(FileSystemEventHandler):
# Scan new file
lora_data = await self.scanner.scan_single_lora(file_path)
if lora_data:
# Update tags count
for tag in lora_data.get('tags', []):
self.scanner._tags_count[tag] = self.scanner._tags_count.get(tag, 0) + 1
cache.raw_data.append(lora_data)
new_folders.add(lora_data['folder'])
# Update hash index
@@ -109,6 +113,16 @@ class LoraFileHandler(FileSystemEventHandler):
needs_resort = True
elif action == 'remove':
# Find the lora to remove so we can update tags count
lora_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if lora_to_remove:
# Update tags count by reducing counts
for tag in lora_to_remove.get('tags', []):
if tag in self.scanner._tags_count:
self.scanner._tags_count[tag] = max(0, self.scanner._tags_count[tag] - 1)
if self.scanner._tags_count[tag] == 0:
del self.scanner._tags_count[tag]
# Remove from cache and hash index
logger.info(f"Removing {file_path} from cache")
self.scanner._hash_index.remove_by_path(file_path)

View File

@@ -11,6 +11,8 @@ from ..utils.file_utils import load_metadata, get_file_info
from .lora_cache import LoraCache
from difflib import SequenceMatcher
from .lora_hash_index import LoraHashIndex
from .settings_manager import settings
from ..utils.constants import NSFW_LEVELS
import sys
logger = logging.getLogger(__name__)
@@ -35,6 +37,7 @@ class LoraScanner:
self._initialization_task: Optional[asyncio.Task] = None
self._initialized = True
self.file_monitor = None # Add this line
self._tags_count = {} # Add a dictionary to store tag counts
def set_file_monitor(self, monitor):
"""Set file monitor instance"""
@@ -91,13 +94,21 @@ class LoraScanner:
# Clear existing hash index
self._hash_index.clear()
# Clear existing tags count
self._tags_count = {}
# Scan for new data
raw_data = await self.scan_all_loras()
# Build hash index
# Build hash index and tags count
for lora_data in raw_data:
if 'sha256' in lora_data and 'file_path' in lora_data:
self._hash_index.add_entry(lora_data['sha256'], lora_data['file_path'])
# Count tags
if 'tags' in lora_data and lora_data['tags']:
for tag in lora_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Update cache
self._cache = LoraCache(
@@ -159,7 +170,8 @@ class LoraScanner:
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
folder: str = None, search: str = None, fuzzy: bool = False,
recursive: bool = False, base_models: list = None):
recursive: bool = False, base_models: list = None, tags: list = None,
search_options: dict = None) -> Dict:
"""Get paginated and filtered lora data
Args:
@@ -171,22 +183,39 @@ class LoraScanner:
fuzzy: Use fuzzy matching for search
recursive: Include subfolders when folder filter is applied
base_models: List of base models to filter by
tags: List of tags to filter by
search_options: Dictionary with search options (filename, modelname, tags)
"""
cache = await self.get_cached_data()
# 先获取基础数据集
# Get default search options if not provided
if search_options is None:
search_options = {
'filename': True,
'modelname': True,
'tags': False
}
# Get the base data set
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
# 应用文件夹过滤
# Apply SFW filtering if enabled
if settings.get('show_only_sfw', False):
filtered_data = [
item for item in filtered_data
if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R']
]
# Apply folder filtering
if folder is not None:
if recursive:
# 递归模式:匹配所有以该文件夹开头的路径
# Recursive mode: match all paths starting with this folder
filtered_data = [
item for item in filtered_data
if item['folder'].startswith(folder + '/') or item['folder'] == folder
]
else:
# 非递归模式:只匹配确切的文件夹
# Non-recursive mode: match exact folder
filtered_data = [
item for item in filtered_data
if item['folder'] == folder
@@ -199,28 +228,27 @@ class LoraScanner:
if item.get('base_model') in base_models
]
# 应用搜索过滤
# Apply tag filtering
if tags and len(tags) > 0:
filtered_data = [
item for item in filtered_data
if any(tag in item.get('tags', []) for tag in tags)
]
# Apply search filtering
if search:
if fuzzy:
filtered_data = [
item for item in filtered_data
if any(
self.fuzzy_match(str(value), search)
for value in [
item.get('model_name', ''),
item.get('base_model', '')
]
if value
)
if self._fuzzy_search_match(item, search, search_options)
]
else:
# Original exact search logic
filtered_data = [
item for item in filtered_data
if search in str(item.get('model_name', '')).lower()
if self._exact_search_match(item, search, search_options)
]
# 计算分页
# Calculate pagination
total_items = len(filtered_data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
@@ -235,6 +263,44 @@ class LoraScanner:
return result
def _fuzzy_search_match(self, item: Dict, search: str, search_options: Dict) -> bool:
"""Check if an item matches the search term using fuzzy matching with search options"""
# Check filename if enabled
if search_options.get('filename', True) and self.fuzzy_match(item.get('file_name', ''), search):
return True
# Check model name if enabled
if search_options.get('modelname', True) and self.fuzzy_match(item.get('model_name', ''), search):
return True
# Check tags if enabled
if search_options.get('tags', False) and item.get('tags'):
for tag in item['tags']:
if self.fuzzy_match(tag, search):
return True
return False
def _exact_search_match(self, item: Dict, search: str, search_options: Dict) -> bool:
"""Check if an item matches the search term using exact matching with search options"""
search = search.lower()
# Check filename if enabled
if search_options.get('filename', True) and search in item.get('file_name', '').lower():
return True
# Check model name if enabled
if search_options.get('modelname', True) and search in item.get('model_name', '').lower():
return True
# Check tags if enabled
if search_options.get('tags', False) and item.get('tags'):
for tag in item['tags']:
if search in tag.lower():
return True
return False
def invalidate_cache(self):
"""Invalidate the current cache"""
self._cache = None
@@ -312,12 +378,86 @@ class LoraScanner:
# Convert to dict and add folder info
lora_data = metadata.to_dict()
# Try to fetch missing metadata from Civitai if needed
await self._fetch_missing_metadata(file_path, lora_data)
rel_path = os.path.relpath(file_path, root_path)
folder = os.path.dirname(rel_path)
lora_data['folder'] = folder.replace(os.path.sep, '/')
return lora_data
async def _fetch_missing_metadata(self, file_path: str, lora_data: Dict) -> None:
"""Fetch missing description and tags from Civitai if needed
Args:
file_path: Path to the lora file
lora_data: Lora metadata dictionary to update
"""
try:
# Skip if already marked as deleted on Civitai
if lora_data.get('civitai_deleted', False):
logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai")
return
# Check if we need to fetch additional metadata from Civitai
needs_metadata_update = False
model_id = None
# Check if we have Civitai model ID but missing metadata
if lora_data.get('civitai'):
# Try to get model ID directly from the correct location
model_id = lora_data['civitai'].get('modelId')
if model_id:
model_id = str(model_id)
# Check if tags are missing or empty
tags_missing = not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0
# Check if description is missing or empty
desc_missing = not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "")
needs_metadata_update = tags_missing or desc_missing
# Fetch missing metadata if needed
if needs_metadata_update and model_id:
logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}")
from ..services.civitai_client import CivitaiClient
client = CivitaiClient()
# Get metadata and status code
model_metadata, status_code = await client.get_model_metadata(model_id)
await client.close()
# Handle 404 status (model deleted from Civitai)
if status_code == 404:
logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)")
# Mark as deleted to avoid future API calls
lora_data['civitai_deleted'] = True
# Save the updated metadata back to file
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(lora_data, f, indent=2, ensure_ascii=False)
# Process valid metadata if available
elif model_metadata:
logger.debug(f"Updating metadata for {file_path} with model ID {model_id}")
# Update tags if they were missing
if model_metadata.get('tags') and (not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0):
lora_data['tags'] = model_metadata['tags']
# Update description if it was missing
if model_metadata.get('description') and (not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "")):
lora_data['modelDescription'] = model_metadata['description']
# Save the updated metadata back to file
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(lora_data, f, indent=2, ensure_ascii=False)
except Exception as e:
logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}")
async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool:
"""Update preview URL in cache for a specific lora
@@ -428,6 +568,15 @@ class LoraScanner:
async def update_single_lora_cache(self, original_path: str, new_path: str, metadata: Dict) -> bool:
cache = await self.get_cached_data()
# Find the existing item to remove its tags from count
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
if existing_item and 'tags' in existing_item:
for tag in existing_item.get('tags', []):
if tag in self._tags_count:
self._tags_count[tag] = max(0, self._tags_count[tag] - 1)
if self._tags_count[tag] == 0:
del self._tags_count[tag]
# Remove old path from hash index if exists
self._hash_index.remove_by_path(original_path)
@@ -461,6 +610,11 @@ class LoraScanner:
# Update folders list
all_folders = set(item['folder'] for item in cache.raw_data)
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Update tags count with the new/updated tags
if 'tags' in metadata:
for tag in metadata.get('tags', []):
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Resort cache
await cache.resort()
@@ -506,6 +660,29 @@ class LoraScanner:
"""Get hash for a LoRA by its file path"""
return self._hash_index.get_hash(file_path)
# Add new method to get top tags
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get top tags sorted by count
Args:
limit: Maximum number of tags to return
Returns:
List of dictionaries with tag name and count, sorted by count
"""
# Make sure cache is initialized
await self.get_cached_data()
# Sort tags by count in descending order
sorted_tags = sorted(
[{"tag": tag, "count": count} for tag, count in self._tags_count.items()],
key=lambda x: x['count'],
reverse=True
)
# Return limited number
return sorted_tags[:limit]
async def diagnose_hash_index(self):
"""Diagnostic method to verify hash index functionality"""
print("\n\n*** DIAGNOSING LORA HASH INDEX ***\n\n", file=sys.stderr)

View File

@@ -37,7 +37,8 @@ class SettingsManager:
def _get_default_settings(self) -> Dict[str, Any]:
"""Return default settings"""
return {
"civitai_api_key": ""
"civitai_api_key": "",
"show_only_sfw": False
}
def get(self, key: str, default: Any = None) -> Any: