refactor: Consolidate preview file extensions into constants for improved maintainability

This commit is contained in:
Will Miao
2025-04-11 06:19:15 +08:00
parent 86810d9f03
commit 7393e92b21
4 changed files with 30 additions and 39 deletions

View File

@@ -17,6 +17,7 @@ from ..services.settings_manager import settings
import asyncio import asyncio
from .update_routes import UpdateRoutes from .update_routes import UpdateRoutes
from ..services.recipe_scanner import RecipeScanner from ..services.recipe_scanner import RecipeScanner
from ..utils.constants import PREVIEW_EXTENSIONS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -244,18 +245,12 @@ class ApiRoutes:
patterns = [ patterns = [
f"{file_name}.safetensors", # Required f"{file_name}.safetensors", # Required
f"{file_name}.metadata.json", f"{file_name}.metadata.json",
f"{file_name}.preview.png",
f"{file_name}.preview.jpg",
f"{file_name}.preview.jpeg",
f"{file_name}.preview.webp",
f"{file_name}.preview.mp4",
f"{file_name}.png",
f"{file_name}.jpg",
f"{file_name}.jpeg",
f"{file_name}.webp",
f"{file_name}.mp4"
] ]
# 添加所有预览文件扩展名
for ext in PREVIEW_EXTENSIONS:
patterns.append(f"{file_name}{ext}")
deleted = [] deleted = []
main_file = patterns[0] main_file = patterns[0]
main_path = os.path.join(target_dir, main_file).replace(os.sep, '/') main_path = os.path.join(target_dir, main_file).replace(os.sep, '/')
@@ -1054,18 +1049,12 @@ class ApiRoutes:
patterns = [ patterns = [
f"{old_file_name}.safetensors", # Required f"{old_file_name}.safetensors", # Required
f"{old_file_name}.metadata.json", f"{old_file_name}.metadata.json",
f"{old_file_name}.preview.png",
f"{old_file_name}.preview.jpg",
f"{old_file_name}.preview.jpeg",
f"{old_file_name}.preview.webp",
f"{old_file_name}.preview.mp4",
f"{old_file_name}.png",
f"{old_file_name}.jpg",
f"{old_file_name}.jpeg",
f"{old_file_name}.webp",
f"{old_file_name}.mp4"
] ]
# 添加所有预览文件扩展名
for ext in PREVIEW_EXTENSIONS:
patterns.append(f"{old_file_name}{ext}")
# Find all matching files # Find all matching files
existing_files = [] existing_files = []
for pattern in patterns: for pattern in patterns:

View File

@@ -11,6 +11,7 @@ from ..config import config
from ..utils.file_utils import load_metadata, get_file_info, find_preview_file, save_metadata from ..utils.file_utils import load_metadata, get_file_info, find_preview_file, save_metadata
from .model_cache import ModelCache from .model_cache import ModelCache
from .model_hash_index import ModelHashIndex from .model_hash_index import ModelHashIndex
from ..utils.constants import PREVIEW_EXTENSIONS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -384,9 +385,7 @@ class ModelScanner:
shutil.move(source_metadata, target_metadata) shutil.move(source_metadata, target_metadata)
metadata = await self._update_metadata_paths(target_metadata, target_file) metadata = await self._update_metadata_paths(target_metadata, target_file)
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4', for ext in PREVIEW_EXTENSIONS:
'.png', '.jpeg', '.jpg', '.mp4']
for ext in preview_extensions:
source_preview = os.path.join(source_dir, f"{base_name}{ext}") source_preview = os.path.join(source_dir, f"{base_name}{ext}")
if os.path.exists(source_preview): if os.path.exists(source_preview):
target_preview = os.path.join(target_path, f"{base_name}{ext}") target_preview = os.path.join(target_path, f"{base_name}{ext}")
@@ -491,10 +490,8 @@ class ModelScanner:
return None return None
base_name = os.path.splitext(file_path)[0] base_name = os.path.splitext(file_path)[0]
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
'.png', '.jpeg', '.jpg', '.mp4']
for ext in preview_extensions: for ext in PREVIEW_EXTENSIONS:
preview_path = f"{base_name}{ext}" preview_path = f"{base_name}{ext}"
if os.path.exists(preview_path): if os.path.exists(preview_path):
return config.get_preview_static_url(preview_path) return config.get_preview_static_url(preview_path)

View File

@@ -5,4 +5,18 @@ NSFW_LEVELS = {
"X": 8, "X": 8,
"XXX": 16, "XXX": 16,
"Blocked": 32, # Probably not actually visible through the API without being logged in on model owner account? "Blocked": 32, # Probably not actually visible through the API without being logged in on model owner account?
} }
# 预览文件扩展名
PREVIEW_EXTENSIONS = [
'.preview.png',
'.preview.jpeg',
'.preview.jpg',
'.preview.webp',
'.preview.mp4',
'.png',
'.jpeg',
'.jpg',
'.webp',
'.mp4'
]

View File

@@ -8,6 +8,7 @@ from typing import Dict, Optional, Type
from .model_utils import determine_base_model from .model_utils import determine_base_model
from .lora_metadata import extract_lora_metadata, extract_checkpoint_metadata from .lora_metadata import extract_lora_metadata, extract_checkpoint_metadata
from .models import BaseModelMetadata, LoraMetadata, CheckpointMetadata from .models import BaseModelMetadata, LoraMetadata, CheckpointMetadata
from .constants import PREVIEW_EXTENSIONS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -21,19 +22,9 @@ async def calculate_sha256(file_path: str) -> str:
def find_preview_file(base_name: str, dir_path: str) -> str: def find_preview_file(base_name: str, dir_path: str) -> str:
"""Find preview file for given base name in directory""" """Find preview file for given base name in directory"""
preview_patterns = [
f"{base_name}.preview.png",
f"{base_name}.preview.jpg",
f"{base_name}.preview.jpeg",
f"{base_name}.preview.mp4",
f"{base_name}.png",
f"{base_name}.jpg",
f"{base_name}.jpeg",
f"{base_name}.mp4"
]
for pattern in preview_patterns: for ext in PREVIEW_EXTENSIONS:
full_pattern = os.path.join(dir_path, pattern) full_pattern = os.path.join(dir_path, f"{base_name}{ext}")
if os.path.exists(full_pattern): if os.path.exists(full_pattern):
return full_pattern.replace(os.sep, "/") return full_pattern.replace(os.sep, "/")
return "" return ""