From b0758f42894cc0bb6d4febfa56b43eea7d7879e5 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sat, 24 Sep 2022 16:25:00 -0700 Subject: [PATCH] In progress commit on branch delay. --- diffrax/adjoint.py | 16 +++++++++-- diffrax/integrate.py | 65 ++++++++++++++++++++++++++++++++++++-------- diffrax/term.py | 14 ++++++++++ 3 files changed, 82 insertions(+), 13 deletions(-) diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py index f4452a81..97c9769b 100644 --- a/diffrax/adjoint.py +++ b/diffrax/adjoint.py @@ -23,6 +23,7 @@ def loop( solver, stepsize_controller, discrete_terminating_event, + delays, saveat, t0, t1, @@ -194,6 +195,7 @@ def _loop_backsolve_bwd( solver, stepsize_controller, discrete_terminating_event, + delays, saveat, t0, t1, @@ -232,6 +234,7 @@ def _loop_backsolve_bwd( solver=solver, stepsize_controller=stepsize_controller, discrete_terminating_event=discrete_terminating_event, + delays=delays, terms=adjoint_terms, dt0=None if dt0 is None else -dt0, max_steps=max_steps, @@ -398,12 +401,16 @@ def __init__(self, **kwargs): ) self.kwargs = kwargs - def loop(self, *, args, terms, saveat, init_state, **kwargs): + def loop(self, *, args, terms, saveat, init_state, delays, **kwargs): if saveat.steps or saveat.dense: raise NotImplementedError( "Cannot use `adjoint=BacksolveAdjoint()` with " "`saveat=Steps(steps=True)` or `saveat=Steps(dense=True)`." ) + if delays is not None: + raise NotImplementedError( + "Cannot use `delays` with `adjoint=BacksolveAdjoint()`" + ) y = init_state.y sentinel = object() @@ -412,7 +419,12 @@ def loop(self, *, args, terms, saveat, init_state, **kwargs): ) final_state, aux_stats = _loop_backsolve( - (y, args, terms), self=self, saveat=saveat, init_state=init_state, **kwargs + (y, args, terms), + self=self, + saveat=saveat, + init_state=init_state, + delays=delays, + **kwargs, ) # We only allow backpropagation through `ys`; in particular not through diff --git a/diffrax/integrate.py b/diffrax/integrate.py index 683038a9..968bf3fb 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -1,6 +1,6 @@ import functools as ft import warnings -from typing import Optional +from typing import Callable, Optional, Sequence import equinox as eqx import jax @@ -35,7 +35,7 @@ ConstantStepSize, StepTo, ) -from .term import AbstractTerm, WrapTerm +from .term import AbstractTerm, VectorFieldWrapper, WrapTerm class _State(eqx.Module): @@ -102,6 +102,7 @@ def loop( solver, stepsize_controller, discrete_terminating_event, + delays, saveat, t0, t1, @@ -130,21 +131,52 @@ def body_fun(state, inplace): # step sizes, all that jazz. # - (y, y_error, dense_info, solver_state, solver_result) = solver.step( - terms, - state.tprev, - state.tnext, - state.y, - args, - state.solver_state, - state.made_jump, - ) + if delays is None: + (y, y_error, dense_info, solver_state, solver_result) = solver.step( + terms, + state.tprev, + state.tnext, + state.y, + args, + state.solver_state, + state.made_jump, + ) + else: + # TODO: double-check that these are the correct `ts_size` and + # `direction`. + history = DenseInterpolation( + ts=state.dense_ts, + ts_size=state.dense_save_index + 1, + interpolation_cls=solver.interpolation_cls, + infos=state.dense_infos, + direction=1, + ) + history_vals = [] + for delay in delays: + delay_val = delay(state.tprev, state.y, args) + history_val = history.evaluate(delay_val) + history_val.append(history_val) + history_vals = tuple(history_vals) + + is_vf_wrapper = lambda x: isinstance(x, VectorFieldWrapper) + + def _apply_history(x): + if is_vf_wrapper(x): + vector_field = jtu.Partial(x.vector_field, history=history_vals) + return VectorFieldWrapper(vector_field) + else: + return x + + terms_ = jtu.tree_map(_apply_history, terms, is_leaf=is_vf_wrapper) + # TODO: write down implicit problem wrt dense_info, using `terms_` + (y, y_error, dense_info, solver_state, solver_result) = terms_ # ... # e.g. if someone has a sqrt(y) in the vector field, and dt0 is so large that # we get a negative value for y, and then get a NaN vector field. (And then # everything breaks.) See #143. y_error = jtu.tree_map(lambda x: jnp.where(jnp.isnan(x), jnp.inf, x), y_error) + # TODO: handle discontinuity detection for delays error_order = solver.error_order(terms) ( keep_step, @@ -510,6 +542,7 @@ def diffeqsolve( stepsize_controller: AbstractStepSizeController = ConstantStepSize(), adjoint: AbstractAdjoint = RecursiveCheckpointAdjoint(), discrete_terminating_event: Optional[AbstractDiscreteTerminatingEvent] = None, + delays: Optional[Sequence[Callable]] = None, max_steps: Optional[int] = 16**3, throw: bool = True, solver_state: Optional[PyTree] = None, @@ -563,6 +596,10 @@ def diffeqsolve( - `discrete_terminating_event`: A discrete event at which to terminate the solve early. See the page on [Events](./events.md) for more information. + - `delays`: A tuple of functions, which describe the delays used in a delay + differential equation. See the page on [Delays](./delays.md) for more + information. + - `max_steps`: The maximum number of steps to take before quitting the computation unconditionally. @@ -626,6 +663,9 @@ def diffeqsolve( # Initial set-up # + if delays is not None and not saveat.dense: + raise ValueError("Delay differential equations require saving dense output") + # Error checking if dt0 is not None: msg = ( @@ -728,6 +768,8 @@ def _promote(yi): terms, is_leaf=lambda x: isinstance(x, AbstractTerm), ) + if delays is not None: + delays = [lambda t, y, args, fn=fn: fn(t, y, args) * direction for fn in delays] # Stepsize controller gets an opportunity to modify the solver. # Note that at this point the solver could be anything so we must check any @@ -841,6 +883,7 @@ def _promote(yi): solver=solver, stepsize_controller=stepsize_controller, discrete_terminating_event=discrete_terminating_event, + delays=delays, saveat=saveat, t0=t0, t1=t1, diff --git a/diffrax/term.py b/diffrax/term.py index de557070..4361f5a3 100644 --- a/diffrax/term.py +++ b/diffrax/term.py @@ -152,6 +152,13 @@ def is_vf_expensive( return False +class VectorFieldWrapper(eqx.Module): + vector_field: Callable[[Scalar, PyTree, PyTree], PyTree] + + def __call__(self, t, y, args): + return self.vector_field(t, y, args) + + class ODETerm(AbstractTerm): r"""A term representing $f(t, y(t), args) \mathrm{d}t$. That is to say, the term appearing on the right hand side of an ODE, in which the control is time. @@ -169,6 +176,9 @@ class ODETerm(AbstractTerm): """ vector_field: Callable[[Scalar, PyTree, PyTree], PyTree] + def __init__(self, vector_field): + self.vector_field = VectorFieldWrapper(vector_field) + def vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree: return self.vector_field(t, y, args) @@ -200,6 +210,10 @@ class _ControlTerm(AbstractTerm): vector_field: Callable[[Scalar, PyTree, PyTree], PyTree] control: AbstractPath + def __init__(self, vector_field, control): + self.vector_field = VectorFieldWrapper(vector_field) + self.control = control + def vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree: return self.vector_field(t, y, args)