mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
feat: implement same_base update strategy for model annotations
Add support for configurable update flag strategy with new "same_base" mode that considers base model versions when determining update availability. The strategy is controlled by the "update_flag_strategy" setting. When strategy is set to "same_base": - Uses get_records_bulk instead of has_updates_bulk - Compares model versions against highest local versions per base model - Provides more granular update detection based on base model relationships Fallback to existing bulk or individual update checks when: - Strategy is not "same_base" - Bulk operations fail - Records are unavailable This enables more precise update flagging for models sharing common bases.
This commit is contained in:
@@ -17,6 +17,41 @@ from ..utils.preview_selection import select_preview_media
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_int(value) -> Optional[int]:
|
||||
"""Safely convert a value to an integer."""
|
||||
|
||||
try:
|
||||
if value is None:
|
||||
return None
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_string(value) -> Optional[str]:
|
||||
"""Return a stripped string or None if the value is empty."""
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
return stripped or None
|
||||
try:
|
||||
normalized = str(value).strip()
|
||||
return normalized or None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_base_model(value) -> Optional[str]:
|
||||
"""Normalize base-model names for case-insensitive comparison."""
|
||||
|
||||
normalized = _normalize_string(value)
|
||||
if normalized is None:
|
||||
return None
|
||||
return normalized.lower()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelVersionRecord:
|
||||
"""Persisted metadata for a single model version."""
|
||||
@@ -85,6 +120,47 @@ class ModelUpdateRecord:
|
||||
return True
|
||||
return False
|
||||
|
||||
def has_update_for_base(
|
||||
self,
|
||||
local_version_id: Optional[int],
|
||||
local_base_model: Optional[str],
|
||||
) -> bool:
|
||||
"""Return True when a newer remote version with the same base model exists."""
|
||||
|
||||
if self.should_ignore_model:
|
||||
return False
|
||||
|
||||
normalized_base = _normalize_base_model(local_base_model)
|
||||
if normalized_base is None:
|
||||
return False
|
||||
|
||||
threshold = _normalize_int(local_version_id)
|
||||
if threshold is None:
|
||||
highest_local = None
|
||||
for version in self.versions:
|
||||
if not version.is_in_library:
|
||||
continue
|
||||
version_base = _normalize_base_model(version.base_model)
|
||||
if version_base != normalized_base:
|
||||
continue
|
||||
if highest_local is None or version.version_id > highest_local:
|
||||
highest_local = version.version_id
|
||||
threshold = highest_local
|
||||
|
||||
if threshold is None:
|
||||
return False
|
||||
|
||||
for version in self.versions:
|
||||
if version.is_in_library or version.should_ignore:
|
||||
continue
|
||||
version_base = _normalize_base_model(version.base_model)
|
||||
if version_base != normalized_base:
|
||||
continue
|
||||
if version.version_id > threshold:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class ModelUpdateService:
|
||||
"""Persist and query remote model version metadata."""
|
||||
@@ -628,6 +704,20 @@ class ModelUpdateService:
|
||||
for model_id in normalized_ids
|
||||
}
|
||||
|
||||
async def get_records_bulk(
|
||||
self,
|
||||
model_type: str,
|
||||
model_ids: Sequence[int],
|
||||
) -> Dict[int, ModelUpdateRecord]:
|
||||
"""Return cached update records for the requested models."""
|
||||
|
||||
normalized_ids = self._normalize_sequence(model_ids)
|
||||
if not normalized_ids:
|
||||
return {}
|
||||
|
||||
async with self._lock:
|
||||
return self._get_records_bulk(model_type, normalized_ids)
|
||||
|
||||
async def _refresh_single_model(
|
||||
self,
|
||||
model_type: str,
|
||||
@@ -799,7 +889,7 @@ class ModelUpdateService:
|
||||
)
|
||||
continue
|
||||
for key, value in response.items():
|
||||
normalized_key = self._normalize_int(key)
|
||||
normalized_key = _normalize_int(key)
|
||||
if normalized_key is None:
|
||||
continue
|
||||
if isinstance(value, Mapping):
|
||||
@@ -832,8 +922,8 @@ class ModelUpdateService:
|
||||
civitai = item.get("civitai") if isinstance(item, dict) else None
|
||||
if not isinstance(civitai, dict):
|
||||
continue
|
||||
model_id = self._normalize_int(civitai.get("modelId"))
|
||||
version_id = self._normalize_int(civitai.get("id"))
|
||||
model_id = _normalize_int(civitai.get("modelId"))
|
||||
version_id = _normalize_int(civitai.get("id"))
|
||||
if model_id is None or version_id is None:
|
||||
continue
|
||||
if target_set is not None and model_id not in target_set:
|
||||
@@ -973,35 +1063,14 @@ class ModelUpdateService:
|
||||
return True
|
||||
return (now - record.last_checked_at) >= self._ttl_seconds
|
||||
|
||||
@staticmethod
|
||||
def _normalize_int(value) -> Optional[int]:
|
||||
try:
|
||||
if value is None:
|
||||
return None
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
def _normalize_sequence(self, values: Sequence[int]) -> List[int]:
|
||||
normalized = [
|
||||
item
|
||||
for item in (self._normalize_int(value) for value in values)
|
||||
for item in (_normalize_int(value) for value in values)
|
||||
if item is not None
|
||||
]
|
||||
return sorted(dict.fromkeys(normalized))
|
||||
|
||||
@staticmethod
|
||||
def _normalize_string(value) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
return stripped or None
|
||||
try:
|
||||
return str(value)
|
||||
except Exception: # pragma: no cover - defensive conversion
|
||||
return None
|
||||
|
||||
def _extract_versions(self, response) -> Optional[List[ModelVersionRecord]]:
|
||||
if not isinstance(response, Mapping):
|
||||
return None
|
||||
@@ -1014,12 +1083,12 @@ class ModelUpdateService:
|
||||
for index, entry in enumerate(versions):
|
||||
if not isinstance(entry, Mapping):
|
||||
continue
|
||||
version_id = self._normalize_int(entry.get("id"))
|
||||
version_id = _normalize_int(entry.get("id"))
|
||||
if version_id is None:
|
||||
continue
|
||||
name = self._normalize_string(entry.get("name"))
|
||||
base_model = self._normalize_string(entry.get("baseModel"))
|
||||
released_at = self._normalize_string(entry.get("publishedAt") or entry.get("createdAt"))
|
||||
name = _normalize_string(entry.get("name"))
|
||||
base_model = _normalize_string(entry.get("baseModel"))
|
||||
released_at = _normalize_string(entry.get("publishedAt") or entry.get("createdAt"))
|
||||
size_bytes = self._extract_size_bytes(entry.get("files"))
|
||||
preview_url = self._extract_preview_url(entry.get("images"))
|
||||
extracted.append(
|
||||
@@ -1152,11 +1221,11 @@ class ModelUpdateService:
|
||||
name=row["name"],
|
||||
base_model=row["base_model"],
|
||||
released_at=row["released_at"],
|
||||
size_bytes=self._normalize_int(row["size_bytes"]),
|
||||
size_bytes=_normalize_int(row["size_bytes"]),
|
||||
preview_url=row["preview_url"],
|
||||
is_in_library=bool(row["is_in_library"]),
|
||||
should_ignore=bool(row["should_ignore"]),
|
||||
sort_index=self._normalize_int(row["sort_index"]) or 0,
|
||||
sort_index=_normalize_int(row["sort_index"]) or 0,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user