mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(civarchive_client): enhance request handling and context parsing
Introduce `_request_json` method for async JSON requests and improved error handling. Add static methods `_normalize_payload`, `_split_context`, `_ensure_list`, and `_build_model_info` to parse and normalize API responses. These changes improve the robustness of the CivArchiveClient by ensuring consistent data structures and handling potential API response issues gracefully.
This commit is contained in:
@@ -2,6 +2,7 @@ import os
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Dict, Tuple, List
|
||||
from .model_metadata_provider import CivArchiveModelMetadataProvider, ModelMetadataProviderManager
|
||||
from .downloader import get_downloader
|
||||
@@ -49,115 +50,256 @@ class CivArchiveClient:
|
||||
|
||||
self.base_url = "https://civarchive.com/api"
|
||||
|
||||
async def _request_json(
|
||||
self,
|
||||
path: str,
|
||||
params: Optional[Dict[str, str]] = None
|
||||
) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Call CivArchive API and return JSON payload"""
|
||||
downloader = await get_downloader()
|
||||
kwargs: Dict[str, Dict[str, str]] = {}
|
||||
if params:
|
||||
safe_params = {str(key): str(value) for key, value in params.items() if value is not None}
|
||||
if safe_params:
|
||||
kwargs["params"] = safe_params
|
||||
success, payload = await downloader.make_request(
|
||||
"GET",
|
||||
f"{self.base_url}{path}",
|
||||
use_auth=False,
|
||||
**kwargs
|
||||
)
|
||||
if not success:
|
||||
error = payload if isinstance(payload, str) else "Request failed"
|
||||
return None, error
|
||||
if not isinstance(payload, dict):
|
||||
return None, "Invalid response structure"
|
||||
return payload, None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_payload(payload: Dict) -> Dict:
|
||||
"""Unwrap CivArchive responses that wrap content under a data key"""
|
||||
if not isinstance(payload, dict):
|
||||
return {}
|
||||
data = payload.get("data")
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _split_context(payload: Dict) -> Tuple[Dict, Dict, List[Dict]]:
|
||||
"""Separate version payload from surrounding model context"""
|
||||
data = CivArchiveClient._normalize_payload(payload)
|
||||
context: Dict = {}
|
||||
fallback_files: List[Dict] = []
|
||||
version: Dict = {}
|
||||
|
||||
for key, value in data.items():
|
||||
if key in {"version", "model"}:
|
||||
continue
|
||||
context[key] = value
|
||||
|
||||
if isinstance(data.get("version"), dict):
|
||||
version = data["version"]
|
||||
|
||||
model_block = data.get("model")
|
||||
if isinstance(model_block, dict):
|
||||
for key, value in model_block.items():
|
||||
if key == "version":
|
||||
if not version and isinstance(value, dict):
|
||||
version = value
|
||||
continue
|
||||
context.setdefault(key, value)
|
||||
fallback_files = fallback_files or model_block.get("files") or []
|
||||
|
||||
fallback_files = fallback_files or data.get("files") or []
|
||||
return context, version, fallback_files
|
||||
|
||||
@staticmethod
|
||||
def _ensure_list(value) -> List:
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
if value is None:
|
||||
return []
|
||||
return [value]
|
||||
|
||||
@staticmethod
|
||||
def _build_model_info(context: Dict) -> Dict:
|
||||
tags = context.get("tags")
|
||||
if not isinstance(tags, list):
|
||||
tags = list(tags) if isinstance(tags, (set, tuple)) else ([] if tags is None else [tags])
|
||||
return {
|
||||
"name": context.get("name"),
|
||||
"type": context.get("type"),
|
||||
"nsfw": bool(context.get("is_nsfw", context.get("nsfw", False))),
|
||||
"description": context.get("description"),
|
||||
"tags": tags,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_creator_info(context: Dict) -> Dict:
|
||||
username = context.get("creator_username") or context.get("username") or ""
|
||||
image = context.get("creator_image") or context.get("creator_avatar") or ""
|
||||
creator: Dict[str, Optional[str]] = {
|
||||
"username": username,
|
||||
"image": image,
|
||||
}
|
||||
if context.get("creator_name"):
|
||||
creator["name"] = context["creator_name"]
|
||||
if context.get("creator_url"):
|
||||
creator["url"] = context["creator_url"]
|
||||
return creator
|
||||
|
||||
@staticmethod
|
||||
def _transform_file_entry(file_data: Dict) -> Dict:
|
||||
mirrors = file_data.get("mirrors") or []
|
||||
if not isinstance(mirrors, list):
|
||||
mirrors = [mirrors]
|
||||
available_mirror = next(
|
||||
(mirror for mirror in mirrors if isinstance(mirror, dict) and mirror.get("deletedAt") is None),
|
||||
None
|
||||
)
|
||||
download_url = file_data.get("downloadUrl")
|
||||
if not download_url and available_mirror:
|
||||
download_url = available_mirror.get("url")
|
||||
name = file_data.get("name")
|
||||
if not name and available_mirror:
|
||||
name = available_mirror.get("filename")
|
||||
|
||||
transformed: Dict = {
|
||||
"id": file_data.get("id"),
|
||||
"sizeKB": file_data.get("sizeKB"),
|
||||
"name": name,
|
||||
"type": file_data.get("type"),
|
||||
"downloadUrl": download_url,
|
||||
"primary": True,
|
||||
# TODO: for some reason is_primary is false in CivArchive response, need to figure this out,
|
||||
# "primary": bool(file_data.get("is_primary", file_data.get("primary", False))),
|
||||
"mirrors": mirrors,
|
||||
}
|
||||
|
||||
sha256 = file_data.get("sha256")
|
||||
if sha256:
|
||||
transformed["hashes"] = {"SHA256": str(sha256).upper()}
|
||||
elif isinstance(file_data.get("hashes"), dict):
|
||||
transformed["hashes"] = file_data["hashes"]
|
||||
|
||||
if "metadata" in file_data:
|
||||
transformed["metadata"] = file_data["metadata"]
|
||||
|
||||
if file_data.get("modelVersionId") is not None:
|
||||
transformed["modelVersionId"] = file_data.get("modelVersionId")
|
||||
elif file_data.get("model_version_id") is not None:
|
||||
transformed["modelVersionId"] = file_data.get("model_version_id")
|
||||
|
||||
if file_data.get("modelId") is not None:
|
||||
transformed["modelId"] = file_data.get("modelId")
|
||||
elif file_data.get("model_id") is not None:
|
||||
transformed["modelId"] = file_data.get("model_id")
|
||||
|
||||
return transformed
|
||||
|
||||
def _transform_files(
|
||||
self,
|
||||
files: Optional[List[Dict]],
|
||||
fallback_files: Optional[List[Dict]] = None
|
||||
) -> List[Dict]:
|
||||
candidates: List[Dict] = []
|
||||
if isinstance(files, list) and files:
|
||||
candidates = files
|
||||
elif isinstance(fallback_files, list):
|
||||
candidates = fallback_files
|
||||
|
||||
transformed_files: List[Dict] = []
|
||||
for file_data in candidates:
|
||||
if isinstance(file_data, dict):
|
||||
transformed_files.append(self._transform_file_entry(file_data))
|
||||
return transformed_files
|
||||
|
||||
def _transform_version(
|
||||
self,
|
||||
context: Dict,
|
||||
version: Dict,
|
||||
fallback_files: Optional[List[Dict]] = None
|
||||
) -> Optional[Dict]:
|
||||
if not version:
|
||||
return None
|
||||
|
||||
version_copy = deepcopy(version)
|
||||
version_copy.pop("model", None)
|
||||
version_copy.pop("creator", None)
|
||||
|
||||
if "trigger" in version_copy:
|
||||
triggers = version_copy.pop("trigger")
|
||||
if isinstance(triggers, list):
|
||||
version_copy["trainedWords"] = triggers
|
||||
elif triggers is None:
|
||||
version_copy["trainedWords"] = []
|
||||
else:
|
||||
version_copy["trainedWords"] = [triggers]
|
||||
|
||||
if "trainedWords" in version_copy and isinstance(version_copy["trainedWords"], str):
|
||||
version_copy["trainedWords"] = [version_copy["trainedWords"]]
|
||||
|
||||
if "nsfw_level" in version_copy:
|
||||
version_copy["nsfwLevel"] = version_copy.pop("nsfw_level")
|
||||
elif "nsfwLevel" not in version_copy and context.get("nsfw_level") is not None:
|
||||
version_copy["nsfwLevel"] = context.get("nsfw_level")
|
||||
|
||||
stats_keys = ["downloadCount", "ratingCount", "rating"]
|
||||
stats = {key: version_copy.pop(key) for key in stats_keys if key in version_copy}
|
||||
if stats:
|
||||
version_copy["stats"] = stats
|
||||
|
||||
version_copy["files"] = self._transform_files(version_copy.get("files"), fallback_files)
|
||||
version_copy["images"] = self._ensure_list(version_copy.get("images"))
|
||||
|
||||
version_copy["model"] = self._build_model_info(context)
|
||||
version_copy["creator"] = self._build_creator_info(context)
|
||||
|
||||
version_copy["source"] = "civarchive"
|
||||
version_copy["is_deleted"] = bool(context.get("deletedAt")) or bool(version.get("deletedAt"))
|
||||
|
||||
return version_copy
|
||||
|
||||
async def _resolve_version_from_files(self, payload: Dict) -> Optional[Dict]:
|
||||
"""Fallback to fetch version data when only file metadata is available"""
|
||||
data = self._normalize_payload(payload)
|
||||
files = data.get("files") or payload.get("files") or []
|
||||
if not isinstance(files, list):
|
||||
files = [files]
|
||||
for file_data in files:
|
||||
if not isinstance(file_data, dict):
|
||||
continue
|
||||
model_id = file_data.get("model_id") or file_data.get("modelId")
|
||||
version_id = file_data.get("model_version_id") or file_data.get("modelVersionId")
|
||||
if model_id is None or version_id is None:
|
||||
continue
|
||||
resolved = await self.get_model_version(model_id, version_id)
|
||||
if resolved:
|
||||
return resolved
|
||||
return None
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Find model by SHA256 hash value using CivArchive API"""
|
||||
if "/" in model_hash:
|
||||
metadata = await self.get_model_by_url(model_hash)
|
||||
if metadata:
|
||||
return metadata, None
|
||||
else:
|
||||
return None, f"Error fetching url: {model_hash}"
|
||||
try:
|
||||
# CivArchive only supports SHA256 hashes
|
||||
url = f"{self.base_url}/sha256/{model_hash.lower()}"
|
||||
|
||||
downloader = await get_downloader()
|
||||
session = await downloader.session
|
||||
async with session.get(url) as response:
|
||||
if response.status != 200:
|
||||
if response.status == 404:
|
||||
return None, "Model not found"
|
||||
return None, f"HTTP {response.status}"
|
||||
|
||||
data = await response.json()
|
||||
|
||||
# Extract the model and version data from CivArchive structure
|
||||
model_data = data.get('model', {})
|
||||
version_data = model_data.get('version', {})
|
||||
files_data = data.get('files', {})
|
||||
payload, error = await self._request_json(f"/sha256/{model_hash.lower()}")
|
||||
if error:
|
||||
if "not found" in error.lower():
|
||||
return None, "Model not found"
|
||||
return None, error
|
||||
|
||||
context, version_data, fallback_files = self._split_context(payload)
|
||||
transformed = self._transform_version(context, version_data, fallback_files)
|
||||
if transformed:
|
||||
return transformed, None
|
||||
|
||||
resolved = await self._resolve_version_from_files(payload)
|
||||
if resolved:
|
||||
return resolved, None
|
||||
|
||||
logger.error("Error fetching version of CivArchive model by hash %s", model_hash[:10])
|
||||
return None, "No version data found"
|
||||
|
||||
if not version_data:
|
||||
if files_data:
|
||||
logger.error(f"{data}")
|
||||
# sometimes CivArc returns ONLY file info... but it can then be used to get the rest of the info...
|
||||
# actually as of now (10/25), api broke and ONLY returns 'files' info...
|
||||
for file_data in files_data:
|
||||
logger.error(f"{file_data}")
|
||||
if file_data["source"] == "civitai":
|
||||
api_data = await self.get_model_version(file_data["model_id"], file_data["model_version_id"])
|
||||
logger.error(f"{api_data}")
|
||||
logger.error(f"found CivArchive model by hash {model_hash[:10]}")
|
||||
return api_data, None
|
||||
else:
|
||||
logger.error(f"Error fetching version of CivArchive model by hash {model_hash[:10]}")
|
||||
return None, "No version data found"
|
||||
|
||||
# Transform to match expected format
|
||||
result = version_data.copy()
|
||||
|
||||
# Add model information
|
||||
result['model'] = {
|
||||
'name': model_data.get('name'),
|
||||
'type': model_data.get('type'),
|
||||
'nsfw': model_data.get('nsfw', False),
|
||||
'description': model_data.get('description'),
|
||||
'tags': model_data.get('tags', [])
|
||||
}
|
||||
|
||||
# Add creator information
|
||||
result['creator'] = {
|
||||
'username': model_data.get('username', model_data.get('creator_username')),
|
||||
'image': ''
|
||||
}
|
||||
|
||||
# Rename trigger to trainedWords for consistency
|
||||
if 'trigger' in result:
|
||||
result['trainedWords'] = result.pop('trigger')
|
||||
|
||||
# Transform stats
|
||||
if 'downloadCount' in result and 'ratingCount' in result and 'rating' in result:
|
||||
result['stats'] = {
|
||||
'downloadCount': result.pop('downloadCount'),
|
||||
'ratingCount': result.pop('ratingCount'),
|
||||
'rating': result.pop('rating')
|
||||
}
|
||||
|
||||
# Transform files to match expected format
|
||||
if 'files' in result:
|
||||
transformed_files = []
|
||||
for file_data in result['files']:
|
||||
# Find first available mirror
|
||||
available_mirror = None
|
||||
for mirror in file_data.get('mirrors', []):
|
||||
if mirror.get('deletedAt') is None:
|
||||
available_mirror = mirror
|
||||
break
|
||||
|
||||
transformed_file = {
|
||||
'id': file_data.get('id'),
|
||||
'sizeKB': file_data.get('sizeKB'),
|
||||
'name': available_mirror.get('filename', file_data.get('name')) if available_mirror else file_data.get('name'),
|
||||
'type': file_data.get('type'),
|
||||
'downloadUrl': available_mirror.get('url') if available_mirror else file_data.get('downloadUrl'),
|
||||
'primary': True,
|
||||
'mirrors': file_data.get('mirrors', [])
|
||||
}
|
||||
|
||||
# Transform hash format
|
||||
if 'sha256' in file_data:
|
||||
transformed_file['hashes'] = {
|
||||
'SHA256': file_data['sha256'].upper()
|
||||
}
|
||||
|
||||
transformed_files.append(transformed_file)
|
||||
|
||||
result['files'] = transformed_files
|
||||
|
||||
# Add source identifier
|
||||
result['source'] = 'civarchive'
|
||||
|
||||
return result, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivArchive model by hash {model_hash[:10]}: {e}")
|
||||
return None, str(e)
|
||||
@@ -165,26 +307,49 @@ class CivArchiveClient:
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
"""Get all versions of a model using CivArchive API"""
|
||||
try:
|
||||
url = f"{self.base_url}/models/{model_id}"
|
||||
|
||||
downloader = await get_downloader()
|
||||
session = await downloader.session
|
||||
async with session.get(url) as response:
|
||||
if response.status != 200:
|
||||
payload, error = await self._request_json(f"/models/{model_id}")
|
||||
if error or payload is None:
|
||||
if error and "not found" in error.lower():
|
||||
return None
|
||||
|
||||
data = await response.json()
|
||||
|
||||
# Extract versions list
|
||||
versions = data.get('versions', [])
|
||||
|
||||
# Return in format similar to Civitai
|
||||
logger.error(f"Error fetching CivArchive model versions for {model_id}: {error}")
|
||||
return None
|
||||
|
||||
data = self._normalize_payload(payload)
|
||||
context, version_data, fallback_files = self._split_context(payload)
|
||||
|
||||
versions_meta = data.get("versions") or []
|
||||
transformed_versions: List[Dict] = []
|
||||
for meta in versions_meta:
|
||||
if not isinstance(meta, dict):
|
||||
continue
|
||||
version_id = meta.get("id")
|
||||
if version_id is None:
|
||||
continue
|
||||
target_model_id = meta.get("modelId") or model_id
|
||||
version = await self.get_model_version(target_model_id, version_id)
|
||||
if version:
|
||||
transformed_versions.append(version)
|
||||
|
||||
# Ensure the primary version is included even if versions list was empty
|
||||
primary_version = self._transform_version(context, version_data, fallback_files)
|
||||
if primary_version:
|
||||
transformed_versions.insert(0, primary_version)
|
||||
|
||||
ordered_versions: List[Dict] = []
|
||||
seen_ids = set()
|
||||
for version in transformed_versions:
|
||||
version_id = version.get("id")
|
||||
if version_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(version_id)
|
||||
ordered_versions.append(version)
|
||||
|
||||
return {
|
||||
'modelVersions': versions,
|
||||
'type': data.get('type', ''),
|
||||
'name': data.get('name', '')
|
||||
"modelVersions": ordered_versions,
|
||||
"type": context.get("type", ""),
|
||||
"name": context.get("name", ""),
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivArchive model versions for {model_id}: {e}")
|
||||
return None
|
||||
@@ -201,103 +366,37 @@ class CivArchiveClient:
|
||||
"""
|
||||
if model_id is None:
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
if version_id is not None:
|
||||
url = f"{self.base_url}/models/{model_id}?modelVersionId={version_id}"
|
||||
else:
|
||||
url = f"{self.base_url}/models/{model_id}"
|
||||
|
||||
downloader = await get_downloader()
|
||||
session = await downloader.session
|
||||
async with session.get(url) as response:
|
||||
if response.status != 200:
|
||||
params = {"modelVersionId": version_id} if version_id is not None else None
|
||||
payload, error = await self._request_json(f"/models/{model_id}", params=params)
|
||||
if error or payload is None:
|
||||
if error and "not found" in error.lower():
|
||||
return None
|
||||
|
||||
data = await response.json()
|
||||
|
||||
# Get the version data - CivArchive returns the latest/default version in 'version' field
|
||||
version_data = data.get('version', {})
|
||||
versions = data.get('versions', {})
|
||||
|
||||
# If version_id is specified, check if it matches
|
||||
logger.error(f"Error fetching CivArchive model version via API {model_id}/{version_id}: {error}")
|
||||
return None
|
||||
|
||||
context, version_data, fallback_files = self._split_context(payload)
|
||||
|
||||
if not version_data:
|
||||
return await self._resolve_version_from_files(payload)
|
||||
|
||||
if version_id is not None:
|
||||
if version_data.get('id') != version_id:
|
||||
# Version mismatch - would need to iterate through versions or make another call
|
||||
# For now, return None as CivArchive API doesn't provide easy version filtering
|
||||
logger.warning(f"Requested version {version_id} doesn't match default version {version_data.get('id')} for model {model_id}")
|
||||
return None
|
||||
if version_data.get('modelId') != model_id:
|
||||
# you can pass ANY model id, and a version number, and get the CORRECT model id from this...
|
||||
# so recall the api with the correct info now
|
||||
return await self.get_model_version(version_data.get('modelId'), version_id)
|
||||
|
||||
# Transform to expected format
|
||||
result = version_data.copy()
|
||||
|
||||
# Restructure stats
|
||||
if 'downloadCount' in result and 'ratingCount' in result and 'rating' in result:
|
||||
result['stats'] = {
|
||||
'downloadCount': result.pop('downloadCount'),
|
||||
'ratingCount': result.pop('ratingCount'),
|
||||
'rating': result.pop('rating')
|
||||
}
|
||||
|
||||
# Rename trigger to trainedWords
|
||||
if 'trigger' in result:
|
||||
result['trainedWords'] = result.pop('trigger')
|
||||
|
||||
# Transform files data
|
||||
if 'files' in result:
|
||||
transformed_files = []
|
||||
for file_data in result['files']:
|
||||
# Find first available mirror
|
||||
available_mirror = None
|
||||
for mirror in file_data.get('mirrors', []):
|
||||
if mirror.get('deletedAt') is None:
|
||||
available_mirror = mirror
|
||||
break
|
||||
|
||||
transformed_file = {
|
||||
'id': file_data.get('id'),
|
||||
'sizeKB': file_data.get('sizeKB'),
|
||||
'name': available_mirror.get('filename', file_data.get('name')) if available_mirror else file_data.get('name'),
|
||||
'type': file_data.get('type'),
|
||||
'downloadUrl': available_mirror.get('url') if available_mirror else file_data.get('downloadUrl'),
|
||||
'primary': True,
|
||||
'mirrors': file_data.get('mirrors', [])
|
||||
}
|
||||
|
||||
# Transform hash format
|
||||
if 'sha256' in file_data:
|
||||
transformed_file['hashes'] = {
|
||||
'SHA256': file_data['sha256'].upper()
|
||||
}
|
||||
|
||||
transformed_files.append(transformed_file)
|
||||
|
||||
result['files'] = transformed_files
|
||||
|
||||
# Add model information
|
||||
result['model'] = {
|
||||
'name': data.get('name'),
|
||||
'type': data.get('type'),
|
||||
'nsfw': data.get('is_nsfw', False),
|
||||
'description': data.get('description'),
|
||||
'tags': data.get('tags', [])
|
||||
}
|
||||
|
||||
result['creator'] = {
|
||||
'username': data.get('username', data.get('creator_username')),
|
||||
'image': ''
|
||||
}
|
||||
|
||||
# Add source identifier
|
||||
result['source'] = 'civarchive'
|
||||
result['is_deleted'] = data.get('deletedAt') is not None
|
||||
|
||||
return result
|
||||
|
||||
raw_id = version_data.get("id")
|
||||
if raw_id != version_id:
|
||||
logger.warning(
|
||||
"Requested version %s doesn't match default version %s for model %s",
|
||||
version_id,
|
||||
raw_id,
|
||||
model_id,
|
||||
)
|
||||
return None
|
||||
actual_model_id = version_data.get("modelId")
|
||||
if actual_model_id is not None and str(actual_model_id) != str(model_id):
|
||||
return await self.get_model_version(actual_model_id, version_id)
|
||||
|
||||
return self._transform_version(context, version_data, fallback_files)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivArchive model version via API {model_id}/{version_id}: {e}")
|
||||
return None
|
||||
@@ -312,7 +411,10 @@ class CivArchiveClient:
|
||||
Returns:
|
||||
Tuple[Optional[Dict], Optional[str]]: (version_data, error_message)
|
||||
"""
|
||||
return await self.get_model_version(1, version_id)
|
||||
version = await self.get_model_version(1, version_id)
|
||||
if version is None:
|
||||
return None, "Model not found"
|
||||
return version, None
|
||||
|
||||
async def get_model_by_url(self, url) -> Optional[Dict]:
|
||||
"""Get specific model version by parsing CivArchive HTML page (legacy method)
|
||||
@@ -380,7 +482,7 @@ class CivArchiveClient:
|
||||
'name': available_mirror.get('filename', file_data.get('name')) if available_mirror else file_data.get('name'),
|
||||
'type': file_data.get('type'),
|
||||
'downloadUrl': available_mirror.get('url') if available_mirror else None,
|
||||
'primary': True,
|
||||
'primary': file_data.get('is_primary', False),
|
||||
'mirrors': file_data.get('mirrors', [])
|
||||
}
|
||||
|
||||
@@ -415,5 +517,5 @@ class CivArchiveClient:
|
||||
return version
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivArchive model version (scraping) {model_id}/{version_id}: {e}")
|
||||
logger.error(f"Error fetching CivArchive model version (scraping) {url}: {e}")
|
||||
return None
|
||||
|
||||
@@ -294,7 +294,7 @@ class DownloadManager:
|
||||
await progress_callback(0)
|
||||
|
||||
# 2. Get file information
|
||||
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
|
||||
file_info = next((f for f in version_info.get('files', []) if f.get('primary') and f.get('type') == 'Model'), None)
|
||||
if not file_info:
|
||||
return {'success': False, 'error': 'No primary file found in metadata'}
|
||||
mirrors = file_info.get('mirrors') or []
|
||||
|
||||
239
tests/services/test_civarchive_client.py
Normal file
239
tests/services/test_civarchive_client.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import copy
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services import civarchive_client as civarchive_client_module
|
||||
from py.services.civarchive_client import CivArchiveClient
|
||||
from py.services.model_metadata_provider import ModelMetadataProviderManager
|
||||
|
||||
|
||||
class DummyDownloader:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def make_request(self, method, url, use_auth=False, **kwargs):
|
||||
self.calls.append({"method": method, "url": url, "params": kwargs.get("params")})
|
||||
return True, {}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_singletons():
|
||||
CivArchiveClient._instance = None
|
||||
ModelMetadataProviderManager._instance = None
|
||||
yield
|
||||
CivArchiveClient._instance = None
|
||||
ModelMetadataProviderManager._instance = None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def downloader(monkeypatch):
|
||||
instance = DummyDownloader()
|
||||
monkeypatch.setattr(civarchive_client_module, "get_downloader", AsyncMock(return_value=instance))
|
||||
return instance
|
||||
|
||||
|
||||
def _base_civarchive_payload(version_id=1976567, *, trigger="mxpln", nsfw_level=31):
|
||||
version_name = "v2.0" if version_id != 1976567 else "v1.0"
|
||||
file_sha = "e2b7a280d6539556f23f380b3f71e4e22bc4524445c4c96526e117c6005c6ad3"
|
||||
return {
|
||||
"data": {
|
||||
"id": 1746460,
|
||||
"name": "Mixplin Style [Illustrious]",
|
||||
"type": "LORA",
|
||||
"description": "description",
|
||||
"is_nsfw": True,
|
||||
"nsfw_level": nsfw_level,
|
||||
"tags": ["art", "style"],
|
||||
"creator_username": "Ty_Lee",
|
||||
"creator_name": "Ty_Lee",
|
||||
"creator_url": "/users/Ty_Lee",
|
||||
"version": {
|
||||
"id": version_id,
|
||||
"modelId": 1746460,
|
||||
"name": version_name,
|
||||
"baseModel": "Illustrious",
|
||||
"description": "version description",
|
||||
"downloadCount": 437,
|
||||
"ratingCount": 0,
|
||||
"rating": 0,
|
||||
"nsfw_level": nsfw_level,
|
||||
"trigger": [trigger],
|
||||
"files": [
|
||||
{
|
||||
"id": 1874043,
|
||||
"name": "mxpln-illustrious-ty_lee.safetensors",
|
||||
"type": "Model",
|
||||
"sizeKB": 223124.37109375,
|
||||
"downloadUrl": "https://civitai.com/api/download/models/1976567",
|
||||
"sha256": file_sha,
|
||||
"is_primary": False,
|
||||
"mirrors": [
|
||||
{
|
||||
"filename": "mxpln-illustrious-ty_lee.safetensors",
|
||||
"url": "https://civitai.com/api/download/models/1976567",
|
||||
"deletedAt": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"images": [
|
||||
{
|
||||
"id": 86403595,
|
||||
"url": "https://img.genur.art/example.png",
|
||||
"nsfwLevel": 1,
|
||||
}
|
||||
],
|
||||
},
|
||||
"versions": [
|
||||
{"id": 2042594, "name": "v2.0"},
|
||||
{"id": 1976567, "name": "v1.0"},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
async def test_get_model_by_hash_transforms_payload(downloader):
|
||||
payload = _base_civarchive_payload()
|
||||
|
||||
async def fake_make_request(method, url, use_auth=False, **kwargs):
|
||||
downloader.calls.append({"url": url, "params": kwargs.get("params")})
|
||||
if url.endswith("/sha256/abc"):
|
||||
return True, copy.deepcopy(payload)
|
||||
return False, "unexpected"
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivArchiveClient.get_instance()
|
||||
|
||||
result, error = await client.get_model_by_hash("abc")
|
||||
|
||||
assert error is None
|
||||
assert result["id"] == 1976567
|
||||
assert result["nsfwLevel"] == 31
|
||||
assert result["trainedWords"] == ["mxpln"]
|
||||
assert result["stats"] == {"downloadCount": 437, "ratingCount": 0, "rating": 0}
|
||||
assert result["model"]["name"] == "Mixplin Style [Illustrious]"
|
||||
assert result["model"]["nsfw"] is True
|
||||
assert result["creator"]["username"] == "Ty_Lee"
|
||||
assert result["creator"]["image"] == ""
|
||||
file_meta = result["files"][0]
|
||||
assert file_meta["hashes"]["SHA256"] == "E2B7A280D6539556F23F380B3F71E4E22BC4524445C4C96526E117C6005C6AD3"
|
||||
assert file_meta["mirrors"][0]["url"] == "https://civitai.com/api/download/models/1976567"
|
||||
assert file_meta["primary"] is False
|
||||
assert result["source"] == "civarchive"
|
||||
assert result["images"][0]["url"] == "https://img.genur.art/example.png"
|
||||
|
||||
|
||||
async def test_get_model_versions_fetches_each_version(downloader):
|
||||
base_url = "https://civarchive.com/api/models/1746460"
|
||||
base_payload = _base_civarchive_payload(version_id=2042594, trigger="mxpln-new", nsfw_level=5)
|
||||
other_payload = _base_civarchive_payload()
|
||||
|
||||
responses = {
|
||||
(base_url, None): base_payload,
|
||||
(base_url, (("modelVersionId", "2042594"),)): base_payload,
|
||||
(base_url, (("modelVersionId", "1976567"),)): other_payload,
|
||||
}
|
||||
|
||||
async def fake_make_request(method, url, use_auth=False, **kwargs):
|
||||
params = kwargs.get("params")
|
||||
key = (url, tuple(sorted((params or {}).items())) if params else None)
|
||||
downloader.calls.append({"url": url, "params": params})
|
||||
if key in responses:
|
||||
return True, copy.deepcopy(responses[key])
|
||||
return False, "unexpected"
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivArchiveClient.get_instance()
|
||||
|
||||
result = await client.get_model_versions("1746460")
|
||||
|
||||
assert result["name"] == "Mixplin Style [Illustrious]"
|
||||
assert result["type"] == "LORA"
|
||||
versions = result["modelVersions"]
|
||||
assert [version["id"] for version in versions] == [2042594, 1976567]
|
||||
assert versions[0]["trainedWords"] == ["mxpln-new"]
|
||||
assert versions[1]["trainedWords"] == ["mxpln"]
|
||||
assert versions[0]["nsfwLevel"] == 5
|
||||
assert versions[1]["nsfwLevel"] == 31
|
||||
assert any(call["params"] == {"modelVersionId": "2042594"} for call in downloader.calls)
|
||||
assert any(call["params"] == {"modelVersionId": "1976567"} for call in downloader.calls)
|
||||
|
||||
|
||||
async def test_get_model_version_redirects_to_actual_model_id(downloader):
|
||||
first_payload = _base_civarchive_payload()
|
||||
first_payload["data"]["version"]["modelId"] = 222
|
||||
|
||||
base_url_request = "https://civarchive.com/api/models/111"
|
||||
redirected_url_request = "https://civarchive.com/api/models/222"
|
||||
|
||||
async def fake_make_request(method, url, use_auth=False, **kwargs):
|
||||
downloader.calls.append({"url": url, "params": kwargs.get("params")})
|
||||
params = kwargs.get("params") or {}
|
||||
if url == base_url_request:
|
||||
return True, copy.deepcopy(first_payload)
|
||||
if url == redirected_url_request and params.get("modelVersionId") == "1976567":
|
||||
return True, copy.deepcopy(_base_civarchive_payload())
|
||||
return False, "unexpected"
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivArchiveClient.get_instance()
|
||||
|
||||
result = await client.get_model_version(model_id=111, version_id=1976567)
|
||||
|
||||
assert result is not None
|
||||
assert result["model"]["name"] == "Mixplin Style [Illustrious]"
|
||||
assert len(downloader.calls) == 2
|
||||
assert downloader.calls[1]["url"] == redirected_url_request
|
||||
|
||||
|
||||
async def test_get_model_by_hash_uses_file_fallback(downloader, monkeypatch):
|
||||
file_only_payload = {
|
||||
"data": {
|
||||
"files": [
|
||||
{
|
||||
"model_id": 1746460,
|
||||
"model_version_id": 1976567,
|
||||
"source": "civitai",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
version_payload = _base_civarchive_payload()
|
||||
|
||||
async def fake_make_request(method, url, use_auth=False, **kwargs):
|
||||
downloader.calls.append({"url": url, "params": kwargs.get("params")})
|
||||
if "/sha256/" in url:
|
||||
return True, copy.deepcopy(file_only_payload)
|
||||
if "/models/1746460" in url:
|
||||
return True, copy.deepcopy(version_payload)
|
||||
return False, "unexpected"
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivArchiveClient.get_instance()
|
||||
|
||||
result, error = await client.get_model_by_hash("fallback")
|
||||
|
||||
assert error is None
|
||||
assert result["id"] == 1976567
|
||||
assert result["model"]["name"] == "Mixplin Style [Illustrious]"
|
||||
assert any("/models/1746460" in call["url"] for call in downloader.calls)
|
||||
|
||||
|
||||
async def test_get_model_by_hash_handles_not_found(downloader):
|
||||
async def fake_make_request(method, url, use_auth=False, **kwargs):
|
||||
return False, "Resource not found"
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivArchiveClient.get_instance()
|
||||
|
||||
result, error = await client.get_model_by_hash("missing")
|
||||
|
||||
assert result is None
|
||||
assert error == "Model not found"
|
||||
Reference in New Issue
Block a user