is it still the case? @junpenglao https://github.com/pymc-devs/pymc/blob/02cbac6592506a547117b2daa3f05847d16fe063/pymc/sampling/jax.py#L454