@@ -69,44 +69,11 @@ def input_fn(mode, params):
69
69
},
70
70
}
71
71
72
- def decode_record (record ):
73
- """Serialized Example to dict of <feature name, Tensor>."""
74
- data_fields , _ = problem .example_reading_spec ()
75
- decoded = tf .parse_single_example (record , features = data_fields )
76
- decoded ["inputs" ] = decoded ["inputs" ].values
77
- decoded ["targets" ] = decoded ["targets" ].values
78
- return decoded
79
-
80
- data_files = tf .contrib .slim .parallel_reader .get_data_files (
81
- problem .filepattern (data_dir , mode ))
82
- dataset = tf .data .TFRecordDataset (data_files )
83
- dataset = dataset .map (decode_record , num_parallel_calls = num_threads )
84
-
85
- def _preprocess (example , problem , hparams , mode ):
86
- example = problem .preprocess_example (example , mode , hparams )
87
- # We do not want int64s as they are not supported on TPUs.
88
- example = data_reader .cast_int64_to_int32 (example )
89
- return example
90
-
91
- dataset = dataset .map (
92
- lambda ex : _preprocess (ex , problem , hparams , mode ),
93
- num_parallel_calls = num_threads )
94
-
95
72
def _valid_size (example ):
96
73
return data_reader .example_valid_size (
97
74
example , batching_scheme ["min_length" ], batching_scheme ["max_length" ])
98
75
99
- dataset = dataset .filter (_valid_size )
100
- # TODO(rsepassi): In eval mode, should not repeat
101
- dataset = dataset .repeat (None )
102
- dataset = data_reader .padded_batch (dataset , batch_size ,
103
- batching_scheme ["padded_shapes" ])
104
-
105
- if not is_training :
106
- dataset = dataset .map (
107
- lambda f : pad_batch (f , batch_size ), num_parallel_calls = num_threads )
108
-
109
- def shape_def (example ):
76
+ def define_shapes (example ):
110
77
"""Set the right shapes for the features."""
111
78
inputs = example ["inputs" ]
112
79
targets = example ["targets" ]
@@ -130,7 +97,22 @@ def shape_def(example):
130
97
131
98
return example
132
99
133
- dataset = dataset .map (shape_def , num_parallel_calls = num_threads )
100
+ dataset = problem .dataset (
101
+ mode = mode , data_dir = data_dir , num_threads = num_threads , hparams = hparams )
102
+ dataset = dataset .map (
103
+ data_reader .cast_int64_to_int32 , num_threads = num_threads )
104
+ dataset = dataset .filter (_valid_size )
105
+ if is_training :
106
+ dataset = dataset .shuffle (100 )
107
+ # TODO(rsepassi): In eval mode, should not repeat. Do so because TPU seems
108
+ # to crash if it runs out of data during eval.
109
+ dataset = dataset .repeat (None )
110
+ dataset = data_reader .padded_batch (dataset , batch_size ,
111
+ batching_scheme ["padded_shapes" ])
112
+ if not is_training :
113
+ dataset = dataset .map (
114
+ lambda f : pad_batch (f , batch_size ), num_parallel_calls = num_threads )
115
+ dataset = dataset .map (define_shapes , num_parallel_calls = num_threads )
134
116
dataset = dataset .prefetch (1 )
135
117
features = dataset .make_one_shot_iterator ().get_next ()
136
118
0 commit comments