1414# ==============================================================================
1515"""Train a PPO agent using JAX on the specified environment."""
1616
17- from datetime import datetime
17+ import datetime
1818import functools
1919import json
2020import os
2828from brax .training .agents .ppo import networks_vision as ppo_networks_vision
2929from brax .training .agents .ppo import train as ppo
3030from etils import epath
31- from flax .training import orbax_utils
3231import jax
3332import jax .numpy as jp
3433import mediapy as media
3534from ml_collections import config_dict
3635import mujoco
37- from orbax import checkpoint as ocp
38- from tensorboardX import SummaryWriter
39- import wandb
40-
4136import mujoco_playground
4237from mujoco_playground import registry
4338from mujoco_playground import wrapper
4439from mujoco_playground .config import dm_control_suite_params
4540from mujoco_playground .config import locomotion_params
4641from mujoco_playground .config import manipulation_params
42+ import tensorboardX
43+ import wandb
44+
4745
4846xla_flags = os .environ .get ("XLA_FLAGS" , "" )
4947xla_flags += " --xla_gpu_triton_gemm_any=True"
6967 "LeapCubeReorient" ,
7068 f"Name of the environment. One of { ', ' .join (registry .ALL_ENVS )} " ,
7169)
70+ _IMPL = flags .DEFINE_enum ("impl" , "jax" , ["jax" , "warp" ], "MJX implementation" )
7271_VISION = flags .DEFINE_boolean ("vision" , False , "Use vision input" )
7372_LOAD_CHECKPOINT_PATH = flags .DEFINE_string (
7473 "load_checkpoint_path" , None , "Path to load checkpoint from"
9291_NUM_TIMESTEPS = flags .DEFINE_integer (
9392 "num_timesteps" , 1_000_000 , "Number of timesteps"
9493)
94+ _NUM_VIDEOS = flags .DEFINE_integer (
95+ "num_videos" , 1 , "Number of videos to record after training."
96+ )
9597_NUM_EVALS = flags .DEFINE_integer ("num_evals" , 5 , "Number of evaluations" )
9698_REWARD_SCALING = flags .DEFINE_float ("reward_scaling" , 0.1 , "Reward scaling" )
9799_EPISODE_LENGTH = flags .DEFINE_integer ("episode_length" , 1000 , "Episode length" )
@@ -269,7 +271,7 @@ def main(argv):
269271 print (f"PPO Training Parameters:\n { ppo_params } " )
270272
271273 # Generate unique experiment name
272- now = datetime .now ()
274+ now = datetime .datetime . now ()
273275 timestamp = now .strftime ("%Y%m%d-%H%M%S" )
274276 exp_name = f"{ _ENV_NAME .value } -{ timestamp } "
275277 if _SUFFIX .value is not None :
@@ -289,7 +291,7 @@ def main(argv):
289291
290292 # Initialize TensorBoard if required
291293 if _USE_TB .value and not _PLAY_ONLY .value :
292- writer = SummaryWriter (logdir )
294+ writer = tensorboardX . SummaryWriter (logdir )
293295
294296 # Handle checkpoint loading
295297 if _LOAD_CHECKPOINT_PATH .value is not None :
@@ -393,18 +395,26 @@ def progress(num_steps, metrics):
393395 f" reward={ metrics ['episode/sum_reward' ]:.3f} "
394396 )
395397
396- # Load evaluation environment
397- eval_env = (
398- None if _VISION .value else registry .load (_ENV_NAME .value , config = env_cfg )
399- )
398+ # Load evaluation environment.
399+ config_overrides = {"impl" : _IMPL .value }
400+ eval_env = None
401+ if not _VISION .value :
402+ eval_env = registry .load (
403+ _ENV_NAME .value , config = env_cfg , config_overrides = config_overrides
404+ )
405+ num_envs = 1
406+ if _VISION .value :
407+ num_envs = env_cfg .vision_config .render_batch_size
400408
401409 policy_params_fn = lambda * args : None
402410 if _RSCOPE_ENVS .value :
403411 # Interactive visualisation of policy checkpoints
404412 from rscope import brax as rscope_utils
405413
406414 if not _VISION .value :
407- rscope_env = registry .load (_ENV_NAME .value , config = env_cfg )
415+ rscope_env = registry .load (
416+ _ENV_NAME .value , config = env_cfg , config_overrides = config_overrides
417+ )
408418 rscope_env = wrapper .wrap_for_brax_training (
409419 rscope_env ,
410420 episode_length = ppo_params .episode_length ,
@@ -433,7 +443,7 @@ def policy_params_fn(current_step, make_policy, params): # pylint: disable=unus
433443 environment = env ,
434444 progress_fn = progress ,
435445 policy_params_fn = policy_params_fn ,
436- eval_env = None if _VISION . value else eval_env ,
446+ eval_env = eval_env ,
437447 )
438448
439449 print ("Done training." )
@@ -443,63 +453,69 @@ def policy_params_fn(current_step, make_policy, params): # pylint: disable=unus
443453
444454 print ("Starting inference..." )
445455
446- # Create inference function
456+ # Create inference function.
447457 inference_fn = make_inference_fn (params , deterministic = True )
448458 jit_inference_fn = jax .jit (inference_fn )
449459
450- # Prepare for evaluation
451- eval_env = (
452- None if _VISION .value else registry .load (_ENV_NAME .value , config = env_cfg )
453- )
454- num_envs = 1
455- if _VISION .value :
456- eval_env = env
457- num_envs = env_cfg .vision_config .render_batch_size
458-
459- jit_reset = jax .jit (eval_env .reset )
460- jit_step = jax .jit (eval_env .step )
461-
462- rng = jax .random .PRNGKey (123 )
463- rng , reset_rng = jax .random .split (rng )
464- if _VISION .value :
465- reset_rng = jp .asarray (jax .random .split (reset_rng , num_envs ))
466- state = jit_reset (reset_rng )
467- state0 = (
468- jax .tree_util .tree_map (lambda x : x [0 ], state ) if _VISION .value else state
469- )
470- rollout = [state0 ]
471-
472- # Run evaluation rollout
473- for _ in range (env_cfg .episode_length ):
474- act_rng , rng = jax .random .split (rng )
475- ctrl , _ = jit_inference_fn (state .obs , act_rng )
476- state = jit_step (state , ctrl )
477- state0 = (
478- jax .tree_util .tree_map (lambda x : x [0 ], state )
479- if _VISION .value
480- else state
460+ # Run evaluation rollouts.
461+ def do_rollout (rng , state ):
462+ empty_data = state .data .__class__ (
463+ ** {k : None for k in state .data .__annotations__ }
464+ ) # pytype: disable=attribute-error
465+ empty_traj = state .__class__ (** {k : None for k in state .__annotations__ }) # pytype: disable=attribute-error
466+ empty_traj = empty_traj .replace (data = empty_data )
467+
468+ def step (carry , _ ):
469+ state , rng = carry
470+ rng , act_key = jax .random .split (rng )
471+ act = jit_inference_fn (state .obs , act_key )[0 ]
472+ state = eval_env .step (state , act )
473+ traj_data = empty_traj .tree_replace ({
474+ "data.qpos" : state .data .qpos ,
475+ "data.qvel" : state .data .qvel ,
476+ "data.time" : state .data .time ,
477+ "data.ctrl" : state .data .ctrl ,
478+ "data.mocap_pos" : state .data .mocap_pos ,
479+ "data.mocap_quat" : state .data .mocap_quat ,
480+ "data.xfrc_applied" : state .data .xfrc_applied ,
481+ })
482+ if _VISION .value :
483+ traj_data = jax .tree_util .tree_map (lambda x : x [0 ], traj_data )
484+ return (state , rng ), traj_data
485+
486+ _ , traj = jax .lax .scan (
487+ step , (state , rng ), None , length = _EPISODE_LENGTH .value
481488 )
482- rollout .append (state0 )
483- if state0 .done :
484- break
489+ return traj
485490
486- # Render and save the rollout
491+ rng = jax .random .split (jax .random .PRNGKey (_SEED .value ), _NUM_VIDEOS .value )
492+ reset_states = jax .jit (jax .vmap (eval_env .reset ))(rng )
493+ if _VISION .value :
494+ reset_states = jax .tree_util .tree_map (lambda x : x [0 ], reset_states )
495+ traj_stacked = jax .jit (jax .vmap (do_rollout ))(rng , reset_states )
496+ trajectories = [None ] * _NUM_VIDEOS .value
497+ for i in range (_NUM_VIDEOS .value ):
498+ t = jax .tree .map (lambda x , i = i : x [i ], traj_stacked )
499+ trajectories [i ] = [
500+ jax .tree .map (lambda x , j = j : x [j ], t )
501+ for j in range (_EPISODE_LENGTH .value )
502+ ]
503+
504+ # Render and save the rollout.
487505 render_every = 2
488506 fps = 1.0 / eval_env .dt / render_every
489507 print (f"FPS for rendering: { fps } " )
490-
491- traj = rollout [::render_every ]
492-
493508 scene_option = mujoco .MjvOption ()
494509 scene_option .flags [mujoco .mjtVisFlag .mjVIS_TRANSPARENT ] = False
495510 scene_option .flags [mujoco .mjtVisFlag .mjVIS_PERTFORCE ] = False
496511 scene_option .flags [mujoco .mjtVisFlag .mjVIS_CONTACTFORCE ] = False
497-
498- frames = eval_env .render (
499- traj , height = 480 , width = 640 , scene_option = scene_option
500- )
501- media .write_video ("rollout.mp4" , frames , fps = fps )
502- print ("Rollout video saved as 'rollout.mp4'." )
512+ for i , rollout in enumerate (trajectories ):
513+ traj = rollout [::render_every ]
514+ frames = eval_env .render (
515+ traj , height = 480 , width = 640 , scene_option = scene_option
516+ )
517+ media .write_video (f"rollout{ i } .mp4" , frames , fps = fps )
518+ print (f"Rollout video saved as 'rollout{ i } .mp4'." )
503519
504520
505521if __name__ == "__main__" :
0 commit comments