From 26b36c123de55e321d1531351485ffe2527efe3e Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Thu, 25 Sep 2025 09:21:08 +0800 Subject: [PATCH 1/2] test(i18n): modularize translation validation --- tests/i18n/test_i18n.py | 1192 +++++++-------------------------------- 1 file changed, 202 insertions(+), 990 deletions(-) diff --git a/tests/i18n/test_i18n.py b/tests/i18n/test_i18n.py index 1adf0dc9..41631287 100644 --- a/tests/i18n/test_i18n.py +++ b/tests/i18n/test_i18n.py @@ -1,1035 +1,247 @@ -#!/usr/bin/env python3 -""" -Test script to verify the updated i18n system works correctly. -This tests both JavaScript loading and Python server-side functionality. +"""Regression tests for localization data and usage. + +These tests validate three key aspects of the localisation setup: + +* Every locale file is valid JSON and contains the expected sections. +* All locales expose the same translation keys as the English reference. +* Static JavaScript/HTML sources only reference available translation keys. """ -import glob +from __future__ import annotations + import json -import os import re -import sys from pathlib import Path -from typing import Any, Dict, List, Set +from typing import Dict, Iterable, Set + +import pytest ROOT_DIR = Path(__file__).resolve().parents[2] -if str(ROOT_DIR) not in sys.path: - sys.path.insert(0, str(ROOT_DIR)) +LOCALES_DIR = ROOT_DIR / "locales" +STATIC_JS_DIR = ROOT_DIR / "static" / "js" +TEMPLATES_DIR = ROOT_DIR / "templates" + +EXPECTED_LOCALES = ( + "en", + "zh-CN", + "zh-TW", + "ja", + "ru", + "de", + "fr", + "es", + "ko", +) + +REQUIRED_SECTIONS = {"common", "header", "loras", "recipes", "modals"} + +SINGLE_WORD_TRANSLATION_KEYS = { + "loading", + "error", + "success", + "warning", + "info", + "cancel", + "save", + "delete", +} + +FALSE_POSITIVES = { + "checkpoint", + "civitai_api_key", + "div", + "embedding", + "lora", + "show_only_sfw", + "model", + "type", + "name", + "value", + "id", + "class", + "style", + "src", + "href", + "data", + "width", + "height", + "size", + "format", + "version", + "url", + "path", + "file", + "folder", + "image", + "text", + "number", + "boolean", + "array", + "object", + "non.existent.key", +} + +SPECIAL_UI_HELPER_KEYS = { + "uiHelpers.workflow.loraAdded", + "uiHelpers.workflow.loraReplaced", + "uiHelpers.workflow.loraFailedToSend", + "uiHelpers.workflow.recipeAdded", + "uiHelpers.workflow.recipeReplaced", + "uiHelpers.workflow.recipeFailedToSend", +} + +JS_TRANSLATION_PATTERNS = ( + r"\btranslate\s*\(\s*['\"]([a-zA-Z0-9._-]+)['\"]", + r"\bshowToast\s*\(\s*['\"]([a-zA-Z0-9._-]+)['\"]", + r"\bt\s*\(\s*['\"]([a-zA-Z0-9._-]+)['\"]", +) + +HTML_TRANSLATION_PATTERN = ( + r"(?:\{\{|\{%)[^}]*\bt\s*\(\s*['\"]([a-zA-Z0-9._-]+)['\"][^}]*(?:\}\}|%\})" +) -def check_json_files_exist() -> bool: - """Test that all JSON locale files exist and are valid JSON.""" - print("Testing JSON locale files...") - return check_json_structure_validation() +@pytest.fixture(scope="module") +def loaded_locales() -> Dict[str, dict]: + """Load locale JSON once per test module.""" + locales: Dict[str, dict] = {} + for locale in EXPECTED_LOCALES: + path = LOCALES_DIR / f"{locale}.json" + if not path.exists(): + pytest.fail(f"Locale file {path.name} is missing", pytrace=False) -def check_locale_files_structural_consistency() -> bool: - """Test that all locale files have identical structure, line counts, and formatting.""" - print("\nTesting locale files structural consistency...") - - locales_dir = ROOT_DIR / 'locales' - if not locales_dir.exists(): - print("āŒ Locales directory does not exist!") - return False - - # Get all locale files - locale_files = [] - for file in os.listdir(locales_dir): - if file.endswith('.json'): - locale_files.append(file) - - if not locale_files: - print("āŒ No locale files found!") - return False - - # Use en.json as the reference - reference_file = 'en.json' - if reference_file not in locale_files: - print(f"āŒ Reference file {reference_file} not found!") - return False - - locale_files.remove(reference_file) - locale_files.insert(0, reference_file) # Put reference first - - success = True - - # Load and parse the reference file - reference_path = locales_dir / reference_file - try: - with open(reference_path, 'r', encoding='utf-8') as f: - reference_lines = f.readlines() - reference_content = ''.join(reference_lines) - - reference_data = json.loads(reference_content) - reference_structure = get_json_structure(reference_data) - - print(f"šŸ“‹ Reference file {reference_file}:") - print(f" Lines: {len(reference_lines)}") - print(f" Keys: {len(get_all_translation_keys(reference_data))}") - - except Exception as e: - print(f"āŒ Error reading reference file {reference_file}: {e}") - return False - - # Compare each locale file with the reference - for locale_file in locale_files[1:]: # Skip reference file - locale_path = locales_dir / locale_file - locale_name = locale_file.replace('.json', '') - try: - with open(locale_path, 'r', encoding='utf-8') as f: - locale_lines = f.readlines() - locale_content = ''.join(locale_lines) - - locale_data = json.loads(locale_content) - locale_structure = get_json_structure(locale_data) - - # Test 1: Line count consistency - if len(locale_lines) != len(reference_lines): - print(f"āŒ {locale_name}: Line count mismatch!") - print(f" Reference: {len(reference_lines)} lines") - print(f" {locale_name}: {len(locale_lines)} lines") - success = False - continue - - # Test 2: Structural consistency (key order and nesting) - structure_issues = compare_json_structures(reference_structure, locale_structure) - if structure_issues: - print(f"āŒ {locale_name}: Structure mismatch!") - for issue in structure_issues[:5]: # Show first 5 issues - print(f" - {issue}") - if len(structure_issues) > 5: - print(f" ... and {len(structure_issues) - 5} more issues") - success = False - continue - - # Test 3: Line-by-line format consistency (excluding translation values) - format_issues = compare_line_formats(reference_lines, locale_lines, locale_name) - if format_issues: - print(f"āŒ {locale_name}: Format mismatch!") - for issue in format_issues[:5]: # Show first 5 issues - print(f" - {issue}") - if len(format_issues) > 5: - print(f" ... and {len(format_issues) - 5} more issues") - success = False - continue - - # Test 4: Key completeness - reference_keys = get_all_translation_keys(reference_data) - locale_keys = get_all_translation_keys(locale_data) - - missing_keys = reference_keys - locale_keys - extra_keys = locale_keys - reference_keys - - if missing_keys or extra_keys: - print(f"āŒ {locale_name}: Key mismatch!") - if missing_keys: - print(f" Missing {len(missing_keys)} keys") - if extra_keys: - print(f" Extra {len(extra_keys)} keys") - success = False - continue - - print(f"āœ… {locale_name}: Structure and format consistent") - - except json.JSONDecodeError as e: - print(f"āŒ {locale_name}: Invalid JSON syntax: {e}") - success = False - except Exception as e: - print(f"āŒ {locale_name}: Error during validation: {e}") - success = False - - if success: - print(f"\nāœ… All {len(locale_files)} locale files have consistent structure and formatting") - - return success + data = json.loads(path.read_text(encoding="utf-8")) + except json.JSONDecodeError as exc: # pragma: no cover - explicit failure message + pytest.fail(f"Locale file {path.name} contains invalid JSON: {exc}", pytrace=False) -def get_json_structure(data: Any, path: str = '') -> Dict[str, Any]: - """ - Extract the structural information from JSON data. - Returns a dictionary describing the structure without the actual values. - """ - if isinstance(data, dict): - structure = {} - for key, value in data.items(): - current_path = f"{path}.{key}" if path else key - if isinstance(value, dict): - structure[key] = get_json_structure(value, current_path) - elif isinstance(value, list): - structure[key] = {'_type': 'array', '_length': len(value)} - if value: # If array is not empty, analyze first element - structure[key]['_element_type'] = get_json_structure(value[0], f"{current_path}[0]") - else: - structure[key] = {'_type': type(value).__name__} - return structure - elif isinstance(data, list): - return {'_type': 'array', '_length': len(data)} - else: - return {'_type': type(data).__name__} + if not isinstance(data, dict): + pytest.fail( + f"Locale file {path.name} must contain a JSON object at the top level", + pytrace=False, + ) -def compare_json_structures(ref_structure: Dict[str, Any], locale_structure: Dict[str, Any], path: str = '') -> List[str]: - """ - Compare two JSON structures and return a list of differences. - """ - issues = [] - - # Check for missing keys in locale - for key in ref_structure: - current_path = f"{path}.{key}" if path else key - if key not in locale_structure: - issues.append(f"Missing key: {current_path}") - elif isinstance(ref_structure[key], dict) and '_type' not in ref_structure[key]: - # It's a nested object, recurse - if isinstance(locale_structure[key], dict) and '_type' not in locale_structure[key]: - issues.extend(compare_json_structures(ref_structure[key], locale_structure[key], current_path)) - else: - issues.append(f"Structure mismatch at {current_path}: expected object, got {type(locale_structure[key])}") - elif ref_structure[key] != locale_structure[key]: - issues.append(f"Type mismatch at {current_path}: expected {ref_structure[key]}, got {locale_structure[key]}") - - # Check for extra keys in locale - for key in locale_structure: - current_path = f"{path}.{key}" if path else key - if key not in ref_structure: - issues.append(f"Extra key: {current_path}") - - return issues + locales[locale] = data -def extract_line_structure(line: str) -> Dict[str, str]: - """ - Extract structural elements from a JSON line. - Returns indentation, key (if present), and structural characters. - """ - # Get indentation (leading whitespace) - indentation = len(line) - len(line.lstrip()) - - # Remove leading/trailing whitespace for analysis - stripped_line = line.strip() - - # Extract key if this is a key-value line - key_match = re.match(r'^"([^"]+)"\s*:\s*', stripped_line) - key = key_match.group(1) if key_match else '' - - # Extract structural characters (everything except the actual translation value) - if key: - # For key-value lines, extract everything except the value - # Handle string values in quotes with better escaping support - value_pattern = r'^"[^"]+"\s*:\s*("(?:[^"\\]|\\.)*")(.*?)$' - value_match = re.match(value_pattern, stripped_line) - if value_match: - # Preserve the structure but replace the actual string content - structural_chars = f'"{key}": "VALUE"{value_match.group(2)}' - else: - # Handle non-string values (objects, arrays, booleans, numbers) - colon_pos = stripped_line.find(':') - if colon_pos != -1: - after_colon = stripped_line[colon_pos + 1:].strip() - if after_colon.startswith('"'): - # String value - find the end quote with proper escaping - end_quote = find_closing_quote(after_colon, 1) - if end_quote != -1: - structural_chars = f'"{key}": "VALUE"{after_colon[end_quote + 1:]}' - else: - structural_chars = f'"{key}": "VALUE"' - elif after_colon.startswith('{'): - # Object value - structural_chars = f'"{key}": {{' - elif after_colon.startswith('['): - # Array value - structural_chars = f'"{key}": [' - else: - # Other values (numbers, booleans, null) - # Replace the actual value with a placeholder - value_end = find_value_end(after_colon) - structural_chars = f'"{key}": VALUE{after_colon[value_end:]}' - else: - structural_chars = stripped_line - else: - # For non key-value lines (brackets, etc.), keep as-is - structural_chars = stripped_line - - return { - 'indentation': str(indentation), - 'key': key, - 'structural_chars': structural_chars - } - -def find_value_end(text: str) -> int: - """ - Find the end of a non-string JSON value (number, boolean, null). - """ - for i, char in enumerate(text): - if char in ',}]': - return i - return len(text) - -def find_closing_quote(text: str, start: int) -> int: - """ - Find the position of the closing quote, handling escaped quotes properly. - """ - i = start - while i < len(text): - if text[i] == '"': - # Count preceding backslashes - backslash_count = 0 - j = i - 1 - while j >= 0 and text[j] == '\\': - backslash_count += 1 - j -= 1 - - # If even number of backslashes (including 0), the quote is not escaped - if backslash_count % 2 == 0: - return i - i += 1 - return -1 - -def compare_line_formats(ref_lines: List[str], locale_lines: List[str], locale_name: str) -> List[str]: - """ - Compare line-by-line formatting between reference and locale files. - Only checks structural elements (indentation, brackets, commas) and ignores translation values. - """ - issues = [] - - for i, (ref_line, locale_line) in enumerate(zip(ref_lines, locale_lines)): - line_num = i + 1 - - # Skip empty lines and lines with only whitespace - if not ref_line.strip() and not locale_line.strip(): - continue - - # Extract structural elements from each line - ref_structure = extract_line_structure(ref_line) - locale_structure = extract_line_structure(locale_line) - - # Compare structural elements with more tolerance - structure_issues = [] - - # Check indentation (must be exact) - if ref_structure['indentation'] != locale_structure['indentation']: - structure_issues.append(f"indentation ({ref_structure['indentation']} vs {locale_structure['indentation']})") - - # Check keys (must be exact for structural consistency) - if ref_structure['key'] != locale_structure['key']: - structure_issues.append(f"key ('{ref_structure['key']}' vs '{locale_structure['key']}')") - - # Check structural characters with improved normalization - ref_normalized = normalize_structural_chars(ref_structure['structural_chars']) - locale_normalized = normalize_structural_chars(locale_structure['structural_chars']) - - if ref_normalized != locale_normalized: - # Additional check: if both lines have the same key and similar structure, - # this might be a false positive due to translation content differences - if (ref_structure['key'] and locale_structure['key'] and - ref_structure['key'] == locale_structure['key']): - - # Check if the difference is only in the translation value - ref_has_string_value = '"VALUE"' in ref_normalized - locale_has_string_value = '"VALUE"' in locale_normalized - - if ref_has_string_value and locale_has_string_value: - # Both have string values, check if structure around value is same - ref_structure_only = re.sub(r'"VALUE"', '"X"', ref_normalized) - locale_structure_only = re.sub(r'"VALUE"', '"X"', locale_normalized) - - if ref_structure_only == locale_structure_only: - # Structure is actually the same, skip this as false positive - continue - - structure_issues.append(f"structure ('{ref_normalized}' vs '{locale_normalized}')") - - if structure_issues: - issues.append(f"Line {line_num}: {', '.join(structure_issues)}") - - return issues - -def normalize_structural_chars(structural_chars: str) -> str: - """ - Normalize structural characters for comparison by replacing variable content - with placeholders while preserving the actual structure. - """ - # Normalize the structural characters more carefully - normalized = structural_chars - - # Replace quoted strings with a consistent placeholder, handling escapes - # This regex matches strings while properly handling escaped quotes - string_pattern = r'"(?:[^"\\]|\\.)*"(?=\s*[,}\]:}]|$)' - - # Find all string matches and replace with placeholder - strings = re.findall(string_pattern, normalized) - for string_match in strings: - # Only replace if this looks like a translation value, not a key - if ':' in normalized: - # Check if this string comes after a colon (likely a value) - parts = normalized.split(':', 1) - if len(parts) == 2 and string_match in parts[1]: - normalized = normalized.replace(string_match, '"VALUE"', 1) - - # Normalize whitespace around structural characters - normalized = re.sub(r'\s*:\s*', ': ', normalized) - normalized = re.sub(r'\s*,\s*', ', ', normalized) - normalized = re.sub(r'\s*{\s*', '{ ', normalized) - normalized = re.sub(r'\s*}\s*', ' }', normalized) - - return normalized.strip() - -def check_locale_files_formatting_consistency() -> bool: - """Test that all locale files have identical formatting (whitespace, indentation, etc.).""" - print("\nTesting locale files formatting consistency...") - - locales_dir = ROOT_DIR / 'locales' - expected_locales = ['en', 'zh-CN', 'zh-TW', 'ja', 'ru', 'de', 'fr', 'es', 'ko'] - - # Read reference file (en.json) - reference_path = locales_dir / 'en.json' - try: - with open(reference_path, 'r', encoding='utf-8') as f: - reference_lines = f.readlines() - except Exception as e: - print(f"āŒ Error reading reference file: {e}") - return False - - success = True - - # Compare each locale file - for locale in expected_locales[1:]: # Skip 'en' as it's the reference - locale_path = locales_dir / f'{locale}.json' - - if not os.path.exists(locale_path): - print(f"āŒ {locale}.json does not exist!") - success = False - continue - - try: - with open(locale_path, 'r', encoding='utf-8') as f: - locale_lines = f.readlines() - - # Compare line count - if len(locale_lines) != len(reference_lines): - print(f"āŒ {locale}.json: Line count differs from reference") - print(f" Reference: {len(reference_lines)} lines") - print(f" {locale}: {len(locale_lines)} lines") - success = False - continue - - # Compare formatting with improved algorithm - formatting_issues = compare_line_formats(reference_lines, locale_lines, locale) - - if formatting_issues: - print(f"āŒ {locale}.json: Formatting issues found") - # Show only the first few issues to avoid spam - shown_issues = 0 - for issue in formatting_issues: - if shown_issues < 3: # Reduced from 5 to 3 - print(f" - {issue}") - shown_issues += 1 - else: - break - - if len(formatting_issues) > 3: - print(f" ... and {len(formatting_issues) - 3} more issues") - - # Provide debug info for first issue to help identify false positives - if formatting_issues: - first_issue = formatting_issues[0] - line_match = re.match(r'Line (\d+):', first_issue) - if line_match: - line_num = int(line_match.group(1)) - 1 # Convert to 0-based - if 0 <= line_num < len(reference_lines): - print(f" Debug - Reference line {line_num + 1}: {repr(reference_lines[line_num].rstrip())}") - print(f" Debug - {locale} line {line_num + 1}: {repr(locale_lines[line_num].rstrip())}") - - success = False - else: - print(f"āœ… {locale}.json: Formatting consistent with reference") - - except Exception as e: - print(f"āŒ Error validating {locale}.json: {e}") - success = False - - if success: - print("āœ… All locale files have consistent formatting") - else: - print("šŸ’” Note: Some formatting differences may be false positives due to translation content.") - print(" If translations are correct but structure appears different, the test may need refinement.") - - return success - -def check_locale_key_ordering() -> bool: - """Test that all locale files maintain the same key ordering as the reference.""" - print("\nTesting locale files key ordering...") - - locales_dir = ROOT_DIR / 'locales' - expected_locales = ['en', 'zh-CN', 'zh-TW', 'ja', 'ru', 'de', 'fr', 'es', 'ko'] - - # Load reference file - reference_path = locales_dir / 'en.json' - try: - with open(reference_path, 'r', encoding='utf-8') as f: - reference_data = json.load(f, object_pairs_hook=lambda x: x) # Preserve order - - reference_key_order = get_key_order(reference_data) - except Exception as e: - print(f"āŒ Error reading reference file: {e}") - return False - - success = True - - for locale in expected_locales[1:]: # Skip 'en' as it's the reference - locale_path = locales_dir / f'{locale}.json' - - if not os.path.exists(locale_path): - continue - - try: - with open(locale_path, 'r', encoding='utf-8') as f: - locale_data = json.load(f, object_pairs_hook=lambda x: x) # Preserve order - - locale_key_order = get_key_order(locale_data) - - if reference_key_order != locale_key_order: - print(f"āŒ {locale}.json: Key ordering differs from reference") - - # Find the first difference - for i, (ref_key, locale_key) in enumerate(zip(reference_key_order, locale_key_order)): - if ref_key != locale_key: - print(f" First difference at position {i}: '{ref_key}' vs '{locale_key}'") - break - - success = False - else: - print(f"āœ… {locale}.json: Key ordering matches reference") - - except Exception as e: - print(f"āŒ Error validating {locale}.json key ordering: {e}") - success = False - - return success - -def get_key_order(data: Any, path: str = '') -> List[str]: - """ - Extract the order of keys from nested JSON data. - Returns a list of all keys in their order of appearance. - """ - keys = [] - - if isinstance(data, list): - # Handle list of key-value pairs (from object_pairs_hook) - for key, value in data: - current_path = f"{path}.{key}" if path else key - keys.append(current_path) - if isinstance(value, list): # Nested object as list of pairs - keys.extend(get_key_order(value, current_path)) - elif isinstance(data, dict): - for key, value in data.items(): - current_path = f"{path}.{key}" if path else key - keys.append(current_path) - if isinstance(value, (dict, list)): - keys.extend(get_key_order(value, current_path)) - - return keys - -def check_server_i18n() -> bool: - """Test the Python server-side i18n system.""" - print("\nTesting Python server-side i18n...") - - try: - from py.services.server_i18n import ServerI18nManager - - # Create a new instance to test - i18n = ServerI18nManager() - - # Test that translations loaded - available_locales = i18n.get_available_locales() - if not available_locales: - print("āŒ No locales loaded in server i18n!") - return False - - print(f"āœ… Loaded {len(available_locales)} locales: {', '.join(available_locales)}") - - # Test English translations - i18n.set_locale('en') - test_key = 'common.status.loading' - translation = i18n.get_translation(test_key) - if translation == test_key: - print(f"āŒ Translation not found for key '{test_key}'") - return False - - print(f"āœ… English translation for '{test_key}': '{translation}'") - - # Test Chinese translations - i18n.set_locale('zh-CN') - translation_cn = i18n.get_translation(test_key) - if translation_cn == test_key: - print(f"āŒ Chinese translation not found for key '{test_key}'") - return False - - print(f"āœ… Chinese translation for '{test_key}': '{translation_cn}'") - - # Test parameter interpolation - param_key = 'common.itemCount' - translation_with_params = i18n.get_translation(param_key, count=42) - if '{count}' in translation_with_params: - print(f"āŒ Parameter interpolation failed for key '{param_key}'") - return False - - print(f"āœ… Parameter interpolation for '{param_key}': '{translation_with_params}'") - - print("āœ… Server-side i18n system working correctly") - return True - - except Exception as e: - print(f"āŒ Error testing server i18n: {e}") - import traceback - traceback.print_exc() - return False - -def check_translation_completeness() -> bool: - """Test that all languages have the same translation keys.""" - print("\nTesting translation completeness...") - - locales_dir = ROOT_DIR / 'locales' - - # Load English as reference - with open(locales_dir / 'en.json', 'r', encoding='utf-8') as f: - en_data = json.load(f) - - en_keys = get_all_translation_keys(en_data) - print(f"English has {len(en_keys)} translation keys") - - # Check other languages - locales = ['zh-CN', 'zh-TW', 'ja', 'ru', 'de', 'fr', 'es', 'ko'] - - for locale in locales: - with open(locales_dir / f'{locale}.json', 'r', encoding='utf-8') as f: - locale_data = json.load(f) - - locale_keys = get_all_translation_keys(locale_data) - - missing_keys = en_keys - locale_keys - extra_keys = locale_keys - en_keys - - if missing_keys: - print(f"āŒ {locale} missing keys: {len(missing_keys)}") - # Print first few missing keys - for key in sorted(missing_keys)[:5]: - print(f" - {key}") - if len(missing_keys) > 5: - print(f" ... and {len(missing_keys) - 5} more") - - if extra_keys: - print(f"āš ļø {locale} has extra keys: {len(extra_keys)}") - - if not missing_keys and not extra_keys: - print(f"āœ… {locale} has complete translations ({len(locale_keys)} keys)") - - return True + return locales -def extract_i18n_keys_from_js(file_path: str) -> Set[str]: - """Extract translation keys from JavaScript files.""" - keys = set() - - try: - with open(file_path, 'r', encoding='utf-8') as f: - content = f.read() - - # Remove comments to avoid false positives - # Remove single-line comments - content = re.sub(r'//.*$', '', content, flags=re.MULTILINE) - # Remove multi-line comments - content = re.sub(r'/\*.*?\*/', '', content, flags=re.DOTALL) - - # Pattern for translate() function calls - more specific - # Matches: translate('key.name', ...) or translate("key.name", ...) - # Must have opening parenthesis immediately after translate - translate_pattern = r"\btranslate\s*\(\s*['\"]([a-zA-Z0-9._-]+)['\"]" - translate_matches = re.findall(translate_pattern, content) - - # Filter out single words that are likely not translation keys - # Translation keys should typically have dots or be in specific namespaces - filtered_translate = [key for key in translate_matches if '.' in key or key in [ - 'loading', 'error', 'success', 'warning', 'info', 'cancel', 'save', 'delete' - ]] - keys.update(filtered_translate) - - # Pattern for showToast() function calls - more specific - # Matches: showToast('key.name', ...) or showToast("key.name", ...) - showtoast_pattern = r"\bshowToast\s*\(\s*['\"]([a-zA-Z0-9._-]+)['\"]" - showtoast_matches = re.findall(showtoast_pattern, content) - - # Filter showToast matches as well - filtered_showtoast = [key for key in showtoast_matches if '.' in key or key in [ - 'loading', 'error', 'success', 'warning', 'info', 'cancel', 'save', 'delete' - ]] - keys.update(filtered_showtoast) - - # Additional patterns for other i18n function calls you might have - # Pattern for t() function calls (if used in JavaScript) - t_pattern = r"\bt\s*\(\s*['\"]([a-zA-Z0-9._-]+)['\"]" - t_matches = re.findall(t_pattern, content) - filtered_t = [key for key in t_matches if '.' in key or key in [ - 'loading', 'error', 'success', 'warning', 'info', 'cancel', 'save', 'delete' - ]] - keys.update(filtered_t) - - except Exception as e: - print(f"āš ļø Error reading {file_path}: {e}") - - return keys +@pytest.fixture(scope="module") +def english_translation_keys(loaded_locales: Dict[str, dict]) -> Set[str]: + return collect_translation_keys(loaded_locales["en"]) -def extract_i18n_keys_from_html(file_path: str) -> Set[str]: - """Extract translation keys from HTML template files.""" - keys = set() - - try: - with open(file_path, 'r', encoding='utf-8') as f: - content = f.read() - - # Remove HTML comments to avoid false positives - content = re.sub(r'', '', content, flags=re.DOTALL) - - # Pattern for t() function calls in Jinja2 templates - # Matches: {{ t('key.name') }} or {% ... t('key.name') ... %} - # More specific pattern that ensures we're in template context - t_pattern = r"(?:\{\{|\{%)[^}]*\bt\s*\(\s*['\"]([a-zA-Z0-9._-]+)['\"][^}]*(?:\}\}|%\})" - t_matches = re.findall(t_pattern, content) - - # Filter HTML matches - filtered_t = [key for key in t_matches if '.' in key or key in [ - 'loading', 'error', 'success', 'warning', 'info', 'cancel', 'save', 'delete' - ]] - keys.update(filtered_t) - - # Also check for translate() calls in script tags within HTML - script_pattern = r']*>(.*?)' - script_matches = re.findall(script_pattern, content, flags=re.DOTALL) - for script_content in script_matches: - # Apply JavaScript extraction to script content - translate_pattern = r"\btranslate\s*\(\s*['\"]([a-zA-Z0-9._-]+)['\"]" - script_translate_matches = re.findall(translate_pattern, script_content) - filtered_script = [key for key in script_translate_matches if '.' in key] - keys.update(filtered_script) - - except Exception as e: - print(f"āš ļø Error reading {file_path}: {e}") - - return keys +@pytest.fixture(scope="module") +def static_code_translation_keys() -> Set[str]: + return gather_static_translation_keys() -def get_all_translation_keys(data: dict, prefix: str = '', include_containers: bool = False) -> Set[str]: - """ - Recursively collect translation keys. - By default only leaf keys (where the value is NOT a dict) are returned so that - structural/container nodes (e.g. 'common', 'common.actions') are not treated - as real translation entries and won't appear in the 'unused' list. - - Set include_containers=True to also include container/object nodes. - """ +def collect_translation_keys(data: dict, prefix: str = "") -> Set[str]: + """Recursively collect translation keys from a locale dictionary.""" keys: Set[str] = set() - if not isinstance(data, dict): - return keys + for key, value in data.items(): full_key = f"{prefix}.{key}" if prefix else key if isinstance(value, dict): - # Recurse first - keys.update(get_all_translation_keys(value, full_key, include_containers)) - # Optionally include container nodes - if include_containers: - keys.add(full_key) + keys.update(collect_translation_keys(value, full_key)) else: - # Leaf node: actual translatable value keys.add(full_key) + return keys -def check_static_code_analysis() -> bool: - """Test static code analysis to detect missing translation keys.""" - # print("\nTesting static code analysis for translation keys...") - - # Load English translations as reference - locales_dir = ROOT_DIR / 'locales' - with open(locales_dir / 'en.json', 'r', encoding='utf-8') as f: - en_data = json.load(f) - - available_keys = get_all_translation_keys(en_data) - # print(f"Available translation keys in en.json: {len(available_keys)}") - - # Known false positives to exclude from analysis - # These are typically HTML attributes, CSS classes, or other non-translation strings - false_positives = { - 'checkpoint', 'civitai_api_key', 'div', 'embedding', 'lora', 'show_only_sfw', - 'model', 'type', 'name', 'value', 'id', 'class', 'style', 'src', 'href', - 'data', 'width', 'height', 'size', 'format', 'version', 'url', 'path', - 'file', 'folder', 'image', 'text', 'number', 'boolean', 'array', 'object', 'non.existent.key' - } +def gather_static_translation_keys() -> Set[str]: + """Collect translation keys referenced in static JavaScript and HTML templates.""" + keys: Set[str] = set() - # Special translation keys used in uiHelpers.js but not detected by regex - uihelpers_special_keys = { - 'uiHelpers.workflow.loraAdded', - 'uiHelpers.workflow.loraReplaced', - 'uiHelpers.workflow.loraFailedToSend', - 'uiHelpers.workflow.recipeAdded', - 'uiHelpers.workflow.recipeReplaced', - 'uiHelpers.workflow.recipeFailedToSend', - } - - # Extract keys from JavaScript files - js_dir = ROOT_DIR / 'static' / 'js' - js_files = [] - if os.path.exists(js_dir): - # Recursively find all JS files - for root, dirs, files in os.walk(js_dir): - for file in files: - if file.endswith('.js'): - js_files.append(os.path.join(root, file)) - - js_keys = set() - js_files_with_keys = [] - for js_file in js_files: - file_keys = extract_i18n_keys_from_js(js_file) - # Filter out false positives - file_keys = file_keys - false_positives - js_keys.update(file_keys) - if file_keys: - rel_path = os.path.relpath(js_file, ROOT_DIR) - js_files_with_keys.append((rel_path, len(file_keys))) - # print(f" Found {len(file_keys)} keys in {rel_path}") - - # print(f"Total unique keys found in JavaScript files: {len(js_keys)}") - - # Extract keys from HTML template files - templates_dir = ROOT_DIR / 'templates' - html_files = [] - if os.path.exists(templates_dir): - html_files = glob.glob(os.path.join(templates_dir, '*.html')) - # Also check for HTML files in subdirectories - html_files.extend(glob.glob(os.path.join(templates_dir, '**', '*.html'), recursive=True)) - - html_keys = set() - html_files_with_keys = [] - for html_file in html_files: - file_keys = extract_i18n_keys_from_html(html_file) - # Filter out false positives - file_keys = file_keys - false_positives - html_keys.update(file_keys) - if file_keys: - rel_path = os.path.relpath(html_file, ROOT_DIR) - html_files_with_keys.append((rel_path, len(file_keys))) - # print(f" Found {len(file_keys)} keys in {rel_path}") - - # print(f"Total unique keys found in HTML templates: {len(html_keys)}") - - # Combine all used keys - all_used_keys = js_keys.union(html_keys) - # Add special keys from uiHelpers.js - all_used_keys.update(uihelpers_special_keys) - # print(f"Total unique keys used in code: {len(all_used_keys)}") - - # Check for missing keys - missing_keys = all_used_keys - available_keys - unused_keys = available_keys - all_used_keys - - success = True - - if missing_keys: - print(f"\nāŒ Found {len(missing_keys)} missing translation keys:") - for key in sorted(missing_keys): - print(f" - {key}") - success = False - - # Group missing keys by category for better analysis - key_categories = {} - for key in missing_keys: - category = key.split('.')[0] if '.' in key else 'root' - if category not in key_categories: - key_categories[category] = [] - key_categories[category].append(key) - - print(f"\n Missing keys by category:") - for category, keys in sorted(key_categories.items()): - print(f" {category}: {len(keys)} keys") - - # Provide helpful suggestion - print(f"\nšŸ’” If these are false positives, add them to the false_positives set in test_static_code_analysis()") - else: - print("\nāœ… All translation keys used in code are available in en.json") - - if unused_keys: - print(f"\nāš ļø Found {len(unused_keys)} unused translation keys in en.json:") - # Only show first 20 to avoid cluttering output - for key in sorted(unused_keys)[:20]: - print(f" - {key}") - if len(unused_keys) > 20: - print(f" ... and {len(unused_keys) - 20} more") + if STATIC_JS_DIR.exists(): + for file_path in STATIC_JS_DIR.rglob("*.js"): + keys.update(filter_translation_keys(extract_i18n_keys_from_js(file_path))) - # Group unused keys by category for better analysis - unused_categories = {} - for key in unused_keys: - category = key.split('.')[0] if '.' in key else 'root' - if category not in unused_categories: - unused_categories[category] = [] - unused_categories[category].append(key) - - print(f"\n Unused keys by category:") - for category, keys in sorted(unused_categories.items()): - print(f" {category}: {len(keys)} keys") - - # Summary statistics - # print(f"\nšŸ“Š Static Code Analysis Summary:") - # print(f" JavaScript files analyzed: {len(js_files)}") - # print(f" JavaScript files with translations: {len(js_files_with_keys)}") - # print(f" HTML template files analyzed: {len(html_files)}") - # print(f" HTML template files with translations: {len(html_files_with_keys)}") - # print(f" Translation keys in en.json: {len(available_keys)}") - # print(f" Translation keys used in code: {len(all_used_keys)}") - # print(f" Usage coverage: {len(all_used_keys)/len(available_keys)*100:.1f}%") - - return success + if TEMPLATES_DIR.exists(): + for file_path in TEMPLATES_DIR.rglob("*.html"): + keys.update(filter_translation_keys(extract_i18n_keys_from_html(file_path))) + + keys.update(SPECIAL_UI_HELPER_KEYS) + + return keys -def check_json_structure_validation() -> bool: - """Test JSON file structure and syntax validation.""" - print("\nTesting JSON file structure and syntax validation...") - - locales_dir = ROOT_DIR / 'locales' - if not locales_dir.exists(): - print("āŒ Locales directory does not exist!") - return False - - expected_locales = ['en', 'zh-CN', 'zh-TW', 'ja', 'ru', 'de', 'fr', 'es', 'ko'] - success = True - - for locale in expected_locales: - file_path = locales_dir / f'{locale}.json' - if not file_path.exists(): - print(f"āŒ {locale}.json does not exist!") - success = False +def filter_translation_keys(raw_keys: Iterable[str]) -> Set[str]: + """Filter out obvious false positives and non-translation identifiers.""" + filtered: Set[str] = set() + for key in raw_keys: + if key in FALSE_POSITIVES: continue - - try: - with open(file_path, 'r', encoding='utf-8') as f: - data = json.load(f) - - # Check for valid JSON structure - if not isinstance(data, dict): - print(f"āŒ {locale}.json root must be an object/dictionary") - success = False - continue - - # Check that required sections exist - required_sections = ['common', 'header', 'loras', 'recipes', 'modals'] - missing_sections = [] - for section in required_sections: - if section not in data: - missing_sections.append(section) - - if missing_sections: - print(f"āŒ {locale}.json missing required sections: {', '.join(missing_sections)}") - success = False - - # Check for empty values - empty_values = [] - def check_empty_values(obj, path=''): - if isinstance(obj, dict): - for key, value in obj.items(): - current_path = f"{path}.{key}" if path else key - if isinstance(value, dict): - check_empty_values(value, current_path) - elif isinstance(value, str) and not value.strip(): - empty_values.append(current_path) - elif value is None: - empty_values.append(current_path) - - check_empty_values(data) - - if empty_values: - print(f"āš ļø {locale}.json has {len(empty_values)} empty translation values:") - for path in empty_values[:5]: # Show first 5 - print(f" - {path}") - if len(empty_values) > 5: - print(f" ... and {len(empty_values) - 5} more") - - print(f"āœ… {locale}.json structure is valid") - - except json.JSONDecodeError as e: - print(f"āŒ {locale}.json has invalid JSON syntax: {e}") - success = False - except Exception as e: - print(f"āŒ Error validating {locale}.json: {e}") - success = False - - return success + if "." not in key and key not in SINGLE_WORD_TRANSLATION_KEYS: + continue + filtered.add(key) + return filtered -def test_json_files_are_valid(): - assert check_json_files_exist() +def extract_i18n_keys_from_js(file_path: Path) -> Set[str]: + """Extract translation keys from JavaScript sources.""" + content = file_path.read_text(encoding="utf-8") + # Remove single-line and multi-line comments to avoid false positives. + content = re.sub(r"//.*$", "", content, flags=re.MULTILINE) + content = re.sub(r"/\*.*?\*/", "", content, flags=re.DOTALL) + + matches: Set[str] = set() + for pattern in JS_TRANSLATION_PATTERNS: + matches.update(re.findall(pattern, content)) + return matches -def test_locale_structures_match_reference(): - assert check_locale_files_structural_consistency() +def extract_i18n_keys_from_html(file_path: Path) -> Set[str]: + """Extract translation keys from HTML templates.""" + content = file_path.read_text(encoding="utf-8") + content = re.sub(r"", "", content, flags=re.DOTALL) + + matches: Set[str] = set(re.findall(HTML_TRANSLATION_PATTERN, content)) + + # Inspect inline script tags as JavaScript. + for script_body in re.findall(r"]*>(.*?)", content, flags=re.DOTALL): + for pattern in JS_TRANSLATION_PATTERNS: + matches.update(re.findall(pattern, script_body)) + + return matches -def test_locale_formatting_matches_reference(): - assert check_locale_files_formatting_consistency() +@pytest.mark.parametrize("locale", EXPECTED_LOCALES) +def test_locale_files_have_expected_structure(locale: str, loaded_locales: Dict[str, dict]) -> None: + """Every locale must contain the required sections.""" + data = loaded_locales[locale] + missing_sections = sorted(REQUIRED_SECTIONS - data.keys()) + assert not missing_sections, f"{locale} locale is missing sections: {missing_sections}" -def test_locale_key_order_matches_reference(): - assert check_locale_key_ordering() +@pytest.mark.parametrize("locale", EXPECTED_LOCALES[1:]) +def test_locale_keys_match_english( + locale: str, loaded_locales: Dict[str, dict], english_translation_keys: Set[str] +) -> None: + """Locales must expose the same translation keys as English.""" + locale_keys = collect_translation_keys(loaded_locales[locale]) + + missing_keys = sorted(english_translation_keys - locale_keys) + extra_keys = sorted(locale_keys - english_translation_keys) + + assert not missing_keys, ( + f"{locale} is missing translation keys: {missing_keys[:10]}" + + ("..." if len(missing_keys) > 10 else "") + ) + assert not extra_keys, ( + f"{locale} defines unexpected translation keys: {extra_keys[:10]}" + + ("..." if len(extra_keys) > 10 else "") + ) -def test_server_side_i18n_behaves_as_expected(): - assert check_server_i18n() - - -def test_translations_are_complete(): - assert check_translation_completeness() - - -def test_static_code_analysis_is_clean(): - assert check_static_code_analysis() - - -def test_json_structure_validation(): - assert check_json_structure_validation() - - -def main(): - """Run all tests.""" - print("šŸš€ Testing updated i18n system...\n") - - success = True - - # Test JSON files structure and syntax - if not check_json_files_exist(): - success = False - - # Test comprehensive structural consistency - if not check_locale_files_structural_consistency(): - success = False - - # Test formatting consistency - if not check_locale_files_formatting_consistency(): - success = False - - # Test key ordering - if not check_locale_key_ordering(): - success = False - - # Test server i18n - if not check_server_i18n(): - success = False - - # Test translation completeness - if not check_translation_completeness(): - success = False - - # Test static code analysis - if not check_static_code_analysis(): - success = False - - print(f"\n{'šŸŽ‰ All tests passed!' if success else 'āŒ Some tests failed!'}") - return success - -if __name__ == '__main__': - main() +def test_static_sources_only_use_existing_translations( + english_translation_keys: Set[str], static_code_translation_keys: Set[str] +) -> None: + """Static code must not reference unknown translation keys.""" + missing_keys = sorted(static_code_translation_keys - english_translation_keys) + assert not missing_keys, ( + "Static sources reference missing translation keys: " + f"{missing_keys[:20]}" + ("..." if len(missing_keys) > 20 else "") + ) From 095320ef72b2d92d6d3262d9dc1fc994272af7e1 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Thu, 25 Sep 2025 09:40:25 +0800 Subject: [PATCH 2/2] test(routes): tidy lora route test imports --- tests/routes/test_embedding_routes.py | 50 +++++ tests/routes/test_lora_routes.py | 213 +++++++++++++++++++ tests/services/test_civitai_client.py | 222 ++++++++++++++++++++ tests/services/test_download_manager.py | 247 +++++++++++++++++++++++ tests/services/test_settings_manager.py | 61 ++++++ tests/services/test_websocket_manager.py | 84 ++++++++ 6 files changed, 877 insertions(+) create mode 100644 tests/routes/test_embedding_routes.py create mode 100644 tests/routes/test_lora_routes.py create mode 100644 tests/services/test_civitai_client.py create mode 100644 tests/services/test_download_manager.py create mode 100644 tests/services/test_settings_manager.py create mode 100644 tests/services/test_websocket_manager.py diff --git a/tests/routes/test_embedding_routes.py b/tests/routes/test_embedding_routes.py new file mode 100644 index 00000000..fc1782a0 --- /dev/null +++ b/tests/routes/test_embedding_routes.py @@ -0,0 +1,50 @@ +import json + +import pytest + +from py.routes.embedding_routes import EmbeddingRoutes + + +class DummyRequest: + def __init__(self, *, match_info=None): + self.match_info = match_info or {} + + +class StubEmbeddingService: + def __init__(self): + self.info = {} + + async def get_model_info_by_name(self, name): + value = self.info.get(name) + if isinstance(value, Exception): + raise value + return value + + +@pytest.fixture +def routes(): + handler = EmbeddingRoutes() + handler.service = StubEmbeddingService() + return handler + + +async def test_get_embedding_info_success(routes): + routes.service.info["demo"] = {"name": "demo"} + response = await routes.get_embedding_info(DummyRequest(match_info={"name": "demo"})) + payload = json.loads(response.text) + assert payload == {"name": "demo"} + + +async def test_get_embedding_info_missing(routes): + response = await routes.get_embedding_info(DummyRequest(match_info={"name": "missing"})) + payload = json.loads(response.text) + assert response.status == 404 + assert payload == {"error": "Embedding not found"} + + +async def test_get_embedding_info_error(routes): + routes.service.info["demo"] = RuntimeError("boom") + response = await routes.get_embedding_info(DummyRequest(match_info={"name": "demo"})) + payload = json.loads(response.text) + assert response.status == 500 + assert payload == {"error": "boom"} diff --git a/tests/routes/test_lora_routes.py b/tests/routes/test_lora_routes.py new file mode 100644 index 00000000..2b447987 --- /dev/null +++ b/tests/routes/test_lora_routes.py @@ -0,0 +1,213 @@ +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from py.routes.lora_routes import LoraRoutes +from server import PromptServer + + +class DummyRequest: + def __init__(self, *, query=None, match_info=None, json_data=None): + self.query = query or {} + self.match_info = match_info or {} + self._json_data = json_data or {} + + async def json(self): + return self._json_data + + +class StubLoraService: + def __init__(self): + self.notes = {} + self.trigger_words = {} + self.usage_tips = {} + self.previews = {} + self.civitai = {} + + async def get_lora_notes(self, name): + return self.notes.get(name) + + async def get_lora_trigger_words(self, name): + return self.trigger_words.get(name, []) + + async def get_lora_usage_tips_by_relative_path(self, path): + return self.usage_tips.get(path) + + async def get_lora_preview_url(self, name): + return self.previews.get(name) + + async def get_lora_civitai_url(self, name): + return self.civitai.get(name, {"civitai_url": ""}) + + +@pytest.fixture +def routes(): + handler = LoraRoutes() + handler.service = StubLoraService() + return handler + + +async def test_get_lora_notes_success(routes): + routes.service.notes["demo"] = "Great notes" + request = DummyRequest(query={"name": "demo"}) + + response = await routes.get_lora_notes(request) + payload = json.loads(response.text) + + assert payload == {"success": True, "notes": "Great notes"} + + +async def test_get_lora_notes_missing_name(routes): + response = await routes.get_lora_notes(DummyRequest()) + assert response.status == 400 + assert response.text == "Lora file name is required" + + +async def test_get_lora_notes_not_found(routes): + response = await routes.get_lora_notes(DummyRequest(query={"name": "missing"})) + payload = json.loads(response.text) + assert response.status == 404 + assert payload == {"success": False, "error": "LoRA not found in cache"} + + +async def test_get_lora_notes_error(routes, monkeypatch): + async def failing(*_args, **_kwargs): + raise RuntimeError("boom") + + routes.service.get_lora_notes = failing + + response = await routes.get_lora_notes(DummyRequest(query={"name": "demo"})) + payload = json.loads(response.text) + + assert response.status == 500 + assert payload["success"] is False + assert payload["error"] == "boom" + + +async def test_get_lora_trigger_words_success(routes): + routes.service.trigger_words["demo"] = ["trigger"] + response = await routes.get_lora_trigger_words(DummyRequest(query={"name": "demo"})) + payload = json.loads(response.text) + assert payload == {"success": True, "trigger_words": ["trigger"]} + + +async def test_get_lora_trigger_words_missing_name(routes): + response = await routes.get_lora_trigger_words(DummyRequest()) + assert response.status == 400 + + +async def test_get_lora_trigger_words_error(routes): + async def failing(*_args, **_kwargs): + raise RuntimeError("fail") + + routes.service.get_lora_trigger_words = failing + + response = await routes.get_lora_trigger_words(DummyRequest(query={"name": "demo"})) + payload = json.loads(response.text) + assert response.status == 500 + assert payload["success"] is False + + +async def test_get_usage_tips_success(routes): + routes.service.usage_tips["path"] = "tips" + response = await routes.get_lora_usage_tips_by_path(DummyRequest(query={"relative_path": "path"})) + payload = json.loads(response.text) + assert payload == {"success": True, "usage_tips": "tips"} + + +async def test_get_usage_tips_missing_param(routes): + response = await routes.get_lora_usage_tips_by_path(DummyRequest()) + assert response.status == 400 + + +async def test_get_usage_tips_error(routes): + async def failing(*_args, **_kwargs): + raise RuntimeError("bad") + + routes.service.get_lora_usage_tips_by_relative_path = failing + response = await routes.get_lora_usage_tips_by_path(DummyRequest(query={"relative_path": "path"})) + payload = json.loads(response.text) + assert response.status == 500 + assert payload["success"] is False + + +async def test_get_preview_url_success(routes): + routes.service.previews["demo"] = "http://preview" + response = await routes.get_lora_preview_url(DummyRequest(query={"name": "demo"})) + payload = json.loads(response.text) + assert payload == {"success": True, "preview_url": "http://preview"} + + +async def test_get_preview_url_missing(routes): + response = await routes.get_lora_preview_url(DummyRequest()) + assert response.status == 400 + + +async def test_get_preview_url_not_found(routes): + response = await routes.get_lora_preview_url(DummyRequest(query={"name": "missing"})) + payload = json.loads(response.text) + assert response.status == 404 + assert payload["success"] is False + + +async def test_get_civitai_url_success(routes): + routes.service.civitai["demo"] = {"civitai_url": "https://civitai.com"} + response = await routes.get_lora_civitai_url(DummyRequest(query={"name": "demo"})) + payload = json.loads(response.text) + assert payload == {"success": True, "civitai_url": "https://civitai.com"} + + +async def test_get_civitai_url_missing(routes): + response = await routes.get_lora_civitai_url(DummyRequest()) + assert response.status == 400 + + +async def test_get_civitai_url_not_found(routes): + response = await routes.get_lora_civitai_url(DummyRequest(query={"name": "missing"})) + payload = json.loads(response.text) + assert response.status == 404 + assert payload["success"] is False + + +async def test_get_civitai_url_error(routes): + async def failing(*_args, **_kwargs): + raise RuntimeError("oops") + + routes.service.get_lora_civitai_url = failing + response = await routes.get_lora_civitai_url(DummyRequest(query={"name": "demo"})) + payload = json.loads(response.text) + assert response.status == 500 + assert payload["success"] is False + + +async def test_get_trigger_words_broadcasts(monkeypatch, routes): + send_mock = MagicMock() + PromptServer.instance = SimpleNamespace(send_sync=send_mock) + + monkeypatch.setattr("py.routes.lora_routes.get_lora_info", lambda name: (f"path/{name}", [f"trigger-{name}"])) + + request = DummyRequest(json_data={"lora_names": ["one"], "node_ids": ["node"]}) + + response = await routes.get_trigger_words(request) + payload = json.loads(response.text) + + assert payload == {"success": True} + send_mock.assert_called_once_with( + "trigger_word_update", + {"id": "node", "message": "trigger-one"}, + ) + + +async def test_get_trigger_words_error(monkeypatch, routes): + async def failing_json(): + raise RuntimeError("bad json") + + request = DummyRequest(json_data=None) + request.json = failing_json + + response = await routes.get_trigger_words(request) + payload = json.loads(response.text) + assert response.status == 500 + assert payload["success"] is False diff --git a/tests/services/test_civitai_client.py b/tests/services/test_civitai_client.py new file mode 100644 index 00000000..f5283443 --- /dev/null +++ b/tests/services/test_civitai_client.py @@ -0,0 +1,222 @@ +from unittest.mock import AsyncMock + +import pytest + +from py.services import civitai_client as civitai_client_module +from py.services.civitai_client import CivitaiClient +from py.services.model_metadata_provider import ModelMetadataProviderManager + + +class DummyDownloader: + def __init__(self): + self.download_calls = [] + self.memory_calls = [] + self.request_calls = [] + + async def download_file(self, **kwargs): + self.download_calls.append(kwargs) + return True, kwargs["save_path"] + + async def download_to_memory(self, url, use_auth=False): + self.memory_calls.append({"url": url, "use_auth": use_auth}) + return True, b"bytes", {"content-type": "image/jpeg"} + + async def make_request(self, method, url, use_auth=True): + self.request_calls.append({"method": method, "url": url, "use_auth": use_auth}) + return True, {} + + +@pytest.fixture(autouse=True) +def reset_singletons(): + CivitaiClient._instance = None + ModelMetadataProviderManager._instance = None + yield + CivitaiClient._instance = None + ModelMetadataProviderManager._instance = None + + +@pytest.fixture +def downloader(monkeypatch): + instance = DummyDownloader() + monkeypatch.setattr(civitai_client_module, "get_downloader", AsyncMock(return_value=instance)) + return instance + + +async def test_download_file_uses_downloader(tmp_path, downloader): + client = await CivitaiClient.get_instance() + save_dir = tmp_path / "files" + save_dir.mkdir() + + success, path = await client.download_file( + url="https://example.invalid/model", + save_dir=str(save_dir), + default_filename="model.safetensors", + ) + + assert success is True + assert path == str(save_dir / "model.safetensors") + assert downloader.download_calls[0]["use_auth"] is True + + +async def test_get_model_by_hash_enriches_metadata(monkeypatch, downloader): + version_payload = { + "modelId": 123, + "model": {"description": "", "tags": []}, + "creator": {}, + } + model_payload = {"description": "desc", "tags": ["tag"], "creator": {"username": "user"}} + + async def fake_make_request(method, url, use_auth=True): + if url.endswith("by-hash/hash"): + return True, version_payload.copy() + if url.endswith("/models/123"): + return True, model_payload + return False, "unexpected" + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result, error = await client.get_model_by_hash("hash") + + assert error is None + assert result["model"]["description"] == "desc" + assert result["model"]["tags"] == ["tag"] + assert result["creator"] == {"username": "user"} + + +async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader): + async def fake_make_request(method, url, use_auth=True): + return False, "not found" + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result, error = await client.get_model_by_hash("missing") + + assert result is None + assert error == "Model not found" + + +async def test_download_preview_image_writes_file(tmp_path, downloader): + client = await CivitaiClient.get_instance() + target = tmp_path / "preview" / "image.jpg" + + success = await client.download_preview_image("https://example.invalid/preview", str(target)) + + assert success is True + assert target.exists() + assert target.read_bytes() == b"bytes" + + +async def test_download_preview_image_failure(monkeypatch, downloader): + async def failing_download(url, use_auth=False): + return False, b"", {} + + downloader.download_to_memory = failing_download + + client = await CivitaiClient.get_instance() + target = "/tmp/ignored.jpg" + + success = await client.download_preview_image("https://example.invalid/preview", target) + + assert success is False + + +async def test_get_model_versions_success(monkeypatch, downloader): + async def fake_make_request(method, url, use_auth=True): + return True, {"modelVersions": [{"id": 1}], "type": "LORA", "name": "Model"} + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result = await client.get_model_versions("123") + + assert result == {"modelVersions": [{"id": 1}], "type": "LORA", "name": "Model"} + + +async def test_get_model_version_by_version_id(monkeypatch, downloader): + async def fake_make_request(method, url, use_auth=True): + if url.endswith("/model-versions/7"): + return True, { + "modelId": 321, + "model": {"description": ""}, + "files": [], + } + if url.endswith("/models/321"): + return True, {"description": "desc", "tags": ["tag"], "creator": {"username": "user"}} + return False, "unexpected" + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result = await client.get_model_version(version_id=7) + + assert result["model"]["description"] == "desc" + assert result["model"]["tags"] == ["tag"] + assert result["creator"] == {"username": "user"} + + +async def test_get_model_version_requires_identifier(monkeypatch, downloader): + client = await CivitaiClient.get_instance() + result = await client.get_model_version() + assert result is None + + +async def test_get_model_version_info_handles_not_found(monkeypatch, downloader): + async def fake_make_request(method, url, use_auth=True): + return False, "not found" + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result, error = await client.get_model_version_info("55") + + assert result is None + assert error == "Model not found" + + +async def test_get_model_version_info_success(monkeypatch, downloader): + expected = {"id": 55} + + async def fake_make_request(method, url, use_auth=True): + return True, expected + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result, error = await client.get_model_version_info("55") + + assert result == expected + assert error is None + + +async def test_get_image_info_returns_first_item(monkeypatch, downloader): + async def fake_make_request(method, url, use_auth=True): + return True, {"items": [{"id": 1}, {"id": 2}]} + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result = await client.get_image_info("42") + + assert result == {"id": 1} + + +async def test_get_image_info_handles_missing(monkeypatch, downloader): + async def fake_make_request(method, url, use_auth=True): + return True, {"items": []} + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result = await client.get_image_info("42") + + assert result is None diff --git a/tests/services/test_download_manager.py b/tests/services/test_download_manager.py new file mode 100644 index 00000000..80701e7d --- /dev/null +++ b/tests/services/test_download_manager.py @@ -0,0 +1,247 @@ +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from py.services.download_manager import DownloadManager +from py.services import download_manager +from py.services.service_registry import ServiceRegistry +from py.services.settings_manager import settings + + +@pytest.fixture(autouse=True) +def reset_download_manager(): + """Ensure each test operates on a fresh singleton.""" + DownloadManager._instance = None + yield + DownloadManager._instance = None + + +@pytest.fixture(autouse=True) +def isolate_settings(monkeypatch, tmp_path): + """Point settings writes at a temporary directory to avoid touching real files.""" + default_settings = settings._get_default_settings() + default_settings.update( + { + "default_lora_root": str(tmp_path), + "default_checkpoint_root": str(tmp_path / "checkpoints"), + "default_embedding_root": str(tmp_path / "embeddings"), + "download_path_templates": { + "lora": "{base_model}/{first_tag}", + "checkpoint": "{base_model}/{first_tag}", + "embedding": "{base_model}/{first_tag}", + }, + "base_model_path_mappings": {"BaseModel": "MappedModel"}, + } + ) + monkeypatch.setattr(settings, "settings", default_settings) + monkeypatch.setattr(type(settings), "_save_settings", lambda self: None) + + +@pytest.fixture(autouse=True) +def stub_metadata(monkeypatch): + class _StubMetadata: + def __init__(self, save_path: str): + self.file_path = save_path + self.sha256 = "sha256" + self.file_name = Path(save_path).stem + + def _factory(save_path: str): + return _StubMetadata(save_path) + + def _make_class(): + @staticmethod + def from_civitai_info(_version_info, _file_info, save_path): + return _factory(save_path) + + return type("StubMetadata", (), {"from_civitai_info": from_civitai_info}) + + stub_class = _make_class() + monkeypatch.setattr(download_manager, "LoraMetadata", stub_class) + monkeypatch.setattr(download_manager, "CheckpointMetadata", stub_class) + monkeypatch.setattr(download_manager, "EmbeddingMetadata", stub_class) + + +class DummyScanner: + def __init__(self, exists: bool = False): + self.exists = exists + self.calls = [] + + async def check_model_version_exists(self, version_id): + self.calls.append(version_id) + return self.exists + + +@pytest.fixture +def scanners(monkeypatch): + lora_scanner = DummyScanner() + checkpoint_scanner = DummyScanner() + embedding_scanner = DummyScanner() + + monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=lora_scanner)) + monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner)) + monkeypatch.setattr(ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=embedding_scanner)) + + return SimpleNamespace( + lora=lora_scanner, + checkpoint=checkpoint_scanner, + embedding=embedding_scanner, + ) + + +@pytest.fixture +def metadata_provider(monkeypatch): + class DummyProvider: + def __init__(self): + self.calls = [] + + async def get_model_version(self, model_id, model_version_id): + self.calls.append((model_id, model_version_id)) + return { + "id": 42, + "model": {"type": "LoRA", "tags": ["fantasy"]}, + "baseModel": "BaseModel", + "creator": {"username": "Author"}, + "files": [ + { + "primary": True, + "downloadUrl": "https://example.invalid/file.safetensors", + "name": "file.safetensors", + } + ], + } + + provider = DummyProvider() + monkeypatch.setattr( + download_manager, + "get_default_metadata_provider", + AsyncMock(return_value=provider), + ) + return provider + + +@pytest.fixture(autouse=True) +def noop_cleanup(monkeypatch): + async def _cleanup(self, task_id): + if task_id in self._active_downloads: + self._active_downloads[task_id]["cleaned"] = True + + monkeypatch.setattr(DownloadManager, "_cleanup_download_record", _cleanup) + + +async def test_download_requires_identifier(): + manager = DownloadManager() + result = await manager.download_from_civitai() + assert result == { + "success": False, + "error": "Either model_id or model_version_id must be provided", + } + + +async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata_provider, tmp_path): + manager = DownloadManager() + + captured = {} + + async def fake_execute_download( + self, + *, + download_url, + save_dir, + metadata, + version_info, + relative_path, + progress_callback, + model_type, + download_id, + ): + captured.update( + { + "download_url": download_url, + "save_dir": Path(save_dir), + "relative_path": relative_path, + "progress_callback": progress_callback, + "model_type": model_type, + "download_id": download_id, + "metadata_path": metadata.file_path, + } + ) + return {"success": True} + + monkeypatch.setattr(DownloadManager, "_execute_download", fake_execute_download, raising=False) + + result = await manager.download_from_civitai( + model_version_id=99, + save_dir=str(tmp_path), + use_default_paths=True, + progress_callback=None, + source=None, + ) + + assert result["success"] is True + assert "download_id" in result + assert manager._download_tasks == {} + assert manager._active_downloads[result["download_id"]]["status"] == "completed" + + assert captured["relative_path"] == "MappedModel/fantasy" + expected_dir = Path(settings.get("default_lora_root")) / "MappedModel" / "fantasy" + assert captured["save_dir"] == expected_dir + assert captured["model_type"] == "lora" + + +async def test_download_aborts_when_version_exists(monkeypatch, scanners, metadata_provider): + scanners.lora.exists = True + + manager = DownloadManager() + + execute_mock = AsyncMock(return_value={"success": True}) + monkeypatch.setattr(DownloadManager, "_execute_download", execute_mock) + + result = await manager.download_from_civitai(model_version_id=101, save_dir="/tmp") + + assert result["success"] is False + assert result["error"] == "Model version already exists in lora library" + assert "download_id" in result + assert execute_mock.await_count == 0 + + +async def test_download_handles_metadata_errors(monkeypatch, scanners): + async def failing_provider(*_args, **_kwargs): + return None + + monkeypatch.setattr( + download_manager, + "get_default_metadata_provider", + AsyncMock(return_value=SimpleNamespace(get_model_version=AsyncMock(return_value=None))), + ) + + manager = DownloadManager() + + result = await manager.download_from_civitai(model_version_id=5, save_dir="/tmp") + + assert result["success"] is False + assert result["error"] == "Failed to fetch model metadata" + assert "download_id" in result + + +async def test_download_rejects_unsupported_model_type(monkeypatch, scanners): + class Provider: + async def get_model_version(self, *_args, **_kwargs): + return { + "model": {"type": "Unsupported", "tags": []}, + "files": [], + } + + monkeypatch.setattr( + download_manager, + "get_default_metadata_provider", + AsyncMock(return_value=Provider()), + ) + + manager = DownloadManager() + + result = await manager.download_from_civitai(model_version_id=5, save_dir="/tmp") + + assert result["success"] is False + assert result["error"].startswith("Model type") diff --git a/tests/services/test_settings_manager.py b/tests/services/test_settings_manager.py new file mode 100644 index 00000000..7e547680 --- /dev/null +++ b/tests/services/test_settings_manager.py @@ -0,0 +1,61 @@ +import json + +import pytest + +from py.services.settings_manager import SettingsManager + + +@pytest.fixture +def manager(tmp_path, monkeypatch): + monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None) + mgr = SettingsManager() + mgr.settings_file = str(tmp_path / "settings.json") + return mgr + + +def test_environment_variable_overrides_settings(tmp_path, monkeypatch): + monkeypatch.setattr(SettingsManager, "_save_settings", lambda self: None) + monkeypatch.setenv("CIVITAI_API_KEY", "secret") + mgr = SettingsManager() + mgr.settings_file = str(tmp_path / "settings.json") + + assert mgr.get("civitai_api_key") == "secret" + + +def test_download_path_template_parses_json_string(manager): + templates = {"lora": "{author}", "checkpoint": "{author}", "embedding": "{author}"} + manager.settings["download_path_templates"] = json.dumps(templates) + + template = manager.get_download_path_template("lora") + + assert template == "{author}" + assert isinstance(manager.settings["download_path_templates"], dict) + + +def test_download_path_template_invalid_json(manager): + manager.settings["download_path_templates"] = "not json" + + template = manager.get_download_path_template("checkpoint") + + assert template == "{base_model}/{first_tag}" + assert manager.settings["download_path_templates"]["lora"] == "{base_model}/{first_tag}" + + +def test_auto_set_default_roots(manager): + manager.settings["folder_paths"] = { + "loras": ["/loras"], + "checkpoints": ["/checkpoints"], + "embeddings": ["/embeddings"], + } + + manager._auto_set_default_roots() + + assert manager.get("default_lora_root") == "/loras" + assert manager.get("default_checkpoint_root") == "/checkpoints" + assert manager.get("default_embedding_root") == "/embeddings" + + +def test_delete_setting(manager): + manager.set("example", 1) + manager.delete("example") + assert manager.get("example") is None diff --git a/tests/services/test_websocket_manager.py b/tests/services/test_websocket_manager.py new file mode 100644 index 00000000..b85c2197 --- /dev/null +++ b/tests/services/test_websocket_manager.py @@ -0,0 +1,84 @@ +from datetime import datetime, timedelta + +import pytest + +from py.services.websocket_manager import WebSocketManager + + +class DummyWebSocket: + def __init__(self): + self.messages = [] + self.closed = False + + async def send_json(self, data): + if self.closed: + raise RuntimeError("WebSocket closed") + self.messages.append(data) + + +@pytest.fixture +def manager(): + return WebSocketManager() + + +async def test_broadcast_init_progress_adds_defaults(manager): + ws = DummyWebSocket() + manager._init_websockets.add(ws) + + await manager.broadcast_init_progress({}) + + assert ws.messages == [ + { + "stage": "processing", + "progress": 0, + "details": "Processing...", + } + ] + + +async def test_broadcast_download_progress_tracks_state(manager): + ws = DummyWebSocket() + download_id = "abc" + manager._download_websockets[download_id] = ws + + await manager.broadcast_download_progress(download_id, {"progress": 55}) + + assert ws.messages == [{"progress": 55}] + assert manager.get_download_progress(download_id)["progress"] == 55 + + +async def test_broadcast_download_progress_missing_socket(manager): + await manager.broadcast_download_progress("missing", {"progress": 30}) + # Progress should be stored even without a live websocket + assert manager.get_download_progress("missing")["progress"] == 30 + + +async def test_auto_organize_progress_helpers(manager): + payload = {"status": "processing", "progress": 10} + await manager.broadcast_auto_organize_progress(payload) + + assert manager.get_auto_organize_progress() == payload + assert manager.is_auto_organize_running() is True + + manager.cleanup_auto_organize_progress() + assert manager.get_auto_organize_progress() is None + assert manager.is_auto_organize_running() is False + + +def test_cleanup_old_downloads(manager): + now = datetime.now() + manager._download_progress = { + "recent": {"progress": 10, "timestamp": now}, + "stale": {"progress": 100, "timestamp": now - timedelta(hours=48)}, + } + + manager.cleanup_old_downloads(max_age_hours=24) + + assert "stale" not in manager._download_progress + assert "recent" in manager._download_progress + + +def test_generate_download_id(manager): + download_id = manager.generate_download_id() + assert isinstance(download_id, str) + assert download_id