@@ -440,7 +440,7 @@ def all2all_pooled_sync(
440
440
input_split_sizes = [D_local_sum * B_rank for B_rank in batch_size_per_rank ]
441
441
qcomm_ctx = None
442
442
443
- with record_function ("## alltoall_fwd_single ##" ):
443
+ with record_function ("## all2all_pooled ##" ):
444
444
sharded_output_embeddings = AllToAllSingle .apply (
445
445
sharded_input_embeddings ,
446
446
output_split_sizes ,
@@ -558,7 +558,7 @@ def variable_batch_all2all_pooled_sync(
558
558
for split in input_split_sizes
559
559
]
560
560
561
- with record_function ("## alltoall_fwd_single ##" ):
561
+ with record_function ("## variable_batch_all2all_pooled ##" ):
562
562
if pg ._get_backend_name () == "custom" :
563
563
sharded_output_embeddings = torch .empty (
564
564
sum (output_split_sizes ),
@@ -674,7 +674,7 @@ def all2all_sequence_sync(
674
674
675
675
local_T = lengths_after_sparse_data_all2all .shape [0 ]
676
676
if local_T > 0 :
677
- with record_function ("## alltoall_seq_embedding_fwd_permute ##" ):
677
+ with record_function ("## all2all_sequence_permute ##" ):
678
678
if not variable_batch_size :
679
679
(
680
680
permuted_lengths_after_sparse_data_all2all ,
@@ -719,7 +719,7 @@ def all2all_sequence_sync(
719
719
else :
720
720
qcomm_ctx = None
721
721
722
- with record_function ("## alltoall_seq_embedding_fwd_single ##" ):
722
+ with record_function ("## all2all_sequence ##" ):
723
723
sharded_output_embeddings = AllToAllSingle .apply (
724
724
sharded_input_embeddings ,
725
725
output_splits ,
@@ -989,7 +989,7 @@ def reduce_scatter_v_sync(
989
989
input = rsi .codecs .forward .encode (input )
990
990
991
991
if rsi .equal_splits :
992
- with record_function ("## reduce_scatter_base ##" ):
992
+ with record_function ("## reduce_scatter_v ##" ):
993
993
output = torch .ops .torchrec .reduce_scatter_tensor (
994
994
input ,
995
995
reduceOp = "sum" ,
@@ -998,7 +998,7 @@ def reduce_scatter_v_sync(
998
998
gradient_division = get_gradient_division (),
999
999
)
1000
1000
else :
1001
- with record_function ("## reduce_scatter_v_via_all_to_all_single ##" ):
1001
+ with record_function ("## reduce_scatter_v (AllToAllSingle) ##" ):
1002
1002
input_splits = rsi .input_splits
1003
1003
output_splits = [rsi .input_splits [rank ]] * world_size
1004
1004
# TODO(ivankobzarev): Replace with _functional_collectives.reduce_scatter_v when it is added
@@ -1197,7 +1197,7 @@ def forward(
1197
1197
device = sharded_input_embeddings .device ,
1198
1198
)
1199
1199
1200
- with record_function ("## alltoall_fwd_single ##" ):
1200
+ with record_function ("## All2All_Pooled_fwd ##" ):
1201
1201
req = dist .all_to_all_single (
1202
1202
output = sharded_output_embeddings ,
1203
1203
input = sharded_input_embeddings ,
@@ -1218,7 +1218,6 @@ def forward(
1218
1218
1219
1219
@staticmethod
1220
1220
# pyre-fixme[2]: Parameter must be annotated.
1221
- # pyre-fixme[2]: Parameter must be annotated.
1222
1221
def backward (ctx , * unused ) -> Tuple [None , None , None , Tensor ]:
1223
1222
pg = ctx .pg
1224
1223
my_rank = dist .get_rank (pg )
@@ -1360,7 +1359,7 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:
1360
1359
device = sharded_grad_output .device ,
1361
1360
dtype = sharded_grad_output .dtype ,
1362
1361
)
1363
- with record_function ("## alltoall_bwd_single ##" ):
1362
+ with record_function ("## All2All_Pooled_bwd ##" ):
1364
1363
req = dist .all_to_all_single (
1365
1364
output = sharded_grad_input ,
1366
1365
input = sharded_grad_output ,
@@ -1445,7 +1444,7 @@ def forward(
1445
1444
device = sharded_input_embeddings .device ,
1446
1445
)
1447
1446
1448
- with record_function ("## alltoall_fwd_single ##" ):
1447
+ with record_function ("## Variable_Batch_All2All_Pooled_fwd ##" ):
1449
1448
req = dist .all_to_all_single (
1450
1449
output = sharded_output_embeddings ,
1451
1450
input = sharded_input_embeddings ,
@@ -1564,7 +1563,7 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:
1564
1563
device = sharded_grad_output .device ,
1565
1564
dtype = sharded_grad_output .dtype ,
1566
1565
)
1567
- with record_function ("## alltoall_bwd_single ##" ):
1566
+ with record_function ("## Variable_Batch_All2All_Pooled_bwd ##" ):
1568
1567
req = dist .all_to_all_single (
1569
1568
output = sharded_grad_input ,
1570
1569
input = sharded_grad_output ,
@@ -1605,7 +1604,7 @@ def forward(
1605
1604
1606
1605
local_T = lengths_after_sparse_data_all2all .shape [0 ]
1607
1606
if local_T > 0 :
1608
- with record_function ("## alltoall_seq_embedding_fwd_permute ##" ):
1607
+ with record_function ("## All2All_Seq_fwd_permute ##" ):
1609
1608
if not variable_batch_size :
1610
1609
(
1611
1610
permuted_lengths_after_sparse_data_all2all ,
@@ -1659,7 +1658,7 @@ def forward(
1659
1658
device = sharded_input_embeddings .device ,
1660
1659
)
1661
1660
1662
- with record_function ("## alltoall_seq_embedding_fwd_single ##" ):
1661
+ with record_function ("## All2All_Seq_fwd ##" ):
1663
1662
req = dist .all_to_all_single (
1664
1663
output = sharded_output_embeddings ,
1665
1664
input = sharded_input_embeddings ,
@@ -1707,7 +1706,7 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
1707
1706
myreq .dummy_tensor = None
1708
1707
1709
1708
if permuted_lengths_after_sparse_data_all2all is not None :
1710
- with record_function ("## alltoall_seq_embedding_bwd_permute ##" ):
1709
+ with record_function ("## All2All_Seq_bwd_permute ##" ):
1711
1710
if not variable_batch_size :
1712
1711
_ , sharded_grad_input , _ = torch .ops .fbgemm .permute_2D_sparse_data (
1713
1712
backward_recat_tensor ,
@@ -1788,7 +1787,7 @@ def backward(ctx, sharded_grad_output: Tensor) -> Tuple[None, None, Tensor]:
1788
1787
device = sharded_grad_output .device ,
1789
1788
dtype = sharded_grad_output .dtype ,
1790
1789
)
1791
- with record_function ("## alltoall_seq_embedding_bwd_single ##" ):
1790
+ with record_function ("## All2All_Seq_bwd ##" ):
1792
1791
req = dist .all_to_all_single (
1793
1792
output = sharded_grad_input ,
1794
1793
input = sharded_grad_output .view (- 1 ),
@@ -1822,7 +1821,7 @@ def forward(
1822
1821
input = a2ai .codecs .forward .encode (input )
1823
1822
1824
1823
output = input .new_empty (sum (output_split_sizes ))
1825
- with record_function ("## alltoallv_bwd_single ##" ):
1824
+ with record_function ("## All2Allv_fwd ##" ):
1826
1825
req = dist .all_to_all_single (
1827
1826
output ,
1828
1827
input ,
@@ -1908,7 +1907,7 @@ def backward(ctx, *grad_outputs) -> Tuple[None, None, Tensor]:
1908
1907
grad_outputs = [gout .contiguous ().view ([- 1 ]) for gout in grad_outputs ]
1909
1908
grad_output = torch .cat (grad_outputs )
1910
1909
grad_input = grad_output .new_empty ([a2ai .B_global * sum (a2ai .D_local_list )])
1911
- with record_function ("## alltoall_bwd_single ##" ):
1910
+ with record_function ("## All2Allv_bwd ##" ):
1912
1911
req = dist .all_to_all_single (
1913
1912
grad_input ,
1914
1913
grad_output ,
@@ -1944,7 +1943,7 @@ def forward(
1944
1943
dtype = inputs [my_rank ].dtype ,
1945
1944
device = inputs [my_rank ].device ,
1946
1945
)
1947
- with record_function ("## reduce_scatter ##" ):
1946
+ with record_function ("## ReduceScatter_fwd ##" ):
1948
1947
req = dist .reduce_scatter (
1949
1948
output ,
1950
1949
list (inputs ),
@@ -2023,7 +2022,7 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:
2023
2022
for in_size in rsi .input_sizes
2024
2023
]
2025
2024
2026
- with record_function ("## reduce_scatter_bw (all_gather) ##" ):
2025
+ with record_function ("## ReduceScatter_bwd (all_gather) ##" ):
2027
2026
req = dist .all_gather (
2028
2027
grad_inputs ,
2029
2028
grad_output .contiguous (),
@@ -2051,7 +2050,7 @@ def forward(
2051
2050
if rsi .codecs is not None :
2052
2051
inputs = rsi .codecs .forward .encode (inputs )
2053
2052
output = inputs .new_empty ((inputs .size (0 ) // my_size , inputs .size (1 )))
2054
- with record_function ("## reduce_scatter_tensor ##" ):
2053
+ with record_function ("## ReduceScatterBase_fwd (tensor) ##" ):
2055
2054
req = dist .reduce_scatter_tensor (
2056
2055
output ,
2057
2056
inputs ,
@@ -2119,7 +2118,7 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:
2119
2118
if rsi .codecs is not None :
2120
2119
grad_output = rsi .codecs .backward .encode (grad_output )
2121
2120
grad_inputs = grad_output .new_empty (rsi .input_sizes )
2122
- with record_function ("## reduce_scatter_base_bw (all_gather) ##" ):
2121
+ with record_function ("## ReduceScatterBase_bwd (all_gather) ##" ):
2123
2122
req = dist .all_gather_into_tensor (
2124
2123
grad_inputs ,
2125
2124
grad_output .contiguous (),
@@ -2148,7 +2147,7 @@ def forward(
2148
2147
input = agi .codecs .forward .encode (input )
2149
2148
2150
2149
outputs = input .new_empty ((input .size (0 ) * my_size , input .size (1 )))
2151
- with record_function ("## all_gather_into_tensor ##" ):
2150
+ with record_function ("## AllGatherBase_fwd (into_tensor) ##" ):
2152
2151
req = dist .all_gather_into_tensor (
2153
2152
outputs ,
2154
2153
input ,
@@ -2216,7 +2215,7 @@ def backward(ctx, grad_outputs: Tensor) -> Tuple[None, None, Tensor]:
2216
2215
if agi .codecs is not None :
2217
2216
grad_outputs = agi .codecs .backward .encode (grad_outputs )
2218
2217
grad_input = grad_outputs .new_empty (agi .input_size )
2219
- with record_function ("## all_gather_base_bw (reduce_scatter ) ##" ):
2218
+ with record_function ("## AllGatherBase_bw (reduce_scatter_tensor ) ##" ):
2220
2219
req = dist .reduce_scatter_tensor (
2221
2220
grad_input ,
2222
2221
grad_outputs .contiguous (),
@@ -2250,15 +2249,15 @@ def forward(
2250
2249
# Use dist.reduce_scatter_tensor when a vector reduce-scatter is not needed
2251
2250
# else use dist.reduce_scatter which internally supports vector reduce-scatter
2252
2251
if rsi .equal_splits :
2253
- with record_function ("## reduce_scatter_tensor ##" ):
2252
+ with record_function ("## ReduceScatterV_fwd (tensor) ##" ):
2254
2253
req = dist .reduce_scatter_tensor (
2255
2254
output ,
2256
2255
input ,
2257
2256
group = pg ,
2258
2257
async_op = True ,
2259
2258
)
2260
2259
else :
2261
- with record_function ("## reduce_scatter_v ##" ):
2260
+ with record_function ("## ReduceScatterV_fwd ##" ):
2262
2261
req = dist .reduce_scatter (
2263
2262
output ,
2264
2263
list (torch .split (input , rsi .input_splits )),
@@ -2331,15 +2330,15 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:
2331
2330
grad_input = grad_output .new_empty (rsi .total_input_size )
2332
2331
2333
2332
if rsi .equal_splits :
2334
- with record_function ("## reduce_scatter_base_bw (all_gather) ##" ):
2333
+ with record_function ("## ReduceScatterV_bwd (all_gather) ##" ):
2335
2334
req = dist .all_gather_into_tensor (
2336
2335
grad_input ,
2337
2336
grad_output .contiguous (),
2338
2337
group = ctx .pg ,
2339
2338
async_op = True ,
2340
2339
)
2341
2340
else :
2342
- with record_function ("## reduce_scatter_v_bw (all_gather_v) ##" ):
2341
+ with record_function ("## ReduceScatterV_bwd (all_gather_v) ##" ):
2343
2342
req = dist .all_gather (
2344
2343
list (torch .split (grad_input , rsi .input_splits )),
2345
2344
grad_output .contiguous (),
0 commit comments