mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 14:12:11 -03:00
- Introduced a new `Downloader` class to centralize HTTP/HTTPS download management. - Replaced direct `aiohttp` session handling with the unified downloader in `MetadataArchiveManager`, `DownloadManager`, and `ExampleImagesProcessor`. - Added support for resumable downloads, progress tracking, and error handling in the new downloader. - Updated methods to utilize the downloader's capabilities for downloading files and images, improving code maintainability and readability.
330 lines
13 KiB
Python
330 lines
13 KiB
Python
from datetime import datetime
|
|
import os
|
|
import logging
|
|
import asyncio
|
|
from typing import Optional, Dict, Tuple, List
|
|
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
|
|
from .downloader import get_downloader
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class CivitaiClient:
|
|
_instance = None
|
|
_lock = asyncio.Lock()
|
|
|
|
@classmethod
|
|
async def get_instance(cls):
|
|
"""Get singleton instance of CivitaiClient"""
|
|
async with cls._lock:
|
|
if cls._instance is None:
|
|
cls._instance = cls()
|
|
|
|
# Register this client as a metadata provider
|
|
provider_manager = await ModelMetadataProviderManager.get_instance()
|
|
provider_manager.register_provider('civitai', CivitaiModelMetadataProvider(cls._instance), True)
|
|
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
# Check if already initialized for singleton pattern
|
|
if hasattr(self, '_initialized'):
|
|
return
|
|
self._initialized = True
|
|
|
|
self.base_url = "https://civitai.com/api/v1"
|
|
|
|
async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
|
|
"""Download file with resumable downloads and retry mechanism
|
|
|
|
Args:
|
|
url: Download URL
|
|
save_dir: Directory to save the file
|
|
default_filename: Fallback filename if none provided in headers
|
|
progress_callback: Optional async callback function for progress updates (0-100)
|
|
|
|
Returns:
|
|
Tuple[bool, str]: (success, save_path or error message)
|
|
"""
|
|
downloader = await get_downloader()
|
|
save_path = os.path.join(save_dir, default_filename)
|
|
|
|
# Use unified downloader with CivitAI authentication
|
|
success, result = await downloader.download_file(
|
|
url=url,
|
|
save_path=save_path,
|
|
progress_callback=progress_callback,
|
|
use_auth=True, # Enable CivitAI authentication
|
|
allow_resume=True
|
|
)
|
|
|
|
return success, result
|
|
|
|
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
|
try:
|
|
downloader = await get_downloader()
|
|
success, result = await downloader.make_request(
|
|
'GET',
|
|
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
|
use_auth=True
|
|
)
|
|
if success:
|
|
return result
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"API Error: {str(e)}")
|
|
return None
|
|
|
|
async def download_preview_image(self, image_url: str, save_path: str):
|
|
try:
|
|
downloader = await get_downloader()
|
|
success, content = await downloader.download_to_memory(
|
|
image_url,
|
|
use_auth=False # Preview images don't need auth
|
|
)
|
|
if success:
|
|
# Ensure directory exists
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
with open(save_path, 'wb') as f:
|
|
f.write(content)
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Download Error: {str(e)}")
|
|
return False
|
|
|
|
async def get_model_versions(self, model_id: str) -> List[Dict]:
|
|
"""Get all versions of a model with local availability info"""
|
|
try:
|
|
downloader = await get_downloader()
|
|
success, result = await downloader.make_request(
|
|
'GET',
|
|
f"{self.base_url}/models/{model_id}",
|
|
use_auth=True
|
|
)
|
|
if success:
|
|
# Also return model type along with versions
|
|
return {
|
|
'modelVersions': result.get('modelVersions', []),
|
|
'type': result.get('type', '')
|
|
}
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error fetching model versions: {e}")
|
|
return None
|
|
|
|
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
|
"""Get specific model version with additional metadata
|
|
|
|
Args:
|
|
model_id: The Civitai model ID (optional if version_id is provided)
|
|
version_id: Optional specific version ID to retrieve
|
|
|
|
Returns:
|
|
Optional[Dict]: The model version data with additional fields or None if not found
|
|
"""
|
|
try:
|
|
downloader = await get_downloader()
|
|
|
|
# Case 1: Only version_id is provided
|
|
if model_id is None and version_id is not None:
|
|
# First get the version info to extract model_id
|
|
success, version = await downloader.make_request(
|
|
'GET',
|
|
f"{self.base_url}/model-versions/{version_id}",
|
|
use_auth=True
|
|
)
|
|
if not success:
|
|
return None
|
|
|
|
model_id = version.get('modelId')
|
|
if not model_id:
|
|
logger.error(f"No modelId found in version {version_id}")
|
|
return None
|
|
|
|
# Now get the model data for additional metadata
|
|
success, model_data = await downloader.make_request(
|
|
'GET',
|
|
f"{self.base_url}/models/{model_id}",
|
|
use_auth=True
|
|
)
|
|
if success:
|
|
# Enrich version with model data
|
|
version['model']['description'] = model_data.get("description")
|
|
version['model']['tags'] = model_data.get("tags", [])
|
|
version['creator'] = model_data.get("creator")
|
|
|
|
return version
|
|
|
|
# Case 2: model_id is provided (with or without version_id)
|
|
elif model_id is not None:
|
|
# Step 1: Get model data to find version_id if not provided and get additional metadata
|
|
success, data = await downloader.make_request(
|
|
'GET',
|
|
f"{self.base_url}/models/{model_id}",
|
|
use_auth=True
|
|
)
|
|
if not success:
|
|
return None
|
|
|
|
model_versions = data.get('modelVersions', [])
|
|
|
|
# Step 2: Determine the version_id to use
|
|
target_version_id = version_id
|
|
if target_version_id is None:
|
|
target_version_id = model_versions[0].get('id')
|
|
|
|
# Step 3: Get detailed version info using the version_id
|
|
success, version = await downloader.make_request(
|
|
'GET',
|
|
f"{self.base_url}/model-versions/{target_version_id}",
|
|
use_auth=True
|
|
)
|
|
if not success:
|
|
return None
|
|
|
|
# Step 4: Enrich version_info with model data
|
|
# Add description and tags from model data
|
|
version['model']['description'] = data.get("description")
|
|
version['model']['tags'] = data.get("tags", [])
|
|
|
|
# Add creator from model data
|
|
version['creator'] = data.get("creator")
|
|
|
|
return version
|
|
|
|
# Case 3: Neither model_id nor version_id provided
|
|
else:
|
|
logger.error("Either model_id or version_id must be provided")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching model version: {e}")
|
|
return None
|
|
|
|
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
|
"""Fetch model version metadata from Civitai
|
|
|
|
Args:
|
|
version_id: The Civitai model version ID
|
|
|
|
Returns:
|
|
Tuple[Optional[Dict], Optional[str]]: A tuple containing:
|
|
- The model version data or None if not found
|
|
- An error message if there was an error, or None on success
|
|
"""
|
|
try:
|
|
downloader = await get_downloader()
|
|
url = f"{self.base_url}/model-versions/{version_id}"
|
|
|
|
logger.debug(f"Resolving DNS for model version info: {url}")
|
|
success, result = await downloader.make_request(
|
|
'GET',
|
|
url,
|
|
use_auth=True
|
|
)
|
|
|
|
if success:
|
|
logger.debug(f"Successfully fetched model version info for: {version_id}")
|
|
return result, None
|
|
|
|
# Handle specific error cases
|
|
if "404" in str(result):
|
|
error_msg = f"Model not found (status 404)"
|
|
logger.warning(f"Model version not found: {version_id} - {error_msg}")
|
|
return None, error_msg
|
|
|
|
# Other error cases
|
|
logger.error(f"Failed to fetch model info for {version_id}: {result}")
|
|
return None, str(result)
|
|
except Exception as e:
|
|
error_msg = f"Error fetching model version info: {e}"
|
|
logger.error(error_msg)
|
|
return None, error_msg
|
|
|
|
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
|
|
"""Fetch model metadata (description, tags, and creator info) from Civitai API
|
|
|
|
Args:
|
|
model_id: The Civitai model ID
|
|
|
|
Returns:
|
|
Tuple[Optional[Dict], int]: A tuple containing:
|
|
- A dictionary with model metadata or None if not found
|
|
- The HTTP status code from the request (0 for exceptions)
|
|
"""
|
|
try:
|
|
downloader = await get_downloader()
|
|
url = f"{self.base_url}/models/{model_id}"
|
|
|
|
success, result = await downloader.make_request(
|
|
'GET',
|
|
url,
|
|
use_auth=True
|
|
)
|
|
|
|
if not success:
|
|
# Try to extract status code from error message
|
|
status_code = 0
|
|
if "404" in str(result):
|
|
status_code = 404
|
|
elif "401" in str(result):
|
|
status_code = 401
|
|
elif "403" in str(result):
|
|
status_code = 403
|
|
logger.warning(f"Failed to fetch model metadata: {result}")
|
|
return None, status_code
|
|
|
|
# Extract relevant metadata
|
|
metadata = {
|
|
"description": result.get("description") or "No model description available",
|
|
"tags": result.get("tags", []),
|
|
"creator": {
|
|
"username": result.get("creator", {}).get("username"),
|
|
"image": result.get("creator", {}).get("image")
|
|
}
|
|
}
|
|
|
|
if metadata["description"] or metadata["tags"] or metadata["creator"]["username"]:
|
|
return metadata, 200
|
|
else:
|
|
logger.warning(f"No metadata found for model {model_id}")
|
|
return None, 200
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching model metadata: {e}", exc_info=True)
|
|
return None, 0
|
|
|
|
async def get_image_info(self, image_id: str) -> Optional[Dict]:
|
|
"""Fetch image information from Civitai API
|
|
|
|
Args:
|
|
image_id: The Civitai image ID
|
|
|
|
Returns:
|
|
Optional[Dict]: The image data or None if not found
|
|
"""
|
|
try:
|
|
downloader = await get_downloader()
|
|
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
|
|
|
|
logger.debug(f"Fetching image info for ID: {image_id}")
|
|
success, result = await downloader.make_request(
|
|
'GET',
|
|
url,
|
|
use_auth=True
|
|
)
|
|
|
|
if success:
|
|
if result and "items" in result and len(result["items"]) > 0:
|
|
logger.debug(f"Successfully fetched image info for ID: {image_id}")
|
|
return result["items"][0]
|
|
logger.warning(f"No image found with ID: {image_id}")
|
|
return None
|
|
|
|
logger.error(f"Failed to fetch image info for ID: {image_id}: {result}")
|
|
return None
|
|
except Exception as e:
|
|
error_msg = f"Error fetching image info: {e}"
|
|
logger.error(error_msg)
|
|
return None
|