From 80ab846a2acf3333bd97c9875a0ba7ff32895852 Mon Sep 17 00:00:00 2001 From: cx Date: Wed, 27 May 2026 03:32:00 +0000 Subject: [PATCH] refactor(autocast): move autocast from dispatcher to autograd boundary Apply autocast in Function::Apply before Forward/SetupContext so saved tensors are stored in the actual forward compute dtype, instead of guessing it from output->Dtype() inside SetupContext. Drops the duplicate cast workaround in Matmul/Linear::SetupContext. Add AutocastByName entry on AutocastContext (keyed by op name string) and extend GetBaseOpName to strip the "Function" suffix used by autograd::Function::type_. Remove the now-redundant autocast hook from Dispatcher::Call; backward kernels and internal helpers are no longer accidentally re-cast. Add direct common.h includes to elementwise/gather kernels that previously relied on the dispatcher.h -> autocast.h -> common.h transitive include chain. --- infini_train/include/autocast.h | 35 ++++---------------- infini_train/include/dispatcher.h | 2 -- infini_train/src/autograd/function.cc | 12 +++++-- infini_train/src/autograd/linear.cc | 16 +-------- infini_train/src/autograd/matmul.cc | 17 +--------- infini_train/src/kernels/cpu/elementwise.cc | 1 + infini_train/src/kernels/cuda/elementwise.cu | 1 + infini_train/src/kernels/cuda/gather.cu | 1 + 8 files changed, 22 insertions(+), 63 deletions(-) diff --git a/infini_train/include/autocast.h b/infini_train/include/autocast.h index 499c586f..cbb97c3f 100644 --- a/infini_train/include/autocast.h +++ b/infini_train/include/autocast.h @@ -11,25 +11,10 @@ namespace infini_train { namespace { inline std::string_view GetBaseOpName(std::string_view op) { - constexpr std::string_view forward_suffix = "Forward"; - constexpr std::string_view backward_suffix = "Backward"; - - // Check for "Forward" suffix - if (op.size() >= forward_suffix.size()) { - const auto suffix_pos = op.size() - forward_suffix.size(); - if (op.substr(suffix_pos) == forward_suffix) { - return op.substr(0, suffix_pos); - } - } - - // Check for "Backward" suffix - if (op.size() >= backward_suffix.size()) { - const auto suffix_pos = op.size() - backward_suffix.size(); - if (op.substr(suffix_pos) == backward_suffix) { - return op.substr(0, suffix_pos); - } + constexpr std::string_view function_suffix = "Function"; + if (op.size() >= function_suffix.size() && op.substr(op.size() - function_suffix.size()) == function_suffix) { + return op.substr(0, op.size() - function_suffix.size()); } - return op; } }; // namespace @@ -97,18 +82,14 @@ struct AutocastContext { Device::DeviceType device_type = Device::DeviceType::kCPU; // Target device type (CPU/GPU) DataType autocast_dtype = DataType::kBFLOAT16; // The data type used for autocasting - template void Autocast(std::pair key, ArgsT &...args) { + // Cast a parameter pack of tensors (or shared_ptr) according to the cast policy + // associated with `op_name`. Called from autograd::Function::Apply with type_ as op_name. + template void Autocast(std::string_view op_name, ArgsT &...args) { if (!enabled) { return; } - if (device_type != key.first) { - LOG_LOC(FATAL, "In AutocastContext::Autocast(): the AutocastContext device_type is different from the one " - "passed in. Don't know what to do."); - return; - } - - auto map_it = kOpCastPolicyMap.find(GetBaseOpName(key.second)); + auto map_it = kOpCastPolicyMap.find(GetBaseOpName(op_name)); if (map_it == kOpCastPolicyMap.end()) { return; } @@ -132,7 +113,6 @@ struct AutocastContext { } }; - // Process each argument auto cast_arg = [&](auto &arg) { using T = std::decay_t; if constexpr (std::is_same_v>) { @@ -156,7 +136,6 @@ struct AutocastContext { } }; - // Apply casting to each argument (cast_arg(args), ...); } }; diff --git a/infini_train/include/dispatcher.h b/infini_train/include/dispatcher.h index 638df76a..2c390ec8 100644 --- a/infini_train/include/dispatcher.h +++ b/infini_train/include/dispatcher.h @@ -6,7 +6,6 @@ #include "glog/logging.h" -#include "infini_train/include/autocast.h" #include "infini_train/include/device.h" #ifdef PROFILE_MODE #include "infini_train/include/profiler.h" @@ -73,7 +72,6 @@ class Dispatcher { template RetT Call(KeyT key, ArgsT... args) const { auto kernel = this->GetKernel(key); - tls_autocast_context.Autocast(key, args...); #ifdef PROFILE_MODE SetProfileContext(key.second, key.first); #endif diff --git a/infini_train/src/autograd/function.cc b/infini_train/src/autograd/function.cc index f7eb35c7..a09d2004 100644 --- a/infini_train/src/autograd/function.cc +++ b/infini_train/src/autograd/function.cc @@ -2,6 +2,7 @@ #include "glog/logging.h" +#include "infini_train/include/autocast.h" #include "infini_train/include/autograd/accumulate.h" #include "infini_train/include/autograd/function_hook.h" #include "infini_train/include/autograd/grad_mode.h" @@ -46,12 +47,19 @@ std::vector> Function::Apply(const std::vector AccumulateGrad / non-leaf -> grad_fn). + auto compute_inputs = input_tensors; + for (auto &t : compute_inputs) { tls_autocast_context.Autocast(type_, t); } + std::vector> output_tensors; { autograd::NoGradGuard no_grad; // no_grad in autograd.Function.Forward() - output_tensors = Forward(input_tensors); - SetupContext(input_tensors, output_tensors); + output_tensors = Forward(compute_inputs); + SetupContext(compute_inputs, output_tensors); } // Call forward post-hooks diff --git a/infini_train/src/autograd/linear.cc b/infini_train/src/autograd/linear.cc index ff0283ce..76602b03 100644 --- a/infini_train/src/autograd/linear.cc +++ b/infini_train/src/autograd/linear.cc @@ -20,26 +20,12 @@ void Linear::SetupContext(const std::vector> &input_tens const std::vector> &output_tensors) { const auto &input = input_tensors[0]; const auto &weight = input_tensors[1]; - // Cast saved tensors to forward compute dtype (output dtype) so backward - // computes in the same precision as forward, matching PyTorch's behavior. - // FIXME: An extra cast (input/weight -> compute_dtype) is performed here because - // autocast runs before autograd. The correct approach is to adjust the ordering or - // integration of autocast and autograd so that autograd receives already-cast tensors, - // avoiding the redundant cast. - - // FIXME: compute_dtype is not necessarily the dtype of output_tensor; it should be - // determined by autocast, not derived from output_tensors[0]->Dtype(). - auto compute_dtype = output_tensors[0]->Dtype(); bool need_input = needs_input_grad_.size() > 0 && needs_input_grad_[0]; bool need_weight = needs_input_grad_.size() > 1 && needs_input_grad_[1]; - auto cast = [&](const std::shared_ptr &t) { - return t->Dtype() == compute_dtype ? t : std::make_shared(t->To(compute_dtype)); - }; - // grad_input needs weight, grad_weight needs input - saved_tensors_ = {need_weight ? cast(input) : nullptr, need_input ? cast(weight) : nullptr}; + saved_tensors_ = {need_weight ? input : nullptr, need_input ? weight : nullptr}; transpose_ = true; bias_ = input_tensors.size() == 3; diff --git a/infini_train/src/autograd/matmul.cc b/infini_train/src/autograd/matmul.cc index 8ddfc578..1f24dc21 100644 --- a/infini_train/src/autograd/matmul.cc +++ b/infini_train/src/autograd/matmul.cc @@ -20,28 +20,13 @@ void Matmul::SetupContext(const std::vector> &input_tens const auto &input1 = input_tensors[0]; const auto &input2 = input_tensors[1]; const auto &output = output_tensors[0]; - // Cast saved tensors to forward compute dtype (output dtype) so backward - // computes in the same precision as forward, matching PyTorch's behavior. - - // FIXME: An extra cast (input1/input2 -> compute_dtype) is performed here because - // autocast runs before autograd. The correct approach is to adjust the ordering or - // integration of autocast and autograd so that autograd receives already-cast tensors, - // avoiding the redundant cast. - - // FIXME: compute_dtype is not necessarily the dtype of output_tensor; it should be - // determined by autocast, not derived from output->Dtype(). - auto compute_dtype = output->Dtype(); // grad_input1 = grad_output @ input2^T, so input2 is needed // grad_input2 = grad_output^T @ input1, so input1 is needed bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0]; bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1]; - auto cast = [&](const std::shared_ptr &t) { - return t->Dtype() == compute_dtype ? t : std::make_shared(t->To(compute_dtype)); - }; - - saved_tensors_ = {need_grad_input2 ? cast(input1) : nullptr, need_grad_input1 ? cast(input2) : nullptr}; + saved_tensors_ = {need_grad_input2 ? input1 : nullptr, need_grad_input1 ? input2 : nullptr}; input1_dims_ = input1->Dims(); input2_dims_ = input2->Dims(); out_features_ = output->Dims()[0]; diff --git a/infini_train/src/kernels/cpu/elementwise.cc b/infini_train/src/kernels/cpu/elementwise.cc index 71058b51..213abdcf 100644 --- a/infini_train/src/kernels/cpu/elementwise.cc +++ b/infini_train/src/kernels/cpu/elementwise.cc @@ -6,6 +6,7 @@ #include "glog/logging.h" +#include "infini_train/include/common/common.h" #include "infini_train/include/device.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" diff --git a/infini_train/src/kernels/cuda/elementwise.cu b/infini_train/src/kernels/cuda/elementwise.cu index fe63e0b2..c332799e 100644 --- a/infini_train/src/kernels/cuda/elementwise.cu +++ b/infini_train/src/kernels/cuda/elementwise.cu @@ -2,6 +2,7 @@ #include +#include "infini_train/include/common/common.h" #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/common/cuda/kernel_helper.cuh" #include "infini_train/include/core/runtime/device_guard.h" diff --git a/infini_train/src/kernels/cuda/gather.cu b/infini_train/src/kernels/cuda/gather.cu index 12d0567d..1c553842 100644 --- a/infini_train/src/kernels/cuda/gather.cu +++ b/infini_train/src/kernels/cuda/gather.cu @@ -1,5 +1,6 @@ #include "glog/logging.h" +#include "infini_train/include/common/common.h" #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/core/runtime/device_guard.h" #include "infini_train/include/dispatcher.h"