diff --git a/examples/configs/gdpo_math_1B.yaml b/examples/configs/gdpo_math_1B.yaml new file mode 100644 index 0000000000..47450a17a7 --- /dev/null +++ b/examples/configs/gdpo_math_1B.yaml @@ -0,0 +1,62 @@ +# GDPO: inherits from grpo_math_1B.yaml and overrides only what differs. +defaults: grpo_math_1B.yaml + +grpo: + adv_estimator: + name: "gdpo" + normalize_rewards: true + use_leave_one_out_baseline: false + +checkpointing: + checkpoint_dir: "results/gdpo" + +policy: + model_name: "Qwen/Qwen2.5-1.5B-Instruct" + logprob_batch_size: 4 + max_total_sequence_length: 1024 + megatron_cfg: + optimizer: + weight_decay: 0.0 + scheduler: + lr_decay_style: "cosine" + lr_warmup_iters: 10 + +# GDPO uses a single flat data config (GSM8K + math_gdpo_data_processor); replace parent's train/validation/default. +data: + _override_: true + + max_input_seq_length: ${policy.max_total_sequence_length} + shuffle: true + num_workers: 1 + + use_multiple_dataloader: false + + train: + dataset_name: "gsm8k" + split: train + validation: + dataset_name: "gsm8k" + split: test + + default: + prompt_file: null + system_prompt_file: "examples/prompts/gsm8k.txt" + processor: "math_gdpo_data_processor" + env_name: "math_multi_reward" + +env: + math_multi_reward: + num_workers: 8 + math_verify_impl: "hf_math_verify" + +logger: + wandb_enabled: true + wandb: + project: "gdpo-dev" + name: "gdpo-dev-logger" + swanlab: + project: "gdpo-dev" + name: "gdpo-dev-logger" + mlflow: + experiment_name: "gdpo-dev" + run_name: "gdpo-dev-logger" diff --git a/examples/prompts/gsm8k.txt b/examples/prompts/gsm8k.txt new file mode 100644 index 0000000000..3c31977100 --- /dev/null +++ b/examples/prompts/gsm8k.txt @@ -0,0 +1,17 @@ +You are a helpful AI assistant. + +For every request, you should carefully think through the math problem step by step, then provide the final answer in integer format. + +Steps for Each Request: +1. Think: Provide detailed, step-by-step reasoning, calculations, or derivations. +2. Produce Final Answer: After step-by-step reasoning, output the final answer in integer format. + +Output Format: +Your thoughts and reasoning +Final answer in integer format + +Important Notes: +1. You must include your reasoning steps inside . +2. You must always output the Final Answer within after the reasoning steps is done. +3. You should consistently work through the solution step by step before giving the final answer. +4. The final answer can only be an integer. \ No newline at end of file diff --git a/examples/run_grpo.py b/examples/run_grpo.py index 6130b99018..0e9f8bf24a 100644 --- a/examples/run_grpo.py +++ b/examples/run_grpo.py @@ -139,6 +139,13 @@ def main() -> None: "use_multiple_dataloader is not supported with async GRPO" ) + # Async GDPO is not supported + if config["grpo"]["adv_estimator"]["name"] == "gdpo": + raise NotImplementedError( + "GDPO is not supported for async training, " + "please set grpo.async_grpo.enabled to false in your config." + ) + from nemo_rl.algorithms.grpo import async_grpo_train print("🚀 Running async GRPO training") diff --git a/nemo_rl/algorithms/advantage_estimator.py b/nemo_rl/algorithms/advantage_estimator.py index 6d14d8637a..976c142a44 100644 --- a/nemo_rl/algorithms/advantage_estimator.py +++ b/nemo_rl/algorithms/advantage_estimator.py @@ -16,6 +16,7 @@ This module provides different advantage estimation strategies: - GRPOAdvantageEstimator: Standard GRPO advantage with leave-one-out baseline +- GDPOAdvantageEstimator: Multi-reward GDPO (per-component baselines, sum then normalize) - ReinforcePlusPlusAdvantageEstimator: Reinforce++ with optional baseline subtraction (minus_baseline) and KL penalty in reward Reference papers: - ProRLv2: https://developer.nvidia.com/blog/scaling-llm-reinforcement-learning-with-prolonged-training-using-prorl-v2/ @@ -24,8 +25,7 @@ import torch -from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt, calculate_kl - +from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt, calculate_kl, get_gdpo_reward_component_keys class GRPOAdvantageEstimator: """GRPO-style advantage estimator with leave-one-out baseline. @@ -37,12 +37,11 @@ def __init__(self, estimator_config: dict, loss_config: dict): self.use_leave_one_out_baseline = estimator_config["use_leave_one_out_baseline"] self.normalize_rewards = estimator_config["normalize_rewards"] - def compute_advantage(self, prompt_ids, rewards, mask, **kwargs): + def compute_advantage(self, repeated_batch, mask, **kwargs): """Compute GRPO advantages. Args: - prompt_ids: Tensor of shape [batch_size] identifying which prompt each sample belongs to. - rewards: Tensor of shape [batch_size] containing reward for each sample. + repeated_batch: Batch containing _input_ids_for_baseline and total_reward. mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding. Used only for expanding advantages to token-level shape. **kwargs: Additional arguments (unused). @@ -50,6 +49,8 @@ def compute_advantage(self, prompt_ids, rewards, mask, **kwargs): Returns: Advantages tensor of shape [batch_size, seq_len]. """ + prompt_ids = repeated_batch["_input_ids_for_baseline"] + rewards = repeated_batch["total_reward"] baseline, std = calculate_baseline_and_std_per_prompt( prompt_ids, rewards, @@ -69,6 +70,75 @@ def compute_advantage(self, prompt_ids, rewards, mask, **kwargs): return advantages.expand(mask.shape) +class GDPOAdvantageEstimator: + """GDPO-style advantage estimator with leave-one-out baseline. + + Note: GDPO computes advantages for each reward separately over all responses for each prompt. + """ + + def __init__(self, estimator_config: dict, loss_config: dict): + self.use_leave_one_out_baseline = estimator_config["use_leave_one_out_baseline"] + self.normalize_rewards = estimator_config["normalize_rewards"] + + def compute_advantage(self, repeated_batch, mask, **kwargs): + """Compute GDPO advantages. + + Args: + repeated_batch: Batch containing _input_ids_for_baseline and reward1, reward2, ... keys. + mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding. + Used only for expanding advantages to token-level shape. + **kwargs: Additional arguments (unused). + + Returns: + Advantages tensor of shape [batch_size, seq_len]. + """ + reward_component_keys = get_gdpo_reward_component_keys(repeated_batch) + if len(reward_component_keys) < 2: + raise ValueError( + f"GDPO requires multiple reward components (reward1, reward2, ...). " + f"This batch has {len(reward_component_keys)} component(s). " + "Switch to GRPO by setting grpo.adv_estimator.name to 'grpo' in your config." + ) + current_input_ids = repeated_batch["_input_ids_for_baseline"] + valid = torch.ones_like( + repeated_batch[reward_component_keys[0]] + ) + leave_one_out = self.use_leave_one_out_baseline + assert current_input_ids.shape[0] == valid.shape[0], ( + "_input_ids_for_baseline must match reward batch size after dynamic_sampling; " + f"got {current_input_ids.shape[0]} vs {valid.shape[0]}" + ) + advantage_parts = [] + for key in reward_component_keys: + r = repeated_batch[key] + base, std_k = calculate_baseline_and_std_per_prompt( + current_input_ids, + r, + valid, + leave_one_out_baseline=leave_one_out, + ) + adv_k = (r - base).unsqueeze(-1) + if self.normalize_rewards: + + epsilon = 1e-6 + non_zero_std_mask = std_k > 0 + adv_k[non_zero_std_mask] = adv_k[non_zero_std_mask] / ( + std_k.unsqueeze(-1)[non_zero_std_mask] + epsilon + ) + + advantage_parts.append(adv_k) + + advantages = sum(advantage_parts) + # Normalize combined advantage to zero mean and unit std + adv_std = advantages.std() + if adv_std > 0: + advantages = (advantages - advantages.mean()) / adv_std + else: + advantages = advantages - advantages.mean() + + return advantages.expand(mask.shape) + + class ReinforcePlusPlusAdvantageEstimator: """Reinforce++ advantage estimator with optional baseline subtraction and KL penalty in reward. @@ -83,20 +153,11 @@ def __init__(self, estimator_config: dict, loss_config: dict): self.kl_coef = loss_config["reference_policy_kl_penalty"] self.kl_type = loss_config["reference_policy_kl_type"] - def compute_advantage( - self, - prompt_ids, - rewards, - mask, - logprobs_policy=None, - logprobs_reference=None, - **kwargs, - ): + def compute_advantage(self, repeated_batch, mask, logprobs_policy=None, logprobs_reference=None, **kwargs): """Compute Reinforce++ advantages with optional KL penalty. Args: - prompt_ids: Tensor of shape [batch_size] identifying which prompt each sample belongs to. - rewards: Tensor of shape [batch_size] containing reward for each sample. + repeated_batch: Batch containing _input_ids_for_baseline and total_reward. mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding. Used for: (1) expanding advantages to token-level shape, (2) global normalization that only considers valid tokens. @@ -107,6 +168,8 @@ def compute_advantage( Returns: Advantages tensor of shape [batch_size, seq_len], globally normalized across valid tokens. """ + prompt_ids = repeated_batch["_input_ids_for_baseline"] + rewards = repeated_batch["total_reward"] # minus baseline if self.minus_baseline: mean, _ = calculate_baseline_and_std_per_prompt( diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index c060a05a50..f5b8892ed7 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -13,6 +13,7 @@ # limitations under the License. import gc import os +import re import time import warnings from concurrent.futures import ThreadPoolExecutor @@ -29,6 +30,7 @@ from nemo_rl.algorithms.advantage_estimator import ( GRPOAdvantageEstimator, + GDPOAdvantageEstimator, ReinforcePlusPlusAdvantageEstimator, ) from nemo_rl.algorithms.loss import ( @@ -46,6 +48,7 @@ log_generation_metrics_to_wandb, print_performance_metrics, set_seed, + get_gdpo_reward_component_keys ) from nemo_rl.data import DataConfig from nemo_rl.data.collate_fn import rl_collate_fn @@ -121,9 +124,9 @@ class AsyncGRPOConfig(TypedDict): class AdvEstimatorConfig(TypedDict): - """Configuration for advantage estimator (GRPO or Reinforce++).""" + """Configuration for advantage estimator (GRPO, GDPO, or Reinforce++).""" - name: str # "grpo" or "reinforce_plus_plus" + name: str # "grpo", "gdpo", or "reinforce_plus_plus" # GRPO specific normalize_rewards: NotRequired[bool] use_leave_one_out_baseline: NotRequired[bool] @@ -966,11 +969,16 @@ def scale_rewards( ) # Clamp and scale - rewards = torch.clamp(rewards, min=source_min, max=source_max) - scaled_rewards = target_min + (rewards - source_min) / ( - source_max - source_min - ) * (target_max - target_min) + def _scale(reward_tensor: torch.Tensor) -> torch.Tensor: + r = torch.clamp(reward_tensor, min=source_min, max=source_max) + return target_min + (r - source_min) / ( + source_max - source_min + ) * (target_max - target_min) + + scaled_rewards = _scale(rewards) repeated_batch["total_reward"] = scaled_rewards + for key in get_gdpo_reward_component_keys(repeated_batch): + repeated_batch[key] = _scale(repeated_batch[key]) return repeated_batch @@ -1031,7 +1039,7 @@ def _create_advantage_estimator(master_config: MasterConfig): master_config: The master configuration dictionary. Returns: - An advantage estimator instance (GRPOAdvantageEstimator or ReinforcePlusPlusAdvantageEstimator). + An advantage estimator instance (GRPO, GDPO, or ReinforcePlusPlus). Raises: ValueError: If the advantage estimator name is not recognized. @@ -1055,7 +1063,14 @@ def _create_advantage_estimator(master_config: MasterConfig): ) adv_estimator_name = adv_estimator_config["name"] - if adv_estimator_name == "grpo": + if adv_estimator_name == "gdpo": + assert not _should_use_async_rollouts(master_config), ( + "GDPO is not supported for async rollouts, " + "please set policy.generation.vllm_cfg.async_engine to false in your config." + ) + adv_estimator = GDPOAdvantageEstimator(adv_estimator_config, loss_config) + print(" ✓ Using GDPO advantage estimator (multi-reward)") + elif adv_estimator_name == "grpo": adv_estimator = GRPOAdvantageEstimator(adv_estimator_config, loss_config) print(" ✓ Using GRPO advantage estimator") elif adv_estimator_name == "reinforce_plus_plus": @@ -1590,6 +1605,10 @@ def grpo_train( with timer.time("reward_calculation"): # Extract rewards from final_batch rewards = repeated_batch["total_reward"] + # Store input_ids in batch so that after dynamic_sampling it stays aligned with + # the (possibly filtered) batch: select_indices / from_batches / slice all + # apply to this key, so per-reward baselines use the same prompts as reward components. + repeated_batch["_input_ids_for_baseline"] = input_ids print("▶ Computing advantages...", flush=True) if master_config["grpo"].get("calculate_advantages_on_gpu"): @@ -1644,10 +1663,10 @@ def grpo_train( # If the current batch is not enough to fill the buffer during dynamic sampling, we update the cache and process the next batch. if not is_batch_complete: continue + gen_step_metrics = {} if hasattr(policy_generation, "get_step_metrics"): gen_step_metrics = policy_generation.get_step_metrics() - advantages = (rewards - baseline).unsqueeze(-1) # Save baseline for logging (before deletion) baseline_for_log = baseline.clone() @@ -1776,8 +1795,7 @@ def grpo_train( mask = token_mask * sample_mask.unsqueeze(-1) train_data["advantages"] = adv_estimator.compute_advantage( - prompt_ids=prompt_ids_for_adv, - rewards=rewards, + repeated_batch=repeated_batch, mask=mask, logprobs_policy=train_data["prev_logprobs"], logprobs_reference=train_data.get("reference_policy_logprobs"), @@ -2724,6 +2742,8 @@ def async_grpo_train( del prompt_batched_flat rewards = repeated_batch["total_reward"] + # All estimators read _input_ids_for_baseline from repeated_batch + repeated_batch["_input_ids_for_baseline"] = prompt_ids_for_adv print( f" 📊 Rewards stats: min={rewards.min():.4f}, max={rewards.max():.4f}, mean={rewards.mean():.4f}, std={rewards.std():.4f}" @@ -2807,8 +2827,7 @@ def async_grpo_train( mask = token_mask * sample_mask.unsqueeze(-1) train_data["advantages"] = adv_estimator.compute_advantage( - prompt_ids=prompt_ids_for_adv, - rewards=rewards, + repeated_batch=repeated_batch, mask=mask, logprobs_policy=train_data["prev_logprobs"], logprobs_reference=train_data.get("reference_policy_logprobs"), diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index 8e632ca5ee..bcc2281544 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -29,7 +29,12 @@ from nemo_rl.data.chat_templates import COMMON_CHAT_TEMPLATES from nemo_rl.models.policy import TokenizerConfig from nemo_rl.utils.logger import Logger +import re +def get_gdpo_reward_component_keys(batch) -> list: + """Return batch keys that are reward components (reward1, reward2, ...) in sorted order.""" + keys = [k for k in batch.keys() if re.match(r"reward\d+$", str(k))] + return sorted(keys, key=lambda k: int(re.search(r"\d+", str(k)).group())) def calculate_kl( logprobs: torch.Tensor, diff --git a/nemo_rl/data/datasets/response_datasets/__init__.py b/nemo_rl/data/datasets/response_datasets/__init__.py index eb48bb5204..d53a422bbb 100644 --- a/nemo_rl/data/datasets/response_datasets/__init__.py +++ b/nemo_rl/data/datasets/response_datasets/__init__.py @@ -20,6 +20,7 @@ DAPOMath17KDataset, DAPOMathAIME2024Dataset, ) +from nemo_rl.data.datasets.response_datasets.gsm8k import GSM8KDataset from nemo_rl.data.datasets.response_datasets.deepscaler import DeepScalerDataset from nemo_rl.data.datasets.response_datasets.general_conversations_dataset import ( GeneralConversationsJsonlDataset, @@ -55,6 +56,7 @@ "refcoco": RefCOCODataset, "squad": SquadDataset, "tulu3_sft_mixture": Tulu3SftMixtureDataset, + "gsm8k": GSM8KDataset, # load from local JSONL file or HuggingFace "openai_format": OpenAIFormatDataset, "NemoGymDataset": NemoGymDataset, @@ -94,6 +96,7 @@ def load_response_dataset(data_config: ResponseDatasetConfig): "GeneralConversationsJsonlDataset", "DAPOMath17KDataset", "DAPOMathAIME2024Dataset", + "GSM8KDataset", "DeepScalerDataset", "Geometry3KDataset", "HelpSteer3Dataset", diff --git a/nemo_rl/data/datasets/response_datasets/gsm8k.py b/nemo_rl/data/datasets/response_datasets/gsm8k.py new file mode 100644 index 0000000000..ce3affd869 --- /dev/null +++ b/nemo_rl/data/datasets/response_datasets/gsm8k.py @@ -0,0 +1,66 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from datasets import load_dataset + +from nemo_rl.data.datasets.raw_dataset import RawDataset + + +def _extract_hash_answer(text: str) -> str | None: + if "####" not in text: + return None + return text.split("####")[1].strip() + + +class GSM8KDataset(RawDataset): + """Simple wrapper around the GSM8K dataset. + + Args: + split: Split name for the dataset, default is "train" + extract_answer: Whether to extract the answer from the dataset, default is True + """ + + def __init__(self, + split: str = "train", + extract_answer: bool = True, + system_prompt_file: str | None = None, + **kwargs, + ) -> None: + self.task_name = "gsm8k" + self.extract_answer = extract_answer + + # load from huggingface + self.dataset = load_dataset("openai/gsm8k", "main")[split] + + # format the dataset + self.dataset = self.dataset.map( + self.format_data, + remove_columns=self.dataset.column_names, + ) + + def format_data(self, data: dict[str, Any]) -> dict[str, Any]: + if self.extract_answer: + answer = _extract_hash_answer(data["answer"]) + else: + answer = data["answer"] + + return { + "messages": [ + {"role": "user", "content": data["question"]}, + {"role": "assistant", "content": answer}, + ], + "task_name": self.task_name, + } diff --git a/nemo_rl/data/processors.py b/nemo_rl/data/processors.py index 52ac9bf67d..5d3f51cccf 100644 --- a/nemo_rl/data/processors.py +++ b/nemo_rl/data/processors.py @@ -381,6 +381,68 @@ def math_data_processor( return output +# TODO: @yukih: unify to math_hf_data_processor once https://github.com/NVIDIA-NeMo/RL/issues/2060 is resolved. +def math_gdpo_data_processor( + datum_dict: dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer: TokenizerType, + max_seq_length: int, + idx: int, +) -> DatumSpec: + """Process a datum dictionary (directly loaded from data/hf_datasets/openmathinstruct2.py) into a DatumSpec for the Reward Model Environment.""" + user_message = datum_dict["messages"] + problem = user_message[0]["content"] + extra_env_info = {"ground_truth": user_message[1]["content"]} + + # merge system prompt and user prompt + message_list = [] + # system prompt + if task_data_spec.system_prompt: + message_list.append({ + "role": "system", + "content": task_data_spec.system_prompt, + }) + # user prompt + if task_data_spec.prompt: + problem = task_data_spec.prompt.format(problem) + message_list.append({"role": "user", "content": problem}) + + message: list[str] = tokenizer.apply_chat_template( # type: ignore + message_list, + tokenize=False, + add_generation_prompt=True, + add_special_tokens=False, + ) + token_ids = tokenizer( + message, return_tensors="pt", add_special_tokens=False + )["input_ids"][0] + + message_log: LLMMessageLogType = [ + {"role": "user", "content": message, "token_ids": token_ids} + ] + + length = sum(len(m["token_ids"]) for m in message_log) + + loss_multiplier = 1.0 + if length > max_seq_length: + # make smaller and mask out + for chat_message in message_log: + chat_message["token_ids"] = chat_message["token_ids"][ + : min(4, max_seq_length // len(message_log)) + ] + loss_multiplier = 0.0 + + output: DatumSpec = { + "message_log": message_log, + "length": length, + "extra_env_info": extra_env_info, + "loss_multiplier": loss_multiplier, + "idx": idx, + "task_name": datum_dict["task_name"], + } + return output + + def math_hf_data_processor( datum_dict: dict[str, Any], task_data_spec: TaskDataSpec, @@ -698,6 +760,7 @@ def nemo_gym_data_processor( "helpsteer3_data_processor": helpsteer3_data_processor, "math_data_processor": math_data_processor, "math_hf_data_processor": math_hf_data_processor, + "math_gdpo_data_processor": math_gdpo_data_processor, "multichoice_qa_processor": multichoice_qa_processor, "sft_processor": sft_processor, "vlm_hf_data_processor": vlm_hf_data_processor, diff --git a/nemo_rl/distributed/ray_actor_environment_registry.py b/nemo_rl/distributed/ray_actor_environment_registry.py index 90d52fe76e..3f02acb4e1 100644 --- a/nemo_rl/distributed/ray_actor_environment_registry.py +++ b/nemo_rl/distributed/ray_actor_environment_registry.py @@ -35,6 +35,7 @@ "nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2": PY_EXECUTABLES.AUTOMODEL, "nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker": MCORE_EXECUTABLE, "nemo_rl.environments.math_environment.MathEnvironment": PY_EXECUTABLES.SYSTEM, + "nemo_rl.environments.math_environment.MathMultiRewardEnvironment" : PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.vlm_environment.VLMEnvironment": PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.code_environment.CodeEnvironment": PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.reward_model_environment.RewardModelEnvironment": PY_EXECUTABLES.SYSTEM, diff --git a/nemo_rl/environments/interfaces.py b/nemo_rl/environments/interfaces.py index b869c32df7..a8167581e5 100644 --- a/nemo_rl/environments/interfaces.py +++ b/nemo_rl/environments/interfaces.py @@ -44,7 +44,7 @@ class EnvironmentReturn(NamedTuple, Generic[MetadataT]): observations: list[dict[str, str]] metadata: list[MetadataT] next_stop_strings: list[list[str] | None] | list[None] - rewards: Tensor + rewards: Tensor ## This could be of different shape terminateds: Tensor answers: list[str | None] | None diff --git a/nemo_rl/environments/math_environment.py b/nemo_rl/environments/math_environment.py index 8de2da805a..9f0acf655f 100644 --- a/nemo_rl/environments/math_environment.py +++ b/nemo_rl/environments/math_environment.py @@ -230,6 +230,120 @@ def verify( return results +@ray.remote # pragma: no cover +class HFMultiRewardVerifyWorker: + def __init__(self) -> None: + logging.getLogger("math_multi_reward_verify").setLevel(logging.CRITICAL) + + # Use Latex and plain math extraction from predictions + # https://github.com/huggingface/Math-Verify?tab=readme-ov-file#extraction-targets + self.verify_func = math_metric( + gold_extraction_target=(LatexExtractionConfig(),), + pred_extraction_target=( + ExprExtractionConfig(), + LatexExtractionConfig(), + ), + ) + + def verify( + self, + pred_responses: list[str], + ground_truths: list[str], + return_extracted_answer: bool = False, + **kwargs, + ) -> Union[list[float], tuple[list[float], list[str | None]]]: + """Verify the correctness of the predicted responses against the ground truth. + + Args: + pred_responses: list[str]. The predicted responses from the LLM. + ground_truths: list[str]. The ground truth responses. + + Returns: + Union[list[float], tuple[list[float], list[str | None]]]. + If return_extracted_answer is False, returns only the scores. + If return_extracted_answer is True, returns (scores, extracted_answers). + """ + def extract_xml_answer(text: str) -> str: + answer = text.split("")[-1] + answer = answer.split("")[0] + return answer.strip() + + def correctness_reward_func(completions, answer, **kwargs) -> list[float]: + extracted_responses = [extract_xml_answer(r) for r in completions] + return [1.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] + + def int_reward_func(completions, **kwargs) -> list[float]: + extracted_responses = [extract_xml_answer(r) for r in completions] + return [1.0 if r.isdigit() else 0.0 for r in extracted_responses] + + def format_reward_func(completions, **kwargs) -> list[float]: + """Reward function that checks if the completion has a specific format.""" + rewards = [] + for response in completions: + + pattern = r"^.*?\n.*?$" + + if re.search(pattern, response, re.DOTALL) and response.count("") == 1 and response.count("") == 1: + rewards.append(1.0) + else: + rewards.append(0.0) + + return rewards + + number_of_rewards = 3 + results = [[] for i in range(number_of_rewards)] + extracted_answers: list[str | None] = [] + + for response, ground_truth in zip(pred_responses, ground_truths): + + try: + # with _mute_output(): + math_verify_impl = kwargs.get("math_verify_impl", "hf_math_verify") + if math_verify_impl == "hf_math_verify": + cor_reward = correctness_reward_func([response],[ground_truth]) + int_reward = int_reward_func([response]) + format_reward = format_reward_func([response]) + extracted_answer = extract_xml_answer(response) + else: + raise ValueError( + f"Unknown math_verify_impl: {math_verify_impl}. Expected 'hf_math_verify'" + ) + + results[0].extend(cor_reward) + results[1].extend(int_reward) + results[2].extend(format_reward) + + if return_extracted_answer: + # Make sure the extracted answer is not None and is a list of two elements + assert extracted_answer is not None + assert len(extracted_answer) == 2 + extracted_gold, extracted_prediction = extracted_answer + # Get the extracted answer with the same logic as in the HFVerifyWorker + for pred in extracted_prediction: + if any(grader.verify(gold, pred) for gold in extracted_gold): + extracted_answers.append(pred) + break + else: + # If no match is found, means all answers are incorrect, just use the first prediction + extracted_answers.append(extracted_prediction[0][0]) + + # It's possible to emit a TimeoutException and that wouldn't be caught since + # it actually subclasses from BaseException and math-verify itself does not + # to catch it. + except (Exception, TimeoutException): + results[0].append(0.0) + results[1].append(0.0) + results[2].append(0.0) + extracted_answers.append(None) + + if return_extracted_answer: + return results, extracted_answers + else: + return results + # return results --> [[0,1,0], [0,2,0], .........] + + + class MathEnvironmentMetadata(TypedDict): ground_truth: str extracted_answer: str | None @@ -391,3 +505,161 @@ def global_post_process_and_metrics( } return batch, metrics + + + +@ray.remote(max_restarts=-1, max_task_retries=-1) # pragma: no cover +class MathMultiRewardEnvironment(EnvironmentInterface[MathEnvironmentMetadata]): + def __init__(self, cfg: MathEnvConfig): + self.cfg = cfg + self.num_workers = cfg["num_workers"] + # TODO: split out this environment since it's doing more than just math + verifier_type = cfg.get("verifier_type", "math") + assert isinstance(verifier_type, str), ( + f"{verifier_type=} must be a string but was {type(verifier_type)}" + ) + + worker_cls = { + "math": HFMultiRewardVerifyWorker, + }[verifier_type] + self.workers = [ + worker_cls.options( # type: ignore # (decorated with @ray.remote) + runtime_env={"py_executable": PY_EXECUTABLES.SYSTEM} + ).remote() + for _ in range(self.num_workers) + ] + + def shutdown(self) -> None: + # shutdown all workers + for worker in self.workers: + ray.kill(worker) + + def step( + self, + message_log_batch: list[LLMMessageLogType], + metadata: list[MathEnvironmentMetadata], + return_extracted_answer: bool = False, + ) -> EnvironmentReturn[MathEnvironmentMetadata]: + """Runs a step in the math environment. + + Args: + message_log: list[list[dict[str, str]]]. A batch of OpenAI-API-like message logs that represent interactions with the LLM. + metadata: list[MathEnvironmentMetadata]. The grader will use the 'ground_truth' key to evaluate correctness. The extracted answer will be stored to caculate cons@k. + + Returns: + EnvironmentReturn: A tuple containing: + - list[dict[str, str]]: Observations/responses batch + - list[dict]: Updated metadata + - list[str]: Next stop strings for the next turn + - Tensor: Rewards tensor + - Tensor: Done flags tensor + """ + # Extract the assistant's responses from the message history + # Each message list should have at least one assistant response + assistant_response_batch = [] + for conversation in message_log_batch: + assistant_responses = [ + str(interaction["content"]) + for interaction in conversation + if interaction["role"] == "assistant" + ] + assistant_response_batch.append("".join(assistant_responses)) + + ground_truths = [g["ground_truth"] for g in metadata] + + chunked_assistant_response_batch = chunk_list_to_workers( + assistant_response_batch, self.num_workers + ) + chunked_ground_truths = chunk_list_to_workers(ground_truths, self.num_workers) + + # Process each chunk in parallel + futures = [ + self.workers[i].verify.remote( + chunk, + ground_truth_chunk, + return_extracted_answer, + math_verify_impl=self.cfg.get("math_verify_impl", "hf_math_verify"), + ) + for i, (chunk, ground_truth_chunk) in enumerate( + zip(chunked_assistant_response_batch, chunked_ground_truths) + ) + ] + + worker_results = ray.get(futures) + + # Flatten the results and extract both scores and answers + number_of_rewards = 3 + results = [[]for i in range(number_of_rewards)] + extracted_answers: list[str | None] | None = ( + [] if return_extracted_answer else None + ) + + for worker_result in worker_results: + if return_extracted_answer: + raise NotImplementedError("Skip return_extracted_answer handling") + else: + for i in range(number_of_rewards): + results[i].extend(worker_result[i]) + + observations = [ + { + "role": "environment", + "content": "Environment: correct" + if result + else "Environment: incorrect", + } + for result in results[0] ## index 0 always store corretness reward + ] + + # create a tensor of rewards and done flags + rewards = torch.tensor(results).T.cpu() ## Shape Batch_size, Number_rewards + ## hard fixed this done to + done = torch.ones(rewards.shape[0]).cpu() + next_stop_strings = [None] * len(message_log_batch) + + return EnvironmentReturn( + observations=observations, + metadata=metadata, + next_stop_strings=next_stop_strings, + rewards=rewards, + terminateds=done, + answers=extracted_answers, + ) + + def global_post_process_and_metrics( + self, batch: BatchedDataDict[Any] + ) -> tuple[BatchedDataDict[Any], dict[str, float | int]]: + """Computes metrics for this environment given a global rollout batch. + + Every rank will run this function, so you're free to use distributed + calculations if you'd prefer for heavy metrics. + """ + batch["rewards"] = ( + batch["rewards"] * batch["is_end"] + ) # set a reward of 0 for any incorrectly ended sequences + if (batch["rewards"] == 1).float().sum() > 0: + correct_solution_generation_lengths = ( + (batch["generation_lengths"] - batch["prompt_lengths"])[ + batch["rewards"] == 1 + ] + .float() + .mean() + .item() + ) + else: + correct_solution_generation_lengths = 0 + + metrics = { + # "table": table, TODO @sahilj WIP + "accuracy": batch["rewards"].mean().item(), + "pass@samples_per_prompt": calculate_pass_rate_per_prompt( + batch["text"], batch["rewards"] + ), + "fraction_of_samples_properly_ended": batch["is_end"].float().mean().item(), + "num_problems_in_batch": batch["is_end"].shape[0], + "generation_lengths": batch["generation_lengths"].float().mean().item(), + "prompt_lengths": batch["prompt_lengths"].float().mean().item(), + "correct_solution_generation_lengths": correct_solution_generation_lengths, + } + + return batch, metrics diff --git a/nemo_rl/environments/utils.py b/nemo_rl/environments/utils.py index c4227b8631..df82c7d1af 100644 --- a/nemo_rl/environments/utils.py +++ b/nemo_rl/environments/utils.py @@ -35,6 +35,9 @@ class EnvRegistryEntry(TypedDict, total=False): "math": { "actor_class_fqn": "nemo_rl.environments.math_environment.MathEnvironment", }, + "math_multi_reward": { + "actor_class_fqn": "nemo_rl.environments.math_environment.MathMultiRewardEnvironment", + }, "code": { "actor_class_fqn": "nemo_rl.environments.code_environment.CodeEnvironment", }, diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index 603a972095..52f3996244 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -327,7 +327,9 @@ def calculate_rewards( sorted_indices = sorted( range(len(all_indices_order)), key=lambda k: all_indices_order[k] ) - rewards = torch.tensor([all_rewards[i] for i in sorted_indices]) + # Stack rewards: each element may be scalar (single-reward env) or 1d (multi-reward env). + # torch.stack preserves shape: scalars -> (N,), shape (K,) -> (N, K). + rewards = torch.stack([all_rewards[i] for i in sorted_indices]) env_observations = [all_env_observations[i] for i in sorted_indices] terminateds = torch.tensor([all_terminateds[i] for i in sorted_indices]) next_stop_strings = [all_next_stop_strings[i] for i in sorted_indices] @@ -374,6 +376,10 @@ def run_multi_turn_rollout( active_indices = torch.arange(batch_size) total_rewards = torch.zeros(batch_size, dtype=torch.float32) + # Multi_rewards: number of components inferred from first env_output (1 for single-reward envs) + number_of_rewards: int | None = None + multi_rewards: torch.Tensor | None = None + # Initialize stop_strings from the initial batch if present current_stop_strings = current_batch.get("stop_strings", [None] * batch_size) @@ -459,7 +465,24 @@ def run_multi_turn_rollout( # Calculate rewards and get environment feedback env_output: EnvironmentReturn = calculate_rewards(active_batch, task_to_env) - total_rewards[active_indices] += env_output.rewards + # Infer number of reward components on first turn (supports single- and multi-reward envs) + if number_of_rewards is None: + if env_output.rewards.ndim >= 2: + number_of_rewards = int(env_output.rewards.shape[1]) + multi_rewards = torch.zeros( + batch_size, number_of_rewards, dtype=torch.float32 + ) + else: + number_of_rewards = 1 + # multi_rewards left None: GRPO uses total_reward only; reward1 unused + # Accumulate rewards: env may return shape (N,) or (N, K) + if env_output.rewards.ndim >= 2: + multi_rewards[active_indices] += env_output.rewards + total_rewards[active_indices] += env_output.rewards.sum(dim=1) + else: + total_rewards[active_indices] += env_output.rewards + + # Update message log for ALL active samples with env observation # This must happen BEFORE filtering based on done flags @@ -536,6 +559,11 @@ def run_multi_turn_rollout( # Add total rewards to the final batch current_batch["total_reward"] = total_rewards current_batch["truncated"] = sample_truncated + # Expose per-component rewards (reward1, reward2, ...) for multi-reward envs only; GRPO uses total_reward + if multi_rewards is not None: + num_reward_components = multi_rewards.shape[1] + for i in range(num_reward_components): + current_batch[f"reward{i + 1}"] = multi_rewards[:, i].clone() # Calculate aggregate metrics rollout_metrics = { @@ -666,6 +694,8 @@ async def run_sample_multi_turn_rollout( # Sample-level metrics total_reward = 0.0 + reward_acc_list: list[float] = [] # per-component rewards, length set on first multi-reward + multi_reward_seen = False turn_count = 0 token_count = 0 assistant_token_count = 0 @@ -738,8 +768,17 @@ async def run_sample_multi_turn_rollout( # Get environment feedback env_output = calculate_rewards(sample_batch, task_to_env) - # Update total reward - total_reward += float(env_output.rewards[0].item()) + # Update total reward and optional per-reward signals (reward1, reward2, ... rewardN) + if env_output.rewards.ndim == 2 and env_output.rewards.shape[1] >= 1: + multi_reward_seen = True + n = env_output.rewards.shape[1] + if len(reward_acc_list) == 0: + reward_acc_list = [0.0] * n + total_reward += float(env_output.rewards[0].sum().item()) + for j in range(n): + reward_acc_list[j] += float(env_output.rewards[0, j].item()) + else: + total_reward += float(env_output.rewards[0].item()) # Check termination terminated = env_output.terminateds[0].item() env_obs_content = env_output.observations[0]["content"] @@ -789,6 +828,9 @@ async def run_sample_multi_turn_rollout( "stop_strings": current_stop_strings, "idx": sample_idx, } + if multi_reward_seen: + for j in range(len(reward_acc_list)): + final_sample_state[f"reward{j + 1}"] = torch.tensor(reward_acc_list[j]) # Sample metrics sample_metrics = { diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index bee4d8d2eb..ced8ffacaf 100644 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -45,6 +45,7 @@ run_test uv run --no-sync bash ./tests/functional/dpo_automodel_lora.sh run_test uv run --no-sync bash ./tests/functional/dpo_megatron.sh run_test uv run --no-sync bash ./tests/functional/eval.sh run_test uv run --no-sync bash ./tests/functional/eval_async.sh +run_test fast uv run --no-sync bash ./tests/functional/gdpo.sh run_test fast uv run --no-sync bash ./tests/functional/grpo.sh run_test fast uv run --no-sync bash ./tests/functional/grpo_async_gym.sh run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora.sh diff --git a/tests/functional/gdpo.sh b/tests/functional/gdpo.sh new file mode 100644 index 0000000000..ee95645e28 --- /dev/null +++ b/tests/functional/gdpo.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_grpo.py \ + --config $PROJECT_ROOT/examples/configs/gdpo_math_1B.yaml \ + policy.model_name=Qwen/Qwen3-0.6B \ + grpo.num_prompts_per_step=2 \ + grpo.num_generations_per_prompt=4 \ + policy.train_global_batch_size=4 \ + policy.train_micro_batch_size=1 \ + cluster.gpus_per_node=2 \ + cluster.num_nodes=1 \ + grpo.max_num_steps=2 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'max(data["train/gen_kl_error"]) < 0.001' \ + 'min(data["train/probs_ratio_clamped_min"]) > 0.79' \ + 'max(data["train/probs_ratio_clamped_min"]) < 1.21' \ + 'min(data["train/probs_ratio_clamped_max"]) > 0.79' \ + 'max(data["train/probs_ratio_clamped_max"]) < 1.21' diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index 7a0783f132..e0bdb8d2b1 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -1642,8 +1642,9 @@ def test_grpo_advantage_estimator_zero_std(): [2.0, 2.0, 1.0, 3.0] ) # prompt 0: std=0; prompt 1: std=sqrt(2) mask = torch.ones(4, 5) + repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(prompt_ids, rewards, mask) + result = estimator.compute_advantage(repeated_batch, mask) # prompt 0: std=0 -> skip normalization, advantage=0 (reward - mean = 0) # prompt 1: With Bessel correction for 2 samples, std = sqrt(2), normalized = ±1/sqrt(2) ≈ ±0.7071 @@ -1673,8 +1674,9 @@ def test_grpo_advantage_estimator_tensor_shapes(): prompt_ids = torch.tensor([[0], [0]]) rewards = torch.tensor([1.0, 3.0]) # mean=2, std=sqrt(2) with Bessel mask = torch.ones(2, 3) + repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(prompt_ids, rewards, mask) + result = estimator.compute_advantage(repeated_batch, mask) assert result.shape == (2, 3) # Verify normalized values: (reward - mean) / std @@ -1687,8 +1689,9 @@ def test_grpo_advantage_estimator_tensor_shapes(): prompt_ids = torch.tensor([[0]] * 10) rewards = torch.arange(10, dtype=torch.float32) # 0, 1, 2, ..., 9 mask = torch.ones(10, 5) + repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(prompt_ids, rewards, mask) + result = estimator.compute_advantage(repeated_batch, mask) assert result.shape == (10, 5) # After normalization, mean should be ~0 @@ -1712,8 +1715,9 @@ def test_grpo_advantage_estimator_negative_advantages(): prompt_ids = torch.tensor([[0], [0], [0]]) rewards = torch.tensor([0.0, 2.0, 4.0]) # mean=2, deviations: -2, 0, +2 mask = torch.ones(3, 4) + repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(prompt_ids, rewards, mask) + result = estimator.compute_advantage(repeated_batch, mask) # Verify ordering: first should be negative, middle ~0, last positive assert result[0, 0] < 0 # below mean -> negative advantage @@ -1742,8 +1746,9 @@ def test_grpo_advantage_estimator_zero_std_and_zero_advantage(): prompt_ids = torch.tensor([[0], [0], [0], [0]]) rewards = torch.tensor([5.0, 5.0, 5.0, 5.0]) # all same mask = torch.ones(4, 3) + repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(prompt_ids, rewards, mask) + result = estimator.compute_advantage(repeated_batch, mask) # All advantages should be exactly 0 expected = torch.zeros(4, 3) @@ -1768,8 +1773,9 @@ def test_grpo_advantage_estimator_small_nonzero_std(): prompt_ids = torch.tensor([[0], [0]]) rewards = torch.tensor([1.0, 1.01]) # small but detectable difference mask = torch.ones(2, 3) + repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(prompt_ids, rewards, mask) + result = estimator.compute_advantage(repeated_batch, mask) # Even with small std, normalization should still happen # After normalization, the values should be ±1/sqrt(2) (for 2 samples with Bessel) @@ -1808,8 +1814,9 @@ def test_reinforce_plus_plus_global_normalization(): ) # Shape (4, 1) for unique prompt matching rewards = torch.tensor([0.0, 1.0, 2.0, 3.0]) # mean=1.5 mask = torch.ones(4, 5) + repeated_batch = {"_input_ids_for_baseline": prompt_ids, "total_reward": rewards} - result = estimator.compute_advantage(prompt_ids, rewards, mask) + result = estimator.compute_advantage(repeated_batch, mask) # After global normalization, mean should be ~0 result_mean = (result * mask).sum() / mask.sum()