feat: add CivArchive metadata provider and support for optional source parameter in downloads

This commit is contained in:
Will Miao
2025-09-12 21:13:15 +08:00
parent 00b77581fc
commit d05076d258
6 changed files with 173 additions and 19 deletions

View File

@@ -36,17 +36,10 @@ class DownloadManager:
return
self._initialized = True
self._civitai_client = None # Will be lazily initialized
# Add download management
self._active_downloads = OrderedDict() # download_id -> download_info
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
self._download_tasks = {} # download_id -> asyncio.Task
async def _get_civitai_client(self):
"""Lazily initialize CivitaiClient from registry"""
if self._civitai_client is None:
self._civitai_client = await ServiceRegistry.get_civitai_client()
return self._civitai_client
async def _get_lora_scanner(self):
"""Get the lora scanner from registry"""
@@ -59,7 +52,7 @@ class DownloadManager:
async def download_from_civitai(self, model_id: int = None, model_version_id: int = None,
save_dir: str = None, relative_path: str = '',
progress_callback=None, use_default_paths: bool = False,
download_id: str = None) -> Dict:
download_id: str = None, source: str = None) -> Dict:
"""Download model from Civitai with task tracking and concurrency control
Args:
@@ -70,6 +63,7 @@ class DownloadManager:
progress_callback: Callback function for progress updates
use_default_paths: Flag to use default paths
download_id: Unique identifier for this download task
source: Optional source parameter to specify metadata provider
Returns:
Dict with download result
@@ -93,7 +87,7 @@ class DownloadManager:
download_task = asyncio.create_task(
self._download_with_semaphore(
task_id, model_id, model_version_id, save_dir,
relative_path, progress_callback, use_default_paths
relative_path, progress_callback, use_default_paths, source
)
)
@@ -114,7 +108,8 @@ class DownloadManager:
async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int,
save_dir: str, relative_path: str,
progress_callback=None, use_default_paths: bool = False):
progress_callback=None, use_default_paths: bool = False,
source: str = None):
"""Execute download with semaphore to limit concurrency"""
# Update status to waiting
if task_id in self._active_downloads:
@@ -144,7 +139,7 @@ class DownloadManager:
result = await self._execute_original_download(
model_id, model_version_id, save_dir,
relative_path, tracking_callback, use_default_paths,
task_id
task_id, source
)
# Update status based on result
@@ -179,7 +174,7 @@ class DownloadManager:
async def _execute_original_download(self, model_id, model_version_id, save_dir,
relative_path, progress_callback, use_default_paths,
download_id=None):
download_id=None, source=None):
"""Wrapper for original download_from_civitai implementation"""
try:
# Check if model version already exists in library
@@ -201,8 +196,12 @@ class DownloadManager:
if await embedding_scanner.check_model_version_exists(model_version_id):
return {'success': False, 'error': 'Model version already exists in embedding library'}
# Get metadata provider instead of civitai client directly
metadata_provider = await get_default_metadata_provider()
# Get metadata provider based on source parameter
if source == 'civarchive':
from .metadata_service import get_metadata_provider
metadata_provider = await get_metadata_provider('civarchive')
else:
metadata_provider = await get_default_metadata_provider()
# Get version info based on the provided identifier
version_info = await metadata_provider.get_model_version(model_id, model_version_id)
@@ -396,8 +395,6 @@ class DownloadManager:
model_type: str = "lora", download_id: str = None) -> Dict:
"""Execute the actual download process including preview images and model files"""
try:
civitai_client = await self._get_civitai_client()
# Extract original filename details
original_filename = os.path.basename(metadata.file_path)
base_name, extension = os.path.splitext(original_filename)
@@ -504,11 +501,13 @@ class DownloadManager:
# Download model file with progress tracking using downloader
downloader = await get_downloader()
# Determine if the download URL is from Civitai
use_auth = download_url.startswith("https://civitai.com/api/download/")
success, result = await downloader.download_file(
download_url,
save_path, # Use full path instead of separate dir and filename
progress_callback=lambda p: self._handle_download_progress(p, progress_callback),
use_auth=True # Model downloads need authentication
use_auth=use_auth # Only use authentication for Civitai downloads
)
if not success: