Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 2e900a8

Browse files
authored
Merge pull request #93 from kolloldas/fix_issue_79
Fix issue 79: Mismatch in Logits and Labels
2 parents 228feae + 4a1d7da commit 2e900a8

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

tensor2tensor/models/common_layers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,7 +1285,7 @@ def pad_with_zeros(logits, labels):
12851285
logits, labels = pad_to_same_length(logits, labels)
12861286
if len(labels.shape.as_list()) == 3: # 2-d labels.
12871287
logits, labels = pad_to_same_length(logits, labels, axis=2)
1288-
return labels
1288+
return logits, labels
12891289

12901290

12911291
def weights_nonzero(labels):
@@ -1351,8 +1351,8 @@ def padded_cross_entropy(logits,
13511351
confidence = 1.0 - label_smoothing
13521352
vocab_size = tf.shape(logits)[-1]
13531353
with tf.name_scope("padded_cross_entropy", [logits, labels]):
1354-
pad_labels = pad_with_zeros(logits, labels)
1355-
xent = smoothing_cross_entropy(logits, pad_labels, vocab_size, confidence)
1354+
pad_logits, pad_labels = pad_with_zeros(logits, labels)
1355+
xent = smoothing_cross_entropy(pad_logits, pad_labels, vocab_size, confidence)
13561356
weights = weights_fn(pad_labels)
13571357
if not reduce_sum:
13581358
return xent * weights, weights

tensor2tensor/utils/metrics.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ def padded_accuracy_topk(predictions,
3737
weights_fn=common_layers.weights_nonzero):
3838
"""Percentage of times that top-k predictions matches labels on non-0s."""
3939
with tf.variable_scope("padded_accuracy_topk", values=[predictions, labels]):
40-
padded_labels = common_layers.pad_with_zeros(predictions, labels)
40+
padded_predictions, padded_labels = common_layers.pad_with_zeros(predictions, labels)
4141
weights = weights_fn(padded_labels)
42-
effective_k = tf.minimum(k, tf.shape(predictions)[-1])
43-
_, outputs = tf.nn.top_k(predictions, k=effective_k)
42+
effective_k = tf.minimum(k, tf.shape(padded_predictions)[-1])
43+
_, outputs = tf.nn.top_k(padded_predictions, k=effective_k)
4444
outputs = tf.to_int32(outputs)
4545
padded_labels = tf.expand_dims(padded_labels, axis=-1)
4646
padded_labels += tf.zeros_like(outputs) # Pad to same shape.
@@ -61,9 +61,9 @@ def padded_sequence_accuracy(predictions,
6161
"""Percentage of times that predictions matches labels everywhere (non-0)."""
6262
with tf.variable_scope(
6363
"padded_sequence_accuracy", values=[predictions, labels]):
64-
padded_labels = common_layers.pad_with_zeros(predictions, labels)
64+
paded_predictions, padded_labels = common_layers.pad_with_zeros(predictions, labels)
6565
weights = weights_fn(padded_labels)
66-
outputs = tf.to_int32(tf.argmax(predictions, axis=-1))
66+
outputs = tf.to_int32(tf.argmax(paded_predictions, axis=-1))
6767
not_correct = tf.to_float(tf.not_equal(outputs, padded_labels)) * weights
6868
axis = list(range(1, len(outputs.get_shape())))
6969
correct_seq = 1.0 - tf.minimum(1.0, tf.reduce_sum(not_correct, axis=axis))
@@ -84,9 +84,9 @@ def padded_accuracy(predictions,
8484
weights_fn=common_layers.weights_nonzero):
8585
"""Percentage of times that predictions matches labels on non-0s."""
8686
with tf.variable_scope("padded_accuracy", values=[predictions, labels]):
87-
padded_labels = common_layers.pad_with_zeros(predictions, labels)
87+
padded_predictions, padded_labels = common_layers.pad_with_zeros(predictions, labels)
8888
weights = weights_fn(padded_labels)
89-
outputs = tf.to_int32(tf.argmax(predictions, axis=-1))
89+
outputs = tf.to_int32(tf.argmax(padded_predictions, axis=-1))
9090
return tf.to_float(tf.equal(outputs, padded_labels)), weights
9191

9292

0 commit comments

Comments
 (0)