@@ -105,6 +105,8 @@ class SparsecoreConfig:
105
105
sharding_strategy: The sharding strategy to use for the embedding table.
106
106
Defaults to 'MOD' sharding. See the sparsecore documentation for more
107
107
details.
108
+ allow_id_dropping: Whether to allow dropping of IDs that do not fit within
109
+ the XLA buffers allocated for each partition. Defaults to False.
108
110
num_sc_per_device: The number of sparsecores per Jax device. By default, a
109
111
fixed mapping is used to determine this based on device 0. This may fail
110
112
on newer TPU architectures if the mapping is not updated of if device 0 is
@@ -163,6 +165,7 @@ def __call__(self, inputs: Mapping[str, jax.Array]) -> jax.Array:
163
165
optimizer : OptimizerSpec
164
166
sharding_axis : str | int = 0
165
167
sharding_strategy : str = 'MOD'
168
+ allow_id_dropping : bool = False
166
169
167
170
# TODO(aahil): Come up with better defaults / heuristics here.
168
171
max_ids_per_partition_fn : Callable [[str , int ], int ] = dataclasses .field (
@@ -339,7 +342,7 @@ def _to_np(x: Any) -> np.ndarray:
339
342
global_device_count = self .sparsecore_config .global_device_count ,
340
343
num_sc_per_device = self .sparsecore_config .num_sc_per_device ,
341
344
sharding_strategy = self .sparsecore_config .sharding_strategy ,
342
- allow_id_dropping = False ,
345
+ allow_id_dropping = self . sparsecore_config . allow_id_dropping ,
343
346
batch_number = self ._batch_number ,
344
347
)
345
348
0 commit comments