Skip to content

Commit ebc9fd0

Browse files
bridgesignbtaba
andauthored
VectorEnv gym visulization correction (#549)
* VectorEnv gym visulization correction This is a solution in context of #535 * Added tests. Allow width-height configuration * Reduce batch size * Reduce size of test * Change render to list pipeline states before conversion * Update gym_test.py * Update gym_test.py --------- Co-authored-by: btaba <btaba@google.com>
1 parent 01ca8ca commit ebc9fd0

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

brax/envs/wrappers/gym.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,12 @@ def step(self, action):
7777
def seed(self, seed: int = 0):
7878
self._key = jax.random.PRNGKey(seed)
7979

80-
def render(self, mode='human'):
80+
def render(self, mode='human', width=256, height=256):
8181
if mode == 'rgb_array':
8282
sys, state = self._env.sys, self._state
8383
if state is None:
8484
raise RuntimeError('must call reset or step before rendering')
85-
return image.render_array(sys, state.pipeline_state, 256, 256)
85+
return image.render_array(sys, state.pipeline_state, width=width, height=height)
8686
else:
8787
return super().render(mode=mode) # just raise an exception
8888

@@ -143,11 +143,12 @@ def step(self, action):
143143
def seed(self, seed: int = 0):
144144
self._key = jax.random.PRNGKey(seed)
145145

146-
def render(self, mode='human'):
146+
def render(self, mode='human', width=256, height=256):
147147
if mode == 'rgb_array':
148148
sys, state = self._env.sys, self._state
149149
if state is None:
150150
raise RuntimeError('must call reset or step before rendering')
151-
return image.render_array(sys, state.pipeline_state.take(0), 256, 256)
151+
state_list = [state.take(i).pipeline_state for i in range(self.num_envs)]
152+
return np.stack(image.render_array(sys, state_list, width=width, height=height))
152153
else:
153154
return super().render(mode=mode) # just raise an exception

brax/envs/wrappers/gym_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_vector_action_space(self):
4646
np.testing.assert_array_equal(
4747
env.action_space.high,
4848
np.tile(base_env.sys.actuator.ctrl_range[:, 1], [256, 1]),
49-
)
49+
)
5050

5151

5252
if __name__ == '__main__':

0 commit comments

Comments
 (0)