Skip to content

Commit 5a1eaa0

Browse files
btabacopybara-github
authored andcommitted
Fix bug with env_cfg in train_jax_ppo.
PiperOrigin-RevId: 881051403 Change-Id: Ifab6a9684de21e4780478c9be21624efc03a8455
1 parent e604cac commit 5a1eaa0

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

learning/train_jax_ppo.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@
177177
"warp_kernel_cache_dir", None,
178178
"Directory for caching compiled Warp kernels.",
179179
)
180+
_LOGDIR = flags.DEFINE_string(
181+
"logdir", None, "Directory for logging."
182+
)
180183

181184

182185
def get_rl_config(env_name: str) -> config_dict.ConfigDict:
@@ -277,16 +280,18 @@ def main(argv):
277280
ppo_params.network_factory.policy_obs_key = _POLICY_OBS_KEY.value
278281
if _VALUE_OBS_KEY.present:
279282
ppo_params.network_factory.value_obs_key = _VALUE_OBS_KEY.value
283+
280284
if _VISION.value:
281285
env_cfg.vision = True
286+
env_cfg.vision_config.nworld = ppo_params.num_envs
287+
282288
env_cfg_overrides = {}
283289
if _PLAYGROUND_CONFIG_OVERRIDES.value is not None:
284290
env_cfg_overrides = json.loads(_PLAYGROUND_CONFIG_OVERRIDES.value)
285291

286-
if _VISION.value:
287-
env_cfg_overrides["vision"] = True
288-
env_cfg_overrides["vision_config.nworld"] = ppo_params.num_envs
289-
env = registry.load(_ENV_NAME.value, config_overrides=env_cfg_overrides)
292+
env = registry.load(
293+
_ENV_NAME.value, config=env_cfg, config_overrides=env_cfg_overrides
294+
)
290295
if _RUN_EVALS.present:
291296
ppo_params.run_evals = _RUN_EVALS.value
292297
if _LOG_TRAINING_METRICS.present:
@@ -308,7 +313,7 @@ def main(argv):
308313
print(f"Experiment name: {exp_name}")
309314

310315
# Set up logging directory
311-
logdir = epath.Path("logs").resolve() / exp_name
316+
logdir = epath.Path(_LOGDIR.value or "logs").resolve() / exp_name
312317
logdir.mkdir(parents=True, exist_ok=True)
313318
print(f"Logs are being stored in: {logdir}")
314319

@@ -417,11 +422,12 @@ def progress(num_steps, metrics):
417422
)
418423

419424
# Load evaluation environment.
420-
eval_cfg_overrides = dict(env_cfg_overrides)
425+
eval_env_cfg = config_dict.ConfigDict(env_cfg)
421426
if _VISION.value:
422-
eval_cfg_overrides["vision"] = True
423-
eval_cfg_overrides["vision_config.nworld"] = num_eval_envs
424-
eval_env = registry.load(_ENV_NAME.value, config_overrides=eval_cfg_overrides)
427+
eval_env_cfg.vision_config.nworld = num_eval_envs
428+
eval_env = registry.load(
429+
_ENV_NAME.value, config=eval_env_cfg, config_overrides=env_cfg_overrides
430+
)
425431

426432
policy_params_fn = lambda *args: None
427433
if _RSCOPE_ENVS.value:
@@ -475,9 +481,11 @@ def policy_params_fn(current_step, make_policy, params): # pylint: disable=unus
475481
jit_inference_fn = jax.jit(inference_fn)
476482

477483
# For inference rollouts, create env with vision disabled.
478-
env_cfg.vision_config.nworld = _NUM_VIDEOS.value
484+
infer_env_cfg = config_dict.ConfigDict(env_cfg)
485+
if _VISION.value:
486+
infer_env_cfg.vision_config.nworld = _NUM_VIDEOS.value
479487
infer_env = registry.load(
480-
_ENV_NAME.value, config=env_cfg, config_overrides=env_cfg_overrides
488+
_ENV_NAME.value, config=infer_env_cfg, config_overrides=env_cfg_overrides
481489
)
482490

483491
# Run evaluation rollouts matching how training handles batched environments.
@@ -546,8 +554,8 @@ def do_rollout(state, rng):
546554
frames = infer_env.render(
547555
traj, height=480, width=640, scene_option=scene_option
548556
)
549-
media.write_video(f"rollout{i}.mp4", frames, fps=fps)
550-
print(f"Rollout video saved as 'rollout{i}.mp4'.")
557+
media.write_video(logdir / f"rollout{i}.mp4", frames, fps=fps)
558+
print(f"Rollout video saved as '{logdir}/rollout{i}.mp4'.")
551559

552560

553561
def run():

0 commit comments

Comments
 (0)