From 3cf9121a8ce06462668e3966278dc89b85f648c3 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 25 Jul 2025 23:59:27 +0800 Subject: [PATCH] refactor: enhance scanner handling and add embedding support in download manager and misc routes --- py/routes/misc_routes.py | 30 ++++++++++++++++++------------ py/services/download_manager.py | 26 ++++++++++++++++++++++---- templates/components/controls.html | 14 ++++++++------ 3 files changed, 48 insertions(+), 22 deletions(-) diff --git a/py/routes/misc_routes.py b/py/routes/misc_routes.py index 596e0323..4fdd9485 100644 --- a/py/routes/misc_routes.py +++ b/py/routes/misc_routes.py @@ -632,9 +632,10 @@ class MiscRoutes: 'error': 'Parameter modelId must be an integer' }, status=400) - # Get both lora and checkpoint scanners + # Get all scanners lora_scanner = await ServiceRegistry.get_lora_scanner() checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() + embedding_scanner = await ServiceRegistry.get_embedding_scanner() # If modelVersionId is provided, check for specific version if model_version_id_str: @@ -646,18 +647,19 @@ class MiscRoutes: 'error': 'Parameter modelVersionId must be an integer' }, status=400) - # Check if the specific version exists in either scanner + # Check lora scanner first exists = False model_type = None - - # Check lora scanner first + if await lora_scanner.check_model_version_exists(model_id, model_version_id): exists = True model_type = 'lora' - # If not found in lora, check checkpoint scanner elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_id, model_version_id): exists = True model_type = 'checkpoint' + elif embedding_scanner and await embedding_scanner.check_model_version_exists(model_id, model_version_id): + exists = True + model_type = 'embedding' return web.json_response({ 'success': True, @@ -667,25 +669,29 @@ class MiscRoutes: # If modelVersionId is not provided, return all version IDs for the model else: - # Get versions from lora scanner first lora_versions = await lora_scanner.get_model_versions_by_id(model_id) checkpoint_versions = [] - - # Only check checkpoint scanner if no lora versions found + embedding_versions = [] + + # 优先lora,其次checkpoint,最后embedding if not lora_versions: checkpoint_versions = await checkpoint_scanner.get_model_versions_by_id(model_id) - - # Determine model type and combine results + if not lora_versions and not checkpoint_versions: + embedding_versions = await embedding_scanner.get_model_versions_by_id(model_id) + model_type = None versions = [] - + if lora_versions: model_type = 'lora' versions = lora_versions elif checkpoint_versions: model_type = 'checkpoint' versions = checkpoint_versions - + elif embedding_versions: + model_type = 'embedding' + versions = embedding_versions + return web.json_response({ 'success': True, 'modelId': model_id, diff --git a/py/services/download_manager.py b/py/services/download_manager.py index e98e5498..aacdc362 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -4,7 +4,7 @@ import asyncio from collections import OrderedDict import uuid from typing import Dict -from ..utils.models import LoraMetadata, CheckpointMetadata +from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS from ..utils.exif_utils import ExifUtils from ..utils.metadata_manager import MetadataManager @@ -204,6 +204,8 @@ class DownloadManager: model_type = 'checkpoint' elif model_type_from_info in VALID_LORA_TYPES: model_type = 'lora' + elif model_type_from_info == 'textualinversion': + model_type = 'embedding' else: return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'} @@ -222,6 +224,11 @@ class DownloadManager: checkpoint_scanner = await self._get_checkpoint_scanner() if await checkpoint_scanner.check_model_version_exists(version_model_id, version_id): return {'success': False, 'error': 'Model version already exists in checkpoint library'} + elif model_type == 'embedding': + # Embeddings are not checked in scanners, but we can still check if it exists + embedding_scanner = await ServiceRegistry.get_embedding_scanner() + if await embedding_scanner.check_model_version_exists(version_model_id, version_id): + return {'success': False, 'error': 'Model version already exists in embedding library'} # Handle use_default_paths if use_default_paths: @@ -231,11 +238,16 @@ class DownloadManager: if not default_path: return {'success': False, 'error': 'Default checkpoint root path not set in settings'} save_dir = default_path - else: # model_type == 'lora' + elif model_type == 'lora': default_path = settings.get('default_lora_root') if not default_path: return {'success': False, 'error': 'Default lora root path not set in settings'} save_dir = default_path + elif model_type == 'embedding': + default_path = settings.get('default_embedding_root') + if not default_path: + return {'success': False, 'error': 'Default embedding root path not set in settings'} + save_dir = default_path # Calculate relative path using template relative_path = self._calculate_relative_path(version_info) @@ -282,9 +294,12 @@ class DownloadManager: if model_type == "checkpoint": metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path) logger.info(f"Creating CheckpointMetadata for {file_name}") - else: + elif model_type == "lora": metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path) logger.info(f"Creating LoraMetadata for {file_name}") + elif model_type == "embedding": + metadata = EmbeddingMetadata.from_civitai_info(version_info, file_info, save_path) + logger.info(f"Creating EmbeddingMetadata for {file_name}") # 6. Start download process result = await self._execute_download( @@ -447,9 +462,12 @@ class DownloadManager: if model_type == "checkpoint": scanner = await self._get_checkpoint_scanner() logger.info(f"Updating checkpoint cache for {save_path}") - else: + elif model_type == "lora": scanner = await self._get_lora_scanner() logger.info(f"Updating lora cache for {save_path}") + elif model_type == "embedding": + scanner = await ServiceRegistry.get_embedding_scanner() + logger.info(f"Updating embedding cache for {save_path}") # Convert metadata to dictionary metadata_dict = metadata.to_dict() diff --git a/templates/components/controls.html b/templates/components/controls.html index 3c36b1bd..68f03db1 100644 --- a/templates/components/controls.html +++ b/templates/components/controls.html @@ -1,11 +1,13 @@
-
-
- {% for folder in folders %} -
{{ folder }}
- {% endfor %} + {% if folders|length > 1 %} +
+
+ {% for folder in folders %} +
{{ folder }}
+ {% endfor %} +
-
+ {% endif %}