from datetime import datetime import aiohttp import os import json import logging import asyncio from email.parser import Parser from typing import Optional, Dict, Tuple, List from urllib.parse import unquote from ..utils.models import LoraMetadata logger = logging.getLogger(__name__) class CivitaiClient: _instance = None _lock = asyncio.Lock() @classmethod async def get_instance(cls): """Get singleton instance of CivitaiClient""" async with cls._lock: if cls._instance is None: cls._instance = cls() return cls._instance def __init__(self): # Check if already initialized for singleton pattern if hasattr(self, '_initialized'): return self._initialized = True self.base_url = "https://civitai.com/api/v1" self.headers = { 'User-Agent': 'ComfyUI-LoRA-Manager/1.0' } self._session = None self._session_created_at = None # Set default buffer size to 1MB for higher throughput self.chunk_size = 1024 * 1024 @property async def session(self) -> aiohttp.ClientSession: """Lazy initialize the session""" if self._session is None: # Optimize TCP connection parameters connector = aiohttp.TCPConnector( ssl=True, limit=3, # Further reduced from 5 to 3 ttl_dns_cache=0, # Disabled DNS caching completely force_close=False, # Keep connections for reuse enable_cleanup_closed=True ) trust_env = True # Allow using system environment proxy settings # Configure timeout parameters timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60) self._session = aiohttp.ClientSession( connector=connector, trust_env=trust_env, timeout=timeout ) self._session_created_at = datetime.now() return self._session async def _ensure_fresh_session(self): """Refresh session if it's been open too long""" if self._session is not None: if not hasattr(self, '_session_created_at') or \ (datetime.now() - self._session_created_at).total_seconds() > 300: # 5 minutes await self.close() self._session = None return await self.session def _parse_content_disposition(self, header: str) -> str: """Parse filename from content-disposition header""" if not header: return None # Handle quoted filenames if 'filename="' in header: start = header.index('filename="') + 10 end = header.index('"', start) return unquote(header[start:end]) # Fallback to original parsing disposition = Parser().parsestr(f'Content-Disposition: {header}') filename = disposition.get_param('filename') if filename: return unquote(filename) return None def _get_request_headers(self) -> dict: """Get request headers with optional API key""" headers = { 'User-Agent': 'ComfyUI-LoRA-Manager/1.0', 'Content-Type': 'application/json' } from .settings_manager import settings api_key = settings.get('civitai_api_key') if (api_key): headers['Authorization'] = f'Bearer {api_key}' return headers async def _download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]: """Download file with content-disposition support and progress tracking Args: url: Download URL save_dir: Directory to save the file default_filename: Fallback filename if none provided in headers progress_callback: Optional async callback function for progress updates (0-100) Returns: Tuple[bool, str]: (success, save_path or error message) """ logger.debug(f"Resolving DNS for: {url}") session = await self._ensure_fresh_session() try: headers = self._get_request_headers() # Add Range header to allow resumable downloads headers['Accept-Encoding'] = 'identity' # Disable compression for better chunked downloads logger.debug(f"Starting download from: {url}") async with session.get(url, headers=headers, allow_redirects=True) as response: if response.status != 200: # Handle 401 unauthorized responses if response.status == 401: logger.warning(f"Unauthorized access to resource: {url} (Status 401)") return False, "Invalid or missing CivitAI API key, or early access restriction." # Handle other client errors that might be permission-related if response.status == 403: logger.warning(f"Forbidden access to resource: {url} (Status 403)") return False, "Access forbidden: You don't have permission to download this file." # Generic error response for other status codes logger.error(f"Download failed for {url} with status {response.status}") return False, f"Download failed with status {response.status}" # Get filename from content-disposition header content_disposition = response.headers.get('Content-Disposition') filename = self._parse_content_disposition(content_disposition) if not filename: filename = default_filename save_path = os.path.join(save_dir, filename) # Get total file size for progress calculation total_size = int(response.headers.get('content-length', 0)) current_size = 0 last_progress_report_time = datetime.now() # Stream download to file with progress updates using larger buffer with open(save_path, 'wb') as f: async for chunk in response.content.iter_chunked(self.chunk_size): if chunk: f.write(chunk) current_size += len(chunk) # Limit progress update frequency to reduce overhead now = datetime.now() time_diff = (now - last_progress_report_time).total_seconds() if progress_callback and total_size and time_diff >= 0.5: progress = (current_size / total_size) * 100 await progress_callback(progress) last_progress_report_time = now # Ensure 100% progress is reported if progress_callback: await progress_callback(100) return True, save_path except aiohttp.ClientError as e: logger.error(f"Network error during download: {e}") return False, f"Network error: {str(e)}" except Exception as e: logger.error(f"Download error: {e}") return False, str(e) async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: try: session = await self._ensure_fresh_session() async with session.get(f"{self.base_url}/model-versions/by-hash/{model_hash}") as response: if response.status == 200: return await response.json() return None except Exception as e: logger.error(f"API Error: {str(e)}") return None async def download_preview_image(self, image_url: str, save_path: str): try: session = await self._ensure_fresh_session() async with session.get(image_url) as response: if response.status == 200: content = await response.read() with open(save_path, 'wb') as f: f.write(content) return True return False except Exception as e: print(f"Download Error: {str(e)}") return False async def get_model_versions(self, model_id: str) -> List[Dict]: """Get all versions of a model with local availability info""" try: session = await self._ensure_fresh_session() # Use fresh session async with session.get(f"{self.base_url}/models/{model_id}") as response: if response.status != 200: return None data = await response.json() # Also return model type along with versions return { 'modelVersions': data.get('modelVersions', []), 'type': data.get('type', '') } except Exception as e: logger.error(f"Error fetching model versions: {e}") return None async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: """Fetch model version metadata from Civitai Args: version_id: The Civitai model version ID Returns: Tuple[Optional[Dict], Optional[str]]: A tuple containing: - The model version data or None if not found - An error message if there was an error, or None on success """ try: session = await self._ensure_fresh_session() url = f"{self.base_url}/model-versions/{version_id}" headers = self._get_request_headers() logger.debug(f"Resolving DNS for model version info: {url}") async with session.get(url, headers=headers) as response: if response.status == 200: logger.debug(f"Successfully fetched model version info for: {version_id}") return await response.json(), None # Handle specific error cases if response.status == 404: # Try to parse the error message try: error_data = await response.json() error_msg = error_data.get('error', f"Model not found (status 404)") logger.warning(f"Model version not found: {version_id} - {error_msg}") return None, error_msg except: return None, "Model not found (status 404)" # Other error cases logger.error(f"Failed to fetch model info for {version_id} (status {response.status})") return None, f"Failed to fetch model info (status {response.status})" except Exception as e: error_msg = f"Error fetching model version info: {e}" logger.error(error_msg) return None, error_msg async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]: """Fetch model metadata (description, tags, and creator info) from Civitai API Args: model_id: The Civitai model ID Returns: Tuple[Optional[Dict], int]: A tuple containing: - A dictionary with model metadata or None if not found - The HTTP status code from the request """ try: session = await self._ensure_fresh_session() headers = self._get_request_headers() url = f"{self.base_url}/models/{model_id}" async with session.get(url, headers=headers) as response: status_code = response.status if status_code != 200: logger.warning(f"Failed to fetch model metadata: Status {status_code}") return None, status_code data = await response.json() # Extract relevant metadata metadata = { "description": data.get("description") or "No model description available", "tags": data.get("tags", []), "creator": { "username": data.get("creator", {}).get("username"), "image": data.get("creator", {}).get("image") } } if metadata["description"] or metadata["tags"] or metadata["creator"]["username"]: return metadata, status_code else: logger.warning(f"No metadata found for model {model_id}") return None, status_code except Exception as e: logger.error(f"Error fetching model metadata: {e}", exc_info=True) return None, 0 # Keep old method for backward compatibility, delegating to the new one async def get_model_description(self, model_id: str) -> Optional[str]: """Fetch the model description from Civitai API (Legacy method)""" metadata, _ = await self.get_model_metadata(model_id) return metadata.get("description") if metadata else None async def close(self): """Close the session if it exists""" if self._session is not None: await self._session.close() self._session = None async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]: """Get hash from Civitai API""" try: session = await self._ensure_fresh_session() if not session: return None version_info = await session.get(f"{self.base_url}/model-versions/{model_version_id}") if not version_info or not version_info.json().get('files'): return None # Get hash from the first file for file_info in version_info.json().get('files', []): if file_info.get('hashes', {}).get('SHA256'): # Convert hash to lowercase to standardize hash_value = file_info['hashes']['SHA256'].lower() return hash_value return None except Exception as e: logger.error(f"Error getting hash from Civitai: {e}") return None