Skip to content

Commit fb5b173

Browse files
committed
[Layers] Adjust var name in EV
1 parent 31c42f8 commit fb5b173

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

deepray/layers/embedding_variable.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __init__(
7272
**kwargs,
7373
):
7474
super(EmbeddingVariable, self).__init__(name=name)
75-
self.embedding_size = embedding_dim
75+
self.embedding_dim = embedding_dim
7676
self.with_unique = with_unique
7777
self.world_size = get_world_size()
7878

@@ -138,21 +138,21 @@ def unique_read(self, ids, *args, **kwargs):
138138
unique_ids, idx = tf.unique(ids_flat)
139139
unique_embeddings = self.read(unique_ids)
140140
embeddings_flat = tf.gather(unique_embeddings, idx)
141-
embeddings_shape = tf.concat([tf.shape(ids), tf.constant(self.embedding_size, shape=(1,))], 0)
141+
embeddings_shape = tf.concat([tf.shape(ids), tf.constant(self.embedding_dim, shape=(1,))], 0)
142142
embeddings = tf.reshape(embeddings_flat, embeddings_shape)
143143
return embeddings
144144

145145
def hvd_read(self, ids, *args, **kwargs):
146146
"""
147147
Compute embedding output for feature ids. The output shape will be (shape(ids),
148-
embedding_size).
148+
embedding_dim).
149149
150150
Args:
151151
ids: feature ids of the input. It should be same dtype as the key_dtype
152152
of the layer.
153153
154154
Returns:
155-
A embedding output with shape (shape(ids), embedding_size).
155+
A embedding output with shape (shape(ids), embedding_dim).
156156
"""
157157
is_ragged = isinstance(ids, tf.RaggedTensor)
158158

@@ -161,7 +161,7 @@ def hvd_read(self, ids, *args, **kwargs):
161161
ids = ids.flat_values
162162

163163
input_shape = tf.shape(ids)
164-
embeddings_shape = tf.concat([input_shape, [self.embedding_size]], 0)
164+
embeddings_shape = tf.concat([input_shape, [self.embedding_dim]], 0)
165165

166166
ids_flat = tf.reshape(ids, [-1])
167167

@@ -179,7 +179,7 @@ def distributed_lookup(ids):
179179
lookup_result, _ = hvd.alltoall(lookup_result, splits=remote_sizes, name=f"{self.name}_alltoall_embeddings")
180180

181181
input_shape = tf.shape(ids)
182-
recover_shape = tf.concat((input_shape, (self.embedding_size,)), axis=0)
182+
recover_shape = tf.concat((input_shape, (self.embedding_dim,)), axis=0)
183183
gather_indices = tf.expand_dims(tf.concat(gather_indices, axis=0), axis=-1)
184184
lookup_result = tf.scatter_nd(gather_indices, lookup_result, recover_shape)
185185
return lookup_result

0 commit comments

Comments
 (0)