diff --git a/py/services/download_manager.py b/py/services/download_manager.py index cce8cc84..198631a6 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -9,6 +9,7 @@ from urllib.parse import urlparse from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES from ..utils.civitai_utils import rewrite_preview_url +from ..utils.utils import sanitize_folder_name from ..utils.exif_utils import ExifUtils from ..utils.metadata_manager import MetadataManager from .service_registry import ServiceRegistry @@ -427,8 +428,8 @@ class DownloadManager: formatted_path = formatted_path.replace('{base_model}', mapped_base_model) formatted_path = formatted_path.replace('{first_tag}', first_tag) formatted_path = formatted_path.replace('{author}', author) - formatted_path = formatted_path.replace('{model_name}', model_info.get('name', '')) - formatted_path = formatted_path.replace('{version_name}', version_info.get('name', '')) + formatted_path = formatted_path.replace('{model_name}', sanitize_folder_name(model_info.get('name', ''))) + formatted_path = formatted_path.replace('{version_name}', sanitize_folder_name(version_info.get('name', ''))) if model_type == 'embedding': formatted_path = formatted_path.replace(' ', '_') diff --git a/py/utils/utils.py b/py/utils/utils.py index fa0d0fc0..77b2a6c2 100644 --- a/py/utils/utils.py +++ b/py/utils/utils.py @@ -1,5 +1,6 @@ from difflib import SequenceMatcher import os +import re from typing import Dict from ..services.service_registry import ServiceRegistry from ..config import config @@ -85,6 +86,41 @@ def fuzzy_match(text: str, pattern: str, threshold: float = 0.85) -> bool: # All words found either as substrings or fuzzy matches return True +def sanitize_folder_name(name: str, replacement: str = "_") -> str: + """Sanitize a folder name by removing or replacing invalid characters. + + Args: + name: The original folder name. + replacement: The character to use when replacing invalid characters. + + Returns: + A sanitized folder name safe to use across common filesystems. + """ + + if not name: + return "" + + # Replace invalid characters commonly restricted on Windows and POSIX + invalid_chars_pattern = r'[<>:"/\\|?*\x00-\x1f]' + sanitized = re.sub(invalid_chars_pattern, replacement, name) + + # Trim whitespace introduced during sanitization + sanitized = sanitized.strip() + + # Collapse repeated replacement characters to a single instance + if replacement: + sanitized = re.sub(f"{re.escape(replacement)}+", replacement, sanitized) + sanitized = sanitized.strip(replacement) + + # Remove trailing spaces or periods which are invalid on Windows + sanitized = sanitized.rstrip(" .") + + if not sanitized: + return "unnamed" + + return sanitized + + def calculate_recipe_fingerprint(loras): """ Calculate a unique fingerprint for a recipe based on its LoRAs. @@ -175,11 +211,11 @@ def calculate_relative_path_for_model(model_data: Dict, model_type: str = 'lora' first_tag = 'no tags' # Default if no tags available # Format the template with available data - model_name = model_data.get('model_name', '') + model_name = sanitize_folder_name(model_data.get('model_name', '')) version_name = '' if isinstance(civitai_data, dict): - version_name = civitai_data.get('name') or '' + version_name = sanitize_folder_name(civitai_data.get('name') or '') formatted_path = path_template formatted_path = formatted_path.replace('{base_model}', mapped_base_model) diff --git a/tests/services/test_download_manager.py b/tests/services/test_download_manager.py index d5acff88..4306adf7 100644 --- a/tests/services/test_download_manager.py +++ b/tests/services/test_download_manager.py @@ -340,6 +340,22 @@ def test_relative_path_supports_model_and_version_placeholders(): assert relative_path == "Fancy Model/Version One" +def test_relative_path_sanitizes_model_and_version_placeholders(): + manager = DownloadManager() + settings_manager = get_settings_manager() + settings_manager.settings["download_path_templates"]["lora"] = "{model_name}/{version_name}" + + version_info = { + "baseModel": "BaseModel", + "name": "Version:One?", + "model": {"name": "Fancy:Model*", "tags": []}, + } + + relative_path = manager._calculate_relative_path(version_info, "lora") + + assert relative_path == "Fancy_Model/Version_One" + + async def test_execute_download_retries_urls(monkeypatch, tmp_path): manager = DownloadManager() diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 952c5944..a0b34258 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -4,6 +4,7 @@ from py.services.settings_manager import SettingsManager, get_settings_manager from py.utils.utils import ( calculate_recipe_fingerprint, calculate_relative_path_for_model, + sanitize_folder_name, ) @@ -68,6 +69,21 @@ def test_calculate_relative_path_supports_model_and_version(isolated_settings): assert relative_path == "Fancy Model/Version One" +def test_calculate_relative_path_sanitizes_model_and_version_names(isolated_settings): + isolated_settings["download_path_templates"]["lora"] = "{model_name}/{version_name}" + + model_data = { + "model_name": "Fancy:Model*", + "base_model": "SDXL", + "tags": ["tag"], + "civitai": {"id": 1, "name": "Version:One?", "creator": {"username": "Creator"}}, + } + + relative_path = calculate_relative_path_for_model(model_data, "lora") + + assert relative_path == "Fancy_Model/Version_One" + + def test_calculate_recipe_fingerprint_filters_and_sorts(): loras = [ {"hash": "ABC", "strength": 0.1234}, @@ -84,3 +100,17 @@ def test_calculate_recipe_fingerprint_filters_and_sorts(): def test_calculate_recipe_fingerprint_empty_input(): assert calculate_recipe_fingerprint([]) == "" + + +@pytest.mark.parametrize( + "original, expected", + [ + ("ValidName", "ValidName"), + ("Invalid:Name", "Invalid_Name"), + ("Trailing. ", "Trailing"), + ("", ""), + (":::", "unnamed"), + ], +) +def test_sanitize_folder_name(original, expected): + assert sanitize_folder_name(original) == expected