Skip to content

Brax V2 pbd is about 3x slower than V1 #371

@imoneoi

Description

@imoneoi

Why is the positional backend (pbd) in Brax v2 about 3x slower than v1 on humanoid? I observed that substep and dt may have differences, but v1 has a total of 533 pbd steps, and v2 has a total of 666 pbd steps, the difference should not be that big.

Here are the benchmark results:

JAX 0.4.8 @ RTX 3090 * 1

V2 Humanoid PBD

Pop size 1024 FPS 212603.0
Pop size 2048 FPS 341043.9
Pop size 4096 FPS 430086.4
Pop size 10240 FPS 411593.8

V1 Humanoid PBD

Pop size 1024 FPS 441713.6
Pop size 2048 FPS 873183.5
Pop size 4096 FPS 1237109.4
Pop size 10240 FPS 1310780.7

JAX 0.4.11 @ RTX 3090 * 1

V2 Humanoid PBD

Pop size 1024 FPS 199116.0
Pop size 2048 FPS 328199.3
Pop size 4096 FPS 415347.5
Pop size 10240 FPS 403450.6

V1 Humanoid PBD

Pop size 1024 FPS 467750.8
Pop size 2048 FPS 926400.9
Pop size 4096 FPS 1372671.0
Pop size 10240 FPS 1417486.8

The benchmark code is as following:

from functools import partial
import time

import jax
import jax.numpy as jnp

from brax import envs


def create_env(env_name, is_v2, max_steps_per_episode=1000):
    if is_v2:
        # V2 API
        env = envs.get_environment(env_name, backend="positional")
        env = envs.wrapper.EpisodeWrapper(env, max_steps_per_episode, 1)
        env = envs.wrapper.VmapWrapper(env)
    else:
        # V1 API
        env = envs.get_environment(env_name)
        env = envs.wrappers.EpisodeWrapper(env, max_steps_per_episode, 1)
        env = envs.wrappers.VmapWrapper(env)

    return env


@partial(jax.jit, static_argnames=["env", "pop_size", "steps"])
def benchmark(
    env,
    seed:     int = 0,
    pop_size: int = 10240,
    steps:    int = 1000,
):
    # Init state
    init_state_key, act_seq_key = jax.random.split(jax.random.PRNGKey(seed))
    init_state                  = env.reset(jax.random.split(init_state_key, pop_size))

    act_seq                     = jax.random.uniform(act_seq_key, (steps, pop_size, env.action_size), minval=-1, maxval=1)

    # Scan
    def _step_env(carry, act):
        return env.step(carry, act), None
    
    return jax.lax.scan(_step_env, init_state, act_seq)


def main():
    env_name  = "humanoid"
    is_v2     = False

    pop_sizes = [1024, 2048, 4096, 10240]
    steps     = 1000

    # create env
    env = create_env(env_name, is_v2)

    # bench
    for pop_size in pop_sizes:
        conf = dict(env=env, steps=steps, pop_size=pop_size)

        # JIT warmup
        result = benchmark(**conf)
        jax.tree_map(lambda x: x.block_until_ready(), result)

        # Test time
        t_start = time.time()

        result = benchmark(**conf)
        jax.tree_map(lambda x: x.block_until_ready(), result)

        fps = (pop_size * steps) / (time.time() - t_start)

        print(f"Pop size {pop_size} FPS {fps:.1f}")


if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions