diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 1c8536086..de75d813e 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -51,6 +51,7 @@ PartiallyMaterializedTensor, ) from torch import nn + from torch.distributed._tensor import DTensor, Replicate, Shard as DTensorShard from torchrec.distributed.comm import get_local_rank, get_node_group_size from torchrec.distributed.composable.table_batched_embedding_slice import ( @@ -1071,6 +1072,13 @@ def __init__( self._feature_table_map: List[int] = [] self.table_name_to_count: Dict[str, int] = {} self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {} + self._fused_params: Dict[str, Any] = config.fused_params or {} + self._embedding_table_index_type: torch.dtype = self._fused_params.get( + "embedding_table_index_type", torch.int64 + ) + self._embedding_table_offset_type: torch.dtype = self._fused_params.get( + "embedding_table_offset_type", torch.int64 + ) for idx, table_config in enumerate(self._config.embedding_tables): self._local_rows.append(table_config.local_rows) @@ -1113,6 +1121,25 @@ def init_parameters(self) -> None: ) def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: + if torch._utils_internal.justknobs_check( + "pytorch/torchrec:int32_rollout_killswitch" + ): + indices_dtype = ( + torch.int32 + if self._embedding_table_index_type == torch.int32 + and self._embedding_table_offset_type == torch.int32 + else torch.int64 + ) + offsets_dtype = ( + torch.int32 + if self._embedding_table_index_type == torch.int32 + and self._embedding_table_offset_type == torch.int32 + else torch.int64 + ) + return self.emb_module( + indices=features.values().type(dtype=indices_dtype), + offsets=features.offsets().type(dtype=offsets_dtype), + ) return self.emb_module( indices=features.values().long(), offsets=features.offsets().long(), @@ -1857,6 +1884,13 @@ def __init__( self._lengths_per_emb: List[int] = [] self.table_name_to_count: Dict[str, int] = {} self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {} + self._fused_params: Dict[str, Any] = config.fused_params or {} + self._embedding_table_index_type: torch.dtype = self._fused_params.get( + "embedding_table_index_type", torch.int64 + ) + self._embedding_table_offset_type: torch.dtype = self._fused_params.get( + "embedding_table_offset_type", torch.int64 + ) for idx, table_config in enumerate(self._config.embedding_tables): self._local_rows.append(table_config.local_rows) @@ -1902,6 +1936,20 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: weights = features.weights_or_none() if weights is not None and not torch.is_floating_point(weights): weights = None + + indices_dtype = ( + torch.int32 + if self._embedding_table_index_type == torch.int32 + and self._embedding_table_offset_type == torch.int32 + else torch.int64 + ) + offsets_dtype = ( + torch.int32 + if self._embedding_table_index_type == torch.int32 + and self._embedding_table_offset_type == torch.int32 + else torch.int64 + ) + if features.variable_stride_per_key() and isinstance( self.emb_module, ( @@ -1910,18 +1958,38 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: SSDTableBatchedEmbeddingBags, ), ): - return self.emb_module( - indices=features.values().long(), - offsets=features.offsets().long(), - per_sample_weights=weights, - batch_size_per_feature_per_rank=features.stride_per_key_per_rank(), - ) + + if torch._utils_internal.justknobs_check( + "pytorch/torchrec:int32_rollout_killswitch" + ): + return self.emb_module( + indices=features.values().type(dtype=indices_dtype), + offsets=features.offsets().type(dtype=offsets_dtype), + per_sample_weights=weights, + batch_size_per_feature_per_rank=features.stride_per_key_per_rank(), + ) + else: + return self.emb_module( + indices=features.values().long(), + offsets=features.offsets().long(), + per_sample_weights=weights, + batch_size_per_feature_per_rank=features.stride_per_key_per_rank(), + ) else: - return self.emb_module( - indices=features.values().long(), - offsets=features.offsets().long(), - per_sample_weights=weights, - ) + if torch._utils_internal.justknobs_check( + "pytorch/torchrec:int32_rollout_killswitch" + ): + return self.emb_module( + indices=features.values().type(dtype=indices_dtype), + offsets=features.offsets().type(dtype=offsets_dtype), + per_sample_weights=weights, + ) + else: + return self.emb_module( + indices=features.values().long(), + offsets=features.offsets().long(), + per_sample_weights=weights, + ) # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently. def state_dict(