From d05076d258f42a068b24db8938c726014ba9ba4a Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 12 Sep 2025 21:13:15 +0800 Subject: [PATCH] feat: add CivArchive metadata provider and support for optional source parameter in downloads --- py/routes/base_model_routes.py | 5 + py/services/download_manager.py | 33 +++--- py/services/metadata_service.py | 11 +- py/services/model_metadata_provider.py | 138 +++++++++++++++++++++++++ py/utils/routes_common.py | 4 +- requirements.txt | 1 + 6 files changed, 173 insertions(+), 19 deletions(-) diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py index be0f2695..a7970f0b 100644 --- a/py/routes/base_model_routes.py +++ b/py/routes/base_model_routes.py @@ -523,6 +523,7 @@ class BaseModelRoutes(ABC): model_version_id = request.query.get('model_version_id') download_id = request.query.get('download_id') use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true' + source = request.query.get('source') # Optional source parameter # Create a data dictionary that mimics what would be received from a POST request data = { @@ -538,6 +539,10 @@ class BaseModelRoutes(ABC): data['use_default_paths'] = use_default_paths + # Add source parameter if provided + if source: + data['source'] = source + # Create a mock request object with the data future = asyncio.get_event_loop().create_future() future.set_result(data) diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 9f090b20..6638d7d2 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -36,17 +36,10 @@ class DownloadManager: return self._initialized = True - self._civitai_client = None # Will be lazily initialized # Add download management self._active_downloads = OrderedDict() # download_id -> download_info self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads self._download_tasks = {} # download_id -> asyncio.Task - - async def _get_civitai_client(self): - """Lazily initialize CivitaiClient from registry""" - if self._civitai_client is None: - self._civitai_client = await ServiceRegistry.get_civitai_client() - return self._civitai_client async def _get_lora_scanner(self): """Get the lora scanner from registry""" @@ -59,7 +52,7 @@ class DownloadManager: async def download_from_civitai(self, model_id: int = None, model_version_id: int = None, save_dir: str = None, relative_path: str = '', progress_callback=None, use_default_paths: bool = False, - download_id: str = None) -> Dict: + download_id: str = None, source: str = None) -> Dict: """Download model from Civitai with task tracking and concurrency control Args: @@ -70,6 +63,7 @@ class DownloadManager: progress_callback: Callback function for progress updates use_default_paths: Flag to use default paths download_id: Unique identifier for this download task + source: Optional source parameter to specify metadata provider Returns: Dict with download result @@ -93,7 +87,7 @@ class DownloadManager: download_task = asyncio.create_task( self._download_with_semaphore( task_id, model_id, model_version_id, save_dir, - relative_path, progress_callback, use_default_paths + relative_path, progress_callback, use_default_paths, source ) ) @@ -114,7 +108,8 @@ class DownloadManager: async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int, save_dir: str, relative_path: str, - progress_callback=None, use_default_paths: bool = False): + progress_callback=None, use_default_paths: bool = False, + source: str = None): """Execute download with semaphore to limit concurrency""" # Update status to waiting if task_id in self._active_downloads: @@ -144,7 +139,7 @@ class DownloadManager: result = await self._execute_original_download( model_id, model_version_id, save_dir, relative_path, tracking_callback, use_default_paths, - task_id + task_id, source ) # Update status based on result @@ -179,7 +174,7 @@ class DownloadManager: async def _execute_original_download(self, model_id, model_version_id, save_dir, relative_path, progress_callback, use_default_paths, - download_id=None): + download_id=None, source=None): """Wrapper for original download_from_civitai implementation""" try: # Check if model version already exists in library @@ -201,8 +196,12 @@ class DownloadManager: if await embedding_scanner.check_model_version_exists(model_version_id): return {'success': False, 'error': 'Model version already exists in embedding library'} - # Get metadata provider instead of civitai client directly - metadata_provider = await get_default_metadata_provider() + # Get metadata provider based on source parameter + if source == 'civarchive': + from .metadata_service import get_metadata_provider + metadata_provider = await get_metadata_provider('civarchive') + else: + metadata_provider = await get_default_metadata_provider() # Get version info based on the provided identifier version_info = await metadata_provider.get_model_version(model_id, model_version_id) @@ -396,8 +395,6 @@ class DownloadManager: model_type: str = "lora", download_id: str = None) -> Dict: """Execute the actual download process including preview images and model files""" try: - civitai_client = await self._get_civitai_client() - # Extract original filename details original_filename = os.path.basename(metadata.file_path) base_name, extension = os.path.splitext(original_filename) @@ -504,11 +501,13 @@ class DownloadManager: # Download model file with progress tracking using downloader downloader = await get_downloader() + # Determine if the download URL is from Civitai + use_auth = download_url.startswith("https://civitai.com/api/download/") success, result = await downloader.download_file( download_url, save_path, # Use full path instead of separate dir and filename progress_callback=lambda p: self._handle_download_progress(p, progress_callback), - use_auth=True # Model downloads need authentication + use_auth=use_auth # Only use authentication for Civitai downloads ) if not success: diff --git a/py/services/metadata_service.py b/py/services/metadata_service.py index 4c20b0b8..6a4f9dd8 100644 --- a/py/services/metadata_service.py +++ b/py/services/metadata_service.py @@ -52,7 +52,16 @@ async def initialize_metadata_providers(): logger.debug("Civitai API metadata provider registered") except Exception as e: logger.error(f"Failed to initialize Civitai API metadata provider: {e}") - + + # Register CivArchive provider, but do NOT add to fallback providers + try: + from .model_metadata_provider import CivArchiveModelMetadataProvider + civarchive_provider = CivArchiveModelMetadataProvider() + provider_manager.register_provider('civarchive', civarchive_provider) + logger.debug("CivArchive metadata provider registered (not included in fallback)") + except Exception as e: + logger.error(f"Failed to initialize CivArchive metadata provider: {e}") + # Set up fallback provider based on available providers if len(providers) > 1: # Always use Civitai API first, then Archive DB diff --git a/py/services/model_metadata_provider.py b/py/services/model_metadata_provider.py index 9f54a2e7..b02d0c0e 100644 --- a/py/services/model_metadata_provider.py +++ b/py/services/model_metadata_provider.py @@ -2,6 +2,8 @@ from abc import ABC, abstractmethod import json import aiosqlite import logging +import aiohttp +from bs4 import BeautifulSoup from typing import Optional, Dict, List, Tuple, Any logger = logging.getLogger(__name__) @@ -55,6 +57,142 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider): async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]: return await self.client.get_model_metadata(model_id) +class CivArchiveModelMetadataProvider(ModelMetadataProvider): + """Provider that uses CivArchive HTML page parsing for metadata""" + + def __init__(self, session: aiohttp.ClientSession = None): + self.session = session + self._own_session = session is None + + async def _get_session(self): + """Get or create HTTP session""" + if self.session is None: + self.session = aiohttp.ClientSession() + return self.session + + async def close(self): + """Close HTTP session if we own it""" + if self._own_session and self.session: + await self.session.close() + self.session = None + + async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: + """Not supported by CivArchive provider""" + return None + + async def get_model_versions(self, model_id: str) -> Optional[Dict]: + """Not supported by CivArchive provider""" + return None + + async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: + """Get specific model version by parsing CivArchive HTML page""" + if model_id is None or version_id is None: + return None + + try: + # Construct CivArchive URL + url = f"https://civarchive.com/models/{model_id}?modelVersionId={version_id}" + + session = await self._get_session() + async with session.get(url) as response: + if response.status != 200: + return None + + html_content = await response.text() + + # Parse HTML to extract JSON data + soup = BeautifulSoup(html_content, 'html.parser') + script_tag = soup.find('script', {'id': '__NEXT_DATA__', 'type': 'application/json'}) + + if not script_tag: + return None + + # Parse JSON content + json_data = json.loads(script_tag.string) + model_data = json_data.get('props', {}).get('pageProps', {}).get('model') + + if not model_data or 'version' not in model_data: + return None + + # Extract version data as base + version = model_data['version'].copy() + + # Restructure stats + if 'downloadCount' in version and 'ratingCount' in version and 'rating' in version: + version['stats'] = { + 'downloadCount': version.pop('downloadCount'), + 'ratingCount': version.pop('ratingCount'), + 'rating': version.pop('rating') + } + + # Rename trigger to trainedWords + if 'trigger' in version: + version['trainedWords'] = version.pop('trigger') + + # Transform files data to expected format + if 'files' in version: + transformed_files = [] + for file_data in version['files']: + # Find first available mirror (deletedAt is null) + available_mirror = None + for mirror in file_data.get('mirrors', []): + if mirror.get('deletedAt') is None: + available_mirror = mirror + break + + # Create transformed file entry + transformed_file = { + 'id': file_data.get('id'), + 'sizeKB': file_data.get('sizeKB'), + 'name': available_mirror.get('filename', file_data.get('name')) if available_mirror else file_data.get('name'), + 'type': file_data.get('type'), + 'downloadUrl': available_mirror.get('url') if available_mirror else None, + 'primary': True, + 'mirrors': file_data.get('mirrors', []) + } + + # Transform hash format + if 'sha256' in file_data: + transformed_file['hashes'] = { + 'SHA256': file_data['sha256'].upper() + } + + transformed_files.append(transformed_file) + + version['files'] = transformed_files + + # Add model information + version['model'] = { + 'name': model_data.get('name'), + 'type': model_data.get('type'), + 'nsfw': model_data.get('is_nsfw', False), + 'description': model_data.get('description'), + 'tags': model_data.get('tags', []) + } + + version['creator'] = { + 'username': model_data.get('username'), + 'image': '' + } + + # Add source identifier + version['source'] = 'civarchive' + version['is_deleted'] = json_data.get('query', {}).get('is_deleted', False) + + return version + + except Exception as e: + logger.error(f"Error fetching CivArchive model version {model_id}/{version_id}: {e}") + return None + + async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: + """Not supported by CivArchive provider - requires both model_id and version_id""" + return None, "CivArchive provider requires both model_id and version_id" + + async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]: + """Not supported by CivArchive provider""" + return None, 404 + class SQLiteModelMetadataProvider(ModelMetadataProvider): """Provider that uses SQLite database for metadata""" diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index 80765f7b..3be2677f 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -632,6 +632,7 @@ class ModelRouteUtils: }, status=400) use_default_paths = data.get('use_default_paths', False) + source = data.get('source') # Optional source parameter # Pass the download_id to download_from_civitai result = await download_manager.download_from_civitai( @@ -641,7 +642,8 @@ class ModelRouteUtils: relative_path=data.get('relative_path', ''), use_default_paths=use_default_paths, progress_callback=progress_callback, - download_id=download_id # Pass download_id explicitly + download_id=download_id, # Pass download_id explicitly + source=source # Pass source parameter ) # Include download_id in the response diff --git a/requirements.txt b/requirements.txt index 9e280c64..4051dc74 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ numpy natsort GitPython aiosqlite +beautifulsoup4