Skip to content

Commit 79c535e

Browse files
authored
add 16bits x 8bits grouped gemm UTs (#495)
add 16bits x 8 bits grouped gemm UTs where test the direct conversion from 8bits(float_e5m2) to 16bits(half_t)
2 parents 44c3127 + 1539ce1 commit 79c535e

File tree

4 files changed

+106
-4
lines changed

4 files changed

+106
-4
lines changed

include/cutlass/gemm/collective/builders/xe_mma_builder.inl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ struct CollectiveBuilder<
161161
KernelScheduleType,
162162
cute::enable_if_t<
163163
cute::is_any_of_v<KernelScheduleType, KernelScheduleAuto, KernelXe, KernelXeCooperative, KernelXePtrArrayCooperative> &&
164-
cute::is_any_of_v<ElementA, bfloat16_t, half_t, cute::int8_t> &&
165-
cute::is_any_of_v<ElementB, bfloat16_t, half_t, cute::int8_t, cute::uint4_t>
164+
cute::is_any_of_v<ElementA, bfloat16_t, half_t, cute::float_e5m2_t, cute::float_e4m3_t, cute::int8_t> &&
165+
cute::is_any_of_v<ElementB, bfloat16_t, half_t, cute::float_e5m2_t, cute::float_e4m3_t, cute::int8_t, cute::uint4_t>
166166
>
167167
>{
168168

test/unit/gemm/device/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ if(CUTLASS_ENABLE_SYCL)
7979
xe_gemm_fp16_s8_fp32_tensor_op_fp32_group_gemm.cpp
8080
xe_gemm_bf16_u4_fp32_tensor_op_fp32_group_gemm.cpp
8181
xe_gemm_fp16_u4_fp32_tensor_op_fp32_group_gemm.cpp
82+
xe_gemm_fp16_fp8_fp32_tensor_op_fp32_group_gemm.cpp
8283
)
8384

8485
add_custom_target(

test/unit/gemm/device/default_gemm_group_configuration.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ struct DefaultGemmGroupConfiguration<
7171
ElementOutput>
7272
{
7373

74-
static_assert(cute::is_any_of_v<ElementA, bfloat16_t, half_t, int8_t>, "ElementA needs to be of 16 or 8 bit type");
75-
static_assert(cute::is_any_of_v<ElementB, bfloat16_t, half_t, int8_t, uint4_t>, "ElementB needs to be of 16, 8 or 4 bit type");
74+
static_assert(cute::is_any_of_v<ElementA, bfloat16_t, half_t, int8_t, float_e5m2_t, float_e4m3_t>, "ElementA needs to be of 16 or 8 bit type");
75+
static_assert(cute::is_any_of_v<ElementB, bfloat16_t, half_t, int8_t, float_e5m2_t, float_e4m3_t, uint4_t>, "ElementB needs to be of 16, 8 or 4 bit type");
7676
using TileShape = Shape<_256, _256, _32>;
7777

7878
using CollectiveMainloop = typename gemm::collective::CollectiveBuilder<
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/***************************************************************************************************
2+
* Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* 1. Redistributions of source code must retain the above copyright notice, this
9+
* list of conditions and the following disclaimer.
10+
*
11+
* 2. Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* 3. Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
*
30+
**************************************************************************************************/
31+
32+
/*! \file
33+
\brief Tests for Xe Group fp16_fp8_fp32
34+
*/
35+
36+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
37+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
38+
#include "cutlass/gemm/group_array_problem_shape.hpp"
39+
40+
#include "default_gemm_configuration.hpp"
41+
#include "default_gemm_group_configuration.hpp"
42+
#include "gemm_testbed_3x_ptr_array.hpp"
43+
44+
namespace cutlass {
45+
namespace {
46+
template <typename LayoutA, typename LayoutB>
47+
struct XE_Device_Gemm_fp16_fp8_f32_tensor_op_f32_group_gemm {
48+
using ProblemShape = gemm::GroupProblemShape<cute::Shape<int,int,int>>; // <M,N,K> per group
49+
using ElementA = cute::half_t;
50+
using ElementB = cute::float_e5m2_t;
51+
using ElementC = float;
52+
using ElementAccumulator = float;
53+
using LayoutC = layout::RowMajor;
54+
55+
using Config = gemm::device::DefaultGemmGroupConfiguration<
56+
arch::OpClassTensorOp, arch::IntelXe,
57+
ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ElementAccumulator>;
58+
59+
using GemmKernel = gemm::kernel::GemmUniversal<
60+
ProblemShape,
61+
typename Config::CollectiveMainloop,
62+
typename Config::CollectiveEpilogue,
63+
gemm::GroupScheduler
64+
>;
65+
66+
using Gemm = gemm::device::GemmUniversalAdapter<GemmKernel>;
67+
};
68+
69+
TEST(XE_Device_Gemm_fp16t_fp8t_f32t_tensor_op_f32_group_gemm, 256x256x32) {
70+
using LayoutA = layout::RowMajor;
71+
using LayoutB = layout::RowMajor;
72+
using Gemm = XE_Device_Gemm_fp16_fp8_f32_tensor_op_f32_group_gemm<LayoutA, LayoutB>::Gemm;
73+
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>(1.0, 1.0));
74+
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>(1.0, 0.0));
75+
}
76+
77+
TEST(XE_Device_Gemm_fp16n_fp8t_f32t_tensor_op_f32_group_gemm, 256x256x32) {
78+
using LayoutA = layout::ColumnMajor;
79+
using LayoutB = layout::RowMajor;
80+
using Gemm = XE_Device_Gemm_fp16_fp8_f32_tensor_op_f32_group_gemm<LayoutA, LayoutB>::Gemm;
81+
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>(1.0, 1.0));
82+
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>(1.0, 0.0));
83+
}
84+
85+
TEST(XE_Device_Gemm_fp16t_fp8n_f32t_tensor_op_f32_group_gemm, 256x256x32) {
86+
using LayoutA = layout::RowMajor;
87+
using LayoutB = layout::ColumnMajor;
88+
using Gemm = XE_Device_Gemm_fp16_fp8_f32_tensor_op_f32_group_gemm<LayoutA, LayoutB>::Gemm;
89+
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>(1.0, 1.0));
90+
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>(1.0, 0.0));
91+
}
92+
93+
TEST(XE_Device_Gemm_fp16n_fp8n_f32t_tensor_op_f32_group_gemm, 256x256x32) {
94+
using LayoutA = layout::ColumnMajor;
95+
using LayoutB = layout::ColumnMajor;
96+
using Gemm = XE_Device_Gemm_fp16_fp8_f32_tensor_op_f32_group_gemm<LayoutA, LayoutB>::Gemm;
97+
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>(1.0, 1.0));
98+
EXPECT_TRUE(test::gemm::device::TestAll<Gemm>(1.0, 0.0));
99+
}
100+
}
101+
} // namespace cutlass

0 commit comments

Comments
 (0)