-
Notifications
You must be signed in to change notification settings - Fork 45
feat: Support ZeRO-2 based on DistributedOptimizer #110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
23e7f8c
47b3941
e5b4492
e9610d9
477a198
5f573f2
2e3d7fe
c94de4a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,12 @@ namespace infini_train::autograd { | |
| class PostAccumulateGradHook { | ||
| public: | ||
| virtual void operator()(const std::shared_ptr<Tensor> &tensor) = 0; | ||
|
|
||
| // ZeRO-2: Use this function to take over AccumulateGrad::Backward | ||
| virtual bool TryBypassAccumulate(const std::shared_ptr<Tensor> &, const std::shared_ptr<Tensor> &, bool, float) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PostAccumulateGradHook 语义是在 AccumulateGrad 完成梯度累积后执行 hook,不太合适承担 bypass accumulate 的职责。对于 zero-2 这种需要接管梯度累积路径的场景,更合适的方式是引入一个 PreAccumulateGradHook,在执行梯度累计前进行拦截与处理。 |
||
| return false; | ||
| } | ||
|
|
||
| virtual ~PostAccumulateGradHook() = default; | ||
| }; | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,8 +22,8 @@ namespace infini_train::nn::parallel { | |
| class ParamAndGradBucket { | ||
| public: | ||
| ParamAndGradBucket(const std::vector<std::shared_ptr<Tensor>> ¶ms, const std::shared_ptr<Tensor> ¶m_data, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议加个注释描述一下每个参数的含义 |
||
| const std::shared_ptr<Tensor> &grad_data, size_t offset, size_t num_elements_unpadded, | ||
| float gradient_scaling_factor, size_t bucket_id); | ||
| DataType param_dtype, const std::shared_ptr<Tensor> &grad_data, DataType grad_dtype, | ||
| size_t offset, size_t num_elements_unpadded, float gradient_scaling_factor, size_t bucket_id); | ||
|
|
||
| size_t bucket_id() const { return bucket_id_; } | ||
|
|
||
|
|
@@ -33,6 +33,10 @@ class ParamAndGradBucket { | |
|
|
||
| const std::shared_ptr<Tensor> &grad_data() const { return grad_data_; } | ||
|
|
||
| DataType param_dtype() const { return param_dtype_; } | ||
|
|
||
| DataType grad_dtype() const { return grad_dtype_; } | ||
|
|
||
| size_t offset() const { return offset_; } | ||
|
|
||
| size_t num_elements_unpadded() const { return num_elements_unpadded_; } | ||
|
|
@@ -49,6 +53,8 @@ class ParamAndGradBucket { | |
| std::vector<std::shared_ptr<Tensor>> params_; | ||
| std::shared_ptr<Tensor> param_data_; | ||
| std::shared_ptr<Tensor> grad_data_; | ||
| DataType param_dtype_; | ||
| DataType grad_dtype_; | ||
|
|
||
| size_t offset_ = 0; | ||
| size_t num_elements_unpadded_ = 0; | ||
|
|
@@ -73,6 +79,11 @@ class ParamAndGradBucketGroup { | |
| // Start grad reduce | ||
| void StartGradSync(); | ||
|
|
||
| // Accumulate a parameter grad into bucket buffer | ||
| // ZeRO-2: Use this funtion to take over autograd::AccumulateGrad::Backward | ||
| void AccumulateParamGrad(const std::shared_ptr<Tensor> ¶meter, const std::shared_ptr<Tensor> &grad, | ||
| bool overwrite, float learning_rate); | ||
|
|
||
| // Wait for gradient reduce to complete | ||
| void FinishGradSync(); | ||
|
|
||
|
|
@@ -87,6 +98,9 @@ class ParamAndGradBucketGroup { | |
|
|
||
| const std::vector<std::shared_ptr<ParamAndGradBucket>> &buckets() const { return buckets_; } | ||
|
|
||
| // ZeRO-2: Get a bucket's local grad shard buffer | ||
| std::shared_ptr<Tensor> GetLocalGradShardBuffer(size_t bucket_idx) const; | ||
|
|
||
| const DistributedDataParallelConfig &config() const { return ddp_config_; } | ||
|
|
||
| private: | ||
|
|
@@ -98,12 +112,19 @@ class ParamAndGradBucketGroup { | |
|
|
||
| std::unordered_set<Tensor *> params_; | ||
| std::unordered_set<Tensor *> params_with_grad_; | ||
| // Tensor -> (Bucket, Bucket Index) | ||
| std::unordered_map<Tensor *, std::pair<std::shared_ptr<ParamAndGradBucket>, size_t>> param_to_bucket_; | ||
|
|
||
| // TODO(zbl): Implement CoalescedWork for aggregate works | ||
| // According to Megatron-LM's _coalescing_manager | ||
| std::vector<std::shared_ptr<Work>> grad_reduce_work_list_; | ||
| std::vector<size_t> grad_reduce_bucket_indices_; | ||
| std::vector<std::shared_ptr<Work>> param_gather_work_list_; | ||
|
|
||
| // ZeRO-2: persistent grad shard buffers and temporary full grad buffers | ||
| std::vector<std::shared_ptr<Tensor>> grad_shard_buffer_list_; | ||
| std::vector<std::shared_ptr<Tensor>> temp_full_grad_buffer_list_; | ||
|
|
||
| std::shared_ptr<ParamAndGradBucketGroup> next_param_gather_bucket_group_ = nullptr; | ||
|
|
||
| std::vector<std::vector<std::shared_ptr<Tensor>>> param_buffer_shard_list_; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里还有必要保留 use_distributed_optimizer 参数吗?感觉是不是用统一的一个参数来区分 zero1/zero2/zero3 比较合适,想体现切分语义的话可以用这个:
https://github.com/NVIDIA/Megatron-LM/blob/0010683033b67cb091719ccb1d8195a524b91356/docs/user-guide/parallelism-guide.md?plain=1#L53
或者简单一点就保留 zero_stage 参数(默认值 0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个我想的是,由于现在的 zero-1/2 是基于 DistOpt 的基建展开做的,而 Megatron 的三级 zero 是基于 FSDP 的基建做的(也就是你引用的这部分),后续咱们如果支持 fsdp 的话,可能就存在两种 zero 的实现路径。你也可以看到你引用的这个文档里,使用上也得先写一行
--use-megatron-fsdp的 flag,后面再带着这个切分方法的参数。There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我感觉 megatron 这里本身可能存在一些历史遗留问题,我们不一定需要完全照搬;另外之后如果真的需要支持 FSDP,也不应该长期保留两套 zero 实现路径,更合理的方向可能还是统一抽象成一套 zero 语义或者只保留一种语义,否则长期维护成本太高。