Skip to content

Commit 4981d05

Browse files
btabacopybara-github
authored andcommitted
Add uv index for source install. Add warp impl to training script.
PiperOrigin-RevId: 795515399 Change-Id: I4e3397893d5a417c1888dc2b10ff39eabfcfda26
1 parent 619e6cb commit 4981d05

File tree

3 files changed

+112
-65
lines changed

3 files changed

+112
-65
lines changed

README.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,16 @@ A comprehensive suite of GPU-accelerated environments for robot learning researc
88

99
Features include:
1010

11-
- Classic control environments from `dm_control` reimplemented in MJX.
11+
- Classic control environments from `dm_control`.
1212
- Quadruped and bipedal locomotion environments.
1313
- Non-prehensile and dexterous manipulation environments.
1414
- Vision-based support available via [Madrona-MJX](https://github.com/shacklettbp/madrona_mjx).
1515

1616
For more details, check out the project [website](https://playground.mujoco.org/).
1717

18+
> [!NOTE]
19+
> We now support training with both the MuJoCo MJX JAX implementation, as well as the [MuJoCo Warp](https://github.com/google-deepmind/mujoco_warp) implementation at HEAD. See MuJoCo 3.3.5 [release notes](https://mujoco.readthedocs.io/en/stable/changelog.html#version-3-3-5-august-8-2025) under `MJX` for more details.
20+
1821
## Installation
1922

2023
You can install MuJoCo Playground directly from PyPI:
@@ -23,6 +26,14 @@ You can install MuJoCo Playground directly from PyPI:
2326
pip install playground
2427
```
2528

29+
> [!WARNING]
30+
> The `playground` release may depend on pre-release versions of `mujoco` and
31+
> `warp-lang`, in which case you can try `pip install playground
32+
> --extra-index-url=https://py.mujoco.org
33+
> --extra-index-url=https://pypi.nvidia.com/warp-lang/`.
34+
> If there are still version mismatches, please open a github issue, and install
35+
> from source.
36+
2637
### From Source
2738

2839
> [!IMPORTANT]

learning/train_jax_ppo.py

Lines changed: 75 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# ==============================================================================
1515
"""Train a PPO agent using JAX on the specified environment."""
1616

17-
from datetime import datetime
17+
import datetime
1818
import functools
1919
import json
2020
import os
@@ -28,22 +28,20 @@
2828
from brax.training.agents.ppo import networks_vision as ppo_networks_vision
2929
from brax.training.agents.ppo import train as ppo
3030
from etils import epath
31-
from flax.training import orbax_utils
3231
import jax
3332
import jax.numpy as jp
3433
import mediapy as media
3534
from ml_collections import config_dict
3635
import mujoco
37-
from orbax import checkpoint as ocp
38-
from tensorboardX import SummaryWriter
39-
import wandb
40-
4136
import mujoco_playground
4237
from mujoco_playground import registry
4338
from mujoco_playground import wrapper
4439
from mujoco_playground.config import dm_control_suite_params
4540
from mujoco_playground.config import locomotion_params
4641
from mujoco_playground.config import manipulation_params
42+
import tensorboardX
43+
import wandb
44+
4745

4846
xla_flags = os.environ.get("XLA_FLAGS", "")
4947
xla_flags += " --xla_gpu_triton_gemm_any=True"
@@ -69,6 +67,7 @@
6967
"LeapCubeReorient",
7068
f"Name of the environment. One of {', '.join(registry.ALL_ENVS)}",
7169
)
70+
_IMPL = flags.DEFINE_enum("impl", "jax", ["jax", "warp"], "MJX implementation")
7271
_VISION = flags.DEFINE_boolean("vision", False, "Use vision input")
7372
_LOAD_CHECKPOINT_PATH = flags.DEFINE_string(
7473
"load_checkpoint_path", None, "Path to load checkpoint from"
@@ -92,6 +91,9 @@
9291
_NUM_TIMESTEPS = flags.DEFINE_integer(
9392
"num_timesteps", 1_000_000, "Number of timesteps"
9493
)
94+
_NUM_VIDEOS = flags.DEFINE_integer(
95+
"num_videos", 1, "Number of videos to record after training."
96+
)
9597
_NUM_EVALS = flags.DEFINE_integer("num_evals", 5, "Number of evaluations")
9698
_REWARD_SCALING = flags.DEFINE_float("reward_scaling", 0.1, "Reward scaling")
9799
_EPISODE_LENGTH = flags.DEFINE_integer("episode_length", 1000, "Episode length")
@@ -269,7 +271,7 @@ def main(argv):
269271
print(f"PPO Training Parameters:\n{ppo_params}")
270272

271273
# Generate unique experiment name
272-
now = datetime.now()
274+
now = datetime.datetime.now()
273275
timestamp = now.strftime("%Y%m%d-%H%M%S")
274276
exp_name = f"{_ENV_NAME.value}-{timestamp}"
275277
if _SUFFIX.value is not None:
@@ -289,7 +291,7 @@ def main(argv):
289291

290292
# Initialize TensorBoard if required
291293
if _USE_TB.value and not _PLAY_ONLY.value:
292-
writer = SummaryWriter(logdir)
294+
writer = tensorboardX.SummaryWriter(logdir)
293295

294296
# Handle checkpoint loading
295297
if _LOAD_CHECKPOINT_PATH.value is not None:
@@ -393,18 +395,26 @@ def progress(num_steps, metrics):
393395
f" reward={metrics['episode/sum_reward']:.3f}"
394396
)
395397

396-
# Load evaluation environment
397-
eval_env = (
398-
None if _VISION.value else registry.load(_ENV_NAME.value, config=env_cfg)
399-
)
398+
# Load evaluation environment.
399+
config_overrides = {"impl": _IMPL.value}
400+
eval_env = None
401+
if not _VISION.value:
402+
eval_env = registry.load(
403+
_ENV_NAME.value, config=env_cfg, config_overrides=config_overrides
404+
)
405+
num_envs = 1
406+
if _VISION.value:
407+
num_envs = env_cfg.vision_config.render_batch_size
400408

401409
policy_params_fn = lambda *args: None
402410
if _RSCOPE_ENVS.value:
403411
# Interactive visualisation of policy checkpoints
404412
from rscope import brax as rscope_utils
405413

406414
if not _VISION.value:
407-
rscope_env = registry.load(_ENV_NAME.value, config=env_cfg)
415+
rscope_env = registry.load(
416+
_ENV_NAME.value, config=env_cfg, config_overrides=config_overrides
417+
)
408418
rscope_env = wrapper.wrap_for_brax_training(
409419
rscope_env,
410420
episode_length=ppo_params.episode_length,
@@ -433,7 +443,7 @@ def policy_params_fn(current_step, make_policy, params): # pylint: disable=unus
433443
environment=env,
434444
progress_fn=progress,
435445
policy_params_fn=policy_params_fn,
436-
eval_env=None if _VISION.value else eval_env,
446+
eval_env=eval_env,
437447
)
438448

439449
print("Done training.")
@@ -443,63 +453,69 @@ def policy_params_fn(current_step, make_policy, params): # pylint: disable=unus
443453

444454
print("Starting inference...")
445455

446-
# Create inference function
456+
# Create inference function.
447457
inference_fn = make_inference_fn(params, deterministic=True)
448458
jit_inference_fn = jax.jit(inference_fn)
449459

450-
# Prepare for evaluation
451-
eval_env = (
452-
None if _VISION.value else registry.load(_ENV_NAME.value, config=env_cfg)
453-
)
454-
num_envs = 1
455-
if _VISION.value:
456-
eval_env = env
457-
num_envs = env_cfg.vision_config.render_batch_size
458-
459-
jit_reset = jax.jit(eval_env.reset)
460-
jit_step = jax.jit(eval_env.step)
461-
462-
rng = jax.random.PRNGKey(123)
463-
rng, reset_rng = jax.random.split(rng)
464-
if _VISION.value:
465-
reset_rng = jp.asarray(jax.random.split(reset_rng, num_envs))
466-
state = jit_reset(reset_rng)
467-
state0 = (
468-
jax.tree_util.tree_map(lambda x: x[0], state) if _VISION.value else state
469-
)
470-
rollout = [state0]
471-
472-
# Run evaluation rollout
473-
for _ in range(env_cfg.episode_length):
474-
act_rng, rng = jax.random.split(rng)
475-
ctrl, _ = jit_inference_fn(state.obs, act_rng)
476-
state = jit_step(state, ctrl)
477-
state0 = (
478-
jax.tree_util.tree_map(lambda x: x[0], state)
479-
if _VISION.value
480-
else state
460+
# Run evaluation rollouts.
461+
def do_rollout(rng, state):
462+
empty_data = state.data.__class__(
463+
**{k: None for k in state.data.__annotations__}
464+
) # pytype: disable=attribute-error
465+
empty_traj = state.__class__(**{k: None for k in state.__annotations__}) # pytype: disable=attribute-error
466+
empty_traj = empty_traj.replace(data=empty_data)
467+
468+
def step(carry, _):
469+
state, rng = carry
470+
rng, act_key = jax.random.split(rng)
471+
act = jit_inference_fn(state.obs, act_key)[0]
472+
state = eval_env.step(state, act)
473+
traj_data = empty_traj.tree_replace({
474+
"data.qpos": state.data.qpos,
475+
"data.qvel": state.data.qvel,
476+
"data.time": state.data.time,
477+
"data.ctrl": state.data.ctrl,
478+
"data.mocap_pos": state.data.mocap_pos,
479+
"data.mocap_quat": state.data.mocap_quat,
480+
"data.xfrc_applied": state.data.xfrc_applied,
481+
})
482+
if _VISION.value:
483+
traj_data = jax.tree_util.tree_map(lambda x: x[0], traj_data)
484+
return (state, rng), traj_data
485+
486+
_, traj = jax.lax.scan(
487+
step, (state, rng), None, length=_EPISODE_LENGTH.value
481488
)
482-
rollout.append(state0)
483-
if state0.done:
484-
break
489+
return traj
485490

486-
# Render and save the rollout
491+
rng = jax.random.split(jax.random.PRNGKey(_SEED.value), _NUM_VIDEOS.value)
492+
reset_states = jax.jit(jax.vmap(eval_env.reset))(rng)
493+
if _VISION.value:
494+
reset_states = jax.tree_util.tree_map(lambda x: x[0], reset_states)
495+
traj_stacked = jax.jit(jax.vmap(do_rollout))(rng, reset_states)
496+
trajectories = [None] * _NUM_VIDEOS.value
497+
for i in range(_NUM_VIDEOS.value):
498+
t = jax.tree.map(lambda x, i=i: x[i], traj_stacked)
499+
trajectories[i] = [
500+
jax.tree.map(lambda x, j=j: x[j], t)
501+
for j in range(_EPISODE_LENGTH.value)
502+
]
503+
504+
# Render and save the rollout.
487505
render_every = 2
488506
fps = 1.0 / eval_env.dt / render_every
489507
print(f"FPS for rendering: {fps}")
490-
491-
traj = rollout[::render_every]
492-
493508
scene_option = mujoco.MjvOption()
494509
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False
495510
scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = False
496511
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False
497-
498-
frames = eval_env.render(
499-
traj, height=480, width=640, scene_option=scene_option
500-
)
501-
media.write_video("rollout.mp4", frames, fps=fps)
502-
print("Rollout video saved as 'rollout.mp4'.")
512+
for i, rollout in enumerate(trajectories):
513+
traj = rollout[::render_every]
514+
frames = eval_env.render(
515+
traj, height=480, width=640, scene_option=scene_option
516+
)
517+
media.write_video(f"rollout{i}.mp4", frames, fps=fps)
518+
print(f"Rollout video saved as 'rollout{i}.mp4'.")
503519

504520

505521
if __name__ == "__main__":

pyproject.toml

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,37 @@ classifiers = [
2323
"Topic :: Scientific/Engineering",
2424
]
2525
dependencies = [
26-
"brax>=0.12.1",
26+
"brax>=0.12.5",
2727
"etils",
2828
"flax",
2929
"jax",
3030
"lxml",
31+
"mediapy",
3132
"ml_collections",
32-
"mujoco-mjx>=3.2.7",
33-
"mujoco>=3.2.7",
33+
"mujoco-mjx>=3.3.6.dev",
34+
"mujoco>=3.3.6.dev",
35+
"orbax-checkpoint>=0.11.22",
3436
"tqdm",
37+
"warp-lang>=1.9.0.dev",
38+
"wandb",
3539
]
3640
keywords = ["mjx", "mujoco", "sim2real", "reinforcement learning"]
3741

42+
[[tool.uv.index]]
43+
name = "mujoco"
44+
url = "https://py.mujoco.org"
45+
explicit = true
46+
47+
[[tool.uv.index]]
48+
name = "nvidia"
49+
url = "https://pypi.nvidia.com"
50+
explicit = true
51+
52+
[tool.uv.sources]
53+
mujoco = { index = "mujoco" }
54+
warp-lang = { index = "nvidia" }
55+
mujoco-mjx = { git = "https://github.com/google-deepmind/mujoco", rev = "977f94e9dfba12e1861fe5cbc56c50bfb734a78e", subdirectory = "mjx" }
56+
3857
[project.optional-dependencies]
3958
test = [
4059
"absl-py",
@@ -44,9 +63,7 @@ test = [
4463
]
4564
notebooks = [
4665
"matplotlib",
47-
"mediapy",
4866
"jupyter",
49-
"wandb",
5067
]
5168
dev = [
5269
"playground[test]",
@@ -63,6 +80,9 @@ all = [
6380
"playground[notebooks]",
6481
]
6582

83+
[tool.hatch.metadata]
84+
allow-direct-references = true
85+
6686
[tool.hatch.build.targets.wheel]
6787
packages = ["mujoco_playground"]
6888

0 commit comments

Comments
 (0)