fix(civitai): use red-only api host (#897)

This commit is contained in:
Will Miao
2026-04-16 12:06:34 +08:00
parent af2146f96c
commit 406d5fea6a
6 changed files with 145 additions and 150 deletions

View File

@@ -30,7 +30,7 @@ class CivitaiBaseModelService:
DEFAULT_CACHE_TTL = 7 * 24 * 60 * 60 DEFAULT_CACHE_TTL = 7 * 24 * 60 * 60
# Civitai API endpoint for enums # Civitai API endpoint for enums
CIVITAI_ENUMS_URL = "https://civitai.com/api/v1/enums" CIVITAI_ENUMS_URL = "https://civitai.red/api/v1/enums"
@classmethod @classmethod
async def get_instance(cls) -> CivitaiBaseModelService: async def get_instance(cls) -> CivitaiBaseModelService:

View File

@@ -9,7 +9,7 @@ from .model_metadata_provider import (
) )
from .downloader import get_downloader from .downloader import get_downloader
from .errors import RateLimitError, ResourceNotFoundError from .errors import RateLimitError, ResourceNotFoundError
from ..utils.civitai_utils import extract_civitai_page_host, resolve_license_payload from ..utils.civitai_utils import resolve_license_payload
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -39,24 +39,10 @@ class CivitaiClient:
return return
self._initialized = True self._initialized = True
self.base_url = "https://civitai.com/api/v1" self.base_url = "https://civitai.red/api/v1"
self._image_info_api_hosts = ("civitai.com", "civitai.red")
def _build_image_info_url(self, host: str, image_id: str) -> str: def _build_image_info_url(self, image_id: str) -> str:
return f"https://{host}/api/v1/images?imageId={image_id}&nsfw=X" return f"{self.base_url}/images?imageId={image_id}&nsfw=X"
def _resolve_image_info_hosts(self, source_url: str | None) -> List[str]:
preferred_host = extract_civitai_page_host(source_url)
if preferred_host in self._image_info_api_hosts:
return [
preferred_host,
*[
host
for host in self._image_info_api_hosts
if host != preferred_host
],
]
return list(self._image_info_api_hosts)
async def _make_request( async def _make_request(
self, self,
@@ -207,7 +193,9 @@ class CivitaiClient:
"""Get all versions of a model with local availability info""" """Get all versions of a model with local availability info"""
try: try:
success, result = await self._make_request( success, result = await self._make_request(
"GET", f"{self.base_url}/models/{model_id}", use_auth=True "GET",
f"{self.base_url}/models/{model_id}",
use_auth=True,
) )
if success: if success:
# Also return model type along with versions # Also return model type along with versions
@@ -363,7 +351,9 @@ class CivitaiClient:
async def _fetch_model_data(self, model_id: int) -> Optional[Dict]: async def _fetch_model_data(self, model_id: int) -> Optional[Dict]:
success, data = await self._make_request( success, data = await self._make_request(
"GET", f"{self.base_url}/models/{model_id}", use_auth=True "GET",
f"{self.base_url}/models/{model_id}",
use_auth=True,
) )
if success: if success:
return data return data
@@ -375,7 +365,9 @@ class CivitaiClient:
return None return None
success, version = await self._make_request( success, version = await self._make_request(
"GET", f"{self.base_url}/model-versions/{version_id}", use_auth=True "GET",
f"{self.base_url}/model-versions/{version_id}",
use_auth=True,
) )
if success: if success:
return version return version
@@ -388,7 +380,9 @@ class CivitaiClient:
return None return None
success, version = await self._make_request( success, version = await self._make_request(
"GET", f"{self.base_url}/model-versions/by-hash/{model_hash}", use_auth=True "GET",
f"{self.base_url}/model-versions/by-hash/{model_hash}",
use_auth=True,
) )
if success: if success:
return version return version
@@ -470,13 +464,11 @@ class CivitaiClient:
try: try:
url = f"{self.base_url}/model-versions/{version_id}" url = f"{self.base_url}/model-versions/{version_id}"
logger.debug(f"Resolving DNS for model version info: {url}") logger.debug("Resolving Civitai model version info: %s", url)
success, result = await self._make_request("GET", url, use_auth=True) success, result = await self._make_request("GET", url, use_auth=True)
if success: if success:
logger.debug( logger.debug("Successfully fetched model version info for: %s", version_id)
f"Successfully fetched model version info for: {version_id}"
)
self._remove_comfy_metadata(result) self._remove_comfy_metadata(result)
return result, None return result, None
@@ -503,99 +495,51 @@ class CivitaiClient:
Args: Args:
image_id: The Civitai image ID image_id: The Civitai image ID
source_url: Optional original image page URL used to prioritize source_url: Original image page URL. Accepted for caller compatibility;
``civitai.com`` vs ``civitai.red`` image lookups. API requests always target ``civitai.red``.
Returns: Returns:
Optional[Dict]: The image data or None if not found Optional[Dict]: The image data or None if not found
""" """
try: try:
requested_id = int(image_id) requested_id = int(image_id)
candidate_hosts = self._resolve_image_info_hosts(source_url) url = self._build_image_info_url(image_id)
last_error: Any = None success, result = await self._make_request("GET", url, use_auth=True)
logger.debug(
"Fetching image info for ID %s with host order %s",
image_id,
candidate_hosts,
)
for index, host in enumerate(candidate_hosts): if not success:
url = self._build_image_info_url(host, image_id) logger.error(
success, result = await self._make_request("GET", url, use_auth=True) "Failed to fetch image info for ID %s from civitai.red: %s",
image_id,
if not success: result,
last_error = result )
if index < len(candidate_hosts) - 1:
logger.warning(
"Failed to fetch image info for ID %s from %s: %s. Trying fallback host.",
image_id,
host,
result,
)
continue
logger.error(
"Failed to fetch image info for ID %s from %s: %s",
image_id,
host,
result,
)
return None
if result and "items" in result and isinstance(result["items"], list):
items = result["items"]
for item in items:
if isinstance(item, dict) and item.get("id") == requested_id:
logger.debug(
"Successfully fetched image info for ID %s from %s",
image_id,
host,
)
return item
returned_ids = [
item.get("id")
for item in items
if isinstance(item, dict) and "id" in item
]
if index < len(candidate_hosts) - 1:
logger.info(
"No matching image for requested ID %s from %s; trying fallback host. Returned %d item(s) with IDs: %s",
image_id,
host,
len(items),
returned_ids,
)
continue
logger.warning(
"CivitAI API returned no matching image for requested ID %s from %s. Returned %d item(s) with IDs: %s. This may indicate the image was deleted, hidden, or there is a database lag.",
image_id,
host,
len(items),
returned_ids,
)
return None
if index < len(candidate_hosts) - 1:
logger.info(
"No image found with ID %s from %s; trying fallback host",
image_id,
host,
)
continue
logger.warning("No image found with ID: %s", image_id)
return None return None
if last_error is not None: if result and "items" in result and isinstance(result["items"], list):
logger.error( items = result["items"]
"Failed to fetch image info for ID %s from all candidate hosts: %s",
for item in items:
if isinstance(item, dict) and item.get("id") == requested_id:
logger.debug(
"Successfully fetched image info for ID %s from civitai.red",
image_id,
)
return item
returned_ids = [
item.get("id")
for item in items
if isinstance(item, dict) and "id" in item
]
logger.warning(
"CivitAI API returned no matching image for requested ID %s from civitai.red. Returned %d item(s) with IDs: %s. This may indicate the image was deleted, hidden, or there is a database lag.",
image_id, image_id,
last_error, len(items),
returned_ids,
) )
return None
logger.warning("No image found with ID: %s", image_id)
return None return None
except RateLimitError: except RateLimitError:
raise raise
@@ -614,8 +558,12 @@ class CivitaiClient:
return None return None
try: try:
url = f"{self.base_url}/models?username={username}" success, result = await self._make_request(
success, result = await self._make_request("GET", url, use_auth=True) "GET",
f"{self.base_url}/models",
use_auth=True,
params={"username": username},
)
if not success: if not success:
logger.error("Failed to fetch models for %s: %s", username, result) logger.error("Failed to fetch models for %s: %s", username, result)

View File

@@ -31,6 +31,11 @@ import tempfile
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CIVITAI_DOWNLOAD_URL_PREFIXES = (
"https://civitai.com/api/download/",
"https://civitai.red/api/download/",
)
class DownloadManager: class DownloadManager:
_instance = None _instance = None
@@ -647,12 +652,12 @@ class DownloadManager:
civitai_urls = [ civitai_urls = [
u u
for u in download_urls for u in download_urls
if u.startswith("https://civitai.com/api/download/") if u.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES)
] ]
non_civitai_urls = [ non_civitai_urls = [
u u
for u in download_urls for u in download_urls
if not u.startswith("https://civitai.com/api/download/") if not u.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES)
] ]
download_urls = non_civitai_urls + civitai_urls download_urls = non_civitai_urls + civitai_urls
else: else:
@@ -1133,7 +1138,7 @@ class DownloadManager:
pause_control.update_stall_timeout(downloader.stall_timeout) pause_control.update_stall_timeout(downloader.stall_timeout)
last_error = None last_error = None
for download_url in download_urls: for download_url in download_urls:
use_auth = download_url.startswith("https://civitai.com/api/download/") use_auth = download_url.startswith(CIVITAI_DOWNLOAD_URL_PREFIXES)
download_kwargs = { download_kwargs = {
"progress_callback": lambda progress, snapshot=None: ( "progress_callback": lambda progress, snapshot=None: (
self._handle_download_progress( self._handle_download_progress(

View File

@@ -94,7 +94,7 @@ class DummyDoctorScanner:
class DummyCivitaiClient: class DummyCivitaiClient:
def __init__(self, *, success=True, result=None): def __init__(self, *, success=True, result=None):
self.base_url = 'https://civitai.com/api/v1' self.base_url = 'https://civitai.red/api/v1'
self._success = success self._success = success
self._result = result if result is not None else {'items': []} self._result = result if result is not None else {'items': []}

View File

@@ -62,6 +62,12 @@ async def test_download_file_uses_downloader(tmp_path, downloader):
assert downloader.download_calls[0]["use_auth"] is True assert downloader.download_calls[0]["use_auth"] is True
async def test_client_defaults_to_red_api_host(downloader):
client = await CivitaiClient.get_instance()
assert client.base_url == "https://civitai.red/api/v1"
async def test_get_model_by_hash_enriches_metadata(monkeypatch, downloader): async def test_get_model_by_hash_enriches_metadata(monkeypatch, downloader):
version_payload = { version_payload = {
"modelId": 123, "modelId": 123,
@@ -551,36 +557,12 @@ async def test_get_image_info_prefers_red_host_for_red_source(monkeypatch, downl
] ]
async def test_get_image_info_falls_back_from_com_to_red(monkeypatch, downloader): async def test_get_image_info_uses_red_host_even_for_red_source(monkeypatch, downloader):
requested_urls = [] requested_urls = []
async def fake_make_request(method, url, use_auth=True, **kwargs): async def fake_make_request(method, url, use_auth=True, **kwargs):
requested_urls.append(url) requested_urls.append(url)
if url.startswith("https://civitai.com/"): return True, {"items": [{"id": 124950237, "name": "target"}]}
return True, {"items": []}
return True, {"items": [{"id": 124950237, "name": "fallback"}]}
downloader.make_request = fake_make_request
client = await CivitaiClient.get_instance()
result = await client.get_image_info("124950237")
assert result == {"id": 124950237, "name": "fallback"}
assert requested_urls == [
"https://civitai.com/api/v1/images?imageId=124950237&nsfw=X",
"https://civitai.red/api/v1/images?imageId=124950237&nsfw=X",
]
async def test_get_image_info_falls_back_from_red_to_com(monkeypatch, downloader):
requested_urls = []
async def fake_make_request(method, url, use_auth=True, **kwargs):
requested_urls.append(url)
if url.startswith("https://civitai.red/"):
return True, {"items": []}
return True, {"items": [{"id": 124950237, "name": "fallback"}]}
downloader.make_request = fake_make_request downloader.make_request = fake_make_request
@@ -590,21 +572,18 @@ async def test_get_image_info_falls_back_from_red_to_com(monkeypatch, downloader
"124950237", source_url="https://civitai.red/images/124950237" "124950237", source_url="https://civitai.red/images/124950237"
) )
assert result == {"id": 124950237, "name": "fallback"} assert result == {"id": 124950237, "name": "target"}
assert requested_urls == [ assert requested_urls == [
"https://civitai.red/api/v1/images?imageId=124950237&nsfw=X", "https://civitai.red/api/v1/images?imageId=124950237&nsfw=X",
"https://civitai.com/api/v1/images?imageId=124950237&nsfw=X",
] ]
async def test_get_image_info_falls_back_after_request_failure(monkeypatch, downloader): async def test_get_image_info_does_not_fall_back_after_request_failure(monkeypatch, downloader):
requested_urls = [] requested_urls = []
async def fake_make_request(method, url, use_auth=True, **kwargs): async def fake_make_request(method, url, use_auth=True, **kwargs):
requested_urls.append(url) requested_urls.append(url)
if url.startswith("https://civitai.red/"): return False, "403 forbidden"
return False, "403 forbidden"
return True, {"items": [{"id": 124950237, "name": "fallback"}]}
downloader.make_request = fake_make_request downloader.make_request = fake_make_request
@@ -614,10 +593,9 @@ async def test_get_image_info_falls_back_after_request_failure(monkeypatch, down
"124950237", source_url="https://civitai.red/images/124950237" "124950237", source_url="https://civitai.red/images/124950237"
) )
assert result == {"id": 124950237, "name": "fallback"} assert result is None
assert requested_urls == [ assert requested_urls == [
"https://civitai.red/api/v1/images?imageId=124950237&nsfw=X", "https://civitai.red/api/v1/images?imageId=124950237&nsfw=X",
"https://civitai.com/api/v1/images?imageId=124950237&nsfw=X",
] ]

View File

@@ -7,7 +7,10 @@ from unittest.mock import AsyncMock
import pytest import pytest
from py.services.download_manager import DownloadManager from py.services.download_manager import (
CIVITAI_DOWNLOAD_URL_PREFIXES,
DownloadManager,
)
from py.services import download_manager from py.services import download_manager
from py.services.service_registry import ServiceRegistry from py.services.service_registry import ServiceRegistry
from py.services.settings_manager import SettingsManager, get_settings_manager from py.services.settings_manager import SettingsManager, get_settings_manager
@@ -309,6 +312,67 @@ async def test_execute_download_respects_blur_setting(monkeypatch, tmp_path):
assert stored_preview and stored_preview.endswith(".jpeg") assert stored_preview and stored_preview.endswith(".jpeg")
@pytest.mark.asyncio
async def test_execute_download_uses_auth_for_red_civitai_downloads(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
target_path = save_dir / "file.safetensors"
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
self.preview_nsfw_level = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, _path):
return None
def to_dict(self):
return {"file_path": self.file_path}
metadata = DummyMetadata(target_path)
recorded_use_auth = []
class DummyDownloader:
stall_timeout = None
async def download_file(self, url, path, progress_callback=None, use_auth=None, **_kwargs):
recorded_use_auth.append((url, use_auth))
Path(path).write_bytes(b"model")
return True, None
monkeypatch.setattr(
download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())
)
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(
DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)
)
result = await manager._execute_download(
download_urls=["https://civitai.red/api/download/models/119514"],
save_dir=str(save_dir),
metadata=metadata,
version_info={"images": []},
relative_path="",
progress_callback=None,
model_type="lora",
download_id=None,
)
assert result == {"success": True}
assert recorded_use_auth == [("https://civitai.red/api/download/models/119514", True)]
assert "https://civitai.red/api/download/".startswith(CIVITAI_DOWNLOAD_URL_PREFIXES)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_civarchive_source_uses_civarchive_provider( async def test_civarchive_source_uses_civarchive_provider(
monkeypatch, scanners, tmp_path monkeypatch, scanners, tmp_path