Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions parallax/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import functools
import inspect
from typing import Callable, Protocol, runtime_checkable
from typing import Any, Callable, Protocol, runtime_checkable

from flax import nnx
import jax
Expand Down Expand Up @@ -172,18 +172,30 @@ def compute_vjp(state_primal, input_primal, cotangent):
return final_grads_state


class RemattedLayer(nnx.Module):

def __init__(self, layer: nnx.Module, rematted_call: Callable[..., Any]):
self.layer = layer
self.rematted_call = rematted_call

def __call__(self, *args, **kwargs):
return self.rematted_call(self.layer, *args, **kwargs)

def __getattr__(self, name):
return getattr(self.layer, name)


def remat_model(model: nnx.Module) -> nnx.Module:
"""Takes an NNX Module and returns one with all layers rematerialized."""
# TODO(jeffcarp): Generalize this to work with non-Sequential models.
new_model = nnx.clone(model)
for i, layer in enumerate(new_model.layers):
signature = inspect.signature(layer.__call__)
# Assumes `input` will always be the first parameter of the Module.
static_argnums = tuple(range(1, len(signature.parameters)))
new_model.layers[i] = nnx.remat(
new_model.layers[i],
static_argnums=static_argnums,
)
unbound_call = layer.__class__.__call__
# Shift static_argnums by 1 because of `self` in unbound call.
static_argnums = tuple(i + 1 for i in range(1, len(signature.parameters)))
rematted_call = nnx.remat(unbound_call, static_argnums=static_argnums)
new_model.layers[i] = RemattedLayer(layer, rematted_call)
return new_model


Expand Down
4 changes: 2 additions & 2 deletions parallax/offload_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def loss_fn(model, inputs):
np.testing.assert_array_equal(act_loss, exp_loss)
# Assert all model weights match after gradient update.
np.testing.assert_allclose(
actual_model.layers[0].kernel.value,
reference_model.layers[0].kernel.value,
actual_model.layers[0].kernel[...],
reference_model.layers[0].kernel[...],
atol=1e-3,
rtol=1e-1,
)
Expand Down