Skip to content

Commit 32e15ba

Browse files
authored
Add Grouped GEMM for Mixed Dtype (#457)
This PR adds Grouped GEMM support for mixed precision GEMM.
2 parents ce061da + e9d1004 commit 32e15ba

21 files changed

+2664
-108
lines changed

applications/dual_gemm/collective/xe_dual_gemm_mma.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ struct DualGemmMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, ElementA_
9898

9999
using TensorMKL = decltype(make_tensor(make_gmem_ptr(static_cast<ElementA const*>(nullptr)), make_shape(0,0,0), StrideA{})); //(m, k)
100100
using TensorNKL = decltype(make_tensor(make_gmem_ptr(static_cast<ElementB const*>(nullptr)), make_shape(0,0,0), StrideB{})); //(n, k)
101-
101+
using MainloopTensors = cute::tuple<TensorMKL, TensorNKL>;
102102
using CopyThreadShape = Shape<_1, Int<SubgroupSize>>;
103103

104104
// Host side kernel arguments

applications/dual_gemm/kernel/xe_dual_gemm.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class DualGemm
125125
126126
using TensorMKL = typename DualGemmMainloop::TensorMKL;
127127
using TensorNKL = typename DualGemmMainloop::TensorNKL;
128-
128+
using MainloopTensors = cute::tuple<TensorMKL, TensorNKL>;
129129
using TensorMK = decltype(TensorMKL{}(_, _, 0));
130130
using TensorNK = decltype(TensorNKL{}(_, _, 0));
131131
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
/***************************************************************************************************
2+
* Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* 1. Redistributions of source code must retain the above copyright notice, this
9+
* list of conditions and the following disclaimer.
10+
*
11+
* 2. Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* 3. Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
*
30+
**************************************************************************************************/
31+
/*! \file
32+
\brief CUTLASS Intel BMG Grouped Gemm with mixed input types
33+
34+
This example demonstrates how to dispatch a mixed precision GEMM (int8 and bfloat16 | half_t) on BMG, with
35+
optional dequantization. The GemmMode enum describes the 3 modes of operation:
36+
37+
- ConvertOnly: Narrower type is simply converted to the wider type before MMA
38+
- ConvertAndScale: Narrower type is converted to wider type, then scaled
39+
- ConvertAndScaleWithZeroPoint: Narrower type is converted to wider type, scaled and offset
40+
41+
- Requirements:
42+
- dequantization group size (options.g) must be multiple of k-block size
43+
- scales & zeros must be MN-major
44+
45+
The MMA operation itself takes bfloat16 input for both A and B, and so the narrower type is first
46+
upcasted (inside the mainloop) prior to being passed into the MMA atom.
47+
48+
Verification for this example is performed against a standard reference GEMM in the wider type.
49+
The narrow-type input data are upcasted (or dequantized) externally before executing the
50+
reference GEMM.
51+
52+
Note: due to a bug in the IGC compiler, it's currently necessary to build this example with the
53+
following environment variable set (CMake handles this for AOT compilation; for JIT, please set
54+
this in your environment):
55+
56+
export IGC_allowDecompose2DBlockFuncs=0
57+
58+
To build & run this example (from your build dir):
59+
60+
$ ninja 10_bmg_grouped_gemm_bf16_s8
61+
$ ./examples/sycl/10_bmg_grouped_gemm_mixed_dtype/10_bmg_grouped_gemm_bf16_s8
62+
$ ninja 10_bmg_grouped_gemm_f16_s8_tensorwise
63+
$ ./examples/sycl/10_bmg_grouped_gemm_mixed_dtype/10_bmg_grouped_gemm_f16_s8_tensorwise
64+
65+
Call with `--help` for information about available options
66+
*/
67+
68+
#include "bmg_grouped_gemm_mixed_dtype_runner.hpp"
69+
70+
///////////////////////////////////////////////////////////////////////////////////////////////////
71+
72+
int main(int argc, const char** argv)
73+
{
74+
//
75+
// Parse options
76+
//
77+
78+
Options options;
79+
80+
options.parse(argc, argv);
81+
82+
if (options.help) {
83+
options.print_usage(std::cout) << std::endl;
84+
return 0;
85+
}
86+
87+
if (options.error) {
88+
std::cerr << "Aborting execution." << std::endl;
89+
return -1;
90+
}
91+
92+
//
93+
// Run examples
94+
//
95+
96+
// The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This
97+
// information is used by the underlying kernel.
98+
cutlass::KernelHardwareInfo hw_info;
99+
100+
// Change device_id to another value if you are running on a machine with multiple GPUs and wish
101+
// to use a GPU other than that with device ID 0.
102+
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
103+
104+
// The code section below describes datatype for input, output matrices and computation between
105+
// elements in input matrices.
106+
using ElementAccumulator = float; // <- data type of accumulator
107+
using ElementComputeEpilogue = float; // <- data type of epilogue operations
108+
using ElementInputA = cutlass::QUANT_TYPE; // <- data type of elements in input matrix A
109+
using ElementInputB = cutlass::MMA_TYPE; // <- data type of elements in input matrix B
110+
using ElementOutput = float; // <- data type of elements in output matrix D
111+
112+
using LayoutA = cutlass::layout::RowMajor;
113+
using LayoutB = cutlass::layout::RowMajor;
114+
using LayoutC = cutlass::layout::RowMajor;
115+
using LayoutD = cutlass::layout::RowMajor;
116+
117+
using ElementZero = cutlass::MMA_TYPE;
118+
using ElementScale = cutlass::MMA_TYPE;
119+
using StrideScale = cute::Stride<_1, int64_t, int64_t>;
120+
using StrideZero = StrideScale;
121+
122+
using GmemTiledCopyA = XE_2D_U8x32x32_LD_N; // U8 (1-byte) block copy for A (narrower type)
123+
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; // U16 (2-byte) block copy for B (wider type)
124+
static_assert(sizeof(ElementInputA) == 1, "ElementA width must match GmemTiledCopyA U8");
125+
126+
// Workgroup-level tile
127+
using TileShape = Shape<_256, _256, _32>;
128+
129+
// Although this is a mixed type example, the actual MMA accepts bf16 input for both A and B:
130+
using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32
131+
typename TiledMMAHelper<MMA_Atom<typename helpers::MMAOp<cutlass::MMA_TYPE>::type>, Layout<TileShape>,
132+
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
133+
134+
constexpr int PipelineStages = 3; // prefetch 3 iters of data for A and B
135+
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16GroupMixedPrecision<PipelineStages>;
136+
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group;
137+
138+
// Default (Linear Combination) epilogue
139+
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
140+
ElementAccumulator, ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
141+
142+
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<EpilogueDispatchPolicy, EpilogueOp, TileShape,
143+
decltype(tile_shape(TiledMma()))>;
144+
using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
145+
EpilogueDispatchPolicy,
146+
TileShape,
147+
ElementAccumulator,
148+
cutlass::gemm::TagToStrideC_t<LayoutC*>,
149+
ElementOutput,
150+
cutlass::gemm::TagToStrideC_t<LayoutD*>,
151+
FusionCallBacks,
152+
XE_2D_U32x8x16_LD_N,
153+
void, void,
154+
XE_2D_U32x8x16_ST_N,
155+
void, void>;
156+
157+
// Use the helpers to avoid template arg repetition
158+
using GemmAdapterBuilder = helpers::MixedGemmUniversalAdapterBuilder<ProblemShape, CollectiveEpilogue>;
159+
160+
using MixedBuilderQuantA =
161+
helpers::MixedCollectiveMmaBuilder<GEMMDispatchPolicy, TileShape,
162+
cutlass::gemm::TagToStrideA_t<LayoutA*>,
163+
cutlass::gemm::TagToStrideB_t<LayoutB*>,
164+
TiledMma, GmemTiledCopyA, GmemTiledCopyB>;
165+
166+
using MixedBuilderQuantB =
167+
helpers::MixedCollectiveMmaBuilder<GEMMDispatchPolicy, TileShape,
168+
cutlass::gemm::TagToStrideA_t<LayoutA*>,
169+
cutlass::gemm::TagToStrideB_t<LayoutB*>,
170+
TiledMma, GmemTiledCopyB, GmemTiledCopyA>;
171+
172+
// A-narrow Mainloop & GemmUniversalAdapter
173+
using MainloopAConvertOnly =
174+
MixedBuilderQuantA::CollectiveMma<cute::tuple<ElementInputA>,
175+
ElementInputB>;
176+
using GemmAConvertOnly =
177+
GemmAdapterBuilder::GemmUniversalAdapter<MainloopAConvertOnly>;
178+
179+
using MainloopAConvertAndScale = MixedBuilderQuantA::CollectiveMma<
180+
cute::tuple<ElementInputA, ElementScale, StrideScale*>, ElementInputB>;
181+
using GemmAConvertAndScale =
182+
GemmAdapterBuilder::GemmUniversalAdapter<MainloopAConvertAndScale>;
183+
184+
using MainloopAConvertAndScaleWithZeroPoint =
185+
MixedBuilderQuantA::CollectiveMma<
186+
cute::tuple<ElementInputA, ElementScale, StrideScale*, ElementZero, StrideZero*>, ElementInputB>;
187+
using GemmAConvertAndScaleWithZeroPoint =
188+
GemmAdapterBuilder::GemmUniversalAdapter<
189+
MainloopAConvertAndScaleWithZeroPoint>;
190+
191+
// B-narrow Mainloop & GemmUniversalAdapter
192+
using MainloopBConvertOnly =
193+
MixedBuilderQuantB::CollectiveMma<ElementInputB,
194+
cute::tuple<ElementInputA>>;
195+
using GemmBConvertOnly =
196+
GemmAdapterBuilder::GemmUniversalAdapter<MainloopBConvertOnly>;
197+
198+
using MainloopBConvertAndScale = MixedBuilderQuantB::CollectiveMma<
199+
ElementInputB, cute::tuple<ElementInputA, ElementScale, StrideScale*>>;
200+
using GemmBConvertAndScale =
201+
GemmAdapterBuilder::GemmUniversalAdapter<MainloopBConvertAndScale>;
202+
203+
using MainloopBConvertAndScaleWithZeroPoint =
204+
MixedBuilderQuantB::CollectiveMma<
205+
ElementInputB, cute::tuple<ElementInputA, ElementScale, StrideScale*, ElementZero, StrideZero*>>;
206+
using GemmBConvertAndScaleWithZeroPoint =
207+
GemmAdapterBuilder::GemmUniversalAdapter<
208+
MainloopBConvertAndScaleWithZeroPoint>;
209+
210+
if(options.a_narrower){
211+
std::cout << "Setting A as narrower type" << std::endl;
212+
if(options.mode == GemmMode::ConvertOnly) {
213+
std::cout << "Running in ConvertOnly mode." << std::endl;
214+
CUTLASS_CHECK(ExampleRunner<GemmAConvertOnly>{}.run(options, hw_info));
215+
} else if(options.mode == GemmMode::ConvertAndScale){
216+
std::cout << "Running in ConvertAndScale mode." << std::endl;
217+
CUTLASS_CHECK(ExampleRunner<GemmAConvertAndScale>{}.run(options, hw_info));
218+
} else {
219+
std::cout << "Running in ConvertAndScaleWithZeroPoint mode." << std::endl;
220+
CUTLASS_CHECK(ExampleRunner<GemmAConvertAndScaleWithZeroPoint>{}.run(options, hw_info));
221+
}
222+
} else {
223+
std::cout << "Setting B as narrower type" << std::endl;
224+
if(options.mode == GemmMode::ConvertOnly) {
225+
std::cout << "Running in ConvertOnly mode." << std::endl;
226+
CUTLASS_CHECK(ExampleRunner<GemmBConvertOnly>{}.run(options, hw_info));
227+
} else if(options.mode == GemmMode::ConvertAndScale){
228+
std::cout << "Running in ConvertAndScale mode." << std::endl;
229+
CUTLASS_CHECK(ExampleRunner<GemmBConvertAndScale>{}.run(options, hw_info));
230+
} else {
231+
std::cout << "Running in ConvertAndScaleWithZeroPoint mode." << std::endl;
232+
CUTLASS_CHECK(ExampleRunner<GemmBConvertAndScaleWithZeroPoint>{}.run(options, hw_info));
233+
}
234+
}
235+
236+
return 0;
237+
}

0 commit comments

Comments
 (0)