1515# pylint: disable=wrong-import-position
1616"""Train a PPO agent using RSL-RL for the specified environment."""
1717
18- import os
19-
20- xla_flags = os .environ .get ("XLA_FLAGS" , "" )
21- xla_flags += " --xla_gpu_triton_gemm_any=True"
22- os .environ ["XLA_FLAGS" ] = xla_flags
23- os .environ ["MUJOCO_GL" ] = "egl"
24-
2518from datetime import datetime
2619import json
20+ import os
2721
2822from absl import app
2923from absl import flags
3226import mediapy as media
3327from ml_collections import config_dict
3428import mujoco
35- from rsl_rl .runners import OnPolicyRunner
36- import torch
37- import wandb
38-
3929import mujoco_playground
4030from mujoco_playground import registry
4131from mujoco_playground import wrapper_torch
4232from mujoco_playground .config import locomotion_params
4333from mujoco_playground .config import manipulation_params
34+ from rsl_rl .runners import OnPolicyRunner
35+ import torch
36+ import warp as wp
37+
38+ try :
39+ import wandb # pylint: disable=g-import-not-at-top
40+ except ImportError :
41+ wandb = None
42+
43+ xla_flags = os .environ .get ("XLA_FLAGS" , "" )
44+ xla_flags += " --xla_gpu_triton_gemm_any=True"
45+ os .environ ["XLA_FLAGS" ] = xla_flags
46+ os .environ ["MUJOCO_GL" ] = "egl"
4447
4548# Suppress logs if you want
4649logging .set_verbosity (logging .WARNING )
7881_CAMERA = flags .DEFINE_string (
7982 "camera" , None , "Camera name to use for rendering."
8083)
84+ _WP_KERNEL_CACHE_DIR = flags .DEFINE_string (
85+ "wp_kernel_cache_dir" ,
86+ "/tmp/wp_kernel_cache_playground" ,
87+ "Path to the WP kernel cache directory." ,
88+ )
8189
8290
8391def get_rl_config (env_name : str ) -> config_dict .ConfigDict :
@@ -93,6 +101,8 @@ def main(argv):
93101 """Run training and evaluation for the specified environment using RSL-RL."""
94102 del argv # unused
95103
104+ wp .config .kernel_cache_dir = _WP_KERNEL_CACHE_DIR .value
105+
96106 # Possibly parse the device for multi-GPU
97107 if _MULTI_GPU .value :
98108 local_rank = int (os .environ .get ("LOCAL_RANK" , "0" ))
@@ -119,7 +129,7 @@ def main(argv):
119129 print (f"Experiment name: { exp_name } " )
120130
121131 # Logging directory
122- logdir = os .path .abspath (os .path .join ("logs" , exp_name ))
132+ logdir = os .path .abspath (os .path .join ("/tmp/rslrl-training- logs/ " , exp_name ))
123133 os .makedirs (logdir , exist_ok = True )
124134 print (f"Logs are being stored in: { logdir } " )
125135
@@ -129,7 +139,7 @@ def main(argv):
129139 print (f"Checkpoint path: { ckpt_path } " )
130140
131141 # Initialize Weights & Biases if required
132- if _USE_WANDB .value and not _PLAY_ONLY .value :
142+ if _USE_WANDB .value and not _PLAY_ONLY .value and wandb is not None :
133143 wandb .tensorboard .patch (root_logdir = logdir )
134144 wandb .init (project = "mjxrl" , name = exp_name )
135145 wandb .config .update (env_cfg .to_dict ())
@@ -152,7 +162,9 @@ def render_callback(_, state):
152162 render_trajectory .append (state )
153163
154164 # Create the environment
155- raw_env = registry .load (_ENV_NAME .value , config = env_cfg )
165+ raw_env = registry .load (
166+ _ENV_NAME .value , config = env_cfg , config_overrides = {"impl" : "jax" }
167+ )
156168 brax_env = wrapper_torch .RSLRLBraxWrapper (
157169 raw_env ,
158170 num_envs ,
@@ -186,7 +198,7 @@ def render_callback(_, state):
186198 # If resume, load from checkpoint
187199 if train_cfg .resume :
188200 resume_path = wrapper_torch .get_load_path (
189- os . path . abspath ( " logs" ) ,
201+ "/tmp/rslrl-training- logs/" ,
190202 load_run = train_cfg .load_run ,
191203 checkpoint = train_cfg .checkpoint ,
192204 )
@@ -206,7 +218,9 @@ def render_callback(_, state):
206218 policy = runner .get_inference_policy (device = device )
207219
208220 # Example: run a single rollout
209- eval_env = registry .load (_ENV_NAME .value , config = env_cfg )
221+ eval_env = registry .load (
222+ _ENV_NAME .value , config = env_cfg , config_overrides = {"impl" : "jax" }
223+ )
210224 jit_reset = jax .jit (eval_env .reset )
211225 jit_step = jax .jit (eval_env .step )
212226
@@ -215,18 +229,25 @@ def render_callback(_, state):
215229 rollout = [state ]
216230
217231 # We’ll assume your environment’s observation is in state.obs["state"].
218- obs_torch = wrapper_torch ._jax_to_torch (state .obs ["state" ])
232+ is_dict_obs = isinstance (eval_env .observation_size , dict )
233+ obs = state .obs ["state" ] if is_dict_obs else state .obs
234+ obs_torch = wrapper_torch ._jax_to_torch (obs )
219235
220236 for _ in range (env_cfg .episode_length ):
221237 with torch .no_grad ():
222- actions = policy (obs_torch )
238+ actions = policy ({"state" : obs_torch })
239+ actions = torch .clip (actions , - 1.0 , 1.0 ) # from wrapper_torch.py
223240 # Step environment
224241 state = jit_step (state , wrapper_torch ._torch_to_jax (actions .flatten ()))
225242 rollout .append (state )
226- obs_torch = wrapper_torch ._jax_to_torch (state .obs ["state" ])
243+ obs = state .obs ["state" ] if is_dict_obs else state .obs
244+ obs_torch = wrapper_torch ._jax_to_torch (obs )
227245 if state .done :
228246 break
229247
248+ reward_sum = sum (s .reward for s in rollout )
249+ print (f"Rollout reward: { reward_sum } " )
250+
230251 # Render
231252 scene_option = mujoco .MjvOption ()
232253 scene_option .flags [mujoco .mjtVisFlag .mjVIS_TRANSPARENT ] = True
0 commit comments