diff --git a/examples/configs/gdpo_math_1B.yaml b/examples/configs/gdpo_math_1B.yaml
new file mode 100644
index 0000000000..47450a17a7
--- /dev/null
+++ b/examples/configs/gdpo_math_1B.yaml
@@ -0,0 +1,62 @@
+# GDPO: inherits from grpo_math_1B.yaml and overrides only what differs.
+defaults: grpo_math_1B.yaml
+
+grpo:
+ adv_estimator:
+ name: "gdpo"
+ normalize_rewards: true
+ use_leave_one_out_baseline: false
+
+checkpointing:
+ checkpoint_dir: "results/gdpo"
+
+policy:
+ model_name: "Qwen/Qwen2.5-1.5B-Instruct"
+ logprob_batch_size: 4
+ max_total_sequence_length: 1024
+ megatron_cfg:
+ optimizer:
+ weight_decay: 0.0
+ scheduler:
+ lr_decay_style: "cosine"
+ lr_warmup_iters: 10
+
+# GDPO uses a single flat data config (GSM8K + math_gdpo_data_processor); replace parent's train/validation/default.
+data:
+ _override_: true
+
+ max_input_seq_length: ${policy.max_total_sequence_length}
+ shuffle: true
+ num_workers: 1
+
+ use_multiple_dataloader: false
+
+ train:
+ dataset_name: "gsm8k"
+ split: train
+ validation:
+ dataset_name: "gsm8k"
+ split: test
+
+ default:
+ prompt_file: null
+ system_prompt_file: "examples/prompts/gsm8k.txt"
+ processor: "math_gdpo_data_processor"
+ env_name: "math_multi_reward"
+
+env:
+ math_multi_reward:
+ num_workers: 8
+ math_verify_impl: "hf_math_verify"
+
+logger:
+ wandb_enabled: true
+ wandb:
+ project: "gdpo-dev"
+ name: "gdpo-dev-logger"
+ swanlab:
+ project: "gdpo-dev"
+ name: "gdpo-dev-logger"
+ mlflow:
+ experiment_name: "gdpo-dev"
+ run_name: "gdpo-dev-logger"
diff --git a/examples/prompts/gsm8k.txt b/examples/prompts/gsm8k.txt
new file mode 100644
index 0000000000..3c31977100
--- /dev/null
+++ b/examples/prompts/gsm8k.txt
@@ -0,0 +1,17 @@
+You are a helpful AI assistant.
+
+For every request, you should carefully think through the math problem step by step, then provide the final answer in integer format.
+
+Steps for Each Request:
+1. Think: Provide detailed, step-by-step reasoning, calculations, or derivations.
+2. Produce Final Answer: After step-by-step reasoning, output the final answer in integer format.
+
+Output Format:
+Your thoughts and reasoning
+Final answer in integer format
+
+Important Notes:
+1. You must include your reasoning steps inside .
+2. You must always output the Final Answer within after the reasoning steps is done.
+3. You should consistently work through the solution step by step before giving the final answer.
+4. The final answer can only be an integer.
\ No newline at end of file
diff --git a/examples/run_grpo.py b/examples/run_grpo.py
index 6130b99018..0e9f8bf24a 100644
--- a/examples/run_grpo.py
+++ b/examples/run_grpo.py
@@ -139,6 +139,13 @@ def main() -> None:
"use_multiple_dataloader is not supported with async GRPO"
)
+ # Async GDPO is not supported
+ if config["grpo"]["adv_estimator"]["name"] == "gdpo":
+ raise NotImplementedError(
+ "GDPO is not supported for async training, "
+ "please set grpo.async_grpo.enabled to false in your config."
+ )
+
from nemo_rl.algorithms.grpo import async_grpo_train
print("🚀 Running async GRPO training")
diff --git a/nemo_rl/algorithms/advantage_estimator.py b/nemo_rl/algorithms/advantage_estimator.py
index 6d14d8637a..976c142a44 100644
--- a/nemo_rl/algorithms/advantage_estimator.py
+++ b/nemo_rl/algorithms/advantage_estimator.py
@@ -16,6 +16,7 @@
This module provides different advantage estimation strategies:
- GRPOAdvantageEstimator: Standard GRPO advantage with leave-one-out baseline
+- GDPOAdvantageEstimator: Multi-reward GDPO (per-component baselines, sum then normalize)
- ReinforcePlusPlusAdvantageEstimator: Reinforce++ with optional baseline subtraction (minus_baseline) and KL penalty in reward
Reference papers:
- ProRLv2: https://developer.nvidia.com/blog/scaling-llm-reinforcement-learning-with-prolonged-training-using-prorl-v2/
@@ -24,8 +25,7 @@
import torch
-from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt, calculate_kl
-
+from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt, calculate_kl, get_gdpo_reward_component_keys
class GRPOAdvantageEstimator:
"""GRPO-style advantage estimator with leave-one-out baseline.
@@ -37,12 +37,11 @@ def __init__(self, estimator_config: dict, loss_config: dict):
self.use_leave_one_out_baseline = estimator_config["use_leave_one_out_baseline"]
self.normalize_rewards = estimator_config["normalize_rewards"]
- def compute_advantage(self, prompt_ids, rewards, mask, **kwargs):
+ def compute_advantage(self, repeated_batch, mask, **kwargs):
"""Compute GRPO advantages.
Args:
- prompt_ids: Tensor of shape [batch_size] identifying which prompt each sample belongs to.
- rewards: Tensor of shape [batch_size] containing reward for each sample.
+ repeated_batch: Batch containing _input_ids_for_baseline and total_reward.
mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding.
Used only for expanding advantages to token-level shape.
**kwargs: Additional arguments (unused).
@@ -50,6 +49,8 @@ def compute_advantage(self, prompt_ids, rewards, mask, **kwargs):
Returns:
Advantages tensor of shape [batch_size, seq_len].
"""
+ prompt_ids = repeated_batch["_input_ids_for_baseline"]
+ rewards = repeated_batch["total_reward"]
baseline, std = calculate_baseline_and_std_per_prompt(
prompt_ids,
rewards,
@@ -69,6 +70,75 @@ def compute_advantage(self, prompt_ids, rewards, mask, **kwargs):
return advantages.expand(mask.shape)
+class GDPOAdvantageEstimator:
+ """GDPO-style advantage estimator with leave-one-out baseline.
+
+ Note: GDPO computes advantages for each reward separately over all responses for each prompt.
+ """
+
+ def __init__(self, estimator_config: dict, loss_config: dict):
+ self.use_leave_one_out_baseline = estimator_config["use_leave_one_out_baseline"]
+ self.normalize_rewards = estimator_config["normalize_rewards"]
+
+ def compute_advantage(self, repeated_batch, mask, **kwargs):
+ """Compute GDPO advantages.
+
+ Args:
+ repeated_batch: Batch containing _input_ids_for_baseline and reward1, reward2, ... keys.
+ mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding.
+ Used only for expanding advantages to token-level shape.
+ **kwargs: Additional arguments (unused).
+
+ Returns:
+ Advantages tensor of shape [batch_size, seq_len].
+ """
+ reward_component_keys = get_gdpo_reward_component_keys(repeated_batch)
+ if len(reward_component_keys) < 2:
+ raise ValueError(
+ f"GDPO requires multiple reward components (reward1, reward2, ...). "
+ f"This batch has {len(reward_component_keys)} component(s). "
+ "Switch to GRPO by setting grpo.adv_estimator.name to 'grpo' in your config."
+ )
+ current_input_ids = repeated_batch["_input_ids_for_baseline"]
+ valid = torch.ones_like(
+ repeated_batch[reward_component_keys[0]]
+ )
+ leave_one_out = self.use_leave_one_out_baseline
+ assert current_input_ids.shape[0] == valid.shape[0], (
+ "_input_ids_for_baseline must match reward batch size after dynamic_sampling; "
+ f"got {current_input_ids.shape[0]} vs {valid.shape[0]}"
+ )
+ advantage_parts = []
+ for key in reward_component_keys:
+ r = repeated_batch[key]
+ base, std_k = calculate_baseline_and_std_per_prompt(
+ current_input_ids,
+ r,
+ valid,
+ leave_one_out_baseline=leave_one_out,
+ )
+ adv_k = (r - base).unsqueeze(-1)
+ if self.normalize_rewards:
+
+ epsilon = 1e-6
+ non_zero_std_mask = std_k > 0
+ adv_k[non_zero_std_mask] = adv_k[non_zero_std_mask] / (
+ std_k.unsqueeze(-1)[non_zero_std_mask] + epsilon
+ )
+
+ advantage_parts.append(adv_k)
+
+ advantages = sum(advantage_parts)
+ # Normalize combined advantage to zero mean and unit std
+ adv_std = advantages.std()
+ if adv_std > 0:
+ advantages = (advantages - advantages.mean()) / adv_std
+ else:
+ advantages = advantages - advantages.mean()
+
+ return advantages.expand(mask.shape)
+
+
class ReinforcePlusPlusAdvantageEstimator:
"""Reinforce++ advantage estimator with optional baseline subtraction and KL penalty in reward.
@@ -83,20 +153,11 @@ def __init__(self, estimator_config: dict, loss_config: dict):
self.kl_coef = loss_config["reference_policy_kl_penalty"]
self.kl_type = loss_config["reference_policy_kl_type"]
- def compute_advantage(
- self,
- prompt_ids,
- rewards,
- mask,
- logprobs_policy=None,
- logprobs_reference=None,
- **kwargs,
- ):
+ def compute_advantage(self, repeated_batch, mask, logprobs_policy=None, logprobs_reference=None, **kwargs):
"""Compute Reinforce++ advantages with optional KL penalty.
Args:
- prompt_ids: Tensor of shape [batch_size] identifying which prompt each sample belongs to.
- rewards: Tensor of shape [batch_size] containing reward for each sample.
+ repeated_batch: Batch containing _input_ids_for_baseline and total_reward.
mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding.
Used for: (1) expanding advantages to token-level shape, (2) global normalization
that only considers valid tokens.
@@ -107,6 +168,8 @@ def compute_advantage(
Returns:
Advantages tensor of shape [batch_size, seq_len], globally normalized across valid tokens.
"""
+ prompt_ids = repeated_batch["_input_ids_for_baseline"]
+ rewards = repeated_batch["total_reward"]
# minus baseline
if self.minus_baseline:
mean, _ = calculate_baseline_and_std_per_prompt(
diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py
index c060a05a50..f5b8892ed7 100644
--- a/nemo_rl/algorithms/grpo.py
+++ b/nemo_rl/algorithms/grpo.py
@@ -13,6 +13,7 @@
# limitations under the License.
import gc
import os
+import re
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
@@ -29,6 +30,7 @@
from nemo_rl.algorithms.advantage_estimator import (
GRPOAdvantageEstimator,
+ GDPOAdvantageEstimator,
ReinforcePlusPlusAdvantageEstimator,
)
from nemo_rl.algorithms.loss import (
@@ -46,6 +48,7 @@
log_generation_metrics_to_wandb,
print_performance_metrics,
set_seed,
+ get_gdpo_reward_component_keys
)
from nemo_rl.data import DataConfig
from nemo_rl.data.collate_fn import rl_collate_fn
@@ -121,9 +124,9 @@ class AsyncGRPOConfig(TypedDict):
class AdvEstimatorConfig(TypedDict):
- """Configuration for advantage estimator (GRPO or Reinforce++)."""
+ """Configuration for advantage estimator (GRPO, GDPO, or Reinforce++)."""
- name: str # "grpo" or "reinforce_plus_plus"
+ name: str # "grpo", "gdpo", or "reinforce_plus_plus"
# GRPO specific
normalize_rewards: NotRequired[bool]
use_leave_one_out_baseline: NotRequired[bool]
@@ -966,11 +969,16 @@ def scale_rewards(
)
# Clamp and scale
- rewards = torch.clamp(rewards, min=source_min, max=source_max)
- scaled_rewards = target_min + (rewards - source_min) / (
- source_max - source_min
- ) * (target_max - target_min)
+ def _scale(reward_tensor: torch.Tensor) -> torch.Tensor:
+ r = torch.clamp(reward_tensor, min=source_min, max=source_max)
+ return target_min + (r - source_min) / (
+ source_max - source_min
+ ) * (target_max - target_min)
+
+ scaled_rewards = _scale(rewards)
repeated_batch["total_reward"] = scaled_rewards
+ for key in get_gdpo_reward_component_keys(repeated_batch):
+ repeated_batch[key] = _scale(repeated_batch[key])
return repeated_batch
@@ -1031,7 +1039,7 @@ def _create_advantage_estimator(master_config: MasterConfig):
master_config: The master configuration dictionary.
Returns:
- An advantage estimator instance (GRPOAdvantageEstimator or ReinforcePlusPlusAdvantageEstimator).
+ An advantage estimator instance (GRPO, GDPO, or ReinforcePlusPlus).
Raises:
ValueError: If the advantage estimator name is not recognized.
@@ -1055,7 +1063,14 @@ def _create_advantage_estimator(master_config: MasterConfig):
)
adv_estimator_name = adv_estimator_config["name"]
- if adv_estimator_name == "grpo":
+ if adv_estimator_name == "gdpo":
+ assert not _should_use_async_rollouts(master_config), (
+ "GDPO is not supported for async rollouts, "
+ "please set policy.generation.vllm_cfg.async_engine to false in your config."
+ )
+ adv_estimator = GDPOAdvantageEstimator(adv_estimator_config, loss_config)
+ print(" ✓ Using GDPO advantage estimator (multi-reward)")
+ elif adv_estimator_name == "grpo":
adv_estimator = GRPOAdvantageEstimator(adv_estimator_config, loss_config)
print(" ✓ Using GRPO advantage estimator")
elif adv_estimator_name == "reinforce_plus_plus":
@@ -1590,6 +1605,10 @@ def grpo_train(
with timer.time("reward_calculation"):
# Extract rewards from final_batch
rewards = repeated_batch["total_reward"]
+ # Store input_ids in batch so that after dynamic_sampling it stays aligned with
+ # the (possibly filtered) batch: select_indices / from_batches / slice all
+ # apply to this key, so per-reward baselines use the same prompts as reward components.
+ repeated_batch["_input_ids_for_baseline"] = input_ids
print("â–¶ Computing advantages...", flush=True)
if master_config["grpo"].get("calculate_advantages_on_gpu"):
@@ -1644,10 +1663,10 @@ def grpo_train(
# If the current batch is not enough to fill the buffer during dynamic sampling, we update the cache and process the next batch.
if not is_batch_complete:
continue
+
gen_step_metrics = {}
if hasattr(policy_generation, "get_step_metrics"):
gen_step_metrics = policy_generation.get_step_metrics()
- advantages = (rewards - baseline).unsqueeze(-1)
# Save baseline for logging (before deletion)
baseline_for_log = baseline.clone()
@@ -1776,8 +1795,7 @@ def grpo_train(
mask = token_mask * sample_mask.unsqueeze(-1)
train_data["advantages"] = adv_estimator.compute_advantage(
- prompt_ids=prompt_ids_for_adv,
- rewards=rewards,
+ repeated_batch=repeated_batch,
mask=mask,
logprobs_policy=train_data["prev_logprobs"],
logprobs_reference=train_data.get("reference_policy_logprobs"),
@@ -2724,6 +2742,8 @@ def async_grpo_train(
del prompt_batched_flat
rewards = repeated_batch["total_reward"]
+ # All estimators read _input_ids_for_baseline from repeated_batch
+ repeated_batch["_input_ids_for_baseline"] = prompt_ids_for_adv
print(
f" 📊 Rewards stats: min={rewards.min():.4f}, max={rewards.max():.4f}, mean={rewards.mean():.4f}, std={rewards.std():.4f}"
@@ -2807,8 +2827,7 @@ def async_grpo_train(
mask = token_mask * sample_mask.unsqueeze(-1)
train_data["advantages"] = adv_estimator.compute_advantage(
- prompt_ids=prompt_ids_for_adv,
- rewards=rewards,
+ repeated_batch=repeated_batch,
mask=mask,
logprobs_policy=train_data["prev_logprobs"],
logprobs_reference=train_data.get("reference_policy_logprobs"),
diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py
index 8e632ca5ee..bcc2281544 100644
--- a/nemo_rl/algorithms/utils.py
+++ b/nemo_rl/algorithms/utils.py
@@ -29,7 +29,12 @@
from nemo_rl.data.chat_templates import COMMON_CHAT_TEMPLATES
from nemo_rl.models.policy import TokenizerConfig
from nemo_rl.utils.logger import Logger
+import re
+def get_gdpo_reward_component_keys(batch) -> list:
+ """Return batch keys that are reward components (reward1, reward2, ...) in sorted order."""
+ keys = [k for k in batch.keys() if re.match(r"reward\d+$", str(k))]
+ return sorted(keys, key=lambda k: int(re.search(r"\d+", str(k)).group()))
def calculate_kl(
logprobs: torch.Tensor,
diff --git a/nemo_rl/data/datasets/response_datasets/__init__.py b/nemo_rl/data/datasets/response_datasets/__init__.py
index eb48bb5204..d53a422bbb 100644
--- a/nemo_rl/data/datasets/response_datasets/__init__.py
+++ b/nemo_rl/data/datasets/response_datasets/__init__.py
@@ -20,6 +20,7 @@
DAPOMath17KDataset,
DAPOMathAIME2024Dataset,
)
+from nemo_rl.data.datasets.response_datasets.gsm8k import GSM8KDataset
from nemo_rl.data.datasets.response_datasets.deepscaler import DeepScalerDataset
from nemo_rl.data.datasets.response_datasets.general_conversations_dataset import (
GeneralConversationsJsonlDataset,
@@ -55,6 +56,7 @@
"refcoco": RefCOCODataset,
"squad": SquadDataset,
"tulu3_sft_mixture": Tulu3SftMixtureDataset,
+ "gsm8k": GSM8KDataset,
# load from local JSONL file or HuggingFace
"openai_format": OpenAIFormatDataset,
"NemoGymDataset": NemoGymDataset,
@@ -94,6 +96,7 @@ def load_response_dataset(data_config: ResponseDatasetConfig):
"GeneralConversationsJsonlDataset",
"DAPOMath17KDataset",
"DAPOMathAIME2024Dataset",
+ "GSM8KDataset",
"DeepScalerDataset",
"Geometry3KDataset",
"HelpSteer3Dataset",
diff --git a/nemo_rl/data/datasets/response_datasets/gsm8k.py b/nemo_rl/data/datasets/response_datasets/gsm8k.py
new file mode 100644
index 0000000000..ce3affd869
--- /dev/null
+++ b/nemo_rl/data/datasets/response_datasets/gsm8k.py
@@ -0,0 +1,66 @@
+# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+
+from datasets import load_dataset
+
+from nemo_rl.data.datasets.raw_dataset import RawDataset
+
+
+def _extract_hash_answer(text: str) -> str | None:
+ if "####" not in text:
+ return None
+ return text.split("####")[1].strip()
+
+
+class GSM8KDataset(RawDataset):
+ """Simple wrapper around the GSM8K dataset.
+
+ Args:
+ split: Split name for the dataset, default is "train"
+ extract_answer: Whether to extract the answer from the dataset, default is True
+ """
+
+ def __init__(self,
+ split: str = "train",
+ extract_answer: bool = True,
+ system_prompt_file: str | None = None,
+ **kwargs,
+ ) -> None:
+ self.task_name = "gsm8k"
+ self.extract_answer = extract_answer
+
+ # load from huggingface
+ self.dataset = load_dataset("openai/gsm8k", "main")[split]
+
+ # format the dataset
+ self.dataset = self.dataset.map(
+ self.format_data,
+ remove_columns=self.dataset.column_names,
+ )
+
+ def format_data(self, data: dict[str, Any]) -> dict[str, Any]:
+ if self.extract_answer:
+ answer = _extract_hash_answer(data["answer"])
+ else:
+ answer = data["answer"]
+
+ return {
+ "messages": [
+ {"role": "user", "content": data["question"]},
+ {"role": "assistant", "content": answer},
+ ],
+ "task_name": self.task_name,
+ }
diff --git a/nemo_rl/data/processors.py b/nemo_rl/data/processors.py
index 52ac9bf67d..5d3f51cccf 100644
--- a/nemo_rl/data/processors.py
+++ b/nemo_rl/data/processors.py
@@ -381,6 +381,68 @@ def math_data_processor(
return output
+# TODO: @yukih: unify to math_hf_data_processor once https://github.com/NVIDIA-NeMo/RL/issues/2060 is resolved.
+def math_gdpo_data_processor(
+ datum_dict: dict[str, Any],
+ task_data_spec: TaskDataSpec,
+ tokenizer: TokenizerType,
+ max_seq_length: int,
+ idx: int,
+) -> DatumSpec:
+ """Process a datum dictionary (directly loaded from data/hf_datasets/openmathinstruct2.py) into a DatumSpec for the Reward Model Environment."""
+ user_message = datum_dict["messages"]
+ problem = user_message[0]["content"]
+ extra_env_info = {"ground_truth": user_message[1]["content"]}
+
+ # merge system prompt and user prompt
+ message_list = []
+ # system prompt
+ if task_data_spec.system_prompt:
+ message_list.append({
+ "role": "system",
+ "content": task_data_spec.system_prompt,
+ })
+ # user prompt
+ if task_data_spec.prompt:
+ problem = task_data_spec.prompt.format(problem)
+ message_list.append({"role": "user", "content": problem})
+
+ message: list[str] = tokenizer.apply_chat_template( # type: ignore
+ message_list,
+ tokenize=False,
+ add_generation_prompt=True,
+ add_special_tokens=False,
+ )
+ token_ids = tokenizer(
+ message, return_tensors="pt", add_special_tokens=False
+ )["input_ids"][0]
+
+ message_log: LLMMessageLogType = [
+ {"role": "user", "content": message, "token_ids": token_ids}
+ ]
+
+ length = sum(len(m["token_ids"]) for m in message_log)
+
+ loss_multiplier = 1.0
+ if length > max_seq_length:
+ # make smaller and mask out
+ for chat_message in message_log:
+ chat_message["token_ids"] = chat_message["token_ids"][
+ : min(4, max_seq_length // len(message_log))
+ ]
+ loss_multiplier = 0.0
+
+ output: DatumSpec = {
+ "message_log": message_log,
+ "length": length,
+ "extra_env_info": extra_env_info,
+ "loss_multiplier": loss_multiplier,
+ "idx": idx,
+ "task_name": datum_dict["task_name"],
+ }
+ return output
+
+
def math_hf_data_processor(
datum_dict: dict[str, Any],
task_data_spec: TaskDataSpec,
@@ -698,6 +760,7 @@ def nemo_gym_data_processor(
"helpsteer3_data_processor": helpsteer3_data_processor,
"math_data_processor": math_data_processor,
"math_hf_data_processor": math_hf_data_processor,
+ "math_gdpo_data_processor": math_gdpo_data_processor,
"multichoice_qa_processor": multichoice_qa_processor,
"sft_processor": sft_processor,
"vlm_hf_data_processor": vlm_hf_data_processor,
diff --git a/nemo_rl/distributed/ray_actor_environment_registry.py b/nemo_rl/distributed/ray_actor_environment_registry.py
index 90d52fe76e..3f02acb4e1 100644
--- a/nemo_rl/distributed/ray_actor_environment_registry.py
+++ b/nemo_rl/distributed/ray_actor_environment_registry.py
@@ -35,6 +35,7 @@
"nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2": PY_EXECUTABLES.AUTOMODEL,
"nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker": MCORE_EXECUTABLE,
"nemo_rl.environments.math_environment.MathEnvironment": PY_EXECUTABLES.SYSTEM,
+ "nemo_rl.environments.math_environment.MathMultiRewardEnvironment" : PY_EXECUTABLES.SYSTEM,
"nemo_rl.environments.vlm_environment.VLMEnvironment": PY_EXECUTABLES.SYSTEM,
"nemo_rl.environments.code_environment.CodeEnvironment": PY_EXECUTABLES.SYSTEM,
"nemo_rl.environments.reward_model_environment.RewardModelEnvironment": PY_EXECUTABLES.SYSTEM,
diff --git a/nemo_rl/environments/interfaces.py b/nemo_rl/environments/interfaces.py
index b869c32df7..a8167581e5 100644
--- a/nemo_rl/environments/interfaces.py
+++ b/nemo_rl/environments/interfaces.py
@@ -44,7 +44,7 @@ class EnvironmentReturn(NamedTuple, Generic[MetadataT]):
observations: list[dict[str, str]]
metadata: list[MetadataT]
next_stop_strings: list[list[str] | None] | list[None]
- rewards: Tensor
+ rewards: Tensor ## This could be of different shape
terminateds: Tensor
answers: list[str | None] | None
diff --git a/nemo_rl/environments/math_environment.py b/nemo_rl/environments/math_environment.py
index 8de2da805a..9f0acf655f 100644
--- a/nemo_rl/environments/math_environment.py
+++ b/nemo_rl/environments/math_environment.py
@@ -230,6 +230,120 @@ def verify(
return results
+@ray.remote # pragma: no cover
+class HFMultiRewardVerifyWorker:
+ def __init__(self) -> None:
+ logging.getLogger("math_multi_reward_verify").setLevel(logging.CRITICAL)
+
+ # Use Latex and plain math extraction from predictions
+ # https://github.com/huggingface/Math-Verify?tab=readme-ov-file#extraction-targets
+ self.verify_func = math_metric(
+ gold_extraction_target=(LatexExtractionConfig(),),
+ pred_extraction_target=(
+ ExprExtractionConfig(),
+ LatexExtractionConfig(),
+ ),
+ )
+
+ def verify(
+ self,
+ pred_responses: list[str],
+ ground_truths: list[str],
+ return_extracted_answer: bool = False,
+ **kwargs,
+ ) -> Union[list[float], tuple[list[float], list[str | None]]]:
+ """Verify the correctness of the predicted responses against the ground truth.
+
+ Args:
+ pred_responses: list[str]. The predicted responses from the LLM.
+ ground_truths: list[str]. The ground truth responses.
+
+ Returns:
+ Union[list[float], tuple[list[float], list[str | None]]].
+ If return_extracted_answer is False, returns only the scores.
+ If return_extracted_answer is True, returns (scores, extracted_answers).
+ """
+ def extract_xml_answer(text: str) -> str:
+ answer = text.split("")[-1]
+ answer = answer.split("")[0]
+ return answer.strip()
+
+ def correctness_reward_func(completions, answer, **kwargs) -> list[float]:
+ extracted_responses = [extract_xml_answer(r) for r in completions]
+ return [1.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
+
+ def int_reward_func(completions, **kwargs) -> list[float]:
+ extracted_responses = [extract_xml_answer(r) for r in completions]
+ return [1.0 if r.isdigit() else 0.0 for r in extracted_responses]
+
+ def format_reward_func(completions, **kwargs) -> list[float]:
+ """Reward function that checks if the completion has a specific format."""
+ rewards = []
+ for response in completions:
+
+ pattern = r"^.*?\n.*?$"
+
+ if re.search(pattern, response, re.DOTALL) and response.count("") == 1 and response.count("") == 1:
+ rewards.append(1.0)
+ else:
+ rewards.append(0.0)
+
+ return rewards
+
+ number_of_rewards = 3
+ results = [[] for i in range(number_of_rewards)]
+ extracted_answers: list[str | None] = []
+
+ for response, ground_truth in zip(pred_responses, ground_truths):
+
+ try:
+ # with _mute_output():
+ math_verify_impl = kwargs.get("math_verify_impl", "hf_math_verify")
+ if math_verify_impl == "hf_math_verify":
+ cor_reward = correctness_reward_func([response],[ground_truth])
+ int_reward = int_reward_func([response])
+ format_reward = format_reward_func([response])
+ extracted_answer = extract_xml_answer(response)
+ else:
+ raise ValueError(
+ f"Unknown math_verify_impl: {math_verify_impl}. Expected 'hf_math_verify'"
+ )
+
+ results[0].extend(cor_reward)
+ results[1].extend(int_reward)
+ results[2].extend(format_reward)
+
+ if return_extracted_answer:
+ # Make sure the extracted answer is not None and is a list of two elements
+ assert extracted_answer is not None
+ assert len(extracted_answer) == 2
+ extracted_gold, extracted_prediction = extracted_answer
+ # Get the extracted answer with the same logic as in the HFVerifyWorker
+ for pred in extracted_prediction:
+ if any(grader.verify(gold, pred) for gold in extracted_gold):
+ extracted_answers.append(pred)
+ break
+ else:
+ # If no match is found, means all answers are incorrect, just use the first prediction
+ extracted_answers.append(extracted_prediction[0][0])
+
+ # It's possible to emit a TimeoutException and that wouldn't be caught since
+ # it actually subclasses from BaseException and math-verify itself does not
+ # to catch it.
+ except (Exception, TimeoutException):
+ results[0].append(0.0)
+ results[1].append(0.0)
+ results[2].append(0.0)
+ extracted_answers.append(None)
+
+ if return_extracted_answer:
+ return results, extracted_answers
+ else:
+ return results
+ # return results --> [[0,1,0], [0,2,0], .........]
+
+
+
class MathEnvironmentMetadata(TypedDict):
ground_truth: str
extracted_answer: str | None
@@ -391,3 +505,161 @@ def global_post_process_and_metrics(
}
return batch, metrics
+
+
+
+@ray.remote(max_restarts=-1, max_task_retries=-1) # pragma: no cover
+class MathMultiRewardEnvironment(EnvironmentInterface[MathEnvironmentMetadata]):
+ def __init__(self, cfg: MathEnvConfig):
+ self.cfg = cfg
+ self.num_workers = cfg["num_workers"]
+ # TODO: split out this environment since it's doing more than just math
+ verifier_type = cfg.get("verifier_type", "math")
+ assert isinstance(verifier_type, str), (
+ f"{verifier_type=} must be a string but was {type(verifier_type)}"
+ )
+
+ worker_cls = {
+ "math": HFMultiRewardVerifyWorker,
+ }[verifier_type]
+ self.workers = [
+ worker_cls.options( # type: ignore # (decorated with @ray.remote)
+ runtime_env={"py_executable": PY_EXECUTABLES.SYSTEM}
+ ).remote()
+ for _ in range(self.num_workers)
+ ]
+
+ def shutdown(self) -> None:
+ # shutdown all workers
+ for worker in self.workers:
+ ray.kill(worker)
+
+ def step(
+ self,
+ message_log_batch: list[LLMMessageLogType],
+ metadata: list[MathEnvironmentMetadata],
+ return_extracted_answer: bool = False,
+ ) -> EnvironmentReturn[MathEnvironmentMetadata]:
+ """Runs a step in the math environment.
+
+ Args:
+ message_log: list[list[dict[str, str]]]. A batch of OpenAI-API-like message logs that represent interactions with the LLM.
+ metadata: list[MathEnvironmentMetadata]. The grader will use the 'ground_truth' key to evaluate correctness. The extracted answer will be stored to caculate cons@k.
+
+ Returns:
+ EnvironmentReturn: A tuple containing:
+ - list[dict[str, str]]: Observations/responses batch
+ - list[dict]: Updated metadata
+ - list[str]: Next stop strings for the next turn
+ - Tensor: Rewards tensor
+ - Tensor: Done flags tensor
+ """
+ # Extract the assistant's responses from the message history
+ # Each message list should have at least one assistant response
+ assistant_response_batch = []
+ for conversation in message_log_batch:
+ assistant_responses = [
+ str(interaction["content"])
+ for interaction in conversation
+ if interaction["role"] == "assistant"
+ ]
+ assistant_response_batch.append("".join(assistant_responses))
+
+ ground_truths = [g["ground_truth"] for g in metadata]
+
+ chunked_assistant_response_batch = chunk_list_to_workers(
+ assistant_response_batch, self.num_workers
+ )
+ chunked_ground_truths = chunk_list_to_workers(ground_truths, self.num_workers)
+
+ # Process each chunk in parallel
+ futures = [
+ self.workers[i].verify.remote(
+ chunk,
+ ground_truth_chunk,
+ return_extracted_answer,
+ math_verify_impl=self.cfg.get("math_verify_impl", "hf_math_verify"),
+ )
+ for i, (chunk, ground_truth_chunk) in enumerate(
+ zip(chunked_assistant_response_batch, chunked_ground_truths)
+ )
+ ]
+
+ worker_results = ray.get(futures)
+
+ # Flatten the results and extract both scores and answers
+ number_of_rewards = 3
+ results = [[]for i in range(number_of_rewards)]
+ extracted_answers: list[str | None] | None = (
+ [] if return_extracted_answer else None
+ )
+
+ for worker_result in worker_results:
+ if return_extracted_answer:
+ raise NotImplementedError("Skip return_extracted_answer handling")
+ else:
+ for i in range(number_of_rewards):
+ results[i].extend(worker_result[i])
+
+ observations = [
+ {
+ "role": "environment",
+ "content": "Environment: correct"
+ if result
+ else "Environment: incorrect",
+ }
+ for result in results[0] ## index 0 always store corretness reward
+ ]
+
+ # create a tensor of rewards and done flags
+ rewards = torch.tensor(results).T.cpu() ## Shape Batch_size, Number_rewards
+ ## hard fixed this done to
+ done = torch.ones(rewards.shape[0]).cpu()
+ next_stop_strings = [None] * len(message_log_batch)
+
+ return EnvironmentReturn(
+ observations=observations,
+ metadata=metadata,
+ next_stop_strings=next_stop_strings,
+ rewards=rewards,
+ terminateds=done,
+ answers=extracted_answers,
+ )
+
+ def global_post_process_and_metrics(
+ self, batch: BatchedDataDict[Any]
+ ) -> tuple[BatchedDataDict[Any], dict[str, float | int]]:
+ """Computes metrics for this environment given a global rollout batch.
+
+ Every rank will run this function, so you're free to use distributed
+ calculations if you'd prefer for heavy metrics.
+ """
+ batch["rewards"] = (
+ batch["rewards"] * batch["is_end"]
+ ) # set a reward of 0 for any incorrectly ended sequences
+ if (batch["rewards"] == 1).float().sum() > 0:
+ correct_solution_generation_lengths = (
+ (batch["generation_lengths"] - batch["prompt_lengths"])[
+ batch["rewards"] == 1
+ ]
+ .float()
+ .mean()
+ .item()
+ )
+ else:
+ correct_solution_generation_lengths = 0
+
+ metrics = {
+ # "table": table, TODO @sahilj WIP
+ "accuracy": batch["rewards"].mean().item(),
+ "pass@samples_per_prompt": calculate_pass_rate_per_prompt(
+ batch["text"], batch["rewards"]
+ ),
+ "fraction_of_samples_properly_ended": batch["is_end"].float().mean().item(),
+ "num_problems_in_batch": batch["is_end"].shape[0],
+ "generation_lengths": batch["generation_lengths"].float().mean().item(),
+ "prompt_lengths": batch["prompt_lengths"].float().mean().item(),
+ "correct_solution_generation_lengths": correct_solution_generation_lengths,
+ }
+
+ return batch, metrics
diff --git a/nemo_rl/environments/utils.py b/nemo_rl/environments/utils.py
index c4227b8631..df82c7d1af 100644
--- a/nemo_rl/environments/utils.py
+++ b/nemo_rl/environments/utils.py
@@ -35,6 +35,9 @@ class EnvRegistryEntry(TypedDict, total=False):
"math": {
"actor_class_fqn": "nemo_rl.environments.math_environment.MathEnvironment",
},
+ "math_multi_reward": {
+ "actor_class_fqn": "nemo_rl.environments.math_environment.MathMultiRewardEnvironment",
+ },
"code": {
"actor_class_fqn": "nemo_rl.environments.code_environment.CodeEnvironment",
},
diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py
index 603a972095..52f3996244 100644
--- a/nemo_rl/experience/rollouts.py
+++ b/nemo_rl/experience/rollouts.py
@@ -327,7 +327,9 @@ def calculate_rewards(
sorted_indices = sorted(
range(len(all_indices_order)), key=lambda k: all_indices_order[k]
)
- rewards = torch.tensor([all_rewards[i] for i in sorted_indices])
+ # Stack rewards: each element may be scalar (single-reward env) or 1d (multi-reward env).
+ # torch.stack preserves shape: scalars -> (N,), shape (K,) -> (N, K).
+ rewards = torch.stack([all_rewards[i] for i in sorted_indices])
env_observations = [all_env_observations[i] for i in sorted_indices]
terminateds = torch.tensor([all_terminateds[i] for i in sorted_indices])
next_stop_strings = [all_next_stop_strings[i] for i in sorted_indices]
@@ -374,6 +376,10 @@ def run_multi_turn_rollout(
active_indices = torch.arange(batch_size)
total_rewards = torch.zeros(batch_size, dtype=torch.float32)
+ # Multi_rewards: number of components inferred from first env_output (1 for single-reward envs)
+ number_of_rewards: int | None = None
+ multi_rewards: torch.Tensor | None = None
+
# Initialize stop_strings from the initial batch if present
current_stop_strings = current_batch.get("stop_strings", [None] * batch_size)
@@ -459,7 +465,24 @@ def run_multi_turn_rollout(
# Calculate rewards and get environment feedback
env_output: EnvironmentReturn = calculate_rewards(active_batch, task_to_env)
- total_rewards[active_indices] += env_output.rewards
+ # Infer number of reward components on first turn (supports single- and multi-reward envs)
+ if number_of_rewards is None:
+ if env_output.rewards.ndim >= 2:
+ number_of_rewards = int(env_output.rewards.shape[1])
+ multi_rewards = torch.zeros(
+ batch_size, number_of_rewards, dtype=torch.float32
+ )
+ else:
+ number_of_rewards = 1
+ # multi_rewards left None: GRPO uses total_reward only; reward1 unused
+ # Accumulate rewards: env may return shape (N,) or (N, K)
+ if env_output.rewards.ndim >= 2:
+ multi_rewards[active_indices] += env_output.rewards
+ total_rewards[active_indices] += env_output.rewards.sum(dim=1)
+ else:
+ total_rewards[active_indices] += env_output.rewards
+
+
# Update message log for ALL active samples with env observation
# This must happen BEFORE filtering based on done flags
@@ -536,6 +559,11 @@ def run_multi_turn_rollout(
# Add total rewards to the final batch
current_batch["total_reward"] = total_rewards
current_batch["truncated"] = sample_truncated
+ # Expose per-component rewards (reward1, reward2, ...) for multi-reward envs only; GRPO uses total_reward
+ if multi_rewards is not None:
+ num_reward_components = multi_rewards.shape[1]
+ for i in range(num_reward_components):
+ current_batch[f"reward{i + 1}"] = multi_rewards[:, i].clone()
# Calculate aggregate metrics
rollout_metrics = {
@@ -666,6 +694,8 @@ async def run_sample_multi_turn_rollout(
# Sample-level metrics
total_reward = 0.0
+ reward_acc_list: list[float] = [] # per-component rewards, length set on first multi-reward
+ multi_reward_seen = False
turn_count = 0
token_count = 0
assistant_token_count = 0
@@ -738,8 +768,17 @@ async def run_sample_multi_turn_rollout(
# Get environment feedback
env_output = calculate_rewards(sample_batch, task_to_env)
- # Update total reward
- total_reward += float(env_output.rewards[0].item())
+ # Update total reward and optional per-reward signals (reward1, reward2, ... rewardN)
+ if env_output.rewards.ndim == 2 and env_output.rewards.shape[1] >= 1:
+ multi_reward_seen = True
+ n = env_output.rewards.shape[1]
+ if len(reward_acc_list) == 0:
+ reward_acc_list = [0.0] * n
+ total_reward += float(env_output.rewards[0].sum().item())
+ for j in range(n):
+ reward_acc_list[j] += float(env_output.rewards[0, j].item())
+ else:
+ total_reward += float(env_output.rewards[0].item())
# Check termination
terminated = env_output.terminateds[0].item()
env_obs_content = env_output.observations[0]["content"]
@@ -789,6 +828,9 @@ async def run_sample_multi_turn_rollout(
"stop_strings": current_stop_strings,
"idx": sample_idx,
}
+ if multi_reward_seen:
+ for j in range(len(reward_acc_list)):
+ final_sample_state[f"reward{j + 1}"] = torch.tensor(reward_acc_list[j])
# Sample metrics
sample_metrics = {
diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh
index bee4d8d2eb..ced8ffacaf 100644
--- a/tests/functional/L1_Functional_Tests_GPU.sh
+++ b/tests/functional/L1_Functional_Tests_GPU.sh
@@ -45,6 +45,7 @@ run_test uv run --no-sync bash ./tests/functional/dpo_automodel_lora.sh
run_test uv run --no-sync bash ./tests/functional/dpo_megatron.sh
run_test uv run --no-sync bash ./tests/functional/eval.sh
run_test uv run --no-sync bash ./tests/functional/eval_async.sh
+run_test fast uv run --no-sync bash ./tests/functional/gdpo.sh
run_test fast uv run --no-sync bash ./tests/functional/grpo.sh
run_test fast uv run --no-sync bash ./tests/functional/grpo_async_gym.sh
run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora.sh
diff --git a/tests/functional/gdpo.sh b/tests/functional/gdpo.sh
new file mode 100644
index 0000000000..ee95645e28
--- /dev/null
+++ b/tests/functional/gdpo.sh
@@ -0,0 +1,47 @@
+#!/bin/bash
+
+SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
+PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
+# Mark the current repo as safe, since wandb fetches metadata about the repo
+git config --global --add safe.directory $PROJECT_ROOT
+
+set -eou pipefail
+
+EXP_NAME=$(basename $0 .sh)
+EXP_DIR=$SCRIPT_DIR/$EXP_NAME
+LOG_DIR=$EXP_DIR/logs
+JSON_METRICS=$EXP_DIR/metrics.json
+RUN_LOG=$EXP_DIR/run.log
+export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}
+
+rm -rf $EXP_DIR $LOG_DIR
+mkdir -p $EXP_DIR $LOG_DIR
+
+cd $PROJECT_ROOT
+uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
+ $PROJECT_ROOT/examples/run_grpo.py \
+ --config $PROJECT_ROOT/examples/configs/gdpo_math_1B.yaml \
+ policy.model_name=Qwen/Qwen3-0.6B \
+ grpo.num_prompts_per_step=2 \
+ grpo.num_generations_per_prompt=4 \
+ policy.train_global_batch_size=4 \
+ policy.train_micro_batch_size=1 \
+ cluster.gpus_per_node=2 \
+ cluster.num_nodes=1 \
+ grpo.max_num_steps=2 \
+ logger.tensorboard_enabled=true \
+ logger.log_dir=$LOG_DIR \
+ logger.wandb_enabled=false \
+ logger.monitor_gpus=true \
+ checkpointing.enabled=false \
+ $@ \
+ 2>&1 | tee $RUN_LOG
+
+uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
+
+uv run tests/check_metrics.py $JSON_METRICS \
+ 'max(data["train/gen_kl_error"]) < 0.001' \
+ 'min(data["train/probs_ratio_clamped_min"]) > 0.79' \
+ 'max(data["train/probs_ratio_clamped_min"]) < 1.21' \
+ 'min(data["train/probs_ratio_clamped_max"]) > 0.79' \
+ 'max(data["train/probs_ratio_clamped_max"]) < 1.21'
diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py
index 7a0783f132..e0bdb8d2b1 100644
--- a/tests/unit/algorithms/test_grpo.py
+++ b/tests/unit/algorithms/test_grpo.py
@@ -1642,8 +1642,9 @@ def test_grpo_advantage_estimator_zero_std():
[2.0, 2.0, 1.0, 3.0]
) # prompt 0: std=0; prompt 1: std=sqrt(2)
mask = torch.ones(4, 5)
+ repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards}
- result = estimator.compute_advantage(prompt_ids, rewards, mask)
+ result = estimator.compute_advantage(repeated_batch, mask)
# prompt 0: std=0 -> skip normalization, advantage=0 (reward - mean = 0)
# prompt 1: With Bessel correction for 2 samples, std = sqrt(2), normalized = ±1/sqrt(2) ≈ ±0.7071
@@ -1673,8 +1674,9 @@ def test_grpo_advantage_estimator_tensor_shapes():
prompt_ids = torch.tensor([[0], [0]])
rewards = torch.tensor([1.0, 3.0]) # mean=2, std=sqrt(2) with Bessel
mask = torch.ones(2, 3)
+ repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards}
- result = estimator.compute_advantage(prompt_ids, rewards, mask)
+ result = estimator.compute_advantage(repeated_batch, mask)
assert result.shape == (2, 3)
# Verify normalized values: (reward - mean) / std
@@ -1687,8 +1689,9 @@ def test_grpo_advantage_estimator_tensor_shapes():
prompt_ids = torch.tensor([[0]] * 10)
rewards = torch.arange(10, dtype=torch.float32) # 0, 1, 2, ..., 9
mask = torch.ones(10, 5)
+ repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards}
- result = estimator.compute_advantage(prompt_ids, rewards, mask)
+ result = estimator.compute_advantage(repeated_batch, mask)
assert result.shape == (10, 5)
# After normalization, mean should be ~0
@@ -1712,8 +1715,9 @@ def test_grpo_advantage_estimator_negative_advantages():
prompt_ids = torch.tensor([[0], [0], [0]])
rewards = torch.tensor([0.0, 2.0, 4.0]) # mean=2, deviations: -2, 0, +2
mask = torch.ones(3, 4)
+ repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards}
- result = estimator.compute_advantage(prompt_ids, rewards, mask)
+ result = estimator.compute_advantage(repeated_batch, mask)
# Verify ordering: first should be negative, middle ~0, last positive
assert result[0, 0] < 0 # below mean -> negative advantage
@@ -1742,8 +1746,9 @@ def test_grpo_advantage_estimator_zero_std_and_zero_advantage():
prompt_ids = torch.tensor([[0], [0], [0], [0]])
rewards = torch.tensor([5.0, 5.0, 5.0, 5.0]) # all same
mask = torch.ones(4, 3)
+ repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards}
- result = estimator.compute_advantage(prompt_ids, rewards, mask)
+ result = estimator.compute_advantage(repeated_batch, mask)
# All advantages should be exactly 0
expected = torch.zeros(4, 3)
@@ -1768,8 +1773,9 @@ def test_grpo_advantage_estimator_small_nonzero_std():
prompt_ids = torch.tensor([[0], [0]])
rewards = torch.tensor([1.0, 1.01]) # small but detectable difference
mask = torch.ones(2, 3)
+ repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards}
- result = estimator.compute_advantage(prompt_ids, rewards, mask)
+ result = estimator.compute_advantage(repeated_batch, mask)
# Even with small std, normalization should still happen
# After normalization, the values should be ±1/sqrt(2) (for 2 samples with Bessel)
@@ -1808,8 +1814,9 @@ def test_reinforce_plus_plus_global_normalization():
) # Shape (4, 1) for unique prompt matching
rewards = torch.tensor([0.0, 1.0, 2.0, 3.0]) # mean=1.5
mask = torch.ones(4, 5)
+ repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards}
- result = estimator.compute_advantage(prompt_ids, rewards, mask)
+ result = estimator.compute_advantage(repeated_batch, mask)
# After global normalization, mean should be ~0
result_mean = (result * mask).sum() / mask.sum()