Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit f7c28e8

Browse files
author
Błażej O
committed
Cleaning up rl environments.
1 parent 45460e6 commit f7c28e8

File tree

3 files changed

+15
-151
lines changed

3 files changed

+15
-151
lines changed

tensor2tensor/rl/envs/atari_wrappers.py

Lines changed: 0 additions & 139 deletions
This file was deleted.

tensor2tensor/rl/envs/utils.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -281,21 +281,14 @@ def _worker(self, constructor, conn):
281281
conn.send((self._EXCEPTION, stacktrace))
282282
conn.close()
283283

284-
def batch_env_factory(environment_spec, hparams, num_agents, xvfb=False):
284+
def batch_env_factory(environment_lambda, hparams, num_agents, xvfb=False):
285285
# define env
286286
wrappers = hparams.in_graph_wrappers if hasattr(hparams, "in_graph_wrappers") else []
287287

288288
if hparams.simulated_environment:
289289
batch_env = define_simulated_batch_env(num_agents)
290290
else:
291-
if environment_spec == "stacked_pong":
292-
environment_spec = lambda: gym.make("PongNoFrameskip-v4")
293-
wrappers = [(tf_atari_wrappers.MaxAndSkipEnv, {"skip": 4})]
294-
if isinstance(environment_spec, str):
295-
env_lambda = lambda: gym.make(environment_spec)
296-
else:
297-
env_lambda = environment_spec
298-
batch_env = define_batch_env(env_lambda, num_agents, xvfb=xvfb) # TODO -video?
291+
batch_env = define_batch_env(environment_lambda, num_agents, xvfb=xvfb) # TODO -video?
299292
for w in wrappers:
300293
batch_env = w[0](batch_env, **w[1])
301294
return batch_env

tensor2tensor/rl/rl_trainer_lib.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@
2929
from tensor2tensor.models.research import rl # pylint: disable=unused-import
3030
from tensor2tensor.rl import collect
3131
from tensor2tensor.rl import ppo
32-
from tensor2tensor.rl.envs import atari_wrappers
3332
from tensor2tensor.rl.envs import utils
33+
from tensor2tensor.rl.envs import tf_atari_wrappers
34+
3435

3536
import tensorflow as tf
3637

@@ -41,7 +42,17 @@ def define_train(hparams, environment_spec, event_dir):
4142
"""Define the training setup."""
4243
policy_lambda = hparams.network
4344

44-
batch_env = utils.batch_env_factory(environment_spec, hparams, num_agents=hparams.num_agents)
45+
if environment_spec == "stacked_pong":
46+
environment_spec = lambda: gym.make("PongNoFrameskip-v4")
47+
wrappers = hparams.in_graph_wrappers if hasattr(hparams, "in_graph_wrappers") else []
48+
wrappers.append((tf_atari_wrappers.MaxAndSkipEnv, {"skip": 4}))
49+
hparams.in_graph_wrappers = wrappers
50+
if isinstance(environment_spec, str):
51+
env_lambda = lambda: gym.make(environment_spec)
52+
else:
53+
env_lambda = environment_spec
54+
55+
batch_env = utils.batch_env_factory(env_lambda, hparams, num_agents=hparams.num_agents)
4556

4657
policy_factory = tf.make_template(
4758
"network",
@@ -54,7 +65,6 @@ def define_train(hparams, environment_spec, event_dir):
5465
summary = tf.summary.merge([collect_summary, ppo_summary])
5566

5667
with tf.variable_scope("eval", reuse=tf.AUTO_REUSE):
57-
env_lambda = lambda: gym.make("PongNoFrameskip-v4")
5868
eval_env_lambda = env_lambda
5969
if event_dir and hparams.video_during_eval:
6070
# Some environments reset environments automatically, when reached done

0 commit comments

Comments
 (0)