-
-
Notifications
You must be signed in to change notification settings - Fork 10k
[Kernels] Enable Torch Symmetric Memory All-Reduce By Default #24111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Add benchmark Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
@ilmarkov can torch symm memory also be used to write Alltoall/AllGather/ReduceScatter? We currently see that those comm ops are quite slow when Attention DP is used, and I wonder if we can easily extend these AR implementations for those ops. Thanks! cc @weireweire |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work, LGTM!
# add options for testing | ||
force_multimem: Optional[bool] = None, | ||
max_size_override: Optional[int] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: these don't actually look to be used?
Enable torch symm memory for TP allreduce by default.
Add an testing option to custom allreduce to choose between one shot and two shot algos
Add a benchmark to compare nccl, custom allreduce and torch symm mem allreduce
The dispatching of the algorithms is done based on input size for cuda Hopper and Blackwell devices reaching the best performance out of the existing allreduce algorithms.
E2E results are presented in the original PR #20759
(Up to 10% TTFT improvement for TP=8)
Isolated primitives benchmark results:
H100
In TP=4 case for input sizes between 1 and 32MB we could use
symm_mem_multimem
but its performance is close to CA 2 stage so we use CA for all inputs below 32MB.TP=8 to be updated
B200