24
24
# Dependency imports
25
25
26
26
import gym
27
- import numpy as np
27
+ import os
28
+ from tensorflow .contrib .training import HParams
29
+ from collections import deque
28
30
29
31
from tensor2tensor .data_generators import generator_utils
30
32
from tensor2tensor .data_generators import problem
31
-
32
33
from tensor2tensor .models .research import rl
33
- from tensor2tensor .rl import rl_trainer_lib # pylint: disable=unused-import
34
- from tensor2tensor .rl .envs import atari_wrappers
35
-
36
- from tensor2tensor .utils import metrics
37
34
from tensor2tensor .utils import registry
35
+ from tensor2tensor .rl .envs .utils import batch_env_factory
36
+ from tensor2tensor .rl .envs .tf_atari_wrappers import MemoryWrapper , TimeLimitWrapper
37
+ from tensor2tensor .rl .envs .tf_atari_wrappers import MaxAndSkipWrapper
38
+ from tensor2tensor .rl .envs .tf_atari_wrappers import PongT2TGeneratorHackWrapper
39
+ from tensor2tensor .rl import collect
38
40
39
41
import tensorflow as tf
40
42
41
43
44
+ def moviepy_editor ():
45
+ """Access to moviepy to allow for import of this file without a moviepy install."""
46
+ try :
47
+ from moviepy import editor # pylint: disable=g-import-not-at-top
48
+ except ImportError :
49
+ raise ImportError ("pip install moviepy to record videos" )
50
+ return editor
51
+
42
52
flags = tf .flags
43
53
FLAGS = flags .FLAGS
44
54
45
- flags .DEFINE_string ("model_path" , "" , "File with model for pong" )
46
-
55
+ flags .DEFINE_string ("agent_policy_path" , "" , "File with model for pong" )
47
56
57
+ @registry .register_problem
48
58
class GymDiscreteProblem (problem .Problem ):
49
59
"""Gym environment with discrete actions and rewards."""
50
60
51
61
def __init__ (self , * args , ** kwargs ):
52
62
super (GymDiscreteProblem , self ).__init__ (* args , ** kwargs )
53
- self ._env = None
63
+ self .num_channels = 3
64
+ self .history_size = 2
65
+
66
+ # defaults
67
+ self .environment_spec = lambda : gym .make ("PongNoFrameskip-v4" )
68
+ self .in_graph_wrappers = [(MaxAndSkipWrapper , {"skip" : 4 })]
69
+ self .collect_hparams = rl .atari_base ()
70
+ self .num_steps = 1000
71
+ self .movies = True
72
+ self .movies_fps = 24
73
+ self .simulated_environment = None
74
+ self .warm_up = 70
75
+
76
+ def _setup (self ):
77
+ # TODO: remove PongT2TGeneratorHackWrapper by writing a modality
78
+
79
+ in_graph_wrappers = [(PongT2TGeneratorHackWrapper , {"add_value" : 2 }),
80
+ (MemoryWrapper , {})] + self .in_graph_wrappers
81
+ env_hparams = HParams (in_graph_wrappers = in_graph_wrappers ,
82
+ simulated_environment = self .simulated_environment )
83
+
84
+ generator_batch_env = \
85
+ batch_env_factory (self .environment_spec , env_hparams , num_agents = 1 , xvfb = False )
86
+
87
+ with tf .variable_scope ("" , reuse = tf .AUTO_REUSE ):
88
+ policy_lambda = self .collect_hparams .network
89
+ policy_factory = tf .make_template (
90
+ "network" ,
91
+ functools .partial (policy_lambda , self .environment_spec ().action_space , self .collect_hparams ),
92
+ create_scope_now_ = True ,
93
+ unique_name_ = "network" )
54
94
55
- def example_reading_spec (self , label_repr = None ):
95
+ with tf .variable_scope ("" , reuse = tf .AUTO_REUSE ):
96
+ sample_policy = lambda policy : 0 * policy .sample ()
97
+ # sample_policy = lambda policy: 0
56
98
99
+ self .collect_hparams .epoch_length = 10
100
+ _ , self .collect_trigger_op = collect .define_collect (
101
+ policy_factory , generator_batch_env , self .collect_hparams ,
102
+ eval_phase = False , policy_to_actions_lambda = sample_policy , scope = "define_collect" )
103
+
104
+ self .avilable_data_size_op = MemoryWrapper .singleton ._speculum .size ()
105
+ self .data_get_op = MemoryWrapper .singleton ._speculum .dequeue ()
106
+ self .history_buffer = deque (maxlen = self .history_size + 1 )
107
+
108
+ def example_reading_spec (self , label_repr = None ):
57
109
data_fields = {
58
- "inputs " : tf .FixedLenFeature ([ 210 , 160 , 3 ], tf .int64 ),
59
- "inputs_prev " : tf .FixedLenFeature ([ 210 , 160 , 3 ], tf .int64 ),
60
- "targets " : tf .FixedLenFeature ([210 , 160 , 3 ], tf .int64 ),
61
- "action " : tf .FixedLenFeature ([1 ], tf .int64 ),
62
- "reward " : tf .FixedLenFeature ([1 ], tf .int64 )
110
+ "targets_encoded " : tf .FixedLenFeature ((), tf .string ),
111
+ "image/format " : tf .FixedLenFeature ((), tf .string ),
112
+ "action " : tf .FixedLenFeature ([1 ], tf .int64 ),
113
+ "reward " : tf .FixedLenFeature ([1 ], tf .int64 ),
114
+ # "done ": tf.FixedLenFeature([1], tf.int64)
63
115
}
64
116
65
- return data_fields , None
117
+ for x in range (self .history_size ):
118
+ data_fields ["inputs_encoded_{}" .format (x )] = tf .FixedLenFeature ((), tf .string )
66
119
67
- def eval_metrics (self ):
68
- return [metrics .Metrics .ACC , metrics .Metrics .ACC_PER_SEQ ,
69
- metrics .Metrics .NEG_LOG_PERPLEXITY , metrics .Metrics .IMAGE_SUMMARY ]
70
120
71
- @property
72
- def env_name (self ):
73
- # This is the name of the Gym environment for this problem.
74
- raise NotImplementedError ()
121
+ data_items_to_decoders = {
122
+ "targets" :
123
+ tf .contrib .slim .tfexample_decoder .Image (
124
+ image_key = "targets_encoded" ,
125
+ format_key = "image/format" ,
126
+ shape = [210 , 160 , 3 ],
127
+ channels = 3 ),
75
128
76
- @property
77
- def env (self ):
78
- if self ._env is None :
79
- self ._env = gym .make (self .env_name )
80
- return self ._env
129
+ #Just do a pass through
130
+ "action" :tf .contrib .slim .tfexample_decoder .Tensor (tensor_key = "action" ),
131
+ "reward" :tf .contrib .slim .tfexample_decoder .Tensor (tensor_key = "reward" ),
132
+ }
81
133
82
- @property
83
- def num_channels (self ):
84
- return 3
134
+ for x in range (self .history_size ):
135
+ data_items_to_decoders ["inputs_{}" .format (x )] = tf .contrib .slim .tfexample_decoder .Image (
136
+ image_key = "inputs_encoded_{}" .format (x ),
137
+ format_key = "image/format" ,
138
+ shape = [210 , 160 , 3 ],
139
+ channels = 3 )
140
+
141
+ return data_fields , data_items_to_decoders
142
+
143
+ # def preprocess_example(self, example, mode, hparams):
144
+ # if not self._was_reversed:
145
+ # for x in range(self.history_size):
146
+ # input_name = "inputs_{}".format(x)
147
+ # example[input_name] = tf.image.per_image_standardization(example[input_name])
148
+ # return example
85
149
86
150
@property
87
151
def num_actions (self ):
88
- raise NotImplementedError ()
152
+ return 4
89
153
90
154
@property
91
155
def num_rewards (self ):
92
- raise NotImplementedError ()
93
-
94
- @property
95
- def num_steps (self ):
96
- raise NotImplementedError ()
156
+ return 2
97
157
98
158
@property
99
159
def num_shards (self ):
@@ -108,35 +168,70 @@ def get_action(self, observation=None):
108
168
109
169
def hparams (self , defaults , unused_model_hparams ):
110
170
p = defaults
111
- p .input_modality = {"inputs" : ("image" , 256 ),
112
- "inputs_prev" : ("image" , 256 ),
113
- "reward" : ("symbol" , self .num_rewards ),
114
- "action" : ("symbol" , self .num_actions )}
115
- p .target_modality = ("image" , 256 )
171
+ # hard coded +1 after "symbol" refers to the fact
172
+ # that 0 is a special symbol meaning padding
173
+ # when symbols are e.g. 0, 1, 2, 3 we
174
+ # shift them to 0, 1, 2, 3, 4
175
+ p .input_modality = {"action" : ("symbol:identity" , self .num_actions )}
176
+
177
+ for x in range (self .history_size ):
178
+ p .input_modality ["inputs_{}" .format (x )] = ("image" , 256 )
179
+
180
+ p .target_modality = {"targets" : ("image" , 256 ),
181
+ "reward" : ("symbol" , self .num_rewards + 1 ),
182
+ # "done": ("symbol", 2+1)
183
+ }
184
+
116
185
p .input_space_id = problem .SpaceID .IMAGE
117
186
p .target_space_id = problem .SpaceID .IMAGE
118
187
188
+ def restore_networks (self , sess ):
189
+ model_saver = tf .train .Saver (
190
+ tf .global_variables (".*network_parameters.*" ))
191
+ if FLAGS .agent_policy_path :
192
+ model_saver .restore (sess , FLAGS .agent_policy_path )
193
+
119
194
def generator (self , data_dir , tmp_dir ):
120
- self .env .reset ()
121
- action = self .get_action ()
122
- prev_observation , observation = None , None
123
- for _ in range (self .num_steps ):
124
- prev_prev_observation = prev_observation
125
- prev_observation = observation
126
- observation , reward , done , _ = self .env .step (action )
127
- action = self .get_action (observation )
128
- if done :
129
- self .env .reset ()
130
- def flatten (nparray ):
131
- flat1 = [x for sublist in nparray .tolist () for x in sublist ]
132
- return [x for sublist in flat1 for x in sublist ]
133
- if prev_prev_observation is not None :
134
- yield {"inputs_prev" : flatten (prev_prev_observation ),
135
- "inputs" : flatten (prev_observation ),
136
- "action" : [action ],
137
- "done" : [done ],
138
- "reward" : [int (reward )],
139
- "targets" : flatten (observation )}
195
+ self ._setup ()
196
+ clip_files = []
197
+ with tf .Session () as sess :
198
+ sess .run (tf .global_variables_initializer ())
199
+ self .restore_networks (sess )
200
+
201
+ pieces_generated = 0
202
+ while pieces_generated < self .num_steps + self .warm_up :
203
+ avilable_data_size = sess .run (self .avilable_data_size_op )
204
+ if avilable_data_size > 0 :
205
+ observ , reward , action , done = sess .run (self .data_get_op )
206
+ self .history_buffer .append (observ )
207
+
208
+ if self .movies == True and pieces_generated > self .warm_up :
209
+ file_name = os .path .join (tmp_dir ,'output_{}.png' .format (pieces_generated ))
210
+ clip_files .append (file_name )
211
+ with open (file_name , 'wb' ) as f :
212
+ f .write (observ )
213
+
214
+ if len (self .history_buffer )== self .history_size + 1 :
215
+ pieces_generated += 1
216
+ ret_dict = {
217
+ "targets_encoded" : [observ ],
218
+ "image/format" : ["png" ],
219
+ "action" : [int (action )],
220
+ # "done": [bool(done)],
221
+ "reward" : [int (reward )],
222
+ }
223
+ for i , v in enumerate (list (self .history_buffer )[:- 1 ]):
224
+ ret_dict ["inputs_encoded_{}" .format (i )] = [v ]
225
+ if pieces_generated > self .warm_up :
226
+ yield ret_dict
227
+ else :
228
+ sess .run (self .collect_trigger_op )
229
+ if self .movies :
230
+ # print(clip_files)
231
+ clip = moviepy_editor ().ImageSequenceClip (clip_files , fps = self .movies_fps )
232
+ clip .write_videofile (os .path .join (data_dir , 'output_{}.mp4' .format (self .name )),
233
+ fps = self .movies_fps , codec = 'mpeg4' )
234
+
140
235
141
236
def generate_data (self , data_dir , tmp_dir , task_id = - 1 ):
142
237
train_paths = self .training_filepaths (
@@ -150,93 +245,23 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
150
245
151
246
152
247
@registry .register_problem
153
- class GymPongRandom5k (GymDiscreteProblem ):
154
- """Pong game, random actions."""
155
-
156
- @property
157
- def env_name (self ):
158
- return "PongDeterministic-v4"
159
-
160
- @property
161
- def num_actions (self ):
162
- return 4
163
-
164
- @property
165
- def num_rewards (self ):
166
- return 2
167
-
168
- @property
169
- def num_steps (self ):
170
- return 5000
171
-
172
-
173
- @registry .register_problem
174
- class GymPongTrajectoriesFromPolicy (GymDiscreteProblem ):
175
- """Pong game, loaded actions."""
248
+ class GymSimulatedDiscreteProblem (GymDiscreteProblem ):
249
+ """Simulated gym environment with discrete actions and rewards."""
176
250
177
251
def __init__ (self , * args , ** kwargs ):
178
- super (GymPongTrajectoriesFromPolicy , self ).__init__ (* args , ** kwargs )
179
- self ._env = None
180
- self ._last_policy_op = None
181
- self ._max_frame_pl = None
182
- self ._last_action = self .env .action_space .sample ()
183
- self ._skip = 4
184
- self ._skip_step = 0
185
- self ._obs_buffer = np .zeros ((2 ,) + self .env .observation_space .shape ,
186
- dtype = np .uint8 )
187
-
188
- def generator (self , data_dir , tmp_dir ):
189
- env_spec = lambda : atari_wrappers .wrap_atari ( # pylint: disable=g-long-lambda
190
- gym .make (self .env_name ),
191
- warp = False ,
192
- frame_skip = 4 ,
193
- frame_stack = False )
194
- hparams = rl .atari_base ()
195
- with tf .variable_scope ("train" , reuse = tf .AUTO_REUSE ):
196
- policy_lambda = hparams .network
197
- policy_factory = tf .make_template (
198
- "network" ,
199
- functools .partial (policy_lambda , env_spec ().action_space , hparams ))
200
- self ._max_frame_pl = tf .placeholder (
201
- tf .float32 , self .env .observation_space .shape )
202
- actor_critic = policy_factory (tf .expand_dims (tf .expand_dims (
203
- self ._max_frame_pl , 0 ), 0 ))
204
- policy = actor_critic .policy
205
- self ._last_policy_op = policy .mode ()
206
- with tf .Session () as sess :
207
- model_saver = tf .train .Saver (
208
- tf .global_variables (".*network_parameters.*" ))
209
- model_saver .restore (sess , FLAGS .model_path )
210
- for item in super (GymPongTrajectoriesFromPolicy ,
211
- self ).generator (data_dir , tmp_dir ):
212
- yield item
213
-
214
- # TODO(blazej0): For training of atari agents wrappers are usually used.
215
- # Below we have a hacky solution which is a workaround to be used together
216
- # with atari_wrappers.MaxAndSkipEnv.
217
- def get_action (self , observation = None ):
218
- if self ._skip_step == self ._skip - 2 : self ._obs_buffer [0 ] = observation
219
- if self ._skip_step == self ._skip - 1 : self ._obs_buffer [1 ] = observation
220
- self ._skip_step = (self ._skip_step + 1 ) % self ._skip
221
- if self ._skip_step == 0 :
222
- max_frame = self ._obs_buffer .max (axis = 0 )
223
- self ._last_action = int (tf .get_default_session ().run (
224
- self ._last_policy_op ,
225
- feed_dict = {self ._max_frame_pl : max_frame })[0 , 0 ])
226
- return self ._last_action
227
-
228
- @property
229
- def env_name (self ):
230
- return "PongDeterministic-v4"
231
-
232
- @property
233
- def num_actions (self ):
234
- return 4
235
-
236
- @property
237
- def num_rewards (self ):
238
- return 2
239
-
240
- @property
241
- def num_steps (self ):
242
- return 5000
252
+ super (GymSimulatedDiscreteProblem , self ).__init__ (* args , ** kwargs )
253
+ #TODO: pull it outside
254
+ self .in_graph_wrappers = [(TimeLimitWrapper , {"timelimit" : 150 }), (MaxAndSkipWrapper , {"skip" : 4 })]
255
+ self .simulated_environment = True
256
+ self .movies_fps = 2
257
+
258
+ def restore_networks (self , sess ):
259
+ super (GymSimulatedDiscreteProblem , self ).restore_networks (sess )
260
+
261
+ #TODO: adjust regexp for different models
262
+ env_model_loader = tf .train .Saver (tf .global_variables (".*basic_conv_gen.*" ))
263
+ sess = tf .get_default_session ()
264
+
265
+ ckpts = tf .train .get_checkpoint_state (FLAGS .output_dir )
266
+ ckpt = ckpts .model_checkpoint_path
267
+ env_model_loader .restore (sess , ckpt )
0 commit comments