refactor: Consolidate preview file extensions into constants for improved maintainability

This commit is contained in:
Will Miao
2025-04-11 06:19:15 +08:00
parent 86810d9f03
commit 7393e92b21
4 changed files with 30 additions and 39 deletions

View File

@@ -17,6 +17,7 @@ from ..services.settings_manager import settings
import asyncio
from .update_routes import UpdateRoutes
from ..services.recipe_scanner import RecipeScanner
from ..utils.constants import PREVIEW_EXTENSIONS
logger = logging.getLogger(__name__)
@@ -244,18 +245,12 @@ class ApiRoutes:
patterns = [
f"{file_name}.safetensors", # Required
f"{file_name}.metadata.json",
f"{file_name}.preview.png",
f"{file_name}.preview.jpg",
f"{file_name}.preview.jpeg",
f"{file_name}.preview.webp",
f"{file_name}.preview.mp4",
f"{file_name}.png",
f"{file_name}.jpg",
f"{file_name}.jpeg",
f"{file_name}.webp",
f"{file_name}.mp4"
]
# 添加所有预览文件扩展名
for ext in PREVIEW_EXTENSIONS:
patterns.append(f"{file_name}{ext}")
deleted = []
main_file = patterns[0]
main_path = os.path.join(target_dir, main_file).replace(os.sep, '/')
@@ -1054,18 +1049,12 @@ class ApiRoutes:
patterns = [
f"{old_file_name}.safetensors", # Required
f"{old_file_name}.metadata.json",
f"{old_file_name}.preview.png",
f"{old_file_name}.preview.jpg",
f"{old_file_name}.preview.jpeg",
f"{old_file_name}.preview.webp",
f"{old_file_name}.preview.mp4",
f"{old_file_name}.png",
f"{old_file_name}.jpg",
f"{old_file_name}.jpeg",
f"{old_file_name}.webp",
f"{old_file_name}.mp4"
]
# 添加所有预览文件扩展名
for ext in PREVIEW_EXTENSIONS:
patterns.append(f"{old_file_name}{ext}")
# Find all matching files
existing_files = []
for pattern in patterns:

View File

@@ -11,6 +11,7 @@ from ..config import config
from ..utils.file_utils import load_metadata, get_file_info, find_preview_file, save_metadata
from .model_cache import ModelCache
from .model_hash_index import ModelHashIndex
from ..utils.constants import PREVIEW_EXTENSIONS
logger = logging.getLogger(__name__)
@@ -384,9 +385,7 @@ class ModelScanner:
shutil.move(source_metadata, target_metadata)
metadata = await self._update_metadata_paths(target_metadata, target_file)
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
'.png', '.jpeg', '.jpg', '.mp4']
for ext in preview_extensions:
for ext in PREVIEW_EXTENSIONS:
source_preview = os.path.join(source_dir, f"{base_name}{ext}")
if os.path.exists(source_preview):
target_preview = os.path.join(target_path, f"{base_name}{ext}")
@@ -491,10 +490,8 @@ class ModelScanner:
return None
base_name = os.path.splitext(file_path)[0]
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
'.png', '.jpeg', '.jpg', '.mp4']
for ext in preview_extensions:
for ext in PREVIEW_EXTENSIONS:
preview_path = f"{base_name}{ext}"
if os.path.exists(preview_path):
return config.get_preview_static_url(preview_path)

View File

@@ -5,4 +5,18 @@ NSFW_LEVELS = {
"X": 8,
"XXX": 16,
"Blocked": 32, # Probably not actually visible through the API without being logged in on model owner account?
}
}
# 预览文件扩展名
PREVIEW_EXTENSIONS = [
'.preview.png',
'.preview.jpeg',
'.preview.jpg',
'.preview.webp',
'.preview.mp4',
'.png',
'.jpeg',
'.jpg',
'.webp',
'.mp4'
]

View File

@@ -8,6 +8,7 @@ from typing import Dict, Optional, Type
from .model_utils import determine_base_model
from .lora_metadata import extract_lora_metadata, extract_checkpoint_metadata
from .models import BaseModelMetadata, LoraMetadata, CheckpointMetadata
from .constants import PREVIEW_EXTENSIONS
logger = logging.getLogger(__name__)
@@ -21,19 +22,9 @@ async def calculate_sha256(file_path: str) -> str:
def find_preview_file(base_name: str, dir_path: str) -> str:
"""Find preview file for given base name in directory"""
preview_patterns = [
f"{base_name}.preview.png",
f"{base_name}.preview.jpg",
f"{base_name}.preview.jpeg",
f"{base_name}.preview.mp4",
f"{base_name}.png",
f"{base_name}.jpg",
f"{base_name}.jpeg",
f"{base_name}.mp4"
]
for pattern in preview_patterns:
full_pattern = os.path.join(dir_path, pattern)
for ext in PREVIEW_EXTENSIONS:
full_pattern = os.path.join(dir_path, f"{base_name}{ext}")
if os.path.exists(full_pattern):
return full_pattern.replace(os.sep, "/")
return ""