-
Notifications
You must be signed in to change notification settings - Fork 53
Support fp32 accumulation for bf16 gemm and grouped gemm #482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@rolandschulz @tdeng5 Could you please take a review here? Thanks. |
@jiyang1011 Could you please have a review here? Thanks |
82f8e45
to
838aae9
Compare
For 16bits accumulator mma, it's the UT. Supposed that the ElementAccumulator is hard set as float. It's obviously not rational |
Previously bf16 in bf16 out gemm will use bf16 accmulator, after this patch fp32 accumulator will be used for good accuracy. There is an assumption that C and D have same dtypes. Signed-off-by: Wuxun Zhang <[email protected]>
dd3a205
to
0d733e8
Compare
Now it supports different accumulator dtype. Please check latest commits. |
@jiyang1011 @taozha2 can you please help trigger CI test here? |
|
||
using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; | ||
using ElementC = typename Gemm::ElementC; | ||
using ElementC = typename CollectiveEpilogue::ElementOutput; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this change is wrong. Why change this? Did you check this is correct if dtype of C and output is different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this issue can be detected by pre-ci checks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It assums C and D has same dtype here. I also think we need to support more dtype combinations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change it if you assume C and D is the same? We should add a static_assert if it only works under that assumption.
I see now that this is the new feature. Thank you! |
using ElementC = typename FusionCallbacks::ElementSource; | ||
using ElementAccumulator = ElementC_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
using TiledMma = typename CollectiveMainloop::TiledMma; | ||
|
||
using EpilogueOp = epilogue::fusion::LinearCombination<float, float>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, while the BF16 A
, B
and FP32
C
with BF16
D
/output case is supported in the main branch, the epilogue::fusion::LinearCombination
API usage at this line is non-intuitive because the main branch is using a hacky way that deviates from the intended/documented use of this API, since its first template parameter is intended to be the output dtype.
Currently, the unwritten/implicit contract for this code in the current main branch seems to be:
- intuitively thinking of it as computing
D
=alpha * Accum
+beta * C
in Float, - and then setting the correct
ElementD
parameter in thecutlass::epilogue::collective::CollectiveEpilogue
can be thought of as converting to the correct output dtype (which isElementOutput
in this file).
It seems that when this PR would be ready, it will rectify the API usage of cutlass::epilogue::fusion::LinearCombination
in this repo.
Thanks!
For current implementation, if users want a bf16-in-bf16-out gemm, the MMa atom
XE_8x16x16_BF16BF16BF16BF16_TT
will be used and accumulation happens in bf16 dtype. But for preserving good accuracy, fp32 accumulation is needed. This PR supports this by adding dtype conversion in epilogue. It adds assumption that C and D should have same dtype.