From 1a77509020f65accae9ac0ad068761e244b8a1bd Mon Sep 17 00:00:00 2001 From: Neil Kichler Date: Fri, 28 Jun 2024 13:31:53 +0200 Subject: [PATCH] Support Jax version 4.25.0 and higher. --- autobound/jax/jaxpr_editor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/autobound/jax/jaxpr_editor.py b/autobound/jax/jaxpr_editor.py index d5bca20..5c2306c 100644 --- a/autobound/jax/jaxpr_editor.py +++ b/autobound/jax/jaxpr_editor.py @@ -169,7 +169,11 @@ def vertex_to_var_or_literal(vertex): if vertex[0]: _, count, suffix, aval = vertex if count not in count_to_var: - count_to_var[count] = jax.core.Var(count, suffix, aval) + if jax.__version__ >= '0.4.25': + # count argument was removed in jax 0.4.25: https://github.com/google/jax/pull/10573 + count_to_var[count] = jax.core.Var(suffix, aval) + else: + count_to_var[count] = jax.core.Var(count, suffix, aval) return count_to_var[count] else: _, val, aval = vertex