177177 "warp_kernel_cache_dir" , None ,
178178 "Directory for caching compiled Warp kernels." ,
179179)
180+ _LOGDIR = flags .DEFINE_string (
181+ "logdir" , None , "Directory for logging."
182+ )
180183
181184
182185def get_rl_config (env_name : str ) -> config_dict .ConfigDict :
@@ -277,16 +280,18 @@ def main(argv):
277280 ppo_params .network_factory .policy_obs_key = _POLICY_OBS_KEY .value
278281 if _VALUE_OBS_KEY .present :
279282 ppo_params .network_factory .value_obs_key = _VALUE_OBS_KEY .value
283+
280284 if _VISION .value :
281285 env_cfg .vision = True
286+ env_cfg .vision_config .nworld = ppo_params .num_envs
287+
282288 env_cfg_overrides = {}
283289 if _PLAYGROUND_CONFIG_OVERRIDES .value is not None :
284290 env_cfg_overrides = json .loads (_PLAYGROUND_CONFIG_OVERRIDES .value )
285291
286- if _VISION .value :
287- env_cfg_overrides ["vision" ] = True
288- env_cfg_overrides ["vision_config.nworld" ] = ppo_params .num_envs
289- env = registry .load (_ENV_NAME .value , config_overrides = env_cfg_overrides )
292+ env = registry .load (
293+ _ENV_NAME .value , config = env_cfg , config_overrides = env_cfg_overrides
294+ )
290295 if _RUN_EVALS .present :
291296 ppo_params .run_evals = _RUN_EVALS .value
292297 if _LOG_TRAINING_METRICS .present :
@@ -308,7 +313,7 @@ def main(argv):
308313 print (f"Experiment name: { exp_name } " )
309314
310315 # Set up logging directory
311- logdir = epath .Path ("logs" ).resolve () / exp_name
316+ logdir = epath .Path (_LOGDIR . value or "logs" ).resolve () / exp_name
312317 logdir .mkdir (parents = True , exist_ok = True )
313318 print (f"Logs are being stored in: { logdir } " )
314319
@@ -417,11 +422,12 @@ def progress(num_steps, metrics):
417422 )
418423
419424 # Load evaluation environment.
420- eval_cfg_overrides = dict ( env_cfg_overrides )
425+ eval_env_cfg = config_dict . ConfigDict ( env_cfg )
421426 if _VISION .value :
422- eval_cfg_overrides ["vision" ] = True
423- eval_cfg_overrides ["vision_config.nworld" ] = num_eval_envs
424- eval_env = registry .load (_ENV_NAME .value , config_overrides = eval_cfg_overrides )
427+ eval_env_cfg .vision_config .nworld = num_eval_envs
428+ eval_env = registry .load (
429+ _ENV_NAME .value , config = eval_env_cfg , config_overrides = env_cfg_overrides
430+ )
425431
426432 policy_params_fn = lambda * args : None
427433 if _RSCOPE_ENVS .value :
@@ -475,9 +481,11 @@ def policy_params_fn(current_step, make_policy, params): # pylint: disable=unus
475481 jit_inference_fn = jax .jit (inference_fn )
476482
477483 # For inference rollouts, create env with vision disabled.
478- env_cfg .vision_config .nworld = _NUM_VIDEOS .value
484+ infer_env_cfg = config_dict .ConfigDict (env_cfg )
485+ if _VISION .value :
486+ infer_env_cfg .vision_config .nworld = _NUM_VIDEOS .value
479487 infer_env = registry .load (
480- _ENV_NAME .value , config = env_cfg , config_overrides = env_cfg_overrides
488+ _ENV_NAME .value , config = infer_env_cfg , config_overrides = env_cfg_overrides
481489 )
482490
483491 # Run evaluation rollouts matching how training handles batched environments.
@@ -546,8 +554,8 @@ def do_rollout(state, rng):
546554 frames = infer_env .render (
547555 traj , height = 480 , width = 640 , scene_option = scene_option
548556 )
549- media .write_video (f"rollout{ i } .mp4" , frames , fps = fps )
550- print (f"Rollout video saved as 'rollout{ i } .mp4'." )
557+ media .write_video (logdir / f"rollout{ i } .mp4" , frames , fps = fps )
558+ print (f"Rollout video saved as '{ logdir } / rollout{ i } .mp4'." )
551559
552560
553561def run ():
0 commit comments