diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 505d997f..17bf0e4d 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -375,6 +375,19 @@ class DownloadManager: download_id=download_id, ) + if result.get('success', False): + resolved_model_id = ( + model_id + or version_info.get('modelId') + or (version_info.get('model') or {}).get('id') + ) + await self._sync_downloaded_version( + model_type, + resolved_model_id, + version_info, + model_version_id, + ) + # If early_access_msg exists and download failed, replace error message if 'early_access_msg' in locals() and not result.get('success', False): result['error'] = early_access_msg @@ -389,6 +402,96 @@ class DownloadManager: return {'success': False, 'error': f"Early access restriction: {str(e)}. Please ensure you have purchased early access and are logged in to Civitai."} return {'success': False, 'error': str(e)} + async def _sync_downloaded_version( + self, + model_type: str, + model_id_value, + version_info: Dict, + fallback_version_id=None, + ) -> None: + """Ensure update tracking reflects a newly downloaded version.""" + + try: + update_service = await ServiceRegistry.get_model_update_service() + except Exception as exc: + logger.debug("Skipping update sync; failed to acquire update service: %s", exc) + return + + if update_service is None: + return + + resolved_model_id = model_id_value + if resolved_model_id is None: + resolved_model_id = version_info.get('modelId') + if resolved_model_id is None: + model_info = version_info.get('model') + if isinstance(model_info, dict): + resolved_model_id = model_info.get('id') + try: + resolved_model_id = int(resolved_model_id) + except (TypeError, ValueError): + logger.debug("Skipping update sync; invalid model id: %s", resolved_model_id) + return + + version_id = version_info.get('id') + if version_id is None: + version_id = fallback_version_id + try: + version_id = int(version_id) + except (TypeError, ValueError): + logger.debug( + "Skipping update sync; invalid version id for model %s: %s", + resolved_model_id, + version_id, + ) + return + + version_ids = set() + scanner = None + try: + if model_type == 'lora': + scanner = await self._get_lora_scanner() + elif model_type == 'checkpoint': + scanner = await self._get_checkpoint_scanner() + elif model_type == 'embedding': + scanner = await ServiceRegistry.get_embedding_scanner() + except Exception as exc: + logger.debug("Failed to acquire scanner for %s models: %s", model_type, exc) + + if scanner is not None: + try: + local_versions = await scanner.get_model_versions_by_id(resolved_model_id) + except Exception as exc: + logger.debug( + "Failed to collect local versions for %s model %s: %s", + model_type, + resolved_model_id, + exc, + ) + else: + for entry in local_versions or []: + vid = entry.get('versionId') + try: + version_ids.add(int(vid)) + except (TypeError, ValueError): + continue + + version_ids.add(version_id) + + try: + await update_service.update_in_library_versions( + model_type, + resolved_model_id, + sorted(version_ids), + ) + except Exception as exc: + logger.debug( + "Failed to update in-library versions for %s model %s: %s", + model_type, + resolved_model_id, + exc, + ) + def _calculate_relative_path(self, version_info: Dict, model_type: str = 'lora') -> str: """Calculate relative path using template from settings