Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 7058eda

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
VideoProblem API with dynamic frame composition, video modality and making gym problem and one model work.
PiperOrigin-RevId: 192798025
1 parent 6c629ea commit 7058eda

File tree

10 files changed

+576
-81
lines changed

10 files changed

+576
-81
lines changed

tensor2tensor/data_generators/all_problems.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,45 +20,46 @@
2020

2121
import importlib
2222

23+
2324
modules = [
24-
'tensor2tensor.data_generators.algorithmic',
25-
'tensor2tensor.data_generators.algorithmic_math',
26-
'tensor2tensor.data_generators.audio',
27-
'tensor2tensor.data_generators.celeba',
28-
'tensor2tensor.data_generators.cifar',
29-
'tensor2tensor.data_generators.cipher',
30-
'tensor2tensor.data_generators.cnn_dailymail',
31-
'tensor2tensor.data_generators.desc2code',
32-
'tensor2tensor.data_generators.fsns',
33-
'tensor2tensor.data_generators.gene_expression',
34-
'tensor2tensor.data_generators.gym',
35-
'tensor2tensor.data_generators.ice_parsing',
36-
'tensor2tensor.data_generators.imagenet',
37-
'tensor2tensor.data_generators.imdb',
38-
'tensor2tensor.data_generators.librispeech',
39-
'tensor2tensor.data_generators.lm1b',
40-
'tensor2tensor.data_generators.mnist',
41-
'tensor2tensor.data_generators.mscoco',
42-
'tensor2tensor.data_generators.multinli',
43-
'tensor2tensor.data_generators.ocr',
44-
'tensor2tensor.data_generators.problem_hparams',
45-
'tensor2tensor.data_generators.ptb',
46-
'tensor2tensor.data_generators.snli',
47-
'tensor2tensor.data_generators.squad',
48-
'tensor2tensor.data_generators.translate_encs',
49-
'tensor2tensor.data_generators.translate_ende',
50-
'tensor2tensor.data_generators.translate_enfr',
51-
'tensor2tensor.data_generators.translate_enmk',
52-
'tensor2tensor.data_generators.translate_envi',
53-
'tensor2tensor.data_generators.translate_enzh',
54-
'tensor2tensor.data_generators.twentybn',
55-
'tensor2tensor.data_generators.wiki',
56-
'tensor2tensor.data_generators.wsj_parsing',
25+
"tensor2tensor.data_generators.algorithmic",
26+
"tensor2tensor.data_generators.algorithmic_math",
27+
"tensor2tensor.data_generators.audio",
28+
"tensor2tensor.data_generators.celeba",
29+
"tensor2tensor.data_generators.cifar",
30+
"tensor2tensor.data_generators.cipher",
31+
"tensor2tensor.data_generators.cnn_dailymail",
32+
"tensor2tensor.data_generators.desc2code",
33+
"tensor2tensor.data_generators.fsns",
34+
"tensor2tensor.data_generators.gene_expression",
35+
"tensor2tensor.data_generators.gym",
36+
"tensor2tensor.data_generators.ice_parsing",
37+
"tensor2tensor.data_generators.imagenet",
38+
"tensor2tensor.data_generators.imdb",
39+
"tensor2tensor.data_generators.librispeech",
40+
"tensor2tensor.data_generators.lm1b",
41+
"tensor2tensor.data_generators.mnist",
42+
"tensor2tensor.data_generators.mscoco",
43+
"tensor2tensor.data_generators.multinli",
44+
"tensor2tensor.data_generators.ocr",
45+
"tensor2tensor.data_generators.problem_hparams",
46+
"tensor2tensor.data_generators.ptb",
47+
"tensor2tensor.data_generators.snli",
48+
"tensor2tensor.data_generators.squad",
49+
"tensor2tensor.data_generators.translate_encs",
50+
"tensor2tensor.data_generators.translate_ende",
51+
"tensor2tensor.data_generators.translate_enfr",
52+
"tensor2tensor.data_generators.translate_enmk",
53+
"tensor2tensor.data_generators.translate_envi",
54+
"tensor2tensor.data_generators.translate_enzh",
55+
"tensor2tensor.data_generators.twentybn",
56+
"tensor2tensor.data_generators.wiki",
57+
"tensor2tensor.data_generators.wsj_parsing",
5758
]
5859

5960

6061
for module in modules:
6162
try:
6263
importlib.import_module(module)
63-
except ImportError:
64-
pass
64+
except ImportError as error:
65+
print("Did not import module: %s; Cause: %s" % (module, str(error)))

tensor2tensor/data_generators/gym.py

Lines changed: 140 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,152 @@
2828

2929
from tensor2tensor.data_generators import generator_utils
3030
from tensor2tensor.data_generators import problem
31+
from tensor2tensor.data_generators import video_utils
32+
3133
from tensor2tensor.models.research import rl
3234
from tensor2tensor.rl import collect
3335
from tensor2tensor.rl.envs import tf_atari_wrappers as atari
3436
from tensor2tensor.rl.envs.utils import batch_env_factory
37+
3538
from tensor2tensor.utils import registry
3639

3740
import tensorflow as tf
3841

39-
from tensorflow.contrib.training import HParams
40-
4142

4243
flags = tf.flags
4344
FLAGS = flags.FLAGS
4445

4546
flags.DEFINE_string("agent_policy_path", "", "File with model for pong")
4647

4748

49+
class GymDiscreteProblem(video_utils.VideoProblem):
50+
"""Gym environment with discrete actions and rewards."""
51+
52+
def __init__(self, *args, **kwargs):
53+
super(GymDiscreteProblem, self).__init__(*args, **kwargs)
54+
self._env = None
55+
56+
@property
57+
def num_input_frames(self):
58+
"""Number of frames to batch on one input."""
59+
return 2
60+
61+
@property
62+
def num_target_frames(self):
63+
"""Number of frames to batch on one target."""
64+
return 1
65+
66+
@property
67+
def extra_reading_spec(self):
68+
"""Additional data fields to store on disk and their decoders."""
69+
data_fields = {
70+
"action": tf.FixedLenFeature([1], tf.int64),
71+
"reward": tf.FixedLenFeature([1], tf.int64)
72+
}
73+
decoders = {
74+
"action": tf.contrib.slim.tfexample_decoder.Tensor(tensor_key="action"),
75+
"reward": tf.contrib.slim.tfexample_decoder.Tensor(tensor_key="reward"),
76+
}
77+
return data_fields, decoders
78+
79+
@property
80+
def is_generate_per_split(self):
81+
"""Whether we have a train/test split or just hold out data."""
82+
return False # Just hold out some generated data for evals.
83+
84+
@property
85+
def env_name(self):
86+
"""This is the name of the Gym environment for this problem."""
87+
raise NotImplementedError()
88+
89+
@property
90+
def env(self):
91+
if self._env is None:
92+
self._env = gym.make(self.env_name)
93+
return self._env
94+
95+
@property
96+
def num_actions(self):
97+
raise NotImplementedError()
98+
99+
@property
100+
def num_rewards(self):
101+
raise NotImplementedError()
102+
103+
@property
104+
def num_steps(self):
105+
raise NotImplementedError()
106+
107+
@property
108+
def min_reward(self):
109+
raise NotImplementedError()
110+
111+
def get_action(self, observation=None):
112+
return self.env.action_space.sample()
113+
114+
def hparams(self, defaults, unused_model_hparams):
115+
p = defaults
116+
p.input_modality = {"inputs": ("video", 256),
117+
"input_reward": ("symbol", self.num_rewards),
118+
"input_action": ("symbol", self.num_actions)}
119+
p.target_modality = ("video", 256)
120+
p.input_space_id = problem.SpaceID.IMAGE
121+
p.target_space_id = problem.SpaceID.IMAGE
122+
123+
def generate_samples(self, data_dir, tmp_dir, unused_dataset_split):
124+
self.env.reset()
125+
action = self.get_action()
126+
for _ in range(self.num_steps):
127+
observation, reward, done, _ = self.env.step(action)
128+
action = self.get_action(observation)
129+
yield {"frame": observation,
130+
"action": [action],
131+
"done": [done],
132+
"reward": [int(reward - self.min_reward)]}
133+
134+
135+
@registry.register_problem
136+
class GymPongRandom5k(GymDiscreteProblem):
137+
"""Pong game, random actions."""
138+
139+
@property
140+
def env_name(self):
141+
return "PongDeterministic-v4"
142+
143+
@property
144+
def frame_height(self):
145+
return 210
146+
147+
@property
148+
def frame_width(self):
149+
return 160
150+
151+
@property
152+
def num_actions(self):
153+
return 4
154+
155+
@property
156+
def min_reward(self):
157+
return -1
158+
159+
@property
160+
def num_rewards(self):
161+
return 3
162+
163+
@property
164+
def num_steps(self):
165+
return 5000
166+
167+
168+
@registry.register_problem
169+
class GymPongRandom50k(GymPongRandom5k):
170+
"""Pong game, random actions."""
171+
172+
@property
173+
def num_steps(self):
174+
return 50000
175+
176+
48177
def moviepy_editor():
49178
"""Access to moviepy that fails gracefully without a moviepy install."""
50179
try:
@@ -55,11 +184,11 @@ def moviepy_editor():
55184

56185

57186
@registry.register_problem
58-
class GymDiscreteProblem(problem.Problem):
187+
class GymDiscreteProblemWithAgent(problem.Problem):
59188
"""Gym environment with discrete actions and rewards."""
60189

61190
def __init__(self, *args, **kwargs):
62-
super(GymDiscreteProblem, self).__init__(*args, **kwargs)
191+
super(GymDiscreteProblemWithAgent, self).__init__(*args, **kwargs)
63192
self.num_channels = 3
64193
self.history_size = 2
65194

@@ -68,16 +197,17 @@ def __init__(self, *args, **kwargs):
68197
self.in_graph_wrappers = [(atari.MaxAndSkipWrapper, {"skip": 4})]
69198
self.collect_hparams = rl.atari_base()
70199
self.num_steps = 1000
71-
self.movies = True
200+
self.movies = False
72201
self.movies_fps = 24
73202
self.simulated_environment = None
74203
self.warm_up = 70
75204

76205
def _setup(self):
77206
in_graph_wrappers = [(atari.ShiftRewardWrapper, {"add_value": 2}),
78207
(atari.MemoryWrapper, {})] + self.in_graph_wrappers
79-
env_hparams = HParams(in_graph_wrappers=in_graph_wrappers,
80-
simulated_environment=self.simulated_environment)
208+
env_hparams = tf.contrib.training.HParams(
209+
in_graph_wrappers=in_graph_wrappers,
210+
simulated_environment=self.simulated_environment)
81211

82212
generator_batch_env = batch_env_factory(
83213
self.environment_spec, env_hparams, num_agents=1, xvfb=False)
@@ -234,19 +364,19 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
234364

235365

236366
@registry.register_problem
237-
class GymSimulatedDiscreteProblem(GymDiscreteProblem):
367+
class GymSimulatedDiscreteProblemWithAgent(GymDiscreteProblemWithAgent):
238368
"""Simulated gym environment with discrete actions and rewards."""
239369

240370
def __init__(self, *args, **kwargs):
241-
super(GymSimulatedDiscreteProblem, self).__init__(*args, **kwargs)
371+
super(GymSimulatedDiscreteProblemWithAgent, self).__init__(*args, **kwargs)
242372
# TODO(lukaszkaiser): pull it outside
243373
self.in_graph_wrappers = [(atari.TimeLimitWrapper, {"timelimit": 150}),
244374
(atari.MaxAndSkipWrapper, {"skip": 4})]
245375
self.simulated_environment = True
246376
self.movies_fps = 2
247377

248378
def restore_networks(self, sess):
249-
super(GymSimulatedDiscreteProblem, self).restore_networks(sess)
379+
super(GymSimulatedDiscreteProblemWithAgent, self).restore_networks(sess)
250380

251381
# TODO(lukaszkaiser): adjust regexp for different models
252382
env_model_loader = tf.train.Saver(tf.global_variables(".*basic_conv_gen.*"))

tensor2tensor/data_generators/image_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
158158
self.dev_filepaths(data_dir, self.dev_shards, shuffled=False))
159159

160160

161-
def _encoded_images(images):
161+
def encode_images_as_png(images):
162162
if context.in_eager_mode():
163163
for image in images:
164164
yield tf.image.encode_png(image).numpy()
@@ -195,7 +195,7 @@ def image_generator(images, labels):
195195
if not images:
196196
raise ValueError("Must provide some images for the generator.")
197197
width, height, _ = images[0].shape
198-
for (enc_image, label) in zip(_encoded_images(images), labels):
198+
for (enc_image, label) in zip(encode_images_as_png(images), labels):
199199
yield {
200200
"image/encoded": [enc_image],
201201
"image/format": ["png"],

tensor2tensor/data_generators/problem.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -781,9 +781,7 @@ def define_shapes(example):
781781
batch_size_means_tokens = False
782782
else:
783783
tf.logging.warning(
784-
"Shapes are not fully defined. Assuming batch_size means tokens. "
785-
"Override batch_size_means_tokens() "
786-
"in your problem subclass if this is undesired behavior.")
784+
"Shapes are not fully defined. Assuming batch_size means tokens.")
787785
batch_size_means_tokens = True
788786

789787
# Batching

0 commit comments

Comments
 (0)