Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions examples/configs/gdpo_math_1B.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# GDPO: inherits from grpo_math_1B.yaml and overrides only what differs.
defaults: grpo_math_1B.yaml

grpo:
adv_estimator:
name: "gdpo"
normalize_rewards: true
use_leave_one_out_baseline: false

checkpointing:
checkpoint_dir: "results/gdpo"

policy:
model_name: "Qwen/Qwen2.5-1.5B-Instruct"
logprob_batch_size: 4
max_total_sequence_length: 1024
megatron_cfg:
optimizer:
weight_decay: 0.0
scheduler:
lr_decay_style: "cosine"
lr_warmup_iters: 10

# GDPO uses a single flat data config (GSM8K + math_gdpo_data_processor); replace parent's train/validation/default.
data:
_override_: true

max_input_seq_length: ${policy.max_total_sequence_length}
shuffle: true
num_workers: 1

use_multiple_dataloader: false

train:
dataset_name: "gsm8k"
split: train
validation:
dataset_name: "gsm8k"
split: test

default:
prompt_file: null
system_prompt_file: "examples/prompts/gsm8k.txt"
processor: "math_gdpo_data_processor"
env_name: "math_multi_reward"

env:
math_multi_reward:
num_workers: 8
math_verify_impl: "hf_math_verify"

logger:
wandb_enabled: true
wandb:
project: "gdpo-dev"
name: "gdpo-dev-logger"
swanlab:
project: "gdpo-dev"
name: "gdpo-dev-logger"
mlflow:
experiment_name: "gdpo-dev"
run_name: "gdpo-dev-logger"
17 changes: 17 additions & 0 deletions examples/prompts/gsm8k.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
You are a helpful AI assistant.

For every request, you should carefully think through the math problem step by step, then provide the final answer in integer format.

Steps for Each Request:
1. Think: Provide detailed, step-by-step reasoning, calculations, or derivations.
2. Produce Final Answer: After step-by-step reasoning, output the final answer in integer format.

Output Format:
<think>Your thoughts and reasoning</think>
<answer>Final answer in integer format</answer>

Important Notes:
1. You must include your reasoning steps inside <think>.
2. You must always output the Final Answer within <answer> after the reasoning steps is done.
3. You should consistently work through the solution step by step before giving the final answer.
4. The final answer can only be an integer.
7 changes: 7 additions & 0 deletions examples/run_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ def main() -> None:
"use_multiple_dataloader is not supported with async GRPO"
)

# Async GDPO is not supported
if config["grpo"]["adv_estimator"]["name"] == "gdpo":
raise NotImplementedError(
"GDPO is not supported for async training, "
"please set grpo.async_grpo.enabled to false in your config."
)

from nemo_rl.algorithms.grpo import async_grpo_train

print("🚀 Running async GRPO training")
Expand Down
95 changes: 79 additions & 16 deletions nemo_rl/algorithms/advantage_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
This module provides different advantage estimation strategies:
- GRPOAdvantageEstimator: Standard GRPO advantage with leave-one-out baseline
- GDPOAdvantageEstimator: Multi-reward GDPO (per-component baselines, sum then normalize)
- ReinforcePlusPlusAdvantageEstimator: Reinforce++ with optional baseline subtraction (minus_baseline) and KL penalty in reward
Reference papers:
- ProRLv2: https://developer.nvidia.com/blog/scaling-llm-reinforcement-learning-with-prolonged-training-using-prorl-v2/
Expand All @@ -24,8 +25,7 @@

import torch

from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt, calculate_kl

from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt, calculate_kl, get_gdpo_reward_component_keys

class GRPOAdvantageEstimator:
"""GRPO-style advantage estimator with leave-one-out baseline.
Expand All @@ -37,19 +37,20 @@ def __init__(self, estimator_config: dict, loss_config: dict):
self.use_leave_one_out_baseline = estimator_config["use_leave_one_out_baseline"]
self.normalize_rewards = estimator_config["normalize_rewards"]

def compute_advantage(self, prompt_ids, rewards, mask, **kwargs):
def compute_advantage(self, repeated_batch, mask, **kwargs):
"""Compute GRPO advantages.
Args:
prompt_ids: Tensor of shape [batch_size] identifying which prompt each sample belongs to.
rewards: Tensor of shape [batch_size] containing reward for each sample.
repeated_batch: Batch containing _input_ids_for_baseline and total_reward.
mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding.
Used only for expanding advantages to token-level shape.
**kwargs: Additional arguments (unused).
Returns:
Advantages tensor of shape [batch_size, seq_len].
"""
prompt_ids = repeated_batch["_input_ids_for_baseline"]
rewards = repeated_batch["total_reward"]
baseline, std = calculate_baseline_and_std_per_prompt(
prompt_ids,
rewards,
Expand All @@ -69,6 +70,75 @@ def compute_advantage(self, prompt_ids, rewards, mask, **kwargs):
return advantages.expand(mask.shape)


class GDPOAdvantageEstimator:
"""GDPO-style advantage estimator with leave-one-out baseline.
Note: GDPO computes advantages for each reward separately over all responses for each prompt.
"""

def __init__(self, estimator_config: dict, loss_config: dict):
self.use_leave_one_out_baseline = estimator_config["use_leave_one_out_baseline"]
self.normalize_rewards = estimator_config["normalize_rewards"]

def compute_advantage(self, repeated_batch, mask, **kwargs):
"""Compute GDPO advantages.
Args:
repeated_batch: Batch containing _input_ids_for_baseline and reward1, reward2, ... keys.
mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding.
Used only for expanding advantages to token-level shape.
**kwargs: Additional arguments (unused).
Returns:
Advantages tensor of shape [batch_size, seq_len].
"""
reward_component_keys = get_gdpo_reward_component_keys(repeated_batch)
if len(reward_component_keys) < 2:
raise ValueError(
f"GDPO requires multiple reward components (reward1, reward2, ...). "
f"This batch has {len(reward_component_keys)} component(s). "
"Switch to GRPO by setting grpo.adv_estimator.name to 'grpo' in your config."
)
current_input_ids = repeated_batch["_input_ids_for_baseline"]
valid = torch.ones_like(
repeated_batch[reward_component_keys[0]]
)
leave_one_out = self.use_leave_one_out_baseline
assert current_input_ids.shape[0] == valid.shape[0], (
"_input_ids_for_baseline must match reward batch size after dynamic_sampling; "
f"got {current_input_ids.shape[0]} vs {valid.shape[0]}"
)
advantage_parts = []
for key in reward_component_keys:
r = repeated_batch[key]
base, std_k = calculate_baseline_and_std_per_prompt(
current_input_ids,
r,
valid,
leave_one_out_baseline=leave_one_out,
)
adv_k = (r - base).unsqueeze(-1)
if self.normalize_rewards:

epsilon = 1e-6
non_zero_std_mask = std_k > 0
adv_k[non_zero_std_mask] = adv_k[non_zero_std_mask] / (
std_k.unsqueeze(-1)[non_zero_std_mask] + epsilon
)

advantage_parts.append(adv_k)

advantages = sum(advantage_parts)
# Normalize combined advantage to zero mean and unit std
adv_std = advantages.std()
if adv_std > 0:
advantages = (advantages - advantages.mean()) / adv_std
else:
advantages = advantages - advantages.mean()

return advantages.expand(mask.shape)


class ReinforcePlusPlusAdvantageEstimator:
"""Reinforce++ advantage estimator with optional baseline subtraction and KL penalty in reward.
Expand All @@ -83,20 +153,11 @@ def __init__(self, estimator_config: dict, loss_config: dict):
self.kl_coef = loss_config["reference_policy_kl_penalty"]
self.kl_type = loss_config["reference_policy_kl_type"]

def compute_advantage(
self,
prompt_ids,
rewards,
mask,
logprobs_policy=None,
logprobs_reference=None,
**kwargs,
):
def compute_advantage(self, repeated_batch, mask, logprobs_policy=None, logprobs_reference=None, **kwargs):
"""Compute Reinforce++ advantages with optional KL penalty.
Args:
prompt_ids: Tensor of shape [batch_size] identifying which prompt each sample belongs to.
rewards: Tensor of shape [batch_size] containing reward for each sample.
repeated_batch: Batch containing _input_ids_for_baseline and total_reward.
mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding.
Used for: (1) expanding advantages to token-level shape, (2) global normalization
that only considers valid tokens.
Expand All @@ -107,6 +168,8 @@ def compute_advantage(
Returns:
Advantages tensor of shape [batch_size, seq_len], globally normalized across valid tokens.
"""
prompt_ids = repeated_batch["_input_ids_for_baseline"]
rewards = repeated_batch["total_reward"]
# minus baseline
if self.minus_baseline:
mean, _ = calculate_baseline_and_std_per_prompt(
Expand Down
45 changes: 32 additions & 13 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import gc
import os
import re
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -29,6 +30,7 @@

from nemo_rl.algorithms.advantage_estimator import (
GRPOAdvantageEstimator,
GDPOAdvantageEstimator,
ReinforcePlusPlusAdvantageEstimator,
)
from nemo_rl.algorithms.loss import (
Expand All @@ -46,6 +48,7 @@
log_generation_metrics_to_wandb,
print_performance_metrics,
set_seed,
get_gdpo_reward_component_keys
)
from nemo_rl.data import DataConfig
from nemo_rl.data.collate_fn import rl_collate_fn
Expand Down Expand Up @@ -121,9 +124,9 @@ class AsyncGRPOConfig(TypedDict):


class AdvEstimatorConfig(TypedDict):
"""Configuration for advantage estimator (GRPO or Reinforce++)."""
"""Configuration for advantage estimator (GRPO, GDPO, or Reinforce++)."""

name: str # "grpo" or "reinforce_plus_plus"
name: str # "grpo", "gdpo", or "reinforce_plus_plus"
# GRPO specific
normalize_rewards: NotRequired[bool]
use_leave_one_out_baseline: NotRequired[bool]
Expand Down Expand Up @@ -966,11 +969,16 @@ def scale_rewards(
)

# Clamp and scale
rewards = torch.clamp(rewards, min=source_min, max=source_max)
scaled_rewards = target_min + (rewards - source_min) / (
source_max - source_min
) * (target_max - target_min)
def _scale(reward_tensor: torch.Tensor) -> torch.Tensor:
r = torch.clamp(reward_tensor, min=source_min, max=source_max)
return target_min + (r - source_min) / (
source_max - source_min
) * (target_max - target_min)

scaled_rewards = _scale(rewards)
repeated_batch["total_reward"] = scaled_rewards
for key in get_gdpo_reward_component_keys(repeated_batch):
repeated_batch[key] = _scale(repeated_batch[key])

return repeated_batch

Expand Down Expand Up @@ -1031,7 +1039,7 @@ def _create_advantage_estimator(master_config: MasterConfig):
master_config: The master configuration dictionary.

Returns:
An advantage estimator instance (GRPOAdvantageEstimator or ReinforcePlusPlusAdvantageEstimator).
An advantage estimator instance (GRPO, GDPO, or ReinforcePlusPlus).

Raises:
ValueError: If the advantage estimator name is not recognized.
Expand All @@ -1055,7 +1063,14 @@ def _create_advantage_estimator(master_config: MasterConfig):
)

adv_estimator_name = adv_estimator_config["name"]
if adv_estimator_name == "grpo":
if adv_estimator_name == "gdpo":
assert not _should_use_async_rollouts(master_config), (
"GDPO is not supported for async rollouts, "
"please set policy.generation.vllm_cfg.async_engine to false in your config."
)
adv_estimator = GDPOAdvantageEstimator(adv_estimator_config, loss_config)
print(" ✓ Using GDPO advantage estimator (multi-reward)")
elif adv_estimator_name == "grpo":
adv_estimator = GRPOAdvantageEstimator(adv_estimator_config, loss_config)
print(" ✓ Using GRPO advantage estimator")
elif adv_estimator_name == "reinforce_plus_plus":
Expand Down Expand Up @@ -1590,6 +1605,10 @@ def grpo_train(
with timer.time("reward_calculation"):
# Extract rewards from final_batch
rewards = repeated_batch["total_reward"]
# Store input_ids in batch so that after dynamic_sampling it stays aligned with
# the (possibly filtered) batch: select_indices / from_batches / slice all
# apply to this key, so per-reward baselines use the same prompts as reward components.
repeated_batch["_input_ids_for_baseline"] = input_ids

print("▶ Computing advantages...", flush=True)
if master_config["grpo"].get("calculate_advantages_on_gpu"):
Expand Down Expand Up @@ -1644,10 +1663,10 @@ def grpo_train(
# If the current batch is not enough to fill the buffer during dynamic sampling, we update the cache and process the next batch.
if not is_batch_complete:
continue

gen_step_metrics = {}
if hasattr(policy_generation, "get_step_metrics"):
gen_step_metrics = policy_generation.get_step_metrics()
advantages = (rewards - baseline).unsqueeze(-1)

# Save baseline for logging (before deletion)
baseline_for_log = baseline.clone()
Expand Down Expand Up @@ -1776,8 +1795,7 @@ def grpo_train(
mask = token_mask * sample_mask.unsqueeze(-1)

train_data["advantages"] = adv_estimator.compute_advantage(
prompt_ids=prompt_ids_for_adv,
rewards=rewards,
repeated_batch=repeated_batch,
mask=mask,
logprobs_policy=train_data["prev_logprobs"],
logprobs_reference=train_data.get("reference_policy_logprobs"),
Expand Down Expand Up @@ -2724,6 +2742,8 @@ def async_grpo_train(
del prompt_batched_flat

rewards = repeated_batch["total_reward"]
# All estimators read _input_ids_for_baseline from repeated_batch
repeated_batch["_input_ids_for_baseline"] = prompt_ids_for_adv

print(
f" 📊 Rewards stats: min={rewards.min():.4f}, max={rewards.max():.4f}, mean={rewards.mean():.4f}, std={rewards.std():.4f}"
Expand Down Expand Up @@ -2807,8 +2827,7 @@ def async_grpo_train(
mask = token_mask * sample_mask.unsqueeze(-1)

train_data["advantages"] = adv_estimator.compute_advantage(
prompt_ids=prompt_ids_for_adv,
rewards=rewards,
repeated_batch=repeated_batch,
mask=mask,
logprobs_policy=train_data["prev_logprobs"],
logprobs_reference=train_data.get("reference_policy_logprobs"),
Expand Down
5 changes: 5 additions & 0 deletions nemo_rl/algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@
from nemo_rl.data.chat_templates import COMMON_CHAT_TEMPLATES
from nemo_rl.models.policy import TokenizerConfig
from nemo_rl.utils.logger import Logger
import re

def get_gdpo_reward_component_keys(batch) -> list:
"""Return batch keys that are reward components (reward1, reward2, ...) in sorted order."""
keys = [k for k in batch.keys() if re.match(r"reward\d+$", str(k))]
return sorted(keys, key=lambda k: int(re.search(r"\d+", str(k)).group()))

def calculate_kl(
logprobs: torch.Tensor,
Expand Down
Loading
Loading