-
Notifications
You must be signed in to change notification settings - Fork 2
Add Primus Turbo grouped GEMM support for MoE sparse matmul #83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -855,6 +855,28 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments): | |
| use_tokamax_backend=self.config.use_tokamax_gmm, | ||
| is_fsdp_shard_on_exp=self.config.fsdp_shard_on_exp, | ||
| ) | ||
| elif self.config.use_turbo_grouped_gemm: | ||
| try: | ||
| from primus_turbo.jax.lax.grouped_gemm import grouped_gemm as turbo_grouped_gemm | ||
| except ImportError: | ||
| raise ImportError("use_turbo_grouped_gemm=True requires the primus_turbo package to be installed.") | ||
| if not getattr(turbo_grouped_gemm, "_logged", False): | ||
| max_logging.log("Using primus_turbo grouped_gemm in MoE sparse matmul") | ||
| turbo_grouped_gemm._logged = True | ||
| # Thread-local x64: CK kernel requires int64 group_sizes, but | ||
| # global x64 breaks argsort (XLA-ROCm s32/s64 scatter mismatch). | ||
| # Use jax.experimental.enable_x64() for thread-local scope, | ||
| # safe for concurrent shard_map threads. | ||
| # Remove this once primus_turbo accepts int32 group_lens natively. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a GitHub issue for this in the |
||
| with jax.experimental.enable_x64(): | ||
| output = turbo_grouped_gemm( | ||
| inputs, | ||
| kernel, | ||
| group_sizes.astype(jnp.int64), | ||
| transA=False, | ||
| transB=False, | ||
| num_cu=-1, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @qianghan-amd Do we want to be able to allow users to tune this? maybe as an additional flag in the config? I think the optimal CU count used depends on # of experts and GEMM sizes, right? |
||
| ) | ||
| else: | ||
| rhs_inputs = kernel | ||
| if isinstance(kernel, aqt.QTensor): | ||
|
|
@@ -1445,7 +1467,10 @@ def get_einsum( | |
| def aqt_einsum(*args, **kwargs): # pylint: disable=unused-argument | ||
| # simply skip kwargs, since aqt einsum doesn't support any kwargs | ||
| # like precision | ||
| is_aqt = not ( isinstance(self.quant, quantizations.Fp8Quantization) or isinstance(self.quant, quantizations.NANOOFp8Quantization) ) | ||
| is_aqt = not ( | ||
| isinstance(self.quant, quantizations.Fp8Quantization) | ||
| or isinstance(self.quant, quantizations.NANOOFp8Quantization) | ||
| ) | ||
| kw = {"mesh_axes": rhs_mesh_axes} if is_aqt else {"dtype": self.dtype} | ||
| return self.quant.einsum(**kw)(*args) # pytype: disable=attribute-error | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If
getattr(turbo_grouped_gemm, "_logged", False)returnsTrue, don't we still want to printmax_logging.log("Using primus_turbo grouped_gemm in MoE sparse matmul")?