@@ -265,7 +265,7 @@ def _add_table_variable(
265
265
table_specs : Sequence [embedding_spec .TableSpec ],
266
266
num_shards : int ,
267
267
add_slot_variables : bool ,
268
- ) -> tuple [ keras . Variable , tuple [ keras . Variable , ...] | None ] :
268
+ ) -> embedding . EmbeddingVariables :
269
269
stacked_table_spec = typing .cast (
270
270
embedding_spec .StackedTableSpec , table_specs [0 ].stacked_table_spec
271
271
)
@@ -334,7 +334,7 @@ def _add_table_variable(
334
334
slot_initializers , slot_variables
335
335
)
336
336
337
- return table_variable , slot_variables
337
+ return embedding . EmbeddingVariables ( table_variable , slot_variables )
338
338
339
339
@keras_utils .no_automatic_dependency_tracking
340
340
def _sparsecore_init (
@@ -738,8 +738,8 @@ def _sparsecore_set_tables(self, tables: Mapping[str, ArrayLike]) -> None:
738
738
# Assign stacked table variables to the device values.
739
739
keras .tree .map_structure_up_to (
740
740
device_tables ,
741
- lambda table_and_slot_variables ,
742
- table_value : table_and_slot_variables [ 0 ] .assign (table_value ),
741
+ lambda embedding_variables ,
742
+ table_value : embedding_variables . table .assign (table_value ),
743
743
self ._table_and_slot_variables ,
744
744
device_tables ,
745
745
)
@@ -754,8 +754,10 @@ def _sparsecore_get_embedding_tables(self) -> dict[str, ArrayLike]:
754
754
755
755
# Extract only the table variables, not the gradient slot variables.
756
756
table_variables = {
757
- name : jax .device_get (table_and_slots [0 ].value )
758
- for name , table_and_slots in self ._table_and_slot_variables .items ()
757
+ name : jax .device_get (embedding_variables .table .value )
758
+ for name , embedding_variables in (
759
+ self ._table_and_slot_variables .items ()
760
+ )
759
761
}
760
762
761
763
return typing .cast (
0 commit comments