From 124002a472cc1099df7e580b2f2ecab216ace557 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Mon, 21 Jul 2025 07:37:34 +0800 Subject: [PATCH] feat: Add JSON parsing for base_model_path_mappings and refactor path handling in DownloadManager --- py/routes/misc_routes.py | 11 ++++ py/services/download_manager.py | 102 ++++++++++++++------------------ 2 files changed, 57 insertions(+), 56 deletions(-) diff --git a/py/routes/misc_routes.py b/py/routes/misc_routes.py index 54146ce4..019aa82e 100644 --- a/py/routes/misc_routes.py +++ b/py/routes/misc_routes.py @@ -1,3 +1,4 @@ +import json import logging import os import sys @@ -179,6 +180,16 @@ class MiscRoutes: if old_path != value: logger.info(f"Example images path changed to {value} - server restart required") + # Special handling for base_model_path_mappings - parse JSON string + if key == 'base_model_path_mappings' and value: + try: + value = json.loads(value) + except json.JSONDecodeError: + return web.json_response({ + 'success': False, + 'error': f"Invalid JSON format for base_model_path_mappings: {value}" + }) + # Save to settings settings.set(key, value) diff --git a/py/services/download_manager.py b/py/services/download_manager.py index d53256a8..d0ca7094 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -1,7 +1,6 @@ import logging import os import asyncio -import yaml from typing import Dict from ..utils.models import LoraMetadata, CheckpointMetadata from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS @@ -34,36 +33,6 @@ class DownloadManager: self._initialized = True self._civitai_client = None # Will be lazily initialized - self._path_mappings = self._load_path_mappings() - - def _load_path_mappings(self): - """Load path mappings from YAML configuration""" - path_mappings = { - 'base_models': {}, - 'model_tags': {} - } - - # Path to the configuration file - config_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'path_mappings.yaml') - - try: - if os.path.exists(config_path): - with open(config_path, 'r', encoding='utf-8') as f: - mappings = yaml.safe_load(f) - - if mappings and isinstance(mappings, dict): - if 'base_models' in mappings and isinstance(mappings['base_models'], dict): - path_mappings['base_models'] = mappings['base_models'] - if 'model_tags' in mappings and isinstance(mappings['model_tags'], dict): - path_mappings['model_tags'] = mappings['model_tags'] - - logger.info(f"Loaded path mappings from {config_path}") - else: - logger.info(f"Path mappings configuration file not found at {config_path}, using default mappings") - except Exception as e: - logger.error(f"Error loading path mappings: {e}", exc_info=True) - - return path_mappings async def _get_civitai_client(self): """Lazily initialize CivitaiClient from registry""" @@ -157,31 +126,8 @@ class DownloadManager: return {'success': False, 'error': 'Default lora root path not set in settings'} save_dir = default_path - # Set relative_path to version_info.baseModel/prioritized_tag - base_model = version_info.get('baseModel', '') - model_tags = version_info.get('model', {}).get('tags', []) - - if base_model: - # Apply base model mapping if available - mapped_base_model = self._path_mappings['base_models'].get(base_model, base_model) - - # Find the first Civitai model tag that exists in model_tags - prioritized_tag = None - for civitai_tag in CIVITAI_MODEL_TAGS: - if civitai_tag in model_tags: - prioritized_tag = civitai_tag - break - - # If no Civitai model tag found, fallback to first tag - if prioritized_tag is None and model_tags: - prioritized_tag = model_tags[0] - - if prioritized_tag: - # Apply tag mapping if available - mapped_tag = self._path_mappings['model_tags'].get(prioritized_tag, prioritized_tag) - relative_path = os.path.join(mapped_base_model, mapped_tag) - else: - relative_path = mapped_base_model + # Calculate relative path using template + relative_path = self._calculate_relative_path(version_info) # Update save directory with relative path if provided if relative_path: @@ -250,6 +196,50 @@ class DownloadManager: return {'success': False, 'error': f"Early access restriction: {str(e)}. Please ensure you have purchased early access and are logged in to Civitai."} return {'success': False, 'error': str(e)} + def _calculate_relative_path(self, version_info: Dict) -> str: + """Calculate relative path using template from settings + + Args: + version_info: Version info from Civitai API + + Returns: + Relative path string + """ + # Get path template from settings, default to '{base_model}/{first_tag}' + path_template = settings.get('download_path_template', '{base_model}/{first_tag}') + + # If template is empty, return empty path (flat structure) + if not path_template: + return '' + + # Get base model name + base_model = version_info.get('baseModel', '') + + # Apply mapping if available + base_model_mappings = settings.get('base_model_path_mappings', {}) + mapped_base_model = base_model_mappings.get(base_model, base_model) + + # Get model tags + model_tags = version_info.get('model', {}).get('tags', []) + + # Find the first Civitai model tag that exists in model_tags + first_tag = '' + for civitai_tag in CIVITAI_MODEL_TAGS: + if civitai_tag in model_tags: + first_tag = civitai_tag + break + + # If no Civitai model tag found, fallback to first tag + if not first_tag and model_tags: + first_tag = model_tags[0] + + # Format the template with available data + formatted_path = path_template + formatted_path = formatted_path.replace('{base_model}', mapped_base_model) + formatted_path = formatted_path.replace('{first_tag}', first_tag) + + return formatted_path + async def _execute_download(self, download_url: str, save_dir: str, metadata, version_info: Dict, relative_path: str, progress_callback=None,