feat: adding wandb table log feature, showing concrete test samples#2018
feat: adding wandb table log feature, showing concrete test samples#2018vinhngx wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
Conversation
Signed-off-by: Vinh Nguyen <vinhn@nvidia.com>
📝 WalkthroughWalkthroughAdds WandB table logging for validation samples during GRPO training. Introduces configuration field Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
nemo_rl/utils/logger.py (1)
77-90:⚠️ Potential issue | 🟡 MinorDocument
num_val_samples_to_login the TypedDict interface.
Please add a Google‑style class docstring (Attributes section) that states the purpose, valid values (≥0), and recommended default (0 disables). This satisfies the config-key documentation requirement at the definition site.✍️ Suggested docstring update
class LoggerConfig(TypedDict): + """Logger configuration. + + Attributes: + num_val_samples_to_log (int | None): Number of validation samples to log to WandB + tables. Use 0 to disable. Recommended default: 0. + """ log_dir: str wandb_enabled: bool swanlab_enabled: bool @@ num_val_samples_to_print: NotRequired[int] num_val_samples_to_log: NotRequired[int]As per coding guidelines, "When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/utils/logger.py` around lines 77 - 90, Add a Google‑style class docstring to the LoggerConfig TypedDict that includes an "Attributes" section documenting the new config key "num_val_samples_to_log" (purpose: number of validation samples to log, valid values: integer >= 0, recommended default: 0 disables logging), and ensure the description notes type and behavior; update any exemplar YAMLs under examples/configs/*.yaml to reflect the default as well. Reference: LoggerConfig and the key name num_val_samples_to_log when locating where to add the docstring.nemo_rl/algorithms/grpo.py (1)
2154-2166:⚠️ Potential issue | 🟡 MinorAvoid non‑None defaults in code for config values.
get(..., 0)introduces a hidden default. Gate the call on presence of the key instead.✅ Suggested change (no hidden default)
- _log_validation_samples_to_wandb( - logger=logger, - message_logs=all_message_logs, - rewards=total_rewards, - step=step, - num_samples=master_config["logger"].get("num_val_samples_to_log", 0), - ) + num_val_samples_to_log = master_config["logger"].get("num_val_samples_to_log") + if num_val_samples_to_log is not None: + _log_validation_samples_to_wandb( + logger=logger, + message_logs=all_message_logs, + rewards=total_rewards, + step=step, + num_samples=num_val_samples_to_log, + )As per coding guidelines, "YAML is the single source of truth for configuration defaults; do not set non-None defaults in code for configuration values".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/algorithms/grpo.py` around lines 2154 - 2166, The code currently uses master_config["logger"].get("num_val_samples_to_log", 0) which embeds a non-None default in code; instead check for the key and only pass the value when present. Modify the call to _log_validation_samples_to_wandb: compute num_samples only if "num_val_samples_to_log" in master_config["logger"] (e.g. num_samples = master_config["logger"]["num_val_samples_to_log"]) and call _log_validation_samples_to_wandb with that num_samples; otherwise call _log_validation_samples_to_wandb without the num_samples argument (or pass None if the function accepts it), removing the .get(..., 0) usage and avoiding a hidden default. Ensure the change touches the invocation site of _log_validation_samples_to_wandb and not logger.log_batched_dict_as_jsonl.
🧹 Nitpick comments (2)
docs/design-docs/logger.md (1)
167-167: Tighten the sentence for clarity.
Suggested edit: “When set to >0, the logger uploads a table at each validation step with columns input, output, and reward.”🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/design-docs/logger.md` at line 167, The sentence is wordy—tighten it to clearly state behavior for num_val_samples_to_log: replace the current line with a concise version like “When set to >0, the logger uploads a table at each validation step with columns input, output, and reward.” Keep references to the resulting WandB tables (val/generations and val/generations_step_{step}) and preserve the mention of num_val_samples_to_log and the three column names (input, output, reward) so readers can find related settings and outputs.examples/configs/grpo_math_1B.yaml (1)
331-334: Clarify the WandB-disabled logging knobs.
Withwandb_enabled: false,num_val_samples_to_log: 16has no effect. Consider setting it to0here (or enabling WandB) to avoid confusion in the example.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/configs/grpo_math_1B.yaml` around lines 331 - 334, The example config shows wandb_enabled and num_val_samples_to_log but with wandb_enabled: false the logging count is a no-op; update the example to avoid confusion by either setting num_val_samples_to_log: 0 when wandb_enabled: false or flip wandb_enabled to true—change the num_val_samples_to_log key (and/or wandb_enabled) in the config so the settings are consistent (refer to wandb_enabled and num_val_samples_to_log).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@nemo_rl/algorithms/grpo.py`:
- Around line 2178-2238: Add Google-style docstrings to the three helper
functions _message_content_to_str, _select_sample_indices_for_logging, and
_log_validation_samples_to_wandb: for each function include a short one-line
summary, args section with types and descriptions (e.g., content: Any; rewards:
list[float]; message_logs: list[LLMMessageLogType]; step: int; num_samples: int;
logger: Optional[Logger]), a Returns section describing the return type and
meaning, and any Raises or Notes if applicable; ensure the docstrings follow the
Google style (triple-quoted, Sphinx-parsable) and attach them immediately above
each function definition without changing function signatures or behavior.
- Around line 191-195: Add a Google-style docstring to the GRPOLoggerConfig
TypedDict class explaining the purpose of each key and documenting the new key
num_val_samples_to_log: describe its purpose (number of validation samples to
write to wandb table), valid values/types (int >= 0), and a recommended default
(e.g., 0 or a small positive int); update the class docstring to list
num_val_samples_to_print and num_val_samples_to_log with types and defaults, and
update the exemplar YAMLs under examples/configs/*.yaml to reflect the chosen
default for num_val_samples_to_log so examples match the documented default.
---
Outside diff comments:
In `@nemo_rl/algorithms/grpo.py`:
- Around line 2154-2166: The code currently uses
master_config["logger"].get("num_val_samples_to_log", 0) which embeds a non-None
default in code; instead check for the key and only pass the value when present.
Modify the call to _log_validation_samples_to_wandb: compute num_samples only if
"num_val_samples_to_log" in master_config["logger"] (e.g. num_samples =
master_config["logger"]["num_val_samples_to_log"]) and call
_log_validation_samples_to_wandb with that num_samples; otherwise call
_log_validation_samples_to_wandb without the num_samples argument (or pass None
if the function accepts it), removing the .get(..., 0) usage and avoiding a
hidden default. Ensure the change touches the invocation site of
_log_validation_samples_to_wandb and not logger.log_batched_dict_as_jsonl.
In `@nemo_rl/utils/logger.py`:
- Around line 77-90: Add a Google‑style class docstring to the LoggerConfig
TypedDict that includes an "Attributes" section documenting the new config key
"num_val_samples_to_log" (purpose: number of validation samples to log, valid
values: integer >= 0, recommended default: 0 disables logging), and ensure the
description notes type and behavior; update any exemplar YAMLs under
examples/configs/*.yaml to reflect the default as well. Reference: LoggerConfig
and the key name num_val_samples_to_log when locating where to add the
docstring.
---
Nitpick comments:
In `@docs/design-docs/logger.md`:
- Line 167: The sentence is wordy—tighten it to clearly state behavior for
num_val_samples_to_log: replace the current line with a concise version like
“When set to >0, the logger uploads a table at each validation step with columns
input, output, and reward.” Keep references to the resulting WandB tables
(val/generations and val/generations_step_{step}) and preserve the mention of
num_val_samples_to_log and the three column names (input, output, reward) so
readers can find related settings and outputs.
In `@examples/configs/grpo_math_1B.yaml`:
- Around line 331-334: The example config shows wandb_enabled and
num_val_samples_to_log but with wandb_enabled: false the logging count is a
no-op; update the example to avoid confusion by either setting
num_val_samples_to_log: 0 when wandb_enabled: false or flip wandb_enabled to
true—change the num_val_samples_to_log key (and/or wandb_enabled) in the config
so the settings are consistent (refer to wandb_enabled and
num_val_samples_to_log).
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
docs/design-docs/logger.mdexamples/configs/grpo_math_1B.yamlnemo_rl/algorithms/grpo.pynemo_rl/utils/logger.py
| class GRPOLoggerConfig(LoggerConfig): | ||
| num_val_samples_to_print: int # number of val samples to print to stdout | ||
| num_val_samples_to_log: NotRequired[ | ||
| int | ||
| ] # number of val samples to log to wandb table |
There was a problem hiding this comment.
Add a Google‑style docstring for GRPOLoggerConfig (new key docs + default).
num_val_samples_to_log is a new config key and should be documented with purpose, valid values, and recommended default at the type definition site.
✍️ Suggested docstring update
class GRPOLoggerConfig(LoggerConfig):
+ """Logger config for GRPO training.
+
+ Attributes:
+ num_val_samples_to_print (int): Number of validation samples to print to stdout.
+ num_val_samples_to_log (int | None): Number of validation samples to log to WandB
+ tables. Use 0 to disable. Recommended default: 0.
+ """
num_val_samples_to_print: int # number of val samples to print to stdout
num_val_samples_to_log: NotRequired[
int
] # number of val samples to log to wandb tableAs per coding guidelines, "When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@nemo_rl/algorithms/grpo.py` around lines 191 - 195, Add a Google-style
docstring to the GRPOLoggerConfig TypedDict class explaining the purpose of each
key and documenting the new key num_val_samples_to_log: describe its purpose
(number of validation samples to write to wandb table), valid values/types (int
>= 0), and a recommended default (e.g., 0 or a small positive int); update the
class docstring to list num_val_samples_to_print and num_val_samples_to_log with
types and defaults, and update the exemplar YAMLs under examples/configs/*.yaml
to reflect the chosen default for num_val_samples_to_log so examples match the
documented default.
| def _message_content_to_str(content: Any) -> str: | ||
| if isinstance(content, list): | ||
| return "".join(str(part) for part in content) | ||
| return str(content) | ||
|
|
||
|
|
||
| def _select_sample_indices_for_logging( | ||
| rewards: list[float], num_samples: int | ||
| ) -> list[int]: | ||
| if num_samples <= 0: | ||
| return [] | ||
| indices = list(range(len(rewards))) | ||
| if len(indices) <= num_samples: | ||
| return indices | ||
| sorted_indices = sorted(indices, key=lambda i: rewards[i], reverse=True) | ||
| half = num_samples // 2 | ||
| selected = sorted_indices[:half] + sorted_indices[-half:] | ||
| if num_samples % 2 == 1: | ||
| middle_idx = len(sorted_indices) // 2 | ||
| selected.append(sorted_indices[middle_idx]) | ||
| return selected[:num_samples] | ||
|
|
||
|
|
||
| def _log_validation_samples_to_wandb( | ||
| logger: Optional[Logger], | ||
| message_logs: list[LLMMessageLogType], | ||
| rewards: list[float], | ||
| step: int, | ||
| num_samples: int, | ||
| ) -> None: | ||
| if logger is None or logger.wandb_logger is None: | ||
| return | ||
| if not message_logs or not rewards or num_samples <= 0: | ||
| return | ||
|
|
||
| num_samples = min(num_samples, len(message_logs)) | ||
| indices = _select_sample_indices_for_logging(rewards, num_samples) | ||
| rows: list[list[Any]] = [] | ||
| for idx in indices: | ||
| message_log = message_logs[idx] | ||
| prompt = "" | ||
| response = "" | ||
| for msg in message_log: | ||
| role = msg.get("role") | ||
| content = _message_content_to_str(msg.get("content", "")) | ||
| if role == "user" and not prompt: | ||
| prompt = content | ||
| elif role == "assistant": | ||
| response = content | ||
| rows.append([prompt, response, float(rewards[idx])]) | ||
|
|
||
| if rows: | ||
| table = Table(columns=["input", "output", "reward"], data=rows) | ||
| logger.wandb_logger.log_metrics( | ||
| { | ||
| "val/generations": table, | ||
| f"val/generations_step_{step}": table, | ||
| }, | ||
| step=step, | ||
| prefix="", | ||
| ) |
There was a problem hiding this comment.
Add Google‑style docstrings to new helper functions.
These helpers are new public‑ish utilities (even if underscored) and should have Sphinx‑parseable docstrings.
✍️ Suggested docstrings
def _message_content_to_str(content: Any) -> str:
+ """Convert message content to a string.
+
+ Args:
+ content: Message content value (string, list, etc.).
+
+ Returns:
+ String representation of the content. Lists are joined element-wise.
+ """
if isinstance(content, list):
return "".join(str(part) for part in content)
return str(content)
def _select_sample_indices_for_logging(
rewards: list[float], num_samples: int
) -> list[int]:
+ """Select indices for logging a mix of high/low reward samples.
+
+ Args:
+ rewards: Reward values aligned with message logs.
+ num_samples: Number of indices to select.
+
+ Returns:
+ List of selected indices (best/worst and optional middle).
+ """
if num_samples <= 0:
return []
@@
def _log_validation_samples_to_wandb(
logger: Optional[Logger],
message_logs: list[LLMMessageLogType],
rewards: list[float],
step: int,
num_samples: int,
) -> None:
+ """Log validation samples to WandB as a table.
+
+ Args:
+ logger: Logger with an initialized WandB backend.
+ message_logs: Message logs containing prompts and responses.
+ rewards: Reward values aligned with message logs.
+ step: Validation step number.
+ num_samples: Number of samples to log.
+ """
if logger is None or logger.wandb_logger is None:
returnAs per coding guidelines, "Use Google style docstrings for classes and functions, which can be parsed by Sphinx".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@nemo_rl/algorithms/grpo.py` around lines 2178 - 2238, Add Google-style
docstrings to the three helper functions _message_content_to_str,
_select_sample_indices_for_logging, and _log_validation_samples_to_wandb: for
each function include a short one-line summary, args section with types and
descriptions (e.g., content: Any; rewards: list[float]; message_logs:
list[LLMMessageLogType]; step: int; num_samples: int; logger: Optional[Logger]),
a Returns section describing the return type and meaning, and any Raises or
Notes if applicable; ensure the docstrings follow the Google style
(triple-quoted, Sphinx-parsable) and attach them immediately above each function
definition without changing function signatures or behavior.
What does this PR do?
VERL has this useful feature, where at each validation step, a number of input/output samples are included in Wandb for easy manual inspection:

This PR add a similar functionality to NemoRL:

Add optional logging of validation input/output samples to WandB as tables so runs can be inspected manually in the WandB UI (prompt, response, reward per sample).
Issues
Usage
Enable WandB and set
num_val_samples_to_login the logger config. At each validation step, a subset of samples (mix of high- and low-reward) is logged as a WandB table with columns input, output, and reward.Before your PR is "Ready for review"
Pre checks:
Additional Information
Files changed: nemo_rl/algorithms/grpo.py (wandb Table logging and helpers), nemo_rl/utils/logger.py (num_val_samples_to_log in LoggerConfig), examples/configs/grpo_math_1B.yaml (config option), docs/design-docs/logger.md (Validation Table Logging section).
No tables are logged when wandb_enabled is false or num_val_samples_to_log is 0.
Summary by CodeRabbit
Release Notes
New Features
Documentation