diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 27eafbecdc..404fae85fd 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -598,7 +598,8 @@ def test_sanity_grouped_linear( # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527. bs = bs * 16 num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) - + if os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0") == "0" and single_param: + pytest.skip("single parameter grouped linear requires NVTE_GROUPED_LINEAR_SINGLE_PARAM=1") skip_unsupported_backward_override("grouped_linear", fp8_recipe, backward_override) if fp8_recipe is not None: fp8_recipe = copy.deepcopy(fp8_recipe)