from datetime import datetime import aiohttp import os import logging import asyncio from email.parser import Parser from typing import Optional, Dict, Tuple, List from urllib.parse import unquote logger = logging.getLogger(__name__) class CivitaiClient: _instance = None _lock = asyncio.Lock() @classmethod async def get_instance(cls): """Get singleton instance of CivitaiClient""" async with cls._lock: if cls._instance is None: cls._instance = cls() return cls._instance def __init__(self): # Check if already initialized for singleton pattern if hasattr(self, '_initialized'): return self._initialized = True self.base_url = "https://civitai.com/api/v1" self.headers = { 'User-Agent': 'ComfyUI-LoRA-Manager/1.0' } self._session = None self._session_created_at = None # Adjust chunk size based on storage type - consider making this configurable self.chunk_size = 4 * 1024 * 1024 # 4MB chunks for better HDD throughput @property async def session(self) -> aiohttp.ClientSession: """Lazy initialize the session""" if self._session is None: # Optimize TCP connection parameters connector = aiohttp.TCPConnector( ssl=True, limit=8, # Increase from 3 to 8 for better parallelism ttl_dns_cache=300, # Enable DNS caching with reasonable timeout force_close=False, # Keep connections for reuse enable_cleanup_closed=True ) trust_env = True # Allow using system environment proxy settings # Configure timeout parameters - increase read timeout for large files and remove sock_read timeout timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=None) self._session = aiohttp.ClientSession( connector=connector, trust_env=trust_env, timeout=timeout ) self._session_created_at = datetime.now() return self._session async def _ensure_fresh_session(self): """Refresh session if it's been open too long""" if self._session is not None: if not hasattr(self, '_session_created_at') or \ (datetime.now() - self._session_created_at).total_seconds() > 300: # 5 minutes await self.close() self._session = None return await self.session def _parse_content_disposition(self, header: str) -> str: """Parse filename from content-disposition header""" if not header: return None # Handle quoted filenames if 'filename="' in header: start = header.index('filename="') + 10 end = header.index('"', start) return unquote(header[start:end]) # Fallback to original parsing disposition = Parser().parsestr(f'Content-Disposition: {header}') filename = disposition.get_param('filename') if filename: return unquote(filename) return None def _get_request_headers(self) -> dict: """Get request headers with optional API key""" headers = { 'User-Agent': 'ComfyUI-LoRA-Manager/1.0', 'Content-Type': 'application/json' } from .settings_manager import settings api_key = settings.get('civitai_api_key') if (api_key): headers['Authorization'] = f'Bearer {api_key}' return headers async def _download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]: """Download file with resumable downloads and retry mechanism Args: url: Download URL save_dir: Directory to save the file default_filename: Fallback filename if none provided in headers progress_callback: Optional async callback function for progress updates (0-100) Returns: Tuple[bool, str]: (success, save_path or error message) """ max_retries = 5 retry_count = 0 base_delay = 2.0 # Base delay for exponential backoff # Initial setup session = await self._ensure_fresh_session() save_path = os.path.join(save_dir, default_filename) part_path = save_path + '.part' # Get existing file size for resume resume_offset = 0 if os.path.exists(part_path): resume_offset = os.path.getsize(part_path) logger.info(f"Resuming download from offset {resume_offset} bytes") total_size = 0 filename = default_filename while retry_count <= max_retries: try: headers = self._get_request_headers() # Add Range header for resume if we have partial data if resume_offset > 0: headers['Range'] = f'bytes={resume_offset}-' # Add Range header to allow resumable downloads headers['Accept-Encoding'] = 'identity' # Disable compression for better chunked downloads logger.debug(f"Download attempt {retry_count + 1}/{max_retries + 1} from: {url}") if resume_offset > 0: logger.debug(f"Requesting range from byte {resume_offset}") async with session.get(url, headers=headers, allow_redirects=True) as response: # Handle different response codes if response.status == 200: # Full content response if resume_offset > 0: # Server doesn't support ranges, restart from beginning logger.warning("Server doesn't support range requests, restarting download") resume_offset = 0 if os.path.exists(part_path): os.remove(part_path) elif response.status == 206: # Partial content response (resume successful) content_range = response.headers.get('Content-Range') if content_range: # Parse total size from Content-Range header (e.g., "bytes 1024-2047/2048") range_parts = content_range.split('/') if len(range_parts) == 2: total_size = int(range_parts[1]) logger.info(f"Successfully resumed download from byte {resume_offset}") elif response.status == 416: # Range not satisfiable - file might be complete or corrupted if os.path.exists(part_path): part_size = os.path.getsize(part_path) logger.warning(f"Range not satisfiable. Part file size: {part_size}") # Try to get actual file size head_response = await session.head(url, headers=self._get_request_headers()) if head_response.status == 200: actual_size = int(head_response.headers.get('content-length', 0)) if part_size == actual_size: # File is complete, just rename it os.rename(part_path, save_path) if progress_callback: await progress_callback(100) return True, save_path # Remove corrupted part file and restart os.remove(part_path) resume_offset = 0 continue elif response.status == 401: logger.warning(f"Unauthorized access to resource: {url} (Status 401)") return False, "Invalid or missing CivitAI API key, or early access restriction." elif response.status == 403: logger.warning(f"Forbidden access to resource: {url} (Status 403)") return False, "Access forbidden: You don't have permission to download this file." else: logger.error(f"Download failed for {url} with status {response.status}") return False, f"Download failed with status {response.status}" # Get total file size for progress calculation (if not set from Content-Range) if total_size == 0: total_size = int(response.headers.get('content-length', 0)) if response.status == 206: # For partial content, add the offset to get total file size total_size += resume_offset current_size = resume_offset last_progress_report_time = datetime.now() # Stream download to file with progress updates using larger buffer loop = asyncio.get_running_loop() mode = 'ab' if resume_offset > 0 else 'wb' with open(part_path, mode) as f: async for chunk in response.content.iter_chunked(self.chunk_size): if chunk: # Run blocking file write in executor await loop.run_in_executor(None, f.write, chunk) current_size += len(chunk) # Limit progress update frequency to reduce overhead now = datetime.now() time_diff = (now - last_progress_report_time).total_seconds() if progress_callback and total_size and time_diff >= 1.0: progress = (current_size / total_size) * 100 await progress_callback(progress) last_progress_report_time = now # Download completed successfully # Verify file size if total_size was provided final_size = os.path.getsize(part_path) if total_size > 0 and final_size != total_size: logger.warning(f"File size mismatch. Expected: {total_size}, Got: {final_size}") # Don't treat this as fatal error, rename anyway # Atomically rename .part to final file with retries max_rename_attempts = 5 rename_attempt = 0 rename_success = False while rename_attempt < max_rename_attempts and not rename_success: try: os.rename(part_path, save_path) rename_success = True except PermissionError as e: rename_attempt += 1 if rename_attempt < max_rename_attempts: logger.info(f"File still in use, retrying rename in 2 seconds (attempt {rename_attempt}/{max_rename_attempts})") await asyncio.sleep(2) # Wait before retrying else: logger.error(f"Failed to rename file after {max_rename_attempts} attempts: {e}") return False, f"Failed to finalize download: {str(e)}" # Ensure 100% progress is reported if progress_callback: await progress_callback(100) return True, save_path except (aiohttp.ClientError, aiohttp.ClientPayloadError, aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e: retry_count += 1 logger.warning(f"Network error during download (attempt {retry_count}/{max_retries + 1}): {e}") if retry_count <= max_retries: # Calculate delay with exponential backoff delay = base_delay * (2 ** (retry_count - 1)) logger.info(f"Retrying in {delay} seconds...") await asyncio.sleep(delay) # Update resume offset for next attempt if os.path.exists(part_path): resume_offset = os.path.getsize(part_path) logger.info(f"Will resume from byte {resume_offset}") # Refresh session to get new connection await self.close() session = await self._ensure_fresh_session() continue else: logger.error(f"Max retries exceeded for download: {e}") return False, f"Network error after {max_retries + 1} attempts: {str(e)}" except Exception as e: logger.error(f"Unexpected download error: {e}") return False, str(e) return False, f"Download failed after {max_retries + 1} attempts" async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: try: session = await self._ensure_fresh_session() async with session.get(f"{self.base_url}/model-versions/by-hash/{model_hash}") as response: if response.status == 200: return await response.json() return None except Exception as e: logger.error(f"API Error: {str(e)}") return None async def download_preview_image(self, image_url: str, save_path: str): try: session = await self._ensure_fresh_session() async with session.get(image_url) as response: if response.status == 200: content = await response.read() with open(save_path, 'wb') as f: f.write(content) return True return False except Exception as e: print(f"Download Error: {str(e)}") return False async def get_model_versions(self, model_id: str) -> List[Dict]: """Get all versions of a model with local availability info""" try: session = await self._ensure_fresh_session() # Use fresh session async with session.get(f"{self.base_url}/models/{model_id}") as response: if response.status != 200: return None data = await response.json() # Also return model type along with versions return { 'modelVersions': data.get('modelVersions', []), 'type': data.get('type', '') } except Exception as e: logger.error(f"Error fetching model versions: {e}") return None async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: """Get specific model version with additional metadata Args: model_id: The Civitai model ID (optional if version_id is provided) version_id: Optional specific version ID to retrieve Returns: Optional[Dict]: The model version data with additional fields or None if not found """ try: session = await self._ensure_fresh_session() headers = self._get_request_headers() # Case 1: Only version_id is provided if model_id is None and version_id is not None: # First get the version info to extract model_id async with session.get(f"{self.base_url}/model-versions/{version_id}", headers=headers) as response: if response.status != 200: return None version = await response.json() model_id = version.get('modelId') if not model_id: logger.error(f"No modelId found in version {version_id}") return None # Now get the model data for additional metadata async with session.get(f"{self.base_url}/models/{model_id}") as response: if response.status != 200: return version # Return version without additional metadata model_data = await response.json() # Enrich version with model data version['model']['description'] = model_data.get("description") version['model']['tags'] = model_data.get("tags", []) version['creator'] = model_data.get("creator") return version # Case 2: model_id is provided (with or without version_id) elif model_id is not None: # Step 1: Get model data to find version_id if not provided and get additional metadata async with session.get(f"{self.base_url}/models/{model_id}") as response: if response.status != 200: return None data = await response.json() model_versions = data.get('modelVersions', []) # Step 2: Determine the version_id to use target_version_id = version_id if target_version_id is None: target_version_id = model_versions[0].get('id') # Step 3: Get detailed version info using the version_id async with session.get(f"{self.base_url}/model-versions/{target_version_id}", headers=headers) as response: if response.status != 200: return None version = await response.json() # Step 4: Enrich version_info with model data # Add description and tags from model data version['model']['description'] = data.get("description") version['model']['tags'] = data.get("tags", []) # Add creator from model data version['creator'] = data.get("creator") return version # Case 3: Neither model_id nor version_id provided else: logger.error("Either model_id or version_id must be provided") return None except Exception as e: logger.error(f"Error fetching model version: {e}") return None async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: """Fetch model version metadata from Civitai Args: version_id: The Civitai model version ID Returns: Tuple[Optional[Dict], Optional[str]]: A tuple containing: - The model version data or None if not found - An error message if there was an error, or None on success """ try: session = await self._ensure_fresh_session() url = f"{self.base_url}/model-versions/{version_id}" headers = self._get_request_headers() logger.debug(f"Resolving DNS for model version info: {url}") async with session.get(url, headers=headers) as response: if response.status == 200: logger.debug(f"Successfully fetched model version info for: {version_id}") return await response.json(), None # Handle specific error cases if response.status == 404: # Try to parse the error message try: error_data = await response.json() error_msg = error_data.get('error', f"Model not found (status 404)") logger.warning(f"Model version not found: {version_id} - {error_msg}") return None, error_msg except: return None, "Model not found (status 404)" # Other error cases logger.error(f"Failed to fetch model info for {version_id} (status {response.status})") return None, f"Failed to fetch model info (status {response.status})" except Exception as e: error_msg = f"Error fetching model version info: {e}" logger.error(error_msg) return None, error_msg async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]: """Fetch model metadata (description, tags, and creator info) from Civitai API Args: model_id: The Civitai model ID Returns: Tuple[Optional[Dict], int]: A tuple containing: - A dictionary with model metadata or None if not found - The HTTP status code from the request """ try: session = await self._ensure_fresh_session() headers = self._get_request_headers() url = f"{self.base_url}/models/{model_id}" async with session.get(url, headers=headers) as response: status_code = response.status if status_code != 200: logger.warning(f"Failed to fetch model metadata: Status {status_code}") return None, status_code data = await response.json() # Extract relevant metadata metadata = { "description": data.get("description") or "No model description available", "tags": data.get("tags", []), "creator": { "username": data.get("creator", {}).get("username"), "image": data.get("creator", {}).get("image") } } if metadata["description"] or metadata["tags"] or metadata["creator"]["username"]: return metadata, status_code else: logger.warning(f"No metadata found for model {model_id}") return None, status_code except Exception as e: logger.error(f"Error fetching model metadata: {e}", exc_info=True) return None, 0 # Keep old method for backward compatibility, delegating to the new one async def get_model_description(self, model_id: str) -> Optional[str]: """Fetch the model description from Civitai API (Legacy method)""" metadata, _ = await self.get_model_metadata(model_id) return metadata.get("description") if metadata else None async def close(self): """Close the session if it exists""" if self._session is not None: await self._session.close() self._session = None async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]: """Get hash from Civitai API""" try: session = await self._ensure_fresh_session() if not session: return None version_info = await session.get(f"{self.base_url}/model-versions/{model_version_id}") if not version_info or not version_info.json().get('files'): return None # Get hash from the first file for file_info in version_info.json().get('files', []): if file_info.get('hashes', {}).get('SHA256'): # Convert hash to lowercase to standardize hash_value = file_info['hashes']['SHA256'].lower() return hash_value return None except Exception as e: logger.error(f"Error getting hash from Civitai: {e}") return None async def get_image_info(self, image_id: str) -> Optional[Dict]: """Fetch image information from Civitai API Args: image_id: The Civitai image ID Returns: Optional[Dict]: The image data or None if not found """ try: session = await self._ensure_fresh_session() headers = self._get_request_headers() url = f"{self.base_url}/images?imageId={image_id}&nsfw=X" logger.debug(f"Fetching image info for ID: {image_id}") async with session.get(url, headers=headers) as response: if response.status == 200: data = await response.json() if data and "items" in data and len(data["items"]) > 0: logger.debug(f"Successfully fetched image info for ID: {image_id}") return data["items"][0] logger.warning(f"No image found with ID: {image_id}") return None logger.error(f"Failed to fetch image info for ID: {image_id} (status {response.status})") return None except Exception as e: error_msg = f"Error fetching image info: {e}" logger.error(error_msg) return None