mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(metadata): add rate limit retry support to metadata providers
This commit is contained in:
@@ -2,11 +2,12 @@ import os
|
|||||||
import logging
|
import logging
|
||||||
from .model_metadata_provider import (
|
from .model_metadata_provider import (
|
||||||
ModelMetadataProvider,
|
ModelMetadataProvider,
|
||||||
ModelMetadataProviderManager,
|
ModelMetadataProviderManager,
|
||||||
SQLiteModelMetadataProvider,
|
SQLiteModelMetadataProvider,
|
||||||
CivitaiModelMetadataProvider,
|
CivitaiModelMetadataProvider,
|
||||||
CivArchiveModelMetadataProvider,
|
CivArchiveModelMetadataProvider,
|
||||||
FallbackMetadataProvider
|
FallbackMetadataProvider,
|
||||||
|
RateLimitRetryingProvider,
|
||||||
)
|
)
|
||||||
from .settings_manager import get_settings_manager
|
from .settings_manager import get_settings_manager
|
||||||
from .metadata_archive_manager import MetadataArchiveManager
|
from .metadata_archive_manager import MetadataArchiveManager
|
||||||
@@ -108,14 +109,24 @@ async def get_metadata_archive_manager():
|
|||||||
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
return MetadataArchiveManager(base_path)
|
return MetadataArchiveManager(base_path)
|
||||||
|
|
||||||
|
def _wrap_provider_with_rate_limit(provider_name: str | None, provider: ModelMetadataProvider) -> ModelMetadataProvider:
|
||||||
|
if isinstance(provider, (FallbackMetadataProvider, RateLimitRetryingProvider)):
|
||||||
|
return provider
|
||||||
|
return RateLimitRetryingProvider(provider, label=provider_name)
|
||||||
|
|
||||||
|
|
||||||
async def get_metadata_provider(provider_name: str = None):
|
async def get_metadata_provider(provider_name: str = None):
|
||||||
"""Get a specific metadata provider or default provider"""
|
"""Get a specific metadata provider or default provider with rate-limit handling."""
|
||||||
|
|
||||||
provider_manager = await ModelMetadataProviderManager.get_instance()
|
provider_manager = await ModelMetadataProviderManager.get_instance()
|
||||||
|
|
||||||
if provider_name:
|
provider = (
|
||||||
return provider_manager._get_provider(provider_name)
|
provider_manager._get_provider(provider_name)
|
||||||
|
if provider_name
|
||||||
return provider_manager._get_provider()
|
else provider_manager._get_provider()
|
||||||
|
)
|
||||||
|
|
||||||
|
return _wrap_provider_with_rate_limit(provider_name, provider)
|
||||||
|
|
||||||
async def get_default_metadata_provider():
|
async def get_default_metadata_provider():
|
||||||
"""Get the default metadata provider (fallback or single provider)"""
|
"""Get the default metadata provider (fallback or single provider)"""
|
||||||
|
|||||||
@@ -41,6 +41,55 @@ def _require_aiosqlite() -> Any:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class _RateLimitRetryHelper:
|
||||||
|
"""Coordinate exponential backoff retries after rate limiting."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
retry_limit: int = 3,
|
||||||
|
base_delay: float = 1.5,
|
||||||
|
max_delay: float = 30.0,
|
||||||
|
jitter_ratio: float = 0.2,
|
||||||
|
) -> None:
|
||||||
|
self._retry_limit = max(1, retry_limit)
|
||||||
|
self._base_delay = base_delay
|
||||||
|
self._max_delay = max_delay
|
||||||
|
self._jitter_ratio = max(0.0, jitter_ratio)
|
||||||
|
|
||||||
|
async def run(self, label: str, func, *args, **kwargs):
|
||||||
|
attempt = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
except RateLimitError as exc:
|
||||||
|
attempt += 1
|
||||||
|
if attempt >= self._retry_limit:
|
||||||
|
exc.provider = exc.provider or label
|
||||||
|
raise
|
||||||
|
|
||||||
|
delay = self._calculate_delay(exc.retry_after, attempt)
|
||||||
|
logger.warning(
|
||||||
|
"Provider %s rate limited request; retrying in %.2fs (attempt %s/%s)",
|
||||||
|
label,
|
||||||
|
delay,
|
||||||
|
attempt,
|
||||||
|
self._retry_limit,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
def _calculate_delay(self, retry_after: Optional[float], attempt: int) -> float:
|
||||||
|
if retry_after is not None:
|
||||||
|
return min(self._max_delay, max(0.0, retry_after))
|
||||||
|
|
||||||
|
base_delay = self._base_delay * (2 ** max(0, attempt - 1))
|
||||||
|
jitter_span = base_delay * self._jitter_ratio
|
||||||
|
if jitter_span > 0:
|
||||||
|
base_delay += random.uniform(-jitter_span, jitter_span)
|
||||||
|
|
||||||
|
return min(self._max_delay, max(0.0, base_delay))
|
||||||
|
|
||||||
class ModelMetadataProvider(ABC):
|
class ModelMetadataProvider(ABC):
|
||||||
"""Base abstract class for all model metadata providers"""
|
"""Base abstract class for all model metadata providers"""
|
||||||
|
|
||||||
@@ -390,6 +439,12 @@ class FallbackMetadataProvider(ModelMetadataProvider):
|
|||||||
self._rate_limit_base_delay = rate_limit_base_delay
|
self._rate_limit_base_delay = rate_limit_base_delay
|
||||||
self._rate_limit_max_delay = rate_limit_max_delay
|
self._rate_limit_max_delay = rate_limit_max_delay
|
||||||
self._rate_limit_jitter_ratio = max(0.0, rate_limit_jitter_ratio)
|
self._rate_limit_jitter_ratio = max(0.0, rate_limit_jitter_ratio)
|
||||||
|
self._rate_limit_helper = _RateLimitRetryHelper(
|
||||||
|
retry_limit=self._rate_limit_retry_limit,
|
||||||
|
base_delay=self._rate_limit_base_delay,
|
||||||
|
max_delay=self._rate_limit_max_delay,
|
||||||
|
jitter_ratio=self._rate_limit_jitter_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
for provider, label in self._iter_providers():
|
for provider, label in self._iter_providers():
|
||||||
@@ -485,44 +540,80 @@ class FallbackMetadataProvider(ModelMetadataProvider):
|
|||||||
def _iter_providers(self):
|
def _iter_providers(self):
|
||||||
return zip(self.providers, self._provider_labels)
|
return zip(self.providers, self._provider_labels)
|
||||||
|
|
||||||
async def _call_with_rate_limit(
|
async def _call_with_rate_limit(self, label: str, func, *args, **kwargs):
|
||||||
|
return await self._rate_limit_helper.run(label, func, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitRetryingProvider(ModelMetadataProvider):
|
||||||
|
"""Adapter that retries individual provider calls after rate limiting."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
self,
|
self,
|
||||||
label: str,
|
provider: ModelMetadataProvider,
|
||||||
func,
|
label: Optional[str] = None,
|
||||||
*args,
|
*,
|
||||||
**kwargs,
|
rate_limit_retry_limit: int = 3,
|
||||||
):
|
rate_limit_base_delay: float = 1.5,
|
||||||
attempt = 0
|
rate_limit_max_delay: float = 30.0,
|
||||||
while True:
|
rate_limit_jitter_ratio: float = 0.2,
|
||||||
try:
|
) -> None:
|
||||||
return await func(*args, **kwargs)
|
self._provider = provider
|
||||||
except RateLimitError as exc:
|
self._label = label or provider.__class__.__name__
|
||||||
attempt += 1
|
self._rate_limit_helper = _RateLimitRetryHelper(
|
||||||
if attempt >= self._rate_limit_retry_limit:
|
retry_limit=rate_limit_retry_limit,
|
||||||
exc.provider = exc.provider or label
|
base_delay=rate_limit_base_delay,
|
||||||
raise exc
|
max_delay=rate_limit_max_delay,
|
||||||
delay = self._calculate_rate_limit_delay(exc.retry_after, attempt)
|
jitter_ratio=rate_limit_jitter_ratio,
|
||||||
logger.warning(
|
)
|
||||||
"Provider %s rate limited request; retrying in %.2fs (attempt %s/%s)",
|
|
||||||
label,
|
|
||||||
delay,
|
|
||||||
attempt,
|
|
||||||
self._rate_limit_retry_limit,
|
|
||||||
)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
except Exception:
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _calculate_rate_limit_delay(self, retry_after: Optional[float], attempt: int) -> float:
|
def __getattr__(self, item):
|
||||||
if retry_after is not None:
|
return getattr(self._provider, item)
|
||||||
return min(self._rate_limit_max_delay, max(0.0, retry_after))
|
|
||||||
|
|
||||||
base_delay = self._rate_limit_base_delay * (2 ** max(0, attempt - 1))
|
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
jitter_span = base_delay * self._rate_limit_jitter_ratio
|
return await self._rate_limit_helper.run(
|
||||||
if jitter_span > 0:
|
self._label,
|
||||||
base_delay += random.uniform(-jitter_span, jitter_span)
|
self._provider.get_model_by_hash,
|
||||||
|
model_hash,
|
||||||
|
)
|
||||||
|
|
||||||
return min(self._rate_limit_max_delay, max(0.0, base_delay))
|
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||||
|
return await self._rate_limit_helper.run(
|
||||||
|
self._label,
|
||||||
|
self._provider.get_model_versions,
|
||||||
|
model_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_model_versions_bulk(
|
||||||
|
self,
|
||||||
|
model_ids: Sequence[int],
|
||||||
|
) -> Optional[Dict[int, Dict]]:
|
||||||
|
return await self._rate_limit_helper.run(
|
||||||
|
self._label,
|
||||||
|
self._provider.get_model_versions_bulk,
|
||||||
|
model_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||||
|
return await self._rate_limit_helper.run(
|
||||||
|
self._label,
|
||||||
|
self._provider.get_model_version,
|
||||||
|
model_id,
|
||||||
|
version_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
|
return await self._rate_limit_helper.run(
|
||||||
|
self._label,
|
||||||
|
self._provider.get_model_version_info,
|
||||||
|
version_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_user_models(self, username: str) -> Optional[List[Dict]]:
|
||||||
|
return await self._rate_limit_helper.run(
|
||||||
|
self._label,
|
||||||
|
self._provider.get_user_models,
|
||||||
|
username,
|
||||||
|
)
|
||||||
|
|
||||||
class ModelMetadataProviderManager:
|
class ModelMetadataProviderManager:
|
||||||
"""Manager for selecting and using model metadata providers"""
|
"""Manager for selecting and using model metadata providers"""
|
||||||
|
|||||||
62
tests/services/test_metadata_service.py
Normal file
62
tests/services/test_metadata_service.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from py.services import metadata_service
|
||||||
|
from py.services.model_metadata_provider import (
|
||||||
|
FallbackMetadataProvider,
|
||||||
|
ModelMetadataProvider,
|
||||||
|
RateLimitRetryingProvider,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyProvider(ModelMetadataProvider):
|
||||||
|
async def get_model_by_hash(self, model_hash: str):
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
async def get_model_versions(self, model_id: str):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_model_versions_bulk(self, model_ids):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_model_version(self, model_id: int = None, version_id: int = None):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_model_version_info(self, version_id: str):
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
async def get_user_models(self, username: str):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_metadata_provider_wraps_non_fallback(monkeypatch):
|
||||||
|
provider = DummyProvider()
|
||||||
|
dummy_manager = SimpleNamespace(_get_provider=lambda _name=None: provider)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
metadata_service.ModelMetadataProviderManager,
|
||||||
|
"get_instance",
|
||||||
|
AsyncMock(return_value=dummy_manager),
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapped = await metadata_service.get_metadata_provider("dummy")
|
||||||
|
|
||||||
|
assert isinstance(wrapped, RateLimitRetryingProvider)
|
||||||
|
assert wrapped is not provider
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_metadata_provider_returns_fallback_as_is(monkeypatch):
|
||||||
|
fallback = FallbackMetadataProvider([("dummy", DummyProvider())])
|
||||||
|
dummy_manager = SimpleNamespace(_get_provider=lambda _name=None: fallback)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
metadata_service.ModelMetadataProviderManager,
|
||||||
|
"get_instance",
|
||||||
|
AsyncMock(return_value=dummy_manager),
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = await metadata_service.get_metadata_provider()
|
||||||
|
|
||||||
|
assert provider is fallback
|
||||||
@@ -4,7 +4,10 @@ import pytest
|
|||||||
|
|
||||||
from py.services import model_metadata_provider as provider_module
|
from py.services import model_metadata_provider as provider_module
|
||||||
from py.services.errors import RateLimitError
|
from py.services.errors import RateLimitError
|
||||||
from py.services.model_metadata_provider import FallbackMetadataProvider
|
from py.services.model_metadata_provider import (
|
||||||
|
FallbackMetadataProvider,
|
||||||
|
RateLimitRetryingProvider,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RateLimitThenSuccessProvider:
|
class RateLimitThenSuccessProvider:
|
||||||
@@ -80,3 +83,37 @@ async def test_fallback_respects_retry_limit(monkeypatch):
|
|||||||
assert primary.calls == 2
|
assert primary.calls == 2
|
||||||
assert secondary.calls == 0
|
assert secondary.calls == 0
|
||||||
sleep_mock.assert_awaited_once()
|
sleep_mock.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rate_limit_retrying_provider_retries(monkeypatch):
|
||||||
|
sleep_mock = AsyncMock()
|
||||||
|
monkeypatch.setattr(provider_module.asyncio, "sleep", sleep_mock)
|
||||||
|
monkeypatch.setattr(provider_module.random, "uniform", lambda *_: 0.0)
|
||||||
|
|
||||||
|
inner = RateLimitThenSuccessProvider()
|
||||||
|
wrapper = RateLimitRetryingProvider(inner, label="inner", rate_limit_base_delay=0.1)
|
||||||
|
|
||||||
|
result, error = await wrapper.get_model_by_hash("abc")
|
||||||
|
|
||||||
|
assert error is None
|
||||||
|
assert result == {"id": "ok"}
|
||||||
|
assert inner.calls == 2
|
||||||
|
sleep_mock.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rate_limit_retrying_provider_respects_limit(monkeypatch):
|
||||||
|
sleep_mock = AsyncMock()
|
||||||
|
monkeypatch.setattr(provider_module.asyncio, "sleep", sleep_mock)
|
||||||
|
monkeypatch.setattr(provider_module.random, "uniform", lambda *_: 0.0)
|
||||||
|
|
||||||
|
inner = AlwaysRateLimitedProvider()
|
||||||
|
wrapper = RateLimitRetryingProvider(inner, label="inner", rate_limit_retry_limit=2)
|
||||||
|
|
||||||
|
with pytest.raises(RateLimitError) as exc_info:
|
||||||
|
await wrapper.get_model_by_hash("abc")
|
||||||
|
|
||||||
|
assert exc_info.value.provider == "inner"
|
||||||
|
assert inner.calls == 2
|
||||||
|
sleep_mock.assert_awaited_once()
|
||||||
|
|||||||
Reference in New Issue
Block a user