Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs_nnx/guides/demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ from flax import nnx
class Block(nnx.Module):
def __init__(self, din, dout, *, rngs):
self.linear = nnx.Linear(din, dout, rngs=rngs)
self.bn = nnx.BatchNorm(dout, rngs=rngs)
self.bn = nnx.BatchNorm(dout, use_running_average=False, rngs=rngs)

def __call__(self, x):
return nnx.relu(self.bn(self.linear(x)))
Expand Down
4 changes: 2 additions & 2 deletions docs_nnx/guides/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ import optax
class Model(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, use_running_average=False, rngs=rngs)
self.dropout = nnx.Dropout(0.2, deterministic=False, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

def __call__(self, x):
Expand Down
4 changes: 2 additions & 2 deletions docs_nnx/guides/pytree.md
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,8 @@ NNX Modules are `Pytree`s that have two additional methods for traking intermedi
class Block(nnx.Module):
def __init__(self, din: int, dout: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dout, rngs=rngs)
self.bn = nnx.BatchNorm(dout, rngs=rngs)
self.dropout = nnx.Dropout(0.1, rngs=rngs)
self.bn = nnx.BatchNorm(dout, use_running_average=False, rngs=rngs)
self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs)

def __call__(self, x):
y = nnx.relu(self.dropout(self.bn(self.linear(x))))
Expand Down
10 changes: 5 additions & 5 deletions docs_nnx/guides/randomness.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ from flax import nnx
class Model(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
self.linear = nnx.Linear(20, 10, rngs=rngs)
self.drop = nnx.Dropout(0.1)
self.drop = nnx.Dropout(0.1, deterministic=False)

def __call__(self, x, *, rngs):
return nnx.relu(self.drop(self.linear(x), rngs=rngs))
Expand Down Expand Up @@ -90,7 +90,7 @@ Specifically, this will use the RngSteam `rngs.params` for weight initialization
The `nnx.Dropout` module also requires a random state, but it requires this state at *call* time rather than initialization. Once again, we can pass it random state using the `rngs` keyword argument.

```{code-cell} ipython3
dropout = nnx.Dropout(0.5)
dropout = nnx.Dropout(0.5, deterministic=False)
```

```{code-cell} ipython3
Expand Down Expand Up @@ -159,7 +159,7 @@ Say you want to train a model that uses dropout on a batch of data. You don't wa
class Model(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(20, 10, rngs=rngs)
self.drop = nnx.Dropout(0.1)
self.drop = nnx.Dropout(0.1, deterministic=False)

def __call__(self, x, rngs):
return nnx.relu(self.drop(self.linear(x), rngs=rngs))
Expand Down Expand Up @@ -199,7 +199,7 @@ So far, we have looked at passing random state directly to each Module when it g
class Model(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(20, 10, rngs=rngs)
self.drop = nnx.Dropout(0.1, rngs=rngs)
self.drop = nnx.Dropout(0.1, deterministic=False, rngs=rngs)

def __call__(self, x):
return nnx.relu(self.drop(self.linear(x)))
Expand Down Expand Up @@ -296,7 +296,7 @@ class Count(nnx.Variable): pass
class RNNCell(nnx.Module):
def __init__(self, din, dout, rngs):
self.linear = nnx.Linear(dout + din, dout, rngs=rngs)
self.drop = nnx.Dropout(0.1, rngs=rngs, rng_collection='recurrent_dropout')
self.drop = nnx.Dropout(0.1, deterministic=False, rngs=rngs, rng_collection='recurrent_dropout')
self.dout = dout
self.count = Count(jnp.array(0, jnp.uint32))

Expand Down
4 changes: 2 additions & 2 deletions docs_nnx/hijax/hijax.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ print("mutable =", v_ref)
class Block(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.1, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, use_running_average=False, rngs=rngs)
self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs)
self.linear_out = Linear(dmid, dout, rngs=rngs)

def __call__(self, x):
Expand Down
4 changes: 2 additions & 2 deletions docs_nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ The example below shows how to define a simple `MLP` by subclassing `Module`. Th
class MLP(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
self.linear1 = Linear(din, dmid, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.1)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.1, deterministic=False)
self.bn = nnx.BatchNorm(dmid, use_running_average=False, rngs=rngs)
self.linear2 = Linear(dmid, dout, rngs=rngs)

def __call__(self, x: jax.Array, rngs: nnx.Rngs):
Expand Down
8 changes: 4 additions & 4 deletions flax/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,12 +444,12 @@ def view(node: A, /, *, only: filterlib.Filter = ..., raise_if_not_found: bool =
>>> class Block(nnx.Module):
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
... self.linear = nnx.Linear(din, dout, rngs=rngs)
... self.dropout = nnx.Dropout(0.5, deterministic=False)
... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
... self.dropout = nnx.Dropout(0.5)
... self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
(None, None)
>>> new_block = nnx.view(block, deterministic=True, use_running_average=True)
>>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average
(True, True)
Expand All @@ -459,7 +459,7 @@ def view(node: A, /, *, only: filterlib.Filter = ..., raise_if_not_found: bool =
>>> new_block = nnx.view(block, only=nnx.Dropout, deterministic=True)
>>> # Only the dropout will be modified
>>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average
(True, False)
(True, None)

Args:
node: the object to create a copy of.
Expand Down
15 changes: 12 additions & 3 deletions flax/nnx/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def __init__(
self,
num_features: int,
*,
use_running_average: bool = False,
use_running_average: bool | None = None,
axis: int = -1,
momentum: float = 0.99,
epsilon: float = 1e-5,
Expand Down Expand Up @@ -364,8 +364,17 @@ def __call__(
use_running_average = first_from(
use_running_average,
self.use_running_average,
error_msg="""No `use_running_average` argument was provided to BatchNorm
as either a __call__ argument, class attribute, or nnx.flag.""",
error_msg=(
'No `use_running_average` argument was provided to BatchNorm.'
' Consider one of the following options:\n\n'
'1. Pass `use_running_average` to the BatchNorm constructor:\n\n'
' self.bn = nnx.BatchNorm(..., use_running_average=True/False)\n\n'
'2. Pass `use_running_average` to the BatchNorm __call__:\n\n'
' self.bn(x, use_running_average=True/False)\n\n'
'3. Use `nnx.view` to create a view of the model with a'
' specific `use_running_average` value:\n\n'
' model_view = nnx.view(model, use_running_average=True/False)\n'
),
)
feature_axes = _canonicalize_axes(x.ndim, self.axis)
reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes)
Expand Down
15 changes: 12 additions & 3 deletions flax/nnx/nn/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
rate: float,
*,
broadcast_dims: Sequence[int] = (),
deterministic: bool = False,
deterministic: bool | None = None,
rng_collection: str = 'dropout',
rngs: rnglib.Rngs | rnglib.RngStream | None = None,
):
Expand Down Expand Up @@ -117,8 +117,17 @@ def __call__(
deterministic = first_from(
deterministic,
self.deterministic,
error_msg="""No `deterministic` argument was provided to Dropout
as either a __call__ argument or class attribute""",
error_msg=(
'No `deterministic` argument was provided to Dropout.'
' Consider one of the following options:\n\n'
'1. Pass `deterministic` to the Dropout constructor:\n\n'
' self.dropout = nnx.Dropout(..., deterministic=True/False)\n\n'
'2. Pass `deterministic` to the Dropout __call__:\n\n'
' self.dropout(x, deterministic=True/False)\n\n'
'3. Use `nnx.view` to create a view of the model with a'
' specific `deterministic` value:\n\n'
' model_view = nnx.view(model, deterministic=True/False)\n'
),
)

if (self.rate == 0.0) or deterministic:
Expand Down
2 changes: 1 addition & 1 deletion tests/nnx/bridge/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def test_pure_nnx_submodule(self):
class NNXLayer(nnx.Module):
def __init__(self, dim, dropout, rngs):
self.linear = nnx.Linear(dim, dim, use_bias=False, rngs=rngs)
self.dropout = nnx.Dropout(dropout, rngs=rngs)
self.dropout = nnx.Dropout(dropout, deterministic=False, rngs=rngs)
self.count = nnx.Intermediate(jnp.array([0.]))
def __call__(self, x):
# Required check to avoid state update in `init()`. Can this be avoided?
Expand Down
6 changes: 3 additions & 3 deletions tests/nnx/bridge/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def test_nnx_to_linen_multiple_rngs(self):
class NNXInner(nnx.Module):
def __init__(self, din, dout, rngs):
self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout)))
self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.5, deterministic=False, rngs=rngs)
def __call__(self, x):
return self.dropout(x @ self.w[...])

Expand Down Expand Up @@ -423,7 +423,7 @@ def test_nnx_to_linen_pytree_structure_consistency(self):
class NNXInner(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout)))
self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.5, deterministic=False, rngs=rngs)
def __call__(self, x):
return self.dropout(x @ self.w)

Expand Down Expand Up @@ -476,7 +476,7 @@ def __init__(self, din, dout, dropout_rate, rngs):
self.w = nnx.Param(
nnx.with_partitioning(nnx.initializers.lecun_normal(), sharding=('in', 'out')
)(rngs.params(), (din, dout)))
self.dropout = nnx.Dropout(rate=dropout_rate, rngs=rngs)
self.dropout = nnx.Dropout(rate=dropout_rate, deterministic=False, rngs=rngs)
def __call__(self, x):
return self.dropout(x @ self.w)

Expand Down
10 changes: 5 additions & 5 deletions tests/nnx/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class Model(nnx.Module):

def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, use_running_average=False, rngs=rngs)
self.dropout = nnx.Dropout(0.2, deterministic=False, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

def __call__(self, x):
Expand Down Expand Up @@ -177,7 +177,7 @@ def loss_fn(model: Model):
new_model = nnx.graph.view(model, use_running_average=False)

for _i in range(3):
train_step(model, x, y)
train_step(new_model, x, y)

assert new_model.block1.linear is new_model.block2.linear
assert new_model.block1.linear.bias is not None
Expand Down Expand Up @@ -468,8 +468,8 @@ def test_example_mutable_arrays(self):
class Model(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, use_running_average=False, rngs=rngs)
self.dropout = nnx.Dropout(0.2, deterministic=False, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

def __call__(self, x):
Expand Down
4 changes: 2 additions & 2 deletions tests/nnx/nn/recurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def __init__(
**kwargs,
)
self.recurrent_dropout = nnx.Dropout(
rate=dropout_rate, rng_collection='recurrent_dropout', rngs=rngs
rate=dropout_rate, deterministic=False, rng_collection='recurrent_dropout', rngs=rngs
)

def __call__(self, carry, x):
Expand All @@ -615,7 +615,7 @@ def __init__(
dropout_rate=recurrent_dropout_rate,
)
self.lstm = nnx.RNN(cell, broadcast_rngs='recurrent_dropout')
self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)
self.dropout = nnx.Dropout(dropout_rate, deterministic=False, rngs=rngs)
self.dense = nnx.Linear(
in_features=hidden_features, out_features=1, rngs=rngs
)
Expand Down
4 changes: 2 additions & 2 deletions tests/nnx/nn/stochastic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_dropout_rng_override(self):
np.testing.assert_allclose(y1, y2)

def test_dropout_arg_override(self):
m = nnx.Dropout(rate=0.5)
m = nnx.Dropout(rate=0.5, deterministic=False)
x = jnp.ones((1, 10))

# deterministic call arg provided
Expand All @@ -89,7 +89,7 @@ def test_dropout_arg_override(self):
m(x)

def test_dropout_arg_override_view(self):
m = nnx.Dropout(rate=0.5)
m = nnx.Dropout(rate=0.5, deterministic=False)
x = jnp.ones((1, 10))

# deterministic call arg provided
Expand Down
2 changes: 1 addition & 1 deletion tests/nnx/rngs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_reseed(self, graph):
class Model(nnx.Module):
def __init__(self, rngs):
self.linear = nnx.Linear(2, 3, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs)

def __call__(self, x):
return self.dropout(self.linear(x))
Expand Down
4 changes: 2 additions & 2 deletions tests/nnx/spmd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ def test_out_sharding_dropout(self):
replicated_array = jnp.arange(8).reshape(2, 4).astype(jnp.float32)
sharded_array = reshard(replicated_array, P("X", None))
layers = [
nnx.Dropout(rate=0.5, rngs=nnx.Rngs(0)),
nnx.Dropout(rate=0.5, broadcast_dims=(1,), rngs=nnx.Rngs(0)),
nnx.Dropout(rate=0.5, deterministic=False, rngs=nnx.Rngs(0)),
nnx.Dropout(rate=0.5, deterministic=False, broadcast_dims=(1,), rngs=nnx.Rngs(0)),
]
for layer in layers:
assert 'float32[2@X,4]' in str(jax.typeof(layer(sharded_array)))
Expand Down
12 changes: 6 additions & 6 deletions tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2874,7 +2874,7 @@ def __init__(self, input_size, hidden_size, rngs):
self.linear = nnx.Linear(
hidden_size + input_size, hidden_size, rngs=rngs
)
self.drop = nnx.Dropout(0.1, rngs=rngs)
self.drop = nnx.Dropout(0.1, deterministic=False, rngs=rngs)
self.hidden_size = hidden_size

def __call__(self, carry, x) -> tuple[jax.Array, jax.Array]:
Expand Down Expand Up @@ -3381,7 +3381,7 @@ def test_state_axes_simple(self):
class Block(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(2, 3, rngs=rngs)
self.bn = nnx.BatchNorm(3, rngs=rngs)
self.bn = nnx.BatchNorm(3, use_running_average=False, rngs=rngs)
self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs)

def __call__(self, x: jax.Array) -> jax.Array:
Expand Down Expand Up @@ -3415,7 +3415,7 @@ def test_split_rngs_decorator_simple(self):
class Block(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(2, 3, rngs=rngs)
self.bn = nnx.BatchNorm(3, rngs=rngs)
self.bn = nnx.BatchNorm(3, use_running_average=False, rngs=rngs)
self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs)

def __call__(self, x: jax.Array) -> jax.Array:
Expand Down Expand Up @@ -3455,7 +3455,7 @@ def test_state_axes_super_simple(self):
class Block(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(2, 3, rngs=rngs)
self.bn = nnx.BatchNorm(3, rngs=rngs)
self.bn = nnx.BatchNorm(3, use_running_average=False, rngs=rngs)
self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs)

def __call__(self, x: jax.Array) -> jax.Array:
Expand Down Expand Up @@ -3696,8 +3696,8 @@ def test_example(self):
class Model(nnx.Module):
def __init__(self, din, dout, *, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dout, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
self.bn = nnx.BatchNorm(dout, rngs=rngs)
self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs)
self.bn = nnx.BatchNorm(dout, use_running_average=False, rngs=rngs)

def __call__(self, x):
return nnx.relu(self.dropout(self.bn(self.linear(x))))
Expand Down
Loading