mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 13:42:12 -03:00
297 lines
11 KiB
Python
297 lines
11 KiB
Python
import os
|
||
import logging
|
||
import time
|
||
import asyncio
|
||
from typing import List, Dict, Optional
|
||
from dataclasses import dataclass
|
||
from operator import itemgetter
|
||
from ..config import config
|
||
from ..utils.file_utils import load_metadata, get_file_info
|
||
from .lora_cache import LoraCache
|
||
from difflib import SequenceMatcher
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class LoraScanner:
|
||
"""Service for scanning and managing LoRA files"""
|
||
|
||
_instance = None
|
||
_lock = asyncio.Lock()
|
||
|
||
def __new__(cls):
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
# 确保初始化只执行一次
|
||
if not hasattr(self, '_initialized'):
|
||
self._cache: Optional[LoraCache] = None
|
||
self._initialization_lock = asyncio.Lock()
|
||
self._initialization_task: Optional[asyncio.Task] = None
|
||
self._initialized = True
|
||
|
||
@classmethod
|
||
async def get_instance(cls):
|
||
"""Get singleton instance with async support"""
|
||
async with cls._lock:
|
||
if cls._instance is None:
|
||
cls._instance = cls()
|
||
return cls._instance
|
||
|
||
async def get_cached_data(self, force_refresh: bool = False) -> LoraCache:
|
||
"""Get cached LoRA data, refresh if needed"""
|
||
async with self._initialization_lock:
|
||
|
||
# 如果缓存未初始化但需要响应请求,返回空缓存
|
||
if self._cache is None and not force_refresh:
|
||
return LoraCache(
|
||
raw_data=[],
|
||
sorted_by_name=[],
|
||
sorted_by_date=[],
|
||
folders=[]
|
||
)
|
||
|
||
# 如果正在初始化,等待完成
|
||
if self._initialization_task and not self._initialization_task.done():
|
||
try:
|
||
await self._initialization_task
|
||
except Exception as e:
|
||
logger.error(f"Cache initialization failed: {e}")
|
||
self._initialization_task = None
|
||
|
||
if (self._cache is None or force_refresh):
|
||
|
||
# 创建新的初始化任务
|
||
if not self._initialization_task or self._initialization_task.done():
|
||
self._initialization_task = asyncio.create_task(self._initialize_cache())
|
||
|
||
try:
|
||
await self._initialization_task
|
||
except Exception as e:
|
||
logger.error(f"Cache initialization failed: {e}")
|
||
# 如果缓存已存在,继续使用旧缓存
|
||
if self._cache is None:
|
||
raise # 如果没有缓存,则抛出异常
|
||
|
||
return self._cache
|
||
|
||
async def _initialize_cache(self) -> None:
|
||
"""Initialize or refresh the cache"""
|
||
# Scan for new data
|
||
raw_data = await self.scan_all_loras()
|
||
|
||
# Update cache
|
||
self._cache = LoraCache(
|
||
raw_data=raw_data,
|
||
sorted_by_name=[],
|
||
sorted_by_date=[],
|
||
folders=[]
|
||
)
|
||
|
||
# Call resort_cache to create sorted views
|
||
await self._cache.resort()
|
||
|
||
def fuzzy_match(self, text: str, pattern: str, threshold: float = 0.7) -> bool:
|
||
"""
|
||
Check if text matches pattern using fuzzy matching.
|
||
Returns True if similarity ratio is above threshold.
|
||
"""
|
||
if not pattern or not text:
|
||
return False
|
||
|
||
# Convert both to lowercase for case-insensitive matching
|
||
text = text.lower()
|
||
pattern = pattern.lower()
|
||
|
||
# Split pattern into words
|
||
search_words = pattern.split()
|
||
|
||
# Check each word
|
||
for word in search_words:
|
||
# First check if word is a substring (faster)
|
||
if word in text:
|
||
continue
|
||
|
||
# If not found as substring, try fuzzy matching
|
||
# Check if any part of the text matches this word
|
||
found_match = False
|
||
for text_part in text.split():
|
||
ratio = SequenceMatcher(None, text_part, word).ratio()
|
||
if ratio >= threshold:
|
||
found_match = True
|
||
break
|
||
|
||
if not found_match:
|
||
return False
|
||
|
||
# All words found either as substrings or fuzzy matches
|
||
return True
|
||
|
||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
||
folder: str = None, search: str = None, fuzzy: bool = False):
|
||
cache = await self.get_cached_data()
|
||
|
||
# 先获取基础数据集
|
||
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
||
|
||
# 应用文件夹过滤
|
||
if folder is not None:
|
||
filtered_data = [item for item in filtered_data if item['folder'] == folder]
|
||
|
||
# 应用搜索过滤(只匹配model_name)
|
||
if search:
|
||
if fuzzy:
|
||
filtered_data = [
|
||
item for item in filtered_data
|
||
if any(
|
||
self.fuzzy_match(str(value), search)
|
||
for value in [
|
||
item.get('model_name', ''),
|
||
item.get('base_model', '')
|
||
]
|
||
if value
|
||
)
|
||
]
|
||
else:
|
||
# Original exact search logic
|
||
filtered_data = [
|
||
item for item in filtered_data
|
||
if search in str(item.get('model_name', '')).lower()
|
||
]
|
||
|
||
# 计算分页
|
||
total_items = len(filtered_data)
|
||
start_idx = (page - 1) * page_size
|
||
end_idx = min(start_idx + page_size, total_items)
|
||
|
||
result = {
|
||
'items': filtered_data[start_idx:end_idx],
|
||
'total': total_items,
|
||
'page': page,
|
||
'page_size': page_size,
|
||
'total_pages': (total_items + page_size - 1) // page_size
|
||
}
|
||
|
||
return result
|
||
|
||
def invalidate_cache(self):
|
||
"""Invalidate the current cache"""
|
||
self._cache = None
|
||
|
||
async def scan_all_loras(self) -> List[Dict]:
|
||
"""Scan all LoRA directories and return metadata"""
|
||
all_loras = []
|
||
|
||
# 分目录异步扫描
|
||
scan_tasks = []
|
||
for loras_root in config.loras_roots:
|
||
task = asyncio.create_task(self._scan_directory(loras_root))
|
||
scan_tasks.append(task)
|
||
|
||
for task in scan_tasks:
|
||
try:
|
||
loras = await task
|
||
all_loras.extend(loras)
|
||
except Exception as e:
|
||
logger.error(f"Error scanning directory: {e}")
|
||
|
||
return all_loras
|
||
|
||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||
"""Scan a single directory for LoRA files"""
|
||
loras = []
|
||
|
||
# 使用异步安全的目录遍历方式
|
||
async def scan_recursive(path: str):
|
||
try:
|
||
with os.scandir(path) as it:
|
||
entries = list(it) # 同步获取目录条目
|
||
for entry in entries:
|
||
if entry.is_file() and entry.name.endswith('.safetensors'):
|
||
file_path = entry.path.replace(os.sep, "/")
|
||
await self._process_single_file(file_path, root_path, loras)
|
||
await asyncio.sleep(0) # 释放事件循环
|
||
elif entry.is_dir():
|
||
await scan_recursive(entry.path)
|
||
except Exception as e:
|
||
logger.error(f"Error scanning {path}: {e}")
|
||
|
||
await scan_recursive(root_path)
|
||
return loras
|
||
|
||
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
|
||
"""处理单个文件并添加到结果列表"""
|
||
try:
|
||
result = await self._process_lora_file(file_path, root_path)
|
||
if result:
|
||
loras.append(result)
|
||
except Exception as e:
|
||
logger.error(f"Error processing {file_path}: {e}")
|
||
|
||
async def _process_lora_file(self, file_path: str, root_path: str) -> Dict:
|
||
"""Process a single LoRA file and return its metadata"""
|
||
# Try loading existing metadata
|
||
metadata = await load_metadata(file_path)
|
||
|
||
if metadata is None:
|
||
# Create new metadata if none exists
|
||
metadata = await get_file_info(file_path)
|
||
|
||
# Convert to dict and add folder info
|
||
lora_data = metadata.to_dict()
|
||
rel_path = os.path.relpath(file_path, root_path)
|
||
folder = os.path.dirname(rel_path)
|
||
lora_data['folder'] = folder.replace(os.path.sep, '/')
|
||
|
||
return lora_data
|
||
|
||
async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool:
|
||
"""Update preview URL in cache for a specific lora
|
||
|
||
Args:
|
||
file_path: The file path of the lora to update
|
||
preview_url: The new preview URL
|
||
|
||
Returns:
|
||
bool: True if the update was successful, False if cache doesn't exist or lora wasn't found
|
||
"""
|
||
if self._cache is None:
|
||
return False
|
||
|
||
return await self._cache.update_preview_url(file_path, preview_url)
|
||
|
||
async def scan_single_lora(self, file_path: str) -> Optional[Dict]:
|
||
"""Scan a single LoRA file and return its metadata"""
|
||
try:
|
||
if not os.path.exists(file_path):
|
||
return None
|
||
|
||
# 获取基本文件信息
|
||
metadata = await get_file_info(file_path)
|
||
if not metadata:
|
||
return None
|
||
|
||
# 计算相对于 lora_roots 的文件夹路径
|
||
folder = None
|
||
file_dir = os.path.dirname(file_path)
|
||
for root in config.loras_roots:
|
||
if file_dir.startswith(root):
|
||
rel_path = os.path.relpath(file_dir, root)
|
||
if rel_path == '.':
|
||
folder = '' # 根目录
|
||
else:
|
||
folder = rel_path.replace(os.sep, '/')
|
||
break
|
||
|
||
# 确保 folder 字段存在
|
||
metadata_dict = metadata.to_dict()
|
||
metadata_dict['folder'] = folder or ''
|
||
|
||
return metadata_dict
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error scanning {file_path}: {e}")
|
||
return None
|
||
|