Why is the positional backend (pbd) in Brax v2 about 3x slower than v1 on humanoid? I observed that substep and dt may have differences, but v1 has a total of 533 pbd steps, and v2 has a total of 666 pbd steps, the difference should not be that big.
Here are the benchmark results:
JAX 0.4.8 @ RTX 3090 * 1
V2 Humanoid PBD
Pop size 1024 FPS 212603.0
Pop size 2048 FPS 341043.9
Pop size 4096 FPS 430086.4
Pop size 10240 FPS 411593.8
V1 Humanoid PBD
Pop size 1024 FPS 441713.6
Pop size 2048 FPS 873183.5
Pop size 4096 FPS 1237109.4
Pop size 10240 FPS 1310780.7
JAX 0.4.11 @ RTX 3090 * 1
V2 Humanoid PBD
Pop size 1024 FPS 199116.0
Pop size 2048 FPS 328199.3
Pop size 4096 FPS 415347.5
Pop size 10240 FPS 403450.6
V1 Humanoid PBD
Pop size 1024 FPS 467750.8
Pop size 2048 FPS 926400.9
Pop size 4096 FPS 1372671.0
Pop size 10240 FPS 1417486.8
The benchmark code is as following:
from functools import partial
import time
import jax
import jax.numpy as jnp
from brax import envs
def create_env(env_name, is_v2, max_steps_per_episode=1000):
if is_v2:
# V2 API
env = envs.get_environment(env_name, backend="positional")
env = envs.wrapper.EpisodeWrapper(env, max_steps_per_episode, 1)
env = envs.wrapper.VmapWrapper(env)
else:
# V1 API
env = envs.get_environment(env_name)
env = envs.wrappers.EpisodeWrapper(env, max_steps_per_episode, 1)
env = envs.wrappers.VmapWrapper(env)
return env
@partial(jax.jit, static_argnames=["env", "pop_size", "steps"])
def benchmark(
env,
seed: int = 0,
pop_size: int = 10240,
steps: int = 1000,
):
# Init state
init_state_key, act_seq_key = jax.random.split(jax.random.PRNGKey(seed))
init_state = env.reset(jax.random.split(init_state_key, pop_size))
act_seq = jax.random.uniform(act_seq_key, (steps, pop_size, env.action_size), minval=-1, maxval=1)
# Scan
def _step_env(carry, act):
return env.step(carry, act), None
return jax.lax.scan(_step_env, init_state, act_seq)
def main():
env_name = "humanoid"
is_v2 = False
pop_sizes = [1024, 2048, 4096, 10240]
steps = 1000
# create env
env = create_env(env_name, is_v2)
# bench
for pop_size in pop_sizes:
conf = dict(env=env, steps=steps, pop_size=pop_size)
# JIT warmup
result = benchmark(**conf)
jax.tree_map(lambda x: x.block_until_ready(), result)
# Test time
t_start = time.time()
result = benchmark(**conf)
jax.tree_map(lambda x: x.block_until_ready(), result)
fps = (pop_size * steps) / (time.time() - t_start)
print(f"Pop size {pop_size} FPS {fps:.1f}")
if __name__ == "__main__":
main()
Why is the positional backend (pbd) in Brax v2 about 3x slower than v1 on humanoid? I observed that substep and dt may have differences, but v1 has a total of 533 pbd steps, and v2 has a total of 666 pbd steps, the difference should not be that big.
Here are the benchmark results:
The benchmark code is as following: