diff --git a/pysages/grids.py b/pysages/grids.py index 6c2039ee..8be7987e 100644 --- a/pysages/grids.py +++ b/pysages/grids.py @@ -5,9 +5,9 @@ from jax import jit from jax import numpy as np -from plum import Union, parametric +from plum import parametric -from pysages.typing import JaxArray +from pysages.typing import JaxArray, Union from pysages.utils import dispatch, is_generic_subclass, prod