mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 08:26:45 -03:00
fix(network): add offline cooldown guard for remote metadata requests
This commit is contained in:
@@ -16,6 +16,10 @@ import jinja2
|
||||
|
||||
from ...config import config
|
||||
from ...services.download_coordinator import DownloadCoordinator
|
||||
from ...services.connectivity_guard import (
|
||||
OFFLINE_FRIENDLY_MESSAGE,
|
||||
is_expected_offline_error,
|
||||
)
|
||||
from ...services.metadata_sync_service import MetadataSyncService
|
||||
from ...services.model_file_service import ModelMoveService
|
||||
from ...services.preview_asset_service import PreviewAssetService
|
||||
@@ -504,6 +508,11 @@ class ModelManagementHandler:
|
||||
formatted_metadata = await self._service.format_response(model_data)
|
||||
return web.json_response({"success": True, "metadata": formatted_metadata})
|
||||
except Exception as exc:
|
||||
if is_expected_offline_error(str(exc)):
|
||||
return web.json_response(
|
||||
{"success": False, "error": OFFLINE_FRIENDLY_MESSAGE},
|
||||
status=503,
|
||||
)
|
||||
self._logger.error("Error fetching from CivitAI: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
@@ -550,6 +559,11 @@ class ModelManagementHandler:
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
if is_expected_offline_error(str(exc)):
|
||||
return web.json_response(
|
||||
{"success": False, "error": OFFLINE_FRIENDLY_MESSAGE},
|
||||
status=503,
|
||||
)
|
||||
self._logger.error("Error re-linking to CivitAI: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
@@ -1858,6 +1872,11 @@ class ModelUpdateHandler:
|
||||
status=429,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive log
|
||||
if is_expected_offline_error(str(exc)):
|
||||
return web.json_response(
|
||||
{"success": False, "error": OFFLINE_FRIENDLY_MESSAGE},
|
||||
status=503,
|
||||
)
|
||||
self._logger.error("Failed to fetch license info: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
@@ -1946,9 +1965,12 @@ class ModelUpdateHandler:
|
||||
{"success": False, "error": str(exc) or "Rate limited"}, status=429
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._logger.error(
|
||||
"Failed to refresh model updates: %s", exc, exc_info=True
|
||||
)
|
||||
if is_expected_offline_error(str(exc)):
|
||||
return web.json_response(
|
||||
{"success": False, "error": OFFLINE_FRIENDLY_MESSAGE},
|
||||
status=503,
|
||||
)
|
||||
self._logger.error("Failed to refresh model updates: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
serialized_records = []
|
||||
|
||||
@@ -3,6 +3,11 @@ import copy
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional, Dict, Tuple, List, Sequence
|
||||
from .connectivity_guard import (
|
||||
OFFLINE_FRIENDLY_MESSAGE,
|
||||
is_expected_offline_error,
|
||||
is_offline_cooldown_error,
|
||||
)
|
||||
from .model_metadata_provider import (
|
||||
CivitaiModelMetadataProvider,
|
||||
ModelMetadataProviderManager,
|
||||
@@ -65,6 +70,8 @@ class CivitaiClient:
|
||||
if result.provider is None:
|
||||
result.provider = "civitai_api"
|
||||
raise result
|
||||
if not success and is_offline_cooldown_error(result):
|
||||
return False, OFFLINE_FRIENDLY_MESSAGE
|
||||
return success, result
|
||||
|
||||
@staticmethod
|
||||
@@ -124,6 +131,8 @@ class CivitaiClient:
|
||||
)
|
||||
if not success:
|
||||
message = str(version)
|
||||
if is_expected_offline_error(message):
|
||||
return None, OFFLINE_FRIENDLY_MESSAGE
|
||||
if "not found" in message.lower():
|
||||
return None, "Model not found"
|
||||
|
||||
@@ -164,6 +173,9 @@ class CivitaiClient:
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
if is_expected_offline_error(str(e)):
|
||||
logger.debug("Preview download skipped due to offline state.")
|
||||
return False
|
||||
logger.error(f"Download Error: {str(e)}")
|
||||
return False
|
||||
|
||||
@@ -207,6 +219,9 @@ class CivitaiClient:
|
||||
message = self._extract_error_message(result)
|
||||
if message and "not found" in message.lower():
|
||||
raise ResourceNotFoundError(f"Resource not found for model {model_id}")
|
||||
if is_expected_offline_error(message):
|
||||
logger.info("Civitai request skipped: %s", OFFLINE_FRIENDLY_MESSAGE)
|
||||
return None
|
||||
if message:
|
||||
raise RuntimeError(message)
|
||||
return None
|
||||
@@ -357,6 +372,8 @@ class CivitaiClient:
|
||||
)
|
||||
if success:
|
||||
return data
|
||||
if is_expected_offline_error(data):
|
||||
return None
|
||||
logger.warning(f"Failed to fetch model data for model {model_id}")
|
||||
return None
|
||||
|
||||
@@ -371,6 +388,8 @@ class CivitaiClient:
|
||||
)
|
||||
if success:
|
||||
return version
|
||||
if is_expected_offline_error(version):
|
||||
return None
|
||||
|
||||
logger.warning(f"Failed to fetch version by id {version_id}")
|
||||
return None
|
||||
@@ -386,6 +405,8 @@ class CivitaiClient:
|
||||
)
|
||||
if success:
|
||||
return version
|
||||
if is_expected_offline_error(version):
|
||||
return None
|
||||
|
||||
logger.warning(f"Failed to fetch version by hash {model_hash}")
|
||||
return None
|
||||
@@ -473,6 +494,8 @@ class CivitaiClient:
|
||||
return result, None
|
||||
|
||||
# Handle specific error cases
|
||||
if is_expected_offline_error(result):
|
||||
return None, OFFLINE_FRIENDLY_MESSAGE
|
||||
if "not found" in str(result):
|
||||
error_msg = f"Model not found"
|
||||
logger.warning(f"Model version not found: {version_id} - {error_msg}")
|
||||
@@ -507,6 +530,8 @@ class CivitaiClient:
|
||||
success, result = await self._make_request("GET", url, use_auth=True)
|
||||
|
||||
if not success:
|
||||
if is_expected_offline_error(result):
|
||||
return None
|
||||
logger.error(
|
||||
"Failed to fetch image info for ID %s from civitai.red: %s",
|
||||
image_id,
|
||||
@@ -566,6 +591,9 @@ class CivitaiClient:
|
||||
)
|
||||
|
||||
if not success:
|
||||
if is_expected_offline_error(result):
|
||||
logger.info("User model fetch skipped: %s", OFFLINE_FRIENDLY_MESSAGE)
|
||||
return None
|
||||
logger.error("Failed to fetch models for %s: %s", username, result)
|
||||
return None
|
||||
|
||||
|
||||
147
py/services/connectivity_guard.py
Normal file
147
py/services/connectivity_guard.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""In-memory connectivity guard to suppress repeated network retries when offline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import errno
|
||||
import logging
|
||||
import socket
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OFFLINE_COOLDOWN_ERROR = "offline_cooldown"
|
||||
OFFLINE_FRIENDLY_MESSAGE = "Network offline, will retry automatically later"
|
||||
|
||||
|
||||
def is_offline_cooldown_error(value: Any) -> bool:
|
||||
"""Return True when a response payload represents guard short-circuit."""
|
||||
return isinstance(value, str) and value == OFFLINE_COOLDOWN_ERROR
|
||||
|
||||
|
||||
def is_expected_offline_error(value: Any) -> bool:
|
||||
"""Return True when payload is an expected offline-related result."""
|
||||
if is_offline_cooldown_error(value):
|
||||
return True
|
||||
if not isinstance(value, str):
|
||||
return False
|
||||
normalized = value.lower()
|
||||
return "network offline" in normalized or "offline" in normalized
|
||||
|
||||
|
||||
class ConnectivityGuard:
|
||||
"""Tracks network failures and gates outbound requests during cooldown."""
|
||||
|
||||
_instance: "ConnectivityGuard | None" = None
|
||||
_instance_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> "ConnectivityGuard":
|
||||
async with cls._instance_lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if hasattr(self, "_initialized"):
|
||||
return
|
||||
self._initialized = True
|
||||
self.online = True
|
||||
self.failure_count = 0
|
||||
self.cooldown_until: datetime | None = None
|
||||
self.base_backoff_seconds = 30
|
||||
self.max_backoff_seconds = 300
|
||||
self.failure_threshold = 3
|
||||
|
||||
def _now(self) -> datetime:
|
||||
return datetime.now()
|
||||
|
||||
def in_cooldown(self) -> bool:
|
||||
if self.cooldown_until is None:
|
||||
return False
|
||||
return self._now() < self.cooldown_until
|
||||
|
||||
def cooldown_remaining_seconds(self) -> float:
|
||||
if self.cooldown_until is None:
|
||||
return 0.0
|
||||
return max(0.0, (self.cooldown_until - self._now()).total_seconds())
|
||||
|
||||
def should_block_request(self) -> bool:
|
||||
return self.in_cooldown()
|
||||
|
||||
def register_success(self) -> None:
|
||||
was_offline = (not self.online) or self.cooldown_until is not None
|
||||
self.online = True
|
||||
self.failure_count = 0
|
||||
self.cooldown_until = None
|
||||
if was_offline:
|
||||
logger.info("Connectivity restored; requests resumed.")
|
||||
|
||||
def register_network_failure(self, exc: Exception) -> None:
|
||||
self.online = False
|
||||
self.failure_count += 1
|
||||
|
||||
if self.failure_count < self.failure_threshold:
|
||||
logger.debug(
|
||||
"Network failure tracked (%d/%d): %s",
|
||||
self.failure_count,
|
||||
self.failure_threshold,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
retry_step = self.failure_count - self.failure_threshold
|
||||
backoff = min(
|
||||
self.max_backoff_seconds,
|
||||
self.base_backoff_seconds * (2**retry_step),
|
||||
)
|
||||
should_log_warning = not self.in_cooldown()
|
||||
self.cooldown_until = self._now() + timedelta(seconds=backoff)
|
||||
|
||||
if should_log_warning:
|
||||
logger.warning(
|
||||
"Connectivity offline; enter cooldown for %ss after %d network failures.",
|
||||
int(backoff),
|
||||
self.failure_count,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Cooldown still active; failure_count=%d, backoff=%ss.",
|
||||
self.failure_count,
|
||||
int(backoff),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_network_unreachable_error(exc: Exception) -> bool:
|
||||
"""Return whether the exception should count as connectivity failure."""
|
||||
if isinstance(exc, asyncio.CancelledError):
|
||||
return False
|
||||
|
||||
if isinstance(
|
||||
exc,
|
||||
(
|
||||
asyncio.TimeoutError,
|
||||
TimeoutError,
|
||||
ConnectionRefusedError,
|
||||
socket.gaierror,
|
||||
aiohttp.ServerTimeoutError,
|
||||
aiohttp.ConnectionTimeoutError,
|
||||
aiohttp.ClientConnectorError,
|
||||
aiohttp.ClientConnectionError,
|
||||
),
|
||||
):
|
||||
return True
|
||||
|
||||
if isinstance(exc, OSError) and exc.errno in {
|
||||
errno.ENETUNREACH,
|
||||
errno.EHOSTUNREACH,
|
||||
errno.ETIMEDOUT,
|
||||
errno.ECONNREFUSED,
|
||||
}:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -20,6 +20,10 @@ from datetime import datetime, timedelta
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import Optional, Dict, Tuple, Callable, Union, Awaitable
|
||||
from ..services.settings_manager import get_settings_manager
|
||||
from .connectivity_guard import (
|
||||
OFFLINE_COOLDOWN_ERROR,
|
||||
ConnectivityGuard,
|
||||
)
|
||||
from .errors import RateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -797,6 +801,10 @@ class Downloader:
|
||||
Returns:
|
||||
Tuple[bool, Union[bytes, str], Optional[Dict]]: (success, content or error message, response headers if requested)
|
||||
"""
|
||||
guard = await ConnectivityGuard.get_instance()
|
||||
if guard.should_block_request():
|
||||
return False, OFFLINE_COOLDOWN_ERROR, None
|
||||
|
||||
try:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
@@ -819,6 +827,7 @@ class Downloader:
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
content = await response.read()
|
||||
guard.register_success()
|
||||
if return_headers:
|
||||
return True, content, dict(response.headers)
|
||||
else:
|
||||
@@ -837,6 +846,12 @@ class Downloader:
|
||||
return False, error_msg, None
|
||||
|
||||
except Exception as e:
|
||||
if guard.is_network_unreachable_error(e):
|
||||
guard.register_network_failure(e)
|
||||
if guard.should_block_request():
|
||||
return False, OFFLINE_COOLDOWN_ERROR, None
|
||||
logger.debug("Network unavailable during memory download: %s", e)
|
||||
return False, str(e), None
|
||||
logger.error(f"Error downloading to memory from {url}: {e}")
|
||||
return False, str(e), None
|
||||
|
||||
@@ -857,6 +872,10 @@ class Downloader:
|
||||
Returns:
|
||||
Tuple[bool, Union[Dict, str]]: (success, headers dict or error message)
|
||||
"""
|
||||
guard = await ConnectivityGuard.get_instance()
|
||||
if guard.should_block_request():
|
||||
return False, OFFLINE_COOLDOWN_ERROR
|
||||
|
||||
try:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
@@ -878,11 +897,18 @@ class Downloader:
|
||||
url, headers=headers, proxy=self.proxy_url
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
guard.register_success()
|
||||
return True, dict(response.headers)
|
||||
else:
|
||||
return False, f"Head request failed with status {response.status}"
|
||||
|
||||
except Exception as e:
|
||||
if guard.is_network_unreachable_error(e):
|
||||
guard.register_network_failure(e)
|
||||
if guard.should_block_request():
|
||||
return False, OFFLINE_COOLDOWN_ERROR
|
||||
logger.debug("Network unavailable during header probe: %s", e)
|
||||
return False, str(e)
|
||||
logger.error(f"Error getting headers from {url}: {e}")
|
||||
return False, str(e)
|
||||
|
||||
@@ -907,6 +933,10 @@ class Downloader:
|
||||
Returns:
|
||||
Tuple[bool, Union[Dict, str]]: (success, response data or error message)
|
||||
"""
|
||||
guard = await ConnectivityGuard.get_instance()
|
||||
if guard.should_block_request():
|
||||
return False, OFFLINE_COOLDOWN_ERROR
|
||||
|
||||
try:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
@@ -930,6 +960,7 @@ class Downloader:
|
||||
method, url, headers=headers, **kwargs
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
guard.register_success()
|
||||
# Try to parse as JSON, fall back to text
|
||||
try:
|
||||
data = await response.json()
|
||||
@@ -960,6 +991,12 @@ class Downloader:
|
||||
return False, f"Request failed with status {response.status}"
|
||||
|
||||
except Exception as e:
|
||||
if guard.is_network_unreachable_error(e):
|
||||
guard.register_network_failure(e)
|
||||
if guard.should_block_request():
|
||||
return False, OFFLINE_COOLDOWN_ERROR
|
||||
logger.debug("Network unavailable for %s %s: %s", method, url, e)
|
||||
return False, str(e)
|
||||
logger.error(f"Error making {method} request to {url}: {e}")
|
||||
return False, str(e)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Any, Awaitable, Callable, Dict, Iterable, Optional
|
||||
from ..services.settings_manager import SettingsManager
|
||||
from ..utils.civitai_utils import resolve_license_payload
|
||||
from ..utils.model_utils import determine_base_model
|
||||
from .connectivity_guard import OFFLINE_FRIENDLY_MESSAGE, is_expected_offline_error
|
||||
from .errors import RateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -274,11 +275,18 @@ class MetadataSyncService:
|
||||
else "No provider returned metadata"
|
||||
)
|
||||
|
||||
resolved_error = last_error or default_error
|
||||
if is_expected_offline_error(resolved_error):
|
||||
resolved_error = OFFLINE_FRIENDLY_MESSAGE
|
||||
|
||||
error_msg = (
|
||||
f"Error fetching metadata: {last_error or default_error} "
|
||||
f"Error fetching metadata: {resolved_error} "
|
||||
f"(model_name={model_data.get('model_name', '')})"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
if is_expected_offline_error(resolved_error):
|
||||
logger.info(error_msg)
|
||||
else:
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
model_data["from_civitai"] = True
|
||||
@@ -347,6 +355,9 @@ class MetadataSyncService:
|
||||
return False, error_msg
|
||||
except Exception as exc: # pragma: no cover - error path
|
||||
error_msg = f"Error fetching metadata: {exc}"
|
||||
if is_expected_offline_error(str(exc)):
|
||||
logger.info(OFFLINE_FRIENDLY_MESSAGE)
|
||||
return False, OFFLINE_FRIENDLY_MESSAGE
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return False, error_msg
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
from py.services import civitai_client as civitai_client_module
|
||||
from py.services.civitai_client import CivitaiClient
|
||||
from py.services.connectivity_guard import OFFLINE_COOLDOWN_ERROR, OFFLINE_FRIENDLY_MESSAGE
|
||||
from py.services.errors import RateLimitError, ResourceNotFoundError
|
||||
from py.services.model_metadata_provider import ModelMetadataProviderManager
|
||||
|
||||
@@ -115,6 +116,20 @@ async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader):
|
||||
assert error == "Model not found"
|
||||
|
||||
|
||||
async def test_get_model_by_hash_handles_offline_cooldown(downloader):
|
||||
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||
return False, OFFLINE_COOLDOWN_ERROR
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result, error = await client.get_model_by_hash("missing")
|
||||
|
||||
assert result is None
|
||||
assert error == OFFLINE_FRIENDLY_MESSAGE
|
||||
|
||||
|
||||
async def test_get_model_by_hash_propagates_rate_limit(monkeypatch, downloader):
|
||||
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||
return False, RateLimitError("limited", retry_after=4)
|
||||
|
||||
74
tests/services/test_connectivity_guard.py
Normal file
74
tests/services/test_connectivity_guard.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import asyncio
|
||||
import errno
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.connectivity_guard import (
|
||||
OFFLINE_COOLDOWN_ERROR,
|
||||
ConnectivityGuard,
|
||||
)
|
||||
from py.services.downloader import Downloader
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_connectivity_guard_singleton():
|
||||
ConnectivityGuard._instance = None
|
||||
yield
|
||||
ConnectivityGuard._instance = None
|
||||
|
||||
|
||||
async def test_connectivity_guard_enters_cooldown_after_threshold():
|
||||
guard = await ConnectivityGuard.get_instance()
|
||||
|
||||
assert guard.online is True
|
||||
assert guard.should_block_request() is False
|
||||
|
||||
guard.register_network_failure(OSError(errno.ENETUNREACH, "unreachable"))
|
||||
guard.register_network_failure(asyncio.TimeoutError("timeout"))
|
||||
|
||||
assert guard.should_block_request() is False
|
||||
assert guard.failure_count == 2
|
||||
|
||||
guard.register_network_failure(ConnectionRefusedError("refused"))
|
||||
|
||||
assert guard.online is False
|
||||
assert guard.failure_count == 3
|
||||
assert guard.should_block_request() is True
|
||||
assert guard.cooldown_remaining_seconds() > 0
|
||||
|
||||
|
||||
async def test_connectivity_guard_recovers_after_success():
|
||||
guard = await ConnectivityGuard.get_instance()
|
||||
guard.online = False
|
||||
guard.failure_count = 5
|
||||
guard.cooldown_until = datetime.now() + timedelta(seconds=90)
|
||||
|
||||
guard.register_success()
|
||||
|
||||
assert guard.online is True
|
||||
assert guard.failure_count == 0
|
||||
assert guard.cooldown_until is None
|
||||
assert guard.should_block_request() is False
|
||||
|
||||
|
||||
async def test_downloader_short_circuits_all_request_helpers_during_cooldown():
|
||||
guard = await ConnectivityGuard.get_instance()
|
||||
guard.cooldown_until = datetime.now() + timedelta(seconds=30)
|
||||
guard.online = False
|
||||
guard.failure_count = 3
|
||||
|
||||
downloader = Downloader()
|
||||
|
||||
ok, payload = await downloader.make_request("GET", "https://example.invalid")
|
||||
assert ok is False
|
||||
assert payload == OFFLINE_COOLDOWN_ERROR
|
||||
|
||||
ok, payload, headers = await downloader.download_to_memory("https://example.invalid")
|
||||
assert ok is False
|
||||
assert payload == OFFLINE_COOLDOWN_ERROR
|
||||
assert headers is None
|
||||
|
||||
ok, payload = await downloader.get_response_headers("https://example.invalid")
|
||||
assert ok is False
|
||||
assert payload == OFFLINE_COOLDOWN_ERROR
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.connectivity_guard import OFFLINE_COOLDOWN_ERROR, OFFLINE_FRIENDLY_MESSAGE
|
||||
from py.services.errors import RateLimitError
|
||||
from py.services.metadata_sync_service import MetadataSyncService
|
||||
|
||||
@@ -243,17 +244,32 @@ async def test_fetch_and_update_model_handles_missing_remote_metadata(tmp_path):
|
||||
|
||||
assert not ok
|
||||
assert "Model not found" in error
|
||||
assert model_data["from_civitai"] is False
|
||||
assert model_data["civitai_deleted"] is True
|
||||
|
||||
helpers.metadata_manager.hydrate_model_data.assert_not_awaited()
|
||||
assert model_data["hydrated"] is True
|
||||
|
||||
helpers.metadata_manager.save_metadata.assert_awaited_once()
|
||||
call_args = helpers.metadata_manager.save_metadata.await_args
|
||||
assert call_args.args[0].endswith("model.safetensors")
|
||||
assert "folder" not in call_args.args[1]
|
||||
assert call_args.args[1]["hydrated"] is True
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_update_model_returns_friendly_offline_message(tmp_path):
|
||||
helpers = build_service()
|
||||
helpers.default_provider.get_model_by_hash.return_value = (None, OFFLINE_COOLDOWN_ERROR)
|
||||
|
||||
model_path = tmp_path / "model.safetensors"
|
||||
model_data = {
|
||||
"model_name": "Local",
|
||||
"folder": "root",
|
||||
"file_path": str(model_path),
|
||||
}
|
||||
update_cache = AsyncMock(return_value=True)
|
||||
|
||||
ok, error = await helpers.service.fetch_and_update_model(
|
||||
sha256="abc",
|
||||
file_path=str(model_path),
|
||||
model_data=model_data,
|
||||
update_cache_func=update_cache,
|
||||
)
|
||||
|
||||
assert ok is False
|
||||
assert error is not None
|
||||
assert OFFLINE_FRIENDLY_MESSAGE in error
|
||||
update_cache.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user