diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index e9fc2da9e1..b6ecd00483 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -588,7 +588,11 @@ def _create_megatron_config( global_batch_size=config["train_global_batch_size"], # ignored train_iters=config["megatron_cfg"]["train_iters"], ), - optimizer=OptimizerConfig(**config["megatron_cfg"]["optimizer"]), + optimizer=OptimizerConfig( + fp8_recipe=config["megatron_cfg"]["fp8_cfg"]["fp8_recipe"], + overlap_param_gather=config["megatron_cfg"]["distributed_data_parallel_config"]["overlap_param_gather"], + **config["megatron_cfg"]["optimizer"] + ), ddp=DistributedDataParallelConfig( check_for_nan_in_grad=True, grad_reduce_in_fp32=config["megatron_cfg"][ @@ -609,6 +613,12 @@ def _create_megatron_config( data_parallel_sharding_strategy=config["megatron_cfg"][ "distributed_data_parallel_config" ]["data_parallel_sharding_strategy"], + fp8_param_gather=config["megatron_cfg"]["optimizer"].get( + "reuse_grad_buf_for_mxfp8_param_ag", False + ), + reuse_grad_buf_for_mxfp8_param_ag=config["megatron_cfg"]["optimizer"].get( + "reuse_grad_buf_for_mxfp8_param_ag", False + ), ), scheduler=SchedulerConfig(**config["megatron_cfg"]["scheduler"]), dataset=None, diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index d9a1c3d8a3..9cdbdcf8b6 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -326,6 +326,20 @@ def train( self.model.zero_grad_buffer() self.optimizer.zero_grad() + from megatron.bridge.training.train import ( + _handle_mxfp8_param_buffer_copy, + ) + + _handle_mxfp8_param_buffer_copy( + optimizer=self.optimizer, + reuse_grad_buf_for_mxfp8_param_ag=self.cfg["megatron_cfg"][ + "optimizer" + ]["reuse_grad_buf_for_mxfp8_param_ag"], + overlap_param_gather=self.cfg["megatron_cfg"][ + "distributed_data_parallel_config" + ]["overlap_param_gather"], + ) + # Forward pass. losses_reduced = megatron_forward_backward( model=self.model, @@ -463,6 +477,22 @@ def get_logprobs( We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor. """ + self.model.zero_grad_buffer() + + from megatron.bridge.training.train import ( + _handle_mxfp8_param_buffer_copy, + ) + + _handle_mxfp8_param_buffer_copy( + optimizer=self.optimizer, + reuse_grad_buf_for_mxfp8_param_ag=self.cfg["megatron_cfg"][ + "optimizer" + ]["reuse_grad_buf_for_mxfp8_param_ag"], + overlap_param_gather=self.cfg["megatron_cfg"][ + "distributed_data_parallel_config" + ]["overlap_param_gather"], + ) + no_grad = torch.no_grad() no_grad.__enter__() logprob_batch_size = ( @@ -1057,13 +1087,13 @@ def broadcast_weights_for_collective( ) def prepare_for_lp_inference(self): - self.model = self.move_model(self.model, "cuda", move_grads=False) + self.model = self.move_model(self.model, "cuda", move_grads=True) self.model.eval() - # offload grads to cpu - self.model = self.move_model( - self.model, "cpu", move_params=False, move_grads=True - ) # get rid of grad buffers + # # offload grads to cpu + # self.model = self.move_model( + # self.model, "cpu", move_params=False, move_grads=True + # ) # get rid of grad buffers # offload optimizer to cpu torch.randn(1).cuda() # wake up torch allocator