Skip to content

Commit d9fcec5

Browse files
basilwongfacebook-github-bot
authored andcommitted
batched_embedding_kernel int32 support behind jk (#3164)
Summary: ### tl;dr After this diff stack int32 indices and offsets will be supported for FBGEMM embedding lookup kernels. This will be able to enabled via config on APS. ### Implementation https://docs.google.com/document/d/1GoFghmJcDSGf6XhVkoTJs4C0jTemvpGe1fCNi6oQDRo/edit?usp=sharing ### Context https://docs.google.com/document/d/1YVfxsafqXkxAAdRyXbjmSH4AEz3-6DBiTGjs1rT8ZHQ/edit?usp=sharing ### Diff specific changes Putting the ability to cast to int32 behind jk killswitch which we can turn off at any time in torchrec. ### JK https://www.internalfb.com/intern/justknobs/?name=fbgemm_gpu%2Ffeatures&name=fbgemm_gpu%2Ffeatures#INT32_INDICES Differential Revision: D77843259
1 parent 89e7771 commit d9fcec5

File tree

1 file changed

+79
-11
lines changed

1 file changed

+79
-11
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 79 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
PartiallyMaterializedTensor,
5252
)
5353
from torch import nn
54+
5455
from torch.distributed._tensor import DTensor, Replicate, Shard as DTensorShard
5556
from torchrec.distributed.comm import get_local_rank, get_node_group_size
5657
from torchrec.distributed.composable.table_batched_embedding_slice import (
@@ -1071,6 +1072,13 @@ def __init__(
10711072
self._feature_table_map: List[int] = []
10721073
self.table_name_to_count: Dict[str, int] = {}
10731074
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}
1075+
self._fused_params: Dict[str, Any] = config.fused_params or {}
1076+
self._embedding_table_index_type: torch.dtype = self._fused_params.get(
1077+
"embedding_table_index_type", torch.int64
1078+
)
1079+
self._embedding_table_offset_type: torch.dtype = self._fused_params.get(
1080+
"embedding_table_offset_type", torch.int64
1081+
)
10741082

10751083
for idx, table_config in enumerate(self._config.embedding_tables):
10761084
self._local_rows.append(table_config.local_rows)
@@ -1113,6 +1121,25 @@ def init_parameters(self) -> None:
11131121
)
11141122

11151123
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
1124+
if torch._utils_internal.justknobs_check(
1125+
"pytorch/torchrec:int32_rollout_killswitch"
1126+
):
1127+
indices_dtype = (
1128+
torch.int32
1129+
if self._embedding_table_index_type == torch.int32
1130+
and self._embedding_table_offset_type == torch.int32
1131+
else torch.int64
1132+
)
1133+
offsets_dtype = (
1134+
torch.int32
1135+
if self._embedding_table_index_type == torch.int32
1136+
and self._embedding_table_offset_type == torch.int32
1137+
else torch.int64
1138+
)
1139+
return self.emb_module(
1140+
indices=features.values().type(dtype=indices_dtype),
1141+
offsets=features.offsets().type(dtype=offsets_dtype),
1142+
)
11161143
return self.emb_module(
11171144
indices=features.values().long(),
11181145
offsets=features.offsets().long(),
@@ -1857,6 +1884,13 @@ def __init__(
18571884
self._lengths_per_emb: List[int] = []
18581885
self.table_name_to_count: Dict[str, int] = {}
18591886
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}
1887+
self._fused_params: Dict[str, Any] = config.fused_params or {}
1888+
self._embedding_table_index_type: torch.dtype = self._fused_params.get(
1889+
"embedding_table_index_type", torch.int64
1890+
)
1891+
self._embedding_table_offset_type: torch.dtype = self._fused_params.get(
1892+
"embedding_table_offset_type", torch.int64
1893+
)
18601894

18611895
for idx, table_config in enumerate(self._config.embedding_tables):
18621896
self._local_rows.append(table_config.local_rows)
@@ -1902,6 +1936,20 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
19021936
weights = features.weights_or_none()
19031937
if weights is not None and not torch.is_floating_point(weights):
19041938
weights = None
1939+
1940+
indices_dtype = (
1941+
torch.int32
1942+
if self._embedding_table_index_type == torch.int32
1943+
and self._embedding_table_offset_type == torch.int32
1944+
else torch.int64
1945+
)
1946+
offsets_dtype = (
1947+
torch.int32
1948+
if self._embedding_table_index_type == torch.int32
1949+
and self._embedding_table_offset_type == torch.int32
1950+
else torch.int64
1951+
)
1952+
19051953
if features.variable_stride_per_key() and isinstance(
19061954
self.emb_module,
19071955
(
@@ -1910,18 +1958,38 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
19101958
SSDTableBatchedEmbeddingBags,
19111959
),
19121960
):
1913-
return self.emb_module(
1914-
indices=features.values().long(),
1915-
offsets=features.offsets().long(),
1916-
per_sample_weights=weights,
1917-
batch_size_per_feature_per_rank=features.stride_per_key_per_rank(),
1918-
)
1961+
1962+
if torch._utils_internal.justknobs_check(
1963+
"pytorch/torchrec:int32_rollout_killswitch"
1964+
):
1965+
return self.emb_module(
1966+
indices=features.values().type(dtype=indices_dtype),
1967+
offsets=features.offsets().type(dtype=offsets_dtype),
1968+
per_sample_weights=weights,
1969+
batch_size_per_feature_per_rank=features.stride_per_key_per_rank(),
1970+
)
1971+
else:
1972+
return self.emb_module(
1973+
indices=features.values().long(),
1974+
offsets=features.offsets().long(),
1975+
per_sample_weights=weights,
1976+
batch_size_per_feature_per_rank=features.stride_per_key_per_rank(),
1977+
)
19191978
else:
1920-
return self.emb_module(
1921-
indices=features.values().long(),
1922-
offsets=features.offsets().long(),
1923-
per_sample_weights=weights,
1924-
)
1979+
if torch._utils_internal.justknobs_check(
1980+
"pytorch/torchrec:int32_rollout_killswitch"
1981+
):
1982+
return self.emb_module(
1983+
indices=features.values().type(dtype=indices_dtype),
1984+
offsets=features.offsets().type(dtype=offsets_dtype),
1985+
per_sample_weights=weights,
1986+
)
1987+
else:
1988+
return self.emb_module(
1989+
indices=features.values().long(),
1990+
offsets=features.offsets().long(),
1991+
per_sample_weights=weights,
1992+
)
19251993

19261994
# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
19271995
def state_dict(

0 commit comments

Comments
 (0)