diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 0137d757..5268212b 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -7,7 +7,7 @@ jobs: run-tests: strategy: matrix: - python-version: [ "3.10", "3.12" ] + python-version: [ "3.11", "3.13" ] os: [ ubuntu-latest ] fail-fast: false runs-on: ${{ matrix.os }} diff --git a/benchmarks/against_scan.py b/benchmarks/against_scan.py index 6c67655f..a7771c4e 100644 --- a/benchmarks/against_scan.py +++ b/benchmarks/against_scan.py @@ -36,7 +36,7 @@ def speedtest(fn, name): # INTEGRATE WITH scan -@jax.checkpoint # pyright: ignore +@jax.checkpoint def body(carry, t): u, v, dt = carry u = u + du(t, v, None) * dt diff --git a/diffrax/_step_size_controller/clip.py b/diffrax/_step_size_controller/clip.py index 0a642d6e..cac6ec1a 100644 --- a/diffrax/_step_size_controller/clip.py +++ b/diffrax/_step_size_controller/clip.py @@ -356,7 +356,7 @@ def callback(_keep_step, _t1): step_info = None else: step_index, step_ts = controller_state.step_info - # We actaully bump `next_t0` past any `step_ts` whilst checking where to + # We actually bump `next_t0` past any `step_ts` whilst checking where to # clip `next_t1`. This is in case we have a set up like the following: # ```python # ClipStepSizeController( @@ -376,6 +376,24 @@ def callback(_keep_step, _t1): else: jump_index, jump_ts = controller_state.jump_info next_t0, made_jump2 = _bump_next_t0(next_t0, jump_ts) + # This next line is to fix + # https://github.com/patrick-kidger/diffrax/issues/713 + # TODO: should we add this to the `step_ts` branch as well? + # + # What's going on here is that we may have + # the `diffeqsolve(t0=...)` be prevbefore a jump time (for example due to a + # previous diffeqsolve targeting that time), in which case during `.init` + # we will obtain `t0 = t1 = prevbefore(jump_time)`. + # The `_bump_next_t0` will then move `next_t0` to after the `jump_time`... + # whilst leaving `next_t1` unchanged! We actually end up `next_t1 < next_t0` + # which is very not okay. + # + # The fix is to ensure that `next_t1` is itself bumped to at least this + # value. As a final detail, we need to make it `nextafter` so that we don't + # have a zero-length interval – in this case an underlying PID controller + # would just never change the interval size at all, since it acts + # multiplicatively. (And even just 1 ULP is enough to unstick it.) + next_t1 = jnp.maximum(eqxi.nextafter(next_t0), next_t1) made_jump = made_jump | made_jump2 jump_index = _find_idx_with_hint(next_t0, jump_ts, jump_index) next_t1 = _clip_t(next_t1, jump_index, jump_ts, True) diff --git a/pyproject.toml b/pyproject.toml index c7ec6d27..459deb3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ keywords = ["jax", "dynamical-systems", "differential-equations", "deep-learning license = {file = "LICENSE"} name = "diffrax" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.11" urls = {repository = "https://github.com/patrick-kidger/diffrax"} version = "0.7.0" diff --git a/test/test_adaptive_stepsize_controller.py b/test/test_adaptive_stepsize_controller.py index 21c24e4a..d785d16e 100644 --- a/test/test_adaptive_stepsize_controller.py +++ b/test/test_adaptive_stepsize_controller.py @@ -7,6 +7,7 @@ import jax.numpy as jnp import jax.random as jr import jax.tree_util as jtu +import optimistix as optx import pytest from diffrax._step_size_controller.clip import _find_idx_with_hint from jaxtyping import Array @@ -361,3 +362,29 @@ def test_jump_at_t1_with_large_t1_in_float32(): saveat=saveat, ) assert sol.ts == jnp.array([t1]) + + +# https://github.com/patrick-kidger/diffrax/issues/713 +def test_t0_at_jump_time(): + jump_time = 0.98 + controller = diffrax.PIDController(rtol=1e-6, atol=1e-6) + controller = diffrax.ClipStepSizeController(controller, jump_ts=[jump_time]) + sol = diffrax.diffeqsolve( + diffrax.ODETerm(lambda t, y, args: jnp.zeros_like(y)), + diffrax.Heun(), + t0=eqxi.prevbefore(jnp.asarray(jump_time)), + t1=1.2, + dt0=None, + y0=jnp.array([0, 0, 0, 0.0]), + stepsize_controller=controller, + event=diffrax.Event( + cond_fn=lambda t, y, args, **kw: jump_time - t, + root_finder=optx.Newton(atol=1e-4, rtol=1e-4), + direction=True, + ), + max_steps=100, + ) + # And in particular not an event. + # What used to happen was something very weird where we'd oscillate across the + # jump time. + assert sol.result == diffrax.RESULTS.successful