-
Notifications
You must be signed in to change notification settings - Fork 140
Description
'''
File "/home/aaron/Projects/personal/JaxMARL/baselines/IPPO/ippo_rnn_overcooked_v2.py", line 613, in main
out = jax.vmap(train_jit)(rngs)
File "/home/aaron/Projects/personal/JaxMARL/baselines/IPPO/ippo_rnn_overcooked_v2.py", line 283, in train
network_params = network.init(_rng, init_hstate, init_x)
File "/home/aaron/Projects/personal/JaxMARL/baselines/IPPO/ippo_rnn_overcooked_v2.py", line 142, in call
hidden, embedding = ScannedRNN()(hidden, rnn_in)
File "/home/aaron/miniforge3/envs/jaxmarl310/lib/python3.10/site-packages/flax/core/axes_scan.py", line 159, in scan_fn
debug_info = jax.api_util.debug_info("flax scan", broadcast_body,
AttributeError: module 'jax.api_util' has no attribute 'debug_info'
'''
This is likely a version compability issue. With
pip install jax==0.4.35 jaxlib==0.4.35 flax==0.8.5
this disappears but
'''
Traceback (most recent call last):
File "/home/aaron/Projects/personal/JaxMARL/baselines/IPPO/ippo_rnn_overcooked_v2.py", line 613, in main
out = jax.vmap(train_jit)(rngs)
File "/home/aaron/Projects/personal/JaxMARL/baselines/IPPO/ippo_rnn_overcooked_v2.py", line 583, in train
runner_state, metric = jax.lax.scan(
File "/home/aaron/Projects/personal/JaxMARL/baselines/IPPO/ippo_rnn_overcooked_v2.py", line 546, in _update_step
update_state, loss_info = jax.lax.scan(
File "/home/aaron/Projects/personal/JaxMARL/baselines/IPPO/ippo_rnn_overcooked_v2.py", line 525, in _update_epoch
train_state, total_loss = jax.lax.scan(
File "/home/aaron/Projects/personal/JaxMARL/baselines/IPPO/ippo_rnn_overcooked_v2.py", line 488, in _update_minbatch
total_loss, grads = grad_fn(
File "/home/aaron/Projects/personal/JaxMARL/baselines/IPPO/ippo_rnn_overcooked_v2.py", line 446, in _loss_fn
_, pi, value = network.apply(
File "/home/aaron/Projects/personal/JaxMARL/baselines/IPPO/ippo_rnn_overcooked_v2.py", line 142, in call
hidden, embedding = ScannedRNN()(hidden, rnn_in)
File "/home/aaron/miniforge3/envs/jaxmarl310/lib/python3.10/site-packages/flax/core/axes_scan.py", line 165, in scan_fn
c, ys = lax.scan(
TypeError: scan body function carry input and carry output must have equal types (e.g. shapes and dtypes of arrays), but they differ:
The input carry component c[1] has type float32[128] but the corresponding output carry component has type float32[1,128], so the shapes do not match.
'''
takes its place, which seems to be a shape mismatch. Might again be version-based, maybe this time its too old.