diff --git a/flax/nnx/nn/recurrent.py b/flax/nnx/nn/recurrent.py index 6346737ef..c7371afc0 100644 --- a/flax/nnx/nn/recurrent.py +++ b/flax/nnx/nn/recurrent.py @@ -193,7 +193,7 @@ def __init__( self.carry_init = carry_init def __call__( - self, carry: tuple[Array, Array], inputs: Array + self, carry: tuple[Array, Array], inputs: Array, *, out_sharding=None ) -> tuple[tuple[Array, Array], Array]: # type: ignore[override] r"""A long short-term memory (LSTM) cell. @@ -202,15 +202,16 @@ def __call__( initialized using ``LSTMCell.initialize_carry``. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. + out_sharding: sharding to apply to the output of each Linear layer. Returns: A tuple with the new carry and the output. """ c, h = carry - i = self.gate_fn(self.ii(inputs) + self.hi(h)) - f = self.gate_fn(self.if_(inputs) + self.hf(h)) - g = self.activation_fn(self.ig(inputs) + self.hg(h)) - o = self.gate_fn(self.io(inputs) + self.ho(h)) + i = self.gate_fn(self.ii(inputs, out_sharding=out_sharding) + self.hi(h, out_sharding=out_sharding)) + f = self.gate_fn(self.if_(inputs, out_sharding=out_sharding) + self.hf(h, out_sharding=out_sharding)) + g = self.activation_fn(self.ig(inputs, out_sharding=out_sharding) + self.hg(h, out_sharding=out_sharding)) + o = self.gate_fn(self.io(inputs, out_sharding=out_sharding) + self.ho(h, out_sharding=out_sharding)) new_c = f * c + i * g new_h = o * self.activation_fn(new_c) return (new_c, new_h), new_h @@ -371,7 +372,7 @@ def __init__( self.carry_init = carry_init def __call__( - self, carry: tuple[Array, Array], inputs: Array + self, carry: tuple[Array, Array], inputs: Array, *, out_sharding=None ) -> tuple[tuple[Array, Array], Array]: # type: ignore[override] r"""An optimized long short-term memory (LSTM) cell. @@ -380,6 +381,7 @@ def __call__( ``LSTMCell.initialize_carry``. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. + out_sharding: sharding to apply to the output of each Linear layer. Returns: A tuple with the new carry and the output. @@ -387,7 +389,7 @@ def __call__( c, h = carry # Compute combined transformations for inputs and hidden state - y = self.dense_i(inputs) + self.dense_h(h) + y = self.dense_i(inputs, out_sharding=out_sharding) + self.dense_h(h, out_sharding=out_sharding) # Split the combined transformations into individual gates i, f, g, o = jnp.split(y, indices_or_sections=4, axis=-1) @@ -534,8 +536,20 @@ def __init__( self.carry_init = carry_init - def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: # type: ignore[override] - new_carry = self.dense_i(inputs) + self.dense_h(carry) + def __call__(self, carry: Array, inputs: Array, *, out_sharding=None) -> tuple[Array, Array]: # type: ignore[override] + """Simple RNN cell. + + Args: + carry: the hidden state of the RNN cell, + initialized using ``SimpleCell.initialize_carry``. + inputs: an ndarray with the input for the current time step. + All dimensions except the final are considered batch dimensions. + out_sharding: sharding to apply to the output of each Linear layer. + + Returns: + A tuple with the new carry and the output. + """ + new_carry = self.dense_i(inputs, out_sharding=out_sharding) + self.dense_h(carry, out_sharding=out_sharding) if self.residual: new_carry += carry new_carry = self.activation_fn(new_carry) @@ -691,7 +705,7 @@ def __init__( self.carry_init = carry_init - def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: # type: ignore[override] + def __call__(self, carry: Array, inputs: Array, *, out_sharding=None) -> tuple[Array, Array]: # type: ignore[override] """Gated recurrent unit (GRU) cell. Args: @@ -699,6 +713,7 @@ def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: # type: initialized using ``GRUCell.initialize_carry``. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. + out_sharding: sharding to apply to the output of each Linear layer. Returns: A tuple with the new carry and the output. @@ -706,8 +721,8 @@ def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: # type: h = carry # Compute combined transformations for inputs and hidden state - x_transformed = self.dense_i(inputs) - h_transformed = self.dense_h(h) + x_transformed = self.dense_i(inputs, out_sharding=out_sharding) + h_transformed = self.dense_h(h, out_sharding=out_sharding) # Split the combined transformations into individual components xi_r, xi_z, xi_n = jnp.split(x_transformed, 3, axis=-1) diff --git a/tests/nnx/nn/recurrent_test.py b/tests/nnx/nn/recurrent_test.py index 034baff2f..d0d296dc1 100644 --- a/tests/nnx/nn/recurrent_test.py +++ b/tests/nnx/nn/recurrent_test.py @@ -642,5 +642,26 @@ def __call__(self, x): self.assertEqual(model.lstm.cell.recurrent_dropout.rngs.count[...], 1) +class TestCellSharding(absltest.TestCase): + def test_out_sharding_signature(self): + rngs = nnx.Rngs(0) + + cell_types = [ + (nnx.SimpleCell, {'in_features': 4, 'hidden_features': 4}), + (nnx.LSTMCell, {'in_features': 4, 'hidden_features': 4}), + (nnx.OptimizedLSTMCell, {'in_features': 4, 'hidden_features': 4}), + (nnx.GRUCell, {'in_features': 4, 'hidden_features': 4}), + ] + + for cell_cls, kwargs in cell_types: + with self.subTest(cell_cls=cell_cls.__name__): + model = cell_cls(**kwargs, rngs=rngs) + carry = model.initialize_carry((1, 4), rngs=rngs) + x = jnp.ones((1, 4)) + # Just verify it accepts out_sharding=None without error + out = model(carry, x, out_sharding=None) + self.assertLen(out, 2) # All cells return (new_carry, output) or equivalent structure + + if __name__ == '__main__': absltest.main()