28
28
29
29
from tensor2tensor .data_generators import generator_utils
30
30
from tensor2tensor .data_generators import problem
31
+ from tensor2tensor .data_generators import video_utils
32
+
31
33
from tensor2tensor .models .research import rl
32
34
from tensor2tensor .rl import collect
33
35
from tensor2tensor .rl .envs import tf_atari_wrappers as atari
34
36
from tensor2tensor .rl .envs .utils import batch_env_factory
37
+
35
38
from tensor2tensor .utils import registry
36
39
37
40
import tensorflow as tf
38
41
39
- from tensorflow .contrib .training import HParams
40
-
41
42
42
43
flags = tf .flags
43
44
FLAGS = flags .FLAGS
44
45
45
46
flags .DEFINE_string ("agent_policy_path" , "" , "File with model for pong" )
46
47
47
48
49
+ class GymDiscreteProblem (video_utils .VideoProblem ):
50
+ """Gym environment with discrete actions and rewards."""
51
+
52
+ def __init__ (self , * args , ** kwargs ):
53
+ super (GymDiscreteProblem , self ).__init__ (* args , ** kwargs )
54
+ self ._env = None
55
+
56
+ @property
57
+ def num_input_frames (self ):
58
+ """Number of frames to batch on one input."""
59
+ return 2
60
+
61
+ @property
62
+ def num_target_frames (self ):
63
+ """Number of frames to batch on one target."""
64
+ return 1
65
+
66
+ @property
67
+ def extra_reading_spec (self ):
68
+ """Additional data fields to store on disk and their decoders."""
69
+ data_fields = {
70
+ "action" : tf .FixedLenFeature ([1 ], tf .int64 ),
71
+ "reward" : tf .FixedLenFeature ([1 ], tf .int64 )
72
+ }
73
+ decoders = {
74
+ "action" : tf .contrib .slim .tfexample_decoder .Tensor (tensor_key = "action" ),
75
+ "reward" : tf .contrib .slim .tfexample_decoder .Tensor (tensor_key = "reward" ),
76
+ }
77
+ return data_fields , decoders
78
+
79
+ @property
80
+ def is_generate_per_split (self ):
81
+ """Whether we have a train/test split or just hold out data."""
82
+ return False # Just hold out some generated data for evals.
83
+
84
+ @property
85
+ def env_name (self ):
86
+ """This is the name of the Gym environment for this problem."""
87
+ raise NotImplementedError ()
88
+
89
+ @property
90
+ def env (self ):
91
+ if self ._env is None :
92
+ self ._env = gym .make (self .env_name )
93
+ return self ._env
94
+
95
+ @property
96
+ def num_actions (self ):
97
+ raise NotImplementedError ()
98
+
99
+ @property
100
+ def num_rewards (self ):
101
+ raise NotImplementedError ()
102
+
103
+ @property
104
+ def num_steps (self ):
105
+ raise NotImplementedError ()
106
+
107
+ @property
108
+ def min_reward (self ):
109
+ raise NotImplementedError ()
110
+
111
+ def get_action (self , observation = None ):
112
+ return self .env .action_space .sample ()
113
+
114
+ def hparams (self , defaults , unused_model_hparams ):
115
+ p = defaults
116
+ p .input_modality = {"inputs" : ("video" , 256 ),
117
+ "input_reward" : ("symbol" , self .num_rewards ),
118
+ "input_action" : ("symbol" , self .num_actions )}
119
+ p .target_modality = ("video" , 256 )
120
+ p .input_space_id = problem .SpaceID .IMAGE
121
+ p .target_space_id = problem .SpaceID .IMAGE
122
+
123
+ def generate_samples (self , data_dir , tmp_dir , unused_dataset_split ):
124
+ self .env .reset ()
125
+ action = self .get_action ()
126
+ for _ in range (self .num_steps ):
127
+ observation , reward , done , _ = self .env .step (action )
128
+ action = self .get_action (observation )
129
+ yield {"frame" : observation ,
130
+ "action" : [action ],
131
+ "done" : [done ],
132
+ "reward" : [int (reward - self .min_reward )]}
133
+
134
+
135
+ @registry .register_problem
136
+ class GymPongRandom5k (GymDiscreteProblem ):
137
+ """Pong game, random actions."""
138
+
139
+ @property
140
+ def env_name (self ):
141
+ return "PongDeterministic-v4"
142
+
143
+ @property
144
+ def frame_height (self ):
145
+ return 210
146
+
147
+ @property
148
+ def frame_width (self ):
149
+ return 160
150
+
151
+ @property
152
+ def num_actions (self ):
153
+ return 4
154
+
155
+ @property
156
+ def min_reward (self ):
157
+ return - 1
158
+
159
+ @property
160
+ def num_rewards (self ):
161
+ return 3
162
+
163
+ @property
164
+ def num_steps (self ):
165
+ return 5000
166
+
167
+
168
+ @registry .register_problem
169
+ class GymPongRandom50k (GymPongRandom5k ):
170
+ """Pong game, random actions."""
171
+
172
+ @property
173
+ def num_steps (self ):
174
+ return 50000
175
+
176
+
48
177
def moviepy_editor ():
49
178
"""Access to moviepy that fails gracefully without a moviepy install."""
50
179
try :
@@ -55,11 +184,11 @@ def moviepy_editor():
55
184
56
185
57
186
@registry .register_problem
58
- class GymDiscreteProblem (problem .Problem ):
187
+ class GymDiscreteProblemWithAgent (problem .Problem ):
59
188
"""Gym environment with discrete actions and rewards."""
60
189
61
190
def __init__ (self , * args , ** kwargs ):
62
- super (GymDiscreteProblem , self ).__init__ (* args , ** kwargs )
191
+ super (GymDiscreteProblemWithAgent , self ).__init__ (* args , ** kwargs )
63
192
self .num_channels = 3
64
193
self .history_size = 2
65
194
@@ -68,16 +197,17 @@ def __init__(self, *args, **kwargs):
68
197
self .in_graph_wrappers = [(atari .MaxAndSkipWrapper , {"skip" : 4 })]
69
198
self .collect_hparams = rl .atari_base ()
70
199
self .num_steps = 1000
71
- self .movies = True
200
+ self .movies = False
72
201
self .movies_fps = 24
73
202
self .simulated_environment = None
74
203
self .warm_up = 70
75
204
76
205
def _setup (self ):
77
206
in_graph_wrappers = [(atari .ShiftRewardWrapper , {"add_value" : 2 }),
78
207
(atari .MemoryWrapper , {})] + self .in_graph_wrappers
79
- env_hparams = HParams (in_graph_wrappers = in_graph_wrappers ,
80
- simulated_environment = self .simulated_environment )
208
+ env_hparams = tf .contrib .training .HParams (
209
+ in_graph_wrappers = in_graph_wrappers ,
210
+ simulated_environment = self .simulated_environment )
81
211
82
212
generator_batch_env = batch_env_factory (
83
213
self .environment_spec , env_hparams , num_agents = 1 , xvfb = False )
@@ -234,19 +364,19 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
234
364
235
365
236
366
@registry .register_problem
237
- class GymSimulatedDiscreteProblem ( GymDiscreteProblem ):
367
+ class GymSimulatedDiscreteProblemWithAgent ( GymDiscreteProblemWithAgent ):
238
368
"""Simulated gym environment with discrete actions and rewards."""
239
369
240
370
def __init__ (self , * args , ** kwargs ):
241
- super (GymSimulatedDiscreteProblem , self ).__init__ (* args , ** kwargs )
371
+ super (GymSimulatedDiscreteProblemWithAgent , self ).__init__ (* args , ** kwargs )
242
372
# TODO(lukaszkaiser): pull it outside
243
373
self .in_graph_wrappers = [(atari .TimeLimitWrapper , {"timelimit" : 150 }),
244
374
(atari .MaxAndSkipWrapper , {"skip" : 4 })]
245
375
self .simulated_environment = True
246
376
self .movies_fps = 2
247
377
248
378
def restore_networks (self , sess ):
249
- super (GymSimulatedDiscreteProblem , self ).restore_networks (sess )
379
+ super (GymSimulatedDiscreteProblemWithAgent , self ).restore_networks (sess )
250
380
251
381
# TODO(lukaszkaiser): adjust regexp for different models
252
382
env_model_loader = tf .train .Saver (tf .global_variables (".*basic_conv_gen.*" ))
0 commit comments