From ceb40404fbd75c8565b812040a449467ab5d893e Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Thu, 5 Feb 2026 16:48:37 +0000 Subject: [PATCH 1/9] Fix sequence padding for mxfp8 training Signed-off-by: Guyue Huang --- nemo_rl/models/megatron/data.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/nemo_rl/models/megatron/data.py b/nemo_rl/models/megatron/data.py index f884e95e1b..a146bf976f 100644 --- a/nemo_rl/models/megatron/data.py +++ b/nemo_rl/models/megatron/data.py @@ -547,7 +547,6 @@ def _get_pack_sequence_parameters_for_megatron( cp_size = megatron_cfg["context_parallel_size"] fp8_cfg = megatron_cfg.get("fp8_cfg", None) or {} use_fp8 = fp8_cfg.get("enabled", False) - use_blockwise_fp8 = fp8_cfg.get("fp8_recipe", None) == "blockwise" # individual sequence needs to be splitted to CP domain, and to TP domain when SP is enabled. pad_individual_seqs_to_multiple_of = 1 @@ -558,7 +557,11 @@ def _get_pack_sequence_parameters_for_megatron( # packed sequence length, after splitted to TP and CP domains, needs to be divisible by 128 if using blockwise FP8, and divisible by 16 if using other FP8 recipes. if use_fp8: - divisor = 128 if use_blockwise_fp8 else 16 + divisor = 16 + if fp8_cfg.get("fp8_recipe", None) == "blockwise": + divisor = 128 + elif fp8_cfg.get("fp8_recipe", None) == "mxfp8": + divisor = 32 pad_packed_seq_to_multiple_of = divisor if cp_size > 1: pad_packed_seq_to_multiple_of *= cp_size * 2 From b68b5cafe3db1736f2d8ed83dd02db4b9bf22244 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Mon, 9 Feb 2026 21:21:36 +0000 Subject: [PATCH 2/9] Fix test coverage Signed-off-by: Guyue Huang --- .../models/megatron/test_megatron_data.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/tests/unit/models/megatron/test_megatron_data.py b/tests/unit/models/megatron/test_megatron_data.py index 28b4d55479..ea71d70be1 100644 --- a/tests/unit/models/megatron/test_megatron_data.py +++ b/tests/unit/models/megatron/test_megatron_data.py @@ -1606,6 +1606,69 @@ def run_all_get_pack_sequence_parameters_for_megatron_tests(self): "success": False, "error": f"Expected pad_individual=1, pad_packed=1, pad_to={test_seq_len}, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", } + + # Test 15: FP8 with MXFP8 recipe + megatron_cfg = { + "tensor_model_parallel_size": 1, + "sequence_parallel": False, + "pipeline_model_parallel_size": 1, + "context_parallel_size": 1, + "fp8_cfg": {"enabled": True, "fp8_recipe": "mxfp8"}, + } + + pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( + megatron_cfg, max_seq_len + ) + + if pad_individual != 1 or pad_packed != 32 or pad_to is not None: + return { + "success": False, + "error": f"Expected pad_individual=1, pad_packed=32, pad_to=None, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", + } + + # Test 16: FP8 with MXFP8 recipe, CP, and TP+SP + megatron_cfg = { + "tensor_model_parallel_size": 2, + "sequence_parallel": True, + "pipeline_model_parallel_size": 1, + "context_parallel_size": 4, + "fp8_cfg": {"enabled": True, "fp8_recipe": "mxfp8"}, + } + + pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( + megatron_cfg, max_seq_len + ) + + expected_individual = 4 * 2 * 2 # cp_size * 2 * tp_size + expected_packed = 32 * 4 * 2 * 2 # divisor * cp_size * 2 * tp_size + + if pad_individual != expected_individual or pad_packed != expected_packed or pad_to is not None: + return { + "success": False, + "error": f"Expected pad_individual={expected_individual}, pad_packed={expected_packed}, pad_to=None, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", + } + + # Test 17: FP8 with MXFP8 recipe, CP, TP+SP, and PP + megatron_cfg = { + "tensor_model_parallel_size": 2, + "sequence_parallel": True, + "pipeline_model_parallel_size": 4, + "context_parallel_size": 4, + "fp8_cfg": {"enabled": True, "fp8_recipe": "mxfp8"}, + } + + pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( + megatron_cfg, max_seq_len + ) + + expected_individual = 4 * 2 * 2 # cp_size * 2 * tp_size + expected_packed = 32 * 4 * 2 * 2 * 4 # divisor * cp_size * 2 * tp_size * pp_size + + if pad_individual != expected_individual or pad_packed != expected_packed or pad_to != max_seq_len: + return { + "success": False, + "error": f"Expected pad_individual={expected_individual}, pad_packed={expected_packed}, pad_to={max_seq_len}, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", + } return {"success": True, "error": None} From b98dac56d95f9065d4d1e9aeb6ad918723d2babe Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Mon, 9 Feb 2026 21:49:48 +0000 Subject: [PATCH 3/9] Fix lint Signed-off-by: Guyue Huang --- .../models/megatron/test_megatron_data.py | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/unit/models/megatron/test_megatron_data.py b/tests/unit/models/megatron/test_megatron_data.py index ea71d70be1..0841621d92 100644 --- a/tests/unit/models/megatron/test_megatron_data.py +++ b/tests/unit/models/megatron/test_megatron_data.py @@ -1606,7 +1606,7 @@ def run_all_get_pack_sequence_parameters_for_megatron_tests(self): "success": False, "error": f"Expected pad_individual=1, pad_packed=1, pad_to={test_seq_len}, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", } - + # Test 15: FP8 with MXFP8 recipe megatron_cfg = { "tensor_model_parallel_size": 1, @@ -1642,12 +1642,16 @@ def run_all_get_pack_sequence_parameters_for_megatron_tests(self): expected_individual = 4 * 2 * 2 # cp_size * 2 * tp_size expected_packed = 32 * 4 * 2 * 2 # divisor * cp_size * 2 * tp_size - if pad_individual != expected_individual or pad_packed != expected_packed or pad_to is not None: + if ( + pad_individual != expected_individual + or pad_packed != expected_packed + or pad_to is not None + ): return { "success": False, "error": f"Expected pad_individual={expected_individual}, pad_packed={expected_packed}, pad_to=None, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", } - + # Test 17: FP8 with MXFP8 recipe, CP, TP+SP, and PP megatron_cfg = { "tensor_model_parallel_size": 2, @@ -1656,15 +1660,21 @@ def run_all_get_pack_sequence_parameters_for_megatron_tests(self): "context_parallel_size": 4, "fp8_cfg": {"enabled": True, "fp8_recipe": "mxfp8"}, } - + pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( megatron_cfg, max_seq_len ) - + expected_individual = 4 * 2 * 2 # cp_size * 2 * tp_size - expected_packed = 32 * 4 * 2 * 2 * 4 # divisor * cp_size * 2 * tp_size * pp_size - - if pad_individual != expected_individual or pad_packed != expected_packed or pad_to != max_seq_len: + expected_packed = ( + 32 * 4 * 2 * 2 * 4 + ) # divisor * cp_size * 2 * tp_size * pp_size + + if ( + pad_individual != expected_individual + or pad_packed != expected_packed + or pad_to != max_seq_len + ): return { "success": False, "error": f"Expected pad_individual={expected_individual}, pad_packed={expected_packed}, pad_to={max_seq_len}, got pad_individual={pad_individual}, pad_packed={pad_packed}, pad_to={pad_to}", From 416d6dfa9de2fec68b038dfdf6e6200ed744783c Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Mon, 9 Feb 2026 22:29:34 +0000 Subject: [PATCH 4/9] fix lint Signed-off-by: Guyue Huang --- tests/unit/models/megatron/test_megatron_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/models/megatron/test_megatron_data.py b/tests/unit/models/megatron/test_megatron_data.py index 0841621d92..2ed91b3574 100644 --- a/tests/unit/models/megatron/test_megatron_data.py +++ b/tests/unit/models/megatron/test_megatron_data.py @@ -1668,7 +1668,7 @@ def run_all_get_pack_sequence_parameters_for_megatron_tests(self): expected_individual = 4 * 2 * 2 # cp_size * 2 * tp_size expected_packed = ( 32 * 4 * 2 * 2 * 4 - ) # divisor * cp_size * 2 * tp_size * pp_size + ) # divisor * cp_size * 2 * tp_size * pp_size if ( pad_individual != expected_individual From 52e66ae3e09db83c7e1f48ed42699187abc355b2 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 10 Feb 2026 17:07:07 +0000 Subject: [PATCH 5/9] Fix ut Signed-off-by: root --- tests/unit/models/megatron/test_megatron_data.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/unit/models/megatron/test_megatron_data.py b/tests/unit/models/megatron/test_megatron_data.py index 2ed91b3574..e698720954 100644 --- a/tests/unit/models/megatron/test_megatron_data.py +++ b/tests/unit/models/megatron/test_megatron_data.py @@ -1329,7 +1329,7 @@ def run_all_get_pack_sequence_parameters_for_megatron_tests(self): "pipeline_model_parallel_size": 1, "context_parallel_size": 1, } - max_seq_len = 1024 + max_seq_len = 1023 pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( megatron_cfg, max_seq_len @@ -1515,10 +1515,14 @@ def run_all_get_pack_sequence_parameters_for_megatron_tests(self): expected_individual = 2 * 2 * 2 # cp_size * 2 * tp_size expected_packed = 16 * 2 * 2 * 2 # divisor * cp_size * 2 * tp_size + + def _round_up_to_multiple_of(x, y): + return (x + y - 1) // y * y + if ( pad_individual != expected_individual or pad_packed != expected_packed - or pad_to != max_seq_len + or pad_to != _round_up_to_multiple_of(max_seq_len, expected_packed) ): return { "success": False, @@ -1666,14 +1670,12 @@ def run_all_get_pack_sequence_parameters_for_megatron_tests(self): ) expected_individual = 4 * 2 * 2 # cp_size * 2 * tp_size - expected_packed = ( - 32 * 4 * 2 * 2 * 4 - ) # divisor * cp_size * 2 * tp_size * pp_size + expected_packed = 32 * 4 * 2 * 2 # divisor * cp_size * 2 * tp_size * pp_size if ( pad_individual != expected_individual or pad_packed != expected_packed - or pad_to != max_seq_len + or pad_to != _round_up_to_multiple_of(max_seq_len, expected_packed) ): return { "success": False, From f2cadd6f793947434538279b9ee243f0bc01e8e0 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Wed, 11 Feb 2026 10:56:05 -0800 Subject: [PATCH 6/9] Make sure fp8_cfg placeholder is in all yamls Signed-off-by: Guyue Huang --- examples/configs/distillation_math.yaml | 6 ++ .../configs/distillation_math_megatron.yaml | 6 ++ examples/configs/dpo.yaml | 8 ++- examples/configs/grpo_math_1B.yaml | 6 +- ...3.1-8b-instruct-2n8g-megatron-fp8-e2e.yaml | 3 - ...oonlight-16ba3b-4n8g-megatron-fp8-e2e.yaml | 10 +-- ...ma3.1-8b-instruct-2n8g-fp8-async-1off.yaml | 3 - examples/configs/rm.yaml | 5 ++ examples/configs/sft.yaml | 6 ++ .../sft_openmathinstruct2_megatron.yaml | 13 ++-- examples/configs/vlm_grpo_3B.yaml | 6 ++ examples/configs/vlm_grpo_3B_megatron.yaml | 5 ++ nemo_rl/models/megatron/data.py | 4 +- .../configs/grpo_math_1B.yaml | 6 +- .../models/megatron/test_megatron_data.py | 63 ++++++++++++++++--- 15 files changed, 115 insertions(+), 35 deletions(-) diff --git a/examples/configs/distillation_math.yaml b/examples/configs/distillation_math.yaml index 67ff8a71d2..bb3743a4fc 100644 --- a/examples/configs/distillation_math.yaml +++ b/examples/configs/distillation_math.yaml @@ -155,6 +155,12 @@ policy: &POLICY_BASE use_custom_fsdp: false data_parallel_sharding_strategy: "optim_grads_params" + fp8_cfg: + enabled: false + fp8: "e4m3" + fp8_recipe: "blockwise" + fp8_param: false + scheduler: - name: "torch.optim.lr_scheduler.LinearLR" kwargs: diff --git a/examples/configs/distillation_math_megatron.yaml b/examples/configs/distillation_math_megatron.yaml index ae2fbcd3e1..76151678f1 100644 --- a/examples/configs/distillation_math_megatron.yaml +++ b/examples/configs/distillation_math_megatron.yaml @@ -106,6 +106,12 @@ policy: &POLICY_BASE use_custom_fsdp: false data_parallel_sharding_strategy: "optim_grads_params" + fp8_cfg: + enabled: false + fp8: "e4m3" + fp8_recipe: "blockwise" + fp8_param: false + generation: backend: "vllm" max_new_tokens: ${..max_total_sequence_length} # refer to local policy/teacher config diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index 579b53f264..fec83a6199 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -177,7 +177,13 @@ policy: overlap_param_gather: true data_parallel_sharding_strategy: "optim_grads_params" use_custom_fsdp: false - + + fp8_cfg: + enabled: false + fp8: "e4m3" + fp8_recipe: "blockwise" + fp8_param: false + data: max_input_seq_length: ${policy.max_total_sequence_length} shuffle: true diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 69e25f8f4f..da7d911119 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -189,7 +189,11 @@ policy: use_custom_fsdp: false data_parallel_sharding_strategy: "optim_grads_params" - fp8_cfg: null + fp8_cfg: + enabled: false + fp8: "e4m3" + fp8_recipe: "blockwise" + fp8_param: false env_vars: null diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-2n8g-megatron-fp8-e2e.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-2n8g-megatron-fp8-e2e.yaml index 6411c6fb49..5b348e9c5a 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-2n8g-megatron-fp8-e2e.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-2n8g-megatron-fp8-e2e.yaml @@ -33,9 +33,6 @@ policy: lr_warmup_init: 5.0e-08 fp8_cfg: enabled: true - fp8: e4m3 - fp8_recipe: blockwise - fp8_param: false env_vars: NVTE_FP8_BLOCK_SCALING_FP32_SCALES: '1' generation: diff --git a/examples/configs/recipes/llm/grpo-moonlight-16ba3b-4n8g-megatron-fp8-e2e.yaml b/examples/configs/recipes/llm/grpo-moonlight-16ba3b-4n8g-megatron-fp8-e2e.yaml index 27108c55c7..54b8d6671f 100644 --- a/examples/configs/recipes/llm/grpo-moonlight-16ba3b-4n8g-megatron-fp8-e2e.yaml +++ b/examples/configs/recipes/llm/grpo-moonlight-16ba3b-4n8g-megatron-fp8-e2e.yaml @@ -28,9 +28,6 @@ policy: apply_rope_fusion: false fp8_cfg: enabled: true - fp8: e4m3 - fp8_recipe: blockwise - fp8_param: false optimizer: lr: 1.0e-06 use_precision_aware_optimizer: false @@ -43,10 +40,9 @@ policy: precision: fp8 use_deep_gemm: true gpu_memory_utilization: 0.5 - quantization_ignored_layer_kws: [ - a_proj, - b_proj - ] + quantization_ignored_layer_kws: + - a_proj + - b_proj logger: monitor_gpus: false wandb: diff --git a/examples/configs/recipes/llm/performance/grpo-llama3.1-8b-instruct-2n8g-fp8-async-1off.yaml b/examples/configs/recipes/llm/performance/grpo-llama3.1-8b-instruct-2n8g-fp8-async-1off.yaml index b32786f7d7..e3f935cdc0 100644 --- a/examples/configs/recipes/llm/performance/grpo-llama3.1-8b-instruct-2n8g-fp8-async-1off.yaml +++ b/examples/configs/recipes/llm/performance/grpo-llama3.1-8b-instruct-2n8g-fp8-async-1off.yaml @@ -5,9 +5,6 @@ policy: megatron_cfg: fp8_cfg: enabled: true - fp8: "e4m3" - fp8_recipe: "blockwise" - fp8_param: false env_vars: NVTE_FP8_BLOCK_SCALING_FP32_SCALES: "1" generation: diff --git a/examples/configs/rm.yaml b/examples/configs/rm.yaml index 9e89d6b199..52239ec8cd 100644 --- a/examples/configs/rm.yaml +++ b/examples/configs/rm.yaml @@ -128,6 +128,11 @@ policy: overlap_param_gather: false data_parallel_sharding_strategy: "optim_grads_params" + fp8_cfg: + enabled: false + fp8: "e4m3" + fp8_recipe: "blockwise" + fp8_param: false data: max_input_seq_length: ${policy.max_total_sequence_length} diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 6d53d7f606..7b90a90c38 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -175,6 +175,12 @@ policy: data_parallel_sharding_strategy: "optim_grads_params" use_custom_fsdp: false + fp8_cfg: + enabled: false + fp8: "e4m3" + fp8_recipe: "blockwise" + fp8_param: false + data: max_input_seq_length: ${policy.max_total_sequence_length} add_bos: true diff --git a/examples/configs/sft_openmathinstruct2_megatron.yaml b/examples/configs/sft_openmathinstruct2_megatron.yaml index faca12e0ae..40f62473ac 100644 --- a/examples/configs/sft_openmathinstruct2_megatron.yaml +++ b/examples/configs/sft_openmathinstruct2_megatron.yaml @@ -100,14 +100,11 @@ policy: env_vars: PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:False" - ## fp8 training currently not supported - #fp8_cfg: - # enabled: true - # fp8: hybrid - # fp8_recipe: delayed - # fp8_param: true # false gives the following error: "RuntimeError: /TransformerEngine/transformer_engine/common/gemm/cublaslt_gemm.cu:116 in function CanonicalizeGemmInput: Assertion failed: !is_fp8_dtype(ret.Atype). Input A is missing column-wise usage" - # fp8_dot_product_attention: false #true - # fp8_multi_head_attention: false #true + fp8_cfg: + enabled: false + fp8: "e4m3" + fp8_recipe: "blockwise" + fp8_param: false dynamic_batching: enabled: false diff --git a/examples/configs/vlm_grpo_3B.yaml b/examples/configs/vlm_grpo_3B.yaml index 9c6049d85d..5d0c1aae2a 100644 --- a/examples/configs/vlm_grpo_3B.yaml +++ b/examples/configs/vlm_grpo_3B.yaml @@ -159,6 +159,12 @@ policy: use_custom_fsdp: false data_parallel_sharding_strategy: "optim_grads_params" + fp8_cfg: + enabled: false + fp8: "e4m3" + fp8_recipe: "blockwise" + fp8_param: false + # dynamic_batching improves performance by ensuring logprob and training microbatches # have a sufficent number of tokens to maximize GPU utilization. Specifically, variable length diff --git a/examples/configs/vlm_grpo_3B_megatron.yaml b/examples/configs/vlm_grpo_3B_megatron.yaml index 8e13681629..6d9016503a 100644 --- a/examples/configs/vlm_grpo_3B_megatron.yaml +++ b/examples/configs/vlm_grpo_3B_megatron.yaml @@ -189,6 +189,11 @@ policy: overlap_param_gather: true use_custom_fsdp: false data_parallel_sharding_strategy: optim_grads_params + fp8_cfg: + enabled: false + fp8: "e4m3" + fp8_recipe: "blockwise" + fp8_param: false data: max_input_seq_length: ${policy.max_total_sequence_length} shuffle: true diff --git a/nemo_rl/models/megatron/data.py b/nemo_rl/models/megatron/data.py index a146bf976f..5adbec29c7 100644 --- a/nemo_rl/models/megatron/data.py +++ b/nemo_rl/models/megatron/data.py @@ -558,9 +558,9 @@ def _get_pack_sequence_parameters_for_megatron( # packed sequence length, after splitted to TP and CP domains, needs to be divisible by 128 if using blockwise FP8, and divisible by 16 if using other FP8 recipes. if use_fp8: divisor = 16 - if fp8_cfg.get("fp8_recipe", None) == "blockwise": + if fp8_cfg["fp8_recipe"] == "blockwise": divisor = 128 - elif fp8_cfg.get("fp8_recipe", None) == "mxfp8": + elif fp8_cfg["fp8_recipe"] == "mxfp8": divisor = 32 pad_packed_seq_to_multiple_of = divisor if cp_size > 1: diff --git a/research/template_project/configs/grpo_math_1B.yaml b/research/template_project/configs/grpo_math_1B.yaml index ef968b717e..758a1def74 100644 --- a/research/template_project/configs/grpo_math_1B.yaml +++ b/research/template_project/configs/grpo_math_1B.yaml @@ -136,7 +136,11 @@ policy: use_custom_fsdp: false data_parallel_sharding_strategy: "optim_grads_params" - fp8_cfg: null + fp8_cfg: + enabled: false + fp8: "e4m3" + fp8_recipe: "blockwise" + fp8_param: false env_vars: null diff --git a/tests/unit/models/megatron/test_megatron_data.py b/tests/unit/models/megatron/test_megatron_data.py index e698720954..6e381d2933 100644 --- a/tests/unit/models/megatron/test_megatron_data.py +++ b/tests/unit/models/megatron/test_megatron_data.py @@ -1443,7 +1443,12 @@ def run_all_get_pack_sequence_parameters_for_megatron_tests(self): "sequence_parallel": False, "pipeline_model_parallel_size": 1, "context_parallel_size": 1, - "fp8_cfg": {"enabled": True}, + "fp8_cfg": { + "enabled": True, + "fp8": "hybrid", + "fp8_recipe": "tensorwise", + "fp8_param": False, + }, } pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( @@ -1462,7 +1467,12 @@ def run_all_get_pack_sequence_parameters_for_megatron_tests(self): "sequence_parallel": False, "pipeline_model_parallel_size": 1, "context_parallel_size": 1, - "fp8_cfg": {"enabled": True, "fp8_recipe": "blockwise"}, + "fp8_cfg": { + "enabled": True, + "fp8": "e4m3", + "fp8_recipe": "blockwise", + "fp8_param": False, + }, } pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( @@ -1481,7 +1491,12 @@ def run_all_get_pack_sequence_parameters_for_megatron_tests(self): "sequence_parallel": True, "pipeline_model_parallel_size": 1, "context_parallel_size": 4, - "fp8_cfg": {"enabled": True, "fp8_recipe": "blockwise"}, + "fp8_cfg": { + "enabled": True, + "fp8": "e4m3", + "fp8_recipe": "blockwise", + "fp8_param": False, + }, } pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( @@ -1506,7 +1521,12 @@ def run_all_get_pack_sequence_parameters_for_megatron_tests(self): "sequence_parallel": True, "pipeline_model_parallel_size": 4, "context_parallel_size": 2, - "fp8_cfg": {"enabled": True, "fp8_recipe": "other"}, + "fp8_cfg": { + "enabled": True, + "fp8": "hybrid", + "fp8_recipe": "tensorwise", + "fp8_param": False, + }, } pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( @@ -1535,7 +1555,12 @@ def _round_up_to_multiple_of(x, y): "sequence_parallel": False, "pipeline_model_parallel_size": 1, "context_parallel_size": 1, - "fp8_cfg": {"enabled": False}, + "fp8_cfg": { + "enabled": False, + "fp8": "e4m3", + "fp8_recipe": "blockwise", + "fp8_param": False, + }, } pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( @@ -1573,7 +1598,12 @@ def _round_up_to_multiple_of(x, y): "sequence_parallel": True, "pipeline_model_parallel_size": 1, "context_parallel_size": 8, - "fp8_cfg": {"enabled": True, "fp8_recipe": "blockwise"}, + "fp8_cfg": { + "enabled": True, + "fp8": "e4m3", + "fp8_recipe": "blockwise", + "fp8_param": False, + }, } pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( @@ -1617,7 +1647,12 @@ def _round_up_to_multiple_of(x, y): "sequence_parallel": False, "pipeline_model_parallel_size": 1, "context_parallel_size": 1, - "fp8_cfg": {"enabled": True, "fp8_recipe": "mxfp8"}, + "fp8_cfg": { + "enabled": True, + "fp8": "e4m3", + "fp8_recipe": "mxfp8", + "fp8_param": False, + }, } pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( @@ -1636,7 +1671,12 @@ def _round_up_to_multiple_of(x, y): "sequence_parallel": True, "pipeline_model_parallel_size": 1, "context_parallel_size": 4, - "fp8_cfg": {"enabled": True, "fp8_recipe": "mxfp8"}, + "fp8_cfg": { + "enabled": True, + "fp8": "e4m3", + "fp8_recipe": "mxfp8", + "fp8_param": False, + }, } pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( @@ -1662,7 +1702,12 @@ def _round_up_to_multiple_of(x, y): "sequence_parallel": True, "pipeline_model_parallel_size": 4, "context_parallel_size": 4, - "fp8_cfg": {"enabled": True, "fp8_recipe": "mxfp8"}, + "fp8_cfg": { + "enabled": True, + "fp8": "e4m3", + "fp8_recipe": "mxfp8", + "fp8_param": False, + }, } pad_individual, pad_packed, pad_to = _get_pack_sequence_parameters_for_megatron( From 2a6e7a4e2738715e8aa9c5f7ac8b94c9cd1d40dc Mon Sep 17 00:00:00 2001 From: root Date: Mon, 16 Feb 2026 22:43:26 +0000 Subject: [PATCH 7/9] Fix crashes for mxfp8 fp8_param_gather Signed-off-by: root --- nemo_rl/models/megatron/setup.py | 6 +++++ .../policy/workers/megatron_policy_worker.py | 26 +++++++++++++++---- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index 24bfdb0605..b7c7926524 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -608,6 +608,12 @@ def _create_megatron_config( data_parallel_sharding_strategy=config["megatron_cfg"][ "distributed_data_parallel_config" ]["data_parallel_sharding_strategy"], + fp8_param_gather=config["megatron_cfg"]["optimizer"].get( + "reuse_grad_buf_for_mxfp8_param_ag", False + ), + reuse_grad_buf_for_mxfp8_param_ag=config["megatron_cfg"]["optimizer"].get( + "reuse_grad_buf_for_mxfp8_param_ag", False + ), ), scheduler=SchedulerConfig(**config["megatron_cfg"]["scheduler"]), dataset=None, diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 48ba0623e2..ec1d9e80a5 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -382,6 +382,20 @@ def train( self.model.zero_grad_buffer() self.optimizer.zero_grad() + from megatron.bridge.training.train import ( + _handle_mxfp8_param_buffer_copy, + ) + + _handle_mxfp8_param_buffer_copy( + optimizer=self.optimizer, + reuse_grad_buf_for_mxfp8_param_ag=self.cfg["megatron_cfg"][ + "optimizer" + ]["reuse_grad_buf_for_mxfp8_param_ag"], + overlap_param_gather=self.cfg["megatron_cfg"][ + "distributed_data_parallel_config" + ]["overlap_param_gather"], + ) + # Forward pass. forward_backward_func = get_forward_backward_func() losses_reduced = forward_backward_func( @@ -533,6 +547,8 @@ def get_logprobs( We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor. """ + self.model.zero_grad_buffer() + no_grad = torch.no_grad() no_grad.__enter__() logprob_batch_size = ( @@ -1358,13 +1374,13 @@ def broadcast_weights_for_collective( ) def prepare_for_lp_inference(self): - self.model = self.move_model(self.model, "cuda", move_grads=False) + self.model = self.move_model(self.model, "cuda", move_grads=True) self.model.eval() - # offload grads to cpu - self.model = self.move_model( - self.model, "cpu", move_params=False, move_grads=True - ) # get rid of grad buffers + # # offload grads to cpu + # self.model = self.move_model( + # self.model, "cpu", move_params=False, move_grads=True + # ) # get rid of grad buffers # offload optimizer to cpu torch.randn(1).cuda() # wake up torch allocator From 8de41359991274e402a101cc663051c8c576ba67 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Tue, 17 Feb 2026 02:07:40 +0000 Subject: [PATCH 8/9] fix Signed-off-by: Guyue Huang --- nemo_rl/models/megatron/setup.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index fdaa48e2c4..b6ecd00483 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -588,7 +588,11 @@ def _create_megatron_config( global_batch_size=config["train_global_batch_size"], # ignored train_iters=config["megatron_cfg"]["train_iters"], ), - optimizer=OptimizerConfig(**config["megatron_cfg"]["optimizer"]), + optimizer=OptimizerConfig( + fp8_recipe=config["megatron_cfg"]["fp8_cfg"]["fp8_recipe"], + overlap_param_gather=config["megatron_cfg"]["distributed_data_parallel_config"]["overlap_param_gather"], + **config["megatron_cfg"]["optimizer"] + ), ddp=DistributedDataParallelConfig( check_for_nan_in_grad=True, grad_reduce_in_fp32=config["megatron_cfg"][ From 16652bfc83fe53207dd7a5f02b9613ec3e1f52ef Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Tue, 17 Feb 2026 05:13:09 +0000 Subject: [PATCH 9/9] Fix for logprob difference in mxfp8 Signed-off-by: Guyue Huang --- .../policy/workers/megatron_policy_worker.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index eedd8682d8..9cdbdcf8b6 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -479,6 +479,20 @@ def get_logprobs( """ self.model.zero_grad_buffer() + from megatron.bridge.training.train import ( + _handle_mxfp8_param_buffer_copy, + ) + + _handle_mxfp8_param_buffer_copy( + optimizer=self.optimizer, + reuse_grad_buf_for_mxfp8_param_ag=self.cfg["megatron_cfg"][ + "optimizer" + ]["reuse_grad_buf_for_mxfp8_param_ag"], + overlap_param_gather=self.cfg["megatron_cfg"][ + "distributed_data_parallel_config" + ]["overlap_param_gather"], + ) + no_grad = torch.no_grad() no_grad.__enter__() logprob_batch_size = (