refactor: enhance scanner handling and add embedding support in download manager and misc routes

This commit is contained in:
Will Miao
2025-07-25 23:59:27 +08:00
parent 381bd3938a
commit 3cf9121a8c
3 changed files with 48 additions and 22 deletions

View File

@@ -632,9 +632,10 @@ class MiscRoutes:
'error': 'Parameter modelId must be an integer' 'error': 'Parameter modelId must be an integer'
}, status=400) }, status=400)
# Get both lora and checkpoint scanners # Get all scanners
lora_scanner = await ServiceRegistry.get_lora_scanner() lora_scanner = await ServiceRegistry.get_lora_scanner()
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
# If modelVersionId is provided, check for specific version # If modelVersionId is provided, check for specific version
if model_version_id_str: if model_version_id_str:
@@ -646,18 +647,19 @@ class MiscRoutes:
'error': 'Parameter modelVersionId must be an integer' 'error': 'Parameter modelVersionId must be an integer'
}, status=400) }, status=400)
# Check if the specific version exists in either scanner # Check lora scanner first
exists = False exists = False
model_type = None model_type = None
# Check lora scanner first
if await lora_scanner.check_model_version_exists(model_id, model_version_id): if await lora_scanner.check_model_version_exists(model_id, model_version_id):
exists = True exists = True
model_type = 'lora' model_type = 'lora'
# If not found in lora, check checkpoint scanner
elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_id, model_version_id): elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_id, model_version_id):
exists = True exists = True
model_type = 'checkpoint' model_type = 'checkpoint'
elif embedding_scanner and await embedding_scanner.check_model_version_exists(model_id, model_version_id):
exists = True
model_type = 'embedding'
return web.json_response({ return web.json_response({
'success': True, 'success': True,
@@ -667,25 +669,29 @@ class MiscRoutes:
# If modelVersionId is not provided, return all version IDs for the model # If modelVersionId is not provided, return all version IDs for the model
else: else:
# Get versions from lora scanner first
lora_versions = await lora_scanner.get_model_versions_by_id(model_id) lora_versions = await lora_scanner.get_model_versions_by_id(model_id)
checkpoint_versions = [] checkpoint_versions = []
embedding_versions = []
# Only check checkpoint scanner if no lora versions found
# 优先lora其次checkpoint最后embedding
if not lora_versions: if not lora_versions:
checkpoint_versions = await checkpoint_scanner.get_model_versions_by_id(model_id) checkpoint_versions = await checkpoint_scanner.get_model_versions_by_id(model_id)
if not lora_versions and not checkpoint_versions:
# Determine model type and combine results embedding_versions = await embedding_scanner.get_model_versions_by_id(model_id)
model_type = None model_type = None
versions = [] versions = []
if lora_versions: if lora_versions:
model_type = 'lora' model_type = 'lora'
versions = lora_versions versions = lora_versions
elif checkpoint_versions: elif checkpoint_versions:
model_type = 'checkpoint' model_type = 'checkpoint'
versions = checkpoint_versions versions = checkpoint_versions
elif embedding_versions:
model_type = 'embedding'
versions = embedding_versions
return web.json_response({ return web.json_response({
'success': True, 'success': True,
'modelId': model_id, 'modelId': model_id,

View File

@@ -4,7 +4,7 @@ import asyncio
from collections import OrderedDict from collections import OrderedDict
import uuid import uuid
from typing import Dict from typing import Dict
from ..utils.models import LoraMetadata, CheckpointMetadata from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
from ..utils.exif_utils import ExifUtils from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
@@ -204,6 +204,8 @@ class DownloadManager:
model_type = 'checkpoint' model_type = 'checkpoint'
elif model_type_from_info in VALID_LORA_TYPES: elif model_type_from_info in VALID_LORA_TYPES:
model_type = 'lora' model_type = 'lora'
elif model_type_from_info == 'textualinversion':
model_type = 'embedding'
else: else:
return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'} return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'}
@@ -222,6 +224,11 @@ class DownloadManager:
checkpoint_scanner = await self._get_checkpoint_scanner() checkpoint_scanner = await self._get_checkpoint_scanner()
if await checkpoint_scanner.check_model_version_exists(version_model_id, version_id): if await checkpoint_scanner.check_model_version_exists(version_model_id, version_id):
return {'success': False, 'error': 'Model version already exists in checkpoint library'} return {'success': False, 'error': 'Model version already exists in checkpoint library'}
elif model_type == 'embedding':
# Embeddings are not checked in scanners, but we can still check if it exists
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
if await embedding_scanner.check_model_version_exists(version_model_id, version_id):
return {'success': False, 'error': 'Model version already exists in embedding library'}
# Handle use_default_paths # Handle use_default_paths
if use_default_paths: if use_default_paths:
@@ -231,11 +238,16 @@ class DownloadManager:
if not default_path: if not default_path:
return {'success': False, 'error': 'Default checkpoint root path not set in settings'} return {'success': False, 'error': 'Default checkpoint root path not set in settings'}
save_dir = default_path save_dir = default_path
else: # model_type == 'lora' elif model_type == 'lora':
default_path = settings.get('default_lora_root') default_path = settings.get('default_lora_root')
if not default_path: if not default_path:
return {'success': False, 'error': 'Default lora root path not set in settings'} return {'success': False, 'error': 'Default lora root path not set in settings'}
save_dir = default_path save_dir = default_path
elif model_type == 'embedding':
default_path = settings.get('default_embedding_root')
if not default_path:
return {'success': False, 'error': 'Default embedding root path not set in settings'}
save_dir = default_path
# Calculate relative path using template # Calculate relative path using template
relative_path = self._calculate_relative_path(version_info) relative_path = self._calculate_relative_path(version_info)
@@ -282,9 +294,12 @@ class DownloadManager:
if model_type == "checkpoint": if model_type == "checkpoint":
metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path) metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path)
logger.info(f"Creating CheckpointMetadata for {file_name}") logger.info(f"Creating CheckpointMetadata for {file_name}")
else: elif model_type == "lora":
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path) metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
logger.info(f"Creating LoraMetadata for {file_name}") logger.info(f"Creating LoraMetadata for {file_name}")
elif model_type == "embedding":
metadata = EmbeddingMetadata.from_civitai_info(version_info, file_info, save_path)
logger.info(f"Creating EmbeddingMetadata for {file_name}")
# 6. Start download process # 6. Start download process
result = await self._execute_download( result = await self._execute_download(
@@ -447,9 +462,12 @@ class DownloadManager:
if model_type == "checkpoint": if model_type == "checkpoint":
scanner = await self._get_checkpoint_scanner() scanner = await self._get_checkpoint_scanner()
logger.info(f"Updating checkpoint cache for {save_path}") logger.info(f"Updating checkpoint cache for {save_path}")
else: elif model_type == "lora":
scanner = await self._get_lora_scanner() scanner = await self._get_lora_scanner()
logger.info(f"Updating lora cache for {save_path}") logger.info(f"Updating lora cache for {save_path}")
elif model_type == "embedding":
scanner = await ServiceRegistry.get_embedding_scanner()
logger.info(f"Updating embedding cache for {save_path}")
# Convert metadata to dictionary # Convert metadata to dictionary
metadata_dict = metadata.to_dict() metadata_dict = metadata.to_dict()

View File

@@ -1,11 +1,13 @@
<div class="controls"> <div class="controls">
<div class="folder-tags-container"> {% if folders|length > 1 %}
<div class="folder-tags"> <div class="folder-tags-container">
{% for folder in folders %} <div class="folder-tags">
<div class="tag" data-folder="{{ folder }}">{{ folder }}</div> {% for folder in folders %}
{% endfor %} <div class="tag" data-folder="{{ folder }}">{{ folder }}</div>
{% endfor %}
</div>
</div> </div>
</div> {% endif %}
<div class="actions"> <div class="actions">
<div class="action-buttons"> <div class="action-buttons">