mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat: add CivArchive metadata provider and support for optional source parameter in downloads
This commit is contained in:
@@ -523,6 +523,7 @@ class BaseModelRoutes(ABC):
|
|||||||
model_version_id = request.query.get('model_version_id')
|
model_version_id = request.query.get('model_version_id')
|
||||||
download_id = request.query.get('download_id')
|
download_id = request.query.get('download_id')
|
||||||
use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true'
|
use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true'
|
||||||
|
source = request.query.get('source') # Optional source parameter
|
||||||
|
|
||||||
# Create a data dictionary that mimics what would be received from a POST request
|
# Create a data dictionary that mimics what would be received from a POST request
|
||||||
data = {
|
data = {
|
||||||
@@ -538,6 +539,10 @@ class BaseModelRoutes(ABC):
|
|||||||
|
|
||||||
data['use_default_paths'] = use_default_paths
|
data['use_default_paths'] = use_default_paths
|
||||||
|
|
||||||
|
# Add source parameter if provided
|
||||||
|
if source:
|
||||||
|
data['source'] = source
|
||||||
|
|
||||||
# Create a mock request object with the data
|
# Create a mock request object with the data
|
||||||
future = asyncio.get_event_loop().create_future()
|
future = asyncio.get_event_loop().create_future()
|
||||||
future.set_result(data)
|
future.set_result(data)
|
||||||
|
|||||||
@@ -36,17 +36,10 @@ class DownloadManager:
|
|||||||
return
|
return
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
self._civitai_client = None # Will be lazily initialized
|
|
||||||
# Add download management
|
# Add download management
|
||||||
self._active_downloads = OrderedDict() # download_id -> download_info
|
self._active_downloads = OrderedDict() # download_id -> download_info
|
||||||
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
|
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
|
||||||
self._download_tasks = {} # download_id -> asyncio.Task
|
self._download_tasks = {} # download_id -> asyncio.Task
|
||||||
|
|
||||||
async def _get_civitai_client(self):
|
|
||||||
"""Lazily initialize CivitaiClient from registry"""
|
|
||||||
if self._civitai_client is None:
|
|
||||||
self._civitai_client = await ServiceRegistry.get_civitai_client()
|
|
||||||
return self._civitai_client
|
|
||||||
|
|
||||||
async def _get_lora_scanner(self):
|
async def _get_lora_scanner(self):
|
||||||
"""Get the lora scanner from registry"""
|
"""Get the lora scanner from registry"""
|
||||||
@@ -59,7 +52,7 @@ class DownloadManager:
|
|||||||
async def download_from_civitai(self, model_id: int = None, model_version_id: int = None,
|
async def download_from_civitai(self, model_id: int = None, model_version_id: int = None,
|
||||||
save_dir: str = None, relative_path: str = '',
|
save_dir: str = None, relative_path: str = '',
|
||||||
progress_callback=None, use_default_paths: bool = False,
|
progress_callback=None, use_default_paths: bool = False,
|
||||||
download_id: str = None) -> Dict:
|
download_id: str = None, source: str = None) -> Dict:
|
||||||
"""Download model from Civitai with task tracking and concurrency control
|
"""Download model from Civitai with task tracking and concurrency control
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -70,6 +63,7 @@ class DownloadManager:
|
|||||||
progress_callback: Callback function for progress updates
|
progress_callback: Callback function for progress updates
|
||||||
use_default_paths: Flag to use default paths
|
use_default_paths: Flag to use default paths
|
||||||
download_id: Unique identifier for this download task
|
download_id: Unique identifier for this download task
|
||||||
|
source: Optional source parameter to specify metadata provider
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with download result
|
Dict with download result
|
||||||
@@ -93,7 +87,7 @@ class DownloadManager:
|
|||||||
download_task = asyncio.create_task(
|
download_task = asyncio.create_task(
|
||||||
self._download_with_semaphore(
|
self._download_with_semaphore(
|
||||||
task_id, model_id, model_version_id, save_dir,
|
task_id, model_id, model_version_id, save_dir,
|
||||||
relative_path, progress_callback, use_default_paths
|
relative_path, progress_callback, use_default_paths, source
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -114,7 +108,8 @@ class DownloadManager:
|
|||||||
|
|
||||||
async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int,
|
async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int,
|
||||||
save_dir: str, relative_path: str,
|
save_dir: str, relative_path: str,
|
||||||
progress_callback=None, use_default_paths: bool = False):
|
progress_callback=None, use_default_paths: bool = False,
|
||||||
|
source: str = None):
|
||||||
"""Execute download with semaphore to limit concurrency"""
|
"""Execute download with semaphore to limit concurrency"""
|
||||||
# Update status to waiting
|
# Update status to waiting
|
||||||
if task_id in self._active_downloads:
|
if task_id in self._active_downloads:
|
||||||
@@ -144,7 +139,7 @@ class DownloadManager:
|
|||||||
result = await self._execute_original_download(
|
result = await self._execute_original_download(
|
||||||
model_id, model_version_id, save_dir,
|
model_id, model_version_id, save_dir,
|
||||||
relative_path, tracking_callback, use_default_paths,
|
relative_path, tracking_callback, use_default_paths,
|
||||||
task_id
|
task_id, source
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update status based on result
|
# Update status based on result
|
||||||
@@ -179,7 +174,7 @@ class DownloadManager:
|
|||||||
|
|
||||||
async def _execute_original_download(self, model_id, model_version_id, save_dir,
|
async def _execute_original_download(self, model_id, model_version_id, save_dir,
|
||||||
relative_path, progress_callback, use_default_paths,
|
relative_path, progress_callback, use_default_paths,
|
||||||
download_id=None):
|
download_id=None, source=None):
|
||||||
"""Wrapper for original download_from_civitai implementation"""
|
"""Wrapper for original download_from_civitai implementation"""
|
||||||
try:
|
try:
|
||||||
# Check if model version already exists in library
|
# Check if model version already exists in library
|
||||||
@@ -201,8 +196,12 @@ class DownloadManager:
|
|||||||
if await embedding_scanner.check_model_version_exists(model_version_id):
|
if await embedding_scanner.check_model_version_exists(model_version_id):
|
||||||
return {'success': False, 'error': 'Model version already exists in embedding library'}
|
return {'success': False, 'error': 'Model version already exists in embedding library'}
|
||||||
|
|
||||||
# Get metadata provider instead of civitai client directly
|
# Get metadata provider based on source parameter
|
||||||
metadata_provider = await get_default_metadata_provider()
|
if source == 'civarchive':
|
||||||
|
from .metadata_service import get_metadata_provider
|
||||||
|
metadata_provider = await get_metadata_provider('civarchive')
|
||||||
|
else:
|
||||||
|
metadata_provider = await get_default_metadata_provider()
|
||||||
|
|
||||||
# Get version info based on the provided identifier
|
# Get version info based on the provided identifier
|
||||||
version_info = await metadata_provider.get_model_version(model_id, model_version_id)
|
version_info = await metadata_provider.get_model_version(model_id, model_version_id)
|
||||||
@@ -396,8 +395,6 @@ class DownloadManager:
|
|||||||
model_type: str = "lora", download_id: str = None) -> Dict:
|
model_type: str = "lora", download_id: str = None) -> Dict:
|
||||||
"""Execute the actual download process including preview images and model files"""
|
"""Execute the actual download process including preview images and model files"""
|
||||||
try:
|
try:
|
||||||
civitai_client = await self._get_civitai_client()
|
|
||||||
|
|
||||||
# Extract original filename details
|
# Extract original filename details
|
||||||
original_filename = os.path.basename(metadata.file_path)
|
original_filename = os.path.basename(metadata.file_path)
|
||||||
base_name, extension = os.path.splitext(original_filename)
|
base_name, extension = os.path.splitext(original_filename)
|
||||||
@@ -504,11 +501,13 @@ class DownloadManager:
|
|||||||
|
|
||||||
# Download model file with progress tracking using downloader
|
# Download model file with progress tracking using downloader
|
||||||
downloader = await get_downloader()
|
downloader = await get_downloader()
|
||||||
|
# Determine if the download URL is from Civitai
|
||||||
|
use_auth = download_url.startswith("https://civitai.com/api/download/")
|
||||||
success, result = await downloader.download_file(
|
success, result = await downloader.download_file(
|
||||||
download_url,
|
download_url,
|
||||||
save_path, # Use full path instead of separate dir and filename
|
save_path, # Use full path instead of separate dir and filename
|
||||||
progress_callback=lambda p: self._handle_download_progress(p, progress_callback),
|
progress_callback=lambda p: self._handle_download_progress(p, progress_callback),
|
||||||
use_auth=True # Model downloads need authentication
|
use_auth=use_auth # Only use authentication for Civitai downloads
|
||||||
)
|
)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
|
|||||||
@@ -52,7 +52,16 @@ async def initialize_metadata_providers():
|
|||||||
logger.debug("Civitai API metadata provider registered")
|
logger.debug("Civitai API metadata provider registered")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize Civitai API metadata provider: {e}")
|
logger.error(f"Failed to initialize Civitai API metadata provider: {e}")
|
||||||
|
|
||||||
|
# Register CivArchive provider, but do NOT add to fallback providers
|
||||||
|
try:
|
||||||
|
from .model_metadata_provider import CivArchiveModelMetadataProvider
|
||||||
|
civarchive_provider = CivArchiveModelMetadataProvider()
|
||||||
|
provider_manager.register_provider('civarchive', civarchive_provider)
|
||||||
|
logger.debug("CivArchive metadata provider registered (not included in fallback)")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize CivArchive metadata provider: {e}")
|
||||||
|
|
||||||
# Set up fallback provider based on available providers
|
# Set up fallback provider based on available providers
|
||||||
if len(providers) > 1:
|
if len(providers) > 1:
|
||||||
# Always use Civitai API first, then Archive DB
|
# Always use Civitai API first, then Archive DB
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ from abc import ABC, abstractmethod
|
|||||||
import json
|
import json
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
import logging
|
import logging
|
||||||
|
import aiohttp
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
from typing import Optional, Dict, List, Tuple, Any
|
from typing import Optional, Dict, List, Tuple, Any
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -55,6 +57,142 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
|||||||
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
||||||
return await self.client.get_model_metadata(model_id)
|
return await self.client.get_model_metadata(model_id)
|
||||||
|
|
||||||
|
class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
||||||
|
"""Provider that uses CivArchive HTML page parsing for metadata"""
|
||||||
|
|
||||||
|
def __init__(self, session: aiohttp.ClientSession = None):
|
||||||
|
self.session = session
|
||||||
|
self._own_session = session is None
|
||||||
|
|
||||||
|
async def _get_session(self):
|
||||||
|
"""Get or create HTTP session"""
|
||||||
|
if self.session is None:
|
||||||
|
self.session = aiohttp.ClientSession()
|
||||||
|
return self.session
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close HTTP session if we own it"""
|
||||||
|
if self._own_session and self.session:
|
||||||
|
await self.session.close()
|
||||||
|
self.session = None
|
||||||
|
|
||||||
|
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
||||||
|
"""Not supported by CivArchive provider"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||||
|
"""Not supported by CivArchive provider"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||||
|
"""Get specific model version by parsing CivArchive HTML page"""
|
||||||
|
if model_id is None or version_id is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Construct CivArchive URL
|
||||||
|
url = f"https://civarchive.com/models/{model_id}?modelVersionId={version_id}"
|
||||||
|
|
||||||
|
session = await self._get_session()
|
||||||
|
async with session.get(url) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return None
|
||||||
|
|
||||||
|
html_content = await response.text()
|
||||||
|
|
||||||
|
# Parse HTML to extract JSON data
|
||||||
|
soup = BeautifulSoup(html_content, 'html.parser')
|
||||||
|
script_tag = soup.find('script', {'id': '__NEXT_DATA__', 'type': 'application/json'})
|
||||||
|
|
||||||
|
if not script_tag:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Parse JSON content
|
||||||
|
json_data = json.loads(script_tag.string)
|
||||||
|
model_data = json_data.get('props', {}).get('pageProps', {}).get('model')
|
||||||
|
|
||||||
|
if not model_data or 'version' not in model_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract version data as base
|
||||||
|
version = model_data['version'].copy()
|
||||||
|
|
||||||
|
# Restructure stats
|
||||||
|
if 'downloadCount' in version and 'ratingCount' in version and 'rating' in version:
|
||||||
|
version['stats'] = {
|
||||||
|
'downloadCount': version.pop('downloadCount'),
|
||||||
|
'ratingCount': version.pop('ratingCount'),
|
||||||
|
'rating': version.pop('rating')
|
||||||
|
}
|
||||||
|
|
||||||
|
# Rename trigger to trainedWords
|
||||||
|
if 'trigger' in version:
|
||||||
|
version['trainedWords'] = version.pop('trigger')
|
||||||
|
|
||||||
|
# Transform files data to expected format
|
||||||
|
if 'files' in version:
|
||||||
|
transformed_files = []
|
||||||
|
for file_data in version['files']:
|
||||||
|
# Find first available mirror (deletedAt is null)
|
||||||
|
available_mirror = None
|
||||||
|
for mirror in file_data.get('mirrors', []):
|
||||||
|
if mirror.get('deletedAt') is None:
|
||||||
|
available_mirror = mirror
|
||||||
|
break
|
||||||
|
|
||||||
|
# Create transformed file entry
|
||||||
|
transformed_file = {
|
||||||
|
'id': file_data.get('id'),
|
||||||
|
'sizeKB': file_data.get('sizeKB'),
|
||||||
|
'name': available_mirror.get('filename', file_data.get('name')) if available_mirror else file_data.get('name'),
|
||||||
|
'type': file_data.get('type'),
|
||||||
|
'downloadUrl': available_mirror.get('url') if available_mirror else None,
|
||||||
|
'primary': True,
|
||||||
|
'mirrors': file_data.get('mirrors', [])
|
||||||
|
}
|
||||||
|
|
||||||
|
# Transform hash format
|
||||||
|
if 'sha256' in file_data:
|
||||||
|
transformed_file['hashes'] = {
|
||||||
|
'SHA256': file_data['sha256'].upper()
|
||||||
|
}
|
||||||
|
|
||||||
|
transformed_files.append(transformed_file)
|
||||||
|
|
||||||
|
version['files'] = transformed_files
|
||||||
|
|
||||||
|
# Add model information
|
||||||
|
version['model'] = {
|
||||||
|
'name': model_data.get('name'),
|
||||||
|
'type': model_data.get('type'),
|
||||||
|
'nsfw': model_data.get('is_nsfw', False),
|
||||||
|
'description': model_data.get('description'),
|
||||||
|
'tags': model_data.get('tags', [])
|
||||||
|
}
|
||||||
|
|
||||||
|
version['creator'] = {
|
||||||
|
'username': model_data.get('username'),
|
||||||
|
'image': ''
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add source identifier
|
||||||
|
version['source'] = 'civarchive'
|
||||||
|
version['is_deleted'] = json_data.get('query', {}).get('is_deleted', False)
|
||||||
|
|
||||||
|
return version
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching CivArchive model version {model_id}/{version_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
|
"""Not supported by CivArchive provider - requires both model_id and version_id"""
|
||||||
|
return None, "CivArchive provider requires both model_id and version_id"
|
||||||
|
|
||||||
|
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
||||||
|
"""Not supported by CivArchive provider"""
|
||||||
|
return None, 404
|
||||||
|
|
||||||
class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
||||||
"""Provider that uses SQLite database for metadata"""
|
"""Provider that uses SQLite database for metadata"""
|
||||||
|
|
||||||
|
|||||||
@@ -632,6 +632,7 @@ class ModelRouteUtils:
|
|||||||
}, status=400)
|
}, status=400)
|
||||||
|
|
||||||
use_default_paths = data.get('use_default_paths', False)
|
use_default_paths = data.get('use_default_paths', False)
|
||||||
|
source = data.get('source') # Optional source parameter
|
||||||
|
|
||||||
# Pass the download_id to download_from_civitai
|
# Pass the download_id to download_from_civitai
|
||||||
result = await download_manager.download_from_civitai(
|
result = await download_manager.download_from_civitai(
|
||||||
@@ -641,7 +642,8 @@ class ModelRouteUtils:
|
|||||||
relative_path=data.get('relative_path', ''),
|
relative_path=data.get('relative_path', ''),
|
||||||
use_default_paths=use_default_paths,
|
use_default_paths=use_default_paths,
|
||||||
progress_callback=progress_callback,
|
progress_callback=progress_callback,
|
||||||
download_id=download_id # Pass download_id explicitly
|
download_id=download_id, # Pass download_id explicitly
|
||||||
|
source=source # Pass source parameter
|
||||||
)
|
)
|
||||||
|
|
||||||
# Include download_id in the response
|
# Include download_id in the response
|
||||||
|
|||||||
@@ -9,3 +9,4 @@ numpy
|
|||||||
natsort
|
natsort
|
||||||
GitPython
|
GitPython
|
||||||
aiosqlite
|
aiosqlite
|
||||||
|
beautifulsoup4
|
||||||
|
|||||||
Reference in New Issue
Block a user