diff --git a/brax/envs/wrappers/training.py b/brax/envs/wrappers/training.py index d0ca7d6af..6c0e1c207 100644 --- a/brax/envs/wrappers/training.py +++ b/brax/envs/wrappers/training.py @@ -133,6 +133,7 @@ def reset(self, rng: jax.Array) -> State: state = self.env.reset(rng) state.info['first_pipeline_state'] = state.pipeline_state state.info['first_obs'] = state.obs + state.info["obs_st"] = state.obs return state def step(self, state: State, action: jax.Array) -> State: @@ -143,6 +144,9 @@ def step(self, state: State, action: jax.Array) -> State: state = state.replace(done=jp.zeros_like(state.done)) state = self.env.step(state, action) + # Store next_obs before reset + obs_st = state.obs + def where_done(x, y): done = state.done if done.shape and done.shape[0] != x.shape[0]: @@ -155,6 +159,7 @@ def where_done(x, y): where_done, state.info['first_pipeline_state'], state.pipeline_state ) obs = jax.tree.map(where_done, state.info['first_obs'], state.obs) + state.info["obs_st"] = obs_st return state.replace(pipeline_state=pipeline_state, obs=obs) diff --git a/brax/envs/wrappers/training_test.py b/brax/envs/wrappers/training_test.py index 734ef0d77..ef494db68 100644 --- a/brax/envs/wrappers/training_test.py +++ b/brax/envs/wrappers/training_test.py @@ -26,6 +26,40 @@ class TrainingTest(absltest.TestCase): + def test_autoreset_termination(self): + for env_id in ["ant", "halfcheetah"]: + with self.subTest(env_id=env_id): + self._run_termination(env_id) + + def _run_termination(self, env_id): + env = envs.create(env_id) + key = jax.random.PRNGKey(42) + max_steps_in_episode = env.episode_length + + state = jax.jit(env.reset)(key) + action = jp.zeros(env.sys.act_size()) + + env_step_fn = jax.jit(env.step) + + def step_fn(state, _): + next_state = env_step_fn(state, action) + return next_state, (next_state.obs, next_state.done, next_state.info) + + _, (observations, dones, infos) = jax.lax.scan( + f=step_fn, init=state, xs=None, length=max_steps_in_episode + 1 + ) + + observations_step = infos["obs_st"] + # Should have at least finished once + assert sum(dones) >= 1 + for i, (obs, done, obs_st) in enumerate(zip(observations, dones, observations_step)): + if done: + # Ensure we stored the last obs from finished episode, \\ + # which differs from first obs of new episode + assert not jp.array_equal(obs_st, obs) + else: + assert jp.array_equal(obs_st, obs) + def test_domain_randomization_wrapper(self): def rand(sys, rng): @jax.vmap