Merge pull request #619 from willmiao/codex/add-check-for-updates-in-bulk-context-menu

feat(bulk): add selected update check
This commit is contained in:
pixelpaws
2025-10-29 08:56:26 +08:00
committed by GitHub
19 changed files with 484 additions and 36 deletions

View File

@@ -1045,6 +1045,21 @@ class ModelUpdateHandler:
force_refresh = self._parse_bool(request.query.get("force")) or self._parse_bool(
payload.get("force")
)
raw_model_ids = payload.get("modelIds")
if raw_model_ids is None:
raw_model_ids = payload.get("model_ids")
target_model_ids: list[int] = []
if isinstance(raw_model_ids, (list, tuple, set)):
for value in raw_model_ids:
normalized = self._normalize_model_id(value)
if normalized is not None:
target_model_ids.append(normalized)
if target_model_ids:
target_model_ids = sorted(set(target_model_ids))
provider = await self._get_civitai_provider()
if provider is None:
return web.json_response(
@@ -1057,6 +1072,7 @@ class ModelUpdateHandler:
self._service.scanner,
provider,
force_refresh=force_refresh,
target_model_ids=target_model_ids or None,
)
except RateLimitError as exc:
return web.json_response(

View File

@@ -277,15 +277,33 @@ class ModelUpdateService:
metadata_provider,
*,
force_refresh: bool = False,
target_model_ids: Optional[Sequence[int]] = None,
) -> Dict[int, ModelUpdateRecord]:
"""Refresh update information for every model present in the cache."""
local_versions = await self._collect_local_versions(scanner)
normalized_targets = (
self._normalize_sequence(target_model_ids)
if target_model_ids is not None
else []
)
target_filter = normalized_targets or None
local_versions = await self._collect_local_versions(
scanner,
target_model_ids=target_filter,
)
total_models = len(local_versions)
if total_models == 0:
logger.info(
"No %s models found while refreshing update metadata", model_type
)
if target_filter:
logger.info(
"No %s models matched requested ids %s while refreshing update metadata",
model_type,
target_filter,
)
else:
logger.info(
"No %s models found while refreshing update metadata", model_type
)
return {}
logger.info(
@@ -683,12 +701,23 @@ class ModelUpdateService:
)
return aggregated
async def _collect_local_versions(self, scanner) -> Dict[int, List[int]]:
async def _collect_local_versions(
self,
scanner,
*,
target_model_ids: Optional[Sequence[int]] = None,
) -> Dict[int, List[int]]:
cache = await scanner.get_cached_data()
mapping: Dict[int, set[int]] = {}
if not cache or not getattr(cache, "raw_data", None):
return {}
target_set = None
if target_model_ids:
target_set = set(target_model_ids)
if not target_set:
return {}
for item in cache.raw_data:
civitai = item.get("civitai") if isinstance(item, dict) else None
if not isinstance(civitai, dict):
@@ -697,6 +726,8 @@ class ModelUpdateService:
version_id = self._normalize_int(civitai.get("id"))
if model_id is None or version_id is None:
continue
if target_set is not None and model_id not in target_set:
continue
mapping.setdefault(model_id, set()).add(version_id)
return {model_id: sorted(ids) for model_id, ids in mapping.items()}