checkpoint

This commit is contained in:
Will Miao
2025-04-10 09:08:36 +08:00
parent 64c9e4aeca
commit 8fdfb68741
12 changed files with 1397 additions and 484 deletions

View File

@@ -73,7 +73,7 @@ class Config:
"""添加静态路由映射"""
normalized_path = os.path.normpath(path).replace(os.sep, '/')
self._route_mappings[normalized_path] = route
logger.info(f"Added route mapping: {normalized_path} -> {route}")
# logger.info(f"Added route mapping: {normalized_path} -> {route}")
def map_path_to_link(self, path: str) -> str:
"""将目标路径映射回符号链接路径"""

View File

@@ -7,10 +7,12 @@ from .routes.api_routes import ApiRoutes
from .routes.recipe_routes import RecipeRoutes
from .routes.checkpoints_routes import CheckpointsRoutes
from .services.lora_scanner import LoraScanner
from .services.checkpoint_scanner import CheckpointScanner
from .services.recipe_scanner import RecipeScanner
from .services.file_monitor import LoraFileMonitor
from .services.file_monitor import LoraFileMonitor, CheckpointFileMonitor
from .services.lora_cache import LoraCache
from .services.recipe_cache import RecipeCache
from .services.model_cache import ModelCache
import logging
logger = logging.getLogger(__name__)
@@ -23,7 +25,7 @@ class LoraManager:
"""Initialize and register all routes"""
app = PromptServer.instance.app
added_targets = set() # 用于跟踪已添加的目标路径
added_targets = set() # Track already added target paths
# Add static routes for each lora root
for idx, root in enumerate(config.loras_roots, start=1):
@@ -35,15 +37,34 @@ class LoraManager:
if link == root:
real_root = target
break
# 为原始路径添加静态路由
# Add static route for original path
app.router.add_static(preview_path, real_root)
logger.info(f"Added static route {preview_path} -> {real_root}")
# 记录路由映射
# Record route mapping
config.add_route_mapping(real_root, preview_path)
added_targets.add(real_root)
# 为符号链接的目标路径添加额外的静态路由
# Add static routes for each checkpoint root
checkpoint_scanner = CheckpointScanner()
for idx, root in enumerate(checkpoint_scanner.get_model_roots(), start=1):
preview_path = f'/checkpoints_static/root{idx}/preview'
real_root = root
if root in config._path_mappings.values():
for target, link in config._path_mappings.items():
if link == root:
real_root = target
break
# Add static route for original path
app.router.add_static(preview_path, real_root)
logger.info(f"Added static route {preview_path} -> {real_root}")
# Record route mapping
config.add_route_mapping(real_root, preview_path)
added_targets.add(real_root)
# Add static routes for symlink target paths
link_idx = 1
for target_path, link_path in config._path_mappings.items():
@@ -59,37 +80,47 @@ class LoraManager:
app.router.add_static('/loras_static', config.static_path)
# Setup feature routes
routes = LoraRoutes()
lora_routes = LoraRoutes()
checkpoints_routes = CheckpointsRoutes()
# Setup file monitoring
monitor = LoraFileMonitor(routes.scanner, config.loras_roots)
monitor.start()
lora_monitor = LoraFileMonitor(lora_routes.scanner, config.loras_roots)
lora_monitor.start()
routes.setup_routes(app)
checkpoint_monitor = CheckpointFileMonitor(checkpoints_routes.scanner, checkpoints_routes.scanner.get_model_roots())
checkpoint_monitor.start()
lora_routes.setup_routes(app)
checkpoints_routes.setup_routes(app)
ApiRoutes.setup_routes(app, monitor)
ApiRoutes.setup_routes(app, lora_monitor)
RecipeRoutes.setup_routes(app)
# Store monitor in app for cleanup
app['lora_monitor'] = monitor
# Store monitors in app for cleanup
app['lora_monitor'] = lora_monitor
app['checkpoint_monitor'] = checkpoint_monitor
logger.info("PromptServer app: ", app)
# Schedule cache initialization using the application's startup handler
app.on_startup.append(lambda app: cls._schedule_cache_init(routes.scanner, routes.recipe_scanner))
app.on_startup.append(lambda app: cls._schedule_cache_init(
lora_routes.scanner,
checkpoints_routes.scanner,
lora_routes.recipe_scanner
))
# Add cleanup
app.on_shutdown.append(cls._cleanup)
app.on_shutdown.append(ApiRoutes.cleanup)
@classmethod
async def _schedule_cache_init(cls, scanner: LoraScanner, recipe_scanner: RecipeScanner):
async def _schedule_cache_init(cls, lora_scanner, checkpoint_scanner, recipe_scanner):
"""Schedule cache initialization in the running event loop"""
try:
# 创建低优先级的初始化任务
lora_task = asyncio.create_task(cls._initialize_lora_cache(scanner), name='lora_cache_init')
# Schedule recipe cache initialization with a delay to let lora scanner initialize first
recipe_task = asyncio.create_task(cls._initialize_recipe_cache(recipe_scanner, delay=2), name='recipe_cache_init')
# Create low-priority initialization tasks
lora_task = asyncio.create_task(cls._initialize_lora_cache(lora_scanner), name='lora_cache_init')
checkpoint_task = asyncio.create_task(cls._initialize_checkpoint_cache(checkpoint_scanner), name='checkpoint_cache_init')
recipe_task = asyncio.create_task(cls._initialize_recipe_cache(recipe_scanner), name='recipe_cache_init')
logger.info("Cache initialization tasks scheduled to run in background")
except Exception as e:
logger.error(f"LoRA Manager: Error scheduling cache initialization: {e}")
@@ -97,26 +128,45 @@ class LoraManager:
async def _initialize_lora_cache(cls, scanner: LoraScanner):
"""Initialize lora cache in background"""
try:
# 设置初始缓存占位
# Set initial placeholder cache
scanner._cache = LoraCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# 分阶段加载缓存
await scanner.get_cached_data(force_refresh=True)
# 使用线程池执行耗时操作
loop = asyncio.get_event_loop()
await loop.run_in_executor(
None, # 使用默认线程池
lambda: scanner.get_cached_data_sync(force_refresh=True) # 创建同步版本的方法
)
# Load cache in phases
# await scanner.get_cached_data(force_refresh=True)
except Exception as e:
logger.error(f"LoRA Manager: Error initializing lora cache: {e}")
@classmethod
async def _initialize_recipe_cache(cls, scanner: RecipeScanner, delay: float = 2.0):
"""Initialize recipe cache in background with a delay"""
async def _initialize_checkpoint_cache(cls, scanner: CheckpointScanner):
"""Initialize checkpoint cache in background"""
try:
# Wait for the specified delay to let lora scanner initialize first
await asyncio.sleep(delay)
# Set initial placeholder cache
scanner._cache = ModelCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# Load cache in phases
await scanner.get_cached_data(force_refresh=True)
except Exception as e:
logger.error(f"LoRA Manager: Error initializing checkpoint cache: {e}")
@classmethod
async def _initialize_recipe_cache(cls, scanner: RecipeScanner):
"""Initialize recipe cache in background with a delay"""
try:
# Set initial empty cache
scanner._cache = RecipeCache(
raw_data=[],
@@ -134,3 +184,6 @@ class LoraManager:
"""Cleanup resources"""
if 'lora_monitor' in app:
app['lora_monitor'].stop()
if 'checkpoint_monitor' in app:
app['checkpoint_monitor'].stop()

View File

@@ -1,44 +1,146 @@
import os
import json
import asyncio
import aiohttp
from aiohttp import web
import jinja2
import logging
from datetime import datetime
from ..services.checkpoint_scanner import CheckpointScanner
from ..config import config
from ..services.settings_manager import settings
logger = logging.getLogger(__name__)
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
class CheckpointsRoutes:
"""Route handlers for Checkpoints management endpoints"""
"""API routes for checkpoint management"""
def __init__(self):
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
self.scanner = CheckpointScanner()
def setup_routes(self, app):
"""Register routes with the aiohttp app"""
app.router.add_get('/lora_manager/api/checkpoints', self.get_checkpoints)
app.router.add_get('/lora_manager/api/checkpoints/scan', self.scan_checkpoints)
app.router.add_get('/lora_manager/api/checkpoints/info/{name}', self.get_checkpoint_info)
async def handle_checkpoints_page(self, request: web.Request) -> web.Response:
"""Handle GET /checkpoints request"""
async def get_checkpoints(self, request):
"""Get paginated checkpoint data"""
try:
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
is_initializing=False,
settings=settings,
request=request
# Parse query parameters
page = int(request.query.get('page', '1'))
page_size = min(int(request.query.get('page_size', '20')), 100)
sort_by = request.query.get('sort', 'name')
folder = request.query.get('folder', None)
search = request.query.get('search', None)
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
base_models = request.query.getall('base_model', [])
tags = request.query.getall('tag', [])
# Process search options
search_options = {
'filename': request.query.get('search_filename', 'true').lower() == 'true',
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
'tags': request.query.get('search_tags', 'false').lower() == 'true',
'recursive': request.query.get('recursive', 'false').lower() == 'true',
}
# Process hash filters if provided
hash_filters = {}
if 'hash' in request.query:
hash_filters['single_hash'] = request.query['hash']
elif 'hashes' in request.query:
try:
hash_list = json.loads(request.query['hashes'])
if isinstance(hash_list, list):
hash_filters['multiple_hashes'] = hash_list
except (json.JSONDecodeError, TypeError):
pass
# Get data from scanner
result = await self.get_paginated_data(
page=page,
page_size=page_size,
sort_by=sort_by,
folder=folder,
search=search,
fuzzy_search=fuzzy_search,
base_models=base_models,
tags=tags,
search_options=search_options,
hash_filters=hash_filters
)
return web.Response(
text=rendered,
content_type='text/html'
)
# Return as JSON
return web.json_response(result)
except Exception as e:
logger.error(f"Error handling checkpoints request: {e}", exc_info=True)
return web.Response(
text="Error loading checkpoints page",
status=500
)
logger.error(f"Error in get_checkpoints: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
def setup_routes(self, app: web.Application):
"""Register routes with the application"""
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
async def get_paginated_data(self, page, page_size, sort_by='name',
folder=None, search=None, fuzzy_search=False,
base_models=None, tags=None,
search_options=None, hash_filters=None):
"""Get paginated and filtered checkpoint data"""
cache = await self.scanner.get_cached_data()
# Implement similar filtering logic as in LoraScanner
# (Adapt code from LoraScanner.get_paginated_data)
# ...
# For now, a simplified implementation:
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
# Apply basic folder filtering if needed
if folder is not None:
filtered_data = [
cp for cp in filtered_data
if cp['folder'] == folder
]
# Apply basic search if needed
if search:
filtered_data = [
cp for cp in filtered_data
if search.lower() in cp['file_name'].lower() or
search.lower() in cp['model_name'].lower()
]
# Calculate pagination
total_items = len(filtered_data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
result = {
'items': filtered_data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
return result
async def scan_checkpoints(self, request):
"""Force a rescan of checkpoint files"""
try:
await self.scanner.get_cached_data(force_refresh=True)
return web.json_response({"status": "success", "message": "Checkpoint scan completed"})
except Exception as e:
logger.error(f"Error in scan_checkpoints: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_checkpoint_info(self, request):
"""Get detailed information for a specific checkpoint by name"""
try:
name = request.match_info.get('name', '')
checkpoint_info = await self.scanner.get_checkpoint_info_by_name(name)
if checkpoint_info:
return web.json_response(checkpoint_info)
else:
return web.json_response({"error": "Checkpoint not found"}, status=404)
except Exception as e:
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)

View File

@@ -0,0 +1,131 @@
import os
import logging
import asyncio
from typing import List, Dict, Optional, Set
import folder_paths # type: ignore
from ..utils.models import CheckpointMetadata
from ..config import config
from .model_scanner import ModelScanner
from .model_hash_index import ModelHashIndex
logger = logging.getLogger(__name__)
class CheckpointScanner(ModelScanner):
"""Service for scanning and managing checkpoint files"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not hasattr(self, '_initialized'):
# Define supported file extensions
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft'}
super().__init__(
model_type="checkpoint",
model_class=CheckpointMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex()
)
self._checkpoint_roots = self._init_checkpoint_roots()
self._initialized = True
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def _init_checkpoint_roots(self) -> List[str]:
"""Initialize checkpoint roots from ComfyUI settings"""
# Get both checkpoint and diffusion_models paths
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
diffusion_paths = folder_paths.get_folder_paths("diffusion_models")
# Combine, normalize and deduplicate paths
all_paths = set()
for path in checkpoint_paths + diffusion_paths:
if os.path.exists(path):
norm_path = path.replace(os.sep, "/")
all_paths.add(norm_path)
# Sort for consistent order
sorted_paths = sorted(all_paths, key=lambda p: p.lower())
logger.info(f"Found checkpoint roots: {sorted_paths}")
return sorted_paths
def get_model_roots(self) -> List[str]:
"""Get checkpoint root directories"""
return self._checkpoint_roots
async def scan_all_models(self) -> List[Dict]:
"""Scan all checkpoint directories and return metadata"""
all_checkpoints = []
# Create scan tasks for each directory
scan_tasks = []
for root in self._checkpoint_roots:
task = asyncio.create_task(self._scan_directory(root))
scan_tasks.append(task)
# Wait for all tasks to complete
for task in scan_tasks:
try:
checkpoints = await task
all_checkpoints.extend(checkpoints)
except Exception as e:
logger.error(f"Error scanning checkpoint directory: {e}")
return all_checkpoints
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a directory for checkpoint files"""
checkpoints = []
original_root = root_path
async def scan_recursive(path: str, visited_paths: set):
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True):
# Check if file has supported extension
ext = os.path.splitext(entry.name)[1].lower()
if ext in self.file_extensions:
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, checkpoints)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# For directories, continue scanning with original path
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return checkpoints
async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list):
"""Process a single checkpoint file and add to results"""
try:
result = await self._process_model_file(file_path, root_path)
if result:
checkpoints.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")

View File

@@ -7,6 +7,8 @@ from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from typing import List, Dict, Set
from threading import Lock
from .checkpoint_scanner import CheckpointScanner
from .lora_scanner import LoraScanner
from ..config import config
@@ -330,4 +332,90 @@ class LoraFileMonitor:
self.observer.schedule(self.handler, path, recursive=True)
logger.info(f"Added new monitoring path: {path}")
except Exception as e:
logger.error(f"Error adding new monitor for {path}: {e}")
logger.error(f"Error adding new monitor for {path}: {e}")
# Add CheckpointFileMonitor class
class CheckpointFileMonitor(LoraFileMonitor):
"""Monitor for checkpoint file changes"""
def __init__(self, scanner: CheckpointScanner, roots: List[str]):
# Reuse most of the LoraFileMonitor functionality, but with a different handler
self.scanner = scanner
scanner.set_file_monitor(self)
self.observer = Observer()
self.loop = asyncio.get_event_loop()
self.handler = CheckpointFileHandler(scanner, self.loop)
# Use existing path mappings
self.monitor_paths = set()
for root in roots:
self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/'))
# Add all mapped target paths
for target_path in config._path_mappings.keys():
self.monitor_paths.add(target_path)
class CheckpointFileHandler(LoraFileHandler):
"""Handler for checkpoint file system events"""
def __init__(self, scanner: CheckpointScanner, loop: asyncio.AbstractEventLoop):
super().__init__(scanner, loop)
# Configure supported file extensions
self.supported_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft'}
def on_created(self, event):
if event.is_directory:
return
# Handle supported file extensions directly
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext in self.supported_extensions:
if self._should_ignore(event.src_path):
return
# Process this file directly
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
if normalized_path not in self.scheduled_files:
logger.info(f"Checkpoint file created: {event.src_path}")
self.scheduled_files.add(normalized_path)
self._schedule_update('add', event.src_path)
# Ignore modifications for a short period after creation
self.loop.call_later(
self.debounce_delay * 2,
self.scheduled_files.discard,
normalized_path
)
def on_modified(self, event):
if event.is_directory:
return
# Only process supported file types
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext in self.supported_extensions:
super().on_modified(event)
def on_deleted(self, event):
if event.is_directory:
return
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext not in self.supported_extensions:
return
super().on_deleted(event)
def on_moved(self, event):
"""Handle file move/rename events"""
src_ext = os.path.splitext(event.src_path)[1].lower()
dest_ext = os.path.splitext(event.dest_path)[1].lower()
# If destination has supported extension, treat as new file
if dest_ext in self.supported_extensions:
super().on_moved(event)
# If source was supported extension, treat as deleted
elif src_ext in self.supported_extensions:
super().on_moved(event)

View File

@@ -4,13 +4,11 @@ import logging
import asyncio
import shutil
import time
from typing import List, Dict, Optional
from typing import List, Dict, Optional, Set
from ..utils.models import LoraMetadata
from ..config import config
from ..utils.file_utils import load_metadata, get_file_info, normalize_path, find_preview_file, save_metadata
from ..utils.lora_metadata import extract_lora_metadata
from .lora_cache import LoraCache
from .model_scanner import ModelScanner
from .lora_hash_index import LoraHashIndex
from .settings_manager import settings
from ..utils.constants import NSFW_LEVELS
@@ -19,7 +17,7 @@ import sys
logger = logging.getLogger(__name__)
class LoraScanner:
class LoraScanner(ModelScanner):
"""Service for scanning and managing LoRA files"""
_instance = None
@@ -31,20 +29,20 @@ class LoraScanner:
return cls._instance
def __init__(self):
# 确保初始化只执行一次
# Ensure initialization happens only once
if not hasattr(self, '_initialized'):
self._cache: Optional[LoraCache] = None
self._hash_index = LoraHashIndex()
self._initialization_lock = asyncio.Lock()
self._initialization_task: Optional[asyncio.Task] = None
# Define supported file extensions
file_extensions = {'.safetensors'}
# Initialize parent class
super().__init__(
model_type="lora",
model_class=LoraMetadata,
file_extensions=file_extensions,
hash_index=LoraHashIndex()
)
self._initialized = True
self.file_monitor = None # Add this line
self._tags_count = {} # Add a dictionary to store tag counts
def set_file_monitor(self, monitor):
"""Set file monitor instance"""
self.file_monitor = monitor
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
@@ -52,89 +50,74 @@ class LoraScanner:
if cls._instance is None:
cls._instance = cls()
return cls._instance
async def get_cached_data(self, force_refresh: bool = False) -> LoraCache:
"""Get cached LoRA data, refresh if needed"""
async with self._initialization_lock:
def get_model_roots(self) -> List[str]:
"""Get lora root directories"""
return config.loras_roots
async def scan_all_models(self) -> List[Dict]:
"""Scan all LoRA directories and return metadata"""
all_loras = []
# Create scan tasks for each directory
scan_tasks = []
for lora_root in self.get_model_roots():
task = asyncio.create_task(self._scan_directory(lora_root))
scan_tasks.append(task)
# 如果缓存未初始化但需要响应请求,返回空缓存
if self._cache is None and not force_refresh:
return LoraCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# 如果正在初始化,等待完成
if self._initialization_task and not self._initialization_task.done():
try:
await self._initialization_task
except Exception as e:
logger.error(f"Cache initialization failed: {e}")
self._initialization_task = None
if (self._cache is None or force_refresh):
# Wait for all tasks to complete
for task in scan_tasks:
try:
loras = await task
all_loras.extend(loras)
except Exception as e:
logger.error(f"Error scanning directory: {e}")
# 创建新的初始化任务
if not self._initialization_task or self._initialization_task.done():
self._initialization_task = asyncio.create_task(self._initialize_cache())
return all_loras
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a single directory for LoRA files"""
loras = []
original_root = root_path # Save original root path
async def scan_recursive(path: str, visited_paths: set):
"""Recursively scan directory, avoiding circular symlinks"""
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
try:
await self._initialization_task
except Exception as e:
logger.error(f"Cache initialization failed: {e}")
# 如果缓存已存在,继续使用旧缓存
if self._cache is None:
raise # 如果没有缓存,则抛出异常
return self._cache
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
# Use original path instead of real path
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, loras)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# For directories, continue scanning with original path
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
async def _initialize_cache(self) -> None:
"""Initialize or refresh the cache"""
await scan_recursive(root_path, set())
return loras
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
"""Process a single file and add to results list"""
try:
start_time = time.time()
# Clear existing hash index
self._hash_index.clear()
# Clear existing tags count
self._tags_count = {}
# Scan for new data
raw_data = await self.scan_all_loras()
# Build hash index and tags count
for lora_data in raw_data:
if 'sha256' in lora_data and 'file_path' in lora_data:
self._hash_index.add_entry(lora_data['sha256'].lower(), lora_data['file_path'])
# Count tags
if 'tags' in lora_data and lora_data['tags']:
for tag in lora_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Update cache
self._cache = LoraCache(
raw_data=raw_data,
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# Call resort_cache to create sorted views
await self._cache.resort()
self._initialization_task = None
logger.info(f"LoRA Manager: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} loras")
result = await self._process_model_file(file_path, root_path)
if result:
loras.append(result)
except Exception as e:
logger.error(f"LoRA Manager: Error initializing cache: {e}")
self._cache = LoraCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
logger.error(f"Error processing {file_path}: {e}")
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
folder: str = None, search: str = None, fuzzy_search: bool = False,
base_models: list = None, tags: list = None,
@@ -280,240 +263,14 @@ class LoraScanner:
return result
def invalidate_cache(self):
"""Invalidate the current cache"""
self._cache = None
async def scan_all_loras(self) -> List[Dict]:
"""Scan all LoRA directories and return metadata"""
all_loras = []
# 分目录异步扫描
scan_tasks = []
for loras_root in config.loras_roots:
task = asyncio.create_task(self._scan_directory(loras_root))
scan_tasks.append(task)
for task in scan_tasks:
try:
loras = await task
all_loras.extend(loras)
except Exception as e:
logger.error(f"Error scanning directory: {e}")
return all_loras
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a single directory for LoRA files"""
loras = []
original_root = root_path # 保存原始根路径
async def scan_recursive(path: str, visited_paths: set):
"""递归扫描目录,避免循环链接"""
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True) and entry.name.endswith('.safetensors'):
# 使用原始路径而不是真实路径
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, loras)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# 对于目录,使用原始路径继续扫描
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return loras
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
"""处理单个文件并添加到结果列表"""
try:
result = await self._process_lora_file(file_path, root_path)
if result:
loras.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
async def _process_lora_file(self, file_path: str, root_path: str) -> Dict:
"""Process a single LoRA file and return its metadata"""
# Try loading existing metadata
metadata = await load_metadata(file_path)
if metadata is None:
# Try to find and use .civitai.info file first
civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info"
if os.path.exists(civitai_info_path):
try:
with open(civitai_info_path, 'r', encoding='utf-8') as f:
version_info = json.load(f)
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
if file_info:
# Create a minimal file_info with the required fields
file_name = os.path.splitext(os.path.basename(file_path))[0]
file_info['name'] = file_name
# Use from_civitai_info to create metadata
metadata = LoraMetadata.from_civitai_info(version_info, file_info, file_path)
metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path))
await save_metadata(file_path, metadata)
logger.debug(f"Created metadata from .civitai.info for {file_path}")
except Exception as e:
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
# If still no metadata, create new metadata using get_file_info
if metadata is None:
metadata = await get_file_info(file_path)
# Convert to dict and add folder info
lora_data = metadata.to_dict()
# Try to fetch missing metadata from Civitai if needed
await self._fetch_missing_metadata(file_path, lora_data)
rel_path = os.path.relpath(file_path, root_path)
folder = os.path.dirname(rel_path)
lora_data['folder'] = folder.replace(os.path.sep, '/')
return lora_data
async def _fetch_missing_metadata(self, file_path: str, lora_data: Dict) -> None:
"""Fetch missing description and tags from Civitai if needed
Args:
file_path: Path to the lora file
lora_data: Lora metadata dictionary to update
"""
try:
# Skip if already marked as deleted on Civitai
if lora_data.get('civitai_deleted', False):
logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai")
return
# Check if we need to fetch additional metadata from Civitai
needs_metadata_update = False
model_id = None
# Check if we have Civitai model ID but missing metadata
if lora_data.get('civitai'):
# Try to get model ID directly from the correct location
model_id = lora_data['civitai'].get('modelId')
if model_id:
model_id = str(model_id)
# Check if tags are missing or empty
tags_missing = not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0
# Check if description is missing or empty
desc_missing = not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "")
needs_metadata_update = tags_missing or desc_missing
# Fetch missing metadata if needed
if needs_metadata_update and model_id:
logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}")
from ..services.civitai_client import CivitaiClient
client = CivitaiClient()
# Get metadata and status code
model_metadata, status_code = await client.get_model_metadata(model_id)
await client.close()
# Handle 404 status (model deleted from Civitai)
if status_code == 404:
logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)")
# Mark as deleted to avoid future API calls
lora_data['civitai_deleted'] = True
# Save the updated metadata back to file
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(lora_data, f, indent=2, ensure_ascii=False)
# Process valid metadata if available
elif model_metadata:
logger.debug(f"Updating metadata for {file_path} with model ID {model_id}")
# Update tags if they were missing
if model_metadata.get('tags') and (not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0):
lora_data['tags'] = model_metadata['tags']
# Update description if it was missing
if model_metadata.get('description') and (not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "")):
lora_data['modelDescription'] = model_metadata['description']
# Save the updated metadata back to file
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(lora_data, f, indent=2, ensure_ascii=False)
except Exception as e:
logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}")
async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool:
"""Update preview URL in cache for a specific lora
Args:
file_path: The file path of the lora to update
preview_url: The new preview URL
Returns:
bool: True if the update was successful, False if cache doesn't exist or lora wasn't found
"""
if self._cache is None:
return False
return await self._cache.update_preview_url(file_path, preview_url)
async def scan_single_lora(self, file_path: str) -> Optional[Dict]:
"""Scan a single LoRA file and return its metadata"""
try:
if not os.path.exists(os.path.realpath(file_path)):
return None
# 获取基本文件信息
metadata = await get_file_info(file_path)
if not metadata:
return None
folder = self._calculate_folder(file_path)
# 确保 folder 字段存在
metadata_dict = metadata.to_dict()
metadata_dict['folder'] = folder or ''
return metadata_dict
except Exception as e:
logger.error(f"Error scanning {file_path}: {e}")
return None
def _calculate_folder(self, file_path: str) -> str:
"""Calculate the folder path for a LoRA file"""
# 使用原始路径计算相对路径
for root in config.loras_roots:
if file_path.startswith(root):
rel_path = os.path.relpath(file_path, root)
return os.path.dirname(rel_path).replace(os.path.sep, '/')
return ''
async def move_model(self, source_path: str, target_path: str) -> bool:
"""Move a model and its associated files to a new location"""
try:
# 保持原始路径格式
# Keep original path format
source_path = source_path.replace(os.sep, '/')
target_path = target_path.replace(os.sep, '/')
# 其余代码保持不变
# Rest of the code remains unchanged
base_name = os.path.splitext(os.path.basename(source_path))[0]
source_dir = os.path.dirname(source_path)
@@ -521,7 +278,7 @@ class LoraScanner:
target_lora = os.path.join(target_path, f"{base_name}.safetensors").replace(os.sep, '/')
# 使用真实路径进行文件操作
# Use real paths for file operations
real_source = os.path.realpath(source_path)
real_target = os.path.realpath(target_lora)
@@ -537,7 +294,7 @@ class LoraScanner:
file_size
)
# 使用真实路径进行文件操作
# Use real paths for file operations
shutil.move(real_source, real_target)
# Move associated files
@@ -648,7 +405,7 @@ class LoraScanner:
except Exception as e:
logger.error(f"Error updating metadata paths: {e}", exc_info=True)
# Add new methods for hash index functionality
# Lora-specific hash index functionality
def has_lora_hash(self, sha256: str) -> bool:
"""Check if a LoRA with given hash exists"""
return self._hash_index.has_hash(sha256.lower())
@@ -681,16 +438,8 @@ class LoraScanner:
return None
# Add new method to get top tags
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get top tags sorted by count
Args:
limit: Maximum number of tags to return
Returns:
List of dictionaries with tag name and count, sorted by count
"""
"""Get top tags sorted by count"""
# Make sure cache is initialized
await self.get_cached_data()
@@ -705,14 +454,7 @@ class LoraScanner:
return sorted_tags[:limit]
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get base models used in loras sorted by frequency
Args:
limit: Maximum number of base models to return
Returns:
List of dictionaries with base model name and count, sorted by count
"""
"""Get base models used in loras sorted by frequency"""
# Make sure cache is initialized
cache = await self.get_cached_data()

View File

@@ -0,0 +1,64 @@
import asyncio
from typing import List, Dict
from dataclasses import dataclass
from operator import itemgetter
@dataclass
class ModelCache:
"""Cache structure for model data"""
raw_data: List[Dict]
sorted_by_name: List[Dict]
sorted_by_date: List[Dict]
folders: List[str]
def __post_init__(self):
self._lock = asyncio.Lock()
async def resort(self, name_only: bool = False):
"""Resort all cached data views"""
async with self._lock:
self.sorted_by_name = sorted(
self.raw_data,
key=lambda x: x['model_name'].lower() # Case-insensitive sort
)
if not name_only:
self.sorted_by_date = sorted(
self.raw_data,
key=itemgetter('modified'),
reverse=True
)
# Update folder list
all_folders = set(l['folder'] for l in self.raw_data)
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
async def update_preview_url(self, file_path: str, preview_url: str) -> bool:
"""Update preview_url for a specific model in all cached data
Args:
file_path: The file path of the model to update
preview_url: The new preview URL
Returns:
bool: True if the update was successful, False if the model wasn't found
"""
async with self._lock:
# Update in raw_data
for item in self.raw_data:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
else:
return False # Model not found
# Update in sorted lists (references to the same dict objects)
for item in self.sorted_by_name:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
for item in self.sorted_by_date:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
return True

View File

@@ -0,0 +1,78 @@
from typing import Dict, Optional, Set
class ModelHashIndex:
"""Index for looking up models by hash or path"""
def __init__(self):
self._hash_to_path: Dict[str, str] = {}
self._path_to_hash: Dict[str, str] = {}
def add_entry(self, sha256: str, file_path: str) -> None:
"""Add or update hash index entry"""
if not sha256 or not file_path:
return
# Ensure hash is lowercase for consistency
sha256 = sha256.lower()
# Remove old path mapping if hash exists
if sha256 in self._hash_to_path:
old_path = self._hash_to_path[sha256]
if old_path in self._path_to_hash:
del self._path_to_hash[old_path]
# Remove old hash mapping if path exists
if file_path in self._path_to_hash:
old_hash = self._path_to_hash[file_path]
if old_hash in self._hash_to_path:
del self._hash_to_path[old_hash]
# Add new mappings
self._hash_to_path[sha256] = file_path
self._path_to_hash[file_path] = sha256
def remove_by_path(self, file_path: str) -> None:
"""Remove entry by file path"""
if file_path in self._path_to_hash:
hash_val = self._path_to_hash[file_path]
if hash_val in self._hash_to_path:
del self._hash_to_path[hash_val]
del self._path_to_hash[file_path]
def remove_by_hash(self, sha256: str) -> None:
"""Remove entry by hash"""
sha256 = sha256.lower()
if sha256 in self._hash_to_path:
path = self._hash_to_path[sha256]
if path in self._path_to_hash:
del self._path_to_hash[path]
del self._hash_to_path[sha256]
def has_hash(self, sha256: str) -> bool:
"""Check if hash exists in index"""
return sha256.lower() in self._hash_to_path
def get_path(self, sha256: str) -> Optional[str]:
"""Get file path for a hash"""
return self._hash_to_path.get(sha256.lower())
def get_hash(self, file_path: str) -> Optional[str]:
"""Get hash for a file path"""
return self._path_to_hash.get(file_path)
def clear(self) -> None:
"""Clear all entries"""
self._hash_to_path.clear()
self._path_to_hash.clear()
def get_all_hashes(self) -> Set[str]:
"""Get all hashes in the index"""
return set(self._hash_to_path.keys())
def get_all_paths(self) -> Set[str]:
"""Get all file paths in the index"""
return set(self._path_to_hash.keys())
def __len__(self) -> int:
"""Get number of entries"""
return len(self._hash_to_path)

View File

@@ -0,0 +1,554 @@
import json
import os
import logging
import asyncio
import time
import shutil
from typing import List, Dict, Optional, Type, Set
from ..utils.models import BaseModelMetadata
from ..config import config
from ..utils.file_utils import load_metadata, get_file_info, find_preview_file, save_metadata
from .model_cache import ModelCache
from .model_hash_index import ModelHashIndex
logger = logging.getLogger(__name__)
class ModelScanner:
"""Base service for scanning and managing model files"""
_instance = None
_lock = asyncio.Lock()
def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None):
"""Initialize the scanner
Args:
model_type: Type of model (lora, checkpoint, etc.)
model_class: Class used to create metadata instances
file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'})
hash_index: Hash index instance (optional)
"""
self.model_type = model_type
self.model_class = model_class
self.file_extensions = file_extensions
self._cache = None
self._hash_index = hash_index or ModelHashIndex()
self._initialization_lock = asyncio.Lock()
self._initialization_task = None
self.file_monitor = None
self._tags_count = {} # Dictionary to store tag counts
def set_file_monitor(self, monitor):
"""Set file monitor instance"""
self.file_monitor = monitor
async def get_cached_data(self, force_refresh: bool = False) -> ModelCache:
"""Get cached model data, refresh if needed"""
async with self._initialization_lock:
# Return empty cache if not initialized and no refresh requested
if self._cache is None and not force_refresh:
return ModelCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# Wait for ongoing initialization if any
if self._initialization_task and not self._initialization_task.done():
try:
await self._initialization_task
except Exception as e:
logger.error(f"Cache initialization failed: {e}")
self._initialization_task = None
if (self._cache is None or force_refresh):
# Create new initialization task
if not self._initialization_task or self._initialization_task.done():
self._initialization_task = asyncio.create_task(self._initialize_cache())
try:
await self._initialization_task
except Exception as e:
logger.error(f"Cache initialization failed: {e}")
# Continue using old cache if it exists
if self._cache is None:
raise # Raise exception if no cache available
return self._cache
async def _initialize_cache(self) -> None:
"""Initialize or refresh the cache"""
try:
start_time = time.time()
# Clear existing hash index
self._hash_index.clear()
# Clear existing tags count
self._tags_count = {}
# Scan for new data
raw_data = await self.scan_all_models()
# Build hash index and tags count
for model_data in raw_data:
if 'sha256' in model_data and 'file_path' in model_data:
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
# Count tags
if 'tags' in model_data and model_data['tags']:
for tag in model_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Update cache
self._cache = ModelCache(
raw_data=raw_data,
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# Resort cache
await self._cache.resort()
self._initialization_task = None
logger.info(f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} models")
except Exception as e:
logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache: {e}")
self._cache = ModelCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# These methods should be implemented in child classes
async def scan_all_models(self) -> List[Dict]:
"""Scan all model directories and return metadata"""
raise NotImplementedError("Subclasses must implement scan_all_models")
def get_model_roots(self) -> List[str]:
"""Get model root directories"""
raise NotImplementedError("Subclasses must implement get_model_roots")
async def scan_single_model(self, file_path: str) -> Optional[Dict]:
"""Scan a single model file and return its metadata"""
try:
if not os.path.exists(os.path.realpath(file_path)):
return None
# Get basic file info
metadata = await self._get_file_info(file_path)
if not metadata:
return None
folder = self._calculate_folder(file_path)
# Ensure folder field exists
metadata_dict = metadata.to_dict()
metadata_dict['folder'] = folder or ''
return metadata_dict
except Exception as e:
logger.error(f"Error scanning {file_path}: {e}")
return None
async def _get_file_info(self, file_path: str) -> Optional[BaseModelMetadata]:
"""Get model file info and metadata (extensible for different model types)"""
# Implementation may vary by model type - override in subclasses if needed
return await get_file_info(file_path, self.model_class)
def _calculate_folder(self, file_path: str) -> str:
"""Calculate the folder path for a model file"""
# Use original path to calculate relative path
for root in self.get_model_roots():
if file_path.startswith(root):
rel_path = os.path.relpath(file_path, root)
return os.path.dirname(rel_path).replace(os.path.sep, '/')
return ''
# Common methods shared between scanners
async def _process_model_file(self, file_path: str, root_path: str) -> Dict:
"""Process a single model file and return its metadata"""
# Try loading existing metadata
metadata = await load_metadata(file_path, self.model_class)
if metadata is None:
# Try to find and use .civitai.info file first
civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info"
if os.path.exists(civitai_info_path):
try:
with open(civitai_info_path, 'r', encoding='utf-8') as f:
version_info = json.load(f)
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
if file_info:
# Create a minimal file_info with the required fields
file_name = os.path.splitext(os.path.basename(file_path))[0]
file_info['name'] = file_name
# Use from_civitai_info to create metadata
metadata = self.model_class.from_civitai_info(version_info, file_info, file_path)
metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path))
await save_metadata(file_path, metadata)
logger.debug(f"Created metadata from .civitai.info for {file_path}")
except Exception as e:
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
# If still no metadata, create new metadata
if metadata is None:
metadata = await self._get_file_info(file_path)
# Convert to dict and add folder info
model_data = metadata.to_dict()
# Try to fetch missing metadata from Civitai if needed
await self._fetch_missing_metadata(file_path, model_data)
rel_path = os.path.relpath(file_path, root_path)
folder = os.path.dirname(rel_path)
model_data['folder'] = folder.replace(os.path.sep, '/')
return model_data
async def _fetch_missing_metadata(self, file_path: str, model_data: Dict) -> None:
"""Fetch missing description and tags from Civitai if needed"""
try:
# Skip if already marked as deleted on Civitai
if model_data.get('civitai_deleted', False):
logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai")
return
# Check if we need to fetch additional metadata from Civitai
needs_metadata_update = False
model_id = None
# Check if we have Civitai model ID but missing metadata
if model_data.get('civitai'):
model_id = model_data['civitai'].get('modelId')
if model_id:
model_id = str(model_id)
# Check if tags or description are missing
tags_missing = not model_data.get('tags') or len(model_data.get('tags', [])) == 0
desc_missing = not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")
needs_metadata_update = tags_missing or desc_missing
# Fetch missing metadata if needed
if needs_metadata_update and model_id:
logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}")
from ..services.civitai_client import CivitaiClient
client = CivitaiClient()
# Get metadata and status code
model_metadata, status_code = await client.get_model_metadata(model_id)
await client.close()
# Handle 404 status (model deleted from Civitai)
if status_code == 404:
logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)")
model_data['civitai_deleted'] = True
# Save the updated metadata
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(model_data, f, indent=2, ensure_ascii=False)
# Process valid metadata if available
elif model_metadata:
logger.debug(f"Updating metadata for {file_path} with model ID {model_id}")
# Update tags if they were missing
if model_metadata.get('tags') and (not model_data.get('tags') or len(model_data.get('tags', [])) == 0):
model_data['tags'] = model_metadata['tags']
# Update description if it was missing
if model_metadata.get('description') and (not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")):
model_data['modelDescription'] = model_metadata['description']
# Save the updated metadata
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(model_data, f, indent=2, ensure_ascii=False)
except Exception as e:
logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}")
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Base implementation for directory scanning"""
models = []
original_root = root_path
async def scan_recursive(path: str, visited_paths: set):
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True):
# Check if file has supported extension
ext = os.path.splitext(entry.name)[1].lower()
if ext in self.file_extensions:
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, models)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# For directories, continue scanning with original path
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return models
async def _process_single_file(self, file_path: str, root_path: str, models_list: list):
"""Process a single file and add to results list"""
try:
result = await self._process_model_file(file_path, root_path)
if result:
models_list.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
async def move_model(self, source_path: str, target_path: str) -> bool:
"""Move a model and its associated files to a new location"""
try:
# Keep original path format
source_path = source_path.replace(os.sep, '/')
target_path = target_path.replace(os.sep, '/')
# Get file extension from source
file_ext = os.path.splitext(source_path)[1]
# If no extension or not in supported extensions, return False
if not file_ext or file_ext.lower() not in self.file_extensions:
logger.error(f"Invalid file extension for model: {file_ext}")
return False
base_name = os.path.splitext(os.path.basename(source_path))[0]
source_dir = os.path.dirname(source_path)
os.makedirs(target_path, exist_ok=True)
target_file = os.path.join(target_path, f"{base_name}{file_ext}").replace(os.sep, '/')
# Use real paths for file operations
real_source = os.path.realpath(source_path)
real_target = os.path.realpath(target_file)
file_size = os.path.getsize(real_source)
if self.file_monitor:
self.file_monitor.handler.add_ignore_path(
real_source,
file_size
)
self.file_monitor.handler.add_ignore_path(
real_target,
file_size
)
# Use real paths for file operations
shutil.move(real_source, real_target)
# Move associated files
source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json")
metadata = None
if os.path.exists(source_metadata):
target_metadata = os.path.join(target_path, f"{base_name}.metadata.json")
shutil.move(source_metadata, target_metadata)
metadata = await self._update_metadata_paths(target_metadata, target_file)
# Move preview file if exists
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
'.png', '.jpeg', '.jpg', '.mp4']
for ext in preview_extensions:
source_preview = os.path.join(source_dir, f"{base_name}{ext}")
if os.path.exists(source_preview):
target_preview = os.path.join(target_path, f"{base_name}{ext}")
shutil.move(source_preview, target_preview)
break
# Update cache
await self.update_single_model_cache(source_path, target_file, metadata)
return True
except Exception as e:
logger.error(f"Error moving model: {e}", exc_info=True)
return False
async def _update_metadata_paths(self, metadata_path: str, model_path: str) -> Dict:
"""Update file paths in metadata file"""
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
# Update file_path
metadata['file_path'] = model_path.replace(os.sep, '/')
# Update preview_url if exists
if 'preview_url' in metadata:
preview_dir = os.path.dirname(model_path)
preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0]
preview_ext = os.path.splitext(metadata['preview_url'])[1]
new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}")
metadata['preview_url'] = new_preview_path.replace(os.sep, '/')
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
return metadata
except Exception as e:
logger.error(f"Error updating metadata paths: {e}", exc_info=True)
return None
async def update_single_model_cache(self, original_path: str, new_path: str, metadata: Dict) -> bool:
"""Update cache after a model has been moved or modified"""
cache = await self.get_cached_data()
# Find the existing item to remove its tags from count
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
if existing_item and 'tags' in existing_item:
for tag in existing_item.get('tags', []):
if tag in self._tags_count:
self._tags_count[tag] = max(0, self._tags_count[tag] - 1)
if self._tags_count[tag] == 0:
del self._tags_count[tag]
# Remove old path from hash index if exists
self._hash_index.remove_by_path(original_path)
# Remove the old entry from raw_data
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != original_path
]
if metadata:
# If this is an update to an existing path (not a move), ensure folder is preserved
if original_path == new_path:
# Find the folder from existing entries or calculate it
existing_folder = next((item['folder'] for item in cache.raw_data
if item['file_path'] == original_path), None)
if existing_folder:
metadata['folder'] = existing_folder
else:
metadata['folder'] = self._calculate_folder(new_path)
else:
# For moved files, recalculate the folder
metadata['folder'] = self._calculate_folder(new_path)
# Add the updated metadata to raw_data
cache.raw_data.append(metadata)
# Update hash index with new path
if 'sha256' in metadata:
self._hash_index.add_entry(metadata['sha256'].lower(), new_path)
# Update folders list
all_folders = set(item['folder'] for item in cache.raw_data)
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Update tags count with the new/updated tags
if 'tags' in metadata:
for tag in metadata.get('tags', []):
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Resort cache
await cache.resort()
return True
# Hash index functionality (common for all model types)
def has_hash(self, sha256: str) -> bool:
"""Check if a model with given hash exists"""
return self._hash_index.has_hash(sha256.lower())
def get_path_by_hash(self, sha256: str) -> Optional[str]:
"""Get file path for a model by its hash"""
return self._hash_index.get_path(sha256.lower())
def get_hash_by_path(self, file_path: str) -> Optional[str]:
"""Get hash for a model by its file path"""
return self._hash_index.get_hash(file_path)
def get_preview_url_by_hash(self, sha256: str) -> Optional[str]:
"""Get preview static URL for a model by its hash"""
# Get the file path first
file_path = self._hash_index.get_path(sha256.lower())
if not file_path:
return None
# Determine the preview file path (typically same name with different extension)
base_name = os.path.splitext(file_path)[0]
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
'.png', '.jpeg', '.jpg', '.mp4']
for ext in preview_extensions:
preview_path = f"{base_name}{ext}"
if os.path.exists(preview_path):
# Convert to static URL using config
return config.get_preview_static_url(preview_path)
return None
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get top tags sorted by count"""
# Make sure cache is initialized
await self.get_cached_data()
# Sort tags by count in descending order
sorted_tags = sorted(
[{"tag": tag, "count": count} for tag, count in self._tags_count.items()],
key=lambda x: x['count'],
reverse=True
)
# Return limited number
return sorted_tags[:limit]
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get base models sorted by frequency"""
# Make sure cache is initialized
cache = await self.get_cached_data()
# Count base model occurrences
base_model_counts = {}
for model in cache.raw_data:
if 'base_model' in model and model['base_model']:
base_model = model['base_model']
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
# Sort base models by count
sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()]
sorted_models.sort(key=lambda x: x['count'], reverse=True)
# Return limited number
return sorted_models[:limit]
async def get_model_info_by_name(self, name):
"""Get model information by name"""
try:
# Get cached data
cache = await self.get_cached_data()
# Find the model by name
for model in cache.raw_data:
if model.get("file_name") == name:
return model
return None
except Exception as e:
logger.error(f"Error getting model info by name: {e}", exc_info=True)
return None

View File

@@ -2,12 +2,12 @@ import logging
import os
import hashlib
import json
from typing import Dict, Optional
import time
from typing import Dict, Optional, Type
from .model_utils import determine_base_model
from .lora_metadata import extract_lora_metadata
from .models import LoraMetadata
from .lora_metadata import extract_lora_metadata, extract_checkpoint_metadata
from .models import BaseModelMetadata, LoraMetadata, CheckpointMetadata
logger = logging.getLogger(__name__)
@@ -15,7 +15,7 @@ async def calculate_sha256(file_path: str) -> str:
"""Calculate SHA256 hash of a file"""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
for byte_block in iter(lambda: f.read(128 * 1024), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
@@ -42,8 +42,8 @@ def normalize_path(path: str) -> str:
"""Normalize file path to use forward slashes"""
return path.replace(os.sep, "/") if path else path
async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
"""Get basic file information as LoraMetadata object"""
async def get_file_info(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]:
"""Get basic file information as a model metadata object"""
# First check if file actually exists and resolve symlinks
try:
real_path = os.path.realpath(file_path)
@@ -74,27 +74,52 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
try:
# If we didn't get SHA256 from the .json file, calculate it
if not sha256:
start_time = time.time()
sha256 = await calculate_sha256(real_path)
logger.debug(f"Calculated SHA256 for {file_path} in {time.time() - start_time:.2f} seconds")
# Create default metadata based on model class
if model_class == CheckpointMetadata:
metadata = CheckpointMetadata(
file_name=base_name,
model_name=base_name,
file_path=normalize_path(file_path),
size=os.path.getsize(real_path),
modified=os.path.getmtime(real_path),
sha256=sha256,
base_model="Unknown", # Will be updated later
preview_url=normalize_path(preview_url),
tags=[],
modelDescription="",
model_type="checkpoint"
)
metadata = LoraMetadata(
file_name=base_name,
model_name=base_name,
file_path=normalize_path(file_path),
size=os.path.getsize(real_path),
modified=os.path.getmtime(real_path),
sha256=sha256,
base_model="Unknown", # Will be updated later
usage_tips="",
notes="",
from_civitai=True,
preview_url=normalize_path(preview_url),
tags=[],
modelDescription=""
)
# Extract checkpoint-specific metadata
# model_info = await extract_checkpoint_metadata(real_path)
# metadata.base_model = model_info['base_model']
# if 'model_type' in model_info:
# metadata.model_type = model_info['model_type']
else: # Default to LoraMetadata
metadata = LoraMetadata(
file_name=base_name,
model_name=base_name,
file_path=normalize_path(file_path),
size=os.path.getsize(real_path),
modified=os.path.getmtime(real_path),
sha256=sha256,
base_model="Unknown", # Will be updated later
usage_tips="{}",
preview_url=normalize_path(preview_url),
tags=[],
modelDescription=""
)
# Extract lora-specific metadata
model_info = await extract_lora_metadata(real_path)
metadata.base_model = model_info['base_model']
# create metadata file
base_model_info = await extract_lora_metadata(real_path)
metadata.base_model = base_model_info['base_model']
# Save metadata to file
await save_metadata(file_path, metadata)
return metadata
@@ -102,7 +127,7 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
logger.error(f"Error getting file info for {file_path}: {e}")
return None
async def save_metadata(file_path: str, metadata: LoraMetadata) -> None:
async def save_metadata(file_path: str, metadata: BaseModelMetadata) -> None:
"""Save metadata to .metadata.json file"""
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
try:
@@ -115,7 +140,7 @@ async def save_metadata(file_path: str, metadata: LoraMetadata) -> None:
except Exception as e:
print(f"Error saving metadata to {metadata_path}: {str(e)}")
async def load_metadata(file_path: str) -> Optional[LoraMetadata]:
async def load_metadata(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]:
"""Load metadata from .metadata.json file"""
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
try:
@@ -162,12 +187,22 @@ async def load_metadata(file_path: str) -> Optional[LoraMetadata]:
if 'modelDescription' not in data:
data['modelDescription'] = ""
needs_update = True
# For checkpoint metadata
if model_class == CheckpointMetadata and 'model_type' not in data:
data['model_type'] = "checkpoint"
needs_update = True
# For lora metadata
if model_class == LoraMetadata and 'usage_tips' not in data:
data['usage_tips'] = "{}"
needs_update = True
if needs_update:
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
return LoraMetadata.from_dict(data)
return model_class.from_dict(data)
except Exception as e:
print(f"Error loading metadata from {metadata_path}: {str(e)}")

View File

@@ -1,6 +1,7 @@
from safetensors import safe_open
from typing import Dict
from .model_utils import determine_base_model
import os
async def extract_lora_metadata(file_path: str) -> Dict:
"""Extract essential metadata from safetensors file"""
@@ -13,4 +14,67 @@ async def extract_lora_metadata(file_path: str) -> Dict:
return {"base_model": base_model}
except Exception as e:
print(f"Error reading metadata from {file_path}: {str(e)}")
return {"base_model": "Unknown"}
return {"base_model": "Unknown"}
async def extract_checkpoint_metadata(file_path: str) -> dict:
"""Extract metadata from a checkpoint file to determine model type and base model"""
try:
# Analyze filename for clues about the model
filename = os.path.basename(file_path).lower()
model_info = {
'base_model': 'Unknown',
'model_type': 'checkpoint'
}
# Detect base model from filename
if 'xl' in filename or 'sdxl' in filename:
model_info['base_model'] = 'SDXL'
elif 'sd3' in filename:
model_info['base_model'] = 'SD3'
elif 'sd2' in filename or 'v2' in filename:
model_info['base_model'] = 'SD2.x'
elif 'sd1' in filename or 'v1' in filename:
model_info['base_model'] = 'SD1.5'
# Detect model type from filename
if 'inpaint' in filename:
model_info['model_type'] = 'inpainting'
elif 'anime' in filename:
model_info['model_type'] = 'anime'
elif 'realistic' in filename:
model_info['model_type'] = 'realistic'
# Try to peek at the safetensors file structure if available
if file_path.endswith('.safetensors'):
import json
import struct
with open(file_path, 'rb') as f:
header_size = struct.unpack('<Q', f.read(8))[0]
header_json = f.read(header_size)
header = json.loads(header_json)
# Look for specific keys to identify model type
metadata = header.get('__metadata__', {})
if metadata:
# Try to determine if it's SDXL
if any(key.startswith('conditioner.embedders.1') for key in header):
model_info['base_model'] = 'SDXL'
# Look for model type info
if metadata.get('modelspec.architecture') == 'SD-XL':
model_info['base_model'] = 'SDXL'
elif metadata.get('modelspec.architecture') == 'SD-3':
model_info['base_model'] = 'SD3'
# Check for specific use case
if metadata.get('modelspec.purpose') == 'inpainting':
model_info['model_type'] = 'inpainting'
return model_info
except Exception as e:
logger.error(f"Error extracting checkpoint metadata for {file_path}: {e}")
# Return default values
return {'base_model': 'Unknown', 'model_type': 'checkpoint'}

View File

@@ -5,20 +5,19 @@ import os
from .model_utils import determine_base_model
@dataclass
class LoraMetadata:
"""Represents the metadata structure for a Lora model"""
file_name: str # The filename without extension of the lora
model_name: str # The lora's name defined by the creator, initially same as file_name
file_path: str # Full path to the safetensors file
class BaseModelMetadata:
"""Base class for all model metadata structures"""
file_name: str # The filename without extension
model_name: str # The model's name defined by the creator
file_path: str # Full path to the model file
size: int # File size in bytes
modified: float # Last modified timestamp
sha256: str # SHA256 hash of the file
base_model: str # Base model (SD1.5/SD2.1/SDXL/etc.)
base_model: str # Base model type (SD1.5/SD2.1/SDXL/etc.)
preview_url: str # Preview image URL
preview_nsfw_level: int = 0 # NSFW level of the preview image
usage_tips: str = "{}" # Usage tips for the model, json string
notes: str = "" # Additional notes
from_civitai: bool = True # Whether the lora is from Civitai
from_civitai: bool = True # Whether from Civitai
civitai: Optional[Dict] = None # Civitai API data if available
tags: List[str] = None # Model tags
modelDescription: str = "" # Full model description
@@ -29,32 +28,11 @@ class LoraMetadata:
self.tags = []
@classmethod
def from_dict(cls, data: Dict) -> 'LoraMetadata':
"""Create LoraMetadata instance from dictionary"""
# Create a copy of the data to avoid modifying the input
def from_dict(cls, data: Dict) -> 'BaseModelMetadata':
"""Create instance from dictionary"""
data_copy = data.copy()
return cls(**data_copy)
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata':
"""Create LoraMetadata instance from Civitai version info"""
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
file_path=save_path.replace(os.sep, '/'),
size=file_info.get('sizeKB', 0) * 1024,
modified=datetime.now().timestamp(),
sha256=file_info['hashes'].get('SHA256', '').lower(),
base_model=base_model,
preview_url=None, # Will be updated after preview download
preview_nsfw_level=0, # Will be updated after preview download, it is decided by the nsfw level of the preview image
from_civitai=True,
civitai=version_info
)
def to_dict(self) -> Dict:
"""Convert to dictionary for JSON serialization"""
return asdict(self)
@@ -76,30 +54,54 @@ class LoraMetadata:
self.file_path = file_path.replace(os.sep, '/')
@dataclass
class CheckpointMetadata:
"""Represents the metadata structure for a Checkpoint model"""
file_name: str # The filename without extension
model_name: str # The checkpoint's name defined by the creator
file_path: str # Full path to the model file
size: int # File size in bytes
modified: float # Last modified timestamp
sha256: str # SHA256 hash of the file
base_model: str # Base model type (SD1.5/SD2.1/SDXL/etc.)
preview_url: str # Preview image URL
preview_nsfw_level: int = 0 # NSFW level of the preview image
model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.)
notes: str = "" # Additional notes
from_civitai: bool = True # Whether from Civitai
civitai: Optional[Dict] = None # Civitai API data if available
tags: List[str] = None # Model tags
modelDescription: str = "" # Full model description
# Additional checkpoint-specific fields
resolution: Optional[str] = None # Native resolution (e.g., 512x512, 1024x1024)
vae_included: bool = False # Whether VAE is included in the checkpoint
architecture: str = "" # Model architecture (if known)
def __post_init__(self):
if self.tags is None:
self.tags = []
class LoraMetadata(BaseModelMetadata):
"""Represents the metadata structure for a Lora model"""
usage_tips: str = "{}" # Usage tips for the model, json string
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata':
"""Create LoraMetadata instance from Civitai version info"""
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
file_path=save_path.replace(os.sep, '/'),
size=file_info.get('sizeKB', 0) * 1024,
modified=datetime.now().timestamp(),
sha256=file_info['hashes'].get('SHA256', '').lower(),
base_model=base_model,
preview_url=None, # Will be updated after preview download
preview_nsfw_level=0, # Will be updated after preview download
from_civitai=True,
civitai=version_info
)
@dataclass
class CheckpointMetadata(BaseModelMetadata):
"""Represents the metadata structure for a Checkpoint model"""
model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.)
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'CheckpointMetadata':
"""Create CheckpointMetadata instance from Civitai version info"""
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
model_type = version_info.get('type', 'checkpoint')
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
file_path=save_path.replace(os.sep, '/'),
size=file_info.get('sizeKB', 0) * 1024,
modified=datetime.now().timestamp(),
sha256=file_info['hashes'].get('SHA256', '').lower(),
base_model=base_model,
preview_url=None, # Will be updated after preview download
preview_nsfw_level=0,
from_civitai=True,
civitai=version_info,
model_type=model_type
)