From fcca8e943e4ea66f282c935ab9e9b54fb322aff3 Mon Sep 17 00:00:00 2001 From: aarushisingh04 Date: Sun, 15 Feb 2026 15:47:29 +0530 Subject: [PATCH 1/2] feat(nnx): Add out_sharding to recurrent cell __call__ methods --- flax/nnx/nn/recurrent.py | 40 ++++++++++++++++++++++++---------- tests/nnx/nn/recurrent_test.py | 21 ++++++++++++++++++ 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/flax/nnx/nn/recurrent.py b/flax/nnx/nn/recurrent.py index 6346737ef..d7989c7c7 100644 --- a/flax/nnx/nn/recurrent.py +++ b/flax/nnx/nn/recurrent.py @@ -193,7 +193,7 @@ def __init__( self.carry_init = carry_init def __call__( - self, carry: tuple[Array, Array], inputs: Array + self, carry: tuple[Array, Array], inputs: Array, *, out_sharding=None ) -> tuple[tuple[Array, Array], Array]: # type: ignore[override] r"""A long short-term memory (LSTM) cell. @@ -202,15 +202,16 @@ def __call__( initialized using ``LSTMCell.initialize_carry``. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. + out_sharding: sharding to apply to the output of each Linear layer. Returns: A tuple with the new carry and the output. """ c, h = carry - i = self.gate_fn(self.ii(inputs) + self.hi(h)) - f = self.gate_fn(self.if_(inputs) + self.hf(h)) - g = self.activation_fn(self.ig(inputs) + self.hg(h)) - o = self.gate_fn(self.io(inputs) + self.ho(h)) + i = self.gate_fn(self.ii(inputs, out_sharding=out_sharding) + self.hi(h, out_sharding=out_sharding)) + f = self.gate_fn(self.if_(inputs, out_sharding=out_sharding) + self.hf(h, out_sharding=out_sharding)) + g = self.activation_fn(self.ig(inputs, out_sharding=out_sharding) + self.hg(h, out_sharding=out_sharding)) + o = self.gate_fn(self.io(inputs, out_sharding=out_sharding) + self.ho(h, out_sharding=out_sharding)) new_c = f * c + i * g new_h = o * self.activation_fn(new_c) return (new_c, new_h), new_h @@ -371,7 +372,7 @@ def __init__( self.carry_init = carry_init def __call__( - self, carry: tuple[Array, Array], inputs: Array + self, carry: tuple[Array, Array], inputs: Array, *, out_sharding=None ) -> tuple[tuple[Array, Array], Array]: # type: ignore[override] r"""An optimized long short-term memory (LSTM) cell. @@ -380,6 +381,7 @@ def __call__( ``LSTMCell.initialize_carry``. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. + out_sharding: sharding to apply to the output of each Linear layer. Returns: A tuple with the new carry and the output. @@ -387,7 +389,7 @@ def __call__( c, h = carry # Compute combined transformations for inputs and hidden state - y = self.dense_i(inputs) + self.dense_h(h) + y = self.dense_i(inputs, out_sharding=out_sharding) + self.dense_h(h, out_sharding=out_sharding) # Split the combined transformations into individual gates i, f, g, o = jnp.split(y, indices_or_sections=4, axis=-1) @@ -534,8 +536,21 @@ def __init__( self.carry_init = carry_init - def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: # type: ignore[override] - new_carry = self.dense_i(inputs) + self.dense_h(carry) + def __call__(self, carry: Array, inputs: Array, *, out_sharding=None) -> tuple[Array, Array]: # type: ignore[override] + """Simple RNN cell. + + Args: + carry: the hidden state of the RNN cell, + initialized using ``SimpleCell.initialize_carry``. + inputs: an ndarray with the input for the current time step. + All dimensions except the final are considered batch dimensions. + out_sharding: the sharding of the output. If None, the output is not + sharded. + + Returns: + A tuple with the new carry and the output. + """ + new_carry = self.dense_i(inputs, out_sharding=out_sharding) + self.dense_h(carry, out_sharding=out_sharding) if self.residual: new_carry += carry new_carry = self.activation_fn(new_carry) @@ -691,7 +706,7 @@ def __init__( self.carry_init = carry_init - def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: # type: ignore[override] + def __call__(self, carry: Array, inputs: Array, *, out_sharding=None) -> tuple[Array, Array]: # type: ignore[override] """Gated recurrent unit (GRU) cell. Args: @@ -699,6 +714,7 @@ def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: # type: initialized using ``GRUCell.initialize_carry``. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. + out_sharding: sharding to apply to the output of each Linear layer. Returns: A tuple with the new carry and the output. @@ -706,8 +722,8 @@ def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: # type: h = carry # Compute combined transformations for inputs and hidden state - x_transformed = self.dense_i(inputs) - h_transformed = self.dense_h(h) + x_transformed = self.dense_i(inputs, out_sharding=out_sharding) + h_transformed = self.dense_h(h, out_sharding=out_sharding) # Split the combined transformations into individual components xi_r, xi_z, xi_n = jnp.split(x_transformed, 3, axis=-1) diff --git a/tests/nnx/nn/recurrent_test.py b/tests/nnx/nn/recurrent_test.py index 034baff2f..d0d296dc1 100644 --- a/tests/nnx/nn/recurrent_test.py +++ b/tests/nnx/nn/recurrent_test.py @@ -642,5 +642,26 @@ def __call__(self, x): self.assertEqual(model.lstm.cell.recurrent_dropout.rngs.count[...], 1) +class TestCellSharding(absltest.TestCase): + def test_out_sharding_signature(self): + rngs = nnx.Rngs(0) + + cell_types = [ + (nnx.SimpleCell, {'in_features': 4, 'hidden_features': 4}), + (nnx.LSTMCell, {'in_features': 4, 'hidden_features': 4}), + (nnx.OptimizedLSTMCell, {'in_features': 4, 'hidden_features': 4}), + (nnx.GRUCell, {'in_features': 4, 'hidden_features': 4}), + ] + + for cell_cls, kwargs in cell_types: + with self.subTest(cell_cls=cell_cls.__name__): + model = cell_cls(**kwargs, rngs=rngs) + carry = model.initialize_carry((1, 4), rngs=rngs) + x = jnp.ones((1, 4)) + # Just verify it accepts out_sharding=None without error + out = model(carry, x, out_sharding=None) + self.assertLen(out, 2) # All cells return (new_carry, output) or equivalent structure + + if __name__ == '__main__': absltest.main() From 331ff8e5d7803bf24081cf800697da2f06e83a8f Mon Sep 17 00:00:00 2001 From: Aarushi Singh <110608667+aarushisingh04@users.noreply.github.com> Date: Sun, 15 Feb 2026 20:43:44 +0530 Subject: [PATCH 2/2] Update flax/nnx/nn/recurrent.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- flax/nnx/nn/recurrent.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flax/nnx/nn/recurrent.py b/flax/nnx/nn/recurrent.py index d7989c7c7..c7371afc0 100644 --- a/flax/nnx/nn/recurrent.py +++ b/flax/nnx/nn/recurrent.py @@ -544,8 +544,7 @@ def __call__(self, carry: Array, inputs: Array, *, out_sharding=None) -> tuple[A initialized using ``SimpleCell.initialize_carry``. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. - out_sharding: the sharding of the output. If None, the output is not - sharded. + out_sharding: sharding to apply to the output of each Linear layer. Returns: A tuple with the new carry and the output.