feat: Add JSON parsing for base_model_path_mappings and refactor path handling in DownloadManager

This commit is contained in:
Will Miao
2025-07-21 07:37:34 +08:00
parent 0c883433c1
commit 124002a472
2 changed files with 57 additions and 56 deletions

View File

@@ -1,3 +1,4 @@
import json
import logging import logging
import os import os
import sys import sys
@@ -179,6 +180,16 @@ class MiscRoutes:
if old_path != value: if old_path != value:
logger.info(f"Example images path changed to {value} - server restart required") logger.info(f"Example images path changed to {value} - server restart required")
# Special handling for base_model_path_mappings - parse JSON string
if key == 'base_model_path_mappings' and value:
try:
value = json.loads(value)
except json.JSONDecodeError:
return web.json_response({
'success': False,
'error': f"Invalid JSON format for base_model_path_mappings: {value}"
})
# Save to settings # Save to settings
settings.set(key, value) settings.set(key, value)

View File

@@ -1,7 +1,6 @@
import logging import logging
import os import os
import asyncio import asyncio
import yaml
from typing import Dict from typing import Dict
from ..utils.models import LoraMetadata, CheckpointMetadata from ..utils.models import LoraMetadata, CheckpointMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
@@ -34,36 +33,6 @@ class DownloadManager:
self._initialized = True self._initialized = True
self._civitai_client = None # Will be lazily initialized self._civitai_client = None # Will be lazily initialized
self._path_mappings = self._load_path_mappings()
def _load_path_mappings(self):
"""Load path mappings from YAML configuration"""
path_mappings = {
'base_models': {},
'model_tags': {}
}
# Path to the configuration file
config_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'path_mappings.yaml')
try:
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
mappings = yaml.safe_load(f)
if mappings and isinstance(mappings, dict):
if 'base_models' in mappings and isinstance(mappings['base_models'], dict):
path_mappings['base_models'] = mappings['base_models']
if 'model_tags' in mappings and isinstance(mappings['model_tags'], dict):
path_mappings['model_tags'] = mappings['model_tags']
logger.info(f"Loaded path mappings from {config_path}")
else:
logger.info(f"Path mappings configuration file not found at {config_path}, using default mappings")
except Exception as e:
logger.error(f"Error loading path mappings: {e}", exc_info=True)
return path_mappings
async def _get_civitai_client(self): async def _get_civitai_client(self):
"""Lazily initialize CivitaiClient from registry""" """Lazily initialize CivitaiClient from registry"""
@@ -157,31 +126,8 @@ class DownloadManager:
return {'success': False, 'error': 'Default lora root path not set in settings'} return {'success': False, 'error': 'Default lora root path not set in settings'}
save_dir = default_path save_dir = default_path
# Set relative_path to version_info.baseModel/prioritized_tag # Calculate relative path using template
base_model = version_info.get('baseModel', '') relative_path = self._calculate_relative_path(version_info)
model_tags = version_info.get('model', {}).get('tags', [])
if base_model:
# Apply base model mapping if available
mapped_base_model = self._path_mappings['base_models'].get(base_model, base_model)
# Find the first Civitai model tag that exists in model_tags
prioritized_tag = None
for civitai_tag in CIVITAI_MODEL_TAGS:
if civitai_tag in model_tags:
prioritized_tag = civitai_tag
break
# If no Civitai model tag found, fallback to first tag
if prioritized_tag is None and model_tags:
prioritized_tag = model_tags[0]
if prioritized_tag:
# Apply tag mapping if available
mapped_tag = self._path_mappings['model_tags'].get(prioritized_tag, prioritized_tag)
relative_path = os.path.join(mapped_base_model, mapped_tag)
else:
relative_path = mapped_base_model
# Update save directory with relative path if provided # Update save directory with relative path if provided
if relative_path: if relative_path:
@@ -250,6 +196,50 @@ class DownloadManager:
return {'success': False, 'error': f"Early access restriction: {str(e)}. Please ensure you have purchased early access and are logged in to Civitai."} return {'success': False, 'error': f"Early access restriction: {str(e)}. Please ensure you have purchased early access and are logged in to Civitai."}
return {'success': False, 'error': str(e)} return {'success': False, 'error': str(e)}
def _calculate_relative_path(self, version_info: Dict) -> str:
"""Calculate relative path using template from settings
Args:
version_info: Version info from Civitai API
Returns:
Relative path string
"""
# Get path template from settings, default to '{base_model}/{first_tag}'
path_template = settings.get('download_path_template', '{base_model}/{first_tag}')
# If template is empty, return empty path (flat structure)
if not path_template:
return ''
# Get base model name
base_model = version_info.get('baseModel', '')
# Apply mapping if available
base_model_mappings = settings.get('base_model_path_mappings', {})
mapped_base_model = base_model_mappings.get(base_model, base_model)
# Get model tags
model_tags = version_info.get('model', {}).get('tags', [])
# Find the first Civitai model tag that exists in model_tags
first_tag = ''
for civitai_tag in CIVITAI_MODEL_TAGS:
if civitai_tag in model_tags:
first_tag = civitai_tag
break
# If no Civitai model tag found, fallback to first tag
if not first_tag and model_tags:
first_tag = model_tags[0]
# Format the template with available data
formatted_path = path_template
formatted_path = formatted_path.replace('{base_model}', mapped_base_model)
formatted_path = formatted_path.replace('{first_tag}', first_tag)
return formatted_path
async def _execute_download(self, download_url: str, save_dir: str, async def _execute_download(self, download_url: str, save_dir: str,
metadata, version_info: Dict, metadata, version_info: Dict,
relative_path: str, progress_callback=None, relative_path: str, progress_callback=None,