Skip to content

Conversation

@poonai
Copy link

@poonai poonai commented Nov 20, 2025

Hello,

I'm new to JAX and numerical computing, and willing to invest the time to learn by implementing numerical methods. After this, I plan to add support for DAE and additional rosenbrock methods. I would appreciate your guidance on getting this PR merged.

Thanks

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! This is super awesome. I've been meaning to add Rosenbrock methods for a while, so I'd love to get this in. I have some fairly nitty comments but the structure of this PR already looks excellent.

control = terms.contr(t0, t1)

# common L.H.S
A = (lx.MatrixLinearOperator(eye) / (control * self.tableau.γ[0])) - (
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use lx.IdentityLinearOperator here.

class Ros3p(AbstractAdaptiveSolver):
r"""Ros3p method.
3rd order Rosenbrock method for solving stiff equation. Uses a 1st order local linear
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect it may make more sense to use a third-order hermite interpolation by default. (Which is the usual standard interpolation method most of the time.)

)
)

# stage 1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this might be a bit tricky, but it's quite important that the stages be wrapped up into a lax.scan. This is so that we don't compile the user-supplied vector field multiple times, as that hugely increases compile time for nontrivial vector fields.

solver_state: _SolverState,
made_jump: BoolScalarLike,
) -> tuple[Y, Y, DenseInfo, _SolverState, RESULTS]:
del made_jump, solver_state
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've not checked, is this method definitely not FSAL?

f(1.0)


def test_ros3p():
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have a few tests that run pretty much every solver, it would be good to add ros3p to these as well.



_tableau = _RosenbrockTableau(
m_sol=jnp.array([2.0, 0.5773502691896258, 0.4226497308103742]),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these should be regular numpy arrays, to avoid initialising the JAX backend (which happens the first time an array is created) whilst Diffrax is being imported.


def step(
self,
terms: AbstractTerm[ArrayLike, ArrayLike],
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional work is required to make it work with MultiTerm. Reading about other rosenbrock method will allow me to design the proper PyTree abstraction. So, I've limited the term structure to the simple ode.

I can implement this now or include it in the next PR along with the next Rosenbrock method.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think it should already work as-is. A MultiTerm is an AbstractTerm already.

Try diffeqsolve(MultiTerm(ODETerm(...), ODETerm(...)), Ros3p(), ...) and see what happens?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, what I can see here though is that the evolving state and the control are being restricted to specifically arrays, and not the general pytree-of-arrays that we aim to support.

Take a look at equinox.internal.ω for a helper that we use ubiquitously to make it easy to work with pytree-valued state.

Alternatively, it would be fairly straightforward to use jax.flatten_util.ravel_pytree before and after the code you already have.

(Could you add a test that uses pytree-valued state to be sure that whichever choice you make works?)

# and Heun if the solver is Stratonovich.
@pytest.mark.parametrize("solver_ctr,noise,theoretical_order", _solvers_and_orders())
@pytest.mark.parametrize("dtype", (jnp.float64,))
@pytest.mark.skip(reason="This test is failing in the main the branch")
Copy link
Author

@poonai poonai Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test has been fixed in the dev branch.

Should I raise the PR against the dev branch?

I changed the base branch to dev.

@poonai
Copy link
Author

poonai commented Nov 28, 2025

@patrick-kidger I've addressed your comments. Please review them when you get a chance. In the meantime, I'll start reading about other methods.

Thanks

@poonai poonai changed the base branch from main to dev November 28, 2025 11:24
Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really good to me! I have some various small comments – and one complicated compiler optimization one – but the overall structure looks very reasonable to me. :)

Comment on lines 1038 to 1040
if isinstance(solver, Ros3p):
# TODO: add complex dtype support to ros3p.
raise ValueError("Ros3p does not support complex dtypes.")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check could probably be moved to Ros3p.step? I try to keep the core diffeqsolve code solver-agnostic where possible.


def step(
self,
terms: AbstractTerm[ArrayLike, ArrayLike],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think it should already work as-is. A MultiTerm is an AbstractTerm already.

Try diffeqsolve(MultiTerm(ODETerm(...), ODETerm(...)), Ros3p(), ...) and see what happens?


def step(
self,
terms: AbstractTerm[ArrayLike, ArrayLike],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, what I can see here though is that the evolving state and the control are being restricted to specifically arrays, and not the general pytree-of-arrays that we aim to support.

Take a look at equinox.internal.ω for a helper that we use ubiquitously to make it easy to work with pytree-valued state.

Alternatively, it would be fairly straightforward to use jax.flatten_util.ravel_pytree before and after the code you already have.

(Could you add a test that uses pytree-valued state to be sure that whichever choice you make works?)

Comment on lines 192 to 197
u = u.at[stage].set(stage_u)
return u, vf

u, stage_vf = lax.scan(
f=body, init=u, xs=jnp.arange(start_stage, self.tableau.num_stages)
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is some super inside baseball here, but: performing a scan, whose body function both reads from and writes to the same buffer, will be inefficient when reverse-mode differentiated.

This particular combination of things is something that XLA fails to optimize well.

For this reason we have equinox.internal.scan instead, which can be used like so:

u, stage_vf = equinox.internal.scan(
    body, u, jnp.arange(self.tableau.num_stages),
    buffers=lambda x: x,
    kind="checkpointed",
    checkpoints=self.tableau.num_stages
)

The interesting part here is the buffers argument, which is used to specify a path to the particular arrays that will be the subject of this inplace-updating behaviour. The body function will then be called with an array wrapper that does some smart things to avoid the XLA issue.

(You've bumped straight into one of the most difficult JAX issues to tackle, I'm afraid!)


Other than that, note that this will error if start_stage is a traced value, i.e. if made_jump is not False, as then jnp.arange(start_stage, self.tableau.num_stages) will not be an array of known size.

I tihnk the easiest way to tackle this would be to remove the FSAL logic and just evaluate all stages on each step. (It is possible to still do FSAL, it's just very complicated – c.f. AbstractRungeKutta.step – let's not do that now 😁) To be sure we fix this, can you add a test that includes jumps?


b = vf + vf_increment + ((control * γ[stage]) * time_derivative)
# solving Ax=b
stage_u = lx.linear_solve(A, b).value
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a self.linear_solver argument that specifies the linear solver to use? (C.f. where this appears elsewhere in Diffrax, e.g. as an attribute of VeryChord)

@poonai poonai changed the title add support for ros3p rosenbrock method add support rosenbrock method Dec 7, 2025
@poonai poonai changed the title add support rosenbrock method add support for rosenbrock method Dec 7, 2025
@poonai poonai requested a review from patrick-kidger December 7, 2025 08:57
Copy link
Author

@poonai poonai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added 4 more rosenbrock methods. Actually, not much code change, It's just more tableaus....

I'm also working on those promised changes. I'll try to add those before you take look. You can also merge this if you find this version good enough.

# TODO: add complex dtype support.
raise ValueError("rosenbrock does not support complex dtypes.")

if isinstance(terms, ODETerm):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current Rosenbrock method is not intended to solve CDEs or SDEs. It is specifically aimed at stiff ODEs and index-1 DAEs. Because of this, I have explicitly rejected other types of terms.

To support MultiTerm, the vector fields and Jacobians need to be added together so that Rosenbrock still sees a single effective system. However, this currently requires changes in MultiTerm, because we cannot assume whether term.vf returns a single vector field or a tuple of vector fields.

I have now figured out the changes needed to support DAEs, and it turns out to be simpler than I initially expected. I plan to introduce a dedicated DAETerm to make this work.

I will implement both DAE and MultiTerm support together as part of the next PR.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To support MultiTerm, the vector fields and Jacobians need to be added together so that Rosenbrock still sees a single effective system.

I think this should already happen in the methods of MultiTerm!

What I'd suggest: write the code to accept specifically an ODETerm with a pytree-of-arrays as the state. And then, I suspect you'll find that the code you've written will very nearly 'just work' (with no changes) if you pass it any kind of AbstractTerm, including MultiTerm.

I plan to introduce a dedicated DAETerm to make this work.

I'd definitely ask that you hold off on this! I do have plans to support DAEs but I think it will be possible in a far more general way than this.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To support MultiTerm, the vector fields and Jacobians need to be added together so that Rosenbrock still sees a single effective system.

I think this should already happen in the methods of MultiTerm!

What I'd suggest: write the code to accept specifically an ODETerm with a pytree-of-arrays as the state. And then, I suspect you'll find that the code you've written will very nearly 'just work' (with no changes) if you pass it any kind of AbstractTerm, including MultiTerm.

I plan to introduce a dedicated DAETerm to make this work.

I'd definitely ask that you hold off on this! I do have plans to support DAEs but I think it will be possible in a far more general way than this.

solver_state: _SolverState,
made_jump: BoolScalarLike,
) -> tuple[Y, Y, DenseInfo, _SolverState, RESULTS]:
del solver_state, made_jump
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed FSAL 😓

@poonai poonai changed the title add support for rosenbrock method WIP: add support for rosenbrock method Dec 12, 2025
Signed-off-by: balaji <rbalajis25@gmail.com>
Signed-off-by: balaji <rbalajis25@gmail.com>
@poonai poonai changed the title WIP: add support for rosenbrock method add support for rosenbrock method Dec 23, 2025
order = scipy.stats.linregress(exponents, errors).slope # pyright: ignore
# We accept quite a wide range. Improving this test would be nice.
assert -0.9 < order - solver.order(term) < 0.9
assert -0.9 < order - solver.order(term)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rodas5p is 6th order for this particular problem.

@poonai
Copy link
Author

poonai commented Dec 24, 2025

Hi,

I initially used Julia’s implementation as a reference and copied their coefficients. However, it turned out that those were derived coefficients and did not follow the original paper directly. As a result, I took a significant detour to extract the coefficients myself. Later, @gstein3m generously shared the coefficients from the MATLAB code with me.

At first, I considered removing all methods except ros3p. However, I decided that adding a higher-order solver would be more beneficial to the community. That’s why it took quite some time to address your comments.

return jtu.tree_map(_eval, self.coeffs)


class RodasInterpolation(AbstractLocalInterpolation):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new interpolation method added for rodas class.

dt = jtu.tree_leaves(control)[0]
eye = jnp.eye(len(time_derivative))
if self.rodas:
A = lx.MatrixLinearOperator(eye - dt * γ[0] * jacobian)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not make the IdentityOperator work with MatrixOperator

 a = jnp.eye(3)
print(a.shape)  # shape (3,3)

RESULT = lx.IdentityLinearOperator(a.shape)  -lx.MatrixLinearOperator(a)  # fails
print(lx.IdentityLinearOperator((1,) * 3).as_matrix().shape)  # shape (3,3) 
RESULT = lx.IdentityLinearOperator((1,) * 3) -lx.MatrixLinearOperator(a)  # fails


jacobian = jax.jacfwd(lambda y: terms.vf_prod(t0, y, args, identity))(y0)
jacobian, _ = fu.ravel_pytree(jacobian)
jacobian = jnp.reshape(jacobian, time_derivative.shape * 2)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to materialize the jacobian, as it is used in rhs as well in the every stage equation.

control = terms.contr(t0, t1)
identity = jtu.tree_map(lambda leaf: jnp.ones_like(leaf), control)

time_derivative = jax.jacfwd(lambda t: terms.vf_prod(t, y0, args, identity))(t0)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used vf_prod as a workaround to handle MultiTerm, since it sums each term’s VF.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants