Skip to content

Overcookedv2 baseline does not run with the dependencies defined in the pyproject.toml #164

@zitr0y

Description

@zitr0y

'''
File "/home/aaron/Projects/personal/JaxMARL/baselines/IPPO/ippo_rnn_overcooked_v2.py", line 613, in main
out = jax.vmap(train_jit)(rngs)
File "/home/aaron/Projects/personal/JaxMARL/baselines/IPPO/ippo_rnn_overcooked_v2.py", line 283, in train
network_params = network.init(_rng, init_hstate, init_x)
File "/home/aaron/Projects/personal/JaxMARL/baselines/IPPO/ippo_rnn_overcooked_v2.py", line 142, in call
hidden, embedding = ScannedRNN()(hidden, rnn_in)
File "/home/aaron/miniforge3/envs/jaxmarl310/lib/python3.10/site-packages/flax/core/axes_scan.py", line 159, in scan_fn
debug_info = jax.api_util.debug_info("flax scan", broadcast_body,
AttributeError: module 'jax.api_util' has no attribute 'debug_info'
'''

This is likely a version compability issue. With

pip install jax==0.4.35 jaxlib==0.4.35 flax==0.8.5

this disappears but

'''
Traceback (most recent call last):
File "/home/aaron/Projects/personal/JaxMARL/baselines/IPPO/ippo_rnn_overcooked_v2.py", line 613, in main
out = jax.vmap(train_jit)(rngs)
File "/home/aaron/Projects/personal/JaxMARL/baselines/IPPO/ippo_rnn_overcooked_v2.py", line 583, in train
runner_state, metric = jax.lax.scan(
File "/home/aaron/Projects/personal/JaxMARL/baselines/IPPO/ippo_rnn_overcooked_v2.py", line 546, in _update_step
update_state, loss_info = jax.lax.scan(
File "/home/aaron/Projects/personal/JaxMARL/baselines/IPPO/ippo_rnn_overcooked_v2.py", line 525, in _update_epoch
train_state, total_loss = jax.lax.scan(
File "/home/aaron/Projects/personal/JaxMARL/baselines/IPPO/ippo_rnn_overcooked_v2.py", line 488, in _update_minbatch
total_loss, grads = grad_fn(
File "/home/aaron/Projects/personal/JaxMARL/baselines/IPPO/ippo_rnn_overcooked_v2.py", line 446, in _loss_fn
_, pi, value = network.apply(
File "/home/aaron/Projects/personal/JaxMARL/baselines/IPPO/ippo_rnn_overcooked_v2.py", line 142, in call
hidden, embedding = ScannedRNN()(hidden, rnn_in)
File "/home/aaron/miniforge3/envs/jaxmarl310/lib/python3.10/site-packages/flax/core/axes_scan.py", line 165, in scan_fn
c, ys = lax.scan(
TypeError: scan body function carry input and carry output must have equal types (e.g. shapes and dtypes of arrays), but they differ:

The input carry component c[1] has type float32[128] but the corresponding output carry component has type float32[1,128], so the shapes do not match.
'''

takes its place, which seems to be a shape mismatch. Might again be version-based, maybe this time its too old.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions