mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
checkpoint
This commit is contained in:
@@ -73,7 +73,7 @@ class Config:
|
|||||||
"""添加静态路由映射"""
|
"""添加静态路由映射"""
|
||||||
normalized_path = os.path.normpath(path).replace(os.sep, '/')
|
normalized_path = os.path.normpath(path).replace(os.sep, '/')
|
||||||
self._route_mappings[normalized_path] = route
|
self._route_mappings[normalized_path] = route
|
||||||
logger.info(f"Added route mapping: {normalized_path} -> {route}")
|
# logger.info(f"Added route mapping: {normalized_path} -> {route}")
|
||||||
|
|
||||||
def map_path_to_link(self, path: str) -> str:
|
def map_path_to_link(self, path: str) -> str:
|
||||||
"""将目标路径映射回符号链接路径"""
|
"""将目标路径映射回符号链接路径"""
|
||||||
|
|||||||
@@ -7,10 +7,12 @@ from .routes.api_routes import ApiRoutes
|
|||||||
from .routes.recipe_routes import RecipeRoutes
|
from .routes.recipe_routes import RecipeRoutes
|
||||||
from .routes.checkpoints_routes import CheckpointsRoutes
|
from .routes.checkpoints_routes import CheckpointsRoutes
|
||||||
from .services.lora_scanner import LoraScanner
|
from .services.lora_scanner import LoraScanner
|
||||||
|
from .services.checkpoint_scanner import CheckpointScanner
|
||||||
from .services.recipe_scanner import RecipeScanner
|
from .services.recipe_scanner import RecipeScanner
|
||||||
from .services.file_monitor import LoraFileMonitor
|
from .services.file_monitor import LoraFileMonitor, CheckpointFileMonitor
|
||||||
from .services.lora_cache import LoraCache
|
from .services.lora_cache import LoraCache
|
||||||
from .services.recipe_cache import RecipeCache
|
from .services.recipe_cache import RecipeCache
|
||||||
|
from .services.model_cache import ModelCache
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -23,7 +25,7 @@ class LoraManager:
|
|||||||
"""Initialize and register all routes"""
|
"""Initialize and register all routes"""
|
||||||
app = PromptServer.instance.app
|
app = PromptServer.instance.app
|
||||||
|
|
||||||
added_targets = set() # 用于跟踪已添加的目标路径
|
added_targets = set() # Track already added target paths
|
||||||
|
|
||||||
# Add static routes for each lora root
|
# Add static routes for each lora root
|
||||||
for idx, root in enumerate(config.loras_roots, start=1):
|
for idx, root in enumerate(config.loras_roots, start=1):
|
||||||
@@ -35,15 +37,34 @@ class LoraManager:
|
|||||||
if link == root:
|
if link == root:
|
||||||
real_root = target
|
real_root = target
|
||||||
break
|
break
|
||||||
# 为原始路径添加静态路由
|
# Add static route for original path
|
||||||
app.router.add_static(preview_path, real_root)
|
app.router.add_static(preview_path, real_root)
|
||||||
logger.info(f"Added static route {preview_path} -> {real_root}")
|
logger.info(f"Added static route {preview_path} -> {real_root}")
|
||||||
|
|
||||||
# 记录路由映射
|
# Record route mapping
|
||||||
config.add_route_mapping(real_root, preview_path)
|
config.add_route_mapping(real_root, preview_path)
|
||||||
added_targets.add(real_root)
|
added_targets.add(real_root)
|
||||||
|
|
||||||
# 为符号链接的目标路径添加额外的静态路由
|
# Add static routes for each checkpoint root
|
||||||
|
checkpoint_scanner = CheckpointScanner()
|
||||||
|
for idx, root in enumerate(checkpoint_scanner.get_model_roots(), start=1):
|
||||||
|
preview_path = f'/checkpoints_static/root{idx}/preview'
|
||||||
|
|
||||||
|
real_root = root
|
||||||
|
if root in config._path_mappings.values():
|
||||||
|
for target, link in config._path_mappings.items():
|
||||||
|
if link == root:
|
||||||
|
real_root = target
|
||||||
|
break
|
||||||
|
# Add static route for original path
|
||||||
|
app.router.add_static(preview_path, real_root)
|
||||||
|
logger.info(f"Added static route {preview_path} -> {real_root}")
|
||||||
|
|
||||||
|
# Record route mapping
|
||||||
|
config.add_route_mapping(real_root, preview_path)
|
||||||
|
added_targets.add(real_root)
|
||||||
|
|
||||||
|
# Add static routes for symlink target paths
|
||||||
link_idx = 1
|
link_idx = 1
|
||||||
|
|
||||||
for target_path, link_path in config._path_mappings.items():
|
for target_path, link_path in config._path_mappings.items():
|
||||||
@@ -59,37 +80,47 @@ class LoraManager:
|
|||||||
app.router.add_static('/loras_static', config.static_path)
|
app.router.add_static('/loras_static', config.static_path)
|
||||||
|
|
||||||
# Setup feature routes
|
# Setup feature routes
|
||||||
routes = LoraRoutes()
|
lora_routes = LoraRoutes()
|
||||||
checkpoints_routes = CheckpointsRoutes()
|
checkpoints_routes = CheckpointsRoutes()
|
||||||
|
|
||||||
# Setup file monitoring
|
# Setup file monitoring
|
||||||
monitor = LoraFileMonitor(routes.scanner, config.loras_roots)
|
lora_monitor = LoraFileMonitor(lora_routes.scanner, config.loras_roots)
|
||||||
monitor.start()
|
lora_monitor.start()
|
||||||
|
|
||||||
routes.setup_routes(app)
|
checkpoint_monitor = CheckpointFileMonitor(checkpoints_routes.scanner, checkpoints_routes.scanner.get_model_roots())
|
||||||
|
checkpoint_monitor.start()
|
||||||
|
|
||||||
|
lora_routes.setup_routes(app)
|
||||||
checkpoints_routes.setup_routes(app)
|
checkpoints_routes.setup_routes(app)
|
||||||
ApiRoutes.setup_routes(app, monitor)
|
ApiRoutes.setup_routes(app, lora_monitor)
|
||||||
RecipeRoutes.setup_routes(app)
|
RecipeRoutes.setup_routes(app)
|
||||||
|
|
||||||
# Store monitor in app for cleanup
|
# Store monitors in app for cleanup
|
||||||
app['lora_monitor'] = monitor
|
app['lora_monitor'] = lora_monitor
|
||||||
|
app['checkpoint_monitor'] = checkpoint_monitor
|
||||||
|
|
||||||
|
logger.info("PromptServer app: ", app)
|
||||||
|
|
||||||
# Schedule cache initialization using the application's startup handler
|
# Schedule cache initialization using the application's startup handler
|
||||||
app.on_startup.append(lambda app: cls._schedule_cache_init(routes.scanner, routes.recipe_scanner))
|
app.on_startup.append(lambda app: cls._schedule_cache_init(
|
||||||
|
lora_routes.scanner,
|
||||||
|
checkpoints_routes.scanner,
|
||||||
|
lora_routes.recipe_scanner
|
||||||
|
))
|
||||||
|
|
||||||
# Add cleanup
|
# Add cleanup
|
||||||
app.on_shutdown.append(cls._cleanup)
|
app.on_shutdown.append(cls._cleanup)
|
||||||
app.on_shutdown.append(ApiRoutes.cleanup)
|
app.on_shutdown.append(ApiRoutes.cleanup)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _schedule_cache_init(cls, scanner: LoraScanner, recipe_scanner: RecipeScanner):
|
async def _schedule_cache_init(cls, lora_scanner, checkpoint_scanner, recipe_scanner):
|
||||||
"""Schedule cache initialization in the running event loop"""
|
"""Schedule cache initialization in the running event loop"""
|
||||||
try:
|
try:
|
||||||
# 创建低优先级的初始化任务
|
# Create low-priority initialization tasks
|
||||||
lora_task = asyncio.create_task(cls._initialize_lora_cache(scanner), name='lora_cache_init')
|
lora_task = asyncio.create_task(cls._initialize_lora_cache(lora_scanner), name='lora_cache_init')
|
||||||
|
checkpoint_task = asyncio.create_task(cls._initialize_checkpoint_cache(checkpoint_scanner), name='checkpoint_cache_init')
|
||||||
# Schedule recipe cache initialization with a delay to let lora scanner initialize first
|
recipe_task = asyncio.create_task(cls._initialize_recipe_cache(recipe_scanner), name='recipe_cache_init')
|
||||||
recipe_task = asyncio.create_task(cls._initialize_recipe_cache(recipe_scanner, delay=2), name='recipe_cache_init')
|
logger.info("Cache initialization tasks scheduled to run in background")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LoRA Manager: Error scheduling cache initialization: {e}")
|
logger.error(f"LoRA Manager: Error scheduling cache initialization: {e}")
|
||||||
|
|
||||||
@@ -97,26 +128,45 @@ class LoraManager:
|
|||||||
async def _initialize_lora_cache(cls, scanner: LoraScanner):
|
async def _initialize_lora_cache(cls, scanner: LoraScanner):
|
||||||
"""Initialize lora cache in background"""
|
"""Initialize lora cache in background"""
|
||||||
try:
|
try:
|
||||||
# 设置初始缓存占位
|
# Set initial placeholder cache
|
||||||
scanner._cache = LoraCache(
|
scanner._cache = LoraCache(
|
||||||
raw_data=[],
|
raw_data=[],
|
||||||
sorted_by_name=[],
|
sorted_by_name=[],
|
||||||
sorted_by_date=[],
|
sorted_by_date=[],
|
||||||
folders=[]
|
folders=[]
|
||||||
)
|
)
|
||||||
|
# 使用线程池执行耗时操作
|
||||||
# 分阶段加载缓存
|
loop = asyncio.get_event_loop()
|
||||||
await scanner.get_cached_data(force_refresh=True)
|
await loop.run_in_executor(
|
||||||
|
None, # 使用默认线程池
|
||||||
|
lambda: scanner.get_cached_data_sync(force_refresh=True) # 创建同步版本的方法
|
||||||
|
)
|
||||||
|
# Load cache in phases
|
||||||
|
# await scanner.get_cached_data(force_refresh=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LoRA Manager: Error initializing lora cache: {e}")
|
logger.error(f"LoRA Manager: Error initializing lora cache: {e}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _initialize_recipe_cache(cls, scanner: RecipeScanner, delay: float = 2.0):
|
async def _initialize_checkpoint_cache(cls, scanner: CheckpointScanner):
|
||||||
|
"""Initialize checkpoint cache in background"""
|
||||||
|
try:
|
||||||
|
# Set initial placeholder cache
|
||||||
|
scanner._cache = ModelCache(
|
||||||
|
raw_data=[],
|
||||||
|
sorted_by_name=[],
|
||||||
|
sorted_by_date=[],
|
||||||
|
folders=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load cache in phases
|
||||||
|
await scanner.get_cached_data(force_refresh=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LoRA Manager: Error initializing checkpoint cache: {e}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _initialize_recipe_cache(cls, scanner: RecipeScanner):
|
||||||
"""Initialize recipe cache in background with a delay"""
|
"""Initialize recipe cache in background with a delay"""
|
||||||
try:
|
try:
|
||||||
# Wait for the specified delay to let lora scanner initialize first
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
|
|
||||||
# Set initial empty cache
|
# Set initial empty cache
|
||||||
scanner._cache = RecipeCache(
|
scanner._cache = RecipeCache(
|
||||||
raw_data=[],
|
raw_data=[],
|
||||||
@@ -134,3 +184,6 @@ class LoraManager:
|
|||||||
"""Cleanup resources"""
|
"""Cleanup resources"""
|
||||||
if 'lora_monitor' in app:
|
if 'lora_monitor' in app:
|
||||||
app['lora_monitor'].stop()
|
app['lora_monitor'].stop()
|
||||||
|
|
||||||
|
if 'checkpoint_monitor' in app:
|
||||||
|
app['checkpoint_monitor'].stop()
|
||||||
|
|||||||
@@ -1,44 +1,146 @@
|
|||||||
import os
|
import os
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import jinja2
|
|
||||||
import logging
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from ..services.checkpoint_scanner import CheckpointScanner
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..services.settings_manager import settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
|
|
||||||
|
|
||||||
class CheckpointsRoutes:
|
class CheckpointsRoutes:
|
||||||
"""Route handlers for Checkpoints management endpoints"""
|
"""API routes for checkpoint management"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.template_env = jinja2.Environment(
|
self.scanner = CheckpointScanner()
|
||||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
|
||||||
autoescape=True
|
|
||||||
)
|
|
||||||
|
|
||||||
async def handle_checkpoints_page(self, request: web.Request) -> web.Response:
|
def setup_routes(self, app):
|
||||||
"""Handle GET /checkpoints request"""
|
"""Register routes with the aiohttp app"""
|
||||||
|
app.router.add_get('/lora_manager/api/checkpoints', self.get_checkpoints)
|
||||||
|
app.router.add_get('/lora_manager/api/checkpoints/scan', self.scan_checkpoints)
|
||||||
|
app.router.add_get('/lora_manager/api/checkpoints/info/{name}', self.get_checkpoint_info)
|
||||||
|
|
||||||
|
async def get_checkpoints(self, request):
|
||||||
|
"""Get paginated checkpoint data"""
|
||||||
try:
|
try:
|
||||||
template = self.template_env.get_template('checkpoints.html')
|
# Parse query parameters
|
||||||
rendered = template.render(
|
page = int(request.query.get('page', '1'))
|
||||||
is_initializing=False,
|
page_size = min(int(request.query.get('page_size', '20')), 100)
|
||||||
settings=settings,
|
sort_by = request.query.get('sort', 'name')
|
||||||
request=request
|
folder = request.query.get('folder', None)
|
||||||
|
search = request.query.get('search', None)
|
||||||
|
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
|
||||||
|
base_models = request.query.getall('base_model', [])
|
||||||
|
tags = request.query.getall('tag', [])
|
||||||
|
|
||||||
|
# Process search options
|
||||||
|
search_options = {
|
||||||
|
'filename': request.query.get('search_filename', 'true').lower() == 'true',
|
||||||
|
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
|
||||||
|
'tags': request.query.get('search_tags', 'false').lower() == 'true',
|
||||||
|
'recursive': request.query.get('recursive', 'false').lower() == 'true',
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process hash filters if provided
|
||||||
|
hash_filters = {}
|
||||||
|
if 'hash' in request.query:
|
||||||
|
hash_filters['single_hash'] = request.query['hash']
|
||||||
|
elif 'hashes' in request.query:
|
||||||
|
try:
|
||||||
|
hash_list = json.loads(request.query['hashes'])
|
||||||
|
if isinstance(hash_list, list):
|
||||||
|
hash_filters['multiple_hashes'] = hash_list
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Get data from scanner
|
||||||
|
result = await self.get_paginated_data(
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
sort_by=sort_by,
|
||||||
|
folder=folder,
|
||||||
|
search=search,
|
||||||
|
fuzzy_search=fuzzy_search,
|
||||||
|
base_models=base_models,
|
||||||
|
tags=tags,
|
||||||
|
search_options=search_options,
|
||||||
|
hash_filters=hash_filters
|
||||||
)
|
)
|
||||||
|
|
||||||
return web.Response(
|
# Return as JSON
|
||||||
text=rendered,
|
return web.json_response(result)
|
||||||
content_type='text/html'
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error handling checkpoints request: {e}", exc_info=True)
|
logger.error(f"Error in get_checkpoints: {e}", exc_info=True)
|
||||||
return web.Response(
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
text="Error loading checkpoints page",
|
|
||||||
status=500
|
|
||||||
)
|
|
||||||
|
|
||||||
def setup_routes(self, app: web.Application):
|
async def get_paginated_data(self, page, page_size, sort_by='name',
|
||||||
"""Register routes with the application"""
|
folder=None, search=None, fuzzy_search=False,
|
||||||
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
|
base_models=None, tags=None,
|
||||||
|
search_options=None, hash_filters=None):
|
||||||
|
"""Get paginated and filtered checkpoint data"""
|
||||||
|
cache = await self.scanner.get_cached_data()
|
||||||
|
|
||||||
|
# Implement similar filtering logic as in LoraScanner
|
||||||
|
# (Adapt code from LoraScanner.get_paginated_data)
|
||||||
|
# ...
|
||||||
|
|
||||||
|
# For now, a simplified implementation:
|
||||||
|
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
||||||
|
|
||||||
|
# Apply basic folder filtering if needed
|
||||||
|
if folder is not None:
|
||||||
|
filtered_data = [
|
||||||
|
cp for cp in filtered_data
|
||||||
|
if cp['folder'] == folder
|
||||||
|
]
|
||||||
|
|
||||||
|
# Apply basic search if needed
|
||||||
|
if search:
|
||||||
|
filtered_data = [
|
||||||
|
cp for cp in filtered_data
|
||||||
|
if search.lower() in cp['file_name'].lower() or
|
||||||
|
search.lower() in cp['model_name'].lower()
|
||||||
|
]
|
||||||
|
|
||||||
|
# Calculate pagination
|
||||||
|
total_items = len(filtered_data)
|
||||||
|
start_idx = (page - 1) * page_size
|
||||||
|
end_idx = min(start_idx + page_size, total_items)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
'items': filtered_data[start_idx:end_idx],
|
||||||
|
'total': total_items,
|
||||||
|
'page': page,
|
||||||
|
'page_size': page_size,
|
||||||
|
'total_pages': (total_items + page_size - 1) // page_size
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def scan_checkpoints(self, request):
|
||||||
|
"""Force a rescan of checkpoint files"""
|
||||||
|
try:
|
||||||
|
await self.scanner.get_cached_data(force_refresh=True)
|
||||||
|
return web.json_response({"status": "success", "message": "Checkpoint scan completed"})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in scan_checkpoints: {e}", exc_info=True)
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def get_checkpoint_info(self, request):
|
||||||
|
"""Get detailed information for a specific checkpoint by name"""
|
||||||
|
try:
|
||||||
|
name = request.match_info.get('name', '')
|
||||||
|
checkpoint_info = await self.scanner.get_checkpoint_info_by_name(name)
|
||||||
|
|
||||||
|
if checkpoint_info:
|
||||||
|
return web.json_response(checkpoint_info)
|
||||||
|
else:
|
||||||
|
return web.json_response({"error": "Checkpoint not found"}, status=404)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|||||||
131
py/services/checkpoint_scanner.py
Normal file
131
py/services/checkpoint_scanner.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from typing import List, Dict, Optional, Set
|
||||||
|
import folder_paths # type: ignore
|
||||||
|
|
||||||
|
from ..utils.models import CheckpointMetadata
|
||||||
|
from ..config import config
|
||||||
|
from .model_scanner import ModelScanner
|
||||||
|
from .model_hash_index import ModelHashIndex
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class CheckpointScanner(ModelScanner):
|
||||||
|
"""Service for scanning and managing checkpoint files"""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if not hasattr(self, '_initialized'):
|
||||||
|
# Define supported file extensions
|
||||||
|
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft'}
|
||||||
|
super().__init__(
|
||||||
|
model_type="checkpoint",
|
||||||
|
model_class=CheckpointMetadata,
|
||||||
|
file_extensions=file_extensions,
|
||||||
|
hash_index=ModelHashIndex()
|
||||||
|
)
|
||||||
|
self._checkpoint_roots = self._init_checkpoint_roots()
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_instance(cls):
|
||||||
|
"""Get singleton instance with async support"""
|
||||||
|
async with cls._lock:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = cls()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def _init_checkpoint_roots(self) -> List[str]:
|
||||||
|
"""Initialize checkpoint roots from ComfyUI settings"""
|
||||||
|
# Get both checkpoint and diffusion_models paths
|
||||||
|
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||||
|
diffusion_paths = folder_paths.get_folder_paths("diffusion_models")
|
||||||
|
|
||||||
|
# Combine, normalize and deduplicate paths
|
||||||
|
all_paths = set()
|
||||||
|
for path in checkpoint_paths + diffusion_paths:
|
||||||
|
if os.path.exists(path):
|
||||||
|
norm_path = path.replace(os.sep, "/")
|
||||||
|
all_paths.add(norm_path)
|
||||||
|
|
||||||
|
# Sort for consistent order
|
||||||
|
sorted_paths = sorted(all_paths, key=lambda p: p.lower())
|
||||||
|
logger.info(f"Found checkpoint roots: {sorted_paths}")
|
||||||
|
|
||||||
|
return sorted_paths
|
||||||
|
|
||||||
|
def get_model_roots(self) -> List[str]:
|
||||||
|
"""Get checkpoint root directories"""
|
||||||
|
return self._checkpoint_roots
|
||||||
|
|
||||||
|
async def scan_all_models(self) -> List[Dict]:
|
||||||
|
"""Scan all checkpoint directories and return metadata"""
|
||||||
|
all_checkpoints = []
|
||||||
|
|
||||||
|
# Create scan tasks for each directory
|
||||||
|
scan_tasks = []
|
||||||
|
for root in self._checkpoint_roots:
|
||||||
|
task = asyncio.create_task(self._scan_directory(root))
|
||||||
|
scan_tasks.append(task)
|
||||||
|
|
||||||
|
# Wait for all tasks to complete
|
||||||
|
for task in scan_tasks:
|
||||||
|
try:
|
||||||
|
checkpoints = await task
|
||||||
|
all_checkpoints.extend(checkpoints)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error scanning checkpoint directory: {e}")
|
||||||
|
|
||||||
|
return all_checkpoints
|
||||||
|
|
||||||
|
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||||
|
"""Scan a directory for checkpoint files"""
|
||||||
|
checkpoints = []
|
||||||
|
original_root = root_path
|
||||||
|
|
||||||
|
async def scan_recursive(path: str, visited_paths: set):
|
||||||
|
try:
|
||||||
|
real_path = os.path.realpath(path)
|
||||||
|
if real_path in visited_paths:
|
||||||
|
logger.debug(f"Skipping already visited path: {path}")
|
||||||
|
return
|
||||||
|
visited_paths.add(real_path)
|
||||||
|
|
||||||
|
with os.scandir(path) as it:
|
||||||
|
entries = list(it)
|
||||||
|
for entry in entries:
|
||||||
|
try:
|
||||||
|
if entry.is_file(follow_symlinks=True):
|
||||||
|
# Check if file has supported extension
|
||||||
|
ext = os.path.splitext(entry.name)[1].lower()
|
||||||
|
if ext in self.file_extensions:
|
||||||
|
file_path = entry.path.replace(os.sep, "/")
|
||||||
|
await self._process_single_file(file_path, original_root, checkpoints)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
elif entry.is_dir(follow_symlinks=True):
|
||||||
|
# For directories, continue scanning with original path
|
||||||
|
await scan_recursive(entry.path, visited_paths)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error scanning {path}: {e}")
|
||||||
|
|
||||||
|
await scan_recursive(root_path, set())
|
||||||
|
return checkpoints
|
||||||
|
|
||||||
|
async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list):
|
||||||
|
"""Process a single checkpoint file and add to results"""
|
||||||
|
try:
|
||||||
|
result = await self._process_model_file(file_path, root_path)
|
||||||
|
if result:
|
||||||
|
checkpoints.append(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing {file_path}: {e}")
|
||||||
@@ -7,6 +7,8 @@ from watchdog.observers import Observer
|
|||||||
from watchdog.events import FileSystemEventHandler
|
from watchdog.events import FileSystemEventHandler
|
||||||
from typing import List, Dict, Set
|
from typing import List, Dict, Set
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
|
from .checkpoint_scanner import CheckpointScanner
|
||||||
from .lora_scanner import LoraScanner
|
from .lora_scanner import LoraScanner
|
||||||
from ..config import config
|
from ..config import config
|
||||||
|
|
||||||
@@ -331,3 +333,89 @@ class LoraFileMonitor:
|
|||||||
logger.info(f"Added new monitoring path: {path}")
|
logger.info(f"Added new monitoring path: {path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error adding new monitor for {path}: {e}")
|
logger.error(f"Error adding new monitor for {path}: {e}")
|
||||||
|
|
||||||
|
# Add CheckpointFileMonitor class
|
||||||
|
|
||||||
|
class CheckpointFileMonitor(LoraFileMonitor):
|
||||||
|
"""Monitor for checkpoint file changes"""
|
||||||
|
|
||||||
|
def __init__(self, scanner: CheckpointScanner, roots: List[str]):
|
||||||
|
# Reuse most of the LoraFileMonitor functionality, but with a different handler
|
||||||
|
self.scanner = scanner
|
||||||
|
scanner.set_file_monitor(self)
|
||||||
|
self.observer = Observer()
|
||||||
|
self.loop = asyncio.get_event_loop()
|
||||||
|
self.handler = CheckpointFileHandler(scanner, self.loop)
|
||||||
|
|
||||||
|
# Use existing path mappings
|
||||||
|
self.monitor_paths = set()
|
||||||
|
for root in roots:
|
||||||
|
self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/'))
|
||||||
|
|
||||||
|
# Add all mapped target paths
|
||||||
|
for target_path in config._path_mappings.keys():
|
||||||
|
self.monitor_paths.add(target_path)
|
||||||
|
|
||||||
|
class CheckpointFileHandler(LoraFileHandler):
|
||||||
|
"""Handler for checkpoint file system events"""
|
||||||
|
|
||||||
|
def __init__(self, scanner: CheckpointScanner, loop: asyncio.AbstractEventLoop):
|
||||||
|
super().__init__(scanner, loop)
|
||||||
|
# Configure supported file extensions
|
||||||
|
self.supported_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft'}
|
||||||
|
|
||||||
|
def on_created(self, event):
|
||||||
|
if event.is_directory:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle supported file extensions directly
|
||||||
|
file_ext = os.path.splitext(event.src_path)[1].lower()
|
||||||
|
if file_ext in self.supported_extensions:
|
||||||
|
if self._should_ignore(event.src_path):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Process this file directly
|
||||||
|
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
|
||||||
|
if normalized_path not in self.scheduled_files:
|
||||||
|
logger.info(f"Checkpoint file created: {event.src_path}")
|
||||||
|
self.scheduled_files.add(normalized_path)
|
||||||
|
self._schedule_update('add', event.src_path)
|
||||||
|
|
||||||
|
# Ignore modifications for a short period after creation
|
||||||
|
self.loop.call_later(
|
||||||
|
self.debounce_delay * 2,
|
||||||
|
self.scheduled_files.discard,
|
||||||
|
normalized_path
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_modified(self, event):
|
||||||
|
if event.is_directory:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Only process supported file types
|
||||||
|
file_ext = os.path.splitext(event.src_path)[1].lower()
|
||||||
|
if file_ext in self.supported_extensions:
|
||||||
|
super().on_modified(event)
|
||||||
|
|
||||||
|
def on_deleted(self, event):
|
||||||
|
if event.is_directory:
|
||||||
|
return
|
||||||
|
|
||||||
|
file_ext = os.path.splitext(event.src_path)[1].lower()
|
||||||
|
if file_ext not in self.supported_extensions:
|
||||||
|
return
|
||||||
|
|
||||||
|
super().on_deleted(event)
|
||||||
|
|
||||||
|
def on_moved(self, event):
|
||||||
|
"""Handle file move/rename events"""
|
||||||
|
src_ext = os.path.splitext(event.src_path)[1].lower()
|
||||||
|
dest_ext = os.path.splitext(event.dest_path)[1].lower()
|
||||||
|
|
||||||
|
# If destination has supported extension, treat as new file
|
||||||
|
if dest_ext in self.supported_extensions:
|
||||||
|
super().on_moved(event)
|
||||||
|
|
||||||
|
# If source was supported extension, treat as deleted
|
||||||
|
elif src_ext in self.supported_extensions:
|
||||||
|
super().on_moved(event)
|
||||||
@@ -4,13 +4,11 @@ import logging
|
|||||||
import asyncio
|
import asyncio
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional, Set
|
||||||
|
|
||||||
from ..utils.models import LoraMetadata
|
from ..utils.models import LoraMetadata
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..utils.file_utils import load_metadata, get_file_info, normalize_path, find_preview_file, save_metadata
|
from .model_scanner import ModelScanner
|
||||||
from ..utils.lora_metadata import extract_lora_metadata
|
|
||||||
from .lora_cache import LoraCache
|
|
||||||
from .lora_hash_index import LoraHashIndex
|
from .lora_hash_index import LoraHashIndex
|
||||||
from .settings_manager import settings
|
from .settings_manager import settings
|
||||||
from ..utils.constants import NSFW_LEVELS
|
from ..utils.constants import NSFW_LEVELS
|
||||||
@@ -19,7 +17,7 @@ import sys
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class LoraScanner:
|
class LoraScanner(ModelScanner):
|
||||||
"""Service for scanning and managing LoRA files"""
|
"""Service for scanning and managing LoRA files"""
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
@@ -31,19 +29,19 @@ class LoraScanner:
|
|||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 确保初始化只执行一次
|
# Ensure initialization happens only once
|
||||||
if not hasattr(self, '_initialized'):
|
if not hasattr(self, '_initialized'):
|
||||||
self._cache: Optional[LoraCache] = None
|
# Define supported file extensions
|
||||||
self._hash_index = LoraHashIndex()
|
file_extensions = {'.safetensors'}
|
||||||
self._initialization_lock = asyncio.Lock()
|
|
||||||
self._initialization_task: Optional[asyncio.Task] = None
|
|
||||||
self._initialized = True
|
|
||||||
self.file_monitor = None # Add this line
|
|
||||||
self._tags_count = {} # Add a dictionary to store tag counts
|
|
||||||
|
|
||||||
def set_file_monitor(self, monitor):
|
# Initialize parent class
|
||||||
"""Set file monitor instance"""
|
super().__init__(
|
||||||
self.file_monitor = monitor
|
model_type="lora",
|
||||||
|
model_class=LoraMetadata,
|
||||||
|
file_extensions=file_extensions,
|
||||||
|
hash_index=LoraHashIndex()
|
||||||
|
)
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_instance(cls):
|
async def get_instance(cls):
|
||||||
@@ -53,87 +51,72 @@ class LoraScanner:
|
|||||||
cls._instance = cls()
|
cls._instance = cls()
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
async def get_cached_data(self, force_refresh: bool = False) -> LoraCache:
|
def get_model_roots(self) -> List[str]:
|
||||||
"""Get cached LoRA data, refresh if needed"""
|
"""Get lora root directories"""
|
||||||
async with self._initialization_lock:
|
return config.loras_roots
|
||||||
|
|
||||||
# 如果缓存未初始化但需要响应请求,返回空缓存
|
async def scan_all_models(self) -> List[Dict]:
|
||||||
if self._cache is None and not force_refresh:
|
"""Scan all LoRA directories and return metadata"""
|
||||||
return LoraCache(
|
all_loras = []
|
||||||
raw_data=[],
|
|
||||||
sorted_by_name=[],
|
|
||||||
sorted_by_date=[],
|
|
||||||
folders=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
# 如果正在初始化,等待完成
|
# Create scan tasks for each directory
|
||||||
if self._initialization_task and not self._initialization_task.done():
|
scan_tasks = []
|
||||||
try:
|
for lora_root in self.get_model_roots():
|
||||||
await self._initialization_task
|
task = asyncio.create_task(self._scan_directory(lora_root))
|
||||||
except Exception as e:
|
scan_tasks.append(task)
|
||||||
logger.error(f"Cache initialization failed: {e}")
|
|
||||||
self._initialization_task = None
|
|
||||||
|
|
||||||
if (self._cache is None or force_refresh):
|
# Wait for all tasks to complete
|
||||||
|
for task in scan_tasks:
|
||||||
|
try:
|
||||||
|
loras = await task
|
||||||
|
all_loras.extend(loras)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error scanning directory: {e}")
|
||||||
|
|
||||||
# 创建新的初始化任务
|
return all_loras
|
||||||
if not self._initialization_task or self._initialization_task.done():
|
|
||||||
self._initialization_task = asyncio.create_task(self._initialize_cache())
|
|
||||||
|
|
||||||
try:
|
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||||
await self._initialization_task
|
"""Scan a single directory for LoRA files"""
|
||||||
except Exception as e:
|
loras = []
|
||||||
logger.error(f"Cache initialization failed: {e}")
|
original_root = root_path # Save original root path
|
||||||
# 如果缓存已存在,继续使用旧缓存
|
|
||||||
if self._cache is None:
|
|
||||||
raise # 如果没有缓存,则抛出异常
|
|
||||||
|
|
||||||
return self._cache
|
async def scan_recursive(path: str, visited_paths: set):
|
||||||
|
"""Recursively scan directory, avoiding circular symlinks"""
|
||||||
|
try:
|
||||||
|
real_path = os.path.realpath(path)
|
||||||
|
if real_path in visited_paths:
|
||||||
|
logger.debug(f"Skipping already visited path: {path}")
|
||||||
|
return
|
||||||
|
visited_paths.add(real_path)
|
||||||
|
|
||||||
async def _initialize_cache(self) -> None:
|
with os.scandir(path) as it:
|
||||||
"""Initialize or refresh the cache"""
|
entries = list(it)
|
||||||
|
for entry in entries:
|
||||||
|
try:
|
||||||
|
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
|
||||||
|
# Use original path instead of real path
|
||||||
|
file_path = entry.path.replace(os.sep, "/")
|
||||||
|
await self._process_single_file(file_path, original_root, loras)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
elif entry.is_dir(follow_symlinks=True):
|
||||||
|
# For directories, continue scanning with original path
|
||||||
|
await scan_recursive(entry.path, visited_paths)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error scanning {path}: {e}")
|
||||||
|
|
||||||
|
await scan_recursive(root_path, set())
|
||||||
|
return loras
|
||||||
|
|
||||||
|
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
|
||||||
|
"""Process a single file and add to results list"""
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
result = await self._process_model_file(file_path, root_path)
|
||||||
# Clear existing hash index
|
if result:
|
||||||
self._hash_index.clear()
|
loras.append(result)
|
||||||
|
|
||||||
# Clear existing tags count
|
|
||||||
self._tags_count = {}
|
|
||||||
|
|
||||||
# Scan for new data
|
|
||||||
raw_data = await self.scan_all_loras()
|
|
||||||
|
|
||||||
# Build hash index and tags count
|
|
||||||
for lora_data in raw_data:
|
|
||||||
if 'sha256' in lora_data and 'file_path' in lora_data:
|
|
||||||
self._hash_index.add_entry(lora_data['sha256'].lower(), lora_data['file_path'])
|
|
||||||
|
|
||||||
# Count tags
|
|
||||||
if 'tags' in lora_data and lora_data['tags']:
|
|
||||||
for tag in lora_data['tags']:
|
|
||||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
|
||||||
|
|
||||||
# Update cache
|
|
||||||
self._cache = LoraCache(
|
|
||||||
raw_data=raw_data,
|
|
||||||
sorted_by_name=[],
|
|
||||||
sorted_by_date=[],
|
|
||||||
folders=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call resort_cache to create sorted views
|
|
||||||
await self._cache.resort()
|
|
||||||
|
|
||||||
self._initialization_task = None
|
|
||||||
logger.info(f"LoRA Manager: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} loras")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LoRA Manager: Error initializing cache: {e}")
|
logger.error(f"Error processing {file_path}: {e}")
|
||||||
self._cache = LoraCache(
|
|
||||||
raw_data=[],
|
|
||||||
sorted_by_name=[],
|
|
||||||
sorted_by_date=[],
|
|
||||||
folders=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
||||||
folder: str = None, search: str = None, fuzzy_search: bool = False,
|
folder: str = None, search: str = None, fuzzy_search: bool = False,
|
||||||
@@ -280,240 +263,14 @@ class LoraScanner:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def invalidate_cache(self):
|
|
||||||
"""Invalidate the current cache"""
|
|
||||||
self._cache = None
|
|
||||||
|
|
||||||
async def scan_all_loras(self) -> List[Dict]:
|
|
||||||
"""Scan all LoRA directories and return metadata"""
|
|
||||||
all_loras = []
|
|
||||||
|
|
||||||
# 分目录异步扫描
|
|
||||||
scan_tasks = []
|
|
||||||
for loras_root in config.loras_roots:
|
|
||||||
task = asyncio.create_task(self._scan_directory(loras_root))
|
|
||||||
scan_tasks.append(task)
|
|
||||||
|
|
||||||
for task in scan_tasks:
|
|
||||||
try:
|
|
||||||
loras = await task
|
|
||||||
all_loras.extend(loras)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error scanning directory: {e}")
|
|
||||||
|
|
||||||
return all_loras
|
|
||||||
|
|
||||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
|
||||||
"""Scan a single directory for LoRA files"""
|
|
||||||
loras = []
|
|
||||||
original_root = root_path # 保存原始根路径
|
|
||||||
|
|
||||||
async def scan_recursive(path: str, visited_paths: set):
|
|
||||||
"""递归扫描目录,避免循环链接"""
|
|
||||||
try:
|
|
||||||
real_path = os.path.realpath(path)
|
|
||||||
if real_path in visited_paths:
|
|
||||||
logger.debug(f"Skipping already visited path: {path}")
|
|
||||||
return
|
|
||||||
visited_paths.add(real_path)
|
|
||||||
|
|
||||||
with os.scandir(path) as it:
|
|
||||||
entries = list(it)
|
|
||||||
for entry in entries:
|
|
||||||
try:
|
|
||||||
if entry.is_file(follow_symlinks=True) and entry.name.endswith('.safetensors'):
|
|
||||||
# 使用原始路径而不是真实路径
|
|
||||||
file_path = entry.path.replace(os.sep, "/")
|
|
||||||
await self._process_single_file(file_path, original_root, loras)
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
elif entry.is_dir(follow_symlinks=True):
|
|
||||||
# 对于目录,使用原始路径继续扫描
|
|
||||||
await scan_recursive(entry.path, visited_paths)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error scanning {path}: {e}")
|
|
||||||
|
|
||||||
await scan_recursive(root_path, set())
|
|
||||||
return loras
|
|
||||||
|
|
||||||
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
|
|
||||||
"""处理单个文件并添加到结果列表"""
|
|
||||||
try:
|
|
||||||
result = await self._process_lora_file(file_path, root_path)
|
|
||||||
if result:
|
|
||||||
loras.append(result)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing {file_path}: {e}")
|
|
||||||
|
|
||||||
async def _process_lora_file(self, file_path: str, root_path: str) -> Dict:
|
|
||||||
"""Process a single LoRA file and return its metadata"""
|
|
||||||
# Try loading existing metadata
|
|
||||||
metadata = await load_metadata(file_path)
|
|
||||||
|
|
||||||
if metadata is None:
|
|
||||||
# Try to find and use .civitai.info file first
|
|
||||||
civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info"
|
|
||||||
if os.path.exists(civitai_info_path):
|
|
||||||
try:
|
|
||||||
with open(civitai_info_path, 'r', encoding='utf-8') as f:
|
|
||||||
version_info = json.load(f)
|
|
||||||
|
|
||||||
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
|
|
||||||
if file_info:
|
|
||||||
# Create a minimal file_info with the required fields
|
|
||||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
|
||||||
file_info['name'] = file_name
|
|
||||||
|
|
||||||
# Use from_civitai_info to create metadata
|
|
||||||
metadata = LoraMetadata.from_civitai_info(version_info, file_info, file_path)
|
|
||||||
metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path))
|
|
||||||
await save_metadata(file_path, metadata)
|
|
||||||
logger.debug(f"Created metadata from .civitai.info for {file_path}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
|
|
||||||
|
|
||||||
# If still no metadata, create new metadata using get_file_info
|
|
||||||
if metadata is None:
|
|
||||||
metadata = await get_file_info(file_path)
|
|
||||||
|
|
||||||
# Convert to dict and add folder info
|
|
||||||
lora_data = metadata.to_dict()
|
|
||||||
# Try to fetch missing metadata from Civitai if needed
|
|
||||||
await self._fetch_missing_metadata(file_path, lora_data)
|
|
||||||
rel_path = os.path.relpath(file_path, root_path)
|
|
||||||
folder = os.path.dirname(rel_path)
|
|
||||||
lora_data['folder'] = folder.replace(os.path.sep, '/')
|
|
||||||
|
|
||||||
return lora_data
|
|
||||||
|
|
||||||
async def _fetch_missing_metadata(self, file_path: str, lora_data: Dict) -> None:
|
|
||||||
"""Fetch missing description and tags from Civitai if needed
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: Path to the lora file
|
|
||||||
lora_data: Lora metadata dictionary to update
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Skip if already marked as deleted on Civitai
|
|
||||||
if lora_data.get('civitai_deleted', False):
|
|
||||||
logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if we need to fetch additional metadata from Civitai
|
|
||||||
needs_metadata_update = False
|
|
||||||
model_id = None
|
|
||||||
|
|
||||||
# Check if we have Civitai model ID but missing metadata
|
|
||||||
if lora_data.get('civitai'):
|
|
||||||
# Try to get model ID directly from the correct location
|
|
||||||
model_id = lora_data['civitai'].get('modelId')
|
|
||||||
|
|
||||||
if model_id:
|
|
||||||
model_id = str(model_id)
|
|
||||||
# Check if tags are missing or empty
|
|
||||||
tags_missing = not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0
|
|
||||||
|
|
||||||
# Check if description is missing or empty
|
|
||||||
desc_missing = not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "")
|
|
||||||
|
|
||||||
needs_metadata_update = tags_missing or desc_missing
|
|
||||||
|
|
||||||
# Fetch missing metadata if needed
|
|
||||||
if needs_metadata_update and model_id:
|
|
||||||
logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}")
|
|
||||||
from ..services.civitai_client import CivitaiClient
|
|
||||||
client = CivitaiClient()
|
|
||||||
|
|
||||||
# Get metadata and status code
|
|
||||||
model_metadata, status_code = await client.get_model_metadata(model_id)
|
|
||||||
await client.close()
|
|
||||||
|
|
||||||
# Handle 404 status (model deleted from Civitai)
|
|
||||||
if status_code == 404:
|
|
||||||
logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)")
|
|
||||||
# Mark as deleted to avoid future API calls
|
|
||||||
lora_data['civitai_deleted'] = True
|
|
||||||
|
|
||||||
# Save the updated metadata back to file
|
|
||||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
|
||||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(lora_data, f, indent=2, ensure_ascii=False)
|
|
||||||
|
|
||||||
# Process valid metadata if available
|
|
||||||
elif model_metadata:
|
|
||||||
logger.debug(f"Updating metadata for {file_path} with model ID {model_id}")
|
|
||||||
|
|
||||||
# Update tags if they were missing
|
|
||||||
if model_metadata.get('tags') and (not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0):
|
|
||||||
lora_data['tags'] = model_metadata['tags']
|
|
||||||
|
|
||||||
# Update description if it was missing
|
|
||||||
if model_metadata.get('description') and (not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "")):
|
|
||||||
lora_data['modelDescription'] = model_metadata['description']
|
|
||||||
|
|
||||||
# Save the updated metadata back to file
|
|
||||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
|
||||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(lora_data, f, indent=2, ensure_ascii=False)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}")
|
|
||||||
|
|
||||||
async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool:
|
|
||||||
"""Update preview URL in cache for a specific lora
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: The file path of the lora to update
|
|
||||||
preview_url: The new preview URL
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the update was successful, False if cache doesn't exist or lora wasn't found
|
|
||||||
"""
|
|
||||||
if self._cache is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return await self._cache.update_preview_url(file_path, preview_url)
|
|
||||||
|
|
||||||
async def scan_single_lora(self, file_path: str) -> Optional[Dict]:
|
|
||||||
"""Scan a single LoRA file and return its metadata"""
|
|
||||||
try:
|
|
||||||
if not os.path.exists(os.path.realpath(file_path)):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 获取基本文件信息
|
|
||||||
metadata = await get_file_info(file_path)
|
|
||||||
if not metadata:
|
|
||||||
return None
|
|
||||||
|
|
||||||
folder = self._calculate_folder(file_path)
|
|
||||||
|
|
||||||
# 确保 folder 字段存在
|
|
||||||
metadata_dict = metadata.to_dict()
|
|
||||||
metadata_dict['folder'] = folder or ''
|
|
||||||
|
|
||||||
return metadata_dict
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error scanning {file_path}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _calculate_folder(self, file_path: str) -> str:
|
|
||||||
"""Calculate the folder path for a LoRA file"""
|
|
||||||
# 使用原始路径计算相对路径
|
|
||||||
for root in config.loras_roots:
|
|
||||||
if file_path.startswith(root):
|
|
||||||
rel_path = os.path.relpath(file_path, root)
|
|
||||||
return os.path.dirname(rel_path).replace(os.path.sep, '/')
|
|
||||||
return ''
|
|
||||||
|
|
||||||
async def move_model(self, source_path: str, target_path: str) -> bool:
|
async def move_model(self, source_path: str, target_path: str) -> bool:
|
||||||
"""Move a model and its associated files to a new location"""
|
"""Move a model and its associated files to a new location"""
|
||||||
try:
|
try:
|
||||||
# 保持原始路径格式
|
# Keep original path format
|
||||||
source_path = source_path.replace(os.sep, '/')
|
source_path = source_path.replace(os.sep, '/')
|
||||||
target_path = target_path.replace(os.sep, '/')
|
target_path = target_path.replace(os.sep, '/')
|
||||||
|
|
||||||
# 其余代码保持不变
|
# Rest of the code remains unchanged
|
||||||
base_name = os.path.splitext(os.path.basename(source_path))[0]
|
base_name = os.path.splitext(os.path.basename(source_path))[0]
|
||||||
source_dir = os.path.dirname(source_path)
|
source_dir = os.path.dirname(source_path)
|
||||||
|
|
||||||
@@ -521,7 +278,7 @@ class LoraScanner:
|
|||||||
|
|
||||||
target_lora = os.path.join(target_path, f"{base_name}.safetensors").replace(os.sep, '/')
|
target_lora = os.path.join(target_path, f"{base_name}.safetensors").replace(os.sep, '/')
|
||||||
|
|
||||||
# 使用真实路径进行文件操作
|
# Use real paths for file operations
|
||||||
real_source = os.path.realpath(source_path)
|
real_source = os.path.realpath(source_path)
|
||||||
real_target = os.path.realpath(target_lora)
|
real_target = os.path.realpath(target_lora)
|
||||||
|
|
||||||
@@ -537,7 +294,7 @@ class LoraScanner:
|
|||||||
file_size
|
file_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# 使用真实路径进行文件操作
|
# Use real paths for file operations
|
||||||
shutil.move(real_source, real_target)
|
shutil.move(real_source, real_target)
|
||||||
|
|
||||||
# Move associated files
|
# Move associated files
|
||||||
@@ -648,7 +405,7 @@ class LoraScanner:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error updating metadata paths: {e}", exc_info=True)
|
logger.error(f"Error updating metadata paths: {e}", exc_info=True)
|
||||||
|
|
||||||
# Add new methods for hash index functionality
|
# Lora-specific hash index functionality
|
||||||
def has_lora_hash(self, sha256: str) -> bool:
|
def has_lora_hash(self, sha256: str) -> bool:
|
||||||
"""Check if a LoRA with given hash exists"""
|
"""Check if a LoRA with given hash exists"""
|
||||||
return self._hash_index.has_hash(sha256.lower())
|
return self._hash_index.has_hash(sha256.lower())
|
||||||
@@ -681,16 +438,8 @@ class LoraScanner:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Add new method to get top tags
|
|
||||||
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
|
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
|
||||||
"""Get top tags sorted by count
|
"""Get top tags sorted by count"""
|
||||||
|
|
||||||
Args:
|
|
||||||
limit: Maximum number of tags to return
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of dictionaries with tag name and count, sorted by count
|
|
||||||
"""
|
|
||||||
# Make sure cache is initialized
|
# Make sure cache is initialized
|
||||||
await self.get_cached_data()
|
await self.get_cached_data()
|
||||||
|
|
||||||
@@ -705,14 +454,7 @@ class LoraScanner:
|
|||||||
return sorted_tags[:limit]
|
return sorted_tags[:limit]
|
||||||
|
|
||||||
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
|
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
|
||||||
"""Get base models used in loras sorted by frequency
|
"""Get base models used in loras sorted by frequency"""
|
||||||
|
|
||||||
Args:
|
|
||||||
limit: Maximum number of base models to return
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of dictionaries with base model name and count, sorted by count
|
|
||||||
"""
|
|
||||||
# Make sure cache is initialized
|
# Make sure cache is initialized
|
||||||
cache = await self.get_cached_data()
|
cache = await self.get_cached_data()
|
||||||
|
|
||||||
|
|||||||
64
py/services/model_cache.py
Normal file
64
py/services/model_cache.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import List, Dict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from operator import itemgetter
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelCache:
|
||||||
|
"""Cache structure for model data"""
|
||||||
|
raw_data: List[Dict]
|
||||||
|
sorted_by_name: List[Dict]
|
||||||
|
sorted_by_date: List[Dict]
|
||||||
|
folders: List[str]
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def resort(self, name_only: bool = False):
|
||||||
|
"""Resort all cached data views"""
|
||||||
|
async with self._lock:
|
||||||
|
self.sorted_by_name = sorted(
|
||||||
|
self.raw_data,
|
||||||
|
key=lambda x: x['model_name'].lower() # Case-insensitive sort
|
||||||
|
)
|
||||||
|
if not name_only:
|
||||||
|
self.sorted_by_date = sorted(
|
||||||
|
self.raw_data,
|
||||||
|
key=itemgetter('modified'),
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
# Update folder list
|
||||||
|
all_folders = set(l['folder'] for l in self.raw_data)
|
||||||
|
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||||
|
|
||||||
|
async def update_preview_url(self, file_path: str, preview_url: str) -> bool:
|
||||||
|
"""Update preview_url for a specific model in all cached data
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: The file path of the model to update
|
||||||
|
preview_url: The new preview URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the update was successful, False if the model wasn't found
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
# Update in raw_data
|
||||||
|
for item in self.raw_data:
|
||||||
|
if item['file_path'] == file_path:
|
||||||
|
item['preview_url'] = preview_url
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return False # Model not found
|
||||||
|
|
||||||
|
# Update in sorted lists (references to the same dict objects)
|
||||||
|
for item in self.sorted_by_name:
|
||||||
|
if item['file_path'] == file_path:
|
||||||
|
item['preview_url'] = preview_url
|
||||||
|
break
|
||||||
|
|
||||||
|
for item in self.sorted_by_date:
|
||||||
|
if item['file_path'] == file_path:
|
||||||
|
item['preview_url'] = preview_url
|
||||||
|
break
|
||||||
|
|
||||||
|
return True
|
||||||
78
py/services/model_hash_index.py
Normal file
78
py/services/model_hash_index.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
from typing import Dict, Optional, Set
|
||||||
|
|
||||||
|
class ModelHashIndex:
|
||||||
|
"""Index for looking up models by hash or path"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._hash_to_path: Dict[str, str] = {}
|
||||||
|
self._path_to_hash: Dict[str, str] = {}
|
||||||
|
|
||||||
|
def add_entry(self, sha256: str, file_path: str) -> None:
|
||||||
|
"""Add or update hash index entry"""
|
||||||
|
if not sha256 or not file_path:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Ensure hash is lowercase for consistency
|
||||||
|
sha256 = sha256.lower()
|
||||||
|
|
||||||
|
# Remove old path mapping if hash exists
|
||||||
|
if sha256 in self._hash_to_path:
|
||||||
|
old_path = self._hash_to_path[sha256]
|
||||||
|
if old_path in self._path_to_hash:
|
||||||
|
del self._path_to_hash[old_path]
|
||||||
|
|
||||||
|
# Remove old hash mapping if path exists
|
||||||
|
if file_path in self._path_to_hash:
|
||||||
|
old_hash = self._path_to_hash[file_path]
|
||||||
|
if old_hash in self._hash_to_path:
|
||||||
|
del self._hash_to_path[old_hash]
|
||||||
|
|
||||||
|
# Add new mappings
|
||||||
|
self._hash_to_path[sha256] = file_path
|
||||||
|
self._path_to_hash[file_path] = sha256
|
||||||
|
|
||||||
|
def remove_by_path(self, file_path: str) -> None:
|
||||||
|
"""Remove entry by file path"""
|
||||||
|
if file_path in self._path_to_hash:
|
||||||
|
hash_val = self._path_to_hash[file_path]
|
||||||
|
if hash_val in self._hash_to_path:
|
||||||
|
del self._hash_to_path[hash_val]
|
||||||
|
del self._path_to_hash[file_path]
|
||||||
|
|
||||||
|
def remove_by_hash(self, sha256: str) -> None:
|
||||||
|
"""Remove entry by hash"""
|
||||||
|
sha256 = sha256.lower()
|
||||||
|
if sha256 in self._hash_to_path:
|
||||||
|
path = self._hash_to_path[sha256]
|
||||||
|
if path in self._path_to_hash:
|
||||||
|
del self._path_to_hash[path]
|
||||||
|
del self._hash_to_path[sha256]
|
||||||
|
|
||||||
|
def has_hash(self, sha256: str) -> bool:
|
||||||
|
"""Check if hash exists in index"""
|
||||||
|
return sha256.lower() in self._hash_to_path
|
||||||
|
|
||||||
|
def get_path(self, sha256: str) -> Optional[str]:
|
||||||
|
"""Get file path for a hash"""
|
||||||
|
return self._hash_to_path.get(sha256.lower())
|
||||||
|
|
||||||
|
def get_hash(self, file_path: str) -> Optional[str]:
|
||||||
|
"""Get hash for a file path"""
|
||||||
|
return self._path_to_hash.get(file_path)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear all entries"""
|
||||||
|
self._hash_to_path.clear()
|
||||||
|
self._path_to_hash.clear()
|
||||||
|
|
||||||
|
def get_all_hashes(self) -> Set[str]:
|
||||||
|
"""Get all hashes in the index"""
|
||||||
|
return set(self._hash_to_path.keys())
|
||||||
|
|
||||||
|
def get_all_paths(self) -> Set[str]:
|
||||||
|
"""Get all file paths in the index"""
|
||||||
|
return set(self._path_to_hash.keys())
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Get number of entries"""
|
||||||
|
return len(self._hash_to_path)
|
||||||
554
py/services/model_scanner.py
Normal file
554
py/services/model_scanner.py
Normal file
@@ -0,0 +1,554 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import shutil
|
||||||
|
from typing import List, Dict, Optional, Type, Set
|
||||||
|
|
||||||
|
from ..utils.models import BaseModelMetadata
|
||||||
|
from ..config import config
|
||||||
|
from ..utils.file_utils import load_metadata, get_file_info, find_preview_file, save_metadata
|
||||||
|
from .model_cache import ModelCache
|
||||||
|
from .model_hash_index import ModelHashIndex
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ModelScanner:
|
||||||
|
"""Base service for scanning and managing model files"""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None):
|
||||||
|
"""Initialize the scanner
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: Type of model (lora, checkpoint, etc.)
|
||||||
|
model_class: Class used to create metadata instances
|
||||||
|
file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'})
|
||||||
|
hash_index: Hash index instance (optional)
|
||||||
|
"""
|
||||||
|
self.model_type = model_type
|
||||||
|
self.model_class = model_class
|
||||||
|
self.file_extensions = file_extensions
|
||||||
|
self._cache = None
|
||||||
|
self._hash_index = hash_index or ModelHashIndex()
|
||||||
|
self._initialization_lock = asyncio.Lock()
|
||||||
|
self._initialization_task = None
|
||||||
|
self.file_monitor = None
|
||||||
|
self._tags_count = {} # Dictionary to store tag counts
|
||||||
|
|
||||||
|
def set_file_monitor(self, monitor):
|
||||||
|
"""Set file monitor instance"""
|
||||||
|
self.file_monitor = monitor
|
||||||
|
|
||||||
|
async def get_cached_data(self, force_refresh: bool = False) -> ModelCache:
|
||||||
|
"""Get cached model data, refresh if needed"""
|
||||||
|
async with self._initialization_lock:
|
||||||
|
# Return empty cache if not initialized and no refresh requested
|
||||||
|
if self._cache is None and not force_refresh:
|
||||||
|
return ModelCache(
|
||||||
|
raw_data=[],
|
||||||
|
sorted_by_name=[],
|
||||||
|
sorted_by_date=[],
|
||||||
|
folders=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for ongoing initialization if any
|
||||||
|
if self._initialization_task and not self._initialization_task.done():
|
||||||
|
try:
|
||||||
|
await self._initialization_task
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Cache initialization failed: {e}")
|
||||||
|
self._initialization_task = None
|
||||||
|
|
||||||
|
if (self._cache is None or force_refresh):
|
||||||
|
# Create new initialization task
|
||||||
|
if not self._initialization_task or self._initialization_task.done():
|
||||||
|
self._initialization_task = asyncio.create_task(self._initialize_cache())
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._initialization_task
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Cache initialization failed: {e}")
|
||||||
|
# Continue using old cache if it exists
|
||||||
|
if self._cache is None:
|
||||||
|
raise # Raise exception if no cache available
|
||||||
|
|
||||||
|
return self._cache
|
||||||
|
|
||||||
|
async def _initialize_cache(self) -> None:
|
||||||
|
"""Initialize or refresh the cache"""
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
# Clear existing hash index
|
||||||
|
self._hash_index.clear()
|
||||||
|
|
||||||
|
# Clear existing tags count
|
||||||
|
self._tags_count = {}
|
||||||
|
|
||||||
|
# Scan for new data
|
||||||
|
raw_data = await self.scan_all_models()
|
||||||
|
|
||||||
|
# Build hash index and tags count
|
||||||
|
for model_data in raw_data:
|
||||||
|
if 'sha256' in model_data and 'file_path' in model_data:
|
||||||
|
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
|
||||||
|
|
||||||
|
# Count tags
|
||||||
|
if 'tags' in model_data and model_data['tags']:
|
||||||
|
for tag in model_data['tags']:
|
||||||
|
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
self._cache = ModelCache(
|
||||||
|
raw_data=raw_data,
|
||||||
|
sorted_by_name=[],
|
||||||
|
sorted_by_date=[],
|
||||||
|
folders=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resort cache
|
||||||
|
await self._cache.resort()
|
||||||
|
|
||||||
|
self._initialization_task = None
|
||||||
|
logger.info(f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} models")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache: {e}")
|
||||||
|
self._cache = ModelCache(
|
||||||
|
raw_data=[],
|
||||||
|
sorted_by_name=[],
|
||||||
|
sorted_by_date=[],
|
||||||
|
folders=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# These methods should be implemented in child classes
|
||||||
|
async def scan_all_models(self) -> List[Dict]:
|
||||||
|
"""Scan all model directories and return metadata"""
|
||||||
|
raise NotImplementedError("Subclasses must implement scan_all_models")
|
||||||
|
|
||||||
|
def get_model_roots(self) -> List[str]:
|
||||||
|
"""Get model root directories"""
|
||||||
|
raise NotImplementedError("Subclasses must implement get_model_roots")
|
||||||
|
|
||||||
|
async def scan_single_model(self, file_path: str) -> Optional[Dict]:
|
||||||
|
"""Scan a single model file and return its metadata"""
|
||||||
|
try:
|
||||||
|
if not os.path.exists(os.path.realpath(file_path)):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get basic file info
|
||||||
|
metadata = await self._get_file_info(file_path)
|
||||||
|
if not metadata:
|
||||||
|
return None
|
||||||
|
|
||||||
|
folder = self._calculate_folder(file_path)
|
||||||
|
|
||||||
|
# Ensure folder field exists
|
||||||
|
metadata_dict = metadata.to_dict()
|
||||||
|
metadata_dict['folder'] = folder or ''
|
||||||
|
|
||||||
|
return metadata_dict
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error scanning {file_path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _get_file_info(self, file_path: str) -> Optional[BaseModelMetadata]:
|
||||||
|
"""Get model file info and metadata (extensible for different model types)"""
|
||||||
|
# Implementation may vary by model type - override in subclasses if needed
|
||||||
|
return await get_file_info(file_path, self.model_class)
|
||||||
|
|
||||||
|
def _calculate_folder(self, file_path: str) -> str:
|
||||||
|
"""Calculate the folder path for a model file"""
|
||||||
|
# Use original path to calculate relative path
|
||||||
|
for root in self.get_model_roots():
|
||||||
|
if file_path.startswith(root):
|
||||||
|
rel_path = os.path.relpath(file_path, root)
|
||||||
|
return os.path.dirname(rel_path).replace(os.path.sep, '/')
|
||||||
|
return ''
|
||||||
|
|
||||||
|
# Common methods shared between scanners
|
||||||
|
async def _process_model_file(self, file_path: str, root_path: str) -> Dict:
|
||||||
|
"""Process a single model file and return its metadata"""
|
||||||
|
# Try loading existing metadata
|
||||||
|
metadata = await load_metadata(file_path, self.model_class)
|
||||||
|
|
||||||
|
if metadata is None:
|
||||||
|
# Try to find and use .civitai.info file first
|
||||||
|
civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info"
|
||||||
|
if os.path.exists(civitai_info_path):
|
||||||
|
try:
|
||||||
|
with open(civitai_info_path, 'r', encoding='utf-8') as f:
|
||||||
|
version_info = json.load(f)
|
||||||
|
|
||||||
|
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
|
||||||
|
if file_info:
|
||||||
|
# Create a minimal file_info with the required fields
|
||||||
|
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||||
|
file_info['name'] = file_name
|
||||||
|
|
||||||
|
# Use from_civitai_info to create metadata
|
||||||
|
metadata = self.model_class.from_civitai_info(version_info, file_info, file_path)
|
||||||
|
metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path))
|
||||||
|
await save_metadata(file_path, metadata)
|
||||||
|
logger.debug(f"Created metadata from .civitai.info for {file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
|
||||||
|
|
||||||
|
# If still no metadata, create new metadata
|
||||||
|
if metadata is None:
|
||||||
|
metadata = await self._get_file_info(file_path)
|
||||||
|
|
||||||
|
# Convert to dict and add folder info
|
||||||
|
model_data = metadata.to_dict()
|
||||||
|
|
||||||
|
# Try to fetch missing metadata from Civitai if needed
|
||||||
|
await self._fetch_missing_metadata(file_path, model_data)
|
||||||
|
rel_path = os.path.relpath(file_path, root_path)
|
||||||
|
folder = os.path.dirname(rel_path)
|
||||||
|
model_data['folder'] = folder.replace(os.path.sep, '/')
|
||||||
|
|
||||||
|
return model_data
|
||||||
|
|
||||||
|
async def _fetch_missing_metadata(self, file_path: str, model_data: Dict) -> None:
|
||||||
|
"""Fetch missing description and tags from Civitai if needed"""
|
||||||
|
try:
|
||||||
|
# Skip if already marked as deleted on Civitai
|
||||||
|
if model_data.get('civitai_deleted', False):
|
||||||
|
logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if we need to fetch additional metadata from Civitai
|
||||||
|
needs_metadata_update = False
|
||||||
|
model_id = None
|
||||||
|
|
||||||
|
# Check if we have Civitai model ID but missing metadata
|
||||||
|
if model_data.get('civitai'):
|
||||||
|
model_id = model_data['civitai'].get('modelId')
|
||||||
|
|
||||||
|
if model_id:
|
||||||
|
model_id = str(model_id)
|
||||||
|
# Check if tags or description are missing
|
||||||
|
tags_missing = not model_data.get('tags') or len(model_data.get('tags', [])) == 0
|
||||||
|
desc_missing = not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")
|
||||||
|
needs_metadata_update = tags_missing or desc_missing
|
||||||
|
|
||||||
|
# Fetch missing metadata if needed
|
||||||
|
if needs_metadata_update and model_id:
|
||||||
|
logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}")
|
||||||
|
from ..services.civitai_client import CivitaiClient
|
||||||
|
client = CivitaiClient()
|
||||||
|
|
||||||
|
# Get metadata and status code
|
||||||
|
model_metadata, status_code = await client.get_model_metadata(model_id)
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
# Handle 404 status (model deleted from Civitai)
|
||||||
|
if status_code == 404:
|
||||||
|
logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)")
|
||||||
|
model_data['civitai_deleted'] = True
|
||||||
|
|
||||||
|
# Save the updated metadata
|
||||||
|
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||||
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(model_data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
# Process valid metadata if available
|
||||||
|
elif model_metadata:
|
||||||
|
logger.debug(f"Updating metadata for {file_path} with model ID {model_id}")
|
||||||
|
|
||||||
|
# Update tags if they were missing
|
||||||
|
if model_metadata.get('tags') and (not model_data.get('tags') or len(model_data.get('tags', [])) == 0):
|
||||||
|
model_data['tags'] = model_metadata['tags']
|
||||||
|
|
||||||
|
# Update description if it was missing
|
||||||
|
if model_metadata.get('description') and (not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")):
|
||||||
|
model_data['modelDescription'] = model_metadata['description']
|
||||||
|
|
||||||
|
# Save the updated metadata
|
||||||
|
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||||
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(model_data, f, indent=2, ensure_ascii=False)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}")
|
||||||
|
|
||||||
|
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||||
|
"""Base implementation for directory scanning"""
|
||||||
|
models = []
|
||||||
|
original_root = root_path
|
||||||
|
|
||||||
|
async def scan_recursive(path: str, visited_paths: set):
|
||||||
|
try:
|
||||||
|
real_path = os.path.realpath(path)
|
||||||
|
if real_path in visited_paths:
|
||||||
|
logger.debug(f"Skipping already visited path: {path}")
|
||||||
|
return
|
||||||
|
visited_paths.add(real_path)
|
||||||
|
|
||||||
|
with os.scandir(path) as it:
|
||||||
|
entries = list(it)
|
||||||
|
for entry in entries:
|
||||||
|
try:
|
||||||
|
if entry.is_file(follow_symlinks=True):
|
||||||
|
# Check if file has supported extension
|
||||||
|
ext = os.path.splitext(entry.name)[1].lower()
|
||||||
|
if ext in self.file_extensions:
|
||||||
|
file_path = entry.path.replace(os.sep, "/")
|
||||||
|
await self._process_single_file(file_path, original_root, models)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
elif entry.is_dir(follow_symlinks=True):
|
||||||
|
# For directories, continue scanning with original path
|
||||||
|
await scan_recursive(entry.path, visited_paths)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error scanning {path}: {e}")
|
||||||
|
|
||||||
|
await scan_recursive(root_path, set())
|
||||||
|
return models
|
||||||
|
|
||||||
|
async def _process_single_file(self, file_path: str, root_path: str, models_list: list):
|
||||||
|
"""Process a single file and add to results list"""
|
||||||
|
try:
|
||||||
|
result = await self._process_model_file(file_path, root_path)
|
||||||
|
if result:
|
||||||
|
models_list.append(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing {file_path}: {e}")
|
||||||
|
|
||||||
|
async def move_model(self, source_path: str, target_path: str) -> bool:
|
||||||
|
"""Move a model and its associated files to a new location"""
|
||||||
|
try:
|
||||||
|
# Keep original path format
|
||||||
|
source_path = source_path.replace(os.sep, '/')
|
||||||
|
target_path = target_path.replace(os.sep, '/')
|
||||||
|
|
||||||
|
# Get file extension from source
|
||||||
|
file_ext = os.path.splitext(source_path)[1]
|
||||||
|
|
||||||
|
# If no extension or not in supported extensions, return False
|
||||||
|
if not file_ext or file_ext.lower() not in self.file_extensions:
|
||||||
|
logger.error(f"Invalid file extension for model: {file_ext}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
base_name = os.path.splitext(os.path.basename(source_path))[0]
|
||||||
|
source_dir = os.path.dirname(source_path)
|
||||||
|
|
||||||
|
os.makedirs(target_path, exist_ok=True)
|
||||||
|
|
||||||
|
target_file = os.path.join(target_path, f"{base_name}{file_ext}").replace(os.sep, '/')
|
||||||
|
|
||||||
|
# Use real paths for file operations
|
||||||
|
real_source = os.path.realpath(source_path)
|
||||||
|
real_target = os.path.realpath(target_file)
|
||||||
|
|
||||||
|
file_size = os.path.getsize(real_source)
|
||||||
|
|
||||||
|
if self.file_monitor:
|
||||||
|
self.file_monitor.handler.add_ignore_path(
|
||||||
|
real_source,
|
||||||
|
file_size
|
||||||
|
)
|
||||||
|
self.file_monitor.handler.add_ignore_path(
|
||||||
|
real_target,
|
||||||
|
file_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use real paths for file operations
|
||||||
|
shutil.move(real_source, real_target)
|
||||||
|
|
||||||
|
# Move associated files
|
||||||
|
source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json")
|
||||||
|
metadata = None
|
||||||
|
if os.path.exists(source_metadata):
|
||||||
|
target_metadata = os.path.join(target_path, f"{base_name}.metadata.json")
|
||||||
|
shutil.move(source_metadata, target_metadata)
|
||||||
|
metadata = await self._update_metadata_paths(target_metadata, target_file)
|
||||||
|
|
||||||
|
# Move preview file if exists
|
||||||
|
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
|
||||||
|
'.png', '.jpeg', '.jpg', '.mp4']
|
||||||
|
for ext in preview_extensions:
|
||||||
|
source_preview = os.path.join(source_dir, f"{base_name}{ext}")
|
||||||
|
if os.path.exists(source_preview):
|
||||||
|
target_preview = os.path.join(target_path, f"{base_name}{ext}")
|
||||||
|
shutil.move(source_preview, target_preview)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
await self.update_single_model_cache(source_path, target_file, metadata)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error moving model: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _update_metadata_paths(self, metadata_path: str, model_path: str) -> Dict:
|
||||||
|
"""Update file paths in metadata file"""
|
||||||
|
try:
|
||||||
|
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
|
||||||
|
# Update file_path
|
||||||
|
metadata['file_path'] = model_path.replace(os.sep, '/')
|
||||||
|
|
||||||
|
# Update preview_url if exists
|
||||||
|
if 'preview_url' in metadata:
|
||||||
|
preview_dir = os.path.dirname(model_path)
|
||||||
|
preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0]
|
||||||
|
preview_ext = os.path.splitext(metadata['preview_url'])[1]
|
||||||
|
new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}")
|
||||||
|
metadata['preview_url'] = new_preview_path.replace(os.sep, '/')
|
||||||
|
|
||||||
|
# Save updated metadata
|
||||||
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating metadata paths: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def update_single_model_cache(self, original_path: str, new_path: str, metadata: Dict) -> bool:
|
||||||
|
"""Update cache after a model has been moved or modified"""
|
||||||
|
cache = await self.get_cached_data()
|
||||||
|
|
||||||
|
# Find the existing item to remove its tags from count
|
||||||
|
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
|
||||||
|
if existing_item and 'tags' in existing_item:
|
||||||
|
for tag in existing_item.get('tags', []):
|
||||||
|
if tag in self._tags_count:
|
||||||
|
self._tags_count[tag] = max(0, self._tags_count[tag] - 1)
|
||||||
|
if self._tags_count[tag] == 0:
|
||||||
|
del self._tags_count[tag]
|
||||||
|
|
||||||
|
# Remove old path from hash index if exists
|
||||||
|
self._hash_index.remove_by_path(original_path)
|
||||||
|
|
||||||
|
# Remove the old entry from raw_data
|
||||||
|
cache.raw_data = [
|
||||||
|
item for item in cache.raw_data
|
||||||
|
if item['file_path'] != original_path
|
||||||
|
]
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
# If this is an update to an existing path (not a move), ensure folder is preserved
|
||||||
|
if original_path == new_path:
|
||||||
|
# Find the folder from existing entries or calculate it
|
||||||
|
existing_folder = next((item['folder'] for item in cache.raw_data
|
||||||
|
if item['file_path'] == original_path), None)
|
||||||
|
if existing_folder:
|
||||||
|
metadata['folder'] = existing_folder
|
||||||
|
else:
|
||||||
|
metadata['folder'] = self._calculate_folder(new_path)
|
||||||
|
else:
|
||||||
|
# For moved files, recalculate the folder
|
||||||
|
metadata['folder'] = self._calculate_folder(new_path)
|
||||||
|
|
||||||
|
# Add the updated metadata to raw_data
|
||||||
|
cache.raw_data.append(metadata)
|
||||||
|
|
||||||
|
# Update hash index with new path
|
||||||
|
if 'sha256' in metadata:
|
||||||
|
self._hash_index.add_entry(metadata['sha256'].lower(), new_path)
|
||||||
|
|
||||||
|
# Update folders list
|
||||||
|
all_folders = set(item['folder'] for item in cache.raw_data)
|
||||||
|
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||||
|
|
||||||
|
# Update tags count with the new/updated tags
|
||||||
|
if 'tags' in metadata:
|
||||||
|
for tag in metadata.get('tags', []):
|
||||||
|
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||||
|
|
||||||
|
# Resort cache
|
||||||
|
await cache.resort()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Hash index functionality (common for all model types)
|
||||||
|
def has_hash(self, sha256: str) -> bool:
|
||||||
|
"""Check if a model with given hash exists"""
|
||||||
|
return self._hash_index.has_hash(sha256.lower())
|
||||||
|
|
||||||
|
def get_path_by_hash(self, sha256: str) -> Optional[str]:
|
||||||
|
"""Get file path for a model by its hash"""
|
||||||
|
return self._hash_index.get_path(sha256.lower())
|
||||||
|
|
||||||
|
def get_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||||
|
"""Get hash for a model by its file path"""
|
||||||
|
return self._hash_index.get_hash(file_path)
|
||||||
|
|
||||||
|
def get_preview_url_by_hash(self, sha256: str) -> Optional[str]:
|
||||||
|
"""Get preview static URL for a model by its hash"""
|
||||||
|
# Get the file path first
|
||||||
|
file_path = self._hash_index.get_path(sha256.lower())
|
||||||
|
if not file_path:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Determine the preview file path (typically same name with different extension)
|
||||||
|
base_name = os.path.splitext(file_path)[0]
|
||||||
|
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
|
||||||
|
'.png', '.jpeg', '.jpg', '.mp4']
|
||||||
|
|
||||||
|
for ext in preview_extensions:
|
||||||
|
preview_path = f"{base_name}{ext}"
|
||||||
|
if os.path.exists(preview_path):
|
||||||
|
# Convert to static URL using config
|
||||||
|
return config.get_preview_static_url(preview_path)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
|
||||||
|
"""Get top tags sorted by count"""
|
||||||
|
# Make sure cache is initialized
|
||||||
|
await self.get_cached_data()
|
||||||
|
|
||||||
|
# Sort tags by count in descending order
|
||||||
|
sorted_tags = sorted(
|
||||||
|
[{"tag": tag, "count": count} for tag, count in self._tags_count.items()],
|
||||||
|
key=lambda x: x['count'],
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return limited number
|
||||||
|
return sorted_tags[:limit]
|
||||||
|
|
||||||
|
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
|
||||||
|
"""Get base models sorted by frequency"""
|
||||||
|
# Make sure cache is initialized
|
||||||
|
cache = await self.get_cached_data()
|
||||||
|
|
||||||
|
# Count base model occurrences
|
||||||
|
base_model_counts = {}
|
||||||
|
for model in cache.raw_data:
|
||||||
|
if 'base_model' in model and model['base_model']:
|
||||||
|
base_model = model['base_model']
|
||||||
|
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||||
|
|
||||||
|
# Sort base models by count
|
||||||
|
sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()]
|
||||||
|
sorted_models.sort(key=lambda x: x['count'], reverse=True)
|
||||||
|
|
||||||
|
# Return limited number
|
||||||
|
return sorted_models[:limit]
|
||||||
|
|
||||||
|
async def get_model_info_by_name(self, name):
|
||||||
|
"""Get model information by name"""
|
||||||
|
try:
|
||||||
|
# Get cached data
|
||||||
|
cache = await self.get_cached_data()
|
||||||
|
|
||||||
|
# Find the model by name
|
||||||
|
for model in cache.raw_data:
|
||||||
|
if model.get("file_name") == name:
|
||||||
|
return model
|
||||||
|
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting model info by name: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
@@ -2,12 +2,12 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
from typing import Dict, Optional
|
import time
|
||||||
|
from typing import Dict, Optional, Type
|
||||||
|
|
||||||
from .model_utils import determine_base_model
|
from .model_utils import determine_base_model
|
||||||
|
from .lora_metadata import extract_lora_metadata, extract_checkpoint_metadata
|
||||||
from .lora_metadata import extract_lora_metadata
|
from .models import BaseModelMetadata, LoraMetadata, CheckpointMetadata
|
||||||
from .models import LoraMetadata
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -15,7 +15,7 @@ async def calculate_sha256(file_path: str) -> str:
|
|||||||
"""Calculate SHA256 hash of a file"""
|
"""Calculate SHA256 hash of a file"""
|
||||||
sha256_hash = hashlib.sha256()
|
sha256_hash = hashlib.sha256()
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
for byte_block in iter(lambda: f.read(4096), b""):
|
for byte_block in iter(lambda: f.read(128 * 1024), b""):
|
||||||
sha256_hash.update(byte_block)
|
sha256_hash.update(byte_block)
|
||||||
return sha256_hash.hexdigest()
|
return sha256_hash.hexdigest()
|
||||||
|
|
||||||
@@ -42,8 +42,8 @@ def normalize_path(path: str) -> str:
|
|||||||
"""Normalize file path to use forward slashes"""
|
"""Normalize file path to use forward slashes"""
|
||||||
return path.replace(os.sep, "/") if path else path
|
return path.replace(os.sep, "/") if path else path
|
||||||
|
|
||||||
async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
|
async def get_file_info(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]:
|
||||||
"""Get basic file information as LoraMetadata object"""
|
"""Get basic file information as a model metadata object"""
|
||||||
# First check if file actually exists and resolve symlinks
|
# First check if file actually exists and resolve symlinks
|
||||||
try:
|
try:
|
||||||
real_path = os.path.realpath(file_path)
|
real_path = os.path.realpath(file_path)
|
||||||
@@ -74,27 +74,52 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
|
|||||||
try:
|
try:
|
||||||
# If we didn't get SHA256 from the .json file, calculate it
|
# If we didn't get SHA256 from the .json file, calculate it
|
||||||
if not sha256:
|
if not sha256:
|
||||||
|
start_time = time.time()
|
||||||
sha256 = await calculate_sha256(real_path)
|
sha256 = await calculate_sha256(real_path)
|
||||||
|
logger.debug(f"Calculated SHA256 for {file_path} in {time.time() - start_time:.2f} seconds")
|
||||||
|
|
||||||
metadata = LoraMetadata(
|
# Create default metadata based on model class
|
||||||
file_name=base_name,
|
if model_class == CheckpointMetadata:
|
||||||
model_name=base_name,
|
metadata = CheckpointMetadata(
|
||||||
file_path=normalize_path(file_path),
|
file_name=base_name,
|
||||||
size=os.path.getsize(real_path),
|
model_name=base_name,
|
||||||
modified=os.path.getmtime(real_path),
|
file_path=normalize_path(file_path),
|
||||||
sha256=sha256,
|
size=os.path.getsize(real_path),
|
||||||
base_model="Unknown", # Will be updated later
|
modified=os.path.getmtime(real_path),
|
||||||
usage_tips="",
|
sha256=sha256,
|
||||||
notes="",
|
base_model="Unknown", # Will be updated later
|
||||||
from_civitai=True,
|
preview_url=normalize_path(preview_url),
|
||||||
preview_url=normalize_path(preview_url),
|
tags=[],
|
||||||
tags=[],
|
modelDescription="",
|
||||||
modelDescription=""
|
model_type="checkpoint"
|
||||||
)
|
)
|
||||||
|
|
||||||
# create metadata file
|
# Extract checkpoint-specific metadata
|
||||||
base_model_info = await extract_lora_metadata(real_path)
|
# model_info = await extract_checkpoint_metadata(real_path)
|
||||||
metadata.base_model = base_model_info['base_model']
|
# metadata.base_model = model_info['base_model']
|
||||||
|
# if 'model_type' in model_info:
|
||||||
|
# metadata.model_type = model_info['model_type']
|
||||||
|
|
||||||
|
else: # Default to LoraMetadata
|
||||||
|
metadata = LoraMetadata(
|
||||||
|
file_name=base_name,
|
||||||
|
model_name=base_name,
|
||||||
|
file_path=normalize_path(file_path),
|
||||||
|
size=os.path.getsize(real_path),
|
||||||
|
modified=os.path.getmtime(real_path),
|
||||||
|
sha256=sha256,
|
||||||
|
base_model="Unknown", # Will be updated later
|
||||||
|
usage_tips="{}",
|
||||||
|
preview_url=normalize_path(preview_url),
|
||||||
|
tags=[],
|
||||||
|
modelDescription=""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract lora-specific metadata
|
||||||
|
model_info = await extract_lora_metadata(real_path)
|
||||||
|
metadata.base_model = model_info['base_model']
|
||||||
|
|
||||||
|
# Save metadata to file
|
||||||
await save_metadata(file_path, metadata)
|
await save_metadata(file_path, metadata)
|
||||||
|
|
||||||
return metadata
|
return metadata
|
||||||
@@ -102,7 +127,7 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
|
|||||||
logger.error(f"Error getting file info for {file_path}: {e}")
|
logger.error(f"Error getting file info for {file_path}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def save_metadata(file_path: str, metadata: LoraMetadata) -> None:
|
async def save_metadata(file_path: str, metadata: BaseModelMetadata) -> None:
|
||||||
"""Save metadata to .metadata.json file"""
|
"""Save metadata to .metadata.json file"""
|
||||||
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
|
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
|
||||||
try:
|
try:
|
||||||
@@ -115,7 +140,7 @@ async def save_metadata(file_path: str, metadata: LoraMetadata) -> None:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error saving metadata to {metadata_path}: {str(e)}")
|
print(f"Error saving metadata to {metadata_path}: {str(e)}")
|
||||||
|
|
||||||
async def load_metadata(file_path: str) -> Optional[LoraMetadata]:
|
async def load_metadata(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]:
|
||||||
"""Load metadata from .metadata.json file"""
|
"""Load metadata from .metadata.json file"""
|
||||||
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
|
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
|
||||||
try:
|
try:
|
||||||
@@ -163,11 +188,21 @@ async def load_metadata(file_path: str) -> Optional[LoraMetadata]:
|
|||||||
data['modelDescription'] = ""
|
data['modelDescription'] = ""
|
||||||
needs_update = True
|
needs_update = True
|
||||||
|
|
||||||
|
# For checkpoint metadata
|
||||||
|
if model_class == CheckpointMetadata and 'model_type' not in data:
|
||||||
|
data['model_type'] = "checkpoint"
|
||||||
|
needs_update = True
|
||||||
|
|
||||||
|
# For lora metadata
|
||||||
|
if model_class == LoraMetadata and 'usage_tips' not in data:
|
||||||
|
data['usage_tips'] = "{}"
|
||||||
|
needs_update = True
|
||||||
|
|
||||||
if needs_update:
|
if needs_update:
|
||||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
return LoraMetadata.from_dict(data)
|
return model_class.from_dict(data)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading metadata from {metadata_path}: {str(e)}")
|
print(f"Error loading metadata from {metadata_path}: {str(e)}")
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from .model_utils import determine_base_model
|
from .model_utils import determine_base_model
|
||||||
|
import os
|
||||||
|
|
||||||
async def extract_lora_metadata(file_path: str) -> Dict:
|
async def extract_lora_metadata(file_path: str) -> Dict:
|
||||||
"""Extract essential metadata from safetensors file"""
|
"""Extract essential metadata from safetensors file"""
|
||||||
@@ -14,3 +15,66 @@ async def extract_lora_metadata(file_path: str) -> Dict:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error reading metadata from {file_path}: {str(e)}")
|
print(f"Error reading metadata from {file_path}: {str(e)}")
|
||||||
return {"base_model": "Unknown"}
|
return {"base_model": "Unknown"}
|
||||||
|
|
||||||
|
async def extract_checkpoint_metadata(file_path: str) -> dict:
|
||||||
|
"""Extract metadata from a checkpoint file to determine model type and base model"""
|
||||||
|
try:
|
||||||
|
# Analyze filename for clues about the model
|
||||||
|
filename = os.path.basename(file_path).lower()
|
||||||
|
|
||||||
|
model_info = {
|
||||||
|
'base_model': 'Unknown',
|
||||||
|
'model_type': 'checkpoint'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Detect base model from filename
|
||||||
|
if 'xl' in filename or 'sdxl' in filename:
|
||||||
|
model_info['base_model'] = 'SDXL'
|
||||||
|
elif 'sd3' in filename:
|
||||||
|
model_info['base_model'] = 'SD3'
|
||||||
|
elif 'sd2' in filename or 'v2' in filename:
|
||||||
|
model_info['base_model'] = 'SD2.x'
|
||||||
|
elif 'sd1' in filename or 'v1' in filename:
|
||||||
|
model_info['base_model'] = 'SD1.5'
|
||||||
|
|
||||||
|
# Detect model type from filename
|
||||||
|
if 'inpaint' in filename:
|
||||||
|
model_info['model_type'] = 'inpainting'
|
||||||
|
elif 'anime' in filename:
|
||||||
|
model_info['model_type'] = 'anime'
|
||||||
|
elif 'realistic' in filename:
|
||||||
|
model_info['model_type'] = 'realistic'
|
||||||
|
|
||||||
|
# Try to peek at the safetensors file structure if available
|
||||||
|
if file_path.endswith('.safetensors'):
|
||||||
|
import json
|
||||||
|
import struct
|
||||||
|
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
header_size = struct.unpack('<Q', f.read(8))[0]
|
||||||
|
header_json = f.read(header_size)
|
||||||
|
header = json.loads(header_json)
|
||||||
|
|
||||||
|
# Look for specific keys to identify model type
|
||||||
|
metadata = header.get('__metadata__', {})
|
||||||
|
if metadata:
|
||||||
|
# Try to determine if it's SDXL
|
||||||
|
if any(key.startswith('conditioner.embedders.1') for key in header):
|
||||||
|
model_info['base_model'] = 'SDXL'
|
||||||
|
|
||||||
|
# Look for model type info
|
||||||
|
if metadata.get('modelspec.architecture') == 'SD-XL':
|
||||||
|
model_info['base_model'] = 'SDXL'
|
||||||
|
elif metadata.get('modelspec.architecture') == 'SD-3':
|
||||||
|
model_info['base_model'] = 'SD3'
|
||||||
|
|
||||||
|
# Check for specific use case
|
||||||
|
if metadata.get('modelspec.purpose') == 'inpainting':
|
||||||
|
model_info['model_type'] = 'inpainting'
|
||||||
|
|
||||||
|
return model_info
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting checkpoint metadata for {file_path}: {e}")
|
||||||
|
# Return default values
|
||||||
|
return {'base_model': 'Unknown', 'model_type': 'checkpoint'}
|
||||||
@@ -5,20 +5,19 @@ import os
|
|||||||
from .model_utils import determine_base_model
|
from .model_utils import determine_base_model
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoraMetadata:
|
class BaseModelMetadata:
|
||||||
"""Represents the metadata structure for a Lora model"""
|
"""Base class for all model metadata structures"""
|
||||||
file_name: str # The filename without extension of the lora
|
file_name: str # The filename without extension
|
||||||
model_name: str # The lora's name defined by the creator, initially same as file_name
|
model_name: str # The model's name defined by the creator
|
||||||
file_path: str # Full path to the safetensors file
|
file_path: str # Full path to the model file
|
||||||
size: int # File size in bytes
|
size: int # File size in bytes
|
||||||
modified: float # Last modified timestamp
|
modified: float # Last modified timestamp
|
||||||
sha256: str # SHA256 hash of the file
|
sha256: str # SHA256 hash of the file
|
||||||
base_model: str # Base model (SD1.5/SD2.1/SDXL/etc.)
|
base_model: str # Base model type (SD1.5/SD2.1/SDXL/etc.)
|
||||||
preview_url: str # Preview image URL
|
preview_url: str # Preview image URL
|
||||||
preview_nsfw_level: int = 0 # NSFW level of the preview image
|
preview_nsfw_level: int = 0 # NSFW level of the preview image
|
||||||
usage_tips: str = "{}" # Usage tips for the model, json string
|
|
||||||
notes: str = "" # Additional notes
|
notes: str = "" # Additional notes
|
||||||
from_civitai: bool = True # Whether the lora is from Civitai
|
from_civitai: bool = True # Whether from Civitai
|
||||||
civitai: Optional[Dict] = None # Civitai API data if available
|
civitai: Optional[Dict] = None # Civitai API data if available
|
||||||
tags: List[str] = None # Model tags
|
tags: List[str] = None # Model tags
|
||||||
modelDescription: str = "" # Full model description
|
modelDescription: str = "" # Full model description
|
||||||
@@ -29,32 +28,11 @@ class LoraMetadata:
|
|||||||
self.tags = []
|
self.tags = []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict) -> 'LoraMetadata':
|
def from_dict(cls, data: Dict) -> 'BaseModelMetadata':
|
||||||
"""Create LoraMetadata instance from dictionary"""
|
"""Create instance from dictionary"""
|
||||||
# Create a copy of the data to avoid modifying the input
|
|
||||||
data_copy = data.copy()
|
data_copy = data.copy()
|
||||||
return cls(**data_copy)
|
return cls(**data_copy)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata':
|
|
||||||
"""Create LoraMetadata instance from Civitai version info"""
|
|
||||||
file_name = file_info['name']
|
|
||||||
base_model = determine_base_model(version_info.get('baseModel', ''))
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
file_name=os.path.splitext(file_name)[0],
|
|
||||||
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
|
||||||
file_path=save_path.replace(os.sep, '/'),
|
|
||||||
size=file_info.get('sizeKB', 0) * 1024,
|
|
||||||
modified=datetime.now().timestamp(),
|
|
||||||
sha256=file_info['hashes'].get('SHA256', '').lower(),
|
|
||||||
base_model=base_model,
|
|
||||||
preview_url=None, # Will be updated after preview download
|
|
||||||
preview_nsfw_level=0, # Will be updated after preview download, it is decided by the nsfw level of the preview image
|
|
||||||
from_civitai=True,
|
|
||||||
civitai=version_info
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict:
|
def to_dict(self) -> Dict:
|
||||||
"""Convert to dictionary for JSON serialization"""
|
"""Convert to dictionary for JSON serialization"""
|
||||||
return asdict(self)
|
return asdict(self)
|
||||||
@@ -76,30 +54,54 @@ class LoraMetadata:
|
|||||||
self.file_path = file_path.replace(os.sep, '/')
|
self.file_path = file_path.replace(os.sep, '/')
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CheckpointMetadata:
|
class LoraMetadata(BaseModelMetadata):
|
||||||
|
"""Represents the metadata structure for a Lora model"""
|
||||||
|
usage_tips: str = "{}" # Usage tips for the model, json string
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata':
|
||||||
|
"""Create LoraMetadata instance from Civitai version info"""
|
||||||
|
file_name = file_info['name']
|
||||||
|
base_model = determine_base_model(version_info.get('baseModel', ''))
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
file_name=os.path.splitext(file_name)[0],
|
||||||
|
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
||||||
|
file_path=save_path.replace(os.sep, '/'),
|
||||||
|
size=file_info.get('sizeKB', 0) * 1024,
|
||||||
|
modified=datetime.now().timestamp(),
|
||||||
|
sha256=file_info['hashes'].get('SHA256', '').lower(),
|
||||||
|
base_model=base_model,
|
||||||
|
preview_url=None, # Will be updated after preview download
|
||||||
|
preview_nsfw_level=0, # Will be updated after preview download
|
||||||
|
from_civitai=True,
|
||||||
|
civitai=version_info
|
||||||
|
)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CheckpointMetadata(BaseModelMetadata):
|
||||||
"""Represents the metadata structure for a Checkpoint model"""
|
"""Represents the metadata structure for a Checkpoint model"""
|
||||||
file_name: str # The filename without extension
|
|
||||||
model_name: str # The checkpoint's name defined by the creator
|
|
||||||
file_path: str # Full path to the model file
|
|
||||||
size: int # File size in bytes
|
|
||||||
modified: float # Last modified timestamp
|
|
||||||
sha256: str # SHA256 hash of the file
|
|
||||||
base_model: str # Base model type (SD1.5/SD2.1/SDXL/etc.)
|
|
||||||
preview_url: str # Preview image URL
|
|
||||||
preview_nsfw_level: int = 0 # NSFW level of the preview image
|
|
||||||
model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.)
|
model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.)
|
||||||
notes: str = "" # Additional notes
|
|
||||||
from_civitai: bool = True # Whether from Civitai
|
|
||||||
civitai: Optional[Dict] = None # Civitai API data if available
|
|
||||||
tags: List[str] = None # Model tags
|
|
||||||
modelDescription: str = "" # Full model description
|
|
||||||
|
|
||||||
# Additional checkpoint-specific fields
|
@classmethod
|
||||||
resolution: Optional[str] = None # Native resolution (e.g., 512x512, 1024x1024)
|
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'CheckpointMetadata':
|
||||||
vae_included: bool = False # Whether VAE is included in the checkpoint
|
"""Create CheckpointMetadata instance from Civitai version info"""
|
||||||
architecture: str = "" # Model architecture (if known)
|
file_name = file_info['name']
|
||||||
|
base_model = determine_base_model(version_info.get('baseModel', ''))
|
||||||
|
model_type = version_info.get('type', 'checkpoint')
|
||||||
|
|
||||||
def __post_init__(self):
|
return cls(
|
||||||
if self.tags is None:
|
file_name=os.path.splitext(file_name)[0],
|
||||||
self.tags = []
|
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
||||||
|
file_path=save_path.replace(os.sep, '/'),
|
||||||
|
size=file_info.get('sizeKB', 0) * 1024,
|
||||||
|
modified=datetime.now().timestamp(),
|
||||||
|
sha256=file_info['hashes'].get('SHA256', '').lower(),
|
||||||
|
base_model=base_model,
|
||||||
|
preview_url=None, # Will be updated after preview download
|
||||||
|
preview_nsfw_level=0,
|
||||||
|
from_civitai=True,
|
||||||
|
civitai=version_info,
|
||||||
|
model_type=model_type
|
||||||
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user