Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions flax/nnx/nn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(
self.carry_init = carry_init

def __call__(
self, carry: tuple[Array, Array], inputs: Array
self, carry: tuple[Array, Array], inputs: Array, *, out_sharding=None
) -> tuple[tuple[Array, Array], Array]: # type: ignore[override]
r"""A long short-term memory (LSTM) cell.

Expand All @@ -202,15 +202,16 @@ def __call__(
initialized using ``LSTMCell.initialize_carry``.
inputs: an ndarray with the input for the current time step.
All dimensions except the final are considered batch dimensions.
out_sharding: sharding to apply to the output of each Linear layer.

Returns:
A tuple with the new carry and the output.
"""
c, h = carry
i = self.gate_fn(self.ii(inputs) + self.hi(h))
f = self.gate_fn(self.if_(inputs) + self.hf(h))
g = self.activation_fn(self.ig(inputs) + self.hg(h))
o = self.gate_fn(self.io(inputs) + self.ho(h))
i = self.gate_fn(self.ii(inputs, out_sharding=out_sharding) + self.hi(h, out_sharding=out_sharding))
f = self.gate_fn(self.if_(inputs, out_sharding=out_sharding) + self.hf(h, out_sharding=out_sharding))
g = self.activation_fn(self.ig(inputs, out_sharding=out_sharding) + self.hg(h, out_sharding=out_sharding))
o = self.gate_fn(self.io(inputs, out_sharding=out_sharding) + self.ho(h, out_sharding=out_sharding))
new_c = f * c + i * g
new_h = o * self.activation_fn(new_c)
return (new_c, new_h), new_h
Expand Down Expand Up @@ -371,7 +372,7 @@ def __init__(
self.carry_init = carry_init

def __call__(
self, carry: tuple[Array, Array], inputs: Array
self, carry: tuple[Array, Array], inputs: Array, *, out_sharding=None
) -> tuple[tuple[Array, Array], Array]: # type: ignore[override]
r"""An optimized long short-term memory (LSTM) cell.

Expand All @@ -380,14 +381,15 @@ def __call__(
``LSTMCell.initialize_carry``.
inputs: an ndarray with the input for the current time step.
All dimensions except the final are considered batch dimensions.
out_sharding: sharding to apply to the output of each Linear layer.

Returns:
A tuple with the new carry and the output.
"""
c, h = carry

# Compute combined transformations for inputs and hidden state
y = self.dense_i(inputs) + self.dense_h(h)
y = self.dense_i(inputs, out_sharding=out_sharding) + self.dense_h(h, out_sharding=out_sharding)

# Split the combined transformations into individual gates
i, f, g, o = jnp.split(y, indices_or_sections=4, axis=-1)
Expand Down Expand Up @@ -534,8 +536,20 @@ def __init__(

self.carry_init = carry_init

def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: # type: ignore[override]
new_carry = self.dense_i(inputs) + self.dense_h(carry)
def __call__(self, carry: Array, inputs: Array, *, out_sharding=None) -> tuple[Array, Array]: # type: ignore[override]
"""Simple RNN cell.

Args:
carry: the hidden state of the RNN cell,
initialized using ``SimpleCell.initialize_carry``.
inputs: an ndarray with the input for the current time step.
All dimensions except the final are considered batch dimensions.
out_sharding: sharding to apply to the output of each Linear layer.

Returns:
A tuple with the new carry and the output.
"""
new_carry = self.dense_i(inputs, out_sharding=out_sharding) + self.dense_h(carry, out_sharding=out_sharding)
if self.residual:
new_carry += carry
new_carry = self.activation_fn(new_carry)
Expand Down Expand Up @@ -691,23 +705,24 @@ def __init__(

self.carry_init = carry_init

def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: # type: ignore[override]
def __call__(self, carry: Array, inputs: Array, *, out_sharding=None) -> tuple[Array, Array]: # type: ignore[override]
"""Gated recurrent unit (GRU) cell.

Args:
carry: the hidden state of the GRU cell,
initialized using ``GRUCell.initialize_carry``.
inputs: an ndarray with the input for the current time step.
All dimensions except the final are considered batch dimensions.
out_sharding: sharding to apply to the output of each Linear layer.

Returns:
A tuple with the new carry and the output.
"""
h = carry

# Compute combined transformations for inputs and hidden state
x_transformed = self.dense_i(inputs)
h_transformed = self.dense_h(h)
x_transformed = self.dense_i(inputs, out_sharding=out_sharding)
h_transformed = self.dense_h(h, out_sharding=out_sharding)

# Split the combined transformations into individual components
xi_r, xi_z, xi_n = jnp.split(x_transformed, 3, axis=-1)
Expand Down
21 changes: 21 additions & 0 deletions tests/nnx/nn/recurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,5 +642,26 @@ def __call__(self, x):
self.assertEqual(model.lstm.cell.recurrent_dropout.rngs.count[...], 1)


class TestCellSharding(absltest.TestCase):
def test_out_sharding_signature(self):
rngs = nnx.Rngs(0)

cell_types = [
(nnx.SimpleCell, {'in_features': 4, 'hidden_features': 4}),
(nnx.LSTMCell, {'in_features': 4, 'hidden_features': 4}),
(nnx.OptimizedLSTMCell, {'in_features': 4, 'hidden_features': 4}),
(nnx.GRUCell, {'in_features': 4, 'hidden_features': 4}),
]

for cell_cls, kwargs in cell_types:
with self.subTest(cell_cls=cell_cls.__name__):
model = cell_cls(**kwargs, rngs=rngs)
carry = model.initialize_carry((1, 4), rngs=rngs)
x = jnp.ones((1, 4))
# Just verify it accepts out_sharding=None without error
out = model(carry, x, out_sharding=None)
self.assertLen(out, 2) # All cells return (new_carry, output) or equivalent structure


if __name__ == '__main__':
absltest.main()