@@ -66,6 +66,8 @@ flags.DEFINE_integer("num_shards", 0, "How many shards to use. Ignored for "
66
66
"registered Problems." )
67
67
flags .DEFINE_integer ("max_cases" , 0 ,
68
68
"Maximum number of cases to generate (unbounded if 0)." )
69
+ flags .DEFINE_bool ("only_list" , False ,
70
+ "If true, we only list the problems that will be generated." )
69
71
flags .DEFINE_integer ("random_seed" , 429459 , "Random seed to use." )
70
72
flags .DEFINE_integer ("task_id" , - 1 , "For distributed data generation." )
71
73
flags .DEFINE_string ("t2t_usr_dir" , "" ,
@@ -81,33 +83,33 @@ _SUPPORTED_PROBLEM_GENERATORS = {
81
83
"algorithmic_algebra_inverse" : (
82
84
lambda : algorithmic_math .algebra_inverse (26 , 0 , 2 , 100000 ),
83
85
lambda : algorithmic_math .algebra_inverse (26 , 3 , 3 , 10000 )),
84
- "wmt_parsing_tokens_8k " : (
86
+ "parsing_english_ptb8k " : (
85
87
lambda : wmt .parsing_token_generator (
86
88
FLAGS .data_dir , FLAGS .tmp_dir , True , 2 ** 13 ),
87
89
lambda : wmt .parsing_token_generator (
88
90
FLAGS .data_dir , FLAGS .tmp_dir , False , 2 ** 13 )),
89
- "wsj_parsing_tokens_16k " : (
91
+ "parsing_english_ptb16k " : (
90
92
lambda : wsj_parsing .parsing_token_generator (
91
93
FLAGS .data_dir , FLAGS .tmp_dir , True , 2 ** 14 , 2 ** 9 ),
92
94
lambda : wsj_parsing .parsing_token_generator (
93
95
FLAGS .data_dir , FLAGS .tmp_dir , False , 2 ** 14 , 2 ** 9 )),
94
- "wmt_ende_bpe32k " : (
96
+ "translate_ende_wmt_bpe32k " : (
95
97
lambda : wmt .ende_bpe_token_generator (
96
98
FLAGS .data_dir , FLAGS .tmp_dir , True ),
97
99
lambda : wmt .ende_bpe_token_generator (
98
100
FLAGS .data_dir , FLAGS .tmp_dir , False )),
99
- "lm1b_32k " : (
101
+ "languagemodel_1b32k " : (
100
102
lambda : lm1b .generator (FLAGS .tmp_dir , True ),
101
103
lambda : lm1b .generator (FLAGS .tmp_dir , False )
102
104
),
103
- "lm1b_characters " : (
105
+ "languagemodel_1b_characters " : (
104
106
lambda : lm1b .generator (FLAGS .tmp_dir , True , characters = True ),
105
107
lambda : lm1b .generator (FLAGS .tmp_dir , False , characters = True )
106
108
),
107
109
"image_celeba_tune" : (
108
110
lambda : image .celeba_generator (FLAGS .tmp_dir , 162770 ),
109
111
lambda : image .celeba_generator (FLAGS .tmp_dir , 19867 , 162770 )),
110
- "snli_32k " : (
112
+ "inference_snli32k " : (
111
113
lambda : snli .snli_token_generator (FLAGS .tmp_dir , True , 2 ** 15 ),
112
114
lambda : snli .snli_token_generator (FLAGS .tmp_dir , False , 2 ** 15 ),
113
115
),
@@ -181,7 +183,11 @@ def main(_):
181
183
"Data will be written to default data_dir=%s." ,
182
184
FLAGS .data_dir )
183
185
184
- tf .logging .info ("Generating problems:\n * %s\n " % "\n * " .join (problems ))
186
+ tf .logging .info ("Generating problems:\n %s"
187
+ % registry .display_list_by_prefix (problems ,
188
+ starting_spaces = 4 ))
189
+ if FLAGS .only_list :
190
+ return
185
191
for problem in problems :
186
192
set_random_seed ()
187
193
@@ -210,7 +216,7 @@ def generate_data_for_problem(problem):
210
216
211
217
212
218
def generate_data_for_registered_problem (problem_name ):
213
- tf .logging .info ("Generating training data for %s." , problem_name )
219
+ tf .logging .info ("Generating data for %s." , problem_name )
214
220
if FLAGS .num_shards :
215
221
raise ValueError ("--num_shards should not be set for registered Problem." )
216
222
problem = registry .problem (problem_name )
0 commit comments