Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions diffrax/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def loop(
solver,
stepsize_controller,
discrete_terminating_event,
delays,
saveat,
t0,
t1,
Expand Down Expand Up @@ -194,6 +195,7 @@ def _loop_backsolve_bwd(
solver,
stepsize_controller,
discrete_terminating_event,
delays,
saveat,
t0,
t1,
Expand Down Expand Up @@ -232,6 +234,7 @@ def _loop_backsolve_bwd(
solver=solver,
stepsize_controller=stepsize_controller,
discrete_terminating_event=discrete_terminating_event,
delays=delays,
terms=adjoint_terms,
dt0=None if dt0 is None else -dt0,
max_steps=max_steps,
Expand Down Expand Up @@ -398,12 +401,16 @@ def __init__(self, **kwargs):
)
self.kwargs = kwargs

def loop(self, *, args, terms, saveat, init_state, **kwargs):
def loop(self, *, args, terms, saveat, init_state, delays, **kwargs):
if saveat.steps or saveat.dense:
raise NotImplementedError(
"Cannot use `adjoint=BacksolveAdjoint()` with "
"`saveat=Steps(steps=True)` or `saveat=Steps(dense=True)`."
)
if delays is not None:
raise NotImplementedError(
"Cannot use `delays` with `adjoint=BacksolveAdjoint()`"
)

y = init_state.y
sentinel = object()
Expand All @@ -412,7 +419,12 @@ def loop(self, *, args, terms, saveat, init_state, **kwargs):
)

final_state, aux_stats = _loop_backsolve(
(y, args, terms), self=self, saveat=saveat, init_state=init_state, **kwargs
(y, args, terms),
self=self,
saveat=saveat,
init_state=init_state,
delays=delays,
**kwargs,
)

# We only allow backpropagation through `ys`; in particular not through
Expand Down
65 changes: 54 additions & 11 deletions diffrax/integrate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools as ft
import warnings
from typing import Optional
from typing import Callable, Optional, Sequence

import equinox as eqx
import jax
Expand Down Expand Up @@ -35,7 +35,7 @@
ConstantStepSize,
StepTo,
)
from .term import AbstractTerm, WrapTerm
from .term import AbstractTerm, VectorFieldWrapper, WrapTerm


class _State(eqx.Module):
Expand Down Expand Up @@ -102,6 +102,7 @@ def loop(
solver,
stepsize_controller,
discrete_terminating_event,
delays,
saveat,
t0,
t1,
Expand Down Expand Up @@ -130,21 +131,52 @@ def body_fun(state, inplace):
# step sizes, all that jazz.
#

(y, y_error, dense_info, solver_state, solver_result) = solver.step(
terms,
state.tprev,
state.tnext,
state.y,
args,
state.solver_state,
state.made_jump,
)
if delays is None:
(y, y_error, dense_info, solver_state, solver_result) = solver.step(
terms,
state.tprev,
state.tnext,
state.y,
args,
state.solver_state,
state.made_jump,
)
else:
# TODO: double-check that these are the correct `ts_size` and
# `direction`.
history = DenseInterpolation(
ts=state.dense_ts,
ts_size=state.dense_save_index + 1,
interpolation_cls=solver.interpolation_cls,
infos=state.dense_infos,
direction=1,
)
history_vals = []
for delay in delays:
delay_val = delay(state.tprev, state.y, args)
history_val = history.evaluate(delay_val)
history_val.append(history_val)
history_vals = tuple(history_vals)

is_vf_wrapper = lambda x: isinstance(x, VectorFieldWrapper)

def _apply_history(x):
if is_vf_wrapper(x):
vector_field = jtu.Partial(x.vector_field, history=history_vals)
return VectorFieldWrapper(vector_field)
else:
return x

terms_ = jtu.tree_map(_apply_history, terms, is_leaf=is_vf_wrapper)
# TODO: write down implicit problem wrt dense_info, using `terms_`
(y, y_error, dense_info, solver_state, solver_result) = terms_ # ...

# e.g. if someone has a sqrt(y) in the vector field, and dt0 is so large that
# we get a negative value for y, and then get a NaN vector field. (And then
# everything breaks.) See #143.
y_error = jtu.tree_map(lambda x: jnp.where(jnp.isnan(x), jnp.inf, x), y_error)

# TODO: handle discontinuity detection for delays
error_order = solver.error_order(terms)
(
keep_step,
Expand Down Expand Up @@ -510,6 +542,7 @@ def diffeqsolve(
stepsize_controller: AbstractStepSizeController = ConstantStepSize(),
adjoint: AbstractAdjoint = RecursiveCheckpointAdjoint(),
discrete_terminating_event: Optional[AbstractDiscreteTerminatingEvent] = None,
delays: Optional[Sequence[Callable]] = None,
max_steps: Optional[int] = 16**3,
throw: bool = True,
solver_state: Optional[PyTree] = None,
Expand Down Expand Up @@ -563,6 +596,10 @@ def diffeqsolve(
- `discrete_terminating_event`: A discrete event at which to terminate the solve
early. See the page on [Events](./events.md) for more information.

- `delays`: A tuple of functions, which describe the delays used in a delay
differential equation. See the page on [Delays](./delays.md) for more
information.

- `max_steps`: The maximum number of steps to take before quitting the computation
unconditionally.

Expand Down Expand Up @@ -626,6 +663,9 @@ def diffeqsolve(
# Initial set-up
#

if delays is not None and not saveat.dense:
raise ValueError("Delay differential equations require saving dense output")

# Error checking
if dt0 is not None:
msg = (
Expand Down Expand Up @@ -728,6 +768,8 @@ def _promote(yi):
terms,
is_leaf=lambda x: isinstance(x, AbstractTerm),
)
if delays is not None:
delays = [lambda t, y, args, fn=fn: fn(t, y, args) * direction for fn in delays]

# Stepsize controller gets an opportunity to modify the solver.
# Note that at this point the solver could be anything so we must check any
Expand Down Expand Up @@ -841,6 +883,7 @@ def _promote(yi):
solver=solver,
stepsize_controller=stepsize_controller,
discrete_terminating_event=discrete_terminating_event,
delays=delays,
saveat=saveat,
t0=t0,
t1=t1,
Expand Down
14 changes: 14 additions & 0 deletions diffrax/term.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,13 @@ def is_vf_expensive(
return False


class VectorFieldWrapper(eqx.Module):
vector_field: Callable[[Scalar, PyTree, PyTree], PyTree]

def __call__(self, t, y, args):
return self.vector_field(t, y, args)


class ODETerm(AbstractTerm):
r"""A term representing $f(t, y(t), args) \mathrm{d}t$. That is to say, the term
appearing on the right hand side of an ODE, in which the control is time.
Expand All @@ -169,6 +176,9 @@ class ODETerm(AbstractTerm):
"""
vector_field: Callable[[Scalar, PyTree, PyTree], PyTree]

def __init__(self, vector_field):
self.vector_field = VectorFieldWrapper(vector_field)

def vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree:
return self.vector_field(t, y, args)

Expand Down Expand Up @@ -200,6 +210,10 @@ class _ControlTerm(AbstractTerm):
vector_field: Callable[[Scalar, PyTree, PyTree], PyTree]
control: AbstractPath

def __init__(self, vector_field, control):
self.vector_field = VectorFieldWrapper(vector_field)
self.control = control

def vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree:
return self.vector_field(t, y, args)

Expand Down