-
-
Notifications
You must be signed in to change notification settings - Fork 164
Description
Hi,
I think this is partially related to #60 as it involves storing some information after an accepted step, but the difference is that I actually need to access the last known information from inside the ODE function.
I have an ODE function that requires the use of a nonlinear solver to compute the derivatives. At the moment, I'm using a fixed initial guess for NewtonNonlinearSolver, but this is inefficient. What I'd like to do is, after an accepted step, store the found root and use it as the initial guess during the next integration step. I was doing this in torchdiffeq successfully, but I can't see an equivalent way in Diffrax.
As a (contrived) example: the code below performs some sort of nonlinear solve, but each time with a poor initial guess (meaning it takes 10 iterations to converge at each call to the ODE function). If I set init_x = 0.9, which is a much better guess in this case, it takes two or three iterations, so the potential benefit is clear (especially for more expensive nonlinear functions). In this case, I wouldn't expect to run into weird issues with gradients, because backpropagating through NewtonNonlinearSolver shouldn't depend on the initial guess.
Thanks!
from diffrax import diffeqsolve, ODETerm, Dopri5, NewtonNonlinearSolver
import jax.debug
import jax.numpy as jnp
init_x = 0.1
nl_solver = NewtonNonlinearSolver(rtol=1e-3, atol=1e-6)
def f_nonlinear(x, y):
return jnp.cos(y * x) - x**3
def f(t, y, args):
sol = nl_solver(f_nonlinear, init_x, y)
jax.debug.print(
"t=t{t}, {n} iterations, x={x}",
t=t,
n=sol.num_steps,
x=sol.root,
)
return -sol.root
term = ODETerm(f)
solver = Dopri5()
y0 = 1.0
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)