diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index 6b4ef91d..3d088b72 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -577,6 +577,59 @@ class CivitaiClient: logger.error(error_msg) return None + async def get_model_versions_by_hashes( + self, hashes: List[str] + ) -> Optional[List[Dict]]: + """Fetch full version details for up to 100 SHA256 hashes via the batch endpoint. + + Uses POST /api/v1/model-versions/by-hash which returns full version + details including ``usageControl`` and ``earlyAccessEndsAt`` that are + not available from the model-level API. + + Args: + hashes: List of SHA256 hashes (max 100 per batch; auto-split). + + Returns: + List of version dicts or None on failure. + """ + if not hashes: + return [] + + BATCH_SIZE = 100 + all_versions: List[Dict] = [] + + for start in range(0, len(hashes), BATCH_SIZE): + batch = hashes[start : start + BATCH_SIZE] + try: + success, result = await self._make_request( + "POST", + f"{self.base_url}/model-versions/by-hash", + use_auth=True, + json=batch, + ) + if not success: + logger.warning( + "Batch by-hash request failed for %d hashes: %s", + len(batch), + result, + ) + continue + + if isinstance(result, list): + all_versions.extend(result) + else: + logger.debug( + "Unexpected by-hash response type: %s", type(result) + ) + except RateLimitError: + raise + except Exception as exc: # pragma: no cover - defensive logging + logger.error( + "Error fetching model versions by hashes: %s", exc + ) + + return all_versions if all_versions else None + async def get_user_models(self, username: str) -> Optional[List[Dict]]: """Fetch all models for a specific Civitai user.""" if not username: diff --git a/py/services/model_metadata_provider.py b/py/services/model_metadata_provider.py index ea33a6be..b80f5d48 100644 --- a/py/services/model_metadata_provider.py +++ b/py/services/model_metadata_provider.py @@ -108,6 +108,18 @@ class ModelMetadataProvider(ABC): ) -> Optional[Dict[int, Dict]]: """Fetch model versions for multiple model ids when supported.""" raise NotImplementedError + + async def get_model_versions_by_hashes( + self, hashes: List[str] + ) -> Optional[List[Dict]]: + """Fetch full version details for multiple SHA256 hashes. + + Used specifically to retrieve ``usageControl`` which is only + available from the per-version / by-hash API, not from model-level + responses. Providers that cannot resolve hashes should let the + default ``NotImplementedError`` propagate. + """ + raise NotImplementedError @abstractmethod async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: @@ -140,6 +152,11 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider): self, model_ids: Sequence[int] ) -> Optional[Dict[int, Dict]]: return await self.client.get_model_versions_bulk(model_ids) + + async def get_model_versions_by_hashes( + self, hashes: List[str] + ) -> Optional[List[Dict]]: + return await self.client.get_model_versions_by_hashes(hashes) async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: return await self.client.get_model_version(model_id, version_id) @@ -519,6 +536,32 @@ class FallbackMetadataProvider(ModelMetadataProvider): continue return None, "No provider could retrieve the data" + async def get_model_versions_by_hashes( + self, hashes: List[str] + ) -> Optional[List[Dict]]: + for provider, label in self._iter_providers(): + try: + result = await self._call_with_rate_limit( + label, + provider.get_model_versions_by_hashes, + hashes, + ) + if result is not None: + return result + except NotImplementedError: + continue + except RateLimitError as exc: + exc.provider = exc.provider or label + raise exc + except Exception as e: + logger.debug( + "Provider %s failed for get_model_versions_by_hashes: %s", + label, + e, + ) + continue + return None + async def get_user_models(self, username: str) -> Optional[List[Dict]]: for provider, label in self._iter_providers(): try: @@ -593,6 +636,15 @@ class RateLimitRetryingProvider(ModelMetadataProvider): model_ids, ) + async def get_model_versions_by_hashes( + self, hashes: List[str] + ) -> Optional[List[Dict]]: + return await self._rate_limit_helper.run( + self._label, + self._provider.get_model_versions_by_hashes, + hashes, + ) + async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: return await self._rate_limit_helper.run( self._label, @@ -669,6 +721,17 @@ class ModelMetadataProviderManager: provider = self._get_provider(provider_name) return await provider.get_model_version_info(version_id) + async def get_model_versions_by_hashes( + self, + hashes: List[str], + provider_name: str = None, + ) -> Optional[List[Dict]]: + provider = self._get_provider(provider_name) + try: + return await provider.get_model_versions_by_hashes(hashes) + except NotImplementedError: + return None + async def get_user_models(self, username: str, provider_name: str = None) -> Optional[List[Dict]]: """Fetch models owned by the specified user""" provider = self._get_provider(provider_name) diff --git a/py/services/model_update_service.py b/py/services/model_update_service.py index cfcba24b..0bfcfd23 100644 --- a/py/services/model_update_service.py +++ b/py/services/model_update_service.py @@ -989,6 +989,11 @@ class ModelUpdateService: fallback_attempted = True try: response = await metadata_provider.get_model_versions(model_id) + if response is not None: + await self._enrich_version_entries( + metadata_provider, + {model_id: response}, + ) except RateLimitError: raise except ResourceNotFoundError as exc: @@ -1083,6 +1088,136 @@ class ModelUpdateService: self._upsert_record(record) return record + async def _enrich_version_entries( + self, + metadata_provider, + responses_by_model_id: Dict[int, Mapping], + ) -> None: + """Enrich version entries with ``usageControl`` via batch hash endpoint. + + The model-level API does not include ``usageControl`` on version + entries. This method collects SHA256 hashes from every version's + primary model file, calls ``POST /api/v1/model-versions/by-hash`` + (up to 100 hashes per request), and injects ``usageControl`` + + ``earlyAccessEndsAt`` into each version entry dict in-place. + """ + if not metadata_provider or not responses_by_model_id: + return + + hashes_by_version: Dict[int, str] = {} + for response in responses_by_model_id.values(): + hashes_by_version.update( + self._collect_hashes_from_response(response) + ) + + if not hashes_by_version: + return + + version_ids_by_hash: Dict[str, List[int]] = {} + for version_id, sha256 in hashes_by_version.items(): + version_ids_by_hash.setdefault(sha256, []).append(version_id) + + all_hashes = list(version_ids_by_hash.keys()) + BATCH_SIZE = 100 + + enrichment: Dict[int, Dict] = {} + try: + for start in range(0, len(all_hashes), BATCH_SIZE): + batch = all_hashes[start : start + BATCH_SIZE] + try: + enriched = await metadata_provider.get_model_versions_by_hashes( + batch + ) + except NotImplementedError: + return + except RateLimitError: + raise + except Exception: + continue + + if not enriched: + continue + + for entry in enriched: + if not isinstance(entry, dict): + continue + version_id = entry.get("id") + if version_id is None: + continue + enrichment[version_id] = { + "usageControl": _normalize_string( + entry.get("usageControl") + ), + "earlyAccessEndsAt": _normalize_string( + entry.get("earlyAccessEndsAt") + ), + } + except RateLimitError: + raise + + if not enrichment: + return + + for response in responses_by_model_id.values(): + versions = response.get("modelVersions") + if not isinstance(versions, list): + continue + for version in versions: + if not isinstance(version, dict): + continue + version_id = version.get("id") + if version_id not in enrichment: + continue + extra = enrichment[version_id] + if extra.get("usageControl") and not version.get("usageControl"): + version["usageControl"] = extra["usageControl"] + if extra.get("earlyAccessEndsAt") and not version.get( + "earlyAccessEndsAt" + ): + version["earlyAccessEndsAt"] = extra["earlyAccessEndsAt"] + + @staticmethod + def _collect_hashes_from_response(response: Mapping) -> Dict[int, str]: + """Extract ``{version_id: sha256}`` from a model-level API response. + + Returns an empty dict if the response structure is unexpected. + """ + result: Dict[int, str] = {} + versions = response.get("modelVersions") + if not isinstance(versions, list): + return result + for entry in versions: + if not isinstance(entry, dict): + continue + version_id = _normalize_int(entry.get("id")) + if version_id is None: + continue + sha256 = ModelUpdateService._extract_sha256_from_version_entry(entry) + if sha256: + result[version_id] = sha256 + return result + + @staticmethod + def _extract_sha256_from_version_entry(entry: Mapping) -> Optional[str]: + """Return the SHA256 hash from the primary model file of a version entry.""" + files = entry.get("files") + if not isinstance(files, list): + return None + for file_info in files: + if not isinstance(file_info, dict): + continue + if file_info.get("type") != "Model": + continue + primary = file_info.get("primary") + if primary is not True and str(primary).strip().lower() != "true": + continue + hashes = file_info.get("hashes") + if isinstance(hashes, dict): + sha256 = hashes.get("SHA256") + if sha256: + return sha256 + return None + async def _fetch_model_versions_bulk( self, metadata_provider, @@ -1134,6 +1269,7 @@ class ModelUpdateService: len(aggregated), provider_name, ) + await self._enrich_version_entries(metadata_provider, aggregated) return aggregated async def _collect_local_versions( @@ -1261,6 +1397,7 @@ class ModelUpdateService: sort_index=sort_map.get(version_id, index), early_access_ends_at=remote_version.early_access_ends_at, is_early_access=remote_version.is_early_access, + usage_control=remote_version.usage_control, ) ) diff --git a/static/css/components/lora-modal/versions.css b/static/css/components/lora-modal/versions.css index 25dd81fe..9239218d 100644 --- a/static/css/components/lora-modal/versions.css +++ b/static/css/components/lora-modal/versions.css @@ -387,6 +387,10 @@ cursor: not-allowed; } +.version-action-disabled-wrapper { + display: inline-flex; +} + .versions-loading-state, .versions-empty, .versions-error { diff --git a/static/js/components/shared/ModelVersionsTab.js b/static/js/components/shared/ModelVersionsTab.js index 0e88fa2b..c00a2e98 100644 --- a/static/js/components/shared/ModelVersionsTab.js +++ b/static/js/components/shared/ModelVersionsTab.js @@ -241,7 +241,7 @@ function buildActionButton(label, variant, action, options = {}) { if (action) { attributes.push(`data-version-action="${escapeHtml(action)}"`); } - if (options.title) { + if (!options.disabled && options.title) { attributes.push(`title="${escapeHtml(options.title)}"`); attributes.push(`aria-label="${escapeHtml(options.title)}"`); } @@ -251,7 +251,11 @@ function buildActionButton(label, variant, action, options = {}) { if (options.extraAttributes) { attributes.push(options.extraAttributes); } - return ``; + const buttonHtml = ``; + if (options.disabled && options.title) { + return `${buttonHtml}`; + } + return buttonHtml; } const DISPLAY_FILTER_MODES = Object.freeze({