mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-14 09:07:36 -03:00
Compare commits
8 Commits
761108bfd1
...
89e26d9292
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
89e26d9292 | ||
|
|
fc19a145ff | ||
|
|
34f03d6495 | ||
|
|
9443175abc | ||
|
|
dc5072628f | ||
|
|
ff4b8ec849 | ||
|
|
7ab271c752 | ||
|
|
5a7f4dc88b |
77
py/config.py
77
py/config.py
@@ -26,20 +26,44 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def _resolve_valid_default_root(
|
def _resolve_valid_default_root(
|
||||||
current: str, primary_paths: List[str], name: str
|
current: str, primary_paths: List[str], allowed_paths: List[str], name: str
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Return a valid default root from the current primary path set."""
|
"""Return a valid default root from the current primary/extra path set."""
|
||||||
|
|
||||||
valid_paths = [path for path in primary_paths if isinstance(path, str) and path.strip()]
|
valid_paths = [path for path in primary_paths if isinstance(path, str) and path.strip()]
|
||||||
if not valid_paths:
|
fallback_paths: List[str] = []
|
||||||
return ""
|
seen: Set[str] = set()
|
||||||
|
for path in allowed_paths:
|
||||||
|
if not isinstance(path, str):
|
||||||
|
continue
|
||||||
|
stripped = path.strip()
|
||||||
|
if not stripped or stripped in seen:
|
||||||
|
continue
|
||||||
|
seen.add(stripped)
|
||||||
|
fallback_paths.append(stripped)
|
||||||
|
|
||||||
if current in valid_paths:
|
allowed = set(fallback_paths)
|
||||||
|
|
||||||
|
if current and current in allowed:
|
||||||
return current
|
return current
|
||||||
|
|
||||||
|
if not valid_paths:
|
||||||
|
if not fallback_paths:
|
||||||
|
return ""
|
||||||
|
if current:
|
||||||
|
logger.info(
|
||||||
|
"Repaired stale %s from '%s' to '%s' because it is not present in primary or extra roots",
|
||||||
|
name,
|
||||||
|
current,
|
||||||
|
fallback_paths[0],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("Auto-setting %s to '%s'", name, fallback_paths[0])
|
||||||
|
return fallback_paths[0]
|
||||||
|
|
||||||
if current:
|
if current:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Repaired stale %s from '%s' to '%s'",
|
"Repaired stale %s from '%s' to '%s' because it is not present in primary or extra roots",
|
||||||
name,
|
name,
|
||||||
current,
|
current,
|
||||||
valid_paths[0],
|
valid_paths[0],
|
||||||
@@ -226,39 +250,76 @@ class Config:
|
|||||||
default_lora_root = _resolve_valid_default_root(
|
default_lora_root = _resolve_valid_default_root(
|
||||||
comfy_library.get("default_lora_root", ""),
|
comfy_library.get("default_lora_root", ""),
|
||||||
list(self.loras_roots or []),
|
list(self.loras_roots or []),
|
||||||
|
list(self.loras_roots or [])
|
||||||
|
+ list(comfy_library.get("extra_folder_paths", {}).get("loras", []) or []),
|
||||||
"default_lora_root",
|
"default_lora_root",
|
||||||
)
|
)
|
||||||
|
|
||||||
default_checkpoint_root = _resolve_valid_default_root(
|
default_checkpoint_root = _resolve_valid_default_root(
|
||||||
comfy_library.get("default_checkpoint_root", ""),
|
comfy_library.get("default_checkpoint_root", ""),
|
||||||
list(self.checkpoints_roots or []),
|
list(self.checkpoints_roots or []),
|
||||||
|
list(self.checkpoints_roots or [])
|
||||||
|
+ list(comfy_library.get("extra_folder_paths", {}).get("checkpoints", []) or []),
|
||||||
"default_checkpoint_root",
|
"default_checkpoint_root",
|
||||||
)
|
)
|
||||||
|
|
||||||
default_embedding_root = _resolve_valid_default_root(
|
default_embedding_root = _resolve_valid_default_root(
|
||||||
comfy_library.get("default_embedding_root", ""),
|
comfy_library.get("default_embedding_root", ""),
|
||||||
list(self.embeddings_roots or []),
|
list(self.embeddings_roots or []),
|
||||||
|
list(self.embeddings_roots or [])
|
||||||
|
+ list(comfy_library.get("extra_folder_paths", {}).get("embeddings", []) or []),
|
||||||
"default_embedding_root",
|
"default_embedding_root",
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = dict(comfy_library.get("metadata", {}))
|
metadata = dict(comfy_library.get("metadata", {}))
|
||||||
metadata.setdefault("display_name", "ComfyUI")
|
metadata.setdefault("display_name", "ComfyUI")
|
||||||
metadata["source"] = "comfyui"
|
metadata["source"] = "comfyui"
|
||||||
|
extra_folder_paths = {}
|
||||||
|
if isinstance(comfy_library, Mapping):
|
||||||
|
existing_extra_paths = comfy_library.get("extra_folder_paths", {})
|
||||||
|
if isinstance(existing_extra_paths, Mapping):
|
||||||
|
extra_folder_paths = {
|
||||||
|
key: list(value) if isinstance(value, list) else []
|
||||||
|
for key, value in existing_extra_paths.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
active_library_name = settings_service.get_active_library_name()
|
||||||
|
should_activate = (
|
||||||
|
active_library_name == "comfyui"
|
||||||
|
or self._should_activate_comfy_library(libraries, libraries_changed)
|
||||||
|
)
|
||||||
|
|
||||||
settings_service.upsert_library(
|
settings_service.upsert_library(
|
||||||
"comfyui",
|
"comfyui",
|
||||||
folder_paths=target_folder_paths,
|
folder_paths=target_folder_paths,
|
||||||
|
extra_folder_paths=extra_folder_paths,
|
||||||
default_lora_root=default_lora_root,
|
default_lora_root=default_lora_root,
|
||||||
default_checkpoint_root=default_checkpoint_root,
|
default_checkpoint_root=default_checkpoint_root,
|
||||||
default_embedding_root=default_embedding_root,
|
default_embedding_root=default_embedding_root,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
activate=True,
|
activate=should_activate,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Updated 'comfyui' library with current folder paths")
|
if should_activate:
|
||||||
|
logger.info("Updated 'comfyui' library with current folder paths")
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Updated 'comfyui' library with current folder paths without activating it"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to save folder paths: {e}")
|
logger.warning(f"Failed to save folder paths: {e}")
|
||||||
|
|
||||||
|
def _should_activate_comfy_library(
|
||||||
|
self, libraries: Mapping[str, Any], libraries_changed: bool
|
||||||
|
) -> bool:
|
||||||
|
"""Return whether startup sync should make the ComfyUI library active."""
|
||||||
|
|
||||||
|
if libraries_changed:
|
||||||
|
return True
|
||||||
|
if not libraries:
|
||||||
|
return True
|
||||||
|
return "comfyui" in libraries and len(libraries) == 1
|
||||||
|
|
||||||
def _is_link(self, path: str) -> bool:
|
def _is_link(self, path: str) -> bool:
|
||||||
try:
|
try:
|
||||||
if os.path.islink(path):
|
if os.path.islink(path):
|
||||||
|
|||||||
@@ -16,6 +16,10 @@ import jinja2
|
|||||||
|
|
||||||
from ...config import config
|
from ...config import config
|
||||||
from ...services.download_coordinator import DownloadCoordinator
|
from ...services.download_coordinator import DownloadCoordinator
|
||||||
|
from ...services.connectivity_guard import (
|
||||||
|
OFFLINE_FRIENDLY_MESSAGE,
|
||||||
|
is_expected_offline_error,
|
||||||
|
)
|
||||||
from ...services.metadata_sync_service import MetadataSyncService
|
from ...services.metadata_sync_service import MetadataSyncService
|
||||||
from ...services.model_file_service import ModelMoveService
|
from ...services.model_file_service import ModelMoveService
|
||||||
from ...services.preview_asset_service import PreviewAssetService
|
from ...services.preview_asset_service import PreviewAssetService
|
||||||
@@ -504,6 +508,11 @@ class ModelManagementHandler:
|
|||||||
formatted_metadata = await self._service.format_response(model_data)
|
formatted_metadata = await self._service.format_response(model_data)
|
||||||
return web.json_response({"success": True, "metadata": formatted_metadata})
|
return web.json_response({"success": True, "metadata": formatted_metadata})
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
if is_expected_offline_error(str(exc)):
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": OFFLINE_FRIENDLY_MESSAGE},
|
||||||
|
status=503,
|
||||||
|
)
|
||||||
self._logger.error("Error fetching from CivitAI: %s", exc, exc_info=True)
|
self._logger.error("Error fetching from CivitAI: %s", exc, exc_info=True)
|
||||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
@@ -550,6 +559,11 @@ class ModelManagementHandler:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
if is_expected_offline_error(str(exc)):
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": OFFLINE_FRIENDLY_MESSAGE},
|
||||||
|
status=503,
|
||||||
|
)
|
||||||
self._logger.error("Error re-linking to CivitAI: %s", exc, exc_info=True)
|
self._logger.error("Error re-linking to CivitAI: %s", exc, exc_info=True)
|
||||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
@@ -1858,6 +1872,11 @@ class ModelUpdateHandler:
|
|||||||
status=429,
|
status=429,
|
||||||
)
|
)
|
||||||
except Exception as exc: # pragma: no cover - defensive log
|
except Exception as exc: # pragma: no cover - defensive log
|
||||||
|
if is_expected_offline_error(str(exc)):
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": OFFLINE_FRIENDLY_MESSAGE},
|
||||||
|
status=503,
|
||||||
|
)
|
||||||
self._logger.error("Failed to fetch license info: %s", exc, exc_info=True)
|
self._logger.error("Failed to fetch license info: %s", exc, exc_info=True)
|
||||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
@@ -1946,9 +1965,12 @@ class ModelUpdateHandler:
|
|||||||
{"success": False, "error": str(exc) or "Rate limited"}, status=429
|
{"success": False, "error": str(exc) or "Rate limited"}, status=429
|
||||||
)
|
)
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
self._logger.error(
|
if is_expected_offline_error(str(exc)):
|
||||||
"Failed to refresh model updates: %s", exc, exc_info=True
|
return web.json_response(
|
||||||
)
|
{"success": False, "error": OFFLINE_FRIENDLY_MESSAGE},
|
||||||
|
status=503,
|
||||||
|
)
|
||||||
|
self._logger.error("Failed to refresh model updates: %s", exc, exc_info=True)
|
||||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
serialized_records = []
|
serialized_records = []
|
||||||
|
|||||||
@@ -3,6 +3,11 @@ import copy
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Optional, Dict, Tuple, List, Sequence
|
from typing import Any, Optional, Dict, Tuple, List, Sequence
|
||||||
|
from .connectivity_guard import (
|
||||||
|
OFFLINE_FRIENDLY_MESSAGE,
|
||||||
|
is_expected_offline_error,
|
||||||
|
is_offline_cooldown_error,
|
||||||
|
)
|
||||||
from .model_metadata_provider import (
|
from .model_metadata_provider import (
|
||||||
CivitaiModelMetadataProvider,
|
CivitaiModelMetadataProvider,
|
||||||
ModelMetadataProviderManager,
|
ModelMetadataProviderManager,
|
||||||
@@ -65,6 +70,8 @@ class CivitaiClient:
|
|||||||
if result.provider is None:
|
if result.provider is None:
|
||||||
result.provider = "civitai_api"
|
result.provider = "civitai_api"
|
||||||
raise result
|
raise result
|
||||||
|
if not success and is_offline_cooldown_error(result):
|
||||||
|
return False, OFFLINE_FRIENDLY_MESSAGE
|
||||||
return success, result
|
return success, result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -124,6 +131,8 @@ class CivitaiClient:
|
|||||||
)
|
)
|
||||||
if not success:
|
if not success:
|
||||||
message = str(version)
|
message = str(version)
|
||||||
|
if is_expected_offline_error(message):
|
||||||
|
return None, OFFLINE_FRIENDLY_MESSAGE
|
||||||
if "not found" in message.lower():
|
if "not found" in message.lower():
|
||||||
return None, "Model not found"
|
return None, "Model not found"
|
||||||
|
|
||||||
@@ -164,6 +173,9 @@ class CivitaiClient:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if is_expected_offline_error(str(e)):
|
||||||
|
logger.debug("Preview download skipped due to offline state.")
|
||||||
|
return False
|
||||||
logger.error(f"Download Error: {str(e)}")
|
logger.error(f"Download Error: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -207,6 +219,9 @@ class CivitaiClient:
|
|||||||
message = self._extract_error_message(result)
|
message = self._extract_error_message(result)
|
||||||
if message and "not found" in message.lower():
|
if message and "not found" in message.lower():
|
||||||
raise ResourceNotFoundError(f"Resource not found for model {model_id}")
|
raise ResourceNotFoundError(f"Resource not found for model {model_id}")
|
||||||
|
if is_expected_offline_error(message):
|
||||||
|
logger.info("Civitai request skipped: %s", OFFLINE_FRIENDLY_MESSAGE)
|
||||||
|
return None
|
||||||
if message:
|
if message:
|
||||||
raise RuntimeError(message)
|
raise RuntimeError(message)
|
||||||
return None
|
return None
|
||||||
@@ -357,6 +372,8 @@ class CivitaiClient:
|
|||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
return data
|
return data
|
||||||
|
if is_expected_offline_error(data):
|
||||||
|
return None
|
||||||
logger.warning(f"Failed to fetch model data for model {model_id}")
|
logger.warning(f"Failed to fetch model data for model {model_id}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -371,6 +388,8 @@ class CivitaiClient:
|
|||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
return version
|
return version
|
||||||
|
if is_expected_offline_error(version):
|
||||||
|
return None
|
||||||
|
|
||||||
logger.warning(f"Failed to fetch version by id {version_id}")
|
logger.warning(f"Failed to fetch version by id {version_id}")
|
||||||
return None
|
return None
|
||||||
@@ -386,6 +405,8 @@ class CivitaiClient:
|
|||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
return version
|
return version
|
||||||
|
if is_expected_offline_error(version):
|
||||||
|
return None
|
||||||
|
|
||||||
logger.warning(f"Failed to fetch version by hash {model_hash}")
|
logger.warning(f"Failed to fetch version by hash {model_hash}")
|
||||||
return None
|
return None
|
||||||
@@ -473,6 +494,8 @@ class CivitaiClient:
|
|||||||
return result, None
|
return result, None
|
||||||
|
|
||||||
# Handle specific error cases
|
# Handle specific error cases
|
||||||
|
if is_expected_offline_error(result):
|
||||||
|
return None, OFFLINE_FRIENDLY_MESSAGE
|
||||||
if "not found" in str(result):
|
if "not found" in str(result):
|
||||||
error_msg = f"Model not found"
|
error_msg = f"Model not found"
|
||||||
logger.warning(f"Model version not found: {version_id} - {error_msg}")
|
logger.warning(f"Model version not found: {version_id} - {error_msg}")
|
||||||
@@ -507,6 +530,8 @@ class CivitaiClient:
|
|||||||
success, result = await self._make_request("GET", url, use_auth=True)
|
success, result = await self._make_request("GET", url, use_auth=True)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
|
if is_expected_offline_error(result):
|
||||||
|
return None
|
||||||
logger.error(
|
logger.error(
|
||||||
"Failed to fetch image info for ID %s from civitai.red: %s",
|
"Failed to fetch image info for ID %s from civitai.red: %s",
|
||||||
image_id,
|
image_id,
|
||||||
@@ -566,6 +591,9 @@ class CivitaiClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
|
if is_expected_offline_error(result):
|
||||||
|
logger.info("User model fetch skipped: %s", OFFLINE_FRIENDLY_MESSAGE)
|
||||||
|
return None
|
||||||
logger.error("Failed to fetch models for %s: %s", username, result)
|
logger.error("Failed to fetch models for %s: %s", username, result)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
204
py/services/connectivity_guard.py
Normal file
204
py/services/connectivity_guard.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
"""In-memory connectivity guard to suppress repeated network retries when offline."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import errno
|
||||||
|
import logging
|
||||||
|
import socket
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
OFFLINE_COOLDOWN_ERROR = "offline_cooldown"
|
||||||
|
OFFLINE_FRIENDLY_MESSAGE = "Network offline, will retry automatically later"
|
||||||
|
|
||||||
|
|
||||||
|
def is_offline_cooldown_error(value: Any) -> bool:
|
||||||
|
"""Return True when a response payload represents guard short-circuit."""
|
||||||
|
return isinstance(value, str) and value == OFFLINE_COOLDOWN_ERROR
|
||||||
|
|
||||||
|
|
||||||
|
def is_expected_offline_error(value: Any) -> bool:
|
||||||
|
"""Return True when payload is an expected offline-related result."""
|
||||||
|
if is_offline_cooldown_error(value):
|
||||||
|
return True
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return False
|
||||||
|
normalized = value.lower()
|
||||||
|
return "network offline" in normalized or "offline" in normalized
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectivityGuard:
|
||||||
|
"""Tracks network failures and gates outbound requests during cooldown."""
|
||||||
|
|
||||||
|
_instance: "ConnectivityGuard | None" = None
|
||||||
|
_instance_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_instance(cls) -> "ConnectivityGuard":
|
||||||
|
async with cls._instance_lock:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = cls()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
if hasattr(self, "_initialized"):
|
||||||
|
return
|
||||||
|
self._initialized = True
|
||||||
|
self._default_destination = "__global__"
|
||||||
|
self._destination_states: dict[str, _DestinationState] = {
|
||||||
|
self._default_destination: _DestinationState()
|
||||||
|
}
|
||||||
|
self.base_backoff_seconds = 30
|
||||||
|
self.max_backoff_seconds = 300
|
||||||
|
self.failure_threshold = 3
|
||||||
|
|
||||||
|
@property
|
||||||
|
def online(self) -> bool:
|
||||||
|
return self._state_for_destination(None).online
|
||||||
|
|
||||||
|
@online.setter
|
||||||
|
def online(self, value: bool) -> None:
|
||||||
|
self._state_for_destination(None).online = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def failure_count(self) -> int:
|
||||||
|
return self._state_for_destination(None).failure_count
|
||||||
|
|
||||||
|
@failure_count.setter
|
||||||
|
def failure_count(self, value: int) -> None:
|
||||||
|
self._state_for_destination(None).failure_count = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cooldown_until(self) -> datetime | None:
|
||||||
|
return self._state_for_destination(None).cooldown_until
|
||||||
|
|
||||||
|
@cooldown_until.setter
|
||||||
|
def cooldown_until(self, value: datetime | None) -> None:
|
||||||
|
self._state_for_destination(None).cooldown_until = value
|
||||||
|
|
||||||
|
def _now(self) -> datetime:
|
||||||
|
return datetime.now()
|
||||||
|
|
||||||
|
def _normalize_destination(self, destination: str | None) -> str:
|
||||||
|
if destination is None or not destination.strip():
|
||||||
|
return self._default_destination
|
||||||
|
return destination.lower().strip()
|
||||||
|
|
||||||
|
def _state_for_destination(self, destination: str | None) -> "_DestinationState":
|
||||||
|
destination_key = self._normalize_destination(destination)
|
||||||
|
if destination_key not in self._destination_states:
|
||||||
|
self._destination_states[destination_key] = _DestinationState()
|
||||||
|
return self._destination_states[destination_key]
|
||||||
|
|
||||||
|
def in_cooldown(self, destination: str | None = None) -> bool:
|
||||||
|
state = self._state_for_destination(destination)
|
||||||
|
if state.cooldown_until is None:
|
||||||
|
return False
|
||||||
|
return self._now() < state.cooldown_until
|
||||||
|
|
||||||
|
def cooldown_remaining_seconds(self, destination: str | None = None) -> float:
|
||||||
|
state = self._state_for_destination(destination)
|
||||||
|
if state.cooldown_until is None:
|
||||||
|
return 0.0
|
||||||
|
return max(0.0, (state.cooldown_until - self._now()).total_seconds())
|
||||||
|
|
||||||
|
def should_block_request(self, destination: str | None = None) -> bool:
|
||||||
|
return self.in_cooldown(destination)
|
||||||
|
|
||||||
|
def register_success(self, destination: str | None = None) -> None:
|
||||||
|
destination_key = self._normalize_destination(destination)
|
||||||
|
state = self._state_for_destination(destination_key)
|
||||||
|
was_offline = (not state.online) or state.cooldown_until is not None
|
||||||
|
state.online = True
|
||||||
|
state.failure_count = 0
|
||||||
|
state.cooldown_until = None
|
||||||
|
if was_offline:
|
||||||
|
logger.info(
|
||||||
|
"Connectivity restored for destination '%s'; requests resumed.",
|
||||||
|
destination_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
def register_network_failure(
|
||||||
|
self, exc: Exception, destination: str | None = None
|
||||||
|
) -> None:
|
||||||
|
destination_key = self._normalize_destination(destination)
|
||||||
|
state = self._state_for_destination(destination_key)
|
||||||
|
state.online = False
|
||||||
|
state.failure_count += 1
|
||||||
|
|
||||||
|
if state.failure_count < self.failure_threshold:
|
||||||
|
logger.debug(
|
||||||
|
"Network failure tracked for destination '%s' (%d/%d): %s",
|
||||||
|
destination_key,
|
||||||
|
state.failure_count,
|
||||||
|
self.failure_threshold,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
retry_step = state.failure_count - self.failure_threshold
|
||||||
|
backoff = min(
|
||||||
|
self.max_backoff_seconds,
|
||||||
|
self.base_backoff_seconds * (2**retry_step),
|
||||||
|
)
|
||||||
|
should_log_warning = not self.in_cooldown(destination_key)
|
||||||
|
state.cooldown_until = self._now() + timedelta(seconds=backoff)
|
||||||
|
|
||||||
|
if should_log_warning:
|
||||||
|
logger.warning(
|
||||||
|
"Connectivity offline for destination '%s'; enter cooldown for %ss after %d network failures.",
|
||||||
|
destination_key,
|
||||||
|
int(backoff),
|
||||||
|
state.failure_count,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"Cooldown still active for destination '%s'; failure_count=%d, backoff=%ss.",
|
||||||
|
destination_key,
|
||||||
|
state.failure_count,
|
||||||
|
int(backoff),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_network_unreachable_error(exc: Exception) -> bool:
|
||||||
|
"""Return whether the exception should count as connectivity failure."""
|
||||||
|
if isinstance(exc, asyncio.CancelledError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if isinstance(
|
||||||
|
exc,
|
||||||
|
(
|
||||||
|
asyncio.TimeoutError,
|
||||||
|
TimeoutError,
|
||||||
|
ConnectionRefusedError,
|
||||||
|
socket.gaierror,
|
||||||
|
aiohttp.ServerTimeoutError,
|
||||||
|
aiohttp.ConnectionTimeoutError,
|
||||||
|
aiohttp.ClientConnectorError,
|
||||||
|
aiohttp.ClientConnectionError,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if isinstance(exc, OSError) and exc.errno in {
|
||||||
|
errno.ENETUNREACH,
|
||||||
|
errno.EHOSTUNREACH,
|
||||||
|
errno.ETIMEDOUT,
|
||||||
|
errno.ECONNREFUSED,
|
||||||
|
}:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _DestinationState:
|
||||||
|
online: bool = True
|
||||||
|
failure_count: int = 0
|
||||||
|
cooldown_until: datetime | None = None
|
||||||
@@ -18,8 +18,14 @@ from collections import deque
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from email.utils import parsedate_to_datetime
|
from email.utils import parsedate_to_datetime
|
||||||
|
from urllib.parse import urlparse
|
||||||
from typing import Optional, Dict, Tuple, Callable, Union, Awaitable
|
from typing import Optional, Dict, Tuple, Callable, Union, Awaitable
|
||||||
from ..services.settings_manager import get_settings_manager
|
from ..services.settings_manager import get_settings_manager
|
||||||
|
from .connectivity_guard import (
|
||||||
|
OFFLINE_COOLDOWN_ERROR,
|
||||||
|
OFFLINE_FRIENDLY_MESSAGE,
|
||||||
|
ConnectivityGuard,
|
||||||
|
)
|
||||||
from .errors import RateLimitError
|
from .errors import RateLimitError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -797,6 +803,10 @@ class Downloader:
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, Union[bytes, str], Optional[Dict]]: (success, content or error message, response headers if requested)
|
Tuple[bool, Union[bytes, str], Optional[Dict]]: (success, content or error message, response headers if requested)
|
||||||
"""
|
"""
|
||||||
|
guard = await ConnectivityGuard.get_instance()
|
||||||
|
if guard.should_block_request():
|
||||||
|
return False, OFFLINE_FRIENDLY_MESSAGE, None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session = await self.session
|
session = await self.session
|
||||||
# Debug log for proxy mode at request time
|
# Debug log for proxy mode at request time
|
||||||
@@ -819,6 +829,7 @@ class Downloader:
|
|||||||
) as response:
|
) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
content = await response.read()
|
content = await response.read()
|
||||||
|
guard.register_success(destination)
|
||||||
if return_headers:
|
if return_headers:
|
||||||
return True, content, dict(response.headers)
|
return True, content, dict(response.headers)
|
||||||
else:
|
else:
|
||||||
@@ -837,6 +848,12 @@ class Downloader:
|
|||||||
return False, error_msg, None
|
return False, error_msg, None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if guard.is_network_unreachable_error(e):
|
||||||
|
guard.register_network_failure(e)
|
||||||
|
if guard.should_block_request():
|
||||||
|
return False, OFFLINE_FRIENDLY_MESSAGE, None
|
||||||
|
logger.debug("Network unavailable during memory download: %s", e)
|
||||||
|
return False, str(e), None
|
||||||
logger.error(f"Error downloading to memory from {url}: {e}")
|
logger.error(f"Error downloading to memory from {url}: {e}")
|
||||||
return False, str(e), None
|
return False, str(e), None
|
||||||
|
|
||||||
@@ -857,6 +874,11 @@ class Downloader:
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, Union[Dict, str]]: (success, headers dict or error message)
|
Tuple[bool, Union[Dict, str]]: (success, headers dict or error message)
|
||||||
"""
|
"""
|
||||||
|
guard = await ConnectivityGuard.get_instance()
|
||||||
|
destination = self._guard_destination(url)
|
||||||
|
if guard.should_block_request(destination):
|
||||||
|
return False, OFFLINE_COOLDOWN_ERROR
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session = await self.session
|
session = await self.session
|
||||||
# Debug log for proxy mode at request time
|
# Debug log for proxy mode at request time
|
||||||
@@ -878,11 +900,18 @@ class Downloader:
|
|||||||
url, headers=headers, proxy=self.proxy_url
|
url, headers=headers, proxy=self.proxy_url
|
||||||
) as response:
|
) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
|
guard.register_success(destination)
|
||||||
return True, dict(response.headers)
|
return True, dict(response.headers)
|
||||||
else:
|
else:
|
||||||
return False, f"Head request failed with status {response.status}"
|
return False, f"Head request failed with status {response.status}"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if guard.is_network_unreachable_error(e):
|
||||||
|
guard.register_network_failure(e, destination)
|
||||||
|
if guard.should_block_request(destination):
|
||||||
|
return False, OFFLINE_COOLDOWN_ERROR
|
||||||
|
logger.debug("Network unavailable during header probe: %s", e)
|
||||||
|
return False, str(e)
|
||||||
logger.error(f"Error getting headers from {url}: {e}")
|
logger.error(f"Error getting headers from {url}: {e}")
|
||||||
return False, str(e)
|
return False, str(e)
|
||||||
|
|
||||||
@@ -907,6 +936,11 @@ class Downloader:
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, Union[Dict, str]]: (success, response data or error message)
|
Tuple[bool, Union[Dict, str]]: (success, response data or error message)
|
||||||
"""
|
"""
|
||||||
|
guard = await ConnectivityGuard.get_instance()
|
||||||
|
destination = self._guard_destination(url)
|
||||||
|
if guard.should_block_request(destination):
|
||||||
|
return False, OFFLINE_COOLDOWN_ERROR
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session = await self.session
|
session = await self.session
|
||||||
# Debug log for proxy mode at request time
|
# Debug log for proxy mode at request time
|
||||||
@@ -930,6 +964,7 @@ class Downloader:
|
|||||||
method, url, headers=headers, **kwargs
|
method, url, headers=headers, **kwargs
|
||||||
) as response:
|
) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
|
guard.register_success(destination)
|
||||||
# Try to parse as JSON, fall back to text
|
# Try to parse as JSON, fall back to text
|
||||||
try:
|
try:
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
@@ -960,6 +995,12 @@ class Downloader:
|
|||||||
return False, f"Request failed with status {response.status}"
|
return False, f"Request failed with status {response.status}"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if guard.is_network_unreachable_error(e):
|
||||||
|
guard.register_network_failure(e, destination)
|
||||||
|
if guard.should_block_request(destination):
|
||||||
|
return False, OFFLINE_COOLDOWN_ERROR
|
||||||
|
logger.debug("Network unavailable for %s %s: %s", method, url, e)
|
||||||
|
return False, str(e)
|
||||||
logger.error(f"Error making {method} request to {url}: {e}")
|
logger.error(f"Error making {method} request to {url}: {e}")
|
||||||
return False, str(e)
|
return False, str(e)
|
||||||
|
|
||||||
@@ -1010,6 +1051,14 @@ class Downloader:
|
|||||||
delta = retry_datetime - datetime.now(tz=retry_datetime.tzinfo)
|
delta = retry_datetime - datetime.now(tz=retry_datetime.tzinfo)
|
||||||
return max(0.0, delta.total_seconds())
|
return max(0.0, delta.total_seconds())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _guard_destination(url: str) -> str:
|
||||||
|
"""Build per-destination connectivity guard scope from request URL."""
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
if parsed_url.hostname:
|
||||||
|
return parsed_url.hostname.lower()
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
# Global instance accessor
|
# Global instance accessor
|
||||||
async def get_downloader() -> Downloader:
|
async def get_downloader() -> Downloader:
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from typing import Any, Awaitable, Callable, Dict, Iterable, Optional
|
|||||||
from ..services.settings_manager import SettingsManager
|
from ..services.settings_manager import SettingsManager
|
||||||
from ..utils.civitai_utils import resolve_license_payload
|
from ..utils.civitai_utils import resolve_license_payload
|
||||||
from ..utils.model_utils import determine_base_model
|
from ..utils.model_utils import determine_base_model
|
||||||
|
from .connectivity_guard import OFFLINE_FRIENDLY_MESSAGE, is_expected_offline_error
|
||||||
from .errors import RateLimitError
|
from .errors import RateLimitError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -274,11 +275,18 @@ class MetadataSyncService:
|
|||||||
else "No provider returned metadata"
|
else "No provider returned metadata"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
resolved_error = last_error or default_error
|
||||||
|
if is_expected_offline_error(resolved_error):
|
||||||
|
resolved_error = OFFLINE_FRIENDLY_MESSAGE
|
||||||
|
|
||||||
error_msg = (
|
error_msg = (
|
||||||
f"Error fetching metadata: {last_error or default_error} "
|
f"Error fetching metadata: {resolved_error} "
|
||||||
f"(model_name={model_data.get('model_name', '')})"
|
f"(model_name={model_data.get('model_name', '')})"
|
||||||
)
|
)
|
||||||
logger.error(error_msg)
|
if is_expected_offline_error(resolved_error):
|
||||||
|
logger.info(error_msg)
|
||||||
|
else:
|
||||||
|
logger.error(error_msg)
|
||||||
return False, error_msg
|
return False, error_msg
|
||||||
|
|
||||||
model_data["from_civitai"] = True
|
model_data["from_civitai"] = True
|
||||||
@@ -347,6 +355,9 @@ class MetadataSyncService:
|
|||||||
return False, error_msg
|
return False, error_msg
|
||||||
except Exception as exc: # pragma: no cover - error path
|
except Exception as exc: # pragma: no cover - error path
|
||||||
error_msg = f"Error fetching metadata: {exc}"
|
error_msg = f"Error fetching metadata: {exc}"
|
||||||
|
if is_expected_offline_error(str(exc)):
|
||||||
|
logger.info(OFFLINE_FRIENDLY_MESSAGE)
|
||||||
|
return False, OFFLINE_FRIENDLY_MESSAGE
|
||||||
logger.error(error_msg, exc_info=True)
|
logger.error(error_msg, exc_info=True)
|
||||||
return False, error_msg
|
return False, error_msg
|
||||||
|
|
||||||
|
|||||||
@@ -763,34 +763,29 @@ class SettingsManager:
|
|||||||
if self._preserve_disk_template:
|
if self._preserve_disk_template:
|
||||||
return
|
return
|
||||||
|
|
||||||
folder_paths = self.settings.get("folder_paths", {})
|
|
||||||
updated = False
|
updated = False
|
||||||
|
|
||||||
def _check_and_auto_set(key: str, setting_key: str) -> bool:
|
def _check_and_auto_set(key: str, setting_key: str) -> bool:
|
||||||
"""Repair default roots when empty or no longer present."""
|
"""Repair default roots when empty or no longer present."""
|
||||||
current = self.settings.get(setting_key, "")
|
current = self.settings.get(setting_key, "")
|
||||||
candidates = folder_paths.get(key, [])
|
primary_candidates = self._get_valid_root_candidates(key)
|
||||||
if not isinstance(candidates, list) or not candidates:
|
if not primary_candidates:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Filter valid path strings
|
allowed_roots = self._get_allowed_roots(key)
|
||||||
valid_paths = [p for p in candidates if isinstance(p, str) and p.strip()]
|
if current and current in allowed_roots:
|
||||||
if not valid_paths:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if current in valid_paths:
|
self.settings[setting_key] = primary_candidates[0]
|
||||||
return False
|
|
||||||
|
|
||||||
self.settings[setting_key] = valid_paths[0]
|
|
||||||
if current:
|
if current:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Repaired stale %s from '%s' to '%s'",
|
"Repaired stale %s from '%s' to '%s' because it is not present in primary or extra roots",
|
||||||
setting_key,
|
setting_key,
|
||||||
current,
|
current,
|
||||||
valid_paths[0],
|
primary_candidates[0],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Auto-set %s to '%s'", setting_key, valid_paths[0])
|
logger.info("Auto-set %s to '%s'", setting_key, primary_candidates[0])
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Process all model types
|
# Process all model types
|
||||||
@@ -813,6 +808,33 @@ class SettingsManager:
|
|||||||
else:
|
else:
|
||||||
self._save_settings()
|
self._save_settings()
|
||||||
|
|
||||||
|
def _get_valid_root_candidates(self, key: str) -> List[str]:
|
||||||
|
"""Return stable root candidates, preferring primary roots over extra roots."""
|
||||||
|
|
||||||
|
candidates: List[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for mapping_key in ("folder_paths", "extra_folder_paths"):
|
||||||
|
raw_paths = self.settings.get(mapping_key, {})
|
||||||
|
if not isinstance(raw_paths, Mapping):
|
||||||
|
continue
|
||||||
|
values = raw_paths.get(key, [])
|
||||||
|
if not isinstance(values, list):
|
||||||
|
continue
|
||||||
|
for value in values:
|
||||||
|
if not isinstance(value, str):
|
||||||
|
continue
|
||||||
|
normalized = value.strip()
|
||||||
|
if not normalized or normalized in seen:
|
||||||
|
continue
|
||||||
|
seen.add(normalized)
|
||||||
|
candidates.append(normalized)
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
def _get_allowed_roots(self, key: str) -> set[str]:
|
||||||
|
"""Return all valid roots for a model type, including extra roots."""
|
||||||
|
|
||||||
|
return set(self._get_valid_root_candidates(key))
|
||||||
|
|
||||||
def _check_environment_variables(self) -> None:
|
def _check_environment_variables(self) -> None:
|
||||||
"""Check for environment variables and update settings if needed"""
|
"""Check for environment variables and update settings if needed"""
|
||||||
env_api_key = os.environ.get("CIVITAI_API_KEY")
|
env_api_key = os.environ.get("CIVITAI_API_KEY")
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ def test_save_paths_renames_default_library(monkeypatch: pytest.MonkeyPatch, tmp
|
|||||||
self.delete_calls = []
|
self.delete_calls = []
|
||||||
self.upsert_calls = []
|
self.upsert_calls = []
|
||||||
self._renamed = False
|
self._renamed = False
|
||||||
|
self.active_library = "default"
|
||||||
|
|
||||||
def get_libraries(self):
|
def get_libraries(self):
|
||||||
if self._renamed:
|
if self._renamed:
|
||||||
@@ -62,6 +63,11 @@ def test_save_paths_renames_default_library(monkeypatch: pytest.MonkeyPatch, tmp
|
|||||||
def rename_library(self, old_name: str, new_name: str):
|
def rename_library(self, old_name: str, new_name: str):
|
||||||
self.rename_calls.append((old_name, new_name))
|
self.rename_calls.append((old_name, new_name))
|
||||||
self._renamed = True
|
self._renamed = True
|
||||||
|
if self.active_library == old_name:
|
||||||
|
self.active_library = new_name
|
||||||
|
|
||||||
|
def get_active_library_name(self):
|
||||||
|
return self.active_library
|
||||||
|
|
||||||
def delete_library(self, name: str): # pragma: no cover - defensive guard
|
def delete_library(self, name: str): # pragma: no cover - defensive guard
|
||||||
self.delete_calls.append(name)
|
self.delete_calls.append(name)
|
||||||
@@ -104,6 +110,7 @@ def test_save_paths_logs_warning_when_upsert_fails(
|
|||||||
class RaisingSettingsService:
|
class RaisingSettingsService:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.upsert_attempts = []
|
self.upsert_attempts = []
|
||||||
|
self.active_library = "comfyui"
|
||||||
|
|
||||||
def get_libraries(self):
|
def get_libraries(self):
|
||||||
return {
|
return {
|
||||||
@@ -116,6 +123,9 @@ def test_save_paths_logs_warning_when_upsert_fails(
|
|||||||
def rename_library(self, *_):
|
def rename_library(self, *_):
|
||||||
raise AssertionError("rename_library should not be invoked")
|
raise AssertionError("rename_library should not be invoked")
|
||||||
|
|
||||||
|
def get_active_library_name(self):
|
||||||
|
return self.active_library
|
||||||
|
|
||||||
def upsert_library(self, name: str, **payload):
|
def upsert_library(self, name: str, **payload):
|
||||||
self.upsert_attempts.append((name, payload))
|
self.upsert_attempts.append((name, payload))
|
||||||
raise RuntimeError("boom")
|
raise RuntimeError("boom")
|
||||||
@@ -135,6 +145,8 @@ def test_save_paths_repairs_empty_default_roots(monkeypatch: pytest.MonkeyPatch,
|
|||||||
folder_paths = _setup_config_environment(monkeypatch, tmp_path)
|
folder_paths = _setup_config_environment(monkeypatch, tmp_path)
|
||||||
|
|
||||||
class FakeSettingsService:
|
class FakeSettingsService:
|
||||||
|
active_library = "comfyui"
|
||||||
|
|
||||||
def get_libraries(self):
|
def get_libraries(self):
|
||||||
return {
|
return {
|
||||||
"comfyui": {
|
"comfyui": {
|
||||||
@@ -148,6 +160,9 @@ def test_save_paths_repairs_empty_default_roots(monkeypatch: pytest.MonkeyPatch,
|
|||||||
def rename_library(self, *_):
|
def rename_library(self, *_):
|
||||||
raise AssertionError("rename_library should not be invoked")
|
raise AssertionError("rename_library should not be invoked")
|
||||||
|
|
||||||
|
def get_active_library_name(self):
|
||||||
|
return self.active_library
|
||||||
|
|
||||||
def upsert_library(self, name: str, **payload):
|
def upsert_library(self, name: str, **payload):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.payload = payload
|
self.payload = payload
|
||||||
@@ -167,6 +182,8 @@ def test_save_paths_repairs_stale_default_roots(monkeypatch: pytest.MonkeyPatch,
|
|||||||
folder_paths = _setup_config_environment(monkeypatch, tmp_path)
|
folder_paths = _setup_config_environment(monkeypatch, tmp_path)
|
||||||
|
|
||||||
class FakeSettingsService:
|
class FakeSettingsService:
|
||||||
|
active_library = "comfyui"
|
||||||
|
|
||||||
def get_libraries(self):
|
def get_libraries(self):
|
||||||
return {
|
return {
|
||||||
"comfyui": {
|
"comfyui": {
|
||||||
@@ -180,6 +197,9 @@ def test_save_paths_repairs_stale_default_roots(monkeypatch: pytest.MonkeyPatch,
|
|||||||
def rename_library(self, *_):
|
def rename_library(self, *_):
|
||||||
raise AssertionError("rename_library should not be invoked")
|
raise AssertionError("rename_library should not be invoked")
|
||||||
|
|
||||||
|
def get_active_library_name(self):
|
||||||
|
return self.active_library
|
||||||
|
|
||||||
def upsert_library(self, name: str, **payload):
|
def upsert_library(self, name: str, **payload):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.payload = payload
|
self.payload = payload
|
||||||
@@ -199,6 +219,8 @@ def test_save_paths_keeps_valid_default_roots(monkeypatch: pytest.MonkeyPatch, t
|
|||||||
folder_paths = _setup_config_environment(monkeypatch, tmp_path)
|
folder_paths = _setup_config_environment(monkeypatch, tmp_path)
|
||||||
|
|
||||||
class FakeSettingsService:
|
class FakeSettingsService:
|
||||||
|
active_library = "comfyui"
|
||||||
|
|
||||||
def get_libraries(self):
|
def get_libraries(self):
|
||||||
return {
|
return {
|
||||||
"comfyui": {
|
"comfyui": {
|
||||||
@@ -212,6 +234,9 @@ def test_save_paths_keeps_valid_default_roots(monkeypatch: pytest.MonkeyPatch, t
|
|||||||
def rename_library(self, *_):
|
def rename_library(self, *_):
|
||||||
raise AssertionError("rename_library should not be invoked")
|
raise AssertionError("rename_library should not be invoked")
|
||||||
|
|
||||||
|
def get_active_library_name(self):
|
||||||
|
return self.active_library
|
||||||
|
|
||||||
def upsert_library(self, name: str, **payload):
|
def upsert_library(self, name: str, **payload):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.payload = payload
|
self.payload = payload
|
||||||
@@ -258,6 +283,7 @@ def test_save_paths_removes_template_default_library(monkeypatch, tmp_path):
|
|||||||
self.rename_calls = []
|
self.rename_calls = []
|
||||||
self.delete_calls = []
|
self.delete_calls = []
|
||||||
self.upsert_calls = []
|
self.upsert_calls = []
|
||||||
|
self.active_library = "default"
|
||||||
|
|
||||||
def get_libraries(self):
|
def get_libraries(self):
|
||||||
return self.libraries
|
return self.libraries
|
||||||
@@ -265,6 +291,8 @@ def test_save_paths_removes_template_default_library(monkeypatch, tmp_path):
|
|||||||
def rename_library(self, old_name: str, new_name: str):
|
def rename_library(self, old_name: str, new_name: str):
|
||||||
self.rename_calls.append((old_name, new_name))
|
self.rename_calls.append((old_name, new_name))
|
||||||
self.libraries[new_name] = self.libraries.pop(old_name)
|
self.libraries[new_name] = self.libraries.pop(old_name)
|
||||||
|
if self.active_library == old_name:
|
||||||
|
self.active_library = new_name
|
||||||
|
|
||||||
def delete_library(self, name: str):
|
def delete_library(self, name: str):
|
||||||
self.delete_calls.append(name)
|
self.delete_calls.append(name)
|
||||||
@@ -273,6 +301,11 @@ def test_save_paths_removes_template_default_library(monkeypatch, tmp_path):
|
|||||||
def upsert_library(self, name: str, **payload):
|
def upsert_library(self, name: str, **payload):
|
||||||
self.upsert_calls.append((name, payload))
|
self.upsert_calls.append((name, payload))
|
||||||
self.libraries[name] = {**payload}
|
self.libraries[name] = {**payload}
|
||||||
|
if payload.get("activate"):
|
||||||
|
self.active_library = name
|
||||||
|
|
||||||
|
def get_active_library_name(self):
|
||||||
|
return self.active_library
|
||||||
|
|
||||||
fake_settings = FakeSettingsService()
|
fake_settings = FakeSettingsService()
|
||||||
monkeypatch.setattr(settings_manager_module, "settings", fake_settings)
|
monkeypatch.setattr(settings_manager_module, "settings", fake_settings)
|
||||||
@@ -313,6 +346,156 @@ def test_save_paths_removes_template_default_library(monkeypatch, tmp_path):
|
|||||||
assert payload["activate"] is True
|
assert payload["activate"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_paths_keeps_default_roots_in_extra_paths(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||||
|
folder_paths = _setup_config_environment(monkeypatch, tmp_path)
|
||||||
|
extra_lora_dir = tmp_path / "extra_loras"
|
||||||
|
extra_checkpoint_dir = tmp_path / "extra_checkpoints"
|
||||||
|
extra_embedding_dir = tmp_path / "extra_embeddings"
|
||||||
|
|
||||||
|
for directory in (extra_lora_dir, extra_checkpoint_dir, extra_embedding_dir):
|
||||||
|
directory.mkdir()
|
||||||
|
|
||||||
|
class FakeSettingsService:
|
||||||
|
active_library = "comfyui"
|
||||||
|
|
||||||
|
def get_libraries(self):
|
||||||
|
return {
|
||||||
|
"comfyui": {
|
||||||
|
"folder_paths": {key: list(value) for key, value in folder_paths.items()},
|
||||||
|
"extra_folder_paths": {
|
||||||
|
"loras": [str(extra_lora_dir)],
|
||||||
|
"checkpoints": [str(extra_checkpoint_dir)],
|
||||||
|
"embeddings": [str(extra_embedding_dir)],
|
||||||
|
},
|
||||||
|
"default_lora_root": str(extra_lora_dir),
|
||||||
|
"default_checkpoint_root": str(extra_checkpoint_dir),
|
||||||
|
"default_embedding_root": str(extra_embedding_dir),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def rename_library(self, *_):
|
||||||
|
raise AssertionError("rename_library should not be invoked")
|
||||||
|
|
||||||
|
def get_active_library_name(self):
|
||||||
|
return self.active_library
|
||||||
|
|
||||||
|
def upsert_library(self, name: str, **payload):
|
||||||
|
self.name = name
|
||||||
|
self.payload = payload
|
||||||
|
|
||||||
|
fake_settings = FakeSettingsService()
|
||||||
|
monkeypatch.setattr(settings_manager_module, "settings", fake_settings)
|
||||||
|
|
||||||
|
config_module.Config()
|
||||||
|
|
||||||
|
assert fake_settings.name == "comfyui"
|
||||||
|
assert fake_settings.payload["extra_folder_paths"]["loras"] == [str(extra_lora_dir).replace("\\", "/")]
|
||||||
|
assert fake_settings.payload["extra_folder_paths"]["checkpoints"] == [
|
||||||
|
str(extra_checkpoint_dir).replace("\\", "/")
|
||||||
|
]
|
||||||
|
assert fake_settings.payload["extra_folder_paths"]["embeddings"] == [
|
||||||
|
str(extra_embedding_dir).replace("\\", "/")
|
||||||
|
]
|
||||||
|
assert fake_settings.payload["default_lora_root"] == str(extra_lora_dir).replace("\\", "/")
|
||||||
|
assert fake_settings.payload["default_checkpoint_root"] == str(extra_checkpoint_dir).replace("\\", "/")
|
||||||
|
assert fake_settings.payload["default_embedding_root"] == str(extra_embedding_dir).replace("\\", "/")
|
||||||
|
assert fake_settings.payload["activate"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_paths_repairs_empty_default_roots_to_extra_paths_when_primary_missing(
|
||||||
|
monkeypatch: pytest.MonkeyPatch, tmp_path
|
||||||
|
):
|
||||||
|
_setup_config_environment(monkeypatch, tmp_path)
|
||||||
|
extra_lora_dir = tmp_path / "extra_loras"
|
||||||
|
extra_lora_dir.mkdir()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config_module.folder_paths,
|
||||||
|
"get_folder_paths",
|
||||||
|
lambda kind: [] if kind == "loras" else [],
|
||||||
|
)
|
||||||
|
|
||||||
|
class FakeSettingsService:
|
||||||
|
active_library = "comfyui"
|
||||||
|
|
||||||
|
def get_libraries(self):
|
||||||
|
return {
|
||||||
|
"comfyui": {
|
||||||
|
"folder_paths": {
|
||||||
|
"loras": [],
|
||||||
|
"checkpoints": [],
|
||||||
|
"unet": [],
|
||||||
|
"embeddings": [],
|
||||||
|
},
|
||||||
|
"extra_folder_paths": {
|
||||||
|
"loras": [str(extra_lora_dir)],
|
||||||
|
},
|
||||||
|
"default_lora_root": "",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def rename_library(self, *_):
|
||||||
|
raise AssertionError("rename_library should not be invoked")
|
||||||
|
|
||||||
|
def get_active_library_name(self):
|
||||||
|
return self.active_library
|
||||||
|
|
||||||
|
def upsert_library(self, name: str, **payload):
|
||||||
|
self.name = name
|
||||||
|
self.payload = payload
|
||||||
|
|
||||||
|
fake_settings = FakeSettingsService()
|
||||||
|
monkeypatch.setattr(settings_manager_module, "settings", fake_settings)
|
||||||
|
|
||||||
|
config_module.Config()
|
||||||
|
|
||||||
|
assert fake_settings.name == "comfyui"
|
||||||
|
assert fake_settings.payload["default_lora_root"] == str(extra_lora_dir).replace("\\", "/")
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_paths_does_not_activate_comfyui_library_when_another_library_is_active(
|
||||||
|
monkeypatch: pytest.MonkeyPatch, tmp_path
|
||||||
|
):
|
||||||
|
folder_paths = _setup_config_environment(monkeypatch, tmp_path)
|
||||||
|
|
||||||
|
class FakeSettingsService:
|
||||||
|
def __init__(self):
|
||||||
|
self.active_library = "studio"
|
||||||
|
self.upsert_calls = []
|
||||||
|
|
||||||
|
def get_libraries(self):
|
||||||
|
return {
|
||||||
|
"studio": {
|
||||||
|
"folder_paths": {"loras": ["/studio/loras"]},
|
||||||
|
},
|
||||||
|
"comfyui": {
|
||||||
|
"folder_paths": {key: list(value) for key, value in folder_paths.items()},
|
||||||
|
"default_lora_root": folder_paths["loras"][0],
|
||||||
|
"default_checkpoint_root": folder_paths["checkpoints"][0],
|
||||||
|
"default_embedding_root": folder_paths["embeddings"][0],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def rename_library(self, *_):
|
||||||
|
raise AssertionError("rename_library should not be invoked")
|
||||||
|
|
||||||
|
def get_active_library_name(self):
|
||||||
|
return self.active_library
|
||||||
|
|
||||||
|
def upsert_library(self, name: str, **payload):
|
||||||
|
self.upsert_calls.append((name, payload))
|
||||||
|
|
||||||
|
fake_settings = FakeSettingsService()
|
||||||
|
monkeypatch.setattr(settings_manager_module, "settings", fake_settings)
|
||||||
|
|
||||||
|
config_module.Config()
|
||||||
|
|
||||||
|
assert len(fake_settings.upsert_calls) == 1
|
||||||
|
name, payload = fake_settings.upsert_calls[0]
|
||||||
|
assert name == "comfyui"
|
||||||
|
assert payload["activate"] is False
|
||||||
|
|
||||||
|
|
||||||
def test_apply_library_settings_merges_extra_paths(monkeypatch, tmp_path):
|
def test_apply_library_settings_merges_extra_paths(monkeypatch, tmp_path):
|
||||||
"""Test that apply_library_settings correctly merges folder_paths with extra_folder_paths."""
|
"""Test that apply_library_settings correctly merges folder_paths with extra_folder_paths."""
|
||||||
loras_dir = tmp_path / "loras"
|
loras_dir = tmp_path / "loras"
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import pytest
|
|||||||
|
|
||||||
from py.services import civitai_client as civitai_client_module
|
from py.services import civitai_client as civitai_client_module
|
||||||
from py.services.civitai_client import CivitaiClient
|
from py.services.civitai_client import CivitaiClient
|
||||||
|
from py.services.connectivity_guard import OFFLINE_COOLDOWN_ERROR, OFFLINE_FRIENDLY_MESSAGE
|
||||||
from py.services.errors import RateLimitError, ResourceNotFoundError
|
from py.services.errors import RateLimitError, ResourceNotFoundError
|
||||||
from py.services.model_metadata_provider import ModelMetadataProviderManager
|
from py.services.model_metadata_provider import ModelMetadataProviderManager
|
||||||
|
|
||||||
@@ -115,6 +116,20 @@ async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader):
|
|||||||
assert error == "Model not found"
|
assert error == "Model not found"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_model_by_hash_handles_offline_cooldown(downloader):
|
||||||
|
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||||
|
return False, OFFLINE_COOLDOWN_ERROR
|
||||||
|
|
||||||
|
downloader.make_request = fake_make_request
|
||||||
|
|
||||||
|
client = await CivitaiClient.get_instance()
|
||||||
|
|
||||||
|
result, error = await client.get_model_by_hash("missing")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
assert error == OFFLINE_FRIENDLY_MESSAGE
|
||||||
|
|
||||||
|
|
||||||
async def test_get_model_by_hash_propagates_rate_limit(monkeypatch, downloader):
|
async def test_get_model_by_hash_propagates_rate_limit(monkeypatch, downloader):
|
||||||
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||||
return False, RateLimitError("limited", retry_after=4)
|
return False, RateLimitError("limited", retry_after=4)
|
||||||
|
|||||||
125
tests/services/test_connectivity_guard.py
Normal file
125
tests/services/test_connectivity_guard.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
import asyncio
|
||||||
|
import errno
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from py.services.connectivity_guard import (
|
||||||
|
OFFLINE_COOLDOWN_ERROR,
|
||||||
|
OFFLINE_FRIENDLY_MESSAGE,
|
||||||
|
ConnectivityGuard,
|
||||||
|
)
|
||||||
|
from py.services.downloader import Downloader
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_connectivity_guard_singleton():
|
||||||
|
ConnectivityGuard._instance = None
|
||||||
|
yield
|
||||||
|
ConnectivityGuard._instance = None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_connectivity_guard_enters_cooldown_after_threshold():
|
||||||
|
guard = await ConnectivityGuard.get_instance()
|
||||||
|
|
||||||
|
assert guard.online is True
|
||||||
|
assert guard.should_block_request() is False
|
||||||
|
|
||||||
|
guard.register_network_failure(OSError(errno.ENETUNREACH, "unreachable"))
|
||||||
|
guard.register_network_failure(asyncio.TimeoutError("timeout"))
|
||||||
|
|
||||||
|
assert guard.should_block_request() is False
|
||||||
|
assert guard.failure_count == 2
|
||||||
|
|
||||||
|
guard.register_network_failure(ConnectionRefusedError("refused"))
|
||||||
|
|
||||||
|
assert guard.online is False
|
||||||
|
assert guard.failure_count == 3
|
||||||
|
assert guard.should_block_request() is True
|
||||||
|
assert guard.cooldown_remaining_seconds() > 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_connectivity_guard_scopes_cooldown_to_destination():
|
||||||
|
guard = await ConnectivityGuard.get_instance()
|
||||||
|
|
||||||
|
destination_a = "civitai.com"
|
||||||
|
destination_b = "api.github.com"
|
||||||
|
|
||||||
|
guard.register_network_failure(
|
||||||
|
OSError(errno.ENETUNREACH, "unreachable"),
|
||||||
|
destination_a,
|
||||||
|
)
|
||||||
|
guard.register_network_failure(asyncio.TimeoutError("timeout"), destination_a)
|
||||||
|
guard.register_network_failure(ConnectionRefusedError("refused"), destination_a)
|
||||||
|
|
||||||
|
assert guard.should_block_request(destination_a) is True
|
||||||
|
assert guard.should_block_request(destination_b) is False
|
||||||
|
|
||||||
|
guard.register_success(destination_a)
|
||||||
|
assert guard.should_block_request(destination_a) is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_connectivity_guard_recovers_after_success():
|
||||||
|
guard = await ConnectivityGuard.get_instance()
|
||||||
|
guard.online = False
|
||||||
|
guard.failure_count = 5
|
||||||
|
guard.cooldown_until = datetime.now() + timedelta(seconds=90)
|
||||||
|
|
||||||
|
guard.register_success()
|
||||||
|
|
||||||
|
assert guard.online is True
|
||||||
|
assert guard.failure_count == 0
|
||||||
|
assert guard.cooldown_until is None
|
||||||
|
assert guard.should_block_request() is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_downloader_short_circuits_all_request_helpers_during_cooldown():
|
||||||
|
guard = await ConnectivityGuard.get_instance()
|
||||||
|
destination = "example.invalid"
|
||||||
|
guard.register_network_failure(
|
||||||
|
OSError(errno.ENETUNREACH, "unreachable"),
|
||||||
|
destination,
|
||||||
|
)
|
||||||
|
guard.register_network_failure(asyncio.TimeoutError("timeout"), destination)
|
||||||
|
guard.register_network_failure(
|
||||||
|
ConnectionRefusedError("refused"),
|
||||||
|
destination,
|
||||||
|
)
|
||||||
|
|
||||||
|
downloader = Downloader()
|
||||||
|
|
||||||
|
ok, payload = await downloader.make_request("GET", f"https://{destination}")
|
||||||
|
assert ok is False
|
||||||
|
assert payload == OFFLINE_COOLDOWN_ERROR
|
||||||
|
|
||||||
|
ok, payload, headers = await downloader.download_to_memory(f"https://{destination}")
|
||||||
|
assert ok is False
|
||||||
|
assert payload == OFFLINE_FRIENDLY_MESSAGE
|
||||||
|
assert headers is None
|
||||||
|
|
||||||
|
ok, payload = await downloader.get_response_headers(f"https://{destination}")
|
||||||
|
assert ok is False
|
||||||
|
assert payload == OFFLINE_COOLDOWN_ERROR
|
||||||
|
|
||||||
|
|
||||||
|
async def test_downloader_only_short_circuits_requests_for_same_destination():
|
||||||
|
guard = await ConnectivityGuard.get_instance()
|
||||||
|
guard.register_network_failure(
|
||||||
|
OSError(errno.ENETUNREACH, "unreachable"),
|
||||||
|
"example.invalid",
|
||||||
|
)
|
||||||
|
guard.register_network_failure(asyncio.TimeoutError("timeout"), "example.invalid")
|
||||||
|
guard.register_network_failure(
|
||||||
|
ConnectionRefusedError("refused"),
|
||||||
|
"example.invalid",
|
||||||
|
)
|
||||||
|
|
||||||
|
downloader = Downloader()
|
||||||
|
ok, payload = await downloader.make_request("GET", "https://example.invalid")
|
||||||
|
assert ok is False
|
||||||
|
assert payload == OFFLINE_COOLDOWN_ERROR
|
||||||
|
|
||||||
|
assert (
|
||||||
|
guard.should_block_request(downloader._guard_destination("https://example.com"))
|
||||||
|
is False
|
||||||
|
)
|
||||||
@@ -4,6 +4,7 @@ from unittest.mock import AsyncMock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from py.services.connectivity_guard import OFFLINE_COOLDOWN_ERROR, OFFLINE_FRIENDLY_MESSAGE
|
||||||
from py.services.errors import RateLimitError
|
from py.services.errors import RateLimitError
|
||||||
from py.services.metadata_sync_service import MetadataSyncService
|
from py.services.metadata_sync_service import MetadataSyncService
|
||||||
|
|
||||||
@@ -243,17 +244,32 @@ async def test_fetch_and_update_model_handles_missing_remote_metadata(tmp_path):
|
|||||||
|
|
||||||
assert not ok
|
assert not ok
|
||||||
assert "Model not found" in error
|
assert "Model not found" in error
|
||||||
assert model_data["from_civitai"] is False
|
|
||||||
assert model_data["civitai_deleted"] is True
|
|
||||||
|
|
||||||
helpers.metadata_manager.hydrate_model_data.assert_not_awaited()
|
|
||||||
assert model_data["hydrated"] is True
|
|
||||||
|
|
||||||
helpers.metadata_manager.save_metadata.assert_awaited_once()
|
@pytest.mark.asyncio
|
||||||
call_args = helpers.metadata_manager.save_metadata.await_args
|
async def test_fetch_and_update_model_returns_friendly_offline_message(tmp_path):
|
||||||
assert call_args.args[0].endswith("model.safetensors")
|
helpers = build_service()
|
||||||
assert "folder" not in call_args.args[1]
|
helpers.default_provider.get_model_by_hash.return_value = (None, OFFLINE_COOLDOWN_ERROR)
|
||||||
assert call_args.args[1]["hydrated"] is True
|
|
||||||
|
model_path = tmp_path / "model.safetensors"
|
||||||
|
model_data = {
|
||||||
|
"model_name": "Local",
|
||||||
|
"folder": "root",
|
||||||
|
"file_path": str(model_path),
|
||||||
|
}
|
||||||
|
update_cache = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
ok, error = await helpers.service.fetch_and_update_model(
|
||||||
|
sha256="abc",
|
||||||
|
file_path=str(model_path),
|
||||||
|
model_data=model_data,
|
||||||
|
update_cache_func=update_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ok is False
|
||||||
|
assert error is not None
|
||||||
|
assert OFFLINE_FRIENDLY_MESSAGE in error
|
||||||
|
update_cache.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -332,6 +332,43 @@ def test_auto_set_default_roots_keeps_valid_values(manager):
|
|||||||
assert manager.get("default_embedding_root") == "/embeddings"
|
assert manager.get("default_embedding_root") == "/embeddings"
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_set_default_roots_keeps_valid_extra_values(manager):
|
||||||
|
manager.settings["default_lora_root"] = "/extra-loras"
|
||||||
|
manager.settings["default_checkpoint_root"] = "/extra-checkpoints"
|
||||||
|
manager.settings["default_embedding_root"] = "/extra-embeddings"
|
||||||
|
manager.settings["default_unet_root"] = "/extra-unet"
|
||||||
|
|
||||||
|
manager.settings["folder_paths"] = {
|
||||||
|
"loras": ["/loras"],
|
||||||
|
"checkpoints": ["/checkpoints"],
|
||||||
|
"unet": ["/unet"],
|
||||||
|
"embeddings": ["/embeddings"],
|
||||||
|
}
|
||||||
|
manager.settings["extra_folder_paths"] = {
|
||||||
|
"loras": ["/extra-loras"],
|
||||||
|
"checkpoints": ["/extra-checkpoints"],
|
||||||
|
"unet": ["/extra-unet"],
|
||||||
|
"embeddings": ["/extra-embeddings"],
|
||||||
|
}
|
||||||
|
|
||||||
|
manager._auto_set_default_roots()
|
||||||
|
|
||||||
|
assert manager.get("default_lora_root") == "/extra-loras"
|
||||||
|
assert manager.get("default_checkpoint_root") == "/extra-checkpoints"
|
||||||
|
assert manager.get("default_unet_root") == "/extra-unet"
|
||||||
|
assert manager.get("default_embedding_root") == "/extra-embeddings"
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_set_default_roots_falls_back_to_extra_when_primary_missing(manager):
|
||||||
|
manager.settings["default_lora_root"] = ""
|
||||||
|
manager.settings["folder_paths"] = {"loras": []}
|
||||||
|
manager.settings["extra_folder_paths"] = {"loras": ["/extra-loras"]}
|
||||||
|
|
||||||
|
manager._auto_set_default_roots()
|
||||||
|
|
||||||
|
assert manager.get("default_lora_root") == "/extra-loras"
|
||||||
|
|
||||||
|
|
||||||
def test_delete_setting(manager):
|
def test_delete_setting(manager):
|
||||||
manager.set("example", 1)
|
manager.set("example", 1)
|
||||||
manager.delete("example")
|
manager.delete("example")
|
||||||
|
|||||||
Reference in New Issue
Block a user