diff --git a/parallax/offload.py b/parallax/offload.py index 4317b59..e0b14f3 100644 --- a/parallax/offload.py +++ b/parallax/offload.py @@ -16,7 +16,7 @@ import functools import inspect -from typing import Callable, Protocol, runtime_checkable +from typing import Any, Callable, Protocol, runtime_checkable from flax import nnx import jax @@ -172,18 +172,30 @@ def compute_vjp(state_primal, input_primal, cotangent): return final_grads_state +class RemattedLayer(nnx.Module): + + def __init__(self, layer: nnx.Module, rematted_call: Callable[..., Any]): + self.layer = layer + self.rematted_call = rematted_call + + def __call__(self, *args, **kwargs): + return self.rematted_call(self.layer, *args, **kwargs) + + def __getattr__(self, name): + return getattr(self.layer, name) + + def remat_model(model: nnx.Module) -> nnx.Module: """Takes an NNX Module and returns one with all layers rematerialized.""" # TODO(jeffcarp): Generalize this to work with non-Sequential models. new_model = nnx.clone(model) for i, layer in enumerate(new_model.layers): signature = inspect.signature(layer.__call__) - # Assumes `input` will always be the first parameter of the Module. - static_argnums = tuple(range(1, len(signature.parameters))) - new_model.layers[i] = nnx.remat( - new_model.layers[i], - static_argnums=static_argnums, - ) + unbound_call = layer.__class__.__call__ + # Shift static_argnums by 1 because of `self` in unbound call. + static_argnums = tuple(i + 1 for i in range(1, len(signature.parameters))) + rematted_call = nnx.remat(unbound_call, static_argnums=static_argnums) + new_model.layers[i] = RemattedLayer(layer, rematted_call) return new_model diff --git a/parallax/offload_test.py b/parallax/offload_test.py index 615e121..fb22f32 100644 --- a/parallax/offload_test.py +++ b/parallax/offload_test.py @@ -100,8 +100,8 @@ def loss_fn(model, inputs): np.testing.assert_array_equal(act_loss, exp_loss) # Assert all model weights match after gradient update. np.testing.assert_allclose( - actual_model.layers[0].kernel.value, - reference_model.layers[0].kernel.value, + actual_model.layers[0].kernel[...], + reference_model.layers[0].kernel[...], atol=1e-3, rtol=1e-1, )