Skip to content

Conversation

ilmarkov
Copy link
Contributor

@ilmarkov ilmarkov commented Sep 2, 2025

Enable torch symm memory for TP allreduce by default.
Add an testing option to custom allreduce to choose between one shot and two shot algos
Add a benchmark to compare nccl, custom allreduce and torch symm mem allreduce
The dispatching of the algorithms is done based on input size for cuda Hopper and Blackwell devices reaching the best performance out of the existing allreduce algorithms.

E2E results are presented in the original PR #20759
(Up to 10% TTFT improvement for TP=8)

Isolated primitives benchmark results:

H100

TP=2

All times are in milliseconds (ms) per allreduce operation

==================================================================================================================================
Device Communicator Benchmark Results
World Size: 2, Data Type: torch.bfloat16, Hidden Size: 8192
==================================================================================================================================
Tensor Shape        Tensor Size    ca_1stage           ca_2stage           pynccl              symm_mem_multimem   symm_mem_two_shot   Best (Speedup vs PyNccl)      
---------------------------------------------------------------------------------------------------------------------------------------------------------------------
(32, 8192)          0.50 MB        0.007               0.012               0.011               0.014               0.014               ca_1stage (1.63x)             
(64, 8192)          1.00 MB        0.009               0.014               0.017               0.016               0.016               ca_1stage (1.90x)             
(96, 8192)          1.50 MB        0.011               0.017               0.020               0.020               0.018               ca_1stage (1.85x)             
(128, 8192)         2.00 MB        0.013               0.019               0.037               0.023               0.020               ca_1stage (2.82x)             
(192, 8192)         3.00 MB        0.016               0.023               0.038               0.028               0.024               ca_1stage (2.29x)             
(256, 8192)         4.00 MB        0.021               0.028               0.040               0.034               0.028               ca_1stage (1.96x)             
(1024, 8192)        16.00 MB       0.070               0.080               0.087               0.112               0.084               ca_1stage (1.23x)             
(2048, 8192)        32.00 MB       0.128               0.167               0.150               0.233               0.161               ca_1stage (1.17x)             
(3062, 8192)        47.84 MB       0.188               0.246               0.203               0.346               0.236               ca_1stage (1.08x)             
(4096, 8192)        64.00 MB       0.250               0.326               0.262               0.463               0.312               ca_1stage (1.05x)             
==================================================================================================================================
TP=4

Tensor Shape        Tensor Size    ca_1stage           ca_2stage           pynccl              symm_mem_multimem   symm_mem_two_shot   Best (Speedup vs PyNccl)      
---------------------------------------------------------------------------------------------------------------------------------------------------------------------
(32, 8192)          0.50 MB        0.010               0.013               0.016               0.013               0.016               ca_1stage (1.61x)             
(64, 8192)          1.00 MB        0.015               0.016               0.022               0.016               0.018               ca_1stage (1.48x)             
(96, 8192)          1.50 MB        0.020               0.020               0.025               0.018               0.022               symm_mem_multimem (1.39x)     
(128, 8192)         2.00 MB        0.025               0.021               0.032               0.021               0.025               symm_mem_multimem (1.53x)     
(192, 8192)         3.00 MB        0.034               0.026               0.039               0.026               0.029               symm_mem_multimem (1.50x)     
(256, 8192)         4.00 MB        0.044               0.032               0.048               0.031               0.037               symm_mem_multimem (1.55x)     
(1024, 8192)        16.00 MB       0.157               0.097               0.114               0.100               0.119               ca_2stage (1.17x)             
(2048, 8192)        32.00 MB       0.307               0.184               0.188               0.191               0.229               ca_2stage (1.02x)             
(3062, 8192)        47.84 MB       0.457               0.272               0.258               0.280               0.333               pynccl (1.00x)                
(4096, 8192)        64.00 MB       0.608               0.360               0.341               0.372               0.445               pynccl (1.00x)                
==================================================================================================================================

In TP=4 case for input sizes between 1 and 32MB we could use symm_mem_multimem but its performance is close to CA 2 stage so we use CA for all inputs below 32MB.

TP=6

Tensor Shape        Tensor Size    ca_1stage           ca_2stage           pynccl              symm_mem_multimem   symm_mem_two_shot   Best (Speedup vs PyNccl)      
---------------------------------------------------------------------------------------------------------------------------------------------------------------------
(32, 8192)          0.50 MB        0.013               0.015               0.023               0.013               0.017               ca_1stage (1.71x)             
(64, 8192)          1.00 MB        0.021               0.017               0.025               0.016               0.021               symm_mem_multimem (1.60x)     
(96, 8192)          1.50 MB        0.030               0.020               0.032               0.018               0.024               symm_mem_multimem (1.75x)     
(128, 8192)         2.00 MB        0.038               0.026               0.043               0.021               0.029               symm_mem_multimem (2.05x)     
(192, 8192)         3.00 MB        0.053               0.029               0.047               0.025               0.035               symm_mem_multimem (1.84x)     
(256, 8192)         4.00 MB        0.070               0.037               0.053               0.031               0.041               symm_mem_multimem (1.71x)     
(1024, 8192)        16.00 MB       0.267               0.110               0.131               0.098               0.131               symm_mem_multimem (1.34x)     
(2048, 8192)        32.00 MB       0.520               0.212               0.212               0.189               0.250               symm_mem_multimem (1.12x)     
(3062, 8192)        47.84 MB       0.786               0.318               0.296               0.275               0.367               symm_mem_multimem (1.07x)     
(4096, 8192)        64.00 MB       1.040               0.417               0.344               0.366               0.486               pynccl (1.00x)                
==================================================================================================================================

TP=8 to be updated

B200

TP=2
Tensor Shape        Tensor Size    ca_1stage           ca_2stage           pynccl              symm_mem_multimem   symm_mem_two_shot   Best (Speedup vs PyNccl)      
---------------------------------------------------------------------------------------------------------------------------------------------------------------------
(32, 8192)          0.50 MB        0.009               0.013               0.014               0.016               0.016               ca_1stage (1.54x)             
(64, 8192)          1.00 MB        0.013               0.018               0.015               0.021               0.018               ca_1stage (1.23x)             
(96, 8192)          1.50 MB        0.015               0.021               0.017               0.025               0.019               ca_1stage (1.10x)             
(128, 8192)         2.00 MB        0.019               0.025               0.019               0.029               0.022               ca_1stage (1.02x)             
(192, 8192)         3.00 MB        0.026               0.031               0.025               0.038               0.025               symm_mem_two_shot (1.02x)     
(256, 8192)         4.00 MB        0.032               0.038               0.037               0.047               0.029               symm_mem_two_shot (1.30x)     
(1024, 8192)        16.00 MB       0.113               0.119               0.054               0.150               0.072               pynccl (1.00x)                
(2048, 8192)        32.00 MB       0.218               0.233               0.085               0.293               0.133               pynccl (1.00x)                
(3062, 8192)        47.84 MB       0.326               0.360               0.118               0.440               0.197               pynccl (1.00x)                
(4096, 8192)        64.00 MB       0.427               0.496               0.156               0.615               0.270               pynccl (1.00x) 
TP=4

Tensor Shape        Tensor Size    ca_1stage           ca_2stage           pynccl              symm_mem_multimem   symm_mem_two_shot   Best (Speedup vs PyNccl)      
---------------------------------------------------------------------------------------------------------------------------------------------------------------------
(32, 8192)          0.50 MB        0.010               0.017               0.017               0.016               0.018               ca_1stage (1.68x)             
(64, 8192)          1.00 MB        0.015               0.018               0.019               0.018               0.019               ca_1stage (1.33x)             
(96, 8192)          1.50 MB        0.019               0.027               0.022               0.021               0.019               ca_1stage (1.19x)             
(128, 8192)         2.00 MB        0.022               0.027               0.026               0.023               0.023               ca_1stage (1.14x)             
(192, 8192)         3.00 MB        0.030               0.035               0.031               0.028               0.025               symm_mem_two_shot (1.25x)     
(256, 8192)         4.00 MB        0.038               0.043               0.037               0.033               0.029               symm_mem_two_shot (1.29x)     
(1024, 8192)        16.00 MB       0.126               0.131               0.076               0.091               0.064               symm_mem_two_shot (1.19x)     
(2048, 8192)        32.00 MB       0.246               0.243               0.121               0.173               0.116               symm_mem_two_shot (1.04x)     
(3062, 8192)        47.84 MB       0.365               0.363               0.151               0.256               0.169               pynccl (1.00x)                
(4096, 8192)        64.00 MB       0.490               0.482               0.189               0.346               0.226               pynccl (1.00x)      
TP=6

Tensor Shape        Tensor Size    ca_1stage           ca_2stage           pynccl              symm_mem_multimem   symm_mem_two_shot   Best (Speedup vs PyNccl)      
---------------------------------------------------------------------------------------------------------------------------------------------------------------------
(32, 8192)          0.50 MB        0.012               0.022               0.025               0.016               0.023               ca_1stage (2.14x)             
(64, 8192)          1.00 MB        0.017               0.022               0.026               0.018               0.023               ca_1stage (1.58x)             
(96, 8192)          1.50 MB        0.022               0.023               0.027               0.019               0.024               symm_mem_multimem (1.45x)     
(128, 8192)         2.00 MB        0.026               0.035               0.031               0.021               0.026               symm_mem_multimem (1.45x)     
(192, 8192)         3.00 MB        0.035               0.036               0.035               0.024               0.035               symm_mem_multimem (1.43x)     
(256, 8192)         4.00 MB        0.044               0.048               0.052               0.029               0.037               symm_mem_multimem (1.81x)     
(1024, 8192)        16.00 MB       0.149               0.137               0.102               0.070               0.099               symm_mem_multimem (1.46x)     
(2048, 8192)        32.00 MB       0.291               0.249               0.180               0.132               0.175               symm_mem_multimem (1.37x)     
(3062, 8192)        47.84 MB       0.427               0.372               0.241               0.192               0.254               symm_mem_multimem (1.25x)     
(4096, 8192)        64.00 MB       0.568               0.487               0.310               0.259               0.334               symm_mem_multimem (1.20x) 
TP=8

Tensor Shape        Tensor Size    ca_1stage           ca_2stage           pynccl              symm_mem_multimem   symm_mem_two_shot   Best (Speedup vs PyNccl)      
---------------------------------------------------------------------------------------------------------------------------------------------------------------------
(32, 8192)          0.50 MB        0.013               0.025               0.028               0.016               0.023               ca_1stage (2.12x)             
(64, 8192)          1.00 MB        0.020               0.026               0.030               0.017               0.024               symm_mem_multimem (1.78x)     
(96, 8192)          1.50 MB        0.026               0.027               0.032               0.019               0.024               symm_mem_multimem (1.68x)     
(128, 8192)         2.00 MB        0.031               0.028               0.034               0.020               0.025               symm_mem_multimem (1.74x)     
(192, 8192)         3.00 MB        0.043               0.044               0.046               0.023               0.026               symm_mem_multimem (2.04x)     
(256, 8192)         4.00 MB        0.056               0.046               0.054               0.026               0.035               symm_mem_multimem (2.08x)     
(1024, 8192)        16.00 MB       0.197               0.144               0.101               0.060               0.080               symm_mem_multimem (1.68x)     
(2048, 8192)        32.00 MB       0.389               0.261               0.169               0.109               0.142               symm_mem_multimem (1.55x)     
(3062, 8192)        47.84 MB       0.564               0.383               0.222               0.160               0.208               symm_mem_multimem (1.39x)     
(4096, 8192)        64.00 MB       0.748               0.499               0.278               0.218               0.284               symm_mem_multimem (1.28x) 

Add benchmark

Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
@mergify mergify bot added the performance Performance-related issues label Sep 2, 2025
@ilmarkov ilmarkov marked this pull request as ready for review September 3, 2025 12:29
@robertgshaw2-redhat robertgshaw2-redhat changed the title Enable torch symmetric memory be default [Kernels][AR] Enable torch symmetric memory be default Sep 3, 2025
@robertgshaw2-redhat robertgshaw2-redhat changed the title [Kernels][AR] Enable torch symmetric memory be default [Kernels][AR] Enable Torch Symmetric Memory By Default Sep 3, 2025
@nvpohanh
Copy link
Contributor

nvpohanh commented Sep 5, 2025

@ilmarkov can torch symm memory also be used to write Alltoall/AllGather/ReduceScatter? We currently see that those comm ops are quite slow when Attention DP is used, and I wonder if we can easily extend these AR implementations for those ops. Thanks!

cc @weireweire

@ilmarkov
Copy link
Contributor Author

ilmarkov commented Sep 6, 2025

@nvpohanh Yes, torch symm mem supports these primitives (link). I haven't benchmarked them, though.

@mgoin mgoin changed the title [Kernels][AR] Enable Torch Symmetric Memory By Default [Kernels] Enable Torch Symmetric Memory All-Reduce By Default Sep 10, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, LGTM!

Comment on lines +34 to +36
# add options for testing
force_multimem: Optional[bool] = None,
max_size_override: Optional[int] = None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: these don't actually look to be used?

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants