From 1cce37cfdf28d91d65f970b0606d6ffdfda2bec1 Mon Sep 17 00:00:00 2001 From: Jeff Carpenter Date: Wed, 22 Apr 2026 22:45:54 -0700 Subject: [PATCH] Fix remat failures in offload.py PiperOrigin-RevId: 904235597 --- parallax/offload.py | 26 +++++++++++++++++++------- parallax/offload_test.py | 4 ++-- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/parallax/offload.py b/parallax/offload.py index 4317b59..e0b14f3 100644 --- a/parallax/offload.py +++ b/parallax/offload.py @@ -16,7 +16,7 @@ import functools import inspect -from typing import Callable, Protocol, runtime_checkable +from typing import Any, Callable, Protocol, runtime_checkable from flax import nnx import jax @@ -172,18 +172,30 @@ def compute_vjp(state_primal, input_primal, cotangent): return final_grads_state +class RemattedLayer(nnx.Module): + + def __init__(self, layer: nnx.Module, rematted_call: Callable[..., Any]): + self.layer = layer + self.rematted_call = rematted_call + + def __call__(self, *args, **kwargs): + return self.rematted_call(self.layer, *args, **kwargs) + + def __getattr__(self, name): + return getattr(self.layer, name) + + def remat_model(model: nnx.Module) -> nnx.Module: """Takes an NNX Module and returns one with all layers rematerialized.""" # TODO(jeffcarp): Generalize this to work with non-Sequential models. new_model = nnx.clone(model) for i, layer in enumerate(new_model.layers): signature = inspect.signature(layer.__call__) - # Assumes `input` will always be the first parameter of the Module. - static_argnums = tuple(range(1, len(signature.parameters))) - new_model.layers[i] = nnx.remat( - new_model.layers[i], - static_argnums=static_argnums, - ) + unbound_call = layer.__class__.__call__ + # Shift static_argnums by 1 because of `self` in unbound call. + static_argnums = tuple(i + 1 for i in range(1, len(signature.parameters))) + rematted_call = nnx.remat(unbound_call, static_argnums=static_argnums) + new_model.layers[i] = RemattedLayer(layer, rematted_call) return new_model diff --git a/parallax/offload_test.py b/parallax/offload_test.py index 615e121..fb22f32 100644 --- a/parallax/offload_test.py +++ b/parallax/offload_test.py @@ -100,8 +100,8 @@ def loss_fn(model, inputs): np.testing.assert_array_equal(act_loss, exp_loss) # Assert all model weights match after gradient update. np.testing.assert_allclose( - actual_model.layers[0].kernel.value, - reference_model.layers[0].kernel.value, + actual_model.layers[0].kernel[...], + reference_model.layers[0].kernel[...], atol=1e-3, rtol=1e-1, )