diff --git a/docs_nnx/guides/demo.md b/docs_nnx/guides/demo.md index 1f423a77e..90bb3f867 100644 --- a/docs_nnx/guides/demo.md +++ b/docs_nnx/guides/demo.md @@ -25,7 +25,7 @@ from flax import nnx class Block(nnx.Module): def __init__(self, din, dout, *, rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) - self.bn = nnx.BatchNorm(dout, rngs=rngs) + self.bn = nnx.BatchNorm(dout, use_running_average=False, rngs=rngs) def __call__(self, x): return nnx.relu(self.bn(self.linear(x))) diff --git a/docs_nnx/guides/performance.md b/docs_nnx/guides/performance.md index 0c29ac815..23a898534 100644 --- a/docs_nnx/guides/performance.md +++ b/docs_nnx/guides/performance.md @@ -26,8 +26,8 @@ import optax class Model(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dmid, rngs=rngs) - self.bn = nnx.BatchNorm(dmid, rngs=rngs) - self.dropout = nnx.Dropout(0.2, rngs=rngs) + self.bn = nnx.BatchNorm(dmid, use_running_average=False, rngs=rngs) + self.dropout = nnx.Dropout(0.2, deterministic=False, rngs=rngs) self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, x): diff --git a/docs_nnx/guides/pytree.md b/docs_nnx/guides/pytree.md index 2c9f46caf..eee0a0b87 100644 --- a/docs_nnx/guides/pytree.md +++ b/docs_nnx/guides/pytree.md @@ -436,8 +436,8 @@ NNX Modules are `Pytree`s that have two additional methods for traking intermedi class Block(nnx.Module): def __init__(self, din: int, dout: int, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) - self.bn = nnx.BatchNorm(dout, rngs=rngs) - self.dropout = nnx.Dropout(0.1, rngs=rngs) + self.bn = nnx.BatchNorm(dout, use_running_average=False, rngs=rngs) + self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs) def __call__(self, x): y = nnx.relu(self.dropout(self.bn(self.linear(x)))) diff --git a/docs_nnx/guides/randomness.md b/docs_nnx/guides/randomness.md index c7955c5ba..3a3881a27 100644 --- a/docs_nnx/guides/randomness.md +++ b/docs_nnx/guides/randomness.md @@ -18,7 +18,7 @@ from flax import nnx class Model(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.linear = nnx.Linear(20, 10, rngs=rngs) - self.drop = nnx.Dropout(0.1) + self.drop = nnx.Dropout(0.1, deterministic=False) def __call__(self, x, *, rngs): return nnx.relu(self.drop(self.linear(x), rngs=rngs)) @@ -90,7 +90,7 @@ Specifically, this will use the RngSteam `rngs.params` for weight initialization The `nnx.Dropout` module also requires a random state, but it requires this state at *call* time rather than initialization. Once again, we can pass it random state using the `rngs` keyword argument. ```{code-cell} ipython3 -dropout = nnx.Dropout(0.5) +dropout = nnx.Dropout(0.5, deterministic=False) ``` ```{code-cell} ipython3 @@ -159,7 +159,7 @@ Say you want to train a model that uses dropout on a batch of data. You don't wa class Model(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(20, 10, rngs=rngs) - self.drop = nnx.Dropout(0.1) + self.drop = nnx.Dropout(0.1, deterministic=False) def __call__(self, x, rngs): return nnx.relu(self.drop(self.linear(x), rngs=rngs)) @@ -199,7 +199,7 @@ So far, we have looked at passing random state directly to each Module when it g class Model(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(20, 10, rngs=rngs) - self.drop = nnx.Dropout(0.1, rngs=rngs) + self.drop = nnx.Dropout(0.1, deterministic=False, rngs=rngs) def __call__(self, x): return nnx.relu(self.drop(self.linear(x))) @@ -296,7 +296,7 @@ class Count(nnx.Variable): pass class RNNCell(nnx.Module): def __init__(self, din, dout, rngs): self.linear = nnx.Linear(dout + din, dout, rngs=rngs) - self.drop = nnx.Dropout(0.1, rngs=rngs, rng_collection='recurrent_dropout') + self.drop = nnx.Dropout(0.1, deterministic=False, rngs=rngs, rng_collection='recurrent_dropout') self.dout = dout self.count = Count(jnp.array(0, jnp.uint32)) diff --git a/docs_nnx/hijax/hijax.md b/docs_nnx/hijax/hijax.md index 6d6123022..08d4a9f35 100644 --- a/docs_nnx/hijax/hijax.md +++ b/docs_nnx/hijax/hijax.md @@ -153,8 +153,8 @@ print("mutable =", v_ref) class Block(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): self.linear = Linear(din, dmid, rngs=rngs) - self.bn = nnx.BatchNorm(dmid, rngs=rngs) - self.dropout = nnx.Dropout(0.1, rngs=rngs) + self.bn = nnx.BatchNorm(dmid, use_running_average=False, rngs=rngs) + self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs) self.linear_out = Linear(dmid, dout, rngs=rngs) def __call__(self, x): diff --git a/docs_nnx/nnx_basics.md b/docs_nnx/nnx_basics.md index f29074de7..85784fa89 100644 --- a/docs_nnx/nnx_basics.md +++ b/docs_nnx/nnx_basics.md @@ -96,8 +96,8 @@ The example below shows how to define a simple `MLP` by subclassing `Module`. Th class MLP(nnx.Module): def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): self.linear1 = Linear(din, dmid, rngs=rngs) - self.dropout = nnx.Dropout(rate=0.1) - self.bn = nnx.BatchNorm(dmid, rngs=rngs) + self.dropout = nnx.Dropout(rate=0.1, deterministic=False) + self.bn = nnx.BatchNorm(dmid, use_running_average=False, rngs=rngs) self.linear2 = Linear(dmid, dout, rngs=rngs) def __call__(self, x: jax.Array, rngs: nnx.Rngs): diff --git a/flax/nnx/module.py b/flax/nnx/module.py index 5944ae4ca..21e9a31be 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -444,12 +444,12 @@ def view(node: A, /, *, only: filterlib.Filter = ..., raise_if_not_found: bool = >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) - ... self.dropout = nnx.Dropout(0.5, deterministic=False) - ... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs) + ... self.dropout = nnx.Dropout(0.5) + ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average - (False, False) + (None, None) >>> new_block = nnx.view(block, deterministic=True, use_running_average=True) >>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average (True, True) @@ -459,7 +459,7 @@ def view(node: A, /, *, only: filterlib.Filter = ..., raise_if_not_found: bool = >>> new_block = nnx.view(block, only=nnx.Dropout, deterministic=True) >>> # Only the dropout will be modified >>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average - (True, False) + (True, None) Args: node: the object to create a copy of. diff --git a/flax/nnx/nn/normalization.py b/flax/nnx/nn/normalization.py index a91b3e3e7..1766da5b0 100644 --- a/flax/nnx/nn/normalization.py +++ b/flax/nnx/nn/normalization.py @@ -290,7 +290,7 @@ def __init__( self, num_features: int, *, - use_running_average: bool = False, + use_running_average: bool | None = None, axis: int = -1, momentum: float = 0.99, epsilon: float = 1e-5, @@ -364,8 +364,17 @@ def __call__( use_running_average = first_from( use_running_average, self.use_running_average, - error_msg="""No `use_running_average` argument was provided to BatchNorm - as either a __call__ argument, class attribute, or nnx.flag.""", + error_msg=( + 'No `use_running_average` argument was provided to BatchNorm.' + ' Consider one of the following options:\n\n' + '1. Pass `use_running_average` to the BatchNorm constructor:\n\n' + ' self.bn = nnx.BatchNorm(..., use_running_average=True/False)\n\n' + '2. Pass `use_running_average` to the BatchNorm __call__:\n\n' + ' self.bn(x, use_running_average=True/False)\n\n' + '3. Use `nnx.view` to create a view of the model with a' + ' specific `use_running_average` value:\n\n' + ' model_view = nnx.view(model, use_running_average=True/False)\n' + ), ) feature_axes = _canonicalize_axes(x.ndim, self.axis) reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes) diff --git a/flax/nnx/nn/stochastic.py b/flax/nnx/nn/stochastic.py index cea31f35e..4dcc78925 100644 --- a/flax/nnx/nn/stochastic.py +++ b/flax/nnx/nn/stochastic.py @@ -73,7 +73,7 @@ def __init__( rate: float, *, broadcast_dims: Sequence[int] = (), - deterministic: bool = False, + deterministic: bool | None = None, rng_collection: str = 'dropout', rngs: rnglib.Rngs | rnglib.RngStream | None = None, ): @@ -117,8 +117,17 @@ def __call__( deterministic = first_from( deterministic, self.deterministic, - error_msg="""No `deterministic` argument was provided to Dropout - as either a __call__ argument or class attribute""", + error_msg=( + 'No `deterministic` argument was provided to Dropout.' + ' Consider one of the following options:\n\n' + '1. Pass `deterministic` to the Dropout constructor:\n\n' + ' self.dropout = nnx.Dropout(..., deterministic=True/False)\n\n' + '2. Pass `deterministic` to the Dropout __call__:\n\n' + ' self.dropout(x, deterministic=True/False)\n\n' + '3. Use `nnx.view` to create a view of the model with a' + ' specific `deterministic` value:\n\n' + ' model_view = nnx.view(model, deterministic=True/False)\n' + ), ) if (self.rate == 0.0) or deterministic: diff --git a/tests/nnx/bridge/module_test.py b/tests/nnx/bridge/module_test.py index 4f790e798..ce2c7d322 100644 --- a/tests/nnx/bridge/module_test.py +++ b/tests/nnx/bridge/module_test.py @@ -280,7 +280,7 @@ def test_pure_nnx_submodule(self): class NNXLayer(nnx.Module): def __init__(self, dim, dropout, rngs): self.linear = nnx.Linear(dim, dim, use_bias=False, rngs=rngs) - self.dropout = nnx.Dropout(dropout, rngs=rngs) + self.dropout = nnx.Dropout(dropout, deterministic=False, rngs=rngs) self.count = nnx.Intermediate(jnp.array([0.])) def __call__(self, x): # Required check to avoid state update in `init()`. Can this be avoided? diff --git a/tests/nnx/bridge/wrappers_test.py b/tests/nnx/bridge/wrappers_test.py index b162655d9..6d687984d 100644 --- a/tests/nnx/bridge/wrappers_test.py +++ b/tests/nnx/bridge/wrappers_test.py @@ -270,7 +270,7 @@ def test_nnx_to_linen_multiple_rngs(self): class NNXInner(nnx.Module): def __init__(self, din, dout, rngs): self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout))) - self.dropout = nnx.Dropout(rate=0.5, rngs=rngs) + self.dropout = nnx.Dropout(rate=0.5, deterministic=False, rngs=rngs) def __call__(self, x): return self.dropout(x @ self.w[...]) @@ -423,7 +423,7 @@ def test_nnx_to_linen_pytree_structure_consistency(self): class NNXInner(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout))) - self.dropout = nnx.Dropout(rate=0.5, rngs=rngs) + self.dropout = nnx.Dropout(rate=0.5, deterministic=False, rngs=rngs) def __call__(self, x): return self.dropout(x @ self.w) @@ -476,7 +476,7 @@ def __init__(self, din, dout, dropout_rate, rngs): self.w = nnx.Param( nnx.with_partitioning(nnx.initializers.lecun_normal(), sharding=('in', 'out') )(rngs.params(), (din, dout))) - self.dropout = nnx.Dropout(rate=dropout_rate, rngs=rngs) + self.dropout = nnx.Dropout(rate=dropout_rate, deterministic=False, rngs=rngs) def __call__(self, x): return self.dropout(x @ self.w) diff --git a/tests/nnx/integration_test.py b/tests/nnx/integration_test.py index 675662184..a679020cf 100644 --- a/tests/nnx/integration_test.py +++ b/tests/nnx/integration_test.py @@ -36,8 +36,8 @@ class Model(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dmid, rngs=rngs) - self.bn = nnx.BatchNorm(dmid, rngs=rngs) - self.dropout = nnx.Dropout(0.2, rngs=rngs) + self.bn = nnx.BatchNorm(dmid, use_running_average=False, rngs=rngs) + self.dropout = nnx.Dropout(0.2, deterministic=False, rngs=rngs) self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, x): @@ -177,7 +177,7 @@ def loss_fn(model: Model): new_model = nnx.graph.view(model, use_running_average=False) for _i in range(3): - train_step(model, x, y) + train_step(new_model, x, y) assert new_model.block1.linear is new_model.block2.linear assert new_model.block1.linear.bias is not None @@ -468,8 +468,8 @@ def test_example_mutable_arrays(self): class Model(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dmid, rngs=rngs) - self.bn = nnx.BatchNorm(dmid, rngs=rngs) - self.dropout = nnx.Dropout(0.2, rngs=rngs) + self.bn = nnx.BatchNorm(dmid, use_running_average=False, rngs=rngs) + self.dropout = nnx.Dropout(0.2, deterministic=False, rngs=rngs) self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, x): diff --git a/tests/nnx/nn/recurrent_test.py b/tests/nnx/nn/recurrent_test.py index 034baff2f..9747d4ac6 100644 --- a/tests/nnx/nn/recurrent_test.py +++ b/tests/nnx/nn/recurrent_test.py @@ -589,7 +589,7 @@ def __init__( **kwargs, ) self.recurrent_dropout = nnx.Dropout( - rate=dropout_rate, rng_collection='recurrent_dropout', rngs=rngs + rate=dropout_rate, deterministic=False, rng_collection='recurrent_dropout', rngs=rngs ) def __call__(self, carry, x): @@ -615,7 +615,7 @@ def __init__( dropout_rate=recurrent_dropout_rate, ) self.lstm = nnx.RNN(cell, broadcast_rngs='recurrent_dropout') - self.dropout = nnx.Dropout(dropout_rate, rngs=rngs) + self.dropout = nnx.Dropout(dropout_rate, deterministic=False, rngs=rngs) self.dense = nnx.Linear( in_features=hidden_features, out_features=1, rngs=rngs ) diff --git a/tests/nnx/nn/stochastic_test.py b/tests/nnx/nn/stochastic_test.py index 296776de0..0ef35735d 100644 --- a/tests/nnx/nn/stochastic_test.py +++ b/tests/nnx/nn/stochastic_test.py @@ -67,7 +67,7 @@ def test_dropout_rng_override(self): np.testing.assert_allclose(y1, y2) def test_dropout_arg_override(self): - m = nnx.Dropout(rate=0.5) + m = nnx.Dropout(rate=0.5, deterministic=False) x = jnp.ones((1, 10)) # deterministic call arg provided @@ -89,7 +89,7 @@ def test_dropout_arg_override(self): m(x) def test_dropout_arg_override_view(self): - m = nnx.Dropout(rate=0.5) + m = nnx.Dropout(rate=0.5, deterministic=False) x = jnp.ones((1, 10)) # deterministic call arg provided diff --git a/tests/nnx/rngs_test.py b/tests/nnx/rngs_test.py index 817198936..20e70179c 100644 --- a/tests/nnx/rngs_test.py +++ b/tests/nnx/rngs_test.py @@ -177,7 +177,7 @@ def test_reseed(self, graph): class Model(nnx.Module): def __init__(self, rngs): self.linear = nnx.Linear(2, 3, rngs=rngs) - self.dropout = nnx.Dropout(0.5, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) def __call__(self, x): return self.dropout(self.linear(x)) diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index 8680ae858..ecfce393a 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -259,8 +259,8 @@ def test_out_sharding_dropout(self): replicated_array = jnp.arange(8).reshape(2, 4).astype(jnp.float32) sharded_array = reshard(replicated_array, P("X", None)) layers = [ - nnx.Dropout(rate=0.5, rngs=nnx.Rngs(0)), - nnx.Dropout(rate=0.5, broadcast_dims=(1,), rngs=nnx.Rngs(0)), + nnx.Dropout(rate=0.5, deterministic=False, rngs=nnx.Rngs(0)), + nnx.Dropout(rate=0.5, deterministic=False, broadcast_dims=(1,), rngs=nnx.Rngs(0)), ] for layer in layers: assert 'float32[2@X,4]' in str(jax.typeof(layer(sharded_array))) diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 55913b147..7808dbc3f 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -2874,7 +2874,7 @@ def __init__(self, input_size, hidden_size, rngs): self.linear = nnx.Linear( hidden_size + input_size, hidden_size, rngs=rngs ) - self.drop = nnx.Dropout(0.1, rngs=rngs) + self.drop = nnx.Dropout(0.1, deterministic=False, rngs=rngs) self.hidden_size = hidden_size def __call__(self, carry, x) -> tuple[jax.Array, jax.Array]: @@ -3381,7 +3381,7 @@ def test_state_axes_simple(self): class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(2, 3, rngs=rngs) - self.bn = nnx.BatchNorm(3, rngs=rngs) + self.bn = nnx.BatchNorm(3, use_running_average=False, rngs=rngs) self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs) def __call__(self, x: jax.Array) -> jax.Array: @@ -3415,7 +3415,7 @@ def test_split_rngs_decorator_simple(self): class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(2, 3, rngs=rngs) - self.bn = nnx.BatchNorm(3, rngs=rngs) + self.bn = nnx.BatchNorm(3, use_running_average=False, rngs=rngs) self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs) def __call__(self, x: jax.Array) -> jax.Array: @@ -3455,7 +3455,7 @@ def test_state_axes_super_simple(self): class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(2, 3, rngs=rngs) - self.bn = nnx.BatchNorm(3, rngs=rngs) + self.bn = nnx.BatchNorm(3, use_running_average=False, rngs=rngs) self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs) def __call__(self, x: jax.Array) -> jax.Array: @@ -3696,8 +3696,8 @@ def test_example(self): class Model(nnx.Module): def __init__(self, din, dout, *, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) - self.dropout = nnx.Dropout(0.5, rngs=rngs) - self.bn = nnx.BatchNorm(dout, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) + self.bn = nnx.BatchNorm(dout, use_running_average=False, rngs=rngs) def __call__(self, x): return nnx.relu(self.dropout(self.bn(self.linear(x))))