mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 13:42:12 -03:00
391 lines
15 KiB
Python
391 lines
15 KiB
Python
import json
|
|
import os
|
|
import logging
|
|
import asyncio
|
|
import shutil
|
|
import time
|
|
from typing import List, Dict, Optional, Set
|
|
|
|
from ..utils.models import LoraMetadata
|
|
from ..config import config
|
|
from .model_scanner import ModelScanner
|
|
from .lora_hash_index import LoraHashIndex
|
|
from .settings_manager import settings
|
|
from ..utils.constants import NSFW_LEVELS
|
|
from ..utils.utils import fuzzy_match
|
|
from .service_registry import ServiceRegistry
|
|
import sys
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class LoraScanner(ModelScanner):
|
|
"""Service for scanning and managing LoRA files"""
|
|
|
|
_instance = None
|
|
_lock = asyncio.Lock()
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
# Ensure initialization happens only once
|
|
if not hasattr(self, '_initialized'):
|
|
# Define supported file extensions
|
|
file_extensions = {'.safetensors'}
|
|
|
|
# Initialize parent class
|
|
super().__init__(
|
|
model_type="lora",
|
|
model_class=LoraMetadata,
|
|
file_extensions=file_extensions,
|
|
hash_index=LoraHashIndex()
|
|
)
|
|
self._initialized = True
|
|
|
|
@classmethod
|
|
async def get_instance(cls):
|
|
"""Get singleton instance with async support"""
|
|
async with cls._lock:
|
|
if cls._instance is None:
|
|
cls._instance = cls()
|
|
return cls._instance
|
|
|
|
def get_model_roots(self) -> List[str]:
|
|
"""Get lora root directories"""
|
|
return config.loras_roots
|
|
|
|
async def scan_all_models(self) -> List[Dict]:
|
|
"""Scan all LoRA directories and return metadata"""
|
|
all_loras = []
|
|
|
|
# Create scan tasks for each directory
|
|
scan_tasks = []
|
|
for lora_root in self.get_model_roots():
|
|
task = asyncio.create_task(self._scan_directory(lora_root))
|
|
scan_tasks.append(task)
|
|
|
|
# Wait for all tasks to complete
|
|
for task in scan_tasks:
|
|
try:
|
|
loras = await task
|
|
all_loras.extend(loras)
|
|
except Exception as e:
|
|
logger.error(f"Error scanning directory: {e}")
|
|
|
|
return all_loras
|
|
|
|
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
|
"""Scan a single directory for LoRA files"""
|
|
loras = []
|
|
original_root = root_path # Save original root path
|
|
|
|
async def scan_recursive(path: str, visited_paths: set):
|
|
"""Recursively scan directory, avoiding circular symlinks"""
|
|
try:
|
|
real_path = os.path.realpath(path)
|
|
if real_path in visited_paths:
|
|
logger.debug(f"Skipping already visited path: {path}")
|
|
return
|
|
visited_paths.add(real_path)
|
|
|
|
with os.scandir(path) as it:
|
|
entries = list(it)
|
|
for entry in entries:
|
|
try:
|
|
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
|
|
# Use original path instead of real path
|
|
file_path = entry.path.replace(os.sep, "/")
|
|
await self._process_single_file(file_path, original_root, loras)
|
|
await asyncio.sleep(0)
|
|
elif entry.is_dir(follow_symlinks=True):
|
|
# For directories, continue scanning with original path
|
|
await scan_recursive(entry.path, visited_paths)
|
|
except Exception as e:
|
|
logger.error(f"Error processing entry {entry.path}: {e}")
|
|
except Exception as e:
|
|
logger.error(f"Error scanning {path}: {e}")
|
|
|
|
await scan_recursive(root_path, set())
|
|
return loras
|
|
|
|
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
|
|
"""Process a single file and add to results list"""
|
|
try:
|
|
result = await self._process_model_file(file_path, root_path)
|
|
if result:
|
|
loras.append(result)
|
|
except Exception as e:
|
|
logger.error(f"Error processing {file_path}: {e}")
|
|
|
|
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
|
folder: str = None, search: str = None, fuzzy_search: bool = False,
|
|
base_models: list = None, tags: list = None,
|
|
search_options: dict = None, hash_filters: dict = None) -> Dict:
|
|
"""Get paginated and filtered lora data
|
|
|
|
Args:
|
|
page: Current page number (1-based)
|
|
page_size: Number of items per page
|
|
sort_by: Sort method ('name' or 'date')
|
|
folder: Filter by folder path
|
|
search: Search term
|
|
fuzzy_search: Use fuzzy matching for search
|
|
base_models: List of base models to filter by
|
|
tags: List of tags to filter by
|
|
search_options: Dictionary with search options (filename, modelname, tags, recursive)
|
|
hash_filters: Dictionary with hash filtering options (single_hash or multiple_hashes)
|
|
"""
|
|
cache = await self.get_cached_data()
|
|
|
|
# Get default search options if not provided
|
|
if search_options is None:
|
|
search_options = {
|
|
'filename': True,
|
|
'modelname': True,
|
|
'tags': False,
|
|
'recursive': False,
|
|
}
|
|
|
|
# Get the base data set
|
|
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
|
|
|
# Apply hash filtering if provided (highest priority)
|
|
if hash_filters:
|
|
single_hash = hash_filters.get('single_hash')
|
|
multiple_hashes = hash_filters.get('multiple_hashes')
|
|
|
|
if single_hash:
|
|
# Filter by single hash
|
|
single_hash = single_hash.lower() # Ensure lowercase for matching
|
|
filtered_data = [
|
|
lora for lora in filtered_data
|
|
if lora.get('sha256', '').lower() == single_hash
|
|
]
|
|
elif multiple_hashes:
|
|
# Filter by multiple hashes
|
|
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
|
|
filtered_data = [
|
|
lora for lora in filtered_data
|
|
if lora.get('sha256', '').lower() in hash_set
|
|
]
|
|
|
|
|
|
# Jump to pagination
|
|
total_items = len(filtered_data)
|
|
start_idx = (page - 1) * page_size
|
|
end_idx = min(start_idx + page_size, total_items)
|
|
|
|
result = {
|
|
'items': filtered_data[start_idx:end_idx],
|
|
'total': total_items,
|
|
'page': page,
|
|
'page_size': page_size,
|
|
'total_pages': (total_items + page_size - 1) // page_size
|
|
}
|
|
|
|
return result
|
|
|
|
# Apply SFW filtering if enabled
|
|
if settings.get('show_only_sfw', False):
|
|
filtered_data = [
|
|
lora for lora in filtered_data
|
|
if not lora.get('preview_nsfw_level') or lora.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
|
]
|
|
|
|
# Apply folder filtering
|
|
if folder is not None:
|
|
if search_options.get('recursive', False):
|
|
# Recursive folder filtering - include all subfolders
|
|
filtered_data = [
|
|
lora for lora in filtered_data
|
|
if lora['folder'].startswith(folder)
|
|
]
|
|
else:
|
|
# Exact folder filtering
|
|
filtered_data = [
|
|
lora for lora in filtered_data
|
|
if lora['folder'] == folder
|
|
]
|
|
|
|
# Apply base model filtering
|
|
if base_models and len(base_models) > 0:
|
|
filtered_data = [
|
|
lora for lora in filtered_data
|
|
if lora.get('base_model') in base_models
|
|
]
|
|
|
|
# Apply tag filtering
|
|
if tags and len(tags) > 0:
|
|
filtered_data = [
|
|
lora for lora in filtered_data
|
|
if any(tag in lora.get('tags', []) for tag in tags)
|
|
]
|
|
|
|
# Apply search filtering
|
|
if search:
|
|
search_results = []
|
|
search_opts = search_options or {}
|
|
|
|
for lora in filtered_data:
|
|
# Search by file name
|
|
if search_opts.get('filename', True):
|
|
if fuzzy_match(lora.get('file_name', ''), search):
|
|
search_results.append(lora)
|
|
continue
|
|
|
|
# Search by model name
|
|
if search_opts.get('modelname', True):
|
|
if fuzzy_match(lora.get('model_name', ''), search):
|
|
search_results.append(lora)
|
|
continue
|
|
|
|
# Search by tags
|
|
if search_opts.get('tags', False) and 'tags' in lora:
|
|
if any(fuzzy_match(tag, search) for tag in lora['tags']):
|
|
search_results.append(lora)
|
|
continue
|
|
|
|
filtered_data = search_results
|
|
|
|
# Calculate pagination
|
|
total_items = len(filtered_data)
|
|
start_idx = (page - 1) * page_size
|
|
end_idx = min(start_idx + page_size, total_items)
|
|
|
|
result = {
|
|
'items': filtered_data[start_idx:end_idx],
|
|
'total': total_items,
|
|
'page': page,
|
|
'page_size': page_size,
|
|
'total_pages': (total_items + page_size - 1) // page_size
|
|
}
|
|
|
|
return result
|
|
|
|
async def _update_metadata_paths(self, metadata_path: str, lora_path: str) -> Dict:
|
|
"""Update file paths in metadata file"""
|
|
try:
|
|
with open(metadata_path, 'r', encoding='utf-8') as f:
|
|
metadata = json.load(f)
|
|
|
|
# Update file_path
|
|
metadata['file_path'] = lora_path.replace(os.sep, '/')
|
|
|
|
# Update preview_url if exists
|
|
if 'preview_url' in metadata:
|
|
preview_dir = os.path.dirname(lora_path)
|
|
preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0]
|
|
preview_ext = os.path.splitext(metadata['preview_url'])[1]
|
|
new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}")
|
|
metadata['preview_url'] = new_preview_path.replace(os.sep, '/')
|
|
|
|
# Save updated metadata
|
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
|
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
|
|
|
return metadata
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating metadata paths: {e}", exc_info=True)
|
|
|
|
# Lora-specific hash index functionality
|
|
def has_lora_hash(self, sha256: str) -> bool:
|
|
"""Check if a LoRA with given hash exists"""
|
|
return self._hash_index.has_hash(sha256.lower())
|
|
|
|
def get_lora_path_by_hash(self, sha256: str) -> Optional[str]:
|
|
"""Get file path for a LoRA by its hash"""
|
|
return self._hash_index.get_path(sha256.lower())
|
|
|
|
def get_lora_hash_by_path(self, file_path: str) -> Optional[str]:
|
|
"""Get hash for a LoRA by its file path"""
|
|
return self._hash_index.get_hash(file_path)
|
|
|
|
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
|
|
"""Get top tags sorted by count"""
|
|
# Make sure cache is initialized
|
|
await self.get_cached_data()
|
|
|
|
# Sort tags by count in descending order
|
|
sorted_tags = sorted(
|
|
[{"tag": tag, "count": count} for tag, count in self._tags_count.items()],
|
|
key=lambda x: x['count'],
|
|
reverse=True
|
|
)
|
|
|
|
# Return limited number
|
|
return sorted_tags[:limit]
|
|
|
|
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
|
|
"""Get base models used in loras sorted by frequency"""
|
|
# Make sure cache is initialized
|
|
cache = await self.get_cached_data()
|
|
|
|
# Count base model occurrences
|
|
base_model_counts = {}
|
|
for lora in cache.raw_data:
|
|
if 'base_model' in lora and lora['base_model']:
|
|
base_model = lora['base_model']
|
|
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
|
|
|
# Sort base models by count
|
|
sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()]
|
|
sorted_models.sort(key=lambda x: x['count'], reverse=True)
|
|
|
|
# Return limited number
|
|
return sorted_models[:limit]
|
|
|
|
async def diagnose_hash_index(self):
|
|
"""Diagnostic method to verify hash index functionality"""
|
|
print("\n\n*** DIAGNOSING LORA HASH INDEX ***\n\n", file=sys.stderr)
|
|
|
|
# First check if the hash index has any entries
|
|
if hasattr(self, '_hash_index'):
|
|
index_entries = len(self._hash_index._hash_to_path)
|
|
print(f"Hash index has {index_entries} entries", file=sys.stderr)
|
|
|
|
# Print a few example entries if available
|
|
if index_entries > 0:
|
|
print("\nSample hash index entries:", file=sys.stderr)
|
|
count = 0
|
|
for hash_val, path in self._hash_index._hash_to_path.items():
|
|
if count < 5: # Just show the first 5
|
|
print(f"Hash: {hash_val[:8]}... -> Path: {path}", file=sys.stderr)
|
|
count += 1
|
|
else:
|
|
break
|
|
else:
|
|
print("Hash index not initialized", file=sys.stderr)
|
|
|
|
# Try looking up by a known hash for testing
|
|
if not hasattr(self, '_hash_index') or not self._hash_index._hash_to_path:
|
|
print("No hash entries to test lookup with", file=sys.stderr)
|
|
return
|
|
|
|
test_hash = next(iter(self._hash_index._hash_to_path.keys()))
|
|
test_path = self._hash_index.get_path(test_hash)
|
|
print(f"\nTest lookup by hash: {test_hash[:8]}... -> {test_path}", file=sys.stderr)
|
|
|
|
# Also test reverse lookup
|
|
test_hash_result = self._hash_index.get_hash(test_path)
|
|
print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr)
|
|
|
|
async def get_lora_info_by_name(self, name):
|
|
"""Get LoRA information by name"""
|
|
try:
|
|
# Get cached data
|
|
cache = await self.get_cached_data()
|
|
|
|
# Find the LoRA by name
|
|
for lora in cache.raw_data:
|
|
if lora.get("file_name") == name:
|
|
return lora
|
|
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting LoRA info by name: {e}", exc_info=True)
|
|
return None
|
|
|