Skip to content

Commit 847628b

Browse files
Hilly12recml authors
authored andcommitted
Fix breakage in sparsecore preprocessor.
PiperOrigin-RevId: 785645239
1 parent aa8e6a6 commit 847628b

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

recml/examples/dlrm_experiment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def sparsecore_config(self) -> sparsecore.SparsecoreConfig:
122122
for f in self.features.sparse_features()
123123
},
124124
optimizer=self.embedding_optimizer,
125+
allow_id_dropping=True,
125126
)
126127
object.__setattr__(self, '_sparsecore_config', sparsecore_config)
127128
return sparsecore_config

recml/layers/linen/sparsecore.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ class SparsecoreConfig:
105105
sharding_strategy: The sharding strategy to use for the embedding table.
106106
Defaults to 'MOD' sharding. See the sparsecore documentation for more
107107
details.
108+
allow_id_dropping: Whether to allow dropping of IDs that do not fit within
109+
the XLA buffers allocated for each partition. Defaults to False.
108110
num_sc_per_device: The number of sparsecores per Jax device. By default, a
109111
fixed mapping is used to determine this based on device 0. This may fail
110112
on newer TPU architectures if the mapping is not updated of if device 0 is
@@ -163,6 +165,7 @@ def __call__(self, inputs: Mapping[str, jax.Array]) -> jax.Array:
163165
optimizer: OptimizerSpec
164166
sharding_axis: str | int = 0
165167
sharding_strategy: str = 'MOD'
168+
allow_id_dropping: bool = False
166169

167170
# TODO(aahil): Come up with better defaults / heuristics here.
168171
max_ids_per_partition_fn: Callable[[str, int], int] = dataclasses.field(
@@ -339,7 +342,7 @@ def _to_np(x: Any) -> np.ndarray:
339342
global_device_count=self.sparsecore_config.global_device_count,
340343
num_sc_per_device=self.sparsecore_config.num_sc_per_device,
341344
sharding_strategy=self.sparsecore_config.sharding_strategy,
342-
allow_id_dropping=False,
345+
allow_id_dropping=self.sparsecore_config.allow_id_dropping,
343346
batch_number=self._batch_number,
344347
)
345348

0 commit comments

Comments
 (0)