Skip to content

Conversation

wuxun-zhang
Copy link

For current implementation, if users want a bf16-in-bf16-out gemm, the MMa atom XE_8x16x16_BF16BF16BF16BF16_TT will be used and accumulation happens in bf16 dtype. But for preserving good accuracy, fp32 accumulation is needed. This PR supports this by adding dtype conversion in epilogue. It adds assumption that C and D should have same dtype.

@wuxun-zhang
Copy link
Author

@rolandschulz @tdeng5 Could you please take a review here? Thanks.

@tdeng5 tdeng5 requested a review from taozha2 August 29, 2025 05:26
@taozha2 taozha2 requested a review from jiyang1011 August 29, 2025 05:44
@wuxun-zhang
Copy link
Author

@jiyang1011 Could you please have a review here? Thanks

@jiyang1011
Copy link

https://github.com/intel/cutlass-sycl/blob/main/test/unit/gemm/device/xe_gemm_fp16_fp16_fp16_tensor_op_fp16.cpp

For 16bits accumulator mma, it's the UT. Supposed that the ElementAccumulator is hard set as float. It's obviously not rational

Previously bf16 in bf16 out gemm will use bf16 accmulator, after this patch
fp32 accumulator will be used for good accuracy.
There is an assumption that C and D have same dtypes.

Signed-off-by: Wuxun Zhang <[email protected]>
@wuxun-zhang
Copy link
Author

https://github.com/intel/cutlass-sycl/blob/main/test/unit/gemm/device/xe_gemm_fp16_fp16_fp16_tensor_op_fp16.cpp

For 16bits accumulator mma, it's the UT. Supposed that the ElementAccumulator is hard set as float. It's obviously not rational

Now it supports different accumulator dtype. Please check latest commits.

@wuxun-zhang
Copy link
Author

@jiyang1011 @taozha2 can you please help trigger CI test here?

@taozha2 taozha2 requested a review from rolandschulz September 4, 2025 00:47

using CollectiveEpilogue = typename Gemm::CollectiveEpilogue;
using ElementC = typename Gemm::ElementC;
using ElementC = typename CollectiveEpilogue::ElementOutput;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change is wrong. Why change this? Did you check this is correct if dtype of C and output is different?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this issue can be detected by pre-ci checks.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It assums C and D has same dtype here. I also think we need to support more dtype combinations.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change it if you assume C and D is the same? We should add a static_assert if it only works under that assumption.

@sanchitintel
Copy link

sanchitintel commented Sep 12, 2025

if users want a bf16-in-bf16-out gemm, the MMa atom XE_8x16x16_BF16BF16BF16BF16_TT will be used and accumulation happens in bf16 dtype. But for preserving good accuracy, fp32 accumulation is needed. This PR supports this by adding dtype conversion in epilogue.

XE_8x16x16_F32BF16BF16F32_TT already supports FP32 accumulation.
Support for changing output dtype to BF16 in epilogue is already present, except for GroupedGEMM.
I opened #505 & #506 to illustrate (may have to manually compare code with a diff tool such as BeyondCompare).

It adds assumption that C and D should have same dtype.

I see now that this is the new feature.

Thank you!

Comment on lines 93 to 94
using ElementC = typename FusionCallbacks::ElementSource;
using ElementAccumulator = ElementC_;
Copy link

@sanchitintel sanchitintel Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This config did not work with the latest commit of this branch. Is it currently supported? Thanks!

image


using TiledMma = typename CollectiveMainloop::TiledMma;

using EpilogueOp = epilogue::fusion::LinearCombination<float, float>;
Copy link

@sanchitintel sanchitintel Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, while the BF16 A, B and FP32 C with BF16 D/output case is supported in the main branch, the epilogue::fusion::LinearCombination API usage at this line is non-intuitive because the main branch is using a hacky way that deviates from the intended/documented use of this API, since its first template parameter is intended to be the output dtype.

Currently, the unwritten/implicit contract for this code in the current main branch seems to be:

  1. intuitively thinking of it as computing D = alpha * Accum + beta * C in Float,
  2. and then setting the correct ElementD parameter in the cutlass::epilogue::collective::CollectiveEpilogue can be thought of as converting to the correct output dtype (which is ElementOutput in this file).

It seems that when this PR would be ready, it will rectify the API usage of cutlass::epilogue::fusion::LinearCombination in this repo.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants