mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
feat: add CivArchive metadata provider and support for optional source parameter in downloads
This commit is contained in:
@@ -36,17 +36,10 @@ class DownloadManager:
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
self._civitai_client = None # Will be lazily initialized
|
||||
# Add download management
|
||||
self._active_downloads = OrderedDict() # download_id -> download_info
|
||||
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
|
||||
self._download_tasks = {} # download_id -> asyncio.Task
|
||||
|
||||
async def _get_civitai_client(self):
|
||||
"""Lazily initialize CivitaiClient from registry"""
|
||||
if self._civitai_client is None:
|
||||
self._civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
return self._civitai_client
|
||||
|
||||
async def _get_lora_scanner(self):
|
||||
"""Get the lora scanner from registry"""
|
||||
@@ -59,7 +52,7 @@ class DownloadManager:
|
||||
async def download_from_civitai(self, model_id: int = None, model_version_id: int = None,
|
||||
save_dir: str = None, relative_path: str = '',
|
||||
progress_callback=None, use_default_paths: bool = False,
|
||||
download_id: str = None) -> Dict:
|
||||
download_id: str = None, source: str = None) -> Dict:
|
||||
"""Download model from Civitai with task tracking and concurrency control
|
||||
|
||||
Args:
|
||||
@@ -70,6 +63,7 @@ class DownloadManager:
|
||||
progress_callback: Callback function for progress updates
|
||||
use_default_paths: Flag to use default paths
|
||||
download_id: Unique identifier for this download task
|
||||
source: Optional source parameter to specify metadata provider
|
||||
|
||||
Returns:
|
||||
Dict with download result
|
||||
@@ -93,7 +87,7 @@ class DownloadManager:
|
||||
download_task = asyncio.create_task(
|
||||
self._download_with_semaphore(
|
||||
task_id, model_id, model_version_id, save_dir,
|
||||
relative_path, progress_callback, use_default_paths
|
||||
relative_path, progress_callback, use_default_paths, source
|
||||
)
|
||||
)
|
||||
|
||||
@@ -114,7 +108,8 @@ class DownloadManager:
|
||||
|
||||
async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int,
|
||||
save_dir: str, relative_path: str,
|
||||
progress_callback=None, use_default_paths: bool = False):
|
||||
progress_callback=None, use_default_paths: bool = False,
|
||||
source: str = None):
|
||||
"""Execute download with semaphore to limit concurrency"""
|
||||
# Update status to waiting
|
||||
if task_id in self._active_downloads:
|
||||
@@ -144,7 +139,7 @@ class DownloadManager:
|
||||
result = await self._execute_original_download(
|
||||
model_id, model_version_id, save_dir,
|
||||
relative_path, tracking_callback, use_default_paths,
|
||||
task_id
|
||||
task_id, source
|
||||
)
|
||||
|
||||
# Update status based on result
|
||||
@@ -179,7 +174,7 @@ class DownloadManager:
|
||||
|
||||
async def _execute_original_download(self, model_id, model_version_id, save_dir,
|
||||
relative_path, progress_callback, use_default_paths,
|
||||
download_id=None):
|
||||
download_id=None, source=None):
|
||||
"""Wrapper for original download_from_civitai implementation"""
|
||||
try:
|
||||
# Check if model version already exists in library
|
||||
@@ -201,8 +196,12 @@ class DownloadManager:
|
||||
if await embedding_scanner.check_model_version_exists(model_version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in embedding library'}
|
||||
|
||||
# Get metadata provider instead of civitai client directly
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
# Get metadata provider based on source parameter
|
||||
if source == 'civarchive':
|
||||
from .metadata_service import get_metadata_provider
|
||||
metadata_provider = await get_metadata_provider('civarchive')
|
||||
else:
|
||||
metadata_provider = await get_default_metadata_provider()
|
||||
|
||||
# Get version info based on the provided identifier
|
||||
version_info = await metadata_provider.get_model_version(model_id, model_version_id)
|
||||
@@ -396,8 +395,6 @@ class DownloadManager:
|
||||
model_type: str = "lora", download_id: str = None) -> Dict:
|
||||
"""Execute the actual download process including preview images and model files"""
|
||||
try:
|
||||
civitai_client = await self._get_civitai_client()
|
||||
|
||||
# Extract original filename details
|
||||
original_filename = os.path.basename(metadata.file_path)
|
||||
base_name, extension = os.path.splitext(original_filename)
|
||||
@@ -504,11 +501,13 @@ class DownloadManager:
|
||||
|
||||
# Download model file with progress tracking using downloader
|
||||
downloader = await get_downloader()
|
||||
# Determine if the download URL is from Civitai
|
||||
use_auth = download_url.startswith("https://civitai.com/api/download/")
|
||||
success, result = await downloader.download_file(
|
||||
download_url,
|
||||
save_path, # Use full path instead of separate dir and filename
|
||||
progress_callback=lambda p: self._handle_download_progress(p, progress_callback),
|
||||
use_auth=True # Model downloads need authentication
|
||||
use_auth=use_auth # Only use authentication for Civitai downloads
|
||||
)
|
||||
|
||||
if not success:
|
||||
|
||||
Reference in New Issue
Block a user