Skip to content

Commit 1d5fbd7

Browse files
Brax Teambtaba
authored andcommitted
Internal change
PiperOrigin-RevId: 520465701 Change-Id: I6d34c4a7398ade1ed882c66bf6e9cdac12f39101
1 parent fc8f410 commit 1d5fbd7

File tree

407 files changed

+10330
-14492
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

407 files changed

+10330
-14492
lines changed

MANIFEST.in

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
include brax/tests/testdata/cylinder.stl
2-
include brax/v2/envs/assets/*.xml
3-
recursive-include brax/v2/test_data *.xml *.stl *.obj
4-
recursive-include brax/v2/visualizer *
1+
include brax/envs/assets/*.xml
2+
recursive-include brax/test_data *.xml *.stl *.obj
3+
recursive-include brax/visualizer *

README.md

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
<img src="https://github.com/google/brax/raw/main/docs/img/brax_logo.gif" width="336" height="80" alt="BRAX"/>
22

3-
Brax is a differentiable physics engine that simulates environments made up of
4-
rigid bodies, joints, and actuators. Brax is written in
5-
[JAX](https://github.com/google/jax) and is designed for use on acceleration
6-
hardware. It is both efficient for single-device simulation, and scalable to
7-
massively parallel simulation on multiple devices, without the need for pesky
8-
datacenters.
3+
Brax is a fast and fully differentiable physics engine used for research and
4+
development of robotics, human perception, materials science, reinforcement
5+
learning, and other simulation-heavy applications.
96

10-
<img src="https://github.com/google/brax/raw/main/docs/img/ant.gif" width="150" height="107"/><img src="https://github.com/google/brax/raw/main/docs/img/fetch.gif" width="150" height="107"/><img src="https://github.com/google/brax/raw/main/docs/img/grasp.gif" width="150" height="107"/><img src="https://github.com/google/brax/raw/main/docs/img/halfcheetah.gif" width="150" height="107"/><img src="https://github.com/google/brax/raw/main/docs/img/humanoid.gif" width="150" height="107"/>
7+
Brax is written in [JAX](https://github.com/google/jax) and is designed for use
8+
on acceleration hardware. It is both efficient for single-device simulation, and
9+
scalable to massively parallel simulation on multiple devices, without the need
10+
for pesky datacenters.
1111

12-
*Some policies trained via Brax. Brax simulates these environments at millions
13-
of physics steps per second on TPU.*
12+
<img src="https://github.com/google/brax/raw/main/docs/img/humanoid_v2.gif" width="250" height="250"/><img src="https://github.com/google/brax/raw/main/docs/img/a1.gif" width="250" height="250"/><img src="https://github.com/google/brax/raw/main/docs/img/ant_v2.gif" width="250" height="250"/><img src="https://github.com/google/brax/raw/main/docs/img/ur5e.gif" width="250" height="250"/>
1413

15-
Brax also includes a suite of learning algorithms that train agents in seconds
14+
Brax simulates environments at millions of physics steps per second on TPU, and includes a suite of learning algorithms that train agents in seconds
1615
to minutes:
1716

1817
* Baseline learning algorithms such as
@@ -22,17 +21,34 @@ to minutes:
2221
[evolutionary strategies](https://github.com/google/brax/blob/main/brax/training/agents/es).
2322
* Learning algorithms that leverage the differentiability of the simulator, such as [analytic policy gradients](https://github.com/google/brax/blob/main/brax/training/agents/apg).
2423

24+
## One API, Three Pipelines
25+
26+
Brax offers three distinct physics pipelines that are easy to swap:
27+
28+
* [Generalized](https://github.com/google/brax/blob/main/brax/v2/generalized/)
29+
calculates motion in [generalized coordinates](https://en.wikipedia.org/wiki/Generalized_coordinates) using the same accurate robot
30+
dynamics algorithms as [MuJoCo](https://mujoco.org/) and [TDS](https://github.com/erwincoumans/tiny-differentiable-simulator).
31+
* [Positional](https://github.com/google/brax/blob/main/brax/v2/positional/)
32+
uses [Position Based Dynamics](https://matthias-research.github.io/pages/publications/posBasedDyn.pdf),
33+
a fast but stable method of resolving joint and collision constraints.
34+
* [Spring](https://github.com/google/brax/blob/main/brax/v2/spring/) provides
35+
fast and cheap simulation for rapid experimentation, using simple impulse-based
36+
methods often found in video games.
37+
38+
These pipelines share the same API and can run side-by-side within the same
39+
simulation. This makes Brax well suited for experiments in transfer learning
40+
and closing the gap between simulation and the real world.
41+
2542
## Quickstart: Colab in the Cloud
2643

2744
Explore Brax easily and quickly through a series of colab notebooks:
2845

2946
* [Brax Basics](https://colab.research.google.com/github/google/brax/blob/main/notebooks/basics.ipynb) introduces the Brax API, and shows how to simulate basic physics primitives.
30-
* [Brax Environments](https://colab.research.google.com/github/google/brax/blob/main/notebooks/environments.ipynb) shows how to operate and visualize Brax environments. It also demonstrates converting Brax environments to Gym environments, and how to use Brax via other ML frameworks such as PyTorch.
31-
* [Brax Training with TPU](https://colab.research.google.com/github/google/brax/blob/main/notebooks/training.ipynb) introduces Brax's training algorithms, and lets you train your own policies directly within the colab. It also demonstrates loading and saving policies.
32-
* [Brax Training with PyTorch on GPU](https://colab.research.google.com/github/google/brax/blob/main/notebooks/training_torch.ipynb) demonstrates how Brax can be used in other ML frameworks for fast training, in this case PyTorch.
33-
* [Brax Multi-Agent](https://colab.research.google.com/github/google/brax/blob/main/notebooks/multiagent.ipynb) measures Brax's performance on multi-agent simulation, with many bodies in the environment at once.
47+
* [Brax Training](https://colab.research.google.com/github/google/brax/blob/main/notebooks/training.ipynb)
48+
introduces the Brax v2 API, and shows how to train a policy with the
49+
generalized backend.
3450

35-
## Using Brax locally
51+
## Using Brax Locally
3652

3753
To install Brax from pypi, install it with:
3854

bin/learn

Lines changed: 0 additions & 7 deletions
This file was deleted.

brax/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 The Brax Authors.
1+
# Copyright 2023 The Brax Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -16,7 +16,7 @@
1616

1717
__version__ = '0.1.2'
1818

19-
from brax.physics.base import Info
20-
from brax.physics.base import QP
21-
from brax.physics.config_pb2 import Config
22-
from brax.physics.system import System
19+
from brax.base import Motion
20+
from brax.base import State
21+
from brax.base import System
22+
from brax.base import Transform
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 The Brax Authors.
1+
# Copyright 2023 The Brax Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -15,8 +15,8 @@
1515
# pylint:disable=g-multiple-import
1616
"""Functions for applying actuators to a physics pipeline."""
1717

18-
from brax.v2 import scan
19-
from brax.v2.base import System
18+
from brax import scan
19+
from brax.base import System
2020
from jax import numpy as jp
2121

2222

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 The Brax Authors.
1+
# Copyright 2023 The Brax Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -17,11 +17,11 @@
1717

1818
from absl.testing import absltest
1919
from absl.testing import parameterized
20-
from brax.v2 import actuator
21-
from brax.v2 import test_utils
22-
from brax.v2.generalized import pipeline as g_pipeline
23-
from brax.v2.positional import pipeline as p_pipeline
24-
from brax.v2.spring import pipeline as s_pipeline
20+
from brax import actuator
21+
from brax import test_utils
22+
from brax.generalized import pipeline as g_pipeline
23+
from brax.positional import pipeline as p_pipeline
24+
from brax.spring import pipeline as s_pipeline
2525
import jax
2626
from jax import numpy as jp
2727
import mujoco

brax/v2/base.py renamed to brax/base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 The Brax Authors.
1+
# Copyright 2023 The Brax Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -16,7 +16,7 @@
1616
import functools
1717
from typing import Any, List, Optional, Sequence, Tuple, Union
1818

19-
from brax.v2 import math
19+
from brax import math
2020
from flax import struct
2121
from jax import numpy as jp
2222
from jax import vmap
@@ -247,7 +247,7 @@ def mul(self, m: Motion) -> 'Force':
247247
class Link(Base):
248248
"""A rigid segment of an articulated body.
249249
250-
Links are connected to eachother by joints. By moving (rotating or
250+
Links are connected to each other by joints. By moving (rotating or
251251
translating) the joints, the entire system can be articulated.
252252
253253
Attributes:
@@ -586,16 +586,16 @@ def qd_idx(self, link_type: str) -> jp.ndarray:
586586
return jp.array(idxs)
587587

588588
def q_size(self) -> int:
589-
"""Returns the size of the q vector (joint position) for this sytem."""
589+
"""Returns the size of the q vector (joint position) for this system."""
590590
return sum([Q_WIDTHS[t] for t in self.link_types])
591591

592592
def qd_size(self) -> int:
593-
"""Returns the size of the qd vector (joint velocity) for this sytem."""
593+
"""Returns the size of the qd vector (joint velocity) for this system."""
594594
return sum([QD_WIDTHS[t] for t in self.link_types])
595595

596596
def act_size(self) -> int:
597597
"""Returns the act dimension for the system."""
598-
return sum([QD_WIDTHS[self.link_types[i]] for i in self.actuator_link_id])
598+
return sum({'m': 1, 'p': 1}[act_typ] for act_typ in self.actuator_types)
599599

600600

601601
# below are some operation dispatch derivations
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 The Brax Authors.
1+
# Copyright 2023 The Brax Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -16,8 +16,8 @@
1616
# pylint:disable=g-multiple-import
1717
from typing import Tuple
1818

19-
from brax.v2 import math
20-
from brax.v2.base import Motion, System, Transform
19+
from brax import math
20+
from brax.base import Motion, System, Transform
2121
import jax
2222
from jax import numpy as jp
2323

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 The Brax Authors.
1+
# Copyright 2023 The Brax Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -15,10 +15,10 @@
1515
"""Tests for com."""
1616
# pylint:disable=g-multiple-import
1717
from absl.testing import absltest
18-
from brax.v2 import kinematics
19-
from brax.v2 import math
20-
from brax.v2 import test_utils
21-
from brax.v2.spring import com
18+
from brax import com
19+
from brax import kinematics
20+
from brax import math
21+
from brax import test_utils
2222
import jax
2323
from jax import numpy as jp
2424
import numpy as np

brax/envs/__init__.py

Lines changed: 50 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 The Brax Authors.
1+
# Copyright 2023 The Brax Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -13,116 +13,92 @@
1313
# limitations under the License.
1414

1515
# pylint:disable=g-multiple-import
16-
"""Some example environments to help get started quickly with brax."""
16+
"""Environments for training and evaluating policies."""
1717

1818
import functools
19-
from typing import Callable, Optional, Type, Union, overload
19+
from typing import Optional, Type
2020

21-
from brax.envs import acrobot
2221
from brax.envs import ant
2322
from brax.envs import fast
24-
from brax.envs import fetch
25-
from brax.envs import grasp
2623
from brax.envs import half_cheetah
2724
from brax.envs import hopper
2825
from brax.envs import humanoid
29-
from brax.envs import humanoid_standup
26+
from brax.envs import humanoidstandup
3027
from brax.envs import inverted_double_pendulum
3128
from brax.envs import inverted_pendulum
3229
from brax.envs import pusher
3330
from brax.envs import reacher
34-
from brax.envs import reacherangle
35-
from brax.envs import swimmer
36-
from brax.envs import ur5e
3731
from brax.envs import walker2d
38-
from brax.envs import wrappers
39-
from brax.envs.env import Env, State, Wrapper
40-
import gym
41-
32+
from brax.envs import wrapper
33+
from brax.envs.env import Env, State
4234

4335
_envs = {
44-
'acrobot': acrobot.Acrobot,
45-
'ant': functools.partial(ant.Ant, use_contact_forces=True),
36+
'ant': ant.Ant,
4637
'fast': fast.Fast,
47-
'fetch': fetch.Fetch,
48-
'grasp': grasp.Grasp,
4938
'halfcheetah': half_cheetah.Halfcheetah,
5039
'hopper': hopper.Hopper,
5140
'humanoid': humanoid.Humanoid,
52-
'humanoidstandup': humanoid_standup.HumanoidStandup,
41+
'humanoidstandup': humanoidstandup.HumanoidStandup,
5342
'inverted_pendulum': inverted_pendulum.InvertedPendulum,
5443
'inverted_double_pendulum': inverted_double_pendulum.InvertedDoublePendulum,
5544
'pusher': pusher.Pusher,
5645
'reacher': reacher.Reacher,
57-
'reacherangle': reacherangle.ReacherAngle,
58-
'swimmer': swimmer.Swimmer,
59-
'ur5e': ur5e.Ur5e,
6046
'walker2d': walker2d.Walker2d,
6147
}
6248

6349

64-
def get_environment(env_name, **kwargs) -> Env:
50+
51+
def get_environment(env_name: str, **kwargs) -> Env:
52+
"""Returns an environment from the environment registry.
53+
54+
Args:
55+
env_name: environment name string
56+
**kwargs: keyword arguments that get passed to the Env class constructor
57+
58+
Returns:
59+
env: an environment
60+
"""
6561
return _envs[env_name](**kwargs)
6662

6763

6864
def register_environment(env_name: str, env_class: Type[Env]):
65+
"""Adds an environment to the registry.
66+
67+
Args:
68+
env_name: environment name string
69+
env_class: the Env class to add to the registry
70+
"""
6971
_envs[env_name] = env_class
7072

7173

72-
def create(env_name: str,
73-
episode_length: int = 1000,
74-
action_repeat: int = 1,
75-
auto_reset: bool = True,
76-
batch_size: Optional[int] = None,
77-
eval_metrics: bool = False,
78-
**kwargs) -> Env:
79-
"""Creates an Env with a specified brax system."""
74+
def create(
75+
env_name: str,
76+
episode_length: int = 1000,
77+
action_repeat: int = 1,
78+
auto_reset: bool = True,
79+
batch_size: Optional[int] = None,
80+
**kwargs,
81+
) -> Env:
82+
"""Creates an environment from the registry.
83+
84+
Args:
85+
env_name: environment name string
86+
episode_length: length of episode
87+
action_repeat: how many repeated actions to take per environment step
88+
auto_reset: whether to auto reset the environment after an episode is done
89+
batch_size: the number of environments to batch together
90+
**kwargs: keyword argments that get passed to the Env class constructor
91+
92+
Returns:
93+
env: an environment
94+
"""
8095
env = _envs[env_name](**kwargs)
96+
8197
if episode_length is not None:
82-
env = wrappers.EpisodeWrapper(env, episode_length, action_repeat)
98+
env = wrapper.EpisodeWrapper(env, episode_length, action_repeat)
8399
if batch_size:
84-
env = wrappers.VectorWrapper(env, batch_size)
100+
env = wrapper.VmapWrapper(env, batch_size)
85101
if auto_reset:
86-
env = wrappers.AutoResetWrapper(env)
87-
if eval_metrics:
88-
env = wrappers.EvalWrapper(env)
102+
env = wrapper.AutoResetWrapper(env)
89103

90104
return env # type: ignore
91-
92-
93-
def create_fn(env_name: str, **kwargs) -> Callable[..., Env]:
94-
"""Returns a function that when called, creates an Env."""
95-
return functools.partial(create, env_name, **kwargs)
96-
97-
98-
@overload
99-
def create_gym_env(env_name: str,
100-
batch_size: None = None,
101-
seed: int = 0,
102-
backend: Optional[str] = None,
103-
**kwargs) -> gym.Env:
104-
...
105-
106-
107-
@overload
108-
def create_gym_env(env_name: str,
109-
batch_size: int,
110-
seed: int = 0,
111-
backend: Optional[str] = None,
112-
**kwargs) -> gym.vector.VectorEnv:
113-
...
114-
115-
116-
def create_gym_env(env_name: str,
117-
batch_size: Optional[int] = None,
118-
seed: int = 0,
119-
backend: Optional[str] = None,
120-
**kwargs) -> Union[gym.Env, gym.vector.VectorEnv]:
121-
"""Creates a `gym.Env` or `gym.vector.VectorEnv` from a Brax environment."""
122-
environment = create(env_name=env_name, batch_size=batch_size, **kwargs)
123-
if batch_size is None:
124-
return wrappers.GymWrapper(environment, seed=seed, backend=backend)
125-
if batch_size <= 0:
126-
raise ValueError(
127-
'`batch_size` should either be None or a positive integer.')
128-
return wrappers.VectorGymWrapper(environment, seed=seed, backend=backend)

0 commit comments

Comments
 (0)