@@ -43,7 +43,8 @@ class Metrics(object):
43
43
ROUGE_2_F = "rouge_2_fscore"
44
44
ROUGE_L_F = "rouge_L_fscore"
45
45
EDIT_DISTANCE = "edit_distance"
46
-
46
+ SET_PRECISION = 'set_precision'
47
+ SET_RECALL = 'set_recall'
47
48
48
49
def padded_rmse (predictions , labels , weights_fn = common_layers .weights_all ):
49
50
predictions , labels = common_layers .pad_with_zeros (predictions , labels )
@@ -188,6 +189,48 @@ def padded_accuracy(predictions,
188
189
padded_labels = tf .to_int32 (padded_labels )
189
190
return tf .to_float (tf .equal (outputs , padded_labels )), weights
190
191
192
+ def set_precision (predictions ,
193
+ labels ,
194
+ weights_fn = common_layers .weights_nonzero ):
195
+ """Precision of set predictions.
196
+
197
+ Args:
198
+ predictions : A Tensor of scores of shape (batch, nlabels)
199
+ labels: A Tensor of int32s giving true set elements of shape (batch, seq_length)
200
+
201
+ Returns:
202
+ hits: A Tensor of shape (batch, nlabels)
203
+ weights: A Tensor of shape (batch, nlabels)
204
+ """
205
+ with tf .variable_scope ("set_precision" , values = [predictions , labels ]):
206
+ labels = tf .squeeze (labels , [2 , 3 ])
207
+ labels = tf .one_hot (labels , predictions .shape [- 1 ])
208
+ labels = tf .reduce_max (labels , axis = 1 )
209
+ labels = tf .cast (labels , tf .bool )
210
+ predictions = predictions > 0
211
+ return tf .to_float (tf .equal (labels , predictions )), tf .to_float (predictions )
212
+
213
+ def set_recall (predictions ,
214
+ labels ,
215
+ weights_fn = common_layers .weights_nonzero ):
216
+ """Recall of set predictions.
217
+
218
+ Args:
219
+ predictions : A Tensor of scores of shape (batch, nlabels)
220
+ labels: A Tensor of int32s giving true set elements of shape (batch, seq_length)
221
+
222
+ Returns:
223
+ hits: A Tensor of shape (batch, nlabels)
224
+ weights: A Tensor of shape (batch, nlabels)
225
+ """
226
+ with tf .variable_scope ("set_recall" , values = [predictions , labels ]):
227
+ labels = tf .squeeze (labels , [2 , 3 ])
228
+ labels = tf .one_hot (labels , predictions .shape [- 1 ])
229
+ labels = tf .reduce_max (labels , axis = 1 )
230
+ labels = tf .cast (labels , tf .bool )
231
+ predictions = predictions > 0
232
+ return tf .to_float (tf .equal (labels , predictions )), tf .to_float (labels )
233
+
191
234
192
235
def create_evaluation_metrics (problems , model_hparams ):
193
236
"""Creates the evaluation metrics for the model.
@@ -278,4 +321,6 @@ def wrapped_metric_fn():
278
321
Metrics .ROUGE_2_F : rouge .rouge_2_fscore ,
279
322
Metrics .ROUGE_L_F : rouge .rouge_l_fscore ,
280
323
Metrics .EDIT_DISTANCE : sequence_edit_distance ,
324
+ Metrics .SET_PRECISION : set_precision ,
325
+ Metrics .SET_RECALL : set_recall ,
281
326
}
0 commit comments