feat: add CivArchive metadata provider and support for optional source parameter in downloads

This commit is contained in:
Will Miao
2025-09-12 21:13:15 +08:00
parent 00b77581fc
commit d05076d258
6 changed files with 173 additions and 19 deletions

View File

@@ -523,6 +523,7 @@ class BaseModelRoutes(ABC):
model_version_id = request.query.get('model_version_id')
download_id = request.query.get('download_id')
use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true'
source = request.query.get('source') # Optional source parameter
# Create a data dictionary that mimics what would be received from a POST request
data = {
@@ -538,6 +539,10 @@ class BaseModelRoutes(ABC):
data['use_default_paths'] = use_default_paths
# Add source parameter if provided
if source:
data['source'] = source
# Create a mock request object with the data
future = asyncio.get_event_loop().create_future()
future.set_result(data)

View File

@@ -36,17 +36,10 @@ class DownloadManager:
return
self._initialized = True
self._civitai_client = None # Will be lazily initialized
# Add download management
self._active_downloads = OrderedDict() # download_id -> download_info
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
self._download_tasks = {} # download_id -> asyncio.Task
async def _get_civitai_client(self):
"""Lazily initialize CivitaiClient from registry"""
if self._civitai_client is None:
self._civitai_client = await ServiceRegistry.get_civitai_client()
return self._civitai_client
async def _get_lora_scanner(self):
"""Get the lora scanner from registry"""
@@ -59,7 +52,7 @@ class DownloadManager:
async def download_from_civitai(self, model_id: int = None, model_version_id: int = None,
save_dir: str = None, relative_path: str = '',
progress_callback=None, use_default_paths: bool = False,
download_id: str = None) -> Dict:
download_id: str = None, source: str = None) -> Dict:
"""Download model from Civitai with task tracking and concurrency control
Args:
@@ -70,6 +63,7 @@ class DownloadManager:
progress_callback: Callback function for progress updates
use_default_paths: Flag to use default paths
download_id: Unique identifier for this download task
source: Optional source parameter to specify metadata provider
Returns:
Dict with download result
@@ -93,7 +87,7 @@ class DownloadManager:
download_task = asyncio.create_task(
self._download_with_semaphore(
task_id, model_id, model_version_id, save_dir,
relative_path, progress_callback, use_default_paths
relative_path, progress_callback, use_default_paths, source
)
)
@@ -114,7 +108,8 @@ class DownloadManager:
async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int,
save_dir: str, relative_path: str,
progress_callback=None, use_default_paths: bool = False):
progress_callback=None, use_default_paths: bool = False,
source: str = None):
"""Execute download with semaphore to limit concurrency"""
# Update status to waiting
if task_id in self._active_downloads:
@@ -144,7 +139,7 @@ class DownloadManager:
result = await self._execute_original_download(
model_id, model_version_id, save_dir,
relative_path, tracking_callback, use_default_paths,
task_id
task_id, source
)
# Update status based on result
@@ -179,7 +174,7 @@ class DownloadManager:
async def _execute_original_download(self, model_id, model_version_id, save_dir,
relative_path, progress_callback, use_default_paths,
download_id=None):
download_id=None, source=None):
"""Wrapper for original download_from_civitai implementation"""
try:
# Check if model version already exists in library
@@ -201,8 +196,12 @@ class DownloadManager:
if await embedding_scanner.check_model_version_exists(model_version_id):
return {'success': False, 'error': 'Model version already exists in embedding library'}
# Get metadata provider instead of civitai client directly
metadata_provider = await get_default_metadata_provider()
# Get metadata provider based on source parameter
if source == 'civarchive':
from .metadata_service import get_metadata_provider
metadata_provider = await get_metadata_provider('civarchive')
else:
metadata_provider = await get_default_metadata_provider()
# Get version info based on the provided identifier
version_info = await metadata_provider.get_model_version(model_id, model_version_id)
@@ -396,8 +395,6 @@ class DownloadManager:
model_type: str = "lora", download_id: str = None) -> Dict:
"""Execute the actual download process including preview images and model files"""
try:
civitai_client = await self._get_civitai_client()
# Extract original filename details
original_filename = os.path.basename(metadata.file_path)
base_name, extension = os.path.splitext(original_filename)
@@ -504,11 +501,13 @@ class DownloadManager:
# Download model file with progress tracking using downloader
downloader = await get_downloader()
# Determine if the download URL is from Civitai
use_auth = download_url.startswith("https://civitai.com/api/download/")
success, result = await downloader.download_file(
download_url,
save_path, # Use full path instead of separate dir and filename
progress_callback=lambda p: self._handle_download_progress(p, progress_callback),
use_auth=True # Model downloads need authentication
use_auth=use_auth # Only use authentication for Civitai downloads
)
if not success:

View File

@@ -52,7 +52,16 @@ async def initialize_metadata_providers():
logger.debug("Civitai API metadata provider registered")
except Exception as e:
logger.error(f"Failed to initialize Civitai API metadata provider: {e}")
# Register CivArchive provider, but do NOT add to fallback providers
try:
from .model_metadata_provider import CivArchiveModelMetadataProvider
civarchive_provider = CivArchiveModelMetadataProvider()
provider_manager.register_provider('civarchive', civarchive_provider)
logger.debug("CivArchive metadata provider registered (not included in fallback)")
except Exception as e:
logger.error(f"Failed to initialize CivArchive metadata provider: {e}")
# Set up fallback provider based on available providers
if len(providers) > 1:
# Always use Civitai API first, then Archive DB

View File

@@ -2,6 +2,8 @@ from abc import ABC, abstractmethod
import json
import aiosqlite
import logging
import aiohttp
from bs4 import BeautifulSoup
from typing import Optional, Dict, List, Tuple, Any
logger = logging.getLogger(__name__)
@@ -55,6 +57,142 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
return await self.client.get_model_metadata(model_id)
class CivArchiveModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses CivArchive HTML page parsing for metadata"""
def __init__(self, session: aiohttp.ClientSession = None):
self.session = session
self._own_session = session is None
async def _get_session(self):
"""Get or create HTTP session"""
if self.session is None:
self.session = aiohttp.ClientSession()
return self.session
async def close(self):
"""Close HTTP session if we own it"""
if self._own_session and self.session:
await self.session.close()
self.session = None
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
"""Not supported by CivArchive provider"""
return None
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
"""Not supported by CivArchive provider"""
return None
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
"""Get specific model version by parsing CivArchive HTML page"""
if model_id is None or version_id is None:
return None
try:
# Construct CivArchive URL
url = f"https://civarchive.com/models/{model_id}?modelVersionId={version_id}"
session = await self._get_session()
async with session.get(url) as response:
if response.status != 200:
return None
html_content = await response.text()
# Parse HTML to extract JSON data
soup = BeautifulSoup(html_content, 'html.parser')
script_tag = soup.find('script', {'id': '__NEXT_DATA__', 'type': 'application/json'})
if not script_tag:
return None
# Parse JSON content
json_data = json.loads(script_tag.string)
model_data = json_data.get('props', {}).get('pageProps', {}).get('model')
if not model_data or 'version' not in model_data:
return None
# Extract version data as base
version = model_data['version'].copy()
# Restructure stats
if 'downloadCount' in version and 'ratingCount' in version and 'rating' in version:
version['stats'] = {
'downloadCount': version.pop('downloadCount'),
'ratingCount': version.pop('ratingCount'),
'rating': version.pop('rating')
}
# Rename trigger to trainedWords
if 'trigger' in version:
version['trainedWords'] = version.pop('trigger')
# Transform files data to expected format
if 'files' in version:
transformed_files = []
for file_data in version['files']:
# Find first available mirror (deletedAt is null)
available_mirror = None
for mirror in file_data.get('mirrors', []):
if mirror.get('deletedAt') is None:
available_mirror = mirror
break
# Create transformed file entry
transformed_file = {
'id': file_data.get('id'),
'sizeKB': file_data.get('sizeKB'),
'name': available_mirror.get('filename', file_data.get('name')) if available_mirror else file_data.get('name'),
'type': file_data.get('type'),
'downloadUrl': available_mirror.get('url') if available_mirror else None,
'primary': True,
'mirrors': file_data.get('mirrors', [])
}
# Transform hash format
if 'sha256' in file_data:
transformed_file['hashes'] = {
'SHA256': file_data['sha256'].upper()
}
transformed_files.append(transformed_file)
version['files'] = transformed_files
# Add model information
version['model'] = {
'name': model_data.get('name'),
'type': model_data.get('type'),
'nsfw': model_data.get('is_nsfw', False),
'description': model_data.get('description'),
'tags': model_data.get('tags', [])
}
version['creator'] = {
'username': model_data.get('username'),
'image': ''
}
# Add source identifier
version['source'] = 'civarchive'
version['is_deleted'] = json_data.get('query', {}).get('is_deleted', False)
return version
except Exception as e:
logger.error(f"Error fetching CivArchive model version {model_id}/{version_id}: {e}")
return None
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Not supported by CivArchive provider - requires both model_id and version_id"""
return None, "CivArchive provider requires both model_id and version_id"
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
"""Not supported by CivArchive provider"""
return None, 404
class SQLiteModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses SQLite database for metadata"""

View File

@@ -632,6 +632,7 @@ class ModelRouteUtils:
}, status=400)
use_default_paths = data.get('use_default_paths', False)
source = data.get('source') # Optional source parameter
# Pass the download_id to download_from_civitai
result = await download_manager.download_from_civitai(
@@ -641,7 +642,8 @@ class ModelRouteUtils:
relative_path=data.get('relative_path', ''),
use_default_paths=use_default_paths,
progress_callback=progress_callback,
download_id=download_id # Pass download_id explicitly
download_id=download_id, # Pass download_id explicitly
source=source # Pass source parameter
)
# Include download_id in the response

View File

@@ -9,3 +9,4 @@ numpy
natsort
GitPython
aiosqlite
beautifulsoup4