mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Additional info: Now prioritizes using the Civitai Images API to fetch image and generation metadata. Even NSFW images can now be imported via URL.
380 lines
16 KiB
Python
380 lines
16 KiB
Python
from datetime import datetime
|
|
import aiohttp
|
|
import os
|
|
import json
|
|
import logging
|
|
import asyncio
|
|
from email.parser import Parser
|
|
from typing import Optional, Dict, Tuple, List
|
|
from urllib.parse import unquote
|
|
from ..utils.models import LoraMetadata
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class CivitaiClient:
|
|
_instance = None
|
|
_lock = asyncio.Lock()
|
|
|
|
@classmethod
|
|
async def get_instance(cls):
|
|
"""Get singleton instance of CivitaiClient"""
|
|
async with cls._lock:
|
|
if cls._instance is None:
|
|
cls._instance = cls()
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
# Check if already initialized for singleton pattern
|
|
if hasattr(self, '_initialized'):
|
|
return
|
|
self._initialized = True
|
|
|
|
self.base_url = "https://civitai.com/api/v1"
|
|
self.headers = {
|
|
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
|
|
}
|
|
self._session = None
|
|
self._session_created_at = None
|
|
# Set default buffer size to 1MB for higher throughput
|
|
self.chunk_size = 1024 * 1024
|
|
|
|
@property
|
|
async def session(self) -> aiohttp.ClientSession:
|
|
"""Lazy initialize the session"""
|
|
if self._session is None:
|
|
# Optimize TCP connection parameters
|
|
connector = aiohttp.TCPConnector(
|
|
ssl=True,
|
|
limit=3, # Further reduced from 5 to 3
|
|
ttl_dns_cache=0, # Disabled DNS caching completely
|
|
force_close=False, # Keep connections for reuse
|
|
enable_cleanup_closed=True
|
|
)
|
|
trust_env = True # Allow using system environment proxy settings
|
|
# Configure timeout parameters
|
|
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
|
|
self._session = aiohttp.ClientSession(
|
|
connector=connector,
|
|
trust_env=trust_env,
|
|
timeout=timeout
|
|
)
|
|
self._session_created_at = datetime.now()
|
|
return self._session
|
|
|
|
async def _ensure_fresh_session(self):
|
|
"""Refresh session if it's been open too long"""
|
|
if self._session is not None:
|
|
if not hasattr(self, '_session_created_at') or \
|
|
(datetime.now() - self._session_created_at).total_seconds() > 300: # 5 minutes
|
|
await self.close()
|
|
self._session = None
|
|
|
|
return await self.session
|
|
|
|
def _parse_content_disposition(self, header: str) -> str:
|
|
"""Parse filename from content-disposition header"""
|
|
if not header:
|
|
return None
|
|
|
|
# Handle quoted filenames
|
|
if 'filename="' in header:
|
|
start = header.index('filename="') + 10
|
|
end = header.index('"', start)
|
|
return unquote(header[start:end])
|
|
|
|
# Fallback to original parsing
|
|
disposition = Parser().parsestr(f'Content-Disposition: {header}')
|
|
filename = disposition.get_param('filename')
|
|
if filename:
|
|
return unquote(filename)
|
|
return None
|
|
|
|
def _get_request_headers(self) -> dict:
|
|
"""Get request headers with optional API key"""
|
|
headers = {
|
|
'User-Agent': 'ComfyUI-LoRA-Manager/1.0',
|
|
'Content-Type': 'application/json'
|
|
}
|
|
|
|
from .settings_manager import settings
|
|
api_key = settings.get('civitai_api_key')
|
|
if (api_key):
|
|
headers['Authorization'] = f'Bearer {api_key}'
|
|
|
|
return headers
|
|
|
|
async def _download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
|
|
"""Download file with content-disposition support and progress tracking
|
|
|
|
Args:
|
|
url: Download URL
|
|
save_dir: Directory to save the file
|
|
default_filename: Fallback filename if none provided in headers
|
|
progress_callback: Optional async callback function for progress updates (0-100)
|
|
|
|
Returns:
|
|
Tuple[bool, str]: (success, save_path or error message)
|
|
"""
|
|
logger.debug(f"Resolving DNS for: {url}")
|
|
session = await self._ensure_fresh_session()
|
|
try:
|
|
headers = self._get_request_headers()
|
|
|
|
# Add Range header to allow resumable downloads
|
|
headers['Accept-Encoding'] = 'identity' # Disable compression for better chunked downloads
|
|
|
|
logger.debug(f"Starting download from: {url}")
|
|
async with session.get(url, headers=headers, allow_redirects=True) as response:
|
|
if response.status != 200:
|
|
# Handle 401 unauthorized responses
|
|
if response.status == 401:
|
|
logger.warning(f"Unauthorized access to resource: {url} (Status 401)")
|
|
|
|
return False, "Invalid or missing CivitAI API key, or early access restriction."
|
|
|
|
# Handle other client errors that might be permission-related
|
|
if response.status == 403:
|
|
logger.warning(f"Forbidden access to resource: {url} (Status 403)")
|
|
return False, "Access forbidden: You don't have permission to download this file."
|
|
|
|
# Generic error response for other status codes
|
|
logger.error(f"Download failed for {url} with status {response.status}")
|
|
return False, f"Download failed with status {response.status}"
|
|
|
|
# Get filename from content-disposition header
|
|
content_disposition = response.headers.get('Content-Disposition')
|
|
filename = self._parse_content_disposition(content_disposition)
|
|
if not filename:
|
|
filename = default_filename
|
|
|
|
save_path = os.path.join(save_dir, filename)
|
|
|
|
# Get total file size for progress calculation
|
|
total_size = int(response.headers.get('content-length', 0))
|
|
current_size = 0
|
|
last_progress_report_time = datetime.now()
|
|
|
|
# Stream download to file with progress updates using larger buffer
|
|
with open(save_path, 'wb') as f:
|
|
async for chunk in response.content.iter_chunked(self.chunk_size):
|
|
if chunk:
|
|
f.write(chunk)
|
|
current_size += len(chunk)
|
|
|
|
# Limit progress update frequency to reduce overhead
|
|
now = datetime.now()
|
|
time_diff = (now - last_progress_report_time).total_seconds()
|
|
|
|
if progress_callback and total_size and time_diff >= 0.5:
|
|
progress = (current_size / total_size) * 100
|
|
await progress_callback(progress)
|
|
last_progress_report_time = now
|
|
|
|
# Ensure 100% progress is reported
|
|
if progress_callback:
|
|
await progress_callback(100)
|
|
|
|
return True, save_path
|
|
|
|
except aiohttp.ClientError as e:
|
|
logger.error(f"Network error during download: {e}")
|
|
return False, f"Network error: {str(e)}"
|
|
except Exception as e:
|
|
logger.error(f"Download error: {e}")
|
|
return False, str(e)
|
|
|
|
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
|
try:
|
|
session = await self._ensure_fresh_session()
|
|
async with session.get(f"{self.base_url}/model-versions/by-hash/{model_hash}") as response:
|
|
if response.status == 200:
|
|
return await response.json()
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"API Error: {str(e)}")
|
|
return None
|
|
|
|
async def download_preview_image(self, image_url: str, save_path: str):
|
|
try:
|
|
session = await self._ensure_fresh_session()
|
|
async with session.get(image_url) as response:
|
|
if response.status == 200:
|
|
content = await response.read()
|
|
with open(save_path, 'wb') as f:
|
|
f.write(content)
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
print(f"Download Error: {str(e)}")
|
|
return False
|
|
|
|
async def get_model_versions(self, model_id: str) -> List[Dict]:
|
|
"""Get all versions of a model with local availability info"""
|
|
try:
|
|
session = await self._ensure_fresh_session() # Use fresh session
|
|
async with session.get(f"{self.base_url}/models/{model_id}") as response:
|
|
if response.status != 200:
|
|
return None
|
|
data = await response.json()
|
|
# Also return model type along with versions
|
|
return {
|
|
'modelVersions': data.get('modelVersions', []),
|
|
'type': data.get('type', '')
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error fetching model versions: {e}")
|
|
return None
|
|
|
|
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
|
"""Fetch model version metadata from Civitai
|
|
|
|
Args:
|
|
version_id: The Civitai model version ID
|
|
|
|
Returns:
|
|
Tuple[Optional[Dict], Optional[str]]: A tuple containing:
|
|
- The model version data or None if not found
|
|
- An error message if there was an error, or None on success
|
|
"""
|
|
try:
|
|
session = await self._ensure_fresh_session()
|
|
url = f"{self.base_url}/model-versions/{version_id}"
|
|
headers = self._get_request_headers()
|
|
|
|
logger.debug(f"Resolving DNS for model version info: {url}")
|
|
async with session.get(url, headers=headers) as response:
|
|
if response.status == 200:
|
|
logger.debug(f"Successfully fetched model version info for: {version_id}")
|
|
return await response.json(), None
|
|
|
|
# Handle specific error cases
|
|
if response.status == 404:
|
|
# Try to parse the error message
|
|
try:
|
|
error_data = await response.json()
|
|
error_msg = error_data.get('error', f"Model not found (status 404)")
|
|
logger.warning(f"Model version not found: {version_id} - {error_msg}")
|
|
return None, error_msg
|
|
except:
|
|
return None, "Model not found (status 404)"
|
|
|
|
# Other error cases
|
|
logger.error(f"Failed to fetch model info for {version_id} (status {response.status})")
|
|
return None, f"Failed to fetch model info (status {response.status})"
|
|
except Exception as e:
|
|
error_msg = f"Error fetching model version info: {e}"
|
|
logger.error(error_msg)
|
|
return None, error_msg
|
|
|
|
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
|
"""Fetch model metadata (description, tags, and creator info) from Civitai API
|
|
|
|
Args:
|
|
model_id: The Civitai model ID
|
|
|
|
Returns:
|
|
Tuple[Optional[Dict], int]: A tuple containing:
|
|
- A dictionary with model metadata or None if not found
|
|
- The HTTP status code from the request
|
|
"""
|
|
try:
|
|
session = await self._ensure_fresh_session()
|
|
headers = self._get_request_headers()
|
|
url = f"{self.base_url}/models/{model_id}"
|
|
|
|
async with session.get(url, headers=headers) as response:
|
|
status_code = response.status
|
|
|
|
if status_code != 200:
|
|
logger.warning(f"Failed to fetch model metadata: Status {status_code}")
|
|
return None, status_code
|
|
|
|
data = await response.json()
|
|
|
|
# Extract relevant metadata
|
|
metadata = {
|
|
"description": data.get("description") or "No model description available",
|
|
"tags": data.get("tags", []),
|
|
"creator": {
|
|
"username": data.get("creator", {}).get("username"),
|
|
"image": data.get("creator", {}).get("image")
|
|
}
|
|
}
|
|
|
|
if metadata["description"] or metadata["tags"] or metadata["creator"]["username"]:
|
|
return metadata, status_code
|
|
else:
|
|
logger.warning(f"No metadata found for model {model_id}")
|
|
return None, status_code
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching model metadata: {e}", exc_info=True)
|
|
return None, 0
|
|
|
|
# Keep old method for backward compatibility, delegating to the new one
|
|
async def get_model_description(self, model_id: str) -> Optional[str]:
|
|
"""Fetch the model description from Civitai API (Legacy method)"""
|
|
metadata, _ = await self.get_model_metadata(model_id)
|
|
return metadata.get("description") if metadata else None
|
|
|
|
async def close(self):
|
|
"""Close the session if it exists"""
|
|
if self._session is not None:
|
|
await self._session.close()
|
|
self._session = None
|
|
|
|
async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]:
|
|
"""Get hash from Civitai API"""
|
|
try:
|
|
session = await self._ensure_fresh_session()
|
|
if not session:
|
|
return None
|
|
|
|
version_info = await session.get(f"{self.base_url}/model-versions/{model_version_id}")
|
|
|
|
if not version_info or not version_info.json().get('files'):
|
|
return None
|
|
|
|
# Get hash from the first file
|
|
for file_info in version_info.json().get('files', []):
|
|
if file_info.get('hashes', {}).get('SHA256'):
|
|
# Convert hash to lowercase to standardize
|
|
hash_value = file_info['hashes']['SHA256'].lower()
|
|
return hash_value
|
|
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting hash from Civitai: {e}")
|
|
return None
|
|
|
|
async def get_image_info(self, image_id: str) -> Optional[Dict]:
|
|
"""Fetch image information from Civitai API
|
|
|
|
Args:
|
|
image_id: The Civitai image ID
|
|
|
|
Returns:
|
|
Optional[Dict]: The image data or None if not found
|
|
"""
|
|
try:
|
|
session = await self._ensure_fresh_session()
|
|
headers = self._get_request_headers()
|
|
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
|
|
|
|
logger.debug(f"Fetching image info for ID: {image_id}")
|
|
async with session.get(url, headers=headers) as response:
|
|
if response.status == 200:
|
|
data = await response.json()
|
|
if data and "items" in data and len(data["items"]) > 0:
|
|
logger.debug(f"Successfully fetched image info for ID: {image_id}")
|
|
return data["items"][0]
|
|
logger.warning(f"No image found with ID: {image_id}")
|
|
return None
|
|
|
|
logger.error(f"Failed to fetch image info for ID: {image_id} (status {response.status})")
|
|
return None
|
|
except Exception as e:
|
|
error_msg = f"Error fetching image info: {e}"
|
|
logger.error(error_msg)
|
|
return None
|