mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 05:32:12 -03:00
531 lines
19 KiB
Python
531 lines
19 KiB
Python
import asyncio
|
|
import copy
|
|
import logging
|
|
import os
|
|
from typing import Any, Optional, Dict, Tuple, List, Sequence
|
|
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
|
|
from .downloader import get_downloader
|
|
from .errors import RateLimitError, ResourceNotFoundError
|
|
from ..utils.civitai_utils import resolve_license_payload
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class CivitaiClient:
|
|
_instance = None
|
|
_lock = asyncio.Lock()
|
|
|
|
@classmethod
|
|
async def get_instance(cls):
|
|
"""Get singleton instance of CivitaiClient"""
|
|
async with cls._lock:
|
|
if cls._instance is None:
|
|
cls._instance = cls()
|
|
|
|
# Register this client as a metadata provider
|
|
provider_manager = await ModelMetadataProviderManager.get_instance()
|
|
provider_manager.register_provider('civitai', CivitaiModelMetadataProvider(cls._instance), True)
|
|
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
# Check if already initialized for singleton pattern
|
|
if hasattr(self, '_initialized'):
|
|
return
|
|
self._initialized = True
|
|
|
|
self.base_url = "https://civitai.com/api/v1"
|
|
|
|
async def _make_request(
|
|
self,
|
|
method: str,
|
|
url: str,
|
|
*,
|
|
use_auth: bool = False,
|
|
**kwargs,
|
|
) -> Tuple[bool, Dict | str]:
|
|
"""Wrapper around downloader.make_request that surfaces rate limits."""
|
|
|
|
downloader = await get_downloader()
|
|
success, result = await downloader.make_request(
|
|
method,
|
|
url,
|
|
use_auth=use_auth,
|
|
**kwargs,
|
|
)
|
|
if not success and isinstance(result, RateLimitError):
|
|
if result.provider is None:
|
|
result.provider = "civitai_api"
|
|
raise result
|
|
return success, result
|
|
|
|
@staticmethod
|
|
def _remove_comfy_metadata(model_version: Optional[Dict]) -> None:
|
|
"""Remove Comfy-specific metadata from model version images."""
|
|
if not isinstance(model_version, dict):
|
|
return
|
|
|
|
images = model_version.get("images")
|
|
if not isinstance(images, list):
|
|
return
|
|
|
|
for image in images:
|
|
if not isinstance(image, dict):
|
|
continue
|
|
|
|
meta = image.get("meta")
|
|
if isinstance(meta, dict) and "comfy" in meta:
|
|
meta.pop("comfy", None)
|
|
|
|
async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
|
|
"""Download file with resumable downloads and retry mechanism
|
|
|
|
Args:
|
|
url: Download URL
|
|
save_dir: Directory to save the file
|
|
default_filename: Fallback filename if none provided in headers
|
|
progress_callback: Optional async callback function for progress updates (0-100)
|
|
|
|
Returns:
|
|
Tuple[bool, str]: (success, save_path or error message)
|
|
"""
|
|
downloader = await get_downloader()
|
|
save_path = os.path.join(save_dir, default_filename)
|
|
|
|
# Use unified downloader with CivitAI authentication
|
|
success, result = await downloader.download_file(
|
|
url=url,
|
|
save_path=save_path,
|
|
progress_callback=progress_callback,
|
|
use_auth=True, # Enable CivitAI authentication
|
|
allow_resume=True
|
|
)
|
|
|
|
return success, result
|
|
|
|
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
|
try:
|
|
success, version = await self._make_request(
|
|
'GET',
|
|
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
|
use_auth=True
|
|
)
|
|
if not success:
|
|
message = str(version)
|
|
if "not found" in message.lower():
|
|
return None, "Model not found"
|
|
|
|
logger.error("Failed to fetch model info for %s: %s", model_hash[:10], message)
|
|
return None, message
|
|
|
|
model_id = version.get('modelId')
|
|
if model_id:
|
|
model_data = await self._fetch_model_data(model_id)
|
|
if model_data:
|
|
self._enrich_version_with_model_data(version, model_data)
|
|
|
|
self._remove_comfy_metadata(version)
|
|
return version, None
|
|
except RateLimitError:
|
|
raise
|
|
except Exception as exc:
|
|
logger.error("API Error: %s", exc)
|
|
return None, str(exc)
|
|
|
|
async def download_preview_image(self, image_url: str, save_path: str):
|
|
try:
|
|
downloader = await get_downloader()
|
|
success, content, headers = await downloader.download_to_memory(
|
|
image_url,
|
|
use_auth=False # Preview images don't need auth
|
|
)
|
|
if success:
|
|
# Ensure directory exists
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
with open(save_path, 'wb') as f:
|
|
f.write(content)
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Download Error: {str(e)}")
|
|
return False
|
|
|
|
@staticmethod
|
|
def _extract_error_message(payload: Any) -> str:
|
|
"""Return a human-readable error message from an API payload."""
|
|
|
|
def _from_value(value: Any) -> str:
|
|
if isinstance(value, str):
|
|
return value
|
|
if isinstance(value, dict):
|
|
for key in ("message", "error", "detail", "details"):
|
|
if key in value:
|
|
candidate = _from_value(value[key])
|
|
if candidate:
|
|
return candidate
|
|
if isinstance(value, list):
|
|
for item in value:
|
|
candidate = _from_value(item)
|
|
if candidate:
|
|
return candidate
|
|
return ""
|
|
|
|
return _from_value(payload)
|
|
|
|
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
|
"""Get all versions of a model with local availability info"""
|
|
try:
|
|
success, result = await self._make_request(
|
|
'GET',
|
|
f"{self.base_url}/models/{model_id}",
|
|
use_auth=True
|
|
)
|
|
if success:
|
|
# Also return model type along with versions
|
|
return {
|
|
'modelVersions': result.get('modelVersions', []),
|
|
'type': result.get('type', ''),
|
|
'name': result.get('name', '')
|
|
}
|
|
message = self._extract_error_message(result)
|
|
if message and 'not found' in message.lower():
|
|
raise ResourceNotFoundError(f"Resource not found for model {model_id}")
|
|
if message:
|
|
raise RuntimeError(message)
|
|
return None
|
|
except RateLimitError:
|
|
raise
|
|
except ResourceNotFoundError as exc:
|
|
logger.info("Model %s is no longer available on Civitai: %s", model_id, exc)
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Error fetching model versions: %s", e, exc_info=True)
|
|
raise
|
|
|
|
async def get_model_versions_bulk(
|
|
self, model_ids: Sequence[int]
|
|
) -> Optional[Dict[int, Dict]]:
|
|
"""Fetch model metadata for multiple ids using the batch API."""
|
|
|
|
deduped: Dict[int, None] = {}
|
|
for raw_id in model_ids:
|
|
try:
|
|
normalized = int(raw_id)
|
|
except (TypeError, ValueError):
|
|
continue
|
|
deduped.setdefault(normalized, None)
|
|
|
|
normalized_ids = [str(model_id) for model_id in deduped.keys()]
|
|
if not normalized_ids:
|
|
return {}
|
|
|
|
try:
|
|
query = ",".join(normalized_ids)
|
|
success, result = await self._make_request(
|
|
'GET',
|
|
f"{self.base_url}/models",
|
|
use_auth=True,
|
|
params={'ids': query},
|
|
)
|
|
if not success:
|
|
return None
|
|
|
|
items = result.get('items') if isinstance(result, dict) else None
|
|
if not isinstance(items, list):
|
|
return {}
|
|
|
|
payload: Dict[int, Dict] = {}
|
|
for item in items:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
model_id = item.get('id')
|
|
try:
|
|
normalized_id = int(model_id)
|
|
except (TypeError, ValueError):
|
|
continue
|
|
payload[normalized_id] = {
|
|
'modelVersions': item.get('modelVersions', []),
|
|
'type': item.get('type', ''),
|
|
'name': item.get('name', ''),
|
|
'allowNoCredit': item.get('allowNoCredit'),
|
|
'allowCommercialUse': item.get('allowCommercialUse'),
|
|
'allowDerivatives': item.get('allowDerivatives'),
|
|
'allowDifferentLicense': item.get('allowDifferentLicense'),
|
|
}
|
|
return payload
|
|
except RateLimitError:
|
|
raise
|
|
except Exception as exc:
|
|
logger.error(f"Error fetching model versions in bulk: {exc}")
|
|
return None
|
|
|
|
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
|
"""Get specific model version with additional metadata."""
|
|
try:
|
|
if model_id is None and version_id is not None:
|
|
return await self._get_version_by_id_only(version_id)
|
|
|
|
if model_id is not None:
|
|
return await self._get_version_with_model_id(model_id, version_id)
|
|
|
|
logger.error("Either model_id or version_id must be provided")
|
|
return None
|
|
|
|
except RateLimitError:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error fetching model version: {e}")
|
|
return None
|
|
|
|
async def _get_version_by_id_only(self, version_id: int) -> Optional[Dict]:
|
|
version = await self._fetch_version_by_id(version_id)
|
|
if version is None:
|
|
return None
|
|
|
|
model_id = version.get('modelId')
|
|
if not model_id:
|
|
logger.error(f"No modelId found in version {version_id}")
|
|
return None
|
|
|
|
model_data = await self._fetch_model_data(model_id)
|
|
if model_data:
|
|
self._enrich_version_with_model_data(version, model_data)
|
|
|
|
self._remove_comfy_metadata(version)
|
|
return version
|
|
|
|
async def _get_version_with_model_id(self, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
|
|
model_data = await self._fetch_model_data(model_id)
|
|
if not model_data:
|
|
return None
|
|
|
|
target_version = self._select_target_version(model_data, model_id, version_id)
|
|
if target_version is None:
|
|
return None
|
|
|
|
target_version_id = target_version.get('id')
|
|
version = await self._fetch_version_by_id(target_version_id) if target_version_id else None
|
|
|
|
if version is None:
|
|
model_hash = self._extract_primary_model_hash(target_version)
|
|
if model_hash:
|
|
version = await self._fetch_version_by_hash(model_hash)
|
|
else:
|
|
logger.warning(
|
|
f"No primary model hash found for model {model_id} version {target_version_id}"
|
|
)
|
|
|
|
if version is None:
|
|
version = self._build_version_from_model_data(target_version, model_id, model_data)
|
|
|
|
self._enrich_version_with_model_data(version, model_data)
|
|
self._remove_comfy_metadata(version)
|
|
return version
|
|
|
|
async def _fetch_model_data(self, model_id: int) -> Optional[Dict]:
|
|
success, data = await self._make_request(
|
|
'GET',
|
|
f"{self.base_url}/models/{model_id}",
|
|
use_auth=True
|
|
)
|
|
if success:
|
|
return data
|
|
logger.warning(f"Failed to fetch model data for model {model_id}")
|
|
return None
|
|
|
|
async def _fetch_version_by_id(self, version_id: Optional[int]) -> Optional[Dict]:
|
|
if version_id is None:
|
|
return None
|
|
|
|
success, version = await self._make_request(
|
|
'GET',
|
|
f"{self.base_url}/model-versions/{version_id}",
|
|
use_auth=True
|
|
)
|
|
if success:
|
|
return version
|
|
|
|
logger.warning(f"Failed to fetch version by id {version_id}")
|
|
return None
|
|
|
|
async def _fetch_version_by_hash(self, model_hash: Optional[str]) -> Optional[Dict]:
|
|
if not model_hash:
|
|
return None
|
|
|
|
success, version = await self._make_request(
|
|
'GET',
|
|
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
|
use_auth=True
|
|
)
|
|
if success:
|
|
return version
|
|
|
|
logger.warning(f"Failed to fetch version by hash {model_hash}")
|
|
return None
|
|
|
|
def _select_target_version(self, model_data: Dict, model_id: int, version_id: Optional[int]) -> Optional[Dict]:
|
|
model_versions = model_data.get('modelVersions', [])
|
|
if not model_versions:
|
|
logger.warning(f"No model versions found for model {model_id}")
|
|
return None
|
|
|
|
if version_id is not None:
|
|
target_version = next(
|
|
(item for item in model_versions if item.get('id') == version_id),
|
|
None
|
|
)
|
|
if target_version is None:
|
|
logger.warning(
|
|
f"Version {version_id} not found for model {model_id}, defaulting to first version"
|
|
)
|
|
return model_versions[0]
|
|
return target_version
|
|
|
|
return model_versions[0]
|
|
|
|
def _extract_primary_model_hash(self, version_entry: Dict) -> Optional[str]:
|
|
for file_info in version_entry.get('files', []):
|
|
if file_info.get('type') == 'Model' and file_info.get('primary'):
|
|
hashes = file_info.get('hashes', {})
|
|
model_hash = hashes.get('SHA256')
|
|
if model_hash:
|
|
return model_hash
|
|
return None
|
|
|
|
def _build_version_from_model_data(self, version_entry: Dict, model_id: int, model_data: Dict) -> Dict:
|
|
version = copy.deepcopy(version_entry)
|
|
version.pop('index', None)
|
|
version['modelId'] = model_id
|
|
version['model'] = {
|
|
'name': model_data.get('name'),
|
|
'type': model_data.get('type'),
|
|
'nsfw': model_data.get('nsfw'),
|
|
'poi': model_data.get('poi')
|
|
}
|
|
return version
|
|
|
|
def _enrich_version_with_model_data(self, version: Dict, model_data: Dict) -> None:
|
|
model_info = version.get('model')
|
|
if not isinstance(model_info, dict):
|
|
model_info = {}
|
|
version['model'] = model_info
|
|
|
|
model_info['description'] = model_data.get("description")
|
|
model_info['tags'] = model_data.get("tags", [])
|
|
version['creator'] = model_data.get("creator")
|
|
|
|
license_payload = resolve_license_payload(model_data)
|
|
for field, value in license_payload.items():
|
|
model_info[field] = value
|
|
|
|
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
|
"""Fetch model version metadata from Civitai
|
|
|
|
Args:
|
|
version_id: The Civitai model version ID
|
|
|
|
Returns:
|
|
Tuple[Optional[Dict], Optional[str]]: A tuple containing:
|
|
- The model version data or None if not found
|
|
- An error message if there was an error, or None on success
|
|
"""
|
|
try:
|
|
url = f"{self.base_url}/model-versions/{version_id}"
|
|
|
|
logger.debug(f"Resolving DNS for model version info: {url}")
|
|
success, result = await self._make_request(
|
|
'GET',
|
|
url,
|
|
use_auth=True
|
|
)
|
|
|
|
if success:
|
|
logger.debug(f"Successfully fetched model version info for: {version_id}")
|
|
self._remove_comfy_metadata(result)
|
|
return result, None
|
|
|
|
# Handle specific error cases
|
|
if "not found" in str(result):
|
|
error_msg = f"Model not found"
|
|
logger.warning(f"Model version not found: {version_id} - {error_msg}")
|
|
return None, error_msg
|
|
|
|
# Other error cases
|
|
logger.error(f"Failed to fetch model info for {version_id}: {result}")
|
|
return None, str(result)
|
|
except RateLimitError:
|
|
raise
|
|
except Exception as e:
|
|
error_msg = f"Error fetching model version info: {e}"
|
|
logger.error(error_msg)
|
|
return None, error_msg
|
|
|
|
async def get_image_info(self, image_id: str) -> Optional[Dict]:
|
|
"""Fetch image information from Civitai API
|
|
|
|
Args:
|
|
image_id: The Civitai image ID
|
|
|
|
Returns:
|
|
Optional[Dict]: The image data or None if not found
|
|
"""
|
|
try:
|
|
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
|
|
|
|
logger.debug(f"Fetching image info for ID: {image_id}")
|
|
success, result = await self._make_request(
|
|
'GET',
|
|
url,
|
|
use_auth=True
|
|
)
|
|
|
|
if success:
|
|
if result and "items" in result and len(result["items"]) > 0:
|
|
logger.debug(f"Successfully fetched image info for ID: {image_id}")
|
|
return result["items"][0]
|
|
logger.warning(f"No image found with ID: {image_id}")
|
|
return None
|
|
|
|
logger.error(f"Failed to fetch image info for ID: {image_id}: {result}")
|
|
return None
|
|
except RateLimitError:
|
|
raise
|
|
except Exception as e:
|
|
error_msg = f"Error fetching image info: {e}"
|
|
logger.error(error_msg)
|
|
return None
|
|
|
|
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
|
"""Fetch all models for a specific Civitai user."""
|
|
if not username:
|
|
return None
|
|
|
|
try:
|
|
url = f"{self.base_url}/models?username={username}"
|
|
success, result = await self._make_request(
|
|
'GET',
|
|
url,
|
|
use_auth=True
|
|
)
|
|
|
|
if not success:
|
|
logger.error("Failed to fetch models for %s: %s", username, result)
|
|
return None
|
|
|
|
items = result.get("items") if isinstance(result, dict) else None
|
|
if not isinstance(items, list):
|
|
return []
|
|
|
|
for model in items:
|
|
versions = model.get("modelVersions")
|
|
if not isinstance(versions, list):
|
|
continue
|
|
for version in versions:
|
|
self._remove_comfy_metadata(version)
|
|
|
|
return items
|
|
except RateLimitError:
|
|
raise
|
|
except Exception as exc: # pragma: no cover - defensive logging
|
|
logger.error("Error fetching models for %s: %s", username, exc)
|
|
return None
|