mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 22:22:11 -03:00
Better Civ Archive support (adds API) (#549)
* add CivArchive API * Oops, missed committing this part when I updated codebase to latest version * Adjust API for version fetching and solve the broken API (hash gives only files, not models - likely to be fixed but in the meantime...) * add asyncio import to allow timeout cooldown --------- Co-authored-by: Scruffy Nerf <Scruffynerf@duck.com>
This commit is contained in:
419
py/services/civarchive_client.py
Normal file
419
py/services/civarchive_client.py
Normal file
@@ -0,0 +1,419 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Tuple, List
|
||||
from .model_metadata_provider import CivArchiveModelMetadataProvider, ModelMetadataProviderManager
|
||||
from .downloader import get_downloader
|
||||
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
except ImportError as exc:
|
||||
BeautifulSoup = None # type: ignore[assignment]
|
||||
_BS4_IMPORT_ERROR = exc
|
||||
else:
|
||||
_BS4_IMPORT_ERROR = None
|
||||
|
||||
def _require_beautifulsoup():
|
||||
if BeautifulSoup is None:
|
||||
raise RuntimeError(
|
||||
"BeautifulSoup (bs4) is required for CivArchive client. "
|
||||
"Install it with 'pip install beautifulsoup4'."
|
||||
) from _BS4_IMPORT_ERROR
|
||||
return BeautifulSoup
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CivArchiveClient:
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance of CivArchiveClient"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
|
||||
# Register this client as a metadata provider
|
||||
provider_manager = await ModelMetadataProviderManager.get_instance()
|
||||
provider_manager.register_provider('civarchive', CivArchiveModelMetadataProvider(cls._instance), False)
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
# Check if already initialized for singleton pattern
|
||||
if hasattr(self, '_initialized'):
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
self.base_url = "https://civarchive.com/api"
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Find model by SHA256 hash value using CivArchive API"""
|
||||
if "/" in model_hash:
|
||||
metadata = await self.get_model_by_url(model_hash)
|
||||
if metadata:
|
||||
return metadata, None
|
||||
else:
|
||||
return None, f"Error fetching url: {model_hash}"
|
||||
try:
|
||||
# CivArchive only supports SHA256 hashes
|
||||
url = f"{self.base_url}/sha256/{model_hash.lower()}"
|
||||
|
||||
downloader = await get_downloader()
|
||||
session = await downloader.session
|
||||
async with session.get(url) as response:
|
||||
if response.status != 200:
|
||||
if response.status == 404:
|
||||
return None, "Model not found"
|
||||
return None, f"HTTP {response.status}"
|
||||
|
||||
data = await response.json()
|
||||
|
||||
# Extract the model and version data from CivArchive structure
|
||||
model_data = data.get('model', {})
|
||||
version_data = model_data.get('version', {})
|
||||
files_data = data.get('files', {})
|
||||
|
||||
if not version_data:
|
||||
if files_data:
|
||||
logger.error(f"{data}")
|
||||
# sometimes CivArc returns ONLY file info... but it can then be used to get the rest of the info...
|
||||
# actually as of now (10/25), api broke and ONLY returns 'files' info...
|
||||
for file_data in files_data:
|
||||
logger.error(f"{file_data}")
|
||||
if file_data["source"] == "civitai":
|
||||
api_data = await self.get_model_version(file_data["model_id"], file_data["model_version_id"])
|
||||
logger.error(f"{api_data}")
|
||||
logger.error(f"found CivArchive model by hash {model_hash[:10]}")
|
||||
return api_data, None
|
||||
else:
|
||||
logger.error(f"Error fetching version of CivArchive model by hash {model_hash[:10]}")
|
||||
return None, "No version data found"
|
||||
|
||||
# Transform to match expected format
|
||||
result = version_data.copy()
|
||||
|
||||
# Add model information
|
||||
result['model'] = {
|
||||
'name': model_data.get('name'),
|
||||
'type': model_data.get('type'),
|
||||
'nsfw': model_data.get('nsfw', False),
|
||||
'description': model_data.get('description'),
|
||||
'tags': model_data.get('tags', [])
|
||||
}
|
||||
|
||||
# Add creator information
|
||||
result['creator'] = {
|
||||
'username': model_data.get('username', model_data.get('creator_username')),
|
||||
'image': ''
|
||||
}
|
||||
|
||||
# Rename trigger to trainedWords for consistency
|
||||
if 'trigger' in result:
|
||||
result['trainedWords'] = result.pop('trigger')
|
||||
|
||||
# Transform stats
|
||||
if 'downloadCount' in result and 'ratingCount' in result and 'rating' in result:
|
||||
result['stats'] = {
|
||||
'downloadCount': result.pop('downloadCount'),
|
||||
'ratingCount': result.pop('ratingCount'),
|
||||
'rating': result.pop('rating')
|
||||
}
|
||||
|
||||
# Transform files to match expected format
|
||||
if 'files' in result:
|
||||
transformed_files = []
|
||||
for file_data in result['files']:
|
||||
# Find first available mirror
|
||||
available_mirror = None
|
||||
for mirror in file_data.get('mirrors', []):
|
||||
if mirror.get('deletedAt') is None:
|
||||
available_mirror = mirror
|
||||
break
|
||||
|
||||
transformed_file = {
|
||||
'id': file_data.get('id'),
|
||||
'sizeKB': file_data.get('sizeKB'),
|
||||
'name': available_mirror.get('filename', file_data.get('name')) if available_mirror else file_data.get('name'),
|
||||
'type': file_data.get('type'),
|
||||
'downloadUrl': available_mirror.get('url') if available_mirror else file_data.get('downloadUrl'),
|
||||
'primary': True,
|
||||
'mirrors': file_data.get('mirrors', [])
|
||||
}
|
||||
|
||||
# Transform hash format
|
||||
if 'sha256' in file_data:
|
||||
transformed_file['hashes'] = {
|
||||
'SHA256': file_data['sha256'].upper()
|
||||
}
|
||||
|
||||
transformed_files.append(transformed_file)
|
||||
|
||||
result['files'] = transformed_files
|
||||
|
||||
# Add source identifier
|
||||
result['source'] = 'civarchive'
|
||||
|
||||
return result, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivArchive model by hash {model_hash[:10]}: {e}")
|
||||
return None, str(e)
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
"""Get all versions of a model using CivArchive API"""
|
||||
try:
|
||||
url = f"{self.base_url}/models/{model_id}"
|
||||
|
||||
downloader = await get_downloader()
|
||||
session = await downloader.session
|
||||
async with session.get(url) as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
data = await response.json()
|
||||
|
||||
# Extract versions list
|
||||
versions = data.get('versions', [])
|
||||
|
||||
# Return in format similar to Civitai
|
||||
return {
|
||||
'modelVersions': versions,
|
||||
'type': data.get('type', ''),
|
||||
'name': data.get('name', '')
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivArchive model versions for {model_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
"""Get specific model version using CivArchive API
|
||||
|
||||
Args:
|
||||
model_id: The model ID (required)
|
||||
version_id: Optional specific version ID to filter to
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: The model version data or None if not found
|
||||
"""
|
||||
if model_id is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
if version_id is not None:
|
||||
url = f"{self.base_url}/models/{model_id}?modelVersionId={version_id}"
|
||||
else:
|
||||
url = f"{self.base_url}/models/{model_id}"
|
||||
|
||||
downloader = await get_downloader()
|
||||
session = await downloader.session
|
||||
async with session.get(url) as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
data = await response.json()
|
||||
|
||||
# Get the version data - CivArchive returns the latest/default version in 'version' field
|
||||
version_data = data.get('version', {})
|
||||
versions = data.get('versions', {})
|
||||
|
||||
# If version_id is specified, check if it matches
|
||||
if version_id is not None:
|
||||
if version_data.get('id') != version_id:
|
||||
# Version mismatch - would need to iterate through versions or make another call
|
||||
# For now, return None as CivArchive API doesn't provide easy version filtering
|
||||
logger.warning(f"Requested version {version_id} doesn't match default version {version_data.get('id')} for model {model_id}")
|
||||
return None
|
||||
if version_data.get('modelId') != model_id:
|
||||
# you can pass ANY model id, and a version number, and get the CORRECT model id from this...
|
||||
# so recall the api with the correct info now
|
||||
return await self.get_model_version(version_data.get('modelId'), version_id)
|
||||
|
||||
# Transform to expected format
|
||||
result = version_data.copy()
|
||||
|
||||
# Restructure stats
|
||||
if 'downloadCount' in result and 'ratingCount' in result and 'rating' in result:
|
||||
result['stats'] = {
|
||||
'downloadCount': result.pop('downloadCount'),
|
||||
'ratingCount': result.pop('ratingCount'),
|
||||
'rating': result.pop('rating')
|
||||
}
|
||||
|
||||
# Rename trigger to trainedWords
|
||||
if 'trigger' in result:
|
||||
result['trainedWords'] = result.pop('trigger')
|
||||
|
||||
# Transform files data
|
||||
if 'files' in result:
|
||||
transformed_files = []
|
||||
for file_data in result['files']:
|
||||
# Find first available mirror
|
||||
available_mirror = None
|
||||
for mirror in file_data.get('mirrors', []):
|
||||
if mirror.get('deletedAt') is None:
|
||||
available_mirror = mirror
|
||||
break
|
||||
|
||||
transformed_file = {
|
||||
'id': file_data.get('id'),
|
||||
'sizeKB': file_data.get('sizeKB'),
|
||||
'name': available_mirror.get('filename', file_data.get('name')) if available_mirror else file_data.get('name'),
|
||||
'type': file_data.get('type'),
|
||||
'downloadUrl': available_mirror.get('url') if available_mirror else file_data.get('downloadUrl'),
|
||||
'primary': True,
|
||||
'mirrors': file_data.get('mirrors', [])
|
||||
}
|
||||
|
||||
# Transform hash format
|
||||
if 'sha256' in file_data:
|
||||
transformed_file['hashes'] = {
|
||||
'SHA256': file_data['sha256'].upper()
|
||||
}
|
||||
|
||||
transformed_files.append(transformed_file)
|
||||
|
||||
result['files'] = transformed_files
|
||||
|
||||
# Add model information
|
||||
result['model'] = {
|
||||
'name': data.get('name'),
|
||||
'type': data.get('type'),
|
||||
'nsfw': data.get('is_nsfw', False),
|
||||
'description': data.get('description'),
|
||||
'tags': data.get('tags', [])
|
||||
}
|
||||
|
||||
result['creator'] = {
|
||||
'username': data.get('username', data.get('creator_username')),
|
||||
'image': ''
|
||||
}
|
||||
|
||||
# Add source identifier
|
||||
result['source'] = 'civarchive'
|
||||
result['is_deleted'] = data.get('deletedAt') is not None
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivArchive model version via API {model_id}/{version_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
""" Fetch model version metadata using a known bogus model lookup
|
||||
CivArchive lacks a direct version lookup API, this uses a workaround (which we handle in the main model request now)
|
||||
|
||||
Args:
|
||||
version_id: The model version ID
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[Dict], Optional[str]]: (version_data, error_message)
|
||||
"""
|
||||
return await self.get_model_version(1, version_id)
|
||||
|
||||
async def get_model_by_url(self, url) -> Optional[Dict]:
|
||||
"""Get specific model version by parsing CivArchive HTML page (legacy method)
|
||||
|
||||
This is the original HTML scraping implementation, kept for reference and new sites added not in api.
|
||||
The primary get_model_version() now uses the API instead.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Construct CivArchive URL
|
||||
url = f"https://civarchive.com/{url}"
|
||||
downloader = await get_downloader()
|
||||
session = await downloader.session
|
||||
async with session.get(url) as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
html_content = await response.text()
|
||||
|
||||
# Parse HTML to extract JSON data
|
||||
soup_parser = _require_beautifulsoup()
|
||||
soup = soup_parser(html_content, 'html.parser')
|
||||
script_tag = soup.find('script', {'id': '__NEXT_DATA__', 'type': 'application/json'})
|
||||
|
||||
if not script_tag:
|
||||
return None
|
||||
|
||||
# Parse JSON content
|
||||
json_data = json.loads(script_tag.string)
|
||||
model_data = json_data.get('props', {}).get('pageProps', {}).get('model')
|
||||
|
||||
if not model_data or 'version' not in model_data:
|
||||
return None
|
||||
|
||||
# Extract version data as base
|
||||
version = model_data['version'].copy()
|
||||
|
||||
# Restructure stats
|
||||
if 'downloadCount' in version and 'ratingCount' in version and 'rating' in version:
|
||||
version['stats'] = {
|
||||
'downloadCount': version.pop('downloadCount'),
|
||||
'ratingCount': version.pop('ratingCount'),
|
||||
'rating': version.pop('rating')
|
||||
}
|
||||
|
||||
# Rename trigger to trainedWords
|
||||
if 'trigger' in version:
|
||||
version['trainedWords'] = version.pop('trigger')
|
||||
|
||||
# Transform files data to expected format
|
||||
if 'files' in version:
|
||||
transformed_files = []
|
||||
for file_data in version['files']:
|
||||
# Find first available mirror (deletedAt is null)
|
||||
available_mirror = None
|
||||
for mirror in file_data.get('mirrors', []):
|
||||
if mirror.get('deletedAt') is None:
|
||||
available_mirror = mirror
|
||||
break
|
||||
|
||||
# Create transformed file entry
|
||||
transformed_file = {
|
||||
'id': file_data.get('id'),
|
||||
'sizeKB': file_data.get('sizeKB'),
|
||||
'name': available_mirror.get('filename', file_data.get('name')) if available_mirror else file_data.get('name'),
|
||||
'type': file_data.get('type'),
|
||||
'downloadUrl': available_mirror.get('url') if available_mirror else None,
|
||||
'primary': True,
|
||||
'mirrors': file_data.get('mirrors', [])
|
||||
}
|
||||
|
||||
# Transform hash format
|
||||
if 'sha256' in file_data:
|
||||
transformed_file['hashes'] = {
|
||||
'SHA256': file_data['sha256'].upper()
|
||||
}
|
||||
|
||||
transformed_files.append(transformed_file)
|
||||
|
||||
version['files'] = transformed_files
|
||||
|
||||
# Add model information
|
||||
version['model'] = {
|
||||
'name': model_data.get('name'),
|
||||
'type': model_data.get('type'),
|
||||
'nsfw': model_data.get('is_nsfw', False),
|
||||
'description': model_data.get('description'),
|
||||
'tags': model_data.get('tags', [])
|
||||
}
|
||||
|
||||
version['creator'] = {
|
||||
'username': model_data.get('username'),
|
||||
'image': ''
|
||||
}
|
||||
|
||||
# Add source identifier
|
||||
version['source'] = 'civarchive'
|
||||
version['is_deleted'] = json_data.get('query', {}).get('is_deleted', False)
|
||||
|
||||
return version
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivArchive model version (scraping) {model_id}/{version_id}: {e}")
|
||||
return None
|
||||
@@ -4,6 +4,7 @@ from .model_metadata_provider import (
|
||||
ModelMetadataProviderManager,
|
||||
SQLiteModelMetadataProvider,
|
||||
CivitaiModelMetadataProvider,
|
||||
CivArchiveModelMetadataProvider,
|
||||
FallbackMetadataProvider
|
||||
)
|
||||
from .settings_manager import get_settings_manager
|
||||
@@ -54,26 +55,28 @@ async def initialize_metadata_providers():
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Civitai API metadata provider: {e}")
|
||||
|
||||
# Register CivArchive provider, but do NOT add to fallback providers
|
||||
# Register CivArchive provider, and all add to fallback providers
|
||||
try:
|
||||
from .model_metadata_provider import CivArchiveModelMetadataProvider
|
||||
civarchive_provider = CivArchiveModelMetadataProvider()
|
||||
provider_manager.register_provider('civarchive', civarchive_provider)
|
||||
logger.debug("CivArchive metadata provider registered (not included in fallback)")
|
||||
civarchive_client = await ServiceRegistry.get_civarchive_client()
|
||||
civarchive_provider = CivitaiModelMetadataProvider(civarchive_client)
|
||||
provider_manager.register_provider('civarchive_api', civarchive_provider)
|
||||
providers.append(('civarchive_api', civarchive_provider))
|
||||
logger.debug("CivArchive metadata provider registered (also included in fallback)")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize CivArchive metadata provider: {e}")
|
||||
|
||||
# Set up fallback provider based on available providers
|
||||
if len(providers) > 1:
|
||||
# Always use Civitai API first, then Archive DB
|
||||
# Always use Civarchive, then Civitai API, then Archive DB
|
||||
ordered_providers = []
|
||||
ordered_providers.extend([p[1] for p in providers if p[0] == 'civarchive_api'])
|
||||
ordered_providers.extend([p[1] for p in providers if p[0] == 'civitai_api'])
|
||||
ordered_providers.extend([p[1] for p in providers if p[0] == 'sqlite'])
|
||||
|
||||
if ordered_providers:
|
||||
fallback_provider = FallbackMetadataProvider(ordered_providers)
|
||||
provider_manager.register_provider('fallback', fallback_provider, is_default=True)
|
||||
logger.debug(f"Fallback metadata provider registered with {len(ordered_providers)} providers, Civitai API first")
|
||||
logger.info(f"Fallback metadata provider registered with {len(ordered_providers)} providers, Civarchive first")
|
||||
elif len(providers) == 1:
|
||||
# Only one provider available, set it as default
|
||||
provider_name, provider = providers[0]
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Any, Awaitable, Callable, Dict, Iterable, Optional
|
||||
|
||||
@@ -169,6 +170,46 @@ class MetadataSyncService:
|
||||
enable_archive = self._settings.get("enable_metadata_archive_db", False)
|
||||
|
||||
try:
|
||||
metadata_provider = await self._get_provider("civarchive_api")
|
||||
tryagain = True
|
||||
delay = 5
|
||||
|
||||
while tryagain:
|
||||
civitai_metadata, error = await metadata_provider.get_model_by_hash(sha256)
|
||||
tryagain = False
|
||||
if not civitai_metadata or error:
|
||||
if error == "HTTP 429":
|
||||
error_msg = (f"Error fetching metadata: {error} (model_name={model_data.get('model_name', '')} sha256={sha256})")
|
||||
logger.error(error_msg)
|
||||
delay = delay * 2
|
||||
await asyncio.sleep(delay)
|
||||
tryagain = True
|
||||
continue
|
||||
if error == "Model not found":
|
||||
model_data["from_civitai"] = False
|
||||
model_data["civitai_deleted"] = True
|
||||
#model_data["db_checked"] = enable_archive
|
||||
model_data["last_checked_at"] = datetime.now().timestamp()
|
||||
data_to_save = model_data.copy()
|
||||
data_to_save.pop("folder", None)
|
||||
await self._metadata_manager.save_metadata(file_path, data_to_save)
|
||||
await asyncio.sleep(1)
|
||||
if error == "No version data found":
|
||||
error_msg = (f"Error - No civitai version found: (model_name={model_data.get('model_name', '')} sha256={sha256})")
|
||||
logger.error(error_msg)
|
||||
error = False
|
||||
if civitai_metadata.get('files'):
|
||||
for file in civitai_metadata['files']:
|
||||
logger.error(f"{file}")
|
||||
if 'tensorart' in file['url'] or "seaart" in file['url']:
|
||||
civitai_metadata, error = await metadata_provider.get_model_by_hash(file['url'])
|
||||
error_msg = (f"Error fetching metadata: {error} {civitai_metadata}")
|
||||
logger.error(error_msg)
|
||||
if error or not civitai_metadata:
|
||||
error_msg = (f"Error fetching metadata: {error} (model_name={model_data.get('model_name', '')} sha256={sha256})")
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
if model_data.get("civitai_deleted") is True:
|
||||
if not enable_archive or model_data.get("db_checked") is True:
|
||||
if not enable_archive:
|
||||
|
||||
@@ -88,122 +88,22 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
||||
return await self.client.get_user_models(username)
|
||||
|
||||
class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
||||
"""Provider that uses CivArchive HTML page parsing for metadata"""
|
||||
"""Provider that uses CivArchive API for metadata"""
|
||||
|
||||
def __init__(self, civarchive_client):
|
||||
self.client = civarchive_client
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Not supported by CivArchive provider"""
|
||||
return None, "CivArchive provider does not support hash lookup"
|
||||
return await self.client.get_model_by_hash(model_hash)
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
"""Not supported by CivArchive provider"""
|
||||
return None
|
||||
return await self.client.get_model_versions(model_id)
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
"""Get specific model version by parsing CivArchive HTML page"""
|
||||
if model_id is None or version_id is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Construct CivArchive URL
|
||||
url = f"https://civarchive.com/models/{model_id}?modelVersionId={version_id}"
|
||||
|
||||
downloader = await get_downloader()
|
||||
session = await downloader.session
|
||||
async with session.get(url) as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
html_content = await response.text()
|
||||
|
||||
# Parse HTML to extract JSON data
|
||||
soup_parser = _require_beautifulsoup()
|
||||
soup = soup_parser(html_content, 'html.parser')
|
||||
script_tag = soup.find('script', {'id': '__NEXT_DATA__', 'type': 'application/json'})
|
||||
|
||||
if not script_tag:
|
||||
return None
|
||||
|
||||
# Parse JSON content
|
||||
json_data = json.loads(script_tag.string)
|
||||
model_data = json_data.get('props', {}).get('pageProps', {}).get('model')
|
||||
|
||||
if not model_data or 'version' not in model_data:
|
||||
return None
|
||||
|
||||
# Extract version data as base
|
||||
version = model_data['version'].copy()
|
||||
|
||||
# Restructure stats
|
||||
if 'downloadCount' in version and 'ratingCount' in version and 'rating' in version:
|
||||
version['stats'] = {
|
||||
'downloadCount': version.pop('downloadCount'),
|
||||
'ratingCount': version.pop('ratingCount'),
|
||||
'rating': version.pop('rating')
|
||||
}
|
||||
|
||||
# Rename trigger to trainedWords
|
||||
if 'trigger' in version:
|
||||
version['trainedWords'] = version.pop('trigger')
|
||||
|
||||
# Transform files data to expected format
|
||||
if 'files' in version:
|
||||
transformed_files = []
|
||||
for file_data in version['files']:
|
||||
# Find first available mirror (deletedAt is null)
|
||||
available_mirror = None
|
||||
for mirror in file_data.get('mirrors', []):
|
||||
if mirror.get('deletedAt') is None:
|
||||
available_mirror = mirror
|
||||
break
|
||||
|
||||
# Create transformed file entry
|
||||
transformed_file = {
|
||||
'id': file_data.get('id'),
|
||||
'sizeKB': file_data.get('sizeKB'),
|
||||
'name': available_mirror.get('filename', file_data.get('name')) if available_mirror else file_data.get('name'),
|
||||
'type': file_data.get('type'),
|
||||
'downloadUrl': available_mirror.get('url') if available_mirror else None,
|
||||
'primary': True,
|
||||
'mirrors': file_data.get('mirrors', [])
|
||||
}
|
||||
|
||||
# Transform hash format
|
||||
if 'sha256' in file_data:
|
||||
transformed_file['hashes'] = {
|
||||
'SHA256': file_data['sha256'].upper()
|
||||
}
|
||||
|
||||
transformed_files.append(transformed_file)
|
||||
|
||||
version['files'] = transformed_files
|
||||
|
||||
# Add model information
|
||||
version['model'] = {
|
||||
'name': model_data.get('name'),
|
||||
'type': model_data.get('type'),
|
||||
'nsfw': model_data.get('is_nsfw', False),
|
||||
'description': model_data.get('description'),
|
||||
'tags': model_data.get('tags', [])
|
||||
}
|
||||
|
||||
version['creator'] = {
|
||||
'username': model_data.get('username'),
|
||||
'image': ''
|
||||
}
|
||||
|
||||
# Add source identifier
|
||||
version['source'] = 'civarchive'
|
||||
version['is_deleted'] = json_data.get('query', {}).get('is_deleted', False)
|
||||
|
||||
return version
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivArchive model version {model_id}/{version_id}: {e}")
|
||||
return None
|
||||
return await self.client.get_model_version(model_id, version_id)
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Not supported by CivArchive provider - requires both model_id and version_id"""
|
||||
return None, "CivArchive provider requires both model_id and version_id"
|
||||
return await self.client.get_model_version_info(version_id)
|
||||
|
||||
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||
"""Not supported by CivArchive provider"""
|
||||
|
||||
@@ -144,6 +144,27 @@ class ServiceRegistry:
|
||||
cls._services[service_name] = client
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
async def get_civarchive_client(cls):
|
||||
"""Get or create CivArchive client instance"""
|
||||
service_name = "civarchive_client"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .civarchive_client import CivArchiveClient
|
||||
|
||||
client = await CivArchiveClient.get_instance()
|
||||
cls._services[service_name] = client
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
async def get_download_manager(cls):
|
||||
|
||||
Reference in New Issue
Block a user