Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 32396a9

Browse files
authored
Merge pull request #701 from deepsense-ai/rl_model_experiment
RL model-based experiment
2 parents 01369c9 + f7c28e8 commit 32396a9

20 files changed

+1076
-488
lines changed

tensor2tensor/data_generators/gym.py

Lines changed: 175 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -24,76 +24,136 @@
2424
# Dependency imports
2525

2626
import gym
27-
import numpy as np
27+
import os
28+
from tensorflow.contrib.training import HParams
29+
from collections import deque
2830

2931
from tensor2tensor.data_generators import generator_utils
3032
from tensor2tensor.data_generators import problem
31-
3233
from tensor2tensor.models.research import rl
33-
from tensor2tensor.rl import rl_trainer_lib # pylint: disable=unused-import
34-
from tensor2tensor.rl.envs import atari_wrappers
35-
36-
from tensor2tensor.utils import metrics
3734
from tensor2tensor.utils import registry
35+
from tensor2tensor.rl.envs.utils import batch_env_factory
36+
from tensor2tensor.rl.envs.tf_atari_wrappers import MemoryWrapper, TimeLimitWrapper
37+
from tensor2tensor.rl.envs.tf_atari_wrappers import MaxAndSkipWrapper
38+
from tensor2tensor.rl.envs.tf_atari_wrappers import PongT2TGeneratorHackWrapper
39+
from tensor2tensor.rl import collect
3840

3941
import tensorflow as tf
4042

4143

44+
def moviepy_editor():
45+
"""Access to moviepy to allow for import of this file without a moviepy install."""
46+
try:
47+
from moviepy import editor # pylint: disable=g-import-not-at-top
48+
except ImportError:
49+
raise ImportError("pip install moviepy to record videos")
50+
return editor
51+
4252
flags = tf.flags
4353
FLAGS = flags.FLAGS
4454

45-
flags.DEFINE_string("model_path", "", "File with model for pong")
46-
55+
flags.DEFINE_string("agent_policy_path", "", "File with model for pong")
4756

57+
@registry.register_problem
4858
class GymDiscreteProblem(problem.Problem):
4959
"""Gym environment with discrete actions and rewards."""
5060

5161
def __init__(self, *args, **kwargs):
5262
super(GymDiscreteProblem, self).__init__(*args, **kwargs)
53-
self._env = None
63+
self.num_channels = 3
64+
self.history_size = 2
65+
66+
# defaults
67+
self.environment_spec = lambda: gym.make("PongNoFrameskip-v4")
68+
self.in_graph_wrappers = [(MaxAndSkipWrapper, {"skip": 4})]
69+
self.collect_hparams = rl.atari_base()
70+
self.num_steps = 1000
71+
self.movies = True
72+
self.movies_fps = 24
73+
self.simulated_environment = None
74+
self.warm_up = 70
75+
76+
def _setup(self):
77+
# TODO: remove PongT2TGeneratorHackWrapper by writing a modality
78+
79+
in_graph_wrappers = [(PongT2TGeneratorHackWrapper, {"add_value": 2}),
80+
(MemoryWrapper, {})] + self.in_graph_wrappers
81+
env_hparams = HParams(in_graph_wrappers=in_graph_wrappers,
82+
simulated_environment=self.simulated_environment)
83+
84+
generator_batch_env = \
85+
batch_env_factory(self.environment_spec, env_hparams, num_agents=1, xvfb=False)
86+
87+
with tf.variable_scope("", reuse=tf.AUTO_REUSE):
88+
policy_lambda = self.collect_hparams.network
89+
policy_factory = tf.make_template(
90+
"network",
91+
functools.partial(policy_lambda, self.environment_spec().action_space, self.collect_hparams),
92+
create_scope_now_=True,
93+
unique_name_="network")
5494

55-
def example_reading_spec(self, label_repr=None):
95+
with tf.variable_scope("", reuse=tf.AUTO_REUSE):
96+
sample_policy = lambda policy: 0*policy.sample()
97+
# sample_policy = lambda policy: 0
5698

99+
self.collect_hparams.epoch_length = 10
100+
_, self.collect_trigger_op = collect.define_collect(
101+
policy_factory, generator_batch_env, self.collect_hparams,
102+
eval_phase=False, policy_to_actions_lambda=sample_policy, scope="define_collect")
103+
104+
self.avilable_data_size_op = MemoryWrapper.singleton._speculum.size()
105+
self.data_get_op = MemoryWrapper.singleton._speculum.dequeue()
106+
self.history_buffer = deque(maxlen=self.history_size+1)
107+
108+
def example_reading_spec(self, label_repr=None):
57109
data_fields = {
58-
"inputs": tf.FixedLenFeature([210, 160, 3], tf.int64),
59-
"inputs_prev": tf.FixedLenFeature([210, 160, 3], tf.int64),
60-
"targets": tf.FixedLenFeature([210, 160, 3], tf.int64),
61-
"action": tf.FixedLenFeature([1], tf.int64),
62-
"reward": tf.FixedLenFeature([1], tf.int64)
110+
"targets_encoded": tf.FixedLenFeature((), tf.string),
111+
"image/format": tf.FixedLenFeature((), tf.string),
112+
"action": tf.FixedLenFeature([1], tf.int64),
113+
"reward": tf.FixedLenFeature([1], tf.int64),
114+
# "done": tf.FixedLenFeature([1], tf.int64)
63115
}
64116

65-
return data_fields, None
117+
for x in range(self.history_size):
118+
data_fields["inputs_encoded_{}".format(x)] = tf.FixedLenFeature((), tf.string)
66119

67-
def eval_metrics(self):
68-
return [metrics.Metrics.ACC, metrics.Metrics.ACC_PER_SEQ,
69-
metrics.Metrics.NEG_LOG_PERPLEXITY, metrics.Metrics.IMAGE_SUMMARY]
70120

71-
@property
72-
def env_name(self):
73-
# This is the name of the Gym environment for this problem.
74-
raise NotImplementedError()
121+
data_items_to_decoders = {
122+
"targets":
123+
tf.contrib.slim.tfexample_decoder.Image(
124+
image_key="targets_encoded",
125+
format_key="image/format",
126+
shape=[210, 160, 3],
127+
channels=3),
75128

76-
@property
77-
def env(self):
78-
if self._env is None:
79-
self._env = gym.make(self.env_name)
80-
return self._env
129+
#Just do a pass through
130+
"action":tf.contrib.slim.tfexample_decoder.Tensor(tensor_key="action"),
131+
"reward":tf.contrib.slim.tfexample_decoder.Tensor(tensor_key="reward"),
132+
}
81133

82-
@property
83-
def num_channels(self):
84-
return 3
134+
for x in range(self.history_size):
135+
data_items_to_decoders["inputs_{}".format(x)] = tf.contrib.slim.tfexample_decoder.Image(
136+
image_key="inputs_encoded_{}".format(x),
137+
format_key="image/format",
138+
shape=[210, 160, 3],
139+
channels=3)
140+
141+
return data_fields, data_items_to_decoders
142+
143+
# def preprocess_example(self, example, mode, hparams):
144+
# if not self._was_reversed:
145+
# for x in range(self.history_size):
146+
# input_name = "inputs_{}".format(x)
147+
# example[input_name] = tf.image.per_image_standardization(example[input_name])
148+
# return example
85149

86150
@property
87151
def num_actions(self):
88-
raise NotImplementedError()
152+
return 4
89153

90154
@property
91155
def num_rewards(self):
92-
raise NotImplementedError()
93-
94-
@property
95-
def num_steps(self):
96-
raise NotImplementedError()
156+
return 2
97157

98158
@property
99159
def num_shards(self):
@@ -108,35 +168,70 @@ def get_action(self, observation=None):
108168

109169
def hparams(self, defaults, unused_model_hparams):
110170
p = defaults
111-
p.input_modality = {"inputs": ("image", 256),
112-
"inputs_prev": ("image", 256),
113-
"reward": ("symbol", self.num_rewards),
114-
"action": ("symbol", self.num_actions)}
115-
p.target_modality = ("image", 256)
171+
# hard coded +1 after "symbol" refers to the fact
172+
# that 0 is a special symbol meaning padding
173+
# when symbols are e.g. 0, 1, 2, 3 we
174+
# shift them to 0, 1, 2, 3, 4
175+
p.input_modality = {"action": ("symbol:identity", self.num_actions)}
176+
177+
for x in range(self.history_size):
178+
p.input_modality["inputs_{}".format(x)] = ("image", 256)
179+
180+
p.target_modality = {"targets": ("image", 256),
181+
"reward": ("symbol", self.num_rewards+1),
182+
# "done": ("symbol", 2+1)
183+
}
184+
116185
p.input_space_id = problem.SpaceID.IMAGE
117186
p.target_space_id = problem.SpaceID.IMAGE
118187

188+
def restore_networks(self, sess):
189+
model_saver = tf.train.Saver(
190+
tf.global_variables(".*network_parameters.*"))
191+
if FLAGS.agent_policy_path:
192+
model_saver.restore(sess, FLAGS.agent_policy_path)
193+
119194
def generator(self, data_dir, tmp_dir):
120-
self.env.reset()
121-
action = self.get_action()
122-
prev_observation, observation = None, None
123-
for _ in range(self.num_steps):
124-
prev_prev_observation = prev_observation
125-
prev_observation = observation
126-
observation, reward, done, _ = self.env.step(action)
127-
action = self.get_action(observation)
128-
if done:
129-
self.env.reset()
130-
def flatten(nparray):
131-
flat1 = [x for sublist in nparray.tolist() for x in sublist]
132-
return [x for sublist in flat1 for x in sublist]
133-
if prev_prev_observation is not None:
134-
yield {"inputs_prev": flatten(prev_prev_observation),
135-
"inputs": flatten(prev_observation),
136-
"action": [action],
137-
"done": [done],
138-
"reward": [int(reward)],
139-
"targets": flatten(observation)}
195+
self._setup()
196+
clip_files = []
197+
with tf.Session() as sess:
198+
sess.run(tf.global_variables_initializer())
199+
self.restore_networks(sess)
200+
201+
pieces_generated = 0
202+
while pieces_generated<self.num_steps + self.warm_up:
203+
avilable_data_size = sess.run(self.avilable_data_size_op)
204+
if avilable_data_size>0:
205+
observ, reward, action, done = sess.run(self.data_get_op)
206+
self.history_buffer.append(observ)
207+
208+
if self.movies==True and pieces_generated>self.warm_up:
209+
file_name = os.path.join(tmp_dir,'output_{}.png'.format(pieces_generated))
210+
clip_files.append(file_name)
211+
with open(file_name, 'wb') as f:
212+
f.write(observ)
213+
214+
if len(self.history_buffer)==self.history_size+1:
215+
pieces_generated += 1
216+
ret_dict = {
217+
"targets_encoded": [observ],
218+
"image/format": ["png"],
219+
"action": [int(action)],
220+
# "done": [bool(done)],
221+
"reward": [int(reward)],
222+
}
223+
for i, v in enumerate(list(self.history_buffer)[:-1]):
224+
ret_dict["inputs_encoded_{}".format(i)] = [v]
225+
if pieces_generated>self.warm_up:
226+
yield ret_dict
227+
else:
228+
sess.run(self.collect_trigger_op)
229+
if self.movies:
230+
# print(clip_files)
231+
clip = moviepy_editor().ImageSequenceClip(clip_files, fps=self.movies_fps)
232+
clip.write_videofile(os.path.join(data_dir, 'output_{}.mp4'.format(self.name)),
233+
fps=self.movies_fps, codec='mpeg4')
234+
140235

141236
def generate_data(self, data_dir, tmp_dir, task_id=-1):
142237
train_paths = self.training_filepaths(
@@ -150,93 +245,23 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
150245

151246

152247
@registry.register_problem
153-
class GymPongRandom5k(GymDiscreteProblem):
154-
"""Pong game, random actions."""
155-
156-
@property
157-
def env_name(self):
158-
return "PongDeterministic-v4"
159-
160-
@property
161-
def num_actions(self):
162-
return 4
163-
164-
@property
165-
def num_rewards(self):
166-
return 2
167-
168-
@property
169-
def num_steps(self):
170-
return 5000
171-
172-
173-
@registry.register_problem
174-
class GymPongTrajectoriesFromPolicy(GymDiscreteProblem):
175-
"""Pong game, loaded actions."""
248+
class GymSimulatedDiscreteProblem(GymDiscreteProblem):
249+
"""Simulated gym environment with discrete actions and rewards."""
176250

177251
def __init__(self, *args, **kwargs):
178-
super(GymPongTrajectoriesFromPolicy, self).__init__(*args, **kwargs)
179-
self._env = None
180-
self._last_policy_op = None
181-
self._max_frame_pl = None
182-
self._last_action = self.env.action_space.sample()
183-
self._skip = 4
184-
self._skip_step = 0
185-
self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape,
186-
dtype=np.uint8)
187-
188-
def generator(self, data_dir, tmp_dir):
189-
env_spec = lambda: atari_wrappers.wrap_atari( # pylint: disable=g-long-lambda
190-
gym.make(self.env_name),
191-
warp=False,
192-
frame_skip=4,
193-
frame_stack=False)
194-
hparams = rl.atari_base()
195-
with tf.variable_scope("train", reuse=tf.AUTO_REUSE):
196-
policy_lambda = hparams.network
197-
policy_factory = tf.make_template(
198-
"network",
199-
functools.partial(policy_lambda, env_spec().action_space, hparams))
200-
self._max_frame_pl = tf.placeholder(
201-
tf.float32, self.env.observation_space.shape)
202-
actor_critic = policy_factory(tf.expand_dims(tf.expand_dims(
203-
self._max_frame_pl, 0), 0))
204-
policy = actor_critic.policy
205-
self._last_policy_op = policy.mode()
206-
with tf.Session() as sess:
207-
model_saver = tf.train.Saver(
208-
tf.global_variables(".*network_parameters.*"))
209-
model_saver.restore(sess, FLAGS.model_path)
210-
for item in super(GymPongTrajectoriesFromPolicy,
211-
self).generator(data_dir, tmp_dir):
212-
yield item
213-
214-
# TODO(blazej0): For training of atari agents wrappers are usually used.
215-
# Below we have a hacky solution which is a workaround to be used together
216-
# with atari_wrappers.MaxAndSkipEnv.
217-
def get_action(self, observation=None):
218-
if self._skip_step == self._skip - 2: self._obs_buffer[0] = observation
219-
if self._skip_step == self._skip - 1: self._obs_buffer[1] = observation
220-
self._skip_step = (self._skip_step + 1) % self._skip
221-
if self._skip_step == 0:
222-
max_frame = self._obs_buffer.max(axis=0)
223-
self._last_action = int(tf.get_default_session().run(
224-
self._last_policy_op,
225-
feed_dict={self._max_frame_pl: max_frame})[0, 0])
226-
return self._last_action
227-
228-
@property
229-
def env_name(self):
230-
return "PongDeterministic-v4"
231-
232-
@property
233-
def num_actions(self):
234-
return 4
235-
236-
@property
237-
def num_rewards(self):
238-
return 2
239-
240-
@property
241-
def num_steps(self):
242-
return 5000
252+
super(GymSimulatedDiscreteProblem, self).__init__(*args, **kwargs)
253+
#TODO: pull it outside
254+
self.in_graph_wrappers = [(TimeLimitWrapper, {"timelimit": 150}), (MaxAndSkipWrapper, {"skip": 4})]
255+
self.simulated_environment = True
256+
self.movies_fps = 2
257+
258+
def restore_networks(self, sess):
259+
super(GymSimulatedDiscreteProblem, self).restore_networks(sess)
260+
261+
#TODO: adjust regexp for different models
262+
env_model_loader = tf.train.Saver(tf.global_variables(".*basic_conv_gen.*"))
263+
sess = tf.get_default_session()
264+
265+
ckpts = tf.train.get_checkpoint_state(FLAGS.output_dir)
266+
ckpt = ckpts.model_checkpoint_path
267+
env_model_loader.restore(sess, ckpt)

tensor2tensor/layers/modalities.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def bottom_simple(self, x, name, reuse):
100100
x = tf.squeeze(x, axis=3)
101101
while len(x.get_shape()) < 3:
102102
x = tf.expand_dims(x, axis=-1)
103-
104103
var = self._get_weights()
105104
x = common_layers.dropout_no_scaling(
106105
x, 1.0 - self._model_hparams.symbol_dropout)

0 commit comments

Comments
 (0)