feat: implement same_base update strategy for model annotations

Add support for configurable update flag strategy with new "same_base" mode that considers base model versions when determining update availability. The strategy is controlled by the "update_flag_strategy" setting.

When strategy is set to "same_base":
- Uses get_records_bulk instead of has_updates_bulk
- Compares model versions against highest local versions per base model
- Provides more granular update detection based on base model relationships

Fallback to existing bulk or individual update checks when:
- Strategy is not "same_base"
- Bulk operations fail
- Records are unavailable

This enables more precise update flagging for models sharing common bases.
This commit is contained in:
Will Miao
2025-11-17 19:26:41 +08:00
parent 8158441a92
commit 0e73db0669
5 changed files with 458 additions and 48 deletions

View File

@@ -267,20 +267,49 @@ class BaseModelService(ABC):
if not ordered_ids:
return annotated
strategy_value = self.settings.get("update_flag_strategy")
if isinstance(strategy_value, str) and strategy_value.strip():
strategy = strategy_value.strip().lower()
else:
strategy = "any"
same_base_mode = strategy == "same_base"
records = None
resolved: Optional[Dict[int, bool]] = None
bulk_method = getattr(self.update_service, "has_updates_bulk", None)
if callable(bulk_method):
try:
resolved = await bulk_method(self.model_type, ordered_ids)
except Exception as exc:
logger.error(
"Failed to resolve update status in bulk for %s models (%s): %s",
self.model_type,
ordered_ids,
exc,
exc_info=True,
)
resolved = None
if same_base_mode:
record_method = getattr(self.update_service, "get_records_bulk", None)
if callable(record_method):
try:
records = await record_method(self.model_type, ordered_ids)
resolved = {
model_id: record.has_update()
for model_id, record in records.items()
}
except Exception as exc:
logger.error(
"Failed to resolve update records in bulk for %s models (%s): %s",
self.model_type,
ordered_ids,
exc,
exc_info=True,
)
records = None
resolved = None
if resolved is None:
bulk_method = getattr(self.update_service, "has_updates_bulk", None)
if callable(bulk_method):
try:
resolved = await bulk_method(self.model_type, ordered_ids)
except Exception as exc:
logger.error(
"Failed to resolve update status in bulk for %s models (%s): %s",
self.model_type,
ordered_ids,
exc,
exc_info=True,
)
resolved = None
if resolved is None:
tasks = [
@@ -301,8 +330,24 @@ class BaseModelService(ABC):
resolved[model_id] = bool(result)
for model_id, items_for_id in id_to_items.items():
flag = bool(resolved.get(model_id, False))
default_flag = bool(resolved.get(model_id, False)) if resolved else False
record = records.get(model_id) if records else None
base_highest_versions = (
self._build_highest_local_versions_by_base(record) if same_base_mode and record else {}
)
for item in items_for_id:
if same_base_mode and record is not None:
base_model = self._extract_base_model(item)
normalized_base = self._normalize_base_model_name(base_model)
threshold_version = base_highest_versions.get(normalized_base) if normalized_base else None
if threshold_version is None:
threshold_version = self._extract_version_id(item)
flag = record.has_update_for_base(
threshold_version,
base_model,
)
else:
flag = default_flag
item['update_available'] = flag
return annotated
@@ -319,7 +364,71 @@ class BaseModelService(ABC):
return int(value)
except (TypeError, ValueError):
return None
@staticmethod
def _extract_version_id(item: Dict) -> Optional[int]:
civitai = item.get('civitai') if isinstance(item, dict) else None
if not isinstance(civitai, dict):
return None
value = civitai.get('id')
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
@staticmethod
def _extract_base_model(item: Dict) -> Optional[str]:
value = item.get('base_model')
if value is None:
return None
if isinstance(value, str):
candidate = value.strip()
else:
try:
candidate = str(value).strip()
except Exception:
return None
return candidate if candidate else None
@staticmethod
def _normalize_base_model_name(value: Optional[str]) -> Optional[str]:
"""Return a lowercased, trimmed base model name for comparison."""
if value is None:
return None
if isinstance(value, str):
candidate = value.strip()
else:
try:
candidate = str(value).strip()
except Exception:
return None
return candidate.lower() if candidate else None
def _build_highest_local_versions_by_base(self, record) -> Dict[str, int]:
"""Return the highest local version id known for each normalized base model."""
if record is None:
return {}
highest_by_base: Dict[str, int] = {}
for version in getattr(record, "versions", []):
if not getattr(version, "is_in_library", False):
continue
normalized_base = self._normalize_base_model_name(getattr(version, "base_model", None))
if normalized_base is None:
continue
version_id = getattr(version, "version_id", None)
if version_id is None:
continue
current_max = highest_by_base.get(normalized_base)
if current_max is None or version_id > current_max:
highest_by_base[normalized_base] = version_id
return highest_by_base
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
"""Apply pagination to filtered data"""
total_items = len(data)

View File

@@ -17,6 +17,41 @@ from ..utils.preview_selection import select_preview_media
logger = logging.getLogger(__name__)
def _normalize_int(value) -> Optional[int]:
"""Safely convert a value to an integer."""
try:
if value is None:
return None
return int(value)
except (TypeError, ValueError):
return None
def _normalize_string(value) -> Optional[str]:
"""Return a stripped string or None if the value is empty."""
if value is None:
return None
if isinstance(value, str):
stripped = value.strip()
return stripped or None
try:
normalized = str(value).strip()
return normalized or None
except Exception:
return None
def _normalize_base_model(value) -> Optional[str]:
"""Normalize base-model names for case-insensitive comparison."""
normalized = _normalize_string(value)
if normalized is None:
return None
return normalized.lower()
@dataclass
class ModelVersionRecord:
"""Persisted metadata for a single model version."""
@@ -85,6 +120,47 @@ class ModelUpdateRecord:
return True
return False
def has_update_for_base(
self,
local_version_id: Optional[int],
local_base_model: Optional[str],
) -> bool:
"""Return True when a newer remote version with the same base model exists."""
if self.should_ignore_model:
return False
normalized_base = _normalize_base_model(local_base_model)
if normalized_base is None:
return False
threshold = _normalize_int(local_version_id)
if threshold is None:
highest_local = None
for version in self.versions:
if not version.is_in_library:
continue
version_base = _normalize_base_model(version.base_model)
if version_base != normalized_base:
continue
if highest_local is None or version.version_id > highest_local:
highest_local = version.version_id
threshold = highest_local
if threshold is None:
return False
for version in self.versions:
if version.is_in_library or version.should_ignore:
continue
version_base = _normalize_base_model(version.base_model)
if version_base != normalized_base:
continue
if version.version_id > threshold:
return True
return False
class ModelUpdateService:
"""Persist and query remote model version metadata."""
@@ -628,6 +704,20 @@ class ModelUpdateService:
for model_id in normalized_ids
}
async def get_records_bulk(
self,
model_type: str,
model_ids: Sequence[int],
) -> Dict[int, ModelUpdateRecord]:
"""Return cached update records for the requested models."""
normalized_ids = self._normalize_sequence(model_ids)
if not normalized_ids:
return {}
async with self._lock:
return self._get_records_bulk(model_type, normalized_ids)
async def _refresh_single_model(
self,
model_type: str,
@@ -799,7 +889,7 @@ class ModelUpdateService:
)
continue
for key, value in response.items():
normalized_key = self._normalize_int(key)
normalized_key = _normalize_int(key)
if normalized_key is None:
continue
if isinstance(value, Mapping):
@@ -832,8 +922,8 @@ class ModelUpdateService:
civitai = item.get("civitai") if isinstance(item, dict) else None
if not isinstance(civitai, dict):
continue
model_id = self._normalize_int(civitai.get("modelId"))
version_id = self._normalize_int(civitai.get("id"))
model_id = _normalize_int(civitai.get("modelId"))
version_id = _normalize_int(civitai.get("id"))
if model_id is None or version_id is None:
continue
if target_set is not None and model_id not in target_set:
@@ -973,35 +1063,14 @@ class ModelUpdateService:
return True
return (now - record.last_checked_at) >= self._ttl_seconds
@staticmethod
def _normalize_int(value) -> Optional[int]:
try:
if value is None:
return None
return int(value)
except (TypeError, ValueError):
return None
def _normalize_sequence(self, values: Sequence[int]) -> List[int]:
normalized = [
item
for item in (self._normalize_int(value) for value in values)
for item in (_normalize_int(value) for value in values)
if item is not None
]
return sorted(dict.fromkeys(normalized))
@staticmethod
def _normalize_string(value) -> Optional[str]:
if value is None:
return None
if isinstance(value, str):
stripped = value.strip()
return stripped or None
try:
return str(value)
except Exception: # pragma: no cover - defensive conversion
return None
def _extract_versions(self, response) -> Optional[List[ModelVersionRecord]]:
if not isinstance(response, Mapping):
return None
@@ -1014,12 +1083,12 @@ class ModelUpdateService:
for index, entry in enumerate(versions):
if not isinstance(entry, Mapping):
continue
version_id = self._normalize_int(entry.get("id"))
version_id = _normalize_int(entry.get("id"))
if version_id is None:
continue
name = self._normalize_string(entry.get("name"))
base_model = self._normalize_string(entry.get("baseModel"))
released_at = self._normalize_string(entry.get("publishedAt") or entry.get("createdAt"))
name = _normalize_string(entry.get("name"))
base_model = _normalize_string(entry.get("baseModel"))
released_at = _normalize_string(entry.get("publishedAt") or entry.get("createdAt"))
size_bytes = self._extract_size_bytes(entry.get("files"))
preview_url = self._extract_preview_url(entry.get("images"))
extracted.append(
@@ -1152,11 +1221,11 @@ class ModelUpdateService:
name=row["name"],
base_model=row["base_model"],
released_at=row["released_at"],
size_bytes=self._normalize_int(row["size_bytes"]),
size_bytes=_normalize_int(row["size_bytes"]),
preview_url=row["preview_url"],
is_in_library=bool(row["is_in_library"]),
should_ignore=bool(row["should_ignore"]),
sort_index=self._normalize_int(row["sort_index"]) or 0,
sort_index=_normalize_int(row["sort_index"]) or 0,
)
)

View File

@@ -61,6 +61,7 @@ DEFAULT_SETTINGS: Dict[str, Any] = {
"priority_tags": DEFAULT_PRIORITY_TAG_CONFIG.copy(),
"model_name_display": "model_name",
"model_card_footer_action": "example_images",
"update_flag_strategy": "any",
}