|
15 | 15 | # pylint:disable=g-multiple-import, g-importing-member |
16 | 16 | """Wrappers to support Brax training.""" |
17 | 17 |
|
| 18 | +import contextlib |
18 | 19 | from typing import Callable, Dict, Optional, Tuple |
19 | 20 |
|
20 | 21 | from brax.base import System |
@@ -230,23 +231,28 @@ def __init__( |
230 | 231 | super().__init__(env) |
231 | 232 | self._sys_v, self._in_axes = randomization_fn(self.sys) |
232 | 233 |
|
233 | | - def _env_fn(self, sys: System) -> Env: |
234 | | - env = self.env |
235 | | - env.unwrapped.sys = sys |
236 | | - return env |
| 234 | + @contextlib.contextmanager |
| 235 | + def v_env_fn(self, new_sys: System): |
| 236 | + env = self.env.unwrapped |
| 237 | + old_sys = env.sys # pytype: disable=attribute-error |
| 238 | + try: |
| 239 | + env.sys = new_sys |
| 240 | + yield env |
| 241 | + finally: |
| 242 | + env.unwrapped.sys = old_sys |
237 | 243 |
|
238 | 244 | def reset(self, rng: jax.Array) -> State: |
239 | 245 | def reset(sys, rng): |
240 | | - env = self._env_fn(sys=sys) |
241 | | - return env.reset(rng) |
| 246 | + with self.v_env_fn(sys) as v_env: |
| 247 | + return v_env.reset(rng) |
242 | 248 |
|
243 | 249 | state = jax.vmap(reset, in_axes=[self._in_axes, 0])(self._sys_v, rng) |
244 | 250 | return state |
245 | 251 |
|
246 | 252 | def step(self, state: State, action: jax.Array) -> State: |
247 | 253 | def step(sys, s, a): |
248 | | - env = self._env_fn(sys=sys) |
249 | | - return env.step(s, a) |
| 254 | + with self.v_env_fn(sys) as v_env: |
| 255 | + return v_env.step(s, a) |
250 | 256 |
|
251 | 257 | res = jax.vmap(step, in_axes=[self._in_axes, 0, 0])( |
252 | 258 | self._sys_v, state, action |
|
0 commit comments