diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 4b7d8179b..85d89bff8 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -129,7 +129,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Atype = A.data.dtype; ret.A_scale_inv = A.scale_inv.dptr; ret.lda = is_A_transposed ? k : m; - if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) { + int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { ret.A = A.columnwise_data.dptr; @@ -140,7 +141,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } - } else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) { + } else if (is_nvte_non_tn_fp8_gemm_supported && !A.has_data()) { // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed // data with the mirrored transpose-flag if we don't have row-wise data. NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype), @@ -220,7 +221,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Btype = B.data.dtype; ret.B_scale_inv = B.scale_inv.dptr; ret.ldb = is_B_transposed ? n : k; - if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { + int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + if (!is_nvte_non_tn_fp8_gemm_supported && is_B_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { ret.B = B.columnwise_data.dptr; @@ -231,7 +233,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); } - } else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) { + } else if (is_nvte_non_tn_fp8_gemm_supported && !B.has_data()) { // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed // data with the mirrored transpose-flag if we don't have row-wise data. NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype), diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 7fc9d7898..0994dbd71 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -517,7 +517,6 @@ NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor); * \brief Namespace containing C++ API of Transformer Engine. */ namespace transformer_engine { - /*! \enum DType * \brief TE datatype. */ diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 370d9723c..82c50c4eb 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -961,7 +961,7 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { } int nvte_is_non_tn_fp8_gemm_supported() { - int num_devices = transformer_engine::cuda::num_devices(); + static int num_devices = transformer_engine::cuda::num_devices(); static std::vector cache(num_devices, -1); static std::vector flags(num_devices); int device_id = transformer_engine::cuda::current_device(); diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index fa6f142b6..1b56c55bb 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -13,29 +13,31 @@ namespace transformer_engine::pytorch { /*! convert fp4 data shape back to original shape */ -std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose) { - std::vector ret; +NVTEShape convert_shape_back_from_fp4(const NVTEShape& shape, bool transpose) { + NVTEShape ret; size_t start_idx = (transpose) ? 1 : 0; - for (size_t i = start_idx; i < shape.size() - 1; ++i) { - ret.push_back(shape[i]); + size_t out_idx = 0; + + // Copy dimensions from start_idx to ndim-1 + for (size_t i = start_idx; i < shape.ndim - 1; ++i) { + ret.data[out_idx++] = shape.data[i]; } - ret.push_back(shape.back() * 2); + + // Last dimension multiplied by 2 + ret.data[out_idx++] = shape.data[shape.ndim - 1] * 2; + + // If transpose, add the first dimension if (transpose) { - ret.push_back(shape.front()); + ret.data[out_idx++] = shape.data[0]; } - return ret; -} -std::vector getTensorShape(const at::Tensor& t) { - std::vector shape; - for (auto s : t.sizes()) { - shape.push_back(s); - } - return shape; + ret.ndim = out_idx; + return ret; } -NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) { +NVTEShape getTensorShape(const at::Tensor& t) { NVTEShape ret; + const c10::IntArrayRef& torch_shape = t.sizes(); ret.ndim = torch_shape.size(); constexpr int max_dimensions = sizeof(ret.data) / sizeof(size_t); NVTE_CHECK(ret.ndim < max_dimensions, @@ -112,17 +114,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( return transformer_engine::TensorWrapper(data_ptr, shape, type); } -transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const std::vector& shape, const transformer_engine::DType type) { - return transformer_engine::TensorWrapper(data_ptr, shape, type); -} - transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) { transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); - std::vector shape; - for (auto s : tensor.sizes()) { - shape.push_back(s); - } + NVTEShape shape = getTensorShape(tensor); return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype); } @@ -164,14 +158,13 @@ makeTransformerEngineTensorList(std::vector> at_tensor_l } transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const std::vector& shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape, + void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, NVTEScalingMode scaling_mode) { TensorWrapper ret(scaling_mode); ret.set_rowwise_data(data_ptr, type, shape); - const std::vector meta_shape{1}; - ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); - ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + ret.set_amax(amax_ptr, DType::kFloat32, TensorWrapper::defaultShape); + ret.set_scale(scale_ptr, DType::kFloat32, TensorWrapper::defaultShape); auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); @@ -179,17 +172,16 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( } transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, - const std::vector& columnwise_shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, - const std::vector& scale_inv_shape, - const std::vector& columnwise_scale_inv_shape, NVTEScalingMode scaling_mode) { + void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, + const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape, + NVTEScalingMode scaling_mode) { TensorWrapper ret(scaling_mode); ret.set_rowwise_data(data_ptr, type, shape); ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape); - const std::vector meta_shape{1}; - ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); - ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + ret.set_amax(amax_ptr, DType::kFloat32, TensorWrapper::defaultShape); + ret.set_scale(scale_ptr, DType::kFloat32, TensorWrapper::defaultShape); auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat8E4M3 : DType::kFloat32; @@ -230,6 +222,9 @@ template size_t product(const std::vector& shape); template int64_t product(const std::vector& shape); size_t product(const NVTEShape& shape, size_t begin, size_t end) { + if (end == -1) { + end = shape.ndim; + } NVTE_CHECK(begin <= end && end <= shape.ndim, "Attempted to access entries ", begin, " to ", end, " in a shape with ", shape.ndim, " entries"); size_t ret = 1; diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 1e1e3326c..4d2271bb2 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -100,9 +100,8 @@ class Quantizer { virtual void set_quantization_params(TensorWrapper* tensor) const = 0; /*! @brief Construct a tensor with uninitialized data */ - virtual std::pair create_tensor(const std::vector& shape, + virtual std::pair create_tensor(const NVTEShape& shape, DType dtype) const = 0; - /*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor * * The PyTorch tensor's attributes are modified to match the @@ -134,11 +133,11 @@ class NoneQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override {} - std::pair create_tensor(const std::vector& shape, + std::pair create_tensor(const NVTEShape& shape, DType dtype) const override; /*! @brief Construct a tensor with pre-initialized data */ - std::pair create_tensor(const std::vector& shape, DType dtype, + std::pair create_tensor(const NVTEShape& shape, DType dtype, at::Tensor data) const; std::pair convert_and_update_tensor(py::object tensor) const override; @@ -160,11 +159,10 @@ class Float8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor(const std::vector& shape, + std::pair create_tensor(const NVTEShape& shape, DType dtype) const override; - /*! @brief Construct a tensor with pre-initialized data */ - std::pair create_tensor(const std::vector& shape, DType dtype, + std::pair create_tensor(const NVTEShape& shape, DType dtype, std::optional data, std::optional transpose, std::optional scale_inv) const; @@ -192,16 +190,15 @@ class Float8CurrentScalingQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor(const std::vector& shape, + std::pair create_tensor(const NVTEShape& shape, DType dtype) const override; - /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. * * The amax is zeroed out. Most TE kernels that output amax expect * amax to be initialized to zero. */ std::pair create_unquantized_tensor_with_amax( - const std::vector& shape, DType dtype, std::optional data = std::nullopt); + const NVTEShape& shape, DType dtype, std::optional data = std::nullopt); std::pair convert_and_update_tensor(py::object shape) const override; @@ -217,7 +214,6 @@ class Float8CurrentScalingQuantizer : public Quantizer { void quantize_with_amax(TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt); - private: void quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, bool compute_amax); }; @@ -251,7 +247,7 @@ class Float8BlockQuantizer : public Quantizer { // Create a python Float8BlockQuantized tensor and C++ wrapper // for the tensor. Should set quantized data, scales for rowwise // and optionally columnwise usage. - std::pair create_tensor(const std::vector& shape, + std::pair create_tensor(const NVTEShape& shape, DType dtype) const override; std::pair convert_and_update_tensor(py::object shape) const override; @@ -260,6 +256,11 @@ class Float8BlockQuantizer : public Quantizer { const std::optional& noop_flag = std::nullopt) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; + NVTEShape get_scale_shape(const NVTEShape& shape, bool columnwise) const; + + private: + template + ShapeT get_scale_shape_impl(size_t numel, size_t last_dim, bool columnwise) const; }; class MXFP8Quantizer : public Quantizer { @@ -272,7 +273,7 @@ class MXFP8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor(const std::vector& shape, + std::pair create_tensor(const NVTEShape& shape, DType dtype) const override; std::pair convert_and_update_tensor(py::object shape) const override; @@ -281,6 +282,11 @@ class MXFP8Quantizer : public Quantizer { const std::optional& noop_flag = std::nullopt) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; + NVTEShape get_scale_shape(const NVTEShape& shape, bool columnwise) const; + + private: + template + ShapeT get_scale_shape_impl(size_t numel, size_t last_dim, bool columnwise) const; }; class NVFP4Quantizer : public Quantizer { @@ -306,9 +312,8 @@ class NVFP4Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor(const std::vector& shape, + std::pair create_tensor(const NVTEShape& shape, DType dtype) const override; - /*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer * * The amax is zeroed out. Most TE kernels that output amax expect @@ -331,15 +336,37 @@ class NVFP4Quantizer : public Quantizer { void quantize_with_amax(TensorWrapper& input, TensorWrapper& out); std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; + NVTEShape get_scale_shape(const NVTEShape& shape, bool columnwise) const; private: + template + ShapeT get_scale_shape_impl(size_t numel, size_t last_dim, bool columnwise) const; + + public: void quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, bool compute_amax); }; std::unique_ptr convert_quantizer(py::handle quantizer); -std::vector getTensorShape(const at::Tensor& t); +NVTEShape getTensorShape(const at::Tensor& t); + +template +inline NVTEShape make_nvte_1d_shape(T dim0) { + NVTEShape shape; + shape.ndim = 1; + shape.data[0] = static_cast(dim0); + return shape; +} + +template +inline NVTEShape make_nvte_2d_shape(T dim0, U dim1) { + NVTEShape shape; + shape.ndim = 2; + shape.data[0] = static_cast(dim0); + shape.data[1] = static_cast(dim1); + return shape; +} transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); @@ -423,21 +450,16 @@ inline transformer_engine::DType GetTransformerEngineDType(int DType_value) { return static_cast(DType_value); } -transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, - const std::vector& shape, - const transformer_engine::DType type); - transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, const std::vector& shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape = {1}, + void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); transformer_engine::TensorWrapper makeTransformerEngineTensor( - void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, - const std::vector& columnwise_shape, const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, - const std::vector& scale_inv_shape = {1}, - const std::vector& columnwise_scale_inv_shape = {1}, + void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape, + const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr, + void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape, NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, @@ -459,7 +481,7 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( template T product(const std::vector& shape); -size_t product(const NVTEShape& shape, size_t begin, size_t end); +size_t product(const NVTEShape& shape, size_t begin = 0, size_t end = -1); std::vector nvte_shape_to_vector(const NVTEShape& nvte_shape); @@ -479,9 +501,7 @@ std::vector convertShape(const NVTEShape& shape); size_t roundup(const size_t value, const size_t multiple); -NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); - -std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose); +NVTEShape convert_shape_back_from_fp4(const NVTEShape& shape, bool transpose); // unpack the PhiloxCudaState into CUDA tensor void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr); diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 9ea14e1af..41f7048d6 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -26,8 +26,8 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int // Construct output tensor auto quantizer_cpp = convert_quantizer(quantizer); const auto input_shape = input_nvte.shape(); - std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); - output_shape.back() /= shape_divisor; + NVTEShape output_shape = input_shape; + output_shape.data[output_shape.ndim - 1] /= shape_divisor; auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); auto [out_nvte, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); @@ -137,9 +137,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i // Construct grad input tensor auto quantizer_cpp = convert_quantizer(quantizer); - const auto input_shape_te = input_nvte.shape(); - const std::vector input_shape(input_shape_te.data, - input_shape_te.data + input_shape_te.ndim); + const auto input_shape = input_nvte.shape(); auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); auto [grad_input_nvte, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index b455e0375..03bd96791 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -59,6 +59,7 @@ std::pair quantizer_helper(py::handle quantizer, const std::vector &shape, DType dtype, bool create_hp_tensor_for_cs, std::optional data) { + NVTEShape nvte_shape = nvte_make_shape(shape.data(), shape.size()); std::unique_ptr T_quantizer = convert_quantizer(quantizer); TensorWrapper te_T; py::object py_T; @@ -66,27 +67,28 @@ std::pair quantizer_helper(py::handle quantizer, // high precision auto *none_quantizer = dynamic_cast(T_quantizer.get()); if (data.has_value()) { - std::tie(te_T, py_T) = none_quantizer->create_tensor(shape, dtype, data.value()); + std::tie(te_T, py_T) = none_quantizer->create_tensor(nvte_shape, dtype, data.value()); } else { - std::tie(te_T, py_T) = none_quantizer->create_tensor(shape, dtype); + std::tie(te_T, py_T) = none_quantizer->create_tensor(nvte_shape, dtype); } } else if (detail::IsFloat8Quantizers(quantizer.ptr())) { // delayed scaling; this helps initialize scale_inv auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); std::tie(te_T, py_T) = - T_quantizer_fp8->create_tensor(shape, dtype, data, std::nullopt, std::nullopt); + T_quantizer_fp8->create_tensor(nvte_shape, dtype, data, std::nullopt, std::nullopt); } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { // current scaling auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); if (create_hp_tensor_for_cs) { if (data.has_value()) { std::tie(te_T, py_T) = - T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value()); + T_quantizer_fp8->create_unquantized_tensor_with_amax(nvte_shape, dtype, data.value()); } else { - std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype); + std::tie(te_T, py_T) = + T_quantizer_fp8->create_unquantized_tensor_with_amax(nvte_shape, dtype); } } else { - std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(shape, dtype); + std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(nvte_shape, dtype); NVTE_CHECK( !data.has_value(), "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); @@ -162,14 +164,13 @@ std::vector fused_attn_fwd( NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - auto bias_sizes = Bias.value().sizes().vec(); - std::vector bias_shape{bias_sizes.begin(), bias_sizes.end()}; + NVTEShape bias_shape = getTensorShape(Bias.value()); te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32); } auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); - std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; + NVTEShape cu_seqlens_q_shape = getTensorShape(cu_seqlens_q); auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); - std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; + NVTEShape cu_seqlens_kv_shape = getTensorShape(cu_seqlens_kv); te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, DType::kInt32); te_cu_seqlens_kv = @@ -177,38 +178,32 @@ std::vector fused_attn_fwd( if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); - std::vector cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(), - cu_seqlens_q_padded_sizes.end()}; + NVTEShape cu_seqlens_q_padded_shape = getTensorShape(cu_seqlens_q_padded.value()); auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); - std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), - cu_seqlens_kv_padded_sizes.end()}; + NVTEShape cu_seqlens_kv_padded_shape = getTensorShape(cu_seqlens_kv_padded.value()); te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), cu_seqlens_q_padded_shape, DType::kInt32); te_cu_seqlens_kv_padded = makeTransformerEngineTensor( cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); } - if ((page_table_k.has_value()) && (page_table_v.has_value())) { - auto page_table_k_sizes = page_table_k.value().sizes().vec(); - std::vector page_table_k_shape{page_table_k_sizes.begin(), page_table_k_sizes.end()}; - auto page_table_v_sizes = page_table_v.value().sizes().vec(); - std::vector page_table_v_shape{page_table_v_sizes.begin(), page_table_v_sizes.end()}; - te_page_table_k = - makeTransformerEngineTensor(page_table_k.value().data_ptr(), page_table_k_shape, - DType::kInt32, nullptr, nullptr, nullptr); - te_page_table_v = - makeTransformerEngineTensor(page_table_v.value().data_ptr(), page_table_v_shape, - DType::kInt32, nullptr, nullptr, nullptr); + NVTEShape page_table_k_shape = getTensorShape(page_table_k.value()); + NVTEShape page_table_v_shape = getTensorShape(page_table_v.value()); + te_page_table_k = makeTransformerEngineTensor(page_table_k.value().data_ptr(), + page_table_k_shape, DType::kInt32, nullptr, + nullptr, nullptr, TensorWrapper::defaultShape); + te_page_table_v = makeTransformerEngineTensor(page_table_v.value().data_ptr(), + page_table_v_shape, DType::kInt32, nullptr, + nullptr, nullptr, TensorWrapper::defaultShape); } // softmax offset TensorWrapper te_SoftmaxOffset; if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) { - auto SoftmaxOffset_sizes = SoftmaxOffset.value().sizes().vec(); - std::vector SoftmaxOffset_shape{SoftmaxOffset_sizes.begin(), SoftmaxOffset_sizes.end()}; - te_SoftmaxOffset = - makeTransformerEngineTensor(SoftmaxOffset.value().data_ptr(), SoftmaxOffset_shape, - DType::kFloat32, nullptr, nullptr, nullptr); + NVTEShape SoftmaxOffset_shape = getTensorShape(SoftmaxOffset.value()); + te_SoftmaxOffset = makeTransformerEngineTensor(SoftmaxOffset.value().data_ptr(), + SoftmaxOffset_shape, DType::kFloat32, nullptr, + nullptr, nullptr, TensorWrapper::defaultShape); } // extract rng seed and offset @@ -460,28 +455,25 @@ std::vector fused_attn_bwd( } // create cu_seqlens tensorwrappers - auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); - std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; - auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); - std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; + NVTEShape cu_seqlens_q_shape = getTensorShape(cu_seqlens_q); + NVTEShape cu_seqlens_kv_shape = getTensorShape(cu_seqlens_kv); TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, - DType::kInt32, nullptr, nullptr, nullptr); - te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, - DType::kInt32, nullptr, nullptr, nullptr); + te_cu_seqlens_q = makeTransformerEngineTensor( + cu_seqlens_q.data_ptr(), static_cast(cu_seqlens_q_shape), DType::kInt32, nullptr, + nullptr, nullptr, TensorWrapper::emptyShape); + te_cu_seqlens_kv = makeTransformerEngineTensor( + cu_seqlens_kv.data_ptr(), static_cast(cu_seqlens_kv_shape), DType::kInt32, + nullptr, nullptr, nullptr, TensorWrapper::emptyShape); TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { - auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); - std::vector cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(), - cu_seqlens_q_padded_sizes.end()}; - auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); - std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), - cu_seqlens_kv_padded_sizes.end()}; + NVTEShape cu_seqlens_q_padded_shape = getTensorShape(cu_seqlens_q_padded.value()); + NVTEShape cu_seqlens_kv_padded_shape = getTensorShape(cu_seqlens_kv_padded.value()); te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), cu_seqlens_q_padded_shape, DType::kInt32); te_cu_seqlens_kv_padded = makeTransformerEngineTensor( - cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); + cu_seqlens_kv_padded.value().data_ptr(), + static_cast(cu_seqlens_kv_padded_shape), DType::kInt32); } // convert auxiliary tensors from forward to NVTETensors @@ -489,13 +481,11 @@ std::vector fused_attn_bwd( nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - const std::vector &signed_shape = Aux_CTX_Tensors[i].sizes().vec(); - const std::vector tmp(signed_shape.begin(), signed_shape.end()); + NVTEShape tmp = getTensorShape(Aux_CTX_Tensors[i]); NVTEBasicTensor temp_data = { Aux_CTX_Tensors[i].data_ptr(), - static_cast(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())), - nvte_make_shape(tmp.data(), tmp.size())}; + static_cast(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())), tmp}; nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); } diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index c59e3c4f6..6966e93b1 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -26,11 +26,11 @@ std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle // Grad output tensor auto grad_output_torch = grad_output.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto shape = getTensorShape(grad_output_torch); + const NVTEShape &shape = getTensorShape(grad_output_torch); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); // Construct grad bias tensor - const int64_t bias_size = static_cast(shape.back()); + const int64_t bias_size = static_cast(shape.data[shape.ndim - 1]); auto grad_bias_torch = allocateTorchTensor(bias_size, grad_output_dtype); auto grad_bias_nvte = makeTransformerEngineTensor(grad_bias_torch); @@ -116,17 +116,17 @@ std::vector dact_dbias( // Grad output and activation input tensors grad_output_torch = grad_output_torch.contiguous(); const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); - const auto output_shape = getTensorShape(grad_output_torch); + const NVTEShape &output_shape = getTensorShape(grad_output_torch); auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); act_input_torch = act_input_torch.contiguous(); const TensorWrapper &act_input_nvte = makeTransformerEngineTensor(act_input_torch); - const auto input_shape = getTensorShape(act_input_torch); + const NVTEShape &input_shape = getTensorShape(act_input_torch); // Construct tensors auto quantizer_cpp = convert_quantizer(quantizer_py); auto [grad_input_nvte, grad_input_py] = quantizer_cpp->create_tensor(input_shape, grad_output_dtype); - const int64_t bias_size = static_cast(input_shape.back()); + const int64_t bias_size = static_cast(input_shape.data[input_shape.ndim - 1]); auto grad_bias_torch = allocateTorchTensor(bias_size, grad_output_dtype); auto grad_bias_nvte = makeTransformerEngineTensor(grad_bias_torch); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 3bbc99b44..53de74b47 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -56,9 +56,8 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob TensorWrapper output_cpp; py::object output_py; if (output.is_none()) { - const auto shape = get_tensor_shape(input_cpp); const auto fake_dtype = input_cpp.dtype(); - std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype); + std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(input_cpp.shape(), fake_dtype); } else { std::tie(output_cpp, output_py) = quantizer_cpp->convert_and_update_tensor(output); } @@ -88,11 +87,7 @@ py::object dequantize(const py::handle &input, transformer_engine::DType otype) const auto &input_tensor = makeTransformerEngineTensor(input, none); NoneQuantizer q(none); - - const auto &shape = convertShape(input_tensor.shape()); - - auto [out_tensor, out] = q.create_tensor(shape, otype); - + auto [out_tensor, out] = q.create_tensor(input_tensor.shape(), otype); NVTE_SCOPED_GIL_RELEASE({ nvte_dequantize(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream()); }); @@ -180,8 +175,8 @@ std::vector multi_tensor_quantize(const std::vector &ten const auto input_dtype = GetTransformerEngineDType(input_py.scalar_type()); // Construct output tensor - std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); - auto [output_cpp, output_py] = quantizer_cpp_list[i]->create_tensor(output_shape, input_dtype); + // std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); + auto [output_cpp, output_py] = quantizer_cpp_list[i]->create_tensor(input_shape, input_dtype); output_cpp_list.emplace_back(std::move(output_cpp)); output_py_list.emplace_back(std::move(output_py)); } @@ -195,7 +190,7 @@ std::vector multi_tensor_quantize(const std::vector &ten namespace { std::tuple, std::vector> bulk_allocate_fp8_blockwise_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, + std::vector &shape_list, std::vector &quantizer_py_list, std::vector &quantizer_cpp_list) { init_extension(); std::tuple, std::vector> retval; @@ -220,9 +215,9 @@ std::tuple, std::vector> bulk_allocate_fp // Helper function to construct tensor view // Note: Deleter holds a shared_ptr for the buffer, so the buffer // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + auto make_torch_view = [](std::shared_ptr &buffer, const NVTEShape &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { - std::vector shape_int64(shape.begin(), shape.end()); + std::vector shape_int64(shape.data, shape.data + shape.ndim); bool is_empty_shape = product(shape) == 0; if (buffer->data_ptr() == nullptr || is_empty_shape) { return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); @@ -235,7 +230,7 @@ std::tuple, std::vector> bulk_allocate_fp // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list; - std::vector> rowwise_data_shapes, rowwise_scale_shapes; + std::vector rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { @@ -273,15 +268,16 @@ std::tuple, std::vector> bulk_allocate_fp // Allocate column-wise data std::vector columnwise_data_list, columnwise_scale_list; - std::vector> columnwise_data_shapes, columnwise_scale_shapes; + std::vector columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { columnwise_data_shapes.emplace_back(); - auto &shape = columnwise_data_shapes.back(); - shape.push_back(shape_list[i].back()); - for (size_t j = 0; j < shape_list[i].size() - 1; ++j) { - shape.push_back(shape_list[i][j]); + NVTEShape &shape = columnwise_data_shapes.back(); + shape.ndim = shape_list[i].ndim; + shape.data[0] = shape_list[i].data[shape_list[i].ndim - 1]; + for (size_t j = 0; j < shape_list[i].ndim - 1; ++j) { + shape.data[j + 1] = shape_list[i].data[j]; } columnwise_scale_shapes.emplace_back( quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true)); @@ -330,24 +326,23 @@ std::tuple, std::vector> bulk_allocate_fp tensor_py_list.emplace_back(Float8BlockwiseQTensorClass( rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype, quantizer_py_list[i], is_2D_scaled, Float8BlockScaleTensorFormat::GEMM_READY)); - // Construct C++ tensor tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_data_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_data_shapes[i] : std::vector{0}, fp8_dtype, nullptr, - nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, + rowwise_usage ? rowwise_data_shapes[i] : TensorWrapper::emptyShape, + columnwise_usage ? columnwise_data_shapes[i] : TensorWrapper::emptyShape, fp8_dtype, + nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode)); + rowwise_usage ? rowwise_scale_shapes[i] : TensorWrapper::emptyShape, + columnwise_usage ? columnwise_scale_shapes[i] : TensorWrapper::emptyShape, scaling_mode)); } return retval; } std::tuple, std::vector> bulk_allocate_mxfp8_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, + std::vector &shape_list, std::vector &quantizer_py_list, std::vector &quantizer_cpp_list) { init_extension(); std::tuple, std::vector> retval; @@ -371,9 +366,9 @@ std::tuple, std::vector> bulk_allocate_mx // Helper function to construct tensor view // Note: Deleter holds a shared_ptr for the buffer, so the buffer // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + auto make_torch_view = [](std::shared_ptr &buffer, const NVTEShape &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { - std::vector shape_int64(shape.begin(), shape.end()); + std::vector shape_int64(shape.data, shape.data + shape.ndim); bool is_empty_shape = product(shape) == 0; if (buffer->data_ptr() == nullptr || is_empty_shape) { return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); @@ -386,13 +381,13 @@ std::tuple, std::vector> bulk_allocate_mx // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list; - std::vector> rowwise_data_shapes, rowwise_scale_shapes; + std::vector rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_shapes.emplace_back(shape_list[i]); rowwise_scale_shapes.emplace_back( - quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); + quantizer_cpp_list[i]->get_scale_shape(rowwise_data_shapes[i], false)); } // Offsets in full buffer @@ -424,7 +419,7 @@ std::tuple, std::vector> bulk_allocate_mx // Allocate column-wise data std::vector columnwise_data_list, columnwise_scale_list; - std::vector> columnwise_data_shapes, columnwise_scale_shapes; + std::vector columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { @@ -477,17 +472,16 @@ std::tuple, std::vector> bulk_allocate_mx tensor_py_list.emplace_back(MXFP8TensorClass(rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype, quantizer_py_list[i])); - // Construct C++ tensor tensor_cpp_list.emplace_back(makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_data_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_data_shapes[i] : std::vector{0}, fp8_dtype, nullptr, - nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, + rowwise_usage ? rowwise_data_shapes[i] : TensorWrapper::emptyShape, + columnwise_usage ? columnwise_data_shapes[i] : TensorWrapper::emptyShape, fp8_dtype, + nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode)); + rowwise_usage ? rowwise_scale_shapes[i] : TensorWrapper::emptyShape, + columnwise_usage ? columnwise_scale_shapes[i] : TensorWrapper::emptyShape, scaling_mode)); } return retval; @@ -497,7 +491,7 @@ std::tuple, std::vector> bulk_allocate_mx // layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN] // amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate std::tuple, std::vector, bool> bulk_allocate_nvfp4_tensors( - std::vector> &shape_list, std::vector &quantizer_py_list, + std::vector &shape_list, std::vector &quantizer_py_list, std::vector &quantizer_cpp_list) { init_extension(); std::tuple, std::vector, bool> retval; @@ -522,9 +516,9 @@ std::tuple, std::vector, bool> bulk_alloc // Helper function to construct tensor view // Note: Deleter holds a shared_ptr for the buffer, so the buffer // will survive until all views are deleted. - auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + auto make_torch_view = [](std::shared_ptr &buffer, const NVTEShape &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { - std::vector shape_int64(shape.begin(), shape.end()); + std::vector shape_int64(shape.data, shape.data + shape.ndim); bool is_empty_shape = product(shape) == 0; if (buffer->data_ptr() == nullptr || is_empty_shape) { return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); @@ -535,24 +529,24 @@ std::tuple, std::vector, bool> bulk_alloc at::device(at::kCUDA).dtype(dtype)); }; - // Lambda function for converting std::vector shape to NVFP4 shape (last dim divided by 2) - auto to_fp4_shape = [](const std::vector &shape) { - std::vector fp4_shape(shape.begin(), shape.end()); - if (!fp4_shape.empty()) { - fp4_shape.back() /= 2; + // Lambda function for converting NVTEShape shape to NVFP4 shape (last dim divided by 2) + auto to_fp4_shape = [](const NVTEShape &shape) { + NVTEShape fp4_shape = shape; + if (fp4_shape.ndim != 0) { + fp4_shape.data[fp4_shape.ndim - 1] /= 2; } return fp4_shape; }; // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list, amax_rowwise_list; - std::vector> rowwise_data_shapes, rowwise_scale_shapes; + std::vector rowwise_data_shapes, rowwise_scale_shapes; if (rowwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_shapes.emplace_back(shape_list[i]); rowwise_scale_shapes.emplace_back( - quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); + quantizer_cpp_list[i]->get_scale_shape(rowwise_data_shapes[i], false)); } // Offsets in full buffer @@ -587,7 +581,6 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { rowwise_data_list.emplace_back(make_torch_view(buffer, to_fp4_shape(rowwise_data_shapes[i]), @@ -595,13 +588,13 @@ std::tuple, std::vector, bool> bulk_alloc rowwise_scale_list.emplace_back( make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); amax_rowwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + make_torch_view(buffer, TensorWrapper::defaultShape, amax_offsets[i], torch::kFloat32)); } } // Allocate column-wise data std::vector columnwise_data_list, columnwise_scale_list, amax_columnwise_list; - std::vector> columnwise_data_shapes, columnwise_scale_shapes; + std::vector columnwise_data_shapes, columnwise_scale_shapes; if (columnwise_usage) { // Tensor sizes for (size_t i = 0; i < num_tensors; ++i) { @@ -609,9 +602,10 @@ std::tuple, std::vector, bool> bulk_alloc // NVFP4 on SM100 is TN only columnwise_data_shapes.emplace_back(); auto &shape = columnwise_data_shapes.back(); - shape.push_back(shape_list[i].back()); - for (size_t j = 0; j < shape_list[i].size() - 1; ++j) { - shape.push_back(shape_list[i][j]); + shape.ndim = shape_list[i].ndim; + shape.data[0] = shape_list[i].data[shape_list[i].ndim - 1]; + for (size_t j = 0; j < shape_list[i].ndim - 1; ++j) { + shape.data[j + 1] = shape_list[i].data[j]; } columnwise_scale_shapes.emplace_back( quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true)); @@ -649,7 +643,6 @@ std::tuple, std::vector, bool> bulk_alloc // Allocate full buffer auto buffer = std::make_shared( at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { columnwise_data_list.emplace_back(make_torch_view( @@ -657,7 +650,7 @@ std::tuple, std::vector, bool> bulk_alloc columnwise_scale_list.emplace_back( make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); amax_columnwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + make_torch_view(buffer, TensorWrapper::defaultShape, amax_offsets[i], torch::kFloat32)); } } @@ -686,22 +679,22 @@ std::tuple, std::vector, bool> bulk_alloc auto tensor_wrapper = makeTransformerEngineTensor( rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_data_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_data_shapes[i] : std::vector{0}, fp4_dtype, + rowwise_usage ? rowwise_data_shapes[i] : TensorWrapper::emptyShape, + columnwise_usage ? columnwise_data_shapes[i] : TensorWrapper::emptyShape, fp4_dtype, /*amax_ptr=*/nullptr, /*scale_ptr=*/nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, - rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, - columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode); + rowwise_usage ? rowwise_scale_shapes[i] : TensorWrapper::emptyShape, + columnwise_usage ? columnwise_scale_shapes[i] : TensorWrapper::emptyShape, scaling_mode); // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, - std::vector{1}); + TensorWrapper::defaultShape); } if (columnwise_usage) { tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, - std::vector{1}); + TensorWrapper::defaultShape); } tensor_cpp_list.emplace_back(std::move(tensor_wrapper)); } @@ -765,9 +758,9 @@ static StochasticRngStateResources setup_stochastic_rounding_rng_states_helper( at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); int64_t *rng_state_ptr = static_cast(res.rng_states_tensor.data_ptr()) + i * 2; philox_unpack(philox_args, rng_state_ptr); - - res.te_rng_state_list.push_back(makeTransformerEngineTensor( - static_cast(rng_state_ptr), std::vector{2}, DType::kInt64)); + const NVTEShape rng_state_shape = make_nvte_1d_shape(2); + res.te_rng_state_list.push_back(makeTransformerEngineTensor(static_cast(rng_state_ptr), + rng_state_shape, DType::kInt64)); quant_config_list_rowwise[i].set_rng_state(res.te_rng_state_list[i].data()); quant_config_list_rowwise[i].set_stochastic_rounding(true); @@ -781,7 +774,7 @@ static StochasticRngStateResources setup_stochastic_rounding_rng_states_helper( philox_unpack(philox_args_col, rng_state_ptr_colwise); res.te_rng_state_list_colwise.push_back(makeTransformerEngineTensor( - static_cast(rng_state_ptr_colwise), std::vector{2}, DType::kInt64)); + static_cast(rng_state_ptr_colwise), rng_state_shape, DType::kInt64)); quant_config_list_colwise[i].set_rng_state(res.te_rng_state_list_colwise[i].data()); quant_config_list_colwise[i].set_stochastic_rounding(true); } @@ -1003,12 +996,12 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); - output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + output_list[i].set_amax(amax_ptr, DType::kFloat32, TensorWrapper::defaultShape); } nvte_group_amax(input.data(), reinterpret_cast(nvte_tensor_output_list.data()), split_sections.data(), num_tensors, stream); for (size_t i = 0; i < num_tensors; i++) { - output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector{1}); + output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, TensorWrapper::defaultShape); } // Quantize tensors individually @@ -1104,27 +1097,32 @@ std::vector split_quantize(const at::Tensor &tensor, auto input_py = tensor.contiguous(); uint8_t *input_dptr = reinterpret_cast(input_py.data_ptr()); auto input_dtype = GetTransformerEngineDType(input_py.scalar_type()); - std::vector input_shape; + NVTEShape input_shape; + input_shape.ndim = 0; size_t input_size = 1; for (const auto &d : input_py.sizes()) { - input_shape.push_back(d); + input_shape.data[input_shape.ndim++] = static_cast(d); input_size *= d; } - NVTE_CHECK(input_shape.size() > 0, "Input tensor has 0 dims"); + NVTE_CHECK(input_shape.ndim > 0, "Input tensor has 0 dims"); // Split input tensor along dim 0 std::vector input_list; - std::vector> split_shapes; + std::vector split_shapes; size_t dim0_offset = 0; const size_t dim0_stride = - input_shape[0] == 0 ? 0 : input_py.element_size() * input_size / input_shape[0]; + input_shape.data[0] == 0 ? 0 : input_py.element_size() * input_size / input_shape.data[0]; for (size_t i = 0; i < num_splits; ++i) { - NVTE_CHECK(dim0_offset + split_sections[i] <= input_shape[0], - "Attempted to split tensor with shape=", input_shape, - " along dim 0 with split_sections=", split_sections); - split_shapes.push_back(input_shape); + NVTE_CHECK(dim0_offset + split_sections[i] <= input_shape.data[0], + "Attempted to split tensor with dim 0 shape=", input_shape.data[0], + " with split_section size =", split_sections[i]); + split_shapes.emplace_back(); auto &split_shape = split_shapes.back(); - split_shape[0] = split_sections[i]; + split_shape.ndim = input_shape.ndim; + split_shape.data[0] = split_sections[i]; + for (size_t j = 1; j < input_shape.ndim; ++j) { + split_shape.data[j] = input_shape.data[j]; + } void *split_dptr = static_cast(input_dptr + dim0_offset * dim0_stride); input_list.emplace_back(makeTransformerEngineTensor(split_dptr, split_shape, input_dtype)); dim0_offset += split_sections[i]; diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 07ddfbeb6..8fc6f9b2d 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -40,8 +40,8 @@ bool is_low_precision(const DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } -std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool transa, - const NVTEShape& B_shape, const bool transb) { +NVTEShape getGemmOutputShape(const NVTEShape& A_shape, const bool transa, const NVTEShape& B_shape, + const bool transb) { // Flatten outer dims to get 2D matrices const size_t A0 = A_shape.ndim > 0 ? product(A_shape, 0, A_shape.ndim - 1) : 1; const size_t A1 = A_shape.ndim > 0 ? A_shape.data[A_shape.ndim - 1] : 1; @@ -53,35 +53,37 @@ std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool tran A0, ",", A1, "), transa=", transa, ", B=(", B0, ",", B1, "), transb=", transb, ")"); // Construct output dims - std::vector ret; + NVTEShape ret; + size_t idx = 0; if (transb) { - ret.emplace_back(B1); + ret.data[idx++] = B1; } else { // Unflatten B0 for (size_t i = 0; i < B_shape.ndim - 1; ++i) { - ret.emplace_back(B_shape.data[i]); + ret.data[idx++] = B_shape.data[i]; } } if (transa) { - ret.emplace_back(A0); + ret.data[idx++] = A0; } else { - ret.emplace_back(A1); + ret.data[idx++] = A1; } + ret.ndim = idx; return ret; } -bool checkGemmShape(const std::vector& expected, const NVTEShape& actual) { - if (expected.size() != actual.ndim) return false; - for (size_t i = 0; i < expected.size(); ++i) { - if (expected[i] != actual.data[i]) return false; +bool checkGemmShape(const NVTEShape& expected, const NVTEShape& actual) { + if (expected.ndim != actual.ndim) return false; + for (size_t i = 0; i < expected.ndim; ++i) { + if (expected.data[i] != actual.data[i]) return false; } return true; } } // namespace detail -std::pair createOutputTensor(const std::vector& shape, - DType dtype, py::handle quantizer) { +std::pair createOutputTensor(const NVTEShape& shape, DType dtype, + py::handle quantizer) { std::unique_ptr my_quantizer = convert_quantizer(quantizer); return my_quantizer->create_tensor(shape, dtype); } @@ -98,7 +100,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Ensure that cublasLt handle is created on the correct device, // overriding torch.cuda.set_device calls from user side. // Assumes all tensors passed are on the same device. - at::cuda::CUDAGuard device_guard(workspace.device()); + auto device = workspace.device(); + at::cuda::CUDAGuard device_guard(device); // Input tensors NVTE_CHECK(!A.is_none(), "Tensor A has not been provided"); @@ -196,10 +199,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans if (!grad) { auto dtype = GetATenDType(gelu_type); auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); - std::vector torch_shape; - for (auto v : D_shape) { - torch_shape.push_back(v); - } + std::vector torch_shape(D_shape.data, D_shape.data + D_shape.ndim); pre_gelu_out = at::empty(torch_shape, opts); } else { if (gelu_in.has_value()) { @@ -207,18 +207,18 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } } - const auto gelu_shape = gelu ? D_shape : std::vector{0}; + const auto gelu_shape = gelu ? D_shape : TensorWrapper::emptyShape; auto te_pre_gelu_out = makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type); // Workspace auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), - std::vector{workspaceSize}, DType::kByte); + make_nvte_1d_shape(workspaceSize), DType::kByte); // Set an external SM Margin to all the GEMMs. // This comes in handy when DP is overlapped with GEMMs - const int device_id = at::cuda::current_device(); + const int device_id = device.index(); const int sm_count = transformer_engine::cuda::sm_count(device_id); int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); @@ -264,7 +264,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans extra_output_tensor = makeTransformerEngineTensor(*extra_output); } else { extra_output_tensor = - makeTransformerEngineTensor(nullptr, std::vector{0}, DType::kByte); + makeTransformerEngineTensor(nullptr, TensorWrapper::emptyShape, DType::kByte); } // Direct GEMM call to the correct overlap @@ -365,32 +365,28 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING; - auto te_A = makeTransformerEngineTensor( - A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, - nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), - nvte_scaling_modeA); - auto te_B = makeTransformerEngineTensor( - B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, - nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), - nvte_scaling_modeB); + auto te_A = makeTransformerEngineTensor(A.data_ptr(), make_nvte_2d_shape(A.size(0), A.size(1)), + A_type, nullptr, nullptr, A_scale_inverse.data_ptr(), + getTensorShape(A_scale_inverse), nvte_scaling_modeA); + auto te_B = makeTransformerEngineTensor(B.data_ptr(), make_nvte_2d_shape(B.size(0), B.size(1)), + B_type, nullptr, nullptr, B_scale_inverse.data_ptr(), + getTensorShape(B_scale_inverse), nvte_scaling_modeB); // TODO: D_scale_inv cannot be nullptr when D_type is FP8. - auto te_D = makeTransformerEngineTensor( - D.data_ptr(), - std::vector{static_cast(D.size(0)), static_cast(D.size(1))}, D_type, - D_amax.data_ptr(), D_scale.data_ptr(), nullptr); - auto te_bias = makeTransformerEngineTensor( - bias.data_ptr(), std::vector{static_cast(bias.size(0))}, bias_type); - auto te_counter = makeTransformerEngineTensor( - counter.data_ptr(), std::vector{static_cast(counter.size(0))}, DType::kInt32); + auto te_D = makeTransformerEngineTensor(D.data_ptr(), make_nvte_2d_shape(D.size(0), D.size(1)), + D_type, D_amax.data_ptr(), D_scale.data_ptr(), nullptr, + TensorWrapper::defaultShape); + auto te_bias = + makeTransformerEngineTensor(bias.data_ptr(), make_nvte_1d_shape(bias.size(0)), bias_type); + auto te_counter = makeTransformerEngineTensor(counter.data_ptr(), + make_nvte_1d_shape(counter.size(0)), DType::kInt32); const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr - ? std::vector{static_cast(pre_gelu_out.size(0))} - : std::vector{static_cast(pre_gelu_out.size(0)), - static_cast(pre_gelu_out.size(1))}; + ? make_nvte_1d_shape(pre_gelu_out.size(0)) + : make_nvte_2d_shape(pre_gelu_out.size(0), pre_gelu_out.size(1)); auto te_pre_gelu_out = makeTransformerEngineTensor( pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), - std::vector{workspaceSize}, DType::kByte); + make_nvte_1d_shape(workspaceSize), DType::kByte); NVTE_SCOPED_GIL_RELEASE({ nvte_cublas_atomic_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), @@ -430,12 +426,13 @@ std::optional> te_general_grouped_gemm( // if there is single output at::Tensor out_tensor; - auto size_t_shape = + const NVTEShape nvte_D_shape = pytorch::detail::getGemmOutputShape(te_A.shape(), transa, te_B.shape(), transb); bool D_numel_is_zero = false; std::vector D_shape; - for (size_t t : size_t_shape) { - D_shape.push_back(t); + for (size_t j = 0; j < nvte_D_shape.ndim; ++j) { + const size_t t = nvte_D_shape.data[j]; + D_shape.push_back(static_cast(t)); if (t == 0) { D_numel_is_zero = true; } @@ -480,10 +477,12 @@ std::optional> te_general_grouped_gemm( auto te_bias = makeTransformerEngineTensor(bias[i]); auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); - const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr - ? std::vector{static_cast(te_pre_gelu_out.size(0))} - : std::vector{static_cast(te_pre_gelu_out.size(0)), - static_cast(te_pre_gelu_out.size(1))}; + NVTEShape gelu_shape; + if (pre_gelu_out[i].data_ptr() == nullptr) { + gelu_shape = make_nvte_1d_shape(te_pre_gelu_out.size(0)); + } else { + gelu_shape = make_nvte_2d_shape(te_pre_gelu_out.size(0), te_pre_gelu_out.size(1)); + } DType gelu_type = bias_type; te_pre_gelu_out = @@ -552,7 +551,7 @@ std::optional> te_general_grouped_gemm( std::vector te_workspace_wrappers; for (size_t i = 0; i < workspace.size(); i++) { auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), - std::vector{workspaceSize}, DType::kByte); + make_nvte_1d_shape(workspaceSize), DType::kByte); te_workspace_vector.emplace_back(wsp.data()); te_workspace_wrappers.emplace_back(std::move(wsp)); } diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index d7a07724c..aa6d7c9a8 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -79,9 +79,9 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe } // Tensor dimensions - const auto shape = nvte_shape_to_vector(input_nvte.shape()); - const auto outer_size = product(shape) / shape.back(); - const auto inner_size = shape.back(); + const NVTEShape &shape = input_nvte.shape(); + const auto outer_size = product(shape) / shape.data[shape.ndim - 1]; + const auto inner_size = shape.data[shape.ndim - 1]; // Tensors to save for backward pass at::Tensor mu_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); @@ -310,9 +310,9 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w const TensorWrapper &weight_nvte = makeTransformerEngineTensor(weight, none); // Tensor dimensions - const auto shape = nvte_shape_to_vector(input_nvte.shape()); - const auto outer_size = product(shape) / shape.back(); - const auto inner_size = shape.back(); + const NVTEShape &shape = input_nvte.shape(); + const auto outer_size = product(shape) / shape.data[shape.ndim - 1]; + const auto inner_size = shape.data[shape.ndim - 1]; // Tensors to save for backward pass at::Tensor rsigma_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cpp b/transformer_engine/pytorch/csrc/extensions/padding.cpp index 6c66fda01..577005446 100644 --- a/transformer_engine/pytorch/csrc/extensions/padding.cpp +++ b/transformer_engine/pytorch/csrc/extensions/padding.cpp @@ -20,7 +20,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, const auto num_tensors = input_row_list.size(); // Extract properties from PyTorch tensors std::vector input_dptr_list, output_dptr_list; - std::vector> input_shape_list, output_shape_list; + std::vector input_shape_list, output_shape_list; std::vector input_type_list; void* d_input_ptr = reinterpret_cast(input.data_ptr()); void* d_output_ptr = reinterpret_cast(output.data_ptr()); @@ -34,8 +34,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, input_row_list[tensor_id] * input.size(1) * input.element_size(); input_char_ptr += input_dptr_offset; d_input_ptr = reinterpret_cast(input_char_ptr); - - input_shape_list.push_back({input_row_list[tensor_id], static_cast(input.size(1))}); + input_shape_list.push_back(make_nvte_2d_shape(input_row_list[tensor_id], input.size(1))); input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); // Move the output pointer to the next split. @@ -46,13 +45,13 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, d_output_ptr = reinterpret_cast(output_char_ptr); output_shape_list.push_back( - {padded_input_row_list[tensor_id], static_cast(output.size(1))}); + make_nvte_2d_shape(padded_input_row_list[tensor_id], output.size(1))); } // Construct TE tensors std::vector nvte_input_list, nvte_output_list; std::vector tensor_wrappers; - auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + auto make_tensor = [&tensor_wrappers](void* dptr, const NVTEShape& shape, DType dtype) -> NVTETensor { tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype)); return tensor_wrappers.back().data(); @@ -95,7 +94,7 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, const auto num_tensors = input_row_list.size(); // Extract properties from PyTorch tensors std::vector input_dptr_list, output_dptr_list; - std::vector> input_shape_list, output_shape_list; + std::vector input_shape_list, output_shape_list; std::vector input_type_list; void* d_input_ptr = reinterpret_cast(input.data_ptr()); void* d_output_ptr = reinterpret_cast(output.data_ptr()); @@ -109,8 +108,7 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, input_row_list[tensor_id] * input.size(1) * input.element_size(); input_char_ptr += input_dptr_offset; d_input_ptr = reinterpret_cast(input_char_ptr); - - input_shape_list.push_back({input_row_list[tensor_id], static_cast(input.size(1))}); + input_shape_list.push_back(make_nvte_2d_shape(input_row_list[tensor_id], input.size(1))); input_type_list.push_back(GetTransformerEngineDType(input.scalar_type())); // Move the output pointer to the next split. @@ -121,13 +119,13 @@ void fused_multi_row_unpadding(at::Tensor input, at::Tensor output, d_output_ptr = reinterpret_cast(output_char_ptr); output_shape_list.push_back( - {unpadded_input_row_list[tensor_id], static_cast(output.size(1))}); + make_nvte_2d_shape(unpadded_input_row_list[tensor_id], output.size(1))); } // Construct TE tensors std::vector nvte_input_list, nvte_output_list; std::vector tensor_wrappers; - auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, + auto make_tensor = [&tensor_wrappers](void* dptr, const NVTEShape& shape, transformer_engine::DType dtype) -> NVTETensor { tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype)); return tensor_wrappers.back().data(); diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cpp b/transformer_engine/pytorch/csrc/extensions/permutation.cpp index 226705b16..09c2d5078 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cpp @@ -60,19 +60,13 @@ std::tuple> moe_permute_fwd( {num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); auto stream = at::cuda::getCurrentCUDAStream().stream(); - auto input_cu = makeTransformerEngineTensor( - input.data_ptr(), - std::vector{static_cast(input.size(0)), static_cast(num_cols)}, - dtype); - auto permuted_output_cu = - makeTransformerEngineTensor(permuted_output.data_ptr(), - std::vector{static_cast(permuted_output.size(0)), - static_cast(num_cols)}, - dtype); + input.data_ptr(), make_nvte_2d_shape(input.size(0), input.size(1)), dtype); + auto permuted_output_cu = makeTransformerEngineTensor( + permuted_output.data_ptr(), + make_nvte_2d_shape(permuted_output.size(0), permuted_output.size(1)), dtype); auto sorted_row_id_cu = makeTransformerEngineTensor( - sorted_row_id_ptr, std::vector{static_cast(num_tokens * topK)}, - DType::kInt32); + sorted_row_id_ptr, make_nvte_1d_shape(num_tokens * topK), DType::kInt32); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); nvte_permute(input_cu.data(), permuted_output_cu.data(), sorted_row_id_cu.data(), @@ -97,16 +91,11 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const DType dtype, at::Tensor row torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false)); auto stream = at::cuda::getCurrentCUDAStream().stream(); - auto input_cu = makeTransformerEngineTensor( - input.data_ptr(), - std::vector{static_cast(input.size(0)), static_cast(num_cols)}, - dtype); + input.data_ptr(), make_nvte_2d_shape(input.size(0), input.size(1)), dtype); auto unpermuted_output_cu = makeTransformerEngineTensor( unpermuted_output.data_ptr(), - std::vector{static_cast(unpermuted_output.size(0)), - static_cast(num_cols)}, - dtype); + make_nvte_2d_shape(unpermuted_output.size(0), unpermuted_output.size(1)), dtype); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto prob_cu = makeTransformerEngineTensor(prob); @@ -133,17 +122,11 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T auto stream = at::cuda::getCurrentCUDAStream().stream(); auto input_bwd_cu = makeTransformerEngineTensor( - input_bwd.data_ptr(), - std::vector{static_cast(input_bwd.size(0)), static_cast(num_cols)}, - dtype); + input_bwd.data_ptr(), make_nvte_2d_shape(input_bwd.size(0), num_cols), dtype); auto act_grad_cu = makeTransformerEngineTensor( - act_grad.data_ptr(), - std::vector{static_cast(act_grad.size(0)), static_cast(num_cols)}, - dtype); + act_grad.data_ptr(), make_nvte_2d_shape(act_grad.size(0), num_cols), dtype); auto input_fwd_cu = makeTransformerEngineTensor( - input_fwd.data_ptr(), - std::vector{static_cast(input_fwd.size(0)), static_cast(num_cols)}, - dtype); + input_fwd.data_ptr(), make_nvte_2d_shape(input_fwd.size(0), num_cols), dtype); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto prob_cu = makeTransformerEngineTensor(prob); auto prob_grad_cu = makeTransformerEngineTensor(prob_grad); diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index c02d2ec61..29ddc352c 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -42,14 +42,14 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio for (size_t i = 0; i < num_tensors; i++) { te_amax_histories.push_back(nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING)); NVTETensor& amax_history = te_amax_histories.back(); - NVTEShape amax_shape = convertTorchShape(amax_histories[i].sizes()); + NVTEShape amax_shape = getTensorShape(amax_histories[i]); NVTEBasicTensor amax_history_data = {amax_histories[i].data_ptr(), static_cast(DType::kFloat32), amax_shape}; nvte_set_tensor_param(&amax_history, kNVTERowwiseData, &amax_history_data); te_scales.push_back(nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING)); NVTETensor& scale = te_scales.back(); - NVTEShape scale_shape = convertTorchShape(scales[i].sizes()); + NVTEShape scale_shape = getTensorShape(scales[i]); NVTEBasicTensor scale_data = {scales[i].data_ptr(), static_cast(DType::kFloat32), scale_shape}; nvte_set_tensor_param(&scale, kNVTERowwiseData, &scale_data); diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 477d7c87e..f9628d579 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -19,16 +19,16 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional transpose_shape_int64; - if (shape.size() > 0) { - transpose_shape_int64.push_back(shape.back()); - for (size_t i = 0; i < shape.size() - 1; ++i) { - transpose_shape_int64.push_back(shape[i]); + if (shape.ndim > 0) { + transpose_shape_int64.push_back(shape.data[shape.ndim - 1]); + for (size_t i = 0; i < shape.ndim - 1; ++i) { + transpose_shape_int64.push_back(shape.data[i]); } } - const size_t M = shape.size() > 0 ? product(shape) / shape.back() : 1; - const size_t N = shape.size() > 0 ? shape.back() : 1; + const size_t M = shape.ndim > 0 ? product(shape) / shape.data[shape.ndim - 1] : 1; + const size_t N = shape.ndim > 0 ? shape.data[shape.ndim - 1] : 1; // Output tensor at::Tensor out; @@ -45,8 +45,8 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional{M, N}, otype); - auto output_cu = makeTransformerEngineTensor(out.data_ptr(), std::vector{N, M}, otype); + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), make_nvte_2d_shape(M, N), otype); + auto output_cu = makeTransformerEngineTensor(out.data_ptr(), make_nvte_2d_shape(N, M), otype); nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return out; @@ -56,15 +56,16 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional out) { init_extension(); // Make sure input is contiguous - const auto &input = tensor.contiguous(); + const auto& input = tensor.contiguous(); // Allocate output tensor if needed if (!out) { - auto in_shape = getTensorShape(input); - NVTE_CHECK(in_shape.size() >= 2, "Invalid input tensor dimensions (shape=", in_shape, ")"); - std::vector out_shape_int64(in_shape.begin(), in_shape.end()); - out_shape_int64[0] = static_cast(in_shape[1]); - out_shape_int64[1] = static_cast(in_shape[0]); + const NVTEShape& in_shape = getTensorShape(input); + NVTE_CHECK(in_shape.ndim >= 2, "Invalid input tensor dimensions with ", in_shape.ndim, + " number of dimensions"); + std::vector out_shape_int64(in_shape.data, in_shape.data + in_shape.ndim); + out_shape_int64[0] = static_cast(in_shape.data[1]); + out_shape_int64[1] = static_cast(in_shape.data[0]); auto opts = at::TensorOptions().dtype(input.dtype()).device(input.device()); out = at::empty(out_shape_int64, opts); } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index a73efc008..4e96b0d12 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -19,18 +19,56 @@ namespace { * The tensor is interpreted as a 2D matrix by flattening all but the * last dimension, and then transposed. */ -template -std::vector make_transpose_shape(const std::vector& shape) { - std::vector ret; - if (shape.size() > 0) { - ret.push_back(shape.back()); - for (size_t i = 0; i < shape.size() - 1; ++i) { - ret.push_back(shape[i]); +template +std::vector make_transpose_shape(const NVTEShape& shape) { + std::vector ret; + if (shape.ndim > 0) { + ret.push_back(shape.data[shape.ndim - 1]); + for (size_t i = 0; i < shape.ndim - 1; ++i) { + ret.push_back(static_cast(shape.data[i])); + } + } + return ret; +} + +// Specialization for NVTEShape +NVTEShape make_transpose_nvte_shape(const NVTEShape& shape) { + NVTEShape ret; + ret.ndim = shape.ndim; + if (shape.ndim > 0) { + ret.data[0] = shape.data[shape.ndim - 1]; + for (size_t i = 0; i < shape.ndim - 1; ++i) { + ret.data[i + 1] = shape.data[i]; } } return ret; } +/*! @brief Compare two NVTEShape objects for equality */ +inline bool shapes_equal(const NVTEShape& shape1, const NVTEShape& shape2) { + if (shape1.ndim != shape2.ndim) { + return false; + } + for (size_t i = 0; i < shape1.ndim; ++i) { + if (shape1.data[i] != shape2.data[i]) { + return false; + } + } + return true; +} + +std::string shape_to_string(const NVTEShape& shape) { + std::string s = "["; + for (size_t i = 0; i < shape.ndim; ++i) { + s += std::to_string(shape.data[i]); + if (i != shape.ndim - 1) { + s += ", "; + } + } + s += "]"; + return s; +} + /*! @brief Convert shape for FP4 data by dividing the last dimension by 2 */ template std::vector convert_shape_for_fp4(const std::vector& shape) { @@ -70,14 +108,14 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti this->dtype = type; } -std::pair NoneQuantizer::create_tensor(const std::vector& shape, +std::pair NoneQuantizer::create_tensor(const NVTEShape& shape, DType dtype) const { - const std::vector shape_int64(shape.begin(), shape.end()); + const std::vector shape_int64(shape.data, shape.data + shape.ndim); const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); return create_tensor(shape, dtype, at::empty(shape_int64, opts)); } -std::pair NoneQuantizer::create_tensor(const std::vector& shape, +std::pair NoneQuantizer::create_tensor(const NVTEShape& shape, DType dtype, at::Tensor data) const { TensorWrapper out_cpp; @@ -110,22 +148,22 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { getTensorShape(amax)); } -std::pair Float8Quantizer::create_tensor( - const std::vector& shape, DType dtype) const { +std::pair Float8Quantizer::create_tensor(const NVTEShape& shape, + DType dtype) const { const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); at::Tensor scale_inv = at::empty(std::vector{1}, opts); return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); } std::pair Float8Quantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional data, + const NVTEShape& shape, DType dtype, std::optional data, std::optional transpose, std::optional scale_inv) const { using namespace pybind11::literals; - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Initialize data tensor - const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; if (with_data && !data) { - const std::vector shape_int64(shape.begin(), shape.end()); + const std::vector shape_int64(shape.data, shape.data + shape.ndim); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); data = at::empty(shape_int64, opts); } else if (!with_data && data) { @@ -134,7 +172,7 @@ std::pair Float8Quantizer::create_tensor( py::object data_py = with_data ? py::cast(*data) : py::none(); // Initialize transpose tensor - const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose && !transpose) { const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -158,7 +196,7 @@ std::pair Float8Quantizer::create_tensor( "quantizer"_a = this->quantizer); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); - const std::vector shape_int64(shape.begin(), shape.end()); + const std::vector shape_int64(shape.data, shape.data + shape.ndim); out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), "data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, @@ -169,13 +207,14 @@ std::pair Float8Quantizer::create_tensor( TensorWrapper out_cpp(this->get_scaling_mode()); if (with_data) { out_cpp.set_rowwise_data(data->data_ptr(), this->dtype, shape); - out_cpp.set_rowwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32, std::vector{1}); + out_cpp.set_rowwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32, + TensorWrapper::defaultShape); } if (with_transpose) { - const auto transpose_shape = make_transpose_shape(shape); + const auto& transpose_shape = make_transpose_nvte_shape(shape); out_cpp.set_columnwise_data(transpose->data_ptr(), this->dtype, transpose_shape); out_cpp.set_columnwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32, - std::vector{1}); + TensorWrapper::defaultShape); } this->set_quantization_params(&out_cpp); @@ -187,8 +226,9 @@ std::pair Float8Quantizer::convert_and_update_tensor( NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); // Expected buffers - const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); - const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; + const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid usages for Float8Quantizer."); // Extract buffers from Python tensor @@ -207,19 +247,20 @@ std::pair Float8Quantizer::convert_and_update_tensor( at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast(); // Tensor dimensions - std::vector shape; + NVTEShape shape; if (has_transpose) { - const auto transpose_shape = getTensorShape(*transpose_tensor); - if (transpose_shape.size() > 0) { - for (size_t i = 1; i < transpose_shape.size(); ++i) { - shape.push_back(transpose_shape[i]); + const NVTEShape transpose_shape = getTensorShape(*transpose_tensor); + if (transpose_shape.ndim > 0) { + for (size_t i = 1; i < transpose_shape.ndim; ++i) { + shape.data[i - 1] = transpose_shape.data[i]; } - shape.push_back(transpose_shape.front()); + shape.data[transpose_shape.ndim - 1] = transpose_shape.data[0]; + shape.ndim = transpose_shape.ndim; } if (has_data) { - auto expected_shape = getTensorShape(*data_tensor); - NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, - ") and transpose (shape=", transpose_shape, ") do not match"); + const NVTEShape expected_shape = getTensorShape(*data_tensor); + NVTE_CHECK(shapes_equal(shape, expected_shape), "FP8 data (shape=", shape_to_string(shape), + ") and transpose (shape=", shape_to_string(transpose_shape), ") do not match"); } } else { // Already checked has_data == true shape = getTensorShape(*data_tensor); @@ -231,7 +272,7 @@ std::pair Float8Quantizer::convert_and_update_tensor( data_py = py::none(); tensor.attr("_data") = data_py; } else if (!has_data && need_data) { - const std::vector shape_int64(shape.begin(), shape.end()); + const std::vector shape_int64(shape.data, shape.data + shape.ndim); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); data_tensor = at::empty(shape_int64, opts); data_py = py::cast(data_tensor); @@ -260,13 +301,13 @@ std::pair Float8Quantizer::convert_and_update_tensor( if (data_tensor) { out_cpp.set_rowwise_data(data_tensor->data_ptr(), this->dtype, shape); out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, - std::vector{1}); + TensorWrapper::defaultShape); } if (transpose_tensor) { - const auto transpose_shape = make_transpose_shape(shape); + const auto transpose_shape = make_transpose_nvte_shape(shape); out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape); out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, - std::vector{1}); + TensorWrapper::defaultShape); } this->set_quantization_params(&out_cpp); @@ -323,21 +364,22 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso } std::pair Float8CurrentScalingQuantizer::create_tensor( - const std::vector& shape, DType dtype) const { + const NVTEShape& shape, DType dtype) const { using namespace pybind11::literals; // Initialize data tensor at::Tensor data_tensor; - const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; if (with_data) { - const std::vector shape_int64(shape.begin(), shape.end()); + const std::vector shape_int64(shape.data, shape.data + shape.ndim); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); data_tensor = at::empty(shape_int64, opts); } // Initialize transpose tensor at::Tensor transpose_tensor; - const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose) { const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -363,7 +405,7 @@ std::pair Float8CurrentScalingQuantizer::create_tenso "quantizer"_a = this->quantizer); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); - const std::vector shape_int64(shape.begin(), shape.end()); + const std::vector shape_int64(shape.data, shape.data + shape.ndim); out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), "data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, @@ -375,13 +417,13 @@ std::pair Float8CurrentScalingQuantizer::create_tenso if (with_data) { out_cpp.set_rowwise_data(data_tensor.data_ptr(), this->dtype, shape); out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, - std::vector{1}); + TensorWrapper::defaultShape); } if (with_transpose) { - const auto transpose_shape = make_transpose_shape(shape); + const auto transpose_shape = make_transpose_nvte_shape(shape); out_cpp.set_columnwise_data(transpose_tensor.data_ptr(), this->dtype, transpose_shape); out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, - std::vector{1}); + TensorWrapper::defaultShape); } this->set_quantization_params(&out_cpp); @@ -389,7 +431,7 @@ std::pair Float8CurrentScalingQuantizer::create_tenso } std::pair -Float8CurrentScalingQuantizer::create_unquantized_tensor_with_amax(const std::vector& shape, +Float8CurrentScalingQuantizer::create_unquantized_tensor_with_amax(const NVTEShape& shape, DType dtype, std::optional data) { amax.zero_(); @@ -408,8 +450,9 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ "Float8CurrentScalingQuantizer must output to Float8Tensor."); // Expected buffers - const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); - const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; + const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages."); // Extract buffers from Python tensor @@ -428,19 +471,20 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast(); // Tensor dimensions - std::vector shape; + NVTEShape shape; if (has_transpose) { - const auto transpose_shape = getTensorShape(*transpose_tensor); - if (transpose_shape.size() > 0) { - for (size_t i = 1; i < transpose_shape.size(); ++i) { - shape.push_back(transpose_shape[i]); + const NVTEShape transpose_shape = getTensorShape(*transpose_tensor); + if (transpose_shape.ndim > 0) { + for (size_t i = 1; i < transpose_shape.ndim; ++i) { + shape.data[i - 1] = transpose_shape.data[i]; } - shape.push_back(transpose_shape.front()); + shape.data[transpose_shape.ndim - 1] = transpose_shape.data[0]; + shape.ndim = transpose_shape.ndim; } if (has_data) { - auto expected_shape = getTensorShape(*data_tensor); - NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, - ") and transpose (shape=", transpose_shape, ") do not match"); + const NVTEShape expected_shape = getTensorShape(*data_tensor); + NVTE_CHECK(shapes_equal(shape, expected_shape), "FP8 data (shape=", shape_to_string(shape), + ") and transpose (shape=", shape_to_string(transpose_shape), ") do not match"); } } else { // Already checked has_data == true shape = getTensorShape(*data_tensor); @@ -452,7 +496,7 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ data_py = py::none(); tensor.attr("_data") = data_py; } else if (!has_data && need_data) { - const std::vector shape_int64(shape.begin(), shape.end()); + const std::vector shape_int64(shape.data, shape.data + shape.ndim); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); data_tensor = at::empty(shape_int64, opts); data_py = py::cast(data_tensor); @@ -481,13 +525,13 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ if (data_tensor) { out_cpp.set_rowwise_data(data_tensor->data_ptr(), this->dtype, shape); out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, - std::vector{1}); + TensorWrapper::defaultShape); } if (transpose_tensor) { - const auto transpose_shape = make_transpose_shape(shape); + const auto transpose_shape = make_transpose_nvte_shape(shape); out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape); out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, - std::vector{1}); + TensorWrapper::defaultShape); } this->set_quantization_params(&out_cpp); @@ -531,7 +575,7 @@ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, Te NVTE_SCOPED_GIL_RELEASE({ nvte_compute_scale_from_amax(out.data(), quant_config, stream); }); // Cast to FP8 - out.set_amax(nullptr, DType::kFloat32, out.defaultShape); // Avoid atomic amax updates + out.set_amax(nullptr, DType::kFloat32, TensorWrapper::defaultShape); // Avoid atomic amax updates NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); }); } @@ -544,7 +588,7 @@ void Float8CurrentScalingQuantizer::quantize_with_amax( TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag) { NVTE_CHECK(input.get_amax().data_ptr == amax.data_ptr(), "Input does not use the appropriate amax tensor"); - input.set_amax(nullptr, DType::kFloat32, input.defaultShape); + input.set_amax(nullptr, DType::kFloat32, TensorWrapper::defaultShape); this->quantize_impl(input, out, noop_flag, false); } @@ -560,12 +604,12 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} -std::pair Float8BlockQuantizer::create_tensor( - const std::vector& shape, DType dtype) const { +std::pair Float8BlockQuantizer::create_tensor(const NVTEShape& shape, + DType dtype) const { using namespace pybind11::literals; std::vector torch_shape; - for (auto s : shape) { - torch_shape.emplace_back(static_cast(s)); + for (size_t i = 0; i < shape.ndim; ++i) { + torch_shape.emplace_back(static_cast(shape.data[i])); } TensorWrapper tensor(this->get_scaling_mode()); @@ -582,29 +626,30 @@ std::pair Float8BlockQuantizer::create_tensor( if (rowwise_usage) { data_rowwise = at::empty(torch_shape, opts); auto scale_shape = get_scale_shape(shape, false); - size_t sinv0 = scale_shape[0]; - size_t sinv1 = scale_shape[1]; + size_t sinv0 = scale_shape.data[0]; + size_t sinv1 = scale_shape.data[1]; scale_inv_rowwise = at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); tensor.set_rowwise_data(data_rowwise.data_ptr(), this->dtype, shape); tensor.set_rowwise_scale_inv(scale_inv_rowwise.data_ptr(), DType::kFloat32, - std::vector{sinv0, sinv1}); + make_nvte_2d_shape(sinv0, sinv1)); } if (columnwise_usage) { std::vector torch_columnwise_shape; - std::vector columnwise_shape; - NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ", - columnwise_shape, " torch shape: ", torch_columnwise_shape); + NVTEShape columnwise_shape; + columnwise_shape.ndim = 0; + NVTE_CHECK(torch_shape.size() == shape.ndim, "Shape expected to match torch shape. Shape ", + shape_to_string(columnwise_shape), " torch shape: ", torch_columnwise_shape); if (torch_shape.size() > 0) { if (!all_gather_usage) { torch_columnwise_shape.reserve(torch_shape.size()); - columnwise_shape.reserve(shape.size()); torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); - columnwise_shape.push_back(shape[shape.size() - 1]); + columnwise_shape.ndim = shape.ndim; + columnwise_shape.data[0] = shape.data[shape.ndim - 1]; for (size_t i = 0; i < torch_shape.size() - 1; ++i) { torch_columnwise_shape.push_back(torch_shape[i]); - columnwise_shape.push_back(shape[i]); + columnwise_shape.data[i + 1] = shape.data[i]; } } else { // assert we are doing 1D scaling @@ -615,15 +660,15 @@ std::pair Float8BlockQuantizer::create_tensor( } } auto scale_shape = get_scale_shape(shape, true); - size_t sinv0 = scale_shape[0]; - size_t sinv1 = scale_shape[1]; + size_t sinv0 = scale_shape.data[0]; + size_t sinv1 = scale_shape.data[1]; data_colwise = at::empty(torch_columnwise_shape, opts); scale_inv_colwise = at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); tensor.set_columnwise_data(data_colwise.data_ptr(), this->dtype, columnwise_shape); tensor.set_columnwise_scale_inv(scale_inv_colwise.data_ptr(), DType::kFloat32, - std::vector{sinv0, sinv1}); + make_nvte_2d_shape(sinv0, sinv1)); } this->set_quantization_params(&tensor); @@ -675,37 +720,33 @@ std::pair Float8BlockQuantizer::convert_and_update_te opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); - auto get_columnwise_shape = [&columnwise_data](bool all_gather_usage) -> std::vector { + auto get_columnwise_shape = [&columnwise_data](bool all_gather_usage) -> NVTEShape { if (!columnwise_data) { - return std::vector(); + return NVTEShape(); } if (all_gather_usage) { return getTensorShape(*columnwise_data); } - std::vector shape = getTensorShape(*columnwise_data); - std::vector shape_transposed(shape.size()); - for (size_t i = 0; i + 1 < shape.size(); ++i) { - shape_transposed[i] = shape[i + 1]; - } - if (shape.size() > 0) { - shape_transposed[shape.size() - 1] = shape[0]; - } + NVTEShape shape = getTensorShape(*columnwise_data); + NVTEShape shape_transposed = make_transpose_nvte_shape(shape); return shape_transposed; }; - std::vector shape; + NVTEShape shape, columnwise_shape; if (rowwise_data) { shape = getTensorShape(*rowwise_data); if (columnwise_data) { auto expected_shape = get_columnwise_shape(all_gather_usage); - NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape, - ") and column-wise data (shape=", expected_shape, ") do not match"); + NVTE_CHECK(shapes_equal(shape, expected_shape), + "BlockwiseFP8 row-wise data (shape=", shape_to_string(shape), + ") and column-wise data (shape=", shape_to_string(expected_shape), + ") do not match"); } } else { shape = get_columnwise_shape(all_gather_usage); } std::vector torch_shape; - for (auto s : shape) { - torch_shape.emplace_back(static_cast(s)); + for (size_t i = 0; i < shape.ndim; ++i) { + torch_shape.emplace_back(static_cast(shape.data[i])); } // Coerce row-wise data @@ -716,8 +757,8 @@ std::pair Float8BlockQuantizer::convert_and_update_te } if (!rowwise_scale_inv) { auto scale_shape = get_scale_shape(shape, false); - size_t sinv0 = scale_shape[0]; - size_t sinv1 = scale_shape[1]; + size_t sinv0 = scale_shape.data[0]; + size_t sinv1 = scale_shape.data[1]; rowwise_scale_inv = at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; @@ -735,17 +776,16 @@ std::pair Float8BlockQuantizer::convert_and_update_te // Coerce column-wise data if (columnwise_usage) { - std::vector columnwise_shape; std::vector torch_columnwise_shape; if (torch_shape.size() > 0) { if (!all_gather_usage) { torch_columnwise_shape.reserve(torch_shape.size()); - columnwise_shape.reserve(shape.size()); torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); - columnwise_shape.push_back(shape[shape.size() - 1]); + columnwise_shape.ndim = shape.ndim; + columnwise_shape.data[0] = shape.data[shape.ndim - 1]; for (size_t i = 0; i < torch_shape.size() - 1; ++i) { torch_columnwise_shape.push_back(torch_shape[i]); - columnwise_shape.push_back(shape[i]); + columnwise_shape.data[i + 1] = shape.data[i]; } } else { // assert we are doing 1D scaling @@ -761,8 +801,8 @@ std::pair Float8BlockQuantizer::convert_and_update_te } if (!columnwise_scale_inv) { auto scale_shape = get_scale_shape(shape, true); - size_t sinv0 = scale_shape[0]; - size_t sinv1 = scale_shape[1]; + size_t sinv0 = scale_shape.data[0]; + size_t sinv1 = scale_shape.data[1]; columnwise_scale_inv = at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; @@ -784,8 +824,7 @@ std::pair Float8BlockQuantizer::convert_and_update_te const at::Tensor& data_rowwise = tensor.attr("_rowwise_data").cast(); const at::Tensor& scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); void* scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); - const auto& rowwise_shape = getTensorShape(data_rowwise); - ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, rowwise_shape); + ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, shape); const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape); } @@ -793,8 +832,7 @@ std::pair Float8BlockQuantizer::convert_and_update_te const at::Tensor& data_colwise = tensor.attr("_columnwise_data").cast(); const at::Tensor& scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); void* scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); - const auto& shape = getTensorShape(data_colwise); - ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape); + ret.set_columnwise_data(data_colwise.data_ptr(), dtype, columnwise_shape); const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); } @@ -824,11 +862,31 @@ void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& o std::vector Float8BlockQuantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { size_t numel = 1; + size_t k_dim; + for (auto s : shape) { numel *= s; } + k_dim = shape.size() == 0 ? 1u : shape.back(); + + return get_scale_shape_impl>(numel, k_dim, columnwise); +} - size_t k_dim = shape.size() == 0 ? 1u : shape.back(); +NVTEShape Float8BlockQuantizer::get_scale_shape(const NVTEShape& shape, bool columnwise) const { + size_t numel = 1; + size_t k_dim; + + for (size_t i = 0; i < shape.ndim; ++i) { + numel *= shape.data[i]; + } + k_dim = shape.ndim == 0 ? 1u : shape.data[shape.ndim - 1]; + + return get_scale_shape_impl(numel, k_dim, columnwise); +} + +template +ShapeT Float8BlockQuantizer::get_scale_shape_impl(size_t numel, size_t k_dim, + bool columnwise) const { size_t m_dim = numel / k_dim; constexpr size_t kBlockLen = 128; @@ -836,27 +894,20 @@ std::vector Float8BlockQuantizer::get_scale_shape(const std::vector scale_shape; - + size_t sinv0 = 0; + size_t sinv1 = 0; bool rowwise_usage = !columnwise; if (rowwise_usage) { - // rowwise scaling factor shape - size_t sinv0 = 0; - size_t sinv1 = 0; if (block_scaling_dim == 2) { - // 2D scaling is always GEMM_READY for now NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY, "2D scaling is always GEMM_READY for now."); sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4); } else if (block_scaling_dim == 1) { - // 1D scaling can be GEMM_READY or COMPACT bool rowwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT; - // default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4); - // if the rowwise format is compact, the scaling factor is not be transposed if (rowwise_compact) { std::swap(sinv0, sinv1); } @@ -866,13 +917,8 @@ std::vector Float8BlockQuantizer::get_scale_shape(const std::vector Float8BlockQuantizer::get_scale_shape(const std::vector) { + return make_nvte_2d_shape(sinv0, sinv1); + } else { + return std::vector{sinv0, sinv1}; + } } MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { @@ -902,22 +952,22 @@ MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantize void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {} -std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, +std::pair MXFP8Quantizer::create_tensor(const NVTEShape& shape, DType dtype) const { using namespace pybind11::literals; // Tensor dimensions - const std::vector shape_int64(shape.begin(), shape.end()); + const std::vector shape_int64(shape.data, shape.data + shape.ndim); size_t flat_first_dim = 1; - if (shape.size() > 0) { - for (size_t i = 0; i < shape.size() - 1; ++i) { - flat_first_dim *= shape[i]; + if (shape.ndim > 0) { + for (size_t i = 0; i < shape.ndim - 1; ++i) { + flat_first_dim *= shape.data[i]; } } - const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + const size_t flat_last_dim = shape.ndim > 0 ? shape.data[shape.ndim - 1] : 1; NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0, "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, - " (got shape=", shape, ")"); + " (got shape=", shape_to_string(shape), ")"); const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); @@ -926,14 +976,15 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor; const auto uint8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); if (rowwise_usage) { - const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), - rowwise_scale_inv_shape.end()); + const std::vector scale_inv_shape_int64( + rowwise_scale_inv_shape.data, rowwise_scale_inv_shape.data + rowwise_scale_inv_shape.ndim); rowwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts); } if (columnwise_usage) { - const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), - columnwise_scale_inv_shape.end()); + const std::vector scale_inv_shape_int64( + columnwise_scale_inv_shape.data, + columnwise_scale_inv_shape.data + columnwise_scale_inv_shape.ndim); columnwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts); } @@ -1002,13 +1053,14 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( NVTE_CHECK(rowwise_data || columnwise_data, "MXFP8Tensor has no data."); // Tensor dimensions - std::vector shape; + NVTEShape shape; if (columnwise_data) { shape = getTensorShape(*columnwise_data); if (rowwise_data) { - auto expected_shape = getTensorShape(*rowwise_data); - NVTE_CHECK(shape == expected_shape, "MXFP8 row-wise data (shape=", expected_shape, - ") and column-wise data (shape=", shape, ") do not match"); + const NVTEShape expected_shape = getTensorShape(*rowwise_data); + NVTE_CHECK(shapes_equal(shape, expected_shape), + "MXFP8 row-wise data (shape=", shape_to_string(expected_shape), + ") and column-wise data (shape=", shape_to_string(shape), ") do not match"); } } else { // Already checked columnwise_data_tensor == true shape = getTensorShape(*rowwise_data); @@ -1017,15 +1069,15 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( // Coerce row-wise data if (rowwise_usage) { if (!rowwise_data) { - const std::vector shape_int64(shape.begin(), shape.end()); + const std::vector shape_int64(shape.data, shape.data + shape.ndim); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); rowwise_data = at::empty(shape_int64, opts); tensor.attr("_rowwise_data") = *rowwise_data; } if (!rowwise_scale_inv) { const auto scale_inv_shape = get_scale_shape(shape, false); - const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), - scale_inv_shape.end()); + const std::vector scale_inv_shape_int64(scale_inv_shape.data, + scale_inv_shape.data + scale_inv_shape.ndim); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts); tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; @@ -1044,15 +1096,15 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( // Coerce column-wise data if (columnwise_usage) { if (!columnwise_data) { - const std::vector shape_int64(shape.begin(), shape.end()); + const std::vector shape_int64(shape.data, shape.data + shape.ndim); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); columnwise_data = at::empty(shape_int64, opts); tensor.attr("_columnwise_data") = *columnwise_data; } if (!columnwise_scale_inv) { const auto scale_inv_shape = get_scale_shape(shape, true); - const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), - scale_inv_shape.end()); + const std::vector scale_inv_shape_int64(scale_inv_shape.data, + scale_inv_shape.data + scale_inv_shape.ndim); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); columnwise_scale_inv = at::empty(scale_inv_shape_int64, opts); tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; @@ -1105,32 +1157,50 @@ void MXFP8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, std::vector MXFP8Quantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { size_t numel = 1; + size_t last_dim; + for (auto s : shape) { numel *= s; } + last_dim = shape.empty() ? 1 : shape.back(); - auto last_dim = shape.back(); + return get_scale_shape_impl>(numel, last_dim, columnwise); +} - NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, - "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, - " (got shape=", shape, ")"); +NVTEShape MXFP8Quantizer::get_scale_shape(const NVTEShape& shape, bool columnwise) const { + size_t numel = 1; + size_t last_dim; + + for (size_t i = 0; i < shape.ndim; ++i) { + numel *= shape.data[i]; + } + last_dim = shape.ndim == 0 ? 1 : shape.data[shape.ndim - 1]; + + return get_scale_shape_impl(numel, last_dim, columnwise); +} - std::vector scale_shape; +template +ShapeT MXFP8Quantizer::get_scale_shape_impl(size_t numel, size_t last_dim, bool columnwise) const { + NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, + "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE); + size_t sinv0 = 0; + size_t sinv1 = 0; bool rowwise_usage = !columnwise; if (rowwise_usage) { - // rowwise scaling factor shape - size_t sinv0 = roundup(numel / last_dim, 128); - size_t sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); - scale_shape = {sinv0, sinv1}; + sinv0 = roundup(numel / last_dim, 128); + sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); } else { - // columnwise scaling factor shape - size_t sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); - size_t sinv1 = roundup(last_dim, 128); - scale_shape = {sinv0, sinv1}; + sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); + sinv1 = roundup(last_dim, 128); + } + + if constexpr (std::is_same_v) { + return make_nvte_2d_shape(sinv0, sinv1); + } else { + return std::vector{sinv0, sinv1}; } - return scale_shape; } NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { @@ -1169,24 +1239,24 @@ void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const { columnwise_data.shape); } -std::pair NVFP4Quantizer::create_tensor(const std::vector& shape, +std::pair NVFP4Quantizer::create_tensor(const NVTEShape& shape, DType dtype) const { using namespace pybind11::literals; // Tensor dimensions - const std::vector shape_int64(shape.begin(), shape.end()); + const std::vector shape_int64(shape.data, shape.data + shape.ndim); size_t flat_first_dim = 1; - if (shape.size() > 0) { - for (size_t i = 0; i < shape.size() - 1; ++i) { - flat_first_dim *= shape[i]; + if (shape.ndim > 0) { + for (size_t i = 0; i < shape.ndim - 1; ++i) { + flat_first_dim *= shape.data[i]; } } - const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + const size_t flat_last_dim = shape.ndim > 0 ? shape.data[shape.ndim - 1] : 1; NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, "First dim for NVFP4 must be divisible by ", - NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); + NVFP4_BLOCK_SIZE, " (got shape=", shape_to_string(shape), ")"); NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0, "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, - " (got shape=", shape, ")"); + " (got shape=", shape_to_string(shape), ")"); const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); @@ -1196,8 +1266,8 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve const auto bit8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); const auto bit32_tensor_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); if (rowwise_usage) { - const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), - rowwise_scale_inv_shape.end()); + const std::vector scale_inv_shape_int64( + rowwise_scale_inv_shape.data, rowwise_scale_inv_shape.data + rowwise_scale_inv_shape.ndim); rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); // hadamard amax kernel will zero out pointer with ZeroAmaxKernel @@ -1205,13 +1275,13 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve amax_rowwise = at::empty({1}, bit32_tensor_opts); } if (columnwise_usage) { - const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), - columnwise_scale_inv_shape.end()); + const std::vector scale_inv_shape_int64( + columnwise_scale_inv_shape.data, + columnwise_scale_inv_shape.data + columnwise_scale_inv_shape.ndim); // enforce 2D shape to avoid [S, B, H] shape and B and be 1 // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero - std::vector shape_int64_2d = {static_cast(flat_first_dim), - static_cast(flat_last_dim)}; - const auto transpose_shape_int64 = make_transpose_shape(shape_int64_2d); + NVTEShape shape_2d = make_nvte_2d_shape(flat_first_dim, flat_last_dim); + const auto transpose_shape_int64 = make_transpose_shape(shape_2d); columnwise_data_tensor = at::empty(convert_shape_for_fp4(transpose_shape_int64), bit8_tensor_opts); columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); @@ -1258,19 +1328,19 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), DType::kFloat4E2M1, shape); out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, rowwise_scale_inv_shape); - out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, std::vector{1}); + out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, TensorWrapper::defaultShape); } if (columnwise_usage) { // enforce 2D shape to avoid [S, B, H] shape and B and be 1 // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero - std::vector shape_2d = {flat_first_dim, flat_last_dim}; - auto col_data_shape_fp4 = make_transpose_shape(shape_2d); + NVTEShape shape_2d = make_nvte_2d_shape(flat_first_dim, flat_last_dim); + auto col_data_shape_fp4 = make_transpose_nvte_shape(shape_2d); out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), DType::kFloat4E2M1, col_data_shape_fp4); out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, columnwise_scale_inv_shape); out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, - std::vector{1}); + TensorWrapper::defaultShape); } this->set_quantization_params(&out_cpp); @@ -1280,8 +1350,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve std::pair NVFP4Quantizer::create_unquantized_tensor_with_amax( TensorWrapper& quantized_tensor, DType dtype) { // Construct tensor - auto shape = convertShape(quantized_tensor.shape()); - auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); + auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(quantized_tensor.shape(), dtype); // Register amax pointer from quantized tensor void* amax_ptr = quantized_tensor.amax(); @@ -1289,7 +1358,7 @@ std::pair NVFP4Quantizer::create_unquantized_tensor_w amax_ptr = quantized_tensor.get_columnwise_amax().data_ptr; } NVTE_CHECK(amax_ptr != nullptr, "Could not extract amax pointer from NVFP4 tensor."); - out_cpp.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + out_cpp.set_amax(amax_ptr, DType::kFloat32, TensorWrapper::defaultShape); // Zero out amax NVTE_CHECK_CUDA(cudaMemsetAsync(amax_ptr, 0, sizeof(float), at::cuda::getCurrentCUDAStream())); @@ -1318,38 +1387,39 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( NVTE_CHECK(rowwise_data || columnwise_data, "NVFP4Tensor has no data."); // Tensor dimensions, shape means original shape - std::vector shape; + NVTEShape shape; if (columnwise_data) { shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); if (rowwise_data) { - auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); - NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, - ") and column-wise data (shape=", shape, ") do not match"); + NVTEShape expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + NVTE_CHECK(shapes_equal(shape, expected_shape), + "NVFP4 row-wise data (shape=", shape_to_string(expected_shape), + ") and column-wise data (shape=", shape_to_string(shape), ") do not match"); } } else { // Already checked columnwise_data_tensor == true shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); } size_t flat_first_dim = 1; - if (shape.size() > 0) { - for (size_t i = 0; i < shape.size() - 1; ++i) { - flat_first_dim *= shape[i]; + if (shape.ndim > 0) { + for (size_t i = 0; i < shape.ndim - 1; ++i) { + flat_first_dim *= shape.data[i]; } } - const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + const size_t flat_last_dim = shape.ndim > 0 ? shape.data[shape.ndim - 1] : 1; // Coerce row-wise data if (rowwise_usage) { if (!rowwise_data) { - const std::vector shape_int64(shape.begin(), shape.end()); + const std::vector shape_int64(shape.data, shape.data + shape.ndim); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); rowwise_data = at::empty(convert_shape_for_fp4(shape_int64), opts); tensor.attr("_rowwise_data") = *rowwise_data; } if (!rowwise_scale_inv) { const auto scale_inv_shape = get_scale_shape(shape, false); - const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), - scale_inv_shape.end()); + const std::vector scale_inv_shape_int64(scale_inv_shape.data, + scale_inv_shape.data + scale_inv_shape.ndim); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts); tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; @@ -1381,17 +1451,16 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( if (!columnwise_data) { // enforce 2D shape to avoid [S, B, H] shape and B and be 1 // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero - std::vector shape_int64_2d = {static_cast(flat_first_dim), - static_cast(flat_last_dim)}; + NVTEShape shape_2d = make_nvte_2d_shape(flat_first_dim, flat_last_dim); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - const auto transpose_shape_int64 = make_transpose_shape(shape_int64_2d); + const auto transpose_shape_int64 = make_transpose_shape(shape_2d); columnwise_data = at::empty(convert_shape_for_fp4(transpose_shape_int64), opts); tensor.attr("_columnwise_data") = *columnwise_data; } if (!columnwise_scale_inv) { const auto scale_inv_shape = get_scale_shape(shape, true); - const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), - scale_inv_shape.end()); + const std::vector scale_inv_shape_int64(scale_inv_shape.data, + scale_inv_shape.data + scale_inv_shape.ndim); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); columnwise_scale_inv = at::empty(scale_inv_shape_int64, opts); tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; @@ -1424,19 +1493,19 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( out_cpp.set_rowwise_data(rowwise_data->data_ptr(), DType::kFloat4E2M1, shape); out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E4M3, getTensorShape(*rowwise_scale_inv)); - out_cpp.set_amax(amax_rowwise->data_ptr(), DType::kFloat32, std::vector{1}); + out_cpp.set_amax(amax_rowwise->data_ptr(), DType::kFloat32, TensorWrapper::defaultShape); } if (columnwise_usage) { // enforce 2D shape to avoid [S, B, H] shape and B and be 1 // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero - std::vector shape_2d = {flat_first_dim, flat_last_dim}; - auto col_data_shape_fp4 = make_transpose_shape(shape_2d); + NVTEShape shape_2d = make_nvte_2d_shape(flat_first_dim, flat_last_dim); + auto col_data_shape_fp4 = make_transpose_nvte_shape(shape_2d); out_cpp.set_columnwise_data(columnwise_data->data_ptr(), DType::kFloat4E2M1, col_data_shape_fp4); out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E4M3, getTensorShape(*columnwise_scale_inv)); out_cpp.set_columnwise_amax(amax_columnwise->data_ptr(), DType::kFloat32, - std::vector{1}); + TensorWrapper::defaultShape); } this->set_quantization_params(&out_cpp); @@ -1531,10 +1600,10 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); // Compute amax of input tensor - out.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + out.set_amax(amax_ptr, DType::kFloat32, TensorWrapper::defaultShape); NVTE_SCOPED_GIL_RELEASE( { nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); }); - out.set_amax(rowwise_amax_ptr, DType::kFloat32, std::vector{1}); + out.set_amax(rowwise_amax_ptr, DType::kFloat32, TensorWrapper::defaultShape); // Make sure row-wise and column-wise amaxes match if (rowwise_amax_ptr != amax_ptr && rowwise_amax_ptr != nullptr) { @@ -1604,20 +1673,17 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // Note: since we are faking columnwise tensor into rowwise, the flat first dim check will fail // need to convert the shape to 2D here auto colwise_data_shape = out_columnwise_data.shape; - std::vector colwise_data_shape_2d; // shape could be [512, 32, 64], that's actually 512, 32, 128 because 2 FP4 take 1 byte // the 2D shape should be [512, 32*128], but columnwise data shape expect last dim to be halved again // so the multiple 2 get cancelled out - colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); size_t last_dim = 1; for (size_t i = 1; i < colwise_data_shape.ndim; ++i) { last_dim *= colwise_data_shape.data[i]; } - colwise_data_shape_2d.push_back(last_dim); out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, static_cast(out_columnwise_data.dtype), - colwise_data_shape_2d); + make_nvte_2d_shape(colwise_data_shape.data[0], last_dim)); out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr, static_cast(out_columnwise_scale_inv.dtype), out_columnwise_scale_inv.shape); @@ -1642,7 +1708,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // NOTE (frsun): This is non-intuitive, we are writing the // result of transposed RHT to the output of rowwise. rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(), - std::vector{cols, rows}); + make_nvte_2d_shape(cols, rows)); NVTE_SCOPED_GIL_RELEASE({ // Perform the RHT(input.t), and write to rht_output_cpp.columnwise. @@ -1696,7 +1762,7 @@ void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out NVTE_CHECK_CUDA(cudaMemcpyAsync(output_columnwise_amax_ptr, input_amax_ptr, sizeof(float), cudaMemcpyDeviceToDevice, at::cuda::getCurrentCUDAStream())); } - input.set_amax(nullptr, DType::kFloat32, input.defaultShape); + input.set_amax(nullptr, DType::kFloat32, TensorWrapper::defaultShape); // Perform quantization this->quantize_impl(input, out, std::nullopt, false); @@ -1705,35 +1771,54 @@ void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out std::vector NVFP4Quantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { size_t numel = 1; + size_t last_dim; + for (auto s : shape) { numel *= s; } + last_dim = shape.empty() ? 1 : shape.back(); + + return get_scale_shape_impl>(numel, last_dim, columnwise); +} + +NVTEShape NVFP4Quantizer::get_scale_shape(const NVTEShape& shape, bool columnwise) const { + size_t numel = 1; + size_t last_dim; + + for (size_t i = 0; i < shape.ndim; ++i) { + numel *= shape.data[i]; + } + last_dim = shape.ndim == 0 ? 1 : shape.data[shape.ndim - 1]; - auto last_dim = shape.back(); + return get_scale_shape_impl(numel, last_dim, columnwise); +} + +template +ShapeT NVFP4Quantizer::get_scale_shape_impl(size_t numel, size_t last_dim, bool columnwise) const { auto flat_first_dim = numel / last_dim; NVTE_CHECK(last_dim % NVFP4_BLOCK_SIZE == 0, "Last dim for NVFP4 must be divisible by ", NVFP4_BLOCK_SIZE, " (got dim=", last_dim, ")"); NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, - "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, - " (got shape=", shape, ")"); - - std::vector scale_shape; + "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE); + size_t sinv0 = 0; + size_t sinv1 = 0; bool rowwise_usage = !columnwise; if (rowwise_usage) { - // rowwise scaling factor shape - size_t sinv0 = roundup(flat_first_dim, 128); - size_t sinv1 = roundup(last_dim / NVFP4_BLOCK_SIZE, 4); - scale_shape = {sinv0, sinv1}; + sinv0 = roundup(flat_first_dim, 128); + sinv1 = roundup(last_dim / NVFP4_BLOCK_SIZE, 4); + } else { + sinv0 = roundup(last_dim, 128); + sinv1 = roundup(flat_first_dim / NVFP4_BLOCK_SIZE, 4); + } + + if constexpr (std::is_same_v) { + return make_nvte_2d_shape(sinv0, sinv1); } else { - // columnwise scaling factor shape - size_t sinv0 = roundup(last_dim, 128); - size_t sinv1 = roundup(flat_first_dim / NVFP4_BLOCK_SIZE, 4); - scale_shape = {sinv0, sinv1}; + return std::vector{sinv0, sinv1}; } - return scale_shape; } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index 96fd2ccb3..38ce5a12d 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -15,7 +15,7 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap if (input.scaling_mode() == NVTE_INVALID_SCALING) { NVTE_ERROR("Invalid scaling mode for swizzle."); - } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING && + } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING || input.scaling_mode() != NVTE_NVFP4_1D_SCALING) { return std::nullopt; } @@ -26,26 +26,22 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap const auto nvfp4 = input.scaling_mode() == NVTE_NVFP4_1D_SCALING; NVTEBasicTensor scale_inv; - NVTEShape nvte_input_shape; + NVTEShape input_shape; if (rowwise) { - nvte_input_shape = input.shape(); + input_shape = input.shape(); scale_inv = input.get_rowwise_scale_inv(); } else { - nvte_input_shape = input.get_columnwise_data().shape; + input_shape = input.get_columnwise_data().shape; scale_inv = input.get_columnwise_scale_inv(); } - auto input_shape = nvte_shape_to_vector(nvte_input_shape); - auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape); - - NVTE_CHECK(input_shape.size() >= 2, "Wrong ndims for swizzle input shape."); + auto& scale_inv_shape = scale_inv.shape; + NVTE_CHECK(input_shape.ndim >= 2, "Wrong ndims for swizzle input shape."); // Allocate memory for swizzled output. auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); - std::vector scale_inv_shape_int; - for (size_t i = 0; i < scale_inv_shape.size(); ++i) { - scale_inv_shape_int.push_back(static_cast(scale_inv_shape[i])); - } + std::vector scale_inv_shape_int(scale_inv_shape.data, + scale_inv_shape.data + scale_inv_shape.ndim); auto swizzled_scale_inv = at::empty(scale_inv_shape_int, options); void* scale_inv_dptr = scale_inv.data_ptr; void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f3220d586..067939677 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1343,7 +1343,7 @@ def reset_parameters(self, defer_init=False): elif self.parallel_mode == "column": set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) - @no_torch_dynamo() + @no_torch_dynamo(recursive=False) def forward( self, inp: torch.Tensor,