feat: implement auto-organize models endpoint with batch processing and error handling

This commit is contained in:
Will Miao
2025-08-08 19:13:12 +08:00
parent 286f4ff384
commit a920921570
3 changed files with 284 additions and 1 deletions

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import asyncio import asyncio
import os
import json import json
import logging import logging
from aiohttp import web from aiohttp import web
@@ -10,6 +11,8 @@ import jinja2
from ..utils.routes_common import ModelRouteUtils from ..utils.routes_common import ModelRouteUtils
from ..services.websocket_manager import ws_manager from ..services.websocket_manager import ws_manager
from ..services.settings_manager import settings from ..services.settings_manager import settings
from ..utils.utils import calculate_relative_path_for_model
from ..utils.constants import AUTO_ORGANIZE_BATCH_SIZE
from ..config import config from ..config import config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -50,6 +53,7 @@ class BaseModelRoutes(ABC):
app.router.add_post(f'/api/{prefix}/verify-duplicates', self.verify_duplicates) app.router.add_post(f'/api/{prefix}/verify-duplicates', self.verify_duplicates)
app.router.add_post(f'/api/{prefix}/move_model', self.move_model) app.router.add_post(f'/api/{prefix}/move_model', self.move_model)
app.router.add_post(f'/api/{prefix}/move_models_bulk', self.move_models_bulk) app.router.add_post(f'/api/{prefix}/move_models_bulk', self.move_models_bulk)
app.router.add_get(f'/api/{prefix}/auto-organize', self.auto_organize_models)
# Common query routes # Common query routes
app.router.add_get(f'/api/{prefix}/top-tags', self.get_top_tags) app.router.add_get(f'/api/{prefix}/top-tags', self.get_top_tags)
@@ -736,3 +740,223 @@ class BaseModelRoutes(ABC):
except Exception as e: except Exception as e:
logger.error(f"Error moving models in bulk: {e}", exc_info=True) logger.error(f"Error moving models in bulk: {e}", exc_info=True)
return web.Response(text=str(e), status=500) return web.Response(text=str(e), status=500)
async def auto_organize_models(self, request: web.Request) -> web.Response:
"""Auto-organize all models based on current settings"""
try:
# Get all models from cache
cache = await self.service.scanner.get_cached_data()
all_models = cache.raw_data
# Get model roots for this scanner
model_roots = self.service.get_model_roots()
if not model_roots:
return web.json_response({
'success': False,
'error': 'No model roots configured'
}, status=400)
# Check if flat structure is configured
path_template = settings.get('download_path_template', '{base_model}/{first_tag}')
is_flat_structure = not path_template
# Prepare results tracking
results = []
total_models = len(all_models)
processed = 0
success_count = 0
failure_count = 0
skipped_count = 0
# Send initial progress via WebSocket
await ws_manager.broadcast({
'type': 'auto_organize_progress',
'status': 'started',
'total': total_models,
'processed': 0,
'success': 0,
'failures': 0,
'skipped': 0
})
# Process models in batches
for i in range(0, total_models, AUTO_ORGANIZE_BATCH_SIZE):
batch = all_models[i:i + AUTO_ORGANIZE_BATCH_SIZE]
for model in batch:
try:
file_path = model.get('file_path')
if not file_path:
if len(results) < 100: # Limit detailed results
results.append({
"model": model.get('model_name', 'Unknown'),
"success": False,
"message": "No file path found"
})
failure_count += 1
processed += 1
continue
# Find which model root this file belongs to
current_root = None
for root in model_roots:
# Normalize paths for comparison
normalized_root = os.path.normpath(root).replace(os.sep, '/')
normalized_file = os.path.normpath(file_path).replace(os.sep, '/')
if normalized_file.startswith(normalized_root):
current_root = root
break
if not current_root:
if len(results) < 100: # Limit detailed results
results.append({
"model": model.get('model_name', 'Unknown'),
"success": False,
"message": "Model file not found in any configured root directory"
})
failure_count += 1
processed += 1
continue
# Handle flat structure case
if is_flat_structure:
current_dir = os.path.dirname(file_path)
# Check if already in root directory
if os.path.normpath(current_dir) == os.path.normpath(current_root):
skipped_count += 1
processed += 1
continue
# Move to root directory for flat structure
target_dir = current_root
else:
# Calculate new relative path based on settings
new_relative_path = calculate_relative_path_for_model(model)
# If no relative path calculated (insufficient metadata), skip
if not new_relative_path:
if len(results) < 100: # Limit detailed results
results.append({
"model": model.get('model_name', 'Unknown'),
"success": False,
"message": "Skipped - insufficient metadata for organization"
})
skipped_count += 1
processed += 1
continue
# Calculate target directory
target_dir = os.path.join(current_root, new_relative_path).replace(os.sep, '/')
current_dir = os.path.dirname(file_path)
# Skip if already in correct location
if os.path.normpath(current_dir) == os.path.normpath(target_dir):
skipped_count += 1
processed += 1
continue
# Check if target file would conflict
file_name = os.path.basename(file_path)
target_file_path = os.path.join(target_dir, file_name)
if os.path.exists(target_file_path):
if len(results) < 100: # Limit detailed results
results.append({
"model": model.get('model_name', 'Unknown'),
"success": False,
"message": f"Target file already exists: {target_file_path}"
})
failure_count += 1
processed += 1
continue
# Perform the move
success = await self.service.scanner.move_model(file_path, target_dir)
if success:
success_count += 1
else:
if len(results) < 100: # Limit detailed results
results.append({
"model": model.get('model_name', 'Unknown'),
"success": False,
"message": "Failed to move model"
})
failure_count += 1
processed += 1
except Exception as e:
logger.error(f"Error processing model {model.get('model_name', 'Unknown')}: {e}", exc_info=True)
if len(results) < 100: # Limit detailed results
results.append({
"model": model.get('model_name', 'Unknown'),
"success": False,
"message": f"Error: {str(e)}"
})
failure_count += 1
processed += 1
# Send progress update after each batch
await ws_manager.broadcast({
'type': 'auto_organize_progress',
'status': 'processing',
'total': total_models,
'processed': processed,
'success': success_count,
'failures': failure_count,
'skipped': skipped_count
})
# Small delay between batches to prevent overwhelming the system
await asyncio.sleep(0.1)
# Send completion message
await ws_manager.broadcast({
'type': 'auto_organize_progress',
'status': 'completed',
'total': total_models,
'processed': processed,
'success': success_count,
'failures': failure_count,
'skipped': skipped_count
})
# Prepare response with limited details
response_data = {
'success': True,
'message': f'Auto-organize completed: {success_count} moved, {skipped_count} skipped, {failure_count} failed out of {total_models} total',
'summary': {
'total': total_models,
'success': success_count,
'skipped': skipped_count,
'failures': failure_count,
'organization_type': 'flat' if is_flat_structure else 'structured'
}
}
# Only include detailed results if under limit
if len(results) <= 100:
response_data['results'] = results
else:
response_data['results_truncated'] = True
response_data['sample_results'] = results[:50] # Show first 50 as sample
return web.json_response(response_data)
except Exception as e:
logger.error(f"Error in auto_organize_models: {e}", exc_info=True)
# Send error message via WebSocket
await ws_manager.broadcast({
'type': 'auto_organize_progress',
'status': 'error',
'error': str(e)
})
return web.json_response({
'success': False,
'error': str(e)
}, status=500)

View File

@@ -48,6 +48,9 @@ SUPPORTED_MEDIA_EXTENSIONS = {
# Valid Lora types # Valid Lora types
VALID_LORA_TYPES = ['lora', 'locon', 'dora'] VALID_LORA_TYPES = ['lora', 'locon', 'dora']
# Auto-organize settings
AUTO_ORGANIZE_BATCH_SIZE = 50 # Process models in batches to avoid overwhelming the system
# Civitai model tags in priority order for subfolder organization # Civitai model tags in priority order for subfolder organization
CIVITAI_MODEL_TAGS = [ CIVITAI_MODEL_TAGS = [
'character', 'style', 'concept', 'clothing', 'character', 'style', 'concept', 'clothing',

View File

@@ -1,7 +1,10 @@
from difflib import SequenceMatcher from difflib import SequenceMatcher
import os import os
from typing import Dict
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..config import config from ..config import config
from ..services.settings_manager import settings
from .constants import CIVITAI_MODEL_TAGS
import asyncio import asyncio
def get_lora_info(lora_name): def get_lora_info(lora_name):
@@ -128,3 +131,56 @@ def calculate_recipe_fingerprint(loras):
fingerprint = "|".join([f"{hash_value}:{strength}" for hash_value, strength in valid_loras]) fingerprint = "|".join([f"{hash_value}:{strength}" for hash_value, strength in valid_loras])
return fingerprint return fingerprint
def calculate_relative_path_for_model(model_data: Dict) -> str:
"""Calculate relative path for existing model using template from settings
Args:
model_data: Model data from scanner cache
Returns:
Relative path string (empty string for flat structure)
"""
# Get path template from settings, default to '{base_model}/{first_tag}'
path_template = settings.get('download_path_template', '{base_model}/{first_tag}')
# If template is empty, return empty path (flat structure)
if not path_template:
return ''
# Get base model name from model metadata
civitai_data = model_data.get('civitai', {})
# For CivitAI models, prefer civitai data; for non-CivitAI models, use model_data directly
if civitai_data:
base_model = civitai_data.get('baseModel', '')
else:
# Fallback to model_data fields for non-CivitAI models
base_model = model_data.get('base_model', '')
model_tags = model_data.get('tags', [])
# Apply mapping if available
base_model_mappings = settings.get('base_model_path_mappings', {})
mapped_base_model = base_model_mappings.get(base_model, base_model)
# Find the first Civitai model tag that exists in model_tags
first_tag = ''
for civitai_tag in CIVITAI_MODEL_TAGS:
if civitai_tag in model_tags:
first_tag = civitai_tag
break
# If no Civitai model tag found, fallback to first tag
if not first_tag and model_tags:
first_tag = model_tags[0]
if not first_tag:
first_tag = 'no tags' # Default if no tags available
# Format the template with available data
formatted_path = path_template
formatted_path = formatted_path.replace('{base_model}', mapped_base_model)
formatted_path = formatted_path.replace('{first_tag}', first_tag)
return formatted_path