@@ -72,7 +72,7 @@ def __init__(
7272 ** kwargs ,
7373 ):
7474 super (EmbeddingVariable , self ).__init__ (name = name )
75- self .embedding_size = embedding_dim
75+ self .embedding_dim = embedding_dim
7676 self .with_unique = with_unique
7777 self .world_size = get_world_size ()
7878
@@ -138,21 +138,21 @@ def unique_read(self, ids, *args, **kwargs):
138138 unique_ids , idx = tf .unique (ids_flat )
139139 unique_embeddings = self .read (unique_ids )
140140 embeddings_flat = tf .gather (unique_embeddings , idx )
141- embeddings_shape = tf .concat ([tf .shape (ids ), tf .constant (self .embedding_size , shape = (1 ,))], 0 )
141+ embeddings_shape = tf .concat ([tf .shape (ids ), tf .constant (self .embedding_dim , shape = (1 ,))], 0 )
142142 embeddings = tf .reshape (embeddings_flat , embeddings_shape )
143143 return embeddings
144144
145145 def hvd_read (self , ids , * args , ** kwargs ):
146146 """
147147 Compute embedding output for feature ids. The output shape will be (shape(ids),
148- embedding_size ).
148+ embedding_dim ).
149149
150150 Args:
151151 ids: feature ids of the input. It should be same dtype as the key_dtype
152152 of the layer.
153153
154154 Returns:
155- A embedding output with shape (shape(ids), embedding_size ).
155+ A embedding output with shape (shape(ids), embedding_dim ).
156156 """
157157 is_ragged = isinstance (ids , tf .RaggedTensor )
158158
@@ -161,7 +161,7 @@ def hvd_read(self, ids, *args, **kwargs):
161161 ids = ids .flat_values
162162
163163 input_shape = tf .shape (ids )
164- embeddings_shape = tf .concat ([input_shape , [self .embedding_size ]], 0 )
164+ embeddings_shape = tf .concat ([input_shape , [self .embedding_dim ]], 0 )
165165
166166 ids_flat = tf .reshape (ids , [- 1 ])
167167
@@ -179,7 +179,7 @@ def distributed_lookup(ids):
179179 lookup_result , _ = hvd .alltoall (lookup_result , splits = remote_sizes , name = f"{ self .name } _alltoall_embeddings" )
180180
181181 input_shape = tf .shape (ids )
182- recover_shape = tf .concat ((input_shape , (self .embedding_size ,)), axis = 0 )
182+ recover_shape = tf .concat ((input_shape , (self .embedding_dim ,)), axis = 0 )
183183 gather_indices = tf .expand_dims (tf .concat (gather_indices , axis = 0 ), axis = - 1 )
184184 lookup_result = tf .scatter_nd (gather_indices , lookup_result , recover_shape )
185185 return lookup_result
0 commit comments