diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index a6bd3447..c11969ae 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -169,14 +169,17 @@ static void m_grouped_fp8_gemm_nn_contiguous(const std::pair& a, +static std::optional> m_grouped_fp8_gemm_nt_masked(const std::pair& a, const std::pair& b, const torch::Tensor& d, const torch::Tensor& masked_m, const int& expected_m, std::optional> recipe, const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { + const bool& disable_ue8m0_cast, + const int& max_block_n, + const bool& enable_overlap, + const std::optional& signal) { // Shape must be `[G, M, K] @ [G, N, K].mT` const auto& major_a = get_major_type_ab(a.first); const auto& major_b = get_major_type_ab(b.first); @@ -196,6 +199,12 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pairget_arch_major(); + std::optional> result = std::nullopt; if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { - sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, - num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + result = sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims, + max_block_n, enable_overlap, signal); } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { sm100_m_grouped_fp8_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m, num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); @@ -219,6 +230,7 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair& a, @@ -436,7 +448,9 @@ static void register_apis(pybind11::module_& m) { m.def("m_grouped_fp8_gemm_nt_masked", &m_grouped_fp8_gemm_nt_masked, py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), py::arg("expected_m"), py::arg("recipe") = std::nullopt, - py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); + py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false, + py::arg("max_block_n") = 256, py::arg("enable_overlap") = false, + py::arg("signal") = std::nullopt); m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous, py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"), py::arg("ks_tensor"), py::arg("c") = std::nullopt, diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index 681e6546..9eec1e84 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -61,6 +61,7 @@ struct GemmConfig { cute::UMMA::Major major_b; bool with_accumulation; int block_m, block_n, block_k; + int signal_threshold; int num_stages, num_last_stages; // Templated device configs @@ -71,6 +72,8 @@ struct GemmConfig { MulticastConfig multicast_config; SharedMemoryConfig smem_config; ThreadConfig thread_config; + + bool enable_overlap; }; static bool is_multicast_legal(const int& shape_dim, const int& block_dim, @@ -146,7 +149,8 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k const int& m, const int& n, const int& k, const int& num_groups, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, - const bool& with_accumulation, const int& num_sms) { + const bool& with_accumulation, const int& num_sms, + const int& max_block_n = 256, const bool& enable_overlap = false) { DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn or ab_dtype == torch::kBFloat16); DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat); @@ -158,7 +162,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance block_ms = std::vector{64, 128}; std::vector block_ns; - for (int i = 16; i <= 256; i += 16) + for (int i = 16; i <= max_block_n; i += 16) block_ns.push_back(i); // K block size is selected in a fixed manner @@ -269,6 +273,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k .block_m = best_block_m, .block_n = best_block_n, .block_k = block_k, + .signal_threshold = ceil_div(n, best_block_n), .num_stages = best_num_stages, .num_last_stages = ceil_div(k, block_k) % best_num_stages, .num_sms = num_min_sms, @@ -276,7 +281,8 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k .multicast_config = best_multicast_config, // ReSharper disable once CppLocalVariableMightNotBeInitialized .smem_config = best_smem_config, - .thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n) + .thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n), + .enable_overlap = enable_overlap }; // Only SM100 BF16 kernels support tensor core control diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index 3afc2d33..49b4c9ad 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -3,7 +3,6 @@ #include #include "../../jit/compiler.hpp" -#include "../../jit/device_runtime.hpp" #include "../../jit/kernel_runtime.hpp" #include "../../utils/exception.hpp" #include "../../utils/format.hpp" @@ -21,7 +20,7 @@ class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime GemmConfig gemm_config; LaunchArgs launch_args; - void *sfb, *grouped_layout; + void *sfb, *grouped_layout, *signal; CUtensorMap tensor_map_a; CUtensorMap tensor_map_b; CUtensorMap tensor_map_d; @@ -43,7 +42,7 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, {}, - {}, {} + {}, {}, {} >); }}; )", @@ -55,13 +54,13 @@ static void __instantiate_kernel() {{ args.gemm_config.num_stages, args.gemm_config.num_last_stages, args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, - args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type)); + args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type), args.gemm_config.enable_overlap); } static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { // TODO: optimize `args` copy DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, - args.sfb, args.grouped_layout, + args.sfb, args.grouped_layout, args.signal, args.m, args.n, args.k, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d, args.tensor_map_sfa)); @@ -117,6 +116,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, config.multicast_config.num_multicast), .sfb = sfb.data_ptr(), .grouped_layout = nullptr, + .signal = nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_d = tensor_map_d, @@ -140,7 +140,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons const auto& aligned_k = align(k, 128); const auto& config = get_best_config( GemmType::MGroupedContiguous, KernelType::Kernel1D2D, - m, n, k, 1, major_a, major_b, + m, n, k, num_groups, major_a, major_b, torch::kFloat8_e4m3fn, d.scalar_type(), false, device_runtime->get_num_sms()); @@ -176,6 +176,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons config.multicast_config.num_multicast), .sfb = sfb.data_ptr(), .grouped_layout = m_indices.data_ptr(), + .signal = nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_d = tensor_map_d, @@ -186,14 +187,17 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons SM90FP8Gemm1D2DRuntime::launch(runtime, args); } -static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, +static std::optional> sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, const torch::Tensor& b, const torch::Tensor& sfb, const torch::Tensor& d, const torch::Tensor& masked_m, const int& num_groups, const int& m, const int& n, const int& k, const int& expected_m, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const std::string& compiled_dims) { + const std::string& compiled_dims, + const int& max_block_n, + const bool& enable_overlap, + const std::optional& signal) { const auto& aligned_k = align(k, 128); DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); @@ -202,7 +206,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to GemmType::MGroupedMasked, KernelType::Kernel1D2D, expected_m, n, k, num_groups, major_a, major_b, torch::kFloat8_e4m3fn, d.scalar_type(), false, - device_runtime->get_num_sms()); + device_runtime->get_num_sms(), max_block_n, enable_overlap); // Requires no TMA splits DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); @@ -236,6 +240,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to config.multicast_config.num_multicast), .sfb = sfb.data_ptr(), .grouped_layout = masked_m.data_ptr(), + .signal = enable_overlap ? signal.value().data_ptr() : nullptr, .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_d = tensor_map_d, @@ -244,6 +249,9 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code); SM90FP8Gemm1D2DRuntime::launch(runtime, args); + return enable_overlap ? + std::optional(std::make_pair(config.block_m, config.signal_threshold)) : + std::nullopt; } } // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/common/utils.cuh b/deep_gemm/include/deep_gemm/common/utils.cuh index fc84b696..3a9cdc06 100644 --- a/deep_gemm/include/deep_gemm/common/utils.cuh +++ b/deep_gemm/include/deep_gemm/common/utils.cuh @@ -144,6 +144,16 @@ __device__ __forceinline__ void prefetch_l1(void *ptr) { asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); } +__device__ __forceinline__ void store_wait() { + asm volatile("cp.async.bulk.wait_group 0;\n" ::: "memory"); +} + +__device__ __forceinline__ int atomic_add_release_global(int* addr, int value) { + int ret; + asm volatile ("atom.add.release.gpu.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(addr), "r"(value)); + return ret; +} + template struct Vectorized { static auto zeros() { diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index 5a65d69e..11f611fb 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -36,9 +36,9 @@ template + uint32_t kNumSMs, GemmType kGemmType, bool kEnableOverlap> __global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void -sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, +sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, int* signal, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, const __grid_constant__ cute::TmaDescriptor tensor_map_a, const __grid_constant__ cute::TmaDescriptor tensor_map_b, @@ -428,6 +428,18 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, cute::tma_store_arrive(); } __syncwarp(); + + if constexpr (kEnableOverlap) { + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + store_wait(); + } + + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + if (threadIdx.x == 0) { + atomic_add_release_global(signal + scheduler.current_group_idx * ceil_div(shape_m, BLOCK_M) + m_block_idx, 1); + } + } } } #else diff --git a/deep_gemm/testing/numeric.py b/deep_gemm/testing/numeric.py index d06a03b9..59f6acd2 100644 --- a/deep_gemm/testing/numeric.py +++ b/deep_gemm/testing/numeric.py @@ -1,6 +1,20 @@ import torch from typing import Iterable +def check_signal(num_local_expert, max_m, block_m, threshold, signal, masked_m): + ceil_div = lambda a, b: (a + b - 1) // b + + expert_len = max_m // block_m + for expert in range(num_local_expert): + mask = masked_m[expert] + start = expert * expert_len + end = expert * expert_len + expert_len + valid_len = ceil_div(mask, block_m) + for i in range(start, end): + if i < start + valid_len: + assert signal[i] == threshold, f'{i=}, {signal[i]=}, {threshold=}' + else: + assert signal[i] == 0, f'{i=}, {signal[i]=}' def calc_diff(x: torch.Tensor, y: torch.Tensor): x, y = x.double(), y.double() diff --git a/tests/generators.py b/tests/generators.py index 82cdbdcc..8f0cce99 100644 --- a/tests/generators.py +++ b/tests/generators.py @@ -88,9 +88,10 @@ def enumerate_m_grouped_contiguous(use_bf16: bool = False) -> Generator: def enumerate_m_grouped_masked() -> Generator: max_m = 4096 for kernel_type in get_kernel_types(): - for num_groups, m in ((1, 1024), (2, 512), (4, 256)): - for n, k in ((4096, 7168), (7168, 2048), ): - yield kernel_type, num_groups, max_m, m, n, k + for enable_overlap in (False, True): + for num_groups, m in ((1, 1024), (2, 512), (4, 256), (16, 64), (16, 32)): + for n, k in ((4096, 7168), (7168, 2048), ): + yield kernel_type, enable_overlap, num_groups, max_m, m, n, k def enumerate_k_grouped_contiguous(): @@ -191,7 +192,7 @@ def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int, - use_ue8m0: bool = False, use_bf16: bool = False): + use_ue8m0: bool = False, use_bf16: bool = False, enable_overlap: bool = False): a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16) b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16) @@ -211,7 +212,10 @@ def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: a_fp8[0][i], a_fp8[1][i] = per_token_cast_to_fp8(a[i], use_ue8m0=use_ue8m0) b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0) - return a_fp8, b_fp8, masked_m, d, ref_d + max_signal_size = num_groups * ceil_div(max_m, 64) + signal = torch.zeros(max_signal_size, dtype=torch.int32, device='cuda') if enable_overlap else None + + return a_fp8, b_fp8, masked_m, d, ref_d, signal def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, ks: List[int], use_ue8m0: bool): @@ -233,3 +237,4 @@ def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, ks: List[int] a_fp8 = per_channel_cast_to_fp8(a, use_ue8m0=use_ue8m0) b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0) return k, a_fp8, b_fp8, c, d, ref_d + diff --git a/tests/test_fp8.py b/tests/test_fp8.py index 0c7d3cea..04500505 100644 --- a/tests/test_fp8.py +++ b/tests/test_fp8.py @@ -6,7 +6,8 @@ import deep_gemm from deep_gemm.testing import ( bench, bench_kineto, - calc_diff, count_bytes + calc_diff, count_bytes, + check_signal, ) from generators import ( @@ -97,30 +98,35 @@ def test_m_grouped_gemm_masked() -> None: print('Testing m-grouped masked GEMM:') # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. - for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(): + for kernel_type, enable_overlap, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(): kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' use_ue8m0 = get_ue8m0_usage(kernel_type) disable_ue8m0_cast = not use_ue8m0 # Test correctness for i in range(10): - a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) - deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + a, b, masked_m, d, ref_d, signal = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0, enable_overlap=enable_overlap) + result = deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, enable_overlap=enable_overlap, signal=signal) + + if enable_overlap: + block_m, threshold = result + check_signal(num_groups, max_m, block_m, threshold, signal, masked_m) + for j in range(num_groups): diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' # Construct full cases - a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) + a, b, masked_m, d, ref_d, signal = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0, enable_overlap=enable_overlap) # noinspection PyShadowingNames def test_func(): - deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, enable_overlap=enable_overlap, signal=signal) # Test performance with fixed shapes valid_m = masked_m.sum().item() t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' + print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}, enable_overlap={enable_overlap}): ' f'{t * 1e6:4.0f} us | ' f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s')