Skip to content
Open
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
Expand Down Expand Up @@ -68,6 +70,28 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
int per_tensor, int max_chunks_per_tensor,
cudaStream_t stream);

/*! \brief Computes cumulative L2 norm for a list of tensors from precomputed chunk metadata.
*
* \warning This API is **experimental** and subject to change.
*/
void nvte_multi_tensor_l2norm_cuda_custom(int chunk_size, NVTETensor noop_flag,
const NVTEDType input_dtype, const int64_t *addresses,
const int *sizes, const int *block_to_tensor,
const int *chunk_offsets, int total_chunks,
NVTETensor output,
NVTETensor ret, cudaStream_t stream);

/*! \brief Computes cumulative L2 norm for a list of tensors after unscaling from precomputed
* chunk metadata.
*
* \warning This API is **experimental** and subject to change.
*/
void nvte_multi_tensor_unscale_l2norm_cuda_custom(
int chunk_size, NVTETensor noop_flag, const NVTEDType input_dtype,
const int64_t *addresses, const int *sizes, const int *block_to_tensor,
const int *chunk_offsets, int total_chunks,
NVTETensor output, NVTETensor ret, NVTETensor inv_scale, cudaStream_t stream);

/*! \brief Compute and apply gradient update to parameters for Adam optimizer.
*
* \warning This API is **experimental** and subject to change.
Expand Down Expand Up @@ -120,6 +144,35 @@ void nvte_multi_tensor_adam_param_remainder_cuda(
const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, cudaStream_t stream);

/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* where the master parameters only store the remainder bits.
* Uses precomputed chunk metadata instead of TensorListMetadata.
*
* \warning This API is **experimental** and subject to change.
*/
void nvte_multi_tensor_adam_param_remainder_cuda_custom(
int chunk_size, NVTETensor noop_flag, NVTEDType grad_dtype, NVTEDType moment_dtype,
int64_t *addresses, int64_t *sizes, int *block_to_tensor, int *chunk_offsets,
int total_chunks,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, cudaStream_t stream);

/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* (4-list: g, p, m, v).
* Uses precomputed chunk metadata instead of TensorListMetadata.
*
* \warning This API is **experimental** and subject to change.
*/
void nvte_multi_tensor_adam_cuda_custom(
int chunk_size, NVTETensor noop_flag, NVTEDType grad_dtype, NVTEDType param_dtype,
NVTEDType moment_dtype,
int64_t *addresses, int64_t *sizes, int *block_to_tensor, int *chunk_offsets,
int total_chunks, int has_master,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, cudaStream_t stream);

/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* when model parameters are in Float8 precision.
*
Expand Down
Loading
Loading