Skip to content

Commit 122b25d

Browse files
basilwongfacebook-github-bot
authored andcommitted
batched_embedding_kernel int32 support behind jk
Summary: ### tl;dr After this diff stack int32 indices and offsets will be supported for FBGEMM embedding lookup kernels. This will be able to enabled via config on APS. ### Implementation https://docs.google.com/document/d/1GoFghmJcDSGf6XhVkoTJs4C0jTemvpGe1fCNi6oQDRo/edit?usp=sharing ### Context https://docs.google.com/document/d/1YVfxsafqXkxAAdRyXbjmSH4AEz3-6DBiTGjs1rT8ZHQ/edit?usp=sharing ### Diff specific changes Putting the ability to cast to int32 behind jk killswitch which we can turn off at any time in torchrec. Differential Revision: D77843259
1 parent 0919506 commit 122b25d

File tree

1 file changed

+73
-11
lines changed

1 file changed

+73
-11
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 73 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
Union,
2828
)
2929

30+
import pyjk as justknobs
31+
3032
import torch
3133
import torch.distributed as dist
3234
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
@@ -949,6 +951,13 @@ def __init__(
949951
self._feature_table_map: List[int] = []
950952
self.table_name_to_count: Dict[str, int] = {}
951953
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}
954+
self._fused_params: Dict[str, Any] = config.fused_params or {}
955+
self._embedding_table_index_type: torch.dtype = self._fused_params.get(
956+
"embedding_table_index_type", torch.int64
957+
)
958+
self._embedding_table_offset_type: torch.dtype = self._fused_params.get(
959+
"embedding_table_offset_type", torch.int64
960+
)
952961

953962
for idx, table_config in enumerate(self._config.embedding_tables):
954963
self._local_rows.append(table_config.local_rows)
@@ -991,6 +1000,23 @@ def init_parameters(self) -> None:
9911000
)
9921001

9931002
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
1003+
if justknobs.check("pytorch/torchrec:int32_rollout_killswitch"):
1004+
indices_dtype = (
1005+
torch.int32
1006+
if self._embedding_table_index_type == torch.int32
1007+
and self._embedding_table_offset_type == torch.int32
1008+
else torch.int64
1009+
)
1010+
offsets_dtype = (
1011+
torch.int32
1012+
if self._embedding_table_index_type == torch.int32
1013+
and self._embedding_table_offset_type == torch.int32
1014+
else torch.int64
1015+
)
1016+
return self.emb_module(
1017+
indices=features.values().type(dtype=indices_dtype),
1018+
offsets=features.offsets().type(dtype=offsets_dtype),
1019+
)
9941020
return self.emb_module(
9951021
indices=features.values().long(),
9961022
offsets=features.offsets().long(),
@@ -1754,6 +1780,13 @@ def __init__(
17541780
self._lengths_per_emb: List[int] = []
17551781
self.table_name_to_count: Dict[str, int] = {}
17561782
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}
1783+
self._fused_params: Dict[str, Any] = config.fused_params or {}
1784+
self._embedding_table_index_type: torch.dtype = self._fused_params.get(
1785+
"embedding_table_index_type", torch.int64
1786+
)
1787+
self._embedding_table_offset_type: torch.dtype = self._fused_params.get(
1788+
"embedding_table_offset_type", torch.int64
1789+
)
17571790

17581791
for idx, table_config in enumerate(self._config.embedding_tables):
17591792
self._local_rows.append(table_config.local_rows)
@@ -1799,6 +1832,20 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
17991832
weights = features.weights_or_none()
18001833
if weights is not None and not torch.is_floating_point(weights):
18011834
weights = None
1835+
1836+
indices_dtype = (
1837+
torch.int32
1838+
if self._embedding_table_index_type == torch.int32
1839+
and self._embedding_table_offset_type == torch.int32
1840+
else torch.int64
1841+
)
1842+
offsets_dtype = (
1843+
torch.int32
1844+
if self._embedding_table_index_type == torch.int32
1845+
and self._embedding_table_offset_type == torch.int32
1846+
else torch.int64
1847+
)
1848+
18021849
if features.variable_stride_per_key() and isinstance(
18031850
self.emb_module,
18041851
(
@@ -1807,18 +1854,33 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
18071854
SSDTableBatchedEmbeddingBags,
18081855
),
18091856
):
1810-
return self.emb_module(
1811-
indices=features.values().long(),
1812-
offsets=features.offsets().long(),
1813-
per_sample_weights=weights,
1814-
batch_size_per_feature_per_rank=features.stride_per_key_per_rank(),
1815-
)
1857+
if justknobs.check("pytorch/torchrec:int32_rollout_killswitch"):
1858+
return self.emb_module(
1859+
indices=features.values().type(dtype=indices_dtype),
1860+
offsets=features.offsets().type(dtype=offsets_dtype),
1861+
per_sample_weights=weights,
1862+
batch_size_per_feature_per_rank=features.stride_per_key_per_rank(),
1863+
)
1864+
else:
1865+
return self.emb_module(
1866+
indices=features.values().long(),
1867+
offsets=features.offsets().long(),
1868+
per_sample_weights=weights,
1869+
batch_size_per_feature_per_rank=features.stride_per_key_per_rank(),
1870+
)
18161871
else:
1817-
return self.emb_module(
1818-
indices=features.values().long(),
1819-
offsets=features.offsets().long(),
1820-
per_sample_weights=weights,
1821-
)
1872+
if justknobs.check("pytorch/torchrec:int32_rollout_killswitch"):
1873+
return self.emb_module(
1874+
indices=features.values().type(dtype=indices_dtype),
1875+
offsets=features.offsets().type(dtype=offsets_dtype),
1876+
per_sample_weights=weights,
1877+
)
1878+
else:
1879+
return self.emb_module(
1880+
indices=features.values().long(),
1881+
offsets=features.offsets().long(),
1882+
per_sample_weights=weights,
1883+
)
18221884

18231885
# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
18241886
def state_dict(

0 commit comments

Comments
 (0)