7272 f"Name of the environment. One of { ', ' .join (registry .ALL_ENVS )} " ,
7373)
7474_IMPL = flags .DEFINE_enum ("impl" , "jax" , ["jax" , "warp" ], "MJX implementation" )
75+ _PLAYGROUND_CONFIG_OVERRIDES = flags .DEFINE_string (
76+ "playground_config_overrides" ,
77+ None ,
78+ "Overrides for the playground env config." ,
79+ )
7580_VISION = flags .DEFINE_boolean ("vision" , False , "Use vision input" )
7681_LOAD_CHECKPOINT_PATH = flags .DEFINE_string (
7782 "load_checkpoint_path" , None , "Path to load checkpoint from"
@@ -264,7 +269,12 @@ def main(argv):
264269 if _VISION .value :
265270 env_cfg .vision = True
266271 env_cfg .vision_config .render_batch_size = ppo_params .num_envs
267- env = registry .load (_ENV_NAME .value , config = env_cfg )
272+ env_cfg_overrides = {}
273+ if _PLAYGROUND_CONFIG_OVERRIDES .value is not None :
274+ env_cfg_overrides = json .loads (_PLAYGROUND_CONFIG_OVERRIDES .value )
275+ env = registry .load (
276+ _ENV_NAME .value , config = env_cfg , config_overrides = env_cfg_overrides
277+ )
268278 if _RUN_EVALS .present :
269279 ppo_params .run_evals = _RUN_EVALS .value
270280 if _LOG_TRAINING_METRICS .present :
@@ -273,6 +283,8 @@ def main(argv):
273283 ppo_params .training_metrics_steps = _TRAINING_METRICS_STEPS .value
274284
275285 print (f"Environment Config:\n { env_cfg } " )
286+ if env_cfg_overrides :
287+ print (f"Environment Config Overrides:\n { env_cfg_overrides } \n " )
276288 print (f"PPO Training Parameters:\n { ppo_params } " )
277289
278290 # Generate unique experiment name
@@ -408,7 +420,9 @@ def progress(num_steps, metrics):
408420 # Load evaluation environment.
409421 eval_env = None
410422 if not _VISION .value :
411- eval_env = registry .load (_ENV_NAME .value , config = env_cfg )
423+ eval_env = registry .load (
424+ _ENV_NAME .value , config = env_cfg , config_overrides = env_cfg_overrides
425+ )
412426 num_envs = 1
413427 if _VISION .value :
414428 num_envs = env_cfg .vision_config .render_batch_size
@@ -419,7 +433,9 @@ def progress(num_steps, metrics):
419433 from rscope import brax as rscope_utils
420434
421435 if not _VISION .value :
422- rscope_env = registry .load (_ENV_NAME .value , config = env_cfg )
436+ rscope_env = registry .load (
437+ _ENV_NAME .value , config = env_cfg , config_overrides = env_cfg_overrides
438+ )
423439 rscope_env = wrapper .wrap_for_brax_training (
424440 rscope_env ,
425441 episode_length = ppo_params .episode_length ,
0 commit comments