51
51
PartiallyMaterializedTensor ,
52
52
)
53
53
from torch import nn
54
+
54
55
from torch .distributed ._tensor import DTensor , Replicate , Shard as DTensorShard
55
56
from torchrec .distributed .comm import get_local_rank , get_node_group_size
56
57
from torchrec .distributed .composable .table_batched_embedding_slice import (
@@ -1071,6 +1072,13 @@ def __init__(
1071
1072
self ._feature_table_map : List [int ] = []
1072
1073
self .table_name_to_count : Dict [str , int ] = {}
1073
1074
self ._param_per_table : Dict [str , TableBatchedEmbeddingSlice ] = {}
1075
+ self ._fused_params : Dict [str , Any ] = config .fused_params or {}
1076
+ self ._embedding_table_index_type : torch .dtype = self ._fused_params .get (
1077
+ "embedding_table_index_type" , torch .int64
1078
+ )
1079
+ self ._embedding_table_offset_type : torch .dtype = self ._fused_params .get (
1080
+ "embedding_table_offset_type" , torch .int64
1081
+ )
1074
1082
1075
1083
for idx , table_config in enumerate (self ._config .embedding_tables ):
1076
1084
self ._local_rows .append (table_config .local_rows )
@@ -1113,6 +1121,25 @@ def init_parameters(self) -> None:
1113
1121
)
1114
1122
1115
1123
def forward (self , features : KeyedJaggedTensor ) -> torch .Tensor :
1124
+ if torch ._utils_internal .justknobs_check (
1125
+ "pytorch/torchrec:int32_rollout_killswitch"
1126
+ ):
1127
+ indices_dtype = (
1128
+ torch .int32
1129
+ if self ._embedding_table_index_type == torch .int32
1130
+ and self ._embedding_table_offset_type == torch .int32
1131
+ else torch .int64
1132
+ )
1133
+ offsets_dtype = (
1134
+ torch .int32
1135
+ if self ._embedding_table_index_type == torch .int32
1136
+ and self ._embedding_table_offset_type == torch .int32
1137
+ else torch .int64
1138
+ )
1139
+ return self .emb_module (
1140
+ indices = features .values ().type (dtype = indices_dtype ),
1141
+ offsets = features .offsets ().type (dtype = offsets_dtype ),
1142
+ )
1116
1143
return self .emb_module (
1117
1144
indices = features .values ().long (),
1118
1145
offsets = features .offsets ().long (),
@@ -1857,6 +1884,13 @@ def __init__(
1857
1884
self ._lengths_per_emb : List [int ] = []
1858
1885
self .table_name_to_count : Dict [str , int ] = {}
1859
1886
self ._param_per_table : Dict [str , TableBatchedEmbeddingSlice ] = {}
1887
+ self ._fused_params : Dict [str , Any ] = config .fused_params or {}
1888
+ self ._embedding_table_index_type : torch .dtype = self ._fused_params .get (
1889
+ "embedding_table_index_type" , torch .int64
1890
+ )
1891
+ self ._embedding_table_offset_type : torch .dtype = self ._fused_params .get (
1892
+ "embedding_table_offset_type" , torch .int64
1893
+ )
1860
1894
1861
1895
for idx , table_config in enumerate (self ._config .embedding_tables ):
1862
1896
self ._local_rows .append (table_config .local_rows )
@@ -1902,6 +1936,20 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
1902
1936
weights = features .weights_or_none ()
1903
1937
if weights is not None and not torch .is_floating_point (weights ):
1904
1938
weights = None
1939
+
1940
+ indices_dtype = (
1941
+ torch .int32
1942
+ if self ._embedding_table_index_type == torch .int32
1943
+ and self ._embedding_table_offset_type == torch .int32
1944
+ else torch .int64
1945
+ )
1946
+ offsets_dtype = (
1947
+ torch .int32
1948
+ if self ._embedding_table_index_type == torch .int32
1949
+ and self ._embedding_table_offset_type == torch .int32
1950
+ else torch .int64
1951
+ )
1952
+
1905
1953
if features .variable_stride_per_key () and isinstance (
1906
1954
self .emb_module ,
1907
1955
(
@@ -1910,18 +1958,38 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
1910
1958
SSDTableBatchedEmbeddingBags ,
1911
1959
),
1912
1960
):
1913
- return self .emb_module (
1914
- indices = features .values ().long (),
1915
- offsets = features .offsets ().long (),
1916
- per_sample_weights = weights ,
1917
- batch_size_per_feature_per_rank = features .stride_per_key_per_rank (),
1918
- )
1961
+
1962
+ if torch ._utils_internal .justknobs_check (
1963
+ "pytorch/torchrec:int32_rollout_killswitch"
1964
+ ):
1965
+ return self .emb_module (
1966
+ indices = features .values ().type (dtype = indices_dtype ),
1967
+ offsets = features .offsets ().type (dtype = offsets_dtype ),
1968
+ per_sample_weights = weights ,
1969
+ batch_size_per_feature_per_rank = features .stride_per_key_per_rank (),
1970
+ )
1971
+ else :
1972
+ return self .emb_module (
1973
+ indices = features .values ().long (),
1974
+ offsets = features .offsets ().long (),
1975
+ per_sample_weights = weights ,
1976
+ batch_size_per_feature_per_rank = features .stride_per_key_per_rank (),
1977
+ )
1919
1978
else :
1920
- return self .emb_module (
1921
- indices = features .values ().long (),
1922
- offsets = features .offsets ().long (),
1923
- per_sample_weights = weights ,
1924
- )
1979
+ if torch ._utils_internal .justknobs_check (
1980
+ "pytorch/torchrec:int32_rollout_killswitch"
1981
+ ):
1982
+ return self .emb_module (
1983
+ indices = features .values ().type (dtype = indices_dtype ),
1984
+ offsets = features .offsets ().type (dtype = offsets_dtype ),
1985
+ per_sample_weights = weights ,
1986
+ )
1987
+ else :
1988
+ return self .emb_module (
1989
+ indices = features .values ().long (),
1990
+ offsets = features .offsets ().long (),
1991
+ per_sample_weights = weights ,
1992
+ )
1925
1993
1926
1994
# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
1927
1995
def state_dict (
0 commit comments