27
27
Union ,
28
28
)
29
29
30
+ import pyjk as justknobs
31
+
30
32
import torch
31
33
import torch .distributed as dist
32
34
from fbgemm_gpu .split_table_batched_embeddings_ops_common import (
@@ -949,6 +951,13 @@ def __init__(
949
951
self ._feature_table_map : List [int ] = []
950
952
self .table_name_to_count : Dict [str , int ] = {}
951
953
self ._param_per_table : Dict [str , TableBatchedEmbeddingSlice ] = {}
954
+ self ._fused_params : Dict [str , Any ] = config .fused_params or {}
955
+ self ._embedding_table_index_type : torch .dtype = self ._fused_params .get (
956
+ "embedding_table_index_type" , torch .int64
957
+ )
958
+ self ._embedding_table_offset_type : torch .dtype = self ._fused_params .get (
959
+ "embedding_table_offset_type" , torch .int64
960
+ )
952
961
953
962
for idx , table_config in enumerate (self ._config .embedding_tables ):
954
963
self ._local_rows .append (table_config .local_rows )
@@ -991,6 +1000,23 @@ def init_parameters(self) -> None:
991
1000
)
992
1001
993
1002
def forward (self , features : KeyedJaggedTensor ) -> torch .Tensor :
1003
+ if justknobs .check ("pytorch/torchrec:int32_rollout_killswitch" ):
1004
+ indices_dtype = (
1005
+ torch .int32
1006
+ if self ._embedding_table_index_type == torch .int32
1007
+ and self ._embedding_table_offset_type == torch .int32
1008
+ else torch .int64
1009
+ )
1010
+ offsets_dtype = (
1011
+ torch .int32
1012
+ if self ._embedding_table_index_type == torch .int32
1013
+ and self ._embedding_table_offset_type == torch .int32
1014
+ else torch .int64
1015
+ )
1016
+ return self .emb_module (
1017
+ indices = features .values ().type (dtype = indices_dtype ),
1018
+ offsets = features .offsets ().type (dtype = offsets_dtype ),
1019
+ )
994
1020
return self .emb_module (
995
1021
indices = features .values ().long (),
996
1022
offsets = features .offsets ().long (),
@@ -1754,6 +1780,13 @@ def __init__(
1754
1780
self ._lengths_per_emb : List [int ] = []
1755
1781
self .table_name_to_count : Dict [str , int ] = {}
1756
1782
self ._param_per_table : Dict [str , TableBatchedEmbeddingSlice ] = {}
1783
+ self ._fused_params : Dict [str , Any ] = config .fused_params or {}
1784
+ self ._embedding_table_index_type : torch .dtype = self ._fused_params .get (
1785
+ "embedding_table_index_type" , torch .int64
1786
+ )
1787
+ self ._embedding_table_offset_type : torch .dtype = self ._fused_params .get (
1788
+ "embedding_table_offset_type" , torch .int64
1789
+ )
1757
1790
1758
1791
for idx , table_config in enumerate (self ._config .embedding_tables ):
1759
1792
self ._local_rows .append (table_config .local_rows )
@@ -1799,6 +1832,20 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
1799
1832
weights = features .weights_or_none ()
1800
1833
if weights is not None and not torch .is_floating_point (weights ):
1801
1834
weights = None
1835
+
1836
+ indices_dtype = (
1837
+ torch .int32
1838
+ if self ._embedding_table_index_type == torch .int32
1839
+ and self ._embedding_table_offset_type == torch .int32
1840
+ else torch .int64
1841
+ )
1842
+ offsets_dtype = (
1843
+ torch .int32
1844
+ if self ._embedding_table_index_type == torch .int32
1845
+ and self ._embedding_table_offset_type == torch .int32
1846
+ else torch .int64
1847
+ )
1848
+
1802
1849
if features .variable_stride_per_key () and isinstance (
1803
1850
self .emb_module ,
1804
1851
(
@@ -1807,18 +1854,33 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
1807
1854
SSDTableBatchedEmbeddingBags ,
1808
1855
),
1809
1856
):
1810
- return self .emb_module (
1811
- indices = features .values ().long (),
1812
- offsets = features .offsets ().long (),
1813
- per_sample_weights = weights ,
1814
- batch_size_per_feature_per_rank = features .stride_per_key_per_rank (),
1815
- )
1857
+ if justknobs .check ("pytorch/torchrec:int32_rollout_killswitch" ):
1858
+ return self .emb_module (
1859
+ indices = features .values ().type (dtype = indices_dtype ),
1860
+ offsets = features .offsets ().type (dtype = offsets_dtype ),
1861
+ per_sample_weights = weights ,
1862
+ batch_size_per_feature_per_rank = features .stride_per_key_per_rank (),
1863
+ )
1864
+ else :
1865
+ return self .emb_module (
1866
+ indices = features .values ().long (),
1867
+ offsets = features .offsets ().long (),
1868
+ per_sample_weights = weights ,
1869
+ batch_size_per_feature_per_rank = features .stride_per_key_per_rank (),
1870
+ )
1816
1871
else :
1817
- return self .emb_module (
1818
- indices = features .values ().long (),
1819
- offsets = features .offsets ().long (),
1820
- per_sample_weights = weights ,
1821
- )
1872
+ if justknobs .check ("pytorch/torchrec:int32_rollout_killswitch" ):
1873
+ return self .emb_module (
1874
+ indices = features .values ().type (dtype = indices_dtype ),
1875
+ offsets = features .offsets ().type (dtype = offsets_dtype ),
1876
+ per_sample_weights = weights ,
1877
+ )
1878
+ else :
1879
+ return self .emb_module (
1880
+ indices = features .values ().long (),
1881
+ offsets = features .offsets ().long (),
1882
+ per_sample_weights = weights ,
1883
+ )
1822
1884
1823
1885
# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
1824
1886
def state_dict (
0 commit comments