Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit b1e7708

Browse files
nshazeerRyan Sepassi
authored andcommitted
transformer_symshard research model - Symmetrically sharded version of Transformer.
PiperOrigin-RevId: 186915433
1 parent ba2f9c8 commit b1e7708

File tree

3 files changed

+420
-2
lines changed

3 files changed

+420
-2
lines changed

tensor2tensor/layers/common_layers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,11 +216,12 @@ def dropout_no_scaling(x, keep_prob):
216216

217217

218218
def embedding(x, vocab_size, dense_size, name=None, reuse=None, multiplier=1.0,
219-
symbol_dropout_rate=0.0):
219+
symbol_dropout_rate=0.0, embedding_var=None):
220220
"""Embed x of type int64 into dense vectors, reducing to max 4 dimensions."""
221221
with tf.variable_scope(
222222
name, default_name="embedding", values=[x], reuse=reuse):
223-
embedding_var = tf.get_variable("kernel", [vocab_size, dense_size])
223+
if embedding_var is None:
224+
embedding_var = tf.get_variable("kernel", [vocab_size, dense_size])
224225
# On the backwards pass, we want to convert the gradient from
225226
# an indexed-slices to a regular tensor before sending it back to the
226227
# parameter server. This avoids excess computation on the parameter server.

tensor2tensor/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,6 @@
4848
from tensor2tensor.models.research import transformer_moe
4949
from tensor2tensor.models.research import transformer_revnet
5050
from tensor2tensor.models.research import transformer_sketch
51+
from tensor2tensor.models.research import transformer_symshard
5152
from tensor2tensor.models.research import transformer_vae
5253
# pylint: enable=unused-import

0 commit comments

Comments
 (0)