From efe110573ffaf4f5e0e3381db99cbad5b3833fca Mon Sep 17 00:00:00 2001 From: Sai Asish Y Date: Thu, 16 Apr 2026 15:46:58 -0700 Subject: [PATCH] Warn when init_noise_std and related params are ignored with tanh_normal (#661) When distribution_type='tanh_normal' (the default), init_noise_std, noise_std_type, and state_dependent_std are silently ignored because the tanh_normal branch creates a plain MLP whose output fully determines the standard deviation. Users tuning these parameters under the default get no feedback that their changes have no effect. Emit a warning listing the ignored parameters when any of them differ from their defaults. --- brax/training/agents/ppo/networks.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/brax/training/agents/ppo/networks.py b/brax/training/agents/ppo/networks.py index 1410831f8..70c5a1cad 100644 --- a/brax/training/agents/ppo/networks.py +++ b/brax/training/agents/ppo/networks.py @@ -14,6 +14,7 @@ """PPO networks.""" +import warnings from typing import Any, Literal, Mapping, Sequence, Tuple from brax.training import distribution @@ -116,6 +117,23 @@ def make_ppo_networks( value_kernel_init_kwargs = value_network_kernel_init_kwargs or {} mean_kernel_init_kwargs_ = mean_kernel_init_kwargs or {} + if distribution_type == 'tanh_normal': + ignored = [] + if init_noise_std != 1.0: + ignored.append(f'init_noise_std={init_noise_std!r}') + if noise_std_type != 'scalar': + ignored.append(f'noise_std_type={noise_std_type!r}') + if state_dependent_std: + ignored.append(f'state_dependent_std={state_dependent_std!r}') + if ignored: + warnings.warn( + f'{", ".join(ignored)} {"has" if len(ignored) == 1 else "have"}' + ' no effect with distribution_type="tanh_normal". The standard' + ' deviation is determined entirely by the policy network output.' + ' These parameters only apply to distribution_type="normal".', + stacklevel=2, + ) + parametric_action_distribution: distribution.ParametricDistribution if distribution_type == 'normal': parametric_action_distribution = distribution.NormalDistribution( @@ -163,4 +181,4 @@ def make_ppo_networks( policy_network=policy_network, value_network=value_network, parametric_action_distribution=parametric_action_distribution, - ) + ) \ No newline at end of file