mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 06:32:12 -03:00
Merge branch 'main' into dev
This commit is contained in:
@@ -8,7 +8,7 @@ from .utils import FlexibleOptionalInputType, any_type
|
||||
|
||||
class LoraManagerLoader:
|
||||
NAME = "Lora Loader (LoraManager)"
|
||||
CATEGORY = "loaders"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
@@ -49,11 +49,32 @@ class LoraManagerLoader:
|
||||
return relative_path, trigger_words
|
||||
return lora_name, [] # Fallback if not found
|
||||
|
||||
def extract_lora_name(self, lora_path):
|
||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
||||
# Get the basename without extension
|
||||
basename = os.path.basename(lora_path)
|
||||
return os.path.splitext(basename)[0]
|
||||
|
||||
def load_loras(self, model, clip, text, **kwargs):
|
||||
"""Loads multiple LoRAs based on the kwargs input."""
|
||||
"""Loads multiple LoRAs based on the kwargs input and lora_stack."""
|
||||
loaded_loras = []
|
||||
all_trigger_words = []
|
||||
|
||||
lora_stack = kwargs.get('lora_stack', None)
|
||||
# First process lora_stack if available
|
||||
if lora_stack:
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
# Apply the LoRA using the provided path and strengths
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
# Extract lora name for trigger words lookup
|
||||
lora_name = self.extract_lora_name(lora_path)
|
||||
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
||||
|
||||
all_trigger_words.extend(trigger_words)
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Then process loras from kwargs
|
||||
if 'loras' in kwargs:
|
||||
for lora in kwargs['loras']:
|
||||
if not lora.get('active', False):
|
||||
@@ -72,6 +93,7 @@ class LoraManagerLoader:
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
trigger_words_text = ", ".join(all_trigger_words) if all_trigger_words else ""
|
||||
# use ',, ' to separate trigger words for group mode
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
return (model, clip, trigger_words_text)
|
||||
91
py/nodes/lora_stacker.py
Normal file
91
py/nodes/lora_stacker.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from comfy.comfy_types import IO # type: ignore
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
import asyncio
|
||||
import os
|
||||
from .utils import FlexibleOptionalInputType, any_type
|
||||
|
||||
class LoraStacker:
|
||||
NAME = "Lora Stacker (LoraManager)"
|
||||
CATEGORY = "Lora Manager/stackers"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"text": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"dynamicPrompts": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
"placeholder": "LoRA syntax input: <lora:name:strength>"
|
||||
}),
|
||||
},
|
||||
"optional": FlexibleOptionalInputType(any_type),
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LORA_STACK", IO.STRING)
|
||||
RETURN_NAMES = ("LORA_STACK", "trigger_words")
|
||||
FUNCTION = "stack_loras"
|
||||
|
||||
async def get_lora_info(self, lora_name):
|
||||
"""Get the lora path and trigger words from cache"""
|
||||
scanner = await LoraScanner.get_instance()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == lora_name:
|
||||
file_path = item.get('file_path')
|
||||
if file_path:
|
||||
for root in config.loras_roots:
|
||||
root = root.replace(os.sep, '/')
|
||||
if file_path.startswith(root):
|
||||
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
||||
# Get trigger words from civitai metadata
|
||||
civitai = item.get('civitai', {})
|
||||
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
||||
return relative_path, trigger_words
|
||||
return lora_name, [] # Fallback if not found
|
||||
|
||||
def extract_lora_name(self, lora_path):
|
||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
||||
# Get the basename without extension
|
||||
basename = os.path.basename(lora_path)
|
||||
return os.path.splitext(basename)[0]
|
||||
|
||||
def stack_loras(self, text, **kwargs):
|
||||
"""Stacks multiple LoRAs based on the kwargs input without loading them."""
|
||||
stack = []
|
||||
all_trigger_words = []
|
||||
|
||||
# Process existing lora_stack if available
|
||||
lora_stack = kwargs.get('lora_stack', None)
|
||||
if lora_stack:
|
||||
stack.extend(lora_stack)
|
||||
# Get trigger words from existing stack entries
|
||||
for lora_path, _, _ in lora_stack:
|
||||
lora_name = self.extract_lora_name(lora_path)
|
||||
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
if 'loras' in kwargs:
|
||||
for lora in kwargs['loras']:
|
||||
if not lora.get('active', False):
|
||||
continue
|
||||
|
||||
lora_name = lora['name']
|
||||
model_strength = float(lora['strength'])
|
||||
clip_strength = model_strength # Using same strength for both as in the original loader
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
||||
|
||||
# Add to stack without loading
|
||||
stack.append((lora_path, model_strength, clip_strength))
|
||||
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# use ',, ' to separate trigger words for group mode
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
return (stack, trigger_words_text)
|
||||
@@ -1,17 +1,18 @@
|
||||
import json
|
||||
import re
|
||||
from server import PromptServer # type: ignore
|
||||
from .utils import FlexibleOptionalInputType, any_type
|
||||
|
||||
class TriggerWordToggle:
|
||||
NAME = "TriggerWord Toggle (LoraManager)"
|
||||
CATEGORY = "lora manager"
|
||||
CATEGORY = "Lora Manager/utils"
|
||||
DESCRIPTION = "Toggle trigger words on/off"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"trigger_words": ("STRING", {"defaultInput": True, "forceInput": True}),
|
||||
"group_mode": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
"optional": FlexibleOptionalInputType(any_type),
|
||||
"hidden": {
|
||||
@@ -23,7 +24,8 @@ class TriggerWordToggle:
|
||||
RETURN_NAMES = ("filtered_trigger_words",)
|
||||
FUNCTION = "process_trigger_words"
|
||||
|
||||
def process_trigger_words(self, trigger_words, id, **kwargs):
|
||||
def process_trigger_words(self, id, group_mode, **kwargs):
|
||||
trigger_words = kwargs.get("trigger_words", "")
|
||||
# Send trigger words to frontend
|
||||
PromptServer.instance.send_sync("trigger_word_update", {
|
||||
"id": id,
|
||||
@@ -41,20 +43,33 @@ class TriggerWordToggle:
|
||||
if isinstance(trigger_data, str):
|
||||
trigger_data = json.loads(trigger_data)
|
||||
|
||||
# Create dictionaries to track active state of words
|
||||
# Create dictionaries to track active state of words or groups
|
||||
active_state = {item['text']: item.get('active', False) for item in trigger_data}
|
||||
|
||||
# Split original trigger words
|
||||
original_words = [word.strip() for word in trigger_words.split(',')]
|
||||
|
||||
# Filter words: keep those not in toggle_trigger_words or those that are active
|
||||
filtered_words = [word for word in original_words if word not in active_state or active_state[word]]
|
||||
|
||||
# Join them in the same format as input
|
||||
if filtered_words:
|
||||
filtered_triggers = ', '.join(filtered_words)
|
||||
if group_mode:
|
||||
# Split by two or more consecutive commas to get groups
|
||||
groups = re.split(r',{2,}', trigger_words)
|
||||
# Remove leading/trailing whitespace from each group
|
||||
groups = [group.strip() for group in groups]
|
||||
|
||||
# Filter groups: keep those not in toggle_trigger_words or those that are active
|
||||
filtered_groups = [group for group in groups if group not in active_state or active_state[group]]
|
||||
|
||||
if filtered_groups:
|
||||
filtered_triggers = ', '.join(filtered_groups)
|
||||
else:
|
||||
filtered_triggers = ""
|
||||
else:
|
||||
filtered_triggers = ""
|
||||
# Original behavior for individual words mode
|
||||
original_words = [word.strip() for word in trigger_words.split(',')]
|
||||
# Filter out empty strings
|
||||
original_words = [word for word in original_words if word]
|
||||
filtered_words = [word for word in original_words if word not in active_state or active_state[word]]
|
||||
|
||||
if filtered_words:
|
||||
filtered_triggers = ', '.join(filtered_words)
|
||||
else:
|
||||
filtered_triggers = ""
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing trigger words: {e}")
|
||||
|
||||
@@ -4,6 +4,7 @@ class AnyType(str):
|
||||
def __ne__(self, __value: object) -> bool:
|
||||
return False
|
||||
|
||||
# Credit to Regis Gaughan, III (rgthree)
|
||||
class FlexibleOptionalInputType(dict):
|
||||
"""A special class to make flexible nodes that pass data to our python handlers.
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ import logging
|
||||
from aiohttp import web
|
||||
from typing import Dict, List
|
||||
|
||||
from ..utils.model_utils import determine_base_model
|
||||
|
||||
from ..services.file_monitor import LoraFileMonitor
|
||||
from ..services.download_manager import DownloadManager
|
||||
from ..services.civitai_client import CivitaiClient
|
||||
@@ -42,9 +44,11 @@ class ApiRoutes:
|
||||
app.router.add_post('/api/download-lora', routes.download_lora)
|
||||
app.router.add_post('/api/settings', routes.update_settings)
|
||||
app.router.add_post('/api/move_model', routes.move_model)
|
||||
app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route
|
||||
app.router.add_post('/loras/api/save-metadata', routes.save_metadata)
|
||||
app.router.add_get('/api/lora-preview-url', routes.get_lora_preview_url) # Add new route
|
||||
app.router.add_post('/api/move_models_bulk', routes.move_models_bulk)
|
||||
app.router.add_get('/api/top-tags', routes.get_top_tags) # Add new route for top tags
|
||||
app.router.add_get('/api/recipes', cls.handle_get_recipes)
|
||||
|
||||
# Add update check routes
|
||||
@@ -132,6 +136,11 @@ class ApiRoutes:
|
||||
base_models = request.query.get('base_models', '').split(',')
|
||||
base_models = [model.strip() for model in base_models if model.strip()]
|
||||
|
||||
# Parse search options
|
||||
search_filename = request.query.get('search_filename', 'true').lower() == 'true'
|
||||
search_modelname = request.query.get('search_modelname', 'true').lower() == 'true'
|
||||
search_tags = request.query.get('search_tags', 'false').lower() == 'true'
|
||||
|
||||
# Validate parameters
|
||||
if page < 1 or page_size < 1 or page_size > 100:
|
||||
return web.json_response({
|
||||
@@ -143,6 +152,10 @@ class ApiRoutes:
|
||||
'error': 'Invalid sort parameter'
|
||||
}, status=400)
|
||||
|
||||
# Parse tags filter parameter
|
||||
tags = request.query.get('tags', '').split(',')
|
||||
tags = [tag.strip() for tag in tags if tag.strip()]
|
||||
|
||||
# Get paginated data with search and filters
|
||||
result = await self.scanner.get_paginated_data(
|
||||
page=page,
|
||||
@@ -152,7 +165,13 @@ class ApiRoutes:
|
||||
search=search,
|
||||
fuzzy=fuzzy,
|
||||
recursive=recursive,
|
||||
base_models=base_models # Pass base models filter
|
||||
base_models=base_models, # Pass base models filter
|
||||
tags=tags, # Add tags parameter
|
||||
search_options={
|
||||
'filename': search_filename,
|
||||
'modelname': search_modelname,
|
||||
'tags': search_tags
|
||||
}
|
||||
)
|
||||
|
||||
# Format the response data
|
||||
@@ -185,12 +204,15 @@ class ApiRoutes:
|
||||
"model_name": lora["model_name"],
|
||||
"file_name": lora["file_name"],
|
||||
"preview_url": config.get_preview_static_url(lora["preview_url"]),
|
||||
"preview_nsfw_level": lora.get("preview_nsfw_level", 0),
|
||||
"base_model": lora["base_model"],
|
||||
"folder": lora["folder"],
|
||||
"sha256": lora["sha256"],
|
||||
"file_path": lora["file_path"].replace(os.sep, "/"),
|
||||
"file_size": lora["size"],
|
||||
"modified": lora["modified"],
|
||||
"tags": lora["tags"],
|
||||
"modelDescription": lora["modelDescription"],
|
||||
"from_civitai": lora.get("from_civitai", True),
|
||||
"usage_tips": lora.get("usage_tips", ""),
|
||||
"notes": lora.get("notes", ""),
|
||||
@@ -333,8 +355,16 @@ class ApiRoutes:
|
||||
|
||||
# Update model name if available
|
||||
if 'model' in civitai_metadata:
|
||||
local_metadata['model_name'] = civitai_metadata['model'].get('name',
|
||||
local_metadata.get('model_name'))
|
||||
if civitai_metadata.get('model', {}).get('name'):
|
||||
local_metadata['model_name'] = determine_base_model(civitai_metadata['model']['name'])
|
||||
|
||||
# Fetch additional model metadata (description and tags) if we have model ID
|
||||
model_id = civitai_metadata['modelId']
|
||||
if model_id:
|
||||
model_metadata, _ = await client.get_model_metadata(str(model_id))
|
||||
if model_metadata:
|
||||
local_metadata['modelDescription'] = model_metadata.get('description', '')
|
||||
local_metadata['tags'] = model_metadata.get('tags', [])
|
||||
|
||||
# Update base model
|
||||
local_metadata['base_model'] = civitai_metadata.get('baseModel')
|
||||
@@ -350,6 +380,7 @@ class ApiRoutes:
|
||||
|
||||
if await client.download_preview_image(first_preview['url'], preview_path):
|
||||
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
|
||||
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
|
||||
|
||||
# Save updated metadata
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
@@ -369,7 +400,7 @@ class ApiRoutes:
|
||||
# 准备要处理的 loras
|
||||
to_process = [
|
||||
lora for lora in cache.raw_data
|
||||
if lora.get('sha256') and not lora.get('civitai') and lora.get('from_civitai')
|
||||
if lora.get('sha256') and (not lora.get('civitai') or 'id' not in lora.get('civitai')) and lora.get('from_civitai') # TODO: for lora not from CivitAI but added traineWords
|
||||
]
|
||||
total_to_process = len(to_process)
|
||||
|
||||
@@ -547,6 +578,8 @@ class ApiRoutes:
|
||||
# Validate and update settings
|
||||
if 'civitai_api_key' in data:
|
||||
settings.set('civitai_api_key', data['civitai_api_key'])
|
||||
if 'show_only_sfw' in data:
|
||||
settings.set('show_only_sfw', data['show_only_sfw'])
|
||||
|
||||
return web.json_response({'success': True})
|
||||
except Exception as e:
|
||||
@@ -602,8 +635,15 @@ class ApiRoutes:
|
||||
else:
|
||||
metadata = {}
|
||||
|
||||
# Update metadata with new values
|
||||
metadata.update(metadata_updates)
|
||||
# Handle nested updates (for civitai.trainedWords)
|
||||
for key, value in metadata_updates.items():
|
||||
if isinstance(value, dict) and key in metadata and isinstance(metadata[key], dict):
|
||||
# Deep update for nested dictionaries
|
||||
for nested_key, nested_value in value.items():
|
||||
metadata[key][nested_key] = nested_value
|
||||
else:
|
||||
# Regular update for top-level keys
|
||||
metadata[key] = value
|
||||
|
||||
# Save updated metadata
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
@@ -694,6 +734,97 @@ class ApiRoutes:
|
||||
logger.error(f"Error moving models in bulk: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_lora_model_description(self, request: web.Request) -> web.Response:
|
||||
"""Get model description for a Lora model"""
|
||||
try:
|
||||
# Get parameters
|
||||
model_id = request.query.get('model_id')
|
||||
file_path = request.query.get('file_path')
|
||||
|
||||
if not model_id:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Model ID is required'
|
||||
}, status=400)
|
||||
|
||||
# Check if we already have the description stored in metadata
|
||||
description = None
|
||||
tags = []
|
||||
if file_path:
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
if os.path.exists(metadata_path):
|
||||
try:
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
description = metadata.get('modelDescription')
|
||||
tags = metadata.get('tags', [])
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading metadata from {metadata_path}: {e}")
|
||||
|
||||
# If description is not in metadata, fetch from CivitAI
|
||||
if not description:
|
||||
logger.info(f"Fetching model metadata for model ID: {model_id}")
|
||||
model_metadata, _ = await self.civitai_client.get_model_metadata(model_id)
|
||||
|
||||
if model_metadata:
|
||||
description = model_metadata.get('description')
|
||||
tags = model_metadata.get('tags', [])
|
||||
|
||||
# Save the metadata to file if we have a file path and got metadata
|
||||
if file_path:
|
||||
try:
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
if os.path.exists(metadata_path):
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
metadata['modelDescription'] = description
|
||||
metadata['tags'] = tags
|
||||
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Saved model metadata to file for {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model metadata: {e}")
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'description': description or "<p>No model description available.</p>",
|
||||
'tags': tags
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model metadata: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||
"""Handle request for top tags sorted by frequency"""
|
||||
try:
|
||||
# Parse query parameters
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
|
||||
# Validate limit
|
||||
if limit < 1 or limit > 100:
|
||||
limit = 20 # Default to a reasonable limit
|
||||
|
||||
# Get top tags
|
||||
top_tags = await self.scanner.get_top_tags(limit)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'tags': top_tags
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Internal server error'
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_get_recipes(request):
|
||||
"""API endpoint for getting paginated recipes"""
|
||||
|
||||
@@ -28,10 +28,16 @@ class LoraRoutes:
|
||||
"model_name": lora["model_name"],
|
||||
"file_name": lora["file_name"],
|
||||
"preview_url": config.get_preview_static_url(lora["preview_url"]),
|
||||
"preview_nsfw_level": lora.get("preview_nsfw_level", 0),
|
||||
"base_model": lora["base_model"],
|
||||
"folder": lora["folder"],
|
||||
"sha256": lora["sha256"],
|
||||
"file_path": lora["file_path"].replace(os.sep, "/"),
|
||||
"size": lora["size"],
|
||||
"tags": lora["tags"],
|
||||
"modelDescription": lora["modelDescription"],
|
||||
"usage_tips": lora["usage_tips"],
|
||||
"notes": lora["notes"],
|
||||
"modified": lora["modified"],
|
||||
"from_civitai": lora.get("from_civitai", True),
|
||||
"civitai": self._filter_civitai_data(lora.get("civitai", {}))
|
||||
|
||||
@@ -163,6 +163,53 @@ class CivitaiClient:
|
||||
logger.error(f"Error fetching model version info: {e}")
|
||||
return None
|
||||
|
||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
||||
"""Fetch model metadata (description and tags) from Civitai API
|
||||
|
||||
Args:
|
||||
model_id: The Civitai model ID
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[Dict], int]: A tuple containing:
|
||||
- A dictionary with model metadata or None if not found
|
||||
- The HTTP status code from the request
|
||||
"""
|
||||
try:
|
||||
session = await self.session
|
||||
headers = self._get_request_headers()
|
||||
url = f"{self.base_url}/models/{model_id}"
|
||||
|
||||
async with session.get(url, headers=headers) as response:
|
||||
status_code = response.status
|
||||
|
||||
if status_code != 200:
|
||||
logger.warning(f"Failed to fetch model metadata: Status {status_code}")
|
||||
return None, status_code
|
||||
|
||||
data = await response.json()
|
||||
|
||||
# Extract relevant metadata
|
||||
metadata = {
|
||||
"description": data.get("description") or "No model description available",
|
||||
"tags": data.get("tags", [])
|
||||
}
|
||||
|
||||
if metadata["description"] or metadata["tags"]:
|
||||
return metadata, status_code
|
||||
else:
|
||||
logger.warning(f"No metadata found for model {model_id}")
|
||||
return None, status_code
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model metadata: {e}", exc_info=True)
|
||||
return None, 0
|
||||
|
||||
# Keep old method for backward compatibility, delegating to the new one
|
||||
async def get_model_description(self, model_id: str) -> Optional[str]:
|
||||
"""Fetch the model description from Civitai API (Legacy method)"""
|
||||
metadata, _ = await self.get_model_metadata(model_id)
|
||||
return metadata.get("description") if metadata else None
|
||||
|
||||
async def close(self):
|
||||
"""Close the session if it exists"""
|
||||
if self._session is not None:
|
||||
|
||||
@@ -51,6 +51,16 @@ class DownloadManager:
|
||||
# 5. 准备元数据
|
||||
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
|
||||
# 5.1 获取并更新模型标签和描述信息
|
||||
model_id = version_info.get('modelId')
|
||||
if model_id:
|
||||
model_metadata, _ = await self.civitai_client.get_model_metadata(str(model_id))
|
||||
if model_metadata:
|
||||
if model_metadata.get("tags"):
|
||||
metadata.tags = model_metadata.get("tags", [])
|
||||
if model_metadata.get("description"):
|
||||
metadata.modelDescription = model_metadata.get("description", "")
|
||||
|
||||
# 6. 开始下载流程
|
||||
result = await self._execute_download(
|
||||
download_url=download_url,
|
||||
@@ -86,6 +96,7 @@ class DownloadManager:
|
||||
preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext
|
||||
if await self.civitai_client.download_preview_image(images[0]['url'], preview_path):
|
||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
||||
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
@@ -98,6 +98,10 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
# Scan new file
|
||||
lora_data = await self.scanner.scan_single_lora(file_path)
|
||||
if lora_data:
|
||||
# Update tags count
|
||||
for tag in lora_data.get('tags', []):
|
||||
self.scanner._tags_count[tag] = self.scanner._tags_count.get(tag, 0) + 1
|
||||
|
||||
cache.raw_data.append(lora_data)
|
||||
new_folders.add(lora_data['folder'])
|
||||
# Update hash index
|
||||
@@ -109,6 +113,16 @@ class LoraFileHandler(FileSystemEventHandler):
|
||||
needs_resort = True
|
||||
|
||||
elif action == 'remove':
|
||||
# Find the lora to remove so we can update tags count
|
||||
lora_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
||||
if lora_to_remove:
|
||||
# Update tags count by reducing counts
|
||||
for tag in lora_to_remove.get('tags', []):
|
||||
if tag in self.scanner._tags_count:
|
||||
self.scanner._tags_count[tag] = max(0, self.scanner._tags_count[tag] - 1)
|
||||
if self.scanner._tags_count[tag] == 0:
|
||||
del self.scanner._tags_count[tag]
|
||||
|
||||
# Remove from cache and hash index
|
||||
logger.info(f"Removing {file_path} from cache")
|
||||
self.scanner._hash_index.remove_by_path(file_path)
|
||||
|
||||
@@ -11,6 +11,8 @@ from ..utils.file_utils import load_metadata, get_file_info
|
||||
from .lora_cache import LoraCache
|
||||
from difflib import SequenceMatcher
|
||||
from .lora_hash_index import LoraHashIndex
|
||||
from .settings_manager import settings
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -35,6 +37,7 @@ class LoraScanner:
|
||||
self._initialization_task: Optional[asyncio.Task] = None
|
||||
self._initialized = True
|
||||
self.file_monitor = None # Add this line
|
||||
self._tags_count = {} # Add a dictionary to store tag counts
|
||||
|
||||
def set_file_monitor(self, monitor):
|
||||
"""Set file monitor instance"""
|
||||
@@ -91,13 +94,21 @@ class LoraScanner:
|
||||
# Clear existing hash index
|
||||
self._hash_index.clear()
|
||||
|
||||
# Clear existing tags count
|
||||
self._tags_count = {}
|
||||
|
||||
# Scan for new data
|
||||
raw_data = await self.scan_all_loras()
|
||||
|
||||
# Build hash index
|
||||
# Build hash index and tags count
|
||||
for lora_data in raw_data:
|
||||
if 'sha256' in lora_data and 'file_path' in lora_data:
|
||||
self._hash_index.add_entry(lora_data['sha256'], lora_data['file_path'])
|
||||
|
||||
# Count tags
|
||||
if 'tags' in lora_data and lora_data['tags']:
|
||||
for tag in lora_data['tags']:
|
||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||
|
||||
# Update cache
|
||||
self._cache = LoraCache(
|
||||
@@ -159,7 +170,8 @@ class LoraScanner:
|
||||
|
||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
||||
folder: str = None, search: str = None, fuzzy: bool = False,
|
||||
recursive: bool = False, base_models: list = None):
|
||||
recursive: bool = False, base_models: list = None, tags: list = None,
|
||||
search_options: dict = None) -> Dict:
|
||||
"""Get paginated and filtered lora data
|
||||
|
||||
Args:
|
||||
@@ -171,22 +183,39 @@ class LoraScanner:
|
||||
fuzzy: Use fuzzy matching for search
|
||||
recursive: Include subfolders when folder filter is applied
|
||||
base_models: List of base models to filter by
|
||||
tags: List of tags to filter by
|
||||
search_options: Dictionary with search options (filename, modelname, tags)
|
||||
"""
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
# 先获取基础数据集
|
||||
# Get default search options if not provided
|
||||
if search_options is None:
|
||||
search_options = {
|
||||
'filename': True,
|
||||
'modelname': True,
|
||||
'tags': False
|
||||
}
|
||||
|
||||
# Get the base data set
|
||||
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
||||
|
||||
# 应用文件夹过滤
|
||||
# Apply SFW filtering if enabled
|
||||
if settings.get('show_only_sfw', False):
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
||||
]
|
||||
|
||||
# Apply folder filtering
|
||||
if folder is not None:
|
||||
if recursive:
|
||||
# 递归模式:匹配所有以该文件夹开头的路径
|
||||
# Recursive mode: match all paths starting with this folder
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if item['folder'].startswith(folder + '/') or item['folder'] == folder
|
||||
]
|
||||
else:
|
||||
# 非递归模式:只匹配确切的文件夹
|
||||
# Non-recursive mode: match exact folder
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if item['folder'] == folder
|
||||
@@ -199,28 +228,27 @@ class LoraScanner:
|
||||
if item.get('base_model') in base_models
|
||||
]
|
||||
|
||||
# 应用搜索过滤
|
||||
# Apply tag filtering
|
||||
if tags and len(tags) > 0:
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if any(tag in item.get('tags', []) for tag in tags)
|
||||
]
|
||||
|
||||
# Apply search filtering
|
||||
if search:
|
||||
if fuzzy:
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if any(
|
||||
self.fuzzy_match(str(value), search)
|
||||
for value in [
|
||||
item.get('model_name', ''),
|
||||
item.get('base_model', '')
|
||||
]
|
||||
if value
|
||||
)
|
||||
if self._fuzzy_search_match(item, search, search_options)
|
||||
]
|
||||
else:
|
||||
# Original exact search logic
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if search in str(item.get('model_name', '')).lower()
|
||||
if self._exact_search_match(item, search, search_options)
|
||||
]
|
||||
|
||||
# 计算分页
|
||||
# Calculate pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
@@ -235,6 +263,44 @@ class LoraScanner:
|
||||
|
||||
return result
|
||||
|
||||
def _fuzzy_search_match(self, item: Dict, search: str, search_options: Dict) -> bool:
|
||||
"""Check if an item matches the search term using fuzzy matching with search options"""
|
||||
# Check filename if enabled
|
||||
if search_options.get('filename', True) and self.fuzzy_match(item.get('file_name', ''), search):
|
||||
return True
|
||||
|
||||
# Check model name if enabled
|
||||
if search_options.get('modelname', True) and self.fuzzy_match(item.get('model_name', ''), search):
|
||||
return True
|
||||
|
||||
# Check tags if enabled
|
||||
if search_options.get('tags', False) and item.get('tags'):
|
||||
for tag in item['tags']:
|
||||
if self.fuzzy_match(tag, search):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _exact_search_match(self, item: Dict, search: str, search_options: Dict) -> bool:
|
||||
"""Check if an item matches the search term using exact matching with search options"""
|
||||
search = search.lower()
|
||||
|
||||
# Check filename if enabled
|
||||
if search_options.get('filename', True) and search in item.get('file_name', '').lower():
|
||||
return True
|
||||
|
||||
# Check model name if enabled
|
||||
if search_options.get('modelname', True) and search in item.get('model_name', '').lower():
|
||||
return True
|
||||
|
||||
# Check tags if enabled
|
||||
if search_options.get('tags', False) and item.get('tags'):
|
||||
for tag in item['tags']:
|
||||
if search in tag.lower():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def invalidate_cache(self):
|
||||
"""Invalidate the current cache"""
|
||||
self._cache = None
|
||||
@@ -312,12 +378,86 @@ class LoraScanner:
|
||||
|
||||
# Convert to dict and add folder info
|
||||
lora_data = metadata.to_dict()
|
||||
# Try to fetch missing metadata from Civitai if needed
|
||||
await self._fetch_missing_metadata(file_path, lora_data)
|
||||
rel_path = os.path.relpath(file_path, root_path)
|
||||
folder = os.path.dirname(rel_path)
|
||||
lora_data['folder'] = folder.replace(os.path.sep, '/')
|
||||
|
||||
return lora_data
|
||||
|
||||
async def _fetch_missing_metadata(self, file_path: str, lora_data: Dict) -> None:
|
||||
"""Fetch missing description and tags from Civitai if needed
|
||||
|
||||
Args:
|
||||
file_path: Path to the lora file
|
||||
lora_data: Lora metadata dictionary to update
|
||||
"""
|
||||
try:
|
||||
# Skip if already marked as deleted on Civitai
|
||||
if lora_data.get('civitai_deleted', False):
|
||||
logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai")
|
||||
return
|
||||
|
||||
# Check if we need to fetch additional metadata from Civitai
|
||||
needs_metadata_update = False
|
||||
model_id = None
|
||||
|
||||
# Check if we have Civitai model ID but missing metadata
|
||||
if lora_data.get('civitai'):
|
||||
# Try to get model ID directly from the correct location
|
||||
model_id = lora_data['civitai'].get('modelId')
|
||||
|
||||
if model_id:
|
||||
model_id = str(model_id)
|
||||
# Check if tags are missing or empty
|
||||
tags_missing = not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0
|
||||
|
||||
# Check if description is missing or empty
|
||||
desc_missing = not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "")
|
||||
|
||||
needs_metadata_update = tags_missing or desc_missing
|
||||
|
||||
# Fetch missing metadata if needed
|
||||
if needs_metadata_update and model_id:
|
||||
logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}")
|
||||
from ..services.civitai_client import CivitaiClient
|
||||
client = CivitaiClient()
|
||||
|
||||
# Get metadata and status code
|
||||
model_metadata, status_code = await client.get_model_metadata(model_id)
|
||||
await client.close()
|
||||
|
||||
# Handle 404 status (model deleted from Civitai)
|
||||
if status_code == 404:
|
||||
logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)")
|
||||
# Mark as deleted to avoid future API calls
|
||||
lora_data['civitai_deleted'] = True
|
||||
|
||||
# Save the updated metadata back to file
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(lora_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Process valid metadata if available
|
||||
elif model_metadata:
|
||||
logger.debug(f"Updating metadata for {file_path} with model ID {model_id}")
|
||||
|
||||
# Update tags if they were missing
|
||||
if model_metadata.get('tags') and (not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0):
|
||||
lora_data['tags'] = model_metadata['tags']
|
||||
|
||||
# Update description if it was missing
|
||||
if model_metadata.get('description') and (not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "")):
|
||||
lora_data['modelDescription'] = model_metadata['description']
|
||||
|
||||
# Save the updated metadata back to file
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(lora_data, f, indent=2, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}")
|
||||
|
||||
async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool:
|
||||
"""Update preview URL in cache for a specific lora
|
||||
|
||||
@@ -428,6 +568,15 @@ class LoraScanner:
|
||||
async def update_single_lora_cache(self, original_path: str, new_path: str, metadata: Dict) -> bool:
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
# Find the existing item to remove its tags from count
|
||||
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
|
||||
if existing_item and 'tags' in existing_item:
|
||||
for tag in existing_item.get('tags', []):
|
||||
if tag in self._tags_count:
|
||||
self._tags_count[tag] = max(0, self._tags_count[tag] - 1)
|
||||
if self._tags_count[tag] == 0:
|
||||
del self._tags_count[tag]
|
||||
|
||||
# Remove old path from hash index if exists
|
||||
self._hash_index.remove_by_path(original_path)
|
||||
|
||||
@@ -461,6 +610,11 @@ class LoraScanner:
|
||||
# Update folders list
|
||||
all_folders = set(item['folder'] for item in cache.raw_data)
|
||||
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
# Update tags count with the new/updated tags
|
||||
if 'tags' in metadata:
|
||||
for tag in metadata.get('tags', []):
|
||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||
|
||||
# Resort cache
|
||||
await cache.resort()
|
||||
@@ -506,6 +660,29 @@ class LoraScanner:
|
||||
"""Get hash for a LoRA by its file path"""
|
||||
return self._hash_index.get_hash(file_path)
|
||||
|
||||
# Add new method to get top tags
|
||||
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
|
||||
"""Get top tags sorted by count
|
||||
|
||||
Args:
|
||||
limit: Maximum number of tags to return
|
||||
|
||||
Returns:
|
||||
List of dictionaries with tag name and count, sorted by count
|
||||
"""
|
||||
# Make sure cache is initialized
|
||||
await self.get_cached_data()
|
||||
|
||||
# Sort tags by count in descending order
|
||||
sorted_tags = sorted(
|
||||
[{"tag": tag, "count": count} for tag, count in self._tags_count.items()],
|
||||
key=lambda x: x['count'],
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# Return limited number
|
||||
return sorted_tags[:limit]
|
||||
|
||||
async def diagnose_hash_index(self):
|
||||
"""Diagnostic method to verify hash index functionality"""
|
||||
print("\n\n*** DIAGNOSING LORA HASH INDEX ***\n\n", file=sys.stderr)
|
||||
|
||||
@@ -37,7 +37,8 @@ class SettingsManager:
|
||||
def _get_default_settings(self) -> Dict[str, Any]:
|
||||
"""Return default settings"""
|
||||
return {
|
||||
"civitai_api_key": ""
|
||||
"civitai_api_key": "",
|
||||
"show_only_sfw": False
|
||||
}
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
|
||||
8
py/utils/constants.py
Normal file
8
py/utils/constants.py
Normal file
@@ -0,0 +1,8 @@
|
||||
NSFW_LEVELS = {
|
||||
"PG": 1,
|
||||
"PG13": 2,
|
||||
"R": 4,
|
||||
"X": 8,
|
||||
"XXX": 16,
|
||||
"Blocked": 32, # Probably not actually visible through the API without being logged in on model owner account?
|
||||
}
|
||||
@@ -4,6 +4,8 @@ import hashlib
|
||||
import json
|
||||
from typing import Dict, Optional
|
||||
|
||||
from .model_utils import determine_base_model
|
||||
|
||||
from .lora_metadata import extract_lora_metadata
|
||||
from .models import LoraMetadata
|
||||
|
||||
@@ -69,6 +71,8 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
|
||||
notes="",
|
||||
from_civitai=True,
|
||||
preview_url=normalize_path(preview_url),
|
||||
tags=[],
|
||||
modelDescription=""
|
||||
)
|
||||
|
||||
# create metadata file
|
||||
@@ -103,9 +107,18 @@ async def load_metadata(file_path: str) -> Optional[LoraMetadata]:
|
||||
data = json.load(f)
|
||||
|
||||
needs_update = False
|
||||
|
||||
# Check and normalize base model name
|
||||
normalized_base_model = determine_base_model(data['base_model'])
|
||||
if data['base_model'] != normalized_base_model:
|
||||
data['base_model'] = normalized_base_model
|
||||
needs_update = True
|
||||
|
||||
if data['file_path'] != normalize_path(data['file_path']):
|
||||
data['file_path'] = normalize_path(data['file_path'])
|
||||
# Compare paths without extensions
|
||||
stored_path_base = os.path.splitext(data['file_path'])[0]
|
||||
current_path_base = os.path.splitext(normalize_path(file_path))[0]
|
||||
if stored_path_base != current_path_base:
|
||||
data['file_path'] = normalize_path(file_path)
|
||||
needs_update = True
|
||||
|
||||
preview_url = data.get('preview_url', '')
|
||||
@@ -116,8 +129,21 @@ async def load_metadata(file_path: str) -> Optional[LoraMetadata]:
|
||||
if new_preview_url != preview_url:
|
||||
data['preview_url'] = new_preview_url
|
||||
needs_update = True
|
||||
elif preview_url != normalize_path(preview_url):
|
||||
data['preview_url'] = normalize_path(preview_url)
|
||||
else:
|
||||
# Compare preview paths without extensions
|
||||
stored_preview_base = os.path.splitext(preview_url)[0]
|
||||
current_preview_base = os.path.splitext(normalize_path(preview_url))[0]
|
||||
if stored_preview_base != current_preview_base:
|
||||
data['preview_url'] = normalize_path(preview_url)
|
||||
needs_update = True
|
||||
|
||||
# Ensure all fields are present
|
||||
if 'tags' not in data:
|
||||
data['tags'] = []
|
||||
needs_update = True
|
||||
|
||||
if 'modelDescription' not in data:
|
||||
data['modelDescription'] = ""
|
||||
needs_update = True
|
||||
|
||||
if needs_update:
|
||||
|
||||
@@ -8,7 +8,8 @@ BASE_MODEL_MAPPING = {
|
||||
"sd-v2": "SD 2.0",
|
||||
"flux1": "Flux.1 D",
|
||||
"flux.1 d": "Flux.1 D",
|
||||
"illustrious": "IL",
|
||||
"illustrious": "Illustrious",
|
||||
"il": "Illustrious",
|
||||
"pony": "Pony",
|
||||
"Hunyuan Video": "Hunyuan Video"
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, List
|
||||
from datetime import datetime
|
||||
import os
|
||||
from .model_utils import determine_base_model
|
||||
@@ -15,10 +15,18 @@ class LoraMetadata:
|
||||
sha256: str # SHA256 hash of the file
|
||||
base_model: str # Base model (SD1.5/SD2.1/SDXL/etc.)
|
||||
preview_url: str # Preview image URL
|
||||
preview_nsfw_level: int = 0 # NSFW level of the preview image
|
||||
usage_tips: str = "{}" # Usage tips for the model, json string
|
||||
notes: str = "" # Additional notes
|
||||
from_civitai: bool = True # Whether the lora is from Civitai
|
||||
from_civitai: bool = True # Whether the lora is from Civitai
|
||||
civitai: Optional[Dict] = None # Civitai API data if available
|
||||
tags: List[str] = None # Model tags
|
||||
modelDescription: str = "" # Full model description
|
||||
|
||||
def __post_init__(self):
|
||||
# Initialize empty lists to avoid mutable default parameter issue
|
||||
if self.tags is None:
|
||||
self.tags = []
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> 'LoraMetadata':
|
||||
@@ -42,6 +50,7 @@ class LoraMetadata:
|
||||
sha256=file_info['hashes'].get('SHA256', ''),
|
||||
base_model=base_model,
|
||||
preview_url=None, # Will be updated after preview download
|
||||
preview_nsfw_level=0, # Will be updated after preview download, it is decided by the nsfw level of the preview image
|
||||
from_civitai=True,
|
||||
civitai=version_info
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user