From 188fe407b6c97bf97c7804c39eda06a0554eab30 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Sun, 26 Oct 2025 10:42:09 +0800 Subject: [PATCH] feat(download): sync downloaded versions with update tracking Add automatic synchronization of downloaded model versions with the update tracking system. After a successful download, the system now resolves model and version IDs from the download response and updates the update service with the newly downloaded version along with any existing local versions. This ensures that: - Update tracking accurately reflects which versions are available locally - The system properly tracks both newly downloaded and existing versions - Failed sync operations are gracefully handled with appropriate logging - Support is included for LoRA, checkpoint, and embedding model types --- py/services/download_manager.py | 103 ++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 505d997f..17bf0e4d 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -375,6 +375,19 @@ class DownloadManager: download_id=download_id, ) + if result.get('success', False): + resolved_model_id = ( + model_id + or version_info.get('modelId') + or (version_info.get('model') or {}).get('id') + ) + await self._sync_downloaded_version( + model_type, + resolved_model_id, + version_info, + model_version_id, + ) + # If early_access_msg exists and download failed, replace error message if 'early_access_msg' in locals() and not result.get('success', False): result['error'] = early_access_msg @@ -389,6 +402,96 @@ class DownloadManager: return {'success': False, 'error': f"Early access restriction: {str(e)}. Please ensure you have purchased early access and are logged in to Civitai."} return {'success': False, 'error': str(e)} + async def _sync_downloaded_version( + self, + model_type: str, + model_id_value, + version_info: Dict, + fallback_version_id=None, + ) -> None: + """Ensure update tracking reflects a newly downloaded version.""" + + try: + update_service = await ServiceRegistry.get_model_update_service() + except Exception as exc: + logger.debug("Skipping update sync; failed to acquire update service: %s", exc) + return + + if update_service is None: + return + + resolved_model_id = model_id_value + if resolved_model_id is None: + resolved_model_id = version_info.get('modelId') + if resolved_model_id is None: + model_info = version_info.get('model') + if isinstance(model_info, dict): + resolved_model_id = model_info.get('id') + try: + resolved_model_id = int(resolved_model_id) + except (TypeError, ValueError): + logger.debug("Skipping update sync; invalid model id: %s", resolved_model_id) + return + + version_id = version_info.get('id') + if version_id is None: + version_id = fallback_version_id + try: + version_id = int(version_id) + except (TypeError, ValueError): + logger.debug( + "Skipping update sync; invalid version id for model %s: %s", + resolved_model_id, + version_id, + ) + return + + version_ids = set() + scanner = None + try: + if model_type == 'lora': + scanner = await self._get_lora_scanner() + elif model_type == 'checkpoint': + scanner = await self._get_checkpoint_scanner() + elif model_type == 'embedding': + scanner = await ServiceRegistry.get_embedding_scanner() + except Exception as exc: + logger.debug("Failed to acquire scanner for %s models: %s", model_type, exc) + + if scanner is not None: + try: + local_versions = await scanner.get_model_versions_by_id(resolved_model_id) + except Exception as exc: + logger.debug( + "Failed to collect local versions for %s model %s: %s", + model_type, + resolved_model_id, + exc, + ) + else: + for entry in local_versions or []: + vid = entry.get('versionId') + try: + version_ids.add(int(vid)) + except (TypeError, ValueError): + continue + + version_ids.add(version_id) + + try: + await update_service.update_in_library_versions( + model_type, + resolved_model_id, + sorted(version_ids), + ) + except Exception as exc: + logger.debug( + "Failed to update in-library versions for %s model %s: %s", + model_type, + resolved_model_id, + exc, + ) + def _calculate_relative_path(self, version_info: Dict, model_type: str = 'lora') -> str: """Calculate relative path using template from settings