Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions example/common/utils.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include "example/common/utils.h"

#include "gflags/gflags.h"
#include "glog/logging.h"

namespace infini_train {

float ConvertBF16ToFloat(void *ptr) {
Expand Down Expand Up @@ -61,4 +64,11 @@ void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t s
ifs.seekg(base + std::streamoff(len * sizeof(float)));
}

void ValidateDistributedOptimizerFlags(bool use_distributed_optimizer) {
gflags::CommandLineFlagInfo zero_stage_info;
CHECK(gflags::GetCommandLineFlagInfo("zero_stage", &zero_stage_info));
CHECK(use_distributed_optimizer || zero_stage_info.is_default)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里还有必要保留 use_distributed_optimizer 参数吗?感觉是不是用统一的一个参数来区分 zero1/zero2/zero3 比较合适,想体现切分语义的话可以用这个:
https://github.com/NVIDIA/Megatron-LM/blob/0010683033b67cb091719ccb1d8195a524b91356/docs/user-guide/parallelism-guide.md?plain=1#L53

--data-parallel-sharding-strategy [no_shard | optim | optim_grads | optim_grads_params]

no_shard            ≈ DDP

optim               ≈ ZeRO-1,只切 optimizer states

optim_grads         ≈ ZeRO-2,切 optimizer states + gradients

optim_grads_params  ≈ ZeRO-3,切 optimizer states + gradients + parameters

或者简单一点就保留 zero_stage 参数(默认值 0)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个我想的是,由于现在的 zero-1/2 是基于 DistOpt 的基建展开做的,而 Megatron 的三级 zero 是基于 FSDP 的基建做的(也就是你引用的这部分),后续咱们如果支持 fsdp 的话,可能就存在两种 zero 的实现路径。你也可以看到你引用的这个文档里,使用上也得先写一行 --use-megatron-fsdp 的 flag,后面再带着这个切分方法的参数。

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我感觉 megatron 这里本身可能存在一些历史遗留问题,我们不一定需要完全照搬;另外之后如果真的需要支持 FSDP,也不应该长期保留两套 zero 实现路径,更合理的方向可能还是统一抽象成一套 zero 语义或者只保留一种语义,否则长期维护成本太高。

<< "--zero_stage requires --use_distributed_optimizer=true.";
}

} // namespace infini_train
2 changes: 2 additions & 0 deletions example/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@ void ReadVectorAllFloat(std::ifstream &ifs, float *dst, int64_t len);

void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t start, int64_t cnt);

void ValidateDistributedOptimizerFlags(bool use_distributed_optimizer);

} // namespace infini_train
11 changes: 8 additions & 3 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

#include "example/common/tiny_shakespeare_dataset.h"
#include "example/common/tokenizer.h"
#include "example/common/utils.h"
#include "example/gpt2/checkpoint_loader.h"
#include "example/gpt2/config.h"

Expand All @@ -58,6 +59,7 @@ DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations");
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
DEFINE_int32(zero_stage, 1, "ZeRO stage (1/2/3), default 1 (only take effects when use_distributed_optimizer=true)");
// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -114,6 +116,7 @@ const std::unordered_map<std::string, nn::TransformerConfig> kModelToConfigs = {
DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
DEFINE_validator(device,
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 1 && value <= 3; });

void Train(const nn::parallel::Rank &rank) {
using namespace nn::parallel;
Expand Down Expand Up @@ -252,8 +255,8 @@ void Train(const nn::parallel::Rank &rank) {
model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
pp_rank, device, model_config.GetChunkSize());
if (ddp_world_size > 1) {
auto ddp_config
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
auto ddp_config = DistributedDataParallelConfig{
.use_distributed_optimizer = FLAGS_use_distributed_optimizer, .zero_stage = FLAGS_zero_stage};
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
(*mutable_chunks)[chunk_id]
Expand All @@ -265,7 +268,8 @@ void Train(const nn::parallel::Rank &rank) {
// before wrapping the model with DistributedDataParallel (DDP).
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
// are created during the conversion.
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer,
.zero_stage = FLAGS_zero_stage};
model = std::make_shared<DistributedDataParallel>(model, rank, ddp_config);
}

Expand Down Expand Up @@ -447,6 +451,7 @@ void Train(const nn::parallel::Rank &rank) {
int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
ValidateDistributedOptimizerFlags(FLAGS_use_distributed_optimizer);

auto precision_config = utils::PrecisionCheckConfig::Parse(FLAGS_precision_check);
nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
Expand Down
11 changes: 8 additions & 3 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

#include "example/common/tiny_shakespeare_dataset.h"
#include "example/common/tokenizer.h"
#include "example/common/utils.h"
#include "example/llama3/checkpoint_loader.h"
#include "example/llama3/config.h"

Expand All @@ -57,6 +58,7 @@ DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations");
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
DEFINE_int32(zero_stage, 1, "ZeRO stage (1/2/3), default 1 (only take effects when use_distributed_optimizer=true)");
Comment thread
chen2021673 marked this conversation as resolved.
// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -100,6 +102,7 @@ constexpr char kDtypeBF16[] = "bfloat16";
DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
DEFINE_validator(device,
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 1 && value <= 3; });

void Train(const nn::parallel::Rank &rank) {
using namespace nn::parallel;
Expand Down Expand Up @@ -222,8 +225,8 @@ void Train(const nn::parallel::Rank &rank) {
model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
pp_rank, device, model_config.GetChunkSize());
if (ddp_world_size > 1) {
auto ddp_config
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
auto ddp_config = DistributedDataParallelConfig{
.use_distributed_optimizer = FLAGS_use_distributed_optimizer, .zero_stage = FLAGS_zero_stage};
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
(*mutable_chunks)[chunk_id]
Expand All @@ -236,7 +239,8 @@ void Train(const nn::parallel::Rank &rank) {
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
// are created during the conversion.

auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer,
.zero_stage = FLAGS_zero_stage};
model = std::make_shared<DistributedDataParallel>(model, rank, ddp_config);
}

Expand Down Expand Up @@ -422,6 +426,7 @@ void Train(const nn::parallel::Rank &rank) {
int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
ValidateDistributedOptimizerFlags(FLAGS_use_distributed_optimizer);

auto precision_config = utils::PrecisionCheckConfig::Parse(FLAGS_precision_check);
nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
Expand Down
6 changes: 6 additions & 0 deletions infini_train/include/autograd/function_hook.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ namespace infini_train::autograd {
class PostAccumulateGradHook {
public:
virtual void operator()(const std::shared_ptr<Tensor> &tensor) = 0;

// ZeRO-2: Use this function to take over AccumulateGrad::Backward
virtual bool TryBypassAccumulate(const std::shared_ptr<Tensor> &, const std::shared_ptr<Tensor> &, bool, float) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PostAccumulateGradHook 语义是在 AccumulateGrad 完成梯度累积后执行 hook,不太合适承担 bypass accumulate 的职责。对于 zero-2 这种需要接管梯度累积路径的场景,更合适的方式是引入一个 PreAccumulateGradHook,在执行梯度累计前进行拦截与处理。

return false;
}

virtual ~PostAccumulateGradHook() = default;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ class DistributedDataParallelConfig {
// In this case, grad reduce is triggered immediately when a grad is ready or till all grads are ready.
bool overlap_grad_reduce = true;

// ZeRO-DP Stage for memory optimization (Only take effects when use_distributed_optimizer=true)
// ZeRO-1: Optimizer states partitioning, by default
// ZeRO-2: Gradients partitioning
// ZeRO-3: Parameters partitioning
int zero_stage = 1;

// Whether to overlap parameter all-gather with forward compute.
bool overlap_param_gather = true;

Expand All @@ -59,7 +65,7 @@ class DistributedDataParallelConfig {
// Maximum number of parameters in each ParamAndGradBucket.
// NOTE(zbl): This is distinct from DDP Reducer's MB-based bucket caps.
// TODO(zbl): To unify the definition of bucket_size argument for users
size_t bucket_size_in_elements = 40000000;
size_t bucket_size_in_elements = 1000000;

// Whether to pad bucket sizes to improve NCCL bus bandwidth utilization.
bool pad_buckets_for_high_nccl_busbw = false;
Expand Down
25 changes: 23 additions & 2 deletions infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ namespace infini_train::nn::parallel {
class ParamAndGradBucket {
public:
ParamAndGradBucket(const std::vector<std::shared_ptr<Tensor>> &params, const std::shared_ptr<Tensor> &param_data,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议加个注释描述一下每个参数的含义

const std::shared_ptr<Tensor> &grad_data, size_t offset, size_t num_elements_unpadded,
float gradient_scaling_factor, size_t bucket_id);
DataType param_dtype, const std::shared_ptr<Tensor> &grad_data, DataType grad_dtype,
size_t offset, size_t num_elements_unpadded, float gradient_scaling_factor, size_t bucket_id);

size_t bucket_id() const { return bucket_id_; }

Expand All @@ -33,6 +33,10 @@ class ParamAndGradBucket {

const std::shared_ptr<Tensor> &grad_data() const { return grad_data_; }

DataType param_dtype() const { return param_dtype_; }

DataType grad_dtype() const { return grad_dtype_; }

size_t offset() const { return offset_; }

size_t num_elements_unpadded() const { return num_elements_unpadded_; }
Expand All @@ -49,6 +53,8 @@ class ParamAndGradBucket {
std::vector<std::shared_ptr<Tensor>> params_;
std::shared_ptr<Tensor> param_data_;
std::shared_ptr<Tensor> grad_data_;
DataType param_dtype_;
DataType grad_dtype_;

size_t offset_ = 0;
size_t num_elements_unpadded_ = 0;
Expand All @@ -73,6 +79,11 @@ class ParamAndGradBucketGroup {
// Start grad reduce
void StartGradSync();

// Accumulate a parameter grad into bucket buffer
// ZeRO-2: Use this funtion to take over autograd::AccumulateGrad::Backward
void AccumulateParamGrad(const std::shared_ptr<Tensor> &parameter, const std::shared_ptr<Tensor> &grad,
bool overwrite, float learning_rate);

// Wait for gradient reduce to complete
void FinishGradSync();

Expand All @@ -87,6 +98,9 @@ class ParamAndGradBucketGroup {

const std::vector<std::shared_ptr<ParamAndGradBucket>> &buckets() const { return buckets_; }

// ZeRO-2: Get a bucket's local grad shard buffer
std::shared_ptr<Tensor> GetLocalGradShardBuffer(size_t bucket_idx) const;

const DistributedDataParallelConfig &config() const { return ddp_config_; }

private:
Expand All @@ -98,12 +112,19 @@ class ParamAndGradBucketGroup {

std::unordered_set<Tensor *> params_;
std::unordered_set<Tensor *> params_with_grad_;
// Tensor -> (Bucket, Bucket Index)
std::unordered_map<Tensor *, std::pair<std::shared_ptr<ParamAndGradBucket>, size_t>> param_to_bucket_;

// TODO(zbl): Implement CoalescedWork for aggregate works
// According to Megatron-LM's _coalescing_manager
std::vector<std::shared_ptr<Work>> grad_reduce_work_list_;
std::vector<size_t> grad_reduce_bucket_indices_;
std::vector<std::shared_ptr<Work>> param_gather_work_list_;

// ZeRO-2: persistent grad shard buffers and temporary full grad buffers
std::vector<std::shared_ptr<Tensor>> grad_shard_buffer_list_;
std::vector<std::shared_ptr<Tensor>> temp_full_grad_buffer_list_;

std::shared_ptr<ParamAndGradBucketGroup> next_param_gather_bucket_group_ = nullptr;

std::vector<std::vector<std::shared_ptr<Tensor>>> param_buffer_shard_list_;
Expand Down
10 changes: 8 additions & 2 deletions infini_train/src/autograd/accumulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,15 @@ AccumulateGrad::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_output
"running before autograd). The grad is not cast and will be used as-is.";
}

const bool overwrite = tensor_->ConsumeGradOverwriteFlag();
auto hook = tensor_->post_accumulate_grad_hook();
if (hook && hook->TryBypassAccumulate(tensor_, grad_output, overwrite, learning_rate_)) {
tensor_->ResetAccumulator();
return {};
}

if (grad) {
if (tensor_->ConsumeGradOverwriteFlag()) {
if (overwrite) {
// If the tensor is marked to overrite its current grad on next grad update
// See notes in `infini_train::nn::parallel::Reducer::PrepareForBackward()`
// NOTE(zbl): must copy, cannot change grad buffer address
Expand All @@ -48,7 +55,6 @@ AccumulateGrad::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_output
auto new_grad = std::make_shared<Tensor>(*grad_output.get(), 0, grad_output->Dims());
tensor_->set_grad(new_grad);
}
auto hook = tensor_->post_accumulate_grad_hook();
if (hook != nullptr) {
(*hook)(tensor_->grad());
}
Expand Down
Loading
Loading