Skip to content

Commit a218c87

Browse files
Refactor return type of DistributedEmbedding methods to use embedding.EmbeddingVariables (#158)
1 parent 6ca4b4f commit a218c87

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def _add_table_variable(
265265
table_specs: Sequence[embedding_spec.TableSpec],
266266
num_shards: int,
267267
add_slot_variables: bool,
268-
) -> tuple[keras.Variable, tuple[keras.Variable, ...] | None]:
268+
) -> embedding.EmbeddingVariables:
269269
stacked_table_spec = typing.cast(
270270
embedding_spec.StackedTableSpec, table_specs[0].stacked_table_spec
271271
)
@@ -334,7 +334,7 @@ def _add_table_variable(
334334
slot_initializers, slot_variables
335335
)
336336

337-
return table_variable, slot_variables
337+
return embedding.EmbeddingVariables(table_variable, slot_variables)
338338

339339
@keras_utils.no_automatic_dependency_tracking
340340
def _sparsecore_init(
@@ -738,8 +738,8 @@ def _sparsecore_set_tables(self, tables: Mapping[str, ArrayLike]) -> None:
738738
# Assign stacked table variables to the device values.
739739
keras.tree.map_structure_up_to(
740740
device_tables,
741-
lambda table_and_slot_variables,
742-
table_value: table_and_slot_variables[0].assign(table_value),
741+
lambda embedding_variables,
742+
table_value: embedding_variables.table.assign(table_value),
743743
self._table_and_slot_variables,
744744
device_tables,
745745
)
@@ -754,8 +754,10 @@ def _sparsecore_get_embedding_tables(self) -> dict[str, ArrayLike]:
754754

755755
# Extract only the table variables, not the gradient slot variables.
756756
table_variables = {
757-
name: jax.device_get(table_and_slots[0].value)
758-
for name, table_and_slots in self._table_and_slot_variables.items()
757+
name: jax.device_get(embedding_variables.table.value)
758+
for name, embedding_variables in (
759+
self._table_and_slot_variables.items()
760+
)
759761
}
760762

761763
return typing.cast(

0 commit comments

Comments
 (0)