Files
ComfyUI-Lora-Manager/py/services/civitai_client.py
Will Miao 9ba3e2c204 feat(metadata): implement metadata providers and initialize metadata service
- Added ModelMetadataProvider and CivitaiModelMetadataProvider for handling model metadata.
- Introduced SQLiteModelMetadataProvider for SQLite database integration.
- Created metadata_service.py to initialize and configure metadata providers.
- Updated CivitaiClient to register as a metadata provider.
- Refactored download_manager to use the new download_file method.
- Added SQL schema for models, model_versions, and model_files.
- Updated requirements.txt to include aiosqlite.
2025-09-08 22:41:17 +08:00

519 lines
24 KiB
Python

from datetime import datetime
import aiohttp
import os
import logging
import asyncio
from typing import Optional, Dict, Tuple, List
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
logger = logging.getLogger(__name__)
class CivitaiClient:
_instance = None
_lock = asyncio.Lock()
@classmethod
async def get_instance(cls):
"""Get singleton instance of CivitaiClient"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
# Register this client as a metadata provider
provider_manager = await ModelMetadataProviderManager.get_instance()
provider_manager.register_provider('civitai', CivitaiModelMetadataProvider(cls._instance), True)
return cls._instance
def __init__(self):
# Check if already initialized for singleton pattern
if hasattr(self, '_initialized'):
return
self._initialized = True
self.base_url = "https://civitai.com/api/v1"
self.headers = {
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
}
self._session = None
self._session_created_at = None
# Adjust chunk size based on storage type - consider making this configurable
self.chunk_size = 4 * 1024 * 1024 # 4MB chunks for better HDD throughput
@property
async def session(self) -> aiohttp.ClientSession:
"""Lazy initialize the session"""
if self._session is None:
# Optimize TCP connection parameters
connector = aiohttp.TCPConnector(
ssl=True,
limit=8, # Increase from 3 to 8 for better parallelism
ttl_dns_cache=300, # Enable DNS caching with reasonable timeout
force_close=False, # Keep connections for reuse
enable_cleanup_closed=True
)
trust_env = True # Allow using system environment proxy settings
# Configure timeout parameters - increase read timeout for large files and remove sock_read timeout
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=None)
self._session = aiohttp.ClientSession(
connector=connector,
trust_env=trust_env,
timeout=timeout
)
self._session_created_at = datetime.now()
return self._session
async def _ensure_fresh_session(self):
"""Refresh session if it's been open too long"""
if self._session is not None:
if not hasattr(self, '_session_created_at') or \
(datetime.now() - self._session_created_at).total_seconds() > 300: # 5 minutes
await self.close()
self._session = None
return await self.session
def _get_request_headers(self) -> dict:
"""Get request headers with optional API key"""
headers = {
'User-Agent': 'ComfyUI-LoRA-Manager/1.0',
'Content-Type': 'application/json'
}
from .settings_manager import settings
api_key = settings.get('civitai_api_key')
if (api_key):
headers['Authorization'] = f'Bearer {api_key}'
return headers
async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
"""Download file with resumable downloads and retry mechanism
Args:
url: Download URL
save_dir: Directory to save the file
default_filename: Fallback filename if none provided in headers
progress_callback: Optional async callback function for progress updates (0-100)
Returns:
Tuple[bool, str]: (success, save_path or error message)
"""
max_retries = 5
retry_count = 0
base_delay = 2.0 # Base delay for exponential backoff
# Initial setup
session = await self._ensure_fresh_session()
save_path = os.path.join(save_dir, default_filename)
part_path = save_path + '.part'
# Get existing file size for resume
resume_offset = 0
if os.path.exists(part_path):
resume_offset = os.path.getsize(part_path)
logger.info(f"Resuming download from offset {resume_offset} bytes")
total_size = 0
while retry_count <= max_retries:
try:
headers = self._get_request_headers()
# Add Range header for resume if we have partial data
if resume_offset > 0:
headers['Range'] = f'bytes={resume_offset}-'
# Add Range header to allow resumable downloads
headers['Accept-Encoding'] = 'identity' # Disable compression for better chunked downloads
logger.debug(f"Download attempt {retry_count + 1}/{max_retries + 1} from: {url}")
if resume_offset > 0:
logger.debug(f"Requesting range from byte {resume_offset}")
async with session.get(url, headers=headers, allow_redirects=True) as response:
# Handle different response codes
if response.status == 200:
# Full content response
if resume_offset > 0:
# Server doesn't support ranges, restart from beginning
logger.warning("Server doesn't support range requests, restarting download")
resume_offset = 0
if os.path.exists(part_path):
os.remove(part_path)
elif response.status == 206:
# Partial content response (resume successful)
content_range = response.headers.get('Content-Range')
if content_range:
# Parse total size from Content-Range header (e.g., "bytes 1024-2047/2048")
range_parts = content_range.split('/')
if len(range_parts) == 2:
total_size = int(range_parts[1])
logger.info(f"Successfully resumed download from byte {resume_offset}")
elif response.status == 416:
# Range not satisfiable - file might be complete or corrupted
if os.path.exists(part_path):
part_size = os.path.getsize(part_path)
logger.warning(f"Range not satisfiable. Part file size: {part_size}")
# Try to get actual file size
head_response = await session.head(url, headers=self._get_request_headers())
if head_response.status == 200:
actual_size = int(head_response.headers.get('content-length', 0))
if part_size == actual_size:
# File is complete, just rename it
os.rename(part_path, save_path)
if progress_callback:
await progress_callback(100)
return True, save_path
# Remove corrupted part file and restart
os.remove(part_path)
resume_offset = 0
continue
elif response.status == 401:
logger.warning(f"Unauthorized access to resource: {url} (Status 401)")
return False, "Invalid or missing CivitAI API key, or early access restriction."
elif response.status == 403:
logger.warning(f"Forbidden access to resource: {url} (Status 403)")
return False, "Access forbidden: You don't have permission to download this file."
else:
logger.error(f"Download failed for {url} with status {response.status}")
return False, f"Download failed with status {response.status}"
# Get total file size for progress calculation (if not set from Content-Range)
if total_size == 0:
total_size = int(response.headers.get('content-length', 0))
if response.status == 206:
# For partial content, add the offset to get total file size
total_size += resume_offset
current_size = resume_offset
last_progress_report_time = datetime.now()
# Stream download to file with progress updates using larger buffer
loop = asyncio.get_running_loop()
mode = 'ab' if resume_offset > 0 else 'wb'
with open(part_path, mode) as f:
async for chunk in response.content.iter_chunked(self.chunk_size):
if chunk:
# Run blocking file write in executor
await loop.run_in_executor(None, f.write, chunk)
current_size += len(chunk)
# Limit progress update frequency to reduce overhead
now = datetime.now()
time_diff = (now - last_progress_report_time).total_seconds()
if progress_callback and total_size and time_diff >= 1.0:
progress = (current_size / total_size) * 100
await progress_callback(progress)
last_progress_report_time = now
# Download completed successfully
# Verify file size if total_size was provided
final_size = os.path.getsize(part_path)
if total_size > 0 and final_size != total_size:
logger.warning(f"File size mismatch. Expected: {total_size}, Got: {final_size}")
# Don't treat this as fatal error, rename anyway
# Atomically rename .part to final file with retries
max_rename_attempts = 5
rename_attempt = 0
rename_success = False
while rename_attempt < max_rename_attempts and not rename_success:
try:
os.rename(part_path, save_path)
rename_success = True
except PermissionError as e:
rename_attempt += 1
if rename_attempt < max_rename_attempts:
logger.info(f"File still in use, retrying rename in 2 seconds (attempt {rename_attempt}/{max_rename_attempts})")
await asyncio.sleep(2) # Wait before retrying
else:
logger.error(f"Failed to rename file after {max_rename_attempts} attempts: {e}")
return False, f"Failed to finalize download: {str(e)}"
# Ensure 100% progress is reported
if progress_callback:
await progress_callback(100)
return True, save_path
except (aiohttp.ClientError, aiohttp.ClientPayloadError,
aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e:
retry_count += 1
logger.warning(f"Network error during download (attempt {retry_count}/{max_retries + 1}): {e}")
if retry_count <= max_retries:
# Calculate delay with exponential backoff
delay = base_delay * (2 ** (retry_count - 1))
logger.info(f"Retrying in {delay} seconds...")
await asyncio.sleep(delay)
# Update resume offset for next attempt
if os.path.exists(part_path):
resume_offset = os.path.getsize(part_path)
logger.info(f"Will resume from byte {resume_offset}")
# Refresh session to get new connection
await self.close()
session = await self._ensure_fresh_session()
continue
else:
logger.error(f"Max retries exceeded for download: {e}")
return False, f"Network error after {max_retries + 1} attempts: {str(e)}"
except Exception as e:
logger.error(f"Unexpected download error: {e}")
return False, str(e)
return False, f"Download failed after {max_retries + 1} attempts"
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
try:
session = await self._ensure_fresh_session()
async with session.get(f"{self.base_url}/model-versions/by-hash/{model_hash}") as response:
if response.status == 200:
return await response.json()
return None
except Exception as e:
logger.error(f"API Error: {str(e)}")
return None
async def download_preview_image(self, image_url: str, save_path: str):
try:
session = await self._ensure_fresh_session()
async with session.get(image_url) as response:
if response.status == 200:
content = await response.read()
with open(save_path, 'wb') as f:
f.write(content)
return True
return False
except Exception as e:
print(f"Download Error: {str(e)}")
return False
async def get_model_versions(self, model_id: str) -> List[Dict]:
"""Get all versions of a model with local availability info"""
try:
session = await self._ensure_fresh_session() # Use fresh session
async with session.get(f"{self.base_url}/models/{model_id}") as response:
if response.status != 200:
return None
data = await response.json()
# Also return model type along with versions
return {
'modelVersions': data.get('modelVersions', []),
'type': data.get('type', '')
}
except Exception as e:
logger.error(f"Error fetching model versions: {e}")
return None
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
"""Get specific model version with additional metadata
Args:
model_id: The Civitai model ID (optional if version_id is provided)
version_id: Optional specific version ID to retrieve
Returns:
Optional[Dict]: The model version data with additional fields or None if not found
"""
try:
session = await self._ensure_fresh_session()
headers = self._get_request_headers()
# Case 1: Only version_id is provided
if model_id is None and version_id is not None:
# First get the version info to extract model_id
async with session.get(f"{self.base_url}/model-versions/{version_id}", headers=headers) as response:
if response.status != 200:
return None
version = await response.json()
model_id = version.get('modelId')
if not model_id:
logger.error(f"No modelId found in version {version_id}")
return None
# Now get the model data for additional metadata
async with session.get(f"{self.base_url}/models/{model_id}") as response:
if response.status != 200:
return version # Return version without additional metadata
model_data = await response.json()
# Enrich version with model data
version['model']['description'] = model_data.get("description")
version['model']['tags'] = model_data.get("tags", [])
version['creator'] = model_data.get("creator")
return version
# Case 2: model_id is provided (with or without version_id)
elif model_id is not None:
# Step 1: Get model data to find version_id if not provided and get additional metadata
async with session.get(f"{self.base_url}/models/{model_id}") as response:
if response.status != 200:
return None
data = await response.json()
model_versions = data.get('modelVersions', [])
# Step 2: Determine the version_id to use
target_version_id = version_id
if target_version_id is None:
target_version_id = model_versions[0].get('id')
# Step 3: Get detailed version info using the version_id
async with session.get(f"{self.base_url}/model-versions/{target_version_id}", headers=headers) as response:
if response.status != 200:
return None
version = await response.json()
# Step 4: Enrich version_info with model data
# Add description and tags from model data
version['model']['description'] = data.get("description")
version['model']['tags'] = data.get("tags", [])
# Add creator from model data
version['creator'] = data.get("creator")
return version
# Case 3: Neither model_id nor version_id provided
else:
logger.error("Either model_id or version_id must be provided")
return None
except Exception as e:
logger.error(f"Error fetching model version: {e}")
return None
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version metadata from Civitai
Args:
version_id: The Civitai model version ID
Returns:
Tuple[Optional[Dict], Optional[str]]: A tuple containing:
- The model version data or None if not found
- An error message if there was an error, or None on success
"""
try:
session = await self._ensure_fresh_session()
url = f"{self.base_url}/model-versions/{version_id}"
headers = self._get_request_headers()
logger.debug(f"Resolving DNS for model version info: {url}")
async with session.get(url, headers=headers) as response:
if response.status == 200:
logger.debug(f"Successfully fetched model version info for: {version_id}")
return await response.json(), None
# Handle specific error cases
if response.status == 404:
# Try to parse the error message
try:
error_data = await response.json()
error_msg = error_data.get('error', f"Model not found (status 404)")
logger.warning(f"Model version not found: {version_id} - {error_msg}")
return None, error_msg
except:
return None, "Model not found (status 404)"
# Other error cases
logger.error(f"Failed to fetch model info for {version_id} (status {response.status})")
return None, f"Failed to fetch model info (status {response.status})"
except Exception as e:
error_msg = f"Error fetching model version info: {e}"
logger.error(error_msg)
return None, error_msg
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
"""Fetch model metadata (description, tags, and creator info) from Civitai API
Args:
model_id: The Civitai model ID
Returns:
Tuple[Optional[Dict], int]: A tuple containing:
- A dictionary with model metadata or None if not found
- The HTTP status code from the request
"""
try:
session = await self._ensure_fresh_session()
headers = self._get_request_headers()
url = f"{self.base_url}/models/{model_id}"
async with session.get(url, headers=headers) as response:
status_code = response.status
if status_code != 200:
logger.warning(f"Failed to fetch model metadata: Status {status_code}")
return None, status_code
data = await response.json()
# Extract relevant metadata
metadata = {
"description": data.get("description") or "No model description available",
"tags": data.get("tags", []),
"creator": {
"username": data.get("creator", {}).get("username"),
"image": data.get("creator", {}).get("image")
}
}
if metadata["description"] or metadata["tags"] or metadata["creator"]["username"]:
return metadata, status_code
else:
logger.warning(f"No metadata found for model {model_id}")
return None, status_code
except Exception as e:
logger.error(f"Error fetching model metadata: {e}", exc_info=True)
return None, 0
async def close(self):
"""Close the session if it exists"""
if self._session is not None:
await self._session.close()
self._session = None
async def get_image_info(self, image_id: str) -> Optional[Dict]:
"""Fetch image information from Civitai API
Args:
image_id: The Civitai image ID
Returns:
Optional[Dict]: The image data or None if not found
"""
try:
session = await self._ensure_fresh_session()
headers = self._get_request_headers()
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
logger.debug(f"Fetching image info for ID: {image_id}")
async with session.get(url, headers=headers) as response:
if response.status == 200:
data = await response.json()
if data and "items" in data and len(data["items"]) > 0:
logger.debug(f"Successfully fetched image info for ID: {image_id}")
return data["items"][0]
logger.warning(f"No image found with ID: {image_id}")
return None
logger.error(f"Failed to fetch image info for ID: {image_id} (status {response.status})")
return None
except Exception as e:
error_msg = f"Error fetching image info: {e}"
logger.error(error_msg)
return None