Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 76706ef

Browse files
author
Ryan Sepassi
committed
Make output of fn in @recompute_grad a list to avoid trying to concat tuple and list
PiperOrigin-RevId: 169632380
1 parent 6237729 commit 76706ef

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tensor2tensor/layers/rev_block.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def _recompute_grad(fn, args):
348348
def grad_fn(inputs, variables, outputs, output_grads):
349349
del outputs
350350
# recompute outputs
351-
outputs = fn(*inputs)
351+
outputs = list(fn(*inputs))
352352
grads = tf.gradients(outputs, inputs + variables, output_grads)
353353
grad_inputs = grads[:len(inputs)]
354354
grad_vars = grads[len(inputs):]

0 commit comments

Comments
 (0)