feat(download): sync downloaded versions with update tracking

Add automatic synchronization of downloaded model versions with the update tracking system. After a successful download, the system now resolves model and version IDs from the download response and updates the update service with the newly downloaded version along with any existing local versions.

This ensures that:
- Update tracking accurately reflects which versions are available locally
- The system properly tracks both newly downloaded and existing versions
- Failed sync operations are gracefully handled with appropriate logging
- Support is included for LoRA, checkpoint, and embedding model types
This commit is contained in:
Will Miao
2025-10-26 10:42:09 +08:00
parent 600afdcd92
commit 188fe407b6

View File

@@ -375,6 +375,19 @@ class DownloadManager:
download_id=download_id,
)
if result.get('success', False):
resolved_model_id = (
model_id
or version_info.get('modelId')
or (version_info.get('model') or {}).get('id')
)
await self._sync_downloaded_version(
model_type,
resolved_model_id,
version_info,
model_version_id,
)
# If early_access_msg exists and download failed, replace error message
if 'early_access_msg' in locals() and not result.get('success', False):
result['error'] = early_access_msg
@@ -389,6 +402,96 @@ class DownloadManager:
return {'success': False, 'error': f"Early access restriction: {str(e)}. Please ensure you have purchased early access and are logged in to Civitai."}
return {'success': False, 'error': str(e)}
async def _sync_downloaded_version(
self,
model_type: str,
model_id_value,
version_info: Dict,
fallback_version_id=None,
) -> None:
"""Ensure update tracking reflects a newly downloaded version."""
try:
update_service = await ServiceRegistry.get_model_update_service()
except Exception as exc:
logger.debug("Skipping update sync; failed to acquire update service: %s", exc)
return
if update_service is None:
return
resolved_model_id = model_id_value
if resolved_model_id is None:
resolved_model_id = version_info.get('modelId')
if resolved_model_id is None:
model_info = version_info.get('model')
if isinstance(model_info, dict):
resolved_model_id = model_info.get('id')
try:
resolved_model_id = int(resolved_model_id)
except (TypeError, ValueError):
logger.debug("Skipping update sync; invalid model id: %s", resolved_model_id)
return
version_id = version_info.get('id')
if version_id is None:
version_id = fallback_version_id
try:
version_id = int(version_id)
except (TypeError, ValueError):
logger.debug(
"Skipping update sync; invalid version id for model %s: %s",
resolved_model_id,
version_id,
)
return
version_ids = set()
scanner = None
try:
if model_type == 'lora':
scanner = await self._get_lora_scanner()
elif model_type == 'checkpoint':
scanner = await self._get_checkpoint_scanner()
elif model_type == 'embedding':
scanner = await ServiceRegistry.get_embedding_scanner()
except Exception as exc:
logger.debug("Failed to acquire scanner for %s models: %s", model_type, exc)
if scanner is not None:
try:
local_versions = await scanner.get_model_versions_by_id(resolved_model_id)
except Exception as exc:
logger.debug(
"Failed to collect local versions for %s model %s: %s",
model_type,
resolved_model_id,
exc,
)
else:
for entry in local_versions or []:
vid = entry.get('versionId')
try:
version_ids.add(int(vid))
except (TypeError, ValueError):
continue
version_ids.add(version_id)
try:
await update_service.update_in_library_versions(
model_type,
resolved_model_id,
sorted(version_ids),
)
except Exception as exc:
logger.debug(
"Failed to update in-library versions for %s model %s: %s",
model_type,
resolved_model_id,
exc,
)
def _calculate_relative_path(self, version_info: Dict, model_type: str = 'lora') -> str:
"""Calculate relative path using template from settings