Files
ComfyUI-Lora-Manager/services/lora_scanner.py
2025-02-04 23:15:10 +08:00

251 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import logging
import time
import asyncio
from typing import List, Dict, Optional
from dataclasses import dataclass
from operator import itemgetter
from ..config import config
from ..utils.file_utils import load_metadata, get_file_info
from .lora_cache import LoraCache
logger = logging.getLogger(__name__)
class LoraScanner:
"""Service for scanning and managing LoRA files"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
# 确保初始化只执行一次
if not hasattr(self, '_initialized'):
self._cache: Optional[LoraCache] = None
self._initialization_lock = asyncio.Lock()
self._initialization_task: Optional[asyncio.Task] = None
self._initialized = True
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
async def get_cached_data(self, force_refresh: bool = False) -> LoraCache:
"""Get cached LoRA data, refresh if needed"""
async with self._initialization_lock:
# 如果缓存未初始化但需要响应请求,返回空缓存
if self._cache is None and not force_refresh:
return LoraCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# 如果正在初始化,等待完成
if self._initialization_task and not self._initialization_task.done():
try:
await self._initialization_task
except Exception as e:
logger.error(f"Cache initialization failed: {e}")
self._initialization_task = None
if (self._cache is None or force_refresh):
# 创建新的初始化任务
if not self._initialization_task or self._initialization_task.done():
self._initialization_task = asyncio.create_task(self._initialize_cache())
try:
await self._initialization_task
except Exception as e:
logger.error(f"Cache initialization failed: {e}")
# 如果缓存已存在,继续使用旧缓存
if self._cache is None:
raise # 如果没有缓存,则抛出异常
return self._cache
async def _initialize_cache(self) -> None:
"""Initialize or refresh the cache"""
# Scan for new data
raw_data = await self.scan_all_loras()
# Update cache
self._cache = LoraCache(
raw_data=raw_data,
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# Call resort_cache to create sorted views
await self._cache.resort()
async def get_paginated_data(self,
page: int,
page_size: int,
sort_by: str = 'date',
folder: Optional[str] = None,
search: Optional[str] = None) -> Dict:
"""Get paginated LoRA data with search support"""
cache = await self.get_cached_data()
# 先获取基础数据集
data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
# 应用文件夹过滤
if folder is not None:
data = [item for item in data if item['folder'] == folder]
# 应用搜索过滤只匹配model_name
if search:
search = search.lower().strip()
before_search = len(data)
data = [item for item in data
if search in item['model_name'].lower()]
# 计算分页
total_items = len(data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
result = {
'items': data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
return result
def invalidate_cache(self):
"""Invalidate the current cache"""
self._cache = None
async def scan_all_loras(self) -> List[Dict]:
"""Scan all LoRA directories and return metadata"""
all_loras = []
# 分目录异步扫描
scan_tasks = []
for loras_root in config.loras_roots:
task = asyncio.create_task(self._scan_directory(loras_root))
scan_tasks.append(task)
for task in scan_tasks:
try:
loras = await task
all_loras.extend(loras)
except Exception as e:
logger.error(f"Error scanning directory: {e}")
return all_loras
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a single directory for LoRA files"""
loras = []
# 使用异步安全的目录遍历方式
async def scan_recursive(path: str):
try:
with os.scandir(path) as it:
entries = list(it) # 同步获取目录条目
for entry in entries:
if entry.is_file() and entry.name.endswith('.safetensors'):
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, root_path, loras)
await asyncio.sleep(0) # 释放事件循环
elif entry.is_dir():
await scan_recursive(entry.path)
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path)
return loras
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
"""处理单个文件并添加到结果列表"""
try:
result = await self._process_lora_file(file_path, root_path)
if result:
loras.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
async def _process_lora_file(self, file_path: str, root_path: str) -> Dict:
"""Process a single LoRA file and return its metadata"""
# Try loading existing metadata
metadata = await load_metadata(file_path)
if metadata is None:
# Create new metadata if none exists
metadata = await get_file_info(file_path)
# Convert to dict and add folder info
lora_data = metadata.to_dict()
rel_path = os.path.relpath(file_path, root_path)
folder = os.path.dirname(rel_path)
lora_data['folder'] = folder.replace(os.path.sep, '/')
return lora_data
async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool:
"""Update preview URL in cache for a specific lora
Args:
file_path: The file path of the lora to update
preview_url: The new preview URL
Returns:
bool: True if the update was successful, False if cache doesn't exist or lora wasn't found
"""
if self._cache is None:
return False
return await self._cache.update_preview_url(file_path, preview_url)
async def scan_single_lora(self, file_path: str) -> Optional[Dict]:
"""Scan a single LoRA file and return its metadata"""
try:
if not os.path.exists(file_path):
return None
# 获取基本文件信息
metadata = await get_file_info(file_path)
if not metadata:
return None
# 计算相对于 lora_roots 的文件夹路径
folder = None
file_dir = os.path.dirname(file_path)
for root in config.loras_roots:
if file_dir.startswith(root):
rel_path = os.path.relpath(file_dir, root)
if rel_path == '.':
folder = '' # 根目录
else:
folder = rel_path.replace(os.sep, '/')
break
# 确保 folder 字段存在
metadata_dict = metadata.to_dict()
metadata_dict['folder'] = folder or ''
return metadata_dict
except Exception as e:
logger.error(f"Error scanning {file_path}: {e}")
return None