feat: add CivArchive metadata provider and support for optional source parameter in downloads

This commit is contained in:
Will Miao
2025-09-12 21:13:15 +08:00
parent 00b77581fc
commit d05076d258
6 changed files with 173 additions and 19 deletions

View File

@@ -523,6 +523,7 @@ class BaseModelRoutes(ABC):
model_version_id = request.query.get('model_version_id') model_version_id = request.query.get('model_version_id')
download_id = request.query.get('download_id') download_id = request.query.get('download_id')
use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true' use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true'
source = request.query.get('source') # Optional source parameter
# Create a data dictionary that mimics what would be received from a POST request # Create a data dictionary that mimics what would be received from a POST request
data = { data = {
@@ -538,6 +539,10 @@ class BaseModelRoutes(ABC):
data['use_default_paths'] = use_default_paths data['use_default_paths'] = use_default_paths
# Add source parameter if provided
if source:
data['source'] = source
# Create a mock request object with the data # Create a mock request object with the data
future = asyncio.get_event_loop().create_future() future = asyncio.get_event_loop().create_future()
future.set_result(data) future.set_result(data)

View File

@@ -36,17 +36,10 @@ class DownloadManager:
return return
self._initialized = True self._initialized = True
self._civitai_client = None # Will be lazily initialized
# Add download management # Add download management
self._active_downloads = OrderedDict() # download_id -> download_info self._active_downloads = OrderedDict() # download_id -> download_info
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
self._download_tasks = {} # download_id -> asyncio.Task self._download_tasks = {} # download_id -> asyncio.Task
async def _get_civitai_client(self):
"""Lazily initialize CivitaiClient from registry"""
if self._civitai_client is None:
self._civitai_client = await ServiceRegistry.get_civitai_client()
return self._civitai_client
async def _get_lora_scanner(self): async def _get_lora_scanner(self):
"""Get the lora scanner from registry""" """Get the lora scanner from registry"""
@@ -59,7 +52,7 @@ class DownloadManager:
async def download_from_civitai(self, model_id: int = None, model_version_id: int = None, async def download_from_civitai(self, model_id: int = None, model_version_id: int = None,
save_dir: str = None, relative_path: str = '', save_dir: str = None, relative_path: str = '',
progress_callback=None, use_default_paths: bool = False, progress_callback=None, use_default_paths: bool = False,
download_id: str = None) -> Dict: download_id: str = None, source: str = None) -> Dict:
"""Download model from Civitai with task tracking and concurrency control """Download model from Civitai with task tracking and concurrency control
Args: Args:
@@ -70,6 +63,7 @@ class DownloadManager:
progress_callback: Callback function for progress updates progress_callback: Callback function for progress updates
use_default_paths: Flag to use default paths use_default_paths: Flag to use default paths
download_id: Unique identifier for this download task download_id: Unique identifier for this download task
source: Optional source parameter to specify metadata provider
Returns: Returns:
Dict with download result Dict with download result
@@ -93,7 +87,7 @@ class DownloadManager:
download_task = asyncio.create_task( download_task = asyncio.create_task(
self._download_with_semaphore( self._download_with_semaphore(
task_id, model_id, model_version_id, save_dir, task_id, model_id, model_version_id, save_dir,
relative_path, progress_callback, use_default_paths relative_path, progress_callback, use_default_paths, source
) )
) )
@@ -114,7 +108,8 @@ class DownloadManager:
async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int, async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int,
save_dir: str, relative_path: str, save_dir: str, relative_path: str,
progress_callback=None, use_default_paths: bool = False): progress_callback=None, use_default_paths: bool = False,
source: str = None):
"""Execute download with semaphore to limit concurrency""" """Execute download with semaphore to limit concurrency"""
# Update status to waiting # Update status to waiting
if task_id in self._active_downloads: if task_id in self._active_downloads:
@@ -144,7 +139,7 @@ class DownloadManager:
result = await self._execute_original_download( result = await self._execute_original_download(
model_id, model_version_id, save_dir, model_id, model_version_id, save_dir,
relative_path, tracking_callback, use_default_paths, relative_path, tracking_callback, use_default_paths,
task_id task_id, source
) )
# Update status based on result # Update status based on result
@@ -179,7 +174,7 @@ class DownloadManager:
async def _execute_original_download(self, model_id, model_version_id, save_dir, async def _execute_original_download(self, model_id, model_version_id, save_dir,
relative_path, progress_callback, use_default_paths, relative_path, progress_callback, use_default_paths,
download_id=None): download_id=None, source=None):
"""Wrapper for original download_from_civitai implementation""" """Wrapper for original download_from_civitai implementation"""
try: try:
# Check if model version already exists in library # Check if model version already exists in library
@@ -201,8 +196,12 @@ class DownloadManager:
if await embedding_scanner.check_model_version_exists(model_version_id): if await embedding_scanner.check_model_version_exists(model_version_id):
return {'success': False, 'error': 'Model version already exists in embedding library'} return {'success': False, 'error': 'Model version already exists in embedding library'}
# Get metadata provider instead of civitai client directly # Get metadata provider based on source parameter
metadata_provider = await get_default_metadata_provider() if source == 'civarchive':
from .metadata_service import get_metadata_provider
metadata_provider = await get_metadata_provider('civarchive')
else:
metadata_provider = await get_default_metadata_provider()
# Get version info based on the provided identifier # Get version info based on the provided identifier
version_info = await metadata_provider.get_model_version(model_id, model_version_id) version_info = await metadata_provider.get_model_version(model_id, model_version_id)
@@ -396,8 +395,6 @@ class DownloadManager:
model_type: str = "lora", download_id: str = None) -> Dict: model_type: str = "lora", download_id: str = None) -> Dict:
"""Execute the actual download process including preview images and model files""" """Execute the actual download process including preview images and model files"""
try: try:
civitai_client = await self._get_civitai_client()
# Extract original filename details # Extract original filename details
original_filename = os.path.basename(metadata.file_path) original_filename = os.path.basename(metadata.file_path)
base_name, extension = os.path.splitext(original_filename) base_name, extension = os.path.splitext(original_filename)
@@ -504,11 +501,13 @@ class DownloadManager:
# Download model file with progress tracking using downloader # Download model file with progress tracking using downloader
downloader = await get_downloader() downloader = await get_downloader()
# Determine if the download URL is from Civitai
use_auth = download_url.startswith("https://civitai.com/api/download/")
success, result = await downloader.download_file( success, result = await downloader.download_file(
download_url, download_url,
save_path, # Use full path instead of separate dir and filename save_path, # Use full path instead of separate dir and filename
progress_callback=lambda p: self._handle_download_progress(p, progress_callback), progress_callback=lambda p: self._handle_download_progress(p, progress_callback),
use_auth=True # Model downloads need authentication use_auth=use_auth # Only use authentication for Civitai downloads
) )
if not success: if not success:

View File

@@ -52,7 +52,16 @@ async def initialize_metadata_providers():
logger.debug("Civitai API metadata provider registered") logger.debug("Civitai API metadata provider registered")
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize Civitai API metadata provider: {e}") logger.error(f"Failed to initialize Civitai API metadata provider: {e}")
# Register CivArchive provider, but do NOT add to fallback providers
try:
from .model_metadata_provider import CivArchiveModelMetadataProvider
civarchive_provider = CivArchiveModelMetadataProvider()
provider_manager.register_provider('civarchive', civarchive_provider)
logger.debug("CivArchive metadata provider registered (not included in fallback)")
except Exception as e:
logger.error(f"Failed to initialize CivArchive metadata provider: {e}")
# Set up fallback provider based on available providers # Set up fallback provider based on available providers
if len(providers) > 1: if len(providers) > 1:
# Always use Civitai API first, then Archive DB # Always use Civitai API first, then Archive DB

View File

@@ -2,6 +2,8 @@ from abc import ABC, abstractmethod
import json import json
import aiosqlite import aiosqlite
import logging import logging
import aiohttp
from bs4 import BeautifulSoup
from typing import Optional, Dict, List, Tuple, Any from typing import Optional, Dict, List, Tuple, Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -55,6 +57,142 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]: async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
return await self.client.get_model_metadata(model_id) return await self.client.get_model_metadata(model_id)
class CivArchiveModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses CivArchive HTML page parsing for metadata"""
def __init__(self, session: aiohttp.ClientSession = None):
self.session = session
self._own_session = session is None
async def _get_session(self):
"""Get or create HTTP session"""
if self.session is None:
self.session = aiohttp.ClientSession()
return self.session
async def close(self):
"""Close HTTP session if we own it"""
if self._own_session and self.session:
await self.session.close()
self.session = None
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
"""Not supported by CivArchive provider"""
return None
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
"""Not supported by CivArchive provider"""
return None
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
"""Get specific model version by parsing CivArchive HTML page"""
if model_id is None or version_id is None:
return None
try:
# Construct CivArchive URL
url = f"https://civarchive.com/models/{model_id}?modelVersionId={version_id}"
session = await self._get_session()
async with session.get(url) as response:
if response.status != 200:
return None
html_content = await response.text()
# Parse HTML to extract JSON data
soup = BeautifulSoup(html_content, 'html.parser')
script_tag = soup.find('script', {'id': '__NEXT_DATA__', 'type': 'application/json'})
if not script_tag:
return None
# Parse JSON content
json_data = json.loads(script_tag.string)
model_data = json_data.get('props', {}).get('pageProps', {}).get('model')
if not model_data or 'version' not in model_data:
return None
# Extract version data as base
version = model_data['version'].copy()
# Restructure stats
if 'downloadCount' in version and 'ratingCount' in version and 'rating' in version:
version['stats'] = {
'downloadCount': version.pop('downloadCount'),
'ratingCount': version.pop('ratingCount'),
'rating': version.pop('rating')
}
# Rename trigger to trainedWords
if 'trigger' in version:
version['trainedWords'] = version.pop('trigger')
# Transform files data to expected format
if 'files' in version:
transformed_files = []
for file_data in version['files']:
# Find first available mirror (deletedAt is null)
available_mirror = None
for mirror in file_data.get('mirrors', []):
if mirror.get('deletedAt') is None:
available_mirror = mirror
break
# Create transformed file entry
transformed_file = {
'id': file_data.get('id'),
'sizeKB': file_data.get('sizeKB'),
'name': available_mirror.get('filename', file_data.get('name')) if available_mirror else file_data.get('name'),
'type': file_data.get('type'),
'downloadUrl': available_mirror.get('url') if available_mirror else None,
'primary': True,
'mirrors': file_data.get('mirrors', [])
}
# Transform hash format
if 'sha256' in file_data:
transformed_file['hashes'] = {
'SHA256': file_data['sha256'].upper()
}
transformed_files.append(transformed_file)
version['files'] = transformed_files
# Add model information
version['model'] = {
'name': model_data.get('name'),
'type': model_data.get('type'),
'nsfw': model_data.get('is_nsfw', False),
'description': model_data.get('description'),
'tags': model_data.get('tags', [])
}
version['creator'] = {
'username': model_data.get('username'),
'image': ''
}
# Add source identifier
version['source'] = 'civarchive'
version['is_deleted'] = json_data.get('query', {}).get('is_deleted', False)
return version
except Exception as e:
logger.error(f"Error fetching CivArchive model version {model_id}/{version_id}: {e}")
return None
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Not supported by CivArchive provider - requires both model_id and version_id"""
return None, "CivArchive provider requires both model_id and version_id"
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
"""Not supported by CivArchive provider"""
return None, 404
class SQLiteModelMetadataProvider(ModelMetadataProvider): class SQLiteModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses SQLite database for metadata""" """Provider that uses SQLite database for metadata"""

View File

@@ -632,6 +632,7 @@ class ModelRouteUtils:
}, status=400) }, status=400)
use_default_paths = data.get('use_default_paths', False) use_default_paths = data.get('use_default_paths', False)
source = data.get('source') # Optional source parameter
# Pass the download_id to download_from_civitai # Pass the download_id to download_from_civitai
result = await download_manager.download_from_civitai( result = await download_manager.download_from_civitai(
@@ -641,7 +642,8 @@ class ModelRouteUtils:
relative_path=data.get('relative_path', ''), relative_path=data.get('relative_path', ''),
use_default_paths=use_default_paths, use_default_paths=use_default_paths,
progress_callback=progress_callback, progress_callback=progress_callback,
download_id=download_id # Pass download_id explicitly download_id=download_id, # Pass download_id explicitly
source=source # Pass source parameter
) )
# Include download_id in the response # Include download_id in the response

View File

@@ -9,3 +9,4 @@ numpy
natsort natsort
GitPython GitPython
aiosqlite aiosqlite
beautifulsoup4