diff --git a/tests/DETEST/run.py b/tests/DETEST/run.py index 8f3e9ca9c..0842f53aa 100644 --- a/tests/DETEST/run.py +++ b/tests/DETEST/run.py @@ -22,8 +22,8 @@ def __call__(self, t, y): def main(): sol = dict() - for method in ['dopri5', 'adams']: - for tol in [1e-3, 1e-6, 1e-9]: + for method in ['dopri5', 'classic_dopri5']: + for tol in [1e-3, 1e-4, 1e-5, 1e-6]: print('======= {} | tol={:e} ======='.format(method, tol)) nfes = [] times = [] diff --git a/torchdiffeq/_impl/dopri5.py b/torchdiffeq/_impl/dopri5.py index 1a925ef1d..bcc610632 100644 --- a/torchdiffeq/_impl/dopri5.py +++ b/torchdiffeq/_impl/dopri5.py @@ -1,18 +1,39 @@ import torch from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver +alpha=torch.tensor([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.], dtype=torch.float64) +beta=[ + torch.tensor([1 / 5], dtype=torch.float64), + torch.tensor([3 / 40, 9 / 40], dtype=torch.float64), + torch.tensor([44 / 45, -56 / 15, 32 / 9], dtype=torch.float64), + torch.tensor([19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], dtype=torch.float64), + torch.tensor([9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], dtype=torch.float64), + torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84], dtype=torch.float64), +] +c_sol=torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0], dtype=torch.float64) + + +_CLASSIC_DORMAND_PRINCE_SHAMPINE_TABLEAU = _ButcherTableau( + alpha=alpha, beta=beta, c_sol=c_sol, + c_error=torch.tensor([ + 35 / 384 - 5179 / 57600, + 0, + 500 / 1113 - 7571 / 16695, + 125 / 192 - 393 / 640, + -2187 / 6784 - -92097 / 339200, + 11 / 84 - 187 / 2100, + -1 / 40, + ], dtype=torch.float64), +) + +DPS_C_MID = torch.tensor([ + 6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2, -2691868925 / 45128329728 / 2, + 187940372067 / 1594534317056 / 2, -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2 +], dtype=torch.float64) + _DORMAND_PRINCE_SHAMPINE_TABLEAU = _ButcherTableau( - alpha=torch.tensor([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.], dtype=torch.float64), - beta=[ - torch.tensor([1 / 5], dtype=torch.float64), - torch.tensor([3 / 40, 9 / 40], dtype=torch.float64), - torch.tensor([44 / 45, -56 / 15, 32 / 9], dtype=torch.float64), - torch.tensor([19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], dtype=torch.float64), - torch.tensor([9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], dtype=torch.float64), - torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84], dtype=torch.float64), - ], - c_sol=torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0], dtype=torch.float64), + alpha=alpha, beta=beta, c_sol=c_sol, c_error=torch.tensor([ 35 / 384 - 1951 / 21600, 0, @@ -24,13 +45,16 @@ ], dtype=torch.float64), ) -DPS_C_MID = torch.tensor([ - 6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2, -2691868925 / 45128329728 / 2, - 187940372067 / 1594534317056 / 2, -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2 -], dtype=torch.float64) - class Dopri5Solver(RKAdaptiveStepsizeODESolver): order = 5 tableau = _DORMAND_PRINCE_SHAMPINE_TABLEAU mid = DPS_C_MID + + +class ClassicDopri5Solver(RKAdaptiveStepsizeODESolver): + order = 5 + tableau = _CLASSIC_DORMAND_PRINCE_SHAMPINE_TABLEAU + mid = DPS_C_MID + + diff --git a/torchdiffeq/_impl/odeint.py b/torchdiffeq/_impl/odeint.py index a174219ad..2b47eaf48 100644 --- a/torchdiffeq/_impl/odeint.py +++ b/torchdiffeq/_impl/odeint.py @@ -1,6 +1,6 @@ import torch from torch.autograd.functional import vjp -from .dopri5 import Dopri5Solver +from .dopri5 import ClassicDopri5Solver, Dopri5Solver from .bosh3 import Bosh3Solver from .adaptive_heun import AdaptiveHeunSolver from .fehlberg2 import Fehlberg2 @@ -13,6 +13,7 @@ SOLVERS = { 'dopri8': Dopri8Solver, 'dopri5': Dopri5Solver, + 'classic_dopri5': ClassicDopri5Solver, 'bosh3': Bosh3Solver, 'fehlberg2': Fehlberg2, 'adaptive_heun': AdaptiveHeunSolver,