EXPERIMENTAL tensor parallel REKLS#175
Conversation
Greptile SummaryThis PR introduces
Confidence Score: 5/5Safe to review further; the core TP gather/scatter logic and Kronecker-factor write-back are correct. This is explicitly marked experimental and not intended to merge yet. The TP gather/scatter logic, divisibility guard, shard write-back, and eigenbasis rotation sequence all match the documented algorithm. The bit-exact distributed test validates correctness end-to-end. The only findings are minor annotation/guard hygiene issues that do not affect optimizer correctness or training outcomes. emerging_optimizers/soap/rekls.py — the closure assert and TYPE_CHECKING overload annotation are worth cleaning up before any non-experimental use. Important Files Changed
Sequence DiagramsequenceDiagram
participant R as Rank i
participant G as tp_group (all ranks)
R->>R: "local_grad = p.grad.to(float32)"
R->>R: _apply_weight_decay_inplace(p, local_grad)
R->>G: dist.all_gather(grad_shards, local_grad)
G-->>R: "full_grad = cat(grad_shards, dim=partition_dim)"
R->>G: dist.all_gather(factor_shards_L, state["L"])
G-->>R: "full_L = cat(factor_shards_L, dim=0)"
R->>G: dist.all_gather(factor_shards_R, state["R"])
G-->>R: "full_R = cat(factor_shards_R, dim=0)"
R->>R: "pre_eigenbasis = get_eigenbasis_eigh([full_L, full_R])"
R->>R: update_kronecker_factors_kl_shampoo([full_L, full_R], full_grad, pre_eigenbasis)
R->>R: state["L"].copy_(full_L.chunk(tp_size)[rank])
R->>R: state["R"].copy_(full_R.chunk(tp_size)[rank])
R->>R: "eigenbasis, exp_avg, exp_avg_sq = update_eigenbasis_and_exp_avgs(...)"
R->>R: "full_grad_proj = precondition(full_grad, eigenbasis)"
R->>R: "full_adam_update = calculate_laprop_update(full_grad_proj, exp_avg, exp_avg_sq, ...)"
R->>R: "full_precond_update = precondition(full_adam_update, eigenbasis)"
R->>R: "p.add_(full_precond_update.chunk(tp_size, dim=partition_dim)[rank], alpha=-lr)"
Reviews (10): Last reviewed commit: "EXP: use laprop inside REKLS" | Re-trigger Greptile |
|
@greptile check closure contract again |
|
@greptile overload can't be dropped because it is in upstream pytorch. |
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
|
/ok to test 6eee42f |
|
That's a fair point. The The actual runtime behavior is still made clear by the
|
|
@mkhona-nvidia this actually turned out to be a very mergeable limited scope version, take a quick look. |
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
For experimental purpose only, DONOT merge yet.