feat(civarchive_client): enhance request handling and context parsing

Introduce `_request_json` method for async JSON requests and improved error handling. Add static methods `_normalize_payload`, `_split_context`, `_ensure_list`, and `_build_model_info` to parse and normalize API responses. These changes improve the robustness of the CivArchiveClient by ensuring consistent data structures and handling potential API response issues gracefully.
This commit is contained in:
Will Miao
2025-10-11 13:07:29 +08:00
parent 7d560bf07a
commit 1f60160e8b
3 changed files with 561 additions and 220 deletions

View File

@@ -2,6 +2,7 @@ import os
import json import json
import logging import logging
import asyncio import asyncio
from copy import deepcopy
from typing import Optional, Dict, Tuple, List from typing import Optional, Dict, Tuple, List
from .model_metadata_provider import CivArchiveModelMetadataProvider, ModelMetadataProviderManager from .model_metadata_provider import CivArchiveModelMetadataProvider, ModelMetadataProviderManager
from .downloader import get_downloader from .downloader import get_downloader
@@ -49,115 +50,256 @@ class CivArchiveClient:
self.base_url = "https://civarchive.com/api" self.base_url = "https://civarchive.com/api"
async def _request_json(
self,
path: str,
params: Optional[Dict[str, str]] = None
) -> Tuple[Optional[Dict], Optional[str]]:
"""Call CivArchive API and return JSON payload"""
downloader = await get_downloader()
kwargs: Dict[str, Dict[str, str]] = {}
if params:
safe_params = {str(key): str(value) for key, value in params.items() if value is not None}
if safe_params:
kwargs["params"] = safe_params
success, payload = await downloader.make_request(
"GET",
f"{self.base_url}{path}",
use_auth=False,
**kwargs
)
if not success:
error = payload if isinstance(payload, str) else "Request failed"
return None, error
if not isinstance(payload, dict):
return None, "Invalid response structure"
return payload, None
@staticmethod
def _normalize_payload(payload: Dict) -> Dict:
"""Unwrap CivArchive responses that wrap content under a data key"""
if not isinstance(payload, dict):
return {}
data = payload.get("data")
if isinstance(data, dict):
return data
return payload
@staticmethod
def _split_context(payload: Dict) -> Tuple[Dict, Dict, List[Dict]]:
"""Separate version payload from surrounding model context"""
data = CivArchiveClient._normalize_payload(payload)
context: Dict = {}
fallback_files: List[Dict] = []
version: Dict = {}
for key, value in data.items():
if key in {"version", "model"}:
continue
context[key] = value
if isinstance(data.get("version"), dict):
version = data["version"]
model_block = data.get("model")
if isinstance(model_block, dict):
for key, value in model_block.items():
if key == "version":
if not version and isinstance(value, dict):
version = value
continue
context.setdefault(key, value)
fallback_files = fallback_files or model_block.get("files") or []
fallback_files = fallback_files or data.get("files") or []
return context, version, fallback_files
@staticmethod
def _ensure_list(value) -> List:
if isinstance(value, list):
return value
if value is None:
return []
return [value]
@staticmethod
def _build_model_info(context: Dict) -> Dict:
tags = context.get("tags")
if not isinstance(tags, list):
tags = list(tags) if isinstance(tags, (set, tuple)) else ([] if tags is None else [tags])
return {
"name": context.get("name"),
"type": context.get("type"),
"nsfw": bool(context.get("is_nsfw", context.get("nsfw", False))),
"description": context.get("description"),
"tags": tags,
}
@staticmethod
def _build_creator_info(context: Dict) -> Dict:
username = context.get("creator_username") or context.get("username") or ""
image = context.get("creator_image") or context.get("creator_avatar") or ""
creator: Dict[str, Optional[str]] = {
"username": username,
"image": image,
}
if context.get("creator_name"):
creator["name"] = context["creator_name"]
if context.get("creator_url"):
creator["url"] = context["creator_url"]
return creator
@staticmethod
def _transform_file_entry(file_data: Dict) -> Dict:
mirrors = file_data.get("mirrors") or []
if not isinstance(mirrors, list):
mirrors = [mirrors]
available_mirror = next(
(mirror for mirror in mirrors if isinstance(mirror, dict) and mirror.get("deletedAt") is None),
None
)
download_url = file_data.get("downloadUrl")
if not download_url and available_mirror:
download_url = available_mirror.get("url")
name = file_data.get("name")
if not name and available_mirror:
name = available_mirror.get("filename")
transformed: Dict = {
"id": file_data.get("id"),
"sizeKB": file_data.get("sizeKB"),
"name": name,
"type": file_data.get("type"),
"downloadUrl": download_url,
"primary": True,
# TODO: for some reason is_primary is false in CivArchive response, need to figure this out,
# "primary": bool(file_data.get("is_primary", file_data.get("primary", False))),
"mirrors": mirrors,
}
sha256 = file_data.get("sha256")
if sha256:
transformed["hashes"] = {"SHA256": str(sha256).upper()}
elif isinstance(file_data.get("hashes"), dict):
transformed["hashes"] = file_data["hashes"]
if "metadata" in file_data:
transformed["metadata"] = file_data["metadata"]
if file_data.get("modelVersionId") is not None:
transformed["modelVersionId"] = file_data.get("modelVersionId")
elif file_data.get("model_version_id") is not None:
transformed["modelVersionId"] = file_data.get("model_version_id")
if file_data.get("modelId") is not None:
transformed["modelId"] = file_data.get("modelId")
elif file_data.get("model_id") is not None:
transformed["modelId"] = file_data.get("model_id")
return transformed
def _transform_files(
self,
files: Optional[List[Dict]],
fallback_files: Optional[List[Dict]] = None
) -> List[Dict]:
candidates: List[Dict] = []
if isinstance(files, list) and files:
candidates = files
elif isinstance(fallback_files, list):
candidates = fallback_files
transformed_files: List[Dict] = []
for file_data in candidates:
if isinstance(file_data, dict):
transformed_files.append(self._transform_file_entry(file_data))
return transformed_files
def _transform_version(
self,
context: Dict,
version: Dict,
fallback_files: Optional[List[Dict]] = None
) -> Optional[Dict]:
if not version:
return None
version_copy = deepcopy(version)
version_copy.pop("model", None)
version_copy.pop("creator", None)
if "trigger" in version_copy:
triggers = version_copy.pop("trigger")
if isinstance(triggers, list):
version_copy["trainedWords"] = triggers
elif triggers is None:
version_copy["trainedWords"] = []
else:
version_copy["trainedWords"] = [triggers]
if "trainedWords" in version_copy and isinstance(version_copy["trainedWords"], str):
version_copy["trainedWords"] = [version_copy["trainedWords"]]
if "nsfw_level" in version_copy:
version_copy["nsfwLevel"] = version_copy.pop("nsfw_level")
elif "nsfwLevel" not in version_copy and context.get("nsfw_level") is not None:
version_copy["nsfwLevel"] = context.get("nsfw_level")
stats_keys = ["downloadCount", "ratingCount", "rating"]
stats = {key: version_copy.pop(key) for key in stats_keys if key in version_copy}
if stats:
version_copy["stats"] = stats
version_copy["files"] = self._transform_files(version_copy.get("files"), fallback_files)
version_copy["images"] = self._ensure_list(version_copy.get("images"))
version_copy["model"] = self._build_model_info(context)
version_copy["creator"] = self._build_creator_info(context)
version_copy["source"] = "civarchive"
version_copy["is_deleted"] = bool(context.get("deletedAt")) or bool(version.get("deletedAt"))
return version_copy
async def _resolve_version_from_files(self, payload: Dict) -> Optional[Dict]:
"""Fallback to fetch version data when only file metadata is available"""
data = self._normalize_payload(payload)
files = data.get("files") or payload.get("files") or []
if not isinstance(files, list):
files = [files]
for file_data in files:
if not isinstance(file_data, dict):
continue
model_id = file_data.get("model_id") or file_data.get("modelId")
version_id = file_data.get("model_version_id") or file_data.get("modelVersionId")
if model_id is None or version_id is None:
continue
resolved = await self.get_model_version(model_id, version_id)
if resolved:
return resolved
return None
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]: async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Find model by SHA256 hash value using CivArchive API""" """Find model by SHA256 hash value using CivArchive API"""
if "/" in model_hash:
metadata = await self.get_model_by_url(model_hash)
if metadata:
return metadata, None
else:
return None, f"Error fetching url: {model_hash}"
try: try:
# CivArchive only supports SHA256 hashes payload, error = await self._request_json(f"/sha256/{model_hash.lower()}")
url = f"{self.base_url}/sha256/{model_hash.lower()}" if error:
if "not found" in error.lower():
downloader = await get_downloader()
session = await downloader.session
async with session.get(url) as response:
if response.status != 200:
if response.status == 404:
return None, "Model not found" return None, "Model not found"
return None, f"HTTP {response.status}" return None, error
data = await response.json() context, version_data, fallback_files = self._split_context(payload)
transformed = self._transform_version(context, version_data, fallback_files)
if transformed:
return transformed, None
# Extract the model and version data from CivArchive structure resolved = await self._resolve_version_from_files(payload)
model_data = data.get('model', {}) if resolved:
version_data = model_data.get('version', {}) return resolved, None
files_data = data.get('files', {})
if not version_data: logger.error("Error fetching version of CivArchive model by hash %s", model_hash[:10])
if files_data:
logger.error(f"{data}")
# sometimes CivArc returns ONLY file info... but it can then be used to get the rest of the info...
# actually as of now (10/25), api broke and ONLY returns 'files' info...
for file_data in files_data:
logger.error(f"{file_data}")
if file_data["source"] == "civitai":
api_data = await self.get_model_version(file_data["model_id"], file_data["model_version_id"])
logger.error(f"{api_data}")
logger.error(f"found CivArchive model by hash {model_hash[:10]}")
return api_data, None
else:
logger.error(f"Error fetching version of CivArchive model by hash {model_hash[:10]}")
return None, "No version data found" return None, "No version data found"
# Transform to match expected format
result = version_data.copy()
# Add model information
result['model'] = {
'name': model_data.get('name'),
'type': model_data.get('type'),
'nsfw': model_data.get('nsfw', False),
'description': model_data.get('description'),
'tags': model_data.get('tags', [])
}
# Add creator information
result['creator'] = {
'username': model_data.get('username', model_data.get('creator_username')),
'image': ''
}
# Rename trigger to trainedWords for consistency
if 'trigger' in result:
result['trainedWords'] = result.pop('trigger')
# Transform stats
if 'downloadCount' in result and 'ratingCount' in result and 'rating' in result:
result['stats'] = {
'downloadCount': result.pop('downloadCount'),
'ratingCount': result.pop('ratingCount'),
'rating': result.pop('rating')
}
# Transform files to match expected format
if 'files' in result:
transformed_files = []
for file_data in result['files']:
# Find first available mirror
available_mirror = None
for mirror in file_data.get('mirrors', []):
if mirror.get('deletedAt') is None:
available_mirror = mirror
break
transformed_file = {
'id': file_data.get('id'),
'sizeKB': file_data.get('sizeKB'),
'name': available_mirror.get('filename', file_data.get('name')) if available_mirror else file_data.get('name'),
'type': file_data.get('type'),
'downloadUrl': available_mirror.get('url') if available_mirror else file_data.get('downloadUrl'),
'primary': True,
'mirrors': file_data.get('mirrors', [])
}
# Transform hash format
if 'sha256' in file_data:
transformed_file['hashes'] = {
'SHA256': file_data['sha256'].upper()
}
transformed_files.append(transformed_file)
result['files'] = transformed_files
# Add source identifier
result['source'] = 'civarchive'
return result, None
except Exception as e: except Exception as e:
logger.error(f"Error fetching CivArchive model by hash {model_hash[:10]}: {e}") logger.error(f"Error fetching CivArchive model by hash {model_hash[:10]}: {e}")
return None, str(e) return None, str(e)
@@ -165,24 +307,47 @@ class CivArchiveClient:
async def get_model_versions(self, model_id: str) -> Optional[Dict]: async def get_model_versions(self, model_id: str) -> Optional[Dict]:
"""Get all versions of a model using CivArchive API""" """Get all versions of a model using CivArchive API"""
try: try:
url = f"{self.base_url}/models/{model_id}" payload, error = await self._request_json(f"/models/{model_id}")
if error or payload is None:
downloader = await get_downloader() if error and "not found" in error.lower():
session = await downloader.session return None
async with session.get(url) as response: logger.error(f"Error fetching CivArchive model versions for {model_id}: {error}")
if response.status != 200:
return None return None
data = await response.json() data = self._normalize_payload(payload)
context, version_data, fallback_files = self._split_context(payload)
# Extract versions list versions_meta = data.get("versions") or []
versions = data.get('versions', []) transformed_versions: List[Dict] = []
for meta in versions_meta:
if not isinstance(meta, dict):
continue
version_id = meta.get("id")
if version_id is None:
continue
target_model_id = meta.get("modelId") or model_id
version = await self.get_model_version(target_model_id, version_id)
if version:
transformed_versions.append(version)
# Ensure the primary version is included even if versions list was empty
primary_version = self._transform_version(context, version_data, fallback_files)
if primary_version:
transformed_versions.insert(0, primary_version)
ordered_versions: List[Dict] = []
seen_ids = set()
for version in transformed_versions:
version_id = version.get("id")
if version_id in seen_ids:
continue
seen_ids.add(version_id)
ordered_versions.append(version)
# Return in format similar to Civitai
return { return {
'modelVersions': versions, "modelVersions": ordered_versions,
'type': data.get('type', ''), "type": context.get("type", ""),
'name': data.get('name', '') "name": context.get("name", ""),
} }
except Exception as e: except Exception as e:
@@ -203,100 +368,34 @@ class CivArchiveClient:
return None return None
try: try:
if version_id is not None: params = {"modelVersionId": version_id} if version_id is not None else None
url = f"{self.base_url}/models/{model_id}?modelVersionId={version_id}" payload, error = await self._request_json(f"/models/{model_id}", params=params)
else: if error or payload is None:
url = f"{self.base_url}/models/{model_id}" if error and "not found" in error.lower():
return None
downloader = await get_downloader() logger.error(f"Error fetching CivArchive model version via API {model_id}/{version_id}: {error}")
session = await downloader.session
async with session.get(url) as response:
if response.status != 200:
return None return None
data = await response.json() context, version_data, fallback_files = self._split_context(payload)
# Get the version data - CivArchive returns the latest/default version in 'version' field if not version_data:
version_data = data.get('version', {}) return await self._resolve_version_from_files(payload)
versions = data.get('versions', {})
# If version_id is specified, check if it matches
if version_id is not None: if version_id is not None:
if version_data.get('id') != version_id: raw_id = version_data.get("id")
# Version mismatch - would need to iterate through versions or make another call if raw_id != version_id:
# For now, return None as CivArchive API doesn't provide easy version filtering logger.warning(
logger.warning(f"Requested version {version_id} doesn't match default version {version_data.get('id')} for model {model_id}") "Requested version %s doesn't match default version %s for model %s",
version_id,
raw_id,
model_id,
)
return None return None
if version_data.get('modelId') != model_id: actual_model_id = version_data.get("modelId")
# you can pass ANY model id, and a version number, and get the CORRECT model id from this... if actual_model_id is not None and str(actual_model_id) != str(model_id):
# so recall the api with the correct info now return await self.get_model_version(actual_model_id, version_id)
return await self.get_model_version(version_data.get('modelId'), version_id)
# Transform to expected format return self._transform_version(context, version_data, fallback_files)
result = version_data.copy()
# Restructure stats
if 'downloadCount' in result and 'ratingCount' in result and 'rating' in result:
result['stats'] = {
'downloadCount': result.pop('downloadCount'),
'ratingCount': result.pop('ratingCount'),
'rating': result.pop('rating')
}
# Rename trigger to trainedWords
if 'trigger' in result:
result['trainedWords'] = result.pop('trigger')
# Transform files data
if 'files' in result:
transformed_files = []
for file_data in result['files']:
# Find first available mirror
available_mirror = None
for mirror in file_data.get('mirrors', []):
if mirror.get('deletedAt') is None:
available_mirror = mirror
break
transformed_file = {
'id': file_data.get('id'),
'sizeKB': file_data.get('sizeKB'),
'name': available_mirror.get('filename', file_data.get('name')) if available_mirror else file_data.get('name'),
'type': file_data.get('type'),
'downloadUrl': available_mirror.get('url') if available_mirror else file_data.get('downloadUrl'),
'primary': True,
'mirrors': file_data.get('mirrors', [])
}
# Transform hash format
if 'sha256' in file_data:
transformed_file['hashes'] = {
'SHA256': file_data['sha256'].upper()
}
transformed_files.append(transformed_file)
result['files'] = transformed_files
# Add model information
result['model'] = {
'name': data.get('name'),
'type': data.get('type'),
'nsfw': data.get('is_nsfw', False),
'description': data.get('description'),
'tags': data.get('tags', [])
}
result['creator'] = {
'username': data.get('username', data.get('creator_username')),
'image': ''
}
# Add source identifier
result['source'] = 'civarchive'
result['is_deleted'] = data.get('deletedAt') is not None
return result
except Exception as e: except Exception as e:
logger.error(f"Error fetching CivArchive model version via API {model_id}/{version_id}: {e}") logger.error(f"Error fetching CivArchive model version via API {model_id}/{version_id}: {e}")
@@ -312,7 +411,10 @@ class CivArchiveClient:
Returns: Returns:
Tuple[Optional[Dict], Optional[str]]: (version_data, error_message) Tuple[Optional[Dict], Optional[str]]: (version_data, error_message)
""" """
return await self.get_model_version(1, version_id) version = await self.get_model_version(1, version_id)
if version is None:
return None, "Model not found"
return version, None
async def get_model_by_url(self, url) -> Optional[Dict]: async def get_model_by_url(self, url) -> Optional[Dict]:
"""Get specific model version by parsing CivArchive HTML page (legacy method) """Get specific model version by parsing CivArchive HTML page (legacy method)
@@ -380,7 +482,7 @@ class CivArchiveClient:
'name': available_mirror.get('filename', file_data.get('name')) if available_mirror else file_data.get('name'), 'name': available_mirror.get('filename', file_data.get('name')) if available_mirror else file_data.get('name'),
'type': file_data.get('type'), 'type': file_data.get('type'),
'downloadUrl': available_mirror.get('url') if available_mirror else None, 'downloadUrl': available_mirror.get('url') if available_mirror else None,
'primary': True, 'primary': file_data.get('is_primary', False),
'mirrors': file_data.get('mirrors', []) 'mirrors': file_data.get('mirrors', [])
} }
@@ -415,5 +517,5 @@ class CivArchiveClient:
return version return version
except Exception as e: except Exception as e:
logger.error(f"Error fetching CivArchive model version (scraping) {model_id}/{version_id}: {e}") logger.error(f"Error fetching CivArchive model version (scraping) {url}: {e}")
return None return None

View File

@@ -294,7 +294,7 @@ class DownloadManager:
await progress_callback(0) await progress_callback(0)
# 2. Get file information # 2. Get file information
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None) file_info = next((f for f in version_info.get('files', []) if f.get('primary') and f.get('type') == 'Model'), None)
if not file_info: if not file_info:
return {'success': False, 'error': 'No primary file found in metadata'} return {'success': False, 'error': 'No primary file found in metadata'}
mirrors = file_info.get('mirrors') or [] mirrors = file_info.get('mirrors') or []

View File

@@ -0,0 +1,239 @@
import copy
from unittest.mock import AsyncMock
import pytest
from py.services import civarchive_client as civarchive_client_module
from py.services.civarchive_client import CivArchiveClient
from py.services.model_metadata_provider import ModelMetadataProviderManager
class DummyDownloader:
def __init__(self):
self.calls = []
async def make_request(self, method, url, use_auth=False, **kwargs):
self.calls.append({"method": method, "url": url, "params": kwargs.get("params")})
return True, {}
@pytest.fixture(autouse=True)
def reset_singletons():
CivArchiveClient._instance = None
ModelMetadataProviderManager._instance = None
yield
CivArchiveClient._instance = None
ModelMetadataProviderManager._instance = None
@pytest.fixture
def downloader(monkeypatch):
instance = DummyDownloader()
monkeypatch.setattr(civarchive_client_module, "get_downloader", AsyncMock(return_value=instance))
return instance
def _base_civarchive_payload(version_id=1976567, *, trigger="mxpln", nsfw_level=31):
version_name = "v2.0" if version_id != 1976567 else "v1.0"
file_sha = "e2b7a280d6539556f23f380b3f71e4e22bc4524445c4c96526e117c6005c6ad3"
return {
"data": {
"id": 1746460,
"name": "Mixplin Style [Illustrious]",
"type": "LORA",
"description": "description",
"is_nsfw": True,
"nsfw_level": nsfw_level,
"tags": ["art", "style"],
"creator_username": "Ty_Lee",
"creator_name": "Ty_Lee",
"creator_url": "/users/Ty_Lee",
"version": {
"id": version_id,
"modelId": 1746460,
"name": version_name,
"baseModel": "Illustrious",
"description": "version description",
"downloadCount": 437,
"ratingCount": 0,
"rating": 0,
"nsfw_level": nsfw_level,
"trigger": [trigger],
"files": [
{
"id": 1874043,
"name": "mxpln-illustrious-ty_lee.safetensors",
"type": "Model",
"sizeKB": 223124.37109375,
"downloadUrl": "https://civitai.com/api/download/models/1976567",
"sha256": file_sha,
"is_primary": False,
"mirrors": [
{
"filename": "mxpln-illustrious-ty_lee.safetensors",
"url": "https://civitai.com/api/download/models/1976567",
"deletedAt": None,
}
],
}
],
"images": [
{
"id": 86403595,
"url": "https://img.genur.art/example.png",
"nsfwLevel": 1,
}
],
},
"versions": [
{"id": 2042594, "name": "v2.0"},
{"id": 1976567, "name": "v1.0"},
],
}
}
async def test_get_model_by_hash_transforms_payload(downloader):
payload = _base_civarchive_payload()
async def fake_make_request(method, url, use_auth=False, **kwargs):
downloader.calls.append({"url": url, "params": kwargs.get("params")})
if url.endswith("/sha256/abc"):
return True, copy.deepcopy(payload)
return False, "unexpected"
downloader.make_request = fake_make_request
client = await CivArchiveClient.get_instance()
result, error = await client.get_model_by_hash("abc")
assert error is None
assert result["id"] == 1976567
assert result["nsfwLevel"] == 31
assert result["trainedWords"] == ["mxpln"]
assert result["stats"] == {"downloadCount": 437, "ratingCount": 0, "rating": 0}
assert result["model"]["name"] == "Mixplin Style [Illustrious]"
assert result["model"]["nsfw"] is True
assert result["creator"]["username"] == "Ty_Lee"
assert result["creator"]["image"] == ""
file_meta = result["files"][0]
assert file_meta["hashes"]["SHA256"] == "E2B7A280D6539556F23F380B3F71E4E22BC4524445C4C96526E117C6005C6AD3"
assert file_meta["mirrors"][0]["url"] == "https://civitai.com/api/download/models/1976567"
assert file_meta["primary"] is False
assert result["source"] == "civarchive"
assert result["images"][0]["url"] == "https://img.genur.art/example.png"
async def test_get_model_versions_fetches_each_version(downloader):
base_url = "https://civarchive.com/api/models/1746460"
base_payload = _base_civarchive_payload(version_id=2042594, trigger="mxpln-new", nsfw_level=5)
other_payload = _base_civarchive_payload()
responses = {
(base_url, None): base_payload,
(base_url, (("modelVersionId", "2042594"),)): base_payload,
(base_url, (("modelVersionId", "1976567"),)): other_payload,
}
async def fake_make_request(method, url, use_auth=False, **kwargs):
params = kwargs.get("params")
key = (url, tuple(sorted((params or {}).items())) if params else None)
downloader.calls.append({"url": url, "params": params})
if key in responses:
return True, copy.deepcopy(responses[key])
return False, "unexpected"
downloader.make_request = fake_make_request
client = await CivArchiveClient.get_instance()
result = await client.get_model_versions("1746460")
assert result["name"] == "Mixplin Style [Illustrious]"
assert result["type"] == "LORA"
versions = result["modelVersions"]
assert [version["id"] for version in versions] == [2042594, 1976567]
assert versions[0]["trainedWords"] == ["mxpln-new"]
assert versions[1]["trainedWords"] == ["mxpln"]
assert versions[0]["nsfwLevel"] == 5
assert versions[1]["nsfwLevel"] == 31
assert any(call["params"] == {"modelVersionId": "2042594"} for call in downloader.calls)
assert any(call["params"] == {"modelVersionId": "1976567"} for call in downloader.calls)
async def test_get_model_version_redirects_to_actual_model_id(downloader):
first_payload = _base_civarchive_payload()
first_payload["data"]["version"]["modelId"] = 222
base_url_request = "https://civarchive.com/api/models/111"
redirected_url_request = "https://civarchive.com/api/models/222"
async def fake_make_request(method, url, use_auth=False, **kwargs):
downloader.calls.append({"url": url, "params": kwargs.get("params")})
params = kwargs.get("params") or {}
if url == base_url_request:
return True, copy.deepcopy(first_payload)
if url == redirected_url_request and params.get("modelVersionId") == "1976567":
return True, copy.deepcopy(_base_civarchive_payload())
return False, "unexpected"
downloader.make_request = fake_make_request
client = await CivArchiveClient.get_instance()
result = await client.get_model_version(model_id=111, version_id=1976567)
assert result is not None
assert result["model"]["name"] == "Mixplin Style [Illustrious]"
assert len(downloader.calls) == 2
assert downloader.calls[1]["url"] == redirected_url_request
async def test_get_model_by_hash_uses_file_fallback(downloader, monkeypatch):
file_only_payload = {
"data": {
"files": [
{
"model_id": 1746460,
"model_version_id": 1976567,
"source": "civitai",
}
]
}
}
version_payload = _base_civarchive_payload()
async def fake_make_request(method, url, use_auth=False, **kwargs):
downloader.calls.append({"url": url, "params": kwargs.get("params")})
if "/sha256/" in url:
return True, copy.deepcopy(file_only_payload)
if "/models/1746460" in url:
return True, copy.deepcopy(version_payload)
return False, "unexpected"
downloader.make_request = fake_make_request
client = await CivArchiveClient.get_instance()
result, error = await client.get_model_by_hash("fallback")
assert error is None
assert result["id"] == 1976567
assert result["model"]["name"] == "Mixplin Style [Illustrious]"
assert any("/models/1746460" in call["url"] for call in downloader.calls)
async def test_get_model_by_hash_handles_not_found(downloader):
async def fake_make_request(method, url, use_auth=False, **kwargs):
return False, "Resource not found"
downloader.make_request = fake_make_request
client = await CivArchiveClient.get_instance()
result, error = await client.get_model_by_hash("missing")
assert result is None
assert error == "Model not found"