Skip to content

[BENCHMARK] Reuse CUTLASS's gemm configuration file #4720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions benchmarks/cmake/FindCUTLASSLibrary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ if (NOT CUTLASSLibrary_FOUND)
set(CUTLASSLibrary_INCLUDE_DIR "${CUTLASSLibrary_SOURCE_DIR}/include" CACHE INTERNAL "CUTLASSLibrary_SOURCE_DIR")
set(CUTLASSLibrary_INCLUDE_TOOL_DIR "${CUTLASSLibrary_SOURCE_DIR}/tools/util/include" CACHE INTERNAL "CUTLASSLibrary_SOURCE_DIR")
set(CUTLASSLibrary_INCLUDE_APPLICATION_DIR "${CUTLASSLibrary_SOURCE_DIR}/applications" CACHE INTERNAL "CUTLASSLibrary_SOURCE_DIR")
set(CUTLASSLibrary_INCLUDE_BENCHMARK_DIR "${CUTLASSLibrary_SOURCE_DIR}/benchmarks" CACHE INTERNAL "CUTLASSLibrary_SOURCE_DIR")
set(CUTLASSLibrary_BENCHMARK_CONFIG_DIR "${CUTLASSLibrary_SOURCE_DIR}/benchmarks/device/pvc/input_files" CACHE INTERNAL "CUTLASSLibrary_SOURCE_DIR")

find_package_handle_standard_args(
CUTLASSLibrary
Expand Down
31 changes: 30 additions & 1 deletion benchmarks/cutlass_kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,45 @@ set(CUTLASS_KERNEL_FLAGS ${CUTLASS_KERNEL_FLAGS}
-Xs "-options \"-igc_opts 'VISAOptions=-perfmodel,VectorAliasBBThreshold=1000,ExtraOCLOptions=-cl-intel-256-GRF-per-thread'\" -options -ze-opt-large-register-file"
)

# Path to the configuration tool
set(CONFIG_TOOL ${CMAKE_CURRENT_SOURCE_DIR}/config-tool.py)

# Input and output files
# The name of this file must be kept in sync with the best known CUTLASS config.
# TODO: Re-enable gemm config input to come from `CUTLASSLibrary_BENCHMARK_CONFIG_DIR`
# set(GEMM_CONFIG_INPUT ${CUTLASSLibrary_BENCHMARK_CONFIG_DIR}/input_gemm.in)
set(GEMM_CONFIG_INPUT ${CMAKE_CURRENT_SOURCE_DIR}/gemm/input_gemm.in)
set(GEMM_CONFIG_OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/gemm_table.hpp)
set(GEMM_CONFIG_NAME "gemm_config")

# Use a custom command to generate a C++ header with the configuration table
# from the CUTLASS benchmark configuration.
add_custom_command(
OUTPUT ${GEMM_CONFIG_OUTPUT}
COMMAND ${CMAKE_COMMAND} -E echo "Generating GEMM config header..."
COMMAND ${Python3_EXECUTABLE} ${CONFIG_TOOL} ${GEMM_CONFIG_INPUT} -o ${GEMM_CONFIG_OUTPUT} --name ${GEMM_CONFIG_NAME}
DEPENDS ${GEMM_CONFIG_INPUT} ${CONFIG_TOOL}
COMMENT "Generate GEMM configuration"
VERBATIM
)

# Create a target that other targets can depend on
add_custom_target(generate_gemm_config DEPENDS ${GEMM_CONFIG_OUTPUT})

Python3_add_library(cutlass_kernel MODULE WITH_SOABI python_main.cpp)

target_compile_options(cutlass_kernel PRIVATE "-fsycl" "-fsycl-targets=intel_gpu_pvc,intel_gpu_bmg_g21" "-fpreview-breaking-changes")
target_compile_options(cutlass_kernel PRIVATE "-DCUTLASS_ENABLE_SYCL")
target_compile_options(cutlass_kernel PRIVATE "-DSYCL_INTEL_TARGET")
target_compile_definitions(cutlass_kernel PRIVATE GEMM_CONFIG_HEADER=\"${GEMM_CONFIG_OUTPUT}\")
target_compile_definitions(cutlass_kernel PRIVATE GEMM_CONFIG_NAME=\"${GEMM_CONFIG_NAME}\")

target_link_options(cutlass_kernel PRIVATE ${CUTLASS_KERNEL_FLAGS})
target_link_libraries(cutlass_kernel PUBLIC ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY})

target_include_directories(cutlass_kernel PUBLIC "${CUTLASSLibrary_INCLUDE_DIR}" "${CUTLASSLibrary_INCLUDE_TOOL_DIR}" "${CUTLASSLibrary_INCLUDE_APPLICATION_DIR}")
target_include_directories(cutlass_kernel PUBLIC "${CUTLASSLibrary_INCLUDE_DIR}" "${CUTLASSLibrary_INCLUDE_TOOL_DIR}" "${CUTLASSLibrary_INCLUDE_APPLICATION_DIR}" "${CUTLASSLibrary_INCLUDE_BENCHMARK_DIR}")

add_dependencies(cutlass_kernel generate_gemm_config)

add_subdirectory(gemm)
add_subdirectory(attention)
Expand Down
9 changes: 5 additions & 4 deletions benchmarks/cutlass_kernel/attention/attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,11 @@ using FARunPtr = int (*)(const at::Tensor &Q, const at::Tensor &K,
int SeqLengthKV, int HeadSizeQK, int HeadSizeVO,
float sm_scale);

auto attention(const at::Tensor &Q, const at::Tensor &K, const at::Tensor &V,
at::Tensor &O, int Batch, int NumHeadsQ, int NumHeadsKV,
int SeqLengthQO, int SeqLengthKV, int HeadSizeQK, int HeadSizeVO,
bool Causal, float sm_scale) -> int {
auto attention_kernel(const at::Tensor &Q, const at::Tensor &K,
const at::Tensor &V, at::Tensor &O, int Batch,
int NumHeadsQ, int NumHeadsKV, int SeqLengthQO,
int SeqLengthKV, int HeadSizeQK, int HeadSizeVO,
bool Causal, float sm_scale) -> int {
constexpr int PipelineStages = 2;
FARunPtr f = nullptr;

Expand Down
54 changes: 54 additions & 0 deletions benchmarks/cutlass_kernel/config-tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python3

import argparse
import re
import sys


def build_config_map(file_paths):
config_map = {}
pattern = re.compile(r'^(?P<name>\S+).*?--l=(?P<l>\d+)\s+--m=(?P<m>\d+)\s+--k=(?P<k>\d+)\s+--n=(?P<n>\d+)')

for path in file_paths:
try:
with open(path, 'r', encoding='utf-8') as f:
for line in f:
match = pattern.match(line.strip())
if match:
name = match.group('name')
l = int(match.group('l'))
m = int(match.group('m'))
k = int(match.group('k'))
n = int(match.group('n'))
config_map[(l, m, n, k)] = name
except IOError as e:
print(f'Error reading {path}: {e}', file=sys.stderr)

return config_map


def main():
parser = argparse.ArgumentParser(description='Parse GEMM benchmark files and generate C++ table.')
parser.add_argument('-o', '--output', required=True, help='Output file path')
parser.add_argument('--name', required=True, help='Name identifier for logging or grouping')
parser.add_argument('inputs', nargs='+', help='Input file(s) with GEMM benchmark data')

args = parser.parse_args()

config_map = build_config_map(args.inputs)

try:
with open(args.output, 'w', encoding='utf-8') as outfile:
outfile.write('// This file was auto-generated, do not edit!\n\n')
outfile.write(
f'static constexpr std::array<std::pair<Dim, GemmRunPtr>, {len(config_map)}> {args.name} = {{{{\n')
for (l, m, n, k), name in config_map.items():
outfile.write(f'{{ {{ {l}, {m}, {n}, {k} }}, &gemm_run<{name}> }},\n')
outfile.write('}};\n')
except IOError as e:
print(f'Error writing output file: {e}')
sys.exit(1)


if __name__ == '__main__':
main()
116 changes: 26 additions & 90 deletions benchmarks/cutlass_kernel/gemm/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,69 +9,27 @@
#include <exception>
#include <iostream>

#define CUTLASS_CREATE_GEMM_BENCHMARK(x)
#define CUTLASS_BENCHMARK(x)
#include "gemm/benchmarks_sycl.hpp"
#include "gemm/gemm_configuration_sycl.hpp"

////////////////////////////////////////////////////////////////////////////////
// PRIVATE FUNCTION
////////////////////////////////////////////////////////////////////////////////

template <typename TileShape>
template <typename GemmConfig>
static auto gemm_run(const at::Tensor &A, const at::Tensor &B, at::Tensor &C,
const int M, const int N, const int K, const int L)
-> int {
RECORD_FUNCTION("cutlass gemm", {});

using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ElementInputA = cutlass::bfloat16_t;
using ElementInputB = cutlass::bfloat16_t;
using ElementOutput = float;

using LayoutA = typename cutlass::layout::RowMajor;
using LayoutB = typename cutlass::layout::RowMajor;
using LayoutC = typename cutlass::layout::RowMajor;
using LayoutD = typename cutlass::layout::RowMajor;

constexpr int AlignmentA = sizeof(ElementInputA);
constexpr int AlignmentB = sizeof(ElementInputB);
constexpr int AlignmentC = sizeof(ElementAccumulator);
constexpr int AlignmentD = sizeof(ElementOutput);

/// MAIN LOOP ///

using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, ElementInputA,
LayoutA, AlignmentA, ElementInputB, LayoutB, AlignmentB,
ElementAccumulator, TileShape,
cute::Shape<cute::_1, cute::_1, cute::_1>,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;

/// EPILOGUE LOOP ///

using EpilogueOp = typename cutlass::epilogue::fusion::LinCombEltAct<
cutlass::epilogue::thread::ReLu, ElementOutput, ElementComputeEpilogue,
ElementAccumulator, ElementAccumulator,
cutlass::FloatRoundStyle::round_to_nearest>;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, TileShape,
cute::Shape<cute::_1, cute::_1, cute::_1>,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementComputeEpilogue, ElementAccumulator, ElementAccumulator,
LayoutC, AlignmentC, ElementOutput, LayoutD, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto,
EpilogueOp>::CollectiveOp;

/// GEMM ///

using GemmKernel = typename cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>;

/// GEMM INVOCATION ///

try {
using Gemm =
typename cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using Gemm = GemmConfig::Gemm;
typename Gemm::Arguments arguments;

/// Buffer Initialization
Expand Down Expand Up @@ -107,15 +65,17 @@ static auto gemm_run(const at::Tensor &A, const at::Tensor &B, at::Tensor &C,
"Query result for SM count per device: " << hw_info.sm_count);
}

arguments = {cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
{_A, stride_A, _B, stride_B},
{{ElementComputeEpilogue(1), ElementComputeEpilogue(0)},
nullptr,
stride_C,
_C,
stride_D},
hw_info};
arguments = GemmConfig::defaultArguments();
arguments.mode = cutlass::gemm::GemmUniversalMode::kGemm;
arguments.problem_shape = problem_size;
arguments.mainloop = {_A, stride_A, _B, stride_B};
arguments.epilogue = {
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)},
nullptr,
stride_C,
_C,
stride_D};
arguments.hw_info = hw_info;

Gemm gemm_op;

Expand Down Expand Up @@ -148,43 +108,19 @@ using GemmRunPtr = int (*)(const at::Tensor &A, const at::Tensor &B,
at::Tensor &C, const int M, const int N, const int K,
const int L);

/// Each entry associates a specific problem dimension to their corresponding
/// tile shape. For more details, see:
/// https://github.com/codeplaysoftware/cutlass-sycl/tree/sycl-develop/benchmarks

// clang-format off
static constexpr std::array<std::pair<Dim, GemmRunPtr>, 18> tile_map = {{
{ { 1, 1024, 8192, 28672 }, &gemm_run<cute::Shape<cute::_128, cute::_512, cute::_32>> },
{ { 32, 4096, 128, 4096 }, &gemm_run<cute::Shape<cute::_256, cute::_128, cute::_32>> },
{ { 4096, 8, 16384, 128 }, &gemm_run<cute::Shape<cute::_128, cute::_256, cute::_16>> },
{ { 4096, 8, 128, 16384 }, &gemm_run<cute::Shape<cute::_8, cute::_128, cute::_32>> },
{ { 1, 1, 1024, 4096 }, &gemm_run<cute::Shape<cute::_8, cute::_64, cute::_32>> },
{ { 1, 1, 4096, 4096 }, &gemm_run<cute::Shape<cute::_8, cute::_128, cute::_32>> },
{ { 1, 1, 6144, 4096 }, &gemm_run<cute::Shape<cute::_8, cute::_128, cute::_32>> },
{ { 1, 1, 14336, 4096 }, &gemm_run<cute::Shape<cute::_64, cute::_256, cute::_32>> },
{ { 1, 1, 28672, 4096 }, &gemm_run<cute::Shape<cute::_32, cute::_128, cute::_32>> },
{ { 1, 1, 128256, 4096 }, &gemm_run<cute::Shape<cute::_32, cute::_512, cute::_32>> },
{ { 1, 1, 4096, 14336 }, &gemm_run<cute::Shape<cute::_8, cute::_128, cute::_32>> },
{ { 1, 8, 1024, 4096 }, &gemm_run<cute::Shape<cute::_8, cute::_64, cute::_32>> },
{ { 1, 8, 4096, 4096 }, &gemm_run<cute::Shape<cute::_256, cute::_256, cute::_32>> },
{ { 1, 8, 6144, 4096 }, &gemm_run<cute::Shape<cute::_256, cute::_256, cute::_32>> },
{ { 1, 8, 14336, 4096 }, &gemm_run<cute::Shape<cute::_64, cute::_256, cute::_32>> },
{ { 1, 8, 28672, 4096 }, &gemm_run<cute::Shape<cute::_32, cute::_128, cute::_32>> },
{ { 1, 8, 128256, 4096 }, &gemm_run<cute::Shape<cute::_32, cute::_512, cute::_32>> },
{ { 1, 8, 4096, 14336 }, &gemm_run<cute::Shape<cute::_256, cute::_256, cute::_32>> },
}};
// clang-format on

auto gemm(const at::Tensor &A, const at::Tensor &B, at::Tensor &C, const int M,
const int N, const int K, const int L) -> int {
// Includes the table mapping problem shape to best config from the header
// generated by the configuration tool from the CUTLASS config file.
#include GEMM_CONFIG_HEADER

auto gemm_kernel(const at::Tensor &A, const at::Tensor &B, at::Tensor &C,
const int M, const int N, const int K, const int L) -> int {
const Dim test_case{L, M, N, K};

for (auto const &kv : tile_map) {
for (auto const &kv : gemm_config) {
if (test_case == kv.first) {
return kv.second(A, B, C, M, N, K, L);
}
}

return gemm_run<cute::Shape<cute::_256, cute::_256, cute::_32>>(A, B, C, M, N,
K, L);
return gemm_run<PvcGemmBF16BF16FP32_RRR_1>(A, B, C, M, N, K, L);
}
21 changes: 21 additions & 0 deletions benchmarks/cutlass_kernel/gemm/input_gemm.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=1 --k=5120 --n=13824
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4 --k=4096 --n=12288
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=32768 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=32768
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=16384 --n=8192
PvcGemmBF16BF16FP32_RRR_2 --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=28672 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=3072 --k=4096 --n=3072
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=4096 --n=4096
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=16384 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=1024
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=4096
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=8192 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=1024
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=4096
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=1024 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=4096 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096
PvcGemmBF16BF16FP32_RRR_3 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128
PvcGemmBF16BF16FP32_RRR_5 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n=128
4 changes: 2 additions & 2 deletions benchmarks/cutlass_kernel/python_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@
////////////////////////////////////////////////////////////////////////////////

PYBIND11_MODULE(cutlass_kernel, m) {
m.def("gemm", &gemm, "gemm (CUTLASS)");
m.def("attention", &attention, "attention (CUTLASS)");
m.def("gemm", &gemm_kernel, "gemm (CUTLASS)");
m.def("attention", &attention_kernel, "attention (CUTLASS)");
}