Skip to content

Commit 610557c

Browse files
Merge pull request #611 from AndruGomes13:dr_fix
PiperOrigin-RevId: 795062674 Change-Id: Ib32da6ef0c4e703b5f06a02fc05841dc0a518aad
2 parents 4d0cbe2 + 10017d6 commit 610557c

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

brax/envs/wrappers/training.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# pylint:disable=g-multiple-import, g-importing-member
1616
"""Wrappers to support Brax training."""
1717

18+
import contextlib
1819
from typing import Callable, Dict, Optional, Tuple
1920

2021
from brax.base import System
@@ -230,23 +231,28 @@ def __init__(
230231
super().__init__(env)
231232
self._sys_v, self._in_axes = randomization_fn(self.sys)
232233

233-
def _env_fn(self, sys: System) -> Env:
234-
env = self.env
235-
env.unwrapped.sys = sys
236-
return env
234+
@contextlib.contextmanager
235+
def v_env_fn(self, new_sys: System):
236+
env = self.env.unwrapped
237+
old_sys = env.sys # pytype: disable=attribute-error
238+
try:
239+
env.sys = new_sys
240+
yield env
241+
finally:
242+
env.unwrapped.sys = old_sys
237243

238244
def reset(self, rng: jax.Array) -> State:
239245
def reset(sys, rng):
240-
env = self._env_fn(sys=sys)
241-
return env.reset(rng)
246+
with self.v_env_fn(sys) as v_env:
247+
return v_env.reset(rng)
242248

243249
state = jax.vmap(reset, in_axes=[self._in_axes, 0])(self._sys_v, rng)
244250
return state
245251

246252
def step(self, state: State, action: jax.Array) -> State:
247253
def step(sys, s, a):
248-
env = self._env_fn(sys=sys)
249-
return env.step(s, a)
254+
with self.v_env_fn(sys) as v_env:
255+
return v_env.step(s, a)
250256

251257
res = jax.vmap(step, in_axes=[self._in_axes, 0, 0])(
252258
self._sys_v, state, action

0 commit comments

Comments
 (0)