Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit f9f41eb

Browse files
authored
Merge pull request #348 from medicode/set_precision_recall
Set precision and recall metrics
2 parents c34d16c + 5685b10 commit f9f41eb

File tree

1 file changed

+46
-1
lines changed

1 file changed

+46
-1
lines changed

tensor2tensor/utils/metrics.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ class Metrics(object):
4343
ROUGE_2_F = "rouge_2_fscore"
4444
ROUGE_L_F = "rouge_L_fscore"
4545
EDIT_DISTANCE = "edit_distance"
46-
46+
SET_PRECISION = 'set_precision'
47+
SET_RECALL = 'set_recall'
4748

4849
def padded_rmse(predictions, labels, weights_fn=common_layers.weights_all):
4950
predictions, labels = common_layers.pad_with_zeros(predictions, labels)
@@ -188,6 +189,48 @@ def padded_accuracy(predictions,
188189
padded_labels = tf.to_int32(padded_labels)
189190
return tf.to_float(tf.equal(outputs, padded_labels)), weights
190191

192+
def set_precision(predictions,
193+
labels,
194+
weights_fn=common_layers.weights_nonzero):
195+
"""Precision of set predictions.
196+
197+
Args:
198+
predictions : A Tensor of scores of shape (batch, nlabels)
199+
labels: A Tensor of int32s giving true set elements of shape (batch, seq_length)
200+
201+
Returns:
202+
hits: A Tensor of shape (batch, nlabels)
203+
weights: A Tensor of shape (batch, nlabels)
204+
"""
205+
with tf.variable_scope("set_precision", values=[predictions, labels]):
206+
labels = tf.squeeze(labels, [2, 3])
207+
labels = tf.one_hot(labels, predictions.shape[-1])
208+
labels = tf.reduce_max(labels, axis=1)
209+
labels = tf.cast(labels, tf.bool)
210+
predictions = predictions > 0
211+
return tf.to_float(tf.equal(labels, predictions)), tf.to_float(predictions)
212+
213+
def set_recall(predictions,
214+
labels,
215+
weights_fn=common_layers.weights_nonzero):
216+
"""Recall of set predictions.
217+
218+
Args:
219+
predictions : A Tensor of scores of shape (batch, nlabels)
220+
labels: A Tensor of int32s giving true set elements of shape (batch, seq_length)
221+
222+
Returns:
223+
hits: A Tensor of shape (batch, nlabels)
224+
weights: A Tensor of shape (batch, nlabels)
225+
"""
226+
with tf.variable_scope("set_recall", values=[predictions, labels]):
227+
labels = tf.squeeze(labels, [2, 3])
228+
labels = tf.one_hot(labels, predictions.shape[-1])
229+
labels = tf.reduce_max(labels, axis=1)
230+
labels = tf.cast(labels, tf.bool)
231+
predictions = predictions > 0
232+
return tf.to_float(tf.equal(labels, predictions)), tf.to_float(labels)
233+
191234

192235
def create_evaluation_metrics(problems, model_hparams):
193236
"""Creates the evaluation metrics for the model.
@@ -278,4 +321,6 @@ def wrapped_metric_fn():
278321
Metrics.ROUGE_2_F: rouge.rouge_2_fscore,
279322
Metrics.ROUGE_L_F: rouge.rouge_l_fscore,
280323
Metrics.EDIT_DISTANCE: sequence_edit_distance,
324+
Metrics.SET_PRECISION: set_precision,
325+
Metrics.SET_RECALL: set_recall,
281326
}

0 commit comments

Comments
 (0)