Skip to content

Commit bfac262

Browse files
btabacopybara-github
authored andcommitted
internal change
PiperOrigin-RevId: 832442898 Change-Id: Idf68bc40b5d63c29681a5852599967e2d149a22c
1 parent 13f4bfc commit bfac262

File tree

2 files changed

+46
-25
lines changed

2 files changed

+46
-25
lines changed

learning/train_rsl_rl.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,9 @@
1515
# pylint: disable=wrong-import-position
1616
"""Train a PPO agent using RSL-RL for the specified environment."""
1717

18-
import os
19-
20-
xla_flags = os.environ.get("XLA_FLAGS", "")
21-
xla_flags += " --xla_gpu_triton_gemm_any=True"
22-
os.environ["XLA_FLAGS"] = xla_flags
23-
os.environ["MUJOCO_GL"] = "egl"
24-
2518
from datetime import datetime
2619
import json
20+
import os
2721

2822
from absl import app
2923
from absl import flags
@@ -32,15 +26,24 @@
3226
import mediapy as media
3327
from ml_collections import config_dict
3428
import mujoco
35-
from rsl_rl.runners import OnPolicyRunner
36-
import torch
37-
import wandb
38-
3929
import mujoco_playground
4030
from mujoco_playground import registry
4131
from mujoco_playground import wrapper_torch
4232
from mujoco_playground.config import locomotion_params
4333
from mujoco_playground.config import manipulation_params
34+
from rsl_rl.runners import OnPolicyRunner
35+
import torch
36+
import warp as wp
37+
38+
try:
39+
import wandb # pylint: disable=g-import-not-at-top
40+
except ImportError:
41+
wandb = None
42+
43+
xla_flags = os.environ.get("XLA_FLAGS", "")
44+
xla_flags += " --xla_gpu_triton_gemm_any=True"
45+
os.environ["XLA_FLAGS"] = xla_flags
46+
os.environ["MUJOCO_GL"] = "egl"
4447

4548
# Suppress logs if you want
4649
logging.set_verbosity(logging.WARNING)
@@ -78,6 +81,11 @@
7881
_CAMERA = flags.DEFINE_string(
7982
"camera", None, "Camera name to use for rendering."
8083
)
84+
_WP_KERNEL_CACHE_DIR = flags.DEFINE_string(
85+
"wp_kernel_cache_dir",
86+
"/tmp/wp_kernel_cache_playground",
87+
"Path to the WP kernel cache directory.",
88+
)
8189

8290

8391
def get_rl_config(env_name: str) -> config_dict.ConfigDict:
@@ -93,6 +101,8 @@ def main(argv):
93101
"""Run training and evaluation for the specified environment using RSL-RL."""
94102
del argv # unused
95103

104+
wp.config.kernel_cache_dir = _WP_KERNEL_CACHE_DIR.value
105+
96106
# Possibly parse the device for multi-GPU
97107
if _MULTI_GPU.value:
98108
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
@@ -119,7 +129,7 @@ def main(argv):
119129
print(f"Experiment name: {exp_name}")
120130

121131
# Logging directory
122-
logdir = os.path.abspath(os.path.join("logs", exp_name))
132+
logdir = os.path.abspath(os.path.join("/tmp/rslrl-training-logs/", exp_name))
123133
os.makedirs(logdir, exist_ok=True)
124134
print(f"Logs are being stored in: {logdir}")
125135

@@ -129,7 +139,7 @@ def main(argv):
129139
print(f"Checkpoint path: {ckpt_path}")
130140

131141
# Initialize Weights & Biases if required
132-
if _USE_WANDB.value and not _PLAY_ONLY.value:
142+
if _USE_WANDB.value and not _PLAY_ONLY.value and wandb is not None:
133143
wandb.tensorboard.patch(root_logdir=logdir)
134144
wandb.init(project="mjxrl", name=exp_name)
135145
wandb.config.update(env_cfg.to_dict())
@@ -152,7 +162,9 @@ def render_callback(_, state):
152162
render_trajectory.append(state)
153163

154164
# Create the environment
155-
raw_env = registry.load(_ENV_NAME.value, config=env_cfg)
165+
raw_env = registry.load(
166+
_ENV_NAME.value, config=env_cfg, config_overrides={"impl": "jax"}
167+
)
156168
brax_env = wrapper_torch.RSLRLBraxWrapper(
157169
raw_env,
158170
num_envs,
@@ -186,7 +198,7 @@ def render_callback(_, state):
186198
# If resume, load from checkpoint
187199
if train_cfg.resume:
188200
resume_path = wrapper_torch.get_load_path(
189-
os.path.abspath("logs"),
201+
"/tmp/rslrl-training-logs/",
190202
load_run=train_cfg.load_run,
191203
checkpoint=train_cfg.checkpoint,
192204
)
@@ -206,7 +218,9 @@ def render_callback(_, state):
206218
policy = runner.get_inference_policy(device=device)
207219

208220
# Example: run a single rollout
209-
eval_env = registry.load(_ENV_NAME.value, config=env_cfg)
221+
eval_env = registry.load(
222+
_ENV_NAME.value, config=env_cfg, config_overrides={"impl": "jax"}
223+
)
210224
jit_reset = jax.jit(eval_env.reset)
211225
jit_step = jax.jit(eval_env.step)
212226

@@ -215,18 +229,25 @@ def render_callback(_, state):
215229
rollout = [state]
216230

217231
# We’ll assume your environment’s observation is in state.obs["state"].
218-
obs_torch = wrapper_torch._jax_to_torch(state.obs["state"])
232+
is_dict_obs = isinstance(eval_env.observation_size, dict)
233+
obs = state.obs["state"] if is_dict_obs else state.obs
234+
obs_torch = wrapper_torch._jax_to_torch(obs)
219235

220236
for _ in range(env_cfg.episode_length):
221237
with torch.no_grad():
222-
actions = policy(obs_torch)
238+
actions = policy({"state": obs_torch})
239+
actions = torch.clip(actions, -1.0, 1.0) # from wrapper_torch.py
223240
# Step environment
224241
state = jit_step(state, wrapper_torch._torch_to_jax(actions.flatten()))
225242
rollout.append(state)
226-
obs_torch = wrapper_torch._jax_to_torch(state.obs["state"])
243+
obs = state.obs["state"] if is_dict_obs else state.obs
244+
obs_torch = wrapper_torch._jax_to_torch(obs)
227245
if state.done:
228246
break
229247

248+
reward_sum = sum(s.reward for s in rollout)
249+
print(f"Rollout reward: {reward_sum}")
250+
230251
# Render
231252
scene_option = mujoco.MjvOption()
232253
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = True

mujoco_playground/config/manipulation_params.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -216,18 +216,18 @@ def rsl_rl_config(env_name: str, unused_impl: Optional[str] = None) -> config_di
216216
value_loss_coef=1.0,
217217
use_clipped_value_loss=True,
218218
clip_param=0.2,
219-
entropy_coef=0.001,
220-
num_learning_epochs=5,
219+
entropy_coef=0.01,
220+
num_learning_epochs=4,
221221
# mini batch size = num_envs*nsteps / nminibatches
222-
num_mini_batches=4,
223-
learning_rate=3.0e-4, # 5.e-4
222+
num_mini_batches=8,
223+
learning_rate=1e-3,
224224
schedule="adaptive", # could be adaptive, fixed
225-
gamma=0.99,
225+
gamma=0.97,
226226
lam=0.95,
227227
desired_kl=0.01,
228228
max_grad_norm=1.0,
229229
),
230-
num_steps_per_env=24, # per iteration
230+
num_steps_per_env=40, # per iteration
231231
max_iterations=100000, # number of policy updates
232232
empirical_normalization=True,
233233
# logging

0 commit comments

Comments
 (0)