mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
feat: implement same_base update strategy for model annotations
Add support for configurable update flag strategy with new "same_base" mode that considers base model versions when determining update availability. The strategy is controlled by the "update_flag_strategy" setting. When strategy is set to "same_base": - Uses get_records_bulk instead of has_updates_bulk - Compares model versions against highest local versions per base model - Provides more granular update detection based on base model relationships Fallback to existing bulk or individual update checks when: - Strategy is not "same_base" - Bulk operations fail - Records are unavailable This enables more precise update flagging for models sharing common bases.
This commit is contained in:
@@ -267,20 +267,49 @@ class BaseModelService(ABC):
|
||||
if not ordered_ids:
|
||||
return annotated
|
||||
|
||||
strategy_value = self.settings.get("update_flag_strategy")
|
||||
if isinstance(strategy_value, str) and strategy_value.strip():
|
||||
strategy = strategy_value.strip().lower()
|
||||
else:
|
||||
strategy = "any"
|
||||
same_base_mode = strategy == "same_base"
|
||||
|
||||
records = None
|
||||
resolved: Optional[Dict[int, bool]] = None
|
||||
bulk_method = getattr(self.update_service, "has_updates_bulk", None)
|
||||
if callable(bulk_method):
|
||||
try:
|
||||
resolved = await bulk_method(self.model_type, ordered_ids)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to resolve update status in bulk for %s models (%s): %s",
|
||||
self.model_type,
|
||||
ordered_ids,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
resolved = None
|
||||
if same_base_mode:
|
||||
record_method = getattr(self.update_service, "get_records_bulk", None)
|
||||
if callable(record_method):
|
||||
try:
|
||||
records = await record_method(self.model_type, ordered_ids)
|
||||
resolved = {
|
||||
model_id: record.has_update()
|
||||
for model_id, record in records.items()
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to resolve update records in bulk for %s models (%s): %s",
|
||||
self.model_type,
|
||||
ordered_ids,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
records = None
|
||||
resolved = None
|
||||
|
||||
if resolved is None:
|
||||
bulk_method = getattr(self.update_service, "has_updates_bulk", None)
|
||||
if callable(bulk_method):
|
||||
try:
|
||||
resolved = await bulk_method(self.model_type, ordered_ids)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to resolve update status in bulk for %s models (%s): %s",
|
||||
self.model_type,
|
||||
ordered_ids,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
resolved = None
|
||||
|
||||
if resolved is None:
|
||||
tasks = [
|
||||
@@ -301,8 +330,24 @@ class BaseModelService(ABC):
|
||||
resolved[model_id] = bool(result)
|
||||
|
||||
for model_id, items_for_id in id_to_items.items():
|
||||
flag = bool(resolved.get(model_id, False))
|
||||
default_flag = bool(resolved.get(model_id, False)) if resolved else False
|
||||
record = records.get(model_id) if records else None
|
||||
base_highest_versions = (
|
||||
self._build_highest_local_versions_by_base(record) if same_base_mode and record else {}
|
||||
)
|
||||
for item in items_for_id:
|
||||
if same_base_mode and record is not None:
|
||||
base_model = self._extract_base_model(item)
|
||||
normalized_base = self._normalize_base_model_name(base_model)
|
||||
threshold_version = base_highest_versions.get(normalized_base) if normalized_base else None
|
||||
if threshold_version is None:
|
||||
threshold_version = self._extract_version_id(item)
|
||||
flag = record.has_update_for_base(
|
||||
threshold_version,
|
||||
base_model,
|
||||
)
|
||||
else:
|
||||
flag = default_flag
|
||||
item['update_available'] = flag
|
||||
|
||||
return annotated
|
||||
@@ -319,7 +364,71 @@ class BaseModelService(ABC):
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _extract_version_id(item: Dict) -> Optional[int]:
|
||||
civitai = item.get('civitai') if isinstance(item, dict) else None
|
||||
if not isinstance(civitai, dict):
|
||||
return None
|
||||
value = civitai.get('id')
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_base_model(item: Dict) -> Optional[str]:
|
||||
value = item.get('base_model')
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
candidate = value.strip()
|
||||
else:
|
||||
try:
|
||||
candidate = str(value).strip()
|
||||
except Exception:
|
||||
return None
|
||||
return candidate if candidate else None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_base_model_name(value: Optional[str]) -> Optional[str]:
|
||||
"""Return a lowercased, trimmed base model name for comparison."""
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
candidate = value.strip()
|
||||
else:
|
||||
try:
|
||||
candidate = str(value).strip()
|
||||
except Exception:
|
||||
return None
|
||||
return candidate.lower() if candidate else None
|
||||
|
||||
def _build_highest_local_versions_by_base(self, record) -> Dict[str, int]:
|
||||
"""Return the highest local version id known for each normalized base model."""
|
||||
|
||||
if record is None:
|
||||
return {}
|
||||
|
||||
highest_by_base: Dict[str, int] = {}
|
||||
for version in getattr(record, "versions", []):
|
||||
if not getattr(version, "is_in_library", False):
|
||||
continue
|
||||
normalized_base = self._normalize_base_model_name(getattr(version, "base_model", None))
|
||||
if normalized_base is None:
|
||||
continue
|
||||
version_id = getattr(version, "version_id", None)
|
||||
if version_id is None:
|
||||
continue
|
||||
current_max = highest_by_base.get(normalized_base)
|
||||
if current_max is None or version_id > current_max:
|
||||
highest_by_base[normalized_base] = version_id
|
||||
|
||||
return highest_by_base
|
||||
|
||||
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
|
||||
"""Apply pagination to filtered data"""
|
||||
total_items = len(data)
|
||||
|
||||
Reference in New Issue
Block a user