Skip to content

EXPERIMENTAL tensor parallel REKLS#175

Open
skyw wants to merge 18 commits into
mainfrom
skyw/tp_rekls_exp
Open

EXPERIMENTAL tensor parallel REKLS#175
skyw wants to merge 18 commits into
mainfrom
skyw/tp_rekls_exp

Conversation

@skyw
Copy link
Copy Markdown
Contributor

@skyw skyw commented May 6, 2026

For experimental purpose only, DONOT merge yet.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 6, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 6, 2026

Greptile Summary

This PR introduces TpRekls, a tensor-parallel variant of the REKLS optimizer reimplemented from scratch (not inheriting SOAP) to keep TP bookkeeping isolated. It also adds tp_utils.all_gather_grad_and_kronecker_factors_tp, get_pg_size/get_pg_rank utilities, and distributed CPU tests.

  • Core TP step: each rank all-gathers its sharded gradient and Kronecker-factor shards, runs the full KL-Shampoo correction + dual-eigh eigenbasis update on the assembled tensors, then writes back only the local shard of L/R and the local slice of the preconditioned update. Adam moment state (exp_avg, exp_avg_sq) is kept full-size and duplicated on every rank.
  • Divisibility guard: _init_group validates that both parameter dimensions are divisible by tp_size, giving a clear error instead of a silent shape mismatch during gather.
  • CI expansion: the shell script now loops over all test_distributed_*_cpu.py files for nproc_per_node in {8, 4}, plus an explicit n=1 single-process run for the REKLS test.

Confidence Score: 5/5

Safe to review further; the core TP gather/scatter logic and Kronecker-factor write-back are correct. This is explicitly marked experimental and not intended to merge yet.

The TP gather/scatter logic, divisibility guard, shard write-back, and eigenbasis rotation sequence all match the documented algorithm. The bit-exact distributed test validates correctness end-to-end. The only findings are minor annotation/guard hygiene issues that do not affect optimizer correctness or training outcomes.

emerging_optimizers/soap/rekls.py — the closure assert and TYPE_CHECKING overload annotation are worth cleaning up before any non-experimental use.

Important Files Changed

Filename Overview
emerging_optimizers/soap/rekls.py Adds TpRekls: a tensor-parallel REKLS optimizer that all-gathers sharded gradients and Kronecker factors, runs the full KL-Shampoo + eigenbasis update on every rank, then writes back local shards. Logic appears correct; minor type-annotation inconsistency and assert-vs-raise guard noted.
emerging_optimizers/soap/tp_utils.py New helper that all-gathers a partitioned gradient (along partition_dim) and Kronecker-factor shards (always along dim 0) via three separate dist.all_gather calls, then concatenates them. Straightforward and correct for the sharding contract described in the docstring.
emerging_optimizers/utils/init.py Adds get_pg_size and get_pg_rank with safe fallback to 1/0 when distributed is not initialized or group is None; exported through all.
emerging_optimizers/soap/soap.py Docstring-only improvement to update_kronecker_factors_kl_shampoo: expands description to include the full KL-Shampoo math and renames the eigval_exp parameter description. No logic changes.
tests/test_distributed_rekls_cpu.py Distributed integration test verifying TpRekls produces bit-identical updates to non-distributed REKLS across 5 steps and mixed partition_dims. The tearDownModule destroy_process_group call is unreachable (flagged previously).
tests/test_distributed_soap_utils_cpu.py Tests all_gather_grad_and_kronecker_factors_tp: verifies gathered tensors match the non-distributed full tensors, and that in-place Kronecker factor updates on gathered tensors match the reference. Same unreachable tearDownModule issue as the rekls test.
tests/ci/L0_Tests_CPU.sh Refactors CI to loop over all test_distributed_*_cpu.py files for n in {8,4}, then adds an explicit n=1 single-process run for TpRekls. Shape divisibility for n=8 is satisfied by all parameterized test shapes.

Sequence Diagram

sequenceDiagram
    participant R as Rank i
    participant G as tp_group (all ranks)

    R->>R: "local_grad = p.grad.to(float32)"
    R->>R: _apply_weight_decay_inplace(p, local_grad)
    R->>G: dist.all_gather(grad_shards, local_grad)
    G-->>R: "full_grad = cat(grad_shards, dim=partition_dim)"
    R->>G: dist.all_gather(factor_shards_L, state["L"])
    G-->>R: "full_L = cat(factor_shards_L, dim=0)"
    R->>G: dist.all_gather(factor_shards_R, state["R"])
    G-->>R: "full_R = cat(factor_shards_R, dim=0)"
    R->>R: "pre_eigenbasis = get_eigenbasis_eigh([full_L, full_R])"
    R->>R: update_kronecker_factors_kl_shampoo([full_L, full_R], full_grad, pre_eigenbasis)
    R->>R: state["L"].copy_(full_L.chunk(tp_size)[rank])
    R->>R: state["R"].copy_(full_R.chunk(tp_size)[rank])
    R->>R: "eigenbasis, exp_avg, exp_avg_sq = update_eigenbasis_and_exp_avgs(...)"
    R->>R: "full_grad_proj = precondition(full_grad, eigenbasis)"
    R->>R: "full_adam_update = calculate_laprop_update(full_grad_proj, exp_avg, exp_avg_sq, ...)"
    R->>R: "full_precond_update = precondition(full_adam_update, eigenbasis)"
    R->>R: "p.add_(full_precond_update.chunk(tp_size, dim=partition_dim)[rank], alpha=-lr)"
Loading

Reviews (10): Last reviewed commit: "EXP: use laprop inside REKLS" | Re-trigger Greptile

Comment thread emerging_optimizers/soap/rekls.py
Comment thread tests/test_distributed_rekls_cpu.py Outdated
Comment thread tests/test_distributed_soap_utils_cpu.py Outdated
@skyw skyw marked this pull request as draft May 6, 2026 20:32
@skyw skyw marked this pull request as ready for review May 6, 2026 20:40
@skyw skyw marked this pull request as draft May 6, 2026 20:40
@skyw
Copy link
Copy Markdown
Contributor Author

skyw commented May 6, 2026

@greptile check closure contract again

@skyw skyw marked this pull request as ready for review May 6, 2026 22:33
Comment thread emerging_optimizers/soap/rekls.py
@skyw
Copy link
Copy Markdown
Contributor Author

skyw commented May 6, 2026

@greptile overload can't be dropped because it is in upstream pytorch.

skyw added 12 commits May 6, 2026 15:54
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw skyw force-pushed the skyw/tp_rekls_exp branch from e61f0f2 to 6eee42f Compare May 6, 2026 22:54
@skyw
Copy link
Copy Markdown
Contributor Author

skyw commented May 6, 2026

/ok to test 6eee42f

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 6, 2026

That's a fair point. The @overload stubs are guarded under if TYPE_CHECKING: so they're only visible to type checkers (mypy/pyright), not at runtime — and since TpRekls inherits from optim.Optimizer, which declares those two overloads in its upstream type stubs, carrying them here is necessary to satisfy the @override contract and avoid type errors.

The actual runtime behavior is still made clear by the assert closure is None, "No support for closure" at the top of the implementation, so there's no ambiguity for callers at runtime. Acknowledged — no action needed here.

Tip: You can give me more context with [.greptile/rules.md](https://www.greptile.com/docs/code-review/custom-standards) so I get this right next time.

@skyw
Copy link
Copy Markdown
Contributor Author

skyw commented May 6, 2026

@mkhona-nvidia this actually turned out to be a very mergeable limited scope version, take a quick look.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 6, 2026

Test Results

   75 files  + 25    135 suites  +25   1m 29s ⏱️ +9s
1 044 tests + 13  1 044 ✅ + 13  0 💤 ±0  0 ❌ ±0 
2 470 runs  +157  2 470 ✅ +157  0 💤 ±0  0 ❌ ±0 

Results for commit 6eee42f. ± Comparison against base commit 04ee932.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 6, 2026

Codecov Report

❌ Patch coverage is 88.88889% with 12 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
emerging_optimizers/soap/rekls.py 90.47% 4 Missing and 4 partials ⚠️
emerging_optimizers/utils/__init__.py 55.55% 2 Missing and 2 partials ⚠️

📢 Thoughts on this report? Let us know!

skyw added 6 commits May 14, 2026 09:01
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant